homepage/server/csrf_test.go
2026-01-13 21:31:43 +01:00

134 lines
4.4 KiB
Go

package main
import (
"fmt"
"net/http"
"net/http/httptest"
"slices"
"testing"
)
var emptyHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
type fetchMetadata struct {
site string
mode string
dest string
// Sec-Fetch-User is not tested here because it has limited availability:
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Sec-Fetch-User#browser_compatibility
}
func addFetchMetadataHeaders(req *http.Request, m fetchMetadata) {
req.Header.Set("Sec-Fetch-Site", m.site)
req.Header.Set("Sec-Fetch-Mode", m.mode)
req.Header.Set("Sec-Fetch-Dest", m.dest)
}
// TestLookups makes sure that the lookup tables for valid fetch metadata header
// values are sorted correctly, so they can be used in binary searches.
func TestLookups(t *testing.T) {
t.Run("sites", func(t *testing.T) {
var s = slices.Sorted(slices.Values(fetchMetadataSites))
if slices.Compare(s, fetchMetadataSites) != 0 {
t.Error("fetchMetadataSites is not sorted")
}
})
t.Run("modes", func(t *testing.T) {
var s = slices.Sorted(slices.Values(fetchMetadataModes))
if slices.Compare(s, fetchMetadataModes) != 0 {
t.Error("fetchMetadataModes is not sorted")
}
})
t.Run("dests", func(t *testing.T) {
var s = slices.Sorted(slices.Values(fetchMetadataDests))
if slices.Compare(s, fetchMetadataDests) != 0 {
t.Error("fetchMetadataDests is not sorted")
}
})
}
// TestCsrfFetchHeaderCookies sees that the response headers and (crucially)
// cookies don't change when the client supports fetch metadata request
// headers.
func TestCsrfFetchHeaderCookies(t *testing.T) {
handler := ProtectCsrf(emptyHandler, []byte("test"))
req := httptest.NewRequest(http.MethodGet, "/", nil)
addFetchMetadataHeaders(req, fetchMetadata{"none", "navigate", "document"})
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
resp := w.Result()
if len(resp.Header) != 0 {
t.Errorf("Fetch metadata CSRF protection changed headers: %v", resp.Header)
}
}
func TestCsrfFetchHeaderAllows(t *testing.T) {
handler := ProtectCsrf(emptyHandler, []byte("test"))
run := func(t *testing.T, m string, fm fetchMetadata) {
req := httptest.NewRequest(m, "/", nil)
addFetchMetadataHeaders(req, fm)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Errorf("CSRF protection false block (status %d %s)", resp.StatusCode, resp.Status)
}
}
for _, method := range []string{http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete} {
t.Run(method, func(t *testing.T) {
for _, fm := range []fetchMetadata{
{"same-origin", "navigate", "document"}, // Same origin top-level navigation / form submit
{"same-origin", "cors", "empty"}, // Same origin fetch() request
} {
t.Run(fmt.Sprintf("site-%s,mode-%s,dest-%s", fm.site, fm.mode, fm.dest), func(t *testing.T) {
run(t, method, fm)
})
}
})
}
for _, method := range []string{http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete} {
t.Run(method, func(t *testing.T) {
for _, fm := range []fetchMetadata{
{"same-origin", "navigate", "document"}, // Same origin top-level navigation / form submit
{"same-origin", "cors", "empty"}, // Same origin fetch() request
} {
t.Run(fmt.Sprintf("site-%s,mode-%s,dest-%s", fm.site, fm.mode, fm.dest), func(t *testing.T) {
run(t, method, fm)
})
}
})
}
}
func TestCsrfFetchHeaderBlocks(t *testing.T) {
handler := ProtectCsrf(emptyHandler, []byte("test"))
run := func(t *testing.T, m string, fm fetchMetadata) {
req := httptest.NewRequest(m, "/", nil)
addFetchMetadataHeaders(req, fm)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusForbidden {
t.Errorf("CSRF protection false allow (status %d %s)", resp.StatusCode, resp.Status)
}
}
for _, method := range []string{http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete} {
t.Run(method, func(t *testing.T) {
for _, fm := range []fetchMetadata{
{"none", "navigate", "document"}, // Manual top-level navigation
{"cross-site", "navigate", "document"}, // Cross-site navigation
{"same-site", "navigate", "document"}, // eTLD-1 same-site navigation
} {
t.Run(fmt.Sprintf("site-%s,mode-%s,dest-%s", fm.site, fm.mode, fm.dest), func(t *testing.T) {
run(t, method, fm)
})
}
})
}
}