From f54bd1a806b2b85d8a5cc87d4a9e1f483936554a Mon Sep 17 00:00:00 2001 From: Rod Vagg Date: Fri, 3 Jun 2022 10:45:28 +1000 Subject: [PATCH] feat(ipld): vouchers as plain ipld.Node (#325) * feat(ipld): vouchers as plain ipld.Node * feat: add ValidationResult#Equals() utility * feat(ipld): introduce TypedVoucher tuple type * chore(ipld): ipld.Node -> datamodel.Node * chore: remove RegisterVoucherResultType * fix: minor staticcheck fixes --- README.md | 6 +- benchmarks/benchmark_test.go | 2 +- benchmarks/testinstance/testinstance.go | 3 +- channels/channel_state.go | 79 +++--- channels/channels.go | 42 ++- channels/channels_test.go | 74 ++--- encoding/encoding.go | 171 ------------ encoding/encoding_test.go | 37 --- encoding/testdata/testdata.go | 37 --- encoding/testdata/testdata_cbor_gen.go | 84 ------ go.mod | 4 +- go.sum | 3 +- impl/events.go | 12 +- impl/impl.go | 55 ++-- impl/initiating_test.go | 68 ++--- impl/integration_test.go | 253 ++++++++---------- impl/receiving_requests.go | 50 ++-- impl/responding_test.go | 142 ++++++---- impl/restart.go | 40 +-- impl/restart_integration_test.go | 28 +- impl/utils.go | 33 +-- ipldutils/ipldutils.go | 183 +++++++++++++ manager.go | 40 +-- message.go | 10 +- message/message1_1prime/message.go | 107 ++++---- message/message1_1prime/message_test.go | 99 +++---- message/message1_1prime/schema.go | 29 -- message/message1_1prime/schema.ipldsch | 2 +- message/message1_1prime/transfer_message.go | 29 +- message/message1_1prime/transfer_request.go | 40 ++- .../message1_1prime/transfer_request_test.go | 6 +- message/message1_1prime/transfer_response.go | 23 +- .../message1_1prime/transfer_response_test.go | 6 +- network/libp2p_impl_test.go | 16 +- registry/registry.go | 35 +-- registry/registry_test.go | 20 +- testutil/fakedttype.go | 97 ++++--- testutil/fakedttype_cbor_gen.go | 75 ------ testutil/fakegraphsync.go | 8 +- testutil/faketransport.go | 9 +- testutil/gstestdata.go | 5 +- testutil/message.go | 8 +- testutil/mockchannelstate.go | 14 +- testutil/stubbedvalidator.go | 14 +- testutil/testutil.go | 4 +- transport.go | 3 +- transport/graphsync/graphsync.go | 5 +- types.go | 42 ++- 48 files changed, 938 insertions(+), 1214 deletions(-) delete mode 100644 encoding/encoding.go delete mode 100644 encoding/encoding_test.go delete mode 100644 encoding/testdata/testdata.go delete mode 100644 encoding/testdata/testdata_cbor_gen.go create mode 100644 ipldutils/ipldutils.go delete mode 100644 message/message1_1prime/schema.go delete mode 100644 testutil/fakedttype_cbor_gen.go diff --git a/README.md b/README.md index f2add43b..3ab3f91f 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ func (vl *myValidator) ValidatePush( sender peer.ID, voucher datatransfer.Voucher, baseCid cid.Cid, - selector ipld.Node) error { + selector datamodel.Node) error { v := voucher.(*myVoucher) if v.data == "" || v.data != "validpush" { @@ -99,7 +99,7 @@ func (vl *myValidator) ValidatePull( receiver peer.ID, voucher datatransfer.Voucher, baseCid cid.Cid, - selector ipld.Node) error { + selector datamodel.Node) error { v := voucher.(*myVoucher) if v.data == "" || v.data != "validpull" { @@ -135,7 +135,7 @@ must be sent with the request. Using the trivial examples above: For more detail, please see the [unit tests](https://github.com/filecoin-project/go-data-transfer/blob/master/impl/impl_test.go). ### Open a Push or Pull Request -For a push or pull request, provide a context, a `datatransfer.Voucher`, a host recipient `peer.ID`, a baseCID `cid.CID` and a selector `ipld.Node`. These +For a push or pull request, provide a context, a `datatransfer.Voucher`, a host recipient `peer.ID`, a baseCID `cid.CID` and a selector `datamodel.Node`. These calls return a `datatransfer.ChannelID` and any error: ```go channelID, err := dtm.OpenPullDataChannel(ctx, recipient, voucher, baseCid, selector) diff --git a/benchmarks/benchmark_test.go b/benchmarks/benchmark_test.go index 56755598..0a170d98 100644 --- a/benchmarks/benchmark_test.go +++ b/benchmarks/benchmark_test.go @@ -105,7 +105,7 @@ func p2pStrestTest(ctx context.Context, b *testing.B, numfiles int, df distFunc, timer := time.NewTimer(30 * time.Second) start := time.Now() for j := 0; j < numfiles; j++ { - _, err := pusher.Manager.OpenPushDataChannel(ctx, receiver.Peer, testutil.NewFakeDTType(), allCids[j], allSelector) + _, err := pusher.Manager.OpenPushDataChannel(ctx, receiver.Peer, testutil.NewTestTypedVoucher(), allCids[j], allSelector) if err != nil { b.Fatalf("received error on request: %s", err.Error()) } diff --git a/benchmarks/testinstance/testinstance.go b/benchmarks/testinstance/testinstance.go index c720f44b..18a8e8de 100644 --- a/benchmarks/testinstance/testinstance.go +++ b/benchmarks/testinstance/testinstance.go @@ -188,8 +188,7 @@ func NewInstance(ctx context.Context, net tn.Network, tempDir string, diskBasedD sv := testutil.NewStubbedValidator() sv.StubSuccessPull() sv.StubSuccessPush() - dt.RegisterVoucherType(testutil.NewFakeDTType(), sv) - dt.RegisterVoucherResultType(testutil.NewFakeDTType()) + dt.RegisterVoucherType(testutil.TestVoucherType, sv) return Instance{ Adapter: dtNet, Peer: p, diff --git a/channels/channel_state.go b/channels/channel_state.go index a24a7feb..31c18d4b 100644 --- a/channels/channel_state.go +++ b/channels/channel_state.go @@ -4,22 +4,19 @@ import ( "bytes" "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/codec/dagcbor" + "github.com/ipld/go-ipld-prime/datamodel" basicnode "github.com/ipld/go-ipld-prime/node/basic" peer "github.com/libp2p/go-libp2p-core/peer" datatransfer "github.com/filecoin-project/go-data-transfer/v2" "github.com/filecoin-project/go-data-transfer/v2/channels/internal" + ipldutils "github.com/filecoin-project/go-data-transfer/v2/ipldutils" ) // channelState is immutable channel data plus mutable state type channelState struct { ic internal.ChannelState - - // additional voucherResults - voucherResultDecoder DecoderByTypeFunc - voucherDecoder DecoderByTypeFunc } // EmptyChannelState is the zero value for channel state, meaning not present @@ -45,7 +42,7 @@ func (c channelState) BaseCID() cid.Cid { return c.ic.BaseCid } // Selector returns the IPLD selector for this data transfer (represented as // an IPLD node) -func (c channelState) Selector() ipld.Node { +func (c channelState) Selector() datamodel.Node { builder := basicnode.Prototype.Any.NewBuilder() reader := bytes.NewReader(c.ic.Selector.Raw) err := dagcbor.Decode(builder, reader) @@ -56,13 +53,15 @@ func (c channelState) Selector() ipld.Node { } // Voucher returns the voucher for this data transfer -func (c channelState) Voucher() datatransfer.Voucher { +func (c channelState) Voucher() (datatransfer.TypedVoucher, error) { if len(c.ic.Vouchers) == 0 { - return nil + return datatransfer.TypedVoucher{}, nil + } + node, err := ipldutils.DeferredToNode(c.ic.Vouchers[0].Voucher) + if err != nil { + return datatransfer.TypedVoucher{}, err } - decoder, _ := c.voucherDecoder(c.ic.Vouchers[0].Type) - encodable, _ := decoder.DecodeFromCbor(c.ic.Vouchers[0].Voucher.Raw) - return encodable.(datatransfer.Voucher) + return datatransfer.TypedVoucher{Voucher: node, Type: c.ic.Vouchers[0].Type}, nil } // ReceivedCidsTotal returns the number of (non-unique) cids received so far @@ -108,36 +107,46 @@ func (c channelState) Message() string { return c.ic.Message } -func (c channelState) Vouchers() []datatransfer.Voucher { - vouchers := make([]datatransfer.Voucher, 0, len(c.ic.Vouchers)) +func (c channelState) Vouchers() ([]datatransfer.TypedVoucher, error) { + vouchers := make([]datatransfer.TypedVoucher, 0, len(c.ic.Vouchers)) for _, encoded := range c.ic.Vouchers { - decoder, _ := c.voucherDecoder(encoded.Type) - encodable, _ := decoder.DecodeFromCbor(encoded.Voucher.Raw) - vouchers = append(vouchers, encodable.(datatransfer.Voucher)) + node, err := ipldutils.DeferredToNode(encoded.Voucher) + if err != nil { + return nil, err + } + vouchers = append(vouchers, datatransfer.TypedVoucher{Voucher: node, Type: encoded.Type}) } - return vouchers + return vouchers, nil } -func (c channelState) LastVoucher() datatransfer.Voucher { - decoder, _ := c.voucherDecoder(c.ic.Vouchers[len(c.ic.Vouchers)-1].Type) - encodable, _ := decoder.DecodeFromCbor(c.ic.Vouchers[len(c.ic.Vouchers)-1].Voucher.Raw) - return encodable.(datatransfer.Voucher) +func (c channelState) LastVoucher() (datatransfer.TypedVoucher, error) { + ev := c.ic.Vouchers[len(c.ic.Vouchers)-1] + node, err := ipldutils.DeferredToNode(ev.Voucher) + if err != nil { + return datatransfer.TypedVoucher{}, err + } + return datatransfer.TypedVoucher{Voucher: node, Type: ev.Type}, nil } -func (c channelState) LastVoucherResult() datatransfer.VoucherResult { - decoder, _ := c.voucherResultDecoder(c.ic.VoucherResults[len(c.ic.VoucherResults)-1].Type) - encodable, _ := decoder.DecodeFromCbor(c.ic.VoucherResults[len(c.ic.VoucherResults)-1].VoucherResult.Raw) - return encodable.(datatransfer.VoucherResult) +func (c channelState) LastVoucherResult() (datatransfer.TypedVoucher, error) { + evr := c.ic.VoucherResults[len(c.ic.VoucherResults)-1] + node, err := ipldutils.DeferredToNode(evr.VoucherResult) + if err != nil { + return datatransfer.TypedVoucher{}, err + } + return datatransfer.TypedVoucher{Voucher: node, Type: evr.Type}, nil } -func (c channelState) VoucherResults() []datatransfer.VoucherResult { - voucherResults := make([]datatransfer.VoucherResult, 0, len(c.ic.VoucherResults)) +func (c channelState) VoucherResults() ([]datatransfer.TypedVoucher, error) { + voucherResults := make([]datatransfer.TypedVoucher, 0, len(c.ic.VoucherResults)) for _, encoded := range c.ic.VoucherResults { - decoder, _ := c.voucherResultDecoder(encoded.Type) - encodable, _ := decoder.DecodeFromCbor(encoded.VoucherResult.Raw) - voucherResults = append(voucherResults, encodable.(datatransfer.VoucherResult)) + node, err := ipldutils.DeferredToNode(encoded.VoucherResult) + if err != nil { + return nil, err + } + voucherResults = append(voucherResults, datatransfer.TypedVoucher{Voucher: node, Type: encoded.Type}) } - return voucherResults + return voucherResults, nil } func (c channelState) SelfPeer() peer.ID { @@ -174,12 +183,8 @@ func (c channelState) Stages() *datatransfer.ChannelStages { return c.ic.Stages } -func fromInternalChannelState(c internal.ChannelState, voucherDecoder DecoderByTypeFunc, voucherResultDecoder DecoderByTypeFunc) datatransfer.ChannelState { - return channelState{ - ic: c, - voucherResultDecoder: voucherResultDecoder, - voucherDecoder: voucherDecoder, - } +func fromInternalChannelState(c internal.ChannelState) datatransfer.ChannelState { + return channelState{ic: c} } var _ datatransfer.ChannelState = channelState{} diff --git a/channels/channels.go b/channels/channels.go index 0c575f54..4c7f1572 100644 --- a/channels/channels.go +++ b/channels/channels.go @@ -7,7 +7,7 @@ import ( "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" - "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" peer "github.com/libp2p/go-libp2p-core/peer" cbg "github.com/whyrusleeping/cbor-gen" "golang.org/x/xerrors" @@ -20,11 +20,9 @@ import ( datatransfer "github.com/filecoin-project/go-data-transfer/v2" "github.com/filecoin-project/go-data-transfer/v2/channels/internal" "github.com/filecoin-project/go-data-transfer/v2/channels/internal/migrations" - "github.com/filecoin-project/go-data-transfer/v2/encoding" + ipldutils "github.com/filecoin-project/go-data-transfer/v2/ipldutils" ) -type DecoderByTypeFunc func(identifier datatransfer.TypeIdentifier) (encoding.Decoder, bool) - type Notifier func(datatransfer.Event, datatransfer.ChannelState) // ErrNotFound is returned when a channel cannot be found with a given channel ID @@ -46,8 +44,6 @@ var ErrWrongType = errors.New("Cannot change type of implementation specific dat // Channels is a thread safe list of channels type Channels struct { notifier Notifier - voucherDecoder DecoderByTypeFunc - voucherResultDecoder DecoderByTypeFunc blockIndexCache *blockIndexCache progressCache *progressCache stateMachines fsm.Group @@ -65,16 +61,10 @@ type ChannelEnvironment interface { // New returns a new thread safe list of channels func New(ds datastore.Batching, notifier Notifier, - voucherDecoder DecoderByTypeFunc, - voucherResultDecoder DecoderByTypeFunc, env ChannelEnvironment, selfPeer peer.ID) (*Channels, error) { - c := &Channels{ - notifier: notifier, - voucherDecoder: voucherDecoder, - voucherResultDecoder: voucherResultDecoder, - } + c := &Channels{notifier: notifier} c.blockIndexCache = newBlockIndexCache() c.progressCache = newProgressCache() channelMigrations, err := migrations.GetChannelStateMigrations(selfPeer) @@ -121,7 +111,7 @@ func (c *Channels) dispatch(eventName fsm.EventName, channel fsm.StateType) { // CreateNew creates a new channel id and channel state and saves to channels. // returns error if the channel exists already. -func (c *Channels) CreateNew(selfPeer peer.ID, tid datatransfer.TransferID, baseCid cid.Cid, selector ipld.Node, voucher datatransfer.Voucher, initiator, dataSender, dataReceiver peer.ID) (datatransfer.ChannelID, error) { +func (c *Channels) CreateNew(selfPeer peer.ID, tid datatransfer.TransferID, baseCid cid.Cid, selector datamodel.Node, voucher datatransfer.TypedVoucher, initiator, dataSender, dataReceiver peer.ID) (datatransfer.ChannelID, error) { var responder peer.ID if dataSender == initiator { responder = dataReceiver @@ -129,11 +119,11 @@ func (c *Channels) CreateNew(selfPeer peer.ID, tid datatransfer.TransferID, base responder = dataSender } chid := datatransfer.ChannelID{Initiator: initiator, Responder: responder, ID: tid} - voucherBytes, err := encoding.Encode(voucher) + initialVoucher, err := ipldutils.NodeToDeferred(voucher.Voucher) if err != nil { return datatransfer.ChannelID{}, err } - selBytes, err := encoding.Encode(selector) + selBytes, err := ipldutils.NodeToBytes(selector) if err != nil { return datatransfer.ChannelID{}, err } @@ -149,10 +139,8 @@ func (c *Channels) CreateNew(selfPeer peer.ID, tid datatransfer.TransferID, base Stages: &datatransfer.ChannelStages{}, Vouchers: []internal.EncodedVoucher{ { - Type: voucher.Type(), - Voucher: &cbg.Deferred{ - Raw: voucherBytes, - }, + Type: voucher.Type, + Voucher: initialVoucher, }, }, Status: datatransfer.Requested, @@ -289,21 +277,21 @@ func (c *Channels) ResumeResponder(chid datatransfer.ChannelID) error { } // NewVoucher records a new voucher for this channel -func (c *Channels) NewVoucher(chid datatransfer.ChannelID, voucher datatransfer.Voucher) error { - voucherBytes, err := encoding.Encode(voucher) +func (c *Channels) NewVoucher(chid datatransfer.ChannelID, voucher datatransfer.TypedVoucher) error { + voucherBytes, err := ipldutils.NodeToBytes(voucher.Voucher) if err != nil { return err } - return c.send(chid, datatransfer.NewVoucher, voucher.Type(), voucherBytes) + return c.send(chid, datatransfer.NewVoucher, voucher.Type, voucherBytes) } // NewVoucherResult records a new voucher result for this channel -func (c *Channels) NewVoucherResult(chid datatransfer.ChannelID, voucherResult datatransfer.VoucherResult) error { - voucherResultBytes, err := encoding.Encode(voucherResult) +func (c *Channels) NewVoucherResult(chid datatransfer.ChannelID, voucherResult datatransfer.TypedVoucher) error { + voucherResultBytes, err := ipldutils.NodeToBytes(voucherResult.Voucher) if err != nil { return err } - return c.send(chid, datatransfer.NewVoucherResult, voucherResult.Type(), voucherResultBytes) + return c.send(chid, datatransfer.NewVoucherResult, voucherResult.Type, voucherResultBytes) } // Complete indicates responder has completed sending/receiving data @@ -485,5 +473,5 @@ func (c *Channels) checkChannelExists(chid datatransfer.ChannelID, code datatran // Convert from the internally used channel state format to the externally exposed ChannelState func (c *Channels) fromInternalChannelState(ch internal.ChannelState) datatransfer.ChannelState { - return fromInternalChannelState(ch, c.voucherDecoder, c.voucherResultDecoder) + return fromInternalChannelState(ch) } diff --git a/channels/channels_test.go b/channels/channels_test.go index 8b739282..505afa34 100644 --- a/channels/channels_test.go +++ b/channels/channels_test.go @@ -16,7 +16,6 @@ import ( datatransfer "github.com/filecoin-project/go-data-transfer/v2" "github.com/filecoin-project/go-data-transfer/v2/channels" - "github.com/filecoin-project/go-data-transfer/v2/encoding" "github.com/filecoin-project/go-data-transfer/v2/testutil" ) @@ -32,13 +31,13 @@ func TestChannels(t *testing.T) { tid1 := datatransfer.TransferID(0) tid2 := datatransfer.TransferID(1) - fv1 := &testutil.FakeDTType{} - fv2 := &testutil.FakeDTType{} + fv1 := testutil.NewTestTypedVoucher() + fv2 := testutil.NewTestTypedVoucher() cids := testutil.GenerateCids(4) selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() peers := testutil.GeneratePeers(4) - channelList, err := channels.New(ds, notifier, decoderByType, decoderByType, &fakeEnv{}, peers[0]) + channelList, err := channels.New(ds, notifier, &fakeEnv{}, peers[0]) require.NoError(t, err) err = channelList.Start(ctx) @@ -80,7 +79,9 @@ func TestChannels(t *testing.T) { require.NotEqual(t, channels.EmptyChannelState, state) require.Equal(t, cids[0], state.BaseCID()) require.Equal(t, selector, state.Selector()) - require.Equal(t, fv1, state.Voucher()) + voucher, err := state.Voucher() + require.NoError(t, err) + require.True(t, fv1.Equals(voucher)) require.Equal(t, peers[0], state.Sender()) require.Equal(t, peers[1], state.Recipient()) @@ -124,7 +125,7 @@ func TestChannels(t *testing.T) { t.Run("datasent/queued when transfer is already finished", func(t *testing.T) { ds := dss.MutexWrap(datastore.NewMapDatastore()) - channelList, err := channels.New(ds, notifier, decoderByType, decoderByType, &fakeEnv{}, peers[0]) + channelList, err := channels.New(ds, notifier, &fakeEnv{}, peers[0]) require.NoError(t, err) err = channelList.Start(ctx) require.NoError(t, err) @@ -156,7 +157,7 @@ func TestChannels(t *testing.T) { t.Run("updating send/receive values", func(t *testing.T) { ds := dss.MutexWrap(datastore.NewMapDatastore()) - channelList, err := channels.New(ds, notifier, decoderByType, decoderByType, &fakeEnv{}, peers[0]) + channelList, err := channels.New(ds, notifier, &fakeEnv{}, peers[0]) require.NoError(t, err) err = channelList.Start(ctx) require.NoError(t, err) @@ -218,7 +219,7 @@ func TestChannels(t *testing.T) { t.Run("data limit", func(t *testing.T) { ds := dss.MutexWrap(datastore.NewMapDatastore()) - channelList, err := channels.New(ds, notifier, decoderByType, decoderByType, &fakeEnv{}, peers[0]) + channelList, err := channels.New(ds, notifier, &fakeEnv{}, peers[0]) require.NoError(t, err) err = channelList.Start(ctx) require.NoError(t, err) @@ -297,31 +298,53 @@ func TestChannels(t *testing.T) { }) t.Run("new vouchers & voucherResults", func(t *testing.T) { - fv3 := testutil.NewFakeDTType() - fvr1 := testutil.NewFakeDTType() + fv3 := testutil.NewTestTypedVoucher() + fvr1 := testutil.NewTestTypedVoucher() state, err := channelList.GetByID(ctx, datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) - require.Equal(t, []datatransfer.Voucher{fv1}, state.Vouchers()) - require.Equal(t, fv1, state.Voucher()) - require.Equal(t, fv1, state.LastVoucher()) + vouchers, err := state.Vouchers() + require.NoError(t, err) + require.Len(t, vouchers, 1) + require.True(t, fv1.Equals(vouchers[0])) + voucher, err := state.Voucher() + require.NoError(t, err) + require.True(t, fv1.Equals(voucher)) + voucher, err = state.LastVoucher() + require.NoError(t, err) + require.True(t, fv1.Equals(voucher)) err = channelList.NewVoucher(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, fv3) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.NewVoucher) - require.Equal(t, []datatransfer.Voucher{fv1, fv3}, state.Vouchers()) - require.Equal(t, fv1, state.Voucher()) - require.Equal(t, fv3, state.LastVoucher()) + vouchers, err = state.Vouchers() + require.NoError(t, err) + require.Len(t, vouchers, 2) + require.True(t, fv1.Equals(vouchers[0])) + require.True(t, fv3.Equals(vouchers[1])) + voucher, err = state.Voucher() + require.NoError(t, err) + require.True(t, fv1.Equals(voucher)) + voucher, err = state.LastVoucher() + require.NoError(t, err) + require.True(t, fv3.Equals(voucher)) state, err = channelList.GetByID(ctx, datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) - require.Equal(t, []datatransfer.VoucherResult{}, state.VoucherResults()) + results, err := state.VoucherResults() + require.NoError(t, err) + require.Equal(t, []datatransfer.TypedVoucher{}, results) err = channelList.NewVoucherResult(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, fvr1) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.NewVoucherResult) - require.Equal(t, []datatransfer.VoucherResult{fvr1}, state.VoucherResults()) - require.Equal(t, fvr1, state.LastVoucherResult()) + voucherResults, err := state.VoucherResults() + require.NoError(t, err) + require.Len(t, voucherResults, 1) + require.True(t, fvr1.Equals(voucherResults[0])) + voucherResult, err := state.LastVoucherResult() + require.NoError(t, err) + require.True(t, fvr1.Equals(voucherResult)) }) t.Run("test finality", func(t *testing.T) { @@ -387,7 +410,7 @@ func TestChannels(t *testing.T) { notifier := func(evt datatransfer.Event, chst datatransfer.ChannelState) { received <- event{evt, chst} } - channelList, err := channels.New(ds, notifier, decoderByType, decoderByType, &fakeEnv{}, peers[0]) + channelList, err := channels.New(ds, notifier, &fakeEnv{}, peers[0]) require.NoError(t, err) err = channelList.Start(ctx) require.NoError(t, err) @@ -469,14 +492,3 @@ func (fe *fakeEnv) ID() peer.ID { func (fe *fakeEnv) CleanupChannel(chid datatransfer.ChannelID) { } - -func decoderByType(identifier datatransfer.TypeIdentifier) (encoding.Decoder, bool) { - if identifier == testutil.NewFakeDTType().Type() { - decoder, err := encoding.NewDecoder(testutil.NewFakeDTType()) - if err != nil { - return nil, false - } - return decoder, true - } - return nil, false -} diff --git a/encoding/encoding.go b/encoding/encoding.go deleted file mode 100644 index dec7abcd..00000000 --- a/encoding/encoding.go +++ /dev/null @@ -1,171 +0,0 @@ -package encoding - -import ( - "bytes" - "reflect" - - cbor "github.com/ipfs/go-ipld-cbor" - "github.com/ipld/go-ipld-prime" - "github.com/ipld/go-ipld-prime/codec/dagcbor" - "github.com/ipld/go-ipld-prime/datamodel" - "github.com/ipld/go-ipld-prime/node/basicnode" - "github.com/ipld/go-ipld-prime/schema" - cborgen "github.com/whyrusleeping/cbor-gen" - "golang.org/x/xerrors" -) - -// Encodable is an object that can be written to CBOR and decoded back -type Encodable interface{} - -// Encode encodes an encodable to CBOR, using the best available path for -// writing to CBOR -func Encode(value Encodable) ([]byte, error) { - if cbgEncodable, ok := value.(cborgen.CBORMarshaler); ok { - buf := new(bytes.Buffer) - err := cbgEncodable.MarshalCBOR(buf) - if err != nil { - return nil, err - } - return buf.Bytes(), nil - } - if ipldEncodable, ok := value.(datamodel.Node); ok { - if tn, ok := ipldEncodable.(schema.TypedNode); ok { - ipldEncodable = tn.Representation() - } - buf := &bytes.Buffer{} - err := dagcbor.Encode(ipldEncodable, buf) - if err != nil { - return nil, err - } - return buf.Bytes(), nil - } - return cbor.DumpObject(value) -} - -func EncodeToNode(encodable Encodable) (datamodel.Node, error) { - byts, err := Encode(encodable) - if err != nil { - return nil, err - } - na := basicnode.Prototype.Any.NewBuilder() - if err := dagcbor.Decode(na, bytes.NewReader(byts)); err != nil { - return nil, err - } - return na.Build(), nil -} - -// Decoder is CBOR decoder for a given encodable type -type Decoder interface { - DecodeFromCbor([]byte) (Encodable, error) - DecodeFromNode(datamodel.Node) (Encodable, error) -} - -// NewDecoder creates a new Decoder that will decode into new instances of the given -// object type. It will use the decoding that is optimal for that type -// It returns error if it's not possible to setup a decoder for this type -func NewDecoder(decodeType Encodable) (Decoder, error) { - // check if type is datamodel.Node, if so, just use style - if ipldDecodable, ok := decodeType.(datamodel.Node); ok { - return &ipldDecoder{ipldDecodable.Prototype()}, nil - } - // check if type is a pointer, as we need that to make new copies - // for cborgen types & regular IPLD types - decodeReflectType := reflect.TypeOf(decodeType) - if decodeReflectType.Kind() != reflect.Ptr { - return nil, xerrors.New("type must be a pointer") - } - // check if type is a cbor-gen type - if _, ok := decodeType.(cborgen.CBORUnmarshaler); ok { - return &cbgDecoder{decodeReflectType}, nil - } - // type does is neither ipld-prime nor cbor-gen, so we need to see if it - // can rountrip with oldschool ipld-format - encoded, err := cbor.DumpObject(decodeType) - if err != nil { - return nil, xerrors.New("Object type did not encode") - } - newDecodable := reflect.New(decodeReflectType.Elem()).Interface() - if err := cbor.DecodeInto(encoded, newDecodable); err != nil { - return nil, xerrors.New("Object type did not decode") - } - return &defaultDecoder{decodeReflectType}, nil -} - -type ipldDecoder struct { - style ipld.NodePrototype -} - -func (decoder *ipldDecoder) DecodeFromCbor(encoded []byte) (Encodable, error) { - builder := decoder.style.NewBuilder() - buf := bytes.NewReader(encoded) - err := dagcbor.Decode(builder, buf) - if err != nil { - return nil, err - } - return builder.Build(), nil -} - -func (decoder *ipldDecoder) DecodeFromNode(node datamodel.Node) (Encodable, error) { - builder := decoder.style.NewBuilder() - if err := builder.AssignNode(node); err != nil { - return nil, err - } - return builder.Build(), nil -} - -type cbgDecoder struct { - cbgType reflect.Type -} - -func (decoder *cbgDecoder) DecodeFromCbor(encoded []byte) (Encodable, error) { - decodedValue := reflect.New(decoder.cbgType.Elem()) - decoded, ok := decodedValue.Interface().(cborgen.CBORUnmarshaler) - if !ok || reflect.ValueOf(decoded).IsNil() { - return nil, xerrors.New("problem instantiating decoded value") - } - buf := bytes.NewReader(encoded) - err := decoded.UnmarshalCBOR(buf) - if err != nil { - return nil, err - } - return decoded, nil -} - -func (decoder *cbgDecoder) DecodeFromNode(node datamodel.Node) (Encodable, error) { - if tn, ok := node.(schema.TypedNode); ok { - node = tn.Representation() - } - buf := &bytes.Buffer{} - if err := dagcbor.Encode(node, buf); err != nil { - return nil, err - } - return decoder.DecodeFromCbor(buf.Bytes()) -} - -type defaultDecoder struct { - ptrType reflect.Type -} - -func (decoder *defaultDecoder) DecodeFromCbor(encoded []byte) (Encodable, error) { - decodedValue := reflect.New(decoder.ptrType.Elem()) - decoded, ok := decodedValue.Interface().(Encodable) - if !ok || reflect.ValueOf(decoded).IsNil() { - return nil, xerrors.New("problem instantiating decoded value") - } - err := cbor.DecodeInto(encoded, decoded) - if err != nil { - return nil, err - } - return decoded, nil -} - -func (decoder *defaultDecoder) DecodeFromNode(node datamodel.Node) (Encodable, error) { - if tn, ok := node.(schema.TypedNode); ok { - node = tn.Representation() - } - buf := &bytes.Buffer{} - if err := dagcbor.Encode(node, buf); err != nil { - return nil, err - } - return decoder.DecodeFromCbor(buf.Bytes()) -} diff --git a/encoding/encoding_test.go b/encoding/encoding_test.go deleted file mode 100644 index 3d751f7a..00000000 --- a/encoding/encoding_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package encoding_test - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/filecoin-project/go-data-transfer/v2/encoding" - "github.com/filecoin-project/go-data-transfer/v2/encoding/testdata" -) - -func TestRoundTrip(t *testing.T) { - testCases := map[string]struct { - val encoding.Encodable - }{ - "can encode/decode IPLD prime types": { - val: testdata.Prime, - }, - "can encode/decode cbor-gen types": { - val: testdata.Cbg, - }, - "can encode/decode old ipld format types": { - val: testdata.Standard, - }, - } - for testCase, data := range testCases { - t.Run(testCase, func(t *testing.T) { - encoded, err := encoding.Encode(data.val) - require.NoError(t, err) - decoder, err := encoding.NewDecoder(data.val) - require.NoError(t, err) - decoded, err := decoder.DecodeFromCbor(encoded) - require.NoError(t, err) - require.Equal(t, data.val, decoded) - }) - } -} diff --git a/encoding/testdata/testdata.go b/encoding/testdata/testdata.go deleted file mode 100644 index 5bed37ba..00000000 --- a/encoding/testdata/testdata.go +++ /dev/null @@ -1,37 +0,0 @@ -package testdata - -import ( - cbor "github.com/ipfs/go-ipld-cbor" - "github.com/ipld/go-ipld-prime/fluent" - basicnode "github.com/ipld/go-ipld-prime/node/basic" -) - -// Prime = an instance of an ipld prime piece of data -var Prime = fluent.MustBuildMap(basicnode.Prototype.Map, 2, func(na fluent.MapAssembler) { - nva := na.AssembleEntry("X") - nva.AssignInt(100) - nva = na.AssembleEntry("Y") - nva.AssignString("appleSauce") -}) - -type standardType struct { - X int - Y string -} - -func init() { - cbor.RegisterCborType(standardType{}) -} - -// Standard = an instance that is neither ipld prime nor cbor -var Standard *standardType = &standardType{X: 100, Y: "appleSauce"} - -//go:generate cbor-gen-for cbgType - -type cbgType struct { - X uint64 - Y string -} - -// Cbg = an instance of a cbor-gen type -var Cbg *cbgType = &cbgType{X: 100, Y: "appleSauce"} diff --git a/encoding/testdata/testdata_cbor_gen.go b/encoding/testdata/testdata_cbor_gen.go deleted file mode 100644 index 67c6c688..00000000 --- a/encoding/testdata/testdata_cbor_gen.go +++ /dev/null @@ -1,84 +0,0 @@ -// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. - -package testdata - -import ( - "fmt" - "io" - - cbg "github.com/whyrusleeping/cbor-gen" - xerrors "golang.org/x/xerrors" -) - -var _ = xerrors.Errorf - -func (t *cbgType) MarshalCBOR(w io.Writer) error { - if t == nil { - _, err := w.Write(cbg.CborNull) - return err - } - if _, err := w.Write([]byte{130}); err != nil { - return err - } - - // t.X (uint64) (uint64) - - if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, uint64(t.X))); err != nil { - return err - } - - // t.Y (string) (string) - if len(t.Y) > cbg.MaxLength { - return xerrors.Errorf("Value in field t.Y was too long") - } - - if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len(t.Y)))); err != nil { - return err - } - if _, err := w.Write([]byte(t.Y)); err != nil { - return err - } - return nil -} - -func (t *cbgType) UnmarshalCBOR(r io.Reader) error { - br := cbg.GetPeeker(r) - - maj, extra, err := cbg.CborReadHeader(br) - if err != nil { - return err - } - if maj != cbg.MajArray { - return fmt.Errorf("cbor input should be of type array") - } - - if extra != 2 { - return fmt.Errorf("cbor input had wrong number of fields") - } - - // t.X (uint64) (uint64) - - { - - maj, extra, err = cbg.CborReadHeader(br) - if err != nil { - return err - } - if maj != cbg.MajUnsignedInt { - return fmt.Errorf("wrong type for uint64 field") - } - t.X = uint64(extra) - - } - // t.Y (string) (string) - - { - sval, err := cbg.ReadString(br) - if err != nil { - return err - } - - t.Y = string(sval) - } - return nil -} diff --git a/go.mod b/go.mod index 252c47ac..8154ae11 100644 --- a/go.mod +++ b/go.mod @@ -20,12 +20,11 @@ require ( github.com/ipfs/go-ipfs-delay v0.0.1 github.com/ipfs/go-ipfs-exchange-offline v0.1.1 github.com/ipfs/go-ipfs-files v0.0.8 - github.com/ipfs/go-ipld-cbor v0.0.5 github.com/ipfs/go-ipld-format v0.2.0 github.com/ipfs/go-log/v2 v2.5.1 github.com/ipfs/go-merkledag v0.5.1 github.com/ipfs/go-unixfs v0.3.1 - github.com/ipld/go-ipld-prime v0.16.0 + github.com/ipld/go-ipld-prime v0.16.1-0.20220519105356-1f1151b69dba github.com/jbenet/go-random v0.0.0-20190219211222-123a90aedc0c github.com/jpillora/backoff v1.0.0 github.com/libp2p/go-libp2p v0.19.4 @@ -71,6 +70,7 @@ require ( github.com/ipfs/go-ipfs-posinfo v0.0.1 // indirect github.com/ipfs/go-ipfs-pq v0.0.2 // indirect github.com/ipfs/go-ipfs-util v0.0.2 // indirect + github.com/ipfs/go-ipld-cbor v0.0.5 // indirect github.com/ipfs/go-ipld-legacy v0.1.0 // indirect github.com/ipfs/go-log v1.0.5 // indirect github.com/ipfs/go-metrics-interface v0.0.1 // indirect diff --git a/go.sum b/go.sum index c78f04f5..4edd2462 100644 --- a/go.sum +++ b/go.sum @@ -518,8 +518,9 @@ github.com/ipld/go-codec-dagpb v1.3.1/go.mod h1:ErNNglIi5KMur/MfFE/svtgQthzVvf+4 github.com/ipld/go-ipld-prime v0.9.1-0.20210324083106-dc342a9917db/go.mod h1:KvBLMr4PX1gWptgkzRjVZCrLmSGcZCb/jioOQwCqZN8= github.com/ipld/go-ipld-prime v0.11.0/go.mod h1:+WIAkokurHmZ/KwzDOMUuoeJgaRQktHtEaLglS3ZeV8= github.com/ipld/go-ipld-prime v0.14.0/go.mod h1:9ASQLwUFLptCov6lIYc70GRB4V7UTyLD0IJtrDJe6ZM= -github.com/ipld/go-ipld-prime v0.16.0 h1:RS5hhjB/mcpeEPJvfyj0qbOj/QL+/j05heZ0qa97dVo= github.com/ipld/go-ipld-prime v0.16.0/go.mod h1:axSCuOCBPqrH+gvXr2w9uAOulJqBPhHPT2PjoiiU1qA= +github.com/ipld/go-ipld-prime v0.16.1-0.20220519105356-1f1151b69dba h1:1eimQ/EpBUnxyhvSQ9gxzokN9EiDYHCeZ2URkhADIGQ= +github.com/ipld/go-ipld-prime v0.16.1-0.20220519105356-1f1151b69dba/go.mod h1:/bZAYlzT7SJS4UV0al4q67xgKvenm5hKrPCa2wNGN1U= github.com/ipld/go-ipld-prime/storage/bsadapter v0.0.0-20211210234204-ce2a1c70cd73/go.mod h1:2PJ0JgxyB08t0b2WKrcuqI3di0V+5n6RS/LTUJhkoxY= github.com/jackpal/gateway v1.0.5/go.mod h1:lTpwd4ACLXmpyiCTRtfiNyVnUmqT9RivzCDQetPfnjA= github.com/jackpal/go-nat-pmp v1.0.1/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= diff --git a/impl/events.go b/impl/events.go index ef5c5fb8..3ac2e64c 100644 --- a/impl/events.go +++ b/impl/events.go @@ -38,7 +38,7 @@ func (m *manager) OnChannelOpened(chid datatransfer.ChannelID) error { // back a pause to the transport if the data limit is exceeded func (m *manager) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) error { ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) - ctx, span := otel.Tracer("data-transfer").Start(ctx, "dataReceived", trace.WithAttributes( + _, span := otel.Tracer("data-transfer").Start(ctx, "dataReceived", trace.WithAttributes( attribute.String("channelID", chid.String()), attribute.String("link", link.String()), attribute.Int64("index", index), @@ -69,7 +69,7 @@ func (m *manager) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size // machine. ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) - ctx, span := otel.Tracer("data-transfer").Start(ctx, "dataQueued", trace.WithAttributes( + _, span := otel.Tracer("data-transfer").Start(ctx, "dataQueued", trace.WithAttributes( attribute.String("channelID", chid.String()), attribute.String("link", link.String()), attribute.Int64("size", int64(size)), @@ -91,7 +91,7 @@ func (m *manager) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size func (m *manager) OnDataSent(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) error { ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) - ctx, span := otel.Tracer("data-transfer").Start(ctx, "dataSent", trace.WithAttributes( + _, span := otel.Tracer("data-transfer").Start(ctx, "dataSent", trace.WithAttributes( attribute.String("channelID", chid.String()), attribute.String("link", link.String()), attribute.Int64("size", int64(size)), @@ -153,11 +153,11 @@ func (m *manager) OnResponseReceived(chid datatransfer.ChannelID, response datat // is there a voucher response in this message? if !response.EmptyVoucherResult() { // if so decode and save it - vresult, err := m.decodeVoucherResult(response) + vresult, err := response.VoucherResult() if err != nil { return err } - err = m.channels.NewVoucherResult(chid, vresult) + err = m.channels.NewVoucherResult(chid, datatransfer.TypedVoucher{Voucher: vresult, Type: response.VoucherResultType()}) if err != nil { return err } @@ -271,7 +271,7 @@ func (m *manager) OnChannelCompleted(chid datatransfer.ChannelID, completeErr er log.Infow("received OnChannelCompleted, will send completion message to initiator", "chid", chid) // generate and send the final status message - msg, err := message.CompleteResponse(chst.TransferID(), true, chst.RequiresFinalization(), datatransfer.EmptyTypeIdentifier, nil) + msg, err := message.CompleteResponse(chst.TransferID(), true, chst.RequiresFinalization(), nil) if err != nil { return err } diff --git a/impl/impl.go b/impl/impl.go index efd7e878..61f0fe67 100644 --- a/impl/impl.go +++ b/impl/impl.go @@ -10,7 +10,7 @@ import ( "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" logging "github.com/ipfs/go-log/v2" - "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/libp2p/go-libp2p-core/peer" "go.opentelemetry.io/otel" @@ -22,7 +22,6 @@ import ( datatransfer "github.com/filecoin-project/go-data-transfer/v2" "github.com/filecoin-project/go-data-transfer/v2/channelmonitor" "github.com/filecoin-project/go-data-transfer/v2/channels" - "github.com/filecoin-project/go-data-transfer/v2/encoding" "github.com/filecoin-project/go-data-transfer/v2/message" "github.com/filecoin-project/go-data-transfer/v2/message/types" "github.com/filecoin-project/go-data-transfer/v2/network" @@ -36,7 +35,6 @@ var cancelSendTimeout = 30 * time.Second type manager struct { dataTransferNetwork network.DataTransferNetwork validatedTypes *registry.Registry - resultTypes *registry.Registry transportConfigurers *registry.Registry pubSub *pubsub.PubSub readySub *pubsub.PubSub @@ -96,7 +94,6 @@ func NewDataTransfer(ds datastore.Batching, dataTransferNetwork network.DataTran m := &manager{ dataTransferNetwork: dataTransferNetwork, validatedTypes: registry.NewRegistry(), - resultTypes: registry.NewRegistry(), transportConfigurers: registry.NewRegistry(), pubSub: pubsub.New(dispatcher), readySub: pubsub.New(readyDispatcher), @@ -106,7 +103,7 @@ func NewDataTransfer(ds datastore.Batching, dataTransferNetwork network.DataTran spansIndex: tracing.NewSpansIndex(), } - channels, err := channels.New(ds, m.notifier, m.voucherDecoder, m.resultTypes.Decoder, &channelEnvironment{m}, dataTransferNetwork.ID()) + channels, err := channels.New(ds, m.notifier, &channelEnvironment{m}, dataTransferNetwork.ID()) if err != nil { return nil, err } @@ -124,10 +121,6 @@ func NewDataTransfer(ds datastore.Batching, dataTransferNetwork network.DataTran return m, nil } -func (m *manager) voucherDecoder(voucherType datatransfer.TypeIdentifier) (encoding.Decoder, bool) { - return m.validatedTypes.Decoder(voucherType) -} - func (m *manager) notifier(evt datatransfer.Event, chst datatransfer.ChannelState) { err := m.pubSub.Publish(internalEvent{evt, chst}) if err != nil { @@ -173,7 +166,7 @@ func (m *manager) Stop(ctx context.Context) error { // * voucher type does not implement voucher // * there is a voucher type registered with an identical identifier // * voucherType's Kind is not reflect.Ptr -func (m *manager) RegisterVoucherType(voucherType datatransfer.Voucher, validator datatransfer.RequestValidator) error { +func (m *manager) RegisterVoucherType(voucherType datatransfer.TypeIdentifier, validator datatransfer.RequestValidator) error { err := m.validatedTypes.Register(voucherType, validator) if err != nil { return xerrors.Errorf("error registering voucher type: %w", err) @@ -183,7 +176,7 @@ func (m *manager) RegisterVoucherType(voucherType datatransfer.Voucher, validato // OpenPushDataChannel opens a data transfer that will send data to the recipient peer and // transfer parts of the piece that match the selector -func (m *manager) OpenPushDataChannel(ctx context.Context, requestTo peer.ID, voucher datatransfer.Voucher, baseCid cid.Cid, selector ipld.Node) (datatransfer.ChannelID, error) { +func (m *manager) OpenPushDataChannel(ctx context.Context, requestTo peer.ID, voucher datatransfer.TypedVoucher, baseCid cid.Cid, selector datamodel.Node) (datatransfer.ChannelID, error) { log.Infof("open push channel to %s with base cid %s", requestTo, baseCid) req, err := m.newRequest(ctx, selector, false, voucher, baseCid, requestTo) @@ -197,7 +190,7 @@ func (m *manager) OpenPushDataChannel(ctx context.Context, requestTo peer.ID, vo return chid, err } ctx, span := m.spansIndex.SpanForChannel(ctx, chid) - processor, has := m.transportConfigurers.Processor(voucher.Type()) + processor, has := m.transportConfigurers.Processor(voucher.Type) if has { transportConfigurer := processor.(datatransfer.TransportConfigurer) transportConfigurer(chid, voucher, m.transport) @@ -205,7 +198,7 @@ func (m *manager) OpenPushDataChannel(ctx context.Context, requestTo peer.ID, vo m.dataTransferNetwork.Protect(requestTo, chid.String()) monitoredChan := m.channelMonitor.AddPushChannel(chid) if err := m.dataTransferNetwork.SendMessage(ctx, requestTo, req); err != nil { - err = fmt.Errorf("Unable to send request: %w", err) + err = fmt.Errorf("unable to send request: %w", err) span.RecordError(err) span.SetStatus(codes.Error, err.Error()) _ = m.channels.Error(chid, err) @@ -224,7 +217,7 @@ func (m *manager) OpenPushDataChannel(ctx context.Context, requestTo peer.ID, vo // OpenPullDataChannel opens a data transfer that will request data from the sending peer and // transfer parts of the piece that match the selector -func (m *manager) OpenPullDataChannel(ctx context.Context, requestTo peer.ID, voucher datatransfer.Voucher, baseCid cid.Cid, selector ipld.Node) (datatransfer.ChannelID, error) { +func (m *manager) OpenPullDataChannel(ctx context.Context, requestTo peer.ID, voucher datatransfer.TypedVoucher, baseCid cid.Cid, selector datamodel.Node) (datatransfer.ChannelID, error) { log.Infof("open pull channel to %s with base cid %s", requestTo, baseCid) req, err := m.newRequest(ctx, selector, true, voucher, baseCid, requestTo) @@ -238,7 +231,7 @@ func (m *manager) OpenPullDataChannel(ctx context.Context, requestTo peer.ID, vo return chid, err } ctx, span := m.spansIndex.SpanForChannel(ctx, chid) - processor, has := m.transportConfigurers.Processor(voucher.Type()) + processor, has := m.transportConfigurers.Processor(voucher.Type) if has { transportConfigurer := processor.(datatransfer.TransportConfigurer) transportConfigurer(chid, voucher, m.transport) @@ -246,7 +239,7 @@ func (m *manager) OpenPullDataChannel(ctx context.Context, requestTo peer.ID, vo m.dataTransferNetwork.Protect(requestTo, chid.String()) monitoredChan := m.channelMonitor.AddPullChannel(chid) if err := m.transport.OpenChannel(ctx, requestTo, chid, cidlink.Link{Cid: baseCid}, selector, nil, req); err != nil { - err = fmt.Errorf("Unable to send request: %w", err) + err = fmt.Errorf("unable to send request: %w", err) span.RecordError(err) span.SetStatus(codes.Error, err.Error()) _ = m.channels.Error(chid, err) @@ -262,7 +255,7 @@ func (m *manager) OpenPullDataChannel(ctx context.Context, requestTo peer.ID, vo } // SendVoucher sends an intermediate voucher as needed when the receiver sends a request for revalidation -func (m *manager) SendVoucher(ctx context.Context, channelID datatransfer.ChannelID, voucher datatransfer.Voucher) error { +func (m *manager) SendVoucher(ctx context.Context, channelID datatransfer.ChannelID, voucher datatransfer.TypedVoucher) error { chst, err := m.channels.GetByID(ctx, channelID) if err != nil { return err @@ -270,7 +263,6 @@ func (m *manager) SendVoucher(ctx context.Context, channelID datatransfer.Channe ctx, _ = m.spansIndex.SpanForChannel(ctx, channelID) ctx, span := otel.Tracer("data-transfer").Start(ctx, "sendVoucher", trace.WithAttributes( attribute.String("channelID", channelID.String()), - attribute.String("voucherType", string(voucher.Type())), )) defer span.End() if channelID.Initiator != m.peerID { @@ -279,14 +271,14 @@ func (m *manager) SendVoucher(ctx context.Context, channelID datatransfer.Channe span.SetStatus(codes.Error, err.Error()) return err } - updateRequest, err := message.VoucherRequest(channelID.ID, voucher.Type(), voucher) + updateRequest, err := message.VoucherRequest(channelID.ID, &voucher) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) return err } if err := m.dataTransferNetwork.SendMessage(ctx, chst.OtherPeer(), updateRequest); err != nil { - err = fmt.Errorf("Unable to send request: %w", err) + err = fmt.Errorf("unable to send request: %w", err) _ = m.OnRequestDisconnected(channelID, err) span.RecordError(err) span.SetStatus(codes.Error, err.Error()) @@ -295,7 +287,7 @@ func (m *manager) SendVoucher(ctx context.Context, channelID datatransfer.Channe return m.channels.NewVoucher(channelID, voucher) } -func (m *manager) SendVoucherResult(ctx context.Context, channelID datatransfer.ChannelID, voucherResult datatransfer.VoucherResult) error { +func (m *manager) SendVoucherResult(ctx context.Context, channelID datatransfer.ChannelID, voucherResult datatransfer.TypedVoucher) error { chst, err := m.channels.GetByID(ctx, channelID) if err != nil { return err @@ -303,7 +295,6 @@ func (m *manager) SendVoucherResult(ctx context.Context, channelID datatransfer. ctx, _ = m.spansIndex.SpanForChannel(ctx, channelID) ctx, span := otel.Tracer("data-transfer").Start(ctx, "sendVoucherResult", trace.WithAttributes( attribute.String("channelID", channelID.String()), - attribute.String("voucherResultType", string(voucherResult.Type())), )) defer span.End() if channelID.Initiator == m.peerID { @@ -315,9 +306,9 @@ func (m *manager) SendVoucherResult(ctx context.Context, channelID datatransfer. var updateResponse datatransfer.Response if chst.Status().InFinalization() { - updateResponse, err = message.CompleteResponse(channelID.ID, chst.Status().IsAccepted(), chst.Status().IsResponderPaused(), voucherResult.Type(), voucherResult) + updateResponse, err = message.CompleteResponse(channelID.ID, chst.Status().IsAccepted(), chst.Status().IsResponderPaused(), &voucherResult) } else { - updateResponse, err = message.VoucherResultResponse(channelID.ID, chst.Status().IsAccepted(), chst.Status().IsResponderPaused(), voucherResult.Type(), voucherResult) + updateResponse, err = message.VoucherResultResponse(channelID.ID, chst.Status().IsAccepted(), chst.Status().IsResponderPaused(), &voucherResult) } if err != nil { @@ -326,7 +317,7 @@ func (m *manager) SendVoucherResult(ctx context.Context, channelID datatransfer. return err } if err := m.dataTransferNetwork.SendMessage(ctx, chst.OtherPeer(), updateResponse); err != nil { - err = fmt.Errorf("Unable to send request: %w", err) + err = fmt.Errorf("unable to send request: %w", err) _ = m.OnRequestDisconnected(channelID, err) span.RecordError(err) span.SetStatus(codes.Error, err.Error()) @@ -549,7 +540,7 @@ func (m *manager) PauseDataTransferChannel(ctx context.Context, chid datatransfe } if err := m.dataTransferNetwork.SendMessage(ctx, chid.OtherParty(m.peerID), m.pauseMessage(chid)); err != nil { - err = fmt.Errorf("Unable to send pause message: %w", err) + err = fmt.Errorf("unable to send pause message: %w", err) _ = m.OnRequestDisconnected(chid, err) return err } @@ -600,19 +591,9 @@ func (m *manager) InProgressChannels(ctx context.Context) (map[datatransfer.Chan return m.channels.InProgress() } -// RegisterVoucherResultType allows deserialization of a voucher result, -// so that a listener can read the metadata -func (m *manager) RegisterVoucherResultType(resultType datatransfer.VoucherResult) error { - err := m.resultTypes.Register(resultType, nil) - if err != nil { - return xerrors.Errorf("error registering voucher type: %w", err) - } - return nil -} - // RegisterTransportConfigurer registers the given transport configurer to be run on requests with the given voucher // type -func (m *manager) RegisterTransportConfigurer(voucherType datatransfer.Voucher, configurer datatransfer.TransportConfigurer) error { +func (m *manager) RegisterTransportConfigurer(voucherType datatransfer.TypeIdentifier, configurer datatransfer.TransportConfigurer) error { err := m.transportConfigurers.Register(voucherType, configurer) if err != nil { return xerrors.Errorf("error registering transport configurer: %w", err) diff --git a/impl/initiating_test.go b/impl/initiating_test.go index 0d1274b7..8a393834 100644 --- a/impl/initiating_test.go +++ b/impl/initiating_test.go @@ -9,7 +9,7 @@ import ( "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" dss "github.com/ipfs/go-datastore/sync" - "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/require" @@ -52,7 +52,7 @@ func TestDataTransferInitiating(t *testing.T) { receivedSelector, err := receivedRequest.Selector() require.NoError(t, err) require.Equal(t, receivedSelector, h.stor) - testutil.AssertFakeDTVoucher(t, receivedRequest, h.voucher) + testutil.AssertTestVoucher(t, receivedRequest, h.voucher) }, }, "OpenPullDataTransfer": { @@ -79,7 +79,7 @@ func TestDataTransferInitiating(t *testing.T) { receivedSelector, err := receivedRequest.Selector() require.NoError(t, err) require.Equal(t, receivedSelector, h.stor) - testutil.AssertFakeDTVoucher(t, receivedRequest, h.voucher) + testutil.AssertTestVoucher(t, receivedRequest, h.voucher) }, }, "SendVoucher with no channel open": { @@ -93,7 +93,7 @@ func TestDataTransferInitiating(t *testing.T) { verify: func(t *testing.T, h *harness) { channelID, err := h.dt.OpenPushDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) - voucher := testutil.NewFakeDTType() + voucher := testutil.NewTestTypedVoucher() err = h.dt.SendVoucher(ctx, channelID, voucher) require.NoError(t, err) require.Len(t, h.network.SentMessages, 2) @@ -103,7 +103,7 @@ func TestDataTransferInitiating(t *testing.T) { require.True(t, ok) require.True(t, receivedRequest.IsVoucher()) require.False(t, receivedRequest.IsCancel()) - testutil.AssertFakeDTVoucher(t, receivedRequest, voucher) + testutil.AssertTestVoucher(t, receivedRequest, voucher) }, }, "SendVoucher with channel open, pull succeeds": { @@ -111,7 +111,7 @@ func TestDataTransferInitiating(t *testing.T) { verify: func(t *testing.T, h *harness) { channelID, err := h.dt.OpenPullDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) - voucher := testutil.NewFakeDTType() + voucher := testutil.NewTestTypedVoucher() err = h.dt.SendVoucher(ctx, channelID, voucher) require.NoError(t, err) require.Len(t, h.transport.OpenedChannels, 1) @@ -122,26 +122,16 @@ func TestDataTransferInitiating(t *testing.T) { require.True(t, ok) require.False(t, receivedRequest.IsCancel()) require.True(t, receivedRequest.IsVoucher()) - testutil.AssertFakeDTVoucher(t, receivedRequest, voucher) + testutil.AssertTestVoucher(t, receivedRequest, voucher) }, }, "reregister voucher type again errors": { verify: func(t *testing.T, h *harness) { - voucher := testutil.NewFakeDTType() sv := testutil.NewStubbedValidator() - err := h.dt.RegisterVoucherType(h.voucher, sv) + err := h.dt.RegisterVoucherType(h.voucher.Type, sv) require.NoError(t, err) - err = h.dt.RegisterVoucherType(voucher, sv) - require.EqualError(t, err, "error registering voucher type: identifier already registered: FakeDTType") - }, - }, - "reregister non pointer errors": { - verify: func(t *testing.T, h *harness) { - sv := testutil.NewStubbedValidator() - err := h.dt.RegisterVoucherType(h.voucher, sv) - require.NoError(t, err) - err = h.dt.RegisterVoucherType(testutil.FakeDTType{}, sv) - require.EqualError(t, err, "error registering voucher type: registering entry type FakeDTType: type must be a pointer") + err = h.dt.RegisterVoucherType(testutil.TestVoucherType, sv) + require.EqualError(t, err, "error registering voucher type: identifier already registered: TestVoucher") }, }, "success response": { @@ -150,7 +140,7 @@ func TestDataTransferInitiating(t *testing.T) { channelID, err := h.dt.OpenPushDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) require.NotEmpty(t, channelID) - response, err := message.NewResponse(channelID.ID, true, false, datatransfer.EmptyTypeIdentifier, nil) + response, err := message.NewResponse(channelID.ID, true, false, nil) require.NoError(t, err) err = h.transport.EventHandler.OnResponseReceived(channelID, response) require.NoError(t, err) @@ -162,7 +152,7 @@ func TestDataTransferInitiating(t *testing.T) { channelID, err := h.dt.OpenPushDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) require.NotEmpty(t, channelID) - response, err := message.NewResponse(channelID.ID, true, false, h.voucherResult.Type(), h.voucherResult) + response, err := message.NewResponse(channelID.ID, true, false, &h.voucherResult) require.NoError(t, err) err = h.transport.EventHandler.OnResponseReceived(channelID, response) require.NoError(t, err) @@ -174,7 +164,7 @@ func TestDataTransferInitiating(t *testing.T) { channelID, err := h.dt.OpenPushDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) require.NotEmpty(t, channelID) - response, err := message.NewResponse(channelID.ID, true, false, datatransfer.EmptyTypeIdentifier, nil) + response, err := message.NewResponse(channelID.ID, true, false, nil) require.NoError(t, err) err = h.transport.EventHandler.OnResponseReceived(channelID, response) require.NoError(t, err) @@ -230,7 +220,7 @@ func TestDataTransferInitiating(t *testing.T) { channelID, err := h.dt.OpenPullDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) require.NotEmpty(t, channelID) - response, err := message.NewResponse(channelID.ID, true, false, datatransfer.EmptyTypeIdentifier, nil) + response, err := message.NewResponse(channelID.ID, true, false, nil) require.NoError(t, err) err = h.transport.EventHandler.OnResponseReceived(channelID, response) require.NoError(t, err) @@ -284,7 +274,7 @@ func TestDataTransferInitiating(t *testing.T) { "customizing push transfer": { expectedEvents: []datatransfer.EventCode{datatransfer.Open}, verify: func(t *testing.T, h *harness) { - err := h.dt.RegisterTransportConfigurer(h.voucher, func(channelID datatransfer.ChannelID, voucher datatransfer.Voucher, transport datatransfer.Transport) { + err := h.dt.RegisterTransportConfigurer(h.voucher.Type, func(channelID datatransfer.ChannelID, voucher datatransfer.TypedVoucher, transport datatransfer.Transport) { ft, ok := transport.(*testutil.FakeTransport) if !ok { return @@ -304,7 +294,7 @@ func TestDataTransferInitiating(t *testing.T) { "customizing pull transfer": { expectedEvents: []datatransfer.EventCode{datatransfer.Open}, verify: func(t *testing.T, h *harness) { - err := h.dt.RegisterTransportConfigurer(h.voucher, func(channelID datatransfer.ChannelID, voucher datatransfer.Voucher, transport datatransfer.Transport) { + err := h.dt.RegisterTransportConfigurer(h.voucher.Type, func(channelID datatransfer.ChannelID, voucher datatransfer.TypedVoucher, transport datatransfer.Transport) { ft, ok := transport.(*testutil.FakeTransport) if !ok { return @@ -344,9 +334,8 @@ func TestDataTransferInitiating(t *testing.T) { } ev.setup(t, dt) h.stor = testutil.AllSelector() - h.voucher = testutil.NewFakeDTType() - h.voucherResult = testutil.NewFakeDTType() - err = h.dt.RegisterVoucherResultType(h.voucherResult) + h.voucher = testutil.NewTestTypedVoucher() + h.voucherResult = testutil.NewTestTypedVoucher() require.NoError(t, err) h.baseCid = testutil.GenerateCids(1)[0] verify.verify(t, h) @@ -405,7 +394,7 @@ func TestDataTransferRestartInitiating(t *testing.T) { receivedSelector, err := receivedRequest.Selector() require.NoError(t, err) require.Equal(t, receivedSelector, h.stor) - testutil.AssertFakeDTVoucher(t, receivedRequest, h.voucher) + testutil.AssertTestVoucher(t, receivedRequest, h.voucher) }, }, "RestartDataTransferChannel: Manager Peer Create Push Restart works": { @@ -441,7 +430,7 @@ func TestDataTransferRestartInitiating(t *testing.T) { receivedSelector, err := receivedRequest.Selector() require.NoError(t, err) require.Equal(t, receivedSelector, h.stor) - testutil.AssertFakeDTVoucher(t, receivedRequest, h.voucher) + testutil.AssertTestVoucher(t, receivedRequest, h.voucher) }, }, "RestartDataTransferChannel: Manager Peer Receive Push Restart works ": { @@ -610,17 +599,16 @@ func TestDataTransferRestartInitiating(t *testing.T) { // setup voucher processing h.stor = testutil.AllSelector() - h.voucher = testutil.NewFakeDTType() - require.NoError(t, h.dt.RegisterVoucherType(h.voucher, h.voucherValidator)) - h.voucherResult = testutil.NewFakeDTType() - err = h.dt.RegisterVoucherResultType(h.voucherResult) + h.voucher = testutil.NewTestTypedVoucher() + require.NoError(t, h.dt.RegisterVoucherType(h.voucher.Type, h.voucherValidator)) + h.voucherResult = testutil.NewTestTypedVoucher() require.NoError(t, err) h.baseCid = testutil.GenerateCids(1)[0] h.id = datatransfer.TransferID(rand.Int31()) - h.pushRequest, err = message.NewRequest(h.id, false, false, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + h.pushRequest, err = message.NewRequest(h.id, false, false, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) - h.pullRequest, err = message.NewRequest(h.id, false, true, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + h.pullRequest, err = message.NewRequest(h.id, false, true, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) // run tests steps and verify @@ -639,9 +627,9 @@ type harness struct { ds datastore.Batching dt datatransfer.Manager voucherValidator *testutil.StubbedValidator - stor ipld.Node - voucher *testutil.FakeDTType - voucherResult *testutil.FakeDTType + stor datamodel.Node + voucher datatransfer.TypedVoucher + voucherResult datatransfer.TypedVoucher baseCid cid.Cid id datatransfer.TransferID diff --git a/impl/integration_test.go b/impl/integration_test.go index 09896c30..53c375cc 100644 --- a/impl/integration_test.go +++ b/impl/integration_test.go @@ -27,6 +27,7 @@ import ( "github.com/ipfs/go-unixfs/importer/balanced" ihelper "github.com/ipfs/go-unixfs/importer/helpers" "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" @@ -36,7 +37,6 @@ import ( datatransfer "github.com/filecoin-project/go-data-transfer/v2" "github.com/filecoin-project/go-data-transfer/v2/channelmonitor" - "github.com/filecoin-project/go-data-transfer/v2/encoding" . "github.com/filecoin-project/go-data-transfer/v2/impl" "github.com/filecoin-project/go-data-transfer/v2/message" "github.com/filecoin-project/go-data-transfer/v2/network" @@ -49,7 +49,8 @@ const loremFile = "lorem.txt" const loremFileTransferBytes = 20439 const loremLargeFile = "lorem_large.txt" -const loremLargeFileTransferBytes = 217452 + +// const loremLargeFileTransferBytes = 217452 // nil means use the default protocols // tests data transfer for the following protocol combinations: @@ -181,7 +182,7 @@ func TestRoundTrip(t *testing.T) { } dt1.SubscribeToEvents(subscriber) dt2.SubscribeToEvents(subscriber) - voucher := testutil.FakeDTType{Data: "applesauce"} + voucher := testutil.NewTestTypedVoucherWith("applesauce") sv := testutil.NewStubbedValidator() sv.StubResult(datatransfer.ValidationResult{Accepted: true}) @@ -191,9 +192,8 @@ func TestRoundTrip(t *testing.T) { bs := bstore.NewBlockstore(namespace.Wrap(ds, datastore.NewKey("blockstore"))) lsys := storeutil.LinkSystemForBlockstore(bs) sourceDagService = merkledag.NewDAGService(blockservice.New(bs, offline.Exchange(bs))) - err := dt1.RegisterTransportConfigurer(&testutil.FakeDTType{}, func(channelID datatransfer.ChannelID, testVoucher datatransfer.Voucher, transport datatransfer.Transport) { - fv, ok := testVoucher.(*testutil.FakeDTType) - if ok && fv.Data == voucher.Data { + err := dt1.RegisterTransportConfigurer(testutil.TestVoucherType, func(channelID datatransfer.ChannelID, testVoucher datatransfer.TypedVoucher, transport datatransfer.Transport) { + if testVoucher.Equals(voucher) { gsTransport, ok := transport.(*tp.Transport) if ok { err := gsTransport.UseStore(channelID, lsys) @@ -214,9 +214,8 @@ func TestRoundTrip(t *testing.T) { bs := bstore.NewBlockstore(namespace.Wrap(ds, datastore.NewKey("blockstore"))) lsys := storeutil.LinkSystemForBlockstore(bs) destDagService = merkledag.NewDAGService(blockservice.New(bs, offline.Exchange(bs))) - err := dt2.RegisterTransportConfigurer(&testutil.FakeDTType{}, func(channelID datatransfer.ChannelID, testVoucher datatransfer.Voucher, transport datatransfer.Transport) { - fv, ok := testVoucher.(*testutil.FakeDTType) - if ok && fv.Data == voucher.Data { + err := dt2.RegisterTransportConfigurer(testutil.TestVoucherType, func(channelID datatransfer.ChannelID, testVoucher datatransfer.TypedVoucher, transport datatransfer.Transport) { + if testVoucher.Equals(voucher) { gsTransport, ok := transport.(*tp.Transport) if ok { err := gsTransport.UseStore(channelID, lsys) @@ -232,12 +231,12 @@ func TestRoundTrip(t *testing.T) { var chid datatransfer.ChannelID if data.isPull { sv.ExpectSuccessPull() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) + chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), voucher, rootCid, gsData.AllSelector) } else { sv.ExpectSuccessPush() - require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - chid, err = dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, rootCid, gsData.AllSelector) + require.NoError(t, dt2.RegisterVoucherType(testutil.TestVoucherType, sv)) + chid, err = dt1.OpenPushDataChannel(ctx, host2.ID(), voucher, rootCid, gsData.AllSelector) } require.NoError(t, err) opens := 0 @@ -325,9 +324,9 @@ func TestMultipleRoundTripMultipleStores(t *testing.T) { } dt1.SubscribeToEvents(subscriber) dt2.SubscribeToEvents(subscriber) - vouchers := make([]datatransfer.Voucher, 0, data.requestCount) + vouchers := make([]datatransfer.TypedVoucher, 0, data.requestCount) for i := 0; i < data.requestCount; i++ { - vouchers = append(vouchers, testutil.NewFakeDTType()) + vouchers = append(vouchers, testutil.NewTestTypedVoucher()) } sv := testutil.NewStubbedValidator() sv.StubResult(datatransfer.ValidationResult{Accepted: true}) @@ -347,16 +346,13 @@ func TestMultipleRoundTripMultipleStores(t *testing.T) { linkSystems = append(linkSystems, lsys) } - err = dt2.RegisterTransportConfigurer(&testutil.FakeDTType{}, func(channelID datatransfer.ChannelID, testVoucher datatransfer.Voucher, transport datatransfer.Transport) { - fv, ok := testVoucher.(*testutil.FakeDTType) - if ok { - for i, voucher := range vouchers { - if fv.Data == voucher.(*testutil.FakeDTType).Data { - gsTransport, ok := transport.(*tp.Transport) - if ok { - err := gsTransport.UseStore(channelID, linkSystems[i]) - require.NoError(t, err) - } + err = dt2.RegisterTransportConfigurer(testutil.TestVoucherType, func(channelID datatransfer.ChannelID, testVoucher datatransfer.TypedVoucher, transport datatransfer.Transport) { + for i, voucher := range vouchers { + if testVoucher.Equals(voucher) { + gsTransport, ok := transport.(*tp.Transport) + if ok { + err := gsTransport.UseStore(channelID, linkSystems[i]) + require.NoError(t, err) } } } @@ -365,14 +361,14 @@ func TestMultipleRoundTripMultipleStores(t *testing.T) { if data.isPull { sv.ExpectSuccessPull() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) for i := 0; i < data.requestCount; i++ { _, err = dt2.OpenPullDataChannel(ctx, host1.ID(), vouchers[i], rootCid, gsData.AllSelector) require.NoError(t, err) } } else { sv.ExpectSuccessPush() - require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, dt2.RegisterVoucherType(testutil.TestVoucherType, sv)) for i := 0; i < data.requestCount; i++ { _, err = dt1.OpenPushDataChannel(ctx, host2.ID(), vouchers[i], rootCid, gsData.AllSelector) require.NoError(t, err) @@ -453,10 +449,9 @@ func TestManyReceiversAtOnce(t *testing.T) { err = receiver.Start(gsData.Ctx) require.NoError(t, err) - err = receiver.RegisterTransportConfigurer(&testutil.FakeDTType{}, func(channelID datatransfer.ChannelID, testVoucher datatransfer.Voucher, transport datatransfer.Transport) { - _, isFv := testVoucher.(*testutil.FakeDTType) + err = receiver.RegisterTransportConfigurer(testutil.TestVoucherType, func(channelID datatransfer.ChannelID, testVoucher datatransfer.TypedVoucher, transport datatransfer.Transport) { gsTransport, isGs := transport.(*tp.Transport) - if isFv && isGs { + if isGs { err := gsTransport.UseStore(channelID, altLinkSystem) require.NoError(t, err) } @@ -488,9 +483,9 @@ func TestManyReceiversAtOnce(t *testing.T) { for _, receiver := range receivers { receiver.SubscribeToEvents(subscriber) } - vouchers := make([]datatransfer.Voucher, 0, data.receiverCount) + vouchers := make([]datatransfer.TypedVoucher, 0, data.receiverCount) for i := 0; i < data.receiverCount; i++ { - vouchers = append(vouchers, testutil.NewFakeDTType()) + vouchers = append(vouchers, testutil.NewTestTypedVoucher()) } sv := testutil.NewStubbedValidator() sv.StubResult(datatransfer.ValidationResult{Accepted: true}) @@ -500,7 +495,7 @@ func TestManyReceiversAtOnce(t *testing.T) { if data.isPull { sv.ExpectSuccessPull() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) for i, receiver := range receivers { _, err = receiver.OpenPullDataChannel(ctx, host1.ID(), vouchers[i], rootCid, gsData.AllSelector) require.NoError(t, err) @@ -508,7 +503,7 @@ func TestManyReceiversAtOnce(t *testing.T) { } else { sv.ExpectSuccessPush() for i, receiver := range receivers { - require.NoError(t, receiver.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, receiver.RegisterVoucherType(testutil.TestVoucherType, sv)) _, err = dt1.OpenPushDataChannel(ctx, hosts[i].ID(), vouchers[i], rootCid, gsData.AllSelector) require.NoError(t, err) } @@ -736,7 +731,7 @@ func TestAutoRestart(t *testing.T) { } initiator.SubscribeToEvents(subscriber) responder.SubscribeToEvents(subscriber) - voucher := testutil.FakeDTType{Data: "applesauce"} + voucher := testutil.NewTestTypedVoucherWith("applesauce") sv := testutil.NewStubbedValidator() sv.StubResult(datatransfer.ValidationResult{Accepted: true}) sv.StubRestartResult(datatransfer.ValidationResult{Accepted: true}) @@ -753,8 +748,8 @@ func TestAutoRestart(t *testing.T) { root, origBytes := testutil.LoadUnixFSFile(ctx, t, sourceDagService, loremFile) rootCid := root.(cidlink.Link).Cid - require.NoError(t, initiator.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - require.NoError(t, responder.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, initiator.RegisterVoucherType(testutil.TestVoucherType, sv)) + require.NoError(t, responder.RegisterVoucherType(testutil.TestVoucherType, sv)) // If the test case needs to subscribe to response events, provide // the test case with the responder @@ -775,10 +770,10 @@ func TestAutoRestart(t *testing.T) { var chid datatransfer.ChannelID if tc.isPush { // Open a push channel - chid, err = initiator.OpenPushDataChannel(ctx, responderHost.ID(), &voucher, rootCid, gsData.AllSelector) + chid, err = initiator.OpenPushDataChannel(ctx, responderHost.ID(), voucher, rootCid, gsData.AllSelector) } else { // Open a pull channel - chid, err = initiator.OpenPullDataChannel(ctx, responderHost.ID(), &voucher, rootCid, gsData.AllSelector) + chid, err = initiator.OpenPullDataChannel(ctx, responderHost.ID(), voucher, rootCid, gsData.AllSelector) } require.NoError(t, err) @@ -929,7 +924,7 @@ func TestAutoRestartAfterBouncingInitiator(t *testing.T) { } dataReceived := onDataReceivedChan(dataReceiver) - voucher := testutil.FakeDTType{Data: "applesauce"} + voucher := testutil.NewTestTypedVoucherWith("applesauce") sv := testutil.NewStubbedValidator() sv.StubResult(datatransfer.ValidationResult{Accepted: true}) sv.StubRestartResult(datatransfer.ValidationResult{Accepted: true}) @@ -946,16 +941,16 @@ func TestAutoRestartAfterBouncingInitiator(t *testing.T) { root, origBytes := testutil.LoadUnixFSFile(ctx, t, sourceDagService, loremLargeFile) rootCid := root.(cidlink.Link).Cid - require.NoError(t, initiator.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - require.NoError(t, responder.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, initiator.RegisterVoucherType(testutil.TestVoucherType, sv)) + require.NoError(t, responder.RegisterVoucherType(testutil.TestVoucherType, sv)) var chid datatransfer.ChannelID if isPush { // Open a push channel - chid, err = initiator.OpenPushDataChannel(ctx, responderHost.ID(), &voucher, rootCid, gsData.AllSelector) + chid, err = initiator.OpenPushDataChannel(ctx, responderHost.ID(), voucher, rootCid, gsData.AllSelector) } else { // Open a pull channel - chid, err = initiator.OpenPullDataChannel(ctx, responderHost.ID(), &voucher, rootCid, gsData.AllSelector) + chid, err = initiator.OpenPullDataChannel(ctx, responderHost.ID(), voucher, rootCid, gsData.AllSelector) } require.NoError(t, err) @@ -987,7 +982,7 @@ func TestAutoRestartAfterBouncingInitiator(t *testing.T) { initiator2GSTspt := gsData.SetupGSTransportHost1() initiator2, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, initiator2GSTspt, restartConf) require.NoError(t, err) - require.NoError(t, initiator2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, initiator2.RegisterVoucherType(testutil.TestVoucherType, sv)) initiator2.SubscribeToEvents(completeSubscriber) testutil.StartAndWaitForReady(ctx, t, initiator2) @@ -1127,7 +1122,7 @@ func TestRoundTripCancelledRequest(t *testing.T) { } dt1.SubscribeToEvents(subscriber) dt2.SubscribeToEvents(subscriber) - voucher := testutil.FakeDTType{Data: "applesauce"} + voucher := testutil.NewTestTypedVoucherWith("applesauce") sv := testutil.NewStubbedValidator() root, _ := testutil.LoadUnixFSFile(ctx, t, gsData.DagService1, loremFile) rootCid := root.(cidlink.Link).Cid @@ -1136,13 +1131,13 @@ func TestRoundTripCancelledRequest(t *testing.T) { if data.isPull { sv.ExpectSuccessPull() sv.StubResult(datatransfer.ValidationResult{Accepted: true, ForcePause: true}) - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) + chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), voucher, rootCid, gsData.AllSelector) } else { sv.ExpectSuccessPush() sv.StubResult(datatransfer.ValidationResult{Accepted: true, ForcePause: true}) - require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - chid, err = dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, rootCid, gsData.AllSelector) + require.NoError(t, dt2.RegisterVoucherType(testutil.TestVoucherType, sv)) + chid, err = dt1.OpenPushDataChannel(ctx, host2.ID(), voucher, rootCid, gsData.AllSelector) } require.NoError(t, err) opens := 0 @@ -1187,21 +1182,23 @@ type retrievalRevalidator struct { providerPausePoint int pausePoints []uint64 leavePausedInitially bool - initialVoucherResult datatransfer.VoucherResult + initialVoucherResult *datatransfer.TypedVoucher requiresFinalization bool } func (r *retrievalRevalidator) ValidatePush( chid datatransfer.ChannelID, sender peer.ID, - voucher datatransfer.Voucher, + voucher datamodel.Node, baseCid cid.Cid, - selector ipld.Node) (datatransfer.ValidationResult, error) { + selector datamodel.Node) (datatransfer.ValidationResult, error) { vr := datatransfer.ValidationResult{ Accepted: true, RequiresFinalization: r.requiresFinalization, ForcePause: r.leavePausedInitially, - VoucherResult: r.initialVoucherResult, + } + if r.initialVoucherResult != nil { + vr.VoucherResult = r.initialVoucherResult } if len(r.pausePoints) > r.providerPausePoint { vr.DataLimit = r.pausePoints[r.providerPausePoint] @@ -1214,14 +1211,16 @@ func (r *retrievalRevalidator) ValidatePush( func (r *retrievalRevalidator) ValidatePull( chid datatransfer.ChannelID, sender peer.ID, - voucher datatransfer.Voucher, + voucher datamodel.Node, baseCid cid.Cid, - selector ipld.Node) (datatransfer.ValidationResult, error) { + selector datamodel.Node) (datatransfer.ValidationResult, error) { vr := datatransfer.ValidationResult{ Accepted: true, RequiresFinalization: r.requiresFinalization, ForcePause: r.leavePausedInitially, - VoucherResult: r.initialVoucherResult, + } + if r.initialVoucherResult != nil { + vr.VoucherResult = r.initialVoucherResult } if len(r.pausePoints) > r.providerPausePoint { vr.DataLimit = r.pausePoints[r.providerPausePoint] @@ -1329,26 +1328,24 @@ func TestSimulatedRetrievalFlow(t *testing.T) { errChan := make(chan struct{}, 2) clientPausePoint := 0 clientFinished := make(chan struct{}, 1) - finalVoucherResult := testutil.NewFakeDTType() - encodedFVR, err := encoding.Encode(finalVoucherResult) + finalVoucherResult := testutil.NewTestTypedVoucher() require.NoError(t, err) var clientSubscriber datatransfer.Subscriber = func(event datatransfer.Event, channelState datatransfer.ChannelState) { if event.Code == datatransfer.Error { errChan <- struct{}{} } if event.Code == datatransfer.NewVoucherResult { - lastVoucherResult := channelState.LastVoucherResult() - encodedLVR, err := encoding.Encode(lastVoucherResult) + lastVoucherResult, err := channelState.LastVoucherResult() require.NoError(t, err) - if bytes.Equal(encodedLVR, encodedFVR) { - _ = dt2.SendVoucher(ctx, chid, testutil.NewFakeDTType()) + if lastVoucherResult.Equals(finalVoucherResult) { + _ = dt2.SendVoucher(ctx, chid, testutil.NewTestTypedVoucher()) } } if event.Code == datatransfer.DataReceived && clientPausePoint < len(config.pausePoints) && channelState.Received() > config.pausePoints[clientPausePoint] { - _ = dt2.SendVoucher(ctx, chid, testutil.NewFakeDTType()) + _ = dt2.SendVoucher(ctx, chid, testutil.NewTestTypedVoucher()) clientPausePoint++ } if channelState.Status() == datatransfer.Completed { @@ -1376,7 +1373,7 @@ func TestSimulatedRetrievalFlow(t *testing.T) { dt1.UpdateValidationStatus(ctx, chid, sv.nextStatus()) } if event.Code == datatransfer.DataLimitExceeded { - dt1.SendVoucherResult(ctx, chid, testutil.NewFakeDTType()) + dt1.SendVoucherResult(ctx, chid, testutil.NewTestTypedVoucher()) } if event.Code == datatransfer.BeginFinalizing { sv.requiresFinalization = false @@ -1390,12 +1387,11 @@ func TestSimulatedRetrievalFlow(t *testing.T) { } } dt1.SubscribeToEvents(providerSubscriber) - voucher := testutil.FakeDTType{Data: "applesauce"} + voucher := testutil.NewTestTypedVoucherWith("applesauce") - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) - require.NoError(t, dt2.RegisterVoucherResultType(testutil.NewFakeDTType())) - chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) + chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), voucher, rootCid, gsData.AllSelector) require.NoError(t, err) for providerFinished != nil || clientFinished != nil { @@ -1495,7 +1491,7 @@ func TestPauseAndResume(t *testing.T) { } dt1.SubscribeToEvents(subscriber) dt2.SubscribeToEvents(subscriber) - voucher := testutil.FakeDTType{Data: "applesauce"} + voucher := testutil.NewTestTypedVoucherWith("applesauce") sv := testutil.NewStubbedValidator() sv.StubResult(datatransfer.ValidationResult{Accepted: true}) sv.StubRestartResult(datatransfer.ValidationResult{Accepted: true}) @@ -1503,12 +1499,12 @@ func TestPauseAndResume(t *testing.T) { var chid datatransfer.ChannelID if isPull { sv.ExpectSuccessPull() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) + chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), voucher, rootCid, gsData.AllSelector) } else { sv.ExpectSuccessPush() - require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - chid, err = dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, rootCid, gsData.AllSelector) + require.NoError(t, dt2.RegisterVoucherType(testutil.TestVoucherType, sv)) + chid, err = dt1.OpenPushDataChannel(ctx, host2.ID(), voucher, rootCid, gsData.AllSelector) } require.NoError(t, err) opens := 0 @@ -1606,15 +1602,15 @@ func TestUnrecognizedVoucherRoundTrip(t *testing.T) { } dt1.SubscribeToEvents(subscriber) dt2.SubscribeToEvents(subscriber) - voucher := testutil.FakeDTType{Data: "applesauce"} + voucher := testutil.NewTestTypedVoucherWith("applesauce") root, _ := testutil.LoadUnixFSFile(ctx, t, gsData.DagService1, loremFile) rootCid := root.(cidlink.Link).Cid if isPull { - _, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) + _, err = dt2.OpenPullDataChannel(ctx, host1.ID(), voucher, rootCid, gsData.AllSelector) } else { - _, err = dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, rootCid, gsData.AllSelector) + _, err = dt1.OpenPushDataChannel(ctx, host2.ID(), voucher, rootCid, gsData.AllSelector) } require.NoError(t, err) opens := 0 @@ -1656,8 +1652,8 @@ func TestDataTransferSubscribing(t *testing.T) { dt2, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt2) - require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - voucher := testutil.FakeDTType{Data: "applesauce"} + require.NoError(t, dt2.RegisterVoucherType(testutil.TestVoucherType, sv)) + voucher := testutil.NewTestTypedVoucherWith("applesauce") baseCid := testutil.GenerateCids(1)[0] dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1) @@ -1677,7 +1673,7 @@ func TestDataTransferSubscribing(t *testing.T) { } unsub1 := dt1.SubscribeToEvents(subscribe1) unsub2 := dt1.SubscribeToEvents(subscribe2) - _, err = dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, baseCid, gsData.AllSelector) + _, err = dt1.OpenPushDataChannel(ctx, host2.ID(), voucher, baseCid, gsData.AllSelector) require.NoError(t, err) select { case <-ctx.Done(): @@ -1706,7 +1702,7 @@ func TestDataTransferSubscribing(t *testing.T) { } unsub3 := dt1.SubscribeToEvents(subscribe3) unsub4 := dt1.SubscribeToEvents(subscribe4) - _, err = dt1.OpenPullDataChannel(ctx, host2.ID(), &voucher, baseCid, gsData.AllSelector) + _, err = dt1.OpenPullDataChannel(ctx, host2.ID(), voucher, baseCid, gsData.AllSelector) require.NoError(t, err) select { case <-ctx.Done(): @@ -1776,7 +1772,7 @@ func TestRespondingToPushGraphsyncRequests(t *testing.T) { gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator and data sender host2 := gsData.Host2 // data recipient, makes graphsync request for data - voucher := testutil.NewFakeDTType() + voucher := testutil.NewTestTypedVoucher() link := gsData.LoadUnixFSFile(t, false) // setup receiving peer to just record message coming in @@ -1795,8 +1791,7 @@ func TestRespondingToPushGraphsyncRequests(t *testing.T) { dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) - voucherResult := testutil.NewFakeDTType() - err = dt1.RegisterVoucherResultType(voucherResult) + voucherResult := testutil.NewTestTypedVoucher() require.NoError(t, err) t.Run("when request is initiated", func(t *testing.T) { @@ -1811,7 +1806,7 @@ func TestRespondingToPushGraphsyncRequests(t *testing.T) { } requestReceived := messageReceived.message.(datatransfer.Request) - response, err := message.NewResponse(requestReceived.TransferID(), true, false, voucherResult.Type(), voucherResult) + response, err := message.NewResponse(requestReceived.TransferID(), true, false, &voucherResult) require.NoError(t, err) nd, err := response.ToIPLD() require.NoError(t, err) @@ -1830,7 +1825,7 @@ func TestRespondingToPushGraphsyncRequests(t *testing.T) { }) t.Run("when no request is initiated", func(t *testing.T) { - response, err := message.NewResponse(datatransfer.TransferID(rand.Uint32()), true, false, voucher.Type(), voucher) + response, err := message.NewResponse(datatransfer.TransferID(rand.Uint32()), true, false, &voucher) require.NoError(t, err) nd, err := response.ToIPLD() require.NoError(t, err) @@ -1857,7 +1852,7 @@ func TestResponseHookWhenExtensionNotFound(t *testing.T) { gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator and data sender host2 := gsData.Host2 // data recipient, makes graphsync request for data - voucher := testutil.FakeDTType{Data: "applesauce"} + voucher := testutil.NewTestTypedVoucherWith("applesauce") link := gsData.LoadUnixFSFile(t, false) // setup receiving peer to just record message coming in @@ -1885,7 +1880,7 @@ func TestResponseHookWhenExtensionNotFound(t *testing.T) { } gs1.RegisterIncomingRequestHook(validateHook) - _, err := dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, link.(cidlink.Link).Cid, gsData.AllSelector) + _, err := dt1.OpenPushDataChannel(ctx, host2.ID(), voucher, link.(cidlink.Link).Cid, gsData.AllSelector) require.NoError(t, err) select { @@ -1921,12 +1916,13 @@ func TestRespondingToPullGraphsyncRequests(t *testing.T) { dt1, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) - voucher := testutil.NewFakeDTType() - request, err := message.NewRequest(id, false, true, voucher.Type(), voucher, testutil.GenerateCids(1)[0], gsData.AllSelector) + voucher := testutil.NewTestTypedVoucher() + request, err := message.NewRequest(id, false, true, &voucher, testutil.GenerateCids(1)[0], gsData.AllSelector) require.NoError(t, err) nd, err := request.ToIPLD() + require.NoError(t, err) gsRequest := gsmsg.NewRequest(graphsync.NewRequestID(), link.(cidlink.Link).Cid, gsData.AllSelector, graphsync.Priority(rand.Int31()), graphsync.ExtensionData{ Name: extension.ExtensionDataTransfer1_1, Data: nd, @@ -1949,9 +1945,9 @@ func TestRespondingToPullGraphsyncRequests(t *testing.T) { dt1, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - voucher := testutil.NewFakeDTType() - dtRequest, err := message.NewRequest(id, false, true, voucher.Type(), voucher, testutil.GenerateCids(1)[0], gsData.AllSelector) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) + voucher := testutil.NewTestTypedVoucher() + dtRequest, err := message.NewRequest(id, false, true, &voucher, testutil.GenerateCids(1)[0], gsData.AllSelector) require.NoError(t, err) nd, err := dtRequest.ToIPLD() @@ -2033,24 +2029,20 @@ func TestMultipleMessagesInExtension(t *testing.T) { // In this retrieval flow we expect 2 voucher results: // The first one is sent as a response from the initial request telling the client // the provider has accepted the request and is starting to send blocks - respVoucher := testutil.NewFakeDTType() - encodedRVR, err := encoding.Encode(respVoucher) - require.NoError(t, err) + respVoucher := testutil.NewTestTypedVoucher() // voucher results are sent by the providers to request payment while pausing until a voucher is sent // to revalidate - voucherResults := []datatransfer.VoucherResult{ - &testutil.FakeDTType{Data: "one"}, - &testutil.FakeDTType{Data: "two"}, - &testutil.FakeDTType{Data: "thr"}, - &testutil.FakeDTType{Data: "for"}, - &testutil.FakeDTType{Data: "fiv"}, + voucherResults := []datatransfer.TypedVoucher{ + testutil.NewTestTypedVoucherWith("one"), + testutil.NewTestTypedVoucherWith("two"), + testutil.NewTestTypedVoucherWith("thr"), + testutil.NewTestTypedVoucherWith("for"), + testutil.NewTestTypedVoucherWith("fiv"), } // The final voucher result is sent by the provider to request a last payment voucher - finalVoucherResult := testutil.NewFakeDTType() - encodedFVR, err := encoding.Encode(finalVoucherResult) - require.NoError(t, err) + finalVoucherResult := testutil.NewTestTypedVoucher() dt2.SubscribeToEvents(func(event datatransfer.Event, channelState datatransfer.ChannelState) { if event.Code == datatransfer.Error { @@ -2058,13 +2050,12 @@ func TestMultipleMessagesInExtension(t *testing.T) { } // Here we verify reception of voucherResults by the client if event.Code == datatransfer.NewVoucherResult { - voucherResult := channelState.LastVoucherResult() - encodedVR, err := encoding.Encode(voucherResult) + voucherResult, err := channelState.LastVoucherResult() require.NoError(t, err) // If this voucher result is the response voucher no action is needed // we just know that the provider has accepted the transfer and is sending blocks - if bytes.Equal(encodedVR, encodedRVR) { + if voucherResult.Equals(respVoucher) { // The test will fail if no response voucher is received clientGotResponse <- struct{}{} } @@ -2072,18 +2063,16 @@ func TestMultipleMessagesInExtension(t *testing.T) { // If this voucher is a revalidation request we need to send a new voucher // to revalidate and unpause the transfer if clientPausePoint < 5 { - encodedExpected, err := encoding.Encode(voucherResults[clientPausePoint]) - require.NoError(t, err) - if bytes.Equal(encodedVR, encodedExpected) { - _ = dt2.SendVoucher(ctx, chid, testutil.NewFakeDTType()) + if voucherResult.Equals(voucherResults[clientPausePoint]) { + _ = dt2.SendVoucher(ctx, chid, testutil.NewTestTypedVoucher()) clientPausePoint++ } } // If this voucher result is the final voucher result we need // to send a new voucher to unpause the provider and complete the transfer - if bytes.Equal(encodedVR, encodedFVR) { - _ = dt2.SendVoucher(ctx, chid, testutil.NewFakeDTType()) + if voucherResult.Equals(finalVoucherResult) { + _ = dt2.SendVoucher(ctx, chid, testutil.NewTestTypedVoucher()) } } @@ -2098,7 +2087,7 @@ func TestMultipleMessagesInExtension(t *testing.T) { StubbedValidator: testutil.NewStubbedValidator(), pausePoints: pausePoints, requiresFinalization: true, - initialVoucherResult: respVoucher, + initialVoucherResult: &respVoucher, } dt1.SubscribeToEvents(func(event datatransfer.Event, channelState datatransfer.ChannelState) { if event.Code == datatransfer.Error { @@ -2122,12 +2111,10 @@ func TestMultipleMessagesInExtension(t *testing.T) { dt1.SendVoucherResult(ctx, chid, finalVoucherResult) } }) - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - - require.NoError(t, dt2.RegisterVoucherResultType(testutil.NewFakeDTType())) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) - voucher := testutil.FakeDTType{Data: "applesauce"} - chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) + voucher := testutil.NewTestTypedVoucherWith("applesauce") + chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), voucher, rootCid, gsData.AllSelector) require.NoError(t, err) // Expect the client to receive a response voucher, the provider to complete the transfer and @@ -2174,22 +2161,18 @@ func TestMultipleParallelTransfers(t *testing.T) { // In this retrieval flow we expect 2 voucher results: // The first one is sent as a response from the initial request telling the client // the provider has accepted the request and is starting to send blocks - respVoucher := testutil.NewFakeDTType() - encodedRVR, err := encoding.Encode(respVoucher) + respVoucher := testutil.NewTestTypedVoucher() require.NoError(t, err) // The final voucher result is sent by the provider to let the client know the deal is completed - finalVoucherResult := testutil.NewFakeDTType() - encodedFVR, err := encoding.Encode(finalVoucherResult) + finalVoucherResult := testutil.NewTestTypedVoucher() require.NoError(t, err) sv := &retrievalRevalidator{ StubbedValidator: testutil.NewStubbedValidator(), - initialVoucherResult: respVoucher, + initialVoucherResult: &respVoucher, } - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - - require.NoError(t, dt2.RegisterVoucherResultType(testutil.NewFakeDTType())) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) // for each size we create a new random DAG of the given size and try to retrieve it for _, size := range sizes { @@ -2217,21 +2200,21 @@ func TestMultipleParallelTransfers(t *testing.T) { } // Here we verify reception of voucherResults by the client if event.Code == datatransfer.NewVoucherResult { - voucherResult := channelState.LastVoucherResult() - encodedVR, err := encoding.Encode(voucherResult) + voucherResult, err := channelState.LastVoucherResult() + require.NoError(t, err) require.NoError(t, err) // If this voucher result is the response voucher no action is needed // we just know that the provider has accepted the transfer and is sending blocks - if bytes.Equal(encodedVR, encodedRVR) { + if voucherResult.Equals(respVoucher) { // The test will fail if no response voucher is received clientGotResponse <- struct{}{} } // If this voucher result is the final voucher result we need // to send a new voucher to unpause the provider and complete the transfer - if bytes.Equal(encodedVR, encodedFVR) { - _ = dt2.SendVoucher(ctx, chid, testutil.NewFakeDTType()) + if voucherResult.Equals(finalVoucherResult) { + _ = dt2.SendVoucher(ctx, chid, testutil.NewTestTypedVoucher()) } } @@ -2260,7 +2243,7 @@ func TestMultipleParallelTransfers(t *testing.T) { root, origBytes := LoadRandomData(ctx, t, gsData.DagService1, size) rootCid := root.(cidlink.Link).Cid - voucher := testutil.NewFakeDTType() + voucher := testutil.NewTestTypedVoucher() chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), voucher, rootCid, gsData.AllSelector) require.NoError(t, err) close(chidReceived) diff --git a/impl/receiving_requests.go b/impl/receiving_requests.go index cb86afa7..4760fffb 100644 --- a/impl/receiving_requests.go +++ b/impl/receiving_requests.go @@ -4,7 +4,7 @@ import ( "context" "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" "github.com/libp2p/go-libp2p-core/peer" "golang.org/x/xerrors" @@ -41,13 +41,16 @@ func (m *manager) acceptRequest(chid datatransfer.ChannelID, incoming datatransf return datatransfer.ValidationResult{}, err } - voucher, err := m.decodeVoucher(incoming) + voucher, err := incoming.TypedVoucher() if err != nil { return datatransfer.ValidationResult{}, err } + processor, ok := m.validatedTypes.Processor(voucher.Type) + if !ok { + return datatransfer.ValidationResult{}, xerrors.Errorf("unknown voucher type: %s", voucher.Type) + } - var validatorFunc func(datatransfer.ChannelID, peer.ID, datatransfer.Voucher, cid.Cid, ipld.Node) (datatransfer.ValidationResult, error) - processor, _ := m.validatedTypes.Processor(voucher.Type()) + var validatorFunc func(datatransfer.ChannelID, peer.ID, datamodel.Node, cid.Cid, datamodel.Node) (datatransfer.ValidationResult, error) validator := processor.(datatransfer.RequestValidator) if incoming.IsPull() { validatorFunc = validator.ValidatePull @@ -55,7 +58,7 @@ func (m *manager) acceptRequest(chid datatransfer.ChannelID, incoming datatransf validatorFunc = validator.ValidatePush } - result, err := validatorFunc(chid, chid.Initiator, voucher, incoming.BaseCid(), stor) + result, err := validatorFunc(chid, chid.Initiator, voucher.Voucher, incoming.BaseCid(), stor) // if an error occurred during validation or the request was not accepted, return if err != nil || !result.Accepted { @@ -73,7 +76,16 @@ func (m *manager) acceptRequest(chid datatransfer.ChannelID, incoming datatransf } log.Infow("data-transfer request validated, will create & start tracking channel", "channelID", chid, "payloadCid", incoming.BaseCid()) - _, err = m.channels.CreateNew(m.peerID, incoming.TransferID(), incoming.BaseCid(), stor, voucher, chid.Initiator, dataSender, dataReceiver) + _, err = m.channels.CreateNew( + m.peerID, + incoming.TransferID(), + incoming.BaseCid(), + stor, + voucher, + chid.Initiator, + dataSender, + dataReceiver, + ) if err != nil { log.Errorw("failed to create and start tracking channel", "channelID", chid, "err", err) return result, err @@ -97,7 +109,7 @@ func (m *manager) acceptRequest(chid datatransfer.ChannelID, incoming datatransf } // configure the transport - processor, has := m.transportConfigurers.Processor(voucher.Type()) + processor, has := m.transportConfigurers.Processor(voucher.Type) if has { transportConfigurer := processor.(datatransfer.TransportConfigurer) transportConfigurer(chid, voucher, m.transport) @@ -173,14 +185,16 @@ func (m *manager) restartRequest(chid datatransfer.ChannelID, } // configure the transport - voucher, err := m.decodeVoucher(incoming) + voucher, err := incoming.Voucher() if err != nil { return stayPaused, result, err } - processor, has := m.transportConfigurers.Processor(voucher.Type()) + voucherType := incoming.VoucherType() + typedVoucher := datatransfer.TypedVoucher{Voucher: voucher, Type: voucherType} + processor, has := m.transportConfigurers.Processor(voucherType) if has { transportConfigurer := processor.(datatransfer.TransportConfigurer) - transportConfigurer(chid, voucher, m.transport) + transportConfigurer(chid, typedVoucher, m.transport) } m.dataTransferNetwork.Protect(initiator, chid.String()) return stayPaused, result, nil @@ -189,11 +203,11 @@ func (m *manager) restartRequest(chid datatransfer.ChannelID, // processUpdateVoucher handles an incoming request message with an updated voucher func (m *manager) processUpdateVoucher(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { // decode the voucher and save it on the channel - vouch, err := m.decodeVoucher(request) + voucher, err := request.TypedVoucher() if err != nil { return nil, err } - return nil, m.channels.NewVoucher(chid, vouch) + return nil, m.channels.NewVoucher(chid, voucher) } // receiveUpdateRequest handles an incoming request message with an updated voucher @@ -238,7 +252,7 @@ func (m *manager) requestError(result datatransfer.ValidationResult, resultErr e // recordRejectedValidationEvents sends changes based on an reject validation to the state machine func (m *manager) recordRejectedValidationEvents(chid datatransfer.ChannelID, result datatransfer.ValidationResult) error { if result.VoucherResult != nil { - if err := m.channels.NewVoucherResult(chid, result.VoucherResult); err != nil { + if err := m.channels.NewVoucherResult(chid, *result.VoucherResult); err != nil { return err } } @@ -251,8 +265,8 @@ func (m *manager) recordAcceptedValidationEvents(chst datatransfer.ChannelState, chid := chst.ChannelID() // record the voucher result if present - if result.VoucherResult != nil { - err := m.channels.NewVoucherResult(chid, result.VoucherResult) + if result.VoucherResult != nil && result.VoucherResult.Voucher != nil { + err := m.channels.NewVoucherResult(chid, *result.VoucherResult) if err != nil { return err } @@ -296,7 +310,11 @@ func (m *manager) recordAcceptedValidationEvents(chst datatransfer.ChannelState, // validateRestart looks up the appropriate validator and validates a restart func (m *manager) validateRestart(chst datatransfer.ChannelState) (datatransfer.ValidationResult, error) { - processor, _ := m.validatedTypes.Processor(chst.Voucher().Type()) + chv, err := chst.Voucher() + if err != nil { + return datatransfer.ValidationResult{}, err + } + processor, _ := m.validatedTypes.Processor(chv.Type) validator := processor.(datatransfer.RequestValidator) return validator.ValidateRestart(chst.ChannelID(), chst) diff --git a/impl/responding_test.go b/impl/responding_test.go index 0fe466f7..d3adfb5c 100644 --- a/impl/responding_test.go +++ b/impl/responding_test.go @@ -11,6 +11,7 @@ import ( "github.com/ipfs/go-datastore" dss "github.com/ipfs/go-datastore/sync" "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/assert" @@ -40,7 +41,8 @@ func TestDataTransferResponding(t *testing.T) { }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) @@ -48,7 +50,7 @@ func TestDataTransferResponding(t *testing.T) { validation := h.sv.ValidationsReceived[0] assert.False(t, validation.IsPull) assert.Equal(t, h.peers[1], validation.Other) - assert.Equal(t, h.voucher, validation.Voucher) + assert.True(t, ipld.DeepEqual(h.voucher.Voucher, validation.Voucher)) assert.Equal(t, h.baseCid, validation.BaseCid) assert.Equal(t, h.stor, validation.Selector) @@ -73,7 +75,8 @@ func TestDataTransferResponding(t *testing.T) { "new push request rejects": { configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(datatransfer.ValidationResult{Accepted: false, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: false, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) @@ -94,7 +97,8 @@ func TestDataTransferResponding(t *testing.T) { "new push request errors": { configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectErrorPush() - sv.StubResult(datatransfer.ValidationResult{VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) @@ -115,7 +119,8 @@ func TestDataTransferResponding(t *testing.T) { "new push request pauses": { configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, ForcePause: true, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, ForcePause: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) @@ -156,7 +161,7 @@ func TestDataTransferResponding(t *testing.T) { validation := h.sv.ValidationsReceived[0] assert.True(t, validation.IsPull) assert.Equal(t, h.peers[1], validation.Other) - assert.Equal(t, h.voucher, validation.Voucher) + assert.True(t, ipld.DeepEqual(h.voucher.Voucher, validation.Voucher)) assert.Equal(t, h.baseCid, validation.BaseCid) assert.Equal(t, h.stor, validation.Selector) require.True(t, response.Accepted()) @@ -227,7 +232,7 @@ func TestDataTransferResponding(t *testing.T) { }, verify: func(t *testing.T, h *receiverHarness) { h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) - newVoucherResult := testutil.NewFakeDTType() + newVoucherResult := testutil.NewTestTypedVoucher() err := h.dt.SendVoucherResult(h.ctx, channelID(h.id, h.peers), newVoucherResult) require.NoError(t, err) }, @@ -239,7 +244,7 @@ func TestDataTransferResponding(t *testing.T) { }, verify: func(t *testing.T, h *receiverHarness) { _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) - newVoucherResult := testutil.NewFakeDTType() + newVoucherResult := testutil.NewTestTypedVoucher() err := h.dt.SendVoucherResult(h.ctx, channelID(h.id, h.peers), newVoucherResult) require.NoError(t, err) }, @@ -251,7 +256,7 @@ func TestDataTransferResponding(t *testing.T) { }, verify: func(t *testing.T, h *receiverHarness) { h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) - newVoucher := testutil.NewFakeDTType() + newVoucher := testutil.NewTestTypedVoucher() err := h.dt.SendVoucher(h.ctx, channelID(h.id, h.peers), newVoucher) require.EqualError(t, err, "cannot send voucher for request we did not initiate") }, @@ -263,7 +268,7 @@ func TestDataTransferResponding(t *testing.T) { }, verify: func(t *testing.T, h *receiverHarness) { _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) - newVoucher := testutil.NewFakeDTType() + newVoucher := testutil.NewTestTypedVoucher() err := h.dt.SendVoucher(h.ctx, channelID(h.id, h.peers), newVoucher) require.EqualError(t, err, "cannot send voucher for request we did not initiate") }, @@ -277,7 +282,8 @@ func TestDataTransferResponding(t *testing.T) { }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) @@ -294,7 +300,8 @@ func TestDataTransferResponding(t *testing.T) { datatransfer.ResumeInitiator}, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) @@ -314,7 +321,8 @@ func TestDataTransferResponding(t *testing.T) { datatransfer.ResumeInitiator}, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) @@ -336,7 +344,8 @@ func TestDataTransferResponding(t *testing.T) { }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) @@ -362,7 +371,8 @@ func TestDataTransferResponding(t *testing.T) { }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, DataLimit: 1000, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, DataLimit: 1000, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) @@ -379,7 +389,8 @@ func TestDataTransferResponding(t *testing.T) { response, err = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.voucherUpdate) require.NoError(t, err, nil) require.Nil(t, response) - err = h.dt.UpdateValidationStatus(ctx, channelID(h.id, h.peers), datatransfer.ValidationResult{Accepted: true, DataLimit: 50000, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + err = h.dt.UpdateValidationStatus(ctx, channelID(h.id, h.peers), datatransfer.ValidationResult{Accepted: true, DataLimit: 50000, VoucherResult: &vr}) require.NoError(t, err) require.Len(t, h.transport.ResumedChannels, 1) resCh := h.transport.ResumedChannels[0] @@ -411,7 +422,8 @@ func TestDataTransferResponding(t *testing.T) { }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, DataLimit: 1000, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, DataLimit: 1000, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) @@ -428,7 +440,8 @@ func TestDataTransferResponding(t *testing.T) { response, err = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.voucherUpdate) require.NoError(t, err, nil) require.Nil(t, response) - err = h.dt.UpdateValidationStatus(ctx, channelID(h.id, h.peers), datatransfer.ValidationResult{Accepted: false, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + err = h.dt.UpdateValidationStatus(ctx, channelID(h.id, h.peers), datatransfer.ValidationResult{Accepted: false, VoucherResult: &vr}) require.NoError(t, err) require.Len(t, h.transport.ClosedChannels, 1) require.Equal(t, h.transport.ClosedChannels[0], channelID(h.id, h.peers)) @@ -461,7 +474,8 @@ func TestDataTransferResponding(t *testing.T) { }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, DataLimit: 1000, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, DataLimit: 1000, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) @@ -481,7 +495,8 @@ func TestDataTransferResponding(t *testing.T) { response, err = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.voucherUpdate) require.NoError(t, err, nil) require.Nil(t, response) - err = h.dt.UpdateValidationStatus(ctx, channelID(h.id, h.peers), datatransfer.ValidationResult{Accepted: true, DataLimit: 50000, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + err = h.dt.UpdateValidationStatus(ctx, channelID(h.id, h.peers), datatransfer.ValidationResult{Accepted: true, DataLimit: 50000, VoucherResult: &vr}) require.NoError(t, err) require.Len(t, h.transport.ResumedChannels, 1) resCh := h.transport.ResumedChannels[0] @@ -512,7 +527,8 @@ func TestDataTransferResponding(t *testing.T) { }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, RequiresFinalization: true, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, RequiresFinalization: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) @@ -533,7 +549,8 @@ func TestDataTransferResponding(t *testing.T) { response, err = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.voucherUpdate) require.NoError(t, err, nil) require.Nil(t, response) - err = h.dt.UpdateValidationStatus(ctx, channelID(h.id, h.peers), datatransfer.ValidationResult{Accepted: true, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + err = h.dt.UpdateValidationStatus(ctx, channelID(h.id, h.peers), datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) require.NoError(t, err) require.Len(t, h.network.SentMessages, 2) sentMsg := h.network.SentMessages[1] @@ -572,7 +589,7 @@ func TestDataTransferResponding(t *testing.T) { sv.StubResult(datatransfer.ValidationResult{Accepted: true}) }, verify: func(t *testing.T, h *receiverHarness) { - err := h.dt.RegisterTransportConfigurer(h.voucher, func(channelID datatransfer.ChannelID, voucher datatransfer.Voucher, transport datatransfer.Transport) { + err := h.dt.RegisterTransportConfigurer(h.voucher.Type, func(channelID datatransfer.ChannelID, voucher datatransfer.TypedVoucher, transport datatransfer.Transport) { ft, ok := transport.(*testutil.FakeTransport) if !ok { return @@ -597,7 +614,7 @@ func TestDataTransferResponding(t *testing.T) { sv.StubResult(datatransfer.ValidationResult{Accepted: true}) }, verify: func(t *testing.T, h *receiverHarness) { - err := h.dt.RegisterTransportConfigurer(h.voucher, func(channelID datatransfer.ChannelID, voucher datatransfer.Voucher, transport datatransfer.Transport) { + err := h.dt.RegisterTransportConfigurer(h.voucher.Type, func(channelID datatransfer.ChannelID, voucher datatransfer.TypedVoucher, transport datatransfer.Transport) { ft, ok := transport.(*testutil.FakeTransport) if !ok { return @@ -634,26 +651,26 @@ func TestDataTransferResponding(t *testing.T) { } ev.setup(t, dt) h.stor = testutil.AllSelector() - h.voucher = testutil.NewFakeDTType() + h.voucher = testutil.NewTestTypedVoucher() h.baseCid = testutil.GenerateCids(1)[0] h.id = datatransfer.TransferID(rand.Int31()) - h.pullRequest, err = message.NewRequest(h.id, false, true, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + h.pullRequest, err = message.NewRequest(h.id, false, true, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) - h.pushRequest, err = message.NewRequest(h.id, false, false, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + h.pushRequest, err = message.NewRequest(h.id, false, false, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) h.pauseUpdate = message.UpdateRequest(h.id, true) require.NoError(t, err) h.resumeUpdate = message.UpdateRequest(h.id, false) require.NoError(t, err) - updateVoucher := testutil.NewFakeDTType() - h.voucherUpdate, err = message.VoucherRequest(h.id, updateVoucher.Type(), updateVoucher) + updateVoucher := testutil.NewTestTypedVoucher() + h.voucherUpdate, err = message.VoucherRequest(h.id, &updateVoucher) h.cancelUpdate = message.CancelRequest(h.id) require.NoError(t, err) h.sv = testutil.NewStubbedValidator() if verify.configureValidator != nil { verify.configureValidator(h.sv) } - require.NoError(t, h.dt.RegisterVoucherType(h.voucher, h.sv)) + require.NoError(t, h.dt.RegisterVoucherType(h.voucher.Type, h.sv)) require.NoError(t, err) verify.verify(t, h) h.sv.VerifyExpectations(t) @@ -677,7 +694,7 @@ func TestDataTransferRestartResponding(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, channelID) - response, err := message.RestartResponse(channelID.ID, true, false, datatransfer.EmptyTypeIdentifier, nil) + response, err := message.RestartResponse(channelID.ID, true, false, nil) require.NoError(t, err) err = h.transport.EventHandler.OnResponseReceived(channelID, response) require.NoError(t, err) @@ -697,9 +714,11 @@ func TestDataTransferRestartResponding(t *testing.T) { }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPush() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) sv.ExpectSuccessValidateRestart() - sv.StubRestartResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: testutil.NewFakeDTType()}) + vr = testutil.NewTestTypedVoucher() + sv.StubRestartResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { // receive an incoming push @@ -717,8 +736,7 @@ func TestDataTransferRestartResponding(t *testing.T) { require.NoError(t, ev.OnDataReceived(chid, cidlink.Link{Cid: testCids[1]}, 12345, 2, true)) // receive restart push request - req, err := message.NewRequest(h.pushRequest.TransferID(), true, false, h.voucher.Type(), h.voucher, - h.baseCid, h.stor) + req, err := message.NewRequest(h.pushRequest.TransferID(), true, false, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], req) require.Len(t, h.sv.RevalidationsReceived, 1) @@ -758,9 +776,11 @@ func TestDataTransferRestartResponding(t *testing.T) { }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) sv.ExpectSuccessValidateRestart() - sv.StubRestartResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: testutil.NewFakeDTType()}) + vr = testutil.NewTestTypedVoucher() + sv.StubRestartResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { // receive an incoming pull @@ -771,7 +791,7 @@ func TestDataTransferRestartResponding(t *testing.T) { require.Len(t, h.network.SentMessages, 0) // receive restart pull request - restartReq, err := message.NewRequest(h.id, true, true, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + restartReq, err := message.NewRequest(h.id, true, true, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), restartReq) require.NoError(t, err) @@ -801,7 +821,8 @@ func TestDataTransferRestartResponding(t *testing.T) { }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { // receive an incoming pull @@ -812,7 +833,7 @@ func TestDataTransferRestartResponding(t *testing.T) { require.Len(t, h.network.SentMessages, 0) // receive restart pull request - restartReq, err := message.NewRequest(h.id, true, true, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + restartReq, err := message.NewRequest(h.id, true, true, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) p := testutil.GeneratePeers(1)[0] chid := datatransfer.ChannelID{ID: h.pullRequest.TransferID(), Initiator: p, Responder: h.peers[0]} @@ -830,7 +851,8 @@ func TestDataTransferRestartResponding(t *testing.T) { }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) sv.ExpectSuccessValidateRestart() sv.StubRestartResult(datatransfer.ValidationResult{Accepted: false}) }, @@ -844,7 +866,7 @@ func TestDataTransferRestartResponding(t *testing.T) { // receive restart pull request h.sv.ExpectErrorPull() - restartReq, err := message.NewRequest(h.id, true, true, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + restartReq, err := message.NewRequest(h.id, true, true, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) _, err = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), restartReq) require.EqualError(t, err, datatransfer.ErrRejected.Error()) @@ -858,7 +880,8 @@ func TestDataTransferRestartResponding(t *testing.T) { }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { // receive an incoming pull @@ -871,7 +894,7 @@ func TestDataTransferRestartResponding(t *testing.T) { // receive restart pull request randCid := testutil.GenerateCids(1)[0] - restartReq, err := message.NewRequest(h.id, true, true, h.voucher.Type(), h.voucher, randCid, h.stor) + restartReq, err := message.NewRequest(h.id, true, true, &h.voucher, randCid, h.stor) require.NoError(t, err) _, err = h.transport.EventHandler.OnRequestReceived(chid, restartReq) require.EqualError(t, err, fmt.Sprintf("restart request for channel %s failed validation: base cid does not match", chid)) @@ -885,7 +908,8 @@ func TestDataTransferRestartResponding(t *testing.T) { }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { // receive an incoming pull @@ -898,10 +922,10 @@ func TestDataTransferRestartResponding(t *testing.T) { // receive restart pull request - restartReq, err := message.NewRequest(h.id, true, true, "rand", h.voucher, h.baseCid, h.stor) + restartReq, err := message.NewRequest(h.id, true, true, &datatransfer.TypedVoucher{Voucher: h.voucher.Voucher, Type: "rand"}, h.baseCid, h.stor) require.NoError(t, err) _, err = h.transport.EventHandler.OnRequestReceived(chid, restartReq) - require.EqualError(t, err, fmt.Sprintf("restart request for channel %s failed validation: failed to decode request voucher: unknown voucher type: rand", chid)) + require.EqualError(t, err, fmt.Sprintf("restart request for channel %s failed validation: channel and request voucher types do not match", chid)) }, }, "restart request fails if voucher does not match": { @@ -912,7 +936,8 @@ func TestDataTransferRestartResponding(t *testing.T) { }, configureValidator: func(sv *testutil.StubbedValidator) { sv.ExpectSuccessPull() - sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: testutil.NewFakeDTType()}) + vr := testutil.NewTestTypedVoucher() + sv.StubResult(datatransfer.ValidationResult{Accepted: true, VoucherResult: &vr}) }, verify: func(t *testing.T, h *receiverHarness) { // receive an incoming pull @@ -924,9 +949,8 @@ func TestDataTransferRestartResponding(t *testing.T) { require.Len(t, h.network.SentMessages, 0) // receive restart pull request - v := testutil.NewFakeDTType() - v.Data = "rand" - restartReq, err := message.NewRequest(h.id, true, true, h.voucher.Type(), v, h.baseCid, h.stor) + v := testutil.NewTestTypedVoucherWith("rand") + restartReq, err := message.NewRequest(h.id, true, true, &v, h.baseCid, h.stor) require.NoError(t, err) _, err = h.transport.EventHandler.OnRequestReceived(chid, restartReq) require.EqualError(t, err, fmt.Sprintf("restart request for channel %s failed validation: channel and request vouchers do not match", chid)) @@ -986,7 +1010,7 @@ func TestDataTransferRestartResponding(t *testing.T) { receivedSelector, err := request.Selector() require.NoError(t, err) require.Equal(t, receivedSelector, h.stor) - testutil.AssertFakeDTVoucher(t, request, h.voucher) + testutil.AssertTestVoucher(t, request, h.voucher) }, }, "ReceiveRestartExistingChannelRequest: Resend Push Request": { @@ -1027,7 +1051,7 @@ func TestDataTransferRestartResponding(t *testing.T) { receivedSelector, err := receivedRequest.Selector() require.NoError(t, err) require.Equal(t, receivedSelector, h.stor) - testutil.AssertFakeDTVoucher(t, receivedRequest, h.voucher) + testutil.AssertTestVoucher(t, receivedRequest, h.voucher) }, }, "ReceiveRestartExistingChannelRequest: errors if peer is not the initiator": { @@ -1093,19 +1117,19 @@ func TestDataTransferRestartResponding(t *testing.T) { } ev.setup(t, dt) h.stor = testutil.AllSelector() - h.voucher = testutil.NewFakeDTType() + h.voucher = testutil.NewTestTypedVoucher() h.baseCid = testutil.GenerateCids(1)[0] h.id = datatransfer.TransferID(rand.Int31()) - h.pullRequest, err = message.NewRequest(h.id, false, true, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + h.pullRequest, err = message.NewRequest(h.id, false, true, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) - h.pushRequest, err = message.NewRequest(h.id, false, false, h.voucher.Type(), h.voucher, h.baseCid, h.stor) + h.pushRequest, err = message.NewRequest(h.id, false, false, &h.voucher, h.baseCid, h.stor) require.NoError(t, err) h.sv = testutil.NewStubbedValidator() if verify.configureValidator != nil { verify.configureValidator(h.sv) } - require.NoError(t, h.dt.RegisterVoucherType(h.voucher, h.sv)) + require.NoError(t, h.dt.RegisterVoucherType(h.voucher.Type, h.sv)) verify.verify(t, h) h.sv.VerifyExpectations(t) @@ -1129,8 +1153,8 @@ type receiverHarness struct { sv *testutil.StubbedValidator ds datastore.Batching dt datatransfer.Manager - stor ipld.Node - voucher *testutil.FakeDTType + stor datamodel.Node + voucher datatransfer.TypedVoucher baseCid cid.Cid } diff --git a/impl/restart.go b/impl/restart.go index 9efe0728..eaee5349 100644 --- a/impl/restart.go +++ b/impl/restart.go @@ -1,16 +1,15 @@ package impl import ( - "bytes" "context" + "github.com/ipld/go-ipld-prime" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/libp2p/go-libp2p-core/peer" "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer/v2" "github.com/filecoin-project/go-data-transfer/v2/channels" - "github.com/filecoin-project/go-data-transfer/v2/encoding" "github.com/filecoin-project/go-data-transfer/v2/message" ) @@ -73,17 +72,20 @@ func (m *manager) restartManagerPeerReceivePull(ctx context.Context, channel dat func (m *manager) openPushRestartChannel(ctx context.Context, channel datatransfer.ChannelState) error { selector := channel.Selector() - voucher := channel.Voucher() + voucher, err := channel.Voucher() + if err != nil { + return err + } baseCid := channel.BaseCID() requestTo := channel.OtherPeer() chid := channel.ChannelID() - req, err := message.NewRequest(chid.ID, true, false, voucher.Type(), voucher, baseCid, selector) + req, err := message.NewRequest(chid.ID, true, false, &voucher, baseCid, selector) if err != nil { return err } - processor, has := m.transportConfigurers.Processor(voucher.Type()) + processor, has := m.transportConfigurers.Processor(voucher.Type) if has { transportConfigurer := processor.(datatransfer.TransportConfigurer) transportConfigurer(chid, voucher, m.transport) @@ -108,17 +110,20 @@ func (m *manager) openPushRestartChannel(ctx context.Context, channel datatransf func (m *manager) openPullRestartChannel(ctx context.Context, channel datatransfer.ChannelState) error { selector := channel.Selector() - voucher := channel.Voucher() + voucher, err := channel.Voucher() + if err != nil { + return err + } baseCid := channel.BaseCID() requestTo := channel.OtherPeer() chid := channel.ChannelID() - req, err := message.NewRequest(chid.ID, true, true, voucher.Type(), voucher, baseCid, selector) + req, err := message.NewRequest(chid.ID, true, true, &voucher, baseCid, selector) if err != nil { return err } - processor, has := m.transportConfigurers.Processor(voucher.Type()) + processor, has := m.transportConfigurers.Processor(voucher.Type) if has { transportConfigurer := processor.(datatransfer.TransportConfigurer) transportConfigurer(chid, voucher, m.transport) @@ -164,24 +169,19 @@ func (m *manager) validateRestartRequest(ctx context.Context, otherPeer peer.ID, } // vouchers should match - reqVoucher, err := m.decodeVoucher(req) + reqVoucher, err := req.Voucher() if err != nil { - return xerrors.Errorf("failed to decode request voucher: %w", err) + return xerrors.Errorf("failed to fetch request voucher: %w", err) } - if reqVoucher.Type() != channel.Voucher().Type() { - return xerrors.New("channel and request voucher types do not match") - } - - reqBz, err := encoding.Encode(reqVoucher) + channelVoucher, err := channel.Voucher() if err != nil { - return xerrors.New("failed to encode request voucher") + return xerrors.Errorf("failed to fetch channel voucher: %w", err) } - channelBz, err := encoding.Encode(channel.Voucher()) - if err != nil { - return xerrors.New("failed to encode channel voucher") + if req.VoucherType() != channelVoucher.Type { + return xerrors.New("channel and request voucher types do not match") } - if !bytes.Equal(reqBz, channelBz) { + if !ipld.DeepEqual(reqVoucher, channelVoucher.Voucher) { return xerrors.New("channel and request vouchers do not match") } diff --git a/impl/restart_integration_test.go b/impl/restart_integration_test.go index 453aac80..1b4a50ec 100644 --- a/impl/restart_integration_test.go +++ b/impl/restart_integration_test.go @@ -41,8 +41,8 @@ func TestRestartPush(t *testing.T) { "Restart peer create push": { stopAt: 20, openPushF: func(rh *restartHarness) datatransfer.ChannelID { - voucher := testutil.FakeDTType{Data: "applesauce"} - chid, err := rh.dt1.OpenPushDataChannel(rh.testCtx, rh.peer2, &voucher, rh.rootCid, rh.gsData.AllSelector) + voucher := testutil.NewTestTypedVoucherWith("applesauce") + chid, err := rh.dt1.OpenPushDataChannel(rh.testCtx, rh.peer2, voucher, rh.rootCid, rh.gsData.AllSelector) require.NoError(rh.t, err) return chid }, @@ -53,7 +53,7 @@ func TestRestartPush(t *testing.T) { tp1 := rh.gsData.SetupGSTransportHost1() rh.dt1, err = NewDataTransfer(rh.gsData.DtDs1, rh.gsData.DtNet1, tp1) require.NoError(rh.t, err) - require.NoError(rh.t, rh.dt1.RegisterVoucherType(&testutil.FakeDTType{}, rh.sv)) + require.NoError(rh.t, rh.dt1.RegisterVoucherType(testutil.TestVoucherType, rh.sv)) testutil.StartAndWaitForReady(rh.testCtx, t, rh.dt1) rh.dt1.SubscribeToEvents(subscriber) require.NoError(rh.t, rh.dt1.RestartDataTransferChannel(rh.testCtx, chId)) @@ -83,8 +83,8 @@ func TestRestartPush(t *testing.T) { "Restart peer receive push": { stopAt: 20, openPushF: func(rh *restartHarness) datatransfer.ChannelID { - voucher := testutil.FakeDTType{Data: "applesauce"} - chid, err := rh.dt1.OpenPushDataChannel(rh.testCtx, rh.peer2, &voucher, rh.rootCid, rh.gsData.AllSelector) + voucher := testutil.NewTestTypedVoucherWith("applesauce") + chid, err := rh.dt1.OpenPushDataChannel(rh.testCtx, rh.peer2, voucher, rh.rootCid, rh.gsData.AllSelector) require.NoError(rh.t, err) return chid }, @@ -95,7 +95,7 @@ func TestRestartPush(t *testing.T) { tp2 := rh.gsData.SetupGSTransportHost2() rh.dt2, err = NewDataTransfer(rh.gsData.DtDs2, rh.gsData.DtNet2, tp2) require.NoError(rh.t, err) - require.NoError(rh.t, rh.dt2.RegisterVoucherType(&testutil.FakeDTType{}, rh.sv)) + require.NoError(rh.t, rh.dt2.RegisterVoucherType(testutil.TestVoucherType, rh.sv)) testutil.StartAndWaitForReady(rh.testCtx, t, rh.dt2) rh.dt2.SubscribeToEvents(subscriber) require.NoError(rh.t, rh.dt2.RestartDataTransferChannel(rh.testCtx, chId)) @@ -294,8 +294,8 @@ func TestRestartPull(t *testing.T) { "Restart peer create pull": { stopAt: 40, openPullF: func(rh *restartHarness) datatransfer.ChannelID { - voucher := testutil.FakeDTType{Data: "applesauce"} - chid, err := rh.dt2.OpenPullDataChannel(rh.testCtx, rh.peer1, &voucher, rh.rootCid, rh.gsData.AllSelector) + voucher := testutil.NewTestTypedVoucherWith("applesauce") + chid, err := rh.dt2.OpenPullDataChannel(rh.testCtx, rh.peer1, voucher, rh.rootCid, rh.gsData.AllSelector) require.NoError(rh.t, err) return chid }, @@ -306,7 +306,7 @@ func TestRestartPull(t *testing.T) { tp2 := rh.gsData.SetupGSTransportHost2() rh.dt2, err = NewDataTransfer(rh.gsData.DtDs2, rh.gsData.DtNet2, tp2) require.NoError(rh.t, err) - require.NoError(rh.t, rh.dt2.RegisterVoucherType(&testutil.FakeDTType{}, rh.sv)) + require.NoError(rh.t, rh.dt2.RegisterVoucherType(testutil.TestVoucherType, rh.sv)) testutil.StartAndWaitForReady(rh.testCtx, t, rh.dt2) rh.dt2.SubscribeToEvents(subscriber) require.NoError(rh.t, rh.dt2.RestartDataTransferChannel(rh.testCtx, chId)) @@ -333,8 +333,8 @@ func TestRestartPull(t *testing.T) { "Restart peer receive pull": { stopAt: 40, openPullF: func(rh *restartHarness) datatransfer.ChannelID { - voucher := testutil.FakeDTType{Data: "applesauce"} - chid, err := rh.dt2.OpenPullDataChannel(rh.testCtx, rh.peer1, &voucher, rh.rootCid, rh.gsData.AllSelector) + voucher := testutil.NewTestTypedVoucherWith("applesauce") + chid, err := rh.dt2.OpenPullDataChannel(rh.testCtx, rh.peer1, voucher, rh.rootCid, rh.gsData.AllSelector) require.NoError(rh.t, err) return chid }, @@ -345,7 +345,7 @@ func TestRestartPull(t *testing.T) { tp1 := rh.gsData.SetupGSTransportHost1() rh.dt1, err = NewDataTransfer(rh.gsData.DtDs1, rh.gsData.DtNet1, tp1) require.NoError(rh.t, err) - require.NoError(rh.t, rh.dt1.RegisterVoucherType(&testutil.FakeDTType{}, rh.sv)) + require.NoError(rh.t, rh.dt1.RegisterVoucherType(testutil.TestVoucherType, rh.sv)) testutil.StartAndWaitForReady(rh.testCtx, t, rh.dt1) rh.dt1.SubscribeToEvents(subscriber) require.NoError(rh.t, rh.dt1.RestartDataTransferChannel(rh.testCtx, chId)) @@ -570,8 +570,8 @@ func newRestartHarness(t *testing.T) *restartHarness { require.NoError(t, err) sv := testutil.NewStubbedValidator() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) + require.NoError(t, dt2.RegisterVoucherType(testutil.TestVoucherType, sv)) sourceDagService := gsData.DagService1 root, origBytes := testutil.LoadUnixFSFile(ctx, t, sourceDagService, largeFile) diff --git a/impl/utils.go b/impl/utils.go index 4c748acc..518b9e0c 100644 --- a/impl/utils.go +++ b/impl/utils.go @@ -4,9 +4,8 @@ import ( "context" "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" "github.com/libp2p/go-libp2p-core/peer" - "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer/v2" "github.com/filecoin-project/go-data-transfer/v2/message" @@ -30,10 +29,10 @@ var resumeTransportStatesResponder = statusList{ } // newRequest encapsulates message creation -func (m *manager) newRequest(ctx context.Context, selector ipld.Node, isPull bool, voucher datatransfer.Voucher, baseCid cid.Cid, to peer.ID) (datatransfer.Request, error) { +func (m *manager) newRequest(ctx context.Context, selector datamodel.Node, isPull bool, voucher datatransfer.TypedVoucher, baseCid cid.Cid, to peer.ID) (datatransfer.Request, error) { // Generate a new transfer ID for the request tid := datatransfer.TransferID(m.transferIDGen.next()) - return message.NewRequest(tid, false, isPull, voucher.Type(), voucher, baseCid, selector) + return message.NewRequest(tid, false, isPull, &voucher, baseCid, selector) } func (m *manager) resume(chid datatransfer.ChannelID) error { @@ -84,29 +83,3 @@ func (m *manager) cancelMessage(chid datatransfer.ChannelID) datatransfer.Messag } return message.CancelResponse(chid.ID) } - -func (m *manager) decodeVoucherResult(response datatransfer.Response) (datatransfer.VoucherResult, error) { - vtypStr := datatransfer.TypeIdentifier(response.VoucherResultType()) - decoder, has := m.resultTypes.Decoder(vtypStr) - if !has { - return nil, xerrors.Errorf("unknown voucher result type: %s", vtypStr) - } - encodable, err := response.VoucherResult(decoder) - if err != nil { - return nil, err - } - return encodable.(datatransfer.Registerable), nil -} - -func (m *manager) decodeVoucher(request datatransfer.Request) (datatransfer.Voucher, error) { - vtypStr := datatransfer.TypeIdentifier(request.VoucherType()) - decoder, has := m.validatedTypes.Decoder(vtypStr) - if !has { - return nil, xerrors.Errorf("unknown voucher type: %s", vtypStr) - } - encodable, err := request.Voucher(decoder) - if err != nil { - return nil, err - } - return encodable.(datatransfer.Registerable), nil -} diff --git a/ipldutils/ipldutils.go b/ipldutils/ipldutils.go new file mode 100644 index 00000000..1b5b0be2 --- /dev/null +++ b/ipldutils/ipldutils.go @@ -0,0 +1,183 @@ +package shared + +import ( + "bytes" + "fmt" + "io" + "reflect" + + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/codec/dagcbor" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/node/bindnode" + "github.com/ipld/go-ipld-prime/schema" + cbg "github.com/whyrusleeping/cbor-gen" +) + +type typeWithBindnodeSchema interface { + BindnodeSchema() string +} + +// TODO: remove this I think +type typeWithBindnodePostDecode interface { + BindnodePostDecode() error +} + +// We use the prototype map to store TypedPrototype and Type information +// mapped against Go type names so we only have to run the schema parse once. +// Currently there's not much additional benefit of storing this but there +// may be in the future. +var prototype map[string]schema.TypedPrototype = make(map[string]schema.TypedPrototype) + +var bindnodeOptions = []bindnode.Option{} + +func typeName(ptrValue interface{}) string { + val := reflect.ValueOf(ptrValue).Type() + for val.Kind() == reflect.Ptr { + val = val.Elem() + } + return val.Name() +} + +// lookup of cached TypedPrototype (and therefore Type) for a Go type, if not +// found, initial parse and setup and caching of the TypedPrototype will happen +func prototypeFor(typeName string, ptrType interface{}) (schema.TypedPrototype, error) { + proto, ok := prototype[typeName] + if !ok { + schemaType, err := schemaTypeFor(typeName, ptrType) + if err != nil { + return nil, err + } + if schemaType == nil { + return nil, fmt.Errorf("could not find type [%s] in schema", typeName) + } + proto = bindnode.Prototype(ptrType, schemaType, bindnodeOptions...) + prototype[typeName] = proto + } + return proto, nil +} + +// load the schema for a Go type, which must have a BindnodeSchema() method +// attached to it +func schemaTypeFor(typeName string, ptrType interface{}) (schema.Type, error) { + tws, ok := ptrType.(typeWithBindnodeSchema) + if !ok { + return nil, fmt.Errorf("attempted to perform IPLD mapping on type without BindnodeSchema(): %T", ptrType) + } + schema := tws.BindnodeSchema() + typeSystem, err := ipld.LoadSchemaBytes([]byte(schema)) + if err != nil { + return nil, err + } + schemaType := typeSystem.TypeByName(typeName) + if schemaType == nil { + if !ok { + return nil, fmt.Errorf("schema for [%T] does not contain that named type [%s]", ptrType, typeName) + } + } + return schemaType, nil +} + +// FromReader deserializes DAG-CBOR from a Reader and instantiates the Go type +// that's provided as a pointer via the ptrValue argument. +func FromReader(r io.Reader, ptrValue interface{}) (interface{}, error) { + name := typeName(ptrValue) + proto, err := prototypeFor(name, ptrValue) + if err != nil { + return nil, err + } + node, err := ipld.DecodeStreamingUsingPrototype(r, dagcbor.Decode, proto) + if err != nil { + return nil, err + } + typ := bindnode.Unwrap(node) + if twpd, ok := typ.(typeWithBindnodePostDecode); ok { + // we have some more work to do + if err = twpd.BindnodePostDecode(); err != nil { + return nil, err + } + } + return typ, nil +} + +// FromNode converts an datamodel.Node into an appropriate Go type that's provided as +// a pointer via the ptrValue argument +func FromNode(node datamodel.Node, ptrValue interface{}) (interface{}, error) { + name := typeName(ptrValue) + proto, err := prototypeFor(name, ptrValue) + if err != nil { + return nil, err + } + if tn, ok := node.(schema.TypedNode); ok { + node = tn.Representation() + } + builder := proto.Representation().NewBuilder() + err = builder.AssignNode(node) + if err != nil { + return nil, err + } + typ := bindnode.Unwrap(builder.Build()) + if twpd, ok := typ.(typeWithBindnodePostDecode); ok { + // we have some more work to do + if err = twpd.BindnodePostDecode(); err != nil { + return nil, err + } + } + return typ, nil +} + +// ToNode converts a Go type that's provided as a pointer via the ptrValue +// argument to an datamodel.Node. +func ToNode(ptrValue interface{}) (schema.TypedNode, error) { + name := typeName(ptrValue) + proto, err := prototypeFor(name, ptrValue) + if err != nil { + return nil, err + } + return bindnode.Wrap(ptrValue, proto.Type(), bindnodeOptions...), err +} + +// NodeToWriter is a utility method that serializes an datamodel.Node as DAG-CBOR to +// a Writer +func NodeToWriter(node datamodel.Node, w io.Writer) error { + return ipld.EncodeStreaming(w, node, dagcbor.Encode) +} + +// NodeToBytes is a utility method that serializes an datamodel.Node as DAG-CBOR to +// a []byte +func NodeToBytes(node datamodel.Node) ([]byte, error) { + var buf bytes.Buffer + err := NodeToWriter(node, &buf) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// NodeFromBytes is a utility method that deserializes an untyped datamodel.Node +// from DAG-CBOR format bytes +func NodeFromBytes(b []byte) (datamodel.Node, error) { + return ipld.Decode(b, dagcbor.Decode) +} + +// TypeToWriter is a utility method that serializes a Go type that's provided as a +// pointer via the ptrValue argument as DAG-CBOR to a Writer +func TypeToWriter(ptrValue interface{}, w io.Writer) error { + node, err := ToNode(ptrValue) + if err != nil { + return err + } + return ipld.EncodeStreaming(w, node, dagcbor.Encode) +} + +func NodeToDeferred(node datamodel.Node) (*cbg.Deferred, error) { + byts, err := NodeToBytes(node) + if err != nil { + return nil, err + } + return &cbg.Deferred{Raw: byts}, nil +} + +func DeferredToNode(def *cbg.Deferred) (datamodel.Node, error) { + return NodeFromBytes(def.Raw) +} diff --git a/manager.go b/manager.go index b291dc11..822f76e5 100644 --- a/manager.go +++ b/manager.go @@ -4,7 +4,7 @@ import ( "context" "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" "github.com/libp2p/go-libp2p-core/peer" ) @@ -15,7 +15,7 @@ type ValidationResult struct { Accepted bool // VoucherResult provides information to the other party about what happened // with the voucher - VoucherResult + VoucherResult *TypedVoucher // ForcePause indicates whether the request should be paused, regardless // of data limit and finalization status ForcePause bool @@ -28,6 +28,16 @@ type ValidationResult struct { RequiresFinalization bool } +// Equals checks the deep equality of two ValidationResult values +func (vr ValidationResult) Equals(vr2 ValidationResult) bool { + return vr.Accepted == vr2.Accepted && + vr.ForcePause == vr2.ForcePause && + vr.DataLimit == vr2.DataLimit && + vr.RequiresFinalization == vr2.RequiresFinalization && + (vr.VoucherResult == nil) == (vr2.VoucherResult == nil) && + (vr.VoucherResult == nil || vr.VoucherResult.Equals(*vr2.VoucherResult)) +} + // LeaveRequestPaused indicates whether all conditions are met to resume a request func (vr ValidationResult) LeaveRequestPaused(chst ChannelState) bool { if vr.ForcePause { @@ -56,9 +66,9 @@ type RequestValidator interface { ValidatePush( chid ChannelID, sender peer.ID, - voucher Voucher, + voucher datamodel.Node, baseCid cid.Cid, - selector ipld.Node) (ValidationResult, error) + selector datamodel.Node) (ValidationResult, error) // ValidatePull validates a pull request received from the peer that will receive data // -- All information about the validation operation is contained in ValidationResult, // including if it was rejected. Information about why a rejection occurred should be embedded @@ -67,9 +77,9 @@ type RequestValidator interface { ValidatePull( chid ChannelID, receiver peer.ID, - voucher Voucher, + voucher datamodel.Node, baseCid cid.Cid, - selector ipld.Node) (ValidationResult, error) + selector datamodel.Node) (ValidationResult, error) // ValidateRestart validates restarting a request // -- All information about the validation operation is contained in ValidationResult, @@ -80,7 +90,7 @@ type RequestValidator interface { } // TransportConfigurer provides a mechanism to provide transport specific configuration for a given voucher type -type TransportConfigurer func(chid ChannelID, voucher Voucher, transport Transport) +type TransportConfigurer func(chid ChannelID, voucher TypedVoucher, transport Transport) // ReadyFunc is function that gets called once when the data transfer module is ready type ReadyFunc func(error) @@ -101,29 +111,25 @@ type Manager interface { // RegisterVoucherType registers a validator for the given voucher type // will error if voucher type does not implement voucher // or if there is a voucher type registered with an identical identifier - RegisterVoucherType(voucherType Voucher, validator RequestValidator) error - - // RegisterVoucherResultType allows deserialization of a voucher result, - // so that a listener can read the metadata - RegisterVoucherResultType(resultType VoucherResult) error + RegisterVoucherType(voucherType TypeIdentifier, validator RequestValidator) error // RegisterTransportConfigurer registers the given transport configurer to be run on requests with the given voucher // type - RegisterTransportConfigurer(voucherType Voucher, configurer TransportConfigurer) error + RegisterTransportConfigurer(voucherType TypeIdentifier, configurer TransportConfigurer) error // open a data transfer that will send data to the recipient peer and // transfer parts of the piece that match the selector - OpenPushDataChannel(ctx context.Context, to peer.ID, voucher Voucher, baseCid cid.Cid, selector ipld.Node) (ChannelID, error) + OpenPushDataChannel(ctx context.Context, to peer.ID, voucher TypedVoucher, baseCid cid.Cid, selector datamodel.Node) (ChannelID, error) // open a data transfer that will request data from the sending peer and // transfer parts of the piece that match the selector - OpenPullDataChannel(ctx context.Context, to peer.ID, voucher Voucher, baseCid cid.Cid, selector ipld.Node) (ChannelID, error) + OpenPullDataChannel(ctx context.Context, to peer.ID, voucher TypedVoucher, baseCid cid.Cid, selector datamodel.Node) (ChannelID, error) // send an intermediate voucher as needed when the receiver sends a request for revalidation - SendVoucher(ctx context.Context, chid ChannelID, voucher Voucher) error + SendVoucher(ctx context.Context, chid ChannelID, voucher TypedVoucher) error // send information from the responder to update the initiator on the state of their voucher - SendVoucherResult(ctx context.Context, chid ChannelID, voucher VoucherResult) error + SendVoucherResult(ctx context.Context, chid ChannelID, voucherResult TypedVoucher) error // Update the validation status for a given channel, to change data limits, finalization, accepted status, and pause state // and send new voucher results as diff --git a/message.go b/message.go index 7f36b33d..d54eabb1 100644 --- a/message.go +++ b/message.go @@ -4,11 +4,8 @@ import ( "io" "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/datamodel" "github.com/libp2p/go-libp2p-core/protocol" - - "github.com/filecoin-project/go-data-transfer/v2/encoding" ) var ( @@ -38,9 +35,10 @@ type Request interface { IsPull() bool IsVoucher() bool VoucherType() TypeIdentifier - Voucher(decoder encoding.Decoder) (encoding.Encodable, error) + Voucher() (datamodel.Node, error) + TypedVoucher() (TypedVoucher, error) BaseCid() cid.Cid - Selector() (ipld.Node, error) + Selector() (datamodel.Node, error) IsRestartExistingChannelRequest() bool RestartChannelId() (ChannelID, error) } @@ -52,6 +50,6 @@ type Response interface { IsComplete() bool Accepted() bool VoucherResultType() TypeIdentifier - VoucherResult(decoder encoding.Decoder) (encoding.Encodable, error) + VoucherResult() (datamodel.Node, error) EmptyVoucherResult() bool } diff --git a/message/message1_1prime/message.go b/message/message1_1prime/message.go index e533252d..b740bc01 100644 --- a/message/message1_1prime/message.go +++ b/message/message1_1prime/message.go @@ -5,24 +5,25 @@ import ( "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" - "github.com/ipld/go-ipld-prime/codec/dagcbor" "github.com/ipld/go-ipld-prime/datamodel" - "github.com/ipld/go-ipld-prime/node/bindnode" "github.com/ipld/go-ipld-prime/schema" xerrors "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer/v2" - "github.com/filecoin-project/go-data-transfer/v2/encoding" + ipldutils "github.com/filecoin-project/go-data-transfer/v2/ipldutils" "github.com/filecoin-project/go-data-transfer/v2/message/types" ) +var emptyTypedVoucher = datatransfer.TypedVoucher{ + Voucher: ipld.Null, + Type: datatransfer.EmptyTypeIdentifier, +} + // NewRequest generates a new request for the data transfer protocol -func NewRequest(id datatransfer.TransferID, isRestart bool, isPull bool, vtype datatransfer.TypeIdentifier, voucher encoding.Encodable, baseCid cid.Cid, selector ipld.Node) (datatransfer.Request, error) { - vnode, err := encoding.EncodeToNode(voucher) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) +func NewRequest(id datatransfer.TransferID, isRestart bool, isPull bool, voucher *datatransfer.TypedVoucher, baseCid cid.Cid, selector datamodel.Node) (datatransfer.Request, error) { + if voucher == nil { + voucher = &emptyTypedVoucher } - if baseCid == cid.Undef { return nil, xerrors.Errorf("base CID must be defined") } @@ -34,13 +35,17 @@ func NewRequest(id datatransfer.TransferID, isRestart bool, isPull bool, vtype d typ = uint64(types.NewMessage) } + if voucher == nil { + voucher = &emptyTypedVoucher + } + return &TransferRequest1_1{ MessageType: typ, Pull: isPull, - VoucherPtr: &vnode, - SelectorPtr: &selector, + VoucherPtr: voucher.Voucher, + SelectorPtr: selector, BaseCidPtr: &baseCid, - VoucherTypeIdentifier: vtype, + VoucherTypeIdentifier: voucher.Type, TransferId: uint64(id), }, nil } @@ -71,32 +76,30 @@ func UpdateRequest(id datatransfer.TransferID, isPaused bool) datatransfer.Reque } // VoucherRequest generates a new request for the data transfer protocol -func VoucherRequest(id datatransfer.TransferID, vtype datatransfer.TypeIdentifier, voucher encoding.Encodable) (datatransfer.Request, error) { - vnode, err := encoding.EncodeToNode(voucher) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) +func VoucherRequest(id datatransfer.TransferID, voucher *datatransfer.TypedVoucher) (datatransfer.Request, error) { + if voucher == nil { + voucher = &emptyTypedVoucher } return &TransferRequest1_1{ MessageType: uint64(types.VoucherMessage), - VoucherPtr: &vnode, - VoucherTypeIdentifier: vtype, + VoucherPtr: voucher.Voucher, + VoucherTypeIdentifier: voucher.Type, TransferId: uint64(id), }, nil } // RestartResponse builds a new Data Transfer response -func RestartResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vnode, err := encoding.EncodeToNode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) +func RestartResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) (datatransfer.Response, error) { + if voucherResult == nil { + voucherResult = &emptyTypedVoucher } return &TransferResponse1_1{ RequestAccepted: accepted, MessageType: uint64(types.RestartMessage), Paused: isPaused, TransferId: uint64(id), - VoucherTypeIdentifier: voucherResultType, - VoucherResultPtr: &vnode, + VoucherResultPtr: voucherResult.Voucher, + VoucherTypeIdentifier: voucherResult.Type, }, nil } @@ -108,13 +111,10 @@ func ValidationResultResponse( validationResult datatransfer.ValidationResult, validationErr error, paused bool) (datatransfer.Response, error) { - voucherResultType := datatransfer.EmptyTypeIdentifier + + voucherResult := &emptyTypedVoucher if validationResult.VoucherResult != nil { - voucherResultType = validationResult.VoucherResult.Type() - } - vnode, err := encoding.EncodeToNode(validationResult.VoucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) + voucherResult = validationResult.VoucherResult } return &TransferResponse1_1{ // TODO: when we area able to change the protocol, it would be helpful to record @@ -123,40 +123,38 @@ func ValidationResultResponse( MessageType: uint64(messageType), Paused: paused, TransferId: uint64(id), - VoucherTypeIdentifier: voucherResultType, - VoucherResultPtr: &vnode, + VoucherTypeIdentifier: voucherResult.Type, + VoucherResultPtr: voucherResult.Voucher, }, nil } // NewResponse builds a new Data Transfer response -func NewResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vnode, err := encoding.EncodeToNode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) +func NewResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) (datatransfer.Response, error) { + if voucherResult == nil { + voucherResult = &emptyTypedVoucher } return &TransferResponse1_1{ RequestAccepted: accepted, MessageType: uint64(types.NewMessage), Paused: isPaused, TransferId: uint64(id), - VoucherTypeIdentifier: voucherResultType, - VoucherResultPtr: &vnode, + VoucherTypeIdentifier: voucherResult.Type, + VoucherResultPtr: voucherResult.Voucher, }, nil } // VoucherResultResponse builds a new response for a voucher result -func VoucherResultResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vnode, err := encoding.EncodeToNode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) +func VoucherResultResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) (datatransfer.Response, error) { + if voucherResult == nil { + voucherResult = &emptyTypedVoucher } return &TransferResponse1_1{ RequestAccepted: accepted, MessageType: uint64(types.VoucherResultMessage), Paused: isPaused, TransferId: uint64(id), - VoucherTypeIdentifier: voucherResultType, - VoucherResultPtr: &vnode, + VoucherTypeIdentifier: voucherResult.Type, + VoucherResultPtr: voucherResult.Voucher, }, nil } @@ -178,30 +176,27 @@ func CancelResponse(id datatransfer.TransferID) datatransfer.Response { } // CompleteResponse returns a new complete response message -func CompleteResponse(id datatransfer.TransferID, isAccepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vnode, err := encoding.EncodeToNode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) +func CompleteResponse(id datatransfer.TransferID, isAccepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) (datatransfer.Response, error) { + if voucherResult == nil { + voucherResult = &emptyTypedVoucher } return &TransferResponse1_1{ MessageType: uint64(types.CompleteMessage), RequestAccepted: isAccepted, Paused: isPaused, - VoucherTypeIdentifier: voucherResultType, - VoucherResultPtr: &vnode, + VoucherTypeIdentifier: voucherResult.Type, + VoucherResultPtr: voucherResult.Voucher, TransferId: uint64(id), }, nil } // FromNet can read a network stream to deserialize a GraphSyncMessage func FromNet(r io.Reader) (datatransfer.Message, error) { - builder := Prototype.TransferMessage.Representation().NewBuilder() - err := dagcbor.Decode(builder, r) + tm, err := ipldutils.FromReader(r, &TransferMessage1_1{}) if err != nil { return nil, err } - node := builder.Build() - tresp := bindnode.Unwrap(node).(*TransferMessage1_1) + tresp := tm.(*TransferMessage1_1) if (tresp.IsRequest && tresp.Request == nil) || (!tresp.IsRequest && tresp.Response == nil) { return nil, xerrors.Errorf("invalid/malformed message") @@ -218,12 +213,12 @@ func FromIPLD(node datamodel.Node) (datatransfer.Message, error) { if tn, ok := node.(schema.TypedNode); ok { // shouldn't need this if from Graphsync node = tn.Representation() } - builder := Prototype.TransferMessage.Representation().NewBuilder() - err := builder.AssignNode(node) + tm, err := ipldutils.FromNode(node, &TransferMessage1_1{}) if err != nil { return nil, err } - tresp := bindnode.Unwrap(builder.Build()).(*TransferMessage1_1) + tresp := tm.(*TransferMessage1_1) + if (tresp.IsRequest && tresp.Request == nil) || (!tresp.IsRequest && tresp.Response == nil) { return nil, xerrors.Errorf("invalid/malformed message") } diff --git a/message/message1_1prime/message_test.go b/message/message1_1prime/message_test.go index 93e95f27..7ef75e53 100644 --- a/message/message1_1prime/message_test.go +++ b/message/message1_1prime/message_test.go @@ -8,16 +8,13 @@ import ( "testing" "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime/codec/dagcbor" basicnode "github.com/ipld/go-ipld-prime/node/basic" - "github.com/ipld/go-ipld-prime/node/bindnode" "github.com/ipld/go-ipld-prime/traversal/selector/builder" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" datatransfer "github.com/filecoin-project/go-data-transfer/v2" - "github.com/filecoin-project/go-data-transfer/v2/encoding" message1_1 "github.com/filecoin-project/go-data-transfer/v2/message/message1_1prime" "github.com/filecoin-project/go-data-transfer/v2/testutil" ) @@ -27,8 +24,8 @@ func TestNewRequest(t *testing.T) { selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() isPull := true id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() - request, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) + voucher := testutil.NewTestTypedVoucher() + request, err := message1_1.NewRequest(id, false, isPull, &voucher, baseCid, selector) require.NoError(t, err) assert.Equal(t, id, request.TransferID()) assert.False(t, request.IsCancel()) @@ -36,8 +33,7 @@ func TestNewRequest(t *testing.T) { assert.True(t, request.IsPull()) assert.True(t, request.IsRequest()) assert.Equal(t, baseCid.String(), request.BaseCid().String()) - encoding.NewDecoder(request) - testutil.AssertFakeDTVoucher(t, request, voucher) + testutil.AssertTestVoucher(t, request, voucher) receivedSelector, err := request.Selector() require.NoError(t, err) require.Equal(t, selector, receivedSelector) @@ -56,8 +52,8 @@ func TestRestartRequest(t *testing.T) { selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() isPull := true id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() - request, err := message1_1.NewRequest(id, true, isPull, voucher.Type(), voucher, baseCid, selector) + voucher := testutil.NewTestTypedVoucher() + request, err := message1_1.NewRequest(id, true, isPull, &voucher, baseCid, selector) require.NoError(t, err) assert.Equal(t, id, request.TransferID()) assert.False(t, request.IsCancel()) @@ -65,7 +61,7 @@ func TestRestartRequest(t *testing.T) { assert.True(t, request.IsPull()) assert.True(t, request.IsRequest()) assert.Equal(t, baseCid.String(), request.BaseCid().String()) - testutil.AssertFakeDTVoucher(t, request, voucher) + testutil.AssertTestVoucher(t, request, voucher) receivedSelector, err := request.Selector() require.NoError(t, err) require.Equal(t, selector, receivedSelector) @@ -115,19 +111,9 @@ func TestRestartExistingChannelRequest(t *testing.T) { }) } -func TestTransferRequest_MarshalCBOR(t *testing.T) { - // sanity check MarshalCBOR does its thing w/o error - req, err := NewTestTransferRequest() - require.NoError(t, err) - wbuf := new(bytes.Buffer) - node := bindnode.Wrap(&req, message1_1.Prototype.TransferRequest.Type()) - err = dagcbor.Encode(node.Representation(), wbuf) - require.NoError(t, err) - assert.Greater(t, wbuf.Len(), 0) -} func TestTransferRequest_UnmarshalCBOR(t *testing.T) { t.Run("round-trip", func(t *testing.T) { - req, err := NewTestTransferRequest() + req, err := NewTestTransferRequest("test data here") require.NoError(t, err) wbuf := new(bytes.Buffer) // use ToNet / FromNet @@ -144,14 +130,15 @@ func TestTransferRequest_UnmarshalCBOR(t *testing.T) { assert.Equal(t, req.IsPull(), desReq.IsPull()) assert.Equal(t, req.IsCancel(), desReq.IsCancel()) assert.Equal(t, req.BaseCid(), desReq.BaseCid()) - testutil.AssertEqualFakeDTVoucher(t, &req, desReq) + testutil.AssertEqualTestVoucher(t, &req, desReq) testutil.AssertEqualSelector(t, &req, desReq) }) t.Run("cbor-gen compat", func(t *testing.T) { - req, err := NewTestTransferRequest() + vouchByts, _ := hex.DecodeString("f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e35") + req, err := NewTestTransferRequest(string(vouchByts)) require.NoError(t, err) - msg, _ := hex.DecodeString("a36449735271f56752657175657374aa6442436964d82a58230012204bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a6454797065006450617573f46450617274f46450756c6cf46453746f72a1612ea065566f756368817864f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e3564565479706a46616b65445454797065665866657249441a4d6582216e526573746172744368616e6e656c8360600068526573706f6e7365f6") + msg, _ := hex.DecodeString("a36449735271f56752657175657374aa6442436964d82a58230012204bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a6454797065006450617573f46450617274f46450756c6cf46453746f72a1612ea065566f756368817864f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e3564565479706b54657374566f7563686572665866657249441a4d6582216e526573746172744368616e6e656c8360600068526573706f6e7365f6") desMsg, err := message1_1.FromNet(bytes.NewReader(msg)) require.NoError(t, err) @@ -164,15 +151,15 @@ func TestTransferRequest_UnmarshalCBOR(t *testing.T) { assert.Equal(t, req.IsCancel(), desReq.IsCancel()) c, _ := cid.Parse("QmTTA2daxGqo5denp6SwLzzkLJm3fuisYEi9CoWsuHpzfb") assert.Equal(t, c, desReq.BaseCid()) - testutil.AssertEqualFakeDTVoucher(t, &req, desReq) + testutil.AssertEqualTestVoucher(t, &req, desReq) testutil.AssertEqualSelector(t, &req, desReq) }) } func TestResponses(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - voucherResult := testutil.NewFakeDTType() - response, err := message1_1.NewResponse(id, false, true, voucherResult.Type(), voucherResult) // not accepted + voucherResult := testutil.NewTestTypedVoucher() + response, err := message1_1.NewResponse(id, false, true, &voucherResult) // not accepted require.NoError(t, err) assert.Equal(t, response.TransferID(), id) assert.False(t, response.Accepted()) @@ -180,7 +167,7 @@ func TestResponses(t *testing.T) { assert.False(t, response.IsUpdate()) assert.True(t, response.IsPaused()) assert.False(t, response.IsRequest()) - testutil.AssertFakeDTVoucherResult(t, response, voucherResult) + testutil.AssertTestVoucherResult(t, response, voucherResult) // Sanity check to make sure we can cast to datatransfer.Message msg, ok := response.(datatransfer.Message) require.True(t, ok) @@ -194,8 +181,8 @@ func TestResponses(t *testing.T) { func TestTransferResponse_MarshalCBOR(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - voucherResult := testutil.NewFakeDTType() - response, err := message1_1.NewResponse(id, true, false, voucherResult.Type(), voucherResult) // accepted + voucherResult := testutil.NewTestTypedVoucher() + response, err := message1_1.NewResponse(id, true, false, &voucherResult) // accepted require.NoError(t, err) // sanity check that we can marshal data @@ -207,8 +194,8 @@ func TestTransferResponse_MarshalCBOR(t *testing.T) { func TestTransferResponse_UnmarshalCBOR(t *testing.T) { t.Run("round-trip", func(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - voucherResult := testutil.NewFakeDTType() - response, err := message1_1.NewResponse(id, true, false, voucherResult.Type(), voucherResult) // accepted + voucherResult := testutil.NewTestTypedVoucher() + response, err := message1_1.NewResponse(id, true, false, &voucherResult) // accepted require.NoError(t, err) wbuf := new(bytes.Buffer) @@ -229,13 +216,11 @@ func TestTransferResponse_UnmarshalCBOR(t *testing.T) { assert.True(t, desResp.IsNew()) assert.False(t, desResp.IsUpdate()) assert.False(t, desMsg.IsPaused()) - testutil.AssertFakeDTVoucherResult(t, desResp, voucherResult) + testutil.AssertTestVoucherResult(t, desResp, voucherResult) }) t.Run("cbor-gen compat", func(t *testing.T) { - voucherResult := testutil.NewFakeDTType() - voucherResult.Data = "\xf5_\xf8\xf1%\b\xb6>\xf2\xbf\xec\xa7Uz\xe9\r\xf61\x1a^\xc1c\x1bJ\x1f\xa8C1\v\xd9ç\x10\xea\xac塽\xd7*п\xe0Iw\x1c\x11\xe7V3\x8b\xd98e\xe6E\xf1\xad웜\x99\xef@\u007f\xbdOƅ\x9ey\x04ŭ}ɽ\x10\xa5\xcc\x16\x97=[(\xec\x1am\xd4=\x9f\x82\xf9\xf1\x8c=\x03A\x8e5" - - msg, _ := hex.DecodeString("a36449735271f46752657175657374f668526573706f6e7365a66454797065006441637074f56450617573f4665866657249441a4d6582216456526573817864f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e3564565479706a46616b65445454797065") + voucherResult := testutil.NewTestTypedVoucherWith("\xf5_\xf8\xf1%\b\xb6>\xf2\xbf\xec\xa7Uz\xe9\r\xf61\x1a^\xc1c\x1bJ\x1f\xa8C1\v\xd9ç\x10\xea\xac塽\xd7*п\xe0Iw\x1c\x11\xe7V3\x8b\xd98e\xe6E\xf1\xad웜\x99\xef@\u007f\xbdOƅ\x9ey\x04ŭ}ɽ\x10\xa5\xcc\x16\x97=[(\xec\x1am\xd4=\x9f\x82\xf9\xf1\x8c=\x03A\x8e5") + msg, _ := hex.DecodeString("a36449735271f46752657175657374f668526573706f6e7365a66454797065006441637074f56450617573f4665866657249441a4d6582216456526573817864f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e3564565479706b54657374566f7563686572") desMsg, err := message1_1.FromNet(bytes.NewReader(msg)) require.NoError(t, err) assert.False(t, desMsg.IsRequest()) @@ -250,7 +235,7 @@ func TestTransferResponse_UnmarshalCBOR(t *testing.T) { assert.True(t, desResp.IsNew()) assert.False(t, desResp.IsUpdate()) assert.False(t, desMsg.IsPaused()) - testutil.AssertFakeDTVoucherResult(t, desResp, voucherResult) + testutil.AssertTestVoucherResult(t, desResp, voucherResult) }) } @@ -381,7 +366,7 @@ func TestCancelResponse(t *testing.T) { func TestCompleteResponse(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - response, err := message1_1.CompleteResponse(id, true, true, datatransfer.EmptyTypeIdentifier, nil) + response, err := message1_1.CompleteResponse(id, true, true, nil) require.NoError(t, err) assert.Equal(t, response.TransferID(), id) assert.False(t, response.IsNew()) @@ -407,9 +392,9 @@ func TestToNetFromNetEquivalency(t *testing.T) { isPull := false id := datatransfer.TransferID(rand.Int31()) accepted := false - voucher := testutil.NewFakeDTType() - voucherResult := testutil.NewFakeDTType() - request, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) + voucher := testutil.NewTestTypedVoucher() + voucherResult := testutil.NewTestTypedVoucher() + request, err := message1_1.NewRequest(id, false, isPull, &voucher, baseCid, selector) require.NoError(t, err) buf := new(bytes.Buffer) err = request.ToNet(buf) @@ -426,10 +411,10 @@ func TestToNetFromNetEquivalency(t *testing.T) { require.Equal(t, deserializedRequest.IsPull(), request.IsPull()) require.Equal(t, deserializedRequest.IsRequest(), request.IsRequest()) require.Equal(t, deserializedRequest.BaseCid(), request.BaseCid()) - testutil.AssertEqualFakeDTVoucher(t, request, deserializedRequest) + testutil.AssertEqualTestVoucher(t, request, deserializedRequest) testutil.AssertEqualSelector(t, request, deserializedRequest) - response, err := message1_1.NewResponse(id, accepted, false, voucherResult.Type(), voucherResult) + response, err := message1_1.NewResponse(id, accepted, false, &voucherResult) require.NoError(t, err) err = response.ToNet(buf) require.NoError(t, err) @@ -444,7 +429,7 @@ func TestToNetFromNetEquivalency(t *testing.T) { require.Equal(t, deserializedResponse.IsRequest(), response.IsRequest()) require.Equal(t, deserializedResponse.IsUpdate(), response.IsUpdate()) require.Equal(t, deserializedResponse.IsPaused(), response.IsPaused()) - testutil.AssertEqualFakeDTVoucherResult(t, response, deserializedResponse) + testutil.AssertEqualTestVoucherResult(t, response, deserializedResponse) request = message1_1.CancelRequest(id) err = request.ToNet(buf) @@ -465,15 +450,17 @@ func TestToNetFromNetEquivalency(t *testing.T) { isPull := false id := datatransfer.TransferID(1298498081) accepted := false - voucher := testutil.NewFakeDTType() - voucherResult := testutil.NewFakeDTType() - request, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) + vouchByts, _ := hex.DecodeString("f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e35") + voucher := testutil.NewTestTypedVoucherWith(string(vouchByts)) + vouchResultByts, _ := hex.DecodeString("4204cb9a1e34c5f08e9b20aa76090e70020bb56c0ca3d3af7296cd1058a5112890fed218488f084d8df9e4835fb54ad045ffd936e3bf7261b0426c51352a097816ed74482bb9084b4a7ed8adc517f3371e0e0434b511625cd1a41792243dccdcfe88094b") + voucherResult := testutil.NewTestTypedVoucherWith(string(vouchResultByts)) + request, err := message1_1.NewRequest(id, false, isPull, &voucher, baseCid, selector) require.NoError(t, err) buf := new(bytes.Buffer) err = request.ToNet(buf) require.NoError(t, err) require.Greater(t, buf.Len(), 0) - msg, _ := hex.DecodeString("a36449735271f56752657175657374aa6442436964d82a58230012204bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a6454797065006450617573f46450617274f46450756c6cf46453746f72a1612ea065566f756368817864f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e3564565479706a46616b65445454797065665866657249441a4d6582216e526573746172744368616e6e656c8360600068526573706f6e7365f6") + msg, _ := hex.DecodeString("a36449735271f56752657175657374aa6442436964d82a58230012204bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a6454797065006450617573f46450617274f46450756c6cf46453746f72a1612ea065566f756368817864f55ff8f12508b63ef2bfeca7557ae90df6311a5ec1631b4a1fa843310bd9c3a710eaace5a1bdd72ad0bfe049771c11e756338bd93865e645f1adec9b9c99ef407fbd4fc6859e7904c5ad7dc9bd10a5cc16973d5b28ec1a6dd43d9f82f9f18c3d03418e3564565479706b54657374566f7563686572665866657249441a4d6582216e526573746172744368616e6e656c8360600068526573706f6e7365f6") deserialized, err := message1_1.FromNet(bytes.NewReader(msg)) require.NoError(t, err) @@ -486,14 +473,14 @@ func TestToNetFromNetEquivalency(t *testing.T) { require.Equal(t, deserializedRequest.IsRequest(), request.IsRequest()) c, _ := cid.Parse("QmTTA2daxGqo5denp6SwLzzkLJm3fuisYEi9CoWsuHpzfb") assert.Equal(t, c, deserializedRequest.BaseCid()) - testutil.AssertEqualFakeDTVoucher(t, request, deserializedRequest) + testutil.AssertEqualTestVoucher(t, request, deserializedRequest) testutil.AssertEqualSelector(t, request, deserializedRequest) - response, err := message1_1.NewResponse(id, accepted, false, voucherResult.Type(), voucherResult) + response, err := message1_1.NewResponse(id, accepted, false, &voucherResult) require.NoError(t, err) err = response.ToNet(buf) require.NoError(t, err) - msg, _ = hex.DecodeString("a36449735271f46752657175657374f668526573706f6e7365a66454797065006441637074f46450617573f4665866657249441a4d65822164565265738178644204cb9a1e34c5f08e9b20aa76090e70020bb56c0ca3d3af7296cd1058a5112890fed218488f084d8df9e4835fb54ad045ffd936e3bf7261b0426c51352a097816ed74482bb9084b4a7ed8adc517f3371e0e0434b511625cd1a41792243dccdcfe88094b64565479706a46616b65445454797065") + msg, _ = hex.DecodeString("a36449735271f46752657175657374f668526573706f6e7365a66454797065006441637074f46450617573f4665866657249441a4d65822164565265738178644204cb9a1e34c5f08e9b20aa76090e70020bb56c0ca3d3af7296cd1058a5112890fed218488f084d8df9e4835fb54ad045ffd936e3bf7261b0426c51352a097816ed74482bb9084b4a7ed8adc517f3371e0e0434b511625cd1a41792243dccdcfe88094b64565479706b54657374566f7563686572") deserialized, err = message1_1.FromNet(bytes.NewReader(msg)) require.NoError(t, err) @@ -505,7 +492,7 @@ func TestToNetFromNetEquivalency(t *testing.T) { require.Equal(t, deserializedResponse.IsRequest(), response.IsRequest()) require.Equal(t, deserializedResponse.IsUpdate(), response.IsUpdate()) require.Equal(t, deserializedResponse.IsPaused(), response.IsPaused()) - testutil.AssertEqualFakeDTVoucherResult(t, response, deserializedResponse) + testutil.AssertEqualTestVoucherResult(t, response, deserializedResponse) request = message1_1.CancelRequest(id) err = request.ToNet(buf) @@ -537,13 +524,13 @@ func TestFromNetMessageValidation(t *testing.T) { assert.Nil(t, msg) } -func NewTestTransferRequest() (message1_1.TransferRequest1_1, error) { +func NewTestTransferRequest(data string) (message1_1.TransferRequest1_1, error) { bcid := testutil.GenerateCids(1)[0] selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() isPull := false id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() - req, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, bcid, selector) + voucher := testutil.NewTestTypedVoucherWith(data) + req, err := message1_1.NewRequest(id, false, isPull, &voucher, bcid, selector) if err != nil { return message1_1.TransferRequest1_1{}, err } diff --git a/message/message1_1prime/schema.go b/message/message1_1prime/schema.go deleted file mode 100644 index c779b1fc..00000000 --- a/message/message1_1prime/schema.go +++ /dev/null @@ -1,29 +0,0 @@ -package message1_1 - -import ( - _ "embed" - - "github.com/ipld/go-ipld-prime" - "github.com/ipld/go-ipld-prime/node/bindnode" - "github.com/ipld/go-ipld-prime/schema" -) - -//go:embed schema.ipldsch -var embedSchema []byte - -var Prototype struct { - TransferMessage schema.TypedPrototype - TransferRequest schema.TypedPrototype - TransferResponse schema.TypedPrototype -} - -func init() { - ts, err := ipld.LoadSchemaBytes(embedSchema) - if err != nil { - panic(err) - } - - Prototype.TransferMessage = bindnode.Prototype((*TransferMessage1_1)(nil), ts.TypeByName("TransferMessage")) - Prototype.TransferRequest = bindnode.Prototype((*TransferRequest1_1)(nil), ts.TypeByName("TransferRequest")) - Prototype.TransferResponse = bindnode.Prototype((*TransferResponse1_1)(nil), ts.TypeByName("TransferResponse")) -} diff --git a/message/message1_1prime/schema.ipldsch b/message/message1_1prime/schema.ipldsch index 71413514..d5f9f87a 100644 --- a/message/message1_1prime/schema.ipldsch +++ b/message/message1_1prime/schema.ipldsch @@ -30,7 +30,7 @@ type TransferResponse struct { VoucherTypeIdentifier TypeIdentifier (rename "VTyp") } -type TransferMessage struct { +type TransferMessage1_1 struct { IsRequest Bool (rename "IsRq") Request nullable TransferRequest Response nullable TransferResponse diff --git a/message/message1_1prime/transfer_message.go b/message/message1_1prime/transfer_message.go index e51fc9d4..0944212b 100644 --- a/message/message1_1prime/transfer_message.go +++ b/message/message1_1prime/transfer_message.go @@ -1,16 +1,19 @@ package message1_1 import ( + _ "embed" "io" - "github.com/ipld/go-ipld-prime/codec/dagcbor" "github.com/ipld/go-ipld-prime/datamodel" - "github.com/ipld/go-ipld-prime/node/bindnode" "github.com/ipld/go-ipld-prime/schema" datatransfer "github.com/filecoin-project/go-data-transfer/v2" + ipldutils "github.com/filecoin-project/go-data-transfer/v2/ipldutils" ) +//go:embed schema.ipldsch +var embedSchema []byte + // TransferMessage1_1 is the transfer message for the 1.1 Data Transfer Protocol. type TransferMessage1_1 struct { IsRequest bool @@ -19,6 +22,10 @@ type TransferMessage1_1 struct { Response *TransferResponse1_1 } +func (tm *TransferMessage1_1) BindnodeSchema() string { + return string(embedSchema) +} + // ========= datatransfer.Message interface // TransferID returns the TransferID of this message @@ -29,16 +36,24 @@ func (tm *TransferMessage1_1) TransferID() datatransfer.TransferID { return tm.Response.TransferID() } -func (tm *TransferMessage1_1) toIPLD() schema.TypedNode { - return bindnode.Wrap(tm, Prototype.TransferMessage.Type()) +func (tm *TransferMessage1_1) toIPLD() (schema.TypedNode, error) { + return ipldutils.ToNode(tm) } -// ToNet serializes a transfer message type. +// ToIPLD converts a transfer message type to an ipld Node func (tm *TransferMessage1_1) ToIPLD() (datamodel.Node, error) { - return tm.toIPLD().Representation(), nil + node, err := tm.toIPLD() + if err != nil { + return nil, err + } + return node.Representation(), nil } // ToNet serializes a transfer message type. func (tm *TransferMessage1_1) ToNet(w io.Writer) error { - return dagcbor.Encode(tm.toIPLD().Representation(), w) + i, err := tm.toIPLD() + if err != nil { + return err + } + return ipldutils.NodeToWriter(i, w) } diff --git a/message/message1_1prime/transfer_request.go b/message/message1_1prime/transfer_request.go index 5d844d92..a01e2452 100644 --- a/message/message1_1prime/transfer_request.go +++ b/message/message1_1prime/transfer_request.go @@ -4,14 +4,13 @@ import ( "io" "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime/codec/dagcbor" "github.com/ipld/go-ipld-prime/datamodel" "github.com/ipld/go-ipld-prime/schema" "github.com/libp2p/go-libp2p-core/protocol" xerrors "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer/v2" - "github.com/filecoin-project/go-data-transfer/v2/encoding" + ipldutils "github.com/filecoin-project/go-data-transfer/v2/ipldutils" "github.com/filecoin-project/go-data-transfer/v2/message/types" ) @@ -23,8 +22,8 @@ type TransferRequest1_1 struct { Pause bool Partial bool Pull bool - SelectorPtr *datamodel.Node - VoucherPtr *datamodel.Node + SelectorPtr datamodel.Node + VoucherPtr datamodel.Node VoucherTypeIdentifier datatransfer.TypeIdentifier TransferId uint64 RestartChannel datatransfer.ChannelID @@ -91,11 +90,22 @@ func (trq *TransferRequest1_1) VoucherType() datatransfer.TypeIdentifier { } // Voucher returns the Voucher bytes -func (trq *TransferRequest1_1) Voucher(decoder encoding.Decoder) (encoding.Encodable, error) { +func (trq *TransferRequest1_1) Voucher() (datamodel.Node, error) { if trq.VoucherPtr == nil { return nil, xerrors.New("No voucher present to read") } - return decoder.DecodeFromNode(*trq.VoucherPtr) + return trq.VoucherPtr, nil +} + +// TypedVoucher is a convenience method that returns the voucher and its typed +// as a TypedVoucher object +// TODO(rvagg): tests for this +func (trq *TransferRequest1_1) TypedVoucher() (datatransfer.TypedVoucher, error) { + voucher, err := trq.Voucher() + if err != nil { + return datatransfer.TypedVoucher{}, err + } + return datatransfer.TypedVoucher{Voucher: voucher, Type: trq.VoucherType()}, nil } func (trq *TransferRequest1_1) EmptyVoucher() bool { @@ -115,7 +125,7 @@ func (trq *TransferRequest1_1) Selector() (datamodel.Node, error) { if trq.SelectorPtr == nil { return nil, xerrors.New("No selector present to read") } - return *trq.SelectorPtr, nil + return trq.SelectorPtr, nil } // IsCancel returns true if this is a cancel request @@ -128,20 +138,28 @@ func (trq *TransferRequest1_1) IsPartial() bool { return trq.Partial } -func (trsp *TransferRequest1_1) toIPLD() schema.TypedNode { +func (trq *TransferRequest1_1) toIPLD() (schema.TypedNode, error) { msg := TransferMessage1_1{ IsRequest: true, - Request: trsp, + Request: trq, Response: nil, } return msg.toIPLD() } func (trq *TransferRequest1_1) ToIPLD() (datamodel.Node, error) { - return trq.toIPLD().Representation(), nil + msg, err := trq.toIPLD() + if err != nil { + return nil, err + } + return msg.Representation(), nil } // ToNet serializes a transfer request. func (trq *TransferRequest1_1) ToNet(w io.Writer) error { - return dagcbor.Encode(trq.toIPLD().Representation(), w) + i, err := trq.toIPLD() + if err != nil { + return err + } + return ipldutils.NodeToWriter(i, w) } diff --git a/message/message1_1prime/transfer_request_test.go b/message/message1_1prime/transfer_request_test.go index 093653ed..2cff27e3 100644 --- a/message/message1_1prime/transfer_request_test.go +++ b/message/message1_1prime/transfer_request_test.go @@ -18,10 +18,10 @@ func TestRequestMessageForProtocol(t *testing.T) { selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() isPull := true id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() + voucher := testutil.NewTestTypedVoucher() // for the new protocols - request, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) + request, err := message1_1.NewRequest(id, false, isPull, &voucher, baseCid, selector) require.NoError(t, err) out12, err := request.MessageForProtocol(datatransfer.ProtocolDataTransfer1_2) @@ -37,5 +37,5 @@ func TestRequestMessageForProtocol(t *testing.T) { n, err := req.Selector() require.NoError(t, err) require.Equal(t, selector, n) - require.Equal(t, voucher.Type(), req.VoucherType()) + require.Equal(t, testutil.TestVoucherType, req.VoucherType()) } diff --git a/message/message1_1prime/transfer_response.go b/message/message1_1prime/transfer_response.go index 32d240d6..1431ff72 100644 --- a/message/message1_1prime/transfer_response.go +++ b/message/message1_1prime/transfer_response.go @@ -3,14 +3,13 @@ package message1_1 import ( "io" - "github.com/ipld/go-ipld-prime/codec/dagcbor" "github.com/ipld/go-ipld-prime/datamodel" "github.com/ipld/go-ipld-prime/schema" "github.com/libp2p/go-libp2p-core/protocol" xerrors "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer/v2" - "github.com/filecoin-project/go-data-transfer/v2/encoding" + ipldutils "github.com/filecoin-project/go-data-transfer/v2/ipldutils" "github.com/filecoin-project/go-data-transfer/v2/message/types" ) @@ -21,7 +20,7 @@ type TransferResponse1_1 struct { RequestAccepted bool Paused bool TransferId uint64 - VoucherResultPtr *datamodel.Node + VoucherResultPtr datamodel.Node VoucherTypeIdentifier datatransfer.TypeIdentifier } @@ -73,11 +72,11 @@ func (trsp *TransferResponse1_1) VoucherResultType() datatransfer.TypeIdentifier return trsp.VoucherTypeIdentifier } -func (trsp *TransferResponse1_1) VoucherResult(decoder encoding.Decoder) (encoding.Encodable, error) { +func (trsp *TransferResponse1_1) VoucherResult() (datamodel.Node, error) { if trsp.VoucherResultPtr == nil { return nil, xerrors.New("No voucher present to read") } - return decoder.DecodeFromNode(*trsp.VoucherResultPtr) + return trsp.VoucherResultPtr, nil } func (trq *TransferResponse1_1) IsRestart() bool { @@ -97,7 +96,7 @@ func (trsp *TransferResponse1_1) MessageForProtocol(targetProtocol protocol.ID) } } -func (trsp *TransferResponse1_1) toIPLD() schema.TypedNode { +func (trsp *TransferResponse1_1) toIPLD() (schema.TypedNode, error) { msg := TransferMessage1_1{ IsRequest: false, Request: nil, @@ -107,10 +106,18 @@ func (trsp *TransferResponse1_1) toIPLD() schema.TypedNode { } func (trsp *TransferResponse1_1) ToIPLD() (datamodel.Node, error) { - return trsp.toIPLD().Representation(), nil + msg, err := trsp.toIPLD() + if err != nil { + return nil, err + } + return msg.Representation(), nil } // ToNet serializes a transfer response. func (trsp *TransferResponse1_1) ToNet(w io.Writer) error { - return dagcbor.Encode(trsp.toIPLD().Representation(), w) + i, err := trsp.toIPLD() + if err != nil { + return err + } + return ipldutils.NodeToWriter(i, w) } diff --git a/message/message1_1prime/transfer_response_test.go b/message/message1_1prime/transfer_response_test.go index dbaf2de0..9b979371 100644 --- a/message/message1_1prime/transfer_response_test.go +++ b/message/message1_1prime/transfer_response_test.go @@ -13,8 +13,8 @@ import ( func TestResponseMessageForProtocol(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - voucherResult := testutil.NewFakeDTType() - response, err := message1_1.NewResponse(id, false, true, voucherResult.Type(), voucherResult) // not accepted + voucherResult := testutil.NewTestTypedVoucher() + response, err := message1_1.NewResponse(id, false, true, &voucherResult) // not accepted require.NoError(t, err) // v1.2 protocol @@ -25,7 +25,7 @@ func TestResponseMessageForProtocol(t *testing.T) { resp, ok := (out).(datatransfer.Response) require.True(t, ok) require.True(t, resp.IsPaused()) - require.Equal(t, voucherResult.Type(), resp.VoucherResultType()) + require.Equal(t, testutil.TestVoucherType, resp.VoucherResultType()) require.True(t, resp.IsValidationResult()) // random protocol diff --git a/network/libp2p_impl_test.go b/network/libp2p_impl_test.go index e35c250a..89489155 100644 --- a/network/libp2p_impl_test.go +++ b/network/libp2p_impl_test.go @@ -102,8 +102,8 @@ func TestMessageSendAndReceive(t *testing.T) { selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() isPull := false id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() - request, err := message.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) + voucher := testutil.NewTestTypedVoucher() + request, err := message.NewRequest(id, false, isPull, &voucher, baseCid, selector) require.NoError(t, err) require.NoError(t, dtnet1.SendMessage(ctx, host2.ID(), request)) @@ -124,15 +124,15 @@ func TestMessageSendAndReceive(t *testing.T) { assert.Equal(t, request.IsPull(), receivedRequest.IsPull()) assert.Equal(t, request.IsRequest(), receivedRequest.IsRequest()) assert.True(t, receivedRequest.BaseCid().Equals(request.BaseCid())) - testutil.AssertEqualFakeDTVoucher(t, request, receivedRequest) + testutil.AssertEqualTestVoucher(t, request, receivedRequest) testutil.AssertEqualSelector(t, request, receivedRequest) }) t.Run("Send Response", func(t *testing.T) { accepted := false id := datatransfer.TransferID(rand.Int31()) - voucherResult := testutil.NewFakeDTType() - response, err := message.ValidationResultResponse(types.NewMessage, id, datatransfer.ValidationResult{Accepted: accepted, VoucherResult: voucherResult}, nil, false) + voucherResult := testutil.NewTestTypedVoucher() + response, err := message.ValidationResultResponse(types.NewMessage, id, datatransfer.ValidationResult{Accepted: accepted, VoucherResult: &voucherResult}, nil, false) require.NoError(t, err) require.NoError(t, dtnet2.SendMessage(ctx, host1.ID(), response)) @@ -151,7 +151,7 @@ func TestMessageSendAndReceive(t *testing.T) { assert.Equal(t, response.TransferID(), receivedResponse.TransferID()) assert.Equal(t, response.Accepted(), receivedResponse.Accepted()) assert.Equal(t, response.IsRequest(), receivedResponse.IsRequest()) - testutil.AssertEqualFakeDTVoucherResult(t, response, receivedResponse) + testutil.AssertEqualTestVoucherResult(t, response, receivedResponse) }) t.Run("Send Restart Request", func(t *testing.T) { @@ -273,8 +273,8 @@ func TestSendMessageRetry(t *testing.T) { selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() isPull := false id := datatransfer.TransferID(rand.Int31()) - voucher := testutil.NewFakeDTType() - request, err := message.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) + voucher := testutil.NewTestTypedVoucher() + request, err := message.NewRequest(id, false, isPull, &voucher, baseCid, selector) require.NoError(t, err) err = dtnet1.SendMessage(ctx, host2.ID(), request) diff --git a/registry/registry.go b/registry/registry.go index 4237cec7..00f3815d 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -6,7 +6,6 @@ import ( "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer/v2" - "github.com/filecoin-project/go-data-transfer/v2/encoding" ) // Processor is an interface that processes a certain type of encodable objects @@ -14,11 +13,6 @@ import ( // left to the user of the registry type Processor interface{} -type registryEntry struct { - decoder encoding.Decoder - processor Processor -} - // Registry maintans a register of types of encodable objects and a corresponding // processor for those objects // The encodable types must have a method Type() that specifies and identifier @@ -26,54 +20,41 @@ type registryEntry struct { // on this unique identifier type Registry struct { registryLk sync.RWMutex - entries map[datatransfer.TypeIdentifier]registryEntry + entries map[datatransfer.TypeIdentifier]Processor } // NewRegistry initialzes a new registy func NewRegistry() *Registry { return &Registry{ - entries: make(map[datatransfer.TypeIdentifier]registryEntry), + entries: make(map[datatransfer.TypeIdentifier]Processor), } } // Register registers the given processor for the given entry type -func (r *Registry) Register(entry datatransfer.Registerable, processor Processor) error { - identifier := entry.Type() - decoder, err := encoding.NewDecoder(entry) - if err != nil { - return xerrors.Errorf("registering entry type %s: %w", identifier, err) - } +func (r *Registry) Register(identifier datatransfer.TypeIdentifier, processor Processor) error { r.registryLk.Lock() defer r.registryLk.Unlock() if _, ok := r.entries[identifier]; ok { return xerrors.Errorf("identifier already registered: %s", identifier) } - r.entries[identifier] = registryEntry{decoder, processor} + r.entries[identifier] = processor return nil } -// Decoder gets a decoder for the given identifier -func (r *Registry) Decoder(identifier datatransfer.TypeIdentifier) (encoding.Decoder, bool) { - r.registryLk.RLock() - entry, has := r.entries[identifier] - r.registryLk.RUnlock() - return entry.decoder, has -} - // Processor gets the processing interface for the given identifer func (r *Registry) Processor(identifier datatransfer.TypeIdentifier) (Processor, bool) { r.registryLk.RLock() entry, has := r.entries[identifier] r.registryLk.RUnlock() - return entry.processor, has + return entry, has } // Each iterates through all of the entries in this registry -func (r *Registry) Each(process func(datatransfer.TypeIdentifier, encoding.Decoder, Processor) error) error { +func (r *Registry) Each(process func(datatransfer.TypeIdentifier, Processor) error) error { r.registryLk.RLock() defer r.registryLk.RUnlock() - for identifier, entry := range r.entries { - err := process(identifier, entry.decoder, entry.processor) + for identifier, processor := range r.entries { + err := process(identifier, processor) if err != nil { return err } diff --git a/registry/registry_test.go b/registry/registry_test.go index 63adf363..84fd95c8 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -12,27 +12,15 @@ import ( func TestRegistry(t *testing.T) { r := registry.NewRegistry() t.Run("it registers", func(t *testing.T) { - err := r.Register(&testutil.FakeDTType{}, func() {}) + err := r.Register(testutil.TestVoucherType, func() {}) require.NoError(t, err) }) t.Run("it errors when registred again", func(t *testing.T) { - err := r.Register(&testutil.FakeDTType{}, func() {}) - require.EqualError(t, err, "identifier already registered: FakeDTType") - }) - t.Run("it errors when decoder setup fails", func(t *testing.T) { - err := r.Register(testutil.FakeDTType{}, func() {}) - require.EqualError(t, err, "registering entry type FakeDTType: type must be a pointer") - }) - t.Run("it reads decoders", func(t *testing.T) { - decoder, has := r.Decoder("FakeDTType") - require.True(t, has) - require.NotNil(t, decoder) - decoder, has = r.Decoder("OtherType") - require.False(t, has) - require.Nil(t, decoder) + err := r.Register(testutil.TestVoucherType, func() {}) + require.EqualError(t, err, "identifier already registered: TestVoucher") }) t.Run("it reads processors", func(t *testing.T) { - processor, has := r.Processor("FakeDTType") + processor, has := r.Processor("TestVoucher") require.True(t, has) require.NotNil(t, processor) processor, has = r.Processor("OtherType") diff --git a/testutil/fakedttype.go b/testutil/fakedttype.go index 20cc219e..03db9273 100644 --- a/testutil/fakedttype.go +++ b/testutil/fakedttype.go @@ -3,71 +3,80 @@ package testutil import ( "testing" + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/fluent/qp" + basicnode "github.com/ipld/go-ipld-prime/node/basic" "github.com/stretchr/testify/require" datatransfer "github.com/filecoin-project/go-data-transfer/v2" - "github.com/filecoin-project/go-data-transfer/v2/encoding" ) -//go:generate cbor-gen-for FakeDTType +const TestVoucherType = datatransfer.TypeIdentifier("TestVoucher") -// FakeDTType simple fake type for using with registries -type FakeDTType struct { - Data string -} - -// Type satisfies registry.Entry -func (ft FakeDTType) Type() datatransfer.TypeIdentifier { - return "FakeDTType" -} - -// AssertFakeDTVoucher asserts that a data transfer requests contains the expected fake data transfer voucher type -func AssertFakeDTVoucher(t *testing.T, request datatransfer.Request, expected *FakeDTType) { - require.Equal(t, datatransfer.TypeIdentifier("FakeDTType"), request.VoucherType()) - fakeDTDecoder, err := encoding.NewDecoder(&FakeDTType{}) +// AssertTestVoucher asserts that a data transfer requests contains the expected fake data transfer voucher type +func AssertTestVoucher(t *testing.T, request datatransfer.Request, expected datatransfer.TypedVoucher) { + require.Equal(t, expected.Type, request.VoucherType()) + voucher, err := request.Voucher() require.NoError(t, err) - decoded, err := request.Voucher(fakeDTDecoder) - require.NoError(t, err) - require.Equal(t, expected, decoded) + require.True(t, ipld.DeepEqual(expected.Voucher, voucher)) } -// AssertEqualFakeDTVoucher asserts that two requests have the same fake data transfer voucher -func AssertEqualFakeDTVoucher(t *testing.T, expectedRequest datatransfer.Request, request datatransfer.Request) { +// AssertEqualTestVoucher asserts that two requests have the same fake data transfer voucher +func AssertEqualTestVoucher(t *testing.T, expectedRequest datatransfer.Request, request datatransfer.Request) { require.Equal(t, expectedRequest.VoucherType(), request.VoucherType()) - fakeDTDecoder, err := encoding.NewDecoder(&FakeDTType{}) - require.NoError(t, err) - expectedDecoded, err := request.Voucher(fakeDTDecoder) + require.Equal(t, TestVoucherType, request.VoucherType()) + expected, err := expectedRequest.Voucher() require.NoError(t, err) - decoded, err := request.Voucher(fakeDTDecoder) + actual, err := request.Voucher() require.NoError(t, err) - require.Equal(t, expectedDecoded, decoded) + require.True(t, ipld.DeepEqual(expected, actual)) } -// AssertFakeDTVoucherResult asserts that a data transfer response contains the expected fake data transfer voucher result type -func AssertFakeDTVoucherResult(t *testing.T, response datatransfer.Response, expected *FakeDTType) { - require.Equal(t, datatransfer.TypeIdentifier("FakeDTType"), response.VoucherResultType()) - fakeDTDecoder, err := encoding.NewDecoder(&FakeDTType{}) - require.NoError(t, err) - decoded, err := response.VoucherResult(fakeDTDecoder) +// AssertTestVoucherResult asserts that a data transfer response contains the expected fake data transfer voucher result type +func AssertTestVoucherResult(t *testing.T, response datatransfer.Response, expected datatransfer.TypedVoucher) { + require.Equal(t, expected.Type, response.VoucherResultType()) + voucherResult, err := response.VoucherResult() require.NoError(t, err) - require.Equal(t, expected, decoded) + require.True(t, ipld.DeepEqual(expected.Voucher, voucherResult)) } -// AssertEqualFakeDTVoucherResult asserts that two responses have the same fake data transfer voucher result -func AssertEqualFakeDTVoucherResult(t *testing.T, expectedResponse datatransfer.Response, response datatransfer.Response) { +// AssertEqualTestVoucherResult asserts that two responses have the same fake data transfer voucher result +func AssertEqualTestVoucherResult(t *testing.T, expectedResponse datatransfer.Response, response datatransfer.Response) { require.Equal(t, expectedResponse.VoucherResultType(), response.VoucherResultType()) - fakeDTDecoder, err := encoding.NewDecoder(&FakeDTType{}) + expectedVoucherResult, err := expectedResponse.VoucherResult() require.NoError(t, err) - expectedDecoded, err := response.VoucherResult(fakeDTDecoder) + actualVoucherResult, err := response.VoucherResult() require.NoError(t, err) - decoded, err := response.VoucherResult(fakeDTDecoder) - require.NoError(t, err) - require.Equal(t, expectedDecoded, decoded) + require.True(t, ipld.DeepEqual(expectedVoucherResult, actualVoucherResult)) +} + +// NewTestVoucher returns a fake voucher with random data +func NewTestVoucher() datamodel.Node { + n, err := qp.BuildList(basicnode.Prototype.Any, 1, func(ma datamodel.ListAssembler) { + qp.ListEntry(ma, qp.String(string(RandomBytes(100)))) + }) + if err != nil { + panic(err) + } + return n } -// NewFakeDTType returns a fake dt type with random data -func NewFakeDTType() *FakeDTType { - return &FakeDTType{Data: string(RandomBytes(100))} +func NewTestTypedVoucher() datatransfer.TypedVoucher { + return datatransfer.TypedVoucher{Voucher: NewTestVoucher(), Type: TestVoucherType} } -var _ datatransfer.Registerable = &FakeDTType{} +// NewTestVoucher returns a fake voucher with random data +func NewTestVoucherWith(data string) datamodel.Node { + n, err := qp.BuildList(basicnode.Prototype.Any, 1, func(ma datamodel.ListAssembler) { + qp.ListEntry(ma, qp.String(data)) + }) + if err != nil { + panic(err) + } + return n +} + +func NewTestTypedVoucherWith(data string) datatransfer.TypedVoucher { + return datatransfer.TypedVoucher{Voucher: NewTestVoucherWith(data), Type: TestVoucherType} +} diff --git a/testutil/fakedttype_cbor_gen.go b/testutil/fakedttype_cbor_gen.go deleted file mode 100644 index d7913605..00000000 --- a/testutil/fakedttype_cbor_gen.go +++ /dev/null @@ -1,75 +0,0 @@ -// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. - -package testutil - -import ( - "fmt" - "io" - "sort" - - cid "github.com/ipfs/go-cid" - cbg "github.com/whyrusleeping/cbor-gen" - xerrors "golang.org/x/xerrors" -) - -var _ = xerrors.Errorf -var _ = cid.Undef -var _ = sort.Sort - -var lengthBufFakeDTType = []byte{129} - -func (t *FakeDTType) MarshalCBOR(w io.Writer) error { - if t == nil { - _, err := w.Write(cbg.CborNull) - return err - } - if _, err := w.Write(lengthBufFakeDTType); err != nil { - return err - } - - scratch := make([]byte, 9) - - // t.Data (string) (string) - if len(t.Data) > cbg.MaxLength { - return xerrors.Errorf("Value in field t.Data was too long") - } - - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Data))); err != nil { - return err - } - if _, err := io.WriteString(w, string(t.Data)); err != nil { - return err - } - return nil -} - -func (t *FakeDTType) UnmarshalCBOR(r io.Reader) error { - *t = FakeDTType{} - - br := cbg.GetPeeker(r) - scratch := make([]byte, 8) - - maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != cbg.MajArray { - return fmt.Errorf("cbor input should be of type array") - } - - if extra != 1 { - return fmt.Errorf("cbor input had wrong number of fields") - } - - // t.Data (string) (string) - - { - sval, err := cbg.ReadStringBuf(br, scratch) - if err != nil { - return err - } - - t.Data = string(sval) - } - return nil -} diff --git a/testutil/fakegraphsync.go b/testutil/fakegraphsync.go index 42b30873..f0ac8ab0 100644 --- a/testutil/fakegraphsync.go +++ b/testutil/fakegraphsync.go @@ -40,7 +40,7 @@ type ReceivedGraphSyncRequest struct { Ctx context.Context P peer.ID Root ipld.Link - Selector ipld.Node + Selector datamodel.Node Extensions []graphsync.ExtensionData ResponseChan chan graphsync.ResponseProgress ResponseErrChan chan error @@ -193,7 +193,7 @@ func (fgs *FakeGraphSync) AssertDoesNotHavePersistenceOption(t *testing.T, name } // Request initiates a new GraphSync request to the given peer using the given selector spec. -func (fgs *FakeGraphSync) Request(ctx context.Context, p peer.ID, root ipld.Link, selector ipld.Node, extensions ...graphsync.ExtensionData) (<-chan graphsync.ResponseProgress, <-chan error) { +func (fgs *FakeGraphSync) Request(ctx context.Context, p peer.ID, root ipld.Link, selector datamodel.Node, extensions ...graphsync.ExtensionData) (<-chan graphsync.ResponseProgress, <-chan error) { errors := make(chan error) responses := make(chan graphsync.ResponseProgress) fgs.requests <- ReceivedGraphSyncRequest{ctx, p, root, selector, extensions, responses, errors} @@ -387,7 +387,7 @@ func NewFakeBlockData() graphsync.BlockData { type fakeRequest struct { id graphsync.RequestID root cid.Cid - selector ipld.Node + selector datamodel.Node priority graphsync.Priority requestType graphsync.RequestType extensions map[graphsync.ExtensionName]datamodel.Node @@ -404,7 +404,7 @@ func (fr *fakeRequest) Root() cid.Cid { } // Selector returns the byte representation of the selector for this request -func (fr *fakeRequest) Selector() ipld.Node { +func (fr *fakeRequest) Selector() datamodel.Node { return fr.selector } diff --git a/testutil/faketransport.go b/testutil/faketransport.go index 92441364..08faee74 100644 --- a/testutil/faketransport.go +++ b/testutil/faketransport.go @@ -4,6 +4,7 @@ import ( "context" "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" "github.com/libp2p/go-libp2p-core/peer" datatransfer "github.com/filecoin-project/go-data-transfer/v2" @@ -14,7 +15,7 @@ type OpenedChannel struct { DataSender peer.ID ChannelID datatransfer.ChannelID Root ipld.Link - Selector ipld.Node + Selector datamodel.Node Channel datatransfer.ChannelState Message datatransfer.Message } @@ -28,7 +29,7 @@ type ResumedChannel struct { // CustomizedTransfer is just a way to record calls made to transport configurer type CustomizedTransfer struct { ChannelID datatransfer.ChannelID - Voucher datatransfer.Voucher + Voucher datatransfer.TypedVoucher } // FakeTransport is a fake transport with mocked results @@ -57,7 +58,7 @@ func NewFakeTransport() *FakeTransport { // Note: from a data transfer symantic standpoint, it doesn't matter if the // request is push or pull -- OpenChannel is called by the party that is // intending to receive data -func (ft *FakeTransport) OpenChannel(ctx context.Context, dataSender peer.ID, channelID datatransfer.ChannelID, root ipld.Link, stor ipld.Node, channel datatransfer.ChannelState, msg datatransfer.Message) error { +func (ft *FakeTransport) OpenChannel(ctx context.Context, dataSender peer.ID, channelID datatransfer.ChannelID, root ipld.Link, stor datamodel.Node, channel datatransfer.ChannelState, msg datatransfer.Message) error { ft.OpenedChannels = append(ft.OpenedChannels, OpenedChannel{dataSender, channelID, root, stor, channel, msg}) return ft.OpenChannelErr } @@ -95,6 +96,6 @@ func (ft *FakeTransport) CleanupChannel(chid datatransfer.ChannelID) { ft.CleanedUpChannels = append(ft.CleanedUpChannels, chid) } -func (ft *FakeTransport) RecordCustomizedTransfer(chid datatransfer.ChannelID, voucher datatransfer.Voucher) { +func (ft *FakeTransport) RecordCustomizedTransfer(chid datatransfer.ChannelID, voucher datatransfer.TypedVoucher) { ft.CustomizedTransfers = append(ft.CustomizedTransfers, CustomizedTransfer{chid, voucher}) } diff --git a/testutil/gstestdata.go b/testutil/gstestdata.go index 3f080c5a..3fa22989 100644 --- a/testutil/gstestdata.go +++ b/testutil/gstestdata.go @@ -29,6 +29,7 @@ import ( "github.com/ipfs/go-unixfs/importer/balanced" ihelper "github.com/ipfs/go-unixfs/importer/helpers" "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" cidlink "github.com/ipld/go-ipld-prime/linking/cid" basicnode "github.com/ipld/go-ipld-prime/node/basic" "github.com/ipld/go-ipld-prime/traversal/selector" @@ -44,7 +45,7 @@ import ( "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" ) -var allSelector ipld.Node +var allSelector datamodel.Node const loremFile = "lorem.txt" @@ -82,7 +83,7 @@ type GraphsyncTestingData struct { GsNet2 gsnet.GraphSyncNetwork DtNet1 network.DataTransferNetwork DtNet2 network.DataTransferNetwork - AllSelector ipld.Node + AllSelector datamodel.Node OrigBytes []byte TempDir1 string TempDir2 string diff --git a/testutil/message.go b/testutil/message.go index c3745be9..14319cc3 100644 --- a/testutil/message.go +++ b/testutil/message.go @@ -13,18 +13,18 @@ import ( // NewDTRequest makes a new DT Request message func NewDTRequest(t *testing.T, transferID datatransfer.TransferID) datatransfer.Request { - voucher := NewFakeDTType() + voucher := NewTestTypedVoucher() baseCid := GenerateCids(1)[0] selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() - r, err := message.NewRequest(transferID, false, false, voucher.Type(), voucher, baseCid, selector) + r, err := message.NewRequest(transferID, false, false, &voucher, baseCid, selector) require.NoError(t, err) return r } // NewDTResponse makes a new DT Request message func NewDTResponse(t *testing.T, transferID datatransfer.TransferID) datatransfer.Response { - vresult := NewFakeDTType() - r, err := message.NewResponse(transferID, false, false, vresult.Type(), vresult) + vresult := NewTestTypedVoucher() + r, err := message.NewResponse(transferID, false, false, &vresult) require.NoError(t, err) return r } diff --git a/testutil/mockchannelstate.go b/testutil/mockchannelstate.go index a9a86586..a5754bde 100644 --- a/testutil/mockchannelstate.go +++ b/testutil/mockchannelstate.go @@ -2,7 +2,7 @@ package testutil import ( cid "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" "github.com/libp2p/go-libp2p-core/peer" datatransfer "github.com/filecoin-project/go-data-transfer/v2" @@ -105,11 +105,11 @@ func (m *MockChannelState) BaseCID() cid.Cid { panic("implement me") } -func (m *MockChannelState) Selector() ipld.Node { +func (m *MockChannelState) Selector() datamodel.Node { panic("implement me") } -func (m *MockChannelState) Voucher() datatransfer.Voucher { +func (m *MockChannelState) Voucher() (datatransfer.TypedVoucher, error) { panic("implement me") } @@ -141,19 +141,19 @@ func (m *MockChannelState) Message() string { panic("implement me") } -func (m *MockChannelState) Vouchers() []datatransfer.Voucher { +func (m *MockChannelState) Vouchers() ([]datatransfer.TypedVoucher, error) { panic("implement me") } -func (m *MockChannelState) VoucherResults() []datatransfer.VoucherResult { +func (m *MockChannelState) VoucherResults() ([]datatransfer.TypedVoucher, error) { panic("implement me") } -func (m *MockChannelState) LastVoucher() datatransfer.Voucher { +func (m *MockChannelState) LastVoucher() (datatransfer.TypedVoucher, error) { panic("implement me") } -func (m *MockChannelState) LastVoucherResult() datatransfer.VoucherResult { +func (m *MockChannelState) LastVoucherResult() (datatransfer.TypedVoucher, error) { panic("implement me") } diff --git a/testutil/stubbedvalidator.go b/testutil/stubbedvalidator.go index 30a8bab7..1647bdc0 100644 --- a/testutil/stubbedvalidator.go +++ b/testutil/stubbedvalidator.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/require" @@ -21,9 +21,9 @@ func NewStubbedValidator() *StubbedValidator { func (sv *StubbedValidator) ValidatePush( chid datatransfer.ChannelID, sender peer.ID, - voucher datatransfer.Voucher, + voucher datamodel.Node, baseCid cid.Cid, - selector ipld.Node) (datatransfer.ValidationResult, error) { + selector datamodel.Node) (datatransfer.ValidationResult, error) { sv.didPush = true sv.ValidationsReceived = append(sv.ValidationsReceived, ReceivedValidation{false, sender, voucher, baseCid, selector}) return sv.result, sv.pushError @@ -33,9 +33,9 @@ func (sv *StubbedValidator) ValidatePush( func (sv *StubbedValidator) ValidatePull( chid datatransfer.ChannelID, receiver peer.ID, - voucher datatransfer.Voucher, + voucher datamodel.Node, baseCid cid.Cid, - selector ipld.Node) (datatransfer.ValidationResult, error) { + selector datamodel.Node) (datatransfer.ValidationResult, error) { sv.didPull = true sv.ValidationsReceived = append(sv.ValidationsReceived, ReceivedValidation{true, receiver, voucher, baseCid, selector}) return sv.result, sv.pullError @@ -140,9 +140,9 @@ func (sv *StubbedValidator) ExpectSuccessValidateRestart() { type ReceivedValidation struct { IsPull bool Other peer.ID - Voucher datatransfer.Voucher + Voucher datamodel.Node BaseCid cid.Cid - Selector ipld.Node + Selector datamodel.Node } // ReceivedRestartValidation records a call to ValidateRestart diff --git a/testutil/testutil.go b/testutil/testutil.go index 66600730..cfcde8b5 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -9,7 +9,7 @@ import ( blocks "github.com/ipfs/go-block-format" "github.com/ipfs/go-cid" blocksutil "github.com/ipfs/go-ipfs-blocksutil" - "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" basicnode "github.com/ipld/go-ipld-prime/node/basic" "github.com/ipld/go-ipld-prime/traversal/selector" "github.com/ipld/go-ipld-prime/traversal/selector/builder" @@ -102,7 +102,7 @@ func AssertEqualSelector(t *testing.T, expectedRequest datatransfer.Request, req } // AllSelector just returns a new instance of a "whole dag selector" -func AllSelector() ipld.Node { +func AllSelector() datamodel.Node { ssb := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any) return ssb.ExploreRecursive(selector.RecursionLimitNone(), ssb.ExploreAll(ssb.ExploreRecursiveEdge())).Node() diff --git a/transport.go b/transport.go index 81187077..ca002a28 100644 --- a/transport.go +++ b/transport.go @@ -4,6 +4,7 @@ import ( "context" ipld "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" peer "github.com/libp2p/go-libp2p-core/peer" ) @@ -98,7 +99,7 @@ type Transport interface { dataSender peer.ID, channelID ChannelID, root ipld.Link, - stor ipld.Node, + stor datamodel.Node, channel ChannelState, msg Message, ) error diff --git a/transport/graphsync/graphsync.go b/transport/graphsync/graphsync.go index 7fc48d41..5b6e5e84 100644 --- a/transport/graphsync/graphsync.go +++ b/transport/graphsync/graphsync.go @@ -11,6 +11,7 @@ import ( "github.com/ipfs/go-graphsync/donotsendfirstblocks" logging "github.com/ipfs/go-log/v2" ipld "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" peer "github.com/libp2p/go-libp2p-core/peer" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" @@ -111,7 +112,7 @@ func (t *Transport) OpenChannel( dataSender peer.ID, channelID datatransfer.ChannelID, root ipld.Link, - stor ipld.Node, + stor datamodel.Node, channel datatransfer.ChannelState, msg datatransfer.Message, ) error { @@ -938,7 +939,7 @@ func (c *dtChannel) open( chid datatransfer.ChannelID, dataSender peer.ID, root ipld.Link, - stor ipld.Node, + stor datamodel.Node, channel datatransfer.ChannelState, exts []graphsync.ExtensionData, ) (*gsReq, error) { diff --git a/types.go b/types.go index 9201c5aa..b15d4cde 100644 --- a/types.go +++ b/types.go @@ -6,10 +6,9 @@ import ( "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" "github.com/libp2p/go-libp2p-core/peer" cbg "github.com/whyrusleeping/cbor-gen" - - "github.com/filecoin-project/go-data-transfer/v2/encoding" ) //go:generate cbor-gen-for ChannelID ChannelStages ChannelStage Log @@ -21,23 +20,18 @@ type TypeIdentifier string // EmptyTypeIdentifier means there is no voucher present const EmptyTypeIdentifier = TypeIdentifier("") -// Registerable is a type of object in a registry. It must be encodable and must -// have a single method that uniquely identifies its type -type Registerable interface { - encoding.Encodable - // Type is a unique string identifier for this voucher type - Type() TypeIdentifier +// TypedVoucher is a voucher or voucher result in IPLD form and an associated +// type identifier for that voucher or voucher result +type TypedVoucher struct { + Voucher datamodel.Node + Type TypeIdentifier } -// Voucher is used to validate -// a data transfer request against the underlying storage or retrieval deal -// that precipitated it. The only requirement is a voucher can read and write -// from bytes, and has a string identifier type -type Voucher Registerable - -// VoucherResult is used to provide option additional information about a -// voucher being rejected or accepted -type VoucherResult Registerable +// Equals is a utility to compare that two TypedVouchers are the same - both type +// and the voucher's IPLD content +func (tv1 TypedVoucher) Equals(tv2 TypedVoucher) bool { + return tv1.Type == tv2.Type && ipld.DeepEqual(tv1.Voucher, tv2.Voucher) +} // TransferID is an identifier for a data transfer, shared between // request/responder and unique to the requester @@ -74,10 +68,10 @@ type Channel interface { // Selector returns the IPLD selector for this data transfer (represented as // an IPLD node) - Selector() ipld.Node + Selector() datamodel.Node - // Voucher returns the voucher for this data transfer - Voucher() Voucher + // Voucher returns the initial voucher for this data transfer + Voucher() (TypedVoucher, error) // Sender returns the peer id for the node that is sending data Sender() peer.ID @@ -118,16 +112,16 @@ type ChannelState interface { Message() string // Vouchers returns all vouchers sent on this channel - Vouchers() []Voucher + Vouchers() ([]TypedVoucher, error) // VoucherResults are results of vouchers sent on the channel - VoucherResults() []VoucherResult + VoucherResults() ([]TypedVoucher, error) // LastVoucher returns the last voucher sent on the channel - LastVoucher() Voucher + LastVoucher() (TypedVoucher, error) // LastVoucherResult returns the last voucher result sent on the channel - LastVoucherResult() VoucherResult + LastVoucherResult() (TypedVoucher, error) // ReceivedCidsTotal returns the number of (non-unique) cids received so far // on the channel - note that a block can exist in more than one place in the DAG