From 0f2b2153cd7fa9dd88752a4557d5409c145018c9 Mon Sep 17 00:00:00 2001 From: Vitor Enes Date: Wed, 16 Nov 2022 15:12:29 +0000 Subject: [PATCH] Combine options outside of Z3 --- .github/workflows/actions.yml | 4 +- predicate/Makefile | 7 +- predicate/README.md | 73 +-- predicate/cli/__main__.py | 8 +- predicate/examples/access.py | 29 +- predicate/examples/join_session.py | 65 +-- predicate/mypy.ini | 1 - predicate/poetry.lock | 122 +++-- predicate/pyproject.toml | 3 +- predicate/solver/ast.py | 139 +++--- predicate/solver/teleport.py | 253 ++++++---- predicate/solver/test/test_ast.py | 420 ++++++++-------- predicate/solver/test/test_aws.py | 16 +- predicate/solver/test/test_teleport.py | 467 ++++++++++-------- .../test/test_teleport_access_requests.py | 58 +-- .../solver/test/test_teleport_get_started.py | 26 +- 16 files changed, 913 insertions(+), 778 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 8b9db27..c8bd4a1 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -17,4 +17,6 @@ jobs: pip install poetry poetry install - name: Run tests - run: 'poetry run make test' + run: poetry run make test + - name: Run linter + run: poetry run make lint-check diff --git a/predicate/Makefile b/predicate/Makefile index 4dc90a3..75ca534 100644 --- a/predicate/Makefile +++ b/predicate/Makefile @@ -1,16 +1,15 @@ -all: +all: lint test .PHONY: test-% test-%: - mypy solver/test/test_$*.py pytest --pyargs solver/test/test_$*.py .PHONY: lint lint: lint-python -.PHONY: check -check: +.PHONY: lint-check +lint-check: lint-python --check .PHONY: test diff --git a/predicate/README.md b/predicate/README.md index aa70655..9353a1a 100644 --- a/predicate/README.md +++ b/predicate/README.md @@ -10,83 +10,16 @@ Alternately, `poetry shell` can also be used to run `predicate`. ## Working with policies -### Example policy - -```py -# access.py - -from solver.ast import Duration -from solver.teleport import AccessNode, Node, Options, OptionsSet, Policy, Rules, User - - -class Teleport: - p = Policy( - name="access", - loud=False, - allow=Rules( - AccessNode( - ((AccessNode.login == User.name) & (User.name != "root")) - | (User.traits["team"] == ("admins",)) - ), - ), - options=OptionsSet(Options((Options.max_session_ttl < Duration.new(hours=10)))), - deny=Rules( - AccessNode( - (AccessNode.login == "mike") - | (AccessNode.login == "jester") - | (Node.labels["env"] == "prod") - ), - ), - ) - - def test_access(self): - # Alice will be able to login to any machine as herself - ret, _ = self.p.check( - AccessNode( - (AccessNode.login == "alice") - & (User.name == "alice") - & (Node.labels["env"] == "dev") - ) - ) - assert ret is True, "Alice can login with her user to any node" - - # No one is permitted to login as mike - ret, _ = self.p.query(AccessNode((AccessNode.login == "mike"))) - assert ret is False, "This role does not allow access as mike" - - # No one is permitted to login as jester - ret, _ = self.p.query(AccessNode((AccessNode.login == "jester"))) - assert ret is False, "This role does not allow access as jester" -``` +See example policies in the [examples/](examples/) folder. ### Testing a policy ```bash -predicate test access.py -``` - -```bash -Running 1 tests: - - test_access: ok +predicate test examples/access.py ``` ### Exporting a policy ```bash -predicate export access.py +predicate export examples/access.py ``` - -```yaml -kind: policy -metadata: - name: access -spec: - allow: - access_node: (((access_node.login == user.name) && (!(user.name == "root"))) || - equals(user.traits["team"], ["admins"])) - deny: - access_node: (((access_node.login == "mike") || (access_node.login == "jester")) - || (node.labels["env"] == "prod")) - options: (options.max_session_ttl < 36000000000000) -version: v1 -``` \ No newline at end of file diff --git a/predicate/cli/__main__.py b/predicate/cli/__main__.py index 824fa6a..4e6a771 100644 --- a/predicate/cli/__main__.py +++ b/predicate/cli/__main__.py @@ -5,6 +5,8 @@ import click import yaml +from solver.teleport import Policy + @click.group() def main(): @@ -17,7 +19,7 @@ def export(policy_file): module = run_path(policy_file) # Grabs the class and directly reads the policy since it's a static member. - policy = module["Teleport"].p + policy: Policy = module["Teleport"].p # Dump the policy into a Teleport resource and write it to the terminal. obj = policy.export() @@ -31,10 +33,12 @@ def export(policy_file): def deploy(policy_file, sudo): click.echo("parsing policy...") module = run_path(policy_file) - policy = module["Teleport"].p + policy: Policy = module["Teleport"].p + click.echo("translating policy...") obj = policy.export() serialized = yaml.dump(obj) + click.echo("deploying policy...") args = ["tctl", "create", "-f"] if sudo: diff --git a/predicate/examples/access.py b/predicate/examples/access.py index d5047e5..0b61092 100644 --- a/predicate/examples/access.py +++ b/predicate/examples/access.py @@ -1,18 +1,31 @@ from solver.ast import Duration -from solver.teleport import AccessNode, Node, Options, OptionsSet, Policy, Rules, User +from solver.teleport import ( + AccessNode, + Node, + Options, + Policy, + RecordingMode, + Rules, + SourceIp, + User, +) class Teleport: p = Policy( name="access", loud=False, + options=Options( + max_session_ttl=Duration.new(hours=10), + recording_mode=RecordingMode.STRICT, + source_ip=SourceIp.PINNED, + ), allow=Rules( AccessNode( ((AccessNode.login == User.name) & (User.name != "root")) | (User.traits["team"] == ("admins",)) ), ), - options=OptionsSet(Options((Options.max_session_ttl < Duration.new(hours=10)))), deny=Rules( AccessNode( (AccessNode.login == "mike") @@ -24,19 +37,19 @@ class Teleport: def test_access(self): # Alice will be able to login to any machine as herself - ret, _ = self.p.check( + ret = self.p.check( AccessNode( (AccessNode.login == "alice") & (User.name == "alice") & (Node.labels["env"] == "dev") ) ) - assert ret is True, "Alice can login with her user to any node" + assert ret.solves is True, "Alice can login with her user to any node" # No one is permitted to login as mike - ret, _ = self.p.query(AccessNode((AccessNode.login == "mike"))) - assert ret is False, "This role does not allow access as mike" + ret = self.p.query(AccessNode((AccessNode.login == "mike"))) + assert ret.solves is False, "This role does not allow access as mike" # No one is permitted to login as jester - ret, _ = self.p.query(AccessNode((AccessNode.login == "jester"))) - assert ret is False, "This role does not allow access as jester" + ret = self.p.query(AccessNode((AccessNode.login == "jester"))) + assert ret.solves is False, "This role does not allow access as jester" diff --git a/predicate/examples/join_session.py b/predicate/examples/join_session.py index d416040..ed17da1 100644 --- a/predicate/examples/join_session.py +++ b/predicate/examples/join_session.py @@ -1,4 +1,4 @@ -from solver.teleport import JoinSession, Session, Policy, Rules, User, AccessNode +from solver.teleport import JoinSession, Policy, Rules, Session, User class Teleport: @@ -9,51 +9,58 @@ class Teleport: # Equivalent to `join_sessions`: # https://goteleport.com/docs/access-controls/guides/moderated-sessions/#join_sessions JoinSession( - (User.traits["team"].contains("dev")) & - ((JoinSession.mode == "observer") | (JoinSession.mode == "peer")) & - ((Session.owner.traits["team"].contains("dev")) | (Session.owner.traits["team"].contains("intern"))) + (User.traits["team"].contains("dev")) + & ((JoinSession.mode == "observer") | (JoinSession.mode == "peer")) + & ( + (Session.owner.traits["team"].contains("dev")) + | (Session.owner.traits["team"].contains("intern")) + ) ), ), - deny=Rules( - JoinSession( - User.traits["team"].contains("intern") - ) - ) + deny=Rules(JoinSession(User.traits["team"].contains("intern"))), ) def test_access(self): - ret, _ = self.p.check( + ret = self.p.check( JoinSession( - (User.traits["team"] == ("dev",)) & - (JoinSession.mode == "observer") & - (Session.owner.traits["team"] == ("intern",)) + (User.traits["team"] == ("dev",)) + & (JoinSession.mode == "observer") + & (Session.owner.traits["team"] == ("intern",)) ) ) - assert ret is True, "a dev user can join a session from an intern user as an observer" + assert ( + ret.solves is True + ), "a dev user can join a session from an intern user as an observer" - ret, _ = self.p.check( + ret = self.p.check( JoinSession( - (User.traits["team"] == ("marketing",)) & - (JoinSession.mode == "observer") & - (Session.owner.traits["team"] == ("intern",)) + (User.traits["team"] == ("marketing",)) + & (JoinSession.mode == "observer") + & (Session.owner.traits["team"] == ("intern",)) ) ) - assert ret is False, "a marketing user cannot join a session from an intern user as an observer" + assert ( + ret.solves is False + ), "a marketing user cannot join a session from an intern user as an observer" - ret, _ = self.p.check( + ret = self.p.check( JoinSession( - (User.traits["team"] == ("dev",)) & - (JoinSession.mode == "moderator") & - (Session.owner.traits["team"] == ("intern",)) + (User.traits["team"] == ("dev",)) + & (JoinSession.mode == "moderator") + & (Session.owner.traits["team"] == ("intern",)) ) ) - assert ret is False, "a dev user cannot join a session from an intern user as a moderator" + assert ( + ret.solves is False + ), "a dev user cannot join a session from an intern user as a moderator" - ret, _ = self.p.check( + ret = self.p.check( JoinSession( - (User.traits["team"] == ("dev", "intern")) & - (JoinSession.mode == "observer") & - (Session.owner.traits["team"] == ("intern",)) + (User.traits["team"] == ("dev", "intern")) + & (JoinSession.mode == "observer") + & (Session.owner.traits["team"] == ("intern",)) ) ) - assert ret is False, "a dev intern user cannot join a session from an intern user as an observer" + assert ( + ret.solves is False + ), "a dev intern user cannot join a session from an intern user as an observer" diff --git a/predicate/mypy.ini b/predicate/mypy.ini index c4f5208..ebcf395 100644 --- a/predicate/mypy.ini +++ b/predicate/mypy.ini @@ -1,4 +1,3 @@ [mypy] ignore_missing_imports = True -exclude = examples diff --git a/predicate/poetry.lock b/predicate/poetry.lock index 5020c31..1e1ac50 100644 --- a/predicate/poetry.lock +++ b/predicate/poetry.lock @@ -10,7 +10,7 @@ python-versions = ">=3.5" dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope.interface"] docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"] tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"] -tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] +tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] [[package]] name = "black" @@ -46,11 +46,22 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} [[package]] name = "colorama" -version = "0.4.5" +version = "0.4.6" description = "Cross-platform colored terminal text." category = "main" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" + +[[package]] +name = "exceptiongroup" +version = "1.0.4" +description = "Backport of PEP 654 (exception groups)" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.extras] +test = ["pytest (>=6)"] [[package]] name = "flake8" @@ -83,9 +94,9 @@ python-versions = ">=3.6.1,<4.0" [package.extras] colors = ["colorama (>=0.4.3,<0.5.0)"] -pipfile_deprecated_finder = ["pipreqs", "requirementslib"] +pipfile-deprecated-finder = ["pipreqs", "requirementslib"] plugins = ["setuptools"] -requirements_deprecated_finder = ["pip-api", "pipreqs"] +requirements-deprecated-finder = ["pip-api", "pipreqs"] [[package]] name = "lint-python" @@ -146,7 +157,7 @@ pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" [[package]] name = "pathspec" -version = "0.10.1" +version = "0.10.2" description = "Utility library for gitignore style pattern matching of file paths." category = "main" optional = false @@ -154,15 +165,15 @@ python-versions = ">=3.7" [[package]] name = "platformdirs" -version = "2.5.2" -description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +version = "2.5.4" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." category = "main" optional = false python-versions = ">=3.7" [package.extras] -docs = ["furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx (>=4)", "sphinx-autodoc-typehints (>=1.12)"] -test = ["appdirs (==1.4.4)", "pytest (>=6)", "pytest-cov (>=2.7)", "pytest-mock (>=3.6)"] +docs = ["furo (>=2022.9.29)", "proselint (>=0.13)", "sphinx (>=5.3)", "sphinx-autodoc-typehints (>=1.19.4)"] +test = ["appdirs (==1.4.4)", "pytest (>=7.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"] [[package]] name = "pluggy" @@ -176,14 +187,6 @@ python-versions = ">=3.6" dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] -[[package]] -name = "py" -version = "1.11.0" -description = "library with cross-python path, ini-parsing, io, code, log facilities" -category = "main" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" - [[package]] name = "pycodestyle" version = "2.9.1" @@ -213,7 +216,7 @@ diagrams = ["jinja2", "railroad-diagrams"] [[package]] name = "pytest" -version = "7.1.3" +version = "7.2.0" description = "pytest: simple powerful testing with Python" category = "main" optional = false @@ -222,17 +225,17 @@ python-versions = ">=3.7" [package.dependencies] attrs = ">=19.2.0" colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" pluggy = ">=0.12,<2.0" -py = ">=1.8.2" -tomli = ">=1.0.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] [[package]] -name = "PyYAML" +name = "pyyaml" version = "6.0" description = "YAML parser and emitter for Python" category = "main" @@ -241,7 +244,7 @@ python-versions = ">=3.6" [[package]] name = "setuptools" -version = "65.4.1" +version = "65.5.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" category = "main" optional = false @@ -249,7 +252,7 @@ python-versions = ">=3.7" [package.extras] docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mock", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] [[package]] @@ -268,6 +271,33 @@ category = "main" optional = false python-versions = ">=3.7" +[[package]] +name = "types-pyyaml" +version = "6.0.12.2" +description = "Typing stubs for PyYAML" +category = "main" +optional = false +python-versions = "*" + +[[package]] +name = "types-requests" +version = "2.28.11.4" +description = "Typing stubs for requests" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +types-urllib3 = "<1.27" + +[[package]] +name = "types-urllib3" +version = "1.26.25.4" +description = "Typing stubs for urllib3" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "typing-extensions" version = "4.4.0" @@ -287,7 +317,7 @@ python-versions = "*" [metadata] lock-version = "1.1" python-versions = "^3.10" -content-hash = "3133493c42e82a69223f30162d4c726dbf3e0d02d32edae4fa0e0cf539b224e2" +content-hash = "fe6585a5c18e50ae7324dd72d155959a4e5fa3e23cf310c15877b5745de22283" [metadata.files] attrs = [ @@ -322,8 +352,12 @@ click = [ {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, ] colorama = [ - {file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"}, - {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"}, + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] +exceptiongroup = [ + {file = "exceptiongroup-1.0.4-py3-none-any.whl", hash = "sha256:542adf9dea4055530d6e1279602fa5cb11dab2395fa650b8674eaec35fc4a828"}, + {file = "exceptiongroup-1.0.4.tar.gz", hash = "sha256:bd14967b79cd9bdb54d97323216f8fdf533e278df937aa2a90089e7d6e06e5ec"}, ] flake8 = [ {file = "flake8-5.0.4-py2.py3-none-any.whl", hash = "sha256:7a1cf6b73744f5806ab95e526f6f0d8c01c66d7bbe349562d22dfca20610b248"}, @@ -380,21 +414,17 @@ packaging = [ {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, ] pathspec = [ - {file = "pathspec-0.10.1-py3-none-any.whl", hash = "sha256:46846318467efc4556ccfd27816e004270a9eeeeb4d062ce5e6fc7a87c573f93"}, - {file = "pathspec-0.10.1.tar.gz", hash = "sha256:7ace6161b621d31e7902eb6b5ae148d12cfd23f4a249b9ffb6b9fee12084323d"}, + {file = "pathspec-0.10.2-py3-none-any.whl", hash = "sha256:88c2606f2c1e818b978540f73ecc908e13999c6c3a383daf3705652ae79807a5"}, + {file = "pathspec-0.10.2.tar.gz", hash = "sha256:8f6bf73e5758fd365ef5d58ce09ac7c27d2833a8d7da51712eac6e27e35141b0"}, ] platformdirs = [ - {file = "platformdirs-2.5.2-py3-none-any.whl", hash = "sha256:027d8e83a2d7de06bbac4e5ef7e023c02b863d7ea5d079477e722bb41ab25788"}, - {file = "platformdirs-2.5.2.tar.gz", hash = "sha256:58c8abb07dcb441e6ee4b11d8df0ac856038f944ab98b7be6b27b2a3c7feef19"}, + {file = "platformdirs-2.5.4-py3-none-any.whl", hash = "sha256:af0276409f9a02373d540bf8480021a048711d572745aef4b7842dad245eba10"}, + {file = "platformdirs-2.5.4.tar.gz", hash = "sha256:1006647646d80f16130f052404c6b901e80ee4ed6bef6792e1f238a8969106f7"}, ] pluggy = [ {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, ] -py = [ - {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, - {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, -] pycodestyle = [ {file = "pycodestyle-2.9.1-py2.py3-none-any.whl", hash = "sha256:d1735fc58b418fd7c5f658d28d943854f8a849b01a5d0a1e6f3f3fdd0166804b"}, {file = "pycodestyle-2.9.1.tar.gz", hash = "sha256:2c9607871d58c76354b697b42f5d57e1ada7d261c261efac224b664affdc5785"}, @@ -408,10 +438,10 @@ pyparsing = [ {file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"}, ] pytest = [ - {file = "pytest-7.1.3-py3-none-any.whl", hash = "sha256:1377bda3466d70b55e3f5cecfa55bb7cfcf219c7964629b967c37cf0bda818b7"}, - {file = "pytest-7.1.3.tar.gz", hash = "sha256:4f365fec2dff9c1162f834d9f18af1ba13062db0c708bf7b946f8a5c76180c39"}, + {file = "pytest-7.2.0-py3-none-any.whl", hash = "sha256:892f933d339f068883b6fd5a459f03d85bfcb355e4981e146d2c7616c21fef71"}, + {file = "pytest-7.2.0.tar.gz", hash = "sha256:c4014eb40e10f11f355ad4e3c2fb2c6c6d1919c73f3b5a433de4708202cade59"}, ] -PyYAML = [ +pyyaml = [ {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, {file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"}, {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77f396e6ef4c73fdc33a9157446466f1cff553d979bd00ecb64385760c6babdc"}, @@ -454,8 +484,8 @@ PyYAML = [ {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, ] setuptools = [ - {file = "setuptools-65.4.1-py3-none-any.whl", hash = "sha256:1b6bdc6161661409c5f21508763dc63ab20a9ac2f8ba20029aaaa7fdb9118012"}, - {file = "setuptools-65.4.1.tar.gz", hash = "sha256:3050e338e5871e70c72983072fe34f6032ae1cdeeeb67338199c2f74e083a80e"}, + {file = "setuptools-65.5.1-py3-none-any.whl", hash = "sha256:d0b9a8433464d5800cbe05094acf5c6d52a91bfac9b52bcfc4d41382be5d5d31"}, + {file = "setuptools-65.5.1.tar.gz", hash = "sha256:e197a19aa8ec9722928f2206f8de752def0e4c9fc6953527360d1c36d94ddb2f"}, ] toml = [ {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, @@ -465,6 +495,18 @@ tomli = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] +types-pyyaml = [ + {file = "types-PyYAML-6.0.12.2.tar.gz", hash = "sha256:6840819871c92deebe6a2067fb800c11b8a063632eb4e3e755914e7ab3604e83"}, + {file = "types_PyYAML-6.0.12.2-py3-none-any.whl", hash = "sha256:1e94e80aafee07a7e798addb2a320e32956a373f376655128ae20637adb2655b"}, +] +types-requests = [ + {file = "types-requests-2.28.11.4.tar.gz", hash = "sha256:d4f342b0df432262e9e326d17638eeae96a5881e78e7a6aae46d33870d73952e"}, + {file = "types_requests-2.28.11.4-py3-none-any.whl", hash = "sha256:bdb1f9811e53d0642c8347b09137363eb25e1a516819e190da187c29595a1df3"}, +] +types-urllib3 = [ + {file = "types-urllib3-1.26.25.4.tar.gz", hash = "sha256:eec5556428eec862b1ac578fb69aab3877995a99ffec9e5a12cf7fbd0cc9daee"}, + {file = "types_urllib3-1.26.25.4-py3-none-any.whl", hash = "sha256:ed6b9e8a8be488796f72306889a06a3fc3cb1aa99af02ab8afb50144d7317e49"}, +] typing-extensions = [ {file = "typing_extensions-4.4.0-py3-none-any.whl", hash = "sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e"}, {file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"}, diff --git a/predicate/pyproject.toml b/predicate/pyproject.toml index 4356934..5f6b484 100644 --- a/predicate/pyproject.toml +++ b/predicate/pyproject.toml @@ -24,11 +24,12 @@ click = "^8.1.3" setuptools = "^65.3.0" PyYAML = "^6.0" mypy = "^0.982" +types-requests = "^2.28.11.4" +types-pyyaml = "^6.0.12.2" [tool.lint-python] lint-version = "2" source = "." -extra-requirements = "types-requests" use-mypy = true [tool.isort] diff --git a/predicate/solver/ast.py b/predicate/solver/ast.py index 79aee12..6dcdc17 100644 --- a/predicate/solver/ast.py +++ b/predicate/solver/ast.py @@ -20,6 +20,10 @@ # * book https://theory.stanford.edu/~nikolaj/programmingz3.html # * reference https://z3prover.github.io/api/html/namespacez3py.html +# To allow using class type inside class: +# https://github.com/python/mypy/issues/3661#issuecomment-647945042 +from __future__ import annotations + import functools import operator import sre_constants @@ -350,6 +354,11 @@ def walk(self, fn): def __str__(self): return "`{}`".format(self.V) + def __eq__(self, other): + if isinstance(other, DurationLiteral): + return self.V == other.V + return False + class BoolLiteral: """ @@ -410,7 +419,7 @@ def __gt__(self, val): class Int(IntMixin): """ - Int is integer variable, e.g. count = Int('count') + Int is an integer variable, e.g. count = Int('count') """ def __init__(self, name: str): @@ -426,9 +435,10 @@ def walk(self, fn): def __str__(self): return "int({})".format(self.name) -class LtDuration: + +class Duration: """ - LtDuration is a duration that only allows < inequalities. + Duration is a duration variable, e.g. ttl = Duration('ttl') """ def __init__(self, name: str): @@ -441,18 +451,6 @@ def traverse(self): def walk(self, fn): fn(self) - def __str__(self): - return "ltduration({})".format(self.name) - - def __lt__(self, val: DurationLiteral): - return Lt(self, val) - - -class Duration(LtDuration): - """ - Duration is a duration that allows <, >, == and != inequalities. - """ - @staticmethod def new( hours: int = 0, @@ -474,15 +472,29 @@ def new( def __str__(self): return "duration({})".format(self.name) - def __eq__(self, val: DurationLiteral): + def __eq__(self, val): + self.check_value_is_valid(val) return Eq(self, val) - def __ne__(self, val: DurationLiteral): + def __ne__(self, val): + self.check_value_is_valid(val) return Not(Eq(self, val)) - def __gt__(self, val: DurationLiteral): + def __lt__(self, val): + self.check_value_is_valid(val) + return Lt(self, val) + + def __gt__(self, val): + self.check_value_is_valid(val) return Gt(self, val) + def check_value_is_valid(self, val): + if isinstance(val, DurationLiteral): + return + raise TypeError( + "unsupported type {}, supported duration literals only".format(type(val)) + ) + class Bool: def __init__(self, name: str): @@ -861,9 +873,7 @@ def check_value_is_valid(self, val): # raise type error if `val` is not one of the enum values raise TypeError( - "value {} is not one of: {}".format( - val, [v for v in self.values] - ) + "value {} is not one of: {}".format(val, [v for v in self.values]) ) def __str__(self): @@ -1206,7 +1216,7 @@ def __init__(self, name): self.name = name self.fn_map = z3.Function(self.name, z3.StringSort(), z3.StringSort()) - def __getitem__(self, key: String): + def __getitem__(self, key: str | String): """ getitem used to build an expression, for example m[key] == "val" """ @@ -1221,7 +1231,7 @@ def walk(self, fn): class MapIndex(LogicMixin): - def __init__(self, m: StringMap, key: String): + def __init__(self, m: StringMap, key: str | String): self.m = m self.key = key @@ -1612,17 +1622,17 @@ def iff(fn_key, iterator): self.fn_map, [arg_key], iff(arg_key, iter(values.items())) ) - def __getitem__(self, key: String): + def __getitem__(self, key: str | String): """ getitem used to build an expression, for example m[key].contains("val") """ # Map Index should impact function definition, aggregate it return StringSetMapIndex(self, key) - def add_value(self, key: String, val: String): + def add_value(self, key: str | String, val: str | String): return StringSetMapAddValue(self, key, val) - def remove_keys(self, *keys: String): + def remove_keys(self, *keys): return StringSetMapRemoveKeys(self, keys) def overwrite(self, values: typing.Dict): @@ -1681,19 +1691,19 @@ def __str__(self): def traverse(self): return self.fn_map - def __getitem__(self, key: String): + def __getitem__(self, key: str | String): """ getitem used to build an expression, for example m[key].contains("val") """ # Map Index should impact function definition, aggregate it return StringSetMapIndex(self, key) - def add_value(self, key: String, val: String): + def add_value(self, key: str | String, val: str | String): return StringSetMapAddValue(self, key, val) class StringSetMapAddValue(StringSetMap): - def __init__(self, m: StringSetMap, key, val): + def __init__(self, m: StringSetMap, key: str | String, val: str | String): self.m = m self.name = m.name + "_add_value" self.K = key @@ -1728,14 +1738,14 @@ def __str__(self): def traverse(self): return self.fn_map - def __getitem__(self, key: String): + def __getitem__(self, key: str | String): """ getitem used to build an expression, for example m[key].contains("val") """ # Map Index should impact function definition, aggregate it return StringSetMapIndex(self, key) - def add_value(self, key: String, val: String): + def add_value(self, key: str | String, val: str | String): return StringSetMapAddValue(self, key, val) @@ -1771,19 +1781,19 @@ def __str__(self): def traverse(self): return self.fn_map - def __getitem__(self, key: String): + def __getitem__(self, key: str | String): """ getitem used to build an expression, for example m[key].contains("val") """ # Map Index should impact function definition, aggregate it return StringSetMapIndex(self, key) - def add_value(self, key: String, val: String): + def add_value(self, key: str | String, val: str | String): return StringSetMapAddValue(self, key, val) class StringSetMapIndex: - def __init__(self, m: StringSetMap, key: String): + def __init__(self, m: StringSetMap, key: str | String): self.m = m self.key = key @@ -1979,18 +1989,17 @@ def traverse(self): return self.E.m.fn_map(z3.StringVal(self.E.key)) == self.V.traverse() -def collect_symbols(s, expr): - if isinstance(expr, (String, Int, LtDuration, Duration, Bool, StringEnum)): +def collect_symbols(s: set[str], expr): + if isinstance(expr, (String, Int, Duration, Bool, StringEnum)): s.add(expr.name) if isinstance(expr, MapIndex): s.add(expr.m.name + "." + expr.key) -def collect_names(s, expr): - if isinstance(expr, (String, Int, LtDuration, Duration, Bool, StringEnum)): - s.add(expr.name) - if isinstance(expr, MapIndex): - s.add(expr.m.name) +class SolverResult: + def __init__(self, solves, model=None): + self.solves = solves + self.model = model class Predicate: @@ -2006,19 +2015,14 @@ def __str__(self): def walk(self, fn): self.expr.walk(fn) - def verify(self): - solver = z3.Solver() - solver.add(self.expr.traverse()) - if solver.check() == z3.unsat: - raise ParameterError("our own predicate is unsolvable") - - def check(self, other): + def check(self, other: Predicate) -> SolverResult: """ check checks the predicate against conditions specified in another predicate. Both predicates should define """ # sanity check - to check two predicates, they should # define the same sets of symbols + # TODO: this is not checking for equality if not self.symbols.issubset(other.symbols): diff = self.symbols.difference(other.symbols) raise ParameterError( @@ -2029,23 +2033,24 @@ def check(self, other): return self.solves_with(other) - def query(self, other): + def query(self, other: Predicate) -> SolverResult: """ Query can only succeed if symbols in the query are a strict subset of all symbols used in the predicate being queried Query behaves like SQL, e.g. select * from users where name like 'a%'; """ if not other.symbols.issubset(self.symbols): - diff = self.symbols.difference(other.symbols) - return ( - False, - """check can not resolve ambiguity, query uses symbols %s that are not present in predicate %s, diff: %s, - query must be a subset of symbols of the predicate""" - % (self.symbols, other.symbols, diff), - ) + _ = self.symbols.difference(other.symbols) + # TODO: this should raise a ParameterError, like in check + # raise ParameterError( + # """query can not resolve ambiguity, query uses symbols %s that are not present in predicate %s, diff: %s, + # query must be a subset of symbols of the predicate""" + # % (self.symbols, other.symbols, diff), + # ) + return SolverResult(False) return self.solves_with(other) - def solve(self): + def solve(self) -> SolverResult: """ Solve solves predicate against itself """ @@ -2053,13 +2058,13 @@ def solve(self): e = self.expr.traverse() if self.loud: print("OUR EXPR: {}".format(e)) - solver.add(self.expr.traverse()) + solver.add(e) if solver.check() == z3.unsat: raise ParameterError("our own predicate is unsolvable") - return (True, solver.model()) + return SolverResult(True, model=solver.model()) - def solves_with(self, other): + def solves_with(self, other: Predicate) -> SolverResult: """ solves_with returns true if the predicate can be true with another predicate being true at the same time. @@ -2077,15 +2082,11 @@ def solves_with(self, other): print("THEIR EXPR: {}".format(o)) solver.add(o) - # TODO do a second pass to build a key checking function - # for both predicates! - self.expr.walk(functools.partial(collect_symbols, self.symbols)) - if solver.check() == z3.unsat: - return (False, "predicate is unsolvable against %s" % (other.expr,)) - return (True, solver.model()) + return SolverResult(False) + return SolverResult(True, model=solver.model()) - def equivalent(self, other): + def equivalent(self, other: Predicate): solver = z3.Solver() solver.add(z3.Distinct(self.expr.traverse(), other.expr.traverse())) result = solver.check() @@ -2106,8 +2107,8 @@ def simplify(self): then it removes the redundant one. """ - def split(vals, expr): - if type(expr) == And or type(expr) == Or: + def split(vals: list, expr): + if isinstance(expr, (And, Or)): vals.append(expr.L) vals.append(expr.R) diff --git a/predicate/solver/teleport.py b/predicate/solver/teleport.py index 1828c4c..c7a604b 100644 --- a/predicate/solver/teleport.py +++ b/predicate/solver/teleport.py @@ -14,51 +14,126 @@ limitations under the License. """ +# To allow using class type inside class: +# https://github.com/python/mypy/issues/3661#issuecomment-647945042 +from __future__ import annotations + import functools import operator +import re +import typing from collections.abc import Iterable +from enum import Enum import z3 -import re from . import ast from .errors import ParameterError def scoped(cls): - cls.scope = re.sub(r"([a-z])([A-Z])", r'\1_\2', cls.__name__).lower() + cls.scope = re.sub(r"([a-z])([A-Z])", r"\1_\2", cls.__name__).lower() return cls -class Options(ast.Predicate): +class SourceIp(Enum): """ - Options apply to some allow rules if they match + SourceIp defines the possible values for the source_ip option. + The values are ordered in increasing permissiveness. """ - max_session_ttl = ast.LtDuration("options.max_session_ttl") + PINNED = (0, "pinned") + UNPINNED = (1, "unpinned") - pin_source_ip = ast.Bool("options.pin_source_ip") + def __lt__(self, other): + if isinstance(other, SourceIp): + return self.value[0] < other.value[0] + raise TypeError( + "unsupported type {}, supported source ip only".format(type(other)) + ) - recording_mode = ast.StringEnum( - "options.recording_mode", set([(0, "best_effort"), (1, "strict")]) - ) - def __init__(self, expr): - ast.Predicate.__init__(self, expr) +class RecordingMode(Enum): + """ + RecordingMode defines the possible values for the recording_mode option. + The values are ordered in increasing permissiveness. + """ + + STRICT = (0, "strict") + BEST_EFFORT = (1, "best_effort") + + def __lt__(self, other): + if isinstance(other, RecordingMode): + return self.value[0] < other.value[0] + raise TypeError( + "unsupported type {}, supported recording mode only".format(type(other)) + ) -class OptionsSet: +class Options: """ - OptionsSet is a set of option expressions + Options specifies a list of Teleport options. """ - def __init__(self, *options: Options): - self.options = options + # TODO: ensure user cannot set invalid option values + # type checker is not doing its job ATM + def __init__( + self, + max_session_ttl: typing.Optional[ast.DurationLiteral] = None, + source_ip: typing.Optional[SourceIp] = None, + recording_mode: typing.Optional[RecordingMode] = None, + ): + self.max_session_ttl = max_session_ttl + self.source_ip = source_ip + self.recording_mode = recording_mode - def collect_like(self, other: ast.Predicate): - return [ - o for o in self.options if len(o.symbols.intersection(other.symbols)) > 0 - ] + def __str__(self): + return "options(max_session_ttl={}, source_ip={}, recording_mode={})".format( + self.max_session_ttl, self.source_ip, self.recording_mode + ) + + def empty(self) -> bool: + """ + empty returns true if all options are set to None. + """ + return functools.reduce( + operator.__and__, + map( + lambda o: not o, + [self.max_session_ttl, self.source_ip, self.recording_mode], + ), + ) + + @staticmethod + def combine(left: Options, right: Options) -> Options: + """ + combines combines two sets of options. + If the two sets are conflicting (e.g. one option has `ttl=10` and the other had `ttl=3`), + the least permissive option is picked (`ttl=3` in this case). + """ + + def combine_fun(left, right, fun): + if left and right: + return fun(left, right) + if left: + return left + if right: + return right + else: + return None + + def min_duration( + left: ast.DurationLiteral, right: ast.DurationLiteral + ) -> ast.DurationLiteral: + return ast.DurationLiteral(min(left.V, right.V)) + + return Options( + max_session_ttl=combine_fun( + left.max_session_ttl, right.max_session_ttl, min_duration + ), + source_ip=combine_fun(left.source_ip, right.source_ip, min), + recording_mode=combine_fun(left.recording_mode, right.recording_mode, min), + ) @scoped @@ -73,14 +148,6 @@ def __init__(self, expr): ast.Predicate.__init__(self, expr) # TODO check that the predicate is complete, has listed logins - def __and__(self, options: Options): - """ - This is a somewhat special case, options define max session TTL, - so this operator constructs a node predicate that contains options - that are relevant to node. - """ - return AccessNode(self.expr & options.expr) - class Node: """ @@ -121,8 +188,8 @@ def PolicyMap(expr): def try_login(policy_map_expr, traits_expr): p = ast.Predicate(policy_map_expr != ast.StringTuple(())) - ret, model = p.check(ast.Predicate(traits_expr)) - if not ret: + ret = p.check(ast.Predicate(traits_expr)) + if not ret.solves: return () out = [] @@ -131,7 +198,7 @@ def first(depth): for i in range(depth): vals = ast.StringListSort.cdr(vals) expr = ast.fn_string_list_first(vals) - return model.eval(expr).as_string() + return ret.model.eval(expr).as_string() depth = 0 while True: @@ -203,6 +270,7 @@ class JoinSession(ast.Predicate): def __init__(self, expr): ast.Predicate.__init__(self, expr) + class Session: """ Session is a Teleport session. @@ -246,8 +314,14 @@ class Rules: Rules are allow or deny rules """ - def __init__(self, *rules): - self.rules = rules or [] + def __init__(self, *rules: ast.Predicate): + self.rules = rules + + def empty(self) -> bool: + """ + empty returns true if the set of rules is empty. + """ + return len(self.rules) == 0 def collect_like(self, other: ast.Predicate): return [r for r in self.rules if r.__class__ == other.__class__] @@ -277,7 +351,6 @@ def t_expr(predicate): predicate, ( ast.String, - ast.LtDuration, ast.Duration, ast.StringList, ast.StringEnum, @@ -352,29 +425,47 @@ def t_expr(predicate): raise Exception(f"unknown predicate type: {type(predicate)}") +class TeleportSolverResult: + """ + TeleportSolverResult contains the following fields: + - solves: indicates whether the solver was able to solve the predicate(s) + - model: contains the Z3 model in case `solves == True` + - options: contains all options (from a policy set) combined + """ + + # TODO: can this be done with class inheritance? + def __init__(self, solver_result: ast.SolverResult, options: Options): + self.solves = solver_result.solves + self.model = solver_result.model + self.options = options + + class Policy: def __init__( self, name: str, - options: OptionsSet = None, - allow: Rules = None, - deny: Rules = None, + options: Options = Options(), + allow: Rules = Rules(), + deny: Rules = Rules(), loud: bool = True, ): - self.name = name if name == "": - raise ast.ParameterError("supply a valid name") - if allow is None and deny is None: - raise ast.ParameterError("provide either allow or deny") - self.allow = allow or Rules() - self.deny = deny or Rules() - self.options = options or OptionsSet() + raise ast.ParameterError("policy name cannot be empty") + if options.empty() and allow.empty() and deny.empty(): + raise ast.ParameterError( + "policy must contain either options, allow or deny rules" + ) + + self.name = name + self.allow = allow + self.deny = deny + self.options = options self.loud = loud - def check(self, other: ast.Predicate): + def check(self, other: ast.Predicate) -> TeleportSolverResult: return PolicySet([self], self.loud).check(other) - def query(self, other: ast.Predicate): + def query(self, other: ast.Predicate) -> TeleportSolverResult: return PolicySet([self], self.loud).query(other) def build_predicate(self, other: ast.Predicate): @@ -391,7 +482,7 @@ def export(self): } def group_rules(operator, rules): - scopes = {} + scopes: dict[str, list] = {} for rule in rules: if rule.scope not in scopes: scopes[rule.scope] = [] @@ -404,9 +495,15 @@ def group_rules(operator, rules): return scopes - if self.options.options: - options_rules = functools.reduce(operator.and_, self.options.options) - out["spec"]["options"] = t_expr(options_rules) + if self.options: + options = {} + if self.options.max_session_ttl: + options["max_session_ttl"] = self.options.max_session_ttl.V + if self.options.source_ip: + options["source_ip"] = self.options.source_ip.value[1] + if self.options.recording_mode: + options["recording_mode"] = self.options.recording_mode.value[1] + out["spec"]["options"] = options if self.allow.rules: out["spec"]["allow"] = group_rules(operator.or_, self.allow.rules) @@ -430,53 +527,31 @@ def __init__(self, policies: Iterable[Policy], loud: bool = True): def build_predicate(self, other: ast.Predicate) -> ast.Predicate: allow = [] deny = [] - options = [] for p in self.policies: allow.extend([e.expr for e in p.allow.collect_like(other)]) - # here we collect options from our policies that are mentioned in the predicate - # we are checking against, so our options are "sticky" - options.extend([o.expr for o in p.options.collect_like(other)]) deny.extend([ast.Not(e.expr) for e in p.deny.collect_like(other)]) - # all options should match - # TODO: how to deal with Teleport options logic that returns min out of two? - # probably < equation will solve this problem - allow_expr = None - options_expr = None - # if option predicates are present, apply them as mandatory - # to the allow expression, so allow is matching only if options - # match as well. - if options: - options_expr = functools.reduce(operator.and_, options) - if allow: + if allow and deny: allow_expr = functools.reduce(operator.or_, allow) - if options: - allow_expr = allow_expr & options_expr - if deny: deny_expr = functools.reduce(operator.and_, deny) - - if not allow and not deny: - raise ast.ParameterError("policy set is empty {}") - pr = None - if not deny: - pr = ast.Predicate(allow_expr, self.loud) - elif not allow_expr: - pr = ast.Predicate(deny_expr, self.loud) + return ast.Predicate(allow_expr & deny_expr, self.loud) + elif allow: + allow_expr = functools.reduce(operator.or_, allow) + return ast.Predicate(allow_expr, self.loud) + elif deny: + deny_expr = functools.reduce(operator.and_, deny) + return ast.Predicate(deny_expr, self.loud) else: - pr = ast.Predicate(allow_expr & deny_expr, self.loud) - return pr + raise ast.ParameterError("policy set is empty") - def check(self, other: ast.Predicate): - return self.build_predicate(other).check(other) + def combine_options(self) -> Options: + options = [p.options for p in self.policies] + return functools.reduce(Options.combine, options, Options()) - def query(self, other: ast.Predicate): - return self.build_predicate(other).query(other) + def check(self, other: ast.Predicate) -> TeleportSolverResult: + ret = self.build_predicate(other).check(other) + return TeleportSolverResult(ret, self.combine_options()) - def names(self): - """ - Names returns names in the policy set - """ - s = set() - for p in self.policies: - s.add(p.name) - return s + def query(self, other: ast.Predicate) -> TeleportSolverResult: + ret = self.build_predicate(other).query(other) + return TeleportSolverResult(ret, self.combine_options()) diff --git a/predicate/solver/test/test_ast.py b/predicate/solver/test/test_ast.py index 2913fba..4d57967 100644 --- a/predicate/solver/test/test_ast.py +++ b/predicate/solver/test/test_ast.py @@ -70,9 +70,8 @@ def test_check_equiv(self): p = Predicate(User.team == "stage") # This predicate is unsolvable, contradicts our main predicate - ret, msg = p.check(Predicate(User.team != "stage")) - assert ret is False - assert "unsolvable" in msg + ret = p.check(Predicate(User.team != "stage")) + assert ret.solves is False # Two predicates are equivalent, if they return the same results, # equivalency is not equality, it's more broad. @@ -92,8 +91,8 @@ def test_two_symbols(self): """ p = Predicate(Server.env == User.team) - ret, _ = p.check(Predicate((Server.env == "prod") & (User.team == "prod"))) - assert ret is True, "this predicate holds when both values match" + ret = p.check(Predicate((Server.env == "prod") & (User.team == "prod"))) + assert ret.solves is True, "this predicate holds when both values match" # user is not defined in the other predicate the check should fail # as we haven't defined all symbols @@ -119,7 +118,7 @@ def test_queries(self): p = Predicate((Server.env == User.team) & (Server.login == User.name)) # Bob can access server with label prod with his name - ret, _ = p.check( + ret = p.check( Predicate( (Server.env == "prod") & (User.team == "prod") @@ -127,26 +126,26 @@ def test_queries(self): & (Server.login == "bob") ) ) - assert ret is True + assert ret.solves is True # Query helps to ask more broad questions, e.g. can bob access servers labeled as "prod"? - ret, _ = p.query( + ret = p.query( Predicate( (Server.env == "prod") & (User.team == "prod") & (User.name == "bob") ) ) - assert ret is True, "Bob can access servers labeled as prod" + assert ret.solves is True, "Bob can access servers labeled as prod" # Can bob access servers labeled as stage? - ret, _ = p.query( + ret = p.query( Predicate( (Server.env == "stage") & (User.team == "prod") & (User.name == "bob") ) ) - assert ret is False, "Bob can't access servers labeled as stage" + assert ret.solves is False, "Bob can't access servers labeled as stage" # Bob can't access server prod with someone else's name - ret, _ = p.check( + ret = p.check( Predicate( (Server.env == "prod") & (User.team == "prod") @@ -154,10 +153,10 @@ def test_queries(self): & (Server.login == "jim") ) ) - assert ret is False, "Bob can't access prod with someone else's username" + assert ret.solves is False, "Bob can't access prod with someone else's username" # Bob can't access server prod if Bob's team is not valid - ret, _ = p.check( + ret = p.check( Predicate( (Server.env == "prod") & (User.team == "stage") @@ -165,7 +164,7 @@ def test_queries(self): & (Server.login == "bob") ) ) - assert ret is False, "Bob can't access servers of not his team" + assert ret.solves is False, "Bob can't access servers of not his team" def test_invariants(self): """ @@ -183,7 +182,7 @@ def test_invariants(self): p = Predicate(root | general) # Alice can access prod as root - ret, _ = p.check( + ret = p.check( Predicate( (Server.env == "prod") & (User.name == "alice") @@ -191,10 +190,10 @@ def test_invariants(self): & (User.team == "noop") ) ) - assert ret is True, "Alice can access prod as root" + assert ret.solves is True, "Alice can access prod as root" # Bob can access stage as his name - ret, _ = p.check( + ret = p.check( Predicate( (Server.env == "stage") & (User.name == "bob") @@ -202,10 +201,10 @@ def test_invariants(self): & (User.team == "stage") ) ) - assert ret is True, "Bob can access stage with his name" + assert ret.solves is True, "Bob can access stage with his name" # Bob can't access prod as root - ret, _ = p.check( + ret = p.check( Predicate( (Server.env == "prod") & (User.name == "bob") @@ -213,20 +212,20 @@ def test_invariants(self): & (User.team == "prod") ) ) - assert ret is False, "Bob can't access prod as root" + assert ret.solves is False, "Bob can't access prod as root" # Queries: - ret, _ = p.query(Predicate((Server.env == "prod") & (Server.login == "root"))) - assert ret is True, "Is it possible for someone access prod as root" + ret = p.query(Predicate((Server.env == "prod") & (Server.login == "root"))) + assert ret.solves is True, "Is it possible for someone access prod as root" # Is it possible for bob to access prod as root? # this is invariant we can verify with call to query - ret, _ = p.query( + ret = p.query( Predicate( (Server.env == "prod") & (Server.login == "root") & (User.name == "bob") ) ) - assert ret is False, "Bob can't access prod as root" + assert ret.solves is False, "Bob can't access prod as root" # This is a more broad, and more strict invariant: # @@ -235,7 +234,7 @@ def test_invariants(self): # This could be a linter checker to make sure that whatever rules # people define, they can't access as root. # - ret, _ = p.query( + ret = p.query( Predicate( (Server.env == "prod") & (Server.login == "root") @@ -243,7 +242,7 @@ def test_invariants(self): ) ) assert ( - ret is False + ret.solves is False ), "Is it possible for anyone who is not alice to access prod as root?" # Let's try the case that contradicts the predicate @@ -256,7 +255,7 @@ def test_invariants(self): p = Predicate(root | violation) # Jim can access prod as root - ret, _ = p.check( + ret = p.check( Predicate( (Server.env == "prod") & (User.name == "jim") @@ -264,38 +263,38 @@ def test_invariants(self): & (User.team == "noop") ) ) - assert ret is False, "Jim can't access prod as root" + assert ret.solves is False, "Jim can't access prod as root" def test_regex(self): p = Predicate(parse_regex("stage-.*").matches(User.team)) - ret, _ = p.check(Predicate(User.team == "stage-test")) - assert ret is True, "prefix patterns match" + ret = p.check(Predicate(User.team == "stage-test")) + assert ret.solves is True, "prefix patterns match" - ret, _ = p.check(Predicate(User.team == "stage-other")) - assert ret is True, "prefix patterns match" + ret = p.check(Predicate(User.team == "stage-other")) + assert ret.solves is True, "prefix patterns match" - ret, _ = p.check(Predicate(User.team == "prod-test")) - assert ret is False, "prefix pattern mismatch" + ret = p.check(Predicate(User.team == "prod-test")) + assert ret.solves is False, "prefix pattern mismatch" def test_concat(self): p = Predicate(Server.login == User.name + "-login") - ret, _ = p.check( + ret = p.check( Predicate((Server.login == "alice-login") & (User.name == "alice")) ) - assert ret is True, "pattern matches suffix" + assert ret.solves is True, "pattern matches suffix" p = Predicate(Server.login == "login-" + User.name) - ret, _ = p.check( + ret = p.check( Predicate((Server.login == "login-alice") & (User.name == "alice")) ) - assert ret is True, "pattern matches prefix" + assert ret.solves is True, "pattern matches prefix" p = Predicate(Server.login == "login-" + User.name + "-user") - ret, _ = p.check( + ret = p.check( Predicate((Server.login == "login-alice-user") & (User.name == "alice")) ) - assert ret is True, "pattern matches suffix and prefix" + assert ret.solves is True, "pattern matches suffix and prefix" # TODOs: # https://github.com/Z3Prover/z3/blob/9f9543ef698adc77252ed366e6d85cc71e4b8c89/src/ast/rewriter/seq_axioms.cpp#L1044 @@ -306,69 +305,69 @@ def test_delimiter(self): Test splitting at delimiter. """ p = Predicate(Server.login == User.name.before_delimiter("@")) - ret, _ = p.check( + ret = p.check( Predicate((Server.login == "alice") & (User.name == "alice@example.com")) ) - assert ret is True, "splitting before delimiter works" + assert ret.solves is True, "splitting before delimiter works" - ret, _ = p.check( + ret = p.check( Predicate((Server.login == "") & (User.name == "alice-example.com")) ) - assert ret is True, "delimiter not present, string renders to empty" + assert ret.solves is True, "delimiter not present, string renders to empty" p = Predicate(Server.login == User.name.after_delimiter("@")) - ret, _ = p.check( + ret = p.check( Predicate( (Server.login == "example.com") & (User.name == "alice@example.com") ) ) - assert ret is True, "splitting after delimiter works" + assert ret.solves is True, "splitting after delimiter works" - ret, _ = p.check( + ret = p.check( Predicate((Server.login == "") & (User.name == "alice-example.com")) ) - assert ret is True, "delimiter not present, string renders to empty" + assert ret.solves is True, "delimiter not present, string renders to empty" def test_replace(self): """ Test replace string characters. """ p = Predicate(Server.login == User.name.replace("@", "-")) - ret, z = p.check( + ret = p.check( Predicate( (Server.login == "alice-example.com") & (User.name == "alice@example.com") ) ) - assert ret is True, "replacing works" + assert ret.solves is True, "replacing works" - ret, _ = p.check( + ret = p.check( Predicate( (Server.login == "alice+example.com") & (User.name == "alice+example.com") ) ) - assert ret is True, "character not present, no effect" + assert ret.solves is True, "character not present, no effect" def test_upper_lower(self): """ Test upper lower case. """ p = Predicate(Server.login == User.name.upper()) - ret, z = p.check(Predicate((Server.login == "ALICE") & (User.name == "AlicE"))) - assert ret is True, "uppercase works" + ret = p.check(Predicate((Server.login == "ALICE") & (User.name == "AlicE"))) + assert ret.solves is True, "uppercase works" p = Predicate(Server.login == User.name.lower()) - ret, z = p.check(Predicate((Server.login == "alice") & (User.name == "AlicE"))) - assert ret is True, "lowercase works" + ret = p.check(Predicate((Server.login == "alice") & (User.name == "AlicE"))) + assert ret.solves is True, "lowercase works" def test_string_list_init(self): # filter example - we only keep fruits by copying them over fruits = StringList("fruits", ("apple", "strawberry", "banana")) p = Predicate((fruits == ("apple", "strawberry", "banana"))) - ret, _ = p.solve() - assert ret is True, "values match" + ret = p.solve() + assert ret.solves is True, "values match" def test_string_list_with_if(self): basket = StringList("basket") @@ -382,48 +381,44 @@ def test_string_list_with_if(self): ), ) - p = Predicate( + ret = Predicate( (basket == ("strawberry", "apple", "banana")) & (fruits == ("blueberry", "strawberry", "apple", "banana")) - ) - ret, _ = p.solve() - assert ret is True, "blueberry was added" + ).solve() + assert ret.solves is True, "blueberry was added" - p = Predicate( + ret = Predicate( (basket == ("apple", "strawberry")) & (fruits == ("apple", "strawberry")) - ) - p.solve() - assert ret is True, "blueberry was not added" + ).solve() + assert ret.solves is True, "blueberry was not added" def test_string_list_transform_value(self): external = StringList("external") # transform example - we reference another variable and transform traits = StringList("traits", external.replace("@", "-")) - p = Predicate( + ret = Predicate( (external == ("alice@wonderland.local",)) & (traits == ("alice-wonderland.local",)) - ) - ret, _ = p.solve() - assert ret is True, "transformation has been applied" + ).solve() + assert ret.solves is True, "transformation has been applied" - p = Predicate((external == ()) & (traits == ())) - ret, _ = p.solve() - assert ret is True, "transformation on empty list is empty" + ret = Predicate((external == ()) & (traits == ())).solve() + assert ret.solves is True, "transformation on empty list is empty" def test_string_set_map_contains(self): traits = StringSetMap("mymap") p = Predicate(traits["key"].contains("potato")) - ret, _ = p.check( + ret = p.check( Predicate( (traits["key"] == ("apple", "potato", "banana")) | (traits["key"] == ("strawberry",)) ) ) - assert ret is True, "values match" + assert ret.solves is True, "values match" - ret, _ = p.check(Predicate(traits["key"] == ("apple", "banana"))) - assert ret is False, "values don't match" + ret = p.check(Predicate(traits["key"] == ("apple", "banana"))) + assert ret.solves is False, "values don't match" def test_string_set_map_contains_regex(self): traits = StringSetMap( @@ -432,9 +427,8 @@ def test_string_set_map_contains_regex(self): "groups": ("fruit-apple", "veggie-potato", "fruit-banana"), }, ) - p = Predicate(traits["groups"].contains_regex("fruit-.*")) - ret, _ = p.solve() - assert ret is True, "values match regular expression" + ret = Predicate(traits["groups"].contains_regex("fruit-.*")).solve() + assert ret.solves is True, "values match regular expression" traits = StringSetMap( "mymap", @@ -442,10 +436,8 @@ def test_string_set_map_contains_regex(self): "groups": ("apple", "potato", "banana"), }, ) - with pytest.raises(ParameterError) as exc: - p = Predicate(traits["groups"].contains_regex("fruit-apple")) - ret, _ = p.solve() - assert "unsolvable" in str(exc.value) + with pytest.raises(ParameterError, match="unsolvable"): + _ = Predicate(traits["groups"].contains_regex("fruit-apple")).solve() def test_string_set_map_len(self): traits = StringSetMap( @@ -454,32 +446,29 @@ def test_string_set_map_len(self): "groups": ("fruit-apple", "veggie-potato", "fruit-banana"), }, ) - p = Predicate(traits["groups"].len() == 3) - ret, _ = p.solve() - assert ret is True, "counts properly" + ret = Predicate(traits["groups"].len() == 3).solve() + assert ret.solves is True, "counts properly" - p = Predicate(traits["missing"].len() == 0) - ret, _ = p.solve() - assert ret is True, "missing key results in empty count" + ret = Predicate(traits["missing"].len() == 0).solve() + assert ret.solves is True, "missing key results in empty count" with pytest.raises(ParameterError): - p = Predicate(traits["missing"].len() == 1) - ret, _ = p.solve() + _ = Predicate(traits["missing"].len() == 1).solve() def test_string_set_map_len_undefined(self): approvals = StringSetMap("mymap") p = Predicate(approvals["my-policy"].len() > 2) - ret, _ = p.solve() - assert ret is True, "predicate solves" + ret = p.solve() + assert ret.solves is True, "predicate solves" - ret, _ = p.check(Predicate(approvals["my-policy"].len() == 3)) - assert ret is True, "predicate solves" + ret = p.check(Predicate(approvals["my-policy"].len() == 3)) + assert ret.solves is True, "predicate solves" - ret, _ = p.check(Predicate(approvals["my-policy"].len() == 2)) - assert ret is False, "predicate fails to solve" + ret = p.check(Predicate(approvals["my-policy"].len() == 2)) + assert ret.solves is False, "predicate fails to solve" - ret, _ = p.query(Predicate(approvals["my-policy"].len() == 2)) - assert ret is False, "predicate fails to solve with query as well" + ret = p.query(Predicate(approvals["my-policy"].len() == 2)) + assert ret.solves is False, "predicate fails to solve with query as well" def test_string_set_map_first(self): traits = StringSetMap( @@ -488,9 +477,8 @@ def test_string_set_map_first(self): "groups": ("fruit-apple", "veggie-potato", "fruit-banana"), }, ) - p = Predicate(traits["groups"].first() == "fruit-apple") - ret, _ = p.solve() - assert ret is True, "gets first value" + ret = Predicate(traits["groups"].first() == "fruit-apple").solve() + assert ret.solves is True, "gets first value" traits = StringSetMap( "mymap", @@ -499,9 +487,8 @@ def test_string_set_map_first(self): }, ) - p = Predicate(traits["groups"].first() == "potato") - ret, _ = p.solve() - assert ret is True, "gets first non-empty value" + ret = Predicate(traits["groups"].first() == "potato").solve() + assert ret.solves is True, "gets first non-empty value" traits = StringSetMap( "mymap", @@ -510,9 +497,8 @@ def test_string_set_map_first(self): }, ) - p = Predicate(traits["groups"].first() == "") - ret, _ = p.solve() - assert ret is True, "returns empty string if no value" + ret = Predicate(traits["groups"].first() == "").solve() + assert ret.solves is True, "returns empty string if no value" def test_string_set_map_add_value(self): traits = StringSetMap("mymap") @@ -520,13 +506,11 @@ def test_string_set_map_add_value(self): # this predicate is always true, we always add strawberry traits.add_value("fruits", "strawberry")["fruits"].contains("strawberry") ) - ret, _ = p.check( - Predicate(traits["fruits"] == ("strawberry", "apple", "banana")) - ) - assert ret is True, "values match with strawberry" + ret = p.check(Predicate(traits["fruits"] == ("strawberry", "apple", "banana"))) + assert ret.solves is True, "values match with strawberry" - ret, _ = p.check(Predicate(traits["fruits"] == ("apple", "banana"))) - assert ret is True, "values match even if strawberry is missing" + ret = p.check(Predicate(traits["fruits"] == ("apple", "banana"))) + assert ret.solves is True, "values match even if strawberry is missing" def test_string_set_map_with_values(self): external = StringSetMap("external") @@ -537,21 +521,18 @@ def test_string_set_map_with_values(self): "our-fruits": external["fruits"], }, ) - p = Predicate( + ret = Predicate( (external["fruits"] == ("strawberry", "apple", "banana")) & (traits["our-fruits"].contains("strawberry")) - ) - ret, _ = p.solve() - assert ret is True, "values match with strawberry" + ).solve() + assert ret.solves is True, "values match with strawberry" - with pytest.raises(ParameterError) as exc: + with pytest.raises(ParameterError, match="unsolvable"): # this predicate is unsolvable - p = Predicate( + _ = Predicate( (external["fruits"] == ("apple", "banana")) & (traits["our-fruits"].contains("strawberry")) - ) - p.solve() - assert "unsolvable" in str(exc.value) + ).solve() def test_string_set_map_list_init(self): # filter example - we only keep fruits by copying them over @@ -562,12 +543,11 @@ def test_string_set_map_list_init(self): "veggies": ("potato",), }, ) - p = Predicate( + ret = Predicate( (traits["fruits"] == ("apple", "strawberry", "banana")) & (traits["veggies"] == ("potato",)) - ) - ret, _ = p.solve() - assert ret is True, "values match with strawberry" + ).solve() + assert ret.solves is True, "values match with strawberry" def test_string_set_map_remove_keys(self): # filter example - we only keep fruits by copying them over @@ -578,12 +558,11 @@ def test_string_set_map_remove_keys(self): "veggies": ("potato",), }, ).remove_keys("veggies") - p = Predicate( + ret = Predicate( (traits["fruits"] == ("apple", "strawberry", "banana")) & (traits["veggies"] == ()) - ) - ret, _ = p.solve() - assert ret is True, "veggies is empty" + ).solve() + assert ret.solves is True, "veggies is empty" def test_string_set_map_overwrite(self): traits = StringSetMap( @@ -592,12 +571,11 @@ def test_string_set_map_overwrite(self): "fruits": ("apple", "strawberry", "banana"), }, ).overwrite({"veggies": ("potato",)}) - p = Predicate( + ret = Predicate( (traits["fruits"] == ("apple", "strawberry", "banana")) & (traits["veggies"] == ("potato",)) - ) - ret, _ = p.solve() - assert ret is True, "overwritten with added potato values" + ).solve() + assert ret.solves is True, "overwritten with added potato values" def test_string_set_map_with_if(self): external = StringSetMap("external") @@ -612,19 +590,17 @@ def test_string_set_map_with_if(self): ) }, ) - p = Predicate( + ret = Predicate( (external["fruits"] == ("strawberry", "apple", "banana")) & (traits["our-fruits"] == ("blueberry", "strawberry", "apple", "banana")) - ) - ret, _ = p.solve() - assert ret is True, "blueberry was added" + ).solve() + assert ret.solves is True, "blueberry was added" - p = Predicate( + ret = Predicate( (external["fruits"] == ("apple", "strawberry")) & (traits["our-fruits"] == ("apple", "strawberry")) - ) - p.solve() - assert ret is True, "blueberry was not added" + ).solve() + assert ret.solves is True, "blueberry was not added" def test_string_set_map_transform_value(self): external = StringSetMap("external") @@ -635,16 +611,14 @@ def test_string_set_map_transform_value(self): "login": external["email"].replace("@", "-"), }, ) - p = Predicate( + ret = Predicate( (external["email"] == ("alice@wonderland.local",)) & (traits["login"] == ("alice-wonderland.local",)) - ) - ret, _ = p.solve() - assert ret is True, "transformation has been applied" + ).solve() + assert ret.solves is True, "transformation has been applied" - p = Predicate((external["email"] == ()) & (traits["login"] == ())) - ret, _ = p.solve() - assert ret is True, "transformation on empty list is empty" + ret = Predicate((external["email"] == ()) & (traits["login"] == ())).solve() + assert ret.solves is True, "transformation on empty list is empty" def test_string_map(self): """ @@ -655,58 +629,60 @@ def test_string_map(self): # maps could be part of the predicate p = Predicate(m["key"] == "val") - ret, _ = p.query(Predicate(m["key"] == "val")) - assert ret is True, "values match" + ret = p.query(Predicate(m["key"] == "val")) + assert ret.solves is True, "values match" - ret, _ = p.query(Predicate(m["key"] != "val")) - assert ret is False, "values don't match" + ret = p.query(Predicate(m["key"] != "val")) + assert ret.solves is False, "values don't match" # check raises parameter error because key 'key' is missing with pytest.raises(ParameterError): - ret, _ = p.check(Predicate(m["missing"] == "potato")) + ret = p.check(Predicate(m["missing"] == "potato")) - ret, _ = p.query(Predicate(m["missing"] == "potato")) + ret = p.query(Predicate(m["missing"] == "potato")) assert ( - ret is False + ret.solves is False ), "query that does not have matching keys should fail, otherwise you would get query match both predicates with keys missing and exact matches" - ret, _ = p.solves_with(Predicate(m["missing"] == "potato")) + ret = p.solves_with(Predicate(m["missing"] == "potato")) assert ( - ret is True + ret.solves is True ), "contrary to query statement above, this predicate does not contradict the statement m['key'] == 'val', both can be true at the same time" # multiple key-value checks m = StringMap("mymap") p = Predicate((m["key"] == "val") & (m["key-2"] == "val-2")) - ret, _ = p.query(Predicate(m["key"] == "val")) - assert ret is True, "query on subset of keys is successful" + ret = p.query(Predicate(m["key"] == "val")) + assert ret.solves is True, "query on subset of keys is successful" # check will raise error when there is ambiguity because # not all keys have been specified with pytest.raises(ParameterError): - ret, _ = p.check(Predicate(m["key"] == "val")) + _ = p.check(Predicate(m["key"] == "val")) - ret, _ = p.check(Predicate((m["key"] == "val") & (m["key-2"] == "val-2"))) - assert ret is True, "check, all keys and values, match" + ret = p.check(Predicate((m["key"] == "val") & (m["key-2"] == "val-2"))) + assert ret.solves is True, "check, all keys and values, match" - ret, _ = p.check( + ret = p.check( Predicate( (m["key"] == "val") & (m["key-2"] == "val-2") & (m["key-3"] == "val") ) ) - assert ret is True, "check is OK when the right side has superset of keys" + assert ( + ret.solves is True + ), "check is OK when the right side has superset of keys" - ret, _ = p.check(Predicate((m["key"] == "val") & (m["key-2"] == "wrong"))) - assert ret is False, "check fails when some values don't match" + ret = p.check(Predicate((m["key"] == "val") & (m["key-2"] == "wrong"))) + assert ret.solves is False, "check fails when some values don't match" - ret, _ = p.check( + ret = p.check( Predicate( (m["key"] == "val") & (m["key-2"] == "wrong") & (m["key-3"] == "val") ) ) assert ( - ret is False + ret.solves is False ), "check fails when the right side has superset of keys, but values don't match" def test_string_enum(self): @@ -716,17 +692,15 @@ def test_string_enum(self): e = StringEnum("fruits", set(["banana", "apple", "strawberry"])) # enums could be part of the predicate - p = Predicate( - (e == 'apple') | (e == 'banana') - ) - ret, _ = p.query(Predicate(e == "banana")) - assert ret is True, "value can be banana" + p = Predicate((e == "apple") | (e == "banana")) + ret = p.query(Predicate(e == "banana")) + assert ret.solves is True, "value can be banana" - ret, _ = p.query(Predicate((e == "strawberry"))) - assert ret is False, "value cannot be strawberry" + ret = p.query(Predicate((e == "strawberry"))) + assert ret.solves is False, "value cannot be strawberry" - ret, _ = p.query(Predicate((e != "apple") & (e != "banana"))) - assert ret is False, "value must be one of apple and banana" + ret = p.query(Predicate((e != "apple") & (e != "banana"))) + assert ret.solves is False, "value must be one of apple and banana" def test_string_enum_comparison(self): """ @@ -737,17 +711,17 @@ def test_string_enum_comparison(self): # enums could be part of the predicate and can provide constraints p = Predicate((e > "apple") | (e == "apple")) - ret, _ = p.query(Predicate(e == "apple")) - assert ret is True, "the value can be apple" + ret = p.query(Predicate(e == "apple")) + assert ret.solves is True, "the value can be apple" - ret, _ = p.query(Predicate(e == "watermelon")) - assert ret is True, "the value can be watermelon" + ret = p.query(Predicate(e == "watermelon")) + assert ret.solves is True, "the value can be watermelon" - ret, _ = p.query(Predicate(e == "strawberry")) - assert ret is False, "the value cannot be strawberry" + ret = p.query(Predicate(e == "strawberry")) + assert ret.solves is False, "the value cannot be strawberry" - ret, _ = p.query(Predicate((e != "apple") & (e != "watermelon"))) - assert ret is False, "value must be one of apple and watermelon" + ret = p.query(Predicate((e != "apple") & (e != "watermelon"))) + assert ret.solves is False, "value must be one of apple and watermelon" # ensure that all inequalities can only be specified using valid enum values with pytest.raises(TypeError, match=r"is not one of"): @@ -769,8 +743,8 @@ def test_string_map_regex(self): # maps could be part of the predicate p = Predicate(parse_regex("env-.*").matches(m["key"])) - ret, _ = p.query(Predicate(m["key"] == "env-prod")) - assert ret is True + ret = p.query(Predicate(m["key"] == "env-prod")) + assert ret.solves is True def test_string_tuple(self): """ @@ -778,8 +752,8 @@ def test_string_tuple(self): """ t = StringTuple(["banana", "potato", "apple"]) p = Predicate(t.contains("banana")) - ret, _ = p.query(Predicate(t.contains("apple"))) - assert ret is True + ret = p.query(Predicate(t.contains("apple"))) + assert ret.solves is True def test_regex_tuple(self): """ @@ -787,8 +761,8 @@ def test_regex_tuple(self): """ t = regex_tuple(["banana-.*", "potato-.*", "apple-.*"]) p = Predicate(t.matches("banana-smoothie")) - ret, _ = p.query(Predicate(t.matches("apple-smoothie"))) - assert ret is True + ret = p.query(Predicate(t.matches("apple-smoothie"))) + assert ret.solves is True def test_int(self): """ @@ -796,16 +770,16 @@ def test_int(self): """ p = Predicate(Request.approve == 1) - ret, _ = p.check(Predicate(Request.approve == 1)) - assert ret is True, "solves with simple equality check" + ret = p.check(Predicate(Request.approve == 1)) + assert ret.solves is True, "solves with simple equality check" p = Predicate((Request.approve > 1) & (Request.approve < 3)) - ret, _ = p.check(Predicate(Request.approve == 2)) - assert ret is True, "solves with simple boundary check" + ret = p.check(Predicate(Request.approve == 2)) + assert ret.solves is True, "solves with simple boundary check" - ret, _ = p.check(Predicate(Request.approve == 5)) - assert ret is False, "solves with simple boundary check" + ret = p.check(Predicate(Request.approve == 5)) + assert ret.solves is False, "solves with simple boundary check" def test_duration(self): """ @@ -815,19 +789,19 @@ def test_duration(self): Options.ttl == Duration.new(hours=5), ) - ret, _ = p.check(Predicate(Options.ttl == Duration.new(hours=5))) - assert ret is True, "solves with simple equality check" + ret = p.check(Predicate(Options.ttl == Duration.new(hours=5))) + assert ret.solves is True, "solves with simple equality check" p = Predicate( (Options.ttl > Duration.new(seconds=10)) & (Options.ttl < Duration.new(hours=5)) ) - ret, _ = p.check(Predicate(Options.ttl == Duration.new(hours=3))) - assert ret is True, "solves with simple boundary check" + ret = p.check(Predicate(Options.ttl == Duration.new(hours=3))) + assert ret.solves is True, "solves with simple boundary check" - ret, _ = p.check(Predicate(Options.ttl == Duration.new(hours=6))) - assert ret is False, "solves with simple boundary check" + ret = p.check(Predicate(Options.ttl == Duration.new(hours=6))) + assert ret.solves is False, "solves with simple boundary check" def test_bool(self): """ @@ -837,11 +811,11 @@ def test_bool(self): Options.pin_source_ip == True, ) - ret, _ = p.check(Predicate(Options.pin_source_ip == True)) - assert ret is True, "solves with simple equality check" + ret = p.check(Predicate(Options.pin_source_ip == True)) + assert ret.solves is True, "solves with simple equality check" - ret, _ = p.check(Predicate(Options.pin_source_ip == False)) - assert ret is False, "solves with simple boundary check" + ret = p.check(Predicate(Options.pin_source_ip == False)) + assert ret.solves is False, "solves with simple boundary check" def test_select(self): external = StringSetMap("external") @@ -854,15 +828,15 @@ def test_select(self): # Case(external['groups'].matches_regexp('test-.*'), (external['groups'].replace('admin-', 'role-').add('bob'))) ) - ret, _ = Predicate( + ret = Predicate( (s == ("admin",)) & (external["groups"] == ("admin", "other")) ).solve() - assert ret is True, "simple match works" + assert ret.solves is True, "simple match works" - ret, _ = Predicate( + ret = Predicate( (s == ()) & (external["groups"] == ("nomatch", "other")) ).solve() - assert ret is True, "no match results in default" + assert ret.solves is True, "no match results in default" # once you have a select defined, you can specify set of roles and policies @@ -875,15 +849,15 @@ def test_select_regex(self): Default(()), ) - ret, _ = Predicate( + ret = Predicate( (s == ("admin",)) & (external["groups"] == ("admin-test", "other")) ).solve() - assert ret is True, "simple match works" + assert ret.solves is True, "simple match works" - ret, _ = Predicate( + ret = Predicate( (s == ()) & (external["groups"] == ("nomatch", "other")) ).solve() - assert ret is True, "no match results in default" + assert ret.solves is True, "no match results in default" def test_select_regex_replace(self): external = StringSetMap("external") @@ -897,14 +871,14 @@ def test_select_regex_replace(self): Default(external["groups"]), ) - ret, _ = Predicate( + ret = Predicate( (s == ("ext-test", "ext-prod")) & (external["groups"] == ("admin-test", "admin-prod")) ).solve() - assert ret is True, "match and replace works" + assert ret.solves is True, "match and replace works" - ret, _ = Predicate( + ret = Predicate( (s == ("dev-test", "dev-prod")) & (external["groups"] == ("dev-test", "dev-prod")) ).solve() - assert ret is True, "match and replace works default value" + assert ret.solves is True, "match and replace works default value" diff --git a/predicate/solver/test/test_aws.py b/predicate/solver/test/test_aws.py index 5210c19..9b68a7c 100644 --- a/predicate/solver/test/test_aws.py +++ b/predicate/solver/test/test_aws.py @@ -7,43 +7,43 @@ class TestAWS: def test_aws_allow_policy(self, mixed_statement_policy): p = Predicate(aws.policy(mixed_statement_policy)) - ret, _ = p.check( + ret = p.check( Predicate( (aws.Action.resource == "arn:aws:s3:::example_bucket") & (aws.Action.action == "s3:ListBucket") ) ) - assert ret is True + assert ret.solves is True def test_aws_policy(self, s3_policy): p = Predicate(aws.policy(s3_policy)) # get bucket location on any bucket works - ret, d = p.check( + ret = p.check( Predicate( (aws.Action.resource == "arn:aws:s3:::example_bucket") & (aws.Action.action == "s3:GetBucketLocation") ) ) - assert ret is True + assert ret.solves is True # listing bucket logs is not allowed - ret, d = p.check( + ret = p.check( Predicate( (aws.Action.resource == "arn:aws:s3:::example_bucket/logs") & (aws.Action.action == "s3:GetObject") ) ) - assert ret is False + assert ret.solves is False # can get a random doc from a bucket - ret, d = p.check( + ret = p.check( Predicate( (aws.Action.resource == "arn:aws:s3:::carlossalazar/document") & (aws.Action.action == "s3:GetObject") ) ) - assert ret is True + assert ret.solves is True @pytest.fixture diff --git a/predicate/solver/test/test_teleport.py b/predicate/solver/test/test_teleport.py index 1856ebe..f44576a 100644 --- a/predicate/solver/test/test_teleport.py +++ b/predicate/solver/test/test_teleport.py @@ -11,18 +11,19 @@ StringSetMap, ) from ..teleport import ( + AccessNode, + JoinSession, LoginRule, - User, Node, - AccessNode, Options, - OptionsSet, Policy, PolicyMap, PolicySet, + RecordingMode, Rules, - JoinSession, Session, + SourceIp, + User, map_policies, try_login, ) @@ -33,45 +34,55 @@ def test_node(self): p = Policy( name="test", allow=Rules( - AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "prod")), + AccessNode( + (AccessNode.login == "root") & (Node.labels["env"] == "prod") + ), ), ) - ret, _ = p.check( + ret = p.check( AccessNode( (AccessNode.login == "root") & (Node.labels["env"] == "prod") & (Node.labels["os"] == "Linux") ) ) - assert ret is True, "check works on a superset" + assert ret.solves is True, "check works on a superset" def test_allow_policy_set(self): a = Policy( name="a", allow=Rules( - AccessNode((AccessNode.login == "ubuntu") & (Node.labels["env"] == "prod")), + AccessNode( + (AccessNode.login == "ubuntu") & (Node.labels["env"] == "prod") + ), ), ) b = Policy( name="b", allow=Rules( - AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "stage")), + AccessNode( + (AccessNode.login == "root") & (Node.labels["env"] == "stage") + ), ), ) s = PolicySet([a, b]) - ret, _ = s.check( + ret = s.check( AccessNode((AccessNode.login == "ubuntu") & (Node.labels["env"] == "prod")) ) - assert ret is True, "check works on a subset" + assert ret.solves is True, "check works on a subset" - ret, _ = s.check(AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "stage"))) - assert ret is True, "check works on a subset" + ret = s.check( + AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "stage")) + ) + assert ret.solves is True, "check works on a subset" - ret, _ = s.check(AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "prod"))) - assert ret is False, "rejects the merge" + ret = s.check( + AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "prod")) + ) + assert ret.solves is False, "rejects the merge" def test_deny_policy_set(self): a = Policy( @@ -87,65 +98,169 @@ def test_deny_policy_set(self): b = Policy( name="b", deny=Rules( - AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "prod")), + AccessNode( + (AccessNode.login == "root") & (Node.labels["env"] == "prod") + ), ), ) s = PolicySet([a, b]) - ret, _ = s.check(AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "prod"))) - assert ret is False, "deny in a set overrides allow" + ret = s.check( + AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "prod")) + ) + assert ret.solves is False, "deny in a set overrides allow" - ret, _ = s.check( + ret = s.check( AccessNode((AccessNode.login == "ubuntu") & (Node.labels["env"] == "prod")) ) - assert ret is True, "non-denied part of allow is OK" + assert ret.solves is True, "non-denied part of allow is OK" + + def test_empty_policy(self): + with pytest.raises(ParameterError, match="policy name cannot be empty"): + _ = Policy(name="") + + with pytest.raises( + ParameterError, + match="policy must contain either options, allow or deny rules", + ): + _ = Policy(name="a") + + with pytest.raises( + ParameterError, + match="policy must contain either options, allow or deny rules", + ): + _ = Policy(name="a", options=Options(), allow=Rules(), deny=Rules()) + + # policy only with options is valid + _ = Policy(name="a", options=Options(max_session_ttl=Duration.new(hours=10))) + + # policy only with allow rules is valid + _ = Policy(name="a", allow=Rules(AccessNode(AccessNode.login == "root"))) - def test_valid_options(self): - # check that only < inequalities are possible - _ = Options(Options.max_session_ttl < Duration.new(hours=3)) - _ = Options(Duration.new(hours=3) > Options.max_session_ttl) + # policy only with deny rules is valid + _ = Policy(name="a", deny=Rules(AccessNode(AccessNode.login == "root"))) - with pytest.raises(Exception): - _ = Options(Options.max_session_ttl > Duration.new(hours=3)) - with pytest.raises(Exception): - _ = Options(Duration.new(hours=3) < Options.max_session_ttl) + def test_empty_rules(self): + assert Rules().empty() is True, "empty rules are empty" + assert ( + Rules(AccessNode(AccessNode.login == "root")).empty() is False + ), "non empty rules are non empty" + + def test_empty_options(self): + assert Options().empty() is True, "empty options are empty" + assert ( + Options(max_session_ttl=Duration.new(hours=10)).empty() is False + ), "non empty options are non empty" - with pytest.raises(Exception): - _ = Options(Options.max_session_ttl == Duration.new(hours=3)) - with pytest.raises(Exception): - _ = Options(Duration.new(hours=3) == Options.max_session_ttl) + assert Rules().empty() is True, "empty rules are empty" + assert ( + Rules(AccessNode(AccessNode.login == "root")).empty() is False + ), "non empty rules are non empty" - with pytest.raises(Exception): - _ = Options(Options.max_session_ttl != Duration.new(hours=3)) - with pytest.raises(Exception): - _ = Options(Duration.new(hours=3) != Options.max_session_ttl) + def test_options_init(self): + # TODO add tests ensuring that users cannot set invalid option values + # for example, the next test should fail and doesn't + _ = Options( + max_session_ttl=Duration.new(hours=10), + recording_mode=Duration.new(hours=10), + source_ip=Node.labels["env"] == "prod", + ) + + def test_options_combine(self): + result = Options.combine( + Options(), + Options(), + ) + assert result.max_session_ttl is None + assert result.source_ip is None + assert result.recording_mode is None + + result = Options.combine( + Options(max_session_ttl=Duration.new(hours=10)), + Options(), + ) + assert result.max_session_ttl == Duration.new(hours=10) + assert result.source_ip is None + assert result.recording_mode is None + + result = Options.combine( + Options(), + Options(max_session_ttl=Duration.new(hours=3)), + ) + assert result.max_session_ttl == Duration.new(hours=3) + assert result.source_ip is None + assert result.recording_mode is None + + result = Options.combine( + Options(max_session_ttl=Duration.new(hours=10)), + Options(max_session_ttl=Duration.new(hours=3)), + ) + assert result.max_session_ttl == Duration.new(hours=3) + assert result.source_ip is None + assert result.recording_mode is None + + result = Options.combine( + Options(max_session_ttl=Duration.new(hours=3)), + Options(max_session_ttl=Duration.new(hours=10)), + ) + assert result.max_session_ttl == Duration.new(hours=3) + assert result.source_ip is None + assert result.recording_mode is None + + result = Options.combine( + Options(source_ip=SourceIp.PINNED), + Options(source_ip=SourceIp.UNPINNED), + ) + assert result.max_session_ttl is None + assert result.source_ip == SourceIp.PINNED + assert result.recording_mode is None + + result = Options.combine( + Options(source_ip=SourceIp.UNPINNED), + Options(source_ip=SourceIp.PINNED), + ) + assert result.max_session_ttl is None + assert result.source_ip == SourceIp.PINNED + assert result.recording_mode is None + + result = Options.combine( + Options(recording_mode=RecordingMode.BEST_EFFORT), + Options(recording_mode=RecordingMode.STRICT), + ) + assert result.max_session_ttl is None + assert result.source_ip is None + assert result.recording_mode == RecordingMode.STRICT + + result = Options.combine( + Options(recording_mode=RecordingMode.STRICT), + Options(recording_mode=RecordingMode.BEST_EFFORT), + ) + assert result.max_session_ttl is None + assert result.source_ip is None + assert result.recording_mode == RecordingMode.STRICT def test_options(self): p = Policy( name="b", - options=OptionsSet( - Options( - (Options.max_session_ttl < Duration.new(hours=10)) - & (Options.max_session_ttl < Duration.new(seconds=10)), - ) + options=Options( + max_session_ttl=Duration.new(hours=10), ), allow=Rules( - AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "prod")), + AccessNode( + (AccessNode.login == "root") & (Node.labels["env"] == "prod") + ), ), ) - ret, _ = p.check( + ret = p.check( AccessNode( (AccessNode.login == "root") & (Node.labels["env"] == "prod") & (Node.labels["os"] == "Linux") ) - # Since we only have <, it's impossible to specify a `session_ttl` that would not be valid. - # This means that the predicate will always match the policy. - & Options(Options.max_session_ttl < Duration.new(hours=3)) ) - - assert ret is True, "options and core predicate matches" + assert ret.solves is True, "options and core predicate matches" + assert ret.options.max_session_ttl == Duration.new(hours=10) def test_options_extra(self): """ @@ -153,225 +268,184 @@ def test_options_extra(self): """ p = Policy( name="p", - options=OptionsSet( - Options( - (Options.max_session_ttl < Duration.new(hours=10)) - ), - Options(Options.pin_source_ip == True), + options=Options( + max_session_ttl=Duration.new(hours=10), + source_ip=SourceIp.PINNED, ), allow=Rules( # unrelated rules are with comma, related rules are part of the predicate - AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "prod")), + AccessNode( + (AccessNode.login == "root") & (Node.labels["env"] == "prod") + ), ), ) - ret, _ = p.check( + ret = p.check( AccessNode( (AccessNode.login == "root") & (Node.labels["env"] == "prod") & (Node.labels["os"] == "Linux") ) - & Options( - (Options.max_session_ttl < Duration.new(hours=3)) - # TODO: `check` doesn't require that `pin_source_ip` is defined here, but should!?. - & (Options.pin_source_ip == True) - ) ) - assert ret is True, "options and core predicate matches" - - ret, _ = p.check( - AccessNode( - (AccessNode.login == "root") - & (Node.labels["env"] == "prod") - & (Node.labels["os"] == "Linux") - ) - & Options( - (Options.max_session_ttl < Duration.new(hours=3)) - & (Options.pin_source_ip == False) - ) - ) - assert ret is False, "options fails restriction when contradiction is specified" + assert ret.solves is True + assert ret.options.max_session_ttl == Duration.new(hours=10) + assert ret.options.source_ip == SourceIp.PINNED def test_options_policy_set(self): a = Policy( name="a", - options=OptionsSet( - Options( - (Options.max_session_ttl < Duration.new(hours=10)) - ), - Options(Options.pin_source_ip == True), + options=Options( + max_session_ttl=Duration.new(hours=10), + source_ip=SourceIp.PINNED, + recording_mode=RecordingMode.BEST_EFFORT, ), allow=Rules( - AccessNode((AccessNode.login == "ubuntu") & (Node.labels["env"] == "stage")), + AccessNode( + (AccessNode.login == "ubuntu") & (Node.labels["env"] == "stage") + ) ), ) b = Policy( name="b", + options=Options( + max_session_ttl=Duration.new(hours=3), + source_ip=SourceIp.UNPINNED, + recording_mode=RecordingMode.STRICT, + ), allow=Rules( - AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "prod")), + AccessNode( + (AccessNode.login == "root") & (Node.labels["env"] == "prod") + ), ), ) p = PolicySet([a, b]) - ret, _ = p.check( + ret = p.check( AccessNode( (AccessNode.login == "root") & (Node.labels["env"] == "prod") & (Node.labels["os"] == "Linux") ) - & Options( - (Options.max_session_ttl < Duration.new(hours=3)) - & (Options.pin_source_ip == True) - ) ) - assert ret is True, "options and core predicate matches" - - ret, _ = p.check( - AccessNode( - (AccessNode.login == "root") - & (Node.labels["env"] == "prod") - & (Node.labels["os"] == "Linux") - ) - & Options( - (Options.max_session_ttl < Duration.new(hours=3)) - & (Options.pin_source_ip == False) - ) - ) - - assert ret is False, "options fails restriction" + assert ret.solves is True, "options and core predicate matches" + assert ret.options.max_session_ttl == Duration.new(hours=3) + assert ret.options.source_ip == SourceIp.PINNED + assert ret.options.recording_mode == RecordingMode.STRICT def test_options_policy_set_enum(self): # policy a requires best effort a = Policy( name="a", - options=OptionsSet( - Options( - (Options.recording_mode > "best_effort") - | (Options.recording_mode == "best_effort") - ), + options=Options( + recording_mode=RecordingMode.BEST_EFFORT, ), allow=Rules( - AccessNode((AccessNode.login == "ubuntu") & (Node.labels["env"] == "stage")), + AccessNode( + (AccessNode.login == "ubuntu") & (Node.labels["env"] == "stage") + ), ), ) # policy b requires strict recording mode b = Policy( name="b", - options=OptionsSet( - Options(Options.recording_mode == "strict"), + options=Options( + recording_mode=RecordingMode.STRICT, ), allow=Rules( - AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "prod")), + AccessNode( + (AccessNode.login == "root") & (Node.labels["env"] == "prod") + ), ), ) p = PolicySet([a, b]) - ret, _ = p.check( + ret = p.check( AccessNode( (AccessNode.login == "root") & (Node.labels["env"] == "prod") & (Node.labels["os"] == "Linux") ) - & Options(Options.recording_mode == "strict") ) - assert ret is True, "options and core predicate matches" + assert ret.solves is True, "options and core predicate matches" + assert ret.options.recording_mode == RecordingMode.STRICT - ret, _ = p.check( - AccessNode( - (AccessNode.login == "root") - & (Node.labels["env"] == "prod") - & (Node.labels["os"] == "Linux") - ) - & Options(Options.recording_mode == "best_effort") - ) - - assert ( - ret is False - ), "options fails recording mode restriction from the policy b" - - ret, _ = p.check( - AccessNode( - (AccessNode.login == "ubuntu") - & (Node.labels["env"] == "stage") - & (Node.labels["os"] == "Linux") - ) - & Options(Options.recording_mode == "strict") - ) - - assert ret is True, "options and core predicate matches" - - ret, _ = p.check( + ret = p.check( AccessNode( (AccessNode.login == "ubuntu") & (Node.labels["env"] == "stage") & (Node.labels["os"] == "Linux") ) - & Options(Options.recording_mode == "best_effort") ) - assert ( - ret is False - ), "strict is enforced for all modes of access across all policies in the set" + assert ret.solves is True, "options and core predicate matches" + assert ret.options.recording_mode == RecordingMode.STRICT def test_join_session(self): p = Policy( name="join_session", allow=Rules( JoinSession( - (User.traits["team"].contains("dev")) & - ((JoinSession.mode == "observer") | (JoinSession.mode == "peer")) & - ((Session.owner.traits["team"].contains("dev")) | (Session.owner.traits["team"].contains("intern"))) + (User.traits["team"].contains("dev")) + & ((JoinSession.mode == "observer") | (JoinSession.mode == "peer")) + & ( + (Session.owner.traits["team"].contains("dev")) + | (Session.owner.traits["team"].contains("intern")) + ) ), ), - deny=Rules( - JoinSession( - User.traits["team"].contains("intern") - ) - ) + deny=Rules(JoinSession(User.traits["team"].contains("intern"))), ) - ret, _ = p.check( + ret = p.check( JoinSession( - (User.traits["team"] == ("dev",)) & - (JoinSession.mode == "observer") & - (Session.owner.traits["team"] == ("intern",)) + (User.traits["team"] == ("dev",)) + & (JoinSession.mode == "observer") + & (Session.owner.traits["team"] == ("intern",)) ) ) - assert ret is True, "a dev user can join a session from an intern user as an observer" + assert ( + ret.solves is True + ), "a dev user can join a session from an intern user as an observer" - ret, _ = p.check( + ret = p.check( JoinSession( - (User.traits["team"] == ("marketing",)) & - (JoinSession.mode == "observer") & - (Session.owner.traits["team"] == ("intern",)) + (User.traits["team"] == ("marketing",)) + & (JoinSession.mode == "observer") + & (Session.owner.traits["team"] == ("intern",)) ) ) - assert ret is False, "a marketing user cannot join a session from an intern user as an observer" + assert ( + ret.solves is False + ), "a marketing user cannot join a session from an intern user as an observer" - ret, _ = p.check( + ret = p.check( JoinSession( - (User.traits["team"] == ("dev",)) & - (JoinSession.mode == "moderator") & - (Session.owner.traits["team"] == ("intern",)) + (User.traits["team"] == ("dev",)) + & (JoinSession.mode == "moderator") + & (Session.owner.traits["team"] == ("intern",)) ) ) - assert ret is False, "a dev user cannot join a session from an intern user as a moderator" + assert ( + ret.solves is False + ), "a dev user cannot join a session from an intern user as a moderator" - ret, _ = p.check( + ret = p.check( JoinSession( - (User.traits["team"] == ("dev", "intern")) & - (JoinSession.mode == "observer") & - (Session.owner.traits["team"] == ("intern",)) + (User.traits["team"] == ("dev", "intern")) + & (JoinSession.mode == "observer") + & (Session.owner.traits["team"] == ("intern",)) ) ) - assert ret is False, "a dev intern user cannot join a session from an intern user as an observer" + assert ( + ret.solves is False + ), "a dev intern user cannot join a session from an intern user as an observer" def test_login_rules(self): """ @@ -388,23 +462,22 @@ def test_login_rules(self): (external["email"] == ("alice@wonderland.local",)) & (traits["login"] == ("alice-wonderland.local",)) ) - ret, _ = p.solve() - assert ret is True, "transformation has been applied" + ret = p.solve() + assert ret.solves is True, "transformation has been applied" def test_policy_wrong_expr(self): """ Test that policy mapping always returns the right value """ - with pytest.raises(ParameterError) as exc: + with pytest.raises(ParameterError, match="should eval to string list"): PolicyMap( Select( # Default is necessary to specify default empty sequence or type Default(StringLiteral("test")), ) ) - assert "should eval to string list" in str(exc.value) - with pytest.raises(ParameterError) as exc: + with pytest.raises(ParameterError): external = StringSetMap("external") PolicyMap( Select( @@ -434,17 +507,17 @@ def test_policy_mapping(self): ) ) - ret, _ = Predicate( + ret = Predicate( (s == ("ext-test", "ext-prod")) & (external["groups"] == ("admin-test", "admin-prod")) ).solve() - assert ret is True, "match and replace works" + assert ret.solves is True, "match and replace works" - ret, _ = Predicate( + ret = Predicate( (s == ("dev-test", "dev-prod")) & (external["groups"] == ("dev-test", "dev-prod")) ).solve() - assert ret is True, "match and replace works default value" + assert ret.solves is True, "match and replace works default value" def test_full_cycle(self): external = StringSetMap("external") @@ -460,8 +533,8 @@ def test_full_cycle(self): (external["email"] == ("alice@wonderland.local",)) & (traits["login"] == ("alice-wonderland.local",)) ) - ret, _ = p.solve() - assert ret is True, "match and replace works in login rules" + ret = p.solve() + assert ret.solves is True, "match and replace works in login rules" s = PolicyMap( Select( @@ -474,26 +547,32 @@ def test_full_cycle(self): ) ) - ret, _ = Predicate( + ret = Predicate( (s == ("ext-test", "ext-prod")) & (external["groups"] == ("admin-test", "admin-prod")) ).solve() - assert ret is True, "match and replace works in policy maps" + assert ret.solves is True, "match and replace works in policy maps" - ret, _ = Predicate( + ret = Predicate( (s == ("dev-test", "dev-prod")) & (external["groups"] == ("dev-test", "dev-prod")) ).solve() - assert ret is True, "match and replace works in policy maps (default value)" + assert ( + ret.solves is True + ), "match and replace works in policy maps (default value)" # dev policy allows access to stage, and denies access to root dev = Policy( name="dev-stage", allow=Rules( - AccessNode((AccessNode.login == "ubuntu") & (Node.labels["env"] == "stage")), + AccessNode( + (AccessNode.login == "ubuntu") & (Node.labels["env"] == "stage") + ), ), deny=Rules( - AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "prod")), + AccessNode( + (AccessNode.login == "root") & (Node.labels["env"] == "prod") + ), ), ) @@ -501,8 +580,8 @@ def test_full_cycle(self): # but requires strict recording mode ext = Policy( name="ext-stage", - options=OptionsSet( - Options(Options.recording_mode == "strict"), + options=Options( + recording_mode=RecordingMode.STRICT, ), allow=Rules( AccessNode( @@ -515,8 +594,10 @@ def test_full_cycle(self): p = PolicySet([dev, ext]) # make sure that policy set will never allow access to prod - ret, _ = p.check(AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "prod"))) - assert ret is False + ret = p.check( + AccessNode((AccessNode.login == "root") & (Node.labels["env"] == "prod")) + ) + assert ret.solves is False policy_names = try_login( s, @@ -528,13 +609,13 @@ def test_full_cycle(self): # policy set will allow Alice to connect to prod if her # email is alice@wonderland.local - ret, _ = p.check( + ret = p.check( AccessNode( (AccessNode.login == "alice-wonderland.local") & (Node.labels["env"] == "prod") ) ) - assert ret is True + assert ret.solves is True # TODO: How to simplify testing and make it shorter? # TODO: How to connect policy mappings and diff --git a/predicate/solver/test/test_teleport_access_requests.py b/predicate/solver/test/test_teleport_access_requests.py index 35b7aa5..32541a2 100644 --- a/predicate/solver/test/test_teleport_access_requests.py +++ b/predicate/solver/test/test_teleport_access_requests.py @@ -31,52 +31,52 @@ def test_access_requests(self): ) # Can devs request access to stage? - ret, _ = devs.query( + ret = devs.query( Request( (RequestPolicy.names == ("access-stage",)) & (RequestPolicy.approvals["access-stage"].len() > 0) ) ) - assert ret is True, "Devs can request access to stage" + assert ret.solves is True, "Devs can request access to stage" - ret, _ = devs.query( + ret = devs.query( Request( (RequestPolicy.names == ("access-prod",)) & (RequestPolicy.approvals["access-prod"].len() > 0) ) ) - assert ret is False, "Devs can't request access to prod" + assert ret.solves is False, "Devs can't request access to prod" # Can devs review access to prod? - ret, _ = devs.query( + ret = devs.query( Review( RequestPolicy.names.contains( "access-stage", ) ) ) - assert ret is True, "Devs can review other folks access to stage" + assert ret.solves is True, "Devs can review other folks access to stage" # Can user with these policies review a role? - ret, _ = devs.query( + ret = devs.query( Review( RequestPolicy.names.contains( "access-prod", ) ) ) - assert ret is False, "can't review role that is not listed in the policy" + assert ret.solves is False, "can't review role that is not listed in the policy" # TODO: how to bind to roles? - ret, _ = devs.query( + ret = devs.query( Request( (RequestPolicy.names == ("access-stage",)) & (RequestPolicy.approvals["access-stage"] == ("alice", "bob")) ) ) - assert ret is True, "two folks have approved the request" + assert ret.solves is True, "two folks have approved the request" - ret, _ = devs.query( + ret = devs.query( Request( (RequestPolicy.names == ("access-stage",)) & (RequestPolicy.approvals["access-stage"] == ("alice", "bob")) @@ -84,7 +84,7 @@ def test_access_requests(self): ) ) assert ( - ret is False + ret.solves is False ), "two folks have approved the request, but one person denied it" def test_access_requests_review_expression(self): @@ -114,7 +114,7 @@ def test_access_requests_review_expression(self): ) request = RequestPolicy.names == ("access-stage",) - ret, _ = devs.query( + ret = devs.query( Request( request & ( @@ -123,10 +123,10 @@ def test_access_requests_review_expression(self): ) ) ) - assert ret is True, "one person have approved the request" + assert ret.solves is True, "one person have approved the request" request = RequestPolicy.names == ("access-stage",) - ret, _ = devs.query( + ret = devs.query( Request( request & ( @@ -135,7 +135,7 @@ def test_access_requests_review_expression(self): ), ) ) - assert ret is False, "one person has denied the request" + assert ret.solves is False, "one person has denied the request" def test_access_requests_review_limits(self): """ @@ -165,7 +165,7 @@ def test_access_requests_review_limits(self): ) # Can devs review access to stage? - ret, _ = devs.check( + ret = devs.check( Review( (RequestPolicy.names == ("access-stage",)) & (RequestPolicy.approvals["access-stage"].len() > 0) @@ -173,10 +173,12 @@ def test_access_requests_review_limits(self): & (Node.labels["env"] == "sre") ) ) - assert ret is True, "Devs can request access to stage for nodes in their env" + assert ( + ret.solves is True + ), "Devs can request access to stage for nodes in their env" # Can devs review access to stage? - ret, _ = devs.check( + ret = devs.check( Review( (RequestPolicy.names == ("access-stage",)) & (RequestPolicy.approvals["access-stage"].len() > 0) @@ -185,7 +187,7 @@ def test_access_requests_review_limits(self): ) ) assert ( - ret is False + ret.solves is False ), "Devs can't request access to stage for nodes not in their env" def test_access_requests_multi(self): @@ -229,26 +231,26 @@ def test_access_requests_multi(self): ) # Can devs request access to stage? - ret, _ = devs.query( + ret = devs.query( Request( (RequestPolicy.names.contains("access-stage")) & (RequestPolicy.approvals["access-stage"].len() > 0) ) ) - assert ret is True, "Devs can request access to stage" + assert ret.solves is True, "Devs can request access to stage" # Can devs request access to stage and prod at the same time? - ret, _ = devs.query( + ret = devs.query( Request( (RequestPolicy.names == ("access-stage", "access-prod")) & (RequestPolicy.approvals["access-stage"].len() > 0) ) ) - assert ret is True, "Devs can request access to stage and prod" + assert ret.solves is True, "Devs can request access to stage and prod" # With multi-roles, both roles have to be requested request = RequestPolicy.names == ("access-stage",) - ret, _ = devs.check( + ret = devs.check( Request( request & ( @@ -258,11 +260,11 @@ def test_access_requests_multi(self): ) ) assert ( - ret is False + ret.solves is False ), "one person have approved the request, but the request fails because both roles have to be requested" # Request for two policies got approved - ret, model = devs.check( + ret = devs.check( Request( (RequestPolicy.names == ("access-stage", "access-prod")) & ( @@ -283,7 +285,7 @@ def test_access_requests_multi(self): ) ) assert ( - ret is True + ret.solves is True ), "request is approved with two approvals for prod and one for stage" def test_access_requests_todo(self): diff --git a/predicate/solver/test/test_teleport_get_started.py b/predicate/solver/test/test_teleport_get_started.py index f147c72..607c7e4 100644 --- a/predicate/solver/test/test_teleport_get_started.py +++ b/predicate/solver/test/test_teleport_get_started.py @@ -18,8 +18,8 @@ def test_node_access(self): ) # Check if alice can access nodes as root - ret, _ = p.check(AccessNode((AccessNode.login == "root") & (User.name == "alice"))) - assert ret is True, "everyone can access as root, including alice" + ret = p.check(AccessNode((AccessNode.login == "root") & (User.name == "alice"))) + assert ret.solves is True, "everyone can access as root, including alice" # This is not a very useful policy, because it gives everyone # access as root. Let's narrow down this policy to let users @@ -32,16 +32,18 @@ def test_node_access(self): ) # Alice will be able to login to any machine as herself - ret, _ = p.check(AccessNode((AccessNode.login == "alice") & (User.name == "alice"))) - assert ret is True, "Alice can login with her user to any node" + ret = p.check( + AccessNode((AccessNode.login == "alice") & (User.name == "alice")) + ) + assert ret.solves is True, "Alice can login with her user to any node" # We can verify that a strong invariant holds: # Unless a username is root, a user can not access a server as # root. This creates a problem though, can we deny access as root # altogether? - ret, _ = p.check(AccessNode((AccessNode.login == "root") & (User.name != "root"))) + ret = p.check(AccessNode((AccessNode.login == "root") & (User.name != "root"))) assert ( - ret is False + ret.solves is False ), "This role does not allow access as root unless a user name is root" # Let's prohibit root access altogether. Deny rules always take @@ -58,8 +60,8 @@ def test_node_access(self): # partial conditions, our predicate requires user to be specified, # while this query does not specify any user. Checks require all # parameters of the predicate, while queries do not. - ret, _ = p.query(AccessNode((AccessNode.login == "root"))) - assert ret is False, "This role does not allow access as root to anyone" + ret = p.query(AccessNode((AccessNode.login == "root"))) + assert ret.solves is False, "This role does not allow access as root to anyone" def test_node_access_multiple_teams(self): """ @@ -82,7 +84,7 @@ def test_node_access_multiple_teams(self): ) # Check if alice can access nodes as root - ret, _ = devs_and_admins.check( + ret = devs_and_admins.check( AccessNode( (User.name == "alice") & (User.traits["team"] == ("dev",)) @@ -91,11 +93,11 @@ def test_node_access_multiple_teams(self): ) ) assert ( - ret is True + ret.solves is True ), "Policy lets Alice to access nodes as her username if nodes are labeled with dev" # Check if bob can access nodes as root - ret, _ = devs_and_admins.check( + ret = devs_and_admins.check( AccessNode( (User.name == "bob") & (User.traits["team"] == ("db-admins",)) @@ -104,7 +106,7 @@ def test_node_access_multiple_teams(self): ) ) assert ( - ret is True + ret.solves is True ), "Policy lets Bob to access nodes as her username if nodes are labeled with dev" # The policy