Skip to content

Commit

Permalink
Nomad tweaks to gohcl
Browse files Browse the repository at this point in the history
  • Loading branch information
angrycub committed Jul 25, 2023
1 parent 217bb57 commit 58caf00
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 33 deletions.
184 changes: 160 additions & 24 deletions gohcl/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,35 @@ import (
"github.com/zclconf/go-cty/cty/gocty"
)

// ExpressionDecoderFunc represents custom expression decoder for a specific type
type ExpressionDecoderFunc func(expr hcl.Expression, ctx *hcl.EvalContext, val interface{}) hcl.Diagnostics

// BodyDecoderFunc represents custom body decoder for a specific type
type BodyDecoderFunc func(body hcl.Body, ctx *hcl.EvalContext, val interface{}) hcl.Diagnostics

type Decoder struct {
exprConvertors map[reflect.Type]ExpressionDecoderFunc
bodyConvertors map[reflect.Type]BodyDecoderFunc
}

var global = &Decoder{}

// RegisterExpressionDecoder registers a custom expression decoder for a target type.
func (d *Decoder) RegisterExpressionDecoder(typ reflect.Type, fn ExpressionDecoderFunc) {
if d.exprConvertors == nil {
d.exprConvertors = map[reflect.Type]ExpressionDecoderFunc{}
}
d.exprConvertors[typ] = fn
}

// RegisterBlockDecoder registers a custom block decoder for a target type.
func (d *Decoder) RegisterBlockDecoder(typ reflect.Type, fn BodyDecoderFunc) {
if d.bodyConvertors == nil {
d.bodyConvertors = map[reflect.Type]BodyDecoderFunc{}
}
d.bodyConvertors[typ] = fn
}

// DecodeBody extracts the configuration within the given body into the given
// value. This value must be a non-nil pointer to either a struct or
// a map, where in the former case the configuration will be decoded using
Expand All @@ -31,27 +60,47 @@ import (
// may still be accessed by a careful caller for static analysis and editor
// integration use-cases.
func DecodeBody(body hcl.Body, ctx *hcl.EvalContext, val interface{}) hcl.Diagnostics {
return global.DecodeBody(body, ctx, val)
}

// DecodeBody extracts the configuration within the given body into the given
// value. This value must be a non-nil pointer to either a struct or
// a map, where in the former case the configuration will be decoded using
// struct tags and in the latter case only attributes are allowed and their
// values are decoded into the map.
//
// The given EvalContext is used to resolve any variables or functions in
// expressions encountered while decoding. This may be nil to require only
// constant values, for simple applications that do not support variables or
// functions.
//
// The returned diagnostics should be inspected with its HasErrors method to
// determine if the populated value is valid and complete. If error diagnostics
// are returned then the given value may have been partially-populated but
// may still be accessed by a careful caller for static analysis and editor
// integration use-cases.
func (d *Decoder) DecodeBody(body hcl.Body, ctx *hcl.EvalContext, val interface{}) hcl.Diagnostics {
rv := reflect.ValueOf(val)
if rv.Kind() != reflect.Ptr {
panic(fmt.Sprintf("target value must be a pointer, not %s", rv.Type().String()))
}

return decodeBodyToValue(body, ctx, rv.Elem())
return d.decodeBodyToValue(body, ctx, rv.Elem())
}

func decodeBodyToValue(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) hcl.Diagnostics {
func (d *Decoder) decodeBodyToValue(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) hcl.Diagnostics {
et := val.Type()
switch et.Kind() {
case reflect.Struct:
return decodeBodyToStruct(body, ctx, val)
return d.decodeBodyToStruct(body, ctx, val)
case reflect.Map:
return decodeBodyToMap(body, ctx, val)
return d.decodeBodyToMap(body, ctx, val)
default:
panic(fmt.Sprintf("target value must be pointer to struct or map, not %s", et.String()))
}
}

func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) hcl.Diagnostics {
func (d *Decoder) decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) hcl.Diagnostics {
schema, partial := ImpliedBodySchema(val.Interface())

var content *hcl.BodyContent
Expand All @@ -77,7 +126,7 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value)
fieldV.Set(reflect.ValueOf(body))

default:
diags = append(diags, decodeBodyToValue(body, ctx, fieldV)...)
diags = append(diags, d.decodeBodyToValue(body, ctx, fieldV)...)
}
}

Expand All @@ -95,7 +144,7 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value)
}
fieldV.Set(reflect.ValueOf(attrs))
default:
diags = append(diags, decodeBodyToValue(leftovers, ctx, fieldV)...)
diags = append(diags, d.decodeBodyToValue(leftovers, ctx, fieldV)...)
}
}

Expand Down Expand Up @@ -124,7 +173,7 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value)
case exprType.AssignableTo(field.Type):
fieldV.Set(reflect.ValueOf(attr.Expr))
default:
diags = append(diags, DecodeExpression(
diags = append(diags, d.DecodeExpression(
attr.Expr, ctx, fieldV.Addr().Interface(),
)...)
}
Expand All @@ -139,6 +188,7 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value)
ty := field.Type
isSlice := false
isPtr := false
isMap := false
if ty.Kind() == reflect.Slice {
isSlice = true
ty = ty.Elem()
Expand All @@ -147,8 +197,11 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value)
isPtr = true
ty = ty.Elem()
}
if ty.Kind() == reflect.Map {
isMap = true
}

if len(blocks) > 1 && !isSlice {
if len(blocks) > 1 && !isSlice && !(isMap && len(blocks[0].Labels) == 1) {
diags = append(diags, &hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: fmt.Sprintf("Duplicate %s block", typeName),
Expand All @@ -162,7 +215,7 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value)
}

if len(blocks) == 0 {
if isSlice || isPtr {
if isSlice || isPtr || isMap {
if val.Field(fieldIdx).IsNil() {
val.Field(fieldIdx).Set(reflect.Zero(field.Type))
}
Expand Down Expand Up @@ -198,13 +251,13 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value)
if v.IsNil() {
v = reflect.New(ty)
}
diags = append(diags, decodeBlockToValue(block, ctx, v.Elem())...)
diags = append(diags, d.decodeBlockToValue(block, ctx, v.Elem())...)
sli.Index(i).Set(v)
} else {
if i >= sli.Len() {
sli = reflect.Append(sli, reflect.Indirect(reflect.New(ty)))
}
diags = append(diags, decodeBlockToValue(block, ctx, sli.Index(i))...)
diags = append(diags, d.decodeBlockToValue(block, ctx, sli.Index(i))...)
}
}

Expand All @@ -213,28 +266,56 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value)
}

val.Field(fieldIdx).Set(sli)
case isMap && len(blocks[0].Labels) == 1:
v := val.Field(fieldIdx)
if v.IsNil() {
v.Set(reflect.MakeMap(ty))
}

for _, block := range blocks {
tyv := ty.Elem()
isPtr := false
if tyv.Kind() == reflect.Ptr {
isPtr = true
tyv = tyv.Elem()
}
ev := reflect.New(tyv)
diags = append(diags, d.decodeBodyToValue(block.Body, ctx, ev.Elem())...)

blockTags := getFieldTags(tyv)
lv := block.Labels[0]
lfieldIdx := blockTags.Labels[0].FieldIndex
f := ev.Elem().Field(lfieldIdx)
if f.Kind() == reflect.Ptr {
f.Set(reflect.ValueOf(&lv))
} else {
f.SetString(lv)
}

if !isPtr {
ev = ev.Elem()
}
v.SetMapIndex(reflect.ValueOf(lv), ev)
}
default:
block := blocks[0]
if isPtr {
v := val.Field(fieldIdx)
if v.IsNil() {
v = reflect.New(ty)
}
diags = append(diags, decodeBlockToValue(block, ctx, v.Elem())...)
diags = append(diags, d.decodeBlockToValue(block, ctx, v.Elem())...)
val.Field(fieldIdx).Set(v)
} else {
diags = append(diags, decodeBlockToValue(block, ctx, val.Field(fieldIdx))...)
diags = append(diags, d.decodeBlockToValue(block, ctx, val.Field(fieldIdx))...)
}

}

}

return diags
}

func decodeBodyToMap(body hcl.Body, ctx *hcl.EvalContext, v reflect.Value) hcl.Diagnostics {
func (d *Decoder) decodeBodyToMap(body hcl.Body, ctx *hcl.EvalContext, v reflect.Value) hcl.Diagnostics {
attrs, diags := body.JustAttributes()
if attrs == nil {
return diags
Expand All @@ -260,15 +341,38 @@ func decodeBodyToMap(body hcl.Body, ctx *hcl.EvalContext, v reflect.Value) hcl.D
return diags
}

func decodeBlockToValue(block *hcl.Block, ctx *hcl.EvalContext, v reflect.Value) hcl.Diagnostics {
diags := decodeBodyToValue(block.Body, ctx, v)
func (d *Decoder) decodeBlockToValue(block *hcl.Block, ctx *hcl.EvalContext, v reflect.Value) hcl.Diagnostics {
var diags hcl.Diagnostics

if len(block.Labels) > 0 {
blockTags := getFieldTags(v.Type())
for li, lv := range block.Labels {
lfieldIdx := blockTags.Labels[li].FieldIndex
v.Field(lfieldIdx).Set(reflect.ValueOf(lv))
ty := v.Type()

switch {
case blockType.AssignableTo(ty):
v.Elem().Set(reflect.ValueOf(block))
case bodyType.AssignableTo(ty):
v.Elem().Set(reflect.ValueOf(block.Body))
case attrsType.AssignableTo(ty):
attrs, attrsDiags := block.Body.JustAttributes()
if len(attrsDiags) > 0 {
diags = append(diags, attrsDiags...)
}
v.Elem().Set(reflect.ValueOf(attrs))
default:
diags = append(diags, d.decodeBodyToValue(block.Body, ctx, v)...)

if len(block.Labels) > 0 {
blockTags := getFieldTags(ty)
for li, lv := range block.Labels {
lfieldIdx := blockTags.Labels[li].FieldIndex
f := v.Field(lfieldIdx)
if f.Kind() == reflect.Ptr {
f.Set(reflect.ValueOf(&lv))
} else {
f.SetString(lv)
}
}
}

}

return diags
Expand All @@ -289,6 +393,28 @@ func decodeBlockToValue(block *hcl.Block, ctx *hcl.EvalContext, v reflect.Value)
// may still be accessed by a careful caller for static analysis and editor
// integration use-cases.
func DecodeExpression(expr hcl.Expression, ctx *hcl.EvalContext, val interface{}) hcl.Diagnostics {
return global.DecodeExpression(expr, ctx, val)
}

// DecodeExpression extracts the value of the given expression into the given
// value. This value must be something that gocty is able to decode into,
// since the final decoding is delegated to that package.
//
// The given EvalContext is used to resolve any variables or functions in
// expressions encountered while decoding. This may be nil to require only
// constant values, for simple applications that do not support variables or
// functions.
//
// The returned diagnostics should be inspected with its HasErrors method to
// determine if the populated value is valid and complete. If error diagnostics
// are returned then the given value may have been partially-populated but
// may still be accessed by a careful caller for static analysis and editor
// integration use-cases.
func (d *Decoder) DecodeExpression(expr hcl.Expression, ctx *hcl.EvalContext, val interface{}) hcl.Diagnostics {
if diags, ok := d.decodeCustomExpression(expr, ctx, val); ok {
return diags
}

srcVal, diags := expr.Value(ctx)

convTy, err := gocty.ImpliedType(val)
Expand Down Expand Up @@ -321,3 +447,13 @@ func DecodeExpression(expr hcl.Expression, ctx *hcl.EvalContext, val interface{}

return diags
}

func (d *Decoder) decodeCustomExpression(expr hcl.Expression, ctx *hcl.EvalContext, val interface{}) (hcl.Diagnostics, bool) {
ty := reflect.TypeOf(val).Elem()
fn, ok := d.exprConvertors[ty]
if !ok {
return nil, false
}
diags := fn(expr, ctx, val)
return diags, true
}
Loading

0 comments on commit 58caf00

Please sign in to comment.