Skip to content

Commit

Permalink
merge sasha/python
Browse files Browse the repository at this point in the history
  • Loading branch information
Vitor Enes committed Nov 15, 2022
2 parents 10a5771 + 5bfb6c0 commit 5081623
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 55 deletions.
58 changes: 26 additions & 32 deletions predicate/solver/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,11 @@ def walk(self, fn):
def __str__(self):
return "int({})".format(self.name)

class LtDuration:
"""
LtDuration is a duration that only allows < inequalities.
"""

class Duration:
def __init__(self, name: str):
self.name = name
self.val = z3.Int(self.name)
Expand All @@ -439,7 +442,16 @@ def walk(self, fn):
fn(self)

def __str__(self):
return "duration({})".format(self.name)
return "ltduration({})".format(self.name)

def __lt__(self, val: DurationLiteral):
return Lt(self, val)


class Duration(LtDuration):
"""
Duration is a duration that allows <, >, == and != inequalities.
"""

@staticmethod
def new(
Expand All @@ -459,35 +471,17 @@ def new(
+ nanoseconds * NANOSECOND
)

def __eq__(self, val):
if isinstance(val, (Duration, DurationLiteral)):
return Eq(self, val)
raise TypeError(
"unsupported type {}, supported integers only".format(type(val))
)
def __str__(self):
return "duration({})".format(self.name)

def __ne__(self, val):
if isinstance(val, (Duration, DurationLiteral)):
return Not(Eq(self, val))
raise TypeError(
"unsupported type {}, supported integers only".format(type(val))
)
def __eq__(self, val: DurationLiteral):
return Eq(self, val)

def __lt__(self, val):
if isinstance(val, (Duration, DurationLiteral)):
return Lt(self, val)
raise TypeError(
"unsupported type {}, supported integers only".format(type(val))
)
def __ne__(self, val: DurationLiteral):
return Not(Eq(self, val))

def __gt__(self, val):
if isinstance(val, (Duration, DurationLiteral)):
return Gt(self, val)
raise TypeError(
"unsupported type {}, supported duration and duration literals only".format(
type(val)
)
)
def __gt__(self, val: DurationLiteral):
return Gt(self, val)


class Bool:
Expand All @@ -501,7 +495,7 @@ def __eq__(self, val):
if isinstance(val, (Bool,)):
return Eq(self, val)
raise TypeError(
"unsupported type {}, supported integers only".format(type(val))
"unsupported type {}, supported booleans only".format(type(val))
)

def __ne__(self, val):
Expand All @@ -510,7 +504,7 @@ def __ne__(self, val):
if isinstance(val, (Bool,)):
return Not(Eq(self, val))
raise TypeError(
"unsupported type {}, supported integers only".format(type(val))
"unsupported type {}, supported booleans only".format(type(val))
)

def traverse(self):
Expand Down Expand Up @@ -1986,14 +1980,14 @@ def traverse(self):


def collect_symbols(s, expr):
if isinstance(expr, (String, Int, Duration, Bool, StringEnum)):
if isinstance(expr, (String, Int, LtDuration, Duration, Bool, StringEnum)):
s.add(expr.name)
if isinstance(expr, MapIndex):
s.add(expr.m.name + "." + expr.key)


def collect_names(s, expr):
if isinstance(expr, (String, Int, Duration, Bool, StringEnum)):
if isinstance(expr, (String, Int, LtDuration, Duration, Bool, StringEnum)):
s.add(expr.name)
if isinstance(expr, MapIndex):
s.add(expr.m.name)
Expand Down
3 changes: 2 additions & 1 deletion predicate/solver/teleport.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Options(ast.Predicate):
Options apply to some allow rules if they match
"""

max_session_ttl = ast.Duration("options.max_session_ttl")
max_session_ttl = ast.LtDuration("options.max_session_ttl")

pin_source_ip = ast.Bool("options.pin_source_ip")

Expand Down Expand Up @@ -277,6 +277,7 @@ def t_expr(predicate):
predicate,
(
ast.String,
ast.LtDuration,
ast.Duration,
ast.StringList,
ast.StringEnum,
Expand Down
60 changes: 38 additions & 22 deletions predicate/solver/test/test_teleport.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,33 @@ def test_deny_policy_set(self):
)
assert ret is True, "non-denied part of allow is OK"

def test_valid_options(self):
# check that only < inequalities are possible
_ = Options(Options.max_session_ttl < Duration.new(hours=3))
_ = Options(Duration.new(hours=3) > Options.max_session_ttl)

with pytest.raises(Exception):
_ = Options(Options.max_session_ttl > Duration.new(hours=3))
with pytest.raises(Exception):
_ = Options(Duration.new(hours=3) < Options.max_session_ttl)

with pytest.raises(Exception):
_ = Options(Options.max_session_ttl == Duration.new(hours=3))
with pytest.raises(Exception):
_ = Options(Duration.new(hours=3) == Options.max_session_ttl)

with pytest.raises(Exception):
_ = Options(Options.max_session_ttl != Duration.new(hours=3))
with pytest.raises(Exception):
_ = Options(Duration.new(hours=3) != Options.max_session_ttl)

def test_options(self):
p = Policy(
name="b",
options=OptionsSet(
Options(
(Options.max_session_ttl < Duration.new(hours=10))
& (Options.max_session_ttl > Duration.new(seconds=10)),
& (Options.max_session_ttl < Duration.new(seconds=10)),
)
),
allow=Rules(
Expand All @@ -120,22 +140,13 @@ def test_options(self):
& (Node.labels["env"] == "prod")
& (Node.labels["os"] == "Linux")
)
& Options(Options.max_session_ttl == Duration.new(hours=3))
# Since we only have <, it's impossible to specify a `session_ttl` that would not be valid.
# This means that the predicate will always match the policy.
& Options(Options.max_session_ttl < Duration.new(hours=3))
)

assert ret is True, "options and core predicate matches"

ret, _ = p.check(
AccessNode(
(AccessNode.login == "root")
& (Node.labels["env"] == "prod")
& (Node.labels["os"] == "Linux")
)
& Options(Options.max_session_ttl == Duration.new(hours=50))
)

assert ret is False, "options expression fails the entire predicate"

def test_options_extra(self):
"""
Tests that predicate works when options expression is superset
Expand All @@ -145,7 +156,6 @@ def test_options_extra(self):
options=OptionsSet(
Options(
(Options.max_session_ttl < Duration.new(hours=10))
& (Options.max_session_ttl > Duration.new(seconds=10))
),
Options(Options.pin_source_ip == True),
),
Expand All @@ -161,7 +171,11 @@ def test_options_extra(self):
& (Node.labels["env"] == "prod")
& (Node.labels["os"] == "Linux")
)
& Options((Options.max_session_ttl == Duration.new(hours=3)))
& Options(
(Options.max_session_ttl < Duration.new(hours=3))
# TODO: `check` doesn't require that `pin_source_ip` is defined here, but should!?.
& (Options.pin_source_ip == True)
)
)

assert ret is True, "options and core predicate matches"
Expand All @@ -173,11 +187,10 @@ def test_options_extra(self):
& (Node.labels["os"] == "Linux")
)
& Options(
(Options.max_session_ttl == Duration.new(hours=3))
(Options.max_session_ttl < Duration.new(hours=3))
& (Options.pin_source_ip == False)
)
)

assert ret is False, "options fails restriction when contradiction is specified"

def test_options_policy_set(self):
Expand All @@ -186,7 +199,6 @@ def test_options_policy_set(self):
options=OptionsSet(
Options(
(Options.max_session_ttl < Duration.new(hours=10))
& (Options.max_session_ttl > Duration.new(seconds=10))
),
Options(Options.pin_source_ip == True),
),
Expand All @@ -210,7 +222,10 @@ def test_options_policy_set(self):
& (Node.labels["env"] == "prod")
& (Node.labels["os"] == "Linux")
)
& Options((Options.max_session_ttl == Duration.new(hours=3)))
& Options(
(Options.max_session_ttl < Duration.new(hours=3))
& (Options.pin_source_ip == True)
)
)

assert ret is True, "options and core predicate matches"
Expand All @@ -222,7 +237,7 @@ def test_options_policy_set(self):
& (Node.labels["os"] == "Linux")
)
& Options(
(Options.max_session_ttl == Duration.new(hours=3))
(Options.max_session_ttl < Duration.new(hours=3))
& (Options.pin_source_ip == False)
)
)
Expand Down Expand Up @@ -446,6 +461,7 @@ def test_full_cycle(self):
& (traits["login"] == ("alice-wonderland.local",))
)
ret, _ = p.solve()
assert ret is True, "match and replace works in login rules"

s = PolicyMap(
Select(
Expand All @@ -462,13 +478,13 @@ def test_full_cycle(self):
(s == ("ext-test", "ext-prod"))
& (external["groups"] == ("admin-test", "admin-prod"))
).solve()
assert ret is True, "match and replace works"
assert ret is True, "match and replace works in policy maps"

ret, _ = Predicate(
(s == ("dev-test", "dev-prod"))
& (external["groups"] == ("dev-test", "dev-prod"))
).solve()
assert ret is True, "match and replace works default value"
assert ret is True, "match and replace works in policy maps (default value)"

# dev policy allows access to stage, and denies access to root
dev = Policy(
Expand Down

0 comments on commit 5081623

Please sign in to comment.