mirror of
https://github.com/makayabou/asg-server.git
synced 2026-05-02 17:43:36 +02:00
[handlers] use Validate.Var for slice validation
This commit is contained in:
parent
f50b85bdba
commit
0a71b45122
@ -43,7 +43,7 @@ func (h *Handler) ParamsParserValidator(c *fiber.Ctx, out any) error {
|
|||||||
|
|
||||||
func (h *Handler) ValidateStruct(out any) error {
|
func (h *Handler) ValidateStruct(out any) error {
|
||||||
if h.Validator != nil {
|
if h.Validator != nil {
|
||||||
if err := h.Validator.Struct(out); err != nil {
|
if err := h.Validator.Var(out, "required,dive"); err != nil {
|
||||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,7 +11,6 @@ import (
|
|||||||
"github.com/android-sms-gateway/server/internal/sms-gateway/handlers/base"
|
"github.com/android-sms-gateway/server/internal/sms-gateway/handlers/base"
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"go.uber.org/zap"
|
|
||||||
"go.uber.org/zap/zaptest"
|
"go.uber.org/zap/zaptest"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -25,6 +24,16 @@ type testRequestBodyNoValidate struct {
|
|||||||
Age int `json:"age" validate:"required"`
|
Age int `json:"age" validate:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type testRequestQuery struct {
|
||||||
|
Name string `query:"name" validate:"required"`
|
||||||
|
Age int `query:"age" validate:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type testRequestParams struct {
|
||||||
|
ID string `params:"id" validate:"required"`
|
||||||
|
Name string `params:"name" validate:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
func (t *testRequestBody) Validate() error {
|
func (t *testRequestBody) Validate() error {
|
||||||
if t.Age < 18 {
|
if t.Age < 18 {
|
||||||
return fmt.Errorf("must be at least 18 years old")
|
return fmt.Errorf("must be at least 18 years old")
|
||||||
@ -32,6 +41,20 @@ func (t *testRequestBody) Validate() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *testRequestQuery) Validate() error {
|
||||||
|
if t.Age < 18 {
|
||||||
|
return fmt.Errorf("must be at least 18 years old")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testRequestParams) Validate() error {
|
||||||
|
if t.ID == "invalid" {
|
||||||
|
return fmt.Errorf("invalid ID")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandler_BodyParserValidator(t *testing.T) {
|
func TestHandler_BodyParserValidator(t *testing.T) {
|
||||||
logger := zaptest.NewLogger(t)
|
logger := zaptest.NewLogger(t)
|
||||||
validate := validator.New()
|
validate := validator.New()
|
||||||
@ -100,7 +123,10 @@ func TestHandler_BodyParserValidator(t *testing.T) {
|
|||||||
req = httptest.NewRequest("POST", test.path, nil)
|
req = httptest.NewRequest("POST", test.path, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, _ := app.Test(req)
|
resp, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test failed: %v", err)
|
||||||
|
}
|
||||||
if test.expectedStatus != resp.StatusCode {
|
if test.expectedStatus != resp.StatusCode {
|
||||||
t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode)
|
t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode)
|
||||||
}
|
}
|
||||||
@ -109,89 +135,200 @@ func TestHandler_BodyParserValidator(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestHandler_QueryParserValidator(t *testing.T) {
|
func TestHandler_QueryParserValidator(t *testing.T) {
|
||||||
type fields struct {
|
logger := zaptest.NewLogger(t)
|
||||||
Logger *zap.Logger
|
validate := validator.New()
|
||||||
Validator *validator.Validate
|
|
||||||
}
|
handler := &base.Handler{
|
||||||
type args struct {
|
Logger: logger,
|
||||||
c *fiber.Ctx
|
Validator: validate,
|
||||||
out any
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
app := fiber.New()
|
||||||
|
app.Get("/test", func(c *fiber.Ctx) error {
|
||||||
|
var query testRequestQuery
|
||||||
|
return handler.QueryParserValidator(c, &query)
|
||||||
|
})
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
description string
|
||||||
fields fields
|
path string
|
||||||
args args
|
expectedStatus int
|
||||||
wantErr bool
|
|
||||||
}{
|
}{
|
||||||
// TODO: Add test cases.
|
{
|
||||||
|
description: "Invalid query parameters - non-integer age",
|
||||||
|
path: "/test?name=John&age=abc",
|
||||||
|
expectedStatus: fiber.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Valid query parameters",
|
||||||
|
path: "/test?name=John&age=25",
|
||||||
|
expectedStatus: fiber.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Invalid query parameters - missing name",
|
||||||
|
path: "/test?age=25",
|
||||||
|
expectedStatus: fiber.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Invalid query parameters - age too low",
|
||||||
|
path: "/test?name=John&age=17",
|
||||||
|
expectedStatus: fiber.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Invalid query parameters - missing age",
|
||||||
|
path: "/test?name=John",
|
||||||
|
expectedStatus: fiber.StatusBadRequest,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
for _, test := range tests {
|
||||||
h := &base.Handler{
|
t.Run(test.description, func(t *testing.T) {
|
||||||
Logger: tt.fields.Logger,
|
req := httptest.NewRequest("GET", test.path, nil)
|
||||||
Validator: tt.fields.Validator,
|
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test failed: %v", err)
|
||||||
}
|
}
|
||||||
if err := h.QueryParserValidator(tt.args.c, tt.args.out); (err != nil) != tt.wantErr {
|
if test.expectedStatus != resp.StatusCode {
|
||||||
t.Errorf("Handler.QueryParserValidator() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandler_ParamsParserValidator(t *testing.T) {
|
func TestHandler_ParamsParserValidator(t *testing.T) {
|
||||||
type fields struct {
|
logger := zaptest.NewLogger(t)
|
||||||
Logger *zap.Logger
|
validate := validator.New()
|
||||||
Validator *validator.Validate
|
|
||||||
}
|
handler := &base.Handler{
|
||||||
type args struct {
|
Logger: logger,
|
||||||
c *fiber.Ctx
|
Validator: validate,
|
||||||
out any
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
app := fiber.New()
|
||||||
|
app.Get("/test/:id/:name", func(c *fiber.Ctx) error {
|
||||||
|
var params testRequestParams
|
||||||
|
return handler.ParamsParserValidator(c, ¶ms)
|
||||||
|
})
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
description string
|
||||||
fields fields
|
path string
|
||||||
args args
|
expectedStatus int
|
||||||
wantErr bool
|
|
||||||
}{
|
}{
|
||||||
// TODO: Add test cases.
|
{
|
||||||
|
description: "Valid path parameters",
|
||||||
|
path: "/test/123/John",
|
||||||
|
expectedStatus: fiber.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Invalid path parameters - missing id",
|
||||||
|
path: "/test//John",
|
||||||
|
expectedStatus: fiber.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Invalid path parameters - missing name",
|
||||||
|
path: "/test/123/",
|
||||||
|
expectedStatus: fiber.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Invalid path parameters - invalid ID",
|
||||||
|
path: "/test/invalid/John",
|
||||||
|
expectedStatus: fiber.StatusBadRequest,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
for _, test := range tests {
|
||||||
h := &base.Handler{
|
t.Run(test.description, func(t *testing.T) {
|
||||||
Logger: tt.fields.Logger,
|
req := httptest.NewRequest("GET", test.path, nil)
|
||||||
Validator: tt.fields.Validator,
|
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test failed: %v", err)
|
||||||
}
|
}
|
||||||
if err := h.ParamsParserValidator(tt.args.c, tt.args.out); (err != nil) != tt.wantErr {
|
if test.expectedStatus != resp.StatusCode {
|
||||||
t.Errorf("Handler.ParamsParserValidator() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandler_validateStruct(t *testing.T) {
|
func TestHandler_ValidateStruct(t *testing.T) {
|
||||||
type fields struct {
|
logger := zaptest.NewLogger(t)
|
||||||
Logger *zap.Logger
|
validate := validator.New()
|
||||||
Validator *validator.Validate
|
|
||||||
|
// Test with validator
|
||||||
|
handlerWithValidator := &base.Handler{
|
||||||
|
Logger: logger,
|
||||||
|
Validator: validate,
|
||||||
}
|
}
|
||||||
type args struct {
|
|
||||||
out any
|
// Test without validator
|
||||||
|
handlerWithoutValidator := &base.Handler{
|
||||||
|
Logger: logger,
|
||||||
|
Validator: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
description string
|
||||||
fields fields
|
handler *base.Handler
|
||||||
args args
|
input any
|
||||||
wantErr bool
|
expectedStatus int
|
||||||
}{
|
}{
|
||||||
// TODO: Add test cases.
|
{
|
||||||
|
description: "Valid struct with validator",
|
||||||
|
handler: handlerWithValidator,
|
||||||
|
input: &testRequestBody{Name: "John Doe", Age: 25},
|
||||||
|
expectedStatus: fiber.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Invalid struct with validator - missing required field",
|
||||||
|
handler: handlerWithValidator,
|
||||||
|
input: &testRequestBody{Age: 25},
|
||||||
|
expectedStatus: fiber.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Invalid struct with validator - custom validation fails",
|
||||||
|
handler: handlerWithValidator,
|
||||||
|
input: &testRequestBody{Name: "John Doe", Age: 17},
|
||||||
|
expectedStatus: fiber.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Valid struct without validator",
|
||||||
|
handler: handlerWithoutValidator,
|
||||||
|
input: &testRequestBody{Name: "John Doe", Age: 25},
|
||||||
|
expectedStatus: fiber.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Invalid struct without validator - custom validation fails",
|
||||||
|
handler: handlerWithoutValidator,
|
||||||
|
input: &testRequestBody{Name: "John Doe", Age: 17},
|
||||||
|
expectedStatus: fiber.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Valid struct with Validatable interface",
|
||||||
|
handler: handlerWithValidator,
|
||||||
|
input: &testRequestQuery{Name: "John", Age: 25},
|
||||||
|
expectedStatus: fiber.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Invalid struct with Validatable interface",
|
||||||
|
handler: handlerWithValidator,
|
||||||
|
input: &testRequestQuery{Name: "John", Age: 17},
|
||||||
|
expectedStatus: fiber.StatusBadRequest,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
for _, test := range tests {
|
||||||
h := &base.Handler{
|
t.Run(test.description, func(t *testing.T) {
|
||||||
Logger: tt.fields.Logger,
|
err := test.handler.ValidateStruct(test.input)
|
||||||
Validator: tt.fields.Validator,
|
|
||||||
|
if test.expectedStatus == fiber.StatusOK && err != nil {
|
||||||
|
t.Errorf("Expected no error, got %v", err)
|
||||||
}
|
}
|
||||||
if err := h.ValidateStruct(tt.args.out); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("Handler.validateStruct() error = %v, wantErr %v", err, tt.wantErr)
|
if test.expectedStatus == fiber.StatusBadRequest && err == nil {
|
||||||
|
t.Errorf("Expected error, got nil")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -62,12 +62,12 @@ type ThirdPartyController struct {
|
|||||||
func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error {
|
func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error {
|
||||||
var params thirdPartyPostQueryParams
|
var params thirdPartyPostQueryParams
|
||||||
if err := h.QueryParserValidator(c, ¶ms); err != nil {
|
if err := h.QueryParserValidator(c, ¶ms); err != nil {
|
||||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var req smsgateway.Message
|
var req smsgateway.Message
|
||||||
if err := h.BodyParserValidator(c, &req); err != nil {
|
if err := h.BodyParserValidator(c, &req); err != nil {
|
||||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var device models.Device
|
var device models.Device
|
||||||
@ -192,7 +192,7 @@ func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error {
|
|||||||
func (h *ThirdPartyController) list(user models.User, c *fiber.Ctx) error {
|
func (h *ThirdPartyController) list(user models.User, c *fiber.Ctx) error {
|
||||||
params := thirdPartyGetQueryParams{}
|
params := thirdPartyGetQueryParams{}
|
||||||
if err := h.QueryParserValidator(c, ¶ms); err != nil {
|
if err := h.QueryParserValidator(c, ¶ms); err != nil {
|
||||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
messages, total, err := h.messagesSvc.SelectStates(user, params.ToFilter(), params.ToOptions())
|
messages, total, err := h.messagesSvc.SelectStates(user, params.ToFilter(), params.ToOptions())
|
||||||
@ -252,7 +252,7 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error {
|
|||||||
func (h *ThirdPartyController) postInboxExport(user models.User, c *fiber.Ctx) error {
|
func (h *ThirdPartyController) postInboxExport(user models.User, c *fiber.Ctx) error {
|
||||||
req := smsgateway.MessagesExportRequest{}
|
req := smsgateway.MessagesExportRequest{}
|
||||||
if err := h.BodyParserValidator(c, &req); err != nil {
|
if err := h.BodyParserValidator(c, &req); err != nil {
|
||||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
device, err := h.devicesSvc.Get(user.ID, devices.WithID(req.DeviceID))
|
device, err := h.devicesSvc.Get(user.ID, devices.WithID(req.DeviceID))
|
||||||
|
|||||||
@ -49,7 +49,7 @@ func (h *MobileController) list(device models.Device, c *fiber.Ctx) error {
|
|||||||
// Get and validate order parameter
|
// Get and validate order parameter
|
||||||
params := mobileGetQueryParams{}
|
params := mobileGetQueryParams{}
|
||||||
if err := h.QueryParserValidator(c, ¶ms); err != nil {
|
if err := h.QueryParserValidator(c, ¶ms); err != nil {
|
||||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs, err := h.messagesSvc.SelectPending(device.ID, params.OrderOrDefault())
|
msgs, err := h.messagesSvc.SelectPending(device.ID, params.OrderOrDefault())
|
||||||
@ -81,9 +81,9 @@ func (h *MobileController) list(device models.Device, c *fiber.Ctx) error {
|
|||||||
//
|
//
|
||||||
// Update message state
|
// Update message state
|
||||||
func (h *MobileController) patch(device models.Device, c *fiber.Ctx) error {
|
func (h *MobileController) patch(device models.Device, c *fiber.Ctx) error {
|
||||||
var req smsgateway.MobilePatchMessageRequest
|
req := smsgateway.MobilePatchMessageRequest{}
|
||||||
if err := h.BodyParserValidator(c, &req); err != nil {
|
if err := h.BodyParserValidator(c, &req); err != nil {
|
||||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range req {
|
for _, v := range req {
|
||||||
|
|||||||
@ -96,7 +96,7 @@ func (h *mobileHandler) postDevice(c *fiber.Ctx) (err error) {
|
|||||||
req := smsgateway.MobileRegisterRequest{}
|
req := smsgateway.MobileRegisterRequest{}
|
||||||
|
|
||||||
if err = h.BodyParserValidator(c, &req); err != nil {
|
if err = h.BodyParserValidator(c, &req); err != nil {
|
||||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -150,7 +150,7 @@ func (h *mobileHandler) patchDevice(device models.Device, c *fiber.Ctx) error {
|
|||||||
req := smsgateway.MobileUpdateRequest{}
|
req := smsgateway.MobileUpdateRequest{}
|
||||||
|
|
||||||
if err := h.BodyParserValidator(c, &req); err != nil {
|
if err := h.BodyParserValidator(c, &req); err != nil {
|
||||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Id != device.ID {
|
if req.Id != device.ID {
|
||||||
@ -205,7 +205,7 @@ func (h *mobileHandler) changePassword(device models.Device, c *fiber.Ctx) error
|
|||||||
req := smsgateway.MobileChangePasswordRequest{}
|
req := smsgateway.MobileChangePasswordRequest{}
|
||||||
|
|
||||||
if err := h.BodyParserValidator(c, &req); err != nil {
|
if err := h.BodyParserValidator(c, &req); err != nil {
|
||||||
return fiber.NewError(fiber.StatusBadRequest, err.Error())
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.authSvc.ChangePassword(device.UserID, req.CurrentPassword, req.NewPassword); err != nil {
|
if err := h.authSvc.ChangePassword(device.UserID, req.CurrentPassword, req.NewPassword); err != nil {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user