mirror of
https://github.com/makayabou/asg-server.git
synced 2026-05-02 17:43:36 +02:00
[handlers] use Validate.Var for slice validation
This commit is contained in:
parent
f50b85bdba
commit
0a71b45122
@ -43,7 +43,7 @@ func (h *Handler) ParamsParserValidator(c *fiber.Ctx, out any) error {
|
||||
|
||||
func (h *Handler) ValidateStruct(out any) error {
|
||||
if h.Validator != nil {
|
||||
if err := h.Validator.Struct(out); err != nil {
|
||||
if err := h.Validator.Var(out, "required,dive"); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@ -11,7 +11,6 @@ import (
|
||||
"github.com/android-sms-gateway/server/internal/sms-gateway/handlers/base"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
@ -25,6 +24,16 @@ type testRequestBodyNoValidate struct {
|
||||
Age int `json:"age" validate:"required"`
|
||||
}
|
||||
|
||||
type testRequestQuery struct {
|
||||
Name string `query:"name" validate:"required"`
|
||||
Age int `query:"age" validate:"required"`
|
||||
}
|
||||
|
||||
type testRequestParams struct {
|
||||
ID string `params:"id" validate:"required"`
|
||||
Name string `params:"name" validate:"required"`
|
||||
}
|
||||
|
||||
func (t *testRequestBody) Validate() error {
|
||||
if t.Age < 18 {
|
||||
return fmt.Errorf("must be at least 18 years old")
|
||||
@ -32,6 +41,20 @@ func (t *testRequestBody) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testRequestQuery) Validate() error {
|
||||
if t.Age < 18 {
|
||||
return fmt.Errorf("must be at least 18 years old")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testRequestParams) Validate() error {
|
||||
if t.ID == "invalid" {
|
||||
return fmt.Errorf("invalid ID")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHandler_BodyParserValidator(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
validate := validator.New()
|
||||
@ -100,7 +123,10 @@ func TestHandler_BodyParserValidator(t *testing.T) {
|
||||
req = httptest.NewRequest("POST", test.path, nil)
|
||||
}
|
||||
|
||||
resp, _ := app.Test(req)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("app.Test failed: %v", err)
|
||||
}
|
||||
if test.expectedStatus != resp.StatusCode {
|
||||
t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode)
|
||||
}
|
||||
@ -109,89 +135,200 @@ func TestHandler_BodyParserValidator(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandler_QueryParserValidator(t *testing.T) {
|
||||
type fields struct {
|
||||
Logger *zap.Logger
|
||||
Validator *validator.Validate
|
||||
}
|
||||
type args struct {
|
||||
c *fiber.Ctx
|
||||
out any
|
||||
logger := zaptest.NewLogger(t)
|
||||
validate := validator.New()
|
||||
|
||||
handler := &base.Handler{
|
||||
Logger: logger,
|
||||
Validator: validate,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
var query testRequestQuery
|
||||
return handler.QueryParserValidator(c, &query)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
description string
|
||||
path string
|
||||
expectedStatus int
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
{
|
||||
description: "Invalid query parameters - non-integer age",
|
||||
path: "/test?name=John&age=abc",
|
||||
expectedStatus: fiber.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
description: "Valid query parameters",
|
||||
path: "/test?name=John&age=25",
|
||||
expectedStatus: fiber.StatusOK,
|
||||
},
|
||||
{
|
||||
description: "Invalid query parameters - missing name",
|
||||
path: "/test?age=25",
|
||||
expectedStatus: fiber.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
description: "Invalid query parameters - age too low",
|
||||
path: "/test?name=John&age=17",
|
||||
expectedStatus: fiber.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
description: "Invalid query parameters - missing age",
|
||||
path: "/test?name=John",
|
||||
expectedStatus: fiber.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &base.Handler{
|
||||
Logger: tt.fields.Logger,
|
||||
Validator: tt.fields.Validator,
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", test.path, nil)
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("app.Test failed: %v", err)
|
||||
}
|
||||
if err := h.QueryParserValidator(tt.args.c, tt.args.out); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Handler.QueryParserValidator() error = %v, wantErr %v", err, tt.wantErr)
|
||||
if test.expectedStatus != resp.StatusCode {
|
||||
t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_ParamsParserValidator(t *testing.T) {
|
||||
type fields struct {
|
||||
Logger *zap.Logger
|
||||
Validator *validator.Validate
|
||||
}
|
||||
type args struct {
|
||||
c *fiber.Ctx
|
||||
out any
|
||||
logger := zaptest.NewLogger(t)
|
||||
validate := validator.New()
|
||||
|
||||
handler := &base.Handler{
|
||||
Logger: logger,
|
||||
Validator: validate,
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/test/:id/:name", func(c *fiber.Ctx) error {
|
||||
var params testRequestParams
|
||||
return handler.ParamsParserValidator(c, ¶ms)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
description string
|
||||
path string
|
||||
expectedStatus int
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
{
|
||||
description: "Valid path parameters",
|
||||
path: "/test/123/John",
|
||||
expectedStatus: fiber.StatusOK,
|
||||
},
|
||||
{
|
||||
description: "Invalid path parameters - missing id",
|
||||
path: "/test//John",
|
||||
expectedStatus: fiber.StatusNotFound,
|
||||
},
|
||||
{
|
||||
description: "Invalid path parameters - missing name",
|
||||
path: "/test/123/",
|
||||
expectedStatus: fiber.StatusNotFound,
|
||||
},
|
||||
{
|
||||
description: "Invalid path parameters - invalid ID",
|
||||
path: "/test/invalid/John",
|
||||
expectedStatus: fiber.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &base.Handler{
|
||||
Logger: tt.fields.Logger,
|
||||
Validator: tt.fields.Validator,
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", test.path, nil)
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("app.Test failed: %v", err)
|
||||
}
|
||||
if err := h.ParamsParserValidator(tt.args.c, tt.args.out); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Handler.ParamsParserValidator() error = %v, wantErr %v", err, tt.wantErr)
|
||||
if test.expectedStatus != resp.StatusCode {
|
||||
t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_validateStruct(t *testing.T) {
|
||||
type fields struct {
|
||||
Logger *zap.Logger
|
||||
Validator *validator.Validate
|
||||
func TestHandler_ValidateStruct(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
validate := validator.New()
|
||||
|
||||
// Test with validator
|
||||
handlerWithValidator := &base.Handler{
|
||||
Logger: logger,
|
||||
Validator: validate,
|
||||
}
|
||||
type args struct {
|
||||
out any
|
||||
|
||||
// Test without validator
|
||||
handlerWithoutValidator := &base.Handler{
|
||||
Logger: logger,
|
||||
Validator: nil,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
description string
|
||||
handler *base.Handler
|
||||
input any
|
||||
expectedStatus int
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
{
|
||||
description: "Valid struct with validator",
|
||||
handler: handlerWithValidator,
|
||||
input: &testRequestBody{Name: "John Doe", Age: 25},
|
||||
expectedStatus: fiber.StatusOK,
|
||||
},
|
||||
{
|
||||
description: "Invalid struct with validator - missing required field",
|
||||
handler: handlerWithValidator,
|
||||
input: &testRequestBody{Age: 25},
|
||||
expectedStatus: fiber.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
description: "Invalid struct with validator - custom validation fails",
|
||||
handler: handlerWithValidator,
|
||||
input: &testRequestBody{Name: "John Doe", Age: 17},
|
||||
expectedStatus: fiber.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
description: "Valid struct without validator",
|
||||
handler: handlerWithoutValidator,
|
||||
input: &testRequestBody{Name: "John Doe", Age: 25},
|
||||
expectedStatus: fiber.StatusOK,
|
||||
},
|
||||
{
|
||||
description: "Invalid struct without validator - custom validation fails",
|
||||
handler: handlerWithoutValidator,
|
||||
input: &testRequestBody{Name: "John Doe", Age: 17},
|
||||
expectedStatus: fiber.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
description: "Valid struct with Validatable interface",
|
||||
handler: handlerWithValidator,
|
||||
input: &testRequestQuery{Name: "John", Age: 25},
|
||||
expectedStatus: fiber.StatusOK,
|
||||
},
|
||||
{
|
||||
description: "Invalid struct with Validatable interface",
|
||||
handler: handlerWithValidator,
|
||||
input: &testRequestQuery{Name: "John", Age: 17},
|
||||
expectedStatus: fiber.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &base.Handler{
|
||||
Logger: tt.fields.Logger,
|
||||
Validator: tt.fields.Validator,
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
err := test.handler.ValidateStruct(test.input)
|
||||
|
||||
if test.expectedStatus == fiber.StatusOK && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if err := h.ValidateStruct(tt.args.out); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Handler.validateStruct() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
if test.expectedStatus == fiber.StatusBadRequest && err == nil {
|
||||
t.Errorf("Expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -62,12 +62,12 @@ type ThirdPartyController struct {
|
||||
func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error {
|
||||
var params thirdPartyPostQueryParams
|
||||
if err := h.QueryParserValidator(c, ¶ms); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
var req smsgateway.Message
|
||||
if err := h.BodyParserValidator(c, &req); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
var device models.Device
|
||||
@ -192,7 +192,7 @@ func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error {
|
||||
func (h *ThirdPartyController) list(user models.User, c *fiber.Ctx) error {
|
||||
params := thirdPartyGetQueryParams{}
|
||||
if err := h.QueryParserValidator(c, ¶ms); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
messages, total, err := h.messagesSvc.SelectStates(user, params.ToFilter(), params.ToOptions())
|
||||
@ -252,7 +252,7 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error {
|
||||
func (h *ThirdPartyController) postInboxExport(user models.User, c *fiber.Ctx) error {
|
||||
req := smsgateway.MessagesExportRequest{}
|
||||
if err := h.BodyParserValidator(c, &req); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
device, err := h.devicesSvc.Get(user.ID, devices.WithID(req.DeviceID))
|
||||
|
||||
@ -49,7 +49,7 @@ func (h *MobileController) list(device models.Device, c *fiber.Ctx) error {
|
||||
// Get and validate order parameter
|
||||
params := mobileGetQueryParams{}
|
||||
if err := h.QueryParserValidator(c, ¶ms); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
msgs, err := h.messagesSvc.SelectPending(device.ID, params.OrderOrDefault())
|
||||
@ -81,9 +81,9 @@ func (h *MobileController) list(device models.Device, c *fiber.Ctx) error {
|
||||
//
|
||||
// Update message state
|
||||
func (h *MobileController) patch(device models.Device, c *fiber.Ctx) error {
|
||||
var req smsgateway.MobilePatchMessageRequest
|
||||
req := smsgateway.MobilePatchMessageRequest{}
|
||||
if err := h.BodyParserValidator(c, &req); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
for _, v := range req {
|
||||
|
||||
@ -96,7 +96,7 @@ func (h *mobileHandler) postDevice(c *fiber.Ctx) (err error) {
|
||||
req := smsgateway.MobileRegisterRequest{}
|
||||
|
||||
if err = h.BodyParserValidator(c, &req); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
@ -150,7 +150,7 @@ func (h *mobileHandler) patchDevice(device models.Device, c *fiber.Ctx) error {
|
||||
req := smsgateway.MobileUpdateRequest{}
|
||||
|
||||
if err := h.BodyParserValidator(c, &req); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if req.Id != device.ID {
|
||||
@ -205,7 +205,7 @@ func (h *mobileHandler) changePassword(device models.Device, c *fiber.Ctx) error
|
||||
req := smsgateway.MobileChangePasswordRequest{}
|
||||
|
||||
if err := h.BodyParserValidator(c, &req); err != nil {
|
||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.authSvc.ChangePassword(device.UserID, req.CurrentPassword, req.NewPassword); err != nil {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user