diff --git a/internal/sms-gateway/handlers/devices/3rdparty.go b/internal/sms-gateway/handlers/devices/3rdparty.go index e3a0f90..8af7288 100644 --- a/internal/sms-gateway/handlers/devices/3rdparty.go +++ b/internal/sms-gateway/handlers/devices/3rdparty.go @@ -1,6 +1,7 @@ package devices import ( + "errors" "fmt" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/base" @@ -68,7 +69,11 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { func (h *ThirdPartyController) remove(user models.User, c *fiber.Ctx) error { id := c.Params("id") - if err := h.devicesSvc.Remove(user.ID, devices.WithID(id)); err != nil { + err := h.devicesSvc.Remove(user.ID, devices.WithID(id)) + if errors.Is(err, devices.ErrNotFound) { + return fiber.NewError(fiber.StatusNotFound, err.Error()) + } + if err != nil { return fmt.Errorf("can't remove device: %w", err) } diff --git a/internal/sms-gateway/handlers/mobile.go b/internal/sms-gateway/handlers/mobile.go index a3f2b7c..e0b7666 100644 --- a/internal/sms-gateway/handlers/mobile.go +++ b/internal/sms-gateway/handlers/mobile.go @@ -13,6 +13,7 @@ import ( "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/webhooks" "github.com/android-sms-gateway/server/internal/sms-gateway/models" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/auth" + "github.com/android-sms-gateway/server/internal/sms-gateway/modules/devices" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/messages" "github.com/capcom6/go-helpers/anys" "github.com/go-playground/validator/v10" @@ -28,6 +29,7 @@ type mobileHandler struct { base.Handler authSvc *auth.Service + devicesSvc *devices.Service messagesSvc *messages.Service webhooksCtrl *webhooks.MobileController @@ -137,7 +139,7 @@ func (h *mobileHandler) patchDevice(device models.Device, c *fiber.Ctx) error { return fiber.ErrForbidden } - if err := h.authSvc.UpdateDevice(req.Id, req.PushToken); err != nil { + if err := h.devicesSvc.UpdatePushToken(req.Id, req.PushToken); err != nil { return err } @@ -272,6 +274,7 @@ type mobileHandlerParams struct { Validator *validator.Validate AuthSvc *auth.Service + DevicesSvc *devices.Service MessagesSvc *messages.Service WebhooksCtrl *webhooks.MobileController @@ -283,6 +286,7 @@ func newMobileHandler(params mobileHandlerParams) *mobileHandler { return &mobileHandler{ Handler: base.Handler{Logger: params.Logger, Validator: params.Validator}, authSvc: params.AuthSvc, + devicesSvc: params.DevicesSvc, messagesSvc: params.MessagesSvc, webhooksCtrl: params.WebhooksCtrl, idGen: idGen, diff --git a/internal/sms-gateway/modules/auth/service.go b/internal/sms-gateway/modules/auth/service.go index 6721ddf..1656029 100644 --- a/internal/sms-gateway/modules/auth/service.go +++ b/internal/sms-gateway/modules/auth/service.go @@ -38,8 +38,7 @@ type Service struct { users *repository usersCache *cache.Cache[models.User] - devicesSvc *devices.Service - devicesCache *cache.Cache[models.Device] + devicesSvc *devices.Service logger *zap.Logger @@ -56,8 +55,7 @@ func New(params Params) *Service { logger: params.Logger.Named("Service"), idgen: idgen, - usersCache: cache.New[models.User](cache.Config{TTL: 1 * time.Hour}), - devicesCache: cache.New[models.Device](cache.Config{TTL: 10 * time.Minute}), + usersCache: cache.New[models.User](cache.Config{TTL: 1 * time.Hour}), } } @@ -87,10 +85,6 @@ func (s *Service) RegisterDevice(user models.User, name, pushToken *string) (mod return device, s.devicesSvc.Insert(user.ID, &device) } -func (s *Service) UpdateDevice(id, pushToken string) error { - return s.devicesSvc.UpdateToken(id, pushToken) -} - func (s *Service) IsPublic() bool { return s.config.Mode == ModePublic } @@ -108,19 +102,9 @@ func (s *Service) AuthorizeRegistration(token string) error { } func (s *Service) AuthorizeDevice(token string) (models.Device, error) { - hash := sha256.Sum256([]byte(token)) - cacheKey := hex.EncodeToString(hash[:]) - - device, err := s.devicesCache.Get(cacheKey) + device, err := s.devicesSvc.GetByToken(token) if err != nil { - device, err = s.devicesSvc.GetByToken(token) - if err != nil { - return device, fmt.Errorf("can't get device: %w", err) - } - - if err := s.devicesCache.Set(cacheKey, device); err != nil { - s.logger.Error("can't cache device", zap.Error(err)) - } + return device, err } go func(id string) { diff --git a/internal/sms-gateway/modules/devices/repository.go b/internal/sms-gateway/modules/devices/repository.go index c55500f..639ca42 100644 --- a/internal/sms-gateway/modules/devices/repository.go +++ b/internal/sms-gateway/modules/devices/repository.go @@ -51,7 +51,7 @@ func (r *repository) Insert(device *models.Device) error { return r.db.Create(device).Error } -func (r *repository) UpdateToken(id, token string) error { +func (r *repository) UpdatePushToken(id, token string) error { return r.db.Model(&models.Device{}).Where("id", id).Update("push_token", token).Error } diff --git a/internal/sms-gateway/modules/devices/service.go b/internal/sms-gateway/modules/devices/service.go index 44dcca8..eb77021 100644 --- a/internal/sms-gateway/modules/devices/service.go +++ b/internal/sms-gateway/modules/devices/service.go @@ -2,10 +2,14 @@ package devices import ( "context" + "crypto/sha256" + "encoding/hex" + "fmt" "time" "github.com/android-sms-gateway/server/internal/sms-gateway/models" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/db" + "github.com/capcom6/go-helpers/cache" "go.uber.org/fx" "go.uber.org/zap" ) @@ -25,7 +29,8 @@ type ServiceParams struct { type Service struct { config Config - devices *repository + devices *repository + tokensCache *cache.Cache[models.Device] idGen db.IDGen @@ -62,11 +67,26 @@ func (s *Service) Get(userID string, filter ...SelectFilter) (models.Device, err // This method is used to retrieve a device by its auth token. If the device // does not exist, it returns ErrNotFound. func (s *Service) GetByToken(token string) (models.Device, error) { - return s.devices.Get(WithToken(token)) + hash := sha256.Sum256([]byte(token)) + cacheKey := hex.EncodeToString(hash[:]) + + device, err := s.tokensCache.Get(cacheKey) + if err != nil { + device, err = s.devices.Get(WithToken(token)) + if err != nil { + return device, fmt.Errorf("can't get device: %w", err) + } + + if err := s.tokensCache.Set(cacheKey, device); err != nil { + s.logger.Error("can't cache device", zap.Error(err)) + } + } + + return device, nil } -func (s *Service) UpdateToken(deviceId string, token string) error { - return s.devices.UpdateToken(deviceId, token) +func (s *Service) UpdatePushToken(deviceId string, token string) error { + return s.devices.UpdatePushToken(deviceId, token) } func (s *Service) UpdateLastSeen(deviceId string) error { @@ -78,6 +98,15 @@ func (s *Service) UpdateLastSeen(deviceId string) error { func (s *Service) Remove(userID string, filter ...SelectFilter) error { filter = append(filter, WithUserID(userID)) + device, err := s.Get(userID, filter...) + if err != nil { + return err + } + + if err := s.tokensCache.Delete(device.AuthToken); err != nil { + s.logger.Error("can't invalidate token cache", zap.Error(err)) + } + return s.devices.Remove(filter...) } @@ -90,9 +119,10 @@ func (s *Service) Clean(ctx context.Context) error { func NewService(params ServiceParams) *Service { return &Service{ - config: params.Config, - devices: params.Devices, - idGen: params.IDGen, - logger: params.Logger.Named("service"), + config: params.Config, + devices: params.Devices, + tokensCache: cache.New[models.Device](cache.Config{TTL: 10 * time.Minute}), + idGen: params.IDGen, + logger: params.Logger.Named("service"), } } diff --git a/pkg/swagger/docs/mobile.http b/pkg/swagger/docs/mobile.http index 517216c..30f4398 100644 --- a/pkg/swagger/docs/mobile.http +++ b/pkg/swagger/docs/mobile.http @@ -17,6 +17,16 @@ Content-Type: application/json "name": "Android Phone" } +### +PATCH {{baseUrl}}/device HTTP/1.1 +Authorization: Bearer {{mobileToken}} +Content-Type: application/json + +{ + "name": "Android Phone" +} + + ### GET {{baseUrl}}/message HTTP/1.1 Authorization: Bearer {{mobileToken}} diff --git a/pkg/swagger/docs/requests.http b/pkg/swagger/docs/requests.http index ae3b387..51974a5 100644 --- a/pkg/swagger/docs/requests.http +++ b/pkg/swagger/docs/requests.http @@ -60,7 +60,7 @@ GET {{baseUrl}}/3rdparty/v1/devices HTTP/1.1 Authorization: Basic {{credentials}} ### -DELETE {{baseUrl}}/3rdparty/v1/devices/MxKw03Q2ZVoomrLeDLlMO HTTP/1.1 +DELETE {{baseUrl}}/3rdparty/v1/devices/gF0jEYiaG_x9sI1YFWa7a HTTP/1.1 Authorization: Basic {{credentials}} ###