You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
93 lines
2.9 KiB
93 lines
2.9 KiB
package middleware
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// CSRFSimple is a lightweight double-submit-cookie CSRF middleware suitable for HTMX
|
|
// - Sets a readable XSRF-TOKEN cookie on safe requests if missing
|
|
// - Requires unsafe requests to include X-CSRF-Token header matching the cookie
|
|
func CSRFSimple(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
method := strings.ToUpper(r.Method)
|
|
// Determine if connection is effectively HTTPS (behind proxy aware)
|
|
isHTTPS := r.TLS != nil || strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https")
|
|
|
|
// Ensure token cookie exists for safe methods
|
|
if method == http.MethodGet || method == http.MethodHead || method == http.MethodOptions || method == http.MethodTrace {
|
|
token := ""
|
|
if c, err := r.Cookie("XSRF-TOKEN"); err == nil {
|
|
if v := strings.TrimSpace(c.Value); v != "" && !strings.Contains(v, "|") {
|
|
token = v
|
|
}
|
|
}
|
|
if token == "" {
|
|
buf := make([]byte, 32)
|
|
if _, err := rand.Read(buf); err == nil {
|
|
token = base64.RawURLEncoding.EncodeToString(buf)
|
|
}
|
|
}
|
|
if token != "" {
|
|
// Refresh both legacy and simple-mode cookie names so the frontend reads the correct value.
|
|
for _, name := range []string{"XSRF-TOKEN", "XSRF-TOKEN-VALUE"} {
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: name,
|
|
Value: token,
|
|
Path: "/",
|
|
HttpOnly: false,
|
|
Secure: isHTTPS,
|
|
SameSite: http.SameSiteLaxMode,
|
|
Expires: time.Now().Add(12 * time.Hour),
|
|
})
|
|
}
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// For unsafe methods, require header matches cookie
|
|
token := r.Header.Get("X-CSRF-Token")
|
|
if token == "" {
|
|
contentType := r.Header.Get("Content-Type")
|
|
if strings.Contains(contentType, "multipart/form-data") {
|
|
if err := r.ParseMultipartForm(10 << 20); err == nil {
|
|
token = r.PostFormValue("csrfToken")
|
|
}
|
|
} else if err := r.ParseForm(); err == nil {
|
|
token = r.PostFormValue("csrfToken")
|
|
}
|
|
}
|
|
|
|
tokenCookie, err := r.Cookie("XSRF-TOKEN")
|
|
if err != nil || tokenCookie == nil || strings.Contains(tokenCookie.Value, "|") || tokenCookie.Value == "" {
|
|
if legacyCookie, legacyErr := r.Cookie("XSRF-TOKEN-VALUE"); legacyErr == nil {
|
|
tokenCookie = legacyCookie
|
|
} else {
|
|
err = legacyErr
|
|
}
|
|
}
|
|
|
|
if token == "" || tokenCookie == nil || tokenCookie.Value == "" || tokenCookie.Value != token {
|
|
if token == "" {
|
|
if headerToken := r.Header.Get("X-CSRF-Token"); headerToken != "" {
|
|
token = headerToken
|
|
}
|
|
}
|
|
log.Printf("CSRF validation failed (simple mode): header=%q form=%q cookie=%q err=%v", r.Header.Get("X-CSRF-Token"), token, func() string {
|
|
if tokenCookie != nil {
|
|
return tokenCookie.Value
|
|
}
|
|
return ""
|
|
}(), err)
|
|
http.Error(w, "Forbidden - CSRF", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|