diff --git a/src/rapids_dependency_file_generator/cli.py b/src/rapids_dependency_file_generator/cli.py index cf58994..450397f 100644 --- a/src/rapids_dependency_file_generator/cli.py +++ b/src/rapids_dependency_file_generator/cli.py @@ -1,10 +1,14 @@ import argparse +import os import yaml from ._version import __version__ as version from .constants import OutputTypes, default_dependency_file_path -from .rapids_dependency_file_generator import make_dependency_files +from .rapids_dependency_file_generator import ( + delete_existing_files, + make_dependency_files, +) from .rapids_dependency_file_validator import validate_dependencies @@ -17,6 +21,17 @@ def validate_args(argv): default=default_dependency_file_path, help="Path to YAML config file", ) + parser.add_argument( + "--clean", + nargs="?", + default=None, + const="", + help=( + "Delete any files previously created by dfg before running. An optional " + "path to clean may be provided, otherwise the current working directory " + "is used as the root from which to clean." + ), + ) codependent_args = parser.add_argument_group("optional, but codependent") codependent_args.add_argument( @@ -45,6 +60,11 @@ def validate_args(argv): + "".join([f"\n --{x}" for x in dependent_arg_keys]) ) + # If --clean was passed without arguments, default to cleaning from the root of the + # tree where the config file is. + if args.clean == "": + args.clean = os.path.dirname(os.path.abspath(args.config)) + return args @@ -79,4 +99,6 @@ def main(argv=None): } } + if args.clean: + delete_existing_files(args.clean) make_dependency_files(parsed_config, args.config, to_stdout) diff --git a/src/rapids_dependency_file_generator/rapids_dependency_file_generator.py b/src/rapids_dependency_file_generator/rapids_dependency_file_generator.py index 5d2c05b..b58ab7d 100644 --- a/src/rapids_dependency_file_generator/rapids_dependency_file_generator.py +++ b/src/rapids_dependency_file_generator/rapids_dependency_file_generator.py @@ -15,6 +15,27 @@ OUTPUT_ENUM_VALUES = [str(x) for x in OutputTypes] NON_NONE_OUTPUT_ENUM_VALUES = [str(x) for x in OutputTypes if not x == OutputTypes.NONE] +HEADER = f"# This file is generated by `{cli_name}`." + + +def delete_existing_files(root="."): + """Delete any files generated by this generator. + + This function can be used to clean up a directory tree before generating a new set + of files from scratch. + + Parameters + ---------- + root : str + The path to the root of the directory tree to search for files to delete. + """ + for dirpath, _, filenames in os.walk(root): + for fn in filter( + lambda fn: fn.endswith(".txt") or fn.endswith(".yaml"), filenames + ): + with open(file_path := os.path.join(dirpath, fn)) as f: + if HEADER in f.read(): + os.remove(file_path) def dedupe(dependencies): @@ -94,7 +115,7 @@ def make_dependency_file( relative_path_to_config_file = os.path.relpath(config_file, output_dir) file_contents = textwrap.dedent( f"""\ - # This file is generated by `{cli_name}`. + {HEADER} # To make changes, edit {relative_path_to_config_file} and run `{cli_name}`. """ ) diff --git a/tests/test_examples.py b/tests/test_examples.py index 9a41b1c..fe76b90 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,7 +1,5 @@ import glob -import os import pathlib -import shutil import jsonschema import pytest @@ -20,13 +18,6 @@ INVALID_EXAMPLE_FILES = list(CURRENT_DIR.glob("examples/invalid/*/dependencies.yaml")) -@pytest.fixture(scope="session", autouse=True) -def clean_actual_files(): - for root, _, _ in os.walk("tests"): - if pathlib.Path(root).name == "actual": - shutil.rmtree(root) - - def make_file_set(file_dir): return { pathlib.Path(f).relative_to(file_dir) @@ -56,7 +47,14 @@ def test_examples(example_dir): actual_dir = example_dir.joinpath("output", "actual") dep_file_path = example_dir.joinpath("dependencies.yaml") - main(["--config", str(dep_file_path)]) + main( + [ + "--config", + str(dep_file_path), + "--clean", + str(example_dir.joinpath("output", "actual")), + ] + ) expected_file_set = make_file_set(expected_dir) actual_file_set = make_file_set(actual_dir) @@ -75,7 +73,14 @@ def test_error_examples(test_name): dep_file_path = test_dir.joinpath("dependencies.yaml") with pytest.raises(ValueError): - main(["--config", str(dep_file_path)]) + main( + [ + "--config", + str(dep_file_path), + "--clean", + str(test_dir.joinpath("output", "actual")), + ] + ) def test_examples_are_valid(schema, example_dir):