fix cookie -- not working
This commit is contained in:
@@ -2,10 +2,13 @@ package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -16,8 +19,9 @@ import (
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
db *gorm.DB
|
||||
prefix string
|
||||
db *gorm.DB
|
||||
prefix string
|
||||
version string
|
||||
}
|
||||
|
||||
type defaults struct {
|
||||
@@ -26,10 +30,11 @@ type defaults struct {
|
||||
ShopID int64
|
||||
ShopGroupID int64
|
||||
CountryISO string
|
||||
CookieHours int64
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, prefix string) *Service {
|
||||
return &Service{db: db, prefix: prefix}
|
||||
func NewService(db *gorm.DB, prefix, version string) *Service {
|
||||
return &Service{db: db, prefix: prefix, version: version}
|
||||
}
|
||||
|
||||
func (s *Service) NewAnonymous(ctx context.Context, req *http.Request, cookieName string) (*pscookie.SessionContext, error) {
|
||||
@@ -88,12 +93,52 @@ func (s *Service) NewAnonymous(ctx context.Context, req *http.Request, cookieNam
|
||||
ShopID: int64Ptr(def.ShopID),
|
||||
GuestID: int64Ptr(guestID),
|
||||
IsLoggedIn: false,
|
||||
ExpiresAt: cookieExpiry(now, def.CookieHours),
|
||||
Values: values,
|
||||
OrderedKeys: orderedKeys,
|
||||
ParseStatus: pscookie.ParseStatusAnonymous,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) RefreshExpiry(ctx context.Context, session *pscookie.SessionContext) error {
|
||||
if s == nil || session == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
def, err := s.loadDefaults(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session.ExpiresAt = cookieExpiry(time.Now().UTC(), def.CookieHours)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) ResolveCookieName(ctx context.Context, req *http.Request) (string, error) {
|
||||
if s == nil || s.db == nil {
|
||||
return "", fmt.Errorf("prestashop session service is not initialized")
|
||||
}
|
||||
|
||||
host := requestHost(req)
|
||||
shop, err := s.loadCookieShopContext(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
baseName := "ps-s" + strconv.FormatInt(shop.ShopID, 10)
|
||||
sharedDomains := []string(nil)
|
||||
if shop.ShareOrder {
|
||||
baseName = "ps-sg" + strconv.FormatInt(shop.ShopGroupID, 10)
|
||||
sharedDomains, err = s.loadSharedCartDomains(ctx, shop.ShopGroupID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
sum := md5.Sum([]byte(s.version + baseName + prestashopCookieDomain(host, sharedDomains)))
|
||||
return fmt.Sprintf("PrestaShop-%x", sum), nil
|
||||
}
|
||||
|
||||
func (s *Service) loadDefaults(ctx context.Context) (*defaults, error) {
|
||||
def := &defaults{
|
||||
LanguageID: 1,
|
||||
@@ -101,6 +146,7 @@ func (s *Service) loadDefaults(ctx context.Context) (*defaults, error) {
|
||||
ShopID: 1,
|
||||
ShopGroupID: 1,
|
||||
CountryISO: "US",
|
||||
CookieHours: 480,
|
||||
}
|
||||
|
||||
configTable := s.prefix + "configuration"
|
||||
@@ -111,7 +157,7 @@ func (s *Service) loadDefaults(ctx context.Context) (*defaults, error) {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
configQuery := fmt.Sprintf("SELECT name, value FROM %s WHERE name IN ('PS_LANG_DEFAULT', 'PS_CURRENCY_DEFAULT', 'PS_COUNTRY_DEFAULT')", configTable)
|
||||
configQuery := fmt.Sprintf("SELECT name, value FROM %s WHERE name IN ('PS_LANG_DEFAULT', 'PS_CURRENCY_DEFAULT', 'PS_COUNTRY_DEFAULT', 'PS_COOKIE_LIFETIME_FO')", configTable)
|
||||
if err := s.db.WithContext(ctx).Raw(configQuery).Scan(&configs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -131,6 +177,10 @@ func (s *Service) loadDefaults(ctx context.Context) (*defaults, error) {
|
||||
if parsed, err := strconv.ParseInt(cfg.Value, 10, 64); err == nil && parsed > 0 {
|
||||
countryID = parsed
|
||||
}
|
||||
case "PS_COOKIE_LIFETIME_FO":
|
||||
if parsed, err := strconv.ParseInt(cfg.Value, 10, 64); err == nil && parsed > 0 {
|
||||
def.CookieHours = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,6 +322,195 @@ func (s *Service) connectionInsert(ctx context.Context, def *defaults, guestID i
|
||||
return columns, values, nil
|
||||
}
|
||||
|
||||
func cookieExpiry(now time.Time, lifetimeHours int64) *time.Time {
|
||||
if lifetimeHours <= 0 {
|
||||
return nil
|
||||
}
|
||||
expiresAt := now.Add(time.Duration(lifetimeHours) * time.Hour)
|
||||
return &expiresAt
|
||||
}
|
||||
|
||||
type cookieShopContext struct {
|
||||
ShopID int64 `gorm:"column:id_shop"`
|
||||
ShopGroupID int64 `gorm:"column:id_shop_group"`
|
||||
ShareOrder bool `gorm:"column:share_order"`
|
||||
URI string `gorm:"column:uri"`
|
||||
Main bool `gorm:"column:main"`
|
||||
}
|
||||
|
||||
func (s *Service) loadCookieShopContext(ctx context.Context, req *http.Request) (*cookieShopContext, error) {
|
||||
normalizedHost := requestHost(req)
|
||||
requestURI := requestPath(req)
|
||||
|
||||
shopURLTable := s.prefix + "shop_url"
|
||||
shopTable := s.prefix + "shop"
|
||||
shopGroupTable := s.prefix + "shop_group"
|
||||
|
||||
if normalizedHost != "" {
|
||||
query := fmt.Sprintf(`
|
||||
SELECT s.id_shop, s.id_shop_group, sg.share_order, CONCAT(su.physical_uri, su.virtual_uri) AS uri, su.main
|
||||
FROM %s s
|
||||
JOIN %s sg ON sg.id_shop_group = s.id_shop_group
|
||||
JOIN %s su ON su.id_shop = s.id_shop
|
||||
WHERE su.active = 1
|
||||
AND s.active = 1
|
||||
AND s.deleted = 0
|
||||
AND (LOWER(su.domain) = ? OR LOWER(su.domain_ssl) = ?)
|
||||
ORDER BY LENGTH(CONCAT(su.physical_uri, su.virtual_uri)) DESC, su.main DESC, s.id_shop ASC
|
||||
`, shopTable, shopGroupTable, shopURLTable)
|
||||
var shops []cookieShopContext
|
||||
if err := s.db.WithContext(ctx).Raw(query, normalizedHost, normalizedHost).Scan(&shops).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, shop := range shops {
|
||||
if uriMatchesRequest(shop.URI, requestURI) {
|
||||
return &shop, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fallbackQuery := fmt.Sprintf(`
|
||||
SELECT s.id_shop, s.id_shop_group, sg.share_order, '' AS uri, 1 AS main
|
||||
FROM %s s
|
||||
JOIN %s sg ON sg.id_shop_group = s.id_shop_group
|
||||
WHERE s.active = 1
|
||||
AND s.deleted = 0
|
||||
ORDER BY s.id_shop ASC
|
||||
LIMIT 1
|
||||
`, shopTable, shopGroupTable)
|
||||
var shop cookieShopContext
|
||||
if err := s.db.WithContext(ctx).Raw(fallbackQuery).Scan(&shop).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if shop.ShopID == 0 {
|
||||
return nil, fmt.Errorf("prestashop shop context not found")
|
||||
}
|
||||
return &shop, nil
|
||||
}
|
||||
|
||||
func (s *Service) loadSharedCartDomains(ctx context.Context, shopGroupID int64) ([]string, error) {
|
||||
if shopGroupID == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type row struct {
|
||||
Domain string `gorm:"column:domain"`
|
||||
}
|
||||
|
||||
shopURLTable := s.prefix + "shop_url"
|
||||
shopTable := s.prefix + "shop"
|
||||
query := fmt.Sprintf(`
|
||||
SELECT su.domain
|
||||
FROM %s su
|
||||
JOIN %s s ON s.id_shop = su.id_shop
|
||||
WHERE su.main = 1
|
||||
AND su.active = 1
|
||||
AND s.id_shop_group = ?
|
||||
`, shopURLTable, shopTable)
|
||||
|
||||
var rows []row
|
||||
if err := s.db.WithContext(ctx).Raw(query, shopGroupID).Scan(&rows).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
domains := make([]string, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
if host := normalizeRequestHost(row.Domain); host != "" {
|
||||
domains = append(domains, host)
|
||||
}
|
||||
}
|
||||
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
func normalizeRequestHost(input string) string {
|
||||
value := strings.TrimSpace(input)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(value, "://") {
|
||||
if parsed, err := url.Parse(value); err == nil {
|
||||
value = parsed.Hostname()
|
||||
}
|
||||
}
|
||||
if host, _, err := net.SplitHostPort(value); err == nil {
|
||||
value = host
|
||||
}
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
func requestHost(req *http.Request) string {
|
||||
if req == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
host := req.Header.Get("X-Forwarded-Host")
|
||||
if host == "" {
|
||||
host = req.Host
|
||||
}
|
||||
if strings.Contains(host, ",") {
|
||||
host = strings.TrimSpace(strings.Split(host, ",")[0])
|
||||
}
|
||||
|
||||
return normalizeRequestHost(host)
|
||||
}
|
||||
|
||||
func requestPath(req *http.Request) string {
|
||||
if req == nil || req.URL == nil {
|
||||
return "/"
|
||||
}
|
||||
path := req.URL.EscapedPath()
|
||||
if path == "" {
|
||||
path = req.URL.Path
|
||||
}
|
||||
if path == "" {
|
||||
return "/"
|
||||
}
|
||||
decoded, err := url.PathUnescape(path)
|
||||
if err == nil && decoded != "" {
|
||||
path = decoded
|
||||
}
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func uriMatchesRequest(uri, requestURI string) bool {
|
||||
if uri == "" {
|
||||
uri = "/"
|
||||
}
|
||||
if requestURI == "" {
|
||||
requestURI = "/"
|
||||
}
|
||||
return strings.HasPrefix(strings.ToLower(requestURI), strings.ToLower(uri))
|
||||
}
|
||||
|
||||
var sharedDomainPattern = regexp.MustCompile(`^(?:.*\.)?([^.]*(?:.{2,4})?\..{2,3})$`)
|
||||
|
||||
func prestashopCookieDomain(host string, sharedURLs []string) string {
|
||||
normalizedHost := normalizeRequestHost(host)
|
||||
if normalizedHost == "" {
|
||||
return ""
|
||||
}
|
||||
if net.ParseIP(normalizedHost) != nil || !strings.Contains(normalizedHost, ".") {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, sharedURL := range sharedURLs {
|
||||
if normalizeRequestHost(sharedURL) != normalizedHost {
|
||||
continue
|
||||
}
|
||||
matches := sharedDomainPattern.FindStringSubmatch(normalizedHost)
|
||||
if len(matches) == 2 {
|
||||
return "." + matches[1]
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
return normalizedHost
|
||||
}
|
||||
|
||||
func (s *Service) tableColumns(ctx context.Context, tableName string) (map[string]bool, error) {
|
||||
type columnRow struct {
|
||||
ColumnName string `gorm:"column:COLUMN_NAME"`
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
package session
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPrestashopCookieDomain(t *testing.T) {
|
||||
if got := prestashopCookieDomain("localhost", nil); got != "" {
|
||||
t.Fatalf("prestashopCookieDomain(localhost) = %q, want empty", got)
|
||||
}
|
||||
|
||||
if got := prestashopCookieDomain("shop.example.com", []string{"shop.example.com"}); got != ".example.com" {
|
||||
t.Fatalf("prestashopCookieDomain(shared) = %q, want %q", got, ".example.com")
|
||||
}
|
||||
|
||||
if got := prestashopCookieDomain("shop.example.com", nil); got != "shop.example.com" {
|
||||
t.Fatalf("prestashopCookieDomain(single) = %q, want %q", got, "shop.example.com")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURIMatchesRequest(t *testing.T) {
|
||||
if !uriMatchesRequest("/shop/fr/", "/shop/fr/product/test") {
|
||||
t.Fatalf("expected nested shop URI to match request path")
|
||||
}
|
||||
|
||||
if uriMatchesRequest("/shop/fr/", "/shop/en/product/test") {
|
||||
t.Fatalf("unexpected match for different shop URI")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user