OLD | NEW |
(Empty) | |
| 1 import base64 |
| 2 import hashlib |
| 3 import re |
| 4 import logging |
| 5 import json |
| 6 import sys |
| 7 import six |
| 8 |
| 9 from binascii import a2b_base64 |
| 10 |
| 11 from Crypto.PublicKey import RSA |
| 12 from Crypto.PublicKey.RSA import importKey |
| 13 from Crypto.PublicKey.RSA import _RSAobj |
| 14 from Crypto.Util.asn1 import DerSequence |
| 15 |
| 16 from requests import request |
| 17 |
| 18 from jwkest import base64url_to_long |
| 19 from jwkest import as_bytes |
| 20 from jwkest import base64_to_long |
| 21 from jwkest import long_to_base64 |
| 22 from jwkest import JWKESTException |
| 23 from jwkest import b64d |
| 24 from jwkest import b64e |
| 25 from jwkest.ecc import NISTEllipticCurve |
| 26 from jwkest.jwt import b2s_conv |
| 27 |
| 28 if sys.version > '3': |
| 29 long = int |
| 30 else: |
| 31 from __builtin__ import long |
| 32 |
| 33 __author__ = 'rohe0002' |
| 34 |
| 35 logger = logging.getLogger(__name__) |
| 36 |
| 37 PREFIX = "-----BEGIN CERTIFICATE-----" |
| 38 POSTFIX = "-----END CERTIFICATE-----" |
| 39 |
| 40 |
| 41 class JWKException(JWKESTException): |
| 42 pass |
| 43 |
| 44 |
| 45 class FormatError(JWKException): |
| 46 pass |
| 47 |
| 48 |
| 49 class SerializationNotPossible(JWKException): |
| 50 pass |
| 51 |
| 52 |
| 53 class DeSerializationNotPossible(JWKException): |
| 54 pass |
| 55 |
| 56 |
| 57 class HeaderError(JWKESTException): |
| 58 pass |
| 59 |
| 60 |
| 61 def dicthash(d): |
| 62 return hash(repr(sorted(d.items()))) |
| 63 |
| 64 |
| 65 def intarr2str(arr): |
| 66 return "".join([chr(c) for c in arr]) |
| 67 |
| 68 |
| 69 def sha256_digest(msg): |
| 70 return hashlib.sha256(as_bytes(msg)).digest() |
| 71 |
| 72 |
| 73 def sha384_digest(msg): |
| 74 return hashlib.sha384(as_bytes(msg)).digest() |
| 75 |
| 76 |
| 77 def sha512_digest(msg): |
| 78 return hashlib.sha512(as_bytes(msg)).digest() |
| 79 |
| 80 |
| 81 # ============================================================================= |
| 82 |
| 83 |
| 84 def import_rsa_key_from_file(filename): |
| 85 return RSA.importKey(open(filename, 'r').read()) |
| 86 |
| 87 |
| 88 def import_rsa_key(key): |
| 89 """ |
| 90 Extract an RSA key from a PEM-encoded certificate |
| 91 |
| 92 :param key: RSA key encoded in standard form |
| 93 :return: RSA key instance |
| 94 """ |
| 95 return importKey(key) |
| 96 |
| 97 |
| 98 def der2rsa(der): |
| 99 # Extract subjectPublicKeyInfo field from X.509 certificate (see RFC3280) |
| 100 cert = DerSequence() |
| 101 cert.decode(der) |
| 102 tbs_certificate = DerSequence() |
| 103 tbs_certificate.decode(cert[0]) |
| 104 subject_public_key_info = tbs_certificate[6] |
| 105 |
| 106 # Initialize RSA key |
| 107 return RSA.importKey(subject_public_key_info) |
| 108 |
| 109 |
| 110 def pem_cert2rsa(pem_file): |
| 111 # Convert from PEM to DER |
| 112 pem = open(pem_file).read() |
| 113 lines = pem.replace(" ", '').split() |
| 114 return der2rsa(a2b_base64(''.join(lines[1:-1]))) |
| 115 |
| 116 |
| 117 def der_cert2rsa(der): |
| 118 """ |
| 119 Extract an RSA key from a DER certificate |
| 120 |
| 121 @param der: DER-encoded certificate |
| 122 @return: RSA instance |
| 123 """ |
| 124 pem = re.sub(r'[^A-Za-z0-9+/]', '', der) |
| 125 return der2rsa(base64.b64decode(pem)) |
| 126 |
| 127 |
| 128 def load_x509_cert(url, spec2key): |
| 129 """ |
| 130 Get and transform a X509 cert into a key |
| 131 |
| 132 :param url: Where the X509 cert can be found |
| 133 :param spec2key: A dictionary over keys already seen |
| 134 :return: List of 2-tuples (keytype, key) |
| 135 """ |
| 136 try: |
| 137 r = request("GET", url, allow_redirects=True) |
| 138 if r.status_code == 200: |
| 139 cert = str(r.text) |
| 140 try: |
| 141 _key = spec2key[cert] |
| 142 except KeyError: |
| 143 _key = import_rsa_key(cert) |
| 144 spec2key[cert] = _key |
| 145 return [("rsa", _key)] |
| 146 else: |
| 147 raise Exception("HTTP Get error: %s" % r.status_code) |
| 148 except Exception as err: # not a RSA key |
| 149 logger.warning("Can't load key: %s" % err) |
| 150 return [] |
| 151 |
| 152 |
| 153 def rsa_load(filename): |
| 154 """Read a PEM-encoded RSA key pair from a file.""" |
| 155 pem = open(filename, 'r').read() |
| 156 return import_rsa_key(pem) |
| 157 |
| 158 |
| 159 def rsa_eq(key1, key2): |
| 160 # Check if two RSA keys are in fact the same |
| 161 if key1.n == key2.n and key1.e == key2.e: |
| 162 return True |
| 163 else: |
| 164 return False |
| 165 |
| 166 |
| 167 def key_eq(key1, key2): |
| 168 if type(key1) == type(key2): |
| 169 if isinstance(key1, str): |
| 170 return key1 == key2 |
| 171 elif isinstance(key1, RSA): |
| 172 return rsa_eq(key1, key2) |
| 173 |
| 174 return False |
| 175 |
| 176 |
| 177 def x509_rsa_load(txt): |
| 178 """ So I get the same output format as loads produces |
| 179 :param txt: |
| 180 :return: |
| 181 """ |
| 182 return [("rsa", import_rsa_key(txt))] |
| 183 |
| 184 |
| 185 class Key(object): |
| 186 """ |
| 187 Basic JSON Web key class |
| 188 """ |
| 189 members = ["kty", "alg", "use", "kid", "x5c", "x5t", "x5u"] |
| 190 longs = [] |
| 191 public_members = ["kty", "alg", "use", "kid", "x5c", "x5t", "x5u"] |
| 192 |
| 193 def __init__(self, kty="", alg="", use="", kid="", key=None, x5c=None, |
| 194 x5t="", x5u="", **kwargs): |
| 195 self.key = key |
| 196 self.extra_args = kwargs |
| 197 |
| 198 # want kty, alg, use and kid to be strings |
| 199 if isinstance(kty, six.string_types): |
| 200 self.kty = kty |
| 201 else: |
| 202 self.kty = kty.decode("utf8") |
| 203 |
| 204 if isinstance(alg, six.string_types): |
| 205 self.alg = alg |
| 206 else: |
| 207 self.alg = alg.decode("utf8") |
| 208 |
| 209 if isinstance(use, six.string_types): |
| 210 self.use = use |
| 211 else: |
| 212 self.use = use.decode("utf8") |
| 213 |
| 214 if isinstance(kid, six.string_types): |
| 215 self.kid = kid |
| 216 else: |
| 217 self.kid = kid.decode("utf8") |
| 218 |
| 219 self.x5c = x5c or [] |
| 220 self.x5t = x5t |
| 221 self.x5u = x5u |
| 222 self.inactive_since = 0 |
| 223 |
| 224 def to_dict(self): |
| 225 """ |
| 226 A wrapper for to_dict the makes sure that all the private information |
| 227 as well as extra arguments are included. This method should *not* be |
| 228 used for exporting information about the key. |
| 229 """ |
| 230 res = self.serialize(private=True) |
| 231 res.update(self.extra_args) |
| 232 return res |
| 233 |
| 234 def common(self): |
| 235 res = {"kty": self.kty} |
| 236 if self.use: |
| 237 res["use"] = self.use |
| 238 if self.kid: |
| 239 res["kid"] = self.kid |
| 240 if self.alg: |
| 241 res["alg"] = self.alg |
| 242 return res |
| 243 |
| 244 def __str__(self): |
| 245 return str(self.to_dict()) |
| 246 |
| 247 def deserialize(self): |
| 248 """ |
| 249 Starting with information gathered from the on-the-wire representation |
| 250 initiate an appropriate key. |
| 251 """ |
| 252 pass |
| 253 |
| 254 def serialize(self, private=False): |
| 255 """ |
| 256 map key characteristics into attribute values that can be used |
| 257 to create an on-the-wire representation of the key |
| 258 """ |
| 259 pass |
| 260 |
| 261 def get_key(self, **kwargs): |
| 262 return self.key |
| 263 |
| 264 def verify(self): |
| 265 """ |
| 266 Verify that the information gathered from the on-the-wire |
| 267 representation is of the right types. |
| 268 This is supposed to be run before the info is deserialized. |
| 269 """ |
| 270 for param in self.longs: |
| 271 item = getattr(self, param) |
| 272 if not item or isinstance(item, six.integer_types): |
| 273 continue |
| 274 |
| 275 if isinstance(item, bytes): |
| 276 item = item.decode('utf-8') |
| 277 setattr(self, param, item) |
| 278 |
| 279 try: |
| 280 _ = base64url_to_long(item) |
| 281 except Exception: |
| 282 return False |
| 283 else: |
| 284 if [e for e in ['+', '/', '='] if e in item]: |
| 285 return False |
| 286 |
| 287 if self.kid: |
| 288 try: |
| 289 assert isinstance(self.kid, six.string_types) |
| 290 except AssertionError: |
| 291 raise HeaderError("kid of wrong value type") |
| 292 return True |
| 293 |
| 294 def __eq__(self, other): |
| 295 try: |
| 296 assert isinstance(other, Key) |
| 297 assert list(self.__dict__.keys()) == list(other.__dict__.keys()) |
| 298 |
| 299 for key in self.public_members: |
| 300 assert getattr(other, key) == getattr(self, key) |
| 301 except AssertionError: |
| 302 return False |
| 303 else: |
| 304 return True |
| 305 |
| 306 def keys(self): |
| 307 return list(self.to_dict().keys()) |
| 308 |
| 309 |
| 310 def deser(val): |
| 311 if isinstance(val, str): |
| 312 _val = val.encode("utf-8") |
| 313 else: |
| 314 _val = val |
| 315 |
| 316 return base64_to_long(_val) |
| 317 |
| 318 |
| 319 class RSAKey(Key): |
| 320 """ |
| 321 JSON Web key representation of a RSA key |
| 322 """ |
| 323 members = Key.members |
| 324 members.extend(["n", "e", "d", "p", "q"]) |
| 325 longs = ["n", "e", "d", "p", "q", "dp", "dq", "di", "qi"] |
| 326 public_members = Key.public_members |
| 327 public_members.extend(["n", "e"]) |
| 328 |
| 329 def __init__(self, kty="RSA", alg="", use="", kid="", key=None, |
| 330 x5c=None, x5t="", x5u="", n="", e="", d="", p="", q="", |
| 331 dp="", dq="", di="", qi="", **kwargs): |
| 332 Key.__init__(self, kty, alg, use, kid, key, x5c, x5t, x5u, **kwargs) |
| 333 self.n = n |
| 334 self.e = e |
| 335 self.d = d |
| 336 self.p = p |
| 337 self.q = q |
| 338 self.dp = dp |
| 339 self.dq = dq |
| 340 self.di = di |
| 341 self.qi = qi |
| 342 |
| 343 if not self.key and self.n and self.e: |
| 344 self.deserialize() |
| 345 elif self.key and not (self.n and self.e): |
| 346 self._split() |
| 347 |
| 348 def deserialize(self): |
| 349 if self.n and self.e: |
| 350 try: |
| 351 for param in self.longs: |
| 352 item = getattr(self, param) |
| 353 if not item or isinstance(item, six.integer_types): |
| 354 continue |
| 355 else: |
| 356 try: |
| 357 val = long(deser(item)) |
| 358 except Exception: |
| 359 raise |
| 360 else: |
| 361 setattr(self, param, val) |
| 362 |
| 363 lst = [self.n, self.e] |
| 364 if self.d: |
| 365 lst.append(self.d) |
| 366 if self.p: |
| 367 lst.append(self.p) |
| 368 if self.q: |
| 369 lst.append(self.q) |
| 370 self.key = RSA.construct(tuple(lst)) |
| 371 else: |
| 372 self.key = RSA.construct(lst) |
| 373 except ValueError as err: |
| 374 raise DeSerializationNotPossible("%s" % err) |
| 375 elif self.x5c: |
| 376 if self.x5t: # verify the cert |
| 377 pass |
| 378 |
| 379 cert = "\n".join([PREFIX, str(self.x5c[0]), POSTFIX]) |
| 380 self.key = import_rsa_key(cert) |
| 381 self._split() |
| 382 if len(self.x5c) > 1: # verify chain |
| 383 pass |
| 384 else: |
| 385 raise DeSerializationNotPossible() |
| 386 |
| 387 def serialize(self, private=False): |
| 388 if not self.key: |
| 389 raise SerializationNotPossible() |
| 390 |
| 391 res = self.common() |
| 392 |
| 393 public_longs = list(set(self.public_members) & set(self.longs)) |
| 394 for param in public_longs: |
| 395 item = getattr(self, param) |
| 396 if item: |
| 397 res[param] = long_to_base64(item) |
| 398 |
| 399 if private: |
| 400 for param in self.longs: |
| 401 if not private and param in ["d", "p", "q", "dp", "dq", "di", |
| 402 "qi"]: |
| 403 continue |
| 404 item = getattr(self, param) |
| 405 if item: |
| 406 res[param] = long_to_base64(item) |
| 407 return res |
| 408 |
| 409 def _split(self): |
| 410 self.n = self.key.n |
| 411 self.e = self.key.e |
| 412 try: |
| 413 self.d = self.key.d |
| 414 except AttributeError: |
| 415 pass |
| 416 else: |
| 417 for param in ["p", "q"]: |
| 418 try: |
| 419 val = getattr(self.key, param) |
| 420 except AttributeError: |
| 421 pass |
| 422 else: |
| 423 if val: |
| 424 setattr(self, param, val) |
| 425 |
| 426 def load(self, filename): |
| 427 """ |
| 428 Load the key from a file. |
| 429 |
| 430 :param filename: File name |
| 431 """ |
| 432 self.key = rsa_load(filename) |
| 433 self._split() |
| 434 return self |
| 435 |
| 436 def load_key(self, key): |
| 437 """ |
| 438 Use this RSA key |
| 439 |
| 440 :param key: An RSA key instance |
| 441 """ |
| 442 self.key = key |
| 443 self._split() |
| 444 return self |
| 445 |
| 446 def encryption_key(self, **kwargs): |
| 447 """ |
| 448 Make sure there is a key instance present that can be used for |
| 449 encrypting/signing. |
| 450 """ |
| 451 if not self.key: |
| 452 self.deserialize() |
| 453 |
| 454 return self.key |
| 455 |
| 456 |
| 457 class ECKey(Key): |
| 458 """ |
| 459 JSON Web key representation of a Elliptic curve key |
| 460 """ |
| 461 members = ["kty", "alg", "use", "kid", "crv", "x", "y", "d"] |
| 462 longs = ['x', 'y', 'd'] |
| 463 public_members = ["kty", "alg", "use", "kid", "crv", "x", "y"] |
| 464 |
| 465 def __init__(self, kty="EC", alg="", use="", kid="", key=None, |
| 466 crv="", x="", y="", d="", curve=None, **kwargs): |
| 467 Key.__init__(self, kty, alg, use, kid, key, **kwargs) |
| 468 self.crv = crv |
| 469 self.x = x |
| 470 self.y = y |
| 471 self.d = d |
| 472 self.curve = curve |
| 473 |
| 474 # Initiated guess as to what state the key is in |
| 475 # To be usable for encryption/signing/.. it has to be deserialized |
| 476 if self.crv and not self.curve: |
| 477 self.verify() |
| 478 self.deserialize() |
| 479 |
| 480 def deserialize(self): |
| 481 """ |
| 482 Starting with information gathered from the on-the-wire representation |
| 483 of an elliptic curve key initiate an Elliptic Curve. |
| 484 """ |
| 485 try: |
| 486 if not isinstance(self.x, six.integer_types): |
| 487 self.x = deser(self.x) |
| 488 if not isinstance(self.y, six.integer_types): |
| 489 self.y = deser(self.y) |
| 490 except TypeError: |
| 491 raise DeSerializationNotPossible() |
| 492 except ValueError as err: |
| 493 raise DeSerializationNotPossible("%s" % err) |
| 494 |
| 495 self.curve = NISTEllipticCurve.by_name(self.crv) |
| 496 if self.d: |
| 497 try: |
| 498 if isinstance(self.d, six.string_types): |
| 499 self.d = deser(self.d) |
| 500 except ValueError as err: |
| 501 raise DeSerializationNotPossible(str(err)) |
| 502 |
| 503 def get_key(self, private=False, **kwargs): |
| 504 if private: |
| 505 return self.d |
| 506 else: |
| 507 return self.x, self.y |
| 508 |
| 509 def serialize(self, private=False): |
| 510 if not self.crv and not self.curve: |
| 511 raise SerializationNotPossible() |
| 512 |
| 513 res = self.common() |
| 514 res.update({ |
| 515 "crv": self.curve.name(), |
| 516 "x": long_to_base64(self.x), |
| 517 "y": long_to_base64(self.y) |
| 518 }) |
| 519 |
| 520 if private and self.d: |
| 521 res["d"] = long_to_base64(self.d) |
| 522 |
| 523 return res |
| 524 |
| 525 def load_key(self, key): |
| 526 self.curve = key |
| 527 self.d, (self.x, self.y) = key.key_pair() |
| 528 return self |
| 529 |
| 530 def decryption_key(self): |
| 531 return self.get_key(private=True) |
| 532 |
| 533 def encryption_key(self, private=False, **kwargs): |
| 534 # both for encryption and decryption. |
| 535 return self.get_key(private=private) |
| 536 |
| 537 |
| 538 ALG2KEYLEN = { |
| 539 "A128KW": 16, |
| 540 "A192KW": 24, |
| 541 "A256KW": 32, |
| 542 "HS256": 32, |
| 543 "HS384": 48, |
| 544 "HS512": 64 |
| 545 } |
| 546 |
| 547 |
| 548 class SYMKey(Key): |
| 549 members = ["kty", "alg", "use", "kid", "k"] |
| 550 public_members = members[:] |
| 551 |
| 552 def __init__(self, kty="oct", alg="", use="", kid="", key=None, |
| 553 x5c=None, x5t="", x5u="", k="", mtrl="", **kwargs): |
| 554 Key.__init__(self, kty, alg, use, kid, as_bytes(key), x5c, x5t, x5u, **k
wargs) |
| 555 self.k = k |
| 556 if not self.key and self.k: |
| 557 if isinstance(self.k, str): |
| 558 self.k = self.k.encode("utf-8") |
| 559 self.key = b64d(bytes(self.k)) |
| 560 |
| 561 def deserialize(self): |
| 562 self.key = b64d(bytes(self.k)) |
| 563 |
| 564 def serialize(self, private=True): |
| 565 res = self.common() |
| 566 res["k"] = b64e(bytes(self.key)) |
| 567 return res |
| 568 |
| 569 def encryption_key(self, alg, **kwargs): |
| 570 if not self.key: |
| 571 self.deserialize() |
| 572 |
| 573 tsize = ALG2KEYLEN[alg] |
| 574 _keylen = len(self.key) |
| 575 |
| 576 if _keylen <= 32: |
| 577 # SHA256 |
| 578 _enc_key = sha256_digest(self.key)[:tsize] |
| 579 elif _keylen <= 48: |
| 580 # SHA384 |
| 581 _enc_key = sha384_digest(self.key)[:tsize] |
| 582 elif _keylen <= 64: |
| 583 # SHA512 |
| 584 _enc_key = sha512_digest(self.key)[:tsize] |
| 585 else: |
| 586 raise JWKException("No support for symmetric keys > 512 bits") |
| 587 |
| 588 return _enc_key |
| 589 |
| 590 # ----------------------------------------------------------------------------- |
| 591 |
| 592 |
| 593 def keyitems2keyreps(keyitems): |
| 594 keys = [] |
| 595 for key_type, _keys in list(keyitems.items()): |
| 596 if key_type.upper() == "RSA": |
| 597 keys.extend([RSAKey(key=k) for k in _keys]) |
| 598 elif key_type.lower() == "oct": |
| 599 keys.extend([SYMKey(key=k) for k in _keys]) |
| 600 elif key_type.upper() == "EC": |
| 601 keys.extend([ECKey(key=k) for k in _keys]) |
| 602 else: |
| 603 keys.extend([Key(key=k) for k in _keys]) |
| 604 return keys |
| 605 |
| 606 |
| 607 def keyrep(kspec, enc="utf-8"): |
| 608 """ |
| 609 Instantiate a Key given a set of key/word arguments |
| 610 |
| 611 :param kspec: Key specification, arguments to the Key initialization |
| 612 :param enc: The encoding of the strings. If it's JSON which is the default |
| 613 the encoding is utf-8. |
| 614 :return: Key instance |
| 615 """ |
| 616 if enc: |
| 617 _kwargs = {} |
| 618 for key, val in kspec.items(): |
| 619 if isinstance(val, str): |
| 620 _kwargs[key] = val.encode(enc) |
| 621 else: |
| 622 _kwargs[key] = val |
| 623 else: |
| 624 _kwargs = kspec |
| 625 |
| 626 if kspec["kty"] == "RSA": |
| 627 item = RSAKey(**_kwargs) |
| 628 elif kspec["kty"] == "oct": |
| 629 item = SYMKey(**_kwargs) |
| 630 elif kspec["kty"] == "EC": |
| 631 item = ECKey(**_kwargs) |
| 632 else: |
| 633 item = Key(**_kwargs) |
| 634 return item |
| 635 |
| 636 |
| 637 def jwk_wrap(key, use="", kid=""): |
| 638 """ |
| 639 Instantiated a Key instance with the given key |
| 640 |
| 641 :param key: The keys to wrap |
| 642 :param use: What the key are expected to be use for |
| 643 :param kid: A key id |
| 644 :return: The Key instance |
| 645 """ |
| 646 if isinstance(key, _RSAobj): |
| 647 kspec = RSAKey(use=use, kid=kid).load_key(key) |
| 648 elif isinstance(key, str): |
| 649 kspec = SYMKey(key=key, use=use, kid=kid) |
| 650 elif isinstance(key, NISTEllipticCurve): |
| 651 kspec = ECKey(use=use, kid=kid).load_key(key) |
| 652 else: |
| 653 raise Exception("Unknown key type:key="+str(type(key))) |
| 654 |
| 655 kspec.serialize() |
| 656 return kspec |
| 657 |
| 658 |
| 659 class KEYS(object): |
| 660 def __init__(self): |
| 661 self._keys = [] |
| 662 |
| 663 def load_dict(self, dikt): |
| 664 for kspec in dikt["keys"]: |
| 665 self._keys.append(keyrep(kspec)) |
| 666 |
| 667 def load_jwks(self, jwks): |
| 668 """ |
| 669 Load and create keys from a JWKS JSON representation |
| 670 |
| 671 Expects something on this form:: |
| 672 |
| 673 {"keys": |
| 674 [ |
| 675 {"kty":"EC", |
| 676 "crv":"P-256", |
| 677 "x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4", |
| 678 "y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM", |
| 679 "use":"enc", |
| 680 "kid":"1"}, |
| 681 |
| 682 {"kty":"RSA", |
| 683 "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFb....." |
| 684 "e":"AQAB", |
| 685 "kid":"2011-04-29"} |
| 686 ] |
| 687 } |
| 688 |
| 689 :param jwks: The JWKS JSON string representation |
| 690 :return: list of 2-tuples containing key, type |
| 691 """ |
| 692 return self.load_dict(json.loads(jwks)) |
| 693 |
| 694 def dump_jwks(self): |
| 695 """ |
| 696 :return: A JWKS representation of the held keys |
| 697 """ |
| 698 res = [] |
| 699 for key in self._keys: |
| 700 res.append(b2s_conv(key.serialize())) |
| 701 |
| 702 return json.dumps({"keys": res}) |
| 703 |
| 704 def load_from_url(self, url, verify=True): |
| 705 """ |
| 706 Get and transform a JWKS into keys |
| 707 |
| 708 :param url: Where the JWKS can be found |
| 709 :param verify: SSL cert verification |
| 710 :return: list of keys |
| 711 """ |
| 712 |
| 713 r = request("GET", url, allow_redirects=True, verify=verify) |
| 714 if r.status_code == 200: |
| 715 return self.load_jwks(r.text) |
| 716 else: |
| 717 raise Exception("HTTP Get error: %s" % r.status_code) |
| 718 |
| 719 def __getitem__(self, item): |
| 720 """ |
| 721 Get all keys of a specific key type |
| 722 |
| 723 :param kty: Key type |
| 724 :return: list of keys |
| 725 """ |
| 726 kty = item.lower() |
| 727 return [k for k in self._keys if k.kty.lower() == kty] |
| 728 |
| 729 def __iter__(self): |
| 730 for k in self._keys: |
| 731 yield k |
| 732 |
| 733 def __len__(self): |
| 734 return len(self._keys) |
| 735 |
| 736 def keys(self): |
| 737 return list(set([k.kty for k in self._keys])) |
| 738 |
| 739 def __repr__(self): |
| 740 return self.dump_jwks() |
| 741 |
| 742 def __str__(self): |
| 743 return self.__repr__() |
| 744 |
| 745 def kids(self): |
| 746 return [k.kid for k in self._keys if k.kid] |
| 747 |
| 748 def by_kid(self, kid): |
| 749 return [k for k in self._keys if kid == k.kid] |
| 750 |
| 751 def wrap_add(self, keyinst, use="", kid=''): |
| 752 self._keys.append(jwk_wrap(keyinst, use, kid)) |
| 753 |
| 754 def as_dict(self): |
| 755 _res = {} |
| 756 for kty, k in [(k.kty, k) for k in self._keys]: |
| 757 if kty not in ["RSA", "EC", "oct"]: |
| 758 if kty in ["rsa", "ec"]: |
| 759 kty = kty.upper() |
| 760 else: |
| 761 kty = kty.lower() |
| 762 |
| 763 try: |
| 764 _res[kty].append(k) |
| 765 except KeyError: |
| 766 _res[kty] = [k] |
| 767 return _res |
| 768 |
| 769 def add(self, item, enc="utf-8"): |
| 770 self._keys.append(keyrep(item, enc)) |
OLD | NEW |