[handlers] use Validate.Var for slice validation

This commit is contained in:
Aleksandr Soloshenko 2025-08-14 16:28:54 +07:00 committed by Aleksandr
parent f50b85bdba
commit 0a71b45122
5 changed files with 206 additions and 69 deletions

View File

@ -43,7 +43,7 @@ func (h *Handler) ParamsParserValidator(c *fiber.Ctx, out any) error {
func (h *Handler) ValidateStruct(out any) error { func (h *Handler) ValidateStruct(out any) error {
if h.Validator != nil { if h.Validator != nil {
if err := h.Validator.Struct(out); err != nil { if err := h.Validator.Var(out, "required,dive"); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error()) return fiber.NewError(fiber.StatusBadRequest, err.Error())
} }
} }

View File

@ -11,7 +11,6 @@ import (
"github.com/android-sms-gateway/server/internal/sms-gateway/handlers/base" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/base"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"go.uber.org/zap"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
) )
@ -25,6 +24,16 @@ type testRequestBodyNoValidate struct {
Age int `json:"age" validate:"required"` Age int `json:"age" validate:"required"`
} }
type testRequestQuery struct {
Name string `query:"name" validate:"required"`
Age int `query:"age" validate:"required"`
}
type testRequestParams struct {
ID string `params:"id" validate:"required"`
Name string `params:"name" validate:"required"`
}
func (t *testRequestBody) Validate() error { func (t *testRequestBody) Validate() error {
if t.Age < 18 { if t.Age < 18 {
return fmt.Errorf("must be at least 18 years old") return fmt.Errorf("must be at least 18 years old")
@ -32,6 +41,20 @@ func (t *testRequestBody) Validate() error {
return nil return nil
} }
func (t *testRequestQuery) Validate() error {
if t.Age < 18 {
return fmt.Errorf("must be at least 18 years old")
}
return nil
}
func (t *testRequestParams) Validate() error {
if t.ID == "invalid" {
return fmt.Errorf("invalid ID")
}
return nil
}
func TestHandler_BodyParserValidator(t *testing.T) { func TestHandler_BodyParserValidator(t *testing.T) {
logger := zaptest.NewLogger(t) logger := zaptest.NewLogger(t)
validate := validator.New() validate := validator.New()
@ -100,7 +123,10 @@ func TestHandler_BodyParserValidator(t *testing.T) {
req = httptest.NewRequest("POST", test.path, nil) req = httptest.NewRequest("POST", test.path, nil)
} }
resp, _ := app.Test(req) resp, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test failed: %v", err)
}
if test.expectedStatus != resp.StatusCode { if test.expectedStatus != resp.StatusCode {
t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode) t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode)
} }
@ -109,89 +135,200 @@ func TestHandler_BodyParserValidator(t *testing.T) {
} }
func TestHandler_QueryParserValidator(t *testing.T) { func TestHandler_QueryParserValidator(t *testing.T) {
type fields struct { logger := zaptest.NewLogger(t)
Logger *zap.Logger validate := validator.New()
Validator *validator.Validate
} handler := &base.Handler{
type args struct { Logger: logger,
c *fiber.Ctx Validator: validate,
out any
} }
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
var query testRequestQuery
return handler.QueryParserValidator(c, &query)
})
tests := []struct { tests := []struct {
name string description string
fields fields path string
args args expectedStatus int
wantErr bool
}{ }{
// TODO: Add test cases. {
description: "Invalid query parameters - non-integer age",
path: "/test?name=John&age=abc",
expectedStatus: fiber.StatusBadRequest,
},
{
description: "Valid query parameters",
path: "/test?name=John&age=25",
expectedStatus: fiber.StatusOK,
},
{
description: "Invalid query parameters - missing name",
path: "/test?age=25",
expectedStatus: fiber.StatusBadRequest,
},
{
description: "Invalid query parameters - age too low",
path: "/test?name=John&age=17",
expectedStatus: fiber.StatusBadRequest,
},
{
description: "Invalid query parameters - missing age",
path: "/test?name=John",
expectedStatus: fiber.StatusBadRequest,
},
} }
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { for _, test := range tests {
h := &base.Handler{ t.Run(test.description, func(t *testing.T) {
Logger: tt.fields.Logger, req := httptest.NewRequest("GET", test.path, nil)
Validator: tt.fields.Validator,
resp, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test failed: %v", err)
} }
if err := h.QueryParserValidator(tt.args.c, tt.args.out); (err != nil) != tt.wantErr { if test.expectedStatus != resp.StatusCode {
t.Errorf("Handler.QueryParserValidator() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode)
} }
}) })
} }
} }
func TestHandler_ParamsParserValidator(t *testing.T) { func TestHandler_ParamsParserValidator(t *testing.T) {
type fields struct { logger := zaptest.NewLogger(t)
Logger *zap.Logger validate := validator.New()
Validator *validator.Validate
} handler := &base.Handler{
type args struct { Logger: logger,
c *fiber.Ctx Validator: validate,
out any
} }
app := fiber.New()
app.Get("/test/:id/:name", func(c *fiber.Ctx) error {
var params testRequestParams
return handler.ParamsParserValidator(c, &params)
})
tests := []struct { tests := []struct {
name string description string
fields fields path string
args args expectedStatus int
wantErr bool
}{ }{
// TODO: Add test cases. {
description: "Valid path parameters",
path: "/test/123/John",
expectedStatus: fiber.StatusOK,
},
{
description: "Invalid path parameters - missing id",
path: "/test//John",
expectedStatus: fiber.StatusNotFound,
},
{
description: "Invalid path parameters - missing name",
path: "/test/123/",
expectedStatus: fiber.StatusNotFound,
},
{
description: "Invalid path parameters - invalid ID",
path: "/test/invalid/John",
expectedStatus: fiber.StatusBadRequest,
},
} }
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { for _, test := range tests {
h := &base.Handler{ t.Run(test.description, func(t *testing.T) {
Logger: tt.fields.Logger, req := httptest.NewRequest("GET", test.path, nil)
Validator: tt.fields.Validator,
resp, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test failed: %v", err)
} }
if err := h.ParamsParserValidator(tt.args.c, tt.args.out); (err != nil) != tt.wantErr { if test.expectedStatus != resp.StatusCode {
t.Errorf("Handler.ParamsParserValidator() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode)
} }
}) })
} }
} }
func TestHandler_validateStruct(t *testing.T) { func TestHandler_ValidateStruct(t *testing.T) {
type fields struct { logger := zaptest.NewLogger(t)
Logger *zap.Logger validate := validator.New()
Validator *validator.Validate
// Test with validator
handlerWithValidator := &base.Handler{
Logger: logger,
Validator: validate,
} }
type args struct {
out any // Test without validator
handlerWithoutValidator := &base.Handler{
Logger: logger,
Validator: nil,
} }
tests := []struct { tests := []struct {
name string description string
fields fields handler *base.Handler
args args input any
wantErr bool expectedStatus int
}{ }{
// TODO: Add test cases. {
description: "Valid struct with validator",
handler: handlerWithValidator,
input: &testRequestBody{Name: "John Doe", Age: 25},
expectedStatus: fiber.StatusOK,
},
{
description: "Invalid struct with validator - missing required field",
handler: handlerWithValidator,
input: &testRequestBody{Age: 25},
expectedStatus: fiber.StatusBadRequest,
},
{
description: "Invalid struct with validator - custom validation fails",
handler: handlerWithValidator,
input: &testRequestBody{Name: "John Doe", Age: 17},
expectedStatus: fiber.StatusBadRequest,
},
{
description: "Valid struct without validator",
handler: handlerWithoutValidator,
input: &testRequestBody{Name: "John Doe", Age: 25},
expectedStatus: fiber.StatusOK,
},
{
description: "Invalid struct without validator - custom validation fails",
handler: handlerWithoutValidator,
input: &testRequestBody{Name: "John Doe", Age: 17},
expectedStatus: fiber.StatusBadRequest,
},
{
description: "Valid struct with Validatable interface",
handler: handlerWithValidator,
input: &testRequestQuery{Name: "John", Age: 25},
expectedStatus: fiber.StatusOK,
},
{
description: "Invalid struct with Validatable interface",
handler: handlerWithValidator,
input: &testRequestQuery{Name: "John", Age: 17},
expectedStatus: fiber.StatusBadRequest,
},
} }
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { for _, test := range tests {
h := &base.Handler{ t.Run(test.description, func(t *testing.T) {
Logger: tt.fields.Logger, err := test.handler.ValidateStruct(test.input)
Validator: tt.fields.Validator,
if test.expectedStatus == fiber.StatusOK && err != nil {
t.Errorf("Expected no error, got %v", err)
} }
if err := h.ValidateStruct(tt.args.out); (err != nil) != tt.wantErr {
t.Errorf("Handler.validateStruct() error = %v, wantErr %v", err, tt.wantErr) if test.expectedStatus == fiber.StatusBadRequest && err == nil {
t.Errorf("Expected error, got nil")
} }
}) })
} }

View File

@ -62,12 +62,12 @@ type ThirdPartyController struct {
func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error { func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error {
var params thirdPartyPostQueryParams var params thirdPartyPostQueryParams
if err := h.QueryParserValidator(c, &params); err != nil { if err := h.QueryParserValidator(c, &params); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error()) return err
} }
var req smsgateway.Message var req smsgateway.Message
if err := h.BodyParserValidator(c, &req); err != nil { if err := h.BodyParserValidator(c, &req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error()) return err
} }
var device models.Device var device models.Device
@ -192,7 +192,7 @@ func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error {
func (h *ThirdPartyController) list(user models.User, c *fiber.Ctx) error { func (h *ThirdPartyController) list(user models.User, c *fiber.Ctx) error {
params := thirdPartyGetQueryParams{} params := thirdPartyGetQueryParams{}
if err := h.QueryParserValidator(c, &params); err != nil { if err := h.QueryParserValidator(c, &params); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error()) return err
} }
messages, total, err := h.messagesSvc.SelectStates(user, params.ToFilter(), params.ToOptions()) messages, total, err := h.messagesSvc.SelectStates(user, params.ToFilter(), params.ToOptions())
@ -252,7 +252,7 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error {
func (h *ThirdPartyController) postInboxExport(user models.User, c *fiber.Ctx) error { func (h *ThirdPartyController) postInboxExport(user models.User, c *fiber.Ctx) error {
req := smsgateway.MessagesExportRequest{} req := smsgateway.MessagesExportRequest{}
if err := h.BodyParserValidator(c, &req); err != nil { if err := h.BodyParserValidator(c, &req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error()) return err
} }
device, err := h.devicesSvc.Get(user.ID, devices.WithID(req.DeviceID)) device, err := h.devicesSvc.Get(user.ID, devices.WithID(req.DeviceID))

View File

@ -49,7 +49,7 @@ func (h *MobileController) list(device models.Device, c *fiber.Ctx) error {
// Get and validate order parameter // Get and validate order parameter
params := mobileGetQueryParams{} params := mobileGetQueryParams{}
if err := h.QueryParserValidator(c, &params); err != nil { if err := h.QueryParserValidator(c, &params); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error()) return err
} }
msgs, err := h.messagesSvc.SelectPending(device.ID, params.OrderOrDefault()) msgs, err := h.messagesSvc.SelectPending(device.ID, params.OrderOrDefault())
@ -81,9 +81,9 @@ func (h *MobileController) list(device models.Device, c *fiber.Ctx) error {
// //
// Update message state // Update message state
func (h *MobileController) patch(device models.Device, c *fiber.Ctx) error { func (h *MobileController) patch(device models.Device, c *fiber.Ctx) error {
var req smsgateway.MobilePatchMessageRequest req := smsgateway.MobilePatchMessageRequest{}
if err := h.BodyParserValidator(c, &req); err != nil { if err := h.BodyParserValidator(c, &req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error()) return err
} }
for _, v := range req { for _, v := range req {

View File

@ -96,7 +96,7 @@ func (h *mobileHandler) postDevice(c *fiber.Ctx) (err error) {
req := smsgateway.MobileRegisterRequest{} req := smsgateway.MobileRegisterRequest{}
if err = h.BodyParserValidator(c, &req); err != nil { if err = h.BodyParserValidator(c, &req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error()) return err
} }
var ( var (
@ -150,7 +150,7 @@ func (h *mobileHandler) patchDevice(device models.Device, c *fiber.Ctx) error {
req := smsgateway.MobileUpdateRequest{} req := smsgateway.MobileUpdateRequest{}
if err := h.BodyParserValidator(c, &req); err != nil { if err := h.BodyParserValidator(c, &req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error()) return err
} }
if req.Id != device.ID { if req.Id != device.ID {
@ -205,7 +205,7 @@ func (h *mobileHandler) changePassword(device models.Device, c *fiber.Ctx) error
req := smsgateway.MobileChangePasswordRequest{} req := smsgateway.MobileChangePasswordRequest{}
if err := h.BodyParserValidator(c, &req); err != nil { if err := h.BodyParserValidator(c, &req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error()) return err
} }
if err := h.authSvc.ChangePassword(device.UserID, req.CurrentPassword, req.NewPassword); err != nil { if err := h.authSvc.ChangePassword(device.UserID, req.CurrentPassword, req.NewPassword); err != nil {