Skip to content

Commit

Permalink
WIP: make variant
Browse files Browse the repository at this point in the history
  • Loading branch information
matthchr committed Jan 30, 2024
1 parent d56e037 commit a334a78
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 20 deletions.
67 changes: 47 additions & 20 deletions make.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ package rapid
import (
"fmt"
"reflect"
"strings"
)

// Make creates a generator of values of type V, using reflection to infer the required structure.
func Make[V any]() *Generator[V] {
var zero V
gen := newMakeGen(reflect.TypeOf(zero))
gen := newMakeGen(reflect.TypeOf(zero), nil)
return newGenerator[V](&makeGen[V]{
gen: gen,
})
Expand All @@ -33,8 +34,8 @@ func (g *makeGen[V]) value(t *T) V {
return g.gen.value(t).(V)
}

func newMakeGen(typ reflect.Type) *Generator[any] {
gen, mayNeedCast := newMakeKindGen(typ)
func newMakeGen(typ reflect.Type, overrides []*Generator[any]) *Generator[any] {
gen, mayNeedCast := newMakeKindGen(typ, overrides)
if !mayNeedCast || typ.String() == typ.Kind().String() {
return gen // fast path with less reflect
}
Expand All @@ -55,7 +56,33 @@ func (g *castGen) value(t *T) any {
return reflect.ValueOf(v).Convert(g.typ).Interface()
}

func newMakeKindGen(typ reflect.Type) (gen *Generator[any], mayNeedCast bool) {
func makeLinkName(t reflect.Type) string {
path := t.PkgPath()
name := t.Name()
if len(path) == 0 && len(name) == 0 {
return ""
}

if len(path) == 0 {
return name
}

return fmt.Sprintf("%s.%s", t.PkgPath(), t.Name())
}

func newMakeKindGen(typ reflect.Type, overrides []*Generator[any]) (gen *Generator[any], mayNeedCast bool) {
// First, check if any overrides apply
for _, override := range overrides {
// My kingdom for https://github.com/golang/go/issues/54393
tt := reflect.TypeOf(override.impl).Elem()

// Types that are parameterized generically expose "link name" (see https://github.com/golang/go/issues/55924)
target := fmt.Sprintf("[%s]", makeLinkName(typ))
if strings.Contains(tt.Name(), target) {
return override, false // TODO: No idea if false is right here
}
}

switch typ.Kind() {
case reflect.Bool:
return Bool().AsAny(), true
Expand Down Expand Up @@ -86,25 +113,25 @@ func newMakeKindGen(typ reflect.Type) (gen *Generator[any], mayNeedCast bool) {
case reflect.Float64:
return Float64().AsAny(), true
case reflect.Array:
return genAnyArray(typ), false
return genAnyArray(typ, overrides), false
case reflect.Map:
return genAnyMap(typ), false
return genAnyMap(typ, overrides), false
case reflect.Pointer:
return Deferred(func() *Generator[any] { return genAnyPointer(typ) }), false
return Deferred(func() *Generator[any] { return genAnyPointer(typ, overrides) }), false
case reflect.Slice:
return genAnySlice(typ), false
return genAnySlice(typ, overrides), false
case reflect.String:
return String().AsAny(), true
case reflect.Struct:
return genAnyStruct(typ), false
return genAnyStruct(typ, overrides), false
default:
panic(fmt.Sprintf("unsupported type kind for Make: %v", typ.Kind()))
}
}

func genAnyPointer(typ reflect.Type) *Generator[any] {
func genAnyPointer(typ reflect.Type, overrides []*Generator[any]) *Generator[any] {
elem := typ.Elem()
elemGen := newMakeGen(elem)
elemGen := newMakeGen(elem, overrides)
const pNonNil = 0.5

return Custom[any](func(t *T) any {
Expand All @@ -119,9 +146,9 @@ func genAnyPointer(typ reflect.Type) *Generator[any] {
})
}

func genAnyArray(typ reflect.Type) *Generator[any] {
func genAnyArray(typ reflect.Type, overrides []*Generator[any]) *Generator[any] {
count := typ.Len()
elemGen := newMakeGen(typ.Elem())
elemGen := newMakeGen(typ.Elem(), overrides)

return Custom[any](func(t *T) any {
a := reflect.Indirect(reflect.New(typ))
Expand All @@ -137,8 +164,8 @@ func genAnyArray(typ reflect.Type) *Generator[any] {
})
}

func genAnySlice(typ reflect.Type) *Generator[any] {
elemGen := newMakeGen(typ.Elem())
func genAnySlice(typ reflect.Type, overrides []*Generator[any]) *Generator[any] {
elemGen := newMakeGen(typ.Elem(), overrides)

return Custom[any](func(t *T) any {
repeat := newRepeat(-1, -1, -1, elemGen.String())
Expand All @@ -151,9 +178,9 @@ func genAnySlice(typ reflect.Type) *Generator[any] {
})
}

func genAnyMap(typ reflect.Type) *Generator[any] {
keyGen := newMakeGen(typ.Key())
valGen := newMakeGen(typ.Elem())
func genAnyMap(typ reflect.Type, overrides []*Generator[any]) *Generator[any] {
keyGen := newMakeGen(typ.Key(), overrides)
valGen := newMakeGen(typ.Elem(), overrides)

return Custom[any](func(t *T) any {
label := keyGen.String() + "," + valGen.String()
Expand All @@ -172,11 +199,11 @@ func genAnyMap(typ reflect.Type) *Generator[any] {
})
}

func genAnyStruct(typ reflect.Type) *Generator[any] {
func genAnyStruct(typ reflect.Type, overrides []*Generator[any]) *Generator[any] {
numFields := typ.NumField()
fieldGens := make([]*Generator[any], numFields)
for i := 0; i < numFields; i++ {
fieldGens[i] = newMakeGen(typ.Field(i).Type)
fieldGens[i] = newMakeGen(typ.Field(i).Type, overrides)
}

return Custom[any](func(t *T) any {
Expand Down
22 changes: 22 additions & 0 deletions make_variant.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright 2022 Gregory Petrosyan <gregory.petrosyan@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

package rapid

import (
"reflect"
)

// https://stackoverflow.com/questions/73864711/get-type-parameter-from-a-generic-struct-using-reflection

// MakeVariant creates a generator of values of type V, using reflection to infer the required structure.
func MakeVariant[V any](overrides ...*Generator[any]) *Generator[V] {
var zero V
gen := newMakeGen(reflect.TypeOf(zero), overrides)
return newGenerator[V](&makeGen[V]{
gen: gen,
})
}
50 changes: 50 additions & 0 deletions make_variant_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright 2022 Gregory Petrosyan <gregory.petrosyan@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

package rapid_test

import (
"fmt"
"testing"
"time"

"pgregory.net/rapid"
)

type S struct {
F1 string
Int int
T time.Time
TPtr *time.Time
}

func (s S) String() string {
return fmt.Sprintf("%s, %d, %s", s.F1, s.Int, s.T.String())
}

func Test(t *testing.T) {
now := time.Now()

strGen := rapid.Just("Hello")
intGen := rapid.IntRange(0, 100)
timeGen := rapid.Just(now)
sGen := rapid.MakeVariant[S](strGen.AsAny(), intGen.AsAny(), timeGen.AsAny())
s := sGen.Example(1)

if s.F1 != "Hello" {
t.Errorf("Unexpected string value")
}
if s.Int > 100 || s.Int < 0 {
t.Errorf("Unexpected int value")
}
if !s.T.Equal(now) {
t.Errorf("Unexpected time.Time value")
}
if s.TPtr != nil && !s.TPtr.Equal(now) {
t.Errorf("Unexpected time.Time ptr value")
}
fmt.Println(s.String())
}

0 comments on commit a334a78

Please sign in to comment.