package middleware import ( "context" "fmt" "hash/crc32" "net/http" "strconv" "strings" "github.com/labstack/echo/v4" psconfig "git.ma-al.com/goc_marek/ps_shop/internal/prestashop/config" pscookie "git.ma-al.com/goc_marek/ps_shop/internal/prestashop/cookie" ) type AnonymousSessionInitializer interface { NewAnonymous(ctx context.Context, req *http.Request, cookieName string) (*pscookie.SessionContext, error) } type SessionExpiryRefresher interface { RefreshExpiry(ctx context.Context, session *pscookie.SessionContext) error } type SessionCookieNameResolver interface { ResolveCookieName(ctx context.Context, req *http.Request) (string, error) } type LanguageResolver interface { ResolveLanguageID(ctx context.Context, req *http.Request, fallback int64) int64 } type ProductRouteMatcher interface { Owns(path string) bool } func Session(cfg psconfig.Config, codec pscookie.Codec, initializer AnonymousSessionInitializer, languageResolver LanguageResolver, matcher ProductRouteMatcher) echo.MiddlewareFunc { ownership := cfg.ParseRouteOwnership() expiryRefresher, _ := initializer.(SessionExpiryRefresher) cookieNameResolver, _ := initializer.(SessionCookieNameResolver) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { ownedRoute := ownsProductRoute(ownership.ProductPrefixes, c.Request().URL.Path, matcher) configuredCookieName := cfg.DeriveCookieName(requestCookieHost(c.Request())) if cookieNameResolver != nil { resolvedCookieName, err := cookieNameResolver.ResolveCookieName(c.Request().Context(), c.Request()) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("prestashop cookie name resolution failed: %v", err)) } if strings.TrimSpace(resolvedCookieName) != "" { configuredCookieName = resolvedCookieName } } cookieName, rawCookie := findPrestaShopCookie(c.Request(), configuredCookieName) if cookieName == "" { cookieName = configuredCookieName } session, err := codec.Decode(rawCookie) if err != nil { if ownedRoute { return echo.NewHTTPError(http.StatusInternalServerError, "prestashop cookie decode failed") } SetSession(c, &pscookie.SessionContext{ CookieName: cookieName, RawCookie: rawCookie, Values: map[string]string{}, ParseStatus: pscookie.ParseStatusInvalid, }) return next(c) } session.CookieName = cookieName if ownedRoute && initializer != nil && shouldBootstrapAnonymousSession(rawCookie, session) { session, err = initializer.NewAnonymous(c.Request().Context(), c.Request(), cookieName) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("prestashop session bootstrap failed: %v", err)) } } if ownedRoute { applyRequestLanguage(session, resolveRequestLanguageID(c.Request().Context(), c.Request(), session, languageResolver)) applyRequestMarket(session, requestMarketSelection(c.Request())) } if ownedRoute && shouldSetSessionCookie(rawCookie, session) { if expiryRefresher != nil { if err := expiryRefresher.RefreshExpiry(c.Request().Context(), session); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("prestashop session expiry refresh failed: %v", err)) } } encoded, err := codec.Encode(session) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "prestashop cookie encode failed") } session.RawCookie = encoded setPrestaShopCookie(c.Request(), c.Response(), session, cookieName, encoded) if redirectURL, ok := clearMarketSelectionURL(c.Request()); ok { return c.Redirect(http.StatusSeeOther, redirectURL) } } SetSession(c, session) return next(c) } } } type marketSelection struct { CountryID int64 CountryISO string CurrencyID int64 } func resolveRequestLanguageID(ctx context.Context, req *http.Request, session *pscookie.SessionContext, resolver LanguageResolver) int64 { if resolver == nil { return 0 } return resolver.ResolveLanguageID(ctx, req, sessionLanguageID(session)) } func findPrestaShopCookie(req *http.Request, configuredName string) (name, value string) { cookies := req.Cookies() for _, cookie := range cookies { if cookie.Name == configuredName { return cookie.Name, cookie.Value } } prefix := cookiePrefix(configuredName) if prefix == "" { return "", "" } for _, cookie := range cookies { if strings.HasPrefix(cookie.Name, prefix) { return cookie.Name, cookie.Value } } return "", "" } func cookiePrefix(configuredName string) string { if configuredName == "" { return "" } if strings.HasPrefix(configuredName, "PrestaShop-") { return "PrestaShop-" } if idx := strings.Index(configuredName, "-"); idx >= 0 { return configuredName[:idx+1] } return "" } func shouldBootstrapAnonymousSession(rawCookie string, session *pscookie.SessionContext) bool { if session == nil { return true } if rawCookie == "" { return true } if session.IsLoggedIn { return false } return session.GuestID == nil || session.CurrencyID == nil || session.LanguageID == nil || session.Values["id_connections"] == "" || session.Values["iso_code_country"] == "" } func shouldSetSessionCookie(rawCookie string, session *pscookie.SessionContext) bool { if session == nil { return false } if rawCookie == "" { return true } return rawCookie != session.RawCookie } func applyRequestLanguage(session *pscookie.SessionContext, languageID int64) { if session == nil || languageID == 0 { return } if current := sessionLanguageID(session); current == languageID { return } if session.Values == nil { session.Values = map[string]string{} } value := strconv.FormatInt(languageID, 10) session.LanguageID = int64Ptr(languageID) session.Values["id_lang"] = value session.Values["id_language"] = value session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_lang", 1) session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_language", 3) if !session.IsLoggedIn { if checksum := anonymousSessionChecksum(session, languageID); checksum != "" { session.Values["checksum"] = checksum session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "checksum", len(session.OrderedKeys)) } } session.Plaintext = "" session.RawCookie = "" } func applyRequestMarket(session *pscookie.SessionContext, selection marketSelection) { if session == nil || selection.CountryISO == "" || selection.CurrencyID == 0 { return } currentCountry := "" currentCurrency := int64(0) currentCountryID := int64(0) if session.Values != nil { currentCountry = strings.ToUpper(strings.TrimSpace(session.Values["iso_code_country"])) if session.CurrencyID != nil { currentCurrency = *session.CurrencyID } currentCountryID, _ = strconv.ParseInt(session.Values["id_country"], 10, 64) } if currentCountry == selection.CountryISO && currentCurrency == selection.CurrencyID && currentCountryID == selection.CountryID { return } if session.Values == nil { session.Values = map[string]string{} } session.CurrencyID = int64Ptr(selection.CurrencyID) session.Values["iso_code_country"] = selection.CountryISO session.Values["id_currency"] = strconv.FormatInt(selection.CurrencyID, 10) delete(session.Values, "id_country") session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "iso_code_country", 4) session.OrderedKeys = removeOrderedKey(session.OrderedKeys, "id_country") session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_currency", 5) if !session.IsLoggedIn { trimAnonymousCookieValues(session) session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_guest", 6) session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_connections", 7) session.OrderedKeys = removeOrderedKey(session.OrderedKeys, "checksum") session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "checksum", len(session.OrderedKeys)) } session.Plaintext = "" session.RawCookie = "" } func sessionLanguageID(session *pscookie.SessionContext) int64 { if session == nil || session.LanguageID == nil { return 0 } return *session.LanguageID } func anonymousSessionChecksum(session *pscookie.SessionContext, languageID int64) string { if session == nil || session.Values == nil { return "" } guestID, _ := strconv.ParseInt(session.Values["id_guest"], 10, 64) connectionID, _ := strconv.ParseInt(session.Values["id_connections"], 10, 64) currencyID, _ := strconv.ParseInt(session.Values["id_currency"], 10, 64) shopID, _ := strconv.ParseInt(session.Values["id_shop"], 10, 64) if guestID == 0 || connectionID == 0 || currencyID == 0 { return "" } buf := make([]byte, 0, 32) for _, value := range []int64{guestID, connectionID, languageID, currencyID, shopID} { buf = strconv.AppendInt(buf, value, 10) buf = append(buf, '|') } return strconv.FormatUint(uint64(crc32.ChecksumIEEE(buf)), 10) } func ensureOrderedKey(keys []string, key string, index int) []string { for i, existing := range keys { if existing != key { continue } if i == index || index >= len(keys) { return keys } keys = append(keys[:i], keys[i+1:]...) break } if index < 0 { index = 0 } if index >= len(keys) { return append(keys, key) } keys = append(keys, "") copy(keys[index+1:], keys[index:]) keys[index] = key return keys } func removeOrderedKey(keys []string, key string) []string { for i, existing := range keys { if existing == key { return append(keys[:i], keys[i+1:]...) } } return keys } func trimAnonymousCookieValues(session *pscookie.SessionContext) { if session == nil || session.Values == nil { return } allowed := map[string]struct{}{ "date_add": {}, "id_lang": {}, "id_language": {}, "iso_code_country": {}, "id_currency": {}, "id_guest": {}, "id_connections": {}, "checksum": {}, } for key := range session.Values { if _, ok := allowed[key]; !ok { delete(session.Values, key) } } filtered := make([]string, 0, len(session.OrderedKeys)) for _, key := range session.OrderedKeys { if _, ok := allowed[key]; ok { filtered = append(filtered, key) } } session.OrderedKeys = filtered } func int64Ptr(value int64) *int64 { if value == 0 { return nil } v := value return &v } func requestMarketSelection(req *http.Request) marketSelection { if req == nil || req.URL == nil { return marketSelection{} } raw := strings.TrimSpace(req.URL.Query().Get("market")) if raw == "" { return marketSelection{} } parts := strings.Split(raw, ":") if len(parts) != 2 && len(parts) != 3 { return marketSelection{} } selection := marketSelection{} var countryISO string var currencyValue string if len(parts) == 3 { countryID, err := strconv.ParseInt(strings.TrimSpace(parts[0]), 10, 64) if err != nil || countryID == 0 { return marketSelection{} } selection.CountryID = countryID countryISO = strings.ToUpper(strings.TrimSpace(parts[1])) currencyValue = parts[2] } else { countryISO = strings.ToUpper(strings.TrimSpace(parts[0])) currencyValue = parts[1] } currencyID, err := strconv.ParseInt(strings.TrimSpace(currencyValue), 10, 64) if err != nil || currencyID == 0 { return marketSelection{} } if len(countryISO) < 2 || len(countryISO) > 5 { return marketSelection{} } selection.CountryISO = countryISO selection.CurrencyID = currencyID return selection } func clearMarketSelectionURL(req *http.Request) (string, bool) { if req == nil || req.URL == nil { return "", false } query := req.URL.Query() if query.Get("market") == "" { return "", false } query.Del("market") cleanPath := req.URL.Path if cleanPath == "" { cleanPath = "/" } if encoded := query.Encode(); encoded != "" { return cleanPath + "?" + encoded, true } return cleanPath, true } func setPrestaShopCookie(req *http.Request, res *echo.Response, session *pscookie.SessionContext, name, value string) { maxAge := 1 if session != nil && session.ExpiresAt != nil { maxAge = int(session.ExpiresAt.UTC().Unix()) } header := fmt.Sprintf("%s=%s; path=/; max-age=%d; HttpOnly; SameSite=Lax", name, value, maxAge) if requestCookieSecure(req) { header += "; Secure" } res.Header().Add(echo.HeaderSetCookie, header) } func requestCookieHost(req *http.Request) string { if req == nil { return "" } host := req.Header.Get("X-Forwarded-Host") if host == "" { host = req.Host } if strings.Contains(host, ",") { host = strings.TrimSpace(strings.Split(host, ",")[0]) } return host } func requestCookieSecure(req *http.Request) bool { if req.TLS != nil { return true } if forwarded := req.Header.Get("X-Forwarded-Proto"); forwarded != "" { if strings.Contains(forwarded, ",") { forwarded = strings.TrimSpace(strings.Split(forwarded, ",")[0]) } return strings.EqualFold(forwarded, "https") } return false } func ownsProductRoute(prefixes []string, path string, matcher ProductRouteMatcher) bool { if matcher != nil { return matcher.Owns(path) } for _, prefix := range prefixes { if prefix != "" && len(path) >= len(prefix) && path[:len(prefix)] == prefix { return true } } return false }