Skip to content

Commit

Permalink
Nextgen Proto Pythonic API: Struct/ListValue assignment and creation
Browse files Browse the repository at this point in the history
Python dict is now able to be assigned (by create and copy, not reference) and compared with the Protobuf Struct field.
Python list is now able to be assigned (by create and copy, not reference) and compared with the Protobuf ListValue field.

example usage:
  dictionary = {'key1': 5.0, 'key2': {'subkey': 11.0, 'k': False},}
  list_value = [6, 'seven', True, False, None, dictionary]
  msg = more_messages_pb2.WKTMessage(
      optional_struct=dictionary, optional_list_value=list_value
  )
  self.assertEqual(msg.optional_struct, dictionary)
  self.assertEqual(msg.optional_list_value, list_value)

PiperOrigin-RevId: 646099987
  • Loading branch information
anandolee authored and copybara-github committed Jun 24, 2024
1 parent 0302c4c commit e17821c
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 103 deletions.
6 changes: 6 additions & 0 deletions python/google/protobuf/internal/descriptor_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from google.protobuf.internal import testing_refleaks

from google.protobuf import duration_pb2
from google.protobuf import struct_pb2
from google.protobuf import timestamp_pb2
from google.protobuf import unittest_features_pb2
from google.protobuf import unittest_import_pb2
Expand Down Expand Up @@ -439,6 +440,7 @@ def testAddSerializedFile(self):
self.testFindMessageTypeByName()
self.pool.AddSerializedFile(timestamp_pb2.DESCRIPTOR.serialized_pb)
self.pool.AddSerializedFile(duration_pb2.DESCRIPTOR.serialized_pb)
self.pool.AddSerializedFile(struct_pb2.DESCRIPTOR.serialized_pb)
file_json = self.pool.AddSerializedFile(
more_messages_pb2.DESCRIPTOR.serialized_pb)
field = file_json.message_types_by_name['class'].fields_by_name['int_field']
Expand Down Expand Up @@ -550,6 +552,9 @@ def testComplexNesting(self):
timestamp_pb2.DESCRIPTOR.serialized_pb)
duration_desc = descriptor_pb2.FileDescriptorProto.FromString(
duration_pb2.DESCRIPTOR.serialized_pb)
struct_desc = descriptor_pb2.FileDescriptorProto.FromString(
struct_pb2.DESCRIPTOR.serialized_pb
)
more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString(
more_messages_pb2.DESCRIPTOR.serialized_pb)
test1_desc = descriptor_pb2.FileDescriptorProto.FromString(
Expand All @@ -558,6 +563,7 @@ def testComplexNesting(self):
descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb)
self.pool.Add(timestamp_desc)
self.pool.Add(duration_desc)
self.pool.Add(struct_desc)
self.pool.Add(more_messages_desc)
self.pool.Add(test1_desc)
self.pool.Add(test2_desc)
Expand Down
3 changes: 3 additions & 0 deletions python/google/protobuf/internal/more_messages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ syntax = "proto2";
package google.protobuf.internal;

import "google/protobuf/duration.proto";
import "google/protobuf/struct.proto";
import "google/protobuf/timestamp.proto";

// A message where tag numbers are listed out of order, to allow us to test our
Expand Down Expand Up @@ -355,4 +356,6 @@ message ConflictJsonName {
message WKTMessage {
optional Timestamp optional_timestamp = 1;
optional Duration optional_duration = 2;
optional Struct optional_struct = 3;
optional ListValue optional_list_value = 4;
}
59 changes: 44 additions & 15 deletions python/google/protobuf/internal/python_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@

_FieldDescriptor = descriptor_mod.FieldDescriptor
_AnyFullTypeName = 'google.protobuf.Any'
_StructFullTypeName = 'google.protobuf.Struct'
_ListValueFullTypeName = 'google.protobuf.ListValue'
_ExtensionDict = extension_dict._ExtensionDict

class GeneratedProtocolMessageType(type):
Expand Down Expand Up @@ -515,37 +517,47 @@ def init(self, **kwargs):
# field=None is the same as no field at all.
continue
if field.label == _FieldDescriptor.LABEL_REPEATED:
copy = field._default_constructor(self)
field_copy = field._default_constructor(self)
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
if _IsMapField(field):
if _IsMessageMapField(field):
for key in field_value:
copy[key].MergeFrom(field_value[key])
field_copy[key].MergeFrom(field_value[key])
else:
copy.update(field_value)
field_copy.update(field_value)
else:
for val in field_value:
if isinstance(val, dict):
copy.add(**val)
field_copy.add(**val)
else:
copy.add().MergeFrom(val)
field_copy.add().MergeFrom(val)
else: # Scalar
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
field_value = [_GetIntegerEnumValue(field.enum_type, val)
for val in field_value]
copy.extend(field_value)
self._fields[field] = copy
field_copy.extend(field_value)
self._fields[field] = field_copy
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
copy = field._default_constructor(self)
field_copy = field._default_constructor(self)
new_val = None
if isinstance(field_value, message_mod.Message):
new_val = field_value
elif isinstance(field_value, dict):
new_val = field.message_type._concrete_class(**field_value)
elif field.message_type.full_name == 'google.protobuf.Timestamp':
copy.FromDatetime(field_value)
elif field.message_type.full_name == 'google.protobuf.Duration':
copy.FromTimedelta(field_value)
if field.message_type.full_name == _StructFullTypeName:
field_copy.Clear()
if len(field_value) == 1 and 'fields' in field_value:
try:
field_copy.update(field_value)
except:
# Fall back to init normal message field
field_copy.Clear()
new_val = field.message_type._concrete_class(**field_value)
else:
field_copy.update(field_value)
else:
new_val = field.message_type._concrete_class(**field_value)
elif hasattr(field_copy, '_internal_assign'):
field_copy._internal_assign(field_value)
else:
raise TypeError(
'Message field {0}.{1} must be initialized with a '
Expand All @@ -558,10 +570,10 @@ def init(self, **kwargs):

if new_val:
try:
copy.MergeFrom(new_val)
field_copy.MergeFrom(new_val)
except TypeError:
_ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
self._fields[field] = copy
self._fields[field] = field_copy
else:
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
field_value = _GetIntegerEnumValue(field.enum_type, field_value)
Expand Down Expand Up @@ -777,6 +789,14 @@ def setter(self, new_value):
elif field.message_type.full_name == 'google.protobuf.Duration':
getter(self)
self._fields[field].FromTimedelta(new_value)
elif field.message_type.full_name == _StructFullTypeName:
getter(self)
self._fields[field].Clear()
self._fields[field].update(new_value)
elif field.message_type.full_name == _ListValueFullTypeName:
getter(self)
self._fields[field].Clear()
self._fields[field].extend(new_value)
else:
raise AttributeError(
'Assignment not allowed to composite field '
Expand Down Expand Up @@ -978,6 +998,15 @@ def _InternalUnpackAny(msg):
def _AddEqualsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
def __eq__(self, other):
if self.DESCRIPTOR.full_name == _ListValueFullTypeName and isinstance(
other, list
):
return self._internal_compare(other)
if self.DESCRIPTOR.full_name == _StructFullTypeName and isinstance(
other, dict
):
return self._internal_compare(other)

if (not isinstance(other, message_mod.Message) or
other.DESCRIPTOR != self.DESCRIPTOR):
return NotImplemented
Expand Down
40 changes: 40 additions & 0 deletions python/google/protobuf/internal/well_known_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ def FromDatetime(self, dt):
self.seconds = seconds
self.nanos = nanos

def _internal_assign(self, dt):
self.FromDatetime(dt)

def __add__(self, value) -> datetime.datetime:
if isinstance(value, Duration):
return self.ToDatetime() + value.ToTimedelta()
Expand Down Expand Up @@ -442,6 +445,9 @@ def FromTimedelta(self, td):
'object got {0}: {1}'.format(type(td).__name__, e)
) from e

def _internal_assign(self, td):
self.FromTimedelta(td)

def _NormalizeDuration(self, seconds, nanos):
"""Set Duration by seconds and nanos."""
# Force nanos to be negative if the duration is negative.
Expand Down Expand Up @@ -550,6 +556,24 @@ def __len__(self):
def __iter__(self):
return iter(self.fields)

def _internal_assign(self, dictionary):
self.Clear()
self.update(dictionary)

def _internal_compare(self, other):
size = len(self)
if size != len(other):
return False
for key, value in self.items():
if key not in other:
return False
if isinstance(other[key], (dict, list)):
if not value._internal_compare(other[key]):
return False
elif value != other[key]:
return False
return True

def keys(self): # pylint: disable=invalid-name
return self.fields.keys()

Expand Down Expand Up @@ -605,6 +629,22 @@ def __setitem__(self, index, value):
def __delitem__(self, key):
del self.values[key]

def _internal_assign(self, elem_seq):
self.Clear()
self.extend(elem_seq)

def _internal_compare(self, other):
size = len(self)
if size != len(other):
return False
for i in range(size):
if isinstance(other[i], (dict, list)):
if not self[i]._internal_compare(other[i]):
return False
elif self[i] != other[i]:
return False
return True

def items(self):
for i in range(len(self)):
yield self[i]
Expand Down
67 changes: 67 additions & 0 deletions python/google/protobuf/internal/well_known_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,73 @@ def testStructAssignment(self):
s2['x'] = s1['x']
self.assertEqual(s1['x'], s2['x'])

dictionary = {
'key1': 5.0,
'key2': 'abc',
'key3': {'subkey': 11.0, 'k': False},
}
msg = more_messages_pb2.WKTMessage()
msg.optional_struct = dictionary
self.assertEqual(msg.optional_struct, dictionary)

# Tests assign is not merge
dictionary2 = {
'key4': {'subkey': 11.0, 'k': True},
}
msg.optional_struct = dictionary2
self.assertEqual(msg.optional_struct, dictionary2)

# Tests assign empty
msg2 = more_messages_pb2.WKTMessage()
self.assertNotIn('optional_struct', msg2)
msg2.optional_struct = {}
self.assertIn('optional_struct', msg2)
self.assertEqual(msg2.optional_struct, {})

def testListValueAssignment(self):
list_value = [6, 'seven', True, False, None, {}]
msg = more_messages_pb2.WKTMessage()
msg.optional_list_value = list_value
self.assertEqual(msg.optional_list_value, list_value)

def testStructConstruction(self):
dictionary = {
'key1': 5.0,
'key2': 'abc',
'key3': {'subkey': 11.0, 'k': False},
}
list_value = [6, 'seven', True, False, None, dictionary]
msg = more_messages_pb2.WKTMessage(
optional_struct=dictionary, optional_list_value=list_value
)
self.assertEqual(len(msg.optional_struct), len(dictionary))
self.assertEqual(msg.optional_struct, dictionary)
self.assertEqual(len(msg.optional_list_value), len(list_value))
self.assertEqual(msg.optional_list_value, list_value)

msg2 = more_messages_pb2.WKTMessage(
optional_struct={}, optional_list_value=[]
)
self.assertIn('optional_struct', msg2)
self.assertIn('optional_list_value', msg2)
self.assertEqual(msg2.optional_struct, {})
self.assertEqual(msg2.optional_list_value, [])

def testSpecialStructConstruct(self):
dictionary = {'key1': 6.0}
msg = more_messages_pb2.WKTMessage(optional_struct=dictionary)
self.assertEqual(msg.optional_struct, dictionary)

dictionary2 = {'fields': 7.0}
msg2 = more_messages_pb2.WKTMessage(optional_struct=dictionary2)
self.assertEqual(msg2.optional_struct, dictionary2)

# Construct Struct as normal message
value_msg = struct_pb2.Value(number_value=5.0)
dictionary3 = {'fields': {'key1': value_msg}}
msg3 = more_messages_pb2.WKTMessage(optional_struct=dictionary3)
self.assertEqual(msg3.optional_struct, {'key1': 5.0})

def testMergeFrom(self):
struct = struct_pb2.Struct()
struct_class = struct.__class__
Expand Down
Loading

0 comments on commit e17821c

Please sign in to comment.