Index: third_party/protobuf/python/google/protobuf/symbol_database.py |
diff --git a/third_party/protobuf/python/google/protobuf/symbol_database.py b/third_party/protobuf/python/google/protobuf/symbol_database.py |
index ecbef2111425fcb1ba3549d5ced77fb73d22d7a2..87760f263043d31e2e96e41d16239bdb8516a57b 100644 |
--- a/third_party/protobuf/python/google/protobuf/symbol_database.py |
+++ b/third_party/protobuf/python/google/protobuf/symbol_database.py |
@@ -30,9 +30,11 @@ |
"""A database of Python protocol buffer generated symbols. |
-SymbolDatabase is the MessageFactory for messages generated at compile time, |
-and makes it easy to create new instances of a registered type, given only the |
-type's protocol buffer symbol name. |
+SymbolDatabase makes it easy to create new instances of a registered type, given |
+only the type's protocol buffer symbol name. Once all symbols are registered, |
+they can be accessed using either the MessageFactory interface which |
+SymbolDatabase exposes, or the DescriptorPool interface of the underlying |
+pool. |
Example usage: |
@@ -59,17 +61,27 @@ Example usage: |
from google.protobuf import descriptor_pool |
-from google.protobuf import message_factory |
-class SymbolDatabase(message_factory.MessageFactory): |
- """A database of Python generated symbols.""" |
+class SymbolDatabase(object): |
+ """A database of Python generated symbols. |
+ |
+ SymbolDatabase also models message_factory.MessageFactory. |
+ |
+ The symbol database can be used to keep a global registry of all protocol |
+ buffer types used within a program. |
+ """ |
+ |
+ def __init__(self, pool=None): |
+ """Constructor.""" |
+ |
+ self._symbols = {} |
+ self._symbols_by_file = {} |
+ self.pool = pool or descriptor_pool.Default() |
def RegisterMessage(self, message): |
"""Registers the given message type in the local database. |
- Calls to GetSymbol() and GetMessages() will return messages registered here. |
- |
Args: |
message: a message.Message, to be registered. |
@@ -78,7 +90,10 @@ class SymbolDatabase(message_factory.MessageFactory): |
""" |
desc = message.DESCRIPTOR |
- self._classes[desc.full_name] = message |
+ self._symbols[desc.full_name] = message |
+ if desc.file.name not in self._symbols_by_file: |
+ self._symbols_by_file[desc.file.name] = {} |
+ self._symbols_by_file[desc.file.name][desc.full_name] = message |
self.pool.AddDescriptor(desc) |
return message |
@@ -121,47 +136,47 @@ class SymbolDatabase(message_factory.MessageFactory): |
KeyError: if the symbol could not be found. |
""" |
- return self._classes[symbol] |
+ return self._symbols[symbol] |
+ |
+ def GetPrototype(self, descriptor): |
+ """Builds a proto2 message class based on the passed in descriptor. |
+ |
+ Passing a descriptor with a fully qualified name matching a previous |
+ invocation will cause the same class to be returned. |
+ |
+ Args: |
+ descriptor: The descriptor to build from. |
+ |
+ Returns: |
+ A class describing the passed in descriptor. |
+ """ |
+ |
+ return self.GetSymbol(descriptor.full_name) |
def GetMessages(self, files): |
- # TODO(amauryfa): Fix the differences with MessageFactory. |
- """Gets all registered messages from a specified file. |
+ """Gets all the messages from a specified file. |
+ |
+ This will find and resolve dependencies, failing if they are not registered |
+ in the symbol database. |
- Only messages already created and registered will be returned; (this is the |
- case for imported _pb2 modules) |
- But unlike MessageFactory, this version also returns already defined nested |
- messages, but does not register any message extensions. |
Args: |
files: The file names to extract messages from. |
Returns: |
- A dictionary mapping proto names to the message classes. |
+ A dictionary mapping proto names to the message classes. This will include |
+ any dependent messages as well as any messages defined in the same file as |
+ a specified message. |
Raises: |
KeyError: if a file could not be found. |
""" |
- def _GetAllMessageNames(desc): |
- """Walk a message Descriptor and recursively yields all message names.""" |
- yield desc.full_name |
- for msg_desc in desc.nested_types: |
- for full_name in _GetAllMessageNames(msg_desc): |
- yield full_name |
- |
result = {} |
- for file_name in files: |
- file_desc = self.pool.FindFileByName(file_name) |
- for msg_desc in file_desc.message_types_by_name.values(): |
- for full_name in _GetAllMessageNames(msg_desc): |
- try: |
- result[full_name] = self._classes[full_name] |
- except KeyError: |
- # This descriptor has no registered class, skip it. |
- pass |
+ for f in files: |
+ result.update(self._symbols_by_file[f]) |
return result |
- |
_DEFAULT = SymbolDatabase(pool=descriptor_pool.Default()) |