homepage/server/csrf.go
2026-01-13 21:31:43 +01:00

179 lines
5.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"context"
"html/template"
"log/slog"
"net/http"
"slices"
"github.com/gorilla/csrf"
)
func csrfError(w http.ResponseWriter, r *http.Request) {
slog.Debug("CSRF error", "reason", csrf.FailureReason(r))
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
}
type fetchMetadataKey struct{}
type FetchMetadataSite string
type FetchMetadataMode string
type FetchMetadataDest string
const (
SiteCrossSite FetchMetadataSite = "cross-site"
SiteNone FetchMetadataSite = "none"
SiteSameOrigin FetchMetadataSite = "same-origin"
SiteSameSite FetchMetadataSite = "same-site"
ModeCors FetchMetadataMode = "cors"
ModeNavigate FetchMetadataMode = "navigate"
ModeNoCors FetchMetadataMode = "no-cors"
ModeSameOrigin FetchMetadataMode = "same-origin"
ModeWebSocket FetchMetadataMode = "websocket"
DestAudio FetchMetadataDest = "audio"
DestAudioworklet FetchMetadataDest = "audioworklet"
DestDocument FetchMetadataDest = "document"
DestEmbed FetchMetadataDest = "embed"
DestEmpty FetchMetadataDest = "empty"
DestFencedframe FetchMetadataDest = "fencedframe"
DestFont FetchMetadataDest = "font"
DestFrame FetchMetadataDest = "frame"
DestIframe FetchMetadataDest = "iframe"
)
var (
fetchMetadataSites = []FetchMetadataSite{SiteCrossSite, SiteNone, SiteSameOrigin, SiteSameSite}
fetchMetadataModes = []FetchMetadataMode{ModeCors, ModeNavigate, ModeNoCors, ModeSameOrigin, ModeWebSocket}
fetchMetadataDests = []FetchMetadataDest{DestAudio, DestAudioworklet, DestDocument, DestEmbed, DestEmpty, DestFencedframe, DestFont, DestFrame, DestIframe}
)
type FetchMetadata struct {
// Site is the pre-validated value of the `Sec-Fetch-Site` header.
Site FetchMetadataSite
// Mode is the pre-validated value of the `Sec-Fetch-Mode` header.
Mode FetchMetadataMode
// Dest is the pre-validated value of the `Sec-Fetch-Dest` header.
Dest FetchMetadataDest
}
func (fm *FetchMetadata) IsSameOrigin() bool {
return fm != nil && fm.Site == "same-origin"
}
// IsFetch checks whether the request originated from a client- side `fetch()`
// (or similar) call.
func (fm *FetchMetadata) IsFetch() bool {
return fm != nil && fm.Site == "same-origin" && fm.Mode == "cors" && fm.Dest == "empty"
}
// IsLocalNavigation checks if the request originated from a top-level
// navigation, while optionally allowing frame embedding.
func (fm *FetchMetadata) IsLocalNavigation(allowFrames bool) bool {
if fm == nil || fm.Site != "same-origin" || fm.Mode != "navigate" {
return false
}
if allowFrames {
return fm.Dest == "document" || fm.Dest == "fencedframe" || fm.Dest == "frame" || fm.Dest == "iframe"
} else {
return fm.Dest == "document"
}
}
func validateFetchMetadata(r *http.Request) *FetchMetadata {
site := FetchMetadataSite(r.Header.Get("Sec-Fetch-Site"))
if site == "" {
return nil
}
if !slices.Contains(fetchMetadataSites, site) {
slog.Debug("Invalid Sec-Metadata-Site header value", "fetchMetadataSite", site)
return &FetchMetadata{Valid: false}
}
mode := r.Header.Get("Sec-Fetch-Mode")
if !slices.Contains(fetchMetadataModes, mode) {
slog.Debug("Invalid Sec-Metadata-Mode header value", "fetchMetadataMode", mode)
return &FetchMetadata{Valid: false}
}
dest := r.Header.Get("Sec-Fetch-Dest")
if _, ok := slices.BinarySearch(fetchMetadataDests, dest); !ok {
slog.Debug("Invalid Sec-Metadata-Dest header value", "fetchMetadataDest", dest)
return &FetchMetadata{Valid: false}
}
return &FetchMetadata{
Valid: true,
Site: site,
Mode: mode,
Dest: dest,
}
}
// GetFetchMetadata returns the validate fetch metadata headers of the request.
// This may be nil if there were none (for old browsers and other user agents)
// or they were invalid.
//
// Invalid headers should be ignored according to the specification:
// https://w3c.github.io/webappsec-fetch-metadata
func GetFetchMetadata(r *http.Request) *FetchMetadata {
fm, _ := r.Context().Value(fetchMetadataKey{}).(*FetchMetadata)
return fm
}
// AbortFetchMetadata terminates the request with a CSRF error indicating that
// fetch metadata headers did not comply.
func AbortFetchMetadata(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
r = r.WithContext(context.WithValue(ctx, "gorilla.csrf.Error", "CSRF metadata headers did not comply"))
csrfError(w, r)
}
// ProtectCsrf is an HTTP middleware for CSRF protection.
//
// It defaults to using Fetch Metadata Request Headers (prefixed with `Sec-Fetch`)
// to check whether requests were in fact first-party requests coming from the
// same HTTP origin. Only the site header is used other metadata must be checked
// in route implementations.
//
// Failing support for Fetch Metadata Request Headers, CSRF tokens from hidden
// form fields (see CsrfTemplateField) are used as a fallback.
func ProtectCsrf(next http.Handler, authKey []byte, middlewareOpts ...csrf.Option) http.HandlerFunc {
middleware := csrf.Protect(
authKey,
append(middlewareOpts, csrf.ErrorHandler(http.HandlerFunc(csrfError)))...,
)(next)
return func(w http.ResponseWriter, r *http.Request) {
fetchMetadata := validateFetchMetadata(r)
if fetchMetadata != nil {
/*
switch r.Method {
case "GET", "HEAD", "OPTIONS", "TRACE": // Idempotent (safe) methods [RFC7231, 4.2.2]
default:
if secFetchSiteHeader != "same-origin" {
ctx := r.Context()
r = r.WithContext(context.WithValue(ctx, "gorilla.csrf.Error", "fetch metadata headers indicate cross-origin request"))
csrfError(w, r)
return
}
}*/
r = r.WithContext(context.WithValue(r.Context(), fetchMetadataKey{}, fetchMetadata))
next.ServeHTTP(w, r)
} else {
middleware.ServeHTTP(w, r)
}
}
}
// CsrfTemplateField is a template helper for CSRF protection. Stick it into
// any form that submits to a CSRF-protected endpoint.
func CsrfTemplateField(r *http.Request) template.HTML {
if r.Header.Get("Sec-Fetch-Site") != "" {
return ""
}
return csrf.TemplateField(r)
}