Index: third_party/protobuf/python/google/protobuf/pyext/message.cc |
diff --git a/third_party/protobuf/python/google/protobuf/pyext/message.cc b/third_party/protobuf/python/google/protobuf/pyext/message.cc |
index 4f3abc84bedce3584629ef4d98d52792e3dcbdaa..83c151ff626aabab287e8bfe39b4b8b80e99791f 100644 |
--- a/third_party/protobuf/python/google/protobuf/pyext/message.cc |
+++ b/third_party/protobuf/python/google/protobuf/pyext/message.cc |
@@ -63,12 +63,11 @@ |
#include <google/protobuf/pyext/repeated_composite_container.h> |
#include <google/protobuf/pyext/repeated_scalar_container.h> |
#include <google/protobuf/pyext/map_container.h> |
-#include <google/protobuf/pyext/message_factory.h> |
-#include <google/protobuf/pyext/safe_numerics.h> |
#include <google/protobuf/pyext/scoped_pyobject_ptr.h> |
#include <google/protobuf/stubs/strutil.h> |
#if PY_MAJOR_VERSION >= 3 |
+ #define PyInt_Check PyLong_Check |
#define PyInt_AsLong PyLong_AsLong |
#define PyInt_FromLong PyLong_FromLong |
#define PyInt_FromSize_t PyLong_FromSize_t |
@@ -92,6 +91,8 @@ namespace protobuf { |
namespace python { |
static PyObject* kDESCRIPTOR; |
+static PyObject* k_extensions_by_name; |
+static PyObject* k_extensions_by_number; |
PyObject* EnumTypeWrapper_class; |
static PyObject* PythonMessage_class; |
static PyObject* kEmptyWeakref; |
@@ -126,6 +127,19 @@ static bool AddFieldNumberToClass( |
// Finalize the creation of the Message class. |
static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { |
+ // If there are extension_ranges, the message is "extendable", and extension |
+ // classes will register themselves in this class. |
+ if (descriptor->extension_range_count() > 0) { |
+ ScopedPyObjectPtr by_name(PyDict_New()); |
+ if (PyObject_SetAttr(cls, k_extensions_by_name, by_name.get()) < 0) { |
+ return -1; |
+ } |
+ ScopedPyObjectPtr by_number(PyDict_New()); |
+ if (PyObject_SetAttr(cls, k_extensions_by_number, by_number.get()) < 0) { |
+ return -1; |
+ } |
+ } |
+ |
// For each field set: cls.<field>_FIELD_NUMBER = <number> |
for (int i = 0; i < descriptor->field_count(); ++i) { |
if (!AddFieldNumberToClass(cls, descriptor->field(i))) { |
@@ -230,12 +244,6 @@ static PyObject* New(PyTypeObject* type, |
return NULL; |
} |
- // Messages have no __dict__ |
- ScopedPyObjectPtr slots(PyTuple_New(0)); |
- if (PyDict_SetItemString(dict, "__slots__", slots.get()) < 0) { |
- return NULL; |
- } |
- |
// Build the arguments to the base metaclass. |
// We change the __bases__ classes. |
ScopedPyObjectPtr new_args; |
@@ -292,19 +300,16 @@ static PyObject* New(PyTypeObject* type, |
newtype->message_descriptor = descriptor; |
// TODO(amauryfa): Don't always use the canonical pool of the descriptor, |
// use the MessageFactory optionally passed in the class dict. |
- PyDescriptorPool* py_descriptor_pool = |
- GetDescriptorPool_FromPool(descriptor->file()->pool()); |
- if (py_descriptor_pool == NULL) { |
+ newtype->py_descriptor_pool = GetDescriptorPool_FromPool( |
+ descriptor->file()->pool()); |
+ if (newtype->py_descriptor_pool == NULL) { |
return NULL; |
} |
- newtype->py_message_factory = py_descriptor_pool->py_message_factory; |
- Py_INCREF(newtype->py_message_factory); |
+ Py_INCREF(newtype->py_descriptor_pool); |
- // Register the message in the MessageFactory. |
- // TODO(amauryfa): Move this call to MessageFactory.GetPrototype() when the |
- // MessageFactory is fully implemented in C++. |
- if (message_factory::RegisterMessageClass(newtype->py_message_factory, |
- descriptor, newtype) < 0) { |
+ // Add the message to the DescriptorPool. |
+ if (cdescriptor_pool::RegisterMessageClass(newtype->py_descriptor_pool, |
+ descriptor, newtype) < 0) { |
return NULL; |
} |
@@ -316,8 +321,8 @@ static PyObject* New(PyTypeObject* type, |
} |
static void Dealloc(CMessageClass *self) { |
- Py_XDECREF(self->py_message_descriptor); |
- Py_XDECREF(self->py_message_factory); |
+ Py_DECREF(self->py_message_descriptor); |
+ Py_DECREF(self->py_descriptor_pool); |
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); |
} |
@@ -342,61 +347,6 @@ static int InsertEmptyWeakref(PyTypeObject *base_type) { |
#endif // PY_MAJOR_VERSION >= 3 |
} |
-// The _extensions_by_name dictionary is built on every access. |
-// TODO(amauryfa): Migrate all users to pool.FindAllExtensions() |
-static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) { |
- const PyDescriptorPool* pool = self->py_message_factory->pool; |
- |
- std::vector<const FieldDescriptor*> extensions; |
- pool->pool->FindAllExtensions(self->message_descriptor, &extensions); |
- |
- ScopedPyObjectPtr result(PyDict_New()); |
- for (int i = 0; i < extensions.size(); i++) { |
- ScopedPyObjectPtr extension( |
- PyFieldDescriptor_FromDescriptor(extensions[i])); |
- if (extension == NULL) { |
- return NULL; |
- } |
- if (PyDict_SetItemString(result.get(), extensions[i]->full_name().c_str(), |
- extension.get()) < 0) { |
- return NULL; |
- } |
- } |
- return result.release(); |
-} |
- |
-// The _extensions_by_number dictionary is built on every access. |
-// TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber() |
-static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) { |
- const PyDescriptorPool* pool = self->py_message_factory->pool; |
- |
- std::vector<const FieldDescriptor*> extensions; |
- pool->pool->FindAllExtensions(self->message_descriptor, &extensions); |
- |
- ScopedPyObjectPtr result(PyDict_New()); |
- for (int i = 0; i < extensions.size(); i++) { |
- ScopedPyObjectPtr extension( |
- PyFieldDescriptor_FromDescriptor(extensions[i])); |
- if (extension == NULL) { |
- return NULL; |
- } |
- ScopedPyObjectPtr number(PyInt_FromLong(extensions[i]->number())); |
- if (number == NULL) { |
- return NULL; |
- } |
- if (PyDict_SetItem(result.get(), number.get(), extension.get()) < 0) { |
- return NULL; |
- } |
- } |
- return result.release(); |
-} |
- |
-static PyGetSetDef Getters[] = { |
- {"_extensions_by_name", (getter)GetExtensionsByName, NULL}, |
- {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL}, |
- {NULL} |
-}; |
- |
} // namespace message_meta |
PyTypeObject CMessageClass_Type = { |
@@ -429,7 +379,7 @@ PyTypeObject CMessageClass_Type = { |
0, // tp_iternext |
0, // tp_methods |
0, // tp_members |
- message_meta::Getters, // tp_getset |
+ 0, // tp_getset |
0, // tp_base |
0, // tp_dict |
0, // tp_descr_get |
@@ -565,10 +515,23 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) { |
// --------------------------------------------------------------------- |
+// Constants used for integer type range checking. |
+PyObject* kPythonZero; |
+PyObject* kint32min_py; |
+PyObject* kint32max_py; |
+PyObject* kuint32max_py; |
+PyObject* kint64min_py; |
+PyObject* kint64max_py; |
+PyObject* kuint64max_py; |
+ |
PyObject* EncodeError_class; |
PyObject* DecodeError_class; |
PyObject* PickleError_class; |
+// Constant PyString values used for GetAttr/GetItem. |
+static PyObject* k_cdescriptor; |
+static PyObject* kfull_name; |
+ |
/* Is 64bit */ |
void FormatTypeError(PyObject* arg, char* expected_types) { |
PyObject* repr = PyObject_Repr(arg); |
@@ -582,126 +545,68 @@ void FormatTypeError(PyObject* arg, char* expected_types) { |
} |
} |
-void OutOfRangeError(PyObject* arg) { |
- PyObject *s = PyObject_Str(arg); |
- if (s) { |
- PyErr_Format(PyExc_ValueError, |
- "Value out of range: %s", |
- PyString_AsString(s)); |
- Py_DECREF(s); |
- } |
-} |
- |
-template<class RangeType, class ValueType> |
-bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) { |
- if GOOGLE_PREDICT_FALSE(value == -1 && PyErr_Occurred()) { |
- if (PyErr_ExceptionMatches(PyExc_OverflowError)) { |
- // Replace it with the same ValueError as pure python protos instead of |
- // the default one. |
- PyErr_Clear(); |
- OutOfRangeError(arg); |
- } // Otherwise propagate existing error. |
- return false; |
- } |
- if GOOGLE_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value)) { |
- OutOfRangeError(arg); |
- return false; |
- } |
- return true; |
-} |
- |
template<class T> |
-bool CheckAndGetInteger(PyObject* arg, T* value) { |
- // The fast path. |
+bool CheckAndGetInteger( |
+ PyObject* arg, T* value, PyObject* min, PyObject* max) { |
+ bool is_long = PyLong_Check(arg); |
#if PY_MAJOR_VERSION < 3 |
- // For the typical case, offer a fast path. |
- if GOOGLE_PREDICT_TRUE(PyInt_Check(arg)) { |
- long int_result = PyInt_AsLong(arg); |
- if GOOGLE_PREDICT_TRUE(IsValidNumericCast<T>(int_result)) { |
- *value = static_cast<T>(int_result); |
- return true; |
- } else { |
- OutOfRangeError(arg); |
- return false; |
- } |
- } |
-#endif |
- // This effectively defines an integer as "an object that can be cast as |
- // an integer and can be used as an ordinal number". |
- // This definition includes everything that implements numbers.Integral |
- // and shouldn't cast the net too wide. |
- if GOOGLE_PREDICT_FALSE(!PyIndex_Check(arg)) { |
+ if (!PyInt_Check(arg) && !is_long) { |
FormatTypeError(arg, "int, long"); |
return false; |
} |
- |
- // Now we have an integral number so we can safely use PyLong_ functions. |
- // We need to treat the signed and unsigned cases differently in case arg is |
- // holding a value above the maximum for signed longs. |
- if (std::numeric_limits<T>::min() == 0) { |
- // Unsigned case. |
- unsigned PY_LONG_LONG ulong_result; |
- if (PyLong_Check(arg)) { |
- ulong_result = PyLong_AsUnsignedLongLong(arg); |
- } else { |
- // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very |
- // picky about the exact type. |
- PyObject* casted = PyNumber_Long(arg); |
- if GOOGLE_PREDICT_FALSE(casted == NULL) { |
- // Propagate existing error. |
- return false; |
- } |
- ulong_result = PyLong_AsUnsignedLongLong(casted); |
- Py_DECREF(casted); |
- } |
- if (VerifyIntegerCastAndRange<T, unsigned PY_LONG_LONG>(arg, |
- ulong_result)) { |
- *value = static_cast<T>(ulong_result); |
- } else { |
- return false; |
- } |
- } else { |
- // Signed case. |
- PY_LONG_LONG long_result; |
- PyNumberMethods *nb; |
- if ((nb = arg->ob_type->tp_as_number) != NULL && nb->nb_int != NULL) { |
- // PyLong_AsLongLong requires it to be a long or to have an __int__() |
- // method. |
- long_result = PyLong_AsLongLong(arg); |
- } else { |
- // Valid subclasses of numbers.Integral should have a __long__() method |
- // so fall back to that. |
- PyObject* casted = PyNumber_Long(arg); |
- if GOOGLE_PREDICT_FALSE(casted == NULL) { |
- // Propagate existing error. |
- return false; |
+ if (PyObject_Compare(min, arg) > 0 || PyObject_Compare(max, arg) < 0) { |
+#else |
+ if (!is_long) { |
+ FormatTypeError(arg, "int"); |
+ return false; |
+ } |
+ if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 || |
+ PyObject_RichCompareBool(max, arg, Py_GE) != 1) { |
+#endif |
+ if (!PyErr_Occurred()) { |
+ PyObject *s = PyObject_Str(arg); |
+ if (s) { |
+ PyErr_Format(PyExc_ValueError, |
+ "Value out of range: %s", |
+ PyString_AsString(s)); |
+ Py_DECREF(s); |
} |
- long_result = PyLong_AsLongLong(casted); |
- Py_DECREF(casted); |
} |
- if (VerifyIntegerCastAndRange<T, PY_LONG_LONG>(arg, long_result)) { |
- *value = static_cast<T>(long_result); |
+ return false; |
+ } |
+#if PY_MAJOR_VERSION < 3 |
+ if (!is_long) { |
+ *value = static_cast<T>(PyInt_AsLong(arg)); |
+ } else // NOLINT |
+#endif |
+ { |
+ if (min == kPythonZero) { |
+ *value = static_cast<T>(PyLong_AsUnsignedLongLong(arg)); |
} else { |
- return false; |
+ *value = static_cast<T>(PyLong_AsLongLong(arg)); |
} |
} |
- |
return true; |
} |
// These are referenced by repeated_scalar_container, and must |
// be explicitly instantiated. |
-template bool CheckAndGetInteger<int32>(PyObject*, int32*); |
-template bool CheckAndGetInteger<int64>(PyObject*, int64*); |
-template bool CheckAndGetInteger<uint32>(PyObject*, uint32*); |
-template bool CheckAndGetInteger<uint64>(PyObject*, uint64*); |
+template bool CheckAndGetInteger<int32>( |
+ PyObject*, int32*, PyObject*, PyObject*); |
+template bool CheckAndGetInteger<int64>( |
+ PyObject*, int64*, PyObject*, PyObject*); |
+template bool CheckAndGetInteger<uint32>( |
+ PyObject*, uint32*, PyObject*, PyObject*); |
+template bool CheckAndGetInteger<uint64>( |
+ PyObject*, uint64*, PyObject*, PyObject*); |
bool CheckAndGetDouble(PyObject* arg, double* value) { |
- *value = PyFloat_AsDouble(arg); |
- if GOOGLE_PREDICT_FALSE(*value == -1 && PyErr_Occurred()) { |
+ if (!PyInt_Check(arg) && !PyLong_Check(arg) && |
+ !PyFloat_Check(arg)) { |
FormatTypeError(arg, "int, long, float"); |
return false; |
} |
+ *value = PyFloat_AsDouble(arg); |
return true; |
} |
@@ -715,13 +620,11 @@ bool CheckAndGetFloat(PyObject* arg, float* value) { |
} |
bool CheckAndGetBool(PyObject* arg, bool* value) { |
- long long_value = PyInt_AsLong(arg); |
- if (long_value == -1 && PyErr_Occurred()) { |
+ if (!PyInt_Check(arg) && !PyBool_Check(arg) && !PyLong_Check(arg)) { |
FormatTypeError(arg, "int, long, bool"); |
return false; |
} |
- *value = static_cast<bool>(long_value); |
- |
+ *value = static_cast<bool>(PyInt_AsLong(arg)); |
return true; |
} |
@@ -849,9 +752,15 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, |
namespace cmessage { |
-PyMessageFactory* GetFactoryForMessage(CMessage* message) { |
+PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message) { |
+ // No need to check the type: the type of instances of CMessage is always |
+ // an instance of CMessageClass. Let's prove it with a debug-only check. |
GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type)); |
- return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_message_factory; |
+ return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_descriptor_pool; |
+} |
+ |
+MessageFactory* GetFactoryForMessage(CMessage* message) { |
+ return GetDescriptorPoolForMessage(message)->message_factory; |
} |
static int MaybeReleaseOverlappingOneofField( |
@@ -904,8 +813,7 @@ static Message* GetMutableMessage( |
return NULL; |
} |
return reflection->MutableMessage( |
- parent_message, parent_field, |
- GetFactoryForMessage(parent)->message_factory); |
+ parent_message, parent_field, GetFactoryForMessage(parent)); |
} |
struct FixupMessageReference : public ChildVisitor { |
@@ -1053,7 +961,20 @@ int InternalDeleteRepeatedField( |
int min, max; |
length = reflection->FieldSize(*message, field_descriptor); |
- if (PySlice_Check(slice)) { |
+ if (PyInt_Check(slice) || PyLong_Check(slice)) { |
+ from = to = PyLong_AsLong(slice); |
+ if (from < 0) { |
+ from = to = length + from; |
+ } |
+ step = 1; |
+ min = max = from; |
+ |
+ // Range check. |
+ if (from < 0 || from >= length) { |
+ PyErr_Format(PyExc_IndexError, "list assignment index out of range"); |
+ return -1; |
+ } |
+ } else if (PySlice_Check(slice)) { |
from = to = step = slice_length = 0; |
PySlice_GetIndicesEx( |
#if PY_MAJOR_VERSION < 3 |
@@ -1070,23 +991,8 @@ int InternalDeleteRepeatedField( |
max = from; |
} |
} else { |
- from = to = PyLong_AsLong(slice); |
- if (from == -1 && PyErr_Occurred()) { |
- PyErr_SetString(PyExc_TypeError, "list indices must be integers"); |
- return -1; |
- } |
- |
- if (from < 0) { |
- from = to = length + from; |
- } |
- step = 1; |
- min = max = from; |
- |
- // Range check. |
- if (from < 0 || from >= length) { |
- PyErr_Format(PyExc_IndexError, "list assignment index out of range"); |
- return -1; |
- } |
+ PyErr_SetString(PyExc_TypeError, "list indices must be integers"); |
+ return -1; |
} |
Py_ssize_t i = from; |
@@ -1135,12 +1041,7 @@ int InternalDeleteRepeatedField( |
} |
// Initializes fields of a message. Used in constructors. |
-int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { |
- if (args != NULL && PyTuple_Size(args) != 0) { |
- PyErr_SetString(PyExc_TypeError, "No positional arguments allowed"); |
- return -1; |
- } |
- |
+int InitAttributes(CMessage* self, PyObject* kwargs) { |
if (kwargs == NULL) { |
return 0; |
} |
@@ -1266,9 +1167,7 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { |
} |
CMessage* cmessage = reinterpret_cast<CMessage*>(message.get()); |
if (PyDict_Check(value)) { |
- // Make the message exist even if the dict is empty. |
- AssureWritable(cmessage); |
- if (InitAttributes(cmessage, NULL, value) < 0) { |
+ if (InitAttributes(cmessage, value) < 0) { |
return -1; |
} |
} else { |
@@ -1327,7 +1226,7 @@ static PyObject* New(PyTypeObject* cls, |
if (message_descriptor == NULL) { |
return NULL; |
} |
- const Message* default_message = type->py_message_factory->message_factory |
+ const Message* default_message = type->py_descriptor_pool->message_factory |
->GetPrototype(message_descriptor); |
if (default_message == NULL) { |
PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str()); |
@@ -1346,7 +1245,12 @@ static PyObject* New(PyTypeObject* cls, |
// The __init__ method of Message classes. |
// It initializes fields from keywords passed to the constructor. |
static int Init(CMessage* self, PyObject* args, PyObject* kwargs) { |
- return InitAttributes(self, args, kwargs); |
+ if (PyTuple_Size(args) != 0) { |
+ PyErr_SetString(PyExc_TypeError, "No positional arguments allowed"); |
+ return -1; |
+ } |
+ |
+ return InitAttributes(self, kwargs); |
} |
// --------------------------------------------------------------------- |
@@ -1388,9 +1292,6 @@ struct ClearWeakReferences : public ChildVisitor { |
}; |
static void Dealloc(CMessage* self) { |
- if (self->weakreflist) { |
- PyObject_ClearWeakRefs(reinterpret_cast<PyObject*>(self)); |
- } |
// Null out all weak references from children to this message. |
GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences())); |
if (self->extensions) { |
@@ -1558,20 +1459,18 @@ PyObject* HasField(CMessage* self, PyObject* arg) { |
} |
PyObject* ClearExtension(CMessage* self, PyObject* extension) { |
- const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); |
- if (descriptor == NULL) { |
- return NULL; |
- } |
if (self->extensions != NULL) { |
- PyObject* value = PyDict_GetItem(self->extensions->values, extension); |
- if (value != NULL) { |
- if (InternalReleaseFieldByDescriptor(self, descriptor, value) < 0) { |
- return NULL; |
- } |
- PyDict_DelItem(self->extensions->values, extension); |
+ return extension_dict::ClearExtension(self->extensions, extension); |
+ } else { |
+ const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); |
+ if (descriptor == NULL) { |
+ return NULL; |
+ } |
+ if (ScopedPyObjectPtr(ClearFieldByDescriptor(self, descriptor)) == NULL) { |
+ return NULL; |
} |
} |
- return ClearFieldByDescriptor(self, descriptor); |
+ Py_RETURN_NONE; |
} |
PyObject* HasExtension(CMessage* self, PyObject* extension) { |
@@ -1657,7 +1556,7 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) { |
Message* ReleaseMessage(CMessage* self, |
const Descriptor* descriptor, |
const FieldDescriptor* field_descriptor) { |
- MessageFactory* message_factory = GetFactoryForMessage(self)->message_factory; |
+ MessageFactory* message_factory = GetFactoryForMessage(self); |
Message* released_message = self->message->GetReflection()->ReleaseMessage( |
self->message, field_descriptor, message_factory); |
// ReleaseMessage will return NULL which differs from |
@@ -1694,20 +1593,23 @@ struct ReleaseChild : public ChildVisitor { |
parent_(parent) {} |
int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { |
- return repeated_composite_container::Release(container); |
+ return repeated_composite_container::Release( |
+ reinterpret_cast<RepeatedCompositeContainer*>(container)); |
} |
int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) { |
- return repeated_scalar_container::Release(container); |
+ return repeated_scalar_container::Release( |
+ reinterpret_cast<RepeatedScalarContainer*>(container)); |
} |
int VisitMapContainer(MapContainer* container) { |
- return container->Release(); |
+ return reinterpret_cast<MapContainer*>(container)->Release(); |
} |
int VisitCMessage(CMessage* cmessage, |
const FieldDescriptor* field_descriptor) { |
- return ReleaseSubMessage(parent_, field_descriptor, cmessage); |
+ return ReleaseSubMessage(parent_, field_descriptor, |
+ reinterpret_cast<CMessage*>(cmessage)); |
} |
CMessage* parent_; |
@@ -1725,19 +1627,12 @@ int InternalReleaseFieldByDescriptor( |
PyObject* ClearFieldByDescriptor( |
CMessage* self, |
- const FieldDescriptor* field_descriptor) { |
- if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) { |
+ const FieldDescriptor* descriptor) { |
+ if (!CheckFieldBelongsToMessage(descriptor, self->message)) { |
return NULL; |
} |
AssureWritable(self); |
- Message* message = self->message; |
- message->GetReflection()->ClearField(message, field_descriptor); |
- if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM && |
- !message->GetReflection()->SupportsUnknownEnumValues()) { |
- UnknownFieldSet* unknown_field_set = |
- message->GetReflection()->MutableUnknownFields(message); |
- unknown_field_set->DeleteByNumber(field_descriptor->number()); |
- } |
+ self->message->GetReflection()->ClearField(self->message, descriptor); |
Py_RETURN_NONE; |
} |
@@ -1773,17 +1668,27 @@ PyObject* ClearField(CMessage* self, PyObject* arg) { |
arg = arg_in_oneof.get(); |
} |
- // Release the field if it exists in the dict of composite fields. |
- if (self->composite_fields) { |
- PyObject* value = PyDict_GetItem(self->composite_fields, arg); |
- if (value != NULL) { |
- if (InternalReleaseFieldByDescriptor(self, field_descriptor, value) < 0) { |
- return NULL; |
- } |
- PyDict_DelItem(self->composite_fields, arg); |
+ PyObject* composite_field = self->composite_fields ? |
+ PyDict_GetItem(self->composite_fields, arg) : NULL; |
+ |
+ // Only release the field if there's a possibility that there are |
+ // references to it. |
+ if (composite_field != NULL) { |
+ if (InternalReleaseFieldByDescriptor(self, field_descriptor, |
+ composite_field) < 0) { |
+ return NULL; |
} |
+ PyDict_DelItem(self->composite_fields, arg); |
+ } |
+ message->GetReflection()->ClearField(message, field_descriptor); |
+ if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM && |
+ !message->GetReflection()->SupportsUnknownEnumValues()) { |
+ UnknownFieldSet* unknown_field_set = |
+ message->GetReflection()->MutableUnknownFields(message); |
+ unknown_field_set->DeleteByNumber(field_descriptor->number()); |
} |
- return ClearFieldByDescriptor(self, field_descriptor); |
+ |
+ Py_RETURN_NONE; |
} |
PyObject* Clear(CMessage* self) { |
@@ -1994,15 +1899,11 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) { |
// get OOM errors. The protobuf APIs do not provide any tools for processing |
// protobufs in chunks. If you have protos this big you should break them up if |
// it is at all convenient to do so. |
-#ifdef PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS |
-static bool allow_oversize_protos = true; |
-#else |
static bool allow_oversize_protos = false; |
-#endif |
// Provide a method in the module to set allow_oversize_protos to a boolean |
// value. This method returns the newly value of allow_oversize_protos. |
-PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) { |
+static PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) { |
if (!arg || !PyBool_Check(arg)) { |
PyErr_SetString(PyExc_TypeError, |
"Argument to SetAllowOversizeProtos must be boolean"); |
@@ -2029,8 +1930,8 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { |
if (allow_oversize_protos) { |
input.SetTotalBytesLimit(INT_MAX, INT_MAX); |
} |
- PyMessageFactory* factory = GetFactoryForMessage(self); |
- input.SetExtensionRegistry(factory->pool->pool, factory->message_factory); |
+ PyDescriptorPool* pool = GetDescriptorPoolForMessage(self); |
+ input.SetExtensionRegistry(pool->pool, pool->message_factory); |
bool success = self->message->MergePartialFromCodedStream(&input); |
if (success) { |
return PyInt_FromLong(input.CurrentPosition()); |
@@ -2051,29 +1952,99 @@ static PyObject* ByteSize(CMessage* self, PyObject* args) { |
return PyLong_FromLong(self->message->ByteSize()); |
} |
-PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) { |
+static PyObject* RegisterExtension(PyObject* cls, |
+ PyObject* extension_handle) { |
const FieldDescriptor* descriptor = |
GetExtensionDescriptor(extension_handle); |
if (descriptor == NULL) { |
return NULL; |
} |
- if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) { |
- PyErr_Format(PyExc_TypeError, "Expected a message class, got %s", |
- cls->ob_type->tp_name); |
+ |
+ ScopedPyObjectPtr extensions_by_name( |
+ PyObject_GetAttr(cls, k_extensions_by_name)); |
+ if (extensions_by_name == NULL) { |
+ PyErr_SetString(PyExc_TypeError, "no extensions_by_name on class"); |
return NULL; |
} |
- CMessageClass *message_class = reinterpret_cast<CMessageClass*>(cls); |
- if (message_class == NULL) { |
+ ScopedPyObjectPtr full_name(PyObject_GetAttr(extension_handle, kfull_name)); |
+ if (full_name == NULL) { |
return NULL; |
} |
+ |
// If the extension was already registered, check that it is the same. |
- const FieldDescriptor* existing_extension = |
- message_class->py_message_factory->pool->pool->FindExtensionByNumber( |
- descriptor->containing_type(), descriptor->number()); |
- if (existing_extension != NULL && existing_extension != descriptor) { |
- PyErr_SetString(PyExc_ValueError, "Double registration of Extensions"); |
+ PyObject* existing_extension = |
+ PyDict_GetItem(extensions_by_name.get(), full_name.get()); |
+ if (existing_extension != NULL) { |
+ const FieldDescriptor* existing_extension_descriptor = |
+ GetExtensionDescriptor(existing_extension); |
+ if (existing_extension_descriptor != descriptor) { |
+ PyErr_SetString(PyExc_ValueError, "Double registration of Extensions"); |
+ return NULL; |
+ } |
+ // Nothing else to do. |
+ Py_RETURN_NONE; |
+ } |
+ |
+ if (PyDict_SetItem(extensions_by_name.get(), full_name.get(), |
+ extension_handle) < 0) { |
+ return NULL; |
+ } |
+ |
+ // Also store a mapping from extension number to implementing class. |
+ ScopedPyObjectPtr extensions_by_number( |
+ PyObject_GetAttr(cls, k_extensions_by_number)); |
+ if (extensions_by_number == NULL) { |
+ PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class"); |
return NULL; |
} |
+ |
+ ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number")); |
+ if (number == NULL) { |
+ return NULL; |
+ } |
+ |
+ // If the extension was already registered by number, check that it is the |
+ // same. |
+ existing_extension = PyDict_GetItem(extensions_by_number.get(), number.get()); |
+ if (existing_extension != NULL) { |
+ const FieldDescriptor* existing_extension_descriptor = |
+ GetExtensionDescriptor(existing_extension); |
+ if (existing_extension_descriptor != descriptor) { |
+ const Descriptor* msg_desc = GetMessageDescriptor( |
+ reinterpret_cast<PyTypeObject*>(cls)); |
+ PyErr_Format( |
+ PyExc_ValueError, |
+ "Extensions \"%s\" and \"%s\" both try to extend message type " |
+ "\"%s\" with field number %ld.", |
+ existing_extension_descriptor->full_name().c_str(), |
+ descriptor->full_name().c_str(), |
+ msg_desc->full_name().c_str(), |
+ PyInt_AsLong(number.get())); |
+ return NULL; |
+ } |
+ // Nothing else to do. |
+ Py_RETURN_NONE; |
+ } |
+ if (PyDict_SetItem(extensions_by_number.get(), number.get(), |
+ extension_handle) < 0) { |
+ return NULL; |
+ } |
+ |
+ // Check if it's a message set |
+ if (descriptor->is_extension() && |
+ descriptor->containing_type()->options().message_set_wire_format() && |
+ descriptor->type() == FieldDescriptor::TYPE_MESSAGE && |
+ descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) { |
+ ScopedPyObjectPtr message_name(PyString_FromStringAndSize( |
+ descriptor->message_type()->full_name().c_str(), |
+ descriptor->message_type()->full_name().size())); |
+ if (message_name == NULL) { |
+ return NULL; |
+ } |
+ PyDict_SetItem(extensions_by_name.get(), message_name.get(), |
+ extension_handle); |
+ } |
+ |
Py_RETURN_NONE; |
} |
@@ -2110,7 +2081,7 @@ static PyObject* WhichOneof(CMessage* self, PyObject* arg) { |
static PyObject* GetExtensionDict(CMessage* self, void *closure); |
static PyObject* ListFields(CMessage* self) { |
- std::vector<const FieldDescriptor*> fields; |
+ vector<const FieldDescriptor*> fields; |
self->message->GetReflection()->ListFields(*self->message, &fields); |
// Normally, the list will be exactly the size of the fields. |
@@ -2140,8 +2111,8 @@ static PyObject* ListFields(CMessage* self) { |
// is no message class and we cannot retrieve the value. |
// TODO(amauryfa): consider building the class on the fly! |
if (fields[i]->message_type() != NULL && |
- message_factory::GetMessageClass( |
- GetFactoryForMessage(self), |
+ cdescriptor_pool::GetMessageClass( |
+ GetDescriptorPoolForMessage(self), |
fields[i]->message_type()) == NULL) { |
PyErr_Clear(); |
continue; |
@@ -2201,7 +2172,7 @@ static PyObject* DiscardUnknownFields(CMessage* self) { |
PyObject* FindInitializationErrors(CMessage* self) { |
Message* message = self->message; |
- std::vector<string> errors; |
+ vector<string> errors; |
message->FindInitializationErrors(&errors); |
PyObject* error_list = PyList_New(errors.size()); |
@@ -2338,12 +2309,12 @@ PyObject* InternalGetScalar(const Message* message, |
PyObject* InternalGetSubMessage( |
CMessage* self, const FieldDescriptor* field_descriptor) { |
const Reflection* reflection = self->message->GetReflection(); |
- PyMessageFactory* factory = GetFactoryForMessage(self); |
+ PyDescriptorPool* pool = GetDescriptorPoolForMessage(self); |
const Message& sub_message = reflection->GetMessage( |
- *self->message, field_descriptor, factory->message_factory); |
+ *self->message, field_descriptor, pool->message_factory); |
- CMessageClass* message_class = message_factory::GetMessageClass( |
- factory, field_descriptor->message_type()); |
+ CMessageClass* message_class = cdescriptor_pool::GetMessageClass( |
+ pool, field_descriptor->message_type()); |
if (message_class == NULL) { |
return NULL; |
} |
@@ -2593,24 +2564,11 @@ static PyObject* GetExtensionDict(CMessage* self, void *closure) { |
return NULL; |
} |
-static PyObject* GetExtensionsByName(CMessage *self, void *closure) { |
- return message_meta::GetExtensionsByName( |
- reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure); |
-} |
- |
-static PyObject* GetExtensionsByNumber(CMessage *self, void *closure) { |
- return message_meta::GetExtensionsByNumber( |
- reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure); |
-} |
- |
static PyGetSetDef Getters[] = { |
{"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"}, |
- {"_extensions_by_name", (getter)GetExtensionsByName, NULL}, |
- {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL}, |
{NULL} |
}; |
- |
static PyMethodDef Methods[] = { |
{ "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, |
"Makes a deep copy of the class." }, |
@@ -2701,8 +2659,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { |
const Descriptor* entry_type = field_descriptor->message_type(); |
const FieldDescriptor* value_type = entry_type->FindFieldByName("value"); |
if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
- CMessageClass* value_class = message_factory::GetMessageClass( |
- GetFactoryForMessage(self), value_type->message_type()); |
+ CMessageClass* value_class = cdescriptor_pool::GetMessageClass( |
+ GetDescriptorPoolForMessage(self), value_type->message_type()); |
if (value_class == NULL) { |
return NULL; |
} |
@@ -2724,8 +2682,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { |
if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { |
PyObject* py_container = NULL; |
if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
- CMessageClass* message_class = message_factory::GetMessageClass( |
- GetFactoryForMessage(self), field_descriptor->message_type()); |
+ CMessageClass* message_class = cdescriptor_pool::GetMessageClass( |
+ GetDescriptorPoolForMessage(self), field_descriptor->message_type()); |
if (message_class == NULL) { |
return NULL; |
} |
@@ -2820,7 +2778,7 @@ PyTypeObject CMessage_Type = { |
0, // tp_traverse |
0, // tp_clear |
(richcmpfunc)cmessage::RichCompare, // tp_richcompare |
- offsetof(CMessage, weakreflist), // tp_weaklistoffset |
+ 0, // tp_weaklistoffset |
0, // tp_iter |
0, // tp_iternext |
cmessage::Methods, // tp_methods |
@@ -2867,11 +2825,30 @@ static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) { |
return cmsg->message; |
} |
+static const char module_docstring[] = |
+"python-proto2 is a module that can be used to enhance proto2 Python API\n" |
+"performance.\n" |
+"\n" |
+"It provides access to the protocol buffers C++ reflection API that\n" |
+"implements the basic protocol buffer functions."; |
+ |
void InitGlobals() { |
// TODO(gps): Check all return values in this function for NULL and propagate |
// the error (MemoryError) on up to result in an import failure. These should |
// also be freed and reset to NULL during finalization. |
+ kPythonZero = PyInt_FromLong(0); |
+ kint32min_py = PyInt_FromLong(kint32min); |
+ kint32max_py = PyInt_FromLong(kint32max); |
+ kuint32max_py = PyLong_FromLongLong(kuint32max); |
+ kint64min_py = PyLong_FromLongLong(kint64min); |
+ kint64max_py = PyLong_FromLongLong(kint64max); |
+ kuint64max_py = PyLong_FromUnsignedLongLong(kuint64max); |
+ |
kDESCRIPTOR = PyString_FromString("DESCRIPTOR"); |
+ k_cdescriptor = PyString_FromString("_cdescriptor"); |
+ kfull_name = PyString_FromString("full_name"); |
+ k_extensions_by_name = PyString_FromString("_extensions_by_name"); |
+ k_extensions_by_number = PyString_FromString("_extensions_by_number"); |
PyObject *dummy_obj = PySet_New(NULL); |
kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL); |
@@ -2889,11 +2866,6 @@ bool InitProto2MessageModule(PyObject *m) { |
return false; |
} |
- // Initialize types and globals in message_factory.cc |
- if (!InitMessageFactory()) { |
- return false; |
- } |
- |
// Initialize constants defined in this file. |
InitGlobals(); |
@@ -2911,6 +2883,25 @@ bool InitProto2MessageModule(PyObject *m) { |
// DESCRIPTOR is set on each protocol buffer message class elsewhere, but set |
// it here as well to document that subclasses need to set it. |
PyDict_SetItem(CMessage_Type.tp_dict, kDESCRIPTOR, Py_None); |
+ // Subclasses with message extensions will override _extensions_by_name and |
+ // _extensions_by_number with fresh mutable dictionaries in AddDescriptors. |
+ // All other classes can share this same immutable mapping. |
+ ScopedPyObjectPtr empty_dict(PyDict_New()); |
+ if (empty_dict == NULL) { |
+ return false; |
+ } |
+ ScopedPyObjectPtr immutable_dict(PyDictProxy_New(empty_dict.get())); |
+ if (immutable_dict == NULL) { |
+ return false; |
+ } |
+ if (PyDict_SetItem(CMessage_Type.tp_dict, |
+ k_extensions_by_name, immutable_dict.get()) < 0) { |
+ return false; |
+ } |
+ if (PyDict_SetItem(CMessage_Type.tp_dict, |
+ k_extensions_by_number, immutable_dict.get()) < 0) { |
+ return false; |
+ } |
PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(&CMessage_Type)); |
@@ -2956,15 +2947,69 @@ bool InitProto2MessageModule(PyObject *m) { |
} |
// Initialize Map container types. |
- if (!InitMapContainers()) { |
- return false; |
+ { |
+ // ScalarMapContainer_Type derives from our MutableMapping type. |
+ ScopedPyObjectPtr containers(PyImport_ImportModule( |
+ "google.protobuf.internal.containers")); |
+ if (containers == NULL) { |
+ return false; |
+ } |
+ |
+ ScopedPyObjectPtr mutable_mapping( |
+ PyObject_GetAttrString(containers.get(), "MutableMapping")); |
+ if (mutable_mapping == NULL) { |
+ return false; |
+ } |
+ |
+ if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) { |
+ return false; |
+ } |
+ |
+ Py_INCREF(mutable_mapping.get()); |
+#if PY_MAJOR_VERSION >= 3 |
+ PyObject* bases = PyTuple_New(1); |
+ PyTuple_SET_ITEM(bases, 0, mutable_mapping.get()); |
+ |
+ ScalarMapContainer_Type = |
+ PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases); |
+ PyModule_AddObject(m, "ScalarMapContainer", ScalarMapContainer_Type); |
+#else |
+ ScalarMapContainer_Type.tp_base = |
+ reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); |
+ |
+ if (PyType_Ready(&ScalarMapContainer_Type) < 0) { |
+ return false; |
+ } |
+ |
+ PyModule_AddObject(m, "ScalarMapContainer", |
+ reinterpret_cast<PyObject*>(&ScalarMapContainer_Type)); |
+#endif |
+ |
+ if (PyType_Ready(&MapIterator_Type) < 0) { |
+ return false; |
+ } |
+ |
+ PyModule_AddObject(m, "MapIterator", |
+ reinterpret_cast<PyObject*>(&MapIterator_Type)); |
+ |
+ |
+#if PY_MAJOR_VERSION >= 3 |
+ MessageMapContainer_Type = |
+ PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases); |
+ PyModule_AddObject(m, "MessageMapContainer", MessageMapContainer_Type); |
+#else |
+ Py_INCREF(mutable_mapping.get()); |
+ MessageMapContainer_Type.tp_base = |
+ reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); |
+ |
+ if (PyType_Ready(&MessageMapContainer_Type) < 0) { |
+ return false; |
+ } |
+ |
+ PyModule_AddObject(m, "MessageMapContainer", |
+ reinterpret_cast<PyObject*>(&MessageMapContainer_Type)); |
+#endif |
} |
- PyModule_AddObject(m, "ScalarMapContainer", |
- reinterpret_cast<PyObject*>(ScalarMapContainer_Type)); |
- PyModule_AddObject(m, "MessageMapContainer", |
- reinterpret_cast<PyObject*>(MessageMapContainer_Type)); |
- PyModule_AddObject(m, "MapIterator", |
- reinterpret_cast<PyObject*>(&MapIterator_Type)); |
if (PyType_Ready(&ExtensionDict_Type) < 0) { |
return false; |
@@ -2999,10 +3044,6 @@ bool InitProto2MessageModule(PyObject *m) { |
&PyFileDescriptor_Type)); |
PyModule_AddObject(m, "OneofDescriptor", reinterpret_cast<PyObject*>( |
&PyOneofDescriptor_Type)); |
- PyModule_AddObject(m, "ServiceDescriptor", reinterpret_cast<PyObject*>( |
- &PyServiceDescriptor_Type)); |
- PyModule_AddObject(m, "MethodDescriptor", reinterpret_cast<PyObject*>( |
- &PyMethodDescriptor_Type)); |
PyObject* enum_type_wrapper = PyImport_ImportModule( |
"google.protobuf.internal.enum_type_wrapper"); |
@@ -3040,4 +3081,53 @@ bool InitProto2MessageModule(PyObject *m) { |
} // namespace python |
} // namespace protobuf |
+static PyMethodDef ModuleMethods[] = { |
+ {"SetAllowOversizeProtos", |
+ (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos, |
+ METH_O, "Enable/disable oversize proto parsing."}, |
+ { NULL, NULL} |
+}; |
+ |
+#if PY_MAJOR_VERSION >= 3 |
+static struct PyModuleDef _module = { |
+ PyModuleDef_HEAD_INIT, |
+ "_message", |
+ google::protobuf::python::module_docstring, |
+ -1, |
+ ModuleMethods, /* m_methods */ |
+ NULL, |
+ NULL, |
+ NULL, |
+ NULL |
+}; |
+#define INITFUNC PyInit__message |
+#define INITFUNC_ERRORVAL NULL |
+#else // Python 2 |
+#define INITFUNC init_message |
+#define INITFUNC_ERRORVAL |
+#endif |
+ |
+extern "C" { |
+ PyMODINIT_FUNC INITFUNC(void) { |
+ PyObject* m; |
+#if PY_MAJOR_VERSION >= 3 |
+ m = PyModule_Create(&_module); |
+#else |
+ m = Py_InitModule3("_message", ModuleMethods, |
+ google::protobuf::python::module_docstring); |
+#endif |
+ if (m == NULL) { |
+ return INITFUNC_ERRORVAL; |
+ } |
+ |
+ if (!google::protobuf::python::InitProto2MessageModule(m)) { |
+ Py_DECREF(m); |
+ return INITFUNC_ERRORVAL; |
+ } |
+ |
+#if PY_MAJOR_VERSION >= 3 |
+ return m; |
+#endif |
+ } |
+} |
} // namespace google |