Skip to content

Commit

Permalink
p2p: check protonil for all receives (#2346)
Browse files Browse the repository at this point in the history
Check all libp2p protos received over the wire for nil fields.

Also extend `protonil` to detect nil list eleements and nil map values.

category: misc
ticket: none
  • Loading branch information
corverroos committed Jun 22, 2023
1 parent 3234df5 commit 347a50a
Show file tree
Hide file tree
Showing 13 changed files with 351 additions and 174 deletions.
3 changes: 2 additions & 1 deletion app/peerinfo/adhoc.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ func DoOnce(ctx context.Context, tcpNode host.Host, peerID peer.ID) (*pbv1.PeerI
rtt = d
}

req := new(pbv1.PeerInfo) // TODO(corver): Populate request fields and make them required.
resp := new(pbv1.PeerInfo)
err := p2p.SendReceive(ctx, tcpNode, peerID, &pbv1.PeerInfo{}, resp, protocolID1,
err := p2p.SendReceive(ctx, tcpNode, peerID, req, resp, protocolID1,
p2p.WithSendReceiveRTT(rttCallback), p2p.WithDelimitedProtocol(protocolID2))
if err != nil {
return nil, 0, false, err
Expand Down
3 changes: 3 additions & 0 deletions app/peerinfo/peerinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ func (p *PeerInfo) sendOnce(ctx context.Context, now time.Time) {
p2p.WithSendReceiveRTT(rttCallback), p2p.WithDelimitedProtocol(protocolID2))
if err != nil {
return // Logging handled by send func.
} else if resp.SentAt == nil || resp.StartedAt == nil {
log.Error(ctx, "Invalid peerinfo response", err, z.Str("peer", p2p.PeerName(peerID)))
return
}

expectedSentAt := time.Now().Add(-rtt / 2)
Expand Down
33 changes: 18 additions & 15 deletions app/peerinfo/peerinfopb/v1/peerinfo.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 7 additions & 5 deletions app/peerinfo/peerinfopb/v1/peerinfo.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ option go_package = "github.com/obolnetwork/charon/app/peerinfo/peerinfopb/v1";
import "google/protobuf/timestamp.proto";

message PeerInfo {
string charon_version = 1;
bytes lock_hash = 2;
google.protobuf.Timestamp sent_at = 3;
string git_hash = 4;
google.protobuf.Timestamp started_at = 5;
string charon_version = 1;
bytes lock_hash = 2;
optional google.protobuf.Timestamp sent_at = 3;
string git_hash = 4;
optional google.protobuf.Timestamp started_at = 5;

// TODO(corver): Always populate timestamps when sending, then make them required after subsequent release.
}
56 changes: 54 additions & 2 deletions app/protonil/protonil.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ const maxFieldNumber = 64
// Note this only applies to "message" fields, not primitive scalars or "map" or "list" fields
// since their zero values are valid.
func Check(msg proto.Message) error {
if msg == nil {
return errors.New("nil protobuf message")
}

rMsg := msg.ProtoReflect()
if !rMsg.IsValid() {
return errors.New("nil protobuf message")
Expand All @@ -46,8 +50,44 @@ func Check(msg proto.Message) error {
}
checked++

if field.IsMap() || field.IsList() {
// Nil maps and lists are equivalent to empty maps and lists.
// Check the values of map fields.
if field.IsMap() {
var err error
rMsg.Get(field).Map().Range(func(_ protoreflect.MapKey, val protoreflect.Value) bool {
value, ok := valueToMsg(val)
if !ok {
// Not a message value type.
return false
}

err = Check(value.Interface())

return err == nil
})

if err != nil {
return errors.Wrap(err, "map value", z.Any("map", field.Name()))
}

continue
}

// Check elements of list fields.
if field.IsList() {
list := rMsg.Get(field).List()
for i := 0; i < list.Len(); i++ {
elem, ok := valueToMsg(list.Get(i))
if !ok {
// Not a message element type.
break
}

if err := Check(elem.Interface()); err != nil {
return errors.Wrap(err, "list element",
z.Any("list", field.Name()), z.Int("index", i))
}
}

continue
}

Expand Down Expand Up @@ -80,3 +120,15 @@ func Check(msg proto.Message) error {

return nil
}

// valueToMsg converts a protoreflect.Value to a protoreflect.Message if possible.
func valueToMsg(val protoreflect.Value) (protoreflect.Message, bool) {
iface := val.Interface()
if iface == nil {
return nil, false
}

elemMsg, ok := iface.(protoreflect.Message)

return elemMsg, ok
}
62 changes: 51 additions & 11 deletions app/protonil/protonil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@ import (
func TestCheck(t *testing.T) {
tests := []struct {
name string
m1 *v1.M1
msg proto.Message
wantErr string
}{
{
name: "nil",
m1: nil,
msg: nil,
wantErr: "nil protobuf message",
},
{
name: "zero m1, nil m2",
m1: &v1.M1{},
name: "zero msg, nil m2",
msg: &v1.M1{},
wantErr: "nil proto field",
},
{
name: "all populated",
m1: &v1.M1{
Name: "m1",
msg: &v1.M1{
Name: "msg",
M2: &v1.M2{
Name: "m2",
M3: &v1.M3{Name: "m3"},
Expand All @@ -52,8 +52,8 @@ func TestCheck(t *testing.T) {
},
{
name: "optionals nil",
m1: &v1.M1{
Name: "m1",
msg: &v1.M1{
Name: "msg",
M2: &v1.M2{
Name: "m2",
M3: &v1.M3{Name: "m3"},
Expand All @@ -65,8 +65,8 @@ func TestCheck(t *testing.T) {
},
{
name: "nil m3 in optional m2",
m1: &v1.M1{
Name: "m1",
msg: &v1.M1{
Name: "msg",
M2: &v1.M2{
Name: "m2",
M3: &v1.M3{Name: "m3"},
Expand All @@ -78,10 +78,49 @@ func TestCheck(t *testing.T) {
},
wantErr: "inner message field: nil proto field",
},
{
name: "zero m4",
msg: &v1.M4{},
wantErr: "",
},
{
name: "m4 with non-empty containers",
msg: &v1.M4{
M3Map: map[string]*v1.M3{
"k0": {Name: "v0"},
"k1": {Name: "v1"},
},
M3List: []*v1.M3{
{Name: "elem0"},
{Name: "elem1"},
},
},
wantErr: "",
},
{
name: "m4 with nil map value",
msg: &v1.M4{
M3Map: map[string]*v1.M3{
"k0": nil,
"k1": {Name: "v1"},
},
},
wantErr: "map value: nil protobuf message",
},
{
name: "m4 with nil list element",
msg: &v1.M4{
M3List: []*v1.M3{
nil,
{Name: "elem1"},
},
},
wantErr: "list element: nil protobuf message",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
err := protonil.Check(test.m1)
err := protonil.Check(test.msg)
if test.wantErr != "" {
require.ErrorContains(t, err, test.wantErr)
} else {
Expand All @@ -96,6 +135,7 @@ func TestFuzz(t *testing.T) {
new(v1.M1),
new(v1.M2),
new(v1.M3),
new(v1.M4),
new(manifestpb.Cluster),
new(manifestpb.SignedMutation),
new(manifestpb.SignedMutationList),
Expand Down
Loading

0 comments on commit 347a50a

Please sign in to comment.