Index: third_party/protobuf/python/google/protobuf/text_format.py |
diff --git a/third_party/protobuf/python/google/protobuf/text_format.py b/third_party/protobuf/python/google/protobuf/text_format.py |
index 8d256076c28ad4c3245aaccbbbd8284aeb4bc961..6f1e3c8b725ea2f90b27ad62714de4d4df24a94c 100755 |
--- a/third_party/protobuf/python/google/protobuf/text_format.py |
+++ b/third_party/protobuf/python/google/protobuf/text_format.py |
@@ -99,7 +99,7 @@ class TextWriter(object): |
def MessageToString(message, as_utf8=False, as_one_line=False, |
pointy_brackets=False, use_index_order=False, |
- float_format=None): |
+ float_format=None, use_field_number=False): |
"""Convert protobuf message to text format. |
Floating point values can be formatted compactly with 15 digits of |
@@ -118,15 +118,16 @@ def MessageToString(message, as_utf8=False, as_one_line=False, |
field number order. |
float_format: If set, use this to specify floating point number formatting |
(per the "Format Specification Mini-Language"); otherwise, str() is used. |
+ use_field_number: If True, print field numbers instead of names. |
Returns: |
A string of the text formatted protocol buffer message. |
""" |
out = TextWriter(as_utf8) |
- PrintMessage(message, out, as_utf8=as_utf8, as_one_line=as_one_line, |
- pointy_brackets=pointy_brackets, |
- use_index_order=use_index_order, |
- float_format=float_format) |
+ printer = _Printer(out, 0, as_utf8, as_one_line, |
+ pointy_brackets, use_index_order, float_format, |
+ use_field_number) |
+ printer.PrintMessage(message) |
result = out.getvalue() |
out.close() |
if as_one_line: |
@@ -142,133 +143,187 @@ def _IsMapEntry(field): |
def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False, |
pointy_brackets=False, use_index_order=False, |
- float_format=None): |
- fields = message.ListFields() |
- if use_index_order: |
- fields.sort(key=lambda x: x[0].index) |
- for field, value in fields: |
- if _IsMapEntry(field): |
- for key in sorted(value): |
- # This is slow for maps with submessage entires because it copies the |
- # entire tree. Unfortunately this would take significant refactoring |
- # of this file to work around. |
- # |
- # TODO(haberman): refactor and optimize if this becomes an issue. |
- entry_submsg = field.message_type._concrete_class( |
- key=key, value=value[key]) |
- PrintField(field, entry_submsg, out, indent, as_utf8, as_one_line, |
- pointy_brackets=pointy_brackets, |
- use_index_order=use_index_order, float_format=float_format) |
- elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: |
- for element in value: |
- PrintField(field, element, out, indent, as_utf8, as_one_line, |
- pointy_brackets=pointy_brackets, |
- use_index_order=use_index_order, |
- float_format=float_format) |
- else: |
- PrintField(field, value, out, indent, as_utf8, as_one_line, |
- pointy_brackets=pointy_brackets, |
- use_index_order=use_index_order, |
- float_format=float_format) |
+ float_format=None, use_field_number=False): |
+ printer = _Printer(out, indent, as_utf8, as_one_line, |
+ pointy_brackets, use_index_order, float_format, |
+ use_field_number) |
+ printer.PrintMessage(message) |
def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False, |
pointy_brackets=False, use_index_order=False, float_format=None): |
- """Print a single field name/value pair. For repeated fields, the value |
- should be a single element. |
- """ |
- |
- out.write(' ' * indent) |
- if field.is_extension: |
- out.write('[') |
- if (field.containing_type.GetOptions().message_set_wire_format and |
- field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and |
- field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL): |
- out.write(field.message_type.full_name) |
- else: |
- out.write(field.full_name) |
- out.write(']') |
- elif field.type == descriptor.FieldDescriptor.TYPE_GROUP: |
- # For groups, use the capitalized name. |
- out.write(field.message_type.name) |
- else: |
- out.write(field.name) |
- |
- if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE: |
- # The colon is optional in this case, but our cross-language golden files |
- # don't include it. |
- out.write(': ') |
- |
- PrintFieldValue(field, value, out, indent, as_utf8, as_one_line, |
- pointy_brackets=pointy_brackets, |
- use_index_order=use_index_order, |
- float_format=float_format) |
- if as_one_line: |
- out.write(' ') |
- else: |
- out.write('\n') |
+ """Print a single field name/value pair.""" |
+ printer = _Printer(out, indent, as_utf8, as_one_line, |
+ pointy_brackets, use_index_order, float_format) |
+ printer.PrintField(field, value) |
def PrintFieldValue(field, value, out, indent=0, as_utf8=False, |
as_one_line=False, pointy_brackets=False, |
use_index_order=False, |
float_format=None): |
- """Print a single field value (not including name). For repeated fields, |
- the value should be a single element.""" |
+ """Print a single field value (not including name).""" |
+ printer = _Printer(out, indent, as_utf8, as_one_line, |
+ pointy_brackets, use_index_order, float_format) |
+ printer.PrintFieldValue(field, value) |
- if pointy_brackets: |
- openb = '<' |
- closeb = '>' |
- else: |
- openb = '{' |
- closeb = '}' |
- |
- if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: |
- if as_one_line: |
- out.write(' %s ' % openb) |
- PrintMessage(value, out, indent, as_utf8, as_one_line, |
- pointy_brackets=pointy_brackets, |
- use_index_order=use_index_order, |
- float_format=float_format) |
- out.write(closeb) |
- else: |
- out.write(' %s\n' % openb) |
- PrintMessage(value, out, indent + 2, as_utf8, as_one_line, |
- pointy_brackets=pointy_brackets, |
- use_index_order=use_index_order, |
- float_format=float_format) |
- out.write(' ' * indent + closeb) |
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: |
- enum_value = field.enum_type.values_by_number.get(value, None) |
- if enum_value is not None: |
- out.write(enum_value.name) |
+ |
+class _Printer(object): |
+ """Text format printer for protocol message.""" |
+ |
+ def __init__(self, out, indent=0, as_utf8=False, as_one_line=False, |
+ pointy_brackets=False, use_index_order=False, float_format=None, |
+ use_field_number=False): |
+ """Initialize the Printer. |
+ |
+ Floating point values can be formatted compactly with 15 digits of |
+ precision (which is the most that IEEE 754 "double" can guarantee) |
+ using float_format='.15g'. To ensure that converting to text and back to a |
+ proto will result in an identical value, float_format='.17g' should be used. |
+ |
+ Args: |
+ out: To record the text format result. |
+ indent: The indent level for pretty print. |
+ as_utf8: Produce text output in UTF8 format. |
+ as_one_line: Don't introduce newlines between fields. |
+ pointy_brackets: If True, use angle brackets instead of curly braces for |
+ nesting. |
+ use_index_order: If True, print fields of a proto message using the order |
+ defined in source code instead of the field number. By default, use the |
+ field number order. |
+ float_format: If set, use this to specify floating point number formatting |
+ (per the "Format Specification Mini-Language"); otherwise, str() is |
+ used. |
+ use_field_number: If True, print field numbers instead of names. |
+ """ |
+ self.out = out |
+ self.indent = indent |
+ self.as_utf8 = as_utf8 |
+ self.as_one_line = as_one_line |
+ self.pointy_brackets = pointy_brackets |
+ self.use_index_order = use_index_order |
+ self.float_format = float_format |
+ self.use_field_number = use_field_number |
+ |
+ def PrintMessage(self, message): |
+ """Convert protobuf message to text format. |
+ |
+ Args: |
+ message: The protocol buffers message. |
+ """ |
+ fields = message.ListFields() |
+ if self.use_index_order: |
+ fields.sort(key=lambda x: x[0].index) |
+ for field, value in fields: |
+ if _IsMapEntry(field): |
+ for key in sorted(value): |
+ # This is slow for maps with submessage entires because it copies the |
+ # entire tree. Unfortunately this would take significant refactoring |
+ # of this file to work around. |
+ # |
+ # TODO(haberman): refactor and optimize if this becomes an issue. |
+ entry_submsg = field.message_type._concrete_class( |
+ key=key, value=value[key]) |
+ self.PrintField(field, entry_submsg) |
+ elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: |
+ for element in value: |
+ self.PrintField(field, element) |
+ else: |
+ self.PrintField(field, value) |
+ |
+ def PrintField(self, field, value): |
+ """Print a single field name/value pair.""" |
+ out = self.out |
+ out.write(' ' * self.indent) |
+ if self.use_field_number: |
+ out.write(str(field.number)) |
else: |
- out.write(str(value)) |
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: |
- out.write('\"') |
- if isinstance(value, six.text_type): |
- out_value = value.encode('utf-8') |
+ if field.is_extension: |
+ out.write('[') |
+ if (field.containing_type.GetOptions().message_set_wire_format and |
+ field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and |
+ field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL): |
+ out.write(field.message_type.full_name) |
+ else: |
+ out.write(field.full_name) |
+ out.write(']') |
+ elif field.type == descriptor.FieldDescriptor.TYPE_GROUP: |
+ # For groups, use the capitalized name. |
+ out.write(field.message_type.name) |
+ else: |
+ out.write(field.name) |
+ |
+ if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE: |
+ # The colon is optional in this case, but our cross-language golden files |
+ # don't include it. |
+ out.write(': ') |
+ |
+ self.PrintFieldValue(field, value) |
+ if self.as_one_line: |
+ out.write(' ') |
else: |
- out_value = value |
- if field.type == descriptor.FieldDescriptor.TYPE_BYTES: |
- # We need to escape non-UTF8 chars in TYPE_BYTES field. |
- out_as_utf8 = False |
+ out.write('\n') |
+ |
+ def PrintFieldValue(self, field, value): |
+ """Print a single field value (not including name). |
+ |
+ For repeated fields, the value should be a single element. |
+ |
+ Args: |
+ field: The descriptor of the field to be printed. |
+ value: The value of the field. |
+ """ |
+ out = self.out |
+ if self.pointy_brackets: |
+ openb = '<' |
+ closeb = '>' |
else: |
- out_as_utf8 = as_utf8 |
- out.write(text_encoding.CEscape(out_value, out_as_utf8)) |
- out.write('\"') |
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: |
- if value: |
- out.write('true') |
+ openb = '{' |
+ closeb = '}' |
+ |
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: |
+ if self.as_one_line: |
+ out.write(' %s ' % openb) |
+ self.PrintMessage(value) |
+ out.write(closeb) |
+ else: |
+ out.write(' %s\n' % openb) |
+ self.indent += 2 |
+ self.PrintMessage(value) |
+ self.indent -= 2 |
+ out.write(' ' * self.indent + closeb) |
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: |
+ enum_value = field.enum_type.values_by_number.get(value, None) |
+ if enum_value is not None: |
+ out.write(enum_value.name) |
+ else: |
+ out.write(str(value)) |
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: |
+ out.write('\"') |
+ if isinstance(value, six.text_type): |
+ out_value = value.encode('utf-8') |
+ else: |
+ out_value = value |
+ if field.type == descriptor.FieldDescriptor.TYPE_BYTES: |
+ # We need to escape non-UTF8 chars in TYPE_BYTES field. |
+ out_as_utf8 = False |
+ else: |
+ out_as_utf8 = self.as_utf8 |
+ out.write(text_encoding.CEscape(out_value, out_as_utf8)) |
+ out.write('\"') |
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: |
+ if value: |
+ out.write('true') |
+ else: |
+ out.write('false') |
+ elif field.cpp_type in _FLOAT_TYPES and self.float_format is not None: |
+ out.write('{1:{0}}'.format(self.float_format, value)) |
else: |
- out.write('false') |
- elif field.cpp_type in _FLOAT_TYPES and float_format is not None: |
- out.write('{1:{0}}'.format(float_format, value)) |
- else: |
- out.write(str(value)) |
+ out.write(str(value)) |
-def Parse(text, message, allow_unknown_extension=False): |
+def Parse(text, message, |
+ allow_unknown_extension=False, allow_field_number=False): |
"""Parses an text representation of a protocol message into a message. |
Args: |
@@ -276,6 +331,7 @@ def Parse(text, message, allow_unknown_extension=False): |
message: A protocol buffer message to merge into. |
allow_unknown_extension: if True, skip over missing extensions and keep |
parsing |
+ allow_field_number: if True, both field number and field name are allowed. |
Returns: |
The same message passed as argument. |
@@ -285,10 +341,12 @@ def Parse(text, message, allow_unknown_extension=False): |
""" |
if not isinstance(text, str): |
text = text.decode('utf-8') |
- return ParseLines(text.split('\n'), message, allow_unknown_extension) |
+ return ParseLines(text.split('\n'), message, allow_unknown_extension, |
+ allow_field_number) |
-def Merge(text, message, allow_unknown_extension=False): |
+def Merge(text, message, allow_unknown_extension=False, |
+ allow_field_number=False): |
"""Parses an text representation of a protocol message into a message. |
Like Parse(), but allows repeated values for a non-repeated field, and uses |
@@ -299,6 +357,7 @@ def Merge(text, message, allow_unknown_extension=False): |
message: A protocol buffer message to merge into. |
allow_unknown_extension: if True, skip over missing extensions and keep |
parsing |
+ allow_field_number: if True, both field number and field name are allowed. |
Returns: |
The same message passed as argument. |
@@ -306,10 +365,12 @@ def Merge(text, message, allow_unknown_extension=False): |
Raises: |
ParseError: On text parsing problems. |
""" |
- return MergeLines(text.split('\n'), message, allow_unknown_extension) |
+ return MergeLines(text.split('\n'), message, allow_unknown_extension, |
+ allow_field_number) |
-def ParseLines(lines, message, allow_unknown_extension=False): |
+def ParseLines(lines, message, allow_unknown_extension=False, |
+ allow_field_number=False): |
"""Parses an text representation of a protocol message into a message. |
Args: |
@@ -317,6 +378,7 @@ def ParseLines(lines, message, allow_unknown_extension=False): |
message: A protocol buffer message to merge into. |
allow_unknown_extension: if True, skip over missing extensions and keep |
parsing |
+ allow_field_number: if True, both field number and field name are allowed. |
Returns: |
The same message passed as argument. |
@@ -324,11 +386,12 @@ def ParseLines(lines, message, allow_unknown_extension=False): |
Raises: |
ParseError: On text parsing problems. |
""" |
- _ParseOrMerge(lines, message, False, allow_unknown_extension) |
- return message |
+ parser = _Parser(allow_unknown_extension, allow_field_number) |
+ return parser.ParseLines(lines, message) |
-def MergeLines(lines, message, allow_unknown_extension=False): |
+def MergeLines(lines, message, allow_unknown_extension=False, |
+ allow_field_number=False): |
"""Parses an text representation of a protocol message into a message. |
Args: |
@@ -336,6 +399,7 @@ def MergeLines(lines, message, allow_unknown_extension=False): |
message: A protocol buffer message to merge into. |
allow_unknown_extension: if True, skip over missing extensions and keep |
parsing |
+ allow_field_number: if True, both field number and field name are allowed. |
Returns: |
The same message passed as argument. |
@@ -343,108 +407,174 @@ def MergeLines(lines, message, allow_unknown_extension=False): |
Raises: |
ParseError: On text parsing problems. |
""" |
- _ParseOrMerge(lines, message, True, allow_unknown_extension) |
- return message |
+ parser = _Parser(allow_unknown_extension, allow_field_number) |
+ return parser.MergeLines(lines, message) |
-def _ParseOrMerge(lines, |
- message, |
- allow_multiple_scalars, |
- allow_unknown_extension=False): |
- """Converts an text representation of a protocol message into a message. |
+class _Parser(object): |
+ """Text format parser for protocol message.""" |
- Args: |
- lines: Lines of a message's text representation. |
- message: A protocol buffer message to merge into. |
- allow_multiple_scalars: Determines if repeated values for a non-repeated |
- field are permitted, e.g., the string "foo: 1 foo: 2" for a |
- required/optional field named "foo". |
- allow_unknown_extension: if True, skip over missing extensions and keep |
- parsing |
+ def __init__(self, allow_unknown_extension=False, allow_field_number=False): |
+ self.allow_unknown_extension = allow_unknown_extension |
+ self.allow_field_number = allow_field_number |
- Raises: |
- ParseError: On text parsing problems. |
- """ |
- tokenizer = _Tokenizer(lines) |
- while not tokenizer.AtEnd(): |
- _MergeField(tokenizer, message, allow_multiple_scalars, |
- allow_unknown_extension) |
+ def ParseFromString(self, text, message): |
+ """Parses an text representation of a protocol message into a message.""" |
+ if not isinstance(text, str): |
+ text = text.decode('utf-8') |
+ return self.ParseLines(text.split('\n'), message) |
+ def ParseLines(self, lines, message): |
+ """Parses an text representation of a protocol message into a message.""" |
+ self._allow_multiple_scalars = False |
+ self._ParseOrMerge(lines, message) |
+ return message |
-def _MergeField(tokenizer, |
- message, |
- allow_multiple_scalars, |
- allow_unknown_extension=False): |
- """Merges a single protocol message field into a message. |
+ def MergeFromString(self, text, message): |
+ """Merges an text representation of a protocol message into a message.""" |
+ return self._MergeLines(text.split('\n'), message) |
- Args: |
- tokenizer: A tokenizer to parse the field name and values. |
- message: A protocol message to record the data. |
- allow_multiple_scalars: Determines if repeated values for a non-repeated |
- field are permitted, e.g., the string "foo: 1 foo: 2" for a |
- required/optional field named "foo". |
- allow_unknown_extension: if True, skip over missing extensions and keep |
- parsing |
+ def MergeLines(self, lines, message): |
+ """Merges an text representation of a protocol message into a message.""" |
+ self._allow_multiple_scalars = True |
+ self._ParseOrMerge(lines, message) |
+ return message |
- Raises: |
- ParseError: In case of text parsing problems. |
- """ |
- message_descriptor = message.DESCRIPTOR |
- if (hasattr(message_descriptor, 'syntax') and |
- message_descriptor.syntax == 'proto3'): |
- # Proto3 doesn't represent presence so we can't test if multiple |
- # scalars have occurred. We have to allow them. |
- allow_multiple_scalars = True |
- if tokenizer.TryConsume('['): |
- name = [tokenizer.ConsumeIdentifier()] |
- while tokenizer.TryConsume('.'): |
- name.append(tokenizer.ConsumeIdentifier()) |
- name = '.'.join(name) |
- |
- if not message_descriptor.is_extendable: |
- raise tokenizer.ParseErrorPreviousToken( |
- 'Message type "%s" does not have extensions.' % |
- message_descriptor.full_name) |
- # pylint: disable=protected-access |
- field = message.Extensions._FindExtensionByName(name) |
- # pylint: enable=protected-access |
- if not field: |
- if allow_unknown_extension: |
- field = None |
+ def _ParseOrMerge(self, lines, message): |
+ """Converts an text representation of a protocol message into a message. |
+ |
+ Args: |
+ lines: Lines of a message's text representation. |
+ message: A protocol buffer message to merge into. |
+ |
+ Raises: |
+ ParseError: On text parsing problems. |
+ """ |
+ tokenizer = _Tokenizer(lines) |
+ while not tokenizer.AtEnd(): |
+ self._MergeField(tokenizer, message) |
+ |
+ def _MergeField(self, tokenizer, message): |
+ """Merges a single protocol message field into a message. |
+ |
+ Args: |
+ tokenizer: A tokenizer to parse the field name and values. |
+ message: A protocol message to record the data. |
+ |
+ Raises: |
+ ParseError: In case of text parsing problems. |
+ """ |
+ message_descriptor = message.DESCRIPTOR |
+ if (hasattr(message_descriptor, 'syntax') and |
+ message_descriptor.syntax == 'proto3'): |
+ # Proto3 doesn't represent presence so we can't test if multiple |
+ # scalars have occurred. We have to allow them. |
+ self._allow_multiple_scalars = True |
+ if tokenizer.TryConsume('['): |
+ name = [tokenizer.ConsumeIdentifier()] |
+ while tokenizer.TryConsume('.'): |
+ name.append(tokenizer.ConsumeIdentifier()) |
+ name = '.'.join(name) |
+ |
+ if not message_descriptor.is_extendable: |
+ raise tokenizer.ParseErrorPreviousToken( |
+ 'Message type "%s" does not have extensions.' % |
+ message_descriptor.full_name) |
+ # pylint: disable=protected-access |
+ field = message.Extensions._FindExtensionByName(name) |
+ # pylint: enable=protected-access |
+ if not field: |
+ if self.allow_unknown_extension: |
+ field = None |
+ else: |
+ raise tokenizer.ParseErrorPreviousToken( |
+ 'Extension "%s" not registered.' % name) |
+ elif message_descriptor != field.containing_type: |
+ raise tokenizer.ParseErrorPreviousToken( |
+ 'Extension "%s" does not extend message type "%s".' % ( |
+ name, message_descriptor.full_name)) |
+ |
+ tokenizer.Consume(']') |
+ |
+ else: |
+ name = tokenizer.ConsumeIdentifier() |
+ if self.allow_field_number and name.isdigit(): |
+ number = ParseInteger(name, True, True) |
+ field = message_descriptor.fields_by_number.get(number, None) |
+ if not field and message_descriptor.is_extendable: |
+ field = message.Extensions._FindExtensionByNumber(number) |
else: |
+ field = message_descriptor.fields_by_name.get(name, None) |
+ |
+ # Group names are expected to be capitalized as they appear in the |
+ # .proto file, which actually matches their type names, not their field |
+ # names. |
+ if not field: |
+ field = message_descriptor.fields_by_name.get(name.lower(), None) |
+ if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP: |
+ field = None |
+ |
+ if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and |
+ field.message_type.name != name): |
+ field = None |
+ |
+ if not field: |
raise tokenizer.ParseErrorPreviousToken( |
- 'Extension "%s" not registered.' % name) |
- elif message_descriptor != field.containing_type: |
- raise tokenizer.ParseErrorPreviousToken( |
- 'Extension "%s" does not extend message type "%s".' % ( |
- name, message_descriptor.full_name)) |
+ 'Message type "%s" has no field named "%s".' % ( |
+ message_descriptor.full_name, name)) |
+ |
+ if field: |
+ if not self._allow_multiple_scalars and field.containing_oneof: |
+ # Check if there's a different field set in this oneof. |
+ # Note that we ignore the case if the same field was set before, and we |
+ # apply _allow_multiple_scalars to non-scalar fields as well. |
+ which_oneof = message.WhichOneof(field.containing_oneof.name) |
+ if which_oneof is not None and which_oneof != field.name: |
+ raise tokenizer.ParseErrorPreviousToken( |
+ 'Field "%s" is specified along with field "%s", another member ' |
+ 'of oneof "%s" for message type "%s".' % ( |
+ field.name, which_oneof, field.containing_oneof.name, |
+ message_descriptor.full_name)) |
+ |
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: |
+ tokenizer.TryConsume(':') |
+ merger = self._MergeMessageField |
+ else: |
+ tokenizer.Consume(':') |
+ merger = self._MergeScalarField |
- tokenizer.Consume(']') |
+ if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED |
+ and tokenizer.TryConsume('[')): |
+ # Short repeated format, e.g. "foo: [1, 2, 3]" |
+ while True: |
+ merger(tokenizer, message, field) |
+ if tokenizer.TryConsume(']'): break |
+ tokenizer.Consume(',') |
- else: |
- name = tokenizer.ConsumeIdentifier() |
- field = message_descriptor.fields_by_name.get(name, None) |
- |
- # Group names are expected to be capitalized as they appear in the |
- # .proto file, which actually matches their type names, not their field |
- # names. |
- if not field: |
- field = message_descriptor.fields_by_name.get(name.lower(), None) |
- if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP: |
- field = None |
- |
- if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and |
- field.message_type.name != name): |
- field = None |
- |
- if not field: |
- raise tokenizer.ParseErrorPreviousToken( |
- 'Message type "%s" has no field named "%s".' % ( |
- message_descriptor.full_name, name)) |
- |
- if field and field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: |
+ else: |
+ merger(tokenizer, message, field) |
+ |
+ else: # Proto field is unknown. |
+ assert self.allow_unknown_extension |
+ _SkipFieldContents(tokenizer) |
+ |
+ # For historical reasons, fields may optionally be separated by commas or |
+ # semicolons. |
+ if not tokenizer.TryConsume(','): |
+ tokenizer.TryConsume(';') |
+ |
+ def _MergeMessageField(self, tokenizer, message, field): |
+ """Merges a single scalar field into a message. |
+ |
+ Args: |
+ tokenizer: A tokenizer to parse the field value. |
+ message: The message of which field is a member. |
+ field: The descriptor of the field to be merged. |
+ |
+ Raises: |
+ ParseError: In case of text parsing problems. |
+ """ |
is_map_entry = _IsMapEntry(field) |
- tokenizer.TryConsume(':') |
if tokenizer.TryConsume('<'): |
end_token = '>' |
@@ -456,6 +586,7 @@ def _MergeField(tokenizer, |
if field.is_extension: |
sub_message = message.Extensions[field].add() |
elif is_map_entry: |
+ # pylint: disable=protected-access |
sub_message = field.message_type._concrete_class() |
else: |
sub_message = getattr(message, field.name).add() |
@@ -468,9 +599,8 @@ def _MergeField(tokenizer, |
while not tokenizer.TryConsume(end_token): |
if tokenizer.AtEnd(): |
- raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token)) |
- _MergeField(tokenizer, sub_message, allow_multiple_scalars, |
- allow_unknown_extension) |
+ raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token,)) |
+ self._MergeField(tokenizer, sub_message) |
if is_map_entry: |
value_cpptype = field.message_type.fields_by_name['value'].cpp_type |
@@ -479,26 +609,70 @@ def _MergeField(tokenizer, |
value.MergeFrom(sub_message.value) |
else: |
getattr(message, field.name)[sub_message.key] = sub_message.value |
- elif field: |
- tokenizer.Consume(':') |
- if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED and |
- tokenizer.TryConsume('[')): |
- # Short repeated format, e.g. "foo: [1, 2, 3]" |
- while True: |
- _MergeScalarField(tokenizer, message, field, allow_multiple_scalars) |
- if tokenizer.TryConsume(']'): |
- break |
- tokenizer.Consume(',') |
+ |
+ def _MergeScalarField(self, tokenizer, message, field): |
+ """Merges a single scalar field into a message. |
+ |
+ Args: |
+ tokenizer: A tokenizer to parse the field value. |
+ message: A protocol message to record the data. |
+ field: The descriptor of the field to be merged. |
+ |
+ Raises: |
+ ParseError: In case of text parsing problems. |
+ RuntimeError: On runtime errors. |
+ """ |
+ _ = self.allow_unknown_extension |
+ value = None |
+ |
+ if field.type in (descriptor.FieldDescriptor.TYPE_INT32, |
+ descriptor.FieldDescriptor.TYPE_SINT32, |
+ descriptor.FieldDescriptor.TYPE_SFIXED32): |
+ value = tokenizer.ConsumeInt32() |
+ elif field.type in (descriptor.FieldDescriptor.TYPE_INT64, |
+ descriptor.FieldDescriptor.TYPE_SINT64, |
+ descriptor.FieldDescriptor.TYPE_SFIXED64): |
+ value = tokenizer.ConsumeInt64() |
+ elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32, |
+ descriptor.FieldDescriptor.TYPE_FIXED32): |
+ value = tokenizer.ConsumeUint32() |
+ elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64, |
+ descriptor.FieldDescriptor.TYPE_FIXED64): |
+ value = tokenizer.ConsumeUint64() |
+ elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT, |
+ descriptor.FieldDescriptor.TYPE_DOUBLE): |
+ value = tokenizer.ConsumeFloat() |
+ elif field.type == descriptor.FieldDescriptor.TYPE_BOOL: |
+ value = tokenizer.ConsumeBool() |
+ elif field.type == descriptor.FieldDescriptor.TYPE_STRING: |
+ value = tokenizer.ConsumeString() |
+ elif field.type == descriptor.FieldDescriptor.TYPE_BYTES: |
+ value = tokenizer.ConsumeByteString() |
+ elif field.type == descriptor.FieldDescriptor.TYPE_ENUM: |
+ value = tokenizer.ConsumeEnum(field) |
else: |
- _MergeScalarField(tokenizer, message, field, allow_multiple_scalars) |
- else: # Proto field is unknown. |
- assert allow_unknown_extension |
- _SkipFieldContents(tokenizer) |
+ raise RuntimeError('Unknown field type %d' % field.type) |
- # For historical reasons, fields may optionally be separated by commas or |
- # semicolons. |
- if not tokenizer.TryConsume(','): |
- tokenizer.TryConsume(';') |
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: |
+ if field.is_extension: |
+ message.Extensions[field].append(value) |
+ else: |
+ getattr(message, field.name).append(value) |
+ else: |
+ if field.is_extension: |
+ if not self._allow_multiple_scalars and message.HasExtension(field): |
+ raise tokenizer.ParseErrorPreviousToken( |
+ 'Message type "%s" should not have multiple "%s" extensions.' % |
+ (message.DESCRIPTOR.full_name, field.full_name)) |
+ else: |
+ message.Extensions[field] = value |
+ else: |
+ if not self._allow_multiple_scalars and message.HasField(field.name): |
+ raise tokenizer.ParseErrorPreviousToken( |
+ 'Message type "%s" should not have multiple "%s" fields.' % |
+ (message.DESCRIPTOR.full_name, field.name)) |
+ else: |
+ setattr(message, field.name, value) |
def _SkipFieldContents(tokenizer): |
@@ -571,10 +745,10 @@ def _SkipFieldValue(tokenizer): |
Raises: |
ParseError: In case an invalid field value is found. |
""" |
- # String tokens can come in multiple adjacent string literals. |
+ # String/bytes tokens can come in multiple adjacent string literals. |
# If we can consume one, consume as many as we can. |
- if tokenizer.TryConsumeString(): |
- while tokenizer.TryConsumeString(): |
+ if tokenizer.TryConsumeByteString(): |
+ while tokenizer.TryConsumeByteString(): |
pass |
return |
@@ -585,73 +759,6 @@ def _SkipFieldValue(tokenizer): |
raise ParseError('Invalid field value: ' + tokenizer.token) |
-def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars): |
- """Merges a single protocol message scalar field into a message. |
- |
- Args: |
- tokenizer: A tokenizer to parse the field value. |
- message: A protocol message to record the data. |
- field: The descriptor of the field to be merged. |
- allow_multiple_scalars: Determines if repeated values for a non-repeated |
- field are permitted, e.g., the string "foo: 1 foo: 2" for a |
- required/optional field named "foo". |
- |
- Raises: |
- ParseError: In case of text parsing problems. |
- RuntimeError: On runtime errors. |
- """ |
- value = None |
- |
- if field.type in (descriptor.FieldDescriptor.TYPE_INT32, |
- descriptor.FieldDescriptor.TYPE_SINT32, |
- descriptor.FieldDescriptor.TYPE_SFIXED32): |
- value = tokenizer.ConsumeInt32() |
- elif field.type in (descriptor.FieldDescriptor.TYPE_INT64, |
- descriptor.FieldDescriptor.TYPE_SINT64, |
- descriptor.FieldDescriptor.TYPE_SFIXED64): |
- value = tokenizer.ConsumeInt64() |
- elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32, |
- descriptor.FieldDescriptor.TYPE_FIXED32): |
- value = tokenizer.ConsumeUint32() |
- elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64, |
- descriptor.FieldDescriptor.TYPE_FIXED64): |
- value = tokenizer.ConsumeUint64() |
- elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT, |
- descriptor.FieldDescriptor.TYPE_DOUBLE): |
- value = tokenizer.ConsumeFloat() |
- elif field.type == descriptor.FieldDescriptor.TYPE_BOOL: |
- value = tokenizer.ConsumeBool() |
- elif field.type == descriptor.FieldDescriptor.TYPE_STRING: |
- value = tokenizer.ConsumeString() |
- elif field.type == descriptor.FieldDescriptor.TYPE_BYTES: |
- value = tokenizer.ConsumeByteString() |
- elif field.type == descriptor.FieldDescriptor.TYPE_ENUM: |
- value = tokenizer.ConsumeEnum(field) |
- else: |
- raise RuntimeError('Unknown field type %d' % field.type) |
- |
- if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: |
- if field.is_extension: |
- message.Extensions[field].append(value) |
- else: |
- getattr(message, field.name).append(value) |
- else: |
- if field.is_extension: |
- if not allow_multiple_scalars and message.HasExtension(field): |
- raise tokenizer.ParseErrorPreviousToken( |
- 'Message type "%s" should not have multiple "%s" extensions.' % |
- (message.DESCRIPTOR.full_name, field.full_name)) |
- else: |
- message.Extensions[field] = value |
- else: |
- if not allow_multiple_scalars and message.HasField(field.name): |
- raise tokenizer.ParseErrorPreviousToken( |
- 'Message type "%s" should not have multiple "%s" fields.' % |
- (message.DESCRIPTOR.full_name, field.name)) |
- else: |
- setattr(message, field.name, value) |
- |
- |
class _Tokenizer(object): |
"""Protocol buffer text representation tokenizer. |
@@ -882,9 +989,9 @@ class _Tokenizer(object): |
self.NextToken() |
return result |
- def TryConsumeString(self): |
+ def TryConsumeByteString(self): |
try: |
- self.ConsumeString() |
+ self.ConsumeByteString() |
return True |
except ParseError: |
return False |