diff --git a/ast/compile_test.go b/ast/compile_test.go index 98c9e5850d..f586ceb331 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -4093,113 +4093,124 @@ func TestCompilerResolveAllRefs(t *testing.T) { c.Modules = getCompilerTestModules() c.Modules["head"] = MustParseModule(`package head +import rego.v1 import data.doc1 as bar import input.x.y.foo import input.qux as baz -p[foo[bar[i]]] = {"baz": baz} { true }`) +p[foo[bar[i]]] := {"baz": baz} if { true }`) c.Modules["elsekw"] = MustParseModule(`package elsekw + import rego.v1 import input.x.y.foo import data.doc1 as bar import input.baz - p { + p if { false - } else = foo { + } else = foo if { bar - } else = baz { + } else = baz if { true } `) c.Modules["nestedexprs"] = MustParseModule(`package nestedexprs + import rego.v1 x = 1 - p { + p if { f(g(x)) }`) c.Modules["assign"] = MustParseModule(`package assign + import rego.v1 x = 1 y = 1 - p { + p if { x := y [true | x := y] }`) c.Modules["someinassign"] = MustParseModule(`package someinassign - import future.keywords.in + import rego.v1 + x = 1 y = 1 - p[x] { + p[x] if { some x in [1, 2, y] }`) c.Modules["someinassignwithkey"] = MustParseModule(`package someinassignwithkey - import future.keywords.in + import rego.v1 + x = 1 y = 1 - p[x] { + p[x] if { some k, v in [1, 2, y] }`) c.Modules["donotresolve"] = MustParseModule(`package donotresolve + import rego.v1 x = 1 - f(x) { + f(x) if { x = 2 } `) c.Modules["indirectrefs"] = MustParseModule(`package indirectrefs + import rego.v1 - f(x) = [x] {true} + f(x) = [x] if {true} - p { + p if { f(1)[0] } `) c.Modules["comprehensions"] = MustParseModule(`package comprehensions + import rego.v1 nums = [1, 2, 3] - f(x) = [x] {true} + f(x) = [x] if {true} - p[[1]] {true} + p[[1]] if {true} - q { + q if { p[[x | x = nums[_]]] } r = [y | y = f(1)[0]] `) - c.Modules["everykw"] = MustParseModuleWithOpts(`package everykw + c.Modules["everykw"] = MustParseModule(`package everykw + import rego.v1 - nums = {1, 2, 3} - f(_) = true - x = 100 - xs = [1, 2, 3] - p { - every x in xs { - nums[x] - x > 10 - } - }`, ParserOptions{unreleasedKeywords: true, FutureKeywords: []string{"every", "in"}}) + nums = {1, 2, 3} + f(_) = true + x = 100 + xs = [1, 2, 3] + p if { + every x in xs { + nums[x] + x > 10 + } + }`) c.Modules["heads_with_dots"] = MustParseModule(`package heads_with_dots + import rego.v1 this_is_not = true - this.is.dotted { this_is_not } + this.is.dotted if { this_is_not } `) compileStages(c, c.resolveAllRefs) diff --git a/ast/parser.go b/ast/parser.go index 1d56eedd68..3037e1e855 100644 --- a/ast/parser.go +++ b/ast/parser.go @@ -713,6 +713,10 @@ func (p *Parser) parseRules() []*Rule { // p[x] if ... becomes a single-value rule p[x] if hasIf && !usesContains && len(rule.Head.Ref()) == 2 { + if !rule.Head.Ref()[1].IsGround() && len(rule.Head.Args) == 0 { + rule.Head.Key = rule.Head.Ref()[1] + } + if rule.Head.Value == nil { rule.Head.generatedValue = true rule.Head.Value = BooleanTerm(true).SetLocation(rule.Head.Location) diff --git a/ast/parser_test.go b/ast/parser_test.go index c54efabf06..eff9d4585f 100644 --- a/ast/parser_test.go +++ b/ast/parser_test.go @@ -2652,6 +2652,7 @@ func TestRuleIf(t *testing.T) { Head: &Head{ Name: Var("p"), Reference: MustParseRef("p[x]"), + Key: VarTerm("x"), Value: VarTerm("y"), }, Body: MustParseBody(`x := "foo"; y := "bar"`), @@ -2687,6 +2688,7 @@ func TestRuleIf(t *testing.T) { exp: &Rule{ Head: &Head{ Reference: MustParseRef("p[x]"), + Key: VarTerm("x"), Value: BooleanTerm(true), }, Body: MustParseBody(`x := 1`), @@ -2698,6 +2700,7 @@ func TestRuleIf(t *testing.T) { exp: &Rule{ Head: &Head{ Reference: MustParseRef("p[x]"), + Key: VarTerm("x"), Value: BooleanTerm(true), }, Body: MustParseBody(`x := 1`), @@ -2783,6 +2786,7 @@ func TestRuleRefHeads(t *testing.T) { Head: &Head{ Name: Var("p"), Reference: MustParseRef("p[x]"), + Key: VarTerm("x"), Value: IntNumberTerm(1), }, Body: MustParseBody("x := 2"), diff --git a/format/format.go b/format/format.go index dcdd6ee7d0..820c645d6e 100644 --- a/format/format.go +++ b/format/format.go @@ -568,8 +568,7 @@ func (w *writer) writeHead(head *ast.Head, isDefault, isExpandedConst bool, o fm // * a.b.c.d -> a.b.c.d := true isRegoV1RefConst := o.regoV1 && isExpandedConst && head.Key == nil && len(head.Args) == 0 - if len(head.Args) > 0 && - head.Location == head.Value.Location && + if head.Location == head.Value.Location && head.Name != "else" && ast.Compare(head.Value, ast.BooleanTerm(true)) == 0 && !isRegoV1RefConst {