Index: tools/usb_gadget/descriptor_test.py |
diff --git a/tools/usb_gadget/descriptor_test.py b/tools/usb_gadget/descriptor_test.py |
new file mode 100755 |
index 0000000000000000000000000000000000000000..895f0b6495ac803a9225b4f7111ad1287094cf00 |
--- /dev/null |
+++ b/tools/usb_gadget/descriptor_test.py |
@@ -0,0 +1,210 @@ |
+#!/usr/bin/python |
+ |
+import unittest |
+ |
+import descriptor |
+ |
+ |
+class DescriptorWithField(descriptor.Descriptor): |
+ pass |
+ |
+DescriptorWithField.AddField('bField', 'B') |
+ |
+ |
+class DescriptorWithDefault(descriptor.Descriptor): |
+ pass |
+ |
+DescriptorWithDefault.AddField('bDefault', 'B', default=42) |
+ |
+ |
+class DescriptorWithFixed(descriptor.Descriptor): |
+ pass |
+ |
+DescriptorWithFixed.AddFixedField('bFixed', 'B', 42) |
+ |
+ |
+class DescriptorWithComputed(descriptor.Descriptor): |
+ |
+ @property |
+ def foo(self): |
+ return 42 |
+ |
+DescriptorWithComputed.AddComputedField('bComputed', 'B', 'foo') |
+ |
+ |
+class DescriptorWithDescriptors(descriptor.DescriptorContainer): |
+ pass |
+ |
+DescriptorWithDescriptors.AddField('bType', 'B') |
+ |
+ |
+class DescriptorTest(unittest.TestCase): |
+ |
+ def test_default(self): |
+ obj = DescriptorWithDefault() |
+ self.assertEquals(obj.bDefault, 42) |
+ |
+ def test_change_default(self): |
+ obj = DescriptorWithDefault() |
+ obj.bDefault = 1 |
+ self.assertEquals(obj.bDefault, 1) |
+ |
+ def test_override_default(self): |
+ obj = DescriptorWithDefault(bDefault=56) |
+ self.assertEquals(obj.bDefault, 56) |
+ |
+ def test_fixed(self): |
+ obj = DescriptorWithFixed() |
+ self.assertEquals(obj.bFixed, 42) |
+ |
+ def test_set_fixed(self): |
+ with self.assertRaises(RuntimeError): |
+ DescriptorWithFixed(bFixed=1) |
+ |
+ def test_modify_fixed(self): |
+ obj = DescriptorWithFixed() |
+ with self.assertRaises(RuntimeError): |
+ obj.bFixed = 1 |
+ |
+ def test_computed(self): |
+ obj = DescriptorWithComputed() |
+ self.assertEquals(obj.bComputed, 42) |
+ |
+ def test_set_computed(self): |
+ with self.assertRaises(RuntimeError): |
+ DescriptorWithComputed(bComputed=1) |
+ |
+ def test_modify_computed(self): |
+ obj = DescriptorWithComputed() |
+ with self.assertRaises(RuntimeError): |
+ obj.bComputed = 1 |
+ |
+ def test_unexpected(self): |
+ with self.assertRaisesRegexp(TypeError, 'Unexpected'): |
+ DescriptorWithField(bUnexpected=1) |
+ |
+ def test_missing(self): |
+ with self.assertRaisesRegexp(TypeError, 'Missing'): |
+ DescriptorWithField() |
+ |
+ def test_size(self): |
+ obj = DescriptorWithField(bField=42) |
+ self.assertEquals(obj.struct_size, 1) |
+ self.assertEquals(obj.total_size, 1) |
+ |
+ def test_encode(self): |
+ obj = DescriptorWithField(bField=0xff) |
+ self.assertEquals(obj.Encode(), '\xff') |
+ |
+ def test_string(self): |
+ obj = DescriptorWithField(bField=42) |
+ string = str(obj) |
+ self.assertIn('bField', string) |
+ self.assertIn('42', string) |
+ |
+ def test_container(self): |
+ parent = DescriptorWithDescriptors(bType=0) |
+ child1 = DescriptorWithField(bField=1) |
+ parent.Add(child1) |
+ child2 = DescriptorWithField(bField=2) |
+ parent.Add(child2) |
+ self.assertEquals(parent.total_size, 3) |
+ self.assertEquals(parent.Encode(), '\x00\x01\x02') |
+ string = str(parent) |
+ self.assertIn('bType', string) |
+ self.assertIn('bField', string) |
+ |
+ |
+class TestUsbDescriptors(unittest.TestCase): |
+ |
+ def test_device_descriptor(self): |
+ device_desc = descriptor.DeviceDescriptor( |
+ idVendor=0xDEAD, |
+ idProduct=0xBEEF, |
+ bcdDevice=0x0100, |
+ bNumConfigurations=1) |
+ self.assertEquals( |
+ device_desc.Encode(), |
+ '\x12\x01\x00\x02\x00\x00\x00\x40\xAD\xDE\xEF\xBE\x00\x01\x00\x00\x00' |
+ '\x01') |
+ |
+ def test_unique_interfaces(self): |
+ interface_desc1 = descriptor.InterfaceDescriptor(bInterfaceNumber=1) |
+ interface_desc2 = descriptor.InterfaceDescriptor(bInterfaceNumber=1, |
+ bAlternateSetting=1) |
+ interface_desc3 = descriptor.InterfaceDescriptor(bInterfaceNumber=1) |
+ |
+ configuration_desc = descriptor.ConfigurationDescriptor( |
+ bmAttributes=0xC0, |
+ MaxPower=100) |
+ configuration_desc.AddInterface(interface_desc1) |
+ configuration_desc.AddInterface(interface_desc2) |
+ with self.assertRaisesRegexp(RuntimeError, r'Interface 1 \(alternate 0\)'): |
+ configuration_desc.AddInterface(interface_desc3) |
+ |
+ def test_unique_endpoints(self): |
+ endpoint_desc1 = descriptor.EndpointDescriptor( |
+ bEndpointAddress=0x01, |
+ bmAttributes=0x02, |
+ wMaxPacketSize=64, |
+ bInterval=1) |
+ endpoint_desc2 = descriptor.EndpointDescriptor( |
+ bEndpointAddress=0x81, |
+ bmAttributes=0x02, |
+ wMaxPacketSize=64, |
+ bInterval=1) |
+ endpoint_desc3 = descriptor.EndpointDescriptor( |
+ bEndpointAddress=0x01, |
+ bmAttributes=0x01, |
+ wMaxPacketSize=32, |
+ bInterval=10) |
+ |
+ interface_desc = descriptor.InterfaceDescriptor(bInterfaceNumber=1) |
+ interface_desc.AddEndpoint(endpoint_desc1) |
+ interface_desc.AddEndpoint(endpoint_desc2) |
+ with self.assertRaisesRegexp(RuntimeError, 'Endpoint 0x01 already defined'): |
+ interface_desc.AddEndpoint(endpoint_desc3) |
+ |
+ def test_configuration_descriptor(self): |
+ endpoint_desc = descriptor.EndpointDescriptor( |
+ bEndpointAddress=0x01, |
+ bmAttributes=0x02, |
+ wMaxPacketSize=64, |
+ bInterval=1) |
+ encoded_endpoint = '\x07\x05\x01\x02\x40\x00\x01' |
+ self.assertEquals(endpoint_desc.Encode(), encoded_endpoint) |
+ |
+ interface_desc = descriptor.InterfaceDescriptor(bInterfaceNumber=1) |
+ interface_desc.AddEndpoint(endpoint_desc) |
+ self.assertEquals([endpoint_desc], interface_desc.GetEndpoints()) |
+ encoded_interface = ('\x09\x04\x01\x00\x01\xFF\xFF\xFF\x00' + |
+ encoded_endpoint) |
+ self.assertEquals(interface_desc.Encode(), encoded_interface) |
+ |
+ configuration_desc = descriptor.ConfigurationDescriptor( |
+ bmAttributes=0xC0, |
+ MaxPower=100) |
+ configuration_desc.AddInterface(interface_desc) |
+ self.assertEquals([interface_desc], configuration_desc.GetInterfaces()) |
+ encoded_configuration = ('\x09\x02\x19\x00\x01\x01\x00\xC0\x64' + |
+ encoded_interface) |
+ self.assertEquals(configuration_desc.Encode(), encoded_configuration) |
+ |
+ def test_encode_hid_descriptor(self): |
+ hid_desc = descriptor.HidDescriptor() |
+ hid_desc.AddDescriptor(0x22, 0x80) |
+ hid_desc.AddDescriptor(0x23, 0x60) |
+ encoded_desc = '\x0C\x21\x11\x01\x00\x02\x22\x80\x00\x23\x60\x00' |
+ self.assertEquals(hid_desc.Encode(), encoded_desc) |
+ |
+ def test_print_hid_descriptor(self): |
+ hid_desc = descriptor.HidDescriptor() |
+ hid_desc.AddDescriptor(0x22, 0x80) |
+ hid_desc.AddDescriptor(0x23, 0x60) |
+ string = str(hid_desc) |
+ self.assertIn('0x22', string) |
+ self.assertIn('0x23', string) |
+ |
+ |
+if __name__ == '__main__': |
+ unittest.main() |