From 89afb5683662f1a6dfc51bf294422ddbe2b74737 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pedro=20P=C3=A9rez?= Date: Thu, 21 Nov 2024 21:29:59 +0100 Subject: [PATCH] refactor template cache, integrate responseWriterWrapper and integrate timeout --- binding_test.go | 4 +-- ron.go | 79 +++++++++++++++++++++++++++++++++++++++++++------ ron_test.go | 10 +++---- template.go | 32 +++++++++++++------- 4 files changed, 98 insertions(+), 27 deletions(-) diff --git a/binding_test.go b/binding_test.go index 5246550..255cf08 100644 --- a/binding_test.go +++ b/binding_test.go @@ -56,7 +56,7 @@ func Test_BindJSON(t *testing.T) { rr := httptest.NewRecorder() c := &CTX{ - W: rr, + W: &responseWriterWrapper{ResponseWriter: rr}, R: req, } @@ -90,7 +90,7 @@ func Test_BindForm(t *testing.T) { rr := httptest.NewRecorder() c := &CTX{ - W: rr, + W: &responseWriterWrapper{ResponseWriter: rr}, R: req, } diff --git a/ron.go b/ron.go index 2ff1924..b2f65f1 100644 --- a/ron.go +++ b/ron.go @@ -3,6 +3,7 @@ package ron import ( "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -19,17 +20,27 @@ type ( Middleware func(http.Handler) http.Handler + responseWriterWrapper struct { + http.ResponseWriter + headerWritten bool + } + CTX struct { - W http.ResponseWriter + W *responseWriterWrapper R *http.Request E *Engine } + Config struct { + Timeout time.Duration + LogLevel slog.Level + } + Engine struct { mux *http.ServeMux middleware []Middleware groupMux map[string]*groupMux - LogLevel slog.Level + Config *Config Render *Render } @@ -49,11 +60,29 @@ const ( HeaderPlain_UTF8 string = "text/plain; charset=utf-8" ) +func (w *responseWriterWrapper) WriteHeader(code int) { + if !w.headerWritten { + w.headerWritten = true + w.ResponseWriter.WriteHeader(code) + } +} + +func (w *responseWriterWrapper) Write(b []byte) (int, error) { + if !w.headerWritten { + w.headerWritten = true + w.ResponseWriter.WriteHeader(http.StatusOK) + } + return w.ResponseWriter.Write(b) +} + func defaultEngine() *Engine { return &Engine{ mux: http.NewServeMux(), groupMux: make(map[string]*groupMux), - LogLevel: slog.LevelInfo, + Config: &Config{ + Timeout: time.Second * 30, + LogLevel: slog.LevelDebug, + }, } } @@ -81,12 +110,14 @@ func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + e.middleware = append(e.middleware, e.timeOutMiddleware()) handler = createStack(e.middleware...)(handler) - handler.ServeHTTP(w, r) + rw := &responseWriterWrapper{ResponseWriter: w} + handler.ServeHTTP(rw, r) } func (e *Engine) Run(addr string) error { - newLogger(e.LogLevel) + newLogger(e.Config.LogLevel) return http.ListenAndServe(addr, e) } @@ -100,19 +131,47 @@ func createStack(xs ...Middleware) Middleware { } } +func (e *Engine) timeOutMiddleware() Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), e.Config.Timeout) + defer cancel() + + r = r.WithContext(ctx) + done := make(chan struct{}) + + go func() { + next.ServeHTTP(w, r) + close(done) + }() + + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + slog.Debug("timeout reached") + http.Error(w, "Request timed out", http.StatusGatewayTimeout) + } + case <-done: + } + }) + } +} + func (e *Engine) USE(middleware Middleware) { e.middleware = append(e.middleware, middleware) } func (e *Engine) GET(path string, handler func(*CTX, context.Context)) { e.mux.HandleFunc(fmt.Sprintf("GET %s", path), func(w http.ResponseWriter, r *http.Request) { - handler(&CTX{W: w, R: r, E: e}, r.Context()) + rw := &responseWriterWrapper{ResponseWriter: w} + handler(&CTX{W: rw, R: r, E: e}, r.Context()) }) } func (e *Engine) POST(path string, handler func(*CTX, context.Context)) { e.mux.HandleFunc(fmt.Sprintf("POST %s", path), func(w http.ResponseWriter, r *http.Request) { - handler(&CTX{W: w, R: r, E: e}, r.Context()) + rw := &responseWriterWrapper{ResponseWriter: w} + handler(&CTX{W: rw, R: r, E: e}, r.Context()) }) } @@ -136,13 +195,15 @@ func (g *groupMux) USE(middleware Middleware) { func (g *groupMux) GET(path string, handler func(*CTX, context.Context)) { g.mux.HandleFunc(fmt.Sprintf("GET %s", path), func(w http.ResponseWriter, r *http.Request) { - handler(&CTX{W: w, R: r, E: g.engine}, r.Context()) + rw := &responseWriterWrapper{ResponseWriter: w} + handler(&CTX{W: rw, R: r, E: g.engine}, r.Context()) }) } func (g *groupMux) POST(path string, handler func(*CTX, context.Context)) { g.mux.HandleFunc(fmt.Sprintf("POST %s", path), func(w http.ResponseWriter, r *http.Request) { - handler(&CTX{W: w, R: r, E: g.engine}, r.Context()) + rw := &responseWriterWrapper{ResponseWriter: w} + handler(&CTX{W: rw, R: r, E: g.engine}, r.Context()) }) } diff --git a/ron_test.go b/ron_test.go index 1545c8f..aa6ca0d 100644 --- a/ron_test.go +++ b/ron_test.go @@ -70,13 +70,13 @@ func Test_New(t *testing.T) { func Test_applyEngineConfig(t *testing.T) { e := New(func(e *Engine) { e.Render = NewHTMLRender() - e.LogLevel = 1 + e.Config.LogLevel = slog.LevelInfo }) if e.Render == nil { t.Error("Expected Renderer, Actual: nil") } - if e.LogLevel != 1 { - t.Errorf("Expected LogLevel: 1, Actual: %d", e.LogLevel) + if e.Config.LogLevel != slog.LevelInfo { + t.Errorf("Expected LogLevel: 1, Actual: %d", e.Config.LogLevel) } } @@ -405,7 +405,7 @@ func Test_JSON(t *testing.T) { t.Parallel() rr := httptest.NewRecorder() c := &CTX{ - W: rr, + W: &responseWriterWrapper{ResponseWriter: rr}, } c.JSON(tt.givenCode, tt.givenData) @@ -449,7 +449,7 @@ func Test_HTML(t *testing.T) { t.Parallel() rr := httptest.NewRecorder() c := &CTX{ - W: rr, + W: &responseWriterWrapper{ResponseWriter: rr}, E: &Engine{ Render: NewHTMLRender(), }, diff --git a/template.go b/template.go index c7020ab..4314a11 100644 --- a/template.go +++ b/template.go @@ -5,6 +5,7 @@ import ( "errors" "html/template" "io/fs" + "log/slog" "net/http" "path/filepath" "reflect" @@ -63,13 +64,9 @@ func (re *Render) Template(w http.ResponseWriter, tmpl string, td *TemplateData) td = &TemplateData{} } - if re.EnableCache { - tc = re.templateCache - } else { - tc, err = re.createTemplateCache() - if err != nil { - return err - } + tc, err = re.getTemplateCache() + if err != nil { + return err } t, ok := tc[tmpl] @@ -78,19 +75,32 @@ func (re *Render) Template(w http.ResponseWriter, tmpl string, td *TemplateData) } buf := new(bytes.Buffer) - err = t.Execute(buf, td) - if err != nil { + if err = t.Execute(buf, td); err != nil { return err } - _, err = buf.WriteTo(w) - if err != nil { + if _, err = buf.WriteTo(w); err != nil { return err } return nil } +func (re *Render) getTemplateCache() (templateCache, error) { + slog.Debug("template cache", "tc status", re.EnableCache, "tc", len(re.templateCache)) + if len(re.templateCache) == 0 { + cachedTemplates, err := re.createTemplateCache() + if err != nil { + return nil, err + } + re.templateCache = cachedTemplates + } + if re.EnableCache { + return re.templateCache, nil + } + return re.createTemplateCache() +} + func (re *Render) findHTMLFiles() ([]string, error) { var files []string