package main import ( "context" "html/template" "log/slog" "net/http" "slices" "github.com/gorilla/csrf" ) func csrfError(w http.ResponseWriter, r *http.Request) { slog.Debug("CSRF error", "reason", csrf.FailureReason(r)) http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) } type fetchMetadataKey struct{} type FetchMetadataSite string type FetchMetadataMode string type FetchMetadataDest string const ( SiteCrossSite FetchMetadataSite = "cross-site" SiteNone FetchMetadataSite = "none" SiteSameOrigin FetchMetadataSite = "same-origin" SiteSameSite FetchMetadataSite = "same-site" ModeCors FetchMetadataMode = "cors" ModeNavigate FetchMetadataMode = "navigate" ModeNoCors FetchMetadataMode = "no-cors" ModeSameOrigin FetchMetadataMode = "same-origin" ModeWebSocket FetchMetadataMode = "websocket" DestAudio FetchMetadataDest = "audio" DestAudioworklet FetchMetadataDest = "audioworklet" DestDocument FetchMetadataDest = "document" DestEmbed FetchMetadataDest = "embed" DestEmpty FetchMetadataDest = "empty" DestFencedframe FetchMetadataDest = "fencedframe" DestFont FetchMetadataDest = "font" DestFrame FetchMetadataDest = "frame" DestIframe FetchMetadataDest = "iframe" ) var ( fetchMetadataSites = []FetchMetadataSite{SiteCrossSite, SiteNone, SiteSameOrigin, SiteSameSite} fetchMetadataModes = []FetchMetadataMode{ModeCors, ModeNavigate, ModeNoCors, ModeSameOrigin, ModeWebSocket} fetchMetadataDests = []FetchMetadataDest{DestAudio, DestAudioworklet, DestDocument, DestEmbed, DestEmpty, DestFencedframe, DestFont, DestFrame, DestIframe} ) type FetchMetadata struct { // Site is the pre-validated value of the `Sec-Fetch-Site` header. Site FetchMetadataSite // Mode is the pre-validated value of the `Sec-Fetch-Mode` header. Mode FetchMetadataMode // Dest is the pre-validated value of the `Sec-Fetch-Dest` header. Dest FetchMetadataDest } func (fm *FetchMetadata) IsSameOrigin() bool { return fm != nil && fm.Site == "same-origin" } // IsFetch checks whether the request originated from a client- side `fetch()` // (or similar) call. func (fm *FetchMetadata) IsFetch() bool { return fm != nil && fm.Site == "same-origin" && fm.Mode == "cors" && fm.Dest == "empty" } // IsLocalNavigation checks if the request originated from a top-level // navigation, while optionally allowing frame embedding. func (fm *FetchMetadata) IsLocalNavigation(allowFrames bool) bool { if fm == nil || fm.Site != "same-origin" || fm.Mode != "navigate" { return false } if allowFrames { return fm.Dest == "document" || fm.Dest == "fencedframe" || fm.Dest == "frame" || fm.Dest == "iframe" } else { return fm.Dest == "document" } } func validateFetchMetadata(r *http.Request) *FetchMetadata { site := FetchMetadataSite(r.Header.Get("Sec-Fetch-Site")) if site == "" { return nil } if !slices.Contains(fetchMetadataSites, site) { slog.Debug("Invalid Sec-Metadata-Site header value", "fetchMetadataSite", site) return &FetchMetadata{Valid: false} } mode := r.Header.Get("Sec-Fetch-Mode") if !slices.Contains(fetchMetadataModes, mode) { slog.Debug("Invalid Sec-Metadata-Mode header value", "fetchMetadataMode", mode) return &FetchMetadata{Valid: false} } dest := r.Header.Get("Sec-Fetch-Dest") if _, ok := slices.BinarySearch(fetchMetadataDests, dest); !ok { slog.Debug("Invalid Sec-Metadata-Dest header value", "fetchMetadataDest", dest) return &FetchMetadata{Valid: false} } return &FetchMetadata{ Valid: true, Site: site, Mode: mode, Dest: dest, } } // GetFetchMetadata returns the validate fetch metadata headers of the request. // This may be nil if there were none (for old browsers and other user agents) // or they were invalid. // // Invalid headers should be ignored according to the specification: // https://w3c.github.io/webappsec-fetch-metadata func GetFetchMetadata(r *http.Request) *FetchMetadata { fm, _ := r.Context().Value(fetchMetadataKey{}).(*FetchMetadata) return fm } // AbortFetchMetadata terminates the request with a CSRF error indicating that // fetch metadata headers did not comply. func AbortFetchMetadata(w http.ResponseWriter, r *http.Request) { ctx := r.Context() r = r.WithContext(context.WithValue(ctx, "gorilla.csrf.Error", "CSRF metadata headers did not comply")) csrfError(w, r) } // ProtectCsrf is an HTTP middleware for CSRF protection. // // It defaults to using Fetch Metadata Request Headers (prefixed with `Sec-Fetch`) // to check whether requests were in fact first-party requests coming from the // same HTTP origin. Only the site header is used – other metadata must be checked // in route implementations. // // Failing support for Fetch Metadata Request Headers, CSRF tokens from hidden // form fields (see CsrfTemplateField) are used as a fallback. func ProtectCsrf(next http.Handler, authKey []byte, middlewareOpts ...csrf.Option) http.HandlerFunc { middleware := csrf.Protect( authKey, append(middlewareOpts, csrf.ErrorHandler(http.HandlerFunc(csrfError)))..., )(next) return func(w http.ResponseWriter, r *http.Request) { fetchMetadata := validateFetchMetadata(r) if fetchMetadata != nil { /* switch r.Method { case "GET", "HEAD", "OPTIONS", "TRACE": // Idempotent (safe) methods [RFC7231, 4.2.2] default: if secFetchSiteHeader != "same-origin" { ctx := r.Context() r = r.WithContext(context.WithValue(ctx, "gorilla.csrf.Error", "fetch metadata headers indicate cross-origin request")) csrfError(w, r) return } }*/ r = r.WithContext(context.WithValue(r.Context(), fetchMetadataKey{}, fetchMetadata)) next.ServeHTTP(w, r) } else { middleware.ServeHTTP(w, r) } } } // CsrfTemplateField is a template helper for CSRF protection. Stick it into // any form that submits to a CSRF-protected endpoint. func CsrfTemplateField(r *http.Request) template.HTML { if r.Header.Get("Sec-Fetch-Site") != "" { return "" } return csrf.TemplateField(r) }