[handlers] use Validate.Var for slice validation

This commit is contained in:
Aleksandr Soloshenko 2025-08-14 16:28:54 +07:00 committed by Aleksandr
parent f50b85bdba
commit 0a71b45122
5 changed files with 206 additions and 69 deletions

View File

@ -43,7 +43,7 @@ func (h *Handler) ParamsParserValidator(c *fiber.Ctx, out any) error {
func (h *Handler) ValidateStruct(out any) error {
if h.Validator != nil {
if err := h.Validator.Struct(out); err != nil {
if err := h.Validator.Var(out, "required,dive"); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
}
}

View File

@ -11,7 +11,6 @@ import (
"github.com/android-sms-gateway/server/internal/sms-gateway/handlers/base"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
"go.uber.org/zap"
"go.uber.org/zap/zaptest"
)
@ -25,6 +24,16 @@ type testRequestBodyNoValidate struct {
Age int `json:"age" validate:"required"`
}
type testRequestQuery struct {
Name string `query:"name" validate:"required"`
Age int `query:"age" validate:"required"`
}
type testRequestParams struct {
ID string `params:"id" validate:"required"`
Name string `params:"name" validate:"required"`
}
func (t *testRequestBody) Validate() error {
if t.Age < 18 {
return fmt.Errorf("must be at least 18 years old")
@ -32,6 +41,20 @@ func (t *testRequestBody) Validate() error {
return nil
}
func (t *testRequestQuery) Validate() error {
if t.Age < 18 {
return fmt.Errorf("must be at least 18 years old")
}
return nil
}
func (t *testRequestParams) Validate() error {
if t.ID == "invalid" {
return fmt.Errorf("invalid ID")
}
return nil
}
func TestHandler_BodyParserValidator(t *testing.T) {
logger := zaptest.NewLogger(t)
validate := validator.New()
@ -100,7 +123,10 @@ func TestHandler_BodyParserValidator(t *testing.T) {
req = httptest.NewRequest("POST", test.path, nil)
}
resp, _ := app.Test(req)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test failed: %v", err)
}
if test.expectedStatus != resp.StatusCode {
t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode)
}
@ -109,89 +135,200 @@ func TestHandler_BodyParserValidator(t *testing.T) {
}
func TestHandler_QueryParserValidator(t *testing.T) {
type fields struct {
Logger *zap.Logger
Validator *validator.Validate
}
type args struct {
c *fiber.Ctx
out any
logger := zaptest.NewLogger(t)
validate := validator.New()
handler := &base.Handler{
Logger: logger,
Validator: validate,
}
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
var query testRequestQuery
return handler.QueryParserValidator(c, &query)
})
tests := []struct {
name string
fields fields
args args
wantErr bool
description string
path string
expectedStatus int
}{
// TODO: Add test cases.
{
description: "Invalid query parameters - non-integer age",
path: "/test?name=John&age=abc",
expectedStatus: fiber.StatusBadRequest,
},
{
description: "Valid query parameters",
path: "/test?name=John&age=25",
expectedStatus: fiber.StatusOK,
},
{
description: "Invalid query parameters - missing name",
path: "/test?age=25",
expectedStatus: fiber.StatusBadRequest,
},
{
description: "Invalid query parameters - age too low",
path: "/test?name=John&age=17",
expectedStatus: fiber.StatusBadRequest,
},
{
description: "Invalid query parameters - missing age",
path: "/test?name=John",
expectedStatus: fiber.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &base.Handler{
Logger: tt.fields.Logger,
Validator: tt.fields.Validator,
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
req := httptest.NewRequest("GET", test.path, nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test failed: %v", err)
}
if err := h.QueryParserValidator(tt.args.c, tt.args.out); (err != nil) != tt.wantErr {
t.Errorf("Handler.QueryParserValidator() error = %v, wantErr %v", err, tt.wantErr)
if test.expectedStatus != resp.StatusCode {
t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode)
}
})
}
}
func TestHandler_ParamsParserValidator(t *testing.T) {
type fields struct {
Logger *zap.Logger
Validator *validator.Validate
}
type args struct {
c *fiber.Ctx
out any
logger := zaptest.NewLogger(t)
validate := validator.New()
handler := &base.Handler{
Logger: logger,
Validator: validate,
}
app := fiber.New()
app.Get("/test/:id/:name", func(c *fiber.Ctx) error {
var params testRequestParams
return handler.ParamsParserValidator(c, &params)
})
tests := []struct {
name string
fields fields
args args
wantErr bool
description string
path string
expectedStatus int
}{
// TODO: Add test cases.
{
description: "Valid path parameters",
path: "/test/123/John",
expectedStatus: fiber.StatusOK,
},
{
description: "Invalid path parameters - missing id",
path: "/test//John",
expectedStatus: fiber.StatusNotFound,
},
{
description: "Invalid path parameters - missing name",
path: "/test/123/",
expectedStatus: fiber.StatusNotFound,
},
{
description: "Invalid path parameters - invalid ID",
path: "/test/invalid/John",
expectedStatus: fiber.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &base.Handler{
Logger: tt.fields.Logger,
Validator: tt.fields.Validator,
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
req := httptest.NewRequest("GET", test.path, nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test failed: %v", err)
}
if err := h.ParamsParserValidator(tt.args.c, tt.args.out); (err != nil) != tt.wantErr {
t.Errorf("Handler.ParamsParserValidator() error = %v, wantErr %v", err, tt.wantErr)
if test.expectedStatus != resp.StatusCode {
t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode)
}
})
}
}
func TestHandler_validateStruct(t *testing.T) {
type fields struct {
Logger *zap.Logger
Validator *validator.Validate
func TestHandler_ValidateStruct(t *testing.T) {
logger := zaptest.NewLogger(t)
validate := validator.New()
// Test with validator
handlerWithValidator := &base.Handler{
Logger: logger,
Validator: validate,
}
type args struct {
out any
// Test without validator
handlerWithoutValidator := &base.Handler{
Logger: logger,
Validator: nil,
}
tests := []struct {
name string
fields fields
args args
wantErr bool
description string
handler *base.Handler
input any
expectedStatus int
}{
// TODO: Add test cases.
{
description: "Valid struct with validator",
handler: handlerWithValidator,
input: &testRequestBody{Name: "John Doe", Age: 25},
expectedStatus: fiber.StatusOK,
},
{
description: "Invalid struct with validator - missing required field",
handler: handlerWithValidator,
input: &testRequestBody{Age: 25},
expectedStatus: fiber.StatusBadRequest,
},
{
description: "Invalid struct with validator - custom validation fails",
handler: handlerWithValidator,
input: &testRequestBody{Name: "John Doe", Age: 17},
expectedStatus: fiber.StatusBadRequest,
},
{
description: "Valid struct without validator",
handler: handlerWithoutValidator,
input: &testRequestBody{Name: "John Doe", Age: 25},
expectedStatus: fiber.StatusOK,
},
{
description: "Invalid struct without validator - custom validation fails",
handler: handlerWithoutValidator,
input: &testRequestBody{Name: "John Doe", Age: 17},
expectedStatus: fiber.StatusBadRequest,
},
{
description: "Valid struct with Validatable interface",
handler: handlerWithValidator,
input: &testRequestQuery{Name: "John", Age: 25},
expectedStatus: fiber.StatusOK,
},
{
description: "Invalid struct with Validatable interface",
handler: handlerWithValidator,
input: &testRequestQuery{Name: "John", Age: 17},
expectedStatus: fiber.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &base.Handler{
Logger: tt.fields.Logger,
Validator: tt.fields.Validator,
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
err := test.handler.ValidateStruct(test.input)
if test.expectedStatus == fiber.StatusOK && err != nil {
t.Errorf("Expected no error, got %v", err)
}
if err := h.ValidateStruct(tt.args.out); (err != nil) != tt.wantErr {
t.Errorf("Handler.validateStruct() error = %v, wantErr %v", err, tt.wantErr)
if test.expectedStatus == fiber.StatusBadRequest && err == nil {
t.Errorf("Expected error, got nil")
}
})
}

View File

@ -62,12 +62,12 @@ type ThirdPartyController struct {
func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error {
var params thirdPartyPostQueryParams
if err := h.QueryParserValidator(c, &params); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
return err
}
var req smsgateway.Message
if err := h.BodyParserValidator(c, &req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
return err
}
var device models.Device
@ -192,7 +192,7 @@ func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error {
func (h *ThirdPartyController) list(user models.User, c *fiber.Ctx) error {
params := thirdPartyGetQueryParams{}
if err := h.QueryParserValidator(c, &params); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
return err
}
messages, total, err := h.messagesSvc.SelectStates(user, params.ToFilter(), params.ToOptions())
@ -252,7 +252,7 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error {
func (h *ThirdPartyController) postInboxExport(user models.User, c *fiber.Ctx) error {
req := smsgateway.MessagesExportRequest{}
if err := h.BodyParserValidator(c, &req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
return err
}
device, err := h.devicesSvc.Get(user.ID, devices.WithID(req.DeviceID))

View File

@ -49,7 +49,7 @@ func (h *MobileController) list(device models.Device, c *fiber.Ctx) error {
// Get and validate order parameter
params := mobileGetQueryParams{}
if err := h.QueryParserValidator(c, &params); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
return err
}
msgs, err := h.messagesSvc.SelectPending(device.ID, params.OrderOrDefault())
@ -81,9 +81,9 @@ func (h *MobileController) list(device models.Device, c *fiber.Ctx) error {
//
// Update message state
func (h *MobileController) patch(device models.Device, c *fiber.Ctx) error {
var req smsgateway.MobilePatchMessageRequest
req := smsgateway.MobilePatchMessageRequest{}
if err := h.BodyParserValidator(c, &req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
return err
}
for _, v := range req {

View File

@ -96,7 +96,7 @@ func (h *mobileHandler) postDevice(c *fiber.Ctx) (err error) {
req := smsgateway.MobileRegisterRequest{}
if err = h.BodyParserValidator(c, &req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
return err
}
var (
@ -150,7 +150,7 @@ func (h *mobileHandler) patchDevice(device models.Device, c *fiber.Ctx) error {
req := smsgateway.MobileUpdateRequest{}
if err := h.BodyParserValidator(c, &req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
return err
}
if req.Id != device.ID {
@ -205,7 +205,7 @@ func (h *mobileHandler) changePassword(device models.Device, c *fiber.Ctx) error
req := smsgateway.MobileChangePasswordRequest{}
if err := h.BodyParserValidator(c, &req); err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
return err
}
if err := h.authSvc.ChangePassword(device.UserID, req.CurrentPassword, req.NewPassword); err != nil {