diff --git a/templates/home.html b/templates/home.html
index c4d134fb..7d3e5238 100644
--- a/templates/home.html
+++ b/templates/home.html
@@ -10,14 +10,11 @@
Service Provider configurati
{% for item in sp_list %}
-
- {{ item['entityID'] }}
-
- loaded from {{ item['location'] }}
-
+ |
+ {{ item['entityID'] }}
|
{% endfor %}
-{% endblock %}
+{% endblock %}
\ No newline at end of file
diff --git a/testenv/server.py b/testenv/server.py
index 7da6f626..b2b65a59 100644
--- a/testenv/server.py
+++ b/testenv/server.py
@@ -284,7 +284,7 @@ def _handle_http_post(self, action):
def _get_certificates_by_issuer(self, issuer):
try:
- return self._registry.load(issuer).certs()
+ return self._registry.get(issuer).certs()
except KeyError:
self._raise_error(
'entity ID {} non registrato, impossibile ricavare'
@@ -357,7 +357,7 @@ def users(self):
'primary_attributes': spid_main_fields,
'secondary_attributes': spid_secondary_fields,
'users': self.user_manager.all(),
- 'sp_list': self._registry.load_all().keys(),
+ 'sp_list': self._registry.all(),
'can_add_user': can_add_user
}
)
@@ -393,9 +393,8 @@ def index(self):
**{
'sp_list': [
{
- "entityID": entity_id,
- "location": sp_metadata.location,
- } for (entity_id, sp_metadata) in self._registry.load_all().items()
+ "entityID": sp
+ } for sp in self._registry.all()
],
}
)
@@ -406,7 +405,7 @@ def get_destination(self, req, sp_id):
acs_index = getattr(req, 'assertion_consumer_service_index', None)
protocol_binding = getattr(req, 'protocol_binding', None)
if acs_index is not None:
- acss = self._registry.load(
+ acss = self._registry.get(
sp_id).assertion_consumer_service(index=acs_index)
if acss:
destination = acss[0].get('Location')
@@ -591,7 +590,7 @@ def login(self):
atcs_idx
)
)
- sp_metadata = self._registry.load(sp_id)
+ sp_metadata = self._registry.get(sp_id)
required = []
optional = []
if atcs_idx and sp_metadata:
@@ -792,7 +791,7 @@ def continue_response(self):
def _sp_single_logout_service(self, issuer_name):
_slo = None
try:
- _slo = self._registry.load(issuer_name).single_logout_services[0]
+ _slo = self._registry.get(issuer_name).single_logout_services[0]
except Exception:
pass
return _slo
diff --git a/testenv/spmetadata.py b/testenv/spmetadata.py
index 9b32c44e..950bc016 100644
--- a/testenv/spmetadata.py
+++ b/testenv/spmetadata.py
@@ -2,7 +2,6 @@
from itertools import chain
import requests
-from lxml.etree import LxmlError
from testenv import config, log
from testenv.exceptions import DeserializationError, MetadataLoadError, MetadataNotFoundError, ValidationError
@@ -32,117 +31,37 @@
class ServiceProviderMetadataRegistry:
def __init__(self):
+ self._loaders = []
+ for source_type, source_params in list(config.params.metadata.items()):
+ self._loaders.append({
+ 'local': ServiceProviderMetadataFileLoader,
+ 'remote': ServiceProviderMetadataHTTPLoader,
+ 'db': ServiceProviderMetadataDbLoader,
+ }[source_type](source_params))
self._validators = ValidatorGroup([
XMLMetadataFormatValidator(),
ServiceProviderMetadataXMLSchemaValidator(),
SpidMetadataValidator(),
])
- self._index_metadata()
-
- def load(self, entity_id):
- """
- Loads the metadata of a Service Provider.
-
- Args:
- entity_id (str): Entity id of the SP (usually a URL or a URN).
- Returns:
- A ServiceProviderMetadata instance.
-
- Raises
- MetadataNotFoundError: If there is no metadata associated to
- the entity id.
- DeserializationError: If the metadata associated to the entity id
- is not valid.
- """
+ def get(self, entity_id):
entity_id = entity_id.strip()
-
- fresh_metadata = None
-
- metadata = self._metadata.get(entity_id, None)
- if not metadata:
- # Try to reload all sources to see if the unknown entity id was added there
- # somewhere.
- logger.debug(
- "Unknown entityId '{}`, reloading all the sources.".format(entity_id)
- )
- self._index_metadata()
- else:
- # We got an known entity id, try to load its metadata the previously known
- # location.
+ for loader in self._loaders:
try:
- fresh_metadata = metadata.loader.load(metadata.location)
- if fresh_metadata.entity_id != entity_id:
- raise MetadataLoadError
- except MetadataLoadError as e:
- logger.debug(
- ("{}\n"
- "Cannot find entityId '{}` at its previous location '{}`"
- "reloading all the sources").format(e, entity_id, metadata.location)
- )
- self._index_metadata()
-
- if not fresh_metadata:
- try:
- metadata = self._metadata[entity_id]
- fresh_metadata = metadata.loader.load(metadata.location)
- except (KeyError, MetadataLoadError):
- raise MetadataNotFoundError(entity_id)
-
- if metadata.entity_id != entity_id:
- raise MetadataNotFoundError(entity_id)
- try:
- self._validators.validate(fresh_metadata.xml)
- except ValidationError as e:
- raise DeserializationError(fresh_metadata.xml, e.details)
-
- return fresh_metadata
-
- def load_all(self):
- """
- Returns a dict containing all ServerProviderMetadata loaded,
- indexed by entityId.
- """
- self._index_metadata()
-
- return self._metadata
-
- def _index_metadata(self):
- """
- Populate self._metadata with the up to date information from all the
- configured SP metadata.
- """
-
- # dict of { entity_id: ServiceProviderMetadata }
- self._metadata = {}
-
- # Possible sources of metadata, ordered by preference
- # (ie. the first source will be preferred in case of duplicate
- # entity ids).
- SOURCE_TYPES = ['local', 'db', 'remote']
-
- for source_type in reversed(SOURCE_TYPES):
- if source_type not in config.params.metadata:
+ metadata = loader.get(entity_id)
+ try:
+ self._validators.validate(metadata.xml)
+ return metadata
+ except ValidationError as e:
+ raise DeserializationError(metadata.xml, e.details)
+ except MetadataNotFoundError:
continue
- source_params = config.params.metadata[source_type]
-
- loader = {
- 'local': ServiceProviderMetadataFileLoader,
- 'remote': ServiceProviderMetadataHTTPLoader,
- 'db': ServiceProviderMetadataDbLoader,
- }[source_type](source_params)
-
- metadata = loader.load_all()
- for dup in set(metadata.keys()).intersection(set(self._metadata)):
- logger.info(
- "Discarding duplicate entity_id `{}' from '{}`.".format(
- dup,
- self._metadata[dup].location
- )
- )
+ raise MetadataNotFoundError(entity_id)
- self._metadata.update(metadata)
+ def all(self):
+ """Returns the list of entityIDs of all the known Service Providers"""
+ return [i for loader in self._loaders for i in loader.all()]
registry = None
@@ -153,126 +72,62 @@ def build_metadata_registry():
registry = ServiceProviderMetadataRegistry()
-class LoadAllMixin(object):
- def load_all(self):
- """
- Loads all the available SP metadata, skipping duplicates.
-
- Returns:
- A dict containing all local ServerProviderMetadata loaded,
- indexed by entityId.
- """
- metadata = None
- ret = {}
-
- for location in self._locations:
- try:
- metadata = self.load(location)
- except MetadataLoadError as e:
- logger.info(
- "Skipping '{}` because of a load error: {}".format(location, e)
- )
- continue
-
- if metadata.entity_id in ret:
- logger.info(
- "Discarding duplicate entity_id `{}' from '{}`.".format(
- metadata.entity_id,
- metadata.location
- )
- )
- continue
-
- ret[metadata.entity_id] = metadata
-
- return ret
-
-
-class ServiceProviderMetadataFileLoader(LoadAllMixin, object):
- """
- Loads SP metadata from a list of files.
+class ServiceProviderMetadataFileLoader:
+ """Loads metadata from the configured files
- Args:
- locations (list of str): List of paths to load. Paths can also contain
- globbing metacharacters.
+ This could be improved automatically reloading the metadata when
+ file timestamps change
"""
- def __init__(self, locations):
- files = [glob(entry) for entry in locations]
-
- self._locations = list(chain.from_iterable(files))
-
- def load(self, location):
- """
- Loads the SP metadata from file.
-
- Args:
- location (str): The path of file.
+ def __init__(self, conf):
+ self._metadata = {}
- Returns:
- A ServiceProviderMetadata instance.
+ files = [glob(entry) for entry in conf]
+ for file in list(chain.from_iterable(files)):
+ try:
+ with open(file, 'rb') as fp:
+ metadata = ServiceProviderMetadata(fp.read())
+ self._metadata[metadata.entity_id] = metadata
+ logger.debug("Loaded metadata for: " + metadata.entity_id)
+ except Exception as e:
+ raise MetadataLoadError(
+ "Impossibile leggere il file '{}': '{}'".format(file, e)
+ )
- Raises:
- MetadataLoadError: If the load fails.
- """
+ def get(self, entity_id):
try:
- with open(location, 'rb') as fp:
- metadata = ServiceProviderMetadata(fp.read(), self, location)
- except (IOError, LxmlError) as e:
- raise MetadataLoadError(
- "Failed to load '{}': '{}'".format(location, e)
- )
- logger.debug(
- "Loaded metadata for '{}` from '{}`".format(
- metadata.entity_id,
- location
- )
- )
- return metadata
-
-
-class ServiceProviderMetadataHTTPLoader(LoadAllMixin, object):
- """
- Loads SP metadata from a list of HTTP URLs.
-
- Args:
- urls (list of str): List of HTTP URLs to load.
- """
-
- def __init__(self, locations):
- self._locations = locations
+ return self._metadata[entity_id]
+ except KeyError:
+ raise MetadataNotFoundError(entity_id)
- def load(self, location):
- """
- Loads the SP metadata from HTTP.
+ def all(self):
+ return list(self._metadata.keys())
- Args:
- location (str): The URL of the metadata to load.
- Returns:
- A ServiceProviderMetadata instance.
+class ServiceProviderMetadataHTTPLoader:
+ """Loads metadata from the configured URLs"""
- Raises:
- MetadataLoadError: If the load fails.
- """
+ def __init__(self, conf):
+ self._metadata = {}
+ for url in conf:
+ try:
+ response = requests.get(url)
+ response.raise_for_status()
+ metadata = ServiceProviderMetadata(response.content)
+ self._metadata[metadata.entity_id] = metadata
+ except Exception as e:
+ raise MetadataLoadError(
+ "La richiesta all'endpoint HTTP '{}': '{}'".format(url, e)
+ )
+ def get(self, entity_id):
try:
- response = requests.get(location)
- response.raise_for_status()
- metadata = ServiceProviderMetadata(response.content, self, location)
- except Exception as e:
- raise MetadataLoadError(
- "Request to HTTP endpoint '{}': '{}'".format(location, e)
- )
-
- logger.debug(
- "Loaded metadata for '{}` from '{}`".format(
- metadata.entity_id,
- location
- )
- )
+ return self._metadata[entity_id]
+ except KeyError:
+ raise MetadataNotFoundError(entity_id)
- return metadata
+ def all(self):
+ return list(self._metadata.keys())
class ServiceProviderMetadataDbLoader:
@@ -281,38 +136,20 @@ class ServiceProviderMetadataDbLoader:
def __init__(self, conf):
self._provider = DatabaseSPProvider(conf)
- def load(self, entity_id):
+ def get(self, entity_id):
metadata = self._provider.get(entity_id)
if not metadata:
raise MetadataNotFoundError(entity_id)
- return ServiceProviderMetadata(metadata, self, 'db')
-
- def load_all(self):
- """
- Returns a dict containing all 'db' ServerProviderMetadata loaded,
- indexed by entityId."""
- return {
- entity_id: ServiceProviderMetadata(xml, self, 'db')
- for (entity_id, xml) in self._provider.all().items()
- }
+ return ServiceProviderMetadata(metadata)
+ def all(self):
+ return list(self._provider.all().keys())
-class ServiceProviderMetadata(object):
- """
- Object representing the metadata of a Service Provider.
-
- Args:
- xml (str): The metadata as XML.
- loader (instance of ServiceProviderMetadata{File,HTTP,Db}Loader): The loader the
- metadata was loaded with.
- location (str): The source the metadata was loaded from.
- It's a path for 'local' metadata, a URL for 'remote' and
- the string 'db' for 'db'.
- """
- def __init__(self, xml, loader, location):
+
+class ServiceProviderMetadata:
+
+ def __init__(self, xml):
self.xml = xml
- self.loader = loader
- self.location = location
self._metadata = saml_to_dict(xml)
@property
diff --git a/testenv/tests/test_spid_testenv.py b/testenv/tests/test_spid_testenv.py
index 11a6d0fb..d4e3e1c3 100644
--- a/testenv/tests/test_spid_testenv.py
+++ b/testenv/tests/test_spid_testenv.py
@@ -29,7 +29,7 @@
def _sp_single_logout_service(server, issuer_name, binding):
- _slo = server._registry.load(issuer_name).single_logout_service(
+ _slo = server._registry.get(issuer_name).single_logout_service(
binding=binding
)
return _slo[0]
diff --git a/testenv/tests/test_validators.py b/testenv/tests/test_validators.py
index 1ecf9ecf..543a5cbc 100644
--- a/testenv/tests/test_validators.py
+++ b/testenv/tests/test_validators.py
@@ -68,7 +68,7 @@ class FakeRegistry:
def __init__(self, metadata):
self._metadata = metadata.copy()
- def load(self, entity_id):
+ def get(self, entity_id):
return self._metadata.get(entity_id)
@property
diff --git a/testenv/validators.py b/testenv/validators.py
index 5cfd2362..c24a58f8 100644
--- a/testenv/validators.py
+++ b/testenv/validators.py
@@ -499,7 +499,7 @@ def validate(self, request):
'Issuer non presente nella {}'.format(req_type)
)
try:
- sp_metadata = self._registry.load(issuer_name)
+ sp_metadata = self._registry.get(issuer_name)
except MetadataNotFoundError:
raise UnknownEntityIDError(
'L\'entity ID "{}" indicato nell\'elemento non corrisponde a nessun Service Provider registrato in questo Identity Provider di test.'.format(issuer_name)