diff --git a/ron.go b/ron.go index 637f8d2..878a5c9 100644 --- a/ron.go +++ b/ron.go @@ -9,6 +9,7 @@ import ( "net/http" "os" "strings" + "sync" "time" ) @@ -53,6 +54,12 @@ type ( } ) +var rwPool = sync.Pool{ + New: func() any { + return &responseWriterWrapper{} + }, +} + const ( RequestID string = "request_id" HeaderJSON string = "application/json" @@ -78,7 +85,9 @@ func (w *responseWriterWrapper) Write(b []byte) (int, error) { } func (w *responseWriterWrapper) Flush() { - w.ResponseWriter.(http.Flusher).Flush() + if flusher, ok := w.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } } func defaultEngine() *Engine { @@ -117,8 +126,13 @@ func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) { } handler = createStack(e.middleware...)(handler) - rw := &responseWriterWrapper{ResponseWriter: w} + rw := rwPool.Get().(*responseWriterWrapper) + rw.ResponseWriter = w + rw.headerWritten = false handler.ServeHTTP(rw, r) + + rw.Flush() + rwPool.Put(rw) } func (e *Engine) Run(addr string) error { @@ -142,29 +156,41 @@ func (e *Engine) USE(middleware Middleware) { func (e *Engine) GET(path string, handler func(*CTX)) { e.mux.HandleFunc(fmt.Sprintf("GET %s", path), func(w http.ResponseWriter, r *http.Request) { - rw := &responseWriterWrapper{ResponseWriter: w} + rw := rwPool.Get().(*responseWriterWrapper) + rw.ResponseWriter = w + rw.headerWritten = false handler(&CTX{W: rw, R: r, E: e, Ctx: r.Context()}) + rwPool.Put(rw) }) } func (e *Engine) POST(path string, handler func(*CTX)) { e.mux.HandleFunc(fmt.Sprintf("POST %s", path), func(w http.ResponseWriter, r *http.Request) { - rw := &responseWriterWrapper{ResponseWriter: w} + rw := rwPool.Get().(*responseWriterWrapper) + rw.ResponseWriter = w + rw.headerWritten = false handler(&CTX{W: rw, R: r, E: e, Ctx: r.Context()}) + rwPool.Put(rw) }) } func (e *Engine) PUT(path string, handler func(*CTX)) { e.mux.HandleFunc(fmt.Sprintf("PUT %s", path), func(w http.ResponseWriter, r *http.Request) { - rw := &responseWriterWrapper{ResponseWriter: w} + rw := rwPool.Get().(*responseWriterWrapper) + rw.ResponseWriter = w + rw.headerWritten = false handler(&CTX{W: rw, R: r, E: e, Ctx: r.Context()}) + rwPool.Put(rw) }) } func (e *Engine) DELETE(path string, handler func(*CTX)) { e.mux.HandleFunc(fmt.Sprintf("DELETE %s", path), func(w http.ResponseWriter, r *http.Request) { - rw := &responseWriterWrapper{ResponseWriter: w} + rw := rwPool.Get().(*responseWriterWrapper) + rw.ResponseWriter = w + rw.headerWritten = false handler(&CTX{W: rw, R: r, E: e, Ctx: r.Context()}) + rwPool.Put(rw) }) } @@ -188,29 +214,41 @@ func (g *groupMux) USE(middleware Middleware) { func (g *groupMux) GET(path string, handler func(*CTX)) { g.mux.HandleFunc(fmt.Sprintf("GET %s", path), func(w http.ResponseWriter, r *http.Request) { - rw := &responseWriterWrapper{ResponseWriter: w} + rw := rwPool.Get().(*responseWriterWrapper) + rw.ResponseWriter = w + rw.headerWritten = false handler(&CTX{W: rw, R: r, E: g.engine, Ctx: r.Context()}) + rwPool.Put(rw) }) } func (g *groupMux) POST(path string, handler func(*CTX)) { g.mux.HandleFunc(fmt.Sprintf("POST %s", path), func(w http.ResponseWriter, r *http.Request) { - rw := &responseWriterWrapper{ResponseWriter: w} + rw := rwPool.Get().(*responseWriterWrapper) + rw.ResponseWriter = w + rw.headerWritten = false handler(&CTX{W: rw, R: r, E: g.engine, Ctx: r.Context()}) + rwPool.Put(rw) }) } func (g *groupMux) PUT(path string, handler func(*CTX)) { g.mux.HandleFunc(fmt.Sprintf("PUT %s", path), func(w http.ResponseWriter, r *http.Request) { - rw := &responseWriterWrapper{ResponseWriter: w} + rw := rwPool.Get().(*responseWriterWrapper) + rw.ResponseWriter = w + rw.headerWritten = false handler(&CTX{W: rw, R: r, E: g.engine, Ctx: r.Context()}) + rwPool.Put(rw) }) } func (g *groupMux) DELETE(path string, handler func(*CTX)) { g.mux.HandleFunc(fmt.Sprintf("DELETE %s", path), func(w http.ResponseWriter, r *http.Request) { - rw := &responseWriterWrapper{ResponseWriter: w} + rw := rwPool.Get().(*responseWriterWrapper) + rw.ResponseWriter = w + rw.headerWritten = false handler(&CTX{W: rw, R: r, E: g.engine, Ctx: r.Context()}) + rwPool.Put(rw) }) } diff --git a/ron_test.go b/ron_test.go index 243d737..3afda12 100644 --- a/ron_test.go +++ b/ron_test.go @@ -568,11 +568,13 @@ func Test_newLogger(t *testing.T) { } } +var preallocatedHello = []byte("Hello") + func Benchmark_GET(b *testing.B) { engine := New() engine.GET("/hello", func(c *CTX) { - c.W.Write([]byte("Hello")) + c.W.Write(preallocatedHello) }) req := httptest.NewRequest(http.MethodGet, "/hello", nil)