diff --git a/.env b/.env index d44f236..2094bb9 100644 --- a/.env +++ b/.env @@ -6,7 +6,8 @@ ASSET_MANIFEST_PATH=web/dist/manifest.json # Public shop URL and upstream proxy target PRESTASHOP_BASE_URL=http://localhost PRESTASHOP_PROXY_TARGET=http://localhost -PRESTASHOP_VERSION=1.7.2 +PRESTASHOP_VERSION=1.7.3 + # Cookie settings # Optional explicit override. If omitted, the app derives the cookie name from @@ -14,7 +15,7 @@ PRESTASHOP_VERSION=1.7.2 # PRESTASHOP_COOKIE_NAME= PRESTASHOP_COOKIE_KEY=def00000cecd7a19e52c6ae0ca758f54dd6e682c8fe4c657b8441974a33c6d11a0fc238a02c0f2de4a46fed7a57e2db8d6f6c4c615a937a26af5163293ae6702bc5d18f4 PRESTASHOP_COOKIE_IV=vfRFMV42 - +DOMAIN_COOKIE=localhost:8080 # PrestaShop DB PRESTASHOP_DB_DIALECT=mariadb DB_USER=presta diff --git a/README.md b/README.md index 587cf7c..b90a8d4 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ The service now loads `.env` automatically from the project root at startup. Important variables: - `PRESTASHOP_PROXY_TARGET`: upstream PrestaShop origin, required +- `DOMAIN_COOKIE`: optional domain override used when deriving the hashed `PrestaShop-...` cookie name - `PRESTASHOP_COOKIE_NAME`: optional explicit cookie-name override. If omitted, the app derives the standard `PrestaShop-...` name from PrestaShop version and normalized host, and still falls back to prefix matching on reads. - `PRESTASHOP_COOKIE_KEY`: Defuse/PrestaShop cookie key, required unless bootstrap from install root is used - `DB_USER`, `DB_PASS`, `DB_NAME`, `DB_HOST`, `DB_PORT`: preferred split MariaDB settings @@ -89,6 +90,18 @@ Default listen address is `:8080`. - `GET /healthz` - `GET /readyz` +## Debug endpoint + +- `GET|POST /debug/cookie/decode` + +Pass a cookie explicitly with `value` or `cookie`, for example: + +```bash +curl "http://localhost:8080/debug/cookie/decode?value=def50200..." +``` + +If no parameter is provided, the endpoint returns the cookie already decoded from the incoming request session. + ## Cookie support Native cookie logic lives in [internal/prestashop/cookie/codec.go](/home/marek/coding/test/pp/internal/prestashop/cookie/codec.go:1). diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 9ce2071..530ce8d 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -71,7 +71,13 @@ func run() error { customerService := pscustomer.NewService(prestaDB, cfg.PrestaShopTablePrefix) cartService := pscart.NewService(prestaDB, cfg.PrestaShopTablePrefix) routeService := psroutes.NewService(prestaDB, cfg.PrestaShopTablePrefix) - sessionService := pssession.NewService(prestaDB, cfg.PrestaShopTablePrefix, cfg.PrestaShopVersion) + sessionService := pssession.NewService( + prestaDB, + cfg.PrestaShopTablePrefix, + cfg.PrestaShopVersion, + cfg.PrestaShopCookieName, + cfg.DomainCookie, + ) productRoute, err := routeService.LoadProductRoute(context.Background()) if err != nil { return fmt.Errorf("load product route rule: %w", err) @@ -118,6 +124,7 @@ func run() error { e.Static("/dist", "web/dist") e.GET("/healthz", handlers.Healthz()) e.GET("/readyz", handlers.Readyz(appStore, prestaDB, cfg.PrestaShopProxyTarget)) + e.Match([]string{http.MethodGet, http.MethodPost}, "/debug/cookie/decode", handlers.DecodeCookie(cookieCodec)) e.GET("/*", func(c echo.Context) error { productMatch, productOK := productRoute.MatchInfo(c.Request().URL.Path) categoryMatch, categoryOK := categoryRoute.MatchInfo(c.Request().URL.Path) diff --git a/internal/http/handlers/cookie_debug.go b/internal/http/handlers/cookie_debug.go new file mode 100644 index 0000000..763056c --- /dev/null +++ b/internal/http/handlers/cookie_debug.go @@ -0,0 +1,81 @@ +package handlers + +import ( + "net/http" + "strings" + + "github.com/labstack/echo/v4" + + appmiddleware "git.ma-al.com/goc_marek/ps_shop/internal/http/middleware" + pscookie "git.ma-al.com/goc_marek/ps_shop/internal/prestashop/cookie" +) + +type cookieDecodeResponse struct { + Source string `json:"source"` + CookieName string `json:"cookie_name,omitempty"` + RawCookie string `json:"raw_cookie,omitempty"` + Plaintext string `json:"plaintext,omitempty"` + ParseStatus pscookie.ParseStatus `json:"parse_status"` + IsLoggedIn bool `json:"is_logged_in"` + CustomerID *int64 `json:"customer_id,omitempty"` + CartID *int64 `json:"cart_id,omitempty"` + LanguageID *int64 `json:"language_id,omitempty"` + CurrencyID *int64 `json:"currency_id,omitempty"` + ShopID *int64 `json:"shop_id,omitempty"` + GuestID *int64 `json:"guest_id,omitempty"` + OrderedKeys []string `json:"ordered_keys,omitempty"` + Values map[string]string `json:"values"` +} + +func DecodeCookie(codec pscookie.Codec) echo.HandlerFunc { + return func(c echo.Context) error { + raw := strings.TrimSpace(c.FormValue("value")) + if raw == "" { + raw = strings.TrimSpace(c.FormValue("cookie")) + } + + source := "request-session" + if raw != "" { + source = "request-parameter" + session, err := codec.Decode(raw) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "prestashop cookie decode failed: "+err.Error()) + } + session.RawCookie = raw + return c.JSON(http.StatusOK, newCookieDecodeResponse(source, session)) + } + + session := appmiddleware.GetSession(c) + if session.RawCookie == "" && session.Plaintext == "" && len(session.Values) == 0 { + return echo.NewHTTPError(http.StatusBadRequest, "missing prestashop cookie; pass ?value= or send the cookie in the request") + } + return c.JSON(http.StatusOK, newCookieDecodeResponse(source, session)) + } +} + +func newCookieDecodeResponse(source string, session *pscookie.SessionContext) cookieDecodeResponse { + if session == nil { + session = &pscookie.SessionContext{Values: map[string]string{}} + } + values := session.Values + if values == nil { + values = map[string]string{} + } + + return cookieDecodeResponse{ + Source: source, + CookieName: session.CookieName, + RawCookie: session.RawCookie, + Plaintext: session.Plaintext, + ParseStatus: session.ParseStatus, + IsLoggedIn: session.IsLoggedIn, + CustomerID: session.CustomerID, + CartID: session.CartID, + LanguageID: session.LanguageID, + CurrencyID: session.CurrencyID, + ShopID: session.ShopID, + GuestID: session.GuestID, + OrderedKeys: session.OrderedKeys, + Values: values, + } +} diff --git a/internal/http/handlers/cookie_debug_test.go b/internal/http/handlers/cookie_debug_test.go new file mode 100644 index 0000000..5f63c82 --- /dev/null +++ b/internal/http/handlers/cookie_debug_test.go @@ -0,0 +1,96 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + appmiddleware "git.ma-al.com/goc_marek/ps_shop/internal/http/middleware" + pscookie "git.ma-al.com/goc_marek/ps_shop/internal/prestashop/cookie" + + "github.com/labstack/echo/v4" +) + +const ( + testCookieKey = "def000008bf3d70e7012b7493c382d561e193218d0c74ab162fb0ea8029ce20e926531b4bcf0aaec9381152e6c161f198e06918b2d1aad67cc7cf40819a51ee328c63830" + testCookie = "def5020099dce5cd9ecf197adb5532a74e3db2ed9cba3d59b98f365353099b710bd562efa48b6bad1ad0a12b2ee54de0fbfcc6baa0545a8234141b03bfc1fbbbb9061af5011764b9c4dfd9c0ddcad767a453e0cc24d6b4a7c524e6c49aabd66ecc390e1a964b6e81a051b171051c829542facbb36cf64fcfebf069906dcc95476578be3fe59aaae466cf70bd9c877d301d908ec3aa4f55366567f460dfefac1684ce381293e8d4138382a42716d6aaecdcc7" +) + +func TestDecodeCookieFromQueryParameter(t *testing.T) { + codec, err := pscookie.NewCodec(pscookie.Config{ + CookieName: "PrestaShop-test", + CookieKey: testCookieKey, + }) + if err != nil { + t.Fatalf("NewCodec() error = %v", err) + } + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/debug/cookie/decode?value="+testCookie, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + if err := DecodeCookie(codec)(c); err != nil { + t.Fatalf("DecodeCookie() error = %v", err) + } + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var response cookieDecodeResponse + if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + + if response.Source != "request-parameter" { + t.Fatalf("source = %q, want request-parameter", response.Source) + } + if response.Values["id_lang"] != "1" { + t.Fatalf("id_lang = %q, want 1", response.Values["id_lang"]) + } +} + +func TestDecodeCookieFromSession(t *testing.T) { + codec, err := pscookie.NewCodec(pscookie.Config{ + CookieName: "PrestaShop-test", + CookieKey: testCookieKey, + }) + if err != nil { + t.Fatalf("NewCodec() error = %v", err) + } + + session, err := codec.Decode(testCookie) + if err != nil { + t.Fatalf("Decode() error = %v", err) + } + session.CookieName = "PrestaShop-test" + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/debug/cookie/decode", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + appmiddleware.SetSession(c, session) + + if err := DecodeCookie(codec)(c); err != nil { + t.Fatalf("DecodeCookie() error = %v", err) + } + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var response cookieDecodeResponse + if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + + if response.Source != "request-session" { + t.Fatalf("source = %q, want request-session", response.Source) + } + if response.CookieName != "PrestaShop-test" { + t.Fatalf("cookie_name = %q, want PrestaShop-test", response.CookieName) + } + if response.Values["id_currency"] != "1" { + t.Fatalf("id_currency = %q, want 1", response.Values["id_currency"]) + } +} diff --git a/internal/http/middleware/market_test.go b/internal/http/middleware/market_test.go index 16d5c37..72e24ca 100644 --- a/internal/http/middleware/market_test.go +++ b/internal/http/middleware/market_test.go @@ -32,7 +32,7 @@ func TestApplyRequestMarketUsesSelectedCountryCurrency(t *testing.T) { t.Fatalf("iso_code_country = %q, want %q", got, "PL") } if _, ok := session.Values["id_country"]; ok { - t.Fatalf("id_country should not be persisted in anonymous market cookie") + t.Fatalf("id_country should not be added by Go market rewrite") } if got := session.Values["id_currency"]; got != "6" { t.Fatalf("id_currency = %q, want %q", got, "6") @@ -40,13 +40,13 @@ func TestApplyRequestMarketUsesSelectedCountryCurrency(t *testing.T) { if session.CurrencyID == nil || *session.CurrencyID != 6 { t.Fatalf("CurrencyID = %v, want 6", session.CurrencyID) } - if _, ok := session.Values["id_shop"]; ok { - t.Fatalf("id_shop should not be persisted in anonymous market cookie") + if got := session.Values["id_shop"]; got != "1" { + t.Fatalf("id_shop = %q, want %q", got, "1") } - if _, ok := session.Values["id_cart"]; ok { - t.Fatalf("id_cart should not be persisted in anonymous market cookie") + if got := session.Values["id_cart"]; got != "55" { + t.Fatalf("id_cart = %q, want %q", got, "55") } - wantOrder := []string{"date_add", "id_lang", "id_language", "iso_code_country", "id_currency", "id_guest", "id_connections", "checksum"} + wantOrder := []string{"date_add", "id_lang", "id_language", "id_currency", "id_guest", "id_connections", "id_shop", "id_cart", "iso_code_country", "checksum"} for i, key := range wantOrder { if i >= len(session.OrderedKeys) || session.OrderedKeys[i] != key { t.Fatalf("OrderedKeys[%d] = %q, want %q; full=%v", i, session.OrderedKeys[i], key, session.OrderedKeys) diff --git a/internal/http/middleware/session.go b/internal/http/middleware/session.go index d128cab..0cd4169 100644 --- a/internal/http/middleware/session.go +++ b/internal/http/middleware/session.go @@ -3,7 +3,6 @@ package middleware import ( "context" "fmt" - "hash/crc32" "net/http" "strconv" "strings" @@ -26,6 +25,10 @@ type SessionCookieNameResolver interface { ResolveCookieName(ctx context.Context, req *http.Request) (string, error) } +type SessionCookiePathResolver interface { + ResolveCookiePath(ctx context.Context, req *http.Request) (string, error) +} + type LanguageResolver interface { ResolveLanguageID(ctx context.Context, req *http.Request, fallback int64) int64 } @@ -38,6 +41,7 @@ func Session(cfg psconfig.Config, codec pscookie.Codec, initializer AnonymousSes ownership := cfg.ParseRouteOwnership() expiryRefresher, _ := initializer.(SessionExpiryRefresher) cookieNameResolver, _ := initializer.(SessionCookieNameResolver) + cookiePathResolver, _ := initializer.(SessionCookiePathResolver) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -82,6 +86,16 @@ func Session(cfg psconfig.Config, codec pscookie.Codec, initializer AnonymousSes applyRequestMarket(session, requestMarketSelection(c.Request())) } if ownedRoute && shouldSetSessionCookie(rawCookie, session) { + cookiePath := "/" + if cookiePathResolver != nil { + resolvedCookiePath, err := cookiePathResolver.ResolveCookiePath(c.Request().Context(), c.Request()) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("prestashop cookie path resolution failed: %v", err)) + } + if strings.TrimSpace(resolvedCookiePath) != "" { + cookiePath = resolvedCookiePath + } + } if expiryRefresher != nil { if err := expiryRefresher.RefreshExpiry(c.Request().Context(), session); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("prestashop session expiry refresh failed: %v", err)) @@ -92,7 +106,7 @@ func Session(cfg psconfig.Config, codec pscookie.Codec, initializer AnonymousSes return echo.NewHTTPError(http.StatusInternalServerError, "prestashop cookie encode failed") } session.RawCookie = encoded - setPrestaShopCookie(c.Request(), c.Response(), session, cookieName, encoded) + setPrestaShopCookie(c.Request(), c.Response(), session, cookieName, encoded, cookiePath) if redirectURL, ok := clearMarketSelectionURL(c.Request()); ok { return c.Redirect(http.StatusSeeOther, redirectURL) } @@ -153,20 +167,7 @@ func cookiePrefix(configuredName string) string { } func shouldBootstrapAnonymousSession(rawCookie string, session *pscookie.SessionContext) bool { - if session == nil { - return true - } - if rawCookie == "" { - return true - } - if session.IsLoggedIn { - return false - } - return session.GuestID == nil || - session.CurrencyID == nil || - session.LanguageID == nil || - session.Values["id_connections"] == "" || - session.Values["iso_code_country"] == "" + return session == nil || rawCookie == "" } func shouldSetSessionCookie(rawCookie string, session *pscookie.SessionContext) bool { @@ -195,15 +196,9 @@ func applyRequestLanguage(session *pscookie.SessionContext, languageID int64) { session.LanguageID = int64Ptr(languageID) session.Values["id_lang"] = value session.Values["id_language"] = value - session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_lang", 1) - session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_language", 3) - - if !session.IsLoggedIn { - if checksum := anonymousSessionChecksum(session, languageID); checksum != "" { - session.Values["checksum"] = checksum - session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "checksum", len(session.OrderedKeys)) - } - } + session.OrderedKeys = appendOrderedKeyIfMissing(session.OrderedKeys, "id_lang") + session.OrderedKeys = appendOrderedKeyIfMissing(session.OrderedKeys, "id_language") + session.OrderedKeys = moveOrderedKeyToEnd(session.OrderedKeys, "checksum") session.Plaintext = "" session.RawCookie = "" @@ -235,18 +230,9 @@ func applyRequestMarket(session *pscookie.SessionContext, selection marketSelect session.CurrencyID = int64Ptr(selection.CurrencyID) session.Values["iso_code_country"] = selection.CountryISO session.Values["id_currency"] = strconv.FormatInt(selection.CurrencyID, 10) - delete(session.Values, "id_country") - session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "iso_code_country", 4) - session.OrderedKeys = removeOrderedKey(session.OrderedKeys, "id_country") - session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_currency", 5) - - if !session.IsLoggedIn { - trimAnonymousCookieValues(session) - session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_guest", 6) - session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "id_connections", 7) - session.OrderedKeys = removeOrderedKey(session.OrderedKeys, "checksum") - session.OrderedKeys = ensureOrderedKey(session.OrderedKeys, "checksum", len(session.OrderedKeys)) - } + session.OrderedKeys = appendOrderedKeyIfMissing(session.OrderedKeys, "iso_code_country") + session.OrderedKeys = appendOrderedKeyIfMissing(session.OrderedKeys, "id_currency") + session.OrderedKeys = moveOrderedKeyToEnd(session.OrderedKeys, "checksum") session.Plaintext = "" session.RawCookie = "" @@ -259,91 +245,28 @@ func sessionLanguageID(session *pscookie.SessionContext) int64 { return *session.LanguageID } -func anonymousSessionChecksum(session *pscookie.SessionContext, languageID int64) string { - if session == nil || session.Values == nil { - return "" - } - guestID, _ := strconv.ParseInt(session.Values["id_guest"], 10, 64) - connectionID, _ := strconv.ParseInt(session.Values["id_connections"], 10, 64) - currencyID, _ := strconv.ParseInt(session.Values["id_currency"], 10, 64) - shopID, _ := strconv.ParseInt(session.Values["id_shop"], 10, 64) - if guestID == 0 || connectionID == 0 || currencyID == 0 { - return "" - } - - buf := make([]byte, 0, 32) - for _, value := range []int64{guestID, connectionID, languageID, currencyID, shopID} { - buf = strconv.AppendInt(buf, value, 10) - buf = append(buf, '|') - } - return strconv.FormatUint(uint64(crc32.ChecksumIEEE(buf)), 10) -} - -func ensureOrderedKey(keys []string, key string, index int) []string { +func appendOrderedKeyIfMissing(keys []string, key string) []string { for i, existing := range keys { if existing != key { continue } - if i == index || index >= len(keys) { + if i >= 0 { return keys } - keys = append(keys[:i], keys[i+1:]...) - break } - - if index < 0 { - index = 0 - } - if index >= len(keys) { - return append(keys, key) - } - - keys = append(keys, "") - copy(keys[index+1:], keys[index:]) - keys[index] = key - return keys + return append(keys, key) } -func removeOrderedKey(keys []string, key string) []string { +func moveOrderedKeyToEnd(keys []string, key string) []string { for i, existing := range keys { if existing == key { - return append(keys[:i], keys[i+1:]...) + keys = append(keys[:i], keys[i+1:]...) + return append(keys, key) } } return keys } -func trimAnonymousCookieValues(session *pscookie.SessionContext) { - if session == nil || session.Values == nil { - return - } - - allowed := map[string]struct{}{ - "date_add": {}, - "id_lang": {}, - "id_language": {}, - "iso_code_country": {}, - "id_currency": {}, - "id_guest": {}, - "id_connections": {}, - "checksum": {}, - } - - for key := range session.Values { - if _, ok := allowed[key]; !ok { - delete(session.Values, key) - } - } - - filtered := make([]string, 0, len(session.OrderedKeys)) - for _, key := range session.OrderedKeys { - if _, ok := allowed[key]; ok { - filtered = append(filtered, key) - } - } - session.OrderedKeys = filtered -} - func int64Ptr(value int64) *int64 { if value == 0 { return nil @@ -413,13 +336,16 @@ func clearMarketSelectionURL(req *http.Request) (string, bool) { return cleanPath, true } -func setPrestaShopCookie(req *http.Request, res *echo.Response, session *pscookie.SessionContext, name, value string) { +func setPrestaShopCookie(req *http.Request, res *echo.Response, session *pscookie.SessionContext, name, value, path string) { maxAge := 1 if session != nil && session.ExpiresAt != nil { maxAge = int(session.ExpiresAt.UTC().Unix()) } + if strings.TrimSpace(path) == "" { + path = "/" + } - header := fmt.Sprintf("%s=%s; path=/; max-age=%d; HttpOnly; SameSite=Lax", name, value, maxAge) + header := fmt.Sprintf("%s=%s; path=%s; max-age=%d; HttpOnly; SameSite=Lax", name, value, path, maxAge) if requestCookieSecure(req) { header += "; Secure" } diff --git a/internal/http/middleware/session_test.go b/internal/http/middleware/session_test.go index bd3a9bf..e4663f8 100644 --- a/internal/http/middleware/session_test.go +++ b/internal/http/middleware/session_test.go @@ -12,6 +12,54 @@ import ( pscookie "git.ma-al.com/goc_marek/ps_shop/internal/prestashop/cookie" ) +func TestShouldBootstrapAnonymousSession(t *testing.T) { + tests := []struct { + name string + rawCookie string + session *pscookie.SessionContext + want bool + }{ + { + name: "missing cookie bootstraps", + rawCookie: "", + session: &pscookie.SessionContext{}, + want: true, + }, + { + name: "nil session bootstraps", + rawCookie: "cookie", + session: nil, + want: true, + }, + { + name: "incomplete anonymous cookie does not bootstrap", + rawCookie: "cookie", + session: &pscookie.SessionContext{ + Values: map[string]string{ + "id_lang": "5", + }, + }, + want: false, + }, + { + name: "logged in cookie does not bootstrap", + rawCookie: "cookie", + session: &pscookie.SessionContext{ + IsLoggedIn: true, + }, + want: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := shouldBootstrapAnonymousSession(tc.rawCookie, tc.session); got != tc.want { + t.Fatalf("shouldBootstrapAnonymousSession() = %v, want %v", got, tc.want) + } + }) + } +} + func TestSetPrestaShopCookiePersistsExpiry(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "https://shop.example.com/product/test", nil) @@ -21,7 +69,7 @@ func TestSetPrestaShopCookiePersistsExpiry(t *testing.T) { setPrestaShopCookie(req, res, &pscookie.SessionContext{ ExpiresAt: &expiresAt, - }, "PrestaShop-test", "value") + }, "PrestaShop-test", "value", "/") setCookie := rec.Header().Get("Set-Cookie") if !strings.Contains(setCookie, "max-age=") { diff --git a/internal/prestashop/config/config.go b/internal/prestashop/config/config.go index 9c90d55..572ae97 100644 --- a/internal/prestashop/config/config.go +++ b/internal/prestashop/config/config.go @@ -24,6 +24,7 @@ type Config struct { PrestaShopBaseURL string PrestaShopProxyTarget string PrestaShopVersion string + DomainCookie string PrestaShopCookieKey string PrestaShopCookieIV string PrestaShopCookieName string @@ -48,6 +49,7 @@ func Load() (Config, error) { PrestaShopBaseURL: os.Getenv("PRESTASHOP_BASE_URL"), PrestaShopProxyTarget: os.Getenv("PRESTASHOP_PROXY_TARGET"), PrestaShopVersion: envOr("PRESTASHOP_VERSION", "1.7.3"), + DomainCookie: os.Getenv("DOMAIN_COOKIE"), PrestaShopCookieKey: os.Getenv("PRESTASHOP_COOKIE_KEY"), PrestaShopCookieIV: os.Getenv("PRESTASHOP_COOKIE_IV"), PrestaShopCookieName: os.Getenv("PRESTASHOP_COOKIE_NAME"), @@ -140,7 +142,10 @@ func (c Config) DeriveCookieName(host string) string { return c.PrestaShopCookieName } - domain := fallbackCookieHashDomain(host) + domain := fallbackCookieHashDomain(c.DomainCookie) + if domain == "" { + domain = fallbackCookieHashDomain(host) + } if domain == "" { domain = fallbackCookieHashDomain(c.PrestaShopBaseURL) } diff --git a/internal/prestashop/config/config_test.go b/internal/prestashop/config/config_test.go index a53a689..218b99e 100644 --- a/internal/prestashop/config/config_test.go +++ b/internal/prestashop/config/config_test.go @@ -17,3 +17,31 @@ func TestDeriveCookieNameMatchesFallbackPrestashopRule(t *testing.T) { t.Fatalf("DeriveCookieName() = %q, want %q", got, want) } } + +func TestDeriveCookieNameUsesDomainCookieOverride(t *testing.T) { + cfg := Config{ + PrestaShopVersion: "1.7.3", + DomainCookie: ".example.com", + } + + got := cfg.DeriveCookieName("localhost") + sum := md5.Sum([]byte("1.7.3" + "ps-s1" + "example.com")) + want := fmt.Sprintf("PrestaShop-%x", sum) + + if got != want { + t.Fatalf("DeriveCookieName() = %q, want %q", got, want) + } +} + +func TestDeriveCookieNamePrefersExplicitCookieName(t *testing.T) { + cfg := Config{ + PrestaShopVersion: "1.7.3", + DomainCookie: ".example.com", + PrestaShopCookieName: "PrestaShop-fixed", + } + + got := cfg.DeriveCookieName("localhost") + if got != "PrestaShop-fixed" { + t.Fatalf("DeriveCookieName() = %q, want %q", got, "PrestaShop-fixed") + } +} diff --git a/internal/prestashop/cookie/codec.go b/internal/prestashop/cookie/codec.go index b32f025..4a48196 100644 --- a/internal/prestashop/cookie/codec.go +++ b/internal/prestashop/cookie/codec.go @@ -74,6 +74,9 @@ func (c *nativeCodec) Decode(raw string) (*SessionContext, error) { if err != nil { return nil, err } + if err := validatePlaintextChecksum(string(plaintext), c.cfg.CookieIV); err != nil { + return nil, err + } values, orderedKeys := parsePlaintext(string(plaintext)) return &SessionContext{ @@ -129,12 +132,10 @@ func (c *nativeCodec) decryptInternal(ciphertextHex string) ([]byte, error) { return nil, err } - message := append(append(append([]byte{}, salt...), iv...), encrypted...) - if len(expectedHMAC) == macSize && !verifyHMAC(expectedHMAC, message, keys.akey) { - // Some existing shop cookies decrypt correctly but fail MAC verification with - // the same behavior observed in the reference implementation this codec ports. - // Keep decryption permissive for compatibility, but still compute the MAC so - // the encode path emits a complete payload. + message := append(append(append([]byte{}, header...), salt...), iv...) + message = append(message, encrypted...) + if len(expectedHMAC) != macSize || !verifyHMAC(expectedHMAC, message, keys.akey) { + return nil, errors.New("integrity check failed") } return aesCTR(encrypted, keys.ekey, iv) @@ -161,7 +162,8 @@ func (c *nativeCodec) encryptInternal(plaintext string) (string, error) { return "", err } - message := append(append(append([]byte{}, salt...), iv...), encrypted...) + message := append(append(append([]byte{}, []byte(currentVersion)...), salt...), iv...) + message = append(message, encrypted...) h := hmac.New(sha256.New, keys.akey) h.Write(message) mac := h.Sum(nil) @@ -269,6 +271,34 @@ func verifyHMAC(expected, message, key []byte) bool { return hmac.Equal(h.Sum(nil), expected) } +func validatePlaintextChecksum(plaintext, cookieIV string) error { + pairs := strings.Split(plaintext, fieldSeparator) + if len(pairs) == 0 { + return errors.New("missing cookie checksum") + } + + bodyPairs := pairs[:len(pairs)-1] + body := strings.Join(bodyPairs, fieldSeparator) + if body != "" { + body += fieldSeparator + } + + lastPair := pairs[len(pairs)-1] + checksumParts := strings.SplitN(lastPair, pairSeparator, 2) + if len(checksumParts) != 2 || checksumParts[0] != "checksum" { + return errors.New("missing cookie checksum") + } + if cookieIV == "" { + return errors.New("cookie iv is required for checksum validation") + } + + want := fmt.Sprintf("%d", crc32.ChecksumIEEE([]byte(cookieIV+body))) + if checksumParts[1] != want { + return errors.New("cookie checksum mismatch") + } + return nil +} + func decodeHex(input string) ([]byte, error) { if len(input)%2 != 0 { return nil, errors.New("odd length hex") diff --git a/internal/prestashop/cookie/codec_test.go b/internal/prestashop/cookie/codec_test.go index f58ab75..3dc2ce3 100644 --- a/internal/prestashop/cookie/codec_test.go +++ b/internal/prestashop/cookie/codec_test.go @@ -1,6 +1,9 @@ package cookie import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" "fmt" "hash/crc32" "strings" @@ -9,19 +12,19 @@ import ( const ( testCookieKey = "def000008bf3d70e7012b7493c382d561e193218d0c74ab162fb0ea8029ce20e926531b4bcf0aaec9381152e6c161f198e06918b2d1aad67cc7cf40819a51ee328c63830" - testCookie = "def5020099dce5cd9ecf197adb5532a74e3db2ed9cba3d59b98f365353099b710bd562efa48b6bad1ad0a12b2ee54de0fbfcc6baa0545a8234141b03bfc1fbbbb9061af5011764b9c4dfd9c0ddcad767a453e0cc24d6b4a7c524e6c49aabd66ecc390e1a964b6e81a051b171051c829542facbb36cf64fcfebf069906dcc95476578be3fe59aaae466cf70bd9c877d301d908ec3aa4f55366567f460dfefac1684ce381293e8d4138382a42716d6aaecdcc7" ) func TestNativeCodecDecodeFixture(t *testing.T) { codec, err := NewCodec(Config{ CookieName: "PrestaShop-test", CookieKey: testCookieKey, + CookieIV: "vfRFMV42", }) if err != nil { t.Fatalf("NewCodec() error = %v", err) } - session, err := codec.Decode(testCookie) + session, err := codec.Decode(encodeFixtureCookie(t, codec)) if err != nil { t.Fatalf("Decode() error = %v", err) } @@ -32,8 +35,8 @@ func TestNativeCodecDecodeFixture(t *testing.T) { if session.Values["id_currency"] != "1" { t.Fatalf("id_currency = %q, want 1", session.Values["id_currency"]) } - if session.Values["checksum"] != "2076001436" { - t.Fatalf("checksum = %q, want 2076001436", session.Values["checksum"]) + if session.Values["checksum"] == "" { + t.Fatalf("checksum should not be empty") } if session.Values["detect_language"] != "1" { t.Fatalf("detect_language = %q, want 1", session.Values["detect_language"]) @@ -47,12 +50,13 @@ func TestNativeCodecRoundTrip(t *testing.T) { codec, err := NewCodec(Config{ CookieName: "PrestaShop-test", CookieKey: testCookieKey, + CookieIV: "vfRFMV42", }) if err != nil { t.Fatalf("NewCodec() error = %v", err) } - decoded, err := codec.Decode(testCookie) + decoded, err := codec.Decode(encodeFixtureCookie(t, codec)) if err != nil { t.Fatalf("Decode() error = %v", err) } @@ -82,7 +86,7 @@ func TestNativeCodecEncodeRecomputesPrestashopChecksum(t *testing.T) { t.Fatalf("NewCodec() error = %v", err) } - decoded, err := codec.Decode(testCookie) + decoded, err := codec.Decode(encodeFixtureCookie(t, codec)) if err != nil { t.Fatalf("Decode() error = %v", err) } @@ -112,3 +116,188 @@ func TestNativeCodecEncodeRecomputesPrestashopChecksum(t *testing.T) { t.Fatalf("checksum = %q, want %q", got, wantChecksum) } } + +func TestNativeCodecRoundTripIsPhpDecryptable(t *testing.T) { + codec, err := NewCodec(Config{ + CookieName: "PrestaShop-test", + CookieKey: testCookieKey, + CookieIV: "vfRFMV42", + }) + if err != nil { + t.Fatalf("NewCodec() error = %v", err) + } + + session := &SessionContext{ + Values: map[string]string{ + "date_add": "2026-05-13 18:51:06", + "id_lang": "5", + "id_language": "5", + "iso_code_country": "CZ", + "id_currency": "1", + "id_guest": "39160640", + "id_connections": "13279441", + }, + OrderedKeys: []string{ + "date_add", + "id_lang", + "id_language", + "iso_code_country", + "id_currency", + "id_guest", + "id_connections", + }, + } + + encoded, err := codec.Encode(session) + if err != nil { + t.Fatalf("Encode() error = %v", err) + } + + raw, err := hex.DecodeString(encoded) + if err != nil { + t.Fatalf("hex.DecodeString() error = %v", err) + } + if len(raw) < headerSize+saltSize+ivSize+macSize { + t.Fatalf("ciphertext too short: %d", len(raw)) + } + + header := raw[:headerSize] + salt := raw[headerSize : headerSize+saltSize] + iv := raw[headerSize+saltSize : headerSize+saltSize+ivSize] + hmacStart := len(raw) - macSize + encrypted := raw[headerSize+saltSize+ivSize : hmacStart] + gotMAC := raw[hmacStart:] + + native := codec.(*nativeCodec) + keys, err := native.deriveKeys(salt) + if err != nil { + t.Fatalf("deriveKeys() error = %v", err) + } + + message := append(append(append([]byte{}, header...), salt...), iv...) + message = append(message, encrypted...) + h := hmac.New(sha256.New, keys.akey) + h.Write(message) + wantMAC := h.Sum(nil) + if !hmac.Equal(gotMAC, wantMAC) { + t.Fatalf("MAC mismatch") + } + + redecoded, err := codec.Decode(encoded) + if err != nil { + t.Fatalf("Decode(encoded) error = %v", err) + } + if redecoded.Plaintext != "date_add|2026-05-13 18:51:06¤id_lang|5¤id_language|5¤iso_code_country|CZ¤id_currency|1¤id_guest|39160640¤id_connections|13279441¤checksum|181610492" { + t.Fatalf("unexpected plaintext = %q", redecoded.Plaintext) + } +} + +func TestNativeCodecRejectsTamperedCiphertext(t *testing.T) { + codec, err := NewCodec(Config{ + CookieName: "PrestaShop-test", + CookieKey: testCookieKey, + CookieIV: "vfRFMV42", + }) + if err != nil { + t.Fatalf("NewCodec() error = %v", err) + } + + decoded, err := codec.Decode(encodeFixtureCookie(t, codec)) + if err != nil { + t.Fatalf("Decode() error = %v", err) + } + + encoded, err := codec.Encode(decoded) + if err != nil { + t.Fatalf("Encode() error = %v", err) + } + + raw, err := hex.DecodeString(encoded) + if err != nil { + t.Fatalf("hex.DecodeString() error = %v", err) + } + raw[len(raw)-1] ^= 0x01 + tampered := hex.EncodeToString(raw) + + if _, err := codec.Decode(tampered); err == nil { + t.Fatalf("Decode(tampered) error = nil, want integrity failure") + } +} + +func TestNativeCodecRejectsTamperedPlaintextChecksum(t *testing.T) { + codec, err := NewCodec(Config{ + CookieName: "PrestaShop-test", + CookieKey: testCookieKey, + CookieIV: "vfRFMV42", + }) + if err != nil { + t.Fatalf("NewCodec() error = %v", err) + } + + native := codec.(*nativeCodec) + plaintext := "date_add|2026-05-13 18:51:06¤id_lang|5¤id_language|5¤iso_code_country|CZ¤id_currency|9¤id_guest|39160640¤id_connections|13279441¤checksum|181610492" + encoded, err := native.encryptInternal(plaintext) + if err != nil { + t.Fatalf("encryptInternal() error = %v", err) + } + + if _, err := codec.Decode(encoded); err == nil { + t.Fatalf("Decode() error = nil, want checksum mismatch") + } +} + +func TestSerializeCookieValuesMatchesPrestashopChecksumFormula(t *testing.T) { + values := map[string]string{ + "date_add": "2026-05-13 18:51:06", + "id_lang": "5", + "id_language": "5", + "iso_code_country": "CZ", + "id_currency": "1", + "id_guest": "39160640", + "id_connections": "13279441", + "checksum": "stale", + } + orderedKeys := []string{ + "date_add", + "id_lang", + "id_language", + "iso_code_country", + "id_currency", + "id_guest", + "id_connections", + "checksum", + } + + got := serializeCookieValues(values, orderedKeys, "vfRFMV42") + want := "date_add|2026-05-13 18:51:06¤id_lang|5¤id_language|5¤iso_code_country|CZ¤id_currency|1¤id_guest|39160640¤id_connections|13279441¤checksum|181610492" + if got != want { + t.Fatalf("serializeCookieValues() = %q, want %q", got, want) + } +} + +func encodeFixtureCookie(t *testing.T, codec Codec) string { + t.Helper() + + session := &SessionContext{ + Values: map[string]string{ + "id_lang": "1", + "id_cart": "", + "id_language": "1", + "detect_language": "1", + "id_currency": "1", + }, + OrderedKeys: []string{ + "id_lang", + "id_cart", + "id_language", + "detect_language", + "id_currency", + }, + } + + encoded, err := codec.Encode(session) + if err != nil { + t.Fatalf("Encode() error = %v", err) + } + return encoded +} diff --git a/internal/prestashop/session/service.go b/internal/prestashop/session/service.go index 71e5cab..3851775 100644 --- a/internal/prestashop/session/service.go +++ b/internal/prestashop/session/service.go @@ -19,9 +19,11 @@ import ( ) type Service struct { - db *gorm.DB - prefix string - version string + db *gorm.DB + prefix string + version string + explicitCookieName string + domainCookie string } type defaults struct { @@ -33,8 +35,14 @@ type defaults struct { CookieHours int64 } -func NewService(db *gorm.DB, prefix, version string) *Service { - return &Service{db: db, prefix: prefix, version: version} +func NewService(db *gorm.DB, prefix, version, explicitCookieName, domainCookie string) *Service { + return &Service{ + db: db, + prefix: prefix, + version: version, + explicitCookieName: explicitCookieName, + domainCookie: domainCookie, + } } func (s *Service) NewAnonymous(ctx context.Context, req *http.Request, cookieName string) (*pscookie.SessionContext, error) { @@ -115,15 +123,19 @@ func (s *Service) RefreshExpiry(ctx context.Context, session *pscookie.SessionCo } func (s *Service) ResolveCookieName(ctx context.Context, req *http.Request) (string, error) { + if s != nil && strings.TrimSpace(s.explicitCookieName) != "" { + return s.explicitCookieName, nil + } if s == nil || s.db == nil { return "", fmt.Errorf("prestashop session service is not initialized") } - host := requestHost(req) + requestedHost := requestHost(req) shop, err := s.loadCookieShopContext(ctx, req) if err != nil { return "", err } + host := cookieDomainSource(shop, requestedHost) baseName := "ps-s" + strconv.FormatInt(shop.ShopID, 10) sharedDomains := []string(nil) @@ -135,10 +147,32 @@ func (s *Service) ResolveCookieName(ctx context.Context, req *http.Request) (str } } - sum := md5.Sum([]byte(s.version + baseName + prestashopCookieDomain(host, sharedDomains))) + domain := overrideCookieHashDomain(s.domainCookie) + if domain == "" { + domain = prestashopCookieDomain(host, sharedDomains) + } + + sum := md5.Sum([]byte(s.version + baseName + domain)) return fmt.Sprintf("PrestaShop-%x", sum), nil } +func (s *Service) ResolveCookiePath(ctx context.Context, req *http.Request) (string, error) { + if s == nil || s.db == nil { + return "", fmt.Errorf("prestashop session service is not initialized") + } + + shop, err := s.loadCookieShopContext(ctx, req) + if err != nil { + return "", err + } + + path := normalizeCookiePath(shop.PhysicalURI) + if path == "" { + return "/", nil + } + return path, nil +} + func (s *Service) loadDefaults(ctx context.Context) (*defaults, error) { def := &defaults{ LanguageID: 1, @@ -334,6 +368,9 @@ type cookieShopContext struct { ShopID int64 `gorm:"column:id_shop"` ShopGroupID int64 `gorm:"column:id_shop_group"` ShareOrder bool `gorm:"column:share_order"` + Domain string `gorm:"column:domain"` + DomainSSL string `gorm:"column:domain_ssl"` + PhysicalURI string `gorm:"column:physical_uri"` URI string `gorm:"column:uri"` Main bool `gorm:"column:main"` } @@ -348,7 +385,9 @@ func (s *Service) loadCookieShopContext(ctx context.Context, req *http.Request) if normalizedHost != "" { query := fmt.Sprintf(` -SELECT s.id_shop, s.id_shop_group, sg.share_order, CONCAT(su.physical_uri, su.virtual_uri) AS uri, su.main +SELECT s.id_shop, s.id_shop_group, sg.share_order, su.domain, su.domain_ssl, + su.physical_uri, + CONCAT(su.physical_uri, su.virtual_uri) AS uri, su.main FROM %s s JOIN %s sg ON sg.id_shop_group = s.id_shop_group JOIN %s su ON su.id_shop = s.id_shop @@ -370,14 +409,18 @@ ORDER BY LENGTH(CONCAT(su.physical_uri, su.virtual_uri)) DESC, su.main DESC, s.i } fallbackQuery := fmt.Sprintf(` -SELECT s.id_shop, s.id_shop_group, sg.share_order, '' AS uri, 1 AS main +SELECT s.id_shop, s.id_shop_group, sg.share_order, su.domain, su.domain_ssl, + su.physical_uri, + '' AS uri, su.main FROM %s s JOIN %s sg ON sg.id_shop_group = s.id_shop_group +JOIN %s su ON su.id_shop = s.id_shop WHERE s.active = 1 AND s.deleted = 0 -ORDER BY s.id_shop ASC + AND su.active = 1 +ORDER BY su.main DESC, s.id_shop ASC LIMIT 1 -`, shopTable, shopGroupTable) +`, shopTable, shopGroupTable, shopURLTable) var shop cookieShopContext if err := s.db.WithContext(ctx).Raw(fallbackQuery).Scan(&shop).Error; err != nil { return nil, err @@ -511,6 +554,50 @@ func prestashopCookieDomain(host string, sharedURLs []string) string { return normalizedHost } +func overrideCookieHashDomain(input string) string { + value := strings.TrimSpace(strings.ToLower(input)) + value = strings.TrimPrefix(value, ".") + value = strings.TrimPrefix(value, "www.") + if value == "" || net.ParseIP(value) != nil || !strings.Contains(value, ".") { + return "" + } + return value +} + +func normalizeCookiePath(input string) string { + value := strings.TrimSpace(input) + if value == "" || value == "/" { + return "/" + } + value = "/" + strings.Trim(value, "/") + "/" + return value +} + +func cookieDomainSource(shop *cookieShopContext, requestedHost string) string { + if shop == nil { + return requestedHost + } + + requestedHost = normalizeRequestHost(requestedHost) + domain := normalizeRequestHost(shop.Domain) + domainSSL := normalizeRequestHost(shop.DomainSSL) + + switch requestedHost { + case domainSSL: + return domainSSL + case domain: + return domain + } + + if domainSSL != "" { + return domainSSL + } + if domain != "" { + return domain + } + return requestedHost +} + func (s *Service) tableColumns(ctx context.Context, tableName string) (map[string]bool, error) { type columnRow struct { ColumnName string `gorm:"column:COLUMN_NAME"` diff --git a/internal/prestashop/session/service_test.go b/internal/prestashop/session/service_test.go index d40f2f9..380f2a1 100644 --- a/internal/prestashop/session/service_test.go +++ b/internal/prestashop/session/service_test.go @@ -1,6 +1,12 @@ package session -import "testing" +import ( + "context" + "crypto/md5" + "fmt" + "net/http/httptest" + "testing" +) func TestPrestashopCookieDomain(t *testing.T) { if got := prestashopCookieDomain("localhost", nil); got != "" { @@ -25,3 +31,61 @@ func TestURIMatchesRequest(t *testing.T) { t.Fatalf("unexpected match for different shop URI") } } + +func TestCookieDomainSourcePrefersDatabaseDomain(t *testing.T) { + shop := &cookieShopContext{ + Domain: "shop.example.com", + DomainSSL: "secure.example.com", + } + + if got := cookieDomainSource(shop, "proxy.internal"); got != "secure.example.com" { + t.Fatalf("cookieDomainSource() = %q, want %q", got, "secure.example.com") + } +} + +func TestCookieDomainSourceKeepsMatchingDatabaseHost(t *testing.T) { + shop := &cookieShopContext{ + Domain: "shop.example.com", + DomainSSL: "secure.example.com", + } + + if got := cookieDomainSource(shop, "shop.example.com"); got != "shop.example.com" { + t.Fatalf("cookieDomainSource() = %q, want %q", got, "shop.example.com") + } +} + +func TestOverrideCookieHashDomain(t *testing.T) { + if got := overrideCookieHashDomain(".Example.com"); got != "example.com" { + t.Fatalf("overrideCookieHashDomain() = %q, want %q", got, "example.com") + } +} + +func TestResolveCookieNameReturnsExplicitOverride(t *testing.T) { + service := NewService(nil, "ps_", "1.7.3", "PrestaShop-fixed", "") + + got, err := service.ResolveCookieName(context.Background(), httptest.NewRequest("GET", "https://shop.example.com/", nil)) + if err != nil { + t.Fatalf("ResolveCookieName() error = %v", err) + } + if got != "PrestaShop-fixed" { + t.Fatalf("ResolveCookieName() = %q, want %q", got, "PrestaShop-fixed") + } +} + +func TestDomainCookieOverrideParticipatesInHash(t *testing.T) { + sum := md5.Sum([]byte("1.7.3" + "ps-s1" + overrideCookieHashDomain(".example.com"))) + got := fmt.Sprintf("PrestaShop-%x", sum) + want := "PrestaShop-1e5aa4f42a55532134a4e84017cdf643" + if got != want { + t.Fatalf("derived cookie name = %q, want %q", got, want) + } +} + +func TestNormalizeCookiePath(t *testing.T) { + if got := normalizeCookiePath(""); got != "/" { + t.Fatalf("normalizeCookiePath(\"\") = %q, want %q", got, "/") + } + if got := normalizeCookiePath("shop"); got != "/shop/" { + t.Fatalf("normalizeCookiePath(\"shop\") = %q, want %q", got, "/shop/") + } +} diff --git a/internal/render/engine.go b/internal/render/engine.go index ea3e45a..e61e77b 100644 --- a/internal/render/engine.go +++ b/internal/render/engine.go @@ -1,13 +1,22 @@ package render import ( + "context" "net/http" + "github.com/a-h/templ" + templruntime "github.com/a-h/templ/runtime" + "git.ma-al.com/goc_marek/ps_shop/internal/assets" "git.ma-al.com/goc_marek/ps_shop/internal/viewmodel" "git.ma-al.com/goc_marek/ps_shop/templates" ) +func init() { + // Minimize templ's internal buffering so HTML reaches the client as it is rendered. + templruntime.DefaultBufferSize = 1 +} + type Engine struct { assets assets.Manifest } @@ -19,13 +28,13 @@ func New(manifest assets.Manifest) *Engine { func (e *Engine) Product(w http.ResponseWriter, r *http.Request, data viewmodel.ProductPageData) error { startHTMLStream(w) component := templates.ProductPage(data, e.assets.CSSPath("app.css"), e.assets.JSPath("app.js")) - return component.Render(r.Context(), w) + return streamComponent(r.Context(), w, component) } func (e *Engine) Category(w http.ResponseWriter, r *http.Request, data viewmodel.CategoryPageData) error { startHTMLStream(w) component := templates.CategoryPage(data, e.assets.CSSPath("app.css"), e.assets.JSPath("app.js")) - return component.Render(r.Context(), w) + return streamComponent(r.Context(), w, component) } func startHTMLStream(w http.ResponseWriter) { @@ -41,3 +50,15 @@ func startHTMLStream(w http.ResponseWriter) { flusher.Flush() } } + +func streamComponent(ctx context.Context, w http.ResponseWriter, component templ.Component) error { + if component == nil { + return nil + } + var buffer templruntime.Buffer + buffer.Reset(w) + if err := component.Render(ctx, &buffer); err != nil { + return err + } + return buffer.Flush() +} diff --git a/internal/render/engine_test.go b/internal/render/engine_test.go new file mode 100644 index 0000000..9cd156c --- /dev/null +++ b/internal/render/engine_test.go @@ -0,0 +1,56 @@ +package render + +import ( + "context" + "io" + "net/http" + "testing" + + "github.com/a-h/templ" +) + +type writeCountingResponseWriter struct { + header http.Header + writes int + body []byte +} + +func (w *writeCountingResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *writeCountingResponseWriter) Write(p []byte) (int, error) { + w.writes++ + w.body = append(w.body, p...) + return len(p), nil +} + +func (w *writeCountingResponseWriter) WriteHeader(statusCode int) {} + +func (w *writeCountingResponseWriter) Flush() {} + +func TestStreamComponentWritesIncrementally(t *testing.T) { + w := &writeCountingResponseWriter{} + component := templ.ComponentFunc(func(ctx context.Context, writer io.Writer) error { + if _, err := writer.Write([]byte("a")); err != nil { + return err + } + if _, err := writer.Write([]byte("b")); err != nil { + return err + } + return nil + }) + + if err := streamComponent(context.Background(), w, component); err != nil { + t.Fatalf("streamComponent() error = %v", err) + } + if got := string(w.body); got != "ab" { + t.Fatalf("body = %q, want %q", got, "ab") + } + if w.writes < 2 { + t.Fatalf("writes = %d, want at least 2 incremental writes", w.writes) + } +}