Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(452)

Side by Side Diff: third_party/google-endpoints/jwkest/jws.py

Issue 2666783008: Add google-endpoints to third_party/. (Closed)
Patch Set: Created 3 years, 10 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
« no previous file with comments | « third_party/google-endpoints/jwkest/jwk.py ('k') | third_party/google-endpoints/jwkest/jwt.py » ('j') | no next file with comments »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
OLDNEW
(Empty)
1 """JSON Web Token"""
2 import six
3
4 try:
5 from builtins import str
6 from builtins import object
7 except ImportError:
8 pass
9
10 # Most of the code, ideas herein I have borrowed/stolen from other people
11 # Most notably Jeff Lindsay, Ryan Kelly and Richard Barnes
12
13 import json
14 import logging
15
16 import struct
17 from Crypto.Hash import SHA256
18 from Crypto.Hash import SHA384
19 from Crypto.Hash import SHA512
20 from Crypto.Hash import HMAC
21 from Crypto.Signature import PKCS1_v1_5
22 from Crypto.Signature import PKCS1_PSS
23 from Crypto.Util.number import bytes_to_long
24 import sys
25
26 from jwkest import b64d, as_unicode
27 from jwkest import b64e
28 from jwkest import constant_time_compare
29 from jwkest import safe_str_cmp
30 from jwkest import JWKESTException
31 from jwkest import BadSignature
32 from jwkest import UnknownAlgorithm
33 from jwkest.ecc import P256
34 from jwkest.ecc import P384
35 from jwkest.ecc import P521
36
37 from jwkest.jwk import load_x509_cert, KEYS
38 from jwkest.jwk import HeaderError
39 from jwkest.jwk import sha256_digest
40 from jwkest.jwk import sha384_digest
41 from jwkest.jwk import sha512_digest
42 from jwkest.jwk import keyrep
43
44 from jwkest.jwt import JWT
45 from jwkest.jwt import b64encode_item
46
47 logger = logging.getLogger(__name__)
48
49
50 class JWSException(JWKESTException):
51 pass
52
53
54 class NoSuitableSigningKeys(JWSException):
55 pass
56
57
58 class FormatError(JWSException):
59 pass
60
61
62 class WrongTypeOfKey(JWSException):
63 pass
64
65
66 class UnknownSignerAlg(JWSException):
67 pass
68
69
70 class SignerAlgError(JWSException):
71 pass
72
73
74 def left_hash(msg, func="HS256"):
75 """ 128 bits == 16 bytes """
76 if func == 'HS256':
77 return as_unicode(b64e(sha256_digest(msg)[:16]))
78 elif func == 'HS384':
79 return as_unicode(b64e(sha384_digest(msg)[:24]))
80 elif func == 'HS512':
81 return as_unicode(b64e(sha512_digest(msg)[:32]))
82
83
84 def mpint(b):
85 b += b"\x00"
86 return struct.pack(">L", len(b)) + b
87
88
89 def mp2bin(b):
90 # just ignore the length...
91 if b[4] == '\x00':
92 return b[5:]
93 else:
94 return b[4:]
95
96
97 class Signer(object):
98 """Abstract base class for signing algorithms."""
99 def sign(self, msg, key):
100 """Sign ``msg`` with ``key`` and return the signature."""
101 raise NotImplementedError()
102
103 def verify(self, msg, sig, key):
104 """Return True if ``sig`` is a valid signature for ``msg``."""
105 raise NotImplementedError()
106
107
108 class HMACSigner(Signer):
109 def __init__(self, digest):
110 self.digest = digest
111
112 def sign(self, msg, key):
113 h = HMAC.new(key, msg, digestmod=self.digest)
114 return h.digest()
115 # return hmac.new(key, msg, digestmod=self.digest).digest()
116
117 def verify(self, msg, sig, key):
118 if sys.version < '3':
119 if safe_str_cmp(self.sign(msg, key), sig):
120 return True
121 elif constant_time_compare(self.sign(msg, key), sig):
122 return True
123 raise BadSignature(repr(sig))
124
125
126 class RSASigner(Signer):
127 def __init__(self, digest):
128 self.digest = digest
129
130 def sign(self, msg, key):
131 h = self.digest.new(msg)
132 signer = PKCS1_v1_5.new(key)
133 return signer.sign(h)
134
135 def verify(self, msg, sig, key):
136 h = self.digest.new(msg)
137 verifier = PKCS1_v1_5.new(key)
138 try:
139 if verifier.verify(h, sig):
140 return True
141 else:
142 raise BadSignature()
143 except ValueError as e:
144 raise BadSignature(str(e))
145
146
147 class DSASigner(Signer):
148 def __init__(self, digest, sign):
149 self.digest = digest
150 self._sign = sign
151
152 def sign(self, msg, key):
153 # verify the key
154 h = bytes_to_long(self.digest.new(msg).digest())
155 return self._sign.sign(h, key)
156
157 def verify(self, msg, sig, key):
158 h = bytes_to_long(self.digest.new(msg).digest())
159 return self._sign.verify(h, sig, key)
160
161
162 class PSSSigner(Signer):
163 def __init__(self, digest):
164 self.digest = digest
165
166 def sign(self, msg, key):
167 h = self.digest.new(msg)
168 signer = PKCS1_PSS.new(key)
169 return signer.sign(h)
170
171 def verify(self, msg, sig, key):
172 h = self.digest.new(msg)
173 verifier = PKCS1_PSS.new(key)
174 res = verifier.verify(h, sig)
175 if not res:
176 raise BadSignature()
177 else:
178 return True
179
180
181 SIGNER_ALGS = {
182 'HS256': HMACSigner(SHA256),
183 'HS384': HMACSigner(SHA384),
184 'HS512': HMACSigner(SHA512),
185
186 'RS256': RSASigner(SHA256),
187 'RS384': RSASigner(SHA384),
188 'RS512': RSASigner(SHA512),
189
190 'ES256': DSASigner(SHA256, P256),
191 'ES384': DSASigner(SHA384, P384),
192 'ES512': DSASigner(SHA512, P521),
193
194 'PS256': PSSSigner(SHA256),
195 'PS384': PSSSigner(SHA384),
196 'PS512': PSSSigner(SHA512),
197
198 'none': None
199 }
200
201
202 def alg2keytype(alg):
203 if not alg or alg.lower() == "none":
204 return "none"
205 elif alg.startswith("RS") or alg.startswith("PS"):
206 return "RSA"
207 elif alg.startswith("HS") or alg.startswith("A"):
208 return "oct"
209 elif alg.startswith("ES"):
210 return "EC"
211 else:
212 return None
213
214
215 class JWSig(JWT):
216 def sign_input(self):
217 return self.b64part[0] + b'.' + self.b64part[1]
218
219 def signature(self):
220 return self.part[2]
221
222
223 class JWx(object):
224 args = ["alg", "jku", "jwk", "x5u", "x5t", "x5c", "kid", "typ", "cty",
225 "crit"]
226 """
227 :param alg: The signing algorithm
228 :param jku: a URI that refers to a resource for a set of JSON-encoded
229 public keys, one of which corresponds to the key used to digitally
230 sign the JWS
231 :param jwk: A JSON Web Key that corresponds to the key used to
232 digitally sign the JWS
233 :param x5u: a URI that refers to a resource for the X.509 public key
234 certificate or certificate chain [RFC5280] corresponding to the key
235 used to digitally sign the JWS.
236 :param x5t: a base64url encoded SHA-1 thumbprint (a.k.a. digest) of the
237 DER encoding of the X.509 certificate [RFC5280] corresponding to
238 the key used to digitally sign the JWS.
239 :param x5c: the X.509 public key certificate or certificate chain
240 corresponding to the key used to digitally sign the JWS.
241 :param kid: a hint indicating which key was used to secure the JWS.
242 :param typ: the type of this object. 'JWS' == JWS Compact Serialization
243 'JWS+JSON' == JWS JSON Serialization
244 :param cty: the type of the secured content
245 :param crit: indicates which extensions that are being used and MUST
246 be understood and processed.
247 :param kwargs: Extra header parameters
248 :return: A class instance
249 """
250
251 def __init__(self, msg=None, with_digest=False, **kwargs):
252 self.msg = msg
253
254 self._dict = {}
255 self.with_digest = with_digest
256 self.jwt = None
257
258 if kwargs:
259 for key in self.args:
260 try:
261 _val = kwargs[key]
262 except KeyError:
263 if key == "alg":
264 self._dict[key] = "none"
265 continue
266
267 if key == "jwk":
268 if isinstance(_val, dict):
269 self._dict["jwk"] = keyrep(_val)
270 elif isinstance(_val, str):
271 self._dict["jwk"] = keyrep(json.loads(_val))
272 else:
273 self._dict["jwk"] = _val
274 elif key == "x5c" or key == "crit":
275 self._dict["x5c"] = _val or []
276 else:
277 self._dict[key] = _val
278
279 def __contains__(self, item):
280 return item in self._dict
281
282 def __getitem__(self, item):
283 return self._dict[item]
284
285 def __setitem__(self, key, value):
286 self._dict[key] = value
287
288 def __getattr__(self, item):
289 try:
290 return self._dict[item]
291 except KeyError:
292 raise AttributeError(item)
293
294 def keys(self):
295 return list(self._dict.keys())
296
297 def headers(self, extra=None):
298 _extra = extra or {}
299 _header = {}
300 for param in self.args:
301 try:
302 _header[param] = _extra[param]
303 except KeyError:
304 try:
305 if self._dict[param]:
306 _header[param] = self._dict[param]
307 except KeyError:
308 pass
309
310 if "jwk" in self:
311 _header["jwk"] = self["jwk"].serialize()
312 elif "jwk" in _extra:
313 _header["jwk"] = extra["jwk"].serialize()
314
315 if "kid" in self:
316 try:
317 assert isinstance(self["kid"], six.string_types)
318 except AssertionError:
319 raise HeaderError("kid of wrong value type")
320
321 return _header
322
323 def _get_keys(self):
324 logger.debug("_get_keys(): self._dict.keys={0}".format(
325 self._dict.keys()))
326
327 if "jwk" in self:
328 return [self["jwk"]]
329 elif "jku" in self:
330 keys = KEYS()
331 keys.load_from_url(self["jku"])
332 return keys.as_dict()
333 elif "x5u" in self:
334 try:
335 return {"rsa": [load_x509_cert(self["x5u"], {})]}
336 except Exception:
337 # ca_chain = load_x509_cert_chain(self["x5u"])
338 pass
339
340 return {}
341
342 def alg2keytype(self, alg):
343 return alg2keytype(alg)
344
345 def _pick_keys(self, keys, use="", alg=""):
346 """
347 The assumption is that upper layer has made certain you only get
348 keys you can use.
349
350 :param keys: A list of KEY instances
351 :return: A list of KEY instances that fulfill the requirements
352 """
353 if not alg:
354 alg = self["alg"]
355
356 if alg == "none":
357 return []
358
359 _k = self.alg2keytype(alg)
360 if _k is None:
361 logger.error("Unknown arlgorithm '%s'" % alg)
362 return []
363
364 logger.debug("Picking key by key type={0}".format(_k))
365 _kty = [_k.lower(), _k.upper(), _k.lower().encode("utf-8"),
366 _k.upper().encode("utf-8")]
367 _keys = [k for k in keys if k.kty in _kty]
368 try:
369 _kid = self["kid"]
370 except KeyError:
371 try:
372 _kid = self.jwt.headers["kid"]
373 except (AttributeError, KeyError):
374 _kid = None
375
376 logger.debug("Picking key based on alg={0}, kid={1} and use={2}".format(
377 alg, _kid, use))
378
379 pkey = []
380 for _key in _keys:
381 logger.debug("KEY: {0}".format(_key))
382 if _kid:
383 try:
384 assert _kid == _key.kid
385 except (KeyError, AttributeError):
386 pass
387 except AssertionError:
388 continue
389
390 if use and _key.use and _key.use != use:
391 continue
392
393 if alg and _key.alg and _key.alg != alg:
394 continue
395
396 pkey.append(_key)
397
398 return pkey
399
400 def _decode(self, payload):
401 _msg = b64d(bytes(payload))
402 if "cty" in self:
403 if self["cty"] == "JWT":
404 _msg = json.loads(_msg)
405 return _msg
406
407 def dump_header(self):
408 return dict([(x, self._dict[x]) for x in self.args if x in self._dict])
409
410
411 class JWS(JWx):
412
413 def alg_keys(self, keys, use, protected=None):
414 try:
415 _alg = self["alg"]
416 except KeyError:
417 self["alg"] = _alg = "none"
418 else:
419 if not _alg:
420 self["alg"] = _alg = "none"
421
422 if keys:
423 keys = self._pick_keys(keys, use=use, alg=_alg)
424 else:
425 keys = self._pick_keys(self._get_keys(), use=use, alg=_alg)
426
427 xargs = protected or {}
428 xargs["alg"] = _alg
429
430 if keys:
431 key = keys[0]
432 if key.kid:
433 xargs["kid"] = key.kid
434 elif not _alg or _alg.lower() == "none":
435 key = None
436 else:
437 if "kid" in self:
438 raise NoSuitableSigningKeys(
439 "No key for algorithm: %s and kid: %s" % (_alg,
440 self["kid"]))
441 else:
442 raise NoSuitableSigningKeys("No key for algorithm: %s" % _alg)
443
444 return key, xargs, _alg
445
446 def sign_compact(self, keys=None, protected=None):
447 """
448 Produce a JWS using the JWS Compact Serialization
449
450 :param keys: A dictionary of keys
451 :param protected: The protected headers (a dictionary)
452 :return:
453 """
454
455 key, xargs, _alg = self.alg_keys(keys, 'sig', protected)
456
457 if "typ" in self:
458 xargs["typ"] = self["typ"]
459
460 jwt = JWSig(**xargs)
461 if _alg == "none":
462 return jwt.pack(parts=[self.msg, ""])
463
464 # All other cases
465 try:
466 _signer = SIGNER_ALGS[_alg]
467 except KeyError:
468 raise UnknownAlgorithm(_alg)
469
470 _input = jwt.pack(parts=[self.msg])
471 sig = _signer.sign(_input.encode("utf-8"), key.get_key(alg=_alg, private =True))
472 logger.debug("Signed message using key with kid=%s" % key.kid)
473 return ".".join([_input, b64encode_item(sig).decode("utf-8")])
474
475 def verify_compact(self, jws, keys=None, allow_none=False, sigalg=None):
476 """
477 Verify a JWT signature
478
479 :param jws:
480 :param keys:
481 :param allow_none: If signature algorithm 'none' is allowed
482 :param sigalg: Expected sigalg
483 :return:
484 """
485 jwt = JWSig().unpack(jws)
486 self.jwt = jwt
487
488 try:
489 _alg = jwt.headers["alg"]
490 except KeyError:
491 _alg = None
492 else:
493 if _alg is None or _alg.lower() == "none":
494 if allow_none:
495 self.msg = jwt.payload()
496 return self.msg
497 else:
498 raise SignerAlgError("none not allowed")
499
500 if "alg" in self and _alg:
501 if self["alg"] != _alg:
502 raise SignerAlgError("Wrong signing algorithm")
503
504 if sigalg and sigalg != _alg:
505 raise SignerAlgError("Expected {0} got {1}".format(
506 sigalg, jwt.headers["alg"]))
507
508 self["alg"] = _alg
509
510 if keys:
511 _keys = self._pick_keys(keys)
512 else:
513 _keys = self._pick_keys(self._get_keys())
514
515 if not _keys:
516 if "kid" in self:
517 raise NoSuitableSigningKeys(
518 "No key with kid: %s" % (self["kid"]))
519 elif "kid" in self.jwt.headers:
520 raise NoSuitableSigningKeys(
521 "No key with kid: %s" % (self.jwt.headers["kid"]))
522 else:
523 raise NoSuitableSigningKeys("No key for algorithm: %s" % _alg)
524
525 verifier = SIGNER_ALGS[_alg]
526
527 for key in _keys:
528 try:
529 res = verifier.verify(jwt.sign_input(), jwt.signature(),
530 key.get_key(alg=_alg, private=False))
531 except BadSignature:
532 pass
533 else:
534 if res is True:
535 logger.debug(
536 "Verified message using key with kid=%s" % key.kid)
537 self.msg = jwt.payload()
538 return self.msg
539
540 raise BadSignature()
541
542 def sign_json(self, per_signature_header=None, **kwargs):
543 """
544 Produce JWS using the JWS JSON Serialization
545
546 :param per_signature_header: Header parameter values that are to be
547 applied to a specific signature
548 :return:
549 """
550 res = {"signatures": []}
551
552 if per_signature_header is None:
553 per_signature_header = [{"alg": "none"}]
554
555 for _kwa in per_signature_header:
556 _kwa.update(kwargs)
557 _jws = JWS(self.msg, **_kwa)
558 header, payload, signature = _jws.sign_compact().split(".")
559 res["signatures"].append({"header": header,
560 "signature": signature})
561
562 res["payload"] = self.msg
563
564 return res
565
566 def verify_json(self, jws, keys=None, allow_none=False, sigalg=None):
567 """
568
569 :param jws:
570 :param keys:
571 :return:
572 """
573
574 _jwss = json.load(jws)
575
576 try:
577 _payload = _jwss["payload"]
578 except KeyError:
579 raise FormatError("Missing payload")
580
581 try:
582 _signs = _jwss["signatures"]
583 except KeyError:
584 raise FormatError("Missing signatures")
585
586 _claim = None
587 for _sign in _signs:
588 token = b".".join([_sign["protected"].encode(), _payload.encode(), _ sign["signature"].encode()])
589 header = _sign.get("header", {})
590 self.__init__(**header)
591 _tmp = self.verify_compact(token, keys, allow_none, sigalg)
592 if _claim is None:
593 _claim = _tmp
594 else:
595 assert _claim == _tmp
596
597 return _claim
598
599 def is_jws(self, token):
600 """
601
602 :param token:
603 :return:
604 """
605 try:
606 jwt = JWSig().unpack(token)
607 except Exception:
608 return False
609
610 try:
611 assert "alg" in jwt.headers
612 except AssertionError:
613 return False
614 else:
615 if jwt.headers["alg"] is None:
616 jwt.headers["alg"] = "none"
617
618 try:
619 assert jwt.headers["alg"] in SIGNER_ALGS
620 except AssertionError:
621 logger.debug("UnknownSignerAlg: %s" % jwt.headers["alg"])
622 return False
623 else:
624 self.jwt = jwt
625 return True
626
627
628 def factory(token):
629 _jw = JWS()
630 if _jw.is_jws(token):
631 return _jw
632 else:
633 return None
OLDNEW
« no previous file with comments | « third_party/google-endpoints/jwkest/jwk.py ('k') | third_party/google-endpoints/jwkest/jwt.py » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698