Skip to content

Commit

Permalink
add cmux.Close() function
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhilash Gnan authored and soheilhy committed Jan 14, 2021
1 parent 8a8ea3c commit e13d1cb
Showing 1 changed file with 39 additions and 2 deletions.
41 changes: 39 additions & 2 deletions cmux.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ type CMux interface {
// Serve starts multiplexing the listener. Serve blocks and perhaps
// should be invoked concurrently within a go routine.
Serve() error
// Closes cmux server and stops accepting any connections on listener
Close()
// HandleError registers an error handler that handles listener errors.
HandleError(ErrorHandler)
// sets a timeout for the read of matchers
Expand All @@ -111,6 +113,7 @@ type cMux struct {
donec chan struct{}
sls []matchersListener
readTimeout time.Duration
mu sync.Mutex
}

func matchersToMatchWriters(matchers []Matcher) []MatchWriter {
Expand Down Expand Up @@ -146,7 +149,7 @@ func (m *cMux) Serve() error {
var wg sync.WaitGroup

defer func() {
close(m.donec)
m.closeDoneChanLocked()
wg.Wait()

for _, sl := range m.sls {
Expand All @@ -161,6 +164,11 @@ func (m *cMux) Serve() error {
for {
c, err := m.root.Accept()
if err != nil {
select {
case <-m.getDoneChan():
// cmux was closed with cmux.Close()
return nil
}
if !m.handleErr(err) {
return err
}
Expand Down Expand Up @@ -189,7 +197,7 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
}
select {
case sl.l.connc <- muc:
case <-donec:
case <-m.getDoneChan():
_ = c.Close()
}
return
Expand All @@ -204,6 +212,35 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
}
}

func (m *cMux) Close() {
m.mu.Lock()
defer m.mu.Unlock()
m.closeDoneChanLocked()
}

func (m *cMux) getDoneChan() chan struct{} {
m.mu.Lock()
defer m.mu.Unlock()
return m.getDoneChanLocked()
}

func (m *cMux) getDoneChanLocked() chan struct{} {
if m.donec == nil {
m.donec = make(chan struct{})
}
return m.donec
}

func (m *cMux) closeDoneChanLocked() {
ch := m.getDoneChanLocked()
select {
case <-ch:
// Already closed. Don't close again
default:
close(ch)
}
}

func (m *cMux) HandleError(h ErrorHandler) {
m.errh = h
}
Expand Down

0 comments on commit e13d1cb

Please sign in to comment.