Skip to content

Commit

Permalink
Support query parameter addressing
Browse files Browse the repository at this point in the history
  • Loading branch information
kislyuk committed Apr 12, 2022
1 parent c3acaa4 commit cccaa66
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 25 deletions.
34 changes: 20 additions & 14 deletions http_message_signatures/resolvers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import urllib

import http_sfv

from .exceptions import HTTPMessageSignaturesException
from .structures import CaseInsensitiveDict

Expand All @@ -25,18 +27,17 @@ def __init__(self, message):
if hasattr(message, "status_code"):
self.message_type = "response"
self.url = message.url
# TODO: check header key and value transforms are applied per 2.1
self.headers = CaseInsensitiveDict(message.headers)

def resolve(self, component_id):
if component_id.startswith("@"): # derived component
if component_id not in self.derived_component_names:
raise HTTPMessageSignaturesException(f'Unknown covered derived component name "{component_id}"')
resolver = getattr(self, "get_" + component_id[1:].replace("-", "_"))
return resolver()
if component_id not in self.headers:
def resolve(self, component_id: http_sfv.Item):
if component_id.value.startswith("@"): # derived component
if component_id.value not in self.derived_component_names:
raise HTTPMessageSignaturesException(f'Unknown covered derived component name {component_id.value}')
resolver = getattr(self, "get_" + component_id.value[1:].replace("-", "_"))
return resolver(**component_id.params)
if component_id.value not in self.headers:
raise HTTPMessageSignaturesException(f'Covered header field "{component_id}" not found in the message')
return self.headers[component_id]
return self.headers[component_id.value]

def get_method(self):
if self.message_type == "response":
Expand All @@ -61,17 +62,22 @@ def get_path(self):
def get_query(self):
return "?" + urllib.parse.urlsplit(self.url).query

def get_query_params(self):
# need to parse component id as a structured field
# urllib.parse.parse_qs(urllib.parse.urlsplit(request.url).query, keep_blank_values=True)
raise NotImplementedError()
def get_query_params(self, *, name: str):
query = urllib.parse.parse_qs(urllib.parse.urlsplit(self.url).query, keep_blank_values=True)
if name not in query:
raise HTTPMessageSignaturesException(f'Query parameter "{name}" not found in the message URL')
if len(query[name]) != 1:
raise HTTPMessageSignaturesException('Query parameters with multiple values are not supported.')
return query[name][0]

def get_status(self):
if self.message_type != "response":
raise HTTPMessageSignaturesException('Unexpected "@status" component in a request signature')
return str(self.message.status_code)

def get_request_response(self):
def get_request_response(self, *, key: str):
# See 2.2.11 Request-Response Signature Binding
# self.message.request.headers["Signature"][key]
raise NotImplementedError()


Expand Down
36 changes: 25 additions & 11 deletions http_message_signatures/signatures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import datetime
import logging
from typing import List, Dict

import http_sfv
Expand All @@ -8,6 +9,8 @@
from .algorithms import HTTPSignatureAlgorithm, signature_algorithms
from .exceptions import HTTPMessageSignaturesException, InvalidSignature

logger = logging.getLogger(__name__)


class HTTPSignatureHandler:
signature_metadata_parameters = {
Expand All @@ -29,19 +32,17 @@ def __init__(self, *,
self.component_resolver_class = component_resolver_class

def build_signature_base(self, message, *,
covered_component_ids: List[str],
covered_component_ids: List[http_sfv.Item],
signature_params: Dict[str, str]):
assert "@signature-params" not in covered_component_ids
sig_elements = collections.OrderedDict()
component_resolver = self.component_resolver_class(message)
for component_id in covered_component_ids:
component_name_node = http_sfv.Item(component_id)
component_key = str(http_sfv.List([component_name_node]))
component_key = str(http_sfv.List([component_id]))
# TODO: model situations when header occurs multiple times
# TODO: 2.1.2 parameterized keys
component_value = component_resolver.resolve(component_id)
if component_key.lower() != component_key:
raise HTTPMessageSignaturesException(f'Component ID "{component_key}" is not all lowercase')
if component_id.value.lower() != component_id.value:
raise HTTPMessageSignaturesException(f'Component ID "{component_id.value}" is not all lowercase')
if "\n" in component_key:
raise HTTPMessageSignaturesException(f'Component ID "{component_key}" contains newline character')
if component_key in sig_elements:
Expand All @@ -58,6 +59,17 @@ def build_signature_base(self, message, *,
class HTTPMessageSigner(HTTPSignatureHandler):
DEFAULT_SIGNATURE_LABEL = "pyhms"

def parse_covered_component_ids(self, covered_component_ids):
covered_component_nodes = []
for component_id in covered_component_ids:
component_name_node = http_sfv.Item()
if component_id.startswith('"'):
component_name_node.parse(component_id.encode())
else:
component_name_node.value = component_id
covered_component_nodes.append(component_name_node)
return covered_component_nodes

def sign(self, message, *,
key_id: str,
created: datetime.datetime = None,
Expand All @@ -80,9 +92,12 @@ def sign(self, message, *,
signature_params["nonce"] = nonce
if include_alg:
signature_params["alg"] = self.signature_algorithm.algorithm_id
sig_base, sig_params_node, _ = self.build_signature_base(message,
covered_component_ids=covered_component_ids,
signature_params=signature_params)
covered_component_nodes = self.parse_covered_component_ids(covered_component_ids)
sig_base, sig_params_node, _ = self.build_signature_base(
message,
covered_component_ids=covered_component_nodes,
signature_params=signature_params
)
signer = self.signature_algorithm(private_key=key)
signature = signer.sign(sig_base.encode())
sig_label = self.DEFAULT_SIGNATURE_LABEL
Expand Down Expand Up @@ -125,14 +140,13 @@ def verify(self, message):
if sig_input.params["alg"] != self.signature_algorithm.algorithm_id:
raise InvalidSignature("Unexpected algorithm specified in the signature")
key = self.key_resolver.resolve_public_key(sig_input.params["keyid"])
covered_component_ids = [i.value for i in sig_input]
for param in sig_input.params:
if param not in self.signature_metadata_parameters:
raise InvalidSignature(f'Unexpected signature metadata parameter "{param}"')
try:
sig_base, sig_params_node, sig_elements = self.build_signature_base(
message,
covered_component_ids=covered_component_ids,
covered_component_ids=list(sig_input),
signature_params=sig_input.params
)
except Exception as e:
Expand Down
14 changes: 14 additions & 0 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,20 @@ def test_http_message_signatures_B26(self):
with self.assertRaises(InvalidSignature):
verifier.verify(self.test_request)

def test_query_parameters(self):
signer = HTTPMessageSigner(signature_algorithm=HMAC_SHA256, key_resolver=self.key_resolver)
signer.sign(self.test_request,
key_id="test-shared-secret",
covered_component_ids=("date", "@authority", "content-type", '"@query-params";name="Pet"'),
created=datetime.fromtimestamp(1618884473))
self.assertEqual(self.test_request.headers["Signature-Input"],
('pyhms=("date" "@authority" "content-type" "@query-params";name="Pet");'
'created=1618884473;keyid="test-shared-secret";alg="hmac-sha256"'))
self.assertEqual(self.test_request.headers["Signature"],
'pyhms=:LOYhEJpBn34v3KohQBFl5qSy93haFd3+Ka9wwOmKeN0=:')
verifier = HTTPMessageVerifier(signature_algorithm=HMAC_SHA256, key_resolver=self.key_resolver)
verifier.verify(self.test_request)


if __name__ == '__main__':
unittest.main()

0 comments on commit cccaa66

Please sign in to comment.