mirror of
https://github.com/makayabou/asg-server.git
synced 2026-05-02 17:43:36 +02:00
[sse] multiple connections support
This commit is contained in:
parent
63b93fbe01
commit
bbfa9a349d
2
go.mod
2
go.mod
@ -14,6 +14,7 @@ require (
|
||||
github.com/go-playground/validator/v10 v10.16.0
|
||||
github.com/go-sql-driver/mysql v1.7.1
|
||||
github.com/gofiber/fiber/v2 v2.52.5
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jaevor/go-nanoid v1.3.0
|
||||
github.com/nyaruka/phonenumbers v1.4.0
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
@ -45,7 +46,6 @@ require (
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/google/s2a-go v0.1.7 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.12.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@ -39,10 +39,6 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/capcom6/go-helpers v0.3.0 h1:ae18fLfluoPubiB2V+j4cIpfZaTuK4acS2entamaDkE=
|
||||
github.com/capcom6/go-helpers v0.3.0/go.mod h1:WDqc7HZNqHxUTisArkYIBZtqUfJBVyPWeQI+FMwEzAw=
|
||||
github.com/capcom6/go-infra-fx v0.2.1 h1:8rqr2ZV+YC2R07amHMdlE1XKLUhMe5yO+ffCJ/xXlNY=
|
||||
github.com/capcom6/go-infra-fx v0.2.1/go.mod h1:klScvB8QAKgJ19FfJOnUKK5tI0o9b79Aj2RmCJHfbN0=
|
||||
github.com/capcom6/go-infra-fx v0.2.2 h1:vTxlAqHUKpYTOY5Lp9OeTwwzxM34N8wH1vekEShg7eA=
|
||||
github.com/capcom6/go-infra-fx v0.2.2/go.mod h1:KHApbB6bwF7WQNIXW6ZdC4YG+d+ciwxvsnRpbOJa/Ys=
|
||||
github.com/capcom6/go-infra-fx v0.2.3 h1:ZSlBfz8qRaNVMtTBtJ4fLN89472CNimpJwy3kfBgGf8=
|
||||
github.com/capcom6/go-infra-fx v0.2.3/go.mod h1:KHApbB6bwF7WQNIXW6ZdC4YG+d+ciwxvsnRpbOJa/Ys=
|
||||
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@ -17,21 +18,27 @@ type Service struct {
|
||||
config Config
|
||||
|
||||
mu sync.RWMutex
|
||||
connections map[string]*sseConnection
|
||||
connections map[string][]*sseConnection
|
||||
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
type sseConnection struct {
|
||||
channel chan []byte
|
||||
id string
|
||||
channel chan eventWrapper
|
||||
closeSignal chan struct{}
|
||||
}
|
||||
|
||||
type eventWrapper struct {
|
||||
name string
|
||||
data []byte
|
||||
}
|
||||
|
||||
func NewService(config Config, logger *zap.Logger) *Service {
|
||||
return &Service{
|
||||
config: config,
|
||||
|
||||
connections: make(map[string]*sseConnection),
|
||||
connections: make(map[string][]*sseConnection),
|
||||
|
||||
logger: logger,
|
||||
}
|
||||
@ -41,23 +48,31 @@ func (s *Service) Send(deviceID string, event Event) error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
conn, exists := s.connections[deviceID]
|
||||
connections, exists := s.connections[deviceID]
|
||||
if !exists {
|
||||
return fmt.Errorf("no connection for device %s", deviceID)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(event)
|
||||
data, err := json.Marshal(event.Data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't marshal event: %w", err)
|
||||
}
|
||||
|
||||
sent := 0
|
||||
for _, conn := range connections {
|
||||
select {
|
||||
case conn.channel <- data:
|
||||
case conn.channel <- eventWrapper{string(event.Type), data}:
|
||||
// Message sent successfully
|
||||
sent++
|
||||
case <-conn.closeSignal:
|
||||
return fmt.Errorf("connection closed")
|
||||
s.logger.Warn("Connection closed while sending event", zap.String("device_id", deviceID), zap.String("connection_id", conn.id))
|
||||
default:
|
||||
return fmt.Errorf("connection buffer full")
|
||||
s.logger.Warn("Connection buffer full while sending event", zap.String("device_id", deviceID), zap.String("connection_id", conn.id))
|
||||
}
|
||||
}
|
||||
|
||||
if sent == 0 {
|
||||
return fmt.Errorf("no active connection for device %s", deviceID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -67,32 +82,29 @@ func (s *Service) Close(_ context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for id, conn := range s.connections {
|
||||
for deviceID, connections := range s.connections {
|
||||
for _, conn := range connections {
|
||||
close(conn.closeSignal)
|
||||
delete(s.connections, id)
|
||||
}
|
||||
delete(s.connections, deviceID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Handler(deviceId string, c *fiber.Ctx) error {
|
||||
func (s *Service) Handler(deviceID string, c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
|
||||
c.Status(fiber.StatusOK).Context().SetBodyStreamWriter(func(w *bufio.Writer) {
|
||||
s.registerConnection(deviceId)
|
||||
defer s.removeConnection(deviceId)
|
||||
|
||||
conn := s.getConnection(deviceId)
|
||||
if conn == nil {
|
||||
s.logger.Warn("Client not connected", zap.String("client_id", deviceId))
|
||||
return
|
||||
}
|
||||
conn := s.registerConnection(deviceID)
|
||||
defer s.removeConnection(deviceID, conn.id)
|
||||
|
||||
if err := s.writeToStream(w, ":keepalive"); err != nil {
|
||||
s.logger.Warn("Failed to write keepalive",
|
||||
zap.String("client_id", deviceId),
|
||||
zap.String("device_id", deviceID),
|
||||
zap.String("connection_id", conn.id),
|
||||
zap.Error(err))
|
||||
return
|
||||
}
|
||||
@ -102,17 +114,19 @@ func (s *Service) Handler(deviceId string, c *fiber.Ctx) error {
|
||||
|
||||
for {
|
||||
select {
|
||||
case data := <-conn.channel:
|
||||
if err := s.writeToStream(w, fmt.Sprintf("data: %s", utils.UnsafeString(data))); err != nil {
|
||||
case event := <-conn.channel:
|
||||
if err := s.writeToStream(w, fmt.Sprintf("event: %s\ndata: %s", event.name, utils.UnsafeString(event.data))); err != nil {
|
||||
s.logger.Warn("Failed to write event data",
|
||||
zap.String("client_id", deviceId),
|
||||
zap.String("device_id", deviceID),
|
||||
zap.String("connection_id", conn.id),
|
||||
zap.Error(err))
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
if err := s.writeToStream(w, ":keepalive"); err != nil {
|
||||
s.logger.Warn("Failed to write keepalive",
|
||||
zap.String("client_id", deviceId),
|
||||
zap.String("device_id", deviceID),
|
||||
zap.String("connection_id", conn.id),
|
||||
zap.Error(err))
|
||||
return
|
||||
}
|
||||
@ -132,36 +146,45 @@ func (s *Service) writeToStream(w *bufio.Writer, data string) error {
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
func (s *Service) registerConnection(id string) {
|
||||
func (s *Service) registerConnection(deviceID string) *sseConnection {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if existingConn, ok := s.connections[id]; ok {
|
||||
s.logger.Info("Closing existing SSE connection", zap.String("client_id", id))
|
||||
close(existingConn.closeSignal)
|
||||
delete(s.connections, id)
|
||||
}
|
||||
connID := uuid.NewString()
|
||||
|
||||
s.connections[id] = &sseConnection{
|
||||
channel: make(chan []byte, 8),
|
||||
conn := &sseConnection{
|
||||
id: connID,
|
||||
channel: make(chan eventWrapper, 8),
|
||||
closeSignal: make(chan struct{}),
|
||||
}
|
||||
s.logger.Info("Registering SSE connection", zap.String("client_id", id))
|
||||
|
||||
if _, ok := s.connections[deviceID]; !ok {
|
||||
s.connections[deviceID] = []*sseConnection{}
|
||||
}
|
||||
|
||||
func (s *Service) removeConnection(id string) {
|
||||
s.connections[deviceID] = append(s.connections[deviceID], conn)
|
||||
|
||||
s.logger.Info("Registering SSE connection", zap.String("device_id", deviceID), zap.String("connection_id", connID))
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func (s *Service) removeConnection(deviceID, connID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if conn, exists := s.connections[id]; exists {
|
||||
if connections, exists := s.connections[deviceID]; exists {
|
||||
for i, conn := range connections {
|
||||
if conn.id == connID {
|
||||
close(conn.closeSignal)
|
||||
delete(s.connections, id)
|
||||
s.logger.Info("Removing SSE connection", zap.String("client_id", id))
|
||||
s.connections[deviceID] = append(connections[:i], connections[i+1:]...)
|
||||
s.logger.Info("Removing SSE connection", zap.String("device_id", deviceID), zap.String("connection_id", connID))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) getConnection(id string) *sseConnection {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.connections[id]
|
||||
if len(s.connections[deviceID]) == 0 {
|
||||
delete(s.connections, deviceID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user