diff --git a/internal/config/config.go b/internal/config/config.go index 8461b11..8830b2b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,6 +13,7 @@ type Config struct { Database Database `yaml:"database"` // database config FCM FCMConfig `yaml:"fcm"` // firebase cloud messaging config Tasks Tasks `yaml:"tasks"` // tasks config + SSE SSE `yaml:"sse"` // server-sent events config } type Gateway struct { @@ -53,6 +54,10 @@ type HashingTask struct { IntervalSeconds uint16 `yaml:"interval_seconds" envconfig:"TASKS__HASHING__INTERVAL_SECONDS"` // hashing interval in seconds } +type SSE struct { + KeepAlivePeriodSeconds uint16 `yaml:"keep_alive_period_seconds" envconfig:"SSE__KEEP_ALIVE_PERIOD_SECONDS"` // keep alive period in seconds, 0 for no keep alive +} + var defaultConfig = Config{ Gateway: Gateway{Mode: GatewayModePublic}, HTTP: HTTP{ @@ -75,4 +80,7 @@ var defaultConfig = Config{ IntervalSeconds: uint16(15 * 60), }, }, + SSE: SSE{ + KeepAlivePeriodSeconds: 15, + }, } diff --git a/internal/config/module.go b/internal/config/module.go index 0495e88..c72a366 100644 --- a/internal/config/module.go +++ b/internal/config/module.go @@ -93,6 +93,8 @@ var Module = fx.Module( } }), fx.Provide(func(cfg Config) sse.Config { - return sse.NewConfig() + return sse.NewConfig( + sse.WithKeepAlivePeriod(time.Duration(cfg.SSE.KeepAlivePeriodSeconds) * time.Second), + ) }), ) diff --git a/internal/sms-gateway/modules/sse/config.go b/internal/sms-gateway/modules/sse/config.go index fa6d630..3b13443 100644 --- a/internal/sms-gateway/modules/sse/config.go +++ b/internal/sms-gateway/modules/sse/config.go @@ -2,28 +2,38 @@ package sse import "time" -const defaultKeepAlivePeriod = 15 * time.Second +type configOption func(*Config) type Config struct { keepAlivePeriod time.Duration } -func NewConfig() Config { - return Config{ - keepAlivePeriod: defaultKeepAlivePeriod, +const defaultKeepAlivePeriod = 15 * time.Second + +var defaultConfig = Config{ + keepAlivePeriod: defaultKeepAlivePeriod, +} + +func NewConfig(opts ...configOption) Config { + c := defaultConfig + + for _, opt := range opts { + opt(&c) } + + return c } func (c *Config) KeepAlivePeriod() time.Duration { return c.keepAlivePeriod } -func (c *Config) SetKeepAlivePeriod(d time.Duration) *Config { - if d <= 0 { +func WithKeepAlivePeriod(d time.Duration) configOption { + if d < 0 { d = defaultKeepAlivePeriod } - c.keepAlivePeriod = d - - return c + return func(c *Config) { + c.keepAlivePeriod = d + } } diff --git a/internal/sms-gateway/modules/sse/service.go b/internal/sms-gateway/modules/sse/service.go index b0164a5..4c58574 100644 --- a/internal/sms-gateway/modules/sse/service.go +++ b/internal/sms-gateway/modules/sse/service.go @@ -101,17 +101,13 @@ func (s *Service) Handler(deviceID string, c *fiber.Ctx) error { conn := s.registerConnection(deviceID) defer s.removeConnection(deviceID, conn.id) - if err := s.writeToStream(w, ":keepalive"); err != nil { - s.logger.Warn("Failed to write keepalive", - zap.String("device_id", deviceID), - zap.String("connection_id", conn.id), - zap.Error(err)) - return + // Conditionally create ticker + var ticker *time.Ticker + if s.config.keepAlivePeriod > 0 { + ticker = time.NewTicker(s.config.keepAlivePeriod) + defer ticker.Stop() } - ticker := time.NewTicker(s.config.keepAlivePeriod) - defer ticker.Stop() - for { select { case event := <-conn.channel: @@ -122,7 +118,14 @@ func (s *Service) Handler(deviceID string, c *fiber.Ctx) error { zap.Error(err)) return } - case <-ticker.C: + // Conditionally handle ticker events + case <-func() <-chan time.Time { + if ticker != nil { + return ticker.C + } + // Return nil channel that never fires when disabled + return make(chan time.Time) + }(): if err := s.writeToStream(w, ":keepalive"); err != nil { s.logger.Warn("Failed to write keepalive", zap.String("device_id", deviceID),