package main
import (
"context"
"html/template"
"log/slog"
"net/http"
"slices"
"github.com/gorilla/csrf"
)
func csrfError(w http.ResponseWriter, r *http.Request) {
slog.Debug("CSRF error", "reason", csrf.FailureReason(r))
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
}
type fetchMetadataKey struct{}
type FetchMetadataSite string
type FetchMetadataMode string
type FetchMetadataDest string
const (
SiteCrossSite FetchMetadataSite = "cross-site"
SiteNone FetchMetadataSite = "none"
SiteSameOrigin FetchMetadataSite = "same-origin"
SiteSameSite FetchMetadataSite = "same-site"
ModeCors FetchMetadataMode = "cors"
ModeNavigate FetchMetadataMode = "navigate"
ModeNoCors FetchMetadataMode = "no-cors"
ModeSameOrigin FetchMetadataMode = "same-origin"
ModeWebSocket FetchMetadataMode = "websocket"
DestAudio FetchMetadataDest = "audio"
DestAudioworklet FetchMetadataDest = "audioworklet"
DestDocument FetchMetadataDest = "document"
DestEmbed FetchMetadataDest = "embed"
DestEmpty FetchMetadataDest = "empty"
DestFencedframe FetchMetadataDest = "fencedframe"
DestFont FetchMetadataDest = "font"
DestFrame FetchMetadataDest = "frame"
DestIframe FetchMetadataDest = "iframe"
)
var (
fetchMetadataSites = []FetchMetadataSite{SiteCrossSite, SiteNone, SiteSameOrigin, SiteSameSite}
fetchMetadataModes = []FetchMetadataMode{ModeCors, ModeNavigate, ModeNoCors, ModeSameOrigin, ModeWebSocket}
fetchMetadataDests = []FetchMetadataDest{DestAudio, DestAudioworklet, DestDocument, DestEmbed, DestEmpty, DestFencedframe, DestFont, DestFrame, DestIframe}
)
type FetchMetadata struct {
// Site is the pre-validated value of the `Sec-Fetch-Site` header.
Site FetchMetadataSite
// Mode is the pre-validated value of the `Sec-Fetch-Mode` header.
Mode FetchMetadataMode
// Dest is the pre-validated value of the `Sec-Fetch-Dest` header.
Dest FetchMetadataDest
}
func (fm *FetchMetadata) IsSameOrigin() bool {
return fm != nil && fm.Site == "same-origin"
}
// IsFetch checks whether the request originated from a client- side `fetch()`
// (or similar) call.
func (fm *FetchMetadata) IsFetch() bool {
return fm != nil && fm.Site == "same-origin" && fm.Mode == "cors" && fm.Dest == "empty"
}
// IsLocalNavigation checks if the request originated from a top-level
// navigation, while optionally allowing frame embedding.
func (fm *FetchMetadata) IsLocalNavigation(allowFrames bool) bool {
if fm == nil || fm.Site != "same-origin" || fm.Mode != "navigate" {
return false
}
if allowFrames {
return fm.Dest == "document" || fm.Dest == "fencedframe" || fm.Dest == "frame" || fm.Dest == "iframe"
} else {
return fm.Dest == "document"
}
}
func validateFetchMetadata(r *http.Request) *FetchMetadata {
site := FetchMetadataSite(r.Header.Get("Sec-Fetch-Site"))
if site == "" {
return nil
}
if !slices.Contains(fetchMetadataSites, site) {
slog.Debug("Invalid Sec-Metadata-Site header value", "fetchMetadataSite", site)
return &FetchMetadata{Valid: false}
}
mode := r.Header.Get("Sec-Fetch-Mode")
if !slices.Contains(fetchMetadataModes, mode) {
slog.Debug("Invalid Sec-Metadata-Mode header value", "fetchMetadataMode", mode)
return &FetchMetadata{Valid: false}
}
dest := r.Header.Get("Sec-Fetch-Dest")
if _, ok := slices.BinarySearch(fetchMetadataDests, dest); !ok {
slog.Debug("Invalid Sec-Metadata-Dest header value", "fetchMetadataDest", dest)
return &FetchMetadata{Valid: false}
}
return &FetchMetadata{
Valid: true,
Site: site,
Mode: mode,
Dest: dest,
}
}
// GetFetchMetadata returns the validate fetch metadata headers of the request.
// This may be nil if there were none (for old browsers and other user agents)
// or they were invalid.
//
// Invalid headers should be ignored according to the specification:
// https://w3c.github.io/webappsec-fetch-metadata
func GetFetchMetadata(r *http.Request) *FetchMetadata {
fm, _ := r.Context().Value(fetchMetadataKey{}).(*FetchMetadata)
return fm
}
// AbortFetchMetadata terminates the request with a CSRF error indicating that
// fetch metadata headers did not comply.
func AbortFetchMetadata(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
r = r.WithContext(context.WithValue(ctx, "gorilla.csrf.Error", "CSRF metadata headers did not comply"))
csrfError(w, r)
}
// ProtectCsrf is an HTTP middleware for CSRF protection.
//
// It defaults to using Fetch Metadata Request Headers (prefixed with `Sec-Fetch`)
// to check whether requests were in fact first-party requests coming from the
// same HTTP origin. Only the site header is used – other metadata must be checked
// in route implementations.
//
// Failing support for Fetch Metadata Request Headers, CSRF tokens from hidden
// form fields (see CsrfTemplateField) are used as a fallback.
func ProtectCsrf(next http.Handler, authKey []byte, middlewareOpts ...csrf.Option) http.HandlerFunc {
middleware := csrf.Protect(
authKey,
append(middlewareOpts, csrf.ErrorHandler(http.HandlerFunc(csrfError)))...,
)(next)
return func(w http.ResponseWriter, r *http.Request) {
fetchMetadata := validateFetchMetadata(r)
if fetchMetadata != nil {
/*
switch r.Method {
case "GET", "HEAD", "OPTIONS", "TRACE": // Idempotent (safe) methods [RFC7231, 4.2.2]
default:
if secFetchSiteHeader != "same-origin" {
ctx := r.Context()
r = r.WithContext(context.WithValue(ctx, "gorilla.csrf.Error", "fetch metadata headers indicate cross-origin request"))
csrfError(w, r)
return
}
}*/
r = r.WithContext(context.WithValue(r.Context(), fetchMetadataKey{}, fetchMetadata))
next.ServeHTTP(w, r)
} else {
middleware.ServeHTTP(w, r)
}
}
}
// CsrfTemplateField is a template helper for CSRF protection. Stick it into
// any form that submits to a CSRF-protected endpoint.
func CsrfTemplateField(r *http.Request) template.HTML {
if r.Header.Get("Sec-Fetch-Site") != "" {
return ""
}
return csrf.TemplateField(r)
}