diff --git a/core/core.go b/core/core.go index 6f9abf6..17c8119 100644 --- a/core/core.go +++ b/core/core.go @@ -2,9 +2,11 @@ package core import ( "context" + "errors" "fmt" "mime" "net/http" + "reflect" "sort" "sync" @@ -30,7 +32,7 @@ type Core struct { // is responsible for both: // // - decoding an HTTP request to an instance of the inputStruct; -// - and encoding an instance of the inputStruct to an HTTP request. +// - encoding an instance of the inputStruct to an HTTP request. func New(inputStruct any, opts ...Option) (*Core, error) { resolver, err := buildResolver(inputStruct) if err != nil { @@ -52,55 +54,59 @@ func New(inputStruct any, opts ...Option) (*Core, error) { for _, opt := range allOptions { if err := opt(core); err != nil { - return nil, fmt.Errorf("httpin: invalid option: %w", err) + return nil, fmt.Errorf("invalid option: %w", err) } } return core, nil } -// Decode decodes an HTTP request to a struct instance. -// The return value is a pointer to the input struct. -// For example: +// Decode decodes an HTTP request to a struct instance. The return value is a +// pointer to the input struct. For example: // -// New(&Input{}).Decode(req) -> *Input // New(Input{}).Decode(req) -> *Input func (c *Core) Decode(req *http.Request) (any, error) { - var err error - ct, _, _ := mime.ParseMediaType(req.Header.Get("Content-Type")) - if ct == "multipart/form-data" { - err = req.ParseMultipartForm(c.maxMemory) + // Create the input struct instance. Used to be created by owl.Resolve(). + value := reflect.New(c.resolver.Type).Interface() + if err := c.DecodeTo(req, value); err != nil { + return nil, err } else { - err = req.ParseForm() + return value, nil } - if err != nil { - return nil, err +} + +// DecodeTo decodes an HTTP request to the given value. The value must be a pointer +// to the struct instance of the type that the Core instance holds. +func (c *Core) DecodeTo(req *http.Request, value any) (err error) { + if err = c.parseRequestForm(req); err != nil { + return fmt.Errorf("failed to parse request form: %w", err) } - rv, err := c.resolver.Resolve( + err = c.resolver.ResolveTo( + value, owl.WithNamespace(decoderNamespace), owl.WithValue(CtxRequest, req), owl.WithNestedDirectivesEnabled(c.enableNestedDirectives), ) - if err != nil { - return nil, NewInvalidFieldError(err) + if err != nil && !errors.Is(err, owl.ErrInvalidResolveTarget) { + return NewInvalidFieldError(err) } - return rv.Interface(), nil + return err } -// NewRequest wraps NewRequestWithContext using context.Background. +// NewRequest wraps NewRequestWithContext using context.Background(), see +// NewRequestWithContext. func (c *Core) NewRequest(method string, url string, input any) (*http.Request, error) { return c.NewRequestWithContext(context.Background(), method, url, input) } -// NewRequestWithContext returns a new http.Request given a method, url and an -// input struct instance. Note that the Core instance is bound to a specific -// type of struct. Which means when the given input is not the type of the -// struct that the Core instance holds, error of type mismatch will be returned. -// In order to avoid this error, you can always use httpin.NewRequest() function -// instead. Which will create a Core instance for you when needed. There's no -// performance penalty for doing so. Because there's a cache layer for all the -// Core instances. +// NewRequestWithContext turns the given input struct into an HTTP request. Note +// that the Core instance is bound to a specific type of struct. Which means +// when the given input is not the type of the struct that the Core instance +// holds, error of type mismatch will be returned. In order to avoid this error, +// you can always use httpin.NewRequest() instead. Which will create a Core +// instance for you on demand. There's no performance penalty for doing so. +// Because there's a cache layer for all the Core instances. func (c *Core) NewRequestWithContext(ctx context.Context, method string, url string, input any) (*http.Request, error) { c.prepareScanResolver() req, err := http.NewRequestWithContext(ctx, method, url, nil) @@ -168,6 +174,16 @@ func (c *Core) prepareScanResolver() { } } +func (c *Core) parseRequestForm(req *http.Request) (err error) { + ct, _, _ := mime.ParseMediaType(req.Header.Get("Content-Type")) + if ct == "multipart/form-data" { + err = req.ParseMultipartForm(c.maxMemory) + } else { + err = req.ParseForm() + } + return +} + // buildResolver builds a resolver for the inputStruct. It will run normalizations // on the resolver and cache it. func buildResolver(inputStruct any) (*owl.Resolver, error) { diff --git a/core/registry.go b/core/registry.go index 8c32510..5b0765b 100644 --- a/core/registry.go +++ b/core/registry.go @@ -14,13 +14,19 @@ var ( namedStringableAdaptors = make(map[string]*NamedAnyStringableAdaptor) ) -// RegisterCoder registers a custom stringable adaptor for the given type T. -// When a field of type T is encountered, the adaptor will be used to convert -// the value to a Stringable, which will be used to convert the value from/to string. +// RegisterCoder registers a custom coder for the given type T. When a field of +// type T is encountered, this coder will be used to convert the value to a +// Stringable, which will be used to convert the value from/to string. // -// NOTE: this function is designed to override the default Stringable adaptors that -// are registered by this package. For example, if you want to override the defualt -// behaviour of converting a bool value from/to string, you can do this: +// NOTE: this function is designed to override the default Stringable adaptors +// that are registered by this package. For example, if you want to override the +// defualt behaviour of converting a bool value from/to string, you can do this: +// +// func init() { +// core.RegisterCoder[bool](func(b *bool) (core.Stringable, error) { +// return (*YesNo)(b), nil +// }) +// } // // type YesNo bool // @@ -42,25 +48,18 @@ var ( // } // return nil // } -// -// func init() { -// core.RegisterCoder[bool](func(b *bool) (core.Stringable, error) { -// return (*YesNo)(b), nil -// }) -// } func RegisterCoder[T any](adapt func(*T) (Stringable, error)) { customStringableAdaptors[internal.TypeOf[T]()] = internal.NewAnyStringableAdaptor[T](adapt) } -// RegisterNamedCoder works similar to RegisterType, except that it binds the adaptor to a name. -// This is useful when you only want to override the types in a specific struct. -// You will be using the "encoder" and "decoder" directives to specify the name of the adaptor. -// -// For example: +// RegisterNamedCoder works similar to RegisterCoder, except that it binds the +// coder to a name. This is useful when you only want to override the types in +// a specific struct field. You will be using the "coder" or "decoder" directive +// to specify the name of the coder to use. For example: // // type MyStruct struct { -// Bool bool // this field will be encoded/decoded using the default bool coder -// YesNo bool `in:"encoder=yesno,decoder=yesno"` // this field will be encoded/decoded using the YesNo coder +// Bool bool // use default bool coder +// YesNo bool `in:"coder=yesno"` // use YesNo coder // } // // func init() { @@ -68,6 +67,27 @@ func RegisterCoder[T any](adapt func(*T) (Stringable, error)) { // return (*YesNo)(b), nil // }) // } +// +// type YesNo bool +// +// func (yn YesNo) String() string { +// if yn { +// return "yes" +// } +// return "no" +// } +// +// func (yn *YesNo) FromString(s string) error { +// switch s { +// case "yes": +// *yn = true +// case "no": +// *yn = false +// default: +// return fmt.Errorf("invalid YesNo value: %q", s) +// } +// return nil +// } func RegisterNamedCoder[T any](name string, adapt func(*T) (Stringable, error)) { namedStringableAdaptors[name] = &NamedAnyStringableAdaptor{ Name: name, @@ -76,8 +96,9 @@ func RegisterNamedCoder[T any](name string, adapt func(*T) (Stringable, error)) } } -// RegisterFileCoder registers the given type T as a file type. T must implement the Fileable interface. -// Remember if you don't register the type explicitly, it won't be recognized as a file type. +// RegisterFileCoder registers the given type T as a file type. T must implement +// the Fileable interface. Remember if you don't register the type explicitly, +// it won't be recognized as a file type. func RegisterFileCoder[T Fileable]() error { fileTypes[internal.TypeOf[T]()] = struct{}{} return nil diff --git a/go.mod b/go.mod index a4e6eb7..b2364b0 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/ggicci/httpin go 1.20 require ( - github.com/ggicci/owl v0.8.0 + github.com/ggicci/owl v0.8.2 github.com/go-chi/chi/v5 v5.0.11 github.com/gorilla/mux v1.8.1 github.com/justinas/alice v1.2.0 diff --git a/go.sum b/go.sum index ca319ac..dee6dcd 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/ggicci/owl v0.8.0 h1:PCueAADCWwuW2jv7fvp40eNjvrv3se/Rhkb+Ah6MPbM= github.com/ggicci/owl v0.8.0/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4= +github.com/ggicci/owl v0.8.1 h1:vppxAqpNOYBdrPKpcq7lzLy40UmSMr8Oz+h2EsJVgew= +github.com/ggicci/owl v0.8.1/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4= +github.com/ggicci/owl v0.8.2 h1:og+lhqpzSMPDdEB+NJfzoAJARP7qCG3f8uUC3xvGukA= +github.com/ggicci/owl v0.8.2/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4= github.com/go-chi/chi/v5 v5.0.11 h1:BnpYbFZ3T3S1WMpD79r7R5ThWX40TaFB7L31Y8xqSwA= github.com/go-chi/chi/v5 v5.0.11/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= diff --git a/httpin.go b/httpin.go index efabee4..55210b7 100644 --- a/httpin.go +++ b/httpin.go @@ -5,7 +5,7 @@ package httpin import ( "context" - "fmt" + "errors" "io" "net/http" "reflect" @@ -68,42 +68,54 @@ func UploadStream(r io.ReadCloser) *File { return core.UploadStream(r) } -// Decode decodes an HTTP request to the given input struct. The input must be a -// pointer to a struct instance. For example: +// DecodeTo decodes an HTTP request to the given input struct. The input must be +// a pointer (no matter how many levels) to a struct instance. For example: // // input := &InputStruct{} -// if err := Decode(req, &input); err != nil { ... } +// if err := DecodeTo(req, input); err != nil { ... } // // input is now populated with data from the request. -func Decode(req *http.Request, input any, opts ...core.Option) error { - originalType := reflect.TypeOf(input) - if originalType.Kind() != reflect.Ptr { - return fmt.Errorf("httpin: input must be a pointer") - } - co, err := New(originalType.Elem(), opts...) +func DecodeTo(req *http.Request, input any, opts ...core.Option) error { + co, err := New(internal.DereferencedType(input), opts...) if err != nil { return err } - if value, err := co.Decode(req); err != nil { - return err + return co.DecodeTo(req, input) +} + +// Decode decodes an HTTP request to a struct instance. The return value is a +// pointer to the input struct. For example: +// +// if user, err := Decode[User](req); err != nil { ... } +// // now user is a *User instance, which has been populated with data from the request. +func Decode[T any](req *http.Request, opts ...core.Option) (*T, error) { + rt := internal.TypeOf[T]() + if rt.Kind() != reflect.Struct { + return nil, errors.New("generic type T must be a struct type") + } + co, err := New(rt, opts...) + if err != nil { + return nil, err + } + if v, err := co.Decode(req); err != nil { + return nil, err } else { - if originalType.Elem().Kind() == reflect.Ptr { - reflect.ValueOf(input).Elem().Set(reflect.ValueOf(value)) - } else { - reflect.ValueOf(input).Elem().Set(reflect.ValueOf(value).Elem()) - } - return nil + return v.(*T), nil } } -// NewRequest wraps NewRequestWithContext using context.Background. +// NewRequest wraps NewRequestWithContext using context.Background(), see NewRequestWithContext. func NewRequest(method, url string, input any, opts ...core.Option) (*http.Request, error) { return NewRequestWithContext(context.Background(), method, url, input) } -// NewRequestWithContext returns a new http.Request given a method, url and an -// input struct instance. The fields of the input struct will be encoded to the -// request by resolving the "in" tags and executing the directives. +// NewRequestWithContext turns the given input struct into an HTTP request. The +// input struct with the "in" tags defines how to bind the data from the struct +// to the HTTP request. Use it as the replacement of http.NewRequest(). +// +// addUserPayload := &AddUserRequest{...} +// addUserRequest, err := NewRequestWithContext(context.Background(), "GET", "http://example.com", addUserPayload) +// http.DefaultClient.Do(addUserRequest) func NewRequestWithContext(ctx context.Context, method, url string, input any, opts ...core.Option) (*http.Request, error) { co, err := New(input, opts...) if err != nil { diff --git a/httpin_test.go b/httpin_test.go index efa48b9..8d9865a 100644 --- a/httpin_test.go +++ b/httpin_test.go @@ -20,57 +20,79 @@ type Pagination struct { PerPage int `in:"form=per_page,page_size"` } -func TestDecode(t *testing.T) { +func testcasePagination1100() (*http.Request, *Pagination) { r, _ := http.NewRequest("GET", "/", nil) r.Form = url.Values{ "page": {"1"}, "per_page": {"100"}, } - expected := &Pagination{ + return r, &Pagination{ Page: 1, PerPage: 100, } +} + +func TestDecodeTo(t *testing.T) { + r, expected := testcasePagination1100() func() { input := &Pagination{} - err := Decode(r, input) // pointer to a struct instance + err := DecodeTo(r, input) // pointer to a struct instance assert.NoError(t, err) assert.Equal(t, expected, input) }() func() { input := Pagination{} - err := Decode(r, &input) // addressable struct instance + err := DecodeTo(r, &input) // addressable struct instance assert.NoError(t, err) assert.Equal(t, expected, &input) }() func() { input := &Pagination{} - err := Decode(r, &input) // pointer to pointer of struct instance + err := DecodeTo(r, &input) // pointer to pointer of struct instance assert.NoError(t, err) assert.Equal(t, expected, input) }() func() { input := Pagination{} - err := Decode(r, input) // non-pointer struct instance should fail - assert.ErrorContains(t, err, "input must be a pointer") + err := DecodeTo(r, input) // non-pointer struct instance should fail + assert.ErrorContains(t, err, "invalid resolve target") }() } +func TestDecode(t *testing.T) { + r, expected := testcasePagination1100() + + p, err := Decode[Pagination](r) + assert.NoError(t, err) + assert.Equal(t, expected, p) +} + +func TestDecode_ErrNotAStruct(t *testing.T) { + r, _ := testcasePagination1100() + + _, err := Decode[int](r) + assert.ErrorContains(t, err, "T must be a struct type") + + _, err = Decode[*Pagination](r) + assert.ErrorContains(t, err, "T must be a struct type") +} + func TestDecode_ErrBuildResolverFailed(t *testing.T) { - r, _ := http.NewRequest("GET", "/", nil) - r.Form = url.Values{ - "page": {"1"}, - "per_page": {"100"}, - } + r, _ := testcasePagination1100() type Foo struct { Name string `in:"nonexistent=foo"` } - assert.Error(t, Decode(r, &Foo{})) + assert.Error(t, DecodeTo(r, &Foo{})) + + v, err := Decode[Foo](r) + assert.Nil(t, v) + assert.Error(t, err) } func TestDecode_ErrDecodeFailure(t *testing.T) { @@ -81,7 +103,11 @@ func TestDecode_ErrDecodeFailure(t *testing.T) { } p := &Pagination{} - assert.Error(t, Decode(r, p)) + assert.Error(t, DecodeTo(r, p)) + + v, err := Decode[Pagination](r) + assert.Nil(t, v) + assert.Error(t, err) } type EchoInput struct { diff --git a/internal/misc.go b/internal/misc.go index 83f67a0..5e6a401 100644 --- a/internal/misc.go +++ b/internal/misc.go @@ -7,7 +7,7 @@ import ( func IsNil(value reflect.Value) bool { switch value.Kind() { - case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Slice: + case reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.Interface, reflect.Slice: return value.IsNil() default: return false @@ -30,3 +30,12 @@ func TypeOf[T any]() reflect.Type { func Pointerize[T any](v T) *T { return &v } + +// DereferencedType returns the underlying type of a pointer. +func DereferencedType(v any) reflect.Type { + rv := reflect.ValueOf(v) + for rv.Kind() == reflect.Pointer { + rv = rv.Elem() + } + return rv.Type() +} diff --git a/internal/misc_test.go b/internal/misc_test.go index 164f7fc..6348622 100644 --- a/internal/misc_test.go +++ b/internal/misc_test.go @@ -27,3 +27,15 @@ func TestTypeOf(t *testing.T) { func TestPointerize(t *testing.T) { assert.Equal(t, 102, *Pointerize[int](102)) } + +func TestDereferencedType(t *testing.T) { + type Object struct{} + + var o = new(Object) + var po = &o + var ppo = &po + assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(Object{})) + assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(o)) + assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(po)) + assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(ppo)) +}