Index: third_party/protobuf/python/google/protobuf/pyext/extension_dict.cc |
diff --git a/third_party/protobuf/python/google/protobuf/pyext/extension_dict.cc b/third_party/protobuf/python/google/protobuf/pyext/extension_dict.cc |
index 9423c1d890b102ff014fa769abd89435b81d8e76..21bbb8c2b1d3d7e0f2073f5bf3b4e072716fa37f 100644 |
--- a/third_party/protobuf/python/google/protobuf/pyext/extension_dict.cc |
+++ b/third_party/protobuf/python/google/protobuf/pyext/extension_dict.cc |
@@ -38,25 +38,14 @@ |
#include <google/protobuf/descriptor.h> |
#include <google/protobuf/dynamic_message.h> |
#include <google/protobuf/message.h> |
-#include <google/protobuf/descriptor.pb.h> |
#include <google/protobuf/pyext/descriptor.h> |
+#include <google/protobuf/pyext/descriptor_pool.h> |
#include <google/protobuf/pyext/message.h> |
-#include <google/protobuf/pyext/message_factory.h> |
#include <google/protobuf/pyext/repeated_composite_container.h> |
#include <google/protobuf/pyext/repeated_scalar_container.h> |
#include <google/protobuf/pyext/scoped_pyobject_ptr.h> |
#include <google/protobuf/stubs/shared_ptr.h> |
-#if PY_MAJOR_VERSION >= 3 |
- #if PY_VERSION_HEX < 0x03030000 |
- #error "Python 3.0 - 3.2 are not supported." |
- #endif |
- #define PyString_AsStringAndSize(ob, charpp, sizep) \ |
- (PyUnicode_Check(ob)? \ |
- ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ |
- PyBytes_AsStringAndSize(ob, (charpp), (sizep))) |
-#endif |
- |
namespace google { |
namespace protobuf { |
namespace python { |
@@ -71,6 +60,35 @@ PyObject* len(ExtensionDict* self) { |
#endif |
} |
+// TODO(tibell): Use VisitCompositeField. |
+int ReleaseExtension(ExtensionDict* self, |
+ PyObject* extension, |
+ const FieldDescriptor* descriptor) { |
+ if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { |
+ if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
+ if (repeated_composite_container::Release( |
+ reinterpret_cast<RepeatedCompositeContainer*>( |
+ extension)) < 0) { |
+ return -1; |
+ } |
+ } else { |
+ if (repeated_scalar_container::Release( |
+ reinterpret_cast<RepeatedScalarContainer*>( |
+ extension)) < 0) { |
+ return -1; |
+ } |
+ } |
+ } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
+ if (cmessage::ReleaseSubMessage( |
+ self->parent, descriptor, |
+ reinterpret_cast<CMessage*>(extension)) < 0) { |
+ return -1; |
+ } |
+ } |
+ |
+ return 0; |
+} |
+ |
PyObject* subscript(ExtensionDict* self, PyObject* key) { |
const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key); |
if (descriptor == NULL) { |
@@ -101,7 +119,6 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { |
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && |
descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
- // TODO(plabatut): consider building the class on the fly! |
PyObject* sub_message = cmessage::InternalGetSubMessage( |
self->parent, descriptor); |
if (sub_message == NULL) { |
@@ -113,18 +130,8 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { |
if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { |
if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
- // On the fly message class creation is needed to support the following |
- // situation: |
- // 1- add FileDescriptor to the pool that contains extensions of a message |
- // defined by another proto file. Do not create any message classes. |
- // 2- instantiate an extended message, and access the extension using |
- // the field descriptor. |
- // 3- the extension submessage fails to be returned, because no class has |
- // been created. |
- // It happens when deserializing text proto format, or when enumerating |
- // fields of a deserialized message. |
- CMessageClass* message_class = message_factory::GetOrCreateMessageClass( |
- cmessage::GetFactoryForMessage(self->parent), |
+ CMessageClass* message_class = cdescriptor_pool::GetMessageClass( |
+ cmessage::GetDescriptorPoolForMessage(self->parent), |
descriptor->message_type()); |
if (message_class == NULL) { |
return NULL; |
@@ -176,51 +183,75 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { |
return 0; |
} |
-PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) { |
- char* name; |
- Py_ssize_t name_size; |
- if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { |
+PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) { |
+ const FieldDescriptor* descriptor = |
+ cmessage::GetExtensionDescriptor(extension); |
+ if (descriptor == NULL) { |
return NULL; |
} |
- |
- PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool; |
- const FieldDescriptor* message_extension = |
- pool->pool->FindExtensionByName(string(name, name_size)); |
- if (message_extension == NULL) { |
- // Is is the name of a message set extension? |
- const Descriptor* message_descriptor = pool->pool->FindMessageTypeByName( |
- string(name, name_size)); |
- if (message_descriptor && message_descriptor->extension_count() > 0) { |
- const FieldDescriptor* extension = message_descriptor->extension(0); |
- if (extension->is_extension() && |
- extension->containing_type()->options().message_set_wire_format() && |
- extension->type() == FieldDescriptor::TYPE_MESSAGE && |
- extension->label() == FieldDescriptor::LABEL_OPTIONAL) { |
- message_extension = extension; |
+ PyObject* value = PyDict_GetItem(self->values, extension); |
+ if (self->parent) { |
+ if (value != NULL) { |
+ if (ReleaseExtension(self, value, descriptor) < 0) { |
+ return NULL; |
} |
} |
+ if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor( |
+ self->parent, descriptor)) == NULL) { |
+ return NULL; |
+ } |
} |
- if (message_extension == NULL) { |
- Py_RETURN_NONE; |
+ if (PyDict_DelItem(self->values, extension) < 0) { |
+ PyErr_Clear(); |
} |
- |
- return PyFieldDescriptor_FromDescriptor(message_extension); |
+ Py_RETURN_NONE; |
} |
-PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) { |
- int64 number = PyLong_AsLong(arg); |
- if (number == -1 && PyErr_Occurred()) { |
+PyObject* HasExtension(ExtensionDict* self, PyObject* extension) { |
+ const FieldDescriptor* descriptor = |
+ cmessage::GetExtensionDescriptor(extension); |
+ if (descriptor == NULL) { |
return NULL; |
} |
+ if (self->parent) { |
+ return cmessage::HasFieldByDescriptor(self->parent, descriptor); |
+ } else { |
+ int exists = PyDict_Contains(self->values, extension); |
+ if (exists < 0) { |
+ return NULL; |
+ } |
+ return PyBool_FromLong(exists); |
+ } |
+} |
- PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool; |
- const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber( |
- self->parent->message->GetDescriptor(), number); |
- if (message_extension == NULL) { |
+PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { |
+ ScopedPyObjectPtr extensions_by_name(PyObject_GetAttrString( |
+ reinterpret_cast<PyObject*>(self->parent), "_extensions_by_name")); |
+ if (extensions_by_name == NULL) { |
+ return NULL; |
+ } |
+ PyObject* result = PyDict_GetItem(extensions_by_name.get(), name); |
+ if (result == NULL) { |
Py_RETURN_NONE; |
+ } else { |
+ Py_INCREF(result); |
+ return result; |
} |
+} |
- return PyFieldDescriptor_FromDescriptor(message_extension); |
+PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number) { |
+ ScopedPyObjectPtr extensions_by_number(PyObject_GetAttrString( |
+ reinterpret_cast<PyObject*>(self->parent), "_extensions_by_number")); |
+ if (extensions_by_number == NULL) { |
+ return NULL; |
+ } |
+ PyObject* result = PyDict_GetItem(extensions_by_number.get(), number); |
+ if (result == NULL) { |
+ Py_RETURN_NONE; |
+ } else { |
+ Py_INCREF(result); |
+ return result; |
+ } |
} |
ExtensionDict* NewExtensionDict(CMessage *parent) { |
@@ -251,6 +282,8 @@ static PyMappingMethods MpMethods = { |
#define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc } |
static PyMethodDef Methods[] = { |
+ EDMETHOD(ClearExtension, METH_O, "Clears an extension from the object."), |
+ EDMETHOD(HasExtension, METH_O, "Checks if the object has an extension."), |
EDMETHOD(_FindExtensionByName, METH_O, |
"Finds an extension by name."), |
EDMETHOD(_FindExtensionByNumber, METH_O, |