Skip to content

Commit

Permalink
feat: support for restarting server protocols (#611)
Browse files Browse the repository at this point in the history
This allows a connected peer to start a protocol, stop it, and later
start it again within the same connection

Fixes #452
  • Loading branch information
agaffney authored May 12, 2024
1 parent 99f51e0 commit c111882
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 34 deletions.
20 changes: 20 additions & 0 deletions muxer/muxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,26 @@ func (m *Muxer) RegisterProtocol(
return senderChan, receiverChan, m.doneChan
}

func (m *Muxer) UnregisterProtocol(
protocolId uint16,
protocolRole ProtocolRole,
) {
m.protocolReceiversMutex.Lock()
protocolRoles, ok := m.protocolReceivers[protocolId]
if !ok {
return
}
recvChan, ok := protocolRoles[protocolRole]
if !ok {
return
}
// Signal shutdown to protocol
close(recvChan)
// Remove mapping
delete(protocolRoles, protocolRole)
m.protocolReceiversMutex.Unlock()
}

// Send takes a populated Segment and writes it to the connection. A mutex is used to prevent more than
// one protocol from sending at once
func (m *Muxer) Send(msg *Segment) error {
Expand Down
19 changes: 15 additions & 4 deletions protocol/blockfetch/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,37 @@ type Server struct {
*protocol.Protocol
config *Config
callbackContext CallbackContext
protoOptions protocol.ProtocolOptions
}

func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
s := &Server{
config: cfg,
// Save this for re-use later
protoOptions: protoOptions,
}
s.callbackContext = CallbackContext{
Server: s,
ConnectionId: protoOptions.ConnectionId,
}
s.initProtocol()
return s
}

func (s *Server) initProtocol() {
protoConfig := protocol.ProtocolConfig{
Name: ProtocolName,
ProtocolId: ProtocolId,
Muxer: protoOptions.Muxer,
ErrorChan: protoOptions.ErrorChan,
Mode: protoOptions.Mode,
Muxer: s.protoOptions.Muxer,
ErrorChan: s.protoOptions.ErrorChan,
Mode: s.protoOptions.Mode,
Role: protocol.ProtocolRoleServer,
MessageHandlerFunc: s.messageHandler,
MessageFromCborFunc: NewMsgFromCbor,
StateMap: StateMap,
InitialState: StateIdle,
}
s.Protocol = protocol.New(protoConfig)
return s
}

func (s *Server) NoBlocks() error {
Expand Down Expand Up @@ -107,5 +114,9 @@ func (s *Server) handleRequestRange(msg protocol.Message) error {
}

func (s *Server) handleClientDone() error {
// Restart protocol
s.Protocol.Stop()
s.initProtocol()
s.Protocol.Start()
return nil
}
39 changes: 26 additions & 13 deletions protocol/chainsync/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,40 +27,49 @@ type Server struct {
*protocol.Protocol
config *Config
callbackContext CallbackContext
protoOptions protocol.ProtocolOptions
stateContext any
}

// NewServer returns a new ChainSync server object
func NewServer(stateContext interface{}, protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
// Use node-to-client protocol ID
ProtocolId := ProtocolIdNtC
msgFromCborFunc := NewMsgFromCborNtC
if protoOptions.Mode == protocol.ProtocolModeNodeToNode {
// Use node-to-node protocol ID
ProtocolId = ProtocolIdNtN
msgFromCborFunc = NewMsgFromCborNtN
}
s := &Server{
config: cfg,
// Save these for re-use later
protoOptions: protoOptions,
stateContext: stateContext,
}
s.callbackContext = CallbackContext{
Server: s,
ConnectionId: protoOptions.ConnectionId,
}
s.initProtocol()
return s
}

func (s *Server) initProtocol() {
// Use node-to-client protocol ID
ProtocolId := ProtocolIdNtC
msgFromCborFunc := NewMsgFromCborNtC
if s.protoOptions.Mode == protocol.ProtocolModeNodeToNode {
// Use node-to-node protocol ID
ProtocolId = ProtocolIdNtN
msgFromCborFunc = NewMsgFromCborNtN
}
protoConfig := protocol.ProtocolConfig{
Name: ProtocolName,
ProtocolId: ProtocolId,
Muxer: protoOptions.Muxer,
ErrorChan: protoOptions.ErrorChan,
Mode: protoOptions.Mode,
Muxer: s.protoOptions.Muxer,
ErrorChan: s.protoOptions.ErrorChan,
Mode: s.protoOptions.Mode,
Role: protocol.ProtocolRoleServer,
MessageHandlerFunc: s.messageHandler,
MessageFromCborFunc: msgFromCborFunc,
StateMap: StateMap,
StateContext: stateContext,
StateContext: s.stateContext,
InitialState: stateIdle,
}
s.Protocol = protocol.New(protoConfig)
return s
}

func (s *Server) RollBackward(point common.Point, tip Tip) error {
Expand Down Expand Up @@ -147,5 +156,9 @@ func (s *Server) handleFindIntersect(msg protocol.Message) error {
}

func (s *Server) handleDone() error {
// Restart protocol
s.Protocol.Stop()
s.initProtocol()
s.Protocol.Start()
return nil
}
19 changes: 15 additions & 4 deletions protocol/peersharing/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,38 @@ type Server struct {
*protocol.Protocol
config *Config
callbackContext CallbackContext
protoOptions protocol.ProtocolOptions
}

// NewServer returns a new PeerSharing server object
func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
s := &Server{
config: cfg,
// Save this for re-use later
protoOptions: protoOptions,
}
s.callbackContext = CallbackContext{
Server: s,
ConnectionId: protoOptions.ConnectionId,
}
s.initProtocol()
return s
}

func (s *Server) initProtocol() {
protoConfig := protocol.ProtocolConfig{
Name: ProtocolName,
ProtocolId: ProtocolId,
Muxer: protoOptions.Muxer,
ErrorChan: protoOptions.ErrorChan,
Mode: protoOptions.Mode,
Muxer: s.protoOptions.Muxer,
ErrorChan: s.protoOptions.ErrorChan,
Mode: s.protoOptions.Mode,
Role: protocol.ProtocolRoleServer,
MessageHandlerFunc: s.handleMessage,
MessageFromCborFunc: NewMsgFromCbor,
StateMap: StateMap,
InitialState: stateIdle,
}
s.Protocol = protocol.New(protoConfig)
return s
}

func (s *Server) handleMessage(msg protocol.Message) error {
Expand Down Expand Up @@ -88,5 +95,9 @@ func (s *Server) handleShareRequest(msg protocol.Message) error {
}

func (s *Server) handleDone(msg protocol.Message) error {
// Restart protocol
s.Protocol.Stop()
s.initProtocol()
s.Protocol.Start()
return nil
}
16 changes: 16 additions & 0 deletions protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type Protocol struct {
sendReadyChan chan bool
stateTransitionChan chan<- protocolStateTransition
onceStart sync.Once
onceStop sync.Once
}

// ProtocolConfig provides the configuration for Protocol
Expand Down Expand Up @@ -147,6 +148,21 @@ func (p *Protocol) Start() {
})
}

// Stop shuts down the mini-protocol
func (p *Protocol) Stop() {
p.onceStop.Do(func() {
// Unregister protocol from muxer
muxerProtocolRole := muxer.ProtocolRoleInitiator
if p.config.Role == ProtocolRoleServer {
muxerProtocolRole = muxer.ProtocolRoleResponder
}
p.config.Muxer.RegisterProtocol(
p.config.ProtocolId,
muxerProtocolRole,
)
})
}

// Mode returns the protocol mode
func (p *Protocol) Mode() ProtocolMode {
return p.config.Mode
Expand Down
31 changes: 18 additions & 13 deletions protocol/txsubmission/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ type Server struct {
*protocol.Protocol
config *Config
callbackContext CallbackContext
protoOptions protocol.ProtocolOptions
ackCount int
stateDone bool
requestTxIdsResultChan chan []TxIdAndSize
requestTxsResultChan chan []TxBody
onceStart sync.Once
Expand All @@ -36,28 +36,34 @@ type Server struct {
// NewServer returns a new TxSubmission server object
func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
s := &Server{
config: cfg,
config: cfg,
// Save this for re-use later
protoOptions: protoOptions,
requestTxIdsResultChan: make(chan []TxIdAndSize),
requestTxsResultChan: make(chan []TxBody),
}
s.callbackContext = CallbackContext{
Server: s,
ConnectionId: protoOptions.ConnectionId,
}
s.initProtocol()
return s
}

func (s *Server) initProtocol() {
protoConfig := protocol.ProtocolConfig{
Name: ProtocolName,
ProtocolId: ProtocolId,
Muxer: protoOptions.Muxer,
ErrorChan: protoOptions.ErrorChan,
Mode: protoOptions.Mode,
Muxer: s.protoOptions.Muxer,
ErrorChan: s.protoOptions.ErrorChan,
Mode: s.protoOptions.Mode,
Role: protocol.ProtocolRoleServer,
MessageHandlerFunc: s.messageHandler,
MessageFromCborFunc: NewMsgFromCbor,
StateMap: StateMap,
InitialState: stateInit,
}
s.Protocol = protocol.New(protoConfig)
return s
}

func (s *Server) Start() {
Expand Down Expand Up @@ -98,9 +104,6 @@ func (s *Server) RequestTxIds(
blocking bool,
reqCount int,
) ([]TxIdAndSize, error) {
if s.stateDone {
return nil, protocol.ProtocolShuttingDownError
}
msg := NewMsgRequestTxIds(blocking, uint16(s.ackCount), uint16(reqCount))
if err := s.SendMessage(msg); err != nil {
return nil, err
Expand All @@ -117,9 +120,6 @@ func (s *Server) RequestTxIds(

// RequestTxs requests the content of the requested TX identifiers from the remote node's mempool
func (s *Server) RequestTxs(txIds []TxId) ([]TxBody, error) {
if s.stateDone {
return nil, protocol.ProtocolShuttingDownError
}
msg := NewMsgRequestTxs(txIds)
if err := s.SendMessage(msg); err != nil {
return nil, err
Expand Down Expand Up @@ -147,7 +147,12 @@ func (s *Server) handleReplyTxs(msg protocol.Message) error {
}

func (s *Server) handleDone() error {
s.stateDone = true
// Restart protocol
s.Protocol.Stop()
s.initProtocol()
s.requestTxIdsResultChan = make(chan []TxIdAndSize)
s.requestTxsResultChan = make(chan []TxBody)
s.Protocol.Start()
return nil
}

Expand Down

0 comments on commit c111882

Please sign in to comment.