Refactor replace db with UserRepository

This commit is contained in:
PB
2025-10-21 20:39:34 +02:00
parent 9ceea35b08
commit 89b665c3d9
5 changed files with 40 additions and 24 deletions

View File

@@ -19,7 +19,29 @@ func NewUserRepository(db *pgxpool.Pool) *UserRepository {
} }
func (r *UserRepository) GetByID(id string) (*entity.User, error) { func (r *UserRepository) GetByID(id string) (*entity.User, error) {
return &entity.User{}, nil var user entity.User
sql := `SELECT id, username, password, email, created_at FROM identity.users WHERE id=$1 LIMIT 1`
err := r.db.QueryRow(context.Background(), sql, id).
Scan(&user.ID, &user.Username, &user.Password, &user.Email, &user.CreatedAt)
if err != nil {
return nil, errors.New("failed to fetch user from DB: " + err.Error())
}
return &user, nil
}
func (r *UserRepository) GetByUsername(login string) (*entity.User, error) {
var user entity.User
sql := `SELECT id, username, password, email, created_at FROM identity.users WHERE username=$1 LIMIT 1`
err := r.db.QueryRow(context.Background(), sql, login).
Scan(&user.ID, &user.Username, &user.Password, &user.Email, &user.CreatedAt)
if err != nil {
return nil, errors.New("failed to fetch user from DB: " + err.Error())
}
return &user, nil
} }
func (r *UserRepository) Create(user *entity.User) (string, error) { func (r *UserRepository) Create(user *entity.User) (string, error) {

View File

@@ -16,7 +16,7 @@ func (s *Server) LoginHandlerFn(c *fiber.Ctx) error {
} }
repo := domain.NewUserRepository(s.GetDatabase()) repo := domain.NewUserRepository(s.GetDatabase())
authSrv := service.NewAuthService(repo, s.GetDatabase(), s.GetCache()) authSrv := service.NewAuthService(repo, s.GetCache())
token, err := authSrv.Login(data.Username, data.Password) token, err := authSrv.Login(data.Username, data.Password)
if err != nil { if err != nil {

View File

@@ -15,7 +15,7 @@ func (s *Server) RefreshHandlerFn(c *fiber.Ctx) error {
} }
repo := domain.NewUserRepository(s.GetDatabase()) repo := domain.NewUserRepository(s.GetDatabase())
authSrv := service.NewAuthService(repo, s.GetDatabase(), s.GetCache()) authSrv := service.NewAuthService(repo, s.GetCache())
token, err := authSrv.RefreshToken(data.AccessToken) token, err := authSrv.RefreshToken(data.AccessToken)
if err != nil { if err != nil {

View File

@@ -15,7 +15,7 @@ func (s *Server) RegisterHandlerFn(c *fiber.Ctx) error {
} }
repo := domain.NewUserRepository(s.GetDatabase()) repo := domain.NewUserRepository(s.GetDatabase())
authSrv := service.NewAuthService(repo, s.GetDatabase(), s.GetCache()) authSrv := service.NewAuthService(repo, s.GetCache())
id, err := authSrv.Register(data.Email, data.Username, data.Password) id, err := authSrv.Register(data.Email, data.Username, data.Password)
if err != nil { if err != nil {

View File

@@ -9,14 +9,13 @@ import (
domain "git.ego.freeddns.org/egommerce/identity-service/domain/repository" domain "git.ego.freeddns.org/egommerce/identity-service/domain/repository"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/jackc/pgx/v5/pgxpool"
) )
var ( var (
passSrv *PaswordService passSrv *PaswordService
ErrLoginIncorrect = errors.New("login incorrect") ErrLoginIncorrect = errors.New("login incorrect")
ErrUnableToCacheToken = errors.New("unable to save token in cache") ErrUnableToCacheToken = errors.New("unable to save tokens in cache")
ErrInvalidAccessToken = errors.New("invalid access token") ErrInvalidAccessToken = errors.New("invalid access token")
) )
@@ -25,24 +24,19 @@ func init() {
} }
type Auth struct { type Auth struct {
repo *domain.UserRepository userRepo *domain.UserRepository
db *pgxpool.Pool cache *redis.Client
cache *redis.Client
} }
func NewAuthService(repo *domain.UserRepository, db *pgxpool.Pool, cache *redis.Client) *Auth { func NewAuthService(userRepo *domain.UserRepository, cache *redis.Client) *Auth {
return &Auth{ return &Auth{
repo: repo, userRepo: userRepo,
db: db, cache: cache,
cache: cache,
} }
} }
func (a *Auth) Login(login, passwd string) (string, error) { func (a *Auth) Login(login, passwd string) (string, error) {
var id, hashedPasswd string user, err := a.userRepo.GetByUsername(login)
sql := `SELECT id, password FROM identity.users WHERE username=$1 LIMIT 1`
err := a.db.QueryRow(context.Background(), sql, login).Scan(&id, &hashedPasswd)
if err != nil { if err != nil {
// if err = database.NoRowsInQuerySet(err); err != nil { // if err = database.NoRowsInQuerySet(err); err != nil {
// return "", errors.New("no user found") // return "", errors.New("no user found")
@@ -51,13 +45,13 @@ func (a *Auth) Login(login, passwd string) (string, error) {
return "", ErrLoginIncorrect return "", ErrLoginIncorrect
} }
if err = passSrv.Verify(passwd, hashedPasswd); err != nil { if err = passSrv.Verify(passwd, user.Password); err != nil {
return "", ErrLoginIncorrect return "", ErrLoginIncorrect
} }
accessToken, _ := jwtSrv.CreateAccessToken(id) accessToken, _ := jwtSrv.CreateAccessToken(user.ID)
refreshToken, _ := jwtSrv.CreateRefreshToken(id) refreshToken, _ := jwtSrv.CreateRefreshToken(user.ID)
if err = a.saveTokensToCache(id, accessToken, refreshToken); err != nil { if err = a.saveTokensToCache(user.ID, accessToken, refreshToken); err != nil {
return "", ErrUnableToCacheToken return "", ErrUnableToCacheToken
} }
@@ -84,7 +78,7 @@ func (a *Auth) RefreshToken(accessToken string) (string, error) {
func (a *Auth) Register(email, login, passwd string) (string, error) { func (a *Auth) Register(email, login, passwd string) (string, error) {
passwd, _ = passSrv.Hash(passwd) passwd, _ = passSrv.Hash(passwd)
id, err := a.repo.Create(&entity.User{ id, err := a.userRepo.Create(&entity.User{
Email: email, Email: email,
Username: login, Username: login,
Password: passwd, Password: passwd,
@@ -99,12 +93,12 @@ func (a *Auth) Register(email, login, passwd string) (string, error) {
func (a *Auth) saveTokensToCache(id, accessToken, refreshToken string) error { func (a *Auth) saveTokensToCache(id, accessToken, refreshToken string) error {
res := a.cache.Set(context.Background(), "auth:access_token:"+id, accessToken, accessTokenExpireTime) res := a.cache.Set(context.Background(), "auth:access_token:"+id, accessToken, accessTokenExpireTime)
if err := res.Err(); err != nil { if err := res.Err(); err != nil {
fmt.Println("failed to save access token in redis: ", err.Error()) fmt.Println("failed to save access token in cache: ", err.Error())
} }
res = a.cache.Set(context.Background(), "auth:refresh_token:"+id, refreshToken, refreshTokenExpireTime) res = a.cache.Set(context.Background(), "auth:refresh_token:"+id, refreshToken, refreshTokenExpireTime)
if err := res.Err(); err != nil { if err := res.Err(); err != nil {
fmt.Println("failed to save refresh token in redis: ", err.Error()) fmt.Println("failed to save refresh token in cache: ", err.Error())
} }
return nil return nil