Skip to content

Commit

Permalink
client: Use goroutines instead of TryLock
Browse files Browse the repository at this point in the history
Get rid of TryLock by using proper Go structures.

Signed-off-by: Jussi Maki <jussi@isovalent.com>
Signed-off-by: Jarno Rajahalme <jarno@isovalent.com>
  • Loading branch information
jrajahalme committed Nov 8, 2023
1 parent 8374ecb commit 6a4a248
Showing 1 changed file with 118 additions and 123 deletions.
241 changes: 118 additions & 123 deletions shared_client.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright Authors of Cilium

//go:build go1.18
// +build go1.18

package dns

import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
)
Expand All @@ -27,9 +25,10 @@ func NewSharedClients() *SharedClients {
}
}

// GetSharedClient gets or creates an instance of SharedClient keyed with 'key'.
// if 'key' is an empty sting, a new client is always created and it is not actually shared.
// The returned 'closer' must be called once the client is no longer needed.
// GetSharedClient gets or creates an instance of SharedClient keyed with 'key'. if 'key' is an
// empty sting, a new client is always created and it is not actually shared. The returned 'closer'
// must be called once the client is no longer needed. Conversely, the returned 'client' must not be
// used after the closer is called.
func (s *SharedClients) GetSharedClient(key string, conf *Client, serverAddrStr string) (client *SharedClient, closer func()) {
s.Lock()
defer s.Unlock()
Expand All @@ -56,20 +55,26 @@ func (s *SharedClients) GetSharedClient(key string, conf *Client, serverAddrStr
if key != "" {
delete(s.clients, key)
}
// connection must be closed while holding the mutex to avoid a race where a
// new client is created with the same key before this one is closed, which
// could happen if the mutex is released before this Close call.
if client.conn != nil {
client.conn.Close()
}
// connection close must be completed while holding the mutex to avoid a
// race where a new client is created with the same key before this one is
// closed, which could happen if the mutex is released before this Close
// call.
close(client.requestsToSend)
client.conn = nil
client.wg.Wait()
}
}
}

var errNoReader = errors.New("Reader stopped")
type request struct {
ctx context.Context
msg *Msg
ch chan sharedClientResponse
}

type Response struct {
*Msg
type sharedClientResponse struct {
msg *Msg
rtt time.Duration
err error
}

Expand All @@ -79,137 +84,127 @@ type SharedClient struct {

*Client

refcount int // protected by SharedClient's lock

// this mutex protects writes on 'conn' and all access to 'reqs'
sync.Mutex
reqs map[uint16]chan Response // outstanding requests

// 'readerLock' mutex is used to serialize reads on 'conn'. It is always taken and released
// while holding the main lock but the main lock can be released and re-acquired while
// holding 'readerLock' mutex.
readerLock sync.Mutex
requestsToSend chan request
responses chan sharedClientResponse

// Client's connection shared among all requests from the same source address/port. The
// locks above are used to serialize reads and writes on this connection, but reads and
// writes can happen at the same time.
conn *Conn
sync.Mutex // protects fields below
refcount int
conn *Conn
wg sync.WaitGroup
}

func newSharedClient(conf *Client, serverAddr string) *SharedClient {
return &SharedClient{
serverAddr: serverAddr,
Client: conf,
reqs: make(map[uint16]chan Response),
serverAddr: serverAddr,
Client: conf,
requestsToSend: make(chan request, 1),
responses: make(chan sharedClientResponse, 1),
}
}

// ExchangeShared writes the request to the Client's connection and co-operatively
// reads responses from the connection and distributes them to the requestors.
// At most one caller is reading from Client's connection at any time.
// ExchangeShared dials a connection to the server on first invocation, and starts a handler
// goroutines to send and receive responses, distributing them to appropriate concurrent caller
// based on the DNS message Id.
func (c *SharedClient) ExchangeShared(m *Msg) (r *Msg, rtt time.Duration, err error) {
return c.ExchangeSharedContext(context.Background(), m)
}

// ExchangeSharedContext writes the request to the Client's connection and co-operatively
// reads responses from the connection and distributes them to the requestors.
// At most one caller is reading from Client's connection at any time.
func (c *SharedClient) ExchangeSharedContext(ctx context.Context, m *Msg) (r *Msg, rtt time.Duration, err error) {
// Lock allows only one request to be written at a time, but that can happen
// concurrently with reading.
c.Lock()
defer c.Unlock()
if _, exists := c.reqs[m.Id]; exists {
return nil, 0, fmt.Errorf("duplicate request: %d", m.Id)
// handler is started when the connection is dialed
func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan request, responses chan sharedClientResponse) {
defer wg.Done()

// Receive loop
wg.Add(1)
go func() {
defer wg.Done()
defer close(responses)
for {
r, err := conn.ReadMsg()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
responses <- sharedClientResponse{nil, 0, err}
} else {
responses <- sharedClientResponse{r, 0, nil}
}
}
}()

type waiter struct {
ch chan sharedClientResponse
start time.Time
}
waitingResponses := make(map[uint16]waiter)
defer func() {
conn.Close()

// Dial if needed
if c.conn == nil {
c.conn, err = c.DialContext(ctx, c.serverAddr)
if err != nil {
return nil, 0, fmt.Errorf("failed to dial connection to %v: %w", c.serverAddr, err)
// Drain responses send by receive loop to allow it to exit.
for range responses {
}
}

// Create channel for the response with buffer of one, so that write to it
// does not block if we happen to do it ourselves.
respCh := make(chan Response, 1)
c.reqs[m.Id] = respCh

// Send while holding the client lock, as Client is not made to be usable from
// concurrent goroutines.
start := time.Now()
err = c.SendContext(ctx, m, c.conn, start)
if err != nil {
return nil, 0, err
}
for _, waiter := range waitingResponses {
waiter.ch <- sharedClientResponse{nil, 0, net.ErrClosed}
close(waiter.ch)
}
}()

// Wait for the response
var resp Response
for {
// Try taking the reader lock
if c.readerLock.TryLock() {
// We are responsible for reading responses for all users
// of this client until we get our own response or an error occurs.
var err error
for err == nil {
// Release the client lock for the duration of the blocking read
// operation to allow concurrent writes to the underlying
// connection.
var r *Msg
c.Unlock()
// This ReadMsg() will eventually fail due to the read deadline set
// by 'Client' on the underlying connection when sending the
// (last) request.
r, err = c.conn.ReadMsg()
c.Lock()
if err != nil {
break
}
// Locate the request for this response, skipping if not found
ch, exists := c.reqs[r.Id]
if !exists {
continue
}
// Pass the response to the waiting requester
delete(c.reqs, r.Id)
ch <- Response{Msg: r}
if r.Id == m.Id {
// Got our response, quit reading and tell others that
// its their turn to read.
err = errNoReader
}
select {
case req, ok := <-requests:
if !ok {
return
}
// Releasing the reader lock before sending errors on waiter's channels
// so that when they get them, one of them can take the reader lock.
c.readerLock.Unlock()
if errors.Is(err, errNoReader) {
// Can only wake one waiting requester to do the reading as the
// channel buffer length is one, otherwise the channel could get
// full while the request is still waiting for a lock.
for _, ch := range c.reqs {
ch <- Response{err: err}
break
}
start := time.Now()
err := client.SendContext(req.ctx, req.msg, conn, start)
if err != nil {
req.ch <- sharedClientResponse{nil, 0, err}
close(req.ch)
} else {
// Other errors are sent to all recipients
for id, ch := range c.reqs {
delete(c.reqs, id)
ch <- Response{err: err}
waitingResponses[req.msg.Id] = waiter{req.ch, start}
}

case resp := <-responses:
if resp.err != nil {
// ReadMsg failed, but we cannot match it to a request,
// so complete all pending requests.
for _, waiter := range waitingResponses {
waiter.ch <- sharedClientResponse{nil, 0, resp.err}
close(waiter.ch)
}
waitingResponses = make(map[uint16]waiter)
} else if resp.msg != nil {
if waiter, ok := waitingResponses[resp.msg.Id]; ok {
delete(waitingResponses, resp.msg.Id)
resp.rtt = time.Since(waiter.start)
waiter.ch <- resp
close(waiter.ch)
}
}
}
// Get the response of error from the current reader.
// Unlock for the blocking duration to allow concurrent writes
// on the client's connection.
c.Unlock()
resp = <-respCh
c.Lock()
if !errors.Is(resp.err, errNoReader) {
// error other than errNoReader received
break
}
}

func (c *SharedClient) ExchangeSharedContext(ctx context.Context, m *Msg) (r *Msg, rtt time.Duration, err error) {
c.Lock()
if c.conn == nil {
c.conn, err = c.DialContext(ctx, c.serverAddr)
if err != nil {
c.Unlock()
return nil, 0, fmt.Errorf("failed to dial connection to %v: %w", c.serverAddr, err)
}
// Trying again
// Start handler for sending and receiving.
c.wg.Add(1)
go handler(&c.wg, c.Client, c.conn, c.requestsToSend, c.responses)
}
c.Unlock()

respCh := make(chan sharedClientResponse)
c.requestsToSend <- request{
ctx: ctx,
msg: m,
ch: respCh,
}
return resp.Msg, time.Since(start), resp.err
resp := <-respCh
return resp.msg, resp.rtt, resp.err
}

0 comments on commit 6a4a248

Please sign in to comment.