Skip to content

Commit

Permalink
client: set auth header to localhost for unix target (#3730)
Browse files Browse the repository at this point in the history
  • Loading branch information
GarrettGutierrez1 authored Jul 21, 2020
1 parent 5f0e728 commit a5a36bd
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 12 deletions.
3 changes: 3 additions & 0 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *

// Determine the resolver to use.
cc.parsedTarget = grpcutil.ParseTarget(cc.target)
unixScheme := strings.HasPrefix(cc.target, "unix:")
channelz.Infof(logger, cc.channelzID, "parsed scheme: %q", cc.parsedTarget.Scheme)
resolverBuilder := cc.getResolver(cc.parsedTarget.Scheme)
if resolverBuilder == nil {
Expand All @@ -267,6 +268,8 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
cc.authority = creds.Info().ServerName
} else if cc.dopts.insecure && cc.dopts.authority != "" {
cc.authority = cc.dopts.authority
} else if unixScheme {
cc.authority = "localhost"
} else {
// Use endpoint from "scheme://authority/endpoint" as the default
// authority for ClientConn.
Expand Down
104 changes: 104 additions & 0 deletions test/authority_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package test

import (
"context"
"fmt"
"os"
"testing"
"time"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
testpb "google.golang.org/grpc/test/grpc_testing"
)

func runUnixTest(t *testing.T, address, target, expectedAuthority string) {
if err := os.RemoveAll(address); err != nil {
t.Fatalf("Error removing socket file %v: %v\n", address, err)
}
us := &stubServer{
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Error(codes.InvalidArgument, "failed to parse metadata")
}
auths, ok := md[":authority"]
if !ok {
return nil, status.Error(codes.InvalidArgument, "no authority header")
}
if len(auths) < 1 {
return nil, status.Error(codes.InvalidArgument, "no authority header")
}
if auths[0] != expectedAuthority {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid authority header %v, expected %v", auths[0], expectedAuthority))
}
return &testpb.Empty{}, nil
},
network: "unix",
address: address,
target: target,
}
if err := us.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
return
}
defer us.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, err := us.client.EmptyCall(ctx, &testpb.Empty{})
if err != nil {
t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err)
}
}

func (s) TestUnix(t *testing.T) {
tests := []struct {
name string
address string
target string
authority string
}{
{
name: "Unix1",
address: "sock.sock",
target: "unix:sock.sock",
authority: "localhost",
},
{
name: "Unix2",
address: "/tmp/sock.sock",
target: "unix:/tmp/sock.sock",
authority: "localhost",
},
{
name: "Unix3",
address: "/tmp/sock.sock",
target: "unix:///tmp/sock.sock",
authority: "localhost",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
runUnixTest(t, test.address, test.target, test.authority)
})
}
}
42 changes: 30 additions & 12 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4991,7 +4991,11 @@ type stubServer struct {
cc *grpc.ClientConn
s *grpc.Server

addr string // address of listener
// Parameters for Listen and Dial. Defaults will be used if these are empty
// before Start.
network string
address string
target string

cleanups []func() // Lambdas executed in Stop(); populated by Start().

Expand All @@ -5012,14 +5016,21 @@ func (ss *stubServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallSer

// Start starts the server and creates a client connected to it.
func (ss *stubServer) Start(sopts []grpc.ServerOption, dopts ...grpc.DialOption) error {
r := manual.NewBuilderWithScheme("whatever")
ss.r = r
if ss.network == "" {
ss.network = "tcp"
}
if ss.address == "" {
ss.address = "localhost:0"
}
if ss.target == "" {
ss.r = manual.NewBuilderWithScheme("whatever")
}

lis, err := net.Listen("tcp", "localhost:0")
lis, err := net.Listen(ss.network, ss.address)
if err != nil {
return fmt.Errorf(`net.Listen("tcp", "localhost:0") = %v`, err)
return fmt.Errorf("net.Listen(%q, %q) = %v", ss.network, ss.address, err)
}
ss.addr = lis.Addr().String()
ss.address = lis.Addr().String()
ss.cleanups = append(ss.cleanups, func() { lis.Close() })

s := grpc.NewServer(sopts...)
Expand All @@ -5028,15 +5039,20 @@ func (ss *stubServer) Start(sopts []grpc.ServerOption, dopts ...grpc.DialOption)
ss.cleanups = append(ss.cleanups, s.Stop)
ss.s = s

target := ss.r.Scheme() + ":///" + ss.addr
opts := append([]grpc.DialOption{grpc.WithInsecure()}, dopts...)
if ss.r != nil {
ss.target = ss.r.Scheme() + ":///" + ss.address
opts = append(opts, grpc.WithResolvers(ss.r))
}

opts := append([]grpc.DialOption{grpc.WithInsecure(), grpc.WithResolvers(r)}, dopts...)
cc, err := grpc.Dial(target, opts...)
cc, err := grpc.Dial(ss.target, opts...)
if err != nil {
return fmt.Errorf("grpc.Dial(%q) = %v", target, err)
return fmt.Errorf("grpc.Dial(%q) = %v", ss.target, err)
}
ss.cc = cc
ss.r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: ss.addr}}})
if ss.r != nil {
ss.r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: ss.address}}})
}
if err := ss.waitForReady(cc); err != nil {
return err
}
Expand All @@ -5048,7 +5064,9 @@ func (ss *stubServer) Start(sopts []grpc.ServerOption, dopts ...grpc.DialOption)
}

func (ss *stubServer) newServiceConfig(sc string) {
ss.r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: ss.addr}}, ServiceConfig: parseCfg(ss.r, sc)})
if ss.r != nil {
ss.r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: ss.address}}, ServiceConfig: parseCfg(ss.r, sc)})
}
}

func (ss *stubServer) waitForReady(cc *grpc.ClientConn) error {
Expand Down

0 comments on commit a5a36bd

Please sign in to comment.