package middleware import ( "crypto/rand" "encoding/base64" "log" "net/http" "strings" "time" ) // CSRFSimple is a lightweight double-submit-cookie CSRF middleware suitable for HTMX // - Sets a readable XSRF-TOKEN cookie on safe requests if missing // - Requires unsafe requests to include X-CSRF-Token header matching the cookie func CSRFSimple(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { method := strings.ToUpper(r.Method) // Determine if connection is effectively HTTPS (behind proxy aware) isHTTPS := r.TLS != nil || strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https") // Ensure token cookie exists for safe methods if method == http.MethodGet || method == http.MethodHead || method == http.MethodOptions || method == http.MethodTrace { if _, err := r.Cookie("XSRF-TOKEN"); err != nil { // Generate a random token buf := make([]byte, 32) if _, err := rand.Read(buf); err == nil { token := base64.RawURLEncoding.EncodeToString(buf) http.SetCookie(w, &http.Cookie{ Name: "XSRF-TOKEN", Value: token, Path: "/", HttpOnly: false, // must be readable by client script Secure: isHTTPS, SameSite: http.SameSiteLaxMode, Expires: time.Now().Add(12 * time.Hour), }) } } next.ServeHTTP(w, r) return } // For unsafe methods, require header matches cookie token := r.Header.Get("X-CSRF-Token") if token == "" { contentType := r.Header.Get("Content-Type") if strings.Contains(contentType, "multipart/form-data") { if err := r.ParseMultipartForm(10 << 20); err == nil { token = r.PostFormValue("csrfToken") } } else if err := r.ParseForm(); err == nil { token = r.PostFormValue("csrfToken") } } cookie, err := r.Cookie("XSRF-TOKEN") if err != nil || token == "" || cookie == nil || cookie.Value == "" || cookie.Value != token { if token == "" { if headerToken := r.Header.Get("X-CSRF-Token"); headerToken != "" { token = headerToken } } log.Printf("CSRF validation failed (simple mode): header=%q form=%q cookie=%q err=%v", r.Header.Get("X-CSRF-Token"), token, func() string { if cookie != nil { return cookie.Value } return "" }(), err) http.Error(w, "Forbidden - CSRF", http.StatusForbidden) return } next.ServeHTTP(w, r) }) }