This commit is contained in:
Daniel Goc
2026-03-18 16:09:03 +01:00
parent e094865fc7
commit 52c17d7017
8 changed files with 170 additions and 173 deletions

View File

@@ -40,6 +40,7 @@ func AuthHandlerRoutes(r fiber.Router) fiber.Router {
r.Post("/reset-password", handler.ResetPassword) r.Post("/reset-password", handler.ResetPassword)
r.Post("/logout", handler.Logout) r.Post("/logout", handler.Logout)
r.Post("/refresh", handler.RefreshToken) r.Post("/refresh", handler.RefreshToken)
r.Post("/update-choice", handler.UpdateJWTToken)
// Google OAuth2 // Google OAuth2
r.Get("/google", handler.GoogleLogin) r.Get("/google", handler.GoogleLogin)
@@ -344,6 +345,11 @@ func (h *AuthHandler) CompleteRegistration(c fiber.Ctx) error {
return c.Status(fiber.StatusCreated).JSON(response) return c.Status(fiber.StatusCreated).JSON(response)
} }
// CompleteRegistration handles completion of registration with password
func (h *AuthHandler) UpdateJWTToken(c fiber.Ctx) error {
return h.UpdateJWTToken(c)
}
// GoogleLogin redirects the user to Google's OAuth2 consent page // GoogleLogin redirects the user to Google's OAuth2 consent page
func (h *AuthHandler) GoogleLogin(c fiber.Ctx) error { func (h *AuthHandler) GoogleLogin(c fiber.Ctx) error {
// Generate a random state token and store it in a short-lived cookie // Generate a random state token and store it in a short-lived cookie

View File

@@ -1,115 +0,0 @@
package restricted
import (
"strconv"
"git.ma-al.com/goc_daniel/b2b/app/service/authService"
"git.ma-al.com/goc_daniel/b2b/app/service/jwtService"
"git.ma-al.com/goc_daniel/b2b/app/utils/i18n"
"git.ma-al.com/goc_daniel/b2b/app/utils/nullable"
"git.ma-al.com/goc_daniel/b2b/app/utils/response"
"git.ma-al.com/goc_daniel/b2b/app/utils/responseErrors"
"github.com/gofiber/fiber/v3"
)
// JWTCookiesHandler for updating JWT cookies.
type JWTCookiesHandler struct {
jwtService *jwtService.JWTService
authService *authService.AuthService
}
// NewJWTCookiesHandler creates a new JWTCookiesHandler instance
func NewJWTCookiesHandler() *JWTCookiesHandler {
jwtService := jwtService.New()
authSvc := authService.NewAuthService()
return &JWTCookiesHandler{
jwtService: jwtService,
authService: authSvc,
}
}
func JWTCookiesHandlerRoutes(r fiber.Router) fiber.Router {
handler := NewJWTCookiesHandler()
r.Get("/get-languages", handler.GetLanguages)
r.Get("/get-countries", handler.GetCountries)
r.Get("/update-choice", handler.UpdateChoice)
return r
}
func (h *JWTCookiesHandler) GetLanguages(c fiber.Ctx) error {
languages, err := h.jwtService.GetLanguages()
if err != nil {
return c.Status(responseErrors.GetErrorStatus(err)).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, err)))
}
return c.JSON(response.Make(&languages, 0, i18n.T_(c, response.Message_OK)))
}
func (h *JWTCookiesHandler) GetCountries(c fiber.Ctx) error {
countries, err := h.jwtService.GetCountriesAndCurrencies()
if err != nil {
return c.Status(responseErrors.GetErrorStatus(err)).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, err)))
}
return c.JSON(response.Make(&countries, 0, i18n.T_(c, response.Message_OK)))
}
func (h *JWTCookiesHandler) UpdateChoice(c fiber.Ctx) error {
// Get user ID from JWT claims in context (set by auth middleware)
claims, ok := c.Locals("jwt_claims").(*authService.JWTClaims)
if !ok || claims == nil {
return c.Status(fiber.StatusUnauthorized).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, responseErrors.ErrNotAuthenticated)))
}
// Parse language and country_id from query params
langIDStr := c.Query("lang_id")
countryIDStr := c.Query("country_id")
var langID uint
if langIDStr != "" {
parsedID, err := strconv.ParseUint(langIDStr, 10, 32)
if err != nil {
return c.Status(fiber.StatusBadRequest).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, responseErrors.ErrBadLangID)))
}
langID = uint(parsedID)
} else {
langID = 0
}
var countryID uint
if countryIDStr != "" {
parsedID, err := strconv.ParseUint(countryIDStr, 10, 32)
if err != nil {
return c.Status(fiber.StatusBadRequest).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, responseErrors.ErrBadCountryID)))
}
countryID = uint(parsedID)
} else {
countryID = 0
}
// Update choice and get new token using AuthService
newToken, err := h.authService.UpdateChoice(claims.UserID, langID, countryID)
if err != nil {
return c.Status(responseErrors.GetErrorStatus(err)).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, err)))
}
// Set the new JWT cookie
cookie := new(fiber.Cookie)
cookie.Name = "jwt_token"
cookie.Value = newToken
cookie.HTTPOnly = true
cookie.Secure = true
cookie.SameSite = fiber.CookieSameSiteLaxMode
c.Cookie(cookie)
return c.JSON(response.Make(&fiber.Map{"token": newToken}, 0, i18n.T_(c, response.Message_OK)))
}

View File

@@ -0,0 +1,52 @@
package restricted
import (
"git.ma-al.com/goc_daniel/b2b/app/service/langsAndCountriesService"
"git.ma-al.com/goc_daniel/b2b/app/utils/i18n"
"git.ma-al.com/goc_daniel/b2b/app/utils/nullable"
"git.ma-al.com/goc_daniel/b2b/app/utils/response"
"git.ma-al.com/goc_daniel/b2b/app/utils/responseErrors"
"github.com/gofiber/fiber/v3"
)
// LangsAndCountriesHandler for getting languages and countries data
type LangsAndCountriesHandler struct {
langsAndCountriesService *langsAndCountriesService.LangsAndCountriesService
}
// NewLangsAndCountriesHandler creates a new LangsAndCountriesHandler instance
func NewLangsAndCountriesHandler() *LangsAndCountriesHandler {
langsAndCountriesService := langsAndCountriesService.New()
return &LangsAndCountriesHandler{
langsAndCountriesService: langsAndCountriesService,
}
}
func LangsAndCountriesHandlerRoutes(r fiber.Router) fiber.Router {
handler := NewLangsAndCountriesHandler()
r.Get("/get-languages", handler.GetLanguages)
r.Get("/get-countries", handler.GetCountries)
return r
}
func (h *LangsAndCountriesHandler) GetLanguages(c fiber.Ctx) error {
languages, err := h.langsAndCountriesService.GetLanguages()
if err != nil {
return c.Status(responseErrors.GetErrorStatus(err)).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, err)))
}
return c.JSON(response.Make(&languages, 0, i18n.T_(c, response.Message_OK)))
}
func (h *LangsAndCountriesHandler) GetCountries(c fiber.Ctx) error {
countries, err := h.langsAndCountriesService.GetCountriesAndCurrencies()
if err != nil {
return c.Status(responseErrors.GetErrorStatus(err)).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, err)))
}
return c.JSON(response.Make(&countries, 0, i18n.T_(c, response.Message_OK)))
}

View File

@@ -99,8 +99,8 @@ func (s *Server) Setup() error {
// changing the JWT cookies routes (restricted) // changing the JWT cookies routes (restricted)
// in reality it just handles changing user's country and language // in reality it just handles changing user's country and language
jwtUpdates := s.restricted.Group("/jwt-updates") langsAndCountries := s.restricted.Group("/langs-and-countries")
restricted.JWTCookiesHandlerRoutes(jwtUpdates) restricted.LangsAndCountriesHandlerRoutes(langsAndCountries)
// // Restricted routes example // // Restricted routes example
// restricted := s.api.Group("/restricted") // restricted := s.api.Group("/restricted")

View File

@@ -6,6 +6,7 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"strconv"
"time" "time"
"git.ma-al.com/goc_daniel/b2b/app/config" "git.ma-al.com/goc_daniel/b2b/app/config"
@@ -13,9 +14,13 @@ import (
"git.ma-al.com/goc_daniel/b2b/app/model" "git.ma-al.com/goc_daniel/b2b/app/model"
"git.ma-al.com/goc_daniel/b2b/app/service/emailService" "git.ma-al.com/goc_daniel/b2b/app/service/emailService"
constdata "git.ma-al.com/goc_daniel/b2b/app/utils/const_data" constdata "git.ma-al.com/goc_daniel/b2b/app/utils/const_data"
"git.ma-al.com/goc_daniel/b2b/app/utils/i18n"
"git.ma-al.com/goc_daniel/b2b/app/utils/nullable"
"git.ma-al.com/goc_daniel/b2b/app/utils/response"
"git.ma-al.com/goc_daniel/b2b/app/utils/responseErrors" "git.ma-al.com/goc_daniel/b2b/app/utils/responseErrors"
"github.com/dlclark/regexp2" "github.com/dlclark/regexp2"
"github.com/gofiber/fiber/v3"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm" "gorm.io/gorm"
@@ -475,6 +480,16 @@ func hashToken(raw string) string {
// generateAccessToken generates a short-lived JWT access token // generateAccessToken generates a short-lived JWT access token
func (s *AuthService) generateAccessToken(user *model.Customer) (string, error) { func (s *AuthService) generateAccessToken(user *model.Customer) (string, error) {
_, err := s.GetLangISOCode(user.LangID)
if err != nil {
return "", responseErrors.ErrBadLangID
}
err = s.CheckIfCountryExists(user.CountryID)
if err != nil {
return "", responseErrors.ErrBadCountryID
}
claims := JWTClaims{ claims := JWTClaims{
UserID: user.ID, UserID: user.ID,
Email: user.Email, Email: user.Email,
@@ -493,43 +508,82 @@ func (s *AuthService) generateAccessToken(user *model.Customer) (string, error)
return token.SignedString([]byte(s.config.JWTSecret)) return token.SignedString([]byte(s.config.JWTSecret))
} }
// UpdateChoice updates the user's language and/or country choice and returns a new JWT token func (s *AuthService) UpdateJWTToken(c fiber.Ctx) error {
func (s *AuthService) UpdateChoice(userID uint, langID uint, countryID uint) (string, error) { // Get user ID from JWT claims in context (set by auth middleware)
claims, ok := c.Locals("jwt_claims").(*JWTClaims)
if !ok || claims == nil {
return c.Status(fiber.StatusUnauthorized).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, responseErrors.ErrNotAuthenticated)))
}
var user model.Customer var user model.Customer
// Find user by ID // Find user by ID
if err := s.db.First(&user, userID).Error; err != nil { if err := s.db.First(&user, claims.UserID).Error; err != nil {
return "", err return err
} }
// Update user langID if provided // Parse language and country_id from query params
if langID == 0 { langIDStr := c.Query("lang_id")
langID = user.LangID
} var langID uint
_, err := s.GetLangISOCode(langID) if langIDStr != "" {
parsedID, err := strconv.ParseUint(langIDStr, 10, 32)
if err != nil { if err != nil {
return "", responseErrors.ErrBadLangID return c.Status(fiber.StatusBadRequest).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, responseErrors.ErrBadLangID)))
}
langID = uint(parsedID)
_, err = s.GetLangISOCode(langID)
if err != nil {
return responseErrors.ErrBadLangID
} else { } else {
user.LangID = langID user.LangID = langID
} }
if countryID == 0 {
countryID = user.CountryID
} }
countryIDStr := c.Query("country_id")
var countryID uint
if countryIDStr != "" {
parsedID, err := strconv.ParseUint(countryIDStr, 10, 32)
if err != nil {
return c.Status(fiber.StatusBadRequest).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, responseErrors.ErrBadCountryID)))
}
countryID = uint(parsedID)
err = s.CheckIfCountryExists(countryID) err = s.CheckIfCountryExists(countryID)
if err != nil { if err != nil {
return "", responseErrors.ErrBadCountryID return responseErrors.ErrBadCountryID
} else { } else {
user.CountryID = countryID user.CountryID = countryID
} }
}
// Update choice and get new token using AuthService
newToken, err := s.generateAccessToken(&user)
if err != nil {
return c.Status(responseErrors.GetErrorStatus(err)).
JSON(response.Make(nullable.GetNil(""), 0, responseErrors.GetErrorCode(c, err)))
}
// Save the updated user // Save the updated user
if err := s.db.Save(&user).Error; err != nil { if err := s.db.Save(&user).Error; err != nil {
return "", err return fmt.Errorf("database error: %w", err)
} }
// Generate new JWT token with updated claims // Set the new JWT cookie
return s.generateAccessToken(&user) cookie := new(fiber.Cookie)
cookie.Name = "jwt_token"
cookie.Value = newToken
cookie.HTTPOnly = true
cookie.Secure = true
cookie.SameSite = fiber.CookieSameSiteLaxMode
c.Cookie(cookie)
return c.JSON(response.Make(&fiber.Map{"token": newToken}, 0, i18n.T_(c, response.Message_OK)))
} }
// generateVerificationToken generates a random verification token // generateVerificationToken generates a random verification token

View File

@@ -1,26 +0,0 @@
package jwtService
import (
"git.ma-al.com/goc_daniel/b2b/app/model"
"git.ma-al.com/goc_daniel/b2b/repository/jwtFieldsRepo"
)
// JWTService handles retrieving JWT fields (languages and countries)
type JWTService struct {
repo jwtFieldsRepo.UIJWTFieldsRepo
}
// NewJWTService creates a new JWT service
func New() *JWTService {
return &JWTService{
repo: jwtFieldsRepo.New(),
}
}
func (s *JWTService) GetLanguages() ([]model.Language, error) {
return s.repo.GetLanguages()
}
func (s *JWTService) GetCountriesAndCurrencies() ([]model.Country, error) {
return s.repo.GetCountriesAndCurrencies()
}

View File

@@ -0,0 +1,26 @@
package langsAndCountriesService
import (
"git.ma-al.com/goc_daniel/b2b/app/model"
"git.ma-al.com/goc_daniel/b2b/repository/langsAndCountriesRepo"
)
// LangsAndCountriesService literally sends back language and countries information.
type LangsAndCountriesService struct {
repo langsAndCountriesRepo.UILangsAndCountriesRepo
}
// NewLangsAndCountriesService creates a new LangsAndCountries service
func New() *LangsAndCountriesService {
return &LangsAndCountriesService{
repo: langsAndCountriesRepo.New(),
}
}
func (s *LangsAndCountriesService) GetLanguages() ([]model.Language, error) {
return s.repo.GetLanguages()
}
func (s *LangsAndCountriesService) GetCountriesAndCurrencies() ([]model.Country, error) {
return s.repo.GetCountriesAndCurrencies()
}

View File

@@ -1,22 +1,22 @@
package jwtFieldsRepo package langsAndCountriesRepo
import ( import (
"git.ma-al.com/goc_daniel/b2b/app/db" "git.ma-al.com/goc_daniel/b2b/app/db"
"git.ma-al.com/goc_daniel/b2b/app/model" "git.ma-al.com/goc_daniel/b2b/app/model"
) )
type UIJWTFieldsRepo interface { type UILangsAndCountriesRepo interface {
GetLanguages() ([]model.Language, error) GetLanguages() ([]model.Language, error)
GetCountriesAndCurrencies() ([]model.Country, error) GetCountriesAndCurrencies() ([]model.Country, error)
} }
type JWTFieldsRepo struct{} type LangsAndCountriesRepo struct{}
func New() UIJWTFieldsRepo { func New() UILangsAndCountriesRepo {
return &JWTFieldsRepo{} return &LangsAndCountriesRepo{}
} }
func (repo *JWTFieldsRepo) GetLanguages() ([]model.Language, error) { func (repo *LangsAndCountriesRepo) GetLanguages() ([]model.Language, error) {
var languages []model.Language var languages []model.Language
err := db.DB.Table("b2b_language").Scan(&languages).Error err := db.DB.Table("b2b_language").Scan(&languages).Error
@@ -24,7 +24,7 @@ func (repo *JWTFieldsRepo) GetLanguages() ([]model.Language, error) {
return languages, err return languages, err
} }
func (repo *JWTFieldsRepo) GetCountriesAndCurrencies() ([]model.Country, error) { func (repo *LangsAndCountriesRepo) GetCountriesAndCurrencies() ([]model.Country, error) {
var countries []model.Country var countries []model.Country
err := db.DB.Table("b2b_countries"). err := db.DB.Table("b2b_countries").