OLD | NEW |
(Empty) | |
| 1 """JSON Web Token""" |
| 2 import six |
| 3 |
| 4 try: |
| 5 from builtins import str |
| 6 from builtins import object |
| 7 except ImportError: |
| 8 pass |
| 9 |
| 10 # Most of the code, ideas herein I have borrowed/stolen from other people |
| 11 # Most notably Jeff Lindsay, Ryan Kelly and Richard Barnes |
| 12 |
| 13 import json |
| 14 import logging |
| 15 |
| 16 import struct |
| 17 from Crypto.Hash import SHA256 |
| 18 from Crypto.Hash import SHA384 |
| 19 from Crypto.Hash import SHA512 |
| 20 from Crypto.Hash import HMAC |
| 21 from Crypto.Signature import PKCS1_v1_5 |
| 22 from Crypto.Signature import PKCS1_PSS |
| 23 from Crypto.Util.number import bytes_to_long |
| 24 import sys |
| 25 |
| 26 from jwkest import b64d, as_unicode |
| 27 from jwkest import b64e |
| 28 from jwkest import constant_time_compare |
| 29 from jwkest import safe_str_cmp |
| 30 from jwkest import JWKESTException |
| 31 from jwkest import BadSignature |
| 32 from jwkest import UnknownAlgorithm |
| 33 from jwkest.ecc import P256 |
| 34 from jwkest.ecc import P384 |
| 35 from jwkest.ecc import P521 |
| 36 |
| 37 from jwkest.jwk import load_x509_cert, KEYS |
| 38 from jwkest.jwk import HeaderError |
| 39 from jwkest.jwk import sha256_digest |
| 40 from jwkest.jwk import sha384_digest |
| 41 from jwkest.jwk import sha512_digest |
| 42 from jwkest.jwk import keyrep |
| 43 |
| 44 from jwkest.jwt import JWT |
| 45 from jwkest.jwt import b64encode_item |
| 46 |
| 47 logger = logging.getLogger(__name__) |
| 48 |
| 49 |
| 50 class JWSException(JWKESTException): |
| 51 pass |
| 52 |
| 53 |
| 54 class NoSuitableSigningKeys(JWSException): |
| 55 pass |
| 56 |
| 57 |
| 58 class FormatError(JWSException): |
| 59 pass |
| 60 |
| 61 |
| 62 class WrongTypeOfKey(JWSException): |
| 63 pass |
| 64 |
| 65 |
| 66 class UnknownSignerAlg(JWSException): |
| 67 pass |
| 68 |
| 69 |
| 70 class SignerAlgError(JWSException): |
| 71 pass |
| 72 |
| 73 |
| 74 def left_hash(msg, func="HS256"): |
| 75 """ 128 bits == 16 bytes """ |
| 76 if func == 'HS256': |
| 77 return as_unicode(b64e(sha256_digest(msg)[:16])) |
| 78 elif func == 'HS384': |
| 79 return as_unicode(b64e(sha384_digest(msg)[:24])) |
| 80 elif func == 'HS512': |
| 81 return as_unicode(b64e(sha512_digest(msg)[:32])) |
| 82 |
| 83 |
| 84 def mpint(b): |
| 85 b += b"\x00" |
| 86 return struct.pack(">L", len(b)) + b |
| 87 |
| 88 |
| 89 def mp2bin(b): |
| 90 # just ignore the length... |
| 91 if b[4] == '\x00': |
| 92 return b[5:] |
| 93 else: |
| 94 return b[4:] |
| 95 |
| 96 |
| 97 class Signer(object): |
| 98 """Abstract base class for signing algorithms.""" |
| 99 def sign(self, msg, key): |
| 100 """Sign ``msg`` with ``key`` and return the signature.""" |
| 101 raise NotImplementedError() |
| 102 |
| 103 def verify(self, msg, sig, key): |
| 104 """Return True if ``sig`` is a valid signature for ``msg``.""" |
| 105 raise NotImplementedError() |
| 106 |
| 107 |
| 108 class HMACSigner(Signer): |
| 109 def __init__(self, digest): |
| 110 self.digest = digest |
| 111 |
| 112 def sign(self, msg, key): |
| 113 h = HMAC.new(key, msg, digestmod=self.digest) |
| 114 return h.digest() |
| 115 # return hmac.new(key, msg, digestmod=self.digest).digest() |
| 116 |
| 117 def verify(self, msg, sig, key): |
| 118 if sys.version < '3': |
| 119 if safe_str_cmp(self.sign(msg, key), sig): |
| 120 return True |
| 121 elif constant_time_compare(self.sign(msg, key), sig): |
| 122 return True |
| 123 raise BadSignature(repr(sig)) |
| 124 |
| 125 |
| 126 class RSASigner(Signer): |
| 127 def __init__(self, digest): |
| 128 self.digest = digest |
| 129 |
| 130 def sign(self, msg, key): |
| 131 h = self.digest.new(msg) |
| 132 signer = PKCS1_v1_5.new(key) |
| 133 return signer.sign(h) |
| 134 |
| 135 def verify(self, msg, sig, key): |
| 136 h = self.digest.new(msg) |
| 137 verifier = PKCS1_v1_5.new(key) |
| 138 try: |
| 139 if verifier.verify(h, sig): |
| 140 return True |
| 141 else: |
| 142 raise BadSignature() |
| 143 except ValueError as e: |
| 144 raise BadSignature(str(e)) |
| 145 |
| 146 |
| 147 class DSASigner(Signer): |
| 148 def __init__(self, digest, sign): |
| 149 self.digest = digest |
| 150 self._sign = sign |
| 151 |
| 152 def sign(self, msg, key): |
| 153 # verify the key |
| 154 h = bytes_to_long(self.digest.new(msg).digest()) |
| 155 return self._sign.sign(h, key) |
| 156 |
| 157 def verify(self, msg, sig, key): |
| 158 h = bytes_to_long(self.digest.new(msg).digest()) |
| 159 return self._sign.verify(h, sig, key) |
| 160 |
| 161 |
| 162 class PSSSigner(Signer): |
| 163 def __init__(self, digest): |
| 164 self.digest = digest |
| 165 |
| 166 def sign(self, msg, key): |
| 167 h = self.digest.new(msg) |
| 168 signer = PKCS1_PSS.new(key) |
| 169 return signer.sign(h) |
| 170 |
| 171 def verify(self, msg, sig, key): |
| 172 h = self.digest.new(msg) |
| 173 verifier = PKCS1_PSS.new(key) |
| 174 res = verifier.verify(h, sig) |
| 175 if not res: |
| 176 raise BadSignature() |
| 177 else: |
| 178 return True |
| 179 |
| 180 |
| 181 SIGNER_ALGS = { |
| 182 'HS256': HMACSigner(SHA256), |
| 183 'HS384': HMACSigner(SHA384), |
| 184 'HS512': HMACSigner(SHA512), |
| 185 |
| 186 'RS256': RSASigner(SHA256), |
| 187 'RS384': RSASigner(SHA384), |
| 188 'RS512': RSASigner(SHA512), |
| 189 |
| 190 'ES256': DSASigner(SHA256, P256), |
| 191 'ES384': DSASigner(SHA384, P384), |
| 192 'ES512': DSASigner(SHA512, P521), |
| 193 |
| 194 'PS256': PSSSigner(SHA256), |
| 195 'PS384': PSSSigner(SHA384), |
| 196 'PS512': PSSSigner(SHA512), |
| 197 |
| 198 'none': None |
| 199 } |
| 200 |
| 201 |
| 202 def alg2keytype(alg): |
| 203 if not alg or alg.lower() == "none": |
| 204 return "none" |
| 205 elif alg.startswith("RS") or alg.startswith("PS"): |
| 206 return "RSA" |
| 207 elif alg.startswith("HS") or alg.startswith("A"): |
| 208 return "oct" |
| 209 elif alg.startswith("ES"): |
| 210 return "EC" |
| 211 else: |
| 212 return None |
| 213 |
| 214 |
| 215 class JWSig(JWT): |
| 216 def sign_input(self): |
| 217 return self.b64part[0] + b'.' + self.b64part[1] |
| 218 |
| 219 def signature(self): |
| 220 return self.part[2] |
| 221 |
| 222 |
| 223 class JWx(object): |
| 224 args = ["alg", "jku", "jwk", "x5u", "x5t", "x5c", "kid", "typ", "cty", |
| 225 "crit"] |
| 226 """ |
| 227 :param alg: The signing algorithm |
| 228 :param jku: a URI that refers to a resource for a set of JSON-encoded |
| 229 public keys, one of which corresponds to the key used to digitally |
| 230 sign the JWS |
| 231 :param jwk: A JSON Web Key that corresponds to the key used to |
| 232 digitally sign the JWS |
| 233 :param x5u: a URI that refers to a resource for the X.509 public key |
| 234 certificate or certificate chain [RFC5280] corresponding to the key |
| 235 used to digitally sign the JWS. |
| 236 :param x5t: a base64url encoded SHA-1 thumbprint (a.k.a. digest) of the |
| 237 DER encoding of the X.509 certificate [RFC5280] corresponding to |
| 238 the key used to digitally sign the JWS. |
| 239 :param x5c: the X.509 public key certificate or certificate chain |
| 240 corresponding to the key used to digitally sign the JWS. |
| 241 :param kid: a hint indicating which key was used to secure the JWS. |
| 242 :param typ: the type of this object. 'JWS' == JWS Compact Serialization |
| 243 'JWS+JSON' == JWS JSON Serialization |
| 244 :param cty: the type of the secured content |
| 245 :param crit: indicates which extensions that are being used and MUST |
| 246 be understood and processed. |
| 247 :param kwargs: Extra header parameters |
| 248 :return: A class instance |
| 249 """ |
| 250 |
| 251 def __init__(self, msg=None, with_digest=False, **kwargs): |
| 252 self.msg = msg |
| 253 |
| 254 self._dict = {} |
| 255 self.with_digest = with_digest |
| 256 self.jwt = None |
| 257 |
| 258 if kwargs: |
| 259 for key in self.args: |
| 260 try: |
| 261 _val = kwargs[key] |
| 262 except KeyError: |
| 263 if key == "alg": |
| 264 self._dict[key] = "none" |
| 265 continue |
| 266 |
| 267 if key == "jwk": |
| 268 if isinstance(_val, dict): |
| 269 self._dict["jwk"] = keyrep(_val) |
| 270 elif isinstance(_val, str): |
| 271 self._dict["jwk"] = keyrep(json.loads(_val)) |
| 272 else: |
| 273 self._dict["jwk"] = _val |
| 274 elif key == "x5c" or key == "crit": |
| 275 self._dict["x5c"] = _val or [] |
| 276 else: |
| 277 self._dict[key] = _val |
| 278 |
| 279 def __contains__(self, item): |
| 280 return item in self._dict |
| 281 |
| 282 def __getitem__(self, item): |
| 283 return self._dict[item] |
| 284 |
| 285 def __setitem__(self, key, value): |
| 286 self._dict[key] = value |
| 287 |
| 288 def __getattr__(self, item): |
| 289 try: |
| 290 return self._dict[item] |
| 291 except KeyError: |
| 292 raise AttributeError(item) |
| 293 |
| 294 def keys(self): |
| 295 return list(self._dict.keys()) |
| 296 |
| 297 def headers(self, extra=None): |
| 298 _extra = extra or {} |
| 299 _header = {} |
| 300 for param in self.args: |
| 301 try: |
| 302 _header[param] = _extra[param] |
| 303 except KeyError: |
| 304 try: |
| 305 if self._dict[param]: |
| 306 _header[param] = self._dict[param] |
| 307 except KeyError: |
| 308 pass |
| 309 |
| 310 if "jwk" in self: |
| 311 _header["jwk"] = self["jwk"].serialize() |
| 312 elif "jwk" in _extra: |
| 313 _header["jwk"] = extra["jwk"].serialize() |
| 314 |
| 315 if "kid" in self: |
| 316 try: |
| 317 assert isinstance(self["kid"], six.string_types) |
| 318 except AssertionError: |
| 319 raise HeaderError("kid of wrong value type") |
| 320 |
| 321 return _header |
| 322 |
| 323 def _get_keys(self): |
| 324 logger.debug("_get_keys(): self._dict.keys={0}".format( |
| 325 self._dict.keys())) |
| 326 |
| 327 if "jwk" in self: |
| 328 return [self["jwk"]] |
| 329 elif "jku" in self: |
| 330 keys = KEYS() |
| 331 keys.load_from_url(self["jku"]) |
| 332 return keys.as_dict() |
| 333 elif "x5u" in self: |
| 334 try: |
| 335 return {"rsa": [load_x509_cert(self["x5u"], {})]} |
| 336 except Exception: |
| 337 # ca_chain = load_x509_cert_chain(self["x5u"]) |
| 338 pass |
| 339 |
| 340 return {} |
| 341 |
| 342 def alg2keytype(self, alg): |
| 343 return alg2keytype(alg) |
| 344 |
| 345 def _pick_keys(self, keys, use="", alg=""): |
| 346 """ |
| 347 The assumption is that upper layer has made certain you only get |
| 348 keys you can use. |
| 349 |
| 350 :param keys: A list of KEY instances |
| 351 :return: A list of KEY instances that fulfill the requirements |
| 352 """ |
| 353 if not alg: |
| 354 alg = self["alg"] |
| 355 |
| 356 if alg == "none": |
| 357 return [] |
| 358 |
| 359 _k = self.alg2keytype(alg) |
| 360 if _k is None: |
| 361 logger.error("Unknown arlgorithm '%s'" % alg) |
| 362 return [] |
| 363 |
| 364 logger.debug("Picking key by key type={0}".format(_k)) |
| 365 _kty = [_k.lower(), _k.upper(), _k.lower().encode("utf-8"), |
| 366 _k.upper().encode("utf-8")] |
| 367 _keys = [k for k in keys if k.kty in _kty] |
| 368 try: |
| 369 _kid = self["kid"] |
| 370 except KeyError: |
| 371 try: |
| 372 _kid = self.jwt.headers["kid"] |
| 373 except (AttributeError, KeyError): |
| 374 _kid = None |
| 375 |
| 376 logger.debug("Picking key based on alg={0}, kid={1} and use={2}".format( |
| 377 alg, _kid, use)) |
| 378 |
| 379 pkey = [] |
| 380 for _key in _keys: |
| 381 logger.debug("KEY: {0}".format(_key)) |
| 382 if _kid: |
| 383 try: |
| 384 assert _kid == _key.kid |
| 385 except (KeyError, AttributeError): |
| 386 pass |
| 387 except AssertionError: |
| 388 continue |
| 389 |
| 390 if use and _key.use and _key.use != use: |
| 391 continue |
| 392 |
| 393 if alg and _key.alg and _key.alg != alg: |
| 394 continue |
| 395 |
| 396 pkey.append(_key) |
| 397 |
| 398 return pkey |
| 399 |
| 400 def _decode(self, payload): |
| 401 _msg = b64d(bytes(payload)) |
| 402 if "cty" in self: |
| 403 if self["cty"] == "JWT": |
| 404 _msg = json.loads(_msg) |
| 405 return _msg |
| 406 |
| 407 def dump_header(self): |
| 408 return dict([(x, self._dict[x]) for x in self.args if x in self._dict]) |
| 409 |
| 410 |
| 411 class JWS(JWx): |
| 412 |
| 413 def alg_keys(self, keys, use, protected=None): |
| 414 try: |
| 415 _alg = self["alg"] |
| 416 except KeyError: |
| 417 self["alg"] = _alg = "none" |
| 418 else: |
| 419 if not _alg: |
| 420 self["alg"] = _alg = "none" |
| 421 |
| 422 if keys: |
| 423 keys = self._pick_keys(keys, use=use, alg=_alg) |
| 424 else: |
| 425 keys = self._pick_keys(self._get_keys(), use=use, alg=_alg) |
| 426 |
| 427 xargs = protected or {} |
| 428 xargs["alg"] = _alg |
| 429 |
| 430 if keys: |
| 431 key = keys[0] |
| 432 if key.kid: |
| 433 xargs["kid"] = key.kid |
| 434 elif not _alg or _alg.lower() == "none": |
| 435 key = None |
| 436 else: |
| 437 if "kid" in self: |
| 438 raise NoSuitableSigningKeys( |
| 439 "No key for algorithm: %s and kid: %s" % (_alg, |
| 440 self["kid"])) |
| 441 else: |
| 442 raise NoSuitableSigningKeys("No key for algorithm: %s" % _alg) |
| 443 |
| 444 return key, xargs, _alg |
| 445 |
| 446 def sign_compact(self, keys=None, protected=None): |
| 447 """ |
| 448 Produce a JWS using the JWS Compact Serialization |
| 449 |
| 450 :param keys: A dictionary of keys |
| 451 :param protected: The protected headers (a dictionary) |
| 452 :return: |
| 453 """ |
| 454 |
| 455 key, xargs, _alg = self.alg_keys(keys, 'sig', protected) |
| 456 |
| 457 if "typ" in self: |
| 458 xargs["typ"] = self["typ"] |
| 459 |
| 460 jwt = JWSig(**xargs) |
| 461 if _alg == "none": |
| 462 return jwt.pack(parts=[self.msg, ""]) |
| 463 |
| 464 # All other cases |
| 465 try: |
| 466 _signer = SIGNER_ALGS[_alg] |
| 467 except KeyError: |
| 468 raise UnknownAlgorithm(_alg) |
| 469 |
| 470 _input = jwt.pack(parts=[self.msg]) |
| 471 sig = _signer.sign(_input.encode("utf-8"), key.get_key(alg=_alg, private
=True)) |
| 472 logger.debug("Signed message using key with kid=%s" % key.kid) |
| 473 return ".".join([_input, b64encode_item(sig).decode("utf-8")]) |
| 474 |
| 475 def verify_compact(self, jws, keys=None, allow_none=False, sigalg=None): |
| 476 """ |
| 477 Verify a JWT signature |
| 478 |
| 479 :param jws: |
| 480 :param keys: |
| 481 :param allow_none: If signature algorithm 'none' is allowed |
| 482 :param sigalg: Expected sigalg |
| 483 :return: |
| 484 """ |
| 485 jwt = JWSig().unpack(jws) |
| 486 self.jwt = jwt |
| 487 |
| 488 try: |
| 489 _alg = jwt.headers["alg"] |
| 490 except KeyError: |
| 491 _alg = None |
| 492 else: |
| 493 if _alg is None or _alg.lower() == "none": |
| 494 if allow_none: |
| 495 self.msg = jwt.payload() |
| 496 return self.msg |
| 497 else: |
| 498 raise SignerAlgError("none not allowed") |
| 499 |
| 500 if "alg" in self and _alg: |
| 501 if self["alg"] != _alg: |
| 502 raise SignerAlgError("Wrong signing algorithm") |
| 503 |
| 504 if sigalg and sigalg != _alg: |
| 505 raise SignerAlgError("Expected {0} got {1}".format( |
| 506 sigalg, jwt.headers["alg"])) |
| 507 |
| 508 self["alg"] = _alg |
| 509 |
| 510 if keys: |
| 511 _keys = self._pick_keys(keys) |
| 512 else: |
| 513 _keys = self._pick_keys(self._get_keys()) |
| 514 |
| 515 if not _keys: |
| 516 if "kid" in self: |
| 517 raise NoSuitableSigningKeys( |
| 518 "No key with kid: %s" % (self["kid"])) |
| 519 elif "kid" in self.jwt.headers: |
| 520 raise NoSuitableSigningKeys( |
| 521 "No key with kid: %s" % (self.jwt.headers["kid"])) |
| 522 else: |
| 523 raise NoSuitableSigningKeys("No key for algorithm: %s" % _alg) |
| 524 |
| 525 verifier = SIGNER_ALGS[_alg] |
| 526 |
| 527 for key in _keys: |
| 528 try: |
| 529 res = verifier.verify(jwt.sign_input(), jwt.signature(), |
| 530 key.get_key(alg=_alg, private=False)) |
| 531 except BadSignature: |
| 532 pass |
| 533 else: |
| 534 if res is True: |
| 535 logger.debug( |
| 536 "Verified message using key with kid=%s" % key.kid) |
| 537 self.msg = jwt.payload() |
| 538 return self.msg |
| 539 |
| 540 raise BadSignature() |
| 541 |
| 542 def sign_json(self, per_signature_header=None, **kwargs): |
| 543 """ |
| 544 Produce JWS using the JWS JSON Serialization |
| 545 |
| 546 :param per_signature_header: Header parameter values that are to be |
| 547 applied to a specific signature |
| 548 :return: |
| 549 """ |
| 550 res = {"signatures": []} |
| 551 |
| 552 if per_signature_header is None: |
| 553 per_signature_header = [{"alg": "none"}] |
| 554 |
| 555 for _kwa in per_signature_header: |
| 556 _kwa.update(kwargs) |
| 557 _jws = JWS(self.msg, **_kwa) |
| 558 header, payload, signature = _jws.sign_compact().split(".") |
| 559 res["signatures"].append({"header": header, |
| 560 "signature": signature}) |
| 561 |
| 562 res["payload"] = self.msg |
| 563 |
| 564 return res |
| 565 |
| 566 def verify_json(self, jws, keys=None, allow_none=False, sigalg=None): |
| 567 """ |
| 568 |
| 569 :param jws: |
| 570 :param keys: |
| 571 :return: |
| 572 """ |
| 573 |
| 574 _jwss = json.load(jws) |
| 575 |
| 576 try: |
| 577 _payload = _jwss["payload"] |
| 578 except KeyError: |
| 579 raise FormatError("Missing payload") |
| 580 |
| 581 try: |
| 582 _signs = _jwss["signatures"] |
| 583 except KeyError: |
| 584 raise FormatError("Missing signatures") |
| 585 |
| 586 _claim = None |
| 587 for _sign in _signs: |
| 588 token = b".".join([_sign["protected"].encode(), _payload.encode(), _
sign["signature"].encode()]) |
| 589 header = _sign.get("header", {}) |
| 590 self.__init__(**header) |
| 591 _tmp = self.verify_compact(token, keys, allow_none, sigalg) |
| 592 if _claim is None: |
| 593 _claim = _tmp |
| 594 else: |
| 595 assert _claim == _tmp |
| 596 |
| 597 return _claim |
| 598 |
| 599 def is_jws(self, token): |
| 600 """ |
| 601 |
| 602 :param token: |
| 603 :return: |
| 604 """ |
| 605 try: |
| 606 jwt = JWSig().unpack(token) |
| 607 except Exception: |
| 608 return False |
| 609 |
| 610 try: |
| 611 assert "alg" in jwt.headers |
| 612 except AssertionError: |
| 613 return False |
| 614 else: |
| 615 if jwt.headers["alg"] is None: |
| 616 jwt.headers["alg"] = "none" |
| 617 |
| 618 try: |
| 619 assert jwt.headers["alg"] in SIGNER_ALGS |
| 620 except AssertionError: |
| 621 logger.debug("UnknownSignerAlg: %s" % jwt.headers["alg"]) |
| 622 return False |
| 623 else: |
| 624 self.jwt = jwt |
| 625 return True |
| 626 |
| 627 |
| 628 def factory(token): |
| 629 _jw = JWS() |
| 630 if _jw.is_jws(token): |
| 631 return _jw |
| 632 else: |
| 633 return None |
OLD | NEW |