endpoint to update JWT cookies

This commit is contained in:
Daniel Goc
2026-03-18 15:40:54 +01:00
parent 01c8f4333f
commit e094865fc7
7 changed files with 167 additions and 30 deletions

View File

@@ -408,9 +408,12 @@ func (h *AuthHandler) GoogleCallback(c fiber.Ctx) error {
// Redirect to the locale-prefixed charts page after successful Google login. // Redirect to the locale-prefixed charts page after successful Google login.
// The user's preferred language is stored in the auth response; fall back to "en". // The user's preferred language is stored in the auth response; fall back to "en".
lang := response.User.Lang lang, err := h.authService.GetLangISOCode(response.User.LangID)
if lang == "" { if err != nil {
lang = "en" return c.Status(responseErrors.GetErrorStatus(responseErrors.ErrBadLangID)).JSON(fiber.Map{
"error": responseErrors.GetErrorCode(c, responseErrors.ErrBadLangID),
})
} }
return c.Redirect().To(h.config.App.BaseURL + "/" + lang) return c.Redirect().To(h.config.App.BaseURL + "/" + lang)
} }

View File

@@ -1,6 +1,9 @@
package restricted package restricted
import ( import (
"strconv"
"git.ma-al.com/goc_daniel/b2b/app/service/authService"
"git.ma-al.com/goc_daniel/b2b/app/service/jwtService" "git.ma-al.com/goc_daniel/b2b/app/service/jwtService"
"git.ma-al.com/goc_daniel/b2b/app/utils/i18n" "git.ma-al.com/goc_daniel/b2b/app/utils/i18n"
"git.ma-al.com/goc_daniel/b2b/app/utils/nullable" "git.ma-al.com/goc_daniel/b2b/app/utils/nullable"
@@ -12,13 +15,16 @@ import (
// JWTCookiesHandler for updating JWT cookies. // JWTCookiesHandler for updating JWT cookies.
type JWTCookiesHandler struct { type JWTCookiesHandler struct {
jwtService *jwtService.JWTService jwtService *jwtService.JWTService
authService *authService.AuthService
} }
// NewJWTCookiesHandler creates a new JWTCookiesHandler instance // NewJWTCookiesHandler creates a new JWTCookiesHandler instance
func NewJWTCookiesHandler() *JWTCookiesHandler { func NewJWTCookiesHandler() *JWTCookiesHandler {
jwtService := jwtService.New() jwtService := jwtService.New()
authSvc := authService.NewAuthService()
return &JWTCookiesHandler{ return &JWTCookiesHandler{
jwtService: jwtService, jwtService: jwtService,
authService: authSvc,
} }
} }
@@ -53,5 +59,57 @@ func (h *JWTCookiesHandler) GetCountries(c fiber.Ctx) error {
} }
func (h *JWTCookiesHandler) UpdateChoice(c fiber.Ctx) error { func (h *JWTCookiesHandler) UpdateChoice(c fiber.Ctx) error {
return nil // Get user ID from JWT claims in context (set by auth middleware)
claims, ok := c.Locals("jwt_claims").(*authService.JWTClaims)
if !ok || claims == nil {
return c.Status(fiber.StatusUnauthorized).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, responseErrors.ErrNotAuthenticated)))
}
// Parse language and country_id from query params
langIDStr := c.Query("lang_id")
countryIDStr := c.Query("country_id")
var langID uint
if langIDStr != "" {
parsedID, err := strconv.ParseUint(langIDStr, 10, 32)
if err != nil {
return c.Status(fiber.StatusBadRequest).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, responseErrors.ErrBadLangID)))
}
langID = uint(parsedID)
} else {
langID = 0
}
var countryID uint
if countryIDStr != "" {
parsedID, err := strconv.ParseUint(countryIDStr, 10, 32)
if err != nil {
return c.Status(fiber.StatusBadRequest).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, responseErrors.ErrBadCountryID)))
}
countryID = uint(parsedID)
} else {
countryID = 0
}
// Update choice and get new token using AuthService
newToken, err := h.authService.UpdateChoice(claims.UserID, langID, countryID)
if err != nil {
return c.Status(responseErrors.GetErrorStatus(err)).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, err)))
}
// Set the new JWT cookie
cookie := new(fiber.Cookie)
cookie.Name = "jwt_token"
cookie.Value = newToken
cookie.HTTPOnly = true
cookie.Secure = true
cookie.SameSite = fiber.CookieSameSiteLaxMode
c.Cookie(cookie)
return c.JSON(response.Make(&fiber.Map{"token": newToken}, 0, i18n.T_(c, response.Message_OK)))
} }

View File

@@ -25,7 +25,8 @@ type Customer struct {
PasswordResetExpires *time.Time `json:"-"` PasswordResetExpires *time.Time `json:"-"`
LastPasswordResetRequest *time.Time `json:"-"` LastPasswordResetRequest *time.Time `json:"-"`
LastLoginAt *time.Time `json:"last_login_at,omitempty"` LastLoginAt *time.Time `json:"last_login_at,omitempty"`
Lang string `gorm:"size:10;default:'en'" json:"lang"` // User's preferred language LangID uint `gorm:"default:2" json:"lang_id"` // User's preferred language
CountryID uint `gorm:"default:2" json:"country_id"` // User's selected country
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
@@ -76,9 +77,8 @@ type UserSession struct {
Email string `json:"email"` Email string `json:"email"`
Username string `json:"username"` Username string `json:"username"`
Role CustomerRole `json:"role"` Role CustomerRole `json:"role"`
FirstName string `json:"first_name"` LangID uint `json:"lang_id"`
LastName string `json:"last_name"` CountryID uint `json:"country_id"`
Lang string `json:"lang"`
} }
// ToSession converts User to UserSession // ToSession converts User to UserSession
@@ -87,9 +87,8 @@ func (u *Customer) ToSession() *UserSession {
UserID: u.ID, UserID: u.ID,
Email: u.Email, Email: u.Email,
Role: u.Role, Role: u.Role,
FirstName: u.FirstName, LangID: u.LangID,
LastName: u.LastName, CountryID: u.CountryID,
Lang: u.Lang,
} }
} }
@@ -107,7 +106,8 @@ type RegisterRequest struct {
ConfirmPassword string `json:"confirm_password" form:"confirm_password"` ConfirmPassword string `json:"confirm_password" form:"confirm_password"`
FirstName string `json:"first_name" form:"first_name"` FirstName string `json:"first_name" form:"first_name"`
LastName string `json:"last_name" form:"last_name"` LastName string `json:"last_name" form:"last_name"`
Lang string `form:"lang" json:"lang"` LangID uint `form:"lang_id" json:"lang_id"`
CountryID uint `form:"country_id" json:"country_id"`
} }
// CompleteRegistrationRequest represents the completion of registration with email verification // CompleteRegistrationRequest represents the completion of registration with email verification

View File

@@ -28,6 +28,7 @@ type JWTClaims struct {
Username string `json:"username"` Username string `json:"username"`
Role model.CustomerRole `json:"customer_role"` Role model.CustomerRole `json:"customer_role"`
CartsIDs []uint `json:"carts_ids"` CartsIDs []uint `json:"carts_ids"`
LangID uint `json:"lang_id"`
CountryID uint `json:"country_id"` CountryID uint `json:"country_id"`
jwt.RegisteredClaims jwt.RegisteredClaims
} }
@@ -149,7 +150,8 @@ func (s *AuthService) Register(req *model.RegisterRequest) error {
EmailVerified: false, EmailVerified: false,
EmailVerificationToken: token, EmailVerificationToken: token,
EmailVerificationExpires: &expiresAt, EmailVerificationExpires: &expiresAt,
Lang: req.Lang, LangID: req.LangID,
CountryID: req.CountryID,
} }
if err := s.db.Create(&user).Error; err != nil { if err := s.db.Create(&user).Error; err != nil {
@@ -158,10 +160,11 @@ func (s *AuthService) Register(req *model.RegisterRequest) error {
// Send verification email // Send verification email
baseURL := config.Get().App.BaseURL baseURL := config.Get().App.BaseURL
lang := req.Lang lang, err := s.GetLangISOCode(req.LangID)
if lang == "" { if err != nil {
lang = "en" // Default to English return responseErrors.ErrBadLangID
} }
if err := s.email.SendVerificationEmail(user.Email, user.EmailVerificationToken, baseURL, lang); err != nil { if err := s.email.SendVerificationEmail(user.Email, user.EmailVerificationToken, baseURL, lang); err != nil {
// Log error but don't fail registration - user can request resend // Log error but don't fail registration - user can request resend
_ = err _ = err
@@ -266,10 +269,11 @@ func (s *AuthService) RequestPasswordReset(emailAddr string) error {
// Send password reset email // Send password reset email
baseURL := config.Get().App.BaseURL baseURL := config.Get().App.BaseURL
lang := "en" lang, err := s.GetLangISOCode(user.LangID)
if user.Lang != "" { if err != nil {
lang = user.Lang return responseErrors.ErrBadLangID
} }
if err := s.email.SendPasswordResetEmail(user.Email, user.PasswordResetToken, baseURL, lang); err != nil { if err := s.email.SendPasswordResetEmail(user.Email, user.PasswordResetToken, baseURL, lang); err != nil {
_ = err _ = err
} }
@@ -477,7 +481,8 @@ func (s *AuthService) generateAccessToken(user *model.Customer) (string, error)
Username: user.Email, Username: user.Email,
Role: user.Role, Role: user.Role,
CartsIDs: []uint{}, CartsIDs: []uint{},
CountryID: 1, LangID: user.LangID,
CountryID: user.CountryID,
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(s.config.JWTExpiration) * time.Second)), ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(s.config.JWTExpiration) * time.Second)),
IssuedAt: jwt.NewNumericDate(time.Now()), IssuedAt: jwt.NewNumericDate(time.Now()),
@@ -488,6 +493,45 @@ func (s *AuthService) generateAccessToken(user *model.Customer) (string, error)
return token.SignedString([]byte(s.config.JWTSecret)) return token.SignedString([]byte(s.config.JWTSecret))
} }
// UpdateChoice updates the user's language and/or country choice and returns a new JWT token
func (s *AuthService) UpdateChoice(userID uint, langID uint, countryID uint) (string, error) {
var user model.Customer
// Find user by ID
if err := s.db.First(&user, userID).Error; err != nil {
return "", err
}
// Update user langID if provided
if langID == 0 {
langID = user.LangID
}
_, err := s.GetLangISOCode(langID)
if err != nil {
return "", responseErrors.ErrBadLangID
} else {
user.LangID = langID
}
if countryID == 0 {
countryID = user.CountryID
}
err = s.CheckIfCountryExists(countryID)
if err != nil {
return "", responseErrors.ErrBadCountryID
} else {
user.CountryID = countryID
}
// Save the updated user
if err := s.db.Save(&user).Error; err != nil {
return "", err
}
// Generate new JWT token with updated claims
return s.generateAccessToken(&user)
}
// generateVerificationToken generates a random verification token // generateVerificationToken generates a random verification token
func (s *AuthService) generateVerificationToken() (string, error) { func (s *AuthService) generateVerificationToken() (string, error) {
bytes := make([]byte, 32) bytes := make([]byte, 32)
@@ -507,3 +551,29 @@ func validatePassword(password string) error {
return nil return nil
} }
func (s *AuthService) GetLangISOCode(langID uint) (string, error) {
var lang string
if langID == 0 { // retrieve the default lang
err := db.DB.Table("b2b_language").Where("is_default = ?", 1).First(lang).Error
return lang, err
} else {
err := db.DB.Table("b2b_language").Where("id = ?", langID).Where("active = ?", 1).First(lang).Error
return lang, err
}
}
func (s *AuthService) CheckIfCountryExists(countryID uint) error {
var count int64
err := db.DB.Table("b2b_countries").Where("id = ?", countryID).Count(&count).Error
if err != nil {
return err
}
if count == 0 {
return responseErrors.ErrBadCountryID
}
return nil
}

View File

@@ -153,7 +153,7 @@ func (s *AuthService) findOrCreateGoogleUser(info *view.GoogleUserInfo) (*model.
Role: model.RoleUser, Role: model.RoleUser,
IsActive: true, IsActive: true,
EmailVerified: true, EmailVerified: true,
Lang: "en", LangID: 2,
} }
if err := s.db.Create(&newUser).Error; err != nil { if err := s.db.Create(&newUser).Error; err != nil {

View File

@@ -5,14 +5,16 @@ import (
"git.ma-al.com/goc_daniel/b2b/repository/jwtFieldsRepo" "git.ma-al.com/goc_daniel/b2b/repository/jwtFieldsRepo"
) )
// jwtService handles updating JWT cookies // JWTService handles retrieving JWT fields (languages and countries)
type JWTService struct { type JWTService struct {
repo jwtFieldsRepo.JWTFieldsRepo repo jwtFieldsRepo.UIJWTFieldsRepo
} }
// NewJWTService creates a new JWT service // NewJWTService creates a new JWT service
func New() *JWTService { func New() *JWTService {
return &JWTService{} return &JWTService{
repo: jwtFieldsRepo.New(),
}
} }
func (s *JWTService) GetLanguages() ([]model.Language, error) { func (s *JWTService) GetLanguages() ([]model.Language, error) {
@@ -22,7 +24,3 @@ func (s *JWTService) GetLanguages() ([]model.Language, error) {
func (s *JWTService) GetCountriesAndCurrencies() ([]model.Country, error) { func (s *JWTService) GetCountriesAndCurrencies() ([]model.Country, error) {
return s.repo.GetCountriesAndCurrencies() return s.repo.GetCountriesAndCurrencies()
} }
func (s *JWTService) UpdateChoice() error {
return nil
}

View File

@@ -25,6 +25,8 @@ var (
ErrEmailRequired = errors.New("email is required") ErrEmailRequired = errors.New("email is required")
ErrEmailPasswordRequired = errors.New("email and password are required") ErrEmailPasswordRequired = errors.New("email and password are required")
ErrRefreshTokenRequired = errors.New("refresh token is required") ErrRefreshTokenRequired = errors.New("refresh token is required")
ErrBadLangID = errors.New("bad language id")
ErrBadCountryID = errors.New("bad country id")
// Typed errors for password reset // Typed errors for password reset
ErrInvalidResetToken = errors.New("invalid reset token") ErrInvalidResetToken = errors.New("invalid reset token")
@@ -98,6 +100,10 @@ func GetErrorCode(c fiber.Ctx, err error) string {
return i18n.T_(c, "error.err_token_required") return i18n.T_(c, "error.err_token_required")
case errors.Is(err, ErrRefreshTokenRequired): case errors.Is(err, ErrRefreshTokenRequired):
return i18n.T_(c, "error.err_refresh_token_required") return i18n.T_(c, "error.err_refresh_token_required")
case errors.Is(err, ErrBadLangID):
return i18n.T_(c, "error.err_bad_lang_id")
case errors.Is(err, ErrBadCountryID):
return i18n.T_(c, "error.err_bad_country_id")
case errors.Is(err, ErrInvalidResetToken): case errors.Is(err, ErrInvalidResetToken):
return i18n.T_(c, "error.err_invalid_reset_token") return i18n.T_(c, "error.err_invalid_reset_token")
@@ -151,6 +157,8 @@ func GetErrorStatus(err error) int {
errors.Is(err, ErrEmailPasswordRequired), errors.Is(err, ErrEmailPasswordRequired),
errors.Is(err, ErrTokenRequired), errors.Is(err, ErrTokenRequired),
errors.Is(err, ErrRefreshTokenRequired), errors.Is(err, ErrRefreshTokenRequired),
errors.Is(err, ErrBadLangID),
errors.Is(err, ErrBadCountryID),
errors.Is(err, ErrPasswordsDoNotMatch), errors.Is(err, ErrPasswordsDoNotMatch),
errors.Is(err, ErrTokenPasswordRequired), errors.Is(err, ErrTokenPasswordRequired),
errors.Is(err, ErrInvalidResetToken), errors.Is(err, ErrInvalidResetToken),