Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(37)

Unified Diff: tools/telemetry/third_party/gsutilz/third_party/protorpc/protorpc/messages_test.py

Issue 1264873003: Add gsutil/third_party to telemetry/third_party/gsutilz/third_party. (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@master
Patch Set: Remove httplib2 Created 5 years, 5 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
Index: tools/telemetry/third_party/gsutilz/third_party/protorpc/protorpc/messages_test.py
diff --git a/tools/telemetry/third_party/gsutilz/third_party/protorpc/protorpc/messages_test.py b/tools/telemetry/third_party/gsutilz/third_party/protorpc/protorpc/messages_test.py
new file mode 100755
index 0000000000000000000000000000000000000000..1e72b131b2918a01de7e045d0f5f59d2cb51b746
--- /dev/null
+++ b/tools/telemetry/third_party/gsutilz/third_party/protorpc/protorpc/messages_test.py
@@ -0,0 +1,2096 @@
+#!/usr/bin/env python
+#
+# Copyright 2010 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for protorpc.messages."""
+import six
+
+__author__ = 'rafek@google.com (Rafe Kaplan)'
+
+
+import pickle
+import re
+import sys
+import types
+import unittest
+
+from protorpc import descriptor
+from protorpc import message_types
+from protorpc import messages
+from protorpc import test_util
+
+
+class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
+ test_util.TestCase):
+
+ MODULE = messages
+
+
+class ValidationErrorTest(test_util.TestCase):
+
+ def testStr_NoFieldName(self):
+ """Test string version of ValidationError when no name provided."""
+ self.assertEquals('Validation error',
+ str(messages.ValidationError('Validation error')))
+
+ def testStr_FieldName(self):
+ """Test string version of ValidationError when no name provided."""
+ validation_error = messages.ValidationError('Validation error')
+ validation_error.field_name = 'a_field'
+ self.assertEquals('Validation error', str(validation_error))
+
+
+class EnumTest(test_util.TestCase):
+
+ def setUp(self):
+ """Set up tests."""
+ # Redefine Color class in case so that changes to it (an error) in one test
+ # does not affect other tests.
+ global Color
+ class Color(messages.Enum):
+ RED = 20
+ ORANGE = 2
+ YELLOW = 40
+ GREEN = 4
+ BLUE = 50
+ INDIGO = 5
+ VIOLET = 80
+
+ def testNames(self):
+ """Test that names iterates over enum names."""
+ self.assertEquals(
+ set(['BLUE', 'GREEN', 'INDIGO', 'ORANGE', 'RED', 'VIOLET', 'YELLOW']),
+ set(Color.names()))
+
+ def testNumbers(self):
+ """Tests that numbers iterates of enum numbers."""
+ self.assertEquals(set([2, 4, 5, 20, 40, 50, 80]), set(Color.numbers()))
+
+ def testIterate(self):
+ """Test that __iter__ iterates over all enum values."""
+ self.assertEquals(set(Color),
+ set([Color.RED,
+ Color.ORANGE,
+ Color.YELLOW,
+ Color.GREEN,
+ Color.BLUE,
+ Color.INDIGO,
+ Color.VIOLET]))
+
+ def testNaturalOrder(self):
+ """Test that natural order enumeration is in numeric order."""
+ self.assertEquals([Color.ORANGE,
+ Color.GREEN,
+ Color.INDIGO,
+ Color.RED,
+ Color.YELLOW,
+ Color.BLUE,
+ Color.VIOLET],
+ sorted(Color))
+
+ def testByName(self):
+ """Test look-up by name."""
+ self.assertEquals(Color.RED, Color.lookup_by_name('RED'))
+ self.assertRaises(KeyError, Color.lookup_by_name, 20)
+ self.assertRaises(KeyError, Color.lookup_by_name, Color.RED)
+
+ def testByNumber(self):
+ """Test look-up by number."""
+ self.assertRaises(KeyError, Color.lookup_by_number, 'RED')
+ self.assertEquals(Color.RED, Color.lookup_by_number(20))
+ self.assertRaises(KeyError, Color.lookup_by_number, Color.RED)
+
+ def testConstructor(self):
+ """Test that constructor look-up by name or number."""
+ self.assertEquals(Color.RED, Color('RED'))
+ self.assertEquals(Color.RED, Color(u'RED'))
+ self.assertEquals(Color.RED, Color(20))
+ if six.PY2:
+ self.assertEquals(Color.RED, Color(long(20)))
+ self.assertEquals(Color.RED, Color(Color.RED))
+ self.assertRaises(TypeError, Color, 'Not exists')
+ self.assertRaises(TypeError, Color, 'Red')
+ self.assertRaises(TypeError, Color, 100)
+ self.assertRaises(TypeError, Color, 10.0)
+
+ def testLen(self):
+ """Test that len function works to count enums."""
+ self.assertEquals(7, len(Color))
+
+ def testNoSubclasses(self):
+ """Test that it is not possible to sub-class enum classes."""
+ def declare_subclass():
+ class MoreColor(Color):
+ pass
+ self.assertRaises(messages.EnumDefinitionError,
+ declare_subclass)
+
+ def testClassNotMutable(self):
+ """Test that enum classes themselves are not mutable."""
+ self.assertRaises(AttributeError,
+ setattr,
+ Color,
+ 'something_new',
+ 10)
+
+ def testInstancesMutable(self):
+ """Test that enum instances are not mutable."""
+ self.assertRaises(TypeError,
+ setattr,
+ Color.RED,
+ 'something_new',
+ 10)
+
+ def testDefEnum(self):
+ """Test def_enum works by building enum class from dict."""
+ WeekDay = messages.Enum.def_enum({'Monday': 1,
+ 'Tuesday': 2,
+ 'Wednesday': 3,
+ 'Thursday': 4,
+ 'Friday': 6,
+ 'Saturday': 7,
+ 'Sunday': 8},
+ 'WeekDay')
+ self.assertEquals('Wednesday', WeekDay(3).name)
+ self.assertEquals(6, WeekDay('Friday').number)
+ self.assertEquals(WeekDay.Sunday, WeekDay('Sunday'))
+
+ def testNonInt(self):
+ """Test that non-integer values rejection by enum def."""
+ self.assertRaises(messages.EnumDefinitionError,
+ messages.Enum.def_enum,
+ {'Bad': '1'},
+ 'BadEnum')
+
+ def testNegativeInt(self):
+ """Test that negative numbers rejection by enum def."""
+ self.assertRaises(messages.EnumDefinitionError,
+ messages.Enum.def_enum,
+ {'Bad': -1},
+ 'BadEnum')
+
+ def testLowerBound(self):
+ """Test that zero is accepted by enum def."""
+ class NotImportant(messages.Enum):
+ """Testing for value zero"""
+ VALUE = 0
+
+ self.assertEquals(0, int(NotImportant.VALUE))
+
+ def testTooLargeInt(self):
+ """Test that numbers too large are rejected."""
+ self.assertRaises(messages.EnumDefinitionError,
+ messages.Enum.def_enum,
+ {'Bad': (2 ** 29)},
+ 'BadEnum')
+
+ def testRepeatedInt(self):
+ """Test duplicated numbers are forbidden."""
+ self.assertRaises(messages.EnumDefinitionError,
+ messages.Enum.def_enum,
+ {'Ok': 1, 'Repeated': 1},
+ 'BadEnum')
+
+ def testStr(self):
+ """Test converting to string."""
+ self.assertEquals('RED', str(Color.RED))
+ self.assertEquals('ORANGE', str(Color.ORANGE))
+
+ def testInt(self):
+ """Test converting to int."""
+ self.assertEquals(20, int(Color.RED))
+ self.assertEquals(2, int(Color.ORANGE))
+
+ def testRepr(self):
+ """Test enum representation."""
+ self.assertEquals('Color(RED, 20)', repr(Color.RED))
+ self.assertEquals('Color(YELLOW, 40)', repr(Color.YELLOW))
+
+ def testDocstring(self):
+ """Test that docstring is supported ok."""
+ class NotImportant(messages.Enum):
+ """I have a docstring."""
+
+ VALUE1 = 1
+
+ self.assertEquals('I have a docstring.', NotImportant.__doc__)
+
+ def testDeleteEnumValue(self):
+ """Test that enum values cannot be deleted."""
+ self.assertRaises(TypeError, delattr, Color, 'RED')
+
+ def testEnumName(self):
+ """Test enum name."""
+ module_name = test_util.get_module_name(EnumTest)
+ self.assertEquals('%s.Color' % module_name, Color.definition_name())
+ self.assertEquals(module_name, Color.outer_definition_name())
+ self.assertEquals(module_name, Color.definition_package())
+
+ def testDefinitionName_OverrideModule(self):
+ """Test enum module is overriden by module package name."""
+ global package
+ try:
+ package = 'my.package'
+ self.assertEquals('my.package.Color', Color.definition_name())
+ self.assertEquals('my.package', Color.outer_definition_name())
+ self.assertEquals('my.package', Color.definition_package())
+ finally:
+ del package
+
+ def testDefinitionName_NoModule(self):
+ """Test what happens when there is no module for enum."""
+ class Enum1(messages.Enum):
+ pass
+
+ original_modules = sys.modules
+ sys.modules = dict(sys.modules)
+ try:
+ del sys.modules[__name__]
+ self.assertEquals('Enum1', Enum1.definition_name())
+ self.assertEquals(None, Enum1.outer_definition_name())
+ self.assertEquals(None, Enum1.definition_package())
+ self.assertEquals(six.text_type, type(Enum1.definition_name()))
+ finally:
+ sys.modules = original_modules
+
+ def testDefinitionName_Nested(self):
+ """Test nested Enum names."""
+ class MyMessage(messages.Message):
+
+ class NestedEnum(messages.Enum):
+
+ pass
+
+ class NestedMessage(messages.Message):
+
+ class NestedEnum(messages.Enum):
+
+ pass
+
+ module_name = test_util.get_module_name(EnumTest)
+ self.assertEquals('%s.MyMessage.NestedEnum' % module_name,
+ MyMessage.NestedEnum.definition_name())
+ self.assertEquals('%s.MyMessage' % module_name,
+ MyMessage.NestedEnum.outer_definition_name())
+ self.assertEquals(module_name,
+ MyMessage.NestedEnum.definition_package())
+
+ self.assertEquals('%s.MyMessage.NestedMessage.NestedEnum' % module_name,
+ MyMessage.NestedMessage.NestedEnum.definition_name())
+ self.assertEquals(
+ '%s.MyMessage.NestedMessage' % module_name,
+ MyMessage.NestedMessage.NestedEnum.outer_definition_name())
+ self.assertEquals(module_name,
+ MyMessage.NestedMessage.NestedEnum.definition_package())
+
+ def testMessageDefinition(self):
+ """Test that enumeration knows its enclosing message definition."""
+ class OuterEnum(messages.Enum):
+ pass
+
+ self.assertEquals(None, OuterEnum.message_definition())
+
+ class OuterMessage(messages.Message):
+
+ class InnerEnum(messages.Enum):
+ pass
+
+ self.assertEquals(OuterMessage, OuterMessage.InnerEnum.message_definition())
+
+ def testComparison(self):
+ """Test comparing various enums to different types."""
+ class Enum1(messages.Enum):
+ VAL1 = 1
+ VAL2 = 2
+
+ class Enum2(messages.Enum):
+ VAL1 = 1
+
+ self.assertEquals(Enum1.VAL1, Enum1.VAL1)
+ self.assertNotEquals(Enum1.VAL1, Enum1.VAL2)
+ self.assertNotEquals(Enum1.VAL1, Enum2.VAL1)
+ self.assertNotEquals(Enum1.VAL1, 'VAL1')
+ self.assertNotEquals(Enum1.VAL1, 1)
+ self.assertNotEquals(Enum1.VAL1, 2)
+ self.assertNotEquals(Enum1.VAL1, None)
+ self.assertNotEquals(Enum1.VAL1, Enum2.VAL1)
+
+ self.assertTrue(Enum1.VAL1 < Enum1.VAL2)
+ self.assertTrue(Enum1.VAL2 > Enum1.VAL1)
+
+ self.assertNotEquals(1, Enum2.VAL1)
+
+ def testPickle(self):
+ """Testing pickling and unpickling of Enum instances."""
+ colors = list(Color)
+ unpickled = pickle.loads(pickle.dumps(colors))
+ self.assertEquals(colors, unpickled)
+ # Unpickling shouldn't create new enum instances.
+ for i, color in enumerate(colors):
+ self.assertTrue(color is unpickled[i])
+
+
+class FieldListTest(test_util.TestCase):
+
+ def setUp(self):
+ self.integer_field = messages.IntegerField(1, repeated=True)
+
+ def testConstructor(self):
+ self.assertEquals([1, 2, 3],
+ messages.FieldList(self.integer_field, [1, 2, 3]))
+ self.assertEquals([1, 2, 3],
+ messages.FieldList(self.integer_field, (1, 2, 3)))
+ self.assertEquals([], messages.FieldList(self.integer_field, []))
+
+ def testNone(self):
+ self.assertRaises(TypeError, messages.FieldList, self.integer_field, None)
+
+ def testDoNotAutoConvertString(self):
+ string_field = messages.StringField(1, repeated=True)
+ self.assertRaises(messages.ValidationError,
+ messages.FieldList, string_field, 'abc')
+
+ def testConstructorCopies(self):
+ a_list = [1, 3, 6]
+ field_list = messages.FieldList(self.integer_field, a_list)
+ self.assertFalse(a_list is field_list)
+ self.assertFalse(field_list is
+ messages.FieldList(self.integer_field, field_list))
+
+ def testNonRepeatedField(self):
+ self.assertRaisesWithRegexpMatch(
+ messages.FieldDefinitionError,
+ 'FieldList may only accept repeated fields',
+ messages.FieldList,
+ messages.IntegerField(1),
+ [])
+
+ def testConstructor_InvalidValues(self):
+ self.assertRaisesWithRegexpMatch(
+ messages.ValidationError,
+ re.escape("Expected type %r "
+ "for IntegerField, found 1 (type %r)"
+ % (six.integer_types, str)),
+ messages.FieldList, self.integer_field, ["1", "2", "3"])
+
+ def testConstructor_Scalars(self):
+ self.assertRaisesWithRegexpMatch(
+ messages.ValidationError,
+ "IntegerField is repeated. Found: 3",
+ messages.FieldList, self.integer_field, 3)
+
+ self.assertRaisesWithRegexpMatch(
+ messages.ValidationError,
+ "IntegerField is repeated. Found: <(list[_]?|sequence)iterator object",
+ messages.FieldList, self.integer_field, iter([1, 2, 3]))
+
+ def testSetSlice(self):
+ field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
+ field_list[1:3] = [10, 20]
+ self.assertEquals([1, 10, 20, 4, 5], field_list)
+
+ def testSetSlice_InvalidValues(self):
+ field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
+
+ def setslice():
+ field_list[1:3] = ['10', '20']
+
+ msg_re = re.escape("Expected type %r "
+ "for IntegerField, found 10 (type %r)"
+ % (six.integer_types, str))
+ self.assertRaisesWithRegexpMatch(
+ messages.ValidationError,
+ msg_re,
+ setslice)
+
+ def testSetItem(self):
+ field_list = messages.FieldList(self.integer_field, [2])
+ field_list[0] = 10
+ self.assertEquals([10], field_list)
+
+ def testSetItem_InvalidValues(self):
+ field_list = messages.FieldList(self.integer_field, [2])
+
+ def setitem():
+ field_list[0] = '10'
+ self.assertRaisesWithRegexpMatch(
+ messages.ValidationError,
+ re.escape("Expected type %r "
+ "for IntegerField, found 10 (type %r)"
+ % (six.integer_types, str)),
+ setitem)
+
+ def testAppend(self):
+ field_list = messages.FieldList(self.integer_field, [2])
+ field_list.append(10)
+ self.assertEquals([2, 10], field_list)
+
+ def testAppend_InvalidValues(self):
+ field_list = messages.FieldList(self.integer_field, [2])
+ field_list.name = 'a_field'
+
+ def append():
+ field_list.append('10')
+ self.assertRaisesWithRegexpMatch(
+ messages.ValidationError,
+ re.escape("Expected type %r "
+ "for IntegerField, found 10 (type %r)"
+ % (six.integer_types, str)),
+ append)
+
+ def testExtend(self):
+ field_list = messages.FieldList(self.integer_field, [2])
+ field_list.extend([10])
+ self.assertEquals([2, 10], field_list)
+
+ def testExtend_InvalidValues(self):
+ field_list = messages.FieldList(self.integer_field, [2])
+
+ def extend():
+ field_list.extend(['10'])
+ self.assertRaisesWithRegexpMatch(
+ messages.ValidationError,
+ re.escape("Expected type %r "
+ "for IntegerField, found 10 (type %r)"
+ % (six.integer_types, str)),
+ extend)
+
+ def testInsert(self):
+ field_list = messages.FieldList(self.integer_field, [2, 3])
+ field_list.insert(1, 10)
+ self.assertEquals([2, 10, 3], field_list)
+
+ def testInsert_InvalidValues(self):
+ field_list = messages.FieldList(self.integer_field, [2, 3])
+
+ def insert():
+ field_list.insert(1, '10')
+ self.assertRaisesWithRegexpMatch(
+ messages.ValidationError,
+ re.escape("Expected type %r "
+ "for IntegerField, found 10 (type %r)"
+ % (six.integer_types, str)),
+ insert)
+
+ def testPickle(self):
+ """Testing pickling and unpickling of disconnected FieldList instances."""
+ field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
+ unpickled = pickle.loads(pickle.dumps(field_list))
+ self.assertEquals(field_list, unpickled)
+ self.assertIsInstance(unpickled.field, messages.IntegerField)
+ self.assertEquals(1, unpickled.field.number)
+ self.assertTrue(unpickled.field.repeated)
+
+
+class FieldTest(test_util.TestCase):
+
+ def ActionOnAllFieldClasses(self, action):
+ """Test all field classes except Message and Enum.
+
+ Message and Enum require separate tests.
+
+ Args:
+ action: Callable that takes the field class as a parameter.
+ """
+ for field_class in (messages.IntegerField,
+ messages.FloatField,
+ messages.BooleanField,
+ messages.BytesField,
+ messages.StringField,
+ ):
+ action(field_class)
+
+ def testNumberAttribute(self):
+ """Test setting the number attribute."""
+ def action(field_class):
+ # Check range.
+ self.assertRaises(messages.InvalidNumberError,
+ field_class,
+ 0)
+ self.assertRaises(messages.InvalidNumberError,
+ field_class,
+ -1)
+ self.assertRaises(messages.InvalidNumberError,
+ field_class,
+ messages.MAX_FIELD_NUMBER + 1)
+
+ # Check reserved.
+ self.assertRaises(messages.InvalidNumberError,
+ field_class,
+ messages.FIRST_RESERVED_FIELD_NUMBER)
+ self.assertRaises(messages.InvalidNumberError,
+ field_class,
+ messages.LAST_RESERVED_FIELD_NUMBER)
+ self.assertRaises(messages.InvalidNumberError,
+ field_class,
+ '1')
+
+ # This one should work.
+ field_class(number=1)
+ self.ActionOnAllFieldClasses(action)
+
+ def testRequiredAndRepeated(self):
+ """Test setting the required and repeated fields."""
+ def action(field_class):
+ field_class(1, required=True)
+ field_class(1, repeated=True)
+ self.assertRaises(messages.FieldDefinitionError,
+ field_class,
+ 1,
+ required=True,
+ repeated=True)
+ self.ActionOnAllFieldClasses(action)
+
+ def testInvalidVariant(self):
+ """Test field with invalid variants."""
+ def action(field_class):
+ if field_class is not message_types.DateTimeField:
+ self.assertRaises(messages.InvalidVariantError,
+ field_class,
+ 1,
+ variant=messages.Variant.ENUM)
+ self.ActionOnAllFieldClasses(action)
+
+ def testDefaultVariant(self):
+ """Test that default variant is used when not set."""
+ def action(field_class):
+ field = field_class(1)
+ self.assertEquals(field_class.DEFAULT_VARIANT, field.variant)
+
+ self.ActionOnAllFieldClasses(action)
+
+ def testAlternateVariant(self):
+ """Test that default variant is used when not set."""
+ field = messages.IntegerField(1, variant=messages.Variant.UINT32)
+ self.assertEquals(messages.Variant.UINT32, field.variant)
+
+ def testDefaultFields_Single(self):
+ """Test default field is correct type (single)."""
+ defaults = {messages.IntegerField: 10,
+ messages.FloatField: 1.5,
+ messages.BooleanField: False,
+ messages.BytesField: b'abc',
+ messages.StringField: u'abc',
+ }
+
+ def action(field_class):
+ field_class(1, default=defaults[field_class])
+ self.ActionOnAllFieldClasses(action)
+
+ # Run defaults test again checking for str/unicode compatiblity.
+ defaults[messages.StringField] = 'abc'
+ self.ActionOnAllFieldClasses(action)
+
+ def testStringField_BadUnicodeInDefault(self):
+ """Test binary values in string field."""
+ self.assertRaisesWithRegexpMatch(
+ messages.InvalidDefaultError,
+ r"Invalid default value for StringField:.*: "
+ r"Field encountered non-ASCII string .*: "
+ r"'ascii' codec can't decode byte 0x89 in position 0: "
+ r"ordinal not in range",
+ messages.StringField, 1, default=b'\x89')
+
+ def testDefaultFields_InvalidSingle(self):
+ """Test default field is correct type (invalid single)."""
+ def action(field_class):
+ self.assertRaises(messages.InvalidDefaultError,
+ field_class,
+ 1,
+ default=object())
+ self.ActionOnAllFieldClasses(action)
+
+ def testDefaultFields_InvalidRepeated(self):
+ """Test default field does not accept defaults."""
+ self.assertRaisesWithRegexpMatch(
+ messages.FieldDefinitionError,
+ 'Repeated fields may not have defaults',
+ messages.StringField, 1, repeated=True, default=[1, 2, 3])
+
+ def testDefaultFields_None(self):
+ """Test none is always acceptable."""
+ def action(field_class):
+ field_class(1, default=None)
+ field_class(1, required=True, default=None)
+ field_class(1, repeated=True, default=None)
+ self.ActionOnAllFieldClasses(action)
+
+ def testDefaultFields_Enum(self):
+ """Test the default for enum fields."""
+ class Symbol(messages.Enum):
+
+ ALPHA = 1
+ BETA = 2
+ GAMMA = 3
+
+ field = messages.EnumField(Symbol, 1, default=Symbol.ALPHA)
+
+ self.assertEquals(Symbol.ALPHA, field.default)
+
+ def testDefaultFields_EnumStringDelayedResolution(self):
+ """Test that enum fields resolve default strings."""
+ field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label',
+ 1,
+ default='OPTIONAL')
+
+ self.assertEquals(descriptor.FieldDescriptor.Label.OPTIONAL, field.default)
+
+ def testDefaultFields_EnumIntDelayedResolution(self):
+ """Test that enum fields resolve default integers."""
+ field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label',
+ 1,
+ default=2)
+
+ self.assertEquals(descriptor.FieldDescriptor.Label.REQUIRED, field.default)
+
+ def testDefaultFields_EnumOkIfTypeKnown(self):
+ """Test that enum fields accept valid default values when type is known."""
+ field = messages.EnumField(descriptor.FieldDescriptor.Label,
+ 1,
+ default='REPEATED')
+
+ self.assertEquals(descriptor.FieldDescriptor.Label.REPEATED, field.default)
+
+ def testDefaultFields_EnumForceCheckIfTypeKnown(self):
+ """Test that enum fields validate default values if type is known."""
+ self.assertRaisesWithRegexpMatch(TypeError,
+ 'No such value for NOT_A_LABEL in '
+ 'Enum Label',
+ messages.EnumField,
+ descriptor.FieldDescriptor.Label,
+ 1,
+ default='NOT_A_LABEL')
+
+ def testDefaultFields_EnumInvalidDelayedResolution(self):
+ """Test that enum fields raise errors upon delayed resolution error."""
+ field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label',
+ 1,
+ default=200)
+
+ self.assertRaisesWithRegexpMatch(TypeError,
+ 'No such value for 200 in Enum Label',
+ getattr,
+ field,
+ 'default')
+
+ def testValidate_Valid(self):
+ """Test validation of valid values."""
+ values = {messages.IntegerField: 10,
+ messages.FloatField: 1.5,
+ messages.BooleanField: False,
+ messages.BytesField: b'abc',
+ messages.StringField: u'abc',
+ }
+ def action(field_class):
+ # Optional.
+ field = field_class(1)
+ field.validate(values[field_class])
+
+ # Required.
+ field = field_class(1, required=True)
+ field.validate(values[field_class])
+
+ # Repeated.
+ field = field_class(1, repeated=True)
+ field.validate([])
+ field.validate(())
+ field.validate([values[field_class]])
+ field.validate((values[field_class],))
+
+ # Right value, but not repeated.
+ self.assertRaises(messages.ValidationError,
+ field.validate,
+ values[field_class])
+ self.assertRaises(messages.ValidationError,
+ field.validate,
+ values[field_class])
+
+ self.ActionOnAllFieldClasses(action)
+
+ def testValidate_Invalid(self):
+ """Test validation of valid values."""
+ values = {messages.IntegerField: "10",
+ messages.FloatField: 1,
+ messages.BooleanField: 0,
+ messages.BytesField: 10.20,
+ messages.StringField: 42,
+ }
+ def action(field_class):
+ # Optional.
+ field = field_class(1)
+ self.assertRaises(messages.ValidationError,
+ field.validate,
+ values[field_class])
+
+ # Required.
+ field = field_class(1, required=True)
+ self.assertRaises(messages.ValidationError,
+ field.validate,
+ values[field_class])
+
+ # Repeated.
+ field = field_class(1, repeated=True)
+ self.assertRaises(messages.ValidationError,
+ field.validate,
+ [values[field_class]])
+ self.assertRaises(messages.ValidationError,
+ field.validate,
+ (values[field_class],))
+ self.ActionOnAllFieldClasses(action)
+
+ def testValidate_None(self):
+ """Test that None is valid for non-required fields."""
+ def action(field_class):
+ # Optional.
+ field = field_class(1)
+ field.validate(None)
+
+ # Required.
+ field = field_class(1, required=True)
+ self.assertRaisesWithRegexpMatch(messages.ValidationError,
+ 'Required field is missing',
+ field.validate,
+ None)
+
+ # Repeated.
+ field = field_class(1, repeated=True)
+ field.validate(None)
+ self.assertRaisesWithRegexpMatch(messages.ValidationError,
+ 'Repeated values for %s may '
+ 'not be None' % field_class.__name__,
+ field.validate,
+ [None])
+ self.assertRaises(messages.ValidationError,
+ field.validate,
+ (None,))
+ self.ActionOnAllFieldClasses(action)
+
+ def testValidateElement(self):
+ """Test validation of valid values."""
+ values = {messages.IntegerField: 10,
+ messages.FloatField: 1.5,
+ messages.BooleanField: False,
+ messages.BytesField: 'abc',
+ messages.StringField: u'abc',
+ }
+ def action(field_class):
+ # Optional.
+ field = field_class(1)
+ field.validate_element(values[field_class])
+
+ # Required.
+ field = field_class(1, required=True)
+ field.validate_element(values[field_class])
+
+ # Repeated.
+ field = field_class(1, repeated=True)
+ self.assertRaises(message.VAlidationError,
+ field.validate_element,
+ [])
+ self.assertRaises(message.VAlidationError,
+ field.validate_element,
+ ())
+ field.validate_element(values[field_class])
+ field.validate_element(values[field_class])
+
+ # Right value, but repeated.
+ self.assertRaises(messages.ValidationError,
+ field.validate_element,
+ [values[field_class]])
+ self.assertRaises(messages.ValidationError,
+ field.validate_element,
+ (values[field_class],))
+
+ def testReadOnly(self):
+ """Test that objects are all read-only."""
+ def action(field_class):
+ field = field_class(10)
+ self.assertRaises(AttributeError,
+ setattr,
+ field,
+ 'number',
+ 20)
+ self.assertRaises(AttributeError,
+ setattr,
+ field,
+ 'anything_else',
+ 'whatever')
+ self.ActionOnAllFieldClasses(action)
+
+ def testMessageField(self):
+ """Test the construction of message fields."""
+ self.assertRaises(messages.FieldDefinitionError,
+ messages.MessageField,
+ str,
+ 10)
+
+ self.assertRaises(messages.FieldDefinitionError,
+ messages.MessageField,
+ messages.Message,
+ 10)
+
+ class MyMessage(messages.Message):
+ pass
+
+ field = messages.MessageField(MyMessage, 10)
+ self.assertEquals(MyMessage, field.type)
+
+ def testMessageField_ForwardReference(self):
+ """Test the construction of forward reference message fields."""
+ global MyMessage
+ global ForwardMessage
+ try:
+ class MyMessage(messages.Message):
+
+ self_reference = messages.MessageField('MyMessage', 1)
+ forward = messages.MessageField('ForwardMessage', 2)
+ nested = messages.MessageField('ForwardMessage.NestedMessage', 3)
+ inner = messages.MessageField('Inner', 4)
+
+ class Inner(messages.Message):
+
+ sibling = messages.MessageField('Sibling', 1)
+
+ class Sibling(messages.Message):
+
+ pass
+
+ class ForwardMessage(messages.Message):
+
+ class NestedMessage(messages.Message):
+
+ pass
+
+ self.assertEquals(MyMessage,
+ MyMessage.field_by_name('self_reference').type)
+
+ self.assertEquals(ForwardMessage,
+ MyMessage.field_by_name('forward').type)
+
+ self.assertEquals(ForwardMessage.NestedMessage,
+ MyMessage.field_by_name('nested').type)
+
+ self.assertEquals(MyMessage.Inner,
+ MyMessage.field_by_name('inner').type)
+
+ self.assertEquals(MyMessage.Sibling,
+ MyMessage.Inner.field_by_name('sibling').type)
+ finally:
+ try:
+ del MyMessage
+ del ForwardMessage
+ except:
+ pass
+
+ def testMessageField_WrongType(self):
+ """Test that forward referencing the wrong type raises an error."""
+ global AnEnum
+ try:
+ class AnEnum(messages.Enum):
+ pass
+
+ class AnotherMessage(messages.Message):
+
+ a_field = messages.MessageField('AnEnum', 1)
+
+ self.assertRaises(messages.FieldDefinitionError,
+ getattr,
+ AnotherMessage.field_by_name('a_field'),
+ 'type')
+ finally:
+ del AnEnum
+
+ def testMessageFieldValidate(self):
+ """Test validation on message field."""
+ class MyMessage(messages.Message):
+ pass
+
+ class AnotherMessage(messages.Message):
+ pass
+
+ field = messages.MessageField(MyMessage, 10)
+ field.validate(MyMessage())
+
+ self.assertRaises(messages.ValidationError,
+ field.validate,
+ AnotherMessage())
+
+ def testMessageFieldMessageType(self):
+ """Test message_type property."""
+ class MyMessage(messages.Message):
+ pass
+
+ class HasMessage(messages.Message):
+ field = messages.MessageField(MyMessage, 1)
+
+ self.assertEqual(HasMessage.field.type, HasMessage.field.message_type)
+
+ def testMessageFieldValueFromMessage(self):
+ class MyMessage(messages.Message):
+ pass
+
+ class HasMessage(messages.Message):
+ field = messages.MessageField(MyMessage, 1)
+
+ instance = MyMessage()
+
+ self.assertTrue(instance is HasMessage.field.value_from_message(instance))
+
+ def testMessageFieldValueFromMessageWrongType(self):
+ class MyMessage(messages.Message):
+ pass
+
+ class HasMessage(messages.Message):
+ field = messages.MessageField(MyMessage, 1)
+
+ self.assertRaisesWithRegexpMatch(
+ messages.DecodeError,
+ 'Expected type MyMessage, got int: 10',
+ HasMessage.field.value_from_message, 10)
+
+ def testMessageFieldValueToMessage(self):
+ class MyMessage(messages.Message):
+ pass
+
+ class HasMessage(messages.Message):
+ field = messages.MessageField(MyMessage, 1)
+
+ instance = MyMessage()
+
+ self.assertTrue(instance is HasMessage.field.value_to_message(instance))
+
+ def testMessageFieldValueToMessageWrongType(self):
+ class MyMessage(messages.Message):
+ pass
+
+ class MyOtherMessage(messages.Message):
+ pass
+
+ class HasMessage(messages.Message):
+ field = messages.MessageField(MyMessage, 1)
+
+ instance = MyOtherMessage()
+
+ self.assertRaisesWithRegexpMatch(
+ messages.EncodeError,
+ 'Expected type MyMessage, got MyOtherMessage: <MyOtherMessage>',
+ HasMessage.field.value_to_message, instance)
+
+ def testIntegerField_AllowLong(self):
+ """Test that the integer field allows for longs."""
+ if six.PY2:
+ messages.IntegerField(10, default=long(10))
+
+ def testMessageFieldValidate_Initialized(self):
+ """Test validation on message field."""
+ class MyMessage(messages.Message):
+ field1 = messages.IntegerField(1, required=True)
+
+ field = messages.MessageField(MyMessage, 10)
+
+ # Will validate messages where is_initialized() is False.
+ message = MyMessage()
+ field.validate(message)
+ message.field1 = 20
+ field.validate(message)
+
+ def testEnumField(self):
+ """Test the construction of enum fields."""
+ self.assertRaises(messages.FieldDefinitionError,
+ messages.EnumField,
+ str,
+ 10)
+
+ self.assertRaises(messages.FieldDefinitionError,
+ messages.EnumField,
+ messages.Enum,
+ 10)
+
+ class Color(messages.Enum):
+ RED = 1
+ GREEN = 2
+ BLUE = 3
+
+ field = messages.EnumField(Color, 10)
+ self.assertEquals(Color, field.type)
+
+ class Another(messages.Enum):
+ VALUE = 1
+
+ self.assertRaises(messages.InvalidDefaultError,
+ messages.EnumField,
+ Color,
+ 10,
+ default=Another.VALUE)
+
+ def testEnumField_ForwardReference(self):
+ """Test the construction of forward reference enum fields."""
+ global MyMessage
+ global ForwardEnum
+ global ForwardMessage
+ try:
+ class MyMessage(messages.Message):
+
+ forward = messages.EnumField('ForwardEnum', 1)
+ nested = messages.EnumField('ForwardMessage.NestedEnum', 2)
+ inner = messages.EnumField('Inner', 3)
+
+ class Inner(messages.Enum):
+ pass
+
+ class ForwardEnum(messages.Enum):
+ pass
+
+ class ForwardMessage(messages.Message):
+
+ class NestedEnum(messages.Enum):
+ pass
+
+ self.assertEquals(ForwardEnum,
+ MyMessage.field_by_name('forward').type)
+
+ self.assertEquals(ForwardMessage.NestedEnum,
+ MyMessage.field_by_name('nested').type)
+
+ self.assertEquals(MyMessage.Inner,
+ MyMessage.field_by_name('inner').type)
+ finally:
+ try:
+ del MyMessage
+ del ForwardEnum
+ del ForwardMessage
+ except:
+ pass
+
+ def testEnumField_WrongType(self):
+ """Test that forward referencing the wrong type raises an error."""
+ global AMessage
+ try:
+ class AMessage(messages.Message):
+ pass
+
+ class AnotherMessage(messages.Message):
+
+ a_field = messages.EnumField('AMessage', 1)
+
+ self.assertRaises(messages.FieldDefinitionError,
+ getattr,
+ AnotherMessage.field_by_name('a_field'),
+ 'type')
+ finally:
+ del AMessage
+
+ def testMessageDefinition(self):
+ """Test that message definition is set on fields."""
+ class MyMessage(messages.Message):
+
+ my_field = messages.StringField(1)
+
+ self.assertEquals(MyMessage,
+ MyMessage.field_by_name('my_field').message_definition())
+
+ def testNoneAssignment(self):
+ """Test that assigning None does not change comparison."""
+ class MyMessage(messages.Message):
+
+ my_field = messages.StringField(1)
+
+ m1 = MyMessage()
+ m2 = MyMessage()
+ m2.my_field = None
+ self.assertEquals(m1, m2)
+
+ def testNonAsciiStr(self):
+ """Test validation fails for non-ascii StringField values."""
+ class Thing(messages.Message):
+ string_field = messages.StringField(2)
+
+ thing = Thing()
+ self.assertRaisesWithRegexpMatch(
+ messages.ValidationError,
+ 'Field string_field encountered non-ASCII string',
+ setattr, thing, 'string_field', test_util.BINARY)
+
+
+class MessageTest(test_util.TestCase):
+ """Tests for message class."""
+
+ def CreateMessageClass(self):
+ """Creates a simple message class with 3 fields.
+
+ Fields are defined in alphabetical order but with conflicting numeric
+ order.
+ """
+ class ComplexMessage(messages.Message):
+ a3 = messages.IntegerField(3)
+ b1 = messages.StringField(1)
+ c2 = messages.StringField(2)
+
+ return ComplexMessage
+
+ def testSameNumbers(self):
+ """Test that cannot assign two fields with same numbers."""
+
+ def action():
+ class BadMessage(messages.Message):
+ f1 = messages.IntegerField(1)
+ f2 = messages.IntegerField(1)
+ self.assertRaises(messages.DuplicateNumberError,
+ action)
+
+ def testStrictAssignment(self):
+ """Tests that cannot assign to unknown or non-reserved attributes."""
+ class SimpleMessage(messages.Message):
+ field = messages.IntegerField(1)
+
+ simple_message = SimpleMessage()
+ self.assertRaises(AttributeError,
+ setattr,
+ simple_message,
+ 'does_not_exist',
+ 10)
+
+ def testListAssignmentDoesNotCopy(self):
+ class SimpleMessage(messages.Message):
+ repeated = messages.IntegerField(1, repeated=True)
+
+ message = SimpleMessage()
+ original = message.repeated
+ message.repeated = []
+ self.assertFalse(original is message.repeated)
+
+ def testValidate_Optional(self):
+ """Tests validation of optional fields."""
+ class SimpleMessage(messages.Message):
+ non_required = messages.IntegerField(1)
+
+ simple_message = SimpleMessage()
+ simple_message.check_initialized()
+ simple_message.non_required = 10
+ simple_message.check_initialized()
+
+ def testValidate_Required(self):
+ """Tests validation of required fields."""
+ class SimpleMessage(messages.Message):
+ required = messages.IntegerField(1, required=True)
+
+ simple_message = SimpleMessage()
+ self.assertRaises(messages.ValidationError,
+ simple_message.check_initialized)
+ simple_message.required = 10
+ simple_message.check_initialized()
+
+ def testValidate_Repeated(self):
+ """Tests validation of repeated fields."""
+ class SimpleMessage(messages.Message):
+ repeated = messages.IntegerField(1, repeated=True)
+
+ simple_message = SimpleMessage()
+
+ # Check valid values.
+ for valid_value in [], [10], [10, 20], (), (10,), (10, 20):
+ simple_message.repeated = valid_value
+ simple_message.check_initialized()
+
+ # Check cleared.
+ simple_message.repeated = []
+ simple_message.check_initialized()
+
+ # Check invalid values.
+ for invalid_value in 10, ['10', '20'], [None], (None,):
+ self.assertRaises(messages.ValidationError,
+ setattr, simple_message, 'repeated', invalid_value)
+
+ def testIsInitialized(self):
+ """Tests is_initialized."""
+ class SimpleMessage(messages.Message):
+ required = messages.IntegerField(1, required=True)
+
+ simple_message = SimpleMessage()
+ self.assertFalse(simple_message.is_initialized())
+
+ simple_message.required = 10
+
+ self.assertTrue(simple_message.is_initialized())
+
+ def testIsInitializedNestedField(self):
+ """Tests is_initialized for nested fields."""
+ class SimpleMessage(messages.Message):
+ required = messages.IntegerField(1, required=True)
+
+ class NestedMessage(messages.Message):
+ simple = messages.MessageField(SimpleMessage, 1)
+
+ simple_message = SimpleMessage()
+ self.assertFalse(simple_message.is_initialized())
+ nested_message = NestedMessage(simple=simple_message)
+ self.assertFalse(nested_message.is_initialized())
+
+ simple_message.required = 10
+
+ self.assertTrue(simple_message.is_initialized())
+ self.assertTrue(nested_message.is_initialized())
+
+ def testInitializeNestedFieldFromDict(self):
+ """Tests initializing nested fields from dict."""
+ class SimpleMessage(messages.Message):
+ required = messages.IntegerField(1, required=True)
+
+ class NestedMessage(messages.Message):
+ simple = messages.MessageField(SimpleMessage, 1)
+
+ class RepeatedMessage(messages.Message):
+ simple = messages.MessageField(SimpleMessage, 1, repeated=True)
+
+ nested_message1 = NestedMessage(simple={'required': 10})
+ self.assertTrue(nested_message1.is_initialized())
+ self.assertTrue(nested_message1.simple.is_initialized())
+
+ nested_message2 = NestedMessage()
+ nested_message2.simple = {'required': 10}
+ self.assertTrue(nested_message2.is_initialized())
+ self.assertTrue(nested_message2.simple.is_initialized())
+
+ repeated_values = [{}, {'required': 10}, SimpleMessage(required=20)]
+
+ repeated_message1 = RepeatedMessage(simple=repeated_values)
+ self.assertEquals(3, len(repeated_message1.simple))
+ self.assertFalse(repeated_message1.is_initialized())
+
+ repeated_message1.simple[0].required = 0
+ self.assertTrue(repeated_message1.is_initialized())
+
+ repeated_message2 = RepeatedMessage()
+ repeated_message2.simple = repeated_values
+ self.assertEquals(3, len(repeated_message2.simple))
+ self.assertFalse(repeated_message2.is_initialized())
+
+ repeated_message2.simple[0].required = 0
+ self.assertTrue(repeated_message2.is_initialized())
+
+ def testNestedMethodsNotAllowed(self):
+ """Test that method definitions on Message classes are not allowed."""
+ def action():
+ class WithMethods(messages.Message):
+ def not_allowed(self):
+ pass
+
+ self.assertRaises(messages.MessageDefinitionError,
+ action)
+
+ def testNestedAttributesNotAllowed(self):
+ """Test that attribute assignment on Message classes are not allowed."""
+ def int_attribute():
+ class WithMethods(messages.Message):
+ not_allowed = 1
+
+ def string_attribute():
+ class WithMethods(messages.Message):
+ not_allowed = 'not allowed'
+
+ def enum_attribute():
+ class WithMethods(messages.Message):
+ not_allowed = Color.RED
+
+ for action in (int_attribute, string_attribute, enum_attribute):
+ self.assertRaises(messages.MessageDefinitionError,
+ action)
+
+ def testNameIsSetOnFields(self):
+ """Make sure name is set on fields after Message class init."""
+ class HasNamedFields(messages.Message):
+ field = messages.StringField(1)
+
+ self.assertEquals('field', HasNamedFields.field_by_number(1).name)
+
+ def testSubclassingMessageDisallowed(self):
+ """Not permitted to create sub-classes of message classes."""
+ class SuperClass(messages.Message):
+ pass
+
+ def action():
+ class SubClass(SuperClass):
+ pass
+
+ self.assertRaises(messages.MessageDefinitionError,
+ action)
+
+ def testAllFields(self):
+ """Test all_fields method."""
+ ComplexMessage = self.CreateMessageClass()
+ fields = list(ComplexMessage.all_fields())
+
+ # Order does not matter, so sort now.
+ fields = sorted(fields, key=lambda f: f.name)
+
+ self.assertEquals(3, len(fields))
+ self.assertEquals('a3', fields[0].name)
+ self.assertEquals('b1', fields[1].name)
+ self.assertEquals('c2', fields[2].name)
+
+ def testFieldByName(self):
+ """Test getting field by name."""
+ ComplexMessage = self.CreateMessageClass()
+
+ self.assertEquals(3, ComplexMessage.field_by_name('a3').number)
+ self.assertEquals(1, ComplexMessage.field_by_name('b1').number)
+ self.assertEquals(2, ComplexMessage.field_by_name('c2').number)
+
+ self.assertRaises(KeyError,
+ ComplexMessage.field_by_name,
+ 'unknown')
+
+ def testFieldByNumber(self):
+ """Test getting field by number."""
+ ComplexMessage = self.CreateMessageClass()
+
+ self.assertEquals('a3', ComplexMessage.field_by_number(3).name)
+ self.assertEquals('b1', ComplexMessage.field_by_number(1).name)
+ self.assertEquals('c2', ComplexMessage.field_by_number(2).name)
+
+ self.assertRaises(KeyError,
+ ComplexMessage.field_by_number,
+ 4)
+
+ def testGetAssignedValue(self):
+ """Test getting the assigned value of a field."""
+ class SomeMessage(messages.Message):
+ a_value = messages.StringField(1, default=u'a default')
+
+ message = SomeMessage()
+ self.assertEquals(None, message.get_assigned_value('a_value'))
+
+ message.a_value = u'a string'
+ self.assertEquals(u'a string', message.get_assigned_value('a_value'))
+
+ message.a_value = u'a default'
+ self.assertEquals(u'a default', message.get_assigned_value('a_value'))
+
+ self.assertRaisesWithRegexpMatch(
+ AttributeError,
+ 'Message SomeMessage has no field no_such_field',
+ message.get_assigned_value,
+ 'no_such_field')
+
+ def testReset(self):
+ """Test resetting a field value."""
+ class SomeMessage(messages.Message):
+ a_value = messages.StringField(1, default=u'a default')
+ repeated = messages.IntegerField(2, repeated=True)
+
+ message = SomeMessage()
+
+ self.assertRaises(AttributeError, message.reset, 'unknown')
+
+ self.assertEquals(u'a default', message.a_value)
+ message.reset('a_value')
+ self.assertEquals(u'a default', message.a_value)
+
+ message.a_value = u'a new value'
+ self.assertEquals(u'a new value', message.a_value)
+ message.reset('a_value')
+ self.assertEquals(u'a default', message.a_value)
+
+ message.repeated = [1, 2, 3]
+ self.assertEquals([1, 2, 3], message.repeated)
+ saved = message.repeated
+ message.reset('repeated')
+ self.assertEquals([], message.repeated)
+ self.assertIsInstance(message.repeated, messages.FieldList)
+ self.assertEquals([1, 2, 3], saved)
+
+ def testAllowNestedEnums(self):
+ """Test allowing nested enums in a message definition."""
+ class Trade(messages.Message):
+ class Duration(messages.Enum):
+ GTC = 1
+ DAY = 2
+
+ class Currency(messages.Enum):
+ USD = 1
+ GBP = 2
+ INR = 3
+
+ # Sorted by name order seems to be the only feasible option.
+ self.assertEquals(['Currency', 'Duration'], Trade.__enums__)
+
+ # Message definition will now be set on Enumerated objects.
+ self.assertEquals(Trade, Trade.Duration.message_definition())
+
+ def testAllowNestedMessages(self):
+ """Test allowing nested messages in a message definition."""
+ class Trade(messages.Message):
+ class Lot(messages.Message):
+ pass
+
+ class Agent(messages.Message):
+ pass
+
+ # Sorted by name order seems to be the only feasible option.
+ self.assertEquals(['Agent', 'Lot'], Trade.__messages__)
+ self.assertEquals(Trade, Trade.Agent.message_definition())
+ self.assertEquals(Trade, Trade.Lot.message_definition())
+
+ # But not Message itself.
+ def action():
+ class Trade(messages.Message):
+ NiceTry = messages.Message
+ self.assertRaises(messages.MessageDefinitionError, action)
+
+ def testDisallowClassAssignments(self):
+ """Test setting class attributes may not happen."""
+ class MyMessage(messages.Message):
+ pass
+
+ self.assertRaises(AttributeError,
+ setattr,
+ MyMessage,
+ 'x',
+ 'do not assign')
+
+ def testEquality(self):
+ """Test message class equality."""
+ # Comparison against enums must work.
+ class MyEnum(messages.Enum):
+ val1 = 1
+ val2 = 2
+
+ # Comparisons against nested messages must work.
+ class AnotherMessage(messages.Message):
+ string = messages.StringField(1)
+
+ class MyMessage(messages.Message):
+ field1 = messages.IntegerField(1)
+ field2 = messages.EnumField(MyEnum, 2)
+ field3 = messages.MessageField(AnotherMessage, 3)
+
+ message1 = MyMessage()
+
+ self.assertNotEquals('hi', message1)
+ self.assertNotEquals(AnotherMessage(), message1)
+ self.assertEquals(message1, message1)
+
+ message2 = MyMessage()
+
+ self.assertEquals(message1, message2)
+
+ message1.field1 = 10
+ self.assertNotEquals(message1, message2)
+
+ message2.field1 = 20
+ self.assertNotEquals(message1, message2)
+
+ message2.field1 = 10
+ self.assertEquals(message1, message2)
+
+ message1.field2 = MyEnum.val1
+ self.assertNotEquals(message1, message2)
+
+ message2.field2 = MyEnum.val2
+ self.assertNotEquals(message1, message2)
+
+ message2.field2 = MyEnum.val1
+ self.assertEquals(message1, message2)
+
+ message1.field3 = AnotherMessage()
+ message1.field3.string = 'value1'
+ self.assertNotEquals(message1, message2)
+
+ message2.field3 = AnotherMessage()
+ message2.field3.string = 'value2'
+ self.assertNotEquals(message1, message2)
+
+ message2.field3.string = 'value1'
+ self.assertEquals(message1, message2)
+
+ def testEqualityWithUnknowns(self):
+ """Test message class equality with unknown fields."""
+
+ class MyMessage(messages.Message):
+ field1 = messages.IntegerField(1)
+
+ message1 = MyMessage()
+ message2 = MyMessage()
+ self.assertEquals(message1, message2)
+ message1.set_unrecognized_field('unknown1', 'value1',
+ messages.Variant.STRING)
+ self.assertEquals(message1, message2)
+
+ message1.set_unrecognized_field('unknown2', ['asdf', 3],
+ messages.Variant.STRING)
+ message1.set_unrecognized_field('unknown3', 4.7,
+ messages.Variant.DOUBLE)
+ self.assertEquals(message1, message2)
+
+ def testUnrecognizedFieldInvalidVariant(self):
+ class MyMessage(messages.Message):
+ field1 = messages.IntegerField(1)
+
+ message1 = MyMessage()
+ self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4',
+ {'unhandled': 'type'}, None)
+ self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4',
+ {'unhandled': 'type'}, 123)
+
+ def testRepr(self):
+ """Test represtation of Message object."""
+ class MyMessage(messages.Message):
+ integer_value = messages.IntegerField(1)
+ string_value = messages.StringField(2)
+ unassigned = messages.StringField(3)
+ unassigned_with_default = messages.StringField(4, default=u'a default')
+
+ my_message = MyMessage()
+ my_message.integer_value = 42
+ my_message.string_value = u'A string'
+
+ pat = re.compile(r"<MyMessage\n integer_value: 42\n"
+ " string_value: [u]?'A string'>")
+ self.assertTrue(pat.match(repr(my_message)) is not None)
+
+ def testValidation(self):
+ """Test validation of message values."""
+ # Test optional.
+ class SubMessage(messages.Message):
+ pass
+
+ class Message(messages.Message):
+ val = messages.MessageField(SubMessage, 1)
+
+ message = Message()
+
+ message_field = messages.MessageField(Message, 1)
+ message_field.validate(message)
+ message.val = SubMessage()
+ message_field.validate(message)
+ self.assertRaises(messages.ValidationError,
+ setattr, message, 'val', [SubMessage()])
+
+ # Test required.
+ class Message(messages.Message):
+ val = messages.MessageField(SubMessage, 1, required=True)
+
+ message = Message()
+
+ message_field = messages.MessageField(Message, 1)
+ message_field.validate(message)
+ message.val = SubMessage()
+ message_field.validate(message)
+ self.assertRaises(messages.ValidationError,
+ setattr, message, 'val', [SubMessage()])
+
+ # Test repeated.
+ class Message(messages.Message):
+ val = messages.MessageField(SubMessage, 1, repeated=True)
+
+ message = Message()
+
+ message_field = messages.MessageField(Message, 1)
+ message_field.validate(message)
+ self.assertRaisesWithRegexpMatch(
+ messages.ValidationError,
+ "Field val is repeated. Found: <SubMessage>",
+ setattr, message, 'val', SubMessage())
+ message.val = [SubMessage()]
+ message_field.validate(message)
+
+ def testDefinitionName(self):
+ """Test message name."""
+ class MyMessage(messages.Message):
+ pass
+
+ module_name = test_util.get_module_name(FieldTest)
+ self.assertEquals('%s.MyMessage' % module_name,
+ MyMessage.definition_name())
+ self.assertEquals(module_name, MyMessage.outer_definition_name())
+ self.assertEquals(module_name, MyMessage.definition_package())
+
+ self.assertEquals(six.text_type, type(MyMessage.definition_name()))
+ self.assertEquals(six.text_type, type(MyMessage.outer_definition_name()))
+ self.assertEquals(six.text_type, type(MyMessage.definition_package()))
+
+ def testDefinitionName_OverrideModule(self):
+ """Test message module is overriden by module package name."""
+ class MyMessage(messages.Message):
+ pass
+
+ global package
+ package = 'my.package'
+
+ try:
+ self.assertEquals('my.package.MyMessage', MyMessage.definition_name())
+ self.assertEquals('my.package', MyMessage.outer_definition_name())
+ self.assertEquals('my.package', MyMessage.definition_package())
+
+ self.assertEquals(six.text_type, type(MyMessage.definition_name()))
+ self.assertEquals(six.text_type, type(MyMessage.outer_definition_name()))
+ self.assertEquals(six.text_type, type(MyMessage.definition_package()))
+ finally:
+ del package
+
+ def testDefinitionName_NoModule(self):
+ """Test what happens when there is no module for message."""
+ class MyMessage(messages.Message):
+ pass
+
+ original_modules = sys.modules
+ sys.modules = dict(sys.modules)
+ try:
+ del sys.modules[__name__]
+ self.assertEquals('MyMessage', MyMessage.definition_name())
+ self.assertEquals(None, MyMessage.outer_definition_name())
+ self.assertEquals(None, MyMessage.definition_package())
+
+ self.assertEquals(six.text_type, type(MyMessage.definition_name()))
+ finally:
+ sys.modules = original_modules
+
+ def testDefinitionName_Nested(self):
+ """Test nested message names."""
+ class MyMessage(messages.Message):
+
+ class NestedMessage(messages.Message):
+
+ class NestedMessage(messages.Message):
+
+ pass
+
+ module_name = test_util.get_module_name(MessageTest)
+ self.assertEquals('%s.MyMessage.NestedMessage' % module_name,
+ MyMessage.NestedMessage.definition_name())
+ self.assertEquals('%s.MyMessage' % module_name,
+ MyMessage.NestedMessage.outer_definition_name())
+ self.assertEquals(module_name,
+ MyMessage.NestedMessage.definition_package())
+
+ self.assertEquals('%s.MyMessage.NestedMessage.NestedMessage' % module_name,
+ MyMessage.NestedMessage.NestedMessage.definition_name())
+ self.assertEquals(
+ '%s.MyMessage.NestedMessage' % module_name,
+ MyMessage.NestedMessage.NestedMessage.outer_definition_name())
+ self.assertEquals(
+ module_name,
+ MyMessage.NestedMessage.NestedMessage.definition_package())
+
+
+ def testMessageDefinition(self):
+ """Test that enumeration knows its enclosing message definition."""
+ class OuterMessage(messages.Message):
+
+ class InnerMessage(messages.Message):
+ pass
+
+ self.assertEquals(None, OuterMessage.message_definition())
+ self.assertEquals(OuterMessage,
+ OuterMessage.InnerMessage.message_definition())
+
+ def testConstructorKwargs(self):
+ """Test kwargs via constructor."""
+ class SomeMessage(messages.Message):
+ name = messages.StringField(1)
+ number = messages.IntegerField(2)
+
+ expected = SomeMessage()
+ expected.name = 'my name'
+ expected.number = 200
+ self.assertEquals(expected, SomeMessage(name='my name', number=200))
+
+ def testConstructorNotAField(self):
+ """Test kwargs via constructor with wrong names."""
+ class SomeMessage(messages.Message):
+ pass
+
+ self.assertRaisesWithRegexpMatch(
+ AttributeError,
+ 'May not assign arbitrary value does_not_exist to message SomeMessage',
+ SomeMessage,
+ does_not_exist=10)
+
+ def testGetUnsetRepeatedValue(self):
+ class SomeMessage(messages.Message):
+ repeated = messages.IntegerField(1, repeated=True)
+
+ instance = SomeMessage()
+ self.assertEquals([], instance.repeated)
+ self.assertTrue(isinstance(instance.repeated, messages.FieldList))
+
+ def testCompareAutoInitializedRepeatedFields(self):
+ class SomeMessage(messages.Message):
+ repeated = messages.IntegerField(1, repeated=True)
+
+ message1 = SomeMessage(repeated=[])
+ message2 = SomeMessage()
+ self.assertEquals(message1, message2)
+
+ def testUnknownValues(self):
+ """Test message class equality with unknown fields."""
+ class MyMessage(messages.Message):
+ field1 = messages.IntegerField(1)
+
+ message = MyMessage()
+ self.assertEquals([], message.all_unrecognized_fields())
+ self.assertEquals((None, None),
+ message.get_unrecognized_field_info('doesntexist'))
+ self.assertEquals((None, None),
+ message.get_unrecognized_field_info(
+ 'doesntexist', None, None))
+ self.assertEquals(('defaultvalue', 'defaultwire'),
+ message.get_unrecognized_field_info(
+ 'doesntexist', 'defaultvalue', 'defaultwire'))
+ self.assertEquals((3, None),
+ message.get_unrecognized_field_info(
+ 'doesntexist', value_default=3))
+
+ message.set_unrecognized_field('exists', 9.5, messages.Variant.DOUBLE)
+ self.assertEquals(1, len(message.all_unrecognized_fields()))
+ self.assertTrue('exists' in message.all_unrecognized_fields())
+ self.assertEquals((9.5, messages.Variant.DOUBLE),
+ message.get_unrecognized_field_info('exists'))
+ self.assertEquals((9.5, messages.Variant.DOUBLE),
+ message.get_unrecognized_field_info('exists', 'type',
+ 1234))
+ self.assertEquals((1234, None),
+ message.get_unrecognized_field_info('doesntexist', 1234))
+
+ message.set_unrecognized_field('another', 'value', messages.Variant.STRING)
+ self.assertEquals(2, len(message.all_unrecognized_fields()))
+ self.assertTrue('exists' in message.all_unrecognized_fields())
+ self.assertTrue('another' in message.all_unrecognized_fields())
+ self.assertEquals((9.5, messages.Variant.DOUBLE),
+ message.get_unrecognized_field_info('exists'))
+ self.assertEquals(('value', messages.Variant.STRING),
+ message.get_unrecognized_field_info('another'))
+
+ message.set_unrecognized_field('typetest1', ['list', 0, ('test',)],
+ messages.Variant.STRING)
+ self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING),
+ message.get_unrecognized_field_info('typetest1'))
+ message.set_unrecognized_field('typetest2', '', messages.Variant.STRING)
+ self.assertEquals(('', messages.Variant.STRING),
+ message.get_unrecognized_field_info('typetest2'))
+
+ def testPickle(self):
+ """Testing pickling and unpickling of Message instances."""
+ global MyEnum
+ global AnotherMessage
+ global MyMessage
+
+ class MyEnum(messages.Enum):
+ val1 = 1
+ val2 = 2
+
+ class AnotherMessage(messages.Message):
+ string = messages.StringField(1, repeated=True)
+
+ class MyMessage(messages.Message):
+ field1 = messages.IntegerField(1)
+ field2 = messages.EnumField(MyEnum, 2)
+ field3 = messages.MessageField(AnotherMessage, 3)
+
+ message = MyMessage(field1=1, field2=MyEnum.val2,
+ field3=AnotherMessage(string=['a', 'b', 'c']))
+ message.set_unrecognized_field('exists', 'value', messages.Variant.STRING)
+ message.set_unrecognized_field('repeated', ['list', 0, ('test',)],
+ messages.Variant.STRING)
+ unpickled = pickle.loads(pickle.dumps(message))
+ self.assertEquals(message, unpickled)
+ self.assertTrue(AnotherMessage.string is unpickled.field3.string.field)
+ self.assertTrue('exists' in message.all_unrecognized_fields())
+ self.assertEquals(('value', messages.Variant.STRING),
+ message.get_unrecognized_field_info('exists'))
+ self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING),
+ message.get_unrecognized_field_info('repeated'))
+
+
+class FindDefinitionTest(test_util.TestCase):
+ """Test finding definitions relative to various definitions and modules."""
+
+ def setUp(self):
+ """Set up module-space. Starts off empty."""
+ self.modules = {}
+
+ def DefineModule(self, name):
+ """Define a module and its parents in module space.
+
+ Modules that are already defined in self.modules are not re-created.
+
+ Args:
+ name: Fully qualified name of modules to create.
+
+ Returns:
+ Deepest nested module. For example:
+
+ DefineModule('a.b.c') # Returns c.
+ """
+ name_path = name.split('.')
+ full_path = []
+ for node in name_path:
+ full_path.append(node)
+ full_name = '.'.join(full_path)
+ self.modules.setdefault(full_name, types.ModuleType(full_name))
+ return self.modules[name]
+
+ def DefineMessage(self, module, name, children={}, add_to_module=True):
+ """Define a new Message class in the context of a module.
+
+ Used for easily describing complex Message hierarchy. Message is defined
+ including all child definitions.
+
+ Args:
+ module: Fully qualified name of module to place Message class in.
+ name: Name of Message to define within module.
+ children: Define any level of nesting of children definitions. To define
+ a message, map the name to another dictionary. The dictionary can
+ itself contain additional definitions, and so on. To map to an Enum,
+ define the Enum class separately and map it by name.
+ add_to_module: If True, new Message class is added to module. If False,
+ new Message is not added.
+ """
+ # Make sure module exists.
+ module_instance = self.DefineModule(module)
+
+ # Recursively define all child messages.
+ for attribute, value in children.items():
+ if isinstance(value, dict):
+ children[attribute] = self.DefineMessage(
+ module, attribute, value, False)
+
+ # Override default __module__ variable.
+ children['__module__'] = module
+
+ # Instantiate and possibly add to module.
+ message_class = type(name, (messages.Message,), dict(children))
+ if add_to_module:
+ setattr(module_instance, name, message_class)
+ return message_class
+
+ def Importer(self, module, globals='', locals='', fromlist=None):
+ """Importer function.
+
+ Acts like __import__. Only loads modules from self.modules. Does not
+ try to load real modules defined elsewhere. Does not try to handle relative
+ imports.
+
+ Args:
+ module: Fully qualified name of module to load from self.modules.
+ """
+ if fromlist is None:
+ module = module.split('.')[0]
+ try:
+ return self.modules[module]
+ except KeyError:
+ raise ImportError()
+
+ def testNoSuchModule(self):
+ """Test searching for definitions that do no exist."""
+ self.assertRaises(messages.DefinitionNotFoundError,
+ messages.find_definition,
+ 'does.not.exist',
+ importer=self.Importer)
+
+ def testRefersToModule(self):
+ """Test that referring to a module does not return that module."""
+ self.DefineModule('i.am.a.module')
+ self.assertRaises(messages.DefinitionNotFoundError,
+ messages.find_definition,
+ 'i.am.a.module',
+ importer=self.Importer)
+
+ def testNoDefinition(self):
+ """Test not finding a definition in an existing module."""
+ self.DefineModule('i.am.a.module')
+ self.assertRaises(messages.DefinitionNotFoundError,
+ messages.find_definition,
+ 'i.am.a.module.MyMessage',
+ importer=self.Importer)
+
+ def testNotADefinition(self):
+ """Test trying to fetch something that is not a definition."""
+ module = self.DefineModule('i.am.a.module')
+ setattr(module, 'A', 'a string')
+ self.assertRaises(messages.DefinitionNotFoundError,
+ messages.find_definition,
+ 'i.am.a.module.A',
+ importer=self.Importer)
+
+ def testGlobalFind(self):
+ """Test finding definitions from fully qualified module names."""
+ A = self.DefineMessage('a.b.c', 'A', {})
+ self.assertEquals(A, messages.find_definition('a.b.c.A',
+ importer=self.Importer))
+ B = self.DefineMessage('a.b.c', 'B', {'C':{}})
+ self.assertEquals(B.C, messages.find_definition('a.b.c.B.C',
+ importer=self.Importer))
+
+ def testRelativeToModule(self):
+ """Test finding definitions relative to modules."""
+ # Define modules.
+ a = self.DefineModule('a')
+ b = self.DefineModule('a.b')
+ c = self.DefineModule('a.b.c')
+
+ # Define messages.
+ A = self.DefineMessage('a', 'A')
+ B = self.DefineMessage('a.b', 'B')
+ C = self.DefineMessage('a.b.c', 'C')
+ D = self.DefineMessage('a.b.d', 'D')
+
+ # Find A, B, C and D relative to a.
+ self.assertEquals(A, messages.find_definition(
+ 'A', a, importer=self.Importer))
+ self.assertEquals(B, messages.find_definition(
+ 'b.B', a, importer=self.Importer))
+ self.assertEquals(C, messages.find_definition(
+ 'b.c.C', a, importer=self.Importer))
+ self.assertEquals(D, messages.find_definition(
+ 'b.d.D', a, importer=self.Importer))
+
+ # Find A, B, C and D relative to b.
+ self.assertEquals(A, messages.find_definition(
+ 'A', b, importer=self.Importer))
+ self.assertEquals(B, messages.find_definition(
+ 'B', b, importer=self.Importer))
+ self.assertEquals(C, messages.find_definition(
+ 'c.C', b, importer=self.Importer))
+ self.assertEquals(D, messages.find_definition(
+ 'd.D', b, importer=self.Importer))
+
+ # Find A, B, C and D relative to c. Module d is the same case as c.
+ self.assertEquals(A, messages.find_definition(
+ 'A', c, importer=self.Importer))
+ self.assertEquals(B, messages.find_definition(
+ 'B', c, importer=self.Importer))
+ self.assertEquals(C, messages.find_definition(
+ 'C', c, importer=self.Importer))
+ self.assertEquals(D, messages.find_definition(
+ 'd.D', c, importer=self.Importer))
+
+ def testRelativeToMessages(self):
+ """Test finding definitions relative to Message definitions."""
+ A = self.DefineMessage('a.b', 'A', {'B': {'C': {}, 'D': {}}})
+ B = A.B
+ C = A.B.C
+ D = A.B.D
+
+ # Find relative to A.
+ self.assertEquals(A, messages.find_definition(
+ 'A', A, importer=self.Importer))
+ self.assertEquals(B, messages.find_definition(
+ 'B', A, importer=self.Importer))
+ self.assertEquals(C, messages.find_definition(
+ 'B.C', A, importer=self.Importer))
+ self.assertEquals(D, messages.find_definition(
+ 'B.D', A, importer=self.Importer))
+
+ # Find relative to B.
+ self.assertEquals(A, messages.find_definition(
+ 'A', B, importer=self.Importer))
+ self.assertEquals(B, messages.find_definition(
+ 'B', B, importer=self.Importer))
+ self.assertEquals(C, messages.find_definition(
+ 'C', B, importer=self.Importer))
+ self.assertEquals(D, messages.find_definition(
+ 'D', B, importer=self.Importer))
+
+ # Find relative to C.
+ self.assertEquals(A, messages.find_definition(
+ 'A', C, importer=self.Importer))
+ self.assertEquals(B, messages.find_definition(
+ 'B', C, importer=self.Importer))
+ self.assertEquals(C, messages.find_definition(
+ 'C', C, importer=self.Importer))
+ self.assertEquals(D, messages.find_definition(
+ 'D', C, importer=self.Importer))
+
+ # Find relative to C searching from c.
+ self.assertEquals(A, messages.find_definition(
+ 'b.A', C, importer=self.Importer))
+ self.assertEquals(B, messages.find_definition(
+ 'b.A.B', C, importer=self.Importer))
+ self.assertEquals(C, messages.find_definition(
+ 'b.A.B.C', C, importer=self.Importer))
+ self.assertEquals(D, messages.find_definition(
+ 'b.A.B.D', C, importer=self.Importer))
+
+ def testAbsoluteReference(self):
+ """Test finding absolute definition names."""
+ # Define modules.
+ a = self.DefineModule('a')
+ b = self.DefineModule('a.a')
+
+ # Define messages.
+ aA = self.DefineMessage('a', 'A')
+ aaA = self.DefineMessage('a.a', 'A')
+
+ # Always find a.A.
+ self.assertEquals(aA, messages.find_definition('.a.A', None,
+ importer=self.Importer))
+ self.assertEquals(aA, messages.find_definition('.a.A', a,
+ importer=self.Importer))
+ self.assertEquals(aA, messages.find_definition('.a.A', aA,
+ importer=self.Importer))
+ self.assertEquals(aA, messages.find_definition('.a.A', aaA,
+ importer=self.Importer))
+
+ def testFindEnum(self):
+ """Test that Enums are found."""
+ class Color(messages.Enum):
+ pass
+ A = self.DefineMessage('a', 'A', {'Color': Color})
+
+ self.assertEquals(
+ Color,
+ messages.find_definition('Color', A, importer=self.Importer))
+
+ def testFalseScope(self):
+ """Test that Message definitions nested in strange objects are hidden."""
+ global X
+ class X(object):
+ class A(messages.Message):
+ pass
+
+ self.assertRaises(TypeError, messages.find_definition, 'A', X)
+ self.assertRaises(messages.DefinitionNotFoundError,
+ messages.find_definition,
+ 'X.A', sys.modules[__name__])
+
+ def testSearchAttributeFirst(self):
+ """Make sure not faked out by module, but continues searching."""
+ A = self.DefineMessage('a', 'A')
+ module_A = self.DefineModule('a.A')
+
+ self.assertEquals(A, messages.find_definition(
+ 'a.A', None, importer=self.Importer))
+
+
+class FindDefinitionUnicodeTests(test_util.TestCase):
+
+ # TODO(craigcitro): Fix this test and re-enable it.
+ def notatestUnicodeString(self):
+ """Test using unicode names."""
+ from protorpc import registry
+ self.assertEquals('ServiceMapping',
+ messages.find_definition(
+ u'protorpc.registry.ServiceMapping',
+ None).__name__)
+
+
+def main():
+ unittest.main()
+
+
+if __name__ == '__main__':
+ main()

Powered by Google App Engine
This is Rietveld 408576698