package session import ( "context" "crypto/md5" "fmt" "hash/crc32" "net" "net/http" "net/url" "regexp" "strconv" "strings" "time" pscookie "git.ma-al.com/goc_marek/ps_shop/internal/prestashop/cookie" "gorm.io/gorm" ) type Service struct { db *gorm.DB prefix string version string } type defaults struct { LanguageID int64 CurrencyID int64 ShopID int64 ShopGroupID int64 CountryISO string CookieHours int64 } func NewService(db *gorm.DB, prefix, version string) *Service { return &Service{db: db, prefix: prefix, version: version} } func (s *Service) NewAnonymous(ctx context.Context, req *http.Request, cookieName string) (*pscookie.SessionContext, error) { if s == nil || s.db == nil { return nil, fmt.Errorf("prestashop session service is not initialized") } def, err := s.loadDefaults(ctx) if err != nil { return nil, err } guestID, err := s.insertGuest(ctx) if err != nil { return nil, err } connectionID, err := s.insertConnection(ctx, def, guestID, req) if err != nil { return nil, err } now := time.Now().UTC() values := map[string]string{ "checksum": anonymousChecksum(guestID, connectionID, def.LanguageID, def.CurrencyID, def.ShopID), "date_add": now.Format("2006-01-02 15:04:05"), "id_cart": "", "id_connections": strconv.FormatInt(connectionID, 10), "id_currency": strconv.FormatInt(def.CurrencyID, 10), "id_guest": strconv.FormatInt(guestID, 10), "id_lang": strconv.FormatInt(def.LanguageID, 10), "id_language": strconv.FormatInt(def.LanguageID, 10), "iso_code_country": def.CountryISO, } orderedKeys := []string{ "date_add", "id_lang", "id_cart", "id_language", "iso_code_country", "id_currency", "id_guest", "id_connections", "checksum", } if def.ShopID > 0 { values["id_shop"] = strconv.FormatInt(def.ShopID, 10) orderedKeys = append(orderedKeys[:6], append([]string{"id_shop"}, orderedKeys[6:]...)...) } return &pscookie.SessionContext{ CookieName: cookieName, LanguageID: int64Ptr(def.LanguageID), CurrencyID: int64Ptr(def.CurrencyID), ShopID: int64Ptr(def.ShopID), GuestID: int64Ptr(guestID), IsLoggedIn: false, ExpiresAt: cookieExpiry(now, def.CookieHours), Values: values, OrderedKeys: orderedKeys, ParseStatus: pscookie.ParseStatusAnonymous, }, nil } func (s *Service) RefreshExpiry(ctx context.Context, session *pscookie.SessionContext) error { if s == nil || session == nil { return nil } def, err := s.loadDefaults(ctx) if err != nil { return err } session.ExpiresAt = cookieExpiry(time.Now().UTC(), def.CookieHours) return nil } func (s *Service) ResolveCookieName(ctx context.Context, req *http.Request) (string, error) { if s == nil || s.db == nil { return "", fmt.Errorf("prestashop session service is not initialized") } host := requestHost(req) shop, err := s.loadCookieShopContext(ctx, req) if err != nil { return "", err } baseName := "ps-s" + strconv.FormatInt(shop.ShopID, 10) sharedDomains := []string(nil) if shop.ShareOrder { baseName = "ps-sg" + strconv.FormatInt(shop.ShopGroupID, 10) sharedDomains, err = s.loadSharedCartDomains(ctx, shop.ShopGroupID) if err != nil { return "", err } } sum := md5.Sum([]byte(s.version + baseName + prestashopCookieDomain(host, sharedDomains))) return fmt.Sprintf("PrestaShop-%x", sum), nil } func (s *Service) loadDefaults(ctx context.Context) (*defaults, error) { def := &defaults{ LanguageID: 1, CurrencyID: 1, ShopID: 1, ShopGroupID: 1, CountryISO: "US", CookieHours: 480, } configTable := s.prefix + "configuration" shopTable := s.prefix + "shop" countryTable := s.prefix + "country" var configs []struct { Name string Value string } configQuery := fmt.Sprintf("SELECT name, value FROM %s WHERE name IN ('PS_LANG_DEFAULT', 'PS_CURRENCY_DEFAULT', 'PS_COUNTRY_DEFAULT', 'PS_COOKIE_LIFETIME_FO')", configTable) if err := s.db.WithContext(ctx).Raw(configQuery).Scan(&configs).Error; err != nil { return nil, err } countryID := int64(0) for _, cfg := range configs { switch cfg.Name { case "PS_LANG_DEFAULT": if parsed, err := strconv.ParseInt(cfg.Value, 10, 64); err == nil && parsed > 0 { def.LanguageID = parsed } case "PS_CURRENCY_DEFAULT": if parsed, err := strconv.ParseInt(cfg.Value, 10, 64); err == nil && parsed > 0 { def.CurrencyID = parsed } case "PS_COUNTRY_DEFAULT": if parsed, err := strconv.ParseInt(cfg.Value, 10, 64); err == nil && parsed > 0 { countryID = parsed } case "PS_COOKIE_LIFETIME_FO": if parsed, err := strconv.ParseInt(cfg.Value, 10, 64); err == nil && parsed > 0 { def.CookieHours = parsed } } } var shop struct { ID int64 `gorm:"column:id_shop"` GroupID int64 `gorm:"column:id_shop_group"` } shopQuery := fmt.Sprintf("SELECT id_shop, id_shop_group FROM %s ORDER BY id_shop LIMIT 1", shopTable) if err := s.db.WithContext(ctx).Raw(shopQuery).Scan(&shop).Error; err != nil { return nil, err } if shop.ID > 0 { def.ShopID = shop.ID } if shop.GroupID > 0 { def.ShopGroupID = shop.GroupID } if countryID > 0 { var country struct { ISOCode string `gorm:"column:iso_code"` } countryQuery := fmt.Sprintf("SELECT iso_code FROM %s WHERE id_country = ? LIMIT 1", countryTable) if err := s.db.WithContext(ctx).Raw(countryQuery, countryID).Scan(&country).Error; err != nil { return nil, err } if country.ISOCode != "" { def.CountryISO = country.ISOCode } } return def, nil } func (s *Service) insertGuest(ctx context.Context) (int64, error) { sqlDB, err := s.db.DB() if err != nil { return 0, fmt.Errorf("resolve sql db for guest insert: %w", err) } tableName := s.prefix + "guest" columns, values, err := s.guestInsert(ctx) if err != nil { return 0, err } query := insertQuery(tableName, columns) result, err := sqlDB.ExecContext(ctx, query, values...) if err != nil { return 0, fmt.Errorf("insert guest: %w", err) } id, err := result.LastInsertId() if err != nil { return 0, fmt.Errorf("guest last insert id: %w", err) } return id, nil } func (s *Service) insertConnection(ctx context.Context, def *defaults, guestID int64, req *http.Request) (int64, error) { sqlDB, err := s.db.DB() if err != nil { return 0, fmt.Errorf("resolve sql db for connection insert: %w", err) } tableName := s.prefix + "connections" columns, values, err := s.connectionInsert(ctx, def, guestID, req) if err != nil { return 0, err } query := insertQuery(tableName, columns) result, err := sqlDB.ExecContext(ctx, query, values...) if err != nil { return 0, fmt.Errorf("insert connection: %w", err) } id, err := result.LastInsertId() if err != nil { return 0, fmt.Errorf("connection last insert id: %w", err) } return id, nil } func (s *Service) guestInsert(ctx context.Context) ([]string, []any, error) { available, err := s.tableColumns(ctx, s.prefix+"guest") if err != nil { return nil, nil, fmt.Errorf("load guest columns: %w", err) } columns := make([]string, 0) values := make([]any, 0) addColumn := func(name string, value any) { if available[name] { columns = append(columns, name) values = append(values, value) } } addColumn("id_customer", 0) addColumn("id_operating_system", 0) addColumn("id_web_browser", 0) addColumn("javascript", 0) addColumn("screen_resolution_x", 0) addColumn("screen_resolution_y", 0) addColumn("screen_color", 0) addColumn("sun_java", 0) addColumn("adobe_flash", 0) addColumn("adobe_director", 0) addColumn("apple_quicktime", 0) addColumn("real_player", 0) addColumn("windows_media", 0) addColumn("accept_language", "") addColumn("mobile_theme", 0) return columns, values, nil } func (s *Service) connectionInsert(ctx context.Context, def *defaults, guestID int64, req *http.Request) ([]string, []any, error) { available, err := s.tableColumns(ctx, s.prefix+"connections") if err != nil { return nil, nil, fmt.Errorf("load connections columns: %w", err) } now := time.Now().UTC().Format("2006-01-02 15:04:05") columns := make([]string, 0) values := make([]any, 0) addColumn := func(name string, value any) { if available[name] { columns = append(columns, name) values = append(values, value) } } addColumn("id_guest", guestID) addColumn("id_shop", def.ShopID) addColumn("id_shop_group", def.ShopGroupID) addColumn("id_page", 0) addColumn("ip_address", ipAsUint32(req)) addColumn("date_add", now) addColumn("date_upd", now) addColumn("http_referer", referer(req)) addColumn("request_uri", requestURI(req)) return columns, values, nil } func cookieExpiry(now time.Time, lifetimeHours int64) *time.Time { if lifetimeHours <= 0 { return nil } expiresAt := now.Add(time.Duration(lifetimeHours) * time.Hour) return &expiresAt } type cookieShopContext struct { ShopID int64 `gorm:"column:id_shop"` ShopGroupID int64 `gorm:"column:id_shop_group"` ShareOrder bool `gorm:"column:share_order"` URI string `gorm:"column:uri"` Main bool `gorm:"column:main"` } func (s *Service) loadCookieShopContext(ctx context.Context, req *http.Request) (*cookieShopContext, error) { normalizedHost := requestHost(req) requestURI := requestPath(req) shopURLTable := s.prefix + "shop_url" shopTable := s.prefix + "shop" shopGroupTable := s.prefix + "shop_group" if normalizedHost != "" { query := fmt.Sprintf(` SELECT s.id_shop, s.id_shop_group, sg.share_order, CONCAT(su.physical_uri, su.virtual_uri) AS uri, su.main FROM %s s JOIN %s sg ON sg.id_shop_group = s.id_shop_group JOIN %s su ON su.id_shop = s.id_shop WHERE su.active = 1 AND s.active = 1 AND s.deleted = 0 AND (LOWER(su.domain) = ? OR LOWER(su.domain_ssl) = ?) ORDER BY LENGTH(CONCAT(su.physical_uri, su.virtual_uri)) DESC, su.main DESC, s.id_shop ASC `, shopTable, shopGroupTable, shopURLTable) var shops []cookieShopContext if err := s.db.WithContext(ctx).Raw(query, normalizedHost, normalizedHost).Scan(&shops).Error; err != nil { return nil, err } for _, shop := range shops { if uriMatchesRequest(shop.URI, requestURI) { return &shop, nil } } } fallbackQuery := fmt.Sprintf(` SELECT s.id_shop, s.id_shop_group, sg.share_order, '' AS uri, 1 AS main FROM %s s JOIN %s sg ON sg.id_shop_group = s.id_shop_group WHERE s.active = 1 AND s.deleted = 0 ORDER BY s.id_shop ASC LIMIT 1 `, shopTable, shopGroupTable) var shop cookieShopContext if err := s.db.WithContext(ctx).Raw(fallbackQuery).Scan(&shop).Error; err != nil { return nil, err } if shop.ShopID == 0 { return nil, fmt.Errorf("prestashop shop context not found") } return &shop, nil } func (s *Service) loadSharedCartDomains(ctx context.Context, shopGroupID int64) ([]string, error) { if shopGroupID == 0 { return nil, nil } type row struct { Domain string `gorm:"column:domain"` } shopURLTable := s.prefix + "shop_url" shopTable := s.prefix + "shop" query := fmt.Sprintf(` SELECT su.domain FROM %s su JOIN %s s ON s.id_shop = su.id_shop WHERE su.main = 1 AND su.active = 1 AND s.id_shop_group = ? `, shopURLTable, shopTable) var rows []row if err := s.db.WithContext(ctx).Raw(query, shopGroupID).Scan(&rows).Error; err != nil { return nil, err } domains := make([]string, 0, len(rows)) for _, row := range rows { if host := normalizeRequestHost(row.Domain); host != "" { domains = append(domains, host) } } return domains, nil } func normalizeRequestHost(input string) string { value := strings.TrimSpace(input) if value == "" { return "" } if strings.Contains(value, "://") { if parsed, err := url.Parse(value); err == nil { value = parsed.Hostname() } } if host, _, err := net.SplitHostPort(value); err == nil { value = host } return strings.ToLower(strings.TrimSpace(value)) } func requestHost(req *http.Request) string { if req == nil { return "" } host := req.Header.Get("X-Forwarded-Host") if host == "" { host = req.Host } if strings.Contains(host, ",") { host = strings.TrimSpace(strings.Split(host, ",")[0]) } return normalizeRequestHost(host) } func requestPath(req *http.Request) string { if req == nil || req.URL == nil { return "/" } path := req.URL.EscapedPath() if path == "" { path = req.URL.Path } if path == "" { return "/" } decoded, err := url.PathUnescape(path) if err == nil && decoded != "" { path = decoded } if !strings.HasPrefix(path, "/") { path = "/" + path } return path } func uriMatchesRequest(uri, requestURI string) bool { if uri == "" { uri = "/" } if requestURI == "" { requestURI = "/" } return strings.HasPrefix(strings.ToLower(requestURI), strings.ToLower(uri)) } var sharedDomainPattern = regexp.MustCompile(`^(?:.*\.)?([^.]*(?:.{2,4})?\..{2,3})$`) func prestashopCookieDomain(host string, sharedURLs []string) string { normalizedHost := normalizeRequestHost(host) if normalizedHost == "" { return "" } if net.ParseIP(normalizedHost) != nil || !strings.Contains(normalizedHost, ".") { return "" } for _, sharedURL := range sharedURLs { if normalizeRequestHost(sharedURL) != normalizedHost { continue } matches := sharedDomainPattern.FindStringSubmatch(normalizedHost) if len(matches) == 2 { return "." + matches[1] } break } return normalizedHost } func (s *Service) tableColumns(ctx context.Context, tableName string) (map[string]bool, error) { type columnRow struct { ColumnName string `gorm:"column:COLUMN_NAME"` } var rows []columnRow query := ` SELECT COLUMN_NAME FROM information_schema.columns WHERE table_schema = DATABASE() AND table_name = ? ` if err := s.db.WithContext(ctx).Raw(query, tableName).Scan(&rows).Error; err != nil { return nil, err } columns := make(map[string]bool, len(rows)) for _, row := range rows { columns[row.ColumnName] = true } return columns, nil } func insertQuery(tableName string, columns []string) string { if len(columns) == 0 { return fmt.Sprintf("INSERT INTO %s () VALUES ()", tableName) } return fmt.Sprintf( "INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(columns, ", "), placeholders(len(columns)), ) } func placeholders(n int) string { parts := make([]string, n) for i := range parts { parts[i] = "?" } return strings.Join(parts, ", ") } func referer(req *http.Request) string { if req == nil { return "" } return req.Referer() } func requestURI(req *http.Request) string { if req == nil || req.URL == nil { return "" } return req.URL.RequestURI() } func ipAsUint32(req *http.Request) uint32 { if req == nil { return 0 } raw := req.Header.Get("X-Forwarded-For") if raw == "" { raw = req.RemoteAddr } if strings.Contains(raw, ",") { raw = strings.TrimSpace(strings.Split(raw, ",")[0]) } host := raw if parsedHost, _, err := net.SplitHostPort(raw); err == nil { host = parsedHost } ip := net.ParseIP(host) if ip == nil { return 0 } ip = ip.To4() if ip == nil { return 0 } return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3]) } func anonymousChecksum(values ...int64) string { buf := make([]byte, 0, len(values)*8) for _, v := range values { buf = strconv.AppendInt(buf, v, 10) buf = append(buf, '|') } return strconv.FormatUint(uint64(crc32.ChecksumIEEE(buf)), 10) } func int64Ptr(value int64) *int64 { if value == 0 { return nil } v := value return &v }