diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cb1de9c..b853617 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,7 +38,6 @@ jobs: python-version: - "3.11" - "3.12" - steps: - uses: actions/checkout@v4 - name: Install uv @@ -49,6 +48,27 @@ jobs: run: | uv venv uv pip install tox-uv tox-gh-actions - - name: Test with tox + - name: Test building with tox run: uv run tox r - \ No newline at end of file + + unit: + name: Unit tests + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: + - "3.11" + - "3.12" + steps: + - uses: actions/checkout@v4 + - name: Install uv + uses: astral-sh/setup-uv@v2 + - name: Set up Python ${{ matrix.python-version }} + run: uv python install ${{ matrix.python-version }} + - name: Install dependencies + run: | + uv venv + uv pip install -U -e ".[test]" + - name: Run unittest + run: uv run pytest \ No newline at end of file diff --git a/mreg_cli/tokenfile.py b/mreg_cli/tokenfile.py index 0a1308b..95556a9 100644 --- a/mreg_cli/tokenfile.py +++ b/mreg_cli/tokenfile.py @@ -5,9 +5,9 @@ import json import os import sys -from typing import Optional +from typing import Any, Optional, Self -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter, ValidationError # The contents of the token file is: @@ -28,47 +28,61 @@ class Token(BaseModel): username: str +TokenList = TypeAdapter(list[Token]) + + class TokenFile: """A class for managing tokens in a JSON file.""" tokens_path: str = os.path.join(os.getenv("HOME", ""), ".mreg-cli_auth_token.json") - def __init__(self, tokens: Optional[list[dict[str, str]]] = None): + def __init__(self, tokens: Any = None): """Initialize the TokenFile instance.""" - self.tokens = [Token(**token) for token in tokens] if tokens else [] - - @classmethod - def _load_tokens(cls) -> "TokenFile": - """Load tokens from a JSON file, returning a new instance of TokenFile.""" - try: - with open(cls.tokens_path, "r") as file: - data = json.load(file) - return TokenFile(tokens=data["tokens"]) - except (FileNotFoundError, KeyError): - return TokenFile(tokens=[]) - - @classmethod - def _set_file_permissions(cls, mode: int) -> None: + self.tokens = self._validate_tokens(tokens) + + def _validate_tokens(self, tokens: Any) -> list[Token]: + """Convert deserialized JSON to list of Token objects.""" + if tokens: + try: + return TokenList.validate_python(tokens) + except ValidationError as e: + print( + f"Failed to validate tokens from token file {self.tokens_path}: {e}", + file=sys.stderr, + ) + return [] + + def _set_file_permissions(self, mode: int) -> None: """Set the file permissions for the token file.""" try: - os.chmod(cls.tokens_path, mode) + os.chmod(self.tokens_path, mode) except PermissionError: - print("Failed to set permissions on " + cls.tokens_path, file=sys.stderr) + print(f"Failed to set permissions on {self.tokens_path}", file=sys.stderr) except FileNotFoundError: pass - @classmethod - def _save_tokens(cls, tokens: "TokenFile") -> None: + def save(self) -> None: """Save tokens to a JSON file.""" - with open(cls.tokens_path, "w") as file: - json.dump({"tokens": [token.model_dump() for token in tokens.tokens]}, file, indent=4) + with open(self.tokens_path, "w") as file: + json.dump({"tokens": [token.model_dump() for token in self.tokens]}, file, indent=4) + self._set_file_permissions(0o600) - cls._set_file_permissions(0o600) + @classmethod + def load(cls) -> Self: + """Load tokens from a JSON file, returning a new instance of TokenFile.""" + try: + with open(cls.tokens_path, "r") as file: + data = json.load(file) + return cls(tokens=data.get("tokens")) + except (FileNotFoundError, KeyError, json.JSONDecodeError) as e: + if isinstance(e, json.JSONDecodeError): + print(f"Failed to decode JSON in tokens file {cls.tokens_path}", file=sys.stderr) + return cls(tokens=[]) @classmethod def get_entry(cls, username: str, url: str) -> Optional[Token]: """Retrieve a token by username and URL.""" - tokens_file = cls._load_tokens() + tokens_file = cls.load() for token in tokens_file.tokens: if token.url == url and token.username == username: return token @@ -77,13 +91,11 @@ def get_entry(cls, username: str, url: str) -> Optional[Token]: @classmethod def set_entry(cls, username: str, url: str, new_token: str) -> None: """Update or add a token based on the URL and username.""" - tokens_file = cls._load_tokens() + tokens_file = cls.load() for token in tokens_file.tokens: if token.url == url and token.username == username: token.token = new_token - cls._save_tokens(tokens_file) - return - + return tokens_file.save() # If not found, add a new token tokens_file.tokens.append(Token(token=new_token, url=url, username=username)) - cls._save_tokens(tokens_file) + tokens_file.save() diff --git a/pyproject.toml b/pyproject.toml index 7b7f651..4251c27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,17 @@ dependencies = [ dynamic = ["version"] [project.optional-dependencies] -dev = ["ruff", "tox-uv", "rich", "setuptools", "setuptools-scm", "build"] +test = ["pytest", "inline-snapshot", "pytest-httpserver"] +dev = [ + "mreg-cli[test]", + "ruff", + "tox-uv", + "rich", + "setuptools", + "setuptools-scm", + "build", + "pyinstaller", +] [project.urls] Repository = 'https://github.com/unioslo/mreg-cli/' diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1d6d64a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import os +from typing import Iterator + +import pytest +from pytest_httpserver import HTTPServer + +from mreg_cli.config import MregCliConfig + + +@pytest.fixture(autouse=True) +def set_url_env(httpserver: HTTPServer) -> Iterator[None]: + """Set the config URL to the test HTTP server URL.""" + conf = MregCliConfig() + pre_override_conf = conf._config_cmd.copy() # pyright: ignore[reportPrivateUsage] + conf._config_cmd["url"] = httpserver.url_for("/") # pyright: ignore[reportPrivateUsage] + yield + conf._config_cmd = pre_override_conf # pyright: ignore[reportPrivateUsage] + + +@pytest.fixture(autouse=True if os.environ.get("PYTEST_HTTPSERVER_STRICT") else False) +def check_assertions(httpserver: HTTPServer) -> Iterator[None]: + """Ensure all HTTP server assertions are checked after the test.""" + # If the HTTP server raises errors or has failed assertions in its handlers + # themselves, we want to raise an exception to fail the test. + # + # The `check_assertions` method will raise an exception if there are + # if any tests have HTTP test server errors. + # See: https://pytest-httpserver.readthedocs.io/en/latest/tutorial.html#handling-test-errors + # https://pytest-httpserver.readthedocs.io/en/latest/howto.html#using-custom-request-handler + # + # If a test has an assertion or handler error that is expected, it should + # call `httpserver.clear_assertions()` and/or `httpserver.clear_handler_errors()` as needed. + yield + httpserver.check_assertions() + httpserver.check_handler_errors() diff --git a/tests/test_errorbuilder.py b/tests/test_errorbuilder.py new file mode 100644 index 0000000..0033271 --- /dev/null +++ b/tests/test_errorbuilder.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import pytest + +from mreg_cli.errorbuilder import ( + ErrorBuilder, + FallbackErrorBuilder, + FilterErrorBuilder, + build_error_message, + get_builder, +) + + +@pytest.mark.parametrize( + "command, exc_or_str, expected", + [ + ( + r"permission label_add 192.168.0.0/24 oracle-group ^(db|cman)ora.*\.example\.com$ oracle", + "failed to compile regex", + FilterErrorBuilder, + ), + ( + r"permission label_add other_error", + "Other error message", + FallbackErrorBuilder, + ), + ], +) +def test_get_builder(command: str, exc_or_str: str, expected: type[ErrorBuilder]) -> None: + builder = get_builder(command, exc_or_str) + assert builder.__class__ == expected + assert builder.get_underline(0, 0) == "" + assert builder.get_underline(0, 10) == "^^^^^^^^^^" + assert builder.get_underline(5, 10) == " ^^^^^" + + +@pytest.mark.parametrize( + "command, exc_or_str, expected", + [ + ( + r"permission label_add 192.168.0.0/24 oracle-group ^(db|cman)ora.*\.example\.com$ oracle", + r"Unable to compile regex 'cman)ora.*\.example\.com$ oracle'", + r"""Unable to compile regex 'cman)ora.*\.example\.com$ oracle' +permission label_add 192.168.0.0/24 oracle-group ^(db|cman)ora.*\.example\.com$ oracle + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + └ Consider enclosing this part in quotes.""", + ), + ( + r"permission label_add other_error", + "Other error message", + "Other error message", + ), + ], +) +def test_build_error_message(command: str, exc_or_str: str, expected: str) -> None: + assert build_error_message(command, exc_or_str) == expected diff --git a/tests/test_tokenfile.py b/tests/test_tokenfile.py new file mode 100644 index 0000000..539be7a --- /dev/null +++ b/tests/test_tokenfile.py @@ -0,0 +1,209 @@ +"""Tests for token file handling. + +TODO: Add tests for the following scenarios: + - Error cases for token read/write operations + - Class methods behavior when called on the class itself + vs an instance of the class +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Iterator + +import pytest +from inline_snapshot import snapshot + +from mreg_cli.tokenfile import TokenFile + +TOKENS_PATH_ORIGINAL = TokenFile.tokens_path + + +TOKEN_FILE_SINGLE = """ +{ + "tokens": [ + { + "token": "exampletoken123", + "url": "https://example.com", + "username": "exampleuser" + } + ] +} +""" + + +TOKEN_FILE_MULTIPLE = """ +{ + "tokens": [ + { + "token": "exampletoken123", + "url": "https://example.com", + "username": "exampleuser" + }, + { + "token": "footoken456", + "url": "https://foo.com", + "username": "foouser" + }, + { + "token": "bartoken789", + "url": "https://bar.com", + "username": "baruser" + } + ] +} +""" + + +@pytest.fixture(autouse=True) +def reset_token_file_path() -> Iterator[None]: + """Reset the token file path after each test.""" + yield + TokenFile.tokens_path = TOKENS_PATH_ORIGINAL + + +def test_load_file_nonexistent(tmp_path: Path) -> None: + """Load from a nonexistent tokens file.""" + tokens_path = tmp_path / "does_not_exist.json" + assert not tokens_path.exists() + TokenFile.tokens_path = str(tokens_path) + tokenfile = TokenFile.load() + assert tokenfile.tokens == [] + + +def test_load_file_empty(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + """Load from an empty tokens file.""" + tokens_path = tmp_path / "empty.json" + tokens_path.touch() + assert tokens_path.read_text() == "" + TokenFile.tokens_path = str(tokens_path) + tokenfile = TokenFile.load() + assert tokenfile.tokens == [] + assert "Failed to decode JSON" in capsys.readouterr().err + + +def test_load_file_invalid_json(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + """Load from a tokens file with invalid JSON.""" + tokens_path = tmp_path / "invalid.json" + tokens_path.write_text("not json") + assert tokens_path.read_text() == snapshot("not json") + TokenFile.tokens_path = str(tokens_path) + tokenfile = TokenFile.load() + assert tokenfile.tokens == [] + assert "Failed to decode JSON" in capsys.readouterr().err + + +def test_load_file_invalid_tokenfile(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + """Load from a tokens file that is not in the correct format.""" + tokens_path = tmp_path / "invalid_format.json" + + # Contents is valid JSON, but contents are not in the correct format + tokens_path.write_text( + json.dumps( + { + "tokens": [ + { + "invalid_key": 123, + "url": "https://example.com", + # missing token and username keys + }, + { + "token": "exampletoken123", + "url": "https://example.com", + "username": "exampleuser", + }, + ] + } + ) + ) + TokenFile.tokens_path = str(tokens_path) + tokenfile = TokenFile.load() + assert tokenfile.tokens == [] + assert "Failed to validate tokens from token file" in capsys.readouterr().err + + +def test_load_file_single(tmp_path: Path) -> None: + """Load from a tokens file with a single token.""" + tokens_path = tmp_path / "single.json" + tokens_path.write_text(TOKEN_FILE_SINGLE) + TokenFile.tokens_path = str(tokens_path) + tokenfile = TokenFile.load() + tokens = tokenfile.tokens + assert len(tokens) == 1 + assert tokens[0].token == snapshot("exampletoken123") + assert tokens[0].url == snapshot("https://example.com") + assert tokens[0].username == snapshot("exampleuser") + + +def test_load_file_multiple(tmp_path: Path) -> None: + """Load from a tokens file with multiple tokens.""" + tokens_path = tmp_path / "multiple.json" + tokens_path.write_text(TOKEN_FILE_MULTIPLE) + TokenFile.tokens_path = str(tokens_path) + assert len(TokenFile.load().tokens) == snapshot(3) + + +def test_get_entry(tmp_path: Path) -> None: + """Get a token from the token file.""" + tokens_path = tmp_path / "get_token.json" + tokens_path.write_text(TOKEN_FILE_MULTIPLE) + TokenFile.tokens_path = str(tokens_path) + + token = TokenFile.get_entry("exampleuser", "https://example.com") + assert token is not None + assert token.token == snapshot("exampletoken123") + assert token.url == snapshot("https://example.com") + assert token.username == snapshot("exampleuser") + + token = TokenFile.get_entry("foouser", "https://foo.com") + assert token is not None + assert token.token == snapshot("footoken456") + assert token.url == snapshot("https://foo.com") + assert token.username == snapshot("foouser") + + token = TokenFile.get_entry("baruser", "https://bar.com") + assert token is not None + assert token.token == snapshot("bartoken789") + assert token.url == snapshot("https://bar.com") + assert token.username == snapshot("baruser") + + token = TokenFile.get_entry("nonexistent", "https://example.com") + assert token is None + + +def test_set_entry_existing(tmp_path: Path) -> None: + """Set a token in the token file that already exists.""" + tokens_path = tmp_path / "set_existing.json" + tokens_path.write_text(TOKEN_FILE_MULTIPLE) + TokenFile.tokens_path = str(tokens_path) + + assert len(TokenFile.load().tokens) == snapshot(3) + TokenFile.set_entry("newuser", "https://new.com", "newtoken123") + assert len(TokenFile.load().tokens) == snapshot(4) + token = TokenFile.get_entry("newuser", "https://new.com") + assert token is not None + assert token.token == snapshot("newtoken123") + + +@pytest.mark.parametrize("create_before", [True, False], ids=["create_before", "create_after"]) +def test_set_entry_new(tmp_path: Path, create_before: bool) -> None: + """Set a token in the token file that does not already exist.""" + tokens_path = tmp_path / "set_new.json" + if create_before: + tokens_path.touch() # empty file + TokenFile.tokens_path = str(tokens_path) + + # Write tokens to the empty file + assert len(TokenFile.load().tokens) == snapshot(0) + TokenFile.set_entry("newuser", "https://new.com", "newtoken123") + assert len(TokenFile.load().tokens) == snapshot(1) + token = TokenFile.get_entry("newuser", "https://new.com") + assert token is not None + assert token.token == snapshot("newtoken123") + + # Try to load the tokens from the file again + assert len(TokenFile.load().tokens) == snapshot(1) + token = TokenFile.get_entry("newuser", "https://new.com") + assert token is not None + assert token.token == snapshot("newtoken123") diff --git a/tests/utilities/__init__.py b/tests/utilities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utilities/test_api.py b/tests/utilities/test_api.py new file mode 100644 index 0000000..e9cd9ed --- /dev/null +++ b/tests/utilities/test_api.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +from typing import Any + +import pytest +from inline_snapshot import snapshot +from pytest_httpserver import HTTPServer +from werkzeug import Response + +from mreg_cli.exceptions import MultipleEntititesFound, ValidationError +from mreg_cli.utilities.api import _strip_none, get_list, get_list_unique # type: ignore + + +@pytest.mark.parametrize( + "inp,expect", + [ + # Empty dict + ({}, {}), + # Mixed values + ({"foo": "a", "bar": None}, {"foo": "a"}), + # Multiple keys with None values + ({"foo": None, "bar": None}, {}), + # Nested dicts + ({"foo": {"bar": {"baz": None}}}, {}), + ( + {"foo": {"bar": {"baz": None}}, "qux": {}, "quux": ["a", "b", "c"]}, + {"quux": ["a", "b", "c"]}, + ), + ], +) +def test_strip_none(inp: dict[str, Any], expect: dict[str, Any]) -> None: + assert _strip_none(inp) == expect + + +def test_get_list_paginated(httpserver: HTTPServer) -> None: + httpserver.expect_oneshot_request("/foobar").respond_with_json( + { + "results": [{"foo": "bar"}], + "count": 1, + "next": None, + "previous": None, + } + ) + resp = get_list("/foobar") + assert resp == snapshot([{"foo": "bar"}]) + + +def test_get_list_paginated_empty(httpserver: HTTPServer) -> None: + httpserver.expect_oneshot_request("/foobar").respond_with_json( + { + "results": [], + "count": 0, + "next": None, + "previous": None, + } + ) + resp = get_list("/foobar") + assert resp == snapshot([]) + + +def test_get_list_paginated_multiple_pages(httpserver: HTTPServer) -> None: + httpserver.expect_oneshot_request("/foobar").respond_with_json( + { + "results": [{"foo": "bar"}], + "count": 1, + "next": "/foobar?page=2", + "previous": None, + } + ) + httpserver.expect_oneshot_request("/foobar", query_string="page=2").respond_with_json( + { + "results": [{"baz": "qux"}], + "count": 1, + "next": None, + "previous": "/foobar", + } + ) + resp = get_list("/foobar") + assert resp == snapshot([{"foo": "bar"}, {"baz": "qux"}]) + + +def test_get_list_paginated_multiple_pages_ok404(httpserver: HTTPServer) -> None: + """Paginated response with 404 on next page is ignored when `ok404=True`.""" + httpserver.expect_oneshot_request("/foobar").respond_with_json( + { + "results": [{"foo": "bar"}], + "count": 1, + "next": "/foobar?page=2", + "previous": None, + } + ) + httpserver.expect_oneshot_request("/foobar", query_string="page=2").respond_with_response( + Response(status=404) + ) + assert get_list("/foobar", ok404=True) == snapshot([{"foo": "bar"}]) + + +def test_get_list_paginated_multiple_pages_inconsistent_count(httpserver: HTTPServer) -> None: + """Inconsistent count in paginated response is ignored.""" + httpserver.expect_oneshot_request("/foobar").respond_with_json( + { + "results": [{"foo": "bar"}, {"baz": "qux"}], + "count": 1, # wrong count + "next": "/foobar?page=2", + "previous": None, + } + ) + httpserver.expect_oneshot_request("/foobar", query_string="page=2").respond_with_json( + { + "results": [{"quux": "spam"}], + "count": 2, # wrong count + "next": None, + "previous": "/foobar", + } + ) + resp = get_list("/foobar") + assert resp == snapshot([{"foo": "bar"}, {"baz": "qux"}, {"quux": "spam"}]) + + +@pytest.mark.parametrize( + "results", + [ + '"foo"', # Not a list + "42", # Not a list + '{"foo": "bar"}', # Not a list + "{'foo': 'bar'}", # Invalid JSON + not a list + "[{'foo': 'bar'}]", # Invalid JSON + ], +) +def test_get_list_paginated_invalid(httpserver: HTTPServer, results: Any) -> None: + """Invalid JSON or non-array response is an error.""" + httpserver.expect_oneshot_request("/foobar").respond_with_data( + f"""{{ + "results": {results}, + "count": 1, + "next": None, + "previous": None, + }}""" + ) + with pytest.raises(ValidationError) as exc_info: + get_list("/foobar") + assert "did not return valid paginated JSON" in exc_info.exconly() + + +def test_get_list_non_paginated(httpserver: HTTPServer) -> None: + """Inconsistent count in paginated response is ignored.""" + httpserver.expect_oneshot_request("/foobar").respond_with_json( + [ + "foo", + "bar", + {"baz": "qux"}, + ] + ) + resp = get_list("/foobar") + assert resp == snapshot(["foo", "bar", {"baz": "qux"}]) + + +def test_get_list_non_paginated_empty(httpserver: HTTPServer) -> None: + """Inconsistent count in paginated response is ignored.""" + httpserver.expect_oneshot_request("/foobar").respond_with_json([]) + resp = get_list("/foobar") + assert resp == snapshot([]) + + +def test_get_list_non_paginated_non_array(httpserver: HTTPServer) -> None: + """Non-paginated non-array response is an error.""" + httpserver.expect_oneshot_request("/foobar").respond_with_json( + { + "not": "an array", + } + ) + with pytest.raises(ValidationError) as exc_info: + get_list("/foobar") + assert "did not return a valid JSON" in exc_info.exconly() + + +def test_get_list_non_paginated_invalid_json(httpserver: HTTPServer) -> None: + """Non-paginated response with invalid JSON is an error.""" + httpserver.expect_oneshot_request("/foobar").respond_with_data( + "[{'key': 'value'}, 'foo',]", # strings must be double quoted + content_type="application/json", + ) + with pytest.raises(ValidationError) as exc_info: + get_list("/foobar") + assert "did not return a valid JSON" in exc_info.exconly() + + +def test_get_list_unique_paginated(httpserver: HTTPServer) -> None: + """Non-paginated response with invalid JSON is an error.""" + httpserver.expect_oneshot_request("/foobar").respond_with_json( + { + "results": [{"foo": "bar"}], + "count": 1, + "next": None, + "previous": None, + } + ) + resp = get_list_unique("/foobar", params={}) + assert resp == snapshot({"foo": "bar"}) + + +def test_get_list_unique_paginated_too_many_results(httpserver: HTTPServer) -> None: + """Non-paginated response with invalid JSON is an error.""" + httpserver.expect_oneshot_request("/foobar").respond_with_json( + { + "results": [{"foo": "bar"}], + "count": 1, + "next": "/foobar?page=2", + "previous": None, + } + ) + httpserver.expect_oneshot_request("/foobar", query_string="page=2").respond_with_json( + { + "results": [{"baz": "qux"}], + "count": 1, + "next": None, + "previous": "/foobar", + } + ) + with pytest.raises(MultipleEntititesFound) as exc_info: + get_list_unique("/foobar", params={}) + assert "Expected exactly one result, got 2" in exc_info.exconly() + + +def test_get_list_unique_paginated_no_result(httpserver: HTTPServer) -> None: + """No result is None.""" + httpserver.expect_oneshot_request("/foobar").respond_with_json( + { + "results": [], + "count": 0, + "next": None, + "previous": None, + } + ) + resp = get_list_unique("/foobar", params={}) + assert resp is None + + +def test_get_list_unique_non_paginated_no_result(httpserver: HTTPServer) -> None: + """No result is None.""" + httpserver.expect_oneshot_request("/foobar").respond_with_json([]) + resp = get_list_unique("/foobar", params={}) + assert resp is None