diff --git a/middlewares.go b/middlewares.go new file mode 100644 index 0000000..522cdef --- /dev/null +++ b/middlewares.go @@ -0,0 +1,50 @@ +package ron + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/http" + "time" +) + +func (e *Engine) TimeOutMiddleware() Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), e.Config.Timeout) + defer cancel() + + r = r.WithContext(ctx) + done := make(chan struct{}) + + go func() { + next.ServeHTTP(w, r) + close(done) + }() + + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + slog.Debug("timeout reached") + http.Error(w, "Request timed out", http.StatusGatewayTimeout) + } + case <-done: + } + }) + } +} + +func (e *Engine) RequestIdMiddleware() Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + id := r.Header.Get("X-Request-ID") + if id == "" { + id = fmt.Sprintf("%d", time.Now().UnixNano()) + } + ctx = context.WithValue(ctx, RequestID, id) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/ron.go b/ron.go index 2f9e846..37eb9c7 100644 --- a/ron.go +++ b/ron.go @@ -3,7 +3,6 @@ package ron import ( "context" "encoding/json" - "errors" "fmt" "io" "log/slog" @@ -116,8 +115,6 @@ func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - e.middleware = append(e.middleware, e.timeOutMiddleware()) - e.middleware = append(e.middleware, e.requestIdMiddleware()) handler = createStack(e.middleware...)(handler) rw := &responseWriterWrapper{ResponseWriter: w} handler.ServeHTTP(rw, r) @@ -138,46 +135,6 @@ func createStack(xs ...Middleware) Middleware { } } -func (e *Engine) timeOutMiddleware() Middleware { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), e.Config.Timeout) - defer cancel() - - r = r.WithContext(ctx) - done := make(chan struct{}) - - go func() { - next.ServeHTTP(w, r) - close(done) - }() - - select { - case <-ctx.Done(): - if errors.Is(ctx.Err(), context.DeadlineExceeded) { - slog.Debug("timeout reached") - http.Error(w, "Request timed out", http.StatusGatewayTimeout) - } - case <-done: - } - }) - } -} - -func (e *Engine) requestIdMiddleware() Middleware { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - id := r.Header.Get("X-Request-ID") - if id == "" { - id = fmt.Sprintf("%d", time.Now().UnixNano()) - } - ctx = context.WithValue(ctx, RequestID, id) - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} - func (e *Engine) USE(middleware Middleware) { e.middleware = append(e.middleware, middleware) }