Index: third_party/protobuf/python/google/protobuf/symbol_database.py |
diff --git a/third_party/protobuf/python/google/protobuf/symbol_database.py b/third_party/protobuf/python/google/protobuf/symbol_database.py |
index 87760f263043d31e2e96e41d16239bdb8516a57b..ecbef2111425fcb1ba3549d5ced77fb73d22d7a2 100644 |
--- a/third_party/protobuf/python/google/protobuf/symbol_database.py |
+++ b/third_party/protobuf/python/google/protobuf/symbol_database.py |
@@ -30,11 +30,9 @@ |
"""A database of Python protocol buffer generated symbols. |
-SymbolDatabase makes it easy to create new instances of a registered type, given |
-only the type's protocol buffer symbol name. Once all symbols are registered, |
-they can be accessed using either the MessageFactory interface which |
-SymbolDatabase exposes, or the DescriptorPool interface of the underlying |
-pool. |
+SymbolDatabase is the MessageFactory for messages generated at compile time, |
+and makes it easy to create new instances of a registered type, given only the |
+type's protocol buffer symbol name. |
Example usage: |
@@ -61,27 +59,17 @@ Example usage: |
from google.protobuf import descriptor_pool |
+from google.protobuf import message_factory |
-class SymbolDatabase(object): |
- """A database of Python generated symbols. |
- |
- SymbolDatabase also models message_factory.MessageFactory. |
- |
- The symbol database can be used to keep a global registry of all protocol |
- buffer types used within a program. |
- """ |
- |
- def __init__(self, pool=None): |
- """Constructor.""" |
- |
- self._symbols = {} |
- self._symbols_by_file = {} |
- self.pool = pool or descriptor_pool.Default() |
+class SymbolDatabase(message_factory.MessageFactory): |
+ """A database of Python generated symbols.""" |
def RegisterMessage(self, message): |
"""Registers the given message type in the local database. |
+ Calls to GetSymbol() and GetMessages() will return messages registered here. |
+ |
Args: |
message: a message.Message, to be registered. |
@@ -90,10 +78,7 @@ class SymbolDatabase(object): |
""" |
desc = message.DESCRIPTOR |
- self._symbols[desc.full_name] = message |
- if desc.file.name not in self._symbols_by_file: |
- self._symbols_by_file[desc.file.name] = {} |
- self._symbols_by_file[desc.file.name][desc.full_name] = message |
+ self._classes[desc.full_name] = message |
self.pool.AddDescriptor(desc) |
return message |
@@ -136,47 +121,47 @@ class SymbolDatabase(object): |
KeyError: if the symbol could not be found. |
""" |
- return self._symbols[symbol] |
- |
- def GetPrototype(self, descriptor): |
- """Builds a proto2 message class based on the passed in descriptor. |
- |
- Passing a descriptor with a fully qualified name matching a previous |
- invocation will cause the same class to be returned. |
- |
- Args: |
- descriptor: The descriptor to build from. |
- |
- Returns: |
- A class describing the passed in descriptor. |
- """ |
- |
- return self.GetSymbol(descriptor.full_name) |
+ return self._classes[symbol] |
def GetMessages(self, files): |
- """Gets all the messages from a specified file. |
- |
- This will find and resolve dependencies, failing if they are not registered |
- in the symbol database. |
+ # TODO(amauryfa): Fix the differences with MessageFactory. |
+ """Gets all registered messages from a specified file. |
+ Only messages already created and registered will be returned; (this is the |
+ case for imported _pb2 modules) |
+ But unlike MessageFactory, this version also returns already defined nested |
+ messages, but does not register any message extensions. |
Args: |
files: The file names to extract messages from. |
Returns: |
- A dictionary mapping proto names to the message classes. This will include |
- any dependent messages as well as any messages defined in the same file as |
- a specified message. |
+ A dictionary mapping proto names to the message classes. |
Raises: |
KeyError: if a file could not be found. |
""" |
+ def _GetAllMessageNames(desc): |
+ """Walk a message Descriptor and recursively yields all message names.""" |
+ yield desc.full_name |
+ for msg_desc in desc.nested_types: |
+ for full_name in _GetAllMessageNames(msg_desc): |
+ yield full_name |
+ |
result = {} |
- for f in files: |
- result.update(self._symbols_by_file[f]) |
+ for file_name in files: |
+ file_desc = self.pool.FindFileByName(file_name) |
+ for msg_desc in file_desc.message_types_by_name.values(): |
+ for full_name in _GetAllMessageNames(msg_desc): |
+ try: |
+ result[full_name] = self._classes[full_name] |
+ except KeyError: |
+ # This descriptor has no registered class, skip it. |
+ pass |
return result |
+ |
_DEFAULT = SymbolDatabase(pool=descriptor_pool.Default()) |