package main import ( "fmt" "net/http" "net/http/httptest" "slices" "testing" ) var emptyHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) type fetchMetadata struct { site string mode string dest string // Sec-Fetch-User is not tested here because it has limited availability: // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Sec-Fetch-User#browser_compatibility } func addFetchMetadataHeaders(req *http.Request, m fetchMetadata) { req.Header.Set("Sec-Fetch-Site", m.site) req.Header.Set("Sec-Fetch-Mode", m.mode) req.Header.Set("Sec-Fetch-Dest", m.dest) } // TestLookups makes sure that the lookup tables for valid fetch metadata header // values are sorted correctly, so they can be used in binary searches. func TestLookups(t *testing.T) { t.Run("sites", func(t *testing.T) { var s = slices.Sorted(slices.Values(fetchMetadataSites)) if slices.Compare(s, fetchMetadataSites) != 0 { t.Error("fetchMetadataSites is not sorted") } }) t.Run("modes", func(t *testing.T) { var s = slices.Sorted(slices.Values(fetchMetadataModes)) if slices.Compare(s, fetchMetadataModes) != 0 { t.Error("fetchMetadataModes is not sorted") } }) t.Run("dests", func(t *testing.T) { var s = slices.Sorted(slices.Values(fetchMetadataDests)) if slices.Compare(s, fetchMetadataDests) != 0 { t.Error("fetchMetadataDests is not sorted") } }) } // TestCsrfFetchHeaderCookies sees that the response headers and (crucially) // cookies don't change when the client supports fetch metadata request // headers. func TestCsrfFetchHeaderCookies(t *testing.T) { handler := ProtectCsrf(emptyHandler, []byte("test")) req := httptest.NewRequest(http.MethodGet, "/", nil) addFetchMetadataHeaders(req, fetchMetadata{"none", "navigate", "document"}) w := httptest.NewRecorder() handler.ServeHTTP(w, req) resp := w.Result() if len(resp.Header) != 0 { t.Errorf("Fetch metadata CSRF protection changed headers: %v", resp.Header) } } func TestCsrfFetchHeaderAllows(t *testing.T) { handler := ProtectCsrf(emptyHandler, []byte("test")) run := func(t *testing.T, m string, fm fetchMetadata) { req := httptest.NewRequest(m, "/", nil) addFetchMetadataHeaders(req, fm) w := httptest.NewRecorder() handler.ServeHTTP(w, req) resp := w.Result() if resp.StatusCode != http.StatusOK { t.Errorf("CSRF protection false block (status %d %s)", resp.StatusCode, resp.Status) } } for _, method := range []string{http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete} { t.Run(method, func(t *testing.T) { for _, fm := range []fetchMetadata{ {"same-origin", "navigate", "document"}, // Same origin top-level navigation / form submit {"same-origin", "cors", "empty"}, // Same origin fetch() request } { t.Run(fmt.Sprintf("site-%s,mode-%s,dest-%s", fm.site, fm.mode, fm.dest), func(t *testing.T) { run(t, method, fm) }) } }) } for _, method := range []string{http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete} { t.Run(method, func(t *testing.T) { for _, fm := range []fetchMetadata{ {"same-origin", "navigate", "document"}, // Same origin top-level navigation / form submit {"same-origin", "cors", "empty"}, // Same origin fetch() request } { t.Run(fmt.Sprintf("site-%s,mode-%s,dest-%s", fm.site, fm.mode, fm.dest), func(t *testing.T) { run(t, method, fm) }) } }) } } func TestCsrfFetchHeaderBlocks(t *testing.T) { handler := ProtectCsrf(emptyHandler, []byte("test")) run := func(t *testing.T, m string, fm fetchMetadata) { req := httptest.NewRequest(m, "/", nil) addFetchMetadataHeaders(req, fm) w := httptest.NewRecorder() handler.ServeHTTP(w, req) resp := w.Result() if resp.StatusCode != http.StatusForbidden { t.Errorf("CSRF protection false allow (status %d %s)", resp.StatusCode, resp.Status) } } for _, method := range []string{http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete} { t.Run(method, func(t *testing.T) { for _, fm := range []fetchMetadata{ {"none", "navigate", "document"}, // Manual top-level navigation {"cross-site", "navigate", "document"}, // Cross-site navigation {"same-site", "navigate", "document"}, // eTLD-1 same-site navigation } { t.Run(fmt.Sprintf("site-%s,mode-%s,dest-%s", fm.site, fm.mode, fm.dest), func(t *testing.T) { run(t, method, fm) }) } }) } }