diff --git a/app.go b/app.go index a483e74..2b43b10 100644 --- a/app.go +++ b/app.go @@ -64,6 +64,9 @@ type Config struct { // default map[string]DatabaseConfig{} Databases map[string]DatabaseConfig + // default false + CreateSSEBroker bool + // default false CreateTemplates bool @@ -76,6 +79,8 @@ type Config struct { type App struct { config Config + Router *Router + SSEBroker *SSEBroker Templates *Render Session *scs.SessionManager Mailer Mailer @@ -96,6 +101,7 @@ func NewApp(config ...Config) *App { Timezone: "UTC", Paseto: nil, Databases: make(map[string]DatabaseConfig), + CreateSSEBroker: false, CreateSession: false, CreateMailer: false, CreateTemplates: false, @@ -189,6 +195,7 @@ func NewApp(config ...Config) *App { app := &App{ config: cfg, + Router: newRouter(), } slog.Info( @@ -201,6 +208,7 @@ func NewApp(config ...Config) *App { "paseto_public_key", cfg.Paseto.PublicKey.ExportHex(), "paseto_duration", cfg.Paseto.Duration.String(), "databases", cfg.Databases, + "create_sse_broker", cfg.CreateSSEBroker, "create_templates", cfg.CreateTemplates, "create_session", cfg.CreateSession, "create_mailer", cfg.CreateMailer, @@ -210,6 +218,11 @@ func NewApp(config ...Config) *App { slog.Debug("paseto_assymetric_key", "key", cfg.Paseto.AsymmetricKey.ExportHex()) } + if cfg.CreateSSEBroker { + slog.Debug("creating sse broker") + app.SSEBroker = newSSEBroker() + } + if cfg.CreateTemplates { slog.Debug("creating templates") app.Templates = NewHTMLRender() diff --git a/broker.go b/broker.go new file mode 100644 index 0000000..fd57aa0 --- /dev/null +++ b/broker.go @@ -0,0 +1,166 @@ +package goblocks + +import ( + "encoding/json" + "fmt" + "log/slog" + "net/http" + "sync" + "time" + + "github.com/google/uuid" +) + +type Message struct { + Event string + Data any +} + +type Client struct { + ID string + Send chan Message + Close chan struct{} + IsActive bool +} + +type SSEBroker struct { + clients map[string]*Client + mu sync.RWMutex + ticker *time.Ticker + done chan struct{} +} + +func newSSEBroker() *SSEBroker { + s := &SSEBroker{ + clients: make(map[string]*Client), + ticker: time.NewTicker(time.Minute), + done: make(chan struct{}), + } + go s.run() + return s +} + +func (s *SSEBroker) run() { + for { + select { + case <-s.ticker.C: + s.Broadcast("ticker") + case <-s.done: + s.ticker.Stop() + return + } + } +} + +func (s *SSEBroker) registerClient(id string) *Client { + s.mu.Lock() + defer s.mu.Unlock() + + client := &Client{ + ID: id, + Send: make(chan Message, 100), + Close: make(chan struct{}), + IsActive: true, + } + s.clients[id] = client + return client +} + +func (s *SSEBroker) unregisterClient(id string) { + s.mu.Lock() + defer s.mu.Unlock() + + if client, exists := s.clients[id]; exists { + close(client.Close) + delete(s.clients, id) + } +} + +func (s *SSEBroker) Broadcast(event string, data ...any) { + s.mu.RLock() + defer s.mu.RUnlock() + + var dataValue any + if len(data) > 0 { + dataValue = data[0] + } + + message := Message{ + Event: event, + Data: dataValue, + } + + for _, client := range s.clients { + if client.IsActive { + select { + case client.Send <- message: + default: + slog.Warn("client channel full, skipping message", "client_id", client.ID) + } + } + } +} + +func (s *SSEBroker) HandleSSE(w http.ResponseWriter, r *http.Request) { + slog.Debug("SSE called") + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Transfer-Encoding", "chunked") + + clientID := uuid.New().String() + client := s.registerClient(clientID) + defer s.unregisterClient(clientID) + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + slog.Info("SSE connection established", "client_id", clientID) + + fmt.Fprintf(w, "event: connection\ndata: %s\n\n", clientID) + flusher.Flush() + + clientGone := r.Context().Done() + + for { + select { + case message := <-client.Send: + var data string + switch v := message.Data.(type) { + case string: + data = v + case map[string]any, []any: + jsonData, err := json.Marshal(v) + if err != nil { + slog.Error("error marshaling message", "error", err) + continue + } + data = string(jsonData) + default: + jsonData, err := json.Marshal(v) + if err != nil { + slog.Error("error marshaling message", "error", err) + continue + } + data = string(jsonData) + } + fmt.Fprintf(w, "event: %s\ndata: %s\n\n", message.Event, data) + case <-client.Close: + slog.Info("client closed", "client_id", clientID) + return + case <-clientGone: + slog.Info("client gone", "client_id", clientID) + return + } + + flusher.Flush() + } +} + +func (s *SSEBroker) Shutdown() { + close(s.done) +} diff --git a/go.mod b/go.mod index 96ad00a..866d6e7 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/alexedwards/scs/v2 v2.8.0 github.com/go-sql-driver/mysql v1.9.2 github.com/golang-migrate/migrate/v4 v4.18.3 + github.com/google/uuid v1.6.0 github.com/jackc/pgconn v1.14.3 github.com/jackc/pgx/v5 v5.7.4 github.com/stretchr/testify v1.10.0 diff --git a/go.sum b/go.sum index e5edc3e..5c6d05b 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-migrate/migrate/v4 v4.18.3 h1:EYGkoOsvgHHfm5U/naS1RP/6PL/Xv3S4B/swMiAmDLs= github.com/golang-migrate/migrate/v4 v4.18.3/go.mod h1:99BKpIi6ruaaXRM1A77eqZ+FWPQ3cfRa+ZVy5bmWMaY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= diff --git a/router.go b/router.go index 891253a..943d86f 100644 --- a/router.go +++ b/router.go @@ -5,10 +5,8 @@ import ( "slices" ) -// Middleware is a function that wraps an http.Handler. type Middleware func(http.Handler) http.Handler -// Router wraps http.ServeMux and provides grouping and middleware support. type Router struct { globalChain []Middleware routeChain []Middleware @@ -16,7 +14,7 @@ type Router struct { *http.ServeMux } -func NewRouter() *Router { +func newRouter() *Router { return &Router{ServeMux: http.NewServeMux()} }