diff --git a/internal/sms-gateway/handlers/handler.go b/internal/sms-gateway/handlers/handler.go index ef6b2f1..bec4914 100644 --- a/internal/sms-gateway/handlers/handler.go +++ b/internal/sms-gateway/handlers/handler.go @@ -8,6 +8,10 @@ import ( "go.uber.org/zap" ) +type Validatable interface { + Validate() error +} + type Handler struct { Logger *zap.Logger Validator *validator.Validate @@ -18,11 +22,7 @@ func (h *Handler) BodyParserValidator(c *fiber.Ctx, out any) error { return fmt.Errorf("can't parse body: %w", err) } - if h.Validator == nil { - return nil - } - - return h.Validator.Struct(out) + return h.validateStruct(out) } func (h *Handler) QueryParserValidator(c *fiber.Ctx, out any) error { @@ -30,11 +30,7 @@ func (h *Handler) QueryParserValidator(c *fiber.Ctx, out any) error { return fmt.Errorf("can't parse query: %w", err) } - if h.Validator == nil { - return nil - } - - return h.Validator.Struct(out) + return h.validateStruct(out) } func (h *Handler) ParamsParserValidator(c *fiber.Ctx, out any) error { @@ -42,9 +38,21 @@ func (h *Handler) ParamsParserValidator(c *fiber.Ctx, out any) error { return fmt.Errorf("can't parse params: %w", err) } - if h.Validator == nil { - return nil + return h.validateStruct(out) +} + +func (h *Handler) validateStruct(out any) error { + if h.Validator != nil { + if err := h.Validator.Struct(out); err != nil { + return fiber.NewError(fiber.StatusBadRequest, err.Error()) + } } - return h.Validator.Struct(out) + if req, ok := out.(Validatable); ok { + if err := req.Validate(); err != nil { + return fiber.NewError(fiber.StatusBadRequest, err.Error()) + } + } + + return nil } diff --git a/internal/sms-gateway/handlers/handler_test.go b/internal/sms-gateway/handlers/handler_test.go new file mode 100644 index 0000000..4791ff9 --- /dev/null +++ b/internal/sms-gateway/handlers/handler_test.go @@ -0,0 +1,205 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-playground/validator/v10" + "github.com/gofiber/fiber/v2" + "go.uber.org/zap" + "go.uber.org/zap/zaptest" +) + +type TestRequestBody struct { + Name string `json:"name" validate:"required"` + Age int `json:"age" validate:"required"` +} + +type TestRequestBodyNoValidate struct { + Name string `json:"name" validate:"required"` + Age int `json:"age" validate:"required"` +} + +func (t *TestRequestBody) Validate() error { + if t.Age < 18 { + return fmt.Errorf("must be at least 18 years old") + } + return nil +} + +type TestQueryParams struct { + Page int `query:"page" validate:"required"` +} + +type TestURLParams struct { + ID string `params:"id" validate:"required,uuid"` +} + +func TestHandler_BodyParserValidator(t *testing.T) { + logger := zaptest.NewLogger(t) + validate := validator.New() + + handler := &Handler{ + Logger: logger, + Validator: validate, + } + + app := fiber.New() + app.Post("/test", func(c *fiber.Ctx) error { + var body TestRequestBody + return handler.BodyParserValidator(c, &body) + }) + app.Post("/test2", func(c *fiber.Ctx) error { + var body TestRequestBodyNoValidate + return handler.BodyParserValidator(c, &body) + }) + + tests := []struct { + description string + path string + payload any + expectedStatus int + }{ + { + description: "Valid request body", + path: "/test", + payload: &TestRequestBody{Name: "John Doe", Age: 25}, + expectedStatus: fiber.StatusOK, + }, + { + description: "Invalid request body - missing name", + path: "/test", + payload: &TestRequestBody{Age: 25}, + expectedStatus: fiber.StatusBadRequest, + }, + { + description: "Invalid request body - age too low", + path: "/test", + payload: &TestRequestBody{Name: "John Doe", Age: 17}, + expectedStatus: fiber.StatusBadRequest, + }, + { + description: "Valid request body - no validation", + path: "/test2", + payload: &TestRequestBodyNoValidate{Name: "John Doe", Age: 17}, + expectedStatus: fiber.StatusOK, + }, + { + description: "No request body", + path: "/test", + payload: nil, + expectedStatus: fiber.StatusUnprocessableEntity, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + var req *http.Request + if test.payload != nil { + bodyBytes, _ := json.Marshal(test.payload) + req = httptest.NewRequest("POST", test.path, bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + } else { + req = httptest.NewRequest("POST", test.path, nil) + } + + resp, _ := app.Test(req) + if test.expectedStatus != resp.StatusCode { + t.Errorf("Expected status code %d, got %d", test.expectedStatus, resp.StatusCode) + } + }) + } +} + +func TestHandler_QueryParserValidator(t *testing.T) { + type fields struct { + Logger *zap.Logger + Validator *validator.Validate + } + type args struct { + c *fiber.Ctx + out any + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Handler{ + Logger: tt.fields.Logger, + Validator: tt.fields.Validator, + } + if err := h.QueryParserValidator(tt.args.c, tt.args.out); (err != nil) != tt.wantErr { + t.Errorf("Handler.QueryParserValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestHandler_ParamsParserValidator(t *testing.T) { + type fields struct { + Logger *zap.Logger + Validator *validator.Validate + } + type args struct { + c *fiber.Ctx + out any + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Handler{ + Logger: tt.fields.Logger, + Validator: tt.fields.Validator, + } + if err := h.ParamsParserValidator(tt.args.c, tt.args.out); (err != nil) != tt.wantErr { + t.Errorf("Handler.ParamsParserValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestHandler_validateStruct(t *testing.T) { + type fields struct { + Logger *zap.Logger + Validator *validator.Validate + } + type args struct { + out any + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Handler{ + Logger: tt.fields.Logger, + Validator: tt.fields.Validator, + } + if err := h.validateStruct(tt.args.out); (err != nil) != tt.wantErr { + t.Errorf("Handler.validateStruct() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}