diff --git a/src/domain/repository/user_repository.go b/src/domain/repository/user_repository.go index 0016196..f767c1d 100644 --- a/src/domain/repository/user_repository.go +++ b/src/domain/repository/user_repository.go @@ -19,7 +19,29 @@ func NewUserRepository(db *pgxpool.Pool) *UserRepository { } func (r *UserRepository) GetByID(id string) (*entity.User, error) { - return &entity.User{}, nil + var user entity.User + + sql := `SELECT id, username, password, email, created_at FROM identity.users WHERE id=$1 LIMIT 1` + err := r.db.QueryRow(context.Background(), sql, id). + Scan(&user.ID, &user.Username, &user.Password, &user.Email, &user.CreatedAt) + if err != nil { + return nil, errors.New("failed to fetch user from DB: " + err.Error()) + } + + return &user, nil +} + +func (r *UserRepository) GetByUsername(login string) (*entity.User, error) { + var user entity.User + + sql := `SELECT id, username, password, email, created_at FROM identity.users WHERE username=$1 LIMIT 1` + err := r.db.QueryRow(context.Background(), sql, login). + Scan(&user.ID, &user.Username, &user.Password, &user.Email, &user.CreatedAt) + if err != nil { + return nil, errors.New("failed to fetch user from DB: " + err.Error()) + } + + return &user, nil } func (r *UserRepository) Create(user *entity.User) (string, error) { diff --git a/src/internal/server/login_handler.go b/src/internal/server/login_handler.go index 92dc107..469e68e 100644 --- a/src/internal/server/login_handler.go +++ b/src/internal/server/login_handler.go @@ -16,7 +16,7 @@ func (s *Server) LoginHandlerFn(c *fiber.Ctx) error { } repo := domain.NewUserRepository(s.GetDatabase()) - authSrv := service.NewAuthService(repo, s.GetDatabase(), s.GetCache()) + authSrv := service.NewAuthService(repo, s.GetCache()) token, err := authSrv.Login(data.Username, data.Password) if err != nil { diff --git a/src/internal/server/refresh_handler.go b/src/internal/server/refresh_handler.go index 25add8e..79bfefe 100644 --- a/src/internal/server/refresh_handler.go +++ b/src/internal/server/refresh_handler.go @@ -15,7 +15,7 @@ func (s *Server) RefreshHandlerFn(c *fiber.Ctx) error { } repo := domain.NewUserRepository(s.GetDatabase()) - authSrv := service.NewAuthService(repo, s.GetDatabase(), s.GetCache()) + authSrv := service.NewAuthService(repo, s.GetCache()) token, err := authSrv.RefreshToken(data.AccessToken) if err != nil { diff --git a/src/internal/server/register_handler.go b/src/internal/server/register_handler.go index c4bfd6e..345a49c 100644 --- a/src/internal/server/register_handler.go +++ b/src/internal/server/register_handler.go @@ -15,7 +15,7 @@ func (s *Server) RegisterHandlerFn(c *fiber.Ctx) error { } repo := domain.NewUserRepository(s.GetDatabase()) - authSrv := service.NewAuthService(repo, s.GetDatabase(), s.GetCache()) + authSrv := service.NewAuthService(repo, s.GetCache()) id, err := authSrv.Register(data.Email, data.Username, data.Password) if err != nil { diff --git a/src/internal/service/auth.go b/src/internal/service/auth.go index d93918f..4f04d8f 100644 --- a/src/internal/service/auth.go +++ b/src/internal/service/auth.go @@ -9,14 +9,13 @@ import ( domain "git.ego.freeddns.org/egommerce/identity-service/domain/repository" "github.com/go-redis/redis/v8" - "github.com/jackc/pgx/v5/pgxpool" ) var ( passSrv *PaswordService ErrLoginIncorrect = errors.New("login incorrect") - ErrUnableToCacheToken = errors.New("unable to save token in cache") + ErrUnableToCacheToken = errors.New("unable to save tokens in cache") ErrInvalidAccessToken = errors.New("invalid access token") ) @@ -25,24 +24,19 @@ func init() { } type Auth struct { - repo *domain.UserRepository - db *pgxpool.Pool - cache *redis.Client + userRepo *domain.UserRepository + cache *redis.Client } -func NewAuthService(repo *domain.UserRepository, db *pgxpool.Pool, cache *redis.Client) *Auth { +func NewAuthService(userRepo *domain.UserRepository, cache *redis.Client) *Auth { return &Auth{ - repo: repo, - db: db, - cache: cache, + userRepo: userRepo, + cache: cache, } } func (a *Auth) Login(login, passwd string) (string, error) { - var id, hashedPasswd string - - sql := `SELECT id, password FROM identity.users WHERE username=$1 LIMIT 1` - err := a.db.QueryRow(context.Background(), sql, login).Scan(&id, &hashedPasswd) + user, err := a.userRepo.GetByUsername(login) if err != nil { // if err = database.NoRowsInQuerySet(err); err != nil { // return "", errors.New("no user found") @@ -51,13 +45,13 @@ func (a *Auth) Login(login, passwd string) (string, error) { return "", ErrLoginIncorrect } - if err = passSrv.Verify(passwd, hashedPasswd); err != nil { + if err = passSrv.Verify(passwd, user.Password); err != nil { return "", ErrLoginIncorrect } - accessToken, _ := jwtSrv.CreateAccessToken(id) - refreshToken, _ := jwtSrv.CreateRefreshToken(id) - if err = a.saveTokensToCache(id, accessToken, refreshToken); err != nil { + accessToken, _ := jwtSrv.CreateAccessToken(user.ID) + refreshToken, _ := jwtSrv.CreateRefreshToken(user.ID) + if err = a.saveTokensToCache(user.ID, accessToken, refreshToken); err != nil { return "", ErrUnableToCacheToken } @@ -84,7 +78,7 @@ func (a *Auth) RefreshToken(accessToken string) (string, error) { func (a *Auth) Register(email, login, passwd string) (string, error) { passwd, _ = passSrv.Hash(passwd) - id, err := a.repo.Create(&entity.User{ + id, err := a.userRepo.Create(&entity.User{ Email: email, Username: login, Password: passwd, @@ -99,12 +93,12 @@ func (a *Auth) Register(email, login, passwd string) (string, error) { func (a *Auth) saveTokensToCache(id, accessToken, refreshToken string) error { res := a.cache.Set(context.Background(), "auth:access_token:"+id, accessToken, accessTokenExpireTime) if err := res.Err(); err != nil { - fmt.Println("failed to save access token in redis: ", err.Error()) + fmt.Println("failed to save access token in cache: ", err.Error()) } res = a.cache.Set(context.Background(), "auth:refresh_token:"+id, refreshToken, refreshTokenExpireTime) if err := res.Err(); err != nil { - fmt.Println("failed to save refresh token in redis: ", err.Error()) + fmt.Println("failed to save refresh token in cache: ", err.Error()) } return nil