From 413298462f7c08457fdf4371c0f85f2376e768cf Mon Sep 17 00:00:00 2001 From: nic Date: Thu, 16 Oct 2025 18:26:21 -0400 Subject: [PATCH] chore: implementing CSRF --- apps/web/main.go | 13 +++++-- internal/middleware/csrf_simple.go | 52 +++++++++++++++++++++++++ internal/middleware/security_headers.go | 18 +++++++++ 3 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 internal/middleware/csrf_simple.go create mode 100644 internal/middleware/security_headers.go diff --git a/apps/web/main.go b/apps/web/main.go index 9d9ccc4..ff13d10 100644 --- a/apps/web/main.go +++ b/apps/web/main.go @@ -106,9 +106,16 @@ func main() { if disableCSRF { log.Println("WARNING: CSRF is DISABLED by environment (CSRF_DISABLE=true) - DO NOT USE IN PRODUCTION") } else { - // Use standard gorilla/csrf; if origin checks are problematic in certain environments, - // fallback to simple double-submit cookie middleware via CSRF_MODE=simple - if os.Getenv("CSRF_MODE") == "simple" { + // Environment-guarded CSRF mode selection + appEnv := os.Getenv("APP_ENV") // expected: production, staging, dev, local + useSimple := os.Getenv("CSRF_MODE") == "simple" + + if appEnv == "production" && useSimple { + log.Fatal("CSRF_MODE=simple is not allowed in production") + } + + if useSimple && appEnv != "production" { + log.Println("INFO: Using simple CSRF mode (double-submit cookie) for non-production environment") protected.Use(middleware.CSRFSimple) } else { protected.Use(csrfMw) diff --git a/internal/middleware/csrf_simple.go b/internal/middleware/csrf_simple.go new file mode 100644 index 0000000..013920b --- /dev/null +++ b/internal/middleware/csrf_simple.go @@ -0,0 +1,52 @@ +package middleware + +import ( + "crypto/rand" + "encoding/base64" + "net/http" + "strings" + "time" +) + +// CSRFSimple is a lightweight double-submit-cookie CSRF middleware suitable for HTMX +// - Sets a readable XSRF-TOKEN cookie on safe requests if missing +// - Requires unsafe requests to include X-CSRF-Token header matching the cookie +func CSRFSimple(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + method := strings.ToUpper(r.Method) + // Determine if connection is effectively HTTPS (behind proxy aware) + isHTTPS := r.TLS != nil || strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https") + + // Ensure token cookie exists for safe methods + if method == http.MethodGet || method == http.MethodHead || method == http.MethodOptions || method == http.MethodTrace { + if _, err := r.Cookie("XSRF-TOKEN"); err != nil { + // Generate a random token + buf := make([]byte, 32) + if _, err := rand.Read(buf); err == nil { + token := base64.RawURLEncoding.EncodeToString(buf) + http.SetCookie(w, &http.Cookie{ + Name: "XSRF-TOKEN", + Value: token, + Path: "/", + HttpOnly: false, // must be readable by client script + Secure: isHTTPS, + SameSite: http.SameSiteLaxMode, + Expires: time.Now().Add(12 * time.Hour), + }) + } + } + next.ServeHTTP(w, r) + return + } + + // For unsafe methods, require header matches cookie + headerToken := r.Header.Get("X-CSRF-Token") + cookie, err := r.Cookie("XSRF-TOKEN") + if err != nil || headerToken == "" || cookie == nil || cookie.Value == "" || cookie.Value != headerToken { + http.Error(w, "Forbidden - CSRF", http.StatusForbidden) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/internal/middleware/security_headers.go b/internal/middleware/security_headers.go new file mode 100644 index 0000000..80adeb7 --- /dev/null +++ b/internal/middleware/security_headers.go @@ -0,0 +1,18 @@ +package middleware + +import ( + "net/http" +) + +func SecurityHeaders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' https://unpkg.com 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self'; object-src 'none'; base-uri 'self'; frame-ancestors 'none'") + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") + if r.TLS != nil { + w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload") + } + next.ServeHTTP(w, r) + }) +}