diff --git a/predicate/solver/ast.py b/predicate/solver/ast.py index 3e0d632..79aee12 100644 --- a/predicate/solver/ast.py +++ b/predicate/solver/ast.py @@ -426,8 +426,11 @@ def walk(self, fn): def __str__(self): return "int({})".format(self.name) +class LtDuration: + """ + LtDuration is a duration that only allows < inequalities. + """ -class Duration: def __init__(self, name: str): self.name = name self.val = z3.Int(self.name) @@ -439,7 +442,16 @@ def walk(self, fn): fn(self) def __str__(self): - return "duration({})".format(self.name) + return "ltduration({})".format(self.name) + + def __lt__(self, val: DurationLiteral): + return Lt(self, val) + + +class Duration(LtDuration): + """ + Duration is a duration that allows <, >, == and != inequalities. + """ @staticmethod def new( @@ -459,35 +471,17 @@ def new( + nanoseconds * NANOSECOND ) - def __eq__(self, val): - if isinstance(val, (Duration, DurationLiteral)): - return Eq(self, val) - raise TypeError( - "unsupported type {}, supported integers only".format(type(val)) - ) + def __str__(self): + return "duration({})".format(self.name) - def __ne__(self, val): - if isinstance(val, (Duration, DurationLiteral)): - return Not(Eq(self, val)) - raise TypeError( - "unsupported type {}, supported integers only".format(type(val)) - ) + def __eq__(self, val: DurationLiteral): + return Eq(self, val) - def __lt__(self, val): - if isinstance(val, (Duration, DurationLiteral)): - return Lt(self, val) - raise TypeError( - "unsupported type {}, supported integers only".format(type(val)) - ) + def __ne__(self, val: DurationLiteral): + return Not(Eq(self, val)) - def __gt__(self, val): - if isinstance(val, (Duration, DurationLiteral)): - return Gt(self, val) - raise TypeError( - "unsupported type {}, supported duration and duration literals only".format( - type(val) - ) - ) + def __gt__(self, val: DurationLiteral): + return Gt(self, val) class Bool: @@ -501,7 +495,7 @@ def __eq__(self, val): if isinstance(val, (Bool,)): return Eq(self, val) raise TypeError( - "unsupported type {}, supported integers only".format(type(val)) + "unsupported type {}, supported booleans only".format(type(val)) ) def __ne__(self, val): @@ -510,7 +504,7 @@ def __ne__(self, val): if isinstance(val, (Bool,)): return Not(Eq(self, val)) raise TypeError( - "unsupported type {}, supported integers only".format(type(val)) + "unsupported type {}, supported booleans only".format(type(val)) ) def traverse(self): @@ -1986,14 +1980,14 @@ def traverse(self): def collect_symbols(s, expr): - if isinstance(expr, (String, Int, Duration, Bool, StringEnum)): + if isinstance(expr, (String, Int, LtDuration, Duration, Bool, StringEnum)): s.add(expr.name) if isinstance(expr, MapIndex): s.add(expr.m.name + "." + expr.key) def collect_names(s, expr): - if isinstance(expr, (String, Int, Duration, Bool, StringEnum)): + if isinstance(expr, (String, Int, LtDuration, Duration, Bool, StringEnum)): s.add(expr.name) if isinstance(expr, MapIndex): s.add(expr.m.name) diff --git a/predicate/solver/teleport.py b/predicate/solver/teleport.py index 2fcb493..1828c4c 100644 --- a/predicate/solver/teleport.py +++ b/predicate/solver/teleport.py @@ -35,7 +35,7 @@ class Options(ast.Predicate): Options apply to some allow rules if they match """ - max_session_ttl = ast.Duration("options.max_session_ttl") + max_session_ttl = ast.LtDuration("options.max_session_ttl") pin_source_ip = ast.Bool("options.pin_source_ip") @@ -277,6 +277,7 @@ def t_expr(predicate): predicate, ( ast.String, + ast.LtDuration, ast.Duration, ast.StringList, ast.StringEnum, diff --git a/predicate/solver/test/test_teleport.py b/predicate/solver/test/test_teleport.py index d907bbd..1856ebe 100644 --- a/predicate/solver/test/test_teleport.py +++ b/predicate/solver/test/test_teleport.py @@ -100,13 +100,33 @@ def test_deny_policy_set(self): ) assert ret is True, "non-denied part of allow is OK" + def test_valid_options(self): + # check that only < inequalities are possible + _ = Options(Options.max_session_ttl < Duration.new(hours=3)) + _ = Options(Duration.new(hours=3) > Options.max_session_ttl) + + with pytest.raises(Exception): + _ = Options(Options.max_session_ttl > Duration.new(hours=3)) + with pytest.raises(Exception): + _ = Options(Duration.new(hours=3) < Options.max_session_ttl) + + with pytest.raises(Exception): + _ = Options(Options.max_session_ttl == Duration.new(hours=3)) + with pytest.raises(Exception): + _ = Options(Duration.new(hours=3) == Options.max_session_ttl) + + with pytest.raises(Exception): + _ = Options(Options.max_session_ttl != Duration.new(hours=3)) + with pytest.raises(Exception): + _ = Options(Duration.new(hours=3) != Options.max_session_ttl) + def test_options(self): p = Policy( name="b", options=OptionsSet( Options( (Options.max_session_ttl < Duration.new(hours=10)) - & (Options.max_session_ttl > Duration.new(seconds=10)), + & (Options.max_session_ttl < Duration.new(seconds=10)), ) ), allow=Rules( @@ -120,22 +140,13 @@ def test_options(self): & (Node.labels["env"] == "prod") & (Node.labels["os"] == "Linux") ) - & Options(Options.max_session_ttl == Duration.new(hours=3)) + # Since we only have <, it's impossible to specify a `session_ttl` that would not be valid. + # This means that the predicate will always match the policy. + & Options(Options.max_session_ttl < Duration.new(hours=3)) ) assert ret is True, "options and core predicate matches" - ret, _ = p.check( - AccessNode( - (AccessNode.login == "root") - & (Node.labels["env"] == "prod") - & (Node.labels["os"] == "Linux") - ) - & Options(Options.max_session_ttl == Duration.new(hours=50)) - ) - - assert ret is False, "options expression fails the entire predicate" - def test_options_extra(self): """ Tests that predicate works when options expression is superset @@ -145,7 +156,6 @@ def test_options_extra(self): options=OptionsSet( Options( (Options.max_session_ttl < Duration.new(hours=10)) - & (Options.max_session_ttl > Duration.new(seconds=10)) ), Options(Options.pin_source_ip == True), ), @@ -161,7 +171,11 @@ def test_options_extra(self): & (Node.labels["env"] == "prod") & (Node.labels["os"] == "Linux") ) - & Options((Options.max_session_ttl == Duration.new(hours=3))) + & Options( + (Options.max_session_ttl < Duration.new(hours=3)) + # TODO: `check` doesn't require that `pin_source_ip` is defined here, but should!?. + & (Options.pin_source_ip == True) + ) ) assert ret is True, "options and core predicate matches" @@ -173,11 +187,10 @@ def test_options_extra(self): & (Node.labels["os"] == "Linux") ) & Options( - (Options.max_session_ttl == Duration.new(hours=3)) + (Options.max_session_ttl < Duration.new(hours=3)) & (Options.pin_source_ip == False) ) ) - assert ret is False, "options fails restriction when contradiction is specified" def test_options_policy_set(self): @@ -186,7 +199,6 @@ def test_options_policy_set(self): options=OptionsSet( Options( (Options.max_session_ttl < Duration.new(hours=10)) - & (Options.max_session_ttl > Duration.new(seconds=10)) ), Options(Options.pin_source_ip == True), ), @@ -210,7 +222,10 @@ def test_options_policy_set(self): & (Node.labels["env"] == "prod") & (Node.labels["os"] == "Linux") ) - & Options((Options.max_session_ttl == Duration.new(hours=3))) + & Options( + (Options.max_session_ttl < Duration.new(hours=3)) + & (Options.pin_source_ip == True) + ) ) assert ret is True, "options and core predicate matches" @@ -222,7 +237,7 @@ def test_options_policy_set(self): & (Node.labels["os"] == "Linux") ) & Options( - (Options.max_session_ttl == Duration.new(hours=3)) + (Options.max_session_ttl < Duration.new(hours=3)) & (Options.pin_source_ip == False) ) ) @@ -446,6 +461,7 @@ def test_full_cycle(self): & (traits["login"] == ("alice-wonderland.local",)) ) ret, _ = p.solve() + assert ret is True, "match and replace works in login rules" s = PolicyMap( Select( @@ -462,13 +478,13 @@ def test_full_cycle(self): (s == ("ext-test", "ext-prod")) & (external["groups"] == ("admin-test", "admin-prod")) ).solve() - assert ret is True, "match and replace works" + assert ret is True, "match and replace works in policy maps" ret, _ = Predicate( (s == ("dev-test", "dev-prod")) & (external["groups"] == ("dev-test", "dev-prod")) ).solve() - assert ret is True, "match and replace works default value" + assert ret is True, "match and replace works in policy maps (default value)" # dev policy allows access to stage, and denies access to root dev = Policy(