467 lines
13 KiB
Go
467 lines
13 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"hash/crc32"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
|
|
psconfig "git.ma-al.com/goc_marek/ps_shop/internal/prestashop/config"
|
|
pscookie "git.ma-al.com/goc_marek/ps_shop/internal/prestashop/cookie"
|
|
)
|
|
|
|
type AnonymousSessionInitializer interface {
|
|
NewAnonymous(ctx context.Context, req *http.Request, cookieName string) (*pscookie.SessionContext, error)
|
|
}
|
|
|
|
type SessionExpiryRefresher interface {
|
|
RefreshExpiry(ctx context.Context, session *pscookie.SessionContext) error
|
|
}
|
|
|
|
type SessionCookieNameResolver interface {
|
|
ResolveCookieName(ctx context.Context, req *http.Request) (string, error)
|
|
}
|
|
|
|
type LanguageResolver interface {
|
|
ResolveLanguageID(ctx context.Context, req *http.Request, fallback int64) int64
|
|
}
|
|
|
|
type ProductRouteMatcher interface {
|
|
Owns(path string) bool
|
|
}
|
|
|
|
func Session(cfg psconfig.Config, codec pscookie.Codec, initializer AnonymousSessionInitializer, languageResolver LanguageResolver, matcher ProductRouteMatcher) echo.MiddlewareFunc {
|
|
ownership := cfg.ParseRouteOwnership()
|
|
expiryRefresher, _ := initializer.(SessionExpiryRefresher)
|
|
cookieNameResolver, _ := initializer.(SessionCookieNameResolver)
|
|
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
ownedRoute := ownsProductRoute(ownership.ProductPrefixes, c.Request().URL.Path, matcher)
|
|
configuredCookieName := cfg.DeriveCookieName(requestCookieHost(c.Request()))
|
|
if cookieNameResolver != nil {
|
|
resolvedCookieName, err := cookieNameResolver.ResolveCookieName(c.Request().Context(), c.Request())
|
|
if err != nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("prestashop cookie name resolution failed: %v", err))
|
|
}
|
|
if strings.TrimSpace(resolvedCookieName) != "" {
|
|
configuredCookieName = resolvedCookieName
|
|
}
|
|
}
|
|
cookieName, rawCookie := findPrestaShopCookie(c.Request(), configuredCookieName)
|
|
if cookieName == "" {
|
|
cookieName = configuredCookieName
|
|
}
|
|
|
|
session, err := codec.Decode(rawCookie)
|
|
if err != nil {
|
|
if ownedRoute {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, "prestashop cookie decode failed")
|
|
}
|
|
SetSession(c, &pscookie.SessionContext{
|
|
CookieName: cookieName,
|
|
RawCookie: rawCookie,
|
|
Values: map[string]string{},
|
|
ParseStatus: pscookie.ParseStatusInvalid,
|
|
})
|
|
return next(c)
|
|
}
|
|
session.CookieName = cookieName
|
|
if ownedRoute && initializer != nil && shouldBootstrapAnonymousSession(rawCookie, session) {
|
|
session, err = initializer.NewAnonymous(c.Request().Context(), c.Request(), cookieName)
|
|
if err != nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("prestashop session bootstrap failed: %v", err))
|
|
}
|
|
}
|
|
if ownedRoute {
|
|
applyRequestLanguage(session, resolveRequestLanguageID(c.Request().Context(), c.Request(), session, languageResolver))
|
|
applyRequestMarket(session, requestMarketSelection(c.Request()))
|
|
}
|
|
if ownedRoute && shouldSetSessionCookie(rawCookie, session) {
|
|
if expiryRefresher != nil {
|
|
if err := expiryRefresher.RefreshExpiry(c.Request().Context(), session); err != nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("prestashop session expiry refresh failed: %v", err))
|
|
}
|
|
}
|
|
encoded, err := codec.Encode(session)
|
|
if err != nil {
|
|
return echo.NewHTTPError(http.StatusInternalServerError, "prestashop cookie encode failed")
|
|
}
|
|
session.RawCookie = encoded
|
|
setPrestaShopCookie(c.Request(), c.Response(), session, cookieName, encoded)
|
|
if redirectURL, ok := clearMarketSelectionURL(c.Request()); ok {
|
|
return c.Redirect(http.StatusSeeOther, redirectURL)
|
|
}
|
|
}
|
|
|
|
SetSession(c, session)
|
|
return next(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
type marketSelection struct {
|
|
CountryID int64
|
|
CountryISO string
|
|
CurrencyID int64
|
|
}
|
|
|
|
func resolveRequestLanguageID(ctx context.Context, req *http.Request, session *pscookie.SessionContext, resolver LanguageResolver) int64 {
|
|
if resolver == nil {
|
|
return 0
|
|
}
|
|
return resolver.ResolveLanguageID(ctx, req, sessionLanguageID(session))
|
|
}
|
|
|
|
func findPrestaShopCookie(req *http.Request, configuredName string) (name, value string) {
|
|
cookies := req.Cookies()
|
|
for _, cookie := range cookies {
|
|
if cookie.Name == configuredName {
|
|
return cookie.Name, cookie.Value
|
|
}
|
|
}
|
|
|
|
prefix := cookiePrefix(configuredName)
|
|
if prefix == "" {
|
|
return "", ""
|
|
}
|
|
|
|
for _, cookie := range cookies {
|
|
if strings.HasPrefix(cookie.Name, prefix) {
|
|
return cookie.Name, cookie.Value
|
|
}
|
|
}
|
|
|
|
return "", ""
|
|
}
|
|
|
|
func cookiePrefix(configuredName string) string {
|
|
if configuredName == "" {
|
|
return ""
|
|
}
|
|
if strings.HasPrefix(configuredName, "PrestaShop-") {
|
|
return "PrestaShop-"
|
|
}
|
|
if idx := strings.Index(configuredName, "-"); idx >= 0 {
|
|
return configuredName[:idx+1]
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func shouldBootstrapAnonymousSession(rawCookie string, session *pscookie.SessionContext) bool {
|
|
if session == nil {
|
|
return true
|
|
}
|
|
if rawCookie == "" {
|
|
return true
|
|
}
|
|
if session.IsLoggedIn {
|
|
return false
|
|
}
|
|
return session.GuestID == nil ||
|
|
session.CurrencyID == nil ||
|
|
session.LanguageID == nil ||
|
|
session.Values["id_connections"] == "" ||
|
|
session.Values["iso_code_country"] == ""
|
|
}
|
|
|
|
func shouldSetSessionCookie(rawCookie string, session *pscookie.SessionContext) bool {
|
|
if session == nil {
|
|
return false
|
|
}
|
|
if rawCookie == "" {
|
|
return true
|
|
}
|
|
return rawCookie != session.RawCookie
|
|
}
|
|
|
|
func applyRequestLanguage(session *pscookie.SessionContext, languageID int64) {
|
|
if session == nil || languageID == 0 {
|
|
return
|
|
}
|
|
if current := sessionLanguageID(session); current == languageID {
|
|
return
|
|
}
|
|
|
|
if session.Values == nil {
|
|
session.Values = map[string]string{}
|
|
}
|
|
|
|
value := strconv.FormatInt(languageID, 10)
|
|
session.LanguageID = int64Ptr(languageID)
|
|
session.Values["id_lang"] = value
|
|
session.Values["id_language"] = value
|
|
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_lang", 1)
|
|
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_language", 3)
|
|
|
|
if !session.IsLoggedIn {
|
|
if checksum := anonymousSessionChecksum(session, languageID); checksum != "" {
|
|
session.Values["checksum"] = checksum
|
|
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "checksum", len(session.OrderedKeys))
|
|
}
|
|
}
|
|
|
|
session.Plaintext = ""
|
|
session.RawCookie = ""
|
|
}
|
|
|
|
func applyRequestMarket(session *pscookie.SessionContext, selection marketSelection) {
|
|
if session == nil || selection.CountryISO == "" || selection.CurrencyID == 0 {
|
|
return
|
|
}
|
|
|
|
currentCountry := ""
|
|
currentCurrency := int64(0)
|
|
currentCountryID := int64(0)
|
|
if session.Values != nil {
|
|
currentCountry = strings.ToUpper(strings.TrimSpace(session.Values["iso_code_country"]))
|
|
if session.CurrencyID != nil {
|
|
currentCurrency = *session.CurrencyID
|
|
}
|
|
currentCountryID, _ = strconv.ParseInt(session.Values["id_country"], 10, 64)
|
|
}
|
|
if currentCountry == selection.CountryISO && currentCurrency == selection.CurrencyID && currentCountryID == selection.CountryID {
|
|
return
|
|
}
|
|
|
|
if session.Values == nil {
|
|
session.Values = map[string]string{}
|
|
}
|
|
|
|
session.CurrencyID = int64Ptr(selection.CurrencyID)
|
|
session.Values["iso_code_country"] = selection.CountryISO
|
|
session.Values["id_currency"] = strconv.FormatInt(selection.CurrencyID, 10)
|
|
delete(session.Values, "id_country")
|
|
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "iso_code_country", 4)
|
|
session.OrderedKeys = removeOrderedKey(session.OrderedKeys, "id_country")
|
|
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_currency", 5)
|
|
|
|
if !session.IsLoggedIn {
|
|
trimAnonymousCookieValues(session)
|
|
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_guest", 6)
|
|
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_connections", 7)
|
|
session.OrderedKeys = removeOrderedKey(session.OrderedKeys, "checksum")
|
|
session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "checksum", len(session.OrderedKeys))
|
|
}
|
|
|
|
session.Plaintext = ""
|
|
session.RawCookie = ""
|
|
}
|
|
|
|
func sessionLanguageID(session *pscookie.SessionContext) int64 {
|
|
if session == nil || session.LanguageID == nil {
|
|
return 0
|
|
}
|
|
return *session.LanguageID
|
|
}
|
|
|
|
func anonymousSessionChecksum(session *pscookie.SessionContext, languageID int64) string {
|
|
if session == nil || session.Values == nil {
|
|
return ""
|
|
}
|
|
guestID, _ := strconv.ParseInt(session.Values["id_guest"], 10, 64)
|
|
connectionID, _ := strconv.ParseInt(session.Values["id_connections"], 10, 64)
|
|
currencyID, _ := strconv.ParseInt(session.Values["id_currency"], 10, 64)
|
|
shopID, _ := strconv.ParseInt(session.Values["id_shop"], 10, 64)
|
|
if guestID == 0 || connectionID == 0 || currencyID == 0 {
|
|
return ""
|
|
}
|
|
|
|
buf := make([]byte, 0, 32)
|
|
for _, value := range []int64{guestID, connectionID, languageID, currencyID, shopID} {
|
|
buf = strconv.AppendInt(buf, value, 10)
|
|
buf = append(buf, '|')
|
|
}
|
|
return strconv.FormatUint(uint64(crc32.ChecksumIEEE(buf)), 10)
|
|
}
|
|
|
|
func ensureOrderedKey(keys []string, key string, index int) []string {
|
|
for i, existing := range keys {
|
|
if existing != key {
|
|
continue
|
|
}
|
|
if i == index || index >= len(keys) {
|
|
return keys
|
|
}
|
|
keys = append(keys[:i], keys[i+1:]...)
|
|
break
|
|
}
|
|
|
|
if index < 0 {
|
|
index = 0
|
|
}
|
|
if index >= len(keys) {
|
|
return append(keys, key)
|
|
}
|
|
|
|
keys = append(keys, "")
|
|
copy(keys[index+1:], keys[index:])
|
|
keys[index] = key
|
|
return keys
|
|
}
|
|
|
|
func removeOrderedKey(keys []string, key string) []string {
|
|
for i, existing := range keys {
|
|
if existing == key {
|
|
return append(keys[:i], keys[i+1:]...)
|
|
}
|
|
}
|
|
return keys
|
|
}
|
|
|
|
func trimAnonymousCookieValues(session *pscookie.SessionContext) {
|
|
if session == nil || session.Values == nil {
|
|
return
|
|
}
|
|
|
|
allowed := map[string]struct{}{
|
|
"date_add": {},
|
|
"id_lang": {},
|
|
"id_language": {},
|
|
"iso_code_country": {},
|
|
"id_currency": {},
|
|
"id_guest": {},
|
|
"id_connections": {},
|
|
"checksum": {},
|
|
}
|
|
|
|
for key := range session.Values {
|
|
if _, ok := allowed[key]; !ok {
|
|
delete(session.Values, key)
|
|
}
|
|
}
|
|
|
|
filtered := make([]string, 0, len(session.OrderedKeys))
|
|
for _, key := range session.OrderedKeys {
|
|
if _, ok := allowed[key]; ok {
|
|
filtered = append(filtered, key)
|
|
}
|
|
}
|
|
session.OrderedKeys = filtered
|
|
}
|
|
|
|
func int64Ptr(value int64) *int64 {
|
|
if value == 0 {
|
|
return nil
|
|
}
|
|
v := value
|
|
return &v
|
|
}
|
|
|
|
func requestMarketSelection(req *http.Request) marketSelection {
|
|
if req == nil || req.URL == nil {
|
|
return marketSelection{}
|
|
}
|
|
raw := strings.TrimSpace(req.URL.Query().Get("market"))
|
|
if raw == "" {
|
|
return marketSelection{}
|
|
}
|
|
parts := strings.Split(raw, ":")
|
|
if len(parts) != 2 && len(parts) != 3 {
|
|
return marketSelection{}
|
|
}
|
|
|
|
selection := marketSelection{}
|
|
var countryISO string
|
|
var currencyValue string
|
|
if len(parts) == 3 {
|
|
countryID, err := strconv.ParseInt(strings.TrimSpace(parts[0]), 10, 64)
|
|
if err != nil || countryID == 0 {
|
|
return marketSelection{}
|
|
}
|
|
selection.CountryID = countryID
|
|
countryISO = strings.ToUpper(strings.TrimSpace(parts[1]))
|
|
currencyValue = parts[2]
|
|
} else {
|
|
countryISO = strings.ToUpper(strings.TrimSpace(parts[0]))
|
|
currencyValue = parts[1]
|
|
}
|
|
|
|
currencyID, err := strconv.ParseInt(strings.TrimSpace(currencyValue), 10, 64)
|
|
if err != nil || currencyID == 0 {
|
|
return marketSelection{}
|
|
}
|
|
if len(countryISO) < 2 || len(countryISO) > 5 {
|
|
return marketSelection{}
|
|
}
|
|
|
|
selection.CountryISO = countryISO
|
|
selection.CurrencyID = currencyID
|
|
return selection
|
|
}
|
|
|
|
func clearMarketSelectionURL(req *http.Request) (string, bool) {
|
|
if req == nil || req.URL == nil {
|
|
return "", false
|
|
}
|
|
query := req.URL.Query()
|
|
if query.Get("market") == "" {
|
|
return "", false
|
|
}
|
|
query.Del("market")
|
|
cleanPath := req.URL.Path
|
|
if cleanPath == "" {
|
|
cleanPath = "/"
|
|
}
|
|
if encoded := query.Encode(); encoded != "" {
|
|
return cleanPath + "?" + encoded, true
|
|
}
|
|
return cleanPath, true
|
|
}
|
|
|
|
func setPrestaShopCookie(req *http.Request, res *echo.Response, session *pscookie.SessionContext, name, value string) {
|
|
maxAge := 1
|
|
if session != nil && session.ExpiresAt != nil {
|
|
maxAge = int(session.ExpiresAt.UTC().Unix())
|
|
}
|
|
|
|
header := fmt.Sprintf("%s=%s; path=/; max-age=%d; HttpOnly; SameSite=Lax", name, value, maxAge)
|
|
if requestCookieSecure(req) {
|
|
header += "; Secure"
|
|
}
|
|
res.Header().Add(echo.HeaderSetCookie, header)
|
|
}
|
|
|
|
func requestCookieHost(req *http.Request) string {
|
|
if req == nil {
|
|
return ""
|
|
}
|
|
host := req.Header.Get("X-Forwarded-Host")
|
|
if host == "" {
|
|
host = req.Host
|
|
}
|
|
if strings.Contains(host, ",") {
|
|
host = strings.TrimSpace(strings.Split(host, ",")[0])
|
|
}
|
|
return host
|
|
}
|
|
|
|
func requestCookieSecure(req *http.Request) bool {
|
|
if req.TLS != nil {
|
|
return true
|
|
}
|
|
if forwarded := req.Header.Get("X-Forwarded-Proto"); forwarded != "" {
|
|
if strings.Contains(forwarded, ",") {
|
|
forwarded = strings.TrimSpace(strings.Split(forwarded, ",")[0])
|
|
}
|
|
return strings.EqualFold(forwarded, "https")
|
|
}
|
|
return false
|
|
}
|
|
|
|
func ownsProductRoute(prefixes []string, path string, matcher ProductRouteMatcher) bool {
|
|
if matcher != nil {
|
|
return matcher.Owns(path)
|
|
}
|
|
for _, prefix := range prefixes {
|
|
if prefix != "" && len(path) >= len(prefix) && path[:len(prefix)] == prefix {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|