Index: third_party/protobuf/python/google/protobuf/pyext/extension_dict.cc |
diff --git a/third_party/protobuf/python/google/protobuf/pyext/extension_dict.cc b/third_party/protobuf/python/google/protobuf/pyext/extension_dict.cc |
index 21bbb8c2b1d3d7e0f2073f5bf3b4e072716fa37f..9423c1d890b102ff014fa769abd89435b81d8e76 100644 |
--- a/third_party/protobuf/python/google/protobuf/pyext/extension_dict.cc |
+++ b/third_party/protobuf/python/google/protobuf/pyext/extension_dict.cc |
@@ -38,14 +38,25 @@ |
#include <google/protobuf/descriptor.h> |
#include <google/protobuf/dynamic_message.h> |
#include <google/protobuf/message.h> |
+#include <google/protobuf/descriptor.pb.h> |
#include <google/protobuf/pyext/descriptor.h> |
-#include <google/protobuf/pyext/descriptor_pool.h> |
#include <google/protobuf/pyext/message.h> |
+#include <google/protobuf/pyext/message_factory.h> |
#include <google/protobuf/pyext/repeated_composite_container.h> |
#include <google/protobuf/pyext/repeated_scalar_container.h> |
#include <google/protobuf/pyext/scoped_pyobject_ptr.h> |
#include <google/protobuf/stubs/shared_ptr.h> |
+#if PY_MAJOR_VERSION >= 3 |
+ #if PY_VERSION_HEX < 0x03030000 |
+ #error "Python 3.0 - 3.2 are not supported." |
+ #endif |
+ #define PyString_AsStringAndSize(ob, charpp, sizep) \ |
+ (PyUnicode_Check(ob)? \ |
+ ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ |
+ PyBytes_AsStringAndSize(ob, (charpp), (sizep))) |
+#endif |
+ |
namespace google { |
namespace protobuf { |
namespace python { |
@@ -60,35 +71,6 @@ PyObject* len(ExtensionDict* self) { |
#endif |
} |
-// TODO(tibell): Use VisitCompositeField. |
-int ReleaseExtension(ExtensionDict* self, |
- PyObject* extension, |
- const FieldDescriptor* descriptor) { |
- if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { |
- if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
- if (repeated_composite_container::Release( |
- reinterpret_cast<RepeatedCompositeContainer*>( |
- extension)) < 0) { |
- return -1; |
- } |
- } else { |
- if (repeated_scalar_container::Release( |
- reinterpret_cast<RepeatedScalarContainer*>( |
- extension)) < 0) { |
- return -1; |
- } |
- } |
- } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
- if (cmessage::ReleaseSubMessage( |
- self->parent, descriptor, |
- reinterpret_cast<CMessage*>(extension)) < 0) { |
- return -1; |
- } |
- } |
- |
- return 0; |
-} |
- |
PyObject* subscript(ExtensionDict* self, PyObject* key) { |
const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key); |
if (descriptor == NULL) { |
@@ -119,6 +101,7 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { |
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && |
descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
+ // TODO(plabatut): consider building the class on the fly! |
PyObject* sub_message = cmessage::InternalGetSubMessage( |
self->parent, descriptor); |
if (sub_message == NULL) { |
@@ -130,8 +113,18 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { |
if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { |
if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
- CMessageClass* message_class = cdescriptor_pool::GetMessageClass( |
- cmessage::GetDescriptorPoolForMessage(self->parent), |
+ // On the fly message class creation is needed to support the following |
+ // situation: |
+ // 1- add FileDescriptor to the pool that contains extensions of a message |
+ // defined by another proto file. Do not create any message classes. |
+ // 2- instantiate an extended message, and access the extension using |
+ // the field descriptor. |
+ // 3- the extension submessage fails to be returned, because no class has |
+ // been created. |
+ // It happens when deserializing text proto format, or when enumerating |
+ // fields of a deserialized message. |
+ CMessageClass* message_class = message_factory::GetOrCreateMessageClass( |
+ cmessage::GetFactoryForMessage(self->parent), |
descriptor->message_type()); |
if (message_class == NULL) { |
return NULL; |
@@ -183,75 +176,51 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { |
return 0; |
} |
-PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) { |
- const FieldDescriptor* descriptor = |
- cmessage::GetExtensionDescriptor(extension); |
- if (descriptor == NULL) { |
+PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) { |
+ char* name; |
+ Py_ssize_t name_size; |
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { |
return NULL; |
} |
- PyObject* value = PyDict_GetItem(self->values, extension); |
- if (self->parent) { |
- if (value != NULL) { |
- if (ReleaseExtension(self, value, descriptor) < 0) { |
- return NULL; |
+ |
+ PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool; |
+ const FieldDescriptor* message_extension = |
+ pool->pool->FindExtensionByName(string(name, name_size)); |
+ if (message_extension == NULL) { |
+ // Is is the name of a message set extension? |
+ const Descriptor* message_descriptor = pool->pool->FindMessageTypeByName( |
+ string(name, name_size)); |
+ if (message_descriptor && message_descriptor->extension_count() > 0) { |
+ const FieldDescriptor* extension = message_descriptor->extension(0); |
+ if (extension->is_extension() && |
+ extension->containing_type()->options().message_set_wire_format() && |
+ extension->type() == FieldDescriptor::TYPE_MESSAGE && |
+ extension->label() == FieldDescriptor::LABEL_OPTIONAL) { |
+ message_extension = extension; |
} |
} |
- if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor( |
- self->parent, descriptor)) == NULL) { |
- return NULL; |
- } |
} |
- if (PyDict_DelItem(self->values, extension) < 0) { |
- PyErr_Clear(); |
+ if (message_extension == NULL) { |
+ Py_RETURN_NONE; |
} |
- Py_RETURN_NONE; |
-} |
-PyObject* HasExtension(ExtensionDict* self, PyObject* extension) { |
- const FieldDescriptor* descriptor = |
- cmessage::GetExtensionDescriptor(extension); |
- if (descriptor == NULL) { |
- return NULL; |
- } |
- if (self->parent) { |
- return cmessage::HasFieldByDescriptor(self->parent, descriptor); |
- } else { |
- int exists = PyDict_Contains(self->values, extension); |
- if (exists < 0) { |
- return NULL; |
- } |
- return PyBool_FromLong(exists); |
- } |
+ return PyFieldDescriptor_FromDescriptor(message_extension); |
} |
-PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { |
- ScopedPyObjectPtr extensions_by_name(PyObject_GetAttrString( |
- reinterpret_cast<PyObject*>(self->parent), "_extensions_by_name")); |
- if (extensions_by_name == NULL) { |
+PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) { |
+ int64 number = PyLong_AsLong(arg); |
+ if (number == -1 && PyErr_Occurred()) { |
return NULL; |
} |
- PyObject* result = PyDict_GetItem(extensions_by_name.get(), name); |
- if (result == NULL) { |
- Py_RETURN_NONE; |
- } else { |
- Py_INCREF(result); |
- return result; |
- } |
-} |
-PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number) { |
- ScopedPyObjectPtr extensions_by_number(PyObject_GetAttrString( |
- reinterpret_cast<PyObject*>(self->parent), "_extensions_by_number")); |
- if (extensions_by_number == NULL) { |
- return NULL; |
- } |
- PyObject* result = PyDict_GetItem(extensions_by_number.get(), number); |
- if (result == NULL) { |
+ PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool; |
+ const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber( |
+ self->parent->message->GetDescriptor(), number); |
+ if (message_extension == NULL) { |
Py_RETURN_NONE; |
- } else { |
- Py_INCREF(result); |
- return result; |
} |
+ |
+ return PyFieldDescriptor_FromDescriptor(message_extension); |
} |
ExtensionDict* NewExtensionDict(CMessage *parent) { |
@@ -282,8 +251,6 @@ static PyMappingMethods MpMethods = { |
#define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc } |
static PyMethodDef Methods[] = { |
- EDMETHOD(ClearExtension, METH_O, "Clears an extension from the object."), |
- EDMETHOD(HasExtension, METH_O, "Checks if the object has an extension."), |
EDMETHOD(_FindExtensionByName, METH_O, |
"Finds an extension by name."), |
EDMETHOD(_FindExtensionByNumber, METH_O, |