Skip to content

Commit

Permalink
fix: ensure header.Validate is called by the lib (#89)
Browse files Browse the repository at this point in the history
Header implementations should not call Validate themselves as the lib now controls it.

Additionally, we unify the response processing code and Header constructor.

Contains error checks corrections

Closes #78
Based on #88
  • Loading branch information
Wondertan committed Aug 21, 2023
1 parent e500905 commit 6069d6a
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 67 deletions.
12 changes: 10 additions & 2 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ import (
// Header abstracts all methods required to perform header sync.
type Header interface {
// New creates new instance of a header.
// It exists to overcome limitation of Go's type system.
// See:
//https://go.googlesource.com/proposal/+/refs/heads/master/design/43651-type-parameters.md#pointer-method-example
New() Header
// IsZero reports whether Header is a zero value of it's concrete type.
IsZero() bool
// ChainID returns identifier of the chain (ChainID).
// ChainID returns identifier of the chain.
ChainID() string
// Hash returns hash of a header.
Hash() Hash
Expand All @@ -23,9 +26,14 @@ type Header interface {
Time() time.Time
// Verify validates given untrusted Header against trusted Header.
Verify(Header) error
// Validate performs basic validation to check for missed/incorrect fields.
// Validate performs stateless validation to check for missed/incorrect fields.
Validate() error

encoding.BinaryMarshaler
encoding.BinaryUnmarshaler
}

// New is a generic Header constructor.
func New[H Header]() (h H) {
return h.New().(H)
}
26 changes: 8 additions & 18 deletions p2p/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,28 +313,18 @@ func (ex *Exchange[H]) request(
return nil, err
}

headers := make([]H, 0, len(responses))
for _, response := range responses {
if err = convertStatusCodeToError(response.StatusCode); err != nil {
return nil, err
}
var empty H
header := empty.New()
err := header.UnmarshalBinary(response.Body)
if err != nil {
return nil, err
}
err = validateChainID(ex.Params.chainID, header.(H).ChainID())
hdrs, err := processResponses[H](responses)
if err != nil {
return nil, err
}
for _, hdr := range hdrs {
// TODO(@Wondertan): There should be a unified header validation code path
err = validateChainID(ex.Params.chainID, hdr.ChainID())
if err != nil {
return nil, err
}
headers = append(headers, header.(H))
}

if len(headers) == 0 {
return nil, header.ErrNotFound
}
return headers, nil
return hdrs, nil
}

// shufflePeers changes the order of trusted peers.
Expand Down
71 changes: 42 additions & 29 deletions p2p/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (s *session[H]) doRequest(
log.Debugw("requesting headers from peer failed", "peer", stat.peerID, "err", err)
}

h, err := s.processResponse(r)
h, err := s.processResponses(r)
if err != nil {
logFn := log.Errorw

Expand Down Expand Up @@ -216,39 +216,19 @@ func (s *session[H]) doRequest(
s.queue.push(stat)
}

// processResponse converts HeaderResponse to Header.
func (s *session[H]) processResponse(responses []*p2p_pb.HeaderResponse) ([]H, error) {
if len(responses) == 0 {
return nil, errEmptyResponse
}

headers := make([]H, 0)
for _, resp := range responses {
err := convertStatusCodeToError(resp.StatusCode)
if err != nil {
return nil, err
}

var empty H
header := empty.New()
err = header.UnmarshalBinary(resp.Body)
if err != nil {
return nil, err
}
headers = append(headers, header.(H))
}

if len(headers) == 0 {
return nil, header.ErrNotFound
// processResponses converts HeaderResponse to Header.
func (s *session[H]) processResponses(responses []*p2p_pb.HeaderResponse) ([]H, error) {
hdrs, err := processResponses[H](responses)
if err != nil {
return nil, err
}

err := s.validate(headers)
return headers, err
return hdrs, s.verify(hdrs)
}

// validate checks that the received range of headers is adjacent and is valid against the provided
// verify checks that the received range of headers is adjacent and is valid against the provided
// header.
func (s *session[H]) validate(headers []H) error {
func (s *session[H]) verify(headers []H) error {
// if `s.from` is empty, then additional validation for the header`s range is not needed.
if s.from.IsZero() {
return nil
Expand Down Expand Up @@ -302,3 +282,36 @@ func prepareRequests(from, amount, headersPerPeer uint64) []*p2p_pb.HeaderReques
}
return requests
}

// processResponses converts HeaderResponses to Headers
func processResponses[H header.Header](resps []*p2p_pb.HeaderResponse) ([]H, error) {
if len(resps) == 0 {
return nil, errEmptyResponse
}

hdrs := make([]H, 0)
for _, resp := range resps {
err := convertStatusCodeToError(resp.StatusCode)
if err != nil {
return nil, err
}

hdr := header.New[H]()
err = hdr.UnmarshalBinary(resp.Body)
if err != nil {
return nil, err
}

err = hdr.Validate()
if err != nil {
return nil, err
}

hdrs = append(hdrs, hdr)
}

if len(hdrs) == 0 {
return nil, header.ErrNotFound
}
return hdrs, nil
}
4 changes: 2 additions & 2 deletions p2p/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func Test_Validate(t *testing.T) {
)

headers := suite.GenDummyHeaders(5)
err := ses.validate(headers)
err := ses.verify(headers)
assert.NoError(t, err)
}

Expand All @@ -53,6 +53,6 @@ func Test_ValidateFails(t *testing.T) {
headers := suite.GenDummyHeaders(5)
// break adjacency
headers[2] = headers[4]
err := ses.validate(headers)
err := ses.verify(headers)
assert.Error(t, err)
}
19 changes: 14 additions & 5 deletions p2p/subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,28 @@ func (p *Subscriber[H]) Stop(context.Context) error {
// Does not punish peers if *header.VerifyError is given with Uncertain set to true.
func (p *Subscriber[H]) SetVerifier(val func(context.Context, H) error) error {
pval := func(ctx context.Context, p peer.ID, msg *pubsub.Message) pubsub.ValidationResult {
var empty H
maybeHead := empty.New()
err := maybeHead.UnmarshalBinary(msg.Data)
hdr := header.New[H]()
err := hdr.UnmarshalBinary(msg.Data)
if err != nil {
log.Errorw("unmarshalling header",
"from", p.ShortString(),
"err", err)
return pubsub.ValidationReject
}
msg.ValidatorData = maybeHead
// ensure header validity
err = hdr.Validate()
if err != nil {
log.Errorw("invalid header",
"from", p.ShortString(),
"err", err)
return pubsub.ValidationReject
}
// keep the valid header in the msg so Subscriptions can access it without
// additional unmarhalling
msg.ValidatorData = hdr

var verErr *header.VerifyError
err = val(ctx, maybeHead.(H))
err = val(ctx, hdr)
switch {
case err == nil:
return pubsub.ValidationAccept
Expand Down
5 changes: 3 additions & 2 deletions store/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package store

import (
"context"
"errors"

"github.com/celestiaorg/go-header"
)
Expand All @@ -10,10 +11,10 @@ import (
// it initializes the Store by requesting the header with the given hash.
func Init[H header.Header](ctx context.Context, store header.Store[H], ex header.Exchange[H], hash header.Hash) error {
_, err := store.Head(ctx)
switch err {
switch {
default:
return err
case header.ErrNoHead:
case errors.Is(err, header.ErrNoHead):
initial, err := ex.Get(ctx, hash)
if err != nil {
return err
Expand Down
17 changes: 8 additions & 9 deletions store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,12 @@ func (s *Store[H]) Head(ctx context.Context, _ ...header.HeadOption) (H, error)

var zero H
head, err = s.readHead(ctx)
switch err {
switch {
default:
return zero, err
case datastore.ErrNotFound, header.ErrNotFound:
case errors.Is(err, datastore.ErrNotFound), errors.Is(err, header.ErrNotFound):
return zero, header.ErrNoHead
case nil:
case err == nil:
s.heightSub.SetHeight(uint64(head.Height()))
log.Infow("loaded head", "height", head.Height(), "hash", head.Hash())
return head, nil
Expand All @@ -190,22 +190,21 @@ func (s *Store[H]) Get(ctx context.Context, hash header.Hash) (H, error) {

b, err := s.ds.Get(ctx, datastore.NewKey(hash.String()))
if err != nil {
if err == datastore.ErrNotFound {
if errors.Is(err, datastore.ErrNotFound) {
return zero, header.ErrNotFound
}

return zero, err
}

var empty H
h := empty.New()
h := header.New[H]()
err = h.UnmarshalBinary(b)
if err != nil {
return zero, err
}

s.cache.Add(h.Hash().String(), h)
return h.(H), nil
return h, nil
}

func (s *Store[H]) GetByHeight(ctx context.Context, height uint64) (H, error) {
Expand All @@ -216,7 +215,7 @@ func (s *Store[H]) GetByHeight(ctx context.Context, height uint64) (H, error) {
// if the requested 'height' was not yet published
// we subscribe to it
h, err := s.heightSub.Sub(ctx, height)
if err != errElapsedHeight {
if !errors.Is(err, errElapsedHeight) {
return h, err
}
// otherwise, the errElapsedHeight is thrown,
Expand All @@ -229,7 +228,7 @@ func (s *Store[H]) GetByHeight(ctx context.Context, height uint64) (H, error) {

hash, err := s.heightIndex.HashByHeight(ctx, height)
if err != nil {
if err == datastore.ErrNotFound {
if errors.Is(err, datastore.ErrNotFound) {
return zero, header.ErrNotFound
}

Expand Down

0 comments on commit 6069d6a

Please sign in to comment.