diff --git a/.changeset/wet-turtles-provide.md b/.changeset/wet-turtles-provide.md new file mode 100644 index 00000000000..6a26eb52d12 --- /dev/null +++ b/.changeset/wet-turtles-provide.md @@ -0,0 +1,5 @@ +--- +"chainlink": minor +--- + +Copy common transmitter methods into FunctionsContractTransmitter to enable product specific modification diff --git a/core/services/relay/evm/functions.go b/core/services/relay/evm/functions.go index da423c6d5fc..38317a7d2cb 100644 --- a/core/services/relay/evm/functions.go +++ b/core/services/relay/evm/functions.go @@ -21,7 +21,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config" - "github.com/smartcontractkit/chainlink/v2/core/services/ocrcommon" functionsRelay "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/functions" evmRelayTypes "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/types" ) @@ -197,7 +196,13 @@ func newFunctionsContractTransmitter(ctx context.Context, contractVersion uint32 gasLimit = uint64(*ocr2Limit) } - transmitter, err := ocrcommon.NewTransmitter( + functionsTransmitter, err := functionsRelay.NewFunctionsContractTransmitter( + configWatcher.chain.Client(), + OCR2AggregatorTransmissionContractABI, + configWatcher.chain.LogPoller(), + lggr, + nil, + contractVersion, configWatcher.chain.TxManager(), fromAddresses, gasLimit, @@ -207,20 +212,6 @@ func newFunctionsContractTransmitter(ctx context.Context, contractVersion uint32 configWatcher.chain.ID(), ethKeystore, ) - - if err != nil { - return nil, errors.Wrap(err, "failed to create transmitter") - } - - functionsTransmitter, err := functionsRelay.NewFunctionsContractTransmitter( - configWatcher.chain.Client(), - OCR2AggregatorTransmissionContractABI, - transmitter, - configWatcher.chain.LogPoller(), - lggr, - nil, - contractVersion, - ) if err != nil { return nil, err } diff --git a/core/services/relay/evm/functions/contract_transmitter.go b/core/services/relay/evm/functions/contract_transmitter.go index 051b1f0bef9..2f945dca885 100644 --- a/core/services/relay/evm/functions/contract_transmitter.go +++ b/core/services/relay/evm/functions/contract_transmitter.go @@ -16,6 +16,7 @@ import ( "github.com/smartcontractkit/libocr/offchainreporting2plus/chains/evmutil" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "github.com/smartcontractkit/chainlink/v2/common/txmgr/types" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils" @@ -25,16 +26,19 @@ import ( evmRelayTypes "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/types" ) +type roundRobinKeystore interface { + GetRoundRobinAddress(ctx context.Context, chainID *big.Int, addresses ...common.Address) (address common.Address, err error) +} + +type txManager interface { + CreateTransaction(ctx context.Context, txRequest txmgr.TxRequest) (tx txmgr.Tx, err error) +} + type FunctionsContractTransmitter interface { services.ServiceCtx ocrtypes.ContractTransmitter } -type Transmitter interface { - CreateEthTransaction(ctx context.Context, toAddress common.Address, payload []byte, txMeta *txmgr.TxMeta) error - FromAddress() common.Address -} - type ReportToEthMetadata func([]byte) (*txmgr.TxMeta, error) func reportToEvmTxMetaNoop([]byte) (*txmgr.TxMeta, error) { @@ -42,16 +46,23 @@ func reportToEvmTxMetaNoop([]byte) (*txmgr.TxMeta, error) { } type contractTransmitter struct { - contractAddress atomic.Pointer[common.Address] - contractABI abi.ABI - transmitter Transmitter - transmittedEventSig common.Hash - contractReader contractReader - lp logpoller.LogPoller - lggr logger.Logger - reportToEvmTxMeta ReportToEthMetadata - contractVersion uint32 - reportCodec encoding.ReportCodec + contractAddress atomic.Pointer[common.Address] + contractABI abi.ABI + transmittedEventSig common.Hash + contractReader contractReader + lp logpoller.LogPoller + lggr logger.Logger + reportToEvmTxMeta ReportToEthMetadata + contractVersion uint32 + reportCodec encoding.ReportCodec + txm txManager + fromAddresses []common.Address + gasLimit uint64 + effectiveTransmitterAddress common.Address + strategy types.TxStrategy + checker txmgr.TransmitCheckerSpec + chainID *big.Int + keystore roundRobinKeystore } var _ FunctionsContractTransmitter = &contractTransmitter{} @@ -64,12 +75,24 @@ func transmitterFilterName(addr common.Address) string { func NewFunctionsContractTransmitter( caller contractReader, contractABI abi.ABI, - transmitter Transmitter, lp logpoller.LogPoller, lggr logger.Logger, reportToEvmTxMeta ReportToEthMetadata, contractVersion uint32, + txm txManager, + fromAddresses []common.Address, + gasLimit uint64, + effectiveTransmitterAddress common.Address, + strategy types.TxStrategy, + checker txmgr.TransmitCheckerSpec, + chainID *big.Int, + keystore roundRobinKeystore, ) (*contractTransmitter, error) { + // Ensure that a keystore is provided. + if keystore == nil { + return nil, errors.New("nil keystore provided to transmitter") + } + transmitted, ok := contractABI.Events["Transmitted"] if !ok { return nil, errors.New("invalid ABI, missing transmitted") @@ -87,18 +110,54 @@ func NewFunctionsContractTransmitter( return nil, err } return &contractTransmitter{ - contractABI: contractABI, - transmitter: transmitter, - transmittedEventSig: transmitted.ID, - lp: lp, - contractReader: caller, - lggr: lggr.Named("OCRContractTransmitter"), - reportToEvmTxMeta: reportToEvmTxMeta, - contractVersion: contractVersion, - reportCodec: codec, + contractABI: contractABI, + transmittedEventSig: transmitted.ID, + lp: lp, + contractReader: caller, + lggr: lggr.Named("OCRContractTransmitter"), + reportToEvmTxMeta: reportToEvmTxMeta, + contractVersion: contractVersion, + reportCodec: codec, + txm: txm, + fromAddresses: fromAddresses, + gasLimit: gasLimit, + effectiveTransmitterAddress: effectiveTransmitterAddress, + strategy: strategy, + checker: checker, + chainID: chainID, + keystore: keystore, }, nil } +func (oc *contractTransmitter) createEthTransaction(ctx context.Context, toAddress common.Address, payload []byte, txMeta *txmgr.TxMeta) error { + + roundRobinFromAddress, err := oc.keystore.GetRoundRobinAddress(ctx, oc.chainID, oc.fromAddresses...) + if err != nil { + return errors.Wrap(err, "skipped OCR transmission, error getting round-robin address") + } + + _, err = oc.txm.CreateTransaction(ctx, txmgr.TxRequest{ + FromAddress: roundRobinFromAddress, + ToAddress: toAddress, + EncodedPayload: payload, + FeeLimit: oc.gasLimit, + ForwarderAddress: oc.forwarderAddress(), + Strategy: oc.strategy, + Checker: oc.checker, + Meta: txMeta, + }) + return errors.Wrap(err, "skipped OCR transmission") +} + +func (oc *contractTransmitter) forwarderAddress() common.Address { + for _, a := range oc.fromAddresses { + if a == oc.effectiveTransmitterAddress { + return common.Address{} + } + } + return oc.effectiveTransmitterAddress +} + // Transmit sends the report to the on-chain smart contract's Transmit method. func (oc *contractTransmitter) Transmit(ctx context.Context, reportCtx ocrtypes.ReportContext, report ocrtypes.Report, signatures []ocrtypes.AttributedOnchainSignature) error { var rs [][32]byte @@ -161,7 +220,7 @@ func (oc *contractTransmitter) Transmit(ctx context.Context, reportCtx ocrtypes. } oc.lggr.Debugw("FunctionsContractTransmitter: transmitting report", "contractAddress", destinationContract, "txMeta", txMeta, "payloadSize", len(payload)) - return errors.Wrap(oc.transmitter.CreateEthTransaction(ctx, destinationContract, payload, txMeta), "failed to send Eth transaction") + return errors.Wrap(oc.createEthTransaction(ctx, destinationContract, payload, txMeta), "failed to send Eth transaction") } type contractReader interface { @@ -240,7 +299,7 @@ func (oc *contractTransmitter) LatestConfigDigestAndEpoch(ctx context.Context) ( // FromAccount returns the account from which the transmitter invokes the contract func (oc *contractTransmitter) FromAccount() (ocrtypes.Account, error) { - return ocrtypes.Account(oc.transmitter.FromAddress().String()), nil + return ocrtypes.Account(oc.effectiveTransmitterAddress.String()), nil } func (oc *contractTransmitter) Start(ctx context.Context) error { return nil } diff --git a/core/services/relay/evm/functions/contract_transmitter_test.go b/core/services/relay/evm/functions/contract_transmitter_test.go index e9712a3687c..06342564c32 100644 --- a/core/services/relay/evm/functions/contract_transmitter_test.go +++ b/core/services/relay/evm/functions/contract_transmitter_test.go @@ -1,42 +1,46 @@ package functions_test import ( - "context" "encoding/hex" + "math/big" "strings" "testing" "github.com/ethereum/go-ethereum/accounts/abi" gethcommon "github.com/ethereum/go-ethereum/common" "github.com/smartcontractkit/libocr/gethwrappers2/ocr2aggregator" + "github.com/smartcontractkit/libocr/offchainreporting2plus/chains/evmutil" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + commontxmmocks "github.com/smartcontractkit/chainlink/v2/common/txmgr/types/mocks" evmclimocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client/mocks" lpmocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller/mocks" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" + txmmocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr/mocks" + "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/encoding" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/functions" ) -type mockTransmitter struct { - toAddress gethcommon.Address +func newMockTxStrategy(t *testing.T) *commontxmmocks.TxStrategy { + return commontxmmocks.NewTxStrategy(t) } -func (m *mockTransmitter) CreateEthTransaction(ctx context.Context, toAddress gethcommon.Address, payload []byte, _ *txmgr.TxMeta) error { - m.toAddress = toAddress - return nil -} -func (mockTransmitter) FromAddress() gethcommon.Address { return testutils.NewAddress() } - func TestContractTransmitter_LatestConfigDigestAndEpoch(t *testing.T) { t.Parallel() ctx := testutils.Context(t) + db := pgtest.NewSqlxDB(t) + cfg := configtest.NewTestGeneralConfig(t) + ethKeyStore := cltest.NewKeyStore(t, db, cfg.Database()).Eth() + digestStr := "000130da6b9315bd59af6b0a3f5463c0d0a39e92eaa34cbcbdbace7b3bfcc776" lggr := logger.TestLogger(t) c := evmclimocks.NewClient(t) @@ -49,11 +53,27 @@ func TestContractTransmitter_LatestConfigDigestAndEpoch(t *testing.T) { c.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return(digestAndEpochDontScanLogs, nil).Once() contractABI, err := abi.JSON(strings.NewReader(ocr2aggregator.OCR2AggregatorABI)) require.NoError(t, err) + txm := txmmocks.NewMockEvmTxManager(t) + _, fromAddress := cltest.MustInsertRandomKey(t, ethKeyStore) + gasLimit := uint64(1000) + chainID := big.NewInt(0) + effectiveTransmitterAddress := fromAddress + strategy := newMockTxStrategy(t) lp.On("RegisterFilter", mock.Anything, mock.Anything).Return(nil) - functionsTransmitter, err := functions.NewFunctionsContractTransmitter(c, contractABI, &mockTransmitter{}, lp, lggr, func(b []byte) (*txmgr.TxMeta, error) { + functionsTransmitter, err := functions.NewFunctionsContractTransmitter(c, contractABI, lp, lggr, func(b []byte) (*txmgr.TxMeta, error) { return &txmgr.TxMeta{}, nil - }, 1) + }, + 1, + txm, + []gethcommon.Address{fromAddress}, + gasLimit, + effectiveTransmitterAddress, + strategy, + txmgr.TransmitCheckerSpec{}, + chainID, + ethKeyStore, + ) require.NoError(t, err) require.NoError(t, functionsTransmitter.UpdateRoutes(ctx, gethcommon.Address{}, gethcommon.Address{})) @@ -67,18 +87,36 @@ func TestContractTransmitter_Transmit_V1(t *testing.T) { t.Parallel() ctx := testutils.Context(t) + db := pgtest.NewSqlxDB(t) + cfg := configtest.NewTestGeneralConfig(t) + ethKeyStore := cltest.NewKeyStore(t, db, cfg.Database()).Eth() + contractVersion := uint32(1) configuredDestAddress, coordinatorAddress := testutils.NewAddress(), testutils.NewAddress() lggr := logger.TestLogger(t) c := evmclimocks.NewClient(t) lp := lpmocks.NewLogPoller(t) contractABI, _ := abi.JSON(strings.NewReader(ocr2aggregator.OCR2AggregatorABI)) + txm := txmmocks.NewMockEvmTxManager(t) + _, fromAddress := cltest.MustInsertRandomKey(t, ethKeyStore) + gasLimit := uint64(1000) + chainID := big.NewInt(0) + effectiveTransmitterAddress := fromAddress + strategy := newMockTxStrategy(t) lp.On("RegisterFilter", mock.Anything, mock.Anything).Return(nil) - ocrTransmitter := mockTransmitter{} - ot, err := functions.NewFunctionsContractTransmitter(c, contractABI, &ocrTransmitter, lp, lggr, func(b []byte) (*txmgr.TxMeta, error) { + ot, err := functions.NewFunctionsContractTransmitter(c, contractABI, lp, lggr, func(b []byte) (*txmgr.TxMeta, error) { return &txmgr.TxMeta{}, nil - }, contractVersion) + }, contractVersion, + txm, + []gethcommon.Address{fromAddress}, + gasLimit, + effectiveTransmitterAddress, + strategy, + txmgr.TransmitCheckerSpec{}, + chainID, + ethKeyStore, + ) require.NoError(t, err) require.NoError(t, ot.UpdateRoutes(ctx, configuredDestAddress, configuredDestAddress)) @@ -94,10 +132,24 @@ func TestContractTransmitter_Transmit_V1(t *testing.T) { require.NoError(t, err) reportBytes, err := codec.EncodeReport(processedRequests) require.NoError(t, err) + rawReportCtx := evmutil.RawReportContext(ocrtypes.ReportContext{}) + var rs [][32]byte + var ss [][32]byte + var vs [32]byte + payload, err := contractABI.Pack("transmit", rawReportCtx, reportBytes, rs, ss, vs) + require.NoError(t, err) // success + txm.On("CreateTransaction", mock.Anything, txmgr.TxRequest{ + FromAddress: fromAddress, + ToAddress: coordinatorAddress, + EncodedPayload: payload, + FeeLimit: gasLimit, + ForwarderAddress: gethcommon.Address{}, + Meta: &txmgr.TxMeta{}, + Strategy: strategy, + }).Return(txmgr.Tx{}, nil).Once() require.NoError(t, ot.Transmit(testutils.Context(t), ocrtypes.ReportContext{}, reportBytes, []ocrtypes.AttributedOnchainSignature{})) - require.Equal(t, coordinatorAddress, ocrTransmitter.toAddress) // failure on too many signatures signatures := []ocrtypes.AttributedOnchainSignature{} @@ -111,18 +163,36 @@ func TestContractTransmitter_Transmit_V1_CoordinatorMismatch(t *testing.T) { t.Parallel() ctx := testutils.Context(t) + db := pgtest.NewSqlxDB(t) + cfg := configtest.NewTestGeneralConfig(t) + ethKeyStore := cltest.NewKeyStore(t, db, cfg.Database()).Eth() + contractVersion := uint32(1) configuredDestAddress, coordinatorAddress1, coordinatorAddress2 := testutils.NewAddress(), testutils.NewAddress(), testutils.NewAddress() lggr := logger.TestLogger(t) c := evmclimocks.NewClient(t) lp := lpmocks.NewLogPoller(t) contractABI, _ := abi.JSON(strings.NewReader(ocr2aggregator.OCR2AggregatorABI)) + txm := txmmocks.NewMockEvmTxManager(t) + _, fromAddress := cltest.MustInsertRandomKey(t, ethKeyStore) + gasLimit := uint64(1000) + chainID := big.NewInt(0) + effectiveTransmitterAddress := fromAddress + strategy := newMockTxStrategy(t) lp.On("RegisterFilter", mock.Anything, mock.Anything).Return(nil) - ocrTransmitter := mockTransmitter{} - ot, err := functions.NewFunctionsContractTransmitter(c, contractABI, &ocrTransmitter, lp, lggr, func(b []byte) (*txmgr.TxMeta, error) { + ot, err := functions.NewFunctionsContractTransmitter(c, contractABI, lp, lggr, func(b []byte) (*txmgr.TxMeta, error) { return &txmgr.TxMeta{}, nil - }, contractVersion) + }, contractVersion, + txm, + []gethcommon.Address{fromAddress}, + gasLimit, + effectiveTransmitterAddress, + strategy, + txmgr.TransmitCheckerSpec{}, + chainID, + ethKeyStore, + ) require.NoError(t, err) require.NoError(t, ot.UpdateRoutes(ctx, configuredDestAddress, configuredDestAddress)) @@ -144,7 +214,21 @@ func TestContractTransmitter_Transmit_V1_CoordinatorMismatch(t *testing.T) { require.NoError(t, err) reportBytes, err := codec.EncodeReport(processedRequests) require.NoError(t, err) + rawReportCtx := evmutil.RawReportContext(ocrtypes.ReportContext{}) + var rs [][32]byte + var ss [][32]byte + var vs [32]byte + payload, err := contractABI.Pack("transmit", rawReportCtx, reportBytes, rs, ss, vs) + require.NoError(t, err) + txm.On("CreateTransaction", mock.Anything, txmgr.TxRequest{ + FromAddress: fromAddress, + ToAddress: coordinatorAddress1, + EncodedPayload: payload, + FeeLimit: gasLimit, + ForwarderAddress: gethcommon.Address{}, + Meta: &txmgr.TxMeta{}, + Strategy: strategy, + }).Return(txmgr.Tx{}, nil).Once() require.NoError(t, ot.Transmit(testutils.Context(t), ocrtypes.ReportContext{}, reportBytes, []ocrtypes.AttributedOnchainSignature{})) - require.Equal(t, coordinatorAddress1, ocrTransmitter.toAddress) }