Skip to content

Commit

Permalink
Assigning rule key also for if-rules with 2-part non-ground refs (#…
Browse files Browse the repository at this point in the history
…7004)

Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling committed Sep 12, 2024
1 parent 2d63d71 commit ebe5b01
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 30 deletions.
67 changes: 39 additions & 28 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4093,113 +4093,124 @@ func TestCompilerResolveAllRefs(t *testing.T) {
c.Modules = getCompilerTestModules()
c.Modules["head"] = MustParseModule(`package head
import rego.v1
import data.doc1 as bar
import input.x.y.foo
import input.qux as baz
p[foo[bar[i]]] = {"baz": baz} { true }`)
p[foo[bar[i]]] := {"baz": baz} if { true }`)

c.Modules["elsekw"] = MustParseModule(`package elsekw
import rego.v1
import input.x.y.foo
import data.doc1 as bar
import input.baz
p {
p if {
false
} else = foo {
} else = foo if {
bar
} else = baz {
} else = baz if {
true
}
`)

c.Modules["nestedexprs"] = MustParseModule(`package nestedexprs
import rego.v1
x = 1
p {
p if {
f(g(x))
}`)

c.Modules["assign"] = MustParseModule(`package assign
import rego.v1
x = 1
y = 1
p {
p if {
x := y
[true | x := y]
}`)

c.Modules["someinassign"] = MustParseModule(`package someinassign
import future.keywords.in
import rego.v1
x = 1
y = 1
p[x] {
p[x] if {
some x in [1, 2, y]
}`)

c.Modules["someinassignwithkey"] = MustParseModule(`package someinassignwithkey
import future.keywords.in
import rego.v1
x = 1
y = 1
p[x] {
p[x] if {
some k, v in [1, 2, y]
}`)

c.Modules["donotresolve"] = MustParseModule(`package donotresolve
import rego.v1
x = 1
f(x) {
f(x) if {
x = 2
}
`)

c.Modules["indirectrefs"] = MustParseModule(`package indirectrefs
import rego.v1
f(x) = [x] {true}
f(x) = [x] if {true}
p {
p if {
f(1)[0]
}
`)

c.Modules["comprehensions"] = MustParseModule(`package comprehensions
import rego.v1
nums = [1, 2, 3]
f(x) = [x] {true}
f(x) = [x] if {true}
p[[1]] {true}
p[[1]] if {true}
q {
q if {
p[[x | x = nums[_]]]
}
r = [y | y = f(1)[0]]
`)

c.Modules["everykw"] = MustParseModuleWithOpts(`package everykw
c.Modules["everykw"] = MustParseModule(`package everykw
import rego.v1
nums = {1, 2, 3}
f(_) = true
x = 100
xs = [1, 2, 3]
p {
every x in xs {
nums[x]
x > 10
}
}`, ParserOptions{unreleasedKeywords: true, FutureKeywords: []string{"every", "in"}})
nums = {1, 2, 3}
f(_) = true
x = 100
xs = [1, 2, 3]
p if {
every x in xs {
nums[x]
x > 10
}
}`)

c.Modules["heads_with_dots"] = MustParseModule(`package heads_with_dots
import rego.v1
this_is_not = true
this.is.dotted { this_is_not }
this.is.dotted if { this_is_not }
`)

compileStages(c, c.resolveAllRefs)
Expand Down
4 changes: 4 additions & 0 deletions ast/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,10 @@ func (p *Parser) parseRules() []*Rule {

// p[x] if ... becomes a single-value rule p[x]
if hasIf && !usesContains && len(rule.Head.Ref()) == 2 {
if !rule.Head.Ref()[1].IsGround() && len(rule.Head.Args) == 0 {
rule.Head.Key = rule.Head.Ref()[1]
}

if rule.Head.Value == nil {
rule.Head.generatedValue = true
rule.Head.Value = BooleanTerm(true).SetLocation(rule.Head.Location)
Expand Down
4 changes: 4 additions & 0 deletions ast/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2652,6 +2652,7 @@ func TestRuleIf(t *testing.T) {
Head: &Head{
Name: Var("p"),
Reference: MustParseRef("p[x]"),
Key: VarTerm("x"),
Value: VarTerm("y"),
},
Body: MustParseBody(`x := "foo"; y := "bar"`),
Expand Down Expand Up @@ -2687,6 +2688,7 @@ func TestRuleIf(t *testing.T) {
exp: &Rule{
Head: &Head{
Reference: MustParseRef("p[x]"),
Key: VarTerm("x"),
Value: BooleanTerm(true),
},
Body: MustParseBody(`x := 1`),
Expand All @@ -2698,6 +2700,7 @@ func TestRuleIf(t *testing.T) {
exp: &Rule{
Head: &Head{
Reference: MustParseRef("p[x]"),
Key: VarTerm("x"),
Value: BooleanTerm(true),
},
Body: MustParseBody(`x := 1`),
Expand Down Expand Up @@ -2783,6 +2786,7 @@ func TestRuleRefHeads(t *testing.T) {
Head: &Head{
Name: Var("p"),
Reference: MustParseRef("p[x]"),
Key: VarTerm("x"),
Value: IntNumberTerm(1),
},
Body: MustParseBody("x := 2"),
Expand Down
3 changes: 1 addition & 2 deletions format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,7 @@ func (w *writer) writeHead(head *ast.Head, isDefault, isExpandedConst bool, o fm
// * a.b.c.d -> a.b.c.d := true
isRegoV1RefConst := o.regoV1 && isExpandedConst && head.Key == nil && len(head.Args) == 0

if len(head.Args) > 0 &&
head.Location == head.Value.Location &&
if head.Location == head.Value.Location &&
head.Name != "else" &&
ast.Compare(head.Value, ast.BooleanTerm(true)) == 0 &&
!isRegoV1RefConst {
Expand Down

0 comments on commit ebe5b01

Please sign in to comment.