Skip to content

Commit

Permalink
fix compile errors from previous commit; fix bug in protomessage.Walk…
Browse files Browse the repository at this point in the history
… related to map types w/ non-message values; consolidate isMessageKind function into internal package
  • Loading branch information
jhump committed Feb 7, 2024
1 parent 4d6547d commit c9ae7ca
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 42 deletions.
6 changes: 6 additions & 0 deletions internal/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,9 @@ func PathKey(path protoreflect.SourcePath) string {
}
return string(b)
}

// IsMessageKind returns true if the given kind indicates a message type. Both
// messages and groups are message types.
func IsMessageKind(k protoreflect.Kind) bool {
return k == protoreflect.MessageKind || k == protoreflect.GroupKind
}
6 changes: 3 additions & 3 deletions protobuilder/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"google.golang.org/protobuf/types/descriptorpb"

"github.com/jhump/protoreflect/v2/internal"
"github.com/jhump/protoreflect/v2/protoresolve"
"github.com/jhump/protoreflect/v2/protomessage"
)

// EnumRange is a range of enum numbers. The first element is the start
Expand Down Expand Up @@ -69,7 +69,7 @@ func FromEnum(ed protoreflect.EnumDescriptor) (*EnumBuilder, error) {
func fromEnum(ed protoreflect.EnumDescriptor, localEnums map[protoreflect.EnumDescriptor]*EnumBuilder) (*EnumBuilder, error) {
eb := NewEnum(ed.Name())
var err error
eb.Options, err = protoresolve.As[*descriptorpb.EnumOptions](ed.Options())
eb.Options, err = protomessage.As[*descriptorpb.EnumOptions](ed.Options())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -389,7 +389,7 @@ func FromEnumValue(evd protoreflect.EnumValueDescriptor) (*EnumValueBuilder, err
func fromEnumValue(evd protoreflect.EnumValueDescriptor) (*EnumValueBuilder, error) {
evb := NewEnumValue(evd.Name())
var err error
evb.Options, err = protoresolve.As[*descriptorpb.EnumValueOptions](evd.Options())
evb.Options, err = protomessage.As[*descriptorpb.EnumValueOptions](evd.Options())
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions protobuilder/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"google.golang.org/protobuf/types/descriptorpb"

"github.com/jhump/protoreflect/v2/internal"
"github.com/jhump/protoreflect/v2/protoresolve"
"github.com/jhump/protoreflect/v2/protomessage"
"github.com/jhump/protoreflect/v2/protowrap"
)

Expand Down Expand Up @@ -189,7 +189,7 @@ func fromField(fld protoreflect.FieldDescriptor) (*FieldBuilder, error) {
ft := fieldTypeFromDescriptor(fld)
flb := NewField(fld.Name(), ft)
var err error
flb.Options, err = protoresolve.As[*descriptorpb.FieldOptions](fld.Options())
flb.Options, err = protomessage.As[*descriptorpb.FieldOptions](fld.Options())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -639,7 +639,7 @@ func FromOneof(ood protoreflect.OneofDescriptor) (*OneofBuilder, error) {
func fromOneof(ood protoreflect.OneofDescriptor) (*OneofBuilder, error) {
oob := NewOneof(ood.Name())
var err error
oob.Options, err = protoresolve.As[*descriptorpb.OneofOptions](ood.Options())
oob.Options, err = protomessage.As[*descriptorpb.OneofOptions](ood.Options())
if err != nil {
return nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions protobuilder/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"google.golang.org/protobuf/types/descriptorpb"

"github.com/jhump/protoreflect/v2/internal"
"github.com/jhump/protoreflect/v2/protomessage"
"github.com/jhump/protoreflect/v2/protoresolve"
)

Expand Down Expand Up @@ -88,7 +89,7 @@ func FromFile(fd protoreflect.FileDescriptor) (*FileBuilder, error) {
fb.Syntax = fd.Syntax()
fb.Package = fd.Package()
var err error
fb.Options, err = protoresolve.As[*descriptorpb.FileOptions](fd.Options())
fb.Options, err = protomessage.As[*descriptorpb.FileOptions](fd.Options())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -791,7 +792,7 @@ func isExtendeeMessageSet(flb *FieldBuilder) bool {
if flb.localExtendee != nil {
return flb.localExtendee.Options.GetMessageSetWireFormat()
}
opts, _ := protoresolve.As[*descriptorpb.MessageOptions](flb.foreignExtendee.Options())
opts, _ := protomessage.As[*descriptorpb.MessageOptions](flb.foreignExtendee.Options())
return opts.GetMessageSetWireFormat()
}

Expand Down
6 changes: 3 additions & 3 deletions protobuilder/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"google.golang.org/protobuf/types/descriptorpb"

"github.com/jhump/protoreflect/v2/internal"
"github.com/jhump/protoreflect/v2/protoresolve"
"github.com/jhump/protoreflect/v2/protomessage"
)

// FieldRange is a range of field numbers. The first element is the start
Expand Down Expand Up @@ -90,14 +90,14 @@ func fromMessage(md protoreflect.MessageDescriptor,

mb := NewMessage(md.Name())
var err error
mb.Options, err = protoresolve.As[*descriptorpb.MessageOptions](md.Options())
mb.Options, err = protomessage.As[*descriptorpb.MessageOptions](md.Options())
if err != nil {
return nil, err
}
ranges := md.ExtensionRanges()
mb.ExtensionRanges = make([]ExtensionRange, ranges.Len())
for i, length := 0, ranges.Len(); i < length; i++ {
opts, err := protoresolve.As[*descriptorpb.ExtensionRangeOptions](md.ExtensionRangeOptions(i))
opts, err := protomessage.As[*descriptorpb.ExtensionRangeOptions](md.ExtensionRangeOptions(i))
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions protobuilder/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"google.golang.org/protobuf/types/descriptorpb"

"github.com/jhump/protoreflect/v2/internal"
"github.com/jhump/protoreflect/v2/protoresolve"
"github.com/jhump/protoreflect/v2/protomessage"
)

// ServiceBuilder is a builder used to construct a protoreflect.ServiceDescriptor.
Expand Down Expand Up @@ -57,7 +57,7 @@ func FromService(sd protoreflect.ServiceDescriptor) (*ServiceBuilder, error) {
func fromService(sd protoreflect.ServiceDescriptor) (*ServiceBuilder, error) {
sb := NewService(sd.Name())
var err error
sb.Options, err = protoresolve.As[*descriptorpb.ServiceOptions](sd.Options())
sb.Options, err = protomessage.As[*descriptorpb.ServiceOptions](sd.Options())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -294,7 +294,7 @@ func fromMethod(mtd protoreflect.MethodDescriptor) (*MethodBuilder, error) {
resp := RpcTypeImportedMessage(mtd.Output(), mtd.IsStreamingServer())
mtb := NewMethod(mtd.Name(), req, resp)
var err error
mtb.Options, err = protoresolve.As[*descriptorpb.MethodOptions](mtd.Options())
mtb.Options, err = protomessage.As[*descriptorpb.MethodOptions](mtd.Options())
if err != nil {
return nil, err
}
Expand Down
30 changes: 15 additions & 15 deletions protomessage/walk.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package protomessage

import "google.golang.org/protobuf/reflect/protoreflect"
import (
"google.golang.org/protobuf/reflect/protoreflect"

"github.com/jhump/protoreflect/v2/internal"
)

// Walk traverses the given root messages, iterating through its fields and
// through all values in maps and lists, calling the given action for all
Expand All @@ -22,15 +26,7 @@ func walk(root protoreflect.Message, path []any, action func(path []any, val pro
root.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool {
path = append(path, field.Number())
switch {
case field.IsMap() && isMessageKind(field.MapValue().Kind()):
mapVal := val.Map()
mapVal.Range(func(key protoreflect.MapKey, val protoreflect.Value) bool {
path = append(path, key, protoreflect.FieldNumber(2) /* field 2 is the value in an entry */)
ok = walk(val.Message(), path, action)
path = path[:len(path)-2]
return ok
})
case field.IsList() && isMessageKind(field.Kind()):
case field.IsList() && internal.IsMessageKind(field.Kind()):
listVal := val.List()
for i, length := 0, listVal.Len(); i < length; i++ {
path = append(path, i)
Expand All @@ -40,15 +36,19 @@ func walk(root protoreflect.Message, path []any, action func(path []any, val pro
break
}
}
case isMessageKind(field.Kind()):
case field.IsMap() && internal.IsMessageKind(field.MapValue().Kind()):
mapVal := val.Map()
mapVal.Range(func(key protoreflect.MapKey, val protoreflect.Value) bool {
path = append(path, key, protoreflect.FieldNumber(2) /* field 2 is the value in an entry */)
ok = walk(val.Message(), path, action)
path = path[:len(path)-2]
return ok
})
case !field.IsMap() && internal.IsMessageKind(field.Kind()):
ok = walk(val.Message(), path, action)
}
path = path[:len(path)-1] // pop field number
return ok
})
return ok
}

func isMessageKind(k protoreflect.Kind) bool {
return k == protoreflect.MessageKind || k == protoreflect.GroupKind
}
3 changes: 2 additions & 1 deletion protoprint/print.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"google.golang.org/protobuf/types/dynamicpb"

"github.com/jhump/protoreflect/v2/internal"
"github.com/jhump/protoreflect/v2/protomessage"
"github.com/jhump/protoreflect/v2/protoresolve"
"github.com/jhump/protoreflect/v2/sourcelocation"
)
Expand Down Expand Up @@ -893,7 +894,7 @@ func (p *Printer) printMessageBody(
}

func isMessageSet(msg protoreflect.MessageDescriptor) bool {
opts, _ := protoresolve.As[*descriptorpb.MessageOptions](msg.Options())
opts, _ := protomessage.As[*descriptorpb.MessageOptions](msg.Options())
return opts.GetMessageSetWireFormat()
}

Expand Down
6 changes: 3 additions & 3 deletions protoresolve/converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"google.golang.org/protobuf/types/known/typepb"
"google.golang.org/protobuf/types/known/wrapperspb"

"github.com/jhump/protoreflect/v2/internal"
"github.com/jhump/protoreflect/v2/internal/wrappers"
)

Expand Down Expand Up @@ -1349,14 +1350,13 @@ func base(name string) string {
}

func newMessageValueForField(msg protoreflect.Message, field protoreflect.FieldDescriptor) protoreflect.Message {
isMessageKind := field.Kind() == protoreflect.MessageKind || field.Kind() == protoreflect.GroupKind
switch {
case field.IsList() && isMessageKind:
case field.IsList() && internal.IsMessageKind(field.Kind()):
return msg.NewField(field).List().NewElement().Message()
case field.IsMap():
// For maps, create a dynamic message representing the map entry
return dynamicpb.NewMessage(field.Message())
case isMessageKind:
case internal.IsMessageKind(field.Kind()):
return msg.NewField(field).Message()
default:
switch field.Kind() {
Expand Down
6 changes: 4 additions & 2 deletions protoresolve/reparse.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package protoresolve
import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"

"github.com/jhump/protoreflect/v2/internal"
)

// ReparseUnrecognized is a helper function for re-parsing unknown fields of a message,
Expand All @@ -20,7 +22,7 @@ func ReparseUnrecognized(msg proto.Message, resolver SerializationResolver) {

func reparseUnrecognized(msg protoreflect.Message, resolver SerializationResolver) {
msg.Range(func(fld protoreflect.FieldDescriptor, val protoreflect.Value) bool {
if fld.Kind() != protoreflect.MessageKind && fld.Kind() != protoreflect.GroupKind {
if !internal.IsMessageKind(fld.Kind()) {
return true
}
if fld.IsList() {
Expand All @@ -30,7 +32,7 @@ func reparseUnrecognized(msg protoreflect.Message, resolver SerializationResolve
}
} else if fld.IsMap() {
mapVal := fld.MapValue()
if mapVal.Kind() != protoreflect.MessageKind && mapVal.Kind() != protoreflect.GroupKind {
if !internal.IsMessageKind(mapVal.Kind()) {
return true
}
m := val.Map()
Expand Down
11 changes: 4 additions & 7 deletions protoresolve/reparse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"

"github.com/jhump/protoreflect/v2/internal"
"github.com/jhump/protoreflect/v2/internal/testdata"
"github.com/jhump/protoreflect/v2/protoresolve"
)
Expand Down Expand Up @@ -51,15 +52,15 @@ func hasUnrecognized(msg protoreflect.Message) bool {
var foundUnrecognized bool
msg.Range(func(fd protoreflect.FieldDescriptor, val protoreflect.Value) bool {
switch {
case fd.IsList() && isMessageKind(fd.Kind()):
case fd.IsList() && internal.IsMessageKind(fd.Kind()):
l := val.List()
for i, length := 0, l.Len(); i < length; i++ {
if hasUnrecognized(l.Get(i).Message()) {
foundUnrecognized = true
return false
}
}
case fd.IsMap() && isMessageKind(fd.MapValue().Kind()):
case fd.IsMap() && internal.IsMessageKind(fd.MapValue().Kind()):
val.Map().Range(func(_ protoreflect.MapKey, val protoreflect.Value) bool {
if hasUnrecognized(val.Message()) {
foundUnrecognized = true
Expand All @@ -70,7 +71,7 @@ func hasUnrecognized(msg protoreflect.Message) bool {
if foundUnrecognized {
return false
}
case !fd.IsMap() && isMessageKind(fd.Kind()):
case !fd.IsMap() && internal.IsMessageKind(fd.Kind()):
if hasUnrecognized(val.Message()) {
foundUnrecognized = true
return false
Expand All @@ -80,7 +81,3 @@ func hasUnrecognized(msg protoreflect.Message) bool {
})
return foundUnrecognized
}

func isMessageKind(k protoreflect.Kind) bool {
return k == protoreflect.MessageKind || k == protoreflect.GroupKind
}

0 comments on commit c9ae7ca

Please sign in to comment.