diff --git a/binding.go b/binding.go new file mode 100644 index 0000000..7a3ca44 --- /dev/null +++ b/binding.go @@ -0,0 +1,129 @@ +package ron + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "reflect" + "strconv" + "time" +) + +func (c *Context) BindJSON(v any) error { + if c.R.Header.Get("Content-Type") != "application/json" { + return http.ErrNotSupported + } + decoder := json.NewDecoder(c.R.Body) + return decoder.Decode(v) +} + +func (c *Context) BindForm(v interface{}) error { + if err := c.R.ParseForm(); err != nil { + return err + } + + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.Elem().Kind() != reflect.Struct { + return errors.New("v must be a pointer to a struct") + } + + return mapForm(v, c.R.Form) +} + +func mapForm(ptr interface{}, form map[string][]string) error { + val := reflect.ValueOf(ptr).Elem() + typ := val.Type() + + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + structField := typ.Field(i) + + if !field.CanSet() { + continue + } + + tag := structField.Tag.Get("form") + if tag == "" { + tag = structField.Name + } + + if field.Kind() == reflect.Struct && structField.Anonymous { + if err := mapForm(field.Addr().Interface(), form); err != nil { + return err + } + continue + } + + if values, ok := form[tag]; ok && len(values) > 0 { + if field.Kind() == reflect.Slice { + elemType := field.Type().Elem() + slice := reflect.MakeSlice(field.Type(), len(values), len(values)) + for i, v := range values { + elem := slice.Index(i) + if elem.Kind() == reflect.Ptr { + elem.Set(reflect.New(elemType.Elem())) + elem = elem.Elem() + } + if err := setField(elem, v); err != nil { + return err + } + } + field.Set(slice) + } else { + if err := setField(field, values[0]); err != nil { + return err + } + } + } + } + return nil +} + +func setField(field reflect.Value, value string) error { + if !field.CanSet() { + return nil + } + + if field.Type() == reflect.TypeOf(time.Time{}) { + t, err := time.Parse(time.RFC3339, value) + if err != nil { + return err + } + field.Set(reflect.ValueOf(t)) + return nil + } + + kind := field.Kind() + switch kind { + case reflect.String: + field.SetString(value) + case reflect.Int: + intValue, err := strconv.ParseInt(value, 10, field.Type().Bits()) + if err != nil { + return err + } + field.SetInt(intValue) + case reflect.Uint: + uintValue, err := strconv.ParseUint(value, 10, field.Type().Bits()) + if err != nil { + return err + } + field.SetUint(uintValue) + case reflect.Float64: + floatValue, err := strconv.ParseFloat(value, field.Type().Bits()) + if err != nil { + return err + } + field.SetFloat(floatValue) + case reflect.Bool: + boolValue, err := strconv.ParseBool(value) + if err != nil { + return err + } + field.SetBool(boolValue) + default: + return fmt.Errorf("unsupported type: %s", kind) + } + return nil +} diff --git a/binding_test.go b/binding_test.go new file mode 100644 index 0000000..19d2d06 --- /dev/null +++ b/binding_test.go @@ -0,0 +1,106 @@ +package ron + +import ( + "bytes" + "encoding/json" + "net/http/httptest" + "reflect" + "testing" + "time" +) + +type FooBody struct { + FString string `json:"fstring" form:"fstring"` + FInt int `json:"fint" form:"fint"` + FUint uint `json:"fuint" form:"fuint"` + FFloat64 float64 `json:"ffloat64" form:"ffloat64"` + FBool bool `json:"fbool" form:"fbool"` + FTime time.Time `json:"ftime" form:"ftime"` + + FStringSlice []string `json:"fstring_slice" form:"fstring_slice"` + FIntSlice []int `json:"fint_slice" form:"fint_slice"` + FBoolSlice []bool `json:"fbool_slice" form:"fbool_slice"` + + FNone string + + fPrivate string +} + +func expectedStruct() (FooBody, time.Time) { + actualTime := time.Now().Round(time.Second).UTC() + fooBody := FooBody{ + FString: "string", + FInt: -30, + FUint: 30, + FFloat64: 3.14, + FBool: true, + FTime: actualTime, + + FStringSlice: []string{"string1", "string2"}, + FIntSlice: []int{1, 2}, + FBoolSlice: []bool{true, false}, + } + + return fooBody, actualTime +} + +func Test_BindJSON(t *testing.T) { + expected, _ := expectedStruct() + body, err := json.Marshal(expected) + if err != nil { + t.Fatalf("Marshal() failed: %v", err) + } + + req := httptest.NewRequest("POST", "/", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + + c := &Context{ + W: rr, + R: req, + } + + var foo FooBody + err = c.BindJSON(&foo) + if err != nil { + t.Errorf("BindJSON() failed: %v", err) + } + + if reflect.DeepEqual(foo, expected) == false { + t.Errorf("Expected: %v, Actual: %v", expected, foo) + } +} + +func Test_BindForm(t *testing.T) { + expected, actualTime := expectedStruct() + req := httptest.NewRequest("POST", "/", nil) + req.Form = map[string][]string{ + "fstring": {"string"}, + "fint": {"-30"}, + "fuint": {"30"}, + "ffloat64": {"3.14"}, + "fbool": {"true"}, + "ftime": {actualTime.Format(time.RFC3339)}, + "fstring_slice": {"string1", "string2"}, + "fint_slice": {"1", "2"}, + "fbool_slice": {"true", "false"}, + "fnone": {"none"}, + } + + rr := httptest.NewRecorder() + + c := &Context{ + W: rr, + R: req, + } + + var foo FooBody + err := c.BindForm(&foo) + if err != nil { + t.Errorf("BindForm() failed: %v", err) + } + + if reflect.DeepEqual(foo, expected) == false { + t.Errorf("Expected: %v, Actual: %v", expected, foo) + } +} diff --git a/ron_test.go b/ron_test.go index 6adc7ab..e2f5eb8 100644 --- a/ron_test.go +++ b/ron_test.go @@ -9,7 +9,7 @@ import ( type Foo struct { Bar string `json:"bar"` - Taz int `json:"taz"` + Taz int `json:"something"` Car *string `json:"car"` } @@ -19,7 +19,7 @@ func Test_JSON(t *testing.T) { W: rr, } - expected := `{"bar":"bar","taz":30,"car":null}` + expected := `{"bar":"bar","something":30,"car":null}` c.JSON(http.StatusOK, Foo{ Bar: "bar",