Skip to content

Commit

Permalink
DFAST-121 Implement suggested factory
Browse files Browse the repository at this point in the history
  • Loading branch information
MRVermeulenDeltares committed Feb 28, 2024
1 parent c713b04 commit 71257da
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 66 deletions.
47 changes: 20 additions & 27 deletions dfastmi/batch/FileNameRetriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@
This file is part of D-FAST Morphological Impact: https://github.com/Deltares/D-FAST_Morphological_Impact
"""

from typing import Optional, Union, Dict, Any, Tuple
from typing import Callable, Optional, Union, Dict, Any, Tuple
import configparser
from abc import ABC, abstractmethod
from packaging import version

class AFileNameRetriever(ABC):
"""
Expand Down Expand Up @@ -83,33 +82,27 @@ def _cfg_get(self, config: configparser.ConfigParser, chap: str, key: str) -> st
except:
raise Exception(f'Keyword "{key}" is not specified in group "{chap}" of analysis configuration file.')

def get_filename_retriever(imode : int, config : configparser.ConfigParser, need_tides : bool) -> AFileNameRetriever:
class FileNameRetrieverFactory:
"""
Retrieves the expected file name retriever based on the given values.
Arguments
---------
imode : int
Specification of run mode (0 = WAQUA, 1 = D-Flow FM).
config : Optional[configparser.ConfigParser]
The variable containing the configuration (may be None for imode = 0).
needs_tide : bool
Specifies whether the tidal boundary is needed.
Returns
-------
filenames : AFileNameRetriever
returns the AFileNameRetriever which should be used.
Class is used to register and get creation of AFileNameRetriever Objects
"""
if imode == 0 or config is None:
return FileNameRetrieverUnsupported()

if version.parse(config["General"]["Version"]) == version.parse("1"):
return FileNameRetrieverLegacy()

return FileNameRetriever(need_tides)
_creators = {}
"""Contains the AFileNameRetriever Objects creators to be used"""

def __init__(self):
self._creators = {}

def register_creator(self, version: str, creator: Callable[[bool], AFileNameRetriever]):
"""Register creator function to create a AFileNameRetriever object."""
self._creators[version] = creator

def generate(self, version: str, needs_tide: bool) -> AFileNameRetriever:
"""Call the Constructor function to generate AFileNameRetriever object."""
constructor = self._creators.get(version)
if constructor:
return constructor(needs_tide)
else:
return FileNameRetrieverUnsupported()

class FileNameRetrieverUnsupported(AFileNameRetriever):
"""
Expand Down
16 changes: 15 additions & 1 deletion dfastmi/batch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,12 @@ def batch_get_times(Q: Vector, q_fit: Tuple[float, float], q_stagnant: float, q_

return T, time_mi

def _initialize_file_name_retriever_factory() -> FileNameRetriever.FileNameRetrieverFactory:
factory = FileNameRetriever.FileNameRetrieverFactory()
factory.register_creator("1.0", lambda needs_tide: FileNameRetriever.FileNameRetrieverLegacy())
factory.register_creator("2.0", lambda needs_tide: FileNameRetriever.FileNameRetriever(needs_tide))
return factory

def get_filenames(
imode: int,
needs_tide: bool,
Expand Down Expand Up @@ -512,7 +518,15 @@ def get_filenames(
can be the discharge index, discharge value or a tuple of forcing
conditions, such as a Discharge and Tide forcing tuple.
"""
file_name_retriever = FileNameRetriever.get_filename_retriever(imode, config, needs_tide)

if imode != 0:
general_version = config.get("General", "Version", fallback= None)
else:
general_version = None

factory = _initialize_file_name_retriever_factory()
file_name_retriever = factory.generate(general_version, needs_tide)

return file_name_retriever.get_file_names(config)

def analyse_and_report(
Expand Down
77 changes: 39 additions & 38 deletions tests/batch/test_FileNameRetriever.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,53 @@
import pytest
import dfastmi.batch.FileNameRetriever
from dfastmi.batch import FileNameRetriever
from configparser import ConfigParser

class Test_FileNameRetriever():
@pytest.mark.parametrize("imode, config", [
(0, ConfigParser()),
(1, None),
])
def given_values_for_unsupported_retriever_when_get_filename_retriever_then_return_expected_retriever_unsupported(self, imode, config):
file_name_retriever = dfastmi.batch.FileNameRetriever.get_filename_retriever(imode, config, False)
class Test_FileNameRetrieverFactory():
@pytest.fixture
def factory(self) -> FileNameRetriever.FileNameRetrieverFactory:
factory = FileNameRetriever.FileNameRetrieverFactory()
factory.register_creator("1.0", lambda needs_tide: FileNameRetriever.FileNameRetrieverLegacy())
factory.register_creator("2.0", lambda needs_tide: FileNameRetriever.FileNameRetriever(needs_tide))
return factory

assert isinstance(file_name_retriever, dfastmi.batch.FileNameRetriever.FileNameRetrieverUnsupported)
def given_version_1_with_varying_needs_tide_when_generate_then_return_FileNameRetrieverLegacy(self, factory : FileNameRetriever.FileNameRetrieverFactory):
version = "1.0"
needs_tide = True
file_name_retriever = factory.generate(version, needs_tide)
assert isinstance(file_name_retriever, FileNameRetriever.FileNameRetrieverLegacy)

def given_values_for_legacy_retriever_when_get_filename_retriever_then_return_expected_retriever_legacy(self):
imode = 1
config = ConfigParser()
config.add_section("General")
config.set("General", "Version", "1")

file_name_retriever = dfastmi.batch.FileNameRetriever.get_filename_retriever(imode, config, False)

assert isinstance(file_name_retriever, dfastmi.batch.FileNameRetriever.FileNameRetrieverLegacy)

@pytest.mark.parametrize("use_tide", [
True,
False
])
def given_values_for_retriever_when_get_filename_retriever_then_return_expected_retriever(self, use_tide):
imode = 1
config = ConfigParser()
config.add_section("General")
config.set("General", "Version", "2")
file_name_retriever = dfastmi.batch.FileNameRetriever.get_filename_retriever(imode, config, use_tide)
@pytest.mark.parametrize("needs_tide", [
True,
False
])
def given_version_2_with_varying_needs_tide_when_generate_then_return_FileNameRetriever(self, factory : FileNameRetriever.FileNameRetrieverFactory, needs_tide : bool):
version = "2.0"
file_name_retriever = factory.generate(version, needs_tide)
assert isinstance(file_name_retriever, FileNameRetriever.FileNameRetriever)

assert isinstance(file_name_retriever, dfastmi.batch.FileNameRetriever.FileNameRetriever)
@pytest.mark.parametrize("version", [
"0.0",
"999.0",
None
])
def given_unsupported_version_when_generate_then_return_FileNameRetrieverUnsupported(self, factory : FileNameRetriever.FileNameRetrieverFactory, version):
needs_tide = True
file_name_retriever = factory.generate(version, needs_tide)
assert isinstance(file_name_retriever, FileNameRetriever.FileNameRetrieverUnsupported)

class Test_FileNameRetriever_Unsupported():
def given_config_parser_when_get_file_names_unsupported_then_return_no_file_names(self):
config = ConfigParser()
fnrvu = dfastmi.batch.FileNameRetriever.FileNameRetrieverUnsupported()
fnrvu = FileNameRetriever.FileNameRetrieverUnsupported()

filenames = fnrvu.get_file_names(config)

assert len(filenames) == 0

class Test_FileNameRetriever_legacy():
def given_partial_setup_config_parser_when_get_file_names_legacy_then_throw_exception_with_expected_message(self):
fnr_legacy = dfastmi.batch.FileNameRetriever.FileNameRetrieverLegacy()
fnr_legacy = FileNameRetriever.FileNameRetrieverLegacy()

key = "Reference"
chap = "Q1"
Expand All @@ -61,15 +62,15 @@ def given_partial_setup_config_parser_when_get_file_names_legacy_then_throw_exce

def given_empty_config_parser_when_get_file_names_legacy_then_return_no_file_names(self):
config = ConfigParser()
fnr_legacy = dfastmi.batch.FileNameRetriever.FileNameRetrieverLegacy()
fnr_legacy = FileNameRetriever.FileNameRetrieverLegacy()

filenames = fnr_legacy.get_file_names(config)

assert len(filenames) == 0

def given_config_parser_with_data_when_get_file_names_legacy_then_return_expected_file_names(self):
config = ConfigParser()
fnr_legacy = dfastmi.batch.FileNameRetriever.FileNameRetrieverLegacy()
fnr_legacy = FileNameRetriever.FileNameRetrieverLegacy()

q1_expected_filename= self.get_expected_q_filename_and_update_config(config, "Q1")
q2_expected_filename= self.get_expected_q_filename_and_update_config(config, "Q2")
Expand Down Expand Up @@ -99,7 +100,7 @@ class Test_FileNameRetriever():
False
])
def given_partial_setup_config_parser_when_get_file_names_then_throw_exception_with_expected_message(self, use_tide):
file_name_retriever = dfastmi.batch.FileNameRetriever.FileNameRetriever(use_tide)
file_name_retriever = FileNameRetriever.FileNameRetriever(use_tide)

key = "Discharge"
chap = "C1"
Expand All @@ -119,7 +120,7 @@ def given_partial_setup_config_parser_when_get_file_names_then_throw_exception_w
"--=+"
])
def given_config_parser_with_not_a_float_for_discharge_when_get_file_names_then_throw_type_error_with_expected_message(self, not_a_float_string):
file_name_retriever = dfastmi.batch.FileNameRetriever.FileNameRetriever(False)
file_name_retriever = FileNameRetriever.FileNameRetriever(False)

key = "Discharge"
chap = "C1"
Expand All @@ -139,15 +140,15 @@ def given_config_parser_with_not_a_float_for_discharge_when_get_file_names_then_
])
def given_empty_config_parser_when_get_file_names_then_return_no_file_names(self, use_tide):
config = ConfigParser()
file_name_retriever = dfastmi.batch.FileNameRetriever.FileNameRetriever(use_tide)
file_name_retriever = FileNameRetriever.FileNameRetriever(use_tide)

filenames = file_name_retriever.get_file_names(config)

assert len(filenames) == 0

def given_config_parser_with_data_when_get_file_names_then_return_expected_file_names(self):
config = ConfigParser()
file_name_retriever = dfastmi.batch.FileNameRetriever.FileNameRetriever(False)
file_name_retriever = FileNameRetriever.FileNameRetriever(False)

q1_expected_filename= self.get_expected_q_filename_and_update_config(config, "C1", "1.0")
q2_expected_filename= self.get_expected_q_filename_and_update_config(config, "C2", "2.0")
Expand All @@ -161,7 +162,7 @@ def given_config_parser_with_data_when_get_file_names_then_return_expected_file_

def given_config_parser_with_data_and_need_of_tide_enable_when_get_file_names_then_return_expected_file_names_including_tide(self):
config = ConfigParser()
file_name_retriever = dfastmi.batch.FileNameRetriever.FileNameRetriever(True)
file_name_retriever = FileNameRetriever.FileNameRetriever(True)

q1_expected_filename= self.get_expected_q_filename_and_update_config(config, "C1", "1.0", True)
q2_expected_filename= self.get_expected_q_filename_and_update_config(config, "C2", "2.0", True)
Expand Down

0 comments on commit 71257da

Please sign in to comment.