diff --git a/VERSION.txt b/VERSION.txt index b6ec845f38f94..a0f8d2c2b4f7c 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -1.80.2 +1.80.3 diff --git a/appc/appconnector.go b/appc/appconnector.go index f4857fcc612aa..89c6c9aeb9aa7 100644 --- a/appc/appconnector.go +++ b/appc/appconnector.go @@ -289,9 +289,11 @@ func (e *AppConnector) updateDomains(domains []string) { toRemove = append(toRemove, netip.PrefixFrom(a, a.BitLen())) } } - if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil { - e.logf("failed to unadvertise routes on domain removal: %v: %v: %v", slicesx.MapKeys(oldDomains), toRemove, err) - } + e.queue.Add(func() { + if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil { + e.logf("failed to unadvertise routes on domain removal: %v: %v: %v", slicesx.MapKeys(oldDomains), toRemove, err) + } + }) } e.logf("handling domains: %v and wildcards: %v", slicesx.MapKeys(e.domains), e.wildcards) @@ -310,11 +312,6 @@ func (e *AppConnector) updateRoutes(routes []netip.Prefix) { return } - if err := e.routeAdvertiser.AdvertiseRoute(routes...); err != nil { - e.logf("failed to advertise routes: %v: %v", routes, err) - return - } - var toRemove []netip.Prefix // If we're storing routes and know e.controlRoutes is a good @@ -338,9 +335,14 @@ nextRoute: } } - if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil { - e.logf("failed to unadvertise routes: %v: %v", toRemove, err) - } + e.queue.Add(func() { + if err := e.routeAdvertiser.AdvertiseRoute(routes...); err != nil { + e.logf("failed to advertise routes: %v: %v", routes, err) + } + if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil { + e.logf("failed to unadvertise routes: %v: %v", toRemove, err) + } + }) e.controlRoutes = routes if err := e.storeRoutesLocked(); err != nil { diff --git a/appc/appconnector_test.go b/appc/appconnector_test.go index fd0001224984a..c13835f39ed9a 100644 --- a/appc/appconnector_test.go +++ b/appc/appconnector_test.go @@ -8,6 +8,7 @@ import ( "net/netip" "reflect" "slices" + "sync/atomic" "testing" "time" @@ -86,6 +87,7 @@ func TestUpdateRoutes(t *testing.T) { routes := []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24"), netip.MustParsePrefix("192.0.0.1/32")} a.updateRoutes(routes) + a.Wait(ctx) slices.SortFunc(rc.Routes(), prefixCompare) rc.SetRoutes(slices.Compact(rc.Routes())) @@ -105,6 +107,7 @@ func TestUpdateRoutes(t *testing.T) { } func TestUpdateRoutesUnadvertisesContainedRoutes(t *testing.T) { + ctx := context.Background() for _, shouldStore := range []bool{false, true} { rc := &appctest.RouteCollector{} var a *AppConnector @@ -117,6 +120,7 @@ func TestUpdateRoutesUnadvertisesContainedRoutes(t *testing.T) { rc.SetRoutes([]netip.Prefix{netip.MustParsePrefix("192.0.2.1/32")}) routes := []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24")} a.updateRoutes(routes) + a.Wait(ctx) if !slices.EqualFunc(routes, rc.Routes(), prefixEqual) { t.Fatalf("got %v, want %v", rc.Routes(), routes) @@ -636,3 +640,57 @@ func TestMetricBucketsAreSorted(t *testing.T) { t.Errorf("metricStoreRoutesNBuckets must be in order") } } + +// TestUpdateRoutesDeadlock is a regression test for a deadlock in +// LocalBackend<->AppConnector interaction. When using real LocalBackend as the +// routeAdvertiser, calls to Advertise/UnadvertiseRoutes can end up calling +// back into AppConnector via authReconfig. If everything is called +// synchronously, this results in a deadlock on AppConnector.mu. +func TestUpdateRoutesDeadlock(t *testing.T) { + ctx := context.Background() + rc := &appctest.RouteCollector{} + a := NewAppConnector(t.Logf, rc, &RouteInfo{}, fakeStoreRoutes) + + advertiseCalled := new(atomic.Bool) + unadvertiseCalled := new(atomic.Bool) + rc.AdvertiseCallback = func() { + // Call something that requires a.mu to be held. + a.DomainRoutes() + advertiseCalled.Store(true) + } + rc.UnadvertiseCallback = func() { + // Call something that requires a.mu to be held. + a.DomainRoutes() + unadvertiseCalled.Store(true) + } + + a.updateDomains([]string{"example.com"}) + a.Wait(ctx) + + // Trigger rc.AdveriseRoute. + a.updateRoutes( + []netip.Prefix{ + netip.MustParsePrefix("127.0.0.1/32"), + netip.MustParsePrefix("127.0.0.2/32"), + }, + ) + a.Wait(ctx) + // Trigger rc.UnadveriseRoute. + a.updateRoutes( + []netip.Prefix{ + netip.MustParsePrefix("127.0.0.1/32"), + }, + ) + a.Wait(ctx) + + if !advertiseCalled.Load() { + t.Error("AdvertiseRoute was not called") + } + if !unadvertiseCalled.Load() { + t.Error("UnadvertiseRoute was not called") + } + + if want := []netip.Prefix{netip.MustParsePrefix("127.0.0.1/32")}; !slices.Equal(slices.Compact(rc.Routes()), want) { + t.Fatalf("got %v, want %v", rc.Routes(), want) + } +} diff --git a/appc/appctest/appctest.go b/appc/appctest/appctest.go index aa77bc3b41044..9726a2b97d72b 100644 --- a/appc/appctest/appctest.go +++ b/appc/appctest/appctest.go @@ -11,12 +11,22 @@ import ( // RouteCollector is a test helper that collects the list of routes advertised type RouteCollector struct { + // AdvertiseCallback (optional) is called synchronously from + // AdvertiseRoute. + AdvertiseCallback func() + // UnadvertiseCallback (optional) is called synchronously from + // UnadvertiseRoute. + UnadvertiseCallback func() + routes []netip.Prefix removedRoutes []netip.Prefix } func (rc *RouteCollector) AdvertiseRoute(pfx ...netip.Prefix) error { rc.routes = append(rc.routes, pfx...) + if rc.AdvertiseCallback != nil { + rc.AdvertiseCallback() + } return nil } @@ -30,6 +40,9 @@ func (rc *RouteCollector) UnadvertiseRoute(toRemove ...netip.Prefix) error { rc.removedRoutes = append(rc.removedRoutes, r) } } + if rc.UnadvertiseCallback != nil { + rc.UnadvertiseCallback() + } return nil } diff --git a/client/web/web.go b/client/web/web.go index 3a7feea40c398..8218327817161 100644 --- a/client/web/web.go +++ b/client/web/web.go @@ -203,15 +203,25 @@ func NewServer(opts ServerOpts) (s *Server, err error) { } s.assetsHandler, s.assetsCleanup = assetsHandler(s.devMode) - var metric string // clientmetric to report on startup + var metric string + s.apiHandler, metric = s.modeAPIHandler(s.mode) + s.apiHandler = s.withCSRF(s.apiHandler) - // Create handler for "/api" requests with CSRF protection. - // We don't require secure cookies, since the web client is regularly used - // on network appliances that are served on local non-https URLs. - // The client is secured by limiting the interface it listens on, - // or by authenticating requests before they reach the web client. + // Don't block startup on reporting metric. + // Report in separate go routine with 5 second timeout. + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s.lc.IncrementCounter(ctx, metric, 1) + }() + + return s, nil +} + +func (s *Server) withCSRF(h http.Handler) http.Handler { csrfProtect := csrf.Protect(s.csrfKey(), csrf.Secure(false)) + // ref https://github.com/tailscale/tailscale/pull/14822 // signal to the CSRF middleware that the request is being served over // plaintext HTTP to skip TLS-only header checks. withSetPlaintext := func(h http.Handler) http.Handler { @@ -221,27 +231,24 @@ func NewServer(opts ServerOpts) (s *Server, err error) { }) } - switch s.mode { + // NB: the order of the withSetPlaintext and csrfProtect calls is important + // to ensure that we signal to the CSRF middleware that the request is being + // served over plaintext HTTP and not over TLS as it presumes by default. + return withSetPlaintext(csrfProtect(h)) +} + +func (s *Server) modeAPIHandler(mode ServerMode) (http.Handler, string) { + switch mode { case LoginServerMode: - s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveLoginAPI))) - metric = "web_login_client_initialization" + return http.HandlerFunc(s.serveLoginAPI), "web_login_client_initialization" case ReadOnlyServerMode: - s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveLoginAPI))) - metric = "web_readonly_client_initialization" + return http.HandlerFunc(s.serveLoginAPI), "web_readonly_client_initialization" case ManageServerMode: - s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveAPI))) - metric = "web_client_initialization" + return http.HandlerFunc(s.serveAPI), "web_client_initialization" + default: // invalid mode + log.Fatalf("invalid mode: %v", mode) } - - // Don't block startup on reporting metric. - // Report in separate go routine with 5 second timeout. - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - s.lc.IncrementCounter(ctx, metric, 1) - }() - - return s, nil + return nil, "" } func (s *Server) Shutdown() { diff --git a/client/web/web_test.go b/client/web/web_test.go index 3c5543c12014c..e579c450a1bb7 100644 --- a/client/web/web_test.go +++ b/client/web/web_test.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "net/http" + "net/http/cookiejar" "net/http/httptest" "net/netip" "net/url" @@ -20,6 +21,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/gorilla/csrf" "tailscale.com/client/tailscale" "tailscale.com/client/tailscale/apitype" "tailscale.com/ipn" @@ -1477,3 +1479,83 @@ func mockWaitAuthURL(_ context.Context, id string, src tailcfg.NodeID) (*tailcfg return nil, errors.New("unknown id") } } + +func TestCSRFProtect(t *testing.T) { + s := &Server{} + + mux := http.NewServeMux() + mux.HandleFunc("GET /test/csrf-token", func(w http.ResponseWriter, r *http.Request) { + token := csrf.Token(r) + _, err := io.WriteString(w, token) + if err != nil { + t.Fatal(err) + } + }) + mux.HandleFunc("POST /test/csrf-protected", func(w http.ResponseWriter, r *http.Request) { + _, err := io.WriteString(w, "ok") + if err != nil { + t.Fatal(err) + } + }) + h := s.withCSRF(mux) + ser := httptest.NewServer(h) + defer ser.Close() + + jar, err := cookiejar.New(nil) + if err != nil { + t.Fatalf("unable to construct cookie jar: %v", err) + } + + client := ser.Client() + client.Jar = jar + + // make GET request to populate cookie jar + resp, err := client.Get(ser.URL + "/test/csrf-token") + if err != nil { + t.Fatalf("unable to make request: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %v", resp.Status) + } + tokenBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unable to read body: %v", err) + } + + csrfToken := strings.TrimSpace(string(tokenBytes)) + if csrfToken == "" { + t.Fatal("empty csrf token") + } + + // make a POST request without the CSRF header; ensure it fails + resp, err = client.Post(ser.URL+"/test/csrf-protected", "text/plain", nil) + if err != nil { + t.Fatalf("unable to make request: %v", err) + } + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("unexpected status: %v", resp.Status) + } + + // make a POST request with the CSRF header; ensure it succeeds + req, err := http.NewRequest("POST", ser.URL+"/test/csrf-protected", nil) + if err != nil { + t.Fatalf("error building request: %v", err) + } + req.Header.Set("X-CSRF-Token", csrfToken) + resp, err = client.Do(req) + if err != nil { + t.Fatalf("unable to make request: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %v", resp.Status) + } + defer resp.Body.Close() + out, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unable to read body: %v", err) + } + if string(out) != "ok" { + t.Fatalf("unexpected body: %q", out) + } +}