diff --git a/gohcl/decode.go b/gohcl/decode.go index 2d1776a3..53e654f6 100644 --- a/gohcl/decode.go +++ b/gohcl/decode.go @@ -14,6 +14,35 @@ import ( "github.com/zclconf/go-cty/cty/gocty" ) +// ExpressionDecoderFunc represents custom expression decoder for a specific type +type ExpressionDecoderFunc func(expr hcl.Expression, ctx *hcl.EvalContext, val interface{}) hcl.Diagnostics + +// BodyDecoderFunc represents custom body decoder for a specific type +type BodyDecoderFunc func(body hcl.Body, ctx *hcl.EvalContext, val interface{}) hcl.Diagnostics + +type Decoder struct { + exprConvertors map[reflect.Type]ExpressionDecoderFunc + bodyConvertors map[reflect.Type]BodyDecoderFunc +} + +var global = &Decoder{} + +// RegisterExpressionDecoder registers a custom expression decoder for a target type. +func (d *Decoder) RegisterExpressionDecoder(typ reflect.Type, fn ExpressionDecoderFunc) { + if d.exprConvertors == nil { + d.exprConvertors = map[reflect.Type]ExpressionDecoderFunc{} + } + d.exprConvertors[typ] = fn +} + +// RegisterBlockDecoder registers a custom block decoder for a target type. +func (d *Decoder) RegisterBlockDecoder(typ reflect.Type, fn BodyDecoderFunc) { + if d.bodyConvertors == nil { + d.bodyConvertors = map[reflect.Type]BodyDecoderFunc{} + } + d.bodyConvertors[typ] = fn +} + // DecodeBody extracts the configuration within the given body into the given // value. This value must be a non-nil pointer to either a struct or // a map, where in the former case the configuration will be decoded using @@ -31,27 +60,47 @@ import ( // may still be accessed by a careful caller for static analysis and editor // integration use-cases. func DecodeBody(body hcl.Body, ctx *hcl.EvalContext, val interface{}) hcl.Diagnostics { + return global.DecodeBody(body, ctx, val) +} + +// DecodeBody extracts the configuration within the given body into the given +// value. This value must be a non-nil pointer to either a struct or +// a map, where in the former case the configuration will be decoded using +// struct tags and in the latter case only attributes are allowed and their +// values are decoded into the map. +// +// The given EvalContext is used to resolve any variables or functions in +// expressions encountered while decoding. This may be nil to require only +// constant values, for simple applications that do not support variables or +// functions. +// +// The returned diagnostics should be inspected with its HasErrors method to +// determine if the populated value is valid and complete. If error diagnostics +// are returned then the given value may have been partially-populated but +// may still be accessed by a careful caller for static analysis and editor +// integration use-cases. +func (d *Decoder) DecodeBody(body hcl.Body, ctx *hcl.EvalContext, val interface{}) hcl.Diagnostics { rv := reflect.ValueOf(val) if rv.Kind() != reflect.Ptr { panic(fmt.Sprintf("target value must be a pointer, not %s", rv.Type().String())) } - return decodeBodyToValue(body, ctx, rv.Elem()) + return d.decodeBodyToValue(body, ctx, rv.Elem()) } -func decodeBodyToValue(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) hcl.Diagnostics { +func (d *Decoder) decodeBodyToValue(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) hcl.Diagnostics { et := val.Type() switch et.Kind() { case reflect.Struct: - return decodeBodyToStruct(body, ctx, val) + return d.decodeBodyToStruct(body, ctx, val) case reflect.Map: - return decodeBodyToMap(body, ctx, val) + return d.decodeBodyToMap(body, ctx, val) default: panic(fmt.Sprintf("target value must be pointer to struct or map, not %s", et.String())) } } -func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) hcl.Diagnostics { +func (d *Decoder) decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) hcl.Diagnostics { schema, partial := ImpliedBodySchema(val.Interface()) var content *hcl.BodyContent @@ -77,7 +126,7 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) fieldV.Set(reflect.ValueOf(body)) default: - diags = append(diags, decodeBodyToValue(body, ctx, fieldV)...) + diags = append(diags, d.decodeBodyToValue(body, ctx, fieldV)...) } } @@ -95,7 +144,7 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) } fieldV.Set(reflect.ValueOf(attrs)) default: - diags = append(diags, decodeBodyToValue(leftovers, ctx, fieldV)...) + diags = append(diags, d.decodeBodyToValue(leftovers, ctx, fieldV)...) } } @@ -124,7 +173,7 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) case exprType.AssignableTo(field.Type): fieldV.Set(reflect.ValueOf(attr.Expr)) default: - diags = append(diags, DecodeExpression( + diags = append(diags, d.DecodeExpression( attr.Expr, ctx, fieldV.Addr().Interface(), )...) } @@ -139,6 +188,7 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) ty := field.Type isSlice := false isPtr := false + isMap := false if ty.Kind() == reflect.Slice { isSlice = true ty = ty.Elem() @@ -147,8 +197,11 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) isPtr = true ty = ty.Elem() } + if ty.Kind() == reflect.Map { + isMap = true + } - if len(blocks) > 1 && !isSlice { + if len(blocks) > 1 && !isSlice && !(isMap && len(blocks[0].Labels) == 1) { diags = append(diags, &hcl.Diagnostic{ Severity: hcl.DiagError, Summary: fmt.Sprintf("Duplicate %s block", typeName), @@ -162,7 +215,7 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) } if len(blocks) == 0 { - if isSlice || isPtr { + if isSlice || isPtr || isMap { if val.Field(fieldIdx).IsNil() { val.Field(fieldIdx).Set(reflect.Zero(field.Type)) } @@ -198,13 +251,13 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) if v.IsNil() { v = reflect.New(ty) } - diags = append(diags, decodeBlockToValue(block, ctx, v.Elem())...) + diags = append(diags, d.decodeBlockToValue(block, ctx, v.Elem())...) sli.Index(i).Set(v) } else { if i >= sli.Len() { sli = reflect.Append(sli, reflect.Indirect(reflect.New(ty))) } - diags = append(diags, decodeBlockToValue(block, ctx, sli.Index(i))...) + diags = append(diags, d.decodeBlockToValue(block, ctx, sli.Index(i))...) } } @@ -213,7 +266,37 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) } val.Field(fieldIdx).Set(sli) + case isMap && len(blocks[0].Labels) == 1: + v := val.Field(fieldIdx) + if v.IsNil() { + v.Set(reflect.MakeMap(ty)) + } + + for _, block := range blocks { + tyv := ty.Elem() + isPtr := false + if tyv.Kind() == reflect.Ptr { + isPtr = true + tyv = tyv.Elem() + } + ev := reflect.New(tyv) + diags = append(diags, d.decodeBodyToValue(block.Body, ctx, ev.Elem())...) + + blockTags := getFieldTags(tyv) + lv := block.Labels[0] + lfieldIdx := blockTags.Labels[0].FieldIndex + f := ev.Elem().Field(lfieldIdx) + if f.Kind() == reflect.Ptr { + f.Set(reflect.ValueOf(&lv)) + } else { + f.SetString(lv) + } + if !isPtr { + ev = ev.Elem() + } + v.SetMapIndex(reflect.ValueOf(lv), ev) + } default: block := blocks[0] if isPtr { @@ -221,20 +304,18 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) if v.IsNil() { v = reflect.New(ty) } - diags = append(diags, decodeBlockToValue(block, ctx, v.Elem())...) + diags = append(diags, d.decodeBlockToValue(block, ctx, v.Elem())...) val.Field(fieldIdx).Set(v) } else { - diags = append(diags, decodeBlockToValue(block, ctx, val.Field(fieldIdx))...) + diags = append(diags, d.decodeBlockToValue(block, ctx, val.Field(fieldIdx))...) } - } - } return diags } -func decodeBodyToMap(body hcl.Body, ctx *hcl.EvalContext, v reflect.Value) hcl.Diagnostics { +func (d *Decoder) decodeBodyToMap(body hcl.Body, ctx *hcl.EvalContext, v reflect.Value) hcl.Diagnostics { attrs, diags := body.JustAttributes() if attrs == nil { return diags @@ -260,15 +341,38 @@ func decodeBodyToMap(body hcl.Body, ctx *hcl.EvalContext, v reflect.Value) hcl.D return diags } -func decodeBlockToValue(block *hcl.Block, ctx *hcl.EvalContext, v reflect.Value) hcl.Diagnostics { - diags := decodeBodyToValue(block.Body, ctx, v) +func (d *Decoder) decodeBlockToValue(block *hcl.Block, ctx *hcl.EvalContext, v reflect.Value) hcl.Diagnostics { + var diags hcl.Diagnostics - if len(block.Labels) > 0 { - blockTags := getFieldTags(v.Type()) - for li, lv := range block.Labels { - lfieldIdx := blockTags.Labels[li].FieldIndex - v.Field(lfieldIdx).Set(reflect.ValueOf(lv)) + ty := v.Type() + + switch { + case blockType.AssignableTo(ty): + v.Elem().Set(reflect.ValueOf(block)) + case bodyType.AssignableTo(ty): + v.Elem().Set(reflect.ValueOf(block.Body)) + case attrsType.AssignableTo(ty): + attrs, attrsDiags := block.Body.JustAttributes() + if len(attrsDiags) > 0 { + diags = append(diags, attrsDiags...) } + v.Elem().Set(reflect.ValueOf(attrs)) + default: + diags = append(diags, d.decodeBodyToValue(block.Body, ctx, v)...) + + if len(block.Labels) > 0 { + blockTags := getFieldTags(ty) + for li, lv := range block.Labels { + lfieldIdx := blockTags.Labels[li].FieldIndex + f := v.Field(lfieldIdx) + if f.Kind() == reflect.Ptr { + f.Set(reflect.ValueOf(&lv)) + } else { + f.SetString(lv) + } + } + } + } return diags @@ -289,6 +393,28 @@ func decodeBlockToValue(block *hcl.Block, ctx *hcl.EvalContext, v reflect.Value) // may still be accessed by a careful caller for static analysis and editor // integration use-cases. func DecodeExpression(expr hcl.Expression, ctx *hcl.EvalContext, val interface{}) hcl.Diagnostics { + return global.DecodeExpression(expr, ctx, val) +} + +// DecodeExpression extracts the value of the given expression into the given +// value. This value must be something that gocty is able to decode into, +// since the final decoding is delegated to that package. +// +// The given EvalContext is used to resolve any variables or functions in +// expressions encountered while decoding. This may be nil to require only +// constant values, for simple applications that do not support variables or +// functions. +// +// The returned diagnostics should be inspected with its HasErrors method to +// determine if the populated value is valid and complete. If error diagnostics +// are returned then the given value may have been partially-populated but +// may still be accessed by a careful caller for static analysis and editor +// integration use-cases. +func (d *Decoder) DecodeExpression(expr hcl.Expression, ctx *hcl.EvalContext, val interface{}) hcl.Diagnostics { + if diags, ok := d.decodeCustomExpression(expr, ctx, val); ok { + return diags + } + srcVal, diags := expr.Value(ctx) convTy, err := gocty.ImpliedType(val) @@ -321,3 +447,13 @@ func DecodeExpression(expr hcl.Expression, ctx *hcl.EvalContext, val interface{} return diags } + +func (d *Decoder) decodeCustomExpression(expr hcl.Expression, ctx *hcl.EvalContext, val interface{}) (hcl.Diagnostics, bool) { + ty := reflect.TypeOf(val).Elem() + fn, ok := d.exprConvertors[ty] + if !ok { + return nil, false + } + diags := fn(expr, ctx, val) + return diags, true +} diff --git a/gohcl/decode_test.go b/gohcl/decode_test.go index 1ac5d049..787fb9a5 100644 --- a/gohcl/decode_test.go +++ b/gohcl/decode_test.go @@ -429,6 +429,27 @@ func TestDecodeBody(t *testing.T) { }, 2, }, + { + map[string]interface{}{ + "noodle": map[string]interface{}{ + "sample_label": map[string]interface{}{}, + }, + }, + makeInstantiateType(struct { + Noodle struct { + Name *string `hcl:"name,label"` + } `hcl:"noodle,block"` + }{}), + func(gotI interface{}) bool { + noodle := gotI.(struct { + Noodle struct { + Name *string `hcl:"name,label"` + } `hcl:"noodle,block"` + }).Noodle + return noodle.Name != nil && *noodle.Name == "sample_label" + }, + 0, + }, { map[string]interface{}{ "noodle": map[string]interface{}{ @@ -490,6 +511,68 @@ func TestDecodeBody(t *testing.T) { }, 0, }, + { + map[string]interface{}{ + "noodle": map[string]interface{}{ + "foo_foo": map[string]interface{}{}, + "bar_baz": map[string]interface{}{}, + }, + }, + makeInstantiateType(struct { + Noodles map[string]struct { + Name string `hcl:"name,label"` + } `hcl:"noodle,block"` + }{}), + func(gotI interface{}) bool { + noodles := gotI.(struct { + Noodles map[string]struct { + Name string `hcl:"name,label"` + } `hcl:"noodle,block"` + }).Noodles + if len(noodles) != 2 { + return false + } + if _, ok := noodles["foo_foo"]; !ok { + return false + } + if _, ok := noodles["bar_baz"]; !ok { + return false + } + return true + }, + 0, + }, + { + map[string]interface{}{ + "noodle": map[string]interface{}{ + "foo_foo": map[string]interface{}{}, + "bar_baz": map[string]interface{}{}, + }, + }, + makeInstantiateType(struct { + Noodles map[string]*struct { + Name string `hcl:"name,label"` + } `hcl:"noodle,block"` + }{}), + func(gotI interface{}) bool { + noodles := gotI.(struct { + Noodles map[string]*struct { + Name string `hcl:"name,label"` + } `hcl:"noodle,block"` + }).Noodles + if len(noodles) != 2 { + return false + } + if _, ok := noodles["foo_foo"]; !ok { + return false + } + if _, ok := noodles["bar_baz"]; !ok { + return false + } + return true + }, + 0, + }, { map[string]interface{}{ "noodle": map[string]interface{}{ diff --git a/gohcl/schema.go b/gohcl/schema.go index 0cdca271..2c970cf2 100644 --- a/gohcl/schema.go +++ b/gohcl/schema.go @@ -83,19 +83,40 @@ func ImpliedBodySchema(val interface{}) (schema *hcl.BodySchema, partial bool) { if fty.Kind() == reflect.Ptr { fty = fty.Elem() } - if fty.Kind() != reflect.Struct { + + var labelNames []string + + switch fty.Kind() { + case reflect.Struct: + ftags := getFieldTags(fty) + if len(ftags.Labels) > 0 { + labelNames = make([]string, len(ftags.Labels)) + for i, l := range ftags.Labels { + labelNames[i] = l.Name + } + } + case reflect.Map: + fme := fty.Elem() + if fme.Kind() == reflect.Slice { + fme = fme.Elem() + } + if fme.Kind() == reflect.Ptr { + fme = fme.Elem() + } + if fme.Kind() == reflect.Struct { + ftags := getFieldTags(fme) + if len(ftags.Labels) > 0 { + labelNames = make([]string, len(ftags.Labels)) + for i, l := range ftags.Labels { + labelNames[i] = l.Name + } + } + } + default: panic(fmt.Sprintf( "hcl 'block' tag kind cannot be applied to %s field %s: struct required", field.Type.String(), field.Name, )) } - ftags := getFieldTags(fty) - var labelNames []string - if len(ftags.Labels) > 0 { - labelNames = make([]string, len(ftags.Labels)) - for i, l := range ftags.Labels { - labelNames[i] = l.Name - } - } blockSchemas = append(blockSchemas, hcl.BlockHeaderSchema{ Type: n,