Skip to content

Commit

Permalink
Create HeadPoller for Multi-Node (#12871)
Browse files Browse the repository at this point in the history
* Create polling transformer

* Update poller

* Rename to HeadPoller

* lint

* update poller

* Update head poller

* Update poller

* lint

* Refactor Poller

* Update poller_test.go

* Update poller

* Synchronize tests

* Refactor with timeout

* Check test logs

* Update Poller

* Update poller_test.go

* Update poller_test.go

* Simplify poller

* Set logging to warn
  • Loading branch information
DylanTinianov committed Apr 26, 2024
1 parent c98ea64 commit 7338448
Show file tree
Hide file tree
Showing 2 changed files with 305 additions and 0 deletions.
98 changes: 98 additions & 0 deletions common/client/poller.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package client

import (
"context"
"sync"
"time"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/services"

"github.com/smartcontractkit/chainlink/v2/common/types"
)

// Poller is a component that polls a function at a given interval
// and delivers the result to a channel. It is used by multinode to poll
// for new heads and implements the Subscription interface.
type Poller[T any] struct {
services.StateMachine
pollingInterval time.Duration
pollingFunc func(ctx context.Context) (T, error)
pollingTimeout time.Duration
logger logger.Logger
channel chan<- T
errCh chan error

stopCh services.StopChan
wg sync.WaitGroup
}

// NewPoller creates a new Poller instance
func NewPoller[
T any,
](pollingInterval time.Duration, pollingFunc func(ctx context.Context) (T, error), pollingTimeout time.Duration, channel chan<- T, logger logger.Logger) Poller[T] {
return Poller[T]{
pollingInterval: pollingInterval,
pollingFunc: pollingFunc,
pollingTimeout: pollingTimeout,
channel: channel,
logger: logger,
errCh: make(chan error),
stopCh: make(chan struct{}),
}
}

var _ types.Subscription = &Poller[any]{}

func (p *Poller[T]) Start() error {
return p.StartOnce("Poller", func() error {
p.wg.Add(1)
go p.pollingLoop()
return nil
})
}

// Unsubscribe cancels the sending of events to the data channel
func (p *Poller[T]) Unsubscribe() {
_ = p.StopOnce("Poller", func() error {
close(p.stopCh)
p.wg.Wait()
close(p.errCh)
return nil
})
}

func (p *Poller[T]) Err() <-chan error {
return p.errCh
}

func (p *Poller[T]) pollingLoop() {
defer p.wg.Done()

ticker := time.NewTicker(p.pollingInterval)
defer ticker.Stop()

for {
select {
case <-p.stopCh:
return
case <-ticker.C:
// Set polling timeout
pollingCtx, cancelPolling := context.WithTimeout(context.Background(), p.pollingTimeout)
p.stopCh.CtxCancel(pollingCtx, cancelPolling)
// Execute polling function
result, err := p.pollingFunc(pollingCtx)
cancelPolling()
if err != nil {
p.logger.Warnf("polling error: %v", err)
continue
}
// Send result to channel or block if channel is full
select {
case p.channel <- result:
case <-p.stopCh:
return
}
}
}
}
207 changes: 207 additions & 0 deletions common/client/poller_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
package client

import (
"context"
"fmt"
"math/big"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
)

func Test_Poller(t *testing.T) {
lggr := logger.Test(t)

t.Run("Test multiple start", func(t *testing.T) {
pollFunc := func(ctx context.Context) (Head, error) {
return nil, nil
}

channel := make(chan Head, 1)
defer close(channel)

poller := NewPoller[Head](time.Millisecond, pollFunc, time.Second, channel, lggr)
err := poller.Start()
require.NoError(t, err)

err = poller.Start()
require.Error(t, err)
poller.Unsubscribe()
})

t.Run("Test polling for heads", func(t *testing.T) {
// Mock polling function that returns a new value every time it's called
var pollNumber int
pollLock := sync.Mutex{}
pollFunc := func(ctx context.Context) (Head, error) {
pollLock.Lock()
defer pollLock.Unlock()
pollNumber++
h := head{
BlockNumber: int64(pollNumber),
BlockDifficulty: big.NewInt(int64(pollNumber)),
}
return h.ToMockHead(t), nil
}

// data channel to receive updates from the poller
channel := make(chan Head, 1)
defer close(channel)

// Create poller and start to receive data
poller := NewPoller[Head](time.Millisecond, pollFunc, time.Second, channel, lggr)
require.NoError(t, poller.Start())
defer poller.Unsubscribe()

// Receive updates from the poller
pollCount := 0
pollMax := 50
for ; pollCount < pollMax; pollCount++ {
h := <-channel
assert.Equal(t, int64(pollCount+1), h.BlockNumber())
}
})

t.Run("Test polling errors", func(t *testing.T) {
// Mock polling function that returns an error
var pollNumber int
pollLock := sync.Mutex{}
pollFunc := func(ctx context.Context) (Head, error) {
pollLock.Lock()
defer pollLock.Unlock()
pollNumber++
return nil, fmt.Errorf("polling error %d", pollNumber)
}

// data channel to receive updates from the poller
channel := make(chan Head, 1)
defer close(channel)

olggr, observedLogs := logger.TestObserved(t, zap.WarnLevel)

// Create poller and subscribe to receive data
poller := NewPoller[Head](time.Millisecond, pollFunc, time.Second, channel, olggr)
require.NoError(t, poller.Start())
defer poller.Unsubscribe()

// Ensure that all errors were logged as expected
logsSeen := func() bool {
for pollCount := 0; pollCount < 50; pollCount++ {
numLogs := observedLogs.FilterMessage(fmt.Sprintf("polling error: polling error %d", pollCount+1)).Len()
if numLogs != 1 {
return false
}
}
return true
}
require.Eventually(t, logsSeen, time.Second, time.Millisecond)
})

t.Run("Test polling timeout", func(t *testing.T) {
pollFunc := func(ctx context.Context) (Head, error) {
if <-ctx.Done(); true {
return nil, ctx.Err()
}
return nil, nil
}

// Set instant timeout
pollingTimeout := time.Duration(0)

// data channel to receive updates from the poller
channel := make(chan Head, 1)
defer close(channel)

olggr, observedLogs := logger.TestObserved(t, zap.WarnLevel)

// Create poller and subscribe to receive data
poller := NewPoller[Head](time.Millisecond, pollFunc, pollingTimeout, channel, olggr)
require.NoError(t, poller.Start())
defer poller.Unsubscribe()

// Ensure that timeout errors were logged as expected
logsSeen := func() bool {
return observedLogs.FilterMessage("polling error: context deadline exceeded").Len() >= 1
}
require.Eventually(t, logsSeen, time.Second, time.Millisecond)
})

t.Run("Test unsubscribe during polling", func(t *testing.T) {
wait := make(chan struct{})
pollFunc := func(ctx context.Context) (Head, error) {
close(wait)
// Block in polling function until context is cancelled
if <-ctx.Done(); true {
return nil, ctx.Err()
}
return nil, nil
}

// Set long timeout
pollingTimeout := time.Minute

// data channel to receive updates from the poller
channel := make(chan Head, 1)
defer close(channel)

olggr, observedLogs := logger.TestObserved(t, zap.WarnLevel)

// Create poller and subscribe to receive data
poller := NewPoller[Head](time.Millisecond, pollFunc, pollingTimeout, channel, olggr)
require.NoError(t, poller.Start())

// Unsubscribe while blocked in polling function
<-wait
poller.Unsubscribe()

// Ensure error was logged
logsSeen := func() bool {
return observedLogs.FilterMessage("polling error: context canceled").Len() >= 1
}
require.Eventually(t, logsSeen, time.Second, time.Millisecond)
})
}

func Test_Poller_Unsubscribe(t *testing.T) {
lggr := logger.Test(t)
pollFunc := func(ctx context.Context) (Head, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
h := head{
BlockNumber: 0,
BlockDifficulty: big.NewInt(0),
}
return h.ToMockHead(t), nil
}
}

t.Run("Test multiple unsubscribe", func(t *testing.T) {
channel := make(chan Head, 1)
poller := NewPoller[Head](time.Millisecond, pollFunc, time.Second, channel, lggr)
err := poller.Start()
require.NoError(t, err)

<-channel
poller.Unsubscribe()
poller.Unsubscribe()
})

t.Run("Test unsubscribe with closed channel", func(t *testing.T) {
channel := make(chan Head, 1)
poller := NewPoller[Head](time.Millisecond, pollFunc, time.Second, channel, lggr)
err := poller.Start()
require.NoError(t, err)

<-channel
close(channel)
poller.Unsubscribe()
})
}

0 comments on commit 7338448

Please sign in to comment.