refactor template cache, integrate responseWriterWrapper and integrate timeout

This commit is contained in:
Pedro Pérez 2024-11-21 21:29:59 +01:00
parent 4c8c6121b1
commit 89afb56836
4 changed files with 98 additions and 27 deletions

View File

@ -56,7 +56,7 @@ func Test_BindJSON(t *testing.T) {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
c := &CTX{ c := &CTX{
W: rr, W: &responseWriterWrapper{ResponseWriter: rr},
R: req, R: req,
} }
@ -90,7 +90,7 @@ func Test_BindForm(t *testing.T) {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
c := &CTX{ c := &CTX{
W: rr, W: &responseWriterWrapper{ResponseWriter: rr},
R: req, R: req,
} }

79
ron.go
View File

@ -3,6 +3,7 @@ package ron
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
@ -19,17 +20,27 @@ type (
Middleware func(http.Handler) http.Handler Middleware func(http.Handler) http.Handler
responseWriterWrapper struct {
http.ResponseWriter
headerWritten bool
}
CTX struct { CTX struct {
W http.ResponseWriter W *responseWriterWrapper
R *http.Request R *http.Request
E *Engine E *Engine
} }
Config struct {
Timeout time.Duration
LogLevel slog.Level
}
Engine struct { Engine struct {
mux *http.ServeMux mux *http.ServeMux
middleware []Middleware middleware []Middleware
groupMux map[string]*groupMux groupMux map[string]*groupMux
LogLevel slog.Level Config *Config
Render *Render Render *Render
} }
@ -49,11 +60,29 @@ const (
HeaderPlain_UTF8 string = "text/plain; charset=utf-8" HeaderPlain_UTF8 string = "text/plain; charset=utf-8"
) )
func (w *responseWriterWrapper) WriteHeader(code int) {
if !w.headerWritten {
w.headerWritten = true
w.ResponseWriter.WriteHeader(code)
}
}
func (w *responseWriterWrapper) Write(b []byte) (int, error) {
if !w.headerWritten {
w.headerWritten = true
w.ResponseWriter.WriteHeader(http.StatusOK)
}
return w.ResponseWriter.Write(b)
}
func defaultEngine() *Engine { func defaultEngine() *Engine {
return &Engine{ return &Engine{
mux: http.NewServeMux(), mux: http.NewServeMux(),
groupMux: make(map[string]*groupMux), groupMux: make(map[string]*groupMux),
LogLevel: slog.LevelInfo, Config: &Config{
Timeout: time.Second * 30,
LogLevel: slog.LevelDebug,
},
} }
} }
@ -81,12 +110,14 @@ func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
e.middleware = append(e.middleware, e.timeOutMiddleware())
handler = createStack(e.middleware...)(handler) handler = createStack(e.middleware...)(handler)
handler.ServeHTTP(w, r) rw := &responseWriterWrapper{ResponseWriter: w}
handler.ServeHTTP(rw, r)
} }
func (e *Engine) Run(addr string) error { func (e *Engine) Run(addr string) error {
newLogger(e.LogLevel) newLogger(e.Config.LogLevel)
return http.ListenAndServe(addr, e) return http.ListenAndServe(addr, e)
} }
@ -100,19 +131,47 @@ func createStack(xs ...Middleware) Middleware {
} }
} }
func (e *Engine) timeOutMiddleware() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), e.Config.Timeout)
defer cancel()
r = r.WithContext(ctx)
done := make(chan struct{})
go func() {
next.ServeHTTP(w, r)
close(done)
}()
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
slog.Debug("timeout reached")
http.Error(w, "Request timed out", http.StatusGatewayTimeout)
}
case <-done:
}
})
}
}
func (e *Engine) USE(middleware Middleware) { func (e *Engine) USE(middleware Middleware) {
e.middleware = append(e.middleware, middleware) e.middleware = append(e.middleware, middleware)
} }
func (e *Engine) GET(path string, handler func(*CTX, context.Context)) { func (e *Engine) GET(path string, handler func(*CTX, context.Context)) {
e.mux.HandleFunc(fmt.Sprintf("GET %s", path), func(w http.ResponseWriter, r *http.Request) { e.mux.HandleFunc(fmt.Sprintf("GET %s", path), func(w http.ResponseWriter, r *http.Request) {
handler(&CTX{W: w, R: r, E: e}, r.Context()) rw := &responseWriterWrapper{ResponseWriter: w}
handler(&CTX{W: rw, R: r, E: e}, r.Context())
}) })
} }
func (e *Engine) POST(path string, handler func(*CTX, context.Context)) { func (e *Engine) POST(path string, handler func(*CTX, context.Context)) {
e.mux.HandleFunc(fmt.Sprintf("POST %s", path), func(w http.ResponseWriter, r *http.Request) { e.mux.HandleFunc(fmt.Sprintf("POST %s", path), func(w http.ResponseWriter, r *http.Request) {
handler(&CTX{W: w, R: r, E: e}, r.Context()) rw := &responseWriterWrapper{ResponseWriter: w}
handler(&CTX{W: rw, R: r, E: e}, r.Context())
}) })
} }
@ -136,13 +195,15 @@ func (g *groupMux) USE(middleware Middleware) {
func (g *groupMux) GET(path string, handler func(*CTX, context.Context)) { func (g *groupMux) GET(path string, handler func(*CTX, context.Context)) {
g.mux.HandleFunc(fmt.Sprintf("GET %s", path), func(w http.ResponseWriter, r *http.Request) { g.mux.HandleFunc(fmt.Sprintf("GET %s", path), func(w http.ResponseWriter, r *http.Request) {
handler(&CTX{W: w, R: r, E: g.engine}, r.Context()) rw := &responseWriterWrapper{ResponseWriter: w}
handler(&CTX{W: rw, R: r, E: g.engine}, r.Context())
}) })
} }
func (g *groupMux) POST(path string, handler func(*CTX, context.Context)) { func (g *groupMux) POST(path string, handler func(*CTX, context.Context)) {
g.mux.HandleFunc(fmt.Sprintf("POST %s", path), func(w http.ResponseWriter, r *http.Request) { g.mux.HandleFunc(fmt.Sprintf("POST %s", path), func(w http.ResponseWriter, r *http.Request) {
handler(&CTX{W: w, R: r, E: g.engine}, r.Context()) rw := &responseWriterWrapper{ResponseWriter: w}
handler(&CTX{W: rw, R: r, E: g.engine}, r.Context())
}) })
} }

View File

@ -70,13 +70,13 @@ func Test_New(t *testing.T) {
func Test_applyEngineConfig(t *testing.T) { func Test_applyEngineConfig(t *testing.T) {
e := New(func(e *Engine) { e := New(func(e *Engine) {
e.Render = NewHTMLRender() e.Render = NewHTMLRender()
e.LogLevel = 1 e.Config.LogLevel = slog.LevelInfo
}) })
if e.Render == nil { if e.Render == nil {
t.Error("Expected Renderer, Actual: nil") t.Error("Expected Renderer, Actual: nil")
} }
if e.LogLevel != 1 { if e.Config.LogLevel != slog.LevelInfo {
t.Errorf("Expected LogLevel: 1, Actual: %d", e.LogLevel) t.Errorf("Expected LogLevel: 1, Actual: %d", e.Config.LogLevel)
} }
} }
@ -405,7 +405,7 @@ func Test_JSON(t *testing.T) {
t.Parallel() t.Parallel()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
c := &CTX{ c := &CTX{
W: rr, W: &responseWriterWrapper{ResponseWriter: rr},
} }
c.JSON(tt.givenCode, tt.givenData) c.JSON(tt.givenCode, tt.givenData)
@ -449,7 +449,7 @@ func Test_HTML(t *testing.T) {
t.Parallel() t.Parallel()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
c := &CTX{ c := &CTX{
W: rr, W: &responseWriterWrapper{ResponseWriter: rr},
E: &Engine{ E: &Engine{
Render: NewHTMLRender(), Render: NewHTMLRender(),
}, },

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"html/template" "html/template"
"io/fs" "io/fs"
"log/slog"
"net/http" "net/http"
"path/filepath" "path/filepath"
"reflect" "reflect"
@ -63,13 +64,9 @@ func (re *Render) Template(w http.ResponseWriter, tmpl string, td *TemplateData)
td = &TemplateData{} td = &TemplateData{}
} }
if re.EnableCache { tc, err = re.getTemplateCache()
tc = re.templateCache if err != nil {
} else { return err
tc, err = re.createTemplateCache()
if err != nil {
return err
}
} }
t, ok := tc[tmpl] t, ok := tc[tmpl]
@ -78,19 +75,32 @@ func (re *Render) Template(w http.ResponseWriter, tmpl string, td *TemplateData)
} }
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
err = t.Execute(buf, td) if err = t.Execute(buf, td); err != nil {
if err != nil {
return err return err
} }
_, err = buf.WriteTo(w) if _, err = buf.WriteTo(w); err != nil {
if err != nil {
return err return err
} }
return nil return nil
} }
func (re *Render) getTemplateCache() (templateCache, error) {
slog.Debug("template cache", "tc status", re.EnableCache, "tc", len(re.templateCache))
if len(re.templateCache) == 0 {
cachedTemplates, err := re.createTemplateCache()
if err != nil {
return nil, err
}
re.templateCache = cachedTemplates
}
if re.EnableCache {
return re.templateCache, nil
}
return re.createTemplateCache()
}
func (re *Render) findHTMLFiles() ([]string, error) { func (re *Render) findHTMLFiles() ([]string, error) {
var files []string var files []string