fix cookie -- not working

This commit is contained in:
2026-05-12 11:25:32 +02:00
parent 669d24c6a3
commit 8c4e664ca8
23 changed files with 836 additions and 166 deletions
+23
View File
@@ -0,0 +1,23 @@
package handlers
import (
"net/http/httptest"
"testing"
pscatalog "git.ma-al.com/goc_marek/ps_shop/internal/prestashop/catalog"
)
func TestAssignMarketSwitchLinksUsesCountryCurrencyID(t *testing.T) {
req := httptest.NewRequest("GET", "https://shop.example.com/pl/product/test", nil)
locale := pscatalog.HeaderLocaleData{
Countries: []pscatalog.LocaleOption{
{ID: 36, Code: "PL", CurrencyID: 6, Label: "Polska PLN"},
},
}
assignMarketSwitchLinks(req, &locale)
if got := locale.Countries[0].URL; got != "/pl/product/test?market=36%3APL%3A6" {
t.Fatalf("market url = %q, want %q", got, "/pl/product/test?market=36%3APL%3A6")
}
}
+55
View File
@@ -0,0 +1,55 @@
package middleware
import (
"testing"
pscookie "git.ma-al.com/goc_marek/ps_shop/internal/prestashop/cookie"
)
func TestApplyRequestMarketUsesSelectedCountryCurrency(t *testing.T) {
session := &pscookie.SessionContext{
Values: map[string]string{
"date_add": "2026-05-12 10:28:57",
"id_guest": "10",
"id_connections": "11",
"id_lang": "2",
"id_language": "2",
"id_currency": "1",
"id_shop": "1",
"id_cart": "55",
"checksum": "old",
},
OrderedKeys: []string{"date_add", "id_lang", "id_language", "id_currency", "id_guest", "id_connections", "id_shop", "id_cart", "checksum"},
}
applyRequestMarket(session, marketSelection{
CountryID: 36,
CountryISO: "PL",
CurrencyID: 6,
})
if got := session.Values["iso_code_country"]; got != "PL" {
t.Fatalf("iso_code_country = %q, want %q", got, "PL")
}
if _, ok := session.Values["id_country"]; ok {
t.Fatalf("id_country should not be persisted in anonymous market cookie")
}
if got := session.Values["id_currency"]; got != "6" {
t.Fatalf("id_currency = %q, want %q", got, "6")
}
if session.CurrencyID == nil || *session.CurrencyID != 6 {
t.Fatalf("CurrencyID = %v, want 6", session.CurrencyID)
}
if _, ok := session.Values["id_shop"]; ok {
t.Fatalf("id_shop should not be persisted in anonymous market cookie")
}
if _, ok := session.Values["id_cart"]; ok {
t.Fatalf("id_cart should not be persisted in anonymous market cookie")
}
wantOrder := []string{"date_add", "id_lang", "id_language", "iso_code_country", "id_currency", "id_guest", "id_connections", "checksum"}
for i, key := range wantOrder {
if i >= len(session.OrderedKeys) || session.OrderedKeys[i] != key {
t.Fatalf("OrderedKeys[%d] = %q, want %q; full=%v", i, session.OrderedKeys[i], key, session.OrderedKeys)
}
}
}
+81 -57
View File
@@ -4,10 +4,7 @@ import (
"context"
"fmt"
"hash/crc32"
"net"
"net/http"
"net/url"
"path"
"strconv"
"strings"
@@ -21,6 +18,14 @@ type AnonymousSessionInitializer interface {
NewAnonymous(ctx context.Context, req *http.Request, cookieName string) (*pscookie.SessionContext, error)
}
type SessionExpiryRefresher interface {
RefreshExpiry(ctx context.Context, session *pscookie.SessionContext) error
}
type SessionCookieNameResolver interface {
ResolveCookieName(ctx context.Context, req *http.Request) (string, error)
}
type LanguageResolver interface {
ResolveLanguageID(ctx context.Context, req *http.Request, fallback int64) int64
}
@@ -31,11 +36,22 @@ type ProductRouteMatcher interface {
func Session(cfg psconfig.Config, codec pscookie.Codec, initializer AnonymousSessionInitializer, languageResolver LanguageResolver, matcher ProductRouteMatcher) echo.MiddlewareFunc {
ownership := cfg.ParseRouteOwnership()
expiryRefresher, _ := initializer.(SessionExpiryRefresher)
cookieNameResolver, _ := initializer.(SessionCookieNameResolver)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
ownedRoute := ownsProductRoute(ownership.ProductPrefixes, c.Request().URL.Path, matcher)
configuredCookieName := cfg.DeriveCookieName(requestCookieHost(c.Request()))
if cookieNameResolver != nil {
resolvedCookieName, err := cookieNameResolver.ResolveCookieName(c.Request().Context(), c.Request())
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("prestashop cookie name resolution failed: %v", err))
}
if strings.TrimSpace(resolvedCookieName) != "" {
configuredCookieName = resolvedCookieName
}
}
cookieName, rawCookie := findPrestaShopCookie(c.Request(), configuredCookieName)
if cookieName == "" {
cookieName = configuredCookieName
@@ -66,12 +82,17 @@ func Session(cfg psconfig.Config, codec pscookie.Codec, initializer AnonymousSes
applyRequestMarket(session, requestMarketSelection(c.Request()))
}
if ownedRoute && shouldSetSessionCookie(rawCookie, session) {
if expiryRefresher != nil {
if err := expiryRefresher.RefreshExpiry(c.Request().Context(), session); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("prestashop session expiry refresh failed: %v", err))
}
}
encoded, err := codec.Encode(session)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "prestashop cookie encode failed")
}
session.RawCookie = encoded
setPrestaShopCookie(c.Request(), c.Response(), ownership.ProductPrefixes, cookieName, encoded)
setPrestaShopCookie(c.Request(), c.Response(), session, cookieName, encoded)
if redirectURL, ok := clearMarketSelectionURL(c.Request()); ok {
return c.Redirect(http.StatusSeeOther, redirectURL)
}
@@ -213,19 +234,18 @@ func applyRequestMarket(session *pscookie.SessionContext, selection marketSelect
session.CurrencyID = int64Ptr(selection.CurrencyID)
session.Values["iso_code_country"] = selection.CountryISO
if selection.CountryID > 0 {
session.Values["id_country"] = strconv.FormatInt(selection.CountryID, 10)
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_country", 5)
}
session.Values["id_currency"] = strconv.FormatInt(selection.CurrencyID, 10)
delete(session.Values, "id_country")
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "iso_code_country", 4)
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_currency", 6)
session.OrderedKeys = removeOrderedKey(session.OrderedKeys, "id_country")
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_currency", 5)
if !session.IsLoggedIn {
if checksum := anonymousSessionChecksum(session, sessionLanguageID(session)); checksum != "" {
session.Values["checksum"] = checksum
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "checksum", len(session.OrderedKeys))
}
trimAnonymousCookieValues(session)
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_guest", 6)
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_connections", 7)
session.OrderedKeys = removeOrderedKey(session.OrderedKeys, "checksum")
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "checksum", len(session.OrderedKeys))
}
session.Plaintext = ""
@@ -284,6 +304,46 @@ func ensureOrderedKey(keys []string, key string, index int) []string {
return keys
}
func removeOrderedKey(keys []string, key string) []string {
for i, existing := range keys {
if existing == key {
return append(keys[:i], keys[i+1:]...)
}
}
return keys
}
func trimAnonymousCookieValues(session *pscookie.SessionContext) {
if session == nil || session.Values == nil {
return
}
allowed := map[string]struct{}{
"date_add": {},
"id_lang": {},
"id_language": {},
"iso_code_country": {},
"id_currency": {},
"id_guest": {},
"id_connections": {},
"checksum": {},
}
for key := range session.Values {
if _, ok := allowed[key]; !ok {
delete(session.Values, key)
}
}
filtered := make([]string, 0, len(session.OrderedKeys))
for _, key := range session.OrderedKeys {
if _, ok := allowed[key]; ok {
filtered = append(filtered, key)
}
}
session.OrderedKeys = filtered
}
func int64Ptr(value int64) *int64 {
if value == 0 {
return nil
@@ -353,53 +413,17 @@ func clearMarketSelectionURL(req *http.Request) (string, bool) {
return cleanPath, true
}
func setPrestaShopCookie(req *http.Request, res *echo.Response, ownedPrefixes []string, name, value string) {
http.SetCookie(res.Writer, &http.Cookie{
Name: name,
Value: value,
Path: requestCookiePath(req.URL.Path, ownedPrefixes),
Domain: requestCookieDomain(req),
Secure: requestCookieSecure(req),
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
}
func requestCookiePath(requestPath string, ownedPrefixes []string) string {
for _, prefix := range ownedPrefixes {
if prefix == "" || !strings.HasPrefix(requestPath, prefix) {
continue
}
base := path.Clean(strings.TrimSuffix(prefix, "/"))
if base == "." || base == "/" {
return "/"
}
parent := path.Dir(base)
if parent == "." {
return "/"
}
if !strings.HasSuffix(parent, "/") {
parent += "/"
}
return parent
func setPrestaShopCookie(req *http.Request, res *echo.Response, session *pscookie.SessionContext, name, value string) {
maxAge := 1
if session != nil && session.ExpiresAt != nil {
maxAge = int(session.ExpiresAt.UTC().Unix())
}
return "/"
}
func requestCookieDomain(req *http.Request) string {
host := requestCookieHost(req)
if host == "" {
return ""
header := fmt.Sprintf("%s=%s; path=/; max-age=%d; HttpOnly; SameSite=Lax", name, value, maxAge)
if requestCookieSecure(req) {
header += "; Secure"
}
if parsed, err := url.Parse("http://" + host); err == nil {
host = parsed.Hostname()
}
host = strings.TrimSpace(strings.TrimPrefix(host, "."))
if host == "" || strings.EqualFold(host, "localhost") || net.ParseIP(host) != nil {
return ""
}
return host
res.Header().Add(echo.HeaderSetCookie, header)
}
func requestCookieHost(req *http.Request) string {
+36
View File
@@ -0,0 +1,36 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/labstack/echo/v4"
pscookie "git.ma-al.com/goc_marek/ps_shop/internal/prestashop/cookie"
)
func TestSetPrestaShopCookiePersistsExpiry(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "https://shop.example.com/product/test", nil)
rec := httptest.NewRecorder()
res := e.NewContext(req, rec).Response()
expiresAt := time.Now().UTC().Add(4 * time.Hour).Truncate(time.Second)
setPrestaShopCookie(req, res, &pscookie.SessionContext{
ExpiresAt: &expiresAt,
}, "PrestaShop-test", "value")
setCookie := rec.Header().Get("Set-Cookie")
if !strings.Contains(setCookie, "max-age=") {
t.Fatalf("Set-Cookie missing max-age: %q", setCookie)
}
if strings.Contains(setCookie, "Expires=") {
t.Fatalf("Set-Cookie should not include Expires: %q", setCookie)
}
if !strings.Contains(setCookie, "path=/") {
t.Fatalf("Set-Cookie missing path=/: %q", setCookie)
}
}
+17
View File
@@ -391,6 +391,9 @@ ORDER BY cl.name ASC
if err := s.db.WithContext(ctx).Raw(strings.TrimSpace(countryQuery), countryArgs...).Scan(&locale.Countries).Error; err != nil {
return HeaderLocaleData{}, err
}
for i := range locale.Countries {
locale.Countries[i] = formatCountryLocaleOption(locale.Countries[i])
}
locale.CurrentLanguage = pickLocaleOptionByID(locale.Languages, languageID)
locale.CurrentCountry = pickLocaleOptionByCode(locale.Countries, countryISO)
@@ -499,3 +502,17 @@ func pickLocaleOptionByCode(options []LocaleOption, code string) LocaleOption {
}
return LocaleOption{Code: code, Label: code, Meta: code}
}
func formatCountryLocaleOption(option LocaleOption) LocaleOption {
label := strings.TrimSpace(option.Label)
currencyCode := strings.TrimSpace(option.Meta)
if idx := strings.IndexByte(currencyCode, ' '); idx >= 0 {
currencyCode = currencyCode[:idx]
}
currencyCode = strings.TrimSpace(currencyCode)
if label == "" || currencyCode == "" {
return option
}
option.Label = label + " " + currencyCode
return option
}
@@ -0,0 +1,26 @@
package catalog
import "testing"
func TestFormatCountryLocaleOptionAddsCurrencyCodeToLabel(t *testing.T) {
option := LocaleOption{
Label: "Polska",
Meta: "PLN zl",
}
got := formatCountryLocaleOption(option)
if got.Label != "Polska PLN" {
t.Fatalf("formatCountryLocaleOption().Label = %q, want %q", got.Label, "Polska PLN")
}
}
func TestFormatCountryLocaleOptionFallsBackWithoutMeta(t *testing.T) {
option := LocaleOption{
Label: "Polska",
}
got := formatCountryLocaleOption(option)
if got.Label != "Polska" {
t.Fatalf("formatCountryLocaleOption().Label = %q, want %q", got.Label, "Polska")
}
}
+15 -4
View File
@@ -140,15 +140,15 @@ func (c Config) DeriveCookieName(host string) string {
return c.PrestaShopCookieName
}
domain := normalizedCookieDomain(host)
domain := fallbackCookieHashDomain(host)
if domain == "" {
domain = normalizedCookieDomain(c.PrestaShopBaseURL)
domain = fallbackCookieHashDomain(c.PrestaShopBaseURL)
}
if domain == "" {
domain = normalizedCookieDomain(c.PrestaShopProxyTarget)
domain = fallbackCookieHashDomain(c.PrestaShopProxyTarget)
}
sum := md5.Sum([]byte(c.PrestaShopVersion + "PrestaShop" + domain))
sum := md5.Sum([]byte(c.PrestaShopVersion + "ps-s1" + domain))
return fmt.Sprintf("PrestaShop-%x", sum)
}
@@ -256,6 +256,17 @@ func normalizedCookieDomain(input string) string {
return value
}
func fallbackCookieHashDomain(input string) string {
value := normalizedCookieDomain(input)
if value == "" {
return ""
}
if net.ParseIP(value) != nil || !strings.Contains(value, ".") {
return ""
}
return value
}
func dbDSNFromEnv() string {
return mysqlDSN(
firstNonEmpty(os.Getenv("PRESTASHOP_DB_HOST"), os.Getenv("DB_HOST")),
+19
View File
@@ -0,0 +1,19 @@
package config
import (
"crypto/md5"
"fmt"
"testing"
)
func TestDeriveCookieNameMatchesFallbackPrestashopRule(t *testing.T) {
cfg := Config{PrestaShopVersion: "1.7.3"}
got := cfg.DeriveCookieName("localhost")
sum := md5.Sum([]byte("1.7.3" + "ps-s1"))
want := fmt.Sprintf("PrestaShop-%x", sum)
if got != want {
t.Fatalf("DeriveCookieName() = %q, want %q", got, want)
}
}
+63 -1
View File
@@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"hash"
"hash/crc32"
"sort"
"strings"
)
@@ -99,7 +100,7 @@ func (c *nativeCodec) Encode(session *SessionContext) (string, error) {
plaintext := session.Plaintext
if plaintext == "" {
plaintext = serializeValues(session.Values, session.OrderedKeys)
plaintext = serializeCookieValues(session.Values, session.OrderedKeys, c.cfg.CookieIV)
}
return c.encryptInternal(plaintext)
}
@@ -321,6 +322,67 @@ func serializeValues(values map[string]string, orderedKeys []string) string {
return strings.Join(pairs, fieldSeparator)
}
func serializeCookieValues(values map[string]string, orderedKeys []string, cookieIV string) string {
if len(values) == 0 {
return ""
}
keys := orderedValueKeys(values, orderedKeys, "checksum")
if len(keys) == 0 {
return ""
}
var builder strings.Builder
for _, key := range keys {
builder.WriteString(key)
builder.WriteString(pairSeparator)
builder.WriteString(values[key])
builder.WriteString(fieldSeparator)
}
checksum := crc32.ChecksumIEEE([]byte(cookieIV + builder.String()))
builder.WriteString("checksum")
builder.WriteString(pairSeparator)
builder.WriteString(fmt.Sprintf("%d", checksum))
return builder.String()
}
func orderedValueKeys(values map[string]string, orderedKeys []string, excluded ...string) []string {
if len(values) == 0 {
return nil
}
excludeSet := make(map[string]struct{}, len(excluded))
for _, key := range excluded {
excludeSet[key] = struct{}{}
}
keys := make([]string, 0, len(values))
seen := map[string]struct{}{}
for _, key := range orderedKeys {
if _, skip := excludeSet[key]; skip {
continue
}
if _, ok := values[key]; ok {
keys = append(keys, key)
seen[key] = struct{}{}
}
}
extra := make([]string, 0)
for key := range values {
if _, skip := excludeSet[key]; skip {
continue
}
if _, ok := seen[key]; !ok {
extra = append(extra, key)
}
}
sort.Strings(extra)
keys = append(keys, extra...)
return keys
}
func int64Ptr(value string) *int64 {
if value == "" {
return nil
+47 -1
View File
@@ -1,6 +1,11 @@
package cookie
import "testing"
import (
"fmt"
"hash/crc32"
"strings"
"testing"
)
const (
testCookieKey = "def000008bf3d70e7012b7493c382d561e193218d0c74ab162fb0ea8029ce20e926531b4bcf0aaec9381152e6c161f198e06918b2d1aad67cc7cf40819a51ee328c63830"
@@ -66,3 +71,44 @@ func TestNativeCodecRoundTrip(t *testing.T) {
t.Fatalf("plaintext mismatch after roundtrip\n got: %s\nwant: %s", redecoded.Plaintext, decoded.Plaintext)
}
}
func TestNativeCodecEncodeRecomputesPrestashopChecksum(t *testing.T) {
codec, err := NewCodec(Config{
CookieName: "PrestaShop-test",
CookieKey: testCookieKey,
CookieIV: "vfRFMV42",
})
if err != nil {
t.Fatalf("NewCodec() error = %v", err)
}
decoded, err := codec.Decode(testCookie)
if err != nil {
t.Fatalf("Decode() error = %v", err)
}
decoded.Values["iso_code_country"] = "PL"
decoded.Values["id_currency"] = "6"
decoded.Values["checksum"] = "stale"
decoded.Plaintext = ""
encoded, err := codec.Encode(decoded)
if err != nil {
t.Fatalf("Encode() error = %v", err)
}
redecoded, err := codec.Decode(encoded)
if err != nil {
t.Fatalf("Decode(encoded) error = %v", err)
}
pairs := strings.Split(redecoded.Plaintext, fieldSeparator)
if len(pairs) < 2 {
t.Fatalf("plaintext too short: %q", redecoded.Plaintext)
}
body := strings.Join(pairs[:len(pairs)-1], fieldSeparator) + fieldSeparator
wantChecksum := fmt.Sprintf("%d", crc32.ChecksumIEEE([]byte("vfRFMV42"+body)))
if got := redecoded.Values["checksum"]; got != wantChecksum {
t.Fatalf("checksum = %q, want %q", got, wantChecksum)
}
}
+244 -5
View File
@@ -2,10 +2,13 @@ package session
import (
"context"
"crypto/md5"
"fmt"
"hash/crc32"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
@@ -16,8 +19,9 @@ import (
)
type Service struct {
db *gorm.DB
prefix string
db *gorm.DB
prefix string
version string
}
type defaults struct {
@@ -26,10 +30,11 @@ type defaults struct {
ShopID int64
ShopGroupID int64
CountryISO string
CookieHours int64
}
func NewService(db *gorm.DB, prefix string) *Service {
return &Service{db: db, prefix: prefix}
func NewService(db *gorm.DB, prefix, version string) *Service {
return &Service{db: db, prefix: prefix, version: version}
}
func (s *Service) NewAnonymous(ctx context.Context, req *http.Request, cookieName string) (*pscookie.SessionContext, error) {
@@ -88,12 +93,52 @@ func (s *Service) NewAnonymous(ctx context.Context, req *http.Request, cookieNam
ShopID: int64Ptr(def.ShopID),
GuestID: int64Ptr(guestID),
IsLoggedIn: false,
ExpiresAt: cookieExpiry(now, def.CookieHours),
Values: values,
OrderedKeys: orderedKeys,
ParseStatus: pscookie.ParseStatusAnonymous,
}, nil
}
func (s *Service) RefreshExpiry(ctx context.Context, session *pscookie.SessionContext) error {
if s == nil || session == nil {
return nil
}
def, err := s.loadDefaults(ctx)
if err != nil {
return err
}
session.ExpiresAt = cookieExpiry(time.Now().UTC(), def.CookieHours)
return nil
}
func (s *Service) ResolveCookieName(ctx context.Context, req *http.Request) (string, error) {
if s == nil || s.db == nil {
return "", fmt.Errorf("prestashop session service is not initialized")
}
host := requestHost(req)
shop, err := s.loadCookieShopContext(ctx, req)
if err != nil {
return "", err
}
baseName := "ps-s" + strconv.FormatInt(shop.ShopID, 10)
sharedDomains := []string(nil)
if shop.ShareOrder {
baseName = "ps-sg" + strconv.FormatInt(shop.ShopGroupID, 10)
sharedDomains, err = s.loadSharedCartDomains(ctx, shop.ShopGroupID)
if err != nil {
return "", err
}
}
sum := md5.Sum([]byte(s.version + baseName + prestashopCookieDomain(host, sharedDomains)))
return fmt.Sprintf("PrestaShop-%x", sum), nil
}
func (s *Service) loadDefaults(ctx context.Context) (*defaults, error) {
def := &defaults{
LanguageID: 1,
@@ -101,6 +146,7 @@ func (s *Service) loadDefaults(ctx context.Context) (*defaults, error) {
ShopID: 1,
ShopGroupID: 1,
CountryISO: "US",
CookieHours: 480,
}
configTable := s.prefix + "configuration"
@@ -111,7 +157,7 @@ func (s *Service) loadDefaults(ctx context.Context) (*defaults, error) {
Name string
Value string
}
configQuery := fmt.Sprintf("SELECT name, value FROM %s WHERE name IN ('PS_LANG_DEFAULT', 'PS_CURRENCY_DEFAULT', 'PS_COUNTRY_DEFAULT')", configTable)
configQuery := fmt.Sprintf("SELECT name, value FROM %s WHERE name IN ('PS_LANG_DEFAULT', 'PS_CURRENCY_DEFAULT', 'PS_COUNTRY_DEFAULT', 'PS_COOKIE_LIFETIME_FO')", configTable)
if err := s.db.WithContext(ctx).Raw(configQuery).Scan(&configs).Error; err != nil {
return nil, err
}
@@ -131,6 +177,10 @@ func (s *Service) loadDefaults(ctx context.Context) (*defaults, error) {
if parsed, err := strconv.ParseInt(cfg.Value, 10, 64); err == nil && parsed > 0 {
countryID = parsed
}
case "PS_COOKIE_LIFETIME_FO":
if parsed, err := strconv.ParseInt(cfg.Value, 10, 64); err == nil && parsed > 0 {
def.CookieHours = parsed
}
}
}
@@ -272,6 +322,195 @@ func (s *Service) connectionInsert(ctx context.Context, def *defaults, guestID i
return columns, values, nil
}
func cookieExpiry(now time.Time, lifetimeHours int64) *time.Time {
if lifetimeHours <= 0 {
return nil
}
expiresAt := now.Add(time.Duration(lifetimeHours) * time.Hour)
return &expiresAt
}
type cookieShopContext struct {
ShopID int64 `gorm:"column:id_shop"`
ShopGroupID int64 `gorm:"column:id_shop_group"`
ShareOrder bool `gorm:"column:share_order"`
URI string `gorm:"column:uri"`
Main bool `gorm:"column:main"`
}
func (s *Service) loadCookieShopContext(ctx context.Context, req *http.Request) (*cookieShopContext, error) {
normalizedHost := requestHost(req)
requestURI := requestPath(req)
shopURLTable := s.prefix + "shop_url"
shopTable := s.prefix + "shop"
shopGroupTable := s.prefix + "shop_group"
if normalizedHost != "" {
query := fmt.Sprintf(`
SELECT s.id_shop, s.id_shop_group, sg.share_order, CONCAT(su.physical_uri, su.virtual_uri) AS uri, su.main
FROM %s s
JOIN %s sg ON sg.id_shop_group = s.id_shop_group
JOIN %s su ON su.id_shop = s.id_shop
WHERE su.active = 1
AND s.active = 1
AND s.deleted = 0
AND (LOWER(su.domain) = ? OR LOWER(su.domain_ssl) = ?)
ORDER BY LENGTH(CONCAT(su.physical_uri, su.virtual_uri)) DESC, su.main DESC, s.id_shop ASC
`, shopTable, shopGroupTable, shopURLTable)
var shops []cookieShopContext
if err := s.db.WithContext(ctx).Raw(query, normalizedHost, normalizedHost).Scan(&shops).Error; err != nil {
return nil, err
}
for _, shop := range shops {
if uriMatchesRequest(shop.URI, requestURI) {
return &shop, nil
}
}
}
fallbackQuery := fmt.Sprintf(`
SELECT s.id_shop, s.id_shop_group, sg.share_order, '' AS uri, 1 AS main
FROM %s s
JOIN %s sg ON sg.id_shop_group = s.id_shop_group
WHERE s.active = 1
AND s.deleted = 0
ORDER BY s.id_shop ASC
LIMIT 1
`, shopTable, shopGroupTable)
var shop cookieShopContext
if err := s.db.WithContext(ctx).Raw(fallbackQuery).Scan(&shop).Error; err != nil {
return nil, err
}
if shop.ShopID == 0 {
return nil, fmt.Errorf("prestashop shop context not found")
}
return &shop, nil
}
func (s *Service) loadSharedCartDomains(ctx context.Context, shopGroupID int64) ([]string, error) {
if shopGroupID == 0 {
return nil, nil
}
type row struct {
Domain string `gorm:"column:domain"`
}
shopURLTable := s.prefix + "shop_url"
shopTable := s.prefix + "shop"
query := fmt.Sprintf(`
SELECT su.domain
FROM %s su
JOIN %s s ON s.id_shop = su.id_shop
WHERE su.main = 1
AND su.active = 1
AND s.id_shop_group = ?
`, shopURLTable, shopTable)
var rows []row
if err := s.db.WithContext(ctx).Raw(query, shopGroupID).Scan(&rows).Error; err != nil {
return nil, err
}
domains := make([]string, 0, len(rows))
for _, row := range rows {
if host := normalizeRequestHost(row.Domain); host != "" {
domains = append(domains, host)
}
}
return domains, nil
}
func normalizeRequestHost(input string) string {
value := strings.TrimSpace(input)
if value == "" {
return ""
}
if strings.Contains(value, "://") {
if parsed, err := url.Parse(value); err == nil {
value = parsed.Hostname()
}
}
if host, _, err := net.SplitHostPort(value); err == nil {
value = host
}
return strings.ToLower(strings.TrimSpace(value))
}
func requestHost(req *http.Request) string {
if req == nil {
return ""
}
host := req.Header.Get("X-Forwarded-Host")
if host == "" {
host = req.Host
}
if strings.Contains(host, ",") {
host = strings.TrimSpace(strings.Split(host, ",")[0])
}
return normalizeRequestHost(host)
}
func requestPath(req *http.Request) string {
if req == nil || req.URL == nil {
return "/"
}
path := req.URL.EscapedPath()
if path == "" {
path = req.URL.Path
}
if path == "" {
return "/"
}
decoded, err := url.PathUnescape(path)
if err == nil && decoded != "" {
path = decoded
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return path
}
func uriMatchesRequest(uri, requestURI string) bool {
if uri == "" {
uri = "/"
}
if requestURI == "" {
requestURI = "/"
}
return strings.HasPrefix(strings.ToLower(requestURI), strings.ToLower(uri))
}
var sharedDomainPattern = regexp.MustCompile(`^(?:.*\.)?([^.]*(?:.{2,4})?\..{2,3})$`)
func prestashopCookieDomain(host string, sharedURLs []string) string {
normalizedHost := normalizeRequestHost(host)
if normalizedHost == "" {
return ""
}
if net.ParseIP(normalizedHost) != nil || !strings.Contains(normalizedHost, ".") {
return ""
}
for _, sharedURL := range sharedURLs {
if normalizeRequestHost(sharedURL) != normalizedHost {
continue
}
matches := sharedDomainPattern.FindStringSubmatch(normalizedHost)
if len(matches) == 2 {
return "." + matches[1]
}
break
}
return normalizedHost
}
func (s *Service) tableColumns(ctx context.Context, tableName string) (map[string]bool, error) {
type columnRow struct {
ColumnName string `gorm:"column:COLUMN_NAME"`
@@ -0,0 +1,27 @@
package session
import "testing"
func TestPrestashopCookieDomain(t *testing.T) {
if got := prestashopCookieDomain("localhost", nil); got != "" {
t.Fatalf("prestashopCookieDomain(localhost) = %q, want empty", got)
}
if got := prestashopCookieDomain("shop.example.com", []string{"shop.example.com"}); got != ".example.com" {
t.Fatalf("prestashopCookieDomain(shared) = %q, want %q", got, ".example.com")
}
if got := prestashopCookieDomain("shop.example.com", nil); got != "shop.example.com" {
t.Fatalf("prestashopCookieDomain(single) = %q, want %q", got, "shop.example.com")
}
}
func TestURIMatchesRequest(t *testing.T) {
if !uriMatchesRequest("/shop/fr/", "/shop/fr/product/test") {
t.Fatalf("expected nested shop URI to match request path")
}
if uriMatchesRequest("/shop/fr/", "/shop/en/product/test") {
t.Fatalf("unexpected match for different shop URI")
}
}