Skip to content

Commit

Permalink
Mapping (#23)
Browse files Browse the repository at this point in the history
* always return a key id when accessing `Jwk.kid`
* use UserDict instead of dict as base class for Jwk, JwkSet and JwsJson. Accept `Mapping` everywhere instead of `dict`
  • Loading branch information
guillp committed Jan 22, 2024
1 parent 6a73468 commit f1111d6
Show file tree
Hide file tree
Showing 20 changed files with 496 additions and 445 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ repos:
hooks:
- id: blacken-docs
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11
rev: v0.1.13
hooks:
- id: ruff
args: [ --fix ]
- id: ruff-format
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
Expand All @@ -44,6 +44,6 @@ repos:
additional_dependencies:
- types-cryptography==3.3.23.2
- pytest-mypy==0.10.3
- binapy==0.7.0
- binapy==0.8.0
- freezegun==1.2.2
- jwcrypto==1.5.0
656 changes: 329 additions & 327 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion jwskate/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class KeyManagementAlgs:
A128GCMKW = "A128GCMKW"
A192GCMKW = "A192GCMKW"
A256GCMKW = "A256GCMKW"
dir = "dir" # noqa: A003
dir = "dir"

PBES2_HS256_A128KW = "PBES2-HS256+A128KW"
PBES2_HS384_A192KW = "PBES2-HS384+A192KW"
Expand Down
2 changes: 1 addition & 1 deletion jwskate/jwa/encryption/aesgcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def encrypt(
if not isinstance(plaintext, bytes):
plaintext = bytes(plaintext)
ciphertext_with_tag = BinaPy(aead.AESGCM(self.key).encrypt(iv, plaintext, aad))
ciphertext, tag = ciphertext_with_tag.cut_at(-self.tag_size)
ciphertext, tag = ciphertext_with_tag.split_at(-self.tag_size)
return ciphertext, tag

def decrypt(
Expand Down
4 changes: 3 additions & 1 deletion jwskate/jwa/signature/ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def sign(self, data: bytes | SupportsBytes) -> BinaPy:
with self.private_key_required() as key:
dss_sig = key.sign(data, ec.ECDSA(self.hashing_alg))
r, s = asymmetric.utils.decode_dss_signature(dss_sig)
return BinaPy.from_int(r, self.curve.coordinate_size) + BinaPy.from_int(s, self.curve.coordinate_size)
return BinaPy.from_int(r, length=self.curve.coordinate_size) + BinaPy.from_int(
s, length=self.curve.coordinate_size
)

@override
def verify(self, data: bytes | SupportsBytes, signature: bytes | SupportsBytes) -> bool:
Expand Down
10 changes: 5 additions & 5 deletions jwskate/jwe/compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ def enc(self) -> str:
def encrypt(
cls,
plaintext: bytes | SupportsBytes,
key: Jwk | dict[str, Any] | Any,
key: Jwk | Mapping[str, Any] | Any,
*,
enc: str,
alg: str | None = None,
extra_headers: dict[str, Any] | None = None,
extra_headers: Mapping[str, Any] | None = None,
cek: bytes | None = None,
iv: bytes | None = None,
epk: Jwk | None = None,
Expand Down Expand Up @@ -188,7 +188,7 @@ def encrypt(

def unwrap_cek(
self,
key_or_password: Jwk | dict[str, Any] | bytes | str,
key_or_password: Jwk | Mapping[str, Any] | bytes | str,
alg: str | None = None,
algs: Iterable[str] | None = None,
) -> Jwk:
Expand Down Expand Up @@ -220,7 +220,7 @@ def unwrap_cek(

def decrypt(
self,
key: Jwk | dict[str, Any] | Any,
key: Jwk | Mapping[str, Any] | Any,
*,
alg: str | None = None,
algs: Iterable[str] | None = None,
Expand Down Expand Up @@ -249,7 +249,7 @@ def decrypt(

def decrypt_jwt(
self,
key: Jwk | dict[str, Any] | Any,
key: Jwk | Mapping[str, Any] | Any,
*,
alg: str | None = None,
algs: Iterable[str] | None = None,
Expand Down
24 changes: 16 additions & 8 deletions jwskate/jwk/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import warnings
from copy import copy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Mapping, SupportsBytes

Expand Down Expand Up @@ -154,7 +155,7 @@ def generate_for_kty(cls, kty: str, **kwargs: Any) -> Jwk:
"shake256": "shake256",
}

def __new__(cls, key: Jwk | dict[str, Any] | Any, **kwargs: Any) -> Jwk:
def __new__(cls, key: Jwk | Mapping[str, Any] | Any, **kwargs: Any) -> Jwk:
"""Overridden `__new__` to make the Jwk constructor smarter.
The `Jwk` constructor will accept:
Expand All @@ -171,7 +172,7 @@ def __new__(cls, key: Jwk | dict[str, Any] | Any, **kwargs: Any) -> Jwk:
if cls == Jwk:
if isinstance(key, Jwk):
return cls.from_cryptography_key(key.cryptography_key, **kwargs)
if isinstance(key, dict):
if isinstance(key, Mapping):
kty: str | None = key.get("kty")
if kty is None:
msg = "A Json Web Key must have a Key Type (kty)"
Expand All @@ -188,9 +189,9 @@ def __new__(cls, key: Jwk | dict[str, Any] | Any, **kwargs: Any) -> Jwk:
return cls.from_json(key)
else:
return cls.from_cryptography_key(key, **kwargs)
return super().__new__(cls, key, **kwargs)
return super().__new__(cls)

def __init__(self, params: dict[str, Any] | Any, *, include_kid_thumbprint: bool = False):
def __init__(self, params: Mapping[str, Any] | Any, *, include_kid_thumbprint: bool = False):
if isinstance(params, dict): # this is to avoid double init due to the __new__ above
super().__init__({key: val for key, val in params.items() if val is not None})
self._validate()
Expand Down Expand Up @@ -275,7 +276,8 @@ def __setitem__(self, key: str, value: Any) -> None:
RuntimeError: when trying to modify cryptographic attributes
"""
if key in self.PARAMS:
# don't allow modifying private attributes after the key has been initialized
if key in self.PARAMS and hasattr(self, "cryptography_key"):
msg = "JWK key attributes cannot be modified."
raise RuntimeError(msg)
super().__setitem__(key, value)
Expand Down Expand Up @@ -305,12 +307,18 @@ def alg(self) -> str | None:
return alg

@property
def kid(self) -> str | None:
"""Return the JWK key ID (kid), if present."""
def kid(self) -> str:
"""Return the JWK key ID (kid).
If the kid is not explicitly set, the RFC7638 key thumbprint is returned.
"""
kid = self.get("kid")
if kid is not None and not isinstance(kid, str): # pragma: no branch
msg = f"invalid kid type {type(kid)}"
raise TypeError(msg, kid)
if kid is None:
return self.thumbprint()
return kid

@property
Expand Down Expand Up @@ -1220,7 +1228,7 @@ def copy(self) -> Jwk:
a copy of this key, with the same value
"""
return Jwk(super().copy())
return Jwk(copy(self.data))

def with_kid_thumbprint(self, *, force: bool = False) -> Jwk:
"""Include the JWK thumbprint as `kid`.
Expand Down
12 changes: 6 additions & 6 deletions jwskate/jwk/ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def private(cls, *, crv: str, x: int, y: int, d: int, **params: Any) -> ECJwk:
dict(
kty=cls.KTY,
crv=crv,
x=BinaPy.from_int(x, coord_size).to("b64u").ascii(),
y=BinaPy.from_int(y, coord_size).to("b64u").ascii(),
d=BinaPy.from_int(d, coord_size).to("b64u").ascii(),
x=BinaPy.from_int(x, length=coord_size).to("b64u").ascii(),
y=BinaPy.from_int(y, length=coord_size).to("b64u").ascii(),
d=BinaPy.from_int(d, length=coord_size).to("b64u").ascii(),
**{k: v for k, v in params.items() if v is not None},
)
)
Expand Down Expand Up @@ -218,12 +218,12 @@ def from_cryptography_key(cls, cryptography_key: Any, **kwargs: Any) -> ECJwk:
msg = f"Unsupported Curve {cryptography_key.curve.name}"
raise NotImplementedError(msg)

x = BinaPy.from_int(public_numbers.x, crv.coordinate_size).to("b64u").ascii()
y = BinaPy.from_int(public_numbers.y, crv.coordinate_size).to("b64u").ascii()
x = BinaPy.from_int(public_numbers.x, length=crv.coordinate_size).to("b64u").ascii()
y = BinaPy.from_int(public_numbers.y, length=crv.coordinate_size).to("b64u").ascii()
parameters = {"kty": KeyTypes.EC, "crv": crv.name, "x": x, "y": y}
if isinstance(cryptography_key, ec.EllipticCurvePrivateKey):
pn = cryptography_key.private_numbers() # type: ignore[attr-defined]
d = BinaPy.from_int(pn.private_value, crv.coordinate_size).to("b64u").ascii()
d = BinaPy.from_int(pn.private_value, length=crv.coordinate_size).to("b64u").ascii()
parameters["d"] = d

return cls(parameters)
Expand Down
64 changes: 28 additions & 36 deletions jwskate/jwk/jwks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from __future__ import annotations

from typing import Any, Iterable
from typing import Any, Iterable, Mapping

from typing_extensions import override

from jwskate.token import BaseJsonDict

Expand All @@ -17,8 +19,8 @@ class JwkSet(BaseJsonDict):
methods to get the keys, add or remove keys, and verify signatures
using keys from this set.
- a `dict` from the parsed JSON object representing this JwkSet (in paramter `jwks`)
- a list of `Jwk` (in parameter `keys`
- a `dict` from the parsed JSON object representing this JwkSet (in parameter `jwks`)
- a list of `Jwk` (in parameter `keys`)
- nothing, to initialize an empty JwkSet
Args:
Expand All @@ -29,21 +31,23 @@ class JwkSet(BaseJsonDict):

def __init__(
self,
jwks: dict[str, Any] | None = None,
keys: Iterable[Jwk | dict[str, Any]] | None = None,
jwks: Mapping[str, Any] | None = None,
keys: Iterable[Jwk | Mapping[str, Any]] | None = None,
):
if jwks is None and keys is None:
keys = []

if jwks is not None:
keys = jwks.pop("keys", [])
super().__init__(jwks) # init the dict with all the dict content that is not keys
super().__init__({k: v for k, v in jwks.items() if k != "keys"} if jwks else {})
if keys is None and jwks is not None and "keys" in jwks:
keys = jwks.get("keys")
if keys:
for key in keys:
self.add_jwk(key)

@override
def __setitem__(self, name: str, value: Any) -> None:
if name == "keys":
for key in value:
self.add_jwk(key)
else:
super().__init__()

if keys is not None:
for jwk in keys:
self.add_jwk(jwk)
super().__setitem__(name, value)

@property
def jwks(self) -> list[Jwk]:
Expand All @@ -53,7 +57,7 @@ def jwks(self) -> list[Jwk]:
a list of `Jwk`
"""
return self.get("keys", []) # type: ignore[no-any-return]
return self.get("keys", [])

def get_jwk_by_kid(self, kid: str) -> Jwk:
"""Return a Jwk from this JwkSet, based on its kid.
Expand Down Expand Up @@ -84,35 +88,23 @@ def __len__(self) -> int:

def add_jwk(
self,
key: Jwk | dict[str, Any] | Any,
kid: str | None = None,
use: str | None = None,
key: Jwk | Mapping[str, Any] | Any,
) -> str:
"""Add a Jwk in this JwkSet.
Args:
key: the Jwk to add (either a `Jwk` instance, or a dict containing the Jwk parameters)
kid: the kid to use, if `jwk` doesn't contain one
use: the defined use for the added Jwk
Returns:
the kid from the added Jwk (it may be generated if no kid is provided)
the key ID. It will be generated if missing from the given Jwk.
"""
key = to_jwk(key)

self.setdefault("keys", [])
key = to_jwk(key).with_kid_thumbprint()

kid = key.get("kid", kid)
if not kid:
kid = key.thumbprint()
key["kid"] = kid
use = key.get("use", use)
if use:
key["use"] = use
self.jwks.append(key)
self.data.setdefault("keys", [])
self.data["keys"].append(key)

return kid
return key.kid

def remove_jwk(self, kid: str) -> None:
"""Remove a Jwk from this JwkSet, based on a `kid`.
Expand Down Expand Up @@ -198,7 +190,7 @@ def verify(
jwk = self.get_jwk_by_kid(kid)
return jwk.verify(data, signature, alg=alg, algs=algs)

# otherwise, try all keys which support the given alg(s)
# otherwise, try all keys that support the given alg(s)
if algs is None:
if alg is not None:
algs = (alg,)
Expand Down
10 changes: 5 additions & 5 deletions jwskate/jws/compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any, Iterable, SupportsBytes
from typing import TYPE_CHECKING, Any, Iterable, Mapping, SupportsBytes

from binapy import BinaPy
from typing_extensions import Self
Expand Down Expand Up @@ -62,9 +62,9 @@ def __init__(self, value: bytes | str, max_size: int = 16 * 1024):
def sign(
cls,
payload: bytes | SupportsBytes,
key: Jwk | dict[str, Any] | Any,
key: Jwk | Mapping[str, Any] | Any,
alg: str | None = None,
extra_headers: dict[str, Any] | None = None,
extra_headers: Mapping[str, Any] | None = None,
) -> JwsCompact:
"""Sign a payload and returns the resulting JwsCompact.
Expand Down Expand Up @@ -132,7 +132,7 @@ def signed_part(self) -> bytes:

def verify_signature(
self,
key: Jwk | dict[str, Any] | Any,
key: Jwk | Mapping[str, Any] | Any,
*,
alg: str | None = None,
algs: Iterable[str] | None = None,
Expand All @@ -153,7 +153,7 @@ def verify_signature(

def verify(
self,
key: Jwk | dict[str, Any] | Any,
key: Jwk | Mapping[str, Any] | Any,
*,
alg: str | None = None,
algs: Iterable[str] | None = None,
Expand Down
8 changes: 4 additions & 4 deletions jwskate/jws/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def jws_signature(self) -> JwsSignature:
def sign(
cls,
payload: bytes,
key: Jwk | dict[str, Any] | Any,
key: Jwk | Mapping[str, Any] | Any,
alg: str | None = None,
extra_protected_headers: Mapping[str, Any] | None = None,
header: Any | None = None,
Expand Down Expand Up @@ -115,7 +115,7 @@ def compact(self) -> JwsCompact:

def verify_signature(
self,
key: Jwk | dict[str, Any] | Any,
key: Jwk | Mapping[str, Any] | Any,
*,
alg: str | None = None,
algs: Iterable[str] | None = None,
Expand Down Expand Up @@ -218,7 +218,7 @@ def signatures(self) -> list[JwsSignature]:

def add_signature(
self,
key: Jwk | dict[str, Any] | Any,
key: Jwk | Mapping[str, Any] | Any,
alg: str | None = None,
extra_protected_headers: Mapping[str, Any] | None = None,
header: Mapping[str, Any] | None = None,
Expand Down Expand Up @@ -306,7 +306,7 @@ def flatten(

def verify_signature(
self,
key: Jwk | dict[str, Any] | Any,
key: Jwk | Mapping[str, Any] | Any,
*,
alg: str | None = None,
algs: Iterable[str] | None = None,
Expand Down
Loading

0 comments on commit f1111d6

Please sign in to comment.