diff --git a/src/internal/server/jwt_middleware.go b/src/internal/server/jwt_middleware.go new file mode 100644 index 0000000..d00af85 --- /dev/null +++ b/src/internal/server/jwt_middleware.go @@ -0,0 +1,37 @@ +package server + +import ( + "os" + + "github.com/gofiber/fiber/v2" + jwt "github.com/gofiber/jwt/v2" +) + +// JWTProtected func for specify routes group with JWT authentication. +// See: https://github.com/gofiber/jwt +func JWTProtected() func(*fiber.Ctx) error { + // Create config for JWT authentication middleware. + config := jwt.Config{ + SigningKey: []byte(os.Getenv("JWT_SECRET_KEY")), + ContextKey: "jwt", // used in private routes + ErrorHandler: jwtError, + } + + return jwt.New(config) +} + +func jwtError(c *fiber.Ctx, err error) error { + // Return status 400 Bad Request and failed authentication error. + if err.Error() == "Missing or malformed JWT" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": true, + "msg": err.Error(), + }) + } + + // Return status 401 Unauthorized and failed authentication error. + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ + "error": true, + "msg": err.Error(), + }) +} diff --git a/src/internal/server/middleware.go b/src/internal/server/middleware.go index bf4281d..093d93e 100644 --- a/src/internal/server/middleware.go +++ b/src/internal/server/middleware.go @@ -1,30 +1,36 @@ package server import ( - "github.com/gofiber/fiber/v2" + "log" - "git.ego.freeddns.org/egommerce/go-api-pkg/fluentd" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" ) // "github.com/gofiber/fiber/v2" // "github.com/gofiber/fiber/v2/middleware/cors" func SetupMiddleware(s *Server) { - s.Use(defaultCORS) - s.Use(LoggingMiddleware(s.GetLogger())) + s.Use(LoggingMiddleware()) + s.Use(XRequestIDMiddleware()) } -func LoggingMiddleware(log *fluentd.Logger) func(c *fiber.Ctx) error { +func LoggingMiddleware() func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - // path := string(c.Request().URI().Path()) - // if strings.Contains(path, "/health") { - // return c.Next() - // } - - log.Log("Request: %s, remote: %s, via: %s", + log.Printf("Request: %s, remote: %s, via: %s", c.Request().URI().String(), c.Context().RemoteIP().String(), - string(c.Context().UserAgent())) + string(c.Context().UserAgent()), + ) + + return c.Next() + } +} + +func XRequestIDMiddleware() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + requestID := uuid.New().String() + c.Set("X-Request-ID", requestID) return c.Next() }