| Index: third_party/google-endpoints/jwkest/jws.py
|
| diff --git a/third_party/google-endpoints/jwkest/jws.py b/third_party/google-endpoints/jwkest/jws.py
|
| new file mode 100644
|
| index 0000000000000000000000000000000000000000..2a130355e2273e5528965e2e7e44b624b4035366
|
| --- /dev/null
|
| +++ b/third_party/google-endpoints/jwkest/jws.py
|
| @@ -0,0 +1,633 @@
|
| +"""JSON Web Token"""
|
| +import six
|
| +
|
| +try:
|
| + from builtins import str
|
| + from builtins import object
|
| +except ImportError:
|
| + pass
|
| +
|
| +# Most of the code, ideas herein I have borrowed/stolen from other people
|
| +# Most notably Jeff Lindsay, Ryan Kelly and Richard Barnes
|
| +
|
| +import json
|
| +import logging
|
| +
|
| +import struct
|
| +from Crypto.Hash import SHA256
|
| +from Crypto.Hash import SHA384
|
| +from Crypto.Hash import SHA512
|
| +from Crypto.Hash import HMAC
|
| +from Crypto.Signature import PKCS1_v1_5
|
| +from Crypto.Signature import PKCS1_PSS
|
| +from Crypto.Util.number import bytes_to_long
|
| +import sys
|
| +
|
| +from jwkest import b64d, as_unicode
|
| +from jwkest import b64e
|
| +from jwkest import constant_time_compare
|
| +from jwkest import safe_str_cmp
|
| +from jwkest import JWKESTException
|
| +from jwkest import BadSignature
|
| +from jwkest import UnknownAlgorithm
|
| +from jwkest.ecc import P256
|
| +from jwkest.ecc import P384
|
| +from jwkest.ecc import P521
|
| +
|
| +from jwkest.jwk import load_x509_cert, KEYS
|
| +from jwkest.jwk import HeaderError
|
| +from jwkest.jwk import sha256_digest
|
| +from jwkest.jwk import sha384_digest
|
| +from jwkest.jwk import sha512_digest
|
| +from jwkest.jwk import keyrep
|
| +
|
| +from jwkest.jwt import JWT
|
| +from jwkest.jwt import b64encode_item
|
| +
|
| +logger = logging.getLogger(__name__)
|
| +
|
| +
|
| +class JWSException(JWKESTException):
|
| + pass
|
| +
|
| +
|
| +class NoSuitableSigningKeys(JWSException):
|
| + pass
|
| +
|
| +
|
| +class FormatError(JWSException):
|
| + pass
|
| +
|
| +
|
| +class WrongTypeOfKey(JWSException):
|
| + pass
|
| +
|
| +
|
| +class UnknownSignerAlg(JWSException):
|
| + pass
|
| +
|
| +
|
| +class SignerAlgError(JWSException):
|
| + pass
|
| +
|
| +
|
| +def left_hash(msg, func="HS256"):
|
| + """ 128 bits == 16 bytes """
|
| + if func == 'HS256':
|
| + return as_unicode(b64e(sha256_digest(msg)[:16]))
|
| + elif func == 'HS384':
|
| + return as_unicode(b64e(sha384_digest(msg)[:24]))
|
| + elif func == 'HS512':
|
| + return as_unicode(b64e(sha512_digest(msg)[:32]))
|
| +
|
| +
|
| +def mpint(b):
|
| + b += b"\x00"
|
| + return struct.pack(">L", len(b)) + b
|
| +
|
| +
|
| +def mp2bin(b):
|
| + # just ignore the length...
|
| + if b[4] == '\x00':
|
| + return b[5:]
|
| + else:
|
| + return b[4:]
|
| +
|
| +
|
| +class Signer(object):
|
| + """Abstract base class for signing algorithms."""
|
| + def sign(self, msg, key):
|
| + """Sign ``msg`` with ``key`` and return the signature."""
|
| + raise NotImplementedError()
|
| +
|
| + def verify(self, msg, sig, key):
|
| + """Return True if ``sig`` is a valid signature for ``msg``."""
|
| + raise NotImplementedError()
|
| +
|
| +
|
| +class HMACSigner(Signer):
|
| + def __init__(self, digest):
|
| + self.digest = digest
|
| +
|
| + def sign(self, msg, key):
|
| + h = HMAC.new(key, msg, digestmod=self.digest)
|
| + return h.digest()
|
| + # return hmac.new(key, msg, digestmod=self.digest).digest()
|
| +
|
| + def verify(self, msg, sig, key):
|
| + if sys.version < '3':
|
| + if safe_str_cmp(self.sign(msg, key), sig):
|
| + return True
|
| + elif constant_time_compare(self.sign(msg, key), sig):
|
| + return True
|
| + raise BadSignature(repr(sig))
|
| +
|
| +
|
| +class RSASigner(Signer):
|
| + def __init__(self, digest):
|
| + self.digest = digest
|
| +
|
| + def sign(self, msg, key):
|
| + h = self.digest.new(msg)
|
| + signer = PKCS1_v1_5.new(key)
|
| + return signer.sign(h)
|
| +
|
| + def verify(self, msg, sig, key):
|
| + h = self.digest.new(msg)
|
| + verifier = PKCS1_v1_5.new(key)
|
| + try:
|
| + if verifier.verify(h, sig):
|
| + return True
|
| + else:
|
| + raise BadSignature()
|
| + except ValueError as e:
|
| + raise BadSignature(str(e))
|
| +
|
| +
|
| +class DSASigner(Signer):
|
| + def __init__(self, digest, sign):
|
| + self.digest = digest
|
| + self._sign = sign
|
| +
|
| + def sign(self, msg, key):
|
| + # verify the key
|
| + h = bytes_to_long(self.digest.new(msg).digest())
|
| + return self._sign.sign(h, key)
|
| +
|
| + def verify(self, msg, sig, key):
|
| + h = bytes_to_long(self.digest.new(msg).digest())
|
| + return self._sign.verify(h, sig, key)
|
| +
|
| +
|
| +class PSSSigner(Signer):
|
| + def __init__(self, digest):
|
| + self.digest = digest
|
| +
|
| + def sign(self, msg, key):
|
| + h = self.digest.new(msg)
|
| + signer = PKCS1_PSS.new(key)
|
| + return signer.sign(h)
|
| +
|
| + def verify(self, msg, sig, key):
|
| + h = self.digest.new(msg)
|
| + verifier = PKCS1_PSS.new(key)
|
| + res = verifier.verify(h, sig)
|
| + if not res:
|
| + raise BadSignature()
|
| + else:
|
| + return True
|
| +
|
| +
|
| +SIGNER_ALGS = {
|
| + 'HS256': HMACSigner(SHA256),
|
| + 'HS384': HMACSigner(SHA384),
|
| + 'HS512': HMACSigner(SHA512),
|
| +
|
| + 'RS256': RSASigner(SHA256),
|
| + 'RS384': RSASigner(SHA384),
|
| + 'RS512': RSASigner(SHA512),
|
| +
|
| + 'ES256': DSASigner(SHA256, P256),
|
| + 'ES384': DSASigner(SHA384, P384),
|
| + 'ES512': DSASigner(SHA512, P521),
|
| +
|
| + 'PS256': PSSSigner(SHA256),
|
| + 'PS384': PSSSigner(SHA384),
|
| + 'PS512': PSSSigner(SHA512),
|
| +
|
| + 'none': None
|
| +}
|
| +
|
| +
|
| +def alg2keytype(alg):
|
| + if not alg or alg.lower() == "none":
|
| + return "none"
|
| + elif alg.startswith("RS") or alg.startswith("PS"):
|
| + return "RSA"
|
| + elif alg.startswith("HS") or alg.startswith("A"):
|
| + return "oct"
|
| + elif alg.startswith("ES"):
|
| + return "EC"
|
| + else:
|
| + return None
|
| +
|
| +
|
| +class JWSig(JWT):
|
| + def sign_input(self):
|
| + return self.b64part[0] + b'.' + self.b64part[1]
|
| +
|
| + def signature(self):
|
| + return self.part[2]
|
| +
|
| +
|
| +class JWx(object):
|
| + args = ["alg", "jku", "jwk", "x5u", "x5t", "x5c", "kid", "typ", "cty",
|
| + "crit"]
|
| + """
|
| + :param alg: The signing algorithm
|
| + :param jku: a URI that refers to a resource for a set of JSON-encoded
|
| + public keys, one of which corresponds to the key used to digitally
|
| + sign the JWS
|
| + :param jwk: A JSON Web Key that corresponds to the key used to
|
| + digitally sign the JWS
|
| + :param x5u: a URI that refers to a resource for the X.509 public key
|
| + certificate or certificate chain [RFC5280] corresponding to the key
|
| + used to digitally sign the JWS.
|
| + :param x5t: a base64url encoded SHA-1 thumbprint (a.k.a. digest) of the
|
| + DER encoding of the X.509 certificate [RFC5280] corresponding to
|
| + the key used to digitally sign the JWS.
|
| + :param x5c: the X.509 public key certificate or certificate chain
|
| + corresponding to the key used to digitally sign the JWS.
|
| + :param kid: a hint indicating which key was used to secure the JWS.
|
| + :param typ: the type of this object. 'JWS' == JWS Compact Serialization
|
| + 'JWS+JSON' == JWS JSON Serialization
|
| + :param cty: the type of the secured content
|
| + :param crit: indicates which extensions that are being used and MUST
|
| + be understood and processed.
|
| + :param kwargs: Extra header parameters
|
| + :return: A class instance
|
| + """
|
| +
|
| + def __init__(self, msg=None, with_digest=False, **kwargs):
|
| + self.msg = msg
|
| +
|
| + self._dict = {}
|
| + self.with_digest = with_digest
|
| + self.jwt = None
|
| +
|
| + if kwargs:
|
| + for key in self.args:
|
| + try:
|
| + _val = kwargs[key]
|
| + except KeyError:
|
| + if key == "alg":
|
| + self._dict[key] = "none"
|
| + continue
|
| +
|
| + if key == "jwk":
|
| + if isinstance(_val, dict):
|
| + self._dict["jwk"] = keyrep(_val)
|
| + elif isinstance(_val, str):
|
| + self._dict["jwk"] = keyrep(json.loads(_val))
|
| + else:
|
| + self._dict["jwk"] = _val
|
| + elif key == "x5c" or key == "crit":
|
| + self._dict["x5c"] = _val or []
|
| + else:
|
| + self._dict[key] = _val
|
| +
|
| + def __contains__(self, item):
|
| + return item in self._dict
|
| +
|
| + def __getitem__(self, item):
|
| + return self._dict[item]
|
| +
|
| + def __setitem__(self, key, value):
|
| + self._dict[key] = value
|
| +
|
| + def __getattr__(self, item):
|
| + try:
|
| + return self._dict[item]
|
| + except KeyError:
|
| + raise AttributeError(item)
|
| +
|
| + def keys(self):
|
| + return list(self._dict.keys())
|
| +
|
| + def headers(self, extra=None):
|
| + _extra = extra or {}
|
| + _header = {}
|
| + for param in self.args:
|
| + try:
|
| + _header[param] = _extra[param]
|
| + except KeyError:
|
| + try:
|
| + if self._dict[param]:
|
| + _header[param] = self._dict[param]
|
| + except KeyError:
|
| + pass
|
| +
|
| + if "jwk" in self:
|
| + _header["jwk"] = self["jwk"].serialize()
|
| + elif "jwk" in _extra:
|
| + _header["jwk"] = extra["jwk"].serialize()
|
| +
|
| + if "kid" in self:
|
| + try:
|
| + assert isinstance(self["kid"], six.string_types)
|
| + except AssertionError:
|
| + raise HeaderError("kid of wrong value type")
|
| +
|
| + return _header
|
| +
|
| + def _get_keys(self):
|
| + logger.debug("_get_keys(): self._dict.keys={0}".format(
|
| + self._dict.keys()))
|
| +
|
| + if "jwk" in self:
|
| + return [self["jwk"]]
|
| + elif "jku" in self:
|
| + keys = KEYS()
|
| + keys.load_from_url(self["jku"])
|
| + return keys.as_dict()
|
| + elif "x5u" in self:
|
| + try:
|
| + return {"rsa": [load_x509_cert(self["x5u"], {})]}
|
| + except Exception:
|
| + # ca_chain = load_x509_cert_chain(self["x5u"])
|
| + pass
|
| +
|
| + return {}
|
| +
|
| + def alg2keytype(self, alg):
|
| + return alg2keytype(alg)
|
| +
|
| + def _pick_keys(self, keys, use="", alg=""):
|
| + """
|
| + The assumption is that upper layer has made certain you only get
|
| + keys you can use.
|
| +
|
| + :param keys: A list of KEY instances
|
| + :return: A list of KEY instances that fulfill the requirements
|
| + """
|
| + if not alg:
|
| + alg = self["alg"]
|
| +
|
| + if alg == "none":
|
| + return []
|
| +
|
| + _k = self.alg2keytype(alg)
|
| + if _k is None:
|
| + logger.error("Unknown arlgorithm '%s'" % alg)
|
| + return []
|
| +
|
| + logger.debug("Picking key by key type={0}".format(_k))
|
| + _kty = [_k.lower(), _k.upper(), _k.lower().encode("utf-8"),
|
| + _k.upper().encode("utf-8")]
|
| + _keys = [k for k in keys if k.kty in _kty]
|
| + try:
|
| + _kid = self["kid"]
|
| + except KeyError:
|
| + try:
|
| + _kid = self.jwt.headers["kid"]
|
| + except (AttributeError, KeyError):
|
| + _kid = None
|
| +
|
| + logger.debug("Picking key based on alg={0}, kid={1} and use={2}".format(
|
| + alg, _kid, use))
|
| +
|
| + pkey = []
|
| + for _key in _keys:
|
| + logger.debug("KEY: {0}".format(_key))
|
| + if _kid:
|
| + try:
|
| + assert _kid == _key.kid
|
| + except (KeyError, AttributeError):
|
| + pass
|
| + except AssertionError:
|
| + continue
|
| +
|
| + if use and _key.use and _key.use != use:
|
| + continue
|
| +
|
| + if alg and _key.alg and _key.alg != alg:
|
| + continue
|
| +
|
| + pkey.append(_key)
|
| +
|
| + return pkey
|
| +
|
| + def _decode(self, payload):
|
| + _msg = b64d(bytes(payload))
|
| + if "cty" in self:
|
| + if self["cty"] == "JWT":
|
| + _msg = json.loads(_msg)
|
| + return _msg
|
| +
|
| + def dump_header(self):
|
| + return dict([(x, self._dict[x]) for x in self.args if x in self._dict])
|
| +
|
| +
|
| +class JWS(JWx):
|
| +
|
| + def alg_keys(self, keys, use, protected=None):
|
| + try:
|
| + _alg = self["alg"]
|
| + except KeyError:
|
| + self["alg"] = _alg = "none"
|
| + else:
|
| + if not _alg:
|
| + self["alg"] = _alg = "none"
|
| +
|
| + if keys:
|
| + keys = self._pick_keys(keys, use=use, alg=_alg)
|
| + else:
|
| + keys = self._pick_keys(self._get_keys(), use=use, alg=_alg)
|
| +
|
| + xargs = protected or {}
|
| + xargs["alg"] = _alg
|
| +
|
| + if keys:
|
| + key = keys[0]
|
| + if key.kid:
|
| + xargs["kid"] = key.kid
|
| + elif not _alg or _alg.lower() == "none":
|
| + key = None
|
| + else:
|
| + if "kid" in self:
|
| + raise NoSuitableSigningKeys(
|
| + "No key for algorithm: %s and kid: %s" % (_alg,
|
| + self["kid"]))
|
| + else:
|
| + raise NoSuitableSigningKeys("No key for algorithm: %s" % _alg)
|
| +
|
| + return key, xargs, _alg
|
| +
|
| + def sign_compact(self, keys=None, protected=None):
|
| + """
|
| + Produce a JWS using the JWS Compact Serialization
|
| +
|
| + :param keys: A dictionary of keys
|
| + :param protected: The protected headers (a dictionary)
|
| + :return:
|
| + """
|
| +
|
| + key, xargs, _alg = self.alg_keys(keys, 'sig', protected)
|
| +
|
| + if "typ" in self:
|
| + xargs["typ"] = self["typ"]
|
| +
|
| + jwt = JWSig(**xargs)
|
| + if _alg == "none":
|
| + return jwt.pack(parts=[self.msg, ""])
|
| +
|
| + # All other cases
|
| + try:
|
| + _signer = SIGNER_ALGS[_alg]
|
| + except KeyError:
|
| + raise UnknownAlgorithm(_alg)
|
| +
|
| + _input = jwt.pack(parts=[self.msg])
|
| + sig = _signer.sign(_input.encode("utf-8"), key.get_key(alg=_alg, private=True))
|
| + logger.debug("Signed message using key with kid=%s" % key.kid)
|
| + return ".".join([_input, b64encode_item(sig).decode("utf-8")])
|
| +
|
| + def verify_compact(self, jws, keys=None, allow_none=False, sigalg=None):
|
| + """
|
| + Verify a JWT signature
|
| +
|
| + :param jws:
|
| + :param keys:
|
| + :param allow_none: If signature algorithm 'none' is allowed
|
| + :param sigalg: Expected sigalg
|
| + :return:
|
| + """
|
| + jwt = JWSig().unpack(jws)
|
| + self.jwt = jwt
|
| +
|
| + try:
|
| + _alg = jwt.headers["alg"]
|
| + except KeyError:
|
| + _alg = None
|
| + else:
|
| + if _alg is None or _alg.lower() == "none":
|
| + if allow_none:
|
| + self.msg = jwt.payload()
|
| + return self.msg
|
| + else:
|
| + raise SignerAlgError("none not allowed")
|
| +
|
| + if "alg" in self and _alg:
|
| + if self["alg"] != _alg:
|
| + raise SignerAlgError("Wrong signing algorithm")
|
| +
|
| + if sigalg and sigalg != _alg:
|
| + raise SignerAlgError("Expected {0} got {1}".format(
|
| + sigalg, jwt.headers["alg"]))
|
| +
|
| + self["alg"] = _alg
|
| +
|
| + if keys:
|
| + _keys = self._pick_keys(keys)
|
| + else:
|
| + _keys = self._pick_keys(self._get_keys())
|
| +
|
| + if not _keys:
|
| + if "kid" in self:
|
| + raise NoSuitableSigningKeys(
|
| + "No key with kid: %s" % (self["kid"]))
|
| + elif "kid" in self.jwt.headers:
|
| + raise NoSuitableSigningKeys(
|
| + "No key with kid: %s" % (self.jwt.headers["kid"]))
|
| + else:
|
| + raise NoSuitableSigningKeys("No key for algorithm: %s" % _alg)
|
| +
|
| + verifier = SIGNER_ALGS[_alg]
|
| +
|
| + for key in _keys:
|
| + try:
|
| + res = verifier.verify(jwt.sign_input(), jwt.signature(),
|
| + key.get_key(alg=_alg, private=False))
|
| + except BadSignature:
|
| + pass
|
| + else:
|
| + if res is True:
|
| + logger.debug(
|
| + "Verified message using key with kid=%s" % key.kid)
|
| + self.msg = jwt.payload()
|
| + return self.msg
|
| +
|
| + raise BadSignature()
|
| +
|
| + def sign_json(self, per_signature_header=None, **kwargs):
|
| + """
|
| + Produce JWS using the JWS JSON Serialization
|
| +
|
| + :param per_signature_header: Header parameter values that are to be
|
| + applied to a specific signature
|
| + :return:
|
| + """
|
| + res = {"signatures": []}
|
| +
|
| + if per_signature_header is None:
|
| + per_signature_header = [{"alg": "none"}]
|
| +
|
| + for _kwa in per_signature_header:
|
| + _kwa.update(kwargs)
|
| + _jws = JWS(self.msg, **_kwa)
|
| + header, payload, signature = _jws.sign_compact().split(".")
|
| + res["signatures"].append({"header": header,
|
| + "signature": signature})
|
| +
|
| + res["payload"] = self.msg
|
| +
|
| + return res
|
| +
|
| + def verify_json(self, jws, keys=None, allow_none=False, sigalg=None):
|
| + """
|
| +
|
| + :param jws:
|
| + :param keys:
|
| + :return:
|
| + """
|
| +
|
| + _jwss = json.load(jws)
|
| +
|
| + try:
|
| + _payload = _jwss["payload"]
|
| + except KeyError:
|
| + raise FormatError("Missing payload")
|
| +
|
| + try:
|
| + _signs = _jwss["signatures"]
|
| + except KeyError:
|
| + raise FormatError("Missing signatures")
|
| +
|
| + _claim = None
|
| + for _sign in _signs:
|
| + token = b".".join([_sign["protected"].encode(), _payload.encode(), _sign["signature"].encode()])
|
| + header = _sign.get("header", {})
|
| + self.__init__(**header)
|
| + _tmp = self.verify_compact(token, keys, allow_none, sigalg)
|
| + if _claim is None:
|
| + _claim = _tmp
|
| + else:
|
| + assert _claim == _tmp
|
| +
|
| + return _claim
|
| +
|
| + def is_jws(self, token):
|
| + """
|
| +
|
| + :param token:
|
| + :return:
|
| + """
|
| + try:
|
| + jwt = JWSig().unpack(token)
|
| + except Exception:
|
| + return False
|
| +
|
| + try:
|
| + assert "alg" in jwt.headers
|
| + except AssertionError:
|
| + return False
|
| + else:
|
| + if jwt.headers["alg"] is None:
|
| + jwt.headers["alg"] = "none"
|
| +
|
| + try:
|
| + assert jwt.headers["alg"] in SIGNER_ALGS
|
| + except AssertionError:
|
| + logger.debug("UnknownSignerAlg: %s" % jwt.headers["alg"])
|
| + return False
|
| + else:
|
| + self.jwt = jwt
|
| + return True
|
| +
|
| +
|
| +def factory(token):
|
| + _jw = JWS()
|
| + if _jw.is_jws(token):
|
| + return _jw
|
| + else:
|
| + return None
|
|
|