package main import ( "net/http" "net/http/httptest" "testing" ) func TestCORSMiddlewareAllowsDefaultLocalhostOrigin(t *testing.T) { t.Setenv("CORS_ALLOWED_ORIGINS", "") t.Setenv("ALLOWED_ORIGINS", "") req := httptest.NewRequest(http.MethodGet, "/health", nil) req.Header.Set("Origin", "http://localhost:5173") rec := httptest.NewRecorder() corsMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })).ServeHTTP(rec, req) if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:5173" { t.Fatalf("expected localhost origin to be allowed, got %q", got) } if got := rec.Header().Get("Access-Control-Allow-Credentials"); got != "true" { t.Fatalf("expected credentials header for allowed origin, got %q", got) } } func TestCORSMiddlewareDoesNotReflectDisallowedOrigin(t *testing.T) { t.Setenv("CORS_ALLOWED_ORIGINS", "https://app.example.com") t.Setenv("ALLOWED_ORIGINS", "") req := httptest.NewRequest(http.MethodOptions, "/api/v1/auth/login", nil) req.Header.Set("Origin", "https://evil.example.com") rec := httptest.NewRecorder() corsMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("preflight should not call next handler") })).ServeHTTP(rec, req) if got := rec.Code; got != http.StatusNoContent { t.Fatalf("expected preflight status %d, got %d", http.StatusNoContent, got) } if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "" { t.Fatalf("expected disallowed origin not to be reflected, got %q", got) } if got := rec.Header().Get("Access-Control-Allow-Credentials"); got != "" { t.Fatalf("expected credentials header to be omitted for disallowed origin, got %q", got) } }