| OLD | NEW |
| (Empty) | |
| 1 # Copyright 2014 Google Inc. All Rights Reserved. |
| 2 # |
| 3 # Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 # you may not use this file except in compliance with the License. |
| 5 # You may obtain a copy of the License at |
| 6 # |
| 7 # http://www.apache.org/licenses/LICENSE-2.0 |
| 8 # |
| 9 # Unless required by applicable law or agreed to in writing, software |
| 10 # distributed under the License is distributed on an "AS IS" BASIS, |
| 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 # See the License for the specific language governing permissions and |
| 13 # limitations under the License. |
| 14 """Common code for converting proto to other formats, such as JSON.""" |
| 15 |
| 16 import base64 |
| 17 import collections |
| 18 import json |
| 19 |
| 20 |
| 21 from gslib.third_party.protorpc import messages |
| 22 from gslib.third_party.protorpc import protojson |
| 23 |
| 24 from gslib.third_party.storage_apitools import exceptions |
| 25 |
| 26 __all__ = [ |
| 27 'CopyProtoMessage', |
| 28 'JsonToMessage', |
| 29 'MessageToJson', |
| 30 'DictToMessage', |
| 31 'MessageToDict', |
| 32 'PyValueToMessage', |
| 33 'MessageToPyValue', |
| 34 ] |
| 35 |
| 36 |
| 37 _Codec = collections.namedtuple('_Codec', ['encoder', 'decoder']) |
| 38 CodecResult = collections.namedtuple('CodecResult', ['value', 'complete']) |
| 39 |
| 40 |
| 41 # TODO: Make these non-global. |
| 42 _UNRECOGNIZED_FIELD_MAPPINGS = {} |
| 43 _CUSTOM_MESSAGE_CODECS = {} |
| 44 _CUSTOM_FIELD_CODECS = {} |
| 45 _FIELD_TYPE_CODECS = {} |
| 46 |
| 47 |
| 48 def MapUnrecognizedFields(field_name): |
| 49 """Register field_name as a container for unrecognized fields in message.""" |
| 50 def Register(cls): |
| 51 _UNRECOGNIZED_FIELD_MAPPINGS[cls] = field_name |
| 52 return cls |
| 53 return Register |
| 54 |
| 55 |
| 56 def RegisterCustomMessageCodec(encoder, decoder): |
| 57 """Register a custom encoder/decoder for this message class.""" |
| 58 def Register(cls): |
| 59 _CUSTOM_MESSAGE_CODECS[cls] = _Codec(encoder=encoder, decoder=decoder) |
| 60 return cls |
| 61 return Register |
| 62 |
| 63 |
| 64 def RegisterCustomFieldCodec(encoder, decoder): |
| 65 """Register a custom encoder/decoder for this field.""" |
| 66 def Register(field): |
| 67 _CUSTOM_FIELD_CODECS[field] = _Codec(encoder=encoder, decoder=decoder) |
| 68 return field |
| 69 return Register |
| 70 |
| 71 |
| 72 def RegisterFieldTypeCodec(encoder, decoder): |
| 73 """Register a custom encoder/decoder for all fields of this type.""" |
| 74 def Register(field_type): |
| 75 _FIELD_TYPE_CODECS[field_type] = _Codec(encoder=encoder, decoder=decoder) |
| 76 return field_type |
| 77 return Register |
| 78 |
| 79 |
| 80 # TODO: Delete this function with the switch to proto2. |
| 81 def CopyProtoMessage(message): |
| 82 codec = protojson.ProtoJson() |
| 83 return codec.decode_message(type(message), codec.encode_message(message)) |
| 84 |
| 85 |
| 86 def MessageToJson(message, include_fields=None): |
| 87 """Convert the given message to JSON.""" |
| 88 result = _ProtoJsonApiTools.Get().encode_message(message) |
| 89 return _IncludeFields(result, message, include_fields) |
| 90 |
| 91 |
| 92 def JsonToMessage(message_type, message): |
| 93 """Convert the given JSON to a message of type message_type.""" |
| 94 return _ProtoJsonApiTools.Get().decode_message(message_type, message) |
| 95 |
| 96 |
| 97 # TODO: Do this directly, instead of via JSON. |
| 98 def DictToMessage(d, message_type): |
| 99 """Convert the given dictionary to a message of type message_type.""" |
| 100 return JsonToMessage(message_type, json.dumps(d)) |
| 101 |
| 102 |
| 103 def MessageToDict(message): |
| 104 """Convert the given message to a dictionary.""" |
| 105 return json.loads(MessageToJson(message)) |
| 106 |
| 107 |
| 108 def PyValueToMessage(message_type, value): |
| 109 """Convert the given python value to a message of type message_type.""" |
| 110 return JsonToMessage(message_type, json.dumps(value)) |
| 111 |
| 112 |
| 113 def MessageToPyValue(message): |
| 114 """Convert the given message to a python value.""" |
| 115 return json.loads(MessageToJson(message)) |
| 116 |
| 117 |
| 118 def _IncludeFields(encoded_message, message, include_fields): |
| 119 """Add the requested fields to the encoded message.""" |
| 120 if include_fields is None: |
| 121 return encoded_message |
| 122 result = json.loads(encoded_message) |
| 123 for field_name in include_fields: |
| 124 try: |
| 125 message.field_by_name(field_name) |
| 126 except KeyError: |
| 127 raise exceptions.InvalidDataError( |
| 128 'No field named %s in message of type %s' % ( |
| 129 field_name, type(message))) |
| 130 result[field_name] = None |
| 131 return json.dumps(result) |
| 132 |
| 133 |
| 134 def _GetFieldCodecs(field, attr): |
| 135 result = [ |
| 136 getattr(_CUSTOM_FIELD_CODECS.get(field), attr, None), |
| 137 getattr(_FIELD_TYPE_CODECS.get(type(field)), attr, None), |
| 138 ] |
| 139 return [x for x in result if x is not None] |
| 140 |
| 141 |
| 142 class _ProtoJsonApiTools(protojson.ProtoJson): |
| 143 """JSON encoder used by apitools clients.""" |
| 144 _INSTANCE = None |
| 145 |
| 146 @classmethod |
| 147 def Get(cls): |
| 148 if cls._INSTANCE is None: |
| 149 cls._INSTANCE = cls() |
| 150 return cls._INSTANCE |
| 151 |
| 152 def decode_message(self, message_type, encoded_message): # pylint: disable=in
valid-name |
| 153 if message_type in _CUSTOM_MESSAGE_CODECS: |
| 154 return _CUSTOM_MESSAGE_CODECS[message_type].decoder(encoded_message) |
| 155 result = super(_ProtoJsonApiTools, self).decode_message( |
| 156 message_type, encoded_message) |
| 157 return _DecodeUnknownFields(result, encoded_message) |
| 158 |
| 159 def decode_field(self, field, value): # pylint: disable=g-bad-name |
| 160 """Decode the given JSON value. |
| 161 |
| 162 Args: |
| 163 field: a messages.Field for the field we're decoding. |
| 164 value: a python value we'd like to decode. |
| 165 |
| 166 Returns: |
| 167 A value suitable for assignment to field. |
| 168 """ |
| 169 for decoder in _GetFieldCodecs(field, 'decoder'): |
| 170 result = decoder(field, value) |
| 171 value = result.value |
| 172 if result.complete: |
| 173 return value |
| 174 if isinstance(field, messages.MessageField): |
| 175 field_value = self.decode_message(field.message_type, json.dumps(value)) |
| 176 else: |
| 177 field_value = super(_ProtoJsonApiTools, self).decode_field(field, value) |
| 178 return field_value |
| 179 |
| 180 def encode_message(self, message): # pylint: disable=invalid-name |
| 181 if isinstance(message, messages.FieldList): |
| 182 return '[%s]' % (', '.join(self.encode_message(x) for x in message)) |
| 183 if type(message) in _CUSTOM_MESSAGE_CODECS: |
| 184 return _CUSTOM_MESSAGE_CODECS[type(message)].encoder(message) |
| 185 message = _EncodeUnknownFields(message) |
| 186 return super(_ProtoJsonApiTools, self).encode_message(message) |
| 187 |
| 188 def encode_field(self, field, value): # pylint: disable=g-bad-name |
| 189 """Encode the given value as JSON. |
| 190 |
| 191 Args: |
| 192 field: a messages.Field for the field we're encoding. |
| 193 value: a value for field. |
| 194 |
| 195 Returns: |
| 196 A python value suitable for json.dumps. |
| 197 """ |
| 198 for encoder in _GetFieldCodecs(field, 'encoder'): |
| 199 result = encoder(field, value) |
| 200 value = result.value |
| 201 if result.complete: |
| 202 return value |
| 203 if isinstance(field, messages.MessageField): |
| 204 value = json.loads(self.encode_message(value)) |
| 205 return super(_ProtoJsonApiTools, self).encode_field(field, value) |
| 206 |
| 207 |
| 208 # TODO: Fold this and _IncludeFields in as codecs. |
| 209 def _DecodeUnknownFields(message, encoded_message): |
| 210 """Rewrite unknown fields in message into message.destination.""" |
| 211 destination = _UNRECOGNIZED_FIELD_MAPPINGS.get(type(message)) |
| 212 if destination is None: |
| 213 return message |
| 214 pair_field = message.field_by_name(destination) |
| 215 if not isinstance(pair_field, messages.MessageField): |
| 216 raise exceptions.InvalidDataFromServerError( |
| 217 'Unrecognized fields must be mapped to a compound ' |
| 218 'message type.') |
| 219 pair_type = pair_field.message_type |
| 220 # TODO: Add more error checking around the pair |
| 221 # type being exactly what we suspect (field names, etc). |
| 222 if isinstance(pair_type.value, messages.MessageField): |
| 223 new_values = _DecodeUnknownMessages( |
| 224 message, json.loads(encoded_message), pair_type) |
| 225 else: |
| 226 new_values = _DecodeUnrecognizedFields(message, pair_type) |
| 227 setattr(message, destination, new_values) |
| 228 # We could probably get away with not setting this, but |
| 229 # why not clear it? |
| 230 setattr(message, '_Message__unrecognized_fields', {}) |
| 231 return message |
| 232 |
| 233 |
| 234 def _DecodeUnknownMessages(message, encoded_message, pair_type): |
| 235 """Process unknown fields in encoded_message of a message type.""" |
| 236 field_type = pair_type.value.type |
| 237 new_values = [] |
| 238 all_field_names = [x.name for x in message.all_fields()] |
| 239 for name, value_dict in encoded_message.iteritems(): |
| 240 if name in all_field_names: |
| 241 continue |
| 242 value = PyValueToMessage(field_type, value_dict) |
| 243 new_pair = pair_type(key=name, value=value) |
| 244 new_values.append(new_pair) |
| 245 return new_values |
| 246 |
| 247 |
| 248 def _DecodeUnrecognizedFields(message, pair_type): |
| 249 """Process unrecognized fields in message.""" |
| 250 new_values = [] |
| 251 for unknown_field in message.all_unrecognized_fields(): |
| 252 # TODO: Consider validating the variant if |
| 253 # the assignment below doesn't take care of it. It may |
| 254 # also be necessary to check it in the case that the |
| 255 # type has multiple encodings. |
| 256 value, _ = message.get_unrecognized_field_info(unknown_field) |
| 257 value_type = pair_type.field_by_name('value') |
| 258 if isinstance(value_type, messages.MessageField): |
| 259 decoded_value = DictToMessage(value, pair_type.value.message_type) |
| 260 else: |
| 261 decoded_value = value |
| 262 new_pair = pair_type(key=str(unknown_field), value=decoded_value) |
| 263 new_values.append(new_pair) |
| 264 return new_values |
| 265 |
| 266 |
| 267 def _EncodeUnknownFields(message): |
| 268 """Remap unknown fields in message out of message.source.""" |
| 269 source = _UNRECOGNIZED_FIELD_MAPPINGS.get(type(message)) |
| 270 if source is None: |
| 271 return message |
| 272 result = CopyProtoMessage(message) |
| 273 pairs_field = message.field_by_name(source) |
| 274 if not isinstance(pairs_field, messages.MessageField): |
| 275 raise exceptions.InvalidUserInputError( |
| 276 'Invalid pairs field %s' % pairs_field) |
| 277 pairs_type = pairs_field.message_type |
| 278 value_variant = pairs_type.field_by_name('value').variant |
| 279 pairs = getattr(message, source) |
| 280 for pair in pairs: |
| 281 if value_variant == messages.Variant.MESSAGE: |
| 282 encoded_value = MessageToDict(pair.value) |
| 283 else: |
| 284 encoded_value = pair.value |
| 285 result.set_unrecognized_field(pair.key, encoded_value, value_variant) |
| 286 setattr(result, source, []) |
| 287 return result |
| 288 |
| 289 |
| 290 def _SafeEncodeBytes(field, value): |
| 291 """Encode the bytes in value as urlsafe base64.""" |
| 292 try: |
| 293 if field.repeated: |
| 294 result = [base64.urlsafe_b64encode(byte) for byte in value] |
| 295 else: |
| 296 result = base64.urlsafe_b64encode(value) |
| 297 complete = True |
| 298 except TypeError: |
| 299 result = value |
| 300 complete = False |
| 301 return CodecResult(value=result, complete=complete) |
| 302 |
| 303 |
| 304 def _SafeDecodeBytes(unused_field, value): |
| 305 """Decode the urlsafe base64 value into bytes.""" |
| 306 try: |
| 307 result = base64.urlsafe_b64decode(str(value)) |
| 308 complete = True |
| 309 except TypeError: |
| 310 result = value |
| 311 complete = False |
| 312 return CodecResult(value=result, complete=complete) |
| 313 |
| 314 |
| 315 RegisterFieldTypeCodec(_SafeEncodeBytes, _SafeDecodeBytes)(messages.BytesField) |
| OLD | NEW |