From 0a71b45122721d79ee69865e6997b680c1804656 Mon Sep 17 00:00:00 2001 From: Aleksandr Soloshenko Date: Thu, 14 Aug 2025 16:28:54 +0700 Subject: [PATCH] [handlers] use `Validate.Var` for slice validation --- internal/sms-gateway/handlers/base/handler.go | 2 +- .../sms-gateway/handlers/base/handler_test.go | 253 ++++++++++++++---- .../sms-gateway/handlers/messages/3rdparty.go | 8 +- .../sms-gateway/handlers/messages/mobile.go | 6 +- internal/sms-gateway/handlers/mobile.go | 6 +- 5 files changed, 206 insertions(+), 69 deletions(-) diff --git a/internal/sms-gateway/handlers/base/handler.go b/internal/sms-gateway/handlers/base/handler.go index d79521d..aaa8b1a 100644 --- a/internal/sms-gateway/handlers/base/handler.go +++ b/internal/sms-gateway/handlers/base/handler.go @@ -43,7 +43,7 @@ func (h *Handler) ParamsParserValidator(c *fiber.Ctx, out any) error { func (h *Handler) ValidateStruct(out any) error { if h.Validator != nil { - if err := h.Validator.Struct(out); err != nil { + if err := h.Validator.Var(out, "required,dive"); err != nil { return fiber.NewError(fiber.StatusBadRequest, err.Error()) } } diff --git a/internal/sms-gateway/handlers/base/handler_test.go b/internal/sms-gateway/handlers/base/handler_test.go index 9fb4450..b1f8f5a 100644 --- a/internal/sms-gateway/handlers/base/handler_test.go +++ b/internal/sms-gateway/handlers/base/handler_test.go @@ -11,7 +11,6 @@ import ( "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/base" "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2" - "go.uber.org/zap" "go.uber.org/zap/zaptest" ) @@ -25,6 +24,16 @@ type testRequestBodyNoValidate struct { Age int `json:"age" validate:"required"` } +type testRequestQuery struct { + Name string `query:"name" validate:"required"` + Age int `query:"age" validate:"required"` +} + +type testRequestParams struct { + ID string `params:"id" validate:"required"` + Name string `params:"name" validate:"required"` +} + func (t *testRequestBody) Validate() error { if t.Age < 18 { return fmt.Errorf("must be at least 18 years old") @@ -32,6 +41,20 @@ func (t *testRequestBody) Validate() error { return nil } +func (t *testRequestQuery) Validate() error { + if t.Age < 18 { + return fmt.Errorf("must be at least 18 years old") + } + return nil +} + +func (t *testRequestParams) Validate() error { + if t.ID == "invalid" { + return fmt.Errorf("invalid ID") + } + return nil +} + func TestHandler_BodyParserValidator(t *testing.T) { logger := zaptest.NewLogger(t) validate := validator.New() @@ -100,7 +123,10 @@ func TestHandler_BodyParserValidator(t *testing.T) { req = httptest.NewRequest("POST", test.path, nil) } - resp, _ := app.Test(req) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("app.Test failed: %v", err) + } if test.expectedStatus != resp.StatusCode { t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode) } @@ -109,89 +135,200 @@ func TestHandler_BodyParserValidator(t *testing.T) { } func TestHandler_QueryParserValidator(t *testing.T) { - type fields struct { - Logger *zap.Logger - Validator *validator.Validate - } - type args struct { - c *fiber.Ctx - out any + logger := zaptest.NewLogger(t) + validate := validator.New() + + handler := &base.Handler{ + Logger: logger, + Validator: validate, } + + app := fiber.New() + app.Get("/test", func(c *fiber.Ctx) error { + var query testRequestQuery + return handler.QueryParserValidator(c, &query) + }) + tests := []struct { - name string - fields fields - args args - wantErr bool + description string + path string + expectedStatus int }{ - // TODO: Add test cases. + { + description: "Invalid query parameters - non-integer age", + path: "/test?name=John&age=abc", + expectedStatus: fiber.StatusBadRequest, + }, + { + description: "Valid query parameters", + path: "/test?name=John&age=25", + expectedStatus: fiber.StatusOK, + }, + { + description: "Invalid query parameters - missing name", + path: "/test?age=25", + expectedStatus: fiber.StatusBadRequest, + }, + { + description: "Invalid query parameters - age too low", + path: "/test?name=John&age=17", + expectedStatus: fiber.StatusBadRequest, + }, + { + description: "Invalid query parameters - missing age", + path: "/test?name=John", + expectedStatus: fiber.StatusBadRequest, + }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - h := &base.Handler{ - Logger: tt.fields.Logger, - Validator: tt.fields.Validator, + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + req := httptest.NewRequest("GET", test.path, nil) + + resp, err := app.Test(req) + if err != nil { + t.Fatalf("app.Test failed: %v", err) } - if err := h.QueryParserValidator(tt.args.c, tt.args.out); (err != nil) != tt.wantErr { - t.Errorf("Handler.QueryParserValidator() error = %v, wantErr %v", err, tt.wantErr) + if test.expectedStatus != resp.StatusCode { + t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode) } }) } } func TestHandler_ParamsParserValidator(t *testing.T) { - type fields struct { - Logger *zap.Logger - Validator *validator.Validate - } - type args struct { - c *fiber.Ctx - out any + logger := zaptest.NewLogger(t) + validate := validator.New() + + handler := &base.Handler{ + Logger: logger, + Validator: validate, } + + app := fiber.New() + app.Get("/test/:id/:name", func(c *fiber.Ctx) error { + var params testRequestParams + return handler.ParamsParserValidator(c, ¶ms) + }) + tests := []struct { - name string - fields fields - args args - wantErr bool + description string + path string + expectedStatus int }{ - // TODO: Add test cases. + { + description: "Valid path parameters", + path: "/test/123/John", + expectedStatus: fiber.StatusOK, + }, + { + description: "Invalid path parameters - missing id", + path: "/test//John", + expectedStatus: fiber.StatusNotFound, + }, + { + description: "Invalid path parameters - missing name", + path: "/test/123/", + expectedStatus: fiber.StatusNotFound, + }, + { + description: "Invalid path parameters - invalid ID", + path: "/test/invalid/John", + expectedStatus: fiber.StatusBadRequest, + }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - h := &base.Handler{ - Logger: tt.fields.Logger, - Validator: tt.fields.Validator, + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + req := httptest.NewRequest("GET", test.path, nil) + + resp, err := app.Test(req) + if err != nil { + t.Fatalf("app.Test failed: %v", err) } - if err := h.ParamsParserValidator(tt.args.c, tt.args.out); (err != nil) != tt.wantErr { - t.Errorf("Handler.ParamsParserValidator() error = %v, wantErr %v", err, tt.wantErr) + if test.expectedStatus != resp.StatusCode { + t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode) } }) } } -func TestHandler_validateStruct(t *testing.T) { - type fields struct { - Logger *zap.Logger - Validator *validator.Validate +func TestHandler_ValidateStruct(t *testing.T) { + logger := zaptest.NewLogger(t) + validate := validator.New() + + // Test with validator + handlerWithValidator := &base.Handler{ + Logger: logger, + Validator: validate, } - type args struct { - out any + + // Test without validator + handlerWithoutValidator := &base.Handler{ + Logger: logger, + Validator: nil, } + tests := []struct { - name string - fields fields - args args - wantErr bool + description string + handler *base.Handler + input any + expectedStatus int }{ - // TODO: Add test cases. + { + description: "Valid struct with validator", + handler: handlerWithValidator, + input: &testRequestBody{Name: "John Doe", Age: 25}, + expectedStatus: fiber.StatusOK, + }, + { + description: "Invalid struct with validator - missing required field", + handler: handlerWithValidator, + input: &testRequestBody{Age: 25}, + expectedStatus: fiber.StatusBadRequest, + }, + { + description: "Invalid struct with validator - custom validation fails", + handler: handlerWithValidator, + input: &testRequestBody{Name: "John Doe", Age: 17}, + expectedStatus: fiber.StatusBadRequest, + }, + { + description: "Valid struct without validator", + handler: handlerWithoutValidator, + input: &testRequestBody{Name: "John Doe", Age: 25}, + expectedStatus: fiber.StatusOK, + }, + { + description: "Invalid struct without validator - custom validation fails", + handler: handlerWithoutValidator, + input: &testRequestBody{Name: "John Doe", Age: 17}, + expectedStatus: fiber.StatusBadRequest, + }, + { + description: "Valid struct with Validatable interface", + handler: handlerWithValidator, + input: &testRequestQuery{Name: "John", Age: 25}, + expectedStatus: fiber.StatusOK, + }, + { + description: "Invalid struct with Validatable interface", + handler: handlerWithValidator, + input: &testRequestQuery{Name: "John", Age: 17}, + expectedStatus: fiber.StatusBadRequest, + }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - h := &base.Handler{ - Logger: tt.fields.Logger, - Validator: tt.fields.Validator, + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + err := test.handler.ValidateStruct(test.input) + + if test.expectedStatus == fiber.StatusOK && err != nil { + t.Errorf("Expected no error, got %v", err) } - if err := h.ValidateStruct(tt.args.out); (err != nil) != tt.wantErr { - t.Errorf("Handler.validateStruct() error = %v, wantErr %v", err, tt.wantErr) + + if test.expectedStatus == fiber.StatusBadRequest && err == nil { + t.Errorf("Expected error, got nil") } }) } diff --git a/internal/sms-gateway/handlers/messages/3rdparty.go b/internal/sms-gateway/handlers/messages/3rdparty.go index d14df87..631b726 100644 --- a/internal/sms-gateway/handlers/messages/3rdparty.go +++ b/internal/sms-gateway/handlers/messages/3rdparty.go @@ -62,12 +62,12 @@ type ThirdPartyController struct { func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error { var params thirdPartyPostQueryParams if err := h.QueryParserValidator(c, ¶ms); err != nil { - return fiber.NewError(fiber.StatusBadRequest, err.Error()) + return err } var req smsgateway.Message if err := h.BodyParserValidator(c, &req); err != nil { - return fiber.NewError(fiber.StatusBadRequest, err.Error()) + return err } var device models.Device @@ -192,7 +192,7 @@ func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error { func (h *ThirdPartyController) list(user models.User, c *fiber.Ctx) error { params := thirdPartyGetQueryParams{} if err := h.QueryParserValidator(c, ¶ms); err != nil { - return fiber.NewError(fiber.StatusBadRequest, err.Error()) + return err } messages, total, err := h.messagesSvc.SelectStates(user, params.ToFilter(), params.ToOptions()) @@ -252,7 +252,7 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { func (h *ThirdPartyController) postInboxExport(user models.User, c *fiber.Ctx) error { req := smsgateway.MessagesExportRequest{} if err := h.BodyParserValidator(c, &req); err != nil { - return fiber.NewError(fiber.StatusBadRequest, err.Error()) + return err } device, err := h.devicesSvc.Get(user.ID, devices.WithID(req.DeviceID)) diff --git a/internal/sms-gateway/handlers/messages/mobile.go b/internal/sms-gateway/handlers/messages/mobile.go index 6814546..5c4d8bb 100644 --- a/internal/sms-gateway/handlers/messages/mobile.go +++ b/internal/sms-gateway/handlers/messages/mobile.go @@ -49,7 +49,7 @@ func (h *MobileController) list(device models.Device, c *fiber.Ctx) error { // Get and validate order parameter params := mobileGetQueryParams{} if err := h.QueryParserValidator(c, ¶ms); err != nil { - return fiber.NewError(fiber.StatusBadRequest, err.Error()) + return err } msgs, err := h.messagesSvc.SelectPending(device.ID, params.OrderOrDefault()) @@ -81,9 +81,9 @@ func (h *MobileController) list(device models.Device, c *fiber.Ctx) error { // // Update message state func (h *MobileController) patch(device models.Device, c *fiber.Ctx) error { - var req smsgateway.MobilePatchMessageRequest + req := smsgateway.MobilePatchMessageRequest{} if err := h.BodyParserValidator(c, &req); err != nil { - return fiber.NewError(fiber.StatusBadRequest, err.Error()) + return err } for _, v := range req { diff --git a/internal/sms-gateway/handlers/mobile.go b/internal/sms-gateway/handlers/mobile.go index 9c9914f..d9f04fc 100644 --- a/internal/sms-gateway/handlers/mobile.go +++ b/internal/sms-gateway/handlers/mobile.go @@ -96,7 +96,7 @@ func (h *mobileHandler) postDevice(c *fiber.Ctx) (err error) { req := smsgateway.MobileRegisterRequest{} if err = h.BodyParserValidator(c, &req); err != nil { - return fiber.NewError(fiber.StatusBadRequest, err.Error()) + return err } var ( @@ -150,7 +150,7 @@ func (h *mobileHandler) patchDevice(device models.Device, c *fiber.Ctx) error { req := smsgateway.MobileUpdateRequest{} if err := h.BodyParserValidator(c, &req); err != nil { - return fiber.NewError(fiber.StatusBadRequest, err.Error()) + return err } if req.Id != device.ID { @@ -205,7 +205,7 @@ func (h *mobileHandler) changePassword(device models.Device, c *fiber.Ctx) error req := smsgateway.MobileChangePasswordRequest{} if err := h.BodyParserValidator(c, &req); err != nil { - return fiber.NewError(fiber.StatusBadRequest, err.Error()) + return err } if err := h.authSvc.ChangePassword(device.UserID, req.CurrentPassword, req.NewPassword); err != nil {