diff --git a/binding/binding.go b/binding/binding.go new file mode 100644 index 0000000..744a082 --- /dev/null +++ b/binding/binding.go @@ -0,0 +1,83 @@ +package binding + +import ( + "errors" + "net/http" + "reflect" + "strconv" +) + +type FormBinding struct{} + +func (FormBinding) Bind(r *http.Request, obj any) error { + if r == nil { + return errors.New("request is nil") + } + + if err := r.ParseForm(); err != nil { + return err + } + + if r.Form == nil { + return errors.New("form is nil") + } + + return mapForm(obj, r.Form) +} + +func mapForm(obj any, form map[string][]string) error { + val := reflect.ValueOf(obj) + if val.Kind() != reflect.Ptr || val.IsNil() { + return errors.New("obj must be a non-nil pointer") + } + val = val.Elem() + if val.Kind() != reflect.Struct { + return errors.New("obj must be a pointer to a struct") + } + + typ := val.Type() + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + fieldType := typ.Field(i) + formTag := fieldType.Tag.Get("form") + + if formTag == "" { + formTag = fieldType.Name + } + + if values, ok := form[formTag]; ok && len(values) > 0 { + value := values[0] + switch field.Kind() { + case reflect.String: + field.SetString(value) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + intValue, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + field.SetInt(intValue) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + uintValue, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return err + } + field.SetUint(uintValue) + case reflect.Float32, reflect.Float64: + floatValue, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + field.SetFloat(floatValue) + case reflect.Bool: + boolValue, err := strconv.ParseBool(value) + if err != nil { + return err + } + field.SetBool(boolValue) + default: + return errors.New("unsupported field type") + } + } + } + return nil +} diff --git a/binding/binding_test.go b/binding/binding_test.go new file mode 100644 index 0000000..aca7f00 --- /dev/null +++ b/binding/binding_test.go @@ -0,0 +1,64 @@ +package binding + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func Test_mapForm(t *testing.T) { + var someStruct struct { + StringType string `form:"stringtype"` + IntType int `form:"inttype"` + Int8Type int8 `form:"int8type"` + Int16Type int16 `form:"int16type"` + Int32Type int32 `form:"int32type"` + Int64Type int64 `form:"int64type"` + UintType uint `form:"uinttype"` + Uint8Type uint8 `form:"uint8type"` + Uint16Type uint16 `form:"uint16type"` + Uint32Type uint32 `form:"uint32type"` + Uint64Type uint64 `form:"uint64type"` + Float32Type float32 `form:"float32type"` + Float64Type float64 `form:"float64type"` + BoolType bool `form:"booltype"` + } + + formData := map[string][]string{ + "stringtype": {"stringType"}, + "inttype": {"-2147483647"}, + "int8type": {"-127"}, + "int16type": {"-32767"}, + "int32type": {"-2147483647"}, + "int64type": {"-9223372036854775807"}, + "uinttype": {"4294967295"}, + "uint8type": {"255"}, + "uint16type": {"65535"}, + "uint32type": {"4294967295"}, + "uint64type": {"18446744073709551615"}, + "float32type": { + "3.1415927", + }, + "float64type": { + "3.141592653589793", + }, + "booltype": {"true"}, + } + + err := mapForm(&someStruct, formData) + require.NoError(t, err) + require.Equal(t, "stringType", someStruct.StringType) + require.Equal(t, int(-2147483647), someStruct.IntType) + require.Equal(t, int8(-127), someStruct.Int8Type) + require.Equal(t, int16(-32767), someStruct.Int16Type) + require.Equal(t, int32(-2147483647), someStruct.Int32Type) + require.Equal(t, int64(-9223372036854775807), someStruct.Int64Type) + require.Equal(t, uint(4294967295), someStruct.UintType) + require.Equal(t, uint8(255), someStruct.Uint8Type) + require.Equal(t, uint16(65535), someStruct.Uint16Type) + require.Equal(t, uint32(4294967295), someStruct.Uint32Type) + require.Equal(t, uint64(18446744073709551615), someStruct.Uint64Type) + require.Equal(t, float32(3.1415927), someStruct.Float32Type) + require.Equal(t, float64(3.141592653589793), someStruct.Float64Type) + require.Equal(t, true, someStruct.BoolType) + t.Log(someStruct) +}