From bbfa9a349d9d1de17913d06a957228712dbb12d6 Mon Sep 17 00:00:00 2001 From: Aleksandr Soloshenko Date: Mon, 28 Jul 2025 18:35:17 +0700 Subject: [PATCH] [sse] multiple connections support --- go.mod | 2 +- go.sum | 4 - internal/sms-gateway/modules/sse/service.go | 121 ++++++++++++-------- 3 files changed, 73 insertions(+), 54 deletions(-) diff --git a/go.mod b/go.mod index 127eb4c..5c3bf10 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/go-playground/validator/v10 v10.16.0 github.com/go-sql-driver/mysql v1.7.1 github.com/gofiber/fiber/v2 v2.52.5 + github.com/google/uuid v1.6.0 github.com/jaevor/go-nanoid v1.3.0 github.com/nyaruka/phonenumbers v1.4.0 github.com/prometheus/client_golang v1.19.1 @@ -45,7 +46,6 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/s2a-go v0.1.7 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect diff --git a/go.sum b/go.sum index e9b38e6..eea65f3 100644 --- a/go.sum +++ b/go.sum @@ -39,10 +39,6 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/capcom6/go-helpers v0.3.0 h1:ae18fLfluoPubiB2V+j4cIpfZaTuK4acS2entamaDkE= github.com/capcom6/go-helpers v0.3.0/go.mod h1:WDqc7HZNqHxUTisArkYIBZtqUfJBVyPWeQI+FMwEzAw= -github.com/capcom6/go-infra-fx v0.2.1 h1:8rqr2ZV+YC2R07amHMdlE1XKLUhMe5yO+ffCJ/xXlNY= -github.com/capcom6/go-infra-fx v0.2.1/go.mod h1:klScvB8QAKgJ19FfJOnUKK5tI0o9b79Aj2RmCJHfbN0= -github.com/capcom6/go-infra-fx v0.2.2 h1:vTxlAqHUKpYTOY5Lp9OeTwwzxM34N8wH1vekEShg7eA= -github.com/capcom6/go-infra-fx v0.2.2/go.mod h1:KHApbB6bwF7WQNIXW6ZdC4YG+d+ciwxvsnRpbOJa/Ys= github.com/capcom6/go-infra-fx v0.2.3 h1:ZSlBfz8qRaNVMtTBtJ4fLN89472CNimpJwy3kfBgGf8= github.com/capcom6/go-infra-fx v0.2.3/go.mod h1:KHApbB6bwF7WQNIXW6ZdC4YG+d+ciwxvsnRpbOJa/Ys= github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= diff --git a/internal/sms-gateway/modules/sse/service.go b/internal/sms-gateway/modules/sse/service.go index 7110f9b..b0164a5 100644 --- a/internal/sms-gateway/modules/sse/service.go +++ b/internal/sms-gateway/modules/sse/service.go @@ -10,6 +10,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/utils" + "github.com/google/uuid" "go.uber.org/zap" ) @@ -17,21 +18,27 @@ type Service struct { config Config mu sync.RWMutex - connections map[string]*sseConnection + connections map[string][]*sseConnection logger *zap.Logger } type sseConnection struct { - channel chan []byte + id string + channel chan eventWrapper closeSignal chan struct{} } +type eventWrapper struct { + name string + data []byte +} + func NewService(config Config, logger *zap.Logger) *Service { return &Service{ config: config, - connections: make(map[string]*sseConnection), + connections: make(map[string][]*sseConnection), logger: logger, } @@ -41,23 +48,31 @@ func (s *Service) Send(deviceID string, event Event) error { s.mu.RLock() defer s.mu.RUnlock() - conn, exists := s.connections[deviceID] + connections, exists := s.connections[deviceID] if !exists { return fmt.Errorf("no connection for device %s", deviceID) } - data, err := json.Marshal(event) + data, err := json.Marshal(event.Data) if err != nil { return fmt.Errorf("can't marshal event: %w", err) } - select { - case conn.channel <- data: - // Message sent successfully - case <-conn.closeSignal: - return fmt.Errorf("connection closed") - default: - return fmt.Errorf("connection buffer full") + sent := 0 + for _, conn := range connections { + select { + case conn.channel <- eventWrapper{string(event.Type), data}: + // Message sent successfully + sent++ + case <-conn.closeSignal: + s.logger.Warn("Connection closed while sending event", zap.String("device_id", deviceID), zap.String("connection_id", conn.id)) + default: + s.logger.Warn("Connection buffer full while sending event", zap.String("device_id", deviceID), zap.String("connection_id", conn.id)) + } + } + + if sent == 0 { + return fmt.Errorf("no active connection for device %s", deviceID) } return nil @@ -67,32 +82,29 @@ func (s *Service) Close(_ context.Context) error { s.mu.Lock() defer s.mu.Unlock() - for id, conn := range s.connections { - close(conn.closeSignal) - delete(s.connections, id) + for deviceID, connections := range s.connections { + for _, conn := range connections { + close(conn.closeSignal) + } + delete(s.connections, deviceID) } return nil } -func (s *Service) Handler(deviceId string, c *fiber.Ctx) error { +func (s *Service) Handler(deviceID string, c *fiber.Ctx) error { c.Set("Content-Type", "text/event-stream") c.Set("Cache-Control", "no-cache") c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") c.Status(fiber.StatusOK).Context().SetBodyStreamWriter(func(w *bufio.Writer) { - s.registerConnection(deviceId) - defer s.removeConnection(deviceId) - - conn := s.getConnection(deviceId) - if conn == nil { - s.logger.Warn("Client not connected", zap.String("client_id", deviceId)) - return - } + conn := s.registerConnection(deviceID) + defer s.removeConnection(deviceID, conn.id) if err := s.writeToStream(w, ":keepalive"); err != nil { s.logger.Warn("Failed to write keepalive", - zap.String("client_id", deviceId), + zap.String("device_id", deviceID), + zap.String("connection_id", conn.id), zap.Error(err)) return } @@ -102,17 +114,19 @@ func (s *Service) Handler(deviceId string, c *fiber.Ctx) error { for { select { - case data := <-conn.channel: - if err := s.writeToStream(w, fmt.Sprintf("data: %s", utils.UnsafeString(data))); err != nil { + case event := <-conn.channel: + if err := s.writeToStream(w, fmt.Sprintf("event: %s\ndata: %s", event.name, utils.UnsafeString(event.data))); err != nil { s.logger.Warn("Failed to write event data", - zap.String("client_id", deviceId), + zap.String("device_id", deviceID), + zap.String("connection_id", conn.id), zap.Error(err)) return } case <-ticker.C: if err := s.writeToStream(w, ":keepalive"); err != nil { s.logger.Warn("Failed to write keepalive", - zap.String("client_id", deviceId), + zap.String("device_id", deviceID), + zap.String("connection_id", conn.id), zap.Error(err)) return } @@ -132,36 +146,45 @@ func (s *Service) writeToStream(w *bufio.Writer, data string) error { return w.Flush() } -func (s *Service) registerConnection(id string) { +func (s *Service) registerConnection(deviceID string) *sseConnection { s.mu.Lock() defer s.mu.Unlock() - if existingConn, ok := s.connections[id]; ok { - s.logger.Info("Closing existing SSE connection", zap.String("client_id", id)) - close(existingConn.closeSignal) - delete(s.connections, id) - } + connID := uuid.NewString() - s.connections[id] = &sseConnection{ - channel: make(chan []byte, 8), + conn := &sseConnection{ + id: connID, + channel: make(chan eventWrapper, 8), closeSignal: make(chan struct{}), } - s.logger.Info("Registering SSE connection", zap.String("client_id", id)) + + if _, ok := s.connections[deviceID]; !ok { + s.connections[deviceID] = []*sseConnection{} + } + + s.connections[deviceID] = append(s.connections[deviceID], conn) + + s.logger.Info("Registering SSE connection", zap.String("device_id", deviceID), zap.String("connection_id", connID)) + + return conn } -func (s *Service) removeConnection(id string) { +func (s *Service) removeConnection(deviceID, connID string) { s.mu.Lock() defer s.mu.Unlock() - if conn, exists := s.connections[id]; exists { - close(conn.closeSignal) - delete(s.connections, id) - s.logger.Info("Removing SSE connection", zap.String("client_id", id)) + if connections, exists := s.connections[deviceID]; exists { + for i, conn := range connections { + if conn.id == connID { + close(conn.closeSignal) + s.connections[deviceID] = append(connections[:i], connections[i+1:]...) + s.logger.Info("Removing SSE connection", zap.String("device_id", deviceID), zap.String("connection_id", connID)) + break + } + } + + if len(s.connections[deviceID]) == 0 { + delete(s.connections, deviceID) + } } } - -func (s *Service) getConnection(id string) *sseConnection { - s.mu.RLock() - defer s.mu.RUnlock() - return s.connections[id] -}