diff --git a/src/go.mod b/src/go.mod index 3582885..5b35819 100644 --- a/src/go.mod +++ b/src/go.mod @@ -5,7 +5,7 @@ go 1.24.0 toolchain go1.24.1 require ( - git.ego.freeddns.org/egommerce/api-entities v0.3.4 + git.ego.freeddns.org/egommerce/api-entities v0.3.8 git.ego.freeddns.org/egommerce/go-api-pkg v0.4.6 github.com/go-pg/migrations/v8 v8.1.0 github.com/go-pg/pg/v10 v10.15.0 diff --git a/src/go.sum b/src/go.sum index c87f401..dc81d3c 100644 --- a/src/go.sum +++ b/src/go.sum @@ -1,6 +1,6 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -git.ego.freeddns.org/egommerce/api-entities v0.3.4 h1:yzPpumgtZeANRX3c74Y8/Z3s4vR6BKYceZhPLnLNNVY= -git.ego.freeddns.org/egommerce/api-entities v0.3.4/go.mod h1:IqynARw+06GOm4eZGZuepmbi7bUxWBnOB4jd5cI7jf8= +git.ego.freeddns.org/egommerce/api-entities v0.3.8 h1:ULuyJfr04E9AIAe8QTxLjDFV81H6HTxMNXtI/b+3Ec0= +git.ego.freeddns.org/egommerce/api-entities v0.3.8/go.mod h1:IqynARw+06GOm4eZGZuepmbi7bUxWBnOB4jd5cI7jf8= git.ego.freeddns.org/egommerce/go-api-pkg v0.4.6 h1:1iZW+vkbv7fQusv/pMjtIM1QvJ+QQr3nyvuuajgHc80= git.ego.freeddns.org/egommerce/go-api-pkg v0.4.6/go.mod h1:5Ft8LCd0UXp5hHpvXRBCv9mCGikogFhL7LP2qit12JM= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= diff --git a/src/internal/server/jwt_middleware.go b/src/internal/server/jwt_middleware.go deleted file mode 100644 index ecf9cdf..0000000 --- a/src/internal/server/jwt_middleware.go +++ /dev/null @@ -1,36 +0,0 @@ -package server - -import ( - baseCnf "git.ego.freeddns.org/egommerce/go-api-pkg/config" - "github.com/gofiber/fiber/v2" - jwt "github.com/gofiber/jwt/v2" -) - -// JWTProtected func for specify routes group with JWT authentication. -// See: https://github.com/gofiber/jwt -func JWTProtected() func(*fiber.Ctx) error { - // Create config for JWT authentication middleware. - config := jwt.Config{ - SigningKey: []byte(baseCnf.GetEnv("JWT_ACCESS_TOKEN_SECRET_KEY", "FallbackAccessTokenSecret")), - ContextKey: "jwt", // used in private routes - ErrorHandler: jwtError, - } - - return jwt.New(config) -} - -func jwtError(c *fiber.Ctx, err error) error { - // Return status 400 Bad Request and failed authentication error. - if err.Error() == "Missing or malformed JWT" { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ - "error": true, - "msg": err.Error(), - }) - } - - // Return status 401 Unauthorized and failed authentication error. - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ - "error": true, - "msg": err.Error(), - }) -} diff --git a/src/internal/server/login_handler.go b/src/internal/server/login_handler.go index a06717d..fc4ec83 100644 --- a/src/internal/server/login_handler.go +++ b/src/internal/server/login_handler.go @@ -3,6 +3,7 @@ package server import ( dto "git.ego.freeddns.org/egommerce/api-entities/identity/dto" "git.ego.freeddns.org/egommerce/identity-service/internal/service" + "github.com/gofiber/fiber/v2" ) diff --git a/src/internal/server/middleware.go b/src/internal/server/middleware.go deleted file mode 100644 index 093d93e..0000000 --- a/src/internal/server/middleware.go +++ /dev/null @@ -1,37 +0,0 @@ -package server - -import ( - "log" - - "github.com/gofiber/fiber/v2" - "github.com/google/uuid" -) - -// "github.com/gofiber/fiber/v2" -// "github.com/gofiber/fiber/v2/middleware/cors" - -func SetupMiddleware(s *Server) { - s.Use(LoggingMiddleware()) - s.Use(XRequestIDMiddleware()) -} - -func LoggingMiddleware() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - log.Printf("Request: %s, remote: %s, via: %s", - c.Request().URI().String(), - c.Context().RemoteIP().String(), - string(c.Context().UserAgent()), - ) - - return c.Next() - } -} - -func XRequestIDMiddleware() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - requestID := uuid.New().String() - c.Set("X-Request-ID", requestID) - - return c.Next() - } -} diff --git a/src/internal/server/middlewares.go b/src/internal/server/middlewares.go new file mode 100644 index 0000000..b33c821 --- /dev/null +++ b/src/internal/server/middlewares.go @@ -0,0 +1,66 @@ +package server + +import ( + "log" + + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" +) + +// "github.com/gofiber/fiber/v2" +// "github.com/gofiber/fiber/v2/middleware/cors" + +func SetupMiddleware(s *Server) { + s.Use(LoggingMiddleware()) + s.Use(XRequestIDMiddleware()) +} + +func LoggingMiddleware() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + log.Printf("Request: %s, remote: %s, via: %s", + c.Request().URI().String(), + c.Context().RemoteIP().String(), + string(c.Context().UserAgent()), + ) + + return c.Next() + } +} + +func XRequestIDMiddleware() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + requestID := uuid.New().String() + c.Set("X-Request-ID", requestID) + + return c.Next() + } +} + +// JWTProtected func for specify routes group with JWT authentication. +// See: https://github.com/gofiber/jwt +// func JWTProtected() func(*fiber.Ctx) error { +// // Create config for JWT authentication middleware. +// config := jwt.Config{ +// SigningKey: []byte(baseCnf.GetEnv("JWT_ACCESS_TOKEN_SECRET_KEY", "FallbackAccessTokenSecret")), +// ContextKey: "jwt", // used in private routes +// ErrorHandler: jwtError, +// } + +// return jwt.New(config) +// } + +// func jwtError(c *fiber.Ctx, err error) error { +// // Return status 400 Bad Request and failed authentication error. +// if err.Error() == "Missing or malformed JWT" { +// return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ +// "error": true, +// "msg": err.Error(), +// }) +// } + +// // Return status 401 Unauthorized and failed authentication error. +// return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ +// "error": true, +// "msg": err.Error(), +// }) +// } diff --git a/src/internal/server/refresh_handler.go b/src/internal/server/refresh_handler.go new file mode 100644 index 0000000..0d2408b --- /dev/null +++ b/src/internal/server/refresh_handler.go @@ -0,0 +1,28 @@ +package server + +import ( + dto "git.ego.freeddns.org/egommerce/api-entities/identity/dto" + "git.ego.freeddns.org/egommerce/identity-service/internal/service" + + "github.com/gofiber/fiber/v2" +) + +func (s *Server) RefreshHandlerFn(c *fiber.Ctx) error { + data := new(dto.AuthRefreshTokenRequestDTO) + if err := c.BodyParser(data); err != nil { + return s.Error(c, fiber.StatusBadRequest, "Error parsing input") + } + + authSrv := service.NewAuthService(s.GetDatabase(), s.GetCache()) + + token, err := authSrv.RefreshToken(data.AccessToken) + if err != nil { + if err == service.ErrUnableToCacheToken { + return s.Error(c, fiber.StatusInternalServerError, err.Error()) + } + + return s.Error(c, fiber.StatusBadRequest, err.Error()) + } + + return c.JSON(&dto.AuthLoginResponseDTO{Token: token}) +} diff --git a/src/internal/server/register_handler.go b/src/internal/server/register_handler.go index 4b47b4d..86cfdc4 100644 --- a/src/internal/server/register_handler.go +++ b/src/internal/server/register_handler.go @@ -3,6 +3,7 @@ package server import ( dto "git.ego.freeddns.org/egommerce/api-entities/identity/dto" "git.ego.freeddns.org/egommerce/identity-service/internal/service" + "github.com/gofiber/fiber/v2" ) diff --git a/src/internal/server/router.go b/src/internal/server/router.go index b30f509..59be803 100644 --- a/src/internal/server/router.go +++ b/src/internal/server/router.go @@ -21,5 +21,6 @@ func SetupRouter(s *Server) { s.Group("/v1"). Post("/login", s.LoginHandlerFn). + Post("/refresh", s.RefreshHandlerFn). Post("/register", s.RegisterHandlerFn) } diff --git a/src/internal/service/auth.go b/src/internal/service/auth.go index a128b7d..a3949e1 100644 --- a/src/internal/service/auth.go +++ b/src/internal/service/auth.go @@ -6,16 +6,17 @@ import ( "fmt" db "git.ego.freeddns.org/egommerce/identity-service/pkg/database" + "github.com/go-redis/redis/v8" "github.com/jackc/pgx/v5/pgxpool" ) var ( AuthService *Auth - jwtSrv *JWT ErrLoginIncorrect = errors.New("login incorrect") ErrUnableToCacheToken = errors.New("unable to save token in cache") + ErrInvalidAccessToken = errors.New("invalid access token") ) func init() { @@ -48,13 +49,30 @@ func (a *Auth) Login(login, passwd string) (string, error) { accessToken, _ := jwtSrv.CreateAccessToken(id) refreshToken, _ := jwtSrv.CreateRefreshToken(id) - if err = a.saveTokensToCache(accessToken, refreshToken, id); err != nil { + if err = a.saveTokensToCache(id, accessToken, refreshToken); err != nil { return "", ErrUnableToCacheToken } return accessToken, nil } +func (a *Auth) RefreshToken(accessToken string) (string, error) { + token, claims, err := jwtSrv.ValidateAccessToken(accessToken) + if err != nil || !token.Valid { + return "", ErrInvalidAccessToken + } + + id := claims["sub"] + + newAccessToken, _ := jwtSrv.CreateAccessToken(id.(string)) + newRefreshToken, _ := jwtSrv.CreateRefreshToken(id.(string)) + if err = a.saveTokensToCache(id.(string), newAccessToken, newRefreshToken); err != nil { + return "", ErrUnableToCacheToken + } + + return newAccessToken, nil +} + func (a *Auth) Register(email, login, passwd string) (string, error) { var id string @@ -71,7 +89,7 @@ func (a *Auth) Register(email, login, passwd string) (string, error) { return id, nil } -func (a *Auth) saveTokensToCache(accessToken, refreshToken, id string) error { +func (a *Auth) saveTokensToCache(id, accessToken, refreshToken string) error { res := a.cache.Set(context.Background(), "auth:access_token:"+id, accessToken, accessTokenExpireTime) if err := res.Err(); err != nil { fmt.Println("failed to save access token in redis: ", err.Error()) diff --git a/src/internal/service/jwt.go b/src/internal/service/jwt.go index 7f000af..f2bb4c5 100644 --- a/src/internal/service/jwt.go +++ b/src/internal/service/jwt.go @@ -1,6 +1,7 @@ package service import ( + "errors" "fmt" "strconv" "time" @@ -10,10 +11,15 @@ import ( ) var ( + ErrorTokenExpired = errors.New("token has expired") + ErrInvalidToken = errors.New("invalid token") + accessTokenExpireTime time.Duration refreshTokenExpireTime time.Duration ) +var jwtSrv *JWT + func init() { expAccessTokenTime, _ := strconv.Atoi(baseCnf.GetEnv("JWT_ACCESS_TOKEN_EXPIRE_TIME", "5")) accessTokenExpireTime = time.Duration(int(time.Hour) * expAccessTokenTime) // hours @@ -21,12 +27,16 @@ func init() { expRefreshTokenTime, _ := strconv.Atoi(baseCnf.GetEnv("JWT_REFRESH_TOKEN_EXPIRE_TIME", "7")) refreshTokenExpireTime = time.Duration(int(time.Hour*24) * expRefreshTokenTime) // days - jwtSrv = &JWT{ + jwtSrv = NewJWTService( accessTokenExpireTime, []byte(baseCnf.GetEnv("JWT_ACCESS_TOKEN_SECRET_KEY", "FallbackAccessTokenSecret")), refreshTokenExpireTime, []byte(baseCnf.GetEnv("JWT_REFRESH_TOKEN_SECRET_KEY", "FallbackRefreshTokenSecret")), - } + ) +} + +func NewJWTService(aTokenExp time.Duration, aTokenSecret []byte, rTokenExp time.Duration, rTokenSecret []byte) *JWT { + return &JWT{aTokenExp, aTokenSecret, rTokenExp, rTokenSecret} } type JWT struct { @@ -61,7 +71,7 @@ func (s *JWT) CreateRefreshToken(id string) (string, error) { return token.SignedString(s.accessTokenSecret) } -func (s *JWT) ValidateToken(tokenStr string) error { +func (s *JWT) ValidateAccessToken(tokenStr string) (*jwt.Token, jwt.MapClaims, error) { token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { // Don't forget to validate the alg is what you expect: if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { @@ -71,9 +81,19 @@ func (s *JWT) ValidateToken(tokenStr string) error { return s.accessTokenSecret, nil }) - if _, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { - return nil + if err != nil { + return nil, nil, err } - return err + if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { + if exp, ok := claims["exp"].(float64); ok { + if int64(exp) < time.Now().Unix() { + return nil, nil, ErrorTokenExpired + } + + return token, claims, nil + } + } + + return nil, nil, ErrInvalidToken }