From c4983ee36d30fcecb1367a0f59cbd28258c9237f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pedro=20P=C3=A9rez?= Date: Tue, 19 Nov 2024 22:30:18 +0100 Subject: [PATCH] add middlewares and grouping routes --- .gitignore | 3 +- ron.go | 85 ++++++++++--- ron_test.go | 355 ++++++++++++++++++++++++++++++++++++++++++++-------- 3 files changed, 373 insertions(+), 70 deletions(-) diff --git a/.gitignore b/.gitignore index 5292519..f30e625 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -logs/ \ No newline at end of file +logs/ +.idea/ \ No newline at end of file diff --git a/ron.go b/ron.go index aeea964..0d0817f 100644 --- a/ron.go +++ b/ron.go @@ -17,6 +17,8 @@ type ( Data map[string]any + Middleware func(http.Handler) http.Handler + Context struct { C context.Context W http.ResponseWriter @@ -25,15 +27,25 @@ type ( } Engine struct { - mux *http.ServeMux - LogLevel slog.Level - Render *Render + mux *http.ServeMux + middleware []Middleware + groupMux map[string]*groupMux + LogLevel slog.Level + Render *Render + } + + groupMux struct { + prefix string + mux *http.ServeMux + middleware []Middleware + engine *Engine } ) func defaultEngine() *Engine { return &Engine{ mux: http.NewServeMux(), + groupMux: make(map[string]*groupMux), LogLevel: slog.LevelInfo, } } @@ -54,7 +66,16 @@ func (e *Engine) apply(opts ...EngineOptions) *Engine { } func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) { - e.handleRequest(w, r) + var handler http.Handler = e.mux + for prefix, group := range e.groupMux { + if strings.HasPrefix(r.URL.Path, prefix) { + handler = createStack(group.middleware...)(handler) + break + } + } + + handler = createStack(e.middleware...)(handler) + handler.ServeHTTP(w, r) } func (e *Engine) Run(addr string) error { @@ -62,30 +83,62 @@ func (e *Engine) Run(addr string) error { return http.ListenAndServe(addr, e) } -func (e *Engine) handleRequest(w http.ResponseWriter, r *http.Request) { - e.mux.ServeHTTP(w, r) +func createStack(xs ...Middleware) Middleware { + return func(next http.Handler) http.Handler { + for i := len(xs) - 1; i >= 0; i-- { + x := xs[i] + next = x(next) + } + return next + } +} + +func (e *Engine) USE(middleware Middleware) { + e.middleware = append(e.middleware, middleware) } func (e *Engine) GET(path string, handler func(*Context)) { - e.mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) - return - } + e.mux.HandleFunc(fmt.Sprintf("GET %s", path), func(w http.ResponseWriter, r *http.Request) { handler(&Context{W: w, R: r, E: e}) }) } func (e *Engine) POST(path string, handler func(*Context)) { - e.mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) - return - } + e.mux.HandleFunc(fmt.Sprintf("POST %s", path), func(w http.ResponseWriter, r *http.Request) { handler(&Context{W: w, R: r, E: e}) }) } +func (e *Engine) GROUP(prefix string) *groupMux { + if _, ok := e.groupMux[prefix]; !ok { + e.groupMux[prefix] = &groupMux{ + prefix: prefix, + mux: http.NewServeMux(), + engine: e, + } + + e.mux.Handle(prefix+"/", http.StripPrefix(prefix, e.groupMux[prefix].mux)) + } + + return e.groupMux[prefix] +} + +func (g *groupMux) USE(middleware Middleware) { + g.middleware = append(g.middleware, middleware) +} + +func (g *groupMux) GET(path string, handler func(*Context)) { + g.mux.HandleFunc(fmt.Sprintf("GET %s", path), func(w http.ResponseWriter, r *http.Request) { + handler(&Context{W: w, R: r, E: g.engine}) + }) +} + +func (g *groupMux) POST(path string, handler func(*Context)) { + g.mux.HandleFunc(fmt.Sprintf("POST %s", path), func(w http.ResponseWriter, r *http.Request) { + handler(&Context{W: w, R: r, E: g.engine}) + }) +} + // Static serves static files from a specified directory, accessible through a defined URL path. // // The `path` parameter represents the URL prefix to access the static files. diff --git a/ron_test.go b/ron_test.go index 05593f1..0f61a0a 100644 --- a/ron_test.go +++ b/ron_test.go @@ -1,6 +1,7 @@ package ron import ( + "log/slog" "net/http" "net/http/httptest" "os" @@ -39,32 +40,126 @@ func Test_applyEngineConfig(t *testing.T) { func Test_ServeHTTP(t *testing.T) { e := New() - rr := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/", nil) - e.ServeHTTP(rr, req) - - if status := rr.Code; status != http.StatusNotFound { - t.Errorf("Expected status code: %d, Actual: %d", http.StatusNotFound, status) - } -} - -func Test_GET(t *testing.T) { - e := New() - e.GET("/", func(c *Context) { + api := e.GROUP("/api") + api.GET("/index", func(c *Context) { c.W.WriteHeader(http.StatusOK) - c.W.Write([]byte("GET")) + c.W.Write([]byte("GET API")) }) rr := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/", nil) + req, _ := http.NewRequest("GET", "/api/index", nil) e.ServeHTTP(rr, req) if status := rr.Code; status != http.StatusOK { t.Errorf("Expected status code: %d, Actual: %d", http.StatusOK, status) } +} - if rr.Body.String() != "GET" { - t.Errorf("Expected: GET, Actual: %s", rr.Body.String()) +func Test_RUN(t *testing.T) { + e := New() + go func() { + e.Run(":8080") + }() +} + +func Test_createStack(t *testing.T) { + m1 := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Middleware 1")) + next.ServeHTTP(w, r) + }) + } + m2 := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Middleware 2")) + next.ServeHTTP(w, r) + }) + } + + stack := createStack(m1, m2) + rr := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + stack(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Handler")) + })).ServeHTTP(rr, req) + + if rr.Body.String() != "Middleware 1Middleware 2Handler" { + t.Errorf("Expected: Middleware 1Middleware 2Handler, Actual: %s", rr.Body.String()) + } +} + +func Test_USE(t *testing.T) { + e := New() + m1 := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Middleware 1")) + next.ServeHTTP(w, r) + }) + } + m2 := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Middleware 2")) + next.ServeHTTP(w, r) + }) + } + + e.USE(m1) + e.USE(m2) + rr := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + e.ServeHTTP(rr, req) + + if rr.Body.String() != "Middleware 1Middleware 2404 page not found\n" { + t.Errorf("Expected: Middleware 1Middleware 2Handler, Actual: %s", rr.Body.String()) + } +} + +func Test_GET(t *testing.T) { + e := New() + tests := []struct { + name string + method string + path string + expectedStatus int + expectedBody string + }{ + {"root endpoint", "GET", "/", http.StatusOK, "GET Root"}, + {"api endpoint", "GET", "/api", http.StatusOK, "GET API"}, + {"api endpoint with version", "GET", "/api/v1", http.StatusOK, "GET API v1"}, + {"resource with param", "GET", "/api/v1/resource/1", http.StatusOK, "GET Resource"}, + } + + e.GET("/", func(c *Context) { + c.W.WriteHeader(http.StatusOK) + c.W.Write([]byte("GET Root")) + }) + e.GET("/api", func(c *Context) { + c.W.WriteHeader(http.StatusOK) + c.W.Write([]byte("GET API")) + }) + e.GET("/api/v1", func(c *Context) { + c.W.WriteHeader(http.StatusOK) + c.W.Write([]byte("GET API v1")) + }) + e.GET("/api/v1/resource/{id}", func(c *Context) { + c.W.WriteHeader(http.StatusOK) + c.W.Write([]byte("GET Resource")) + }) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rr := httptest.NewRecorder() + req, _ := http.NewRequest(tt.method, tt.path, nil) + e.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("Expected status code: %d, Actual: %d", tt.expectedStatus, rr.Code) + } + + if rr.Body.String() != tt.expectedBody { + t.Errorf("Expected body: %q, Actual: %q", tt.expectedBody, rr.Body.String()) + } + }) } } @@ -88,6 +183,86 @@ func Test_POST(t *testing.T) { } } +func Test_GROUP(t *testing.T) { + e := New() + api := e.GROUP("/api") + api.GET("/index", func(c *Context) { + c.W.WriteHeader(http.StatusOK) + c.W.Write([]byte("GET API")) + }) + + rr := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/index", nil) + e.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("Expected status code: %d, Actual: %d", http.StatusOK, status) + } + + if rr.Body.String() != "GET API" { + t.Errorf("Expected: GET API, Actual: %s", rr.Body.String()) + } +} + +func Test_GROUPWithMiddleware(t *testing.T) { + e := New() + e.GET("/index", func(c *Context) { + c.W.WriteHeader(http.StatusOK) + c.W.Write([]byte("GET Root")) + }) + e.USE(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Middleware 1")) + next.ServeHTTP(w, r) + }) + }) + + api := e.GROUP("/api") + api.USE(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Middleware 2")) + next.ServeHTTP(w, r) + }) + }) + api.GET("/index", func(c *Context) { + c.W.WriteHeader(http.StatusOK) + c.W.Write([]byte("GET API")) + }) + + rr := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/index", nil) + e.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("Expected status code: %d, Actual: %d", http.StatusOK, status) + } + + if rr.Body.String() != "Middleware 1Middleware 2GET API" { + t.Errorf("Expected: Middleware 1Middleware 2GET API, Actual: %s", rr.Body.String()) + } +} + +func Test_GROUPPOST(t *testing.T) { + e := New() + api := e.GROUP("/api") + api.POST("/index", func(c *Context) { + c.W.WriteHeader(http.StatusOK) + c.W.Write([]byte("POST API")) + }) + + rr := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/index", nil) + e.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("Expected status code: %d, Actual: %d", http.StatusOK, status) + } + + if rr.Body.String() != "POST API" { + t.Errorf("Expected: POST API, Actual: %s", rr.Body.String()) + } +} + func Test_Static(t *testing.T) { os.Mkdir("assets", os.ModePerm) f, _ := os.Create("assets/style.css") @@ -114,29 +289,53 @@ type Foo struct { } func Test_JSON(t *testing.T) { - rr := httptest.NewRecorder() - c := &Context{ - W: rr, + tests := []struct { + name string + code int + data any + expectedStatus int + expectedBody string + expectedHeader string + }{ + { + name: "valid JSON", + code: http.StatusOK, + data: Foo{Bar: "bar", Taz: 30, Car: nil}, + expectedStatus: http.StatusOK, + expectedBody: `{"bar":"bar","something":30,"car":null}` + "\n", + expectedHeader: "application/json", + }, + { + name: "invalid JSON", + code: http.StatusOK, + data: make(chan int), + expectedStatus: http.StatusInternalServerError, + expectedBody: "json: unsupported type: chan int\n", + expectedHeader: "application/json", + }, } - expected := `{"bar":"bar","something":30,"car":null}` + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rr := httptest.NewRecorder() + c := &Context{ + W: rr, + } - c.JSON(http.StatusOK, Foo{ - Bar: "bar", - Taz: 30, - Car: nil, - }) + c.JSON(tt.code, tt.data) - if status := rr.Code; status != http.StatusOK { - t.Errorf("Expected status code: %d, Actual: %d", http.StatusOK, status) - } + if status := rr.Code; status != tt.expectedStatus { + t.Errorf("Expected status code: %d, Actual: %d", tt.expectedStatus, status) + } - if rr.Header().Get("Content-Type") != "application/json" { - t.Errorf("Expected Content-Type: application/json, Actual: %s", c.W.Header().Get("Content-Type")) - } + if rr.Header().Get("Content-Type") != tt.expectedHeader { + t.Errorf("Expected Content-Type: %s, Actual: %s", tt.expectedHeader, rr.Header().Get("Content-Type")) + } - if rr.Body.String() != string(expected)+"\n" { - t.Errorf("Expected: %s, Actual: %s", string(expected), rr.Body.String()) + if rr.Body.String() != tt.expectedBody { + t.Errorf("Expected body: %q, Actual: %q", tt.expectedBody, rr.Body.String()) + } + }) } } @@ -147,32 +346,82 @@ func Test_HTML(t *testing.T) { f.Close() defer os.RemoveAll("templates") - rr := httptest.NewRecorder() - c := &Context{ - W: rr, - E: &Engine{ - Render: NewHTMLRender(), + tests := []struct { + name string + code int + templateName string + templateData *TemplateData + expectedStatus int + expectedBody string + expectedHeader string + }{ + { + name: "valid HTML", + code: http.StatusOK, + templateName: "page.index.gohtml", + templateData: &TemplateData{Data: Data{"heading1": "foo", "heading2": "bar"}}, + expectedStatus: http.StatusOK, + expectedBody: `

foo

bar

`, + expectedHeader: "text/html; charset=utf-8", + }, + { + name: "template not found", + code: http.StatusOK, + templateName: "nonexistent.gohtml", + templateData: &TemplateData{Data: Data{"heading1": "foo", "heading2": "bar"}}, + expectedStatus: http.StatusInternalServerError, + expectedBody: "open templates/nonexistent.gohtml: no such file or directory\n", + expectedHeader: "text/html; charset=utf-8", }, } - expected := `

foo

bar

` + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rr := httptest.NewRecorder() + c := &Context{ + W: rr, + E: &Engine{ + Render: NewHTMLRender(), + }, + } - c.HTML(http.StatusOK, "page.index.gohtml", &TemplateData{ - Data: Data{ - "heading1": "foo", - "heading2": "bar", - }, - }) + c.HTML(tt.code, tt.templateName, tt.templateData) - if status := rr.Code; status != http.StatusOK { - t.Errorf("Expected status code: %d, Actual: %d", http.StatusOK, status) - } + if status := rr.Code; status != tt.expectedStatus { + t.Errorf("Expected status code: %d, Actual: %d", tt.expectedStatus, status) + } - if rr.Header().Get("Content-Type") != "text/html; charset=utf-8" { - t.Errorf("Expected Content-Type: text/html; charset=utf-8, Actual: %s", c.W.Header().Get("Content-Type")) - } + if rr.Header().Get("Content-Type") != tt.expectedHeader { + t.Errorf("Expected Content-Type: %s, Actual: %s", tt.expectedHeader, rr.Header().Get("Content-Type")) + } - if rr.Body.String() != string(expected) { - t.Errorf("Expected: %s, Actual: %s", string(expected), rr.Body.String()) + if rr.Body.String() != tt.expectedBody { + t.Errorf("Expected body: %q, Actual: %q", tt.expectedBody, rr.Body.String()) + } + }) + } +} + +func Test_newLogger(t *testing.T) { + tests := []struct { + name string + level slog.Level + wantErr bool + }{ + {"valid level", 1, false}, + {"invalid level", -1, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.wantErr { + t.Errorf("newLogger() panicked: %v", r) + } + } + }() + newLogger(tt.level) + }) } }