diff --git a/templates/home.html b/templates/home.html index c4d134fb..7d3e5238 100644 --- a/templates/home.html +++ b/templates/home.html @@ -10,14 +10,11 @@

Service Provider configurati {% for item in sp_list %} - - {{ item['entityID'] }} - - loaded from {{ item['location'] }} - + + {{ item['entityID'] }} {% endfor %} -{% endblock %} +{% endblock %} \ No newline at end of file diff --git a/testenv/server.py b/testenv/server.py index 7da6f626..b2b65a59 100644 --- a/testenv/server.py +++ b/testenv/server.py @@ -284,7 +284,7 @@ def _handle_http_post(self, action): def _get_certificates_by_issuer(self, issuer): try: - return self._registry.load(issuer).certs() + return self._registry.get(issuer).certs() except KeyError: self._raise_error( 'entity ID {} non registrato, impossibile ricavare' @@ -357,7 +357,7 @@ def users(self): 'primary_attributes': spid_main_fields, 'secondary_attributes': spid_secondary_fields, 'users': self.user_manager.all(), - 'sp_list': self._registry.load_all().keys(), + 'sp_list': self._registry.all(), 'can_add_user': can_add_user } ) @@ -393,9 +393,8 @@ def index(self): **{ 'sp_list': [ { - "entityID": entity_id, - "location": sp_metadata.location, - } for (entity_id, sp_metadata) in self._registry.load_all().items() + "entityID": sp + } for sp in self._registry.all() ], } ) @@ -406,7 +405,7 @@ def get_destination(self, req, sp_id): acs_index = getattr(req, 'assertion_consumer_service_index', None) protocol_binding = getattr(req, 'protocol_binding', None) if acs_index is not None: - acss = self._registry.load( + acss = self._registry.get( sp_id).assertion_consumer_service(index=acs_index) if acss: destination = acss[0].get('Location') @@ -591,7 +590,7 @@ def login(self): atcs_idx ) ) - sp_metadata = self._registry.load(sp_id) + sp_metadata = self._registry.get(sp_id) required = [] optional = [] if atcs_idx and sp_metadata: @@ -792,7 +791,7 @@ def continue_response(self): def _sp_single_logout_service(self, issuer_name): _slo = None try: - _slo = self._registry.load(issuer_name).single_logout_services[0] + _slo = self._registry.get(issuer_name).single_logout_services[0] except Exception: pass return _slo diff --git a/testenv/spmetadata.py b/testenv/spmetadata.py index 9b32c44e..950bc016 100644 --- a/testenv/spmetadata.py +++ b/testenv/spmetadata.py @@ -2,7 +2,6 @@ from itertools import chain import requests -from lxml.etree import LxmlError from testenv import config, log from testenv.exceptions import DeserializationError, MetadataLoadError, MetadataNotFoundError, ValidationError @@ -32,117 +31,37 @@ class ServiceProviderMetadataRegistry: def __init__(self): + self._loaders = [] + for source_type, source_params in list(config.params.metadata.items()): + self._loaders.append({ + 'local': ServiceProviderMetadataFileLoader, + 'remote': ServiceProviderMetadataHTTPLoader, + 'db': ServiceProviderMetadataDbLoader, + }[source_type](source_params)) self._validators = ValidatorGroup([ XMLMetadataFormatValidator(), ServiceProviderMetadataXMLSchemaValidator(), SpidMetadataValidator(), ]) - self._index_metadata() - - def load(self, entity_id): - """ - Loads the metadata of a Service Provider. - - Args: - entity_id (str): Entity id of the SP (usually a URL or a URN). - Returns: - A ServiceProviderMetadata instance. - - Raises - MetadataNotFoundError: If there is no metadata associated to - the entity id. - DeserializationError: If the metadata associated to the entity id - is not valid. - """ + def get(self, entity_id): entity_id = entity_id.strip() - - fresh_metadata = None - - metadata = self._metadata.get(entity_id, None) - if not metadata: - # Try to reload all sources to see if the unknown entity id was added there - # somewhere. - logger.debug( - "Unknown entityId '{}`, reloading all the sources.".format(entity_id) - ) - self._index_metadata() - else: - # We got an known entity id, try to load its metadata the previously known - # location. + for loader in self._loaders: try: - fresh_metadata = metadata.loader.load(metadata.location) - if fresh_metadata.entity_id != entity_id: - raise MetadataLoadError - except MetadataLoadError as e: - logger.debug( - ("{}\n" - "Cannot find entityId '{}` at its previous location '{}`" - "reloading all the sources").format(e, entity_id, metadata.location) - ) - self._index_metadata() - - if not fresh_metadata: - try: - metadata = self._metadata[entity_id] - fresh_metadata = metadata.loader.load(metadata.location) - except (KeyError, MetadataLoadError): - raise MetadataNotFoundError(entity_id) - - if metadata.entity_id != entity_id: - raise MetadataNotFoundError(entity_id) - try: - self._validators.validate(fresh_metadata.xml) - except ValidationError as e: - raise DeserializationError(fresh_metadata.xml, e.details) - - return fresh_metadata - - def load_all(self): - """ - Returns a dict containing all ServerProviderMetadata loaded, - indexed by entityId. - """ - self._index_metadata() - - return self._metadata - - def _index_metadata(self): - """ - Populate self._metadata with the up to date information from all the - configured SP metadata. - """ - - # dict of { entity_id: ServiceProviderMetadata } - self._metadata = {} - - # Possible sources of metadata, ordered by preference - # (ie. the first source will be preferred in case of duplicate - # entity ids). - SOURCE_TYPES = ['local', 'db', 'remote'] - - for source_type in reversed(SOURCE_TYPES): - if source_type not in config.params.metadata: + metadata = loader.get(entity_id) + try: + self._validators.validate(metadata.xml) + return metadata + except ValidationError as e: + raise DeserializationError(metadata.xml, e.details) + except MetadataNotFoundError: continue - source_params = config.params.metadata[source_type] - - loader = { - 'local': ServiceProviderMetadataFileLoader, - 'remote': ServiceProviderMetadataHTTPLoader, - 'db': ServiceProviderMetadataDbLoader, - }[source_type](source_params) - - metadata = loader.load_all() - for dup in set(metadata.keys()).intersection(set(self._metadata)): - logger.info( - "Discarding duplicate entity_id `{}' from '{}`.".format( - dup, - self._metadata[dup].location - ) - ) + raise MetadataNotFoundError(entity_id) - self._metadata.update(metadata) + def all(self): + """Returns the list of entityIDs of all the known Service Providers""" + return [i for loader in self._loaders for i in loader.all()] registry = None @@ -153,126 +72,62 @@ def build_metadata_registry(): registry = ServiceProviderMetadataRegistry() -class LoadAllMixin(object): - def load_all(self): - """ - Loads all the available SP metadata, skipping duplicates. - - Returns: - A dict containing all local ServerProviderMetadata loaded, - indexed by entityId. - """ - metadata = None - ret = {} - - for location in self._locations: - try: - metadata = self.load(location) - except MetadataLoadError as e: - logger.info( - "Skipping '{}` because of a load error: {}".format(location, e) - ) - continue - - if metadata.entity_id in ret: - logger.info( - "Discarding duplicate entity_id `{}' from '{}`.".format( - metadata.entity_id, - metadata.location - ) - ) - continue - - ret[metadata.entity_id] = metadata - - return ret - - -class ServiceProviderMetadataFileLoader(LoadAllMixin, object): - """ - Loads SP metadata from a list of files. +class ServiceProviderMetadataFileLoader: + """Loads metadata from the configured files - Args: - locations (list of str): List of paths to load. Paths can also contain - globbing metacharacters. + This could be improved automatically reloading the metadata when + file timestamps change """ - def __init__(self, locations): - files = [glob(entry) for entry in locations] - - self._locations = list(chain.from_iterable(files)) - - def load(self, location): - """ - Loads the SP metadata from file. - - Args: - location (str): The path of file. + def __init__(self, conf): + self._metadata = {} - Returns: - A ServiceProviderMetadata instance. + files = [glob(entry) for entry in conf] + for file in list(chain.from_iterable(files)): + try: + with open(file, 'rb') as fp: + metadata = ServiceProviderMetadata(fp.read()) + self._metadata[metadata.entity_id] = metadata + logger.debug("Loaded metadata for: " + metadata.entity_id) + except Exception as e: + raise MetadataLoadError( + "Impossibile leggere il file '{}': '{}'".format(file, e) + ) - Raises: - MetadataLoadError: If the load fails. - """ + def get(self, entity_id): try: - with open(location, 'rb') as fp: - metadata = ServiceProviderMetadata(fp.read(), self, location) - except (IOError, LxmlError) as e: - raise MetadataLoadError( - "Failed to load '{}': '{}'".format(location, e) - ) - logger.debug( - "Loaded metadata for '{}` from '{}`".format( - metadata.entity_id, - location - ) - ) - return metadata - - -class ServiceProviderMetadataHTTPLoader(LoadAllMixin, object): - """ - Loads SP metadata from a list of HTTP URLs. - - Args: - urls (list of str): List of HTTP URLs to load. - """ - - def __init__(self, locations): - self._locations = locations + return self._metadata[entity_id] + except KeyError: + raise MetadataNotFoundError(entity_id) - def load(self, location): - """ - Loads the SP metadata from HTTP. + def all(self): + return list(self._metadata.keys()) - Args: - location (str): The URL of the metadata to load. - Returns: - A ServiceProviderMetadata instance. +class ServiceProviderMetadataHTTPLoader: + """Loads metadata from the configured URLs""" - Raises: - MetadataLoadError: If the load fails. - """ + def __init__(self, conf): + self._metadata = {} + for url in conf: + try: + response = requests.get(url) + response.raise_for_status() + metadata = ServiceProviderMetadata(response.content) + self._metadata[metadata.entity_id] = metadata + except Exception as e: + raise MetadataLoadError( + "La richiesta all'endpoint HTTP '{}': '{}'".format(url, e) + ) + def get(self, entity_id): try: - response = requests.get(location) - response.raise_for_status() - metadata = ServiceProviderMetadata(response.content, self, location) - except Exception as e: - raise MetadataLoadError( - "Request to HTTP endpoint '{}': '{}'".format(location, e) - ) - - logger.debug( - "Loaded metadata for '{}` from '{}`".format( - metadata.entity_id, - location - ) - ) + return self._metadata[entity_id] + except KeyError: + raise MetadataNotFoundError(entity_id) - return metadata + def all(self): + return list(self._metadata.keys()) class ServiceProviderMetadataDbLoader: @@ -281,38 +136,20 @@ class ServiceProviderMetadataDbLoader: def __init__(self, conf): self._provider = DatabaseSPProvider(conf) - def load(self, entity_id): + def get(self, entity_id): metadata = self._provider.get(entity_id) if not metadata: raise MetadataNotFoundError(entity_id) - return ServiceProviderMetadata(metadata, self, 'db') - - def load_all(self): - """ - Returns a dict containing all 'db' ServerProviderMetadata loaded, - indexed by entityId.""" - return { - entity_id: ServiceProviderMetadata(xml, self, 'db') - for (entity_id, xml) in self._provider.all().items() - } + return ServiceProviderMetadata(metadata) + def all(self): + return list(self._provider.all().keys()) -class ServiceProviderMetadata(object): - """ - Object representing the metadata of a Service Provider. - - Args: - xml (str): The metadata as XML. - loader (instance of ServiceProviderMetadata{File,HTTP,Db}Loader): The loader the - metadata was loaded with. - location (str): The source the metadata was loaded from. - It's a path for 'local' metadata, a URL for 'remote' and - the string 'db' for 'db'. - """ - def __init__(self, xml, loader, location): + +class ServiceProviderMetadata: + + def __init__(self, xml): self.xml = xml - self.loader = loader - self.location = location self._metadata = saml_to_dict(xml) @property diff --git a/testenv/tests/test_spid_testenv.py b/testenv/tests/test_spid_testenv.py index 11a6d0fb..d4e3e1c3 100644 --- a/testenv/tests/test_spid_testenv.py +++ b/testenv/tests/test_spid_testenv.py @@ -29,7 +29,7 @@ def _sp_single_logout_service(server, issuer_name, binding): - _slo = server._registry.load(issuer_name).single_logout_service( + _slo = server._registry.get(issuer_name).single_logout_service( binding=binding ) return _slo[0] diff --git a/testenv/tests/test_validators.py b/testenv/tests/test_validators.py index 1ecf9ecf..543a5cbc 100644 --- a/testenv/tests/test_validators.py +++ b/testenv/tests/test_validators.py @@ -68,7 +68,7 @@ class FakeRegistry: def __init__(self, metadata): self._metadata = metadata.copy() - def load(self, entity_id): + def get(self, entity_id): return self._metadata.get(entity_id) @property diff --git a/testenv/validators.py b/testenv/validators.py index 5cfd2362..c24a58f8 100644 --- a/testenv/validators.py +++ b/testenv/validators.py @@ -499,7 +499,7 @@ def validate(self, request): 'Issuer non presente nella {}'.format(req_type) ) try: - sp_metadata = self._registry.load(issuer_name) + sp_metadata = self._registry.get(issuer_name) except MetadataNotFoundError: raise UnknownEntityIDError( 'L\'entity ID "{}" indicato nell\'elemento non corrisponde a nessun Service Provider registrato in questo Identity Provider di test.'.format(issuer_name)