From acb4a22dcc1d535ee85c76bb25767eb76d83edd5 Mon Sep 17 00:00:00 2001 From: Jonathan Nobels Date: Thu, 10 Oct 2024 14:34:14 -0400 Subject: [PATCH 001/179] VERSION.txt: this is v1.77.0 (#13779) --- VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION.txt b/VERSION.txt index 7c7053aa2388a..79e15fd49370a 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -1.75.0 +1.77.0 From 33029d4486d71714bfed29c84c5f6f0da1626ec2 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Thu, 10 Oct 2024 15:52:47 -0700 Subject: [PATCH 002/179] net/netcheck: fix netcheck cli-triggered nil pointer deref (#13782) Updates #13780 Signed-off-by: Jordan Whited --- net/netcheck/netcheck.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index dbb85cf9c0945..bebf4c9b05461 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -940,7 +940,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe } } if len(need) > 0 { - if !opts.OnlyTCP443 { + if opts == nil || !opts.OnlyTCP443 { // Kick off ICMP in parallel to HTTPS checks; we don't // reuse the same WaitGroup for those probes because we // need to close the underlying Pinger after a timeout From f9949cde8bba1156aaccc189e1632bf9a1478444 Mon Sep 17 00:00:00 2001 From: Percy Wegmann Date: Fri, 11 Oct 2024 08:06:53 -0500 Subject: [PATCH 003/179] client/tailscale,cmd/{cli,get-authkey,k8s-operator}: set distinct User-Agents This helps better distinguish what is generating activity to the Tailscale public API. Updates tailscale/corp#23838 Signed-off-by: Percy Wegmann --- client/tailscale/tailscale.go | 17 ++++++++++------- cmd/get-authkey/main.go | 1 + cmd/k8s-operator/operator.go | 1 + cmd/tailscale/cli/up.go | 1 + tsnet/tsnet.go | 1 + 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/client/tailscale/tailscale.go b/client/tailscale/tailscale.go index 8945619653c5d..8533b47129e01 100644 --- a/client/tailscale/tailscale.go +++ b/client/tailscale/tailscale.go @@ -51,6 +51,9 @@ type Client struct { // HTTPClient optionally specifies an alternate HTTP client to use. // If nil, http.DefaultClient is used. HTTPClient *http.Client + + // UserAgent optionally specifies an alternate User-Agent header + UserAgent string } func (c *Client) httpClient() *http.Client { @@ -97,8 +100,9 @@ func (c *Client) setAuth(r *http.Request) { // and can be changed manually by the user. func NewClient(tailnet string, auth AuthMethod) *Client { return &Client{ - tailnet: tailnet, - auth: auth, + tailnet: tailnet, + auth: auth, + UserAgent: "tailscale-client-oss", } } @@ -110,17 +114,16 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { return nil, errors.New("use of Client without setting I_Acknowledge_This_API_Is_Unstable") } c.setAuth(req) + if c.UserAgent != "" { + req.Header.Set("User-Agent", c.UserAgent) + } return c.httpClient().Do(req) } // sendRequest add the authentication key to the request and sends it. It // receives the response and reads up to 10MB of it. func (c *Client) sendRequest(req *http.Request) ([]byte, *http.Response, error) { - if !I_Acknowledge_This_API_Is_Unstable { - return nil, nil, errors.New("use of Client without setting I_Acknowledge_This_API_Is_Unstable") - } - c.setAuth(req) - resp, err := c.httpClient().Do(req) + resp, err := c.Do(req) if err != nil { return nil, resp, err } diff --git a/cmd/get-authkey/main.go b/cmd/get-authkey/main.go index d8030252c0081..777258d64b21b 100644 --- a/cmd/get-authkey/main.go +++ b/cmd/get-authkey/main.go @@ -51,6 +51,7 @@ func main() { ctx := context.Background() tsClient := tailscale.NewClient("-", nil) + tsClient.UserAgent = "tailscale-get-authkey" tsClient.HTTPClient = credentials.Client(ctx) tsClient.BaseURL = baseURL diff --git a/cmd/k8s-operator/operator.go b/cmd/k8s-operator/operator.go index bd9c0f7bcd5b0..d8dd403cc6097 100644 --- a/cmd/k8s-operator/operator.go +++ b/cmd/k8s-operator/operator.go @@ -143,6 +143,7 @@ func initTSNet(zlog *zap.SugaredLogger) (*tsnet.Server, *tailscale.Client) { TokenURL: "https://login.tailscale.com/api/v2/oauth/token", } tsClient := tailscale.NewClient("-", nil) + tsClient.UserAgent = "tailscale-k8s-operator" tsClient.HTTPClient = credentials.Client(context.Background()) s := &tsnet.Server{ diff --git a/cmd/tailscale/cli/up.go b/cmd/tailscale/cli/up.go index e1b828105b8dd..bf6a9af773f60 100644 --- a/cmd/tailscale/cli/up.go +++ b/cmd/tailscale/cli/up.go @@ -1152,6 +1152,7 @@ func resolveAuthKey(ctx context.Context, v, tags string) (string, error) { } tsClient := tailscale.NewClient("-", nil) + tsClient.UserAgent = "tailscale-cli" tsClient.HTTPClient = credentials.Client(ctx) tsClient.BaseURL = baseURL diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 0be33ba8a5d37..6751e0bb03cbe 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -903,6 +903,7 @@ func (s *Server) APIClient() (*tailscale.Client, error) { } c := tailscale.NewClient("-", nil) + c.UserAgent = "tailscale-tsnet" c.HTTPClient = &http.Client{Transport: s.lb.KeyProvingNoiseRoundTripper()} return c, nil } From 17335d21049c724e365d4e9879286cd2fdb9aba5 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 4 Oct 2024 12:23:34 -0500 Subject: [PATCH 004/179] net/dns/resolver: forward SERVFAIL responses over PeerDNS As per the docstring, (*forwarder).forwardWithDestChan should either send to responseChan and returns nil, or returns a non-nil error (without sending to the channel). However, this does not hold when all upstream DNS servers replied with an error. We've been handling this special error path in (*Resolver).Query but not in (*Resolver).HandlePeerDNSQuery. As a result, SERVFAIL responses from upstream servers were being converted into HTTP 503 responses, instead of being properly forwarded as SERVFAIL within a successful HTTP response, as per RFC 8484, section 4.2.1: A successful HTTP response with a 2xx status code (see Section 6.3 of [RFC7231]) is used for any valid DNS response, regardless of the DNS response code. For example, a successful 2xx HTTP status code is used even with a DNS message whose DNS response code indicates failure, such as SERVFAIL or NXDOMAIN. In this PR we fix (*forwarder).forwardWithDestChan to no longer return an error when it sends a response to responseChan, and remove the special handling in (*Resolver).Query, as it is no longer necessary. Updates #13571 Signed-off-by: Nick Hill --- net/dns/resolver/forwarder.go | 1 + net/dns/resolver/tsdns.go | 10 +--------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 846ca3d5e4fb5..5920b7f29ec67 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -1053,6 +1053,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo if verboseDNSForward() { f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr) } + return nil } } return firstErr diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index d196ad4d6c1f0..43ba0acf194f2 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -321,15 +321,7 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from net defer cancel() err = r.forwarder.forwardWithDestChan(ctx, packet{bs, family, from}, responses) if err != nil { - select { - // Best effort: use any error response sent by forwardWithDestChan. - // This is present in some errors paths, such as when all upstream - // DNS servers replied with an error. - case resp := <-responses: - return resp.bs, err - default: - return nil, err - } + return nil, err } return (<-responses).bs, nil } From e7545f2eac48ae9f35ba4a080d6e0b6ecfd054a4 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 4 Oct 2024 12:34:41 -0500 Subject: [PATCH 005/179] net/dns/resolver: translate 5xx DoH server errors into SERVFAIL DNS responses If a DoH server returns an HTTP server error, rather than a SERVFAIL within a successful HTTP response, we should handle it in the same way as SERVFAIL. Updates #13571 Signed-off-by: Nick Hill --- net/dns/resolver/forwarder.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 5920b7f29ec67..0bf9040704237 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -487,6 +487,10 @@ func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client, defer hres.Body.Close() if hres.StatusCode != 200 { metricDNSFwdDoHErrorStatus.Add(1) + if hres.StatusCode/100 == 5 { + // Translate 5xx HTTP server errors into SERVFAIL DNS responses. + return nil, fmt.Errorf("%w: %s", errServerFailure, hres.Status) + } return nil, errors.New(hres.Status) } if ct := hres.Header.Get("Content-Type"); ct != dohType { From c2144c44a33a373174624ede9f4f6ffe8334cf05 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 4 Oct 2024 15:11:46 -0500 Subject: [PATCH 006/179] net/dns/resolver: update (*forwarder).forwardWithDestChan to always return an error unless it sends a response to responseChan We currently have two executions paths where (*forwarder).forwardWithDestChan returns nil, rather than an error, without sending a DNS response to responseChan. These paths are accompanied by a comment that reads: // Returning an error will cause an internal retry, there is // nothing we can do if parsing failed. Just drop the packet. But it is not (or no longer longer) accurate: returning an error from forwardWithDestChan does not currently cause a retry. Moreover, although these paths are currently unreachable due to implementation details, if (*forwarder).forwardWithDestChan were to return nil without sending a response to responseChan, it would cause a deadlock at one call site and a panic at another. Therefore, we update (*forwarder).forwardWithDestChan to return errors in those two paths and remove comments that were no longer accurate and misleading. Updates #cleanup Updates #13571 Signed-off-by: Nick Hill --- net/dns/resolver/forwarder.go | 10 ++-------- net/dns/resolver/forwarder_test.go | 17 +++++++++++------ net/dns/resolver/tsdns_test.go | 4 ++-- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 0bf9040704237..c00dea1aea8c4 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -920,10 +920,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo metricDNSFwdDropBonjour.Add(1) res, err := nxDomainResponse(query) if err != nil { - f.logf("error parsing bonjour query: %v", err) - // Returning an error will cause an internal retry, there is - // nothing we can do if parsing failed. Just drop the packet. - return nil + return err } select { case <-ctx.Done(): @@ -955,10 +952,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo res, err := servfailResponse(query) if err != nil { - f.logf("building servfail response: %v", err) - // Returning an error will cause an internal retry, there is - // nothing we can do if parsing failed. Just drop the packet. - return nil + return err } select { case <-ctx.Done(): diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index 09d8109018156..9c0964e93a756 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -7,7 +7,6 @@ import ( "bytes" "context" "encoding/binary" - "errors" "flag" "fmt" "io" @@ -657,14 +656,20 @@ func TestForwarderTCPFallbackError(t *testing.T) { } }) - _, err := runTestQuery(t, port, request, nil) + resp, err := runTestQuery(t, port, request, nil) if !sawRequest.Load() { t.Error("did not see DNS request") } - if err == nil { - t.Error("wanted error, got nil") - } else if !errors.Is(err, errServerFailure) { - t.Errorf("wanted errServerFailure, got: %v", err) + if err != nil { + t.Fatalf("wanted nil, got %v", err) + } + var parser dns.Parser + respHeader, err := parser.Start(resp) + if err != nil { + t.Fatalf("parser.Start() failed: %v", err) + } + if got, want := respHeader.RCode, dns.RCodeServerFailure; got != want { + t.Errorf("wanted %v, got %v", want, got) } } diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index e2c4750b5c1a3..d7b9fb360eaf0 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -1503,8 +1503,8 @@ func TestServfail(t *testing.T) { r.SetConfig(cfg) pkt, err := syncRespond(r, dnspacket("test.site.", dns.TypeA, noEdns)) - if !errors.Is(err, errServerFailure) { - t.Errorf("err = %v, want %v", err, errServerFailure) + if err != nil { + t.Fatalf("err = %v, want nil", err) } wantPkt := []byte{ From f07ff47922c11377374ffe91a8dbe0fa12fb1b56 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Mon, 7 Oct 2024 17:08:22 -0500 Subject: [PATCH 007/179] net/dns/resolver: add tests for using a forwarder with multiple upstream resolvers If multiple upstream DNS servers are available, quad-100 sends requests to all of them and forwards the first successful response, if any. If no successful responses are received, it propagates the first failure from any of them. This PR adds some test coverage for these scenarios. Updates #13571 Signed-off-by: Nick Khyl --- net/dns/resolver/forwarder_test.go | 235 +++++++++++++++++++++++------ 1 file changed, 190 insertions(+), 45 deletions(-) diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index 9c0964e93a756..e341186ecf45e 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -449,7 +449,7 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) return } -func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) ([]byte, error) { +func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) { netMon, err := netmon.New(tb.Logf) if err != nil { tb.Fatal(err) @@ -463,8 +463,9 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa modify(fwd) } - rr := resolverAndDelay{ - name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)}, + resolvers := make([]resolverAndDelay, len(ports)) + for i, port := range ports { + resolvers[i].name = &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)} } rpkt := packet{ @@ -476,7 +477,7 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa rchan := make(chan packet, 1) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) tb.Cleanup(cancel) - err = fwd.forwardWithDestChan(ctx, rpkt, rchan, rr) + err = fwd.forwardWithDestChan(ctx, rpkt, rchan, resolvers...) select { case res := <-rchan: return res.bs, err @@ -485,8 +486,62 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa } } -func mustRunTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) []byte { - resp, err := runTestQuery(tb, port, request, modify) +// makeTestRequest returns a new TypeA request for the given domain. +func makeTestRequest(tb testing.TB, domain string) []byte { + tb.Helper() + name := dns.MustNewName(domain) + builder := dns.NewBuilder(nil, dns.Header{}) + builder.StartQuestions() + builder.Question(dns.Question{ + Name: name, + Type: dns.TypeA, + Class: dns.ClassINET, + }) + request, err := builder.Finish() + if err != nil { + tb.Fatal(err) + } + return request +} + +// makeTestResponse returns a new Type A response for the given domain, +// with the specified status code and zero or more addresses. +func makeTestResponse(tb testing.TB, domain string, code dns.RCode, addrs ...netip.Addr) []byte { + tb.Helper() + name := dns.MustNewName(domain) + builder := dns.NewBuilder(nil, dns.Header{ + Response: true, + Authoritative: true, + RCode: code, + }) + builder.StartQuestions() + q := dns.Question{ + Name: name, + Type: dns.TypeA, + Class: dns.ClassINET, + } + builder.Question(q) + if len(addrs) > 0 { + builder.StartAnswers() + for _, addr := range addrs { + builder.AResource(dns.ResourceHeader{ + Name: q.Name, + Class: q.Class, + TTL: 120, + }, dns.AResource{ + A: addr.As4(), + }) + } + } + response, err := builder.Finish() + if err != nil { + tb.Fatal(err) + } + return response +} + +func mustRunTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) []byte { + resp, err := runTestQuery(tb, request, modify, ports...) if err != nil { tb.Fatalf("error making request: %v", err) } @@ -515,7 +570,7 @@ func TestForwarderTCPFallback(t *testing.T) { } }) - resp := mustRunTestQuery(t, port, request, nil) + resp := mustRunTestQuery(t, request, nil, port) if !bytes.Equal(resp, largeResponse) { t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) } @@ -553,7 +608,7 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) { } }) - resp := mustRunTestQuery(t, port, request, nil) + resp := mustRunTestQuery(t, request, nil, port) if !bytes.Equal(resp, largeResponse) { t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) } @@ -584,11 +639,11 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) { } }) - resp := mustRunTestQuery(t, port, request, func(fwd *forwarder) { + resp := mustRunTestQuery(t, request, func(fwd *forwarder) { // Disable retries for this test. fwd.controlKnobs = &controlknobs.Knobs{} fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true) - }) + }, port) wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...) @@ -612,41 +667,10 @@ func TestForwarderTCPFallbackError(t *testing.T) { const domain = "error-response.tailscale.com." // Our response is a SERVFAIL - response := func() []byte { - name := dns.MustNewName(domain) - - builder := dns.NewBuilder(nil, dns.Header{ - Response: true, - RCode: dns.RCodeServerFailure, - }) - builder.StartQuestions() - builder.Question(dns.Question{ - Name: name, - Type: dns.TypeA, - Class: dns.ClassINET, - }) - response, err := builder.Finish() - if err != nil { - t.Fatal(err) - } - return response - }() + response := makeTestResponse(t, domain, dns.RCodeServerFailure) // Our request is a single A query for the domain in the answer, above. - request := func() []byte { - builder := dns.NewBuilder(nil, dns.Header{}) - builder.StartQuestions() - builder.Question(dns.Question{ - Name: dns.MustNewName(domain), - Type: dns.TypeA, - Class: dns.ClassINET, - }) - request, err := builder.Finish() - if err != nil { - t.Fatal(err) - } - return request - }() + request := makeTestRequest(t, domain) var sawRequest atomic.Bool port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { @@ -656,7 +680,7 @@ func TestForwarderTCPFallbackError(t *testing.T) { } }) - resp, err := runTestQuery(t, port, request, nil) + resp, err := runTestQuery(t, request, nil, port) if !sawRequest.Load() { t.Error("did not see DNS request") } @@ -673,6 +697,127 @@ func TestForwarderTCPFallbackError(t *testing.T) { } } +// Test to ensure that if we have more than one resolver, and at least one of them +// returns a successful response, we propagate it. +func TestForwarderWithManyResolvers(t *testing.T) { + enableDebug(t) + + const domain = "example.com." + request := makeTestRequest(t, domain) + + tests := []struct { + name string + responses [][]byte // upstream responses + wantResponses [][]byte // we should receive one of these from the forwarder + }{ + { + name: "Success", + responses: [][]byte{ // All upstream servers returned successful, but different, response. + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")), + }, + wantResponses: [][]byte{ // We may forward whichever response is received first. + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")), + }, + }, + { + name: "ServFail", + responses: [][]byte{ // All upstream servers returned a SERVFAIL. + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + wantResponses: [][]byte{ + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + }, + { + name: "ServFail+Success", + responses: [][]byte{ // All upstream servers fail except for one. + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + wantResponses: [][]byte{ // We should forward the successful response. + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + }, + { + name: "NXDomain", + responses: [][]byte{ // All upstream servers returned NXDOMAIN. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeNameError), + }, + wantResponses: [][]byte{ + makeTestResponse(t, domain, dns.RCodeNameError), + }, + }, + { + name: "NXDomain+Success", + responses: [][]byte{ // All upstream servers returned NXDOMAIN except for one. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + wantResponses: [][]byte{ // However, only SERVFAIL are considered to be errors. Therefore, we may forward any response. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + }, + { + name: "Refused", + responses: [][]byte{ // All upstream servers return different failures. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + wantResponses: [][]byte{ // Refused is not considered to be an error and can be forwarded. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + }, + { + name: "MixFail", + responses: [][]byte{ // All upstream servers return different failures. + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + wantResponses: [][]byte{ // Both NXDomain and Refused can be forwarded. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ports := make([]uint16, len(tt.responses)) + for i := range tt.responses { + ports[i] = runDNSServer(t, nil, tt.responses[i], func(isTCP bool, gotRequest []byte) {}) + } + gotResponse, err := runTestQuery(t, request, nil, ports...) + if err != nil { + t.Fatalf("wanted nil, got %v", err) + } + responseOk := slices.ContainsFunc(tt.wantResponses, func(wantResponse []byte) bool { + return slices.Equal(gotResponse, wantResponse) + }) + if !responseOk { + t.Errorf("invalid response\ngot: %+v\nwant: %+v", gotResponse, tt.wantResponses[0]) + } + }) + } +} + // mdnsResponder at minimum has an expectation that NXDOMAIN must include the // question, otherwise it will penalize our server (#13511). func TestNXDOMAINIncludesQuestion(t *testing.T) { @@ -718,7 +863,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) { port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { }) - res, err := runTestQuery(t, port, request, nil) + res, err := runTestQuery(t, request, nil, port) if err != nil { t.Fatal(err) } From ecc8035f73f62424298d2a36dc2d747601fb04c8 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Fri, 11 Oct 2024 13:12:18 -0700 Subject: [PATCH 008/179] types/bools: add Compare to compare boolean values (#13792) The bools.Compare function compares boolean values by reporting -1, 0, +1 for ordering so that it can be easily used with slices.SortFunc. Updates #cleanup Updates tailscale/corp#11038 Signed-off-by: Joe Tsai --- types/bools/compare.go | 17 +++++++++++++++++ types/bools/compare_test.go | 21 +++++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 types/bools/compare.go create mode 100644 types/bools/compare_test.go diff --git a/types/bools/compare.go b/types/bools/compare.go new file mode 100644 index 0000000000000..ac433b240755a --- /dev/null +++ b/types/bools/compare.go @@ -0,0 +1,17 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package bools contains the bools.Compare function. +package bools + +// Compare compares two boolean values as if false is ordered before true. +func Compare[T ~bool](x, y T) int { + switch { + case x == false && y == true: + return -1 + case x == true && y == false: + return +1 + default: + return 0 + } +} diff --git a/types/bools/compare_test.go b/types/bools/compare_test.go new file mode 100644 index 0000000000000..280294621e719 --- /dev/null +++ b/types/bools/compare_test.go @@ -0,0 +1,21 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package bools + +import "testing" + +func TestCompare(t *testing.T) { + if got := Compare(false, false); got != 0 { + t.Errorf("Compare(false, false) = %v, want 0", got) + } + if got := Compare(false, true); got != -1 { + t.Errorf("Compare(false, true) = %v, want -1", got) + } + if got := Compare(true, false); got != +1 { + t.Errorf("Compare(true, false) = %v, want +1", got) + } + if got := Compare(true, true); got != 0 { + t.Errorf("Compare(true, true) = %v, want 0", got) + } +} From 12e6094d9c7e8f856d5117235d18ad86d0812d32 Mon Sep 17 00:00:00 2001 From: Percy Wegmann Date: Fri, 11 Oct 2024 14:59:47 -0500 Subject: [PATCH 009/179] ssh/tailssh: calculate passthrough environment at latest possible stage This allows passing through any environment variables that we set ourselves, for example DBUS_SESSION_BUS_ADDRESS. Updates #11175 Co-authored-by: Mario Minardi Signed-off-by: Percy Wegmann --- ssh/tailssh/incubator.go | 52 ++++++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index 7748376b2548b..3ff676d519898 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -210,8 +210,6 @@ type incubatorArgs struct { debugTest bool isSELinuxEnforcing bool encodedEnv string - allowListEnvKeys string - forwardedEnviron []string } func parseIncubatorArgs(args []string) (incubatorArgs, error) { @@ -246,31 +244,35 @@ func parseIncubatorArgs(args []string) (incubatorArgs, error) { ia.gids = append(ia.gids, gid) } - ia.forwardedEnviron = os.Environ() + return ia, nil +} + +func (ia incubatorArgs) forwadedEnviron() ([]string, string, error) { + environ := os.Environ() // pass through SSH_AUTH_SOCK environment variable to support ssh agent forwarding - ia.allowListEnvKeys = "SSH_AUTH_SOCK" + allowListKeys := "SSH_AUTH_SOCK" if ia.encodedEnv != "" { unquoted, err := strconv.Unquote(ia.encodedEnv) if err != nil { - return ia, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) + return nil, "", fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) } var extraEnviron []string err = json.Unmarshal([]byte(unquoted), &extraEnviron) if err != nil { - return ia, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) + return nil, "", fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) } - ia.forwardedEnviron = append(ia.forwardedEnviron, extraEnviron...) + environ = append(environ, extraEnviron...) for _, v := range extraEnviron { - ia.allowListEnvKeys = fmt.Sprintf("%s,%s", ia.allowListEnvKeys, strings.Split(v, "=")[0]) + allowListKeys = fmt.Sprintf("%s,%s", allowListKeys, strings.Split(v, "=")[0]) } } - return ia, nil + return environ, allowListKeys, nil } // beIncubator is the entrypoint to the `tailscaled be-child ssh` subcommand. @@ -450,8 +452,13 @@ func tryExecLogin(dlogf logger.Logf, ia incubatorArgs) error { loginArgs := ia.loginArgs(loginCmdPath) dlogf("logging in with %+v", loginArgs) + environ, _, err := ia.forwadedEnviron() + if err != nil { + return err + } + // If Exec works, the Go code will not proceed past this: - err = unix.Exec(loginCmdPath, loginArgs, ia.forwardedEnviron) + err = unix.Exec(loginCmdPath, loginArgs, environ) // If we made it here, Exec failed. return err @@ -484,9 +491,14 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) { defer sessionCloser() } + environ, allowListEnvKeys, err := ia.forwadedEnviron() + if err != nil { + return false, err + } + loginArgs := []string{ su, - "-w", ia.allowListEnvKeys, + "-w", allowListEnvKeys, "-l", ia.localUser, } @@ -498,7 +510,7 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) { dlogf("logging in with %+v", loginArgs) // If Exec works, the Go code will not proceed past this: - err = unix.Exec(su, loginArgs, ia.forwardedEnviron) + err = unix.Exec(su, loginArgs, environ) // If we made it here, Exec failed. return true, err @@ -527,11 +539,16 @@ func findSU(dlogf logger.Logf, ia incubatorArgs) string { return "" } + _, allowListEnvKeys, err := ia.forwadedEnviron() + if err != nil { + return "" + } + // First try to execute su -w -l -c true // to make sure su supports the necessary arguments. err = exec.Command( su, - "-w", ia.allowListEnvKeys, + "-w", allowListEnvKeys, "-l", ia.localUser, "-c", "true", @@ -558,10 +575,15 @@ func handleSSHInProcess(dlogf logger.Logf, ia incubatorArgs) error { return err } + environ, _, err := ia.forwadedEnviron() + if err != nil { + return err + } + args := shellArgs(ia.isShell, ia.cmd) dlogf("running %s %q", ia.loginShell, args) - cmd := newCommand(ia.hasTTY, ia.loginShell, ia.forwardedEnviron, args) - err := cmd.Run() + cmd := newCommand(ia.hasTTY, ia.loginShell, environ, args) + err = cmd.Run() if ee, ok := err.(*exec.ExitError); ok { ps := ee.ProcessState code := ps.ExitCode() From adc83689649f4f7e5c576ed6a697bf8c0d4bef8c Mon Sep 17 00:00:00 2001 From: Paul Scott <408401+icio@users.noreply.github.com> Date: Mon, 14 Oct 2024 10:02:04 +0100 Subject: [PATCH 010/179] tstest: avoid Fatal in ResourceCheck to show panic (#13790) Fixes #13789 Signed-off-by: Paul Scott --- tstest/resource.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tstest/resource.go b/tstest/resource.go index a3c292094fac6..b094c7911014f 100644 --- a/tstest/resource.go +++ b/tstest/resource.go @@ -29,7 +29,8 @@ func ResourceCheck(tb testing.TB) { startN, startStacks := goroutines() tb.Cleanup(func() { if tb.Failed() { - // Something else went wrong. + // Test has failed - but this doesn't catch panics due to + // https://github.com/golang/go/issues/49929. return } // Goroutines might be still exiting. @@ -44,7 +45,10 @@ func ResourceCheck(tb testing.TB) { return } tb.Logf("goroutine diff:\n%v\n", cmp.Diff(startStacks, endStacks)) - tb.Fatalf("goroutine count: expected %d, got %d\n", startN, endN) + + // tb.Failed() above won't report on panics, so we shouldn't call Fatal + // here or we risk suppressing reporting of the panic. + tb.Errorf("goroutine count: expected %d, got %d\n", startN, endN) }) } From 40c991f6b85b6a5ff1a4b440650750e95c755f61 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 25 Sep 2024 17:20:56 +0200 Subject: [PATCH 011/179] wgengine: instrument with usermetrics Updates tailscale/corp#22075 Signed-off-by: Kristoffer Dalby --- tsnet/tsnet_test.go | 146 ++++++++++++++++++++- util/clientmetric/clientmetric.go | 51 ++++++++ util/clientmetric/clientmetric_test.go | 49 ++++++++ wgengine/magicsock/derp.go | 6 +- wgengine/magicsock/endpoint.go | 27 ++-- wgengine/magicsock/magicsock.go | 167 +++++++++++++++++++++++-- wgengine/magicsock/magicsock_test.go | 86 +++++++++++++ 7 files changed, 509 insertions(+), 23 deletions(-) diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 255baf618c0b3..98c1fd4ab3462 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -36,6 +36,7 @@ import ( dto "github.com/prometheus/client_model/go" "github.com/prometheus/common/expfmt" "golang.org/x/net/proxy" + "tailscale.com/client/tailscale" "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/health" "tailscale.com/ipn" @@ -874,6 +875,78 @@ func promMetricLabelsStr(labels []*dto.LabelPair) string { return b.String() } +// sendData sends a given amount of bytes from s1 to s2. +func sendData(logf func(format string, args ...any), ctx context.Context, bytesCount int, s1, s2 *Server, s1ip, s2ip netip.Addr) error { + l := must.Get(s1.Listen("tcp", fmt.Sprintf("%s:8081", s1ip))) + defer l.Close() + + // Dial to s1 from s2 + w, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)) + if err != nil { + return err + } + defer w.Close() + + stopReceive := make(chan struct{}) + defer close(stopReceive) + allReceived := make(chan error) + defer close(allReceived) + + go func() { + conn, err := l.Accept() + if err != nil { + allReceived <- err + return + } + conn.SetWriteDeadline(time.Now().Add(30 * time.Second)) + + total := 0 + recvStart := time.Now() + for { + got := make([]byte, bytesCount) + n, err := conn.Read(got) + if n != bytesCount { + logf("read %d bytes, want %d", n, bytesCount) + } + + select { + case <-stopReceive: + return + default: + } + + if err != nil { + allReceived <- fmt.Errorf("failed reading packet, %s", err) + return + } + + total += n + logf("received %d/%d bytes, %.2f %%", total, bytesCount, (float64(total) / (float64(bytesCount)) * 100)) + if total == bytesCount { + break + } + } + + logf("all received, took: %s", time.Since(recvStart).String()) + allReceived <- nil + }() + + sendStart := time.Now() + w.SetWriteDeadline(time.Now().Add(30 * time.Second)) + if _, err := w.Write(bytes.Repeat([]byte("A"), bytesCount)); err != nil { + stopReceive <- struct{}{} + return err + } + + logf("all sent (%s), waiting for all packets (%d) to be received", time.Since(sendStart).String(), bytesCount) + err, _ = <-allReceived + if err != nil { + return err + } + + return nil +} + func TestUserMetrics(t *testing.T) { flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/13420") tstest.ResourceCheck(t) @@ -882,7 +955,7 @@ func TestUserMetrics(t *testing.T) { controlURL, c := startControl(t) s1, s1ip, s1PubKey := startServer(t, ctx, controlURL, "s1") - s2, _, _ := startServer(t, ctx, controlURL, "s2") + s2, s2ip, _ := startServer(t, ctx, controlURL, "s2") s1.lb.EditPrefs(&ipn.MaskedPrefs{ Prefs: ipn.Prefs{ @@ -951,6 +1024,20 @@ func TestUserMetrics(t *testing.T) { return status1.Self.PrimaryRoutes != nil && status1.Self.PrimaryRoutes.Len() == int(wantRoutes)+1 }) + mustDirect(t, t.Logf, lc1, lc2) + + // 10 megabytes + bytesToSend := 10 * 1024 * 1024 + + // This asserts generates some traffic, it is factored out + // of TestUDPConn. + start := time.Now() + err = sendData(t.Logf, ctx, bytesToSend, s1, s2, s1ip, s2ip) + if err != nil { + t.Fatalf("Failed to send packets: %v", err) + } + t.Logf("Sent %d bytes from s1 to s2 in %s", bytesToSend, time.Since(start).String()) + ctxLc, cancelLc := context.WithTimeout(context.Background(), 5*time.Second) defer cancelLc() metrics1, err := lc1.UserMetrics(ctxLc) @@ -968,6 +1055,9 @@ func TestUserMetrics(t *testing.T) { t.Fatal(err) } + // Allow the metrics for the bytes sent to be off by 15%. + bytesSentTolerance := 1.15 + t.Logf("Metrics1:\n%s\n", metrics1) // The node is advertising 4 routes: @@ -997,6 +1087,18 @@ func TestUserMetrics(t *testing.T) { t.Errorf("metrics1, tailscaled_primary_routes: got %v, want %v", got, want) } + // Verify that the amount of data recorded in bytes is higher or equal to the + // 10 megabytes sent. + inboundBytes1 := parsedMetrics1[`tailscaled_inbound_bytes_total{path="direct_ipv4"}`] + if inboundBytes1 < float64(bytesToSend) { + t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, inboundBytes1) + } + + // But ensure that it is not too much higher than the 10 megabytes sent. + if inboundBytes1 > float64(bytesToSend)*bytesSentTolerance { + t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, inboundBytes1) + } + metrics2, err := lc2.UserMetrics(ctx) if err != nil { t.Fatal(err) @@ -1033,6 +1135,18 @@ func TestUserMetrics(t *testing.T) { if got, want := parsedMetrics2["tailscaled_primary_routes"], 0.0; got != want { t.Errorf("metrics2, tailscaled_primary_routes: got %v, want %v", got, want) } + + // Verify that the amount of data recorded in bytes is higher or equal than the + // 10 megabytes sent. + outboundBytes2 := parsedMetrics2[`tailscaled_outbound_bytes_total{path="direct_ipv4"}`] + if outboundBytes2 < float64(bytesToSend) { + t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, outboundBytes2) + } + + // But ensure that it is not too much higher than the 10 megabytes sent. + if outboundBytes2 > float64(bytesToSend)*bytesSentTolerance { + t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, outboundBytes2) + } } func waitForCondition(t *testing.T, msg string, waitTime time.Duration, f func() bool) { @@ -1044,3 +1158,33 @@ func waitForCondition(t *testing.T, msg string, waitTime time.Duration, f func() } t.Fatalf("waiting for condition: %s", msg) } + +// mustDirect ensures there is a direct connection between LocalClient 1 and 2 +func mustDirect(t *testing.T, logf logger.Logf, lc1, lc2 *tailscale.LocalClient) { + t.Helper() + lastLog := time.Now().Add(-time.Minute) + // See https://github.com/tailscale/tailscale/issues/654 + // and https://github.com/tailscale/tailscale/issues/3247 for discussions of this deadline. + for deadline := time.Now().Add(30 * time.Second); time.Now().Before(deadline); time.Sleep(10 * time.Millisecond) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + status1, err := lc1.Status(ctx) + if err != nil { + continue + } + status2, err := lc2.Status(ctx) + if err != nil { + continue + } + pst := status1.Peer[status2.Self.PublicKey] + if pst.CurAddr != "" { + logf("direct link %s->%s found with addr %s", status1.Self.HostName, status2.Self.HostName, pst.CurAddr) + return + } + if now := time.Now(); now.Sub(lastLog) > time.Second { + logf("no direct path %s->%s yet, addrs %v", status1.Self.HostName, status2.Self.HostName, pst.Addrs) + lastLog = now + } + } + t.Error("magicsock did not find a direct path from lc1 to lc2") +} diff --git a/util/clientmetric/clientmetric.go b/util/clientmetric/clientmetric.go index b2d356b60fcc3..584a24f73dca8 100644 --- a/util/clientmetric/clientmetric.go +++ b/util/clientmetric/clientmetric.go @@ -9,6 +9,7 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "expvar" "fmt" "io" "sort" @@ -16,6 +17,8 @@ import ( "sync" "sync/atomic" "time" + + "tailscale.com/util/set" ) var ( @@ -223,6 +226,54 @@ func NewGaugeFunc(name string, f func() int64) *Metric { return m } +// AggregateCounter returns a sum of expvar counters registered with it. +type AggregateCounter struct { + mu sync.RWMutex + counters set.Set[*expvar.Int] +} + +func (c *AggregateCounter) Value() int64 { + c.mu.RLock() + defer c.mu.RUnlock() + var sum int64 + for cnt := range c.counters { + sum += cnt.Value() + } + return sum +} + +// Register registers provided expvar counter. +// When a counter is added to the counter, it will be reset +// to start counting from 0. This is to avoid incrementing the +// counter with an unexpectedly large value. +func (c *AggregateCounter) Register(counter *expvar.Int) { + c.mu.Lock() + defer c.mu.Unlock() + // No need to do anything if it's already registered. + if c.counters.Contains(counter) { + return + } + counter.Set(0) + c.counters.Add(counter) +} + +// UnregisterAll unregisters all counters resulting in it +// starting back down at zero. This is to ensure monotonicity +// and respect the semantics of the counter. +func (c *AggregateCounter) UnregisterAll() { + c.mu.Lock() + defer c.mu.Unlock() + c.counters = set.Set[*expvar.Int]{} +} + +// NewAggregateCounter returns a new aggregate counter that returns +// a sum of expvar variables registered with it. +func NewAggregateCounter(name string) *AggregateCounter { + c := &AggregateCounter{counters: set.Set[*expvar.Int]{}} + NewGaugeFunc(name, c.Value) + return c +} + // WritePrometheusExpositionFormat writes all client metrics to w in // the Prometheus text-based exposition format. // diff --git a/util/clientmetric/clientmetric_test.go b/util/clientmetric/clientmetric_test.go index ab6c4335afb41..555d7a71170a4 100644 --- a/util/clientmetric/clientmetric_test.go +++ b/util/clientmetric/clientmetric_test.go @@ -4,8 +4,11 @@ package clientmetric import ( + "expvar" "testing" "time" + + qt "github.com/frankban/quicktest" ) func TestDeltaEncBuf(t *testing.T) { @@ -107,3 +110,49 @@ func TestWithFunc(t *testing.T) { t.Errorf("second = %q; want %q", got, want) } } + +func TestAggregateCounter(t *testing.T) { + clearMetrics() + + c := qt.New(t) + + expv1 := &expvar.Int{} + expv2 := &expvar.Int{} + expv3 := &expvar.Int{} + + aggCounter := NewAggregateCounter("agg_counter") + + aggCounter.Register(expv1) + c.Assert(aggCounter.Value(), qt.Equals, int64(0)) + + expv1.Add(1) + c.Assert(aggCounter.Value(), qt.Equals, int64(1)) + + aggCounter.Register(expv2) + c.Assert(aggCounter.Value(), qt.Equals, int64(1)) + + expv1.Add(1) + expv2.Add(1) + c.Assert(aggCounter.Value(), qt.Equals, int64(3)) + + // Adding a new expvar should not change the value + // and any value the counter already had is reset + expv3.Set(5) + aggCounter.Register(expv3) + c.Assert(aggCounter.Value(), qt.Equals, int64(3)) + + // Registering the same expvar multiple times should not change the value + aggCounter.Register(expv3) + c.Assert(aggCounter.Value(), qt.Equals, int64(3)) + + aggCounter.UnregisterAll() + c.Assert(aggCounter.Value(), qt.Equals, int64(0)) + + // Start over + expv3.Set(5) + aggCounter.Register(expv3) + c.Assert(aggCounter.Value(), qt.Equals, int64(0)) + + expv3.Set(5) + c.Assert(aggCounter.Value(), qt.Equals, int64(5)) +} diff --git a/wgengine/magicsock/derp.go b/wgengine/magicsock/derp.go index 69c5cbc90a5f4..281447ac229ae 100644 --- a/wgengine/magicsock/derp.go +++ b/wgengine/magicsock/derp.go @@ -669,7 +669,8 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan c.logf("magicsock: derp.Send(%v): %v", wr.addr, err) metricSendDERPError.Add(1) } else { - metricSendDERP.Add(1) + c.metrics.outboundPacketsDERPTotal.Add(1) + c.metrics.outboundBytesDERPTotal.Add(int64(len(wr.b))) } } } @@ -690,7 +691,8 @@ func (c *connBind) receiveDERP(buffs [][]byte, sizes []int, eps []conn.Endpoint) // No data read occurred. Wait for another packet. continue } - metricRecvDataDERP.Add(1) + c.metrics.inboundPacketsDERPTotal.Add(1) + c.metrics.inboundBytesDERPTotal.Add(int64(n)) sizes[0] = n eps[0] = ep return 1, nil diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index 53ecb84de833b..78b9ee92a06f5 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -960,26 +960,39 @@ func (de *endpoint) send(buffs [][]byte) error { de.noteBadEndpoint(udpAddr) } + var txBytes int + for _, b := range buffs { + txBytes += len(b) + } + + switch { + case udpAddr.Addr().Is4(): + de.c.metrics.outboundPacketsIPv4Total.Add(int64(len(buffs))) + de.c.metrics.outboundBytesIPv4Total.Add(int64(txBytes)) + case udpAddr.Addr().Is6(): + de.c.metrics.outboundPacketsIPv6Total.Add(int64(len(buffs))) + de.c.metrics.outboundBytesIPv6Total.Add(int64(txBytes)) + } + // TODO(raggi): needs updating for accuracy, as in error conditions we may have partial sends. if stats := de.c.stats.Load(); err == nil && stats != nil { - var txBytes int - for _, b := range buffs { - txBytes += len(b) - } stats.UpdateTxPhysical(de.nodeAddr, udpAddr, txBytes) } } if derpAddr.IsValid() { allOk := true + var txBytes int for _, buff := range buffs { ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff) - if stats := de.c.stats.Load(); stats != nil { - stats.UpdateTxPhysical(de.nodeAddr, derpAddr, len(buff)) - } + txBytes += len(buff) if !ok { allOk = false } } + + if stats := de.c.stats.Load(); stats != nil { + stats.UpdateTxPhysical(de.nodeAddr, derpAddr, txBytes) + } if allOk { return nil } diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 08aff842d77aa..2d4944baf6fd0 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -10,6 +10,7 @@ import ( "bytes" "context" "errors" + "expvar" "fmt" "io" "net" @@ -80,6 +81,54 @@ const ( socketBufferSize = 7 << 20 ) +// Path is a label indicating the type of path a packet took. +type Path string + +const ( + PathDirectIPv4 Path = "direct_ipv4" + PathDirectIPv6 Path = "direct_ipv6" + PathDERP Path = "derp" +) + +type pathLabel struct { + // Path indicates the path that the packet took: + // - direct_ipv4 + // - direct_ipv6 + // - derp + Path Path +} + +// metrics in wgengine contains the usermetrics counters for magicsock, it +// is however a bit special. All them metrics are labeled, but looking up +// the metric everytime we need to record it has an overhead, and includes +// a lock in MultiLabelMap. The metrics are therefore instead created with +// wgengine and the underlying expvar.Int is stored to be used directly. +type metrics struct { + // inboundPacketsTotal is the total number of inbound packets received, + // labeled by the path the packet took. + inboundPacketsIPv4Total expvar.Int + inboundPacketsIPv6Total expvar.Int + inboundPacketsDERPTotal expvar.Int + + // inboundBytesTotal is the total number of inbound bytes received, + // labeled by the path the packet took. + inboundBytesIPv4Total expvar.Int + inboundBytesIPv6Total expvar.Int + inboundBytesDERPTotal expvar.Int + + // outboundPacketsTotal is the total number of outbound packets sent, + // labeled by the path the packet took. + outboundPacketsIPv4Total expvar.Int + outboundPacketsIPv6Total expvar.Int + outboundPacketsDERPTotal expvar.Int + + // outboundBytesTotal is the total number of outbound bytes sent, + // labeled by the path the packet took. + outboundBytesIPv4Total expvar.Int + outboundBytesIPv6Total expvar.Int + outboundBytesDERPTotal expvar.Int +} + // A Conn routes UDP packets and actively manages a list of its endpoints. type Conn struct { // This block mirrors the contents and field order of the Options @@ -321,6 +370,9 @@ type Conn struct { // responsibility to ensure that traffic from these endpoints is routed // to the node. staticEndpoints views.Slice[netip.AddrPort] + + // metrics contains the metrics for the magicsock instance. + metrics *metrics } // SetDebugLoggingEnabled controls whether spammy debug logging is enabled. @@ -503,6 +555,8 @@ func NewConn(opts Options) (*Conn, error) { UseDNSCache: true, } + c.metrics = registerMetrics(opts.Metrics) + if d4, err := c.listenRawDisco("ip4"); err == nil { c.logf("[v1] using BPF disco receiver for IPv4") c.closeDisco4 = d4 @@ -520,6 +574,76 @@ func NewConn(opts Options) (*Conn, error) { return c, nil } +// registerMetrics wires up the metrics for wgengine, instead of +// registering the label metric directly, the underlying expvar is exposed. +// See metrics for more info. +func registerMetrics(reg *usermetric.Registry) *metrics { + pathDirectV4 := pathLabel{Path: PathDirectIPv4} + pathDirectV6 := pathLabel{Path: PathDirectIPv6} + pathDERP := pathLabel{Path: PathDERP} + inboundPacketsTotal := usermetric.NewMultiLabelMapWithRegistry[pathLabel]( + reg, + "tailscaled_inbound_packets_total", + "counter", + "Counts the number of packets received from other peers", + ) + inboundBytesTotal := usermetric.NewMultiLabelMapWithRegistry[pathLabel]( + reg, + "tailscaled_inbound_bytes_total", + "counter", + "Counts the number of bytes received from other peers", + ) + outboundPacketsTotal := usermetric.NewMultiLabelMapWithRegistry[pathLabel]( + reg, + "tailscaled_outbound_packets_total", + "counter", + "Counts the number of packets sent to other peers", + ) + outboundBytesTotal := usermetric.NewMultiLabelMapWithRegistry[pathLabel]( + reg, + "tailscaled_outbound_bytes_total", + "counter", + "Counts the number of bytes sent to other peers", + ) + m := new(metrics) + + // Map clientmetrics to the usermetric counters. + metricRecvDataPacketsIPv4.Register(&m.inboundPacketsIPv4Total) + metricRecvDataPacketsIPv6.Register(&m.inboundPacketsIPv6Total) + metricRecvDataPacketsDERP.Register(&m.inboundPacketsDERPTotal) + metricSendUDP.Register(&m.outboundPacketsIPv4Total) + metricSendUDP.Register(&m.outboundPacketsIPv6Total) + metricSendDERP.Register(&m.outboundPacketsDERPTotal) + + inboundPacketsTotal.Set(pathDirectV4, &m.inboundPacketsIPv4Total) + inboundPacketsTotal.Set(pathDirectV6, &m.inboundPacketsIPv6Total) + inboundPacketsTotal.Set(pathDERP, &m.inboundPacketsDERPTotal) + + inboundBytesTotal.Set(pathDirectV4, &m.inboundBytesIPv4Total) + inboundBytesTotal.Set(pathDirectV6, &m.inboundBytesIPv6Total) + inboundBytesTotal.Set(pathDERP, &m.inboundBytesDERPTotal) + + outboundPacketsTotal.Set(pathDirectV4, &m.outboundPacketsIPv4Total) + outboundPacketsTotal.Set(pathDirectV6, &m.outboundPacketsIPv6Total) + outboundPacketsTotal.Set(pathDERP, &m.outboundPacketsDERPTotal) + + outboundBytesTotal.Set(pathDirectV4, &m.outboundBytesIPv4Total) + outboundBytesTotal.Set(pathDirectV6, &m.outboundBytesIPv6Total) + outboundBytesTotal.Set(pathDERP, &m.outboundBytesDERPTotal) + + return m +} + +// deregisterMetrics unregisters the underlying usermetrics expvar counters +// from clientmetrics. +func deregisterMetrics(m *metrics) { + metricRecvDataPacketsIPv4.UnregisterAll() + metricRecvDataPacketsIPv6.UnregisterAll() + metricRecvDataPacketsDERP.UnregisterAll() + metricSendUDP.UnregisterAll() + metricSendDERP.UnregisterAll() +} + // InstallCaptureHook installs a callback which is called to // log debug information into the pcap stream. This function // can be called with a nil argument to uninstall the capture @@ -1140,7 +1264,14 @@ func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte) (sent bool, err error) { _ = c.maybeRebindOnError(runtime.GOOS, err) } else { if sent { - metricSendUDP.Add(1) + switch { + case ipp.Addr().Is4(): + c.metrics.outboundPacketsIPv4Total.Add(1) + c.metrics.outboundBytesIPv4Total.Add(int64(len(b))) + case ipp.Addr().Is6(): + c.metrics.outboundPacketsIPv6Total.Add(1) + c.metrics.outboundBytesIPv6Total.Add(int64(len(b))) + } } } return @@ -1278,19 +1409,24 @@ func (c *Conn) putReceiveBatch(batch *receiveBatch) { c.receiveBatchPool.Put(batch) } -// receiveIPv4 creates an IPv4 ReceiveFunc reading from c.pconn4. func (c *Conn) receiveIPv4() conn.ReceiveFunc { - return c.mkReceiveFunc(&c.pconn4, c.health.ReceiveFuncStats(health.ReceiveIPv4), metricRecvDataIPv4) + return c.mkReceiveFunc(&c.pconn4, c.health.ReceiveFuncStats(health.ReceiveIPv4), + &c.metrics.inboundPacketsIPv4Total, + &c.metrics.inboundBytesIPv4Total, + ) } // receiveIPv6 creates an IPv6 ReceiveFunc reading from c.pconn6. func (c *Conn) receiveIPv6() conn.ReceiveFunc { - return c.mkReceiveFunc(&c.pconn6, c.health.ReceiveFuncStats(health.ReceiveIPv6), metricRecvDataIPv6) + return c.mkReceiveFunc(&c.pconn6, c.health.ReceiveFuncStats(health.ReceiveIPv6), + &c.metrics.inboundPacketsIPv6Total, + &c.metrics.inboundBytesIPv6Total, + ) } // mkReceiveFunc creates a ReceiveFunc reading from ruc. -// The provided healthItem and metric are updated if non-nil. -func (c *Conn) mkReceiveFunc(ruc *RebindingUDPConn, healthItem *health.ReceiveFuncStats, metric *clientmetric.Metric) conn.ReceiveFunc { +// The provided healthItem and metrics are updated if non-nil. +func (c *Conn) mkReceiveFunc(ruc *RebindingUDPConn, healthItem *health.ReceiveFuncStats, packetMetric, bytesMetric *expvar.Int) conn.ReceiveFunc { // epCache caches an IPPort->endpoint for hot flows. var epCache ippEndpointCache @@ -1327,8 +1463,11 @@ func (c *Conn) mkReceiveFunc(ruc *RebindingUDPConn, healthItem *health.ReceiveFu } ipp := msg.Addr.(*net.UDPAddr).AddrPort() if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &epCache); ok { - if metric != nil { - metric.Add(1) + if packetMetric != nil { + packetMetric.Add(1) + } + if bytesMetric != nil { + bytesMetric.Add(int64(msg.N)) } eps[i] = ep sizes[i] = msg.N @@ -2377,6 +2516,8 @@ func (c *Conn) Close() error { pinger.Close() } + deregisterMetrics(c.metrics) + return nil } @@ -2930,17 +3071,17 @@ var ( metricSendDERPErrorChan = clientmetric.NewCounter("magicsock_send_derp_error_chan") metricSendDERPErrorClosed = clientmetric.NewCounter("magicsock_send_derp_error_closed") metricSendDERPErrorQueue = clientmetric.NewCounter("magicsock_send_derp_error_queue") - metricSendUDP = clientmetric.NewCounter("magicsock_send_udp") + metricSendUDP = clientmetric.NewAggregateCounter("magicsock_send_udp") metricSendUDPError = clientmetric.NewCounter("magicsock_send_udp_error") - metricSendDERP = clientmetric.NewCounter("magicsock_send_derp") + metricSendDERP = clientmetric.NewAggregateCounter("magicsock_send_derp") metricSendDERPError = clientmetric.NewCounter("magicsock_send_derp_error") // Data packets (non-disco) metricSendData = clientmetric.NewCounter("magicsock_send_data") metricSendDataNetworkDown = clientmetric.NewCounter("magicsock_send_data_network_down") - metricRecvDataDERP = clientmetric.NewCounter("magicsock_recv_data_derp") - metricRecvDataIPv4 = clientmetric.NewCounter("magicsock_recv_data_ipv4") - metricRecvDataIPv6 = clientmetric.NewCounter("magicsock_recv_data_ipv6") + metricRecvDataPacketsDERP = clientmetric.NewAggregateCounter("magicsock_recv_data_derp") + metricRecvDataPacketsIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_ipv4") + metricRecvDataPacketsIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_ipv6") // Disco packets metricSendDiscoUDP = clientmetric.NewCounter("magicsock_disco_send_udp") diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 6b2d961b9b6fd..c1b8eef223257 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -28,6 +28,7 @@ import ( "time" "unsafe" + qt "github.com/frankban/quicktest" wgconn "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun/tuntest" @@ -1188,6 +1189,91 @@ func testTwoDevicePing(t *testing.T, d *devices) { checkStats(t, m1, m1Conns) checkStats(t, m2, m2Conns) }) + t.Run("compare-metrics-stats", func(t *testing.T) { + setT(t) + defer setT(outerT) + m1.conn.resetMetricsForTest() + m1.stats.TestExtract() + m2.conn.resetMetricsForTest() + m2.stats.TestExtract() + t.Logf("Metrics before: %s\n", m1.metrics.String()) + ping1(t) + ping2(t) + assertConnStatsAndUserMetricsEqual(t, m1) + assertConnStatsAndUserMetricsEqual(t, m2) + t.Logf("Metrics after: %s\n", m1.metrics.String()) + }) +} + +func (c *Conn) resetMetricsForTest() { + c.metrics.inboundBytesIPv4Total.Set(0) + c.metrics.inboundPacketsIPv4Total.Set(0) + c.metrics.outboundBytesIPv4Total.Set(0) + c.metrics.outboundPacketsIPv4Total.Set(0) + c.metrics.inboundBytesIPv6Total.Set(0) + c.metrics.inboundPacketsIPv6Total.Set(0) + c.metrics.outboundBytesIPv6Total.Set(0) + c.metrics.outboundPacketsIPv6Total.Set(0) + c.metrics.inboundBytesDERPTotal.Set(0) + c.metrics.inboundPacketsDERPTotal.Set(0) + c.metrics.outboundBytesDERPTotal.Set(0) + c.metrics.outboundPacketsDERPTotal.Set(0) +} + +func assertConnStatsAndUserMetricsEqual(t *testing.T, ms *magicStack) { + _, phys := ms.stats.TestExtract() + + physIPv4RxBytes := int64(0) + physIPv4TxBytes := int64(0) + physDERPRxBytes := int64(0) + physDERPTxBytes := int64(0) + physIPv4RxPackets := int64(0) + physIPv4TxPackets := int64(0) + physDERPRxPackets := int64(0) + physDERPTxPackets := int64(0) + for conn, count := range phys { + t.Logf("physconn src: %s, dst: %s", conn.Src.String(), conn.Dst.String()) + if conn.Dst.String() == "127.3.3.40:1" { + physDERPRxBytes += int64(count.RxBytes) + physDERPTxBytes += int64(count.TxBytes) + physDERPRxPackets += int64(count.RxPackets) + physDERPTxPackets += int64(count.TxPackets) + } else { + physIPv4RxBytes += int64(count.RxBytes) + physIPv4TxBytes += int64(count.TxBytes) + physIPv4RxPackets += int64(count.RxPackets) + physIPv4TxPackets += int64(count.TxPackets) + } + } + + metricIPv4RxBytes := ms.conn.metrics.inboundBytesIPv4Total.Value() + metricIPv4RxPackets := ms.conn.metrics.inboundPacketsIPv4Total.Value() + metricIPv4TxBytes := ms.conn.metrics.outboundBytesIPv4Total.Value() + metricIPv4TxPackets := ms.conn.metrics.outboundPacketsIPv4Total.Value() + + metricDERPRxBytes := ms.conn.metrics.inboundBytesDERPTotal.Value() + metricDERPRxPackets := ms.conn.metrics.inboundPacketsDERPTotal.Value() + metricDERPTxBytes := ms.conn.metrics.outboundBytesDERPTotal.Value() + metricDERPTxPackets := ms.conn.metrics.outboundPacketsDERPTotal.Value() + + c := qt.New(t) + c.Assert(physDERPRxBytes, qt.Equals, metricDERPRxBytes) + c.Assert(physDERPTxBytes, qt.Equals, metricDERPTxBytes) + c.Assert(physIPv4RxBytes, qt.Equals, metricIPv4RxBytes) + c.Assert(physIPv4TxBytes, qt.Equals, metricIPv4TxBytes) + c.Assert(physDERPRxPackets, qt.Equals, metricDERPRxPackets) + c.Assert(physDERPTxPackets, qt.Equals, metricDERPTxPackets) + c.Assert(physIPv4RxPackets, qt.Equals, metricIPv4RxPackets) + c.Assert(physIPv4TxPackets, qt.Equals, metricIPv4TxPackets) + + // Validate that the usermetrics and clientmetrics are in sync + // Note: the clientmetrics are global, this means that when they are registering with the + // wgengine, multiple in-process nodes used by this test will be updating the same metrics. This is why we need to multiply + // the metrics by 2 to get the expected value. + // TODO(kradalby): https://github.com/tailscale/tailscale/issues/13420 + c.Assert(metricSendUDP.Value(), qt.Equals, metricIPv4TxPackets*2) + c.Assert(metricRecvDataPacketsIPv4.Value(), qt.Equals, metricIPv4RxPackets*2) + c.Assert(metricRecvDataPacketsDERP.Value(), qt.Equals, metricDERPRxPackets*2) } func TestDiscoMessage(t *testing.T) { From e0d711c478e335e99302a5320b43538337fa298b Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 23 Sep 2024 17:07:38 +0200 Subject: [PATCH 012/179] {net/connstats,wgengine/magicsock}: fix packet counting in connstats connstats currently increments the packet counter whenever it is called to store a length of data, however when udp batch sending was introduced we pass the length for a series of packages, and it is only incremented ones, making it count wrongly if we are on a platform supporting udp batches. Updates tailscale/corp#22075 Signed-off-by: Kristoffer Dalby --- net/connstats/stats.go | 22 +++++++++++----------- wgengine/magicsock/derp.go | 2 +- wgengine/magicsock/endpoint.go | 4 ++-- wgengine/magicsock/magicsock.go | 2 +- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/net/connstats/stats.go b/net/connstats/stats.go index dbcd946b82d9a..4e6d8e109aaad 100644 --- a/net/connstats/stats.go +++ b/net/connstats/stats.go @@ -131,23 +131,23 @@ func (s *Statistics) updateVirtual(b []byte, receive bool) { s.virtual[conn] = cnts } -// UpdateTxPhysical updates the counters for a transmitted wireguard packet +// UpdateTxPhysical updates the counters for zero or more transmitted wireguard packets. // The src is always a Tailscale IP address, representing some remote peer. // The dst is a remote IP address and port that corresponds // with some physical peer backing the Tailscale IP address. -func (s *Statistics) UpdateTxPhysical(src netip.Addr, dst netip.AddrPort, n int) { - s.updatePhysical(src, dst, n, false) +func (s *Statistics) UpdateTxPhysical(src netip.Addr, dst netip.AddrPort, packets, bytes int) { + s.updatePhysical(src, dst, packets, bytes, false) } -// UpdateRxPhysical updates the counters for a received wireguard packet. +// UpdateRxPhysical updates the counters for zero or more received wireguard packets. // The src is always a Tailscale IP address, representing some remote peer. // The dst is a remote IP address and port that corresponds // with some physical peer backing the Tailscale IP address. -func (s *Statistics) UpdateRxPhysical(src netip.Addr, dst netip.AddrPort, n int) { - s.updatePhysical(src, dst, n, true) +func (s *Statistics) UpdateRxPhysical(src netip.Addr, dst netip.AddrPort, packets, bytes int) { + s.updatePhysical(src, dst, packets, bytes, true) } -func (s *Statistics) updatePhysical(src netip.Addr, dst netip.AddrPort, n int, receive bool) { +func (s *Statistics) updatePhysical(src netip.Addr, dst netip.AddrPort, packets, bytes int, receive bool) { conn := netlogtype.Connection{Src: netip.AddrPortFrom(src, 0), Dst: dst} s.mu.Lock() @@ -157,11 +157,11 @@ func (s *Statistics) updatePhysical(src netip.Addr, dst netip.AddrPort, n int, r return } if receive { - cnts.RxPackets++ - cnts.RxBytes += uint64(n) + cnts.RxPackets += uint64(packets) + cnts.RxBytes += uint64(bytes) } else { - cnts.TxPackets++ - cnts.TxBytes += uint64(n) + cnts.TxPackets += uint64(packets) + cnts.TxBytes += uint64(bytes) } s.physical[conn] = cnts } diff --git a/wgengine/magicsock/derp.go b/wgengine/magicsock/derp.go index 281447ac229ae..bfee02f6e87da 100644 --- a/wgengine/magicsock/derp.go +++ b/wgengine/magicsock/derp.go @@ -730,7 +730,7 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en ep.noteRecvActivity(ipp, mono.Now()) if stats := c.stats.Load(); stats != nil { - stats.UpdateRxPhysical(ep.nodeAddr, ipp, dm.n) + stats.UpdateRxPhysical(ep.nodeAddr, ipp, 1, dm.n) } return n, ep } diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index 78b9ee92a06f5..ab9f3d47dd033 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -976,7 +976,7 @@ func (de *endpoint) send(buffs [][]byte) error { // TODO(raggi): needs updating for accuracy, as in error conditions we may have partial sends. if stats := de.c.stats.Load(); err == nil && stats != nil { - stats.UpdateTxPhysical(de.nodeAddr, udpAddr, txBytes) + stats.UpdateTxPhysical(de.nodeAddr, udpAddr, len(buffs), txBytes) } } if derpAddr.IsValid() { @@ -991,7 +991,7 @@ func (de *endpoint) send(buffs [][]byte) error { } if stats := de.c.stats.Load(); stats != nil { - stats.UpdateTxPhysical(de.nodeAddr, derpAddr, txBytes) + stats.UpdateTxPhysical(de.nodeAddr, derpAddr, 1, txBytes) } if allOk { return nil diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 2d4944baf6fd0..72e59a2e72c62 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -1523,7 +1523,7 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache) ep.lastRecvUDPAny.StoreAtomic(now) ep.noteRecvActivity(ipp, now) if stats := c.stats.Load(); stats != nil { - stats.UpdateRxPhysical(ep.nodeAddr, ipp, len(b)) + stats.UpdateRxPhysical(ep.nodeAddr, ipp, 1, len(b)) } return ep, true } From a8f9c0d6e40a99e091e06572cd5e9b20db7baa21 Mon Sep 17 00:00:00 2001 From: License Updater Date: Mon, 14 Oct 2024 15:03:08 +0000 Subject: [PATCH 013/179] licenses: update license notices Signed-off-by: License Updater --- licenses/android.md | 2 -- licenses/apple.md | 3 +-- licenses/tailscale.md | 1 - licenses/windows.md | 2 +- 4 files changed, 2 insertions(+), 6 deletions(-) diff --git a/licenses/android.md b/licenses/android.md index ef53117e8ceb7..94aeb3fc0615f 100644 --- a/licenses/android.md +++ b/licenses/android.md @@ -36,7 +36,6 @@ Client][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/golang/groupcache/lru](https://pkg.go.dev/github.com/golang/groupcache/lru) ([Apache-2.0](https://github.com/golang/groupcache/blob/41bb18bfe9da/LICENSE)) - [github.com/google/btree](https://pkg.go.dev/github.com/google/btree) ([Apache-2.0](https://github.com/google/btree/blob/v1.1.2/LICENSE)) - [github.com/google/nftables](https://pkg.go.dev/github.com/google/nftables) ([Apache-2.0](https://github.com/google/nftables/blob/5e242ec57806/LICENSE)) - - [github.com/google/uuid](https://pkg.go.dev/github.com/google/uuid) ([BSD-3-Clause](https://github.com/google/uuid/blob/v1.6.0/LICENSE)) - [github.com/hdevalence/ed25519consensus](https://pkg.go.dev/github.com/hdevalence/ed25519consensus) ([BSD-3-Clause](https://github.com/hdevalence/ed25519consensus/blob/v0.2.0/LICENSE)) - [github.com/illarion/gonotify/v2](https://pkg.go.dev/github.com/illarion/gonotify/v2) ([MIT](https://github.com/illarion/gonotify/blob/v2.0.3/LICENSE)) - [github.com/insomniacslk/dhcp](https://pkg.go.dev/github.com/insomniacslk/dhcp) ([BSD-3-Clause](https://github.com/insomniacslk/dhcp/blob/8c70d406f6d2/LICENSE)) @@ -57,7 +56,6 @@ Client][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/safchain/ethtool](https://pkg.go.dev/github.com/safchain/ethtool) ([Apache-2.0](https://github.com/safchain/ethtool/blob/v0.3.0/LICENSE)) - [github.com/tailscale/golang-x-crypto](https://pkg.go.dev/github.com/tailscale/golang-x-crypto) ([BSD-3-Clause](https://github.com/tailscale/golang-x-crypto/blob/3fde5e568aa4/LICENSE)) - [github.com/tailscale/goupnp](https://pkg.go.dev/github.com/tailscale/goupnp) ([BSD-2-Clause](https://github.com/tailscale/goupnp/blob/c64d0f06ea05/LICENSE)) - - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/20486734a56a/LICENSE)) - [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/4d49adab4de7/LICENSE)) - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/b535050b2aa4/LICENSE)) - [github.com/tailscale/tailscale-android/libtailscale](https://pkg.go.dev/github.com/tailscale/tailscale-android/libtailscale) ([BSD-3-Clause](https://github.com/tailscale/tailscale-android/blob/HEAD/LICENSE)) diff --git a/licenses/apple.md b/licenses/apple.md index 4cb100c625942..751082d5b220f 100644 --- a/licenses/apple.md +++ b/licenses/apple.md @@ -63,7 +63,6 @@ See also the dependencies in the [Tailscale CLI][]. - [github.com/safchain/ethtool](https://pkg.go.dev/github.com/safchain/ethtool) ([Apache-2.0](https://github.com/safchain/ethtool/blob/v0.3.0/LICENSE)) - [github.com/tailscale/golang-x-crypto](https://pkg.go.dev/github.com/tailscale/golang-x-crypto) ([BSD-3-Clause](https://github.com/tailscale/golang-x-crypto/blob/3fde5e568aa4/LICENSE)) - [github.com/tailscale/goupnp](https://pkg.go.dev/github.com/tailscale/goupnp) ([BSD-2-Clause](https://github.com/tailscale/goupnp/blob/c64d0f06ea05/LICENSE)) - - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/20486734a56a/LICENSE)) - [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/4d49adab4de7/LICENSE)) - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/b535050b2aa4/LICENSE)) - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/799c1978fafc/LICENSE)) @@ -77,7 +76,7 @@ See also the dependencies in the [Tailscale CLI][]. - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE)) - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fe59bbe5:LICENSE)) - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.7.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.8.0:LICENSE)) - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE)) - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE)) - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE)) diff --git a/licenses/tailscale.md b/licenses/tailscale.md index 544aa91cecab1..b1303d2a6dd8e 100644 --- a/licenses/tailscale.md +++ b/licenses/tailscale.md @@ -80,7 +80,6 @@ Some packages may only be included on certain architectures or operating systems - [github.com/tailscale/certstore](https://pkg.go.dev/github.com/tailscale/certstore) ([MIT](https://github.com/tailscale/certstore/blob/d3fa0460f47e/LICENSE.md)) - [github.com/tailscale/go-winio](https://pkg.go.dev/github.com/tailscale/go-winio) ([MIT](https://github.com/tailscale/go-winio/blob/c4f33415bf55/LICENSE)) - [github.com/tailscale/golang-x-crypto](https://pkg.go.dev/github.com/tailscale/golang-x-crypto) ([BSD-3-Clause](https://github.com/tailscale/golang-x-crypto/blob/3fde5e568aa4/LICENSE)) - - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/20486734a56a/LICENSE)) - [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/4d49adab4de7/LICENSE)) - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/b535050b2aa4/LICENSE)) - [github.com/tailscale/web-client-prebuilt](https://pkg.go.dev/github.com/tailscale/web-client-prebuilt) ([BSD-3-Clause](https://github.com/tailscale/web-client-prebuilt/blob/5db17b287bf1/LICENSE)) diff --git a/licenses/windows.md b/licenses/windows.md index e7f7f6f13ca08..2a8e4e621a4a6 100644 --- a/licenses/windows.md +++ b/licenses/windows.md @@ -70,7 +70,7 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [golang.org/x/image/bmp](https://pkg.go.dev/golang.org/x/image/bmp) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.18.0:LICENSE)) - [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.19.0:LICENSE)) - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.7.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.8.0:LICENSE)) - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE)) - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE)) - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE)) From 5f22f726365851acfb189bfedf436ac34ef42782 Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Tue, 15 Oct 2024 19:38:11 +0100 Subject: [PATCH 014/179] hostinfo,build_docker.sh,tailcfg: more reliably detect being in a container (#13826) Our existing container-detection tricks did not work on Kubernetes, where Docker is no longer used as a container runtime. Extends the existing go build tags for containers to the other container packages and uses that to reliably detect builds that were created by Tailscale for use in a container. Unfortunately this doesn't necessarily improve detection for users' custom builds, but that's a separate issue. Updates #13825 Signed-off-by: Tom Proctor --- build_docker.sh | 2 ++ hostinfo/hostinfo.go | 13 +++++++++++-- hostinfo/hostinfo_container_linux_test.go | 16 ++++++++++++++++ hostinfo/hostinfo_linux_test.go | 8 +++++++- tailcfg/tailcfg.go | 2 +- 5 files changed, 37 insertions(+), 4 deletions(-) create mode 100644 hostinfo/hostinfo_container_linux_test.go diff --git a/build_docker.sh b/build_docker.sh index 1cbdc4b9ef8e8..e8b1c8f28f450 100755 --- a/build_docker.sh +++ b/build_docker.sh @@ -56,6 +56,7 @@ case "$TARGET" in -X tailscale.com/version.gitCommitStamp=${VERSION_GIT_HASH}" \ --base="${BASE}" \ --tags="${TAGS}" \ + --gotags="ts_kube,ts_package_container" \ --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ @@ -72,6 +73,7 @@ case "$TARGET" in -X tailscale.com/version.gitCommitStamp=${VERSION_GIT_HASH}" \ --base="${BASE}" \ --tags="${TAGS}" \ + --gotags="ts_kube,ts_package_container" \ --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ diff --git a/hostinfo/hostinfo.go b/hostinfo/hostinfo.go index 1f9037829d82d..3233a422dd6c3 100644 --- a/hostinfo/hostinfo.go +++ b/hostinfo/hostinfo.go @@ -280,13 +280,22 @@ func getEnvType() EnvType { return "" } -// inContainer reports whether we're running in a container. +// inContainer reports whether we're running in a container. Best-effort only, +// there's no foolproof way to detect this, but the build tag should catch all +// official builds from 1.78.0. func inContainer() opt.Bool { if runtime.GOOS != "linux" { return "" } var ret opt.Bool ret.Set(false) + if packageType != nil && packageType() == "container" { + // Go build tag ts_package_container was set during build. + ret.Set(true) + return ret + } + // Only set if using docker's container runtime. Not guaranteed by + // documentation, but it's been in place for a long time. if _, err := os.Stat("/.dockerenv"); err == nil { ret.Set(true) return ret @@ -362,7 +371,7 @@ func inFlyDotIo() bool { } func inReplit() bool { - // https://docs.replit.com/programming-ide/getting-repl-metadata + // https://docs.replit.com/replit-workspace/configuring-repl#environment-variables if os.Getenv("REPL_OWNER") != "" && os.Getenv("REPL_SLUG") != "" { return true } diff --git a/hostinfo/hostinfo_container_linux_test.go b/hostinfo/hostinfo_container_linux_test.go new file mode 100644 index 0000000000000..594a5f5120a6a --- /dev/null +++ b/hostinfo/hostinfo_container_linux_test.go @@ -0,0 +1,16 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !android && ts_package_container + +package hostinfo + +import ( + "testing" +) + +func TestInContainer(t *testing.T) { + if got := inContainer(); !got.EqualBool(true) { + t.Errorf("inContainer = %v; want true due to ts_package_container build tag", got) + } +} diff --git a/hostinfo/hostinfo_linux_test.go b/hostinfo/hostinfo_linux_test.go index 4859167a270ec..c8bd2abbeb230 100644 --- a/hostinfo/hostinfo_linux_test.go +++ b/hostinfo/hostinfo_linux_test.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux && !android +//go:build linux && !android && !ts_package_container package hostinfo @@ -34,3 +34,9 @@ remotes/origin/QTSFW_5.0.0` t.Errorf("got %q; want %q", got, want) } } + +func TestInContainer(t *testing.T) { + if got := inContainer(); !got.EqualBool(false) { + t.Errorf("inContainer = %v; want false due to absence of ts_package_container build tag", got) + } +} diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index df50a860311d1..92bf2cd95da15 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -771,7 +771,7 @@ type Hostinfo struct { // "5.10.0-17-amd64". OSVersion string `json:",omitempty"` - Container opt.Bool `json:",omitempty"` // whether the client is running in a container + Container opt.Bool `json:",omitempty"` // best-effort whether the client is running in a container Env string `json:",omitempty"` // a hostinfo.EnvType in string form Distro string `json:",omitempty"` // "debian", "ubuntu", "nixos", ... DistroVersion string `json:",omitempty"` // "20.04", ... From 2aa9125ac438ffa902158b5bedf9791c93117b9b Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Tue, 15 Oct 2024 16:18:04 -0400 Subject: [PATCH 015/179] cmd/derpprobe: add /healthz endpoint For a customer that wants to run their own DERP prober, let's add a /healthz endpoint that can be used to monitor derpprobe itself. Updates #6526 Signed-off-by: Andrew Dunham Change-Id: Iba315c999fc0b1a93d8c503c07cc733b4c8d5b6b --- cmd/derpprobe/derpprobe.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cmd/derpprobe/derpprobe.go b/cmd/derpprobe/derpprobe.go index 1d0ec32c3c064..5b7b77091de7f 100644 --- a/cmd/derpprobe/derpprobe.go +++ b/cmd/derpprobe/derpprobe.go @@ -75,6 +75,11 @@ func main() { prober.WithPageLink("Prober metrics", "/debug/varz"), prober.WithProbeLink("Run Probe", "/debug/probe-run?name={{.Name}}"), ), tsweb.HandlerOptions{Logf: log.Printf})) + mux.Handle("/healthz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok\n")) + })) log.Printf("Listening on %s", *listen) log.Fatal(http.ListenAndServe(*listen, mux)) } From ff5f233c3a43fa20a61ba4a76a2f3f5a75f8d437 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Mon, 7 Oct 2024 21:18:45 -0500 Subject: [PATCH 016/179] util/syspolicy: add rsop package that provides access to the resultant policy In this PR we add syspolicy/rsop package that facilitates policy source registration and provides access to the resultant policy merged from all registered sources for a given scope. Updates #12687 Signed-off-by: Nick Khyl --- util/syspolicy/internal/internal.go | 3 + util/syspolicy/rsop/change_callbacks.go | 107 ++ util/syspolicy/rsop/resultant_policy.go | 449 +++++++++ util/syspolicy/rsop/resultant_policy_test.go | 986 +++++++++++++++++++ util/syspolicy/rsop/rsop.go | 174 ++++ util/syspolicy/rsop/store_registration.go | 94 ++ util/syspolicy/setting/policy_scope.go | 3 + util/syspolicy/setting/setting.go | 3 + util/syspolicy/source/test_store.go | 33 +- 9 files changed, 1834 insertions(+), 18 deletions(-) create mode 100644 util/syspolicy/rsop/change_callbacks.go create mode 100644 util/syspolicy/rsop/resultant_policy.go create mode 100644 util/syspolicy/rsop/resultant_policy_test.go create mode 100644 util/syspolicy/rsop/rsop.go create mode 100644 util/syspolicy/rsop/store_registration.go diff --git a/util/syspolicy/internal/internal.go b/util/syspolicy/internal/internal.go index 4c3e28d3914bb..8f28896259abf 100644 --- a/util/syspolicy/internal/internal.go +++ b/util/syspolicy/internal/internal.go @@ -13,6 +13,9 @@ import ( "tailscale.com/version" ) +// Init facilitates deferred invocation of initializers. +var Init lazy.DeferredInit + // OSForTesting is the operating system override used for testing. // It follows the same naming convention as [version.OS]. var OSForTesting lazy.SyncValue[string] diff --git a/util/syspolicy/rsop/change_callbacks.go b/util/syspolicy/rsop/change_callbacks.go new file mode 100644 index 0000000000000..b962f30c008c1 --- /dev/null +++ b/util/syspolicy/rsop/change_callbacks.go @@ -0,0 +1,107 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rsop + +import ( + "reflect" + "slices" + "sync" + "time" + + "tailscale.com/util/set" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/setting" +) + +// Change represents a change from the Old to the New value of type T. +type Change[T any] struct { + New, Old T +} + +// PolicyChangeCallback is a function called whenever a policy changes. +type PolicyChangeCallback func(*PolicyChange) + +// PolicyChange describes a policy change. +type PolicyChange struct { + snapshots Change[*setting.Snapshot] +} + +// New returns the [setting.Snapshot] after the change. +func (c PolicyChange) New() *setting.Snapshot { + return c.snapshots.New +} + +// Old returns the [setting.Snapshot] before the change. +func (c PolicyChange) Old() *setting.Snapshot { + return c.snapshots.Old +} + +// HasChanged reports whether a policy setting with the specified [setting.Key], has changed. +func (c PolicyChange) HasChanged(key setting.Key) bool { + new, newErr := c.snapshots.New.GetErr(key) + old, oldErr := c.snapshots.Old.GetErr(key) + if newErr != nil && oldErr != nil { + return false + } + if newErr != nil || oldErr != nil { + return true + } + switch newVal := new.(type) { + case bool, uint64, string, setting.Visibility, setting.PreferenceOption, time.Duration: + return newVal != old + case []string: + oldVal, ok := old.([]string) + return !ok || !slices.Equal(newVal, oldVal) + default: + loggerx.Errorf("[unexpected] %q has an unsupported value type: %T", key, newVal) + return !reflect.DeepEqual(new, old) + } +} + +// policyChangeCallbacks are the callbacks to invoke when the effective policy changes. +// It is safe for concurrent use. +type policyChangeCallbacks struct { + mu sync.Mutex + cbs set.HandleSet[PolicyChangeCallback] +} + +// Register adds the specified callback to be invoked whenever the policy changes. +func (c *policyChangeCallbacks) Register(callback PolicyChangeCallback) (unregister func()) { + c.mu.Lock() + handle := c.cbs.Add(callback) + c.mu.Unlock() + return func() { + c.mu.Lock() + delete(c.cbs, handle) + c.mu.Unlock() + } +} + +// Invoke calls the registered callback functions with the specified policy change info. +func (c *policyChangeCallbacks) Invoke(snapshots Change[*setting.Snapshot]) { + var wg sync.WaitGroup + defer wg.Wait() + + c.mu.Lock() + defer c.mu.Unlock() + + wg.Add(len(c.cbs)) + change := &PolicyChange{snapshots: snapshots} + for _, cb := range c.cbs { + go func() { + defer wg.Done() + cb(change) + }() + } +} + +// Close awaits the completion of active callbacks and prevents any further invocations. +func (c *policyChangeCallbacks) Close() { + c.mu.Lock() + defer c.mu.Unlock() + if c.cbs != nil { + clear(c.cbs) + c.cbs = nil + } +} diff --git a/util/syspolicy/rsop/resultant_policy.go b/util/syspolicy/rsop/resultant_policy.go new file mode 100644 index 0000000000000..019b8f602f86d --- /dev/null +++ b/util/syspolicy/rsop/resultant_policy.go @@ -0,0 +1,449 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rsop + +import ( + "errors" + "fmt" + "slices" + "sync" + "sync/atomic" + "time" + + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/setting" + + "tailscale.com/util/syspolicy/source" +) + +// ErrPolicyClosed is returned by [Policy.Reload], [Policy.addSource], +// [Policy.removeSource] and [Policy.replaceSource] if the policy has been closed. +var ErrPolicyClosed = errors.New("effective policy closed") + +// The minimum and maximum wait times after detecting a policy change +// before reloading the policy. This only affects policy reloads triggered +// by a change in the underlying [source.Store] and does not impact +// synchronous, caller-initiated reloads, such as when [Policy.Reload] is called. +// +// Policy changes occurring within [policyReloadMinDelay] of each other +// will be batched together, resulting in a single policy reload +// no later than [policyReloadMaxDelay] after the first detected change. +// In other words, the effective policy will be reloaded no more often than once +// every 5 seconds, but at most 15 seconds after an underlying [source.Store] +// has issued a policy change callback. +// +// See [Policy.watchReload]. +var ( + policyReloadMinDelay = 5 * time.Second + policyReloadMaxDelay = 15 * time.Second +) + +// Policy provides access to the current effective [setting.Snapshot] for a given +// scope and allows to reload it from the underlying [source.Store] list. It also allows to +// subscribe and receive a callback whenever the effective [setting.Snapshot] is changed. +// +// It is safe for concurrent use. +type Policy struct { + scope setting.PolicyScope + + reloadCh chan reloadRequest // 1-buffered; written to when a policy reload is required + closeCh chan struct{} // closed to signal that the Policy is being closed + doneCh chan struct{} // closed by [Policy.closeInternal] + + // effective is the most recent version of the [setting.Snapshot] + // containing policy settings merged from all applicable sources. + effective atomic.Pointer[setting.Snapshot] + + changeCallbacks policyChangeCallbacks + + mu sync.Mutex + watcherStarted bool // whether [Policy.watchReload] was started + sources source.ReadableSources + closing bool // whether [Policy.Close] was called (even if we're still closing) +} + +// newPolicy returns a new [Policy] for the specified [setting.PolicyScope] +// that tracks changes and merges policy settings read from the specified sources. +func newPolicy(scope setting.PolicyScope, sources ...*source.Source) (_ *Policy, err error) { + readableSources := make(source.ReadableSources, 0, len(sources)) + defer func() { + if err != nil { + readableSources.Close() + } + }() + for _, s := range sources { + reader, err := s.Reader() + if err != nil { + return nil, fmt.Errorf("failed to get a store reader: %w", err) + } + session, err := reader.OpenSession() + if err != nil { + return nil, fmt.Errorf("failed to open a reading session: %w", err) + } + readableSources = append(readableSources, source.ReadableSource{Source: s, ReadingSession: session}) + } + + // Sort policy sources by their precedence from lower to higher. + // For example, {UserPolicy},{ProfilePolicy},{DevicePolicy}. + readableSources.StableSort() + + p := &Policy{ + scope: scope, + sources: readableSources, + reloadCh: make(chan reloadRequest, 1), + closeCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + if _, err := p.reloadNow(false); err != nil { + p.Close() + return nil, err + } + p.startWatchReloadIfNeeded() + return p, nil +} + +// IsValid reports whether p is in a valid state and has not been closed. +// +// Since p's state can be changed by other goroutines at any time, this should +// only be used as an optimization. +func (p *Policy) IsValid() bool { + select { + case <-p.closeCh: + return false + default: + return true + } +} + +// Scope returns the [setting.PolicyScope] that this policy applies to. +func (p *Policy) Scope() setting.PolicyScope { + return p.scope +} + +// Get returns the effective [setting.Snapshot]. +func (p *Policy) Get() *setting.Snapshot { + return p.effective.Load() +} + +// RegisterChangeCallback adds a function to be called whenever the effective +// policy changes. The returned function can be used to unregister the callback. +func (p *Policy) RegisterChangeCallback(callback PolicyChangeCallback) (unregister func()) { + return p.changeCallbacks.Register(callback) +} + +// Reload synchronously re-reads policy settings from the underlying list of policy sources, +// constructing a new merged [setting.Snapshot] even if the policy remains unchanged. +// In most scenarios, there's no need to re-read the policy manually. +// Instead, it is recommended to register a policy change callback, or to use +// the most recent [setting.Snapshot] returned by the [Policy.Get] method. +// +// It must not be called with p.mu held. +func (p *Policy) Reload() (*setting.Snapshot, error) { + return p.reload(true) +} + +// reload is like Reload, but allows to specify whether to re-read policy settings +// from unchanged policy sources. +// +// It must not be called with p.mu held. +func (p *Policy) reload(force bool) (*setting.Snapshot, error) { + if !p.startWatchReloadIfNeeded() { + return p.Get(), nil + } + + respCh := make(chan reloadResponse, 1) + select { + case p.reloadCh <- reloadRequest{force: force, respCh: respCh}: + // continue + case <-p.closeCh: + return nil, ErrPolicyClosed + } + select { + case resp := <-respCh: + return resp.policy, resp.err + case <-p.closeCh: + return nil, ErrPolicyClosed + } +} + +// reloadAsync requests an asynchronous background policy reload. +// The policy will be reloaded no later than in [policyReloadMaxDelay]. +// +// It must not be called with p.mu held. +func (p *Policy) reloadAsync() { + if !p.startWatchReloadIfNeeded() { + return + } + select { + case p.reloadCh <- reloadRequest{}: + // Sent. + default: + // A reload request is already en route. + } +} + +// reloadNow loads and merges policies from all sources, updating the effective policy. +// If the force parameter is true, it forcibly reloads policies +// from the underlying policy store, even if no policy changes were detected. +// +// Except for the initial policy reload during the [Policy] creation, +// this method should only be called from the [Policy.watchReload] goroutine. +func (p *Policy) reloadNow(force bool) (*setting.Snapshot, error) { + new, err := p.readAndMerge(force) + if err != nil { + return nil, err + } + old := p.effective.Swap(new) + // A nil old value indicates the initial policy load rather than a policy change. + // Additionally, we should not invoke the policy change callbacks unless the + // policy items have actually changed. + if old != nil && !old.EqualItems(new) { + snapshots := Change[*setting.Snapshot]{New: new, Old: old} + p.changeCallbacks.Invoke(snapshots) + } + return new, nil +} + +// Done returns a channel that is closed when the [Policy] is closed. +func (p *Policy) Done() <-chan struct{} { + return p.doneCh +} + +// readAndMerge reads and merges policy settings from all applicable sources, +// returning a [setting.Snapshot] with the merged result. +// If the force parameter is true, it re-reads policy settings from each source +// even if no policy change was observed, and returns an error if the read +// operation fails. +func (p *Policy) readAndMerge(force bool) (*setting.Snapshot, error) { + p.mu.Lock() + defer p.mu.Unlock() + // Start with an empty policy in the target scope. + effective := setting.NewSnapshot(nil, setting.SummaryWith(p.scope)) + // Then merge policy settings from all sources. + // Policy sources with the highest precedence (e.g., the device policy) are merged last, + // overriding any conflicting policy settings with lower precedence. + for _, s := range p.sources { + var policy *setting.Snapshot + if force { + var err error + if policy, err = s.ReadSettings(); err != nil { + return nil, err + } + } else { + policy = s.GetSettings() + } + effective = setting.MergeSnapshots(effective, policy) + } + return effective, nil +} + +// addSource adds the specified source to the list of sources used by p, +// and triggers a synchronous policy refresh. It returns an error +// if the source is not a valid source for this effective policy, +// or if the effective policy is being closed, +// or if policy refresh fails with an error. +func (p *Policy) addSource(source *source.Source) error { + return p.applySourcesChange(source, nil) +} + +// removeSource removes the specified source from the list of sources used by p, +// and triggers a synchronous policy refresh. It returns an error if the +// effective policy is being closed, or if policy refresh fails with an error. +func (p *Policy) removeSource(source *source.Source) error { + return p.applySourcesChange(nil, source) +} + +// replaceSource replaces the old source with the new source atomically, +// and triggers a synchronous policy refresh. It returns an error +// if the source is not a valid source for this effective policy, +// or if the effective policy is being closed, +// or if policy refresh fails with an error. +func (p *Policy) replaceSource(old, new *source.Source) error { + return p.applySourcesChange(new, old) +} + +func (p *Policy) applySourcesChange(toAdd, toRemove *source.Source) error { + if toAdd == toRemove { + return nil + } + if toAdd != nil && !toAdd.Scope().Contains(p.scope) { + return errors.New("scope mismatch") + } + + changed, err := func() (changed bool, err error) { + p.mu.Lock() + defer p.mu.Unlock() + if toAdd != nil && !p.sources.Contains(toAdd) { + reader, err := toAdd.Reader() + if err != nil { + return false, fmt.Errorf("failed to get a store reader: %w", err) + } + session, err := reader.OpenSession() + if err != nil { + return false, fmt.Errorf("failed to open a reading session: %w", err) + } + + addAt := p.sources.InsertionIndexOf(toAdd) + toAdd := source.ReadableSource{ + Source: toAdd, + ReadingSession: session, + } + p.sources = slices.Insert(p.sources, addAt, toAdd) + go p.watchPolicyChanges(toAdd) + changed = true + } + if toRemove != nil { + if deleteAt := p.sources.IndexOf(toRemove); deleteAt != -1 { + p.sources.DeleteAt(deleteAt) + changed = true + } + } + return changed, nil + }() + if changed { + _, err = p.reload(false) + } + return err // may be nil or non-nil +} + +func (p *Policy) watchPolicyChanges(s source.ReadableSource) { + for { + select { + case _, ok := <-s.ReadingSession.PolicyChanged(): + if !ok { + p.mu.Lock() + abruptlyClosed := slices.Contains(p.sources, s) + p.mu.Unlock() + if abruptlyClosed { + // The underlying [source.Source] was closed abruptly without + // being properly removed or replaced by another policy source. + // We can't keep this [Policy] up to date, so we should close it. + p.Close() + } + return + } + // The PolicyChanged channel was signaled. + // Request an asynchronous policy reload. + p.reloadAsync() + case <-p.closeCh: + // The [Policy] is being closed. + return + } + } +} + +// startWatchReloadIfNeeded starts [Policy.watchReload] in a new goroutine +// if the list of policy sources is not empty, it hasn't been started yet, +// and the [Policy] is not being closed. +// It reports whether [Policy.watchReload] has ever been started. +// +// It must not be called with p.mu held. +func (p *Policy) startWatchReloadIfNeeded() bool { + p.mu.Lock() + defer p.mu.Unlock() + if len(p.sources) != 0 && !p.watcherStarted && !p.closing { + go p.watchReload() + for i := range p.sources { + go p.watchPolicyChanges(p.sources[i]) + } + p.watcherStarted = true + } + return p.watcherStarted +} + +// reloadRequest describes a policy reload request. +type reloadRequest struct { + // force policy reload regardless of whether a policy change was detected. + force bool + // respCh is an optional channel. If non-nil, it makes the reload request + // synchronous and receives the result. + respCh chan<- reloadResponse +} + +// reloadResponse is a result of a synchronous policy reload. +type reloadResponse struct { + policy *setting.Snapshot + err error +} + +// watchReload processes incoming synchronous and asynchronous policy reload requests. +// +// Synchronous requests (with a non-nil respCh) are served immediately. +// +// Asynchronous requests are debounced and throttled: they are executed at least +// [policyReloadMinDelay] after the last request, but no later than [policyReloadMaxDelay] +// after the first request in a batch. +func (p *Policy) watchReload() { + defer p.closeInternal() + + force := false // whether a forced refresh was requested + var delayCh, timeoutCh <-chan time.Time + reload := func(respCh chan<- reloadResponse) { + delayCh, timeoutCh = nil, nil + policy, err := p.reloadNow(force) + if err != nil { + loggerx.Errorf("%v policy reload failed: %v\n", p.scope, err) + } + if respCh != nil { + respCh <- reloadResponse{policy: policy, err: err} + } + force = false + } + +loop: + for { + select { + case req := <-p.reloadCh: + if req.force { + force = true + } + if req.respCh != nil { + reload(req.respCh) + continue + } + if delayCh == nil { + timeoutCh = time.After(policyReloadMinDelay) + } + delayCh = time.After(policyReloadMaxDelay) + case <-delayCh: + reload(nil) + case <-timeoutCh: + reload(nil) + case <-p.closeCh: + break loop + } + } +} + +func (p *Policy) closeInternal() { + p.mu.Lock() + defer p.mu.Unlock() + p.sources.Close() + p.changeCallbacks.Close() + close(p.doneCh) + deletePolicy(p) +} + +// Close initiates the closing of the policy. +// The [Policy.Done] channel is closed to signal that the operation has been completed. +func (p *Policy) Close() { + p.mu.Lock() + alreadyClosing := p.closing + watcherStarted := p.watcherStarted + p.closing = true + p.mu.Unlock() + + if alreadyClosing { + return + } + + close(p.closeCh) + if !watcherStarted { + // Normally, closing p.closeCh signals [Policy.watchReload] to exit, + // and [Policy.closeInternal] performs the actual closing when + // [Policy.watchReload] returns. However, if the watcher was never + // started, we need to call [Policy.closeInternal] manually. + go p.closeInternal() + } +} diff --git a/util/syspolicy/rsop/resultant_policy_test.go b/util/syspolicy/rsop/resultant_policy_test.go new file mode 100644 index 0000000000000..b2408c7f71519 --- /dev/null +++ b/util/syspolicy/rsop/resultant_policy_test.go @@ -0,0 +1,986 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rsop + +import ( + "errors" + "slices" + "sort" + "strconv" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/tstest" + "tailscale.com/util/syspolicy/setting" + + "tailscale.com/util/syspolicy/source" +) + +func TestGetEffectivePolicyNoSource(t *testing.T) { + tests := []struct { + name string + scope setting.PolicyScope + }{ + { + name: "DevicePolicy", + scope: setting.DeviceScope, + }, + { + name: "CurrentProfilePolicy", + scope: setting.CurrentProfileScope, + }, + { + name: "CurrentUserPolicy", + scope: setting.CurrentUserScope, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var policy *Policy + t.Cleanup(func() { + if policy != nil { + policy.Close() + <-policy.Done() + } + }) + + // Make sure we don't create any goroutines. + // We intentionally call ResourceCheck after t.Cleanup, so that when the test exits, + // the resource check runs before the test cleanup closes the policy. + // This helps to report any unexpectedly created goroutines. + // The goal is to ensure that using the syspolicy package, and particularly + // the rsop sub-package, is not wasteful and does not create unnecessary goroutines + // on platforms without registered policy sources. + tstest.ResourceCheck(t) + + policy, err := PolicyFor(tt.scope) + if err != nil { + t.Fatalf("Failed to get effective policy for %v: %v", tt.scope, err) + } + + if got := policy.Get(); got.Len() != 0 { + t.Errorf("Snapshot: got %v; want empty", got) + } + + if got, err := policy.Reload(); err != nil { + t.Errorf("Reload failed: %v", err) + } else if got.Len() != 0 { + t.Errorf("Snapshot: got %v; want empty", got) + } + }) + } +} + +func TestRegisterSourceAndGetEffectivePolicy(t *testing.T) { + type sourceConfig struct { + name string + scope setting.PolicyScope + settingKey setting.Key + settingValue string + wantEffective bool + } + tests := []struct { + name string + scope setting.PolicyScope + initialSources []sourceConfig + additionalSources []sourceConfig + wantSnapshot *setting.Snapshot + }{ + { + name: "DevicePolicy/NoSources", + scope: setting.DeviceScope, + wantSnapshot: setting.NewSnapshot(nil, setting.DeviceScope), + }, + { + name: "UserScope/NoSources", + scope: setting.CurrentUserScope, + wantSnapshot: setting.NewSnapshot(nil, setting.CurrentUserScope), + }, + { + name: "DevicePolicy/OneInitialSource", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, + { + name: "DevicePolicy/OneAdditionalSource", + scope: setting.DeviceScope, + additionalSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, + { + name: "DevicePolicy/ManyInitialSources/NoConflicts", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + { + name: "TestSourceB", + scope: setting.DeviceScope, + settingKey: "TestKeyB", + settingValue: "TestValueB", + wantEffective: true, + }, + { + name: "TestSourceC", + scope: setting.DeviceScope, + settingKey: "TestKeyC", + settingValue: "TestValueC", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)), + "TestKeyC": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)), + }, setting.DeviceScope), + }, + { + name: "DevicePolicy/ManyInitialSources/Conflicts", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + { + name: "TestSourceB", + scope: setting.DeviceScope, + settingKey: "TestKeyB", + settingValue: "TestValueB", + wantEffective: true, + }, + { + name: "TestSourceC", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueC", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)), + }, setting.DeviceScope), + }, + { + name: "DevicePolicy/MixedSources/Conflicts", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + { + name: "TestSourceB", + scope: setting.DeviceScope, + settingKey: "TestKeyB", + settingValue: "TestValueB", + wantEffective: true, + }, + { + name: "TestSourceC", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueC", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceD", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueD", + wantEffective: true, + }, + { + name: "TestSourceE", + scope: setting.DeviceScope, + settingKey: "TestKeyC", + settingValue: "TestValueE", + wantEffective: true, + }, + { + name: "TestSourceF", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueF", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueF", nil, setting.NewNamedOrigin("TestSourceF", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)), + "TestKeyC": setting.RawItemWith("TestValueE", nil, setting.NewNamedOrigin("TestSourceE", setting.DeviceScope)), + }, setting.DeviceScope), + }, + { + name: "UserScope/Init-DeviceSource", + scope: setting.CurrentUserScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, setting.CurrentUserScope, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, + { + name: "UserScope/Init-DeviceSource/Add-UserSource", + scope: setting.CurrentUserScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceUser", + scope: setting.CurrentUserScope, + settingKey: "TestKeyB", + settingValue: "UserValue", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("UserValue", nil, setting.NewNamedOrigin("TestSourceUser", setting.CurrentUserScope)), + }, setting.CurrentUserScope), + }, + { + name: "UserScope/Init-DeviceSource/Add-UserSource-and-ProfileSource", + scope: setting.CurrentUserScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceProfile", + scope: setting.CurrentProfileScope, + settingKey: "TestKeyB", + settingValue: "ProfileValue", + wantEffective: true, + }, + { + name: "TestSourceUser", + scope: setting.CurrentUserScope, + settingKey: "TestKeyB", + settingValue: "UserValue", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("ProfileValue", nil, setting.NewNamedOrigin("TestSourceProfile", setting.CurrentProfileScope)), + }, setting.CurrentUserScope), + }, + { + name: "DevicePolicy/User-Source-does-not-apply", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceUser", + scope: setting.CurrentUserScope, + settingKey: "TestKeyA", + settingValue: "UserValue", + wantEffective: false, // Registering a user source should have no impact on the device policy. + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Register all settings that we use in this test. + var definitions []*setting.Definition + for _, source := range slices.Concat(tt.initialSources, tt.additionalSources) { + definitions = append(definitions, setting.NewDefinition(source.settingKey, tt.scope.Kind(), setting.StringValue)) + } + if err := setting.SetDefinitionsForTest(t, definitions...); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + // Add the initial policy sources. + var wantSources []*source.Source + for _, s := range tt.initialSources { + store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue)) + source := source.NewSource(s.name, s.scope, store) + if err := registerSource(source); err != nil { + t.Fatalf("Failed to register policy source: %v", source) + } + if s.wantEffective { + wantSources = append(wantSources, source) + } + t.Cleanup(func() { unregisterSource(source) }) + } + + // Retrieve the effective policy. + policy, err := policyForTest(t, tt.scope) + if err != nil { + t.Fatalf("Failed to get effective policy for %v: %v", tt.scope, err) + } + + checkPolicySources(t, policy, wantSources) + + // Add additional setting sources. + for _, s := range tt.additionalSources { + store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue)) + source := source.NewSource(s.name, s.scope, store) + if err := registerSource(source); err != nil { + t.Fatalf("Failed to register additional policy source: %v", source) + } + if s.wantEffective { + wantSources = append(wantSources, source) + } + t.Cleanup(func() { unregisterSource(source) }) + } + + checkPolicySources(t, policy, wantSources) + + // Verify the final effective settings snapshots. + if got := policy.Get(); !got.Equal(tt.wantSnapshot) { + t.Errorf("Snapshot: got %v; want %v", got, tt.wantSnapshot) + } + }) + } +} + +func TestPolicyFor(t *testing.T) { + tests := []struct { + name string + scopeA, scopeB setting.PolicyScope + closePolicy bool // indicates whether to close policyA before retrieving policyB + wantSame bool // specifies whether policyA and policyB should reference the same [Policy] instance + }{ + { + name: "Device/Device", + scopeA: setting.DeviceScope, + scopeB: setting.DeviceScope, + wantSame: true, + }, + { + name: "Device/CurrentProfile", + scopeA: setting.DeviceScope, + scopeB: setting.CurrentProfileScope, + wantSame: false, + }, + { + name: "Device/CurrentUser", + scopeA: setting.DeviceScope, + scopeB: setting.CurrentUserScope, + wantSame: false, + }, + { + name: "CurrentProfile/CurrentProfile", + scopeA: setting.CurrentProfileScope, + scopeB: setting.CurrentProfileScope, + wantSame: true, + }, + { + name: "CurrentProfile/CurrentUser", + scopeA: setting.CurrentProfileScope, + scopeB: setting.CurrentUserScope, + wantSame: false, + }, + { + name: "CurrentUser/CurrentUser", + scopeA: setting.CurrentUserScope, + scopeB: setting.CurrentUserScope, + wantSame: true, + }, + { + name: "UserA/UserA", + scopeA: setting.UserScopeOf("UserA"), + scopeB: setting.UserScopeOf("UserA"), + wantSame: true, + }, + { + name: "UserA/UserB", + scopeA: setting.UserScopeOf("UserA"), + scopeB: setting.UserScopeOf("UserB"), + wantSame: false, + }, + { + name: "New-after-close", + scopeA: setting.DeviceScope, + scopeB: setting.DeviceScope, + closePolicy: true, + wantSame: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + policyA, err := policyForTest(t, tt.scopeA) + if err != nil { + t.Fatalf("Failed to get effective policy for %v: %v", tt.scopeA, err) + } + + if tt.closePolicy { + policyA.Close() + } + + policyB, err := policyForTest(t, tt.scopeB) + if err != nil { + t.Fatalf("Failed to get effective policy for %v: %v", tt.scopeB, err) + } + + if gotSame := policyA == policyB; gotSame != tt.wantSame { + t.Fatalf("Got same: %v; want same %v", gotSame, tt.wantSame) + } + }) + } +} + +func TestPolicyChangeHasChanged(t *testing.T) { + tests := []struct { + name string + old, new map[setting.Key]setting.RawItem + wantChanged []setting.Key + wantUnchanged []setting.Key + }{ + { + name: "String-Settings", + old: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf("Old"), + "UnchangedSetting": setting.RawItemOf("Value"), + }, + new: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf("New"), + "UnchangedSetting": setting.RawItemOf("Value"), + }, + wantChanged: []setting.Key{"ChangedSetting"}, + wantUnchanged: []setting.Key{"UnchangedSetting"}, + }, + { + name: "UInt64-Settings", + old: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf(uint64(0)), + "UnchangedSetting": setting.RawItemOf(uint64(42)), + }, + new: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf(uint64(1)), + "UnchangedSetting": setting.RawItemOf(uint64(42)), + }, + wantChanged: []setting.Key{"ChangedSetting"}, + wantUnchanged: []setting.Key{"UnchangedSetting"}, + }, + { + name: "StringSlice-Settings", + old: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf([]string{"Chicago"}), + "UnchangedSetting": setting.RawItemOf([]string{"String1", "String2"}), + }, + new: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf([]string{"New York"}), + "UnchangedSetting": setting.RawItemOf([]string{"String1", "String2"}), + }, + wantChanged: []setting.Key{"ChangedSetting"}, + wantUnchanged: []setting.Key{"UnchangedSetting"}, + }, + { + name: "Int8-Settings", // We don't have actual int8 settings, but this should still work. + old: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf(int8(0)), + "UnchangedSetting": setting.RawItemOf(int8(42)), + }, + new: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf(int8(1)), + "UnchangedSetting": setting.RawItemOf(int8(42)), + }, + wantChanged: []setting.Key{"ChangedSetting"}, + wantUnchanged: []setting.Key{"UnchangedSetting"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + old := setting.NewSnapshot(tt.old) + new := setting.NewSnapshot(tt.new) + change := PolicyChange{Change[*setting.Snapshot]{old, new}} + for _, wantChanged := range tt.wantChanged { + if !change.HasChanged(wantChanged) { + t.Errorf("%q changed: got false; want true", wantChanged) + } + } + for _, wantUnchanged := range tt.wantUnchanged { + if change.HasChanged(wantUnchanged) { + t.Errorf("%q unchanged: got true; want false", wantUnchanged) + } + } + }) + } +} + +func TestChangePolicySetting(t *testing.T) { + setForTest(t, &policyReloadMinDelay, 100*time.Millisecond) + setForTest(t, &policyReloadMaxDelay, 500*time.Millisecond) + + // Register policy settings used in this test. + settingA := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue) + settingB := setting.NewDefinition("TestSettingB", setting.DeviceSetting, setting.StringValue) + if err := setting.SetDefinitionsForTest(t, settingA, settingB); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + // Register a test policy store and create a effective policy that reads the policy settings from it. + store := source.NewTestStoreOf[string](t) + if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil { + t.Fatalf("Failed to register policy store: %v", err) + } + policy, err := policyForTest(t, setting.DeviceScope) + if err != nil { + t.Fatalf("Failed to get effective policy: %v", err) + } + + // The policy setting is not configured yet. + if _, ok := policy.Get().GetSetting(settingA.Key()); ok { + t.Fatalf("Policy setting %q unexpectedly exists", settingA.Key()) + } + + // Subscribe to the policy change callback... + policyChanged := make(chan *PolicyChange) + unregister := policy.RegisterChangeCallback(func(pc *PolicyChange) { policyChanged <- pc }) + t.Cleanup(unregister) + + // ...make the change, and measure the time between initiating the change + // and receiving the callback. + start := time.Now() + const wantValueA = "TestValueA" + store.SetStrings(source.TestSettingOf(settingA.Key(), wantValueA)) + change := <-policyChanged + gotDelay := time.Since(start) + + // Ensure there is at least a [policyReloadMinDelay] delay between + // a change and the policy reload along with the callback invocation. + // This prevents reloading policy settings too frequently + // when multiple settings change within a short period of time. + if gotDelay < policyReloadMinDelay { + t.Errorf("Delay: got %v; want >= %v", gotDelay, policyReloadMinDelay) + } + + // Verify that the [PolicyChange] passed to the policy change callback + // contains the correct information regarding the policy setting changes. + if !change.HasChanged(settingA.Key()) { + t.Errorf("Policy setting %q has not changed", settingA.Key()) + } + if change.HasChanged(settingB.Key()) { + t.Errorf("Policy setting %q was unexpectedly changed", settingB.Key()) + } + if _, ok := change.Old().GetSetting(settingA.Key()); ok { + t.Fatalf("Policy setting %q unexpectedly exists", settingA.Key()) + } + if gotValue := change.New().Get(settingA.Key()); gotValue != wantValueA { + t.Errorf("Policy setting %q: got %q; want %q", settingA.Key(), gotValue, wantValueA) + } + + // And also verify that the current (most recent) [setting.Snapshot] + // includes the change we just made. + if gotValue := policy.Get().Get(settingA.Key()); gotValue != wantValueA { + t.Errorf("Policy setting %q: got %q; want %q", settingA.Key(), gotValue, wantValueA) + } + + // Now, let's change another policy setting value N times. + const N = 10 + wantValueB := strconv.Itoa(N) + start = time.Now() + for i := range N { + store.SetStrings(source.TestSettingOf(settingB.Key(), strconv.Itoa(i+1))) + } + + // The callback should be invoked only once, even though the policy setting + // has changed N times. + change = <-policyChanged + gotDelay = time.Since(start) + gotCallbacks := 1 +drain: + for { + select { + case <-policyChanged: + gotCallbacks++ + case <-time.After(policyReloadMaxDelay): + break drain + } + } + if wantCallbacks := 1; gotCallbacks > wantCallbacks { + t.Errorf("Callbacks: got %d; want %d", gotCallbacks, wantCallbacks) + } + + // Additionally, the policy change callback should be received no sooner + // than [policyReloadMinDelay] and no later than [policyReloadMaxDelay]. + if gotDelay < policyReloadMinDelay || gotDelay > policyReloadMaxDelay { + t.Errorf("Delay: got %v; want >= %v && <= %v", gotDelay, policyReloadMinDelay, policyReloadMaxDelay) + } + + // Verify that the [PolicyChange] received via the callback + // contains the final policy setting value. + if !change.HasChanged(settingB.Key()) { + t.Errorf("Policy setting %q has not changed", settingB.Key()) + } + if change.HasChanged(settingA.Key()) { + t.Errorf("Policy setting %q was unexpectedly changed", settingA.Key()) + } + if _, ok := change.Old().GetSetting(settingB.Key()); ok { + t.Fatalf("Policy setting %q unexpectedly exists", settingB.Key()) + } + if gotValue := change.New().Get(settingB.Key()); gotValue != wantValueB { + t.Errorf("Policy setting %q: got %q; want %q", settingB.Key(), gotValue, wantValueB) + } + + // Lastly, if a policy store issues a change notification, but the effective policy + // remains unchanged, the [Policy] should ignore it without invoking the change callbacks. + store.NotifyPolicyChanged() + select { + case <-policyChanged: + t.Fatal("Unexpected policy changed notification") + case <-time.After(policyReloadMaxDelay): + } +} + +func TestClosePolicySource(t *testing.T) { + testSetting := setting.NewDefinition("TestSetting", setting.DeviceSetting, setting.StringValue) + if err := setting.SetDefinitionsForTest(t, testSetting); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + wantSettingValue := "TestValue" + store := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), wantSettingValue)) + if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil { + t.Fatalf("Failed to register policy store: %v", err) + } + policy, err := policyForTest(t, setting.DeviceScope) + if err != nil { + t.Fatalf("Failed to get effective policy: %v", err) + } + + initialSnapshot, err := policy.Reload() + if err != nil { + t.Fatalf("Failed to reload policy: %v", err) + } + if gotSettingValue, err := initialSnapshot.GetErr(testSetting.Key()); err != nil { + t.Fatalf("Failed to get %q setting value: %v", testSetting.Key(), err) + } else if gotSettingValue != wantSettingValue { + t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), gotSettingValue, wantSettingValue) + } + + store.Close() + + // Closing a policy source abruptly without removing it first should invalidate and close the policy. + <-policy.Done() + if policy.IsValid() { + t.Fatal("The policy was not properly closed") + } + + // The resulting policy snapshot should remain valid and unchanged. + finalSnapshot := policy.Get() + if !finalSnapshot.Equal(initialSnapshot) { + t.Fatal("Policy snapshot has changed") + } + if gotSettingValue, err := finalSnapshot.GetErr(testSetting.Key()); err != nil { + t.Fatalf("Failed to get final %q setting value: %v", testSetting.Key(), err) + } else if gotSettingValue != wantSettingValue { + t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), gotSettingValue, wantSettingValue) + } + + // However, any further requests to reload the policy should fail. + if _, err := policy.Reload(); err == nil || !errors.Is(err, ErrPolicyClosed) { + t.Fatalf("Reload: gotErr: %v; wantErr: %v", err, ErrPolicyClosed) + } +} + +func TestRemovePolicySource(t *testing.T) { + // Register policy settings used in this test. + settingA := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue) + settingB := setting.NewDefinition("TestSettingB", setting.DeviceSetting, setting.StringValue) + if err := setting.SetDefinitionsForTest(t, settingA, settingB); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + // Register two policy stores. + storeA := source.NewTestStoreOf(t, source.TestSettingOf(settingA.Key(), "A")) + storeRegA, err := RegisterStoreForTest(t, "TestSourceA", setting.DeviceScope, storeA) + if err != nil { + t.Fatalf("Failed to register policy store A: %v", err) + } + storeB := source.NewTestStoreOf(t, source.TestSettingOf(settingB.Key(), "B")) + storeRegB, err := RegisterStoreForTest(t, "TestSourceB", setting.DeviceScope, storeB) + if err != nil { + t.Fatalf("Failed to register policy store A: %v", err) + } + + // Create a effective [Policy] that reads policy settings from the two stores. + policy, err := policyForTest(t, setting.DeviceScope) + if err != nil { + t.Fatalf("Failed to get effective policy: %v", err) + } + + // Verify that the [Policy] uses both stores and includes policy settings from each. + if gotSources, wantSources := len(policy.sources), 2; gotSources != wantSources { + t.Fatalf("Policy Sources: got %v; want %v", gotSources, wantSources) + } + if got, want := policy.Get().Get(settingA.Key()), "A"; got != want { + t.Fatalf("Setting %q: got %q; want %q", settingA.Key(), got, want) + } + if got, want := policy.Get().Get(settingB.Key()), "B"; got != want { + t.Fatalf("Setting %q: got %q; want %q", settingB.Key(), got, want) + } + + // Unregister Store A and verify that the effective policy remains valid. + // It should no longer use the removed store or include any policy settings from it. + if err := storeRegA.Unregister(); err != nil { + t.Fatalf("Failed to unregister Store A: %v", err) + } + if !policy.IsValid() { + t.Fatalf("Policy was unexpectedly closed") + } + if gotSources, wantSources := len(policy.sources), 1; gotSources != wantSources { + t.Fatalf("Policy Sources: got %v; want %v", gotSources, wantSources) + } + if got, want := policy.Get().Get(settingA.Key()), any(nil); got != want { + t.Fatalf("Setting %q: got %q; want %q", settingA.Key(), got, want) + } + if got, want := policy.Get().Get(settingB.Key()), "B"; got != want { + t.Fatalf("Setting %q: got %q; want %q", settingB.Key(), got, want) + } + + // Unregister Store B and verify that the effective policy is still valid. + // However, it should be empty since there are no associated policy sources. + if err := storeRegB.Unregister(); err != nil { + t.Fatalf("Failed to unregister Store B: %v", err) + } + if !policy.IsValid() { + t.Fatalf("Policy was unexpectedly closed") + } + if gotSources, wantSources := len(policy.sources), 0; gotSources != wantSources { + t.Fatalf("Policy Sources: got %v; want %v", gotSources, wantSources) + } + if got := policy.Get(); got.Len() != 0 { + t.Fatalf("Settings: got %v; want {Empty}", got) + } +} + +func TestReplacePolicySource(t *testing.T) { + setForTest(t, &policyReloadMinDelay, 100*time.Millisecond) + setForTest(t, &policyReloadMaxDelay, 500*time.Millisecond) + + // Register policy settings used in this test. + testSetting := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue) + if err := setting.SetDefinitionsForTest(t, testSetting); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + // Create two policy stores. + initialStore := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), "InitialValue")) + newStore := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), "NewValue")) + unchangedStore := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), "NewValue")) + + // Register the initial store and create a effective [Policy] that reads policy settings from it. + reg, err := RegisterStoreForTest(t, "TestStore", setting.DeviceScope, initialStore) + if err != nil { + t.Fatalf("Failed to register the initial store: %v", err) + } + policy, err := policyForTest(t, setting.DeviceScope) + if err != nil { + t.Fatalf("Failed to get effective policy: %v", err) + } + + // Verify that the test setting has its initial value. + if got, want := policy.Get().Get(testSetting.Key()), "InitialValue"; got != want { + t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), got, want) + } + + // Subscribe to the policy change callback. + policyChanged := make(chan *PolicyChange, 1) + unregister := policy.RegisterChangeCallback(func(pc *PolicyChange) { policyChanged <- pc }) + t.Cleanup(unregister) + + // Now, let's replace the initial store with the new store. + reg, err = reg.ReplaceStore(newStore) + if err != nil { + t.Fatalf("Failed to replace the policy store: %v", err) + } + t.Cleanup(func() { reg.Unregister() }) + + // We should receive a policy change notification as the setting value has changed. + <-policyChanged + + // Verify that the test setting has the new value. + if got, want := policy.Get().Get(testSetting.Key()), "NewValue"; got != want { + t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), got, want) + } + + // Replacing a policy store with an identical one containing the same + // values for the same settings should not be considered a policy change. + reg, err = reg.ReplaceStore(unchangedStore) + if err != nil { + t.Fatalf("Failed to replace the policy store: %v", err) + } + t.Cleanup(func() { reg.Unregister() }) + + select { + case <-policyChanged: + t.Fatal("Unexpected policy changed notification") + default: + <-time.After(policyReloadMaxDelay) + } +} + +func TestAddClosedPolicySource(t *testing.T) { + store := source.NewTestStoreOf[string](t) + if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil { + t.Fatalf("Failed to register policy store: %v", err) + } + store.Close() + + _, err := policyForTest(t, setting.DeviceScope) + if err == nil || !errors.Is(err, source.ErrStoreClosed) { + t.Fatalf("got: %v; want: %v", err, source.ErrStoreClosed) + } +} + +func TestClosePolicyMoreThanOnce(t *testing.T) { + tests := []struct { + name string + numSources int + }{ + { + name: "NoSources", + numSources: 0, + }, + { + name: "OneSource", + numSources: 1, + }, + { + name: "ManySources", + numSources: 10, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for i := range tt.numSources { + store := source.NewTestStoreOf[string](t) + if _, err := RegisterStoreForTest(t, "TestSource #"+strconv.Itoa(i), setting.DeviceScope, store); err != nil { + t.Fatalf("Failed to register policy store: %v", err) + } + } + + policy, err := policyForTest(t, setting.DeviceScope) + if err != nil { + t.Fatalf("failed to get effective policy: %v", err) + } + + const N = 10000 + var wg sync.WaitGroup + for range N { + wg.Add(1) + go func() { + wg.Done() + policy.Close() + <-policy.Done() + }() + } + wg.Wait() + }) + } +} + +func checkPolicySources(tb testing.TB, gotPolicy *Policy, wantSources []*source.Source) { + tb.Helper() + sort.SliceStable(wantSources, func(i, j int) bool { + return wantSources[i].Compare(wantSources[j]) < 0 + }) + gotSources := make([]*source.Source, len(gotPolicy.sources)) + for i := range gotPolicy.sources { + gotSources[i] = gotPolicy.sources[i].Source + } + type sourceSummary struct{ Name, Scope string } + toSourceSummary := cmp.Transformer("source", func(s *source.Source) sourceSummary { return sourceSummary{s.Name(), s.Scope().String()} }) + if diff := cmp.Diff(wantSources, gotSources, toSourceSummary, cmpopts.EquateEmpty()); diff != "" { + tb.Errorf("Policy Sources mismatch: %v", diff) + } +} + +// policyForTest is like [PolicyFor], but it deletes the policy +// when tb and all its subtests complete. +func policyForTest(tb testing.TB, target setting.PolicyScope) (*Policy, error) { + tb.Helper() + + policy, err := PolicyFor(target) + if err != nil { + return nil, err + } + tb.Cleanup(func() { + policy.Close() + <-policy.Done() + deletePolicy(policy) + }) + return policy, nil +} + +func setForTest[T any](tb testing.TB, target *T, newValue T) { + oldValue := *target + tb.Cleanup(func() { *target = oldValue }) + *target = newValue +} diff --git a/util/syspolicy/rsop/rsop.go b/util/syspolicy/rsop/rsop.go new file mode 100644 index 0000000000000..429b9b10121b3 --- /dev/null +++ b/util/syspolicy/rsop/rsop.go @@ -0,0 +1,174 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package rsop facilitates [source.Store] registration via [RegisterStore] +// and provides access to the effective policy merged from all registered sources +// via [PolicyFor]. +package rsop + +import ( + "errors" + "fmt" + "slices" + "sync" + + "tailscale.com/syncs" + "tailscale.com/util/slicesx" + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" +) + +var ( + policyMu sync.Mutex // protects [policySources] and [effectivePolicies] + policySources []*source.Source // all registered policy sources + effectivePolicies []*Policy // all active (non-closed) effective policies returned by [PolicyFor] + + // effectivePolicyLRU is an LRU cache of [Policy] by [setting.Scope]. + // Although there could be multiple [setting.PolicyScope] instances with the same [setting.Scope], + // such as two user scopes for different users, there is only one [setting.DeviceScope], only one + // [setting.CurrentProfileScope], and in most cases, only one active user scope. + // Therefore, cache misses that require falling back to [effectivePolicies] are extremely rare. + // It's a fixed-size array of atomic values and can be accessed without [policyMu] held. + effectivePolicyLRU [setting.NumScopes]syncs.AtomicValue[*Policy] +) + +// PolicyFor returns the [Policy] for the specified scope, +// creating it from the registered [source.Store]s if it doesn't already exist. +func PolicyFor(scope setting.PolicyScope) (*Policy, error) { + if err := internal.Init.Do(); err != nil { + return nil, err + } + policy := effectivePolicyLRU[scope.Kind()].Load() + if policy != nil && policy.Scope() == scope && policy.IsValid() { + return policy, nil + } + return policyForSlow(scope) +} + +func policyForSlow(scope setting.PolicyScope) (policy *Policy, err error) { + defer func() { + // Always update the LRU cache on exit if we found (or created) + // a policy for the specified scope. + if policy != nil { + effectivePolicyLRU[scope.Kind()].Store(policy) + } + }() + + policyMu.Lock() + defer policyMu.Unlock() + if policy, ok := findPolicyByScopeLocked(scope); ok { + return policy, nil + } + + // If there is no existing effective policy for the specified scope, + // we need to create one using the policy sources registered for that scope. + sources := slicesx.Filter(nil, policySources, func(source *source.Source) bool { + return source.Scope().Contains(scope) + }) + policy, err = newPolicy(scope, sources...) + if err != nil { + return nil, err + } + effectivePolicies = append(effectivePolicies, policy) + return policy, nil +} + +// findPolicyByScopeLocked returns a policy with the specified scope and true if +// one exists in the [effectivePolicies] list, otherwise it returns nil, false. +// [policyMu] must be held. +func findPolicyByScopeLocked(target setting.PolicyScope) (policy *Policy, ok bool) { + for _, policy := range effectivePolicies { + if policy.Scope() == target && policy.IsValid() { + return policy, true + } + } + return nil, false +} + +// deletePolicy deletes the specified effective policy from [effectivePolicies] +// and [effectivePolicyLRU]. +func deletePolicy(policy *Policy) { + policyMu.Lock() + defer policyMu.Unlock() + if i := slices.Index(effectivePolicies, policy); i != -1 { + effectivePolicies = slices.Delete(effectivePolicies, i, i+1) + } + effectivePolicyLRU[policy.Scope().Kind()].CompareAndSwap(policy, nil) +} + +// registerSource registers the specified [source.Source] to be used by the package. +// It updates existing [Policy]s returned by [PolicyFor] to use this source if +// they are within the source's [setting.PolicyScope]. +func registerSource(source *source.Source) error { + policyMu.Lock() + defer policyMu.Unlock() + if slices.Contains(policySources, source) { + // already registered + return nil + } + policySources = append(policySources, source) + return forEachEffectivePolicyLocked(func(policy *Policy) error { + if !source.Scope().Contains(policy.Scope()) { + // Policy settings in the specified source do not apply + // to the scope of this effective policy. + // For example, a user policy source is being registered + // while the effective policy is for the device (or another user). + return nil + } + return policy.addSource(source) + }) +} + +// replaceSource is like [unregisterSource](old) followed by [registerSource](new), +// but performed atomically: the effective policy will contain settings +// either from the old source or the new source, never both and never neither. +func replaceSource(old, new *source.Source) error { + policyMu.Lock() + defer policyMu.Unlock() + oldIndex := slices.Index(policySources, old) + if oldIndex == -1 { + return fmt.Errorf("the source is not registered: %v", old) + } + policySources[oldIndex] = new + return forEachEffectivePolicyLocked(func(policy *Policy) error { + if !old.Scope().Contains(policy.Scope()) || !new.Scope().Contains(policy.Scope()) { + return nil + } + return policy.replaceSource(old, new) + }) +} + +// unregisterSource unregisters the specified [source.Source], +// so that it won't be used by any new or existing [Policy]. +func unregisterSource(source *source.Source) error { + policyMu.Lock() + defer policyMu.Unlock() + index := slices.Index(policySources, source) + if index == -1 { + return nil + } + policySources = slices.Delete(policySources, index, index+1) + return forEachEffectivePolicyLocked(func(policy *Policy) error { + if !source.Scope().Contains(policy.Scope()) { + return nil + } + return policy.removeSource(source) + }) +} + +// forEachEffectivePolicyLocked calls fn for every non-closed [Policy] in [effectivePolicies]. +// It accumulates the returned errors and returns an error that wraps all errors returned by fn. +// The [policyMu] mutex must be held while this function is executed. +func forEachEffectivePolicyLocked(fn func(p *Policy) error) error { + var errs []error + for _, policy := range effectivePolicies { + if policy.IsValid() { + err := fn(policy) + if err != nil && !errors.Is(err, ErrPolicyClosed) { + errs = append(errs, err) + } + } + } + return errors.Join(errs...) +} diff --git a/util/syspolicy/rsop/store_registration.go b/util/syspolicy/rsop/store_registration.go new file mode 100644 index 0000000000000..09c83e98804ca --- /dev/null +++ b/util/syspolicy/rsop/store_registration.go @@ -0,0 +1,94 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rsop + +import ( + "errors" + "sync" + "sync/atomic" + + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" +) + +// ErrAlreadyConsumed is the error returned when [StoreRegistration.ReplaceStore] +// or [StoreRegistration.Unregister] is called more than once. +var ErrAlreadyConsumed = errors.New("the store registration is no longer valid") + +// StoreRegistration is a [source.Store] registered for use in the specified scope. +// It can be used to unregister the store, or replace it with another one. +type StoreRegistration struct { + source *source.Source + m sync.Mutex // protects the [StoreRegistration.consumeSlow] path + consumed atomic.Bool // can be read without holding m, but must be written with m held +} + +// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope]. +func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { + return newStoreRegistration(name, scope, store) +} + +// RegisterStoreForTest is like [RegisterStore], but unregisters the store when +// tb and all its subtests complete. +func RegisterStoreForTest(tb internal.TB, name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { + reg, err := RegisterStore(name, scope, store) + if err == nil { + tb.Cleanup(func() { + if err := reg.Unregister(); err != nil && !errors.Is(err, ErrAlreadyConsumed) { + tb.Fatalf("Unregister failed: %v", err) + } + }) + } + return reg, err // may be nil or non-nil +} + +func newStoreRegistration(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { + source := source.NewSource(name, scope, store) + if err := registerSource(source); err != nil { + return nil, err + } + return &StoreRegistration{source: source}, nil +} + +// ReplaceStore replaces the registered store with the new one, +// returning a new [StoreRegistration] or an error. +func (r *StoreRegistration) ReplaceStore(new source.Store) (*StoreRegistration, error) { + var res *StoreRegistration + err := r.consume(func() error { + newSource := source.NewSource(r.source.Name(), r.source.Scope(), new) + if err := replaceSource(r.source, newSource); err != nil { + return err + } + res = &StoreRegistration{source: newSource} + return nil + }) + return res, err +} + +// Unregister reverts the registration. +func (r *StoreRegistration) Unregister() error { + return r.consume(func() error { return unregisterSource(r.source) }) +} + +// consume invokes fn, consuming r if no error is returned. +// It returns [ErrAlreadyConsumed] on subsequent calls after the first successful call. +func (r *StoreRegistration) consume(fn func() error) (err error) { + if r.consumed.Load() { + return ErrAlreadyConsumed + } + return r.consumeSlow(fn) +} + +func (r *StoreRegistration) consumeSlow(fn func() error) (err error) { + r.m.Lock() + defer r.m.Unlock() + if r.consumed.Load() { + return ErrAlreadyConsumed + } + if err = fn(); err == nil { + r.consumed.Store(true) + } + return err // may be nil or non-nil +} diff --git a/util/syspolicy/setting/policy_scope.go b/util/syspolicy/setting/policy_scope.go index 55fa339e7e813..c2039fdda15b8 100644 --- a/util/syspolicy/setting/policy_scope.go +++ b/util/syspolicy/setting/policy_scope.go @@ -8,6 +8,7 @@ import ( "strings" "tailscale.com/types/lazy" + "tailscale.com/util/syspolicy/internal" ) var ( @@ -35,6 +36,8 @@ type PolicyScope struct { // when querying policy settings. // It returns [DeviceScope], unless explicitly changed with [SetDefaultScope]. func DefaultScope() PolicyScope { + // Allow deferred package init functions to override the default scope. + internal.Init.Do() return lazyDefaultScope.Get(func() PolicyScope { return DeviceScope }) } diff --git a/util/syspolicy/setting/setting.go b/util/syspolicy/setting/setting.go index 93be287b11e86..70fb0a931e250 100644 --- a/util/syspolicy/setting/setting.go +++ b/util/syspolicy/setting/setting.go @@ -243,6 +243,9 @@ func registerLocked(d *Definition) { func settingDefinitions() (DefinitionMap, error) { return definitions.GetErr(func() (DefinitionMap, error) { + if err := internal.Init.Do(); err != nil { + return nil, err + } definitionsMu.Lock() defer definitionsMu.Unlock() definitionsUsed = true diff --git a/util/syspolicy/source/test_store.go b/util/syspolicy/source/test_store.go index bb8e164fb414a..1f19bbb4386b9 100644 --- a/util/syspolicy/source/test_store.go +++ b/util/syspolicy/source/test_store.go @@ -89,6 +89,7 @@ type TestStore struct { suspendCount int // change callback are suspended if > 0 mr, mw map[setting.Key]any // maps for reading and writing; they're the same unless the store is suspended. cbs set.HandleSet[func()] + closed bool readsMu sync.Mutex reads map[testReadOperation]int // how many times a policy setting was read @@ -98,24 +99,20 @@ type TestStore struct { // The tb will be used to report coding errors detected by the [TestStore]. func NewTestStore(tb internal.TB) *TestStore { m := make(map[setting.Key]any) - return &TestStore{ + store := &TestStore{ tb: tb, done: make(chan struct{}), mr: m, mw: m, } + tb.Cleanup(store.Close) + return store } // NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans], // [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists]. func NewTestStoreOf[T TestValueType](tb internal.TB, settings ...TestSetting[T]) *TestStore { - m := make(map[setting.Key]any) - store := &TestStore{ - tb: tb, - done: make(chan struct{}), - mr: m, - mw: m, - } + store := NewTestStore(tb) switch settings := any(settings).(type) { case []TestSetting[bool]: store.SetBooleans(settings...) @@ -308,7 +305,7 @@ func (s *TestStore) Resume() { s.mr = s.mw s.mu.Unlock() s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() case s.suspendCount < 0: s.tb.Fatal("negative suspendCount") default: @@ -333,7 +330,7 @@ func (s *TestStore) SetBooleans(settings ...TestSetting[bool]) { s.mu.Unlock() } s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } // SetUInt64s sets the specified integer settings in s. @@ -352,7 +349,7 @@ func (s *TestStore) SetUInt64s(settings ...TestSetting[uint64]) { s.mu.Unlock() } s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } // SetStrings sets the specified string settings in s. @@ -371,7 +368,7 @@ func (s *TestStore) SetStrings(settings ...TestSetting[string]) { s.mu.Unlock() } s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } // SetStrings sets the specified string list settings in s. @@ -390,7 +387,7 @@ func (s *TestStore) SetStringLists(settings ...TestSetting[[]string]) { s.mu.Unlock() } s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } // Delete deletes the specified settings from s. @@ -402,7 +399,7 @@ func (s *TestStore) Delete(keys ...setting.Key) { s.mu.Unlock() } s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } // Clear deletes all settings from s. @@ -412,10 +409,10 @@ func (s *TestStore) Clear() { clear(s.mw) s.mu.Unlock() s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } -func (s *TestStore) notifyPolicyChanged() { +func (s *TestStore) NotifyPolicyChanged() { s.mu.RLock() if s.suspendCount != 0 { s.mu.RUnlock() @@ -439,9 +436,9 @@ func (s *TestStore) notifyPolicyChanged() { func (s *TestStore) Close() { s.mu.Lock() defer s.mu.Unlock() - if s.done != nil { + if !s.closed { close(s.done) - s.done = nil + s.closed = true } } From 74dd24ce7173fc593f67692538a78d175b3b37c1 Mon Sep 17 00:00:00 2001 From: Christian Date: Mon, 14 Oct 2024 15:52:03 -0700 Subject: [PATCH 017/179] cmd/tsconnect, logpolicy: fixes for wasm_js.go * updates to LocalBackend require metrics to be passed in which are now initialized * os.MkdirTemp isn't supported in wasm/js so we simply return empty string for logger * adds a UDP dialer which was missing and led to the dialer being incompletely initialized Fixes #10454 and #8272 Signed-off-by: Christian --- cmd/tsconnect/wasm/wasm_js.go | 4 ++++ logpolicy/logpolicy.go | 3 +++ 2 files changed, 7 insertions(+) diff --git a/cmd/tsconnect/wasm/wasm_js.go b/cmd/tsconnect/wasm/wasm_js.go index 8291ac9b4735f..c35d543aabeae 100644 --- a/cmd/tsconnect/wasm/wasm_js.go +++ b/cmd/tsconnect/wasm/wasm_js.go @@ -108,6 +108,7 @@ func newIPN(jsConfig js.Value) map[string]any { SetSubsystem: sys.Set, ControlKnobs: sys.ControlKnobs(), HealthTracker: sys.HealthTracker(), + Metrics: sys.UserMetricsRegistry(), }) if err != nil { log.Fatal(err) @@ -128,6 +129,9 @@ func newIPN(jsConfig js.Value) map[string]any { dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { return ns.DialContextTCP(ctx, dst) } + dialer.NetstackDialUDP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { + return ns.DialContextUDP(ctx, dst) + } sys.NetstackRouter.Set(true) sys.Tun.Get().Start() diff --git a/logpolicy/logpolicy.go b/logpolicy/logpolicy.go index 0d2af77f2d703..d657c4e9352f3 100644 --- a/logpolicy/logpolicy.go +++ b/logpolicy/logpolicy.go @@ -230,6 +230,9 @@ func LogsDir(logf logger.Logf) string { logf("logpolicy: using $STATE_DIRECTORY, %q", systemdStateDir) return systemdStateDir } + case "js": + logf("logpolicy: no logs directory in the browser") + return "" } // Default to e.g. /var/lib/tailscale or /var/db/tailscale on Unix. From 6a885dbc36edb4b2395c4df3d901f42b722d7ced Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 16 Oct 2024 09:33:21 -0700 Subject: [PATCH 018/179] wgengine/magicsock: fix CI-only test warning of missing health tracker While looking at deflaking TestTwoDevicePing/ping_1.0.0.2_via_SendPacket, there were a bunch of distracting: WARNING: (non-fatal) nil health.Tracker (being strict in CI): ... This pacifies those so it's easier to work on actually deflaking the test. Updates #11762 Updates #11874 Change-Id: I08dcb44511d4996b68d5f1ce5a2619b555a2a773 Signed-off-by: Brad Fitzpatrick --- wgengine/magicsock/magicsock_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index c1b8eef223257..7e48e1daa2604 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -176,6 +176,7 @@ func newMagicStackWithKey(t testing.TB, logf logger.Logf, l nettype.PacketListen if err != nil { t.Fatalf("netmon.New: %v", err) } + ht := new(health.Tracker) var reg usermetric.Registry epCh := make(chan []tailcfg.Endpoint, 100) // arbitrary @@ -183,6 +184,7 @@ func newMagicStackWithKey(t testing.TB, logf logger.Logf, l nettype.PacketListen NetMon: netMon, Metrics: ®, Logf: logf, + HealthTracker: ht, DisablePortMapper: true, TestOnlyPacketListener: l, EndpointsFunc: func(eps []tailcfg.Endpoint) { From d32d742af0632445b71befecd75b7fcbf5c68865 Mon Sep 17 00:00:00 2001 From: Mario Minardi Date: Wed, 16 Oct 2024 14:09:53 -0600 Subject: [PATCH 019/179] ipn/ipnlocal: error when trying to use exit node on unsupported platform (#13726) Adds logic to `checkExitNodePrefsLocked` to return an error when attempting to use exit nodes on a platform where this is not supported. This mirrors logic that was added to error out when trying to use `ssh` on an unsupported platform, and has very similar semantics. Fixes https://github.com/tailscale/tailscale/issues/13724 Signed-off-by: Mario Minardi --- client/web/web.go | 26 ++---------- cmd/k8s-operator/depaware.txt | 1 + cmd/tailscale/depaware.txt | 1 + cmd/tailscaled/depaware.txt | 1 + envknob/featureknob/featureknob.go | 68 ++++++++++++++++++++++++++++++ envknob/features.go | 39 ----------------- ipn/ipnlocal/local.go | 7 ++- 7 files changed, 80 insertions(+), 63 deletions(-) create mode 100644 envknob/featureknob/featureknob.go delete mode 100644 envknob/features.go diff --git a/client/web/web.go b/client/web/web.go index 04ba2d086334a..56c5c92e808bb 100644 --- a/client/web/web.go +++ b/client/web/web.go @@ -26,6 +26,7 @@ import ( "tailscale.com/client/tailscale/apitype" "tailscale.com/clientupdate" "tailscale.com/envknob" + "tailscale.com/envknob/featureknob" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" @@ -960,37 +961,16 @@ func (s *Server) serveGetNodeData(w http.ResponseWriter, r *http.Request) { } func availableFeatures() map[string]bool { - env := hostinfo.GetEnvType() features := map[string]bool{ "advertise-exit-node": true, // available on all platforms "advertise-routes": true, // available on all platforms - "use-exit-node": canUseExitNode(env) == nil, - "ssh": envknob.CanRunTailscaleSSH() == nil, + "use-exit-node": featureknob.CanUseExitNode() == nil, + "ssh": featureknob.CanRunTailscaleSSH() == nil, "auto-update": version.IsUnstableBuild() && clientupdate.CanAutoUpdate(), } - if env == hostinfo.HomeAssistantAddOn { - // Setting SSH on Home Assistant causes trouble on startup - // (since the flag is not being passed to `tailscale up`). - // Although Tailscale SSH does work here, - // it's not terribly useful since it's running in a separate container. - features["ssh"] = false - } return features } -func canUseExitNode(env hostinfo.EnvType) error { - switch dist := distro.Get(); dist { - case distro.Synology, // see https://github.com/tailscale/tailscale/issues/1995 - distro.QNAP, - distro.Unraid: - return fmt.Errorf("Tailscale exit nodes cannot be used on %s.", dist) - } - if env == hostinfo.HomeAssistantAddOn { - return errors.New("Tailscale exit nodes cannot be used on Home Assistant.") - } - return nil -} - // aclsAllowAccess returns whether tailnet ACLs (as expressed in the provided filter rules) // permit any devices to access the local web client. // This does not currently check whether a specific device can connect, just any device. diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index b77ea22ef5297..66c2c8baef5fb 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -668,6 +668,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/doctor/routetable from tailscale.com/ipn/ipnlocal tailscale.com/drive from tailscale.com/client/tailscale+ tailscale.com/envknob from tailscale.com/client/tailscale+ + tailscale.com/envknob/featureknob from tailscale.com/client/web+ tailscale.com/health from tailscale.com/control/controlclient+ tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal tailscale.com/hostinfo from tailscale.com/client/web+ diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 2c644d1be7d79..73aedc9e5e695 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -92,6 +92,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/disco from tailscale.com/derp tailscale.com/drive from tailscale.com/client/tailscale+ tailscale.com/envknob from tailscale.com/client/tailscale+ + tailscale.com/envknob/featureknob from tailscale.com/client/web tailscale.com/health from tailscale.com/net/tlsdial+ tailscale.com/health/healthmsg from tailscale.com/cmd/tailscale/cli tailscale.com/hostinfo from tailscale.com/client/web+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 6f71a88a93217..10df37d797f4b 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -263,6 +263,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/drive/driveimpl/dirfs from tailscale.com/drive/driveimpl+ tailscale.com/drive/driveimpl/shared from tailscale.com/drive/driveimpl+ tailscale.com/envknob from tailscale.com/client/tailscale+ + tailscale.com/envknob/featureknob from tailscale.com/client/web+ tailscale.com/health from tailscale.com/control/controlclient+ tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal tailscale.com/hostinfo from tailscale.com/client/web+ diff --git a/envknob/featureknob/featureknob.go b/envknob/featureknob/featureknob.go new file mode 100644 index 0000000000000..d7af80d239782 --- /dev/null +++ b/envknob/featureknob/featureknob.go @@ -0,0 +1,68 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package featureknob provides a facility to control whether features +// can run based on either an envknob or running OS / distro. +package featureknob + +import ( + "errors" + "runtime" + + "tailscale.com/envknob" + "tailscale.com/hostinfo" + "tailscale.com/version" + "tailscale.com/version/distro" +) + +// CanRunTailscaleSSH reports whether serving a Tailscale SSH server is +// supported for the current os/distro. +func CanRunTailscaleSSH() error { + switch runtime.GOOS { + case "linux": + if distro.Get() == distro.Synology && !envknob.UseWIPCode() { + return errors.New("The Tailscale SSH server does not run on Synology.") + } + if distro.Get() == distro.QNAP && !envknob.UseWIPCode() { + return errors.New("The Tailscale SSH server does not run on QNAP.") + } + + // Setting SSH on Home Assistant causes trouble on startup + // (since the flag is not being passed to `tailscale up`). + // Although Tailscale SSH does work here, + // it's not terribly useful since it's running in a separate container. + if hostinfo.GetEnvType() == hostinfo.HomeAssistantAddOn { + return errors.New("The Tailscale SSH server does not run on HomeAssistant.") + } + // otherwise okay + case "darwin": + // okay only in tailscaled mode for now. + if version.IsSandboxedMacOS() { + return errors.New("The Tailscale SSH server does not run in sandboxed Tailscale GUI builds.") + } + case "freebsd", "openbsd": + default: + return errors.New("The Tailscale SSH server is not supported on " + runtime.GOOS) + } + if !envknob.CanSSHD() { + return errors.New("The Tailscale SSH server has been administratively disabled.") + } + return nil +} + +// CanUseExitNode reports whether using an exit node is supported for the +// current os/distro. +func CanUseExitNode() error { + switch dist := distro.Get(); dist { + case distro.Synology, // see https://github.com/tailscale/tailscale/issues/1995 + distro.QNAP, + distro.Unraid: + return errors.New("Tailscale exit nodes cannot be used on " + string(dist)) + } + + if hostinfo.GetEnvType() == hostinfo.HomeAssistantAddOn { + return errors.New("Tailscale exit nodes cannot be used on HomeAssistant.") + } + + return nil +} diff --git a/envknob/features.go b/envknob/features.go deleted file mode 100644 index 9e5909de309f0..0000000000000 --- a/envknob/features.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package envknob - -import ( - "errors" - "runtime" - - "tailscale.com/version" - "tailscale.com/version/distro" -) - -// CanRunTailscaleSSH reports whether serving a Tailscale SSH server is -// supported for the current os/distro. -func CanRunTailscaleSSH() error { - switch runtime.GOOS { - case "linux": - if distro.Get() == distro.Synology && !UseWIPCode() { - return errors.New("The Tailscale SSH server does not run on Synology.") - } - if distro.Get() == distro.QNAP && !UseWIPCode() { - return errors.New("The Tailscale SSH server does not run on QNAP.") - } - // otherwise okay - case "darwin": - // okay only in tailscaled mode for now. - if version.IsSandboxedMacOS() { - return errors.New("The Tailscale SSH server does not run in sandboxed Tailscale GUI builds.") - } - case "freebsd", "openbsd": - default: - return errors.New("The Tailscale SSH server is not supported on " + runtime.GOOS) - } - if !CanSSHD() { - return errors.New("The Tailscale SSH server has been administratively disabled.") - } - return nil -} diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 06dd84831254c..c7df4333b89ec 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -51,6 +51,7 @@ import ( "tailscale.com/doctor/routetable" "tailscale.com/drive" "tailscale.com/envknob" + "tailscale.com/envknob/featureknob" "tailscale.com/health" "tailscale.com/health/healthmsg" "tailscale.com/hostinfo" @@ -3484,7 +3485,7 @@ func (b *LocalBackend) checkSSHPrefsLocked(p *ipn.Prefs) error { if !p.RunSSH { return nil } - if err := envknob.CanRunTailscaleSSH(); err != nil { + if err := featureknob.CanRunTailscaleSSH(); err != nil { return err } if runtime.GOOS == "linux" { @@ -3565,6 +3566,10 @@ func updateExitNodeUsageWarning(p ipn.PrefsView, state *netmon.State, healthTrac } func (b *LocalBackend) checkExitNodePrefsLocked(p *ipn.Prefs) error { + if err := featureknob.CanUseExitNode(); err != nil { + return err + } + if (p.ExitNodeIP.IsValid() || p.ExitNodeID != "") && p.AdvertisesExitNode() { return errors.New("Cannot advertise an exit node and use an exit node at the same time.") } From 22c89fcb19ea36159e232c45b4f5e91c73b9e486 Mon Sep 17 00:00:00 2001 From: Naman Sood Date: Wed, 16 Oct 2024 19:08:06 -0400 Subject: [PATCH 020/179] cmd/tailscale,ipn,tailcfg: add `tailscale advertise` subcommand behind envknob (#13734) Signed-off-by: Naman Sood --- cmd/tailscale/cli/advertise.go | 78 ++++++++++++++++++++++++++++++++++ cmd/tailscale/cli/cli.go | 4 +- cmd/tailscale/cli/cli_test.go | 4 ++ cmd/tailscale/cli/up.go | 3 ++ ipn/ipn_clone.go | 2 + ipn/ipn_view.go | 4 ++ ipn/prefs.go | 11 +++++ ipn/prefs_test.go | 11 +++++ tailcfg/tailcfg.go | 15 +++++++ 9 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 cmd/tailscale/cli/advertise.go diff --git a/cmd/tailscale/cli/advertise.go b/cmd/tailscale/cli/advertise.go new file mode 100644 index 0000000000000..c9474c4274dd2 --- /dev/null +++ b/cmd/tailscale/cli/advertise.go @@ -0,0 +1,78 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "context" + "flag" + "fmt" + "strings" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/envknob" + "tailscale.com/ipn" + "tailscale.com/tailcfg" +) + +var advertiseArgs struct { + services string // comma-separated list of services to advertise +} + +// TODO(naman): This flag may move to set.go or serve_v2.go after the WIPCode +// envknob is not needed. +var advertiseCmd = &ffcli.Command{ + Name: "advertise", + ShortUsage: "tailscale advertise --services=", + ShortHelp: "Advertise this node as a destination for a service", + Exec: runAdvertise, + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("advertise") + fs.StringVar(&advertiseArgs.services, "services", "", "comma-separated services to advertise; each must start with \"svc:\" (e.g. \"svc:idp,svc:nas,svc:database\")") + return fs + })(), +} + +func maybeAdvertiseCmd() []*ffcli.Command { + if !envknob.UseWIPCode() { + return nil + } + return []*ffcli.Command{advertiseCmd} +} + +func runAdvertise(ctx context.Context, args []string) error { + if len(args) > 0 { + return flag.ErrHelp + } + + services, err := parseServiceNames(advertiseArgs.services) + if err != nil { + return err + } + + _, err = localClient.EditPrefs(ctx, &ipn.MaskedPrefs{ + AdvertiseServicesSet: true, + Prefs: ipn.Prefs{ + AdvertiseServices: services, + }, + }) + return err +} + +// parseServiceNames takes a comma-separated list of service names +// (eg. "svc:hello,svc:webserver,svc:catphotos"), splits them into +// a list and validates each service name. If valid, it returns +// the service names in a slice of strings. +func parseServiceNames(servicesArg string) ([]string, error) { + var services []string + if servicesArg != "" { + services = strings.Split(servicesArg, ",") + for _, svc := range services { + err := tailcfg.CheckServiceName(svc) + if err != nil { + return nil, fmt.Errorf("service %q: %s", svc, err) + } + } + } + return services, nil +} diff --git a/cmd/tailscale/cli/cli.go b/cmd/tailscale/cli/cli.go index 864cf6903a6d0..de6bc2a4e5e41 100644 --- a/cmd/tailscale/cli/cli.go +++ b/cmd/tailscale/cli/cli.go @@ -177,7 +177,7 @@ For help on subcommands, add --help after: "tailscale status --help". This CLI is still under active development. Commands and flags will change in the future. `), - Subcommands: []*ffcli.Command{ + Subcommands: append([]*ffcli.Command{ upCmd, downCmd, setCmd, @@ -207,7 +207,7 @@ change in the future. debugCmd, driveCmd, idTokenCmd, - }, + }, maybeAdvertiseCmd()...), FlagSet: rootfs, Exec: func(ctx context.Context, args []string) error { if len(args) > 0 { diff --git a/cmd/tailscale/cli/cli_test.go b/cmd/tailscale/cli/cli_test.go index d103c8f7e9f5c..4b75486715731 100644 --- a/cmd/tailscale/cli/cli_test.go +++ b/cmd/tailscale/cli/cli_test.go @@ -946,6 +946,10 @@ func TestPrefFlagMapping(t *testing.T) { // Handled by the tailscale share subcommand, we don't want a CLI // flag for this. continue + case "AdvertiseServices": + // Handled by the tailscale advertise subcommand, we don't want a + // CLI flag for this. + continue case "InternalExitNodePrior": // Used internally by LocalBackend as part of exit node usage toggling. // No CLI flag for this. diff --git a/cmd/tailscale/cli/up.go b/cmd/tailscale/cli/up.go index bf6a9af773f60..782df407deb18 100644 --- a/cmd/tailscale/cli/up.go +++ b/cmd/tailscale/cli/up.go @@ -164,6 +164,9 @@ func defaultNetfilterMode() string { return "on" } +// upArgsT is the type of upArgs, the argument struct for `tailscale up`. +// As of 2024-10-08, upArgsT is frozen and no new arguments should be +// added to it. Add new arguments to setArgsT instead. type upArgsT struct { qr bool reset bool diff --git a/ipn/ipn_clone.go b/ipn/ipn_clone.go index de35b60a7927d..0e9698faf4488 100644 --- a/ipn/ipn_clone.go +++ b/ipn/ipn_clone.go @@ -27,6 +27,7 @@ func (src *Prefs) Clone() *Prefs { *dst = *src dst.AdvertiseTags = append(src.AdvertiseTags[:0:0], src.AdvertiseTags...) dst.AdvertiseRoutes = append(src.AdvertiseRoutes[:0:0], src.AdvertiseRoutes...) + dst.AdvertiseServices = append(src.AdvertiseServices[:0:0], src.AdvertiseServices...) if src.DriveShares != nil { dst.DriveShares = make([]*drive.Share, len(src.DriveShares)) for i := range dst.DriveShares { @@ -61,6 +62,7 @@ var _PrefsCloneNeedsRegeneration = Prefs(struct { ForceDaemon bool Egg bool AdvertiseRoutes []netip.Prefix + AdvertiseServices []string NoSNAT bool NoStatefulFiltering opt.Bool NetfilterMode preftype.NetfilterMode diff --git a/ipn/ipn_view.go b/ipn/ipn_view.go index ff48b9c8975f9..83a7aebb1de43 100644 --- a/ipn/ipn_view.go +++ b/ipn/ipn_view.go @@ -85,6 +85,9 @@ func (v PrefsView) Egg() bool { return v.ж.Eg func (v PrefsView) AdvertiseRoutes() views.Slice[netip.Prefix] { return views.SliceOf(v.ж.AdvertiseRoutes) } +func (v PrefsView) AdvertiseServices() views.Slice[string] { + return views.SliceOf(v.ж.AdvertiseServices) +} func (v PrefsView) NoSNAT() bool { return v.ж.NoSNAT } func (v PrefsView) NoStatefulFiltering() opt.Bool { return v.ж.NoStatefulFiltering } func (v PrefsView) NetfilterMode() preftype.NetfilterMode { return v.ж.NetfilterMode } @@ -120,6 +123,7 @@ var _PrefsViewNeedsRegeneration = Prefs(struct { ForceDaemon bool Egg bool AdvertiseRoutes []netip.Prefix + AdvertiseServices []string NoSNAT bool NoStatefulFiltering opt.Bool NetfilterMode preftype.NetfilterMode diff --git a/ipn/prefs.go b/ipn/prefs.go index 5d61f0119cd23..f5406f3b732e0 100644 --- a/ipn/prefs.go +++ b/ipn/prefs.go @@ -179,6 +179,12 @@ type Prefs struct { // node. AdvertiseRoutes []netip.Prefix + // AdvertiseServices specifies the list of services that this + // node can serve as a destination for. Note that an advertised + // service must still go through the approval process from the + // control server. + AdvertiseServices []string + // NoSNAT specifies whether to source NAT traffic going to // destinations in AdvertiseRoutes. The default is to apply source // NAT, which makes the traffic appear to come from the router @@ -319,6 +325,7 @@ type MaskedPrefs struct { ForceDaemonSet bool `json:",omitempty"` EggSet bool `json:",omitempty"` AdvertiseRoutesSet bool `json:",omitempty"` + AdvertiseServicesSet bool `json:",omitempty"` NoSNATSet bool `json:",omitempty"` NoStatefulFilteringSet bool `json:",omitempty"` NetfilterModeSet bool `json:",omitempty"` @@ -527,6 +534,9 @@ func (p *Prefs) pretty(goos string) string { if len(p.AdvertiseTags) > 0 { fmt.Fprintf(&sb, "tags=%s ", strings.Join(p.AdvertiseTags, ",")) } + if len(p.AdvertiseServices) > 0 { + fmt.Fprintf(&sb, "services=%s ", strings.Join(p.AdvertiseServices, ",")) + } if goos == "linux" { fmt.Fprintf(&sb, "nf=%v ", p.NetfilterMode) } @@ -598,6 +608,7 @@ func (p *Prefs) Equals(p2 *Prefs) bool { p.ForceDaemon == p2.ForceDaemon && compareIPNets(p.AdvertiseRoutes, p2.AdvertiseRoutes) && compareStrings(p.AdvertiseTags, p2.AdvertiseTags) && + compareStrings(p.AdvertiseServices, p2.AdvertiseServices) && p.Persist.Equals(p2.Persist) && p.ProfileName == p2.ProfileName && p.AutoUpdate.Equals(p2.AutoUpdate) && diff --git a/ipn/prefs_test.go b/ipn/prefs_test.go index dcb999ef56a64..31671c0f8e4ef 100644 --- a/ipn/prefs_test.go +++ b/ipn/prefs_test.go @@ -54,6 +54,7 @@ func TestPrefsEqual(t *testing.T) { "ForceDaemon", "Egg", "AdvertiseRoutes", + "AdvertiseServices", "NoSNAT", "NoStatefulFiltering", "NetfilterMode", @@ -330,6 +331,16 @@ func TestPrefsEqual(t *testing.T) { &Prefs{NetfilterKind: ""}, false, }, + { + &Prefs{AdvertiseServices: []string{"svc:tux", "svc:xenia"}}, + &Prefs{AdvertiseServices: []string{"svc:tux", "svc:xenia"}}, + true, + }, + { + &Prefs{AdvertiseServices: []string{"svc:tux", "svc:xenia"}}, + &Prefs{AdvertiseServices: []string{"svc:tux", "svc:amelie"}}, + false, + }, } for i, tt := range tests { got := tt.a.Equals(tt.b) diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 92bf2cd95da15..0e1b1d4aef9bc 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -651,6 +651,21 @@ func CheckTag(tag string) error { return nil } +// CheckServiceName validates svc for use as a service name. +// We only allow valid DNS labels, since the expectation is that these will be +// used as parts of domain names. +func CheckServiceName(svc string) error { + var ok bool + svc, ok = strings.CutPrefix(svc, "svc:") + if !ok { + return errors.New("services must start with 'svc:'") + } + if svc == "" { + return errors.New("service names must not be empty") + } + return dnsname.ValidLabel(svc) +} + // CheckRequestTags checks that all of h.RequestTags are valid. func (h *Hostinfo) CheckRequestTags() error { if h == nil { From fa95318a47a96acd9dafd9829bd0c8c5332ad4c4 Mon Sep 17 00:00:00 2001 From: Andrea Gottardo Date: Thu, 17 Oct 2024 15:37:10 -0700 Subject: [PATCH 021/179] tool/gocross: add support for tvOS Simulator (#13847) Updates ENG-5321 Allow gocross to build a static library for the Apple TV Simulator. Signed-off-by: Andrea Gottardo --- tool/gocross/autoflags.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tool/gocross/autoflags.go b/tool/gocross/autoflags.go index c66cab55a6770..020b19fa58446 100644 --- a/tool/gocross/autoflags.go +++ b/tool/gocross/autoflags.go @@ -146,7 +146,11 @@ func autoflagsForTest(argv []string, env *Environment, goroot, nativeGOOS, nativ case env.IsSet("MACOSX_DEPLOYMENT_TARGET"): xcodeFlags = append(xcodeFlags, "-mmacosx-version-min="+env.Get("MACOSX_DEPLOYMENT_TARGET", "")) case env.IsSet("TVOS_DEPLOYMENT_TARGET"): - xcodeFlags = append(xcodeFlags, "-mtvos-version-min="+env.Get("TVOS_DEPLOYMENT_TARGET", "")) + if env.Get("TARGET_DEVICE_PLATFORM_NAME", "") == "appletvsimulator" { + xcodeFlags = append(xcodeFlags, "-mtvos-simulator-version-min="+env.Get("TVOS_DEPLOYMENT_TARGET", "")) + } else { + xcodeFlags = append(xcodeFlags, "-mtvos-version-min="+env.Get("TVOS_DEPLOYMENT_TARGET", "")) + } default: return nil, nil, fmt.Errorf("invoked by Xcode but couldn't figure out deployment target. Did Xcode change its envvars again?") } From c0a9895748a7d7f39577ca56b2dd25b9c0d4678e Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Thu, 17 Oct 2024 14:12:31 -0400 Subject: [PATCH 022/179] scripts/installer.sh: support DNF5 This fixes the installation on newer Fedora versions that use dnf5 as the 'dnf' binary. Updates #13828 Signed-off-by: Andrew Dunham Change-Id: I39513243c81640fab244a32b7dbb3f32071e9fce --- scripts/installer.sh | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/scripts/installer.sh b/scripts/installer.sh index 19911ee23c8a7..55315c0ce20f7 100755 --- a/scripts/installer.sh +++ b/scripts/installer.sh @@ -488,9 +488,41 @@ main() { set +x ;; dnf) + # DNF 5 has a different argument format; determine which one we have. + DNF_VERSION="3" + if dnf --version | grep -q '^dnf5 version'; then + DNF_VERSION="5" + fi + + # The 'config-manager' plugin wasn't implemented when + # DNF5 was released; detect that and use the old + # version if necessary. + if [ "$DNF_VERSION" = "5" ]; then + set -x + $SUDO dnf install -y 'dnf-command(config-manager)' && DNF_HAVE_CONFIG_MANAGER=1 || DNF_HAVE_CONFIG_MANAGER=0 + set +x + + if [ "$DNF_HAVE_CONFIG_MANAGER" != "1" ]; then + if type dnf-3 >/dev/null; then + DNF_VERSION="3" + else + echo "dnf 5 detected, but 'dnf-command(config-manager)' not available and dnf-3 not found" + exit 1 + fi + fi + fi + set -x - $SUDO dnf install -y 'dnf-command(config-manager)' - $SUDO dnf config-manager --add-repo "https://pkgs.tailscale.com/$TRACK/$OS/$VERSION/tailscale.repo" + if [ "$DNF_VERSION" = "3" ]; then + $SUDO dnf install -y 'dnf-command(config-manager)' + $SUDO dnf config-manager --add-repo "https://pkgs.tailscale.com/$TRACK/$OS/$VERSION/tailscale.repo" + elif [ "$DNF_VERSION" = "5" ]; then + # Already installed config-manager, above. + $SUDO dnf config-manager addrepo --from-repofile="https://pkgs.tailscale.com/$TRACK/$OS/$VERSION/tailscale.repo" + else + echo "unexpected: unknown dnf version $DNF_VERSION" + exit 1 + fi $SUDO dnf install -y tailscale $SUDO systemctl enable --now tailscaled set +x From 18fc093c0df7a04b9d0a396ad3b635e9f859ffa5 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 18 Oct 2024 07:47:05 -0700 Subject: [PATCH 023/179] derp: give trusted mesh peers longer write timeouts Updates tailscale/corp#24014 Change-Id: I700872be48ab337dce8e11cabef7f82b97f0422a Signed-off-by: Brad Fitzpatrick --- derp/derp_server.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/derp/derp_server.go b/derp/derp_server.go index 8c5d6e890567b..94d2263f4bd05 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -74,6 +74,7 @@ func init() { const ( perClientSendQueueDepth = 32 // packets buffered for sending writeTimeout = 2 * time.Second + privilegedWriteTimeout = 30 * time.Second // for clients with the mesh key ) // dupPolicy is a temporary (2021-08-30) mechanism to change the policy @@ -1721,7 +1722,19 @@ func (c *sclient) sendLoop(ctx context.Context) error { } func (c *sclient) setWriteDeadline() { - c.nc.SetWriteDeadline(time.Now().Add(writeTimeout)) + d := writeTimeout + if c.canMesh { + // Trusted peers get more tolerance. + // + // The "canMesh" is a bit of a misnomer; mesh peers typically run over a + // different interface for a per-region private VPC and are not + // throttled. But monitoring software elsewhere over the internet also + // use the private mesh key to subscribe to connect/disconnect events + // and might hit throttling and need more time to get the initial dump + // of connected peers. + d = privilegedWriteTimeout + } + c.nc.SetWriteDeadline(time.Now().Add(d)) } // sendKeepAlive sends a keep-alive frame, without flushing. From bb60da276468a18b5159598f09649289ad5471c3 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Fri, 18 Oct 2024 10:53:49 -0700 Subject: [PATCH 024/179] derp: add sclient write deadline timeout metric (#13831) Write timeouts can be indicative of stalled TCP streams. Understanding changes in the rate of such events can be helpful in an ops context. Updates tailscale/corp#23668 Signed-off-by: Jordan Whited --- derp/derp_server.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/derp/derp_server.go b/derp/derp_server.go index 94d2263f4bd05..2a0f1aa2a38b1 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -26,6 +26,7 @@ import ( "net" "net/http" "net/netip" + "os" "os/exec" "runtime" "strconv" @@ -142,6 +143,7 @@ type Server struct { multiForwarderCreated expvar.Int multiForwarderDeleted expvar.Int removePktForwardOther expvar.Int + sclientWriteTimeouts expvar.Int avgQueueDuration *uint64 // In milliseconds; accessed atomically tcpRtt metrics.LabelMap // histogram meshUpdateBatchSize *metrics.Histogram @@ -882,6 +884,9 @@ func (c *sclient) run(ctx context.Context) error { if errors.Is(err, context.Canceled) { c.debugLogf("sender canceled by reader exiting") } else { + if errors.Is(err, os.ErrDeadlineExceeded) { + c.s.sclientWriteTimeouts.Add(1) + } c.logf("sender failed: %v", err) } } @@ -2073,6 +2078,7 @@ func (s *Server) ExpVar() expvar.Var { m.Set("multiforwarder_created", &s.multiForwarderCreated) m.Set("multiforwarder_deleted", &s.multiForwarderDeleted) m.Set("packet_forwarder_delete_other_value", &s.removePktForwardOther) + m.Set("sclient_write_timeouts", &s.sclientWriteTimeouts) m.Set("average_queue_duration_ms", expvar.Func(func() any { return math.Float64frombits(atomic.LoadUint64(s.avgQueueDuration)) })) From 874db2173b26894b6b48de95fcb462a8c006f7e4 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Sun, 13 Oct 2024 11:36:46 -0500 Subject: [PATCH 025/179] ipn/{ipnauth,ipnlocal,ipnserver}: send the auth URL to the user who started interactive login We add the ClientID() method to the ipnauth.Actor interface and updated ipnserver.actor to implement it. This method returns a unique ID of the connected client if the actor represents one. It helps link a series of interactions initiated by the client, such as when a notification needs to be sent back to a specific session, rather than all active sessions, in response to a certain request. We also add LocalBackend.WatchNotificationsAs and LocalBackend.StartLoginInteractiveAs methods, which are like WatchNotifications and StartLoginInteractive but accept an additional parameter specifying an ipnauth.Actor who initiates the operation. We store these actor identities in watchSession.owner and LocalBackend.authActor, respectively,and implement LocalBackend.sendTo and related helper methods to enable sending notifications to watchSessions associated with actors (or, more broadly, identifiable recipients). We then use the above to change who receives the BrowseToURL notifications: - For user-initiated, interactive logins, the notification is delivered only to the user who initiated the process. If the initiating actor represents a specific connected client, the URL notification is sent back to the same LocalAPI client that called StartLoginInteractive. Otherwise, the notification is sent to all clients connected as that user. Currently, we only differentiate between users on Windows, as it is inherently a multi-user OS. - In all other cases (e.g., node key expiration), we send the notification to all connected users. Updates tailscale/corp#18342 Signed-off-by: Nick Khyl --- ipn/ipnauth/actor.go | 31 ++ ipn/ipnauth/ipnauth_notwindows.go | 4 +- ipn/ipnauth/test_actor.go | 36 ++ ipn/ipnlocal/local.go | 158 +++++++-- ipn/ipnlocal/local_test.go | 540 ++++++++++++++++++++++++++++++ ipn/ipnserver/actor.go | 23 +- ipn/localapi/localapi.go | 4 +- ipn/localapi/localapi_test.go | 19 +- 8 files changed, 762 insertions(+), 53 deletions(-) create mode 100644 ipn/ipnauth/test_actor.go diff --git a/ipn/ipnauth/actor.go b/ipn/ipnauth/actor.go index db3192c9100ad..1070172688a84 100644 --- a/ipn/ipnauth/actor.go +++ b/ipn/ipnauth/actor.go @@ -4,6 +4,8 @@ package ipnauth import ( + "fmt" + "tailscale.com/ipn" ) @@ -20,6 +22,9 @@ type Actor interface { // Username returns the user name associated with the receiver, // or "" if the actor does not represent a specific user. Username() (string, error) + // ClientID returns a non-zero ClientID and true if the actor represents + // a connected LocalAPI client. Otherwise, it returns a zero value and false. + ClientID() (_ ClientID, ok bool) // IsLocalSystem reports whether the actor is the Windows' Local System account. // @@ -45,3 +50,29 @@ type ActorCloser interface { // Close releases resources associated with the receiver. Close() error } + +// ClientID is an opaque, comparable value used to identify a connected LocalAPI +// client, such as a connected Tailscale GUI or CLI. It does not necessarily +// correspond to the same [net.Conn] or any physical session. +// +// Its zero value is valid, but does not represent a specific connected client. +type ClientID struct { + v any +} + +// NoClientID is the zero value of [ClientID]. +var NoClientID ClientID + +// ClientIDFrom returns a new [ClientID] derived from the specified value. +// ClientIDs derived from equal values are equal. +func ClientIDFrom[T comparable](v T) ClientID { + return ClientID{v} +} + +// String implements [fmt.Stringer]. +func (id ClientID) String() string { + if id.v == nil { + return "(none)" + } + return fmt.Sprint(id.v) +} diff --git a/ipn/ipnauth/ipnauth_notwindows.go b/ipn/ipnauth/ipnauth_notwindows.go index 3dad8233a2198..d9d11bd0a17a1 100644 --- a/ipn/ipnauth/ipnauth_notwindows.go +++ b/ipn/ipnauth/ipnauth_notwindows.go @@ -18,7 +18,9 @@ import ( func GetConnIdentity(_ logger.Logf, c net.Conn) (ci *ConnIdentity, err error) { ci = &ConnIdentity{conn: c, notWindows: true} _, ci.isUnixSock = c.(*net.UnixConn) - ci.creds, _ = peercred.Get(c) + if ci.creds, _ = peercred.Get(c); ci.creds != nil { + ci.pid, _ = ci.creds.PID() + } return ci, nil } diff --git a/ipn/ipnauth/test_actor.go b/ipn/ipnauth/test_actor.go new file mode 100644 index 0000000000000..d38aa21968bb2 --- /dev/null +++ b/ipn/ipnauth/test_actor.go @@ -0,0 +1,36 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnauth + +import ( + "tailscale.com/ipn" +) + +var _ Actor = (*TestActor)(nil) + +// TestActor is an [Actor] used exclusively for testing purposes. +type TestActor struct { + UID ipn.WindowsUserID // OS-specific UID of the user, if the actor represents a local Windows user + Name string // username associated with the actor, or "" + NameErr error // error to be returned by [TestActor.Username] + CID ClientID // non-zero if the actor represents a connected LocalAPI client + LocalSystem bool // whether the actor represents the special Local System account on Windows + LocalAdmin bool // whether the actor has local admin access + +} + +// UserID implements [Actor]. +func (a *TestActor) UserID() ipn.WindowsUserID { return a.UID } + +// Username implements [Actor]. +func (a *TestActor) Username() (string, error) { return a.Name, a.NameErr } + +// ClientID implements [Actor]. +func (a *TestActor) ClientID() (_ ClientID, ok bool) { return a.CID, a.CID != NoClientID } + +// IsLocalSystem implements [Actor]. +func (a *TestActor) IsLocalSystem() bool { return a.LocalSystem } + +// IsLocalAdmin implements [Actor]. +func (a *TestActor) IsLocalAdmin(operatorUID string) bool { return a.LocalAdmin } diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index c7df4333b89ec..b01f3a0c0f16a 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -155,10 +155,12 @@ func RegisterNewSSHServer(fn newSSHServerFunc) { newSSHServer = fn } -// watchSession represents a WatchNotifications channel +// watchSession represents a WatchNotifications channel, +// an [ipnauth.Actor] that owns it (e.g., a connected GUI/CLI), // and sessionID as required to close targeted buses. type watchSession struct { ch chan *ipn.Notify + owner ipnauth.Actor // or nil sessionID string cancel func() // call to signal that the session must be terminated } @@ -265,9 +267,9 @@ type LocalBackend struct { endpoints []tailcfg.Endpoint blocked bool keyExpired bool - authURL string // non-empty if not Running - authURLTime time.Time // when the authURL was received from the control server - interact bool // indicates whether a user requested interactive login + authURL string // non-empty if not Running + authURLTime time.Time // when the authURL was received from the control server + authActor ipnauth.Actor // an actor who called [LocalBackend.StartLoginInteractive] last, or nil egg bool prevIfState *netmon.State peerAPIServer *peerAPIServer // or nil @@ -2129,10 +2131,10 @@ func (b *LocalBackend) Start(opts ipn.Options) error { blid := b.backendLogID.String() b.logf("Backend: logs: be:%v fe:%v", blid, opts.FrontendLogID) - b.sendLocked(ipn.Notify{ + b.sendToLocked(ipn.Notify{ BackendLogID: &blid, Prefs: &prefs, - }) + }, allClients) if !loggedOut && (b.hasNodeKeyLocked() || confWantRunning) { // If we know that we're either logged in or meant to be @@ -2657,10 +2659,15 @@ func applyConfigToHostinfo(hi *tailcfg.Hostinfo, c *conffile.Config) { // notifications. There is currently (2022-11-22) no mechanism provided to // detect when a message has been dropped. func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWatchOpt, onWatchAdded func(), fn func(roNotify *ipn.Notify) (keepGoing bool)) { - ch := make(chan *ipn.Notify, 128) + b.WatchNotificationsAs(ctx, nil, mask, onWatchAdded, fn) +} +// WatchNotificationsAs is like WatchNotifications but takes an [ipnauth.Actor] +// as an additional parameter. If non-nil, the specified callback is invoked +// only for notifications relevant to this actor. +func (b *LocalBackend) WatchNotificationsAs(ctx context.Context, actor ipnauth.Actor, mask ipn.NotifyWatchOpt, onWatchAdded func(), fn func(roNotify *ipn.Notify) (keepGoing bool)) { + ch := make(chan *ipn.Notify, 128) sessionID := rands.HexString(16) - origFn := fn if mask&ipn.NotifyNoPrivateKeys != 0 { fn = func(n *ipn.Notify) bool { @@ -2712,6 +2719,7 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa session := &watchSession{ ch: ch, + owner: actor, sessionID: sessionID, cancel: cancel, } @@ -2834,13 +2842,71 @@ func (b *LocalBackend) DebugPickNewDERP() error { // // b.mu must not be held. func (b *LocalBackend) send(n ipn.Notify) { + b.sendTo(n, allClients) +} + +// notificationTarget describes a notification recipient. +// A zero value is valid and indicate that the notification +// should be broadcast to all active [watchSession]s. +type notificationTarget struct { + // userID is the OS-specific UID of the target user. + // If empty, the notification is not user-specific and + // will be broadcast to all connected users. + // TODO(nickkhyl): make this field cross-platform rather + // than Windows-specific. + userID ipn.WindowsUserID + // clientID identifies a client that should be the exclusive recipient + // of the notification. A zero value indicates that notification should + // be sent to all sessions of the specified user. + clientID ipnauth.ClientID +} + +var allClients = notificationTarget{} // broadcast to all connected clients + +// toNotificationTarget returns a [notificationTarget] that matches only actors +// representing the same user as the specified actor. If the actor represents +// a specific connected client, the [ipnauth.ClientID] must also match. +// If the actor is nil, the [notificationTarget] matches all actors. +func toNotificationTarget(actor ipnauth.Actor) notificationTarget { + t := notificationTarget{} + if actor != nil { + t.userID = actor.UserID() + t.clientID, _ = actor.ClientID() + } + return t +} + +// match reports whether the specified actor should receive notifications +// targeting t. If the actor is nil, it should only receive notifications +// intended for all users. +func (t notificationTarget) match(actor ipnauth.Actor) bool { + if t == allClients { + return true + } + if actor == nil { + return false + } + if t.userID != "" && t.userID != actor.UserID() { + return false + } + if t.clientID != ipnauth.NoClientID { + clientID, ok := actor.ClientID() + if !ok || clientID != t.clientID { + return false + } + } + return true +} + +// sendTo is like [LocalBackend.send] but allows specifying a recipient. +func (b *LocalBackend) sendTo(n ipn.Notify, recipient notificationTarget) { b.mu.Lock() defer b.mu.Unlock() - b.sendLocked(n) + b.sendToLocked(n, recipient) } -// sendLocked is like send, but assumes b.mu is already held. -func (b *LocalBackend) sendLocked(n ipn.Notify) { +// sendToLocked is like [LocalBackend.sendTo], but assumes b.mu is already held. +func (b *LocalBackend) sendToLocked(n ipn.Notify, recipient notificationTarget) { if n.Prefs != nil { n.Prefs = ptr.To(stripKeysFromPrefs(*n.Prefs)) } @@ -2854,10 +2920,12 @@ func (b *LocalBackend) sendLocked(n ipn.Notify) { } for _, sess := range b.notifyWatchers { - select { - case sess.ch <- &n: - default: - // Drop the notification if the channel is full. + if recipient.match(sess.owner) { + select { + case sess.ch <- &n: + default: + // Drop the notification if the channel is full. + } } } } @@ -2892,15 +2960,18 @@ func (b *LocalBackend) sendFileNotify() { // This method is called when a new authURL is received from the control plane, meaning that either a user // has started a new interactive login (e.g., by running `tailscale login` or clicking Login in the GUI), // or the control plane was unable to authenticate this node non-interactively (e.g., due to key expiration). -// b.interact indicates whether an interactive login is in progress. +// A non-nil b.authActor indicates that an interactive login is in progress and was initiated by the specified actor. // If url is "", it is equivalent to calling [LocalBackend.resetAuthURLLocked] with b.mu held. func (b *LocalBackend) setAuthURL(url string) { var popBrowser, keyExpired bool + var recipient ipnauth.Actor b.mu.Lock() switch { case url == "": b.resetAuthURLLocked() + b.mu.Unlock() + return case b.authURL != url: b.authURL = url b.authURLTime = b.clock.Now() @@ -2909,26 +2980,27 @@ func (b *LocalBackend) setAuthURL(url string) { popBrowser = true default: // Otherwise, only open it if the user explicitly requests interactive login. - popBrowser = b.interact + popBrowser = b.authActor != nil } keyExpired = b.keyExpired + recipient = b.authActor // or nil // Consume the StartLoginInteractive call, if any, that caused the control // plane to send us this URL. - b.interact = false + b.authActor = nil b.mu.Unlock() if popBrowser { - b.popBrowserAuthNow(url, keyExpired) + b.popBrowserAuthNow(url, keyExpired, recipient) } } -// popBrowserAuthNow shuts down the data plane and sends an auth URL -// to the connected frontend, if any. +// popBrowserAuthNow shuts down the data plane and sends the URL to the recipient's +// [watchSession]s if the recipient is non-nil; otherwise, it sends the URL to all watchSessions. // keyExpired is the value of b.keyExpired upon entry and indicates // whether the node's key has expired. // It must not be called with b.mu held. -func (b *LocalBackend) popBrowserAuthNow(url string, keyExpired bool) { - b.logf("popBrowserAuthNow: url=%v, key-expired=%v, seamless-key-renewal=%v", url != "", keyExpired, b.seamlessRenewalEnabled()) +func (b *LocalBackend) popBrowserAuthNow(url string, keyExpired bool, recipient ipnauth.Actor) { + b.logf("popBrowserAuthNow(%q): url=%v, key-expired=%v, seamless-key-renewal=%v", maybeUsernameOf(recipient), url != "", keyExpired, b.seamlessRenewalEnabled()) // Deconfigure the local network data plane if: // - seamless key renewal is not enabled; @@ -2937,7 +3009,7 @@ func (b *LocalBackend) popBrowserAuthNow(url string, keyExpired bool) { b.blockEngineUpdates(true) b.stopEngineAndWait() } - b.tellClientToBrowseToURL(url) + b.tellRecipientToBrowseToURL(url, toNotificationTarget(recipient)) if b.State() == ipn.Running { b.enterState(ipn.Starting) } @@ -2978,8 +3050,13 @@ func (b *LocalBackend) validPopBrowserURL(urlStr string) bool { } func (b *LocalBackend) tellClientToBrowseToURL(url string) { + b.tellRecipientToBrowseToURL(url, allClients) +} + +// tellRecipientToBrowseToURL is like tellClientToBrowseToURL but allows specifying a recipient. +func (b *LocalBackend) tellRecipientToBrowseToURL(url string, recipient notificationTarget) { if b.validPopBrowserURL(url) { - b.send(ipn.Notify{BrowseToURL: &url}) + b.sendTo(ipn.Notify{BrowseToURL: &url}, recipient) } } @@ -3251,6 +3328,15 @@ func (b *LocalBackend) tryLookupUserName(uid string) string { // StartLoginInteractive attempts to pick up the in-progress flow where it left // off. func (b *LocalBackend) StartLoginInteractive(ctx context.Context) error { + return b.StartLoginInteractiveAs(ctx, nil) +} + +// StartLoginInteractiveAs is like StartLoginInteractive but takes an [ipnauth.Actor] +// as an additional parameter. If non-nil, the specified user is expected to complete +// the interactive login, and therefore will receive the BrowseToURL notification once +// the control plane sends us one. Otherwise, the notification will be delivered to all +// active [watchSession]s. +func (b *LocalBackend) StartLoginInteractiveAs(ctx context.Context, user ipnauth.Actor) error { b.mu.Lock() if b.cc == nil { panic("LocalBackend.assertClient: b.cc == nil") @@ -3264,17 +3350,17 @@ func (b *LocalBackend) StartLoginInteractive(ctx context.Context) error { hasValidURL := url != "" && timeSinceAuthURLCreated < ((7*24*time.Hour)-(1*time.Hour)) if !hasValidURL { // A user wants to log in interactively, but we don't have a valid authURL. - // Set a flag to indicate that interactive login is in progress, forcing - // a BrowseToURL notification once the authURL becomes available. - b.interact = true + // Remember the user who initiated the login, so that we can notify them + // once the authURL is available. + b.authActor = user } cc := b.cc b.mu.Unlock() - b.logf("StartLoginInteractive: url=%v", hasValidURL) + b.logf("StartLoginInteractiveAs(%q): url=%v", maybeUsernameOf(user), hasValidURL) if hasValidURL { - b.popBrowserAuthNow(url, keyExpired) + b.popBrowserAuthNow(url, keyExpired, user) } else { cc.Login(b.loginFlags | controlclient.LoginInteractive) } @@ -5124,7 +5210,7 @@ func (b *LocalBackend) resetControlClientLocked() controlclient.Client { func (b *LocalBackend) resetAuthURLLocked() { b.authURL = "" b.authURLTime = time.Time{} - b.interact = false + b.authActor = nil } // ResetForClientDisconnect resets the backend for GUI clients running @@ -7369,3 +7455,13 @@ func (b *LocalBackend) srcIPHasCapForFilter(srcIP netip.Addr, cap tailcfg.NodeCa } return n.HasCap(cap) } + +// maybeUsernameOf returns the actor's username if the actor +// is non-nil and its username can be resolved. +func maybeUsernameOf(actor ipnauth.Actor) string { + var username string + if actor != nil { + username, _ = actor.Username() + } + return username +} diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index b0e12d5005431..9a8fa5e02df4f 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -15,6 +15,7 @@ import ( "os" "reflect" "slices" + "strings" "sync" "testing" "time" @@ -31,6 +32,7 @@ import ( "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" "tailscale.com/ipn/store/mem" "tailscale.com/net/netcheck" "tailscale.com/net/netmon" @@ -3998,3 +4000,541 @@ func TestFillAllowedSuggestions(t *testing.T) { }) } } + +func TestNotificationTargetMatch(t *testing.T) { + tests := []struct { + name string + target notificationTarget + actor ipnauth.Actor + wantMatch bool + }{ + { + name: "AllClients/Nil", + target: allClients, + actor: nil, + wantMatch: true, + }, + { + name: "AllClients/NoUID/NoCID", + target: allClients, + actor: &ipnauth.TestActor{}, + wantMatch: true, + }, + { + name: "AllClients/WithUID/NoCID", + target: allClients, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.NoClientID}, + wantMatch: true, + }, + { + name: "AllClients/NoUID/WithCID", + target: allClients, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, + }, + { + name: "AllClients/WithUID/WithCID", + target: allClients, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, + }, + { + name: "FilterByUID/Nil", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: nil, + wantMatch: false, + }, + { + name: "FilterByUID/NoUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{}, + wantMatch: false, + }, + { + name: "FilterByUID/NoUID/WithCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("A")}, + wantMatch: false, + }, + { + name: "FilterByUID/SameUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4"}, + wantMatch: true, + }, + { + name: "FilterByUID/DifferentUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8"}, + wantMatch: false, + }, + { + name: "FilterByUID/SameUID/WithCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, + }, + { + name: "FilterByUID/DifferentUID/WithCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: false, + }, + { + name: "FilterByCID/Nil", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: nil, + wantMatch: false, + }, + { + name: "FilterByCID/NoUID/NoCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{}, + wantMatch: false, + }, + { + name: "FilterByCID/NoUID/SameCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, + }, + { + name: "FilterByCID/NoUID/DifferentCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("B")}, + wantMatch: false, + }, + { + name: "FilterByCID/WithUID/NoCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4"}, + wantMatch: false, + }, + { + name: "FilterByCID/WithUID/SameCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, + }, + { + name: "FilterByCID/WithUID/DifferentCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("B")}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/Nil", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: nil, + wantMatch: false, + }, + { + name: "FilterByUID+CID/NoUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/NoUID/SameCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("A")}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/NoUID/DifferentCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("B")}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/SameUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4"}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/SameUID/SameCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, + }, + { + name: "FilterByUID+CID/SameUID/DifferentCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("B")}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/DifferentUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8"}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/DifferentUID/SameCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/DifferentUID/DifferentCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8", CID: ipnauth.ClientIDFrom("B")}, + wantMatch: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotMatch := tt.target.match(tt.actor) + if gotMatch != tt.wantMatch { + t.Errorf("match: got %v; want %v", gotMatch, tt.wantMatch) + } + }) + } +} + +type newTestControlFn func(tb testing.TB, opts controlclient.Options) controlclient.Client + +func newLocalBackendWithTestControl(t *testing.T, enableLogging bool, newControl newTestControlFn) *LocalBackend { + logf := logger.Discard + if enableLogging { + logf = tstest.WhileTestRunningLogger(t) + } + sys := new(tsd.System) + store := new(mem.Store) + sys.Set(store) + e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) + if err != nil { + t.Fatalf("NewFakeUserspaceEngine: %v", err) + } + t.Cleanup(e.Close) + sys.Set(e) + + b, err := NewLocalBackend(logf, logid.PublicID{}, sys, 0) + if err != nil { + t.Fatalf("NewLocalBackend: %v", err) + } + b.DisablePortMapperForTest() + + b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) { + return newControl(t, opts), nil + }) + return b +} + +// notificationHandler is any function that can process (e.g., check) a notification. +// It returns whether the notification has been handled or should be passed to the next handler. +// The handler may be called from any goroutine, so it must avoid calling functions +// that are restricted to the goroutine running the test or benchmark function, +// such as [testing.common.FailNow] and [testing.common.Fatalf]. +type notificationHandler func(testing.TB, ipnauth.Actor, *ipn.Notify) bool + +// wantedNotification names a [notificationHandler] that processes a notification +// the test expects and wants to receive. The name is used to report notifications +// that haven't been received within the expected timeout. +type wantedNotification struct { + name string + cond notificationHandler +} + +// notificationWatcher observes [LocalBackend] notifications as the specified actor, +// reporting missing but expected notifications using [testing.common.Error], +// and delegating the handling of unexpected notifications to the [notificationHandler]s. +type notificationWatcher struct { + tb testing.TB + lb *LocalBackend + actor ipnauth.Actor + + mu sync.Mutex + mask ipn.NotifyWatchOpt + want []wantedNotification // notifications we want to receive + unexpected []notificationHandler // funcs that are called to check any other notifications + ctxCancel context.CancelFunc // cancels the outstanding [LocalBackend.WatchNotificationsAs] call + got []*ipn.Notify // all notifications, both wanted and unexpected, we've received so far + gotWanted []*ipn.Notify // only the expected notifications; holds nil for any notification that hasn't been received + gotWantedCh chan struct{} // closed when we have received the last wanted notification + doneCh chan struct{} // closed when [LocalBackend.WatchNotificationsAs] returns +} + +func newNotificationWatcher(tb testing.TB, lb *LocalBackend, actor ipnauth.Actor) *notificationWatcher { + return ¬ificationWatcher{tb: tb, lb: lb, actor: actor} +} + +func (w *notificationWatcher) watch(mask ipn.NotifyWatchOpt, wanted []wantedNotification, unexpected ...notificationHandler) { + w.tb.Helper() + + // Cancel any outstanding [LocalBackend.WatchNotificationsAs] calls. + w.mu.Lock() + ctxCancel := w.ctxCancel + doneCh := w.doneCh + w.mu.Unlock() + if doneCh != nil { + ctxCancel() + <-doneCh + } + + doneCh = make(chan struct{}) + gotWantedCh := make(chan struct{}) + ctx, ctxCancel := context.WithCancel(context.Background()) + w.tb.Cleanup(func() { + ctxCancel() + <-doneCh + }) + + w.mu.Lock() + w.mask = mask + w.want = wanted + w.unexpected = unexpected + w.ctxCancel = ctxCancel + w.got = nil + w.gotWanted = make([]*ipn.Notify, len(wanted)) + w.gotWantedCh = gotWantedCh + w.doneCh = doneCh + w.mu.Unlock() + + watchAddedCh := make(chan struct{}) + go func() { + defer close(doneCh) + if len(wanted) == 0 { + close(gotWantedCh) + if len(unexpected) == 0 { + close(watchAddedCh) + return + } + } + + var nextWantIdx int + w.lb.WatchNotificationsAs(ctx, w.actor, w.mask, func() { close(watchAddedCh) }, func(notify *ipn.Notify) (keepGoing bool) { + w.tb.Helper() + + w.mu.Lock() + defer w.mu.Unlock() + w.got = append(w.got, notify) + + wanted := false + for i := nextWantIdx; i < len(w.want); i++ { + if wanted = w.want[i].cond(w.tb, w.actor, notify); wanted { + w.gotWanted[i] = notify + nextWantIdx = i + 1 + break + } + } + + if wanted && nextWantIdx == len(w.want) { + close(w.gotWantedCh) + if len(w.unexpected) == 0 { + // If we have received the last wanted notification, + // and we don't have any handlers for the unexpected notifications, + // we can stop the watcher right away. + return false + } + + } + + if !wanted { + // If we've received a notification we didn't expect, + // it could either be an unwanted notification caused by a bug + // or just a miscellaneous one that's irrelevant for the current test. + // Call unexpected notification handlers, if any, to + // check and fail the test if necessary. + for _, h := range w.unexpected { + if h(w.tb, w.actor, notify) { + break + } + } + } + + return true + }) + + }() + <-watchAddedCh +} + +func (w *notificationWatcher) check() []*ipn.Notify { + w.tb.Helper() + + w.mu.Lock() + cancel := w.ctxCancel + gotWantedCh := w.gotWantedCh + checkUnexpected := len(w.unexpected) != 0 + doneCh := w.doneCh + w.mu.Unlock() + + // Wait for up to 10 seconds to receive expected notifications. + timeout := 10 * time.Second + for { + select { + case <-gotWantedCh: + if checkUnexpected { + gotWantedCh = nil + // But do not wait longer than 500ms for unexpected notifications after + // the expected notifications have been received. + timeout = 500 * time.Millisecond + continue + } + case <-doneCh: + // [LocalBackend.WatchNotificationsAs] has already returned, so no further + // notifications will be received. There's no reason to wait any longer. + case <-time.After(timeout): + } + cancel() + <-doneCh + break + } + + // Report missing notifications, if any, and log all received notifications, + // including both expected and unexpected ones. + w.mu.Lock() + defer w.mu.Unlock() + if hasMissing := slices.Contains(w.gotWanted, nil); hasMissing { + want := make([]string, len(w.want)) + got := make([]string, 0, len(w.want)) + for i, wn := range w.want { + want[i] = wn.name + if w.gotWanted[i] != nil { + got = append(got, wn.name) + } + } + w.tb.Errorf("Notifications(%s): got %q; want %q", actorDescriptionForTest(w.actor), strings.Join(got, ", "), strings.Join(want, ", ")) + for i, n := range w.got { + w.tb.Logf("%d. %v", i, n) + } + return nil + } + + return w.gotWanted +} + +func actorDescriptionForTest(actor ipnauth.Actor) string { + var parts []string + if actor != nil { + if name, _ := actor.Username(); name != "" { + parts = append(parts, name) + } + if uid := actor.UserID(); uid != "" { + parts = append(parts, string(uid)) + } + if clientID, _ := actor.ClientID(); clientID != ipnauth.NoClientID { + parts = append(parts, clientID.String()) + } + } + return fmt.Sprintf("Actor{%s}", strings.Join(parts, ", ")) +} + +func TestLoginNotifications(t *testing.T) { + const ( + enableLogging = true + controlURL = "https://localhost:1/" + loginURL = "https://localhost:1/1" + ) + + wantBrowseToURL := wantedNotification{ + name: "BrowseToURL", + cond: func(t testing.TB, actor ipnauth.Actor, n *ipn.Notify) bool { + if n.BrowseToURL != nil && *n.BrowseToURL != loginURL { + t.Errorf("BrowseToURL (%s): got %q; want %q", actorDescriptionForTest(actor), *n.BrowseToURL, loginURL) + return false + } + return n.BrowseToURL != nil + }, + } + unexpectedBrowseToURL := func(t testing.TB, actor ipnauth.Actor, n *ipn.Notify) bool { + if n.BrowseToURL != nil { + t.Errorf("Unexpected BrowseToURL(%s): %v", actorDescriptionForTest(actor), n) + return true + } + return false + } + + tests := []struct { + name string + logInAs ipnauth.Actor + urlExpectedBy []ipnauth.Actor + urlUnexpectedBy []ipnauth.Actor + }{ + { + name: "NoObservers", + logInAs: &ipnauth.TestActor{UID: "A"}, + urlExpectedBy: []ipnauth.Actor{}, // ensure that it does not panic if no one is watching + }, + { + name: "SingleUser", + logInAs: &ipnauth.TestActor{UID: "A"}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A"}}, + }, + { + name: "SameUser/TwoSessions/NoCID", + logInAs: &ipnauth.TestActor{UID: "A"}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A"}, &ipnauth.TestActor{UID: "A"}}, + }, + { + name: "SameUser/TwoSessions/OneWithCID", + logInAs: &ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("123")}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("123")}}, + urlUnexpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A"}}, + }, + { + name: "SameUser/TwoSessions/BothWithCID", + logInAs: &ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("123")}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("123")}}, + urlUnexpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("456")}}, + }, + { + name: "DifferentUsers/NoCID", + logInAs: &ipnauth.TestActor{UID: "A"}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A"}}, + urlUnexpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "B"}}, + }, + { + name: "DifferentUsers/SameCID", + logInAs: &ipnauth.TestActor{UID: "A"}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("123")}}, + urlUnexpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "B", CID: ipnauth.ClientIDFrom("123")}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + lb := newLocalBackendWithTestControl(t, enableLogging, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + return newClient(tb, opts) + }) + if _, err := lb.EditPrefs(&ipn.MaskedPrefs{ControlURLSet: true, Prefs: ipn.Prefs{ControlURL: controlURL}}); err != nil { + t.Fatalf("(*EditPrefs).Start(): %v", err) + } + if err := lb.Start(ipn.Options{}); err != nil { + t.Fatalf("(*LocalBackend).Start(): %v", err) + } + + sessions := make([]*notificationWatcher, 0, len(tt.urlExpectedBy)+len(tt.urlUnexpectedBy)) + for _, actor := range tt.urlExpectedBy { + session := newNotificationWatcher(t, lb, actor) + session.watch(0, []wantedNotification{wantBrowseToURL}) + sessions = append(sessions, session) + } + for _, actor := range tt.urlUnexpectedBy { + session := newNotificationWatcher(t, lb, actor) + session.watch(0, nil, unexpectedBrowseToURL) + sessions = append(sessions, session) + } + + if err := lb.StartLoginInteractiveAs(context.Background(), tt.logInAs); err != nil { + t.Fatal(err) + } + + lb.cc.(*mockControl).send(nil, loginURL, false, nil) + + var wg sync.WaitGroup + wg.Add(len(sessions)) + for _, sess := range sessions { + go func() { // check all sessions in parallel + sess.check() + wg.Done() + }() + } + wg.Wait() + }) + } +} diff --git a/ipn/ipnserver/actor.go b/ipn/ipnserver/actor.go index 761c9816cab27..63d4b183ca11d 100644 --- a/ipn/ipnserver/actor.go +++ b/ipn/ipnserver/actor.go @@ -31,6 +31,7 @@ type actor struct { logf logger.Logf ci *ipnauth.ConnIdentity + clientID ipnauth.ClientID isLocalSystem bool // whether the actor is the Windows' Local System identity. } @@ -39,7 +40,22 @@ func newActor(logf logger.Logf, c net.Conn) (*actor, error) { if err != nil { return nil, err } - return &actor{logf: logf, ci: ci, isLocalSystem: connIsLocalSystem(ci)}, nil + var clientID ipnauth.ClientID + if pid := ci.Pid(); pid != 0 { + // Derive [ipnauth.ClientID] from the PID of the connected client process. + // TODO(nickkhyl): This is transient and will be re-worked as we + // progress on tailscale/corp#18342. At minimum, we should use a 2-tuple + // (PID + StartTime) or a 3-tuple (PID + StartTime + UID) to identify + // the client process. This helps prevent security issues where a + // terminated client process's PID could be reused by a different + // process. This is not currently an issue as we allow only one user to + // connect anyway. + // Additionally, we should consider caching authentication results since + // operations like retrieving a username by SID might require network + // connectivity on domain-joined devices and/or be slow. + clientID = ipnauth.ClientIDFrom(pid) + } + return &actor{logf: logf, ci: ci, clientID: clientID, isLocalSystem: connIsLocalSystem(ci)}, nil } // IsLocalSystem implements [ipnauth.Actor]. @@ -61,6 +77,11 @@ func (a *actor) pid() int { return a.ci.Pid() } +// ClientID implements [ipnauth.Actor]. +func (a *actor) ClientID() (_ ipnauth.ClientID, ok bool) { + return a.clientID, a.clientID != ipnauth.NoClientID +} + // Username implements [ipnauth.Actor]. func (a *actor) Username() (string, error) { if a.ci == nil { diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index 528304bab77d4..25ec1912131df 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -1231,7 +1231,7 @@ func (h *Handler) serveWatchIPNBus(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") ctx := r.Context() enc := json.NewEncoder(w) - h.b.WatchNotifications(ctx, mask, f.Flush, func(roNotify *ipn.Notify) (keepGoing bool) { + h.b.WatchNotificationsAs(ctx, h.Actor, mask, f.Flush, func(roNotify *ipn.Notify) (keepGoing bool) { err := enc.Encode(roNotify) if err != nil { h.logf("json.Encode: %v", err) @@ -1251,7 +1251,7 @@ func (h *Handler) serveLoginInteractive(w http.ResponseWriter, r *http.Request) http.Error(w, "want POST", http.StatusBadRequest) return } - h.b.StartLoginInteractive(r.Context()) + h.b.StartLoginInteractiveAs(r.Context(), h.Actor) w.WriteHeader(http.StatusNoContent) return } diff --git a/ipn/localapi/localapi_test.go b/ipn/localapi/localapi_test.go index fa54a1e756a7e..d89c46261815a 100644 --- a/ipn/localapi/localapi_test.go +++ b/ipn/localapi/localapi_test.go @@ -39,23 +39,6 @@ import ( "tailscale.com/wgengine" ) -var _ ipnauth.Actor = (*testActor)(nil) - -type testActor struct { - uid ipn.WindowsUserID - name string - isLocalSystem bool - isLocalAdmin bool -} - -func (u *testActor) UserID() ipn.WindowsUserID { return u.uid } - -func (u *testActor) Username() (string, error) { return u.name, nil } - -func (u *testActor) IsLocalSystem() bool { return u.isLocalSystem } - -func (u *testActor) IsLocalAdmin(operatorUID string) bool { return u.isLocalAdmin } - func TestValidHost(t *testing.T) { tests := []struct { host string @@ -207,7 +190,7 @@ func TestWhoIsArgTypes(t *testing.T) { func TestShouldDenyServeConfigForGOOSAndUserContext(t *testing.T) { newHandler := func(connIsLocalAdmin bool) *Handler { - return &Handler{Actor: &testActor{isLocalAdmin: connIsLocalAdmin}, b: newTestLocalBackend(t)} + return &Handler{Actor: &ipnauth.TestActor{LocalAdmin: connIsLocalAdmin}, b: newTestLocalBackend(t)} } tests := []struct { name string From 877fa504b429f662d714408397c0ed403a0eda01 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Fri, 18 Oct 2024 13:12:07 -0700 Subject: [PATCH 026/179] net/netcheck: remove arbitrary deadlines from GetReport() tests (#13832) GetReport() may have side effects when the caller enforces a deadline that is shorter than ReportTimeout. Updates #13783 Updates #13394 Signed-off-by: Jordan Whited --- net/netcheck/netcheck_test.go | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index 02076f8d468e1..964014203f05d 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -38,7 +38,7 @@ func TestBasic(t *testing.T) { c := newTestClient(t) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() if err := c.Standalone(ctx, "127.0.0.1:0"); err != nil { @@ -117,7 +117,7 @@ func TestWorksWhenUDPBlocked(t *testing.T) { c := newTestClient(t) - ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() r, err := c.GetReport(ctx, dm, nil) @@ -872,3 +872,30 @@ func TestReportTimeouts(t *testing.T) { t.Errorf("ReportTimeout (%v) cannot be less than httpsProbeTimeout (%v)", ReportTimeout, httpsProbeTimeout) } } + +func TestNoUDPNilGetReportOpts(t *testing.T) { + blackhole, err := net.ListenPacket("udp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to open blackhole STUN listener: %v", err) + } + defer blackhole.Close() + + dm := stuntest.DERPMapOf(blackhole.LocalAddr().String()) + for _, region := range dm.Regions { + for _, n := range region.Nodes { + n.STUNOnly = false // exercise ICMP & HTTPS probing + } + } + + c := newTestClient(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r, err := c.GetReport(ctx, dm, nil) + if err != nil { + t.Fatal(err) + } + if r.UDP { + t.Fatal("unexpected working UDP") + } +} From e711ee5d226c3cc89790a54ffe8fbac7a20c67ed Mon Sep 17 00:00:00 2001 From: Mario Minardi Date: Fri, 18 Oct 2024 14:20:40 -0600 Subject: [PATCH 027/179] release/dist: clamp min / max version for synology package centre (#13857) Clamp the min and max version for DSM 7.0 and DSM 7.2 packages when we are building packages for the synology package centre. This change leaves packages destined for pkgs.tailscale.com with just the min version set to not break packages in the wild / our update flow. Updates https://github.com/tailscale/corp/issues/22908 Signed-off-by: Mario Minardi --- release/dist/synology/pkgs.go | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/release/dist/synology/pkgs.go b/release/dist/synology/pkgs.go index 7802470e167fe..ab89dbee3e19f 100644 --- a/release/dist/synology/pkgs.go +++ b/release/dist/synology/pkgs.go @@ -155,8 +155,22 @@ func (t *target) mkInfo(b *dist.Build, uncompressedSz int64) []byte { f("os_min_ver", "6.0.1-7445") f("os_max_ver", "7.0-40000") case 7: - f("os_min_ver", "7.0-40000") - f("os_max_ver", "") + if t.packageCenter { + switch t.dsmMinorVersion { + case 0: + f("os_min_ver", "7.0-40000") + f("os_max_ver", "7.2-60000") + case 2: + f("os_min_ver", "7.2-60000") + default: + panic(fmt.Sprintf("unsupported DSM major.minor version %s", t.dsmVersionString())) + } + } else { + // We do not clamp the os_max_ver currently for non-package center builds as + // the binaries for 7.0 and 7.2 are identical. + f("os_min_ver", "7.0-40000") + f("os_max_ver", "") + } default: panic(fmt.Sprintf("unsupported DSM major version %d", t.dsmMajorVersion)) } From fd77965f23a317cb6f7bc53d585ace2c771d5b48 Mon Sep 17 00:00:00 2001 From: Andrea Gottardo Date: Fri, 18 Oct 2024 17:35:46 -0700 Subject: [PATCH 028/179] net/tlsdial: call out firewalls blocking Tailscale in health warnings (#13840) Updates tailscale/tailscale#13839 Adds a new blockblame package which can detect common MITM SSL certificates used by network appliances. We use this in `tlsdial` to display a dedicated health warning when we cannot connect to control, and a network appliance MITM attack is detected. Signed-off-by: Andrea Gottardo --- cmd/derper/depaware.txt | 1 + cmd/k8s-operator/depaware.txt | 1 + cmd/tailscale/depaware.txt | 1 + cmd/tailscaled/depaware.txt | 1 + net/tlsdial/blockblame/blockblame.go | 104 ++++++++++++++++++++++ net/tlsdial/blockblame/blockblame_test.go | 54 +++++++++++ net/tlsdial/tlsdial.go | 32 ++++++- 7 files changed, 192 insertions(+), 2 deletions(-) create mode 100644 net/tlsdial/blockblame/blockblame.go create mode 100644 net/tlsdial/blockblame/blockblame_test.go diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 417dbcfb0deb7..362b07882b268 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -113,6 +113,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/net/stunserver from tailscale.com/cmd/derper L tailscale.com/net/tcpinfo from tailscale.com/derp tailscale.com/net/tlsdial from tailscale.com/derp/derphttp + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial tailscale.com/net/tsaddr from tailscale.com/ipn+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/derp/derphttp+ tailscale.com/net/wsconn from tailscale.com/cmd/derper+ diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 66c2c8baef5fb..58a9aa472c143 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -735,6 +735,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/net/stun from tailscale.com/ipn/localapi+ L tailscale.com/net/tcpinfo from tailscale.com/derp tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial tailscale.com/net/tsaddr from tailscale.com/client/web+ tailscale.com/net/tsdial from tailscale.com/control/controlclient+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 73aedc9e5e695..de534df8de397 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -121,6 +121,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/net/stun from tailscale.com/net/netcheck L tailscale.com/net/tcpinfo from tailscale.com/derp tailscale.com/net/tlsdial from tailscale.com/cmd/tailscale/cli+ + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial tailscale.com/net/tsaddr from tailscale.com/client/web+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ tailscale.com/net/wsconn from tailscale.com/control/controlhttp+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 10df37d797f4b..67d8489df769f 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -322,6 +322,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/net/stun from tailscale.com/ipn/localapi+ L tailscale.com/net/tcpinfo from tailscale.com/derp tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial tailscale.com/net/tsaddr from tailscale.com/client/web+ tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ diff --git a/net/tlsdial/blockblame/blockblame.go b/net/tlsdial/blockblame/blockblame.go new file mode 100644 index 0000000000000..57dc7a6e6d885 --- /dev/null +++ b/net/tlsdial/blockblame/blockblame.go @@ -0,0 +1,104 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package blockblame blames specific firewall manufacturers for blocking Tailscale, +// by analyzing the SSL certificate presented when attempting to connect to a remote +// server. +package blockblame + +import ( + "crypto/x509" + "strings" +) + +// VerifyCertificate checks if the given certificate c is issued by a firewall manufacturer +// that is known to block Tailscale connections. It returns true and the Manufacturer of +// the equipment if it is, or false and nil if it is not. +func VerifyCertificate(c *x509.Certificate) (m *Manufacturer, ok bool) { + for _, m := range Manufacturers { + if m.match != nil && m.match(c) { + return m, true + } + } + return nil, false +} + +// Manufacturer represents a firewall manufacturer that may be blocking Tailscale. +type Manufacturer struct { + // Name is the name of the firewall manufacturer to be + // mentioned in health warning messages, e.g. "Fortinet". + Name string + // match is a function that returns true if the given certificate looks like it might + // be issued by this manufacturer. + match matchFunc +} + +var Manufacturers = []*Manufacturer{ + { + Name: "Aruba Networks", + match: issuerContains("Aruba"), + }, + { + Name: "Cisco", + match: issuerContains("Cisco"), + }, + { + Name: "Fortinet", + match: matchAny( + issuerContains("Fortinet"), + certEmail("support@fortinet.com"), + ), + }, + { + Name: "Huawei", + match: certEmail("mobile@huawei.com"), + }, + { + Name: "Palo Alto Networks", + match: matchAny( + issuerContains("Palo Alto Networks"), + issuerContains("PAN-FW"), + ), + }, + { + Name: "Sophos", + match: issuerContains("Sophos"), + }, + { + Name: "Ubiquiti", + match: matchAny( + issuerContains("UniFi"), + issuerContains("Ubiquiti"), + ), + }, +} + +type matchFunc func(*x509.Certificate) bool + +func issuerContains(s string) matchFunc { + return func(c *x509.Certificate) bool { + return strings.Contains(strings.ToLower(c.Issuer.String()), strings.ToLower(s)) + } +} + +func certEmail(v string) matchFunc { + return func(c *x509.Certificate) bool { + for _, email := range c.EmailAddresses { + if strings.Contains(strings.ToLower(email), strings.ToLower(v)) { + return true + } + } + return false + } +} + +func matchAny(fs ...matchFunc) matchFunc { + return func(c *x509.Certificate) bool { + for _, f := range fs { + if f(c) { + return true + } + } + return false + } +} diff --git a/net/tlsdial/blockblame/blockblame_test.go b/net/tlsdial/blockblame/blockblame_test.go new file mode 100644 index 0000000000000..6d3592c60a3de --- /dev/null +++ b/net/tlsdial/blockblame/blockblame_test.go @@ -0,0 +1,54 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package blockblame + +import ( + "crypto/x509" + "encoding/pem" + "testing" +) + +const controlplaneDotTailscaleDotComPEM = ` +-----BEGIN CERTIFICATE----- +MIIDkzCCAxqgAwIBAgISA2GOahsftpp59yuHClbDuoduMAoGCCqGSM49BAMDMDIx +CzAJBgNVBAYTAlVTMRYwFAYDVQQKEw1MZXQncyBFbmNyeXB0MQswCQYDVQQDEwJF +NjAeFw0yNDEwMTIxNjE2NDVaFw0yNTAxMTAxNjE2NDRaMCUxIzAhBgNVBAMTGmNv +bnRyb2xwbGFuZS50YWlsc2NhbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcD +QgAExfraDUc1t185zuGtZlnPDtEJJSDBqvHN4vQcXSzSTPSAdDYHcA8fL5woU2Kg +jK/2C0wm/rYy2Rre/ulhkS4wB6OCAhswggIXMA4GA1UdDwEB/wQEAwIHgDAdBgNV +HSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAdBgNVHQ4E +FgQUpArnpDj8Yh6NTgMOZjDPx0TuLmcwHwYDVR0jBBgwFoAUkydGmAOpUWiOmNbE +QkjbI79YlNIwVQYIKwYBBQUHAQEESTBHMCEGCCsGAQUFBzABhhVodHRwOi8vZTYu +by5sZW5jci5vcmcwIgYIKwYBBQUHMAKGFmh0dHA6Ly9lNi5pLmxlbmNyLm9yZy8w +JQYDVR0RBB4wHIIaY29udHJvbHBsYW5lLnRhaWxzY2FsZS5jb20wEwYDVR0gBAww +CjAIBgZngQwBAgEwggEDBgorBgEEAdZ5AgQCBIH0BIHxAO8AdgDgkrP8DB3I52g2 +H95huZZNClJ4GYpy1nLEsE2lbW9UBAAAAZKBujCyAAAEAwBHMEUCIQDHMgUaL4H9 +ZJa090ZOpBeEVu3+t+EF4HlHI1NqAai6uQIgeY/lLfjAXfcVgxBHHR4zjd0SzhaP +TREHXzwxzN/8blkAdQDPEVbu1S58r/OHW9lpLpvpGnFnSrAX7KwB0lt3zsw7CAAA +AZKBujh8AAAEAwBGMEQCICQwhMk45t9aiFjfwOC/y6+hDbszqSCpIv63kFElweUy +AiAqTdkqmbqUVpnav5JdWkNERVAIlY4jqrThLsCLZYbNszAKBggqhkjOPQQDAwNn +ADBkAjALyfgAt1XQp1uSfxy4GapR5OsmjEMBRVq6IgsPBlCRBfmf0Q3/a6mF0pjb +Sj4oa+cCMEhZk4DmBTIdZY9zjuh8s7bXNfKxUQS0pEhALtXqyFr+D5dF7JcQo9+s +Z98JY7/PCA== +-----END CERTIFICATE-----` + +func TestVerifyCertificateOurControlPlane(t *testing.T) { + p, _ := pem.Decode([]byte(controlplaneDotTailscaleDotComPEM)) + if p == nil { + t.Fatalf("failed to extract certificate bytes for controlplane.tailscale.com") + return + } + cert, err := x509.ParseCertificate(p.Bytes) + if err != nil { + t.Fatalf("failed to parse certificate: %v", err) + return + } + m, found := VerifyCertificate(cert) + if found { + t.Fatalf("expected to not get a result for the controlplane.tailscale.com certificate") + } + if m != nil { + t.Fatalf("expected nil manufacturer for controlplane.tailscale.com certificate") + } +} diff --git a/net/tlsdial/tlsdial.go b/net/tlsdial/tlsdial.go index a49e7f0f730ee..7e847a8b6a656 100644 --- a/net/tlsdial/tlsdial.go +++ b/net/tlsdial/tlsdial.go @@ -27,6 +27,7 @@ import ( "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/hostinfo" + "tailscale.com/net/tlsdial/blockblame" ) var counterFallbackOK int32 // atomic @@ -44,6 +45,16 @@ var debug = envknob.RegisterBool("TS_DEBUG_TLS_DIAL") // Headscale, etc. var tlsdialWarningPrinted sync.Map // map[string]bool +var mitmBlockWarnable = health.Register(&health.Warnable{ + Code: "blockblame-mitm-detected", + Title: "Network may be blocking Tailscale", + Text: func(args health.Args) string { + return fmt.Sprintf("Network equipment from %q may be blocking Tailscale traffic on this network. Connect to another network, or contact your network administrator for assistance.", args["manufacturer"]) + }, + Severity: health.SeverityMedium, + ImpactsConnectivity: true, +}) + // Config returns a tls.Config for connecting to a server. // If base is non-nil, it's cloned as the base config before // being configured and returned. @@ -86,12 +97,29 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config { // Perform some health checks on this certificate before we do // any verification. + var cert *x509.Certificate var selfSignedIssuer string - if certs := cs.PeerCertificates; len(certs) > 0 && certIsSelfSigned(certs[0]) { - selfSignedIssuer = certs[0].Issuer.String() + if certs := cs.PeerCertificates; len(certs) > 0 { + cert = certs[0] + if certIsSelfSigned(cert) { + selfSignedIssuer = cert.Issuer.String() + } } if ht != nil { defer func() { + if retErr != nil && cert != nil { + // Is it a MITM SSL certificate from a well-known network appliance manufacturer? + // Show a dedicated warning. + m, ok := blockblame.VerifyCertificate(cert) + if ok { + log.Printf("tlsdial: server cert for %q looks like %q equipment (could be blocking Tailscale)", host, m.Name) + ht.SetUnhealthy(mitmBlockWarnable, health.Args{"manufacturer": m.Name}) + } else { + ht.SetHealthy(mitmBlockWarnable) + } + } else { + ht.SetHealthy(mitmBlockWarnable) + } if retErr != nil && selfSignedIssuer != "" { // Self-signed certs are never valid. // From c76a6e5167d4f669a91818d502a642c1634251e7 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sun, 20 Oct 2024 13:22:31 -0700 Subject: [PATCH 029/179] derp: track client-advertised non-ideal DERP connections in more places In f77821fd63 (released in v1.72.0), we made the client tell a DERP server when the connection was not its ideal choice (the first node in its region). But we didn't do anything with that information until now. This adds a metric about how many such connections are on a given derper, and also adds a bit to the PeerPresentFlags bitmask so watchers can identify (and rebalance) them. Updates tailscale/corp#372 Change-Id: Ief8af448750aa6d598e5939a57c062f4e55962be Signed-off-by: Brad Fitzpatrick --- cmd/tailscale/depaware.txt | 2 +- derp/derp.go | 1 + derp/derp_server.go | 30 ++++++++++++++++++++++++++---- derp/derphttp/derphttp_client.go | 2 +- derp/derphttp/derphttp_server.go | 8 +++++++- 5 files changed, 36 insertions(+), 7 deletions(-) diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index de534df8de397..765bbc483e57e 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -155,7 +155,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/util/clientmetric from tailscale.com/net/netcheck+ tailscale.com/util/cloudenv from tailscale.com/net/dnscache+ tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy+ - tailscale.com/util/ctxkey from tailscale.com/types/logger + tailscale.com/util/ctxkey from tailscale.com/types/logger+ 💣 tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics tailscale.com/util/dnsname from tailscale.com/cmd/tailscale/cli+ diff --git a/derp/derp.go b/derp/derp.go index f9b0706477358..878188cd20625 100644 --- a/derp/derp.go +++ b/derp/derp.go @@ -147,6 +147,7 @@ const ( PeerPresentIsRegular = 1 << 0 PeerPresentIsMeshPeer = 1 << 1 PeerPresentIsProber = 1 << 2 + PeerPresentNotIdeal = 1 << 3 // client said derp server is not its Region.Nodes[0] ideal node ) var bin = binary.BigEndian diff --git a/derp/derp_server.go b/derp/derp_server.go index 2a0f1aa2a38b1..ab0ab0a908a07 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -47,6 +47,7 @@ import ( "tailscale.com/tstime/rate" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/util/ctxkey" "tailscale.com/util/mak" "tailscale.com/util/set" "tailscale.com/util/slicesx" @@ -57,6 +58,16 @@ import ( // verbosely log whenever DERP drops a packet. var verboseDropKeys = map[key.NodePublic]bool{} +// IdealNodeHeader is the HTTP request header sent on DERP HTTP client requests +// to indicate that they're connecting to their ideal (Region.Nodes[0]) node. +// The HTTP header value is the name of the node they wish they were connected +// to. This is an optional header. +const IdealNodeHeader = "Ideal-Node" + +// IdealNodeContextKey is the context key used to pass the IdealNodeHeader value +// from the HTTP handler to the DERP server's Accept method. +var IdealNodeContextKey = ctxkey.New[string]("ideal-node", "") + func init() { keys := envknob.String("TS_DEBUG_VERBOSE_DROPS") if keys == "" { @@ -133,6 +144,7 @@ type Server struct { sentPong expvar.Int // number of pong frames enqueued to client accepts expvar.Int curClients expvar.Int + curClientsNotIdeal expvar.Int curHomeClients expvar.Int // ones with preferred dupClientKeys expvar.Int // current number of public keys we have 2+ connections for dupClientConns expvar.Int // current number of connections sharing a public key @@ -603,6 +615,9 @@ func (s *Server) registerClient(c *sclient) { } s.keyOfAddr[c.remoteIPPort] = c.key s.curClients.Add(1) + if c.isNotIdealConn { + s.curClientsNotIdeal.Add(1) + } s.broadcastPeerStateChangeLocked(c.key, c.remoteIPPort, c.presentFlags(), true) } @@ -693,6 +708,9 @@ func (s *Server) unregisterClient(c *sclient) { if c.preferred { s.curHomeClients.Add(-1) } + if c.isNotIdealConn { + s.curClientsNotIdeal.Add(-1) + } } // addPeerGoneFromRegionWatcher adds a function to be called when peer is gone @@ -809,8 +827,8 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem return fmt.Errorf("receive client key: %v", err) } - clientAP, _ := netip.ParseAddrPort(remoteAddr) - if err := s.verifyClient(ctx, clientKey, clientInfo, clientAP.Addr()); err != nil { + remoteIPPort, _ := netip.ParseAddrPort(remoteAddr) + if err := s.verifyClient(ctx, clientKey, clientInfo, remoteIPPort.Addr()); err != nil { return fmt.Errorf("client %v rejected: %v", clientKey, err) } @@ -820,8 +838,6 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem ctx, cancel := context.WithCancel(ctx) defer cancel() - remoteIPPort, _ := netip.ParseAddrPort(remoteAddr) - c := &sclient{ connNum: connNum, s: s, @@ -838,6 +854,7 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem sendPongCh: make(chan [8]byte, 1), peerGone: make(chan peerGoneMsg), canMesh: s.isMeshPeer(clientInfo), + isNotIdealConn: IdealNodeContextKey.Value(ctx) != "", peerGoneLim: rate.NewLimiter(rate.Every(time.Second), 3), } @@ -1511,6 +1528,7 @@ type sclient struct { peerGone chan peerGoneMsg // write request that a peer is not at this server (not used by mesh peers) meshUpdate chan struct{} // write request to write peerStateChange canMesh bool // clientInfo had correct mesh token for inter-region routing + isNotIdealConn bool // client indicated it is not its ideal node in the region isDup atomic.Bool // whether more than 1 sclient for key is connected isDisabled atomic.Bool // whether sends to this peer are disabled due to active/active dups debug bool // turn on for verbose logging @@ -1546,6 +1564,9 @@ func (c *sclient) presentFlags() PeerPresentFlags { if c.canMesh { f |= PeerPresentIsMeshPeer } + if c.isNotIdealConn { + f |= PeerPresentNotIdeal + } if f == 0 { return PeerPresentIsRegular } @@ -2051,6 +2072,7 @@ func (s *Server) ExpVar() expvar.Var { m.Set("gauge_current_file_descriptors", expvar.Func(func() any { return metrics.CurrentFDs() })) m.Set("gauge_current_connections", &s.curClients) m.Set("gauge_current_home_connections", &s.curHomeClients) + m.Set("gauge_current_notideal_connections", &s.curClientsNotIdeal) m.Set("gauge_clients_total", expvar.Func(func() any { return len(s.clientsMesh) })) m.Set("gauge_clients_local", expvar.Func(func() any { return len(s.clients) })) m.Set("gauge_clients_remote", expvar.Func(func() any { return len(s.clientsMesh) - len(s.clients) })) diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index b8cce8cdcb4fa..b695a52a89606 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -498,7 +498,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien req.Header.Set("Connection", "Upgrade") if !idealNodeInRegion && reg != nil { // This is purely informative for now (2024-07-06) for stats: - req.Header.Set("Ideal-Node", reg.Nodes[0].Name) + req.Header.Set(derp.IdealNodeHeader, reg.Nodes[0].Name) // TODO(bradfitz,raggi): start a time.AfterFunc for 30m-1h or so to // dialNode(reg.Nodes[0]) and see if we can even TCP connect to it. If // so, TLS handshake it as well (which is mixed up in this massive diff --git a/derp/derphttp/derphttp_server.go b/derp/derphttp/derphttp_server.go index 41ce86764f66a..ed7d3d7073866 100644 --- a/derp/derphttp/derphttp_server.go +++ b/derp/derphttp/derphttp_server.go @@ -21,6 +21,8 @@ const fastStartHeader = "Derp-Fast-Start" // Handler returns an http.Handler to be mounted at /derp, serving s. func Handler(s *derp.Server) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + // These are installed both here and in cmd/derper. The check here // catches both cmd/derper run with DERP disabled (STUN only mode) as // well as DERP being run in tests with derphttp.Handler directly, @@ -66,7 +68,11 @@ func Handler(s *derp.Server) http.Handler { pubKey.UntypedHexString()) } - s.Accept(r.Context(), netConn, conn, netConn.RemoteAddr().String()) + if v := r.Header.Get(derp.IdealNodeHeader); v != "" { + ctx = derp.IdealNodeContextKey.WithValue(ctx, v) + } + + s.Accept(ctx, netConn, conn, netConn.RemoteAddr().String()) }) } From 72587ab03cd5b4dc751d007c7c5c060b96b39ec3 Mon Sep 17 00:00:00 2001 From: Erisa A Date: Mon, 21 Oct 2024 18:13:06 +0100 Subject: [PATCH 030/179] scripts/installer.sh: allow Archcraft for Arch packages (#13870) Fixes #13869 Signed-off-by: Erisa A --- scripts/installer.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/installer.sh b/scripts/installer.sh index 55315c0ce20f7..d2971978eebe7 100755 --- a/scripts/installer.sh +++ b/scripts/installer.sh @@ -224,7 +224,7 @@ main() { VERSION="leap/15.4" PACKAGETYPE="zypper" ;; - arch|archarm|endeavouros|blendos|garuda) + arch|archarm|endeavouros|blendos|garuda|archcraft) OS="arch" VERSION="" # rolling release PACKAGETYPE="pacman" From f8f53bb6d47526cd5819039d6fa52a050eabc22c Mon Sep 17 00:00:00 2001 From: Andrea Gottardo Date: Mon, 21 Oct 2024 13:40:43 -0700 Subject: [PATCH 031/179] health: remove SysDNSOS, add two Warnables for read+set system DNS config (#13874) --- health/health.go | 15 +-------------- net/dns/manager.go | 28 +++++++++++++++++++++++++--- net/dns/resolved.go | 9 ++++++--- 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/health/health.go b/health/health.go index 216535d17c484..16b41f075fee8 100644 --- a/health/health.go +++ b/health/health.go @@ -128,9 +128,6 @@ const ( // SysDNS is the name of the net/dns subsystem. SysDNS = Subsystem("dns") - // SysDNSOS is the name of the net/dns OSConfigurator subsystem. - SysDNSOS = Subsystem("dns-os") - // SysDNSManager is the name of the net/dns manager subsystem. SysDNSManager = Subsystem("dns-manager") @@ -141,7 +138,7 @@ const ( var subsystemsWarnables = map[Subsystem]*Warnable{} func init() { - for _, s := range []Subsystem{SysRouter, SysDNS, SysDNSOS, SysDNSManager, SysTKA} { + for _, s := range []Subsystem{SysRouter, SysDNS, SysDNSManager, SysTKA} { w := Register(&Warnable{ Code: WarnableCode(s), Severity: SeverityMedium, @@ -510,22 +507,12 @@ func (t *Tracker) SetDNSHealth(err error) { t.setErr(SysDNS, err) } // Deprecated: Warnables should be preferred over Subsystem errors. func (t *Tracker) DNSHealth() error { return t.get(SysDNS) } -// SetDNSOSHealth sets the state of the net/dns.OSConfigurator -// -// Deprecated: Warnables should be preferred over Subsystem errors. -func (t *Tracker) SetDNSOSHealth(err error) { t.setErr(SysDNSOS, err) } - // SetDNSManagerHealth sets the state of the Linux net/dns manager's // discovery of the /etc/resolv.conf situation. // // Deprecated: Warnables should be preferred over Subsystem errors. func (t *Tracker) SetDNSManagerHealth(err error) { t.setErr(SysDNSManager, err) } -// DNSOSHealth returns the net/dns.OSConfigurator error state. -// -// Deprecated: Warnables should be preferred over Subsystem errors. -func (t *Tracker) DNSOSHealth() error { return t.get(SysDNSOS) } - // SetTKAHealth sets the health of the tailnet key authority. // // Deprecated: Warnables should be preferred over Subsystem errors. diff --git a/net/dns/manager.go b/net/dns/manager.go index 51a0fa12cba63..13cb2d84e1930 100644 --- a/net/dns/manager.go +++ b/net/dns/manager.go @@ -8,6 +8,7 @@ import ( "context" "encoding/binary" "errors" + "fmt" "io" "net" "net/netip" @@ -156,11 +157,11 @@ func (m *Manager) setLocked(cfg Config) error { return err } if err := m.os.SetDNS(ocfg); err != nil { - m.health.SetDNSOSHealth(err) + m.health.SetUnhealthy(osConfigurationSetWarnable, health.Args{health.ArgError: err.Error()}) return err } - m.health.SetDNSOSHealth(nil) + m.health.SetHealthy(osConfigurationSetWarnable) m.config = &cfg return nil @@ -217,6 +218,26 @@ func compileHostEntries(cfg Config) (hosts []*HostEntry) { return hosts } +var osConfigurationReadWarnable = health.Register(&health.Warnable{ + Code: "dns-read-os-config-failed", + Title: "Failed to read system DNS configuration", + Text: func(args health.Args) string { + return fmt.Sprintf("Tailscale failed to fetch the DNS configuration of your device: %v", args[health.ArgError]) + }, + Severity: health.SeverityLow, + DependsOn: []*health.Warnable{health.NetworkStatusWarnable}, +}) + +var osConfigurationSetWarnable = health.Register(&health.Warnable{ + Code: "dns-set-os-config-failed", + Title: "Failed to set system DNS configuration", + Text: func(args health.Args) string { + return fmt.Sprintf("Tailscale failed to set the DNS configuration of your device: %v", args[health.ArgError]) + }, + Severity: health.SeverityMedium, + DependsOn: []*health.Warnable{health.NetworkStatusWarnable}, +}) + // compileConfig converts cfg into a quad-100 resolver configuration // and an OS-level configuration. func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig, err error) { @@ -320,9 +341,10 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig // This is currently (2022-10-13) expected on certain iOS and macOS // builds. } else { - m.health.SetDNSOSHealth(err) + m.health.SetUnhealthy(osConfigurationReadWarnable, health.Args{health.ArgError: err.Error()}) return resolver.Config{}, OSConfig{}, err } + m.health.SetHealthy(osConfigurationReadWarnable) } if baseCfg == nil { diff --git a/net/dns/resolved.go b/net/dns/resolved.go index d82d3fc31d80a..1a7c8604101db 100644 --- a/net/dns/resolved.go +++ b/net/dns/resolved.go @@ -163,9 +163,9 @@ func (m *resolvedManager) run(ctx context.Context) { } conn.Signal(signals) - // Reset backoff and SetNSOSHealth after successful on reconnect. + // Reset backoff and set osConfigurationSetWarnable to healthy after a successful reconnect. bo.BackOff(ctx, nil) - m.health.SetDNSOSHealth(nil) + m.health.SetHealthy(osConfigurationSetWarnable) return nil } @@ -243,9 +243,12 @@ func (m *resolvedManager) run(ctx context.Context) { // Set health while holding the lock, because this will // graciously serialize the resync's health outcome with a // concurrent SetDNS call. - m.health.SetDNSOSHealth(err) + if err != nil { m.logf("failed to configure systemd-resolved: %v", err) + m.health.SetUnhealthy(osConfigurationSetWarnable, health.Args{health.ArgError: err.Error()}) + } else { + m.health.SetHealthy(osConfigurationSetWarnable) } } } From 0f4c9c0ecb133f2e7e3df2626e2a6a114d6dc251 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Mon, 21 Oct 2024 12:28:41 -0500 Subject: [PATCH 032/179] cmd/viewer: import types/views when generating a getter for a map field Fixes #13873 Signed-off-by: Nick Khyl --- cmd/viewer/viewer.go | 1 + cmd/viewer/viewer_test.go | 78 +++++++++++++++++++++++++++++++++++++++ util/codegen/codegen.go | 5 +++ 3 files changed, 84 insertions(+) create mode 100644 cmd/viewer/viewer_test.go diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go index 96223297b46e2..0c5868f3a86e6 100644 --- a/cmd/viewer/viewer.go +++ b/cmd/viewer/viewer.go @@ -258,6 +258,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi writeTemplate("unsupportedField") continue } + it.Import("tailscale.com/types/views") args.MapKeyType = it.QualifiedName(key) mElem := m.Elem() var template string diff --git a/cmd/viewer/viewer_test.go b/cmd/viewer/viewer_test.go new file mode 100644 index 0000000000000..cd5f3d95f9c93 --- /dev/null +++ b/cmd/viewer/viewer_test.go @@ -0,0 +1,78 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/token" + "go/types" + "testing" + + "tailscale.com/util/codegen" +) + +func TestViewerImports(t *testing.T) { + tests := []struct { + name string + content string + typeNames []string + wantImports []string + }{ + { + name: "Map", + content: `type Test struct { Map map[string]int }`, + typeNames: []string{"Test"}, + wantImports: []string{"tailscale.com/types/views"}, + }, + { + name: "Slice", + content: `type Test struct { Slice []int }`, + typeNames: []string{"Test"}, + wantImports: []string{"tailscale.com/types/views"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", "package test\n\n"+tt.content, 0) + if err != nil { + fmt.Println("Error parsing:", err) + return + } + + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + } + + conf := types.Config{} + pkg, err := conf.Check("", fset, []*ast.File{f}, info) + if err != nil { + t.Fatal(err) + } + + var output bytes.Buffer + tracker := codegen.NewImportTracker(pkg) + for i := range tt.typeNames { + typeName, ok := pkg.Scope().Lookup(tt.typeNames[i]).(*types.TypeName) + if !ok { + t.Fatalf("type %q does not exist", tt.typeNames[i]) + } + namedType, ok := typeName.Type().(*types.Named) + if !ok { + t.Fatalf("%q is not a named type", tt.typeNames[i]) + } + genView(&output, tracker, namedType, pkg) + } + + for _, pkgName := range tt.wantImports { + if !tracker.Has(pkgName) { + t.Errorf("missing import %q", pkgName) + } + } + }) + } +} diff --git a/util/codegen/codegen.go b/util/codegen/codegen.go index d998d925d9143..2f7781b681a24 100644 --- a/util/codegen/codegen.go +++ b/util/codegen/codegen.go @@ -97,6 +97,11 @@ func (it *ImportTracker) Import(pkg string) { } } +// Has reports whether the specified package has been imported. +func (it *ImportTracker) Has(pkg string) bool { + return it.packages[pkg] +} + func (it *ImportTracker) qualifier(pkg *types.Package) string { if it.thisPkg == pkg { return "" From d4d21a0bbf2c1bd6f0de1bc654d7bd475ef1661e Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Mon, 21 Oct 2024 16:17:28 -0700 Subject: [PATCH 033/179] net/tstun: restore tap mode functionality It had bit-rotted likely during the transition to vector io in 76389d8baf942b10a8f0f4201b7c4b0737a0172c. Tested on Ubuntu 24.04 by creating a netns and doing the DHCP dance to get an IP. Updates #2589 Signed-off-by: Maisem Ali --- cmd/k8s-operator/depaware.txt | 2 +- cmd/tailscaled/depaware.txt | 2 +- net/tstun/tap_linux.go | 121 +++++++++++++++++++++------------- net/tstun/tap_unsupported.go | 8 --- net/tstun/tun.go | 4 +- net/tstun/wrap.go | 41 +++--------- 6 files changed, 88 insertions(+), 90 deletions(-) delete mode 100644 net/tstun/tap_unsupported.go diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 58a9aa472c143..19d6808d75a3e 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -310,7 +310,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ gvisor.dev/gvisor/pkg/tcpip/network/internal/ip from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ gvisor.dev/gvisor/pkg/tcpip/network/internal/multicast from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ gvisor.dev/gvisor/pkg/tcpip/network/ipv4 from tailscale.com/net/tstun+ - gvisor.dev/gvisor/pkg/tcpip/network/ipv6 from tailscale.com/wgengine/netstack + gvisor.dev/gvisor/pkg/tcpip/network/ipv6 from tailscale.com/wgengine/netstack+ gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+ gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+ 💣 gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 67d8489df769f..26165d659afac 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -221,7 +221,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de gvisor.dev/gvisor/pkg/tcpip/network/internal/ip from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ gvisor.dev/gvisor/pkg/tcpip/network/internal/multicast from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ gvisor.dev/gvisor/pkg/tcpip/network/ipv4 from tailscale.com/net/tstun+ - gvisor.dev/gvisor/pkg/tcpip/network/ipv6 from tailscale.com/wgengine/netstack + gvisor.dev/gvisor/pkg/tcpip/network/ipv6 from tailscale.com/wgengine/netstack+ gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+ gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+ 💣 gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ diff --git a/net/tstun/tap_linux.go b/net/tstun/tap_linux.go index c721e6e2734b5..c366b05604f24 100644 --- a/net/tstun/tap_linux.go +++ b/net/tstun/tap_linux.go @@ -6,6 +6,7 @@ package tstun import ( + "bytes" "fmt" "net" "net/netip" @@ -20,10 +21,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checksum" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "tailscale.com/net/netaddr" "tailscale.com/net/packet" + "tailscale.com/syncs" "tailscale.com/types/ipproto" + "tailscale.com/types/logger" "tailscale.com/util/multierr" ) @@ -35,13 +39,13 @@ var ourMAC = net.HardwareAddr{0x30, 0x2D, 0x66, 0xEC, 0x7A, 0x93} func init() { createTAP = createTAPLinux } -func createTAPLinux(tapName, bridgeName string) (tun.Device, error) { +func createTAPLinux(logf logger.Logf, tapName, bridgeName string) (tun.Device, error) { fd, err := unix.Open("/dev/net/tun", unix.O_RDWR, 0) if err != nil { return nil, err } - dev, err := openDevice(fd, tapName, bridgeName) + dev, err := openDevice(logf, fd, tapName, bridgeName) if err != nil { unix.Close(fd) return nil, err @@ -50,7 +54,7 @@ func createTAPLinux(tapName, bridgeName string) (tun.Device, error) { return dev, nil } -func openDevice(fd int, tapName, bridgeName string) (tun.Device, error) { +func openDevice(logf logger.Logf, fd int, tapName, bridgeName string) (tun.Device, error) { ifr, err := unix.NewIfreq(tapName) if err != nil { return nil, err @@ -71,7 +75,7 @@ func openDevice(fd int, tapName, bridgeName string) (tun.Device, error) { } } - return newTAPDevice(fd, tapName) + return newTAPDevice(logf, fd, tapName) } type etherType [2]byte @@ -91,7 +95,7 @@ const ( // handleTAPFrame handles receiving a raw TAP ethernet frame and reports whether // it's been handled (that is, whether it should NOT be passed to wireguard). -func (t *Wrapper) handleTAPFrame(ethBuf []byte) bool { +func (t *tapDevice) handleTAPFrame(ethBuf []byte) bool { if len(ethBuf) < ethernetFrameSize { // Corrupt. Ignore. @@ -164,8 +168,7 @@ func (t *Wrapper) handleTAPFrame(ethBuf []byte) bool { copy(res.HardwareAddressTarget(), req.HardwareAddressSender()) copy(res.ProtocolAddressTarget(), req.ProtocolAddressSender()) - // TODO(raggi): reduce allocs! - n, err := t.tdev.Write([][]byte{buf}, 0) + n, err := t.WriteEthernet(buf) if tapDebug { t.logf("tap: wrote ARP reply %v, %v", n, err) } @@ -182,7 +185,7 @@ const routerIP = "100.70.145.1" // must be in same netmask (currently hack at // handleDHCPRequest handles receiving a raw TAP ethernet frame and reports whether // it's been handled as a DHCP request. That is, it reports whether the frame should // be ignored by the caller and not passed on. -func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { +func (t *tapDevice) handleDHCPRequest(ethBuf []byte) bool { const udpHeader = 8 if len(ethBuf) < ethernetFrameSize+ipv4HeaderLen+udpHeader { if tapDebug { @@ -207,7 +210,7 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { if p.IPProto != ipproto.UDP || p.Src.Port() != 68 || p.Dst.Port() != 67 { // Not a DHCP request. if tapDebug { - t.logf("tap: DHCP wrong meta") + t.logf("tap: DHCP wrong meta: %+v", p) } return passOnPacket } @@ -250,8 +253,7 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { netip.AddrPortFrom(netaddr.IPv4(255, 255, 255, 255), 68), // dst ) - // TODO(raggi): reduce allocs! - n, err := t.tdev.Write([][]byte{pkt}, 0) + n, err := t.WriteEthernet(pkt) if tapDebug { t.logf("tap: wrote DHCP OFFER %v, %v", n, err) } @@ -278,8 +280,7 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { netip.AddrPortFrom(netaddr.IPv4(100, 100, 100, 100), 67), // src netip.AddrPortFrom(netaddr.IPv4(255, 255, 255, 255), 68), // dst ) - // TODO(raggi): reduce allocs! - n, err := t.tdev.Write([][]byte{pkt}, 0) + n, err := t.WriteEthernet(pkt) if tapDebug { t.logf("tap: wrote DHCP ACK %v, %v", n, err) } @@ -291,6 +292,16 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { return consumePacket } +func writeEthernetFrame(buf []byte, srcMAC, dstMAC net.HardwareAddr, proto tcpip.NetworkProtocolNumber) { + // Ethernet header + eth := header.Ethernet(buf) + eth.Encode(&header.EthernetFields{ + SrcAddr: tcpip.LinkAddress(srcMAC), + DstAddr: tcpip.LinkAddress(dstMAC), + Type: proto, + }) +} + func packLayer2UDP(payload []byte, srcMAC, dstMAC net.HardwareAddr, src, dst netip.AddrPort) []byte { buf := make([]byte, header.EthernetMinimumSize+header.UDPMinimumSize+header.IPv4MinimumSize+len(payload)) payloadStart := len(buf) - len(payload) @@ -300,12 +311,7 @@ func packLayer2UDP(payload []byte, srcMAC, dstMAC net.HardwareAddr, src, dst net dstB := dst.Addr().As4() dstIP := tcpip.AddrFromSlice(dstB[:]) // Ethernet header - eth := header.Ethernet(buf) - eth.Encode(&header.EthernetFields{ - SrcAddr: tcpip.LinkAddress(srcMAC), - DstAddr: tcpip.LinkAddress(dstMAC), - Type: ipv4.ProtocolNumber, - }) + writeEthernetFrame(buf, srcMAC, dstMAC, ipv4.ProtocolNumber) // IP header ipbuf := buf[header.EthernetMinimumSize:] ip := header.IPv4(ipbuf) @@ -342,17 +348,18 @@ func run(prog string, args ...string) error { return nil } -func (t *Wrapper) destMAC() [6]byte { +func (t *tapDevice) destMAC() [6]byte { return t.destMACAtomic.Load() } -func newTAPDevice(fd int, tapName string) (tun.Device, error) { +func newTAPDevice(logf logger.Logf, fd int, tapName string) (tun.Device, error) { err := unix.SetNonblock(fd, true) if err != nil { return nil, err } file := os.NewFile(uintptr(fd), "/dev/tap") d := &tapDevice{ + logf: logf, file: file, events: make(chan tun.Event), name: tapName, @@ -360,20 +367,14 @@ func newTAPDevice(fd int, tapName string) (tun.Device, error) { return d, nil } -var ( - _ setWrapperer = &tapDevice{} -) - type tapDevice struct { file *os.File + logf func(format string, args ...any) events chan tun.Event name string - wrapper *Wrapper closeOnce sync.Once -} -func (t *tapDevice) setWrapper(wrapper *Wrapper) { - t.wrapper = wrapper + destMACAtomic syncs.AtomicValue[[6]byte] } func (t *tapDevice) File() *os.File { @@ -384,36 +385,63 @@ func (t *tapDevice) Name() (string, error) { return t.name, nil } +// Read reads an IP packet from the TAP device. It strips the ethernet frame header. func (t *tapDevice) Read(buffs [][]byte, sizes []int, offset int) (int, error) { + n, err := t.ReadEthernet(buffs, sizes, offset) + if err != nil || n == 0 { + return n, err + } + // Strip the ethernet frame header. + copy(buffs[0][offset:], buffs[0][offset+ethernetFrameSize:offset+sizes[0]]) + sizes[0] -= ethernetFrameSize + return 1, nil +} + +// ReadEthernet reads a raw ethernet frame from the TAP device. +func (t *tapDevice) ReadEthernet(buffs [][]byte, sizes []int, offset int) (int, error) { n, err := t.file.Read(buffs[0][offset:]) if err != nil { return 0, err } + if t.handleTAPFrame(buffs[0][offset : offset+n]) { + return 0, nil + } sizes[0] = n return 1, nil } +// WriteEthernet writes a raw ethernet frame to the TAP device. +func (t *tapDevice) WriteEthernet(buf []byte) (int, error) { + return t.file.Write(buf) +} + +// ethBufPool holds a pool of bytes.Buffers for use in [tapDevice.Write]. +var ethBufPool = syncs.Pool[*bytes.Buffer]{New: func() *bytes.Buffer { return new(bytes.Buffer) }} + +// Write writes a raw IP packet to the TAP device. It adds the ethernet frame header. func (t *tapDevice) Write(buffs [][]byte, offset int) (int, error) { errs := make([]error, 0) wrote := 0 + m := t.destMAC() + dstMac := net.HardwareAddr(m[:]) + buf := ethBufPool.Get() + defer ethBufPool.Put(buf) for _, buff := range buffs { - if offset < ethernetFrameSize { - errs = append(errs, fmt.Errorf("[unexpected] weird offset %d for TAP write", offset)) - return 0, multierr.New(errs...) - } - eth := buff[offset-ethernetFrameSize:] - dst := t.wrapper.destMAC() - copy(eth[:6], dst[:]) - copy(eth[6:12], ourMAC[:]) - et := etherTypeIPv4 - if buff[offset]>>4 == 6 { - et = etherTypeIPv6 + buf.Reset() + buf.Grow(header.EthernetMinimumSize + len(buff) - offset) + + var ebuf [14]byte + switch buff[offset] >> 4 { + case 4: + writeEthernetFrame(ebuf[:], ourMAC, dstMac, ipv4.ProtocolNumber) + case 6: + writeEthernetFrame(ebuf[:], ourMAC, dstMac, ipv6.ProtocolNumber) + default: + continue } - eth[12], eth[13] = et[0], et[1] - if tapDebug { - t.wrapper.logf("tap: tapWrite off=%v % x", offset, buff) - } - _, err := t.file.Write(buff[offset-ethernetFrameSize:]) + buf.Write(ebuf[:]) + buf.Write(buff[offset:]) + _, err := t.WriteEthernet(buf.Bytes()) if err != nil { errs = append(errs, err) } else { @@ -428,8 +456,7 @@ func (t *tapDevice) MTU() (int, error) { if err != nil { return 0, err } - err = unix.IoctlIfreq(int(t.file.Fd()), unix.SIOCGIFMTU, ifr) - if err != nil { + if err := unix.IoctlIfreq(int(t.file.Fd()), unix.SIOCGIFMTU, ifr); err != nil { return 0, err } return int(ifr.Uint32()), nil diff --git a/net/tstun/tap_unsupported.go b/net/tstun/tap_unsupported.go deleted file mode 100644 index 6792b229f6b79..0000000000000 --- a/net/tstun/tap_unsupported.go +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux || ts_omit_tap - -package tstun - -func (*Wrapper) handleTAPFrame([]byte) bool { panic("unreachable") } diff --git a/net/tstun/tun.go b/net/tstun/tun.go index 66e209d1acb5a..9f5d42ecc3269 100644 --- a/net/tstun/tun.go +++ b/net/tstun/tun.go @@ -18,7 +18,7 @@ import ( ) // createTAP is non-nil on Linux. -var createTAP func(tapName, bridgeName string) (tun.Device, error) +var createTAP func(logf logger.Logf, tapName, bridgeName string) (tun.Device, error) // New returns a tun.Device for the requested device name, along with // the OS-dependent name that was allocated to the device. @@ -42,7 +42,7 @@ func New(logf logger.Logf, tunName string) (tun.Device, string, error) { default: return nil, "", errors.New("bogus tap argument") } - dev, err = createTAP(tapName, bridgeName) + dev, err = createTAP(logf, tapName, bridgeName) } else { dev, err = tun.CreateTUN(tunName, int(DefaultTUNMTU())) } diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index dcd43d5718ca8..b0765b13d3eda 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -109,9 +109,7 @@ type Wrapper struct { lastActivityAtomic mono.Time // time of last send or receive destIPActivity syncs.AtomicValue[map[netip.Addr]func()] - //lint:ignore U1000 used in tap_linux.go - destMACAtomic syncs.AtomicValue[[6]byte] - discoKey syncs.AtomicValue[key.DiscoPublic] + discoKey syncs.AtomicValue[key.DiscoPublic] // timeNow, if non-nil, will be used to obtain the current time. timeNow func() time.Time @@ -257,12 +255,6 @@ type tunVectorReadResult struct { dataOffset int } -type setWrapperer interface { - // setWrapper enables the underlying TUN/TAP to have access to the Wrapper. - // It MUST be called only once during initialization, other usage is unsafe. - setWrapper(*Wrapper) -} - // Start unblocks any Wrapper.Read calls that have already started // and makes the Wrapper functional. // @@ -313,10 +305,6 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry) w.bufferConsumed <- struct{}{} w.noteActivity() - if sw, ok := w.tdev.(setWrapperer); ok { - sw.setWrapper(w) - } - return w } @@ -459,12 +447,18 @@ const ethernetFrameSize = 14 // 2 six byte MACs, 2 bytes ethertype func (t *Wrapper) pollVector() { sizes := make([]int, len(t.vectorBuffer)) readOffset := PacketStartOffset + reader := t.tdev.Read if t.isTAP { - readOffset = PacketStartOffset - ethernetFrameSize + type tapReader interface { + ReadEthernet(buffs [][]byte, sizes []int, offset int) (int, error) + } + if r, ok := t.tdev.(tapReader); ok { + readOffset = PacketStartOffset - ethernetFrameSize + reader = r.ReadEthernet + } } for range t.bufferConsumed { - DoRead: for i := range t.vectorBuffer { t.vectorBuffer[i] = t.vectorBuffer[i][:cap(t.vectorBuffer[i])] } @@ -474,7 +468,7 @@ func (t *Wrapper) pollVector() { if t.isClosed() { return } - n, err = t.tdev.Read(t.vectorBuffer[:], sizes, readOffset) + n, err = reader(t.vectorBuffer[:], sizes, readOffset) if t.isTAP && tapDebug { s := fmt.Sprintf("% x", t.vectorBuffer[0][:]) for strings.HasSuffix(s, " 00") { @@ -486,21 +480,6 @@ func (t *Wrapper) pollVector() { for i := range sizes[:n] { t.vectorBuffer[i] = t.vectorBuffer[i][:readOffset+sizes[i]] } - if t.isTAP { - if err == nil { - ethernetFrame := t.vectorBuffer[0][readOffset:] - if t.handleTAPFrame(ethernetFrame) { - goto DoRead - } - } - // Fall through. We got an IP packet. - if sizes[0] >= ethernetFrameSize { - t.vectorBuffer[0] = t.vectorBuffer[0][:readOffset+sizes[0]-ethernetFrameSize] - } - if tapDebug { - t.logf("tap regular frame: %x", t.vectorBuffer[0][PacketStartOffset:PacketStartOffset+sizes[0]]) - } - } t.sendVectorOutbound(tunVectorReadResult{ data: t.vectorBuffer[:n], dataOffset: PacketStartOffset, From 85241f8408fd73f47b776c87366d54d240440d24 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Mon, 21 Oct 2024 17:00:41 -0700 Subject: [PATCH 034/179] net/tstun: use /10 as subnet for TAP mode; read IP from netmap Few changes to resolve TODOs in the code: - Instead of using a hardcoded IP, get it from the netmap. - Use 100.100.100.100 as the gateway IP - Use the /10 CGNAT range instead of a random /24 Updates #2589 Signed-off-by: Maisem Ali --- net/tstun/tap_linux.go | 66 ++++++++++++++++++++++++++++-------------- net/tstun/wrap.go | 11 ++++++- 2 files changed, 54 insertions(+), 23 deletions(-) diff --git a/net/tstun/tap_linux.go b/net/tstun/tap_linux.go index c366b05604f24..8a00a96927c4d 100644 --- a/net/tstun/tap_linux.go +++ b/net/tstun/tap_linux.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "tailscale.com/net/netaddr" "tailscale.com/net/packet" + "tailscale.com/net/tsaddr" "tailscale.com/syncs" "tailscale.com/types/ipproto" "tailscale.com/types/logger" @@ -158,7 +159,7 @@ func (t *tapDevice) handleTAPFrame(ethBuf []byte) bool { // If the client's asking about their own IP, tell them it's // their own MAC. TODO(bradfitz): remove String allocs. - if net.IP(req.ProtocolAddressTarget()).String() == theClientIP { + if net.IP(req.ProtocolAddressTarget()).String() == t.clientIPv4.Load() { copy(res.HardwareAddressSender(), ethSrcMAC) } else { copy(res.HardwareAddressSender(), ourMAC[:]) @@ -178,9 +179,12 @@ func (t *tapDevice) handleTAPFrame(ethBuf []byte) bool { } } -// TODO(bradfitz): remove these hard-coded values and move from a /24 to a /10 CGNAT as the range. -const theClientIP = "100.70.145.3" // TODO: make dynamic from netmap -const routerIP = "100.70.145.1" // must be in same netmask (currently hack at /24) as theClientIP +var ( + // routerIP is the IP address of the DHCP server. + routerIP = net.ParseIP(tsaddr.TailscaleServiceIPString) + // cgnatNetMask is the netmask of the 100.64.0.0/10 CGNAT range. + cgnatNetMask = net.IPMask(net.ParseIP("255.192.0.0").To4()) +) // handleDHCPRequest handles receiving a raw TAP ethernet frame and reports whether // it's been handled as a DHCP request. That is, it reports whether the frame should @@ -228,17 +232,22 @@ func (t *tapDevice) handleDHCPRequest(ethBuf []byte) bool { } switch dp.MessageType() { case dhcpv4.MessageTypeDiscover: + ips := t.clientIPv4.Load() + if ips == "" { + t.logf("tap: DHCP no client IP") + return consumePacket + } offer, err := dhcpv4.New( dhcpv4.WithReply(dp), dhcpv4.WithMessageType(dhcpv4.MessageTypeOffer), - dhcpv4.WithRouter(net.ParseIP(routerIP)), // the default route - dhcpv4.WithDNS(net.ParseIP("100.100.100.100")), - dhcpv4.WithServerIP(net.ParseIP("100.100.100.100")), // TODO: what is this? - dhcpv4.WithOption(dhcpv4.OptServerIdentifier(net.ParseIP("100.100.100.100"))), - dhcpv4.WithYourIP(net.ParseIP(theClientIP)), + dhcpv4.WithRouter(routerIP), // the default route + dhcpv4.WithDNS(routerIP), + dhcpv4.WithServerIP(routerIP), // TODO: what is this? + dhcpv4.WithOption(dhcpv4.OptServerIdentifier(routerIP)), + dhcpv4.WithYourIP(net.ParseIP(ips)), dhcpv4.WithLeaseTime(3600), // hour works //dhcpv4.WithHwAddr(ethSrcMAC), - dhcpv4.WithNetmask(net.IPMask(net.ParseIP("255.255.255.0").To4())), // TODO: wrong + dhcpv4.WithNetmask(cgnatNetMask), //dhcpv4.WithTransactionID(dp.TransactionID), ) if err != nil { @@ -258,16 +267,21 @@ func (t *tapDevice) handleDHCPRequest(ethBuf []byte) bool { t.logf("tap: wrote DHCP OFFER %v, %v", n, err) } case dhcpv4.MessageTypeRequest: + ips := t.clientIPv4.Load() + if ips == "" { + t.logf("tap: DHCP no client IP") + return consumePacket + } ack, err := dhcpv4.New( dhcpv4.WithReply(dp), dhcpv4.WithMessageType(dhcpv4.MessageTypeAck), - dhcpv4.WithDNS(net.ParseIP("100.100.100.100")), - dhcpv4.WithRouter(net.ParseIP(routerIP)), // the default route - dhcpv4.WithServerIP(net.ParseIP("100.100.100.100")), // TODO: what is this? - dhcpv4.WithOption(dhcpv4.OptServerIdentifier(net.ParseIP("100.100.100.100"))), - dhcpv4.WithYourIP(net.ParseIP(theClientIP)), // Hello world - dhcpv4.WithLeaseTime(3600), // hour works - dhcpv4.WithNetmask(net.IPMask(net.ParseIP("255.255.255.0").To4())), + dhcpv4.WithDNS(routerIP), + dhcpv4.WithRouter(routerIP), // the default route + dhcpv4.WithServerIP(routerIP), // TODO: what is this? + dhcpv4.WithOption(dhcpv4.OptServerIdentifier(routerIP)), + dhcpv4.WithYourIP(net.ParseIP(ips)), // Hello world + dhcpv4.WithLeaseTime(3600), // hour works + dhcpv4.WithNetmask(cgnatNetMask), ) if err != nil { t.logf("error building DHCP ack: %v", err) @@ -368,15 +382,23 @@ func newTAPDevice(logf logger.Logf, fd int, tapName string) (tun.Device, error) } type tapDevice struct { - file *os.File - logf func(format string, args ...any) - events chan tun.Event - name string - closeOnce sync.Once + file *os.File + logf func(format string, args ...any) + events chan tun.Event + name string + closeOnce sync.Once + clientIPv4 syncs.AtomicValue[string] destMACAtomic syncs.AtomicValue[[6]byte] } +var _ setIPer = (*tapDevice)(nil) + +func (t *tapDevice) SetIP(ipV4, ipV6TODO netip.Addr) error { + t.clientIPv4.Store(ipV4.String()) + return nil +} + func (t *tapDevice) File() *os.File { return t.file } diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index b0765b13d3eda..0b858fc1c5653 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -802,10 +802,19 @@ func (pc *peerConfigTable) outboundPacketIsJailed(p *packet.Parsed) bool { return c.jailed } +type setIPer interface { + // SetIP sets the IP addresses of the TAP device. + SetIP(ipV4, ipV6 netip.Addr) error +} + // SetWGConfig is called when a new NetworkMap is received. func (t *Wrapper) SetWGConfig(wcfg *wgcfg.Config) { + if t.isTAP { + if sip, ok := t.tdev.(setIPer); ok { + sip.SetIP(findV4(wcfg.Addresses), findV6(wcfg.Addresses)) + } + } cfg := peerConfigTableFromWGConfig(wcfg) - old := t.peerConfig.Swap(cfg) if !reflect.DeepEqual(old, cfg) { t.logf("peer config: %v", cfg) From ae5bc88ebea2f96f67e54ba6886c63ee0af14b54 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 22 Oct 2024 09:40:17 -0500 Subject: [PATCH 035/179] health: fix spurious warning about DERP home region '0' Updates #13650 Change-Id: I6b0f165f66da3f881a4caa25d2d9936dc2a7f22c Signed-off-by: Brad Fitzpatrick --- health/health.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/health/health.go b/health/health.go index 16b41f075fee8..3bebcb98356f4 100644 --- a/health/health.go +++ b/health/health.go @@ -1038,11 +1038,15 @@ func (t *Tracker) updateBuiltinWarnablesLocked() { ArgDuration: d.Round(time.Second).String(), }) } - } else { + } else if homeDERP != 0 { t.setUnhealthyLocked(noDERPConnectionWarnable, Args{ ArgDERPRegionID: fmt.Sprint(homeDERP), ArgDERPRegionName: t.derpRegionNameLocked(homeDERP), }) + } else { + // No DERP home yet determined yet. There's probably some + // other problem or things are just starting up. + t.setHealthyLocked(noDERPConnectionWarnable) } if !t.ipnWantRunning { From b2665d9b89ee8c7be10a8e0a2fa36d35d21d8440 Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Tue, 22 Oct 2024 14:17:48 -0500 Subject: [PATCH 036/179] net/netcheck: add a Now field to the netcheck Report This allows us to print the time that a netcheck was run, which is useful in debugging. Updates #10972 Signed-off-by: Andrew Dunham Change-Id: Id48d30d4eb6d5208efb2b1526a71d83fe7f9320b --- cmd/tailscale/cli/netcheck.go | 1 + net/netcheck/netcheck.go | 16 +++++++++------- net/netcheck/netcheck_test.go | 14 ++++++++++++++ 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/cmd/tailscale/cli/netcheck.go b/cmd/tailscale/cli/netcheck.go index 682cd99a3c6e4..312475eced978 100644 --- a/cmd/tailscale/cli/netcheck.go +++ b/cmd/tailscale/cli/netcheck.go @@ -136,6 +136,7 @@ func printReport(dm *tailcfg.DERPMap, report *netcheck.Report) error { } printf("\nReport:\n") + printf("\t* Time: %v\n", report.Now.Format(time.RFC3339Nano)) printf("\t* UDP: %v\n", report.UDP) if report.GlobalV4.IsValid() { printf("\t* IPv4: yes, %s\n", report.GlobalV4) diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index bebf4c9b05461..1714837305ac1 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -85,13 +85,14 @@ const ( // Report contains the result of a single netcheck. type Report struct { - UDP bool // a UDP STUN round trip completed - IPv6 bool // an IPv6 STUN round trip completed - IPv4 bool // an IPv4 STUN round trip completed - IPv6CanSend bool // an IPv6 packet was able to be sent - IPv4CanSend bool // an IPv4 packet was able to be sent - OSHasIPv6 bool // could bind a socket to ::1 - ICMPv4 bool // an ICMPv4 round trip completed + Now time.Time // the time the report was run + UDP bool // a UDP STUN round trip completed + IPv6 bool // an IPv6 STUN round trip completed + IPv4 bool // an IPv4 STUN round trip completed + IPv6CanSend bool // an IPv6 packet was able to be sent + IPv4CanSend bool // an IPv4 packet was able to be sent + OSHasIPv6 bool // could bind a socket to ::1 + ICMPv4 bool // an ICMPv4 round trip completed // MappingVariesByDestIP is whether STUN results depend which // STUN server you're talking to (on IPv4). @@ -1297,6 +1298,7 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(rs *reportState, r *Report, c.prev = map[time.Time]*Report{} } now := c.timeNow() + r.Now = now.UTC() c.prev[now] = r c.last = r diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index 964014203f05d..2780c9c44b08c 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -28,6 +28,9 @@ func newTestClient(t testing.TB) *Client { c := &Client{ NetMon: netmon.NewStatic(), Logf: t.Logf, + TimeNow: func() time.Time { + return time.Unix(1729624521, 0) + }, } return c } @@ -52,6 +55,9 @@ func TestBasic(t *testing.T) { if !r.UDP { t.Error("want UDP") } + if r.Now.IsZero() { + t.Error("Now is zero") + } if len(r.RegionLatency) != 1 { t.Errorf("expected 1 key in DERPLatency; got %+v", r.RegionLatency) } @@ -130,6 +136,14 @@ func TestWorksWhenUDPBlocked(t *testing.T) { want := newReport() + // The Now field can't be compared with reflect.DeepEqual; check using + // the Equal method and then overwrite it so that the comparison below + // succeeds. + if !r.Now.Equal(c.TimeNow()) { + t.Errorf("Now = %v; want %v", r.Now, c.TimeNow()) + } + want.Now = r.Now + // The IPv4CanSend flag gets set differently across platforms. // On Windows this test detects false, while on Linux detects true. // That's not relevant to this test, so just accept what we're From 212270463b2916938a06db251621b7d2f15b08fb Mon Sep 17 00:00:00 2001 From: Paul Scott <408401+icio@users.noreply.github.com> Date: Thu, 24 Oct 2024 09:41:54 -0500 Subject: [PATCH 037/179] cmd/testwrapper: add pkg runtime to output (#13894) Fixes #13893 Signed-off-by: Paul Scott --- cmd/testwrapper/testwrapper.go | 25 ++++++++++++++++--------- cmd/testwrapper/testwrapper_test.go | 6 +++++- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/cmd/testwrapper/testwrapper.go b/cmd/testwrapper/testwrapper.go index 9b8d7a7c17ba5..f6ff8f00a93ab 100644 --- a/cmd/testwrapper/testwrapper.go +++ b/cmd/testwrapper/testwrapper.go @@ -42,6 +42,7 @@ type testAttempt struct { testName string // "TestFoo" outcome string // "pass", "fail", "skip" logs bytes.Buffer + start, end time.Time isMarkedFlaky bool // set if the test is marked as flaky issueURL string // set if the test is marked as flaky @@ -132,11 +133,17 @@ func runTests(ctx context.Context, attempt int, pt *packageTests, goTestArgs, te } pkg := goOutput.Package pkgTests := resultMap[pkg] + if pkgTests == nil { + pkgTests = make(map[string]*testAttempt) + resultMap[pkg] = pkgTests + } if goOutput.Test == "" { switch goOutput.Action { + case "start": + pkgTests[""] = &testAttempt{start: goOutput.Time} case "fail", "pass", "skip": for _, test := range pkgTests { - if test.outcome == "" { + if test.testName != "" && test.outcome == "" { test.outcome = "fail" ch <- test } @@ -144,15 +151,13 @@ func runTests(ctx context.Context, attempt int, pt *packageTests, goTestArgs, te ch <- &testAttempt{ pkg: goOutput.Package, outcome: goOutput.Action, + start: pkgTests[""].start, + end: goOutput.Time, pkgFinished: true, } } continue } - if pkgTests == nil { - pkgTests = make(map[string]*testAttempt) - resultMap[pkg] = pkgTests - } testName := goOutput.Test if test, _, isSubtest := strings.Cut(goOutput.Test, "/"); isSubtest { testName = test @@ -168,8 +173,10 @@ func runTests(ctx context.Context, attempt int, pt *packageTests, goTestArgs, te pkgTests[testName] = &testAttempt{ pkg: pkg, testName: testName, + start: goOutput.Time, } case "skip", "pass", "fail": + pkgTests[testName].end = goOutput.Time pkgTests[testName].outcome = goOutput.Action ch <- pkgTests[testName] case "output": @@ -213,7 +220,7 @@ func main() { firstRun.tests = append(firstRun.tests, &packageTests{Pattern: pkg}) } toRun := []*nextRun{firstRun} - printPkgOutcome := func(pkg, outcome string, attempt int) { + printPkgOutcome := func(pkg, outcome string, attempt int, runtime time.Duration) { if outcome == "skip" { fmt.Printf("?\t%s [skipped/no tests] \n", pkg) return @@ -225,10 +232,10 @@ func main() { outcome = "FAIL" } if attempt > 1 { - fmt.Printf("%s\t%s [attempt=%d]\n", outcome, pkg, attempt) + fmt.Printf("%s\t%s\t%.3fs\t[attempt=%d]\n", outcome, pkg, runtime.Seconds(), attempt) return } - fmt.Printf("%s\t%s\n", outcome, pkg) + fmt.Printf("%s\t%s\t%.3fs\n", outcome, pkg, runtime.Seconds()) } // Check for -coverprofile argument and filter it out @@ -307,7 +314,7 @@ func main() { // when a package times out. failed = true } - printPkgOutcome(tr.pkg, tr.outcome, thisRun.attempt) + printPkgOutcome(tr.pkg, tr.outcome, thisRun.attempt, tr.end.Sub(tr.start)) continue } if testingVerbose || tr.outcome == "fail" { diff --git a/cmd/testwrapper/testwrapper_test.go b/cmd/testwrapper/testwrapper_test.go index d7dbccd093ef8..fb2ed2c52cb2e 100644 --- a/cmd/testwrapper/testwrapper_test.go +++ b/cmd/testwrapper/testwrapper_test.go @@ -10,6 +10,7 @@ import ( "os" "os/exec" "path/filepath" + "regexp" "sync" "testing" ) @@ -76,7 +77,10 @@ func TestFlakeRun(t *testing.T) { t.Fatalf("go run . %s: %s with output:\n%s", testfile, err, out) } - want := []byte("ok\t" + testfile + " [attempt=2]") + // Replace the unpredictable timestamp with "0.00s". + out = regexp.MustCompile(`\t\d+\.\d\d\ds\t`).ReplaceAll(out, []byte("\t0.00s\t")) + + want := []byte("ok\t" + testfile + "\t0.00s\t[attempt=2]") if !bytes.Contains(out, want) { t.Fatalf("wanted output containing %q but got:\n%s", want, out) } From 7fe6e508588c6359fc51b0221aa1c20ac39e3eaa Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Thu, 24 Oct 2024 11:43:22 -0500 Subject: [PATCH 038/179] net/dns/resolver: fix test flake Updates #13902 Signed-off-by: Andrew Dunham Change-Id: Ib2def19caad17367e9a31786ac969278e65f51c6 --- net/dns/resolver/forwarder_test.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index e341186ecf45e..f3e592d4f1993 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -27,6 +27,7 @@ import ( "tailscale.com/health" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" + "tailscale.com/tstest" "tailscale.com/types/dnstype" ) @@ -276,6 +277,8 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on tb.Fatal("cannot skip both UDP and TCP servers") } + logf := tstest.WhileTestRunningLogger(tb) + tcpResponse := make([]byte, len(response)+2) binary.BigEndian.PutUint16(tcpResponse, uint16(len(response))) copy(tcpResponse[2:], response) @@ -329,13 +332,13 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on // Read the length header, then the buffer var length uint16 if err := binary.Read(conn, binary.BigEndian, &length); err != nil { - tb.Logf("error reading length header: %v", err) + logf("error reading length header: %v", err) return } req := make([]byte, length) n, err := io.ReadFull(conn, req) if err != nil { - tb.Logf("error reading query: %v", err) + logf("error reading query: %v", err) return } req = req[:n] @@ -343,7 +346,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on // Write response if _, err := conn.Write(tcpResponse); err != nil { - tb.Logf("error writing response: %v", err) + logf("error writing response: %v", err) return } } @@ -367,7 +370,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on handleUDP := func(addr netip.AddrPort, req []byte) { onRequest(false, req) if _, err := udpLn.WriteToUDPAddrPort(response, addr); err != nil { - tb.Logf("error writing response: %v", err) + logf("error writing response: %v", err) } } @@ -390,7 +393,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on tb.Cleanup(func() { tcpLn.Close() udpLn.Close() - tb.Logf("waiting for listeners to finish...") + logf("waiting for listeners to finish...") wg.Wait() }) return @@ -450,7 +453,8 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) } func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) { - netMon, err := netmon.New(tb.Logf) + logf := tstest.WhileTestRunningLogger(tb) + netMon, err := netmon.New(logf) if err != nil { tb.Fatal(err) } @@ -458,7 +462,7 @@ func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports var dialer tsdial.Dialer dialer.SetNetMon(netMon) - fwd := newForwarder(tb.Logf, netMon, nil, &dialer, new(health.Tracker), nil) + fwd := newForwarder(logf, netMon, nil, &dialer, new(health.Tracker), nil) if modify != nil { modify(fwd) } From e815ae0ec4b718486af9be3a30d3058b65b28c4e Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Tue, 8 Oct 2024 10:50:14 -0500 Subject: [PATCH 039/179] util/syspolicy, ipn/ipnlocal: update syspolicy package to utilize syspolicy/rsop In this PR, we update the syspolicy package to utilize syspolicy/rsop under the hood, and remove syspolicy.CachingHandler, syspolicy.windowsHandler and related code which is no longer used. We mark the syspolicy.Handler interface and RegisterHandler/SetHandlerForTest functions as deprecated, but keep them temporarily until they are no longer used in other repos. We also update the package to register setting definitions for all existing policy settings and to register the Registry-based, Windows-specific policy stores when running on Windows. Finally, we update existing internal and external tests to use the new API and add a few more tests and benchmarks. Updates #12687 Signed-off-by: Nick Khyl --- cmd/derper/depaware.txt | 15 +- cmd/k8s-operator/depaware.txt | 9 +- cmd/tailscale/depaware.txt | 10 +- cmd/tailscaled/depaware.txt | 9 +- ipn/ipnlocal/local_test.go | 225 ++++----------- util/syspolicy/caching_handler.go | 122 -------- util/syspolicy/caching_handler_test.go | 262 ----------------- util/syspolicy/handler.go | 114 +++++--- util/syspolicy/handler_test.go | 19 -- util/syspolicy/handler_windows.go | 105 ------- util/syspolicy/policy_keys.go | 103 ++++++- util/syspolicy/policy_keys_test.go | 95 +++++++ util/syspolicy/policy_keys_windows.go | 38 --- util/syspolicy/syspolicy.go | 152 +++++++--- util/syspolicy/syspolicy_test.go | 377 ++++++++++++++++++------- util/syspolicy/syspolicy_windows.go | 92 ++++++ 16 files changed, 822 insertions(+), 925 deletions(-) delete mode 100644 util/syspolicy/caching_handler.go delete mode 100644 util/syspolicy/caching_handler_test.go delete mode 100644 util/syspolicy/handler_test.go delete mode 100644 util/syspolicy/handler_windows.go create mode 100644 util/syspolicy/policy_keys_test.go delete mode 100644 util/syspolicy/policy_keys_windows.go create mode 100644 util/syspolicy/syspolicy_windows.go diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 362b07882b268..e20c4e556da8f 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -164,11 +164,16 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/util/slicesx from tailscale.com/cmd/derper+ tailscale.com/util/syspolicy from tailscale.com/ipn tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+ - tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy - tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ + tailscale.com/util/testenv from tailscale.com/util/syspolicy+ tailscale.com/util/usermetric from tailscale.com/health tailscale.com/util/vizerror from tailscale.com/tailcfg+ W 💣 tailscale.com/util/winutil from tailscale.com/hostinfo+ + W 💣 tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ tailscale.com/version from tailscale.com/derp+ tailscale.com/version/distro from tailscale.com/envknob+ @@ -189,7 +194,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ golang.org/x/crypto/sha3 from crypto/internal/mlkem768+ W golang.org/x/exp/constraints from tailscale.com/util/winutil - golang.org/x/exp/maps from tailscale.com/util/syspolicy/setting + golang.org/x/exp/maps from tailscale.com/util/syspolicy/setting+ L golang.org/x/net/bpf from github.com/mdlayher/netlink+ golang.org/x/net/dns/dnsmessage from net+ golang.org/x/net/http/httpguts from net/http @@ -250,7 +255,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa encoding/pem from crypto/tls+ errors from bufio+ expvar from github.com/prometheus/client_golang/prometheus+ - flag from tailscale.com/cmd/derper + flag from tailscale.com/cmd/derper+ fmt from compress/flate+ go/token from google.golang.org/protobuf/internal/strs hash from crypto+ @@ -284,7 +289,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa os from crypto/rand+ os/exec from github.com/coreos/go-iptables/iptables+ os/signal from tailscale.com/cmd/derper - W os/user from tailscale.com/util/winutil + W os/user from tailscale.com/util/winutil+ path from github.com/prometheus/client_golang/prometheus/internal+ path/filepath from crypto/x509+ reflect from crypto/x509+ diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 19d6808d75a3e..2ad3978c927d7 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -812,8 +812,11 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/util/slicesx from tailscale.com/appc+ tailscale.com/util/syspolicy from tailscale.com/control/controlclient+ tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+ - tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy - tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock tailscale.com/util/systemd from tailscale.com/control/controlclient+ tailscale.com/util/testenv from tailscale.com/control/controlclient+ @@ -823,7 +826,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/util/vizerror from tailscale.com/tailcfg+ 💣 tailscale.com/util/winutil from tailscale.com/clientupdate+ W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+ - W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns + W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns+ W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ tailscale.com/util/zstdframe from tailscale.com/control/controlclient+ diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 765bbc483e57e..cce76a81e0bfb 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -174,14 +174,18 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+ tailscale.com/util/syspolicy from tailscale.com/ipn tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+ - tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy - tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy - tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ + tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli+ tailscale.com/util/truncate from tailscale.com/cmd/tailscale/cli tailscale.com/util/usermetric from tailscale.com/health tailscale.com/util/vizerror from tailscale.com/tailcfg+ W 💣 tailscale.com/util/winutil from tailscale.com/clientupdate+ W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate + W 💣 tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ tailscale.com/version from tailscale.com/client/web+ tailscale.com/version/distro from tailscale.com/client/web+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 26165d659afac..b3a4aa86fba30 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -401,8 +401,11 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+ tailscale.com/util/syspolicy from tailscale.com/cmd/tailscaled+ tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+ - tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy - tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock tailscale.com/util/systemd from tailscale.com/control/controlclient+ tailscale.com/util/testenv from tailscale.com/ipn/ipnlocal+ @@ -412,7 +415,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/util/vizerror from tailscale.com/tailcfg+ 💣 tailscale.com/util/winutil from tailscale.com/clientupdate+ W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+ - W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns + W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns+ W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ tailscale.com/util/zstdframe from tailscale.com/control/controlclient+ diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index 9a8fa5e02df4f..5fee5d00ee36a 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -54,6 +54,8 @@ import ( "tailscale.com/util/must" "tailscale.com/util/set" "tailscale.com/util/syspolicy" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" "tailscale.com/wgengine" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/wgcfg" @@ -1559,94 +1561,6 @@ func dnsResponse(domain, address string) []byte { return must.Get(b.Finish()) } -type errorSyspolicyHandler struct { - t *testing.T - err error - key syspolicy.Key - allowKeys map[syspolicy.Key]*string -} - -func (h *errorSyspolicyHandler) ReadString(key string) (string, error) { - sk := syspolicy.Key(key) - if _, ok := h.allowKeys[sk]; !ok { - h.t.Errorf("ReadString: %q is not in list of permitted keys", h.key) - } - if sk == h.key { - return "", h.err - } - return "", syspolicy.ErrNoSuchKey -} - -func (h *errorSyspolicyHandler) ReadUInt64(key string) (uint64, error) { - h.t.Errorf("ReadUInt64(%q) unexpectedly called", key) - return 0, syspolicy.ErrNoSuchKey -} - -func (h *errorSyspolicyHandler) ReadBoolean(key string) (bool, error) { - h.t.Errorf("ReadBoolean(%q) unexpectedly called", key) - return false, syspolicy.ErrNoSuchKey -} - -func (h *errorSyspolicyHandler) ReadStringArray(key string) ([]string, error) { - h.t.Errorf("ReadStringArray(%q) unexpectedly called", key) - return nil, syspolicy.ErrNoSuchKey -} - -type mockSyspolicyHandler struct { - t *testing.T - // stringPolicies is the collection of policies that we expect to see - // queried by the current test. If the policy is expected but unset, then - // use nil, otherwise use a string equal to the policy's desired value. - stringPolicies map[syspolicy.Key]*string - // stringArrayPolicies is the collection of policies that we expected to see - // queries by the current test, that return policy string arrays. - stringArrayPolicies map[syspolicy.Key][]string - // failUnknownPolicies is set if policies other than those in stringPolicies - // (uint64 or bool policies are not supported by mockSyspolicyHandler yet) - // should be considered a test failure if they are queried. - failUnknownPolicies bool -} - -func (h *mockSyspolicyHandler) ReadString(key string) (string, error) { - if s, ok := h.stringPolicies[syspolicy.Key(key)]; ok { - if s == nil { - return "", syspolicy.ErrNoSuchKey - } - return *s, nil - } - if h.failUnknownPolicies { - h.t.Errorf("ReadString(%q) unexpectedly called", key) - } - return "", syspolicy.ErrNoSuchKey -} - -func (h *mockSyspolicyHandler) ReadUInt64(key string) (uint64, error) { - if h.failUnknownPolicies { - h.t.Errorf("ReadUInt64(%q) unexpectedly called", key) - } - return 0, syspolicy.ErrNoSuchKey -} - -func (h *mockSyspolicyHandler) ReadBoolean(key string) (bool, error) { - if h.failUnknownPolicies { - h.t.Errorf("ReadBoolean(%q) unexpectedly called", key) - } - return false, syspolicy.ErrNoSuchKey -} - -func (h *mockSyspolicyHandler) ReadStringArray(key string) ([]string, error) { - if h.failUnknownPolicies { - h.t.Errorf("ReadStringArray(%q) unexpectedly called", key) - } - if s, ok := h.stringArrayPolicies[syspolicy.Key(key)]; ok { - if s == nil { - return []string{}, syspolicy.ErrNoSuchKey - } - return s, nil - } - return nil, syspolicy.ErrNoSuchKey -} - func TestSetExitNodeIDPolicy(t *testing.T) { pfx := netip.MustParsePrefix tests := []struct { @@ -1856,23 +1770,18 @@ func TestSetExitNodeIDPolicy(t *testing.T) { }, } + syspolicy.RegisterWellKnownSettingsForTest(t) + for _, test := range tests { t.Run(test.name, func(t *testing.T) { b := newTestBackend(t) - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: map[syspolicy.Key]*string{ - syspolicy.ExitNodeID: nil, - syspolicy.ExitNodeIP: nil, - }, - } - if test.exitNodeIDKey { - msh.stringPolicies[syspolicy.ExitNodeID] = &test.exitNodeID - } - if test.exitNodeIPKey { - msh.stringPolicies[syspolicy.ExitNodeIP] = &test.exitNodeIP - } - syspolicy.SetHandlerForTest(t, msh) + + policyStore := source.NewTestStoreOf(t, + source.TestSettingOf(syspolicy.ExitNodeID, test.exitNodeID), + source.TestSettingOf(syspolicy.ExitNodeIP, test.exitNodeIP), + ) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) + if test.nm == nil { test.nm = new(netmap.NetworkMap) } @@ -1994,13 +1903,13 @@ func TestUpdateNetmapDeltaAutoExitNode(t *testing.T) { report: report, }, } - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: map[syspolicy.Key]*string{ - syspolicy.ExitNodeID: ptr.To("auto:any"), - }, - } - syspolicy.SetHandlerForTest(t, msh) + + syspolicy.RegisterWellKnownSettingsForTest(t) + policyStore := source.NewTestStoreOf(t, source.TestSettingOf( + syspolicy.ExitNodeID, "auto:any", + )) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { b := newTestLocalBackend(t) @@ -2049,13 +1958,11 @@ func TestAutoExitNodeSetNetInfoCallback(t *testing.T) { } cc = newClient(t, opts) b.cc = cc - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: map[syspolicy.Key]*string{ - syspolicy.ExitNodeID: ptr.To("auto:any"), - }, - } - syspolicy.SetHandlerForTest(t, msh) + syspolicy.RegisterWellKnownSettingsForTest(t) + policyStore := source.NewTestStoreOf(t, source.TestSettingOf( + syspolicy.ExitNodeID, "auto:any", + )) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) peer1 := makePeer(1, withCap(26), withDERP(3), withSuggest(), withExitRoutes()) peer2 := makePeer(2, withCap(26), withDERP(2), withSuggest(), withExitRoutes()) selfNode := tailcfg.Node{ @@ -2160,13 +2067,11 @@ func TestSetControlClientStatusAutoExitNode(t *testing.T) { DERPMap: derpMap, } b := newTestLocalBackend(t) - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: map[syspolicy.Key]*string{ - syspolicy.ExitNodeID: ptr.To("auto:any"), - }, - } - syspolicy.SetHandlerForTest(t, msh) + syspolicy.RegisterWellKnownSettingsForTest(t) + policyStore := source.NewTestStoreOf(t, source.TestSettingOf( + syspolicy.ExitNodeID, "auto:any", + )) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) b.netMap = nm b.lastSuggestedExitNode = peer1.StableID() b.sys.MagicSock.Get().SetLastNetcheckReportForTest(b.ctx, report) @@ -2400,17 +2305,16 @@ func TestApplySysPolicy(t *testing.T) { }, } + syspolicy.RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: make(map[syspolicy.Key]*string, len(tt.stringPolicies)), - } + settings := make([]source.TestSetting[string], 0, len(tt.stringPolicies)) for p, v := range tt.stringPolicies { - v := v // construct a unique pointer for each policy value - msh.stringPolicies[p] = &v + settings = append(settings, source.TestSettingOf(p, v)) } - syspolicy.SetHandlerForTest(t, msh) + policyStore := source.NewTestStoreOf(t, settings...) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) t.Run("unit", func(t *testing.T) { prefs := tt.prefs.Clone() @@ -2546,35 +2450,19 @@ func TestPreferencePolicyInfo(t *testing.T) { }, } + syspolicy.RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { for _, pp := range preferencePolicies { t.Run(string(pp.key), func(t *testing.T) { - var h syspolicy.Handler - - allPolicies := make(map[syspolicy.Key]*string, len(preferencePolicies)+1) - allPolicies[syspolicy.ControlURL] = nil - for _, pp := range preferencePolicies { - allPolicies[pp.key] = nil + s := source.TestSetting[string]{ + Key: pp.key, + Error: tt.policyError, + Value: tt.policyValue, } - - if tt.policyError != nil { - h = &errorSyspolicyHandler{ - t: t, - err: tt.policyError, - key: pp.key, - allowKeys: allPolicies, - } - } else { - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: allPolicies, - failUnknownPolicies: true, - } - msh.stringPolicies[pp.key] = &tt.policyValue - h = msh - } - syspolicy.SetHandlerForTest(t, h) + policyStore := source.NewTestStoreOf(t, s) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) prefs := defaultPrefs.AsStruct() pp.set(prefs, tt.initialValue) @@ -3825,15 +3713,16 @@ func TestShouldAutoExitNode(t *testing.T) { expectedBool: false, }, } + + syspolicy.RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: map[syspolicy.Key]*string{ - syspolicy.ExitNodeID: ptr.To(tt.exitNodeIDPolicyValue), - }, - } - syspolicy.SetHandlerForTest(t, msh) + policyStore := source.NewTestStoreOf(t, source.TestSettingOf( + syspolicy.ExitNodeID, tt.exitNodeIDPolicyValue, + )) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) + got := shouldAutoExitNode() if got != tt.expectedBool { t.Fatalf("expected %v got %v for %v policy value", tt.expectedBool, got, tt.exitNodeIDPolicyValue) @@ -3971,17 +3860,13 @@ func TestFillAllowedSuggestions(t *testing.T) { want: []tailcfg.StableNodeID{"ABC", "def", "gHiJ"}, }, } + syspolicy.RegisterWellKnownSettingsForTest(t) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mh := mockSyspolicyHandler{ - t: t, - } - if tt.allowPolicy != nil { - mh.stringArrayPolicies = map[syspolicy.Key][]string{ - syspolicy.AllowedSuggestedExitNodes: tt.allowPolicy, - } - } - syspolicy.SetHandlerForTest(t, &mh) + policyStore := source.NewTestStoreOf(t, source.TestSettingOf( + syspolicy.AllowedSuggestedExitNodes, tt.allowPolicy, + )) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) got := fillAllowedSuggestions() if got == nil { diff --git a/util/syspolicy/caching_handler.go b/util/syspolicy/caching_handler.go deleted file mode 100644 index 5192958bc45a5..0000000000000 --- a/util/syspolicy/caching_handler.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import ( - "errors" - "sync" -) - -// CachingHandler is a handler that reads policies from an underlying handler the first time each key is requested -// and permanently caches the result unless there is an error. If there is an ErrNoSuchKey error, that result is cached, -// otherwise the actual error is returned and the next read for that key will retry using the handler. -type CachingHandler struct { - mu sync.Mutex - strings map[string]string - uint64s map[string]uint64 - bools map[string]bool - strArrs map[string][]string - notFound map[string]bool - handler Handler -} - -// NewCachingHandler creates a CachingHandler given a handler. -func NewCachingHandler(handler Handler) *CachingHandler { - return &CachingHandler{ - handler: handler, - strings: make(map[string]string), - uint64s: make(map[string]uint64), - bools: make(map[string]bool), - strArrs: make(map[string][]string), - notFound: make(map[string]bool), - } -} - -// ReadString reads the policy settings value string given the key. -// ReadString first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadString(key string) (string, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.strings[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return "", ErrNoSuchKey - } - val, err := ch.handler.ReadString(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return "", err - } else if err != nil { - return "", err - } - ch.strings[key] = val - return val, nil -} - -// ReadUInt64 reads the policy settings uint64 value given the key. -// ReadUInt64 first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadUInt64(key string) (uint64, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.uint64s[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return 0, ErrNoSuchKey - } - val, err := ch.handler.ReadUInt64(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return 0, err - } else if err != nil { - return 0, err - } - ch.uint64s[key] = val - return val, nil -} - -// ReadBoolean reads the policy settings boolean value given the key. -// ReadBoolean first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadBoolean(key string) (bool, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.bools[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return false, ErrNoSuchKey - } - val, err := ch.handler.ReadBoolean(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return false, err - } else if err != nil { - return false, err - } - ch.bools[key] = val - return val, nil -} - -// ReadBoolean reads the policy settings boolean value given the key. -// ReadBoolean first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadStringArray(key string) ([]string, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.strArrs[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return nil, ErrNoSuchKey - } - val, err := ch.handler.ReadStringArray(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return nil, err - } else if err != nil { - return nil, err - } - ch.strArrs[key] = val - return val, nil -} diff --git a/util/syspolicy/caching_handler_test.go b/util/syspolicy/caching_handler_test.go deleted file mode 100644 index 881f6ff83c0f8..0000000000000 --- a/util/syspolicy/caching_handler_test.go +++ /dev/null @@ -1,262 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import ( - "testing" -) - -func TestHandlerReadString(t *testing.T) { - tests := []struct { - name string - key string - handlerKey Key - handlerValue string - handlerError error - preserveHandler bool - wantValue string - wantErr error - strings map[string]string - expectedCalls int - }{ - { - name: "read existing cached values", - key: "test", - handlerKey: "do not read", - strings: map[string]string{"test": "foo"}, - wantValue: "foo", - expectedCalls: 0, - }, - { - name: "read existing values not cached", - key: "test", - handlerKey: "test", - handlerValue: "foo", - wantValue: "foo", - expectedCalls: 1, - }, - { - name: "error no such key", - key: "test", - handlerKey: "test", - handlerError: ErrNoSuchKey, - wantErr: ErrNoSuchKey, - expectedCalls: 1, - }, - { - name: "other error", - key: "test", - handlerKey: "test", - handlerError: someOtherError, - wantErr: someOtherError, - preserveHandler: true, - expectedCalls: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testHandler := &testHandler{ - t: t, - key: tt.handlerKey, - s: tt.handlerValue, - err: tt.handlerError, - } - cache := NewCachingHandler(testHandler) - if tt.strings != nil { - cache.strings = tt.strings - } - got, err := cache.ReadString(tt.key) - if err != tt.wantErr { - t.Errorf("err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("got %v want %v", got, cache.strings[tt.key]) - } - if !tt.preserveHandler { - testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil - } - got, err = cache.ReadString(tt.key) - if err != tt.wantErr { - t.Errorf("repeat err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) - } - if testHandler.calls != tt.expectedCalls { - t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) - } - }) - } -} - -func TestHandlerReadUint64(t *testing.T) { - tests := []struct { - name string - key string - handlerKey Key - handlerValue uint64 - handlerError error - preserveHandler bool - wantValue uint64 - wantErr error - uint64s map[string]uint64 - expectedCalls int - }{ - { - name: "read existing cached values", - key: "test", - handlerKey: "do not read", - uint64s: map[string]uint64{"test": 1}, - wantValue: 1, - expectedCalls: 0, - }, - { - name: "read existing values not cached", - key: "test", - handlerKey: "test", - handlerValue: 1, - wantValue: 1, - expectedCalls: 1, - }, - { - name: "error no such key", - key: "test", - handlerKey: "test", - handlerError: ErrNoSuchKey, - wantErr: ErrNoSuchKey, - expectedCalls: 1, - }, - { - name: "other error", - key: "test", - handlerKey: "test", - handlerError: someOtherError, - wantErr: someOtherError, - preserveHandler: true, - expectedCalls: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testHandler := &testHandler{ - t: t, - key: tt.handlerKey, - u64: tt.handlerValue, - err: tt.handlerError, - } - cache := NewCachingHandler(testHandler) - if tt.uint64s != nil { - cache.uint64s = tt.uint64s - } - got, err := cache.ReadUInt64(tt.key) - if err != tt.wantErr { - t.Errorf("err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("got %v want %v", got, cache.strings[tt.key]) - } - if !tt.preserveHandler { - testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil - } - got, err = cache.ReadUInt64(tt.key) - if err != tt.wantErr { - t.Errorf("repeat err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) - } - if testHandler.calls != tt.expectedCalls { - t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) - } - }) - } - -} - -func TestHandlerReadBool(t *testing.T) { - tests := []struct { - name string - key string - handlerKey Key - handlerValue bool - handlerError error - preserveHandler bool - wantValue bool - wantErr error - bools map[string]bool - expectedCalls int - }{ - { - name: "read existing cached values", - key: "test", - handlerKey: "do not read", - bools: map[string]bool{"test": true}, - wantValue: true, - expectedCalls: 0, - }, - { - name: "read existing values not cached", - key: "test", - handlerKey: "test", - handlerValue: true, - wantValue: true, - expectedCalls: 1, - }, - { - name: "error no such key", - key: "test", - handlerKey: "test", - handlerError: ErrNoSuchKey, - wantErr: ErrNoSuchKey, - expectedCalls: 1, - }, - { - name: "other error", - key: "test", - handlerKey: "test", - handlerError: someOtherError, - wantErr: someOtherError, - preserveHandler: true, - expectedCalls: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testHandler := &testHandler{ - t: t, - key: tt.handlerKey, - b: tt.handlerValue, - err: tt.handlerError, - } - cache := NewCachingHandler(testHandler) - if tt.bools != nil { - cache.bools = tt.bools - } - got, err := cache.ReadBoolean(tt.key) - if err != tt.wantErr { - t.Errorf("err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("got %v want %v", got, cache.strings[tt.key]) - } - if !tt.preserveHandler { - testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil - } - got, err = cache.ReadBoolean(tt.key) - if err != tt.wantErr { - t.Errorf("repeat err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) - } - if testHandler.calls != tt.expectedCalls { - t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) - } - }) - } - -} diff --git a/util/syspolicy/handler.go b/util/syspolicy/handler.go index f1fad97709a3f..f511f0a562e8b 100644 --- a/util/syspolicy/handler.go +++ b/util/syspolicy/handler.go @@ -4,16 +4,17 @@ package syspolicy import ( - "errors" - "sync/atomic" + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/rsop" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" ) -var ( - handlerUsed atomic.Bool - handler Handler = defaultHandler{} -) +// TODO(nickkhyl): delete this file once other repos are updated. // Handler reads system policies from OS-specific storage. +// +// Deprecated: implementing a [source.Store] should be preferred. type Handler interface { // ReadString reads the policy setting's string value for the given key. // It should return ErrNoSuchKey if the key does not have a value set. @@ -29,55 +30,88 @@ type Handler interface { ReadStringArray(key string) ([]string, error) } -// ErrNoSuchKey is returned by a Handler when the specified key does not have a -// value set. -var ErrNoSuchKey = errors.New("no such key") +// RegisterHandler wraps and registers the specified handler as the device's +// policy [source.Store] for the program's lifetime. +// +// Deprecated: using [RegisterStore] should be preferred. +func RegisterHandler(h Handler) { + rsop.RegisterStore("DeviceHandler", setting.DeviceScope, WrapHandler(h)) +} -// defaultHandler is the catch all syspolicy type for anything that isn't windows or apple. -type defaultHandler struct{} +// TB is a subset of testing.TB that we use to set up test helpers. +// It's defined here to avoid pulling in the testing package. +type TB = internal.TB -func (defaultHandler) ReadString(_ string) (string, error) { - return "", ErrNoSuchKey +// SetHandlerForTest wraps and sets the specified handler as the device's policy +// [source.Store] for the duration of tb. +// +// Deprecated: using [MustRegisterStoreForTest] should be preferred. +func SetHandlerForTest(tb TB, h Handler) { + RegisterWellKnownSettingsForTest(tb) + MustRegisterStoreForTest(tb, "DeviceHandler-TestOnly", setting.DefaultScope(), WrapHandler(h)) } -func (defaultHandler) ReadUInt64(_ string) (uint64, error) { - return 0, ErrNoSuchKey +var _ source.Store = (*handlerStore)(nil) + +// handlerStore is a [source.Store] that calls the underlying [Handler]. +// +// TODO(nickkhyl): remove it when the corp and android repos are updated. +type handlerStore struct { + h Handler } -func (defaultHandler) ReadBoolean(_ string) (bool, error) { - return false, ErrNoSuchKey +// WrapHandler returns a [source.Store] that wraps the specified [Handler]. +func WrapHandler(h Handler) source.Store { + return handlerStore{h} } -func (defaultHandler) ReadStringArray(_ string) ([]string, error) { - return nil, ErrNoSuchKey +// Lock implements [source.Lockable]. +func (s handlerStore) Lock() error { + if lockable, ok := s.h.(source.Lockable); ok { + return lockable.Lock() + } + return nil } -// markHandlerInUse is called before handler methods are called. -func markHandlerInUse() { - handlerUsed.Store(true) +// Unlock implements [source.Lockable]. +func (s handlerStore) Unlock() { + if lockable, ok := s.h.(source.Lockable); ok { + lockable.Unlock() + } } -// RegisterHandler initializes the policy handler and ensures registration will happen once. -func RegisterHandler(h Handler) { - // Technically this assignment is not concurrency safe, but in the - // event that there was any risk of a data race, we will panic due to - // the CompareAndSwap failing. - handler = h - if !handlerUsed.CompareAndSwap(false, true) { - panic("handler was already used before registration") +// RegisterChangeCallback implements [source.Changeable]. +func (s handlerStore) RegisterChangeCallback(callback func()) (unregister func(), err error) { + if changeable, ok := s.h.(source.Changeable); ok { + return changeable.RegisterChangeCallback(callback) } + return func() {}, nil } -// TB is a subset of testing.TB that we use to set up test helpers. -// It's defined here to avoid pulling in the testing package. -type TB interface { - Helper() - Cleanup(func()) +// ReadString implements [source.Store]. +func (s handlerStore) ReadString(key setting.Key) (string, error) { + return s.h.ReadString(string(key)) } -func SetHandlerForTest(tb TB, h Handler) { - tb.Helper() - oldHandler := handler - handler = h - tb.Cleanup(func() { handler = oldHandler }) +// ReadUInt64 implements [source.Store]. +func (s handlerStore) ReadUInt64(key setting.Key) (uint64, error) { + return s.h.ReadUInt64(string(key)) +} + +// ReadBoolean implements [source.Store]. +func (s handlerStore) ReadBoolean(key setting.Key) (bool, error) { + return s.h.ReadBoolean(string(key)) +} + +// ReadStringArray implements [source.Store]. +func (s handlerStore) ReadStringArray(key setting.Key) ([]string, error) { + return s.h.ReadStringArray(string(key)) +} + +// Done implements [source.Expirable]. +func (s handlerStore) Done() <-chan struct{} { + if expirable, ok := s.h.(source.Expirable); ok { + return expirable.Done() + } + return nil } diff --git a/util/syspolicy/handler_test.go b/util/syspolicy/handler_test.go deleted file mode 100644 index 39b18936f176d..0000000000000 --- a/util/syspolicy/handler_test.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import "testing" - -func TestDefaultHandlerReadValues(t *testing.T) { - var h defaultHandler - - got, err := h.ReadString(string(AdminConsoleVisibility)) - if got != "" || err != ErrNoSuchKey { - t.Fatalf("got %v err %v", got, err) - } - result, err := h.ReadUInt64(string(LogSCMInteractions)) - if result != 0 || err != ErrNoSuchKey { - t.Fatalf("got %v err %v", result, err) - } -} diff --git a/util/syspolicy/handler_windows.go b/util/syspolicy/handler_windows.go deleted file mode 100644 index 661853ead5d53..0000000000000 --- a/util/syspolicy/handler_windows.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import ( - "errors" - "fmt" - - "tailscale.com/util/clientmetric" - "tailscale.com/util/winutil" -) - -var ( - windowsErrors = clientmetric.NewCounter("windows_syspolicy_errors") - windowsAny = clientmetric.NewGauge("windows_syspolicy_any") -) - -type windowsHandler struct{} - -func init() { - RegisterHandler(NewCachingHandler(windowsHandler{})) - - keyList := []struct { - isSet func(Key) bool - keys []Key - }{ - { - isSet: func(k Key) bool { - _, err := handler.ReadString(string(k)) - return err == nil - }, - keys: stringKeys, - }, - { - isSet: func(k Key) bool { - _, err := handler.ReadBoolean(string(k)) - return err == nil - }, - keys: boolKeys, - }, - { - isSet: func(k Key) bool { - _, err := handler.ReadUInt64(string(k)) - return err == nil - }, - keys: uint64Keys, - }, - } - - var anySet bool - for _, l := range keyList { - for _, k := range l.keys { - if !l.isSet(k) { - continue - } - clientmetric.NewGauge(fmt.Sprintf("windows_syspolicy_%s", k)).Set(1) - anySet = true - } - } - if anySet { - windowsAny.Set(1) - } -} - -func (windowsHandler) ReadString(key string) (string, error) { - s, err := winutil.GetPolicyString(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - - return s, err -} - -func (windowsHandler) ReadUInt64(key string) (uint64, error) { - value, err := winutil.GetPolicyInteger(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - return value, err -} - -func (windowsHandler) ReadBoolean(key string) (bool, error) { - value, err := winutil.GetPolicyInteger(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - return value != 0, err -} - -func (windowsHandler) ReadStringArray(key string) ([]string, error) { - value, err := winutil.GetPolicyStringArray(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - return value, err -} diff --git a/util/syspolicy/policy_keys.go b/util/syspolicy/policy_keys.go index ec0556a942cc6..162885b27fa67 100644 --- a/util/syspolicy/policy_keys.go +++ b/util/syspolicy/policy_keys.go @@ -3,10 +3,24 @@ package syspolicy -import "tailscale.com/util/syspolicy/setting" +import ( + "tailscale.com/types/lazy" + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/testenv" +) +// Key is a string that uniquely identifies a policy and must remain unchanged +// once established and documented for a given policy setting. It may contain +// alphanumeric characters and zero or more [KeyPathSeparator]s to group +// individual policy settings into categories. type Key = setting.Key +// The const block below lists known policy keys. +// When adding a key to this list, remember to add a corresponding +// [setting.Definition] to [implicitDefinitions] below. +// Otherwise, the [TestKnownKeysRegistered] test will fail as a reminder. + const ( // Keys with a string value ControlURL Key = "LoginURL" // default ""; if blank, ipn uses ipn.DefaultControlURL. @@ -110,3 +124,90 @@ const ( // AllowedSuggestedExitNodes's string array value is a list of exit node IDs that restricts which exit nodes are considered when generating suggestions for exit nodes. AllowedSuggestedExitNodes Key = "AllowedSuggestedExitNodes" ) + +// implicitDefinitions is a list of [setting.Definition] that will be registered +// automatically when the policy setting definitions are first used by the syspolicy package hierarchy. +// This includes the first time a policy needs to be read from any source. +var implicitDefinitions = []*setting.Definition{ + // Device policy settings (can only be configured on a per-device basis): + setting.NewDefinition(AllowedSuggestedExitNodes, setting.DeviceSetting, setting.StringListValue), + setting.NewDefinition(ApplyUpdates, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(AuthKey, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(CheckUpdates, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(ControlURL, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(DeviceSerialNumber, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(EnableIncomingConnections, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(EnableRunExitNode, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(EnableServerMode, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(EnableTailscaleDNS, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(EnableTailscaleSubnets, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(ExitNodeAllowLANAccess, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(ExitNodeID, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(ExitNodeIP, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(FlushDNSOnSessionUnlock, setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition(LogSCMInteractions, setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition(LogTarget, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(MachineCertificateSubject, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(PostureChecking, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(Tailnet, setting.DeviceSetting, setting.StringValue), + + // User policy settings (can be configured on a user- or device-basis): + setting.NewDefinition(AdminConsoleVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(AutoUpdateVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(ExitNodeMenuVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(KeyExpirationNoticeTime, setting.UserSetting, setting.DurationValue), + setting.NewDefinition(ManagedByCaption, setting.UserSetting, setting.StringValue), + setting.NewDefinition(ManagedByOrganizationName, setting.UserSetting, setting.StringValue), + setting.NewDefinition(ManagedByURL, setting.UserSetting, setting.StringValue), + setting.NewDefinition(NetworkDevicesVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(PreferencesMenuVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(ResetToDefaultsVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(RunExitNodeVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(SuggestedExitNodeVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(TestMenuVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(UpdateMenuVisibility, setting.UserSetting, setting.VisibilityValue), +} + +func init() { + internal.Init.MustDefer(func() error { + // Avoid implicit [setting.Definition] registration during tests. + // Each test should control which policy settings to register. + // Use [setting.SetDefinitionsForTest] to specify necessary definitions, + // or [setWellKnownSettingsForTest] to set implicit definitions for the test duration. + if testenv.InTest() { + return nil + } + for _, d := range implicitDefinitions { + setting.RegisterDefinition(d) + } + return nil + }) +} + +var implicitDefinitionMap lazy.SyncValue[setting.DefinitionMap] + +// WellKnownSettingDefinition returns a well-known, implicit setting definition by its key, +// or an [ErrNoSuchKey] if a policy setting with the specified key does not exist +// among implicit policy definitions. +func WellKnownSettingDefinition(k Key) (*setting.Definition, error) { + m, err := implicitDefinitionMap.GetErr(func() (setting.DefinitionMap, error) { + return setting.DefinitionMapOf(implicitDefinitions) + }) + if err != nil { + return nil, err + } + if d, ok := m[k]; ok { + return d, nil + } + return nil, ErrNoSuchKey +} + +// RegisterWellKnownSettingsForTest registers all implicit setting definitions +// for the duration of the test. +func RegisterWellKnownSettingsForTest(tb TB) { + tb.Helper() + err := setting.SetDefinitionsForTest(tb, implicitDefinitions...) + if err != nil { + tb.Fatalf("Failed to register well-known settings: %v", err) + } +} diff --git a/util/syspolicy/policy_keys_test.go b/util/syspolicy/policy_keys_test.go new file mode 100644 index 0000000000000..4d3260f3e0e60 --- /dev/null +++ b/util/syspolicy/policy_keys_test.go @@ -0,0 +1,95 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syspolicy + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "go/types" + "os" + "reflect" + "strconv" + "testing" + + "tailscale.com/util/syspolicy/setting" +) + +func TestKnownKeysRegistered(t *testing.T) { + keyConsts, err := listStringConsts[Key]("policy_keys.go") + if err != nil { + t.Fatalf("listStringConsts failed: %v", err) + } + + m, err := setting.DefinitionMapOf(implicitDefinitions) + if err != nil { + t.Fatalf("definitionMapOf failed: %v", err) + } + + for _, key := range keyConsts { + t.Run(string(key), func(t *testing.T) { + d := m[key] + if d == nil { + t.Fatalf("%q was not registered", key) + } + if d.Key() != key { + t.Fatalf("d.Key got: %s, want %s", d.Key(), key) + } + }) + } +} + +func TestNotAWellKnownSetting(t *testing.T) { + d, err := WellKnownSettingDefinition("TestSettingDoesNotExist") + if d != nil || err == nil { + t.Fatalf("got %v, %v; want nil, %v", d, err, ErrNoSuchKey) + } +} + +func listStringConsts[T ~string](filename string) (map[string]T, error) { + fset := token.NewFileSet() + src, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + + f, err := parser.ParseFile(fset, filename, src, 0) + if err != nil { + return nil, err + } + + consts := make(map[string]T) + typeName := reflect.TypeFor[T]().Name() + for _, d := range f.Decls { + g, ok := d.(*ast.GenDecl) + if !ok || g.Tok != token.CONST { + continue + } + + for _, s := range g.Specs { + vs, ok := s.(*ast.ValueSpec) + if !ok || len(vs.Names) != len(vs.Values) { + continue + } + if typ, ok := vs.Type.(*ast.Ident); !ok || typ.Name != typeName { + continue + } + + for i, n := range vs.Names { + lit, ok := vs.Values[i].(*ast.BasicLit) + if !ok { + return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, types.ExprString(vs.Values[i])) + } + val, err := strconv.Unquote(lit.Value) + if err != nil { + return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, lit.Value) + } + consts[n.Name] = T(val) + } + } + } + + return consts, nil +} diff --git a/util/syspolicy/policy_keys_windows.go b/util/syspolicy/policy_keys_windows.go deleted file mode 100644 index 5e9a716957bdb..0000000000000 --- a/util/syspolicy/policy_keys_windows.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -var stringKeys = []Key{ - ControlURL, - LogTarget, - Tailnet, - ExitNodeID, - ExitNodeIP, - EnableIncomingConnections, - EnableServerMode, - ExitNodeAllowLANAccess, - EnableTailscaleDNS, - EnableTailscaleSubnets, - AdminConsoleVisibility, - NetworkDevicesVisibility, - TestMenuVisibility, - UpdateMenuVisibility, - RunExitNodeVisibility, - PreferencesMenuVisibility, - ExitNodeMenuVisibility, - AutoUpdateVisibility, - ResetToDefaultsVisibility, - KeyExpirationNoticeTime, - PostureChecking, - ManagedByOrganizationName, - ManagedByCaption, - ManagedByURL, -} - -var boolKeys = []Key{ - LogSCMInteractions, - FlushDNSOnSessionUnlock, -} - -var uint64Keys = []Key{} diff --git a/util/syspolicy/syspolicy.go b/util/syspolicy/syspolicy.go index abe42ed90f8c7..d925731c38b3a 100644 --- a/util/syspolicy/syspolicy.go +++ b/util/syspolicy/syspolicy.go @@ -1,51 +1,82 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// Package syspolicy provides functions to retrieve system settings of a device. +// Package syspolicy facilitates retrieval of the current policy settings +// applied to the device or user and receiving notifications when the policy +// changes. +// +// It provides functions that return specific policy settings by their unique +// [setting.Key]s, such as [GetBoolean], [GetUint64], [GetString], +// [GetStringArray], [GetPreferenceOption], [GetVisibility] and [GetDuration]. package syspolicy import ( "errors" + "fmt" + "reflect" "time" "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/rsop" "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" ) -func GetString(key Key, defaultValue string) (string, error) { - markHandlerInUse() - v, err := handler.ReadString(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil +var ( + // ErrNotConfigured is returned when the requested policy setting is not configured. + ErrNotConfigured = setting.ErrNotConfigured + // ErrTypeMismatch is returned when there's a type mismatch between the actual type + // of the setting value and the expected type. + ErrTypeMismatch = setting.ErrTypeMismatch + // ErrNoSuchKey is returned by [setting.DefinitionOf] when no policy setting + // has been registered with the specified key. + // + // This error is also returned by a (now deprecated) [Handler] when the specified + // key does not have a value set. While the package maintains compatibility with this + // usage of ErrNoSuchKey, it is recommended to return [ErrNotConfigured] from newer + // [source.Store] implementations. + ErrNoSuchKey = setting.ErrNoSuchKey +) + +// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope]. +// +// It is a shorthand for [rsop.RegisterStore]. +func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*rsop.StoreRegistration, error) { + return rsop.RegisterStore(name, scope, store) +} + +// MustRegisterStoreForTest is like [rsop.RegisterStoreForTest], but it fails the test if the store could not be registered. +func MustRegisterStoreForTest(tb TB, name string, scope setting.PolicyScope, store source.Store) *rsop.StoreRegistration { + tb.Helper() + reg, err := rsop.RegisterStoreForTest(tb, name, scope, store) + if err != nil { + tb.Fatalf("Failed to register policy store %q as a %v policy source: %v", name, scope, err) } - return v, err + return reg +} + +// GetString returns a string policy setting with the specified key, +// or defaultValue if it does not exist. +func GetString(key Key, defaultValue string) (string, error) { + return getCurrentPolicySettingValue(key, defaultValue) } +// GetUint64 returns a numeric policy setting with the specified key, +// or defaultValue if it does not exist. func GetUint64(key Key, defaultValue uint64) (uint64, error) { - markHandlerInUse() - v, err := handler.ReadUInt64(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil - } - return v, err + return getCurrentPolicySettingValue(key, defaultValue) } +// GetBoolean returns a boolean policy setting with the specified key, +// or defaultValue if it does not exist. func GetBoolean(key Key, defaultValue bool) (bool, error) { - markHandlerInUse() - v, err := handler.ReadBoolean(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil - } - return v, err + return getCurrentPolicySettingValue(key, defaultValue) } +// GetStringArray returns a multi-string policy setting with the specified key, +// or defaultValue if it does not exist. func GetStringArray(key Key, defaultValue []string) ([]string, error) { - markHandlerInUse() - v, err := handler.ReadStringArray(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil - } - return v, err + return getCurrentPolicySettingValue(key, defaultValue) } // GetPreferenceOption loads a policy from the registry that can be @@ -55,13 +86,7 @@ func GetStringArray(key Key, defaultValue []string) ([]string, error) { // "always" and "never" remove the user's ability to make a selection. If not // present or set to a different value, "user-decides" is the default. func GetPreferenceOption(name Key) (setting.PreferenceOption, error) { - s, err := GetString(name, "user-decides") - if err != nil { - return setting.ShowChoiceByPolicy, err - } - var opt setting.PreferenceOption - err = opt.UnmarshalText([]byte(s)) - return opt, err + return getCurrentPolicySettingValue(name, setting.ShowChoiceByPolicy) } // GetVisibility loads a policy from the registry that can be managed @@ -70,13 +95,7 @@ func GetPreferenceOption(name Key) (setting.PreferenceOption, error) { // true) or "hide" (return true). If not present or set to a different value, // "show" (return false) is the default. func GetVisibility(name Key) (setting.Visibility, error) { - s, err := GetString(name, "show") - if err != nil { - return setting.VisibleByPolicy, err - } - var visibility setting.Visibility - visibility.UnmarshalText([]byte(s)) - return visibility, nil + return getCurrentPolicySettingValue(name, setting.VisibleByPolicy) } // GetDuration loads a policy from the registry that can be managed @@ -85,15 +104,58 @@ func GetVisibility(name Key) (setting.Visibility, error) { // understands. If the registry value is "" or can not be processed, // defaultValue is returned instead. func GetDuration(name Key, defaultValue time.Duration) (time.Duration, error) { - opt, err := GetString(name, "") - if opt == "" || err != nil { - return defaultValue, err + d, err := getCurrentPolicySettingValue(name, defaultValue) + if err != nil { + return d, err } - v, err := time.ParseDuration(opt) - if err != nil || v < 0 { + if d < 0 { return defaultValue, nil } - return v, nil + return d, nil +} + +// RegisterChangeCallback adds a function that will be called whenever the effective policy +// for the default scope changes. The returned function can be used to unregister the callback. +func RegisterChangeCallback(cb rsop.PolicyChangeCallback) (unregister func(), err error) { + effective, err := rsop.PolicyFor(setting.DefaultScope()) + if err != nil { + return nil, err + } + return effective.RegisterChangeCallback(cb), nil +} + +// getCurrentPolicySettingValue returns the value of the policy setting +// specified by its key from the [rsop.Policy] of the [setting.DefaultScope]. It +// returns def if the policy setting is not configured, or an error if it has +// an error or could not be converted to the specified type T. +func getCurrentPolicySettingValue[T setting.ValueType](key Key, def T) (T, error) { + effective, err := rsop.PolicyFor(setting.DefaultScope()) + if err != nil { + return def, err + } + value, err := effective.Get().GetErr(key) + if err != nil { + if errors.Is(err, setting.ErrNotConfigured) || errors.Is(err, setting.ErrNoSuchKey) { + return def, nil + } + return def, err + } + if res, ok := value.(T); ok { + return res, nil + } + return convertPolicySettingValueTo(value, def) +} + +func convertPolicySettingValueTo[T setting.ValueType](value any, def T) (T, error) { + // Convert [PreferenceOption], [Visibility], or [time.Duration] back to a string + // if someone requests a string instead of the actual setting's value. + // TODO(nickkhyl): check if this behavior is relied upon anywhere besides the old tests. + if reflect.TypeFor[T]().Kind() == reflect.String { + if str, ok := value.(fmt.Stringer); ok { + return any(str.String()).(T), nil + } + } + return def, fmt.Errorf("%w: got %T, want %T", setting.ErrTypeMismatch, value, def) } // SelectControlURL returns the ControlURL to use based on a value in diff --git a/util/syspolicy/syspolicy_test.go b/util/syspolicy/syspolicy_test.go index 8280aa1dfbdac..a70a49d395c22 100644 --- a/util/syspolicy/syspolicy_test.go +++ b/util/syspolicy/syspolicy_test.go @@ -9,57 +9,15 @@ import ( "testing" "time" + "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/internal/metrics" "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" ) -// testHandler encompasses all data types returned when testing any of the syspolicy -// methods that involve getting a policy value. -// For keys and the corresponding values, check policy_keys.go. -type testHandler struct { - t *testing.T - key Key - s string - u64 uint64 - b bool - sArr []string - err error - calls int // used for testing reads from cache vs. handler -} - var someOtherError = errors.New("error other than not found") -func (th *testHandler) ReadString(key string) (string, error) { - if key != string(th.key) { - th.t.Errorf("ReadString(%q) want %q", key, th.key) - } - th.calls++ - return th.s, th.err -} - -func (th *testHandler) ReadUInt64(key string) (uint64, error) { - if key != string(th.key) { - th.t.Errorf("ReadUint64(%q) want %q", key, th.key) - } - th.calls++ - return th.u64, th.err -} - -func (th *testHandler) ReadBoolean(key string) (bool, error) { - if key != string(th.key) { - th.t.Errorf("ReadBool(%q) want %q", key, th.key) - } - th.calls++ - return th.b, th.err -} - -func (th *testHandler) ReadStringArray(key string) ([]string, error) { - if key != string(th.key) { - th.t.Errorf("ReadStringArray(%q) want %q", key, th.key) - } - th.calls++ - return th.sArr, th.err -} - func TestGetString(t *testing.T) { tests := []struct { name string @@ -69,23 +27,28 @@ func TestGetString(t *testing.T) { defaultValue string wantValue string wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", key: AdminConsoleVisibility, handlerValue: "hide", wantValue: "hide", + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AdminConsole", Value: 1}, + }, }, { name: "read non-existing value", key: EnableServerMode, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantError: nil, }, { name: "read non-existing value, non-blank default", key: EnableServerMode, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, defaultValue: "test", wantValue: "test", wantError: nil, @@ -95,24 +58,43 @@ func TestGetString(t *testing.T) { key: NetworkDevicesVisibility, handlerError: someOtherError, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_NetworkDevices_error", Value: 1}, + }, }, } + RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - s: tt.handlerValue, - err: tt.handlerError, - }) + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[string]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + value, err := GetString(tt.key, tt.defaultValue) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if value != tt.wantValue { t.Errorf("value=%v, want %v", value, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -129,7 +111,7 @@ func TestGetUint64(t *testing.T) { }{ { name: "read existing value", - key: KeyExpirationNoticeTime, + key: LogSCMInteractions, handlerValue: 1, wantValue: 1, }, @@ -137,14 +119,14 @@ func TestGetUint64(t *testing.T) { name: "read non-existing value", key: LogSCMInteractions, handlerValue: 0, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: 0, }, { name: "read non-existing value, non-zero default", key: LogSCMInteractions, defaultValue: 2, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: 2, }, { @@ -157,14 +139,23 @@ func TestGetUint64(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - u64: tt.handlerValue, - err: tt.handlerError, - }) + // None of the policy settings tested here are integers. + // In fact, we don't have any integer policies as of 2024-10-08. + // However, we can register each of them as an integer policy setting + // for the duration of the test, providing us with something to test against. + if err := setting.SetDefinitionsForTest(t, setting.NewDefinition(tt.key, setting.DeviceSetting, setting.IntegerValue)); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + s := source.TestSetting[uint64]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + value, err := GetUint64(tt.key, tt.defaultValue) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if value != tt.wantValue { @@ -183,45 +174,69 @@ func TestGetBoolean(t *testing.T) { defaultValue bool wantValue bool wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", key: FlushDNSOnSessionUnlock, handlerValue: true, wantValue: true, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_FlushDNSOnSessionUnlock", Value: 1}, + }, }, { name: "read non-existing value", key: LogSCMInteractions, handlerValue: false, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: false, }, { name: "reading value returns other error", key: FlushDNSOnSessionUnlock, handlerError: someOtherError, - wantError: someOtherError, + wantError: someOtherError, // expect error... defaultValue: true, - wantValue: false, + wantValue: true, // ...AND default value if the handler fails. + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_FlushDNSOnSessionUnlock_error", Value: 1}, + }, }, } + RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - b: tt.handlerValue, - err: tt.handlerError, - }) + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[bool]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + value, err := GetBoolean(tt.key, tt.defaultValue) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if value != tt.wantValue { t.Errorf("value=%v, want %v", value, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -234,29 +249,42 @@ func TestGetPreferenceOption(t *testing.T) { handlerError error wantValue setting.PreferenceOption wantError error + wantMetrics []metrics.TestState }{ { name: "always by policy", key: EnableIncomingConnections, handlerValue: "always", wantValue: setting.AlwaysByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections", Value: 1}, + }, }, { name: "never by policy", key: EnableIncomingConnections, handlerValue: "never", wantValue: setting.NeverByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections", Value: 1}, + }, }, { name: "use default", key: EnableIncomingConnections, handlerValue: "", wantValue: setting.ShowChoiceByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections", Value: 1}, + }, }, { name: "read non-existing value", key: EnableIncomingConnections, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: setting.ShowChoiceByPolicy, }, { @@ -265,24 +293,43 @@ func TestGetPreferenceOption(t *testing.T) { handlerError: someOtherError, wantValue: setting.ShowChoiceByPolicy, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections_error", Value: 1}, + }, }, } + RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - s: tt.handlerValue, - err: tt.handlerError, - }) + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[string]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + option, err := GetPreferenceOption(tt.key) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if option != tt.wantValue { t.Errorf("option=%v, want %v", option, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -295,24 +342,33 @@ func TestGetVisibility(t *testing.T) { handlerError error wantValue setting.Visibility wantError error + wantMetrics []metrics.TestState }{ { name: "hidden by policy", key: AdminConsoleVisibility, handlerValue: "hide", wantValue: setting.HiddenByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AdminConsole", Value: 1}, + }, }, { name: "visibility default", key: AdminConsoleVisibility, handlerValue: "show", wantValue: setting.VisibleByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AdminConsole", Value: 1}, + }, }, { name: "read non-existing value", key: AdminConsoleVisibility, handlerValue: "show", - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: setting.VisibleByPolicy, }, { @@ -322,24 +378,43 @@ func TestGetVisibility(t *testing.T) { handlerError: someOtherError, wantValue: setting.VisibleByPolicy, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_AdminConsole_error", Value: 1}, + }, }, } + RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - s: tt.handlerValue, - err: tt.handlerError, - }) + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[string]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + visibility, err := GetVisibility(tt.key) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if visibility != tt.wantValue { t.Errorf("visibility=%v, want %v", visibility, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -353,6 +428,7 @@ func TestGetDuration(t *testing.T) { defaultValue time.Duration wantValue time.Duration wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", @@ -360,25 +436,34 @@ func TestGetDuration(t *testing.T) { handlerValue: "2h", wantValue: 2 * time.Hour, defaultValue: 24 * time.Hour, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_KeyExpirationNotice", Value: 1}, + }, }, { name: "invalid duration value", key: KeyExpirationNoticeTime, handlerValue: "-20", wantValue: 24 * time.Hour, + wantError: errors.New(`time: missing unit in duration "-20"`), defaultValue: 24 * time.Hour, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1}, + }, }, { name: "read non-existing value", key: KeyExpirationNoticeTime, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: 24 * time.Hour, defaultValue: 24 * time.Hour, }, { name: "read non-existing value different default", key: KeyExpirationNoticeTime, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: 0 * time.Second, defaultValue: 0 * time.Second, }, @@ -389,24 +474,43 @@ func TestGetDuration(t *testing.T) { wantValue: 24 * time.Hour, wantError: someOtherError, defaultValue: 24 * time.Hour, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1}, + }, }, } + RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - s: tt.handlerValue, - err: tt.handlerError, - }) + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[string]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + duration, err := GetDuration(tt.key, tt.defaultValue) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if duration != tt.wantValue { t.Errorf("duration=%v, want %v", duration, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -420,23 +524,28 @@ func TestGetStringArray(t *testing.T) { defaultValue []string wantValue []string wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", key: AllowedSuggestedExitNodes, handlerValue: []string{"foo", "bar"}, wantValue: []string{"foo", "bar"}, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowedSuggestedExitNodes", Value: 1}, + }, }, { name: "read non-existing value", key: AllowedSuggestedExitNodes, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantError: nil, }, { name: "read non-existing value, non nil default", key: AllowedSuggestedExitNodes, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, defaultValue: []string{"foo", "bar"}, wantValue: []string{"foo", "bar"}, wantError: nil, @@ -446,28 +555,68 @@ func TestGetStringArray(t *testing.T) { key: AllowedSuggestedExitNodes, handlerError: someOtherError, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_AllowedSuggestedExitNodes_error", Value: 1}, + }, }, } + RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - sArr: tt.handlerValue, - err: tt.handlerError, - }) + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[[]string]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + value, err := GetStringArray(tt.key, tt.defaultValue) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if !slices.Equal(tt.wantValue, value) { t.Errorf("value=%v, want %v", value, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } +func registerSingleSettingStoreForTest[T source.TestValueType](tb TB, s source.TestSetting[T]) { + policyStore := source.NewTestStoreOf(tb, s) + MustRegisterStoreForTest(tb, "TestStore", setting.DeviceScope, policyStore) +} + +func BenchmarkGetString(b *testing.B) { + loggerx.SetForTest(b, logger.Discard, logger.Discard) + RegisterWellKnownSettingsForTest(b) + + wantControlURL := "https://login.tailscale.com" + registerSingleSettingStoreForTest(b, source.TestSettingOf(ControlURL, wantControlURL)) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + gotControlURL, _ := GetString(ControlURL, "https://controlplane.tailscale.com") + if gotControlURL != wantControlURL { + b.Fatalf("got %v; want %v", gotControlURL, wantControlURL) + } + } +} + func TestSelectControlURL(t *testing.T) { tests := []struct { reg, disk, want string @@ -499,3 +648,13 @@ func TestSelectControlURL(t *testing.T) { } } } + +func errorsMatchForTest(got, want error) bool { + if got == nil && want == nil { + return true + } + if got == nil || want == nil { + return false + } + return errors.Is(got, want) || got.Error() == want.Error() +} diff --git a/util/syspolicy/syspolicy_windows.go b/util/syspolicy/syspolicy_windows.go new file mode 100644 index 0000000000000..9d57e249e55e3 --- /dev/null +++ b/util/syspolicy/syspolicy_windows.go @@ -0,0 +1,92 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syspolicy + +import ( + "errors" + "fmt" + "os/user" + + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/rsop" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" + "tailscale.com/util/testenv" +) + +func init() { + // On Windows, we should automatically register the Registry-based policy + // store for the device. If we are running in a user's security context + // (e.g., we're the GUI), we should also register the Registry policy store for + // the user. In the future, we should register (and unregister) user policy + // stores whenever a user connects to (or disconnects from) the local backend. + // This ensures the backend is aware of the user's policy settings and can send + // them to the GUI/CLI/Web clients on demand or whenever they change. + // + // Other platforms, such as macOS, iOS and Android, should register their + // platform-specific policy stores via [RegisterStore] + // (or [RegisterHandler] until they implement the [source.Store] interface). + // + // External code, such as the ipnlocal package, may choose to register + // additional policy stores, such as config files and policies received from + // the control plane. + internal.Init.MustDefer(func() error { + // Do not register or use default policy stores during tests. + // Each test should set up its own necessary configurations. + if testenv.InTest() { + return nil + } + return configureSyspolicy(nil) + }) +} + +// configureSyspolicy configures syspolicy for use on Windows, +// either in test or regular builds depending on whether tb has a non-nil value. +func configureSyspolicy(tb internal.TB) error { + const localSystemSID = "S-1-5-18" + // Always create and register a machine policy store that reads + // policy settings from the HKEY_LOCAL_MACHINE registry hive. + machineStore, err := source.NewMachinePlatformPolicyStore() + if err != nil { + return fmt.Errorf("failed to create the machine policy store: %v", err) + } + if tb == nil { + _, err = rsop.RegisterStore("Platform", setting.DeviceScope, machineStore) + } else { + _, err = rsop.RegisterStoreForTest(tb, "Platform", setting.DeviceScope, machineStore) + } + if err != nil { + return err + } + // Check whether the current process is running as Local System or not. + u, err := user.Current() + if err != nil { + return err + } + if u.Uid == localSystemSID { + return nil + } + // If it's not a Local System's process (e.g., it's the GUI rather than the tailscaled service), + // we should create and use a policy store for the current user that reads + // policy settings from that user's registry hive (HKEY_CURRENT_USER). + userStore, err := source.NewUserPlatformPolicyStore(0) + if err != nil { + return fmt.Errorf("failed to create the current user's policy store: %v", err) + } + if tb == nil { + _, err = rsop.RegisterStore("Platform", setting.CurrentUserScope, userStore) + } else { + _, err = rsop.RegisterStoreForTest(tb, "Platform", setting.CurrentUserScope, userStore) + } + if err != nil { + return err + } + // And also set [setting.CurrentUserScope] as the [setting.DefaultScope], so [GetString], + // [GetVisibility] and similar functions would be returning a merged result + // of the machine's and user's policies. + if !setting.SetDefaultScope(setting.CurrentUserScope) { + return errors.New("current scope already set") + } + return nil +} From 6ab39b7bcd259a7cf4adb9331586a64698c85dcc Mon Sep 17 00:00:00 2001 From: Nick Kirby Date: Sat, 26 Oct 2024 13:03:36 +0100 Subject: [PATCH 040/179] cmd/k8s-operator: validate that tailscale.com/tailnet-ip annotation value is a valid IP Fixes #13836 Signed-off-by: Nick Kirby --- cmd/k8s-operator/operator_test.go | 142 ++++++++++++++++++++++++++++++ cmd/k8s-operator/svc.go | 11 ++- 2 files changed, 150 insertions(+), 3 deletions(-) diff --git a/cmd/k8s-operator/operator_test.go b/cmd/k8s-operator/operator_test.go index 21e1d4313749e..a440fafb5cfc1 100644 --- a/cmd/k8s-operator/operator_test.go +++ b/cmd/k8s-operator/operator_test.go @@ -432,6 +432,148 @@ func TestTailnetTargetIPAnnotation(t *testing.T) { expectMissing[corev1.Secret](t, fc, "operator-ns", fullName) } +func TestTailnetTargetIPAnnotation_IPCouldNotBeParsed(t *testing.T) { + fc := fake.NewFakeClient() + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + clock := tstest.NewClock(tstest.ClockOpts{}) + sr := &ServiceReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + clock: clock, + recorder: record.NewFakeRecorder(100), + } + tailnetTargetIP := "invalid-ip" + mustCreate(t, fc, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + AnnotationTailnetTargetIP: tailnetTargetIP, + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + }) + + expectReconciled(t, sr, "default", "test") + + t0 := conditionTime(clock) + + want := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + AnnotationTailnetTargetIP: tailnetTargetIP, + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + Status: corev1.ServiceStatus{ + Conditions: []metav1.Condition{{ + Type: string(tsapi.ProxyReady), + Status: metav1.ConditionFalse, + LastTransitionTime: t0, + Reason: reasonProxyInvalid, + Message: `unable to provision proxy resources: invalid Service: invalid value of annotation tailscale.com/tailnet-ip: "invalid-ip" could not be parsed as a valid IP Address, error: ParseAddr("invalid-ip"): unable to parse IP`, + }}, + }, + } + + expectEqual(t, fc, want, nil) +} + +func TestTailnetTargetIPAnnotation_InvalidIP(t *testing.T) { + fc := fake.NewFakeClient() + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + clock := tstest.NewClock(tstest.ClockOpts{}) + sr := &ServiceReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + clock: clock, + recorder: record.NewFakeRecorder(100), + } + tailnetTargetIP := "999.999.999.999" + mustCreate(t, fc, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + AnnotationTailnetTargetIP: tailnetTargetIP, + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + }) + + expectReconciled(t, sr, "default", "test") + + t0 := conditionTime(clock) + + want := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + AnnotationTailnetTargetIP: tailnetTargetIP, + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + Status: corev1.ServiceStatus{ + Conditions: []metav1.Condition{{ + Type: string(tsapi.ProxyReady), + Status: metav1.ConditionFalse, + LastTransitionTime: t0, + Reason: reasonProxyInvalid, + Message: `unable to provision proxy resources: invalid Service: invalid value of annotation tailscale.com/tailnet-ip: "999.999.999.999" could not be parsed as a valid IP Address, error: ParseAddr("999.999.999.999"): IPv4 field has value >255`, + }}, + }, + } + + expectEqual(t, fc, want, nil) +} + func TestAnnotations(t *testing.T) { fc := fake.NewFakeClient() ft := &fakeTSClient{} diff --git a/cmd/k8s-operator/svc.go b/cmd/k8s-operator/svc.go index f45f922463113..3c6bc27a95cf0 100644 --- a/cmd/k8s-operator/svc.go +++ b/cmd/k8s-operator/svc.go @@ -358,9 +358,14 @@ func validateService(svc *corev1.Service) []string { violations = append(violations, fmt.Sprintf("invalid value of annotation %s: %q does not appear to be a valid MagicDNS name", AnnotationTailnetTargetFQDN, fqdn)) } } - - // TODO(irbekrm): validate that tailscale.com/tailnet-ip annotation is a - // valid IP address (tailscale/tailscale#13671). + if ipStr := svc.Annotations[AnnotationTailnetTargetIP]; ipStr != "" { + ip, err := netip.ParseAddr(ipStr) + if err != nil { + violations = append(violations, fmt.Sprintf("invalid value of annotation %s: %q could not be parsed as a valid IP Address, error: %s", AnnotationTailnetTargetIP, ipStr, err)) + } else if !ip.IsValid() { + violations = append(violations, fmt.Sprintf("parsed IP address in annotation %s: %q is not valid", AnnotationTailnetTargetIP, ipStr)) + } + } svcName := nameForService(svc) if err := dnsname.ValidLabel(svcName); err != nil { From 853fe3b7132959fb99648df3d5d7aec47a6734c1 Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Sat, 26 Oct 2024 09:33:47 -0500 Subject: [PATCH 041/179] ipn/store/kubestore: cache state in memory (#13918) Cache state in memory on writes, read from memory in reads. kubestore was previously always reading state from a Secret. This change should fix bugs caused by temporary loss of access to kube API server and imporove overall performance Fixes #7671 Updates tailscale/tailscale#12079,tailscale/tailscale#13900 Signed-off-by: Maisem Ali Signed-off-by: Irbe Krumina Co-authored-by: Maisem Ali --- ipn/store/kubestore/store_kube.go | 81 +++++++++++++++++++------------ ipn/store/mem/store_mem.go | 17 +++++++ 2 files changed, 67 insertions(+), 31 deletions(-) diff --git a/ipn/store/kubestore/store_kube.go b/ipn/store/kubestore/store_kube.go index 00950bd3b2394..1e0e01c7b9609 100644 --- a/ipn/store/kubestore/store_kube.go +++ b/ipn/store/kubestore/store_kube.go @@ -13,19 +13,27 @@ import ( "time" "tailscale.com/ipn" + "tailscale.com/ipn/store/mem" "tailscale.com/kube/kubeapi" "tailscale.com/kube/kubeclient" "tailscale.com/types/logger" ) +// TODO(irbekrm): should we bump this? should we have retries? See tailscale/tailscale#13024 +const timeout = 5 * time.Second + // Store is an ipn.StateStore that uses a Kubernetes Secret for persistence. type Store struct { client kubeclient.Client canPatch bool secretName string + + // memory holds the latest tailscale state. Writes write state to a kube Secret and memory, Reads read from + // memory. + memory mem.Store } -// New returns a new Store that persists to the named secret. +// New returns a new Store that persists to the named Secret. func New(_ logger.Logf, secretName string) (*Store, error) { c, err := kubeclient.New() if err != nil { @@ -39,11 +47,16 @@ func New(_ logger.Logf, secretName string) (*Store, error) { if err != nil { return nil, err } - return &Store{ + s := &Store{ client: c, canPatch: canPatch, secretName: secretName, - }, nil + } + // Load latest state from kube Secret if it already exists. + if err := s.loadState(); err != nil { + return nil, fmt.Errorf("error loading state from kube Secret: %w", err) + } + return s, nil } func (s *Store) SetDialer(d func(ctx context.Context, network, address string) (net.Conn, error)) { @@ -54,37 +67,17 @@ func (s *Store) String() string { return "kube.Store" } // ReadState implements the StateStore interface. func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - secret, err := s.client.GetSecret(ctx, s.secretName) - if err != nil { - if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 { - return nil, ipn.ErrStateNotExist - } - return nil, err - } - b, ok := secret.Data[sanitizeKey(id)] - if !ok { - return nil, ipn.ErrStateNotExist - } - return b, nil -} - -func sanitizeKey(k ipn.StateKey) string { - // The only valid characters in a Kubernetes secret key are alphanumeric, -, - // _, and . - return strings.Map(func(r rune) rune { - if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-' || r == '_' || r == '.' { - return r - } - return '_' - }, string(k)) + return s.memory.ReadState(ipn.StateKey(sanitizeKey(id))) } // WriteState implements the StateStore interface. -func (s *Store) WriteState(id ipn.StateKey, bs []byte) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) { + defer func() { + if err == nil { + s.memory.WriteState(ipn.StateKey(sanitizeKey(id)), bs) + } + }() + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() secret, err := s.client.GetSecret(ctx, s.secretName) @@ -137,3 +130,29 @@ func (s *Store) WriteState(id ipn.StateKey, bs []byte) error { } return err } + +func (s *Store) loadState() error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + secret, err := s.client.GetSecret(ctx, s.secretName) + if err != nil { + if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 { + return ipn.ErrStateNotExist + } + return err + } + s.memory.LoadFromMap(secret.Data) + return nil +} + +func sanitizeKey(k ipn.StateKey) string { + // The only valid characters in a Kubernetes secret key are alphanumeric, -, + // _, and . + return strings.Map(func(r rune) rune { + if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-' || r == '_' || r == '.' { + return r + } + return '_' + }, string(k)) +} diff --git a/ipn/store/mem/store_mem.go b/ipn/store/mem/store_mem.go index f3a308ae5dc4f..6f474ce993b43 100644 --- a/ipn/store/mem/store_mem.go +++ b/ipn/store/mem/store_mem.go @@ -9,8 +9,10 @@ import ( "encoding/json" "sync" + xmaps "golang.org/x/exp/maps" "tailscale.com/ipn" "tailscale.com/types/logger" + "tailscale.com/util/mak" ) // New returns a new Store. @@ -28,6 +30,7 @@ type Store struct { func (s *Store) String() string { return "mem.Store" } // ReadState implements the StateStore interface. +// It returns ipn.ErrStateNotExist if the state does not exist. func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) { s.mu.Lock() defer s.mu.Unlock() @@ -39,6 +42,7 @@ func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) { } // WriteState implements the StateStore interface. +// It never returns an error. func (s *Store) WriteState(id ipn.StateKey, bs []byte) error { s.mu.Lock() defer s.mu.Unlock() @@ -49,6 +53,19 @@ func (s *Store) WriteState(id ipn.StateKey, bs []byte) error { return nil } +// LoadFromMap loads the in-memory cache from the provided map. +// Any existing content is cleared, and the provided map is +// copied into the cache. +func (s *Store) LoadFromMap(m map[string][]byte) { + s.mu.Lock() + defer s.mu.Unlock() + xmaps.Clear(s.cache) + for k, v := range m { + mak.Set(&s.cache, ipn.StateKey(k), v) + } + return +} + // LoadFromJSON attempts to unmarshal json content into the // in-memory cache. func (s *Store) LoadFromJSON(data []byte) error { From 9d1348fe212fccf52de11f4009e24a7436167fe7 Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Sun, 27 Oct 2024 10:54:38 -0500 Subject: [PATCH 042/179] ipn/store/kubestore: don't error if state cannot be preloaded (#13926) Preloading of state from kube Secret should not error if the Secret does not exist. Updates tailscale/tailscale#7671 Signed-off-by: Irbe Krumina --- ipn/store/kubestore/store_kube.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipn/store/kubestore/store_kube.go b/ipn/store/kubestore/store_kube.go index 1e0e01c7b9609..2dcc08b6e4d1c 100644 --- a/ipn/store/kubestore/store_kube.go +++ b/ipn/store/kubestore/store_kube.go @@ -53,7 +53,7 @@ func New(_ logger.Logf, secretName string) (*Store, error) { secretName: secretName, } // Load latest state from kube Secret if it already exists. - if err := s.loadState(); err != nil { + if err := s.loadState(); err != nil && err != ipn.ErrStateNotExist { return nil, fmt.Errorf("error loading state from kube Secret: %w", err) } return s, nil From 5d07c17b9395c513dc2d4674d63e33397ce794d5 Mon Sep 17 00:00:00 2001 From: Renato Aguiar Date: Mon, 28 Oct 2024 08:00:48 -0700 Subject: [PATCH 043/179] net/dns: fix blank lines being added to resolv.conf on OpenBSD (#13928) During resolv.conf update, old 'search' lines are cleared but '\n' is not deleted, leaving behind a new blank line on every update. This adds 's' flag to regexp, so '\n' is included in the match and deleted when old lines are cleared. Also, insert missing `\n` when updated 'search' line is appended to resolv.conf. Signed-off-by: Renato Aguiar --- net/dns/resolvd.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/net/dns/resolvd.go b/net/dns/resolvd.go index 9b067eb07b178..ad1a99c111997 100644 --- a/net/dns/resolvd.go +++ b/net/dns/resolvd.go @@ -57,6 +57,7 @@ func (m *resolvdManager) SetDNS(config OSConfig) error { if len(newSearch) > 1 { newResolvConf = append(newResolvConf, []byte(strings.Join(newSearch, " "))...) + newResolvConf = append(newResolvConf, '\n') } err = m.fs.WriteFile(resolvConf, newResolvConf, 0644) @@ -123,6 +124,6 @@ func (m resolvdManager) readResolvConf() (config OSConfig, err error) { } func removeSearchLines(orig []byte) []byte { - re := regexp.MustCompile(`(?m)^search\s+.+$`) + re := regexp.MustCompile(`(?ms)^search\s+.+$`) return re.ReplaceAll(orig, []byte("")) } From 41aac261064602c5eb14ccbacd0a684ffe3ae533 Mon Sep 17 00:00:00 2001 From: License Updater Date: Mon, 28 Oct 2024 15:02:34 +0000 Subject: [PATCH 044/179] licenses: update license notices Signed-off-by: License Updater --- licenses/apple.md | 8 ++++---- licenses/windows.md | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/licenses/apple.md b/licenses/apple.md index 751082d5b220f..36c654c59c026 100644 --- a/licenses/apple.md +++ b/licenses/apple.md @@ -73,13 +73,13 @@ See also the dependencies in the [Tailscale CLI][]. - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.28.0:LICENSE)) - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fe59bbe5:LICENSE)) - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE)) - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.8.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.26.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.25.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.19.0:LICENSE)) - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE)) - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/64c016c92987/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) diff --git a/licenses/windows.md b/licenses/windows.md index 2a8e4e621a4a6..3f6650b9eaff2 100644 --- a/licenses/windows.md +++ b/licenses/windows.md @@ -65,15 +65,15 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.28.0:LICENSE)) - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fe59bbe5:LICENSE)) - [golang.org/x/image/bmp](https://pkg.go.dev/golang.org/x/image/bmp) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.18.0:LICENSE)) - [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.19.0:LICENSE)) - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE)) - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.8.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.26.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.25.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.19.0:LICENSE)) - [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=0fa3db229ce2)) - [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3)) - [gopkg.in/Knetic/govaluate.v3](https://pkg.go.dev/gopkg.in/Knetic/govaluate.v3) ([MIT](https://github.com/Knetic/govaluate/blob/v3.0.0/LICENSE)) From c0a1ed86cbe5e8a8511b04fe1406b3903cd9f8b8 Mon Sep 17 00:00:00 2001 From: James Tucker Date: Fri, 13 Sep 2024 11:35:47 -0700 Subject: [PATCH 045/179] tstest/natlab: add latency & loss simulation A simple implementation of latency and loss simulation, applied to writes to the ethernet interface of the NIC. The latency implementation could be optimized substantially later if necessary. Updates #13355 Signed-off-by: James Tucker --- tstest/natlab/vnet/conf.go | 21 +++++++++++++++++++++ tstest/natlab/vnet/conf_test.go | 15 ++++++++++++++- tstest/natlab/vnet/vnet.go | 23 +++++++++++++++++++++-- 3 files changed, 56 insertions(+), 3 deletions(-) diff --git a/tstest/natlab/vnet/conf.go b/tstest/natlab/vnet/conf.go index cf71a66743e1c..a37c22a6c8023 100644 --- a/tstest/natlab/vnet/conf.go +++ b/tstest/natlab/vnet/conf.go @@ -10,6 +10,7 @@ import ( "net/netip" "os" "slices" + "time" "github.com/google/gopacket/layers" "github.com/google/gopacket/pcapgo" @@ -279,10 +280,28 @@ type Network struct { svcs set.Set[NetworkService] + latency time.Duration // latency applied to interface writes + lossRate float64 // chance of packet loss (0.0 to 1.0) + // ... err error // carried error } +// SetLatency sets the simulated network latency for this network. +func (n *Network) SetLatency(d time.Duration) { + n.latency = d +} + +// SetPacketLoss sets the packet loss rate for this network 0.0 (no loss) to 1.0 (total loss). +func (n *Network) SetPacketLoss(rate float64) { + if rate < 0 { + rate = 0 + } else if rate > 1 { + rate = 1 + } + n.lossRate = rate +} + // SetBlackholedIPv4 sets whether the network should blackhole all IPv4 traffic // out to the Internet. (DHCP etc continues to work on the LAN.) func (n *Network) SetBlackholedIPv4(v bool) { @@ -361,6 +380,8 @@ func (s *Server) initFromConfig(c *Config) error { wanIP4: conf.wanIP4, lanIP4: conf.lanIP4, breakWAN4: conf.breakWAN4, + latency: conf.latency, + lossRate: conf.lossRate, nodesByIP4: map[netip.Addr]*node{}, nodesByMAC: map[MAC]*node{}, logf: logger.WithPrefix(s.logf, fmt.Sprintf("[net-%v] ", conf.mac)), diff --git a/tstest/natlab/vnet/conf_test.go b/tstest/natlab/vnet/conf_test.go index 15d3c69ef52d9..6566ac8cf4610 100644 --- a/tstest/natlab/vnet/conf_test.go +++ b/tstest/natlab/vnet/conf_test.go @@ -3,7 +3,10 @@ package vnet -import "testing" +import ( + "testing" + "time" +) func TestConfig(t *testing.T) { tests := []struct { @@ -18,6 +21,16 @@ func TestConfig(t *testing.T) { c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", HardNAT)) }, }, + { + name: "latency-and-loss", + setup: func(c *Config) { + n1 := c.AddNetwork("2.1.1.1", "192.168.1.1/24", EasyNAT, NATPMP) + n1.SetLatency(time.Second) + n1.SetPacketLoss(0.1) + c.AddNode(n1) + c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", HardNAT)) + }, + }, { name: "indirect", setup: func(c *Config) { diff --git a/tstest/natlab/vnet/vnet.go b/tstest/natlab/vnet/vnet.go index e7991b3e6ef5d..92312c039bfc9 100644 --- a/tstest/natlab/vnet/vnet.go +++ b/tstest/natlab/vnet/vnet.go @@ -515,6 +515,8 @@ type network struct { wanIP4 netip.Addr // router's LAN IPv4, if any lanIP4 netip.Prefix // router's LAN IP + CIDR (e.g. 192.168.2.1/24) breakWAN4 bool // break WAN IPv4 connectivity + latency time.Duration // latency applied to interface writes + lossRate float64 // probability of dropping a packet (0.0 to 1.0) nodesByIP4 map[netip.Addr]*node // by LAN IPv4 nodesByMAC map[MAC]*node logf func(format string, args ...any) @@ -977,7 +979,7 @@ func (n *network) writeEth(res []byte) bool { for mac, nw := range n.writers.All() { if mac != srcMAC { num++ - nw.write(res) + n.conditionedWrite(nw, res) } } return num > 0 @@ -987,7 +989,7 @@ func (n *network) writeEth(res []byte) bool { return false } if nw, ok := n.writers.Load(dstMAC); ok { - nw.write(res) + n.conditionedWrite(nw, res) return true } @@ -1000,6 +1002,23 @@ func (n *network) writeEth(res []byte) bool { return false } +func (n *network) conditionedWrite(nw networkWriter, packet []byte) { + if n.lossRate > 0 && rand.Float64() < n.lossRate { + // packet lost + return + } + if n.latency > 0 { + // copy the packet as there's no guarantee packet is owned long enough. + // TODO(raggi): this could be optimized substantially if necessary, + // a pool of buffers and a cheaper delay mechanism are both obvious improvements. + var pkt = make([]byte, len(packet)) + copy(pkt, packet) + time.AfterFunc(n.latency, func() { nw.write(pkt) }) + } else { + nw.write(packet) + } +} + var ( macAllNodes = MAC{0: 0x33, 1: 0x33, 5: 0x01} macAllRouters = MAC{0: 0x33, 1: 0x33, 5: 0x02} From 0d76d7d21c951872433de708839025c8dfb304b3 Mon Sep 17 00:00:00 2001 From: James Tucker Date: Wed, 11 Sep 2024 11:28:33 -0700 Subject: [PATCH 046/179] tool/gocross: remove trimpath from test builds trimpath can be inconvenient for IDEs and LSPs that do not always correctly handle module relative paths, and can also contribute to caching bugs taking effect. We rarely have a real need for trimpath of test produced binaries, so avoiding it should be a net win. Updates #2988 Signed-off-by: James Tucker --- tool/gocross/autoflags.go | 6 +++++- tool/gocross/autoflags_test.go | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tool/gocross/autoflags.go b/tool/gocross/autoflags.go index 020b19fa58446..b28d3bc5dd26e 100644 --- a/tool/gocross/autoflags.go +++ b/tool/gocross/autoflags.go @@ -35,7 +35,7 @@ func autoflagsForTest(argv []string, env *Environment, goroot, nativeGOOS, nativ cc = "cc" targetOS = cmp.Or(env.Get("GOOS", ""), nativeGOOS) targetArch = cmp.Or(env.Get("GOARCH", ""), nativeGOARCH) - buildFlags = []string{"-trimpath"} + buildFlags = []string{} cgoCflags = []string{"-O3", "-std=gnu11", "-g"} cgoLdflags []string ldflags []string @@ -47,6 +47,10 @@ func autoflagsForTest(argv []string, env *Environment, goroot, nativeGOOS, nativ subcommand = argv[1] } + if subcommand != "test" { + buildFlags = append(buildFlags, "-trimpath") + } + switch subcommand { case "build", "env", "install", "run", "test", "list": default: diff --git a/tool/gocross/autoflags_test.go b/tool/gocross/autoflags_test.go index 8f24dd8a32797..a0f3edfd2bb68 100644 --- a/tool/gocross/autoflags_test.go +++ b/tool/gocross/autoflags_test.go @@ -163,7 +163,6 @@ GOTOOLCHAIN=local (was ) TS_LINK_FAIL_REFLECT=0 (was )`, wantArgv: []string{ "gocross", "test", - "-trimpath", "-tags=tailscale_go,osusergo,netgo", "-ldflags", "-X tailscale.com/version.longStamp=1.2.3-long -X tailscale.com/version.shortStamp=1.2.3 -X tailscale.com/version.gitCommitStamp=abcd -X tailscale.com/version.extraGitCommitStamp=defg '-extldflags=-static'", "-race", From 94fa6d97c5a25269e9c68595d5fabfd847f9f7b4 Mon Sep 17 00:00:00 2001 From: Anton Tolchanov Date: Mon, 28 Oct 2024 16:41:44 +0000 Subject: [PATCH 047/179] ipn/ipnlocal: log errors while fetching serial numbers If the client cannot fetch a serial number, write a log message helping the user understand what happened. Also, don't just return the error immediately, since we still have a chance to collect network interface addresses. Updates #5902 Signed-off-by: Anton Tolchanov --- ipn/ipnlocal/c2n.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ipn/ipnlocal/c2n.go b/ipn/ipnlocal/c2n.go index de6ca2321a741..c3ed32fd89bd5 100644 --- a/ipn/ipnlocal/c2n.go +++ b/ipn/ipnlocal/c2n.go @@ -332,12 +332,10 @@ func handleC2NPostureIdentityGet(b *LocalBackend, w http.ResponseWriter, r *http } if choice.ShouldEnable(b.Prefs().PostureChecking()) { - sns, err := posture.GetSerialNumbers(b.logf) + res.SerialNumbers, err = posture.GetSerialNumbers(b.logf) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + b.logf("c2n: GetSerialNumbers returned error: %v", err) } - res.SerialNumbers = sns // TODO(tailscale/corp#21371, 2024-07-10): once this has landed in a stable release // and looks good in client metrics, remove this parameter and always report MAC From 11e96760ff119dcfa60139371570f761e0c26050 Mon Sep 17 00:00:00 2001 From: Anton Tolchanov Date: Tue, 29 Oct 2024 13:40:33 +0000 Subject: [PATCH 048/179] wgengine/magicsock: fix stats packet counter on derp egress Updates tailscale/corp#22075 Signed-off-by: Anton Tolchanov --- wgengine/magicsock/endpoint.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index ab9f3d47dd033..1ddde97524571 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -991,7 +991,7 @@ func (de *endpoint) send(buffs [][]byte) error { } if stats := de.c.stats.Load(); stats != nil { - stats.UpdateTxPhysical(de.nodeAddr, derpAddr, 1, txBytes) + stats.UpdateTxPhysical(de.nodeAddr, derpAddr, len(buffs), txBytes) } if allOk { return nil From 38af62c7b303d707ba5cc46148809921557e36aa Mon Sep 17 00:00:00 2001 From: Anton Tolchanov Date: Tue, 29 Oct 2024 13:35:12 +0000 Subject: [PATCH 049/179] ipn/ipnlocal: remove the primary routes gauge for now Not confident this is the right way to expose this, so let's remote it for now. Updates tailscale/corp#22075 Signed-off-by: Anton Tolchanov --- ipn/ipnlocal/local.go | 9 --------- tsnet/tsnet_test.go | 12 ------------ 2 files changed, 21 deletions(-) diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index b01f3a0c0f16a..b91f1337af0ed 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -399,11 +399,6 @@ type metrics struct { // approvedRoutes is a metric that reports the number of network routes served by the local node and approved // by the control server. approvedRoutes *usermetric.Gauge - - // primaryRoutes is a metric that reports the number of primary network routes served by the local node. - // A route being a primary route implies that the route is currently served by this node, and not by another - // subnet router in a high availability configuration. - primaryRoutes *usermetric.Gauge } // clientGen is a func that creates a control plane client. @@ -454,8 +449,6 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo "tailscaled_advertised_routes", "Number of advertised network routes (e.g. by a subnet router)"), approvedRoutes: sys.UserMetricsRegistry().NewGauge( "tailscaled_approved_routes", "Number of approved network routes (e.g. by a subnet router)"), - primaryRoutes: sys.UserMetricsRegistry().NewGauge( - "tailscaled_primary_routes", "Number of network routes for which this node is a primary router (in high availability configuration)"), } b := &LocalBackend{ @@ -5477,7 +5470,6 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { // If there is no netmap, the client is going into a "turned off" // state so reset the metrics. b.metrics.approvedRoutes.Set(0) - b.metrics.primaryRoutes.Set(0) return } @@ -5506,7 +5498,6 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { } } b.metrics.approvedRoutes.Set(approved) - b.metrics.primaryRoutes.Set(float64(tsaddr.WithoutExitRoute(nm.SelfNode.PrimaryRoutes()).Len())) } for _, p := range nm.Peers { addNode(p) diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 98c1fd4ab3462..7aebbdd4c39ca 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -1080,13 +1080,6 @@ func TestUserMetrics(t *testing.T) { t.Errorf("metrics1, tailscaled_health_messages: got %v, want %v", got, want) } - // The node is the primary subnet router for 2 routes: - // - 192.0.2.0/24 - // - 192.0.5.1/32 - if got, want := parsedMetrics1["tailscaled_primary_routes"], wantRoutes; got != want { - t.Errorf("metrics1, tailscaled_primary_routes: got %v, want %v", got, want) - } - // Verify that the amount of data recorded in bytes is higher or equal to the // 10 megabytes sent. inboundBytes1 := parsedMetrics1[`tailscaled_inbound_bytes_total{path="direct_ipv4"}`] @@ -1131,11 +1124,6 @@ func TestUserMetrics(t *testing.T) { t.Errorf("metrics2, tailscaled_health_messages: got %v, want %v", got, want) } - // The node is the primary subnet router for 0 routes - if got, want := parsedMetrics2["tailscaled_primary_routes"], 0.0; got != want { - t.Errorf("metrics2, tailscaled_primary_routes: got %v, want %v", got, want) - } - // Verify that the amount of data recorded in bytes is higher or equal than the // 10 megabytes sent. outboundBytes2 := parsedMetrics2[`tailscaled_outbound_bytes_total{path="direct_ipv4"}`] From 9545e36007e5859b0a9aec4052bcb7f7837b0948 Mon Sep 17 00:00:00 2001 From: Anton Tolchanov Date: Sat, 26 Oct 2024 18:28:22 +0100 Subject: [PATCH 050/179] cmd/tailscale/cli: add 'tailscale metrics' command - `tailscale metrics print`: to show metric values in console - `tailscale metrics write`: to write metrics to a file (with a tempfile & rename dance, which is atomic on Unix). Also, remove the `TS_DEBUG_USER_METRICS` envknob as we are getting more confident in these metrics. Updates tailscale/corp#22075 Signed-off-by: Anton Tolchanov --- cmd/tailscale/cli/cli.go | 1 + cmd/tailscale/cli/metrics.go | 88 ++++++++++++++++++++++++++++++++++++ ipn/localapi/localapi.go | 11 +---- 3 files changed, 91 insertions(+), 9 deletions(-) create mode 100644 cmd/tailscale/cli/metrics.go diff --git a/cmd/tailscale/cli/cli.go b/cmd/tailscale/cli/cli.go index de6bc2a4e5e41..f786bcea5bdf7 100644 --- a/cmd/tailscale/cli/cli.go +++ b/cmd/tailscale/cli/cli.go @@ -189,6 +189,7 @@ change in the future. ipCmd, dnsCmd, statusCmd, + metricsCmd, pingCmd, ncCmd, sshCmd, diff --git a/cmd/tailscale/cli/metrics.go b/cmd/tailscale/cli/metrics.go new file mode 100644 index 0000000000000..d5fe9ad81cb70 --- /dev/null +++ b/cmd/tailscale/cli/metrics.go @@ -0,0 +1,88 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/atomicfile" +) + +var metricsCmd = &ffcli.Command{ + Name: "metrics", + ShortHelp: "Show Tailscale metrics", + LongHelp: strings.TrimSpace(` + +The 'tailscale metrics' command shows Tailscale user-facing metrics (as opposed +to internal metrics printed by 'tailscale debug metrics'). + +For more information about Tailscale metrics, refer to +https://tailscale.com/s/client-metrics + +`), + ShortUsage: "tailscale metrics [flags]", + UsageFunc: usageFuncNoDefaultValues, + Exec: runMetricsNoSubcommand, + Subcommands: []*ffcli.Command{ + { + Name: "print", + ShortUsage: "tailscale metrics print", + Exec: runMetricsPrint, + ShortHelp: "Prints current metric values in the Prometheus text exposition format", + }, + { + Name: "write", + ShortUsage: "tailscale metrics write ", + Exec: runMetricsWrite, + ShortHelp: "Writes metric values to a file", + LongHelp: strings.TrimSpace(` + +The 'tailscale metrics write' command writes metric values to a text file provided as its +only argument. It's meant to be used alongside Prometheus node exporter, allowing Tailscale +metrics to be consumed and exported by the textfile collector. + +As an example, to export Tailscale metrics on an Ubuntu system running node exporter, you +can regularly run 'tailscale metrics write /var/lib/prometheus/node-exporter/tailscaled.prom' +using cron or a systemd timer. + + `), + }, + }, +} + +// runMetricsNoSubcommand prints metric values if no subcommand is specified. +func runMetricsNoSubcommand(ctx context.Context, args []string) error { + if len(args) > 0 { + return fmt.Errorf("tailscale metrics: unknown subcommand: %s", args[0]) + } + + return runMetricsPrint(ctx, args) +} + +// runMetricsPrint prints metric values to stdout. +func runMetricsPrint(ctx context.Context, args []string) error { + out, err := localClient.UserMetrics(ctx) + if err != nil { + return err + } + Stdout.Write(out) + return nil +} + +// runMetricsWrite writes metric values to a file. +func runMetricsWrite(ctx context.Context, args []string) error { + if len(args) != 1 { + return errors.New("usage: tailscale metrics write ") + } + path := args[0] + out, err := localClient.UserMetrics(ctx) + if err != nil { + return err + } + return atomicfile.WriteFile(path, out, 0644) +} diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index 25ec1912131df..1d580eca9ff95 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -62,7 +62,6 @@ import ( "tailscale.com/util/osdiag" "tailscale.com/util/progresstracking" "tailscale.com/util/rands" - "tailscale.com/util/testenv" "tailscale.com/version" "tailscale.com/wgengine/magicsock" ) @@ -570,15 +569,9 @@ func (h *Handler) serveMetrics(w http.ResponseWriter, r *http.Request) { clientmetric.WritePrometheusExpositionFormat(w) } -// TODO(kradalby): Remove this once we have landed on a final set of -// metrics to export to clients and consider the metrics stable. -var debugUsermetricsEndpoint = envknob.RegisterBool("TS_DEBUG_USER_METRICS") - +// serveUserMetrics returns user-facing metrics in Prometheus text +// exposition format. func (h *Handler) serveUserMetrics(w http.ResponseWriter, r *http.Request) { - if !testenv.InTest() && !debugUsermetricsEndpoint() { - http.Error(w, "usermetrics debug flag not enabled", http.StatusForbidden) - return - } h.b.UserMetricsRegistry().Handler(w, r) } From 0f9a054cba58aa7f1c45d82f18be43ec0ffd592e Mon Sep 17 00:00:00 2001 From: Jonathan Nobels Date: Tue, 29 Oct 2024 13:49:29 -0400 Subject: [PATCH 051/179] tstest/tailmac: fix Host.app path generation (#13953) updates tailscale/corp#24197 Generation of the Host.app path was erroneous and tailmac run would not work unless the pwd was tailmac/bin. Now you can be able to invoke tailmac from anywhere. Signed-off-by: Jonathan Nobels --- tstest/tailmac/Swift/TailMac/TailMac.swift | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tstest/tailmac/Swift/TailMac/TailMac.swift b/tstest/tailmac/Swift/TailMac/TailMac.swift index 56f651696e12c..6554d5debb459 100644 --- a/tstest/tailmac/Swift/TailMac/TailMac.swift +++ b/tstest/tailmac/Swift/TailMac/TailMac.swift @@ -100,7 +100,10 @@ extension Tailmac { mutating func run() { let process = Process() let stdOutPipe = Pipe() - let appPath = "./Host.app/Contents/MacOS/Host" + + let executablePath = CommandLine.arguments[0] + let executableDirectory = (executablePath as NSString).deletingLastPathComponent + let appPath = executableDirectory + "/Host.app/Contents/MacOS/Host" process.executableURL = URL( fileURLWithPath: appPath, @@ -109,7 +112,7 @@ extension Tailmac { ) if !FileManager.default.fileExists(atPath: appPath) { - fatalError("Could not find Host.app. This must be co-located with the tailmac utility") + fatalError("Could not find Host.app at \(appPath). This must be co-located with the tailmac utility") } process.arguments = ["run", "--id", id] From aecb0ab76bb38c13d29b184380a53a5190c77302 Mon Sep 17 00:00:00 2001 From: Jonathan Nobels Date: Tue, 29 Oct 2024 13:49:51 -0400 Subject: [PATCH 052/179] tstest/tailmac: add support for mounting host directories in the guest (#13957) updates tailscale/corp#24197 tailmac run now supports the --share option which will allow you to specify a directory on the host which can be mounted in the guest using mount_virtiofs vmshare . Signed-off-by: Jonathan Nobels --- tstest/tailmac/Swift/Common/Config.swift | 1 + .../Swift/Common/TailMacConfigHelper.swift | 13 ++++++++++ tstest/tailmac/Swift/Host/HostCli.swift | 4 +++- tstest/tailmac/Swift/Host/VMController.swift | 7 ++++++ tstest/tailmac/Swift/TailMac/TailMac.swift | 24 +++++++++---------- 5 files changed, 35 insertions(+), 14 deletions(-) diff --git a/tstest/tailmac/Swift/Common/Config.swift b/tstest/tailmac/Swift/Common/Config.swift index 01d5069b0049d..18b68ae9b9d14 100644 --- a/tstest/tailmac/Swift/Common/Config.swift +++ b/tstest/tailmac/Swift/Common/Config.swift @@ -14,6 +14,7 @@ class Config: Codable { var mac = "52:cc:cc:cc:cc:01" var ethermac = "52:cc:cc:cc:ce:01" var port: UInt32 = 51009 + var sharedDir: String? // The virtual machines ID. Also double as the directory name under which // we will store configuration, block device, etc. diff --git a/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift b/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift index 00f999a158c19..c0961c883fdbb 100644 --- a/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift +++ b/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift @@ -141,5 +141,18 @@ struct TailMacConfigHelper { func createKeyboardConfiguration() -> VZKeyboardConfiguration { return VZMacKeyboardConfiguration() } + + func createDirectoryShareConfiguration(tag: String) -> VZDirectorySharingDeviceConfiguration? { + guard let dir = config.sharedDir else { return nil } + + let sharedDir = VZSharedDirectory(url: URL(fileURLWithPath: dir), readOnly: false) + let share = VZSingleDirectoryShare(directory: sharedDir) + + // Create the VZVirtioFileSystemDeviceConfiguration and assign it a unique tag. + let sharingConfiguration = VZVirtioFileSystemDeviceConfiguration(tag: tag) + sharingConfiguration.share = share + + return sharingConfiguration + } } diff --git a/tstest/tailmac/Swift/Host/HostCli.swift b/tstest/tailmac/Swift/Host/HostCli.swift index 1318a09fa546e..c31478cc39d45 100644 --- a/tstest/tailmac/Swift/Host/HostCli.swift +++ b/tstest/tailmac/Swift/Host/HostCli.swift @@ -19,10 +19,12 @@ var config: Config = Config() extension HostCli { struct Run: ParsableCommand { @Option var id: String + @Option var share: String? mutating func run() { - print("Running vm with identifier \(id)") config = Config(id) + config.sharedDir = share + print("Running vm with identifier \(id) and sharedDir \(share ?? "")") _ = NSApplicationMain(CommandLine.argc, CommandLine.unsafeArgv) } } diff --git a/tstest/tailmac/Swift/Host/VMController.swift b/tstest/tailmac/Swift/Host/VMController.swift index 8774894c1157a..fe4a3828b18fe 100644 --- a/tstest/tailmac/Swift/Host/VMController.swift +++ b/tstest/tailmac/Swift/Host/VMController.swift @@ -95,6 +95,13 @@ class VMController: NSObject, VZVirtualMachineDelegate { virtualMachineConfiguration.keyboards = [helper.createKeyboardConfiguration()] virtualMachineConfiguration.socketDevices = [helper.createSocketDeviceConfiguration()] + if let dir = config.sharedDir, let shareConfig = helper.createDirectoryShareConfiguration(tag: "vmshare") { + print("Sharing \(dir) as vmshare. Use: mount_virtiofs vmshare in the guest to mount.") + virtualMachineConfiguration.directorySharingDevices = [shareConfig] + } else { + print("No shared directory created. \(config.sharedDir ?? "none") was requested.") + } + try! virtualMachineConfiguration.validate() try! virtualMachineConfiguration.validateSaveRestoreSupport() diff --git a/tstest/tailmac/Swift/TailMac/TailMac.swift b/tstest/tailmac/Swift/TailMac/TailMac.swift index 6554d5debb459..84aa5e498a008 100644 --- a/tstest/tailmac/Swift/TailMac/TailMac.swift +++ b/tstest/tailmac/Swift/TailMac/TailMac.swift @@ -95,6 +95,7 @@ extension Tailmac { extension Tailmac { struct Run: ParsableCommand { @Option(help: "The vm identifier") var id: String + @Option(help: "Optional share directory") var share: String? @Flag(help: "Tail the TailMac log output instead of returning immediatly") var tail mutating func run() { @@ -115,7 +116,12 @@ extension Tailmac { fatalError("Could not find Host.app at \(appPath). This must be co-located with the tailmac utility") } - process.arguments = ["run", "--id", id] + var args = ["run", "--id", id] + if let share { + args.append("--share") + args.append(share) + } + process.arguments = args do { process.standardOutput = stdOutPipe @@ -124,26 +130,18 @@ extension Tailmac { fatalError("Unable to launch the vm process") } - // This doesn't print until we exit which is not ideal, but at least we - // get the output if tail != 0 { + // (jonathan)TODO: How do we get the process output in real time? + // The child process only seems to flush to stdout on completion let outHandle = stdOutPipe.fileHandleForReading - - let queue = OperationQueue() - NotificationCenter.default.addObserver( - forName: NSNotification.Name.NSFileHandleDataAvailable, - object: outHandle, queue: queue) - { - notification -> Void in - let data = outHandle.availableData + outHandle.readabilityHandler = { handle in + let data = handle.availableData if data.count > 0 { if let str = String(data: data, encoding: String.Encoding.utf8) { print(str) } } - outHandle.waitForDataInBackgroundAndNotify() } - outHandle.waitForDataInBackgroundAndNotify() process.waitUntilExit() } } From 856ea2376b59df8f84f96119559d4273588a04ac Mon Sep 17 00:00:00 2001 From: Tim Walters Date: Wed, 23 Oct 2024 14:27:00 -0500 Subject: [PATCH 053/179] wgengine/magicsock: log home DERP changes with latency This adds additional logging on DERP home changes to allow better troubleshooting. Updates tailscale/corp#18095 Signed-off-by: Tim Walters --- wgengine/magicsock/derp.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/wgengine/magicsock/derp.go b/wgengine/magicsock/derp.go index bfee02f6e87da..704ce3c4ff2b5 100644 --- a/wgengine/magicsock/derp.go +++ b/wgengine/magicsock/derp.go @@ -158,10 +158,10 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int) } else { connectedToControl = c.health.GetInPollNetMap() } + c.mu.Lock() + myDerp := c.myDerp + c.mu.Unlock() if !connectedToControl { - c.mu.Lock() - myDerp := c.myDerp - c.mu.Unlock() if myDerp != 0 { metricDERPHomeNoChangeNoControl.Add(1) return myDerp @@ -178,6 +178,11 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int) // one. preferredDERP = c.pickDERPFallback() } + if preferredDERP != myDerp { + c.logf( + "magicsock: home DERP changing from derp-%d [%dms] to derp-%d [%dms]", + c.myDerp, report.RegionLatency[myDerp].Milliseconds(), preferredDERP, report.RegionLatency[preferredDERP].Milliseconds()) + } if !c.setNearestDERP(preferredDERP) { preferredDERP = 0 } From 1103044598ac2897a3f2f6687dc9d2b3d23f7da5 Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Wed, 30 Oct 2024 05:45:31 -0500 Subject: [PATCH 054/179] cmd/k8s-operator,k8s-operator: add topology spread constraints to ProxyClass (#13959) Now when we have HA for egress proxies, it makes sense to support topology spread constraints that would allow users to define more complex topologies of how proxy Pods need to be deployed in relation with other Pods/across regions etc. Updates tailscale/tailscale#13406 Signed-off-by: Irbe Krumina --- .../crds/tailscale.com_proxyclasses.yaml | 176 ++++++++++++++++++ .../deploy/manifests/operator.yaml | 176 ++++++++++++++++++ cmd/k8s-operator/sts.go | 1 + cmd/k8s-operator/sts_test.go | 13 ++ k8s-operator/api.md | 1 + .../apis/v1alpha1/types_proxyclass.go | 4 + .../apis/v1alpha1/zz_generated.deepcopy.go | 7 + 7 files changed, 378 insertions(+) diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml index 0fff30516a132..7086138c03afd 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml @@ -1896,6 +1896,182 @@ spec: Value is the taint value the toleration matches to. If the operator is Exists, the value should be empty, otherwise just a regular string. type: string + topologySpreadConstraints: + description: |- + Proxy Pod's topology spread constraints. + By default Tailscale Kubernetes operator does not apply any topology spread constraints. + https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/ + type: array + items: + description: TopologySpreadConstraint specifies how to spread matching pods among the given topology. + type: object + required: + - maxSkew + - topologyKey + - whenUnsatisfiable + properties: + labelSelector: + description: |- + LabelSelector is used to find matching pods. + Pods that match this label selector are counted to determine the number of pods + in their corresponding topology domain. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select the pods over which + spreading will be calculated. The keys are used to lookup values from the + incoming pod labels, those key-value labels are ANDed with labelSelector + to select the group of existing pods over which spreading will be calculated + for the incoming pod. The same key is forbidden to exist in both MatchLabelKeys and LabelSelector. + MatchLabelKeys cannot be set when LabelSelector isn't set. + Keys that don't exist in the incoming pod labels will + be ignored. A null or empty list means only match against labelSelector. + + This is a beta field and requires the MatchLabelKeysInPodTopologySpread feature gate to be enabled (enabled by default). + type: array + items: + type: string + x-kubernetes-list-type: atomic + maxSkew: + description: |- + MaxSkew describes the degree to which pods may be unevenly distributed. + When `whenUnsatisfiable=DoNotSchedule`, it is the maximum permitted difference + between the number of matching pods in the target topology and the global minimum. + The global minimum is the minimum number of matching pods in an eligible domain + or zero if the number of eligible domains is less than MinDomains. + For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same + labelSelector spread as 2/2/1: + In this case, the global minimum is 1. + | zone1 | zone2 | zone3 | + | P P | P P | P | + - if MaxSkew is 1, incoming pod can only be scheduled to zone3 to become 2/2/2; + scheduling it onto zone1(zone2) would make the ActualSkew(3-1) on zone1(zone2) + violate MaxSkew(1). + - if MaxSkew is 2, incoming pod can be scheduled onto any zone. + When `whenUnsatisfiable=ScheduleAnyway`, it is used to give higher precedence + to topologies that satisfy it. + It's a required field. Default value is 1 and 0 is not allowed. + type: integer + format: int32 + minDomains: + description: |- + MinDomains indicates a minimum number of eligible domains. + When the number of eligible domains with matching topology keys is less than minDomains, + Pod Topology Spread treats "global minimum" as 0, and then the calculation of Skew is performed. + And when the number of eligible domains with matching topology keys equals or greater than minDomains, + this value has no effect on scheduling. + As a result, when the number of eligible domains is less than minDomains, + scheduler won't schedule more than maxSkew Pods to those domains. + If value is nil, the constraint behaves as if MinDomains is equal to 1. + Valid values are integers greater than 0. + When value is not nil, WhenUnsatisfiable must be DoNotSchedule. + + For example, in a 3-zone cluster, MaxSkew is set to 2, MinDomains is set to 5 and pods with the same + labelSelector spread as 2/2/2: + | zone1 | zone2 | zone3 | + | P P | P P | P P | + The number of domains is less than 5(MinDomains), so "global minimum" is treated as 0. + In this situation, new pod with the same labelSelector cannot be scheduled, + because computed skew will be 3(3 - 0) if new Pod is scheduled to any of the three zones, + it will violate MaxSkew. + type: integer + format: int32 + nodeAffinityPolicy: + description: |- + NodeAffinityPolicy indicates how we will treat Pod's nodeAffinity/nodeSelector + when calculating pod topology spread skew. Options are: + - Honor: only nodes matching nodeAffinity/nodeSelector are included in the calculations. + - Ignore: nodeAffinity/nodeSelector are ignored. All nodes are included in the calculations. + + If this value is nil, the behavior is equivalent to the Honor policy. + This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag. + type: string + nodeTaintsPolicy: + description: |- + NodeTaintsPolicy indicates how we will treat node taints when calculating + pod topology spread skew. Options are: + - Honor: nodes without taints, along with tainted nodes for which the incoming pod + has a toleration, are included. + - Ignore: node taints are ignored. All nodes are included. + + If this value is nil, the behavior is equivalent to the Ignore policy. + This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag. + type: string + topologyKey: + description: |- + TopologyKey is the key of node labels. Nodes that have a label with this key + and identical values are considered to be in the same topology. + We consider each as a "bucket", and try to put balanced number + of pods into each bucket. + We define a domain as a particular instance of a topology. + Also, we define an eligible domain as a domain whose nodes meet the requirements of + nodeAffinityPolicy and nodeTaintsPolicy. + e.g. If TopologyKey is "kubernetes.io/hostname", each Node is a domain of that topology. + And, if TopologyKey is "topology.kubernetes.io/zone", each zone is a domain of that topology. + It's a required field. + type: string + whenUnsatisfiable: + description: |- + WhenUnsatisfiable indicates how to deal with a pod if it doesn't satisfy + the spread constraint. + - DoNotSchedule (default) tells the scheduler not to schedule it. + - ScheduleAnyway tells the scheduler to schedule the pod in any location, + but giving higher precedence to topologies that would help reduce the + skew. + A constraint is considered "Unsatisfiable" for an incoming pod + if and only if every possible node assignment for that pod would violate + "MaxSkew" on some topology. + For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same + labelSelector spread as 3/1/1: + | zone1 | zone2 | zone3 | + | P P P | P | P | + If WhenUnsatisfiable is set to DoNotSchedule, incoming pod can only be scheduled + to zone2(zone3) to become 3/2/1(3/1/2) as ActualSkew(2-1) on zone2(zone3) satisfies + MaxSkew(1). In other words, the cluster can still be imbalanced, but scheduler + won't make it *more* imbalanced. + It's a required field. + type: string tailscale: description: |- TailscaleConfig contains options to configure the tailscale-specific diff --git a/cmd/k8s-operator/deploy/manifests/operator.yaml b/cmd/k8s-operator/deploy/manifests/operator.yaml index 1a812b7362757..203a670664968 100644 --- a/cmd/k8s-operator/deploy/manifests/operator.yaml +++ b/cmd/k8s-operator/deploy/manifests/operator.yaml @@ -2323,6 +2323,182 @@ spec: type: string type: object type: array + topologySpreadConstraints: + description: |- + Proxy Pod's topology spread constraints. + By default Tailscale Kubernetes operator does not apply any topology spread constraints. + https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/ + items: + description: TopologySpreadConstraint specifies how to spread matching pods among the given topology. + properties: + labelSelector: + description: |- + LabelSelector is used to find matching pods. + Pods that match this label selector are counted to determine the number of pods + in their corresponding topology domain. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select the pods over which + spreading will be calculated. The keys are used to lookup values from the + incoming pod labels, those key-value labels are ANDed with labelSelector + to select the group of existing pods over which spreading will be calculated + for the incoming pod. The same key is forbidden to exist in both MatchLabelKeys and LabelSelector. + MatchLabelKeys cannot be set when LabelSelector isn't set. + Keys that don't exist in the incoming pod labels will + be ignored. A null or empty list means only match against labelSelector. + + This is a beta field and requires the MatchLabelKeysInPodTopologySpread feature gate to be enabled (enabled by default). + items: + type: string + type: array + x-kubernetes-list-type: atomic + maxSkew: + description: |- + MaxSkew describes the degree to which pods may be unevenly distributed. + When `whenUnsatisfiable=DoNotSchedule`, it is the maximum permitted difference + between the number of matching pods in the target topology and the global minimum. + The global minimum is the minimum number of matching pods in an eligible domain + or zero if the number of eligible domains is less than MinDomains. + For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same + labelSelector spread as 2/2/1: + In this case, the global minimum is 1. + | zone1 | zone2 | zone3 | + | P P | P P | P | + - if MaxSkew is 1, incoming pod can only be scheduled to zone3 to become 2/2/2; + scheduling it onto zone1(zone2) would make the ActualSkew(3-1) on zone1(zone2) + violate MaxSkew(1). + - if MaxSkew is 2, incoming pod can be scheduled onto any zone. + When `whenUnsatisfiable=ScheduleAnyway`, it is used to give higher precedence + to topologies that satisfy it. + It's a required field. Default value is 1 and 0 is not allowed. + format: int32 + type: integer + minDomains: + description: |- + MinDomains indicates a minimum number of eligible domains. + When the number of eligible domains with matching topology keys is less than minDomains, + Pod Topology Spread treats "global minimum" as 0, and then the calculation of Skew is performed. + And when the number of eligible domains with matching topology keys equals or greater than minDomains, + this value has no effect on scheduling. + As a result, when the number of eligible domains is less than minDomains, + scheduler won't schedule more than maxSkew Pods to those domains. + If value is nil, the constraint behaves as if MinDomains is equal to 1. + Valid values are integers greater than 0. + When value is not nil, WhenUnsatisfiable must be DoNotSchedule. + + For example, in a 3-zone cluster, MaxSkew is set to 2, MinDomains is set to 5 and pods with the same + labelSelector spread as 2/2/2: + | zone1 | zone2 | zone3 | + | P P | P P | P P | + The number of domains is less than 5(MinDomains), so "global minimum" is treated as 0. + In this situation, new pod with the same labelSelector cannot be scheduled, + because computed skew will be 3(3 - 0) if new Pod is scheduled to any of the three zones, + it will violate MaxSkew. + format: int32 + type: integer + nodeAffinityPolicy: + description: |- + NodeAffinityPolicy indicates how we will treat Pod's nodeAffinity/nodeSelector + when calculating pod topology spread skew. Options are: + - Honor: only nodes matching nodeAffinity/nodeSelector are included in the calculations. + - Ignore: nodeAffinity/nodeSelector are ignored. All nodes are included in the calculations. + + If this value is nil, the behavior is equivalent to the Honor policy. + This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag. + type: string + nodeTaintsPolicy: + description: |- + NodeTaintsPolicy indicates how we will treat node taints when calculating + pod topology spread skew. Options are: + - Honor: nodes without taints, along with tainted nodes for which the incoming pod + has a toleration, are included. + - Ignore: node taints are ignored. All nodes are included. + + If this value is nil, the behavior is equivalent to the Ignore policy. + This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag. + type: string + topologyKey: + description: |- + TopologyKey is the key of node labels. Nodes that have a label with this key + and identical values are considered to be in the same topology. + We consider each as a "bucket", and try to put balanced number + of pods into each bucket. + We define a domain as a particular instance of a topology. + Also, we define an eligible domain as a domain whose nodes meet the requirements of + nodeAffinityPolicy and nodeTaintsPolicy. + e.g. If TopologyKey is "kubernetes.io/hostname", each Node is a domain of that topology. + And, if TopologyKey is "topology.kubernetes.io/zone", each zone is a domain of that topology. + It's a required field. + type: string + whenUnsatisfiable: + description: |- + WhenUnsatisfiable indicates how to deal with a pod if it doesn't satisfy + the spread constraint. + - DoNotSchedule (default) tells the scheduler not to schedule it. + - ScheduleAnyway tells the scheduler to schedule the pod in any location, + but giving higher precedence to topologies that would help reduce the + skew. + A constraint is considered "Unsatisfiable" for an incoming pod + if and only if every possible node assignment for that pod would violate + "MaxSkew" on some topology. + For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same + labelSelector spread as 3/1/1: + | zone1 | zone2 | zone3 | + | P P P | P | P | + If WhenUnsatisfiable is set to DoNotSchedule, incoming pod can only be scheduled + to zone2(zone3) to become 3/2/1(3/1/2) as ActualSkew(2-1) on zone2(zone3) satisfies + MaxSkew(1). In other words, the cluster can still be imbalanced, but scheduler + won't make it *more* imbalanced. + It's a required field. + type: string + required: + - maxSkew + - topologyKey + - whenUnsatisfiable + type: object + type: array type: object type: object tailscale: diff --git a/cmd/k8s-operator/sts.go b/cmd/k8s-operator/sts.go index 6378a82636939..e89b9c93082cf 100644 --- a/cmd/k8s-operator/sts.go +++ b/cmd/k8s-operator/sts.go @@ -718,6 +718,7 @@ func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, ss.Spec.Template.Spec.NodeSelector = wantsPod.NodeSelector ss.Spec.Template.Spec.Affinity = wantsPod.Affinity ss.Spec.Template.Spec.Tolerations = wantsPod.Tolerations + ss.Spec.Template.Spec.TopologySpreadConstraints = wantsPod.TopologySpreadConstraints // Update containers. updateContainer := func(overlay *tsapi.Container, base corev1.Container) corev1.Container { diff --git a/cmd/k8s-operator/sts_test.go b/cmd/k8s-operator/sts_test.go index b2b2c8b93a2d7..7263c56c36bb9 100644 --- a/cmd/k8s-operator/sts_test.go +++ b/cmd/k8s-operator/sts_test.go @@ -18,6 +18,7 @@ import ( appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/yaml" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/types/ptr" @@ -73,6 +74,16 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { NodeSelector: map[string]string{"beta.kubernetes.io/os": "linux"}, Affinity: &corev1.Affinity{NodeAffinity: &corev1.NodeAffinity{RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{}}}, Tolerations: []corev1.Toleration{{Key: "", Operator: "Exists"}}, + TopologySpreadConstraints: []corev1.TopologySpreadConstraint{ + { + WhenUnsatisfiable: "DoNotSchedule", + TopologyKey: "kubernetes.io/hostname", + MaxSkew: 3, + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{"foo": "bar"}, + }, + }, + }, TailscaleContainer: &tsapi.Container{ SecurityContext: &corev1.SecurityContext{ Privileged: ptr.To(true), @@ -159,6 +170,7 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { wantSS.Spec.Template.Spec.NodeSelector = proxyClassAllOpts.Spec.StatefulSet.Pod.NodeSelector wantSS.Spec.Template.Spec.Affinity = proxyClassAllOpts.Spec.StatefulSet.Pod.Affinity wantSS.Spec.Template.Spec.Tolerations = proxyClassAllOpts.Spec.StatefulSet.Pod.Tolerations + wantSS.Spec.Template.Spec.TopologySpreadConstraints = proxyClassAllOpts.Spec.StatefulSet.Pod.TopologySpreadConstraints wantSS.Spec.Template.Spec.Containers[0].SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.SecurityContext wantSS.Spec.Template.Spec.InitContainers[0].SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleInitContainer.SecurityContext wantSS.Spec.Template.Spec.Containers[0].Resources = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.Resources @@ -201,6 +213,7 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { wantSS.Spec.Template.Spec.NodeSelector = proxyClassAllOpts.Spec.StatefulSet.Pod.NodeSelector wantSS.Spec.Template.Spec.Affinity = proxyClassAllOpts.Spec.StatefulSet.Pod.Affinity wantSS.Spec.Template.Spec.Tolerations = proxyClassAllOpts.Spec.StatefulSet.Pod.Tolerations + wantSS.Spec.Template.Spec.TopologySpreadConstraints = proxyClassAllOpts.Spec.StatefulSet.Pod.TopologySpreadConstraints wantSS.Spec.Template.Spec.Containers[0].SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.SecurityContext wantSS.Spec.Template.Spec.Containers[0].Resources = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.Resources wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, []corev1.EnvVar{{Name: "foo", Value: "bar"}, {Name: "TS_USERSPACE", Value: "true"}, {Name: "bar"}}...) diff --git a/k8s-operator/api.md b/k8s-operator/api.md index e8a6e248a2934..dae969516b9e7 100644 --- a/k8s-operator/api.md +++ b/k8s-operator/api.md @@ -381,6 +381,7 @@ _Appears in:_ | `nodeName` _string_ | Proxy Pod's node name.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | | | `nodeSelector` _object (keys:string, values:string)_ | Proxy Pod's node selector.
By default Tailscale Kubernetes operator does not apply any node
selector.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | | | `tolerations` _[Toleration](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#toleration-v1-core) array_ | Proxy Pod's tolerations.
By default Tailscale Kubernetes operator does not apply any
tolerations.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | | +| `topologySpreadConstraints` _[TopologySpreadConstraint](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#topologyspreadconstraint-v1-core) array_ | Proxy Pod's topology spread constraints.
By default Tailscale Kubernetes operator does not apply any topology spread constraints.
https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/ | | | #### ProxyClass diff --git a/k8s-operator/apis/v1alpha1/types_proxyclass.go b/k8s-operator/apis/v1alpha1/types_proxyclass.go index 7f415bc340bd7..0a224b7960495 100644 --- a/k8s-operator/apis/v1alpha1/types_proxyclass.go +++ b/k8s-operator/apis/v1alpha1/types_proxyclass.go @@ -154,7 +154,11 @@ type Pod struct { // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling // +optional Tolerations []corev1.Toleration `json:"tolerations,omitempty"` + // Proxy Pod's topology spread constraints. + // By default Tailscale Kubernetes operator does not apply any topology spread constraints. + // https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/ // +optional + TopologySpreadConstraints []corev1.TopologySpreadConstraint `json:"topologySpreadConstraints,omitempty"` } type Metrics struct { diff --git a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go index ba4ff40e46dd5..f53165b886ec2 100644 --- a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go +++ b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go @@ -392,6 +392,13 @@ func (in *Pod) DeepCopyInto(out *Pod) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.TopologySpreadConstraints != nil { + in, out := &in.TopologySpreadConstraints, &out.TopologySpreadConstraints + *out = make([]corev1.TopologySpreadConstraint, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Pod. From 2336c340c4fc72758a8e7bae15062fb78f98d895 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Fri, 18 Oct 2024 10:18:06 -0500 Subject: [PATCH 055/179] util/syspolicy: implement a syspolicy store that reads settings from environment variables In this PR, we implement (but do not use yet, pending #13727 review) a syspolicy/source.Store that reads policy settings from environment variables. It converts a CamelCase setting.Key, such as AuthKey or ExitNodeID, to a SCREAMING_SNAKE_CASE, TS_-prefixed environment variable name, such as TS_AUTH_KEY and TS_EXIT_NODE_ID. It then looks up the variable and attempts to parse it according to the expected value type. If the environment variable is not set, the policy setting is considered not configured in this store (the syspolicy package will still read it from other sources). Similarly, if the environment variable has an invalid value for the setting type, it won't be used (though the reported/logged error will differ). Updates #13193 Updates #12687 Signed-off-by: Nick Khyl --- util/syspolicy/internal/metrics/metrics.go | 2 +- util/syspolicy/setting/key.go | 2 +- util/syspolicy/source/env_policy_store.go | 159 ++++++++ .../syspolicy/source/env_policy_store_test.go | 354 ++++++++++++++++++ util/syspolicy/source/policy_store_windows.go | 6 +- 5 files changed, 518 insertions(+), 5 deletions(-) create mode 100644 util/syspolicy/source/env_policy_store.go create mode 100644 util/syspolicy/source/env_policy_store_test.go diff --git a/util/syspolicy/internal/metrics/metrics.go b/util/syspolicy/internal/metrics/metrics.go index 2ea02278afc92..0a2aa1192fc53 100644 --- a/util/syspolicy/internal/metrics/metrics.go +++ b/util/syspolicy/internal/metrics/metrics.go @@ -284,7 +284,7 @@ func SetHooksForTest(tb internal.TB, addMetric, setMetric metricFn) { } func newSettingMetric(key setting.Key, scope setting.Scope, suffix string, typ clientmetric.Type) metric { - name := strings.ReplaceAll(string(key), setting.KeyPathSeparator, "_") + name := strings.ReplaceAll(string(key), string(setting.KeyPathSeparator), "_") return newMetric([]string{name, metricScopeName(scope), suffix}, typ) } diff --git a/util/syspolicy/setting/key.go b/util/syspolicy/setting/key.go index 406fde1321cc2..aa7606d36324a 100644 --- a/util/syspolicy/setting/key.go +++ b/util/syspolicy/setting/key.go @@ -10,4 +10,4 @@ package setting type Key string // KeyPathSeparator allows logical grouping of policy settings into categories. -const KeyPathSeparator = "/" +const KeyPathSeparator = '/' diff --git a/util/syspolicy/source/env_policy_store.go b/util/syspolicy/source/env_policy_store.go new file mode 100644 index 0000000000000..2f07fffcaa22a --- /dev/null +++ b/util/syspolicy/source/env_policy_store.go @@ -0,0 +1,159 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "fmt" + "os" + "strconv" + "strings" + "unicode/utf8" + + "github.com/pkg/errors" + "tailscale.com/util/syspolicy/setting" +) + +var lookupEnv = os.LookupEnv // test hook + +var _ Store = (*EnvPolicyStore)(nil) + +// EnvPolicyStore is a [Store] that reads policy settings from environment variables. +type EnvPolicyStore struct{} + +// ReadString implements [Store]. +func (s *EnvPolicyStore) ReadString(key setting.Key) (string, error) { + _, str, err := s.lookupSettingVariable(key) + if err != nil { + return "", err + } + return str, nil +} + +// ReadUInt64 implements [Store]. +func (s *EnvPolicyStore) ReadUInt64(key setting.Key) (uint64, error) { + name, str, err := s.lookupSettingVariable(key) + if err != nil { + return 0, err + } + if str == "" { + return 0, setting.ErrNotConfigured + } + value, err := strconv.ParseUint(str, 0, 64) + if err != nil { + return 0, fmt.Errorf("%s: %w: %q is not a valid uint64", name, setting.ErrTypeMismatch, str) + } + return value, nil +} + +// ReadBoolean implements [Store]. +func (s *EnvPolicyStore) ReadBoolean(key setting.Key) (bool, error) { + name, str, err := s.lookupSettingVariable(key) + if err != nil { + return false, err + } + if str == "" { + return false, setting.ErrNotConfigured + } + value, err := strconv.ParseBool(str) + if err != nil { + return false, fmt.Errorf("%s: %w: %q is not a valid bool", name, setting.ErrTypeMismatch, str) + } + return value, nil +} + +// ReadStringArray implements [Store]. +func (s *EnvPolicyStore) ReadStringArray(key setting.Key) ([]string, error) { + _, str, err := s.lookupSettingVariable(key) + if err != nil || str == "" { + return nil, err + } + var dst int + res := strings.Split(str, ",") + for src := range res { + res[dst] = strings.TrimSpace(res[src]) + if res[dst] != "" { + dst++ + } + } + return res[0:dst], nil +} + +func (s *EnvPolicyStore) lookupSettingVariable(key setting.Key) (name, value string, err error) { + name, err = keyToEnvVarName(key) + if err != nil { + return "", "", err + } + value, ok := lookupEnv(name) + if !ok { + return name, "", setting.ErrNotConfigured + } + return name, value, nil +} + +var ( + errEmptyKey = errors.New("key must not be empty") + errInvalidKey = errors.New("key must consist of alphanumeric characters and slashes") +) + +// keyToEnvVarName returns the environment variable name for a given policy +// setting key, or an error if the key is invalid. It converts CamelCase keys into +// underscore-separated words and prepends the variable name with the TS prefix. +// For example: AuthKey => TS_AUTH_KEY, ExitNodeAllowLANAccess => TS_EXIT_NODE_ALLOW_LAN_ACCESS, etc. +// +// It's fine to use this in [EnvPolicyStore] without caching variable names since it's not a hot path. +// [EnvPolicyStore] is not a [Changeable] policy store, so the conversion will only happen once. +func keyToEnvVarName(key setting.Key) (string, error) { + if len(key) == 0 { + return "", errEmptyKey + } + + isLower := func(c byte) bool { return 'a' <= c && c <= 'z' } + isUpper := func(c byte) bool { return 'A' <= c && c <= 'Z' } + isLetter := func(c byte) bool { return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') } + isDigit := func(c byte) bool { return '0' <= c && c <= '9' } + + words := make([]string, 0, 8) + words = append(words, "TS") + var currentWord strings.Builder + for i := 0; i < len(key); i++ { + c := key[i] + if c >= utf8.RuneSelf { + return "", errInvalidKey + } + + var split bool + switch { + case isLower(c): + c -= 'a' - 'A' // make upper + split = currentWord.Len() > 0 && !isLetter(key[i-1]) + case isUpper(c): + if currentWord.Len() > 0 { + prevUpper := isUpper(key[i-1]) + nextLower := i < len(key)-1 && isLower(key[i+1]) + split = !prevUpper || nextLower // split on case transition + } + case isDigit(c): + split = currentWord.Len() > 0 && !isDigit(key[i-1]) + case c == setting.KeyPathSeparator: + words = append(words, currentWord.String()) + currentWord.Reset() + continue + default: + return "", errInvalidKey + } + + if split { + words = append(words, currentWord.String()) + currentWord.Reset() + } + + currentWord.WriteByte(c) + } + + if currentWord.Len() > 0 { + words = append(words, currentWord.String()) + } + + return strings.Join(words, "_"), nil +} diff --git a/util/syspolicy/source/env_policy_store_test.go b/util/syspolicy/source/env_policy_store_test.go new file mode 100644 index 0000000000000..364a6104d4f99 --- /dev/null +++ b/util/syspolicy/source/env_policy_store_test.go @@ -0,0 +1,354 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "cmp" + "errors" + "math" + "reflect" + "strconv" + "testing" + + "tailscale.com/util/syspolicy/setting" +) + +func TestKeyToVariableName(t *testing.T) { + tests := []struct { + name string + key setting.Key + want string + wantErr error + }{ + { + name: "empty", + key: "", + wantErr: errEmptyKey, + }, + { + name: "lowercase", + key: "tailnet", + want: "TS_TAILNET", + }, + { + name: "CamelCase", + key: "AuthKey", + want: "TS_AUTH_KEY", + }, + { + name: "LongerCamelCase", + key: "ManagedByOrganizationName", + want: "TS_MANAGED_BY_ORGANIZATION_NAME", + }, + { + name: "UPPERCASE", + key: "UPPERCASE", + want: "TS_UPPERCASE", + }, + { + name: "WithAbbrev/Front", + key: "DNSServer", + want: "TS_DNS_SERVER", + }, + { + name: "WithAbbrev/Middle", + key: "ExitNodeAllowLANAccess", + want: "TS_EXIT_NODE_ALLOW_LAN_ACCESS", + }, + { + name: "WithAbbrev/Back", + key: "ExitNodeID", + want: "TS_EXIT_NODE_ID", + }, + { + name: "WithDigits/Single/Front", + key: "0TestKey", + want: "TS_0_TEST_KEY", + }, + { + name: "WithDigits/Multi/Front", + key: "64TestKey", + want: "TS_64_TEST_KEY", + }, + { + name: "WithDigits/Single/Middle", + key: "Test0Key", + want: "TS_TEST_0_KEY", + }, + { + name: "WithDigits/Multi/Middle", + key: "Test64Key", + want: "TS_TEST_64_KEY", + }, + { + name: "WithDigits/Single/Back", + key: "TestKey0", + want: "TS_TEST_KEY_0", + }, + { + name: "WithDigits/Multi/Back", + key: "TestKey64", + want: "TS_TEST_KEY_64", + }, + { + name: "WithDigits/Multi/Back", + key: "TestKey64", + want: "TS_TEST_KEY_64", + }, + { + name: "WithPathSeparators/Single", + key: "Key/Subkey", + want: "TS_KEY_SUBKEY", + }, + { + name: "WithPathSeparators/Multi", + key: "Root/Level1/Level2", + want: "TS_ROOT_LEVEL_1_LEVEL_2", + }, + { + name: "Mixed", + key: "Network/DNSServer/IPAddress", + want: "TS_NETWORK_DNS_SERVER_IP_ADDRESS", + }, + { + name: "Non-Alphanumeric/NonASCII/1", + key: "ж", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/NonASCII/2", + key: "KeyжName", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/Space", + key: "Key Name", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/Punct", + key: "Key!Name", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/Backslash", + key: `Key\Name`, + wantErr: errInvalidKey, + }, + } + for _, tt := range tests { + t.Run(cmp.Or(tt.name, string(tt.key)), func(t *testing.T) { + got, err := keyToEnvVarName(tt.key) + checkError(t, err, tt.wantErr, true) + + if got != tt.want { + t.Fatalf("got %q; want %q", got, tt.want) + } + }) + } +} + +func TestEnvPolicyStore(t *testing.T) { + blankEnv := func(string) (string, bool) { return "", false } + makeEnv := func(wantName, value string) func(string) (string, bool) { + return func(gotName string) (string, bool) { + if gotName != wantName { + return "", false + } + return value, true + } + } + tests := []struct { + name string + key setting.Key + lookup func(string) (string, bool) + want any + wantErr error + }{ + { + name: "NotConfigured/String", + key: "AuthKey", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: "", + }, + { + name: "Configured/String/Empty", + key: "AuthKey", + lookup: makeEnv("TS_AUTH_KEY", ""), + want: "", + }, + { + name: "Configured/String/NonEmpty", + key: "AuthKey", + lookup: makeEnv("TS_AUTH_KEY", "ABC123"), + want: "ABC123", + }, + { + name: "NotConfigured/UInt64", + key: "IntegerSetting", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: uint64(0), + }, + { + name: "Configured/UInt64/Empty", + key: "IntegerSetting", + lookup: makeEnv("TS_INTEGER_SETTING", ""), + wantErr: setting.ErrNotConfigured, + want: uint64(0), + }, + { + name: "Configured/UInt64/Zero", + key: "IntegerSetting", + lookup: makeEnv("TS_INTEGER_SETTING", "0"), + want: uint64(0), + }, + { + name: "Configured/UInt64/NonZero", + key: "IntegerSetting", + lookup: makeEnv("TS_INTEGER_SETTING", "12345"), + want: uint64(12345), + }, + { + name: "Configured/UInt64/MaxUInt64", + key: "IntegerSetting", + lookup: makeEnv("TS_INTEGER_SETTING", strconv.FormatUint(math.MaxUint64, 10)), + want: uint64(math.MaxUint64), + }, + { + name: "Configured/UInt64/Negative", + key: "IntegerSetting", + lookup: makeEnv("TS_INTEGER_SETTING", "-1"), + wantErr: setting.ErrTypeMismatch, + want: uint64(0), + }, + { + name: "Configured/UInt64/Hex", + key: "IntegerSetting", + lookup: makeEnv("TS_INTEGER_SETTING", "0xDEADBEEF"), + want: uint64(0xDEADBEEF), + }, + { + name: "NotConfigured/Bool", + key: "LogSCMInteractions", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: false, + }, + { + name: "Configured/Bool/Empty", + key: "LogSCMInteractions", + lookup: makeEnv("TS_LOG_SCM_INTERACTIONS", ""), + wantErr: setting.ErrNotConfigured, + want: false, + }, + { + name: "Configured/Bool/True", + key: "LogSCMInteractions", + lookup: makeEnv("TS_LOG_SCM_INTERACTIONS", "true"), + want: true, + }, + { + name: "Configured/Bool/False", + key: "LogSCMInteractions", + lookup: makeEnv("TS_LOG_SCM_INTERACTIONS", "False"), + want: false, + }, + { + name: "Configured/Bool/1", + key: "LogSCMInteractions", + lookup: makeEnv("TS_LOG_SCM_INTERACTIONS", "1"), + want: true, + }, + { + name: "Configured/Bool/0", + key: "LogSCMInteractions", + lookup: makeEnv("TS_LOG_SCM_INTERACTIONS", "0"), + want: false, + }, + { + name: "Configured/Bool/Invalid", + key: "IntegerSetting", + lookup: makeEnv("TS_INTEGER_SETTING", "NotABool"), + wantErr: setting.ErrTypeMismatch, + want: false, + }, + { + name: "NotConfigured/StringArray", + key: "AllowedSuggestedExitNodes", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: []string(nil), + }, + { + name: "Configured/StringArray/Empty", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("TS_ALLOWED_SUGGESTED_EXIT_NODES", ""), + want: []string(nil), + }, + { + name: "Configured/StringArray/Spaces", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("TS_ALLOWED_SUGGESTED_EXIT_NODES", " \t "), + want: []string{}, + }, + { + name: "Configured/StringArray/Single", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("TS_ALLOWED_SUGGESTED_EXIT_NODES", "NodeA"), + want: []string{"NodeA"}, + }, + { + name: "Configured/StringArray/Multi", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("TS_ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,NodeB,NodeC"), + want: []string{"NodeA", "NodeB", "NodeC"}, + }, + { + name: "Configured/StringArray/WithBlank", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("TS_ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,\t,, ,NodeB"), + want: []string{"NodeA", "NodeB"}, + }, + } + for _, tt := range tests { + t.Run(cmp.Or(tt.name, string(tt.key)), func(t *testing.T) { + oldLookupEnv := lookupEnv + t.Cleanup(func() { lookupEnv = oldLookupEnv }) + lookupEnv = tt.lookup + + var got any + var err error + var store EnvPolicyStore + switch tt.want.(type) { + case string: + got, err = store.ReadString(tt.key) + case uint64: + got, err = store.ReadUInt64(tt.key) + case bool: + got, err = store.ReadBoolean(tt.key) + case []string: + got, err = store.ReadStringArray(tt.key) + } + checkError(t, err, tt.wantErr, false) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} + +func checkError(tb testing.TB, got, want error, fatal bool) { + tb.Helper() + f := tb.Errorf + if fatal { + f = tb.Fatalf + } + if (want == nil && got != nil) || + (want != nil && got == nil) || + (want != nil && got != nil && !errors.Is(got, want) && want.Error() != got.Error()) { + f("gotErr: %v; wantErr: %v", got, want) + } +} diff --git a/util/syspolicy/source/policy_store_windows.go b/util/syspolicy/source/policy_store_windows.go index f526b4ce1c666..86e2254e0a381 100644 --- a/util/syspolicy/source/policy_store_windows.go +++ b/util/syspolicy/source/policy_store_windows.go @@ -319,9 +319,9 @@ func (ps *PlatformPolicyStore) ReadStringArray(key setting.Key) ([]string, error // If there are no [setting.KeyPathSeparator]s in the key, the policy setting value // is meant to be stored directly under {HKLM,HKCU}\Software\Policies\Tailscale. func splitSettingKey(key setting.Key) (path, valueName string) { - if idx := strings.LastIndex(string(key), setting.KeyPathSeparator); idx != -1 { - path = strings.ReplaceAll(string(key[:idx]), setting.KeyPathSeparator, `\`) - valueName = string(key[idx+len(setting.KeyPathSeparator):]) + if idx := strings.LastIndexByte(string(key), setting.KeyPathSeparator); idx != -1 { + path = strings.ReplaceAll(string(key[:idx]), string(setting.KeyPathSeparator), `\`) + valueName = string(key[idx+1:]) return path, valueName } return "", string(key) From 2cc1100d242df512612781187eaa898d0de133dc Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Wed, 30 Oct 2024 12:01:20 -0500 Subject: [PATCH 056/179] util/syspolicy/source: use errors instead of github.com/pkg/errors Updates #12687 Signed-off-by: Nick Khyl --- util/syspolicy/source/env_policy_store.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/util/syspolicy/source/env_policy_store.go b/util/syspolicy/source/env_policy_store.go index 2f07fffcaa22a..61065ceff4d4c 100644 --- a/util/syspolicy/source/env_policy_store.go +++ b/util/syspolicy/source/env_policy_store.go @@ -4,13 +4,13 @@ package source import ( + "errors" "fmt" "os" "strconv" "strings" "unicode/utf8" - "github.com/pkg/errors" "tailscale.com/util/syspolicy/setting" ) From 2a2228f97b625b20f5d62092b9f17730078a7fb4 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Tue, 29 Oct 2024 11:24:46 -0500 Subject: [PATCH 057/179] util/syspolicy/setting: make setting.RawItem JSON-marshallable We add setting.RawValue, a new type that facilitates unmarshalling JSON numbers and arrays as uint64 and []string (instead of float64 and []any) for policy setting values. We then use it to make setting.RawItem JSON-marshallable and update the tests. Updates #12687 Signed-off-by: Nick Khyl --- types/opt/value.go | 2 +- util/syspolicy/setting/raw_item.go | 123 ++++++++++-- util/syspolicy/setting/raw_item_test.go | 101 ++++++++++ util/syspolicy/setting/snapshot_test.go | 251 ++++++++++++------------ 4 files changed, 336 insertions(+), 141 deletions(-) create mode 100644 util/syspolicy/setting/raw_item_test.go diff --git a/types/opt/value.go b/types/opt/value.go index 54fab7a538270..b47b03c81b026 100644 --- a/types/opt/value.go +++ b/types/opt/value.go @@ -36,7 +36,7 @@ func ValueOf[T any](v T) Value[T] { } // String implements [fmt.Stringer]. -func (o *Value[T]) String() string { +func (o Value[T]) String() string { if !o.set { return fmt.Sprintf("(empty[%T])", o.value) } diff --git a/util/syspolicy/setting/raw_item.go b/util/syspolicy/setting/raw_item.go index 30480d8923f71..cf46e54b76217 100644 --- a/util/syspolicy/setting/raw_item.go +++ b/util/syspolicy/setting/raw_item.go @@ -5,7 +5,11 @@ package setting import ( "fmt" + "reflect" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" + "tailscale.com/types/opt" "tailscale.com/types/structs" ) @@ -17,10 +21,15 @@ import ( // or converted from strings, these setting types predate the typed policy // hierarchies, and must be supported at this layer. type RawItem struct { - _ structs.Incomparable - value any - err *ErrorText - origin *Origin // or nil + _ structs.Incomparable + data rawItemJSON +} + +// rawItemJSON holds JSON-marshallable data for [RawItem]. +type rawItemJSON struct { + Value RawValue `json:",omitzero"` + Error *ErrorText `json:",omitzero"` // or nil + Origin *Origin `json:",omitzero"` // or nil } // RawItemOf returns a [RawItem] with the specified value. @@ -30,20 +39,20 @@ func RawItemOf(value any) RawItem { // RawItemWith returns a [RawItem] with the specified value, error and origin. func RawItemWith(value any, err *ErrorText, origin *Origin) RawItem { - return RawItem{value: value, err: err, origin: origin} + return RawItem{data: rawItemJSON{Value: RawValue{opt.ValueOf(value)}, Error: err, Origin: origin}} } // Value returns the value of the policy setting, or nil if the policy setting // is not configured, or an error occurred while reading it. func (i RawItem) Value() any { - return i.value + return i.data.Value.Get() } // Error returns the error that occurred when reading the policy setting, // or nil if no error occurred. func (i RawItem) Error() error { - if i.err != nil { - return i.err + if i.data.Error != nil { + return i.data.Error } return nil } @@ -51,17 +60,103 @@ func (i RawItem) Error() error { // Origin returns an optional [Origin] indicating where the policy setting is // configured. func (i RawItem) Origin() *Origin { - return i.origin + return i.data.Origin } // String implements [fmt.Stringer]. func (i RawItem) String() string { var suffix string - if i.origin != nil { - suffix = fmt.Sprintf(" - {%v}", i.origin) + if i.data.Origin != nil { + suffix = fmt.Sprintf(" - {%v}", i.data.Origin) + } + if i.data.Error != nil { + return fmt.Sprintf("Error{%q}%s", i.data.Error.Error(), suffix) + } + return fmt.Sprintf("%v%s", i.data.Value.Value, suffix) +} + +// MarshalJSONV2 implements [jsonv2.MarshalerV2]. +func (i RawItem) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { + return jsonv2.MarshalEncode(out, &i.data, opts) +} + +// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. +func (i *RawItem) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { + return jsonv2.UnmarshalDecode(in, &i.data, opts) +} + +// MarshalJSON implements [json.Marshaler]. +func (i RawItem) MarshalJSON() ([]byte, error) { + return jsonv2.Marshal(i) // uses MarshalJSONV2 +} + +// UnmarshalJSON implements [json.Unmarshaler]. +func (i *RawItem) UnmarshalJSON(b []byte) error { + return jsonv2.Unmarshal(b, i) // uses UnmarshalJSONV2 +} + +// RawValue represents a raw policy setting value read from a policy store. +// It is JSON-marshallable and facilitates unmarshalling of JSON values +// into corresponding policy setting types, with special handling for JSON numbers +// (unmarshalled as float64) and JSON string arrays (unmarshalled as []string). +// See also [RawValue.UnmarshalJSONV2]. +type RawValue struct { + opt.Value[any] +} + +// RawValueType is a constraint that permits raw setting value types. +type RawValueType interface { + bool | uint64 | string | []string +} + +// RawValueOf returns a new [RawValue] holding the specified value. +func RawValueOf[T RawValueType](v T) RawValue { + return RawValue{opt.ValueOf[any](v)} +} + +// MarshalJSONV2 implements [jsonv2.MarshalerV2]. +func (v RawValue) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { + return jsonv2.MarshalEncode(out, v.Value, opts) +} + +// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2] by attempting to unmarshal +// a JSON value as one of the supported policy setting value types (bool, string, uint64, or []string), +// based on the JSON value type. It fails if the JSON value is an object, if it's a JSON number that +// cannot be represented as a uint64, or if a JSON array contains anything other than strings. +func (v *RawValue) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { + var valPtr any + switch k := in.PeekKind(); k { + case 't', 'f': + valPtr = new(bool) + case '"': + valPtr = new(string) + case '0': + valPtr = new(uint64) // unmarshal JSON numbers as uint64 + case '[', 'n': + valPtr = new([]string) // unmarshal arrays as string slices + case '{': + return fmt.Errorf("unexpected token: %v", k) + default: + panic("unreachable") } - if i.err != nil { - return fmt.Sprintf("Error{%q}%s", i.err.Error(), suffix) + if err := jsonv2.UnmarshalDecode(in, valPtr, opts); err != nil { + v.Value.Clear() + return err } - return fmt.Sprintf("%v%s", i.value, suffix) + value := reflect.ValueOf(valPtr).Elem().Interface() + v.Value = opt.ValueOf(value) + return nil +} + +// MarshalJSON implements [json.Marshaler]. +func (v RawValue) MarshalJSON() ([]byte, error) { + return jsonv2.Marshal(v) // uses MarshalJSONV2 } + +// UnmarshalJSON implements [json.Unmarshaler]. +func (v *RawValue) UnmarshalJSON(b []byte) error { + return jsonv2.Unmarshal(b, v) // uses UnmarshalJSONV2 +} + +// RawValues is a map of keyed setting values that can be read from a JSON. +type RawValues map[Key]RawValue diff --git a/util/syspolicy/setting/raw_item_test.go b/util/syspolicy/setting/raw_item_test.go new file mode 100644 index 0000000000000..05562d78c41f3 --- /dev/null +++ b/util/syspolicy/setting/raw_item_test.go @@ -0,0 +1,101 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package setting + +import ( + "math" + "reflect" + "strconv" + "testing" + + jsonv2 "github.com/go-json-experiment/json" +) + +func TestMarshalUnmarshalRawValue(t *testing.T) { + tests := []struct { + name string + json string + want RawValue + wantErr bool + }{ + { + name: "Bool/True", + json: `true`, + want: RawValueOf(true), + }, + { + name: "Bool/False", + json: `false`, + want: RawValueOf(false), + }, + { + name: "String/Empty", + json: `""`, + want: RawValueOf(""), + }, + { + name: "String/NonEmpty", + json: `"Test"`, + want: RawValueOf("Test"), + }, + { + name: "StringSlice/Null", + json: `null`, + want: RawValueOf([]string(nil)), + }, + { + name: "StringSlice/Empty", + json: `[]`, + want: RawValueOf([]string{}), + }, + { + name: "StringSlice/NonEmpty", + json: `["A", "B", "C"]`, + want: RawValueOf([]string{"A", "B", "C"}), + }, + { + name: "StringSlice/NonStrings", + json: `[1, 2, 3]`, + wantErr: true, + }, + { + name: "Number/Integer/0", + json: `0`, + want: RawValueOf(uint64(0)), + }, + { + name: "Number/Integer/1", + json: `1`, + want: RawValueOf(uint64(1)), + }, + { + name: "Number/Integer/MaxUInt64", + json: strconv.FormatUint(math.MaxUint64, 10), + want: RawValueOf(uint64(math.MaxUint64)), + }, + { + name: "Number/Integer/Negative", + json: `-1`, + wantErr: true, + }, + { + name: "Object", + json: `{}`, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got RawValue + gotErr := jsonv2.Unmarshal([]byte(tt.json), &got) + if (gotErr != nil) != tt.wantErr { + t.Fatalf("Error: got %v; want %v", gotErr, tt.wantErr) + } + + if !tt.wantErr && !reflect.DeepEqual(got, tt.want) { + t.Fatalf("Value: got %v; want %v", got, tt.want) + } + }) + } +} diff --git a/util/syspolicy/setting/snapshot_test.go b/util/syspolicy/setting/snapshot_test.go index e198d4a58bfdb..297685e29bf2e 100644 --- a/util/syspolicy/setting/snapshot_test.go +++ b/util/syspolicy/setting/snapshot_test.go @@ -30,134 +30,134 @@ func TestMergeSnapshots(t *testing.T) { name: "first-nil", s1: nil, s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), }, { name: "first-empty", s1: NewSnapshot(map[Key]RawItem{}), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), }, { name: "second-nil", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), s2: nil, want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), }, { name: "second-empty", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), s2: NewSnapshot(map[Key]RawItem{}), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), }, { name: "no-conflicts", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), s2: NewSnapshot(map[Key]RawItem{ - "Setting4": {value: 2 * time.Hour}, - "Setting5": {value: VisibleByPolicy}, - "Setting6": {value: ShowChoiceByPolicy}, + "Setting4": RawItemOf(2 * time.Hour), + "Setting5": RawItemOf(VisibleByPolicy), + "Setting6": RawItemOf(ShowChoiceByPolicy), }), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, - "Setting5": {value: VisibleByPolicy}, - "Setting6": {value: ShowChoiceByPolicy}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), + "Setting5": RawItemOf(VisibleByPolicy), + "Setting6": RawItemOf(ShowChoiceByPolicy), }), }, { name: "with-conflicts", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 456}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, + "Setting1": RawItemOf(456), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), }), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 456}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, + "Setting1": RawItemOf(456), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), }), }, { name: "with-scope-first-wins", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }, DeviceScope), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 456}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, + "Setting1": RawItemOf(456), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), }, CurrentUserScope), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, - "Setting4": {value: 2 * time.Hour}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), + "Setting4": RawItemOf(2 * time.Hour), }, CurrentUserScope), }, { name: "with-scope-second-wins", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }, CurrentUserScope), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 456}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, + "Setting1": RawItemOf(456), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), }, DeviceScope), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 456}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, + "Setting1": RawItemOf(456), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), }, CurrentUserScope), }, { @@ -170,28 +170,27 @@ func TestMergeSnapshots(t *testing.T) { name: "with-scope-first-empty", s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}}, - DeviceScope, NewNamedOrigin("TestPolicy", DeviceScope)), + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true)}, DeviceScope, NewNamedOrigin("TestPolicy", DeviceScope)), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }, CurrentUserScope, NewNamedOrigin("TestPolicy", DeviceScope)), }, { name: "with-scope-second-empty", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }, CurrentUserScope), s2: NewSnapshot(map[Key]RawItem{}), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }, CurrentUserScope), }, } @@ -244,9 +243,9 @@ func TestSnapshotEqual(t *testing.T) { name: "first-nil", s1: nil, s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), wantEqual: false, wantEqualItems: false, @@ -255,9 +254,9 @@ func TestSnapshotEqual(t *testing.T) { name: "first-empty", s1: NewSnapshot(map[Key]RawItem{}), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), wantEqual: false, wantEqualItems: false, @@ -265,9 +264,9 @@ func TestSnapshotEqual(t *testing.T) { { name: "second-nil", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), s2: nil, wantEqual: false, @@ -276,9 +275,9 @@ func TestSnapshotEqual(t *testing.T) { { name: "second-empty", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), s2: NewSnapshot(map[Key]RawItem{}), wantEqual: false, @@ -287,14 +286,14 @@ func TestSnapshotEqual(t *testing.T) { { name: "same-items-same-order-no-scope", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), wantEqual: true, wantEqualItems: true, @@ -302,14 +301,14 @@ func TestSnapshotEqual(t *testing.T) { { name: "same-items-same-order-same-scope", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, DeviceScope), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, DeviceScope), wantEqual: true, wantEqualItems: true, @@ -317,14 +316,14 @@ func TestSnapshotEqual(t *testing.T) { { name: "same-items-different-order-same-scope", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, DeviceScope), s2: NewSnapshot(map[Key]RawItem{ - "Setting3": {value: false}, - "Setting1": {value: 123}, - "Setting2": {value: "String"}, + "Setting3": RawItemOf(false), + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), }, DeviceScope), wantEqual: true, wantEqualItems: true, @@ -332,14 +331,14 @@ func TestSnapshotEqual(t *testing.T) { { name: "same-items-same-order-different-scope", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, DeviceScope), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, CurrentUserScope), wantEqual: false, wantEqualItems: true, @@ -347,14 +346,14 @@ func TestSnapshotEqual(t *testing.T) { { name: "different-items-same-scope", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, DeviceScope), s2: NewSnapshot(map[Key]RawItem{ - "Setting4": {value: 2 * time.Hour}, - "Setting5": {value: VisibleByPolicy}, - "Setting6": {value: ShowChoiceByPolicy}, + "Setting4": RawItemOf(2 * time.Hour), + "Setting5": RawItemOf(VisibleByPolicy), + "Setting6": RawItemOf(ShowChoiceByPolicy), }, DeviceScope), wantEqual: false, wantEqualItems: false, @@ -401,9 +400,9 @@ func TestSnapshotString(t *testing.T) { { name: "non-empty", snapshot: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 2 * time.Hour}, - "Setting2": {value: VisibleByPolicy}, - "Setting3": {value: ShowChoiceByPolicy}, + "Setting1": RawItemOf(2 * time.Hour), + "Setting2": RawItemOf(VisibleByPolicy), + "Setting3": RawItemOf(ShowChoiceByPolicy), }, NewNamedOrigin("Test Policy", DeviceScope)), wantString: `{Test Policy (Device)} Setting1 = 2h0m0s @@ -413,14 +412,14 @@ Setting3 = user-decides`, { name: "non-empty-with-item-origin", snapshot: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 42, origin: NewNamedOrigin("Test Policy", DeviceScope)}, + "Setting1": RawItemWith(42, nil, NewNamedOrigin("Test Policy", DeviceScope)), }), wantString: `Setting1 = 42 - {Test Policy (Device)}`, }, { name: "non-empty-with-item-error", snapshot: NewSnapshot(map[Key]RawItem{ - "Setting1": {err: NewErrorText("bang!")}, + "Setting1": RawItemWith(nil, NewErrorText("bang!"), nil), }), wantString: `Setting1 = Error{"bang!"}`, }, From 540e4c83d08ddfc506db35b27595fd818c14199c Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Tue, 29 Oct 2024 11:29:02 -0500 Subject: [PATCH 058/179] util/syspolicy/setting: make setting.Snapshot JSON-marshallable We make setting.Snapshot JSON-marshallable in preparation for returning it from the LocalAPI. Updates #12687 Signed-off-by: Nick Khyl --- util/syspolicy/setting/snapshot.go | 45 ++++++++ util/syspolicy/setting/snapshot_test.go | 135 ++++++++++++++++++++++++ 2 files changed, 180 insertions(+) diff --git a/util/syspolicy/setting/snapshot.go b/util/syspolicy/setting/snapshot.go index 512bc487c5b98..0af2bae0f480a 100644 --- a/util/syspolicy/setting/snapshot.go +++ b/util/syspolicy/setting/snapshot.go @@ -4,11 +4,14 @@ package setting import ( + "errors" "iter" "maps" "slices" "strings" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" xmaps "golang.org/x/exp/maps" "tailscale.com/util/deephash" ) @@ -65,6 +68,9 @@ func (s *Snapshot) GetSetting(k Key) (setting RawItem, ok bool) { // Equal reports whether s and s2 are equal. func (s *Snapshot) Equal(s2 *Snapshot) bool { + if s == s2 { + return true + } if !s.EqualItems(s2) { return false } @@ -135,6 +141,45 @@ func (s *Snapshot) String() string { return sb.String() } +// snapshotJSON holds JSON-marshallable data for [Snapshot]. +type snapshotJSON struct { + Summary Summary `json:",omitzero"` + Settings map[Key]RawItem `json:",omitempty"` +} + +// MarshalJSONV2 implements [jsonv2.MarshalerV2]. +func (s *Snapshot) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { + data := &snapshotJSON{} + if s != nil { + data.Summary = s.summary + data.Settings = s.m + } + return jsonv2.MarshalEncode(out, data, opts) +} + +// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. +func (s *Snapshot) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { + if s == nil { + return errors.New("s must not be nil") + } + data := &snapshotJSON{} + if err := jsonv2.UnmarshalDecode(in, data, opts); err != nil { + return err + } + *s = Snapshot{m: data.Settings, sig: deephash.Hash(&data.Settings), summary: data.Summary} + return nil +} + +// MarshalJSON implements [json.Marshaler]. +func (s *Snapshot) MarshalJSON() ([]byte, error) { + return jsonv2.Marshal(s) // uses MarshalJSONV2 +} + +// UnmarshalJSON implements [json.Unmarshaler]. +func (s *Snapshot) UnmarshalJSON(b []byte) error { + return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2 +} + // MergeSnapshots returns a [Snapshot] that contains all [RawItem]s // from snapshot1 and snapshot2 and the [Summary] with the narrower [PolicyScope]. // If there's a conflict between policy settings in the two snapshots, diff --git a/util/syspolicy/setting/snapshot_test.go b/util/syspolicy/setting/snapshot_test.go index 297685e29bf2e..d41b362f06976 100644 --- a/util/syspolicy/setting/snapshot_test.go +++ b/util/syspolicy/setting/snapshot_test.go @@ -4,8 +4,13 @@ package setting import ( + "cmp" + "encoding/json" "testing" "time" + + jsonv2 "github.com/go-json-experiment/json" + "tailscale.com/util/syspolicy/internal" ) func TestMergeSnapshots(t *testing.T) { @@ -432,3 +437,133 @@ Setting3 = user-decides`, }) } } + +func TestMarshalUnmarshalSnapshot(t *testing.T) { + tests := []struct { + name string + snapshot *Snapshot + wantJSON string + wantBack *Snapshot + }{ + { + name: "Nil", + snapshot: (*Snapshot)(nil), + wantJSON: "null", + wantBack: NewSnapshot(nil), + }, + { + name: "Zero", + snapshot: &Snapshot{}, + wantJSON: "{}", + }, + { + name: "Bool/True", + snapshot: NewSnapshot(map[Key]RawItem{"BoolPolicy": RawItemOf(true)}), + wantJSON: `{"Settings": {"BoolPolicy": {"Value": true}}}`, + }, + { + name: "Bool/False", + snapshot: NewSnapshot(map[Key]RawItem{"BoolPolicy": RawItemOf(false)}), + wantJSON: `{"Settings": {"BoolPolicy": {"Value": false}}}`, + }, + { + name: "String/Non-Empty", + snapshot: NewSnapshot(map[Key]RawItem{"StringPolicy": RawItemOf("StringValue")}), + wantJSON: `{"Settings": {"StringPolicy": {"Value": "StringValue"}}}`, + }, + { + name: "String/Empty", + snapshot: NewSnapshot(map[Key]RawItem{"StringPolicy": RawItemOf("")}), + wantJSON: `{"Settings": {"StringPolicy": {"Value": ""}}}`, + }, + { + name: "Integer/NonZero", + snapshot: NewSnapshot(map[Key]RawItem{"IntPolicy": RawItemOf(uint64(42))}), + wantJSON: `{"Settings": {"IntPolicy": {"Value": 42}}}`, + }, + { + name: "Integer/Zero", + snapshot: NewSnapshot(map[Key]RawItem{"IntPolicy": RawItemOf(uint64(0))}), + wantJSON: `{"Settings": {"IntPolicy": {"Value": 0}}}`, + }, + { + name: "String-List", + snapshot: NewSnapshot(map[Key]RawItem{"ListPolicy": RawItemOf([]string{"Value1", "Value2"})}), + wantJSON: `{"Settings": {"ListPolicy": {"Value": ["Value1", "Value2"]}}}`, + }, + { + name: "Empty/With-Summary", + snapshot: NewSnapshot( + map[Key]RawItem{}, + SummaryWith(CurrentUserScope, NewNamedOrigin("TestSource", DeviceScope)), + ), + wantJSON: `{"Summary": {"Origin": {"Name": "TestSource", "Scope": "Device"}, "Scope": "User"}}`, + }, + { + name: "Setting/With-Summary", + snapshot: NewSnapshot( + map[Key]RawItem{"PolicySetting": RawItemOf(uint64(42))}, + SummaryWith(CurrentUserScope, NewNamedOrigin("TestSource", DeviceScope)), + ), + wantJSON: `{ + "Summary": {"Origin": {"Name": "TestSource", "Scope": "Device"}, "Scope": "User"}, + "Settings": {"PolicySetting": {"Value": 42}} + }`, + }, + { + name: "Settings/With-Origins", + snapshot: NewSnapshot( + map[Key]RawItem{ + "SettingA": RawItemWith(uint64(42), nil, NewNamedOrigin("SourceA", DeviceScope)), + "SettingB": RawItemWith("B", nil, NewNamedOrigin("SourceB", CurrentProfileScope)), + "SettingC": RawItemWith(true, nil, NewNamedOrigin("SourceC", CurrentUserScope)), + }, + ), + wantJSON: `{ + "Settings": { + "SettingA": {"Value": 42, "Origin": {"Name": "SourceA", "Scope": "Device"}}, + "SettingB": {"Value": "B", "Origin": {"Name": "SourceB", "Scope": "Profile"}}, + "SettingC": {"Value": true, "Origin": {"Name": "SourceC", "Scope": "User"}} + } + }`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + doTest := func(t *testing.T, useJSONv2 bool) { + var gotJSON []byte + var err error + if useJSONv2 { + gotJSON, err = jsonv2.Marshal(tt.snapshot) + } else { + gotJSON, err = json.Marshal(tt.snapshot) + } + if err != nil { + t.Fatal(err) + } + + if got, want, equal := internal.EqualJSONForTest(t, gotJSON, []byte(tt.wantJSON)); !equal { + t.Errorf("JSON: got %s; want %s", got, want) + } + + gotBack := &Snapshot{} + if useJSONv2 { + err = jsonv2.Unmarshal(gotJSON, &gotBack) + } else { + err = json.Unmarshal(gotJSON, &gotBack) + } + if err != nil { + t.Fatal(err) + } + + if wantBack := cmp.Or(tt.wantBack, tt.snapshot); !gotBack.Equal(wantBack) { + t.Errorf("Snapshot: got %+v; want %+v", gotBack, wantBack) + } + } + + t.Run("json", func(t *testing.T) { doTest(t, false) }) + t.Run("jsonv2", func(t *testing.T) { doTest(t, true) }) + }) + } +} From f81348a16b6dd8705cd75379daf3b7490185e841 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 30 Oct 2024 09:48:12 -0700 Subject: [PATCH 059/179] util/syspolicy/source: put EnvPolicyStore env keys in their own namespace ... all prefixed with TS_DEBUGSYSPOLICY_*. Updates #13193 Updates #12687 Updates #13855 Change-Id: Ia8024946f53e2b3afda4456a7bb85bbcf6d12bfc Signed-off-by: Brad Fitzpatrick --- util/syspolicy/source/env_policy_store.go | 2 +- .../syspolicy/source/env_policy_store_test.go | 85 ++++++++++--------- 2 files changed, 46 insertions(+), 41 deletions(-) diff --git a/util/syspolicy/source/env_policy_store.go b/util/syspolicy/source/env_policy_store.go index 61065ceff4d4c..299132b4e11b3 100644 --- a/util/syspolicy/source/env_policy_store.go +++ b/util/syspolicy/source/env_policy_store.go @@ -114,7 +114,7 @@ func keyToEnvVarName(key setting.Key) (string, error) { isDigit := func(c byte) bool { return '0' <= c && c <= '9' } words := make([]string, 0, 8) - words = append(words, "TS") + words = append(words, "TS_DEBUGSYSPOLICY") var currentWord strings.Builder for i := 0; i < len(key); i++ { c := key[i] diff --git a/util/syspolicy/source/env_policy_store_test.go b/util/syspolicy/source/env_policy_store_test.go index 364a6104d4f99..9eacf6378b450 100644 --- a/util/syspolicy/source/env_policy_store_test.go +++ b/util/syspolicy/source/env_policy_store_test.go @@ -14,11 +14,11 @@ import ( "tailscale.com/util/syspolicy/setting" ) -func TestKeyToVariableName(t *testing.T) { +func TestKeyToEnvVarName(t *testing.T) { tests := []struct { name string key setting.Key - want string + want string // suffix after "TS_DEBUGSYSPOLICY_" wantErr error }{ { @@ -29,87 +29,87 @@ func TestKeyToVariableName(t *testing.T) { { name: "lowercase", key: "tailnet", - want: "TS_TAILNET", + want: "TAILNET", }, { name: "CamelCase", key: "AuthKey", - want: "TS_AUTH_KEY", + want: "AUTH_KEY", }, { name: "LongerCamelCase", key: "ManagedByOrganizationName", - want: "TS_MANAGED_BY_ORGANIZATION_NAME", + want: "MANAGED_BY_ORGANIZATION_NAME", }, { name: "UPPERCASE", key: "UPPERCASE", - want: "TS_UPPERCASE", + want: "UPPERCASE", }, { name: "WithAbbrev/Front", key: "DNSServer", - want: "TS_DNS_SERVER", + want: "DNS_SERVER", }, { name: "WithAbbrev/Middle", key: "ExitNodeAllowLANAccess", - want: "TS_EXIT_NODE_ALLOW_LAN_ACCESS", + want: "EXIT_NODE_ALLOW_LAN_ACCESS", }, { name: "WithAbbrev/Back", key: "ExitNodeID", - want: "TS_EXIT_NODE_ID", + want: "EXIT_NODE_ID", }, { name: "WithDigits/Single/Front", key: "0TestKey", - want: "TS_0_TEST_KEY", + want: "0_TEST_KEY", }, { name: "WithDigits/Multi/Front", key: "64TestKey", - want: "TS_64_TEST_KEY", + want: "64_TEST_KEY", }, { name: "WithDigits/Single/Middle", key: "Test0Key", - want: "TS_TEST_0_KEY", + want: "TEST_0_KEY", }, { name: "WithDigits/Multi/Middle", key: "Test64Key", - want: "TS_TEST_64_KEY", + want: "TEST_64_KEY", }, { name: "WithDigits/Single/Back", key: "TestKey0", - want: "TS_TEST_KEY_0", + want: "TEST_KEY_0", }, { name: "WithDigits/Multi/Back", key: "TestKey64", - want: "TS_TEST_KEY_64", + want: "TEST_KEY_64", }, { name: "WithDigits/Multi/Back", key: "TestKey64", - want: "TS_TEST_KEY_64", + want: "TEST_KEY_64", }, { name: "WithPathSeparators/Single", key: "Key/Subkey", - want: "TS_KEY_SUBKEY", + want: "KEY_SUBKEY", }, { name: "WithPathSeparators/Multi", key: "Root/Level1/Level2", - want: "TS_ROOT_LEVEL_1_LEVEL_2", + want: "ROOT_LEVEL_1_LEVEL_2", }, { name: "Mixed", key: "Network/DNSServer/IPAddress", - want: "TS_NETWORK_DNS_SERVER_IP_ADDRESS", + want: "NETWORK_DNS_SERVER_IP_ADDRESS", }, { name: "Non-Alphanumeric/NonASCII/1", @@ -142,8 +142,12 @@ func TestKeyToVariableName(t *testing.T) { got, err := keyToEnvVarName(tt.key) checkError(t, err, tt.wantErr, true) - if got != tt.want { - t.Fatalf("got %q; want %q", got, tt.want) + want := tt.want + if want != "" { + want = "TS_DEBUGSYSPOLICY_" + want + } + if got != want { + t.Fatalf("got %q; want %q", got, want) } }) } @@ -152,6 +156,7 @@ func TestKeyToVariableName(t *testing.T) { func TestEnvPolicyStore(t *testing.T) { blankEnv := func(string) (string, bool) { return "", false } makeEnv := func(wantName, value string) func(string) (string, bool) { + wantName = "TS_DEBUGSYSPOLICY_" + wantName return func(gotName string) (string, bool) { if gotName != wantName { return "", false @@ -176,13 +181,13 @@ func TestEnvPolicyStore(t *testing.T) { { name: "Configured/String/Empty", key: "AuthKey", - lookup: makeEnv("TS_AUTH_KEY", ""), + lookup: makeEnv("AUTH_KEY", ""), want: "", }, { name: "Configured/String/NonEmpty", key: "AuthKey", - lookup: makeEnv("TS_AUTH_KEY", "ABC123"), + lookup: makeEnv("AUTH_KEY", "ABC123"), want: "ABC123", }, { @@ -195,39 +200,39 @@ func TestEnvPolicyStore(t *testing.T) { { name: "Configured/UInt64/Empty", key: "IntegerSetting", - lookup: makeEnv("TS_INTEGER_SETTING", ""), + lookup: makeEnv("INTEGER_SETTING", ""), wantErr: setting.ErrNotConfigured, want: uint64(0), }, { name: "Configured/UInt64/Zero", key: "IntegerSetting", - lookup: makeEnv("TS_INTEGER_SETTING", "0"), + lookup: makeEnv("INTEGER_SETTING", "0"), want: uint64(0), }, { name: "Configured/UInt64/NonZero", key: "IntegerSetting", - lookup: makeEnv("TS_INTEGER_SETTING", "12345"), + lookup: makeEnv("INTEGER_SETTING", "12345"), want: uint64(12345), }, { name: "Configured/UInt64/MaxUInt64", key: "IntegerSetting", - lookup: makeEnv("TS_INTEGER_SETTING", strconv.FormatUint(math.MaxUint64, 10)), + lookup: makeEnv("INTEGER_SETTING", strconv.FormatUint(math.MaxUint64, 10)), want: uint64(math.MaxUint64), }, { name: "Configured/UInt64/Negative", key: "IntegerSetting", - lookup: makeEnv("TS_INTEGER_SETTING", "-1"), + lookup: makeEnv("INTEGER_SETTING", "-1"), wantErr: setting.ErrTypeMismatch, want: uint64(0), }, { name: "Configured/UInt64/Hex", key: "IntegerSetting", - lookup: makeEnv("TS_INTEGER_SETTING", "0xDEADBEEF"), + lookup: makeEnv("INTEGER_SETTING", "0xDEADBEEF"), want: uint64(0xDEADBEEF), }, { @@ -240,38 +245,38 @@ func TestEnvPolicyStore(t *testing.T) { { name: "Configured/Bool/Empty", key: "LogSCMInteractions", - lookup: makeEnv("TS_LOG_SCM_INTERACTIONS", ""), + lookup: makeEnv("LOG_SCM_INTERACTIONS", ""), wantErr: setting.ErrNotConfigured, want: false, }, { name: "Configured/Bool/True", key: "LogSCMInteractions", - lookup: makeEnv("TS_LOG_SCM_INTERACTIONS", "true"), + lookup: makeEnv("LOG_SCM_INTERACTIONS", "true"), want: true, }, { name: "Configured/Bool/False", key: "LogSCMInteractions", - lookup: makeEnv("TS_LOG_SCM_INTERACTIONS", "False"), + lookup: makeEnv("LOG_SCM_INTERACTIONS", "False"), want: false, }, { name: "Configured/Bool/1", key: "LogSCMInteractions", - lookup: makeEnv("TS_LOG_SCM_INTERACTIONS", "1"), + lookup: makeEnv("LOG_SCM_INTERACTIONS", "1"), want: true, }, { name: "Configured/Bool/0", key: "LogSCMInteractions", - lookup: makeEnv("TS_LOG_SCM_INTERACTIONS", "0"), + lookup: makeEnv("LOG_SCM_INTERACTIONS", "0"), want: false, }, { name: "Configured/Bool/Invalid", key: "IntegerSetting", - lookup: makeEnv("TS_INTEGER_SETTING", "NotABool"), + lookup: makeEnv("INTEGER_SETTING", "NotABool"), wantErr: setting.ErrTypeMismatch, want: false, }, @@ -285,31 +290,31 @@ func TestEnvPolicyStore(t *testing.T) { { name: "Configured/StringArray/Empty", key: "AllowedSuggestedExitNodes", - lookup: makeEnv("TS_ALLOWED_SUGGESTED_EXIT_NODES", ""), + lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", ""), want: []string(nil), }, { name: "Configured/StringArray/Spaces", key: "AllowedSuggestedExitNodes", - lookup: makeEnv("TS_ALLOWED_SUGGESTED_EXIT_NODES", " \t "), + lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", " \t "), want: []string{}, }, { name: "Configured/StringArray/Single", key: "AllowedSuggestedExitNodes", - lookup: makeEnv("TS_ALLOWED_SUGGESTED_EXIT_NODES", "NodeA"), + lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", "NodeA"), want: []string{"NodeA"}, }, { name: "Configured/StringArray/Multi", key: "AllowedSuggestedExitNodes", - lookup: makeEnv("TS_ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,NodeB,NodeC"), + lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,NodeB,NodeC"), want: []string{"NodeA", "NodeB", "NodeC"}, }, { name: "Configured/StringArray/WithBlank", key: "AllowedSuggestedExitNodes", - lookup: makeEnv("TS_ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,\t,, ,NodeB"), + lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,\t,, ,NodeB"), want: []string{"NodeA", "NodeB"}, }, } From e1e22785b4cdef300ca89206f307f867ba262c6e Mon Sep 17 00:00:00 2001 From: James Tucker Date: Wed, 30 Oct 2024 11:31:36 -0700 Subject: [PATCH 060/179] net/netcheck: ensure prior preferred DERP is always in netchecks In an environment with unstable latency, such as upstream bufferbloat, there are cases where a full netcheck could drop the prior preferred DERP (likely home DERP) from future netcheck probe plans. This will then likely result in a home DERP having a missing sample on the next incremental netcheck, ultimately resulting in a home DERP move. This change does not fix our overall response to highly unstable latency, but it is an incremental improvement to prevent single spurious samples during a full netcheck from alone triggering a flapping condition, as now the prior changes to include historical latency will still provide the desired resistance, and the home DERP should not move unless latency is consistently worse over a 5 minute period. Note that there is a nomenclature and semantics issue remaining in the difference between a report preferred DERP and a home DERP. A report preferred DERP is aspirational, it is what will be picked as a home DERP if a home DERP connection needs to be established. A nodes home DERP may be different than a recent preferred DERP, in which case a lot of netcheck logic is fallible. In future enhancements much of the DERP move logic should move to consider the home DERP, rather than recent report preferred DERP. Updates #8603 Updates #13969 Signed-off-by: James Tucker --- net/netcheck/netcheck.go | 68 +++++++++++++++++++++++++++-------- net/netcheck/netcheck_test.go | 42 ++++++++++++++++++++-- 2 files changed, 93 insertions(+), 17 deletions(-) diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index 1714837305ac1..2c429862eb133 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -392,10 +392,11 @@ type probePlan map[string][]probe // sortRegions returns the regions of dm first sorted // from fastest to slowest (based on the 'last' report), // end in regions that have no data. -func sortRegions(dm *tailcfg.DERPMap, last *Report) (prev []*tailcfg.DERPRegion) { +func sortRegions(dm *tailcfg.DERPMap, last *Report, preferredDERP int) (prev []*tailcfg.DERPRegion) { prev = make([]*tailcfg.DERPRegion, 0, len(dm.Regions)) for _, reg := range dm.Regions { - if reg.Avoid { + // include an otherwise avoid region if it is the current preferred region + if reg.Avoid && reg.RegionID != preferredDERP { continue } prev = append(prev, reg) @@ -420,9 +421,19 @@ func sortRegions(dm *tailcfg.DERPMap, last *Report) (prev []*tailcfg.DERPRegion) // a full report, all regions are scanned.) const numIncrementalRegions = 3 -// makeProbePlan generates the probe plan for a DERPMap, given the most -// recent report and whether IPv6 is configured on an interface. -func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report) (plan probePlan) { +// makeProbePlan generates the probe plan for a DERPMap, given the most recent +// report and the current home DERP. preferredDERP is passed independently of +// last (report) because last is currently nil'd to indicate a desire for a full +// netcheck. +// +// TODO(raggi,jwhited): refactor the callers and this function to be more clear +// about full vs. incremental netchecks, and remove the need for the history +// hiding. This was avoided in an incremental change due to exactly this kind of +// distant coupling. +// TODO(raggi): change from "preferred DERP" from a historical report to "home +// DERP" as in what DERP is the current home connection, this would further +// reduce flap events. +func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report, preferredDERP int) (plan probePlan) { if last == nil || len(last.RegionLatency) == 0 { return makeProbePlanInitial(dm, ifState) } @@ -433,9 +444,34 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report) (pl had4 := len(last.RegionV4Latency) > 0 had6 := len(last.RegionV6Latency) > 0 hadBoth := have6if && had4 && had6 - for ri, reg := range sortRegions(dm, last) { - if ri == numIncrementalRegions { - break + // #13969 ensure that the home region is always probed. + // If a netcheck has unstable latency, such as a user with large amounts of + // bufferbloat or a highly congested connection, there are cases where a full + // netcheck may observe a one-off high latency to the current home DERP. Prior + // to the forced inclusion of the home DERP, this would result in an + // incremental netcheck following such an event to cause a home DERP move, with + // restoration back to the home DERP on the next full netcheck ~5 minutes later + // - which is highly disruptive when it causes shifts in geo routed subnet + // routers. By always including the home DERP in the incremental netcheck, we + // ensure that the home DERP is always probed, even if it observed a recenet + // poor latency sample. This inclusion enables the latency history checks in + // home DERP selection to still take effect. + // planContainsHome indicates whether the home DERP has been added to the probePlan, + // if there is no prior home, then there's no home to additionally include. + planContainsHome := preferredDERP == 0 + for ri, reg := range sortRegions(dm, last, preferredDERP) { + regIsHome := reg.RegionID == preferredDERP + if ri >= numIncrementalRegions { + // planned at least numIncrementalRegions regions and that includes the + // last home region (or there was none), plan complete. + if planContainsHome { + break + } + // planned at least numIncrementalRegions regions, but not the home region, + // check if this is the home region, if not, skip it. + if !regIsHome { + continue + } } var p4, p6 []probe do4 := have4if @@ -446,7 +482,7 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report) (pl tries := 1 isFastestTwo := ri < 2 - if isFastestTwo { + if isFastestTwo || regIsHome { tries = 2 } else if hadBoth { // For dual stack machines, make the 3rd & slower nodes alternate @@ -457,14 +493,15 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report) (pl do4, do6 = false, true } } - if !isFastestTwo && !had6 { + if !regIsHome && !isFastestTwo && !had6 { do6 = false } - if reg.RegionID == last.PreferredDERP { + if regIsHome { // But if we already had a DERP home, try extra hard to // make sure it's there so we don't flip flop around. tries = 4 + planContainsHome = true } for try := 0; try < tries; try++ { @@ -789,9 +826,10 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe c.curState = rs last := c.last - // Even if we're doing a non-incremental update, we may want to try our - // preferred DERP region for captive portal detection. Save that, if we - // have it. + // Extract preferredDERP from the last report, if available. This will be used + // in captive portal detection and DERP flapping suppression. Ideally this would + // be the current active home DERP rather than the last report preferred DERP, + // but only the latter is presently available. var preferredDERP int if last != nil { preferredDERP = last.PreferredDERP @@ -848,7 +886,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe var plan probePlan if opts == nil || !opts.OnlyTCP443 { - plan = makeProbePlan(dm, ifState, last) + plan = makeProbePlan(dm, ifState, last, preferredDERP) } // If we're doing a full probe, also check for a captive portal. We diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index 2780c9c44b08c..f287978d225e7 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -590,6 +590,40 @@ func TestMakeProbePlan(t *testing.T) { "region-3-v4": []probe{p("3a", 4)}, }, }, + { + // #13969: ensure that the prior/current home region is always included in + // probe plans, so that we don't flap between regions due to a single major + // netcheck having excluded the home region due to a spuriously high sample. + name: "ensure_home_region_inclusion", + dm: basicMap, + have6if: true, + last: &Report{ + RegionLatency: map[int]time.Duration{ + 1: 50 * time.Millisecond, + 2: 20 * time.Millisecond, + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + }, + RegionV4Latency: map[int]time.Duration{ + 1: 50 * time.Millisecond, + 2: 20 * time.Millisecond, + }, + RegionV6Latency: map[int]time.Duration{ + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + }, + PreferredDERP: 1, + }, + want: probePlan{ + "region-1-v4": []probe{p("1a", 4), p("1a", 4, 60*ms), p("1a", 4, 220*ms), p("1a", 4, 330*ms)}, + "region-1-v6": []probe{p("1a", 6), p("1a", 6, 60*ms), p("1a", 6, 220*ms), p("1a", 6, 330*ms)}, + "region-2-v4": []probe{p("2a", 4), p("2b", 4, 24*ms)}, + "region-2-v6": []probe{p("2a", 6), p("2b", 6, 24*ms)}, + "region-3-v4": []probe{p("3a", 4), p("3b", 4, 36*ms)}, + "region-3-v6": []probe{p("3a", 6), p("3b", 6, 36*ms)}, + "region-4-v4": []probe{p("4a", 4)}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -597,7 +631,11 @@ func TestMakeProbePlan(t *testing.T) { HaveV6: tt.have6if, HaveV4: !tt.no4, } - got := makeProbePlan(tt.dm, ifState, tt.last) + preferredDERP := 0 + if tt.last != nil { + preferredDERP = tt.last.PreferredDERP + } + got := makeProbePlan(tt.dm, ifState, tt.last, preferredDERP) if !reflect.DeepEqual(got, tt.want) { t.Errorf("unexpected plan; got:\n%v\nwant:\n%v\n", got, tt.want) } @@ -770,7 +808,7 @@ func TestSortRegions(t *testing.T) { report.RegionLatency[3] = time.Second * time.Duration(6) report.RegionLatency[4] = time.Second * time.Duration(0) report.RegionLatency[5] = time.Second * time.Duration(2) - sortedMap := sortRegions(unsortedMap, report) + sortedMap := sortRegions(unsortedMap, report, 0) // Sorting by latency this should result in rid: 5, 2, 1, 3 // rid 4 with latency 0 should be at the end From 532b26145a088c3946c37040dc4731dc4edcb7cf Mon Sep 17 00:00:00 2001 From: Anton Tolchanov Date: Tue, 29 Oct 2024 13:46:34 +0000 Subject: [PATCH 061/179] wgengine/magicsock: exclude disco from throughput metrics The user-facing metrics are intended to track data transmitted at the overlay network level. Updates tailscale/corp#22075 Signed-off-by: Anton Tolchanov --- wgengine/magicsock/derp.go | 14 ++++++++------ wgengine/magicsock/endpoint.go | 3 ++- wgengine/magicsock/magicsock.go | 7 ++++--- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/wgengine/magicsock/derp.go b/wgengine/magicsock/derp.go index 704ce3c4ff2b5..0204fa0f5a011 100644 --- a/wgengine/magicsock/derp.go +++ b/wgengine/magicsock/derp.go @@ -649,9 +649,10 @@ func (c *Conn) runDerpReader(ctx context.Context, regionID int, dc *derphttp.Cli } type derpWriteRequest struct { - addr netip.AddrPort - pubKey key.NodePublic - b []byte // copied; ownership passed to receiver + addr netip.AddrPort + pubKey key.NodePublic + b []byte // copied; ownership passed to receiver + isDisco bool } // runDerpWriter runs in a goroutine for the life of a DERP @@ -673,7 +674,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan if err != nil { c.logf("magicsock: derp.Send(%v): %v", wr.addr, err) metricSendDERPError.Add(1) - } else { + } else if !wr.isDisco { c.metrics.outboundPacketsDERPTotal.Add(1) c.metrics.outboundBytesDERPTotal.Add(int64(len(wr.b))) } @@ -696,8 +697,6 @@ func (c *connBind) receiveDERP(buffs [][]byte, sizes []int, eps []conn.Endpoint) // No data read occurred. Wait for another packet. continue } - c.metrics.inboundPacketsDERPTotal.Add(1) - c.metrics.inboundBytesDERPTotal.Add(int64(n)) sizes[0] = n eps[0] = ep return 1, nil @@ -737,6 +736,9 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en if stats := c.stats.Load(); stats != nil { stats.UpdateRxPhysical(ep.nodeAddr, ipp, 1, dm.n) } + + c.metrics.inboundPacketsDERPTotal.Add(1) + c.metrics.inboundBytesDERPTotal.Add(int64(n)) return n, ep } diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index 1ddde97524571..5e0ada6170c2f 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -983,7 +983,8 @@ func (de *endpoint) send(buffs [][]byte) error { allOk := true var txBytes int for _, buff := range buffs { - ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff) + const isDisco = false + ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff, isDisco) txBytes += len(buff) if !ok { allOk = false diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 72e59a2e72c62..705e42d9ef84c 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -1356,7 +1356,7 @@ func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error) // An example of when they might be different: sending to an // IPv6 address when the local machine doesn't have IPv6 support // returns (false, nil); it's not an error, but nothing was sent. -func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte) (sent bool, err error) { +func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte, isDisco bool) (sent bool, err error) { if addr.Addr() != tailcfg.DerpMagicIPAddr { return c.sendUDP(addr, b) } @@ -1379,7 +1379,7 @@ func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte) (s case <-c.donec: metricSendDERPErrorClosed.Add(1) return false, errConnClosed - case ch <- derpWriteRequest{addr, pubKey, pkt}: + case ch <- derpWriteRequest{addr, pubKey, pkt, isDisco}: metricSendDERPQueued.Add(1) return true, nil default: @@ -1577,7 +1577,8 @@ func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDi box := di.sharedKey.Seal(m.AppendMarshal(nil)) pkt = append(pkt, box...) - sent, err = c.sendAddr(dst, dstKey, pkt) + const isDisco = true + sent, err = c.sendAddr(dst, dstKey, pkt, isDisco) if sent { if logLevel == discoLog || (logLevel == discoVerboseLog && debugDisco()) { node := "?" From b4f46c31bbf8f079a0e617997e8b86f3c94247bd Mon Sep 17 00:00:00 2001 From: Anton Tolchanov Date: Tue, 29 Oct 2024 09:19:40 +0000 Subject: [PATCH 062/179] wgengine/magicsock: export packet drop metric for outbound errors This required sharing the dropped packet metric between two packages (tstun and magicsock), so I've moved its definition to util/usermetric. Updates tailscale/corp#22075 Signed-off-by: Anton Tolchanov --- net/tstun/wrap.go | 44 ++++-------------- net/tstun/wrap_test.go | 6 +-- util/usermetric/metrics.go | 69 ++++++++++++++++++++++++++++ util/usermetric/usermetric.go | 3 ++ wgengine/magicsock/derp.go | 3 ++ wgengine/magicsock/magicsock.go | 15 +++++- wgengine/magicsock/magicsock_test.go | 25 ++++++++++ 7 files changed, 127 insertions(+), 38 deletions(-) create mode 100644 util/usermetric/metrics.go diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 0b858fc1c5653..c384abf9d4bbe 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -213,24 +213,14 @@ type Wrapper struct { } type metrics struct { - inboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[dropPacketLabel] - outboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[dropPacketLabel] + inboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[usermetric.DropLabels] + outboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[usermetric.DropLabels] } func registerMetrics(reg *usermetric.Registry) *metrics { return &metrics{ - inboundDroppedPacketsTotal: usermetric.NewMultiLabelMapWithRegistry[dropPacketLabel]( - reg, - "tailscaled_inbound_dropped_packets_total", - "counter", - "Counts the number of dropped packets received by the node from other peers", - ), - outboundDroppedPacketsTotal: usermetric.NewMultiLabelMapWithRegistry[dropPacketLabel]( - reg, - "tailscaled_outbound_dropped_packets_total", - "counter", - "Counts the number of packets dropped while being sent to other peers", - ), + inboundDroppedPacketsTotal: reg.DroppedPacketsInbound(), + outboundDroppedPacketsTotal: reg.DroppedPacketsOutbound(), } } @@ -886,8 +876,8 @@ func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed, pc *peerConf if filt.RunOut(p, t.filterFlags) != filter.Accept { metricPacketOutDropFilter.Add(1) - t.metrics.outboundDroppedPacketsTotal.Add(dropPacketLabel{ - Reason: DropReasonACL, + t.metrics.outboundDroppedPacketsTotal.Add(usermetric.DropLabels{ + Reason: usermetric.ReasonACL, }, 1) return filter.Drop, gro } @@ -1158,8 +1148,8 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook ca if outcome != filter.Accept { metricPacketInDropFilter.Add(1) - t.metrics.inboundDroppedPacketsTotal.Add(dropPacketLabel{ - Reason: DropReasonACL, + t.metrics.inboundDroppedPacketsTotal.Add(usermetric.DropLabels{ + Reason: usermetric.ReasonACL, }, 1) // Tell them, via TSMP, we're dropping them due to the ACL. @@ -1239,8 +1229,8 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) { t.noteActivity() _, err := t.tdevWrite(buffs, offset) if err != nil { - t.metrics.inboundDroppedPacketsTotal.Add(dropPacketLabel{ - Reason: DropReasonError, + t.metrics.inboundDroppedPacketsTotal.Add(usermetric.DropLabels{ + Reason: usermetric.ReasonError, }, int64(len(buffs))) } return len(buffs), err @@ -1482,20 +1472,6 @@ var ( metricPacketOutDropSelfDisco = clientmetric.NewCounter("tstun_out_to_wg_drop_self_disco") ) -type DropReason string - -const ( - DropReasonACL DropReason = "acl" - DropReasonError DropReason = "error" -) - -type dropPacketLabel struct { - // Reason indicates what we have done with the packet, and has the following values: - // - acl (rejected packets because of ACL) - // - error (rejected packets because of an error) - Reason DropReason -} - func (t *Wrapper) InstallCaptureHook(cb capture.Callback) { t.captureHook.Store(cb) } diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index 0ed0075b616ee..9ebedda837b0a 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -441,13 +441,13 @@ func TestFilter(t *testing.T) { } var metricInboundDroppedPacketsACL, metricInboundDroppedPacketsErr, metricOutboundDroppedPacketsACL int64 - if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(dropPacketLabel{Reason: DropReasonACL}).(*expvar.Int); ok { + if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(usermetric.DropLabels{Reason: usermetric.ReasonACL}).(*expvar.Int); ok { metricInboundDroppedPacketsACL = m.Value() } - if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(dropPacketLabel{Reason: DropReasonError}).(*expvar.Int); ok { + if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(usermetric.DropLabels{Reason: usermetric.ReasonError}).(*expvar.Int); ok { metricInboundDroppedPacketsErr = m.Value() } - if m, ok := tun.metrics.outboundDroppedPacketsTotal.Get(dropPacketLabel{Reason: DropReasonACL}).(*expvar.Int); ok { + if m, ok := tun.metrics.outboundDroppedPacketsTotal.Get(usermetric.DropLabels{Reason: usermetric.ReasonACL}).(*expvar.Int); ok { metricOutboundDroppedPacketsACL = m.Value() } diff --git a/util/usermetric/metrics.go b/util/usermetric/metrics.go new file mode 100644 index 0000000000000..7f85989ff062a --- /dev/null +++ b/util/usermetric/metrics.go @@ -0,0 +1,69 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This file contains user-facing metrics that are used by multiple packages. +// Use it to define more common metrics. Any changes to the registry and +// metric types should be in usermetric.go. + +package usermetric + +import ( + "sync" + + "tailscale.com/metrics" +) + +// Metrics contains user-facing metrics that are used by multiple packages. +type Metrics struct { + initOnce sync.Once + + droppedPacketsInbound *metrics.MultiLabelMap[DropLabels] + droppedPacketsOutbound *metrics.MultiLabelMap[DropLabels] +} + +// DropReason is the reason why a packet was dropped. +type DropReason string + +const ( + // ReasonACL means that the packet was not permitted by ACL. + ReasonACL DropReason = "acl" + + // ReasonError means that the packet was dropped because of an error. + ReasonError DropReason = "error" +) + +// DropLabels contains common label(s) for dropped packet counters. +type DropLabels struct { + Reason DropReason +} + +// initOnce initializes the common metrics. +func (r *Registry) initOnce() { + r.m.initOnce.Do(func() { + r.m.droppedPacketsInbound = NewMultiLabelMapWithRegistry[DropLabels]( + r, + "tailscaled_inbound_dropped_packets_total", + "counter", + "Counts the number of dropped packets received by the node from other peers", + ) + r.m.droppedPacketsOutbound = NewMultiLabelMapWithRegistry[DropLabels]( + r, + "tailscaled_outbound_dropped_packets_total", + "counter", + "Counts the number of packets dropped while being sent to other peers", + ) + }) +} + +// DroppedPacketsOutbound returns the outbound dropped packet metric, creating it +// if necessary. +func (r *Registry) DroppedPacketsOutbound() *metrics.MultiLabelMap[DropLabels] { + r.initOnce() + return r.m.droppedPacketsOutbound +} + +// DroppedPacketsInbound returns the inbound dropped packet metric. +func (r *Registry) DroppedPacketsInbound() *metrics.MultiLabelMap[DropLabels] { + r.initOnce() + return r.m.droppedPacketsInbound +} diff --git a/util/usermetric/usermetric.go b/util/usermetric/usermetric.go index c964e08a76395..7913a4ef0d5f8 100644 --- a/util/usermetric/usermetric.go +++ b/util/usermetric/usermetric.go @@ -19,6 +19,9 @@ import ( // Registry tracks user-facing metrics of various Tailscale subsystems. type Registry struct { vars expvar.Map + + // m contains common metrics owned by the registry. + m Metrics } // NewMultiLabelMapWithRegistry creates and register a new diff --git a/wgengine/magicsock/derp.go b/wgengine/magicsock/derp.go index 0204fa0f5a011..e9f07086271d5 100644 --- a/wgengine/magicsock/derp.go +++ b/wgengine/magicsock/derp.go @@ -674,6 +674,9 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan if err != nil { c.logf("magicsock: derp.Send(%v): %v", wr.addr, err) metricSendDERPError.Add(1) + if !wr.isDisco { + c.metrics.outboundPacketsDroppedErrors.Add(1) + } } else if !wr.isDisco { c.metrics.outboundPacketsDERPTotal.Add(1) c.metrics.outboundBytesDERPTotal.Add(int64(len(wr.b))) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 705e42d9ef84c..a9c6fa070e90f 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -127,6 +127,10 @@ type metrics struct { outboundBytesIPv4Total expvar.Int outboundBytesIPv6Total expvar.Int outboundBytesDERPTotal expvar.Int + + // outboundPacketsDroppedErrors is the total number of outbound packets + // dropped due to errors. + outboundPacketsDroppedErrors expvar.Int } // A Conn routes UDP packets and actively manages a list of its endpoints. @@ -605,6 +609,8 @@ func registerMetrics(reg *usermetric.Registry) *metrics { "counter", "Counts the number of bytes sent to other peers", ) + outboundPacketsDroppedErrors := reg.DroppedPacketsOutbound() + m := new(metrics) // Map clientmetrics to the usermetric counters. @@ -631,6 +637,8 @@ func registerMetrics(reg *usermetric.Registry) *metrics { outboundBytesTotal.Set(pathDirectV6, &m.outboundBytesIPv6Total) outboundBytesTotal.Set(pathDERP, &m.outboundBytesDERPTotal) + outboundPacketsDroppedErrors.Set(usermetric.DropLabels{Reason: usermetric.ReasonError}, &m.outboundPacketsDroppedErrors) + return m } @@ -1202,8 +1210,13 @@ func (c *Conn) networkDown() bool { return !c.networkUp.Load() } // Send implements conn.Bind. // // See https://pkg.go.dev/golang.zx2c4.com/wireguard/conn#Bind.Send -func (c *Conn) Send(buffs [][]byte, ep conn.Endpoint) error { +func (c *Conn) Send(buffs [][]byte, ep conn.Endpoint) (err error) { n := int64(len(buffs)) + defer func() { + if err != nil { + c.metrics.outboundPacketsDroppedErrors.Add(n) + } + }() metricSendData.Add(n) if c.networkDown() { metricSendDataNetworkDown.Add(n) diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 7e48e1daa2604..1b3f8ec73c16e 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -63,6 +63,7 @@ import ( "tailscale.com/types/nettype" "tailscale.com/types/ptr" "tailscale.com/util/cibuild" + "tailscale.com/util/must" "tailscale.com/util/racebuild" "tailscale.com/util/set" "tailscale.com/util/usermetric" @@ -3083,3 +3084,27 @@ func TestMaybeRebindOnError(t *testing.T) { } }) } + +func TestNetworkDownSendErrors(t *testing.T) { + netMon := must.Get(netmon.New(t.Logf)) + defer netMon.Close() + + reg := new(usermetric.Registry) + conn := must.Get(NewConn(Options{ + DisablePortMapper: true, + Logf: t.Logf, + NetMon: netMon, + Metrics: reg, + })) + defer conn.Close() + + conn.SetNetworkUp(false) + if err := conn.Send([][]byte{{00}}, &lazyEndpoint{}); err == nil { + t.Error("expected error, got nil") + } + resp := httptest.NewRecorder() + reg.Handler(resp, new(http.Request)) + if !strings.Contains(resp.Body.String(), `tailscaled_outbound_dropped_packets_total{reason="error"} 1`) { + t.Errorf("expected NetworkDown to increment packet dropped metric; got %q", resp.Body.String()) + } +} From 45354dab9bddc97acaa84b03b99448ac49b4c0cf Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Thu, 31 Oct 2024 14:45:57 +0000 Subject: [PATCH 063/179] ipn,tailcfg: add app connector config knob to conffile (#13942) Make it possible to advertise app connector via a new conffile field. Also bumps capver - conffile deserialization errors out if unknonw fields are set, so we need to know which clients understand the new field. Updates tailscale/tailscale#11113 Signed-off-by: Irbe Krumina --- ipn/conf.go | 6 ++++++ tailcfg/tailcfg.go | 3 ++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/ipn/conf.go b/ipn/conf.go index 6a67f40040c76..1b2831b03b6c6 100644 --- a/ipn/conf.go +++ b/ipn/conf.go @@ -32,6 +32,8 @@ type ConfigVAlpha struct { AdvertiseRoutes []netip.Prefix `json:",omitempty"` DisableSNAT opt.Bool `json:",omitempty"` + AppConnector *AppConnectorPrefs `json:",omitempty"` // advertise app connector; defaults to false (if nil or explicitly set to false) + NetfilterMode *string `json:",omitempty"` // "on", "off", "nodivert" NoStatefulFiltering opt.Bool `json:",omitempty"` @@ -137,5 +139,9 @@ func (c *ConfigVAlpha) ToPrefs() (MaskedPrefs, error) { mp.AutoUpdate = *c.AutoUpdate mp.AutoUpdateSet = AutoUpdatePrefsMask{ApplySet: true, CheckSet: true} } + if c.AppConnector != nil { + mp.AppConnector = *c.AppConnector + mp.AppConnectorSet = true + } return mp, nil } diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 0e1b1d4aef9bc..9e39a43364962 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -149,7 +149,8 @@ type CapabilityVersion int // - 104: 2024-08-03: SelfNodeV6MasqAddrForThisPeer now works // - 105: 2024-08-05: Fixed SSH behavior on systems that use busybox (issue #12849) // - 106: 2024-09-03: fix panic regression from cryptokey routing change (65fe0ba7b5) -const CurrentCapabilityVersion CapabilityVersion = 106 +// - 107: 2024-10-30: add App Connector to conffile (PR #13942) +const CurrentCapabilityVersion CapabilityVersion = 107 type StableID string From 3f626c0d774bc1b8a93be26a4aa8f2dadeb27ece Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Tue, 29 Oct 2024 15:22:49 -0500 Subject: [PATCH 064/179] cmd/tailscale/cli, client/tailscale, ipn/localapi: add tailscale syspolicy {list,reload} commands In this PR, we add the tailscale syspolicy command with two subcommands: list, which displays policy settings, and reload, which forces a reload of those settings. We also update the LocalAPI and LocalClient to facilitate these additions. Updates #12687 Signed-off-by: Nick Khyl --- client/tailscale/localclient.go | 28 ++++++++ cmd/k8s-operator/depaware.txt | 2 +- cmd/tailscale/cli/cli.go | 1 + cmd/tailscale/cli/syspolicy.go | 110 ++++++++++++++++++++++++++++++++ cmd/tailscaled/depaware.txt | 2 +- ipn/localapi/localapi.go | 50 +++++++++++++++ 6 files changed, 191 insertions(+), 2 deletions(-) create mode 100644 cmd/tailscale/cli/syspolicy.go diff --git a/client/tailscale/localclient.go b/client/tailscale/localclient.go index df51dc1cab52c..9c2bcc467b0e2 100644 --- a/client/tailscale/localclient.go +++ b/client/tailscale/localclient.go @@ -40,6 +40,7 @@ import ( "tailscale.com/types/dnstype" "tailscale.com/types/key" "tailscale.com/types/tkatype" + "tailscale.com/util/syspolicy/setting" ) // defaultLocalClient is the default LocalClient when using the legacy @@ -814,6 +815,33 @@ func (lc *LocalClient) EditPrefs(ctx context.Context, mp *ipn.MaskedPrefs) (*ipn return decodeJSON[*ipn.Prefs](body) } +// GetEffectivePolicy returns the effective policy for the specified scope. +func (lc *LocalClient) GetEffectivePolicy(ctx context.Context, scope setting.PolicyScope) (*setting.Snapshot, error) { + scopeID, err := scope.MarshalText() + if err != nil { + return nil, err + } + body, err := lc.get200(ctx, "/localapi/v0/policy/"+string(scopeID)) + if err != nil { + return nil, err + } + return decodeJSON[*setting.Snapshot](body) +} + +// ReloadEffectivePolicy reloads the effective policy for the specified scope +// by reading and merging policy settings from all applicable policy sources. +func (lc *LocalClient) ReloadEffectivePolicy(ctx context.Context, scope setting.PolicyScope) (*setting.Snapshot, error) { + scopeID, err := scope.MarshalText() + if err != nil { + return nil, err + } + body, err := lc.send(ctx, "POST", "/localapi/v0/policy/"+string(scopeID), 200, http.NoBody) + if err != nil { + return nil, err + } + return decodeJSON[*setting.Snapshot](body) +} + // GetDNSOSConfig returns the system DNS configuration for the current device. // That is, it returns the DNS configuration that the system would use if Tailscale weren't being used. func (lc *LocalClient) GetDNSOSConfig(ctx context.Context) (*apitype.DNSOSConfig, error) { diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 2ad3978c927d7..d62f2e225ca7e 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -814,7 +814,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+ tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source - tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy+ tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock diff --git a/cmd/tailscale/cli/cli.go b/cmd/tailscale/cli/cli.go index f786bcea5bdf7..130a11623351d 100644 --- a/cmd/tailscale/cli/cli.go +++ b/cmd/tailscale/cli/cli.go @@ -185,6 +185,7 @@ change in the future. logoutCmd, switchCmd, configureCmd, + syspolicyCmd, netcheckCmd, ipCmd, dnsCmd, diff --git a/cmd/tailscale/cli/syspolicy.go b/cmd/tailscale/cli/syspolicy.go new file mode 100644 index 0000000000000..06a19defb459a --- /dev/null +++ b/cmd/tailscale/cli/syspolicy.go @@ -0,0 +1,110 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "os" + "slices" + "text/tabwriter" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/util/syspolicy/setting" +) + +var syspolicyArgs struct { + json bool // JSON output mode +} + +var syspolicyCmd = &ffcli.Command{ + Name: "syspolicy", + ShortHelp: "Diagnose the MDM and system policy configuration", + LongHelp: "The 'tailscale syspolicy' command provides tools for diagnosing the MDM and system policy configuration.", + ShortUsage: "tailscale syspolicy ", + UsageFunc: usageFuncNoDefaultValues, + Subcommands: []*ffcli.Command{ + { + Name: "list", + ShortUsage: "tailscale syspolicy list", + Exec: runSysPolicyList, + ShortHelp: "Prints effective policy settings", + LongHelp: "The 'tailscale syspolicy list' subcommand displays the effective policy settings and their sources (e.g., MDM or environment variables).", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("syspolicy list") + fs.BoolVar(&syspolicyArgs.json, "json", false, "output in JSON format") + return fs + })(), + }, + { + Name: "reload", + ShortUsage: "tailscale syspolicy reload", + Exec: runSysPolicyReload, + ShortHelp: "Forces a reload of policy settings, even if no changes are detected, and prints the result", + LongHelp: "The 'tailscale syspolicy reload' subcommand forces a reload of policy settings, even if no changes are detected, and prints the result.", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("syspolicy reload") + fs.BoolVar(&syspolicyArgs.json, "json", false, "output in JSON format") + return fs + })(), + }, + }, +} + +func runSysPolicyList(ctx context.Context, args []string) error { + policy, err := localClient.GetEffectivePolicy(ctx, setting.DefaultScope()) + if err != nil { + return err + } + printPolicySettings(policy) + return nil + +} + +func runSysPolicyReload(ctx context.Context, args []string) error { + policy, err := localClient.ReloadEffectivePolicy(ctx, setting.DefaultScope()) + if err != nil { + return err + } + printPolicySettings(policy) + return nil +} + +func printPolicySettings(policy *setting.Snapshot) { + if syspolicyArgs.json { + json, err := json.MarshalIndent(policy, "", "\t") + if err != nil { + errf("syspolicy marshalling error: %v", err) + } else { + outln(string(json)) + } + return + } + if policy.Len() == 0 { + outln("No policy settings") + return + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "Name\tOrigin\tValue\tError") + fmt.Fprintln(w, "----\t------\t-----\t-----") + for _, k := range slices.Sorted(policy.Keys()) { + setting, _ := policy.GetSetting(k) + var origin string + if o := setting.Origin(); o != nil { + origin = o.String() + } + if err := setting.Error(); err != nil { + fmt.Fprintf(w, "%s\t%s\t\t{%s}\n", k, origin, err) + } else { + fmt.Fprintf(w, "%s\t%s\t%s\t\n", k, origin, setting.Value()) + } + } + w.Flush() + + fmt.Println() + return +} diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index b3a4aa86fba30..53e4790d38eeb 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -403,7 +403,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+ tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source - tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy+ tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index 1d580eca9ff95..0d41725d83dbe 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -62,6 +62,8 @@ import ( "tailscale.com/util/osdiag" "tailscale.com/util/progresstracking" "tailscale.com/util/rands" + "tailscale.com/util/syspolicy/rsop" + "tailscale.com/util/syspolicy/setting" "tailscale.com/version" "tailscale.com/wgengine/magicsock" ) @@ -76,6 +78,7 @@ var handler = map[string]localAPIHandler{ "cert/": (*Handler).serveCert, "file-put/": (*Handler).serveFilePut, "files/": (*Handler).serveFiles, + "policy/": (*Handler).servePolicy, "profiles/": (*Handler).serveProfiles, // The other /localapi/v0/NAME handlers are exact matches and contain only NAME @@ -1332,6 +1335,53 @@ func (h *Handler) servePrefs(w http.ResponseWriter, r *http.Request) { e.Encode(prefs) } +func (h *Handler) servePolicy(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "policy access denied", http.StatusForbidden) + return + } + + suffix, ok := strings.CutPrefix(r.URL.EscapedPath(), "/localapi/v0/policy/") + if !ok { + http.Error(w, "misconfigured", http.StatusInternalServerError) + return + } + + var scope setting.PolicyScope + if suffix == "" { + scope = setting.DefaultScope() + } else if err := scope.UnmarshalText([]byte(suffix)); err != nil { + http.Error(w, fmt.Sprintf("%q is not a valid scope", suffix), http.StatusBadRequest) + return + } + + policy, err := rsop.PolicyFor(scope) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + var effectivePolicy *setting.Snapshot + switch r.Method { + case "GET": + effectivePolicy = policy.Get() + case "POST": + effectivePolicy, err = policy.Reload() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + default: + http.Error(w, "unsupported method", http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "application/json") + e := json.NewEncoder(w) + e.SetIndent("", "\t") + e.Encode(effectivePolicy) +} + type resJSON struct { Error string `json:",omitempty"` } From 3477bfd234523601e2788a51bb365422448278ed Mon Sep 17 00:00:00 2001 From: Andrew Lytvynov Date: Thu, 31 Oct 2024 13:12:38 -0500 Subject: [PATCH 065/179] safeweb: add support for "/" and "/foo" handler distinction (#13980) By counting "/" elements in the pattern we catch many scenarios, but not the root-level handler. If either of the patterns is "/", compare the pattern length to pick the right one. Updates https://github.com/tailscale/corp/issues/8027 Signed-off-by: Andrew Lytvynov --- safeweb/http.go | 17 ++++++++++++++++- safeweb/http_test.go | 10 ++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/safeweb/http.go b/safeweb/http.go index 14c61336ac311..d8818d38605b8 100644 --- a/safeweb/http.go +++ b/safeweb/http.go @@ -225,12 +225,27 @@ const ( browserHandler ) +func (h handlerType) String() string { + switch h { + case browserHandler: + return "browser" + case apiHandler: + return "api" + default: + return "unknown" + } +} + // checkHandlerType returns either apiHandler or browserHandler, depending on // whether apiPattern or browserPattern is more specific (i.e. which pattern // contains more pathname components). If they are equally specific, it returns // unknownHandler. func checkHandlerType(apiPattern, browserPattern string) handlerType { - c := cmp.Compare(strings.Count(path.Clean(apiPattern), "/"), strings.Count(path.Clean(browserPattern), "/")) + apiPattern, browserPattern = path.Clean(apiPattern), path.Clean(browserPattern) + c := cmp.Compare(strings.Count(apiPattern, "/"), strings.Count(browserPattern, "/")) + if apiPattern == "/" || browserPattern == "/" { + c = cmp.Compare(len(apiPattern), len(browserPattern)) + } switch { case c > 0: return apiHandler diff --git a/safeweb/http_test.go b/safeweb/http_test.go index cec14b2b9bb8b..a2e2d7644cdf3 100644 --- a/safeweb/http_test.go +++ b/safeweb/http_test.go @@ -527,13 +527,13 @@ func TestGetMoreSpecificPattern(t *testing.T) { { desc: "same prefix", a: "/foo/bar/quux", - b: "/foo/bar/", + b: "/foo/bar/", // path.Clean will strip the trailing slash. want: apiHandler, }, { desc: "almost same prefix, but not a path component", a: "/goat/sheep/cheese", - b: "/goat/sheepcheese/", + b: "/goat/sheepcheese/", // path.Clean will strip the trailing slash. want: apiHandler, }, { @@ -554,6 +554,12 @@ func TestGetMoreSpecificPattern(t *testing.T) { b: "///////", want: unknownHandler, }, + { + desc: "root-level", + a: "/latest", + b: "/", // path.Clean will NOT strip the trailing slash. + want: apiHandler, + }, } { t.Run(tt.desc, func(t *testing.T) { got := checkHandlerType(tt.a, tt.b) From 6985369479db2c9d5bacccbde6d66630a81eb1ab Mon Sep 17 00:00:00 2001 From: Andrea Gottardo Date: Thu, 31 Oct 2024 12:00:34 -0700 Subject: [PATCH 066/179] net/sockstats: prevent crash in setNetMon (#13985) --- net/sockstats/sockstats_tsgo.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/net/sockstats/sockstats_tsgo.go b/net/sockstats/sockstats_tsgo.go index af691302f8be8..fec9ec3b0dad2 100644 --- a/net/sockstats/sockstats_tsgo.go +++ b/net/sockstats/sockstats_tsgo.go @@ -279,7 +279,13 @@ func setNetMon(netMon *netmon.Monitor) { if ifName == "" { return } - ifIndex := state.Interface[ifName].Index + // DefaultRouteInterface and Interface are gathered at different points in time. + // Check for existence first, to avoid a nil pointer dereference. + iface, ok := state.Interface[ifName] + if !ok { + return + } + ifIndex := iface.Index sockStats.mu.Lock() defer sockStats.mu.Unlock() // Ignore changes to unknown interfaces -- it would require From ddbc950f466ff7fa4c0b2dfb11489311b0d384f2 Mon Sep 17 00:00:00 2001 From: Andrew Lytvynov Date: Thu, 31 Oct 2024 14:13:29 -0500 Subject: [PATCH 067/179] safeweb: add support for custom CSP (#13975) To allow more flexibility with CSPs, add a fully customizable `CSP` type that can be provided in `Config` and encodes itself into the correct format. Preserve the `CSPAllowInlineStyles` option as is today, but maybe that'll get deprecated later in favor of the new CSP field. In particular, this allows for pages loading external JS, or inline JS with nonces or hashes (see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src#unsafe_inline_script) Updates https://github.com/tailscale/corp/issues/8027 Signed-off-by: Andrew Lytvynov --- safeweb/http.go | 88 +++++++++++++++++++++++++++++++++++++------- safeweb/http_test.go | 28 +++++++++----- 2 files changed, 92 insertions(+), 24 deletions(-) diff --git a/safeweb/http.go b/safeweb/http.go index d8818d38605b8..bd53eca5bbfd9 100644 --- a/safeweb/http.go +++ b/safeweb/http.go @@ -74,25 +74,74 @@ import ( crand "crypto/rand" "fmt" "log" + "maps" "net" "net/http" "net/url" "path" + "slices" "strings" "github.com/gorilla/csrf" ) -// The default Content-Security-Policy header. -var defaultCSP = strings.Join([]string{ - `default-src 'self'`, // origin is the only valid source for all content types - `script-src 'self'`, // disallow inline javascript - `frame-ancestors 'none'`, // disallow framing of the page - `form-action 'self'`, // disallow form submissions to other origins - `base-uri 'self'`, // disallow base URIs from other origins - `block-all-mixed-content`, // disallow mixed content when serving over HTTPS - `object-src 'self'`, // disallow embedding of resources from other origins -}, "; ") +// CSP is the value of a Content-Security-Policy header. Keys are CSP +// directives (like "default-src") and values are source expressions (like +// "'self'" or "https://tailscale.com"). A nil slice value is allowed for some +// directives like "upgrade-insecure-requests" that don't expect a list of +// source definitions. +type CSP map[string][]string + +// DefaultCSP is the recommended CSP to use when not loading resources from +// other domains and not embedding the current website. If you need to tweak +// the CSP, it is recommended to extend DefaultCSP instead of writing your own +// from scratch. +func DefaultCSP() CSP { + return CSP{ + "default-src": {"self"}, // origin is the only valid source for all content types + "frame-ancestors": {"none"}, // disallow framing of the page + "form-action": {"self"}, // disallow form submissions to other origins + "base-uri": {"self"}, // disallow base URIs from other origins + // TODO(awly): consider upgrade-insecure-requests in SecureContext + // instead, as this is deprecated. + "block-all-mixed-content": nil, // disallow mixed content when serving over HTTPS + } +} + +// Set sets the values for a given directive. Empty values are allowed, if the +// directive doesn't expect any (like "upgrade-insecure-requests"). +func (csp CSP) Set(directive string, values ...string) { + csp[directive] = values +} + +// Add adds a source expression to an existing directive. +func (csp CSP) Add(directive, value string) { + csp[directive] = append(csp[directive], value) +} + +// Del deletes a directive and all its values. +func (csp CSP) Del(directive string) { + delete(csp, directive) +} + +func (csp CSP) String() string { + keys := slices.Collect(maps.Keys(csp)) + slices.Sort(keys) + var s strings.Builder + for _, k := range keys { + s.WriteString(k) + for _, v := range csp[k] { + // Special values like 'self', 'none', 'unsafe-inline', etc., must + // be quoted. Do it implicitly as a convenience here. + if !strings.Contains(v, ".") && len(v) > 1 && v[0] != '\'' && v[len(v)-1] != '\'' { + v = "'" + v + "'" + } + s.WriteString(" " + v) + } + s.WriteString("; ") + } + return strings.TrimSpace(s.String()) +} // The default Strict-Transport-Security header. This header tells the browser // to exclusively use HTTPS for all requests to the origin for the next year. @@ -130,6 +179,9 @@ type Config struct { // startup. CSRFSecret []byte + // CSP is the Content-Security-Policy header to return with BrowserMux + // responses. + CSP CSP // CSPAllowInlineStyles specifies whether to include `style-src: // unsafe-inline` in the Content-Security-Policy header to permit the use of // inline CSS. @@ -168,6 +220,10 @@ func (c *Config) setDefaults() error { } } + if c.CSP == nil { + c.CSP = DefaultCSP() + } + return nil } @@ -199,16 +255,20 @@ func NewServer(config Config) (*Server, error) { if config.CookiesSameSiteLax { sameSite = csrf.SameSiteLaxMode } + if config.CSPAllowInlineStyles { + if _, ok := config.CSP["style-src"]; ok { + config.CSP.Add("style-src", "unsafe-inline") + } else { + config.CSP.Set("style-src", "self", "unsafe-inline") + } + } s := &Server{ Config: config, - csp: defaultCSP, + csp: config.CSP.String(), // only set Secure flag on CSRF cookies if we are in a secure context // as otherwise the browser will reject the cookie csrfProtect: csrf.Protect(config.CSRFSecret, csrf.Secure(config.SecureContext), csrf.SameSite(sameSite)), } - if config.CSPAllowInlineStyles { - s.csp = defaultCSP + `; style-src 'self' 'unsafe-inline'` - } s.h = cmp.Or(config.HTTPServer, &http.Server{}) if s.h.Handler != nil { return nil, fmt.Errorf("use safeweb.Config.APIMux and safeweb.Config.BrowserMux instead of http.Server.Handler") diff --git a/safeweb/http_test.go b/safeweb/http_test.go index a2e2d7644cdf3..852ce326ba374 100644 --- a/safeweb/http_test.go +++ b/safeweb/http_test.go @@ -241,18 +241,26 @@ func TestCSRFProtection(t *testing.T) { func TestContentSecurityPolicyHeader(t *testing.T) { tests := []struct { name string + csp CSP apiRoute bool - wantCSP bool + wantCSP string }{ { - name: "default routes get CSP headers", - apiRoute: false, - wantCSP: true, + name: "default CSP", + wantCSP: `base-uri 'self'; block-all-mixed-content; default-src 'self'; form-action 'self'; frame-ancestors 'none';`, + }, + { + name: "custom CSP", + csp: CSP{ + "default-src": {"'self'", "https://tailscale.com"}, + "upgrade-insecure-requests": nil, + }, + wantCSP: `default-src 'self' https://tailscale.com; upgrade-insecure-requests;`, }, { name: "`/api/*` routes do not get CSP headers", apiRoute: true, - wantCSP: false, + wantCSP: "", }, } @@ -265,9 +273,9 @@ func TestContentSecurityPolicyHeader(t *testing.T) { var s *Server var err error if tt.apiRoute { - s, err = NewServer(Config{APIMux: h}) + s, err = NewServer(Config{APIMux: h, CSP: tt.csp}) } else { - s, err = NewServer(Config{BrowserMux: h}) + s, err = NewServer(Config{BrowserMux: h, CSP: tt.csp}) } if err != nil { t.Fatal(err) @@ -279,8 +287,8 @@ func TestContentSecurityPolicyHeader(t *testing.T) { s.h.Handler.ServeHTTP(w, req) resp := w.Result() - if (resp.Header.Get("Content-Security-Policy") == "") == tt.wantCSP { - t.Fatalf("content security policy want: %v; got: %v", tt.wantCSP, resp.Header.Get("Content-Security-Policy")) + if got := resp.Header.Get("Content-Security-Policy"); got != tt.wantCSP { + t.Fatalf("content security policy want: %q; got: %q", tt.wantCSP, got) } }) } @@ -397,7 +405,7 @@ func TestCSPAllowInlineStyles(t *testing.T) { csp := resp.Header.Get("Content-Security-Policy") allowsStyles := strings.Contains(csp, "style-src 'self' 'unsafe-inline'") if allowsStyles != allow { - t.Fatalf("CSP inline styles want: %v; got: %v", allow, allowsStyles) + t.Fatalf("CSP inline styles want: %v, got: %v in %q", allow, allowsStyles, csp) } }) } From 84c88604728938a888ff3ca1bfb10c256a77e0f8 Mon Sep 17 00:00:00 2001 From: Aaron Klotz Date: Thu, 31 Oct 2024 15:13:08 -0600 Subject: [PATCH 068/179] util/syspolicy: add policy key for onboarding flow visibility Updates https://github.com/tailscale/corp/issues/23789 Signed-off-by: Aaron Klotz --- util/syspolicy/policy_keys.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/util/syspolicy/policy_keys.go b/util/syspolicy/policy_keys.go index 162885b27fa67..bb9a5d6cc5934 100644 --- a/util/syspolicy/policy_keys.go +++ b/util/syspolicy/policy_keys.go @@ -77,6 +77,9 @@ const ( // SuggestedExitNodeVisibility controls the visibility of suggested exit nodes in the client GUI. // When this system policy is set to 'hide', an exit node suggestion won't be presented to the user as part of the exit nodes picker. SuggestedExitNodeVisibility Key = "SuggestedExitNode" + // OnboardingFlowVisibility controls the visibility of the onboarding flow in the client GUI. + // When this system policy is set to 'hide', the onboarding flow is never shown to the user. + OnboardingFlowVisibility Key = "OnboardingFlow" // Keys with a string value formatted for use with time.ParseDuration(). KeyExpirationNoticeTime Key = "KeyExpirationNotice" // default 24 hours @@ -166,6 +169,7 @@ var implicitDefinitions = []*setting.Definition{ setting.NewDefinition(SuggestedExitNodeVisibility, setting.UserSetting, setting.VisibilityValue), setting.NewDefinition(TestMenuVisibility, setting.UserSetting, setting.VisibilityValue), setting.NewDefinition(UpdateMenuVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(OnboardingFlowVisibility, setting.UserSetting, setting.VisibilityValue), } func init() { From 49de23cf1bae372996de797d86ced771ed314756 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Thu, 31 Oct 2024 19:25:00 -0700 Subject: [PATCH 069/179] net/netcheck: add addReportHistoryAndSetPreferredDERP() test case (#13989) Add an explicit case for exercising preferred DERP hysteresis around the branch that compares latencies on a percentage basis. Updates #cleanup Signed-off-by: Jordan Whited --- net/netcheck/netcheck_test.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index f287978d225e7..b4fbb4023dcc1 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -357,6 +357,15 @@ func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) { wantPrevLen: 3, wantDERP: 2, // moved to d2 since d1 is gone }, + { + name: "preferred_derp_hysteresis_no_switch_pct", + steps: []step{ + {0 * time.Second, report("d1", 34*time.Millisecond, "d2", 35*time.Millisecond)}, + {1 * time.Second, report("d1", 34*time.Millisecond, "d2", 23*time.Millisecond)}, + }, + wantPrevLen: 2, + wantDERP: 1, // diff is 11ms, but d2 is greater than 2/3s of d1 + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From 0ffc7bf38b92f4b064d00fe89239f2da756de642 Mon Sep 17 00:00:00 2001 From: Renato Aguiar Date: Fri, 25 Oct 2024 18:25:39 -0700 Subject: [PATCH 070/179] Fix MagicDNS on OpenBSD Add OpenBSD to the list of platforms that need DNS reconfigured on link changes. Signed-off-by: Renato Aguiar --- wgengine/userspace.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wgengine/userspace.go b/wgengine/userspace.go index fc204736a1da2..2dd0c4cd5da89 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -1236,7 +1236,7 @@ func (e *userspaceEngine) linkChange(delta *netmon.ChangeDelta) { // and Apple platforms. if changed { switch runtime.GOOS { - case "linux", "android", "ios", "darwin": + case "linux", "android", "ios", "darwin", "openbsd": e.wgLock.Lock() dnsCfg := e.lastDNSConfig e.wgLock.Unlock() From d09e9d967f1fd6349a2bddefffe2e9e9f4b33044 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Thu, 31 Oct 2024 08:30:11 -0700 Subject: [PATCH 071/179] ipn/ipnlocal: reload prefs correctly on ReloadConfig We were only updating the ProfileManager and not going down the EditPrefs path which meant the prefs weren't applied till either the process restarted or some other pref changed. This makes it so that we reconfigure everything correctly when ReloadConfig is called. Updates #13032 Signed-off-by: Maisem Ali --- ipn/ipnlocal/local.go | 48 ++++++++++++++++++++++--------- ipn/ipnlocal/local_test.go | 59 ++++++++++++++++++++++++++++++++------ 2 files changed, 85 insertions(+), 22 deletions(-) diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index b91f1337af0ed..edd56f7c452f5 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -479,7 +479,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo mConn.SetNetInfoCallback(b.setNetInfo) if sys.InitialConfig != nil { - if err := b.setConfigLocked(sys.InitialConfig); err != nil { + if err := b.initPrefsFromConfig(sys.InitialConfig); err != nil { return nil, err } } @@ -712,8 +712,8 @@ func (b *LocalBackend) SetDirectFileRoot(dir string) { // It returns (false, nil) if not running in declarative mode, (true, nil) on // success, or (false, error) on failure. func (b *LocalBackend) ReloadConfig() (ok bool, err error) { - b.mu.Lock() - defer b.mu.Unlock() + unlock := b.lockAndGetUnlock() + defer unlock() if b.conf == nil { return false, nil } @@ -721,18 +721,21 @@ func (b *LocalBackend) ReloadConfig() (ok bool, err error) { if err != nil { return false, err } - if err := b.setConfigLocked(conf); err != nil { + if err := b.setConfigLockedOnEntry(conf, unlock); err != nil { return false, fmt.Errorf("error setting config: %w", err) } return true, nil } -func (b *LocalBackend) setConfigLocked(conf *conffile.Config) error { - - // TODO(irbekrm): notify the relevant components to consume any prefs - // updates. Currently only initial configfile settings are applied - // immediately. +// initPrefsFromConfig initializes the backend's prefs from the provided config. +// This should only be called once, at startup. For updates at runtime, use +// [LocalBackend.setConfigLocked]. +func (b *LocalBackend) initPrefsFromConfig(conf *conffile.Config) error { + // TODO(maisem,bradfitz): combine this with setConfigLocked. This is called + // before anything is running, so there's no need to lock and we don't + // update any subsystems. At runtime, we both need to lock and update + // subsystems with the new prefs. p := b.pm.CurrentPrefs().AsStruct() mp, err := conf.Parsed.ToPrefs() if err != nil { @@ -742,13 +745,14 @@ func (b *LocalBackend) setConfigLocked(conf *conffile.Config) error { if err := b.pm.SetPrefs(p.View(), ipn.NetworkProfile{}); err != nil { return err } + b.setStaticEndpointsFromConfigLocked(conf) + b.conf = conf + return nil +} - defer func() { - b.conf = conf - }() - +func (b *LocalBackend) setStaticEndpointsFromConfigLocked(conf *conffile.Config) { if conf.Parsed.StaticEndpoints == nil && (b.conf == nil || b.conf.Parsed.StaticEndpoints == nil) { - return nil + return } // Ensure that magicsock conn has the up to date static wireguard @@ -762,6 +766,22 @@ func (b *LocalBackend) setConfigLocked(conf *conffile.Config) error { ms.SetStaticEndpoints(views.SliceOf(conf.Parsed.StaticEndpoints)) } } +} + +// setConfigLockedOnEntry uses the provided config to update the backend's prefs +// and other state. +func (b *LocalBackend) setConfigLockedOnEntry(conf *conffile.Config, unlock unlockOnce) error { + defer unlock() + p := b.pm.CurrentPrefs().AsStruct() + mp, err := conf.Parsed.ToPrefs() + if err != nil { + return fmt.Errorf("error parsing config to prefs: %w", err) + } + p.ApplyEdits(&mp) + b.setStaticEndpointsFromConfigLocked(conf) + b.setPrefsLockedOnEntry(p, unlock) + + b.conf = conf return nil } diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index 5fee5d00ee36a..433679dda193e 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -13,6 +13,7 @@ import ( "net/http" "net/netip" "os" + "path/filepath" "reflect" "slices" "strings" @@ -32,6 +33,7 @@ import ( "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/ipn" + "tailscale.com/ipn/conffile" "tailscale.com/ipn/ipnauth" "tailscale.com/ipn/store/mem" "tailscale.com/net/netcheck" @@ -432,16 +434,25 @@ func (panicOnUseTransport) RoundTrip(*http.Request) (*http.Response, error) { } func newTestLocalBackend(t testing.TB) *LocalBackend { + return newTestLocalBackendWithSys(t, new(tsd.System)) +} + +// newTestLocalBackendWithSys creates a new LocalBackend with the given tsd.System. +// If the state store or engine are not set in sys, they will be set to a new +// in-memory store and fake userspace engine, respectively. +func newTestLocalBackendWithSys(t testing.TB, sys *tsd.System) *LocalBackend { var logf logger.Logf = logger.Discard - sys := new(tsd.System) - store := new(mem.Store) - sys.Set(store) - eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) - if err != nil { - t.Fatalf("NewFakeUserspaceEngine: %v", err) + if _, ok := sys.StateStore.GetOK(); !ok { + sys.Set(new(mem.Store)) + } + if _, ok := sys.Engine.GetOK(); !ok { + eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) + if err != nil { + t.Fatalf("NewFakeUserspaceEngine: %v", err) + } + t.Cleanup(eng.Close) + sys.Set(eng) } - t.Cleanup(eng.Close) - sys.Set(eng) lb, err := NewLocalBackend(logf, logid.PublicID{}, sys, 0) if err != nil { t.Fatalf("NewLocalBackend: %v", err) @@ -4423,3 +4434,35 @@ func TestLoginNotifications(t *testing.T) { }) } } + +// TestConfigFileReload tests that the LocalBackend reloads its configuration +// when the configuration file changes. +func TestConfigFileReload(t *testing.T) { + cfg1 := `{"Hostname": "foo", "Version": "alpha0"}` + f := filepath.Join(t.TempDir(), "cfg") + must.Do(os.WriteFile(f, []byte(cfg1), 0600)) + sys := new(tsd.System) + sys.InitialConfig = must.Get(conffile.Load(f)) + lb := newTestLocalBackendWithSys(t, sys) + must.Do(lb.Start(ipn.Options{})) + + lb.mu.Lock() + hn := lb.hostinfo.Hostname + lb.mu.Unlock() + if hn != "foo" { + t.Fatalf("got %q; want %q", hn, "foo") + } + + cfg2 := `{"Hostname": "bar", "Version": "alpha0"}` + must.Do(os.WriteFile(f, []byte(cfg2), 0600)) + if !must.Get(lb.ReloadConfig()) { + t.Fatal("reload failed") + } + + lb.mu.Lock() + hn = lb.hostinfo.Hostname + lb.mu.Unlock() + if hn != "bar" { + t.Fatalf("got %q; want %q", hn, "bar") + } +} From 634cc2ba4a03714173f23915e933f9eed918c137 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 1 Nov 2024 10:50:40 -0700 Subject: [PATCH 072/179] wgengine/netstack: remove unused taildrive deps A filesystem was plumbed into netstack in 993acf4475b22d693 but hasn't been used since 2d5d6f5403f3. Remove it. Noticed while rebasing a Tailscale fork elsewhere. Updates tailscale/corp#16827 Change-Id: Ib76deeda205ffe912b77a59b9d22853ebff42813 Signed-off-by: Brad Fitzpatrick --- cmd/tailscaled/tailscaled.go | 2 -- cmd/tsconnect/wasm/wasm_js.go | 2 +- tsnet/tsnet.go | 2 +- wgengine/netstack/netstack.go | 29 +++++++++++++---------------- wgengine/netstack/netstack_test.go | 4 ++-- 5 files changed, 17 insertions(+), 22 deletions(-) diff --git a/cmd/tailscaled/tailscaled.go b/cmd/tailscaled/tailscaled.go index 2831b4061973d..7a5ee03983f44 100644 --- a/cmd/tailscaled/tailscaled.go +++ b/cmd/tailscaled/tailscaled.go @@ -788,7 +788,6 @@ func runDebugServer(mux *http.ServeMux, addr string) { } func newNetstack(logf logger.Logf, sys *tsd.System) (*netstack.Impl, error) { - tfs, _ := sys.DriveForLocal.GetOK() ret, err := netstack.Create(logf, sys.Tun.Get(), sys.Engine.Get(), @@ -796,7 +795,6 @@ func newNetstack(logf logger.Logf, sys *tsd.System) (*netstack.Impl, error) { sys.Dialer.Get(), sys.DNSManager.Get(), sys.ProxyMapper(), - tfs, ) if err != nil { return nil, err diff --git a/cmd/tsconnect/wasm/wasm_js.go b/cmd/tsconnect/wasm/wasm_js.go index c35d543aabeae..d0bc991f2ca9d 100644 --- a/cmd/tsconnect/wasm/wasm_js.go +++ b/cmd/tsconnect/wasm/wasm_js.go @@ -115,7 +115,7 @@ func newIPN(jsConfig js.Value) map[string]any { } sys.Set(eng) - ns, err := netstack.Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil) + ns, err := netstack.Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper()) if err != nil { log.Fatalf("netstack.Create: %v", err) } diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 6751e0bb03cbe..7252d89fe9f64 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -546,7 +546,7 @@ func (s *Server) start() (reterr error) { sys.HealthTracker().SetMetricsRegistry(sys.UserMetricsRegistry()) // TODO(oxtoacart): do we need to support Taildrive on tsnet, and if so, how? - ns, err := netstack.Create(tsLogf, sys.Tun.Get(), eng, sys.MagicSock.Get(), s.dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil) + ns, err := netstack.Create(tsLogf, sys.Tun.Get(), eng, sys.MagicSock.Get(), s.dialer, sys.DNSManager.Get(), sys.ProxyMapper()) if err != nil { return fmt.Errorf("netstack.Create: %w", err) } diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 3185c5d556aa9..280f4b7bb5d3c 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -32,7 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" - "tailscale.com/drive" "tailscale.com/envknob" "tailscale.com/ipn/ipnlocal" "tailscale.com/metrics" @@ -174,19 +173,18 @@ type Impl struct { // It can only be set before calling Start. ProcessSubnets bool - ipstack *stack.Stack - linkEP *linkEndpoint - tundev *tstun.Wrapper - e wgengine.Engine - pm *proxymap.Mapper - mc *magicsock.Conn - logf logger.Logf - dialer *tsdial.Dialer - ctx context.Context // alive until Close - ctxCancel context.CancelFunc // called on Close - lb *ipnlocal.LocalBackend // or nil - dns *dns.Manager - driveForLocal drive.FileSystemForLocal // or nil + ipstack *stack.Stack + linkEP *linkEndpoint + tundev *tstun.Wrapper + e wgengine.Engine + pm *proxymap.Mapper + mc *magicsock.Conn + logf logger.Logf + dialer *tsdial.Dialer + ctx context.Context // alive until Close + ctxCancel context.CancelFunc // called on Close + lb *ipnlocal.LocalBackend // or nil + dns *dns.Manager // loopbackPort, if non-nil, will enable Impl to loop back (dnat to // :loopbackPort) TCP & UDP flows originally @@ -288,7 +286,7 @@ func setTCPBufSizes(ipstack *stack.Stack) error { } // Create creates and populates a new Impl. -func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, dialer *tsdial.Dialer, dns *dns.Manager, pm *proxymap.Mapper, driveForLocal drive.FileSystemForLocal) (*Impl, error) { +func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, dialer *tsdial.Dialer, dns *dns.Manager, pm *proxymap.Mapper) (*Impl, error) { if mc == nil { return nil, errors.New("nil magicsock.Conn") } @@ -382,7 +380,6 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi connsInFlightByClient: make(map[netip.Addr]int), packetsInFlight: make(map[stack.TransportEndpointID]struct{}), dns: dns, - driveForLocal: driveForLocal, } loopbackPort, ok := envknob.LookupInt("TS_DEBUG_NETSTACK_LOOPBACK_PORT") if ok && loopbackPort >= 0 && loopbackPort <= math.MaxUint16 { diff --git a/wgengine/netstack/netstack_test.go b/wgengine/netstack/netstack_test.go index 1bfc76fef097f..a46dcf9dd6fc9 100644 --- a/wgengine/netstack/netstack_test.go +++ b/wgengine/netstack/netstack_test.go @@ -65,7 +65,7 @@ func TestInjectInboundLeak(t *testing.T) { t.Fatal(err) } - ns, err := Create(logf, tunWrap, eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil) + ns, err := Create(logf, tunWrap, eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper()) if err != nil { t.Fatal(err) } @@ -116,7 +116,7 @@ func makeNetstack(tb testing.TB, config func(*Impl)) *Impl { tb.Cleanup(func() { eng.Close() }) sys.Set(eng) - ns, err := Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil) + ns, err := Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper()) if err != nil { tb.Fatal(err) } From b0626ff84c11f8ad5c680fdec214eb5981307f1c Mon Sep 17 00:00:00 2001 From: VimT Date: Fri, 20 Sep 2024 23:52:45 +0800 Subject: [PATCH 073/179] net/socks5: fix UDP relay in userspace-networking mode This commit addresses an issue with the SOCKS5 UDP relay functionality when using the --tun=userspace-networking option. Previously, UDP packets were not being correctly routed into the Tailscale network in this mode. Key changes: - Replace single UDP connection with a map of connections per target - Use c.srv.dial for creating connections to ensure proper routing Updates #7581 Change-Id: Iaaa66f9de6a3713218014cf3f498003a7cac9832 Signed-off-by: VimT --- net/socks5/socks5.go | 101 +++++++++++++++++++++++++++---------------- 1 file changed, 63 insertions(+), 38 deletions(-) diff --git a/net/socks5/socks5.go b/net/socks5/socks5.go index 0d651537fac9a..db315d949b117 100644 --- a/net/socks5/socks5.go +++ b/net/socks5/socks5.go @@ -22,6 +22,7 @@ import ( "log" "net" "strconv" + "tailscale.com/syncs" "time" "tailscale.com/types/logger" @@ -81,6 +82,12 @@ const ( addrTypeNotSupported replyCode = 8 ) +// UDP conn default buffer size and read timeout. +const ( + bufferSize = 8 * 1024 + readTimeout = 5 * time.Second +) + // Server is a SOCKS5 proxy server. type Server struct { // Logf optionally specifies the logger to use. @@ -143,7 +150,8 @@ type Conn struct { clientConn net.Conn request *request - udpClientAddr net.Addr + udpClientAddr net.Addr + udpTargetConns syncs.Map[string, net.Conn] } // Run starts the new connection. @@ -276,15 +284,6 @@ func (c *Conn) handleUDP() error { } defer clientUDPConn.Close() - serverUDPConn, err := net.ListenPacket("udp", "[::]:0") - if err != nil { - res := errorResponse(generalFailure) - buf, _ := res.marshal() - c.clientConn.Write(buf) - return err - } - defer serverUDPConn.Close() - bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String()) if err != nil { return err @@ -305,14 +304,20 @@ func (c *Conn) handleUDP() error { } c.clientConn.Write(buf) - return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn) + return c.transferUDP(c.clientConn, clientUDPConn) } -func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error { +func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - const bufferSize = 8 * 1024 - const readTimeout = 5 * time.Second + + // close all target udp connections when the client connection is closed + defer func() { + c.udpTargetConns.Range(func(_ string, conn net.Conn) bool { + _ = conn.Close() + return true + }) + }() // client -> target go func() { @@ -323,7 +328,7 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta case <-ctx.Done(): return default: - err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout) + err := c.handleUDPRequest(ctx, clientConn, buf) if err != nil { if isTimeout(err) { continue @@ -337,21 +342,50 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta } }() + // A UDP association terminates when the TCP connection that the UDP + // ASSOCIATE request arrived on terminates. RFC1928 + _, err := io.Copy(io.Discard, associatedTCP) + if err != nil { + err = fmt.Errorf("udp associated tcp conn: %w", err) + } + return err +} + +func (c *Conn) getOrDialTargetConn( + ctx context.Context, + clientConn net.PacketConn, + targetAddr string, +) (net.Conn, error) { + host, port, err := splitHostPort(targetAddr) + if err != nil { + return nil, err + } + + conn, loaded := c.udpTargetConns.Load(targetAddr) + if loaded { + return conn, nil + } + conn, err = c.srv.dial(ctx, "udp", targetAddr) + if err != nil { + return nil, err + } + c.udpTargetConns.Store(targetAddr, conn) + // target -> client go func() { - defer cancel() buf := make([]byte, bufferSize) + addr := socksAddr{addrType: getAddrType(host), addr: host, port: port} for { select { case <-ctx.Done(): return default: - err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout) + err := c.handleUDPResponse(clientConn, addr, conn, buf) if err != nil { if isTimeout(err) { continue } - if errors.Is(err, net.ErrClosed) { + if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) { return } c.logf("udp transfer: handle udp response fail: %v", err) @@ -360,20 +394,13 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta } }() - // A UDP association terminates when the TCP connection that the UDP - // ASSOCIATE request arrived on terminates. RFC1928 - _, err := io.Copy(io.Discard, associatedTCP) - if err != nil { - err = fmt.Errorf("udp associated tcp conn: %w", err) - } - return err + return conn, nil } func (c *Conn) handleUDPRequest( + ctx context.Context, clientConn net.PacketConn, - targetConn net.PacketConn, buf []byte, - readTimeout time.Duration, ) error { // add a deadline for the read to avoid blocking forever _ = clientConn.SetReadDeadline(time.Now().Add(readTimeout)) @@ -386,12 +413,14 @@ func (c *Conn) handleUDPRequest( if err != nil { return fmt.Errorf("parse udp request: %w", err) } - targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort()) + + targetAddr := req.addr.hostPort() + targetConn, err := c.getOrDialTargetConn(ctx, clientConn, targetAddr) if err != nil { - c.logf("resolve target addr fail: %v", err) + return fmt.Errorf("dial target %s fail: %w", targetAddr, err) } - nn, err := targetConn.WriteTo(data, targetAddr) + nn, err := targetConn.Write(data) if err != nil { return fmt.Errorf("write to target %s fail: %w", targetAddr, err) } @@ -402,22 +431,18 @@ func (c *Conn) handleUDPRequest( } func (c *Conn) handleUDPResponse( - targetConn net.PacketConn, clientConn net.PacketConn, + targetAddr socksAddr, + targetConn net.Conn, buf []byte, - readTimeout time.Duration, ) error { // add a deadline for the read to avoid blocking forever _ = targetConn.SetReadDeadline(time.Now().Add(readTimeout)) - n, addr, err := targetConn.ReadFrom(buf) + n, err := targetConn.Read(buf) if err != nil { return fmt.Errorf("read from target: %w", err) } - host, port, err := splitHostPort(addr.String()) - if err != nil { - return fmt.Errorf("split host port: %w", err) - } - hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}} + hdr := udpRequest{addr: targetAddr} pkt, err := hdr.marshal() if err != nil { return fmt.Errorf("marshal udp request: %w", err) From 43138c7a5c8815ea104499866440e34bb1220e93 Mon Sep 17 00:00:00 2001 From: VimT Date: Sat, 21 Sep 2024 14:37:51 +0800 Subject: [PATCH 074/179] net/socks5: optimize UDP relay Key changes: - No mutex for every udp package: replace syncs.Map with regular map for udpTargetConns - Use socksAddr as map key for better type safety - Add test for multi udp target Updates #7581 Change-Id: Ic3d384a9eab62dcbf267d7d6d268bf242cc8ed3c Signed-off-by: VimT --- net/socks5/socks5.go | 52 ++++++------ net/socks5/socks5_test.go | 166 +++++++++++++++++++++----------------- 2 files changed, 119 insertions(+), 99 deletions(-) diff --git a/net/socks5/socks5.go b/net/socks5/socks5.go index db315d949b117..4a5befa1d2fef 100644 --- a/net/socks5/socks5.go +++ b/net/socks5/socks5.go @@ -22,7 +22,6 @@ import ( "log" "net" "strconv" - "tailscale.com/syncs" "time" "tailscale.com/types/logger" @@ -151,7 +150,7 @@ type Conn struct { request *request udpClientAddr net.Addr - udpTargetConns syncs.Map[string, net.Conn] + udpTargetConns map[socksAddr]net.Conn } // Run starts the new connection. @@ -311,17 +310,18 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er ctx, cancel := context.WithCancel(context.Background()) defer cancel() - // close all target udp connections when the client connection is closed - defer func() { - c.udpTargetConns.Range(func(_ string, conn net.Conn) bool { - _ = conn.Close() - return true - }) - }() - // client -> target go func() { defer cancel() + + c.udpTargetConns = make(map[socksAddr]net.Conn) + // close all target udp connections when the client connection is closed + defer func() { + for _, conn := range c.udpTargetConns { + _ = conn.Close() + } + }() + buf := make([]byte, bufferSize) for { select { @@ -354,33 +354,27 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er func (c *Conn) getOrDialTargetConn( ctx context.Context, clientConn net.PacketConn, - targetAddr string, + targetAddr socksAddr, ) (net.Conn, error) { - host, port, err := splitHostPort(targetAddr) - if err != nil { - return nil, err - } - - conn, loaded := c.udpTargetConns.Load(targetAddr) - if loaded { + conn, exist := c.udpTargetConns[targetAddr] + if exist { return conn, nil } - conn, err = c.srv.dial(ctx, "udp", targetAddr) + conn, err := c.srv.dial(ctx, "udp", targetAddr.hostPort()) if err != nil { return nil, err } - c.udpTargetConns.Store(targetAddr, conn) + c.udpTargetConns[targetAddr] = conn // target -> client go func() { buf := make([]byte, bufferSize) - addr := socksAddr{addrType: getAddrType(host), addr: host, port: port} for { select { case <-ctx.Done(): return default: - err := c.handleUDPResponse(clientConn, addr, conn, buf) + err := c.handleUDPResponse(clientConn, targetAddr, conn, buf) if err != nil { if isTimeout(err) { continue @@ -414,18 +408,17 @@ func (c *Conn) handleUDPRequest( return fmt.Errorf("parse udp request: %w", err) } - targetAddr := req.addr.hostPort() - targetConn, err := c.getOrDialTargetConn(ctx, clientConn, targetAddr) + targetConn, err := c.getOrDialTargetConn(ctx, clientConn, req.addr) if err != nil { - return fmt.Errorf("dial target %s fail: %w", targetAddr, err) + return fmt.Errorf("dial target %s fail: %w", req.addr, err) } nn, err := targetConn.Write(data) if err != nil { - return fmt.Errorf("write to target %s fail: %w", targetAddr, err) + return fmt.Errorf("write to target %s fail: %w", req.addr, err) } if nn != len(data) { - return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite) + return fmt.Errorf("write to target %s fail: %w", req.addr, io.ErrShortWrite) } return nil } @@ -652,10 +645,15 @@ func (s socksAddr) marshal() ([]byte, error) { pkt = binary.BigEndian.AppendUint16(pkt, s.port) return pkt, nil } + func (s socksAddr) hostPort() string { return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port))) } +func (s socksAddr) String() string { + return s.hostPort() +} + // response contains the contents of // a response packet sent from the proxy // to the client. diff --git a/net/socks5/socks5_test.go b/net/socks5/socks5_test.go index 11ea59d4b57d1..bc6fac79fdcf9 100644 --- a/net/socks5/socks5_test.go +++ b/net/socks5/socks5_test.go @@ -169,12 +169,25 @@ func TestReadPassword(t *testing.T) { func TestUDP(t *testing.T) { // backend UDP server which we'll use SOCKS5 to connect to - listener, err := net.ListenPacket("udp", ":0") - if err != nil { - t.Fatal(err) + newUDPEchoServer := func() net.PacketConn { + listener, err := net.ListenPacket("udp", ":0") + if err != nil { + t.Fatal(err) + } + go udpEchoServer(listener) + return listener } - backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port - go udpEchoServer(listener) + + const echoServerNumber = 3 + echoServerListener := make([]net.PacketConn, echoServerNumber) + for i := 0; i < echoServerNumber; i++ { + echoServerListener[i] = newUDPEchoServer() + } + defer func() { + for i := 0; i < echoServerNumber; i++ { + _ = echoServerListener[i].Close() + } + }() // SOCKS5 server socks5, err := net.Listen("tcp", ":0") @@ -184,84 +197,93 @@ func TestUDP(t *testing.T) { socks5Port := socks5.Addr().(*net.TCPAddr).Port go socks5Server(socks5) - // net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port)) - if err != nil { - t.Fatal(err) - } - _, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth - if err != nil { - t.Fatal(err) - } - buf := make([]byte, 1024) - n, err := conn.Read(buf) // server hello - if err != nil { - t.Fatal(err) - } - if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 { - t.Fatalf("got: %q want: 0x05 0x00", buf[:n]) - } + // make a socks5 udpAssociate conn + newUdpAssociateConn := func() (socks5Conn net.Conn, socks5UDPAddr socksAddr) { + // net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port)) + if err != nil { + t.Fatal(err) + } + _, err = conn.Write([]byte{socks5Version, 0x01, noAuthRequired}) // client hello with no auth + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := conn.Read(buf) // server hello + if err != nil { + t.Fatal(err) + } + if n != 2 || buf[0] != socks5Version || buf[1] != noAuthRequired { + t.Fatalf("got: %q want: 0x05 0x00", buf[:n]) + } - targetAddr := socksAddr{ - addrType: domainName, - addr: "localhost", - port: uint16(backendServerPort), - } - targetAddrPkt, err := targetAddr.marshal() - if err != nil { - t.Fatal(err) - } - _, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust - if err != nil { - t.Fatal(err) - } + targetAddr := socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0} + targetAddrPkt, err := targetAddr.marshal() + if err != nil { + t.Fatal(err) + } + _, err = conn.Write(append([]byte{socks5Version, byte(udpAssociate), 0x00}, targetAddrPkt...)) // client reqeust + if err != nil { + t.Fatal(err) + } - n, err = conn.Read(buf) // server response - if err != nil { - t.Fatal(err) - } - if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) { - t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n]) + n, err = conn.Read(buf) // server response + if err != nil { + t.Fatal(err) + } + if n < 3 || !bytes.Equal(buf[:3], []byte{socks5Version, 0x00, 0x00}) { + t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n]) + } + udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n])) + if err != nil { + t.Fatal(err) + } + + return conn, udpProxySocksAddr } - udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n])) - if err != nil { - t.Fatal(err) + + conn, udpProxySocksAddr := newUdpAssociateConn() + defer conn.Close() + + sendUDPAndWaitResponse := func(socks5UDPConn net.Conn, addr socksAddr, body []byte) (responseBody []byte) { + udpPayload, err := (&udpRequest{addr: addr}).marshal() + if err != nil { + t.Fatal(err) + } + udpPayload = append(udpPayload, body...) + _, err = socks5UDPConn.Write(udpPayload) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := socks5UDPConn.Read(buf) + if err != nil { + t.Fatal(err) + } + _, responseBody, err = parseUDPRequest(buf[:n]) + if err != nil { + t.Fatal(err) + } + return responseBody } udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort()) if err != nil { t.Fatal(err) } - udpConn, err := net.DialUDP("udp", nil, udpProxyAddr) - if err != nil { - t.Fatal(err) - } - udpPayload, err := (&udpRequest{addr: targetAddr}).marshal() - if err != nil { - t.Fatal(err) - } - udpPayload = append(udpPayload, []byte("Test")...) - _, err = udpConn.Write(udpPayload) // send udp package - if err != nil { - t.Fatal(err) - } - n, _, err = udpConn.ReadFrom(buf) - if err != nil { - t.Fatal(err) - } - _, responseBody, err := parseUDPRequest(buf[:n]) // read udp response - if err != nil { - t.Fatal(err) - } - if string(responseBody) != "Test" { - t.Fatalf("got: %q want: Test", responseBody) - } - err = udpConn.Close() + socks5UDPConn, err := net.DialUDP("udp", nil, udpProxyAddr) if err != nil { t.Fatal(err) } - err = conn.Close() - if err != nil { - t.Fatal(err) + defer socks5UDPConn.Close() + + for i := 0; i < echoServerNumber; i++ { + port := echoServerListener[i].LocalAddr().(*net.UDPAddr).Port + addr := socksAddr{addrType: ipv4, addr: "127.0.0.1", port: uint16(port)} + requestBody := []byte(fmt.Sprintf("Test %d", i)) + responseBody := sendUDPAndWaitResponse(socks5UDPConn, addr, requestBody) + if !bytes.Equal(requestBody, responseBody) { + t.Fatalf("got: %q want: %q", responseBody, requestBody) + } } } From 45da3a4b28715fc123af9d60b0284971e2be3096 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sun, 3 Nov 2024 07:12:34 -0800 Subject: [PATCH 075/179] cmd/tsconnect: block after starting esbuild dev server Thanks to @davidbuzz for raising the issue in #13973. Fixes #8272 Fixes #13973 Change-Id: Ic413e14d34c82df3c70a97e591b90316b0b4946b Signed-off-by: Brad Fitzpatrick --- cmd/tsconnect/common.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/tsconnect/common.go b/cmd/tsconnect/common.go index a387c00c9758e..0b0813226383a 100644 --- a/cmd/tsconnect/common.go +++ b/cmd/tsconnect/common.go @@ -150,6 +150,7 @@ func runEsbuildServe(buildOptions esbuild.BuildOptions) { log.Fatalf("Cannot start esbuild server: %v", err) } log.Printf("Listening on http://%s:%d\n", result.Host, result.Port) + select {} } func runEsbuild(buildOptions esbuild.BuildOptions) esbuild.BuildResult { From d4222fae95c04102e75dbf97a8c3517a136881a4 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 22 Oct 2024 13:53:34 -0500 Subject: [PATCH 076/179] tsnet: add accessor to get tsd.System Pulled of otherwise unrelated PR #13884. Updates tailscale/corp#22075 Change-Id: I5b539fcb4aca1b93406cf139c719a5e3c64ff7f7 Signed-off-by: Brad Fitzpatrick --- tsnet/tsnet.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 7252d89fe9f64..70084c103e104 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -126,6 +126,7 @@ type Server struct { initOnce sync.Once initErr error lb *ipnlocal.LocalBackend + sys *tsd.System netstack *netstack.Impl netMon *netmon.Monitor rootPath string // the state directory @@ -518,6 +519,7 @@ func (s *Server) start() (reterr error) { } sys := new(tsd.System) + s.sys = sys if err := s.startLogger(&closePool, sys.HealthTracker(), tsLogf); err != nil { return err } @@ -1227,6 +1229,13 @@ func (s *Server) CapturePcap(ctx context.Context, pcapFile string) error { return nil } +// Sys returns a handle to the Tailscale subsystems of this node. +// +// This is not a stable API, nor are the APIs of the returned subsystems. +func (s *Server) Sys() *tsd.System { + return s.sys +} + type listenKey struct { network string host netip.Addr // or zero value for unspecified From 809a6eba80c94e7593b5f7d1604f1f4ac8a6b61c Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Mon, 4 Nov 2024 18:42:51 +0000 Subject: [PATCH 077/179] cmd/k8s-operator: allow to optionally configure tailscaled port (#14005) Updates tailscale/tailscale#13981 Signed-off-by: Irbe Krumina --- cmd/k8s-operator/operator.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cmd/k8s-operator/operator.go b/cmd/k8s-operator/operator.go index d8dd403cc6097..116ba02e0ce1c 100644 --- a/cmd/k8s-operator/operator.go +++ b/cmd/k8s-operator/operator.go @@ -11,6 +11,7 @@ import ( "context" "os" "regexp" + "strconv" "strings" "time" @@ -150,6 +151,13 @@ func initTSNet(zlog *zap.SugaredLogger) (*tsnet.Server, *tailscale.Client) { Hostname: hostname, Logf: zlog.Named("tailscaled").Debugf, } + if p := os.Getenv("TS_PORT"); p != "" { + port, err := strconv.ParseUint(p, 10, 16) + if err != nil { + startlog.Fatalf("TS_PORT %q cannot be parsed as uint16: %v", p, err) + } + s.Port = uint16(port) + } if kubeSecret != "" { st, err := kubestore.New(logger.Discard, kubeSecret) if err != nil { From 01185e436fd39c2aa499b3c56bcb08d6c4dc7b84 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 4 Nov 2024 20:49:40 -0800 Subject: [PATCH 078/179] types/result, util/lineiter: add package for a result type, use it This adds a new generic result type (motivated by golang/go#70084) to try it out, and uses it in the new lineutil package (replacing the old lineread package), changing that package to return iterators: sometimes over []byte (when the input is all in memory), but sometimes iterators over results of []byte, if errors might happen at runtime. Updates #12912 Updates golang/go#70084 Change-Id: Iacdc1070e661b5fb163907b1e8b07ac7d51d3f83 Signed-off-by: Brad Fitzpatrick --- cmd/derper/depaware.txt | 3 +- cmd/k8s-operator/depaware.txt | 3 +- cmd/stund/depaware.txt | 3 +- cmd/tailscale/depaware.txt | 3 +- cmd/tailscaled/depaware.txt | 3 +- hostinfo/hostinfo.go | 24 ++++----- hostinfo/hostinfo_linux.go | 13 +++-- ipn/ipnlocal/ssh.go | 22 ++++---- net/netmon/interfaces_android.go | 51 ++++++++---------- net/netmon/interfaces_darwin_test.go | 24 ++++----- net/netmon/interfaces_linux.go | 37 ++++++------- net/netmon/netmon_linux_test.go | 2 + net/tshttpproxy/tshttpproxy_synology.go | 15 +++--- ssh/tailssh/tailssh_test.go | 13 ++--- ssh/tailssh/user.go | 18 +++---- types/result/result.go | 49 +++++++++++++++++ util/lineiter/lineiter.go | 72 +++++++++++++++++++++++++ util/lineiter/lineiter_test.go | 32 +++++++++++ util/pidowner/pidowner_linux.go | 20 +++---- version/distro/distro.go | 20 +++---- 20 files changed, 289 insertions(+), 138 deletions(-) create mode 100644 types/result/result.go create mode 100644 util/lineiter/lineiter.go create mode 100644 util/lineiter/lineiter_test.go diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index e20c4e556da8f..a3eec2046e926 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -140,6 +140,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/types/persist from tailscale.com/ipn tailscale.com/types/preftype from tailscale.com/ipn tailscale.com/types/ptr from tailscale.com/hostinfo+ + tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/ipn+ tailscale.com/types/tkatype from tailscale.com/client/tailscale+ tailscale.com/types/views from tailscale.com/ipn+ @@ -154,7 +155,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/util/fastuuid from tailscale.com/tsweb 💣 tailscale.com/util/hashx from tailscale.com/util/deephash tailscale.com/util/httpm from tailscale.com/client/tailscale - tailscale.com/util/lineread from tailscale.com/hostinfo+ + tailscale.com/util/lineiter from tailscale.com/hostinfo+ L tailscale.com/util/linuxfw from tailscale.com/net/netns tailscale.com/util/mak from tailscale.com/health+ tailscale.com/util/multierr from tailscale.com/health+ diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index d62f2e225ca7e..74536c6c9d050 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -775,6 +775,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/types/persist from tailscale.com/control/controlclient+ tailscale.com/types/preftype from tailscale.com/ipn+ tailscale.com/types/ptr from tailscale.com/cmd/k8s-operator+ + tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/control/controlclient+ tailscale.com/types/tkatype from tailscale.com/client/tailscale+ tailscale.com/types/views from tailscale.com/appc+ @@ -792,7 +793,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ 💣 tailscale.com/util/hashx from tailscale.com/util/deephash tailscale.com/util/httphdr from tailscale.com/ipn/ipnlocal+ tailscale.com/util/httpm from tailscale.com/client/tailscale+ - tailscale.com/util/lineread from tailscale.com/hostinfo+ + tailscale.com/util/lineiter from tailscale.com/hostinfo+ L tailscale.com/util/linuxfw from tailscale.com/net/netns+ tailscale.com/util/mak from tailscale.com/appc+ tailscale.com/util/multierr from tailscale.com/control/controlclient+ diff --git a/cmd/stund/depaware.txt b/cmd/stund/depaware.txt index a35f59516ee32..7031b18e2087e 100644 --- a/cmd/stund/depaware.txt +++ b/cmd/stund/depaware.txt @@ -67,6 +67,7 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar tailscale.com/types/logger from tailscale.com/tsweb tailscale.com/types/opt from tailscale.com/envknob+ tailscale.com/types/ptr from tailscale.com/tailcfg+ + tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/tailcfg+ tailscale.com/types/tkatype from tailscale.com/tailcfg+ tailscale.com/types/views from tailscale.com/net/tsaddr+ @@ -74,7 +75,7 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics tailscale.com/util/dnsname from tailscale.com/tailcfg tailscale.com/util/fastuuid from tailscale.com/tsweb - tailscale.com/util/lineread from tailscale.com/version/distro + tailscale.com/util/lineiter from tailscale.com/version/distro tailscale.com/util/nocasemaps from tailscale.com/types/ipproto tailscale.com/util/slicesx from tailscale.com/tailcfg tailscale.com/util/vizerror from tailscale.com/tailcfg+ diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index cce76a81e0bfb..ac5440d2cfabf 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -148,6 +148,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/types/persist from tailscale.com/ipn tailscale.com/types/preftype from tailscale.com/cmd/tailscale/cli+ tailscale.com/types/ptr from tailscale.com/hostinfo+ + tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/ipn+ tailscale.com/types/tkatype from tailscale.com/types/key+ tailscale.com/types/views from tailscale.com/tailcfg+ @@ -162,7 +163,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/util/groupmember from tailscale.com/client/web 💣 tailscale.com/util/hashx from tailscale.com/util/deephash tailscale.com/util/httpm from tailscale.com/client/tailscale+ - tailscale.com/util/lineread from tailscale.com/hostinfo+ + tailscale.com/util/lineiter from tailscale.com/hostinfo+ L tailscale.com/util/linuxfw from tailscale.com/net/netns tailscale.com/util/mak from tailscale.com/cmd/tailscale/cli+ tailscale.com/util/multierr from tailscale.com/control/controlhttp+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 53e4790d38eeb..31a0cb67cb568 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -364,6 +364,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/types/persist from tailscale.com/control/controlclient+ tailscale.com/types/preftype from tailscale.com/ipn+ tailscale.com/types/ptr from tailscale.com/control/controlclient+ + tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/control/controlclient+ tailscale.com/types/tkatype from tailscale.com/tka+ tailscale.com/types/views from tailscale.com/ipn/ipnlocal+ @@ -381,7 +382,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de 💣 tailscale.com/util/hashx from tailscale.com/util/deephash tailscale.com/util/httphdr from tailscale.com/ipn/ipnlocal+ tailscale.com/util/httpm from tailscale.com/client/tailscale+ - tailscale.com/util/lineread from tailscale.com/hostinfo+ + tailscale.com/util/lineiter from tailscale.com/hostinfo+ L tailscale.com/util/linuxfw from tailscale.com/net/netns+ tailscale.com/util/mak from tailscale.com/control/controlclient+ tailscale.com/util/multierr from tailscale.com/cmd/tailscaled+ diff --git a/hostinfo/hostinfo.go b/hostinfo/hostinfo.go index 3233a422dd6c3..3d4216922a12b 100644 --- a/hostinfo/hostinfo.go +++ b/hostinfo/hostinfo.go @@ -25,7 +25,7 @@ import ( "tailscale.com/types/ptr" "tailscale.com/util/cloudenv" "tailscale.com/util/dnsname" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/version" "tailscale.com/version/distro" ) @@ -231,12 +231,12 @@ func desktop() (ret opt.Bool) { } seenDesktop := false - lineread.File("/proc/net/unix", func(line []byte) error { + for lr := range lineiter.File("/proc/net/unix") { + line, _ := lr.Value() seenDesktop = seenDesktop || mem.Contains(mem.B(line), mem.S(" @/tmp/dbus-")) seenDesktop = seenDesktop || mem.Contains(mem.B(line), mem.S(".X11-unix")) seenDesktop = seenDesktop || mem.Contains(mem.B(line), mem.S("/wayland-1")) - return nil - }) + } ret.Set(seenDesktop) // Only cache after a minute - compositors might not have started yet. @@ -305,21 +305,21 @@ func inContainer() opt.Bool { ret.Set(true) return ret } - lineread.File("/proc/1/cgroup", func(line []byte) error { + for lr := range lineiter.File("/proc/1/cgroup") { + line, _ := lr.Value() if mem.Contains(mem.B(line), mem.S("/docker/")) || mem.Contains(mem.B(line), mem.S("/lxc/")) { ret.Set(true) - return io.EOF // arbitrary non-nil error to stop loop + break } - return nil - }) - lineread.File("/proc/mounts", func(line []byte) error { + } + for lr := range lineiter.File("/proc/mounts") { + line, _ := lr.Value() if mem.Contains(mem.B(line), mem.S("lxcfs /proc/cpuinfo fuse.lxcfs")) { ret.Set(true) - return io.EOF + break } - return nil - }) + } return ret } diff --git a/hostinfo/hostinfo_linux.go b/hostinfo/hostinfo_linux.go index 53d4187bc0c67..66484a3588027 100644 --- a/hostinfo/hostinfo_linux.go +++ b/hostinfo/hostinfo_linux.go @@ -12,7 +12,7 @@ import ( "golang.org/x/sys/unix" "tailscale.com/types/ptr" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/version/distro" ) @@ -106,15 +106,18 @@ func linuxVersionMeta() (meta versionMeta) { } m := map[string]string{} - lineread.File(propFile, func(line []byte) error { + for lr := range lineiter.File(propFile) { + line, err := lr.Value() + if err != nil { + break + } eq := bytes.IndexByte(line, '=') if eq == -1 { - return nil + continue } k, v := string(line[:eq]), strings.Trim(string(line[eq+1:]), `"'`) m[k] = v - return nil - }) + } if v := m["VERSION_CODENAME"]; v != "" { meta.DistroCodeName = v diff --git a/ipn/ipnlocal/ssh.go b/ipn/ipnlocal/ssh.go index fbeb19bd15bd1..383d03f5aa9be 100644 --- a/ipn/ipnlocal/ssh.go +++ b/ipn/ipnlocal/ssh.go @@ -27,7 +27,7 @@ import ( "github.com/tailscale/golang-x-crypto/ssh" "go4.org/mem" "tailscale.com/tailcfg" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/util/mak" ) @@ -80,30 +80,32 @@ func (b *LocalBackend) getSSHUsernames(req *tailcfg.C2NSSHUsernamesRequest) (*ta if err != nil { return nil, err } - lineread.Reader(bytes.NewReader(out), func(line []byte) error { + for line := range lineiter.Bytes(out) { line = bytes.TrimSpace(line) if len(line) == 0 || line[0] == '_' { - return nil + continue } add(string(line)) - return nil - }) + } default: - lineread.File("/etc/passwd", func(line []byte) error { + for lr := range lineiter.File("/etc/passwd") { + line, err := lr.Value() + if err != nil { + break + } line = bytes.TrimSpace(line) if len(line) == 0 || line[0] == '#' || line[0] == '_' { - return nil + continue } if mem.HasSuffix(mem.B(line), mem.S("/nologin")) || mem.HasSuffix(mem.B(line), mem.S("/false")) { - return nil + continue } colon := bytes.IndexByte(line, ':') if colon != -1 { add(string(line[:colon])) } - return nil - }) + } } return res, nil } diff --git a/net/netmon/interfaces_android.go b/net/netmon/interfaces_android.go index a96423eb6bfeb..26104e879a393 100644 --- a/net/netmon/interfaces_android.go +++ b/net/netmon/interfaces_android.go @@ -5,7 +5,6 @@ package netmon import ( "bytes" - "errors" "log" "net/netip" "os/exec" @@ -15,7 +14,7 @@ import ( "golang.org/x/sys/unix" "tailscale.com/net/netaddr" "tailscale.com/syncs" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" ) var ( @@ -34,11 +33,6 @@ func init() { var procNetRouteErr atomic.Bool -// errStopReading is a sentinel error value used internally by -// lineread.File callers to stop reading. It doesn't escape to -// callers/users. -var errStopReading = errors.New("stop reading") - /* Parse 10.0.0.1 out of: @@ -54,44 +48,42 @@ func likelyHomeRouterIPAndroid() (ret netip.Addr, myIP netip.Addr, ok bool) { } lineNum := 0 var f []mem.RO - err := lineread.File(procNetRoutePath, func(line []byte) error { + for lr := range lineiter.File(procNetRoutePath) { + line, err := lr.Value() + if err != nil { + procNetRouteErr.Store(true) + return likelyHomeRouterIP() + } + lineNum++ if lineNum == 1 { // Skip header line. - return nil + continue } if lineNum > maxProcNetRouteRead { - return errStopReading + break } f = mem.AppendFields(f[:0], mem.B(line)) if len(f) < 4 { - return nil + continue } gwHex, flagsHex := f[2], f[3] flags, err := mem.ParseUint(flagsHex, 16, 16) if err != nil { - return nil // ignore error, skip line and keep going + continue // ignore error, skip line and keep going } if flags&(unix.RTF_UP|unix.RTF_GATEWAY) != unix.RTF_UP|unix.RTF_GATEWAY { - return nil + continue } ipu32, err := mem.ParseUint(gwHex, 16, 32) if err != nil { - return nil // ignore error, skip line and keep going + continue // ignore error, skip line and keep going } ip := netaddr.IPv4(byte(ipu32), byte(ipu32>>8), byte(ipu32>>16), byte(ipu32>>24)) if ip.IsPrivate() { ret = ip - return errStopReading + break } - return nil - }) - if errors.Is(err, errStopReading) { - err = nil - } - if err != nil { - procNetRouteErr.Store(true) - return likelyHomeRouterIP() } if ret.IsValid() { // Try to get the local IP of the interface associated with @@ -144,23 +136,26 @@ func likelyHomeRouterIPHelper() (ret netip.Addr, _ netip.Addr, ok bool) { return } // Search for line like "default via 10.0.2.2 dev radio0 table 1016 proto static mtu 1500 " - lineread.Reader(out, func(line []byte) error { + for lr := range lineiter.Reader(out) { + line, err := lr.Value() + if err != nil { + break + } const pfx = "default via " if !mem.HasPrefix(mem.B(line), mem.S(pfx)) { - return nil + continue } line = line[len(pfx):] sp := bytes.IndexByte(line, ' ') if sp == -1 { - return nil + continue } ipb := line[:sp] if ip, err := netip.ParseAddr(string(ipb)); err == nil && ip.Is4() { ret = ip log.Printf("interfaces: found Android default route %v", ip) } - return nil - }) + } cmd.Process.Kill() cmd.Wait() return ret, netip.Addr{}, ret.IsValid() diff --git a/net/netmon/interfaces_darwin_test.go b/net/netmon/interfaces_darwin_test.go index d34040d60d31d..d756d13348bc3 100644 --- a/net/netmon/interfaces_darwin_test.go +++ b/net/netmon/interfaces_darwin_test.go @@ -4,14 +4,13 @@ package netmon import ( - "errors" "io" "net/netip" "os/exec" "testing" "go4.org/mem" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/version" ) @@ -73,31 +72,34 @@ func likelyHomeRouterIPDarwinExec() (ret netip.Addr, netif string, ok bool) { defer io.Copy(io.Discard, stdout) // clear the pipe to prevent hangs var f []mem.RO - lineread.Reader(stdout, func(lineb []byte) error { + for lr := range lineiter.Reader(stdout) { + lineb, err := lr.Value() + if err != nil { + break + } line := mem.B(lineb) if !mem.Contains(line, mem.S("default")) { - return nil + continue } f = mem.AppendFields(f[:0], line) if len(f) < 4 || !f[0].EqualString("default") { - return nil + continue } ipm, flagsm, netifm := f[1], f[2], f[3] if !mem.Contains(flagsm, mem.S("G")) { - return nil + continue } if mem.Contains(flagsm, mem.S("I")) { - return nil + continue } ip, err := netip.ParseAddr(string(mem.Append(nil, ipm))) if err == nil && ip.IsPrivate() { ret = ip netif = netifm.StringCopy() // We've found what we're looking for. - return errStopReadingNetstatTable + break } - return nil - }) + } return ret, netif, ret.IsValid() } @@ -110,5 +112,3 @@ func TestFetchRoutingTable(t *testing.T) { } } } - -var errStopReadingNetstatTable = errors.New("found private gateway") diff --git a/net/netmon/interfaces_linux.go b/net/netmon/interfaces_linux.go index 299f3101ea73b..d0fb15ababe9e 100644 --- a/net/netmon/interfaces_linux.go +++ b/net/netmon/interfaces_linux.go @@ -23,7 +23,7 @@ import ( "go4.org/mem" "golang.org/x/sys/unix" "tailscale.com/net/netaddr" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" ) func init() { @@ -32,11 +32,6 @@ func init() { var procNetRouteErr atomic.Bool -// errStopReading is a sentinel error value used internally by -// lineread.File callers to stop reading. It doesn't escape to -// callers/users. -var errStopReading = errors.New("stop reading") - /* Parse 10.0.0.1 out of: @@ -52,44 +47,42 @@ func likelyHomeRouterIPLinux() (ret netip.Addr, myIP netip.Addr, ok bool) { } lineNum := 0 var f []mem.RO - err := lineread.File(procNetRoutePath, func(line []byte) error { + for lr := range lineiter.File(procNetRoutePath) { + line, err := lr.Value() + if err != nil { + procNetRouteErr.Store(true) + log.Printf("interfaces: failed to read /proc/net/route: %v", err) + return ret, myIP, false + } lineNum++ if lineNum == 1 { // Skip header line. - return nil + continue } if lineNum > maxProcNetRouteRead { - return errStopReading + break } f = mem.AppendFields(f[:0], mem.B(line)) if len(f) < 4 { - return nil + continue } gwHex, flagsHex := f[2], f[3] flags, err := mem.ParseUint(flagsHex, 16, 16) if err != nil { - return nil // ignore error, skip line and keep going + continue // ignore error, skip line and keep going } if flags&(unix.RTF_UP|unix.RTF_GATEWAY) != unix.RTF_UP|unix.RTF_GATEWAY { - return nil + continue } ipu32, err := mem.ParseUint(gwHex, 16, 32) if err != nil { - return nil // ignore error, skip line and keep going + continue // ignore error, skip line and keep going } ip := netaddr.IPv4(byte(ipu32), byte(ipu32>>8), byte(ipu32>>16), byte(ipu32>>24)) if ip.IsPrivate() { ret = ip - return errStopReading + break } - return nil - }) - if errors.Is(err, errStopReading) { - err = nil - } - if err != nil { - procNetRouteErr.Store(true) - log.Printf("interfaces: failed to read /proc/net/route: %v", err) } if ret.IsValid() { // Try to get the local IP of the interface associated with diff --git a/net/netmon/netmon_linux_test.go b/net/netmon/netmon_linux_test.go index d09fac26aecee..75d7c646559f1 100644 --- a/net/netmon/netmon_linux_test.go +++ b/net/netmon/netmon_linux_test.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build linux && !android + package netmon import ( diff --git a/net/tshttpproxy/tshttpproxy_synology.go b/net/tshttpproxy/tshttpproxy_synology.go index cda95764865d4..2e50d26d3a655 100644 --- a/net/tshttpproxy/tshttpproxy_synology.go +++ b/net/tshttpproxy/tshttpproxy_synology.go @@ -17,7 +17,7 @@ import ( "sync" "time" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" ) // These vars are overridden for tests. @@ -76,21 +76,22 @@ func synologyProxiesFromConfig() (*url.URL, *url.URL, error) { func parseSynologyConfig(r io.Reader) (*url.URL, *url.URL, error) { cfg := map[string]string{} - if err := lineread.Reader(r, func(line []byte) error { + for lr := range lineiter.Reader(r) { + line, err := lr.Value() + if err != nil { + return nil, nil, err + } // accept and skip over empty lines line = bytes.TrimSpace(line) if len(line) == 0 { - return nil + continue } key, value, ok := strings.Cut(string(line), "=") if !ok { - return fmt.Errorf("missing \"=\" in proxy.conf line: %q", line) + return nil, nil, fmt.Errorf("missing \"=\" in proxy.conf line: %q", line) } cfg[string(key)] = string(value) - return nil - }); err != nil { - return nil, nil, err } if cfg["proxy_enabled"] != "yes" { diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 9e4f5ffd3d481..7ce0aeea3b2fa 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -48,7 +48,7 @@ import ( "tailscale.com/types/netmap" "tailscale.com/types/ptr" "tailscale.com/util/cibuild" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/util/must" "tailscale.com/version/distro" "tailscale.com/wgengine" @@ -1123,14 +1123,11 @@ func TestSSH(t *testing.T) { func parseEnv(out []byte) map[string]string { e := map[string]string{} - lineread.Reader(bytes.NewReader(out), func(line []byte) error { - i := bytes.IndexByte(line, '=') - if i == -1 { - return nil + for line := range lineiter.Bytes(out) { + if i := bytes.IndexByte(line, '='); i != -1 { + e[string(line[:i])] = string(line[i+1:]) } - e[string(line[:i])] = string(line[i+1:]) - return nil - }) + } return e } diff --git a/ssh/tailssh/user.go b/ssh/tailssh/user.go index 33ebb4db729de..15191813bdca6 100644 --- a/ssh/tailssh/user.go +++ b/ssh/tailssh/user.go @@ -6,7 +6,6 @@ package tailssh import ( - "io" "os" "os/exec" "os/user" @@ -18,7 +17,7 @@ import ( "go4.org/mem" "tailscale.com/envknob" "tailscale.com/hostinfo" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/util/osuser" "tailscale.com/version/distro" ) @@ -110,15 +109,16 @@ func defaultPathForUser(u *user.User) string { } func defaultPathForUserOnNixOS(u *user.User) string { - var path string - lineread.File("/etc/pam/environment", func(lineb []byte) error { + for lr := range lineiter.File("/etc/pam/environment") { + lineb, err := lr.Value() + if err != nil { + return "" + } if v := pathFromPAMEnvLine(lineb, u); v != "" { - path = v - return io.EOF // stop iteration + return v } - return nil - }) - return path + } + return "" } func pathFromPAMEnvLine(line []byte, u *user.User) (path string) { diff --git a/types/result/result.go b/types/result/result.go new file mode 100644 index 0000000000000..6bd1c2ea62004 --- /dev/null +++ b/types/result/result.go @@ -0,0 +1,49 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package result contains the Of result type, which is +// either a value or an error. +package result + +// Of is either a T value or an error. +// +// Think of it like Rust or Swift's result types. +// It's named "Of" because the fully qualified name +// for callers reads result.Of[T]. +type Of[T any] struct { + v T // valid if Err is nil; invalid if Err is non-nil + err error +} + +// Value returns a new result with value v, +// without an error. +func Value[T any](v T) Of[T] { + return Of[T]{v: v} +} + +// Error returns a new result with error err. +// If err is nil, the returned result is equivalent +// to calling Value with T's zero value. +func Error[T any](err error) Of[T] { + return Of[T]{err: err} +} + +// MustValue returns r's result value. +// It panics if r.Err returns non-nil. +func (r Of[T]) MustValue() T { + if r.err != nil { + panic(r.err) + } + return r.v +} + +// Value returns r's result value and error. +func (r Of[T]) Value() (T, error) { + return r.v, r.err +} + +// Err returns r's error, if any. +// When r.Err returns nil, it's safe to call r.MustValue without it panicking. +func (r Of[T]) Err() error { + return r.err +} diff --git a/util/lineiter/lineiter.go b/util/lineiter/lineiter.go new file mode 100644 index 0000000000000..5cb1eeef3ee1d --- /dev/null +++ b/util/lineiter/lineiter.go @@ -0,0 +1,72 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package lineiter iterates over lines in things. +package lineiter + +import ( + "bufio" + "bytes" + "io" + "iter" + "os" + + "tailscale.com/types/result" +) + +// File returns an iterator that reads lines from the named file. +// +// The returned substrings don't include the trailing newline. +// Lines may be empty. +func File(name string) iter.Seq[result.Of[[]byte]] { + f, err := os.Open(name) + return reader(f, f, err) +} + +// Bytes returns an iterator over the lines in bs. +// The returned substrings don't include the trailing newline. +// Lines may be empty. +func Bytes(bs []byte) iter.Seq[[]byte] { + return func(yield func([]byte) bool) { + for len(bs) > 0 { + i := bytes.IndexByte(bs, '\n') + if i < 0 { + yield(bs) + return + } + if !yield(bs[:i]) { + return + } + bs = bs[i+1:] + } + } +} + +// Reader returns an iterator over the lines in r. +// +// The returned substrings don't include the trailing newline. +// Lines may be empty. +func Reader(r io.Reader) iter.Seq[result.Of[[]byte]] { + return reader(r, nil, nil) +} + +func reader(r io.Reader, c io.Closer, err error) iter.Seq[result.Of[[]byte]] { + return func(yield func(result.Of[[]byte]) bool) { + if err != nil { + yield(result.Error[[]byte](err)) + return + } + if c != nil { + defer c.Close() + } + bs := bufio.NewScanner(r) + for bs.Scan() { + if !yield(result.Value(bs.Bytes())) { + return + } + } + if err := bs.Err(); err != nil { + yield(result.Error[[]byte](err)) + } + } +} diff --git a/util/lineiter/lineiter_test.go b/util/lineiter/lineiter_test.go new file mode 100644 index 0000000000000..3373d5fe7b122 --- /dev/null +++ b/util/lineiter/lineiter_test.go @@ -0,0 +1,32 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lineiter + +import ( + "slices" + "strings" + "testing" +) + +func TestBytesLines(t *testing.T) { + var got []string + for line := range Bytes([]byte("foo\n\nbar\nbaz")) { + got = append(got, string(line)) + } + want := []string{"foo", "", "bar", "baz"} + if !slices.Equal(got, want) { + t.Errorf("got %q; want %q", got, want) + } +} + +func TestReader(t *testing.T) { + var got []string + for line := range Reader(strings.NewReader("foo\n\nbar\nbaz")) { + got = append(got, string(line.MustValue())) + } + want := []string{"foo", "", "bar", "baz"} + if !slices.Equal(got, want) { + t.Errorf("got %q; want %q", got, want) + } +} diff --git a/util/pidowner/pidowner_linux.go b/util/pidowner/pidowner_linux.go index 2a5181f14e03c..a07f512427062 100644 --- a/util/pidowner/pidowner_linux.go +++ b/util/pidowner/pidowner_linux.go @@ -8,26 +8,26 @@ import ( "os" "strings" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" ) func ownerOfPID(pid int) (userID string, err error) { file := fmt.Sprintf("/proc/%d/status", pid) - err = lineread.File(file, func(line []byte) error { + for lr := range lineiter.File(file) { + line, err := lr.Value() + if err != nil { + if os.IsNotExist(err) { + return "", ErrProcessNotFound + } + return "", err + } if len(line) < 4 || string(line[:4]) != "Uid:" { - return nil + continue } f := strings.Fields(string(line)) if len(f) >= 2 { userID = f[1] // real userid } - return nil - }) - if os.IsNotExist(err) { - return "", ErrProcessNotFound - } - if err != nil { - return } if userID == "" { return "", fmt.Errorf("missing Uid line in %s", file) diff --git a/version/distro/distro.go b/version/distro/distro.go index 8865a834b97d3..ce61137cf3280 100644 --- a/version/distro/distro.go +++ b/version/distro/distro.go @@ -6,13 +6,12 @@ package distro import ( "bytes" - "io" "os" "runtime" "strconv" "tailscale.com/types/lazy" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" ) type Distro string @@ -132,18 +131,19 @@ func DSMVersion() int { return v } // But when run from the command line, we have to read it from the file: - lineread.File("/etc/VERSION", func(line []byte) error { + for lr := range lineiter.File("/etc/VERSION") { + line, err := lr.Value() + if err != nil { + break // but otherwise ignore + } line = bytes.TrimSpace(line) if string(line) == `majorversion="7"` { - v = 7 - return io.EOF + return 7 } if string(line) == `majorversion="6"` { - v = 6 - return io.EOF + return 6 } - return nil - }) - return v + } + return 0 }) } From 065825e94c143bf50f997528332fd63cf47b6cda Mon Sep 17 00:00:00 2001 From: License Updater Date: Mon, 4 Nov 2024 15:02:25 +0000 Subject: [PATCH 079/179] licenses: update license notices Signed-off-by: License Updater --- licenses/windows.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/licenses/windows.md b/licenses/windows.md index 3f6650b9eaff2..8cef256853e56 100644 --- a/licenses/windows.md +++ b/licenses/windows.md @@ -57,8 +57,8 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/tailscale/go-winio](https://pkg.go.dev/github.com/tailscale/go-winio) ([MIT](https://github.com/tailscale/go-winio/blob/c4f33415bf55/LICENSE)) - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/20486734a56a/LICENSE)) - [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/4d49adab4de7/LICENSE)) - - [github.com/tailscale/walk](https://pkg.go.dev/github.com/tailscale/walk) ([BSD-3-Clause](https://github.com/tailscale/walk/blob/52804fd3056a/LICENSE)) - - [github.com/tailscale/win](https://pkg.go.dev/github.com/tailscale/win) ([BSD-3-Clause](https://github.com/tailscale/win/blob/6580b55d49ca/LICENSE)) + - [github.com/tailscale/walk](https://pkg.go.dev/github.com/tailscale/walk) ([BSD-3-Clause](https://github.com/tailscale/walk/blob/8865133fd3ef/LICENSE)) + - [github.com/tailscale/win](https://pkg.go.dev/github.com/tailscale/win) ([BSD-3-Clause](https://github.com/tailscale/win/blob/28f7e73c7afb/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - [github.com/tc-hib/winres](https://pkg.go.dev/github.com/tc-hib/winres) ([0BSD](https://github.com/tc-hib/winres/blob/v0.2.1/LICENSE)) - [github.com/vishvananda/netns](https://pkg.go.dev/github.com/vishvananda/netns) ([Apache-2.0](https://github.com/vishvananda/netns/blob/v0.0.4/LICENSE)) From 8dcbd988f7653aa17b33094d3f917125414aeab6 Mon Sep 17 00:00:00 2001 From: Percy Wegmann Date: Wed, 23 Oct 2024 20:56:09 -0500 Subject: [PATCH 080/179] cmd/derper: show more information on home page - Basic description of DERP If configured to do so, also show - Mailto link to security@tailscale.com - Link to Tailscale Security Policies - Link to Tailscale Acceptable Use Policy Updates tailscale/corp#24092 Signed-off-by: Percy Wegmann --- cmd/derper/depaware.txt | 3 ++ cmd/derper/derper.go | 79 +++++++++++++++++++++++++++++---------- cmd/derper/derper_test.go | 29 ++++++++++++++ 3 files changed, 92 insertions(+), 19 deletions(-) diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index a3eec2046e926..8fa5334aa7ebe 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -264,6 +264,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa hash/fnv from google.golang.org/protobuf/internal/detrand hash/maphash from go4.org/mem html from net/http/pprof+ + html/template from tailscale.com/cmd/derper io from bufio+ io/fs from crypto/x509+ io/ioutil from github.com/mitchellh/go-ps+ @@ -308,6 +309,8 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa sync/atomic from context+ syscall from crypto/rand+ text/tabwriter from runtime/pprof + text/template from html/template + text/template/parse from html/template+ time from compress/gzip+ unicode from bytes+ unicode/utf16 from crypto/x509+ diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index 80c9dc44f138f..51be3abbe3b93 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -19,6 +19,7 @@ import ( "expvar" "flag" "fmt" + "html/template" "io" "log" "math" @@ -212,25 +213,16 @@ func main() { tsweb.AddBrowserHeaders(w) w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(200) - io.WriteString(w, ` -

DERP

-

- This is a Tailscale DERP server. -

-

- Documentation: -

- -`) - if !*runDERP { - io.WriteString(w, `

Status: disabled

`) - } - if tsweb.AllowDebugAccess(r) { - io.WriteString(w, "

Debug info at /debug/.

\n") + err := homePageTemplate.Execute(w, templateData{ + ShowAbuseInfo: validProdHostname.MatchString(*hostname), + Disabled: !*runDERP, + AllowDebug: tsweb.AllowDebugAccess(r), + }) + if err != nil { + if r.Context().Err() == nil { + log.Printf("homePageTemplate.Execute: %v", err) + } + return } })) mux.Handle("/robots.txt", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -468,3 +460,52 @@ func init() { return 0 })) } + +type templateData struct { + ShowAbuseInfo bool + Disabled bool + AllowDebug bool +} + +// homePageTemplate renders the home page using [templateData]. +var homePageTemplate = template.Must(template.New("home").Parse(` +

DERP

+

+ This is a Tailscale DERP server. +

+ +

+ It provides STUN, interactive connectivity establishment, and relaying of end-to-end encrypted traffic + for Tailscale clients. +

+ +{{if .ShowAbuseInfo }} +

+ If you suspect abuse, please contact security@tailscale.com. +

+{{end}} + +

+ Documentation: +

+ + + +{{if .Disabled}} +

Status: disabled

+{{end}} + +{{if .AllowDebug}} +

Debug info at /debug/.

+{{end}} + + +`)) diff --git a/cmd/derper/derper_test.go b/cmd/derper/derper_test.go index 553a78f9f6426..6ddf4455b0495 100644 --- a/cmd/derper/derper_test.go +++ b/cmd/derper/derper_test.go @@ -4,7 +4,9 @@ package main import ( + "bytes" "context" + "fmt" "net/http" "net/http/httptest" "strings" @@ -110,3 +112,30 @@ func TestDeps(t *testing.T) { }, }.Check(t) } + +func TestTemplate(t *testing.T) { + buf := &bytes.Buffer{} + err := homePageTemplate.Execute(buf, templateData{ + ShowAbuseInfo: true, + Disabled: true, + AllowDebug: true, + }) + if err != nil { + t.Fatal(err) + } + + str := buf.String() + if !strings.Contains(str, "If you suspect abuse") { + t.Error("Output is missing abuse mailto") + } + if !strings.Contains(str, "Tailscale Security Policies") { + t.Error("Output is missing Tailscale Security Policies link") + } + if !strings.Contains(str, "Status:") { + t.Error("Output is missing disabled status") + } + if !strings.Contains(str, "Debug info") { + t.Error("Output is missing debug info") + } + fmt.Println(buf.String()) +} From 8ba9b558d2a8efe172f7a005ec1e6572b60f05e2 Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Thu, 7 Nov 2024 12:42:29 +0000 Subject: [PATCH 081/179] envknob,kube/kubetypes,cmd/k8s-operator: add app type for ProxyGroup (#14029) Sets a custom hostinfo app type for ProxyGroup replicas, similarly to how we do it for all other Kubernetes Operator managed components. Updates tailscale/tailscale#13406,tailscale/corp#22920 Signed-off-by: Irbe Krumina --- cmd/k8s-operator/proxygroup.go | 2 +- cmd/k8s-operator/proxygroup_specs.go | 5 +++++ envknob/envknob.go | 2 +- kube/kubetypes/metrics.go | 17 ++++++++++------- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/cmd/k8s-operator/proxygroup.go b/cmd/k8s-operator/proxygroup.go index 1f9983aa98962..7dad9e573e151 100644 --- a/cmd/k8s-operator/proxygroup.go +++ b/cmd/k8s-operator/proxygroup.go @@ -47,7 +47,7 @@ const ( reasonProxyGroupInvalid = "ProxyGroupInvalid" ) -var gaugeProxyGroupResources = clientmetric.NewGauge(kubetypes.MetricProxyGroupCount) +var gaugeProxyGroupResources = clientmetric.NewGauge(kubetypes.MetricProxyGroupEgressCount) // ProxyGroupReconciler ensures cluster resources for a ProxyGroup definition. type ProxyGroupReconciler struct { diff --git a/cmd/k8s-operator/proxygroup_specs.go b/cmd/k8s-operator/proxygroup_specs.go index 9aa7ac3b008a3..f9d1ea52be221 100644 --- a/cmd/k8s-operator/proxygroup_specs.go +++ b/cmd/k8s-operator/proxygroup_specs.go @@ -15,6 +15,7 @@ import ( "sigs.k8s.io/yaml" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/egressservices" + "tailscale.com/kube/kubetypes" "tailscale.com/types/ptr" ) @@ -146,6 +147,10 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode, cfgHa Name: "TS_USERSPACE", Value: "false", }, + { + Name: "TS_INTERNAL_APP", + Value: kubetypes.AppProxyGroupEgress, + }, } if tsFirewallMode != "" { diff --git a/envknob/envknob.go b/envknob/envknob.go index 59a6d90af213b..e74bfea71bdb3 100644 --- a/envknob/envknob.go +++ b/envknob/envknob.go @@ -411,7 +411,7 @@ func TKASkipSignatureCheck() bool { return Bool("TS_UNSAFE_SKIP_NKS_VERIFICATION // Kubernetes Operator components. func App() string { a := os.Getenv("TS_INTERNAL_APP") - if a == kubetypes.AppConnector || a == kubetypes.AppEgressProxy || a == kubetypes.AppIngressProxy || a == kubetypes.AppIngressResource { + if a == kubetypes.AppConnector || a == kubetypes.AppEgressProxy || a == kubetypes.AppIngressProxy || a == kubetypes.AppIngressResource || a == kubetypes.AppProxyGroupEgress || a == kubetypes.AppProxyGroupIngress { return a } return "" diff --git a/kube/kubetypes/metrics.go b/kube/kubetypes/metrics.go index b183f1f6f79f7..63078385ad293 100644 --- a/kube/kubetypes/metrics.go +++ b/kube/kubetypes/metrics.go @@ -5,12 +5,14 @@ package kubetypes const ( // Hostinfo App values for the Tailscale Kubernetes Operator components. - AppOperator = "k8s-operator" - AppAPIServerProxy = "k8s-operator-proxy" - AppIngressProxy = "k8s-operator-ingress-proxy" - AppIngressResource = "k8s-operator-ingress-resource" - AppEgressProxy = "k8s-operator-egress-proxy" - AppConnector = "k8s-operator-connector-resource" + AppOperator = "k8s-operator" + AppAPIServerProxy = "k8s-operator-proxy" + AppIngressProxy = "k8s-operator-ingress-proxy" + AppIngressResource = "k8s-operator-ingress-resource" + AppEgressProxy = "k8s-operator-egress-proxy" + AppConnector = "k8s-operator-connector-resource" + AppProxyGroupEgress = "k8s-operator-proxygroup-egress" + AppProxyGroupIngress = "k8s-operator-proxygroup-ingress" // Clientmetrics for Tailscale Kubernetes Operator components MetricIngressProxyCount = "k8s_ingress_proxies" // L3 @@ -22,5 +24,6 @@ const ( MetricNameserverCount = "k8s_nameserver_resources" MetricRecorderCount = "k8s_recorder_resources" MetricEgressServiceCount = "k8s_egress_service_resources" - MetricProxyGroupCount = "k8s_proxygroup_resources" + MetricProxyGroupEgressCount = "k8s_proxygroup_egress_resources" + MetricProxyGroupIngressCount = "k8s_proxygroup_ingress_resources" ) From 3090461961e30fffb5a28b1432c47a627177a5a1 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 7 Nov 2024 08:02:14 -0800 Subject: [PATCH 082/179] tsweb/varz: optimize some allocs, add helper func for others Updates #cleanup Updates tailscale/corp#23546 (noticed when doing this) Change-Id: Ia9f627fe32bb4955739b2787210ba18f5de27f4d Signed-off-by: Brad Fitzpatrick --- tsweb/varz/varz.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tsweb/varz/varz.go b/tsweb/varz/varz.go index 561b2487710e3..952ebc23134c2 100644 --- a/tsweb/varz/varz.go +++ b/tsweb/varz/varz.go @@ -23,10 +23,16 @@ import ( "tailscale.com/version" ) +// StaticStringVar returns a new expvar.Var that always returns s. +func StaticStringVar(s string) expvar.Var { + var v any = s // box s into an interface just once + return expvar.Func(func() any { return v }) +} + func init() { expvar.Publish("process_start_unix_time", expvar.Func(func() any { return timeStart.Unix() })) - expvar.Publish("version", expvar.Func(func() any { return version.Long() })) - expvar.Publish("go_version", expvar.Func(func() any { return runtime.Version() })) + expvar.Publish("version", StaticStringVar(version.Long())) + expvar.Publish("go_version", StaticStringVar(runtime.Version())) expvar.Publish("counter_uptime_sec", expvar.Func(func() any { return int64(Uptime().Seconds()) })) expvar.Publish("gauge_goroutines", expvar.Func(func() any { return runtime.NumGoroutine() })) } From 2c8859c2e725af2de59203c0b2d39b96f135cb60 Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Thu, 7 Nov 2024 19:27:53 +0000 Subject: [PATCH 083/179] client/tailscale,ipn/{ipnlocal,localapi}: add a pre-shutdown localAPI endpoint that terminates control connections. (#14028) Adds a /disconnect-control local API endpoint that just shuts down control client. This can be run before shutting down an HA subnet router/app connector replica - it will ensure that all connection to control are dropped and control thus considers this node inactive and tells peers to switch over to another replica. Meanwhile the existing connections keep working (assuming that the replica is given some graceful shutdown period). Updates tailscale/tailscale#14020 Signed-off-by: Irbe Krumina --- client/tailscale/localclient.go | 11 +++++++++++ ipn/ipnlocal/local.go | 13 +++++++++++++ ipn/localapi/localapi.go | 17 +++++++++++++++++ 3 files changed, 41 insertions(+) diff --git a/client/tailscale/localclient.go b/client/tailscale/localclient.go index 9c2bcc467b0e2..5eb66817698b7 100644 --- a/client/tailscale/localclient.go +++ b/client/tailscale/localclient.go @@ -1327,6 +1327,17 @@ func (lc *LocalClient) SetServeConfig(ctx context.Context, config *ipn.ServeConf return nil } +// DisconnectControl shuts down all connections to control, thus making control consider this node inactive. This can be +// run on HA subnet router or app connector replicas before shutting them down to ensure peers get told to switch over +// to another replica whilst there is still some grace period for the existing connections to terminate. +func (lc *LocalClient) DisconnectControl(ctx context.Context) error { + _, _, err := lc.sendWithHeaders(ctx, "POST", "/localapi/v0/disconnect-control", 200, nil, nil) + if err != nil { + return fmt.Errorf("error disconnecting control: %w", err) + } + return nil +} + // NetworkLockDisable shuts down network-lock across the tailnet. func (lc *LocalClient) NetworkLockDisable(ctx context.Context, secret []byte) error { if _, err := lc.send(ctx, "POST", "/localapi/v0/tka/disable", 200, bytes.NewReader(secret)); err != nil { diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index edd56f7c452f5..337fa3d2b829a 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -800,6 +800,19 @@ func (b *LocalBackend) pauseOrResumeControlClientLocked() { b.cc.SetPaused((b.state == ipn.Stopped && b.netMap != nil) || (!networkUp && !testenv.InTest() && !assumeNetworkUpdateForTest())) } +// DisconnectControl shuts down control client. This can be run before node shutdown to force control to consider this ndoe +// inactive. This can be used to ensure that nodes that are HA subnet router or app connector replicas are shutting +// down, clients switch over to other replicas whilst the existing connections are kept alive for some period of time. +func (b *LocalBackend) DisconnectControl() { + b.mu.Lock() + defer b.mu.Unlock() + cc := b.resetControlClientLocked() + if cc == nil { + return + } + cc.Shutdown() +} + // captivePortalDetectionInterval is the duration to wait in an unhealthy state with connectivity broken // before running captive portal detection. const captivePortalDetectionInterval = 2 * time.Second diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index 0d41725d83dbe..dc8c089758371 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -100,6 +100,7 @@ var handler = map[string]localAPIHandler{ "derpmap": (*Handler).serveDERPMap, "dev-set-state-store": (*Handler).serveDevSetStateStore, "dial": (*Handler).serveDial, + "disconnect-control": (*Handler).disconnectControl, "dns-osconfig": (*Handler).serveDNSOSConfig, "dns-query": (*Handler).serveDNSQuery, "drive/fileserver-address": (*Handler).serveDriveServerAddr, @@ -952,6 +953,22 @@ func (h *Handler) servePprof(w http.ResponseWriter, r *http.Request) { servePprofFunc(w, r) } +// disconnectControl is the handler for local API /disconnect-control endpoint that shuts down control client, so that +// node no longer communicates with control. Doing this makes control consider this node inactive. This can be used +// before shutting down a replica of HA subnet router or app connector deployments to ensure that control tells the +// peers to switch over to another replica whilst still maintaining th existing peer connections. +func (h *Handler) disconnectControl(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + h.b.DisconnectControl() +} + func (h *Handler) reloadConfig(w http.ResponseWriter, r *http.Request) { if !h.PermitWrite { http.Error(w, "access denied", http.StatusForbidden) From 23880eb5b05368d30023f91c314c9cc2e19f4a90 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 7 Nov 2024 15:21:44 -0800 Subject: [PATCH 084/179] cmd/tailscaled: support "ts_omit_ssh" build tag to remove SSH Some environments would like to remove Tailscale SSH support for the binary for various reasons when not needed (either for peace of mind, or the ~1MB of binary space savings). Updates tailscale/corp#24454 Updates #1278 Updates #12614 Change-Id: Iadd6c5a393992c254b5dc9aa9a526916f96fd07a Signed-off-by: Brad Fitzpatrick --- cmd/tailscaled/deps_test.go | 30 ++++++++++++++++++++++++++++++ cmd/tailscaled/ssh.go | 2 +- tstest/deptest/deptest.go | 3 ++- 3 files changed, 33 insertions(+), 2 deletions(-) create mode 100644 cmd/tailscaled/deps_test.go diff --git a/cmd/tailscaled/deps_test.go b/cmd/tailscaled/deps_test.go new file mode 100644 index 0000000000000..2b4bc280d26cf --- /dev/null +++ b/cmd/tailscaled/deps_test.go @@ -0,0 +1,30 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "testing" + + "tailscale.com/tstest/deptest" +) + +func TestOmitSSH(t *testing.T) { + const msg = "unexpected with ts_omit_ssh" + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_ssh", + BadDeps: map[string]string{ + "tailscale.com/ssh/tailssh": msg, + "golang.org/x/crypto/ssh": msg, + "tailscale.com/sessionrecording": msg, + "github.com/anmitsu/go-shlex": msg, + "github.com/creack/pty": msg, + "github.com/kr/fs": msg, + "github.com/pkg/sftp": msg, + "github.com/u-root/u-root/pkg/termios": msg, + "tempfork/gliderlabs/ssh": msg, + }, + }.Check(t) +} diff --git a/cmd/tailscaled/ssh.go b/cmd/tailscaled/ssh.go index f7b0b367ead57..b10a3b7748719 100644 --- a/cmd/tailscaled/ssh.go +++ b/cmd/tailscaled/ssh.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || darwin || freebsd || openbsd +//go:build (linux || darwin || freebsd || openbsd) && !ts_omit_ssh package main diff --git a/tstest/deptest/deptest.go b/tstest/deptest/deptest.go index 57db2b79aa3c7..ba214de32398b 100644 --- a/tstest/deptest/deptest.go +++ b/tstest/deptest/deptest.go @@ -21,6 +21,7 @@ type DepChecker struct { GOOS string // optional GOARCH string // optional BadDeps map[string]string // package => why + Tags string // comma-separated } func (c DepChecker) Check(t *testing.T) { @@ -29,7 +30,7 @@ func (c DepChecker) Check(t *testing.T) { t.Skip("skipping dep tests on windows hosts") } t.Helper() - cmd := exec.Command("go", "list", "-json", ".") + cmd := exec.Command("go", "list", "-json", "-tags="+c.Tags, ".") var extraEnv []string if c.GOOS != "" { extraEnv = append(extraEnv, "GOOS="+c.GOOS) From c3306bfd15e761e0ad38e3e3970becd0d301e4c7 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 7 Nov 2024 15:59:19 -0800 Subject: [PATCH 085/179] control/controlhttp/controlhttpserver: split out Accept to its own package Otherwise all the clients only using control/controlhttp for the ts2021 HTTP client were also pulling in WebSocket libraries, as the server side always needs to speak websockets, but only GOOS=js clients speak it. This doesn't yet totally remove the websocket dependency on Linux because Linux has a envknob opt-in to act like GOOS=js for manual testing and force the use of WebSockets for DERP only (not control). We can put that behind a build tag in a future change to eliminate the dep on all GOOSes. Updates #1278 Change-Id: I4f60508f4cad52bf8c8943c8851ecee506b7ebc9 Signed-off-by: Brad Fitzpatrick --- cmd/k8s-operator/depaware.txt | 11 +++++----- cmd/tailscale/depaware.txt | 11 +++++----- cmd/tailscaled/depaware.txt | 11 +++++----- control/controlclient/noise_test.go | 4 ++-- control/controlhttp/client.go | 9 +++++---- control/controlhttp/constants.go | 9 --------- .../controlhttpcommon/controlhttpcommon.go | 15 ++++++++++++++ .../controlhttpserver.go} | 16 ++++++++------- control/controlhttp/http_test.go | 20 ++++++++++++++++--- tstest/integration/testcontrol/testcontrol.go | 4 ++-- 10 files changed, 68 insertions(+), 42 deletions(-) create mode 100644 control/controlhttp/controlhttpcommon/controlhttpcommon.go rename control/controlhttp/{server.go => controlhttpserver/controlhttpserver.go} (92%) diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 74536c6c9d050..cdd2ee722d25e 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -80,10 +80,10 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/beorn7/perks/quantile from github.com/prometheus/client_golang/prometheus github.com/bits-and-blooms/bitset from github.com/gaissmai/bart 💣 github.com/cespare/xxhash/v2 from github.com/prometheus/client_golang/prometheus - github.com/coder/websocket from tailscale.com/control/controlhttp+ - github.com/coder/websocket/internal/errd from github.com/coder/websocket - github.com/coder/websocket/internal/util from github.com/coder/websocket - github.com/coder/websocket/internal/xsync from github.com/coder/websocket + L github.com/coder/websocket from tailscale.com/derp/derphttp+ + L github.com/coder/websocket/internal/errd from github.com/coder/websocket + L github.com/coder/websocket/internal/util from github.com/coder/websocket + L github.com/coder/websocket/internal/xsync from github.com/coder/websocket L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw 💣 github.com/davecgh/go-spew/spew from k8s.io/apimachinery/pkg/util/dump W 💣 github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/com+ @@ -658,6 +658,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ tailscale.com/control/controlclient from tailscale.com/ipn/ipnlocal+ tailscale.com/control/controlhttp from tailscale.com/control/controlclient + tailscale.com/control/controlhttp/controlhttpcommon from tailscale.com/control/controlhttp tailscale.com/control/controlknobs from tailscale.com/control/controlclient+ tailscale.com/derp from tailscale.com/derp/derphttp+ tailscale.com/derp/derphttp from tailscale.com/ipn/localapi+ @@ -740,7 +741,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/net/tsdial from tailscale.com/control/controlclient+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ tailscale.com/net/tstun from tailscale.com/tsd+ - tailscale.com/net/wsconn from tailscale.com/control/controlhttp+ + L tailscale.com/net/wsconn from tailscale.com/derp/derphttp tailscale.com/omit from tailscale.com/ipn/conffile tailscale.com/paths from tailscale.com/client/tailscale+ 💣 tailscale.com/portlist from tailscale.com/ipn/ipnlocal diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index ac5440d2cfabf..60af1de01c280 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -5,10 +5,10 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep W 💣 github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate W 💣 github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy - github.com/coder/websocket from tailscale.com/control/controlhttp+ - github.com/coder/websocket/internal/errd from github.com/coder/websocket - github.com/coder/websocket/internal/util from github.com/coder/websocket - github.com/coder/websocket/internal/xsync from github.com/coder/websocket + L github.com/coder/websocket from tailscale.com/derp/derphttp+ + L github.com/coder/websocket/internal/errd from github.com/coder/websocket + L github.com/coder/websocket/internal/util from github.com/coder/websocket + L github.com/coder/websocket/internal/xsync from github.com/coder/websocket L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw W 💣 github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/pe+ W 💣 github.com/dblohm7/wingoes/pe from tailscale.com/util/winutil/authenticode @@ -86,6 +86,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/cmd/tailscale/cli/ffcomplete/internal from tailscale.com/cmd/tailscale/cli/ffcomplete tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ tailscale.com/control/controlhttp from tailscale.com/cmd/tailscale/cli + tailscale.com/control/controlhttp/controlhttpcommon from tailscale.com/control/controlhttp tailscale.com/control/controlknobs from tailscale.com/net/portmapper tailscale.com/derp from tailscale.com/derp/derphttp tailscale.com/derp/derphttp from tailscale.com/net/netcheck @@ -124,7 +125,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial tailscale.com/net/tsaddr from tailscale.com/client/web+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ - tailscale.com/net/wsconn from tailscale.com/control/controlhttp+ + L tailscale.com/net/wsconn from tailscale.com/derp/derphttp tailscale.com/paths from tailscale.com/client/tailscale+ 💣 tailscale.com/safesocket from tailscale.com/client/tailscale+ tailscale.com/syncs from tailscale.com/cmd/tailscale/cli+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 31a0cb67cb568..707c0c065e0a1 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -79,10 +79,10 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http L github.com/aws/smithy-go/waiter from github.com/aws/aws-sdk-go-v2/service/ssm github.com/bits-and-blooms/bitset from github.com/gaissmai/bart - github.com/coder/websocket from tailscale.com/control/controlhttp+ - github.com/coder/websocket/internal/errd from github.com/coder/websocket - github.com/coder/websocket/internal/util from github.com/coder/websocket - github.com/coder/websocket/internal/xsync from github.com/coder/websocket + L github.com/coder/websocket from tailscale.com/derp/derphttp+ + L github.com/coder/websocket/internal/errd from github.com/coder/websocket + L github.com/coder/websocket/internal/util from github.com/coder/websocket + L github.com/coder/websocket/internal/xsync from github.com/coder/websocket L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw LD 💣 github.com/creack/pty from tailscale.com/ssh/tailssh W 💣 github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/com+ @@ -249,6 +249,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ tailscale.com/control/controlclient from tailscale.com/cmd/tailscaled+ tailscale.com/control/controlhttp from tailscale.com/control/controlclient + tailscale.com/control/controlhttp/controlhttpcommon from tailscale.com/control/controlhttp tailscale.com/control/controlknobs from tailscale.com/control/controlclient+ tailscale.com/derp from tailscale.com/derp/derphttp+ tailscale.com/derp/derphttp from tailscale.com/cmd/tailscaled+ @@ -327,7 +328,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ tailscale.com/net/tstun from tailscale.com/cmd/tailscaled+ - tailscale.com/net/wsconn from tailscale.com/control/controlhttp+ + L tailscale.com/net/wsconn from tailscale.com/derp/derphttp tailscale.com/omit from tailscale.com/ipn/conffile tailscale.com/paths from tailscale.com/client/tailscale+ 💣 tailscale.com/portlist from tailscale.com/ipn/ipnlocal diff --git a/control/controlclient/noise_test.go b/control/controlclient/noise_test.go index f2627bd0a50fa..69a3a6a36551d 100644 --- a/control/controlclient/noise_test.go +++ b/control/controlclient/noise_test.go @@ -15,7 +15,7 @@ import ( "time" "golang.org/x/net/http2" - "tailscale.com/control/controlhttp" + "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/internal/noiseconn" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" @@ -201,7 +201,7 @@ func (up *Upgrader) ServeHTTP(w http.ResponseWriter, r *http.Request) { return nil } - cbConn, err := controlhttp.AcceptHTTP(r.Context(), w, r, up.noiseKeyPriv, earlyWriteFn) + cbConn, err := controlhttpserver.AcceptHTTP(r.Context(), w, r, up.noiseKeyPriv, earlyWriteFn) if err != nil { up.logf("controlhttp: Accept: %v", err) return diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index 7e5263e3317fe..9b1d5a1a598e7 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -38,6 +38,7 @@ import ( "time" "tailscale.com/control/controlbase" + "tailscale.com/control/controlhttp/controlhttpcommon" "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/net/dnscache" @@ -571,9 +572,9 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad Method: "POST", URL: u, Header: http.Header{ - "Upgrade": []string{upgradeHeaderValue}, - "Connection": []string{"upgrade"}, - handshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)}, + "Upgrade": []string{controlhttpcommon.UpgradeHeaderValue}, + "Connection": []string{"upgrade"}, + controlhttpcommon.HandshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)}, }, } req = req.WithContext(ctx) @@ -597,7 +598,7 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad return nil, fmt.Errorf("httptrace didn't provide a connection") } - if next := resp.Header.Get("Upgrade"); next != upgradeHeaderValue { + if next := resp.Header.Get("Upgrade"); next != controlhttpcommon.UpgradeHeaderValue { resp.Body.Close() return nil, fmt.Errorf("server switched to unexpected protocol %q", next) } diff --git a/control/controlhttp/constants.go b/control/controlhttp/constants.go index ea1725e76d438..0b550acccf866 100644 --- a/control/controlhttp/constants.go +++ b/control/controlhttp/constants.go @@ -18,15 +18,6 @@ import ( ) const ( - // upgradeHeader is the value of the Upgrade HTTP header used to - // indicate the Tailscale control protocol. - upgradeHeaderValue = "tailscale-control-protocol" - - // handshakeHeaderName is the HTTP request header that can - // optionally contain base64-encoded initial handshake - // payload, to save an RTT. - handshakeHeaderName = "X-Tailscale-Handshake" - // serverUpgradePath is where the server-side HTTP handler to // to do the protocol switch is located. serverUpgradePath = "/ts2021" diff --git a/control/controlhttp/controlhttpcommon/controlhttpcommon.go b/control/controlhttp/controlhttpcommon/controlhttpcommon.go new file mode 100644 index 0000000000000..a86b7ca04a7f4 --- /dev/null +++ b/control/controlhttp/controlhttpcommon/controlhttpcommon.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package controlhttpcommon contains common constants for used +// by the controlhttp client and controlhttpserver packages. +package controlhttpcommon + +// UpgradeHeader is the value of the Upgrade HTTP header used to +// indicate the Tailscale control protocol. +const UpgradeHeaderValue = "tailscale-control-protocol" + +// handshakeHeaderName is the HTTP request header that can +// optionally contain base64-encoded initial handshake +// payload, to save an RTT. +const HandshakeHeaderName = "X-Tailscale-Handshake" diff --git a/control/controlhttp/server.go b/control/controlhttp/controlhttpserver/controlhttpserver.go similarity index 92% rename from control/controlhttp/server.go rename to control/controlhttp/controlhttpserver/controlhttpserver.go index 7c3dd5618c4a3..47f049c180437 100644 --- a/control/controlhttp/server.go +++ b/control/controlhttp/controlhttpserver/controlhttpserver.go @@ -3,7 +3,8 @@ //go:build !ios -package controlhttp +// Packet controlhttpserver contains the HTTP server side of the ts2021 control protocol. +package controlhttpserver import ( "context" @@ -18,6 +19,7 @@ import ( "github.com/coder/websocket" "tailscale.com/control/controlbase" + "tailscale.com/control/controlhttp/controlhttpcommon" "tailscale.com/net/netutil" "tailscale.com/net/wsconn" "tailscale.com/types/key" @@ -45,12 +47,12 @@ func acceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri if next == "websocket" { return acceptWebsocket(ctx, w, r, private) } - if next != upgradeHeaderValue { + if next != controlhttpcommon.UpgradeHeaderValue { http.Error(w, "unknown next protocol", http.StatusBadRequest) return nil, fmt.Errorf("client requested unhandled next protocol %q", next) } - initB64 := r.Header.Get(handshakeHeaderName) + initB64 := r.Header.Get(controlhttpcommon.HandshakeHeaderName) if initB64 == "" { http.Error(w, "missing Tailscale handshake header", http.StatusBadRequest) return nil, errors.New("no tailscale handshake header in HTTP request") @@ -67,7 +69,7 @@ func acceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri return nil, errors.New("can't hijack client connection") } - w.Header().Set("Upgrade", upgradeHeaderValue) + w.Header().Set("Upgrade", controlhttpcommon.UpgradeHeaderValue) w.Header().Set("Connection", "upgrade") w.WriteHeader(http.StatusSwitchingProtocols) @@ -117,7 +119,7 @@ func acceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri // speak HTTP) to a Tailscale control protocol base transport connection. func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate) (*controlbase.Conn, error) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{upgradeHeaderValue}, + Subprotocols: []string{controlhttpcommon.UpgradeHeaderValue}, OriginPatterns: []string{"*"}, // Disable compression because we transmit Noise messages that are not // compressible. @@ -129,7 +131,7 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request if err != nil { return nil, fmt.Errorf("Could not accept WebSocket connection %v", err) } - if c.Subprotocol() != upgradeHeaderValue { + if c.Subprotocol() != controlhttpcommon.UpgradeHeaderValue { c.Close(websocket.StatusPolicyViolation, "client must speak the control subprotocol") return nil, fmt.Errorf("Unexpected subprotocol %q", c.Subprotocol()) } @@ -137,7 +139,7 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request c.Close(websocket.StatusPolicyViolation, "Could not parse parameters") return nil, fmt.Errorf("parse query parameters: %v", err) } - initB64 := r.Form.Get(handshakeHeaderName) + initB64 := r.Form.Get(controlhttpcommon.HandshakeHeaderName) if initB64 == "" { c.Close(websocket.StatusPolicyViolation, "missing Tailscale handshake parameter") return nil, errors.New("no tailscale handshake parameter in HTTP request") diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index 8c8ed7f5701b0..00cc1e6cfd80b 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -23,12 +23,15 @@ import ( "time" "tailscale.com/control/controlbase" + "tailscale.com/control/controlhttp/controlhttpcommon" + "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/net/dnscache" "tailscale.com/net/netmon" "tailscale.com/net/socks5" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" "tailscale.com/tstest" + "tailscale.com/tstest/deptest" "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -158,7 +161,7 @@ func testControlHTTP(t *testing.T, param httpTestParam) { return err } } - conn, err := AcceptHTTP(context.Background(), w, r, server, earlyWriteFn) + conn, err := controlhttpserver.AcceptHTTP(context.Background(), w, r, server, earlyWriteFn) if err != nil { log.Print(err) } @@ -529,7 +532,7 @@ EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Upgrade", upgradeHeaderValue) + w.Header().Set("Upgrade", controlhttpcommon.UpgradeHeaderValue) w.Header().Set("Connection", "upgrade") w.WriteHeader(http.StatusSwitchingProtocols) w.(http.Flusher).Flush() @@ -574,7 +577,7 @@ func TestDialPlan(t *testing.T) { close(done) }) var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := AcceptHTTP(context.Background(), w, r, server, nil) + conn, err := controlhttpserver.AcceptHTTP(context.Background(), w, r, server, nil) if err != nil { log.Print(err) } else { @@ -816,3 +819,14 @@ func (c *closeTrackConn) Close() error { c.d.noteClose(c) return c.Conn.Close() } + +func TestDeps(t *testing.T) { + deptest.DepChecker{ + GOOS: "darwin", + GOARCH: "arm64", + BadDeps: map[string]string{ + // Only the controlhttpserver needs WebSockets... + "github.com/coder/websocket": "controlhttp client shouldn't need websockets", + }, + }.Check(t) +} diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index bbcf277d171e1..2d6a843618627 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -26,7 +26,7 @@ import ( "time" "golang.org/x/net/http2" - "tailscale.com/control/controlhttp" + "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/net/netaddr" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" @@ -288,7 +288,7 @@ func (s *Server) serveNoiseUpgrade(w http.ResponseWriter, r *http.Request) { s.mu.Lock() noisePrivate := s.noisePrivKey s.mu.Unlock() - cc, err := controlhttp.AcceptHTTP(ctx, w, r, noisePrivate, nil) + cc, err := controlhttpserver.AcceptHTTP(ctx, w, r, noisePrivate, nil) if err != nil { log.Printf("AcceptHTTP: %v", err) return From 020cacbe702463f14a5d2d5427819c491c7e6578 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 7 Nov 2024 16:49:47 -0800 Subject: [PATCH 086/179] derp/derphttp: don't link websockets other than on GOOS=js Or unless the new "ts_debug_websockets" build tag is set. Updates #1278 Change-Id: Ic4c4f81c1924250efd025b055585faec37a5491d Signed-off-by: Brad Fitzpatrick --- cmd/derper/depaware.txt | 2 +- cmd/k8s-operator/depaware.txt | 5 ----- cmd/tailscale/depaware.txt | 7 +----- cmd/tailscaled/depaware.txt | 5 ----- control/controlhttp/client_js.go | 5 +++-- .../controlhttpserver/controlhttpserver.go | 2 +- derp/derphttp/derphttp_client.go | 5 ++++- derp/derphttp/derphttp_test.go | 22 +++++++++++++++++++ derp/derphttp/websocket.go | 4 +++- derp/derphttp/websocket_stub.go | 8 +++++++ tstest/deptest/deptest.go | 17 ++++++++++---- 11 files changed, 56 insertions(+), 26 deletions(-) create mode 100644 derp/derphttp/websocket_stub.go diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 8fa5334aa7ebe..81a7f14f4a71c 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -116,7 +116,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial tailscale.com/net/tsaddr from tailscale.com/ipn+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/derp/derphttp+ - tailscale.com/net/wsconn from tailscale.com/cmd/derper+ + tailscale.com/net/wsconn from tailscale.com/cmd/derper tailscale.com/paths from tailscale.com/client/tailscale 💣 tailscale.com/safesocket from tailscale.com/client/tailscale tailscale.com/syncs from tailscale.com/cmd/derper+ diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index cdd2ee722d25e..900d10efedc99 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -80,10 +80,6 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/beorn7/perks/quantile from github.com/prometheus/client_golang/prometheus github.com/bits-and-blooms/bitset from github.com/gaissmai/bart 💣 github.com/cespare/xxhash/v2 from github.com/prometheus/client_golang/prometheus - L github.com/coder/websocket from tailscale.com/derp/derphttp+ - L github.com/coder/websocket/internal/errd from github.com/coder/websocket - L github.com/coder/websocket/internal/util from github.com/coder/websocket - L github.com/coder/websocket/internal/xsync from github.com/coder/websocket L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw 💣 github.com/davecgh/go-spew/spew from k8s.io/apimachinery/pkg/util/dump W 💣 github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/com+ @@ -741,7 +737,6 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/net/tsdial from tailscale.com/control/controlclient+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ tailscale.com/net/tstun from tailscale.com/tsd+ - L tailscale.com/net/wsconn from tailscale.com/derp/derphttp tailscale.com/omit from tailscale.com/ipn/conffile tailscale.com/paths from tailscale.com/client/tailscale+ 💣 tailscale.com/portlist from tailscale.com/ipn/ipnlocal diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 60af1de01c280..d18d8887327fa 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -5,10 +5,6 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep W 💣 github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate W 💣 github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy - L github.com/coder/websocket from tailscale.com/derp/derphttp+ - L github.com/coder/websocket/internal/errd from github.com/coder/websocket - L github.com/coder/websocket/internal/util from github.com/coder/websocket - L github.com/coder/websocket/internal/xsync from github.com/coder/websocket L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw W 💣 github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/pe+ W 💣 github.com/dblohm7/wingoes/pe from tailscale.com/util/winutil/authenticode @@ -125,7 +121,6 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial tailscale.com/net/tsaddr from tailscale.com/client/web+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ - L tailscale.com/net/wsconn from tailscale.com/derp/derphttp tailscale.com/paths from tailscale.com/client/tailscale+ 💣 tailscale.com/safesocket from tailscale.com/client/tailscale+ tailscale.com/syncs from tailscale.com/cmd/tailscale/cli+ @@ -326,7 +321,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep reflect from archive/tar+ regexp from github.com/coreos/go-iptables/iptables+ regexp/syntax from regexp - runtime/debug from github.com/coder/websocket/internal/xsync+ + runtime/debug from tailscale.com+ slices from tailscale.com/client/web+ sort from compress/flate+ strconv from archive/tar+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 707c0c065e0a1..81cd53271cf9e 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -79,10 +79,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http L github.com/aws/smithy-go/waiter from github.com/aws/aws-sdk-go-v2/service/ssm github.com/bits-and-blooms/bitset from github.com/gaissmai/bart - L github.com/coder/websocket from tailscale.com/derp/derphttp+ - L github.com/coder/websocket/internal/errd from github.com/coder/websocket - L github.com/coder/websocket/internal/util from github.com/coder/websocket - L github.com/coder/websocket/internal/xsync from github.com/coder/websocket L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw LD 💣 github.com/creack/pty from tailscale.com/ssh/tailssh W 💣 github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/com+ @@ -328,7 +324,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ tailscale.com/net/tstun from tailscale.com/cmd/tailscaled+ - L tailscale.com/net/wsconn from tailscale.com/derp/derphttp tailscale.com/omit from tailscale.com/ipn/conffile tailscale.com/paths from tailscale.com/client/tailscale+ 💣 tailscale.com/portlist from tailscale.com/ipn/ipnlocal diff --git a/control/controlhttp/client_js.go b/control/controlhttp/client_js.go index 4b7126b52cf38..cc05b5b192766 100644 --- a/control/controlhttp/client_js.go +++ b/control/controlhttp/client_js.go @@ -12,6 +12,7 @@ import ( "github.com/coder/websocket" "tailscale.com/control/controlbase" + "tailscale.com/control/controlhttp/controlhttpcommon" "tailscale.com/net/wsconn" ) @@ -42,11 +43,11 @@ func (d *Dialer) Dial(ctx context.Context) (*ClientConn, error) { // Can't set HTTP headers on the websocket request, so we have to to send // the handshake via an HTTP header. RawQuery: url.Values{ - handshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)}, + controlhttpcommon.HandshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)}, }.Encode(), } wsConn, _, err := websocket.Dial(ctx, wsURL.String(), &websocket.DialOptions{ - Subprotocols: []string{upgradeHeaderValue}, + Subprotocols: []string{controlhttpcommon.UpgradeHeaderValue}, }) if err != nil { return nil, err diff --git a/control/controlhttp/controlhttpserver/controlhttpserver.go b/control/controlhttp/controlhttpserver/controlhttpserver.go index 47f049c180437..af320781069d1 100644 --- a/control/controlhttp/controlhttpserver/controlhttpserver.go +++ b/control/controlhttp/controlhttpserver/controlhttpserver.go @@ -3,7 +3,7 @@ //go:build !ios -// Packet controlhttpserver contains the HTTP server side of the ts2021 control protocol. +// Package controlhttpserver contains the HTTP server side of the ts2021 control protocol. package controlhttpserver import ( diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index b695a52a89606..c95d072b1a572 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -313,6 +313,9 @@ func (c *Client) preferIPv6() bool { var dialWebsocketFunc func(ctx context.Context, urlStr string) (net.Conn, error) func useWebsockets() bool { + if !canWebsockets { + return false + } if runtime.GOOS == "js" { return true } @@ -383,7 +386,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien var node *tailcfg.DERPNode // nil when using c.url to dial var idealNodeInRegion bool switch { - case useWebsockets(): + case canWebsockets && useWebsockets(): var urlStr string if c.url != nil { urlStr = c.url.String() diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go index cfb3676cda16f..cf6032a5e6d43 100644 --- a/derp/derphttp/derphttp_test.go +++ b/derp/derphttp/derphttp_test.go @@ -17,7 +17,9 @@ import ( "tailscale.com/derp" "tailscale.com/net/netmon" + "tailscale.com/tstest/deptest" "tailscale.com/types/key" + "tailscale.com/util/set" ) func TestSendRecv(t *testing.T) { @@ -485,3 +487,23 @@ func TestProbe(t *testing.T) { } } } + +func TestDeps(t *testing.T) { + deptest.DepChecker{ + GOOS: "darwin", + GOARCH: "arm64", + BadDeps: map[string]string{ + "github.com/coder/websocket": "shouldn't link websockets except on js/wasm", + }, + }.Check(t) + + deptest.DepChecker{ + GOOS: "darwin", + GOARCH: "arm64", + Tags: "ts_debug_websockets", + WantDeps: set.Of( + "github.com/coder/websocket", + ), + }.Check(t) + +} diff --git a/derp/derphttp/websocket.go b/derp/derphttp/websocket.go index 6ef47473a2532..9dd640ee37083 100644 --- a/derp/derphttp/websocket.go +++ b/derp/derphttp/websocket.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || js +//go:build js || ((linux || darwin) && ts_debug_websockets) package derphttp @@ -14,6 +14,8 @@ import ( "tailscale.com/net/wsconn" ) +const canWebsockets = true + func init() { dialWebsocketFunc = dialWebsocket } diff --git a/derp/derphttp/websocket_stub.go b/derp/derphttp/websocket_stub.go new file mode 100644 index 0000000000000..d84bfba571f80 --- /dev/null +++ b/derp/derphttp/websocket_stub.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !(js || ((linux || darwin) && ts_debug_websockets)) + +package derphttp + +const canWebsockets = false diff --git a/tstest/deptest/deptest.go b/tstest/deptest/deptest.go index ba214de32398b..00faa8a386db8 100644 --- a/tstest/deptest/deptest.go +++ b/tstest/deptest/deptest.go @@ -13,15 +13,19 @@ import ( "path/filepath" "regexp" "runtime" + "slices" "strings" "testing" + + "tailscale.com/util/set" ) type DepChecker struct { - GOOS string // optional - GOARCH string // optional - BadDeps map[string]string // package => why - Tags string // comma-separated + GOOS string // optional + GOARCH string // optional + BadDeps map[string]string // package => why + WantDeps set.Set[string] // packages expected + Tags string // comma-separated } func (c DepChecker) Check(t *testing.T) { @@ -55,6 +59,11 @@ func (c DepChecker) Check(t *testing.T) { t.Errorf("package %q is not allowed as a dependency (env: %q); reason: %s", dep, extraEnv, why) } } + for dep := range c.WantDeps { + if !slices.Contains(res.Deps, dep) { + t.Errorf("expected package %q to be a dependency (env: %q)", dep, extraEnv) + } + } t.Logf("got %d dependencies", len(res.Deps)) } From 64d70fb718557f73a3cebdc41558405697b913ec Mon Sep 17 00:00:00 2001 From: Anton Tolchanov Date: Fri, 8 Nov 2024 13:21:38 +0000 Subject: [PATCH 087/179] ipn/ipnlocal: log a summary of posture identity response Perhaps I was too opimistic in #13323 thinking we won't need logs for this. Let's log a summary of the response without logging specific identifiers. Updates tailscale/corp#24437 Signed-off-by: Anton Tolchanov --- ipn/ipnlocal/c2n.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ipn/ipnlocal/c2n.go b/ipn/ipnlocal/c2n.go index c3ed32fd89bd5..8380689d1f066 100644 --- a/ipn/ipnlocal/c2n.go +++ b/ipn/ipnlocal/c2n.go @@ -350,6 +350,8 @@ func handleC2NPostureIdentityGet(b *LocalBackend, w http.ResponseWriter, r *http res.PostureDisabled = true } + b.logf("c2n: posture identity disabled=%v reported %d serials %d hwaddrs", res.PostureDisabled, len(res.SerialNumbers), len(res.IfaceHardwareAddrs)) + w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(res) } From 6ff85846bcb5c8aeb35e2fa36808366ec4f148fb Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Fri, 8 Nov 2024 10:02:16 -0800 Subject: [PATCH 088/179] safeweb: add a Shutdown method to the Server type (#14048) Updates #14047 Change-Id: I2d20454c715b11ad9c6aad1d81445e05a170c3a2 Signed-off-by: M. J. Fromberger --- safeweb/http.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/safeweb/http.go b/safeweb/http.go index bd53eca5bbfd9..983ff2fad8031 100644 --- a/safeweb/http.go +++ b/safeweb/http.go @@ -71,6 +71,7 @@ package safeweb import ( "cmp" + "context" crand "crypto/rand" "fmt" "log" @@ -416,3 +417,7 @@ func (s *Server) ListenAndServe(addr string) error { func (s *Server) Close() error { return s.h.Close() } + +// Shutdown gracefully shuts down the server without interrupting any active +// connections. It has the same semantics as[http.Server.Shutdown]. +func (s *Server) Shutdown(ctx context.Context) error { return s.h.Shutdown(ctx) } From b9ecc50ce38d23cfacd5fae6360fa9742c0564a6 Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Mon, 11 Nov 2024 11:43:54 +0000 Subject: [PATCH 089/179] cmd/k8s-operator,k8s-operator,kube/kubetypes: add an option to configure app connector via Connector spec (#13950) * cmd/k8s-operator,k8s-operator,kube/kubetypes: add an option to configure app connector via Connector spec Updates tailscale/tailscale#11113 Signed-off-by: Irbe Krumina --- cmd/k8s-operator/connector.go | 67 ++++++++++--- cmd/k8s-operator/connector_test.go | 99 +++++++++++++++++++ .../deploy/crds/tailscale.com_connectors.yaml | 53 ++++++++-- .../deploy/manifests/operator.yaml | 53 ++++++++-- cmd/k8s-operator/operator_test.go | 4 +- cmd/k8s-operator/sts.go | 22 ++++- cmd/k8s-operator/testutils_test.go | 34 ++++++- k8s-operator/api.md | 23 ++++- k8s-operator/apis/v1alpha1/types_connector.go | 46 +++++++-- .../apis/v1alpha1/zz_generated.deepcopy.go | 25 +++++ kube/kubetypes/metrics.go | 1 + 11 files changed, 381 insertions(+), 46 deletions(-) diff --git a/cmd/k8s-operator/connector.go b/cmd/k8s-operator/connector.go index 016166b4cda29..1c1df7c962b91 100644 --- a/cmd/k8s-operator/connector.go +++ b/cmd/k8s-operator/connector.go @@ -13,7 +13,8 @@ import ( "sync" "time" - "github.com/pkg/errors" + "errors" + "go.uber.org/zap" xslices "golang.org/x/exp/slices" corev1 "k8s.io/api/core/v1" @@ -58,6 +59,7 @@ type ConnectorReconciler struct { subnetRouters set.Slice[types.UID] // for subnet routers gauge exitNodes set.Slice[types.UID] // for exit nodes gauge + appConnectors set.Slice[types.UID] // for app connectors gauge } var ( @@ -67,6 +69,8 @@ var ( gaugeConnectorSubnetRouterResources = clientmetric.NewGauge(kubetypes.MetricConnectorWithSubnetRouterCount) // gaugeConnectorExitNodeResources tracks the number of Connectors currently managed by this operator instance that are exit nodes. gaugeConnectorExitNodeResources = clientmetric.NewGauge(kubetypes.MetricConnectorWithExitNodeCount) + // gaugeConnectorAppConnectorResources tracks the number of Connectors currently managed by this operator instance that are app connectors. + gaugeConnectorAppConnectorResources = clientmetric.NewGauge(kubetypes.MetricConnectorWithAppConnectorCount) ) func (a *ConnectorReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { @@ -108,13 +112,12 @@ func (a *ConnectorReconciler) Reconcile(ctx context.Context, req reconcile.Reque oldCnStatus := cn.Status.DeepCopy() setStatus := func(cn *tsapi.Connector, _ tsapi.ConditionType, status metav1.ConditionStatus, reason, message string) (reconcile.Result, error) { tsoperator.SetConnectorCondition(cn, tsapi.ConnectorReady, status, reason, message, cn.Generation, a.clock, logger) + var updateErr error if !apiequality.Semantic.DeepEqual(oldCnStatus, cn.Status) { // An error encountered here should get returned by the Reconcile function. - if updateErr := a.Client.Status().Update(ctx, cn); updateErr != nil { - err = errors.Wrap(err, updateErr.Error()) - } + updateErr = a.Client.Status().Update(ctx, cn) } - return res, err + return res, errors.Join(err, updateErr) } if !slices.Contains(cn.Finalizers, FinalizerName) { @@ -150,6 +153,9 @@ func (a *ConnectorReconciler) Reconcile(ctx context.Context, req reconcile.Reque cn.Status.SubnetRoutes = cn.Spec.SubnetRouter.AdvertiseRoutes.Stringify() return setStatus(cn, tsapi.ConnectorReady, metav1.ConditionTrue, reasonConnectorCreated, reasonConnectorCreated) } + if cn.Spec.AppConnector != nil { + cn.Status.IsAppConnector = true + } cn.Status.SubnetRoutes = "" return setStatus(cn, tsapi.ConnectorReady, metav1.ConditionTrue, reasonConnectorCreated, reasonConnectorCreated) } @@ -189,23 +195,37 @@ func (a *ConnectorReconciler) maybeProvisionConnector(ctx context.Context, logge sts.Connector.routes = cn.Spec.SubnetRouter.AdvertiseRoutes.Stringify() } + if cn.Spec.AppConnector != nil { + sts.Connector.isAppConnector = true + if len(cn.Spec.AppConnector.Routes) != 0 { + sts.Connector.routes = cn.Spec.AppConnector.Routes.Stringify() + } + } + a.mu.Lock() - if sts.Connector.isExitNode { + if cn.Spec.ExitNode { a.exitNodes.Add(cn.UID) } else { a.exitNodes.Remove(cn.UID) } - if sts.Connector.routes != "" { + if cn.Spec.SubnetRouter != nil { a.subnetRouters.Add(cn.GetUID()) } else { a.subnetRouters.Remove(cn.GetUID()) } + if cn.Spec.AppConnector != nil { + a.appConnectors.Add(cn.GetUID()) + } else { + a.appConnectors.Remove(cn.GetUID()) + } a.mu.Unlock() gaugeConnectorSubnetRouterResources.Set(int64(a.subnetRouters.Len())) gaugeConnectorExitNodeResources.Set(int64(a.exitNodes.Len())) + gaugeConnectorAppConnectorResources.Set(int64(a.appConnectors.Len())) var connectors set.Slice[types.UID] connectors.AddSlice(a.exitNodes.Slice()) connectors.AddSlice(a.subnetRouters.Slice()) + connectors.AddSlice(a.appConnectors.Slice()) gaugeConnectorResources.Set(int64(connectors.Len())) _, err := a.ssr.Provision(ctx, logger, sts) @@ -248,12 +268,15 @@ func (a *ConnectorReconciler) maybeCleanupConnector(ctx context.Context, logger a.mu.Lock() a.subnetRouters.Remove(cn.UID) a.exitNodes.Remove(cn.UID) + a.appConnectors.Remove(cn.UID) a.mu.Unlock() gaugeConnectorExitNodeResources.Set(int64(a.exitNodes.Len())) gaugeConnectorSubnetRouterResources.Set(int64(a.subnetRouters.Len())) + gaugeConnectorAppConnectorResources.Set(int64(a.appConnectors.Len())) var connectors set.Slice[types.UID] connectors.AddSlice(a.exitNodes.Slice()) connectors.AddSlice(a.subnetRouters.Slice()) + connectors.AddSlice(a.appConnectors.Slice()) gaugeConnectorResources.Set(int64(connectors.Len())) return true, nil } @@ -262,8 +285,14 @@ func (a *ConnectorReconciler) validate(cn *tsapi.Connector) error { // Connector fields are already validated at apply time with CEL validation // on custom resource fields. The checks here are a backup in case the // CEL validation breaks without us noticing. - if !(cn.Spec.SubnetRouter != nil || cn.Spec.ExitNode) { - return errors.New("invalid spec: a Connector must expose subnet routes or act as an exit node (or both)") + if cn.Spec.SubnetRouter == nil && !cn.Spec.ExitNode && cn.Spec.AppConnector == nil { + return errors.New("invalid spec: a Connector must be configured as at least one of subnet router, exit node or app connector") + } + if (cn.Spec.SubnetRouter != nil || cn.Spec.ExitNode) && cn.Spec.AppConnector != nil { + return errors.New("invalid spec: a Connector that is configured as an app connector must not be also configured as a subnet router or exit node") + } + if cn.Spec.AppConnector != nil { + return validateAppConnector(cn.Spec.AppConnector) } if cn.Spec.SubnetRouter == nil { return nil @@ -272,19 +301,27 @@ func (a *ConnectorReconciler) validate(cn *tsapi.Connector) error { } func validateSubnetRouter(sb *tsapi.SubnetRouter) error { - if len(sb.AdvertiseRoutes) < 1 { + if len(sb.AdvertiseRoutes) == 0 { return errors.New("invalid subnet router spec: no routes defined") } - var err error - for _, route := range sb.AdvertiseRoutes { + return validateRoutes(sb.AdvertiseRoutes) +} + +func validateAppConnector(ac *tsapi.AppConnector) error { + return validateRoutes(ac.Routes) +} + +func validateRoutes(routes tsapi.Routes) error { + var errs []error + for _, route := range routes { pfx, e := netip.ParsePrefix(string(route)) if e != nil { - err = errors.Wrap(err, fmt.Sprintf("route %s is invalid: %v", route, err)) + errs = append(errs, fmt.Errorf("route %v is invalid: %v", route, e)) continue } if pfx.Masked() != pfx { - err = errors.Wrap(err, fmt.Sprintf("route %s has non-address bits set; expected %s", pfx, pfx.Masked())) + errs = append(errs, fmt.Errorf("route %s has non-address bits set; expected %s", pfx, pfx.Masked())) } } - return err + return errors.Join(errs...) } diff --git a/cmd/k8s-operator/connector_test.go b/cmd/k8s-operator/connector_test.go index a4ba90d3d6683..7cdd83115e877 100644 --- a/cmd/k8s-operator/connector_test.go +++ b/cmd/k8s-operator/connector_test.go @@ -8,12 +8,14 @@ package main import ( "context" "testing" + "time" "go.uber.org/zap" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client/fake" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" @@ -296,3 +298,100 @@ func TestConnectorWithProxyClass(t *testing.T) { expectReconciled(t, cr, "", "test") expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) } + +func TestConnectorWithAppConnector(t *testing.T) { + // Setup + cn := &tsapi.Connector{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + UID: types.UID("1234-UID"), + }, + TypeMeta: metav1.TypeMeta{ + Kind: tsapi.ConnectorKind, + APIVersion: "tailscale.io/v1alpha1", + }, + Spec: tsapi.ConnectorSpec{ + AppConnector: &tsapi.AppConnector{}, + }, + } + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(cn). + WithStatusSubresource(cn). + Build() + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + cl := tstest.NewClock(tstest.ClockOpts{}) + fr := record.NewFakeRecorder(1) + cr := &ConnectorReconciler{ + Client: fc, + clock: cl, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + recorder: fr, + } + + // 1. Connector with app connnector is created and becomes ready + expectReconciled(t, cr, "", "test") + fullName, shortName := findGenName(t, fc, "", "test", "connector") + opts := configOpts{ + stsName: shortName, + secretName: fullName, + parentType: "connector", + hostname: "test-connector", + app: kubetypes.AppConnector, + isAppConnector: true, + } + expectEqual(t, fc, expectedSecret(t, fc, opts), nil) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + // Connector's ready condition should be set to true + + cn.ObjectMeta.Finalizers = append(cn.ObjectMeta.Finalizers, "tailscale.com/finalizer") + cn.Status.IsAppConnector = true + cn.Status.Conditions = []metav1.Condition{{ + Type: string(tsapi.ConnectorReady), + Status: metav1.ConditionTrue, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + Reason: reasonConnectorCreated, + Message: reasonConnectorCreated, + }} + expectEqual(t, fc, cn, nil) + + // 2. Connector with invalid app connector routes has status set to invalid + mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { + conn.Spec.AppConnector.Routes = tsapi.Routes{tsapi.Route("1.2.3.4/5")} + }) + cn.Spec.AppConnector.Routes = tsapi.Routes{tsapi.Route("1.2.3.4/5")} + expectReconciled(t, cr, "", "test") + cn.Status.Conditions = []metav1.Condition{{ + Type: string(tsapi.ConnectorReady), + Status: metav1.ConditionFalse, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + Reason: reasonConnectorInvalid, + Message: "Connector is invalid: route 1.2.3.4/5 has non-address bits set; expected 0.0.0.0/5", + }} + expectEqual(t, fc, cn, nil) + + // 3. Connector with valid app connnector routes becomes ready + mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { + conn.Spec.AppConnector.Routes = tsapi.Routes{tsapi.Route("10.88.2.21/32")} + }) + cn.Spec.AppConnector.Routes = tsapi.Routes{tsapi.Route("10.88.2.21/32")} + cn.Status.Conditions = []metav1.Condition{{ + Type: string(tsapi.ConnectorReady), + Status: metav1.ConditionTrue, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + Reason: reasonConnectorCreated, + Message: reasonConnectorCreated, + }} + expectReconciled(t, cr, "", "test") +} diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_connectors.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_connectors.yaml index 9614f74e6b162..4434c12835ba1 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_connectors.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_connectors.yaml @@ -24,6 +24,10 @@ spec: jsonPath: .status.isExitNode name: IsExitNode type: string + - description: Whether this Connector instance is an app connector. + jsonPath: .status.isAppConnector + name: IsAppConnector + type: string - description: Status of the deployed Connector resources. jsonPath: .status.conditions[?(@.type == "ConnectorReady")].reason name: Status @@ -66,10 +70,40 @@ spec: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#spec-and-status type: object properties: + appConnector: + description: |- + AppConnector defines whether the Connector device should act as a Tailscale app connector. A Connector that is + configured as an app connector cannot be a subnet router or an exit node. If this field is unset, the + Connector does not act as an app connector. + Note that you will need to manually configure the permissions and the domains for the app connector via the + Admin panel. + Note also that the main tested and supported use case of this config option is to deploy an app connector on + Kubernetes to access SaaS applications available on the public internet. Using the app connector to expose + cluster workloads or other internal workloads to tailnet might work, but this is not a use case that we have + tested or optimised for. + If you are using the app connector to access SaaS applications because you need a predictable egress IP that + can be whitelisted, it is also your responsibility to ensure that cluster traffic from the connector flows + via that predictable IP, for example by enforcing that cluster egress traffic is routed via an egress NAT + device with a static IP address. + https://tailscale.com/kb/1281/app-connectors + type: object + properties: + routes: + description: |- + Routes are optional preconfigured routes for the domains routed via the app connector. + If not set, routes for the domains will be discovered dynamically. + If set, the app connector will immediately be able to route traffic using the preconfigured routes, but may + also dynamically discover other routes. + https://tailscale.com/kb/1332/apps-best-practices#preconfiguration + type: array + minItems: 1 + items: + type: string + format: cidr exitNode: description: |- - ExitNode defines whether the Connector node should act as a - Tailscale exit node. Defaults to false. + ExitNode defines whether the Connector device should act as a Tailscale exit node. Defaults to false. + This field is mutually exclusive with the appConnector field. https://tailscale.com/kb/1103/exit-nodes type: boolean hostname: @@ -90,9 +124,11 @@ spec: type: string subnetRouter: description: |- - SubnetRouter defines subnet routes that the Connector node should - expose to tailnet. If unset, none are exposed. + SubnetRouter defines subnet routes that the Connector device should + expose to tailnet as a Tailscale subnet router. https://tailscale.com/kb/1019/subnets/ + If this field is unset, the device does not get configured as a Tailscale subnet router. + This field is mutually exclusive with the appConnector field. type: object required: - advertiseRoutes @@ -125,8 +161,10 @@ spec: type: string pattern: ^tag:[a-zA-Z][a-zA-Z0-9-]*$ x-kubernetes-validations: - - rule: has(self.subnetRouter) || self.exitNode == true - message: A Connector needs to be either an exit node or a subnet router, or both. + - rule: has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true) || has(self.appConnector) + message: A Connector needs to have at least one of exit node, subnet router or app connector configured. + - rule: '!((has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true)) && has(self.appConnector))' + message: The appConnector field is mutually exclusive with exitNode and subnetRouter fields. status: description: |- ConnectorStatus describes the status of the Connector. This is set @@ -200,6 +238,9 @@ spec: If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the node. type: string + isAppConnector: + description: IsAppConnector is set to true if the Connector acts as an app connector. + type: boolean isExitNode: description: IsExitNode is set to true if the Connector acts as an exit node. type: boolean diff --git a/cmd/k8s-operator/deploy/manifests/operator.yaml b/cmd/k8s-operator/deploy/manifests/operator.yaml index 203a670664968..9d8e9faf60816 100644 --- a/cmd/k8s-operator/deploy/manifests/operator.yaml +++ b/cmd/k8s-operator/deploy/manifests/operator.yaml @@ -53,6 +53,10 @@ spec: jsonPath: .status.isExitNode name: IsExitNode type: string + - description: Whether this Connector instance is an app connector. + jsonPath: .status.isAppConnector + name: IsAppConnector + type: string - description: Status of the deployed Connector resources. jsonPath: .status.conditions[?(@.type == "ConnectorReady")].reason name: Status @@ -91,10 +95,40 @@ spec: More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#spec-and-status properties: + appConnector: + description: |- + AppConnector defines whether the Connector device should act as a Tailscale app connector. A Connector that is + configured as an app connector cannot be a subnet router or an exit node. If this field is unset, the + Connector does not act as an app connector. + Note that you will need to manually configure the permissions and the domains for the app connector via the + Admin panel. + Note also that the main tested and supported use case of this config option is to deploy an app connector on + Kubernetes to access SaaS applications available on the public internet. Using the app connector to expose + cluster workloads or other internal workloads to tailnet might work, but this is not a use case that we have + tested or optimised for. + If you are using the app connector to access SaaS applications because you need a predictable egress IP that + can be whitelisted, it is also your responsibility to ensure that cluster traffic from the connector flows + via that predictable IP, for example by enforcing that cluster egress traffic is routed via an egress NAT + device with a static IP address. + https://tailscale.com/kb/1281/app-connectors + properties: + routes: + description: |- + Routes are optional preconfigured routes for the domains routed via the app connector. + If not set, routes for the domains will be discovered dynamically. + If set, the app connector will immediately be able to route traffic using the preconfigured routes, but may + also dynamically discover other routes. + https://tailscale.com/kb/1332/apps-best-practices#preconfiguration + items: + format: cidr + type: string + minItems: 1 + type: array + type: object exitNode: description: |- - ExitNode defines whether the Connector node should act as a - Tailscale exit node. Defaults to false. + ExitNode defines whether the Connector device should act as a Tailscale exit node. Defaults to false. + This field is mutually exclusive with the appConnector field. https://tailscale.com/kb/1103/exit-nodes type: boolean hostname: @@ -115,9 +149,11 @@ spec: type: string subnetRouter: description: |- - SubnetRouter defines subnet routes that the Connector node should - expose to tailnet. If unset, none are exposed. + SubnetRouter defines subnet routes that the Connector device should + expose to tailnet as a Tailscale subnet router. https://tailscale.com/kb/1019/subnets/ + If this field is unset, the device does not get configured as a Tailscale subnet router. + This field is mutually exclusive with the appConnector field. properties: advertiseRoutes: description: |- @@ -151,8 +187,10 @@ spec: type: array type: object x-kubernetes-validations: - - message: A Connector needs to be either an exit node or a subnet router, or both. - rule: has(self.subnetRouter) || self.exitNode == true + - message: A Connector needs to have at least one of exit node, subnet router or app connector configured. + rule: has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true) || has(self.appConnector) + - message: The appConnector field is mutually exclusive with exitNode and subnetRouter fields. + rule: '!((has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true)) && has(self.appConnector))' status: description: |- ConnectorStatus describes the status of the Connector. This is set @@ -225,6 +263,9 @@ spec: If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the node. type: string + isAppConnector: + description: IsAppConnector is set to true if the Connector acts as an app connector. + type: boolean isExitNode: description: IsExitNode is set to true if the Connector acts as an exit node. type: boolean diff --git a/cmd/k8s-operator/operator_test.go b/cmd/k8s-operator/operator_test.go index a440fafb5cfc1..cc9927645621f 100644 --- a/cmd/k8s-operator/operator_test.go +++ b/cmd/k8s-operator/operator_test.go @@ -1388,7 +1388,7 @@ func TestTailscaledConfigfileHash(t *testing.T) { parentType: "svc", hostname: "default-test", clusterTargetIP: "10.20.30.40", - confFileHash: "e09bededa0379920141cbd0b0dbdf9b8b66545877f9e8397423f5ce3e1ba439e", + confFileHash: "362360188dac62bca8013c8134929fed8efd84b1f410c00873d14a05709b5647", app: kubetypes.AppIngressProxy, } expectEqual(t, fc, expectedSTS(t, fc, o), nil) @@ -1399,7 +1399,7 @@ func TestTailscaledConfigfileHash(t *testing.T) { mak.Set(&svc.Annotations, AnnotationHostname, "another-test") }) o.hostname = "another-test" - o.confFileHash = "5d754cf55463135ee34aa9821f2fd8483b53eb0570c3740c84a086304f427684" + o.confFileHash = "20db57cfabc3fc6490f6bb1dc85994e61d255cdfa2a56abb0141736e59f263ef" expectReconciled(t, sr, "default", "test") expectEqual(t, fc, expectedSTS(t, fc, o), nil) } diff --git a/cmd/k8s-operator/sts.go b/cmd/k8s-operator/sts.go index e89b9c93082cf..b6467b7981b67 100644 --- a/cmd/k8s-operator/sts.go +++ b/cmd/k8s-operator/sts.go @@ -132,10 +132,13 @@ type tailscaleSTSConfig struct { } type connector struct { - // routes is a list of subnet routes that this Connector should expose. + // routes is a list of routes that this Connector should advertise either as a subnet router or as an app + // connector. routes string // isExitNode defines whether this Connector should act as an exit node. isExitNode bool + // isAppConnector defines whether this Connector should act as an app connector. + isAppConnector bool } type tsnetServer interface { CertDomains() []string @@ -674,7 +677,7 @@ func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, } if stsCfg != nil && pc.Spec.Metrics != nil && pc.Spec.Metrics.Enable { if stsCfg.TailnetTargetFQDN == "" && stsCfg.TailnetTargetIP == "" && !stsCfg.ForwardClusterTrafficViaL7IngressProxy { - enableMetrics(ss, pc) + enableMetrics(ss) } else if stsCfg.ForwardClusterTrafficViaL7IngressProxy { // TODO (irbekrm): fix this // For Ingress proxies that have been configured with @@ -763,7 +766,7 @@ func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, return ss } -func enableMetrics(ss *appsv1.StatefulSet, pc *tsapi.ProxyClass) { +func enableMetrics(ss *appsv1.StatefulSet) { for i, c := range ss.Spec.Template.Spec.Containers { if c.Name == "tailscale" { // Serve metrics on on :9001/debug/metrics. If @@ -803,11 +806,13 @@ func tailscaledConfig(stsC *tailscaleSTSConfig, newAuthkey string, oldSecret *co Locked: "false", Hostname: &stsC.Hostname, NoStatefulFiltering: "false", + AppConnector: &ipn.AppConnectorPrefs{Advertise: false}, } // For egress proxies only, we need to ensure that stateful filtering is // not in place so that traffic from cluster can be forwarded via // Tailscale IPs. + // TODO (irbekrm): set it to true always as this is now the default in core. if stsC.TailnetTargetFQDN != "" || stsC.TailnetTargetIP != "" { conf.NoStatefulFiltering = "true" } @@ -817,6 +822,9 @@ func tailscaledConfig(stsC *tailscaleSTSConfig, newAuthkey string, oldSecret *co return nil, fmt.Errorf("error calculating routes: %w", err) } conf.AdvertiseRoutes = routes + if stsC.Connector.isAppConnector { + conf.AppConnector.Advertise = true + } } if shouldAcceptRoutes(stsC.ProxyClass) { conf.AcceptRoutes = "true" @@ -831,9 +839,15 @@ func tailscaledConfig(stsC *tailscaleSTSConfig, newAuthkey string, oldSecret *co } conf.AuthKey = key } + capVerConfigs := make(map[tailcfg.CapabilityVersion]ipn.ConfigVAlpha) + capVerConfigs[107] = *conf + + // AppConnector config option is only understood by clients of capver 107 and newer. + conf.AppConnector = nil capVerConfigs[95] = *conf - // legacy config should not contain NoStatefulFiltering field. + + // StatefulFiltering is only understood by clients of capver 95 and newer. conf.NoStatefulFiltering.Clear() capVerConfigs[94] = *conf return capVerConfigs, nil diff --git a/cmd/k8s-operator/testutils_test.go b/cmd/k8s-operator/testutils_test.go index 6b6297cbdd4fe..4b25d103c2a56 100644 --- a/cmd/k8s-operator/testutils_test.go +++ b/cmd/k8s-operator/testutils_test.go @@ -48,6 +48,7 @@ type configOpts struct { clusterTargetDNS string subnetRoutes string isExitNode bool + isAppConnector bool confFileHash string serveConfig *ipn.ServeConfig shouldEnableForwardingClusterTrafficViaIngress bool @@ -356,6 +357,7 @@ func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Sec Locked: "false", AuthKey: ptr.To("secret-authkey"), AcceptRoutes: "false", + AppConnector: &ipn.AppConnectorPrefs{Advertise: false}, } if opts.proxyClass != "" { t.Logf("applying configuration from ProxyClass %s", opts.proxyClass) @@ -370,6 +372,9 @@ func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Sec if opts.shouldRemoveAuthKey { conf.AuthKey = nil } + if opts.isAppConnector { + conf.AppConnector = &ipn.AppConnectorPrefs{Advertise: true} + } var routes []netip.Prefix if opts.subnetRoutes != "" || opts.isExitNode { r := opts.subnetRoutes @@ -384,22 +389,29 @@ func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Sec routes = append(routes, prefix) } } - conf.AdvertiseRoutes = routes - b, err := json.Marshal(conf) - if err != nil { - t.Fatalf("error marshalling tailscaled config") - } if opts.tailnetTargetFQDN != "" || opts.tailnetTargetIP != "" { conf.NoStatefulFiltering = "true" } else { conf.NoStatefulFiltering = "false" } + conf.AdvertiseRoutes = routes + bnn, err := json.Marshal(conf) + if err != nil { + t.Fatalf("error marshalling tailscaled config") + } + conf.AppConnector = nil bn, err := json.Marshal(conf) if err != nil { t.Fatalf("error marshalling tailscaled config") } + conf.NoStatefulFiltering.Clear() + b, err := json.Marshal(conf) + if err != nil { + t.Fatalf("error marshalling tailscaled config") + } mak.Set(&s.StringData, "tailscaled", string(b)) mak.Set(&s.StringData, "cap-95.hujson", string(bn)) + mak.Set(&s.StringData, "cap-107.hujson", string(bnn)) labels := map[string]string{ "tailscale.com/managed": "true", "tailscale.com/parent-resource": "test", @@ -674,5 +686,17 @@ func removeAuthKeyIfExistsModifier(t *testing.T) func(s *corev1.Secret) { } mak.Set(&secret.StringData, "cap-95.hujson", string(b)) } + if len(secret.StringData["cap-107.hujson"]) != 0 { + conf := &ipn.ConfigVAlpha{} + if err := json.Unmarshal([]byte(secret.StringData["cap-107.hujson"]), conf); err != nil { + t.Fatalf("error umarshalling 'cap-107.hujson' contents: %v", err) + } + conf.AuthKey = nil + b, err := json.Marshal(conf) + if err != nil { + t.Fatalf("error marshalling 'cap-107.huson' contents: %v", err) + } + mak.Set(&secret.StringData, "cap-107.hujson", string(b)) + } } } diff --git a/k8s-operator/api.md b/k8s-operator/api.md index dae969516b9e7..7b1aca3148e5b 100644 --- a/k8s-operator/api.md +++ b/k8s-operator/api.md @@ -21,6 +21,22 @@ +#### AppConnector + + + +AppConnector defines a Tailscale app connector node configured via Connector. + + + +_Appears in:_ +- [ConnectorSpec](#connectorspec) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `routes` _[Routes](#routes)_ | Routes are optional preconfigured routes for the domains routed via the app connector.
If not set, routes for the domains will be discovered dynamically.
If set, the app connector will immediately be able to route traffic using the preconfigured routes, but may
also dynamically discover other routes.
https://tailscale.com/kb/1332/apps-best-practices#preconfiguration | | Format: cidr
MinItems: 1
Type: string
| + + #### Connector @@ -86,8 +102,9 @@ _Appears in:_ | `tags` _[Tags](#tags)_ | Tags that the Tailscale node will be tagged with.
Defaults to [tag:k8s].
To autoapprove the subnet routes or exit node defined by a Connector,
you can configure Tailscale ACLs to give these tags the necessary
permissions.
See https://tailscale.com/kb/1337/acl-syntax#autoapprovers.
If you specify custom tags here, you must also make the operator an owner of these tags.
See https://tailscale.com/kb/1236/kubernetes-operator/#setting-up-the-kubernetes-operator.
Tags cannot be changed once a Connector node has been created.
Tag values must be in form ^tag:[a-zA-Z][a-zA-Z0-9-]*$. | | Pattern: `^tag:[a-zA-Z][a-zA-Z0-9-]*$`
Type: string
| | `hostname` _[Hostname](#hostname)_ | Hostname is the tailnet hostname that should be assigned to the
Connector node. If unset, hostname defaults to name>-connector. Hostname can contain lower case letters, numbers and
dashes, it must not start or end with a dash and must be between 2
and 63 characters long. | | Pattern: `^[a-z0-9][a-z0-9-]{0,61}[a-z0-9]$`
Type: string
| | `proxyClass` _string_ | ProxyClass is the name of the ProxyClass custom resource that
contains configuration options that should be applied to the
resources created for this Connector. If unset, the operator will
create resources with the default configuration. | | | -| `subnetRouter` _[SubnetRouter](#subnetrouter)_ | SubnetRouter defines subnet routes that the Connector node should
expose to tailnet. If unset, none are exposed.
https://tailscale.com/kb/1019/subnets/ | | | -| `exitNode` _boolean_ | ExitNode defines whether the Connector node should act as a
Tailscale exit node. Defaults to false.
https://tailscale.com/kb/1103/exit-nodes | | | +| `subnetRouter` _[SubnetRouter](#subnetrouter)_ | SubnetRouter defines subnet routes that the Connector device should
expose to tailnet as a Tailscale subnet router.
https://tailscale.com/kb/1019/subnets/
If this field is unset, the device does not get configured as a Tailscale subnet router.
This field is mutually exclusive with the appConnector field. | | | +| `appConnector` _[AppConnector](#appconnector)_ | AppConnector defines whether the Connector device should act as a Tailscale app connector. A Connector that is
configured as an app connector cannot be a subnet router or an exit node. If this field is unset, the
Connector does not act as an app connector.
Note that you will need to manually configure the permissions and the domains for the app connector via the
Admin panel.
Note also that the main tested and supported use case of this config option is to deploy an app connector on
Kubernetes to access SaaS applications available on the public internet. Using the app connector to expose
cluster workloads or other internal workloads to tailnet might work, but this is not a use case that we have
tested or optimised for.
If you are using the app connector to access SaaS applications because you need a predictable egress IP that
can be whitelisted, it is also your responsibility to ensure that cluster traffic from the connector flows
via that predictable IP, for example by enforcing that cluster egress traffic is routed via an egress NAT
device with a static IP address.
https://tailscale.com/kb/1281/app-connectors | | | +| `exitNode` _boolean_ | ExitNode defines whether the Connector device should act as a Tailscale exit node. Defaults to false.
This field is mutually exclusive with the appConnector field.
https://tailscale.com/kb/1103/exit-nodes | | | #### ConnectorStatus @@ -106,6 +123,7 @@ _Appears in:_ | `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#condition-v1-meta) array_ | List of status conditions to indicate the status of the Connector.
Known condition types are `ConnectorReady`. | | | | `subnetRoutes` _string_ | SubnetRoutes are the routes currently exposed to tailnet via this
Connector instance. | | | | `isExitNode` _boolean_ | IsExitNode is set to true if the Connector acts as an exit node. | | | +| `isAppConnector` _boolean_ | IsAppConnector is set to true if the Connector acts as an app connector. | | | | `tailnetIPs` _string array_ | TailnetIPs is the set of tailnet IP addresses (both IPv4 and IPv6)
assigned to the Connector node. | | | | `hostname` _string_ | Hostname is the fully qualified domain name of the Connector node.
If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the
node. | | | @@ -746,6 +764,7 @@ _Validation:_ - Type: string _Appears in:_ +- [AppConnector](#appconnector) - [SubnetRouter](#subnetrouter) diff --git a/k8s-operator/apis/v1alpha1/types_connector.go b/k8s-operator/apis/v1alpha1/types_connector.go index 27afd0838a388..0222584859bd6 100644 --- a/k8s-operator/apis/v1alpha1/types_connector.go +++ b/k8s-operator/apis/v1alpha1/types_connector.go @@ -22,6 +22,7 @@ var ConnectorKind = "Connector" // +kubebuilder:resource:scope=Cluster,shortName=cn // +kubebuilder:printcolumn:name="SubnetRoutes",type="string",JSONPath=`.status.subnetRoutes`,description="CIDR ranges exposed to tailnet by a subnet router defined via this Connector instance." // +kubebuilder:printcolumn:name="IsExitNode",type="string",JSONPath=`.status.isExitNode`,description="Whether this Connector instance defines an exit node." +// +kubebuilder:printcolumn:name="IsAppConnector",type="string",JSONPath=`.status.isAppConnector`,description="Whether this Connector instance is an app connector." // +kubebuilder:printcolumn:name="Status",type="string",JSONPath=`.status.conditions[?(@.type == "ConnectorReady")].reason`,description="Status of the deployed Connector resources." // Connector defines a Tailscale node that will be deployed in the cluster. The @@ -55,7 +56,8 @@ type ConnectorList struct { } // ConnectorSpec describes a Tailscale node to be deployed in the cluster. -// +kubebuilder:validation:XValidation:rule="has(self.subnetRouter) || self.exitNode == true",message="A Connector needs to be either an exit node or a subnet router, or both." +// +kubebuilder:validation:XValidation:rule="has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true) || has(self.appConnector)",message="A Connector needs to have at least one of exit node, subnet router or app connector configured." +// +kubebuilder:validation:XValidation:rule="!((has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true)) && has(self.appConnector))",message="The appConnector field is mutually exclusive with exitNode and subnetRouter fields." type ConnectorSpec struct { // Tags that the Tailscale node will be tagged with. // Defaults to [tag:k8s]. @@ -82,13 +84,31 @@ type ConnectorSpec struct { // create resources with the default configuration. // +optional ProxyClass string `json:"proxyClass,omitempty"` - // SubnetRouter defines subnet routes that the Connector node should - // expose to tailnet. If unset, none are exposed. + // SubnetRouter defines subnet routes that the Connector device should + // expose to tailnet as a Tailscale subnet router. // https://tailscale.com/kb/1019/subnets/ + // If this field is unset, the device does not get configured as a Tailscale subnet router. + // This field is mutually exclusive with the appConnector field. // +optional - SubnetRouter *SubnetRouter `json:"subnetRouter"` - // ExitNode defines whether the Connector node should act as a - // Tailscale exit node. Defaults to false. + SubnetRouter *SubnetRouter `json:"subnetRouter,omitempty"` + // AppConnector defines whether the Connector device should act as a Tailscale app connector. A Connector that is + // configured as an app connector cannot be a subnet router or an exit node. If this field is unset, the + // Connector does not act as an app connector. + // Note that you will need to manually configure the permissions and the domains for the app connector via the + // Admin panel. + // Note also that the main tested and supported use case of this config option is to deploy an app connector on + // Kubernetes to access SaaS applications available on the public internet. Using the app connector to expose + // cluster workloads or other internal workloads to tailnet might work, but this is not a use case that we have + // tested or optimised for. + // If you are using the app connector to access SaaS applications because you need a predictable egress IP that + // can be whitelisted, it is also your responsibility to ensure that cluster traffic from the connector flows + // via that predictable IP, for example by enforcing that cluster egress traffic is routed via an egress NAT + // device with a static IP address. + // https://tailscale.com/kb/1281/app-connectors + // +optional + AppConnector *AppConnector `json:"appConnector,omitempty"` + // ExitNode defines whether the Connector device should act as a Tailscale exit node. Defaults to false. + // This field is mutually exclusive with the appConnector field. // https://tailscale.com/kb/1103/exit-nodes // +optional ExitNode bool `json:"exitNode"` @@ -104,6 +124,17 @@ type SubnetRouter struct { AdvertiseRoutes Routes `json:"advertiseRoutes"` } +// AppConnector defines a Tailscale app connector node configured via Connector. +type AppConnector struct { + // Routes are optional preconfigured routes for the domains routed via the app connector. + // If not set, routes for the domains will be discovered dynamically. + // If set, the app connector will immediately be able to route traffic using the preconfigured routes, but may + // also dynamically discover other routes. + // https://tailscale.com/kb/1332/apps-best-practices#preconfiguration + // +optional + Routes Routes `json:"routes"` +} + type Tags []Tag func (tags Tags) Stringify() []string { @@ -156,6 +187,9 @@ type ConnectorStatus struct { // IsExitNode is set to true if the Connector acts as an exit node. // +optional IsExitNode bool `json:"isExitNode"` + // IsAppConnector is set to true if the Connector acts as an app connector. + // +optional + IsAppConnector bool `json:"isAppConnector"` // TailnetIPs is the set of tailnet IP addresses (both IPv4 and IPv6) // assigned to the Connector node. // +optional diff --git a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go index f53165b886ec2..c2f69dc045314 100644 --- a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go +++ b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go @@ -13,6 +13,26 @@ import ( "k8s.io/apimachinery/pkg/runtime" ) +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *AppConnector) DeepCopyInto(out *AppConnector) { + *out = *in + if in.Routes != nil { + in, out := &in.Routes, &out.Routes + *out = make(Routes, len(*in)) + copy(*out, *in) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new AppConnector. +func (in *AppConnector) DeepCopy() *AppConnector { + if in == nil { + return nil + } + out := new(AppConnector) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Connector) DeepCopyInto(out *Connector) { *out = *in @@ -85,6 +105,11 @@ func (in *ConnectorSpec) DeepCopyInto(out *ConnectorSpec) { *out = new(SubnetRouter) (*in).DeepCopyInto(*out) } + if in.AppConnector != nil { + in, out := &in.AppConnector, &out.AppConnector + *out = new(AppConnector) + (*in).DeepCopyInto(*out) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ConnectorSpec. diff --git a/kube/kubetypes/metrics.go b/kube/kubetypes/metrics.go index 63078385ad293..63325182d29c4 100644 --- a/kube/kubetypes/metrics.go +++ b/kube/kubetypes/metrics.go @@ -21,6 +21,7 @@ const ( MetricConnectorResourceCount = "k8s_connector_resources" MetricConnectorWithSubnetRouterCount = "k8s_connector_subnetrouter_resources" MetricConnectorWithExitNodeCount = "k8s_connector_exitnode_resources" + MetricConnectorWithAppConnectorCount = "k8s_connector_appconnector_resources" MetricNameserverCount = "k8s_nameserver_resources" MetricRecorderCount = "k8s_recorder_resources" MetricEgressServiceCount = "k8s_egress_service_resources" From 00be1761b76671635b478a20187d83b166991924 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 11 Nov 2024 08:48:09 -0800 Subject: [PATCH 090/179] util/codegen: treat unique.Handle as an opaque value type It doesn't need a Clone method, like a time.Time, etc. And then, because Go 1.23+ uses unique.Handle internally for the netip package types, we can remove those special cases. Updates #14058 (pulled out from that PR) Updates tailscale/corp#24485 Change-Id: Iac3548a9417ccda5987f98e0305745a6e178b375 Signed-off-by: Brad Fitzpatrick --- util/codegen/codegen.go | 11 ++++++++--- util/codegen/codegen_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/util/codegen/codegen.go b/util/codegen/codegen.go index 2f7781b681a24..1b3af10e03ee1 100644 --- a/util/codegen/codegen.go +++ b/util/codegen/codegen.go @@ -277,11 +277,16 @@ func IsInvalid(t types.Type) bool { // It has special handling for some types that contain pointers // that we know are free from memory aliasing/mutation concerns. func ContainsPointers(typ types.Type) bool { - switch typ.String() { + s := typ.String() + switch s { case "time.Time": - // time.Time contains a pointer that does not need copying + // time.Time contains a pointer that does not need cloning. return false - case "inet.af/netip.Addr", "net/netip.Addr", "net/netip.Prefix", "net/netip.AddrPort": + case "inet.af/netip.Addr": + return false + } + if strings.HasPrefix(s, "unique.Handle[") { + // unique.Handle contains a pointer that does not need cloning. return false } switch ft := typ.Underlying().(type) { diff --git a/util/codegen/codegen_test.go b/util/codegen/codegen_test.go index 28ddaed2bac36..74715eecae6ef 100644 --- a/util/codegen/codegen_test.go +++ b/util/codegen/codegen_test.go @@ -10,6 +10,8 @@ import ( "strings" "sync" "testing" + "time" + "unique" "unsafe" "golang.org/x/exp/constraints" @@ -84,6 +86,16 @@ type PointerUnionParam[T netip.Prefix | BasicType | IntPtr] struct { V T } +type StructWithUniqueHandle struct{ _ unique.Handle[[32]byte] } + +type StructWithTime struct{ _ time.Time } + +type StructWithNetipTypes struct { + _ netip.Addr + _ netip.AddrPort + _ netip.Prefix +} + type Interface interface { Method() } @@ -161,6 +173,18 @@ func TestGenericContainsPointers(t *testing.T) { typ: "PointerUnionParam", wantPointer: true, }, + { + typ: "StructWithUniqueHandle", + wantPointer: false, + }, + { + typ: "StructWithTime", + wantPointer: false, + }, + { + typ: "StructWithNetipTypes", + wantPointer: false, + }, } for _, tt := range tests { From 4e0fc037e67a86a0734f025e041ba7f04f4cc3d4 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 11 Nov 2024 13:08:47 -0800 Subject: [PATCH 091/179] all: use iterators over slice views more This gets close to all of the remaining ones. Updates #12912 Change-Id: I9c672bbed2654a6c5cab31e0cbece6c107d8c6fa Signed-off-by: Brad Fitzpatrick --- cmd/tsconnect/wasm/wasm_js.go | 8 ++++---- ipn/ipnlocal/drive.go | 5 ++--- ipn/ipnlocal/local.go | 24 ++++++++++-------------- ipn/ipnlocal/local_test.go | 4 +--- ipn/ipnlocal/network-lock.go | 6 ++---- ipn/ipnlocal/serve.go | 3 +-- ipn/ipnlocal/web_client.go | 4 ++-- net/ipset/ipset.go | 8 ++++---- net/tsaddr/tsaddr.go | 11 +++++------ net/tsdial/dnsmap.go | 9 ++++----- tsnet/tsnet.go | 3 +-- types/netmap/netmap.go | 15 +++++++-------- util/set/slice.go | 4 ++-- wgengine/magicsock/debughttp.go | 3 +-- wgengine/magicsock/endpoint.go | 7 +++---- wgengine/magicsock/magicsock.go | 10 ++++------ wgengine/netstack/netstack.go | 6 ++---- wgengine/pendopen.go | 3 +-- wgengine/userspace.go | 9 ++++----- wgengine/wgcfg/nmcfg/nmcfg.go | 6 ++---- 20 files changed, 62 insertions(+), 86 deletions(-) diff --git a/cmd/tsconnect/wasm/wasm_js.go b/cmd/tsconnect/wasm/wasm_js.go index d0bc991f2ca9d..4ea1cd89713cd 100644 --- a/cmd/tsconnect/wasm/wasm_js.go +++ b/cmd/tsconnect/wasm/wasm_js.go @@ -272,8 +272,8 @@ func (i *jsIPN) run(jsCallbacks js.Value) { name = p.Hostinfo().Hostname() } addrs := make([]string, p.Addresses().Len()) - for i := range p.Addresses().Len() { - addrs[i] = p.Addresses().At(i).Addr().String() + for i, ap := range p.Addresses().All() { + addrs[i] = ap.Addr().String() } return jsNetMapPeerNode{ jsNetMapNode: jsNetMapNode{ @@ -589,8 +589,8 @@ func mapSlice[T any, M any](a []T, f func(T) M) []M { func mapSliceView[T any, M any](a views.Slice[T], f func(T) M) []M { n := make([]M, a.Len()) - for i := range a.Len() { - n[i] = f(a.At(i)) + for i, v := range a.All() { + n[i] = f(v) } return n } diff --git a/ipn/ipnlocal/drive.go b/ipn/ipnlocal/drive.go index 98d563d8746b1..fe3622ba40e3e 100644 --- a/ipn/ipnlocal/drive.go +++ b/ipn/ipnlocal/drive.go @@ -354,9 +354,8 @@ func (b *LocalBackend) driveRemotesFromPeers(nm *netmap.NetworkMap) []*drive.Rem // Check that the peer is allowed to share with us. addresses := peer.Addresses() - for i := range addresses.Len() { - addr := addresses.At(i) - capsMap := b.PeerCaps(addr.Addr()) + for _, p := range addresses.All() { + capsMap := b.PeerCaps(p.Addr()) if capsMap.HasCapability(tailcfg.PeerCapabilityTaildriveSharer) { return true } diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 337fa3d2b829a..493762fccab19 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -1811,8 +1811,7 @@ func setExitNodeID(prefs *ipn.Prefs, nm *netmap.NetworkMap, lastSuggestedExitNod } for _, peer := range nm.Peers { - for i := range peer.Addresses().Len() { - addr := peer.Addresses().At(i) + for _, addr := range peer.Addresses().All() { if !addr.IsSingleIP() || addr.Addr() != prefs.ExitNodeIP { continue } @@ -4997,8 +4996,8 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock case ipn.Running: var addrStrs []string addrs := netMap.GetAddresses() - for i := range addrs.Len() { - addrStrs = append(addrStrs, addrs.At(i).Addr().String()) + for _, p := range addrs.All() { + addrStrs = append(addrStrs, p.Addr().String()) } systemd.Status("Connected; %s; %s", activeLogin, strings.Join(addrStrs, " ")) case ipn.NoState: @@ -6089,8 +6088,7 @@ func (b *LocalBackend) SetDNS(ctx context.Context, name, value string) error { func peerAPIPorts(peer tailcfg.NodeView) (p4, p6 uint16) { svcs := peer.Hostinfo().Services() - for i := range svcs.Len() { - s := svcs.At(i) + for _, s := range svcs.All() { switch s.Proto { case tailcfg.PeerAPI4: p4 = s.Port @@ -6122,8 +6120,7 @@ func peerAPIBase(nm *netmap.NetworkMap, peer tailcfg.NodeView) string { var have4, have6 bool addrs := nm.GetAddresses() - for i := range addrs.Len() { - a := addrs.At(i) + for _, a := range addrs.All() { if !a.IsSingleIP() { continue } @@ -6145,10 +6142,9 @@ func peerAPIBase(nm *netmap.NetworkMap, peer tailcfg.NodeView) string { } func nodeIP(n tailcfg.NodeView, pred func(netip.Addr) bool) netip.Addr { - for i := range n.Addresses().Len() { - a := n.Addresses().At(i) - if a.IsSingleIP() && pred(a.Addr()) { - return a.Addr() + for _, pfx := range n.Addresses().All() { + if pfx.IsSingleIP() && pred(pfx.Addr()) { + return pfx.Addr() } } return netip.Addr{} @@ -6378,8 +6374,8 @@ func peerCanProxyDNS(p tailcfg.NodeView) bool { // If p.Cap is not populated (e.g. older control server), then do the old // thing of searching through services. services := p.Hostinfo().Services() - for i := range services.Len() { - if s := services.At(i); s.Proto == tailcfg.PeerAPIDNS && s.Port >= 1 { + for _, s := range services.All() { + if s.Proto == tailcfg.PeerAPIDNS && s.Port >= 1 { return true } } diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index 433679dda193e..6dad2dba4deeb 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -3041,12 +3041,10 @@ func deterministicNodeForTest(t testing.TB, want views.Slice[tailcfg.StableNodeI var ret tailcfg.NodeView gotIDs := make([]tailcfg.StableNodeID, got.Len()) - for i := range got.Len() { - nv := got.At(i) + for i, nv := range got.All() { if !nv.Valid() { t.Fatalf("invalid node at index %v", i) } - gotIDs[i] = nv.StableID() if nv.StableID() == use { ret = nv diff --git a/ipn/ipnlocal/network-lock.go b/ipn/ipnlocal/network-lock.go index d20bf94eb971a..bf14d339ed890 100644 --- a/ipn/ipnlocal/network-lock.go +++ b/ipn/ipnlocal/network-lock.go @@ -430,8 +430,7 @@ func (b *LocalBackend) tkaBootstrapFromGenesisLocked(g tkatype.MarshaledAUM, per } bootstrapStateID := fmt.Sprintf("%d:%d", genesis.State.StateID1, genesis.State.StateID2) - for i := range persist.DisallowedTKAStateIDs().Len() { - stateID := persist.DisallowedTKAStateIDs().At(i) + for _, stateID := range persist.DisallowedTKAStateIDs().All() { if stateID == bootstrapStateID { return fmt.Errorf("TKA with stateID of %q is disallowed on this node", stateID) } @@ -572,8 +571,7 @@ func tkaStateFromPeer(p tailcfg.NodeView) ipnstate.TKAPeer { TailscaleIPs: make([]netip.Addr, 0, p.Addresses().Len()), NodeKey: p.Key(), } - for i := range p.Addresses().Len() { - addr := p.Addresses().At(i) + for _, addr := range p.Addresses().All() { if addr.IsSingleIP() && tsaddr.IsTailscaleIP(addr.Addr()) { fp.TailscaleIPs = append(fp.TailscaleIPs, addr.Addr()) } diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index 67d521f0968eb..61bed05527167 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -242,8 +242,7 @@ func (b *LocalBackend) updateServeTCPPortNetMapAddrListenersLocked(ports []uint1 } addrs := nm.GetAddresses() - for i := range addrs.Len() { - a := addrs.At(i) + for _, a := range addrs.All() { for _, p := range ports { addrPort := netip.AddrPortFrom(a.Addr(), p) if _, ok := b.serveListeners[addrPort]; ok { diff --git a/ipn/ipnlocal/web_client.go b/ipn/ipnlocal/web_client.go index ccde9f01dced0..37fc31819dac4 100644 --- a/ipn/ipnlocal/web_client.go +++ b/ipn/ipnlocal/web_client.go @@ -121,8 +121,8 @@ func (b *LocalBackend) updateWebClientListenersLocked() { } addrs := b.netMap.GetAddresses() - for i := range addrs.Len() { - addrPort := netip.AddrPortFrom(addrs.At(i).Addr(), webClientPort) + for _, pfx := range addrs.All() { + addrPort := netip.AddrPortFrom(pfx.Addr(), webClientPort) if _, ok := b.webClientListeners[addrPort]; ok { continue // already listening } diff --git a/net/ipset/ipset.go b/net/ipset/ipset.go index 622fd61d05c16..27c1e27ed4180 100644 --- a/net/ipset/ipset.go +++ b/net/ipset/ipset.go @@ -82,8 +82,8 @@ func NewContainsIPFunc(addrs views.Slice[netip.Prefix]) func(ip netip.Addr) bool pathForTest("bart") // Built a bart table. t := &bart.Table[struct{}]{} - for i := range addrs.Len() { - t.Insert(addrs.At(i), struct{}{}) + for _, p := range addrs.All() { + t.Insert(p, struct{}{}) } return bartLookup(t) } @@ -99,8 +99,8 @@ func NewContainsIPFunc(addrs views.Slice[netip.Prefix]) func(ip netip.Addr) bool // General case: pathForTest("ip-map") m := set.Set[netip.Addr]{} - for i := range addrs.Len() { - m.Add(addrs.At(i).Addr()) + for _, p := range addrs.All() { + m.Add(p.Addr()) } return ipInMap(m) } diff --git a/net/tsaddr/tsaddr.go b/net/tsaddr/tsaddr.go index 88069538724b6..e7e0ba088bfd5 100644 --- a/net/tsaddr/tsaddr.go +++ b/net/tsaddr/tsaddr.go @@ -180,8 +180,7 @@ func PrefixIs6(p netip.Prefix) bool { return p.Addr().Is6() } // IPv6 /0 route. func ContainsExitRoutes(rr views.Slice[netip.Prefix]) bool { var v4, v6 bool - for i := range rr.Len() { - r := rr.At(i) + for _, r := range rr.All() { if r == allIPv4 { v4 = true } else if r == allIPv6 { @@ -194,8 +193,8 @@ func ContainsExitRoutes(rr views.Slice[netip.Prefix]) bool { // ContainsExitRoute reports whether rr contains at least one of IPv4 or // IPv6 /0 (exit) routes. func ContainsExitRoute(rr views.Slice[netip.Prefix]) bool { - for i := range rr.Len() { - if rr.At(i).Bits() == 0 { + for _, r := range rr.All() { + if r.Bits() == 0 { return true } } @@ -205,8 +204,8 @@ func ContainsExitRoute(rr views.Slice[netip.Prefix]) bool { // ContainsNonExitSubnetRoutes reports whether v contains Subnet // Routes other than ExitNode Routes. func ContainsNonExitSubnetRoutes(rr views.Slice[netip.Prefix]) bool { - for i := range rr.Len() { - if rr.At(i).Bits() != 0 { + for _, r := range rr.All() { + if r.Bits() != 0 { return true } } diff --git a/net/tsdial/dnsmap.go b/net/tsdial/dnsmap.go index f5d13861bb65f..2ef1cb1f171c0 100644 --- a/net/tsdial/dnsmap.go +++ b/net/tsdial/dnsmap.go @@ -42,8 +42,8 @@ func dnsMapFromNetworkMap(nm *netmap.NetworkMap) dnsMap { if dnsname.HasSuffix(nm.Name, suffix) { ret[canonMapKey(dnsname.TrimSuffix(nm.Name, suffix))] = ip } - for i := range addrs.Len() { - if addrs.At(i).Addr().Is4() { + for _, p := range addrs.All() { + if p.Addr().Is4() { have4 = true } } @@ -52,9 +52,8 @@ func dnsMapFromNetworkMap(nm *netmap.NetworkMap) dnsMap { if p.Name() == "" { continue } - for i := range p.Addresses().Len() { - a := p.Addresses().At(i) - ip := a.Addr() + for _, pfx := range p.Addresses().All() { + ip := pfx.Addr() if ip.Is4() && !have4 { continue } diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 70084c103e104..34cab7385558b 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -433,8 +433,7 @@ func (s *Server) TailscaleIPs() (ip4, ip6 netip.Addr) { return } addrs := nm.GetAddresses() - for i := range addrs.Len() { - addr := addrs.At(i) + for _, addr := range addrs.All() { ip := addr.Addr() if ip.Is6() { ip6 = ip diff --git a/types/netmap/netmap.go b/types/netmap/netmap.go index 5e06229221e1f..94e872a5593ea 100644 --- a/types/netmap/netmap.go +++ b/types/netmap/netmap.go @@ -279,15 +279,14 @@ func (a *NetworkMap) equalConciseHeader(b *NetworkMap) bool { // in nodeConciseEqual in sync. func printPeerConcise(buf *strings.Builder, p tailcfg.NodeView) { aip := make([]string, p.AllowedIPs().Len()) - for i := range aip { - a := p.AllowedIPs().At(i) - s := strings.TrimSuffix(fmt.Sprint(a), "/32") + for i, a := range p.AllowedIPs().All() { + s := strings.TrimSuffix(a.String(), "/32") aip[i] = s } - ep := make([]string, p.Endpoints().Len()) - for i := range ep { - e := p.Endpoints().At(i).String() + epStrs := make([]string, p.Endpoints().Len()) + for i, ep := range p.Endpoints().All() { + e := ep.String() // Align vertically on the ':' between IP and port colon := strings.IndexByte(e, ':') spaces := 0 @@ -295,7 +294,7 @@ func printPeerConcise(buf *strings.Builder, p tailcfg.NodeView) { spaces++ colon-- } - ep[i] = fmt.Sprintf("%21v", e+strings.Repeat(" ", spaces)) + epStrs[i] = fmt.Sprintf("%21v", e+strings.Repeat(" ", spaces)) } derp := p.DERP() @@ -316,7 +315,7 @@ func printPeerConcise(buf *strings.Builder, p tailcfg.NodeView) { discoShort, derp, strings.Join(aip, " "), - strings.Join(ep, " ")) + strings.Join(epStrs, " ")) } // nodeConciseEqual reports whether a and b are equal for the fields accessed by printPeerConcise. diff --git a/util/set/slice.go b/util/set/slice.go index 38551aee197ad..2fc65b82d1c6e 100644 --- a/util/set/slice.go +++ b/util/set/slice.go @@ -67,7 +67,7 @@ func (ss *Slice[T]) Add(vs ...T) { // AddSlice adds all elements in vs to the set. func (ss *Slice[T]) AddSlice(vs views.Slice[T]) { - for i := range vs.Len() { - ss.Add(vs.At(i)) + for _, v := range vs.All() { + ss.Add(v) } } diff --git a/wgengine/magicsock/debughttp.go b/wgengine/magicsock/debughttp.go index 6c07b0d5eaa83..aa109c242e27c 100644 --- a/wgengine/magicsock/debughttp.go +++ b/wgengine/magicsock/debughttp.go @@ -102,8 +102,7 @@ func (c *Conn) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { sort.Slice(ent, func(i, j int) bool { return ent[i].pub.Less(ent[j].pub) }) peers := map[key.NodePublic]tailcfg.NodeView{} - for i := range c.peers.Len() { - p := c.peers.At(i) + for _, p := range c.peers.All() { peers[p.Key()] = p } diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index 5e0ada6170c2f..bbba3181ce453 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -9,6 +9,7 @@ import ( "encoding/binary" "errors" "fmt" + "iter" "math" "math/rand/v2" "net" @@ -1384,20 +1385,18 @@ func (de *endpoint) updateFromNode(n tailcfg.NodeView, heartbeatDisabled bool, p } func (de *endpoint) setEndpointsLocked(eps interface { - Len() int - At(i int) netip.AddrPort + All() iter.Seq2[int, netip.AddrPort] }) { for _, st := range de.endpointState { st.index = indexSentinelDeleted // assume deleted until updated in next loop } var newIpps []netip.AddrPort - for i := range eps.Len() { + for i, ipp := range eps.All() { if i > math.MaxInt16 { // Seems unlikely. break } - ipp := eps.At(i) if !ipp.IsValid() { de.c.logf("magicsock: bogus netmap endpoint from %v", eps) continue diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index a9c6fa070e90f..c361608ad4b23 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -1120,8 +1120,8 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro // re-run. eps = c.endpointTracker.update(time.Now(), eps) - for i := range c.staticEndpoints.Len() { - addAddr(c.staticEndpoints.At(i), tailcfg.EndpointExplicitConf) + for _, ep := range c.staticEndpoints.All() { + addAddr(ep, tailcfg.EndpointExplicitConf) } if localAddr := c.pconn4.LocalAddr(); localAddr.IP.IsUnspecified() { @@ -2360,16 +2360,14 @@ func (c *Conn) logEndpointCreated(n tailcfg.NodeView) { fmt.Fprintf(w, "derp=%v%s ", regionID, code) } - for i := range n.AllowedIPs().Len() { - a := n.AllowedIPs().At(i) + for _, a := range n.AllowedIPs().All() { if a.IsSingleIP() { fmt.Fprintf(w, "aip=%v ", a.Addr()) } else { fmt.Fprintf(w, "aip=%v ", a) } } - for i := range n.Endpoints().Len() { - ep := n.Endpoints().At(i) + for _, ep := range n.Endpoints().All() { fmt.Fprintf(w, "ep=%v ", ep) } })) diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 280f4b7bb5d3c..20eac06e6b8fd 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -643,13 +643,11 @@ func (ns *Impl) UpdateNetstackIPs(nm *netmap.NetworkMap) { newPfx := make(map[netip.Prefix]bool) if selfNode.Valid() { - for i := range selfNode.Addresses().Len() { - p := selfNode.Addresses().At(i) + for _, p := range selfNode.Addresses().All() { newPfx[p] = true } if ns.ProcessSubnets { - for i := range selfNode.AllowedIPs().Len() { - p := selfNode.AllowedIPs().At(i) + for _, p := range selfNode.AllowedIPs().All() { newPfx[p] = true } } diff --git a/wgengine/pendopen.go b/wgengine/pendopen.go index 340c7e0f3f7be..7db07c685aa75 100644 --- a/wgengine/pendopen.go +++ b/wgengine/pendopen.go @@ -207,8 +207,7 @@ func (e *userspaceEngine) onOpenTimeout(flow flowtrack.Tuple) { ps, found := e.getPeerStatusLite(n.Key()) if !found { onlyZeroRoute := true // whether peerForIP returned n only because its /0 route matched - for i := range n.AllowedIPs().Len() { - r := n.AllowedIPs().At(i) + for _, r := range n.AllowedIPs().All() { if r.Bits() != 0 && r.Contains(flow.DstAddr()) { onlyZeroRoute = false break diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 2dd0c4cd5da89..81f8000e0d557 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -852,8 +852,7 @@ func (e *userspaceEngine) updateActivityMapsLocked(trackNodes []key.NodePublic, // hasOverlap checks if there is a IPPrefix which is common amongst the two // provided slices. func hasOverlap(aips, rips views.Slice[netip.Prefix]) bool { - for i := range aips.Len() { - aip := aips.At(i) + for _, aip := range aips.All() { if views.SliceContains(rips, aip) { return true } @@ -1329,9 +1328,9 @@ func (e *userspaceEngine) mySelfIPMatchingFamily(dst netip.Addr) (src netip.Addr if addrs.Len() == 0 { return zero, errors.New("no self address in netmap") } - for i := range addrs.Len() { - if a := addrs.At(i); a.IsSingleIP() && a.Addr().BitLen() == dst.BitLen() { - return a.Addr(), nil + for _, p := range addrs.All() { + if p.IsSingleIP() && p.Addr().BitLen() == dst.BitLen() { + return p.Addr(), nil } } return zero, errors.New("no self address in netmap matching address family") diff --git a/wgengine/wgcfg/nmcfg/nmcfg.go b/wgengine/wgcfg/nmcfg/nmcfg.go index d156f7fcb0ef2..e7d5edf150537 100644 --- a/wgengine/wgcfg/nmcfg/nmcfg.go +++ b/wgengine/wgcfg/nmcfg/nmcfg.go @@ -40,8 +40,7 @@ func cidrIsSubnet(node tailcfg.NodeView, cidr netip.Prefix) bool { if !cidr.IsSingleIP() { return true } - for i := range node.Addresses().Len() { - selfCIDR := node.Addresses().At(i) + for _, selfCIDR := range node.Addresses().All() { if cidr == selfCIDR { return false } @@ -110,8 +109,7 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, cpeer.V4MasqAddr = peer.SelfNodeV4MasqAddrForThisPeer() cpeer.V6MasqAddr = peer.SelfNodeV6MasqAddrForThisPeer() cpeer.IsJailed = peer.IsJailed() - for i := range peer.AllowedIPs().Len() { - allowedIP := peer.AllowedIPs().At(i) + for _, allowedIP := range peer.AllowedIPs().All() { if allowedIP.Bits() == 0 && peer.StableID() != exitNode { if didExitNodeWarn { // Don't log about both the IPv4 /0 and IPv6 /0. From d8a3683fdfc21e0dfe41f47b72c56230296d383b Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Tue, 12 Nov 2024 14:18:19 +0000 Subject: [PATCH 092/179] cmd/k8s-operator: restart ProxyGroup pods less (#14045) We currently annotate pods with a hash of the tailscaled config so that we can trigger pod restarts whenever it changes. However, the hash updates more frequently than is necessary causing more restarts than is necessary. This commit removes two causes; scaling up/down and removing the auth key after pods have initially authed to control. However, note that pods will still restart on scale-up/down because of the updated set of volumes mounted into each pod. Hopefully we can fix that in a planned follow-up PR. Updates #13406 Signed-off-by: Tom Proctor --- cmd/k8s-operator/proxygroup.go | 40 ++++++++++++++++------- cmd/k8s-operator/proxygroup_specs.go | 4 +++ cmd/k8s-operator/proxygroup_test.go | 48 ++++++++++++++++++++-------- 3 files changed, 66 insertions(+), 26 deletions(-) diff --git a/cmd/k8s-operator/proxygroup.go b/cmd/k8s-operator/proxygroup.go index 7dad9e573e151..6b76724662b6d 100644 --- a/cmd/k8s-operator/proxygroup.go +++ b/cmd/k8s-operator/proxygroup.go @@ -353,7 +353,7 @@ func (r *ProxyGroupReconciler) deleteTailnetDevice(ctx context.Context, id tailc func (r *ProxyGroupReconciler) ensureConfigSecretsCreated(ctx context.Context, pg *tsapi.ProxyGroup, proxyClass *tsapi.ProxyClass) (hash string, err error) { logger := r.logger(pg.Name) - var allConfigs []tailscaledConfigs + var configSHA256Sum string for i := range pgReplicas(pg) { cfgSecret := &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ @@ -389,7 +389,6 @@ func (r *ProxyGroupReconciler) ensureConfigSecretsCreated(ctx context.Context, p if err != nil { return "", fmt.Errorf("error creating tailscaled config: %w", err) } - allConfigs = append(allConfigs, configs) for cap, cfg := range configs { cfgJSON, err := json.Marshal(cfg) @@ -399,6 +398,32 @@ func (r *ProxyGroupReconciler) ensureConfigSecretsCreated(ctx context.Context, p mak.Set(&cfgSecret.StringData, tsoperator.TailscaledConfigFileName(cap), string(cfgJSON)) } + // The config sha256 sum is a value for a hash annotation used to trigger + // pod restarts when tailscaled config changes. Any config changes apply + // to all replicas, so it is sufficient to only hash the config for the + // first replica. + // + // In future, we're aiming to eliminate restarts altogether and have + // pods dynamically reload their config when it changes. + if i == 0 { + sum := sha256.New() + for _, cfg := range configs { + // Zero out the auth key so it doesn't affect the sha256 hash when we + // remove it from the config after the pods have all authed. Otherwise + // all the pods will need to restart immediately after authing. + cfg.AuthKey = nil + b, err := json.Marshal(cfg) + if err != nil { + return "", err + } + if _, err := sum.Write(b); err != nil { + return "", err + } + } + + configSHA256Sum = fmt.Sprintf("%x", sum.Sum(nil)) + } + if existingCfgSecret != nil { logger.Debugf("patching the existing ProxyGroup config Secret %s", cfgSecret.Name) if err := r.Patch(ctx, cfgSecret, client.MergeFrom(existingCfgSecret)); err != nil { @@ -412,16 +437,7 @@ func (r *ProxyGroupReconciler) ensureConfigSecretsCreated(ctx context.Context, p } } - sum := sha256.New() - b, err := json.Marshal(allConfigs) - if err != nil { - return "", err - } - if _, err := sum.Write(b); err != nil { - return "", err - } - - return fmt.Sprintf("%x", sum.Sum(nil)), nil + return configSHA256Sum, nil } func pgTailscaledConfig(pg *tsapi.ProxyGroup, class *tsapi.ProxyClass, idx int32, authKey string, oldSecret *corev1.Secret) (tailscaledConfigs, error) { diff --git a/cmd/k8s-operator/proxygroup_specs.go b/cmd/k8s-operator/proxygroup_specs.go index f9d1ea52be221..27fd9ef716361 100644 --- a/cmd/k8s-operator/proxygroup_specs.go +++ b/cmd/k8s-operator/proxygroup_specs.go @@ -93,6 +93,10 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode, cfgHa c.Image = image c.VolumeMounts = func() []corev1.VolumeMount { var mounts []corev1.VolumeMount + + // TODO(tomhjp): Read config directly from the secret instead. The + // mounts change on scaling up/down which causes unnecessary restarts + // for pods that haven't meaningfully changed. for i := range pgReplicas(pg) { mounts = append(mounts, corev1.VolumeMount{ Name: fmt.Sprintf("tailscaledconfig-%d", i), diff --git a/cmd/k8s-operator/proxygroup_test.go b/cmd/k8s-operator/proxygroup_test.go index 445db7537ddb6..23f50cc7a576d 100644 --- a/cmd/k8s-operator/proxygroup_test.go +++ b/cmd/k8s-operator/proxygroup_test.go @@ -35,6 +35,8 @@ var defaultProxyClassAnnotations = map[string]string{ } func TestProxyGroup(t *testing.T) { + const initialCfgHash = "6632726be70cf224049580deb4d317bba065915b5fd415461d60ed621c91b196" + pc := &tsapi.ProxyClass{ ObjectMeta: metav1.ObjectMeta{ Name: "default-pc", @@ -80,6 +82,7 @@ func TestProxyGroup(t *testing.T) { tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "the ProxyGroup's ProxyClass default-pc is not yet in a ready state, waiting...", 0, cl, zl.Sugar()) expectEqual(t, fc, pg, nil) + expectProxyGroupResources(t, fc, pg, false, "") }) t.Run("observe_ProxyGroupCreating_status_reason", func(t *testing.T) { @@ -100,10 +103,11 @@ func TestProxyGroup(t *testing.T) { tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "0/2 ProxyGroup pods running", 0, cl, zl.Sugar()) expectEqual(t, fc, pg, nil) + expectProxyGroupResources(t, fc, pg, true, initialCfgHash) if expected := 1; reconciler.proxyGroups.Len() != expected { t.Fatalf("expected %d recorders, got %d", expected, reconciler.proxyGroups.Len()) } - expectProxyGroupResources(t, fc, pg, true) + expectProxyGroupResources(t, fc, pg, true, initialCfgHash) keyReq := tailscale.KeyCapabilities{ Devices: tailscale.KeyDeviceCapabilities{ Create: tailscale.KeyDeviceCreateCapabilities{ @@ -135,7 +139,7 @@ func TestProxyGroup(t *testing.T) { } tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionTrue, reasonProxyGroupReady, reasonProxyGroupReady, 0, cl, zl.Sugar()) expectEqual(t, fc, pg, nil) - expectProxyGroupResources(t, fc, pg, true) + expectProxyGroupResources(t, fc, pg, true, initialCfgHash) }) t.Run("scale_up_to_3", func(t *testing.T) { @@ -146,6 +150,7 @@ func TestProxyGroup(t *testing.T) { expectReconciled(t, reconciler, "", pg.Name) tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "2/3 ProxyGroup pods running", 0, cl, zl.Sugar()) expectEqual(t, fc, pg, nil) + expectProxyGroupResources(t, fc, pg, true, initialCfgHash) addNodeIDToStateSecrets(t, fc, pg) expectReconciled(t, reconciler, "", pg.Name) @@ -155,7 +160,7 @@ func TestProxyGroup(t *testing.T) { TailnetIPs: []string{"1.2.3.4", "::1"}, }) expectEqual(t, fc, pg, nil) - expectProxyGroupResources(t, fc, pg, true) + expectProxyGroupResources(t, fc, pg, true, initialCfgHash) }) t.Run("scale_down_to_1", func(t *testing.T) { @@ -163,11 +168,26 @@ func TestProxyGroup(t *testing.T) { mustUpdate(t, fc, "", pg.Name, func(p *tsapi.ProxyGroup) { p.Spec = pg.Spec }) + expectReconciled(t, reconciler, "", pg.Name) + pg.Status.Devices = pg.Status.Devices[:1] // truncate to only the first device. expectEqual(t, fc, pg, nil) + expectProxyGroupResources(t, fc, pg, true, initialCfgHash) + }) + + t.Run("trigger_config_change_and_observe_new_config_hash", func(t *testing.T) { + pc.Spec.TailscaleConfig = &tsapi.TailscaleConfig{ + AcceptRoutes: true, + } + mustUpdate(t, fc, "", pc.Name, func(p *tsapi.ProxyClass) { + p.Spec = pc.Spec + }) - expectProxyGroupResources(t, fc, pg, true) + expectReconciled(t, reconciler, "", pg.Name) + + expectEqual(t, fc, pg, nil) + expectProxyGroupResources(t, fc, pg, true, "518a86e9fae64f270f8e0ec2a2ea6ca06c10f725035d3d6caca132cd61e42a74") }) t.Run("delete_and_cleanup", func(t *testing.T) { @@ -191,13 +211,13 @@ func TestProxyGroup(t *testing.T) { }) } -func expectProxyGroupResources(t *testing.T, fc client.WithWatch, pg *tsapi.ProxyGroup, shouldExist bool) { +func expectProxyGroupResources(t *testing.T, fc client.WithWatch, pg *tsapi.ProxyGroup, shouldExist bool, cfgHash string) { t.Helper() role := pgRole(pg, tsNamespace) roleBinding := pgRoleBinding(pg, tsNamespace) serviceAccount := pgServiceAccount(pg, tsNamespace) - statefulSet, err := pgStatefulSet(pg, tsNamespace, testProxyImage, "auto", "") + statefulSet, err := pgStatefulSet(pg, tsNamespace, testProxyImage, "auto", cfgHash) if err != nil { t.Fatal(err) } @@ -207,9 +227,7 @@ func expectProxyGroupResources(t *testing.T, fc client.WithWatch, pg *tsapi.Prox expectEqual(t, fc, role, nil) expectEqual(t, fc, roleBinding, nil) expectEqual(t, fc, serviceAccount, nil) - expectEqual(t, fc, statefulSet, func(ss *appsv1.StatefulSet) { - ss.Spec.Template.Annotations[podAnnotationLastSetConfigFileHash] = "" - }) + expectEqual(t, fc, statefulSet, nil) } else { expectMissing[rbacv1.Role](t, fc, role.Namespace, role.Name) expectMissing[rbacv1.RoleBinding](t, fc, roleBinding.Namespace, roleBinding.Name) @@ -218,11 +236,13 @@ func expectProxyGroupResources(t *testing.T, fc client.WithWatch, pg *tsapi.Prox } var expectedSecrets []string - for i := range pgReplicas(pg) { - expectedSecrets = append(expectedSecrets, - fmt.Sprintf("%s-%d", pg.Name, i), - fmt.Sprintf("%s-%d-config", pg.Name, i), - ) + if shouldExist { + for i := range pgReplicas(pg) { + expectedSecrets = append(expectedSecrets, + fmt.Sprintf("%s-%d", pg.Name, i), + fmt.Sprintf("%s-%d-config", pg.Name, i), + ) + } } expectSecrets(t, fc, expectedSecrets) } From e38522c081ff48add7db73077e7be18f38ea709d Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Tue, 12 Nov 2024 14:23:38 +0000 Subject: [PATCH 093/179] go.{mod,sum},build_docker.sh: bump mkctr, add ability to set OCI annotations for images (#14065) Updates tailscale/tailscale#12914 Signed-off-by: Irbe Krumina --- build_docker.sh | 11 ++++++++ go.mod | 32 +++++++++++----------- go.sum | 72 ++++++++++++++++++++++++------------------------- 3 files changed, 63 insertions(+), 52 deletions(-) diff --git a/build_docker.sh b/build_docker.sh index e8b1c8f28f450..9f39eb08ddf89 100755 --- a/build_docker.sh +++ b/build_docker.sh @@ -17,12 +17,20 @@ eval "$(./build_dist.sh shellvars)" DEFAULT_TARGET="client" DEFAULT_TAGS="v${VERSION_SHORT},v${VERSION_MINOR}" DEFAULT_BASE="tailscale/alpine-base:3.18" +# Set a few pre-defined OCI annotations. The source annotation is used by tools such as Renovate that scan the linked +# Github repo to find release notes for any new image tags. Note that for official Tailscale images the default +# annotations defined here will be overriden by release scripts that call this script. +# https://github.com/opencontainers/image-spec/blob/main/annotations.md#pre-defined-annotation-keys +DEFAULT_ANNOTATIONS="org.opencontainers.image.source=https://github.com/tailscale/tailscale/blob/main/build_docker.sh,org.opencontainers.image.vendor=Tailscale" PUSH="${PUSH:-false}" TARGET="${TARGET:-${DEFAULT_TARGET}}" TAGS="${TAGS:-${DEFAULT_TAGS}}" BASE="${BASE:-${DEFAULT_BASE}}" PLATFORM="${PLATFORM:-}" # default to all platforms +# OCI annotations that will be added to the image. +# https://github.com/opencontainers/image-spec/blob/main/annotations.md +ANNOTATIONS="${ANNOTATIONS:-${DEFAULT_ANNOTATIONS}}" case "$TARGET" in client) @@ -43,6 +51,7 @@ case "$TARGET" in --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ + --annotations="${ANNOTATIONS}" \ /usr/local/bin/containerboot ;; operator) @@ -60,6 +69,7 @@ case "$TARGET" in --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ + --annotations="${ANNOTATIONS}" \ /usr/local/bin/operator ;; k8s-nameserver) @@ -77,6 +87,7 @@ case "$TARGET" in --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ + --annotations="${ANNOTATIONS}" \ /usr/local/bin/k8s-nameserver ;; *) diff --git a/go.mod b/go.mod index 464db8313b5fd..b5451ab613663 100644 --- a/go.mod +++ b/go.mod @@ -42,7 +42,7 @@ require ( github.com/golang/snappy v0.0.4 github.com/golangci/golangci-lint v1.57.1 github.com/google/go-cmp v0.6.0 - github.com/google/go-containerregistry v0.18.0 + github.com/google/go-containerregistry v0.20.2 github.com/google/gopacket v1.1.19 github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806 github.com/google/uuid v1.6.0 @@ -55,7 +55,7 @@ require ( github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 github.com/jsimonetti/rtnetlink v1.4.0 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 - github.com/klauspost/compress v1.17.4 + github.com/klauspost/compress v1.17.11 github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a github.com/mattn/go-colorable v0.1.13 github.com/mattn/go-isatty v0.0.20 @@ -80,7 +80,7 @@ require ( github.com/tailscale/golang-x-crypto v0.0.0-20240604161659-3fde5e568aa4 github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a - github.com/tailscale/mkctr v0.0.0-20240628074852-17ca944da6ba + github.com/tailscale/mkctr v0.0.0-20241111153353-1a38f6676f10 github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 @@ -100,8 +100,8 @@ require ( golang.org/x/mod v0.19.0 golang.org/x/net v0.27.0 golang.org/x/oauth2 v0.16.0 - golang.org/x/sync v0.7.0 - golang.org/x/sys v0.22.0 + golang.org/x/sync v0.9.0 + golang.org/x/sys v0.27.0 golang.org/x/term v0.22.0 golang.org/x/time v0.5.0 golang.org/x/tools v0.23.0 @@ -125,7 +125,7 @@ require ( github.com/Antonboom/testifylint v1.2.0 // indirect github.com/GaijinEntertainment/go-exhaustruct/v3 v3.2.0 // indirect github.com/Masterminds/sprig v2.22.0+incompatible // indirect - github.com/Microsoft/go-winio v0.6.1 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect github.com/OpenPeeDeeP/depguard/v2 v2.2.0 // indirect github.com/alecthomas/go-check-sumtype v0.1.4 // indirect github.com/alexkohler/nakedret/v2 v2.0.4 // indirect @@ -138,7 +138,7 @@ require ( github.com/cyphar/filepath-securejoin v0.2.4 // indirect github.com/dave/astrid v0.0.0-20170323122508-8c2895878b14 // indirect github.com/dave/brenda v1.1.0 // indirect - github.com/docker/go-connections v0.4.0 // indirect + github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/ghostiam/protogetter v0.3.5 // indirect @@ -160,10 +160,10 @@ require ( github.com/ykadowak/zerologlint v0.1.5 // indirect go-simpler.org/musttag v0.9.0 // indirect go-simpler.org/sloglint v0.5.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.47.0 // indirect - go.opentelemetry.io/otel v1.22.0 // indirect - go.opentelemetry.io/otel/metric v1.22.0 // indirect - go.opentelemetry.io/otel/trace v1.22.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.57.0 // indirect + go.opentelemetry.io/otel v1.32.0 // indirect + go.opentelemetry.io/otel/metric v1.32.0 // indirect + go.opentelemetry.io/otel/trace v1.32.0 // indirect go.uber.org/automaxprocs v1.5.3 // indirect golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect ) @@ -220,10 +220,10 @@ require ( github.com/daixiang0/gci v0.12.3 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/denis-tingaikin/go-header v0.5.0 // indirect - github.com/docker/cli v25.0.0+incompatible // indirect + github.com/docker/cli v27.3.1+incompatible // indirect github.com/docker/distribution v2.8.3+incompatible // indirect - github.com/docker/docker v26.1.4+incompatible // indirect - github.com/docker/docker-credential-helpers v0.8.1 // indirect + github.com/docker/docker v27.3.1+incompatible // indirect + github.com/docker/docker-credential-helpers v0.8.2 // indirect github.com/emicklei/go-restful/v3 v3.11.2 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/ettle/strcase v0.2.0 // indirect @@ -322,7 +322,7 @@ require ( github.com/nunnatsa/ginkgolinter v0.16.1 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect - github.com/opencontainers/image-spec v1.1.0-rc6 // indirect + github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pelletier/go-toml/v2 v2.2.0 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect @@ -376,7 +376,7 @@ require ( github.com/ultraware/funlen v0.1.0 // indirect github.com/ultraware/whitespace v0.1.0 // indirect github.com/uudashr/gocognit v1.1.2 // indirect - github.com/vbatts/tar-split v0.11.5 // indirect + github.com/vbatts/tar-split v0.11.6 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect github.com/yagipy/maintidx v1.0.0 // indirect diff --git a/go.sum b/go.sum index 549f559d001fd..55aa3b5357ff9 100644 --- a/go.sum +++ b/go.sum @@ -79,8 +79,8 @@ github.com/Masterminds/sprig v2.22.0+incompatible/go.mod h1:y6hNFY5UBTIWBxnzTeuN github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA= github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM= github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= -github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= -github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/OpenPeeDeeP/depguard/v2 v2.2.0 h1:vDfG60vDtIuf0MEOhmLlLLSzqaRM8EMcgJPdp74zmpA= github.com/OpenPeeDeeP/depguard/v2 v2.2.0/go.mod h1:CIzddKRvLBC4Au5aYP/i3nyaWQ+ClszLIuVocRiCYFQ= github.com/ProtonMail/go-crypto v1.0.0 h1:LRuvITjQWX+WIfr930YHG2HNfjR1uOfyf5vE0kC2U78= @@ -277,16 +277,16 @@ github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5Qvfr github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c= github.com/djherbis/times v1.6.0/go.mod h1:gOHeRAz2h+VJNZ5Gmc/o7iD9k4wW7NMVqieYCY99oc0= -github.com/docker/cli v25.0.0+incompatible h1:zaimaQdnX7fYWFqzN88exE9LDEvRslexpFowZBX6GoQ= -github.com/docker/cli v25.0.0+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= +github.com/docker/cli v27.3.1+incompatible h1:qEGdFBF3Xu6SCvCYhc7CzaQTlBmqDuzxPDpigSyeKQQ= +github.com/docker/cli v27.3.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk= github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= -github.com/docker/docker v26.1.4+incompatible h1:vuTpXDuoga+Z38m1OZHzl7NKisKWaWlhjQk7IDPSLsU= -github.com/docker/docker v26.1.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= -github.com/docker/docker-credential-helpers v0.8.1 h1:j/eKUktUltBtMzKqmfLB0PAgqYyMHOp5vfsD1807oKo= -github.com/docker/docker-credential-helpers v0.8.1/go.mod h1:P3ci7E3lwkZg6XiHdRKft1KckHiO9a2rNtyFbZ/ry9M= -github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= -github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= +github.com/docker/docker v27.3.1+incompatible h1:KttF0XoteNTicmUtBO0L2tP+J7FGRFTjaEF4k6WdhfI= +github.com/docker/docker v27.3.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker-credential-helpers v0.8.2 h1:bX3YxiGzFP5sOXWc3bTPEXdEaZSeVMrFgOr3T+zrFAo= +github.com/docker/docker-credential-helpers v0.8.2/go.mod h1:P3ci7E3lwkZg6XiHdRKft1KckHiO9a2rNtyFbZ/ry9M= +github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= +github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dsnet/try v0.0.3 h1:ptR59SsrcFUYbT/FhAbKTV6iLkeD6O18qfIWRml2fqI= @@ -490,8 +490,8 @@ github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-containerregistry v0.18.0 h1:ShE7erKNPqRh5ue6Z9DUOlk04WsnFWPO6YGr3OxnfoQ= -github.com/google/go-containerregistry v0.18.0/go.mod h1:u0qB2l7mvtWVR5kNcbFIhFY1hLbf8eeGapA+vbFDCtQ= +github.com/google/go-containerregistry v0.20.2 h1:B1wPJ1SN/S7pB+ZAimcciVD+r+yV/l/DSArMxlbwseo= +github.com/google/go-containerregistry v0.20.2/go.mod h1:z38EKdKh4h7IP2gSfUUqEvalZBqs6AoLeWfUy34nQC8= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -627,8 +627,8 @@ github.com/kisielk/errcheck v1.7.0/go.mod h1:1kLL+jV4e+CFfueBmI1dSK2ADDyQnlrnrY/ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kkHAIKE/contextcheck v1.1.4 h1:B6zAaLhOEEcjvUgIYEqystmnFk1Oemn8bvJhbt0GMb8= github.com/kkHAIKE/contextcheck v1.1.4/go.mod h1:1+i/gWqokIa+dm31mqGLZhZJ7Uh44DJGZVmr6QRBNJg= -github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= -github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU= github.com/klauspost/pgzip v1.2.6/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -749,8 +749,8 @@ github.com/onsi/gomega v1.33.1 h1:dsYjIxxSR755MDmKVsaFQTE22ChNBcuuTWgkUDSubOk= github.com/onsi/gomega v1.33.1/go.mod h1:U4R44UsT+9eLIaYRB2a5qajjtQYn0hauxvRm16AVYg0= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= -github.com/opencontainers/image-spec v1.1.0-rc6 h1:XDqvyKsJEbRtATzkgItUqBA7QHk58yxX1Ov9HERHNqU= -github.com/opencontainers/image-spec v1.1.0-rc6/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= +github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= +github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= github.com/otiai10/copy v1.2.0/go.mod h1:rrF5dJ5F0t/EWSYODDu4j9/vEeYHMkc8jt0zJChqQWw= github.com/otiai10/copy v1.14.0 h1:dCI/t1iTdYGtkvCuBG2BgR6KZa83PTclw4U5n2wAllU= github.com/otiai10/copy v1.14.0/go.mod h1:ECfuL02W+/FkTWZWgQqXPWZgW9oeKCSQ5qVfSc4qc4w= @@ -931,8 +931,8 @@ github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 h1:4chzWmimtJPx github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05/go.mod h1:PdCqy9JzfWMJf1H5UJW2ip33/d4YkoKN0r67yKH1mG8= github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a h1:SJy1Pu0eH1C29XwJucQo73FrleVK6t4kYz4NVhp34Yw= github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a/go.mod h1:DFSS3NAGHthKo1gTlmEcSBiZrRJXi28rLNd/1udP1c8= -github.com/tailscale/mkctr v0.0.0-20240628074852-17ca944da6ba h1:uNo1VCm/xg4alMkIKo8RWTKNx5y1otfVOcKbp+irkL4= -github.com/tailscale/mkctr v0.0.0-20240628074852-17ca944da6ba/go.mod h1:DxnqIXBplij66U2ZkL688xy07q97qQ83P+TVueLiHq4= +github.com/tailscale/mkctr v0.0.0-20241111153353-1a38f6676f10 h1:ZB47BgnHcEHQJODkDubs5ZiNeJxMhcgzefV3lykRwVQ= +github.com/tailscale/mkctr v0.0.0-20241111153353-1a38f6676f10/go.mod h1:iDx/0Rr9VV/KanSUDpJ6I/ROf0sQ7OqljXc/esl0UIA= github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 h1:uFsXVBE9Qr4ZoF094vE6iYTLDl0qCiKzYXlL6UeWObU= github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 h1:Gz0rz40FvFVLTBk/K8UNAenb36EbDSnh+q7Z9ldcC8w= @@ -981,8 +981,8 @@ github.com/ultraware/whitespace v0.1.0 h1:O1HKYoh0kIeqE8sFqZf1o0qbORXUCOQFrlaQyZ github.com/ultraware/whitespace v0.1.0/go.mod h1:/se4r3beMFNmewJ4Xmz0nMQ941GJt+qmSHGP9emHYe0= github.com/uudashr/gocognit v1.1.2 h1:l6BAEKJqQH2UpKAPKdMfZf5kE4W/2xk8pfU1OVLvniI= github.com/uudashr/gocognit v1.1.2/go.mod h1:aAVdLURqcanke8h3vg35BC++eseDm66Z7KmchI5et4k= -github.com/vbatts/tar-split v0.11.5 h1:3bHCTIheBm1qFTcgh9oPu+nNBtX+XJIupG/vacinCts= -github.com/vbatts/tar-split v0.11.5/go.mod h1:yZbwRsSeGjusneWgA781EKej9HF8vme8okylkAeNKLk= +github.com/vbatts/tar-split v0.11.6 h1:4SjTW5+PU11n6fZenf2IPoV8/tz3AaYHMWjf23envGs= +github.com/vbatts/tar-split v0.11.6/go.mod h1:dqKNtesIOr2j2Qv3W/cHjnvk9I8+G7oAkFDFN6TCBEI= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= @@ -1022,20 +1022,20 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.47.0 h1:sv9kVfal0MK0wBMCOGr+HeJm9v803BkJxGrk2au7j08= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.47.0/go.mod h1:SK2UL73Zy1quvRPonmOmRDiWk1KBV3LyIeeIxcEApWw= -go.opentelemetry.io/otel v1.22.0 h1:xS7Ku+7yTFvDfDraDIJVpw7XPyuHlB9MCiqqX5mcJ6Y= -go.opentelemetry.io/otel v1.22.0/go.mod h1:eoV4iAi3Ea8LkAEI9+GFT44O6T/D0GWAVFyZVCC6pMI= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.57.0 h1:DheMAlT6POBP+gh8RUH19EOTnQIor5QE0uSRPtzCpSw= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.57.0/go.mod h1:wZcGmeVO9nzP67aYSLDqXNWK87EZWhi7JWj1v7ZXf94= +go.opentelemetry.io/otel v1.32.0 h1:WnBN+Xjcteh0zdk01SVqV55d/m62NJLJdIyb4y/WO5U= +go.opentelemetry.io/otel v1.32.0/go.mod h1:00DCVSB0RQcnzlwyTfqtxSm+DRr9hpYrHjNGiBHVQIg= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.22.0 h1:9M3+rhx7kZCIQQhQRYaZCdNu1V73tm4TvXs2ntl98C4= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.22.0/go.mod h1:noq80iT8rrHP1SfybmPiRGc9dc5M8RPmGvtwo7Oo7tc= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.22.0 h1:FyjCyI9jVEfqhUh2MoSkmolPjfh5fp2hnV0b0irxH4Q= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.22.0/go.mod h1:hYwym2nDEeZfG/motx0p7L7J1N1vyzIThemQsb4g2qY= -go.opentelemetry.io/otel/metric v1.22.0 h1:lypMQnGyJYeuYPhOM/bgjbFM6WE44W1/T45er4d8Hhg= -go.opentelemetry.io/otel/metric v1.22.0/go.mod h1:evJGjVpZv0mQ5QBRJoBF64yMuOf4xCWdXjK8pzFvliY= -go.opentelemetry.io/otel/sdk v1.22.0 h1:6coWHw9xw7EfClIC/+O31R8IY3/+EiRFHevmHafB2Gw= -go.opentelemetry.io/otel/sdk v1.22.0/go.mod h1:iu7luyVGYovrRpe2fmj3CVKouQNdTOkxtLzPvPz1DOc= -go.opentelemetry.io/otel/trace v1.22.0 h1:Hg6pPujv0XG9QaVbGOBVHunyuLcCC3jN7WEhPx83XD0= -go.opentelemetry.io/otel/trace v1.22.0/go.mod h1:RbbHXVqKES9QhzZq/fE5UnOSILqRt40a21sPw2He1xo= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.28.0 h1:j9+03ymgYhPKmeXGk5Zu+cIZOlVzd9Zv7QIiyItjFBU= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.28.0/go.mod h1:Y5+XiUG4Emn1hTfciPzGPJaSI+RpDts6BnCIir0SLqk= +go.opentelemetry.io/otel/metric v1.32.0 h1:xV2umtmNcThh2/a/aCP+h64Xx5wsj8qqnkYZktzNa0M= +go.opentelemetry.io/otel/metric v1.32.0/go.mod h1:jH7CIbbK6SH2V2wE16W05BHCtIDzauciCRLoc/SyMv8= +go.opentelemetry.io/otel/sdk v1.28.0 h1:b9d7hIry8yZsgtbmM0DKyPWMMUMlK9NEKuIG4aBqWyE= +go.opentelemetry.io/otel/sdk v1.28.0/go.mod h1:oYj7ClPUA7Iw3m+r7GeEjz0qckQRJK2B8zjcZEfu7Pg= +go.opentelemetry.io/otel/trace v1.32.0 h1:WIC9mYrXf8TmY/EXuULKc8hR17vE+Hjv2cssQDe03fM= +go.opentelemetry.io/otel/trace v1.32.0/go.mod h1:+i4rkvCraA+tG6AzwloGaCtkx53Fa+L+V8e9a7YvhT8= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8= @@ -1176,8 +1176,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= +golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1239,8 +1239,8 @@ golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= +golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= From cf41cec5a8da13809fab472e221aecd099009b6f Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Tue, 12 Nov 2024 17:13:26 +0000 Subject: [PATCH 094/179] cmd/{k8s-operator,containerboot},k8s-operator: remove support for proxies below capver 95. (#13986) Updates tailscale/tailscale#13984 Signed-off-by: Irbe Krumina --- cmd/containerboot/main.go | 9 ++++----- cmd/k8s-operator/operator_test.go | 4 ++-- cmd/k8s-operator/sts.go | 21 +++------------------ cmd/k8s-operator/testutils_test.go | 20 -------------------- k8s-operator/utils.go | 3 --- 5 files changed, 9 insertions(+), 48 deletions(-) diff --git a/cmd/containerboot/main.go b/cmd/containerboot/main.go index 4c8ba58073c69..17131faae08b8 100644 --- a/cmd/containerboot/main.go +++ b/cmd/containerboot/main.go @@ -102,7 +102,6 @@ import ( "net/netip" "os" "os/signal" - "path" "path/filepath" "slices" "strings" @@ -731,7 +730,6 @@ func tailscaledConfigFilePath() string { } cv, err := kubeutils.CapVerFromFileName(e.Name()) if err != nil { - log.Printf("skipping file %q in tailscaled config directory %q: %v", e.Name(), dir, err) continue } if cv > maxCompatVer && cv <= tailcfg.CurrentCapabilityVersion { @@ -739,8 +737,9 @@ func tailscaledConfigFilePath() string { } } if maxCompatVer == -1 { - log.Fatalf("no tailscaled config file found in %q for current capability version %q", dir, tailcfg.CurrentCapabilityVersion) + log.Fatalf("no tailscaled config file found in %q for current capability version %d", dir, tailcfg.CurrentCapabilityVersion) } - log.Printf("Using tailscaled config file %q for capability version %q", maxCompatVer, tailcfg.CurrentCapabilityVersion) - return path.Join(dir, kubeutils.TailscaledConfigFileName(maxCompatVer)) + filePath := filepath.Join(dir, kubeutils.TailscaledConfigFileName(maxCompatVer)) + log.Printf("Using tailscaled config file %q to match current capability version %d", filePath, tailcfg.CurrentCapabilityVersion) + return filePath } diff --git a/cmd/k8s-operator/operator_test.go b/cmd/k8s-operator/operator_test.go index cc9927645621f..21ef08e520a26 100644 --- a/cmd/k8s-operator/operator_test.go +++ b/cmd/k8s-operator/operator_test.go @@ -1388,7 +1388,7 @@ func TestTailscaledConfigfileHash(t *testing.T) { parentType: "svc", hostname: "default-test", clusterTargetIP: "10.20.30.40", - confFileHash: "362360188dac62bca8013c8134929fed8efd84b1f410c00873d14a05709b5647", + confFileHash: "a67b5ad3ff605531c822327e8f1a23dd0846e1075b722c13402f7d5d0ba32ba2", app: kubetypes.AppIngressProxy, } expectEqual(t, fc, expectedSTS(t, fc, o), nil) @@ -1399,7 +1399,7 @@ func TestTailscaledConfigfileHash(t *testing.T) { mak.Set(&svc.Annotations, AnnotationHostname, "another-test") }) o.hostname = "another-test" - o.confFileHash = "20db57cfabc3fc6490f6bb1dc85994e61d255cdfa2a56abb0141736e59f263ef" + o.confFileHash = "888a993ebee20ad6be99623b45015339de117946850cf1252bede0b570e04293" expectReconciled(t, sr, "default", "test") expectEqual(t, fc, expectedSTS(t, fc, o), nil) } diff --git a/cmd/k8s-operator/sts.go b/cmd/k8s-operator/sts.go index b6467b7981b67..bdacec39b0e98 100644 --- a/cmd/k8s-operator/sts.go +++ b/cmd/k8s-operator/sts.go @@ -521,11 +521,6 @@ func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.S Name: "TS_KUBE_SECRET", Value: proxySecret, }, - corev1.EnvVar{ - // Old tailscaled config key is still used for backwards compatibility. - Name: "EXPERIMENTAL_TS_CONFIGFILE_PATH", - Value: "/etc/tsconfig/tailscaled", - }, corev1.EnvVar{ // New style is in the form of cap-.hujson. Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", @@ -789,15 +784,9 @@ func readAuthKey(secret *corev1.Secret, key string) (*string, error) { return origConf.AuthKey, nil } -// tailscaledConfig takes a proxy config, a newly generated auth key if -// generated and a Secret with the previous proxy state and auth key and -// returns tailscaled configuration and a hash of that configuration. -// -// As of 2024-05-09 it also returns legacy tailscaled config without the -// later added NoStatefulFilter field to support proxies older than cap95. -// TODO (irbekrm): remove the legacy config once we no longer need to support -// versions older than cap94, -// https://tailscale.com/kb/1236/kubernetes-operator#operator-and-proxies +// tailscaledConfig takes a proxy config, a newly generated auth key if generated and a Secret with the previous proxy +// state and auth key and returns tailscaled config files for currently supported proxy versions and a hash of that +// configuration. func tailscaledConfig(stsC *tailscaleSTSConfig, newAuthkey string, oldSecret *corev1.Secret) (tailscaledConfigs, error) { conf := &ipn.ConfigVAlpha{ Version: "alpha0", @@ -846,10 +835,6 @@ func tailscaledConfig(stsC *tailscaleSTSConfig, newAuthkey string, oldSecret *co // AppConnector config option is only understood by clients of capver 107 and newer. conf.AppConnector = nil capVerConfigs[95] = *conf - - // StatefulFiltering is only understood by clients of capver 95 and newer. - conf.NoStatefulFiltering.Clear() - capVerConfigs[94] = *conf return capVerConfigs, nil } diff --git a/cmd/k8s-operator/testutils_test.go b/cmd/k8s-operator/testutils_test.go index 4b25d103c2a56..d42f1b7af89c7 100644 --- a/cmd/k8s-operator/testutils_test.go +++ b/cmd/k8s-operator/testutils_test.go @@ -71,7 +71,6 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef {Name: "TS_USERSPACE", Value: "false"}, {Name: "POD_IP", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "status.podIP"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, {Name: "TS_KUBE_SECRET", Value: opts.secretName}, - {Name: "EXPERIMENTAL_TS_CONFIGFILE_PATH", Value: "/etc/tsconfig/tailscaled"}, {Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig"}, }, SecurityContext: &corev1.SecurityContext{ @@ -230,7 +229,6 @@ func expectedSTSUserspace(t *testing.T, cl client.Client, opts configOpts) *apps {Name: "TS_USERSPACE", Value: "true"}, {Name: "POD_IP", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "status.podIP"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, {Name: "TS_KUBE_SECRET", Value: opts.secretName}, - {Name: "EXPERIMENTAL_TS_CONFIGFILE_PATH", Value: "/etc/tsconfig/tailscaled"}, {Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig"}, {Name: "TS_SERVE_CONFIG", Value: "/etc/tailscaled/serve-config"}, {Name: "TS_INTERNAL_APP", Value: opts.app}, @@ -404,12 +402,6 @@ func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Sec if err != nil { t.Fatalf("error marshalling tailscaled config") } - conf.NoStatefulFiltering.Clear() - b, err := json.Marshal(conf) - if err != nil { - t.Fatalf("error marshalling tailscaled config") - } - mak.Set(&s.StringData, "tailscaled", string(b)) mak.Set(&s.StringData, "cap-95.hujson", string(bn)) mak.Set(&s.StringData, "cap-107.hujson", string(bnn)) labels := map[string]string{ @@ -662,18 +654,6 @@ func removeTargetPortsFromSvc(svc *corev1.Service) { func removeAuthKeyIfExistsModifier(t *testing.T) func(s *corev1.Secret) { return func(secret *corev1.Secret) { t.Helper() - if len(secret.StringData["tailscaled"]) != 0 { - conf := &ipn.ConfigVAlpha{} - if err := json.Unmarshal([]byte(secret.StringData["tailscaled"]), conf); err != nil { - t.Fatalf("error unmarshalling 'tailscaled' contents: %v", err) - } - conf.AuthKey = nil - b, err := json.Marshal(conf) - if err != nil { - t.Fatalf("error marshalling updated 'tailscaled' config: %v", err) - } - mak.Set(&secret.StringData, "tailscaled", string(b)) - } if len(secret.StringData["cap-95.hujson"]) != 0 { conf := &ipn.ConfigVAlpha{} if err := json.Unmarshal([]byte(secret.StringData["cap-95.hujson"]), conf); err != nil { diff --git a/k8s-operator/utils.go b/k8s-operator/utils.go index a1f225fe601c8..420d7e49c7ec2 100644 --- a/k8s-operator/utils.go +++ b/k8s-operator/utils.go @@ -32,9 +32,6 @@ type Records struct { // TailscaledConfigFileName returns a tailscaled config file name in // format expected by containerboot for the given CapVer. func TailscaledConfigFileName(cap tailcfg.CapabilityVersion) string { - if cap < 95 { - return "tailscaled" - } return fmt.Sprintf("cap-%v.hujson", cap) } From 0c6bd9a33b184eadaaba426b1249e5fa2cd2f4b1 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 13 Nov 2024 05:49:51 -0800 Subject: [PATCH 095/179] words: add a scale https://portsmouthbrewery.com/shilling-scale/ Any scale that includes "wee heavy" is a scale worth including. Updates #words Change-Id: I85fd7a64cf22e14f686f1093a220cb59c43e46ba Signed-off-by: Brad Fitzpatrick --- words/scales.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/words/scales.txt b/words/scales.txt index f27dfc5c4aa36..fdec078ee56dd 100644 --- a/words/scales.txt +++ b/words/scales.txt @@ -391,3 +391,4 @@ godzilla sirius vector cherimoya +shilling From 7c6562c861541bf1652f83425b18f618b84d8cde Mon Sep 17 00:00:00 2001 From: Naman Sood Date: Wed, 13 Nov 2024 09:56:02 -0500 Subject: [PATCH 096/179] words: scale up our word count (#14082) Updates tailscale/corp#14698 Signed-off-by: Naman Sood --- words/scales.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/words/scales.txt b/words/scales.txt index fdec078ee56dd..c8041c0fcdb3e 100644 --- a/words/scales.txt +++ b/words/scales.txt @@ -392,3 +392,9 @@ sirius vector cherimoya shilling +kettle +kitchen +fahrenheit +rankine +piano +ruler From 1847f260428012701c61f8f86b72b530d15c1db3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 09:30:14 -0700 Subject: [PATCH 097/179] .github: Bump github/codeql-action from 3.26.11 to 3.27.1 (#14062) Bumps [github/codeql-action](https://github.com/github/codeql-action) from 3.26.11 to 3.27.1. - [Release notes](https://github.com/github/codeql-action/releases) - [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/github/codeql-action/compare/6db8d6351fd0be61f9ed8ebd12ccd35dcec51fea...4f3212b61783c3c68e8309a0f18a699764811cda) --- updated-dependencies: - dependency-name: github/codeql-action dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/codeql-analysis.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 4e266c6eae6ab..0ea73a93c3534 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -55,7 +55,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@6db8d6351fd0be61f9ed8ebd12ccd35dcec51fea # v3.26.11 + uses: github/codeql-action/init@4f3212b61783c3c68e8309a0f18a699764811cda # v3.27.1 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -66,7 +66,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@6db8d6351fd0be61f9ed8ebd12ccd35dcec51fea # v3.26.11 + uses: github/codeql-action/autobuild@4f3212b61783c3c68e8309a0f18a699764811cda # v3.27.1 # ℹ️ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl @@ -80,4 +80,4 @@ jobs: # make release - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@6db8d6351fd0be61f9ed8ebd12ccd35dcec51fea # v3.26.11 + uses: github/codeql-action/analyze@4f3212b61783c3c68e8309a0f18a699764811cda # v3.27.1 From 0cfa217f3e6e2078b82d73bd177bc1d96c291fb2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 09:34:10 -0700 Subject: [PATCH 098/179] .github: Bump actions/upload-artifact from 4.4.0 to 4.4.3 (#13811) Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4.4.0 to 4.4.3. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/50769540e7f4bd5e21e526ee35c689e35e0d6874...b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bc70040b054bf..2fac634b4e93d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -461,7 +461,7 @@ jobs: run: | echo "artifacts_path=$(realpath .)" >> $GITHUB_ENV - name: upload crash - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 if: steps.run.outcome != 'success' && steps.build.outcome == 'success' with: name: artifacts From 4474dcea686ee4ef4263456e3ec497667d4ccf97 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 09:46:30 -0700 Subject: [PATCH 099/179] .github: Bump actions/cache from 4.1.0 to 4.1.2 (#13933) Bumps [actions/cache](https://github.com/actions/cache) from 4.1.0 to 4.1.2. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2...6849a6489940f00c2f30c0fb92c6274307ccb58a) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/test.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2fac634b4e93d..a97e4491796ed 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -80,7 +80,7 @@ jobs: - name: checkout uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Restore Cache - uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 with: # Note: unlike the other setups, this is only grabbing the mod download # cache, rather than the whole mod directory, as the download cache @@ -159,7 +159,7 @@ jobs: cache: false - name: Restore Cache - uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 with: # Note: unlike the other setups, this is only grabbing the mod download # cache, rather than the whole mod directory, as the download cache @@ -260,7 +260,7 @@ jobs: - name: checkout uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Restore Cache - uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 with: # Note: unlike the other setups, this is only grabbing the mod download # cache, rather than the whole mod directory, as the download cache @@ -319,7 +319,7 @@ jobs: - name: checkout uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Restore Cache - uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 with: # Note: unlike the other setups, this is only grabbing the mod download # cache, rather than the whole mod directory, as the download cache @@ -367,7 +367,7 @@ jobs: - name: checkout uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Restore Cache - uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 with: # Note: unlike the other setups, this is only grabbing the mod download # cache, rather than the whole mod directory, as the download cache From 0c9ade46a4321a732887b5ad55bcac69a096390e Mon Sep 17 00:00:00 2001 From: Walter Poupore Date: Wed, 13 Nov 2024 09:25:12 -0800 Subject: [PATCH 100/179] words: Add scoville to scales.txt (#14084) https://en.wikipedia.org/wiki/Scoville_scale Updates #words Signed-off-by: Walter Poupore --- words/scales.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/words/scales.txt b/words/scales.txt index c8041c0fcdb3e..2fe849bb9cee1 100644 --- a/words/scales.txt +++ b/words/scales.txt @@ -398,3 +398,4 @@ fahrenheit rankine piano ruler +scoville From bfe5cd87606454e2d00631d2c29e0fa72443758c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 10:56:44 -0700 Subject: [PATCH 101/179] .github: Bump actions/setup-go from 5.0.2 to 5.1.0 (#13934) Bumps [actions/setup-go](https://github.com/actions/setup-go) from 5.0.2 to 5.1.0. - [Release notes](https://github.com/actions/setup-go/releases) - [Commits](https://github.com/actions/setup-go/compare/0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32...41dfa10bad2bb2ae585af6ee5bb4d7d973ad74ed) --- updated-dependencies: - dependency-name: actions/setup-go dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/codeql-analysis.yml | 2 +- .github/workflows/golangci-lint.yml | 2 +- .github/workflows/test.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 0ea73a93c3534..d9a287be32d8d 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -49,7 +49,7 @@ jobs: # Install a more recent Go that understands modern go.mod content. - name: Install Go - uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2 + uses: actions/setup-go@41dfa10bad2bb2ae585af6ee5bb4d7d973ad74ed # v5.1.0 with: go-version-file: go.mod diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 9c34debc5d2f4..6630e8de852ae 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -25,7 +25,7 @@ jobs: steps: - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - - uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2 + - uses: actions/setup-go@41dfa10bad2bb2ae585af6ee5bb4d7d973ad74ed # v5.1.0 with: go-version-file: go.mod cache: false diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a97e4491796ed..f9bb5cae2235f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -153,7 +153,7 @@ jobs: uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Install Go - uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2 + uses: actions/setup-go@41dfa10bad2bb2ae585af6ee5bb4d7d973ad74ed # v5.1.0 with: go-version-file: go.mod cache: false From f593d3c5c0eee55fdf988085d27aa991dbfd5fd6 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 13 Nov 2024 07:36:43 -0800 Subject: [PATCH 102/179] cmd/tailscale/cli: add "help" alias for --help Fixes #14053 Change-Id: I0a13e11af089f02b0656fea0d316543c67591fb5 Signed-off-by: Brad Fitzpatrick --- cmd/tailscale/cli/cli.go | 9 +++++++-- cmd/tailscale/cli/cli_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/cmd/tailscale/cli/cli.go b/cmd/tailscale/cli/cli.go index 130a11623351d..66961b2e0086d 100644 --- a/cmd/tailscale/cli/cli.go +++ b/cmd/tailscale/cli/cli.go @@ -93,8 +93,13 @@ func Run(args []string) (err error) { args = CleanUpArgs(args) - if len(args) == 1 && (args[0] == "-V" || args[0] == "--version") { - args = []string{"version"} + if len(args) == 1 { + switch args[0] { + case "-V", "--version": + args = []string{"version"} + case "help": + args = []string{"--help"} + } } var warnOnce sync.Once diff --git a/cmd/tailscale/cli/cli_test.go b/cmd/tailscale/cli/cli_test.go index 4b75486715731..0444e914c7260 100644 --- a/cmd/tailscale/cli/cli_test.go +++ b/cmd/tailscale/cli/cli_test.go @@ -9,6 +9,7 @@ import ( "encoding/json" "flag" "fmt" + "io" "net/netip" "reflect" "strings" @@ -1480,3 +1481,33 @@ func TestParseNLArgs(t *testing.T) { }) } } + +func TestHelpAlias(t *testing.T) { + var stdout, stderr bytes.Buffer + tstest.Replace[io.Writer](t, &Stdout, &stdout) + tstest.Replace[io.Writer](t, &Stderr, &stderr) + + gotExit0 := false + defer func() { + if !gotExit0 { + t.Error("expected os.Exit(0) to be called") + return + } + if !strings.Contains(stderr.String(), "SUBCOMMANDS") { + t.Errorf("expected help output to contain SUBCOMMANDS; got stderr=%q; stdout=%q", stderr.String(), stdout.String()) + } + }() + defer func() { + if e := recover(); e != nil { + if strings.Contains(fmt.Sprint(e), "unexpected call to os.Exit(0)") { + gotExit0 = true + } else { + t.Errorf("unexpected panic: %v", e) + } + } + }() + err := Run([]string{"help"}) + if err != nil { + t.Fatalf("Run: %v", err) + } +} From e73cfd9700095406d7263c855bc47801f7e0a2da Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 14 Nov 2024 09:50:42 -0800 Subject: [PATCH 103/179] go.toolchain.rev: bump from Go 1.23.1 to Go 1.23.3 Updates #14100 Change-Id: I57f9d4260be15ce1daebe4a9782910aba3fb9dc9 Signed-off-by: Brad Fitzpatrick --- go.toolchain.rev | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.toolchain.rev b/go.toolchain.rev index 5d87594c25a31..500d853e5e4bd 100644 --- a/go.toolchain.rev +++ b/go.toolchain.rev @@ -1 +1 @@ -bf15628b759344c6fc7763795a405ba65b8be5d7 +96578f73d04e1a231fa2a495ad3fa97747785bc6 From 8fd471ce5748d2129dba584b4fa14b0d29229299 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 14 Nov 2024 09:44:16 -0800 Subject: [PATCH 104/179] control/controlclient: disable https on for http://localhost:$port URLs Previously we required the program to be running in a test or have TS_CONTROL_IS_PLAINTEXT_HTTP before we disabled its https fallback on "http" schema control URLs to localhost with ports. But nobody accidentally does all three of "http", explicit port number, localhost and doesn't mean it. And when they mean it, they're testing a localhost dev control server (like I was) and don't want 443 getting involved. As of the changes for #13597, this became more annoying in that we were trying to use a port which wasn't even available. Updates #13597 Change-Id: Icd00bca56043d2da58ab31de7aa05a3b269c490f Signed-off-by: Brad Fitzpatrick --- control/controlclient/noise.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go index 3994af056fc3b..2e7c70fd1b162 100644 --- a/control/controlclient/noise.go +++ b/control/controlclient/noise.go @@ -17,7 +17,6 @@ import ( "golang.org/x/net/http2" "tailscale.com/control/controlhttp" - "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/internal/noiseconn" "tailscale.com/net/dnscache" @@ -30,7 +29,6 @@ import ( "tailscale.com/util/mak" "tailscale.com/util/multierr" "tailscale.com/util/singleflight" - "tailscale.com/util/testenv" ) // NoiseClient provides a http.Client to connect to tailcontrol over @@ -107,11 +105,6 @@ type NoiseOpts struct { DialPlan func() *tailcfg.ControlDialPlan } -// controlIsPlaintext is whether we should assume that the controlplane is only accessible -// over plaintext HTTP (as the first hop, before the ts2021 encryption begins). -// This is used by some tests which don't have a real TLS certificate. -var controlIsPlaintext = envknob.RegisterBool("TS_CONTROL_IS_PLAINTEXT_HTTP") - // NewNoiseClient returns a new noiseClient for the provided server and machine key. // serverURL is of the form https://: (no trailing slash). // @@ -129,7 +122,7 @@ func NewNoiseClient(opts NoiseOpts) (*NoiseClient, error) { if u.Scheme == "http" { httpPort = port httpsPort = "443" - if (testenv.InTest() || controlIsPlaintext()) && (u.Hostname() == "127.0.0.1" || u.Hostname() == "localhost") { + if u.Hostname() == "127.0.0.1" || u.Hostname() == "localhost" { httpsPort = "" } } else { From c3c4c05331ca13a7a159e5b6307fd72a6d2d3a00 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 15 Nov 2024 07:12:56 -0800 Subject: [PATCH 105/179] tstest/integration/testcontrol: remove a vestigial unused parameter Back in the day this testcontrol package only spoke the nacl-boxed-based control protocol, which used this. Then we added ts2021, which didn't, but still sometimes used it. Then we removed the old mode and didn't remove this parameter in 2409661a0da956. Updates #11585 Change-Id: Ifd290bd7dbbb52b681b3599786437a15bc98b6a5 Signed-off-by: Brad Fitzpatrick --- tstest/integration/testcontrol/testcontrol.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index 2d6a843618627..a6b2e1828b8fe 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -832,7 +832,7 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi w.WriteHeader(200) for { if resBytes, ok := s.takeRawMapMessage(req.NodeKey); ok { - if err := s.sendMapMsg(w, mkey, compress, resBytes); err != nil { + if err := s.sendMapMsg(w, compress, resBytes); err != nil { s.logf("sendMapMsg of raw message: %v", err) return } @@ -864,7 +864,7 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi s.logf("json.Marshal: %v", err) return } - if err := s.sendMapMsg(w, mkey, compress, resBytes); err != nil { + if err := s.sendMapMsg(w, compress, resBytes); err != nil { return } } @@ -895,7 +895,7 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi } break keepAliveLoop case <-keepAliveTimerCh: - if err := s.sendMapMsg(w, mkey, compress, keepAliveMsg); err != nil { + if err := s.sendMapMsg(w, compress, keepAliveMsg); err != nil { return } } @@ -1060,7 +1060,7 @@ func (s *Server) takeRawMapMessage(nk key.NodePublic) (mapResJSON []byte, ok boo return mapResJSON, true } -func (s *Server) sendMapMsg(w http.ResponseWriter, mkey key.MachinePublic, compress bool, msg any) error { +func (s *Server) sendMapMsg(w http.ResponseWriter, compress bool, msg any) error { resBytes, err := s.encode(compress, msg) if err != nil { return err From 1355f622beca0db5794201ab8802804ab1299e2f Mon Sep 17 00:00:00 2001 From: Percy Wegmann Date: Thu, 14 Nov 2024 14:21:30 -0600 Subject: [PATCH 106/179] cmd/derpprobe,prober: add ability to restrict derpprobe to a single region Updates #24522 Co-authored-by: Mario Minardi Signed-off-by: Percy Wegmann --- cmd/derpprobe/derpprobe.go | 4 ++++ prober/derp.go | 23 +++++++++++++++++++++++ prober/derp_test.go | 31 +++++++++++++++++++++++++++++-- 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/cmd/derpprobe/derpprobe.go b/cmd/derpprobe/derpprobe.go index 5b7b77091de7f..8f04326b03980 100644 --- a/cmd/derpprobe/derpprobe.go +++ b/cmd/derpprobe/derpprobe.go @@ -29,6 +29,7 @@ var ( tlsInterval = flag.Duration("tls-interval", 15*time.Second, "TLS probe interval") bwInterval = flag.Duration("bw-interval", 0, "bandwidth probe interval (0 = no bandwidth probing)") bwSize = flag.Int64("bw-probe-size-bytes", 1_000_000, "bandwidth probe size") + regionCode = flag.String("region-code", "", "probe only this region (e.g. 'lax'); if left blank, all regions will be probed") ) func main() { @@ -47,6 +48,9 @@ func main() { if *bwInterval > 0 { opts = append(opts, prober.WithBandwidthProbing(*bwInterval, *bwSize)) } + if *regionCode != "" { + opts = append(opts, prober.WithRegion(*regionCode)) + } dp, err := prober.DERP(p, *derpMapURL, opts...) if err != nil { log.Fatal(err) diff --git a/prober/derp.go b/prober/derp.go index 0dadbe8c2fe06..b1ebc590d4f98 100644 --- a/prober/derp.go +++ b/prober/derp.go @@ -45,6 +45,9 @@ type derpProber struct { bwInterval time.Duration bwProbeSize int64 + // Optionally restrict probes to a single regionCode. + regionCode string + // Probe class for fetching & updating the DERP map. ProbeMap ProbeClass @@ -97,6 +100,14 @@ func WithTLSProbing(interval time.Duration) DERPOpt { } } +// WithRegion restricts probing to the specified region identified by its code +// (e.g. "lax"). This is case sensitive. +func WithRegion(regionCode string) DERPOpt { + return func(d *derpProber) { + d.regionCode = regionCode + } +} + // DERP creates a new derpProber. // // If derpMapURL is "local", the DERPMap is fetched via @@ -135,6 +146,10 @@ func (d *derpProber) probeMapFn(ctx context.Context) error { defer d.Unlock() for _, region := range d.lastDERPMap.Regions { + if d.skipRegion(region) { + continue + } + for _, server := range region.Nodes { labels := Labels{ "region": region.RegionCode, @@ -316,6 +331,10 @@ func (d *derpProber) updateMap(ctx context.Context) error { d.lastDERPMapAt = time.Now() d.nodes = make(map[string]*tailcfg.DERPNode) for _, reg := range d.lastDERPMap.Regions { + if d.skipRegion(reg) { + continue + } + for _, n := range reg.Nodes { if existing, ok := d.nodes[n.Name]; ok { return fmt.Errorf("derpmap has duplicate nodes: %+v and %+v", existing, n) @@ -338,6 +357,10 @@ func (d *derpProber) ProbeUDP(ipaddr string, port int) ProbeClass { } } +func (d *derpProber) skipRegion(region *tailcfg.DERPRegion) bool { + return d.regionCode != "" && region.RegionCode != d.regionCode +} + func derpProbeUDP(ctx context.Context, ipStr string, port int) error { pc, err := net.ListenPacket("udp", ":0") if err != nil { diff --git a/prober/derp_test.go b/prober/derp_test.go index a34292a23b6f4..c084803e94f6a 100644 --- a/prober/derp_test.go +++ b/prober/derp_test.go @@ -44,6 +44,19 @@ func TestDerpProber(t *testing.T) { }, }, }, + 1: { + RegionID: 1, + RegionCode: "one", + Nodes: []*tailcfg.DERPNode{ + { + Name: "n3", + RegionID: 0, + HostName: "derpn3.tailscale.test", + IPv4: "1.1.1.1", + IPv6: "::1", + }, + }, + }, }, } srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -68,6 +81,7 @@ func TestDerpProber(t *testing.T) { meshProbeFn: func(_, _ string) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) }, nodes: make(map[string]*tailcfg.DERPNode), probes: make(map[string]*Probe), + regionCode: "zero", } if err := dp.probeMapFn(context.Background()); err != nil { t.Errorf("unexpected probeMapFn() error: %s", err) @@ -84,9 +98,9 @@ func TestDerpProber(t *testing.T) { // Add one more node and check that probes got created. dm.Regions[0].Nodes = append(dm.Regions[0].Nodes, &tailcfg.DERPNode{ - Name: "n3", + Name: "n4", RegionID: 0, - HostName: "derpn3.tailscale.test", + HostName: "derpn4.tailscale.test", IPv4: "1.1.1.1", IPv6: "::1", }) @@ -113,6 +127,19 @@ func TestDerpProber(t *testing.T) { if len(dp.probes) != 4 { t.Errorf("unexpected probes: %+v", dp.probes) } + + // Stop filtering regions. + dp.regionCode = "" + if err := dp.probeMapFn(context.Background()); err != nil { + t.Errorf("unexpected probeMapFn() error: %s", err) + } + if len(dp.nodes) != 2 { + t.Errorf("unexpected nodes: %+v", dp.nodes) + } + // 6 regular probes + 2 mesh probe + if len(dp.probes) != 8 { + t.Errorf("unexpected probes: %+v", dp.probes) + } } func TestRunDerpProbeNodePair(t *testing.T) { From aefbed323f33e7e02ea87147e2264efcce39d3f6 Mon Sep 17 00:00:00 2001 From: Naman Sood Date: Fri, 15 Nov 2024 16:14:06 -0500 Subject: [PATCH 107/179] ipn,tailcfg: add VIPService struct and c2n to fetch them from client (#14046) * ipn,tailcfg: add VIPService struct and c2n to fetch them from client Updates tailscale/corp#22743, tailscale/corp#22955 Signed-off-by: Naman Sood * more review fixes Signed-off-by: Naman Sood * don't mention PeerCapabilityServicesDestination since it's currently unused Signed-off-by: Naman Sood --------- Signed-off-by: Naman Sood --- ipn/ipnlocal/c2n.go | 9 ++++ ipn/ipnlocal/local.go | 48 +++++++++++++++++++++ ipn/ipnlocal/local_test.go | 88 ++++++++++++++++++++++++++++++++++++++ tailcfg/tailcfg.go | 29 ++++++++++++- tailcfg/tailcfg_clone.go | 1 + tailcfg/tailcfg_test.go | 11 +++++ tailcfg/tailcfg_view.go | 2 + 7 files changed, 187 insertions(+), 1 deletion(-) diff --git a/ipn/ipnlocal/c2n.go b/ipn/ipnlocal/c2n.go index 8380689d1f066..f3a4a3a3d2b29 100644 --- a/ipn/ipnlocal/c2n.go +++ b/ipn/ipnlocal/c2n.go @@ -77,6 +77,9 @@ var c2nHandlers = map[methodAndPath]c2nHandler{ // Linux netfilter. req("POST /netfilter-kind"): handleC2NSetNetfilterKind, + + // VIP services. + req("GET /vip-services"): handleC2NVIPServicesGet, } type c2nHandler func(*LocalBackend, http.ResponseWriter, *http.Request) @@ -269,6 +272,12 @@ func handleC2NSetNetfilterKind(b *LocalBackend, w http.ResponseWriter, r *http.R w.WriteHeader(http.StatusNoContent) } +func handleC2NVIPServicesGet(b *LocalBackend, w http.ResponseWriter, r *http.Request) { + b.logf("c2n: GET /vip-services received") + + json.NewEncoder(w).Encode(b.VIPServices()) +} + func handleC2NUpdateGet(b *LocalBackend, w http.ResponseWriter, r *http.Request) { b.logf("c2n: GET /update received") diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 493762fccab19..3c7296038233a 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -9,6 +9,7 @@ import ( "bytes" "cmp" "context" + "crypto/sha256" "encoding/base64" "encoding/json" "errors" @@ -4888,6 +4889,14 @@ func (b *LocalBackend) applyPrefsToHostinfoLocked(hi *tailcfg.Hostinfo, prefs ip } hi.SSH_HostKeys = sshHostKeys + services := vipServicesFromPrefs(prefs) + if len(services) > 0 { + buf, _ := json.Marshal(services) + hi.ServicesHash = fmt.Sprintf("%02x", sha256.Sum256(buf)) + } else { + hi.ServicesHash = "" + } + // The Hostinfo.WantIngress field tells control whether this node wants to // be wired up for ingress connections. If harmless if it's accidentally // true; the actual policy is controlled in tailscaled by ServeConfig. But @@ -7485,3 +7494,42 @@ func maybeUsernameOf(actor ipnauth.Actor) string { } return username } + +// VIPServices returns the list of tailnet services that this node +// is serving as a destination for. +// The returned memory is owned by the caller. +func (b *LocalBackend) VIPServices() []*tailcfg.VIPService { + b.mu.Lock() + defer b.mu.Unlock() + return vipServicesFromPrefs(b.pm.CurrentPrefs()) +} + +func vipServicesFromPrefs(prefs ipn.PrefsView) []*tailcfg.VIPService { + // keyed by service name + var services map[string]*tailcfg.VIPService + + // TODO(naman): this envknob will be replaced with service-specific port + // information once we start storing that. + var allPortsServices []string + if env := envknob.String("TS_DEBUG_ALLPORTS_SERVICES"); env != "" { + allPortsServices = strings.Split(env, ",") + } + + for _, s := range allPortsServices { + mak.Set(&services, s, &tailcfg.VIPService{ + Name: s, + Ports: []tailcfg.ProtoPortRange{{Ports: tailcfg.PortRangeAny}}, + }) + } + + for _, s := range prefs.AdvertiseServices().AsSlice() { + if services == nil || services[s] == nil { + mak.Set(&services, s, &tailcfg.VIPService{ + Name: s, + }) + } + services[s].Active = true + } + + return slices.Collect(maps.Values(services)) +} diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index 6dad2dba4deeb..6d25a418fc6a8 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -30,6 +30,7 @@ import ( "tailscale.com/control/controlclient" "tailscale.com/drive" "tailscale.com/drive/driveimpl" + "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/ipn" @@ -4464,3 +4465,90 @@ func TestConfigFileReload(t *testing.T) { t.Fatalf("got %q; want %q", hn, "bar") } } + +func TestGetVIPServices(t *testing.T) { + tests := []struct { + name string + advertised []string + mapped []string + want []*tailcfg.VIPService + }{ + { + "advertised-only", + []string{"svc:abc", "svc:def"}, + []string{}, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Active: true, + }, + { + Name: "svc:def", + Active: true, + }, + }, + }, + { + "mapped-only", + []string{}, + []string{"svc:abc"}, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Ports: []tailcfg.ProtoPortRange{{Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + "mapped-and-advertised", + []string{"svc:abc"}, + []string{"svc:abc"}, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Active: true, + Ports: []tailcfg.ProtoPortRange{{Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + "mapped-and-advertised-separately", + []string{"svc:def"}, + []string{"svc:abc"}, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Ports: []tailcfg.ProtoPortRange{{Ports: tailcfg.PortRangeAny}}, + }, + { + Name: "svc:def", + Active: true, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + envknob.Setenv("TS_DEBUG_ALLPORTS_SERVICES", strings.Join(tt.mapped, ",")) + prefs := &ipn.Prefs{ + AdvertiseServices: tt.advertised, + } + got := vipServicesFromPrefs(prefs.View()) + slices.SortFunc(got, func(a, b *tailcfg.VIPService) int { + return strings.Compare(a.Name, b.Name) + }) + if !reflect.DeepEqual(tt.want, got) { + t.Logf("want:") + for _, s := range tt.want { + t.Logf("%+v", s) + } + t.Logf("got:") + for _, s := range got { + t.Logf("%+v", s) + } + t.Fail() + return + } + }) + } +} diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 9e39a43364962..1b283a2fcebd2 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -150,7 +150,8 @@ type CapabilityVersion int // - 105: 2024-08-05: Fixed SSH behavior on systems that use busybox (issue #12849) // - 106: 2024-09-03: fix panic regression from cryptokey routing change (65fe0ba7b5) // - 107: 2024-10-30: add App Connector to conffile (PR #13942) -const CurrentCapabilityVersion CapabilityVersion = 107 +// - 108: 2024-11-08: Client sends ServicesHash in Hostinfo, understands c2n GET /vip-services. +const CurrentCapabilityVersion CapabilityVersion = 108 type StableID string @@ -820,6 +821,7 @@ type Hostinfo struct { Userspace opt.Bool `json:",omitempty"` // if the client is running in userspace (netstack) mode UserspaceRouter opt.Bool `json:",omitempty"` // if the client's subnet router is running in userspace (netstack) mode AppConnector opt.Bool `json:",omitempty"` // if the client is running the app-connector service + ServicesHash string `json:",omitempty"` // opaque hash of the most recent list of tailnet services, change in hash indicates config should be fetched via c2n // Location represents geographical location data about a // Tailscale host. Location is optional and only set if @@ -830,6 +832,26 @@ type Hostinfo struct { // require changes to Hostinfo.Equal. } +// VIPService represents a service created on a tailnet from the +// perspective of a node providing that service. These services +// have an virtual IP (VIP) address pair distinct from the node's IPs. +type VIPService struct { + // Name is the name of the service, of the form `svc:dns-label`. + // See CheckServiceName for a validation func. + // Name uniquely identifies a service on a particular tailnet, + // and so also corresponds uniquely to the pair of IP addresses + // belonging to the VIP service. + Name string + + // Ports specify which ProtoPorts are made available by this node + // on the service's IPs. + Ports []ProtoPortRange + + // Active specifies whether new requests for the service should be + // sent to this node by control. + Active bool +} + // TailscaleSSHEnabled reports whether or not this node is acting as a // Tailscale SSH server. func (hi *Hostinfo) TailscaleSSHEnabled() bool { @@ -1429,6 +1451,11 @@ const ( // user groups as Kubernetes user groups. This capability is read by // peers that are Tailscale Kubernetes operator instances. PeerCapabilityKubernetes PeerCapability = "tailscale.com/cap/kubernetes" + + // PeerCapabilityServicesDestination grants a peer the ability to serve as + // a destination for a set of given VIP services, which is provided as the + // value of this key in NodeCapMap. + PeerCapabilityServicesDestination PeerCapability = "tailscale.com/cap/services-destination" ) // NodeCapMap is a map of capabilities to their optional values. It is valid for diff --git a/tailcfg/tailcfg_clone.go b/tailcfg/tailcfg_clone.go index 61564f3f8bfd4..f4f02c01721dc 100644 --- a/tailcfg/tailcfg_clone.go +++ b/tailcfg/tailcfg_clone.go @@ -183,6 +183,7 @@ var _HostinfoCloneNeedsRegeneration = Hostinfo(struct { Userspace opt.Bool UserspaceRouter opt.Bool AppConnector opt.Bool + ServicesHash string Location *Location }{}) diff --git a/tailcfg/tailcfg_test.go b/tailcfg/tailcfg_test.go index 0d06366771d6e..9f8c418a1ccf9 100644 --- a/tailcfg/tailcfg_test.go +++ b/tailcfg/tailcfg_test.go @@ -66,6 +66,7 @@ func TestHostinfoEqual(t *testing.T) { "Userspace", "UserspaceRouter", "AppConnector", + "ServicesHash", "Location", } if have := fieldsOf(reflect.TypeFor[Hostinfo]()); !reflect.DeepEqual(have, hiHandles) { @@ -240,6 +241,16 @@ func TestHostinfoEqual(t *testing.T) { &Hostinfo{AppConnector: opt.Bool("false")}, false, }, + { + &Hostinfo{ServicesHash: "73475cb40a568e8da8a045ced110137e159f890ac4da883b6b17dc651b3a8049"}, + &Hostinfo{ServicesHash: "73475cb40a568e8da8a045ced110137e159f890ac4da883b6b17dc651b3a8049"}, + true, + }, + { + &Hostinfo{ServicesHash: "084c799cd551dd1d8d5c5f9a5d593b2e931f5e36122ee5c793c1d08a19839cc0"}, + &Hostinfo{}, + false, + }, } for i, tt := range tests { got := tt.a.Equal(tt.b) diff --git a/tailcfg/tailcfg_view.go b/tailcfg/tailcfg_view.go index a3e19b0dcec7a..f275a6a9da5f2 100644 --- a/tailcfg/tailcfg_view.go +++ b/tailcfg/tailcfg_view.go @@ -318,6 +318,7 @@ func (v HostinfoView) Cloud() string { return v.ж.Clou func (v HostinfoView) Userspace() opt.Bool { return v.ж.Userspace } func (v HostinfoView) UserspaceRouter() opt.Bool { return v.ж.UserspaceRouter } func (v HostinfoView) AppConnector() opt.Bool { return v.ж.AppConnector } +func (v HostinfoView) ServicesHash() string { return v.ж.ServicesHash } func (v HostinfoView) Location() *Location { if v.ж.Location == nil { return nil @@ -365,6 +366,7 @@ var _HostinfoViewNeedsRegeneration = Hostinfo(struct { Userspace opt.Bool UserspaceRouter opt.Bool AppConnector opt.Bool + ServicesHash string Location *Location }{}) From 3b93fd9c4430332787e6d9ed6164efb63d3a9e8b Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 15 Nov 2024 14:16:03 -0800 Subject: [PATCH 108/179] net/captivedetection: replace 10k log lines with ... less We see tons of logs of the form: 2024/11/15 19:57:29 netcheck: [v2] 76 available captive portal detection endpoints: [Endpoint{URL="http://192.73.240.161/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://192.73.240.121/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://192.73.240.132/generate_204", StatusCode=204, ExpectedContent="", 11:58SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://209.177.158.246/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://209.177.158.15/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://199.38.182.118/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://192.73.243.135/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://192.73.243.229/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://192.73.243.141/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://45.159.97.144/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://45.159.97.61/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://45.159.97.233/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://45.159.98.196/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://45.159.98.253/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://45.159.98.145/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://68.183.90.120/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://209.177.156.94/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://192.73.248.83/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://209.177.156.197/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://199.38.181.104/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://209.177.145.120/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://199.38.181.93/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://199.38.181.103/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://102.67.165.90/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://102.67.165.185/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://102.67.165.36/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://176.58.90.147/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://176.58.90.207/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://176.58.90.104/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://162.248.221.199/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://162.248.221.215/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://162.248.221.248/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://185.34.3.232/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://185.34.3.207/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://185.34.3.75/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://208.83.234.151/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://208.83.233.233/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://208.72.155.133/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://185.40.234.219/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://185.40.234.113/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://185.40.234.77/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://43.245.48.220/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://43.245.48.50/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://43.245.48.250/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://192.73.252.65/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://192.73.252.134/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://208.111.34.178/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://43.245.49.105/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://43.245.49.83/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://43.245.49.144/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://176.58.92.144/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://176.58.88.183/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://176.58.92.254/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://148.163.220.129/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://148.163.220.134/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://148.163.220.210/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://192.73.242.187/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://192.73.242.28/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://192.73.242.204/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://176.58.93.248/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://176.58.93.147/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://176.58.93.154/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://192.73.244.245/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://208.111.40.12/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://208.111.40.216/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://103.6.84.152/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://205.147.105.30/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://205.147.105.78/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://102.67.167.245/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://102.67.167.37/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://102.67.167.188/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://103.84.155.178/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://103.84.155.188/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://103.84.155.46/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=true, Provider=DERPMapOther} Endpoint{URL="http://controlplane.tailscale.com/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=false, Provider=Tailscale} Endpoint{URL="http://login.tailscale.com/generate_204", StatusCode=204, ExpectedContent="", SupportsTailscaleChallenge=false, Provider=Tailscale}] That can be much shorter. Also add a fast exit path to the concurrency on match. Doing 5 all at once is still pretty gratuitous, though. Updates #1634 Fixes #13019 Change-Id: Icdbb16572fca4477b0ee9882683a3ac6eb08e2f2 Signed-off-by: Brad Fitzpatrick --- net/captivedetection/captivedetection.go | 19 ++++++---- net/captivedetection/captivedetection_test.go | 37 +++++++++++++++---- 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/net/captivedetection/captivedetection.go b/net/captivedetection/captivedetection.go index c6e8bca3a19a2..7d598d853349d 100644 --- a/net/captivedetection/captivedetection.go +++ b/net/captivedetection/captivedetection.go @@ -136,26 +136,31 @@ func interfaceNameDoesNotNeedCaptiveDetection(ifName string, goos string) bool { func (d *Detector) detectOnInterface(ctx context.Context, ifIndex int, endpoints []Endpoint) bool { defer d.httpClient.CloseIdleConnections() - d.logf("[v2] %d available captive portal detection endpoints: %v", len(endpoints), endpoints) + use := min(len(endpoints), 5) + endpoints = endpoints[:use] + d.logf("[v2] %d available captive portal detection endpoints; trying %v", len(endpoints), use) // We try to detect the captive portal more quickly by making requests to multiple endpoints concurrently. var wg sync.WaitGroup resultCh := make(chan bool, len(endpoints)) - for i, e := range endpoints { - if i >= 5 { - // Try a maximum of 5 endpoints, break out (returning false) if we run of attempts. - break - } + // Once any goroutine detects a captive portal, we shut down the others. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + for _, e := range endpoints { wg.Add(1) go func(endpoint Endpoint) { defer wg.Done() found, err := d.verifyCaptivePortalEndpoint(ctx, endpoint, ifIndex) if err != nil { - d.logf("[v1] checkCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) + if ctx.Err() == nil { + d.logf("[v1] checkCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) + } return } if found { + cancel() // one match is good enough resultCh <- true } }(e) diff --git a/net/captivedetection/captivedetection_test.go b/net/captivedetection/captivedetection_test.go index e74273afd922e..29a197d31f263 100644 --- a/net/captivedetection/captivedetection_test.go +++ b/net/captivedetection/captivedetection_test.go @@ -7,10 +7,12 @@ import ( "context" "runtime" "sync" + "sync/atomic" "testing" - "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/net/netmon" + "tailscale.com/syncs" + "tailscale.com/tstest/nettest" ) func TestAvailableEndpointsAlwaysAtLeastTwo(t *testing.T) { @@ -36,25 +38,46 @@ func TestDetectCaptivePortalReturnsFalse(t *testing.T) { } } -func TestAllEndpointsAreUpAndReturnExpectedResponse(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/13019") +func TestEndpointsAreUpAndReturnExpectedResponse(t *testing.T) { + nettest.SkipIfNoNetwork(t) + d := NewDetector(t.Logf) endpoints := availableEndpoints(nil, 0, t.Logf, runtime.GOOS) + t.Logf("testing %d endpoints", len(endpoints)) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var good atomic.Bool var wg sync.WaitGroup + sem := syncs.NewSemaphore(5) for _, e := range endpoints { wg.Add(1) go func(endpoint Endpoint) { defer wg.Done() - found, err := d.verifyCaptivePortalEndpoint(context.Background(), endpoint, 0) - if err != nil { - t.Errorf("verifyCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) + + if !sem.AcquireContext(ctx) { + return + } + defer sem.Release() + + found, err := d.verifyCaptivePortalEndpoint(ctx, endpoint, 0) + if err != nil && ctx.Err() == nil { + t.Logf("verifyCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) } if found { - t.Errorf("verifyCaptivePortalEndpoint with endpoint %v says we're behind a captive portal, but we aren't", endpoint) + t.Logf("verifyCaptivePortalEndpoint with endpoint %v says we're behind a captive portal, but we aren't", endpoint) + return } + good.Store(true) + t.Logf("endpoint good: %v", endpoint) + cancel() }(e) } wg.Wait() + + if !good.Load() { + t.Errorf("no good endpoints found") + } } From f1e1048977b848c8ad8882d77b73e4dd25b1c3f9 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 12 Nov 2024 17:52:31 -0800 Subject: [PATCH 109/179] go.mod: bump tailscale/wireguard-go Updates #11899 Change-Id: Ibd75134a20798c84c7174ba3af639cf22836c7d7 Signed-off-by: Brad Fitzpatrick --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index b5451ab613663..92ba6b9c7b54d 100644 --- a/go.mod +++ b/go.mod @@ -85,7 +85,7 @@ require ( github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 - github.com/tailscale/wireguard-go v0.0.0-20240905161824-799c1978fafc + github.com/tailscale/wireguard-go v0.0.0-20241113014420-4e883d38c8d3 github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e github.com/tc-hib/winres v0.2.1 github.com/tcnksm/go-httpstat v0.2.0 diff --git a/go.sum b/go.sum index 55aa3b5357ff9..fadfb22b1a0c8 100644 --- a/go.sum +++ b/go.sum @@ -941,8 +941,8 @@ github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:t github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 h1:l10Gi6w9jxvinoiq15g8OToDdASBni4CyJOdHY1Hr8M= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6/go.mod h1:ZXRML051h7o4OcI0d3AaILDIad/Xw0IkXaHM17dic1Y= -github.com/tailscale/wireguard-go v0.0.0-20240905161824-799c1978fafc h1:cezaQN9pvKVaw56Ma5qr/G646uKIYP0yQf+OyWN/okc= -github.com/tailscale/wireguard-go v0.0.0-20240905161824-799c1978fafc/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= +github.com/tailscale/wireguard-go v0.0.0-20241113014420-4e883d38c8d3 h1:dmoPb3dG27tZgMtrvqfD/LW4w7gA6BSWl8prCPNmkCQ= +github.com/tailscale/wireguard-go v0.0.0-20241113014420-4e883d38c8d3/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= github.com/tc-hib/winres v0.2.1 h1:YDE0FiP0VmtRaDn7+aaChp1KiF4owBiJa5l964l5ujA= From 5cae7c51bfaaf1adbc645580e48fc55caac9e1c0 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sat, 16 Nov 2024 15:25:51 -0800 Subject: [PATCH 110/179] ipn: remove unused Notify.BackendLogID Updates #14129 Change-Id: I13b5df8765e786a4a919d6b2e72afe987000b2d1 Signed-off-by: Brad Fitzpatrick --- ipn/backend.go | 4 ---- ipn/ipnlocal/local.go | 5 +---- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/ipn/backend.go b/ipn/backend.go index 76ad1910bf14c..5779727fef98a 100644 --- a/ipn/backend.go +++ b/ipn/backend.go @@ -100,7 +100,6 @@ type Notify struct { NetMap *netmap.NetworkMap // if non-nil, the new or current netmap Engine *EngineStatus // if non-nil, the new or current wireguard stats BrowseToURL *string // if non-nil, UI should open a browser right now - BackendLogID *string // if non-nil, the public logtail ID used by backend // FilesWaiting if non-nil means that files are buffered in // the Tailscale daemon and ready for local transfer to the @@ -173,9 +172,6 @@ func (n Notify) String() string { if n.BrowseToURL != nil { sb.WriteString("URL=<...> ") } - if n.BackendLogID != nil { - sb.WriteString("BackendLogID ") - } if n.FilesWaiting != nil { sb.WriteString("FilesWaiting ") } diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 3c7296038233a..33025ed40cbe1 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -2157,10 +2157,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { blid := b.backendLogID.String() b.logf("Backend: logs: be:%v fe:%v", blid, opts.FrontendLogID) - b.sendToLocked(ipn.Notify{ - BackendLogID: &blid, - Prefs: &prefs, - }, allClients) + b.sendToLocked(ipn.Notify{Prefs: &prefs}, allClients) if !loggedOut && (b.hasNodeKeyLocked() || confWantRunning) { // If we know that we're either logged in or meant to be From c2a7f17f2b378897f4545ad6f43891f150423487 Mon Sep 17 00:00:00 2001 From: Andrew Lytvynov Date: Mon, 18 Nov 2024 09:55:54 -0800 Subject: [PATCH 111/179] sessionrecording: implement v2 recording endpoint support (#14105) The v2 endpoint supports HTTP/2 bidirectional streaming and acks for received bytes. This is used to detect when a recorder disappears to more quickly terminate the session. Updates https://github.com/tailscale/corp/issues/24023 Signed-off-by: Andrew Lytvynov --- k8s-operator/sessionrecording/hijacker.go | 2 +- .../sessionrecording/hijacker_test.go | 4 +- sessionrecording/connect.go | 320 ++++++++++++++---- sessionrecording/connect_test.go | 189 +++++++++++ ssh/tailssh/tailssh.go | 13 +- ssh/tailssh/tailssh_test.go | 61 ++-- 6 files changed, 500 insertions(+), 89 deletions(-) create mode 100644 sessionrecording/connect_test.go diff --git a/k8s-operator/sessionrecording/hijacker.go b/k8s-operator/sessionrecording/hijacker.go index f8ef951d415f0..43aa14e613887 100644 --- a/k8s-operator/sessionrecording/hijacker.go +++ b/k8s-operator/sessionrecording/hijacker.go @@ -102,7 +102,7 @@ type Hijacker struct { // connection succeeds. In case of success, returns a list with a single // successful recording attempt and an error channel. If the connection errors // after having been established, an error is sent down the channel. -type RecorderDialFn func(context.Context, []netip.AddrPort, func(context.Context, string, string) (net.Conn, error)) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) +type RecorderDialFn func(context.Context, []netip.AddrPort, sessionrecording.DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) // Hijack hijacks a 'kubectl exec' session and configures for the session // contents to be sent to a recorder. diff --git a/k8s-operator/sessionrecording/hijacker_test.go b/k8s-operator/sessionrecording/hijacker_test.go index 440d9c94294c9..e166ce63b3c85 100644 --- a/k8s-operator/sessionrecording/hijacker_test.go +++ b/k8s-operator/sessionrecording/hijacker_test.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "net" "net/http" "net/netip" "net/url" @@ -20,6 +19,7 @@ import ( "go.uber.org/zap" "tailscale.com/client/tailscale/apitype" "tailscale.com/k8s-operator/sessionrecording/fakes" + "tailscale.com/sessionrecording" "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tstest" @@ -80,7 +80,7 @@ func Test_Hijacker(t *testing.T) { h := &Hijacker{ connectToRecorder: func(context.Context, []netip.AddrPort, - func(context.Context, string, string) (net.Conn, error), + sessionrecording.DialFunc, ) (wc io.WriteCloser, rec []*tailcfg.SSHRecordingAttempt, _ <-chan error, err error) { if tt.failRecorderConnect { err = errors.New("test") diff --git a/sessionrecording/connect.go b/sessionrecording/connect.go index db966ba2cdee2..94761393f885d 100644 --- a/sessionrecording/connect.go +++ b/sessionrecording/connect.go @@ -7,6 +7,8 @@ package sessionrecording import ( "context" + "crypto/tls" + "encoding/json" "errors" "fmt" "io" @@ -14,12 +16,33 @@ import ( "net/http" "net/http/httptrace" "net/netip" + "sync/atomic" "time" + "golang.org/x/net/http2" "tailscale.com/tailcfg" + "tailscale.com/util/httpm" "tailscale.com/util/multierr" ) +const ( + // Timeout for an individual DialFunc call for a single recorder address. + perDialAttemptTimeout = 5 * time.Second + // Timeout for the V2 API HEAD probe request (supportsV2). + http2ProbeTimeout = 10 * time.Second + // Maximum timeout for trying all available recorders, including V2 API + // probes and dial attempts. + allDialAttemptsTimeout = 30 * time.Second +) + +// uploadAckWindow is the period of time to wait for an ackFrame from recorder +// before terminating the connection. This is a variable to allow overriding it +// in tests. +var uploadAckWindow = 30 * time.Second + +// DialFunc is a function for dialing the recorder. +type DialFunc func(ctx context.Context, network, host string) (net.Conn, error) + // ConnectToRecorder connects to the recorder at any of the provided addresses. // It returns the first successful response, or a multierr if all attempts fail. // @@ -32,19 +55,15 @@ import ( // attempts are in order the recorder(s) was attempted. If successful a // successful connection is made, the last attempt in the slice is the // attempt for connected recorder. -func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(context.Context, string, string) (net.Conn, error)) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) { +func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) { if len(recs) == 0 { return nil, nil, nil, errors.New("no recorders configured") } // We use a special context for dialing the recorder, so that we can // limit the time we spend dialing to 30 seconds and still have an // unbounded context for the upload. - dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second) + dialCtx, dialCancel := context.WithTimeout(ctx, allDialAttemptsTimeout) defer dialCancel() - hc, err := SessionRecordingClientForDialer(dialCtx, dial) - if err != nil { - return nil, nil, nil, err - } var errs []error var attempts []*tailcfg.SSHRecordingAttempt @@ -54,74 +73,230 @@ func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(con } attempts = append(attempts, attempt) - // We dial the recorder and wait for it to send a 100-continue - // response before returning from this function. This ensures that - // the recorder is ready to accept the recording. - - // got100 is closed when we receive the 100-continue response. - got100 := make(chan struct{}) - ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ - Got100Continue: func() { - close(got100) - }, - }) - - pr, pw := io.Pipe() - req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", ap.Addr(), ap.Port()), pr) + var pw io.WriteCloser + var errChan <-chan error + var err error + hc := clientHTTP2(dialCtx, dial) + // We need to probe V2 support using a separate HEAD request. Sending + // an HTTP/2 POST request to a HTTP/1 server will just "hang" until the + // request body is closed (instead of returning a 404 as one would + // expect). Sending a HEAD request without a body does not have that + // problem. + if supportsV2(ctx, hc, ap) { + pw, errChan, err = connectV2(ctx, hc, ap) + } else { + pw, errChan, err = connectV1(ctx, clientHTTP1(dialCtx, dial), ap) + } if err != nil { - err = fmt.Errorf("recording: error starting recording: %w", err) + err = fmt.Errorf("recording: error starting recording on %q: %w", ap, err) attempt.FailureMessage = err.Error() errs = append(errs, err) continue } - // We set the Expect header to 100-continue, so that the recorder - // will send a 100-continue response before it starts reading the - // request body. - req.Header.Set("Expect", "100-continue") + return pw, attempts, errChan, nil + } + return nil, attempts, nil, multierr.New(errs...) +} - // errChan is used to indicate the result of the request. - errChan := make(chan error, 1) - go func() { - resp, err := hc.Do(req) - if err != nil { - errChan <- fmt.Errorf("recording: error starting recording: %w", err) +// supportsV2 checks whether a recorder instance supports the /v2/record +// endpoint. +func supportsV2(ctx context.Context, hc *http.Client, ap netip.AddrPort) bool { + ctx, cancel := context.WithTimeout(ctx, http2ProbeTimeout) + defer cancel() + req, err := http.NewRequestWithContext(ctx, httpm.HEAD, fmt.Sprintf("http://%s/v2/record", ap), nil) + if err != nil { + return false + } + resp, err := hc.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + return resp.StatusCode == http.StatusOK && resp.ProtoMajor > 1 +} + +// connectV1 connects to the legacy /record endpoint on the recorder. It is +// used for backwards-compatibility with older tsrecorder instances. +// +// On success, it returns a WriteCloser that can be used to upload the +// recording, and a channel that will be sent an error (or nil) when the upload +// fails or completes. +func connectV1(ctx context.Context, hc *http.Client, ap netip.AddrPort) (io.WriteCloser, <-chan error, error) { + // We dial the recorder and wait for it to send a 100-continue + // response before returning from this function. This ensures that + // the recorder is ready to accept the recording. + + // got100 is closed when we receive the 100-continue response. + got100 := make(chan struct{}) + ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + Got100Continue: func() { + close(got100) + }, + }) + + pr, pw := io.Pipe() + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s/record", ap), pr) + if err != nil { + return nil, nil, err + } + // We set the Expect header to 100-continue, so that the recorder + // will send a 100-continue response before it starts reading the + // request body. + req.Header.Set("Expect", "100-continue") + + // errChan is used to indicate the result of the request. + errChan := make(chan error, 1) + go func() { + defer close(errChan) + resp, err := hc.Do(req) + if err != nil { + errChan <- err + return + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status) + return + } + }() + select { + case <-got100: + return pw, errChan, nil + case err := <-errChan: + // If we get an error before we get the 100-continue response, + // we need to try another recorder. + if err == nil { + // If the error is nil, we got a 200 response, which + // is unexpected as we haven't sent any data yet. + err = errors.New("recording: unexpected EOF") + } + return nil, nil, err + } +} + +// connectV2 connects to the /v2/record endpoint on the recorder over HTTP/2. +// It explicitly tracks ack frames sent in the response and terminates the +// connection if sent recording data is un-acked for uploadAckWindow. +// +// On success, it returns a WriteCloser that can be used to upload the +// recording, and a channel that will be sent an error (or nil) when the upload +// fails or completes. +func connectV2(ctx context.Context, hc *http.Client, ap netip.AddrPort) (io.WriteCloser, <-chan error, error) { + pr, pw := io.Pipe() + upload := &readCounter{r: pr} + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s/v2/record", ap), upload) + if err != nil { + return nil, nil, err + } + + // With HTTP/2, hc.Do will not block while the request body is being sent. + // It will return immediately and allow us to consume the response body at + // the same time. + resp, err := hc.Do(req) + if err != nil { + return nil, nil, err + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, nil, fmt.Errorf("recording: unexpected status: %v", resp.Status) + } + + errChan := make(chan error, 1) + acks := make(chan int64) + // Read acks from the response and send them to the acks channel. + go func() { + defer close(errChan) + defer close(acks) + defer resp.Body.Close() + defer pw.Close() + dec := json.NewDecoder(resp.Body) + for { + var frame v2ResponseFrame + if err := dec.Decode(&frame); err != nil { + if !errors.Is(err, io.EOF) { + errChan <- fmt.Errorf("recording: unexpected error receiving acks: %w", err) + } return } - if resp.StatusCode != 200 { - errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status) + if frame.Error != "" { + errChan <- fmt.Errorf("recording: received error from the recorder: %q", frame.Error) return } - errChan <- nil - }() - select { - case <-got100: - case err := <-errChan: - // If we get an error before we get the 100-continue response, - // we need to try another recorder. - if err == nil { - // If the error is nil, we got a 200 response, which - // is unexpected as we haven't sent any data yet. - err = errors.New("recording: unexpected EOF") + select { + case acks <- frame.Ack: + case <-ctx.Done(): + return } - attempt.FailureMessage = err.Error() - errs = append(errs, err) - continue // try the next recorder } - return pw, attempts, errChan, nil - } - return nil, attempts, nil, multierr.New(errs...) + }() + // Track acks from the acks channel. + go func() { + // Hack for tests: some tests modify uploadAckWindow and reset it when + // the test ends. This can race with t.Reset call below. Making a copy + // here is a lazy workaround to not wait for this goroutine to exit in + // the test cases. + uploadAckWindow := uploadAckWindow + // This timer fires if we didn't receive an ack for too long. + t := time.NewTimer(uploadAckWindow) + defer t.Stop() + for { + select { + case <-t.C: + // Close the pipe which terminates the connection and cleans up + // other goroutines. Note that tsrecorder will send us ack + // frames even if there is no new data to ack. This helps + // detect broken recorder connection if the session is idle. + pr.CloseWithError(errNoAcks) + resp.Body.Close() + return + case _, ok := <-acks: + if !ok { + // acks channel closed means that the goroutine reading them + // finished, which means that the request has ended. + return + } + // TODO(awly): limit how far behind the received acks can be. This + // should handle scenarios where a session suddenly dumps a lot of + // output. + t.Reset(uploadAckWindow) + case <-ctx.Done(): + return + } + } + }() + + return pw, errChan, nil } -// SessionRecordingClientForDialer returns an http.Client that uses a clone of -// the provided Dialer's PeerTransport to dial connections. This is used to make -// requests to the session recording server to upload session recordings. It -// uses the provided dialCtx to dial connections, and limits a single dial to 5 -// seconds. -func SessionRecordingClientForDialer(dialCtx context.Context, dial func(context.Context, string, string) (net.Conn, error)) (*http.Client, error) { - tr := http.DefaultTransport.(*http.Transport).Clone() +var errNoAcks = errors.New("did not receive ack frames from the recorder in 30s") + +type v2ResponseFrame struct { + // Ack is the number of bytes received from the client so far. The bytes + // are not guaranteed to be durably stored yet. + Ack int64 `json:"ack,omitempty"` + // Error is an error encountered while storing the recording. Error is only + // ever set as the last frame in the response. + Error string `json:"error,omitempty"` +} +// readCounter is an io.Reader that counts how many bytes were read. +type readCounter struct { + r io.Reader + sent atomic.Int64 +} + +func (u *readCounter) Read(buf []byte) (int, error) { + n, err := u.r.Read(buf) + u.sent.Add(int64(n)) + return n, err +} + +// clientHTTP1 returns a claassic http.Client with a per-dial context. It uses +// dialCtx and adds a 5s timeout to it. +func clientHTTP1(dialCtx context.Context, dial DialFunc) *http.Client { + tr := http.DefaultTransport.(*http.Transport).Clone() tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - perAttemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + perAttemptCtx, cancel := context.WithTimeout(ctx, perDialAttemptTimeout) defer cancel() go func() { select { @@ -132,7 +307,32 @@ func SessionRecordingClientForDialer(dialCtx context.Context, dial func(context. }() return dial(perAttemptCtx, network, addr) } + return &http.Client{Transport: tr} +} + +// clientHTTP2 is like clientHTTP1 but returns an http.Client suitable for h2c +// requests (HTTP/2 over plaintext). Unfortunately the same client does not +// work for HTTP/1 so we need to split these up. +func clientHTTP2(dialCtx context.Context, dial DialFunc) *http.Client { return &http.Client{ - Transport: tr, - }, nil + Transport: &http2.Transport{ + // Allow "http://" scheme in URLs. + AllowHTTP: true, + // Pretend like we're using TLS, but actually use the provided + // DialFunc underneath. This is necessary to convince the transport + // to actually dial. + DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { + perAttemptCtx, cancel := context.WithTimeout(ctx, perDialAttemptTimeout) + defer cancel() + go func() { + select { + case <-perAttemptCtx.Done(): + case <-dialCtx.Done(): + cancel() + } + }() + return dial(perAttemptCtx, network, addr) + }, + }, + } } diff --git a/sessionrecording/connect_test.go b/sessionrecording/connect_test.go new file mode 100644 index 0000000000000..c0fcf6d40c617 --- /dev/null +++ b/sessionrecording/connect_test.go @@ -0,0 +1,189 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package sessionrecording + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/json" + "io" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "testing" + "time" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +func TestConnectToRecorder(t *testing.T) { + tests := []struct { + desc string + http2 bool + // setup returns a recorder server mux, and a channel which sends the + // hash of the recording uploaded to it. The channel is expected to + // fire only once. + setup func(t *testing.T) (*http.ServeMux, <-chan []byte) + wantErr bool + }{ + { + desc: "v1 recorder", + setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { + uploadHash := make(chan []byte, 1) + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) { + hash := sha256.New() + if _, err := io.Copy(hash, r.Body); err != nil { + t.Error(err) + } + uploadHash <- hash.Sum(nil) + }) + return mux, uploadHash + }, + }, + { + desc: "v2 recorder", + http2: true, + setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { + uploadHash := make(chan []byte, 1) + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) { + t.Error("received request to v1 endpoint") + http.Error(w, "not found", http.StatusNotFound) + }) + mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) { + // Force the status to send to unblock the client waiting + // for it. + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + + body := &readCounter{r: r.Body} + hash := sha256.New() + ctx, cancel := context.WithCancel(r.Context()) + go func() { + defer cancel() + if _, err := io.Copy(hash, body); err != nil { + t.Error(err) + } + }() + + // Send acks for received bytes. + tick := time.NewTicker(time.Millisecond) + defer tick.Stop() + enc := json.NewEncoder(w) + outer: + for { + select { + case <-ctx.Done(): + break outer + case <-tick.C: + if err := enc.Encode(v2ResponseFrame{Ack: body.sent.Load()}); err != nil { + t.Errorf("writing ack frame: %v", err) + break outer + } + } + } + + uploadHash <- hash.Sum(nil) + }) + // Probing HEAD endpoint which always returns 200 OK. + mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {}) + return mux, uploadHash + }, + }, + { + desc: "v2 recorder no acks", + http2: true, + wantErr: true, + setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { + // Make the client no-ack timeout quick for the test. + oldAckWindow := uploadAckWindow + uploadAckWindow = 100 * time.Millisecond + t.Cleanup(func() { uploadAckWindow = oldAckWindow }) + + uploadHash := make(chan []byte, 1) + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) { + t.Error("received request to v1 endpoint") + http.Error(w, "not found", http.StatusNotFound) + }) + mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) { + // Force the status to send to unblock the client waiting + // for it. + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + + // Consume the whole request body but don't send any acks + // back. + hash := sha256.New() + if _, err := io.Copy(hash, r.Body); err != nil { + t.Error(err) + } + // Goes in the channel buffer, non-blocking. + uploadHash <- hash.Sum(nil) + + // Block until the parent test case ends to prevent the + // request termination. We want to exercise the ack + // tracking logic specifically. + ctx, cancel := context.WithCancel(r.Context()) + t.Cleanup(cancel) + <-ctx.Done() + }) + mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {}) + return mux, uploadHash + }, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + mux, uploadHash := tt.setup(t) + + srv := httptest.NewUnstartedServer(mux) + if tt.http2 { + // Wire up h2c-compatible HTTP/2 server. This is optional + // because the v1 recorder didn't support HTTP/2 and we try to + // mimic that. + h2s := &http2.Server{} + srv.Config.Handler = h2c.NewHandler(mux, h2s) + if err := http2.ConfigureServer(srv.Config, h2s); err != nil { + t.Errorf("configuring HTTP/2 support in server: %v", err) + } + } + srv.Start() + t.Cleanup(srv.Close) + + d := new(net.Dialer) + + ctx := context.Background() + w, _, errc, err := ConnectToRecorder(ctx, []netip.AddrPort{netip.MustParseAddrPort(srv.Listener.Addr().String())}, d.DialContext) + if err != nil { + t.Fatalf("ConnectToRecorder: %v", err) + } + + // Send some random data and hash it to compare with the recorded + // data hash. + hash := sha256.New() + const numBytes = 1 << 20 // 1MB + if _, err := io.CopyN(io.MultiWriter(w, hash), rand.Reader, numBytes); err != nil { + t.Fatalf("writing recording data: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("closing recording stream: %v", err) + } + if err := <-errc; err != nil && !tt.wantErr { + t.Fatalf("error from the channel: %v", err) + } else if err == nil && tt.wantErr { + t.Fatalf("did not receive expected error from the channel") + } + + if recv, sent := <-uploadHash, hash.Sum(nil); !bytes.Equal(recv, sent) { + t.Errorf("mismatch in recording data hash, sent %x, received %x", sent, recv) + } + }) + } +} diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 9ade1847e6b27..7cb99c3813104 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -1170,7 +1170,7 @@ func (ss *sshSession) run() { if err != nil && !errors.Is(err, io.EOF) { isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO) if !isErrBecauseProcessExited { - logf("stdout copy: %v, %T", err) + logf("stdout copy: %v", err) ss.cancelCtx(err) } } @@ -1520,9 +1520,14 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { go func() { err := <-errChan if err == nil { - // Success. - ss.logf("recording: finished uploading recording") - return + select { + case <-ss.ctx.Done(): + // Success. + ss.logf("recording: finished uploading recording") + return + default: + err = errors.New("recording upload ended before the SSH session") + } } if onFailure != nil && onFailure.NotifyURL != "" && len(attempts) > 0 { lastAttempt := attempts[len(attempts)-1] diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 7ce0aeea3b2fa..ad9cb1e57b53d 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -33,6 +33,8 @@ import ( "time" gossh "github.com/tailscale/golang-x-crypto/ssh" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/store/mem" "tailscale.com/net/memnet" @@ -481,10 +483,9 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { } var handler http.HandlerFunc - recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { handler(w, r) - })) - defer recordingServer.Close() + }) s := &server{ logf: t.Logf, @@ -533,9 +534,10 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { { name: "upload-fails-after-starting", handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() r.Body.Read(make([]byte, 1)) time.Sleep(100 * time.Millisecond) - w.WriteHeader(http.StatusInternalServerError) }, sshCommand: "echo hello && sleep 1 && echo world", wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n", @@ -548,6 +550,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + s.logf = t.Logf tstest.Replace(t, &handler, tt.handler) sc, dc := memnet.NewTCPConn(src, dst, 1024) var wg sync.WaitGroup @@ -597,12 +600,12 @@ func TestMultipleRecorders(t *testing.T) { t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) } done := make(chan struct{}) - recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { defer close(done) - io.ReadAll(r.Body) w.WriteHeader(http.StatusOK) - })) - defer recordingServer.Close() + w.(http.Flusher).Flush() + io.ReadAll(r.Body) + }) badRecorder, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) @@ -610,15 +613,9 @@ func TestMultipleRecorders(t *testing.T) { badRecorderAddr := badRecorder.Addr().String() badRecorder.Close() - badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(500) - })) - defer badRecordingServer500.Close() - - badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - })) - defer badRecordingServer200.Close() + badRecordingServer500 := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }) s := &server{ logf: t.Logf, @@ -630,7 +627,6 @@ func TestMultipleRecorders(t *testing.T) { Recorders: []netip.AddrPort{ netip.MustParseAddrPort(badRecorderAddr), netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()), - netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()), netip.MustParseAddrPort(recordingServer.Listener.Addr().String()), }, OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{ @@ -701,19 +697,21 @@ func TestSSHRecordingNonInteractive(t *testing.T) { } var recording []byte ctx, cancel := context.WithTimeout(context.Background(), time.Second) - recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { defer cancel() + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + var err error recording, err = io.ReadAll(r.Body) if err != nil { t.Error(err) return } - })) - defer recordingServer.Close() + }) s := &server{ - logf: logger.Discard, + logf: t.Logf, lb: &localState{ sshEnabled: true, matchingRule: newSSHRule( @@ -1299,3 +1297,22 @@ func TestStdOsUserUserAssumptions(t *testing.T) { t.Errorf("os/user.User has %v fields; this package assumes %v", got, want) } } + +func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server { + t.Helper() + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(http.ResponseWriter, *http.Request) { + t.Errorf("v1 recording endpoint called") + }) + mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {}) + mux.HandleFunc("POST /v2/record", handleRecord) + + h2s := &http2.Server{} + srv := httptest.NewUnstartedServer(h2c.NewHandler(mux, h2s)) + if err := http2.ConfigureServer(srv.Config, h2s); err != nil { + t.Errorf("configuring HTTP/2 support in recording server: %v", err) + } + srv.Start() + t.Cleanup(srv.Close) + return srv +} From 93db50356536e89b70e5ca7650ab2abd36444fd2 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 15 Nov 2024 13:31:35 -0800 Subject: [PATCH 112/179] ipn/ipnlocal: add IPN Bus NotifyRateLimit watch bit NotifyRateLimit Limit spamming GUIs with boring updates to once in 3 seconds, unless the notification is relatively interesting and the GUI should update immediately. This is basically @barnstar's #14119 but with the logic moved to be per-watch-session (since the bit is per session), rather than globally. And this distinguishes notable Notify messages (such as state changes) and makes them send immediately. Updates tailscale/corp#24553 Change-Id: I79cac52cce85280ce351e65e76ea11e107b00b49 Signed-off-by: Brad Fitzpatrick --- cmd/tailscale/cli/debug.go | 5 + ipn/backend.go | 2 + ipn/ipnlocal/bus.go | 161 +++++++++++++++++++++++++++ ipn/ipnlocal/bus_test.go | 220 +++++++++++++++++++++++++++++++++++++ ipn/ipnlocal/local.go | 17 ++- 5 files changed, 395 insertions(+), 10 deletions(-) create mode 100644 ipn/ipnlocal/bus.go create mode 100644 ipn/ipnlocal/bus_test.go diff --git a/cmd/tailscale/cli/debug.go b/cmd/tailscale/cli/debug.go index fdde9ef096ae3..7f235e85c8ca7 100644 --- a/cmd/tailscale/cli/debug.go +++ b/cmd/tailscale/cli/debug.go @@ -213,6 +213,7 @@ var debugCmd = &ffcli.Command{ fs := newFlagSet("watch-ipn") fs.BoolVar(&watchIPNArgs.netmap, "netmap", true, "include netmap in messages") fs.BoolVar(&watchIPNArgs.initial, "initial", false, "include initial status") + fs.BoolVar(&watchIPNArgs.rateLimit, "rate-limit", true, "rate limit messags") fs.BoolVar(&watchIPNArgs.showPrivateKey, "show-private-key", false, "include node private key in printed netmap") fs.IntVar(&watchIPNArgs.count, "count", 0, "exit after printing this many statuses, or 0 to keep going forever") return fs @@ -500,6 +501,7 @@ var watchIPNArgs struct { netmap bool initial bool showPrivateKey bool + rateLimit bool count int } @@ -511,6 +513,9 @@ func runWatchIPN(ctx context.Context, args []string) error { if !watchIPNArgs.showPrivateKey { mask |= ipn.NotifyNoPrivateKeys } + if watchIPNArgs.rateLimit { + mask |= ipn.NotifyRateLimit + } watcher, err := localClient.WatchIPNBus(ctx, mask) if err != nil { return err diff --git a/ipn/backend.go b/ipn/backend.go index 5779727fef98a..91a35df0d0da0 100644 --- a/ipn/backend.go +++ b/ipn/backend.go @@ -73,6 +73,8 @@ const ( NotifyInitialOutgoingFiles // if set, the first Notify message (sent immediately) will contain the current Taildrop OutgoingFiles NotifyInitialHealthState // if set, the first Notify message (sent immediately) will contain the current health.State of the client + + NotifyRateLimit // if set, rate limit spammy netmap updates to every few seconds ) // Notify is a communication from a backend (e.g. tailscaled) to a frontend diff --git a/ipn/ipnlocal/bus.go b/ipn/ipnlocal/bus.go new file mode 100644 index 0000000000000..65cc2573a6bb4 --- /dev/null +++ b/ipn/ipnlocal/bus.go @@ -0,0 +1,161 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "context" + "time" + + "tailscale.com/ipn" + "tailscale.com/tstime" +) + +type rateLimitingBusSender struct { + fn func(*ipn.Notify) (keepGoing bool) + lastFlush time.Time // last call to fn, or zero value if none + interval time.Duration // 0 to flush immediately; non-zero to rate limit sends + clock tstime.DefaultClock // non-nil for testing + didSendTestHook func() // non-nil for testing + + // pending, if non-nil, is the pending notification that we + // haven't sent yet. We own this memory to mutate. + pending *ipn.Notify + + // flushTimer is non-nil if the timer is armed. + flushTimer tstime.TimerController // effectively a *time.Timer + flushTimerC <-chan time.Time // ... said ~Timer's C chan +} + +func (s *rateLimitingBusSender) close() { + if s.flushTimer != nil { + s.flushTimer.Stop() + } +} + +func (s *rateLimitingBusSender) flushChan() <-chan time.Time { + return s.flushTimerC +} + +func (s *rateLimitingBusSender) flush() (keepGoing bool) { + if n := s.pending; n != nil { + s.pending = nil + return s.flushNotify(n) + } + return true +} + +func (s *rateLimitingBusSender) flushNotify(n *ipn.Notify) (keepGoing bool) { + s.lastFlush = s.clock.Now() + return s.fn(n) +} + +// send conditionally sends n to the underlying fn, possibly rate +// limiting it, depending on whether s.interval is set, and whether +// n is a notable notification that the client (typically a GUI) would +// want to act on (render) immediately. +// +// It returns whether the caller should keep looping. +// +// The passed-in memory 'n' is owned by the caller and should +// not be mutated. +func (s *rateLimitingBusSender) send(n *ipn.Notify) (keepGoing bool) { + if s.interval <= 0 { + // No rate limiting case. + return s.fn(n) + } + if isNotableNotify(n) { + // Notable notifications are always sent immediately. + // But first send any boring one that was pending. + // TODO(bradfitz): there might be a boring one pending + // with a NetMap or Engine field that is redundant + // with the new one (n) with NetMap or Engine populated. + // We should clear the pending one's NetMap/Engine in + // that case. Or really, merge the two, but mergeBoringNotifies + // only handles the case of both sides being boring. + // So for now, flush both. + if !s.flush() { + return false + } + return s.flushNotify(n) + } + s.pending = mergeBoringNotifies(s.pending, n) + d := s.clock.Now().Sub(s.lastFlush) + if d > s.interval { + return s.flush() + } + nextFlushIn := s.interval - d + if s.flushTimer == nil { + s.flushTimer, s.flushTimerC = s.clock.NewTimer(nextFlushIn) + } else { + s.flushTimer.Reset(nextFlushIn) + } + return true +} + +func (s *rateLimitingBusSender) Run(ctx context.Context, ch <-chan *ipn.Notify) { + for { + select { + case <-ctx.Done(): + return + case n, ok := <-ch: + if !ok { + return + } + if !s.send(n) { + return + } + if f := s.didSendTestHook; f != nil { + f() + } + case <-s.flushChan(): + if !s.flush() { + return + } + } + } +} + +// mergeBoringNotify merges new notify 'src' into possibly-nil 'dst', +// either mutating 'dst' or allocating a new one if 'dst' is nil, +// returning the merged result. +// +// dst and src must both be "boring" (i.e. not notable per isNotifiableNotify). +func mergeBoringNotifies(dst, src *ipn.Notify) *ipn.Notify { + if dst == nil { + dst = &ipn.Notify{Version: src.Version} + } + if src.NetMap != nil { + dst.NetMap = src.NetMap + } + if src.Engine != nil { + dst.Engine = src.Engine + } + return dst +} + +// isNotableNotify reports whether n is a "notable" notification that +// should be sent on the IPN bus immediately (e.g. to GUIs) without +// rate limiting it for a few seconds. +// +// It effectively reports whether n contains any field set that's +// not NetMap or Engine. +func isNotableNotify(n *ipn.Notify) bool { + if n == nil { + return false + } + return n.State != nil || + n.SessionID != "" || + n.BackendLogID != nil || + n.BrowseToURL != nil || + n.LocalTCPPort != nil || + n.ClientVersion != nil || + n.Prefs != nil || + n.ErrMessage != nil || + n.LoginFinished != nil || + !n.DriveShares.IsNil() || + n.Health != nil || + len(n.IncomingFiles) > 0 || + len(n.OutgoingFiles) > 0 || + n.FilesWaiting != nil +} diff --git a/ipn/ipnlocal/bus_test.go b/ipn/ipnlocal/bus_test.go new file mode 100644 index 0000000000000..5c75ac54d688d --- /dev/null +++ b/ipn/ipnlocal/bus_test.go @@ -0,0 +1,220 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "context" + "reflect" + "slices" + "testing" + "time" + + "tailscale.com/drive" + "tailscale.com/ipn" + "tailscale.com/tstest" + "tailscale.com/tstime" + "tailscale.com/types/logger" + "tailscale.com/types/netmap" + "tailscale.com/types/views" +) + +func TestIsNotableNotify(t *testing.T) { + tests := []struct { + name string + notify *ipn.Notify + want bool + }{ + {"nil", nil, false}, + {"empty", &ipn.Notify{}, false}, + {"version", &ipn.Notify{Version: "foo"}, false}, + {"netmap", &ipn.Notify{NetMap: new(netmap.NetworkMap)}, false}, + {"engine", &ipn.Notify{Engine: new(ipn.EngineStatus)}, false}, + } + + // Then for all other fields, assume they're notable. + // We use reflect to catch fields that might be added in the future without + // remembering to update the [isNotableNotify] function. + rt := reflect.TypeFor[ipn.Notify]() + for i := range rt.NumField() { + n := &ipn.Notify{} + sf := rt.Field(i) + switch sf.Name { + case "_", "NetMap", "Engine", "Version": + // Already covered above or not applicable. + continue + case "DriveShares": + n.DriveShares = views.SliceOfViews[*drive.Share, drive.ShareView](make([]*drive.Share, 1)) + default: + rf := reflect.ValueOf(n).Elem().Field(i) + switch rf.Kind() { + case reflect.Pointer: + rf.Set(reflect.New(rf.Type().Elem())) + case reflect.String: + rf.SetString("foo") + case reflect.Slice: + rf.Set(reflect.MakeSlice(rf.Type(), 1, 1)) + default: + t.Errorf("unhandled field kind %v for %q", rf.Kind(), sf.Name) + } + } + + tests = append(tests, struct { + name string + notify *ipn.Notify + want bool + }{ + name: "field-" + rt.Field(i).Name, + notify: n, + want: true, + }) + } + + for _, tt := range tests { + if got := isNotableNotify(tt.notify); got != tt.want { + t.Errorf("%v: got %v; want %v", tt.name, got, tt.want) + } + } +} + +type rateLimitingBusSenderTester struct { + tb testing.TB + got []*ipn.Notify + clock *tstest.Clock + s *rateLimitingBusSender +} + +func (st *rateLimitingBusSenderTester) init() { + if st.s != nil { + return + } + st.clock = tstest.NewClock(tstest.ClockOpts{ + Start: time.Unix(1731777537, 0), // time I wrote this test :) + }) + st.s = &rateLimitingBusSender{ + clock: tstime.DefaultClock{Clock: st.clock}, + fn: func(n *ipn.Notify) bool { + st.got = append(st.got, n) + return true + }, + } +} + +func (st *rateLimitingBusSenderTester) send(n *ipn.Notify) { + st.tb.Helper() + st.init() + if !st.s.send(n) { + st.tb.Fatal("unexpected send failed") + } +} + +func (st *rateLimitingBusSenderTester) advance(d time.Duration) { + st.tb.Helper() + st.clock.Advance(d) + select { + case <-st.s.flushChan(): + if !st.s.flush() { + st.tb.Fatal("unexpected flush failed") + } + default: + } +} + +func TestRateLimitingBusSender(t *testing.T) { + nm1 := &ipn.Notify{NetMap: new(netmap.NetworkMap)} + nm2 := &ipn.Notify{NetMap: new(netmap.NetworkMap)} + eng1 := &ipn.Notify{Engine: new(ipn.EngineStatus)} + eng2 := &ipn.Notify{Engine: new(ipn.EngineStatus)} + + t.Run("unbuffered", func(t *testing.T) { + st := &rateLimitingBusSenderTester{tb: t} + st.send(nm1) + st.send(nm2) + st.send(eng1) + st.send(eng2) + if !slices.Equal(st.got, []*ipn.Notify{nm1, nm2, eng1, eng2}) { + t.Errorf("got %d items; want 4 specific ones, unmodified", len(st.got)) + } + }) + + t.Run("buffered", func(t *testing.T) { + st := &rateLimitingBusSenderTester{tb: t} + st.init() + st.s.interval = 1 * time.Second + st.send(&ipn.Notify{Version: "initial"}) + if len(st.got) != 1 { + t.Fatalf("got %d items; expected 1 (first to flush immediately)", len(st.got)) + } + st.send(nm1) + st.send(nm2) + st.send(eng1) + st.send(eng2) + if len(st.got) != 1 { + if len(st.got) != 1 { + t.Fatalf("got %d items; expected still just that first 1", len(st.got)) + } + } + + // But moving the clock should flush the rest, collasced into one new one. + st.advance(5 * time.Second) + if len(st.got) != 2 { + t.Fatalf("got %d items; want 2", len(st.got)) + } + gotn := st.got[1] + if gotn.NetMap != nm2.NetMap { + t.Errorf("got wrong NetMap; got %p", gotn.NetMap) + } + if gotn.Engine != eng2.Engine { + t.Errorf("got wrong Engine; got %p", gotn.Engine) + } + if t.Failed() { + t.Logf("failed Notify was: %v", logger.AsJSON(gotn)) + } + }) + + // Test the Run method + t.Run("run", func(t *testing.T) { + st := &rateLimitingBusSenderTester{tb: t} + st.init() + st.s.interval = 1 * time.Second + st.s.lastFlush = st.clock.Now() // pretend we just flushed + + flushc := make(chan *ipn.Notify, 1) + st.s.fn = func(n *ipn.Notify) bool { + flushc <- n + return true + } + didSend := make(chan bool, 2) + st.s.didSendTestHook = func() { didSend <- true } + waitSend := func() { + select { + case <-didSend: + case <-time.After(5 * time.Second): + t.Error("timeout waiting for call to send") + } + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + incoming := make(chan *ipn.Notify, 2) + go func() { + incoming <- nm1 + waitSend() + incoming <- nm2 + waitSend() + st.advance(5 * time.Second) + select { + case n := <-flushc: + if n.NetMap != nm2.NetMap { + t.Errorf("got wrong NetMap; got %p", n.NetMap) + } + case <-time.After(10 * time.Second): + t.Error("timeout") + } + cancel() + }() + + st.s.Run(ctx, incoming) + }) +} diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 33025ed40cbe1..cbbea32aa8363 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -2780,20 +2780,17 @@ func (b *LocalBackend) WatchNotificationsAs(ctx context.Context, actor ipnauth.A go b.pollRequestEngineStatus(ctx) } - // TODO(marwan-at-work): check err // TODO(marwan-at-work): streaming background logs? defer b.DeleteForegroundSession(sessionID) - for { - select { - case <-ctx.Done(): - return - case n := <-ch: - if !fn(n) { - return - } - } + sender := &rateLimitingBusSender{fn: fn} + defer sender.close() + + if mask&ipn.NotifyRateLimit != 0 { + sender.interval = 3 * time.Second } + + sender.Run(ctx, ch) } // pollRequestEngineStatus calls b.e.RequestStatus every 2 seconds until ctx From da70a84a4babe00c2f07cb063e18098b795d6249 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 18 Nov 2024 12:04:12 -0800 Subject: [PATCH 113/179] ipn/ipnlocal: fix build, remove another Notify.BackendLogID reference that crept in I merged 5cae7c51bfa (removing Notify.BackendLogID) and 93db50356536e (adding another reference to Notify.BackendLogID) that didn't have merge conflicts, but didn't compile together. This removes the new reference, fixing the build. Updates #14129 Change-Id: I9bb68efd977342ea8822e525d656817235039a66 Signed-off-by: Brad Fitzpatrick --- ipn/ipnlocal/bus.go | 1 - 1 file changed, 1 deletion(-) diff --git a/ipn/ipnlocal/bus.go b/ipn/ipnlocal/bus.go index 65cc2573a6bb4..111a877d849d8 100644 --- a/ipn/ipnlocal/bus.go +++ b/ipn/ipnlocal/bus.go @@ -146,7 +146,6 @@ func isNotableNotify(n *ipn.Notify) bool { } return n.State != nil || n.SessionID != "" || - n.BackendLogID != nil || n.BrowseToURL != nil || n.LocalTCPPort != nil || n.ClientVersion != nil || From 00517c8189569171560c073cd983164ff7735e69 Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Tue, 19 Nov 2024 13:07:19 +0000 Subject: [PATCH 114/179] kube/{kubeapi,kubeclient},ipn/store/kubestore,cmd/{containerboot,k8s-operator}: emit kube store Events (#14112) Adds functionality to kube client to emit Events. Updates kube store to emit Events when tailscaled state has been loaded, updated or if any errors where encountered during those operations. This should help in cases where an error related to state loading/updating caused the Pod to crash in a loop- unlike logs of the originally failed container instance, Events associated with the Pod will still be accessible even after N restarts. Updates tailscale/tailscale#14080 Signed-off-by: Irbe Krumina --- cmd/containerboot/kube.go | 4 +- cmd/containerboot/services.go | 2 +- .../deploy/chart/templates/proxy-rbac.yaml | 3 + .../deploy/manifests/operator.yaml | 8 + cmd/k8s-operator/deploy/manifests/proxy.yaml | 8 + .../deploy/manifests/userspace-proxy.yaml | 8 + cmd/k8s-operator/proxygroup_specs.go | 24 +- cmd/k8s-operator/testutils_test.go | 4 + ipn/store/kubestore/store_kube.go | 44 ++- kube/kubeapi/api.go | 57 +++- kube/kubeclient/client.go | 289 +++++++++++++----- kube/kubeclient/client_test.go | 151 +++++++++ kube/kubeclient/fake_client.go | 6 +- 13 files changed, 506 insertions(+), 102 deletions(-) create mode 100644 kube/kubeclient/client_test.go diff --git a/cmd/containerboot/kube.go b/cmd/containerboot/kube.go index 908cc01efc25a..5a726c20b33e9 100644 --- a/cmd/containerboot/kube.go +++ b/cmd/containerboot/kube.go @@ -61,7 +61,7 @@ func deleteAuthKey(ctx context.Context, secretName string) error { Path: "/data/authkey", }, } - if err := kc.JSONPatchSecret(ctx, secretName, m); err != nil { + if err := kc.JSONPatchResource(ctx, secretName, kubeclient.TypeSecrets, m); err != nil { if s, ok := err.(*kubeapi.Status); ok && s.Code == http.StatusUnprocessableEntity { // This is kubernetes-ese for "the field you asked to // delete already doesn't exist", aka no-op. @@ -81,7 +81,7 @@ func initKubeClient(root string) { kubeclient.SetRootPathForTesting(root) } var err error - kc, err = kubeclient.New() + kc, err = kubeclient.New("tailscale-container") if err != nil { log.Fatalf("Error creating kube client: %v", err) } diff --git a/cmd/containerboot/services.go b/cmd/containerboot/services.go index 4da7286b7ca0a..aed00250d001e 100644 --- a/cmd/containerboot/services.go +++ b/cmd/containerboot/services.go @@ -389,7 +389,7 @@ func (ep *egressProxy) setStatus(ctx context.Context, status *egressservices.Sta Path: fmt.Sprintf("/data/%s", egressservices.KeyEgressServices), Value: bs, } - if err := ep.kc.JSONPatchSecret(ctx, ep.stateSecret, []kubeclient.JSONPatch{patch}); err != nil { + if err := ep.kc.JSONPatchResource(ctx, ep.stateSecret, kubeclient.TypeSecrets, []kubeclient.JSONPatch{patch}); err != nil { return fmt.Errorf("error patching state Secret: %w", err) } ep.tailnetAddrs = n.NetMap.SelfNode.Addresses().AsSlice() diff --git a/cmd/k8s-operator/deploy/chart/templates/proxy-rbac.yaml b/cmd/k8s-operator/deploy/chart/templates/proxy-rbac.yaml index 1c15c9119f971..fa552a7c7e39a 100644 --- a/cmd/k8s-operator/deploy/chart/templates/proxy-rbac.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/proxy-rbac.yaml @@ -16,6 +16,9 @@ rules: - apiGroups: [""] resources: ["secrets"] verbs: ["create","delete","deletecollection","get","list","patch","update","watch"] +- apiGroups: [""] + resources: ["events"] + verbs: ["create", "patch", "get"] --- apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding diff --git a/cmd/k8s-operator/deploy/manifests/operator.yaml b/cmd/k8s-operator/deploy/manifests/operator.yaml index 9d8e9faf60816..c6d7deef59dea 100644 --- a/cmd/k8s-operator/deploy/manifests/operator.yaml +++ b/cmd/k8s-operator/deploy/manifests/operator.yaml @@ -4703,6 +4703,14 @@ rules: - patch - update - watch + - apiGroups: + - "" + resources: + - events + verbs: + - create + - patch + - get --- apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding diff --git a/cmd/k8s-operator/deploy/manifests/proxy.yaml b/cmd/k8s-operator/deploy/manifests/proxy.yaml index a79d48d73ce0f..1ad63c2653361 100644 --- a/cmd/k8s-operator/deploy/manifests/proxy.yaml +++ b/cmd/k8s-operator/deploy/manifests/proxy.yaml @@ -30,6 +30,14 @@ spec: valueFrom: fieldRef: fieldPath: status.podIP + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid securityContext: capabilities: add: diff --git a/cmd/k8s-operator/deploy/manifests/userspace-proxy.yaml b/cmd/k8s-operator/deploy/manifests/userspace-proxy.yaml index 46b49a57b1909..6617f6d4b52fe 100644 --- a/cmd/k8s-operator/deploy/manifests/userspace-proxy.yaml +++ b/cmd/k8s-operator/deploy/manifests/userspace-proxy.yaml @@ -24,3 +24,11 @@ spec: valueFrom: fieldRef: fieldPath: status.podIP + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid diff --git a/cmd/k8s-operator/proxygroup_specs.go b/cmd/k8s-operator/proxygroup_specs.go index 27fd9ef716361..b47cb39b1e9c6 100644 --- a/cmd/k8s-operator/proxygroup_specs.go +++ b/cmd/k8s-operator/proxygroup_specs.go @@ -126,15 +126,6 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode, cfgHa }, }, }, - { - Name: "POD_NAME", - ValueFrom: &corev1.EnvVarSource{ - FieldRef: &corev1.ObjectFieldSelector{ - // Secret is named after the pod. - FieldPath: "metadata.name", - }, - }, - }, { Name: "TS_KUBE_SECRET", Value: "$(POD_NAME)", @@ -147,10 +138,6 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode, cfgHa Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig/$(POD_NAME)", }, - { - Name: "TS_USERSPACE", - Value: "false", - }, { Name: "TS_INTERNAL_APP", Value: kubetypes.AppProxyGroupEgress, @@ -171,7 +158,7 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode, cfgHa }) } - return envs + return append(c.Env, envs...) }() return ss, nil @@ -215,6 +202,15 @@ func pgRole(pg *tsapi.ProxyGroup, namespace string) *rbacv1.Role { return secrets }(), }, + { + APIGroups: []string{""}, + Resources: []string{"events"}, + Verbs: []string{ + "create", + "patch", + "get", + }, + }, }, } } diff --git a/cmd/k8s-operator/testutils_test.go b/cmd/k8s-operator/testutils_test.go index d42f1b7af89c7..084f573e5e45a 100644 --- a/cmd/k8s-operator/testutils_test.go +++ b/cmd/k8s-operator/testutils_test.go @@ -70,6 +70,8 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef Env: []corev1.EnvVar{ {Name: "TS_USERSPACE", Value: "false"}, {Name: "POD_IP", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "status.podIP"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, + {Name: "POD_NAME", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "metadata.name"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, + {Name: "POD_UID", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "metadata.uid"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, {Name: "TS_KUBE_SECRET", Value: opts.secretName}, {Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig"}, }, @@ -228,6 +230,8 @@ func expectedSTSUserspace(t *testing.T, cl client.Client, opts configOpts) *apps Env: []corev1.EnvVar{ {Name: "TS_USERSPACE", Value: "true"}, {Name: "POD_IP", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "status.podIP"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, + {Name: "POD_NAME", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "metadata.name"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, + {Name: "POD_UID", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "metadata.uid"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, {Name: "TS_KUBE_SECRET", Value: opts.secretName}, {Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig"}, {Name: "TS_SERVE_CONFIG", Value: "/etc/tailscaled/serve-config"}, diff --git a/ipn/store/kubestore/store_kube.go b/ipn/store/kubestore/store_kube.go index 2dcc08b6e4d1c..462e6d43425ff 100644 --- a/ipn/store/kubestore/store_kube.go +++ b/ipn/store/kubestore/store_kube.go @@ -7,6 +7,7 @@ package kubestore import ( "context" "fmt" + "log" "net" "os" "strings" @@ -19,8 +20,18 @@ import ( "tailscale.com/types/logger" ) -// TODO(irbekrm): should we bump this? should we have retries? See tailscale/tailscale#13024 -const timeout = 5 * time.Second +const ( + // timeout is the timeout for a single state update that includes calls to the API server to write or read a + // state Secret and emit an Event. + timeout = 30 * time.Second + + reasonTailscaleStateUpdated = "TailscaledStateUpdated" + reasonTailscaleStateLoaded = "TailscaleStateLoaded" + reasonTailscaleStateUpdateFailed = "TailscaleStateUpdateFailed" + reasonTailscaleStateLoadFailed = "TailscaleStateLoadFailed" + eventTypeWarning = "Warning" + eventTypeNormal = "Normal" +) // Store is an ipn.StateStore that uses a Kubernetes Secret for persistence. type Store struct { @@ -35,7 +46,7 @@ type Store struct { // New returns a new Store that persists to the named Secret. func New(_ logger.Logf, secretName string) (*Store, error) { - c, err := kubeclient.New() + c, err := kubeclient.New("tailscale-state-store") if err != nil { return nil, err } @@ -72,13 +83,22 @@ func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) { // WriteState implements the StateStore interface. func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer func() { if err == nil { s.memory.WriteState(ipn.StateKey(sanitizeKey(id)), bs) } + if err != nil { + if err := s.client.Event(ctx, eventTypeWarning, reasonTailscaleStateUpdateFailed, err.Error()); err != nil { + log.Printf("kubestore: error creating tailscaled state update Event: %v", err) + } + } else { + if err := s.client.Event(ctx, eventTypeNormal, reasonTailscaleStateUpdated, "Successfully updated tailscaled state Secret"); err != nil { + log.Printf("kubestore: error creating tailscaled state Event: %v", err) + } + } + cancel() }() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() secret, err := s.client.GetSecret(ctx, s.secretName) if err != nil { @@ -107,7 +127,7 @@ func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) { Value: map[string][]byte{sanitizeKey(id): bs}, }, } - if err := s.client.JSONPatchSecret(ctx, s.secretName, m); err != nil { + if err := s.client.JSONPatchResource(ctx, s.secretName, kubeclient.TypeSecrets, m); err != nil { return fmt.Errorf("error patching Secret %s with a /data field: %v", s.secretName, err) } return nil @@ -119,8 +139,8 @@ func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) { Value: bs, }, } - if err := s.client.JSONPatchSecret(ctx, s.secretName, m); err != nil { - return fmt.Errorf("error patching Secret %s with /data/%s field", s.secretName, sanitizeKey(id)) + if err := s.client.JSONPatchResource(ctx, s.secretName, kubeclient.TypeSecrets, m); err != nil { + return fmt.Errorf("error patching Secret %s with /data/%s field: %v", s.secretName, sanitizeKey(id), err) } return nil } @@ -131,7 +151,7 @@ func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) { return err } -func (s *Store) loadState() error { +func (s *Store) loadState() (err error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() @@ -140,8 +160,14 @@ func (s *Store) loadState() error { if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 { return ipn.ErrStateNotExist } + if err := s.client.Event(ctx, eventTypeWarning, reasonTailscaleStateLoadFailed, err.Error()); err != nil { + log.Printf("kubestore: error creating Event: %v", err) + } return err } + if err := s.client.Event(ctx, eventTypeNormal, reasonTailscaleStateLoaded, "Successfully loaded tailscaled state from Secret"); err != nil { + log.Printf("kubestore: error creating Event: %v", err) + } s.memory.LoadFromMap(secret.Data) return nil } diff --git a/kube/kubeapi/api.go b/kube/kubeapi/api.go index 0e42437a69a2a..a2ae8cc79f20d 100644 --- a/kube/kubeapi/api.go +++ b/kube/kubeapi/api.go @@ -7,7 +7,9 @@ // dependency size for those consumers when adding anything new here. package kubeapi -import "time" +import ( + "time" +) // Note: The API types are copied from k8s.io/api{,machinery} to not introduce a // module dependency on the Kubernetes API as it pulls in many more dependencies. @@ -151,6 +153,57 @@ type Secret struct { Data map[string][]byte `json:"data,omitempty"` } +// Event contains a subset of fields from corev1.Event. +// https://github.com/kubernetes/api/blob/6cc44b8953ae704d6d9ec2adf32e7ae19199ea9f/core/v1/types.go#L7034 +// It is copied here to avoid having to import kube libraries. +type Event struct { + TypeMeta `json:",inline"` + ObjectMeta `json:"metadata"` + Message string `json:"message,omitempty"` + Reason string `json:"reason,omitempty"` + Source EventSource `json:"source,omitempty"` // who is emitting this Event + Type string `json:"type,omitempty"` // Normal or Warning + // InvolvedObject is the subject of the Event. `kubectl describe` will, for most object types, display any + // currently present cluster Events matching the object (but you probably want to set UID for this to work). + InvolvedObject ObjectReference `json:"involvedObject"` + Count int32 `json:"count,omitempty"` // how many times Event was observed + FirstTimestamp time.Time `json:"firstTimestamp,omitempty"` + LastTimestamp time.Time `json:"lastTimestamp,omitempty"` +} + +// EventSource includes a subset of fields from corev1.EventSource. +// https://github.com/kubernetes/api/blob/6cc44b8953ae704d6d9ec2adf32e7ae19199ea9f/core/v1/types.go#L7007 +// It is copied here to avoid having to import kube libraries. +type EventSource struct { + // Component is the name of the component that is emitting the Event. + Component string `json:"component,omitempty"` +} + +// ObjectReference contains a subset of fields from corev1.ObjectReference. +// https://github.com/kubernetes/api/blob/6cc44b8953ae704d6d9ec2adf32e7ae19199ea9f/core/v1/types.go#L6902 +// It is copied here to avoid having to import kube libraries. +type ObjectReference struct { + // Kind of the referent. + // More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds + // +optional + Kind string `json:"kind,omitempty"` + // Namespace of the referent. + // More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/namespaces/ + // +optional + Namespace string `json:"namespace,omitempty"` + // Name of the referent. + // More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names + // +optional + Name string `json:"name,omitempty"` + // UID of the referent. + // More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#uids + // +optional + UID string `json:"uid,omitempty"` + // API version of the referent. + // +optional + APIVersion string `json:"apiVersion,omitempty"` +} + // Status is a return value for calls that don't return other objects. type Status struct { TypeMeta `json:",inline"` @@ -186,6 +239,6 @@ type Status struct { Code int `json:"code,omitempty"` } -func (s *Status) Error() string { +func (s Status) Error() string { return s.Message } diff --git a/kube/kubeclient/client.go b/kube/kubeclient/client.go index e8ddec75d1584..d4309448df030 100644 --- a/kube/kubeclient/client.go +++ b/kube/kubeclient/client.go @@ -23,16 +23,21 @@ import ( "net/url" "os" "path/filepath" + "strings" "sync" "time" "tailscale.com/kube/kubeapi" + "tailscale.com/tstime" "tailscale.com/util/multierr" ) const ( saPath = "/var/run/secrets/kubernetes.io/serviceaccount" defaultURL = "https://kubernetes.default.svc" + + TypeSecrets = "secrets" + typeEvents = "events" ) // rootPathForTests is set by tests to override the root path to the @@ -57,8 +62,13 @@ type Client interface { GetSecret(context.Context, string) (*kubeapi.Secret, error) UpdateSecret(context.Context, *kubeapi.Secret) error CreateSecret(context.Context, *kubeapi.Secret) error + // Event attempts to ensure an event with the specified options associated with the Pod in which we are + // currently running. This is best effort - if the client is not able to create events, this operation will be a + // no-op. If there is already an Event with the given reason for the current Pod, it will get updated (only + // count and timestamp are expected to change), else a new event will be created. + Event(_ context.Context, typ, reason, msg string) error StrategicMergePatchSecret(context.Context, string, *kubeapi.Secret, string) error - JSONPatchSecret(context.Context, string, []JSONPatch) error + JSONPatchResource(_ context.Context, resourceName string, resourceType string, patches []JSONPatch) error CheckSecretPermissions(context.Context, string) (bool, bool, error) SetDialer(dialer func(context.Context, string, string) (net.Conn, error)) SetURL(string) @@ -66,15 +76,24 @@ type Client interface { type client struct { mu sync.Mutex + name string url string - ns string + podName string + podUID string + ns string // Pod namespace client *http.Client token string tokenExpiry time.Time + cl tstime.Clock + // hasEventsPerms is true if client can emit Events for the Pod in which it runs. If it is set to false any + // calls to Events() will be a no-op. + hasEventsPerms bool + // kubeAPIRequest sends a request to the kube API server. It can set to a fake in tests. + kubeAPIRequest kubeAPIRequestFunc } // New returns a new client -func New() (Client, error) { +func New(name string) (Client, error) { ns, err := readFile("namespace") if err != nil { return nil, err @@ -87,9 +106,11 @@ func New() (Client, error) { if ok := cp.AppendCertsFromPEM(caCert); !ok { return nil, fmt.Errorf("kube: error in creating root cert pool") } - return &client{ - url: defaultURL, - ns: string(ns), + c := &client{ + url: defaultURL, + ns: string(ns), + name: name, + cl: tstime.DefaultClock{}, client: &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ @@ -97,7 +118,10 @@ func New() (Client, error) { }, }, }, - }, nil + } + c.kubeAPIRequest = newKubeAPIRequest(c) + c.setEventPerms() + return c, nil } // SetURL sets the URL to use for the Kubernetes API. @@ -115,14 +139,14 @@ func (c *client) SetDialer(dialer func(ctx context.Context, network, addr string func (c *client) expireToken() { c.mu.Lock() defer c.mu.Unlock() - c.tokenExpiry = time.Now() + c.tokenExpiry = c.cl.Now() } func (c *client) getOrRenewToken() (string, error) { c.mu.Lock() defer c.mu.Unlock() tk, te := c.token, c.tokenExpiry - if time.Now().Before(te) { + if c.cl.Now().Before(te) { return tk, nil } @@ -131,17 +155,10 @@ func (c *client) getOrRenewToken() (string, error) { return "", err } c.token = string(tkb) - c.tokenExpiry = time.Now().Add(30 * time.Minute) + c.tokenExpiry = c.cl.Now().Add(30 * time.Minute) return c.token, nil } -func (c *client) secretURL(name string) string { - if name == "" { - return fmt.Sprintf("%s/api/v1/namespaces/%s/secrets", c.url, c.ns) - } - return fmt.Sprintf("%s/api/v1/namespaces/%s/secrets/%s", c.url, c.ns, name) -} - func getError(resp *http.Response) error { if resp.StatusCode == 200 || resp.StatusCode == 201 { // These are the only success codes returned by the Kubernetes API. @@ -161,36 +178,41 @@ func setHeader(key, value string) func(*http.Request) { } } -// doRequest performs an HTTP request to the Kubernetes API. -// If in is not nil, it is expected to be a JSON-encodable object and will be -// sent as the request body. -// If out is not nil, it is expected to be a pointer to an object that can be -// decoded from JSON. -// If the request fails with a 401, the token is expired and a new one is -// requested. -func (c *client) doRequest(ctx context.Context, method, url string, in, out any, opts ...func(*http.Request)) error { - req, err := c.newRequest(ctx, method, url, in) - if err != nil { - return err - } - for _, opt := range opts { - opt(req) - } - resp, err := c.client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if err := getError(resp); err != nil { - if st, ok := err.(*kubeapi.Status); ok && st.Code == 401 { - c.expireToken() +type kubeAPIRequestFunc func(ctx context.Context, method, url string, in, out any, opts ...func(*http.Request)) error + +// newKubeAPIRequest returns a function that can perform an HTTP request to the Kubernetes API. +func newKubeAPIRequest(c *client) kubeAPIRequestFunc { + // If in is not nil, it is expected to be a JSON-encodable object and will be + // sent as the request body. + // If out is not nil, it is expected to be a pointer to an object that can be + // decoded from JSON. + // If the request fails with a 401, the token is expired and a new one is + // requested. + f := func(ctx context.Context, method, url string, in, out any, opts ...func(*http.Request)) error { + req, err := c.newRequest(ctx, method, url, in) + if err != nil { + return err } - return err - } - if out != nil { - return json.NewDecoder(resp.Body).Decode(out) + for _, opt := range opts { + opt(req) + } + resp, err := c.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if err := getError(resp); err != nil { + if st, ok := err.(*kubeapi.Status); ok && st.Code == 401 { + c.expireToken() + } + return err + } + if out != nil { + return json.NewDecoder(resp.Body).Decode(out) + } + return nil } - return nil + return f } func (c *client) newRequest(ctx context.Context, method, url string, in any) (*http.Request, error) { @@ -226,7 +248,7 @@ func (c *client) newRequest(ctx context.Context, method, url string, in any) (*h // GetSecret fetches the secret from the Kubernetes API. func (c *client) GetSecret(ctx context.Context, name string) (*kubeapi.Secret, error) { s := &kubeapi.Secret{Data: make(map[string][]byte)} - if err := c.doRequest(ctx, "GET", c.secretURL(name), nil, s); err != nil { + if err := c.kubeAPIRequest(ctx, "GET", c.resourceURL(name, TypeSecrets), nil, s); err != nil { return nil, err } return s, nil @@ -235,16 +257,16 @@ func (c *client) GetSecret(ctx context.Context, name string) (*kubeapi.Secret, e // CreateSecret creates a secret in the Kubernetes API. func (c *client) CreateSecret(ctx context.Context, s *kubeapi.Secret) error { s.Namespace = c.ns - return c.doRequest(ctx, "POST", c.secretURL(""), s, nil) + return c.kubeAPIRequest(ctx, "POST", c.resourceURL("", TypeSecrets), s, nil) } // UpdateSecret updates a secret in the Kubernetes API. func (c *client) UpdateSecret(ctx context.Context, s *kubeapi.Secret) error { - return c.doRequest(ctx, "PUT", c.secretURL(s.Name), s, nil) + return c.kubeAPIRequest(ctx, "PUT", c.resourceURL(s.Name, TypeSecrets), s, nil) } // JSONPatch is a JSON patch operation. -// It currently (2023-03-02) only supports "add" and "remove" operations. +// It currently (2024-11-15) only supports "add", "remove" and "replace" operations. // // https://tools.ietf.org/html/rfc6902 type JSONPatch struct { @@ -253,22 +275,22 @@ type JSONPatch struct { Value any `json:"value,omitempty"` } -// JSONPatchSecret updates a secret in the Kubernetes API using a JSON patch. -// It currently (2023-03-02) only supports "add" and "remove" operations. -func (c *client) JSONPatchSecret(ctx context.Context, name string, patch []JSONPatch) error { - for _, p := range patch { +// JSONPatchResource updates a resource in the Kubernetes API using a JSON patch. +// It currently (2024-11-15) only supports "add", "remove" and "replace" operations. +func (c *client) JSONPatchResource(ctx context.Context, name, typ string, patches []JSONPatch) error { + for _, p := range patches { if p.Op != "remove" && p.Op != "add" && p.Op != "replace" { return fmt.Errorf("unsupported JSON patch operation: %q", p.Op) } } - return c.doRequest(ctx, "PATCH", c.secretURL(name), patch, nil, setHeader("Content-Type", "application/json-patch+json")) + return c.kubeAPIRequest(ctx, "PATCH", c.resourceURL(name, typ), patches, nil, setHeader("Content-Type", "application/json-patch+json")) } // StrategicMergePatchSecret updates a secret in the Kubernetes API using a // strategic merge patch. // If a fieldManager is provided, it will be used to track the patch. func (c *client) StrategicMergePatchSecret(ctx context.Context, name string, s *kubeapi.Secret, fieldManager string) error { - surl := c.secretURL(name) + surl := c.resourceURL(name, TypeSecrets) if fieldManager != "" { uv := url.Values{ "fieldManager": {fieldManager}, @@ -277,7 +299,66 @@ func (c *client) StrategicMergePatchSecret(ctx context.Context, name string, s * } s.Namespace = c.ns s.Name = name - return c.doRequest(ctx, "PATCH", surl, s, nil, setHeader("Content-Type", "application/strategic-merge-patch+json")) + return c.kubeAPIRequest(ctx, "PATCH", surl, s, nil, setHeader("Content-Type", "application/strategic-merge-patch+json")) +} + +// Event tries to ensure an Event associated with the Pod in which we are running. It is best effort - the event will be +// created if the kube client on startup was able to determine the name and UID of this Pod from POD_NAME,POD_UID env +// vars and if permissions check for event creation succeeded. Events are keyed on opts.Reason- if an Event for the +// current Pod with that reason already exists, its count and first timestamp will be updated, else a new Event will be +// created. +func (c *client) Event(ctx context.Context, typ, reason, msg string) error { + if !c.hasEventsPerms { + return nil + } + name := c.nameForEvent(reason) + ev, err := c.getEvent(ctx, name) + now := c.cl.Now() + if err != nil { + if !IsNotFoundErr(err) { + return err + } + // Event not found - create it + ev := kubeapi.Event{ + ObjectMeta: kubeapi.ObjectMeta{ + Name: name, + Namespace: c.ns, + }, + Type: typ, + Reason: reason, + Message: msg, + Source: kubeapi.EventSource{ + Component: c.name, + }, + InvolvedObject: kubeapi.ObjectReference{ + Name: c.podName, + Namespace: c.ns, + UID: c.podUID, + Kind: "Pod", + APIVersion: "v1", + }, + + FirstTimestamp: now, + LastTimestamp: now, + Count: 1, + } + return c.kubeAPIRequest(ctx, "POST", c.resourceURL("", typeEvents), &ev, nil) + } + // If the Event already exists, we patch its count and last timestamp. This ensures that when users run 'kubectl + // describe pod...', they see the event just once (but with a message of how many times it has appeared over + // last timestamp - first timestamp period of time). + count := ev.Count + 1 + countPatch := JSONPatch{ + Op: "replace", + Value: count, + Path: "/count", + } + tsPatch := JSONPatch{ + Op: "replace", + Value: now, + Path: "/lastTimestamp", + } + return c.JSONPatchResource(ctx, name, typeEvents, []JSONPatch{countPatch, tsPatch}) } // CheckSecretPermissions checks the secret access permissions of the current @@ -293,7 +374,7 @@ func (c *client) StrategicMergePatchSecret(ctx context.Context, name string, s * func (c *client) CheckSecretPermissions(ctx context.Context, secretName string) (canPatch, canCreate bool, err error) { var errs []error for _, verb := range []string{"get", "update"} { - ok, err := c.checkPermission(ctx, verb, secretName) + ok, err := c.checkPermission(ctx, verb, TypeSecrets, secretName) if err != nil { log.Printf("error checking %s permission on secret %s: %v", verb, secretName, err) } else if !ok { @@ -303,12 +384,12 @@ func (c *client) CheckSecretPermissions(ctx context.Context, secretName string) if len(errs) > 0 { return false, false, multierr.New(errs...) } - canPatch, err = c.checkPermission(ctx, "patch", secretName) + canPatch, err = c.checkPermission(ctx, "patch", TypeSecrets, secretName) if err != nil { log.Printf("error checking patch permission on secret %s: %v", secretName, err) return false, false, nil } - canCreate, err = c.checkPermission(ctx, "create", secretName) + canCreate, err = c.checkPermission(ctx, "create", TypeSecrets, secretName) if err != nil { log.Printf("error checking create permission on secret %s: %v", secretName, err) return false, false, nil @@ -316,19 +397,64 @@ func (c *client) CheckSecretPermissions(ctx context.Context, secretName string) return canPatch, canCreate, nil } -// checkPermission reports whether the current pod has permission to use the -// given verb (e.g. get, update, patch, create) on secretName. -func (c *client) checkPermission(ctx context.Context, verb, secretName string) (bool, error) { +func IsNotFoundErr(err error) bool { + if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 { + return true + } + return false +} + +// setEventPerms checks whether this client will be able to write tailscaled Events to its Pod and updates the state +// accordingly. If it determines that the client can not write Events, any subsequent calls to client.Event will be a +// no-op. +func (c *client) setEventPerms() { + name := os.Getenv("POD_NAME") + uid := os.Getenv("POD_UID") + hasPerms := false + defer func() { + c.podName = name + c.podUID = uid + c.hasEventsPerms = hasPerms + if !hasPerms { + log.Printf(`kubeclient: this client is not able to write tailscaled Events to the Pod in which it is running. + To help with future debugging you can make it able write Events by giving it get,create,patch permissions for Events in the Pod namespace + and setting POD_NAME, POD_UID env vars for the Pod.`) + } + }() + if name == "" || uid == "" { + return + } + for _, verb := range []string{"get", "create", "patch"} { + can, err := c.checkPermission(context.Background(), verb, typeEvents, "") + if err != nil { + log.Printf("kubeclient: error checking Events permissions: %v", err) + return + } + if !can { + return + } + } + hasPerms = true + return +} + +// checkPermission reports whether the current pod has permission to use the given verb (e.g. get, update, patch, +// create) on the given resource type. If name is not an empty string, will check the check will be for resource with +// the given name only. +func (c *client) checkPermission(ctx context.Context, verb, typ, name string) (bool, error) { + ra := map[string]any{ + "namespace": c.ns, + "verb": verb, + "resource": typ, + } + if name != "" { + ra["name"] = name + } sar := map[string]any{ "apiVersion": "authorization.k8s.io/v1", "kind": "SelfSubjectAccessReview", "spec": map[string]any{ - "resourceAttributes": map[string]any{ - "namespace": c.ns, - "verb": verb, - "resource": "secrets", - "name": secretName, - }, + "resourceAttributes": ra, }, } var res struct { @@ -337,15 +463,32 @@ func (c *client) checkPermission(ctx context.Context, verb, secretName string) ( } `json:"status"` } url := c.url + "/apis/authorization.k8s.io/v1/selfsubjectaccessreviews" - if err := c.doRequest(ctx, "POST", url, sar, &res); err != nil { + if err := c.kubeAPIRequest(ctx, "POST", url, sar, &res); err != nil { return false, err } return res.Status.Allowed, nil } -func IsNotFoundErr(err error) bool { - if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 { - return true +// resourceURL returns a URL that can be used to interact with the given resource type and, if name is not empty string, +// the named resource of that type. +// Note that this only works for core/v1 resource types. +func (c *client) resourceURL(name, typ string) string { + if name == "" { + return fmt.Sprintf("%s/api/v1/namespaces/%s/%s", c.url, c.ns, typ) } - return false + return fmt.Sprintf("%s/api/v1/namespaces/%s/%s/%s", c.url, c.ns, typ, name) +} + +// nameForEvent returns a name for the Event that uniquely identifies Event with that reason for the current Pod. +func (c *client) nameForEvent(reason string) string { + return fmt.Sprintf("%s.%s.%s", c.podName, c.podUID, strings.ToLower(reason)) +} + +// getEvent fetches the event from the Kubernetes API. +func (c *client) getEvent(ctx context.Context, name string) (*kubeapi.Event, error) { + e := &kubeapi.Event{} + if err := c.kubeAPIRequest(ctx, "GET", c.resourceURL(name, typeEvents), nil, e); err != nil { + return nil, err + } + return e, nil } diff --git a/kube/kubeclient/client_test.go b/kube/kubeclient/client_test.go new file mode 100644 index 0000000000000..6b5e8171c5a76 --- /dev/null +++ b/kube/kubeclient/client_test.go @@ -0,0 +1,151 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package kubeclient + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/kube/kubeapi" + "tailscale.com/tstest" +) + +func Test_client_Event(t *testing.T) { + cl := &tstest.Clock{} + tests := []struct { + name string + typ string + reason string + msg string + argSets []args + wantErr bool + }{ + { + name: "new_event_gets_created", + typ: "Normal", + reason: "TestReason", + msg: "TestMessage", + argSets: []args{ + { // request to GET event returns not found + wantsMethod: "GET", + wantsURL: "test-apiserver/api/v1/namespaces/test-ns/events/test-pod.test-uid.testreason", + setErr: &kubeapi.Status{Code: 404}, + }, + { // sends POST request to create event + wantsMethod: "POST", + wantsURL: "test-apiserver/api/v1/namespaces/test-ns/events", + wantsIn: &kubeapi.Event{ + ObjectMeta: kubeapi.ObjectMeta{ + Name: "test-pod.test-uid.testreason", + Namespace: "test-ns", + }, + Type: "Normal", + Reason: "TestReason", + Message: "TestMessage", + Source: kubeapi.EventSource{ + Component: "test-client", + }, + InvolvedObject: kubeapi.ObjectReference{ + Name: "test-pod", + UID: "test-uid", + Namespace: "test-ns", + APIVersion: "v1", + Kind: "Pod", + }, + FirstTimestamp: cl.Now(), + LastTimestamp: cl.Now(), + Count: 1, + }, + }, + }, + }, + { + name: "existing_event_gets_patched", + typ: "Warning", + reason: "TestReason", + msg: "TestMsg", + argSets: []args{ + { // request to GET event does not error - this is enough to assume that event exists + wantsMethod: "GET", + wantsURL: "test-apiserver/api/v1/namespaces/test-ns/events/test-pod.test-uid.testreason", + setOut: []byte(`{"count":2}`), + }, + { // sends PATCH request to update the event + wantsMethod: "PATCH", + wantsURL: "test-apiserver/api/v1/namespaces/test-ns/events/test-pod.test-uid.testreason", + wantsIn: []JSONPatch{ + {Op: "replace", Path: "/count", Value: int32(3)}, + {Op: "replace", Path: "/lastTimestamp", Value: cl.Now()}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &client{ + cl: cl, + name: "test-client", + podName: "test-pod", + podUID: "test-uid", + url: "test-apiserver", + ns: "test-ns", + kubeAPIRequest: fakeKubeAPIRequest(t, tt.argSets), + hasEventsPerms: true, + } + if err := c.Event(context.Background(), tt.typ, tt.reason, tt.msg); (err != nil) != tt.wantErr { + t.Errorf("client.Event() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// args is a set of values for testing a single call to client.kubeAPIRequest. +type args struct { + // wantsMethod is the expected value of 'method' arg. + wantsMethod string + // wantsURL is the expected value of 'url' arg. + wantsURL string + // wantsIn is the expected value of 'in' arg. + wantsIn any + // setOut can be set to a byte slice representing valid JSON. If set 'out' arg will get set to the unmarshalled + // JSON object. + setOut []byte + // setErr is the error that kubeAPIRequest will return. + setErr error +} + +// fakeKubeAPIRequest can be used to test that a series of calls to client.kubeAPIRequest gets called with expected +// values and to set these calls to return preconfigured values. 'argSets' should be set to a slice of expected +// arguments and should-be return values of a series of kubeAPIRequest calls. +func fakeKubeAPIRequest(t *testing.T, argSets []args) kubeAPIRequestFunc { + count := 0 + f := func(ctx context.Context, gotMethod, gotUrl string, gotIn, gotOut any, opts ...func(*http.Request)) error { + t.Helper() + if count >= len(argSets) { + t.Fatalf("unexpected call to client.kubeAPIRequest, expected %d calls, but got a %dth call", len(argSets), count+1) + } + a := argSets[count] + if gotMethod != a.wantsMethod { + t.Errorf("[%d] got method %q, wants method %q", count, gotMethod, a.wantsMethod) + } + if gotUrl != a.wantsURL { + t.Errorf("[%d] got URL %q, wants URL %q", count, gotMethod, a.wantsMethod) + } + if d := cmp.Diff(gotIn, a.wantsIn); d != "" { + t.Errorf("[%d] unexpected payload (-want + got):\n%s", count, d) + } + if len(a.setOut) != 0 { + if err := json.Unmarshal(a.setOut, gotOut); err != nil { + t.Fatalf("[%d] error unmarshalling output: %v", count, err) + } + } + count++ + return a.setErr + } + return f +} diff --git a/kube/kubeclient/fake_client.go b/kube/kubeclient/fake_client.go index 3cef3d27ee0df..5716ca31b2f4c 100644 --- a/kube/kubeclient/fake_client.go +++ b/kube/kubeclient/fake_client.go @@ -29,7 +29,11 @@ func (fc *FakeClient) SetDialer(dialer func(ctx context.Context, network, addr s func (fc *FakeClient) StrategicMergePatchSecret(context.Context, string, *kubeapi.Secret, string) error { return nil } -func (fc *FakeClient) JSONPatchSecret(context.Context, string, []JSONPatch) error { +func (fc *FakeClient) Event(context.Context, string, string, string) error { + return nil +} + +func (fc *FakeClient) JSONPatchResource(context.Context, string, string, []JSONPatch) error { return nil } func (fc *FakeClient) UpdateSecret(context.Context, *kubeapi.Secret) error { return nil } From bb3d0cae5f7669a4d665c2c282be770b9297650d Mon Sep 17 00:00:00 2001 From: License Updater Date: Mon, 18 Nov 2024 15:02:33 +0000 Subject: [PATCH 115/179] licenses: update license notices Signed-off-by: License Updater --- licenses/apple.md | 31 +++++++++++++++---------------- licenses/tailscale.md | 12 ++++++------ licenses/windows.md | 30 +++++++++++++++--------------- 3 files changed, 36 insertions(+), 37 deletions(-) diff --git a/licenses/apple.md b/licenses/apple.md index 36c654c59c026..aae006c95ede4 100644 --- a/licenses/apple.md +++ b/licenses/apple.md @@ -12,24 +12,23 @@ See also the dependencies in the [Tailscale CLI][]. - [filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519) ([BSD-3-Clause](https://github.com/FiloSottile/edwards25519/blob/v1.1.0/LICENSE)) - - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.30.4/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.32.4/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.27.28/config/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.17.28/credentials/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.16.12/feature/ec2/imds/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.3.16/internal/configsources/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.6.16/internal/endpoints/v2/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.3.23/internal/configsources/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.6.23/internal/endpoints/v2/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.8.1/internal/ini/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.30.4/internal/sync/singleflight/LICENSE)) - - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.11.4/service/internal/accept-encoding/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.32.4/internal/sync/singleflight/LICENSE)) + - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.12.0/service/internal/accept-encoding/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.11.18/service/internal/presigned-url/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/ssm](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssm) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssm/v1.45.0/service/ssm/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.22.5/service/sso/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.26.5/service/ssooidc/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.30.4/service/sts/LICENSE.txt)) - - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.20.4/LICENSE)) - - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.20.4/internal/sync/singleflight/LICENSE)) + - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.22.0/LICENSE)) + - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.22.0/internal/sync/singleflight/LICENSE)) - [github.com/bits-and-blooms/bitset](https://pkg.go.dev/github.com/bits-and-blooms/bitset) ([BSD-3-Clause](https://github.com/bits-and-blooms/bitset/blob/v1.13.0/LICENSE)) - - [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) ([ISC](https://github.com/coder/websocket/blob/v1.8.12/LICENSE.txt)) - [github.com/coreos/go-iptables/iptables](https://pkg.go.dev/github.com/coreos/go-iptables/iptables) ([Apache-2.0](https://github.com/coreos/go-iptables/blob/65c67c9f46e6/LICENSE)) - [github.com/digitalocean/go-smbios/smbios](https://pkg.go.dev/github.com/digitalocean/go-smbios/smbios) ([Apache-2.0](https://github.com/digitalocean/go-smbios/blob/390a4f403a8e/LICENSE.md)) - [github.com/djherbis/times](https://pkg.go.dev/github.com/djherbis/times) ([MIT](https://github.com/djherbis/times/blob/v1.6.0/LICENSE)) @@ -48,9 +47,9 @@ See also the dependencies in the [Tailscale CLI][]. - [github.com/jmespath/go-jmespath](https://pkg.go.dev/github.com/jmespath/go-jmespath) ([Apache-2.0](https://github.com/jmespath/go-jmespath/blob/v0.4.0/LICENSE)) - [github.com/josharian/native](https://pkg.go.dev/github.com/josharian/native) ([MIT](https://github.com/josharian/native/blob/5c7d0dd6ab86/license)) - [github.com/jsimonetti/rtnetlink](https://pkg.go.dev/github.com/jsimonetti/rtnetlink) ([MIT](https://github.com/jsimonetti/rtnetlink/blob/v1.4.1/LICENSE.md)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.8/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.8/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.8/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.11/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.11/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.11/zstd/internal/xxhash/LICENSE.txt)) - [github.com/kortschak/wol](https://pkg.go.dev/github.com/kortschak/wol) ([BSD-3-Clause](https://github.com/kortschak/wol/blob/da482cc4850a/LICENSE)) - [github.com/mdlayher/genetlink](https://pkg.go.dev/github.com/mdlayher/genetlink) ([MIT](https://github.com/mdlayher/genetlink/blob/v1.3.2/LICENSE.md)) - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/v1.7.2/LICENSE.md)) @@ -74,12 +73,12 @@ See also the dependencies in the [Tailscale CLI][]. - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.28.0:LICENSE)) - - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fe59bbe5:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.8.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.26.0:LICENSE)) + - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fc45aab8:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.30.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.9.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.27.0:LICENSE)) - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.25.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.19.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.20.0:LICENSE)) - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE)) - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/64c016c92987/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) diff --git a/licenses/tailscale.md b/licenses/tailscale.md index b1303d2a6dd8e..8f05acedcf93f 100644 --- a/licenses/tailscale.md +++ b/licenses/tailscale.md @@ -58,9 +58,9 @@ Some packages may only be included on certain architectures or operating systems - [github.com/jmespath/go-jmespath](https://pkg.go.dev/github.com/jmespath/go-jmespath) ([Apache-2.0](https://github.com/jmespath/go-jmespath/blob/v0.4.0/LICENSE)) - [github.com/josharian/native](https://pkg.go.dev/github.com/josharian/native) ([MIT](https://github.com/josharian/native/blob/5c7d0dd6ab86/license)) - [github.com/kballard/go-shellquote](https://pkg.go.dev/github.com/kballard/go-shellquote) ([MIT](https://github.com/kballard/go-shellquote/blob/95032a82bc51/LICENSE)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.4/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.4/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.4/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.11/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.11/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.11/zstd/internal/xxhash/LICENSE.txt)) - [github.com/kortschak/wol](https://pkg.go.dev/github.com/kortschak/wol) ([BSD-3-Clause](https://github.com/kortschak/wol/blob/da482cc4850a/LICENSE)) - [github.com/kr/fs](https://pkg.go.dev/github.com/kr/fs) ([BSD-3-Clause](https://github.com/kr/fs/blob/v0.1.0/LICENSE)) - [github.com/mattn/go-colorable](https://pkg.go.dev/github.com/mattn/go-colorable) ([MIT](https://github.com/mattn/go-colorable/blob/v0.1.13/LICENSE)) @@ -84,7 +84,7 @@ Some packages may only be included on certain architectures or operating systems - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/b535050b2aa4/LICENSE)) - [github.com/tailscale/web-client-prebuilt](https://pkg.go.dev/github.com/tailscale/web-client-prebuilt) ([BSD-3-Clause](https://github.com/tailscale/web-client-prebuilt/blob/5db17b287bf1/LICENSE)) - [github.com/tailscale/wf](https://pkg.go.dev/github.com/tailscale/wf) ([BSD-3-Clause](https://github.com/tailscale/wf/blob/6fbb0a674ee6/LICENSE)) - - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/799c1978fafc/LICENSE)) + - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/4e883d38c8d3/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - [github.com/tcnksm/go-httpstat](https://pkg.go.dev/github.com/tcnksm/go-httpstat) ([MIT](https://github.com/tcnksm/go-httpstat/blob/v0.2.0/LICENSE)) - [github.com/toqueteos/webbrowser](https://pkg.go.dev/github.com/toqueteos/webbrowser) ([MIT](https://github.com/toqueteos/webbrowser/blob/v1.2.0/LICENSE.md)) @@ -98,8 +98,8 @@ Some packages may only be included on certain architectures or operating systems - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/1b970713:LICENSE)) - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE)) - [golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) ([BSD-3-Clause](https://cs.opensource.google/go/x/oauth2/+/v0.16.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.7.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.9.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.27.0:LICENSE)) - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE)) - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE)) - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE)) diff --git a/licenses/windows.md b/licenses/windows.md index 8cef256853e56..4cb35e8de2785 100644 --- a/licenses/windows.md +++ b/licenses/windows.md @@ -13,22 +13,22 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/alexbrainman/sspi](https://pkg.go.dev/github.com/alexbrainman/sspi) ([BSD-3-Clause](https://github.com/alexbrainman/sspi/blob/1a75b4708caa/LICENSE)) - [github.com/apenwarr/fixconsole](https://pkg.go.dev/github.com/apenwarr/fixconsole) ([Apache-2.0](https://github.com/apenwarr/fixconsole/blob/5a9f6489cc29/LICENSE)) - [github.com/apenwarr/w32](https://pkg.go.dev/github.com/apenwarr/w32) ([BSD-3-Clause](https://github.com/apenwarr/w32/blob/aa00fece76ab/LICENSE)) - - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.30.4/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.32.4/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.27.28/config/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.17.28/credentials/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.16.12/feature/ec2/imds/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.3.16/internal/configsources/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.6.16/internal/endpoints/v2/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.3.23/internal/configsources/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.6.23/internal/endpoints/v2/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.8.1/internal/ini/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.30.4/internal/sync/singleflight/LICENSE)) - - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.11.4/service/internal/accept-encoding/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.32.4/internal/sync/singleflight/LICENSE)) + - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.12.0/service/internal/accept-encoding/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.11.18/service/internal/presigned-url/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/ssm](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssm) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssm/v1.45.0/service/ssm/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.22.5/service/sso/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.26.5/service/ssooidc/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.30.4/service/sts/LICENSE.txt)) - - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.20.4/LICENSE)) - - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.20.4/internal/sync/singleflight/LICENSE)) + - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.22.0/LICENSE)) + - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.22.0/internal/sync/singleflight/LICENSE)) - [github.com/coreos/go-iptables/iptables](https://pkg.go.dev/github.com/coreos/go-iptables/iptables) ([Apache-2.0](https://github.com/coreos/go-iptables/blob/65c67c9f46e6/LICENSE)) - [github.com/dblohm7/wingoes](https://pkg.go.dev/github.com/dblohm7/wingoes) ([BSD-3-Clause](https://github.com/dblohm7/wingoes/blob/b75a8a7d7eb0/LICENSE)) - [github.com/djherbis/times](https://pkg.go.dev/github.com/djherbis/times) ([MIT](https://github.com/djherbis/times/blob/v1.6.0/LICENSE)) @@ -44,9 +44,9 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/jmespath/go-jmespath](https://pkg.go.dev/github.com/jmespath/go-jmespath) ([Apache-2.0](https://github.com/jmespath/go-jmespath/blob/v0.4.0/LICENSE)) - [github.com/josharian/native](https://pkg.go.dev/github.com/josharian/native) ([MIT](https://github.com/josharian/native/blob/5c7d0dd6ab86/license)) - [github.com/jsimonetti/rtnetlink](https://pkg.go.dev/github.com/jsimonetti/rtnetlink) ([MIT](https://github.com/jsimonetti/rtnetlink/blob/v1.4.1/LICENSE.md)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.8/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.8/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.8/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.11/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.11/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.11/zstd/internal/xxhash/LICENSE.txt)) - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/v1.7.2/LICENSE.md)) - [github.com/mdlayher/socket](https://pkg.go.dev/github.com/mdlayher/socket) ([MIT](https://github.com/mdlayher/socket/blob/v0.5.0/LICENSE.md)) - [github.com/miekg/dns](https://pkg.go.dev/github.com/miekg/dns) ([BSD-3-Clause](https://github.com/miekg/dns/blob/v1.1.58/LICENSE)) @@ -66,14 +66,14 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.28.0:LICENSE)) - - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fe59bbe5:LICENSE)) + - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fc45aab8:LICENSE)) - [golang.org/x/image/bmp](https://pkg.go.dev/golang.org/x/image/bmp) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.18.0:LICENSE)) - [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.19.0:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.8.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.26.0:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.30.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.9.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.27.0:LICENSE)) - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.25.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.19.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.20.0:LICENSE)) - [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=0fa3db229ce2)) - [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3)) - [gopkg.in/Knetic/govaluate.v3](https://pkg.go.dev/gopkg.in/Knetic/govaluate.v3) ([MIT](https://github.com/Knetic/govaluate/blob/v3.0.0/LICENSE)) From d62baa45e646c243b0a38e71e7cf76508a1b6c76 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 19 Nov 2024 09:07:32 -0800 Subject: [PATCH 116/179] version: validate Long format on Android builds Updates #14069 Change-Id: I134a90db561dacc4b1c1c66ccadac135b5d64cf3 Signed-off-by: Brad Fitzpatrick --- version/version.go | 40 ++++++++++++++++++++++++++++++++ version/version_checkformat.go | 17 ++++++++++++++ version/version_internal_test.go | 28 ++++++++++++++++++++++ 3 files changed, 85 insertions(+) create mode 100644 version/version_checkformat.go create mode 100644 version/version_internal_test.go diff --git a/version/version.go b/version/version.go index 4b96d15eaa336..5edea22ca6df0 100644 --- a/version/version.go +++ b/version/version.go @@ -7,6 +7,7 @@ package version import ( "fmt" "runtime/debug" + "strconv" "strings" tailscaleroot "tailscale.com" @@ -169,3 +170,42 @@ func majorMinorPatch() string { ret, _, _ := strings.Cut(Short(), "-") return ret } + +func isValidLongWithTwoRepos(v string) bool { + s := strings.Split(v, "-") + if len(s) != 3 { + return false + } + hexChunk := func(s string) bool { + if len(s) < 6 { + return false + } + for i := range len(s) { + b := s[i] + if (b < '0' || b > '9') && (b < 'a' || b > 'f') { + return false + } + } + return true + } + + v, t, g := s[0], s[1], s[2] + if !strings.HasPrefix(t, "t") || !strings.HasPrefix(g, "g") || + !hexChunk(t[1:]) || !hexChunk(g[1:]) { + return false + } + nums := strings.Split(v, ".") + if len(nums) != 3 { + return false + } + for i, n := range nums { + bits := 8 + if i == 2 { + bits = 16 + } + if _, err := strconv.ParseUint(n, 10, bits); err != nil { + return false + } + } + return true +} diff --git a/version/version_checkformat.go b/version/version_checkformat.go new file mode 100644 index 0000000000000..8a24eda13f080 --- /dev/null +++ b/version/version_checkformat.go @@ -0,0 +1,17 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build tailscale_go && android + +package version + +import "fmt" + +func init() { + // For official Android builds using the tailscale_go toolchain, + // panic if the builder is screwed up we fail to stamp a valid + // version string. + if !isValidLongWithTwoRepos(Long()) { + panic(fmt.Sprintf("malformed version.Long value %q", Long())) + } +} diff --git a/version/version_internal_test.go b/version/version_internal_test.go new file mode 100644 index 0000000000000..ce6bd627042d6 --- /dev/null +++ b/version/version_internal_test.go @@ -0,0 +1,28 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version + +import "testing" + +func TestIsValidLongWithTwoRepos(t *testing.T) { + tests := []struct { + long string + want bool + }{ + {"1.2.3-t01234abcde-g01234abcde", true}, + {"1.2.259-t01234abcde-g01234abcde", true}, // big patch version + {"1.2.3-t01234abcde", false}, // missing repo + {"1.2.3-g01234abcde", false}, // missing repo + {"1.2.3-g01234abcde", false}, // missing repo + {"-t01234abcde-g01234abcde", false}, + {"1.2.3", false}, + {"1.2.3-t01234abcde-g", false}, + {"1.2.3-t01234abcde-gERRBUILDINFO", false}, + } + for _, tt := range tests { + if got := isValidLongWithTwoRepos(tt.long); got != tt.want { + t.Errorf("IsValidLongWithTwoRepos(%q) = %v; want %v", tt.long, got, tt.want) + } + } +} From 810da91a9e3e4b2a9fe0e8aba21b10ed5cf9db34 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 19 Nov 2024 10:28:26 -0800 Subject: [PATCH 117/179] version: fix earlier test/wording mistakes Updates #14069 Change-Id: I1d2fd8a8ab6591af11bfb83748b94342a8ac718f Signed-off-by: Brad Fitzpatrick --- version/version_checkformat.go | 2 +- version/version_internal_test.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/version/version_checkformat.go b/version/version_checkformat.go index 8a24eda13f080..05a97d1912dbe 100644 --- a/version/version_checkformat.go +++ b/version/version_checkformat.go @@ -9,7 +9,7 @@ import "fmt" func init() { // For official Android builds using the tailscale_go toolchain, - // panic if the builder is screwed up we fail to stamp a valid + // panic if the builder is screwed up and we fail to stamp a valid // version string. if !isValidLongWithTwoRepos(Long()) { panic(fmt.Sprintf("malformed version.Long value %q", Long())) diff --git a/version/version_internal_test.go b/version/version_internal_test.go index ce6bd627042d6..19aeab44228bd 100644 --- a/version/version_internal_test.go +++ b/version/version_internal_test.go @@ -14,7 +14,6 @@ func TestIsValidLongWithTwoRepos(t *testing.T) { {"1.2.259-t01234abcde-g01234abcde", true}, // big patch version {"1.2.3-t01234abcde", false}, // missing repo {"1.2.3-g01234abcde", false}, // missing repo - {"1.2.3-g01234abcde", false}, // missing repo {"-t01234abcde-g01234abcde", false}, {"1.2.3", false}, {"1.2.3-t01234abcde-g", false}, From 48343ee6738548dd85e908ea14d5f69338123ec1 Mon Sep 17 00:00:00 2001 From: Aaron Klotz Date: Tue, 19 Nov 2024 10:55:58 -0700 Subject: [PATCH 118/179] util/winutil/s4u: fix token handle leak Fixes #14156 Signed-off-by: Aaron Klotz --- util/winutil/s4u/s4u_windows.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/util/winutil/s4u/s4u_windows.go b/util/winutil/s4u/s4u_windows.go index a12b4786a0d06..8926aaedc5071 100644 --- a/util/winutil/s4u/s4u_windows.go +++ b/util/winutil/s4u/s4u_windows.go @@ -17,6 +17,7 @@ import ( "slices" "strconv" "strings" + "sync" "sync/atomic" "unsafe" @@ -128,9 +129,10 @@ func Login(logf logger.Logf, srcName string, u *user.User, capLevel CapabilityLe if err != nil { return nil, err } + tokenCloseOnce := sync.OnceFunc(func() { token.Close() }) defer func() { if err != nil { - token.Close() + tokenCloseOnce() } }() @@ -162,6 +164,7 @@ func Login(logf logger.Logf, srcName string, u *user.User, capLevel CapabilityLe sessToken.Close() } }() + tokenCloseOnce() } userProfile, err := winutil.LoadUserProfile(sessToken, u) From 9f33aeb649f279412f6b7b24a61506ef37fadb47 Mon Sep 17 00:00:00 2001 From: Anton Tolchanov Date: Mon, 11 Nov 2024 16:51:58 +0000 Subject: [PATCH 119/179] wgengine/filter: actually use the passed CapTestFunc [capver 109] Initial support for SrcCaps was added in 5ec01bf but it was not actually working without this. Updates #12542 Signed-off-by: Anton Tolchanov --- tailcfg/tailcfg.go | 5 +++-- wgengine/filter/filter.go | 21 +++++++++++---------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 1b283a2fcebd2..897e8d27f7f7b 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -142,7 +142,7 @@ type CapabilityVersion int // - 97: 2024-06-06: Client understands NodeAttrDisableSplitDNSWhenNoCustomResolvers // - 98: 2024-06-13: iOS/tvOS clients may provide serial number as part of posture information // - 99: 2024-06-14: Client understands NodeAttrDisableLocalDNSOverrideViaNRPT -// - 100: 2024-06-18: Client supports filtertype.Match.SrcCaps (issue #12542) +// - 100: 2024-06-18: Initial support for filtertype.Match.SrcCaps - actually usable in capver 109 (issue #12542) // - 101: 2024-07-01: Client supports SSH agent forwarding when handling connections with /bin/su // - 102: 2024-07-12: NodeAttrDisableMagicSockCryptoRouting support // - 103: 2024-07-24: Client supports NodeAttrDisableCaptivePortalDetection @@ -151,7 +151,8 @@ type CapabilityVersion int // - 106: 2024-09-03: fix panic regression from cryptokey routing change (65fe0ba7b5) // - 107: 2024-10-30: add App Connector to conffile (PR #13942) // - 108: 2024-11-08: Client sends ServicesHash in Hostinfo, understands c2n GET /vip-services. -const CurrentCapabilityVersion CapabilityVersion = 108 +// - 109: 2024-11-18: Client supports filtertype.Match.SrcCaps (issue #12542) +const CurrentCapabilityVersion CapabilityVersion = 109 type StableID string diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index 56224ac5d3fbc..9e5d8a37f2b24 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -202,16 +202,17 @@ func New(matches []Match, capTest CapTestFunc, localNets, logIPs *netipx.IPSet, } f := &Filter{ - logf: logf, - matches4: matchesFamily(matches, netip.Addr.Is4), - matches6: matchesFamily(matches, netip.Addr.Is6), - cap4: capMatchesFunc(matches, netip.Addr.Is4), - cap6: capMatchesFunc(matches, netip.Addr.Is6), - local4: ipset.FalseContainsIPFunc(), - local6: ipset.FalseContainsIPFunc(), - logIPs4: ipset.FalseContainsIPFunc(), - logIPs6: ipset.FalseContainsIPFunc(), - state: state, + logf: logf, + matches4: matchesFamily(matches, netip.Addr.Is4), + matches6: matchesFamily(matches, netip.Addr.Is6), + cap4: capMatchesFunc(matches, netip.Addr.Is4), + cap6: capMatchesFunc(matches, netip.Addr.Is6), + local4: ipset.FalseContainsIPFunc(), + local6: ipset.FalseContainsIPFunc(), + logIPs4: ipset.FalseContainsIPFunc(), + logIPs6: ipset.FalseContainsIPFunc(), + state: state, + srcIPHasCap: capTest, } if localNets != nil { p := localNets.Prefixes() From 303a4a1dfb2408e4dbe07bf4ddc66457bac85d03 Mon Sep 17 00:00:00 2001 From: James Stocker Date: Wed, 20 Nov 2024 07:43:59 +0100 Subject: [PATCH 120/179] Make the deployment of an IngressClass optional, default to true (#14153) Fixes tailscale/tailscale#14152 Signed-off-by: James Stocker jamesrstocker@gmail.com Co-authored-by: James Stocker --- cmd/k8s-operator/deploy/chart/templates/ingressclass.yaml | 2 ++ cmd/k8s-operator/deploy/chart/values.yaml | 3 +++ 2 files changed, 5 insertions(+) diff --git a/cmd/k8s-operator/deploy/chart/templates/ingressclass.yaml b/cmd/k8s-operator/deploy/chart/templates/ingressclass.yaml index 2a1fa81b42793..208d58ee10f08 100644 --- a/cmd/k8s-operator/deploy/chart/templates/ingressclass.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/ingressclass.yaml @@ -1,3 +1,4 @@ +{{- if .Values.ingressClass.enabled }} apiVersion: networking.k8s.io/v1 kind: IngressClass metadata: @@ -6,3 +7,4 @@ metadata: spec: controller: tailscale.com/ts-ingress # controller name currently can not be changed # parameters: {} # currently no parameters are supported +{{- end }} diff --git a/cmd/k8s-operator/deploy/chart/values.yaml b/cmd/k8s-operator/deploy/chart/values.yaml index e6f4cada44de7..b24ba37b05360 100644 --- a/cmd/k8s-operator/deploy/chart/values.yaml +++ b/cmd/k8s-operator/deploy/chart/values.yaml @@ -54,6 +54,9 @@ operatorConfig: # - name: EXTRA_VAR2 # value: "value2" +# In the case that you already have a tailscale ingressclass in your cluster (or vcluster), you can disable the creation here +ingressClass: + enabled: true # proxyConfig contains configuraton that will be applied to any ingress/egress # proxies created by the operator. From ebeb5da202c00c41a3c87ebf687f89a2fc70bb90 Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Wed, 20 Nov 2024 14:22:34 +0000 Subject: [PATCH 121/179] cmd/k8s-operator,kube/kubeclient,docs/k8s: update rbac to emit events + small fixes (#14164) This is a follow-up to #14112 where our internal kube client was updated to allow it to emit Events - this updates our sample kube manifests and tsrecorder manifest templates so they can benefit from this functionality. Updates tailscale/tailscale#14080 Signed-off-by: Irbe Krumina --- cmd/k8s-operator/tsrecorder_specs.go | 17 +++++++++++++++++ docs/k8s/proxy.yaml | 8 ++++++++ docs/k8s/role.yaml | 3 +++ docs/k8s/sidecar.yaml | 8 ++++++++ docs/k8s/subnet.yaml | 8 ++++++++ docs/k8s/userspace-sidecar.yaml | 8 ++++++++ kube/kubeclient/client_test.go | 2 +- 7 files changed, 53 insertions(+), 1 deletion(-) diff --git a/cmd/k8s-operator/tsrecorder_specs.go b/cmd/k8s-operator/tsrecorder_specs.go index 4a74fb7e03442..4a7bf988773a6 100644 --- a/cmd/k8s-operator/tsrecorder_specs.go +++ b/cmd/k8s-operator/tsrecorder_specs.go @@ -130,6 +130,15 @@ func tsrRole(tsr *tsapi.Recorder, namespace string) *rbacv1.Role { fmt.Sprintf("%s-0", tsr.Name), // Contains the node state. }, }, + { + APIGroups: []string{""}, + Resources: []string{"events"}, + Verbs: []string{ + "get", + "create", + "patch", + }, + }, }, } } @@ -203,6 +212,14 @@ func env(tsr *tsapi.Recorder) []corev1.EnvVar { }, }, }, + { + Name: "POD_UID", + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: "metadata.uid", + }, + }, + }, { Name: "TS_STATE", Value: "kube:$(POD_NAME)", diff --git a/docs/k8s/proxy.yaml b/docs/k8s/proxy.yaml index 2ab7ed334395d..78e97c83b2be9 100644 --- a/docs/k8s/proxy.yaml +++ b/docs/k8s/proxy.yaml @@ -44,6 +44,14 @@ spec: value: "{{TS_DEST_IP}}" - name: TS_AUTH_ONCE value: "true" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid securityContext: capabilities: add: diff --git a/docs/k8s/role.yaml b/docs/k8s/role.yaml index 6d6a8117d1bbd..d7d0846ab29a6 100644 --- a/docs/k8s/role.yaml +++ b/docs/k8s/role.yaml @@ -13,3 +13,6 @@ rules: resourceNames: ["{{TS_KUBE_SECRET}}"] resources: ["secrets"] verbs: ["get", "update", "patch"] +- apiGroups: [""] # "" indicates the core API group + resources: ["events"] + verbs: ["get", "create", "patch"] diff --git a/docs/k8s/sidecar.yaml b/docs/k8s/sidecar.yaml index 7efd32a38d0ac..6baa6d5458d49 100644 --- a/docs/k8s/sidecar.yaml +++ b/docs/k8s/sidecar.yaml @@ -26,6 +26,14 @@ spec: name: tailscale-auth key: TS_AUTHKEY optional: true + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid securityContext: capabilities: add: diff --git a/docs/k8s/subnet.yaml b/docs/k8s/subnet.yaml index 4b7066fb3460a..1af146be689e6 100644 --- a/docs/k8s/subnet.yaml +++ b/docs/k8s/subnet.yaml @@ -28,6 +28,14 @@ spec: optional: true - name: TS_ROUTES value: "{{TS_ROUTES}}" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid securityContext: capabilities: add: diff --git a/docs/k8s/userspace-sidecar.yaml b/docs/k8s/userspace-sidecar.yaml index fc4ed63502dbc..ee19b10a5e5dd 100644 --- a/docs/k8s/userspace-sidecar.yaml +++ b/docs/k8s/userspace-sidecar.yaml @@ -27,3 +27,11 @@ spec: name: tailscale-auth key: TS_AUTHKEY optional: true + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid diff --git a/kube/kubeclient/client_test.go b/kube/kubeclient/client_test.go index 6b5e8171c5a76..31878befe4106 100644 --- a/kube/kubeclient/client_test.go +++ b/kube/kubeclient/client_test.go @@ -134,7 +134,7 @@ func fakeKubeAPIRequest(t *testing.T, argSets []args) kubeAPIRequestFunc { t.Errorf("[%d] got method %q, wants method %q", count, gotMethod, a.wantsMethod) } if gotUrl != a.wantsURL { - t.Errorf("[%d] got URL %q, wants URL %q", count, gotMethod, a.wantsMethod) + t.Errorf("[%d] got URL %q, wants URL %q", count, gotUrl, a.wantsURL) } if d := cmp.Diff(gotIn, a.wantsIn); d != "" { t.Errorf("[%d] unexpected payload (-want + got):\n%s", count, d) From ebaf33a80c5872a2d1156aa3bb55f82f3ce1b97b Mon Sep 17 00:00:00 2001 From: James Scott Date: Wed, 20 Nov 2024 12:28:25 -0800 Subject: [PATCH 122/179] net/tsaddr: extract IsTailscaleIPv4 from IsTailscaleIP (#14169) Extracts tsaddr.IsTailscaleIPv4 out of tsaddr.IsTailscaleIP. This will allow for checking valid Tailscale assigned IPv4 addresses without checking IPv6 addresses. Updates #14168 Updates tailscale/corp#24620 Signed-off-by: James Scott --- net/tsaddr/tsaddr.go | 10 ++++-- net/tsaddr/tsaddr_test.go | 68 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/net/tsaddr/tsaddr.go b/net/tsaddr/tsaddr.go index e7e0ba088bfd5..06e6a26ddb721 100644 --- a/net/tsaddr/tsaddr.go +++ b/net/tsaddr/tsaddr.go @@ -66,15 +66,21 @@ const ( TailscaleServiceIPv6String = "fd7a:115c:a1e0::53" ) -// IsTailscaleIP reports whether ip is an IP address in a range that +// IsTailscaleIP reports whether IP is an IP address in a range that // Tailscale assigns from. func IsTailscaleIP(ip netip.Addr) bool { if ip.Is4() { - return CGNATRange().Contains(ip) && !ChromeOSVMRange().Contains(ip) + return IsTailscaleIPv4(ip) } return TailscaleULARange().Contains(ip) } +// IsTailscaleIPv4 reports whether an IPv4 IP is an IP address that +// Tailscale assigns from. +func IsTailscaleIPv4(ip netip.Addr) bool { + return CGNATRange().Contains(ip) && !ChromeOSVMRange().Contains(ip) +} + // TailscaleULARange returns the IPv6 Unique Local Address range that // is the superset range that Tailscale assigns out of. func TailscaleULARange() netip.Prefix { diff --git a/net/tsaddr/tsaddr_test.go b/net/tsaddr/tsaddr_test.go index 4aa2f8c60f5b3..43977352b157d 100644 --- a/net/tsaddr/tsaddr_test.go +++ b/net/tsaddr/tsaddr_test.go @@ -222,3 +222,71 @@ func TestContainsExitRoute(t *testing.T) { } } } + +func TestIsTailscaleIPv4(t *testing.T) { + tests := []struct { + in netip.Addr + want bool + }{ + { + in: netip.MustParseAddr("100.67.19.57"), + want: true, + }, + { + in: netip.MustParseAddr("10.10.10.10"), + want: false, + }, + { + + in: netip.MustParseAddr("fd7a:115c:a1e0:3f2b:7a1d:4e88:9c2b:7f01"), + want: false, + }, + { + in: netip.MustParseAddr("bc9d:0aa0:1f0a:69ab:eb5c:28e0:5456:a518"), + want: false, + }, + { + in: netip.MustParseAddr("100.115.92.157"), + want: false, + }, + } + for _, tt := range tests { + if got := IsTailscaleIPv4(tt.in); got != tt.want { + t.Errorf("IsTailscaleIPv4() = %v, want %v", got, tt.want) + } + } +} + +func TestIsTailscaleIP(t *testing.T) { + tests := []struct { + in netip.Addr + want bool + }{ + { + in: netip.MustParseAddr("100.67.19.57"), + want: true, + }, + { + in: netip.MustParseAddr("10.10.10.10"), + want: false, + }, + { + + in: netip.MustParseAddr("fd7a:115c:a1e0:3f2b:7a1d:4e88:9c2b:7f01"), + want: true, + }, + { + in: netip.MustParseAddr("bc9d:0aa0:1f0a:69ab:eb5c:28e0:5456:a518"), + want: false, + }, + { + in: netip.MustParseAddr("100.115.92.157"), + want: false, + }, + } + for _, tt := range tests { + if got := IsTailscaleIP(tt.in); got != tt.want { + t.Errorf("IsTailscaleIP() = %v, want %v", got, tt.want) + } + } +} From 02cafbe1cadfcd82d22beb9138d4673169fcdc82 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 20 Nov 2024 14:19:59 -0800 Subject: [PATCH 123/179] tsweb: change RequestID format to have a date in it So we can locate them in logs more easily. Updates tailscale/corp#24721 Change-Id: Ia766c75608050dde7edc99835979a6e9bb328df2 Signed-off-by: Brad Fitzpatrick --- cmd/derper/depaware.txt | 6 ++---- cmd/derper/derper_test.go | 1 + cmd/stund/depaware.txt | 6 ++---- tsweb/request_id.go | 13 ++++++++----- tsweb/tsweb_test.go | 22 ++++++++++++++++++++++ 5 files changed, 35 insertions(+), 13 deletions(-) diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 81a7f14f4a71c..076074f2554a1 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -27,7 +27,6 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa L github.com/google/nftables/expr from github.com/google/nftables+ L github.com/google/nftables/internal/parseexprfunc from github.com/google/nftables+ L github.com/google/nftables/xt from github.com/google/nftables/expr+ - github.com/google/uuid from tailscale.com/util/fastuuid github.com/hdevalence/ed25519consensus from tailscale.com/tka L github.com/josharian/native from github.com/mdlayher/netlink+ L 💣 github.com/jsimonetti/rtnetlink from tailscale.com/net/netmon @@ -152,7 +151,6 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa 💣 tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics tailscale.com/util/dnsname from tailscale.com/hostinfo+ - tailscale.com/util/fastuuid from tailscale.com/tsweb 💣 tailscale.com/util/hashx from tailscale.com/util/deephash tailscale.com/util/httpm from tailscale.com/client/tailscale tailscale.com/util/lineiter from tailscale.com/hostinfo+ @@ -160,6 +158,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/util/mak from tailscale.com/health+ tailscale.com/util/multierr from tailscale.com/health+ tailscale.com/util/nocasemaps from tailscale.com/types/ipproto + tailscale.com/util/rands from tailscale.com/tsweb tailscale.com/util/set from tailscale.com/derp+ tailscale.com/util/singleflight from tailscale.com/net/dnscache tailscale.com/util/slicesx from tailscale.com/cmd/derper+ @@ -244,7 +243,6 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa crypto/tls from golang.org/x/crypto/acme+ crypto/x509 from crypto/tls+ crypto/x509/pkix from crypto/x509+ - database/sql/driver from github.com/google/uuid embed from crypto/internal/nistec+ encoding from encoding/json+ encoding/asn1 from crypto/x509+ @@ -276,7 +274,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa math/big from crypto/dsa+ math/bits from compress/flate+ math/rand from github.com/mdlayher/netlink+ - math/rand/v2 from tailscale.com/util/fastuuid+ + math/rand/v2 from internal/concurrent+ mime from github.com/prometheus/common/expfmt+ mime/multipart from net/http mime/quotedprintable from mime/multipart diff --git a/cmd/derper/derper_test.go b/cmd/derper/derper_test.go index 6ddf4455b0495..08d2e9cbf97c2 100644 --- a/cmd/derper/derper_test.go +++ b/cmd/derper/derper_test.go @@ -109,6 +109,7 @@ func TestDeps(t *testing.T) { "gvisor.dev/gvisor/pkg/tcpip/header": "https://github.com/tailscale/tailscale/issues/9756", "tailscale.com/net/packet": "not needed in derper", "github.com/gaissmai/bart": "not needed in derper", + "database/sql/driver": "not needed in derper", // previously came in via github.com/google/uuid }, }.Check(t) } diff --git a/cmd/stund/depaware.txt b/cmd/stund/depaware.txt index 7031b18e2087e..34a71c43e0010 100644 --- a/cmd/stund/depaware.txt +++ b/cmd/stund/depaware.txt @@ -8,7 +8,6 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/internal/jsonwire from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/jsontext from github.com/go-json-experiment/json+ - github.com/google/uuid from tailscale.com/util/fastuuid 💣 github.com/prometheus/client_golang/prometheus from tailscale.com/tsweb/promvarz github.com/prometheus/client_golang/prometheus/internal from github.com/prometheus/client_golang/prometheus github.com/prometheus/client_model/go from github.com/prometheus/client_golang/prometheus+ @@ -74,9 +73,9 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar tailscale.com/util/ctxkey from tailscale.com/tsweb+ L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics tailscale.com/util/dnsname from tailscale.com/tailcfg - tailscale.com/util/fastuuid from tailscale.com/tsweb tailscale.com/util/lineiter from tailscale.com/version/distro tailscale.com/util/nocasemaps from tailscale.com/types/ipproto + tailscale.com/util/rands from tailscale.com/tsweb tailscale.com/util/slicesx from tailscale.com/tailcfg tailscale.com/util/vizerror from tailscale.com/tailcfg+ tailscale.com/version from tailscale.com/envknob+ @@ -133,7 +132,6 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar crypto/tls from net/http+ crypto/x509 from crypto/tls crypto/x509/pkix from crypto/x509 - database/sql/driver from github.com/google/uuid embed from crypto/internal/nistec+ encoding from encoding/json+ encoding/asn1 from crypto/x509+ @@ -164,7 +162,7 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar math/big from crypto/dsa+ math/bits from compress/flate+ math/rand from math/big+ - math/rand/v2 from tailscale.com/util/fastuuid+ + math/rand/v2 from internal/concurrent+ mime from github.com/prometheus/common/expfmt+ mime/multipart from net/http mime/quotedprintable from mime/multipart diff --git a/tsweb/request_id.go b/tsweb/request_id.go index 8516b8f72161e..46e52385240ca 100644 --- a/tsweb/request_id.go +++ b/tsweb/request_id.go @@ -6,9 +6,10 @@ package tsweb import ( "context" "net/http" + "time" "tailscale.com/util/ctxkey" - "tailscale.com/util/fastuuid" + "tailscale.com/util/rands" ) // RequestID is an opaque identifier for a HTTP request, used to correlate @@ -41,10 +42,12 @@ const RequestIDHeader = "X-Tailscale-Request-Id" // GenerateRequestID generates a new request ID with the current format. func GenerateRequestID() RequestID { - // REQ-1 indicates the version of the RequestID pattern. It is - // currently arbitrary but allows for forward compatible - // transitions if needed. - return RequestID("REQ-1" + fastuuid.NewUUID().String()) + // Return a string of the form "REQ-<...>" + // Previously we returned "REQ-1". + // Now we return "REQ-2" version, where the "2" doubles as the year 2YYY + // in a leading date. + now := time.Now().UTC() + return RequestID("REQ-" + now.Format("20060102150405") + rands.HexString(16)) } // SetRequestID is an HTTP middleware that injects a RequestID in the diff --git a/tsweb/tsweb_test.go b/tsweb/tsweb_test.go index 13840c01225e3..d4c9721e97215 100644 --- a/tsweb/tsweb_test.go +++ b/tsweb/tsweb_test.go @@ -1307,6 +1307,28 @@ func TestBucket(t *testing.T) { } } +func TestGenerateRequestID(t *testing.T) { + t0 := time.Now() + got := GenerateRequestID() + t.Logf("Got: %q", got) + if !strings.HasPrefix(string(got), "REQ-2") { + t.Errorf("expect REQ-2 prefix; got %q", got) + } + const wantLen = len("REQ-2024112022140896f8ead3d3f3be27") + if len(got) != wantLen { + t.Fatalf("len = %d; want %d", len(got), wantLen) + } + d := got[len("REQ-"):][:14] + timeBack, err := time.Parse("20060102150405", string(d)) + if err != nil { + t.Fatalf("parsing time back: %v", err) + } + elapsed := timeBack.Sub(t0) + if elapsed > 3*time.Second { // allow for slow github actions runners :) + t.Fatalf("time back was %v; want within 3s", elapsed) + } +} + func ExampleMiddlewareStack() { // setHeader returns a middleware that sets header k = vs. setHeader := func(k string, vs ...string) Middleware { From 70d1241ca697a677145df84cf844f9c9cadd1bbc Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 20 Nov 2024 16:43:32 -0800 Subject: [PATCH 124/179] util/fastuuid: delete unused package Its sole user was deleted in 02cafbe1cadfc. And it has no public users: https://pkg.go.dev/tailscale.com/util/fastuuid?tab=importedby And nothing in other Tailsale repos that I can find. Updates tailscale/corp#24721 Change-Id: I8755770a255a91c6c99f596e6d10c303b3ddf213 Signed-off-by: Brad Fitzpatrick --- util/fastuuid/fastuuid.go | 56 -------------------------- util/fastuuid/fastuuid_test.go | 72 ---------------------------------- 2 files changed, 128 deletions(-) delete mode 100644 util/fastuuid/fastuuid.go delete mode 100644 util/fastuuid/fastuuid_test.go diff --git a/util/fastuuid/fastuuid.go b/util/fastuuid/fastuuid.go deleted file mode 100644 index 4b115ea4e4974..0000000000000 --- a/util/fastuuid/fastuuid.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package fastuuid implements a UUID construction using an in process CSPRNG. -package fastuuid - -import ( - crand "crypto/rand" - "encoding/binary" - "io" - "math/rand/v2" - "sync" - - "github.com/google/uuid" -) - -// NewUUID returns a new UUID using a pool of generators, good for highly -// concurrent use. -func NewUUID() uuid.UUID { - g := pool.Get().(*generator) - defer pool.Put(g) - return g.newUUID() -} - -var pool = sync.Pool{ - New: func() any { - return newGenerator() - }, -} - -type generator struct { - rng rand.ChaCha8 -} - -func seed() [32]byte { - var r [32]byte - if _, err := io.ReadFull(crand.Reader, r[:]); err != nil { - panic(err) - } - return r -} - -func newGenerator() *generator { - return &generator{ - rng: *rand.NewChaCha8(seed()), - } -} - -func (g *generator) newUUID() uuid.UUID { - var u uuid.UUID - binary.NativeEndian.PutUint64(u[:8], g.rng.Uint64()) - binary.NativeEndian.PutUint64(u[8:], g.rng.Uint64()) - u[6] = (u[6] & 0x0f) | 0x40 // Version 4 - u[8] = (u[8] & 0x3f) | 0x80 // Variant 10 - return u -} diff --git a/util/fastuuid/fastuuid_test.go b/util/fastuuid/fastuuid_test.go deleted file mode 100644 index f0d9939043850..0000000000000 --- a/util/fastuuid/fastuuid_test.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package fastuuid - -import ( - "testing" - - "github.com/google/uuid" -) - -func TestNewUUID(t *testing.T) { - g := pool.Get().(*generator) - defer pool.Put(g) - u := g.newUUID() - if u[6] != (u[6]&0x0f)|0x40 { - t.Errorf("version bits are incorrect") - } - if u[8] != (u[8]&0x3f)|0x80 { - t.Errorf("variant bits are incorrect") - } -} - -func BenchmarkBasic(b *testing.B) { - b.Run("NewUUID", func(b *testing.B) { - for range b.N { - NewUUID() - } - }) - - b.Run("uuid.New-unpooled", func(b *testing.B) { - uuid.DisableRandPool() - for range b.N { - uuid.New() - } - }) - - b.Run("uuid.New-pooled", func(b *testing.B) { - uuid.EnableRandPool() - for range b.N { - uuid.New() - } - }) -} - -func BenchmarkParallel(b *testing.B) { - b.Run("NewUUID", func(b *testing.B) { - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - NewUUID() - } - }) - }) - - b.Run("uuid.New-unpooled", func(b *testing.B) { - uuid.DisableRandPool() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - uuid.New() - } - }) - }) - - b.Run("uuid.New-pooled", func(b *testing.B) { - uuid.EnableRandPool() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - uuid.New() - } - }) - }) -} From af4c3a4a1baba868996bc9ed022d67ebe0320873 Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Wed, 20 Nov 2024 17:48:06 -0500 Subject: [PATCH 125/179] cmd/tailscale/cli: create netmon in debug ts2021 Otherwise we'll see a panic if we hit the dnsfallback code and try to call NewDialer with a nil NetMon. Updates #14161 Signed-off-by: Andrew Dunham Change-Id: I81c6e72376599b341cb58c37134c2a948b97cf5f --- cmd/tailscale/cli/debug.go | 7 +++++++ control/controlhttp/constants.go | 2 ++ 2 files changed, 9 insertions(+) diff --git a/cmd/tailscale/cli/debug.go b/cmd/tailscale/cli/debug.go index 7f235e85c8ca7..78bd708e54fee 100644 --- a/cmd/tailscale/cli/debug.go +++ b/cmd/tailscale/cli/debug.go @@ -36,6 +36,7 @@ import ( "tailscale.com/hostinfo" "tailscale.com/internal/noiseconn" "tailscale.com/ipn" + "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" "tailscale.com/net/tshttpproxy" "tailscale.com/paths" @@ -850,6 +851,11 @@ func runTS2021(ctx context.Context, args []string) error { logf = log.Printf } + netMon, err := netmon.New(logger.WithPrefix(logf, "netmon: ")) + if err != nil { + return fmt.Errorf("creating netmon: %w", err) + } + noiseDialer := &controlhttp.Dialer{ Hostname: ts2021Args.host, HTTPPort: "80", @@ -859,6 +865,7 @@ func runTS2021(ctx context.Context, args []string) error { ProtocolVersion: uint16(ts2021Args.version), Dialer: dialFunc, Logf: logf, + NetMon: netMon, } const tries = 2 for i := range tries { diff --git a/control/controlhttp/constants.go b/control/controlhttp/constants.go index 0b550acccf866..971212d63b994 100644 --- a/control/controlhttp/constants.go +++ b/control/controlhttp/constants.go @@ -76,6 +76,8 @@ type Dialer struct { // dropped. Logf logger.Logf + // NetMon is the [netmon.Monitor] to use for this Dialer. It must be + // non-nil. NetMon *netmon.Monitor // HealthTracker, if non-nil, is the health tracker to use. From 0c8c7c0f901f8a5e6cefe1334f3d2e0ad4db7b69 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 20 Nov 2024 16:14:13 -0800 Subject: [PATCH 126/179] net/tsaddr: include test input in test failure output https://go.dev/wiki/CodeReviewComments#useful-test-failures (Previously it was using subtests with names including the input, but once those went away, there was no context left) Updates #14169 Change-Id: Ib217028183a3d001fe4aee58f2edb746b7b3aa88 Signed-off-by: Brad Fitzpatrick --- net/tsaddr/tsaddr_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/net/tsaddr/tsaddr_test.go b/net/tsaddr/tsaddr_test.go index 43977352b157d..9ac1ce3036299 100644 --- a/net/tsaddr/tsaddr_test.go +++ b/net/tsaddr/tsaddr_test.go @@ -252,7 +252,7 @@ func TestIsTailscaleIPv4(t *testing.T) { } for _, tt := range tests { if got := IsTailscaleIPv4(tt.in); got != tt.want { - t.Errorf("IsTailscaleIPv4() = %v, want %v", got, tt.want) + t.Errorf("IsTailscaleIPv4(%v) = %v, want %v", tt.in, got, tt.want) } } } @@ -286,7 +286,7 @@ func TestIsTailscaleIP(t *testing.T) { } for _, tt := range tests { if got := IsTailscaleIP(tt.in); got != tt.want { - t.Errorf("IsTailscaleIP() = %v, want %v", got, tt.want) + t.Errorf("IsTailscaleIP(%v) = %v, want %v", tt.in, got, tt.want) } } } From e3c6ca43d3e3cad27714d07b3a9ec20141c9c65c Mon Sep 17 00:00:00 2001 From: Andrea Gottardo Date: Thu, 21 Nov 2024 12:56:41 -0800 Subject: [PATCH 127/179] cli: present risk warning when setting up app connector on macOS (#14181) --- cmd/tailscale/cli/risks.go | 13 ++++++++++--- cmd/tailscale/cli/set.go | 7 +++++++ cmd/tailscale/cli/up.go | 6 ++++++ 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/cmd/tailscale/cli/risks.go b/cmd/tailscale/cli/risks.go index 4cfa50d581ed4..acb50e723c585 100644 --- a/cmd/tailscale/cli/risks.go +++ b/cmd/tailscale/cli/risks.go @@ -17,11 +17,18 @@ import ( ) var ( - riskTypes []string - riskLoseSSH = registerRiskType("lose-ssh") - riskAll = registerRiskType("all") + riskTypes []string + riskLoseSSH = registerRiskType("lose-ssh") + riskMacAppConnector = registerRiskType("mac-app-connector") + riskAll = registerRiskType("all") ) +const riskMacAppConnectorMessage = ` +You are trying to configure an app connector on macOS, which is not officially supported due to system limitations. This may result in performance and reliability issues. + +Do not use a macOS app connector for any mission-critical purposes. For the best experience, Linux is the only recommended platform for app connectors. +` + func registerRiskType(riskType string) string { riskTypes = append(riskTypes, riskType) return riskType diff --git a/cmd/tailscale/cli/set.go b/cmd/tailscale/cli/set.go index 2e1251f04a4b9..e8e5f0c51e15b 100644 --- a/cmd/tailscale/cli/set.go +++ b/cmd/tailscale/cli/set.go @@ -10,6 +10,7 @@ import ( "fmt" "net/netip" "os/exec" + "runtime" "strings" "github.com/peterbourgon/ff/v3/ffcli" @@ -203,6 +204,12 @@ func runSet(ctx context.Context, args []string) (retErr error) { } } + if runtime.GOOS == "darwin" && maskedPrefs.AppConnector.Advertise { + if err := presentRiskToUser(riskMacAppConnector, riskMacAppConnectorMessage, setArgs.acceptedRisks); err != nil { + return err + } + } + if maskedPrefs.RunSSHSet { wantSSH, haveSSH := maskedPrefs.RunSSH, curPrefs.RunSSH if err := presentSSHToggleRisk(wantSSH, haveSSH, setArgs.acceptedRisks); err != nil { diff --git a/cmd/tailscale/cli/up.go b/cmd/tailscale/cli/up.go index 782df407deb18..6c5c6f337f909 100644 --- a/cmd/tailscale/cli/up.go +++ b/cmd/tailscale/cli/up.go @@ -379,6 +379,12 @@ func updatePrefs(prefs, curPrefs *ipn.Prefs, env upCheckEnv) (simpleUp bool, jus return false, nil, err } + if runtime.GOOS == "darwin" && env.upArgs.advertiseConnector { + if err := presentRiskToUser(riskMacAppConnector, riskMacAppConnectorMessage, env.upArgs.acceptedRisks); err != nil { + return false, nil, err + } + } + if env.upArgs.forceReauth && isSSHOverTailscale() { if err := presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will result in your SSH session disconnecting.`, env.upArgs.acceptedRisks); err != nil { return false, nil, err From c59ab6baacf3ddc96982b0f6cacd683157e8bc41 Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Fri, 22 Nov 2024 06:53:46 +0000 Subject: [PATCH 128/179] cmd/k8s-operator/deploy: ensure that operator can write kube state Events (#14177) A small follow-up to #14112- ensures that the operator itself can emit Events for its kube state store changes. Updates tailscale/tailscale#14080 Signed-off-by: Irbe Krumina --- cmd/k8s-operator/deploy/chart/templates/deployment.yaml | 8 ++++++++ cmd/k8s-operator/deploy/manifests/operator.yaml | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml index c428d5d1e751e..2653f21595ba7 100644 --- a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml @@ -81,6 +81,14 @@ spec: - name: PROXY_DEFAULT_CLASS value: {{ .Values.proxyConfig.defaultProxyClass }} {{- end }} + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid {{- with .Values.operatorConfig.extraEnv }} {{- toYaml . | nindent 12 }} {{- end }} diff --git a/cmd/k8s-operator/deploy/manifests/operator.yaml b/cmd/k8s-operator/deploy/manifests/operator.yaml index c6d7deef59dea..4035afabaf4ab 100644 --- a/cmd/k8s-operator/deploy/manifests/operator.yaml +++ b/cmd/k8s-operator/deploy/manifests/operator.yaml @@ -4783,6 +4783,14 @@ spec: value: "false" - name: PROXY_FIREWALL_MODE value: auto + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid image: tailscale/k8s-operator:unstable imagePullPolicy: Always name: operator From 74d4652144f11ace04612496095d658414ab09db Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Fri, 22 Nov 2024 15:41:07 +0000 Subject: [PATCH 129/179] cmd/{containerboot,k8s-operator},k8s-operator: new options to expose user metrics (#14035) containerboot: Adds 3 new environment variables for containerboot, `TS_LOCAL_ADDR_PORT` (default `"${POD_IP}:9002"`), `TS_METRICS_ENABLED` (default `false`), and `TS_DEBUG_ADDR_PORT` (default `""`), to configure metrics and debug endpoints. In a follow-up PR, the health check endpoint will be updated to use the `TS_LOCAL_ADDR_PORT` if `TS_HEALTHCHECK_ADDR_PORT` hasn't been set. Users previously only had access to internal debug metrics (which are unstable and not recommended) via passing the `--debug` flag to tailscaled, but can now set `TS_METRICS_ENABLED=true` to expose the stable metrics documented at https://tailscale.com/kb/1482/client-metrics at `/metrics` on the addr/port specified by `TS_LOCAL_ADDR_PORT`. Users can also now configure a debug endpoint more directly via the `TS_DEBUG_ADDR_PORT` environment variable. This is not recommended for production use, but exposes an internal set of debug metrics and pprof endpoints. operator: The `ProxyClass` CRD's `.spec.metrics.enable` field now enables serving the stable user metrics documented at https://tailscale.com/kb/1482/client-metrics at `/metrics` on the same "metrics" container port that debug metrics were previously served on. To smooth the transition for anyone relying on the way the operator previously consumed this field, we also _temporarily_ serve tailscaled's internal debug metrics on the same `/debug/metrics` path as before, until 1.82.0 when debug metrics will be turned off by default even if `.spec.metrics.enable` is set. At that point, anyone who wishes to continue using the internal debug metrics (not recommended) will need to set the new `ProxyClass` field `.spec.statefulSet.pod.tailscaleContainer.debug.enable`. Users who wish to opt out of the transitional behaviour, where enabling `.spec.metrics.enable` also enables debug metrics, can set `.spec.statefulSet.pod.tailscaleContainer.debug.enable` to false (recommended). Separately but related, the operator will no longer specify a host port for the "metrics" container port definition. This caused scheduling conflicts when k8s needs to schedule more than one proxy per node, and was not necessary for allowing the pod's port to be exposed to prometheus scrapers. Updates #11292 --------- Co-authored-by: Kristoffer Dalby Signed-off-by: Tom Proctor --- cmd/containerboot/healthz.go | 2 +- cmd/containerboot/main.go | 8 ++ cmd/containerboot/metrics.go | 91 +++++++++++++++++++ cmd/containerboot/settings.go | 22 ++++- cmd/containerboot/tailscaled.go | 6 ++ .../crds/tailscale.com_proxyclasses.yaml | 45 ++++++++- .../deploy/manifests/operator.yaml | 45 ++++++++- cmd/k8s-operator/proxyclass.go | 4 + cmd/k8s-operator/proxyclass_test.go | 53 +++++++++++ cmd/k8s-operator/sts.go | 90 +++++++++++++++--- cmd/k8s-operator/sts_test.go | 74 ++++++++++++--- k8s-operator/api.md | 19 +++- .../apis/v1alpha1/types_proxyclass.go | 27 +++++- .../apis/v1alpha1/zz_generated.deepcopy.go | 20 ++++ 14 files changed, 472 insertions(+), 34 deletions(-) create mode 100644 cmd/containerboot/metrics.go diff --git a/cmd/containerboot/healthz.go b/cmd/containerboot/healthz.go index fb7fccd968816..12e7ee9f8db73 100644 --- a/cmd/containerboot/healthz.go +++ b/cmd/containerboot/healthz.go @@ -39,7 +39,7 @@ func runHealthz(addr string, h *healthz) { log.Fatalf("error listening on the provided health endpoint address %q: %v", addr, err) } mux := http.NewServeMux() - mux.Handle("/healthz", h) + mux.Handle("GET /healthz", h) log.Printf("Running healthcheck endpoint at %s/healthz", addr) hs := &http.Server{Handler: mux} diff --git a/cmd/containerboot/main.go b/cmd/containerboot/main.go index 17131faae08b8..313e8deb0b93c 100644 --- a/cmd/containerboot/main.go +++ b/cmd/containerboot/main.go @@ -178,6 +178,14 @@ func main() { } defer killTailscaled() + if cfg.LocalAddrPort != "" && cfg.MetricsEnabled { + m := &metrics{ + lc: client, + debugEndpoint: cfg.DebugAddrPort, + } + runMetrics(cfg.LocalAddrPort, m) + } + if cfg.EnableForwardingOptimizations { if err := client.SetUDPGROForwarding(bootCtx); err != nil { log.Printf("[unexpected] error enabling UDP GRO forwarding: %v", err) diff --git a/cmd/containerboot/metrics.go b/cmd/containerboot/metrics.go new file mode 100644 index 0000000000000..e88406f97c9c6 --- /dev/null +++ b/cmd/containerboot/metrics.go @@ -0,0 +1,91 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package main + +import ( + "fmt" + "io" + "log" + "net" + "net/http" + + "tailscale.com/client/tailscale" + "tailscale.com/client/tailscale/apitype" +) + +// metrics is a simple metrics HTTP server, if enabled it forwards requests to +// the tailscaled's LocalAPI usermetrics endpoint at /localapi/v0/usermetrics. +type metrics struct { + debugEndpoint string + lc *tailscale.LocalClient +} + +func proxy(w http.ResponseWriter, r *http.Request, url string, do func(*http.Request) (*http.Response, error)) { + req, err := http.NewRequestWithContext(r.Context(), r.Method, url, r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("failed to construct request: %s", err), http.StatusInternalServerError) + return + } + req.Header = r.Header.Clone() + + resp, err := do(req) + if err != nil { + http.Error(w, fmt.Sprintf("failed to proxy request: %s", err), http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + w.WriteHeader(resp.StatusCode) + for key, val := range resp.Header { + for _, v := range val { + w.Header().Add(key, v) + } + } + if _, err := io.Copy(w, resp.Body); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func (m *metrics) handleMetrics(w http.ResponseWriter, r *http.Request) { + localAPIURL := "http://" + apitype.LocalAPIHost + "/localapi/v0/usermetrics" + proxy(w, r, localAPIURL, m.lc.DoLocalRequest) +} + +func (m *metrics) handleDebug(w http.ResponseWriter, r *http.Request) { + if m.debugEndpoint == "" { + http.Error(w, "debug endpoint not configured", http.StatusNotFound) + return + } + + debugURL := "http://" + m.debugEndpoint + r.URL.Path + proxy(w, r, debugURL, http.DefaultClient.Do) +} + +// runMetrics runs a simple HTTP metrics endpoint at /metrics, forwarding +// requests to tailscaled's /localapi/v0/usermetrics API. +// +// In 1.78.x and 1.80.x, it also proxies debug paths to tailscaled's debug +// endpoint if configured to ease migration for a breaking change serving user +// metrics instead of debug metrics on the "metrics" port. +func runMetrics(addr string, m *metrics) { + ln, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalf("error listening on the provided metrics endpoint address %q: %v", addr, err) + } + + mux := http.NewServeMux() + mux.HandleFunc("GET /metrics", m.handleMetrics) + mux.HandleFunc("/debug/", m.handleDebug) // TODO(tomhjp): Remove for 1.82.0 release. + + log.Printf("Running metrics endpoint at %s/metrics", addr) + ms := &http.Server{Handler: mux} + + go func() { + if err := ms.Serve(ln); err != nil { + log.Fatalf("failed running metrics endpoint: %v", err) + } + }() +} diff --git a/cmd/containerboot/settings.go b/cmd/containerboot/settings.go index 742713e7700de..c877682b95742 100644 --- a/cmd/containerboot/settings.go +++ b/cmd/containerboot/settings.go @@ -67,11 +67,18 @@ type settings struct { PodIP string PodIPv4 string PodIPv6 string - HealthCheckAddrPort string + HealthCheckAddrPort string // TODO(tomhjp): use the local addr/port instead. + LocalAddrPort string + MetricsEnabled bool + DebugAddrPort string EgressSvcsCfgPath string } func configFromEnv() (*settings, error) { + defaultLocalAddrPort := "" + if v, ok := os.LookupEnv("POD_IP"); ok && v != "" { + defaultLocalAddrPort = fmt.Sprintf("%s:9002", v) + } cfg := &settings{ AuthKey: defaultEnvs([]string{"TS_AUTHKEY", "TS_AUTH_KEY"}, ""), Hostname: defaultEnv("TS_HOSTNAME", ""), @@ -98,6 +105,9 @@ func configFromEnv() (*settings, error) { PodIP: defaultEnv("POD_IP", ""), EnableForwardingOptimizations: defaultBool("TS_EXPERIMENTAL_ENABLE_FORWARDING_OPTIMIZATIONS", false), HealthCheckAddrPort: defaultEnv("TS_HEALTHCHECK_ADDR_PORT", ""), + LocalAddrPort: defaultEnv("TS_LOCAL_ADDR_PORT", defaultLocalAddrPort), + MetricsEnabled: defaultBool("TS_METRICS_ENABLED", false), + DebugAddrPort: defaultEnv("TS_DEBUG_ADDR_PORT", ""), EgressSvcsCfgPath: defaultEnv("TS_EGRESS_SERVICES_CONFIG_PATH", ""), } podIPs, ok := os.LookupEnv("POD_IPS") @@ -175,6 +185,16 @@ func (s *settings) validate() error { return fmt.Errorf("error parsing TS_HEALTH_CHECK_ADDR_PORT value %q: %w", s.HealthCheckAddrPort, err) } } + if s.LocalAddrPort != "" { + if _, err := netip.ParseAddrPort(s.LocalAddrPort); err != nil { + return fmt.Errorf("error parsing TS_LOCAL_ADDR_PORT value %q: %w", s.LocalAddrPort, err) + } + } + if s.DebugAddrPort != "" { + if _, err := netip.ParseAddrPort(s.DebugAddrPort); err != nil { + return fmt.Errorf("error parsing TS_DEBUG_ADDR_PORT value %q: %w", s.DebugAddrPort, err) + } + } return nil } diff --git a/cmd/containerboot/tailscaled.go b/cmd/containerboot/tailscaled.go index 53fb7e703be45..d8da49b033d06 100644 --- a/cmd/containerboot/tailscaled.go +++ b/cmd/containerboot/tailscaled.go @@ -90,6 +90,12 @@ func tailscaledArgs(cfg *settings) []string { if cfg.TailscaledConfigFilePath != "" { args = append(args, "--config="+cfg.TailscaledConfigFilePath) } + // Once enough proxy versions have been released for all the supported + // versions to understand this cfg setting, the operator can stop + // setting TS_TAILSCALED_EXTRA_ARGS for the debug flag. + if cfg.DebugAddrPort != "" && !strings.Contains(cfg.DaemonExtraArgs, cfg.DebugAddrPort) { + args = append(args, "--debug="+cfg.DebugAddrPort) + } if cfg.DaemonExtraArgs != "" { args = append(args, strings.Fields(cfg.DaemonExtraArgs)...) } diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml index 7086138c03afd..4c24a1633284e 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml @@ -73,7 +73,12 @@ spec: enable: description: |- Setting enable to true will make the proxy serve Tailscale metrics - at :9001/debug/metrics. + at :9002/metrics. + + In 1.78.x and 1.80.x, this field also serves as the default value for + .spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both + fields will independently default to false. + Defaults to false. type: boolean statefulSet: @@ -1249,6 +1254,25 @@ spec: description: Configuration for the proxy container running tailscale. type: object properties: + debug: + description: |- + Configuration for enabling extra debug information in the container. + Not recommended for production use. + type: object + properties: + enable: + description: |- + Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/ + and internal debug metrics endpoint at :9001/debug/metrics, where + 9001 is a container port named "debug". The endpoints and their responses + may change in backwards incompatible ways in the future, and should not + be considered stable. + + In 1.78.x and 1.80.x, this setting will default to the value of + .spec.metrics.enable, and requests to the "metrics" port matching the + mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x, + this setting will default to false, and no requests will be proxied. + type: boolean env: description: |- List of environment variables to set in the container. @@ -1553,6 +1577,25 @@ spec: description: Configuration for the proxy init container that enables forwarding. type: object properties: + debug: + description: |- + Configuration for enabling extra debug information in the container. + Not recommended for production use. + type: object + properties: + enable: + description: |- + Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/ + and internal debug metrics endpoint at :9001/debug/metrics, where + 9001 is a container port named "debug". The endpoints and their responses + may change in backwards incompatible ways in the future, and should not + be considered stable. + + In 1.78.x and 1.80.x, this setting will default to the value of + .spec.metrics.enable, and requests to the "metrics" port matching the + mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x, + this setting will default to false, and no requests will be proxied. + type: boolean env: description: |- List of environment variables to set in the container. diff --git a/cmd/k8s-operator/deploy/manifests/operator.yaml b/cmd/k8s-operator/deploy/manifests/operator.yaml index 4035afabaf4ab..f764fc09ac353 100644 --- a/cmd/k8s-operator/deploy/manifests/operator.yaml +++ b/cmd/k8s-operator/deploy/manifests/operator.yaml @@ -540,7 +540,12 @@ spec: enable: description: |- Setting enable to true will make the proxy serve Tailscale metrics - at :9001/debug/metrics. + at :9002/metrics. + + In 1.78.x and 1.80.x, this field also serves as the default value for + .spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both + fields will independently default to false. + Defaults to false. type: boolean required: @@ -1716,6 +1721,25 @@ spec: tailscaleContainer: description: Configuration for the proxy container running tailscale. properties: + debug: + description: |- + Configuration for enabling extra debug information in the container. + Not recommended for production use. + properties: + enable: + description: |- + Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/ + and internal debug metrics endpoint at :9001/debug/metrics, where + 9001 is a container port named "debug". The endpoints and their responses + may change in backwards incompatible ways in the future, and should not + be considered stable. + + In 1.78.x and 1.80.x, this setting will default to the value of + .spec.metrics.enable, and requests to the "metrics" port matching the + mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x, + this setting will default to false, and no requests will be proxied. + type: boolean + type: object env: description: |- List of environment variables to set in the container. @@ -2020,6 +2044,25 @@ spec: tailscaleInitContainer: description: Configuration for the proxy init container that enables forwarding. properties: + debug: + description: |- + Configuration for enabling extra debug information in the container. + Not recommended for production use. + properties: + enable: + description: |- + Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/ + and internal debug metrics endpoint at :9001/debug/metrics, where + 9001 is a container port named "debug". The endpoints and their responses + may change in backwards incompatible ways in the future, and should not + be considered stable. + + In 1.78.x and 1.80.x, this setting will default to the value of + .spec.metrics.enable, and requests to the "metrics" port matching the + mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x, + this setting will default to false, and no requests will be proxied. + type: boolean + type: object env: description: |- List of environment variables to set in the container. diff --git a/cmd/k8s-operator/proxyclass.go b/cmd/k8s-operator/proxyclass.go index 882a9030fa75d..13f217f3c1685 100644 --- a/cmd/k8s-operator/proxyclass.go +++ b/cmd/k8s-operator/proxyclass.go @@ -160,6 +160,10 @@ func (pcr *ProxyClassReconciler) validate(pc *tsapi.ProxyClass) (violations fiel violations = append(violations, field.TypeInvalid(field.NewPath("spec", "statefulSet", "pod", "tailscaleInitContainer", "image"), tc.Image, err.Error())) } } + + if tc.Debug != nil { + violations = append(violations, field.TypeInvalid(field.NewPath("spec", "statefulSet", "pod", "tailscaleInitContainer", "debug"), tc.Debug, "debug settings cannot be configured on the init container")) + } } } } diff --git a/cmd/k8s-operator/proxyclass_test.go b/cmd/k8s-operator/proxyclass_test.go index eb68811fc6b94..fb17f5fe5e3ee 100644 --- a/cmd/k8s-operator/proxyclass_test.go +++ b/cmd/k8s-operator/proxyclass_test.go @@ -135,3 +135,56 @@ func TestProxyClass(t *testing.T) { expectReconciled(t, pcr, "", "test") expectEvents(t, fr, expectedEvents) } + +func TestValidateProxyClass(t *testing.T) { + for name, tc := range map[string]struct { + pc *tsapi.ProxyClass + valid bool + }{ + "empty": { + valid: true, + pc: &tsapi.ProxyClass{}, + }, + "debug_enabled_for_main_container": { + valid: true, + pc: &tsapi.ProxyClass{ + Spec: tsapi.ProxyClassSpec{ + StatefulSet: &tsapi.StatefulSet{ + Pod: &tsapi.Pod{ + TailscaleContainer: &tsapi.Container{ + Debug: &tsapi.Debug{ + Enable: true, + }, + }, + }, + }, + }, + }, + }, + "debug_enabled_for_init_container": { + valid: false, + pc: &tsapi.ProxyClass{ + Spec: tsapi.ProxyClassSpec{ + StatefulSet: &tsapi.StatefulSet{ + Pod: &tsapi.Pod{ + TailscaleInitContainer: &tsapi.Container{ + Debug: &tsapi.Debug{ + Enable: true, + }, + }, + }, + }, + }, + }, + }, + } { + t.Run(name, func(t *testing.T) { + pcr := &ProxyClassReconciler{} + err := pcr.validate(tc.pc) + valid := err == nil + if valid != tc.valid { + t.Errorf("expected valid=%v, got valid=%v, err=%v", tc.valid, valid, err) + } + }) + } +} diff --git a/cmd/k8s-operator/sts.go b/cmd/k8s-operator/sts.go index bdacec39b0e98..5df476478c987 100644 --- a/cmd/k8s-operator/sts.go +++ b/cmd/k8s-operator/sts.go @@ -476,7 +476,7 @@ var proxyYaml []byte //go:embed deploy/manifests/userspace-proxy.yaml var userspaceProxyYaml []byte -func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.SugaredLogger, sts *tailscaleSTSConfig, headlessSvc *corev1.Service, proxySecret, tsConfigHash string, configs map[tailcfg.CapabilityVersion]ipn.ConfigVAlpha) (*appsv1.StatefulSet, error) { +func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.SugaredLogger, sts *tailscaleSTSConfig, headlessSvc *corev1.Service, proxySecret, tsConfigHash string, _ map[tailcfg.CapabilityVersion]ipn.ConfigVAlpha) (*appsv1.StatefulSet, error) { ss := new(appsv1.StatefulSet) if sts.ServeConfig != nil && sts.ForwardClusterTrafficViaL7IngressProxy != true { // If forwarding cluster traffic via is required we need non-userspace + NET_ADMIN + forwarding if err := yaml.Unmarshal(userspaceProxyYaml, &ss); err != nil { @@ -666,24 +666,42 @@ func mergeStatefulSetLabelsOrAnnots(current, custom map[string]string, managed [ return custom } +func debugSetting(pc *tsapi.ProxyClass) bool { + if pc == nil || + pc.Spec.StatefulSet == nil || + pc.Spec.StatefulSet.Pod == nil || + pc.Spec.StatefulSet.Pod.TailscaleContainer == nil || + pc.Spec.StatefulSet.Pod.TailscaleContainer.Debug == nil { + // This default will change to false in 1.82.0. + return pc.Spec.Metrics != nil && pc.Spec.Metrics.Enable + } + + return pc.Spec.StatefulSet.Pod.TailscaleContainer.Debug.Enable +} + func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, stsCfg *tailscaleSTSConfig, logger *zap.SugaredLogger) *appsv1.StatefulSet { if pc == nil || ss == nil { return ss } - if stsCfg != nil && pc.Spec.Metrics != nil && pc.Spec.Metrics.Enable { - if stsCfg.TailnetTargetFQDN == "" && stsCfg.TailnetTargetIP == "" && !stsCfg.ForwardClusterTrafficViaL7IngressProxy { - enableMetrics(ss) - } else if stsCfg.ForwardClusterTrafficViaL7IngressProxy { + + metricsEnabled := pc.Spec.Metrics != nil && pc.Spec.Metrics.Enable + debugEnabled := debugSetting(pc) + if metricsEnabled || debugEnabled { + isEgress := stsCfg != nil && (stsCfg.TailnetTargetFQDN != "" || stsCfg.TailnetTargetIP != "") + isForwardingL7Ingress := stsCfg != nil && stsCfg.ForwardClusterTrafficViaL7IngressProxy + if isEgress { // TODO (irbekrm): fix this // For Ingress proxies that have been configured with // tailscale.com/experimental-forward-cluster-traffic-via-ingress // annotation, all cluster traffic is forwarded to the // Ingress backend(s). - logger.Info("ProxyClass specifies that metrics should be enabled, but this is currently not supported for Ingress proxies that accept cluster traffic.") - } else { + logger.Info("ProxyClass specifies that metrics should be enabled, but this is currently not supported for egress proxies.") + } else if isForwardingL7Ingress { // TODO (irbekrm): fix this // For egress proxies, currently all cluster traffic is forwarded to the tailnet target. logger.Info("ProxyClass specifies that metrics should be enabled, but this is currently not supported for Ingress proxies that accept cluster traffic.") + } else { + enableEndpoints(ss, metricsEnabled, debugEnabled) } } @@ -761,16 +779,58 @@ func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, return ss } -func enableMetrics(ss *appsv1.StatefulSet) { +func enableEndpoints(ss *appsv1.StatefulSet, metrics, debug bool) { for i, c := range ss.Spec.Template.Spec.Containers { if c.Name == "tailscale" { - // Serve metrics on on :9001/debug/metrics. If - // we didn't specify Pod IP here, the proxy would, in - // some cases, also listen to its Tailscale IP- we don't - // want folks to start relying on this side-effect as a - // feature. - ss.Spec.Template.Spec.Containers[i].Env = append(ss.Spec.Template.Spec.Containers[i].Env, corev1.EnvVar{Name: "TS_TAILSCALED_EXTRA_ARGS", Value: "--debug=$(POD_IP):9001"}) - ss.Spec.Template.Spec.Containers[i].Ports = append(ss.Spec.Template.Spec.Containers[i].Ports, corev1.ContainerPort{Name: "metrics", Protocol: "TCP", HostPort: 9001, ContainerPort: 9001}) + if debug { + ss.Spec.Template.Spec.Containers[i].Env = append(ss.Spec.Template.Spec.Containers[i].Env, + // Serve tailscaled's debug metrics on on + // :9001/debug/metrics. If we didn't specify Pod IP + // here, the proxy would, in some cases, also listen to its + // Tailscale IP- we don't want folks to start relying on this + // side-effect as a feature. + corev1.EnvVar{ + Name: "TS_DEBUG_ADDR_PORT", + Value: "$(POD_IP):9001", + }, + // TODO(tomhjp): Can remove this env var once 1.76.x is no + // longer supported. + corev1.EnvVar{ + Name: "TS_TAILSCALED_EXTRA_ARGS", + Value: "--debug=$(TS_DEBUG_ADDR_PORT)", + }, + ) + + ss.Spec.Template.Spec.Containers[i].Ports = append(ss.Spec.Template.Spec.Containers[i].Ports, + corev1.ContainerPort{ + Name: "debug", + Protocol: "TCP", + ContainerPort: 9001, + }, + ) + } + + if metrics { + ss.Spec.Template.Spec.Containers[i].Env = append(ss.Spec.Template.Spec.Containers[i].Env, + // Serve client metrics on :9002/metrics. + corev1.EnvVar{ + Name: "TS_LOCAL_ADDR_PORT", + Value: "$(POD_IP):9002", + }, + corev1.EnvVar{ + Name: "TS_METRICS_ENABLED", + Value: "true", + }, + ) + ss.Spec.Template.Spec.Containers[i].Ports = append(ss.Spec.Template.Spec.Containers[i].Ports, + corev1.ContainerPort{ + Name: "metrics", + Protocol: "TCP", + ContainerPort: 9002, + }, + ) + } + break } } diff --git a/cmd/k8s-operator/sts_test.go b/cmd/k8s-operator/sts_test.go index 7263c56c36bb9..7986d1b9164eb 100644 --- a/cmd/k8s-operator/sts_test.go +++ b/cmd/k8s-operator/sts_test.go @@ -125,10 +125,26 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { }, }, } - proxyClassMetrics := &tsapi.ProxyClass{ - Spec: tsapi.ProxyClassSpec{ - Metrics: &tsapi.Metrics{Enable: true}, - }, + + proxyClassWithMetricsDebug := func(metrics bool, debug *bool) *tsapi.ProxyClass { + return &tsapi.ProxyClass{ + Spec: tsapi.ProxyClassSpec{ + Metrics: &tsapi.Metrics{Enable: metrics}, + StatefulSet: func() *tsapi.StatefulSet { + if debug == nil { + return nil + } + + return &tsapi.StatefulSet{ + Pod: &tsapi.Pod{ + TailscaleContainer: &tsapi.Container{ + Debug: &tsapi.Debug{Enable: *debug}, + }, + }, + } + }(), + }, + } } var userspaceProxySS, nonUserspaceProxySS appsv1.StatefulSet @@ -184,7 +200,7 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { gotSS := applyProxyClassToStatefulSet(proxyClassAllOpts, nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { - t.Fatalf("Unexpected result applying ProxyClass with all fields set to a StatefulSet for non-userspace proxy (-got +want):\n%s", diff) + t.Errorf("Unexpected result applying ProxyClass with all fields set to a StatefulSet for non-userspace proxy (-got +want):\n%s", diff) } // 2. Test that a ProxyClass with custom labels and annotations for @@ -197,7 +213,7 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { wantSS.Spec.Template.Annotations = proxyClassJustLabels.Spec.StatefulSet.Pod.Annotations gotSS = applyProxyClassToStatefulSet(proxyClassJustLabels, nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { - t.Fatalf("Unexpected result applying ProxyClass with custom labels and annotations to a StatefulSet for non-userspace proxy (-got +want):\n%s", diff) + t.Errorf("Unexpected result applying ProxyClass with custom labels and annotations to a StatefulSet for non-userspace proxy (-got +want):\n%s", diff) } // 3. Test that a ProxyClass with all fields set gets correctly applied @@ -221,7 +237,7 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { wantSS.Spec.Template.Spec.Containers[0].Image = "ghcr.io/my-repo/tailscale:v0.01testsomething" gotSS = applyProxyClassToStatefulSet(proxyClassAllOpts, userspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { - t.Fatalf("Unexpected result applying ProxyClass with all options to a StatefulSet for a userspace proxy (-got +want):\n%s", diff) + t.Errorf("Unexpected result applying ProxyClass with all options to a StatefulSet for a userspace proxy (-got +want):\n%s", diff) } // 4. Test that a ProxyClass with custom labels and annotations gets correctly applied @@ -233,16 +249,48 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { wantSS.Spec.Template.Annotations = proxyClassJustLabels.Spec.StatefulSet.Pod.Annotations gotSS = applyProxyClassToStatefulSet(proxyClassJustLabels, userspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { - t.Fatalf("Unexpected result applying ProxyClass with custom labels and annotations to a StatefulSet for a userspace proxy (-got +want):\n%s", diff) + t.Errorf("Unexpected result applying ProxyClass with custom labels and annotations to a StatefulSet for a userspace proxy (-got +want):\n%s", diff) + } + + // 5. Metrics enabled defaults to enabling both metrics and debug. + wantSS = nonUserspaceProxySS.DeepCopy() + wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, + corev1.EnvVar{Name: "TS_DEBUG_ADDR_PORT", Value: "$(POD_IP):9001"}, + corev1.EnvVar{Name: "TS_TAILSCALED_EXTRA_ARGS", Value: "--debug=$(TS_DEBUG_ADDR_PORT)"}, + corev1.EnvVar{Name: "TS_LOCAL_ADDR_PORT", Value: "$(POD_IP):9002"}, + corev1.EnvVar{Name: "TS_METRICS_ENABLED", Value: "true"}, + ) + wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{ + {Name: "debug", Protocol: "TCP", ContainerPort: 9001}, + {Name: "metrics", Protocol: "TCP", ContainerPort: 9002}, + } + gotSS = applyProxyClassToStatefulSet(proxyClassWithMetricsDebug(true, nil), nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) + if diff := cmp.Diff(gotSS, wantSS); diff != "" { + t.Errorf("Unexpected result applying ProxyClass with metrics enabled to a StatefulSet (-got +want):\n%s", diff) + } + + // 6. Enable _just_ metrics by explicitly disabling debug. + wantSS = nonUserspaceProxySS.DeepCopy() + wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, + corev1.EnvVar{Name: "TS_LOCAL_ADDR_PORT", Value: "$(POD_IP):9002"}, + corev1.EnvVar{Name: "TS_METRICS_ENABLED", Value: "true"}, + ) + wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{{Name: "metrics", Protocol: "TCP", ContainerPort: 9002}} + gotSS = applyProxyClassToStatefulSet(proxyClassWithMetricsDebug(true, ptr.To(false)), nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) + if diff := cmp.Diff(gotSS, wantSS); diff != "" { + t.Errorf("Unexpected result applying ProxyClass with metrics enabled to a StatefulSet (-got +want):\n%s", diff) } - // 5. Test that a ProxyClass with metrics enabled gets correctly applied to a StatefulSet. + // 7. Enable _just_ debug without metrics. wantSS = nonUserspaceProxySS.DeepCopy() - wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, corev1.EnvVar{Name: "TS_TAILSCALED_EXTRA_ARGS", Value: "--debug=$(POD_IP):9001"}) - wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{{Name: "metrics", Protocol: "TCP", ContainerPort: 9001, HostPort: 9001}} - gotSS = applyProxyClassToStatefulSet(proxyClassMetrics, nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) + wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, + corev1.EnvVar{Name: "TS_DEBUG_ADDR_PORT", Value: "$(POD_IP):9001"}, + corev1.EnvVar{Name: "TS_TAILSCALED_EXTRA_ARGS", Value: "--debug=$(TS_DEBUG_ADDR_PORT)"}, + ) + wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{{Name: "debug", Protocol: "TCP", ContainerPort: 9001}} + gotSS = applyProxyClassToStatefulSet(proxyClassWithMetricsDebug(false, ptr.To(true)), nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { - t.Fatalf("Unexpected result applying ProxyClass with metrics enabled to a StatefulSet (-got +want):\n%s", diff) + t.Errorf("Unexpected result applying ProxyClass with metrics enabled to a StatefulSet (-got +want):\n%s", diff) } } diff --git a/k8s-operator/api.md b/k8s-operator/api.md index 7b1aca3148e5b..640d8fb07bc54 100644 --- a/k8s-operator/api.md +++ b/k8s-operator/api.md @@ -146,6 +146,7 @@ _Appears in:_ | `imagePullPolicy` _[PullPolicy](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#pullpolicy-v1-core)_ | Image pull policy. One of Always, Never, IfNotPresent. Defaults to Always.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#image | | Enum: [Always Never IfNotPresent]
| | `resources` _[ResourceRequirements](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#resourcerequirements-v1-core)_ | Container resource requirements.
By default Tailscale Kubernetes operator does not apply any resource
requirements. The amount of resources required wil depend on the
amount of resources the operator needs to parse, usage patterns and
cluster size.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#resources | | | | `securityContext` _[SecurityContext](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#securitycontext-v1-core)_ | Container security context.
Security context specified here will override the security context by the operator.
By default the operator:
- sets 'privileged: true' for the init container
- set NET_ADMIN capability for tailscale container for proxies that
are created for Services or Connector.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context | | | +| `debug` _[Debug](#debug)_ | Configuration for enabling extra debug information in the container.
Not recommended for production use. | | | #### DNSConfig @@ -248,6 +249,22 @@ _Appears in:_ | `nameserver` _[NameserverStatus](#nameserverstatus)_ | Nameserver describes the status of nameserver cluster resources. | | | +#### Debug + + + + + + + +_Appears in:_ +- [Container](#container) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `enable` _boolean_ | Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/
and internal debug metrics endpoint at :9001/debug/metrics, where
9001 is a container port named "debug". The endpoints and their responses
may change in backwards incompatible ways in the future, and should not
be considered stable.
In 1.78.x and 1.80.x, this setting will default to the value of
.spec.metrics.enable, and requests to the "metrics" port matching the
mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x,
this setting will default to false, and no requests will be proxied. | | | + + #### Env @@ -309,7 +326,7 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `enable` _boolean_ | Setting enable to true will make the proxy serve Tailscale metrics
at :9001/debug/metrics.
Defaults to false. | | | +| `enable` _boolean_ | Setting enable to true will make the proxy serve Tailscale metrics
at :9002/metrics.
In 1.78.x and 1.80.x, this field also serves as the default value for
.spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both
fields will independently default to false.
Defaults to false. | | | #### Name diff --git a/k8s-operator/apis/v1alpha1/types_proxyclass.go b/k8s-operator/apis/v1alpha1/types_proxyclass.go index 0a224b7960495..7e408cd0a7338 100644 --- a/k8s-operator/apis/v1alpha1/types_proxyclass.go +++ b/k8s-operator/apis/v1alpha1/types_proxyclass.go @@ -163,7 +163,12 @@ type Pod struct { type Metrics struct { // Setting enable to true will make the proxy serve Tailscale metrics - // at :9001/debug/metrics. + // at :9002/metrics. + // + // In 1.78.x and 1.80.x, this field also serves as the default value for + // .spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both + // fields will independently default to false. + // // Defaults to false. Enable bool `json:"enable"` } @@ -209,6 +214,26 @@ type Container struct { // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context // +optional SecurityContext *corev1.SecurityContext `json:"securityContext,omitempty"` + // Configuration for enabling extra debug information in the container. + // Not recommended for production use. + // +optional + Debug *Debug `json:"debug,omitempty"` +} + +type Debug struct { + // Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/ + // and internal debug metrics endpoint at :9001/debug/metrics, where + // 9001 is a container port named "debug". The endpoints and their responses + // may change in backwards incompatible ways in the future, and should not + // be considered stable. + // + // In 1.78.x and 1.80.x, this setting will default to the value of + // .spec.metrics.enable, and requests to the "metrics" port matching the + // mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x, + // this setting will default to false, and no requests will be proxied. + // + // +optional + Enable bool `json:"enable"` } type Env struct { diff --git a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go index c2f69dc045314..07e46f3f5cde8 100644 --- a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go +++ b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go @@ -163,6 +163,11 @@ func (in *Container) DeepCopyInto(out *Container) { *out = new(corev1.SecurityContext) (*in).DeepCopyInto(*out) } + if in.Debug != nil { + in, out := &in.Debug, &out.Debug + *out = new(Debug) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Container. @@ -281,6 +286,21 @@ func (in *DNSConfigStatus) DeepCopy() *DNSConfigStatus { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Debug) DeepCopyInto(out *Debug) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Debug. +func (in *Debug) DeepCopy() *Debug { + if in == nil { + return nil + } + out := new(Debug) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Env) DeepCopyInto(out *Env) { *out = *in From 462e1fc503fa2c26d8ff1a70a641ebb835ac9f8f Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Fri, 22 Nov 2024 08:25:54 -0600 Subject: [PATCH 130/179] ipn/{ipnlocal,localapi}, wgengine/netstack: call (*LocalBackend).Shutdown when tests that create them complete We have several places where LocalBackend instances are created for testing, but they are rarely shut down when the tests that created them exit. In this PR, we update newTestLocalBackend and similar functions to use testing.TB.Cleanup(lb.Shutdown) to ensure LocalBackend instances are properly shut down during test cleanup. Updates #12687 Signed-off-by: Nick Khyl --- ipn/ipnlocal/local_test.go | 2 ++ ipn/ipnlocal/state_test.go | 3 +++ ipn/localapi/localapi_test.go | 1 + wgengine/netstack/netstack_test.go | 2 ++ 4 files changed, 8 insertions(+) diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index 6d25a418fc6a8..f30ff6adb6a5b 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -458,6 +458,7 @@ func newTestLocalBackendWithSys(t testing.TB, sys *tsd.System) *LocalBackend { if err != nil { t.Fatalf("NewLocalBackend: %v", err) } + t.Cleanup(lb.Shutdown) return lb } @@ -4109,6 +4110,7 @@ func newLocalBackendWithTestControl(t *testing.T, enableLogging bool, newControl if err != nil { t.Fatalf("NewLocalBackend: %v", err) } + t.Cleanup(b.Shutdown) b.DisablePortMapperForTest() b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) { diff --git a/ipn/ipnlocal/state_test.go b/ipn/ipnlocal/state_test.go index bebd0152b5a36..ef4b0ed62809f 100644 --- a/ipn/ipnlocal/state_test.go +++ b/ipn/ipnlocal/state_test.go @@ -309,6 +309,7 @@ func TestStateMachine(t *testing.T) { if err != nil { t.Fatalf("NewLocalBackend: %v", err) } + t.Cleanup(b.Shutdown) b.DisablePortMapperForTest() var cc, previousCC *mockControl @@ -942,6 +943,7 @@ func TestEditPrefsHasNoKeys(t *testing.T) { if err != nil { t.Fatalf("NewLocalBackend: %v", err) } + t.Cleanup(b.Shutdown) b.hostinfo = &tailcfg.Hostinfo{OS: "testos"} b.pm.SetPrefs((&ipn.Prefs{ Persist: &persist.Persist{ @@ -1023,6 +1025,7 @@ func TestWGEngineStatusRace(t *testing.T) { sys.Set(eng) b, err := NewLocalBackend(logf, logid.PublicID{}, sys, 0) c.Assert(err, qt.IsNil) + t.Cleanup(b.Shutdown) var cc *mockControl b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) { diff --git a/ipn/localapi/localapi_test.go b/ipn/localapi/localapi_test.go index d89c46261815a..145910830e80f 100644 --- a/ipn/localapi/localapi_test.go +++ b/ipn/localapi/localapi_test.go @@ -349,6 +349,7 @@ func newTestLocalBackend(t testing.TB) *ipnlocal.LocalBackend { if err != nil { t.Fatalf("NewLocalBackend: %v", err) } + t.Cleanup(lb.Shutdown) return lb } diff --git a/wgengine/netstack/netstack_test.go b/wgengine/netstack/netstack_test.go index a46dcf9dd6fc9..823acee9156b7 100644 --- a/wgengine/netstack/netstack_test.go +++ b/wgengine/netstack/netstack_test.go @@ -64,6 +64,7 @@ func TestInjectInboundLeak(t *testing.T) { if err != nil { t.Fatal(err) } + t.Cleanup(lb.Shutdown) ns, err := Create(logf, tunWrap, eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper()) if err != nil { @@ -126,6 +127,7 @@ func makeNetstack(tb testing.TB, config func(*Impl)) *Impl { if err != nil { tb.Fatalf("NewLocalBackend: %v", err) } + tb.Cleanup(lb.Shutdown) ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true }) if config != nil { From 8e5cfbe4ab11713e383b3ff0d978f116320de2a3 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Fri, 22 Nov 2024 09:05:01 -0600 Subject: [PATCH 131/179] util/syspolicy/rsop: reduce policyReloadMinDelay and policyReloadMaxDelay when in tests These delays determine how soon syspolicy change callbacks are invoked after a policy setting is updated in a policy source. For tests, we shorten these delays to minimize unnecessary wait times. This adjustment only affects tests that subscribe to policy change notifications and modify policy settings after they have already been set. Initial policy settings are always available immediately without delay. Updates #12687 Signed-off-by: Nick Khyl --- util/syspolicy/rsop/resultant_policy.go | 7 +++++++ util/syspolicy/rsop/resultant_policy_test.go | 13 ++++--------- util/syspolicy/rsop/store_registration.go | 4 ++++ 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/util/syspolicy/rsop/resultant_policy.go b/util/syspolicy/rsop/resultant_policy.go index 019b8f602f86d..b811a00eed77b 100644 --- a/util/syspolicy/rsop/resultant_policy.go +++ b/util/syspolicy/rsop/resultant_policy.go @@ -11,6 +11,7 @@ import ( "sync/atomic" "time" + "tailscale.com/util/syspolicy/internal" "tailscale.com/util/syspolicy/internal/loggerx" "tailscale.com/util/syspolicy/setting" @@ -447,3 +448,9 @@ func (p *Policy) Close() { go p.closeInternal() } } + +func setForTest[T any](tb internal.TB, target *T, newValue T) { + oldValue := *target + tb.Cleanup(func() { *target = oldValue }) + *target = newValue +} diff --git a/util/syspolicy/rsop/resultant_policy_test.go b/util/syspolicy/rsop/resultant_policy_test.go index b2408c7f71519..e4bfb1a886878 100644 --- a/util/syspolicy/rsop/resultant_policy_test.go +++ b/util/syspolicy/rsop/resultant_policy_test.go @@ -574,9 +574,6 @@ func TestPolicyChangeHasChanged(t *testing.T) { } func TestChangePolicySetting(t *testing.T) { - setForTest(t, &policyReloadMinDelay, 100*time.Millisecond) - setForTest(t, &policyReloadMaxDelay, 500*time.Millisecond) - // Register policy settings used in this test. settingA := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue) settingB := setting.NewDefinition("TestSettingB", setting.DeviceSetting, setting.StringValue) @@ -589,6 +586,10 @@ func TestChangePolicySetting(t *testing.T) { if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil { t.Fatalf("Failed to register policy store: %v", err) } + + setForTest(t, &policyReloadMinDelay, 100*time.Millisecond) + setForTest(t, &policyReloadMaxDelay, 500*time.Millisecond) + policy, err := policyForTest(t, setting.DeviceScope) if err != nil { t.Fatalf("Failed to get effective policy: %v", err) @@ -978,9 +979,3 @@ func policyForTest(tb testing.TB, target setting.PolicyScope) (*Policy, error) { }) return policy, nil } - -func setForTest[T any](tb testing.TB, target *T, newValue T) { - oldValue := *target - tb.Cleanup(func() { *target = oldValue }) - *target = newValue -} diff --git a/util/syspolicy/rsop/store_registration.go b/util/syspolicy/rsop/store_registration.go index 09c83e98804ca..f9836846e18ee 100644 --- a/util/syspolicy/rsop/store_registration.go +++ b/util/syspolicy/rsop/store_registration.go @@ -7,6 +7,7 @@ import ( "errors" "sync" "sync/atomic" + "time" "tailscale.com/util/syspolicy/internal" "tailscale.com/util/syspolicy/setting" @@ -33,6 +34,9 @@ func RegisterStore(name string, scope setting.PolicyScope, store source.Store) ( // RegisterStoreForTest is like [RegisterStore], but unregisters the store when // tb and all its subtests complete. func RegisterStoreForTest(tb internal.TB, name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { + setForTest(tb, &policyReloadMinDelay, 10*time.Millisecond) + setForTest(tb, &policyReloadMaxDelay, 500*time.Millisecond) + reg, err := RegisterStore(name, scope, store) if err == nil { tb.Cleanup(func() { From 50bf32a0ba13935273e200d52b9327821f25efc5 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Fri, 22 Nov 2024 09:35:55 -0600 Subject: [PATCH 132/179] cmd/tailscaled: flush DNS if FlushDNSOnSessionUnlock is true upon receiving a session change notification In this PR, we move the syspolicy.FlushDNSOnSessionUnlock check from service startup to when a session change notification is received. This ensures that the most recent policy setting value is used if it has changed since the service started. We also plan to handle session change notifications for unrelated reasons and need to decouple notification subscriptions from DNS anyway. Updates #12687 Updates tailscale/corp#18342 Signed-off-by: Nick Khyl --- cmd/tailscaled/tailscaled_windows.go | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/cmd/tailscaled/tailscaled_windows.go b/cmd/tailscaled/tailscaled_windows.go index 35c878f38ece3..67f9744659d4d 100644 --- a/cmd/tailscaled/tailscaled_windows.go +++ b/cmd/tailscaled/tailscaled_windows.go @@ -160,10 +160,7 @@ func (service *ipnService) Execute(args []string, r <-chan svc.ChangeRequest, ch changes <- svc.Status{State: svc.StartPending} syslogf("Service start pending") - svcAccepts := svc.AcceptStop - if flushDNSOnSessionUnlock, _ := syspolicy.GetBoolean(syspolicy.FlushDNSOnSessionUnlock, false); flushDNSOnSessionUnlock { - svcAccepts |= svc.AcceptSessionChange - } + svcAccepts := svc.AcceptStop | svc.AcceptSessionChange ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -371,13 +368,15 @@ func handleSessionChange(chgRequest svc.ChangeRequest) { return } - log.Printf("Received WTS_SESSION_UNLOCK event, initiating DNS flush.") - go func() { - err := dns.Flush() - if err != nil { - log.Printf("Error flushing DNS on session unlock: %v", err) - } - }() + if flushDNSOnSessionUnlock, _ := syspolicy.GetBoolean(syspolicy.FlushDNSOnSessionUnlock, false); flushDNSOnSessionUnlock { + log.Printf("Received WTS_SESSION_UNLOCK event, initiating DNS flush.") + go func() { + err := dns.Flush() + if err != nil { + log.Printf("Error flushing DNS on session unlock: %v", err) + } + }() + } } var ( From 7c8f663d7059467353d9cd0fdae7b83bb1d4b998 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Fri, 22 Nov 2024 09:52:08 -0600 Subject: [PATCH 133/179] cmd/tailscaled: log SCM interactions if the policy setting is enabled at the time of interaction This updates the syspolicy.LogSCMInteractions check to run at the time of an interaction, just before logging a message, instead of during service startup. This ensures the most recent policy setting is used if it has changed since the service started. Updates #12687 Signed-off-by: Nick Khyl --- cmd/tailscaled/tailscaled_windows.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/cmd/tailscaled/tailscaled_windows.go b/cmd/tailscaled/tailscaled_windows.go index 67f9744659d4d..786c5d8330939 100644 --- a/cmd/tailscaled/tailscaled_windows.go +++ b/cmd/tailscaled/tailscaled_windows.go @@ -134,14 +134,13 @@ func runWindowsService(pol *logpolicy.Policy) error { logger.Logf(log.Printf).JSON(1, "SupportInfo", osdiag.SupportInfo(osdiag.LogSupportInfoReasonStartup)) }() - if logSCMInteractions, _ := syspolicy.GetBoolean(syspolicy.LogSCMInteractions, false); logSCMInteractions { - syslog, err := eventlog.Open(serviceName) - if err == nil { - syslogf = func(format string, args ...any) { + if syslog, err := eventlog.Open(serviceName); err == nil { + syslogf = func(format string, args ...any) { + if logSCMInteractions, _ := syspolicy.GetBoolean(syspolicy.LogSCMInteractions, false); logSCMInteractions { syslog.Info(0, fmt.Sprintf(format, args...)) } - defer syslog.Close() } + defer syslog.Close() } syslogf("Service entering svc.Run") From 2ab66d9698cc77f27598f3642be4159c36231c65 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Thu, 21 Nov 2024 19:29:20 -0600 Subject: [PATCH 134/179] ipn/ipnlocal: move syspolicy handling from setExitNodeID to applySysPolicy This moves code that handles ExitNodeID/ExitNodeIP syspolicy settings from (*LocalBackend).setExitNodeID to applySysPolicy. Updates #12687 Signed-off-by: Nick Khyl --- ipn/ipnlocal/local.go | 71 ++++++++++++++++++++------------------ ipn/ipnlocal/local_test.go | 30 ++++++++++------ 2 files changed, 56 insertions(+), 45 deletions(-) diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index cbbea32aa8363..7c0ddc90cc7e5 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -1489,10 +1489,10 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control b.logf("SetControlClientStatus failed to select auto exit node: %v", err) } } - if setExitNodeID(prefs, curNetMap, b.lastSuggestedExitNode) { + if applySysPolicy(prefs, b.lastSuggestedExitNode) { prefsChanged = true } - if applySysPolicy(prefs) { + if setExitNodeID(prefs, curNetMap) { prefsChanged = true } @@ -1658,12 +1658,37 @@ var preferencePolicies = []preferencePolicyInfo{ // applySysPolicy overwrites configured preferences with policies that may be // configured by the system administrator in an OS-specific way. -func applySysPolicy(prefs *ipn.Prefs) (anyChange bool) { +func applySysPolicy(prefs *ipn.Prefs, lastSuggestedExitNode tailcfg.StableNodeID) (anyChange bool) { if controlURL, err := syspolicy.GetString(syspolicy.ControlURL, prefs.ControlURL); err == nil && prefs.ControlURL != controlURL { prefs.ControlURL = controlURL anyChange = true } + if exitNodeIDStr, _ := syspolicy.GetString(syspolicy.ExitNodeID, ""); exitNodeIDStr != "" { + exitNodeID := tailcfg.StableNodeID(exitNodeIDStr) + if shouldAutoExitNode() && lastSuggestedExitNode != "" { + exitNodeID = lastSuggestedExitNode + } + // Note: when exitNodeIDStr == "auto" && lastSuggestedExitNode == "", + // then exitNodeID is now "auto" which will never match a peer's node ID. + // When there is no a peer matching the node ID, traffic will blackhole, + // preventing accidental non-exit-node usage when a policy is in effect that requires an exit node. + if prefs.ExitNodeID != exitNodeID || prefs.ExitNodeIP.IsValid() { + anyChange = true + } + prefs.ExitNodeID = exitNodeID + prefs.ExitNodeIP = netip.Addr{} + } else if exitNodeIPStr, _ := syspolicy.GetString(syspolicy.ExitNodeIP, ""); exitNodeIPStr != "" { + exitNodeIP, err := netip.ParseAddr(exitNodeIPStr) + if exitNodeIP.IsValid() && err == nil { + if prefs.ExitNodeID != "" || prefs.ExitNodeIP != exitNodeIP { + anyChange = true + } + prefs.ExitNodeID = "" + prefs.ExitNodeIP = exitNodeIP + } + } + for _, opt := range preferencePolicies { if po, err := syspolicy.GetPreferenceOption(opt.key); err == nil { curVal := opt.get(prefs.View()) @@ -1770,30 +1795,7 @@ func (b *LocalBackend) updateNetmapDeltaLocked(muts []netmap.NodeMutation) (hand // setExitNodeID updates prefs to reference an exit node by ID, rather // than by IP. It returns whether prefs was mutated. -func setExitNodeID(prefs *ipn.Prefs, nm *netmap.NetworkMap, lastSuggestedExitNode tailcfg.StableNodeID) (prefsChanged bool) { - if exitNodeIDStr, _ := syspolicy.GetString(syspolicy.ExitNodeID, ""); exitNodeIDStr != "" { - exitNodeID := tailcfg.StableNodeID(exitNodeIDStr) - if shouldAutoExitNode() && lastSuggestedExitNode != "" { - exitNodeID = lastSuggestedExitNode - } - // Note: when exitNodeIDStr == "auto" && lastSuggestedExitNode == "", then exitNodeID is now "auto" which will never match a peer's node ID. - // When there is no a peer matching the node ID, traffic will blackhole, preventing accidental non-exit-node usage when a policy is in effect that requires an exit node. - changed := prefs.ExitNodeID != exitNodeID || prefs.ExitNodeIP.IsValid() - prefs.ExitNodeID = exitNodeID - prefs.ExitNodeIP = netip.Addr{} - return changed - } - - oldExitNodeID := prefs.ExitNodeID - if exitNodeIPStr, _ := syspolicy.GetString(syspolicy.ExitNodeIP, ""); exitNodeIPStr != "" { - exitNodeIP, err := netip.ParseAddr(exitNodeIPStr) - if exitNodeIP.IsValid() && err == nil { - prefsChanged = prefs.ExitNodeID != "" || prefs.ExitNodeIP != exitNodeIP - prefs.ExitNodeID = "" - prefs.ExitNodeIP = exitNodeIP - } - } - +func setExitNodeID(prefs *ipn.Prefs, nm *netmap.NetworkMap) (prefsChanged bool) { if nm == nil { // No netmap, can't resolve anything. return false @@ -1811,6 +1813,7 @@ func setExitNodeID(prefs *ipn.Prefs, nm *netmap.NetworkMap, lastSuggestedExitNod prefsChanged = true } + oldExitNodeID := prefs.ExitNodeID for _, peer := range nm.Peers { for _, addr := range peer.Addresses().All() { if !addr.IsSingleIP() || addr.Addr() != prefs.ExitNodeIP { @@ -1820,7 +1823,7 @@ func setExitNodeID(prefs *ipn.Prefs, nm *netmap.NetworkMap, lastSuggestedExitNod // reference it directly for next time. prefs.ExitNodeID = peer.StableID() prefs.ExitNodeIP = netip.Addr{} - return oldExitNodeID != prefs.ExitNodeID + return prefsChanged || oldExitNodeID != prefs.ExitNodeID } } @@ -3844,12 +3847,12 @@ func (b *LocalBackend) setPrefsLockedOnEntry(newp *ipn.Prefs, unlock unlockOnce) if oldp.Valid() { newp.Persist = oldp.Persist().AsStruct() // caller isn't allowed to override this } - // setExitNodeID returns whether it updated b.prefs, but - // everything in this function treats b.prefs as completely new - // anyway. No-op if no exit node resolution is needed. - setExitNodeID(newp, netMap, b.lastSuggestedExitNode) - // applySysPolicy does likewise so we can also ignore its return value. - applySysPolicy(newp) + // applySysPolicyToPrefsLocked returns whether it updated newp, + // but everything in this function treats b.prefs as completely new + // anyway, so its return value can be ignored here. + applySysPolicy(newp, b.lastSuggestedExitNode) + // setExitNodeID does likewise. No-op if no exit node resolution is needed. + setExitNodeID(newp, netMap) // We do this to avoid holding the lock while doing everything else. oldHi := b.hostinfo diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index f30ff6adb6a5b..c5bd512658991 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -1789,10 +1789,13 @@ func TestSetExitNodeIDPolicy(t *testing.T) { t.Run(test.name, func(t *testing.T) { b := newTestBackend(t) - policyStore := source.NewTestStoreOf(t, - source.TestSettingOf(syspolicy.ExitNodeID, test.exitNodeID), - source.TestSettingOf(syspolicy.ExitNodeIP, test.exitNodeIP), - ) + policyStore := source.NewTestStore(t) + if test.exitNodeIDKey { + policyStore.SetStrings(source.TestSettingOf(syspolicy.ExitNodeID, test.exitNodeID)) + } + if test.exitNodeIPKey { + policyStore.SetStrings(source.TestSettingOf(syspolicy.ExitNodeIP, test.exitNodeIP)) + } syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) if test.nm == nil { @@ -1806,7 +1809,16 @@ func TestSetExitNodeIDPolicy(t *testing.T) { b.netMap = test.nm b.pm = pm b.lastSuggestedExitNode = test.lastSuggestedExitNode - changed := setExitNodeID(b.pm.prefs.AsStruct(), test.nm, tailcfg.StableNodeID(test.lastSuggestedExitNode)) + + prefs := b.pm.prefs.AsStruct() + if changed := applySysPolicy(prefs, test.lastSuggestedExitNode) || setExitNodeID(prefs, test.nm); changed != test.prefsChanged { + t.Errorf("wanted prefs changed %v, got prefs changed %v", test.prefsChanged, changed) + } + + // Both [LocalBackend.SetPrefsForTest] and [LocalBackend.EditPrefs] + // apply syspolicy settings to the current profile's preferences. Therefore, + // we pass the current, unmodified preferences and expect the effective + // preferences to change. b.SetPrefsForTest(pm.CurrentPrefs().AsStruct()) if got := b.pm.prefs.ExitNodeID(); got != tailcfg.StableNodeID(test.exitNodeIDWant) { @@ -1819,10 +1831,6 @@ func TestSetExitNodeIDPolicy(t *testing.T) { } else if got.String() != test.exitNodeIPWant { t.Errorf("got %v want %v", got, test.exitNodeIPWant) } - - if changed != test.prefsChanged { - t.Errorf("wanted prefs changed %v, got prefs changed %v", test.prefsChanged, changed) - } }) } } @@ -2332,7 +2340,7 @@ func TestApplySysPolicy(t *testing.T) { t.Run("unit", func(t *testing.T) { prefs := tt.prefs.Clone() - gotAnyChange := applySysPolicy(prefs) + gotAnyChange := applySysPolicy(prefs, "") if gotAnyChange && prefs.Equals(&tt.prefs) { t.Errorf("anyChange but prefs is unchanged: %v", prefs.Pretty()) @@ -2480,7 +2488,7 @@ func TestPreferencePolicyInfo(t *testing.T) { prefs := defaultPrefs.AsStruct() pp.set(prefs, tt.initialValue) - gotAnyChange := applySysPolicy(prefs) + gotAnyChange := applySysPolicy(prefs, "") if gotAnyChange != tt.wantChange { t.Errorf("anyChange=%v, want %v", gotAnyChange, tt.wantChange) From eb3cd3291106dc603316e4df65ad85cc0d3b3e6b Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Fri, 22 Nov 2024 08:45:53 -0600 Subject: [PATCH 135/179] ipn/ipnlocal: update ipn.Prefs when there's a change in syspolicy settings In this PR, we update ipnlocal.NewLocalBackend to subscribe to policy change notifications and reapply syspolicy settings to the current profile's ipn.Prefs whenever a change occurs. Updates #12687 Signed-off-by: Nick Khyl --- ipn/ipnlocal/local.go | 102 ++++++++++++++++++++++-------- ipn/ipnlocal/local_test.go | 123 +++++++++++++++++++++++++++++++++++++ 2 files changed, 199 insertions(+), 26 deletions(-) diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 7c0ddc90cc7e5..8763581f1e3b3 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -106,6 +106,7 @@ import ( "tailscale.com/util/rands" "tailscale.com/util/set" "tailscale.com/util/syspolicy" + "tailscale.com/util/syspolicy/rsop" "tailscale.com/util/systemd" "tailscale.com/util/testenv" "tailscale.com/util/uniq" @@ -178,27 +179,28 @@ type watchSession struct { // state machine generates events back out to zero or more components. type LocalBackend struct { // Elements that are thread-safe or constant after construction. - ctx context.Context // canceled by Close - ctxCancel context.CancelFunc // cancels ctx - logf logger.Logf // general logging - keyLogf logger.Logf // for printing list of peers on change - statsLogf logger.Logf // for printing peers stats on change - sys *tsd.System - health *health.Tracker // always non-nil - metrics metrics - e wgengine.Engine // non-nil; TODO(bradfitz): remove; use sys - store ipn.StateStore // non-nil; TODO(bradfitz): remove; use sys - dialer *tsdial.Dialer // non-nil; TODO(bradfitz): remove; use sys - pushDeviceToken syncs.AtomicValue[string] - backendLogID logid.PublicID - unregisterNetMon func() - unregisterHealthWatch func() - portpoll *portlist.Poller // may be nil - portpollOnce sync.Once // guards starting readPoller - varRoot string // or empty if SetVarRoot never called - logFlushFunc func() // or nil if SetLogFlusher wasn't called - em *expiryManager // non-nil - sshAtomicBool atomic.Bool + ctx context.Context // canceled by Close + ctxCancel context.CancelFunc // cancels ctx + logf logger.Logf // general logging + keyLogf logger.Logf // for printing list of peers on change + statsLogf logger.Logf // for printing peers stats on change + sys *tsd.System + health *health.Tracker // always non-nil + metrics metrics + e wgengine.Engine // non-nil; TODO(bradfitz): remove; use sys + store ipn.StateStore // non-nil; TODO(bradfitz): remove; use sys + dialer *tsdial.Dialer // non-nil; TODO(bradfitz): remove; use sys + pushDeviceToken syncs.AtomicValue[string] + backendLogID logid.PublicID + unregisterNetMon func() + unregisterHealthWatch func() + unregisterSysPolicyWatch func() + portpoll *portlist.Poller // may be nil + portpollOnce sync.Once // guards starting readPoller + varRoot string // or empty if SetVarRoot never called + logFlushFunc func() // or nil if SetLogFlusher wasn't called + em *expiryManager // non-nil + sshAtomicBool atomic.Bool // webClientAtomicBool controls whether the web client is running. This should // be true unless the disable-web-client node attribute has been set. webClientAtomicBool atomic.Bool @@ -410,7 +412,7 @@ type clientGen func(controlclient.Options) (controlclient.Client, error) // but is not actually running. // // If dialer is nil, a new one is made. -func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, loginFlags controlclient.LoginFlags) (*LocalBackend, error) { +func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, loginFlags controlclient.LoginFlags) (_ *LocalBackend, err error) { e := sys.Engine.Get() store := sys.StateStore.Get() dialer := sys.Dialer.Get() @@ -485,6 +487,15 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo } } + if b.unregisterSysPolicyWatch, err = b.registerSysPolicyWatch(); err != nil { + return nil, err + } + defer func() { + if err != nil { + b.unregisterSysPolicyWatch() + } + }() + netMon := sys.NetMon.Get() b.sockstatLogger, err = sockstatlog.NewLogger(logpolicy.LogsDir(logf), logf, logID, netMon, sys.HealthTracker()) if err != nil { @@ -981,6 +992,7 @@ func (b *LocalBackend) Shutdown() { b.unregisterNetMon() b.unregisterHealthWatch() + b.unregisterSysPolicyWatch() if cc != nil { cc.Shutdown() } @@ -1703,6 +1715,40 @@ func applySysPolicy(prefs *ipn.Prefs, lastSuggestedExitNode tailcfg.StableNodeID return anyChange } +// registerSysPolicyWatch subscribes to syspolicy change notifications +// and immediately applies the effective syspolicy settings to the current profile. +func (b *LocalBackend) registerSysPolicyWatch() (unregister func(), err error) { + if unregister, err = syspolicy.RegisterChangeCallback(b.sysPolicyChanged); err != nil { + return nil, fmt.Errorf("syspolicy: LocalBacked failed to register policy change callback: %v", err) + } + if prefs, anyChange := b.applySysPolicy(); anyChange { + b.logf("syspolicy: changed initial profile prefs: %v", prefs.Pretty()) + } + return unregister, nil +} + +// applySysPolicy overwrites the current profile's preferences with policies +// that may be configured by the system administrator in an OS-specific way. +// +// b.mu must not be held. +func (b *LocalBackend) applySysPolicy() (_ ipn.PrefsView, anyChange bool) { + unlock := b.lockAndGetUnlock() + prefs := b.pm.CurrentPrefs().AsStruct() + if !applySysPolicy(prefs, b.lastSuggestedExitNode) { + unlock.UnlockEarly() + return prefs.View(), false + } + return b.setPrefsLockedOnEntry(prefs, unlock), true +} + +// sysPolicyChanged is a callback triggered by syspolicy when it detects +// a change in one or more syspolicy settings. +func (b *LocalBackend) sysPolicyChanged(*rsop.PolicyChange) { + if prefs, anyChange := b.applySysPolicy(); anyChange { + b.logf("syspolicy: changed profile prefs: %v", prefs.Pretty()) + } +} + var _ controlclient.NetmapDeltaUpdater = (*LocalBackend)(nil) // UpdateNetmapDelta implements controlclient.NetmapDeltaUpdater. @@ -3889,10 +3935,14 @@ func (b *LocalBackend) setPrefsLockedOnEntry(newp *ipn.Prefs, unlock unlockOnce) } prefs := newp.View() - if err := b.pm.SetPrefs(prefs, ipn.NetworkProfile{ - MagicDNSName: b.netMap.MagicDNSSuffix(), - DomainName: b.netMap.DomainName(), - }); err != nil { + np := b.pm.CurrentProfile().NetworkProfile + if netMap != nil { + np = ipn.NetworkProfile{ + MagicDNSName: b.netMap.MagicDNSSuffix(), + DomainName: b.netMap.DomainName(), + } + } + if err := b.pm.SetPrefs(prefs, np); err != nil { b.logf("failed to save new controlclient state: %v", err) } diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index c5bd512658991..b1be86392185d 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -4562,3 +4562,126 @@ func TestGetVIPServices(t *testing.T) { }) } } + +func TestUpdatePrefsOnSysPolicyChange(t *testing.T) { + const enableLogging = false + + type fieldChange struct { + name string + want any + } + + wantPrefsChanges := func(want ...fieldChange) *wantedNotification { + return &wantedNotification{ + name: "Prefs", + cond: func(t testing.TB, actor ipnauth.Actor, n *ipn.Notify) bool { + if n.Prefs != nil { + prefs := reflect.Indirect(reflect.ValueOf(n.Prefs.AsStruct())) + for _, f := range want { + got := prefs.FieldByName(f.name).Interface() + if !reflect.DeepEqual(got, f.want) { + t.Errorf("%v: got %v; want %v", f.name, got, f.want) + } + } + } + return n.Prefs != nil + }, + } + } + + unexpectedPrefsChange := func(t testing.TB, _ ipnauth.Actor, n *ipn.Notify) bool { + if n.Prefs != nil { + t.Errorf("Unexpected Prefs: %v", n.Prefs.Pretty()) + return true + } + return false + } + + tests := []struct { + name string + initialPrefs *ipn.Prefs + stringSettings []source.TestSetting[string] + want *wantedNotification + }{ + { + name: "ShieldsUp/True", + stringSettings: []source.TestSetting[string]{source.TestSettingOf(syspolicy.EnableIncomingConnections, "never")}, + want: wantPrefsChanges(fieldChange{"ShieldsUp", true}), + }, + { + name: "ShieldsUp/False", + initialPrefs: &ipn.Prefs{ShieldsUp: true}, + stringSettings: []source.TestSetting[string]{source.TestSettingOf(syspolicy.EnableIncomingConnections, "always")}, + want: wantPrefsChanges(fieldChange{"ShieldsUp", false}), + }, + { + name: "ExitNodeID", + stringSettings: []source.TestSetting[string]{source.TestSettingOf(syspolicy.ExitNodeID, "foo")}, + want: wantPrefsChanges(fieldChange{"ExitNodeID", tailcfg.StableNodeID("foo")}), + }, + { + name: "EnableRunExitNode", + stringSettings: []source.TestSetting[string]{source.TestSettingOf(syspolicy.EnableRunExitNode, "always")}, + want: wantPrefsChanges(fieldChange{"AdvertiseRoutes", []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}}), + }, + { + name: "Multiple", + initialPrefs: &ipn.Prefs{ + ExitNodeAllowLANAccess: true, + }, + stringSettings: []source.TestSetting[string]{ + source.TestSettingOf(syspolicy.EnableServerMode, "always"), + source.TestSettingOf(syspolicy.ExitNodeAllowLANAccess, "never"), + source.TestSettingOf(syspolicy.ExitNodeIP, "127.0.0.1"), + }, + want: wantPrefsChanges( + fieldChange{"ForceDaemon", true}, + fieldChange{"ExitNodeAllowLANAccess", false}, + fieldChange{"ExitNodeIP", netip.MustParseAddr("127.0.0.1")}, + ), + }, + { + name: "NoChange", + initialPrefs: &ipn.Prefs{ + CorpDNS: true, + ExitNodeID: "foo", + AdvertiseRoutes: []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + stringSettings: []source.TestSetting[string]{ + source.TestSettingOf(syspolicy.EnableTailscaleDNS, "always"), + source.TestSettingOf(syspolicy.ExitNodeID, "foo"), + source.TestSettingOf(syspolicy.EnableRunExitNode, "always"), + }, + want: nil, // syspolicy settings match the preferences; no change notification is expected. + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + syspolicy.RegisterWellKnownSettingsForTest(t) + store := source.NewTestStoreOf[string](t) + syspolicy.MustRegisterStoreForTest(t, "TestSource", setting.DeviceScope, store) + + lb := newLocalBackendWithTestControl(t, enableLogging, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + return newClient(tb, opts) + }) + if tt.initialPrefs != nil { + lb.SetPrefsForTest(tt.initialPrefs) + } + if err := lb.Start(ipn.Options{}); err != nil { + t.Fatalf("(*LocalBackend).Start(): %v", err) + } + + nw := newNotificationWatcher(t, lb, &ipnauth.TestActor{}) + if tt.want != nil { + nw.watch(0, []wantedNotification{*tt.want}) + } else { + nw.watch(0, nil, unexpectedPrefsChange) + } + + store.SetStrings(tt.stringSettings...) + + nw.check() + }) + } +} From 3353f154bb341c9ed9e05ef21e5475f922986def Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Fri, 22 Nov 2024 09:28:56 -0600 Subject: [PATCH 136/179] control/controlclient: use the most recent syspolicy.MachineCertificateSubject value This PR removes the sync.Once wrapper around retrieving the MachineCertificateSubject policy setting value, ensuring the most recent version is always used if it changes after the service starts. Although this policy setting is used by a very limited number of customers, recent support escalations have highlighted issues caused by outdated or incorrect policy values being applied. Updates #12687 Signed-off-by: Nick Khyl --- control/controlclient/sign_supported.go | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/control/controlclient/sign_supported.go b/control/controlclient/sign_supported.go index 0e3dd038e4ed7..a5d42ad7df4a2 100644 --- a/control/controlclient/sign_supported.go +++ b/control/controlclient/sign_supported.go @@ -13,7 +13,6 @@ import ( "crypto/x509" "errors" "fmt" - "sync" "time" "github.com/tailscale/certstore" @@ -22,11 +21,6 @@ import ( "tailscale.com/util/syspolicy" ) -var getMachineCertificateSubjectOnce struct { - sync.Once - v string // Subject of machine certificate to search for -} - // getMachineCertificateSubject returns the exact name of a Subject that needs // to be present in an identity's certificate chain to sign a RegisterRequest, // formatted as per pkix.Name.String(). The Subject may be that of the identity @@ -37,11 +31,8 @@ var getMachineCertificateSubjectOnce struct { // // Example: "CN=Tailscale Inc Test Root CA,OU=Tailscale Inc Test Certificate Authority,O=Tailscale Inc,ST=ON,C=CA" func getMachineCertificateSubject() string { - getMachineCertificateSubjectOnce.Do(func() { - getMachineCertificateSubjectOnce.v, _ = syspolicy.GetString(syspolicy.MachineCertificateSubject, "") - }) - - return getMachineCertificateSubjectOnce.v + machineCertSubject, _ := syspolicy.GetString(syspolicy.MachineCertificateSubject, "") + return machineCertSubject } var ( From 36b7449feafcf5450261193c2507a07f0fedcfa0 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Fri, 22 Nov 2024 09:57:26 -0600 Subject: [PATCH 137/179] ipn/ipnlocal: rebuild allowed suggested exit nodes when syspolicy changes In this PR, we update LocalBackend to rebuild the set of allowed suggested exit nodes whenever the AllowedSuggestedExitNodes syspolicy setting changes. Additionally, we request a new suggested exit node when this occurs, enabling its use if the ExitNodeID syspolicy setting is set to auto:any. Updates #12687 Signed-off-by: Nick Khyl --- ipn/ipnlocal/local.go | 43 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 8763581f1e3b3..fdbd5cf52beb0 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -87,7 +87,6 @@ import ( "tailscale.com/types/dnstype" "tailscale.com/types/empty" "tailscale.com/types/key" - "tailscale.com/types/lazy" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" @@ -356,6 +355,12 @@ type LocalBackend struct { // avoid unnecessary churn between multiple equally-good options. lastSuggestedExitNode tailcfg.StableNodeID + // allowedSuggestedExitNodes is a set of exit nodes permitted by the most recent + // [syspolicy.AllowedSuggestedExitNodes] value. The allowedSuggestedExitNodesMu + // mutex guards access to this set. + allowedSuggestedExitNodesMu sync.Mutex + allowedSuggestedExitNodes set.Set[tailcfg.StableNodeID] + // refreshAutoExitNode indicates if the exit node should be recomputed when the next netcheck report is available. refreshAutoExitNode bool @@ -1724,6 +1729,7 @@ func (b *LocalBackend) registerSysPolicyWatch() (unregister func(), err error) { if prefs, anyChange := b.applySysPolicy(); anyChange { b.logf("syspolicy: changed initial profile prefs: %v", prefs.Pretty()) } + b.refreshAllowedSuggestions() return unregister, nil } @@ -1743,7 +1749,20 @@ func (b *LocalBackend) applySysPolicy() (_ ipn.PrefsView, anyChange bool) { // sysPolicyChanged is a callback triggered by syspolicy when it detects // a change in one or more syspolicy settings. -func (b *LocalBackend) sysPolicyChanged(*rsop.PolicyChange) { +func (b *LocalBackend) sysPolicyChanged(policy *rsop.PolicyChange) { + if policy.HasChanged(syspolicy.AllowedSuggestedExitNodes) { + b.refreshAllowedSuggestions() + // Re-evaluate exit node suggestion now that the policy setting has changed. + b.mu.Lock() + _, err := b.suggestExitNodeLocked(nil) + b.mu.Unlock() + if err != nil && !errors.Is(err, ErrNoPreferredDERP) { + b.logf("failed to select auto exit node: %v", err) + } + // If [syspolicy.ExitNodeID] is set to `auto:any`, the suggested exit node ID + // will be used when [applySysPolicy] updates the current profile's prefs. + } + if prefs, anyChange := b.applySysPolicy(); anyChange { b.logf("syspolicy: changed profile prefs: %v", prefs.Pretty()) } @@ -7197,7 +7216,7 @@ func (b *LocalBackend) suggestExitNodeLocked(netMap *netmap.NetworkMap) (respons lastReport := b.MagicConn().GetLastNetcheckReport(b.ctx) prevSuggestion := b.lastSuggestedExitNode - res, err := suggestExitNode(lastReport, netMap, prevSuggestion, randomRegion, randomNode, getAllowedSuggestions()) + res, err := suggestExitNode(lastReport, netMap, prevSuggestion, randomRegion, randomNode, b.getAllowedSuggestions()) if err != nil { return res, err } @@ -7211,6 +7230,22 @@ func (b *LocalBackend) SuggestExitNode() (response apitype.ExitNodeSuggestionRes return b.suggestExitNodeLocked(nil) } +// getAllowedSuggestions returns a set of exit nodes permitted by the most recent +// [syspolicy.AllowedSuggestedExitNodes] value. Callers must not mutate the returned set. +func (b *LocalBackend) getAllowedSuggestions() set.Set[tailcfg.StableNodeID] { + b.allowedSuggestedExitNodesMu.Lock() + defer b.allowedSuggestedExitNodesMu.Unlock() + return b.allowedSuggestedExitNodes +} + +// refreshAllowedSuggestions rebuilds the set of permitted exit nodes +// from the current [syspolicy.AllowedSuggestedExitNodes] value. +func (b *LocalBackend) refreshAllowedSuggestions() { + b.allowedSuggestedExitNodesMu.Lock() + defer b.allowedSuggestedExitNodesMu.Unlock() + b.allowedSuggestedExitNodes = fillAllowedSuggestions() +} + // selectRegionFunc returns a DERP region from the slice of candidate regions. // The value is returned, not the slice index. type selectRegionFunc func(views.Slice[int]) int @@ -7220,8 +7255,6 @@ type selectRegionFunc func(views.Slice[int]) int // choice. type selectNodeFunc func(nodes views.Slice[tailcfg.NodeView], last tailcfg.StableNodeID) tailcfg.NodeView -var getAllowedSuggestions = lazy.SyncFunc(fillAllowedSuggestions) - func fillAllowedSuggestions() set.Set[tailcfg.StableNodeID] { nodes, err := syspolicy.GetStringArray(syspolicy.AllowedSuggestedExitNodes, nil) if err != nil { From f6431185b0cd196acbefdda9fec523ed4d408aed Mon Sep 17 00:00:00 2001 From: James Tucker Date: Fri, 22 Nov 2024 14:26:42 -0800 Subject: [PATCH 138/179] net/netmon: catch ParseRIB panic to gather buffer data Updates #14201 Updates golang/go#70528 Signed-off-by: James Tucker --- net/netmon/netmon_darwin.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/net/netmon/netmon_darwin.go b/net/netmon/netmon_darwin.go index cc630112523fa..a5096889b84fd 100644 --- a/net/netmon/netmon_darwin.go +++ b/net/netmon/netmon_darwin.go @@ -56,7 +56,15 @@ func (m *darwinRouteMon) Receive() (message, error) { if err != nil { return nil, err } - msgs, err := route.ParseRIB(route.RIBTypeRoute, m.buf[:n]) + msgs, err := func() (msgs []route.Message, err error) { + defer func() { + if recover() != nil { + msgs = nil + err = fmt.Errorf("panic parsing route message") + } + }() + return route.ParseRIB(route.RIBTypeRoute, m.buf[:n]) + }() if err != nil { if debugRouteMessages { m.logf("read %d bytes (% 02x), failed to parse RIB: %v", n, m.buf[:n], err) From ba3523fc3f62835bcddba683e37257ed7d53493c Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Sat, 23 Nov 2024 08:51:40 +0000 Subject: [PATCH 139/179] cmd/containerboot: preserve headers of metrics endpoints responses (#14204) Updates tailscale/tailscale#11292 Signed-off-by: Irbe Krumina --- cmd/containerboot/metrics.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/containerboot/metrics.go b/cmd/containerboot/metrics.go index e88406f97c9c6..874774d7a4cf2 100644 --- a/cmd/containerboot/metrics.go +++ b/cmd/containerboot/metrics.go @@ -38,12 +38,12 @@ func proxy(w http.ResponseWriter, r *http.Request, url string, do func(*http.Req } defer resp.Body.Close() - w.WriteHeader(resp.StatusCode) for key, val := range resp.Header { for _, v := range val { w.Header().Add(key, v) } } + w.WriteHeader(resp.StatusCode) if _, err := io.Copy(w, resp.Body); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } From 788121f47536f2947e514370b45eaa1029a54488 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Mon, 25 Nov 2024 10:10:32 -0600 Subject: [PATCH 140/179] docs/windows/policy: update ADMX policy definitions to reflect the syspolicy settings We add a policy definition for the AllowedSuggestedExitNodes syspolicy setting, allowing admins to configure a list of exit node IDs to be used as a pool for automatic suggested exit node selection. We update definitions for policy settings configurable on both a per-user and per-machine basis, such as UI customizations, to specify class="Both". Lastly, we update the help text for existing policy definitions to include a link to the KB article as the last line instead of in the first paragraph. Updates #12687 Updates tailscale/corp#19681 Signed-off-by: Nick Khyl --- docs/windows/policy/en-US/tailscale.adml | 111 ++++++++++++++--------- docs/windows/policy/tailscale.admx | 31 +++++-- 2 files changed, 91 insertions(+), 51 deletions(-) diff --git a/docs/windows/policy/en-US/tailscale.adml b/docs/windows/policy/en-US/tailscale.adml index 7a658422cd7f6..ebf1a5905f6e9 100644 --- a/docs/windows/policy/en-US/tailscale.adml +++ b/docs/windows/policy/en-US/tailscale.adml @@ -15,16 +15,18 @@ Tailscale version 1.58.0 and later Tailscale version 1.62.0 and later Tailscale version 1.74.0 and later + Tailscale version 1.78.0 and later Tailscale UI customization Settings Require using a specific Tailscale coordination server +If you disable or do not configure this policy, the Tailscale SaaS coordination server will be used by default, but a non-standard Tailscale coordination server can be configured using the CLI. + +See https://tailscale.com/kb/1315/mdm-keys#set-a-custom-control-server-url for more details.]]> Require using a specific Tailscale log server Specify which Tailnet should be used for Login +See https://tailscale.com/kb/1315/mdm-keys#set-a-suggested-or-required-tailnet for more details.]]> Specify the auth key to authenticate devices without user interaction Require using a specific Exit Node +If you do not configure this policy, no exit node will be used by default but an exit node (if one is available and permitted by ACLs) can be chosen by the user if desired. + +See https://tailscale.com/kb/1315/mdm-keys#force-an-exit-node-to-always-be-used and https://tailscale.com/kb/1103/exit-nodes for more details.]]> + Limit automated Exit Node suggestions to specific nodes + Allow incoming connections +If you do not configure this policy, then Allow Incoming Connections depends on what is selected in the Preferences submenu. + +See https://tailscale.com/kb/1315/mdm-keys#set-whether-to-allow-incoming-connections and https://tailscale.com/kb/1072/client-preferences#allow-incoming-connections for more details.]]> Run Tailscale in Unattended Mode +If you do not configure this policy, then Run Unattended depends on what is selected in the Preferences submenu. + +See https://tailscale.com/kb/1315/mdm-keys#set-unattended-mode and https://tailscale.com/kb/1088/run-unattended for more details.]]> Allow Local Network Access when an Exit Node is in use +If you do not configure this policy, then Allow Local Network Access depends on what is selected in the Exit Node submenu. + +See https://tailscale.com/kb/1315/mdm-keys#toggle-local-network-access-when-an-exit-node-is-in-use and https://tailscale.com/kb/1103/exit-nodes#step-4-use-the-exit-node for more details.]]> Use Tailscale DNS Settings +If you do not configure this policy, then Use Tailscale DNS depends on what is selected in the Preferences submenu. + +See https://tailscale.com/kb/1315/mdm-keys#set-whether-the-device-uses-tailscale-dns-settings for more details.]]> Use Tailscale Subnets +If you do not configure this policy, then Use Tailscale Subnets depends on what is selected in the Preferences submenu. + +See https://tailscale.com/kb/1315/mdm-keys#set-whether-the-device-accepts-tailscale-subnets or https://tailscale.com/kb/1019/subnets for more details.]]> Automatically install updates +If you do not configure this policy, then Automatically Install Updates depends on what is selected in the Preferences submenu. + +See https://tailscale.com/kb/1067/update#auto-updates for more details.]]> Run Tailscale as an Exit Node - Show the "Admin Panel" menu item - + Show the "Admin Console" menu item + Show the "Debug" submenu +If you disable this policy, the Debug submenu will be hidden from the Tailscale menu. + +See https://tailscale.com/kb/1315/mdm-keys#hide-the-debug-menu for more details.]]> Show the "Update Available" menu item +If you disable this policy, the Update Available item will be hidden from the Tailscale menu. + +See https://tailscale.com/kb/1315/mdm-keys#hide-the-update-menu for more details.]]> Show the "Run Exit Node" menu item +If you disable this policy, the Run Exit Node item will be hidden from the Exit Node submenu. + +See https://tailscale.com/kb/1315/mdm-keys#hide-the-run-as-exit-node-menu-item for more details.]]> Show the "Preferences" submenu +If you disable this policy, the Preferences submenu will be hidden from the Tailscale menu. + +See https://tailscale.com/kb/1315/mdm-keys#hide-the-preferences-menu for more details.]]> Show the "Exit Node" submenu +If you disable this policy, the Exit Node submenu will be hidden from the Tailscale menu. + +See https://tailscale.com/kb/1315/mdm-keys#hide-the-exit-node-picker for more details.]]> Specify a custom key expiration notification time +If you disable or don't configure this policy, the default time period will be used (as of Tailscale 1.56, this is 24 hours). + +See https://tailscale.com/kb/1315/mdm-keys#set-the-key-expiration-notice-period for more details.]]> Log extra details about service events Collect data for posture checking +If you do not configure this policy, then data collection depends on if it has been enabled from the CLI (as of Tailscale 1.56), it may be present in the GUI in later versions. + +See https://tailscale.com/kb/1315/mdm-keys#enable-gathering-device-posture-data and https://tailscale.com/kb/1326/device-identity for more details.]]> Show the "Managed By {Organization}" menu item Exit Node: + + Target IDs: + diff --git a/docs/windows/policy/tailscale.admx b/docs/windows/policy/tailscale.admx index e70f124ed1a36..f941525c4fc9c 100644 --- a/docs/windows/policy/tailscale.admx +++ b/docs/windows/policy/tailscale.admx @@ -50,6 +50,10 @@ displayName="$(string.SINCE_V1_74)"> + + + @@ -94,7 +98,14 @@ - + > + + + + + + + @@ -197,7 +208,7 @@ - + @@ -207,7 +218,7 @@ hide - + @@ -217,7 +228,7 @@ hide - + @@ -227,7 +238,7 @@ hide - + @@ -237,7 +248,7 @@ hide - + @@ -247,7 +258,7 @@ hide - + @@ -257,7 +268,7 @@ hide - + @@ -267,7 +278,7 @@ hide - + @@ -276,7 +287,7 @@ - + From 4d33f30f91eb7debdf90c8770990801f3857e30c Mon Sep 17 00:00:00 2001 From: James Tucker Date: Mon, 25 Nov 2024 12:00:16 -0800 Subject: [PATCH 141/179] net/netmon: improve panic reporting from #14202 I was hoping we'd catch an example input quickly, but the reporter had rebooted their machine and it is no longer exhibiting the behavior. As such this code may be sticking around quite a bit longer and we might encounter other errors, so include the panic in the log entry. Updates #14201 Updates #14202 Updates golang/go#70528 Signed-off-by: James Tucker --- net/netmon/netmon_darwin.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/net/netmon/netmon_darwin.go b/net/netmon/netmon_darwin.go index a5096889b84fd..e89e2d04794e5 100644 --- a/net/netmon/netmon_darwin.go +++ b/net/netmon/netmon_darwin.go @@ -58,9 +58,12 @@ func (m *darwinRouteMon) Receive() (message, error) { } msgs, err := func() (msgs []route.Message, err error) { defer func() { - if recover() != nil { + // TODO(raggi,#14201): remove once we've got a fix from + // golang/go#70528. + msg := recover() + if msg != nil { msgs = nil - err = fmt.Errorf("panic parsing route message") + err = fmt.Errorf("panic in route.ParseRIB: %s", msg) } }() return route.ParseRIB(route.RIBTypeRoute, m.buf[:n]) From 26de518413277e0869b815c373f694f6b5d18562 Mon Sep 17 00:00:00 2001 From: Mario Minardi Date: Tue, 26 Nov 2024 10:45:03 -0700 Subject: [PATCH 142/179] ipn/ipnlocal: only check CanUseExitNode if we are attempting to use one (#14230) In https://github.com/tailscale/tailscale/pull/13726 we added logic to `checkExitNodePrefsLocked` to error out on platforms where using an exit node is unsupported in order to give users more obvious feedback than having this silently fail downstream. The above change neglected to properly check whether the device in question was actually trying to use an exit node when doing the check and was incorrectly returning an error on any calls to `checkExitNodePrefsLocked` on platforms where using an exit node is not supported as a result. This change remedies this by adding a check to see whether the device is attempting to use an exit node before doing the `CanUseExitNode` check. Updates https://github.com/tailscale/corp/issues/24835 Signed-off-by: Mario Minardi --- ipn/ipnlocal/local.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index fdbd5cf52beb0..278614c0b90dd 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -3740,11 +3740,16 @@ func updateExitNodeUsageWarning(p ipn.PrefsView, state *netmon.State, healthTrac } func (b *LocalBackend) checkExitNodePrefsLocked(p *ipn.Prefs) error { + tryingToUseExitNode := p.ExitNodeIP.IsValid() || p.ExitNodeID != "" + if !tryingToUseExitNode { + return nil + } + if err := featureknob.CanUseExitNode(); err != nil { return err } - if (p.ExitNodeIP.IsValid() || p.ExitNodeID != "") && p.AdvertisesExitNode() { + if p.AdvertisesExitNode() { return errors.New("Cannot advertise an exit node and use an exit node at the same time.") } return nil From a62f7183e4f121a66a7ab32b474d7c5b3f349286 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Tue, 26 Nov 2024 13:11:55 -0600 Subject: [PATCH 143/179] cmd/tailscale/cli: fix format string Updates #12687 Signed-off-by: Nick Khyl --- cmd/tailscale/cli/syspolicy.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/tailscale/cli/syspolicy.go b/cmd/tailscale/cli/syspolicy.go index 06a19defb459a..0e903db397c7d 100644 --- a/cmd/tailscale/cli/syspolicy.go +++ b/cmd/tailscale/cli/syspolicy.go @@ -98,9 +98,9 @@ func printPolicySettings(policy *setting.Snapshot) { origin = o.String() } if err := setting.Error(); err != nil { - fmt.Fprintf(w, "%s\t%s\t\t{%s}\n", k, origin, err) + fmt.Fprintf(w, "%s\t%s\t\t{%v}\n", k, origin, err) } else { - fmt.Fprintf(w, "%s\t%s\t%s\t\n", k, origin, setting.Value()) + fmt.Fprintf(w, "%s\t%s\t%v\t\n", k, origin, setting.Value()) } } w.Flush() From e87b71ec3c7bded3fadf44cb9374df5de5e213d6 Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Tue, 26 Nov 2024 17:50:29 -0500 Subject: [PATCH 144/179] control/controlhttp: set *health.Tracker in tests Observed during another PR: https://github.com/tailscale/tailscale/actions/runs/12040045880/job/33569141807 Updates #cleanup Signed-off-by: Andrew Dunham Change-Id: I9e0f49a35485fa2e097892737e5e3c95bf775a90 --- control/controlhttp/http_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index 00cc1e6cfd80b..aef916ef651c1 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -25,6 +25,7 @@ import ( "tailscale.com/control/controlbase" "tailscale.com/control/controlhttp/controlhttpcommon" "tailscale.com/control/controlhttp/controlhttpserver" + "tailscale.com/health" "tailscale.com/net/dnscache" "tailscale.com/net/netmon" "tailscale.com/net/socks5" @@ -228,6 +229,7 @@ func testControlHTTP(t *testing.T, param httpTestParam) { omitCertErrorLogging: true, testFallbackDelay: fallbackDelay, Clock: clock, + HealthTracker: new(health.Tracker), } if param.httpInDial { @@ -729,6 +731,7 @@ func TestDialPlan(t *testing.T) { omitCertErrorLogging: true, testFallbackDelay: 50 * time.Millisecond, Clock: clock, + HealthTracker: new(health.Tracker), } conn, err := a.dial(ctx) From bb80f14ff42c0e167eb34d65428a63a81d1090a2 Mon Sep 17 00:00:00 2001 From: Anton Tolchanov Date: Tue, 26 Nov 2024 18:13:17 +0000 Subject: [PATCH 145/179] ipn/localapi: count localapi requests to metric endpoints Updates tailscale/corp#22075 Signed-off-by: Anton Tolchanov --- ipn/localapi/localapi.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index dc8c089758371..ea931b0280ccf 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -563,6 +563,7 @@ func (h *Handler) serveLogTap(w http.ResponseWriter, r *http.Request) { } func (h *Handler) serveMetrics(w http.ResponseWriter, r *http.Request) { + metricDebugMetricsCalls.Add(1) // Require write access out of paranoia that the metrics // might contain something sensitive. if !h.PermitWrite { @@ -576,6 +577,7 @@ func (h *Handler) serveMetrics(w http.ResponseWriter, r *http.Request) { // serveUserMetrics returns user-facing metrics in Prometheus text // exposition format. func (h *Handler) serveUserMetrics(w http.ResponseWriter, r *http.Request) { + metricUserMetricsCalls.Add(1) h.b.UserMetricsRegistry().Handler(w, r) } @@ -2972,7 +2974,9 @@ var ( metricInvalidRequests = clientmetric.NewCounter("localapi_invalid_requests") // User-visible LocalAPI endpoints. - metricFilePutCalls = clientmetric.NewCounter("localapi_file_put") + metricFilePutCalls = clientmetric.NewCounter("localapi_file_put") + metricDebugMetricsCalls = clientmetric.NewCounter("localapi_debugmetric_requests") + metricUserMetricsCalls = clientmetric.NewCounter("localapi_usermetric_requests") ) // serveSuggestExitNode serves a POST endpoint for returning a suggested exit node. From bac3af06f5a2e7dcf2976e4d8e846eab0a52b514 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Wed, 27 Nov 2024 11:18:04 -0800 Subject: [PATCH 146/179] logtail: avoid bytes.Buffer allocation (#11858) Re-use a pre-allocated bytes.Buffer struct and shallow the copy the result of bytes.NewBuffer into it to avoid allocating the struct. Note that we're only reusing the bytes.Buffer struct itself and not the underling []byte temporarily stored within it. Updates #cleanup Updates tailscale/corp#18514 Updates golang/go#67004 Signed-off-by: Joe Tsai --- logtail/logtail.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/logtail/logtail.go b/logtail/logtail.go index 9df164273d74c..13e8e85fd40f7 100644 --- a/logtail/logtail.go +++ b/logtail/logtail.go @@ -213,6 +213,7 @@ type Logger struct { procSequence uint64 flushTimer tstime.TimerController // used when flushDelay is >0 writeBuf [bufferSize]byte // owned by Write for reuse + bytesBuf bytes.Buffer // owned by appendTextOrJSONLocked for reuse jsonDec jsontext.Decoder // owned by appendTextOrJSONLocked for reuse shutdownStartMu sync.Mutex // guards the closing of shutdownStart @@ -725,9 +726,16 @@ func (l *Logger) appendTextOrJSONLocked(dst, src []byte, level int) []byte { // whether it contains the reserved "logtail" name at the top-level. var logtailKeyOffset, logtailValOffset, logtailValLength int validJSON := func() bool { - // TODO(dsnet): Avoid allocation of bytes.Buffer struct. + // The jsontext.NewDecoder API operates on an io.Reader, for which + // bytes.Buffer provides a means to convert a []byte into an io.Reader. + // However, bytes.NewBuffer normally allocates unless + // we immediately shallow copy it into a pre-allocated Buffer struct. + // See https://go.dev/issue/67004. + l.bytesBuf = *bytes.NewBuffer(src) + defer func() { l.bytesBuf = bytes.Buffer{} }() // avoid pinning src + dec := &l.jsonDec - dec.Reset(bytes.NewBuffer(src)) + dec.Reset(&l.bytesBuf) if tok, err := dec.ReadToken(); tok.Kind() != '{' || err != nil { return false } From 41e56cedf8eea406e48ec1def6e7ea13a0c303fd Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 20 Nov 2024 11:46:14 +0100 Subject: [PATCH 147/179] health: move health metrics test to health_test Updates #13420 Signed-off-by: Kristoffer Dalby --- health/health.go | 4 +++- health/health_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++- tsnet/tsnet_test.go | 31 ---------------------------- 3 files changed, 50 insertions(+), 33 deletions(-) diff --git a/health/health.go b/health/health.go index 3bebcb98356f4..079b3195c8e86 100644 --- a/health/health.go +++ b/health/health.go @@ -331,7 +331,7 @@ func (t *Tracker) SetMetricsRegistry(reg *usermetric.Registry) { ) t.metricHealthMessage.Set(metricHealthMessageLabel{ - Type: "warning", + Type: MetricLabelWarning, }, expvar.Func(func() any { if t.nil() { return 0 @@ -1283,6 +1283,8 @@ func (t *Tracker) LastNoiseDialWasRecent() bool { return dur < 2*time.Minute } +const MetricLabelWarning = "warning" + type metricHealthMessageLabel struct { // TODO: break down by warnable.severity as well? Type string diff --git a/health/health_test.go b/health/health_test.go index 8107c1cf09db5..69e586066cdd6 100644 --- a/health/health_test.go +++ b/health/health_test.go @@ -7,11 +7,13 @@ import ( "fmt" "reflect" "slices" + "strconv" "testing" "time" "tailscale.com/tailcfg" "tailscale.com/types/opt" + "tailscale.com/util/usermetric" ) func TestAppendWarnableDebugFlags(t *testing.T) { @@ -273,7 +275,7 @@ func TestShowUpdateWarnable(t *testing.T) { wantShow bool }{ { - desc: "nil CientVersion", + desc: "nil ClientVersion", check: true, cv: nil, wantWarnable: nil, @@ -348,3 +350,47 @@ func TestShowUpdateWarnable(t *testing.T) { }) } } + +func TestHealthMetric(t *testing.T) { + tests := []struct { + desc string + check bool + apply opt.Bool + cv *tailcfg.ClientVersion + wantMetricCount int + }{ + // When running in dev, and not initialising the client, there will be two warnings + // by default: + // - is-using-unstable-version + // - wantrunning-false + { + desc: "base-warnings", + check: true, + cv: nil, + wantMetricCount: 2, + }, + // with: update-available + { + desc: "update-warning", + check: true, + cv: &tailcfg.ClientVersion{RunningLatest: false, LatestVersion: "1.2.3"}, + wantMetricCount: 3, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + tr := &Tracker{ + checkForUpdates: tt.check, + applyUpdates: tt.apply, + latestVersion: tt.cv, + } + tr.SetMetricsRegistry(&usermetric.Registry{}) + if val := tr.metricHealthMessage.Get(metricHealthMessageLabel{Type: MetricLabelWarning}).String(); val != strconv.Itoa(tt.wantMetricCount) { + t.Fatalf("metric value: %q, want: %q", val, strconv.Itoa(tt.wantMetricCount)) + } + for _, w := range tr.CurrentState().Warnings { + t.Logf("warning: %v", w) + } + }) + } +} diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 7aebbdd4c39ca..0f904ad2d749f 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -38,7 +38,6 @@ import ( "golang.org/x/net/proxy" "tailscale.com/client/tailscale" "tailscale.com/cmd/testwrapper/flakytest" - "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/ipn/store/mem" "tailscale.com/net/netns" @@ -822,16 +821,6 @@ func TestUDPConn(t *testing.T) { } } -// testWarnable is a Warnable that is used within this package for testing purposes only. -var testWarnable = health.Register(&health.Warnable{ - Code: "test-warnable-tsnet", - Title: "Test warnable", - Severity: health.SeverityLow, - Text: func(args health.Args) string { - return args[health.ArgError] - }, -}) - func parseMetrics(m []byte) (map[string]float64, error) { metrics := make(map[string]float64) @@ -1045,11 +1034,6 @@ func TestUserMetrics(t *testing.T) { t.Fatal(err) } - status1, err := lc1.Status(ctxLc) - if err != nil { - t.Fatal(err) - } - parsedMetrics1, err := parseMetrics(metrics1) if err != nil { t.Fatal(err) @@ -1075,11 +1059,6 @@ func TestUserMetrics(t *testing.T) { t.Errorf("metrics1, tailscaled_approved_routes: got %v, want %v", got, want) } - // Validate the health counter metric against the status of the node - if got, want := parsedMetrics1[`tailscaled_health_messages{type="warning"}`], float64(len(status1.Health)); got != want { - t.Errorf("metrics1, tailscaled_health_messages: got %v, want %v", got, want) - } - // Verify that the amount of data recorded in bytes is higher or equal to the // 10 megabytes sent. inboundBytes1 := parsedMetrics1[`tailscaled_inbound_bytes_total{path="direct_ipv4"}`] @@ -1097,11 +1076,6 @@ func TestUserMetrics(t *testing.T) { t.Fatal(err) } - status2, err := lc2.Status(ctx) - if err != nil { - t.Fatal(err) - } - parsedMetrics2, err := parseMetrics(metrics2) if err != nil { t.Fatal(err) @@ -1119,11 +1093,6 @@ func TestUserMetrics(t *testing.T) { t.Errorf("metrics2, tailscaled_approved_routes: got %v, want %v", got, want) } - // Validate the health counter metric against the status of the node - if got, want := parsedMetrics2[`tailscaled_health_messages{type="warning"}`], float64(len(status2.Health)); got != want { - t.Errorf("metrics2, tailscaled_health_messages: got %v, want %v", got, want) - } - // Verify that the amount of data recorded in bytes is higher or equal than the // 10 megabytes sent. outboundBytes2 := parsedMetrics2[`tailscaled_outbound_bytes_total{path="direct_ipv4"}`] From 06d929f9ac87b0683a55ebd004d15899a0122f71 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 25 Nov 2024 10:15:04 +0100 Subject: [PATCH 148/179] tsnet: send less data in metrics integration test this commit reduced the amount of data sent in the metrics data integration test from 10MB to 1MB. On various machines 10MB was quite flaky, while 1MB has not failed once on 10000 runs. Updates #13420 Signed-off-by: Kristoffer Dalby --- tsnet/tsnet_test.go | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 0f904ad2d749f..aae034b617ba3 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -1015,8 +1015,8 @@ func TestUserMetrics(t *testing.T) { mustDirect(t, t.Logf, lc1, lc2) - // 10 megabytes - bytesToSend := 10 * 1024 * 1024 + // 1 megabytes + bytesToSend := 1 * 1024 * 1024 // This asserts generates some traffic, it is factored out // of TestUDPConn. @@ -1059,14 +1059,13 @@ func TestUserMetrics(t *testing.T) { t.Errorf("metrics1, tailscaled_approved_routes: got %v, want %v", got, want) } - // Verify that the amount of data recorded in bytes is higher or equal to the - // 10 megabytes sent. + // Verify that the amount of data recorded in bytes is higher or equal to the data sent inboundBytes1 := parsedMetrics1[`tailscaled_inbound_bytes_total{path="direct_ipv4"}`] if inboundBytes1 < float64(bytesToSend) { t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, inboundBytes1) } - // But ensure that it is not too much higher than the 10 megabytes sent. + // But ensure that it is not too much higher than the data sent. if inboundBytes1 > float64(bytesToSend)*bytesSentTolerance { t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, inboundBytes1) } @@ -1093,14 +1092,13 @@ func TestUserMetrics(t *testing.T) { t.Errorf("metrics2, tailscaled_approved_routes: got %v, want %v", got, want) } - // Verify that the amount of data recorded in bytes is higher or equal than the - // 10 megabytes sent. + // Verify that the amount of data recorded in bytes is higher or equal than the data sent. outboundBytes2 := parsedMetrics2[`tailscaled_outbound_bytes_total{path="direct_ipv4"}`] if outboundBytes2 < float64(bytesToSend) { t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, outboundBytes2) } - // But ensure that it is not too much higher than the 10 megabytes sent. + // But ensure that it is not too much higher than the data sent. if outboundBytes2 > float64(bytesToSend)*bytesSentTolerance { t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, outboundBytes2) } From e55899386b1f6d9f4b02d7a3349efdf83e162504 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 25 Nov 2024 13:36:37 +0100 Subject: [PATCH 149/179] tsnet: split bytes and routes metrics tests Updates #13420 Signed-off-by: Kristoffer Dalby --- tsnet/tsnet_test.go | 184 +++++++++++++++++++++++++++++--------------- 1 file changed, 123 insertions(+), 61 deletions(-) diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index aae034b617ba3..dbd010ce6ce8b 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -936,16 +936,136 @@ func sendData(logf func(format string, args ...any), ctx context.Context, bytesC return nil } -func TestUserMetrics(t *testing.T) { +func TestUserMetricsByteCounters(t *testing.T) { flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/13420") tstest.ResourceCheck(t) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() - controlURL, c := startControl(t) - s1, s1ip, s1PubKey := startServer(t, ctx, controlURL, "s1") + controlURL, _ := startControl(t) + s1, s1ip, _ := startServer(t, ctx, controlURL, "s1") s2, s2ip, _ := startServer(t, ctx, controlURL, "s2") + lc1, err := s1.LocalClient() + if err != nil { + t.Fatal(err) + } + + lc2, err := s2.LocalClient() + if err != nil { + t.Fatal(err) + } + + // Force an update to the netmap to ensure that the metrics are up-to-date. + s1.lb.DebugForceNetmapUpdate() + s2.lb.DebugForceNetmapUpdate() + + // Wait for both nodes to have a peer in their netmap. + waitForCondition(t, "waiting for netmaps to contain peer", 90*time.Second, func() bool { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + status1, err := lc1.Status(ctx) + if err != nil { + t.Logf("getting status: %s", err) + return false + } + status2, err := lc2.Status(ctx) + if err != nil { + t.Logf("getting status: %s", err) + return false + } + return len(status1.Peers()) > 0 && len(status2.Peers()) > 0 + }) + + // ping to make sure the connection is up. + res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP) + if err != nil { + t.Fatalf("pinging: %s", err) + } + t.Logf("ping success: %#+v", res) + + mustDirect(t, t.Logf, lc1, lc2) + + // 1 megabytes + bytesToSend := 1 * 1024 * 1024 + + // This asserts generates some traffic, it is factored out + // of TestUDPConn. + start := time.Now() + err = sendData(t.Logf, ctx, bytesToSend, s1, s2, s1ip, s2ip) + if err != nil { + t.Fatalf("Failed to send packets: %v", err) + } + t.Logf("Sent %d bytes from s1 to s2 in %s", bytesToSend, time.Since(start).String()) + + ctxLc, cancelLc := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelLc() + metrics1, err := lc1.UserMetrics(ctxLc) + if err != nil { + t.Fatal(err) + } + + parsedMetrics1, err := parseMetrics(metrics1) + if err != nil { + t.Fatal(err) + } + + // Allow the metrics for the bytes sent to be off by 15%. + bytesSentTolerance := 1.15 + + t.Logf("Metrics1:\n%s\n", metrics1) + + // Verify that the amount of data recorded in bytes is higher or equal to the data sent + inboundBytes1 := parsedMetrics1[`tailscaled_inbound_bytes_total{path="direct_ipv4"}`] + if inboundBytes1 < float64(bytesToSend) { + t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, inboundBytes1) + } + + // But ensure that it is not too much higher than the data sent. + if inboundBytes1 > float64(bytesToSend)*bytesSentTolerance { + t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, inboundBytes1) + } + + metrics2, err := lc2.UserMetrics(ctx) + if err != nil { + t.Fatal(err) + } + + parsedMetrics2, err := parseMetrics(metrics2) + if err != nil { + t.Fatal(err) + } + + t.Logf("Metrics2:\n%s\n", metrics2) + + // Verify that the amount of data recorded in bytes is higher or equal than the data sent. + outboundBytes2 := parsedMetrics2[`tailscaled_outbound_bytes_total{path="direct_ipv4"}`] + if outboundBytes2 < float64(bytesToSend) { + t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, outboundBytes2) + } + + // But ensure that it is not too much higher than the data sent. + if outboundBytes2 > float64(bytesToSend)*bytesSentTolerance { + t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, outboundBytes2) + } +} + +func TestUserMetricsRouteGauges(t *testing.T) { + // Windows does not seem to support or report back routes when running in + // userspace via tsnet. So, we skip this check on Windows. + // TODO(kradalby): Figure out if this is correct. + if runtime.GOOS == "windows" { + t.Skipf("skipping on windows") + } + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/13420") + tstest.ResourceCheck(t) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + controlURL, c := startControl(t) + s1, _, s1PubKey := startServer(t, ctx, controlURL, "s1") + s2, _, _ := startServer(t, ctx, controlURL, "s2") + s1.lb.EditPrefs(&ipn.MaskedPrefs{ Prefs: ipn.Prefs{ AdvertiseRoutes: []netip.Prefix{ @@ -973,24 +1093,11 @@ func TestUserMetrics(t *testing.T) { t.Fatal(err) } - // ping to make sure the connection is up. - res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP) - if err != nil { - t.Fatalf("pinging: %s", err) - } - t.Logf("ping success: %#+v", res) - - ht := s1.lb.HealthTracker() - ht.SetUnhealthy(testWarnable, health.Args{"Text": "Hello world 1"}) - // Force an update to the netmap to ensure that the metrics are up-to-date. s1.lb.DebugForceNetmapUpdate() s2.lb.DebugForceNetmapUpdate() wantRoutes := float64(2) - if runtime.GOOS == "windows" { - wantRoutes = 0 - } // Wait for the routes to be propagated to node 1 to ensure // that the metrics are up-to-date. @@ -1002,31 +1109,11 @@ func TestUserMetrics(t *testing.T) { t.Logf("getting status: %s", err) return false } - if runtime.GOOS == "windows" { - // Windows does not seem to support or report back routes when running in - // userspace via tsnet. So, we skip this check on Windows. - // TODO(kradalby): Figure out if this is correct. - return true - } // Wait for the primary routes to reach our desired routes, which is wantRoutes + 1, because // the PrimaryRoutes list will contain a exit node route, which the metric does not count. return status1.Self.PrimaryRoutes != nil && status1.Self.PrimaryRoutes.Len() == int(wantRoutes)+1 }) - mustDirect(t, t.Logf, lc1, lc2) - - // 1 megabytes - bytesToSend := 1 * 1024 * 1024 - - // This asserts generates some traffic, it is factored out - // of TestUDPConn. - start := time.Now() - err = sendData(t.Logf, ctx, bytesToSend, s1, s2, s1ip, s2ip) - if err != nil { - t.Fatalf("Failed to send packets: %v", err) - } - t.Logf("Sent %d bytes from s1 to s2 in %s", bytesToSend, time.Since(start).String()) - ctxLc, cancelLc := context.WithTimeout(context.Background(), 5*time.Second) defer cancelLc() metrics1, err := lc1.UserMetrics(ctxLc) @@ -1039,9 +1126,6 @@ func TestUserMetrics(t *testing.T) { t.Fatal(err) } - // Allow the metrics for the bytes sent to be off by 15%. - bytesSentTolerance := 1.15 - t.Logf("Metrics1:\n%s\n", metrics1) // The node is advertising 4 routes: @@ -1059,17 +1143,6 @@ func TestUserMetrics(t *testing.T) { t.Errorf("metrics1, tailscaled_approved_routes: got %v, want %v", got, want) } - // Verify that the amount of data recorded in bytes is higher or equal to the data sent - inboundBytes1 := parsedMetrics1[`tailscaled_inbound_bytes_total{path="direct_ipv4"}`] - if inboundBytes1 < float64(bytesToSend) { - t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, inboundBytes1) - } - - // But ensure that it is not too much higher than the data sent. - if inboundBytes1 > float64(bytesToSend)*bytesSentTolerance { - t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, inboundBytes1) - } - metrics2, err := lc2.UserMetrics(ctx) if err != nil { t.Fatal(err) @@ -1091,17 +1164,6 @@ func TestUserMetrics(t *testing.T) { if got, want := parsedMetrics2["tailscaled_approved_routes"], 0.0; got != want { t.Errorf("metrics2, tailscaled_approved_routes: got %v, want %v", got, want) } - - // Verify that the amount of data recorded in bytes is higher or equal than the data sent. - outboundBytes2 := parsedMetrics2[`tailscaled_outbound_bytes_total{path="direct_ipv4"}`] - if outboundBytes2 < float64(bytesToSend) { - t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, outboundBytes2) - } - - // But ensure that it is not too much higher than the data sent. - if outboundBytes2 > float64(bytesToSend)*bytesSentTolerance { - t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, outboundBytes2) - } } func waitForCondition(t *testing.T, msg string, waitTime time.Duration, f func() bool) { From 225d8f5a881f01d3cb3ec05a56d6134188061d71 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 25 Nov 2024 14:14:08 +0100 Subject: [PATCH 150/179] tsnet: validate sent data in metrics test Updates #13420 Signed-off-by: Kristoffer Dalby --- tsnet/tsnet_test.go | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index dbd010ce6ce8b..fea68f6d4e93a 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -894,9 +894,11 @@ func sendData(logf func(format string, args ...any), ctx context.Context, bytesC for { got := make([]byte, bytesCount) n, err := conn.Read(got) - if n != bytesCount { - logf("read %d bytes, want %d", n, bytesCount) + if err != nil { + allReceived <- fmt.Errorf("failed reading packet, %s", err) + return } + got = got[:n] select { case <-stopReceive: @@ -904,13 +906,17 @@ func sendData(logf func(format string, args ...any), ctx context.Context, bytesC default: } - if err != nil { - allReceived <- fmt.Errorf("failed reading packet, %s", err) - return - } - total += n logf("received %d/%d bytes, %.2f %%", total, bytesCount, (float64(total) / (float64(bytesCount)) * 100)) + + // Validate the received bytes to be the same as the sent bytes. + for _, b := range string(got) { + if b != 'A' { + allReceived <- fmt.Errorf("received unexpected byte: %c", b) + return + } + } + if total == bytesCount { break } From caba123008359e0987232a554277a64504be3f6c Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 25 Nov 2024 16:00:21 +0100 Subject: [PATCH 151/179] wgengine/magicsock: packet/bytes metrics should not count disco Updates #13420 Signed-off-by: Kristoffer Dalby --- wgengine/magicsock/magicsock.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index c361608ad4b23..805716e61daae 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -1267,7 +1267,7 @@ func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err // sendUDP sends UDP packet b to ipp. // See sendAddr's docs on the return value meanings. -func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte) (sent bool, err error) { +func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte, isDisco bool) (sent bool, err error) { if runtime.GOOS == "js" { return false, errNoUDP } @@ -1276,7 +1276,7 @@ func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte) (sent bool, err error) { metricSendUDPError.Add(1) _ = c.maybeRebindOnError(runtime.GOOS, err) } else { - if sent { + if sent && !isDisco { switch { case ipp.Addr().Is4(): c.metrics.outboundPacketsIPv4Total.Add(1) @@ -1371,7 +1371,7 @@ func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error) // returns (false, nil); it's not an error, but nothing was sent. func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte, isDisco bool) (sent bool, err error) { if addr.Addr() != tailcfg.DerpMagicIPAddr { - return c.sendUDP(addr, b) + return c.sendUDP(addr, b, isDisco) } regionID := int(addr.Port()) From 61dd2662eca775c8a3f6700a0194db5816de5049 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 28 Nov 2024 12:45:40 +0100 Subject: [PATCH 152/179] tsnet: remove flaky test marker from metrics Updates #13420 Signed-off-by: Kristoffer Dalby --- tsnet/tsnet_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index fea68f6d4e93a..14d600817ad70 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -943,14 +943,14 @@ func sendData(logf func(format string, args ...any), ctx context.Context, bytesC } func TestUserMetricsByteCounters(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/13420") - tstest.ResourceCheck(t) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() controlURL, _ := startControl(t) s1, s1ip, _ := startServer(t, ctx, controlURL, "s1") + defer s1.Close() s2, s2ip, _ := startServer(t, ctx, controlURL, "s2") + defer s2.Close() lc1, err := s1.LocalClient() if err != nil { @@ -1063,14 +1063,14 @@ func TestUserMetricsRouteGauges(t *testing.T) { if runtime.GOOS == "windows" { t.Skipf("skipping on windows") } - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/13420") - tstest.ResourceCheck(t) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() controlURL, c := startControl(t) s1, _, s1PubKey := startServer(t, ctx, controlURL, "s1") + defer s1.Close() s2, _, _ := startServer(t, ctx, controlURL, "s2") + defer s2.Close() s1.lb.EditPrefs(&ipn.MaskedPrefs{ Prefs: ipn.Prefs{ From f8587e321ead0889b05a46ca03beef6f75f9d65d Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Fri, 29 Nov 2024 10:37:25 +0000 Subject: [PATCH 153/179] cmd/k8s-operator: fix port name change bug for egress ProxyGroup proxies (#14247) Ensure that the ExternalName Service port names are always synced to the ClusterIP Service, to fix a bug where if users created a Service with a single unnamed port and later changed to 1+ named ports, the operator attempted to apply an invalid multi-port Service with an unnamed port. Also, fixes a small internal issue where not-yet Service status conditons were lost on a spec update. Updates tailscale/tailscale#10102 Signed-off-by: Irbe Krumina --- cmd/k8s-operator/egress-services.go | 24 +++++++- cmd/k8s-operator/egress-services_test.go | 75 +++++++++++++++++------- cmd/k8s-operator/testutils_test.go | 2 +- 3 files changed, 77 insertions(+), 24 deletions(-) diff --git a/cmd/k8s-operator/egress-services.go b/cmd/k8s-operator/egress-services.go index 98ed943669cd0..a562f0170eea1 100644 --- a/cmd/k8s-operator/egress-services.go +++ b/cmd/k8s-operator/egress-services.go @@ -136,9 +136,8 @@ func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Re } if !slices.Contains(svc.Finalizers, FinalizerName) { - l.Infof("configuring tailnet service") // logged exactly once svc.Finalizers = append(svc.Finalizers, FinalizerName) - if err := esr.Update(ctx, svc); err != nil { + if err := esr.updateSvcSpec(ctx, svc); err != nil { err := fmt.Errorf("failed to add finalizer: %w", err) r := svcConfiguredReason(svc, false, l) tsoperator.SetServiceCondition(svc, tsapi.EgressSvcConfigured, metav1.ConditionFalse, r, err.Error(), esr.clock, l) @@ -198,7 +197,7 @@ func (esr *egressSvcsReconciler) maybeProvision(ctx context.Context, svc *corev1 if svc.Spec.ExternalName != clusterIPSvcFQDN { l.Infof("Configuring ExternalName Service to point to ClusterIP Service %s", clusterIPSvcFQDN) svc.Spec.ExternalName = clusterIPSvcFQDN - if err = esr.Update(ctx, svc); err != nil { + if err = esr.updateSvcSpec(ctx, svc); err != nil { err = fmt.Errorf("error updating ExternalName Service: %w", err) return err } @@ -222,6 +221,15 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s found := false for _, wantsPM := range svc.Spec.Ports { if wantsPM.Port == pm.Port && strings.EqualFold(string(wantsPM.Protocol), string(pm.Protocol)) { + // We don't use the port name to distinguish this port internally, but Kubernetes + // require that, for Service ports with more than one name each port is uniquely named. + // So we can always pick the port name from the ExternalName Service as at this point we + // know that those are valid names because Kuberentes already validated it once. Note + // that users could have changed an unnamed port to a named port and might have changed + // port names- this should still work. + // https://kubernetes.io/docs/concepts/services-networking/service/#multi-port-services + // See also https://github.com/tailscale/tailscale/issues/13406#issuecomment-2507230388 + clusterIPSvc.Spec.Ports[i].Name = wantsPM.Name found = true break } @@ -714,3 +722,13 @@ func epsPortsFromSvc(svc *corev1.Service) (ep []discoveryv1.EndpointPort) { } return ep } + +// updateSvcSpec ensures that the given Service's spec is updated in cluster, but the local Service object still retains +// the not-yet-applied status. +// TODO(irbekrm): once we do SSA for these patch updates, this will no longer be needed. +func (esr *egressSvcsReconciler) updateSvcSpec(ctx context.Context, svc *corev1.Service) error { + st := svc.Status.DeepCopy() + err := esr.Update(ctx, svc) + svc.Status = *st + return err +} diff --git a/cmd/k8s-operator/egress-services_test.go b/cmd/k8s-operator/egress-services_test.go index ac77339853ebe..06fe977ecc130 100644 --- a/cmd/k8s-operator/egress-services_test.go +++ b/cmd/k8s-operator/egress-services_test.go @@ -105,28 +105,40 @@ func TestTailscaleEgressServices(t *testing.T) { condition(tsapi.ProxyGroupReady, metav1.ConditionTrue, "", "", clock), } }) - // Quirks of the fake client. - mustUpdateStatus(t, fc, "default", "test", func(svc *corev1.Service) { - svc.Status.Conditions = []metav1.Condition{} + expectReconciled(t, esr, "default", "test") + validateReadyService(t, fc, esr, svc, clock, zl, cm) + }) + t.Run("service_retain_one_unnamed_port", func(t *testing.T) { + svc.Spec.Ports = []corev1.ServicePort{{Protocol: "TCP", Port: 80}} + mustUpdate(t, fc, "default", "test", func(s *corev1.Service) { + s.Spec.Ports = svc.Spec.Ports }) expectReconciled(t, esr, "default", "test") - // Verify that a ClusterIP Service has been created. - name := findGenNameForEgressSvcResources(t, fc, svc) - expectEqual(t, fc, clusterIPSvc(name, svc), removeTargetPortsFromSvc) - clusterSvc := mustGetClusterIPSvc(t, fc, name) - // Verify that an EndpointSlice has been created. - expectEqual(t, fc, endpointSlice(name, svc, clusterSvc), nil) - // Verify that ConfigMap contains configuration for the new egress service. - mustHaveConfigForSvc(t, fc, svc, clusterSvc, cm) - r := svcConfiguredReason(svc, true, zl.Sugar()) - // Verify that the user-created ExternalName Service has Configured set to true and ExternalName pointing to the - // CluterIP Service. - svc.Status.Conditions = []metav1.Condition{ - condition(tsapi.EgressSvcConfigured, metav1.ConditionTrue, r, r, clock), - } - svc.ObjectMeta.Finalizers = []string{"tailscale.com/finalizer"} - svc.Spec.ExternalName = fmt.Sprintf("%s.operator-ns.svc.cluster.local", name) - expectEqual(t, fc, svc, nil) + validateReadyService(t, fc, esr, svc, clock, zl, cm) + }) + t.Run("service_add_two_named_ports", func(t *testing.T) { + svc.Spec.Ports = []corev1.ServicePort{{Protocol: "TCP", Port: 80, Name: "http"}, {Protocol: "TCP", Port: 443, Name: "https"}} + mustUpdate(t, fc, "default", "test", func(s *corev1.Service) { + s.Spec.Ports = svc.Spec.Ports + }) + expectReconciled(t, esr, "default", "test") + validateReadyService(t, fc, esr, svc, clock, zl, cm) + }) + t.Run("service_add_udp_port", func(t *testing.T) { + svc.Spec.Ports = append(svc.Spec.Ports, corev1.ServicePort{Port: 53, Protocol: "UDP", Name: "dns"}) + mustUpdate(t, fc, "default", "test", func(s *corev1.Service) { + s.Spec.Ports = svc.Spec.Ports + }) + expectReconciled(t, esr, "default", "test") + validateReadyService(t, fc, esr, svc, clock, zl, cm) + }) + t.Run("service_change_protocol", func(t *testing.T) { + svc.Spec.Ports = []corev1.ServicePort{{Protocol: "TCP", Port: 80, Name: "http"}, {Protocol: "TCP", Port: 443, Name: "https"}, {Port: 53, Protocol: "TCP", Name: "tcp_dns"}} + mustUpdate(t, fc, "default", "test", func(s *corev1.Service) { + s.Spec.Ports = svc.Spec.Ports + }) + expectReconciled(t, esr, "default", "test") + validateReadyService(t, fc, esr, svc, clock, zl, cm) }) t.Run("delete_external_name_service", func(t *testing.T) { @@ -143,6 +155,29 @@ func TestTailscaleEgressServices(t *testing.T) { }) } +func validateReadyService(t *testing.T, fc client.WithWatch, esr *egressSvcsReconciler, svc *corev1.Service, clock *tstest.Clock, zl *zap.Logger, cm *corev1.ConfigMap) { + expectReconciled(t, esr, "default", "test") + // Verify that a ClusterIP Service has been created. + name := findGenNameForEgressSvcResources(t, fc, svc) + expectEqual(t, fc, clusterIPSvc(name, svc), removeTargetPortsFromSvc) + clusterSvc := mustGetClusterIPSvc(t, fc, name) + // Verify that an EndpointSlice has been created. + expectEqual(t, fc, endpointSlice(name, svc, clusterSvc), nil) + // Verify that ConfigMap contains configuration for the new egress service. + mustHaveConfigForSvc(t, fc, svc, clusterSvc, cm) + r := svcConfiguredReason(svc, true, zl.Sugar()) + // Verify that the user-created ExternalName Service has Configured set to true and ExternalName pointing to the + // CluterIP Service. + svc.Status.Conditions = []metav1.Condition{ + condition(tsapi.EgressSvcValid, metav1.ConditionTrue, "EgressSvcValid", "EgressSvcValid", clock), + condition(tsapi.EgressSvcConfigured, metav1.ConditionTrue, r, r, clock), + } + svc.ObjectMeta.Finalizers = []string{"tailscale.com/finalizer"} + svc.Spec.ExternalName = fmt.Sprintf("%s.operator-ns.svc.cluster.local", name) + expectEqual(t, fc, svc, nil) + +} + func condition(typ tsapi.ConditionType, st metav1.ConditionStatus, r, msg string, clock tstime.Clock) metav1.Condition { return metav1.Condition{ Type: string(typ), diff --git a/cmd/k8s-operator/testutils_test.go b/cmd/k8s-operator/testutils_test.go index 084f573e5e45a..5795a0aaefeb8 100644 --- a/cmd/k8s-operator/testutils_test.go +++ b/cmd/k8s-operator/testutils_test.go @@ -650,7 +650,7 @@ func removeHashAnnotation(sts *appsv1.StatefulSet) { func removeTargetPortsFromSvc(svc *corev1.Service) { newPorts := make([]corev1.ServicePort, 0) for _, p := range svc.Spec.Ports { - newPorts = append(newPorts, corev1.ServicePort{Protocol: p.Protocol, Port: p.Port}) + newPorts = append(newPorts, corev1.ServicePort{Protocol: p.Protocol, Port: p.Port, Name: p.Name}) } svc.Spec.Ports = newPorts } From 44c8892c1818d777423e58464686659d67756451 Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Fri, 29 Nov 2024 15:32:18 +0000 Subject: [PATCH 154/179] Makefile,./build_docker.sh: update kube operator image build target name (#14251) Updates tailscale/corp#24540 Updates tailscale/tailscale#12914 Signed-off-by: Irbe Krumina --- Makefile | 2 +- build_docker.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 98c3d36cc1c9e..960f13885c11c 100644 --- a/Makefile +++ b/Makefile @@ -100,7 +100,7 @@ publishdevoperator: ## Build and publish k8s-operator image to location specifie @test "${REPO}" != "ghcr.io/tailscale/tailscale" || (echo "REPO=... must not be ghcr.io/tailscale/tailscale" && exit 1) @test "${REPO}" != "tailscale/k8s-operator" || (echo "REPO=... must not be tailscale/k8s-operator" && exit 1) @test "${REPO}" != "ghcr.io/tailscale/k8s-operator" || (echo "REPO=... must not be ghcr.io/tailscale/k8s-operator" && exit 1) - TAGS="${TAGS}" REPOS=${REPO} PLATFORM=${PLATFORM} PUSH=true TARGET=operator ./build_docker.sh + TAGS="${TAGS}" REPOS=${REPO} PLATFORM=${PLATFORM} PUSH=true TARGET=k8s-operator ./build_docker.sh publishdevnameserver: ## Build and publish k8s-nameserver image to location specified by ${REPO} @test -n "${REPO}" || (echo "REPO=... required; e.g. REPO=ghcr.io/${USER}/tailscale" && exit 1) diff --git a/build_docker.sh b/build_docker.sh index 9f39eb08ddf89..f9632ea0a06d3 100755 --- a/build_docker.sh +++ b/build_docker.sh @@ -54,7 +54,7 @@ case "$TARGET" in --annotations="${ANNOTATIONS}" \ /usr/local/bin/containerboot ;; - operator) + k8s-operator) DEFAULT_REPOS="tailscale/k8s-operator" REPOS="${REPOS:-${DEFAULT_REPOS}}" go run github.com/tailscale/mkctr \ From 13faa64c142148b1f8c8afd22d61e4a0de651b98 Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Fri, 29 Nov 2024 15:44:58 +0000 Subject: [PATCH 155/179] cmd/k8s-operator: always set stateful filtering to false (#14216) Updates tailscale/tailscale#12108 Signed-off-by: Irbe Krumina --- cmd/k8s-operator/operator_test.go | 4 ++-- cmd/k8s-operator/sts.go | 9 +-------- cmd/k8s-operator/testutils_test.go | 20 ++++++++------------ 3 files changed, 11 insertions(+), 22 deletions(-) diff --git a/cmd/k8s-operator/operator_test.go b/cmd/k8s-operator/operator_test.go index 21ef08e520a26..e46cdd7fe6e45 100644 --- a/cmd/k8s-operator/operator_test.go +++ b/cmd/k8s-operator/operator_test.go @@ -1388,7 +1388,7 @@ func TestTailscaledConfigfileHash(t *testing.T) { parentType: "svc", hostname: "default-test", clusterTargetIP: "10.20.30.40", - confFileHash: "a67b5ad3ff605531c822327e8f1a23dd0846e1075b722c13402f7d5d0ba32ba2", + confFileHash: "acf3467364b0a3ba9b8ee0dd772cb7c2f0bf585e288fa99b7fe4566009ed6041", app: kubetypes.AppIngressProxy, } expectEqual(t, fc, expectedSTS(t, fc, o), nil) @@ -1399,7 +1399,7 @@ func TestTailscaledConfigfileHash(t *testing.T) { mak.Set(&svc.Annotations, AnnotationHostname, "another-test") }) o.hostname = "another-test" - o.confFileHash = "888a993ebee20ad6be99623b45015339de117946850cf1252bede0b570e04293" + o.confFileHash = "d4cc13f09f55f4f6775689004f9a466723325b84d2b590692796bfe22aeaa389" expectReconciled(t, sr, "default", "test") expectEqual(t, fc, expectedSTS(t, fc, o), nil) } diff --git a/cmd/k8s-operator/sts.go b/cmd/k8s-operator/sts.go index 5df476478c987..b12b1cdd011d7 100644 --- a/cmd/k8s-operator/sts.go +++ b/cmd/k8s-operator/sts.go @@ -854,17 +854,10 @@ func tailscaledConfig(stsC *tailscaleSTSConfig, newAuthkey string, oldSecret *co AcceptRoutes: "false", // AcceptRoutes defaults to true Locked: "false", Hostname: &stsC.Hostname, - NoStatefulFiltering: "false", + NoStatefulFiltering: "true", // Explicitly enforce default value, see #14216 AppConnector: &ipn.AppConnectorPrefs{Advertise: false}, } - // For egress proxies only, we need to ensure that stateful filtering is - // not in place so that traffic from cluster can be forwarded via - // Tailscale IPs. - // TODO (irbekrm): set it to true always as this is now the default in core. - if stsC.TailnetTargetFQDN != "" || stsC.TailnetTargetIP != "" { - conf.NoStatefulFiltering = "true" - } if stsC.Connector != nil { routes, err := netutil.CalcAdvertiseRoutes(stsC.Connector.routes, stsC.Connector.isExitNode) if err != nil { diff --git a/cmd/k8s-operator/testutils_test.go b/cmd/k8s-operator/testutils_test.go index 5795a0aaefeb8..8f06f5979cbf4 100644 --- a/cmd/k8s-operator/testutils_test.go +++ b/cmd/k8s-operator/testutils_test.go @@ -353,13 +353,14 @@ func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Sec mak.Set(&s.StringData, "serve-config", string(serveConfigBs)) } conf := &ipn.ConfigVAlpha{ - Version: "alpha0", - AcceptDNS: "false", - Hostname: &opts.hostname, - Locked: "false", - AuthKey: ptr.To("secret-authkey"), - AcceptRoutes: "false", - AppConnector: &ipn.AppConnectorPrefs{Advertise: false}, + Version: "alpha0", + AcceptDNS: "false", + Hostname: &opts.hostname, + Locked: "false", + AuthKey: ptr.To("secret-authkey"), + AcceptRoutes: "false", + AppConnector: &ipn.AppConnectorPrefs{Advertise: false}, + NoStatefulFiltering: "true", } if opts.proxyClass != "" { t.Logf("applying configuration from ProxyClass %s", opts.proxyClass) @@ -391,11 +392,6 @@ func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Sec routes = append(routes, prefix) } } - if opts.tailnetTargetFQDN != "" || opts.tailnetTargetIP != "" { - conf.NoStatefulFiltering = "true" - } else { - conf.NoStatefulFiltering = "false" - } conf.AdvertiseRoutes = routes bnn, err := json.Marshal(conf) if err != nil { From a68efe2088c07c1abad537965be724bdc8273044 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 22 Oct 2024 13:53:34 -0500 Subject: [PATCH 156/179] cmd/checkmetrics: add command for checking metrics against kb This commit adds a command to validate that all the metrics that are registring in the client are also present in a path or url. It is intended to be ran from the KB against the latest version of tailscale. Updates tailscale/corp#24066 Updates tailscale/corp#22075 Co-Authored-By: Brad Fitzpatrick Signed-off-by: Kristoffer Dalby --- cmd/checkmetrics/checkmetrics.go | 131 +++++++++++++++++++++++++++++++ util/usermetric/usermetric.go | 11 +++ 2 files changed, 142 insertions(+) create mode 100644 cmd/checkmetrics/checkmetrics.go diff --git a/cmd/checkmetrics/checkmetrics.go b/cmd/checkmetrics/checkmetrics.go new file mode 100644 index 0000000000000..fb9e8ab4c61ec --- /dev/null +++ b/cmd/checkmetrics/checkmetrics.go @@ -0,0 +1,131 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// checkmetrics validates that all metrics in the tailscale client-metrics +// are documented in a given path or URL. +package main + +import ( + "context" + "flag" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "os" + "strings" + "time" + + "tailscale.com/ipn/store/mem" + "tailscale.com/tsnet" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/util/httpm" +) + +var ( + kbPath = flag.String("kb-path", "", "filepath to the client-metrics knowledge base") + kbUrl = flag.String("kb-url", "", "URL to the client-metrics knowledge base page") +) + +func main() { + flag.Parse() + if *kbPath == "" && *kbUrl == "" { + log.Fatalf("either -kb-path or -kb-url must be set") + } + + var control testcontrol.Server + ts := httptest.NewServer(&control) + defer ts.Close() + + td, err := os.MkdirTemp("", "testcontrol") + if err != nil { + log.Fatal(err) + } + defer os.RemoveAll(td) + + // tsnet is used not used as a Tailscale client, but as a way to + // boot up Tailscale, have all the metrics registered, and then + // verifiy that all the metrics are documented. + tsn := &tsnet.Server{ + Dir: td, + Store: new(mem.Store), + UserLogf: log.Printf, + Ephemeral: true, + ControlURL: ts.URL, + } + if err := tsn.Start(); err != nil { + log.Fatal(err) + } + defer tsn.Close() + + log.Printf("checking that all metrics are documented, looking for: %s", tsn.Sys().UserMetricsRegistry().MetricNames()) + + if *kbPath != "" { + kb, err := readKB(*kbPath) + if err != nil { + log.Fatalf("reading kb: %v", err) + } + missing := undocumentedMetrics(kb, tsn.Sys().UserMetricsRegistry().MetricNames()) + + if len(missing) > 0 { + log.Fatalf("found undocumented metrics in %q: %v", *kbPath, missing) + } + } + + if *kbUrl != "" { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + kb, err := getKB(ctx, *kbUrl) + if err != nil { + log.Fatalf("getting kb: %v", err) + } + missing := undocumentedMetrics(kb, tsn.Sys().UserMetricsRegistry().MetricNames()) + + if len(missing) > 0 { + log.Fatalf("found undocumented metrics in %q: %v", *kbUrl, missing) + } + } +} + +func readKB(path string) (string, error) { + b, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("reading file: %w", err) + } + + return string(b), nil +} + +func getKB(ctx context.Context, url string) (string, error) { + req, err := http.NewRequestWithContext(ctx, httpm.GET, url, nil) + if err != nil { + return "", fmt.Errorf("creating request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("getting kb page: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + b, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("reading body: %w", err) + } + return string(b), nil +} + +func undocumentedMetrics(b string, metrics []string) []string { + var missing []string + for _, metric := range metrics { + if !strings.Contains(b, metric) { + missing = append(missing, metric) + } + } + return missing +} diff --git a/util/usermetric/usermetric.go b/util/usermetric/usermetric.go index 7913a4ef0d5f8..74e9447a64bbb 100644 --- a/util/usermetric/usermetric.go +++ b/util/usermetric/usermetric.go @@ -14,6 +14,7 @@ import ( "tailscale.com/metrics" "tailscale.com/tsweb/varz" + "tailscale.com/util/set" ) // Registry tracks user-facing metrics of various Tailscale subsystems. @@ -106,3 +107,13 @@ func (r *Registry) String() string { return sb.String() } + +// Metrics returns the name of all the metrics in the registry. +func (r *Registry) MetricNames() []string { + ret := make(set.Set[string]) + r.vars.Do(func(kv expvar.KeyValue) { + ret.Add(kv.Key) + }) + + return ret.Slice() +} From 24095e489716b7ec4a6bbe1978dd15ae442af73e Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Mon, 2 Dec 2024 12:18:09 +0000 Subject: [PATCH 157/179] cmd/containerboot: serve health on local endpoint (#14246) * cmd/containerboot: serve health on local endpoint We introduced stable (user) metrics in #14035, and `TS_LOCAL_ADDR_PORT` with it. Rather than requiring users to specify a new addr/port combination for each new local endpoint they want the container to serve, this combines the health check endpoint onto the local addr/port used by metrics if `TS_ENABLE_HEALTH_CHECK` is used instead of `TS_HEALTHCHECK_ADDR_PORT`. `TS_LOCAL_ADDR_PORT` now defaults to binding to all interfaces on 9002 so that it works more seamlessly and with less configuration in environments other than Kubernetes, where the operator always overrides the default anyway. In particular, listening on localhost would not be accessible from outside the container, and many scripted container environments do not know the IP address of the container before it's started. Listening on all interfaces allows users to just set one env var (`TS_ENABLE_METRICS` or `TS_ENABLE_HEALTH_CHECK`) to get a fully functioning local endpoint they can query from outside the container. Updates #14035, #12898 Signed-off-by: Tom Proctor --- cmd/containerboot/healthz.go | 35 ++++---- cmd/containerboot/main.go | 76 +++++++++++++---- cmd/containerboot/main_test.go | 150 ++++++++++++++++++++++++++++++++- cmd/containerboot/metrics.go | 22 ++--- cmd/containerboot/settings.go | 28 ++++-- cmd/k8s-operator/sts.go | 2 +- cmd/k8s-operator/sts_test.go | 4 +- 7 files changed, 251 insertions(+), 66 deletions(-) diff --git a/cmd/containerboot/healthz.go b/cmd/containerboot/healthz.go index 12e7ee9f8db73..895290733cf5f 100644 --- a/cmd/containerboot/healthz.go +++ b/cmd/containerboot/healthz.go @@ -7,7 +7,6 @@ package main import ( "log" - "net" "net/http" "sync" ) @@ -23,29 +22,29 @@ type healthz struct { func (h *healthz) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.Lock() defer h.Unlock() + if h.hasAddrs { w.Write([]byte("ok")) } else { - http.Error(w, "node currently has no tailscale IPs", http.StatusInternalServerError) + http.Error(w, "node currently has no tailscale IPs", http.StatusServiceUnavailable) } } -// runHealthz runs a simple HTTP health endpoint on /healthz, listening on the -// provided address. A containerized tailscale instance is considered healthy if -// it has at least one tailnet IP address. -func runHealthz(addr string, h *healthz) { - lis, err := net.Listen("tcp", addr) - if err != nil { - log.Fatalf("error listening on the provided health endpoint address %q: %v", addr, err) +func (h *healthz) update(healthy bool) { + h.Lock() + defer h.Unlock() + + if h.hasAddrs != healthy { + log.Println("Setting healthy", healthy) } - mux := http.NewServeMux() + h.hasAddrs = healthy +} + +// healthHandlers registers a simple health handler at /healthz. +// A containerized tailscale instance is considered healthy if +// it has at least one tailnet IP address. +func healthHandlers(mux *http.ServeMux) *healthz { + h := &healthz{} mux.Handle("GET /healthz", h) - log.Printf("Running healthcheck endpoint at %s/healthz", addr) - hs := &http.Server{Handler: mux} - - go func() { - if err := hs.Serve(lis); err != nil { - log.Fatalf("failed running health endpoint: %v", err) - } - }() + return h } diff --git a/cmd/containerboot/main.go b/cmd/containerboot/main.go index 313e8deb0b93c..0af9062a5f314 100644 --- a/cmd/containerboot/main.go +++ b/cmd/containerboot/main.go @@ -52,11 +52,17 @@ // ${TS_CERT_DOMAIN}, it will be replaced with the value of the available FQDN. // It cannot be used in conjunction with TS_DEST_IP. The file is watched for changes, // and will be re-applied when it changes. -// - TS_HEALTHCHECK_ADDR_PORT: if specified, an HTTP health endpoint will be -// served at /healthz at the provided address, which should be in form [
]:. -// If not set, no health check will be run. If set to :, addr will default to 0.0.0.0 -// The health endpoint will return 200 OK if this node has at least one tailnet IP address, -// otherwise returns 503. +// - TS_HEALTHCHECK_ADDR_PORT: deprecated, use TS_ENABLE_HEALTH_CHECK instead and optionally +// set TS_LOCAL_ADDR_PORT. Will be removed in 1.82.0. +// - TS_LOCAL_ADDR_PORT: the address and port to serve local metrics and health +// check endpoints if enabled via TS_ENABLE_METRICS and/or TS_ENABLE_HEALTH_CHECK. +// Defaults to [::]:9002, serving on all available interfaces. +// - TS_ENABLE_METRICS: if true, a metrics endpoint will be served at /metrics on +// the address specified by TS_LOCAL_ADDR_PORT. See https://tailscale.com/kb/1482/client-metrics +// for more information on the metrics exposed. +// - TS_ENABLE_HEALTH_CHECK: if true, a health check endpoint will be served at /healthz on +// the address specified by TS_LOCAL_ADDR_PORT. The health endpoint will return 200 +// OK if this node has at least one tailnet IP address, otherwise returns 503. // NB: the health criteria might change in the future. // - TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR: if specified, a path to a // directory that containers tailscaled config in file. The config file needs to be @@ -99,6 +105,7 @@ import ( "log" "math" "net" + "net/http" "net/netip" "os" "os/signal" @@ -178,12 +185,32 @@ func main() { } defer killTailscaled() - if cfg.LocalAddrPort != "" && cfg.MetricsEnabled { - m := &metrics{ - lc: client, - debugEndpoint: cfg.DebugAddrPort, + var healthCheck *healthz + if cfg.HealthCheckAddrPort != "" { + mux := http.NewServeMux() + + log.Printf("Running healthcheck endpoint at %s/healthz", cfg.HealthCheckAddrPort) + healthCheck = healthHandlers(mux) + + close := runHTTPServer(mux, cfg.HealthCheckAddrPort) + defer close() + } + + if cfg.localMetricsEnabled() || cfg.localHealthEnabled() { + mux := http.NewServeMux() + + if cfg.localMetricsEnabled() { + log.Printf("Running metrics endpoint at %s/metrics", cfg.LocalAddrPort) + metricsHandlers(mux, client, cfg.DebugAddrPort) } - runMetrics(cfg.LocalAddrPort, m) + + if cfg.localHealthEnabled() { + log.Printf("Running healthcheck endpoint at %s/healthz", cfg.LocalAddrPort) + healthCheck = healthHandlers(mux) + } + + close := runHTTPServer(mux, cfg.LocalAddrPort) + defer close() } if cfg.EnableForwardingOptimizations { @@ -328,9 +355,6 @@ authLoop: certDomain = new(atomic.Pointer[string]) certDomainChanged = make(chan bool, 1) - - h = &healthz{} // http server for the healthz endpoint - healthzRunner = sync.OnceFunc(func() { runHealthz(cfg.HealthCheckAddrPort, h) }) ) if cfg.ServeConfigPath != "" { go watchServeConfigChanges(ctx, cfg.ServeConfigPath, certDomainChanged, certDomain, client) @@ -556,11 +580,8 @@ runLoop: } } - if cfg.HealthCheckAddrPort != "" { - h.Lock() - h.hasAddrs = len(addrs) != 0 - h.Unlock() - healthzRunner() + if healthCheck != nil { + healthCheck.update(len(addrs) != 0) } if egressSvcsNotify != nil { egressSvcsNotify <- n @@ -751,3 +772,22 @@ func tailscaledConfigFilePath() string { log.Printf("Using tailscaled config file %q to match current capability version %d", filePath, tailcfg.CurrentCapabilityVersion) return filePath } + +func runHTTPServer(mux *http.ServeMux, addr string) (close func() error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalf("failed to listen on addr %q: %v", addr, err) + } + srv := &http.Server{Handler: mux} + + go func() { + if err := srv.Serve(ln); err != nil { + log.Fatalf("failed running server: %v", err) + } + }() + + return func() error { + err := srv.Shutdown(context.Background()) + return errors.Join(err, ln.Close()) + } +} diff --git a/cmd/containerboot/main_test.go b/cmd/containerboot/main_test.go index 5c92787ce6079..47d7c19cfa78f 100644 --- a/cmd/containerboot/main_test.go +++ b/cmd/containerboot/main_test.go @@ -101,6 +101,24 @@ func TestContainerBoot(t *testing.T) { argFile := filepath.Join(d, "args") runningSockPath := filepath.Join(d, "tmp/tailscaled.sock") + var localAddrPort, healthAddrPort int + for _, p := range []*int{&localAddrPort, &healthAddrPort} { + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Failed to open listener: %v", err) + } + if err := ln.Close(); err != nil { + t.Fatalf("Failed to close listener: %v", err) + } + port := ln.Addr().(*net.TCPAddr).Port + *p = port + } + metricsURL := func(port int) string { + return fmt.Sprintf("http://127.0.0.1:%d/metrics", port) + } + healthURL := func(port int) string { + return fmt.Sprintf("http://127.0.0.1:%d/healthz", port) + } type phase struct { // If non-nil, send this IPN bus notification (and remember it as the @@ -119,6 +137,8 @@ func TestContainerBoot(t *testing.T) { // WantFatalLog is the fatal log message we expect from containerboot. // If set for a phase, the test will finish on that phase. WantFatalLog string + + EndpointStatuses map[string]int } runningNotify := &ipn.Notify{ State: ptr.To(ipn.Running), @@ -147,6 +167,11 @@ func TestContainerBoot(t *testing.T) { "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", }, + // No metrics or health by default. + EndpointStatuses: map[string]int{ + metricsURL(9002): -1, + healthURL(9002): -1, + }, }, { Notify: runningNotify, @@ -700,6 +725,104 @@ func TestContainerBoot(t *testing.T) { }, }, }, + { + Name: "metrics_enabled", + Env: map[string]string{ + "TS_LOCAL_ADDR_PORT": fmt.Sprintf("[::]:%d", localAddrPort), + "TS_ENABLE_METRICS": "true", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + EndpointStatuses: map[string]int{ + metricsURL(localAddrPort): 200, + healthURL(localAddrPort): -1, + }, + }, { + Notify: runningNotify, + }, + }, + }, + { + Name: "health_enabled", + Env: map[string]string{ + "TS_LOCAL_ADDR_PORT": fmt.Sprintf("[::]:%d", localAddrPort), + "TS_ENABLE_HEALTH_CHECK": "true", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + EndpointStatuses: map[string]int{ + metricsURL(localAddrPort): -1, + healthURL(localAddrPort): 503, // Doesn't start passing until the next phase. + }, + }, { + Notify: runningNotify, + EndpointStatuses: map[string]int{ + metricsURL(localAddrPort): -1, + healthURL(localAddrPort): 200, + }, + }, + }, + }, + { + Name: "metrics_and_health_on_same_port", + Env: map[string]string{ + "TS_LOCAL_ADDR_PORT": fmt.Sprintf("[::]:%d", localAddrPort), + "TS_ENABLE_METRICS": "true", + "TS_ENABLE_HEALTH_CHECK": "true", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + EndpointStatuses: map[string]int{ + metricsURL(localAddrPort): 200, + healthURL(localAddrPort): 503, // Doesn't start passing until the next phase. + }, + }, { + Notify: runningNotify, + EndpointStatuses: map[string]int{ + metricsURL(localAddrPort): 200, + healthURL(localAddrPort): 200, + }, + }, + }, + }, + { + Name: "local_metrics_and_deprecated_health", + Env: map[string]string{ + "TS_LOCAL_ADDR_PORT": fmt.Sprintf("[::]:%d", localAddrPort), + "TS_ENABLE_METRICS": "true", + "TS_HEALTHCHECK_ADDR_PORT": fmt.Sprintf("[::]:%d", healthAddrPort), + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + EndpointStatuses: map[string]int{ + metricsURL(localAddrPort): 200, + healthURL(healthAddrPort): 503, // Doesn't start passing until the next phase. + }, + }, { + Notify: runningNotify, + EndpointStatuses: map[string]int{ + metricsURL(localAddrPort): 200, + healthURL(healthAddrPort): 200, + }, + }, + }, + }, } for _, test := range tests { @@ -796,7 +919,26 @@ func TestContainerBoot(t *testing.T) { return nil }) if err != nil { - t.Fatal(err) + t.Fatalf("phase %d: %v", i, err) + } + + for url, want := range p.EndpointStatuses { + err := tstest.WaitFor(2*time.Second, func() error { + resp, err := http.Get(url) + if err != nil && want != -1 { + return fmt.Errorf("GET %s: %v", url, err) + } + if want > 0 && resp.StatusCode != want { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("GET %s, want %d, got %d\n%s", url, want, resp.StatusCode, string(body)) + } + + return nil + }) + if err != nil { + t.Fatalf("phase %d: %v", i, err) + } } } waitLogLine(t, 2*time.Second, cbOut, "Startup complete, waiting for shutdown signal") @@ -955,6 +1097,12 @@ func (l *localAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { panic(fmt.Sprintf("unsupported method %q", r.Method)) } + case "/localapi/v0/usermetrics": + if r.Method != "GET" { + panic(fmt.Sprintf("unsupported method %q", r.Method)) + } + w.Write([]byte("fake metrics")) + return default: panic(fmt.Sprintf("unsupported path %q", r.URL.Path)) } diff --git a/cmd/containerboot/metrics.go b/cmd/containerboot/metrics.go index 874774d7a4cf2..a8b9222a5ab1f 100644 --- a/cmd/containerboot/metrics.go +++ b/cmd/containerboot/metrics.go @@ -8,8 +8,6 @@ package main import ( "fmt" "io" - "log" - "net" "net/http" "tailscale.com/client/tailscale" @@ -64,28 +62,18 @@ func (m *metrics) handleDebug(w http.ResponseWriter, r *http.Request) { proxy(w, r, debugURL, http.DefaultClient.Do) } -// runMetrics runs a simple HTTP metrics endpoint at /metrics, forwarding +// metricsHandlers registers a simple HTTP metrics handler at /metrics, forwarding // requests to tailscaled's /localapi/v0/usermetrics API. // // In 1.78.x and 1.80.x, it also proxies debug paths to tailscaled's debug // endpoint if configured to ease migration for a breaking change serving user // metrics instead of debug metrics on the "metrics" port. -func runMetrics(addr string, m *metrics) { - ln, err := net.Listen("tcp", addr) - if err != nil { - log.Fatalf("error listening on the provided metrics endpoint address %q: %v", addr, err) +func metricsHandlers(mux *http.ServeMux, lc *tailscale.LocalClient, debugAddrPort string) { + m := &metrics{ + lc: lc, + debugEndpoint: debugAddrPort, } - mux := http.NewServeMux() mux.HandleFunc("GET /metrics", m.handleMetrics) mux.HandleFunc("/debug/", m.handleDebug) // TODO(tomhjp): Remove for 1.82.0 release. - - log.Printf("Running metrics endpoint at %s/metrics", addr) - ms := &http.Server{Handler: mux} - - go func() { - if err := ms.Serve(ln); err != nil { - log.Fatalf("failed running metrics endpoint: %v", err) - } - }() } diff --git a/cmd/containerboot/settings.go b/cmd/containerboot/settings.go index c877682b95742..1262a0e1872ec 100644 --- a/cmd/containerboot/settings.go +++ b/cmd/containerboot/settings.go @@ -67,18 +67,15 @@ type settings struct { PodIP string PodIPv4 string PodIPv6 string - HealthCheckAddrPort string // TODO(tomhjp): use the local addr/port instead. + HealthCheckAddrPort string LocalAddrPort string MetricsEnabled bool + HealthCheckEnabled bool DebugAddrPort string EgressSvcsCfgPath string } func configFromEnv() (*settings, error) { - defaultLocalAddrPort := "" - if v, ok := os.LookupEnv("POD_IP"); ok && v != "" { - defaultLocalAddrPort = fmt.Sprintf("%s:9002", v) - } cfg := &settings{ AuthKey: defaultEnvs([]string{"TS_AUTHKEY", "TS_AUTH_KEY"}, ""), Hostname: defaultEnv("TS_HOSTNAME", ""), @@ -105,8 +102,9 @@ func configFromEnv() (*settings, error) { PodIP: defaultEnv("POD_IP", ""), EnableForwardingOptimizations: defaultBool("TS_EXPERIMENTAL_ENABLE_FORWARDING_OPTIMIZATIONS", false), HealthCheckAddrPort: defaultEnv("TS_HEALTHCHECK_ADDR_PORT", ""), - LocalAddrPort: defaultEnv("TS_LOCAL_ADDR_PORT", defaultLocalAddrPort), - MetricsEnabled: defaultBool("TS_METRICS_ENABLED", false), + LocalAddrPort: defaultEnv("TS_LOCAL_ADDR_PORT", "[::]:9002"), + MetricsEnabled: defaultBool("TS_ENABLE_METRICS", false), + HealthCheckEnabled: defaultBool("TS_ENABLE_HEALTH_CHECK", false), DebugAddrPort: defaultEnv("TS_DEBUG_ADDR_PORT", ""), EgressSvcsCfgPath: defaultEnv("TS_EGRESS_SERVICES_CONFIG_PATH", ""), } @@ -181,11 +179,12 @@ func (s *settings) validate() error { return errors.New("TS_EXPERIMENTAL_ENABLE_FORWARDING_OPTIMIZATIONS is not supported in userspace mode") } if s.HealthCheckAddrPort != "" { + log.Printf("[warning] TS_HEALTHCHECK_ADDR_PORT is deprecated and will be removed in 1.82.0. Please use TS_ENABLE_HEALTH_CHECK and optionally TS_LOCAL_ADDR_PORT instead.") if _, err := netip.ParseAddrPort(s.HealthCheckAddrPort); err != nil { - return fmt.Errorf("error parsing TS_HEALTH_CHECK_ADDR_PORT value %q: %w", s.HealthCheckAddrPort, err) + return fmt.Errorf("error parsing TS_HEALTHCHECK_ADDR_PORT value %q: %w", s.HealthCheckAddrPort, err) } } - if s.LocalAddrPort != "" { + if s.localMetricsEnabled() || s.localHealthEnabled() { if _, err := netip.ParseAddrPort(s.LocalAddrPort); err != nil { return fmt.Errorf("error parsing TS_LOCAL_ADDR_PORT value %q: %w", s.LocalAddrPort, err) } @@ -195,6 +194,9 @@ func (s *settings) validate() error { return fmt.Errorf("error parsing TS_DEBUG_ADDR_PORT value %q: %w", s.DebugAddrPort, err) } } + if s.HealthCheckEnabled && s.HealthCheckAddrPort != "" { + return errors.New("TS_HEALTHCHECK_ADDR_PORT is deprecated and will be removed in 1.82.0, use TS_ENABLE_HEALTH_CHECK and optionally TS_LOCAL_ADDR_PORT") + } return nil } @@ -292,6 +294,14 @@ func hasKubeStateStore(cfg *settings) bool { return cfg.InKubernetes && cfg.KubernetesCanPatch && cfg.KubeSecret != "" } +func (cfg *settings) localMetricsEnabled() bool { + return cfg.LocalAddrPort != "" && cfg.MetricsEnabled +} + +func (cfg *settings) localHealthEnabled() bool { + return cfg.LocalAddrPort != "" && cfg.HealthCheckEnabled +} + // defaultEnv returns the value of the given envvar name, or defVal if // unset. func defaultEnv(name, defVal string) string { diff --git a/cmd/k8s-operator/sts.go b/cmd/k8s-operator/sts.go index b12b1cdd011d7..73c54a93d0373 100644 --- a/cmd/k8s-operator/sts.go +++ b/cmd/k8s-operator/sts.go @@ -818,7 +818,7 @@ func enableEndpoints(ss *appsv1.StatefulSet, metrics, debug bool) { Value: "$(POD_IP):9002", }, corev1.EnvVar{ - Name: "TS_METRICS_ENABLED", + Name: "TS_ENABLE_METRICS", Value: "true", }, ) diff --git a/cmd/k8s-operator/sts_test.go b/cmd/k8s-operator/sts_test.go index 7986d1b9164eb..05aafaee6a5d4 100644 --- a/cmd/k8s-operator/sts_test.go +++ b/cmd/k8s-operator/sts_test.go @@ -258,7 +258,7 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { corev1.EnvVar{Name: "TS_DEBUG_ADDR_PORT", Value: "$(POD_IP):9001"}, corev1.EnvVar{Name: "TS_TAILSCALED_EXTRA_ARGS", Value: "--debug=$(TS_DEBUG_ADDR_PORT)"}, corev1.EnvVar{Name: "TS_LOCAL_ADDR_PORT", Value: "$(POD_IP):9002"}, - corev1.EnvVar{Name: "TS_METRICS_ENABLED", Value: "true"}, + corev1.EnvVar{Name: "TS_ENABLE_METRICS", Value: "true"}, ) wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{ {Name: "debug", Protocol: "TCP", ContainerPort: 9001}, @@ -273,7 +273,7 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { wantSS = nonUserspaceProxySS.DeepCopy() wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, corev1.EnvVar{Name: "TS_LOCAL_ADDR_PORT", Value: "$(POD_IP):9002"}, - corev1.EnvVar{Name: "TS_METRICS_ENABLED", Value: "true"}, + corev1.EnvVar{Name: "TS_ENABLE_METRICS", Value: "true"}, ) wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{{Name: "metrics", Protocol: "TCP", ContainerPort: 9002}} gotSS = applyProxyClassToStatefulSet(proxyClassWithMetricsDebug(true, ptr.To(false)), nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) From 8d0c690f89971fa3ac30e3cba235cef8b2a81006 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 2 Dec 2024 08:58:21 -0800 Subject: [PATCH 158/179] net/netcheck: clean up ICMP probe AddrPort lookup Fixes #14200 Change-Id: Ib086814cf63dda5de021403fe1db4fb2a798eaae Signed-off-by: Brad Fitzpatrick --- net/netcheck/netcheck.go | 53 ++++++++++++++++++++--------------- net/netcheck/netcheck_test.go | 12 ++++---- 2 files changed, 36 insertions(+), 29 deletions(-) diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index 2c429862eb133..0bb9305683e56 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -1221,17 +1221,19 @@ func (c *Client) measureICMPLatency(ctx context.Context, reg *tailcfg.DERPRegion // Try pinging the first node in the region node := reg.Nodes[0] - // Get the IPAddr by asking for the UDP address that we would use for - // STUN and then using that IP. - // - // TODO(andrew-d): this is a bit ugly - nodeAddr := c.nodeAddr(ctx, node, probeIPv4) - if !nodeAddr.IsValid() { + if node.STUNPort < 0 { + // If STUN is disabled on a node, interpret that as meaning don't measure latency. + return 0, false, nil + } + const unusedPort = 0 + stunAddrPort, ok := c.nodeAddrPort(ctx, node, unusedPort, probeIPv4) + if !ok { return 0, false, fmt.Errorf("no address for node %v (v4-for-icmp)", node.Name) } + ip := stunAddrPort.Addr() addr := &net.IPAddr{ - IP: net.IP(nodeAddr.Addr().AsSlice()), - Zone: nodeAddr.Addr().Zone(), + IP: net.IP(ip.AsSlice()), + Zone: ip.Zone(), } // Use the unique node.Name field as the packet data to reduce the @@ -1478,8 +1480,8 @@ func (rs *reportState) runProbe(ctx context.Context, dm *tailcfg.DERPMap, probe return } - addr := c.nodeAddr(ctx, node, probe.proto) - if !addr.IsValid() { + addr, ok := c.nodeAddrPort(ctx, node, node.STUNPort, probe.proto) + if !ok { c.logf("netcheck.runProbe: named node %q has no %v address", probe.node, probe.proto) return } @@ -1528,12 +1530,17 @@ func (rs *reportState) runProbe(ctx context.Context, dm *tailcfg.DERPMap, probe c.vlogf("sent to %v", addr) } -// proto is 4 or 6 -// If it returns nil, the node is skipped. -func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeProto) (ap netip.AddrPort) { - port := cmp.Or(n.STUNPort, 3478) +// nodeAddrPort returns the IP:port to send a STUN queries to for a given node. +// +// The provided port should be n.STUNPort, which may be negative to disable STUN. +// If STUN is disabled for this node, it returns ok=false. +// The port parameter is separate for the ICMP caller to provide a fake value. +// +// proto is [probeIPv4] or [probeIPv6]. +func (c *Client) nodeAddrPort(ctx context.Context, n *tailcfg.DERPNode, port int, proto probeProto) (_ netip.AddrPort, ok bool) { + var zero netip.AddrPort if port < 0 || port > 1<<16-1 { - return + return zero, false } if n.STUNTestIP != "" { ip, err := netip.ParseAddr(n.STUNTestIP) @@ -1546,7 +1553,7 @@ func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeP if proto == probeIPv6 && ip.Is4() { return } - return netip.AddrPortFrom(ip, uint16(port)) + return netip.AddrPortFrom(ip, uint16(port)), true } switch proto { @@ -1554,20 +1561,20 @@ func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeP if n.IPv4 != "" { ip, _ := netip.ParseAddr(n.IPv4) if !ip.Is4() { - return + return zero, false } - return netip.AddrPortFrom(ip, uint16(port)) + return netip.AddrPortFrom(ip, uint16(port)), true } case probeIPv6: if n.IPv6 != "" { ip, _ := netip.ParseAddr(n.IPv6) if !ip.Is6() { - return + return zero, false } - return netip.AddrPortFrom(ip, uint16(port)) + return netip.AddrPortFrom(ip, uint16(port)), true } default: - return + return zero, false } // The default lookup function if we don't set UseDNSCache is to use net.DefaultResolver. @@ -1609,13 +1616,13 @@ func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeP addrs, err := lookupIPAddr(ctx, n.HostName) for _, a := range addrs { if (a.Is4() && probeIsV4) || (a.Is6() && !probeIsV4) { - return netip.AddrPortFrom(a, uint16(port)) + return netip.AddrPortFrom(a, uint16(port)), true } } if err != nil { c.logf("netcheck: DNS lookup error for %q (node %q region %v): %v", n.HostName, n.Name, n.RegionID, err) } - return + return zero, false } func regionHasDERPNode(r *tailcfg.DERPRegion) bool { diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index b4fbb4023dcc1..23891efcc6e48 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -887,8 +887,8 @@ func TestNodeAddrResolve(t *testing.T) { c.UseDNSCache = tt t.Run("IPv4", func(t *testing.T) { - ap := c.nodeAddr(ctx, dn, probeIPv4) - if !ap.IsValid() { + ap, ok := c.nodeAddrPort(ctx, dn, dn.STUNPort, probeIPv4) + if !ok { t.Fatal("expected valid AddrPort") } if !ap.Addr().Is4() { @@ -902,8 +902,8 @@ func TestNodeAddrResolve(t *testing.T) { t.Skipf("IPv6 may not work on this machine") } - ap := c.nodeAddr(ctx, dn, probeIPv6) - if !ap.IsValid() { + ap, ok := c.nodeAddrPort(ctx, dn, dn.STUNPort, probeIPv6) + if !ok { t.Fatal("expected valid AddrPort") } if !ap.Addr().Is6() { @@ -912,8 +912,8 @@ func TestNodeAddrResolve(t *testing.T) { t.Logf("got IPv6 addr: %v", ap) }) t.Run("IPv6 Failure", func(t *testing.T) { - ap := c.nodeAddr(ctx, dnV4Only, probeIPv6) - if ap.IsValid() { + ap, ok := c.nodeAddrPort(ctx, dnV4Only, dn.STUNPort, probeIPv6) + if ok { t.Fatalf("expected no addr but got: %v", ap) } t.Logf("correctly got invalid addr") From 3f545725392a0cd3185a12961f71fd87b6b956e2 Mon Sep 17 00:00:00 2001 From: KevinLiang10 <37811973+KevinLiang10@users.noreply.github.com> Date: Thu, 28 Nov 2024 12:49:37 -0500 Subject: [PATCH 159/179] IPN: Update ServeConfig to accept configuration for Services. This commit updates ServeConfig to allow configuration to Services (VIPServices for now) via Serve. The scope of this commit is only adding the Services field to ServeConfig. The field doesn't actually allow packet flowing yet. The purpose of this commit is to unblock other work on k8s end. Updates #22953 Signed-off-by: KevinLiang10 <37811973+KevinLiang10@users.noreply.github.com> --- ipn/doc.go | 2 +- ipn/ipn_clone.go | 49 ++++++++++++++++++++++++++++++++ ipn/ipn_view.go | 74 +++++++++++++++++++++++++++++++++++++++++++++++- ipn/serve.go | 21 ++++++++++++++ 4 files changed, 144 insertions(+), 2 deletions(-) diff --git a/ipn/doc.go b/ipn/doc.go index 4b3810be1f734..9a0bbb800b556 100644 --- a/ipn/doc.go +++ b/ipn/doc.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:generate go run tailscale.com/cmd/viewer -type=Prefs,ServeConfig,TCPPortHandler,HTTPHandler,WebServerConfig +//go:generate go run tailscale.com/cmd/viewer -type=Prefs,ServeConfig,ServiceConfig,TCPPortHandler,HTTPHandler,WebServerConfig // Package ipn implements the interactions between the Tailscale cloud // control plane and the local network stack. diff --git a/ipn/ipn_clone.go b/ipn/ipn_clone.go index 0e9698faf4488..34d7ba9a66364 100644 --- a/ipn/ipn_clone.go +++ b/ipn/ipn_clone.go @@ -105,6 +105,16 @@ func (src *ServeConfig) Clone() *ServeConfig { } } } + if dst.Services != nil { + dst.Services = map[string]*ServiceConfig{} + for k, v := range src.Services { + if v == nil { + dst.Services[k] = nil + } else { + dst.Services[k] = v.Clone() + } + } + } dst.AllowFunnel = maps.Clone(src.AllowFunnel) if dst.Foreground != nil { dst.Foreground = map[string]*ServeConfig{} @@ -123,11 +133,50 @@ func (src *ServeConfig) Clone() *ServeConfig { var _ServeConfigCloneNeedsRegeneration = ServeConfig(struct { TCP map[uint16]*TCPPortHandler Web map[HostPort]*WebServerConfig + Services map[string]*ServiceConfig AllowFunnel map[HostPort]bool Foreground map[string]*ServeConfig ETag string }{}) +// Clone makes a deep copy of ServiceConfig. +// The result aliases no memory with the original. +func (src *ServiceConfig) Clone() *ServiceConfig { + if src == nil { + return nil + } + dst := new(ServiceConfig) + *dst = *src + if dst.TCP != nil { + dst.TCP = map[uint16]*TCPPortHandler{} + for k, v := range src.TCP { + if v == nil { + dst.TCP[k] = nil + } else { + dst.TCP[k] = ptr.To(*v) + } + } + } + if dst.Web != nil { + dst.Web = map[HostPort]*WebServerConfig{} + for k, v := range src.Web { + if v == nil { + dst.Web[k] = nil + } else { + dst.Web[k] = v.Clone() + } + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _ServiceConfigCloneNeedsRegeneration = ServiceConfig(struct { + TCP map[uint16]*TCPPortHandler + Web map[HostPort]*WebServerConfig + Tun bool +}{}) + // Clone makes a deep copy of TCPPortHandler. // The result aliases no memory with the original. func (src *TCPPortHandler) Clone() *TCPPortHandler { diff --git a/ipn/ipn_view.go b/ipn/ipn_view.go index 83a7aebb1de43..bc67531e4253d 100644 --- a/ipn/ipn_view.go +++ b/ipn/ipn_view.go @@ -18,7 +18,7 @@ import ( "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=Prefs,ServeConfig,TCPPortHandler,HTTPHandler,WebServerConfig +//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=Prefs,ServeConfig,ServiceConfig,TCPPortHandler,HTTPHandler,WebServerConfig // View returns a readonly view of Prefs. func (p *Prefs) View() PrefsView { @@ -195,6 +195,12 @@ func (v ServeConfigView) Web() views.MapFn[HostPort, *WebServerConfig, WebServer }) } +func (v ServeConfigView) Services() views.MapFn[string, *ServiceConfig, ServiceConfigView] { + return views.MapFnOf(v.ж.Services, func(t *ServiceConfig) ServiceConfigView { + return t.View() + }) +} + func (v ServeConfigView) AllowFunnel() views.Map[HostPort, bool] { return views.MapOf(v.ж.AllowFunnel) } @@ -210,11 +216,77 @@ func (v ServeConfigView) ETag() string { return v.ж.ETag } var _ServeConfigViewNeedsRegeneration = ServeConfig(struct { TCP map[uint16]*TCPPortHandler Web map[HostPort]*WebServerConfig + Services map[string]*ServiceConfig AllowFunnel map[HostPort]bool Foreground map[string]*ServeConfig ETag string }{}) +// View returns a readonly view of ServiceConfig. +func (p *ServiceConfig) View() ServiceConfigView { + return ServiceConfigView{ж: p} +} + +// ServiceConfigView provides a read-only view over ServiceConfig. +// +// Its methods should only be called if `Valid()` returns true. +type ServiceConfigView struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *ServiceConfig +} + +// Valid reports whether underlying value is non-nil. +func (v ServiceConfigView) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v ServiceConfigView) AsStruct() *ServiceConfig { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v ServiceConfigView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *ServiceConfigView) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x ServiceConfig + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v ServiceConfigView) TCP() views.MapFn[uint16, *TCPPortHandler, TCPPortHandlerView] { + return views.MapFnOf(v.ж.TCP, func(t *TCPPortHandler) TCPPortHandlerView { + return t.View() + }) +} + +func (v ServiceConfigView) Web() views.MapFn[HostPort, *WebServerConfig, WebServerConfigView] { + return views.MapFnOf(v.ж.Web, func(t *WebServerConfig) WebServerConfigView { + return t.View() + }) +} +func (v ServiceConfigView) Tun() bool { return v.ж.Tun } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _ServiceConfigViewNeedsRegeneration = ServiceConfig(struct { + TCP map[uint16]*TCPPortHandler + Web map[HostPort]*WebServerConfig + Tun bool +}{}) + // View returns a readonly view of TCPPortHandler. func (p *TCPPortHandler) View() TCPPortHandlerView { return TCPPortHandlerView{ж: p} diff --git a/ipn/serve.go b/ipn/serve.go index 5c0a97ed3ffa9..49e0d9fa3d67a 100644 --- a/ipn/serve.go +++ b/ipn/serve.go @@ -24,6 +24,23 @@ func ServeConfigKey(profileID ProfileID) StateKey { return StateKey("_serve/" + profileID) } +// ServiceConfig contains the config information for a single service. +// it contains a bool to indicate if the service is in Tun mode (L3 forwarding). +// If the service is not in Tun mode, the service is configured by the L4 forwarding +// (TCP ports) and/or the L7 forwarding (http handlers) information. +type ServiceConfig struct { + // TCP are the list of TCP port numbers that tailscaled should handle for + // the Tailscale IP addresses. (not subnet routers, etc) + TCP map[uint16]*TCPPortHandler `json:",omitempty"` + + // Web maps from "$SNI_NAME:$PORT" to a set of HTTP handlers + // keyed by mount point ("/", "/foo", etc) + Web map[HostPort]*WebServerConfig `json:",omitempty"` + + // Tun determines if the service should be using L3 forwarding (Tun mode). + Tun bool `json:",omitempty"` +} + // ServeConfig is the JSON type stored in the StateStore for // StateKey "_serve/$PROFILE_ID" as returned by ServeConfigKey. type ServeConfig struct { @@ -35,6 +52,10 @@ type ServeConfig struct { // keyed by mount point ("/", "/foo", etc) Web map[HostPort]*WebServerConfig `json:",omitempty"` + // Services maps from service name to a ServiceConfig. Which describes the + // L3, L4, and L7 forwarding information for the service. + Services map[string]*ServiceConfig `json:",omitempty"` + // AllowFunnel is the set of SNI:port values for which funnel // traffic is allowed, from trusted ingress peers. AllowFunnel map[HostPort]bool `json:",omitempty"` From eabb424275c8c90dab8d3e0130edea2de432695e Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Tue, 3 Dec 2024 07:01:14 +0000 Subject: [PATCH 160/179] cmd/k8s-operator,docs/k8s: run tun mode proxies in privileged containers (#14262) We were previously relying on unintended behaviour by runc where all containers where by default given read/write/mknod permissions for tun devices. This behaviour was removed in https://github.com/opencontainers/runc/pull/3468 and released in runc 1.2. Containerd container runtime, used by Docker and majority of Kubernetes distributions bumped runc to 1.2 in 1.7.24 https://github.com/containerd/containerd/releases/tag/v1.7.24 thus breaking our reference tun mode Tailscale Kubernetes manifests and Kubernetes operator proxies. This PR changes the all Kubernetes container configs that run Tailscale in tun mode to privileged. This should not be a breaking change because all these containers would run in a Pod that already has a privileged init container. Updates tailscale/tailscale#14256 Updates tailscale/tailscale#10814 Signed-off-by: Irbe Krumina --- .../crds/tailscale.com_proxyclasses.yaml | 22 ++++++++++--------- .../deploy/manifests/operator.yaml | 22 ++++++++++--------- cmd/k8s-operator/deploy/manifests/proxy.yaml | 4 +--- cmd/k8s-operator/testutils_test.go | 4 +--- docs/k8s/proxy.yaml | 4 +--- docs/k8s/sidecar.yaml | 4 +--- docs/k8s/subnet.yaml | 4 +--- k8s-operator/api.md | 2 +- .../apis/v1alpha1/types_proxyclass.go | 11 +++++----- 9 files changed, 36 insertions(+), 41 deletions(-) diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml index 4c24a1633284e..ad2e8f2432b2e 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml @@ -1384,11 +1384,12 @@ spec: securityContext: description: |- Container security context. - Security context specified here will override the security context by the operator. - By default the operator: - - sets 'privileged: true' for the init container - - set NET_ADMIN capability for tailscale container for proxies that - are created for Services or Connector. + Security context specified here will override the security context set by the operator. + By default the operator sets the Tailscale container and the Tailscale init container to privileged + for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup. + You can reduce the permissions of the Tailscale container to cap NET_ADMIN by + installing device plugin in your cluster and configuring the proxies tun device to be created + by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752 https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context type: object properties: @@ -1707,11 +1708,12 @@ spec: securityContext: description: |- Container security context. - Security context specified here will override the security context by the operator. - By default the operator: - - sets 'privileged: true' for the init container - - set NET_ADMIN capability for tailscale container for proxies that - are created for Services or Connector. + Security context specified here will override the security context set by the operator. + By default the operator sets the Tailscale container and the Tailscale init container to privileged + for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup. + You can reduce the permissions of the Tailscale container to cap NET_ADMIN by + installing device plugin in your cluster and configuring the proxies tun device to be created + by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752 https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context type: object properties: diff --git a/cmd/k8s-operator/deploy/manifests/operator.yaml b/cmd/k8s-operator/deploy/manifests/operator.yaml index f764fc09ac353..9b90919fb4a15 100644 --- a/cmd/k8s-operator/deploy/manifests/operator.yaml +++ b/cmd/k8s-operator/deploy/manifests/operator.yaml @@ -1851,11 +1851,12 @@ spec: securityContext: description: |- Container security context. - Security context specified here will override the security context by the operator. - By default the operator: - - sets 'privileged: true' for the init container - - set NET_ADMIN capability for tailscale container for proxies that - are created for Services or Connector. + Security context specified here will override the security context set by the operator. + By default the operator sets the Tailscale container and the Tailscale init container to privileged + for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup. + You can reduce the permissions of the Tailscale container to cap NET_ADMIN by + installing device plugin in your cluster and configuring the proxies tun device to be created + by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752 https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context properties: allowPrivilegeEscalation: @@ -2174,11 +2175,12 @@ spec: securityContext: description: |- Container security context. - Security context specified here will override the security context by the operator. - By default the operator: - - sets 'privileged: true' for the init container - - set NET_ADMIN capability for tailscale container for proxies that - are created for Services or Connector. + Security context specified here will override the security context set by the operator. + By default the operator sets the Tailscale container and the Tailscale init container to privileged + for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup. + You can reduce the permissions of the Tailscale container to cap NET_ADMIN by + installing device plugin in your cluster and configuring the proxies tun device to be created + by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752 https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context properties: allowPrivilegeEscalation: diff --git a/cmd/k8s-operator/deploy/manifests/proxy.yaml b/cmd/k8s-operator/deploy/manifests/proxy.yaml index 1ad63c2653361..3c9a3eaa36c56 100644 --- a/cmd/k8s-operator/deploy/manifests/proxy.yaml +++ b/cmd/k8s-operator/deploy/manifests/proxy.yaml @@ -39,6 +39,4 @@ spec: fieldRef: fieldPath: metadata.uid securityContext: - capabilities: - add: - - NET_ADMIN + privileged: true diff --git a/cmd/k8s-operator/testutils_test.go b/cmd/k8s-operator/testutils_test.go index 8f06f5979cbf4..5f016e91dcc7c 100644 --- a/cmd/k8s-operator/testutils_test.go +++ b/cmd/k8s-operator/testutils_test.go @@ -76,9 +76,7 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef {Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig"}, }, SecurityContext: &corev1.SecurityContext{ - Capabilities: &corev1.Capabilities{ - Add: []corev1.Capability{"NET_ADMIN"}, - }, + Privileged: ptr.To(true), }, ImagePullPolicy: "Always", } diff --git a/docs/k8s/proxy.yaml b/docs/k8s/proxy.yaml index 78e97c83b2be9..048fd7a5bddf9 100644 --- a/docs/k8s/proxy.yaml +++ b/docs/k8s/proxy.yaml @@ -53,6 +53,4 @@ spec: fieldRef: fieldPath: metadata.uid securityContext: - capabilities: - add: - - NET_ADMIN + privileged: true diff --git a/docs/k8s/sidecar.yaml b/docs/k8s/sidecar.yaml index 6baa6d5458d49..520e4379ad9ee 100644 --- a/docs/k8s/sidecar.yaml +++ b/docs/k8s/sidecar.yaml @@ -35,6 +35,4 @@ spec: fieldRef: fieldPath: metadata.uid securityContext: - capabilities: - add: - - NET_ADMIN + privileged: true diff --git a/docs/k8s/subnet.yaml b/docs/k8s/subnet.yaml index 1af146be689e6..ef4e4748c0ceb 100644 --- a/docs/k8s/subnet.yaml +++ b/docs/k8s/subnet.yaml @@ -37,6 +37,4 @@ spec: fieldRef: fieldPath: metadata.uid securityContext: - capabilities: - add: - - NET_ADMIN + privileged: true diff --git a/k8s-operator/api.md b/k8s-operator/api.md index 640d8fb07bc54..730bed210118f 100644 --- a/k8s-operator/api.md +++ b/k8s-operator/api.md @@ -145,7 +145,7 @@ _Appears in:_ | `image` _string_ | Container image name. By default images are pulled from
docker.io/tailscale/tailscale, but the official images are also
available at ghcr.io/tailscale/tailscale. Specifying image name here
will override any proxy image values specified via the Kubernetes
operator's Helm chart values or PROXY_IMAGE env var in the operator
Deployment.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#image | | | | `imagePullPolicy` _[PullPolicy](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#pullpolicy-v1-core)_ | Image pull policy. One of Always, Never, IfNotPresent. Defaults to Always.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#image | | Enum: [Always Never IfNotPresent]
| | `resources` _[ResourceRequirements](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#resourcerequirements-v1-core)_ | Container resource requirements.
By default Tailscale Kubernetes operator does not apply any resource
requirements. The amount of resources required wil depend on the
amount of resources the operator needs to parse, usage patterns and
cluster size.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#resources | | | -| `securityContext` _[SecurityContext](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#securitycontext-v1-core)_ | Container security context.
Security context specified here will override the security context by the operator.
By default the operator:
- sets 'privileged: true' for the init container
- set NET_ADMIN capability for tailscale container for proxies that
are created for Services or Connector.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context | | | +| `securityContext` _[SecurityContext](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#securitycontext-v1-core)_ | Container security context.
Security context specified here will override the security context set by the operator.
By default the operator sets the Tailscale container and the Tailscale init container to privileged
for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup.
You can reduce the permissions of the Tailscale container to cap NET_ADMIN by
installing device plugin in your cluster and configuring the proxies tun device to be created
by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context | | | | `debug` _[Debug](#debug)_ | Configuration for enabling extra debug information in the container.
Not recommended for production use. | | | diff --git a/k8s-operator/apis/v1alpha1/types_proxyclass.go b/k8s-operator/apis/v1alpha1/types_proxyclass.go index 7e408cd0a7338..71fbf24390d55 100644 --- a/k8s-operator/apis/v1alpha1/types_proxyclass.go +++ b/k8s-operator/apis/v1alpha1/types_proxyclass.go @@ -206,11 +206,12 @@ type Container struct { // +optional Resources corev1.ResourceRequirements `json:"resources,omitempty"` // Container security context. - // Security context specified here will override the security context by the operator. - // By default the operator: - // - sets 'privileged: true' for the init container - // - set NET_ADMIN capability for tailscale container for proxies that - // are created for Services or Connector. + // Security context specified here will override the security context set by the operator. + // By default the operator sets the Tailscale container and the Tailscale init container to privileged + // for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup. + // You can reduce the permissions of the Tailscale container to cap NET_ADMIN by + // installing device plugin in your cluster and configuring the proxies tun device to be created + // by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752 // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context // +optional SecurityContext *corev1.SecurityContext `json:"securityContext,omitempty"` From 9f9063e624c66d295d286d2f7bc85c02dfd46d4f Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Tue, 3 Dec 2024 12:35:25 +0000 Subject: [PATCH 161/179] cmd/k8s-operator,k8s-operator,go.mod: optionally create ServiceMonitor (#14248) * cmd/k8s-operator,k8s-operator,go.mod: optionally create ServiceMonitor Adds a new spec.metrics.serviceMonitor field to ProxyClass. If that's set to true (and metrics are enabled), the operator will create a Prometheus ServiceMonitor for each proxy to which the ProxyClass applies. Additionally, create a metrics Service for each proxy that has metrics enabled. Updates tailscale/tailscale#11292 Signed-off-by: Irbe Krumina --- cmd/k8s-operator/connector.go | 3 +- cmd/k8s-operator/depaware.txt | 2 +- .../deploy/chart/templates/operator-rbac.yaml | 7 + .../crds/tailscale.com_proxyclasses.yaml | 21 ++ .../deploy/manifests/operator.yaml | 41 +++ cmd/k8s-operator/ingress.go | 3 +- cmd/k8s-operator/ingress_test.go | 122 ++++++++ cmd/k8s-operator/metrics_resources.go | 272 ++++++++++++++++++ cmd/k8s-operator/operator.go | 77 ++++- cmd/k8s-operator/proxyclass.go | 25 +- cmd/k8s-operator/proxyclass_test.go | 23 +- cmd/k8s-operator/proxygroup.go | 17 ++ cmd/k8s-operator/proxygroup_test.go | 32 ++- cmd/k8s-operator/sts.go | 28 +- cmd/k8s-operator/svc.go | 11 +- cmd/k8s-operator/testutils_test.go | 148 ++++++++++ go.mod | 2 +- k8s-operator/api.md | 19 +- k8s-operator/apis/v1alpha1/register.go | 7 + .../apis/v1alpha1/types_proxyclass.go | 17 ++ .../apis/v1alpha1/zz_generated.deepcopy.go | 22 +- 21 files changed, 877 insertions(+), 22 deletions(-) create mode 100644 cmd/k8s-operator/metrics_resources.go diff --git a/cmd/k8s-operator/connector.go b/cmd/k8s-operator/connector.go index 1c1df7c962b91..1ed6fd1556d5f 100644 --- a/cmd/k8s-operator/connector.go +++ b/cmd/k8s-operator/connector.go @@ -189,6 +189,7 @@ func (a *ConnectorReconciler) maybeProvisionConnector(ctx context.Context, logge isExitNode: cn.Spec.ExitNode, }, ProxyClassName: proxyClass, + proxyType: proxyTypeConnector, } if cn.Spec.SubnetRouter != nil && len(cn.Spec.SubnetRouter.AdvertiseRoutes) > 0 { @@ -253,7 +254,7 @@ func (a *ConnectorReconciler) maybeProvisionConnector(ctx context.Context, logge } func (a *ConnectorReconciler) maybeCleanupConnector(ctx context.Context, logger *zap.SugaredLogger, cn *tsapi.Connector) (bool, error) { - if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(cn.Name, a.tsnamespace, "connector")); err != nil { + if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(cn.Name, a.tsnamespace, "connector"), proxyTypeConnector); err != nil { return false, fmt.Errorf("failed to cleanup Connector resources: %w", err) } else if !done { logger.Debugf("Connector cleanup not done yet, waiting for next reconcile") diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 900d10efedc99..d1d687432863b 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -378,7 +378,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/api/storage/v1beta1 from k8s.io/client-go/applyconfigurations/storage/v1beta1+ k8s.io/api/storagemigration/v1alpha1 from k8s.io/client-go/applyconfigurations/storagemigration/v1alpha1+ k8s.io/apiextensions-apiserver/pkg/apis/apiextensions from k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1 - 💣 k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1 from sigs.k8s.io/controller-runtime/pkg/webhook/conversion + 💣 k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1 from sigs.k8s.io/controller-runtime/pkg/webhook/conversion+ k8s.io/apimachinery/pkg/api/equality from k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1+ k8s.io/apimachinery/pkg/api/errors from k8s.io/apimachinery/pkg/util/managedfields/internal+ k8s.io/apimachinery/pkg/api/meta from k8s.io/apimachinery/pkg/api/validation+ diff --git a/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml b/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml index ede61070b4399..a56edfe0d1b80 100644 --- a/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml @@ -30,6 +30,10 @@ rules: - apiGroups: ["tailscale.com"] resources: ["recorders", "recorders/status"] verbs: ["get", "list", "watch", "update"] +- apiGroups: ["apiextensions.k8s.io"] + resources: ["customresourcedefinitions"] + verbs: ["get", "list", "watch"] + resourceNames: ["servicemonitors.monitoring.coreos.com"] --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding @@ -65,6 +69,9 @@ rules: - apiGroups: ["rbac.authorization.k8s.io"] resources: ["roles", "rolebindings"] verbs: ["get", "create", "patch", "update", "list", "watch"] +- apiGroups: ["monitoring.coreos.com"] + resources: ["servicemonitors"] + verbs: ["get", "list", "update", "create", "delete"] --- apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml index ad2e8f2432b2e..9b45deedb62b7 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml @@ -74,6 +74,8 @@ spec: description: |- Setting enable to true will make the proxy serve Tailscale metrics at :9002/metrics. + A metrics Service named -metrics will also be created in the operator's namespace and will + serve the metrics at :9002/metrics. In 1.78.x and 1.80.x, this field also serves as the default value for .spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both @@ -81,6 +83,25 @@ spec: Defaults to false. type: boolean + serviceMonitor: + description: |- + Enable to create a Prometheus ServiceMonitor for scraping the proxy's Tailscale metrics. + The ServiceMonitor will select the metrics Service that gets created when metrics are enabled. + The ingested metrics for each Service monitor will have labels to identify the proxy: + ts_proxy_type: ingress_service|ingress_resource|connector|proxygroup + ts_proxy_parent_name: name of the parent resource (i.e name of the Connector, Tailscale Ingress, Tailscale Service or ProxyGroup) + ts_proxy_parent_namespace: namespace of the parent resource (if the parent resource is not cluster scoped) + job: ts__[]_ + type: object + required: + - enable + properties: + enable: + description: If Enable is set to true, a Prometheus ServiceMonitor will be created. Enable can only be set to true if metrics are enabled. + type: boolean + x-kubernetes-validations: + - rule: '!(has(self.serviceMonitor) && self.serviceMonitor.enable && !self.enable)' + message: ServiceMonitor can only be enabled if metrics are enabled statefulSet: description: |- Configuration parameters for the proxy's StatefulSet. Tailscale diff --git a/cmd/k8s-operator/deploy/manifests/operator.yaml b/cmd/k8s-operator/deploy/manifests/operator.yaml index 9b90919fb4a15..210a7b43463e5 100644 --- a/cmd/k8s-operator/deploy/manifests/operator.yaml +++ b/cmd/k8s-operator/deploy/manifests/operator.yaml @@ -541,6 +541,8 @@ spec: description: |- Setting enable to true will make the proxy serve Tailscale metrics at :9002/metrics. + A metrics Service named -metrics will also be created in the operator's namespace and will + serve the metrics at :9002/metrics. In 1.78.x and 1.80.x, this field also serves as the default value for .spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both @@ -548,9 +550,28 @@ spec: Defaults to false. type: boolean + serviceMonitor: + description: |- + Enable to create a Prometheus ServiceMonitor for scraping the proxy's Tailscale metrics. + The ServiceMonitor will select the metrics Service that gets created when metrics are enabled. + The ingested metrics for each Service monitor will have labels to identify the proxy: + ts_proxy_type: ingress_service|ingress_resource|connector|proxygroup + ts_proxy_parent_name: name of the parent resource (i.e name of the Connector, Tailscale Ingress, Tailscale Service or ProxyGroup) + ts_proxy_parent_namespace: namespace of the parent resource (if the parent resource is not cluster scoped) + job: ts__[]_ + properties: + enable: + description: If Enable is set to true, a Prometheus ServiceMonitor will be created. Enable can only be set to true if metrics are enabled. + type: boolean + required: + - enable + type: object required: - enable type: object + x-kubernetes-validations: + - message: ServiceMonitor can only be enabled if metrics are enabled + rule: '!(has(self.serviceMonitor) && self.serviceMonitor.enable && !self.enable)' statefulSet: description: |- Configuration parameters for the proxy's StatefulSet. Tailscale @@ -4648,6 +4669,16 @@ rules: - list - watch - update + - apiGroups: + - apiextensions.k8s.io + resourceNames: + - servicemonitors.monitoring.coreos.com + resources: + - customresourcedefinitions + verbs: + - get + - list + - watch --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding @@ -4728,6 +4759,16 @@ rules: - update - list - watch + - apiGroups: + - monitoring.coreos.com + resources: + - servicemonitors + verbs: + - get + - list + - update + - create + - delete --- apiVersion: rbac.authorization.k8s.io/v1 kind: Role diff --git a/cmd/k8s-operator/ingress.go b/cmd/k8s-operator/ingress.go index acc90d465093a..40a5d09283193 100644 --- a/cmd/k8s-operator/ingress.go +++ b/cmd/k8s-operator/ingress.go @@ -90,7 +90,7 @@ func (a *IngressReconciler) maybeCleanup(ctx context.Context, logger *zap.Sugare return nil } - if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(ing.Name, ing.Namespace, "ingress")); err != nil { + if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(ing.Name, ing.Namespace, "ingress"), proxyTypeIngressResource); err != nil { return fmt.Errorf("failed to cleanup: %w", err) } else if !done { logger.Debugf("cleanup not done yet, waiting for next reconcile") @@ -268,6 +268,7 @@ func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.Suga Tags: tags, ChildResourceLabels: crl, ProxyClassName: proxyClass, + proxyType: proxyTypeIngressResource, } if val := ing.GetAnnotations()[AnnotationExperimentalForwardClusterTrafficViaL7IngresProxy]; val == "true" { diff --git a/cmd/k8s-operator/ingress_test.go b/cmd/k8s-operator/ingress_test.go index 38a041dde07f9..e695cc649408c 100644 --- a/cmd/k8s-operator/ingress_test.go +++ b/cmd/k8s-operator/ingress_test.go @@ -12,6 +12,7 @@ import ( appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" networkingv1 "k8s.io/api/networking/v1" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client/fake" @@ -271,3 +272,124 @@ func TestTailscaleIngressWithProxyClass(t *testing.T) { opts.proxyClass = "" expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation) } + +func TestTailscaleIngressWithServiceMonitor(t *testing.T) { + pc := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{Name: "metrics", Generation: 1}, + Spec: tsapi.ProxyClassSpec{ + Metrics: &tsapi.Metrics{ + Enable: true, + ServiceMonitor: &tsapi.ServiceMonitor{Enable: true}, + }, + }, + Status: tsapi.ProxyClassStatus{ + Conditions: []metav1.Condition{{ + Status: metav1.ConditionTrue, + Type: string(tsapi.ProxyClassReady), + ObservedGeneration: 1, + }}}, + } + crd := &apiextensionsv1.CustomResourceDefinition{ObjectMeta: metav1.ObjectMeta{Name: serviceMonitorCRD}} + tsIngressClass := &networkingv1.IngressClass{ObjectMeta: metav1.ObjectMeta{Name: "tailscale"}, Spec: networkingv1.IngressClassSpec{Controller: "tailscale.com/ts-ingress"}} + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(pc, tsIngressClass). + WithStatusSubresource(pc). + Build() + ft := &fakeTSClient{} + fakeTsnetServer := &fakeTSNetServer{certDomains: []string{"foo.com"}} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + ingR := &IngressReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + tsnetServer: fakeTsnetServer, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + } + // 1. Enable metrics- expect metrics Service to be created + ing := &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + // The apiserver is supposed to set the UID, but the fake client + // doesn't. So, set it explicitly because other code later depends + // on it being set. + UID: types.UID("1234-UID"), + Labels: map[string]string{ + "tailscale.com/proxy-class": "metrics", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + DefaultBackend: &networkingv1.IngressBackend{ + Service: &networkingv1.IngressServiceBackend{ + Name: "test", + Port: networkingv1.ServiceBackendPort{ + Number: 8080, + }, + }, + }, + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"default-test"}}, + }, + }, + } + mustCreate(t, fc, ing) + mustCreate(t, fc, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "1.2.3.4", + Ports: []corev1.ServicePort{{ + Port: 8080, + Name: "http"}, + }, + }, + }) + + expectReconciled(t, ingR, "default", "test") + + fullName, shortName := findGenName(t, fc, "default", "test", "ingress") + opts := configOpts{ + stsName: shortName, + secretName: fullName, + namespace: "default", + tailscaleNamespace: "operator-ns", + parentType: "ingress", + hostname: "default-test", + app: kubetypes.AppIngressResource, + enableMetrics: true, + namespaced: true, + proxyType: proxyTypeIngressResource, + } + serveConfig := &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{"${TS_CERT_DOMAIN}:443": {Handlers: map[string]*ipn.HTTPHandler{"/": {Proxy: "http://1.2.3.4:8080/"}}}}, + } + opts.serveConfig = serveConfig + + expectEqual(t, fc, expectedSecret(t, fc, opts), nil) + expectEqual(t, fc, expectedHeadlessService(shortName, "ingress"), nil) + expectEqual(t, fc, expectedMetricsService(opts), nil) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation) + // 2. Enable ServiceMonitor - should not error when there is no ServiceMonitor CRD in cluster + mustUpdate(t, fc, "", "metrics", func(pc *tsapi.ProxyClass) { + pc.Spec.Metrics.ServiceMonitor = &tsapi.ServiceMonitor{Enable: true} + }) + expectReconciled(t, ingR, "default", "test") + // 3. Create ServiceMonitor CRD and reconcile- ServiceMonitor should get created + mustCreate(t, fc, crd) + expectReconciled(t, ingR, "default", "test") + expectEqualUnstructured(t, fc, expectedServiceMonitor(t, opts)) +} diff --git a/cmd/k8s-operator/metrics_resources.go b/cmd/k8s-operator/metrics_resources.go new file mode 100644 index 0000000000000..4881436e8e184 --- /dev/null +++ b/cmd/k8s-operator/metrics_resources.go @@ -0,0 +1,272 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "fmt" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" +) + +const ( + labelMetricsTarget = "tailscale.com/metrics-target" + + // These labels get transferred from the metrics Service to the ingested Prometheus metrics. + labelPromProxyType = "ts_proxy_type" + labelPromProxyParentName = "ts_proxy_parent_name" + labelPromProxyParentNamespace = "ts_proxy_parent_namespace" + labelPromJob = "ts_prom_job" + + serviceMonitorCRD = "servicemonitors.monitoring.coreos.com" +) + +// ServiceMonitor contains a subset of fields of servicemonitors.monitoring.coreos.com Custom Resource Definition. +// Duplicating it here allows us to avoid importing prometheus-operator library. +// https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L40 +type ServiceMonitor struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata"` + Spec ServiceMonitorSpec `json:"spec"` +} + +// https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L55 +type ServiceMonitorSpec struct { + // Endpoints defines the endpoints to be scraped on the selected Service(s). + // https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L82 + Endpoints []ServiceMonitorEndpoint `json:"endpoints"` + // JobLabel is the label on the Service whose value will become the value of the Prometheus job label for the metrics ingested via this ServiceMonitor. + // https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L66 + JobLabel string `json:"jobLabel"` + // NamespaceSelector selects the namespace of Service(s) that this ServiceMonitor allows to scrape. + // https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L88 + NamespaceSelector ServiceMonitorNamespaceSelector `json:"namespaceSelector,omitempty"` + // Selector is the label selector for Service(s) that this ServiceMonitor allows to scrape. + // https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L85 + Selector metav1.LabelSelector `json:"selector"` + // TargetLabels are labels on the selected Service that should be applied as Prometheus labels to the ingested metrics. + // https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L72 + TargetLabels []string `json:"targetLabels"` +} + +// ServiceMonitorNamespaceSelector selects namespaces in which Prometheus operator will attempt to find Services for +// this ServiceMonitor. +// https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L88 +type ServiceMonitorNamespaceSelector struct { + MatchNames []string `json:"matchNames,omitempty"` +} + +// ServiceMonitorEndpoint defines an endpoint of Service to scrape. We only define port here. Prometheus by default +// scrapes /metrics path, which is what we want. +type ServiceMonitorEndpoint struct { + // Port is the name of the Service port that Prometheus will scrape. + Port string `json:"port,omitempty"` +} + +func reconcileMetricsResources(ctx context.Context, logger *zap.SugaredLogger, opts *metricsOpts, pc *tsapi.ProxyClass, cl client.Client) error { + if opts.proxyType == proxyTypeEgress { + // Metrics are currently not being enabled for standalone egress proxies. + return nil + } + if pc == nil || pc.Spec.Metrics == nil || !pc.Spec.Metrics.Enable { + return maybeCleanupMetricsResources(ctx, opts, cl) + } + metricsSvc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: metricsResourceName(opts.proxyStsName), + Namespace: opts.tsNamespace, + Labels: metricsResourceLabels(opts), + }, + Spec: corev1.ServiceSpec{ + Selector: opts.proxyLabels, + Type: corev1.ServiceTypeClusterIP, + Ports: []corev1.ServicePort{{Protocol: "TCP", Port: 9002, Name: "metrics"}}, + }, + } + var err error + metricsSvc, err = createOrUpdate(ctx, cl, opts.tsNamespace, metricsSvc, func(svc *corev1.Service) { + svc.Spec.Ports = metricsSvc.Spec.Ports + svc.Spec.Selector = metricsSvc.Spec.Selector + }) + if err != nil { + return fmt.Errorf("error ensuring metrics Service: %w", err) + } + + crdExists, err := hasServiceMonitorCRD(ctx, cl) + if err != nil { + return fmt.Errorf("error verifying that %q CRD exists: %w", serviceMonitorCRD, err) + } + if !crdExists { + return nil + } + + if pc.Spec.Metrics.ServiceMonitor == nil || !pc.Spec.Metrics.ServiceMonitor.Enable { + return maybeCleanupServiceMonitor(ctx, cl, opts.proxyStsName, opts.tsNamespace) + } + + logger.Info("ensuring ServiceMonitor for metrics Service %s/%s", metricsSvc.Namespace, metricsSvc.Name) + svcMonitor, err := newServiceMonitor(metricsSvc) + if err != nil { + return fmt.Errorf("error creating ServiceMonitor: %w", err) + } + // We don't use createOrUpdate here because that does not work with unstructured types. We also do not update + // the ServiceMonitor because it is not expected that any of its fields would change. Currently this is good + // enough, but in future we might want to add logic to create-or-update unstructured types. + err = cl.Get(ctx, client.ObjectKeyFromObject(metricsSvc), svcMonitor.DeepCopy()) + if apierrors.IsNotFound(err) { + if err := cl.Create(ctx, svcMonitor); err != nil { + return fmt.Errorf("error creating ServiceMonitor: %w", err) + } + return nil + } + if err != nil { + return fmt.Errorf("error getting ServiceMonitor: %w", err) + } + return nil +} + +// maybeCleanupMetricsResources ensures that any metrics resources created for a proxy are deleted. Only metrics Service +// gets deleted explicitly because the ServiceMonitor has Service's owner reference, so gets garbage collected +// automatically. +func maybeCleanupMetricsResources(ctx context.Context, opts *metricsOpts, cl client.Client) error { + sel := metricsSvcSelector(opts.proxyLabels, opts.proxyType) + return cl.DeleteAllOf(ctx, &corev1.Service{}, client.InNamespace(opts.tsNamespace), client.MatchingLabels(sel)) +} + +// maybeCleanupServiceMonitor cleans up any ServiceMonitor created for the named proxy StatefulSet. +func maybeCleanupServiceMonitor(ctx context.Context, cl client.Client, stsName, ns string) error { + smName := metricsResourceName(stsName) + sm := serviceMonitorTemplate(smName, ns) + u, err := serviceMonitorToUnstructured(sm) + if err != nil { + return fmt.Errorf("error building ServiceMonitor: %w", err) + } + err = cl.Get(ctx, types.NamespacedName{Name: smName, Namespace: ns}, u) + if apierrors.IsNotFound(err) { + return nil // nothing to do + } + if err != nil { + return fmt.Errorf("error verifying if ServiceMonitor %s/%s exists: %w", ns, stsName, err) + } + return cl.Delete(ctx, u) +} + +// newServiceMonitor takes a metrics Service created for a proxy and constructs and returns a ServiceMonitor for that +// proxy that can be applied to the kube API server. +// The ServiceMonitor is returned as Unstructured type - this allows us to avoid importing prometheus-operator API server client/schema. +func newServiceMonitor(metricsSvc *corev1.Service) (*unstructured.Unstructured, error) { + sm := serviceMonitorTemplate(metricsSvc.Name, metricsSvc.Namespace) + sm.ObjectMeta.Labels = metricsSvc.Labels + sm.ObjectMeta.OwnerReferences = []metav1.OwnerReference{*metav1.NewControllerRef(metricsSvc, corev1.SchemeGroupVersion.WithKind("Service"))} + sm.Spec = ServiceMonitorSpec{ + Selector: metav1.LabelSelector{MatchLabels: metricsSvc.Labels}, + Endpoints: []ServiceMonitorEndpoint{{ + Port: "metrics", + }}, + NamespaceSelector: ServiceMonitorNamespaceSelector{ + MatchNames: []string{metricsSvc.Namespace}, + }, + JobLabel: labelPromJob, + TargetLabels: []string{ + labelPromProxyParentName, + labelPromProxyParentNamespace, + labelPromProxyType, + }, + } + return serviceMonitorToUnstructured(sm) +} + +// serviceMonitorToUnstructured takes a ServiceMonitor and converts it to Unstructured type that can be used by the c/r +// client in Kubernetes API server calls. +func serviceMonitorToUnstructured(sm *ServiceMonitor) (*unstructured.Unstructured, error) { + contents, err := runtime.DefaultUnstructuredConverter.ToUnstructured(sm) + if err != nil { + return nil, fmt.Errorf("error converting ServiceMonitor to Unstructured: %w", err) + } + u := &unstructured.Unstructured{} + u.SetUnstructuredContent(contents) + u.SetGroupVersionKind(sm.GroupVersionKind()) + return u, nil +} + +// metricsResourceName returns name for metrics Service and ServiceMonitor for a proxy StatefulSet. +func metricsResourceName(stsName string) string { + // Maximum length of StatefulSet name if 52 chars, so this is fine. + return fmt.Sprintf("%s-metrics", stsName) +} + +// metricsResourceLabels constructs labels that will be applied to metrics Service and metrics ServiceMonitor for a +// proxy. +func metricsResourceLabels(opts *metricsOpts) map[string]string { + lbls := map[string]string{ + LabelManaged: "true", + labelMetricsTarget: opts.proxyStsName, + labelPromProxyType: opts.proxyType, + labelPromProxyParentName: opts.proxyLabels[LabelParentName], + } + // Include namespace label for proxies created for a namespaced type. + if isNamespacedProxyType(opts.proxyType) { + lbls[labelPromProxyParentNamespace] = opts.proxyLabels[LabelParentNamespace] + } + lbls[labelPromJob] = promJobName(opts) + return lbls +} + +// promJobName constructs the value of the Prometheus job label that will apply to all metrics for a ServiceMonitor. +func promJobName(opts *metricsOpts) string { + // Include parent resource namespace for proxies created for namespaced types. + if opts.proxyType == proxyTypeIngressResource || opts.proxyType == proxyTypeIngressService { + return fmt.Sprintf("ts_%s_%s_%s", opts.proxyType, opts.proxyLabels[LabelParentNamespace], opts.proxyLabels[LabelParentName]) + } + return fmt.Sprintf("ts_%s_%s", opts.proxyType, opts.proxyLabels[LabelParentName]) +} + +// metricsSvcSelector returns the minimum label set to uniquely identify a metrics Service for a proxy. +func metricsSvcSelector(proxyLabels map[string]string, proxyType string) map[string]string { + sel := map[string]string{ + labelPromProxyType: proxyType, + labelPromProxyParentName: proxyLabels[LabelParentName], + } + // Include namespace label for proxies created for a namespaced type. + if isNamespacedProxyType(proxyType) { + sel[labelPromProxyParentNamespace] = proxyLabels[LabelParentNamespace] + } + return sel +} + +// serviceMonitorTemplate returns a base ServiceMonitor type that, when converted to Unstructured, is a valid type that +// can be used in kube API server calls via the c/r client. +func serviceMonitorTemplate(name, ns string) *ServiceMonitor { + return &ServiceMonitor{ + TypeMeta: metav1.TypeMeta{ + Kind: "ServiceMonitor", + APIVersion: "monitoring.coreos.com/v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: ns, + }, + } +} + +type metricsOpts struct { + proxyStsName string // name of StatefulSet for proxy + tsNamespace string // namespace in which Tailscale is installed + proxyLabels map[string]string // labels of the proxy StatefulSet + proxyType string +} + +func isNamespacedProxyType(typ string) bool { + return typ == proxyTypeIngressResource || typ == proxyTypeIngressService +} diff --git a/cmd/k8s-operator/operator.go b/cmd/k8s-operator/operator.go index 116ba02e0ce1c..ebb2c4578ab93 100644 --- a/cmd/k8s-operator/operator.go +++ b/cmd/k8s-operator/operator.go @@ -24,8 +24,11 @@ import ( discoveryv1 "k8s.io/api/discovery/v1" networkingv1 "k8s.io/api/networking/v1" rbacv1 "k8s.io/api/rbac/v1" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" + "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/rest" + toolscache "k8s.io/client-go/tools/cache" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/cache" "sigs.k8s.io/controller-runtime/pkg/client" @@ -239,21 +242,29 @@ func runReconcilers(opts reconcilerOpts) { nsFilter := cache.ByObject{ Field: client.InNamespace(opts.tailscaleNamespace).AsSelector(), } + // We watch the ServiceMonitor CRD to ensure that reconcilers are re-triggered if user's workflows result in the + // ServiceMonitor CRD applied after some of our resources that define ServiceMonitor creation. This selector + // ensures that we only watch the ServiceMonitor CRD and that we don't cache full contents of it. + serviceMonitorSelector := cache.ByObject{ + Field: fields.SelectorFromSet(fields.Set{"metadata.name": serviceMonitorCRD}), + Transform: crdTransformer(startlog), + } mgrOpts := manager.Options{ // TODO (irbekrm): stricter filtering what we watch/cache/call // reconcilers on. c/r by default starts a watch on any // resources that we GET via the controller manager's client. Cache: cache.Options{ ByObject: map[client.Object]cache.ByObject{ - &corev1.Secret{}: nsFilter, - &corev1.ServiceAccount{}: nsFilter, - &corev1.Pod{}: nsFilter, - &corev1.ConfigMap{}: nsFilter, - &appsv1.StatefulSet{}: nsFilter, - &appsv1.Deployment{}: nsFilter, - &discoveryv1.EndpointSlice{}: nsFilter, - &rbacv1.Role{}: nsFilter, - &rbacv1.RoleBinding{}: nsFilter, + &corev1.Secret{}: nsFilter, + &corev1.ServiceAccount{}: nsFilter, + &corev1.Pod{}: nsFilter, + &corev1.ConfigMap{}: nsFilter, + &appsv1.StatefulSet{}: nsFilter, + &appsv1.Deployment{}: nsFilter, + &discoveryv1.EndpointSlice{}: nsFilter, + &rbacv1.Role{}: nsFilter, + &rbacv1.RoleBinding{}: nsFilter, + &apiextensionsv1.CustomResourceDefinition{}: serviceMonitorSelector, }, }, Scheme: tsapi.GlobalScheme, @@ -422,8 +433,13 @@ func runReconcilers(opts reconcilerOpts) { startlog.Fatalf("could not create egress EndpointSlices reconciler: %v", err) } + // ProxyClass reconciler gets triggered on ServiceMonitor CRD changes to ensure that any ProxyClasses, that + // define that a ServiceMonitor should be created, were set to invalid because the CRD did not exist get + // reconciled if the CRD is applied at a later point. + serviceMonitorFilter := handler.EnqueueRequestsFromMapFunc(proxyClassesWithServiceMonitor(mgr.GetClient(), opts.log)) err = builder.ControllerManagedBy(mgr). For(&tsapi.ProxyClass{}). + Watches(&apiextensionsv1.CustomResourceDefinition{}, serviceMonitorFilter). Complete(&ProxyClassReconciler{ Client: mgr.GetClient(), recorder: eventRecorder, @@ -1018,6 +1034,49 @@ func epsFromExternalNameService(cl client.Client, logger *zap.SugaredLogger, ns } } +// proxyClassesWithServiceMonitor returns an event handler that, given that the event is for the Prometheus +// ServiceMonitor CRD, returns all ProxyClasses that define that a ServiceMonitor should be created. +func proxyClassesWithServiceMonitor(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + crd, ok := o.(*apiextensionsv1.CustomResourceDefinition) + if !ok { + logger.Debugf("[unexpected] ServiceMonitor CRD handler received an object that is not a CustomResourceDefinition") + return nil + } + if crd.Name != serviceMonitorCRD { + logger.Debugf("[unexpected] ServiceMonitor CRD handler received an unexpected CRD %q", crd.Name) + return nil + } + pcl := &tsapi.ProxyClassList{} + if err := cl.List(ctx, pcl); err != nil { + logger.Debugf("[unexpected] error listing ProxyClasses: %v", err) + return nil + } + reqs := make([]reconcile.Request, 0) + for _, pc := range pcl.Items { + if pc.Spec.Metrics != nil && pc.Spec.Metrics.ServiceMonitor != nil && pc.Spec.Metrics.ServiceMonitor.Enable { + reqs = append(reqs, reconcile.Request{ + NamespacedName: types.NamespacedName{Namespace: pc.Namespace, Name: pc.Name}, + }) + } + } + return reqs + } +} + +// crdTransformer gets called before a CRD is stored to c/r cache, it removes the CRD spec to reduce memory consumption. +func crdTransformer(log *zap.SugaredLogger) toolscache.TransformFunc { + return func(o any) (any, error) { + crd, ok := o.(*apiextensionsv1.CustomResourceDefinition) + if !ok { + log.Infof("[unexpected] CRD transformer called for a non-CRD type") + return crd, nil + } + crd.Spec = apiextensionsv1.CustomResourceDefinitionSpec{} + return crd, nil + } +} + // indexEgressServices adds a local index to a cached Tailscale egress Services meant to be exposed on a ProxyGroup. The // index is used a list filter. func indexEgressServices(o client.Object) []string { diff --git a/cmd/k8s-operator/proxyclass.go b/cmd/k8s-operator/proxyclass.go index 13f217f3c1685..ad3cfc9fd02d2 100644 --- a/cmd/k8s-operator/proxyclass.go +++ b/cmd/k8s-operator/proxyclass.go @@ -15,6 +15,7 @@ import ( dockerref "github.com/distribution/reference" "go.uber.org/zap" corev1 "k8s.io/api/core/v1" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" apiequality "k8s.io/apimachinery/pkg/api/equality" apierrors "k8s.io/apimachinery/pkg/api/errors" apivalidation "k8s.io/apimachinery/pkg/api/validation" @@ -95,7 +96,7 @@ func (pcr *ProxyClassReconciler) Reconcile(ctx context.Context, req reconcile.Re pcr.mu.Unlock() oldPCStatus := pc.Status.DeepCopy() - if errs := pcr.validate(pc); errs != nil { + if errs := pcr.validate(ctx, pc); errs != nil { msg := fmt.Sprintf(messageProxyClassInvalid, errs.ToAggregate().Error()) pcr.recorder.Event(pc, corev1.EventTypeWarning, reasonProxyClassInvalid, msg) tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionFalse, reasonProxyClassInvalid, msg, pc.Generation, pcr.clock, logger) @@ -111,7 +112,7 @@ func (pcr *ProxyClassReconciler) Reconcile(ctx context.Context, req reconcile.Re return reconcile.Result{}, nil } -func (pcr *ProxyClassReconciler) validate(pc *tsapi.ProxyClass) (violations field.ErrorList) { +func (pcr *ProxyClassReconciler) validate(ctx context.Context, pc *tsapi.ProxyClass) (violations field.ErrorList) { if sts := pc.Spec.StatefulSet; sts != nil { if len(sts.Labels) > 0 { if errs := metavalidation.ValidateLabels(sts.Labels, field.NewPath(".spec.statefulSet.labels")); errs != nil { @@ -167,6 +168,16 @@ func (pcr *ProxyClassReconciler) validate(pc *tsapi.ProxyClass) (violations fiel } } } + if pc.Spec.Metrics != nil && pc.Spec.Metrics.ServiceMonitor != nil && pc.Spec.Metrics.ServiceMonitor.Enable { + found, err := hasServiceMonitorCRD(ctx, pcr.Client) + if err != nil { + pcr.logger.Infof("[unexpected]: error retrieving %q CRD: %v", serviceMonitorCRD, err) + // best effort validation - don't error out here + } else if !found { + msg := fmt.Sprintf("ProxyClass defines that a ServiceMonitor custom resource should be created, but %q CRD was not found", serviceMonitorCRD) + violations = append(violations, field.TypeInvalid(field.NewPath("spec", "metrics", "serviceMonitor"), "enable", msg)) + } + } // We do not validate embedded fields (security context, resource // requirements etc) as we inherit upstream validation for those fields. // Invalid values would get rejected by upstream validations at apply @@ -174,6 +185,16 @@ func (pcr *ProxyClassReconciler) validate(pc *tsapi.ProxyClass) (violations fiel return violations } +func hasServiceMonitorCRD(ctx context.Context, cl client.Client) (bool, error) { + sm := &apiextensionsv1.CustomResourceDefinition{} + if err := cl.Get(ctx, types.NamespacedName{Name: serviceMonitorCRD}, sm); apierrors.IsNotFound(err) { + return false, nil + } else if err != nil { + return false, err + } + return true, nil +} + // maybeCleanup removes tailscale.com finalizer and ensures that the ProxyClass // is no longer counted towards k8s_proxyclass_resources. func (pcr *ProxyClassReconciler) maybeCleanup(ctx context.Context, logger *zap.SugaredLogger, pc *tsapi.ProxyClass) error { diff --git a/cmd/k8s-operator/proxyclass_test.go b/cmd/k8s-operator/proxyclass_test.go index fb17f5fe5e3ee..e6e16e9f9d59f 100644 --- a/cmd/k8s-operator/proxyclass_test.go +++ b/cmd/k8s-operator/proxyclass_test.go @@ -8,10 +8,12 @@ package main import ( + "context" "testing" "time" "go.uber.org/zap" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" @@ -134,6 +136,25 @@ func TestProxyClass(t *testing.T) { "Warning CustomTSEnvVar ProxyClass overrides the default value for EXPERIMENTAL_ALLOW_PROXYING_CLUSTER_TRAFFIC_VIA_INGRESS env var for tailscale container. Running with custom values for Tailscale env vars is not recommended and might break in the future."} expectReconciled(t, pcr, "", "test") expectEvents(t, fr, expectedEvents) + + // 6. A ProxyClass with ServiceMonitor enabled and in a cluster that has not ServiceMonitor CRD is invalid + pc.Spec.Metrics = &tsapi.Metrics{Enable: true, ServiceMonitor: &tsapi.ServiceMonitor{Enable: true}} + mustUpdate(t, fc, "", "test", func(proxyClass *tsapi.ProxyClass) { + proxyClass.Spec = pc.Spec + }) + expectReconciled(t, pcr, "", "test") + msg = `ProxyClass is not valid: spec.metrics.serviceMonitor: Invalid value: "enable": ProxyClass defines that a ServiceMonitor custom resource should be created, but "servicemonitors.monitoring.coreos.com" CRD was not found` + tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionFalse, reasonProxyClassInvalid, msg, 0, cl, zl.Sugar()) + expectEqual(t, fc, pc, nil) + expectedEvent = "Warning ProxyClassInvalid " + msg + expectEvents(t, fr, []string{expectedEvent}) + + // 7. A ProxyClass with ServiceMonitor enabled and in a cluster that does have the ServiceMonitor CRD is valid + crd := &apiextensionsv1.CustomResourceDefinition{ObjectMeta: metav1.ObjectMeta{Name: serviceMonitorCRD}} + mustCreate(t, fc, crd) + expectReconciled(t, pcr, "", "test") + tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionTrue, reasonProxyClassValid, reasonProxyClassValid, 0, cl, zl.Sugar()) + expectEqual(t, fc, pc, nil) } func TestValidateProxyClass(t *testing.T) { @@ -180,7 +201,7 @@ func TestValidateProxyClass(t *testing.T) { } { t.Run(name, func(t *testing.T) { pcr := &ProxyClassReconciler{} - err := pcr.validate(tc.pc) + err := pcr.validate(context.Background(), tc.pc) valid := err == nil if valid != tc.valid { t.Errorf("expected valid=%v, got valid=%v, err=%v", tc.valid, valid, err) diff --git a/cmd/k8s-operator/proxygroup.go b/cmd/k8s-operator/proxygroup.go index 6b76724662b6d..1aefbd2f6caef 100644 --- a/cmd/k8s-operator/proxygroup.go +++ b/cmd/k8s-operator/proxygroup.go @@ -259,6 +259,15 @@ func (r *ProxyGroupReconciler) maybeProvision(ctx context.Context, pg *tsapi.Pro }); err != nil { return fmt.Errorf("error provisioning StatefulSet: %w", err) } + mo := &metricsOpts{ + tsNamespace: r.tsNamespace, + proxyStsName: pg.Name, + proxyLabels: pgLabels(pg.Name, nil), + proxyType: "proxygroup", + } + if err := reconcileMetricsResources(ctx, logger, mo, proxyClass, r.Client); err != nil { + return fmt.Errorf("error reconciling metrics resources: %w", err) + } if err := r.cleanupDanglingResources(ctx, pg); err != nil { return fmt.Errorf("error cleaning up dangling resources: %w", err) @@ -327,6 +336,14 @@ func (r *ProxyGroupReconciler) maybeCleanup(ctx context.Context, pg *tsapi.Proxy } } + mo := &metricsOpts{ + proxyLabels: pgLabels(pg.Name, nil), + tsNamespace: r.tsNamespace, + proxyType: "proxygroup"} + if err := maybeCleanupMetricsResources(ctx, mo, r.Client); err != nil { + return false, fmt.Errorf("error cleaning up metrics resources: %w", err) + } + logger.Infof("cleaned up ProxyGroup resources") r.mu.Lock() r.proxyGroups.Remove(pg.UID) diff --git a/cmd/k8s-operator/proxygroup_test.go b/cmd/k8s-operator/proxygroup_test.go index 23f50cc7a576d..9c4df9e4f9302 100644 --- a/cmd/k8s-operator/proxygroup_test.go +++ b/cmd/k8s-operator/proxygroup_test.go @@ -17,6 +17,7 @@ import ( appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" @@ -76,6 +77,13 @@ func TestProxyGroup(t *testing.T) { l: zl.Sugar(), clock: cl, } + crd := &apiextensionsv1.CustomResourceDefinition{ObjectMeta: metav1.ObjectMeta{Name: serviceMonitorCRD}} + opts := configOpts{ + proxyType: "proxygroup", + stsName: pg.Name, + parentType: "proxygroup", + tailscaleNamespace: "tailscale", + } t.Run("proxyclass_not_ready", func(t *testing.T) { expectReconciled(t, reconciler, "", pg.Name) @@ -190,6 +198,27 @@ func TestProxyGroup(t *testing.T) { expectProxyGroupResources(t, fc, pg, true, "518a86e9fae64f270f8e0ec2a2ea6ca06c10f725035d3d6caca132cd61e42a74") }) + t.Run("enable_metrics", func(t *testing.T) { + pc.Spec.Metrics = &tsapi.Metrics{Enable: true} + mustUpdate(t, fc, "", pc.Name, func(p *tsapi.ProxyClass) { + p.Spec = pc.Spec + }) + expectReconciled(t, reconciler, "", pg.Name) + expectEqual(t, fc, expectedMetricsService(opts), nil) + }) + t.Run("enable_service_monitor_no_crd", func(t *testing.T) { + pc.Spec.Metrics.ServiceMonitor = &tsapi.ServiceMonitor{Enable: true} + mustUpdate(t, fc, "", pc.Name, func(p *tsapi.ProxyClass) { + p.Spec.Metrics = pc.Spec.Metrics + }) + expectReconciled(t, reconciler, "", pg.Name) + }) + t.Run("create_crd_expect_service_monitor", func(t *testing.T) { + mustCreate(t, fc, crd) + expectReconciled(t, reconciler, "", pg.Name) + expectEqualUnstructured(t, fc, expectedServiceMonitor(t, opts)) + }) + t.Run("delete_and_cleanup", func(t *testing.T) { if err := fc.Delete(context.Background(), pg); err != nil { t.Fatal(err) @@ -197,7 +226,7 @@ func TestProxyGroup(t *testing.T) { expectReconciled(t, reconciler, "", pg.Name) - expectMissing[tsapi.Recorder](t, fc, "", pg.Name) + expectMissing[tsapi.ProxyGroup](t, fc, "", pg.Name) if expected := 0; reconciler.proxyGroups.Len() != expected { t.Fatalf("expected %d ProxyGroups, got %d", expected, reconciler.proxyGroups.Len()) } @@ -206,6 +235,7 @@ func TestProxyGroup(t *testing.T) { if diff := cmp.Diff(tsClient.deleted, []string{"nodeid-1", "nodeid-2", "nodeid-0"}); diff != "" { t.Fatalf("unexpected deleted devices (-got +want):\n%s", diff) } + expectMissing[corev1.Service](t, reconciler, "tailscale", metricsResourceName(pg.Name)) // The fake client does not clean up objects whose owner has been // deleted, so we can't test for the owned resources getting deleted. }) diff --git a/cmd/k8s-operator/sts.go b/cmd/k8s-operator/sts.go index 73c54a93d0373..5de30154cd59a 100644 --- a/cmd/k8s-operator/sts.go +++ b/cmd/k8s-operator/sts.go @@ -94,6 +94,12 @@ const ( podAnnotationLastSetTailnetTargetFQDN = "tailscale.com/operator-last-set-ts-tailnet-target-fqdn" // podAnnotationLastSetConfigFileHash is sha256 hash of the current tailscaled configuration contents. podAnnotationLastSetConfigFileHash = "tailscale.com/operator-last-set-config-file-hash" + + proxyTypeEgress = "egress_service" + proxyTypeIngressService = "ingress_service" + proxyTypeIngressResource = "ingress_resource" + proxyTypeConnector = "connector" + proxyTypeProxyGroup = "proxygroup" ) var ( @@ -122,6 +128,8 @@ type tailscaleSTSConfig struct { Hostname string Tags []string // if empty, use defaultTags + proxyType string + // Connector specifies a configuration of a Connector instance if that's // what this StatefulSet should be created for. Connector *connector @@ -197,14 +205,22 @@ func (a *tailscaleSTSReconciler) Provision(ctx context.Context, logger *zap.Suga if err != nil { return nil, fmt.Errorf("failed to reconcile statefulset: %w", err) } - + mo := &metricsOpts{ + proxyStsName: hsvc.Name, + tsNamespace: hsvc.Namespace, + proxyLabels: hsvc.Labels, + proxyType: sts.proxyType, + } + if err = reconcileMetricsResources(ctx, logger, mo, sts.ProxyClass, a.Client); err != nil { + return nil, fmt.Errorf("failed to ensure metrics resources: %w", err) + } return hsvc, nil } // Cleanup removes all resources associated that were created by Provision with // the given labels. It returns true when all resources have been removed, // otherwise it returns false and the caller should retry later. -func (a *tailscaleSTSReconciler) Cleanup(ctx context.Context, logger *zap.SugaredLogger, labels map[string]string) (done bool, _ error) { +func (a *tailscaleSTSReconciler) Cleanup(ctx context.Context, logger *zap.SugaredLogger, labels map[string]string, typ string) (done bool, _ error) { // Need to delete the StatefulSet first, and delete it with foreground // cascading deletion. That way, the pod that's writing to the Secret will // stop running before we start looking at the Secret's contents, and @@ -257,6 +273,14 @@ func (a *tailscaleSTSReconciler) Cleanup(ctx context.Context, logger *zap.Sugare return false, err } } + mo := &metricsOpts{ + proxyLabels: labels, + tsNamespace: a.operatorNamespace, + proxyType: typ, + } + if err := maybeCleanupMetricsResources(ctx, mo, a.Client); err != nil { + return false, fmt.Errorf("error cleaning up metrics resources: %w", err) + } return true, nil } diff --git a/cmd/k8s-operator/svc.go b/cmd/k8s-operator/svc.go index 3c6bc27a95cf0..6afc56f976121 100644 --- a/cmd/k8s-operator/svc.go +++ b/cmd/k8s-operator/svc.go @@ -152,7 +152,12 @@ func (a *ServiceReconciler) maybeCleanup(ctx context.Context, logger *zap.Sugare return nil } - if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(svc.Name, svc.Namespace, "svc")); err != nil { + proxyTyp := proxyTypeEgress + if a.shouldExpose(svc) { + proxyTyp = proxyTypeIngressService + } + + if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(svc.Name, svc.Namespace, "svc"), proxyTyp); err != nil { return fmt.Errorf("failed to cleanup: %w", err) } else if !done { logger.Debugf("cleanup not done yet, waiting for next reconcile") @@ -256,6 +261,10 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga ChildResourceLabels: crl, ProxyClassName: proxyClass, } + sts.proxyType = proxyTypeEgress + if a.shouldExpose(svc) { + sts.proxyType = proxyTypeIngressService + } a.mu.Lock() if a.shouldExposeClusterIP(svc) { diff --git a/cmd/k8s-operator/testutils_test.go b/cmd/k8s-operator/testutils_test.go index 5f016e91dcc7c..f6ae29b62fefc 100644 --- a/cmd/k8s-operator/testutils_test.go +++ b/cmd/k8s-operator/testutils_test.go @@ -8,6 +8,7 @@ package main import ( "context" "encoding/json" + "fmt" "net/netip" "reflect" "strings" @@ -21,6 +22,7 @@ import ( corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" @@ -39,7 +41,10 @@ type configOpts struct { secretName string hostname string namespace string + tailscaleNamespace string + namespaced bool parentType string + proxyType string priorityClassName string firewallMode string tailnetTargetIP string @@ -56,6 +61,7 @@ type configOpts struct { app string shouldRemoveAuthKey bool secretExtraData map[string][]byte + enableMetrics bool } func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.StatefulSet { @@ -150,6 +156,29 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef Name: "TS_INTERNAL_APP", Value: opts.app, }) + if opts.enableMetrics { + tsContainer.Env = append(tsContainer.Env, + corev1.EnvVar{ + Name: "TS_DEBUG_ADDR_PORT", + Value: "$(POD_IP):9001"}, + corev1.EnvVar{ + Name: "TS_TAILSCALED_EXTRA_ARGS", + Value: "--debug=$(TS_DEBUG_ADDR_PORT)", + }, + corev1.EnvVar{ + Name: "TS_LOCAL_ADDR_PORT", + Value: "$(POD_IP):9002", + }, + corev1.EnvVar{ + Name: "TS_ENABLE_METRICS", + Value: "true", + }, + ) + tsContainer.Ports = append(tsContainer.Ports, + corev1.ContainerPort{Name: "debug", ContainerPort: 9001, Protocol: "TCP"}, + corev1.ContainerPort{Name: "metrics", ContainerPort: 9002, Protocol: "TCP"}, + ) + } ss := &appsv1.StatefulSet{ TypeMeta: metav1.TypeMeta{ Kind: "StatefulSet", @@ -241,6 +270,29 @@ func expectedSTSUserspace(t *testing.T, cl client.Client, opts configOpts) *apps {Name: "serve-config", ReadOnly: true, MountPath: "/etc/tailscaled"}, }, } + if opts.enableMetrics { + tsContainer.Env = append(tsContainer.Env, + corev1.EnvVar{ + Name: "TS_DEBUG_ADDR_PORT", + Value: "$(POD_IP):9001"}, + corev1.EnvVar{ + Name: "TS_TAILSCALED_EXTRA_ARGS", + Value: "--debug=$(TS_DEBUG_ADDR_PORT)", + }, + corev1.EnvVar{ + Name: "TS_LOCAL_ADDR_PORT", + Value: "$(POD_IP):9002", + }, + corev1.EnvVar{ + Name: "TS_ENABLE_METRICS", + Value: "true", + }, + ) + tsContainer.Ports = append(tsContainer.Ports, corev1.ContainerPort{ + Name: "debug", ContainerPort: 9001, Protocol: "TCP"}, + corev1.ContainerPort{Name: "metrics", ContainerPort: 9002, Protocol: "TCP"}, + ) + } volumes := []corev1.Volume{ { Name: "tailscaledconfig", @@ -335,6 +387,87 @@ func expectedHeadlessService(name string, parentType string) *corev1.Service { } } +func expectedMetricsService(opts configOpts) *corev1.Service { + labels := metricsLabels(opts) + selector := map[string]string{ + "tailscale.com/managed": "true", + "tailscale.com/parent-resource": "test", + "tailscale.com/parent-resource-type": opts.parentType, + } + if opts.namespaced { + selector["tailscale.com/parent-resource-ns"] = opts.namespace + } + return &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: metricsResourceName(opts.stsName), + Namespace: opts.tailscaleNamespace, + Labels: labels, + }, + Spec: corev1.ServiceSpec{ + Selector: selector, + Type: corev1.ServiceTypeClusterIP, + Ports: []corev1.ServicePort{{Protocol: "TCP", Port: 9002, Name: "metrics"}}, + }, + } +} + +func metricsLabels(opts configOpts) map[string]string { + promJob := fmt.Sprintf("ts_%s_default_test", opts.proxyType) + if !opts.namespaced { + promJob = fmt.Sprintf("ts_%s_test", opts.proxyType) + } + labels := map[string]string{ + "tailscale.com/managed": "true", + "tailscale.com/metrics-target": opts.stsName, + "ts_prom_job": promJob, + "ts_proxy_type": opts.proxyType, + "ts_proxy_parent_name": "test", + } + if opts.namespaced { + labels["ts_proxy_parent_namespace"] = "default" + } + return labels +} + +func expectedServiceMonitor(t *testing.T, opts configOpts) *unstructured.Unstructured { + t.Helper() + labels := metricsLabels(opts) + name := metricsResourceName(opts.stsName) + sm := &ServiceMonitor{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: opts.tailscaleNamespace, + Labels: labels, + ResourceVersion: "1", + OwnerReferences: []metav1.OwnerReference{{APIVersion: "v1", Kind: "Service", Name: name, BlockOwnerDeletion: ptr.To(true), Controller: ptr.To(true)}}, + }, + TypeMeta: metav1.TypeMeta{ + Kind: "ServiceMonitor", + APIVersion: "monitoring.coreos.com/v1", + }, + Spec: ServiceMonitorSpec{ + Selector: metav1.LabelSelector{MatchLabels: labels}, + Endpoints: []ServiceMonitorEndpoint{{ + Port: "metrics", + }}, + NamespaceSelector: ServiceMonitorNamespaceSelector{ + MatchNames: []string{opts.tailscaleNamespace}, + }, + JobLabel: "ts_prom_job", + TargetLabels: []string{ + "ts_proxy_parent_name", + "ts_proxy_parent_namespace", + "ts_proxy_type", + }, + }, + } + u, err := serviceMonitorToUnstructured(sm) + if err != nil { + t.Fatalf("error converting ServiceMonitor to unstructured: %v", err) + } + return u +} + func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Secret { t.Helper() s := &corev1.Secret{ @@ -502,6 +635,21 @@ func expectEqual[T any, O ptrObject[T]](t *testing.T, client client.Client, want } } +func expectEqualUnstructured(t *testing.T, client client.Client, want *unstructured.Unstructured) { + t.Helper() + got := &unstructured.Unstructured{} + got.SetGroupVersionKind(want.GroupVersionKind()) + if err := client.Get(context.Background(), types.NamespacedName{ + Name: want.GetName(), + Namespace: want.GetNamespace(), + }, got); err != nil { + t.Fatalf("getting %q: %v", want.GetName(), err) + } + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("unexpected contents of Unstructured (-got +want):\n%s", diff) + } +} + func expectMissing[T any, O ptrObject[T]](t *testing.T, client client.Client, ns, name string) { t.Helper() obj := O(new(T)) diff --git a/go.mod b/go.mod index 92ba6b9c7b54d..1924e93ed5d32 100644 --- a/go.mod +++ b/go.mod @@ -396,7 +396,7 @@ require ( gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 howett.net/plist v1.0.0 // indirect - k8s.io/apiextensions-apiserver v0.30.3 // indirect + k8s.io/apiextensions-apiserver v0.30.3 k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340 // indirect k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 diff --git a/k8s-operator/api.md b/k8s-operator/api.md index 730bed210118f..08e1284fe82e7 100644 --- a/k8s-operator/api.md +++ b/k8s-operator/api.md @@ -326,7 +326,8 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `enable` _boolean_ | Setting enable to true will make the proxy serve Tailscale metrics
at :9002/metrics.
In 1.78.x and 1.80.x, this field also serves as the default value for
.spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both
fields will independently default to false.
Defaults to false. | | | +| `enable` _boolean_ | Setting enable to true will make the proxy serve Tailscale metrics
at :9002/metrics.
A metrics Service named -metrics will also be created in the operator's namespace and will
serve the metrics at :9002/metrics.
In 1.78.x and 1.80.x, this field also serves as the default value for
.spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both
fields will independently default to false.
Defaults to false. | | | +| `serviceMonitor` _[ServiceMonitor](#servicemonitor)_ | Enable to create a Prometheus ServiceMonitor for scraping the proxy's Tailscale metrics.
The ServiceMonitor will select the metrics Service that gets created when metrics are enabled.
The ingested metrics for each Service monitor will have labels to identify the proxy:
ts_proxy_type: ingress_service\|ingress_resource\|connector\|proxygroup
ts_proxy_parent_name: name of the parent resource (i.e name of the Connector, Tailscale Ingress, Tailscale Service or ProxyGroup)
ts_proxy_parent_namespace: namespace of the parent resource (if the parent resource is not cluster scoped)
job: ts__[]_ | | | #### Name @@ -836,6 +837,22 @@ _Appears in:_ | `name` _string_ | The name of a Kubernetes Secret in the operator's namespace that contains
credentials for writing to the configured bucket. Each key-value pair
from the secret's data will be mounted as an environment variable. It
should include keys for AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY if
using a static access key. | | | +#### ServiceMonitor + + + + + + + +_Appears in:_ +- [Metrics](#metrics) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `enable` _boolean_ | If Enable is set to true, a Prometheus ServiceMonitor will be created. Enable can only be set to true if metrics are enabled. | | | + + #### StatefulSet diff --git a/k8s-operator/apis/v1alpha1/register.go b/k8s-operator/apis/v1alpha1/register.go index 70b411d120994..0880ac975732e 100644 --- a/k8s-operator/apis/v1alpha1/register.go +++ b/k8s-operator/apis/v1alpha1/register.go @@ -10,6 +10,7 @@ import ( "tailscale.com/k8s-operator/apis" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" @@ -39,12 +40,18 @@ func init() { localSchemeBuilder.Register(addKnownTypes) GlobalScheme = runtime.NewScheme() + // Add core types if err := scheme.AddToScheme(GlobalScheme); err != nil { panic(fmt.Sprintf("failed to add k8s.io scheme: %s", err)) } + // Add tailscale.com types if err := AddToScheme(GlobalScheme); err != nil { panic(fmt.Sprintf("failed to add tailscale.com scheme: %s", err)) } + // Add apiextensions types (CustomResourceDefinitions/CustomResourceDefinitionLists) + if err := apiextensionsv1.AddToScheme(GlobalScheme); err != nil { + panic(fmt.Sprintf("failed to add apiextensions.k8s.io scheme: %s", err)) + } } // Adds the list of known types to api.Scheme. diff --git a/k8s-operator/apis/v1alpha1/types_proxyclass.go b/k8s-operator/apis/v1alpha1/types_proxyclass.go index 71fbf24390d55..ef9a071d02bbe 100644 --- a/k8s-operator/apis/v1alpha1/types_proxyclass.go +++ b/k8s-operator/apis/v1alpha1/types_proxyclass.go @@ -161,9 +161,12 @@ type Pod struct { TopologySpreadConstraints []corev1.TopologySpreadConstraint `json:"topologySpreadConstraints,omitempty"` } +// +kubebuilder:validation:XValidation:rule="!(has(self.serviceMonitor) && self.serviceMonitor.enable && !self.enable)",message="ServiceMonitor can only be enabled if metrics are enabled" type Metrics struct { // Setting enable to true will make the proxy serve Tailscale metrics // at :9002/metrics. + // A metrics Service named -metrics will also be created in the operator's namespace and will + // serve the metrics at :9002/metrics. // // In 1.78.x and 1.80.x, this field also serves as the default value for // .spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both @@ -171,6 +174,20 @@ type Metrics struct { // // Defaults to false. Enable bool `json:"enable"` + // Enable to create a Prometheus ServiceMonitor for scraping the proxy's Tailscale metrics. + // The ServiceMonitor will select the metrics Service that gets created when metrics are enabled. + // The ingested metrics for each Service monitor will have labels to identify the proxy: + // ts_proxy_type: ingress_service|ingress_resource|connector|proxygroup + // ts_proxy_parent_name: name of the parent resource (i.e name of the Connector, Tailscale Ingress, Tailscale Service or ProxyGroup) + // ts_proxy_parent_namespace: namespace of the parent resource (if the parent resource is not cluster scoped) + // job: ts__[]_ + // +optional + ServiceMonitor *ServiceMonitor `json:"serviceMonitor"` +} + +type ServiceMonitor struct { + // If Enable is set to true, a Prometheus ServiceMonitor will be created. Enable can only be set to true if metrics are enabled. + Enable bool `json:"enable"` } type Container struct { diff --git a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go index 07e46f3f5cde8..29c71cb90f309 100644 --- a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go +++ b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go @@ -319,6 +319,11 @@ func (in *Env) DeepCopy() *Env { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Metrics) DeepCopyInto(out *Metrics) { *out = *in + if in.ServiceMonitor != nil { + in, out := &in.ServiceMonitor, &out.ServiceMonitor + *out = new(ServiceMonitor) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Metrics. @@ -526,7 +531,7 @@ func (in *ProxyClassSpec) DeepCopyInto(out *ProxyClassSpec) { if in.Metrics != nil { in, out := &in.Metrics, &out.Metrics *out = new(Metrics) - **out = **in + (*in).DeepCopyInto(*out) } if in.TailscaleConfig != nil { in, out := &in.TailscaleConfig, &out.TailscaleConfig @@ -991,6 +996,21 @@ func (in *S3Secret) DeepCopy() *S3Secret { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ServiceMonitor) DeepCopyInto(out *ServiceMonitor) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ServiceMonitor. +func (in *ServiceMonitor) DeepCopy() *ServiceMonitor { + if in == nil { + return nil + } + out := new(ServiceMonitor) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *StatefulSet) DeepCopyInto(out *StatefulSet) { *out = *in From efdfd547979fc09ea30d96bf31fcc06cadc538f3 Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Tue, 3 Dec 2024 15:02:42 +0000 Subject: [PATCH 162/179] cmd/k8s-operator: avoid port collision with metrics endpoint (#14185) When the operator enables metrics on a proxy, it uses the port 9001, and in the near future it will start using 9002 for the debug endpoint as well. Make sure we don't choose ports from a range that includes 9001 so that we never clash. Setting TS_SOCKS5_SERVER, TS_HEALTHCHECK_ADDR_PORT, TS_OUTBOUND_HTTP_PROXY_LISTEN, and PORT could also open arbitrary ports, so we will need to document that users should not choose ports from the 10000-11000 range for those settings. Updates #13406 Signed-off-by: Tom Proctor --- cmd/k8s-operator/egress-services.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cmd/k8s-operator/egress-services.go b/cmd/k8s-operator/egress-services.go index a562f0170eea1..17746b47012d0 100644 --- a/cmd/k8s-operator/egress-services.go +++ b/cmd/k8s-operator/egress-services.go @@ -51,12 +51,12 @@ const ( labelSvcType = "tailscale.com/svc-type" // ingress or egress typeEgress = "egress" // maxPorts is the maximum number of ports that can be exposed on a - // container. In practice this will be ports in range [3000 - 4000). The + // container. In practice this will be ports in range [10000 - 11000). The // high range should make it easier to distinguish container ports from // the tailnet target ports for debugging purposes (i.e when reading - // netfilter rules). The limit of 10000 is somewhat arbitrary, the + // netfilter rules). The limit of 1000 is somewhat arbitrary, the // assumption is that this would not be hit in practice. - maxPorts = 10000 + maxPorts = 1000 indexEgressProxyGroup = ".metadata.annotations.egress-proxy-group" ) @@ -254,7 +254,7 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s if !found { // Calculate a free port to expose on container and add // a new PortMap to the ClusterIP Service. - if usedPorts.Len() == maxPorts { + if usedPorts.Len() >= maxPorts { // TODO(irbekrm): refactor to avoid extra reconciles here. Low priority as in practice, // the limit should not be hit. return nil, false, fmt.Errorf("unable to allocate additional ports on ProxyGroup %s, %d ports already used. Create another ProxyGroup or open an issue if you believe this is unexpected.", proxyGroupName, maxPorts) @@ -548,13 +548,13 @@ func svcNameBase(s string) string { } } -// unusedPort returns a port in range [3000 - 4000). The caller must ensure that -// usedPorts does not contain all ports in range [3000 - 4000). +// unusedPort returns a port in range [10000 - 11000). The caller must ensure that +// usedPorts does not contain all ports in range [10000 - 11000). func unusedPort(usedPorts sets.Set[int32]) int32 { foundFreePort := false var suggestPort int32 for !foundFreePort { - suggestPort = rand.Int32N(maxPorts) + 3000 + suggestPort = rand.Int32N(maxPorts) + 10000 if !usedPorts.Has(suggestPort) { foundFreePort = true } From cbf1a4efe97a5424010a967285d71cf6ee4458ab Mon Sep 17 00:00:00 2001 From: Oliver Rahner Date: Tue, 3 Dec 2024 18:00:40 +0100 Subject: [PATCH 163/179] cmd/k8s-operator/deploy/chart: allow reading OAuth creds from a CSI driver's volume and annotating operator's Service account (#14264) cmd/k8s-operator/deploy/chart: allow reading OAuth creds from a CSI driver's volume and annotating operator's Service account Updates #14264 Signed-off-by: Oliver Rahner --- .../deploy/chart/templates/deployment.yaml | 10 +++++++--- .../deploy/chart/templates/operator-rbac.yaml | 4 ++++ cmd/k8s-operator/deploy/chart/values.yaml | 20 ++++++++++++++++++- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml index 2653f21595ba7..1b9b97186b6ca 100644 --- a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml @@ -35,9 +35,13 @@ spec: {{- toYaml . | nindent 8 }} {{- end }} volumes: - - name: oauth - secret: - secretName: operator-oauth + - name: oauth + {{- with .Values.oauthSecretVolume }} + {{- toYaml . | nindent 10 }} + {{- else }} + secret: + secretName: operator-oauth + {{- end }} containers: - name: operator {{- with .Values.operatorConfig.securityContext }} diff --git a/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml b/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml index a56edfe0d1b80..637bdf793c2b9 100644 --- a/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml @@ -6,6 +6,10 @@ kind: ServiceAccount metadata: name: operator namespace: {{ .Release.Namespace }} + {{- with .Values.operatorConfig.serviceAccountAnnotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRole diff --git a/cmd/k8s-operator/deploy/chart/values.yaml b/cmd/k8s-operator/deploy/chart/values.yaml index b24ba37b05360..2d1effc255dc5 100644 --- a/cmd/k8s-operator/deploy/chart/values.yaml +++ b/cmd/k8s-operator/deploy/chart/values.yaml @@ -3,11 +3,26 @@ # Operator oauth credentials. If set a Kubernetes Secret with the provided # values will be created in the operator namespace. If unset a Secret named -# operator-oauth must be precreated. +# operator-oauth must be precreated or oauthSecretVolume needs to be adjusted. +# This block will be overridden by oauthSecretVolume, if set. oauth: {} # clientId: "" # clientSecret: "" +# Secret volume. +# If set it defines the volume the oauth secrets will be mounted from. +# The volume needs to contain two files named `client_id` and `client_secret`. +# If unset the volume will reference the Secret named operator-oauth. +# This block will override the oauth block. +oauthSecretVolume: {} + # csi: + # driver: secrets-store.csi.k8s.io + # readOnly: true + # volumeAttributes: + # secretProviderClass: tailscale-oauth + # + ## NAME is pre-defined! + # installCRDs determines whether tailscale.com CRDs should be installed as part # of chart installation. We do not use Helm's CRD installation mechanism as that # does not allow for upgrading CRDs. @@ -40,6 +55,9 @@ operatorConfig: podAnnotations: {} podLabels: {} + serviceAccountAnnotations: {} + # eks.amazonaws.com/role-arn: arn:aws:iam::123456789012:role/tailscale-operator-role + tolerations: [] affinity: {} From aa43388363bbb34835bc721cddc246e3f357d187 Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Wed, 4 Dec 2024 06:46:51 +0000 Subject: [PATCH 164/179] cmd/k8s-operator: fix a bunch of status equality checks (#14270) Updates tailscale/tailscale#14269 Signed-off-by: Irbe Krumina --- cmd/k8s-operator/connector.go | 2 +- cmd/k8s-operator/egress-services-readiness.go | 2 +- cmd/k8s-operator/egress-services.go | 2 +- cmd/k8s-operator/nameserver.go | 14 +++++++------- cmd/k8s-operator/proxyclass.go | 2 +- cmd/k8s-operator/proxygroup.go | 2 +- cmd/k8s-operator/svc.go | 4 ++-- cmd/k8s-operator/tsrecorder.go | 2 +- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/cmd/k8s-operator/connector.go b/cmd/k8s-operator/connector.go index 1ed6fd1556d5f..dfeee6be1cb85 100644 --- a/cmd/k8s-operator/connector.go +++ b/cmd/k8s-operator/connector.go @@ -113,7 +113,7 @@ func (a *ConnectorReconciler) Reconcile(ctx context.Context, req reconcile.Reque setStatus := func(cn *tsapi.Connector, _ tsapi.ConditionType, status metav1.ConditionStatus, reason, message string) (reconcile.Result, error) { tsoperator.SetConnectorCondition(cn, tsapi.ConnectorReady, status, reason, message, cn.Generation, a.clock, logger) var updateErr error - if !apiequality.Semantic.DeepEqual(oldCnStatus, cn.Status) { + if !apiequality.Semantic.DeepEqual(oldCnStatus, &cn.Status) { // An error encountered here should get returned by the Reconcile function. updateErr = a.Client.Status().Update(ctx, cn) } diff --git a/cmd/k8s-operator/egress-services-readiness.go b/cmd/k8s-operator/egress-services-readiness.go index f6991145f88fc..f1964d452633c 100644 --- a/cmd/k8s-operator/egress-services-readiness.go +++ b/cmd/k8s-operator/egress-services-readiness.go @@ -64,7 +64,7 @@ func (esrr *egressSvcsReadinessReconciler) Reconcile(ctx context.Context, req re oldStatus := svc.Status.DeepCopy() defer func() { tsoperator.SetServiceCondition(svc, tsapi.EgressSvcReady, st, reason, msg, esrr.clock, l) - if !apiequality.Semantic.DeepEqual(oldStatus, svc.Status) { + if !apiequality.Semantic.DeepEqual(oldStatus, &svc.Status) { err = errors.Join(err, esrr.Status().Update(ctx, svc)) } }() diff --git a/cmd/k8s-operator/egress-services.go b/cmd/k8s-operator/egress-services.go index 17746b47012d0..a08c0b71563f0 100644 --- a/cmd/k8s-operator/egress-services.go +++ b/cmd/k8s-operator/egress-services.go @@ -123,7 +123,7 @@ func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Re oldStatus := svc.Status.DeepCopy() defer func() { - if !apiequality.Semantic.DeepEqual(oldStatus, svc.Status) { + if !apiequality.Semantic.DeepEqual(oldStatus, &svc.Status) { err = errors.Join(err, esr.Status().Update(ctx, svc)) } }() diff --git a/cmd/k8s-operator/nameserver.go b/cmd/k8s-operator/nameserver.go index 52577c929acea..6a9a6be935642 100644 --- a/cmd/k8s-operator/nameserver.go +++ b/cmd/k8s-operator/nameserver.go @@ -86,7 +86,7 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ return reconcile.Result{}, nil } logger.Info("Cleaning up DNSConfig resources") - if err := a.maybeCleanup(ctx, &dnsCfg, logger); err != nil { + if err := a.maybeCleanup(&dnsCfg); err != nil { logger.Errorf("error cleaning up reconciler resource: %v", err) return res, err } @@ -100,9 +100,9 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ } oldCnStatus := dnsCfg.Status.DeepCopy() - setStatus := func(dnsCfg *tsapi.DNSConfig, conditionType tsapi.ConditionType, status metav1.ConditionStatus, reason, message string) (reconcile.Result, error) { + setStatus := func(dnsCfg *tsapi.DNSConfig, status metav1.ConditionStatus, reason, message string) (reconcile.Result, error) { tsoperator.SetDNSConfigCondition(dnsCfg, tsapi.NameserverReady, status, reason, message, dnsCfg.Generation, a.clock, logger) - if !apiequality.Semantic.DeepEqual(oldCnStatus, dnsCfg.Status) { + if !apiequality.Semantic.DeepEqual(oldCnStatus, &dnsCfg.Status) { // An error encountered here should get returned by the Reconcile function. if updateErr := a.Client.Status().Update(ctx, dnsCfg); updateErr != nil { err = errors.Wrap(err, updateErr.Error()) @@ -118,7 +118,7 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ msg := "invalid cluster configuration: more than one tailscale.com/dnsconfigs found. Please ensure that no more than one is created." logger.Error(msg) a.recorder.Event(&dnsCfg, corev1.EventTypeWarning, reasonMultipleDNSConfigsPresent, messageMultipleDNSConfigsPresent) - setStatus(&dnsCfg, tsapi.NameserverReady, metav1.ConditionFalse, reasonMultipleDNSConfigsPresent, messageMultipleDNSConfigsPresent) + setStatus(&dnsCfg, metav1.ConditionFalse, reasonMultipleDNSConfigsPresent, messageMultipleDNSConfigsPresent) } if !slices.Contains(dnsCfg.Finalizers, FinalizerName) { @@ -127,7 +127,7 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ if err := a.Update(ctx, &dnsCfg); err != nil { msg := fmt.Sprintf(messageNameserverCreationFailed, err) logger.Error(msg) - return setStatus(&dnsCfg, tsapi.NameserverReady, metav1.ConditionFalse, reasonNameserverCreationFailed, msg) + return setStatus(&dnsCfg, metav1.ConditionFalse, reasonNameserverCreationFailed, msg) } } if err := a.maybeProvision(ctx, &dnsCfg, logger); err != nil { @@ -149,7 +149,7 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ dnsCfg.Status.Nameserver = &tsapi.NameserverStatus{ IP: ip, } - return setStatus(&dnsCfg, tsapi.NameserverReady, metav1.ConditionTrue, reasonNameserverCreated, reasonNameserverCreated) + return setStatus(&dnsCfg, metav1.ConditionTrue, reasonNameserverCreated, reasonNameserverCreated) } logger.Info("nameserver Service does not have an IP address allocated, waiting...") return reconcile.Result{}, nil @@ -188,7 +188,7 @@ func (a *NameserverReconciler) maybeProvision(ctx context.Context, tsDNSCfg *tsa // maybeCleanup removes DNSConfig from being tracked. The cluster resources // created, will be automatically garbage collected as they are owned by the // DNSConfig. -func (a *NameserverReconciler) maybeCleanup(ctx context.Context, dnsCfg *tsapi.DNSConfig, logger *zap.SugaredLogger) error { +func (a *NameserverReconciler) maybeCleanup(dnsCfg *tsapi.DNSConfig) error { a.mu.Lock() a.managedNameservers.Remove(dnsCfg.UID) a.mu.Unlock() diff --git a/cmd/k8s-operator/proxyclass.go b/cmd/k8s-operator/proxyclass.go index ad3cfc9fd02d2..b781af05adaaa 100644 --- a/cmd/k8s-operator/proxyclass.go +++ b/cmd/k8s-operator/proxyclass.go @@ -103,7 +103,7 @@ func (pcr *ProxyClassReconciler) Reconcile(ctx context.Context, req reconcile.Re } else { tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionTrue, reasonProxyClassValid, reasonProxyClassValid, pc.Generation, pcr.clock, logger) } - if !apiequality.Semantic.DeepEqual(oldPCStatus, pc.Status) { + if !apiequality.Semantic.DeepEqual(oldPCStatus, &pc.Status) { if err := pcr.Client.Status().Update(ctx, pc); err != nil { logger.Errorf("error updating ProxyClass status: %v", err) return reconcile.Result{}, err diff --git a/cmd/k8s-operator/proxygroup.go b/cmd/k8s-operator/proxygroup.go index 1aefbd2f6caef..344cd9ae065f4 100644 --- a/cmd/k8s-operator/proxygroup.go +++ b/cmd/k8s-operator/proxygroup.go @@ -110,7 +110,7 @@ func (r *ProxyGroupReconciler) Reconcile(ctx context.Context, req reconcile.Requ oldPGStatus := pg.Status.DeepCopy() setStatusReady := func(pg *tsapi.ProxyGroup, status metav1.ConditionStatus, reason, message string) (reconcile.Result, error) { tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, status, reason, message, pg.Generation, r.clock, logger) - if !apiequality.Semantic.DeepEqual(oldPGStatus, pg.Status) { + if !apiequality.Semantic.DeepEqual(oldPGStatus, &pg.Status) { // An error encountered here should get returned by the Reconcile function. if updateErr := r.Client.Status().Update(ctx, pg); updateErr != nil { err = errors.Wrap(err, updateErr.Error()) diff --git a/cmd/k8s-operator/svc.go b/cmd/k8s-operator/svc.go index 6afc56f976121..cbf50c81f6630 100644 --- a/cmd/k8s-operator/svc.go +++ b/cmd/k8s-operator/svc.go @@ -131,7 +131,7 @@ func (a *ServiceReconciler) Reconcile(ctx context.Context, req reconcile.Request func (a *ServiceReconciler) maybeCleanup(ctx context.Context, logger *zap.SugaredLogger, svc *corev1.Service) (err error) { oldSvcStatus := svc.Status.DeepCopy() defer func() { - if !apiequality.Semantic.DeepEqual(oldSvcStatus, svc.Status) { + if !apiequality.Semantic.DeepEqual(oldSvcStatus, &svc.Status) { // An error encountered here should get returned by the Reconcile function. err = errors.Join(err, a.Client.Status().Update(ctx, svc)) } @@ -196,7 +196,7 @@ func (a *ServiceReconciler) maybeCleanup(ctx context.Context, logger *zap.Sugare func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.SugaredLogger, svc *corev1.Service) (err error) { oldSvcStatus := svc.Status.DeepCopy() defer func() { - if !apiequality.Semantic.DeepEqual(oldSvcStatus, svc.Status) { + if !apiequality.Semantic.DeepEqual(oldSvcStatus, &svc.Status) { // An error encountered here should get returned by the Reconcile function. err = errors.Join(err, a.Client.Status().Update(ctx, svc)) } diff --git a/cmd/k8s-operator/tsrecorder.go b/cmd/k8s-operator/tsrecorder.go index cfe38c50af311..4445578a63920 100644 --- a/cmd/k8s-operator/tsrecorder.go +++ b/cmd/k8s-operator/tsrecorder.go @@ -102,7 +102,7 @@ func (r *RecorderReconciler) Reconcile(ctx context.Context, req reconcile.Reques oldTSRStatus := tsr.Status.DeepCopy() setStatusReady := func(tsr *tsapi.Recorder, status metav1.ConditionStatus, reason, message string) (reconcile.Result, error) { tsoperator.SetRecorderCondition(tsr, tsapi.RecorderReady, status, reason, message, tsr.Generation, r.clock, logger) - if !apiequality.Semantic.DeepEqual(oldTSRStatus, tsr.Status) { + if !apiequality.Semantic.DeepEqual(oldTSRStatus, &tsr.Status) { // An error encountered here should get returned by the Reconcile function. if updateErr := r.Client.Status().Update(ctx, tsr); updateErr != nil { err = errors.Wrap(err, updateErr.Error()) From 2aac91688883090d892f01a2953cc0318aee9c90 Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Wed, 4 Dec 2024 12:00:04 +0000 Subject: [PATCH 165/179] cmd/{containerboot,k8s-operator},kube/kubetypes: kube Ingress L7 proxies only advertise HTTPS endpoint when ready (#14171) cmd/containerboot,kube/kubetypes,cmd/k8s-operator: detect if Ingress is created in a tailnet that has no HTTPS This attempts to make Kubernetes Operator L7 Ingress setup failures more explicit: - the Ingress resource now only advertises HTTPS endpoint via status.ingress.loadBalancer.hostname when/if the proxy has succesfully loaded serve config - the proxy attempts to catch cases where HTTPS is disabled for the tailnet and logs a warning Updates tailscale/tailscale#12079 Updates tailscale/tailscale#10407 Signed-off-by: Irbe Krumina --- cmd/containerboot/kube.go | 92 ++++++++++----- cmd/containerboot/kube_test.go | 42 +++---- cmd/containerboot/main.go | 49 ++++++-- cmd/containerboot/main_test.go | 36 +++--- cmd/containerboot/serve.go | 60 ++++++++-- cmd/containerboot/settings.go | 4 +- cmd/k8s-operator/connector.go | 10 +- cmd/k8s-operator/ingress.go | 12 +- cmd/k8s-operator/ingress_test.go | 148 ++++++++++++++++++++++++ cmd/k8s-operator/sts.go | 93 +++++++++++---- cmd/k8s-operator/svc.go | 10 +- kube/kubetypes/{metrics.go => types.go} | 15 +++ 12 files changed, 443 insertions(+), 128 deletions(-) rename kube/kubetypes/{metrics.go => types.go} (59%) diff --git a/cmd/containerboot/kube.go b/cmd/containerboot/kube.go index 5a726c20b33e9..643eef385ee0c 100644 --- a/cmd/containerboot/kube.go +++ b/cmd/containerboot/kube.go @@ -9,30 +9,55 @@ import ( "context" "encoding/json" "fmt" - "log" "net/http" "net/netip" "os" "tailscale.com/kube/kubeapi" "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" "tailscale.com/tailcfg" ) -// storeDeviceID writes deviceID to 'device_id' data field of the named -// Kubernetes Secret. -func storeDeviceID(ctx context.Context, secretName string, deviceID tailcfg.StableNodeID) error { +// kubeClient is a wrapper around Tailscale's internal kube client that knows how to talk to the kube API server. We use +// this rather than any of the upstream Kubernetes client libaries to avoid extra imports. +type kubeClient struct { + kubeclient.Client + stateSecret string +} + +func newKubeClient(root string, stateSecret string) (*kubeClient, error) { + if root != "/" { + // If we are running in a test, we need to set the root path to the fake + // service account directory. + kubeclient.SetRootPathForTesting(root) + } + var err error + kc, err := kubeclient.New("tailscale-container") + if err != nil { + return nil, fmt.Errorf("Error creating kube client: %w", err) + } + if (root != "/") || os.Getenv("TS_KUBERNETES_READ_API_SERVER_ADDRESS_FROM_ENV") == "true" { + // Derive the API server address from the environment variables + // Used to set http server in tests, or optionally enabled by flag + kc.SetURL(fmt.Sprintf("https://%s:%s", os.Getenv("KUBERNETES_SERVICE_HOST"), os.Getenv("KUBERNETES_SERVICE_PORT_HTTPS"))) + } + return &kubeClient{Client: kc, stateSecret: stateSecret}, nil +} + +// storeDeviceID writes deviceID to 'device_id' data field of the client's state Secret. +func (kc *kubeClient) storeDeviceID(ctx context.Context, deviceID tailcfg.StableNodeID) error { s := &kubeapi.Secret{ Data: map[string][]byte{ - "device_id": []byte(deviceID), + kubetypes.KeyDeviceID: []byte(deviceID), }, } - return kc.StrategicMergePatchSecret(ctx, secretName, s, "tailscale-container") + return kc.StrategicMergePatchSecret(ctx, kc.stateSecret, s, "tailscale-container") } -// storeDeviceEndpoints writes device's tailnet IPs and MagicDNS name to fields -// 'device_ips', 'device_fqdn' of the named Kubernetes Secret. -func storeDeviceEndpoints(ctx context.Context, secretName string, fqdn string, addresses []netip.Prefix) error { +// storeDeviceEndpoints writes device's tailnet IPs and MagicDNS name to fields 'device_ips', 'device_fqdn' of client's +// state Secret. +func (kc *kubeClient) storeDeviceEndpoints(ctx context.Context, fqdn string, addresses []netip.Prefix) error { var ips []string for _, addr := range addresses { ips = append(ips, addr.Addr().String()) @@ -44,16 +69,28 @@ func storeDeviceEndpoints(ctx context.Context, secretName string, fqdn string, a s := &kubeapi.Secret{ Data: map[string][]byte{ - "device_fqdn": []byte(fqdn), - "device_ips": deviceIPs, + kubetypes.KeyDeviceFQDN: []byte(fqdn), + kubetypes.KeyDeviceIPs: deviceIPs, }, } - return kc.StrategicMergePatchSecret(ctx, secretName, s, "tailscale-container") + return kc.StrategicMergePatchSecret(ctx, kc.stateSecret, s, "tailscale-container") +} + +// storeHTTPSEndpoint writes an HTTPS endpoint exposed by this device via 'tailscale serve' to the client's state +// Secret. In practice this will be the same value that gets written to 'device_fqdn', but this should only be called +// when the serve config has been successfully set up. +func (kc *kubeClient) storeHTTPSEndpoint(ctx context.Context, ep string) error { + s := &kubeapi.Secret{ + Data: map[string][]byte{ + kubetypes.KeyHTTPSEndpoint: []byte(ep), + }, + } + return kc.StrategicMergePatchSecret(ctx, kc.stateSecret, s, "tailscale-container") } // deleteAuthKey deletes the 'authkey' field of the given kube // secret. No-op if there is no authkey in the secret. -func deleteAuthKey(ctx context.Context, secretName string) error { +func (kc *kubeClient) deleteAuthKey(ctx context.Context) error { // m is a JSON Patch data structure, see https://jsonpatch.com/ or RFC 6902. m := []kubeclient.JSONPatch{ { @@ -61,7 +98,7 @@ func deleteAuthKey(ctx context.Context, secretName string) error { Path: "/data/authkey", }, } - if err := kc.JSONPatchResource(ctx, secretName, kubeclient.TypeSecrets, m); err != nil { + if err := kc.JSONPatchResource(ctx, kc.stateSecret, kubeclient.TypeSecrets, m); err != nil { if s, ok := err.(*kubeapi.Status); ok && s.Code == http.StatusUnprocessableEntity { // This is kubernetes-ese for "the field you asked to // delete already doesn't exist", aka no-op. @@ -72,22 +109,19 @@ func deleteAuthKey(ctx context.Context, secretName string) error { return nil } -var kc kubeclient.Client - -func initKubeClient(root string) { - if root != "/" { - // If we are running in a test, we need to set the root path to the fake - // service account directory. - kubeclient.SetRootPathForTesting(root) +// storeCapVerUID stores the current capability version of tailscale and, if provided, UID of the Pod in the tailscale +// state Secret. +// These two fields are used by the Kubernetes Operator to observe the current capability version of tailscaled running in this container. +func (kc *kubeClient) storeCapVerUID(ctx context.Context, podUID string) error { + capVerS := fmt.Sprintf("%d", tailcfg.CurrentCapabilityVersion) + d := map[string][]byte{ + kubetypes.KeyCapVer: []byte(capVerS), } - var err error - kc, err = kubeclient.New("tailscale-container") - if err != nil { - log.Fatalf("Error creating kube client: %v", err) + if podUID != "" { + d[kubetypes.KeyPodUID] = []byte(podUID) } - if (root != "/") || os.Getenv("TS_KUBERNETES_READ_API_SERVER_ADDRESS_FROM_ENV") == "true" { - // Derive the API server address from the environment variables - // Used to set http server in tests, or optionally enabled by flag - kc.SetURL(fmt.Sprintf("https://%s:%s", os.Getenv("KUBERNETES_SERVICE_HOST"), os.Getenv("KUBERNETES_SERVICE_PORT_HTTPS"))) + s := &kubeapi.Secret{ + Data: d, } + return kc.StrategicMergePatchSecret(ctx, kc.stateSecret, s, "tailscale-container") } diff --git a/cmd/containerboot/kube_test.go b/cmd/containerboot/kube_test.go index 1a5730548838f..2ba69af7c0f57 100644 --- a/cmd/containerboot/kube_test.go +++ b/cmd/containerboot/kube_test.go @@ -21,7 +21,7 @@ func TestSetupKube(t *testing.T) { cfg *settings wantErr bool wantCfg *settings - kc kubeclient.Client + kc *kubeClient }{ { name: "TS_AUTHKEY set, state Secret exists", @@ -29,14 +29,14 @@ func TestSetupKube(t *testing.T) { AuthKey: "foo", KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return nil, nil }, - }, + }}, wantCfg: &settings{ AuthKey: "foo", KubeSecret: "foo", @@ -48,14 +48,14 @@ func TestSetupKube(t *testing.T) { AuthKey: "foo", KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, true, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return nil, &kubeapi.Status{Code: 404} }, - }, + }}, wantCfg: &settings{ AuthKey: "foo", KubeSecret: "foo", @@ -67,14 +67,14 @@ func TestSetupKube(t *testing.T) { AuthKey: "foo", KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return nil, &kubeapi.Status{Code: 404} }, - }, + }}, wantCfg: &settings{ AuthKey: "foo", KubeSecret: "foo", @@ -87,14 +87,14 @@ func TestSetupKube(t *testing.T) { AuthKey: "foo", KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return nil, &kubeapi.Status{Code: 403} }, - }, + }}, wantCfg: &settings{ AuthKey: "foo", KubeSecret: "foo", @@ -111,11 +111,11 @@ func TestSetupKube(t *testing.T) { AuthKey: "foo", KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, errors.New("broken") }, - }, + }}, wantErr: true, }, { @@ -127,14 +127,14 @@ func TestSetupKube(t *testing.T) { wantCfg: &settings{ KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, true, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return nil, &kubeapi.Status{Code: 404} }, - }, + }}, }, { // Interactive login using URL in Pod logs @@ -145,28 +145,28 @@ func TestSetupKube(t *testing.T) { wantCfg: &settings{ KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return &kubeapi.Secret{}, nil }, - }, + }}, }, { name: "TS_AUTHKEY not set, state Secret contains auth key, we do not have RBAC to patch it", cfg: &settings{ KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return &kubeapi.Secret{Data: map[string][]byte{"authkey": []byte("foo")}}, nil }, - }, + }}, wantCfg: &settings{ KubeSecret: "foo", }, @@ -177,14 +177,14 @@ func TestSetupKube(t *testing.T) { cfg: &settings{ KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return true, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return &kubeapi.Secret{Data: map[string][]byte{"authkey": []byte("foo")}}, nil }, - }, + }}, wantCfg: &settings{ KubeSecret: "foo", AuthKey: "foo", @@ -194,9 +194,9 @@ func TestSetupKube(t *testing.T) { } for _, tt := range tests { - kc = tt.kc + kc := tt.kc t.Run(tt.name, func(t *testing.T) { - if err := tt.cfg.setupKube(context.Background()); (err != nil) != tt.wantErr { + if err := tt.cfg.setupKube(context.Background(), kc); (err != nil) != tt.wantErr { t.Errorf("settings.setupKube() error = %v, wantErr %v", err, tt.wantErr) } if diff := cmp.Diff(*tt.cfg, *tt.wantCfg); diff != "" { diff --git a/cmd/containerboot/main.go b/cmd/containerboot/main.go index 0af9062a5f314..ad1c0db201aa5 100644 --- a/cmd/containerboot/main.go +++ b/cmd/containerboot/main.go @@ -121,6 +121,7 @@ import ( "tailscale.com/client/tailscale" "tailscale.com/ipn" kubeutils "tailscale.com/k8s-operator" + "tailscale.com/kube/kubetypes" "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/types/ptr" @@ -167,9 +168,13 @@ func main() { bootCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() + var kc *kubeClient if cfg.InKubernetes { - initKubeClient(cfg.Root) - if err := cfg.setupKube(bootCtx); err != nil { + kc, err = newKubeClient(cfg.Root, cfg.KubeSecret) + if err != nil { + log.Fatalf("error initializing kube client: %v", err) + } + if err := cfg.setupKube(bootCtx, kc); err != nil { log.Fatalf("error setting up for running on Kubernetes: %v", err) } } @@ -319,12 +324,16 @@ authLoop: } } + // Remove any serve config and advertised HTTPS endpoint that may have been set by a previous run of + // containerboot, but only if we're providing a new one. if cfg.ServeConfigPath != "" { - // Remove any serve config that may have been set by a previous run of - // containerboot, but only if we're providing a new one. + log.Printf("serve proxy: unsetting previous config") if err := client.SetServeConfig(ctx, new(ipn.ServeConfig)); err != nil { log.Fatalf("failed to unset serve config: %v", err) } + if err := kc.storeHTTPSEndpoint(ctx, ""); err != nil { + log.Fatalf("failed to update HTTPS endpoint in tailscale state: %v", err) + } } if hasKubeStateStore(cfg) && isTwoStepConfigAuthOnce(cfg) { @@ -332,11 +341,17 @@ authLoop: // authkey is no longer needed. We don't strictly need to // wipe it, but it's good hygiene. log.Printf("Deleting authkey from kube secret") - if err := deleteAuthKey(ctx, cfg.KubeSecret); err != nil { + if err := kc.deleteAuthKey(ctx); err != nil { log.Fatalf("deleting authkey from kube secret: %v", err) } } + if hasKubeStateStore(cfg) { + if err := kc.storeCapVerUID(ctx, cfg.PodUID); err != nil { + log.Fatalf("storing capability version and UID: %v", err) + } + } + w, err = client.WatchIPNBus(ctx, ipn.NotifyInitialNetMap|ipn.NotifyInitialState) if err != nil { log.Fatalf("rewatching tailscaled for updates after auth: %v", err) @@ -355,10 +370,10 @@ authLoop: certDomain = new(atomic.Pointer[string]) certDomainChanged = make(chan bool, 1) + + triggerWatchServeConfigChanges sync.Once ) - if cfg.ServeConfigPath != "" { - go watchServeConfigChanges(ctx, cfg.ServeConfigPath, certDomainChanged, certDomain, client) - } + var nfr linuxfw.NetfilterRunner if isL3Proxy(cfg) { nfr, err = newNetfilterRunner(log.Printf) @@ -459,7 +474,7 @@ runLoop: // fails. deviceID := n.NetMap.SelfNode.StableID() if hasKubeStateStore(cfg) && deephash.Update(¤tDeviceID, &deviceID) { - if err := storeDeviceID(ctx, cfg.KubeSecret, n.NetMap.SelfNode.StableID()); err != nil { + if err := kc.storeDeviceID(ctx, n.NetMap.SelfNode.StableID()); err != nil { log.Fatalf("storing device ID in Kubernetes Secret: %v", err) } } @@ -532,8 +547,11 @@ runLoop: resetTimer(false) backendAddrs = newBackendAddrs } - if cfg.ServeConfigPath != "" && len(n.NetMap.DNS.CertDomains) != 0 { - cd := n.NetMap.DNS.CertDomains[0] + if cfg.ServeConfigPath != "" { + cd := certDomainFromNetmap(n.NetMap) + if cd == "" { + cd = kubetypes.ValueNoHTTPS + } prev := certDomain.Swap(ptr.To(cd)) if prev == nil || *prev != cd { select { @@ -575,7 +593,7 @@ runLoop: // TODO (irbekrm): instead of using the IP and FQDN, have some other mechanism for the proxy signal that it is 'Ready'. deviceEndpoints := []any{n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses()} if hasKubeStateStore(cfg) && deephash.Update(¤tDeviceEndpoints, &deviceEndpoints) { - if err := storeDeviceEndpoints(ctx, cfg.KubeSecret, n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses().AsSlice()); err != nil { + if err := kc.storeDeviceEndpoints(ctx, n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses().AsSlice()); err != nil { log.Fatalf("storing device IPs and FQDN in Kubernetes Secret: %v", err) } } @@ -583,6 +601,13 @@ runLoop: if healthCheck != nil { healthCheck.update(len(addrs) != 0) } + + if cfg.ServeConfigPath != "" { + triggerWatchServeConfigChanges.Do(func() { + go watchServeConfigChanges(ctx, cfg.ServeConfigPath, certDomainChanged, certDomain, client, kc) + }) + } + if egressSvcsNotify != nil { egressSvcsNotify <- n } diff --git a/cmd/containerboot/main_test.go b/cmd/containerboot/main_test.go index 47d7c19cfa78f..83e001b62c09e 100644 --- a/cmd/containerboot/main_test.go +++ b/cmd/containerboot/main_test.go @@ -120,6 +120,8 @@ func TestContainerBoot(t *testing.T) { return fmt.Sprintf("http://127.0.0.1:%d/healthz", port) } + capver := fmt.Sprintf("%d", tailcfg.CurrentCapabilityVersion) + type phase struct { // If non-nil, send this IPN bus notification (and remember it as the // initial update for any future new watchers, then wait for all the @@ -478,10 +480,11 @@ func TestContainerBoot(t *testing.T) { { Notify: runningNotify, WantKubeSecret: map[string]string{ - "authkey": "tskey-key", - "device_fqdn": "test-node.test.ts.net", - "device_id": "myID", - "device_ips": `["100.64.0.1"]`, + "authkey": "tskey-key", + "device_fqdn": "test-node.test.ts.net", + "device_id": "myID", + "device_ips": `["100.64.0.1"]`, + "tailscale_capver": capver, }, }, }, @@ -571,9 +574,10 @@ func TestContainerBoot(t *testing.T) { "/usr/bin/tailscale --socket=/tmp/tailscaled.sock set --accept-dns=false", }, WantKubeSecret: map[string]string{ - "device_fqdn": "test-node.test.ts.net", - "device_id": "myID", - "device_ips": `["100.64.0.1"]`, + "device_fqdn": "test-node.test.ts.net", + "device_id": "myID", + "device_ips": `["100.64.0.1"]`, + "tailscale_capver": capver, }, }, }, @@ -600,10 +604,11 @@ func TestContainerBoot(t *testing.T) { { Notify: runningNotify, WantKubeSecret: map[string]string{ - "authkey": "tskey-key", - "device_fqdn": "test-node.test.ts.net", - "device_id": "myID", - "device_ips": `["100.64.0.1"]`, + "authkey": "tskey-key", + "device_fqdn": "test-node.test.ts.net", + "device_id": "myID", + "device_ips": `["100.64.0.1"]`, + "tailscale_capver": capver, }, }, { @@ -618,10 +623,11 @@ func TestContainerBoot(t *testing.T) { }, }, WantKubeSecret: map[string]string{ - "authkey": "tskey-key", - "device_fqdn": "new-name.test.ts.net", - "device_id": "newID", - "device_ips": `["100.64.0.1"]`, + "authkey": "tskey-key", + "device_fqdn": "new-name.test.ts.net", + "device_id": "newID", + "device_ips": `["100.64.0.1"]`, + "tailscale_capver": capver, }, }, }, diff --git a/cmd/containerboot/serve.go b/cmd/containerboot/serve.go index 6c22b3eeb651e..29ee7347f0c14 100644 --- a/cmd/containerboot/serve.go +++ b/cmd/containerboot/serve.go @@ -19,6 +19,8 @@ import ( "github.com/fsnotify/fsnotify" "tailscale.com/client/tailscale" "tailscale.com/ipn" + "tailscale.com/kube/kubetypes" + "tailscale.com/types/netmap" ) // watchServeConfigChanges watches path for changes, and when it sees one, reads @@ -26,21 +28,21 @@ import ( // applies it to lc. It exits when ctx is canceled. cdChanged is a channel that // is written to when the certDomain changes, causing the serve config to be // re-read and applied. -func watchServeConfigChanges(ctx context.Context, path string, cdChanged <-chan bool, certDomainAtomic *atomic.Pointer[string], lc *tailscale.LocalClient) { +func watchServeConfigChanges(ctx context.Context, path string, cdChanged <-chan bool, certDomainAtomic *atomic.Pointer[string], lc *tailscale.LocalClient, kc *kubeClient) { if certDomainAtomic == nil { - panic("cd must not be nil") + panic("certDomainAtomic must not be nil") } var tickChan <-chan time.Time var eventChan <-chan fsnotify.Event if w, err := fsnotify.NewWatcher(); err != nil { - log.Printf("failed to create fsnotify watcher, timer-only mode: %v", err) + log.Printf("serve proxy: failed to create fsnotify watcher, timer-only mode: %v", err) ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() tickChan = ticker.C } else { defer w.Close() if err := w.Add(filepath.Dir(path)); err != nil { - log.Fatalf("failed to add fsnotify watch: %v", err) + log.Fatalf("serve proxy: failed to add fsnotify watch: %v", err) } eventChan = w.Events } @@ -59,24 +61,60 @@ func watchServeConfigChanges(ctx context.Context, path string, cdChanged <-chan // k8s handles these mounts. So just re-read the file and apply it // if it's changed. } - if certDomain == "" { - continue - } sc, err := readServeConfig(path, certDomain) if err != nil { - log.Fatalf("failed to read serve config: %v", err) + log.Fatalf("serve proxy: failed to read serve config: %v", err) } if prevServeConfig != nil && reflect.DeepEqual(sc, prevServeConfig) { continue } - log.Printf("Applying serve config") - if err := lc.SetServeConfig(ctx, sc); err != nil { - log.Fatalf("failed to set serve config: %v", err) + validateHTTPSServe(certDomain, sc) + if err := updateServeConfig(ctx, sc, certDomain, lc); err != nil { + log.Fatalf("serve proxy: error updating serve config: %v", err) + } + if err := kc.storeHTTPSEndpoint(ctx, certDomain); err != nil { + log.Fatalf("serve proxy: error storing HTTPS endpoint: %v", err) } prevServeConfig = sc } } +func certDomainFromNetmap(nm *netmap.NetworkMap) string { + if len(nm.DNS.CertDomains) == 0 { + return "" + } + return nm.DNS.CertDomains[0] +} + +func updateServeConfig(ctx context.Context, sc *ipn.ServeConfig, certDomain string, lc *tailscale.LocalClient) error { + // TODO(irbekrm): This means that serve config that does not expose HTTPS endpoint will not be set for a tailnet + // that does not have HTTPS enabled. We probably want to fix this. + if certDomain == kubetypes.ValueNoHTTPS { + return nil + } + log.Printf("serve proxy: applying serve config") + return lc.SetServeConfig(ctx, sc) +} + +func validateHTTPSServe(certDomain string, sc *ipn.ServeConfig) { + if certDomain != kubetypes.ValueNoHTTPS || !hasHTTPSEndpoint(sc) { + return + } + log.Printf( + `serve proxy: this node is configured as a proxy that exposes an HTTPS endpoint to tailnet, + (perhaps a Kubernetes operator Ingress proxy) but it is not able to issue TLS certs, so this will likely not work. + To make it work, ensure that HTTPS is enabled for your tailnet, see https://tailscale.com/kb/1153/enabling-https for more details.`) +} + +func hasHTTPSEndpoint(cfg *ipn.ServeConfig) bool { + for _, tcpCfg := range cfg.TCP { + if tcpCfg.HTTPS { + return true + } + } + return false +} + // readServeConfig reads the ipn.ServeConfig from path, replacing // ${TS_CERT_DOMAIN} with certDomain. func readServeConfig(path, certDomain string) (*ipn.ServeConfig, error) { diff --git a/cmd/containerboot/settings.go b/cmd/containerboot/settings.go index 1262a0e1872ec..4fae58584cec7 100644 --- a/cmd/containerboot/settings.go +++ b/cmd/containerboot/settings.go @@ -67,6 +67,7 @@ type settings struct { PodIP string PodIPv4 string PodIPv6 string + PodUID string HealthCheckAddrPort string LocalAddrPort string MetricsEnabled bool @@ -107,6 +108,7 @@ func configFromEnv() (*settings, error) { HealthCheckEnabled: defaultBool("TS_ENABLE_HEALTH_CHECK", false), DebugAddrPort: defaultEnv("TS_DEBUG_ADDR_PORT", ""), EgressSvcsCfgPath: defaultEnv("TS_EGRESS_SERVICES_CONFIG_PATH", ""), + PodUID: defaultEnv("POD_UID", ""), } podIPs, ok := os.LookupEnv("POD_IPS") if ok { @@ -203,7 +205,7 @@ func (s *settings) validate() error { // setupKube is responsible for doing any necessary configuration and checks to // ensure that tailscale state storage and authentication mechanism will work on // Kubernetes. -func (cfg *settings) setupKube(ctx context.Context) error { +func (cfg *settings) setupKube(ctx context.Context, kc *kubeClient) error { if cfg.KubeSecret == "" { return nil } diff --git a/cmd/k8s-operator/connector.go b/cmd/k8s-operator/connector.go index dfeee6be1cb85..1cce02fbba974 100644 --- a/cmd/k8s-operator/connector.go +++ b/cmd/k8s-operator/connector.go @@ -234,21 +234,21 @@ func (a *ConnectorReconciler) maybeProvisionConnector(ctx context.Context, logge return err } - _, tsHost, ips, err := a.ssr.DeviceInfo(ctx, crl) + dev, err := a.ssr.DeviceInfo(ctx, crl, logger) if err != nil { return err } - if tsHost == "" { - logger.Debugf("no Tailscale hostname known yet, waiting for connector pod to finish auth") + if dev == nil || dev.hostname == "" { + logger.Debugf("no Tailscale hostname known yet, waiting for Connector Pod to finish auth") // No hostname yet. Wait for the connector pod to auth. cn.Status.TailnetIPs = nil cn.Status.Hostname = "" return nil } - cn.Status.TailnetIPs = ips - cn.Status.Hostname = tsHost + cn.Status.TailnetIPs = dev.ips + cn.Status.Hostname = dev.hostname return nil } diff --git a/cmd/k8s-operator/ingress.go b/cmd/k8s-operator/ingress.go index 40a5d09283193..749869b2264eb 100644 --- a/cmd/k8s-operator/ingress.go +++ b/cmd/k8s-operator/ingress.go @@ -279,12 +279,12 @@ func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.Suga return fmt.Errorf("failed to provision: %w", err) } - _, tsHost, _, err := a.ssr.DeviceInfo(ctx, crl) + dev, err := a.ssr.DeviceInfo(ctx, crl, logger) if err != nil { - return fmt.Errorf("failed to get device ID: %w", err) + return fmt.Errorf("failed to retrieve Ingress HTTPS endpoint status: %w", err) } - if tsHost == "" { - logger.Debugf("no Tailscale hostname known yet, waiting for proxy pod to finish auth") + if dev == nil || dev.ingressDNSName == "" { + logger.Debugf("no Ingress DNS name known yet, waiting for proxy Pod initialize and start serving Ingress") // No hostname yet. Wait for the proxy pod to auth. ing.Status.LoadBalancer.Ingress = nil if err := a.Status().Update(ctx, ing); err != nil { @@ -293,10 +293,10 @@ func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.Suga return nil } - logger.Debugf("setting ingress hostname to %q", tsHost) + logger.Debugf("setting Ingress hostname to %q", dev.ingressDNSName) ing.Status.LoadBalancer.Ingress = []networkingv1.IngressLoadBalancerIngress{ { - Hostname: tsHost, + Hostname: dev.ingressDNSName, Ports: []networkingv1.IngressPortStatus{ { Protocol: "TCP", diff --git a/cmd/k8s-operator/ingress_test.go b/cmd/k8s-operator/ingress_test.go index e695cc649408c..c4332908a08f9 100644 --- a/cmd/k8s-operator/ingress_test.go +++ b/cmd/k8s-operator/ingress_test.go @@ -142,6 +142,154 @@ func TestTailscaleIngress(t *testing.T) { expectMissing[corev1.Secret](t, fc, "operator-ns", fullName) } +func TestTailscaleIngressHostname(t *testing.T) { + tsIngressClass := &networkingv1.IngressClass{ObjectMeta: metav1.ObjectMeta{Name: "tailscale"}, Spec: networkingv1.IngressClassSpec{Controller: "tailscale.com/ts-ingress"}} + fc := fake.NewFakeClient(tsIngressClass) + ft := &fakeTSClient{} + fakeTsnetServer := &fakeTSNetServer{certDomains: []string{"foo.com"}} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + ingR := &IngressReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + tsnetServer: fakeTsnetServer, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + } + + // 1. Resources get created for regular Ingress + ing := &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + // The apiserver is supposed to set the UID, but the fake client + // doesn't. So, set it explicitly because other code later depends + // on it being set. + UID: types.UID("1234-UID"), + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + DefaultBackend: &networkingv1.IngressBackend{ + Service: &networkingv1.IngressServiceBackend{ + Name: "test", + Port: networkingv1.ServiceBackendPort{ + Number: 8080, + }, + }, + }, + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"default-test"}}, + }, + }, + } + mustCreate(t, fc, ing) + mustCreate(t, fc, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "1.2.3.4", + Ports: []corev1.ServicePort{{ + Port: 8080, + Name: "http"}, + }, + }, + }) + + expectReconciled(t, ingR, "default", "test") + + fullName, shortName := findGenName(t, fc, "default", "test", "ingress") + mustCreate(t, fc, &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: fullName, + Namespace: "operator-ns", + UID: "test-uid", + }, + }) + opts := configOpts{ + stsName: shortName, + secretName: fullName, + namespace: "default", + parentType: "ingress", + hostname: "default-test", + app: kubetypes.AppIngressResource, + } + serveConfig := &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{"${TS_CERT_DOMAIN}:443": {Handlers: map[string]*ipn.HTTPHandler{"/": {Proxy: "http://1.2.3.4:8080/"}}}}, + } + opts.serveConfig = serveConfig + + expectEqual(t, fc, expectedSecret(t, fc, opts), nil) + expectEqual(t, fc, expectedHeadlessService(shortName, "ingress"), nil) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation) + + // 2. Ingress proxy with capability version >= 110 does not have an HTTPS endpoint set + mustUpdate(t, fc, "operator-ns", opts.secretName, func(secret *corev1.Secret) { + mak.Set(&secret.Data, "device_id", []byte("1234")) + mak.Set(&secret.Data, "tailscale_capver", []byte("110")) + mak.Set(&secret.Data, "pod_uid", []byte("test-uid")) + mak.Set(&secret.Data, "device_fqdn", []byte("foo.tailnetxyz.ts.net")) + }) + expectReconciled(t, ingR, "default", "test") + ing.Finalizers = append(ing.Finalizers, "tailscale.com/finalizer") + + expectEqual(t, fc, ing, nil) + + // 3. Ingress proxy with capability version >= 110 advertises HTTPS endpoint + mustUpdate(t, fc, "operator-ns", opts.secretName, func(secret *corev1.Secret) { + mak.Set(&secret.Data, "device_id", []byte("1234")) + mak.Set(&secret.Data, "tailscale_capver", []byte("110")) + mak.Set(&secret.Data, "pod_uid", []byte("test-uid")) + mak.Set(&secret.Data, "device_fqdn", []byte("foo.tailnetxyz.ts.net")) + mak.Set(&secret.Data, "https_endpoint", []byte("foo.tailnetxyz.ts.net")) + }) + expectReconciled(t, ingR, "default", "test") + ing.Status.LoadBalancer = networkingv1.IngressLoadBalancerStatus{ + Ingress: []networkingv1.IngressLoadBalancerIngress{ + {Hostname: "foo.tailnetxyz.ts.net", Ports: []networkingv1.IngressPortStatus{{Port: 443, Protocol: "TCP"}}}, + }, + } + expectEqual(t, fc, ing, nil) + + // 4. Ingress proxy with capability version >= 110 does not have an HTTPS endpoint ready + mustUpdate(t, fc, "operator-ns", opts.secretName, func(secret *corev1.Secret) { + mak.Set(&secret.Data, "device_id", []byte("1234")) + mak.Set(&secret.Data, "tailscale_capver", []byte("110")) + mak.Set(&secret.Data, "pod_uid", []byte("test-uid")) + mak.Set(&secret.Data, "device_fqdn", []byte("foo.tailnetxyz.ts.net")) + mak.Set(&secret.Data, "https_endpoint", []byte("no-https")) + }) + expectReconciled(t, ingR, "default", "test") + ing.Status.LoadBalancer.Ingress = nil + expectEqual(t, fc, ing, nil) + + // 5. Ingress proxy's state has https_endpoints set, but its capver is not matching Pod UID (downgrade) + mustUpdate(t, fc, "operator-ns", opts.secretName, func(secret *corev1.Secret) { + mak.Set(&secret.Data, "device_id", []byte("1234")) + mak.Set(&secret.Data, "tailscale_capver", []byte("110")) + mak.Set(&secret.Data, "pod_uid", []byte("not-the-right-uid")) + mak.Set(&secret.Data, "device_fqdn", []byte("foo.tailnetxyz.ts.net")) + mak.Set(&secret.Data, "https_endpoint", []byte("bar.tailnetxyz.ts.net")) + }) + ing.Status.LoadBalancer = networkingv1.IngressLoadBalancerStatus{ + Ingress: []networkingv1.IngressLoadBalancerIngress{ + {Hostname: "foo.tailnetxyz.ts.net", Ports: []networkingv1.IngressPortStatus{{Port: 443, Protocol: "TCP"}}}, + }, + } + expectReconciled(t, ingR, "default", "test") + expectEqual(t, fc, ing, nil) +} + func TestTailscaleIngressWithProxyClass(t *testing.T) { // Setup pc := &tsapi.ProxyClass{ diff --git a/cmd/k8s-operator/sts.go b/cmd/k8s-operator/sts.go index 5de30154cd59a..ff7c074a8b425 100644 --- a/cmd/k8s-operator/sts.go +++ b/cmd/k8s-operator/sts.go @@ -15,6 +15,7 @@ import ( "net/http" "os" "slices" + "strconv" "strings" "go.uber.org/zap" @@ -197,11 +198,11 @@ func (a *tailscaleSTSReconciler) Provision(ctx context.Context, logger *zap.Suga } sts.ProxyClass = proxyClass - secretName, tsConfigHash, configs, err := a.createOrGetSecret(ctx, logger, sts, hsvc) + secretName, tsConfigHash, _, err := a.createOrGetSecret(ctx, logger, sts, hsvc) if err != nil { return nil, fmt.Errorf("failed to create or get API key secret: %w", err) } - _, err = a.reconcileSTS(ctx, logger, sts, hsvc, secretName, tsConfigHash, configs) + _, err = a.reconcileSTS(ctx, logger, sts, hsvc, secretName, tsConfigHash) if err != nil { return nil, fmt.Errorf("failed to reconcile statefulset: %w", err) } @@ -246,21 +247,21 @@ func (a *tailscaleSTSReconciler) Cleanup(ctx context.Context, logger *zap.Sugare return false, nil } - id, _, _, err := a.DeviceInfo(ctx, labels) + dev, err := a.DeviceInfo(ctx, labels, logger) if err != nil { return false, fmt.Errorf("getting device info: %w", err) } - if id != "" { - logger.Debugf("deleting device %s from control", string(id)) - if err := a.tsClient.DeleteDevice(ctx, string(id)); err != nil { + if dev != nil && dev.id != "" { + logger.Debugf("deleting device %s from control", string(dev.id)) + if err := a.tsClient.DeleteDevice(ctx, string(dev.id)); err != nil { errResp := &tailscale.ErrResponse{} if ok := errors.As(err, errResp); ok && errResp.Status == http.StatusNotFound { - logger.Debugf("device %s not found, likely because it has already been deleted from control", string(id)) + logger.Debugf("device %s not found, likely because it has already been deleted from control", string(dev.id)) } else { return false, fmt.Errorf("deleting device: %w", err) } } else { - logger.Debugf("device %s deleted from control", string(id)) + logger.Debugf("device %s deleted from control", string(dev.id)) } } @@ -440,40 +441,66 @@ func sanitizeConfigBytes(c ipn.ConfigVAlpha) string { // that acts as an operator proxy. It retrieves info from a Kubernetes Secret // labeled with the provided labels. // Either of device ID, hostname and IPs can be empty string if not found in the Secret. -func (a *tailscaleSTSReconciler) DeviceInfo(ctx context.Context, childLabels map[string]string) (id tailcfg.StableNodeID, hostname string, ips []string, err error) { +func (a *tailscaleSTSReconciler) DeviceInfo(ctx context.Context, childLabels map[string]string, logger *zap.SugaredLogger) (dev *device, err error) { sec, err := getSingleObject[corev1.Secret](ctx, a.Client, a.operatorNamespace, childLabels) if err != nil { - return "", "", nil, err + return dev, err } if sec == nil { - return "", "", nil, nil + return dev, nil + } + pod := new(corev1.Pod) + if err := a.Get(ctx, types.NamespacedName{Namespace: sec.Namespace, Name: sec.Name}, pod); err != nil && !apierrors.IsNotFound(err) { + return dev, nil } - return deviceInfo(sec) + return deviceInfo(sec, pod, logger) +} + +// device contains tailscale state of a proxy device as gathered from its tailscale state Secret. +type device struct { + id tailcfg.StableNodeID // device's stable ID + hostname string // MagicDNS name of the device + ips []string // Tailscale IPs of the device + // ingressDNSName is the L7 Ingress DNS name. In practice this will be the same value as hostname, but only set + // when the device has been configured to serve traffic on it via 'tailscale serve'. + ingressDNSName string } -func deviceInfo(sec *corev1.Secret) (id tailcfg.StableNodeID, hostname string, ips []string, err error) { - id = tailcfg.StableNodeID(sec.Data["device_id"]) +func deviceInfo(sec *corev1.Secret, pod *corev1.Pod, log *zap.SugaredLogger) (dev *device, err error) { + id := tailcfg.StableNodeID(sec.Data[kubetypes.KeyDeviceID]) if id == "" { - return "", "", nil, nil + return dev, nil } + dev = &device{id: id} // Kubernetes chokes on well-formed FQDNs with the trailing dot, so we have // to remove it. - hostname = strings.TrimSuffix(string(sec.Data["device_fqdn"]), ".") - if hostname == "" { + dev.hostname = strings.TrimSuffix(string(sec.Data[kubetypes.KeyDeviceFQDN]), ".") + if dev.hostname == "" { // Device ID gets stored and retrieved in a different flow than // FQDN and IPs. A device that acts as Kubernetes operator - // proxy, but whose route setup has failed might have an device + // proxy, but whose route setup has failed might have a device // ID, but no FQDN/IPs. If so, return the ID, to allow the // operator to clean up such devices. - return id, "", nil, nil + return dev, nil + } + // TODO(irbekrm): we fall back to using the hostname field to determine Ingress's hostname to ensure backwards + // compatibility. In 1.82 we can remove this fallback mechanism. + dev.ingressDNSName = dev.hostname + if proxyCapVer(sec, pod, log) >= 109 { + dev.ingressDNSName = strings.TrimSuffix(string(sec.Data[kubetypes.KeyHTTPSEndpoint]), ".") + if strings.EqualFold(dev.ingressDNSName, kubetypes.ValueNoHTTPS) { + dev.ingressDNSName = "" + } } - if rawDeviceIPs, ok := sec.Data["device_ips"]; ok { + if rawDeviceIPs, ok := sec.Data[kubetypes.KeyDeviceIPs]; ok { + ips := make([]string, 0) if err := json.Unmarshal(rawDeviceIPs, &ips); err != nil { - return "", "", nil, err + return nil, err } + dev.ips = ips } - return id, hostname, ips, nil + return dev, nil } func newAuthKey(ctx context.Context, tsClient tsClient, tags []string) (string, error) { @@ -500,7 +527,7 @@ var proxyYaml []byte //go:embed deploy/manifests/userspace-proxy.yaml var userspaceProxyYaml []byte -func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.SugaredLogger, sts *tailscaleSTSConfig, headlessSvc *corev1.Service, proxySecret, tsConfigHash string, _ map[tailcfg.CapabilityVersion]ipn.ConfigVAlpha) (*appsv1.StatefulSet, error) { +func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.SugaredLogger, sts *tailscaleSTSConfig, headlessSvc *corev1.Service, proxySecret, tsConfigHash string) (*appsv1.StatefulSet, error) { ss := new(appsv1.StatefulSet) if sts.ServeConfig != nil && sts.ForwardClusterTrafficViaL7IngressProxy != true { // If forwarding cluster traffic via is required we need non-userspace + NET_ADMIN + forwarding if err := yaml.Unmarshal(userspaceProxyYaml, &ss); err != nil { @@ -1084,3 +1111,23 @@ func nameForService(svc *corev1.Service) string { func isValidFirewallMode(m string) bool { return m == "auto" || m == "nftables" || m == "iptables" } + +// proxyCapVer accepts a proxy state Secret and a proxy Pod returns the capability version of a proxy Pod. +// This is best effort - if the capability version can not (currently) be determined, it returns -1. +func proxyCapVer(sec *corev1.Secret, pod *corev1.Pod, log *zap.SugaredLogger) tailcfg.CapabilityVersion { + if sec == nil || pod == nil { + return tailcfg.CapabilityVersion(-1) + } + if len(sec.Data[kubetypes.KeyCapVer]) == 0 || len(sec.Data[kubetypes.KeyPodUID]) == 0 { + return tailcfg.CapabilityVersion(-1) + } + capVer, err := strconv.Atoi(string(sec.Data[kubetypes.KeyCapVer])) + if err != nil { + log.Infof("[unexpected]: unexpected capability version in proxy's state Secret, expected an integer, got %q", string(sec.Data[kubetypes.KeyCapVer])) + return tailcfg.CapabilityVersion(-1) + } + if !strings.EqualFold(string(pod.ObjectMeta.UID), string(sec.Data[kubetypes.KeyPodUID])) { + return tailcfg.CapabilityVersion(-1) + } + return tailcfg.CapabilityVersion(capVer) +} diff --git a/cmd/k8s-operator/svc.go b/cmd/k8s-operator/svc.go index cbf50c81f6630..314ac2398af65 100644 --- a/cmd/k8s-operator/svc.go +++ b/cmd/k8s-operator/svc.go @@ -320,11 +320,11 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga return nil } - _, tsHost, tsIPs, err := a.ssr.DeviceInfo(ctx, crl) + dev, err := a.ssr.DeviceInfo(ctx, crl, logger) if err != nil { return fmt.Errorf("failed to get device ID: %w", err) } - if tsHost == "" { + if dev == nil || dev.hostname == "" { msg := "no Tailscale hostname known yet, waiting for proxy pod to finish auth" logger.Debug(msg) // No hostname yet. Wait for the proxy pod to auth. @@ -333,9 +333,9 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga return nil } - logger.Debugf("setting Service LoadBalancer status to %q, %s", tsHost, strings.Join(tsIPs, ", ")) + logger.Debugf("setting Service LoadBalancer status to %q, %s", dev.hostname, strings.Join(dev.ips, ", ")) ingress := []corev1.LoadBalancerIngress{ - {Hostname: tsHost}, + {Hostname: dev.hostname}, } clusterIPAddr, err := netip.ParseAddr(svc.Spec.ClusterIP) if err != nil { @@ -343,7 +343,7 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga tsoperator.SetServiceCondition(svc, tsapi.ProxyReady, metav1.ConditionFalse, reasonProxyFailed, msg, a.clock, logger) return errors.New(msg) } - for _, ip := range tsIPs { + for _, ip := range dev.ips { addr, err := netip.ParseAddr(ip) if err != nil { continue diff --git a/kube/kubetypes/metrics.go b/kube/kubetypes/types.go similarity index 59% rename from kube/kubetypes/metrics.go rename to kube/kubetypes/types.go index 63325182d29c4..3c97d8c7da2c5 100644 --- a/kube/kubetypes/metrics.go +++ b/kube/kubetypes/types.go @@ -27,4 +27,19 @@ const ( MetricEgressServiceCount = "k8s_egress_service_resources" MetricProxyGroupEgressCount = "k8s_proxygroup_egress_resources" MetricProxyGroupIngressCount = "k8s_proxygroup_ingress_resources" + + // Keys that containerboot writes to state file that can be used to determine its state. + // fields set in Tailscale state Secret. These are mostly used by the Tailscale Kubernetes operator to determine + // the state of this tailscale device. + KeyDeviceID string = "device_id" // node stable ID of the device + KeyDeviceFQDN string = "device_fqdn" // device's tailnet hostname + KeyDeviceIPs string = "device_ips" // device's tailnet IPs + KeyPodUID string = "pod_uid" // Pod UID + // KeyCapVer contains Tailscale capability version of this proxy instance. + KeyCapVer string = "tailscale_capver" + // KeyHTTPSEndpoint is a name of a field that can be set to the value of any HTTPS endpoint currently exposed by + // this device to the tailnet. This is used by the Kubernetes operator Ingress proxy to communicate to the operator + // that cluster workloads behind the Ingress can now be accessed via the given DNS name over HTTPS. + KeyHTTPSEndpoint string = "https_endpoint" + ValueNoHTTPS string = "no-https" ) From 74069774bee3aeb52637a58587ddfb0369f69676 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 4 Dec 2024 08:41:37 -0800 Subject: [PATCH 166/179] net/tstun: remove tailscaled_outbound_dropped_packets_total reason=acl metric for now Updates #14280 Change-Id: Idff102b3d7650fc9dfbe0c340168806bdf542d76 Signed-off-by: Brad Fitzpatrick --- net/tstun/wrap.go | 7 ++++--- net/tstun/wrap_test.go | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index c384abf9d4bbe..deb8bc0944a37 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -876,9 +876,10 @@ func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed, pc *peerConf if filt.RunOut(p, t.filterFlags) != filter.Accept { metricPacketOutDropFilter.Add(1) - t.metrics.outboundDroppedPacketsTotal.Add(usermetric.DropLabels{ - Reason: usermetric.ReasonACL, - }, 1) + // TODO(#14280): increment a t.metrics.outboundDroppedPacketsTotal here + // once we figure out & document what labels to use for multicast, + // link-local-unicast, IP fragments, etc. But they're not + // usermetric.ReasonACL. return filter.Drop, gro } diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index 9ebedda837b0a..a3dfe7d86c914 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -453,7 +453,7 @@ func TestFilter(t *testing.T) { assertMetricPackets(t, "inACL", 3, metricInboundDroppedPacketsACL) assertMetricPackets(t, "inError", 0, metricInboundDroppedPacketsErr) - assertMetricPackets(t, "outACL", 1, metricOutboundDroppedPacketsACL) + assertMetricPackets(t, "outACL", 0, metricOutboundDroppedPacketsACL) } func assertMetricPackets(t *testing.T, metricName string, want, got int64) { From 7f9ebc0a83f82922787f4e8336b9f626d895a08c Mon Sep 17 00:00:00 2001 From: James Tucker Date: Wed, 4 Dec 2024 12:02:59 -0800 Subject: [PATCH 167/179] cmd/tailscale,net/netcheck: add debug feature to force preferred DERP This provides an interface for a user to force a preferred DERP outcome for all future netchecks that will take precedence unless the forced region is unreachable. The option does not persist and will be lost when the daemon restarts. Updates tailscale/corp#18997 Updates tailscale/corp#24755 Signed-off-by: James Tucker --- client/tailscale/localclient.go | 11 +++++++ cmd/tailscale/cli/debug.go | 25 +++++++++++++++ ipn/ipnlocal/local.go | 6 ++++ ipn/localapi/localapi.go | 7 +++++ net/netcheck/netcheck.go | 28 +++++++++++++++++ net/netcheck/netcheck_test.go | 56 ++++++++++++++++++++++++++++++++- wgengine/magicsock/magicsock.go | 8 +++++ 7 files changed, 140 insertions(+), 1 deletion(-) diff --git a/client/tailscale/localclient.go b/client/tailscale/localclient.go index 5eb66817698b7..34c094a63fbf9 100644 --- a/client/tailscale/localclient.go +++ b/client/tailscale/localclient.go @@ -493,6 +493,17 @@ func (lc *LocalClient) DebugAction(ctx context.Context, action string) error { return nil } +// DebugActionBody invokes a debug action with a body parameter, such as +// "debug-force-prefer-derp". +// These are development tools and subject to change or removal over time. +func (lc *LocalClient) DebugActionBody(ctx context.Context, action string, rbody io.Reader) error { + body, err := lc.send(ctx, "POST", "/localapi/v0/debug?action="+url.QueryEscape(action), 200, rbody) + if err != nil { + return fmt.Errorf("error %w: %s", err, body) + } + return nil +} + // DebugResultJSON invokes a debug action and returns its result as something JSON-able. // These are development tools and subject to change or removal over time. func (lc *LocalClient) DebugResultJSON(ctx context.Context, action string) (any, error) { diff --git a/cmd/tailscale/cli/debug.go b/cmd/tailscale/cli/debug.go index 78bd708e54fee..04b343e760e3c 100644 --- a/cmd/tailscale/cli/debug.go +++ b/cmd/tailscale/cli/debug.go @@ -175,6 +175,12 @@ var debugCmd = &ffcli.Command{ Exec: localAPIAction("pick-new-derp"), ShortHelp: "Switch to some other random DERP home region for a short time", }, + { + Name: "force-prefer-derp", + ShortUsage: "tailscale debug force-prefer-derp", + Exec: forcePreferDERP, + ShortHelp: "Prefer the given region ID if reachable (until restart, or 0 to clear)", + }, { Name: "force-netmap-update", ShortUsage: "tailscale debug force-netmap-update", @@ -577,6 +583,25 @@ func runDERPMap(ctx context.Context, args []string) error { return nil } +func forcePreferDERP(ctx context.Context, args []string) error { + var n int + if len(args) != 1 { + return errors.New("expected exactly one integer argument") + } + n, err := strconv.Atoi(args[0]) + if err != nil { + return fmt.Errorf("expected exactly one integer argument: %w", err) + } + b, err := json.Marshal(n) + if err != nil { + return fmt.Errorf("failed to marshal DERP region: %w", err) + } + if err := localClient.DebugActionBody(ctx, "force-prefer-derp", bytes.NewReader(b)); err != nil { + return fmt.Errorf("failed to force preferred DERP: %w", err) + } + return nil +} + func localAPIAction(action string) func(context.Context, []string) error { return func(ctx context.Context, args []string) error { if len(args) > 0 { diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 278614c0b90dd..f456d49844f1e 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -2920,6 +2920,12 @@ func (b *LocalBackend) DebugPickNewDERP() error { return b.sys.MagicSock.Get().DebugPickNewDERP() } +// DebugForcePreferDERP forwards to netcheck.DebugForcePreferDERP. +// See its docs. +func (b *LocalBackend) DebugForcePreferDERP(n int) { + b.sys.MagicSock.Get().DebugForcePreferDERP(n) +} + // send delivers n to the connected frontend and any API watchers from // LocalBackend.WatchNotifications (via the LocalAPI). // diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index ea931b0280ccf..c14a4bdf285df 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -634,6 +634,13 @@ func (h *Handler) serveDebug(w http.ResponseWriter, r *http.Request) { } case "pick-new-derp": err = h.b.DebugPickNewDERP() + case "force-prefer-derp": + var n int + err = json.NewDecoder(r.Body).Decode(&n) + if err != nil { + break + } + h.b.DebugForcePreferDERP(n) case "": err = fmt.Errorf("missing parameter 'action'") default: diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index 0bb9305683e56..d8f5e1d4996c6 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -236,6 +236,10 @@ type Client struct { // If false, the default net.Resolver will be used, with no caching. UseDNSCache bool + // if non-zero, force this DERP region to be preferred in all reports where + // the DERP is found to be reachable. + ForcePreferredDERP int + // For tests testEnoughRegions int testCaptivePortalDelay time.Duration @@ -780,6 +784,12 @@ func (o *GetReportOpts) getLastDERPActivity(region int) time.Time { return o.GetLastDERPActivity(region) } +func (c *Client) SetForcePreferredDERP(region int) { + c.mu.Lock() + defer c.mu.Unlock() + c.ForcePreferredDERP = region +} + // GetReport gets a report. The 'opts' argument is optional and can be nil. // Callers are discouraged from passing a ctx with an arbitrary deadline as this // may cause GetReport to return prematurely before all reporting methods have @@ -1277,6 +1287,9 @@ func (c *Client) logConciseReport(r *Report, dm *tailcfg.DERPMap) { if r.CaptivePortal != "" { fmt.Fprintf(w, " captiveportal=%v", r.CaptivePortal) } + if c.ForcePreferredDERP != 0 { + fmt.Fprintf(w, " force=%v", c.ForcePreferredDERP) + } fmt.Fprintf(w, " derp=%v", r.PreferredDERP) if r.PreferredDERP != 0 { fmt.Fprintf(w, " derpdist=") @@ -1435,6 +1448,21 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(rs *reportState, r *Report, // which undoes any region change we made above. r.PreferredDERP = prevDERP } + if c.ForcePreferredDERP != 0 { + // If the forced DERP region probed successfully, or has recent traffic, + // use it. + _, haveLatencySample := r.RegionLatency[c.ForcePreferredDERP] + var recentActivity bool + if lastHeard := rs.opts.getLastDERPActivity(c.ForcePreferredDERP); !lastHeard.IsZero() { + now := c.timeNow() + recentActivity = lastHeard.After(rs.start) + recentActivity = recentActivity || lastHeard.After(now.Add(-PreferredDERPFrameTime)) + } + + if haveLatencySample || recentActivity { + r.PreferredDERP = c.ForcePreferredDERP + } + } } func updateLatency(m map[int]time.Duration, regionID int, d time.Duration) { diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index 23891efcc6e48..88c19623d0f0a 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -201,6 +201,7 @@ func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) { steps []step homeParams *tailcfg.DERPHomeParams opts *GetReportOpts + forcedDERP int // if non-zero, force this DERP to be the preferred one wantDERP int // want PreferredDERP on final step wantPrevLen int // wanted len(c.prev) }{ @@ -366,12 +367,65 @@ func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) { wantPrevLen: 2, wantDERP: 1, // diff is 11ms, but d2 is greater than 2/3s of d1 }, + { + name: "forced_two", + steps: []step{ + {time.Second, report("d1", 2, "d2", 3)}, + {2 * time.Second, report("d1", 4, "d2", 3)}, + }, + forcedDERP: 2, + wantPrevLen: 2, + wantDERP: 2, + }, + { + name: "forced_two_unavailable", + steps: []step{ + {time.Second, report("d1", 2, "d2", 1)}, + {2 * time.Second, report("d1", 4)}, + }, + forcedDERP: 2, + wantPrevLen: 2, + wantDERP: 1, + }, + { + name: "forced_two_no_probe_recent_activity", + steps: []step{ + {time.Second, report("d1", 2)}, + {2 * time.Second, report("d1", 4)}, + }, + opts: &GetReportOpts{ + GetLastDERPActivity: mkLDAFunc(map[int]time.Time{ + 1: startTime, + 2: startTime.Add(time.Second), + }), + }, + forcedDERP: 2, + wantPrevLen: 2, + wantDERP: 2, + }, + { + name: "forced_two_no_probe_no_recent_activity", + steps: []step{ + {time.Second, report("d1", 2)}, + {PreferredDERPFrameTime + time.Second, report("d1", 4)}, + }, + opts: &GetReportOpts{ + GetLastDERPActivity: mkLDAFunc(map[int]time.Time{ + 1: startTime, + 2: startTime, + }), + }, + forcedDERP: 2, + wantPrevLen: 2, + wantDERP: 1, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fakeTime := startTime c := &Client{ - TimeNow: func() time.Time { return fakeTime }, + TimeNow: func() time.Time { return fakeTime }, + ForcePreferredDERP: tt.forcedDERP, } dm := &tailcfg.DERPMap{HomeParams: tt.homeParams} rs := &reportState{ diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 805716e61daae..bff905caa5ae4 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -3013,6 +3013,14 @@ func (c *Conn) DebugPickNewDERP() error { return errors.New("too few regions") } +func (c *Conn) DebugForcePreferDERP(n int) { + c.mu.Lock() + defer c.mu.Unlock() + + c.logf("magicsock: [debug] force preferred DERP set to: %d", n) + c.netChecker.SetForcePreferredDERP(n) +} + // portableTrySetSocketBuffer sets SO_SNDBUF and SO_RECVBUF on pconn to socketBufferSize, // logging an error if it occurs. func portableTrySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) { From df94a1487076f744742d5b5c3a234d628bfd2bb5 Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Thu, 5 Dec 2024 12:11:22 +0000 Subject: [PATCH 168/179] cmd/k8s-operator: don't error for transient failures (#14073) Every so often, the ProxyGroup and other controllers lose an optimistic locking race with other controllers that update the objects they create. Stop treating this as an error event, and instead just log an info level log line for it. Fixes #14072 Signed-off-by: Tom Proctor --- cmd/k8s-operator/connector.go | 17 +++++++++++++---- cmd/k8s-operator/dnsrecords.go | 11 ++++++++++- cmd/k8s-operator/egress-services.go | 10 +++++++++- cmd/k8s-operator/ingress.go | 10 +++++++++- cmd/k8s-operator/nameserver.go | 8 +++++++- cmd/k8s-operator/proxygroup.go | 18 +++++++++++++++--- cmd/k8s-operator/svc.go | 10 +++++++++- cmd/k8s-operator/tsrecorder.go | 17 ++++++++++++----- 8 files changed, 84 insertions(+), 17 deletions(-) diff --git a/cmd/k8s-operator/connector.go b/cmd/k8s-operator/connector.go index 1cce02fbba974..c243036cbabd9 100644 --- a/cmd/k8s-operator/connector.go +++ b/cmd/k8s-operator/connector.go @@ -10,6 +10,7 @@ import ( "fmt" "net/netip" "slices" + "strings" "sync" "time" @@ -35,6 +36,7 @@ import ( const ( reasonConnectorCreationFailed = "ConnectorCreationFailed" + reasonConnectorCreating = "ConnectorCreating" reasonConnectorCreated = "ConnectorCreated" reasonConnectorInvalid = "ConnectorInvalid" @@ -134,17 +136,24 @@ func (a *ConnectorReconciler) Reconcile(ctx context.Context, req reconcile.Reque } if err := a.validate(cn); err != nil { - logger.Errorf("error validating Connector spec: %w", err) message := fmt.Sprintf(messageConnectorInvalid, err) a.recorder.Eventf(cn, corev1.EventTypeWarning, reasonConnectorInvalid, message) return setStatus(cn, tsapi.ConnectorReady, metav1.ConditionFalse, reasonConnectorInvalid, message) } if err = a.maybeProvisionConnector(ctx, logger, cn); err != nil { - logger.Errorf("error creating Connector resources: %w", err) + reason := reasonConnectorCreationFailed message := fmt.Sprintf(messageConnectorCreationFailed, err) - a.recorder.Eventf(cn, corev1.EventTypeWarning, reasonConnectorCreationFailed, message) - return setStatus(cn, tsapi.ConnectorReady, metav1.ConditionFalse, reasonConnectorCreationFailed, message) + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + reason = reasonConnectorCreating + message = fmt.Sprintf("optimistic lock error, retrying: %s", err) + err = nil + logger.Info(message) + } else { + a.recorder.Eventf(cn, corev1.EventTypeWarning, reason, message) + } + + return setStatus(cn, tsapi.ConnectorReady, metav1.ConditionFalse, reason, message) } logger.Info("Connector resources synced") diff --git a/cmd/k8s-operator/dnsrecords.go b/cmd/k8s-operator/dnsrecords.go index bba87bf255910..f91dd49ec255e 100644 --- a/cmd/k8s-operator/dnsrecords.go +++ b/cmd/k8s-operator/dnsrecords.go @@ -10,6 +10,7 @@ import ( "encoding/json" "fmt" "slices" + "strings" "go.uber.org/zap" corev1 "k8s.io/api/core/v1" @@ -98,7 +99,15 @@ func (dnsRR *dnsRecordsReconciler) Reconcile(ctx context.Context, req reconcile. return reconcile.Result{}, nil } - return reconcile.Result{}, dnsRR.maybeProvision(ctx, headlessSvc, logger) + if err := dnsRR.maybeProvision(ctx, headlessSvc, logger); err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + } else { + return reconcile.Result{}, err + } + } + + return reconcile.Result{}, nil } // maybeProvision ensures that dnsrecords ConfigMap contains a record for the diff --git a/cmd/k8s-operator/egress-services.go b/cmd/k8s-operator/egress-services.go index a08c0b71563f0..7544376fb2e65 100644 --- a/cmd/k8s-operator/egress-services.go +++ b/cmd/k8s-operator/egress-services.go @@ -156,7 +156,15 @@ func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Re return res, err } - return res, esr.maybeProvision(ctx, svc, l) + if err := esr.maybeProvision(ctx, svc, l); err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + l.Infof("optimistic lock error, retrying: %s", err) + } else { + return reconcile.Result{}, err + } + } + + return res, nil } func (esr *egressSvcsReconciler) maybeProvision(ctx context.Context, svc *corev1.Service, l *zap.SugaredLogger) (err error) { diff --git a/cmd/k8s-operator/ingress.go b/cmd/k8s-operator/ingress.go index 749869b2264eb..3eb47dfb00ad3 100644 --- a/cmd/k8s-operator/ingress.go +++ b/cmd/k8s-operator/ingress.go @@ -76,7 +76,15 @@ func (a *IngressReconciler) Reconcile(ctx context.Context, req reconcile.Request return reconcile.Result{}, a.maybeCleanup(ctx, logger, ing) } - return reconcile.Result{}, a.maybeProvision(ctx, logger, ing) + if err := a.maybeProvision(ctx, logger, ing); err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + } else { + return reconcile.Result{}, err + } + } + + return reconcile.Result{}, nil } func (a *IngressReconciler) maybeCleanup(ctx context.Context, logger *zap.SugaredLogger, ing *networkingv1.Ingress) error { diff --git a/cmd/k8s-operator/nameserver.go b/cmd/k8s-operator/nameserver.go index 6a9a6be935642..ef0762a1234e6 100644 --- a/cmd/k8s-operator/nameserver.go +++ b/cmd/k8s-operator/nameserver.go @@ -9,6 +9,7 @@ import ( "context" "fmt" "slices" + "strings" "sync" _ "embed" @@ -131,7 +132,12 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ } } if err := a.maybeProvision(ctx, &dnsCfg, logger); err != nil { - return reconcile.Result{}, fmt.Errorf("error provisioning nameserver resources: %w", err) + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + return reconcile.Result{}, nil + } else { + return reconcile.Result{}, fmt.Errorf("error provisioning nameserver resources: %w", err) + } } a.mu.Lock() diff --git a/cmd/k8s-operator/proxygroup.go b/cmd/k8s-operator/proxygroup.go index 344cd9ae065f4..39b7ccc01f6fb 100644 --- a/cmd/k8s-operator/proxygroup.go +++ b/cmd/k8s-operator/proxygroup.go @@ -12,6 +12,7 @@ import ( "fmt" "net/http" "slices" + "strings" "sync" "github.com/pkg/errors" @@ -45,6 +46,9 @@ const ( reasonProxyGroupReady = "ProxyGroupReady" reasonProxyGroupCreating = "ProxyGroupCreating" reasonProxyGroupInvalid = "ProxyGroupInvalid" + + // Copied from k8s.io/apiserver/pkg/registry/generic/registry/store.go@cccad306d649184bf2a0e319ba830c53f65c445c + optimisticLockErrorMsg = "the object has been modified; please apply your changes to the latest version and try again" ) var gaugeProxyGroupResources = clientmetric.NewGauge(kubetypes.MetricProxyGroupEgressCount) @@ -166,9 +170,17 @@ func (r *ProxyGroupReconciler) Reconcile(ctx context.Context, req reconcile.Requ } if err = r.maybeProvision(ctx, pg, proxyClass); err != nil { - err = fmt.Errorf("error provisioning ProxyGroup resources: %w", err) - r.recorder.Eventf(pg, corev1.EventTypeWarning, reasonProxyGroupCreationFailed, err.Error()) - return setStatusReady(pg, metav1.ConditionFalse, reasonProxyGroupCreationFailed, err.Error()) + reason := reasonProxyGroupCreationFailed + msg := fmt.Sprintf("error provisioning ProxyGroup resources: %s", err) + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + reason = reasonProxyGroupCreating + msg = fmt.Sprintf("optimistic lock error, retrying: %s", err) + err = nil + logger.Info(msg) + } else { + r.recorder.Eventf(pg, corev1.EventTypeWarning, reason, msg) + } + return setStatusReady(pg, metav1.ConditionFalse, reason, msg) } desiredReplicas := int(pgReplicas(pg)) diff --git a/cmd/k8s-operator/svc.go b/cmd/k8s-operator/svc.go index 314ac2398af65..70c810b256c99 100644 --- a/cmd/k8s-operator/svc.go +++ b/cmd/k8s-operator/svc.go @@ -121,7 +121,15 @@ func (a *ServiceReconciler) Reconcile(ctx context.Context, req reconcile.Request return reconcile.Result{}, a.maybeCleanup(ctx, logger, svc) } - return reconcile.Result{}, a.maybeProvision(ctx, logger, svc) + if err := a.maybeProvision(ctx, logger, svc); err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + } else { + return reconcile.Result{}, err + } + } + + return reconcile.Result{}, nil } // maybeCleanup removes any existing resources related to serving svc over tailscale. diff --git a/cmd/k8s-operator/tsrecorder.go b/cmd/k8s-operator/tsrecorder.go index 4445578a63920..44ce731fe82d6 100644 --- a/cmd/k8s-operator/tsrecorder.go +++ b/cmd/k8s-operator/tsrecorder.go @@ -11,6 +11,7 @@ import ( "fmt" "net/http" "slices" + "strings" "sync" "github.com/pkg/errors" @@ -38,6 +39,7 @@ import ( const ( reasonRecorderCreationFailed = "RecorderCreationFailed" + reasonRecorderCreating = "RecorderCreating" reasonRecorderCreated = "RecorderCreated" reasonRecorderInvalid = "RecorderInvalid" @@ -119,23 +121,28 @@ func (r *RecorderReconciler) Reconcile(ctx context.Context, req reconcile.Reques logger.Infof("ensuring Recorder is set up") tsr.Finalizers = append(tsr.Finalizers, FinalizerName) if err := r.Update(ctx, tsr); err != nil { - logger.Errorf("error adding finalizer: %w", err) return setStatusReady(tsr, metav1.ConditionFalse, reasonRecorderCreationFailed, reasonRecorderCreationFailed) } } if err := r.validate(tsr); err != nil { - logger.Errorf("error validating Recorder spec: %w", err) message := fmt.Sprintf("Recorder is invalid: %s", err) r.recorder.Eventf(tsr, corev1.EventTypeWarning, reasonRecorderInvalid, message) return setStatusReady(tsr, metav1.ConditionFalse, reasonRecorderInvalid, message) } if err = r.maybeProvision(ctx, tsr); err != nil { - logger.Errorf("error creating Recorder resources: %w", err) + reason := reasonRecorderCreationFailed message := fmt.Sprintf("failed creating Recorder: %s", err) - r.recorder.Eventf(tsr, corev1.EventTypeWarning, reasonRecorderCreationFailed, message) - return setStatusReady(tsr, metav1.ConditionFalse, reasonRecorderCreationFailed, message) + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + reason = reasonRecorderCreating + message = fmt.Sprintf("optimistic lock error, retrying: %s", err) + err = nil + logger.Info(message) + } else { + r.recorder.Eventf(tsr, corev1.EventTypeWarning, reasonRecorderCreationFailed, message) + } + return setStatusReady(tsr, metav1.ConditionFalse, reason, message) } logger.Info("Recorder resources synced") From 614c6126435f2f63090586a1a5835379f5d77874 Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Thu, 5 Dec 2024 13:21:03 +0000 Subject: [PATCH 169/179] net/netcheck: preserve STUN port defaulting to 3478 (#14289) Updates tailscale/tailscale#14287 Signed-off-by: Irbe Krumina --- net/netcheck/netcheck.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index d8f5e1d4996c6..7930f88f6dce6 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -1570,6 +1570,9 @@ func (c *Client) nodeAddrPort(ctx context.Context, n *tailcfg.DERPNode, port int if port < 0 || port > 1<<16-1 { return zero, false } + if port == 0 { + port = 3478 + } if n.STUNTestIP != "" { ip, err := netip.ParseAddr(n.STUNTestIP) if err != nil { From 87546a5edf6b6503a87eeb2d666baba57398a066 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 5 Dec 2024 09:40:40 -0800 Subject: [PATCH 170/179] cmd/derper: allow absent SNI when using manual certs and IP literal for hostname Updates #11776 Change-Id: I81756415feb630da093833accc3074903ebd84a7 Signed-off-by: Brad Fitzpatrick --- cmd/derper/cert.go | 14 ++++-- cmd/derper/cert_test.go | 97 +++++++++++++++++++++++++++++++++++++++ cmd/derper/derper.go | 2 +- cmd/derper/derper_test.go | 2 - 4 files changed, 108 insertions(+), 7 deletions(-) create mode 100644 cmd/derper/cert_test.go diff --git a/cmd/derper/cert.go b/cmd/derper/cert.go index db84aa515d257..623fa376f452c 100644 --- a/cmd/derper/cert.go +++ b/cmd/derper/cert.go @@ -8,6 +8,7 @@ import ( "crypto/x509" "errors" "fmt" + "net" "net/http" "path/filepath" "regexp" @@ -53,8 +54,9 @@ func certProviderByCertMode(mode, dir, hostname string) (certProvider, error) { } type manualCertManager struct { - cert *tls.Certificate - hostname string + cert *tls.Certificate + hostname string // hostname or IP address of server + noHostname bool // whether hostname is an IP address } // NewManualCertManager returns a cert provider which read certificate by given hostname on create. @@ -74,7 +76,11 @@ func NewManualCertManager(certdir, hostname string) (certProvider, error) { if err := x509Cert.VerifyHostname(hostname); err != nil { return nil, fmt.Errorf("cert invalid for hostname %q: %w", hostname, err) } - return &manualCertManager{cert: &cert, hostname: hostname}, nil + return &manualCertManager{ + cert: &cert, + hostname: hostname, + noHostname: net.ParseIP(hostname) != nil, + }, nil } func (m *manualCertManager) TLSConfig() *tls.Config { @@ -88,7 +94,7 @@ func (m *manualCertManager) TLSConfig() *tls.Config { } func (m *manualCertManager) getCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { - if hi.ServerName != m.hostname { + if hi.ServerName != m.hostname && !m.noHostname { return nil, fmt.Errorf("cert mismatch with hostname: %q", hi.ServerName) } diff --git a/cmd/derper/cert_test.go b/cmd/derper/cert_test.go new file mode 100644 index 0000000000000..a379e5c04c32e --- /dev/null +++ b/cmd/derper/cert_test.go @@ -0,0 +1,97 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "os" + "path/filepath" + "testing" + "time" +) + +// Verify that in --certmode=manual mode, we can use a bare IP address +// as the --hostname and that GetCertificate will return it. +func TestCertIP(t *testing.T) { + dir := t.TempDir() + const hostname = "1.2.3.4" + + priv, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + if err != nil { + t.Fatal(err) + } + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + t.Fatal(err) + } + ip := net.ParseIP(hostname) + if ip == nil { + t.Fatalf("invalid IP address %q", hostname) + } + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Tailscale Test Corp"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(30 * 24 * time.Hour), + + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{ip}, + } + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + t.Fatal(err) + } + certOut, err := os.Create(filepath.Join(dir, hostname+".crt")) + if err != nil { + t.Fatal(err) + } + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + t.Fatalf("Failed to write data to cert.pem: %v", err) + } + if err := certOut.Close(); err != nil { + t.Fatalf("Error closing cert.pem: %v", err) + } + + keyOut, err := os.OpenFile(filepath.Join(dir, hostname+".key"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + t.Fatal(err) + } + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + t.Fatalf("Unable to marshal private key: %v", err) + } + if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { + t.Fatalf("Failed to write data to key.pem: %v", err) + } + if err := keyOut.Close(); err != nil { + t.Fatalf("Error closing key.pem: %v", err) + } + + cp, err := certProviderByCertMode("manual", dir, hostname) + if err != nil { + t.Fatal(err) + } + back, err := cp.TLSConfig().GetCertificate(&tls.ClientHelloInfo{ + ServerName: "", // no SNI + }) + if err != nil { + t.Fatalf("GetCertificate: %v", err) + } + if back == nil { + t.Fatalf("GetCertificate returned nil") + } +} diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index 51be3abbe3b93..6e24e0ab14b3d 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -58,7 +58,7 @@ var ( configPath = flag.String("c", "", "config file path") certMode = flag.String("certmode", "letsencrypt", "mode for getting a cert. possible options: manual, letsencrypt") certDir = flag.String("certdir", tsweb.DefaultCertDir("derper-certs"), "directory to store LetsEncrypt certs, if addr's port is :443") - hostname = flag.String("hostname", "derp.tailscale.com", "LetsEncrypt host name, if addr's port is :443") + hostname = flag.String("hostname", "derp.tailscale.com", "LetsEncrypt host name, if addr's port is :443. When --certmode=manual, this can be an IP address to avoid SNI checks") runSTUN = flag.Bool("stun", true, "whether to run a STUN server. It will bind to the same IP (if any) as the --addr flag value.") runDERP = flag.Bool("derp", true, "whether to run a DERP server. The only reason to set this false is if you're decommissioning a server but want to keep its bootstrap DNS functionality still running.") diff --git a/cmd/derper/derper_test.go b/cmd/derper/derper_test.go index 08d2e9cbf97c2..6dce1fcdfebdd 100644 --- a/cmd/derper/derper_test.go +++ b/cmd/derper/derper_test.go @@ -6,7 +6,6 @@ package main import ( "bytes" "context" - "fmt" "net/http" "net/http/httptest" "strings" @@ -138,5 +137,4 @@ func TestTemplate(t *testing.T) { if !strings.Contains(str, "Debug info") { t.Error("Output is missing debug info") } - fmt.Println(buf.String()) } From 0267fe83b200f1702a2fa0a395442c02a053fadb Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Thu, 5 Dec 2024 13:16:48 -0600 Subject: [PATCH 171/179] VERSION.txt: this is v1.78.0 Signed-off-by: Nick Khyl --- .bencher/config.yaml | 2 +- .gitattributes | 4 +- .github/ISSUE_TEMPLATE/bug_report.yml | 162 +- .github/ISSUE_TEMPLATE/config.yml | 14 +- .github/ISSUE_TEMPLATE/feature_request.yml | 84 +- .github/dependabot.yml | 42 +- AUTHORS | 34 +- CODEOWNERS | 2 +- CODE_OF_CONDUCT.md | 270 ++-- LICENSE | 56 +- PATENTS | 48 +- SECURITY.md | 16 +- VERSION.txt | 2 +- atomicfile/atomicfile.go | 102 +- atomicfile/atomicfile_test.go | 94 +- chirp/chirp.go | 326 ++-- chirp/chirp_test.go | 384 ++--- client/tailscale/apitype/controltype.go | 38 +- client/tailscale/dns.go | 466 +++--- client/tailscale/example/servetls/servetls.go | 56 +- client/tailscale/keys.go | 332 ++-- client/tailscale/routes.go | 190 +-- client/tailscale/tailnet.go | 84 +- client/web/qnap.go | 254 +-- client/web/src/assets/icons/arrow-right.svg | 8 +- .../web/src/assets/icons/arrow-up-circle.svg | 10 +- client/web/src/assets/icons/check-circle.svg | 8 +- client/web/src/assets/icons/check.svg | 6 +- client/web/src/assets/icons/chevron-down.svg | 6 +- client/web/src/assets/icons/eye.svg | 22 +- client/web/src/assets/icons/search.svg | 8 +- .../web/src/assets/icons/tailscale-icon.svg | 36 +- .../web/src/assets/icons/tailscale-logo.svg | 40 +- client/web/src/assets/icons/user.svg | 8 +- client/web/src/assets/icons/x-circle.svg | 10 +- client/web/synology.go | 118 +- clientupdate/distsign/distsign.go | 972 +++++------ clientupdate/distsign/roots.go | 108 +- clientupdate/distsign/roots/crawshaw-root.pem | 6 +- .../roots/distsign-prod-root-1-pub.pem | 6 +- clientupdate/distsign/roots_test.go | 32 +- cmd/addlicense/main.go | 146 +- cmd/cloner/cloner_test.go | 120 +- cmd/containerboot/test_tailscale.sh | 16 +- cmd/containerboot/test_tailscaled.sh | 76 +- cmd/get-authkey/.gitignore | 2 +- cmd/gitops-pusher/.gitignore | 2 +- cmd/gitops-pusher/README.md | 96 +- cmd/gitops-pusher/cache.go | 132 +- cmd/gitops-pusher/gitops-pusher_test.go | 110 +- cmd/k8s-operator/deploy/chart/.helmignore | 46 +- cmd/k8s-operator/deploy/chart/Chart.yaml | 58 +- .../chart/templates/apiserverproxy-rbac.yaml | 52 +- .../deploy/chart/templates/oauth-secret.yaml | 26 +- .../deploy/manifests/authproxy-rbac.yaml | 46 +- cmd/mkmanifest/main.go | 102 +- cmd/mkpkg/main.go | 268 ++-- cmd/mkversion/mkversion.go | 88 +- cmd/nardump/README.md | 14 +- cmd/nardump/nardump.go | 368 ++--- cmd/nginx-auth/.gitignore | 8 +- cmd/nginx-auth/README.md | 322 ++-- cmd/nginx-auth/deb/postinst.sh | 28 +- cmd/nginx-auth/deb/postrm.sh | 38 +- cmd/nginx-auth/deb/prerm.sh | 16 +- cmd/nginx-auth/mkdeb.sh | 64 +- cmd/nginx-auth/nginx-auth.go | 256 +-- cmd/nginx-auth/rpm/postrm.sh | 18 +- cmd/nginx-auth/rpm/prerm.sh | 18 +- cmd/nginx-auth/tailscale.nginx-auth.service | 22 +- cmd/nginx-auth/tailscale.nginx-auth.socket | 16 +- cmd/pgproxy/README.md | 84 +- cmd/printdep/printdep.go | 82 +- cmd/sniproxy/.gitignore | 2 +- cmd/sniproxy/handlers_test.go | 318 ++-- cmd/sniproxy/server.go | 654 ++++---- cmd/sniproxy/server_test.go | 190 +-- cmd/sniproxy/sniproxy.go | 582 +++---- cmd/speedtest/speedtest.go | 242 +-- cmd/ssh-auth-none-demo/ssh-auth-none-demo.go | 374 ++--- cmd/sync-containers/main.go | 428 ++--- cmd/tailscale/cli/diag.go | 148 +- cmd/tailscale/cli/diag_other.go | 30 +- cmd/tailscale/cli/set_test.go | 262 +-- cmd/tailscale/cli/ssh_exec.go | 48 +- cmd/tailscale/cli/ssh_exec_js.go | 32 +- cmd/tailscale/cli/ssh_exec_windows.go | 74 +- cmd/tailscale/cli/ssh_unix.go | 98 +- cmd/tailscale/cli/web_test.go | 90 +- cmd/tailscale/generate.go | 16 +- cmd/tailscale/tailscale.go | 52 +- cmd/tailscale/windows-manifest.xml | 26 +- cmd/tailscaled/childproc/childproc.go | 38 +- cmd/tailscaled/generate.go | 16 +- cmd/tailscaled/install_darwin.go | 398 ++--- cmd/tailscaled/install_windows.go | 248 +-- cmd/tailscaled/proxy.go | 160 +- cmd/tailscaled/sigpipe.go | 24 +- cmd/tailscaled/tailscaled.defaults | 16 +- cmd/tailscaled/tailscaled.openrc | 50 +- cmd/tailscaled/tailscaled_bird.go | 34 +- cmd/tailscaled/tailscaled_notwindows.go | 28 +- cmd/tailscaled/windows-manifest.xml | 26 +- cmd/tailscaled/with_cli.go | 46 +- cmd/testwrapper/args_test.go | 194 +-- cmd/testwrapper/flakytest/flakytest.go | 88 +- cmd/testwrapper/flakytest/flakytest_test.go | 86 +- cmd/tsconnect/.gitignore | 6 +- cmd/tsconnect/README.md | 98 +- cmd/tsconnect/README.pkg.md | 6 +- cmd/tsconnect/build-pkg.go | 198 +-- cmd/tsconnect/dev-pkg.go | 36 +- cmd/tsconnect/dev.go | 36 +- cmd/tsconnect/dist/placeholder | 4 +- cmd/tsconnect/index.html | 40 +- cmd/tsconnect/package.json | 50 +- cmd/tsconnect/package.json.tmpl | 32 +- cmd/tsconnect/serve.go | 288 ++-- cmd/tsconnect/src/app/app.tsx | 294 ++-- cmd/tsconnect/src/app/go-panic-display.tsx | 40 +- cmd/tsconnect/src/app/header.tsx | 74 +- cmd/tsconnect/src/app/index.css | 148 +- cmd/tsconnect/src/app/index.ts | 72 +- cmd/tsconnect/src/app/ssh.tsx | 314 ++-- cmd/tsconnect/src/app/url-display.tsx | 62 +- cmd/tsconnect/src/lib/js-state-store.ts | 26 +- cmd/tsconnect/src/pkg/pkg.css | 16 +- cmd/tsconnect/src/pkg/pkg.ts | 80 +- cmd/tsconnect/src/types/esbuild.d.ts | 28 +- cmd/tsconnect/src/types/wasm_js.d.ts | 206 +-- cmd/tsconnect/tailwind.config.js | 16 +- cmd/tsconnect/tsconfig.json | 30 +- cmd/tsconnect/tsconnect.go | 142 +- cmd/tsconnect/yarn.lock | 1426 ++++++++--------- cmd/tsshd/tsshd.go | 24 +- control/controlbase/conn.go | 816 +++++----- control/controlbase/handshake.go | 988 ++++++------ control/controlbase/interop_test.go | 512 +++--- control/controlbase/messages.go | 174 +- control/controlclient/sign.go | 84 +- control/controlclient/sign_supported_test.go | 472 +++--- control/controlclient/sign_unsupported.go | 32 +- control/controlclient/status.go | 250 +-- control/controlhttp/client_common.go | 34 +- derp/README.md | 120 +- derp/testdata/example_ss.txt | 16 +- disco/disco_fuzzer.go | 34 +- disco/disco_test.go | 236 +-- disco/pcap.go | 80 +- docs/bird/sample_bird.conf | 32 +- docs/bird/tailscale_bird.conf | 8 +- docs/k8s/Makefile | 50 +- docs/k8s/rolebinding.yaml | 26 +- docs/k8s/sa.yaml | 12 +- docs/sysv/tailscale.init | 126 +- doctor/doctor.go | 158 +- doctor/doctor_test.go | 98 +- doctor/permissions/permissions_bsd.go | 46 +- doctor/permissions/permissions_linux.go | 124 +- doctor/permissions/permissions_other.go | 34 +- doctor/permissions/permissions_test.go | 24 +- doctor/routetable/routetable.go | 68 +- envknob/envknob_nottest.go | 32 +- envknob/envknob_testable.go | 46 +- envknob/logknob/logknob.go | 170 +- envknob/logknob/logknob_test.go | 204 +-- gomod_test.go | 50 +- hostinfo/hostinfo_darwin.go | 42 +- hostinfo/hostinfo_freebsd.go | 128 +- hostinfo/hostinfo_test.go | 102 +- hostinfo/hostinfo_uname.go | 76 +- hostinfo/wol.go | 212 +-- ipn/ipnlocal/breaktcp_darwin.go | 60 +- ipn/ipnlocal/breaktcp_linux.go | 60 +- ipn/ipnlocal/expiry_test.go | 602 +++---- ipn/ipnlocal/peerapi_h2c.go | 40 +- ipn/ipnlocal/testdata/example.com-key.pem | 54 +- ipn/ipnlocal/testdata/example.com.pem | 50 +- ipn/ipnlocal/testdata/rootCA.pem | 58 +- ipn/ipnserver/proxyconnect_js.go | 20 +- ipn/ipnserver/server_test.go | 92 +- ipn/localapi/disabled_stubs.go | 30 +- ipn/localapi/pprof.go | 56 +- ipn/policy/policy.go | 94 +- ipn/store/awsstore/store_aws.go | 372 ++--- ipn/store/awsstore/store_aws_stub.go | 36 +- ipn/store/awsstore/store_aws_test.go | 328 ++-- ipn/store/stores_test.go | 358 ++--- ipn/store_test.go | 96 +- jsondb/db.go | 114 +- jsondb/db_test.go | 110 +- licenses/licenses.go | 42 +- log/filelogger/log.go | 456 +++--- log/filelogger/log_test.go | 54 +- logpolicy/logpolicy_test.go | 72 +- logtail/.gitignore | 12 +- logtail/README.md | 18 +- logtail/api.md | 388 ++--- logtail/example/logreprocess/demo.sh | 172 +- logtail/example/logreprocess/logreprocess.go | 230 +-- logtail/example/logtail/logtail.go | 92 +- logtail/filch/filch.go | 568 +++---- logtail/filch/filch_stub.go | 46 +- logtail/filch/filch_unix.go | 60 +- logtail/filch/filch_windows.go | 86 +- metrics/fds_linux.go | 82 +- metrics/fds_notlinux.go | 16 +- metrics/metrics.go | 326 ++-- net/art/art_test.go | 40 +- net/art/table.go | 1282 +++++++-------- net/dns/debian_resolvconf.go | 368 ++--- net/dns/direct_notlinux.go | 20 +- net/dns/flush_default.go | 20 +- net/dns/ini.go | 60 +- net/dns/ini_test.go | 76 +- net/dns/noop.go | 34 +- net/dns/resolvconf-workaround.sh | 124 +- net/dns/resolvconf.go | 60 +- net/dns/resolvconffile/resolvconffile.go | 248 +-- net/dns/resolvconfpath_default.go | 22 +- net/dns/resolvconfpath_gokrazy.go | 22 +- net/dns/resolver/doh_test.go | 198 +-- net/dns/resolver/macios_ext.go | 52 +- net/dns/resolver/tsdns_server_test.go | 666 ++++---- net/dns/utf.go | 110 +- net/dns/utf_test.go | 48 +- net/dnscache/dnscache_test.go | 484 +++--- net/dnscache/messagecache_test.go | 582 +++---- net/dnsfallback/update-dns-fallbacks.go | 90 +- net/memnet/conn.go | 228 +-- net/memnet/conn_test.go | 42 +- net/memnet/listener.go | 200 +-- net/memnet/listener_test.go | 66 +- net/memnet/memnet.go | 16 +- net/memnet/pipe.go | 488 +++--- net/memnet/pipe_test.go | 234 +-- net/netaddr/netaddr.go | 98 +- net/neterror/neterror.go | 164 +- net/neterror/neterror_linux.go | 52 +- net/neterror/neterror_linux_test.go | 108 +- net/neterror/neterror_windows.go | 32 +- net/netkernelconf/netkernelconf.go | 10 +- net/netknob/netknob.go | 58 +- net/netmon/netmon_darwin_test.go | 54 +- net/netmon/netmon_freebsd.go | 112 +- net/netmon/netmon_linux.go | 580 +++---- net/netmon/netmon_polling.go | 42 +- net/netmon/polling.go | 172 +- net/netns/netns_android.go | 150 +- net/netns/netns_default.go | 44 +- net/netns/netns_linux_test.go | 28 +- net/netns/netns_test.go | 156 +- net/netns/socks.go | 38 +- net/netstat/netstat.go | 70 +- net/netstat/netstat_noimpl.go | 28 +- net/netstat/netstat_test.go | 42 +- net/packet/doc.go | 30 +- net/packet/header.go | 132 +- net/packet/icmp.go | 56 +- net/packet/icmp6_test.go | 158 +- net/packet/ip4.go | 232 +-- net/packet/ip6.go | 152 +- net/packet/tsmp_test.go | 146 +- net/packet/udp4.go | 116 +- net/packet/udp6.go | 108 +- net/ping/ping.go | 686 ++++---- net/ping/ping_test.go | 700 ++++---- net/portmapper/pcp_test.go | 124 +- net/proxymux/mux.go | 288 ++-- net/routetable/routetable_darwin.go | 72 +- net/routetable/routetable_freebsd.go | 56 +- net/routetable/routetable_other.go | 34 +- net/sockstats/sockstats.go | 242 +-- net/sockstats/sockstats_noop.go | 76 +- net/sockstats/sockstats_tsgo_darwin.go | 60 +- net/speedtest/speedtest.go | 174 +- net/speedtest/speedtest_client.go | 82 +- net/speedtest/speedtest_server.go | 292 ++-- net/speedtest/speedtest_test.go | 166 +- net/stun/stun.go | 624 ++++---- net/stun/stun_fuzzer.go | 24 +- net/tcpinfo/tcpinfo.go | 102 +- net/tcpinfo/tcpinfo_darwin.go | 66 +- net/tcpinfo/tcpinfo_linux.go | 66 +- net/tcpinfo/tcpinfo_other.go | 30 +- net/tlsdial/deps_test.go | 16 +- net/tsdial/dnsmap_test.go | 250 +-- net/tsdial/dohclient.go | 200 +-- net/tsdial/dohclient_test.go | 62 +- net/tshttpproxy/mksyscall.go | 22 +- net/tshttpproxy/tshttpproxy_linux.go | 48 +- net/tshttpproxy/tshttpproxy_synology_test.go | 752 ++++----- net/tshttpproxy/tshttpproxy_windows.go | 552 +++---- net/tstun/fake.go | 116 +- net/tstun/ifstatus_noop.go | 36 +- net/tstun/ifstatus_windows.go | 218 +-- net/tstun/linkattrs_linux.go | 126 +- net/tstun/linkattrs_notlinux.go | 24 +- net/tstun/mtu.go | 322 ++-- net/tstun/mtu_test.go | 198 +-- net/tstun/tun_linux.go | 206 +-- net/tstun/tun_macos.go | 50 +- net/tstun/tun_notwindows.go | 24 +- packages/deb/deb.go | 364 ++--- packages/deb/deb_test.go | 410 ++--- paths/migrate.go | 116 +- paths/paths.go | 184 +-- paths/paths_windows.go | 200 +-- portlist/clean.go | 72 +- portlist/clean_test.go | 114 +- portlist/netstat_test.go | 184 +-- portlist/poller.go | 244 +-- portlist/portlist.go | 160 +- portlist/portlist_macos.go | 460 +++--- portlist/portlist_windows.go | 206 +-- posture/serialnumber_macos.go | 148 +- posture/serialnumber_notmacos_test.go | 76 +- posture/serialnumber_test.go | 32 +- pull-toolchain.sh | 32 +- release/deb/debian.postrm.sh | 34 +- release/deb/debian.prerm.sh | 14 +- release/dist/memoize.go | 172 +- release/dist/synology/files/Tailscale.sc | 10 +- release/dist/synology/files/config | 22 +- release/dist/synology/files/index.cgi | 4 +- release/dist/synology/files/logrotate-dsm6 | 16 +- release/dist/synology/files/logrotate-dsm7 | 16 +- release/dist/synology/files/privilege-dsm6 | 14 +- release/dist/synology/files/privilege-dsm7 | 14 +- .../files/privilege-dsm7.for-package-center | 26 +- release/dist/synology/files/resource | 20 +- .../dist/synology/files/scripts/postupgrade | 4 +- .../dist/synology/files/scripts/preupgrade | 4 +- .../synology/files/scripts/start-stop-status | 258 +-- release/dist/unixpkgs/pkgs.go | 944 +++++------ release/dist/unixpkgs/targets.go | 254 +-- release/release.go | 30 +- release/rpm/rpm.postinst.sh | 82 +- release/rpm/rpm.postrm.sh | 16 +- release/rpm/rpm.prerm.sh | 16 +- safesocket/safesocket_test.go | 24 +- smallzstd/testdata | 28 +- smallzstd/zstd.go | 156 +- syncs/locked.go | 64 +- syncs/locked_test.go | 240 +-- syncs/shardedmap.go | 276 ++-- syncs/shardedmap_test.go | 162 +- tailcfg/proto_port_range.go | 374 ++--- tailcfg/proto_port_range_test.go | 262 +-- tailcfg/tka.go | 528 +++--- taildrop/delete.go | 410 ++--- taildrop/delete_test.go | 304 ++-- taildrop/resume_test.go | 148 +- taildrop/retrieve.go | 356 ++-- tempfork/gliderlabs/ssh/LICENSE | 54 +- tempfork/gliderlabs/ssh/README.md | 192 +-- tempfork/gliderlabs/ssh/agent.go | 166 +- tempfork/gliderlabs/ssh/conn.go | 110 +- tempfork/gliderlabs/ssh/context.go | 328 ++-- tempfork/gliderlabs/ssh/context_test.go | 98 +- tempfork/gliderlabs/ssh/doc.go | 90 +- tempfork/gliderlabs/ssh/example_test.go | 100 +- tempfork/gliderlabs/ssh/options.go | 168 +- tempfork/gliderlabs/ssh/options_test.go | 222 +-- tempfork/gliderlabs/ssh/server.go | 918 +++++------ tempfork/gliderlabs/ssh/server_test.go | 256 +-- tempfork/gliderlabs/ssh/session.go | 772 ++++----- tempfork/gliderlabs/ssh/session_test.go | 880 +++++----- tempfork/gliderlabs/ssh/ssh.go | 312 ++-- tempfork/gliderlabs/ssh/ssh_test.go | 34 +- tempfork/gliderlabs/ssh/tcpip.go | 386 ++--- tempfork/gliderlabs/ssh/tcpip_test.go | 170 +- tempfork/gliderlabs/ssh/util.go | 314 ++-- tempfork/gliderlabs/ssh/wrap.go | 66 +- tempfork/heap/heap.go | 242 +-- tka/aum_test.go | 506 +++--- tka/builder.go | 360 ++--- tka/builder_test.go | 540 +++---- tka/deeplink.go | 442 ++--- tka/deeplink_test.go | 104 +- tka/key.go | 318 ++-- tka/key_test.go | 194 +-- tka/state.go | 630 ++++---- tka/state_test.go | 520 +++--- tka/sync_test.go | 754 ++++----- tka/tailchonk_test.go | 1386 ++++++++-------- tka/tka_test.go | 1308 +++++++-------- tool/binaryen.rev | 2 +- tool/go | 14 +- tool/gocross/env.go | 262 +-- tool/gocross/env_test.go | 198 +-- tool/gocross/exec_other.go | 40 +- tool/gocross/exec_unix.go | 24 +- tool/helm | 138 +- tool/helm.rev | 2 +- tool/node | 130 +- tool/wasm-opt | 148 +- tool/yarn | 86 +- tool/yarn.rev | 2 +- tsnet/example/tshello/tshello.go | 120 +- .../tsnet-http-client/tsnet-http-client.go | 88 +- tsnet/example/web-client/web-client.go | 92 +- tsnet/example_tshello_test.go | 144 +- tstest/allocs.go | 100 +- tstest/archtest/qemu_test.go | 146 +- tstest/clock.go | 1388 ++++++++-------- tstest/deptest/deptest_test.go | 20 +- tstest/integration/gen_deps.go | 130 +- tstest/integration/vms/README.md | 190 +-- tstest/integration/vms/distros.hujson | 78 +- tstest/integration/vms/distros_test.go | 28 +- tstest/integration/vms/dns_tester.go | 108 +- tstest/integration/vms/doc.go | 12 +- tstest/integration/vms/harness_test.go | 484 +++--- tstest/integration/vms/nixos_test.go | 462 +++--- tstest/integration/vms/regex_flag.go | 58 +- tstest/integration/vms/regex_flag_test.go | 42 +- tstest/integration/vms/runner.nix | 178 +- tstest/integration/vms/squid.conf | 76 +- tstest/integration/vms/top_level_test.go | 248 +-- tstest/integration/vms/udp_tester.go | 154 +- tstest/log_test.go | 94 +- tstest/natlab/firewall.go | 312 ++-- tstest/natlab/nat.go | 504 +++--- tstest/tstest.go | 190 +-- tstest/tstest_test.go | 48 +- tstime/mono/mono.go | 254 +-- tstime/rate/rate.go | 180 +-- tstime/tstime.go | 370 ++--- tstime/tstime_test.go | 72 +- tsweb/debug_test.go | 416 ++--- tsweb/promvarz/promvarz_test.go | 76 +- types/appctype/appconnector_test.go | 156 +- types/dnstype/dnstype.go | 136 +- types/empty/message.go | 26 +- types/flagtype/flagtype.go | 90 +- types/ipproto/ipproto.go | 398 ++--- types/key/chal.go | 182 +-- types/key/control.go | 136 +- types/key/control_test.go | 76 +- types/key/disco_test.go | 166 +- types/key/machine.go | 528 +++--- types/key/machine_test.go | 238 +-- types/key/nl_test.go | 96 +- types/lazy/unsync.go | 198 +-- types/lazy/unsync_test.go | 280 ++-- types/logger/rusage.go | 46 +- types/logger/rusage_stub.go | 22 +- types/logger/rusage_syscall.go | 58 +- types/logger/tokenbucket.go | 126 +- types/netlogtype/netlogtype.go | 200 +-- types/netlogtype/netlogtype_test.go | 78 +- types/netmap/netmap_test.go | 636 ++++---- types/nettype/nettype.go | 130 +- types/preftype/netfiltermode.go | 92 +- types/ptr/ptr.go | 20 +- types/structs/structs.go | 30 +- types/tkatype/tkatype.go | 80 +- types/tkatype/tkatype_test.go | 86 +- util/cibuild/cibuild.go | 28 +- util/cstruct/cstruct.go | 356 ++-- util/cstruct/cstruct_example_test.go | 146 +- util/deephash/debug.go | 74 +- util/deephash/pointer.go | 228 +-- util/deephash/pointer_norace.go | 26 +- util/deephash/pointer_race.go | 198 +-- util/deephash/testtype/testtype.go | 30 +- util/dirwalk/dirwalk.go | 106 +- util/dirwalk/dirwalk_linux.go | 334 ++-- util/dirwalk/dirwalk_test.go | 182 +-- util/goroutines/goroutines.go | 186 +-- util/goroutines/goroutines_test.go | 58 +- util/groupmember/groupmember.go | 58 +- util/hashx/block512.go | 394 ++--- util/httphdr/httphdr.go | 394 ++--- util/httphdr/httphdr_test.go | 192 +-- util/httpm/httpm.go | 72 +- util/httpm/httpm_test.go | 74 +- util/jsonutil/types.go | 32 +- util/jsonutil/unmarshal.go | 178 +- util/lineread/lineread.go | 74 +- util/linuxfw/linuxfwtest/linuxfwtest.go | 62 +- .../linuxfwtest/linuxfwtest_unsupported.go | 36 +- util/linuxfw/nftables_types.go | 190 +-- util/mak/mak.go | 140 +- util/mak/mak_test.go | 176 +- util/multierr/multierr.go | 272 ++-- util/must/must.go | 50 +- util/osdiag/mksyscall.go | 26 +- util/osdiag/osdiag_windows_test.go | 256 +-- util/osshare/filesharingstatus_noop.go | 24 +- util/pidowner/pidowner.go | 48 +- util/pidowner/pidowner_noimpl.go | 16 +- util/pidowner/pidowner_windows.go | 70 +- util/precompress/precompress.go | 258 +-- util/quarantine/quarantine.go | 28 +- util/quarantine/quarantine_darwin.go | 112 +- util/quarantine/quarantine_default.go | 28 +- util/quarantine/quarantine_windows.go | 58 +- util/race/race_test.go | 198 +-- util/racebuild/off.go | 16 +- util/racebuild/on.go | 16 +- util/racebuild/racebuild.go | 12 +- util/rands/rands.go | 50 +- util/rands/rands_test.go | 30 +- util/set/handle.go | 56 +- util/set/slice_test.go | 112 +- util/sysresources/memory.go | 20 +- util/sysresources/memory_bsd.go | 32 +- util/sysresources/memory_darwin.go | 32 +- util/sysresources/memory_linux.go | 38 +- util/sysresources/memory_unsupported.go | 16 +- util/sysresources/sysresources.go | 12 +- util/sysresources/sysresources_test.go | 50 +- util/systemd/doc.go | 26 +- util/systemd/systemd_linux.go | 154 +- util/systemd/systemd_nonlinux.go | 18 +- util/testenv/testenv.go | 42 +- util/truncate/truncate_test.go | 72 +- util/uniq/slice.go | 124 +- util/winutil/authenticode/mksyscall.go | 36 +- util/winutil/policy/policy_windows.go | 310 ++-- util/winutil/policy/policy_windows_test.go | 76 +- version/.gitignore | 20 +- version/cmdname.go | 278 ++-- version/cmdname_ios.go | 36 +- version/cmp_test.go | 164 +- version/export_test.go | 28 +- version/print.go | 66 +- version/race.go | 20 +- version/race_off.go | 20 +- version/version_test.go | 102 +- wgengine/bench/bench.go | 818 +++++----- wgengine/bench/bench_test.go | 216 +-- wgengine/bench/trafficgen.go | 518 +++--- wgengine/capture/capture.go | 476 +++--- wgengine/magicsock/blockforever_conn.go | 110 +- wgengine/magicsock/endpoint_default.go | 44 +- wgengine/magicsock/endpoint_stub.go | 26 +- wgengine/magicsock/endpoint_tracker.go | 496 +++--- wgengine/magicsock/magicsock_unix_test.go | 120 +- wgengine/magicsock/peermtu_darwin.go | 102 +- wgengine/magicsock/peermtu_linux.go | 98 +- wgengine/magicsock/peermtu_unix.go | 84 +- wgengine/mem_ios.go | 40 +- wgengine/netstack/netstack_linux.go | 38 +- wgengine/router/runner.go | 240 +-- wgengine/watchdog_js.go | 34 +- wgengine/wgcfg/device.go | 136 +- wgengine/wgcfg/device_test.go | 522 +++--- wgengine/wgcfg/parser.go | 372 ++--- wgengine/winnet/winnet_windows.go | 52 +- words/words.go | 116 +- words/words_test.go | 76 +- 554 files changed, 44582 insertions(+), 44582 deletions(-) diff --git a/.bencher/config.yaml b/.bencher/config.yaml index 220bd9d3b7dc0..b60c5c352d48a 100644 --- a/.bencher/config.yaml +++ b/.bencher/config.yaml @@ -1 +1 @@ -suppress_failure_on_regression: true +suppress_failure_on_regression: true diff --git a/.gitattributes b/.gitattributes index 3eb52878271f3..38a6b06a3147f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,2 @@ -go.mod filter=go-mod -*.go diff=golang +go.mod filter=go-mod +*.go diff=golang diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 9163171c90248..688de14440a46 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -1,81 +1,81 @@ -name: Bug report -description: File a bug report. If you need help, contact support instead -labels: [needs-triage, bug] -body: - - type: markdown - attributes: - value: | - Need help with your tailnet? [Contact support](https://tailscale.com/contact/support) instead. - Otherwise, please check if your bug is [already filed](https://github.com/tailscale/tailscale/issues) before filing a new one. - - type: textarea - id: what-happened - attributes: - label: What is the issue? - description: What happened? What did you expect to happen? - validations: - required: true - - type: textarea - id: steps - attributes: - label: Steps to reproduce - description: What are the steps you took that hit this issue? - validations: - required: false - - type: textarea - id: changes - attributes: - label: Are there any recent changes that introduced the issue? - description: If so, what are those changes? - validations: - required: false - - type: dropdown - id: os - attributes: - label: OS - description: What OS are you using? You may select more than one. - multiple: true - options: - - Linux - - macOS - - Windows - - iOS - - Android - - Synology - - Other - validations: - required: false - - type: input - id: os-version - attributes: - label: OS version - description: What OS version are you using? - placeholder: e.g., Debian 11.0, macOS Big Sur 11.6, Synology DSM 7 - validations: - required: false - - type: input - id: ts-version - attributes: - label: Tailscale version - description: What Tailscale version are you using? - placeholder: e.g., 1.14.4 - validations: - required: false - - type: textarea - id: other-software - attributes: - label: Other software - description: What [other software](https://github.com/tailscale/tailscale/wiki/OtherSoftwareInterop) (networking, security, etc) are you running? - validations: - required: false - - type: input - id: bug-report - attributes: - label: Bug report - description: Please run [`tailscale bugreport`](https://tailscale.com/kb/1080/cli/?q=Cli#bugreport) and share the bug identifier. The identifier is a random string which allows Tailscale support to locate your account and gives a point to focus on when looking for errors. - placeholder: e.g., BUG-1b7641a16971a9cd75822c0ed8043fee70ae88cf05c52981dc220eb96a5c49a8-20210427151443Z-fbcd4fd3a4b7ad94 - validations: - required: false - - type: markdown - attributes: - value: | - Thanks for filing a bug report! +name: Bug report +description: File a bug report. If you need help, contact support instead +labels: [needs-triage, bug] +body: + - type: markdown + attributes: + value: | + Need help with your tailnet? [Contact support](https://tailscale.com/contact/support) instead. + Otherwise, please check if your bug is [already filed](https://github.com/tailscale/tailscale/issues) before filing a new one. + - type: textarea + id: what-happened + attributes: + label: What is the issue? + description: What happened? What did you expect to happen? + validations: + required: true + - type: textarea + id: steps + attributes: + label: Steps to reproduce + description: What are the steps you took that hit this issue? + validations: + required: false + - type: textarea + id: changes + attributes: + label: Are there any recent changes that introduced the issue? + description: If so, what are those changes? + validations: + required: false + - type: dropdown + id: os + attributes: + label: OS + description: What OS are you using? You may select more than one. + multiple: true + options: + - Linux + - macOS + - Windows + - iOS + - Android + - Synology + - Other + validations: + required: false + - type: input + id: os-version + attributes: + label: OS version + description: What OS version are you using? + placeholder: e.g., Debian 11.0, macOS Big Sur 11.6, Synology DSM 7 + validations: + required: false + - type: input + id: ts-version + attributes: + label: Tailscale version + description: What Tailscale version are you using? + placeholder: e.g., 1.14.4 + validations: + required: false + - type: textarea + id: other-software + attributes: + label: Other software + description: What [other software](https://github.com/tailscale/tailscale/wiki/OtherSoftwareInterop) (networking, security, etc) are you running? + validations: + required: false + - type: input + id: bug-report + attributes: + label: Bug report + description: Please run [`tailscale bugreport`](https://tailscale.com/kb/1080/cli/?q=Cli#bugreport) and share the bug identifier. The identifier is a random string which allows Tailscale support to locate your account and gives a point to focus on when looking for errors. + placeholder: e.g., BUG-1b7641a16971a9cd75822c0ed8043fee70ae88cf05c52981dc220eb96a5c49a8-20210427151443Z-fbcd4fd3a4b7ad94 + validations: + required: false + - type: markdown + attributes: + value: | + Thanks for filing a bug report! diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 3f4a31534b7d7..e3c44b6a1ab0a 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,8 +1,8 @@ -blank_issues_enabled: true -contact_links: - - name: Support - url: https://tailscale.com/contact/support/ - about: Contact us for support - - name: Troubleshooting - url: https://tailscale.com/kb/1023/troubleshooting +blank_issues_enabled: true +contact_links: + - name: Support + url: https://tailscale.com/contact/support/ + about: Contact us for support + - name: Troubleshooting + url: https://tailscale.com/kb/1023/troubleshooting about: See the troubleshooting guide for help addressing common issues \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index f7538627483ab..02ecae13c5acd 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -1,42 +1,42 @@ -name: Feature request -description: Propose a new feature -title: "FR: " -labels: [needs-triage, fr] -body: - - type: markdown - attributes: - value: | - Please check if your feature request is [already filed](https://github.com/tailscale/tailscale/issues). - Tell us about your idea! - - type: textarea - id: problem - attributes: - label: What are you trying to do? - description: Tell us about the problem you're trying to solve. - validations: - required: false - - type: textarea - id: solution - attributes: - label: How should we solve this? - description: If you have an idea of how you'd like to see this feature work, let us know. - validations: - required: false - - type: textarea - id: alternative - attributes: - label: What is the impact of not solving this? - description: (How) Are you currently working around the issue? - validations: - required: false - - type: textarea - id: context - attributes: - label: Anything else? - description: Any additional context to share, e.g., links - validations: - required: false - - type: markdown - attributes: - value: | - Thanks for filing a feature request! +name: Feature request +description: Propose a new feature +title: "FR: " +labels: [needs-triage, fr] +body: + - type: markdown + attributes: + value: | + Please check if your feature request is [already filed](https://github.com/tailscale/tailscale/issues). + Tell us about your idea! + - type: textarea + id: problem + attributes: + label: What are you trying to do? + description: Tell us about the problem you're trying to solve. + validations: + required: false + - type: textarea + id: solution + attributes: + label: How should we solve this? + description: If you have an idea of how you'd like to see this feature work, let us know. + validations: + required: false + - type: textarea + id: alternative + attributes: + label: What is the impact of not solving this? + description: (How) Are you currently working around the issue? + validations: + required: false + - type: textarea + id: context + attributes: + label: Anything else? + description: Any additional context to share, e.g., links + validations: + required: false + - type: markdown + attributes: + value: | + Thanks for filing a feature request! diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 14c912905363e..225132e5485c0 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,21 +1,21 @@ -# Documentation for this file can be found at: -# https://docs.github.com/en/code-security/supply-chain-security/keeping-your-dependencies-updated-automatically/configuration-options-for-dependency-updates -version: 2 -updates: - ## Disabled between releases. We reenable it briefly after every - ## stable release, pull in all changes, and close it again so that - ## the tree remains more stable during development and the upstream - ## changes have time to soak before the next release. - # - package-ecosystem: "gomod" - # directory: "/" - # schedule: - # interval: "daily" - # commit-message: - # prefix: "go.mod:" - # open-pull-requests-limit: 100 - - package-ecosystem: "github-actions" - directory: "/" - schedule: - interval: "weekly" - commit-message: - prefix: ".github:" +# Documentation for this file can be found at: +# https://docs.github.com/en/code-security/supply-chain-security/keeping-your-dependencies-updated-automatically/configuration-options-for-dependency-updates +version: 2 +updates: + ## Disabled between releases. We reenable it briefly after every + ## stable release, pull in all changes, and close it again so that + ## the tree remains more stable during development and the upstream + ## changes have time to soak before the next release. + # - package-ecosystem: "gomod" + # directory: "/" + # schedule: + # interval: "daily" + # commit-message: + # prefix: "go.mod:" + # open-pull-requests-limit: 100 + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + commit-message: + prefix: ".github:" diff --git a/AUTHORS b/AUTHORS index 03d5932c04746..3fafc44923b2c 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,17 +1,17 @@ -# This is the official list of Tailscale -# authors for copyright purposes. -# -# Names should be added to this file as one of -# Organization's name -# Individual's name -# Individual's name -# -# Please keep the list sorted. -# -# You do not need to add entries to this list, and we don't actively -# populate this list. If you do want to be acknowledged explicitly as -# a copyright holder, though, then please send a PR referencing your -# earlier contributions and clarifying whether it's you or your -# company that owns the rights to your contribution. - -Tailscale Inc. +# This is the official list of Tailscale +# authors for copyright purposes. +# +# Names should be added to this file as one of +# Organization's name +# Individual's name +# Individual's name +# +# Please keep the list sorted. +# +# You do not need to add entries to this list, and we don't actively +# populate this list. If you do want to be acknowledged explicitly as +# a copyright holder, though, then please send a PR referencing your +# earlier contributions and clarifying whether it's you or your +# company that owns the rights to your contribution. + +Tailscale Inc. diff --git a/CODEOWNERS b/CODEOWNERS index af9b0d9f95928..76edf10061958 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1 +1 @@ -/tailcfg/ @tailscale/control-protocol-owners +/tailcfg/ @tailscale/control-protocol-owners diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index be5564ef4a3de..cf4e6ddbe4c31 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,135 +1,135 @@ -# Contributor Covenant Code of Conduct - -## Our Pledge - -We as members, contributors, and leaders pledge to make participation -in our community a harassment-free experience for everyone, regardless -of age, body size, visible or invisible disability, ethnicity, sex -characteristics, gender identity and expression, level of experience, -education, socio-economic status, nationality, personal appearance, -race, religion, or sexual identity and orientation. - -We pledge to act and interact in ways that contribute to an open, -welcoming, diverse, inclusive, and healthy community. - -## Our Standards - -Examples of behavior that contributes to a positive environment for -our community include: - -* Demonstrating empathy and kindness toward other people -* Being respectful of differing opinions, viewpoints, and experiences -* Giving and gracefully accepting constructive feedback -* Accepting responsibility and apologizing to those affected by our - mistakes, and learning from the experience -* Focusing on what is best not just for us as individuals, but for the - overall community - -Examples of unacceptable behavior include: - -* The use of sexualized language or imagery, and sexual attention or - advances of any kind -* Trolling, insulting or derogatory comments, and personal or - political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or email - address, without their explicit permission -* Other conduct which could reasonably be considered inappropriate in - a professional setting - -## Enforcement Responsibilities - -Community leaders are responsible for clarifying and enforcing our -standards of acceptable behavior and will take appropriate and fair -corrective action in response to any behavior that they deem -inappropriate, threatening, offensive, or harmful. - -Community leaders have the right and responsibility to remove, edit, -or reject comments, commits, code, wiki edits, issues, and other -contributions that are not aligned to this Code of Conduct, and will -communicate reasons for moderation decisions when appropriate. - -## Scope - -This Code of Conduct applies within all community spaces, and also -applies when an individual is officially representing the community in -public spaces. Examples of representing our community include using an -official e-mail address, posting via an official social media account, -or acting as an appointed representative at an online or offline -event. - -## Enforcement - -Instances of abusive, harassing, or otherwise unacceptable behavior -may be reported to the community leaders responsible for enforcement -at [info@tailscale.com](mailto:info@tailscale.com). All complaints -will be reviewed and investigated promptly and fairly. - -All community leaders are obligated to respect the privacy and -security of the reporter of any incident. - -## Enforcement Guidelines - -Community leaders will follow these Community Impact Guidelines in -determining the consequences for any action they deem in violation of -this Code of Conduct: - -### 1. Correction - -**Community Impact**: Use of inappropriate language or other behavior -deemed unprofessional or unwelcome in the community. - -**Consequence**: A private, written warning from community leaders, -providing clarity around the nature of the violation and an -explanation of why the behavior was inappropriate. A public apology -may be requested. - -### 2. Warning - -**Community Impact**: A violation through a single incident or series -of actions. - -**Consequence**: A warning with consequences for continued -behavior. No interaction with the people involved, including -unsolicited interaction with those enforcing the Code of Conduct, for -a specified period of time. This includes avoiding interactions in -community spaces as well as external channels like social -media. Violating these terms may lead to a temporary or permanent ban. - -### 3. Temporary Ban - -**Community Impact**: A serious violation of community standards, -including sustained inappropriate behavior. - -**Consequence**: A temporary ban from any sort of interaction or -public communication with the community for a specified period of -time. No public or private interaction with the people involved, -including unsolicited interaction with those enforcing the Code of -Conduct, is allowed during this period. Violating these terms may lead -to a permanent ban. - -### 4. Permanent Ban - -**Community Impact**: Demonstrating a pattern of violation of -community standards, including sustained inappropriate behavior, -harassment of an individual, or aggression toward or disparagement of -classes of individuals. - -**Consequence**: A permanent ban from any sort of public interaction -within the community. - -## Attribution - -This Code of Conduct is adapted from the [Contributor -Covenant][homepage], version 2.0, available at -https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. - -Community Impact Guidelines were inspired by [Mozilla's code of -conduct enforcement ladder](https://github.com/mozilla/diversity). - -[homepage]: https://www.contributor-covenant.org - -For answers to common questions about this code of conduct, see the -FAQ at https://www.contributor-covenant.org/faq. Translations are -available at https://www.contributor-covenant.org/translations. - +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation +in our community a harassment-free experience for everyone, regardless +of age, body size, visible or invisible disability, ethnicity, sex +characteristics, gender identity and expression, level of experience, +education, socio-economic status, nationality, personal appearance, +race, religion, or sexual identity and orientation. + +We pledge to act and interact in ways that contribute to an open, +welcoming, diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for +our community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our + mistakes, and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or + political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in + a professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our +standards of acceptable behavior and will take appropriate and fair +corrective action in response to any behavior that they deem +inappropriate, threatening, offensive, or harmful. + +Community leaders have the right and responsibility to remove, edit, +or reject comments, commits, code, wiki edits, issues, and other +contributions that are not aligned to this Code of Conduct, and will +communicate reasons for moderation decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also +applies when an individual is officially representing the community in +public spaces. Examples of representing our community include using an +official e-mail address, posting via an official social media account, +or acting as an appointed representative at an online or offline +event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior +may be reported to the community leaders responsible for enforcement +at [info@tailscale.com](mailto:info@tailscale.com). All complaints +will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and +security of the reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in +determining the consequences for any action they deem in violation of +this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior +deemed unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, +providing clarity around the nature of the violation and an +explanation of why the behavior was inappropriate. A public apology +may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued +behavior. No interaction with the people involved, including +unsolicited interaction with those enforcing the Code of Conduct, for +a specified period of time. This includes avoiding interactions in +community spaces as well as external channels like social +media. Violating these terms may lead to a temporary or permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, +including sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or +public communication with the community for a specified period of +time. No public or private interaction with the people involved, +including unsolicited interaction with those enforcing the Code of +Conduct, is allowed during this period. Violating these terms may lead +to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of +community standards, including sustained inappropriate behavior, +harassment of an individual, or aggression toward or disparagement of +classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction +within the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor +Covenant][homepage], version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of +conduct enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the +FAQ at https://www.contributor-covenant.org/faq. Translations are +available at https://www.contributor-covenant.org/translations. + diff --git a/LICENSE b/LICENSE index 394db19e4aa5c..3d511c30c1ff5 100644 --- a/LICENSE +++ b/LICENSE @@ -1,28 +1,28 @@ -BSD 3-Clause License - -Copyright (c) 2020 Tailscale Inc & AUTHORS. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +BSD 3-Clause License + +Copyright (c) 2020 Tailscale Inc & AUTHORS. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/PATENTS b/PATENTS index 560a2b8f0e401..b001fb9c1b0b2 100644 --- a/PATENTS +++ b/PATENTS @@ -1,24 +1,24 @@ -Additional IP Rights Grant (Patents) - -"This implementation" means the copyrightable works distributed by -Tailscale Inc. as part of the Tailscale project. - -Tailscale Inc. hereby grants to You a perpetual, worldwide, -non-exclusive, no-charge, royalty-free, irrevocable (except as stated -in this section) patent license to make, have made, use, offer to -sell, sell, import, transfer and otherwise run, modify and propagate -the contents of this implementation of Tailscale, where such license -applies only to those patent claims, both currently owned or -controlled by Tailscale Inc. and acquired in the future, licensable -by Tailscale Inc. that are necessarily infringed by this -implementation of Tailscale. This grant does not include claims that -would be infringed only as a consequence of further modification of -this implementation. If you or your agent or exclusive licensee -institute or order or agree to the institution of patent litigation -against any entity (including a cross-claim or counterclaim in a -lawsuit) alleging that this implementation of Tailscale or any code -incorporated within this implementation of Tailscale constitutes -direct or contributory patent infringement, or inducement of patent -infringement, then any patent rights granted to you under this License -for this implementation of Tailscale shall terminate as of the date -such litigation is filed. +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Tailscale Inc. as part of the Tailscale project. + +Tailscale Inc. hereby grants to You a perpetual, worldwide, +non-exclusive, no-charge, royalty-free, irrevocable (except as stated +in this section) patent license to make, have made, use, offer to +sell, sell, import, transfer and otherwise run, modify and propagate +the contents of this implementation of Tailscale, where such license +applies only to those patent claims, both currently owned or +controlled by Tailscale Inc. and acquired in the future, licensable +by Tailscale Inc. that are necessarily infringed by this +implementation of Tailscale. This grant does not include claims that +would be infringed only as a consequence of further modification of +this implementation. If you or your agent or exclusive licensee +institute or order or agree to the institution of patent litigation +against any entity (including a cross-claim or counterclaim in a +lawsuit) alleging that this implementation of Tailscale or any code +incorporated within this implementation of Tailscale constitutes +direct or contributory patent infringement, or inducement of patent +infringement, then any patent rights granted to you under this License +for this implementation of Tailscale shall terminate as of the date +such litigation is filed. diff --git a/SECURITY.md b/SECURITY.md index 26702b14143c3..e8cd9a326c787 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,8 +1,8 @@ -# Security Policy - -## Reporting a Vulnerability - -You can report vulnerabilities privately to -[security@tailscale.com](mailto:security@tailscale.com). Tailscale -staff will triage the issue, and work with you on a coordinated -disclosure timeline. +# Security Policy + +## Reporting a Vulnerability + +You can report vulnerabilities privately to +[security@tailscale.com](mailto:security@tailscale.com). Tailscale +staff will triage the issue, and work with you on a coordinated +disclosure timeline. diff --git a/VERSION.txt b/VERSION.txt index 79e15fd49370a..54227249d1ff9 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -1.77.0 +1.78.0 diff --git a/atomicfile/atomicfile.go b/atomicfile/atomicfile.go index 5c18e85a896eb..b95c7cbe14964 100644 --- a/atomicfile/atomicfile.go +++ b/atomicfile/atomicfile.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package atomicfile contains code related to writing to filesystems -// atomically. -// -// This package should be considered internal; its API is not stable. -package atomicfile // import "tailscale.com/atomicfile" - -import ( - "fmt" - "os" - "path/filepath" - "runtime" -) - -// WriteFile writes data to filename+some suffix, then renames it into filename. -// The perm argument is ignored on Windows. If the target filename already -// exists but is not a regular file, WriteFile returns an error. -func WriteFile(filename string, data []byte, perm os.FileMode) (err error) { - fi, err := os.Stat(filename) - if err == nil && !fi.Mode().IsRegular() { - return fmt.Errorf("%s already exists and is not a regular file", filename) - } - f, err := os.CreateTemp(filepath.Dir(filename), filepath.Base(filename)+".tmp") - if err != nil { - return err - } - tmpName := f.Name() - defer func() { - if err != nil { - f.Close() - os.Remove(tmpName) - } - }() - if _, err := f.Write(data); err != nil { - return err - } - if runtime.GOOS != "windows" { - if err := f.Chmod(perm); err != nil { - return err - } - } - if err := f.Sync(); err != nil { - return err - } - if err := f.Close(); err != nil { - return err - } - return os.Rename(tmpName, filename) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package atomicfile contains code related to writing to filesystems +// atomically. +// +// This package should be considered internal; its API is not stable. +package atomicfile // import "tailscale.com/atomicfile" + +import ( + "fmt" + "os" + "path/filepath" + "runtime" +) + +// WriteFile writes data to filename+some suffix, then renames it into filename. +// The perm argument is ignored on Windows. If the target filename already +// exists but is not a regular file, WriteFile returns an error. +func WriteFile(filename string, data []byte, perm os.FileMode) (err error) { + fi, err := os.Stat(filename) + if err == nil && !fi.Mode().IsRegular() { + return fmt.Errorf("%s already exists and is not a regular file", filename) + } + f, err := os.CreateTemp(filepath.Dir(filename), filepath.Base(filename)+".tmp") + if err != nil { + return err + } + tmpName := f.Name() + defer func() { + if err != nil { + f.Close() + os.Remove(tmpName) + } + }() + if _, err := f.Write(data); err != nil { + return err + } + if runtime.GOOS != "windows" { + if err := f.Chmod(perm); err != nil { + return err + } + } + if err := f.Sync(); err != nil { + return err + } + if err := f.Close(); err != nil { + return err + } + return os.Rename(tmpName, filename) +} diff --git a/atomicfile/atomicfile_test.go b/atomicfile/atomicfile_test.go index 78c93e664f738..b7a78765b745e 100644 --- a/atomicfile/atomicfile_test.go +++ b/atomicfile/atomicfile_test.go @@ -1,47 +1,47 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !js && !windows - -package atomicfile - -import ( - "net" - "os" - "path/filepath" - "runtime" - "strings" - "testing" -) - -func TestDoesNotOverwriteIrregularFiles(t *testing.T) { - // Per tailscale/tailscale#7658 as one example, almost any imagined use of - // atomicfile.Write should likely not attempt to overwrite an irregular file - // such as a device node, socket, or named pipe. - - const filename = "TestDoesNotOverwriteIrregularFiles" - var path string - // macOS private temp does not allow unix socket creation, but /tmp does. - if runtime.GOOS == "darwin" { - path = filepath.Join("/tmp", filename) - t.Cleanup(func() { os.Remove(path) }) - } else { - path = filepath.Join(t.TempDir(), filename) - } - - // The least troublesome thing to make that is not a file is a unix socket. - // Making a null device sadly requires root. - l, err := net.ListenUnix("unix", &net.UnixAddr{Name: path, Net: "unix"}) - if err != nil { - t.Fatal(err) - } - defer l.Close() - - err = WriteFile(path, []byte("hello"), 0644) - if err == nil { - t.Fatal("expected error, got nil") - } - if !strings.Contains(err.Error(), "is not a regular file") { - t.Fatalf("unexpected error: %v", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !windows + +package atomicfile + +import ( + "net" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func TestDoesNotOverwriteIrregularFiles(t *testing.T) { + // Per tailscale/tailscale#7658 as one example, almost any imagined use of + // atomicfile.Write should likely not attempt to overwrite an irregular file + // such as a device node, socket, or named pipe. + + const filename = "TestDoesNotOverwriteIrregularFiles" + var path string + // macOS private temp does not allow unix socket creation, but /tmp does. + if runtime.GOOS == "darwin" { + path = filepath.Join("/tmp", filename) + t.Cleanup(func() { os.Remove(path) }) + } else { + path = filepath.Join(t.TempDir(), filename) + } + + // The least troublesome thing to make that is not a file is a unix socket. + // Making a null device sadly requires root. + l, err := net.ListenUnix("unix", &net.UnixAddr{Name: path, Net: "unix"}) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + err = WriteFile(path, []byte("hello"), 0644) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "is not a regular file") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/chirp/chirp.go b/chirp/chirp.go index 9653877221778..1b448f2394106 100644 --- a/chirp/chirp.go +++ b/chirp/chirp.go @@ -1,163 +1,163 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package chirp implements a client to communicate with the BIRD Internet -// Routing Daemon. -package chirp - -import ( - "bufio" - "fmt" - "net" - "strings" - "time" -) - -const ( - // Maximum amount of time we should wait when reading a response from BIRD. - responseTimeout = 10 * time.Second -) - -// New creates a BIRDClient. -func New(socket string) (*BIRDClient, error) { - return newWithTimeout(socket, responseTimeout) -} - -func newWithTimeout(socket string, timeout time.Duration) (_ *BIRDClient, err error) { - conn, err := net.Dial("unix", socket) - if err != nil { - return nil, fmt.Errorf("failed to connect to BIRD: %w", err) - } - defer func() { - if err != nil { - conn.Close() - } - }() - - b := &BIRDClient{ - socket: socket, - conn: conn, - scanner: bufio.NewScanner(conn), - timeNow: time.Now, - timeout: timeout, - } - // Read and discard the first line as that is the welcome message. - if _, err := b.readResponse(); err != nil { - return nil, err - } - return b, nil -} - -// BIRDClient handles communication with the BIRD Internet Routing Daemon. -type BIRDClient struct { - socket string - conn net.Conn - scanner *bufio.Scanner - timeNow func() time.Time - timeout time.Duration -} - -// Close closes the underlying connection to BIRD. -func (b *BIRDClient) Close() error { return b.conn.Close() } - -// DisableProtocol disables the provided protocol. -func (b *BIRDClient) DisableProtocol(protocol string) error { - out, err := b.exec("disable %s", protocol) - if err != nil { - return err - } - if strings.Contains(out, fmt.Sprintf("%s: already disabled", protocol)) { - return nil - } else if strings.Contains(out, fmt.Sprintf("%s: disabled", protocol)) { - return nil - } - return fmt.Errorf("failed to disable %s: %v", protocol, out) -} - -// EnableProtocol enables the provided protocol. -func (b *BIRDClient) EnableProtocol(protocol string) error { - out, err := b.exec("enable %s", protocol) - if err != nil { - return err - } - if strings.Contains(out, fmt.Sprintf("%s: already enabled", protocol)) { - return nil - } else if strings.Contains(out, fmt.Sprintf("%s: enabled", protocol)) { - return nil - } - return fmt.Errorf("failed to enable %s: %v", protocol, out) -} - -// BIRD CLI docs from https://bird.network.cz/?get_doc&v=20&f=prog-2.html#ss2.9 - -// Each session of the CLI consists of a sequence of request and replies, -// slightly resembling the FTP and SMTP protocols. -// Requests are commands encoded as a single line of text, -// replies are sequences of lines starting with a four-digit code -// followed by either a space (if it's the last line of the reply) or -// a minus sign (when the reply is going to continue with the next line), -// the rest of the line contains a textual message semantics of which depends on the numeric code. -// If a reply line has the same code as the previous one and it's a continuation line, -// the whole prefix can be replaced by a single white space character. -// -// Reply codes starting with 0 stand for ‘action successfully completed’ messages, -// 1 means ‘table entry’, 8 ‘runtime error’ and 9 ‘syntax error’. - -func (b *BIRDClient) exec(cmd string, args ...any) (string, error) { - if err := b.conn.SetWriteDeadline(b.timeNow().Add(b.timeout)); err != nil { - return "", err - } - if _, err := fmt.Fprintf(b.conn, cmd, args...); err != nil { - return "", err - } - if _, err := fmt.Fprintln(b.conn); err != nil { - return "", err - } - return b.readResponse() -} - -// hasResponseCode reports whether the provided byte slice is -// prefixed with a BIRD response code. -// Equivalent regex: `^\d{4}[ -]`. -func hasResponseCode(s []byte) bool { - if len(s) < 5 { - return false - } - for _, b := range s[:4] { - if '0' <= b && b <= '9' { - continue - } - return false - } - return s[4] == ' ' || s[4] == '-' -} - -func (b *BIRDClient) readResponse() (string, error) { - // Set the read timeout before we start reading anything. - if err := b.conn.SetReadDeadline(b.timeNow().Add(b.timeout)); err != nil { - return "", err - } - - var resp strings.Builder - var done bool - for !done { - if !b.scanner.Scan() { - if err := b.scanner.Err(); err != nil { - return "", err - } - - return "", fmt.Errorf("reading response from bird failed (EOF): %q", resp.String()) - } - out := b.scanner.Bytes() - if _, err := resp.Write(out); err != nil { - return "", err - } - if hasResponseCode(out) { - done = out[4] == ' ' - } - if !done { - resp.WriteRune('\n') - } - } - return resp.String(), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package chirp implements a client to communicate with the BIRD Internet +// Routing Daemon. +package chirp + +import ( + "bufio" + "fmt" + "net" + "strings" + "time" +) + +const ( + // Maximum amount of time we should wait when reading a response from BIRD. + responseTimeout = 10 * time.Second +) + +// New creates a BIRDClient. +func New(socket string) (*BIRDClient, error) { + return newWithTimeout(socket, responseTimeout) +} + +func newWithTimeout(socket string, timeout time.Duration) (_ *BIRDClient, err error) { + conn, err := net.Dial("unix", socket) + if err != nil { + return nil, fmt.Errorf("failed to connect to BIRD: %w", err) + } + defer func() { + if err != nil { + conn.Close() + } + }() + + b := &BIRDClient{ + socket: socket, + conn: conn, + scanner: bufio.NewScanner(conn), + timeNow: time.Now, + timeout: timeout, + } + // Read and discard the first line as that is the welcome message. + if _, err := b.readResponse(); err != nil { + return nil, err + } + return b, nil +} + +// BIRDClient handles communication with the BIRD Internet Routing Daemon. +type BIRDClient struct { + socket string + conn net.Conn + scanner *bufio.Scanner + timeNow func() time.Time + timeout time.Duration +} + +// Close closes the underlying connection to BIRD. +func (b *BIRDClient) Close() error { return b.conn.Close() } + +// DisableProtocol disables the provided protocol. +func (b *BIRDClient) DisableProtocol(protocol string) error { + out, err := b.exec("disable %s", protocol) + if err != nil { + return err + } + if strings.Contains(out, fmt.Sprintf("%s: already disabled", protocol)) { + return nil + } else if strings.Contains(out, fmt.Sprintf("%s: disabled", protocol)) { + return nil + } + return fmt.Errorf("failed to disable %s: %v", protocol, out) +} + +// EnableProtocol enables the provided protocol. +func (b *BIRDClient) EnableProtocol(protocol string) error { + out, err := b.exec("enable %s", protocol) + if err != nil { + return err + } + if strings.Contains(out, fmt.Sprintf("%s: already enabled", protocol)) { + return nil + } else if strings.Contains(out, fmt.Sprintf("%s: enabled", protocol)) { + return nil + } + return fmt.Errorf("failed to enable %s: %v", protocol, out) +} + +// BIRD CLI docs from https://bird.network.cz/?get_doc&v=20&f=prog-2.html#ss2.9 + +// Each session of the CLI consists of a sequence of request and replies, +// slightly resembling the FTP and SMTP protocols. +// Requests are commands encoded as a single line of text, +// replies are sequences of lines starting with a four-digit code +// followed by either a space (if it's the last line of the reply) or +// a minus sign (when the reply is going to continue with the next line), +// the rest of the line contains a textual message semantics of which depends on the numeric code. +// If a reply line has the same code as the previous one and it's a continuation line, +// the whole prefix can be replaced by a single white space character. +// +// Reply codes starting with 0 stand for ‘action successfully completed’ messages, +// 1 means ‘table entry’, 8 ‘runtime error’ and 9 ‘syntax error’. + +func (b *BIRDClient) exec(cmd string, args ...any) (string, error) { + if err := b.conn.SetWriteDeadline(b.timeNow().Add(b.timeout)); err != nil { + return "", err + } + if _, err := fmt.Fprintf(b.conn, cmd, args...); err != nil { + return "", err + } + if _, err := fmt.Fprintln(b.conn); err != nil { + return "", err + } + return b.readResponse() +} + +// hasResponseCode reports whether the provided byte slice is +// prefixed with a BIRD response code. +// Equivalent regex: `^\d{4}[ -]`. +func hasResponseCode(s []byte) bool { + if len(s) < 5 { + return false + } + for _, b := range s[:4] { + if '0' <= b && b <= '9' { + continue + } + return false + } + return s[4] == ' ' || s[4] == '-' +} + +func (b *BIRDClient) readResponse() (string, error) { + // Set the read timeout before we start reading anything. + if err := b.conn.SetReadDeadline(b.timeNow().Add(b.timeout)); err != nil { + return "", err + } + + var resp strings.Builder + var done bool + for !done { + if !b.scanner.Scan() { + if err := b.scanner.Err(); err != nil { + return "", err + } + + return "", fmt.Errorf("reading response from bird failed (EOF): %q", resp.String()) + } + out := b.scanner.Bytes() + if _, err := resp.Write(out); err != nil { + return "", err + } + if hasResponseCode(out) { + done = out[4] == ' ' + } + if !done { + resp.WriteRune('\n') + } + } + return resp.String(), nil +} diff --git a/chirp/chirp_test.go b/chirp/chirp_test.go index 2549c163fd819..b8947a796c996 100644 --- a/chirp/chirp_test.go +++ b/chirp/chirp_test.go @@ -1,192 +1,192 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package chirp - -import ( - "bufio" - "errors" - "fmt" - "net" - "os" - "path/filepath" - "strings" - "sync" - "testing" - "time" -) - -type fakeBIRD struct { - net.Listener - protocolsEnabled map[string]bool - sock string -} - -func newFakeBIRD(t *testing.T, protocols ...string) *fakeBIRD { - sock := filepath.Join(t.TempDir(), "sock") - l, err := net.Listen("unix", sock) - if err != nil { - t.Fatal(err) - } - pe := make(map[string]bool) - for _, p := range protocols { - pe[p] = false - } - return &fakeBIRD{ - Listener: l, - protocolsEnabled: pe, - sock: sock, - } -} - -func (fb *fakeBIRD) listen() error { - for { - c, err := fb.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return nil - } - return err - } - go fb.handle(c) - } -} - -func (fb *fakeBIRD) handle(c net.Conn) { - fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.") - sc := bufio.NewScanner(c) - for sc.Scan() { - cmd := sc.Text() - args := strings.Split(cmd, " ") - switch args[0] { - case "enable": - en, ok := fb.protocolsEnabled[args[1]] - if !ok { - fmt.Fprintln(c, "9001 syntax error, unexpected CF_SYM_UNDEFINED, expecting CF_SYM_KNOWN or TEXT or ALL") - } else if en { - fmt.Fprintf(c, "0010-%s: already enabled\n", args[1]) - } else { - fmt.Fprintf(c, "0011-%s: enabled\n", args[1]) - } - fmt.Fprintln(c, "0000 ") - fb.protocolsEnabled[args[1]] = true - case "disable": - en, ok := fb.protocolsEnabled[args[1]] - if !ok { - fmt.Fprintln(c, "9001 syntax error, unexpected CF_SYM_UNDEFINED, expecting CF_SYM_KNOWN or TEXT or ALL") - } else if !en { - fmt.Fprintf(c, "0008-%s: already disabled\n", args[1]) - } else { - fmt.Fprintf(c, "0009-%s: disabled\n", args[1]) - } - fmt.Fprintln(c, "0000 ") - fb.protocolsEnabled[args[1]] = false - } - } -} - -func TestChirp(t *testing.T) { - fb := newFakeBIRD(t, "tailscale") - defer fb.Close() - go fb.listen() - c, err := New(fb.sock) - if err != nil { - t.Fatal(err) - } - if err := c.EnableProtocol("tailscale"); err != nil { - t.Fatal(err) - } - if err := c.EnableProtocol("tailscale"); err != nil { - t.Fatal(err) - } - if err := c.DisableProtocol("tailscale"); err != nil { - t.Fatal(err) - } - if err := c.DisableProtocol("tailscale"); err != nil { - t.Fatal(err) - } - if err := c.EnableProtocol("rando"); err == nil { - t.Fatalf("enabling %q succeeded", "rando") - } - if err := c.DisableProtocol("rando"); err == nil { - t.Fatalf("disabling %q succeeded", "rando") - } -} - -type hangingListener struct { - net.Listener - t *testing.T - done chan struct{} - wg sync.WaitGroup - sock string -} - -func newHangingListener(t *testing.T) *hangingListener { - sock := filepath.Join(t.TempDir(), "sock") - l, err := net.Listen("unix", sock) - if err != nil { - t.Fatal(err) - } - return &hangingListener{ - Listener: l, - t: t, - done: make(chan struct{}), - sock: sock, - } -} - -func (hl *hangingListener) Stop() { - hl.Close() - close(hl.done) - hl.wg.Wait() -} - -func (hl *hangingListener) listen() error { - for { - c, err := hl.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return nil - } - return err - } - hl.wg.Add(1) - go hl.handle(c) - } -} - -func (hl *hangingListener) handle(c net.Conn) { - defer hl.wg.Done() - - // Write our fake first line of response so that we get into the read loop - fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.") - - ticker := time.NewTicker(2 * time.Second) - defer ticker.Stop() - for { - select { - case <-ticker.C: - hl.t.Logf("connection still hanging") - case <-hl.done: - return - } - } -} - -func TestChirpTimeout(t *testing.T) { - fb := newHangingListener(t) - defer fb.Stop() - go fb.listen() - - c, err := newWithTimeout(fb.sock, 500*time.Millisecond) - if err != nil { - t.Fatal(err) - } - - err = c.EnableProtocol("tailscale") - if err == nil { - t.Fatal("got err=nil, want timeout") - } - if !os.IsTimeout(err) { - t.Fatalf("got err=%v, want os.IsTimeout(err)=true", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package chirp + +import ( + "bufio" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" +) + +type fakeBIRD struct { + net.Listener + protocolsEnabled map[string]bool + sock string +} + +func newFakeBIRD(t *testing.T, protocols ...string) *fakeBIRD { + sock := filepath.Join(t.TempDir(), "sock") + l, err := net.Listen("unix", sock) + if err != nil { + t.Fatal(err) + } + pe := make(map[string]bool) + for _, p := range protocols { + pe[p] = false + } + return &fakeBIRD{ + Listener: l, + protocolsEnabled: pe, + sock: sock, + } +} + +func (fb *fakeBIRD) listen() error { + for { + c, err := fb.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + go fb.handle(c) + } +} + +func (fb *fakeBIRD) handle(c net.Conn) { + fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.") + sc := bufio.NewScanner(c) + for sc.Scan() { + cmd := sc.Text() + args := strings.Split(cmd, " ") + switch args[0] { + case "enable": + en, ok := fb.protocolsEnabled[args[1]] + if !ok { + fmt.Fprintln(c, "9001 syntax error, unexpected CF_SYM_UNDEFINED, expecting CF_SYM_KNOWN or TEXT or ALL") + } else if en { + fmt.Fprintf(c, "0010-%s: already enabled\n", args[1]) + } else { + fmt.Fprintf(c, "0011-%s: enabled\n", args[1]) + } + fmt.Fprintln(c, "0000 ") + fb.protocolsEnabled[args[1]] = true + case "disable": + en, ok := fb.protocolsEnabled[args[1]] + if !ok { + fmt.Fprintln(c, "9001 syntax error, unexpected CF_SYM_UNDEFINED, expecting CF_SYM_KNOWN or TEXT or ALL") + } else if !en { + fmt.Fprintf(c, "0008-%s: already disabled\n", args[1]) + } else { + fmt.Fprintf(c, "0009-%s: disabled\n", args[1]) + } + fmt.Fprintln(c, "0000 ") + fb.protocolsEnabled[args[1]] = false + } + } +} + +func TestChirp(t *testing.T) { + fb := newFakeBIRD(t, "tailscale") + defer fb.Close() + go fb.listen() + c, err := New(fb.sock) + if err != nil { + t.Fatal(err) + } + if err := c.EnableProtocol("tailscale"); err != nil { + t.Fatal(err) + } + if err := c.EnableProtocol("tailscale"); err != nil { + t.Fatal(err) + } + if err := c.DisableProtocol("tailscale"); err != nil { + t.Fatal(err) + } + if err := c.DisableProtocol("tailscale"); err != nil { + t.Fatal(err) + } + if err := c.EnableProtocol("rando"); err == nil { + t.Fatalf("enabling %q succeeded", "rando") + } + if err := c.DisableProtocol("rando"); err == nil { + t.Fatalf("disabling %q succeeded", "rando") + } +} + +type hangingListener struct { + net.Listener + t *testing.T + done chan struct{} + wg sync.WaitGroup + sock string +} + +func newHangingListener(t *testing.T) *hangingListener { + sock := filepath.Join(t.TempDir(), "sock") + l, err := net.Listen("unix", sock) + if err != nil { + t.Fatal(err) + } + return &hangingListener{ + Listener: l, + t: t, + done: make(chan struct{}), + sock: sock, + } +} + +func (hl *hangingListener) Stop() { + hl.Close() + close(hl.done) + hl.wg.Wait() +} + +func (hl *hangingListener) listen() error { + for { + c, err := hl.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + hl.wg.Add(1) + go hl.handle(c) + } +} + +func (hl *hangingListener) handle(c net.Conn) { + defer hl.wg.Done() + + // Write our fake first line of response so that we get into the read loop + fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.") + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + hl.t.Logf("connection still hanging") + case <-hl.done: + return + } + } +} + +func TestChirpTimeout(t *testing.T) { + fb := newHangingListener(t) + defer fb.Stop() + go fb.listen() + + c, err := newWithTimeout(fb.sock, 500*time.Millisecond) + if err != nil { + t.Fatal(err) + } + + err = c.EnableProtocol("tailscale") + if err == nil { + t.Fatal("got err=nil, want timeout") + } + if !os.IsTimeout(err) { + t.Fatalf("got err=%v, want os.IsTimeout(err)=true", err) + } +} diff --git a/client/tailscale/apitype/controltype.go b/client/tailscale/apitype/controltype.go index 9a623be319606..a9a76065f711e 100644 --- a/client/tailscale/apitype/controltype.go +++ b/client/tailscale/apitype/controltype.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package apitype - -type DNSConfig struct { - Resolvers []DNSResolver `json:"resolvers"` - FallbackResolvers []DNSResolver `json:"fallbackResolvers"` - Routes map[string][]DNSResolver `json:"routes"` - Domains []string `json:"domains"` - Nameservers []string `json:"nameservers"` - Proxied bool `json:"proxied"` - TempCorpIssue13969 string `json:"TempCorpIssue13969,omitempty"` -} - -type DNSResolver struct { - Addr string `json:"addr"` - BootstrapResolution []string `json:"bootstrapResolution,omitempty"` -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package apitype + +type DNSConfig struct { + Resolvers []DNSResolver `json:"resolvers"` + FallbackResolvers []DNSResolver `json:"fallbackResolvers"` + Routes map[string][]DNSResolver `json:"routes"` + Domains []string `json:"domains"` + Nameservers []string `json:"nameservers"` + Proxied bool `json:"proxied"` + TempCorpIssue13969 string `json:"TempCorpIssue13969,omitempty"` +} + +type DNSResolver struct { + Addr string `json:"addr"` + BootstrapResolution []string `json:"bootstrapResolution,omitempty"` +} diff --git a/client/tailscale/dns.go b/client/tailscale/dns.go index f198742b3ca51..12b9e15c8b7a5 100644 --- a/client/tailscale/dns.go +++ b/client/tailscale/dns.go @@ -1,233 +1,233 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -package tailscale - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - - "tailscale.com/client/tailscale/apitype" -) - -// DNSNameServers is returned when retrieving the list of nameservers. -// It is also the structure provided when setting nameservers. -type DNSNameServers struct { - DNS []string `json:"dns"` // DNS name servers -} - -// DNSNameServersPostResponse is returned when setting the list of DNS nameservers. -// -// It includes the MagicDNS status since nameservers changes may affect MagicDNS. -type DNSNameServersPostResponse struct { - DNS []string `json:"dns"` // DNS name servers - MagicDNS bool `json:"magicDNS"` // whether MagicDNS is active for this tailnet (enabled + has fallback nameservers) -} - -// DNSSearchpaths is the list of search paths for a given domain. -type DNSSearchPaths struct { - SearchPaths []string `json:"searchPaths"` // DNS search paths -} - -// DNSPreferences is the preferences set for a given tailnet. -// -// It includes MagicDNS which can be turned on or off. To enable MagicDNS, -// there must be at least one nameserver. When all nameservers are removed, -// MagicDNS is disabled. -type DNSPreferences struct { - MagicDNS bool `json:"magicDNS"` // whether MagicDNS is active for this tailnet (enabled + has fallback nameservers) -} - -func (c *Client) dnsGETRequest(ctx context.Context, endpoint string) ([]byte, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) - req, err := http.NewRequestWithContext(ctx, "GET", path, nil) - if err != nil { - return nil, err - } - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - - // If status code was not successful, return the error. - // TODO: Change the check for the StatusCode to include other 2XX success codes. - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - return b, nil -} - -func (c *Client) dnsPOSTRequest(ctx context.Context, endpoint string, postData any) ([]byte, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) - data, err := json.Marshal(&postData) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(data)) - req.Header.Set("Content-Type", "application/json") - if err != nil { - return nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - - // If status code was not successful, return the error. - // TODO: Change the check for the StatusCode to include other 2XX success codes. - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - return b, nil -} - -// DNSConfig retrieves the DNSConfig settings for a domain. -func (c *Client) DNSConfig(ctx context.Context) (cfg *apitype.DNSConfig, err error) { - // Format return errors to be descriptive. - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.DNSConfig: %w", err) - } - }() - b, err := c.dnsGETRequest(ctx, "config") - if err != nil { - return nil, err - } - var dnsResp apitype.DNSConfig - err = json.Unmarshal(b, &dnsResp) - return &dnsResp, err -} - -func (c *Client) SetDNSConfig(ctx context.Context, cfg apitype.DNSConfig) (resp *apitype.DNSConfig, err error) { - // Format return errors to be descriptive. - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SetDNSConfig: %w", err) - } - }() - var dnsResp apitype.DNSConfig - b, err := c.dnsPOSTRequest(ctx, "config", cfg) - if err != nil { - return nil, err - } - err = json.Unmarshal(b, &dnsResp) - return &dnsResp, err -} - -// NameServers retrieves the list of nameservers set for a domain. -func (c *Client) NameServers(ctx context.Context) (nameservers []string, err error) { - // Format return errors to be descriptive. - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.NameServers: %w", err) - } - }() - b, err := c.dnsGETRequest(ctx, "nameservers") - if err != nil { - return nil, err - } - var dnsResp DNSNameServers - err = json.Unmarshal(b, &dnsResp) - return dnsResp.DNS, err -} - -// SetNameServers sets the list of nameservers for a tailnet to the list provided -// by the user. -// -// It returns the new list of nameservers and the MagicDNS status in case it was -// affected by the change. For example, removing all nameservers will turn off -// MagicDNS. -func (c *Client) SetNameServers(ctx context.Context, nameservers []string) (dnsResp *DNSNameServersPostResponse, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SetNameServers: %w", err) - } - }() - dnsReq := DNSNameServers{DNS: nameservers} - b, err := c.dnsPOSTRequest(ctx, "nameservers", dnsReq) - if err != nil { - return nil, err - } - err = json.Unmarshal(b, &dnsResp) - return dnsResp, err -} - -// DNSPreferences retrieves the DNS preferences set for a tailnet. -// -// It returns the status of MagicDNS. -func (c *Client) DNSPreferences(ctx context.Context) (dnsResp *DNSPreferences, err error) { - // Format return errors to be descriptive. - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.DNSPreferences: %w", err) - } - }() - b, err := c.dnsGETRequest(ctx, "preferences") - if err != nil { - return nil, err - } - err = json.Unmarshal(b, &dnsResp) - return dnsResp, err -} - -// SetDNSPreferences sets the DNS preferences for a tailnet. -// -// MagicDNS can only be enabled when there is at least one nameserver provided. -// When all nameservers are removed, MagicDNS is disabled and will stay disabled, -// unless explicitly enabled by a user again. -func (c *Client) SetDNSPreferences(ctx context.Context, magicDNS bool) (dnsResp *DNSPreferences, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SetDNSPreferences: %w", err) - } - }() - dnsReq := DNSPreferences{MagicDNS: magicDNS} - b, err := c.dnsPOSTRequest(ctx, "preferences", dnsReq) - if err != nil { - return - } - err = json.Unmarshal(b, &dnsResp) - return dnsResp, err -} - -// SearchPaths retrieves the list of searchpaths set for a tailnet. -func (c *Client) SearchPaths(ctx context.Context) (searchpaths []string, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SearchPaths: %w", err) - } - }() - b, err := c.dnsGETRequest(ctx, "searchpaths") - if err != nil { - return nil, err - } - var dnsResp *DNSSearchPaths - err = json.Unmarshal(b, &dnsResp) - return dnsResp.SearchPaths, err -} - -// SetSearchPaths sets the list of searchpaths for a tailnet. -func (c *Client) SetSearchPaths(ctx context.Context, searchpaths []string) (newSearchPaths []string, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SetSearchPaths: %w", err) - } - }() - dnsReq := DNSSearchPaths{SearchPaths: searchpaths} - b, err := c.dnsPOSTRequest(ctx, "searchpaths", dnsReq) - if err != nil { - return nil, err - } - var dnsResp DNSSearchPaths - err = json.Unmarshal(b, &dnsResp) - return dnsResp.SearchPaths, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package tailscale + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + + "tailscale.com/client/tailscale/apitype" +) + +// DNSNameServers is returned when retrieving the list of nameservers. +// It is also the structure provided when setting nameservers. +type DNSNameServers struct { + DNS []string `json:"dns"` // DNS name servers +} + +// DNSNameServersPostResponse is returned when setting the list of DNS nameservers. +// +// It includes the MagicDNS status since nameservers changes may affect MagicDNS. +type DNSNameServersPostResponse struct { + DNS []string `json:"dns"` // DNS name servers + MagicDNS bool `json:"magicDNS"` // whether MagicDNS is active for this tailnet (enabled + has fallback nameservers) +} + +// DNSSearchpaths is the list of search paths for a given domain. +type DNSSearchPaths struct { + SearchPaths []string `json:"searchPaths"` // DNS search paths +} + +// DNSPreferences is the preferences set for a given tailnet. +// +// It includes MagicDNS which can be turned on or off. To enable MagicDNS, +// there must be at least one nameserver. When all nameservers are removed, +// MagicDNS is disabled. +type DNSPreferences struct { + MagicDNS bool `json:"magicDNS"` // whether MagicDNS is active for this tailnet (enabled + has fallback nameservers) +} + +func (c *Client) dnsGETRequest(ctx context.Context, endpoint string) ([]byte, error) { + path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) + req, err := http.NewRequestWithContext(ctx, "GET", path, nil) + if err != nil { + return nil, err + } + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + return b, nil +} + +func (c *Client) dnsPOSTRequest(ctx context.Context, endpoint string, postData any) ([]byte, error) { + path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) + data, err := json.Marshal(&postData) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(data)) + req.Header.Set("Content-Type", "application/json") + if err != nil { + return nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + return b, nil +} + +// DNSConfig retrieves the DNSConfig settings for a domain. +func (c *Client) DNSConfig(ctx context.Context) (cfg *apitype.DNSConfig, err error) { + // Format return errors to be descriptive. + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.DNSConfig: %w", err) + } + }() + b, err := c.dnsGETRequest(ctx, "config") + if err != nil { + return nil, err + } + var dnsResp apitype.DNSConfig + err = json.Unmarshal(b, &dnsResp) + return &dnsResp, err +} + +func (c *Client) SetDNSConfig(ctx context.Context, cfg apitype.DNSConfig) (resp *apitype.DNSConfig, err error) { + // Format return errors to be descriptive. + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SetDNSConfig: %w", err) + } + }() + var dnsResp apitype.DNSConfig + b, err := c.dnsPOSTRequest(ctx, "config", cfg) + if err != nil { + return nil, err + } + err = json.Unmarshal(b, &dnsResp) + return &dnsResp, err +} + +// NameServers retrieves the list of nameservers set for a domain. +func (c *Client) NameServers(ctx context.Context) (nameservers []string, err error) { + // Format return errors to be descriptive. + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.NameServers: %w", err) + } + }() + b, err := c.dnsGETRequest(ctx, "nameservers") + if err != nil { + return nil, err + } + var dnsResp DNSNameServers + err = json.Unmarshal(b, &dnsResp) + return dnsResp.DNS, err +} + +// SetNameServers sets the list of nameservers for a tailnet to the list provided +// by the user. +// +// It returns the new list of nameservers and the MagicDNS status in case it was +// affected by the change. For example, removing all nameservers will turn off +// MagicDNS. +func (c *Client) SetNameServers(ctx context.Context, nameservers []string) (dnsResp *DNSNameServersPostResponse, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SetNameServers: %w", err) + } + }() + dnsReq := DNSNameServers{DNS: nameservers} + b, err := c.dnsPOSTRequest(ctx, "nameservers", dnsReq) + if err != nil { + return nil, err + } + err = json.Unmarshal(b, &dnsResp) + return dnsResp, err +} + +// DNSPreferences retrieves the DNS preferences set for a tailnet. +// +// It returns the status of MagicDNS. +func (c *Client) DNSPreferences(ctx context.Context) (dnsResp *DNSPreferences, err error) { + // Format return errors to be descriptive. + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.DNSPreferences: %w", err) + } + }() + b, err := c.dnsGETRequest(ctx, "preferences") + if err != nil { + return nil, err + } + err = json.Unmarshal(b, &dnsResp) + return dnsResp, err +} + +// SetDNSPreferences sets the DNS preferences for a tailnet. +// +// MagicDNS can only be enabled when there is at least one nameserver provided. +// When all nameservers are removed, MagicDNS is disabled and will stay disabled, +// unless explicitly enabled by a user again. +func (c *Client) SetDNSPreferences(ctx context.Context, magicDNS bool) (dnsResp *DNSPreferences, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SetDNSPreferences: %w", err) + } + }() + dnsReq := DNSPreferences{MagicDNS: magicDNS} + b, err := c.dnsPOSTRequest(ctx, "preferences", dnsReq) + if err != nil { + return + } + err = json.Unmarshal(b, &dnsResp) + return dnsResp, err +} + +// SearchPaths retrieves the list of searchpaths set for a tailnet. +func (c *Client) SearchPaths(ctx context.Context) (searchpaths []string, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SearchPaths: %w", err) + } + }() + b, err := c.dnsGETRequest(ctx, "searchpaths") + if err != nil { + return nil, err + } + var dnsResp *DNSSearchPaths + err = json.Unmarshal(b, &dnsResp) + return dnsResp.SearchPaths, err +} + +// SetSearchPaths sets the list of searchpaths for a tailnet. +func (c *Client) SetSearchPaths(ctx context.Context, searchpaths []string) (newSearchPaths []string, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SetSearchPaths: %w", err) + } + }() + dnsReq := DNSSearchPaths{SearchPaths: searchpaths} + b, err := c.dnsPOSTRequest(ctx, "searchpaths", dnsReq) + if err != nil { + return nil, err + } + var dnsResp DNSSearchPaths + err = json.Unmarshal(b, &dnsResp) + return dnsResp.SearchPaths, err +} diff --git a/client/tailscale/example/servetls/servetls.go b/client/tailscale/example/servetls/servetls.go index f48e90d163527..e426cbea2b375 100644 --- a/client/tailscale/example/servetls/servetls.go +++ b/client/tailscale/example/servetls/servetls.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The servetls program shows how to run an HTTPS server -// using a Tailscale cert via LetsEncrypt. -package main - -import ( - "crypto/tls" - "io" - "log" - "net/http" - - "tailscale.com/client/tailscale" -) - -func main() { - s := &http.Server{ - TLSConfig: &tls.Config{ - GetCertificate: tailscale.GetCertificate, - }, - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, "

Hello from Tailscale!

It works.") - }), - } - log.Printf("Running TLS server on :443 ...") - log.Fatal(s.ListenAndServeTLS("", "")) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The servetls program shows how to run an HTTPS server +// using a Tailscale cert via LetsEncrypt. +package main + +import ( + "crypto/tls" + "io" + "log" + "net/http" + + "tailscale.com/client/tailscale" +) + +func main() { + s := &http.Server{ + TLSConfig: &tls.Config{ + GetCertificate: tailscale.GetCertificate, + }, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "

Hello from Tailscale!

It works.") + }), + } + log.Printf("Running TLS server on :443 ...") + log.Fatal(s.ListenAndServeTLS("", "")) +} diff --git a/client/tailscale/keys.go b/client/tailscale/keys.go index 84bcdfae6aeeb..ae5f721b74d6d 100644 --- a/client/tailscale/keys.go +++ b/client/tailscale/keys.go @@ -1,166 +1,166 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailscale - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "time" -) - -// Key represents a Tailscale API or auth key. -type Key struct { - ID string `json:"id"` - Created time.Time `json:"created"` - Expires time.Time `json:"expires"` - Capabilities KeyCapabilities `json:"capabilities"` -} - -// KeyCapabilities are the capabilities of a Key. -type KeyCapabilities struct { - Devices KeyDeviceCapabilities `json:"devices,omitempty"` -} - -// KeyDeviceCapabilities are the device-related capabilities of a Key. -type KeyDeviceCapabilities struct { - Create KeyDeviceCreateCapabilities `json:"create"` -} - -// KeyDeviceCreateCapabilities are the device creation capabilities of a Key. -type KeyDeviceCreateCapabilities struct { - Reusable bool `json:"reusable"` - Ephemeral bool `json:"ephemeral"` - Preauthorized bool `json:"preauthorized"` - Tags []string `json:"tags,omitempty"` -} - -// Keys returns the list of keys for the current user. -func (c *Client) Keys(ctx context.Context) ([]string, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) - req, err := http.NewRequestWithContext(ctx, "GET", path, nil) - if err != nil { - return nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - var keys struct { - Keys []*Key `json:"keys"` - } - if err := json.Unmarshal(b, &keys); err != nil { - return nil, err - } - ret := make([]string, 0, len(keys.Keys)) - for _, k := range keys.Keys { - ret = append(ret, k.ID) - } - return ret, nil -} - -// CreateKey creates a new key for the current user. Currently, only auth keys -// can be created. It returns the secret key itself, which cannot be retrieved again -// later, and the key metadata. -// -// To create a key with a specific expiry, use CreateKeyWithExpiry. -func (c *Client) CreateKey(ctx context.Context, caps KeyCapabilities) (keySecret string, keyMeta *Key, _ error) { - return c.CreateKeyWithExpiry(ctx, caps, 0) -} - -// CreateKeyWithExpiry is like CreateKey, but allows specifying a expiration time. -// -// The time is truncated to a whole number of seconds. If zero, that means no expiration. -func (c *Client) CreateKeyWithExpiry(ctx context.Context, caps KeyCapabilities, expiry time.Duration) (keySecret string, keyMeta *Key, _ error) { - - // convert expirySeconds to an int64 (seconds) - expirySeconds := int64(expiry.Seconds()) - if expirySeconds < 0 { - return "", nil, fmt.Errorf("expiry must be positive") - } - if expirySeconds == 0 && expiry != 0 { - return "", nil, fmt.Errorf("non-zero expiry must be at least one second") - } - - keyRequest := struct { - Capabilities KeyCapabilities `json:"capabilities"` - ExpirySeconds int64 `json:"expirySeconds,omitempty"` - }{caps, int64(expirySeconds)} - bs, err := json.Marshal(keyRequest) - if err != nil { - return "", nil, err - } - - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) - req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewReader(bs)) - if err != nil { - return "", nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return "", nil, err - } - if resp.StatusCode != http.StatusOK { - return "", nil, handleErrorResponse(b, resp) - } - - var key struct { - Key - Secret string `json:"key"` - } - if err := json.Unmarshal(b, &key); err != nil { - return "", nil, err - } - return key.Secret, &key.Key, nil -} - -// Key returns the metadata for the given key ID. Currently, capabilities are -// only returned for auth keys, API keys only return general metadata. -func (c *Client) Key(ctx context.Context, id string) (*Key, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) - req, err := http.NewRequestWithContext(ctx, "GET", path, nil) - if err != nil { - return nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - var key Key - if err := json.Unmarshal(b, &key); err != nil { - return nil, err - } - return &key, nil -} - -// DeleteKey deletes the key with the given ID. -func (c *Client) DeleteKey(ctx context.Context, id string) error { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) - req, err := http.NewRequestWithContext(ctx, "DELETE", path, nil) - if err != nil { - return err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return err - } - if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailscale + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "time" +) + +// Key represents a Tailscale API or auth key. +type Key struct { + ID string `json:"id"` + Created time.Time `json:"created"` + Expires time.Time `json:"expires"` + Capabilities KeyCapabilities `json:"capabilities"` +} + +// KeyCapabilities are the capabilities of a Key. +type KeyCapabilities struct { + Devices KeyDeviceCapabilities `json:"devices,omitempty"` +} + +// KeyDeviceCapabilities are the device-related capabilities of a Key. +type KeyDeviceCapabilities struct { + Create KeyDeviceCreateCapabilities `json:"create"` +} + +// KeyDeviceCreateCapabilities are the device creation capabilities of a Key. +type KeyDeviceCreateCapabilities struct { + Reusable bool `json:"reusable"` + Ephemeral bool `json:"ephemeral"` + Preauthorized bool `json:"preauthorized"` + Tags []string `json:"tags,omitempty"` +} + +// Keys returns the list of keys for the current user. +func (c *Client) Keys(ctx context.Context) ([]string, error) { + path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) + req, err := http.NewRequestWithContext(ctx, "GET", path, nil) + if err != nil { + return nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + var keys struct { + Keys []*Key `json:"keys"` + } + if err := json.Unmarshal(b, &keys); err != nil { + return nil, err + } + ret := make([]string, 0, len(keys.Keys)) + for _, k := range keys.Keys { + ret = append(ret, k.ID) + } + return ret, nil +} + +// CreateKey creates a new key for the current user. Currently, only auth keys +// can be created. It returns the secret key itself, which cannot be retrieved again +// later, and the key metadata. +// +// To create a key with a specific expiry, use CreateKeyWithExpiry. +func (c *Client) CreateKey(ctx context.Context, caps KeyCapabilities) (keySecret string, keyMeta *Key, _ error) { + return c.CreateKeyWithExpiry(ctx, caps, 0) +} + +// CreateKeyWithExpiry is like CreateKey, but allows specifying a expiration time. +// +// The time is truncated to a whole number of seconds. If zero, that means no expiration. +func (c *Client) CreateKeyWithExpiry(ctx context.Context, caps KeyCapabilities, expiry time.Duration) (keySecret string, keyMeta *Key, _ error) { + + // convert expirySeconds to an int64 (seconds) + expirySeconds := int64(expiry.Seconds()) + if expirySeconds < 0 { + return "", nil, fmt.Errorf("expiry must be positive") + } + if expirySeconds == 0 && expiry != 0 { + return "", nil, fmt.Errorf("non-zero expiry must be at least one second") + } + + keyRequest := struct { + Capabilities KeyCapabilities `json:"capabilities"` + ExpirySeconds int64 `json:"expirySeconds,omitempty"` + }{caps, int64(expirySeconds)} + bs, err := json.Marshal(keyRequest) + if err != nil { + return "", nil, err + } + + path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) + req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewReader(bs)) + if err != nil { + return "", nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return "", nil, err + } + if resp.StatusCode != http.StatusOK { + return "", nil, handleErrorResponse(b, resp) + } + + var key struct { + Key + Secret string `json:"key"` + } + if err := json.Unmarshal(b, &key); err != nil { + return "", nil, err + } + return key.Secret, &key.Key, nil +} + +// Key returns the metadata for the given key ID. Currently, capabilities are +// only returned for auth keys, API keys only return general metadata. +func (c *Client) Key(ctx context.Context, id string) (*Key, error) { + path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) + req, err := http.NewRequestWithContext(ctx, "GET", path, nil) + if err != nil { + return nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + var key Key + if err := json.Unmarshal(b, &key); err != nil { + return nil, err + } + return &key, nil +} + +// DeleteKey deletes the key with the given ID. +func (c *Client) DeleteKey(ctx context.Context, id string) error { + path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) + req, err := http.NewRequestWithContext(ctx, "DELETE", path, nil) + if err != nil { + return err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return err + } + if resp.StatusCode != http.StatusOK { + return handleErrorResponse(b, resp) + } + return nil +} diff --git a/client/tailscale/routes.go b/client/tailscale/routes.go index 5912fc46c09a6..41415d1b44c29 100644 --- a/client/tailscale/routes.go +++ b/client/tailscale/routes.go @@ -1,95 +1,95 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -package tailscale - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "net/netip" -) - -// Routes contains the lists of subnet routes that are currently advertised by a device, -// as well as the subnets that are enabled to be routed by the device. -type Routes struct { - AdvertisedRoutes []netip.Prefix `json:"advertisedRoutes"` - EnabledRoutes []netip.Prefix `json:"enabledRoutes"` -} - -// Routes retrieves the list of subnet routes that have been enabled for a device. -// The routes that are returned are not necessarily advertised by the device, -// they have only been preapproved. -func (c *Client) Routes(ctx context.Context, deviceID string) (routes *Routes, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.Routes: %w", err) - } - }() - - path := fmt.Sprintf("%s/api/v2/device/%s/routes", c.baseURL(), deviceID) - req, err := http.NewRequestWithContext(ctx, "GET", path, nil) - if err != nil { - return nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - // If status code was not successful, return the error. - // TODO: Change the check for the StatusCode to include other 2XX success codes. - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - var sr Routes - err = json.Unmarshal(b, &sr) - return &sr, err -} - -type postRoutesParams struct { - Routes []netip.Prefix `json:"routes"` -} - -// SetRoutes updates the list of subnets that are enabled for a device. -// Subnets must be parsable by net/netip.ParsePrefix. -// Subnets do not have to be currently advertised by a device, they may be pre-enabled. -// Returns the updated list of enabled and advertised subnet routes in a *Routes object. -func (c *Client) SetRoutes(ctx context.Context, deviceID string, subnets []netip.Prefix) (routes *Routes, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SetRoutes: %w", err) - } - }() - params := &postRoutesParams{Routes: subnets} - data, err := json.Marshal(params) - if err != nil { - return nil, err - } - path := fmt.Sprintf("%s/api/v2/device/%s/routes", c.baseURL(), deviceID) - req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(data)) - if err != nil { - return nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - // If status code was not successful, return the error. - // TODO: Change the check for the StatusCode to include other 2XX success codes. - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - var srr *Routes - if err := json.Unmarshal(b, &srr); err != nil { - return nil, err - } - return srr, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package tailscale + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/netip" +) + +// Routes contains the lists of subnet routes that are currently advertised by a device, +// as well as the subnets that are enabled to be routed by the device. +type Routes struct { + AdvertisedRoutes []netip.Prefix `json:"advertisedRoutes"` + EnabledRoutes []netip.Prefix `json:"enabledRoutes"` +} + +// Routes retrieves the list of subnet routes that have been enabled for a device. +// The routes that are returned are not necessarily advertised by the device, +// they have only been preapproved. +func (c *Client) Routes(ctx context.Context, deviceID string) (routes *Routes, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.Routes: %w", err) + } + }() + + path := fmt.Sprintf("%s/api/v2/device/%s/routes", c.baseURL(), deviceID) + req, err := http.NewRequestWithContext(ctx, "GET", path, nil) + if err != nil { + return nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + var sr Routes + err = json.Unmarshal(b, &sr) + return &sr, err +} + +type postRoutesParams struct { + Routes []netip.Prefix `json:"routes"` +} + +// SetRoutes updates the list of subnets that are enabled for a device. +// Subnets must be parsable by net/netip.ParsePrefix. +// Subnets do not have to be currently advertised by a device, they may be pre-enabled. +// Returns the updated list of enabled and advertised subnet routes in a *Routes object. +func (c *Client) SetRoutes(ctx context.Context, deviceID string, subnets []netip.Prefix) (routes *Routes, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SetRoutes: %w", err) + } + }() + params := &postRoutesParams{Routes: subnets} + data, err := json.Marshal(params) + if err != nil { + return nil, err + } + path := fmt.Sprintf("%s/api/v2/device/%s/routes", c.baseURL(), deviceID) + req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + var srr *Routes + if err := json.Unmarshal(b, &srr); err != nil { + return nil, err + } + return srr, err +} diff --git a/client/tailscale/tailnet.go b/client/tailscale/tailnet.go index 2539e7f235b0e..eef2dca2014ad 100644 --- a/client/tailscale/tailnet.go +++ b/client/tailscale/tailnet.go @@ -1,42 +1,42 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -package tailscale - -import ( - "context" - "fmt" - "net/http" - "net/url" - - "tailscale.com/util/httpm" -) - -// TailnetDeleteRequest handles sending a DELETE request for a tailnet to control. -func (c *Client) TailnetDeleteRequest(ctx context.Context, tailnetID string) (err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.DeleteTailnet: %w", err) - } - }() - - path := fmt.Sprintf("%s/api/v2/tailnet/%s", c.baseURL(), url.PathEscape(string(tailnetID))) - req, err := http.NewRequestWithContext(ctx, httpm.DELETE, path, nil) - if err != nil { - return err - } - - c.setAuth(req) - b, resp, err := c.sendRequest(req) - if err != nil { - return err - } - - if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) - } - - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package tailscale + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "tailscale.com/util/httpm" +) + +// TailnetDeleteRequest handles sending a DELETE request for a tailnet to control. +func (c *Client) TailnetDeleteRequest(ctx context.Context, tailnetID string) (err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.DeleteTailnet: %w", err) + } + }() + + path := fmt.Sprintf("%s/api/v2/tailnet/%s", c.baseURL(), url.PathEscape(string(tailnetID))) + req, err := http.NewRequestWithContext(ctx, httpm.DELETE, path, nil) + if err != nil { + return err + } + + c.setAuth(req) + b, resp, err := c.sendRequest(req) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusOK { + return handleErrorResponse(b, resp) + } + + return nil +} diff --git a/client/web/qnap.go b/client/web/qnap.go index 9bde64bf5885b..8fa5ee174bae6 100644 --- a/client/web/qnap.go +++ b/client/web/qnap.go @@ -1,127 +1,127 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// qnap.go contains handlers and logic, such as authentication, -// that is specific to running the web client on QNAP. - -package web - -import ( - "crypto/tls" - "encoding/xml" - "errors" - "fmt" - "io" - "log" - "net/http" - "net/url" -) - -// authorizeQNAP authenticates the logged-in QNAP user and verifies that they -// are authorized to use the web client. -// If the user is not authorized to use the client, an error is returned. -func authorizeQNAP(r *http.Request) (authorized bool, err error) { - _, resp, err := qnapAuthn(r) - if err != nil { - return false, err - } - if resp.IsAdmin == 0 { - return false, errors.New("user is not an admin") - } - - return true, nil -} - -type qnapAuthResponse struct { - AuthPassed int `xml:"authPassed"` - IsAdmin int `xml:"isAdmin"` - AuthSID string `xml:"authSid"` - ErrorValue int `xml:"errorValue"` -} - -func qnapAuthn(r *http.Request) (string, *qnapAuthResponse, error) { - user, err := r.Cookie("NAS_USER") - if err != nil { - return "", nil, err - } - token, err := r.Cookie("qtoken") - if err == nil { - return qnapAuthnQtoken(r, user.Value, token.Value) - } - sid, err := r.Cookie("NAS_SID") - if err == nil { - return qnapAuthnSid(r, user.Value, sid.Value) - } - return "", nil, fmt.Errorf("not authenticated by any mechanism") -} - -// qnapAuthnURL returns the auth URL to use by inferring where the UI is -// running based on the request URL. This is necessary because QNAP has so -// many options, see https://github.com/tailscale/tailscale/issues/7108 -// and https://github.com/tailscale/tailscale/issues/6903 -func qnapAuthnURL(requestUrl string, query url.Values) string { - in, err := url.Parse(requestUrl) - scheme := "" - host := "" - if err != nil || in.Scheme == "" { - log.Printf("Cannot parse QNAP login URL %v", err) - - // try localhost and hope for the best - scheme = "http" - host = "localhost" - } else { - scheme = in.Scheme - host = in.Host - } - - u := url.URL{ - Scheme: scheme, - Host: host, - Path: "/cgi-bin/authLogin.cgi", - RawQuery: query.Encode(), - } - - return u.String() -} - -func qnapAuthnQtoken(r *http.Request, user, token string) (string, *qnapAuthResponse, error) { - query := url.Values{ - "qtoken": []string{token}, - "user": []string{user}, - } - return qnapAuthnFinish(user, qnapAuthnURL(r.URL.String(), query)) -} - -func qnapAuthnSid(r *http.Request, user, sid string) (string, *qnapAuthResponse, error) { - query := url.Values{ - "sid": []string{sid}, - } - return qnapAuthnFinish(user, qnapAuthnURL(r.URL.String(), query)) -} - -func qnapAuthnFinish(user, url string) (string, *qnapAuthResponse, error) { - // QNAP Force HTTPS mode uses a self-signed certificate. Even importing - // the QNAP root CA isn't enough, the cert doesn't have a usable CN nor - // SAN. See https://github.com/tailscale/tailscale/issues/6903 - tr := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - client := &http.Client{Transport: tr} - resp, err := client.Get(url) - if err != nil { - return "", nil, err - } - defer resp.Body.Close() - out, err := io.ReadAll(resp.Body) - if err != nil { - return "", nil, err - } - authResp := &qnapAuthResponse{} - if err := xml.Unmarshal(out, authResp); err != nil { - return "", nil, err - } - if authResp.AuthPassed == 0 { - return "", nil, fmt.Errorf("not authenticated") - } - return user, authResp, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// qnap.go contains handlers and logic, such as authentication, +// that is specific to running the web client on QNAP. + +package web + +import ( + "crypto/tls" + "encoding/xml" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/url" +) + +// authorizeQNAP authenticates the logged-in QNAP user and verifies that they +// are authorized to use the web client. +// If the user is not authorized to use the client, an error is returned. +func authorizeQNAP(r *http.Request) (authorized bool, err error) { + _, resp, err := qnapAuthn(r) + if err != nil { + return false, err + } + if resp.IsAdmin == 0 { + return false, errors.New("user is not an admin") + } + + return true, nil +} + +type qnapAuthResponse struct { + AuthPassed int `xml:"authPassed"` + IsAdmin int `xml:"isAdmin"` + AuthSID string `xml:"authSid"` + ErrorValue int `xml:"errorValue"` +} + +func qnapAuthn(r *http.Request) (string, *qnapAuthResponse, error) { + user, err := r.Cookie("NAS_USER") + if err != nil { + return "", nil, err + } + token, err := r.Cookie("qtoken") + if err == nil { + return qnapAuthnQtoken(r, user.Value, token.Value) + } + sid, err := r.Cookie("NAS_SID") + if err == nil { + return qnapAuthnSid(r, user.Value, sid.Value) + } + return "", nil, fmt.Errorf("not authenticated by any mechanism") +} + +// qnapAuthnURL returns the auth URL to use by inferring where the UI is +// running based on the request URL. This is necessary because QNAP has so +// many options, see https://github.com/tailscale/tailscale/issues/7108 +// and https://github.com/tailscale/tailscale/issues/6903 +func qnapAuthnURL(requestUrl string, query url.Values) string { + in, err := url.Parse(requestUrl) + scheme := "" + host := "" + if err != nil || in.Scheme == "" { + log.Printf("Cannot parse QNAP login URL %v", err) + + // try localhost and hope for the best + scheme = "http" + host = "localhost" + } else { + scheme = in.Scheme + host = in.Host + } + + u := url.URL{ + Scheme: scheme, + Host: host, + Path: "/cgi-bin/authLogin.cgi", + RawQuery: query.Encode(), + } + + return u.String() +} + +func qnapAuthnQtoken(r *http.Request, user, token string) (string, *qnapAuthResponse, error) { + query := url.Values{ + "qtoken": []string{token}, + "user": []string{user}, + } + return qnapAuthnFinish(user, qnapAuthnURL(r.URL.String(), query)) +} + +func qnapAuthnSid(r *http.Request, user, sid string) (string, *qnapAuthResponse, error) { + query := url.Values{ + "sid": []string{sid}, + } + return qnapAuthnFinish(user, qnapAuthnURL(r.URL.String(), query)) +} + +func qnapAuthnFinish(user, url string) (string, *qnapAuthResponse, error) { + // QNAP Force HTTPS mode uses a self-signed certificate. Even importing + // the QNAP root CA isn't enough, the cert doesn't have a usable CN nor + // SAN. See https://github.com/tailscale/tailscale/issues/6903 + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + client := &http.Client{Transport: tr} + resp, err := client.Get(url) + if err != nil { + return "", nil, err + } + defer resp.Body.Close() + out, err := io.ReadAll(resp.Body) + if err != nil { + return "", nil, err + } + authResp := &qnapAuthResponse{} + if err := xml.Unmarshal(out, authResp); err != nil { + return "", nil, err + } + if authResp.AuthPassed == 0 { + return "", nil, fmt.Errorf("not authenticated") + } + return user, authResp, nil +} diff --git a/client/web/src/assets/icons/arrow-right.svg b/client/web/src/assets/icons/arrow-right.svg index fbc4bb7ae3b7a..0a32ef4844395 100644 --- a/client/web/src/assets/icons/arrow-right.svg +++ b/client/web/src/assets/icons/arrow-right.svg @@ -1,4 +1,4 @@ - - - - + + + + diff --git a/client/web/src/assets/icons/arrow-up-circle.svg b/client/web/src/assets/icons/arrow-up-circle.svg index e9d009eb6bf65..e64c836be71c9 100644 --- a/client/web/src/assets/icons/arrow-up-circle.svg +++ b/client/web/src/assets/icons/arrow-up-circle.svg @@ -1,5 +1,5 @@ - - - - - + + + + + diff --git a/client/web/src/assets/icons/check-circle.svg b/client/web/src/assets/icons/check-circle.svg index 4daeed514d1ff..6c5ee519e6d35 100644 --- a/client/web/src/assets/icons/check-circle.svg +++ b/client/web/src/assets/icons/check-circle.svg @@ -1,4 +1,4 @@ - - - - + + + + diff --git a/client/web/src/assets/icons/check.svg b/client/web/src/assets/icons/check.svg index efa11685d772c..70027536a6960 100644 --- a/client/web/src/assets/icons/check.svg +++ b/client/web/src/assets/icons/check.svg @@ -1,3 +1,3 @@ - - - + + + diff --git a/client/web/src/assets/icons/chevron-down.svg b/client/web/src/assets/icons/chevron-down.svg index afc98f255d4e5..993744c2fa287 100644 --- a/client/web/src/assets/icons/chevron-down.svg +++ b/client/web/src/assets/icons/chevron-down.svg @@ -1,3 +1,3 @@ - - - + + + diff --git a/client/web/src/assets/icons/eye.svg b/client/web/src/assets/icons/eye.svg index b0b21ed3f701c..e277674777814 100644 --- a/client/web/src/assets/icons/eye.svg +++ b/client/web/src/assets/icons/eye.svg @@ -1,11 +1,11 @@ - - - - - - - - - - - + + + + + + + + + + + diff --git a/client/web/src/assets/icons/search.svg b/client/web/src/assets/icons/search.svg index 782cd90eee1d8..08eb2d3dc3b8f 100644 --- a/client/web/src/assets/icons/search.svg +++ b/client/web/src/assets/icons/search.svg @@ -1,4 +1,4 @@ - - - - + + + + diff --git a/client/web/src/assets/icons/tailscale-icon.svg b/client/web/src/assets/icons/tailscale-icon.svg index d6052fe5e7cd6..de3c975ce1d53 100644 --- a/client/web/src/assets/icons/tailscale-icon.svg +++ b/client/web/src/assets/icons/tailscale-icon.svg @@ -1,18 +1,18 @@ - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + diff --git a/client/web/src/assets/icons/tailscale-logo.svg b/client/web/src/assets/icons/tailscale-logo.svg index 6d5c7ce0caae3..94a9cc4ee906e 100644 --- a/client/web/src/assets/icons/tailscale-logo.svg +++ b/client/web/src/assets/icons/tailscale-logo.svg @@ -1,20 +1,20 @@ - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + diff --git a/client/web/src/assets/icons/user.svg b/client/web/src/assets/icons/user.svg index 29d86f0499956..7fa3d26034d8c 100644 --- a/client/web/src/assets/icons/user.svg +++ b/client/web/src/assets/icons/user.svg @@ -1,4 +1,4 @@ - - - - + + + + diff --git a/client/web/src/assets/icons/x-circle.svg b/client/web/src/assets/icons/x-circle.svg index 49afc5a0366fe..d6259c9177672 100644 --- a/client/web/src/assets/icons/x-circle.svg +++ b/client/web/src/assets/icons/x-circle.svg @@ -1,5 +1,5 @@ - - - - - + + + + + diff --git a/client/web/synology.go b/client/web/synology.go index 922489d78af16..5480263834893 100644 --- a/client/web/synology.go +++ b/client/web/synology.go @@ -1,59 +1,59 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// synology.go contains handlers and logic, such as authentication, -// that is specific to running the web client on Synology. - -package web - -import ( - "errors" - "fmt" - "net/http" - "os/exec" - "strings" - - "tailscale.com/util/groupmember" -) - -// authorizeSynology authenticates the logged-in Synology user and verifies -// that they are authorized to use the web client. -// If the user is authenticated, but not authorized to use the client, an error is returned. -func authorizeSynology(r *http.Request) (authorized bool, err error) { - if !hasSynoToken(r) { - return false, nil - } - - // authenticate the Synology user - cmd := exec.Command("/usr/syno/synoman/webman/modules/authenticate.cgi") - out, err := cmd.CombinedOutput() - if err != nil { - return false, fmt.Errorf("auth: %v: %s", err, out) - } - user := strings.TrimSpace(string(out)) - - // check if the user is in the administrators group - isAdmin, err := groupmember.IsMemberOfGroup("administrators", user) - if err != nil { - return false, err - } - if !isAdmin { - return false, errors.New("not a member of administrators group") - } - - return true, nil -} - -// hasSynoToken returns true if the request include a SynoToken used for synology auth. -func hasSynoToken(r *http.Request) bool { - if r.Header.Get("X-Syno-Token") != "" { - return true - } - if r.URL.Query().Get("SynoToken") != "" { - return true - } - if r.Method == "POST" && r.FormValue("SynoToken") != "" { - return true - } - return false -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// synology.go contains handlers and logic, such as authentication, +// that is specific to running the web client on Synology. + +package web + +import ( + "errors" + "fmt" + "net/http" + "os/exec" + "strings" + + "tailscale.com/util/groupmember" +) + +// authorizeSynology authenticates the logged-in Synology user and verifies +// that they are authorized to use the web client. +// If the user is authenticated, but not authorized to use the client, an error is returned. +func authorizeSynology(r *http.Request) (authorized bool, err error) { + if !hasSynoToken(r) { + return false, nil + } + + // authenticate the Synology user + cmd := exec.Command("/usr/syno/synoman/webman/modules/authenticate.cgi") + out, err := cmd.CombinedOutput() + if err != nil { + return false, fmt.Errorf("auth: %v: %s", err, out) + } + user := strings.TrimSpace(string(out)) + + // check if the user is in the administrators group + isAdmin, err := groupmember.IsMemberOfGroup("administrators", user) + if err != nil { + return false, err + } + if !isAdmin { + return false, errors.New("not a member of administrators group") + } + + return true, nil +} + +// hasSynoToken returns true if the request include a SynoToken used for synology auth. +func hasSynoToken(r *http.Request) bool { + if r.Header.Get("X-Syno-Token") != "" { + return true + } + if r.URL.Query().Get("SynoToken") != "" { + return true + } + if r.Method == "POST" && r.FormValue("SynoToken") != "" { + return true + } + return false +} diff --git a/clientupdate/distsign/distsign.go b/clientupdate/distsign/distsign.go index eba4b9267b119..aae6201539c59 100644 --- a/clientupdate/distsign/distsign.go +++ b/clientupdate/distsign/distsign.go @@ -1,486 +1,486 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package distsign implements signature and validation of arbitrary -// distributable files. -// -// There are 3 parties in this exchange: -// - builder, which creates files, signs them with signing keys and publishes -// to server -// - server, which distributes public signing keys, files and signatures -// - client, which downloads files and signatures from server, and validates -// the signatures -// -// There are 2 types of keys: -// - signing keys, that sign individual distributable files on the builder -// - root keys, that sign signing keys and are kept offline -// -// root keys -(sign)-> signing keys -(sign)-> files -// -// All keys are asymmetric Ed25519 key pairs. -// -// The server serves static files under some known prefix. The kinds of files are: -// - distsign.pub - bundle of PEM-encoded public signing keys -// - distsign.pub.sig - signature of distsign.pub using one of the root keys -// - $file - any distributable file -// - $file.sig - signature of $file using any of the signing keys -// -// The root public keys are baked into the client software at compile time. -// These keys are long-lived and prove the validity of current signing keys -// from distsign.pub. To rotate root keys, a new client release must be -// published, they are not rotated dynamically. There are multiple root keys in -// different locations specifically to allow this rotation without using the -// discarded root key for any new signatures. -// -// The signing public keys are fetched by the client dynamically before every -// download and can be rotated more readily, assuming that most deployed -// clients trust the root keys used to issue fresh signing keys. -package distsign - -import ( - "context" - "crypto/ed25519" - "crypto/rand" - "encoding/binary" - "encoding/pem" - "errors" - "fmt" - "hash" - "io" - "log" - "net/http" - "net/url" - "os" - "time" - - "github.com/hdevalence/ed25519consensus" - "golang.org/x/crypto/blake2s" - "tailscale.com/net/tshttpproxy" - "tailscale.com/types/logger" - "tailscale.com/util/httpm" - "tailscale.com/util/must" -) - -const ( - pemTypeRootPrivate = "ROOT PRIVATE KEY" - pemTypeRootPublic = "ROOT PUBLIC KEY" - pemTypeSigningPrivate = "SIGNING PRIVATE KEY" - pemTypeSigningPublic = "SIGNING PUBLIC KEY" - - downloadSizeLimit = 1 << 29 // 512MB - signingKeysSizeLimit = 1 << 20 // 1MB - signatureSizeLimit = ed25519.SignatureSize -) - -// RootKey is a root key used to sign signing keys. -type RootKey struct { - k ed25519.PrivateKey -} - -// GenerateRootKey generates a new root key pair and encodes it as PEM. -func GenerateRootKey() (priv, pub []byte, err error) { - pub, priv, err = ed25519.GenerateKey(rand.Reader) - if err != nil { - return nil, nil, err - } - return pem.EncodeToMemory(&pem.Block{ - Type: pemTypeRootPrivate, - Bytes: []byte(priv), - }), pem.EncodeToMemory(&pem.Block{ - Type: pemTypeRootPublic, - Bytes: []byte(pub), - }), nil -} - -// ParseRootKey parses the PEM-encoded private root key. The key must be in the -// same format as returned by GenerateRootKey. -func ParseRootKey(privKey []byte) (*RootKey, error) { - k, err := parsePrivateKey(privKey, pemTypeRootPrivate) - if err != nil { - return nil, fmt.Errorf("failed to parse root key: %w", err) - } - return &RootKey{k: k}, nil -} - -// SignSigningKeys signs the bundle of public signing keys. The bundle must be -// a sequence of PEM blocks joined with newlines. -func (r *RootKey) SignSigningKeys(pubBundle []byte) ([]byte, error) { - if _, err := ParseSigningKeyBundle(pubBundle); err != nil { - return nil, err - } - return ed25519.Sign(r.k, pubBundle), nil -} - -// SigningKey is a signing key used to sign packages. -type SigningKey struct { - k ed25519.PrivateKey -} - -// GenerateSigningKey generates a new signing key pair and encodes it as PEM. -func GenerateSigningKey() (priv, pub []byte, err error) { - pub, priv, err = ed25519.GenerateKey(rand.Reader) - if err != nil { - return nil, nil, err - } - return pem.EncodeToMemory(&pem.Block{ - Type: pemTypeSigningPrivate, - Bytes: []byte(priv), - }), pem.EncodeToMemory(&pem.Block{ - Type: pemTypeSigningPublic, - Bytes: []byte(pub), - }), nil -} - -// ParseSigningKey parses the PEM-encoded private signing key. The key must be -// in the same format as returned by GenerateSigningKey. -func ParseSigningKey(privKey []byte) (*SigningKey, error) { - k, err := parsePrivateKey(privKey, pemTypeSigningPrivate) - if err != nil { - return nil, fmt.Errorf("failed to parse root key: %w", err) - } - return &SigningKey{k: k}, nil -} - -// SignPackageHash signs the hash and the length of a package. Use PackageHash -// to compute the inputs. -func (s *SigningKey) SignPackageHash(hash []byte, len int64) ([]byte, error) { - if len <= 0 { - return nil, fmt.Errorf("package length must be positive, got %d", len) - } - msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) - return ed25519.Sign(s.k, msg), nil -} - -// PackageHash is a hash.Hash that counts the number of bytes written. Use it -// to get the hash and length inputs to SigningKey.SignPackageHash. -type PackageHash struct { - hash.Hash - len int64 -} - -// NewPackageHash returns an initialized PackageHash using BLAKE2s. -func NewPackageHash() *PackageHash { - h, err := blake2s.New256(nil) - if err != nil { - // Should never happen with a nil key passed to blake2s. - panic(err) - } - return &PackageHash{Hash: h} -} - -func (ph *PackageHash) Write(b []byte) (int, error) { - ph.len += int64(len(b)) - return ph.Hash.Write(b) -} - -// Reset the PackageHash to its initial state. -func (ph *PackageHash) Reset() { - ph.len = 0 - ph.Hash.Reset() -} - -// Len returns the total number of bytes written. -func (ph *PackageHash) Len() int64 { return ph.len } - -// Client downloads and validates files from a distribution server. -type Client struct { - logf logger.Logf - roots []ed25519.PublicKey - pkgsAddr *url.URL -} - -// NewClient returns a new client for distribution server located at pkgsAddr, -// and uses embedded root keys from the roots/ subdirectory of this package. -func NewClient(logf logger.Logf, pkgsAddr string) (*Client, error) { - if logf == nil { - logf = log.Printf - } - u, err := url.Parse(pkgsAddr) - if err != nil { - return nil, fmt.Errorf("invalid pkgsAddr %q: %w", pkgsAddr, err) - } - return &Client{logf: logf, roots: roots(), pkgsAddr: u}, nil -} - -func (c *Client) url(path string) string { - return c.pkgsAddr.JoinPath(path).String() -} - -// Download fetches a file at path srcPath from pkgsAddr passed in NewClient. -// The file is downloaded to dstPath and its signature is validated using the -// embedded root keys. Download returns an error if anything goes wrong with -// the actual file download or with signature validation. -func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error { - // Always fetch a fresh signing key. - sigPub, err := c.signingKeys() - if err != nil { - return err - } - - srcURL := c.url(srcPath) - sigURL := srcURL + ".sig" - - c.logf("Downloading %q", srcURL) - dstPathUnverified := dstPath + ".unverified" - hash, len, err := c.download(ctx, srcURL, dstPathUnverified, downloadSizeLimit) - if err != nil { - return err - } - c.logf("Downloading %q", sigURL) - sig, err := fetch(sigURL, signatureSizeLimit) - if err != nil { - // Best-effort clean up of downloaded package. - os.Remove(dstPathUnverified) - return err - } - msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) - if !VerifyAny(sigPub, msg, sig) { - // Best-effort clean up of downloaded package. - os.Remove(dstPathUnverified) - return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, srcURL) - } - c.logf("Signature OK") - - if err := os.Rename(dstPathUnverified, dstPath); err != nil { - return fmt.Errorf("failed to move %q to %q after signature validation", dstPathUnverified, dstPath) - } - - return nil -} - -// ValidateLocalBinary fetches the latest signature associated with the binary -// at srcURLPath and uses it to validate the file located on disk via -// localFilePath. ValidateLocalBinary returns an error if anything goes wrong -// with the signature download or with signature validation. -func (c *Client) ValidateLocalBinary(srcURLPath, localFilePath string) error { - // Always fetch a fresh signing key. - sigPub, err := c.signingKeys() - if err != nil { - return err - } - - srcURL := c.url(srcURLPath) - sigURL := srcURL + ".sig" - - localFile, err := os.Open(localFilePath) - if err != nil { - return err - } - defer localFile.Close() - - h := NewPackageHash() - _, err = io.Copy(h, localFile) - if err != nil { - return err - } - hash, hashLen := h.Sum(nil), h.Len() - - c.logf("Downloading %q", sigURL) - sig, err := fetch(sigURL, signatureSizeLimit) - if err != nil { - return err - } - - msg := binary.LittleEndian.AppendUint64(hash, uint64(hashLen)) - if !VerifyAny(sigPub, msg, sig) { - return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, localFilePath) - } - c.logf("Signature OK") - - return nil -} - -// signingKeys fetches current signing keys from the server and validates them -// against the roots. Should be called before validation of any downloaded file -// to get the fresh keys. -func (c *Client) signingKeys() ([]ed25519.PublicKey, error) { - keyURL := c.url("distsign.pub") - sigURL := keyURL + ".sig" - raw, err := fetch(keyURL, signingKeysSizeLimit) - if err != nil { - return nil, err - } - sig, err := fetch(sigURL, signatureSizeLimit) - if err != nil { - return nil, err - } - if !VerifyAny(c.roots, raw, sig) { - return nil, fmt.Errorf("signature %q for key %q does not validate with any known root key; either you are under attack, or running a very old version of Tailscale with outdated root keys", sigURL, keyURL) - } - - keys, err := ParseSigningKeyBundle(raw) - if err != nil { - return nil, fmt.Errorf("cannot parse signing key bundle from %q: %w", keyURL, err) - } - return keys, nil -} - -// fetch reads the response body from url into memory, up to limit bytes. -func fetch(url string, limit int64) ([]byte, error) { - resp, err := http.Get(url) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - return io.ReadAll(io.LimitReader(resp.Body, limit)) -} - -// download writes the response body of url into a local file at dst, up to -// limit bytes. On success, the returned value is a BLAKE2s hash of the file. -func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) { - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.Proxy = tshttpproxy.ProxyFromEnvironment - defer tr.CloseIdleConnections() - hc := &http.Client{Transport: tr} - - quickCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - headReq := must.Get(http.NewRequestWithContext(quickCtx, httpm.HEAD, url, nil)) - - res, err := hc.Do(headReq) - if err != nil { - return nil, 0, err - } - if res.StatusCode != http.StatusOK { - return nil, 0, fmt.Errorf("HEAD %q: %v", url, res.Status) - } - if res.ContentLength <= 0 { - return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength) - } - c.logf("Download size: %v", res.ContentLength) - - dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil)) - dlRes, err := hc.Do(dlReq) - if err != nil { - return nil, 0, err - } - defer dlRes.Body.Close() - // TODO(bradfitz): resume from existing partial file on disk - if dlRes.StatusCode != http.StatusOK { - return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status) - } - - of, err := os.Create(dst) - if err != nil { - return nil, 0, err - } - defer of.Close() - pw := &progressWriter{total: res.ContentLength, logf: c.logf} - h := NewPackageHash() - n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit)) - if err != nil { - return nil, n, err - } - if n != res.ContentLength { - return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength) - } - if err := dlRes.Body.Close(); err != nil { - return nil, n, err - } - if err := of.Close(); err != nil { - return nil, n, err - } - pw.print() - - return h.Sum(nil), h.Len(), nil -} - -type progressWriter struct { - done int64 - total int64 - lastPrint time.Time - logf logger.Logf -} - -func (pw *progressWriter) Write(p []byte) (n int, err error) { - pw.done += int64(len(p)) - if time.Since(pw.lastPrint) > 2*time.Second { - pw.print() - } - return len(p), nil -} - -func (pw *progressWriter) print() { - pw.lastPrint = time.Now() - pw.logf("Downloaded %v/%v (%.1f%%)", pw.done, pw.total, float64(pw.done)/float64(pw.total)*100) -} - -func parsePrivateKey(data []byte, typeTag string) (ed25519.PrivateKey, error) { - b, rest := pem.Decode(data) - if b == nil { - return nil, errors.New("failed to decode PEM data") - } - if len(rest) > 0 { - return nil, errors.New("trailing PEM data") - } - if b.Type != typeTag { - return nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) - } - if len(b.Bytes) != ed25519.PrivateKeySize { - return nil, errors.New("private key has incorrect length for an Ed25519 private key") - } - return ed25519.PrivateKey(b.Bytes), nil -} - -// ParseSigningKeyBundle parses the bundle of PEM-encoded public signing keys. -func ParseSigningKeyBundle(bundle []byte) ([]ed25519.PublicKey, error) { - return parsePublicKeyBundle(bundle, pemTypeSigningPublic) -} - -// ParseRootKeyBundle parses the bundle of PEM-encoded public root keys. -func ParseRootKeyBundle(bundle []byte) ([]ed25519.PublicKey, error) { - return parsePublicKeyBundle(bundle, pemTypeRootPublic) -} - -func parsePublicKeyBundle(bundle []byte, typeTag string) ([]ed25519.PublicKey, error) { - var keys []ed25519.PublicKey - for len(bundle) > 0 { - pub, rest, err := parsePublicKey(bundle, typeTag) - if err != nil { - return nil, err - } - keys = append(keys, pub) - bundle = rest - } - if len(keys) == 0 { - return nil, errors.New("no signing keys found in the bundle") - } - return keys, nil -} - -func parseSinglePublicKey(data []byte, typeTag string) (ed25519.PublicKey, error) { - pub, rest, err := parsePublicKey(data, typeTag) - if err != nil { - return nil, err - } - if len(rest) > 0 { - return nil, errors.New("trailing PEM data") - } - return pub, err -} - -func parsePublicKey(data []byte, typeTag string) (pub ed25519.PublicKey, rest []byte, retErr error) { - b, rest := pem.Decode(data) - if b == nil { - return nil, nil, errors.New("failed to decode PEM data") - } - if b.Type != typeTag { - return nil, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) - } - if len(b.Bytes) != ed25519.PublicKeySize { - return nil, nil, errors.New("public key has incorrect length for an Ed25519 public key") - } - return ed25519.PublicKey(b.Bytes), rest, nil -} - -// VerifyAny verifies whether sig is valid for msg using any of the keys. -// VerifyAny will panic if any of the keys have the wrong size for Ed25519. -func VerifyAny(keys []ed25519.PublicKey, msg, sig []byte) bool { - for _, k := range keys { - if ed25519consensus.Verify(k, msg, sig) { - return true - } - } - return false -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package distsign implements signature and validation of arbitrary +// distributable files. +// +// There are 3 parties in this exchange: +// - builder, which creates files, signs them with signing keys and publishes +// to server +// - server, which distributes public signing keys, files and signatures +// - client, which downloads files and signatures from server, and validates +// the signatures +// +// There are 2 types of keys: +// - signing keys, that sign individual distributable files on the builder +// - root keys, that sign signing keys and are kept offline +// +// root keys -(sign)-> signing keys -(sign)-> files +// +// All keys are asymmetric Ed25519 key pairs. +// +// The server serves static files under some known prefix. The kinds of files are: +// - distsign.pub - bundle of PEM-encoded public signing keys +// - distsign.pub.sig - signature of distsign.pub using one of the root keys +// - $file - any distributable file +// - $file.sig - signature of $file using any of the signing keys +// +// The root public keys are baked into the client software at compile time. +// These keys are long-lived and prove the validity of current signing keys +// from distsign.pub. To rotate root keys, a new client release must be +// published, they are not rotated dynamically. There are multiple root keys in +// different locations specifically to allow this rotation without using the +// discarded root key for any new signatures. +// +// The signing public keys are fetched by the client dynamically before every +// download and can be rotated more readily, assuming that most deployed +// clients trust the root keys used to issue fresh signing keys. +package distsign + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "encoding/pem" + "errors" + "fmt" + "hash" + "io" + "log" + "net/http" + "net/url" + "os" + "time" + + "github.com/hdevalence/ed25519consensus" + "golang.org/x/crypto/blake2s" + "tailscale.com/net/tshttpproxy" + "tailscale.com/types/logger" + "tailscale.com/util/httpm" + "tailscale.com/util/must" +) + +const ( + pemTypeRootPrivate = "ROOT PRIVATE KEY" + pemTypeRootPublic = "ROOT PUBLIC KEY" + pemTypeSigningPrivate = "SIGNING PRIVATE KEY" + pemTypeSigningPublic = "SIGNING PUBLIC KEY" + + downloadSizeLimit = 1 << 29 // 512MB + signingKeysSizeLimit = 1 << 20 // 1MB + signatureSizeLimit = ed25519.SignatureSize +) + +// RootKey is a root key used to sign signing keys. +type RootKey struct { + k ed25519.PrivateKey +} + +// GenerateRootKey generates a new root key pair and encodes it as PEM. +func GenerateRootKey() (priv, pub []byte, err error) { + pub, priv, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + return pem.EncodeToMemory(&pem.Block{ + Type: pemTypeRootPrivate, + Bytes: []byte(priv), + }), pem.EncodeToMemory(&pem.Block{ + Type: pemTypeRootPublic, + Bytes: []byte(pub), + }), nil +} + +// ParseRootKey parses the PEM-encoded private root key. The key must be in the +// same format as returned by GenerateRootKey. +func ParseRootKey(privKey []byte) (*RootKey, error) { + k, err := parsePrivateKey(privKey, pemTypeRootPrivate) + if err != nil { + return nil, fmt.Errorf("failed to parse root key: %w", err) + } + return &RootKey{k: k}, nil +} + +// SignSigningKeys signs the bundle of public signing keys. The bundle must be +// a sequence of PEM blocks joined with newlines. +func (r *RootKey) SignSigningKeys(pubBundle []byte) ([]byte, error) { + if _, err := ParseSigningKeyBundle(pubBundle); err != nil { + return nil, err + } + return ed25519.Sign(r.k, pubBundle), nil +} + +// SigningKey is a signing key used to sign packages. +type SigningKey struct { + k ed25519.PrivateKey +} + +// GenerateSigningKey generates a new signing key pair and encodes it as PEM. +func GenerateSigningKey() (priv, pub []byte, err error) { + pub, priv, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + return pem.EncodeToMemory(&pem.Block{ + Type: pemTypeSigningPrivate, + Bytes: []byte(priv), + }), pem.EncodeToMemory(&pem.Block{ + Type: pemTypeSigningPublic, + Bytes: []byte(pub), + }), nil +} + +// ParseSigningKey parses the PEM-encoded private signing key. The key must be +// in the same format as returned by GenerateSigningKey. +func ParseSigningKey(privKey []byte) (*SigningKey, error) { + k, err := parsePrivateKey(privKey, pemTypeSigningPrivate) + if err != nil { + return nil, fmt.Errorf("failed to parse root key: %w", err) + } + return &SigningKey{k: k}, nil +} + +// SignPackageHash signs the hash and the length of a package. Use PackageHash +// to compute the inputs. +func (s *SigningKey) SignPackageHash(hash []byte, len int64) ([]byte, error) { + if len <= 0 { + return nil, fmt.Errorf("package length must be positive, got %d", len) + } + msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) + return ed25519.Sign(s.k, msg), nil +} + +// PackageHash is a hash.Hash that counts the number of bytes written. Use it +// to get the hash and length inputs to SigningKey.SignPackageHash. +type PackageHash struct { + hash.Hash + len int64 +} + +// NewPackageHash returns an initialized PackageHash using BLAKE2s. +func NewPackageHash() *PackageHash { + h, err := blake2s.New256(nil) + if err != nil { + // Should never happen with a nil key passed to blake2s. + panic(err) + } + return &PackageHash{Hash: h} +} + +func (ph *PackageHash) Write(b []byte) (int, error) { + ph.len += int64(len(b)) + return ph.Hash.Write(b) +} + +// Reset the PackageHash to its initial state. +func (ph *PackageHash) Reset() { + ph.len = 0 + ph.Hash.Reset() +} + +// Len returns the total number of bytes written. +func (ph *PackageHash) Len() int64 { return ph.len } + +// Client downloads and validates files from a distribution server. +type Client struct { + logf logger.Logf + roots []ed25519.PublicKey + pkgsAddr *url.URL +} + +// NewClient returns a new client for distribution server located at pkgsAddr, +// and uses embedded root keys from the roots/ subdirectory of this package. +func NewClient(logf logger.Logf, pkgsAddr string) (*Client, error) { + if logf == nil { + logf = log.Printf + } + u, err := url.Parse(pkgsAddr) + if err != nil { + return nil, fmt.Errorf("invalid pkgsAddr %q: %w", pkgsAddr, err) + } + return &Client{logf: logf, roots: roots(), pkgsAddr: u}, nil +} + +func (c *Client) url(path string) string { + return c.pkgsAddr.JoinPath(path).String() +} + +// Download fetches a file at path srcPath from pkgsAddr passed in NewClient. +// The file is downloaded to dstPath and its signature is validated using the +// embedded root keys. Download returns an error if anything goes wrong with +// the actual file download or with signature validation. +func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error { + // Always fetch a fresh signing key. + sigPub, err := c.signingKeys() + if err != nil { + return err + } + + srcURL := c.url(srcPath) + sigURL := srcURL + ".sig" + + c.logf("Downloading %q", srcURL) + dstPathUnverified := dstPath + ".unverified" + hash, len, err := c.download(ctx, srcURL, dstPathUnverified, downloadSizeLimit) + if err != nil { + return err + } + c.logf("Downloading %q", sigURL) + sig, err := fetch(sigURL, signatureSizeLimit) + if err != nil { + // Best-effort clean up of downloaded package. + os.Remove(dstPathUnverified) + return err + } + msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) + if !VerifyAny(sigPub, msg, sig) { + // Best-effort clean up of downloaded package. + os.Remove(dstPathUnverified) + return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, srcURL) + } + c.logf("Signature OK") + + if err := os.Rename(dstPathUnverified, dstPath); err != nil { + return fmt.Errorf("failed to move %q to %q after signature validation", dstPathUnverified, dstPath) + } + + return nil +} + +// ValidateLocalBinary fetches the latest signature associated with the binary +// at srcURLPath and uses it to validate the file located on disk via +// localFilePath. ValidateLocalBinary returns an error if anything goes wrong +// with the signature download or with signature validation. +func (c *Client) ValidateLocalBinary(srcURLPath, localFilePath string) error { + // Always fetch a fresh signing key. + sigPub, err := c.signingKeys() + if err != nil { + return err + } + + srcURL := c.url(srcURLPath) + sigURL := srcURL + ".sig" + + localFile, err := os.Open(localFilePath) + if err != nil { + return err + } + defer localFile.Close() + + h := NewPackageHash() + _, err = io.Copy(h, localFile) + if err != nil { + return err + } + hash, hashLen := h.Sum(nil), h.Len() + + c.logf("Downloading %q", sigURL) + sig, err := fetch(sigURL, signatureSizeLimit) + if err != nil { + return err + } + + msg := binary.LittleEndian.AppendUint64(hash, uint64(hashLen)) + if !VerifyAny(sigPub, msg, sig) { + return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, localFilePath) + } + c.logf("Signature OK") + + return nil +} + +// signingKeys fetches current signing keys from the server and validates them +// against the roots. Should be called before validation of any downloaded file +// to get the fresh keys. +func (c *Client) signingKeys() ([]ed25519.PublicKey, error) { + keyURL := c.url("distsign.pub") + sigURL := keyURL + ".sig" + raw, err := fetch(keyURL, signingKeysSizeLimit) + if err != nil { + return nil, err + } + sig, err := fetch(sigURL, signatureSizeLimit) + if err != nil { + return nil, err + } + if !VerifyAny(c.roots, raw, sig) { + return nil, fmt.Errorf("signature %q for key %q does not validate with any known root key; either you are under attack, or running a very old version of Tailscale with outdated root keys", sigURL, keyURL) + } + + keys, err := ParseSigningKeyBundle(raw) + if err != nil { + return nil, fmt.Errorf("cannot parse signing key bundle from %q: %w", keyURL, err) + } + return keys, nil +} + +// fetch reads the response body from url into memory, up to limit bytes. +func fetch(url string, limit int64) ([]byte, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return io.ReadAll(io.LimitReader(resp.Body, limit)) +} + +// download writes the response body of url into a local file at dst, up to +// limit bytes. On success, the returned value is a BLAKE2s hash of the file. +func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.Proxy = tshttpproxy.ProxyFromEnvironment + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr} + + quickCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + headReq := must.Get(http.NewRequestWithContext(quickCtx, httpm.HEAD, url, nil)) + + res, err := hc.Do(headReq) + if err != nil { + return nil, 0, err + } + if res.StatusCode != http.StatusOK { + return nil, 0, fmt.Errorf("HEAD %q: %v", url, res.Status) + } + if res.ContentLength <= 0 { + return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength) + } + c.logf("Download size: %v", res.ContentLength) + + dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil)) + dlRes, err := hc.Do(dlReq) + if err != nil { + return nil, 0, err + } + defer dlRes.Body.Close() + // TODO(bradfitz): resume from existing partial file on disk + if dlRes.StatusCode != http.StatusOK { + return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status) + } + + of, err := os.Create(dst) + if err != nil { + return nil, 0, err + } + defer of.Close() + pw := &progressWriter{total: res.ContentLength, logf: c.logf} + h := NewPackageHash() + n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit)) + if err != nil { + return nil, n, err + } + if n != res.ContentLength { + return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength) + } + if err := dlRes.Body.Close(); err != nil { + return nil, n, err + } + if err := of.Close(); err != nil { + return nil, n, err + } + pw.print() + + return h.Sum(nil), h.Len(), nil +} + +type progressWriter struct { + done int64 + total int64 + lastPrint time.Time + logf logger.Logf +} + +func (pw *progressWriter) Write(p []byte) (n int, err error) { + pw.done += int64(len(p)) + if time.Since(pw.lastPrint) > 2*time.Second { + pw.print() + } + return len(p), nil +} + +func (pw *progressWriter) print() { + pw.lastPrint = time.Now() + pw.logf("Downloaded %v/%v (%.1f%%)", pw.done, pw.total, float64(pw.done)/float64(pw.total)*100) +} + +func parsePrivateKey(data []byte, typeTag string) (ed25519.PrivateKey, error) { + b, rest := pem.Decode(data) + if b == nil { + return nil, errors.New("failed to decode PEM data") + } + if len(rest) > 0 { + return nil, errors.New("trailing PEM data") + } + if b.Type != typeTag { + return nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) + } + if len(b.Bytes) != ed25519.PrivateKeySize { + return nil, errors.New("private key has incorrect length for an Ed25519 private key") + } + return ed25519.PrivateKey(b.Bytes), nil +} + +// ParseSigningKeyBundle parses the bundle of PEM-encoded public signing keys. +func ParseSigningKeyBundle(bundle []byte) ([]ed25519.PublicKey, error) { + return parsePublicKeyBundle(bundle, pemTypeSigningPublic) +} + +// ParseRootKeyBundle parses the bundle of PEM-encoded public root keys. +func ParseRootKeyBundle(bundle []byte) ([]ed25519.PublicKey, error) { + return parsePublicKeyBundle(bundle, pemTypeRootPublic) +} + +func parsePublicKeyBundle(bundle []byte, typeTag string) ([]ed25519.PublicKey, error) { + var keys []ed25519.PublicKey + for len(bundle) > 0 { + pub, rest, err := parsePublicKey(bundle, typeTag) + if err != nil { + return nil, err + } + keys = append(keys, pub) + bundle = rest + } + if len(keys) == 0 { + return nil, errors.New("no signing keys found in the bundle") + } + return keys, nil +} + +func parseSinglePublicKey(data []byte, typeTag string) (ed25519.PublicKey, error) { + pub, rest, err := parsePublicKey(data, typeTag) + if err != nil { + return nil, err + } + if len(rest) > 0 { + return nil, errors.New("trailing PEM data") + } + return pub, err +} + +func parsePublicKey(data []byte, typeTag string) (pub ed25519.PublicKey, rest []byte, retErr error) { + b, rest := pem.Decode(data) + if b == nil { + return nil, nil, errors.New("failed to decode PEM data") + } + if b.Type != typeTag { + return nil, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) + } + if len(b.Bytes) != ed25519.PublicKeySize { + return nil, nil, errors.New("public key has incorrect length for an Ed25519 public key") + } + return ed25519.PublicKey(b.Bytes), rest, nil +} + +// VerifyAny verifies whether sig is valid for msg using any of the keys. +// VerifyAny will panic if any of the keys have the wrong size for Ed25519. +func VerifyAny(keys []ed25519.PublicKey, msg, sig []byte) bool { + for _, k := range keys { + if ed25519consensus.Verify(k, msg, sig) { + return true + } + } + return false +} diff --git a/clientupdate/distsign/roots.go b/clientupdate/distsign/roots.go index d5b47b7b62e92..df86557979ecd 100644 --- a/clientupdate/distsign/roots.go +++ b/clientupdate/distsign/roots.go @@ -1,54 +1,54 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package distsign - -import ( - "crypto/ed25519" - "embed" - "errors" - "fmt" - "path" - "path/filepath" - "sync" -) - -//go:embed roots -var rootsFS embed.FS - -var roots = sync.OnceValue(func() []ed25519.PublicKey { - roots, err := parseRoots() - if err != nil { - panic(err) - } - return roots -}) - -func parseRoots() ([]ed25519.PublicKey, error) { - files, err := rootsFS.ReadDir("roots") - if err != nil { - return nil, err - } - var keys []ed25519.PublicKey - for _, f := range files { - if !f.Type().IsRegular() { - continue - } - if filepath.Ext(f.Name()) != ".pem" { - continue - } - raw, err := rootsFS.ReadFile(path.Join("roots", f.Name())) - if err != nil { - return nil, err - } - key, err := parseSinglePublicKey(raw, pemTypeRootPublic) - if err != nil { - return nil, fmt.Errorf("parsing root key %q: %w", f.Name(), err) - } - keys = append(keys, key) - } - if len(keys) == 0 { - return nil, errors.New("no embedded root keys, please check clientupdate/distsign/roots/") - } - return keys, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package distsign + +import ( + "crypto/ed25519" + "embed" + "errors" + "fmt" + "path" + "path/filepath" + "sync" +) + +//go:embed roots +var rootsFS embed.FS + +var roots = sync.OnceValue(func() []ed25519.PublicKey { + roots, err := parseRoots() + if err != nil { + panic(err) + } + return roots +}) + +func parseRoots() ([]ed25519.PublicKey, error) { + files, err := rootsFS.ReadDir("roots") + if err != nil { + return nil, err + } + var keys []ed25519.PublicKey + for _, f := range files { + if !f.Type().IsRegular() { + continue + } + if filepath.Ext(f.Name()) != ".pem" { + continue + } + raw, err := rootsFS.ReadFile(path.Join("roots", f.Name())) + if err != nil { + return nil, err + } + key, err := parseSinglePublicKey(raw, pemTypeRootPublic) + if err != nil { + return nil, fmt.Errorf("parsing root key %q: %w", f.Name(), err) + } + keys = append(keys, key) + } + if len(keys) == 0 { + return nil, errors.New("no embedded root keys, please check clientupdate/distsign/roots/") + } + return keys, nil +} diff --git a/clientupdate/distsign/roots/crawshaw-root.pem b/clientupdate/distsign/roots/crawshaw-root.pem index f80b9aec78b11..897a38295b6b0 100755 --- a/clientupdate/distsign/roots/crawshaw-root.pem +++ b/clientupdate/distsign/roots/crawshaw-root.pem @@ -1,3 +1,3 @@ ------BEGIN ROOT PUBLIC KEY----- -Psrabv2YNiEDhPlnLVSMtB5EKACm7zxvKxfvYD4i7X8= ------END ROOT PUBLIC KEY----- +-----BEGIN ROOT PUBLIC KEY----- +Psrabv2YNiEDhPlnLVSMtB5EKACm7zxvKxfvYD4i7X8= +-----END ROOT PUBLIC KEY----- diff --git a/clientupdate/distsign/roots/distsign-prod-root-1-pub.pem b/clientupdate/distsign/roots/distsign-prod-root-1-pub.pem index d5d6516ab0368..e2f937ed3b0d1 100644 --- a/clientupdate/distsign/roots/distsign-prod-root-1-pub.pem +++ b/clientupdate/distsign/roots/distsign-prod-root-1-pub.pem @@ -1,3 +1,3 @@ ------BEGIN ROOT PUBLIC KEY----- -ZjjKhUHBtLNRSO1dhOTjrXJGJ8lDe1594WM2XDuheVQ= ------END ROOT PUBLIC KEY----- +-----BEGIN ROOT PUBLIC KEY----- +ZjjKhUHBtLNRSO1dhOTjrXJGJ8lDe1594WM2XDuheVQ= +-----END ROOT PUBLIC KEY----- diff --git a/clientupdate/distsign/roots_test.go b/clientupdate/distsign/roots_test.go index 7a94529538ef1..ae0dfbc22d5bd 100644 --- a/clientupdate/distsign/roots_test.go +++ b/clientupdate/distsign/roots_test.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package distsign - -import "testing" - -func TestParseRoots(t *testing.T) { - roots, err := parseRoots() - if err != nil { - t.Fatal(err) - } - if len(roots) == 0 { - t.Error("parseRoots returned no root keys") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package distsign + +import "testing" + +func TestParseRoots(t *testing.T) { + roots, err := parseRoots() + if err != nil { + t.Fatal(err) + } + if len(roots) == 0 { + t.Error("parseRoots returned no root keys") + } +} diff --git a/cmd/addlicense/main.go b/cmd/addlicense/main.go index a8fd9dd4ab96a..58ef7a4711c93 100644 --- a/cmd/addlicense/main.go +++ b/cmd/addlicense/main.go @@ -1,73 +1,73 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Program addlicense adds a license header to a file. -// It is intended for use with 'go generate', -// so it has a slightly weird usage. -package main - -import ( - "flag" - "fmt" - "os" - "os/exec" -) - -var ( - file = flag.String("file", "", "file to modify") -) - -func usage() { - fmt.Fprintf(os.Stderr, ` -usage: addlicense -file FILE -`[1:]) - - flag.PrintDefaults() - fmt.Fprintf(os.Stderr, ` -addlicense adds a Tailscale license to the beginning of file. - -It is intended for use with 'go generate', so it also runs a subcommand, -which presumably creates the file. - -Sample usage: - -addlicense -file pull_strings.go stringer -type=pull -`[1:]) - os.Exit(2) -} - -func main() { - flag.Usage = usage - flag.Parse() - if len(flag.Args()) == 0 { - flag.Usage() - } - cmd := exec.Command(flag.Arg(0), flag.Args()[1:]...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - err := cmd.Run() - check(err) - b, err := os.ReadFile(*file) - check(err) - f, err := os.OpenFile(*file, os.O_TRUNC|os.O_WRONLY, 0644) - check(err) - _, err = fmt.Fprint(f, license) - check(err) - _, err = f.Write(b) - check(err) - err = f.Close() - check(err) -} - -func check(err error) { - if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } -} - -var license = ` -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -`[1:] +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Program addlicense adds a license header to a file. +// It is intended for use with 'go generate', +// so it has a slightly weird usage. +package main + +import ( + "flag" + "fmt" + "os" + "os/exec" +) + +var ( + file = flag.String("file", "", "file to modify") +) + +func usage() { + fmt.Fprintf(os.Stderr, ` +usage: addlicense -file FILE +`[1:]) + + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, ` +addlicense adds a Tailscale license to the beginning of file. + +It is intended for use with 'go generate', so it also runs a subcommand, +which presumably creates the file. + +Sample usage: + +addlicense -file pull_strings.go stringer -type=pull +`[1:]) + os.Exit(2) +} + +func main() { + flag.Usage = usage + flag.Parse() + if len(flag.Args()) == 0 { + flag.Usage() + } + cmd := exec.Command(flag.Arg(0), flag.Args()[1:]...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Run() + check(err) + b, err := os.ReadFile(*file) + check(err) + f, err := os.OpenFile(*file, os.O_TRUNC|os.O_WRONLY, 0644) + check(err) + _, err = fmt.Fprint(f, license) + check(err) + _, err = f.Write(b) + check(err) + err = f.Close() + check(err) +} + +func check(err error) { + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +var license = ` +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +`[1:] diff --git a/cmd/cloner/cloner_test.go b/cmd/cloner/cloner_test.go index d8d5df3cb040c..83d33ab0e615b 100644 --- a/cmd/cloner/cloner_test.go +++ b/cmd/cloner/cloner_test.go @@ -1,60 +1,60 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package main - -import ( - "reflect" - "testing" - - "tailscale.com/cmd/cloner/clonerex" -) - -func TestSliceContainer(t *testing.T) { - num := 5 - examples := []struct { - name string - in *clonerex.SliceContainer - }{ - { - name: "nil", - in: nil, - }, - { - name: "zero", - in: &clonerex.SliceContainer{}, - }, - { - name: "empty", - in: &clonerex.SliceContainer{ - Slice: []*int{}, - }, - }, - { - name: "nils", - in: &clonerex.SliceContainer{ - Slice: []*int{nil, nil, nil, nil, nil}, - }, - }, - { - name: "one", - in: &clonerex.SliceContainer{ - Slice: []*int{&num}, - }, - }, - { - name: "several", - in: &clonerex.SliceContainer{ - Slice: []*int{&num, &num, &num, &num, &num}, - }, - }, - } - - for _, ex := range examples { - t.Run(ex.name, func(t *testing.T) { - out := ex.in.Clone() - if !reflect.DeepEqual(ex.in, out) { - t.Errorf("Clone() = %v, want %v", out, ex.in) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package main + +import ( + "reflect" + "testing" + + "tailscale.com/cmd/cloner/clonerex" +) + +func TestSliceContainer(t *testing.T) { + num := 5 + examples := []struct { + name string + in *clonerex.SliceContainer + }{ + { + name: "nil", + in: nil, + }, + { + name: "zero", + in: &clonerex.SliceContainer{}, + }, + { + name: "empty", + in: &clonerex.SliceContainer{ + Slice: []*int{}, + }, + }, + { + name: "nils", + in: &clonerex.SliceContainer{ + Slice: []*int{nil, nil, nil, nil, nil}, + }, + }, + { + name: "one", + in: &clonerex.SliceContainer{ + Slice: []*int{&num}, + }, + }, + { + name: "several", + in: &clonerex.SliceContainer{ + Slice: []*int{&num, &num, &num, &num, &num}, + }, + }, + } + + for _, ex := range examples { + t.Run(ex.name, func(t *testing.T) { + out := ex.in.Clone() + if !reflect.DeepEqual(ex.in, out) { + t.Errorf("Clone() = %v, want %v", out, ex.in) + } + }) + } +} diff --git a/cmd/containerboot/test_tailscale.sh b/cmd/containerboot/test_tailscale.sh index 1fa10abb18185..dd56adf044bd4 100644 --- a/cmd/containerboot/test_tailscale.sh +++ b/cmd/containerboot/test_tailscale.sh @@ -1,8 +1,8 @@ -#!/usr/bin/env bash -# -# This is a fake tailscale CLI (and also iptables and ip6tables) that -# records its arguments and exits successfully. -# -# It is used by main_test.go to test the behavior of containerboot. - -echo $0 $@ >>$TS_TEST_RECORD_ARGS +#!/usr/bin/env bash +# +# This is a fake tailscale CLI (and also iptables and ip6tables) that +# records its arguments and exits successfully. +# +# It is used by main_test.go to test the behavior of containerboot. + +echo $0 $@ >>$TS_TEST_RECORD_ARGS diff --git a/cmd/containerboot/test_tailscaled.sh b/cmd/containerboot/test_tailscaled.sh index 335e2cb0dcfd1..b7404a0a9d368 100644 --- a/cmd/containerboot/test_tailscaled.sh +++ b/cmd/containerboot/test_tailscaled.sh @@ -1,38 +1,38 @@ -#!/usr/bin/env bash -# -# This is a fake tailscale daemon that records its arguments, symlinks a -# fake LocalAPI socket into place, and does nothing until terminated. -# -# It is used by main_test.go to test the behavior of containerboot. - -set -eu - -echo $0 $@ >>$TS_TEST_RECORD_ARGS - -socket="" -while [[ $# -gt 0 ]]; do - case $1 in - --socket=*) - socket="${1#--socket=}" - shift - ;; - --socket) - shift - socket="$1" - shift - ;; - *) - shift - ;; - esac -done - -if [[ -z "$socket" ]]; then - echo "didn't find socket path in args" - exit 1 -fi - -ln -s "$TS_TEST_SOCKET" "$socket" -trap 'rm -f "$socket"' EXIT - -while sleep 10; do :; done +#!/usr/bin/env bash +# +# This is a fake tailscale daemon that records its arguments, symlinks a +# fake LocalAPI socket into place, and does nothing until terminated. +# +# It is used by main_test.go to test the behavior of containerboot. + +set -eu + +echo $0 $@ >>$TS_TEST_RECORD_ARGS + +socket="" +while [[ $# -gt 0 ]]; do + case $1 in + --socket=*) + socket="${1#--socket=}" + shift + ;; + --socket) + shift + socket="$1" + shift + ;; + *) + shift + ;; + esac +done + +if [[ -z "$socket" ]]; then + echo "didn't find socket path in args" + exit 1 +fi + +ln -s "$TS_TEST_SOCKET" "$socket" +trap 'rm -f "$socket"' EXIT + +while sleep 10; do :; done diff --git a/cmd/get-authkey/.gitignore b/cmd/get-authkey/.gitignore index 3f9c9fb90e68e..e00856fa12524 100644 --- a/cmd/get-authkey/.gitignore +++ b/cmd/get-authkey/.gitignore @@ -1 +1 @@ -get-authkey +get-authkey diff --git a/cmd/gitops-pusher/.gitignore b/cmd/gitops-pusher/.gitignore index 5044522494b23..eeed6e4bf5b1a 100644 --- a/cmd/gitops-pusher/.gitignore +++ b/cmd/gitops-pusher/.gitignore @@ -1 +1 @@ -version-cache.json +version-cache.json diff --git a/cmd/gitops-pusher/README.md b/cmd/gitops-pusher/README.md index 9f77ea970e033..b08125397a1ec 100644 --- a/cmd/gitops-pusher/README.md +++ b/cmd/gitops-pusher/README.md @@ -1,48 +1,48 @@ -# gitops-pusher - -This is a small tool to help people achieve a -[GitOps](https://about.gitlab.com/topics/gitops/) workflow with Tailscale ACL -changes. This tool is intended to be used in a CI flow that looks like this: - -```yaml -name: Tailscale ACL syncing - -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] - -jobs: - acls: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - name: Setup Go environment - uses: actions/setup-go@v3.2.0 - - - name: Install gitops-pusher - run: go install tailscale.com/cmd/gitops-pusher@latest - - - name: Deploy ACL - if: github.event_name == 'push' - env: - TS_API_KEY: ${{ secrets.TS_API_KEY }} - TS_TAILNET: ${{ secrets.TS_TAILNET }} - run: | - ~/go/bin/gitops-pusher --policy-file ./policy.hujson apply - - - name: ACL tests - if: github.event_name == 'pull_request' - env: - TS_API_KEY: ${{ secrets.TS_API_KEY }} - TS_TAILNET: ${{ secrets.TS_TAILNET }} - run: | - ~/go/bin/gitops-pusher --policy-file ./policy.hujson test -``` - -Change the value of the `--policy-file` flag to point to the policy file on -disk. Policy files should be in [HuJSON](https://github.com/tailscale/hujson) -format. +# gitops-pusher + +This is a small tool to help people achieve a +[GitOps](https://about.gitlab.com/topics/gitops/) workflow with Tailscale ACL +changes. This tool is intended to be used in a CI flow that looks like this: + +```yaml +name: Tailscale ACL syncing + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + acls: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Setup Go environment + uses: actions/setup-go@v3.2.0 + + - name: Install gitops-pusher + run: go install tailscale.com/cmd/gitops-pusher@latest + + - name: Deploy ACL + if: github.event_name == 'push' + env: + TS_API_KEY: ${{ secrets.TS_API_KEY }} + TS_TAILNET: ${{ secrets.TS_TAILNET }} + run: | + ~/go/bin/gitops-pusher --policy-file ./policy.hujson apply + + - name: ACL tests + if: github.event_name == 'pull_request' + env: + TS_API_KEY: ${{ secrets.TS_API_KEY }} + TS_TAILNET: ${{ secrets.TS_TAILNET }} + run: | + ~/go/bin/gitops-pusher --policy-file ./policy.hujson test +``` + +Change the value of the `--policy-file` flag to point to the policy file on +disk. Policy files should be in [HuJSON](https://github.com/tailscale/hujson) +format. diff --git a/cmd/gitops-pusher/cache.go b/cmd/gitops-pusher/cache.go index 6792e5e63e9cc..89225e6f86309 100644 --- a/cmd/gitops-pusher/cache.go +++ b/cmd/gitops-pusher/cache.go @@ -1,66 +1,66 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "encoding/json" - "os" -) - -// Cache contains cached information about the last time this tool was run. -// -// This is serialized to a JSON file that should NOT be checked into git. -// It should be managed with either CI cache tools or stored locally somehow. The -// exact mechanism is irrelevant as long as it is consistent. -// -// This allows gitops-pusher to detect external ACL changes. I'm not sure what to -// call this problem, so I've been calling it the "three version problem" in my -// notes. The basic problem is that at any given time we only have two versions -// of the ACL file at any given point. In order to check if there has been -// tampering of the ACL files in the admin panel, we need to have a _third_ version -// to compare against. -// -// In this case I am not storing the old ACL entirely (though that could be a -// reasonable thing to add in the future), but only its sha256sum. This allows -// us to detect if the shasum in control matches the shasum we expect, and if that -// expectation fails, then we can react accordingly. -type Cache struct { - PrevETag string // Stores the previous ETag of the ACL to allow -} - -// Save persists the cache to a given file. -func (c *Cache) Save(fname string) error { - os.Remove(fname) - fout, err := os.Create(fname) - if err != nil { - return err - } - defer fout.Close() - - return json.NewEncoder(fout).Encode(c) -} - -// LoadCache loads the cache from a given file. -func LoadCache(fname string) (*Cache, error) { - var result Cache - - fin, err := os.Open(fname) - if err != nil { - return nil, err - } - defer fin.Close() - - err = json.NewDecoder(fin).Decode(&result) - if err != nil { - return nil, err - } - - return &result, nil -} - -// Shuck removes the first and last character of a string, analogous to -// shucking off the husk of an ear of corn. -func Shuck(s string) string { - return s[1 : len(s)-1] -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "encoding/json" + "os" +) + +// Cache contains cached information about the last time this tool was run. +// +// This is serialized to a JSON file that should NOT be checked into git. +// It should be managed with either CI cache tools or stored locally somehow. The +// exact mechanism is irrelevant as long as it is consistent. +// +// This allows gitops-pusher to detect external ACL changes. I'm not sure what to +// call this problem, so I've been calling it the "three version problem" in my +// notes. The basic problem is that at any given time we only have two versions +// of the ACL file at any given point. In order to check if there has been +// tampering of the ACL files in the admin panel, we need to have a _third_ version +// to compare against. +// +// In this case I am not storing the old ACL entirely (though that could be a +// reasonable thing to add in the future), but only its sha256sum. This allows +// us to detect if the shasum in control matches the shasum we expect, and if that +// expectation fails, then we can react accordingly. +type Cache struct { + PrevETag string // Stores the previous ETag of the ACL to allow +} + +// Save persists the cache to a given file. +func (c *Cache) Save(fname string) error { + os.Remove(fname) + fout, err := os.Create(fname) + if err != nil { + return err + } + defer fout.Close() + + return json.NewEncoder(fout).Encode(c) +} + +// LoadCache loads the cache from a given file. +func LoadCache(fname string) (*Cache, error) { + var result Cache + + fin, err := os.Open(fname) + if err != nil { + return nil, err + } + defer fin.Close() + + err = json.NewDecoder(fin).Decode(&result) + if err != nil { + return nil, err + } + + return &result, nil +} + +// Shuck removes the first and last character of a string, analogous to +// shucking off the husk of an ear of corn. +func Shuck(s string) string { + return s[1 : len(s)-1] +} diff --git a/cmd/gitops-pusher/gitops-pusher_test.go b/cmd/gitops-pusher/gitops-pusher_test.go index b050761d9832d..1beb049c67d5a 100644 --- a/cmd/gitops-pusher/gitops-pusher_test.go +++ b/cmd/gitops-pusher/gitops-pusher_test.go @@ -1,55 +1,55 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package main - -import ( - "encoding/json" - "strings" - "testing" - - "tailscale.com/client/tailscale" -) - -func TestEmbeddedTypeUnmarshal(t *testing.T) { - var gitopsErr ACLGitopsTestError - gitopsErr.Message = "gitops response error" - gitopsErr.Data = []tailscale.ACLTestFailureSummary{ - { - User: "GitopsError", - Errors: []string{"this was initially created as a gitops error"}, - }, - } - - var aclTestErr tailscale.ACLTestError - aclTestErr.Message = "native ACL response error" - aclTestErr.Data = []tailscale.ACLTestFailureSummary{ - { - User: "ACLError", - Errors: []string{"this was initially created as an ACL error"}, - }, - } - - t.Run("unmarshal gitops type from acl type", func(t *testing.T) { - b, _ := json.Marshal(aclTestErr) - var e ACLGitopsTestError - err := json.Unmarshal(b, &e) - if err != nil { - t.Fatal(err) - } - if !strings.Contains(e.Error(), "For user ACLError") { // the gitops error prints out the user, the acl error doesn't - t.Fatalf("user heading for 'ACLError' not found in gitops error: %v", e.Error()) - } - }) - t.Run("unmarshal acl type from gitops type", func(t *testing.T) { - b, _ := json.Marshal(gitopsErr) - var e tailscale.ACLTestError - err := json.Unmarshal(b, &e) - if err != nil { - t.Fatal(err) - } - expectedErr := `Status: 0, Message: "gitops response error", Data: [{User:GitopsError Errors:[this was initially created as a gitops error] Warnings:[]}]` - if e.Error() != expectedErr { - t.Fatalf("got %v\n, expected %v", e.Error(), expectedErr) - } - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package main + +import ( + "encoding/json" + "strings" + "testing" + + "tailscale.com/client/tailscale" +) + +func TestEmbeddedTypeUnmarshal(t *testing.T) { + var gitopsErr ACLGitopsTestError + gitopsErr.Message = "gitops response error" + gitopsErr.Data = []tailscale.ACLTestFailureSummary{ + { + User: "GitopsError", + Errors: []string{"this was initially created as a gitops error"}, + }, + } + + var aclTestErr tailscale.ACLTestError + aclTestErr.Message = "native ACL response error" + aclTestErr.Data = []tailscale.ACLTestFailureSummary{ + { + User: "ACLError", + Errors: []string{"this was initially created as an ACL error"}, + }, + } + + t.Run("unmarshal gitops type from acl type", func(t *testing.T) { + b, _ := json.Marshal(aclTestErr) + var e ACLGitopsTestError + err := json.Unmarshal(b, &e) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(e.Error(), "For user ACLError") { // the gitops error prints out the user, the acl error doesn't + t.Fatalf("user heading for 'ACLError' not found in gitops error: %v", e.Error()) + } + }) + t.Run("unmarshal acl type from gitops type", func(t *testing.T) { + b, _ := json.Marshal(gitopsErr) + var e tailscale.ACLTestError + err := json.Unmarshal(b, &e) + if err != nil { + t.Fatal(err) + } + expectedErr := `Status: 0, Message: "gitops response error", Data: [{User:GitopsError Errors:[this was initially created as a gitops error] Warnings:[]}]` + if e.Error() != expectedErr { + t.Fatalf("got %v\n, expected %v", e.Error(), expectedErr) + } + }) +} diff --git a/cmd/k8s-operator/deploy/chart/.helmignore b/cmd/k8s-operator/deploy/chart/.helmignore index 0e8a0eb36f4ca..f82e96d46779c 100644 --- a/cmd/k8s-operator/deploy/chart/.helmignore +++ b/cmd/k8s-operator/deploy/chart/.helmignore @@ -1,23 +1,23 @@ -# Patterns to ignore when building packages. -# This supports shell glob matching, relative path matching, and -# negation (prefixed with !). Only one pattern per line. -.DS_Store -# Common VCS dirs -.git/ -.gitignore -.bzr/ -.bzrignore -.hg/ -.hgignore -.svn/ -# Common backup files -*.swp -*.bak -*.tmp -*.orig -*~ -# Various IDEs -.project -.idea/ -*.tmproj -.vscode/ +# Patterns to ignore when building packages. +# This supports shell glob matching, relative path matching, and +# negation (prefixed with !). Only one pattern per line. +.DS_Store +# Common VCS dirs +.git/ +.gitignore +.bzr/ +.bzrignore +.hg/ +.hgignore +.svn/ +# Common backup files +*.swp +*.bak +*.tmp +*.orig +*~ +# Various IDEs +.project +.idea/ +*.tmproj +.vscode/ diff --git a/cmd/k8s-operator/deploy/chart/Chart.yaml b/cmd/k8s-operator/deploy/chart/Chart.yaml index 363d87d15954a..472850c415200 100644 --- a/cmd/k8s-operator/deploy/chart/Chart.yaml +++ b/cmd/k8s-operator/deploy/chart/Chart.yaml @@ -1,29 +1,29 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -apiVersion: v2 -name: tailscale-operator -description: A Helm chart for Tailscale Kubernetes operator -home: https://github.com/tailscale/tailscale - -keywords: - - "tailscale" - - "vpn" - - "ingress" - - "egress" - - "wireguard" - -sources: -- https://github.com/tailscale/tailscale - -type: application - -maintainers: - - name: tailscale-maintainers - url: https://tailscale.com/ - -# version will be set to Tailscale repo tag (without 'v') at release time. -version: 0.1.0 - -# appVersion will be set to Tailscale repo tag at release time. -appVersion: "unstable" +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +apiVersion: v2 +name: tailscale-operator +description: A Helm chart for Tailscale Kubernetes operator +home: https://github.com/tailscale/tailscale + +keywords: + - "tailscale" + - "vpn" + - "ingress" + - "egress" + - "wireguard" + +sources: +- https://github.com/tailscale/tailscale + +type: application + +maintainers: + - name: tailscale-maintainers + url: https://tailscale.com/ + +# version will be set to Tailscale repo tag (without 'v') at release time. +version: 0.1.0 + +# appVersion will be set to Tailscale repo tag at release time. +appVersion: "unstable" diff --git a/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml b/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml index 072ecf6d22e2f..488c87d8a09c5 100644 --- a/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml @@ -1,26 +1,26 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -{{ if eq .Values.apiServerProxyConfig.mode "true" }} -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - name: tailscale-auth-proxy -rules: -- apiGroups: [""] - resources: ["users", "groups"] - verbs: ["impersonate"] ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRoleBinding -metadata: - name: tailscale-auth-proxy -subjects: -- kind: ServiceAccount - name: operator - namespace: {{ .Release.Namespace }} -roleRef: - kind: ClusterRole - name: tailscale-auth-proxy - apiGroup: rbac.authorization.k8s.io -{{ end }} +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +{{ if eq .Values.apiServerProxyConfig.mode "true" }} +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: tailscale-auth-proxy +rules: +- apiGroups: [""] + resources: ["users", "groups"] + verbs: ["impersonate"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: tailscale-auth-proxy +subjects: +- kind: ServiceAccount + name: operator + namespace: {{ .Release.Namespace }} +roleRef: + kind: ClusterRole + name: tailscale-auth-proxy + apiGroup: rbac.authorization.k8s.io +{{ end }} diff --git a/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml b/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml index b44fde0a17b49..bde64b7f625eb 100644 --- a/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml @@ -1,13 +1,13 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -{{ if and .Values.oauth .Values.oauth.clientId -}} -apiVersion: v1 -kind: Secret -metadata: - name: operator-oauth - namespace: {{ .Release.Namespace }} -stringData: - client_id: {{ .Values.oauth.clientId }} - client_secret: {{ .Values.oauth.clientSecret }} -{{- end -}} +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +{{ if and .Values.oauth .Values.oauth.clientId -}} +apiVersion: v1 +kind: Secret +metadata: + name: operator-oauth + namespace: {{ .Release.Namespace }} +stringData: + client_id: {{ .Values.oauth.clientId }} + client_secret: {{ .Values.oauth.clientSecret }} +{{- end -}} diff --git a/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml b/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml index ddbdda32e476e..d957260eb513f 100644 --- a/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml +++ b/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml @@ -1,24 +1,24 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - name: tailscale-auth-proxy -rules: -- apiGroups: [""] - resources: ["users", "groups"] - verbs: ["impersonate"] ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRoleBinding -metadata: - name: tailscale-auth-proxy -subjects: -- kind: ServiceAccount - name: operator - namespace: tailscale -roleRef: - kind: ClusterRole - name: tailscale-auth-proxy +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: tailscale-auth-proxy +rules: +- apiGroups: [""] + resources: ["users", "groups"] + verbs: ["impersonate"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: tailscale-auth-proxy +subjects: +- kind: ServiceAccount + name: operator + namespace: tailscale +roleRef: + kind: ClusterRole + name: tailscale-auth-proxy apiGroup: rbac.authorization.k8s.io \ No newline at end of file diff --git a/cmd/mkmanifest/main.go b/cmd/mkmanifest/main.go index fb3c729f12d21..22cd150262cbb 100644 --- a/cmd/mkmanifest/main.go +++ b/cmd/mkmanifest/main.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The mkmanifest command is a simple helper utility to create a '.syso' file -// that contains a Windows manifest file. -package main - -import ( - "log" - "os" - - "github.com/tc-hib/winres" -) - -func main() { - if len(os.Args) != 4 { - log.Fatalf("usage: %s arch manifest.xml output.syso", os.Args[0]) - } - - arch := winres.Arch(os.Args[1]) - switch arch { - case winres.ArchAMD64, winres.ArchARM64, winres.ArchI386: - default: - log.Fatalf("unsupported arch: %s", arch) - } - - manifest, err := os.ReadFile(os.Args[2]) - if err != nil { - log.Fatalf("error reading manifest file %q: %v", os.Args[2], err) - } - - out := os.Args[3] - - // Start by creating an empty resource set - rs := winres.ResourceSet{} - - // Add resources - rs.Set(winres.RT_MANIFEST, winres.ID(1), 0, manifest) - - // Compile to a COFF object file - f, err := os.Create(out) - if err != nil { - log.Fatalf("error creating output file %q: %v", out, err) - } - if err := rs.WriteObject(f, arch); err != nil { - log.Fatalf("error writing object: %v", err) - } - if err := f.Close(); err != nil { - log.Fatalf("error writing output file %q: %v", out, err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The mkmanifest command is a simple helper utility to create a '.syso' file +// that contains a Windows manifest file. +package main + +import ( + "log" + "os" + + "github.com/tc-hib/winres" +) + +func main() { + if len(os.Args) != 4 { + log.Fatalf("usage: %s arch manifest.xml output.syso", os.Args[0]) + } + + arch := winres.Arch(os.Args[1]) + switch arch { + case winres.ArchAMD64, winres.ArchARM64, winres.ArchI386: + default: + log.Fatalf("unsupported arch: %s", arch) + } + + manifest, err := os.ReadFile(os.Args[2]) + if err != nil { + log.Fatalf("error reading manifest file %q: %v", os.Args[2], err) + } + + out := os.Args[3] + + // Start by creating an empty resource set + rs := winres.ResourceSet{} + + // Add resources + rs.Set(winres.RT_MANIFEST, winres.ID(1), 0, manifest) + + // Compile to a COFF object file + f, err := os.Create(out) + if err != nil { + log.Fatalf("error creating output file %q: %v", out, err) + } + if err := rs.WriteObject(f, arch); err != nil { + log.Fatalf("error writing object: %v", err) + } + if err := f.Close(); err != nil { + log.Fatalf("error writing output file %q: %v", out, err) + } +} diff --git a/cmd/mkpkg/main.go b/cmd/mkpkg/main.go index 5e26b07f8f9f8..e942c0162a4fd 100644 --- a/cmd/mkpkg/main.go +++ b/cmd/mkpkg/main.go @@ -1,134 +1,134 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// mkpkg builds the Tailscale rpm and deb packages. -package main - -import ( - "flag" - "fmt" - "log" - "os" - "strings" - - "github.com/goreleaser/nfpm/v2" - _ "github.com/goreleaser/nfpm/v2/deb" - "github.com/goreleaser/nfpm/v2/files" - _ "github.com/goreleaser/nfpm/v2/rpm" -) - -// parseFiles parses a comma-separated list of colon-separated pairs -// into files.Contents format. -func parseFiles(s string, typ string) (files.Contents, error) { - if len(s) == 0 { - return nil, nil - } - var contents files.Contents - for _, f := range strings.Split(s, ",") { - fs := strings.Split(f, ":") - if len(fs) != 2 { - return nil, fmt.Errorf("unparseable file field %q", f) - } - contents = append(contents, &files.Content{Type: files.TypeFile, Source: fs[0], Destination: fs[1]}) - } - return contents, nil -} - -func parseEmptyDirs(s string) files.Contents { - // strings.Split("", ",") would return []string{""}, which is not suitable: - // this would create an empty dir record with path "", breaking the package - if s == "" { - return nil - } - var contents files.Contents - for _, d := range strings.Split(s, ",") { - contents = append(contents, &files.Content{Type: files.TypeDir, Destination: d}) - } - return contents -} - -func main() { - out := flag.String("out", "", "output file to write") - name := flag.String("name", "tailscale", "package name") - description := flag.String("description", "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", "package description") - goarch := flag.String("arch", "amd64", "GOARCH this package is for") - pkgType := flag.String("type", "deb", "type of package to build (deb or rpm)") - regularFiles := flag.String("files", "", "comma-separated list of files in src:dst form") - configFiles := flag.String("configs", "", "like --files, but for files marked as user-editable config files") - emptyDirs := flag.String("emptydirs", "", "comma-separated list of empty directories") - version := flag.String("version", "0.0.0", "version of the package") - postinst := flag.String("postinst", "", "debian postinst script path") - prerm := flag.String("prerm", "", "debian prerm script path") - postrm := flag.String("postrm", "", "debian postrm script path") - replaces := flag.String("replaces", "", "package which this package replaces, if any") - depends := flag.String("depends", "", "comma-separated list of packages this package depends on") - recommends := flag.String("recommends", "", "comma-separated list of packages this package recommends") - flag.Parse() - - filesList, err := parseFiles(*regularFiles, files.TypeFile) - if err != nil { - log.Fatalf("Parsing --files: %v", err) - } - configsList, err := parseFiles(*configFiles, files.TypeConfig) - if err != nil { - log.Fatalf("Parsing --configs: %v", err) - } - emptyDirList := parseEmptyDirs(*emptyDirs) - contents := append(filesList, append(configsList, emptyDirList...)...) - contents, err = files.PrepareForPackager(contents, 0, *pkgType, false) - if err != nil { - log.Fatalf("Building package contents: %v", err) - } - info := nfpm.WithDefaults(&nfpm.Info{ - Name: *name, - Arch: *goarch, - Platform: "linux", - Version: *version, - Maintainer: "Tailscale Inc ", - Description: *description, - Homepage: "https://www.tailscale.com", - License: "MIT", - Overridables: nfpm.Overridables{ - Contents: contents, - Scripts: nfpm.Scripts{ - PostInstall: *postinst, - PreRemove: *prerm, - PostRemove: *postrm, - }, - }, - }) - - if len(*depends) != 0 { - info.Overridables.Depends = strings.Split(*depends, ",") - } - if len(*recommends) != 0 { - info.Overridables.Recommends = strings.Split(*recommends, ",") - } - if *replaces != "" { - info.Overridables.Replaces = []string{*replaces} - info.Overridables.Conflicts = []string{*replaces} - } - - switch *pkgType { - case "deb": - info.Section = "net" - info.Priority = "extra" - case "rpm": - info.Overridables.RPM.Group = "Network" - } - - pkg, err := nfpm.Get(*pkgType) - if err != nil { - log.Fatalf("Getting packager for %q: %v", *pkgType, err) - } - - f, err := os.Create(*out) - if err != nil { - log.Fatalf("Creating output file %q: %v", *out, err) - } - defer f.Close() - - if err := pkg.Package(info, f); err != nil { - log.Fatalf("Creating package %q: %v", *out, err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// mkpkg builds the Tailscale rpm and deb packages. +package main + +import ( + "flag" + "fmt" + "log" + "os" + "strings" + + "github.com/goreleaser/nfpm/v2" + _ "github.com/goreleaser/nfpm/v2/deb" + "github.com/goreleaser/nfpm/v2/files" + _ "github.com/goreleaser/nfpm/v2/rpm" +) + +// parseFiles parses a comma-separated list of colon-separated pairs +// into files.Contents format. +func parseFiles(s string, typ string) (files.Contents, error) { + if len(s) == 0 { + return nil, nil + } + var contents files.Contents + for _, f := range strings.Split(s, ",") { + fs := strings.Split(f, ":") + if len(fs) != 2 { + return nil, fmt.Errorf("unparseable file field %q", f) + } + contents = append(contents, &files.Content{Type: files.TypeFile, Source: fs[0], Destination: fs[1]}) + } + return contents, nil +} + +func parseEmptyDirs(s string) files.Contents { + // strings.Split("", ",") would return []string{""}, which is not suitable: + // this would create an empty dir record with path "", breaking the package + if s == "" { + return nil + } + var contents files.Contents + for _, d := range strings.Split(s, ",") { + contents = append(contents, &files.Content{Type: files.TypeDir, Destination: d}) + } + return contents +} + +func main() { + out := flag.String("out", "", "output file to write") + name := flag.String("name", "tailscale", "package name") + description := flag.String("description", "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", "package description") + goarch := flag.String("arch", "amd64", "GOARCH this package is for") + pkgType := flag.String("type", "deb", "type of package to build (deb or rpm)") + regularFiles := flag.String("files", "", "comma-separated list of files in src:dst form") + configFiles := flag.String("configs", "", "like --files, but for files marked as user-editable config files") + emptyDirs := flag.String("emptydirs", "", "comma-separated list of empty directories") + version := flag.String("version", "0.0.0", "version of the package") + postinst := flag.String("postinst", "", "debian postinst script path") + prerm := flag.String("prerm", "", "debian prerm script path") + postrm := flag.String("postrm", "", "debian postrm script path") + replaces := flag.String("replaces", "", "package which this package replaces, if any") + depends := flag.String("depends", "", "comma-separated list of packages this package depends on") + recommends := flag.String("recommends", "", "comma-separated list of packages this package recommends") + flag.Parse() + + filesList, err := parseFiles(*regularFiles, files.TypeFile) + if err != nil { + log.Fatalf("Parsing --files: %v", err) + } + configsList, err := parseFiles(*configFiles, files.TypeConfig) + if err != nil { + log.Fatalf("Parsing --configs: %v", err) + } + emptyDirList := parseEmptyDirs(*emptyDirs) + contents := append(filesList, append(configsList, emptyDirList...)...) + contents, err = files.PrepareForPackager(contents, 0, *pkgType, false) + if err != nil { + log.Fatalf("Building package contents: %v", err) + } + info := nfpm.WithDefaults(&nfpm.Info{ + Name: *name, + Arch: *goarch, + Platform: "linux", + Version: *version, + Maintainer: "Tailscale Inc ", + Description: *description, + Homepage: "https://www.tailscale.com", + License: "MIT", + Overridables: nfpm.Overridables{ + Contents: contents, + Scripts: nfpm.Scripts{ + PostInstall: *postinst, + PreRemove: *prerm, + PostRemove: *postrm, + }, + }, + }) + + if len(*depends) != 0 { + info.Overridables.Depends = strings.Split(*depends, ",") + } + if len(*recommends) != 0 { + info.Overridables.Recommends = strings.Split(*recommends, ",") + } + if *replaces != "" { + info.Overridables.Replaces = []string{*replaces} + info.Overridables.Conflicts = []string{*replaces} + } + + switch *pkgType { + case "deb": + info.Section = "net" + info.Priority = "extra" + case "rpm": + info.Overridables.RPM.Group = "Network" + } + + pkg, err := nfpm.Get(*pkgType) + if err != nil { + log.Fatalf("Getting packager for %q: %v", *pkgType, err) + } + + f, err := os.Create(*out) + if err != nil { + log.Fatalf("Creating output file %q: %v", *out, err) + } + defer f.Close() + + if err := pkg.Package(info, f); err != nil { + log.Fatalf("Creating package %q: %v", *out, err) + } +} diff --git a/cmd/mkversion/mkversion.go b/cmd/mkversion/mkversion.go index c8c8bf17930f6..6a6a18a50d090 100644 --- a/cmd/mkversion/mkversion.go +++ b/cmd/mkversion/mkversion.go @@ -1,44 +1,44 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// mkversion gets version info from git and outputs a bunch of shell variables -// that get used elsewhere in the build system to embed version numbers into -// binaries. -package main - -import ( - "bufio" - "bytes" - "fmt" - "io" - "os" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/version/mkversion" -) - -func main() { - prefix := "" - if len(os.Args) > 1 { - if os.Args[1] == "--export" { - prefix = "export " - } else { - fmt.Println("usage: mkversion [--export|-h|--help]") - os.Exit(1) - } - } - - var b bytes.Buffer - io.WriteString(&b, mkversion.Info().String()) - // Copyright and the client capability are not part of the version - // information, but similarly used in Xcode builds to embed in the metadata, - // thus generate them now. - copyright := fmt.Sprintf("Copyright © %d Tailscale Inc. All Rights Reserved.", time.Now().Year()) - fmt.Fprintf(&b, "VERSION_COPYRIGHT=%q\n", copyright) - fmt.Fprintf(&b, "VERSION_CAPABILITY=%d\n", tailcfg.CurrentCapabilityVersion) - s := bufio.NewScanner(&b) - for s.Scan() { - fmt.Println(prefix + s.Text()) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// mkversion gets version info from git and outputs a bunch of shell variables +// that get used elsewhere in the build system to embed version numbers into +// binaries. +package main + +import ( + "bufio" + "bytes" + "fmt" + "io" + "os" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/version/mkversion" +) + +func main() { + prefix := "" + if len(os.Args) > 1 { + if os.Args[1] == "--export" { + prefix = "export " + } else { + fmt.Println("usage: mkversion [--export|-h|--help]") + os.Exit(1) + } + } + + var b bytes.Buffer + io.WriteString(&b, mkversion.Info().String()) + // Copyright and the client capability are not part of the version + // information, but similarly used in Xcode builds to embed in the metadata, + // thus generate them now. + copyright := fmt.Sprintf("Copyright © %d Tailscale Inc. All Rights Reserved.", time.Now().Year()) + fmt.Fprintf(&b, "VERSION_COPYRIGHT=%q\n", copyright) + fmt.Fprintf(&b, "VERSION_CAPABILITY=%d\n", tailcfg.CurrentCapabilityVersion) + s := bufio.NewScanner(&b) + for s.Scan() { + fmt.Println(prefix + s.Text()) + } +} diff --git a/cmd/nardump/README.md b/cmd/nardump/README.md index 6fa7fc2f1d345..6c73ff9b0f399 100644 --- a/cmd/nardump/README.md +++ b/cmd/nardump/README.md @@ -1,7 +1,7 @@ -# nardump - -nardump is like nix-store --dump, but in Go, writing a NAR file (tar-like, -but focused on being reproducible) to stdout or to a hash with the --sri flag. - -It lets us calculate the Nix sha256 in shell.nix without the person running -git-pull-oss.sh having Nix available. +# nardump + +nardump is like nix-store --dump, but in Go, writing a NAR file (tar-like, +but focused on being reproducible) to stdout or to a hash with the --sri flag. + +It lets us calculate the Nix sha256 in shell.nix without the person running +git-pull-oss.sh having Nix available. diff --git a/cmd/nardump/nardump.go b/cmd/nardump/nardump.go index 05be7b65a7e37..241475537c418 100644 --- a/cmd/nardump/nardump.go +++ b/cmd/nardump/nardump.go @@ -1,184 +1,184 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// nardump is like nix-store --dump, but in Go, writing a NAR -// file (tar-like, but focused on being reproducible) to stdout -// or to a hash with the --sri flag. -// -// It lets us calculate a Nix sha256 without the person running -// git-pull-oss.sh having Nix available. -package main - -// For the format, see: -// See https://gist.github.com/jbeda/5c79d2b1434f0018d693 - -import ( - "bufio" - "crypto/sha256" - "encoding/base64" - "encoding/binary" - "flag" - "fmt" - "io" - "io/fs" - "log" - "os" - "path" - "sort" -) - -var sri = flag.Bool("sri", false, "print SRI") - -func main() { - flag.Parse() - if flag.NArg() != 1 { - log.Fatal("usage: nardump ") - } - arg := flag.Arg(0) - if err := os.Chdir(arg); err != nil { - log.Fatal(err) - } - if *sri { - hash := sha256.New() - if err := writeNAR(hash, os.DirFS(".")); err != nil { - log.Fatal(err) - } - fmt.Printf("sha256-%s\n", base64.StdEncoding.EncodeToString(hash.Sum(nil))) - return - } - bw := bufio.NewWriter(os.Stdout) - if err := writeNAR(bw, os.DirFS(".")); err != nil { - log.Fatal(err) - } - bw.Flush() -} - -// writeNARError is a sentinel panic type that's recovered by writeNAR -// and converted into the wrapped error. -type writeNARError struct{ err error } - -// narWriter writes NAR files. -type narWriter struct { - w io.Writer - fs fs.FS -} - -// writeNAR writes a NAR file to w from the root of fs. -func writeNAR(w io.Writer, fs fs.FS) (err error) { - defer func() { - if e := recover(); e != nil { - if we, ok := e.(writeNARError); ok { - err = we.err - return - } - panic(e) - } - }() - nw := &narWriter{w: w, fs: fs} - nw.str("nix-archive-1") - return nw.writeDir(".") -} - -func (nw *narWriter) writeDir(dirPath string) error { - ents, err := fs.ReadDir(nw.fs, dirPath) - if err != nil { - return err - } - sort.Slice(ents, func(i, j int) bool { - return ents[i].Name() < ents[j].Name() - }) - nw.str("(") - nw.str("type") - nw.str("directory") - for _, ent := range ents { - nw.str("entry") - nw.str("(") - nw.str("name") - nw.str(ent.Name()) - nw.str("node") - mode := ent.Type() - sub := path.Join(dirPath, ent.Name()) - var err error - switch { - case mode.IsRegular(): - err = nw.writeRegular(sub) - case mode.IsDir(): - err = nw.writeDir(sub) - default: - // TODO(bradfitz): symlink, but requires fighting io/fs a bit - // to get at Readlink or the osFS via fs. But for now - // we don't need symlinks because they're not in Go's archive. - return fmt.Errorf("unsupported file type %v at %q", sub, mode) - } - if err != nil { - return err - } - nw.str(")") - } - nw.str(")") - return nil -} - -func (nw *narWriter) writeRegular(path string) error { - nw.str("(") - nw.str("type") - nw.str("regular") - fi, err := fs.Stat(nw.fs, path) - if err != nil { - return err - } - if fi.Mode()&0111 != 0 { - nw.str("executable") - nw.str("") - } - contents, err := fs.ReadFile(nw.fs, path) - if err != nil { - return err - } - nw.str("contents") - if err := writeBytes(nw.w, contents); err != nil { - return err - } - nw.str(")") - return nil -} - -func (nw *narWriter) str(s string) { - if err := writeString(nw.w, s); err != nil { - panic(writeNARError{err}) - } -} - -func writeString(w io.Writer, s string) error { - var buf [8]byte - binary.LittleEndian.PutUint64(buf[:], uint64(len(s))) - if _, err := w.Write(buf[:]); err != nil { - return err - } - if _, err := io.WriteString(w, s); err != nil { - return err - } - return writePad(w, len(s)) -} - -func writeBytes(w io.Writer, b []byte) error { - var buf [8]byte - binary.LittleEndian.PutUint64(buf[:], uint64(len(b))) - if _, err := w.Write(buf[:]); err != nil { - return err - } - if _, err := w.Write(b); err != nil { - return err - } - return writePad(w, len(b)) -} - -func writePad(w io.Writer, n int) error { - pad := n % 8 - if pad == 0 { - return nil - } - var zeroes [8]byte - _, err := w.Write(zeroes[:8-pad]) - return err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// nardump is like nix-store --dump, but in Go, writing a NAR +// file (tar-like, but focused on being reproducible) to stdout +// or to a hash with the --sri flag. +// +// It lets us calculate a Nix sha256 without the person running +// git-pull-oss.sh having Nix available. +package main + +// For the format, see: +// See https://gist.github.com/jbeda/5c79d2b1434f0018d693 + +import ( + "bufio" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "flag" + "fmt" + "io" + "io/fs" + "log" + "os" + "path" + "sort" +) + +var sri = flag.Bool("sri", false, "print SRI") + +func main() { + flag.Parse() + if flag.NArg() != 1 { + log.Fatal("usage: nardump ") + } + arg := flag.Arg(0) + if err := os.Chdir(arg); err != nil { + log.Fatal(err) + } + if *sri { + hash := sha256.New() + if err := writeNAR(hash, os.DirFS(".")); err != nil { + log.Fatal(err) + } + fmt.Printf("sha256-%s\n", base64.StdEncoding.EncodeToString(hash.Sum(nil))) + return + } + bw := bufio.NewWriter(os.Stdout) + if err := writeNAR(bw, os.DirFS(".")); err != nil { + log.Fatal(err) + } + bw.Flush() +} + +// writeNARError is a sentinel panic type that's recovered by writeNAR +// and converted into the wrapped error. +type writeNARError struct{ err error } + +// narWriter writes NAR files. +type narWriter struct { + w io.Writer + fs fs.FS +} + +// writeNAR writes a NAR file to w from the root of fs. +func writeNAR(w io.Writer, fs fs.FS) (err error) { + defer func() { + if e := recover(); e != nil { + if we, ok := e.(writeNARError); ok { + err = we.err + return + } + panic(e) + } + }() + nw := &narWriter{w: w, fs: fs} + nw.str("nix-archive-1") + return nw.writeDir(".") +} + +func (nw *narWriter) writeDir(dirPath string) error { + ents, err := fs.ReadDir(nw.fs, dirPath) + if err != nil { + return err + } + sort.Slice(ents, func(i, j int) bool { + return ents[i].Name() < ents[j].Name() + }) + nw.str("(") + nw.str("type") + nw.str("directory") + for _, ent := range ents { + nw.str("entry") + nw.str("(") + nw.str("name") + nw.str(ent.Name()) + nw.str("node") + mode := ent.Type() + sub := path.Join(dirPath, ent.Name()) + var err error + switch { + case mode.IsRegular(): + err = nw.writeRegular(sub) + case mode.IsDir(): + err = nw.writeDir(sub) + default: + // TODO(bradfitz): symlink, but requires fighting io/fs a bit + // to get at Readlink or the osFS via fs. But for now + // we don't need symlinks because they're not in Go's archive. + return fmt.Errorf("unsupported file type %v at %q", sub, mode) + } + if err != nil { + return err + } + nw.str(")") + } + nw.str(")") + return nil +} + +func (nw *narWriter) writeRegular(path string) error { + nw.str("(") + nw.str("type") + nw.str("regular") + fi, err := fs.Stat(nw.fs, path) + if err != nil { + return err + } + if fi.Mode()&0111 != 0 { + nw.str("executable") + nw.str("") + } + contents, err := fs.ReadFile(nw.fs, path) + if err != nil { + return err + } + nw.str("contents") + if err := writeBytes(nw.w, contents); err != nil { + return err + } + nw.str(")") + return nil +} + +func (nw *narWriter) str(s string) { + if err := writeString(nw.w, s); err != nil { + panic(writeNARError{err}) + } +} + +func writeString(w io.Writer, s string) error { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], uint64(len(s))) + if _, err := w.Write(buf[:]); err != nil { + return err + } + if _, err := io.WriteString(w, s); err != nil { + return err + } + return writePad(w, len(s)) +} + +func writeBytes(w io.Writer, b []byte) error { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], uint64(len(b))) + if _, err := w.Write(buf[:]); err != nil { + return err + } + if _, err := w.Write(b); err != nil { + return err + } + return writePad(w, len(b)) +} + +func writePad(w io.Writer, n int) error { + pad := n % 8 + if pad == 0 { + return nil + } + var zeroes [8]byte + _, err := w.Write(zeroes[:8-pad]) + return err +} diff --git a/cmd/nginx-auth/.gitignore b/cmd/nginx-auth/.gitignore index 3c608aeb1eede..255276578b60d 100644 --- a/cmd/nginx-auth/.gitignore +++ b/cmd/nginx-auth/.gitignore @@ -1,4 +1,4 @@ -nga.sock -*.deb -*.rpm -tailscale.nginx-auth +nga.sock +*.deb +*.rpm +tailscale.nginx-auth diff --git a/cmd/nginx-auth/README.md b/cmd/nginx-auth/README.md index 858f9ab81a83e..869b1487bf57b 100644 --- a/cmd/nginx-auth/README.md +++ b/cmd/nginx-auth/README.md @@ -1,161 +1,161 @@ -# nginx-auth - -[![status: experimental](https://img.shields.io/badge/status-experimental-blue)](https://tailscale.com/kb/1167/release-stages/#experimental) - -This is a tool that allows users to use Tailscale Whois authentication with -NGINX as a reverse proxy. This allows users that already have a bunch of -services hosted on an internal NGINX server to point those domains to the -Tailscale IP of the NGINX server and then seamlessly use Tailscale for -authentication. - -Many thanks to [@zrail](https://twitter.com/zrail/status/1511788463586222087) on -Twitter for introducing the basic idea and offering some sample code. This -program is based on that sample code with security enhancements. Namely: - -* This listens over a UNIX socket instead of a TCP socket, to prevent - leakage to the network -* This uses systemd socket activation so that systemd owns the socket - and can then lock down the service to the bare minimum required to do - its job without having to worry about dropping permissions -* This provides additional information in HTTP response headers that can - be useful for integrating with various services - -## Configuration - -In order to protect a service with this tool, do the following in the respective -`server` block: - -Create an authentication location with the `internal` flag set: - -```nginx -location /auth { - internal; - - proxy_pass http://unix:/run/tailscale.nginx-auth.sock; - proxy_pass_request_body off; - - proxy_set_header Host $http_host; - proxy_set_header Remote-Addr $remote_addr; - proxy_set_header Remote-Port $remote_port; - proxy_set_header Original-URI $request_uri; -} -``` - -Then add the following to the `location /` block: - -``` -auth_request /auth; -auth_request_set $auth_user $upstream_http_tailscale_user; -auth_request_set $auth_name $upstream_http_tailscale_name; -auth_request_set $auth_login $upstream_http_tailscale_login; -auth_request_set $auth_tailnet $upstream_http_tailscale_tailnet; -auth_request_set $auth_profile_picture $upstream_http_tailscale_profile_picture; - -proxy_set_header X-Webauth-User "$auth_user"; -proxy_set_header X-Webauth-Name "$auth_name"; -proxy_set_header X-Webauth-Login "$auth_login"; -proxy_set_header X-Webauth-Tailnet "$auth_tailnet"; -proxy_set_header X-Webauth-Profile-Picture "$auth_profile_picture"; -``` - -When this configuration is used with a Go HTTP handler such as this: - -```go -http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { - e := json.NewEncoder(w) - e.SetIndent("", " ") - e.Encode(r.Header) -}) -``` - -You will get output like this: - -```json -{ - "Accept": [ - "*/*" - ], - "Connection": [ - "upgrade" - ], - "User-Agent": [ - "curl/7.82.0" - ], - "X-Webauth-Login": [ - "Xe" - ], - "X-Webauth-Name": [ - "Xe Iaso" - ], - "X-Webauth-Profile-Picture": [ - "https://avatars.githubusercontent.com/u/529003?v=4" - ], - "X-Webauth-Tailnet": [ - "cetacean.org.github" - ] - "X-Webauth-User": [ - "Xe@github" - ] -} -``` - -## Headers - -The authentication service provides the following headers to decorate your -proxied requests: - -| Header | Example Value | Description | -| :------ | :-------------- | :---------- | -| `Tailscale-User` | `azurediamond@hunter2.net` | The Tailscale username the remote machine is logged in as in user@host form | -| `Tailscale-Login` | `azurediamond` | The user portion of the Tailscale username the remote machine is logged in as | -| `Tailscale-Name` | `Azure Diamond` | The "real name" of the Tailscale user the machine is logged in as | -| `Tailscale-Profile-Picture` | `https://i.kym-cdn.com/photos/images/newsfeed/001/065/963/ae0.png` | The profile picture provided by the Identity Provider your tailnet uses | -| `Tailscale-Tailnet` | `hunter2.net` | The tailnet name | - -Most of the time you can set `X-Webauth-User` to the contents of the -`Tailscale-User` header, but some services may not accept a username with an `@` -symbol in it. If this is the case, set `X-Webauth-User` to the `Tailscale-Login` -header. - -The `Tailscale-Tailnet` header can help you identify which tailnet the session -is coming from. If you are using node sharing, this can help you make sure that -you aren't giving administrative access to people outside your tailnet. - -### Allow Requests From Only One Tailnet - -If you want to prevent node sharing from allowing users to access a service, add -the `Expected-Tailnet` header to your auth request: - -```nginx -location /auth { - # ... - proxy_set_header Expected-Tailnet "tailnet012345.ts.net"; -} -``` - -If a user from a different tailnet tries to use that service, this will return a -generic "forbidden" error page: - -```html - -403 Forbidden - -

403 Forbidden

-
nginx/1.18.0 (Ubuntu)
- - -``` - -You can get the tailnet name from [the admin panel](https://login.tailscale.com/admin/dns). - -## Building - -Install `cmd/mkpkg`: - -``` -cd .. && go install ./mkpkg -``` - -Then run `./mkdeb.sh`. It will emit a `.deb` and `.rpm` package for amd64 -machines (Linux uname flag: `x86_64`). You can add these to your deployment -methods as you see fit. +# nginx-auth + +[![status: experimental](https://img.shields.io/badge/status-experimental-blue)](https://tailscale.com/kb/1167/release-stages/#experimental) + +This is a tool that allows users to use Tailscale Whois authentication with +NGINX as a reverse proxy. This allows users that already have a bunch of +services hosted on an internal NGINX server to point those domains to the +Tailscale IP of the NGINX server and then seamlessly use Tailscale for +authentication. + +Many thanks to [@zrail](https://twitter.com/zrail/status/1511788463586222087) on +Twitter for introducing the basic idea and offering some sample code. This +program is based on that sample code with security enhancements. Namely: + +* This listens over a UNIX socket instead of a TCP socket, to prevent + leakage to the network +* This uses systemd socket activation so that systemd owns the socket + and can then lock down the service to the bare minimum required to do + its job without having to worry about dropping permissions +* This provides additional information in HTTP response headers that can + be useful for integrating with various services + +## Configuration + +In order to protect a service with this tool, do the following in the respective +`server` block: + +Create an authentication location with the `internal` flag set: + +```nginx +location /auth { + internal; + + proxy_pass http://unix:/run/tailscale.nginx-auth.sock; + proxy_pass_request_body off; + + proxy_set_header Host $http_host; + proxy_set_header Remote-Addr $remote_addr; + proxy_set_header Remote-Port $remote_port; + proxy_set_header Original-URI $request_uri; +} +``` + +Then add the following to the `location /` block: + +``` +auth_request /auth; +auth_request_set $auth_user $upstream_http_tailscale_user; +auth_request_set $auth_name $upstream_http_tailscale_name; +auth_request_set $auth_login $upstream_http_tailscale_login; +auth_request_set $auth_tailnet $upstream_http_tailscale_tailnet; +auth_request_set $auth_profile_picture $upstream_http_tailscale_profile_picture; + +proxy_set_header X-Webauth-User "$auth_user"; +proxy_set_header X-Webauth-Name "$auth_name"; +proxy_set_header X-Webauth-Login "$auth_login"; +proxy_set_header X-Webauth-Tailnet "$auth_tailnet"; +proxy_set_header X-Webauth-Profile-Picture "$auth_profile_picture"; +``` + +When this configuration is used with a Go HTTP handler such as this: + +```go +http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { + e := json.NewEncoder(w) + e.SetIndent("", " ") + e.Encode(r.Header) +}) +``` + +You will get output like this: + +```json +{ + "Accept": [ + "*/*" + ], + "Connection": [ + "upgrade" + ], + "User-Agent": [ + "curl/7.82.0" + ], + "X-Webauth-Login": [ + "Xe" + ], + "X-Webauth-Name": [ + "Xe Iaso" + ], + "X-Webauth-Profile-Picture": [ + "https://avatars.githubusercontent.com/u/529003?v=4" + ], + "X-Webauth-Tailnet": [ + "cetacean.org.github" + ] + "X-Webauth-User": [ + "Xe@github" + ] +} +``` + +## Headers + +The authentication service provides the following headers to decorate your +proxied requests: + +| Header | Example Value | Description | +| :------ | :-------------- | :---------- | +| `Tailscale-User` | `azurediamond@hunter2.net` | The Tailscale username the remote machine is logged in as in user@host form | +| `Tailscale-Login` | `azurediamond` | The user portion of the Tailscale username the remote machine is logged in as | +| `Tailscale-Name` | `Azure Diamond` | The "real name" of the Tailscale user the machine is logged in as | +| `Tailscale-Profile-Picture` | `https://i.kym-cdn.com/photos/images/newsfeed/001/065/963/ae0.png` | The profile picture provided by the Identity Provider your tailnet uses | +| `Tailscale-Tailnet` | `hunter2.net` | The tailnet name | + +Most of the time you can set `X-Webauth-User` to the contents of the +`Tailscale-User` header, but some services may not accept a username with an `@` +symbol in it. If this is the case, set `X-Webauth-User` to the `Tailscale-Login` +header. + +The `Tailscale-Tailnet` header can help you identify which tailnet the session +is coming from. If you are using node sharing, this can help you make sure that +you aren't giving administrative access to people outside your tailnet. + +### Allow Requests From Only One Tailnet + +If you want to prevent node sharing from allowing users to access a service, add +the `Expected-Tailnet` header to your auth request: + +```nginx +location /auth { + # ... + proxy_set_header Expected-Tailnet "tailnet012345.ts.net"; +} +``` + +If a user from a different tailnet tries to use that service, this will return a +generic "forbidden" error page: + +```html + +403 Forbidden + +

403 Forbidden

+
nginx/1.18.0 (Ubuntu)
+ + +``` + +You can get the tailnet name from [the admin panel](https://login.tailscale.com/admin/dns). + +## Building + +Install `cmd/mkpkg`: + +``` +cd .. && go install ./mkpkg +``` + +Then run `./mkdeb.sh`. It will emit a `.deb` and `.rpm` package for amd64 +machines (Linux uname flag: `x86_64`). You can add these to your deployment +methods as you see fit. diff --git a/cmd/nginx-auth/deb/postinst.sh b/cmd/nginx-auth/deb/postinst.sh index d352a84885403..e692ced0757e3 100755 --- a/cmd/nginx-auth/deb/postinst.sh +++ b/cmd/nginx-auth/deb/postinst.sh @@ -1,14 +1,14 @@ -if [ "$1" = "configure" ] || [ "$1" = "abort-upgrade" ] || [ "$1" = "abort-deconfigure" ] || [ "$1" = "abort-remove" ] ; then - deb-systemd-helper unmask 'tailscale.nginx-auth.socket' >/dev/null || true - if deb-systemd-helper --quiet was-enabled 'tailscale.nginx-auth.socket'; then - deb-systemd-helper enable 'tailscale.nginx-auth.socket' >/dev/null || true - else - deb-systemd-helper update-state 'tailscale.nginx-auth.socket' >/dev/null || true - fi - - if systemctl is-active tailscale.nginx-auth.socket >/dev/null; then - systemctl --system daemon-reload >/dev/null || true - deb-systemd-invoke stop 'tailscale.nginx-auth.service' >/dev/null || true - deb-systemd-invoke restart 'tailscale.nginx-auth.socket' >/dev/null || true - fi -fi +if [ "$1" = "configure" ] || [ "$1" = "abort-upgrade" ] || [ "$1" = "abort-deconfigure" ] || [ "$1" = "abort-remove" ] ; then + deb-systemd-helper unmask 'tailscale.nginx-auth.socket' >/dev/null || true + if deb-systemd-helper --quiet was-enabled 'tailscale.nginx-auth.socket'; then + deb-systemd-helper enable 'tailscale.nginx-auth.socket' >/dev/null || true + else + deb-systemd-helper update-state 'tailscale.nginx-auth.socket' >/dev/null || true + fi + + if systemctl is-active tailscale.nginx-auth.socket >/dev/null; then + systemctl --system daemon-reload >/dev/null || true + deb-systemd-invoke stop 'tailscale.nginx-auth.service' >/dev/null || true + deb-systemd-invoke restart 'tailscale.nginx-auth.socket' >/dev/null || true + fi +fi diff --git a/cmd/nginx-auth/deb/postrm.sh b/cmd/nginx-auth/deb/postrm.sh index 4bce86139c6c2..7870efd18fb39 100755 --- a/cmd/nginx-auth/deb/postrm.sh +++ b/cmd/nginx-auth/deb/postrm.sh @@ -1,19 +1,19 @@ -#!/bin/sh -set -e -if [ -d /run/systemd/system ] ; then - systemctl --system daemon-reload >/dev/null || true -fi - -if [ -x "/usr/bin/deb-systemd-helper" ]; then - if [ "$1" = "remove" ]; then - deb-systemd-helper mask 'tailscale.nginx-auth.socket' >/dev/null || true - deb-systemd-helper mask 'tailscale.nginx-auth.service' >/dev/null || true - fi - - if [ "$1" = "purge" ]; then - deb-systemd-helper purge 'tailscale.nginx-auth.socket' >/dev/null || true - deb-systemd-helper unmask 'tailscale.nginx-auth.socket' >/dev/null || true - deb-systemd-helper purge 'tailscale.nginx-auth.service' >/dev/null || true - deb-systemd-helper unmask 'tailscale.nginx-auth.service' >/dev/null || true - fi -fi +#!/bin/sh +set -e +if [ -d /run/systemd/system ] ; then + systemctl --system daemon-reload >/dev/null || true +fi + +if [ -x "/usr/bin/deb-systemd-helper" ]; then + if [ "$1" = "remove" ]; then + deb-systemd-helper mask 'tailscale.nginx-auth.socket' >/dev/null || true + deb-systemd-helper mask 'tailscale.nginx-auth.service' >/dev/null || true + fi + + if [ "$1" = "purge" ]; then + deb-systemd-helper purge 'tailscale.nginx-auth.socket' >/dev/null || true + deb-systemd-helper unmask 'tailscale.nginx-auth.socket' >/dev/null || true + deb-systemd-helper purge 'tailscale.nginx-auth.service' >/dev/null || true + deb-systemd-helper unmask 'tailscale.nginx-auth.service' >/dev/null || true + fi +fi diff --git a/cmd/nginx-auth/deb/prerm.sh b/cmd/nginx-auth/deb/prerm.sh index e4becd17039ba..22be23387c37e 100755 --- a/cmd/nginx-auth/deb/prerm.sh +++ b/cmd/nginx-auth/deb/prerm.sh @@ -1,8 +1,8 @@ -#!/bin/sh -set -e -if [ "$1" = "remove" ]; then - if [ -d /run/systemd/system ]; then - deb-systemd-invoke stop 'tailscale.nginx-auth.service' >/dev/null || true - deb-systemd-invoke stop 'tailscale.nginx-auth.socket' >/dev/null || true - fi -fi +#!/bin/sh +set -e +if [ "$1" = "remove" ]; then + if [ -d /run/systemd/system ]; then + deb-systemd-invoke stop 'tailscale.nginx-auth.service' >/dev/null || true + deb-systemd-invoke stop 'tailscale.nginx-auth.socket' >/dev/null || true + fi +fi diff --git a/cmd/nginx-auth/mkdeb.sh b/cmd/nginx-auth/mkdeb.sh index 59f43230d0817..6a57210937f87 100755 --- a/cmd/nginx-auth/mkdeb.sh +++ b/cmd/nginx-auth/mkdeb.sh @@ -1,32 +1,32 @@ -#!/usr/bin/env bash - -set -e - -VERSION=0.1.3 -for ARCH in amd64 arm64; do - CGO_ENABLED=0 GOARCH=${ARCH} GOOS=linux go build -o tailscale.nginx-auth . - - mkpkg \ - --out=tailscale-nginx-auth-${VERSION}-${ARCH}.deb \ - --name=tailscale-nginx-auth \ - --version=${VERSION} \ - --type=deb \ - --arch=${ARCH} \ - --postinst=deb/postinst.sh \ - --postrm=deb/postrm.sh \ - --prerm=deb/prerm.sh \ - --description="Tailscale NGINX authentication protocol handler" \ - --files=./tailscale.nginx-auth:/usr/sbin/tailscale.nginx-auth,./tailscale.nginx-auth.socket:/lib/systemd/system/tailscale.nginx-auth.socket,./tailscale.nginx-auth.service:/lib/systemd/system/tailscale.nginx-auth.service,./README.md:/usr/share/tailscale/nginx-auth/README.md - - mkpkg \ - --out=tailscale-nginx-auth-${VERSION}-${ARCH}.rpm \ - --name=tailscale-nginx-auth \ - --version=${VERSION} \ - --type=rpm \ - --arch=${ARCH} \ - --postinst=rpm/postinst.sh \ - --postrm=rpm/postrm.sh \ - --prerm=rpm/prerm.sh \ - --description="Tailscale NGINX authentication protocol handler" \ - --files=./tailscale.nginx-auth:/usr/sbin/tailscale.nginx-auth,./tailscale.nginx-auth.socket:/lib/systemd/system/tailscale.nginx-auth.socket,./tailscale.nginx-auth.service:/lib/systemd/system/tailscale.nginx-auth.service,./README.md:/usr/share/tailscale/nginx-auth/README.md -done +#!/usr/bin/env bash + +set -e + +VERSION=0.1.3 +for ARCH in amd64 arm64; do + CGO_ENABLED=0 GOARCH=${ARCH} GOOS=linux go build -o tailscale.nginx-auth . + + mkpkg \ + --out=tailscale-nginx-auth-${VERSION}-${ARCH}.deb \ + --name=tailscale-nginx-auth \ + --version=${VERSION} \ + --type=deb \ + --arch=${ARCH} \ + --postinst=deb/postinst.sh \ + --postrm=deb/postrm.sh \ + --prerm=deb/prerm.sh \ + --description="Tailscale NGINX authentication protocol handler" \ + --files=./tailscale.nginx-auth:/usr/sbin/tailscale.nginx-auth,./tailscale.nginx-auth.socket:/lib/systemd/system/tailscale.nginx-auth.socket,./tailscale.nginx-auth.service:/lib/systemd/system/tailscale.nginx-auth.service,./README.md:/usr/share/tailscale/nginx-auth/README.md + + mkpkg \ + --out=tailscale-nginx-auth-${VERSION}-${ARCH}.rpm \ + --name=tailscale-nginx-auth \ + --version=${VERSION} \ + --type=rpm \ + --arch=${ARCH} \ + --postinst=rpm/postinst.sh \ + --postrm=rpm/postrm.sh \ + --prerm=rpm/prerm.sh \ + --description="Tailscale NGINX authentication protocol handler" \ + --files=./tailscale.nginx-auth:/usr/sbin/tailscale.nginx-auth,./tailscale.nginx-auth.socket:/lib/systemd/system/tailscale.nginx-auth.socket,./tailscale.nginx-auth.service:/lib/systemd/system/tailscale.nginx-auth.service,./README.md:/usr/share/tailscale/nginx-auth/README.md +done diff --git a/cmd/nginx-auth/nginx-auth.go b/cmd/nginx-auth/nginx-auth.go index 09da74da1d3c8..befcb6d6c0423 100644 --- a/cmd/nginx-auth/nginx-auth.go +++ b/cmd/nginx-auth/nginx-auth.go @@ -1,128 +1,128 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -// Command nginx-auth is a tool that allows users to use Tailscale Whois -// authentication with NGINX as a reverse proxy. This allows users that -// already have a bunch of services hosted on an internal NGINX server -// to point those domains to the Tailscale IP of the NGINX server and -// then seamlessly use Tailscale for authentication. -package main - -import ( - "flag" - "log" - "net" - "net/http" - "net/netip" - "net/url" - "os" - "strings" - - "github.com/coreos/go-systemd/activation" - "tailscale.com/client/tailscale" -) - -var ( - sockPath = flag.String("sockpath", "", "the filesystem path for the unix socket this service exposes") -) - -func main() { - flag.Parse() - - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - remoteHost := r.Header.Get("Remote-Addr") - remotePort := r.Header.Get("Remote-Port") - if remoteHost == "" || remotePort == "" { - w.WriteHeader(http.StatusBadRequest) - log.Println("set Remote-Addr to $remote_addr and Remote-Port to $remote_port in your nginx config") - return - } - - remoteAddrStr := net.JoinHostPort(remoteHost, remotePort) - remoteAddr, err := netip.ParseAddrPort(remoteAddrStr) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - log.Printf("remote address and port are not valid: %v", err) - return - } - - info, err := tailscale.WhoIs(r.Context(), remoteAddr.String()) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - log.Printf("can't look up %s: %v", remoteAddr, err) - return - } - - if info.Node.IsTagged() { - w.WriteHeader(http.StatusForbidden) - log.Printf("node %s is tagged", info.Node.Hostinfo.Hostname()) - return - } - - // tailnet of connected node. When accessing shared nodes, this - // will be empty because the tailnet of the sharee is not exposed. - var tailnet string - - if !info.Node.Hostinfo.ShareeNode() { - var ok bool - _, tailnet, ok = strings.Cut(info.Node.Name, info.Node.ComputedName+".") - if !ok { - w.WriteHeader(http.StatusUnauthorized) - log.Printf("can't extract tailnet name from hostname %q", info.Node.Name) - return - } - tailnet = strings.TrimSuffix(tailnet, ".beta.tailscale.net") - } - - if expectedTailnet := r.Header.Get("Expected-Tailnet"); expectedTailnet != "" && expectedTailnet != tailnet { - w.WriteHeader(http.StatusForbidden) - log.Printf("user is part of tailnet %s, wanted: %s", tailnet, url.QueryEscape(expectedTailnet)) - return - } - - h := w.Header() - h.Set("Tailscale-Login", strings.Split(info.UserProfile.LoginName, "@")[0]) - h.Set("Tailscale-User", info.UserProfile.LoginName) - h.Set("Tailscale-Name", info.UserProfile.DisplayName) - h.Set("Tailscale-Profile-Picture", info.UserProfile.ProfilePicURL) - h.Set("Tailscale-Tailnet", tailnet) - w.WriteHeader(http.StatusNoContent) - }) - - if *sockPath != "" { - _ = os.Remove(*sockPath) // ignore error, this file may not already exist - ln, err := net.Listen("unix", *sockPath) - if err != nil { - log.Fatalf("can't listen on %s: %v", *sockPath, err) - } - defer ln.Close() - - log.Printf("listening on %s", *sockPath) - log.Fatal(http.Serve(ln, mux)) - } - - listeners, err := activation.Listeners() - if err != nil { - log.Fatalf("no sockets passed to this service with systemd: %v", err) - } - - // NOTE(Xe): normally you'd want to make a waitgroup here and then register - // each listener with it. In this case I want this to blow up horribly if - // any of the listeners stop working. systemd will restart it due to the - // socket activation at play. - // - // TL;DR: Let it crash, it will come back - for _, ln := range listeners { - go func(ln net.Listener) { - log.Printf("listening on %s", ln.Addr()) - log.Fatal(http.Serve(ln, mux)) - }(ln) - } - - for { - select {} - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +// Command nginx-auth is a tool that allows users to use Tailscale Whois +// authentication with NGINX as a reverse proxy. This allows users that +// already have a bunch of services hosted on an internal NGINX server +// to point those domains to the Tailscale IP of the NGINX server and +// then seamlessly use Tailscale for authentication. +package main + +import ( + "flag" + "log" + "net" + "net/http" + "net/netip" + "net/url" + "os" + "strings" + + "github.com/coreos/go-systemd/activation" + "tailscale.com/client/tailscale" +) + +var ( + sockPath = flag.String("sockpath", "", "the filesystem path for the unix socket this service exposes") +) + +func main() { + flag.Parse() + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + remoteHost := r.Header.Get("Remote-Addr") + remotePort := r.Header.Get("Remote-Port") + if remoteHost == "" || remotePort == "" { + w.WriteHeader(http.StatusBadRequest) + log.Println("set Remote-Addr to $remote_addr and Remote-Port to $remote_port in your nginx config") + return + } + + remoteAddrStr := net.JoinHostPort(remoteHost, remotePort) + remoteAddr, err := netip.ParseAddrPort(remoteAddrStr) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + log.Printf("remote address and port are not valid: %v", err) + return + } + + info, err := tailscale.WhoIs(r.Context(), remoteAddr.String()) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + log.Printf("can't look up %s: %v", remoteAddr, err) + return + } + + if info.Node.IsTagged() { + w.WriteHeader(http.StatusForbidden) + log.Printf("node %s is tagged", info.Node.Hostinfo.Hostname()) + return + } + + // tailnet of connected node. When accessing shared nodes, this + // will be empty because the tailnet of the sharee is not exposed. + var tailnet string + + if !info.Node.Hostinfo.ShareeNode() { + var ok bool + _, tailnet, ok = strings.Cut(info.Node.Name, info.Node.ComputedName+".") + if !ok { + w.WriteHeader(http.StatusUnauthorized) + log.Printf("can't extract tailnet name from hostname %q", info.Node.Name) + return + } + tailnet = strings.TrimSuffix(tailnet, ".beta.tailscale.net") + } + + if expectedTailnet := r.Header.Get("Expected-Tailnet"); expectedTailnet != "" && expectedTailnet != tailnet { + w.WriteHeader(http.StatusForbidden) + log.Printf("user is part of tailnet %s, wanted: %s", tailnet, url.QueryEscape(expectedTailnet)) + return + } + + h := w.Header() + h.Set("Tailscale-Login", strings.Split(info.UserProfile.LoginName, "@")[0]) + h.Set("Tailscale-User", info.UserProfile.LoginName) + h.Set("Tailscale-Name", info.UserProfile.DisplayName) + h.Set("Tailscale-Profile-Picture", info.UserProfile.ProfilePicURL) + h.Set("Tailscale-Tailnet", tailnet) + w.WriteHeader(http.StatusNoContent) + }) + + if *sockPath != "" { + _ = os.Remove(*sockPath) // ignore error, this file may not already exist + ln, err := net.Listen("unix", *sockPath) + if err != nil { + log.Fatalf("can't listen on %s: %v", *sockPath, err) + } + defer ln.Close() + + log.Printf("listening on %s", *sockPath) + log.Fatal(http.Serve(ln, mux)) + } + + listeners, err := activation.Listeners() + if err != nil { + log.Fatalf("no sockets passed to this service with systemd: %v", err) + } + + // NOTE(Xe): normally you'd want to make a waitgroup here and then register + // each listener with it. In this case I want this to blow up horribly if + // any of the listeners stop working. systemd will restart it due to the + // socket activation at play. + // + // TL;DR: Let it crash, it will come back + for _, ln := range listeners { + go func(ln net.Listener) { + log.Printf("listening on %s", ln.Addr()) + log.Fatal(http.Serve(ln, mux)) + }(ln) + } + + for { + select {} + } +} diff --git a/cmd/nginx-auth/rpm/postrm.sh b/cmd/nginx-auth/rpm/postrm.sh index 3d0abfb199137..d8d36893fd931 100755 --- a/cmd/nginx-auth/rpm/postrm.sh +++ b/cmd/nginx-auth/rpm/postrm.sh @@ -1,9 +1,9 @@ -# $1 == 0 for uninstallation. -# $1 == 1 for removing old package during upgrade. - -systemctl daemon-reload >/dev/null 2>&1 || : -if [ $1 -ge 1 ] ; then - # Package upgrade, not uninstall - systemctl stop tailscale.nginx-auth.service >/dev/null 2>&1 || : - systemctl try-restart tailscale.nginx-auth.socket >/dev/null 2>&1 || : -fi +# $1 == 0 for uninstallation. +# $1 == 1 for removing old package during upgrade. + +systemctl daemon-reload >/dev/null 2>&1 || : +if [ $1 -ge 1 ] ; then + # Package upgrade, not uninstall + systemctl stop tailscale.nginx-auth.service >/dev/null 2>&1 || : + systemctl try-restart tailscale.nginx-auth.socket >/dev/null 2>&1 || : +fi diff --git a/cmd/nginx-auth/rpm/prerm.sh b/cmd/nginx-auth/rpm/prerm.sh index 1f198d8292bc5..2e47a53ed9356 100755 --- a/cmd/nginx-auth/rpm/prerm.sh +++ b/cmd/nginx-auth/rpm/prerm.sh @@ -1,9 +1,9 @@ -# $1 == 0 for uninstallation. -# $1 == 1 for removing old package during upgrade. - -if [ $1 -eq 0 ] ; then - # Package removal, not upgrade - systemctl --no-reload disable tailscale.nginx-auth.socket > /dev/null 2>&1 || : - systemctl stop tailscale.nginx-auth.socket > /dev/null 2>&1 || : - systemctl stop tailscale.nginx-auth.service > /dev/null 2>&1 || : -fi +# $1 == 0 for uninstallation. +# $1 == 1 for removing old package during upgrade. + +if [ $1 -eq 0 ] ; then + # Package removal, not upgrade + systemctl --no-reload disable tailscale.nginx-auth.socket > /dev/null 2>&1 || : + systemctl stop tailscale.nginx-auth.socket > /dev/null 2>&1 || : + systemctl stop tailscale.nginx-auth.service > /dev/null 2>&1 || : +fi diff --git a/cmd/nginx-auth/tailscale.nginx-auth.service b/cmd/nginx-auth/tailscale.nginx-auth.service index 086f6c7741d88..8534e25c1048d 100644 --- a/cmd/nginx-auth/tailscale.nginx-auth.service +++ b/cmd/nginx-auth/tailscale.nginx-auth.service @@ -1,11 +1,11 @@ -[Unit] -Description=Tailscale NGINX Authentication service -After=nginx.service -Wants=nginx.service - -[Service] -ExecStart=/usr/sbin/tailscale.nginx-auth -DynamicUser=yes - -[Install] -WantedBy=default.target +[Unit] +Description=Tailscale NGINX Authentication service +After=nginx.service +Wants=nginx.service + +[Service] +ExecStart=/usr/sbin/tailscale.nginx-auth +DynamicUser=yes + +[Install] +WantedBy=default.target diff --git a/cmd/nginx-auth/tailscale.nginx-auth.socket b/cmd/nginx-auth/tailscale.nginx-auth.socket index 7e5641ff3a2f5..53e3e8d83edf3 100644 --- a/cmd/nginx-auth/tailscale.nginx-auth.socket +++ b/cmd/nginx-auth/tailscale.nginx-auth.socket @@ -1,9 +1,9 @@ -[Unit] -Description=Tailscale NGINX Authentication socket -PartOf=tailscale.nginx-auth.service - -[Socket] -ListenStream=/var/run/tailscale.nginx-auth.sock - -[Install] +[Unit] +Description=Tailscale NGINX Authentication socket +PartOf=tailscale.nginx-auth.service + +[Socket] +ListenStream=/var/run/tailscale.nginx-auth.sock + +[Install] WantedBy=sockets.target \ No newline at end of file diff --git a/cmd/pgproxy/README.md b/cmd/pgproxy/README.md index 2e013072a1900..a867ad8cad9de 100644 --- a/cmd/pgproxy/README.md +++ b/cmd/pgproxy/README.md @@ -1,42 +1,42 @@ -# pgproxy - -The pgproxy server is a proxy for the Postgres wire protocol. [Read -more in our blog -post](https://tailscale.com/blog/introducing-pgproxy/) about it! - -The proxy runs an in-process Tailscale instance, accepts postgres -client connections over Tailscale only, and proxies them to the -configured upstream postgres server. - -This proxy exists because postgres clients default to very insecure -connection settings: either they "prefer" but do not require TLS; or -they set sslmode=require, which merely requires that a TLS handshake -took place, but don't verify the server's TLS certificate or the -presented TLS hostname. In other words, sslmode=require enforces that -a TLS session is created, but that session can trivially be -machine-in-the-middled to steal credentials, data, inject malicious -queries, and so forth. - -Because this flaw is in the client's validation of the TLS session, -you have no way of reliably detecting the misconfiguration -server-side. You could fix the configuration of all the clients you -know of, but the default makes it very easy to accidentally regress. - -Instead of trying to verify client configuration over time, this proxy -removes the need for postgres clients to be configured correctly: the -upstream database is configured to only accept connections from the -proxy, and the proxy is only available to clients over Tailscale. - -Therefore, clients must use the proxy to connect to the database. The -client<>proxy connection is secured end-to-end by Tailscale, which the -proxy enforces by verifying that the connecting client is a known -current Tailscale peer. The proxy<>server connection is established by -the proxy itself, using strict TLS verification settings, and the -client is only allowed to communicate with the server once we've -established that the upstream connection is safe to use. - -A couple side benefits: because clients can only connect via -Tailscale, you can use Tailscale ACLs as an extra layer of defense on -top of the postgres user/password authentication. And, the proxy can -maintain an audit log of who connected to the database, complete with -the strongly authenticated Tailscale identity of the client. +# pgproxy + +The pgproxy server is a proxy for the Postgres wire protocol. [Read +more in our blog +post](https://tailscale.com/blog/introducing-pgproxy/) about it! + +The proxy runs an in-process Tailscale instance, accepts postgres +client connections over Tailscale only, and proxies them to the +configured upstream postgres server. + +This proxy exists because postgres clients default to very insecure +connection settings: either they "prefer" but do not require TLS; or +they set sslmode=require, which merely requires that a TLS handshake +took place, but don't verify the server's TLS certificate or the +presented TLS hostname. In other words, sslmode=require enforces that +a TLS session is created, but that session can trivially be +machine-in-the-middled to steal credentials, data, inject malicious +queries, and so forth. + +Because this flaw is in the client's validation of the TLS session, +you have no way of reliably detecting the misconfiguration +server-side. You could fix the configuration of all the clients you +know of, but the default makes it very easy to accidentally regress. + +Instead of trying to verify client configuration over time, this proxy +removes the need for postgres clients to be configured correctly: the +upstream database is configured to only accept connections from the +proxy, and the proxy is only available to clients over Tailscale. + +Therefore, clients must use the proxy to connect to the database. The +client<>proxy connection is secured end-to-end by Tailscale, which the +proxy enforces by verifying that the connecting client is a known +current Tailscale peer. The proxy<>server connection is established by +the proxy itself, using strict TLS verification settings, and the +client is only allowed to communicate with the server once we've +established that the upstream connection is safe to use. + +A couple side benefits: because clients can only connect via +Tailscale, you can use Tailscale ACLs as an extra layer of defense on +top of the postgres user/password authentication. And, the proxy can +maintain an audit log of who connected to the database, complete with +the strongly authenticated Tailscale identity of the client. diff --git a/cmd/printdep/printdep.go b/cmd/printdep/printdep.go index 044283209c08c..0790a8b813cc6 100644 --- a/cmd/printdep/printdep.go +++ b/cmd/printdep/printdep.go @@ -1,41 +1,41 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The printdep command is a build system tool for printing out information -// about dependencies. -package main - -import ( - "flag" - "fmt" - "log" - "runtime" - "strings" - - ts "tailscale.com" -) - -var ( - goToolchain = flag.Bool("go", false, "print the supported Go toolchain git hash (a github.com/tailscale/go commit)") - goToolchainURL = flag.Bool("go-url", false, "print the URL to the tarball of the Tailscale Go toolchain") - alpine = flag.Bool("alpine", false, "print the tag of alpine docker image") -) - -func main() { - flag.Parse() - if *alpine { - fmt.Println(strings.TrimSpace(ts.AlpineDockerTag)) - return - } - if *goToolchain { - fmt.Println(strings.TrimSpace(ts.GoToolchainRev)) - } - if *goToolchainURL { - switch runtime.GOOS { - case "linux", "darwin": - default: - log.Fatalf("unsupported GOOS %q", runtime.GOOS) - } - fmt.Printf("https://github.com/tailscale/go/releases/download/build-%s/%s-%s.tar.gz\n", strings.TrimSpace(ts.GoToolchainRev), runtime.GOOS, runtime.GOARCH) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The printdep command is a build system tool for printing out information +// about dependencies. +package main + +import ( + "flag" + "fmt" + "log" + "runtime" + "strings" + + ts "tailscale.com" +) + +var ( + goToolchain = flag.Bool("go", false, "print the supported Go toolchain git hash (a github.com/tailscale/go commit)") + goToolchainURL = flag.Bool("go-url", false, "print the URL to the tarball of the Tailscale Go toolchain") + alpine = flag.Bool("alpine", false, "print the tag of alpine docker image") +) + +func main() { + flag.Parse() + if *alpine { + fmt.Println(strings.TrimSpace(ts.AlpineDockerTag)) + return + } + if *goToolchain { + fmt.Println(strings.TrimSpace(ts.GoToolchainRev)) + } + if *goToolchainURL { + switch runtime.GOOS { + case "linux", "darwin": + default: + log.Fatalf("unsupported GOOS %q", runtime.GOOS) + } + fmt.Printf("https://github.com/tailscale/go/releases/download/build-%s/%s-%s.tar.gz\n", strings.TrimSpace(ts.GoToolchainRev), runtime.GOOS, runtime.GOARCH) + } +} diff --git a/cmd/sniproxy/.gitignore b/cmd/sniproxy/.gitignore index b1399c88167d4..0bca339122774 100644 --- a/cmd/sniproxy/.gitignore +++ b/cmd/sniproxy/.gitignore @@ -1 +1 @@ -sniproxy +sniproxy diff --git a/cmd/sniproxy/handlers_test.go b/cmd/sniproxy/handlers_test.go index 4f9fc6a34b184..8ec5b097c9b3c 100644 --- a/cmd/sniproxy/handlers_test.go +++ b/cmd/sniproxy/handlers_test.go @@ -1,159 +1,159 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "bytes" - "context" - "encoding/hex" - "io" - "net" - "net/netip" - "strings" - "testing" - - "tailscale.com/net/memnet" -) - -func echoConnOnce(conn net.Conn) { - defer conn.Close() - - b := make([]byte, 256) - n, err := conn.Read(b) - if err != nil { - return - } - - if _, err := conn.Write(b[:n]); err != nil { - return - } -} - -func TestTCPRoundRobinHandler(t *testing.T) { - h := tcpRoundRobinHandler{ - To: []string{"yeet.com"}, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - if network != "tcp" { - t.Errorf("network = %s, want %s", network, "tcp") - } - if addr != "yeet.com:22" { - t.Errorf("addr = %s, want %s", addr, "yeet.com:22") - } - - c, s := memnet.NewConn("outbound", 1024) - go echoConnOnce(s) - return c, nil - }, - } - - cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:22"), 1024) - h.Handle(sSock) - - // Test data write and read, the other end will echo back - // a single stanza - want := "hello" - if _, err := io.WriteString(cSock, want); err != nil { - t.Fatal(err) - } - got := make([]byte, len(want)) - if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { - t.Fatal(err) - } - if string(got) != want { - t.Errorf("got %q, want %q", got, want) - } - - // The other end closed the socket after the first echo, so - // any following read should error. - io.WriteString(cSock, "deadass heres some data on god fr") - if _, err := io.ReadAtLeast(cSock, got, len(got)); err == nil { - t.Error("read succeeded on closed socket") - } -} - -// Capture of first TCP data segment for a connection to https://pkgs.tailscale.com -const tlsStart = `45000239ff1840004006f9f5c0a801f2 -c726b5efcf9e01bbe803b21394e3b752 -801801f641dc00000101080ade3474f2 -2fb93ee71603010200010001fc030303 -c3acbd19d2624765bb19af4bce03365e -1d197f5bb939cdadeff26b0f8e7a0620 -295b04127b82bae46aac4ff58cffef25 -eba75a4b7a6de729532c411bd9dd0d2c -00203a3a130113021303c02bc02fc02c -c030cca9cca8c013c014009c009d002f -003501000193caca0000000a000a0008 -1a1a001d001700180010000e000c0268 -3208687474702f312e31002b0007062a -2a03040303ff01000100000d00120010 -04030804040105030805050108060601 -000b00020100002300000033002b0029 -1a1a000100001d0020d3c76bef062979 -a812ce935cfb4dbe6b3a84dc5ba9226f -23b0f34af9d1d03b4a001b0003020002 -00120000446900050003026832000000 -170015000012706b67732e7461696c73 -63616c652e636f6d002d000201010005 -00050100000000001700003a3a000100 -0015002d000000000000000000000000 -00000000000000000000000000000000 -00000000000000000000000000000000 -0000290094006f0069e76f2016f963ad -38c8632d1f240cd75e00e25fdef295d4 -7042b26f3a9a543b1c7dc74939d77803 -20527d423ff996997bda2c6383a14f49 -219eeef8a053e90a32228df37ddbe126 -eccf6b085c93890d08341d819aea6111 -0d909f4cd6b071d9ea40618e74588a33 -90d494bbb5c3002120d5a164a16c9724 -c9ef5e540d8d6f007789a7acf9f5f16f -bf6a1907a6782ed02b` - -func fakeSNIHeader() []byte { - b, err := hex.DecodeString(strings.Replace(tlsStart, "\n", "", -1)) - if err != nil { - panic(err) - } - return b[0x34:] // trim IP + TCP header -} - -func TestTCPSNIHandler(t *testing.T) { - h := tcpSNIHandler{ - Allowlist: []string{"pkgs.tailscale.com"}, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - if network != "tcp" { - t.Errorf("network = %s, want %s", network, "tcp") - } - if addr != "pkgs.tailscale.com:443" { - t.Errorf("addr = %s, want %s", addr, "pkgs.tailscale.com:443") - } - - c, s := memnet.NewConn("outbound", 1024) - go echoConnOnce(s) - return c, nil - }, - } - - cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:443"), 1024) - h.Handle(sSock) - - // Fake a TLS handshake record with an SNI in it. - if _, err := cSock.Write(fakeSNIHeader()); err != nil { - t.Fatal(err) - } - - // Test read, the other end will echo back - // a single stanza, which is at least the beginning of the SNI header. - want := fakeSNIHeader()[:5] - if _, err := cSock.Write(want); err != nil { - t.Fatal(err) - } - got := make([]byte, len(want)) - if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { - t.Fatal(err) - } - if !bytes.Equal(got, want) { - t.Errorf("got %q, want %q", got, want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "bytes" + "context" + "encoding/hex" + "io" + "net" + "net/netip" + "strings" + "testing" + + "tailscale.com/net/memnet" +) + +func echoConnOnce(conn net.Conn) { + defer conn.Close() + + b := make([]byte, 256) + n, err := conn.Read(b) + if err != nil { + return + } + + if _, err := conn.Write(b[:n]); err != nil { + return + } +} + +func TestTCPRoundRobinHandler(t *testing.T) { + h := tcpRoundRobinHandler{ + To: []string{"yeet.com"}, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if network != "tcp" { + t.Errorf("network = %s, want %s", network, "tcp") + } + if addr != "yeet.com:22" { + t.Errorf("addr = %s, want %s", addr, "yeet.com:22") + } + + c, s := memnet.NewConn("outbound", 1024) + go echoConnOnce(s) + return c, nil + }, + } + + cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:22"), 1024) + h.Handle(sSock) + + // Test data write and read, the other end will echo back + // a single stanza + want := "hello" + if _, err := io.WriteString(cSock, want); err != nil { + t.Fatal(err) + } + got := make([]byte, len(want)) + if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { + t.Fatal(err) + } + if string(got) != want { + t.Errorf("got %q, want %q", got, want) + } + + // The other end closed the socket after the first echo, so + // any following read should error. + io.WriteString(cSock, "deadass heres some data on god fr") + if _, err := io.ReadAtLeast(cSock, got, len(got)); err == nil { + t.Error("read succeeded on closed socket") + } +} + +// Capture of first TCP data segment for a connection to https://pkgs.tailscale.com +const tlsStart = `45000239ff1840004006f9f5c0a801f2 +c726b5efcf9e01bbe803b21394e3b752 +801801f641dc00000101080ade3474f2 +2fb93ee71603010200010001fc030303 +c3acbd19d2624765bb19af4bce03365e +1d197f5bb939cdadeff26b0f8e7a0620 +295b04127b82bae46aac4ff58cffef25 +eba75a4b7a6de729532c411bd9dd0d2c +00203a3a130113021303c02bc02fc02c +c030cca9cca8c013c014009c009d002f +003501000193caca0000000a000a0008 +1a1a001d001700180010000e000c0268 +3208687474702f312e31002b0007062a +2a03040303ff01000100000d00120010 +04030804040105030805050108060601 +000b00020100002300000033002b0029 +1a1a000100001d0020d3c76bef062979 +a812ce935cfb4dbe6b3a84dc5ba9226f +23b0f34af9d1d03b4a001b0003020002 +00120000446900050003026832000000 +170015000012706b67732e7461696c73 +63616c652e636f6d002d000201010005 +00050100000000001700003a3a000100 +0015002d000000000000000000000000 +00000000000000000000000000000000 +00000000000000000000000000000000 +0000290094006f0069e76f2016f963ad +38c8632d1f240cd75e00e25fdef295d4 +7042b26f3a9a543b1c7dc74939d77803 +20527d423ff996997bda2c6383a14f49 +219eeef8a053e90a32228df37ddbe126 +eccf6b085c93890d08341d819aea6111 +0d909f4cd6b071d9ea40618e74588a33 +90d494bbb5c3002120d5a164a16c9724 +c9ef5e540d8d6f007789a7acf9f5f16f +bf6a1907a6782ed02b` + +func fakeSNIHeader() []byte { + b, err := hex.DecodeString(strings.Replace(tlsStart, "\n", "", -1)) + if err != nil { + panic(err) + } + return b[0x34:] // trim IP + TCP header +} + +func TestTCPSNIHandler(t *testing.T) { + h := tcpSNIHandler{ + Allowlist: []string{"pkgs.tailscale.com"}, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if network != "tcp" { + t.Errorf("network = %s, want %s", network, "tcp") + } + if addr != "pkgs.tailscale.com:443" { + t.Errorf("addr = %s, want %s", addr, "pkgs.tailscale.com:443") + } + + c, s := memnet.NewConn("outbound", 1024) + go echoConnOnce(s) + return c, nil + }, + } + + cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:443"), 1024) + h.Handle(sSock) + + // Fake a TLS handshake record with an SNI in it. + if _, err := cSock.Write(fakeSNIHeader()); err != nil { + t.Fatal(err) + } + + // Test read, the other end will echo back + // a single stanza, which is at least the beginning of the SNI header. + want := fakeSNIHeader()[:5] + if _, err := cSock.Write(want); err != nil { + t.Fatal(err) + } + got := make([]byte, len(want)) + if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, want) { + t.Errorf("got %q, want %q", got, want) + } +} diff --git a/cmd/sniproxy/server.go b/cmd/sniproxy/server.go index b322b6f4b1137..c894206613f4a 100644 --- a/cmd/sniproxy/server.go +++ b/cmd/sniproxy/server.go @@ -1,327 +1,327 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "expvar" - "log" - "net" - "net/netip" - "sync" - "time" - - "golang.org/x/net/dns/dnsmessage" - "tailscale.com/metrics" - "tailscale.com/tailcfg" - "tailscale.com/types/appctype" - "tailscale.com/types/ipproto" - "tailscale.com/types/nettype" - "tailscale.com/util/clientmetric" - "tailscale.com/util/mak" -) - -var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") - -// target describes the predicates which route some inbound -// traffic to the app connector to a specific handler. -type target struct { - Dest netip.Prefix - Matching tailcfg.ProtoPortRange -} - -// Server implements an App Connector as expressed in sniproxy. -type Server struct { - mu sync.RWMutex // mu guards following fields - connectors map[appctype.ConfigID]connector -} - -type appcMetrics struct { - dnsResponses expvar.Int - dnsFailures expvar.Int - tcpConns expvar.Int - sniConns expvar.Int - unhandledConns expvar.Int -} - -var getMetrics = sync.OnceValue[*appcMetrics](func() *appcMetrics { - m := appcMetrics{} - - stats := new(metrics.Set) - stats.Set("tls_sessions", &m.sniConns) - clientmetric.NewCounterFunc("sniproxy_tls_sessions", m.sniConns.Value) - stats.Set("tcp_sessions", &m.tcpConns) - clientmetric.NewCounterFunc("sniproxy_tcp_sessions", m.tcpConns.Value) - stats.Set("dns_responses", &m.dnsResponses) - clientmetric.NewCounterFunc("sniproxy_dns_responses", m.dnsResponses.Value) - stats.Set("dns_failed", &m.dnsFailures) - clientmetric.NewCounterFunc("sniproxy_dns_failed", m.dnsFailures.Value) - expvar.Publish("sniproxy", stats) - - return &m -}) - -// Configure applies the provided configuration to the app connector. -func (s *Server) Configure(cfg *appctype.AppConnectorConfig) { - s.mu.Lock() - defer s.mu.Unlock() - s.connectors = makeConnectorsFromConfig(cfg) - log.Printf("installed app connector config: %+v", s.connectors) -} - -// HandleTCPFlow implements tsnet.FallbackTCPHandler. -func (s *Server) HandleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { - m := getMetrics() - s.mu.RLock() - defer s.mu.RUnlock() - - for _, c := range s.connectors { - if handler, intercept := c.handleTCPFlow(src, dst, m); intercept { - return handler, intercept - } - } - - return nil, false -} - -// HandleDNS handles a DNS request to the app connector. -func (s *Server) HandleDNS(c nettype.ConnPacketConn) { - defer c.Close() - c.SetReadDeadline(time.Now().Add(5 * time.Second)) - m := getMetrics() - - buf := make([]byte, 1500) - n, err := c.Read(buf) - if err != nil { - log.Printf("HandleDNS: read failed: %v\n ", err) - m.dnsFailures.Add(1) - return - } - - addrPortStr := c.LocalAddr().String() - host, _, err := net.SplitHostPort(addrPortStr) - if err != nil { - log.Printf("HandleDNS: bogus addrPort %q", addrPortStr) - m.dnsFailures.Add(1) - return - } - localAddr, err := netip.ParseAddr(host) - if err != nil { - log.Printf("HandleDNS: bogus local address %q", host) - m.dnsFailures.Add(1) - return - } - - var msg dnsmessage.Message - err = msg.Unpack(buf[:n]) - if err != nil { - log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err) - m.dnsFailures.Add(1) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - for _, connector := range s.connectors { - resp, err := connector.handleDNS(&msg, localAddr) - if err != nil { - log.Printf("HandleDNS: connector handling failed: %v\n", err) - m.dnsFailures.Add(1) - return - } - if len(resp) > 0 { - // This connector handled the DNS request - _, err = c.Write(resp) - if err != nil { - log.Printf("HandleDNS: write failed: %v\n", err) - m.dnsFailures.Add(1) - return - } - - m.dnsResponses.Add(1) - return - } - } -} - -// connector describes a logical collection of -// services which need to be proxied. -type connector struct { - Handlers map[target]handler -} - -// handleTCPFlow implements tsnet.FallbackTCPHandler. -func (c *connector) handleTCPFlow(src, dst netip.AddrPort, m *appcMetrics) (handler func(net.Conn), intercept bool) { - for t, h := range c.Handlers { - if t.Matching.Proto != 0 && t.Matching.Proto != int(ipproto.TCP) { - continue - } - if !t.Dest.Contains(dst.Addr()) { - continue - } - if !t.Matching.Ports.Contains(dst.Port()) { - continue - } - - switch h.(type) { - case *tcpSNIHandler: - m.sniConns.Add(1) - case *tcpRoundRobinHandler: - m.tcpConns.Add(1) - default: - log.Printf("handleTCPFlow: unhandled handler type %T", h) - } - - return h.Handle, true - } - - m.unhandledConns.Add(1) - return nil, false -} - -// handleDNS returns the DNS response to the given query. If this -// connector is unable to handle the request, nil is returned. -func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (response []byte, err error) { - for t, h := range c.Handlers { - if t.Dest.Contains(localAddr) { - return makeDNSResponse(req, h.ReachableOn()) - } - } - - // Did not match, signal 'not handled' to caller - return nil, nil -} - -func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) { - resp := dnsmessage.NewBuilder(response, - dnsmessage.Header{ - ID: req.Header.ID, - Response: true, - Authoritative: true, - }) - resp.EnableCompression() - - if len(req.Questions) == 0 { - response, _ = resp.Finish() - return response, nil - } - q := req.Questions[0] - err = resp.StartQuestions() - if err != nil { - return - } - resp.Question(q) - - err = resp.StartAnswers() - if err != nil { - return - } - - switch q.Type { - case dnsmessage.TypeAAAA: - for _, ip := range reachableIPs { - if ip.Is6() { - err = resp.AAAAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.AAAAResource{AAAA: ip.As16()}, - ) - } - } - - case dnsmessage.TypeA: - for _, ip := range reachableIPs { - if ip.Is4() { - err = resp.AResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.AResource{A: ip.As4()}, - ) - } - } - - case dnsmessage.TypeSOA: - err = resp.SOAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, - Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, - ) - case dnsmessage.TypeNS: - err = resp.NSResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.NSResource{NS: tsMBox}, - ) - } - - if err != nil { - return nil, err - } - return resp.Finish() -} - -type handler interface { - // Handle handles the given socket. - Handle(c net.Conn) - - // ReachableOn returns the IP addresses this handler is reachable on. - ReachableOn() []netip.Addr -} - -func installDNATHandler(d *appctype.DNATConfig, out *connector) { - // These handlers don't actually do DNAT, they just - // proxy the data over the connection. - var dialer net.Dialer - dialer.Timeout = 5 * time.Second - h := tcpRoundRobinHandler{ - To: d.To, - DialContext: dialer.DialContext, - ReachableIPs: d.Addrs, - } - - for _, addr := range d.Addrs { - for _, protoPort := range d.IP { - t := target{ - Dest: netip.PrefixFrom(addr, addr.BitLen()), - Matching: protoPort, - } - - mak.Set(&out.Handlers, t, handler(&h)) - } - } -} - -func installSNIHandler(c *appctype.SNIProxyConfig, out *connector) { - var dialer net.Dialer - dialer.Timeout = 5 * time.Second - h := tcpSNIHandler{ - Allowlist: c.AllowedDomains, - DialContext: dialer.DialContext, - ReachableIPs: c.Addrs, - } - - for _, addr := range c.Addrs { - for _, protoPort := range c.IP { - t := target{ - Dest: netip.PrefixFrom(addr, addr.BitLen()), - Matching: protoPort, - } - - mak.Set(&out.Handlers, t, handler(&h)) - } - } -} - -func makeConnectorsFromConfig(cfg *appctype.AppConnectorConfig) map[appctype.ConfigID]connector { - var connectors map[appctype.ConfigID]connector - - for cID, d := range cfg.DNAT { - c := connectors[cID] - installDNATHandler(&d, &c) - mak.Set(&connectors, cID, c) - } - for cID, d := range cfg.SNIProxy { - c := connectors[cID] - installSNIHandler(&d, &c) - mak.Set(&connectors, cID, c) - } - - return connectors -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "expvar" + "log" + "net" + "net/netip" + "sync" + "time" + + "golang.org/x/net/dns/dnsmessage" + "tailscale.com/metrics" + "tailscale.com/tailcfg" + "tailscale.com/types/appctype" + "tailscale.com/types/ipproto" + "tailscale.com/types/nettype" + "tailscale.com/util/clientmetric" + "tailscale.com/util/mak" +) + +var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") + +// target describes the predicates which route some inbound +// traffic to the app connector to a specific handler. +type target struct { + Dest netip.Prefix + Matching tailcfg.ProtoPortRange +} + +// Server implements an App Connector as expressed in sniproxy. +type Server struct { + mu sync.RWMutex // mu guards following fields + connectors map[appctype.ConfigID]connector +} + +type appcMetrics struct { + dnsResponses expvar.Int + dnsFailures expvar.Int + tcpConns expvar.Int + sniConns expvar.Int + unhandledConns expvar.Int +} + +var getMetrics = sync.OnceValue[*appcMetrics](func() *appcMetrics { + m := appcMetrics{} + + stats := new(metrics.Set) + stats.Set("tls_sessions", &m.sniConns) + clientmetric.NewCounterFunc("sniproxy_tls_sessions", m.sniConns.Value) + stats.Set("tcp_sessions", &m.tcpConns) + clientmetric.NewCounterFunc("sniproxy_tcp_sessions", m.tcpConns.Value) + stats.Set("dns_responses", &m.dnsResponses) + clientmetric.NewCounterFunc("sniproxy_dns_responses", m.dnsResponses.Value) + stats.Set("dns_failed", &m.dnsFailures) + clientmetric.NewCounterFunc("sniproxy_dns_failed", m.dnsFailures.Value) + expvar.Publish("sniproxy", stats) + + return &m +}) + +// Configure applies the provided configuration to the app connector. +func (s *Server) Configure(cfg *appctype.AppConnectorConfig) { + s.mu.Lock() + defer s.mu.Unlock() + s.connectors = makeConnectorsFromConfig(cfg) + log.Printf("installed app connector config: %+v", s.connectors) +} + +// HandleTCPFlow implements tsnet.FallbackTCPHandler. +func (s *Server) HandleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { + m := getMetrics() + s.mu.RLock() + defer s.mu.RUnlock() + + for _, c := range s.connectors { + if handler, intercept := c.handleTCPFlow(src, dst, m); intercept { + return handler, intercept + } + } + + return nil, false +} + +// HandleDNS handles a DNS request to the app connector. +func (s *Server) HandleDNS(c nettype.ConnPacketConn) { + defer c.Close() + c.SetReadDeadline(time.Now().Add(5 * time.Second)) + m := getMetrics() + + buf := make([]byte, 1500) + n, err := c.Read(buf) + if err != nil { + log.Printf("HandleDNS: read failed: %v\n ", err) + m.dnsFailures.Add(1) + return + } + + addrPortStr := c.LocalAddr().String() + host, _, err := net.SplitHostPort(addrPortStr) + if err != nil { + log.Printf("HandleDNS: bogus addrPort %q", addrPortStr) + m.dnsFailures.Add(1) + return + } + localAddr, err := netip.ParseAddr(host) + if err != nil { + log.Printf("HandleDNS: bogus local address %q", host) + m.dnsFailures.Add(1) + return + } + + var msg dnsmessage.Message + err = msg.Unpack(buf[:n]) + if err != nil { + log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err) + m.dnsFailures.Add(1) + return + } + + s.mu.RLock() + defer s.mu.RUnlock() + for _, connector := range s.connectors { + resp, err := connector.handleDNS(&msg, localAddr) + if err != nil { + log.Printf("HandleDNS: connector handling failed: %v\n", err) + m.dnsFailures.Add(1) + return + } + if len(resp) > 0 { + // This connector handled the DNS request + _, err = c.Write(resp) + if err != nil { + log.Printf("HandleDNS: write failed: %v\n", err) + m.dnsFailures.Add(1) + return + } + + m.dnsResponses.Add(1) + return + } + } +} + +// connector describes a logical collection of +// services which need to be proxied. +type connector struct { + Handlers map[target]handler +} + +// handleTCPFlow implements tsnet.FallbackTCPHandler. +func (c *connector) handleTCPFlow(src, dst netip.AddrPort, m *appcMetrics) (handler func(net.Conn), intercept bool) { + for t, h := range c.Handlers { + if t.Matching.Proto != 0 && t.Matching.Proto != int(ipproto.TCP) { + continue + } + if !t.Dest.Contains(dst.Addr()) { + continue + } + if !t.Matching.Ports.Contains(dst.Port()) { + continue + } + + switch h.(type) { + case *tcpSNIHandler: + m.sniConns.Add(1) + case *tcpRoundRobinHandler: + m.tcpConns.Add(1) + default: + log.Printf("handleTCPFlow: unhandled handler type %T", h) + } + + return h.Handle, true + } + + m.unhandledConns.Add(1) + return nil, false +} + +// handleDNS returns the DNS response to the given query. If this +// connector is unable to handle the request, nil is returned. +func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (response []byte, err error) { + for t, h := range c.Handlers { + if t.Dest.Contains(localAddr) { + return makeDNSResponse(req, h.ReachableOn()) + } + } + + // Did not match, signal 'not handled' to caller + return nil, nil +} + +func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) { + resp := dnsmessage.NewBuilder(response, + dnsmessage.Header{ + ID: req.Header.ID, + Response: true, + Authoritative: true, + }) + resp.EnableCompression() + + if len(req.Questions) == 0 { + response, _ = resp.Finish() + return response, nil + } + q := req.Questions[0] + err = resp.StartQuestions() + if err != nil { + return + } + resp.Question(q) + + err = resp.StartAnswers() + if err != nil { + return + } + + switch q.Type { + case dnsmessage.TypeAAAA: + for _, ip := range reachableIPs { + if ip.Is6() { + err = resp.AAAAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AAAAResource{AAAA: ip.As16()}, + ) + } + } + + case dnsmessage.TypeA: + for _, ip := range reachableIPs { + if ip.Is4() { + err = resp.AResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AResource{A: ip.As4()}, + ) + } + } + + case dnsmessage.TypeSOA: + err = resp.SOAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, + Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, + ) + case dnsmessage.TypeNS: + err = resp.NSResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.NSResource{NS: tsMBox}, + ) + } + + if err != nil { + return nil, err + } + return resp.Finish() +} + +type handler interface { + // Handle handles the given socket. + Handle(c net.Conn) + + // ReachableOn returns the IP addresses this handler is reachable on. + ReachableOn() []netip.Addr +} + +func installDNATHandler(d *appctype.DNATConfig, out *connector) { + // These handlers don't actually do DNAT, they just + // proxy the data over the connection. + var dialer net.Dialer + dialer.Timeout = 5 * time.Second + h := tcpRoundRobinHandler{ + To: d.To, + DialContext: dialer.DialContext, + ReachableIPs: d.Addrs, + } + + for _, addr := range d.Addrs { + for _, protoPort := range d.IP { + t := target{ + Dest: netip.PrefixFrom(addr, addr.BitLen()), + Matching: protoPort, + } + + mak.Set(&out.Handlers, t, handler(&h)) + } + } +} + +func installSNIHandler(c *appctype.SNIProxyConfig, out *connector) { + var dialer net.Dialer + dialer.Timeout = 5 * time.Second + h := tcpSNIHandler{ + Allowlist: c.AllowedDomains, + DialContext: dialer.DialContext, + ReachableIPs: c.Addrs, + } + + for _, addr := range c.Addrs { + for _, protoPort := range c.IP { + t := target{ + Dest: netip.PrefixFrom(addr, addr.BitLen()), + Matching: protoPort, + } + + mak.Set(&out.Handlers, t, handler(&h)) + } + } +} + +func makeConnectorsFromConfig(cfg *appctype.AppConnectorConfig) map[appctype.ConfigID]connector { + var connectors map[appctype.ConfigID]connector + + for cID, d := range cfg.DNAT { + c := connectors[cID] + installDNATHandler(&d, &c) + mak.Set(&connectors, cID, c) + } + for cID, d := range cfg.SNIProxy { + c := connectors[cID] + installSNIHandler(&d, &c) + mak.Set(&connectors, cID, c) + } + + return connectors +} diff --git a/cmd/sniproxy/server_test.go b/cmd/sniproxy/server_test.go index d56f2aa754f85..2a51c874c81b0 100644 --- a/cmd/sniproxy/server_test.go +++ b/cmd/sniproxy/server_test.go @@ -1,95 +1,95 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "net/netip" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "tailscale.com/tailcfg" - "tailscale.com/types/appctype" -) - -func TestMakeConnectorsFromConfig(t *testing.T) { - tcs := []struct { - name string - input *appctype.AppConnectorConfig - want map[appctype.ConfigID]connector - }{ - { - "empty", - &appctype.AppConnectorConfig{}, - nil, - }, - { - "DNAT", - &appctype.AppConnectorConfig{ - DNAT: map[appctype.ConfigID]appctype.DNATConfig{ - "swiggity_swooty": { - Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, - To: []string{"example.org"}, - IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, - }, - }, - }, - map[appctype.ConfigID]connector{ - "swiggity_swooty": { - Handlers: map[target]handler{ - { - Dest: netip.MustParsePrefix("100.64.0.1/32"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - { - Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - }, - }, - }, - }, - { - "SNIProxy", - &appctype.AppConnectorConfig{ - SNIProxy: map[appctype.ConfigID]appctype.SNIProxyConfig{ - "swiggity_swooty": { - Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, - AllowedDomains: []string{"example.org"}, - IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, - }, - }, - }, - map[appctype.ConfigID]connector{ - "swiggity_swooty": { - Handlers: map[target]handler{ - { - Dest: netip.MustParsePrefix("100.64.0.1/32"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - { - Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - }, - }, - }, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - connectors := makeConnectorsFromConfig(tc.input) - - if diff := cmp.Diff(connectors, tc.want, - cmpopts.IgnoreFields(tcpRoundRobinHandler{}, "DialContext"), - cmpopts.IgnoreFields(tcpSNIHandler{}, "DialContext"), - cmp.Comparer(func(x, y netip.Addr) bool { - return x == y - })); diff != "" { - t.Fatalf("mismatch (-want +got):\n%s", diff) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "net/netip" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/tailcfg" + "tailscale.com/types/appctype" +) + +func TestMakeConnectorsFromConfig(t *testing.T) { + tcs := []struct { + name string + input *appctype.AppConnectorConfig + want map[appctype.ConfigID]connector + }{ + { + "empty", + &appctype.AppConnectorConfig{}, + nil, + }, + { + "DNAT", + &appctype.AppConnectorConfig{ + DNAT: map[appctype.ConfigID]appctype.DNATConfig{ + "swiggity_swooty": { + Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, + To: []string{"example.org"}, + IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, + }, + }, + }, + map[appctype.ConfigID]connector{ + "swiggity_swooty": { + Handlers: map[target]handler{ + { + Dest: netip.MustParsePrefix("100.64.0.1/32"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + { + Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + }, + }, + }, + }, + { + "SNIProxy", + &appctype.AppConnectorConfig{ + SNIProxy: map[appctype.ConfigID]appctype.SNIProxyConfig{ + "swiggity_swooty": { + Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, + AllowedDomains: []string{"example.org"}, + IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, + }, + }, + }, + map[appctype.ConfigID]connector{ + "swiggity_swooty": { + Handlers: map[target]handler{ + { + Dest: netip.MustParsePrefix("100.64.0.1/32"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + { + Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + }, + }, + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + connectors := makeConnectorsFromConfig(tc.input) + + if diff := cmp.Diff(connectors, tc.want, + cmpopts.IgnoreFields(tcpRoundRobinHandler{}, "DialContext"), + cmpopts.IgnoreFields(tcpSNIHandler{}, "DialContext"), + cmp.Comparer(func(x, y netip.Addr) bool { + return x == y + })); diff != "" { + t.Fatalf("mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/cmd/sniproxy/sniproxy.go b/cmd/sniproxy/sniproxy.go index fa83aaf4ab44e..c048c8e7e2792 100644 --- a/cmd/sniproxy/sniproxy.go +++ b/cmd/sniproxy/sniproxy.go @@ -1,291 +1,291 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The sniproxy is an outbound SNI proxy. It receives TLS connections over -// Tailscale on one or more TCP ports and sends them out to the same SNI -// hostname & port on the internet. It can optionally forward one or more -// TCP ports to a specific destination. It only does TCP. -package main - -import ( - "context" - "errors" - "flag" - "fmt" - "log" - "net" - "net/http" - "net/netip" - "os" - "sort" - "strconv" - "strings" - - "github.com/peterbourgon/ff/v3" - "tailscale.com/client/tailscale" - "tailscale.com/hostinfo" - "tailscale.com/ipn" - "tailscale.com/tailcfg" - "tailscale.com/tsnet" - "tailscale.com/tsweb" - "tailscale.com/types/appctype" - "tailscale.com/types/ipproto" - "tailscale.com/types/nettype" - "tailscale.com/util/mak" -) - -const configCapKey = "tailscale.com/sniproxy" - -// portForward is the state for a single port forwarding entry, as passed to the --forward flag. -type portForward struct { - Port int - Proto string - Destination string -} - -// parseForward takes a proto/port/destination tuple as an input, as would be passed -// to the --forward command line flag, and returns a *portForward struct of those parameters. -func parseForward(value string) (*portForward, error) { - parts := strings.Split(value, "/") - if len(parts) != 3 { - return nil, errors.New("cannot parse: " + value) - } - - proto := parts[0] - if proto != "tcp" { - return nil, errors.New("unsupported forwarding protocol: " + proto) - } - port, err := strconv.ParseUint(parts[1], 10, 16) - if err != nil { - return nil, errors.New("bad forwarding port: " + parts[1]) - } - host := parts[2] - if host == "" { - return nil, errors.New("bad destination: " + value) - } - - return &portForward{Port: int(port), Proto: proto, Destination: host}, nil -} - -func main() { - // Parse flags - fs := flag.NewFlagSet("sniproxy", flag.ContinueOnError) - var ( - ports = fs.String("ports", "443", "comma-separated list of ports to proxy") - forwards = fs.String("forwards", "", "comma-separated list of ports to transparently forward, protocol/number/destination. For example, --forwards=tcp/22/github.com,tcp/5432/sql.example.com") - wgPort = fs.Int("wg-listen-port", 0, "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select") - promoteHTTPS = fs.Bool("promote-https", true, "promote HTTP to HTTPS") - debugPort = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint") - hostname = fs.String("hostname", "", "Hostname to register the service under") - ) - err := ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_APPC")) - if err != nil { - log.Fatal("ff.Parse") - } - - var ts tsnet.Server - defer ts.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - run(ctx, &ts, *wgPort, *hostname, *promoteHTTPS, *debugPort, *ports, *forwards) -} - -// run actually runs the sniproxy. Its separate from main() to assist in testing. -func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, promoteHTTPS bool, debugPort int, ports, forwards string) { - // Wire up Tailscale node + app connector server - hostinfo.SetApp("sniproxy") - var s sniproxy - s.ts = ts - - s.ts.Port = uint16(wgPort) - s.ts.Hostname = hostname - - lc, err := s.ts.LocalClient() - if err != nil { - log.Fatalf("LocalClient() failed: %v", err) - } - s.lc = lc - s.ts.RegisterFallbackTCPHandler(s.srv.HandleTCPFlow) - - // Start special-purpose listeners: dns, http promotion, debug server - ln, err := s.ts.Listen("udp", ":53") - if err != nil { - log.Fatalf("failed listening on port 53: %v", err) - } - defer ln.Close() - go s.serveDNS(ln) - if promoteHTTPS { - ln, err := s.ts.Listen("tcp", ":80") - if err != nil { - log.Fatalf("failed listening on port 80: %v", err) - } - defer ln.Close() - log.Printf("Promoting HTTP to HTTPS ...") - go s.promoteHTTPS(ln) - } - if debugPort != 0 { - mux := http.NewServeMux() - tsweb.Debugger(mux) - dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", debugPort)) - if err != nil { - log.Fatalf("failed listening on debug port: %v", err) - } - defer dln.Close() - go func() { - log.Fatalf("debug serve: %v", http.Serve(dln, mux)) - }() - } - - // Finally, start mainloop to configure app connector based on information - // in the netmap. - // We set the NotifyInitialNetMap flag so we will always get woken with the - // current netmap, before only being woken on changes. - bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap|ipn.NotifyNoPrivateKeys) - if err != nil { - log.Fatalf("watching IPN bus: %v", err) - } - defer bus.Close() - for { - msg, err := bus.Next() - if err != nil { - if errors.Is(err, context.Canceled) { - return - } - log.Fatalf("reading IPN bus: %v", err) - } - - // NetMap contains app-connector configuration - if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() { - sn := nm.SelfNode.AsStruct() - - var c appctype.AppConnectorConfig - nmConf, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorConfig](sn.CapMap, configCapKey) - if err != nil { - log.Printf("failed to read app connector configuration from coordination server: %v", err) - } else if len(nmConf) > 0 { - c = nmConf[0] - } - - if c.AdvertiseRoutes { - if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil { - log.Printf("failed to advertise routes: %v", err) - } - } - - // Backwards compatibility: combine any configuration from control with flags specified - // on the command line. This is intentionally done after we advertise any routes - // because its never correct to advertise the nodes native IP addresses. - s.mergeConfigFromFlags(&c, ports, forwards) - s.srv.Configure(&c) - } - } -} - -type sniproxy struct { - srv Server - ts *tsnet.Server - lc *tailscale.LocalClient -} - -func (s *sniproxy) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error { - // Collect the set of addresses to advertise, using a map - // to avoid duplicate entries. - addrs := map[netip.Addr]struct{}{} - for _, c := range c.SNIProxy { - for _, ip := range c.Addrs { - addrs[ip] = struct{}{} - } - } - for _, c := range c.DNAT { - for _, ip := range c.Addrs { - addrs[ip] = struct{}{} - } - } - - var routes []netip.Prefix - for a := range addrs { - routes = append(routes, netip.PrefixFrom(a, a.BitLen())) - } - sort.SliceStable(routes, func(i, j int) bool { - return routes[i].Addr().Less(routes[j].Addr()) // determinism r us - }) - - _, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ - Prefs: ipn.Prefs{ - AdvertiseRoutes: routes, - }, - AdvertiseRoutesSet: true, - }) - return err -} - -func (s *sniproxy) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, forwards string) { - ip4, ip6 := s.ts.TailscaleIPs() - - sniConfigFromFlags := appctype.SNIProxyConfig{ - Addrs: []netip.Addr{ip4, ip6}, - } - if ports != "" { - for _, portStr := range strings.Split(ports, ",") { - port, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - log.Fatalf("invalid port: %s", portStr) - } - sniConfigFromFlags.IP = append(sniConfigFromFlags.IP, tailcfg.ProtoPortRange{ - Proto: int(ipproto.TCP), - Ports: tailcfg.PortRange{First: uint16(port), Last: uint16(port)}, - }) - } - } - - var forwardConfigFromFlags []appctype.DNATConfig - for _, forwStr := range strings.Split(forwards, ",") { - if forwStr == "" { - continue - } - forw, err := parseForward(forwStr) - if err != nil { - log.Printf("invalid forwarding spec: %v", err) - continue - } - - forwardConfigFromFlags = append(forwardConfigFromFlags, appctype.DNATConfig{ - Addrs: []netip.Addr{ip4, ip6}, - To: []string{forw.Destination}, - IP: []tailcfg.ProtoPortRange{ - { - Proto: int(ipproto.TCP), - Ports: tailcfg.PortRange{First: uint16(forw.Port), Last: uint16(forw.Port)}, - }, - }, - }) - } - - if len(forwardConfigFromFlags) == 0 && len(sniConfigFromFlags.IP) == 0 { - return // no config specified on the command line - } - - mak.Set(&out.SNIProxy, "flags", sniConfigFromFlags) - for i, forward := range forwardConfigFromFlags { - mak.Set(&out.DNAT, appctype.ConfigID(fmt.Sprintf("flags_%d", i)), forward) - } -} - -func (s *sniproxy) serveDNS(ln net.Listener) { - for { - c, err := ln.Accept() - if err != nil { - log.Printf("serveDNS accept: %v", err) - return - } - go s.srv.HandleDNS(c.(nettype.ConnPacketConn)) - } -} - -func (s *sniproxy) promoteHTTPS(ln net.Listener) { - err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound) - })) - log.Fatalf("promoteHTTPS http.Serve: %v", err) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The sniproxy is an outbound SNI proxy. It receives TLS connections over +// Tailscale on one or more TCP ports and sends them out to the same SNI +// hostname & port on the internet. It can optionally forward one or more +// TCP ports to a specific destination. It only does TCP. +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net" + "net/http" + "net/netip" + "os" + "sort" + "strconv" + "strings" + + "github.com/peterbourgon/ff/v3" + "tailscale.com/client/tailscale" + "tailscale.com/hostinfo" + "tailscale.com/ipn" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" + "tailscale.com/tsweb" + "tailscale.com/types/appctype" + "tailscale.com/types/ipproto" + "tailscale.com/types/nettype" + "tailscale.com/util/mak" +) + +const configCapKey = "tailscale.com/sniproxy" + +// portForward is the state for a single port forwarding entry, as passed to the --forward flag. +type portForward struct { + Port int + Proto string + Destination string +} + +// parseForward takes a proto/port/destination tuple as an input, as would be passed +// to the --forward command line flag, and returns a *portForward struct of those parameters. +func parseForward(value string) (*portForward, error) { + parts := strings.Split(value, "/") + if len(parts) != 3 { + return nil, errors.New("cannot parse: " + value) + } + + proto := parts[0] + if proto != "tcp" { + return nil, errors.New("unsupported forwarding protocol: " + proto) + } + port, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return nil, errors.New("bad forwarding port: " + parts[1]) + } + host := parts[2] + if host == "" { + return nil, errors.New("bad destination: " + value) + } + + return &portForward{Port: int(port), Proto: proto, Destination: host}, nil +} + +func main() { + // Parse flags + fs := flag.NewFlagSet("sniproxy", flag.ContinueOnError) + var ( + ports = fs.String("ports", "443", "comma-separated list of ports to proxy") + forwards = fs.String("forwards", "", "comma-separated list of ports to transparently forward, protocol/number/destination. For example, --forwards=tcp/22/github.com,tcp/5432/sql.example.com") + wgPort = fs.Int("wg-listen-port", 0, "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select") + promoteHTTPS = fs.Bool("promote-https", true, "promote HTTP to HTTPS") + debugPort = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint") + hostname = fs.String("hostname", "", "Hostname to register the service under") + ) + err := ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_APPC")) + if err != nil { + log.Fatal("ff.Parse") + } + + var ts tsnet.Server + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + run(ctx, &ts, *wgPort, *hostname, *promoteHTTPS, *debugPort, *ports, *forwards) +} + +// run actually runs the sniproxy. Its separate from main() to assist in testing. +func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, promoteHTTPS bool, debugPort int, ports, forwards string) { + // Wire up Tailscale node + app connector server + hostinfo.SetApp("sniproxy") + var s sniproxy + s.ts = ts + + s.ts.Port = uint16(wgPort) + s.ts.Hostname = hostname + + lc, err := s.ts.LocalClient() + if err != nil { + log.Fatalf("LocalClient() failed: %v", err) + } + s.lc = lc + s.ts.RegisterFallbackTCPHandler(s.srv.HandleTCPFlow) + + // Start special-purpose listeners: dns, http promotion, debug server + ln, err := s.ts.Listen("udp", ":53") + if err != nil { + log.Fatalf("failed listening on port 53: %v", err) + } + defer ln.Close() + go s.serveDNS(ln) + if promoteHTTPS { + ln, err := s.ts.Listen("tcp", ":80") + if err != nil { + log.Fatalf("failed listening on port 80: %v", err) + } + defer ln.Close() + log.Printf("Promoting HTTP to HTTPS ...") + go s.promoteHTTPS(ln) + } + if debugPort != 0 { + mux := http.NewServeMux() + tsweb.Debugger(mux) + dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", debugPort)) + if err != nil { + log.Fatalf("failed listening on debug port: %v", err) + } + defer dln.Close() + go func() { + log.Fatalf("debug serve: %v", http.Serve(dln, mux)) + }() + } + + // Finally, start mainloop to configure app connector based on information + // in the netmap. + // We set the NotifyInitialNetMap flag so we will always get woken with the + // current netmap, before only being woken on changes. + bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap|ipn.NotifyNoPrivateKeys) + if err != nil { + log.Fatalf("watching IPN bus: %v", err) + } + defer bus.Close() + for { + msg, err := bus.Next() + if err != nil { + if errors.Is(err, context.Canceled) { + return + } + log.Fatalf("reading IPN bus: %v", err) + } + + // NetMap contains app-connector configuration + if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() { + sn := nm.SelfNode.AsStruct() + + var c appctype.AppConnectorConfig + nmConf, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorConfig](sn.CapMap, configCapKey) + if err != nil { + log.Printf("failed to read app connector configuration from coordination server: %v", err) + } else if len(nmConf) > 0 { + c = nmConf[0] + } + + if c.AdvertiseRoutes { + if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil { + log.Printf("failed to advertise routes: %v", err) + } + } + + // Backwards compatibility: combine any configuration from control with flags specified + // on the command line. This is intentionally done after we advertise any routes + // because its never correct to advertise the nodes native IP addresses. + s.mergeConfigFromFlags(&c, ports, forwards) + s.srv.Configure(&c) + } + } +} + +type sniproxy struct { + srv Server + ts *tsnet.Server + lc *tailscale.LocalClient +} + +func (s *sniproxy) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error { + // Collect the set of addresses to advertise, using a map + // to avoid duplicate entries. + addrs := map[netip.Addr]struct{}{} + for _, c := range c.SNIProxy { + for _, ip := range c.Addrs { + addrs[ip] = struct{}{} + } + } + for _, c := range c.DNAT { + for _, ip := range c.Addrs { + addrs[ip] = struct{}{} + } + } + + var routes []netip.Prefix + for a := range addrs { + routes = append(routes, netip.PrefixFrom(a, a.BitLen())) + } + sort.SliceStable(routes, func(i, j int) bool { + return routes[i].Addr().Less(routes[j].Addr()) // determinism r us + }) + + _, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + AdvertiseRoutes: routes, + }, + AdvertiseRoutesSet: true, + }) + return err +} + +func (s *sniproxy) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, forwards string) { + ip4, ip6 := s.ts.TailscaleIPs() + + sniConfigFromFlags := appctype.SNIProxyConfig{ + Addrs: []netip.Addr{ip4, ip6}, + } + if ports != "" { + for _, portStr := range strings.Split(ports, ",") { + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + log.Fatalf("invalid port: %s", portStr) + } + sniConfigFromFlags.IP = append(sniConfigFromFlags.IP, tailcfg.ProtoPortRange{ + Proto: int(ipproto.TCP), + Ports: tailcfg.PortRange{First: uint16(port), Last: uint16(port)}, + }) + } + } + + var forwardConfigFromFlags []appctype.DNATConfig + for _, forwStr := range strings.Split(forwards, ",") { + if forwStr == "" { + continue + } + forw, err := parseForward(forwStr) + if err != nil { + log.Printf("invalid forwarding spec: %v", err) + continue + } + + forwardConfigFromFlags = append(forwardConfigFromFlags, appctype.DNATConfig{ + Addrs: []netip.Addr{ip4, ip6}, + To: []string{forw.Destination}, + IP: []tailcfg.ProtoPortRange{ + { + Proto: int(ipproto.TCP), + Ports: tailcfg.PortRange{First: uint16(forw.Port), Last: uint16(forw.Port)}, + }, + }, + }) + } + + if len(forwardConfigFromFlags) == 0 && len(sniConfigFromFlags.IP) == 0 { + return // no config specified on the command line + } + + mak.Set(&out.SNIProxy, "flags", sniConfigFromFlags) + for i, forward := range forwardConfigFromFlags { + mak.Set(&out.DNAT, appctype.ConfigID(fmt.Sprintf("flags_%d", i)), forward) + } +} + +func (s *sniproxy) serveDNS(ln net.Listener) { + for { + c, err := ln.Accept() + if err != nil { + log.Printf("serveDNS accept: %v", err) + return + } + go s.srv.HandleDNS(c.(nettype.ConnPacketConn)) + } +} + +func (s *sniproxy) promoteHTTPS(ln net.Listener) { + err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound) + })) + log.Fatalf("promoteHTTPS http.Serve: %v", err) +} diff --git a/cmd/speedtest/speedtest.go b/cmd/speedtest/speedtest.go index 9a457ed6c7486..1555c0dcc0b7a 100644 --- a/cmd/speedtest/speedtest.go +++ b/cmd/speedtest/speedtest.go @@ -1,121 +1,121 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Program speedtest provides the speedtest command. The reason to keep it separate from -// the normal tailscale cli is because it is not yet ready to go in the tailscale binary. -// It will be included in the tailscale cli after it has been added to tailscaled. - -// Example usage for client command: go run cmd/speedtest -host 127.0.0.1:20333 -t 5s -// This will connect to the server on 127.0.0.1:20333 and start a 5 second download speedtest. -// Example usage for server command: go run cmd/speedtest -s -host :20333 -// This will start a speedtest server on port 20333. -package main - -import ( - "context" - "errors" - "flag" - "fmt" - "net" - "os" - "strconv" - "text/tabwriter" - "time" - - "github.com/peterbourgon/ff/v3/ffcli" - "tailscale.com/net/speedtest" -) - -// Runs the speedtest command as a commandline program -func main() { - args := os.Args[1:] - if err := speedtestCmd.Parse(args); err != nil { - fmt.Fprintln(os.Stderr, err.Error()) - os.Exit(1) - } - - err := speedtestCmd.Run(context.Background()) - if errors.Is(err, flag.ErrHelp) { - fmt.Fprintln(os.Stderr, speedtestCmd.ShortUsage) - os.Exit(2) - } - if err != nil { - fmt.Fprintln(os.Stderr, err.Error()) - os.Exit(1) - } -} - -// speedtestCmd is the root command. It runs either the server or client depending on the -// flags passed to it. -var speedtestCmd = &ffcli.Command{ - Name: "speedtest", - ShortUsage: "speedtest [-host ] [-s] [-r] [-t ]", - ShortHelp: "Run a speed test", - FlagSet: (func() *flag.FlagSet { - fs := flag.NewFlagSet("speedtest", flag.ExitOnError) - fs.StringVar(&speedtestArgs.host, "host", ":20333", "host:port pair to connect to or listen on") - fs.DurationVar(&speedtestArgs.testDuration, "t", speedtest.DefaultDuration, "duration of the speed test") - fs.BoolVar(&speedtestArgs.runServer, "s", false, "run a speedtest server") - fs.BoolVar(&speedtestArgs.reverse, "r", false, "run in reverse mode (server sends, client receives)") - return fs - })(), - Exec: runSpeedtest, -} - -var speedtestArgs struct { - host string - testDuration time.Duration - runServer bool - reverse bool -} - -func runSpeedtest(ctx context.Context, args []string) error { - - if _, _, err := net.SplitHostPort(speedtestArgs.host); err != nil { - var addrErr *net.AddrError - if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" { - // if no port is provided, append the default port - speedtestArgs.host = net.JoinHostPort(speedtestArgs.host, strconv.Itoa(speedtest.DefaultPort)) - } - } - - if speedtestArgs.runServer { - listener, err := net.Listen("tcp", speedtestArgs.host) - if err != nil { - return err - } - - fmt.Printf("listening on %v\n", listener.Addr()) - - return speedtest.Serve(listener) - } - - // Ensure the duration is within the allowed range - if speedtestArgs.testDuration < speedtest.MinDuration || speedtestArgs.testDuration > speedtest.MaxDuration { - return fmt.Errorf("test duration must be within %v and %v", speedtest.MinDuration, speedtest.MaxDuration) - } - - dir := speedtest.Download - if speedtestArgs.reverse { - dir = speedtest.Upload - } - - fmt.Printf("Starting a %s test with %s\n", dir, speedtestArgs.host) - results, err := speedtest.RunClient(dir, speedtestArgs.testDuration, speedtestArgs.host) - if err != nil { - return err - } - - w := tabwriter.NewWriter(os.Stdout, 12, 0, 0, ' ', tabwriter.TabIndent) - fmt.Println("Results:") - fmt.Fprintln(w, "Interval\t\tTransfer\t\tBandwidth\t\t") - startTime := results[0].IntervalStart - for _, r := range results { - if r.Total { - fmt.Fprintln(w, "-------------------------------------------------------------------------") - } - fmt.Fprintf(w, "%.2f-%.2f\tsec\t%.4f\tMBits\t%.4f\tMbits/sec\t\n", r.IntervalStart.Sub(startTime).Seconds(), r.IntervalEnd.Sub(startTime).Seconds(), r.MegaBits(), r.MBitsPerSecond()) - } - w.Flush() - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Program speedtest provides the speedtest command. The reason to keep it separate from +// the normal tailscale cli is because it is not yet ready to go in the tailscale binary. +// It will be included in the tailscale cli after it has been added to tailscaled. + +// Example usage for client command: go run cmd/speedtest -host 127.0.0.1:20333 -t 5s +// This will connect to the server on 127.0.0.1:20333 and start a 5 second download speedtest. +// Example usage for server command: go run cmd/speedtest -s -host :20333 +// This will start a speedtest server on port 20333. +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "net" + "os" + "strconv" + "text/tabwriter" + "time" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/net/speedtest" +) + +// Runs the speedtest command as a commandline program +func main() { + args := os.Args[1:] + if err := speedtestCmd.Parse(args); err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + + err := speedtestCmd.Run(context.Background()) + if errors.Is(err, flag.ErrHelp) { + fmt.Fprintln(os.Stderr, speedtestCmd.ShortUsage) + os.Exit(2) + } + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } +} + +// speedtestCmd is the root command. It runs either the server or client depending on the +// flags passed to it. +var speedtestCmd = &ffcli.Command{ + Name: "speedtest", + ShortUsage: "speedtest [-host ] [-s] [-r] [-t ]", + ShortHelp: "Run a speed test", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("speedtest", flag.ExitOnError) + fs.StringVar(&speedtestArgs.host, "host", ":20333", "host:port pair to connect to or listen on") + fs.DurationVar(&speedtestArgs.testDuration, "t", speedtest.DefaultDuration, "duration of the speed test") + fs.BoolVar(&speedtestArgs.runServer, "s", false, "run a speedtest server") + fs.BoolVar(&speedtestArgs.reverse, "r", false, "run in reverse mode (server sends, client receives)") + return fs + })(), + Exec: runSpeedtest, +} + +var speedtestArgs struct { + host string + testDuration time.Duration + runServer bool + reverse bool +} + +func runSpeedtest(ctx context.Context, args []string) error { + + if _, _, err := net.SplitHostPort(speedtestArgs.host); err != nil { + var addrErr *net.AddrError + if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" { + // if no port is provided, append the default port + speedtestArgs.host = net.JoinHostPort(speedtestArgs.host, strconv.Itoa(speedtest.DefaultPort)) + } + } + + if speedtestArgs.runServer { + listener, err := net.Listen("tcp", speedtestArgs.host) + if err != nil { + return err + } + + fmt.Printf("listening on %v\n", listener.Addr()) + + return speedtest.Serve(listener) + } + + // Ensure the duration is within the allowed range + if speedtestArgs.testDuration < speedtest.MinDuration || speedtestArgs.testDuration > speedtest.MaxDuration { + return fmt.Errorf("test duration must be within %v and %v", speedtest.MinDuration, speedtest.MaxDuration) + } + + dir := speedtest.Download + if speedtestArgs.reverse { + dir = speedtest.Upload + } + + fmt.Printf("Starting a %s test with %s\n", dir, speedtestArgs.host) + results, err := speedtest.RunClient(dir, speedtestArgs.testDuration, speedtestArgs.host) + if err != nil { + return err + } + + w := tabwriter.NewWriter(os.Stdout, 12, 0, 0, ' ', tabwriter.TabIndent) + fmt.Println("Results:") + fmt.Fprintln(w, "Interval\t\tTransfer\t\tBandwidth\t\t") + startTime := results[0].IntervalStart + for _, r := range results { + if r.Total { + fmt.Fprintln(w, "-------------------------------------------------------------------------") + } + fmt.Fprintf(w, "%.2f-%.2f\tsec\t%.4f\tMBits\t%.4f\tMbits/sec\t\n", r.IntervalStart.Sub(startTime).Seconds(), r.IntervalEnd.Sub(startTime).Seconds(), r.MegaBits(), r.MBitsPerSecond()) + } + w.Flush() + return nil +} diff --git a/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go b/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go index ee929299a4273..ade272c4ba811 100644 --- a/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go +++ b/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go @@ -1,187 +1,187 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// ssh-auth-none-demo is a demo SSH server that's meant to run on the -// public internet (at 188.166.70.128 port 2222) and -// highlight the unique parts of the Tailscale SSH server so SSH -// client authors can hit it easily and fix their SSH clients without -// needing to set up Tailscale and Tailscale SSH. -package main - -import ( - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "flag" - "fmt" - "io" - "log" - "os" - "path/filepath" - "time" - - gossh "github.com/tailscale/golang-x-crypto/ssh" - "tailscale.com/tempfork/gliderlabs/ssh" -) - -// keyTypes are the SSH key types that we either try to read from the -// system's OpenSSH keys. -var keyTypes = []string{"rsa", "ecdsa", "ed25519"} - -var ( - addr = flag.String("addr", ":2222", "address to listen on") -) - -func main() { - flag.Parse() - - cacheDir, err := os.UserCacheDir() - if err != nil { - log.Fatal(err) - } - dir := filepath.Join(cacheDir, "ssh-auth-none-demo") - if err := os.MkdirAll(dir, 0700); err != nil { - log.Fatal(err) - } - - keys, err := getHostKeys(dir) - if err != nil { - log.Fatal(err) - } - if len(keys) == 0 { - log.Fatal("no host keys") - } - - srv := &ssh.Server{ - Addr: *addr, - Version: "Tailscale", - Handler: handleSessionPostSSHAuth, - ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { - start := time.Now() - return &gossh.ServerConfig{ - NextAuthMethodCallback: func(conn gossh.ConnMetadata, prevErrors []error) []string { - return []string{"tailscale"} - }, - NoClientAuth: true, // required for the NoClientAuthCallback to run - NoClientAuthCallback: func(cm gossh.ConnMetadata) (*gossh.Permissions, error) { - cm.SendAuthBanner(fmt.Sprintf("# Banner: doing none auth at %v\r\n", time.Since(start))) - - totalBanners := 2 - if cm.User() == "banners" { - totalBanners = 5 - } - for banner := 2; banner <= totalBanners; banner++ { - time.Sleep(time.Second) - if banner == totalBanners { - cm.SendAuthBanner(fmt.Sprintf("# Banner%d: access granted at %v\r\n", banner, time.Since(start))) - } else { - cm.SendAuthBanner(fmt.Sprintf("# Banner%d at %v\r\n", banner, time.Since(start))) - } - } - return nil, nil - }, - BannerCallback: func(cm gossh.ConnMetadata) string { - log.Printf("Got connection from user %q, %q from %v", cm.User(), cm.ClientVersion(), cm.RemoteAddr()) - return fmt.Sprintf("# Banner for user %q, %q\n", cm.User(), cm.ClientVersion()) - }, - } - }, - } - - for _, signer := range keys { - srv.AddHostKey(signer) - } - - log.Printf("Running on %s ...", srv.Addr) - if err := srv.ListenAndServe(); err != nil { - log.Fatal(err) - } - log.Printf("done") -} - -func handleSessionPostSSHAuth(s ssh.Session) { - log.Printf("Started session from user %q", s.User()) - fmt.Fprintf(s, "Hello user %q, it worked.\n", s.User()) - - // Abort the session on Control-C or Control-D. - go func() { - buf := make([]byte, 1024) - for { - n, err := s.Read(buf) - for _, b := range buf[:n] { - if b <= 4 { // abort on Control-C (3) or Control-D (4) - io.WriteString(s, "bye\n") - s.Exit(1) - } - } - if err != nil { - return - } - } - }() - - for i := 10; i > 0; i-- { - fmt.Fprintf(s, "%v ...\n", i) - time.Sleep(time.Second) - } - s.Exit(0) -} - -func getHostKeys(dir string) (ret []ssh.Signer, err error) { - for _, typ := range keyTypes { - hostKey, err := hostKeyFileOrCreate(dir, typ) - if err != nil { - return nil, err - } - signer, err := gossh.ParsePrivateKey(hostKey) - if err != nil { - return nil, err - } - ret = append(ret, signer) - } - return ret, nil -} - -func hostKeyFileOrCreate(keyDir, typ string) ([]byte, error) { - path := filepath.Join(keyDir, "ssh_host_"+typ+"_key") - v, err := os.ReadFile(path) - if err == nil { - return v, nil - } - if !os.IsNotExist(err) { - return nil, err - } - var priv any - switch typ { - default: - return nil, fmt.Errorf("unsupported key type %q", typ) - case "ed25519": - _, priv, err = ed25519.GenerateKey(rand.Reader) - case "ecdsa": - // curve is arbitrary. We pick whatever will at - // least pacify clients as the actual encryption - // doesn't matter: it's all over WireGuard anyway. - curve := elliptic.P256() - priv, err = ecdsa.GenerateKey(curve, rand.Reader) - case "rsa": - // keySize is arbitrary. We pick whatever will at - // least pacify clients as the actual encryption - // doesn't matter: it's all over WireGuard anyway. - const keySize = 2048 - priv, err = rsa.GenerateKey(rand.Reader, keySize) - } - if err != nil { - return nil, err - } - mk, err := x509.MarshalPKCS8PrivateKey(priv) - if err != nil { - return nil, err - } - pemGen := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk}) - err = os.WriteFile(path, pemGen, 0700) - return pemGen, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// ssh-auth-none-demo is a demo SSH server that's meant to run on the +// public internet (at 188.166.70.128 port 2222) and +// highlight the unique parts of the Tailscale SSH server so SSH +// client authors can hit it easily and fix their SSH clients without +// needing to set up Tailscale and Tailscale SSH. +package main + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "flag" + "fmt" + "io" + "log" + "os" + "path/filepath" + "time" + + gossh "github.com/tailscale/golang-x-crypto/ssh" + "tailscale.com/tempfork/gliderlabs/ssh" +) + +// keyTypes are the SSH key types that we either try to read from the +// system's OpenSSH keys. +var keyTypes = []string{"rsa", "ecdsa", "ed25519"} + +var ( + addr = flag.String("addr", ":2222", "address to listen on") +) + +func main() { + flag.Parse() + + cacheDir, err := os.UserCacheDir() + if err != nil { + log.Fatal(err) + } + dir := filepath.Join(cacheDir, "ssh-auth-none-demo") + if err := os.MkdirAll(dir, 0700); err != nil { + log.Fatal(err) + } + + keys, err := getHostKeys(dir) + if err != nil { + log.Fatal(err) + } + if len(keys) == 0 { + log.Fatal("no host keys") + } + + srv := &ssh.Server{ + Addr: *addr, + Version: "Tailscale", + Handler: handleSessionPostSSHAuth, + ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { + start := time.Now() + return &gossh.ServerConfig{ + NextAuthMethodCallback: func(conn gossh.ConnMetadata, prevErrors []error) []string { + return []string{"tailscale"} + }, + NoClientAuth: true, // required for the NoClientAuthCallback to run + NoClientAuthCallback: func(cm gossh.ConnMetadata) (*gossh.Permissions, error) { + cm.SendAuthBanner(fmt.Sprintf("# Banner: doing none auth at %v\r\n", time.Since(start))) + + totalBanners := 2 + if cm.User() == "banners" { + totalBanners = 5 + } + for banner := 2; banner <= totalBanners; banner++ { + time.Sleep(time.Second) + if banner == totalBanners { + cm.SendAuthBanner(fmt.Sprintf("# Banner%d: access granted at %v\r\n", banner, time.Since(start))) + } else { + cm.SendAuthBanner(fmt.Sprintf("# Banner%d at %v\r\n", banner, time.Since(start))) + } + } + return nil, nil + }, + BannerCallback: func(cm gossh.ConnMetadata) string { + log.Printf("Got connection from user %q, %q from %v", cm.User(), cm.ClientVersion(), cm.RemoteAddr()) + return fmt.Sprintf("# Banner for user %q, %q\n", cm.User(), cm.ClientVersion()) + }, + } + }, + } + + for _, signer := range keys { + srv.AddHostKey(signer) + } + + log.Printf("Running on %s ...", srv.Addr) + if err := srv.ListenAndServe(); err != nil { + log.Fatal(err) + } + log.Printf("done") +} + +func handleSessionPostSSHAuth(s ssh.Session) { + log.Printf("Started session from user %q", s.User()) + fmt.Fprintf(s, "Hello user %q, it worked.\n", s.User()) + + // Abort the session on Control-C or Control-D. + go func() { + buf := make([]byte, 1024) + for { + n, err := s.Read(buf) + for _, b := range buf[:n] { + if b <= 4 { // abort on Control-C (3) or Control-D (4) + io.WriteString(s, "bye\n") + s.Exit(1) + } + } + if err != nil { + return + } + } + }() + + for i := 10; i > 0; i-- { + fmt.Fprintf(s, "%v ...\n", i) + time.Sleep(time.Second) + } + s.Exit(0) +} + +func getHostKeys(dir string) (ret []ssh.Signer, err error) { + for _, typ := range keyTypes { + hostKey, err := hostKeyFileOrCreate(dir, typ) + if err != nil { + return nil, err + } + signer, err := gossh.ParsePrivateKey(hostKey) + if err != nil { + return nil, err + } + ret = append(ret, signer) + } + return ret, nil +} + +func hostKeyFileOrCreate(keyDir, typ string) ([]byte, error) { + path := filepath.Join(keyDir, "ssh_host_"+typ+"_key") + v, err := os.ReadFile(path) + if err == nil { + return v, nil + } + if !os.IsNotExist(err) { + return nil, err + } + var priv any + switch typ { + default: + return nil, fmt.Errorf("unsupported key type %q", typ) + case "ed25519": + _, priv, err = ed25519.GenerateKey(rand.Reader) + case "ecdsa": + // curve is arbitrary. We pick whatever will at + // least pacify clients as the actual encryption + // doesn't matter: it's all over WireGuard anyway. + curve := elliptic.P256() + priv, err = ecdsa.GenerateKey(curve, rand.Reader) + case "rsa": + // keySize is arbitrary. We pick whatever will at + // least pacify clients as the actual encryption + // doesn't matter: it's all over WireGuard anyway. + const keySize = 2048 + priv, err = rsa.GenerateKey(rand.Reader, keySize) + } + if err != nil { + return nil, err + } + mk, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + return nil, err + } + pemGen := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk}) + err = os.WriteFile(path, pemGen, 0700) + return pemGen, err +} diff --git a/cmd/sync-containers/main.go b/cmd/sync-containers/main.go index 6317b4943ae82..68308cfeb3eda 100644 --- a/cmd/sync-containers/main.go +++ b/cmd/sync-containers/main.go @@ -1,214 +1,214 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -// The sync-containers command synchronizes container image tags from one -// registry to another. -// -// It is intended as a workaround for ghcr.io's lack of good push credentials: -// you can either authorize "classic" Personal Access Tokens in your org (which -// are a common vector of very bad compromise), or you can get a short-lived -// credential in a Github action. -// -// Since we publish to both Docker Hub and ghcr.io, we use this program in a -// Github action to effectively rsync from docker hub into ghcr.io, so that we -// can continue to forbid dangerous Personal Access Tokens in the tailscale org. -package main - -import ( - "context" - "flag" - "fmt" - "log" - "sort" - "strings" - - "github.com/google/go-containerregistry/pkg/authn" - "github.com/google/go-containerregistry/pkg/authn/github" - "github.com/google/go-containerregistry/pkg/name" - v1 "github.com/google/go-containerregistry/pkg/v1" - "github.com/google/go-containerregistry/pkg/v1/remote" - "github.com/google/go-containerregistry/pkg/v1/types" -) - -var ( - src = flag.String("src", "", "Source image") - dst = flag.String("dst", "", "Destination image") - max = flag.Int("max", 0, "Maximum number of tags to sync (0 for all tags)") - dryRun = flag.Bool("dry-run", true, "Don't actually sync anything") -) - -func main() { - flag.Parse() - - if *src == "" { - log.Fatalf("--src is required") - } - if *dst == "" { - log.Fatalf("--dst is required") - } - - keychain := authn.NewMultiKeychain(authn.DefaultKeychain, github.Keychain) - opts := []remote.Option{ - remote.WithAuthFromKeychain(keychain), - remote.WithContext(context.Background()), - } - - stags, err := listTags(*src, opts...) - if err != nil { - log.Fatalf("listing source tags: %v", err) - } - dtags, err := listTags(*dst, opts...) - if err != nil { - log.Fatalf("listing destination tags: %v", err) - } - - add, remove := diffTags(stags, dtags) - if l := len(add); l > 0 { - log.Printf("%d tags to push: %s", len(add), strings.Join(add, ", ")) - if *max > 0 && l > *max { - log.Printf("Limiting sync to %d tags", *max) - add = add[:*max] - } - } - for _, tag := range add { - if !*dryRun { - log.Printf("Syncing tag %q", tag) - if err := copyTag(*src, *dst, tag, opts...); err != nil { - log.Printf("Syncing tag %q: progress error: %v", tag, err) - } - } else { - log.Printf("Dry run: would sync tag %q", tag) - } - } - - if len(remove) > 0 { - log.Printf("%d tags to remove: %s\n", len(remove), strings.Join(remove, ", ")) - log.Printf("Not removing any tags for safety.\n") - } - - var wellKnown = [...]string{"latest", "stable"} - for _, tag := range wellKnown { - if needsUpdate(*src, *dst, tag) { - if err := copyTag(*src, *dst, tag, opts...); err != nil { - log.Printf("Updating tag %q: progress error: %v", tag, err) - } - } - } -} - -func copyTag(srcStr, dstStr, tag string, opts ...remote.Option) error { - src, err := name.ParseReference(fmt.Sprintf("%s:%s", srcStr, tag)) - if err != nil { - return err - } - dst, err := name.ParseReference(fmt.Sprintf("%s:%s", dstStr, tag)) - if err != nil { - return err - } - - desc, err := remote.Get(src) - if err != nil { - return err - } - - ch := make(chan v1.Update, 10) - opts = append(opts, remote.WithProgress(ch)) - progressDone := make(chan struct{}) - - go func() { - defer close(progressDone) - for p := range ch { - fmt.Printf("Syncing tag %q: %d%% (%d/%d)\n", tag, int(float64(p.Complete)/float64(p.Total)*100), p.Complete, p.Total) - if p.Error != nil { - fmt.Printf("error: %v\n", p.Error) - } - } - }() - - switch desc.MediaType { - case types.OCIManifestSchema1, types.DockerManifestSchema2: - img, err := desc.Image() - if err != nil { - return err - } - if err := remote.Write(dst, img, opts...); err != nil { - return err - } - case types.OCIImageIndex, types.DockerManifestList: - idx, err := desc.ImageIndex() - if err != nil { - return err - } - if err := remote.WriteIndex(dst, idx, opts...); err != nil { - return err - } - } - - <-progressDone - return nil -} - -func listTags(repoStr string, opts ...remote.Option) ([]string, error) { - repo, err := name.NewRepository(repoStr) - if err != nil { - return nil, err - } - - tags, err := remote.List(repo, opts...) - if err != nil { - return nil, err - } - - sort.Strings(tags) - return tags, nil -} - -func diffTags(src, dst []string) (add, remove []string) { - srcd := make(map[string]bool) - for _, tag := range src { - srcd[tag] = true - } - dstd := make(map[string]bool) - for _, tag := range dst { - dstd[tag] = true - } - - for _, tag := range src { - if !dstd[tag] { - add = append(add, tag) - } - } - for _, tag := range dst { - if !srcd[tag] { - remove = append(remove, tag) - } - } - sort.Strings(add) - sort.Strings(remove) - return add, remove -} - -func needsUpdate(srcStr, dstStr, tag string) bool { - src, err := name.ParseReference(fmt.Sprintf("%s:%s", srcStr, tag)) - if err != nil { - return false - } - dst, err := name.ParseReference(fmt.Sprintf("%s:%s", dstStr, tag)) - if err != nil { - return false - } - - srcDesc, err := remote.Get(src) - if err != nil { - return false - } - - dstDesc, err := remote.Get(dst) - if err != nil { - return true - } - - return srcDesc.Digest != dstDesc.Digest -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// The sync-containers command synchronizes container image tags from one +// registry to another. +// +// It is intended as a workaround for ghcr.io's lack of good push credentials: +// you can either authorize "classic" Personal Access Tokens in your org (which +// are a common vector of very bad compromise), or you can get a short-lived +// credential in a Github action. +// +// Since we publish to both Docker Hub and ghcr.io, we use this program in a +// Github action to effectively rsync from docker hub into ghcr.io, so that we +// can continue to forbid dangerous Personal Access Tokens in the tailscale org. +package main + +import ( + "context" + "flag" + "fmt" + "log" + "sort" + "strings" + + "github.com/google/go-containerregistry/pkg/authn" + "github.com/google/go-containerregistry/pkg/authn/github" + "github.com/google/go-containerregistry/pkg/name" + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/remote" + "github.com/google/go-containerregistry/pkg/v1/types" +) + +var ( + src = flag.String("src", "", "Source image") + dst = flag.String("dst", "", "Destination image") + max = flag.Int("max", 0, "Maximum number of tags to sync (0 for all tags)") + dryRun = flag.Bool("dry-run", true, "Don't actually sync anything") +) + +func main() { + flag.Parse() + + if *src == "" { + log.Fatalf("--src is required") + } + if *dst == "" { + log.Fatalf("--dst is required") + } + + keychain := authn.NewMultiKeychain(authn.DefaultKeychain, github.Keychain) + opts := []remote.Option{ + remote.WithAuthFromKeychain(keychain), + remote.WithContext(context.Background()), + } + + stags, err := listTags(*src, opts...) + if err != nil { + log.Fatalf("listing source tags: %v", err) + } + dtags, err := listTags(*dst, opts...) + if err != nil { + log.Fatalf("listing destination tags: %v", err) + } + + add, remove := diffTags(stags, dtags) + if l := len(add); l > 0 { + log.Printf("%d tags to push: %s", len(add), strings.Join(add, ", ")) + if *max > 0 && l > *max { + log.Printf("Limiting sync to %d tags", *max) + add = add[:*max] + } + } + for _, tag := range add { + if !*dryRun { + log.Printf("Syncing tag %q", tag) + if err := copyTag(*src, *dst, tag, opts...); err != nil { + log.Printf("Syncing tag %q: progress error: %v", tag, err) + } + } else { + log.Printf("Dry run: would sync tag %q", tag) + } + } + + if len(remove) > 0 { + log.Printf("%d tags to remove: %s\n", len(remove), strings.Join(remove, ", ")) + log.Printf("Not removing any tags for safety.\n") + } + + var wellKnown = [...]string{"latest", "stable"} + for _, tag := range wellKnown { + if needsUpdate(*src, *dst, tag) { + if err := copyTag(*src, *dst, tag, opts...); err != nil { + log.Printf("Updating tag %q: progress error: %v", tag, err) + } + } + } +} + +func copyTag(srcStr, dstStr, tag string, opts ...remote.Option) error { + src, err := name.ParseReference(fmt.Sprintf("%s:%s", srcStr, tag)) + if err != nil { + return err + } + dst, err := name.ParseReference(fmt.Sprintf("%s:%s", dstStr, tag)) + if err != nil { + return err + } + + desc, err := remote.Get(src) + if err != nil { + return err + } + + ch := make(chan v1.Update, 10) + opts = append(opts, remote.WithProgress(ch)) + progressDone := make(chan struct{}) + + go func() { + defer close(progressDone) + for p := range ch { + fmt.Printf("Syncing tag %q: %d%% (%d/%d)\n", tag, int(float64(p.Complete)/float64(p.Total)*100), p.Complete, p.Total) + if p.Error != nil { + fmt.Printf("error: %v\n", p.Error) + } + } + }() + + switch desc.MediaType { + case types.OCIManifestSchema1, types.DockerManifestSchema2: + img, err := desc.Image() + if err != nil { + return err + } + if err := remote.Write(dst, img, opts...); err != nil { + return err + } + case types.OCIImageIndex, types.DockerManifestList: + idx, err := desc.ImageIndex() + if err != nil { + return err + } + if err := remote.WriteIndex(dst, idx, opts...); err != nil { + return err + } + } + + <-progressDone + return nil +} + +func listTags(repoStr string, opts ...remote.Option) ([]string, error) { + repo, err := name.NewRepository(repoStr) + if err != nil { + return nil, err + } + + tags, err := remote.List(repo, opts...) + if err != nil { + return nil, err + } + + sort.Strings(tags) + return tags, nil +} + +func diffTags(src, dst []string) (add, remove []string) { + srcd := make(map[string]bool) + for _, tag := range src { + srcd[tag] = true + } + dstd := make(map[string]bool) + for _, tag := range dst { + dstd[tag] = true + } + + for _, tag := range src { + if !dstd[tag] { + add = append(add, tag) + } + } + for _, tag := range dst { + if !srcd[tag] { + remove = append(remove, tag) + } + } + sort.Strings(add) + sort.Strings(remove) + return add, remove +} + +func needsUpdate(srcStr, dstStr, tag string) bool { + src, err := name.ParseReference(fmt.Sprintf("%s:%s", srcStr, tag)) + if err != nil { + return false + } + dst, err := name.ParseReference(fmt.Sprintf("%s:%s", dstStr, tag)) + if err != nil { + return false + } + + srcDesc, err := remote.Get(src) + if err != nil { + return false + } + + dstDesc, err := remote.Get(dst) + if err != nil { + return true + } + + return srcDesc.Digest != dstDesc.Digest +} diff --git a/cmd/tailscale/cli/diag.go b/cmd/tailscale/cli/diag.go index ebf26985fe0bd..a1616f851e142 100644 --- a/cmd/tailscale/cli/diag.go +++ b/cmd/tailscale/cli/diag.go @@ -1,74 +1,74 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || windows || darwin - -package cli - -import ( - "fmt" - "os/exec" - "path/filepath" - "runtime" - "strings" - - ps "github.com/mitchellh/go-ps" - "tailscale.com/version/distro" -) - -// fixTailscaledConnectError is called when the local tailscaled has -// been determined unreachable due to the provided origErr value. It -// returns either the same error or a better one to help the user -// understand why tailscaled isn't running for their platform. -func fixTailscaledConnectError(origErr error) error { - procs, err := ps.Processes() - if err != nil { - return fmt.Errorf("failed to connect to local Tailscaled process and failed to enumerate processes while looking for it") - } - var foundProc ps.Process - for _, proc := range procs { - base := filepath.Base(proc.Executable()) - if base == "tailscaled" { - foundProc = proc - break - } - if runtime.GOOS == "darwin" && base == "IPNExtension" { - foundProc = proc - break - } - if runtime.GOOS == "windows" && strings.EqualFold(base, "tailscaled.exe") { - foundProc = proc - break - } - } - if foundProc == nil { - switch runtime.GOOS { - case "windows": - return fmt.Errorf("failed to connect to local tailscaled process; is the Tailscale service running?") - case "darwin": - return fmt.Errorf("failed to connect to local Tailscale service; is Tailscale running?") - case "linux": - var hint string - if isSystemdSystem() { - hint = " (sudo systemctl start tailscaled ?)" - } - return fmt.Errorf("failed to connect to local tailscaled; it doesn't appear to be running%s", hint) - } - return fmt.Errorf("failed to connect to local tailscaled process; it doesn't appear to be running") - } - return fmt.Errorf("failed to connect to local tailscaled (which appears to be running as %v, pid %v). Got error: %w", foundProc.Executable(), foundProc.Pid(), origErr) -} - -// isSystemdSystem reports whether the current machine uses systemd -// and in particular whether the systemctl command is available. -func isSystemdSystem() bool { - if runtime.GOOS != "linux" { - return false - } - switch distro.Get() { - case distro.QNAP, distro.Gokrazy, distro.Synology, distro.Unraid: - return false - } - _, err := exec.LookPath("systemctl") - return err == nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux || windows || darwin + +package cli + +import ( + "fmt" + "os/exec" + "path/filepath" + "runtime" + "strings" + + ps "github.com/mitchellh/go-ps" + "tailscale.com/version/distro" +) + +// fixTailscaledConnectError is called when the local tailscaled has +// been determined unreachable due to the provided origErr value. It +// returns either the same error or a better one to help the user +// understand why tailscaled isn't running for their platform. +func fixTailscaledConnectError(origErr error) error { + procs, err := ps.Processes() + if err != nil { + return fmt.Errorf("failed to connect to local Tailscaled process and failed to enumerate processes while looking for it") + } + var foundProc ps.Process + for _, proc := range procs { + base := filepath.Base(proc.Executable()) + if base == "tailscaled" { + foundProc = proc + break + } + if runtime.GOOS == "darwin" && base == "IPNExtension" { + foundProc = proc + break + } + if runtime.GOOS == "windows" && strings.EqualFold(base, "tailscaled.exe") { + foundProc = proc + break + } + } + if foundProc == nil { + switch runtime.GOOS { + case "windows": + return fmt.Errorf("failed to connect to local tailscaled process; is the Tailscale service running?") + case "darwin": + return fmt.Errorf("failed to connect to local Tailscale service; is Tailscale running?") + case "linux": + var hint string + if isSystemdSystem() { + hint = " (sudo systemctl start tailscaled ?)" + } + return fmt.Errorf("failed to connect to local tailscaled; it doesn't appear to be running%s", hint) + } + return fmt.Errorf("failed to connect to local tailscaled process; it doesn't appear to be running") + } + return fmt.Errorf("failed to connect to local tailscaled (which appears to be running as %v, pid %v). Got error: %w", foundProc.Executable(), foundProc.Pid(), origErr) +} + +// isSystemdSystem reports whether the current machine uses systemd +// and in particular whether the systemctl command is available. +func isSystemdSystem() bool { + if runtime.GOOS != "linux" { + return false + } + switch distro.Get() { + case distro.QNAP, distro.Gokrazy, distro.Synology, distro.Unraid: + return false + } + _, err := exec.LookPath("systemctl") + return err == nil +} diff --git a/cmd/tailscale/cli/diag_other.go b/cmd/tailscale/cli/diag_other.go index ece10cc79a822..82058ef7a139c 100644 --- a/cmd/tailscale/cli/diag_other.go +++ b/cmd/tailscale/cli/diag_other.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !windows && !darwin - -package cli - -import "fmt" - -// The github.com/mitchellh/go-ps package doesn't work on all platforms, -// so just don't diagnose connect failures. - -func fixTailscaledConnectError(origErr error) error { - return fmt.Errorf("failed to connect to local tailscaled process (is it running?); got: %w", origErr) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !windows && !darwin + +package cli + +import "fmt" + +// The github.com/mitchellh/go-ps package doesn't work on all platforms, +// so just don't diagnose connect failures. + +func fixTailscaledConnectError(origErr error) error { + return fmt.Errorf("failed to connect to local tailscaled process (is it running?); got: %w", origErr) +} diff --git a/cmd/tailscale/cli/set_test.go b/cmd/tailscale/cli/set_test.go index 15305c3ce3ed3..06ef8503f048e 100644 --- a/cmd/tailscale/cli/set_test.go +++ b/cmd/tailscale/cli/set_test.go @@ -1,131 +1,131 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cli - -import ( - "net/netip" - "reflect" - "testing" - - "tailscale.com/ipn" - "tailscale.com/net/tsaddr" - "tailscale.com/types/ptr" -) - -func TestCalcAdvertiseRoutesForSet(t *testing.T) { - pfx := netip.MustParsePrefix - tests := []struct { - name string - setExit *bool - setRoutes *string - was []netip.Prefix - want []netip.Prefix - }{ - { - name: "empty", - }, - { - name: "advertise-exit", - setExit: ptr.To(true), - want: tsaddr.ExitRoutes(), - }, - { - name: "advertise-exit/already-routes", - was: []netip.Prefix{pfx("34.0.0.0/16")}, - setExit: ptr.To(true), - want: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - }, - { - name: "advertise-exit/already-exit", - was: tsaddr.ExitRoutes(), - setExit: ptr.To(true), - want: tsaddr.ExitRoutes(), - }, - { - name: "stop-advertise-exit", - was: tsaddr.ExitRoutes(), - setExit: ptr.To(false), - want: nil, - }, - { - name: "stop-advertise-exit/with-routes", - was: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - setExit: ptr.To(false), - want: []netip.Prefix{pfx("34.0.0.0/16")}, - }, - { - name: "advertise-routes", - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, - }, - { - name: "advertise-routes/already-exit", - was: tsaddr.ExitRoutes(), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - }, - { - name: "advertise-routes/already-diff-routes", - was: []netip.Prefix{pfx("34.0.0.0/16")}, - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, - }, - { - name: "stop-advertise-routes", - was: []netip.Prefix{pfx("34.0.0.0/16")}, - setRoutes: ptr.To(""), - want: nil, - }, - { - name: "stop-advertise-routes/already-exit", - was: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - setRoutes: ptr.To(""), - want: tsaddr.ExitRoutes(), - }, - { - name: "advertise-routes-and-exit", - setExit: ptr.To(true), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - }, - { - name: "advertise-routes-and-exit/already-exit", - was: tsaddr.ExitRoutes(), - setExit: ptr.To(true), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - }, - { - name: "advertise-routes-and-exit/already-routes", - was: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, - setExit: ptr.To(true), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - curPrefs := &ipn.Prefs{ - AdvertiseRoutes: tc.was, - } - sa := setArgsT{} - if tc.setExit != nil { - sa.advertiseDefaultRoute = *tc.setExit - } - if tc.setRoutes != nil { - sa.advertiseRoutes = *tc.setRoutes - } - got, err := calcAdvertiseRoutesForSet(tc.setExit != nil, tc.setRoutes != nil, curPrefs, sa) - if err != nil { - t.Fatal(err) - } - tsaddr.SortPrefixes(got) - tsaddr.SortPrefixes(tc.want) - if !reflect.DeepEqual(got, tc.want) { - t.Errorf("got %v, want %v", got, tc.want) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "net/netip" + "reflect" + "testing" + + "tailscale.com/ipn" + "tailscale.com/net/tsaddr" + "tailscale.com/types/ptr" +) + +func TestCalcAdvertiseRoutesForSet(t *testing.T) { + pfx := netip.MustParsePrefix + tests := []struct { + name string + setExit *bool + setRoutes *string + was []netip.Prefix + want []netip.Prefix + }{ + { + name: "empty", + }, + { + name: "advertise-exit", + setExit: ptr.To(true), + want: tsaddr.ExitRoutes(), + }, + { + name: "advertise-exit/already-routes", + was: []netip.Prefix{pfx("34.0.0.0/16")}, + setExit: ptr.To(true), + want: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + { + name: "advertise-exit/already-exit", + was: tsaddr.ExitRoutes(), + setExit: ptr.To(true), + want: tsaddr.ExitRoutes(), + }, + { + name: "stop-advertise-exit", + was: tsaddr.ExitRoutes(), + setExit: ptr.To(false), + want: nil, + }, + { + name: "stop-advertise-exit/with-routes", + was: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + setExit: ptr.To(false), + want: []netip.Prefix{pfx("34.0.0.0/16")}, + }, + { + name: "advertise-routes", + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, + }, + { + name: "advertise-routes/already-exit", + was: tsaddr.ExitRoutes(), + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + { + name: "advertise-routes/already-diff-routes", + was: []netip.Prefix{pfx("34.0.0.0/16")}, + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, + }, + { + name: "stop-advertise-routes", + was: []netip.Prefix{pfx("34.0.0.0/16")}, + setRoutes: ptr.To(""), + want: nil, + }, + { + name: "stop-advertise-routes/already-exit", + was: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + setRoutes: ptr.To(""), + want: tsaddr.ExitRoutes(), + }, + { + name: "advertise-routes-and-exit", + setExit: ptr.To(true), + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + { + name: "advertise-routes-and-exit/already-exit", + was: tsaddr.ExitRoutes(), + setExit: ptr.To(true), + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + { + name: "advertise-routes-and-exit/already-routes", + was: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, + setExit: ptr.To(true), + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + curPrefs := &ipn.Prefs{ + AdvertiseRoutes: tc.was, + } + sa := setArgsT{} + if tc.setExit != nil { + sa.advertiseDefaultRoute = *tc.setExit + } + if tc.setRoutes != nil { + sa.advertiseRoutes = *tc.setRoutes + } + got, err := calcAdvertiseRoutesForSet(tc.setExit != nil, tc.setRoutes != nil, curPrefs, sa) + if err != nil { + t.Fatal(err) + } + tsaddr.SortPrefixes(got) + tsaddr.SortPrefixes(tc.want) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("got %v, want %v", got, tc.want) + } + }) + } +} diff --git a/cmd/tailscale/cli/ssh_exec.go b/cmd/tailscale/cli/ssh_exec.go index 10e52903dea64..7f7d2a4d5cfe0 100644 --- a/cmd/tailscale/cli/ssh_exec.go +++ b/cmd/tailscale/cli/ssh_exec.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !js && !windows - -package cli - -import ( - "errors" - "os" - "os/exec" - "syscall" -) - -func findSSH() (string, error) { - return exec.LookPath("ssh") -} - -func execSSH(ssh string, argv []string) error { - if err := syscall.Exec(ssh, argv, os.Environ()); err != nil { - return err - } - return errors.New("unreachable") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !windows + +package cli + +import ( + "errors" + "os" + "os/exec" + "syscall" +) + +func findSSH() (string, error) { + return exec.LookPath("ssh") +} + +func execSSH(ssh string, argv []string) error { + if err := syscall.Exec(ssh, argv, os.Environ()); err != nil { + return err + } + return errors.New("unreachable") +} diff --git a/cmd/tailscale/cli/ssh_exec_js.go b/cmd/tailscale/cli/ssh_exec_js.go index 40effc7cafc7e..aa0c09e89ab66 100644 --- a/cmd/tailscale/cli/ssh_exec_js.go +++ b/cmd/tailscale/cli/ssh_exec_js.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cli - -import ( - "errors" -) - -func findSSH() (string, error) { - return "", errors.New("Not implemented") -} - -func execSSH(ssh string, argv []string) error { - return errors.New("Not implemented") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "errors" +) + +func findSSH() (string, error) { + return "", errors.New("Not implemented") +} + +func execSSH(ssh string, argv []string) error { + return errors.New("Not implemented") +} diff --git a/cmd/tailscale/cli/ssh_exec_windows.go b/cmd/tailscale/cli/ssh_exec_windows.go index e249afe667401..30ab70d046dd4 100644 --- a/cmd/tailscale/cli/ssh_exec_windows.go +++ b/cmd/tailscale/cli/ssh_exec_windows.go @@ -1,37 +1,37 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cli - -import ( - "errors" - "os" - "os/exec" - "path/filepath" -) - -func findSSH() (string, error) { - // use C:\Windows\System32\OpenSSH\ssh.exe since unexpected behavior - // occurred with ssh.exe provided by msys2/cygwin and other environments. - if systemRoot := os.Getenv("SystemRoot"); systemRoot != "" { - exe := filepath.Join(systemRoot, "System32", "OpenSSH", "ssh.exe") - if st, err := os.Stat(exe); err == nil && !st.IsDir() { - return exe, nil - } - } - return exec.LookPath("ssh") -} - -func execSSH(ssh string, argv []string) error { - // Don't use syscall.Exec on Windows, it's not fully implemented. - cmd := exec.Command(ssh, argv[1:]...) - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - var ee *exec.ExitError - err := cmd.Run() - if errors.As(err, &ee) { - os.Exit(ee.ExitCode()) - } - return err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "errors" + "os" + "os/exec" + "path/filepath" +) + +func findSSH() (string, error) { + // use C:\Windows\System32\OpenSSH\ssh.exe since unexpected behavior + // occurred with ssh.exe provided by msys2/cygwin and other environments. + if systemRoot := os.Getenv("SystemRoot"); systemRoot != "" { + exe := filepath.Join(systemRoot, "System32", "OpenSSH", "ssh.exe") + if st, err := os.Stat(exe); err == nil && !st.IsDir() { + return exe, nil + } + } + return exec.LookPath("ssh") +} + +func execSSH(ssh string, argv []string) error { + // Don't use syscall.Exec on Windows, it's not fully implemented. + cmd := exec.Command(ssh, argv[1:]...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + var ee *exec.ExitError + err := cmd.Run() + if errors.As(err, &ee) { + os.Exit(ee.ExitCode()) + } + return err +} diff --git a/cmd/tailscale/cli/ssh_unix.go b/cmd/tailscale/cli/ssh_unix.go index 71c0caaa69ad5..07423b69fa9e6 100644 --- a/cmd/tailscale/cli/ssh_unix.go +++ b/cmd/tailscale/cli/ssh_unix.go @@ -1,49 +1,49 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !wasm && !windows && !plan9 - -package cli - -import ( - "bytes" - "os" - "path/filepath" - "runtime" - "strconv" - - "golang.org/x/sys/unix" -) - -func init() { - getSSHClientEnvVar = func() string { - if os.Getenv("SUDO_USER") == "" { - // No sudo, just check the env. - return os.Getenv("SSH_CLIENT") - } - if runtime.GOOS != "linux" { - // TODO(maisem): implement this for other platforms. It's not clear - // if there is a way to get the environment for a given process on - // darwin and bsd. - return "" - } - // SID is the session ID of the user's login session. - // It is also the process ID of the original shell that the user logged in with. - // We only need to check the environment of that process. - sid, err := unix.Getsid(os.Getpid()) - if err != nil { - return "" - } - b, err := os.ReadFile(filepath.Join("/proc", strconv.Itoa(sid), "environ")) - if err != nil { - return "" - } - prefix := []byte("SSH_CLIENT=") - for _, env := range bytes.Split(b, []byte{0}) { - if bytes.HasPrefix(env, prefix) { - return string(env[len(prefix):]) - } - } - return "" - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !wasm && !windows && !plan9 + +package cli + +import ( + "bytes" + "os" + "path/filepath" + "runtime" + "strconv" + + "golang.org/x/sys/unix" +) + +func init() { + getSSHClientEnvVar = func() string { + if os.Getenv("SUDO_USER") == "" { + // No sudo, just check the env. + return os.Getenv("SSH_CLIENT") + } + if runtime.GOOS != "linux" { + // TODO(maisem): implement this for other platforms. It's not clear + // if there is a way to get the environment for a given process on + // darwin and bsd. + return "" + } + // SID is the session ID of the user's login session. + // It is also the process ID of the original shell that the user logged in with. + // We only need to check the environment of that process. + sid, err := unix.Getsid(os.Getpid()) + if err != nil { + return "" + } + b, err := os.ReadFile(filepath.Join("/proc", strconv.Itoa(sid), "environ")) + if err != nil { + return "" + } + prefix := []byte("SSH_CLIENT=") + for _, env := range bytes.Split(b, []byte{0}) { + if bytes.HasPrefix(env, prefix) { + return string(env[len(prefix):]) + } + } + return "" + } +} diff --git a/cmd/tailscale/cli/web_test.go b/cmd/tailscale/cli/web_test.go index f2470b364c41e..f1880597e5c53 100644 --- a/cmd/tailscale/cli/web_test.go +++ b/cmd/tailscale/cli/web_test.go @@ -1,45 +1,45 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cli - -import ( - "testing" -) - -func TestUrlOfListenAddr(t *testing.T) { - tests := []struct { - name string - in, want string - }{ - { - name: "TestLocalhost", - in: "localhost:8088", - want: "http://localhost:8088", - }, - { - name: "TestNoHost", - in: ":8088", - want: "http://127.0.0.1:8088", - }, - { - name: "TestExplicitHost", - in: "127.0.0.2:8088", - want: "http://127.0.0.2:8088", - }, - { - name: "TestIPv6", - in: "[::1]:8088", - want: "http://[::1]:8088", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - u := urlOfListenAddr(tt.in) - if u != tt.want { - t.Errorf("expected url: %q, got: %q", tt.want, u) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "testing" +) + +func TestUrlOfListenAddr(t *testing.T) { + tests := []struct { + name string + in, want string + }{ + { + name: "TestLocalhost", + in: "localhost:8088", + want: "http://localhost:8088", + }, + { + name: "TestNoHost", + in: ":8088", + want: "http://127.0.0.1:8088", + }, + { + name: "TestExplicitHost", + in: "127.0.0.2:8088", + want: "http://127.0.0.2:8088", + }, + { + name: "TestIPv6", + in: "[::1]:8088", + want: "http://[::1]:8088", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u := urlOfListenAddr(tt.in) + if u != tt.want { + t.Errorf("expected url: %q, got: %q", tt.want, u) + } + }) + } +} diff --git a/cmd/tailscale/generate.go b/cmd/tailscale/generate.go index 5c2e9be915980..fa38b370417aa 100644 --- a/cmd/tailscale/generate.go +++ b/cmd/tailscale/generate.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -//go:generate go run tailscale.com/cmd/mkmanifest amd64 windows-manifest.xml manifest_windows_amd64.syso -//go:generate go run tailscale.com/cmd/mkmanifest 386 windows-manifest.xml manifest_windows_386.syso -//go:generate go run tailscale.com/cmd/mkmanifest arm64 windows-manifest.xml manifest_windows_arm64.syso +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +//go:generate go run tailscale.com/cmd/mkmanifest amd64 windows-manifest.xml manifest_windows_amd64.syso +//go:generate go run tailscale.com/cmd/mkmanifest 386 windows-manifest.xml manifest_windows_386.syso +//go:generate go run tailscale.com/cmd/mkmanifest arm64 windows-manifest.xml manifest_windows_arm64.syso diff --git a/cmd/tailscale/tailscale.go b/cmd/tailscale/tailscale.go index f6adb6c197071..1848d65088c3d 100644 --- a/cmd/tailscale/tailscale.go +++ b/cmd/tailscale/tailscale.go @@ -1,26 +1,26 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The tailscale command is the Tailscale command-line client. It interacts -// with the tailscaled node agent. -package main // import "tailscale.com/cmd/tailscale" - -import ( - "fmt" - "os" - "path/filepath" - "strings" - - "tailscale.com/cmd/tailscale/cli" -) - -func main() { - args := os.Args[1:] - if name, _ := os.Executable(); strings.HasSuffix(filepath.Base(name), ".cgi") { - args = []string{"web", "-cgi"} - } - if err := cli.Run(args); err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The tailscale command is the Tailscale command-line client. It interacts +// with the tailscaled node agent. +package main // import "tailscale.com/cmd/tailscale" + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "tailscale.com/cmd/tailscale/cli" +) + +func main() { + args := os.Args[1:] + if name, _ := os.Executable(); strings.HasSuffix(filepath.Base(name), ".cgi") { + args = []string{"web", "-cgi"} + } + if err := cli.Run(args); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} diff --git a/cmd/tailscale/windows-manifest.xml b/cmd/tailscale/windows-manifest.xml index 6c5f46058387f..5eaa54fa514e3 100644 --- a/cmd/tailscale/windows-manifest.xml +++ b/cmd/tailscale/windows-manifest.xml @@ -1,13 +1,13 @@ - - - - - - - - - - - - - + + + + + + + + + + + + + diff --git a/cmd/tailscaled/childproc/childproc.go b/cmd/tailscaled/childproc/childproc.go index cc83a06c6ee7c..068015c59f3eb 100644 --- a/cmd/tailscaled/childproc/childproc.go +++ b/cmd/tailscaled/childproc/childproc.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package childproc allows other packages to register "tailscaled be-child" -// child process hook code. This avoids duplicating build tags in the -// tailscaled package. Instead, the code that needs to fork/exec the self -// executable (when it's tailscaled) can instead register the code -// they want to run. -package childproc - -var Code = map[string]func([]string) error{} - -// Add registers code f to run as 'tailscaled be-child [args]'. -func Add(typ string, f func(args []string) error) { - if _, dup := Code[typ]; dup { - panic("dup hook " + typ) - } - Code[typ] = f -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package childproc allows other packages to register "tailscaled be-child" +// child process hook code. This avoids duplicating build tags in the +// tailscaled package. Instead, the code that needs to fork/exec the self +// executable (when it's tailscaled) can instead register the code +// they want to run. +package childproc + +var Code = map[string]func([]string) error{} + +// Add registers code f to run as 'tailscaled be-child [args]'. +func Add(typ string, f func(args []string) error) { + if _, dup := Code[typ]; dup { + panic("dup hook " + typ) + } + Code[typ] = f +} diff --git a/cmd/tailscaled/generate.go b/cmd/tailscaled/generate.go index 5c2e9be915980..fa38b370417aa 100644 --- a/cmd/tailscaled/generate.go +++ b/cmd/tailscaled/generate.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -//go:generate go run tailscale.com/cmd/mkmanifest amd64 windows-manifest.xml manifest_windows_amd64.syso -//go:generate go run tailscale.com/cmd/mkmanifest 386 windows-manifest.xml manifest_windows_386.syso -//go:generate go run tailscale.com/cmd/mkmanifest arm64 windows-manifest.xml manifest_windows_arm64.syso +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +//go:generate go run tailscale.com/cmd/mkmanifest amd64 windows-manifest.xml manifest_windows_amd64.syso +//go:generate go run tailscale.com/cmd/mkmanifest 386 windows-manifest.xml manifest_windows_386.syso +//go:generate go run tailscale.com/cmd/mkmanifest arm64 windows-manifest.xml manifest_windows_arm64.syso diff --git a/cmd/tailscaled/install_darwin.go b/cmd/tailscaled/install_darwin.go index 05e5eaed8af90..9013b39ba3567 100644 --- a/cmd/tailscaled/install_darwin.go +++ b/cmd/tailscaled/install_darwin.go @@ -1,199 +1,199 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -package main - -import ( - "errors" - "fmt" - "io" - "io/fs" - "os" - "os/exec" - "path/filepath" -) - -func init() { - installSystemDaemon = installSystemDaemonDarwin - uninstallSystemDaemon = uninstallSystemDaemonDarwin -} - -// darwinLaunchdPlist is the launchd.plist that's written to -// /Library/LaunchDaemons/com.tailscale.tailscaled.plist or (in the -// future) a user-specific location. -// -// See man launchd.plist. -const darwinLaunchdPlist = ` - - - - - - Label - com.tailscale.tailscaled - - ProgramArguments - - /usr/local/bin/tailscaled - - - RunAtLoad - - - - -` - -const sysPlist = "/Library/LaunchDaemons/com.tailscale.tailscaled.plist" -const targetBin = "/usr/local/bin/tailscaled" -const service = "com.tailscale.tailscaled" - -func uninstallSystemDaemonDarwin(args []string) (ret error) { - if len(args) > 0 { - return errors.New("uninstall subcommand takes no arguments") - } - - plist, err := exec.Command("launchctl", "list", "com.tailscale.tailscaled").Output() - _ = plist // parse it? https://github.com/DHowett/go-plist if we need something. - running := err == nil - - if running { - out, err := exec.Command("launchctl", "stop", "com.tailscale.tailscaled").CombinedOutput() - if err != nil { - fmt.Printf("launchctl stop com.tailscale.tailscaled: %v, %s\n", err, out) - ret = err - } - out, err = exec.Command("launchctl", "unload", sysPlist).CombinedOutput() - if err != nil { - fmt.Printf("launchctl unload %s: %v, %s\n", sysPlist, err, out) - if ret == nil { - ret = err - } - } - } - - if err := os.Remove(sysPlist); err != nil { - if os.IsNotExist(err) { - err = nil - } - if ret == nil { - ret = err - } - } - - // Do not delete targetBin if it's a symlink, which happens if it was installed via - // Homebrew. - if isSymlink(targetBin) { - return ret - } - - if err := os.Remove(targetBin); err != nil { - if os.IsNotExist(err) { - err = nil - } - if ret == nil { - ret = err - } - } - return ret -} - -func installSystemDaemonDarwin(args []string) (err error) { - if len(args) > 0 { - return errors.New("install subcommand takes no arguments") - } - defer func() { - if err != nil && os.Getuid() != 0 { - err = fmt.Errorf("%w; try running tailscaled with sudo", err) - } - }() - - // Best effort: - uninstallSystemDaemonDarwin(nil) - - exe, err := os.Executable() - if err != nil { - return fmt.Errorf("failed to find our own executable path: %w", err) - } - - same, err := sameFile(exe, targetBin) - if err != nil { - return err - } - - // Do not overwrite targetBin with the binary file if it it's already - // pointing to it. This is primarily to handle Homebrew that writes - // /usr/local/bin/tailscaled is a symlink to the actual binary. - if !same { - if err := copyBinary(exe, targetBin); err != nil { - return err - } - } - if err := os.WriteFile(sysPlist, []byte(darwinLaunchdPlist), 0700); err != nil { - return err - } - - if out, err := exec.Command("launchctl", "load", sysPlist).CombinedOutput(); err != nil { - return fmt.Errorf("error running launchctl load %s: %v, %s", sysPlist, err, out) - } - - if out, err := exec.Command("launchctl", "start", service).CombinedOutput(); err != nil { - return fmt.Errorf("error running launchctl start %s: %v, %s", service, err, out) - } - - return nil -} - -// copyBinary copies binary file `src` into `dst`. -func copyBinary(src, dst string) error { - if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { - return err - } - tmpBin := dst + ".tmp" - f, err := os.Create(tmpBin) - if err != nil { - return err - } - srcf, err := os.Open(src) - if err != nil { - f.Close() - return err - } - _, err = io.Copy(f, srcf) - srcf.Close() - if err != nil { - f.Close() - return err - } - if err := f.Close(); err != nil { - return err - } - if err := os.Chmod(tmpBin, 0755); err != nil { - return err - } - if err := os.Rename(tmpBin, dst); err != nil { - return err - } - - return nil -} - -func isSymlink(path string) bool { - fi, err := os.Lstat(path) - return err == nil && (fi.Mode()&os.ModeSymlink == os.ModeSymlink) -} - -// sameFile returns true if both file paths exist and resolve to the same file. -func sameFile(path1, path2 string) (bool, error) { - dst1, err := filepath.EvalSymlinks(path1) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return false, fmt.Errorf("EvalSymlinks(%s): %w", path1, err) - } - dst2, err := filepath.EvalSymlinks(path2) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return false, fmt.Errorf("EvalSymlinks(%s): %w", path2, err) - } - return dst1 == dst2, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package main + +import ( + "errors" + "fmt" + "io" + "io/fs" + "os" + "os/exec" + "path/filepath" +) + +func init() { + installSystemDaemon = installSystemDaemonDarwin + uninstallSystemDaemon = uninstallSystemDaemonDarwin +} + +// darwinLaunchdPlist is the launchd.plist that's written to +// /Library/LaunchDaemons/com.tailscale.tailscaled.plist or (in the +// future) a user-specific location. +// +// See man launchd.plist. +const darwinLaunchdPlist = ` + + + + + + Label + com.tailscale.tailscaled + + ProgramArguments + + /usr/local/bin/tailscaled + + + RunAtLoad + + + + +` + +const sysPlist = "/Library/LaunchDaemons/com.tailscale.tailscaled.plist" +const targetBin = "/usr/local/bin/tailscaled" +const service = "com.tailscale.tailscaled" + +func uninstallSystemDaemonDarwin(args []string) (ret error) { + if len(args) > 0 { + return errors.New("uninstall subcommand takes no arguments") + } + + plist, err := exec.Command("launchctl", "list", "com.tailscale.tailscaled").Output() + _ = plist // parse it? https://github.com/DHowett/go-plist if we need something. + running := err == nil + + if running { + out, err := exec.Command("launchctl", "stop", "com.tailscale.tailscaled").CombinedOutput() + if err != nil { + fmt.Printf("launchctl stop com.tailscale.tailscaled: %v, %s\n", err, out) + ret = err + } + out, err = exec.Command("launchctl", "unload", sysPlist).CombinedOutput() + if err != nil { + fmt.Printf("launchctl unload %s: %v, %s\n", sysPlist, err, out) + if ret == nil { + ret = err + } + } + } + + if err := os.Remove(sysPlist); err != nil { + if os.IsNotExist(err) { + err = nil + } + if ret == nil { + ret = err + } + } + + // Do not delete targetBin if it's a symlink, which happens if it was installed via + // Homebrew. + if isSymlink(targetBin) { + return ret + } + + if err := os.Remove(targetBin); err != nil { + if os.IsNotExist(err) { + err = nil + } + if ret == nil { + ret = err + } + } + return ret +} + +func installSystemDaemonDarwin(args []string) (err error) { + if len(args) > 0 { + return errors.New("install subcommand takes no arguments") + } + defer func() { + if err != nil && os.Getuid() != 0 { + err = fmt.Errorf("%w; try running tailscaled with sudo", err) + } + }() + + // Best effort: + uninstallSystemDaemonDarwin(nil) + + exe, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to find our own executable path: %w", err) + } + + same, err := sameFile(exe, targetBin) + if err != nil { + return err + } + + // Do not overwrite targetBin with the binary file if it it's already + // pointing to it. This is primarily to handle Homebrew that writes + // /usr/local/bin/tailscaled is a symlink to the actual binary. + if !same { + if err := copyBinary(exe, targetBin); err != nil { + return err + } + } + if err := os.WriteFile(sysPlist, []byte(darwinLaunchdPlist), 0700); err != nil { + return err + } + + if out, err := exec.Command("launchctl", "load", sysPlist).CombinedOutput(); err != nil { + return fmt.Errorf("error running launchctl load %s: %v, %s", sysPlist, err, out) + } + + if out, err := exec.Command("launchctl", "start", service).CombinedOutput(); err != nil { + return fmt.Errorf("error running launchctl start %s: %v, %s", service, err, out) + } + + return nil +} + +// copyBinary copies binary file `src` into `dst`. +func copyBinary(src, dst string) error { + if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { + return err + } + tmpBin := dst + ".tmp" + f, err := os.Create(tmpBin) + if err != nil { + return err + } + srcf, err := os.Open(src) + if err != nil { + f.Close() + return err + } + _, err = io.Copy(f, srcf) + srcf.Close() + if err != nil { + f.Close() + return err + } + if err := f.Close(); err != nil { + return err + } + if err := os.Chmod(tmpBin, 0755); err != nil { + return err + } + if err := os.Rename(tmpBin, dst); err != nil { + return err + } + + return nil +} + +func isSymlink(path string) bool { + fi, err := os.Lstat(path) + return err == nil && (fi.Mode()&os.ModeSymlink == os.ModeSymlink) +} + +// sameFile returns true if both file paths exist and resolve to the same file. +func sameFile(path1, path2 string) (bool, error) { + dst1, err := filepath.EvalSymlinks(path1) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return false, fmt.Errorf("EvalSymlinks(%s): %w", path1, err) + } + dst2, err := filepath.EvalSymlinks(path2) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return false, fmt.Errorf("EvalSymlinks(%s): %w", path2, err) + } + return dst1 == dst2, nil +} diff --git a/cmd/tailscaled/install_windows.go b/cmd/tailscaled/install_windows.go index c36418642d2b4..9e39c8ab37074 100644 --- a/cmd/tailscaled/install_windows.go +++ b/cmd/tailscaled/install_windows.go @@ -1,124 +1,124 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -package main - -import ( - "context" - "errors" - "fmt" - "os" - "time" - - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/svc" - "golang.org/x/sys/windows/svc/mgr" - "tailscale.com/logtail/backoff" - "tailscale.com/types/logger" - "tailscale.com/util/osshare" -) - -func init() { - installSystemDaemon = installSystemDaemonWindows - uninstallSystemDaemon = uninstallSystemDaemonWindows -} - -func installSystemDaemonWindows(args []string) (err error) { - m, err := mgr.Connect() - if err != nil { - return fmt.Errorf("failed to connect to Windows service manager: %v", err) - } - - service, err := m.OpenService(serviceName) - if err == nil { - service.Close() - return fmt.Errorf("service %q is already installed", serviceName) - } - - // no such service; proceed to install the service. - - exe, err := os.Executable() - if err != nil { - return err - } - - c := mgr.Config{ - ServiceType: windows.SERVICE_WIN32_OWN_PROCESS, - StartType: mgr.StartAutomatic, - ErrorControl: mgr.ErrorNormal, - DisplayName: serviceName, - Description: "Connects this computer to others on the Tailscale network.", - } - - service, err = m.CreateService(serviceName, exe, c) - if err != nil { - return fmt.Errorf("failed to create %q service: %v", serviceName, err) - } - defer service.Close() - - // Exponential backoff is often too aggressive, so use (mostly) - // squares instead. - ra := []mgr.RecoveryAction{ - {mgr.ServiceRestart, 1 * time.Second}, - {mgr.ServiceRestart, 2 * time.Second}, - {mgr.ServiceRestart, 4 * time.Second}, - {mgr.ServiceRestart, 9 * time.Second}, - {mgr.ServiceRestart, 16 * time.Second}, - {mgr.ServiceRestart, 25 * time.Second}, - {mgr.ServiceRestart, 36 * time.Second}, - {mgr.ServiceRestart, 49 * time.Second}, - {mgr.ServiceRestart, 64 * time.Second}, - } - const resetPeriodSecs = 60 - err = service.SetRecoveryActions(ra, resetPeriodSecs) - if err != nil { - return fmt.Errorf("failed to set service recovery actions: %v", err) - } - - return nil -} - -func uninstallSystemDaemonWindows(args []string) (ret error) { - // Remove file sharing from Windows shell (noop in non-windows) - osshare.SetFileSharingEnabled(false, logger.Discard) - - m, err := mgr.Connect() - if err != nil { - return fmt.Errorf("failed to connect to Windows service manager: %v", err) - } - defer m.Disconnect() - - service, err := m.OpenService(serviceName) - if err != nil { - return fmt.Errorf("failed to open %q service: %v", serviceName, err) - } - - st, err := service.Query() - if err != nil { - service.Close() - return fmt.Errorf("failed to query service state: %v", err) - } - if st.State != svc.Stopped { - service.Control(svc.Stop) - } - err = service.Delete() - service.Close() - if err != nil { - return fmt.Errorf("failed to delete service: %v", err) - } - - bo := backoff.NewBackoff("uninstall", logger.Discard, 30*time.Second) - end := time.Now().Add(15 * time.Second) - for time.Until(end) > 0 { - service, err = m.OpenService(serviceName) - if err != nil { - // service is no longer openable; success! - break - } - service.Close() - bo.BackOff(context.Background(), errors.New("service not deleted")) - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package main + +import ( + "context" + "errors" + "fmt" + "os" + "time" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" + "tailscale.com/logtail/backoff" + "tailscale.com/types/logger" + "tailscale.com/util/osshare" +) + +func init() { + installSystemDaemon = installSystemDaemonWindows + uninstallSystemDaemon = uninstallSystemDaemonWindows +} + +func installSystemDaemonWindows(args []string) (err error) { + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to Windows service manager: %v", err) + } + + service, err := m.OpenService(serviceName) + if err == nil { + service.Close() + return fmt.Errorf("service %q is already installed", serviceName) + } + + // no such service; proceed to install the service. + + exe, err := os.Executable() + if err != nil { + return err + } + + c := mgr.Config{ + ServiceType: windows.SERVICE_WIN32_OWN_PROCESS, + StartType: mgr.StartAutomatic, + ErrorControl: mgr.ErrorNormal, + DisplayName: serviceName, + Description: "Connects this computer to others on the Tailscale network.", + } + + service, err = m.CreateService(serviceName, exe, c) + if err != nil { + return fmt.Errorf("failed to create %q service: %v", serviceName, err) + } + defer service.Close() + + // Exponential backoff is often too aggressive, so use (mostly) + // squares instead. + ra := []mgr.RecoveryAction{ + {mgr.ServiceRestart, 1 * time.Second}, + {mgr.ServiceRestart, 2 * time.Second}, + {mgr.ServiceRestart, 4 * time.Second}, + {mgr.ServiceRestart, 9 * time.Second}, + {mgr.ServiceRestart, 16 * time.Second}, + {mgr.ServiceRestart, 25 * time.Second}, + {mgr.ServiceRestart, 36 * time.Second}, + {mgr.ServiceRestart, 49 * time.Second}, + {mgr.ServiceRestart, 64 * time.Second}, + } + const resetPeriodSecs = 60 + err = service.SetRecoveryActions(ra, resetPeriodSecs) + if err != nil { + return fmt.Errorf("failed to set service recovery actions: %v", err) + } + + return nil +} + +func uninstallSystemDaemonWindows(args []string) (ret error) { + // Remove file sharing from Windows shell (noop in non-windows) + osshare.SetFileSharingEnabled(false, logger.Discard) + + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to Windows service manager: %v", err) + } + defer m.Disconnect() + + service, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("failed to open %q service: %v", serviceName, err) + } + + st, err := service.Query() + if err != nil { + service.Close() + return fmt.Errorf("failed to query service state: %v", err) + } + if st.State != svc.Stopped { + service.Control(svc.Stop) + } + err = service.Delete() + service.Close() + if err != nil { + return fmt.Errorf("failed to delete service: %v", err) + } + + bo := backoff.NewBackoff("uninstall", logger.Discard, 30*time.Second) + end := time.Now().Add(15 * time.Second) + for time.Until(end) > 0 { + service, err = m.OpenService(serviceName) + if err != nil { + // service is no longer openable; success! + break + } + service.Close() + bo.BackOff(context.Background(), errors.New("service not deleted")) + } + return nil +} diff --git a/cmd/tailscaled/proxy.go b/cmd/tailscaled/proxy.go index a91c62bfa44ac..109ad029d3aaf 100644 --- a/cmd/tailscaled/proxy.go +++ b/cmd/tailscaled/proxy.go @@ -1,80 +1,80 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -// HTTP proxy code - -package main - -import ( - "context" - "io" - "net" - "net/http" - "net/http/httputil" - "strings" -) - -// httpProxyHandler returns an HTTP proxy http.Handler using the -// provided backend dialer. -func httpProxyHandler(dialer func(ctx context.Context, netw, addr string) (net.Conn, error)) http.Handler { - rp := &httputil.ReverseProxy{ - Director: func(r *http.Request) {}, // no change - Transport: &http.Transport{ - DialContext: dialer, - }, - } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "CONNECT" { - backURL := r.RequestURI - if strings.HasPrefix(backURL, "/") || backURL == "*" { - http.Error(w, "bogus RequestURI; must be absolute URL or CONNECT", 400) - return - } - rp.ServeHTTP(w, r) - return - } - - // CONNECT support: - - dst := r.RequestURI - c, err := dialer(r.Context(), "tcp", dst) - if err != nil { - w.Header().Set("Tailscale-Connect-Error", err.Error()) - http.Error(w, err.Error(), 500) - return - } - defer c.Close() - - cc, ccbuf, err := w.(http.Hijacker).Hijack() - if err != nil { - http.Error(w, err.Error(), 500) - return - } - defer cc.Close() - - io.WriteString(cc, "HTTP/1.1 200 OK\r\n\r\n") - - var clientSrc io.Reader = ccbuf - if ccbuf.Reader.Buffered() == 0 { - // In the common case (with no - // buffered data), read directly from - // the underlying client connection to - // save some memory, letting the - // bufio.Reader/Writer get GC'ed. - clientSrc = cc - } - - errc := make(chan error, 1) - go func() { - _, err := io.Copy(cc, c) - errc <- err - }() - go func() { - _, err := io.Copy(c, clientSrc) - errc <- err - }() - <-errc - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +// HTTP proxy code + +package main + +import ( + "context" + "io" + "net" + "net/http" + "net/http/httputil" + "strings" +) + +// httpProxyHandler returns an HTTP proxy http.Handler using the +// provided backend dialer. +func httpProxyHandler(dialer func(ctx context.Context, netw, addr string) (net.Conn, error)) http.Handler { + rp := &httputil.ReverseProxy{ + Director: func(r *http.Request) {}, // no change + Transport: &http.Transport{ + DialContext: dialer, + }, + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "CONNECT" { + backURL := r.RequestURI + if strings.HasPrefix(backURL, "/") || backURL == "*" { + http.Error(w, "bogus RequestURI; must be absolute URL or CONNECT", 400) + return + } + rp.ServeHTTP(w, r) + return + } + + // CONNECT support: + + dst := r.RequestURI + c, err := dialer(r.Context(), "tcp", dst) + if err != nil { + w.Header().Set("Tailscale-Connect-Error", err.Error()) + http.Error(w, err.Error(), 500) + return + } + defer c.Close() + + cc, ccbuf, err := w.(http.Hijacker).Hijack() + if err != nil { + http.Error(w, err.Error(), 500) + return + } + defer cc.Close() + + io.WriteString(cc, "HTTP/1.1 200 OK\r\n\r\n") + + var clientSrc io.Reader = ccbuf + if ccbuf.Reader.Buffered() == 0 { + // In the common case (with no + // buffered data), read directly from + // the underlying client connection to + // save some memory, letting the + // bufio.Reader/Writer get GC'ed. + clientSrc = cc + } + + errc := make(chan error, 1) + go func() { + _, err := io.Copy(cc, c) + errc <- err + }() + go func() { + _, err := io.Copy(c, clientSrc) + errc <- err + }() + <-errc + }) +} diff --git a/cmd/tailscaled/sigpipe.go b/cmd/tailscaled/sigpipe.go index 2fcdab2a4660e..695a880248bc0 100644 --- a/cmd/tailscaled/sigpipe.go +++ b/cmd/tailscaled/sigpipe.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.21 && !plan9 - -package main - -import "syscall" - -func init() { - sigPipe = syscall.SIGPIPE -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.21 && !plan9 + +package main + +import "syscall" + +func init() { + sigPipe = syscall.SIGPIPE +} diff --git a/cmd/tailscaled/tailscaled.defaults b/cmd/tailscaled/tailscaled.defaults index e8384a4f82097..693a6190bfac8 100644 --- a/cmd/tailscaled/tailscaled.defaults +++ b/cmd/tailscaled/tailscaled.defaults @@ -1,8 +1,8 @@ -# Set the port to listen on for incoming VPN packets. -# Remote nodes will automatically be informed about the new port number, -# but you might want to configure this in order to set external firewall -# settings. -PORT="41641" - -# Extra flags you might want to pass to tailscaled. -FLAGS="" +# Set the port to listen on for incoming VPN packets. +# Remote nodes will automatically be informed about the new port number, +# but you might want to configure this in order to set external firewall +# settings. +PORT="41641" + +# Extra flags you might want to pass to tailscaled. +FLAGS="" diff --git a/cmd/tailscaled/tailscaled.openrc b/cmd/tailscaled/tailscaled.openrc index 309d70f23a26f..6193247ce3131 100755 --- a/cmd/tailscaled/tailscaled.openrc +++ b/cmd/tailscaled/tailscaled.openrc @@ -1,25 +1,25 @@ -#!/sbin/openrc-run - -set -a -source /etc/default/tailscaled -set +a - -command="/usr/sbin/tailscaled" -command_args="--state=/var/lib/tailscale/tailscaled.state --port=$PORT --socket=/var/run/tailscale/tailscaled.sock $FLAGS" -command_background=true -pidfile="/run/tailscaled.pid" -start_stop_daemon_args="-1 /var/log/tailscaled.log -2 /var/log/tailscaled.log" - -depend() { - need net -} - -start_pre() { - mkdir -p /var/run/tailscale - mkdir -p /var/lib/tailscale - $command --cleanup -} - -stop_post() { - $command --cleanup -} +#!/sbin/openrc-run + +set -a +source /etc/default/tailscaled +set +a + +command="/usr/sbin/tailscaled" +command_args="--state=/var/lib/tailscale/tailscaled.state --port=$PORT --socket=/var/run/tailscale/tailscaled.sock $FLAGS" +command_background=true +pidfile="/run/tailscaled.pid" +start_stop_daemon_args="-1 /var/log/tailscaled.log -2 /var/log/tailscaled.log" + +depend() { + need net +} + +start_pre() { + mkdir -p /var/run/tailscale + mkdir -p /var/lib/tailscale + $command --cleanup +} + +stop_post() { + $command --cleanup +} diff --git a/cmd/tailscaled/tailscaled_bird.go b/cmd/tailscaled/tailscaled_bird.go index c76f77bec6e36..885f552cb8f50 100644 --- a/cmd/tailscaled/tailscaled_bird.go +++ b/cmd/tailscaled/tailscaled_bird.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 && (linux || darwin || freebsd || openbsd) && !ts_omit_bird - -package main - -import ( - "tailscale.com/chirp" - "tailscale.com/wgengine" -) - -func init() { - createBIRDClient = func(ctlSocket string) (wgengine.BIRDClient, error) { - return chirp.New(ctlSocket) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 && (linux || darwin || freebsd || openbsd) && !ts_omit_bird + +package main + +import ( + "tailscale.com/chirp" + "tailscale.com/wgengine" +) + +func init() { + createBIRDClient = func(ctlSocket string) (wgengine.BIRDClient, error) { + return chirp.New(ctlSocket) + } +} diff --git a/cmd/tailscaled/tailscaled_notwindows.go b/cmd/tailscaled/tailscaled_notwindows.go index d5361cf286d3d..b0a7c159833f5 100644 --- a/cmd/tailscaled/tailscaled_notwindows.go +++ b/cmd/tailscaled/tailscaled_notwindows.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && go1.19 - -package main // import "tailscale.com/cmd/tailscaled" - -import "tailscale.com/logpolicy" - -func isWindowsService() bool { return false } - -func runWindowsService(pol *logpolicy.Policy) error { panic("unreachable") } - -func beWindowsSubprocess() bool { return false } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && go1.19 + +package main // import "tailscale.com/cmd/tailscaled" + +import "tailscale.com/logpolicy" + +func isWindowsService() bool { return false } + +func runWindowsService(pol *logpolicy.Policy) error { panic("unreachable") } + +func beWindowsSubprocess() bool { return false } diff --git a/cmd/tailscaled/windows-manifest.xml b/cmd/tailscaled/windows-manifest.xml index 6c5f46058387f..5eaa54fa514e3 100644 --- a/cmd/tailscaled/windows-manifest.xml +++ b/cmd/tailscaled/windows-manifest.xml @@ -1,13 +1,13 @@ - - - - - - - - - - - - - + + + + + + + + + + + + + diff --git a/cmd/tailscaled/with_cli.go b/cmd/tailscaled/with_cli.go index a8554eb8ce9dc..f191fdb45b288 100644 --- a/cmd/tailscaled/with_cli.go +++ b/cmd/tailscaled/with_cli.go @@ -1,23 +1,23 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ts_include_cli - -package main - -import ( - "fmt" - "os" - - "tailscale.com/cmd/tailscale/cli" -) - -func init() { - beCLI = func() { - args := os.Args[1:] - if err := cli.Run(args); err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_include_cli + +package main + +import ( + "fmt" + "os" + + "tailscale.com/cmd/tailscale/cli" +) + +func init() { + beCLI = func() { + args := os.Args[1:] + if err := cli.Run(args); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + } +} diff --git a/cmd/testwrapper/args_test.go b/cmd/testwrapper/args_test.go index 10063d7bcf6e1..f7f30a7eb2fa5 100644 --- a/cmd/testwrapper/args_test.go +++ b/cmd/testwrapper/args_test.go @@ -1,97 +1,97 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "slices" - "testing" -) - -func TestSplitArgs(t *testing.T) { - tests := []struct { - name string - in []string - pre, pkgs, post []string - }{ - { - name: "empty", - }, - { - name: "all", - in: []string{"-v", "pkg1", "pkg2", "-run", "TestFoo", "-timeout=20s"}, - pre: []string{"-v"}, - pkgs: []string{"pkg1", "pkg2"}, - post: []string{"-run", "TestFoo", "-timeout=20s"}, - }, - { - name: "only_pkgs", - in: []string{"./..."}, - pkgs: []string{"./..."}, - }, - { - name: "pkgs_and_post", - in: []string{"pkg1", "-run", "TestFoo"}, - pkgs: []string{"pkg1"}, - post: []string{"-run", "TestFoo"}, - }, - { - name: "pkgs_and_post", - in: []string{"-v", "pkg2"}, - pre: []string{"-v"}, - pkgs: []string{"pkg2"}, - }, - { - name: "only_args", - in: []string{"-v", "-run=TestFoo"}, - pre: []string{"-run", "TestFoo", "-v"}, // sorted - }, - { - name: "space_in_pre_arg", - in: []string{"-run", "TestFoo", "./cmd/testwrapper"}, - pre: []string{"-run", "TestFoo"}, - pkgs: []string{"./cmd/testwrapper"}, - }, - { - name: "space_in_arg", - in: []string{"-exec", "sudo -E", "./cmd/testwrapper"}, - pre: []string{"-exec", "sudo -E"}, - pkgs: []string{"./cmd/testwrapper"}, - }, - { - name: "test-arg", - in: []string{"-exec", "sudo -E", "./cmd/testwrapper", "--", "--some-flag"}, - pre: []string{"-exec", "sudo -E"}, - pkgs: []string{"./cmd/testwrapper"}, - post: []string{"--", "--some-flag"}, - }, - { - name: "dupe-args", - in: []string{"-v", "-v", "-race", "-race", "./cmd/testwrapper", "--", "--some-flag"}, - pre: []string{"-race", "-v"}, - pkgs: []string{"./cmd/testwrapper"}, - post: []string{"--", "--some-flag"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - pre, pkgs, post, err := splitArgs(tt.in) - if err != nil { - t.Fatal(err) - } - if !slices.Equal(pre, tt.pre) { - t.Errorf("pre = %q; want %q", pre, tt.pre) - } - if !slices.Equal(pkgs, tt.pkgs) { - t.Errorf("pattern = %q; want %q", pkgs, tt.pkgs) - } - if !slices.Equal(post, tt.post) { - t.Errorf("post = %q; want %q", post, tt.post) - } - if t.Failed() { - t.Logf("SplitArgs(%q) = %q %q %q", tt.in, pre, pkgs, post) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "slices" + "testing" +) + +func TestSplitArgs(t *testing.T) { + tests := []struct { + name string + in []string + pre, pkgs, post []string + }{ + { + name: "empty", + }, + { + name: "all", + in: []string{"-v", "pkg1", "pkg2", "-run", "TestFoo", "-timeout=20s"}, + pre: []string{"-v"}, + pkgs: []string{"pkg1", "pkg2"}, + post: []string{"-run", "TestFoo", "-timeout=20s"}, + }, + { + name: "only_pkgs", + in: []string{"./..."}, + pkgs: []string{"./..."}, + }, + { + name: "pkgs_and_post", + in: []string{"pkg1", "-run", "TestFoo"}, + pkgs: []string{"pkg1"}, + post: []string{"-run", "TestFoo"}, + }, + { + name: "pkgs_and_post", + in: []string{"-v", "pkg2"}, + pre: []string{"-v"}, + pkgs: []string{"pkg2"}, + }, + { + name: "only_args", + in: []string{"-v", "-run=TestFoo"}, + pre: []string{"-run", "TestFoo", "-v"}, // sorted + }, + { + name: "space_in_pre_arg", + in: []string{"-run", "TestFoo", "./cmd/testwrapper"}, + pre: []string{"-run", "TestFoo"}, + pkgs: []string{"./cmd/testwrapper"}, + }, + { + name: "space_in_arg", + in: []string{"-exec", "sudo -E", "./cmd/testwrapper"}, + pre: []string{"-exec", "sudo -E"}, + pkgs: []string{"./cmd/testwrapper"}, + }, + { + name: "test-arg", + in: []string{"-exec", "sudo -E", "./cmd/testwrapper", "--", "--some-flag"}, + pre: []string{"-exec", "sudo -E"}, + pkgs: []string{"./cmd/testwrapper"}, + post: []string{"--", "--some-flag"}, + }, + { + name: "dupe-args", + in: []string{"-v", "-v", "-race", "-race", "./cmd/testwrapper", "--", "--some-flag"}, + pre: []string{"-race", "-v"}, + pkgs: []string{"./cmd/testwrapper"}, + post: []string{"--", "--some-flag"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pre, pkgs, post, err := splitArgs(tt.in) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(pre, tt.pre) { + t.Errorf("pre = %q; want %q", pre, tt.pre) + } + if !slices.Equal(pkgs, tt.pkgs) { + t.Errorf("pattern = %q; want %q", pkgs, tt.pkgs) + } + if !slices.Equal(post, tt.post) { + t.Errorf("post = %q; want %q", post, tt.post) + } + if t.Failed() { + t.Logf("SplitArgs(%q) = %q %q %q", tt.in, pre, pkgs, post) + } + }) + } +} diff --git a/cmd/testwrapper/flakytest/flakytest.go b/cmd/testwrapper/flakytest/flakytest.go index 494ed080b26a1..e5e21dd2159ba 100644 --- a/cmd/testwrapper/flakytest/flakytest.go +++ b/cmd/testwrapper/flakytest/flakytest.go @@ -1,44 +1,44 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package flakytest contains test helpers for marking a test as flaky. For -// tests run using cmd/testwrapper, a failed flaky test will cause tests to be -// re-run a few time until they succeed or exceed our iteration limit. -package flakytest - -import ( - "fmt" - "os" - "regexp" - "testing" -) - -// FlakyTestLogMessage is a sentinel value that is printed to stderr when a -// flaky test is marked. This is used by cmd/testwrapper to detect flaky tests -// and retry them. -const FlakyTestLogMessage = "flakytest: this is a known flaky test" - -// FlakeAttemptEnv is an environment variable that is set by cmd/testwrapper -// when a flaky test is being (re)tried. It contains the attempt number, -// starting at 1. -const FlakeAttemptEnv = "TS_TESTWRAPPER_ATTEMPT" - -var issueRegexp = regexp.MustCompile(`\Ahttps://github\.com/tailscale/[a-zA-Z0-9_.-]+/issues/\d+\z`) - -// Mark sets the current test as a flaky test, such that if it fails, it will -// be retried a few times on failure. issue must be a GitHub issue that tracks -// the status of the flaky test being marked, of the format: -// -// https://github.com/tailscale/myRepo-H3re/issues/12345 -func Mark(t testing.TB, issue string) { - if !issueRegexp.MatchString(issue) { - t.Fatalf("bad issue format: %q", issue) - } - if _, ok := os.LookupEnv(FlakeAttemptEnv); ok { - // We're being run under cmd/testwrapper so send our sentinel message - // to stderr. (We avoid doing this when the env is absent to avoid - // spamming people running tests without the wrapper) - fmt.Fprintf(os.Stderr, "%s: %s\n", FlakyTestLogMessage, issue) - } - t.Logf("flakytest: issue tracking this flaky test: %s", issue) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package flakytest contains test helpers for marking a test as flaky. For +// tests run using cmd/testwrapper, a failed flaky test will cause tests to be +// re-run a few time until they succeed or exceed our iteration limit. +package flakytest + +import ( + "fmt" + "os" + "regexp" + "testing" +) + +// FlakyTestLogMessage is a sentinel value that is printed to stderr when a +// flaky test is marked. This is used by cmd/testwrapper to detect flaky tests +// and retry them. +const FlakyTestLogMessage = "flakytest: this is a known flaky test" + +// FlakeAttemptEnv is an environment variable that is set by cmd/testwrapper +// when a flaky test is being (re)tried. It contains the attempt number, +// starting at 1. +const FlakeAttemptEnv = "TS_TESTWRAPPER_ATTEMPT" + +var issueRegexp = regexp.MustCompile(`\Ahttps://github\.com/tailscale/[a-zA-Z0-9_.-]+/issues/\d+\z`) + +// Mark sets the current test as a flaky test, such that if it fails, it will +// be retried a few times on failure. issue must be a GitHub issue that tracks +// the status of the flaky test being marked, of the format: +// +// https://github.com/tailscale/myRepo-H3re/issues/12345 +func Mark(t testing.TB, issue string) { + if !issueRegexp.MatchString(issue) { + t.Fatalf("bad issue format: %q", issue) + } + if _, ok := os.LookupEnv(FlakeAttemptEnv); ok { + // We're being run under cmd/testwrapper so send our sentinel message + // to stderr. (We avoid doing this when the env is absent to avoid + // spamming people running tests without the wrapper) + fmt.Fprintf(os.Stderr, "%s: %s\n", FlakyTestLogMessage, issue) + } + t.Logf("flakytest: issue tracking this flaky test: %s", issue) +} diff --git a/cmd/testwrapper/flakytest/flakytest_test.go b/cmd/testwrapper/flakytest/flakytest_test.go index 85e77a939c75d..551352f6ad8ea 100644 --- a/cmd/testwrapper/flakytest/flakytest_test.go +++ b/cmd/testwrapper/flakytest/flakytest_test.go @@ -1,43 +1,43 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package flakytest - -import ( - "os" - "testing" -) - -func TestIssueFormat(t *testing.T) { - testCases := []struct { - issue string - want bool - }{ - {"https://github.com/tailscale/cOrp/issues/1234", true}, - {"https://github.com/otherproject/corp/issues/1234", false}, - {"https://github.com/tailscale/corp/issues/", false}, - } - for _, testCase := range testCases { - if issueRegexp.MatchString(testCase.issue) != testCase.want { - ss := "" - if !testCase.want { - ss = " not" - } - t.Errorf("expected issueRegexp to%s match %q", ss, testCase.issue) - } - } -} - -// TestFlakeRun is a test that fails when run in the testwrapper -// for the first time, but succeeds on the second run. -// It's used to test whether the testwrapper retries flaky tests. -func TestFlakeRun(t *testing.T) { - Mark(t, "https://github.com/tailscale/tailscale/issues/0") // random issue - e := os.Getenv(FlakeAttemptEnv) - if e == "" { - t.Skip("not running in testwrapper") - } - if e == "1" { - t.Fatal("First run in testwrapper, failing so that test is retried. This is expected.") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package flakytest + +import ( + "os" + "testing" +) + +func TestIssueFormat(t *testing.T) { + testCases := []struct { + issue string + want bool + }{ + {"https://github.com/tailscale/cOrp/issues/1234", true}, + {"https://github.com/otherproject/corp/issues/1234", false}, + {"https://github.com/tailscale/corp/issues/", false}, + } + for _, testCase := range testCases { + if issueRegexp.MatchString(testCase.issue) != testCase.want { + ss := "" + if !testCase.want { + ss = " not" + } + t.Errorf("expected issueRegexp to%s match %q", ss, testCase.issue) + } + } +} + +// TestFlakeRun is a test that fails when run in the testwrapper +// for the first time, but succeeds on the second run. +// It's used to test whether the testwrapper retries flaky tests. +func TestFlakeRun(t *testing.T) { + Mark(t, "https://github.com/tailscale/tailscale/issues/0") // random issue + e := os.Getenv(FlakeAttemptEnv) + if e == "" { + t.Skip("not running in testwrapper") + } + if e == "1" { + t.Fatal("First run in testwrapper, failing so that test is retried. This is expected.") + } +} diff --git a/cmd/tsconnect/.gitignore b/cmd/tsconnect/.gitignore index 13615d1213d63..b791f8e64b14e 100644 --- a/cmd/tsconnect/.gitignore +++ b/cmd/tsconnect/.gitignore @@ -1,3 +1,3 @@ -node_modules/ -/dist -/pkg +node_modules/ +/dist +/pkg diff --git a/cmd/tsconnect/README.md b/cmd/tsconnect/README.md index 536cd7bbf562c..f518f932e07eb 100644 --- a/cmd/tsconnect/README.md +++ b/cmd/tsconnect/README.md @@ -1,49 +1,49 @@ -# tsconnect - -The tsconnect command builds and serves the static site that is generated for -the Tailscale Connect JS/WASM client. - -## Development - -To start the development server: - -``` -./tool/go run ./cmd/tsconnect dev -``` - -The site is served at http://localhost:9090/. JavaScript, CSS and Go `wasm` package changes can be picked up with a browser reload. Server-side Go changes require the server to be stopped and restarted. In development mode the state the Tailscale client state is stored in `sessionStorage` and will thus survive page reloads (but not the tab being closed). - -## Deployment - -To build the static assets necessary for serving, run: - -``` -./tool/go run ./cmd/tsconnect build -``` - -To serve them, run: - -``` -./tool/go run ./cmd/tsconnect serve -``` - -By default the build output is placed in the `dist/` directory and embedded in the binary, but this can be controlled by the `-distdir` flag. The `-addr` flag controls the interface and port that the serve listens on. - -# Library / NPM Package - -The client is also available as [an NPM package](https://www.npmjs.com/package/@tailscale/connect). To build it, run: - -``` -./tool/go run ./cmd/tsconnect build-pkg -``` - -That places the output in the `pkg/` directory, which may then be uploaded to a package registry (or installed from the file path directly). - -To do two-sided development (on both the NPM package and code that uses it), run: - -``` -./tool/go run ./cmd/tsconnect dev-pkg - -``` - -This serves the module at http://localhost:9090/pkg/pkg.js and the generated wasm file at http://localhost:9090/pkg/main.wasm. The two files can be used as drop-in replacements for normal imports of the NPM module. +# tsconnect + +The tsconnect command builds and serves the static site that is generated for +the Tailscale Connect JS/WASM client. + +## Development + +To start the development server: + +``` +./tool/go run ./cmd/tsconnect dev +``` + +The site is served at http://localhost:9090/. JavaScript, CSS and Go `wasm` package changes can be picked up with a browser reload. Server-side Go changes require the server to be stopped and restarted. In development mode the state the Tailscale client state is stored in `sessionStorage` and will thus survive page reloads (but not the tab being closed). + +## Deployment + +To build the static assets necessary for serving, run: + +``` +./tool/go run ./cmd/tsconnect build +``` + +To serve them, run: + +``` +./tool/go run ./cmd/tsconnect serve +``` + +By default the build output is placed in the `dist/` directory and embedded in the binary, but this can be controlled by the `-distdir` flag. The `-addr` flag controls the interface and port that the serve listens on. + +# Library / NPM Package + +The client is also available as [an NPM package](https://www.npmjs.com/package/@tailscale/connect). To build it, run: + +``` +./tool/go run ./cmd/tsconnect build-pkg +``` + +That places the output in the `pkg/` directory, which may then be uploaded to a package registry (or installed from the file path directly). + +To do two-sided development (on both the NPM package and code that uses it), run: + +``` +./tool/go run ./cmd/tsconnect dev-pkg + +``` + +This serves the module at http://localhost:9090/pkg/pkg.js and the generated wasm file at http://localhost:9090/pkg/main.wasm. The two files can be used as drop-in replacements for normal imports of the NPM module. diff --git a/cmd/tsconnect/README.pkg.md b/cmd/tsconnect/README.pkg.md index df8d66789894d..df5799578d5e7 100644 --- a/cmd/tsconnect/README.pkg.md +++ b/cmd/tsconnect/README.pkg.md @@ -1,3 +1,3 @@ -# @tailscale/connect - -NPM package that contains a WebAssembly-based Tailscale client, see [the `cmd/tsconnect` directory in the tailscale repo](https://github.com/tailscale/tailscale/tree/main/cmd/tsconnect#library--npm-package) for more details. +# @tailscale/connect + +NPM package that contains a WebAssembly-based Tailscale client, see [the `cmd/tsconnect` directory in the tailscale repo](https://github.com/tailscale/tailscale/tree/main/cmd/tsconnect#library--npm-package) for more details. diff --git a/cmd/tsconnect/build-pkg.go b/cmd/tsconnect/build-pkg.go index 047504858ae0c..2b6cc9b1fcbc9 100644 --- a/cmd/tsconnect/build-pkg.go +++ b/cmd/tsconnect/build-pkg.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -package main - -import ( - "encoding/json" - "fmt" - "log" - "os" - "path" - - "github.com/tailscale/hujson" - "tailscale.com/util/precompress" - "tailscale.com/version" -) - -func runBuildPkg() { - buildOptions, err := commonPkgSetup(prodMode) - if err != nil { - log.Fatalf("Cannot setup: %v", err) - } - - log.Printf("Linting...\n") - if err := runYarn("lint"); err != nil { - log.Fatalf("Linting failed: %v", err) - } - - if err := cleanDir(*pkgDir); err != nil { - log.Fatalf("Cannot clean %s: %v", *pkgDir, err) - } - - buildOptions.Write = true - buildOptions.MinifyWhitespace = true - buildOptions.MinifyIdentifiers = true - buildOptions.MinifySyntax = true - - runEsbuild(*buildOptions) - - if err := precompressWasm(); err != nil { - log.Fatalf("Could not pre-recompress wasm: %v", err) - } - - log.Printf("Generating types...\n") - if err := runYarn("pkg-types"); err != nil { - log.Fatalf("Type generation failed: %v", err) - } - - if err := updateVersion(); err != nil { - log.Fatalf("Cannot update version: %v", err) - } - - if err := copyReadme(); err != nil { - log.Fatalf("Cannot copy readme: %v", err) - } - - log.Printf("Built package version %s", version.Long()) -} - -func precompressWasm() error { - log.Printf("Pre-compressing main.wasm...\n") - return precompress.Precompress(path.Join(*pkgDir, "main.wasm"), precompress.Options{ - FastCompression: *fastCompression, - }) -} - -func updateVersion() error { - packageJSONBytes, err := os.ReadFile("package.json.tmpl") - if err != nil { - return fmt.Errorf("Could not read package.json: %w", err) - } - - var packageJSON map[string]any - packageJSONBytes, err = hujson.Standardize(packageJSONBytes) - if err != nil { - return fmt.Errorf("Could not standardize template package.json: %w", err) - } - if err := json.Unmarshal(packageJSONBytes, &packageJSON); err != nil { - return fmt.Errorf("Could not unmarshal package.json: %w", err) - } - packageJSON["version"] = version.Long() - - packageJSONBytes, err = json.MarshalIndent(packageJSON, "", " ") - if err != nil { - return fmt.Errorf("Could not marshal package.json: %w", err) - } - - return os.WriteFile(path.Join(*pkgDir, "package.json"), packageJSONBytes, 0644) -} - -func copyReadme() error { - readmeBytes, err := os.ReadFile("README.pkg.md") - if err != nil { - return fmt.Errorf("Could not read README.pkg.md: %w", err) - } - return os.WriteFile(path.Join(*pkgDir, "README.md"), readmeBytes, 0644) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "encoding/json" + "fmt" + "log" + "os" + "path" + + "github.com/tailscale/hujson" + "tailscale.com/util/precompress" + "tailscale.com/version" +) + +func runBuildPkg() { + buildOptions, err := commonPkgSetup(prodMode) + if err != nil { + log.Fatalf("Cannot setup: %v", err) + } + + log.Printf("Linting...\n") + if err := runYarn("lint"); err != nil { + log.Fatalf("Linting failed: %v", err) + } + + if err := cleanDir(*pkgDir); err != nil { + log.Fatalf("Cannot clean %s: %v", *pkgDir, err) + } + + buildOptions.Write = true + buildOptions.MinifyWhitespace = true + buildOptions.MinifyIdentifiers = true + buildOptions.MinifySyntax = true + + runEsbuild(*buildOptions) + + if err := precompressWasm(); err != nil { + log.Fatalf("Could not pre-recompress wasm: %v", err) + } + + log.Printf("Generating types...\n") + if err := runYarn("pkg-types"); err != nil { + log.Fatalf("Type generation failed: %v", err) + } + + if err := updateVersion(); err != nil { + log.Fatalf("Cannot update version: %v", err) + } + + if err := copyReadme(); err != nil { + log.Fatalf("Cannot copy readme: %v", err) + } + + log.Printf("Built package version %s", version.Long()) +} + +func precompressWasm() error { + log.Printf("Pre-compressing main.wasm...\n") + return precompress.Precompress(path.Join(*pkgDir, "main.wasm"), precompress.Options{ + FastCompression: *fastCompression, + }) +} + +func updateVersion() error { + packageJSONBytes, err := os.ReadFile("package.json.tmpl") + if err != nil { + return fmt.Errorf("Could not read package.json: %w", err) + } + + var packageJSON map[string]any + packageJSONBytes, err = hujson.Standardize(packageJSONBytes) + if err != nil { + return fmt.Errorf("Could not standardize template package.json: %w", err) + } + if err := json.Unmarshal(packageJSONBytes, &packageJSON); err != nil { + return fmt.Errorf("Could not unmarshal package.json: %w", err) + } + packageJSON["version"] = version.Long() + + packageJSONBytes, err = json.MarshalIndent(packageJSON, "", " ") + if err != nil { + return fmt.Errorf("Could not marshal package.json: %w", err) + } + + return os.WriteFile(path.Join(*pkgDir, "package.json"), packageJSONBytes, 0644) +} + +func copyReadme() error { + readmeBytes, err := os.ReadFile("README.pkg.md") + if err != nil { + return fmt.Errorf("Could not read README.pkg.md: %w", err) + } + return os.WriteFile(path.Join(*pkgDir, "README.md"), readmeBytes, 0644) +} diff --git a/cmd/tsconnect/dev-pkg.go b/cmd/tsconnect/dev-pkg.go index de534c3b20625..cb5ebf39ef657 100644 --- a/cmd/tsconnect/dev-pkg.go +++ b/cmd/tsconnect/dev-pkg.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -package main - -import ( - "log" -) - -func runDevPkg() { - buildOptions, err := commonPkgSetup(devMode) - if err != nil { - log.Fatalf("Cannot setup: %v", err) - } - runEsbuildServe(*buildOptions) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "log" +) + +func runDevPkg() { + buildOptions, err := commonPkgSetup(devMode) + if err != nil { + log.Fatalf("Cannot setup: %v", err) + } + runEsbuildServe(*buildOptions) +} diff --git a/cmd/tsconnect/dev.go b/cmd/tsconnect/dev.go index 87b10adaf49c8..161eb3b866a00 100644 --- a/cmd/tsconnect/dev.go +++ b/cmd/tsconnect/dev.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -package main - -import ( - "log" -) - -func runDev() { - buildOptions, err := commonSetup(devMode) - if err != nil { - log.Fatalf("Cannot setup: %v", err) - } - runEsbuildServe(*buildOptions) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "log" +) + +func runDev() { + buildOptions, err := commonSetup(devMode) + if err != nil { + log.Fatalf("Cannot setup: %v", err) + } + runEsbuildServe(*buildOptions) +} diff --git a/cmd/tsconnect/dist/placeholder b/cmd/tsconnect/dist/placeholder index 4af99d997207f..dddaba4d76687 100644 --- a/cmd/tsconnect/dist/placeholder +++ b/cmd/tsconnect/dist/placeholder @@ -1,2 +1,2 @@ -This is here to make sure the dist/ directory exists for the go:embed command -in serve.go. +This is here to make sure the dist/ directory exists for the go:embed command +in serve.go. diff --git a/cmd/tsconnect/index.html b/cmd/tsconnect/index.html index 3db45fdef2bca..39aa7571add71 100644 --- a/cmd/tsconnect/index.html +++ b/cmd/tsconnect/index.html @@ -1,20 +1,20 @@ - - - - - - Tailscale Connect - - - - - -
-
-

Tailscale Connect

-
Loading…
-
-
- - + + + + + + Tailscale Connect + + + + + +
+
+

Tailscale Connect

+
Loading…
+
+
+ + diff --git a/cmd/tsconnect/package.json b/cmd/tsconnect/package.json index bf4eb7c099aac..8ea726cc670b8 100644 --- a/cmd/tsconnect/package.json +++ b/cmd/tsconnect/package.json @@ -1,25 +1,25 @@ -{ - "name": "tsconnect", - "version": "0.0.1", - "license": "BSD-3-Clause", - "devDependencies": { - "@types/golang-wasm-exec": "^1.15.0", - "@types/qrcode": "^1.4.2", - "dts-bundle-generator": "^6.12.0", - "preact": "^10.10.0", - "qrcode": "^1.5.0", - "tailwindcss": "^3.1.6", - "typescript": "^4.7.4", - "xterm": "^5.1.0", - "xterm-addon-fit": "^0.7.0", - "xterm-addon-web-links": "^0.8.0" - }, - "scripts": { - "lint": "tsc --noEmit", - "pkg-types": "dts-bundle-generator --inline-declare-global=true --no-banner -o pkg/pkg.d.ts src/pkg/pkg.ts" - }, - "prettier": { - "semi": false, - "printWidth": 80 - } -} +{ + "name": "tsconnect", + "version": "0.0.1", + "license": "BSD-3-Clause", + "devDependencies": { + "@types/golang-wasm-exec": "^1.15.0", + "@types/qrcode": "^1.4.2", + "dts-bundle-generator": "^6.12.0", + "preact": "^10.10.0", + "qrcode": "^1.5.0", + "tailwindcss": "^3.1.6", + "typescript": "^4.7.4", + "xterm": "^5.1.0", + "xterm-addon-fit": "^0.7.0", + "xterm-addon-web-links": "^0.8.0" + }, + "scripts": { + "lint": "tsc --noEmit", + "pkg-types": "dts-bundle-generator --inline-declare-global=true --no-banner -o pkg/pkg.d.ts src/pkg/pkg.ts" + }, + "prettier": { + "semi": false, + "printWidth": 80 + } +} diff --git a/cmd/tsconnect/package.json.tmpl b/cmd/tsconnect/package.json.tmpl index 404b896eaf89e..0263bf48118dd 100644 --- a/cmd/tsconnect/package.json.tmpl +++ b/cmd/tsconnect/package.json.tmpl @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Template for the package.json that is generated by the build-pkg command. -// The version number will be replaced by the current Tailscale client version -// number. -{ - "author": "Tailscale Inc.", - "description": "Tailscale Connect SDK", - "license": "BSD-3-Clause", - "name": "@tailscale/connect", - "type": "module", - "main": "./pkg.js", - "types": "./pkg.d.ts", - "version": "AUTO_GENERATED" -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Template for the package.json that is generated by the build-pkg command. +// The version number will be replaced by the current Tailscale client version +// number. +{ + "author": "Tailscale Inc.", + "description": "Tailscale Connect SDK", + "license": "BSD-3-Clause", + "name": "@tailscale/connect", + "type": "module", + "main": "./pkg.js", + "types": "./pkg.d.ts", + "version": "AUTO_GENERATED" +} diff --git a/cmd/tsconnect/serve.go b/cmd/tsconnect/serve.go index d780bdd57c3e3..80844bea74b6e 100644 --- a/cmd/tsconnect/serve.go +++ b/cmd/tsconnect/serve.go @@ -1,144 +1,144 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -package main - -import ( - "bytes" - "embed" - "encoding/json" - "fmt" - "io" - "io/fs" - "log" - "net/http" - "os" - "path" - "time" - - "tailscale.com/tsweb" - "tailscale.com/util/precompress" -) - -//go:embed index.html -var embeddedFS embed.FS - -//go:embed dist/* -var embeddedDistFS embed.FS - -var serveStartTime = time.Now() - -func runServe() { - mux := http.NewServeMux() - - var distFS fs.FS - if *distDir == "./dist" { - var err error - distFS, err = fs.Sub(embeddedDistFS, "dist") - if err != nil { - log.Fatalf("Could not drop dist/ prefix from embedded FS: %v", err) - } - } else { - distFS = os.DirFS(*distDir) - } - - indexBytes, err := generateServeIndex(distFS) - if err != nil { - log.Fatalf("Could not generate index.html: %v", err) - } - mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.ServeContent(w, r, "index.html", serveStartTime, bytes.NewReader(indexBytes)) - })) - mux.Handle("/dist/", http.StripPrefix("/dist/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handleServeDist(w, r, distFS) - }))) - tsweb.Debugger(mux) - - log.Printf("Listening on %s", *addr) - err = http.ListenAndServe(*addr, mux) - if err != nil { - log.Fatal(err) - } -} - -func generateServeIndex(distFS fs.FS) ([]byte, error) { - log.Printf("Generating index.html...\n") - rawIndexBytes, err := embeddedFS.ReadFile("index.html") - if err != nil { - return nil, fmt.Errorf("Could not read index.html: %w", err) - } - - esbuildMetadataFile, err := distFS.Open("esbuild-metadata.json") - if err != nil { - return nil, fmt.Errorf("Could not open esbuild-metadata.json: %w", err) - } - defer esbuildMetadataFile.Close() - esbuildMetadataBytes, err := io.ReadAll(esbuildMetadataFile) - if err != nil { - return nil, fmt.Errorf("Could not read esbuild-metadata.json: %w", err) - } - var esbuildMetadata EsbuildMetadata - if err := json.Unmarshal(esbuildMetadataBytes, &esbuildMetadata); err != nil { - return nil, fmt.Errorf("Could not parse esbuild-metadata.json: %w", err) - } - entryPointsToHashedDistPaths := make(map[string]string) - mainWasmPath := "" - for outputPath, output := range esbuildMetadata.Outputs { - if output.EntryPoint != "" { - entryPointsToHashedDistPaths[output.EntryPoint] = path.Join("dist", outputPath) - } - if path.Ext(outputPath) == ".wasm" { - for input := range output.Inputs { - if input == "src/main.wasm" { - mainWasmPath = path.Join("dist", outputPath) - break - } - } - } - } - - indexBytes := rawIndexBytes - for entryPointPath, defaultDistPath := range entryPointsToDefaultDistPaths { - hashedDistPath := entryPointsToHashedDistPaths[entryPointPath] - if hashedDistPath != "" { - indexBytes = bytes.ReplaceAll(indexBytes, []byte(defaultDistPath), []byte(hashedDistPath)) - } - } - if mainWasmPath != "" { - mainWasmPrefetch := fmt.Sprintf("\n", mainWasmPath) - indexBytes = bytes.ReplaceAll(indexBytes, []byte(""), []byte(mainWasmPrefetch)) - } - - return indexBytes, nil -} - -var entryPointsToDefaultDistPaths = map[string]string{ - "src/app/index.css": "dist/index.css", - "src/app/index.ts": "dist/index.js", -} - -func handleServeDist(w http.ResponseWriter, r *http.Request, distFS fs.FS) { - path := r.URL.Path - f, err := precompress.OpenPrecompressedFile(w, r, path, distFS) - if err != nil { - http.Error(w, err.Error(), http.StatusNotFound) - return - } - defer f.Close() - - // fs.File does not claim to implement Seeker, but in practice it does. - fSeeker, ok := f.(io.ReadSeeker) - if !ok { - http.Error(w, "Not seekable", http.StatusInternalServerError) - return - } - - // Aggressively cache static assets, since we cache-bust our assets with - // hashed filenames. - w.Header().Set("Cache-Control", "public, max-age=31535996") - w.Header().Set("Vary", "Accept-Encoding") - - http.ServeContent(w, r, path, serveStartTime, fSeeker) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "bytes" + "embed" + "encoding/json" + "fmt" + "io" + "io/fs" + "log" + "net/http" + "os" + "path" + "time" + + "tailscale.com/tsweb" + "tailscale.com/util/precompress" +) + +//go:embed index.html +var embeddedFS embed.FS + +//go:embed dist/* +var embeddedDistFS embed.FS + +var serveStartTime = time.Now() + +func runServe() { + mux := http.NewServeMux() + + var distFS fs.FS + if *distDir == "./dist" { + var err error + distFS, err = fs.Sub(embeddedDistFS, "dist") + if err != nil { + log.Fatalf("Could not drop dist/ prefix from embedded FS: %v", err) + } + } else { + distFS = os.DirFS(*distDir) + } + + indexBytes, err := generateServeIndex(distFS) + if err != nil { + log.Fatalf("Could not generate index.html: %v", err) + } + mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.ServeContent(w, r, "index.html", serveStartTime, bytes.NewReader(indexBytes)) + })) + mux.Handle("/dist/", http.StripPrefix("/dist/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleServeDist(w, r, distFS) + }))) + tsweb.Debugger(mux) + + log.Printf("Listening on %s", *addr) + err = http.ListenAndServe(*addr, mux) + if err != nil { + log.Fatal(err) + } +} + +func generateServeIndex(distFS fs.FS) ([]byte, error) { + log.Printf("Generating index.html...\n") + rawIndexBytes, err := embeddedFS.ReadFile("index.html") + if err != nil { + return nil, fmt.Errorf("Could not read index.html: %w", err) + } + + esbuildMetadataFile, err := distFS.Open("esbuild-metadata.json") + if err != nil { + return nil, fmt.Errorf("Could not open esbuild-metadata.json: %w", err) + } + defer esbuildMetadataFile.Close() + esbuildMetadataBytes, err := io.ReadAll(esbuildMetadataFile) + if err != nil { + return nil, fmt.Errorf("Could not read esbuild-metadata.json: %w", err) + } + var esbuildMetadata EsbuildMetadata + if err := json.Unmarshal(esbuildMetadataBytes, &esbuildMetadata); err != nil { + return nil, fmt.Errorf("Could not parse esbuild-metadata.json: %w", err) + } + entryPointsToHashedDistPaths := make(map[string]string) + mainWasmPath := "" + for outputPath, output := range esbuildMetadata.Outputs { + if output.EntryPoint != "" { + entryPointsToHashedDistPaths[output.EntryPoint] = path.Join("dist", outputPath) + } + if path.Ext(outputPath) == ".wasm" { + for input := range output.Inputs { + if input == "src/main.wasm" { + mainWasmPath = path.Join("dist", outputPath) + break + } + } + } + } + + indexBytes := rawIndexBytes + for entryPointPath, defaultDistPath := range entryPointsToDefaultDistPaths { + hashedDistPath := entryPointsToHashedDistPaths[entryPointPath] + if hashedDistPath != "" { + indexBytes = bytes.ReplaceAll(indexBytes, []byte(defaultDistPath), []byte(hashedDistPath)) + } + } + if mainWasmPath != "" { + mainWasmPrefetch := fmt.Sprintf("\n", mainWasmPath) + indexBytes = bytes.ReplaceAll(indexBytes, []byte(""), []byte(mainWasmPrefetch)) + } + + return indexBytes, nil +} + +var entryPointsToDefaultDistPaths = map[string]string{ + "src/app/index.css": "dist/index.css", + "src/app/index.ts": "dist/index.js", +} + +func handleServeDist(w http.ResponseWriter, r *http.Request, distFS fs.FS) { + path := r.URL.Path + f, err := precompress.OpenPrecompressedFile(w, r, path, distFS) + if err != nil { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + defer f.Close() + + // fs.File does not claim to implement Seeker, but in practice it does. + fSeeker, ok := f.(io.ReadSeeker) + if !ok { + http.Error(w, "Not seekable", http.StatusInternalServerError) + return + } + + // Aggressively cache static assets, since we cache-bust our assets with + // hashed filenames. + w.Header().Set("Cache-Control", "public, max-age=31535996") + w.Header().Set("Vary", "Accept-Encoding") + + http.ServeContent(w, r, path, serveStartTime, fSeeker) +} diff --git a/cmd/tsconnect/src/app/app.tsx b/cmd/tsconnect/src/app/app.tsx index ee538eaeac506..c0aa7a5e88f63 100644 --- a/cmd/tsconnect/src/app/app.tsx +++ b/cmd/tsconnect/src/app/app.tsx @@ -1,147 +1,147 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -import { render, Component } from "preact" -import { URLDisplay } from "./url-display" -import { Header } from "./header" -import { GoPanicDisplay } from "./go-panic-display" -import { SSH } from "./ssh" - -type AppState = { - ipn?: IPN - ipnState: IPNState - netMap?: IPNNetMap - browseToURL?: string - goPanicError?: string -} - -class App extends Component<{}, AppState> { - state: AppState = { ipnState: "NoState" } - #goPanicTimeout?: number - - render() { - const { ipn, ipnState, goPanicError, netMap, browseToURL } = this.state - - let goPanicDisplay - if (goPanicError) { - goPanicDisplay = ( - - ) - } - - let urlDisplay - if (browseToURL) { - urlDisplay = - } - - let machineAuthInstructions - if (ipnState === "NeedsMachineAuth") { - machineAuthInstructions = ( -
- An administrator needs to approve this device. -
- ) - } - - const lockedOut = netMap?.lockedOut - let lockedOutInstructions - if (lockedOut) { - lockedOutInstructions = ( -
-

This instance of Tailscale Connect needs to be signed, due to - {" "}tailnet lock{" "} - being enabled on this domain. -

- -

- Run the following command on a device with a trusted tailnet lock key: -

tailscale lock sign {netMap.self.nodeKey}
-

-
- ) - } - - let ssh - if (ipn && ipnState === "Running" && netMap && !lockedOut) { - ssh = - } - - return ( - <> -
- {goPanicDisplay} -
- {urlDisplay} - {machineAuthInstructions} - {lockedOutInstructions} - {ssh} -
- - ) - } - - runWithIPN(ipn: IPN) { - this.setState({ ipn }, () => { - ipn.run({ - notifyState: this.handleIPNState, - notifyNetMap: this.handleNetMap, - notifyBrowseToURL: this.handleBrowseToURL, - notifyPanicRecover: this.handleGoPanic, - }) - }) - } - - handleIPNState = (state: IPNState) => { - const { ipn } = this.state - this.setState({ ipnState: state }) - if (state === "NeedsLogin") { - ipn?.login() - } else if (["Running", "NeedsMachineAuth"].includes(state)) { - this.setState({ browseToURL: undefined }) - } - } - - handleNetMap = (netMapStr: string) => { - const netMap = JSON.parse(netMapStr) as IPNNetMap - if (DEBUG) { - console.log("Received net map: " + JSON.stringify(netMap, null, 2)) - } - this.setState({ netMap }) - } - - handleBrowseToURL = (url: string) => { - if (this.state.ipnState === "Running") { - // Ignore URL requests if we're already running -- it's most likely an - // SSH check mode trigger and we already linkify the displayed URL - // in the terminal. - return - } - this.setState({ browseToURL: url }) - } - - handleGoPanic = (error: string) => { - if (DEBUG) { - console.error("Go panic", error) - } - this.setState({ goPanicError: error }) - if (this.#goPanicTimeout) { - window.clearTimeout(this.#goPanicTimeout) - } - this.#goPanicTimeout = window.setTimeout(this.clearGoPanic, 10000) - } - - clearGoPanic = () => { - window.clearTimeout(this.#goPanicTimeout) - this.#goPanicTimeout = undefined - this.setState({ goPanicError: undefined }) - } -} - -export function renderApp(): Promise { - return new Promise((resolve) => { - render( - (app ? resolve(app) : undefined)} />, - document.body - ) - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +import { render, Component } from "preact" +import { URLDisplay } from "./url-display" +import { Header } from "./header" +import { GoPanicDisplay } from "./go-panic-display" +import { SSH } from "./ssh" + +type AppState = { + ipn?: IPN + ipnState: IPNState + netMap?: IPNNetMap + browseToURL?: string + goPanicError?: string +} + +class App extends Component<{}, AppState> { + state: AppState = { ipnState: "NoState" } + #goPanicTimeout?: number + + render() { + const { ipn, ipnState, goPanicError, netMap, browseToURL } = this.state + + let goPanicDisplay + if (goPanicError) { + goPanicDisplay = ( + + ) + } + + let urlDisplay + if (browseToURL) { + urlDisplay = + } + + let machineAuthInstructions + if (ipnState === "NeedsMachineAuth") { + machineAuthInstructions = ( +
+ An administrator needs to approve this device. +
+ ) + } + + const lockedOut = netMap?.lockedOut + let lockedOutInstructions + if (lockedOut) { + lockedOutInstructions = ( +
+

This instance of Tailscale Connect needs to be signed, due to + {" "}tailnet lock{" "} + being enabled on this domain. +

+ +

+ Run the following command on a device with a trusted tailnet lock key: +

tailscale lock sign {netMap.self.nodeKey}
+

+
+ ) + } + + let ssh + if (ipn && ipnState === "Running" && netMap && !lockedOut) { + ssh = + } + + return ( + <> +
+ {goPanicDisplay} +
+ {urlDisplay} + {machineAuthInstructions} + {lockedOutInstructions} + {ssh} +
+ + ) + } + + runWithIPN(ipn: IPN) { + this.setState({ ipn }, () => { + ipn.run({ + notifyState: this.handleIPNState, + notifyNetMap: this.handleNetMap, + notifyBrowseToURL: this.handleBrowseToURL, + notifyPanicRecover: this.handleGoPanic, + }) + }) + } + + handleIPNState = (state: IPNState) => { + const { ipn } = this.state + this.setState({ ipnState: state }) + if (state === "NeedsLogin") { + ipn?.login() + } else if (["Running", "NeedsMachineAuth"].includes(state)) { + this.setState({ browseToURL: undefined }) + } + } + + handleNetMap = (netMapStr: string) => { + const netMap = JSON.parse(netMapStr) as IPNNetMap + if (DEBUG) { + console.log("Received net map: " + JSON.stringify(netMap, null, 2)) + } + this.setState({ netMap }) + } + + handleBrowseToURL = (url: string) => { + if (this.state.ipnState === "Running") { + // Ignore URL requests if we're already running -- it's most likely an + // SSH check mode trigger and we already linkify the displayed URL + // in the terminal. + return + } + this.setState({ browseToURL: url }) + } + + handleGoPanic = (error: string) => { + if (DEBUG) { + console.error("Go panic", error) + } + this.setState({ goPanicError: error }) + if (this.#goPanicTimeout) { + window.clearTimeout(this.#goPanicTimeout) + } + this.#goPanicTimeout = window.setTimeout(this.clearGoPanic, 10000) + } + + clearGoPanic = () => { + window.clearTimeout(this.#goPanicTimeout) + this.#goPanicTimeout = undefined + this.setState({ goPanicError: undefined }) + } +} + +export function renderApp(): Promise { + return new Promise((resolve) => { + render( + (app ? resolve(app) : undefined)} />, + document.body + ) + }) +} diff --git a/cmd/tsconnect/src/app/go-panic-display.tsx b/cmd/tsconnect/src/app/go-panic-display.tsx index 5dd7095a27c7d..aab35c4d55e9c 100644 --- a/cmd/tsconnect/src/app/go-panic-display.tsx +++ b/cmd/tsconnect/src/app/go-panic-display.tsx @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -export function GoPanicDisplay({ - error, - dismiss, -}: { - error: string - dismiss: () => void -}) { - return ( -
- Tailscale has encountered an error. -
Click to reload
-
- ) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +export function GoPanicDisplay({ + error, + dismiss, +}: { + error: string + dismiss: () => void +}) { + return ( +
+ Tailscale has encountered an error. +
Click to reload
+
+ ) +} diff --git a/cmd/tsconnect/src/app/header.tsx b/cmd/tsconnect/src/app/header.tsx index 099ff2f8c2f7d..8449f4563689d 100644 --- a/cmd/tsconnect/src/app/header.tsx +++ b/cmd/tsconnect/src/app/header.tsx @@ -1,37 +1,37 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -export function Header({ state, ipn }: { state: IPNState; ipn?: IPN }) { - const stateText = STATE_LABELS[state] - - let logoutButton - if (state === "Running") { - logoutButton = ( - - ) - } - return ( -
-
-

Tailscale Connect

-
{stateText}
- {logoutButton} -
-
- ) -} - -const STATE_LABELS = { - NoState: "Initializing…", - InUseOtherUser: "In-use by another user", - NeedsLogin: "Needs login", - NeedsMachineAuth: "Needs approval", - Stopped: "Stopped", - Starting: "Starting…", - Running: "Running", -} as const +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +export function Header({ state, ipn }: { state: IPNState; ipn?: IPN }) { + const stateText = STATE_LABELS[state] + + let logoutButton + if (state === "Running") { + logoutButton = ( + + ) + } + return ( +
+
+

Tailscale Connect

+
{stateText}
+ {logoutButton} +
+
+ ) +} + +const STATE_LABELS = { + NoState: "Initializing…", + InUseOtherUser: "In-use by another user", + NeedsLogin: "Needs login", + NeedsMachineAuth: "Needs approval", + Stopped: "Stopped", + Starting: "Starting…", + Running: "Running", +} as const diff --git a/cmd/tsconnect/src/app/index.css b/cmd/tsconnect/src/app/index.css index 751b313d9f362..848b83d12b5c9 100644 --- a/cmd/tsconnect/src/app/index.css +++ b/cmd/tsconnect/src/app/index.css @@ -1,74 +1,74 @@ -/* Copyright (c) Tailscale Inc & AUTHORS */ -/* SPDX-License-Identifier: BSD-3-Clause */ - -@import "xterm/css/xterm.css"; - -@tailwind base; -@tailwind components; -@tailwind utilities; - -.link { - @apply text-blue-600; -} - -.link:hover { - @apply underline; -} - -.button { - @apply font-medium py-1 px-2 rounded-md border border-transparent text-center cursor-pointer; - transition-property: background-color, border-color, color, box-shadow; - transition-duration: 120ms; - box-shadow: 0 1px 1px rgba(0, 0, 0, 0.04); - min-width: 80px; -} -.button:focus { - @apply outline-none ring; -} -.button:disabled { - @apply pointer-events-none select-none; -} - -.input { - @apply appearance-none leading-tight rounded-md bg-white border border-gray-300 hover:border-gray-400 transition-colors px-3; - height: 2.375rem; -} - -.input::placeholder { - @apply text-gray-400; -} - -.input:disabled { - @apply border-gray-200; - @apply bg-gray-50; - @apply cursor-not-allowed; -} - -.input:focus { - @apply outline-none ring border-transparent; -} - -.select { - @apply appearance-none py-2 px-3 leading-tight rounded-md bg-white border border-gray-300; -} - -.select-with-arrow { - @apply relative; -} - -.select-with-arrow .select { - width: 100%; -} - -.select-with-arrow::after { - @apply absolute; - content: ""; - top: 50%; - right: 0.5rem; - transform: translate(-0.3em, -0.15em); - width: 0.6em; - height: 0.4em; - opacity: 0.6; - background-color: currentColor; - clip-path: polygon(100% 0%, 0 0%, 50% 100%); -} +/* Copyright (c) Tailscale Inc & AUTHORS */ +/* SPDX-License-Identifier: BSD-3-Clause */ + +@import "xterm/css/xterm.css"; + +@tailwind base; +@tailwind components; +@tailwind utilities; + +.link { + @apply text-blue-600; +} + +.link:hover { + @apply underline; +} + +.button { + @apply font-medium py-1 px-2 rounded-md border border-transparent text-center cursor-pointer; + transition-property: background-color, border-color, color, box-shadow; + transition-duration: 120ms; + box-shadow: 0 1px 1px rgba(0, 0, 0, 0.04); + min-width: 80px; +} +.button:focus { + @apply outline-none ring; +} +.button:disabled { + @apply pointer-events-none select-none; +} + +.input { + @apply appearance-none leading-tight rounded-md bg-white border border-gray-300 hover:border-gray-400 transition-colors px-3; + height: 2.375rem; +} + +.input::placeholder { + @apply text-gray-400; +} + +.input:disabled { + @apply border-gray-200; + @apply bg-gray-50; + @apply cursor-not-allowed; +} + +.input:focus { + @apply outline-none ring border-transparent; +} + +.select { + @apply appearance-none py-2 px-3 leading-tight rounded-md bg-white border border-gray-300; +} + +.select-with-arrow { + @apply relative; +} + +.select-with-arrow .select { + width: 100%; +} + +.select-with-arrow::after { + @apply absolute; + content: ""; + top: 50%; + right: 0.5rem; + transform: translate(-0.3em, -0.15em); + width: 0.6em; + height: 0.4em; + opacity: 0.6; + background-color: currentColor; + clip-path: polygon(100% 0%, 0 0%, 50% 100%); +} diff --git a/cmd/tsconnect/src/app/index.ts b/cmd/tsconnect/src/app/index.ts index 24ca4543921ae..1432188aec1a1 100644 --- a/cmd/tsconnect/src/app/index.ts +++ b/cmd/tsconnect/src/app/index.ts @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -import "../wasm_exec" -import wasmUrl from "./main.wasm" -import { sessionStateStorage } from "../lib/js-state-store" -import { renderApp } from "./app" - -async function main() { - const app = await renderApp() - const go = new Go() - const wasmInstance = await WebAssembly.instantiateStreaming( - fetch(`./dist/${wasmUrl}`), - go.importObject - ) - // The Go process should never exit, if it does then it's an unhandled panic. - go.run(wasmInstance.instance).then(() => - app.handleGoPanic("Unexpected shutdown") - ) - - const params = new URLSearchParams(window.location.search) - const authKey = params.get("authkey") ?? undefined - - const ipn = newIPN({ - // Persist IPN state in sessionStorage in development, so that we don't need - // to re-authorize every time we reload the page. - stateStorage: DEBUG ? sessionStateStorage : undefined, - // authKey allows for an auth key to be - // specified as a url param which automatically - // authorizes the client for use. - authKey: DEBUG ? authKey : undefined, - }) - app.runWithIPN(ipn) -} - -main() +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +import "../wasm_exec" +import wasmUrl from "./main.wasm" +import { sessionStateStorage } from "../lib/js-state-store" +import { renderApp } from "./app" + +async function main() { + const app = await renderApp() + const go = new Go() + const wasmInstance = await WebAssembly.instantiateStreaming( + fetch(`./dist/${wasmUrl}`), + go.importObject + ) + // The Go process should never exit, if it does then it's an unhandled panic. + go.run(wasmInstance.instance).then(() => + app.handleGoPanic("Unexpected shutdown") + ) + + const params = new URLSearchParams(window.location.search) + const authKey = params.get("authkey") ?? undefined + + const ipn = newIPN({ + // Persist IPN state in sessionStorage in development, so that we don't need + // to re-authorize every time we reload the page. + stateStorage: DEBUG ? sessionStateStorage : undefined, + // authKey allows for an auth key to be + // specified as a url param which automatically + // authorizes the client for use. + authKey: DEBUG ? authKey : undefined, + }) + app.runWithIPN(ipn) +} + +main() diff --git a/cmd/tsconnect/src/app/ssh.tsx b/cmd/tsconnect/src/app/ssh.tsx index df81745bd3fd7..1534fd5db643f 100644 --- a/cmd/tsconnect/src/app/ssh.tsx +++ b/cmd/tsconnect/src/app/ssh.tsx @@ -1,157 +1,157 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -import { useState, useCallback, useMemo, useEffect, useRef } from "preact/hooks" -import { createPortal } from "preact/compat" -import type { VNode } from "preact" -import { runSSHSession, SSHSessionDef } from "../lib/ssh" - -export function SSH({ netMap, ipn }: { netMap: IPNNetMap; ipn: IPN }) { - const [sshSessionDef, setSSHSessionDef] = useState( - null - ) - const clearSSHSessionDef = useCallback(() => setSSHSessionDef(null), []) - if (sshSessionDef) { - const sshSession = ( - - ) - if (sshSessionDef.newWindow) { - return {sshSession} - } - return sshSession - } - const sshPeers = netMap.peers.filter( - (p) => p.tailscaleSSHEnabled && p.online !== false - ) - - if (sshPeers.length == 0) { - return - } - - return -} - -type SSHFormSessionDef = SSHSessionDef & { newWindow?: boolean } - -function SSHSession({ - def, - ipn, - onDone, -}: { - def: SSHSessionDef - ipn: IPN - onDone: () => void -}) { - const ref = useRef(null) - useEffect(() => { - if (ref.current) { - runSSHSession(ref.current, def, ipn, { - onConnectionProgress: (p) => console.log("Connection progress", p), - onConnected() {}, - onError: (err) => console.error(err), - onDone, - }) - } - }, [ref]) - - return
-} - -function NoSSHPeers() { - return ( -
- None of your machines have{" "} - - Tailscale SSH - - {" "}enabled. Give it a try! -
- ) -} - -function SSHForm({ - sshPeers, - onSubmit, -}: { - sshPeers: IPNNetMapPeerNode[] - onSubmit: (def: SSHFormSessionDef) => void -}) { - sshPeers = sshPeers.slice().sort((a, b) => a.name.localeCompare(b.name)) - const [username, setUsername] = useState("") - const [hostname, setHostname] = useState(sshPeers[0].name) - return ( -
{ - e.preventDefault() - onSubmit({ username, hostname }) - }} - > - setUsername(e.currentTarget.value)} - /> -
- -
- { - if (e.altKey) { - e.preventDefault() - e.stopPropagation() - onSubmit({ username, hostname, newWindow: true }) - } - }} - /> -
- ) -} - -const NewWindow = ({ - children, - close, -}: { - children: VNode - close: () => void -}) => { - const newWindow = useMemo(() => { - const newWindow = window.open(undefined, undefined, "width=600,height=400") - if (newWindow) { - const containerNode = newWindow.document.createElement("div") - containerNode.className = "h-screen flex flex-col overflow-hidden" - newWindow.document.body.appendChild(containerNode) - - for (const linkNode of document.querySelectorAll( - "head link[rel=stylesheet]" - )) { - const newLink = document.createElement("link") - newLink.rel = "stylesheet" - newLink.href = (linkNode as HTMLLinkElement).href - newWindow.document.head.appendChild(newLink) - } - } - return newWindow - }, []) - if (!newWindow) { - console.error("Could not open window") - return null - } - newWindow.onbeforeunload = () => { - close() - } - - useEffect(() => () => newWindow.close(), []) - return createPortal(children, newWindow.document.body.lastChild as Element) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +import { useState, useCallback, useMemo, useEffect, useRef } from "preact/hooks" +import { createPortal } from "preact/compat" +import type { VNode } from "preact" +import { runSSHSession, SSHSessionDef } from "../lib/ssh" + +export function SSH({ netMap, ipn }: { netMap: IPNNetMap; ipn: IPN }) { + const [sshSessionDef, setSSHSessionDef] = useState( + null + ) + const clearSSHSessionDef = useCallback(() => setSSHSessionDef(null), []) + if (sshSessionDef) { + const sshSession = ( + + ) + if (sshSessionDef.newWindow) { + return {sshSession} + } + return sshSession + } + const sshPeers = netMap.peers.filter( + (p) => p.tailscaleSSHEnabled && p.online !== false + ) + + if (sshPeers.length == 0) { + return + } + + return +} + +type SSHFormSessionDef = SSHSessionDef & { newWindow?: boolean } + +function SSHSession({ + def, + ipn, + onDone, +}: { + def: SSHSessionDef + ipn: IPN + onDone: () => void +}) { + const ref = useRef(null) + useEffect(() => { + if (ref.current) { + runSSHSession(ref.current, def, ipn, { + onConnectionProgress: (p) => console.log("Connection progress", p), + onConnected() {}, + onError: (err) => console.error(err), + onDone, + }) + } + }, [ref]) + + return
+} + +function NoSSHPeers() { + return ( +
+ None of your machines have{" "} + + Tailscale SSH + + {" "}enabled. Give it a try! +
+ ) +} + +function SSHForm({ + sshPeers, + onSubmit, +}: { + sshPeers: IPNNetMapPeerNode[] + onSubmit: (def: SSHFormSessionDef) => void +}) { + sshPeers = sshPeers.slice().sort((a, b) => a.name.localeCompare(b.name)) + const [username, setUsername] = useState("") + const [hostname, setHostname] = useState(sshPeers[0].name) + return ( +
{ + e.preventDefault() + onSubmit({ username, hostname }) + }} + > + setUsername(e.currentTarget.value)} + /> +
+ +
+ { + if (e.altKey) { + e.preventDefault() + e.stopPropagation() + onSubmit({ username, hostname, newWindow: true }) + } + }} + /> +
+ ) +} + +const NewWindow = ({ + children, + close, +}: { + children: VNode + close: () => void +}) => { + const newWindow = useMemo(() => { + const newWindow = window.open(undefined, undefined, "width=600,height=400") + if (newWindow) { + const containerNode = newWindow.document.createElement("div") + containerNode.className = "h-screen flex flex-col overflow-hidden" + newWindow.document.body.appendChild(containerNode) + + for (const linkNode of document.querySelectorAll( + "head link[rel=stylesheet]" + )) { + const newLink = document.createElement("link") + newLink.rel = "stylesheet" + newLink.href = (linkNode as HTMLLinkElement).href + newWindow.document.head.appendChild(newLink) + } + } + return newWindow + }, []) + if (!newWindow) { + console.error("Could not open window") + return null + } + newWindow.onbeforeunload = () => { + close() + } + + useEffect(() => () => newWindow.close(), []) + return createPortal(children, newWindow.document.body.lastChild as Element) +} diff --git a/cmd/tsconnect/src/app/url-display.tsx b/cmd/tsconnect/src/app/url-display.tsx index fc82c7fb91b3c..c9b59018108bc 100644 --- a/cmd/tsconnect/src/app/url-display.tsx +++ b/cmd/tsconnect/src/app/url-display.tsx @@ -1,31 +1,31 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -import { useState } from "preact/hooks" -import * as qrcode from "qrcode" - -export function URLDisplay({ url }: { url: string }) { - const [dataURL, setDataURL] = useState("") - qrcode.toDataURL(url, { width: 512 }, (err, dataURL) => { - if (err) { - console.error("Error generating QR code", err) - } else { - setDataURL(dataURL) - } - }) - - return ( - - ) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +import { useState } from "preact/hooks" +import * as qrcode from "qrcode" + +export function URLDisplay({ url }: { url: string }) { + const [dataURL, setDataURL] = useState("") + qrcode.toDataURL(url, { width: 512 }, (err, dataURL) => { + if (err) { + console.error("Error generating QR code", err) + } else { + setDataURL(dataURL) + } + }) + + return ( + + ) +} diff --git a/cmd/tsconnect/src/lib/js-state-store.ts b/cmd/tsconnect/src/lib/js-state-store.ts index e57dfd98efabd..7685e28a9de7c 100644 --- a/cmd/tsconnect/src/lib/js-state-store.ts +++ b/cmd/tsconnect/src/lib/js-state-store.ts @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -/** @fileoverview Callbacks used by jsStateStore to persist IPN state. */ - -export const sessionStateStorage: IPNStateStorage = { - setState(id, value) { - window.sessionStorage[`ipn-state-${id}`] = value - }, - getState(id) { - return window.sessionStorage[`ipn-state-${id}`] || "" - }, -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +/** @fileoverview Callbacks used by jsStateStore to persist IPN state. */ + +export const sessionStateStorage: IPNStateStorage = { + setState(id, value) { + window.sessionStorage[`ipn-state-${id}`] = value + }, + getState(id) { + return window.sessionStorage[`ipn-state-${id}`] || "" + }, +} diff --git a/cmd/tsconnect/src/pkg/pkg.css b/cmd/tsconnect/src/pkg/pkg.css index 76ea21f5b53b2..60146d5b7cca9 100644 --- a/cmd/tsconnect/src/pkg/pkg.css +++ b/cmd/tsconnect/src/pkg/pkg.css @@ -1,8 +1,8 @@ -/* Copyright (c) Tailscale Inc & AUTHORS */ -/* SPDX-License-Identifier: BSD-3-Clause */ - -@import "xterm/css/xterm.css"; - -@tailwind base; -@tailwind components; -@tailwind utilities; +/* Copyright (c) Tailscale Inc & AUTHORS */ +/* SPDX-License-Identifier: BSD-3-Clause */ + +@import "xterm/css/xterm.css"; + +@tailwind base; +@tailwind components; +@tailwind utilities; diff --git a/cmd/tsconnect/src/pkg/pkg.ts b/cmd/tsconnect/src/pkg/pkg.ts index 4d535cb404015..c0dcb5652ec62 100644 --- a/cmd/tsconnect/src/pkg/pkg.ts +++ b/cmd/tsconnect/src/pkg/pkg.ts @@ -1,40 +1,40 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Type definitions need to be manually imported for dts-bundle-generator to -// discover them. -/// -/// - -import "../wasm_exec" -import wasmURL from "./main.wasm" - -/** - * Superset of the IPNConfig type, with additional configuration that is - * needed for the package to function. - */ -type IPNPackageConfig = IPNConfig & { - // Auth key used to initialize the Tailscale client (required) - authKey: string - // URL of the main.wasm file that is included in the page, if it is not - // accessible via a relative URL. - wasmURL?: string - // Function invoked if the Go process panics or unexpectedly exits. - panicHandler: (err: string) => void -} - -export async function createIPN(config: IPNPackageConfig): Promise { - const go = new Go() - const wasmInstance = await WebAssembly.instantiateStreaming( - fetch(config.wasmURL ?? wasmURL), - go.importObject - ) - // The Go process should never exit, if it does then it's an unhandled panic. - go.run(wasmInstance.instance).then(() => - config.panicHandler("Unexpected shutdown") - ) - - return newIPN(config) -} - -export { runSSHSession } from "../lib/ssh" +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Type definitions need to be manually imported for dts-bundle-generator to +// discover them. +/// +/// + +import "../wasm_exec" +import wasmURL from "./main.wasm" + +/** + * Superset of the IPNConfig type, with additional configuration that is + * needed for the package to function. + */ +type IPNPackageConfig = IPNConfig & { + // Auth key used to initialize the Tailscale client (required) + authKey: string + // URL of the main.wasm file that is included in the page, if it is not + // accessible via a relative URL. + wasmURL?: string + // Function invoked if the Go process panics or unexpectedly exits. + panicHandler: (err: string) => void +} + +export async function createIPN(config: IPNPackageConfig): Promise { + const go = new Go() + const wasmInstance = await WebAssembly.instantiateStreaming( + fetch(config.wasmURL ?? wasmURL), + go.importObject + ) + // The Go process should never exit, if it does then it's an unhandled panic. + go.run(wasmInstance.instance).then(() => + config.panicHandler("Unexpected shutdown") + ) + + return newIPN(config) +} + +export { runSSHSession } from "../lib/ssh" diff --git a/cmd/tsconnect/src/types/esbuild.d.ts b/cmd/tsconnect/src/types/esbuild.d.ts index ef28f7b1cf556..7153b4244e7c5 100644 --- a/cmd/tsconnect/src/types/esbuild.d.ts +++ b/cmd/tsconnect/src/types/esbuild.d.ts @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -/** - * @fileoverview Type definitions for types generated by the esbuild build - * process. - */ - -declare module "*.wasm" { - const path: string - export default path -} - -declare const DEBUG: boolean +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +/** + * @fileoverview Type definitions for types generated by the esbuild build + * process. + */ + +declare module "*.wasm" { + const path: string + export default path +} + +declare const DEBUG: boolean diff --git a/cmd/tsconnect/src/types/wasm_js.d.ts b/cmd/tsconnect/src/types/wasm_js.d.ts index 492197ccb1a9b..82822c508040e 100644 --- a/cmd/tsconnect/src/types/wasm_js.d.ts +++ b/cmd/tsconnect/src/types/wasm_js.d.ts @@ -1,103 +1,103 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -/** - * @fileoverview Type definitions for types exported by the wasm_js.go Go - * module. - */ - -declare global { - function newIPN(config: IPNConfig): IPN - - interface IPN { - run(callbacks: IPNCallbacks): void - login(): void - logout(): void - ssh( - host: string, - username: string, - termConfig: { - writeFn: (data: string) => void - writeErrorFn: (err: string) => void - setReadFn: (readFn: (data: string) => void) => void - rows: number - cols: number - /** Defaults to 5 seconds */ - timeoutSeconds?: number - onConnectionProgress: (message: string) => void - onConnected: () => void - onDone: () => void - } - ): IPNSSHSession - fetch(url: string): Promise<{ - status: number - statusText: string - text: () => Promise - }> - } - - interface IPNSSHSession { - resize(rows: number, cols: number): boolean - close(): boolean - } - - interface IPNStateStorage { - setState(id: string, value: string): void - getState(id: string): string - } - - type IPNConfig = { - stateStorage?: IPNStateStorage - authKey?: string - controlURL?: string - hostname?: string - } - - type IPNCallbacks = { - notifyState: (state: IPNState) => void - notifyNetMap: (netMapStr: string) => void - notifyBrowseToURL: (url: string) => void - notifyPanicRecover: (err: string) => void - } - - type IPNNetMap = { - self: IPNNetMapSelfNode - peers: IPNNetMapPeerNode[] - lockedOut: boolean - } - - type IPNNetMapNode = { - name: string - addresses: string[] - machineKey: string - nodeKey: string - } - - type IPNNetMapSelfNode = IPNNetMapNode & { - machineStatus: IPNMachineStatus - } - - type IPNNetMapPeerNode = IPNNetMapNode & { - online?: boolean - tailscaleSSHEnabled: boolean - } - - /** Mirrors values from ipn/backend.go */ - type IPNState = - | "NoState" - | "InUseOtherUser" - | "NeedsLogin" - | "NeedsMachineAuth" - | "Stopped" - | "Starting" - | "Running" - - /** Mirrors values from MachineStatus in tailcfg.go */ - type IPNMachineStatus = - | "MachineUnknown" - | "MachineUnauthorized" - | "MachineAuthorized" - | "MachineInvalid" -} - -export {} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +/** + * @fileoverview Type definitions for types exported by the wasm_js.go Go + * module. + */ + +declare global { + function newIPN(config: IPNConfig): IPN + + interface IPN { + run(callbacks: IPNCallbacks): void + login(): void + logout(): void + ssh( + host: string, + username: string, + termConfig: { + writeFn: (data: string) => void + writeErrorFn: (err: string) => void + setReadFn: (readFn: (data: string) => void) => void + rows: number + cols: number + /** Defaults to 5 seconds */ + timeoutSeconds?: number + onConnectionProgress: (message: string) => void + onConnected: () => void + onDone: () => void + } + ): IPNSSHSession + fetch(url: string): Promise<{ + status: number + statusText: string + text: () => Promise + }> + } + + interface IPNSSHSession { + resize(rows: number, cols: number): boolean + close(): boolean + } + + interface IPNStateStorage { + setState(id: string, value: string): void + getState(id: string): string + } + + type IPNConfig = { + stateStorage?: IPNStateStorage + authKey?: string + controlURL?: string + hostname?: string + } + + type IPNCallbacks = { + notifyState: (state: IPNState) => void + notifyNetMap: (netMapStr: string) => void + notifyBrowseToURL: (url: string) => void + notifyPanicRecover: (err: string) => void + } + + type IPNNetMap = { + self: IPNNetMapSelfNode + peers: IPNNetMapPeerNode[] + lockedOut: boolean + } + + type IPNNetMapNode = { + name: string + addresses: string[] + machineKey: string + nodeKey: string + } + + type IPNNetMapSelfNode = IPNNetMapNode & { + machineStatus: IPNMachineStatus + } + + type IPNNetMapPeerNode = IPNNetMapNode & { + online?: boolean + tailscaleSSHEnabled: boolean + } + + /** Mirrors values from ipn/backend.go */ + type IPNState = + | "NoState" + | "InUseOtherUser" + | "NeedsLogin" + | "NeedsMachineAuth" + | "Stopped" + | "Starting" + | "Running" + + /** Mirrors values from MachineStatus in tailcfg.go */ + type IPNMachineStatus = + | "MachineUnknown" + | "MachineUnauthorized" + | "MachineAuthorized" + | "MachineInvalid" +} + +export {} diff --git a/cmd/tsconnect/tailwind.config.js b/cmd/tsconnect/tailwind.config.js index 31823000b6139..38bc5b97b714e 100644 --- a/cmd/tsconnect/tailwind.config.js +++ b/cmd/tsconnect/tailwind.config.js @@ -1,8 +1,8 @@ -/** @type {import('tailwindcss').Config} */ -module.exports = { - content: ["./index.html", "./src/**/*.ts", "./src/**/*.tsx"], - theme: { - extend: {}, - }, - plugins: [], -} +/** @type {import('tailwindcss').Config} */ +module.exports = { + content: ["./index.html", "./src/**/*.ts", "./src/**/*.tsx"], + theme: { + extend: {}, + }, + plugins: [], +} diff --git a/cmd/tsconnect/tsconfig.json b/cmd/tsconnect/tsconfig.json index 52c25c7271f7c..1148e2ef0c43a 100644 --- a/cmd/tsconnect/tsconfig.json +++ b/cmd/tsconnect/tsconfig.json @@ -1,15 +1,15 @@ -{ - "compilerOptions": { - "target": "ES2017", - "module": "ES2020", - "moduleResolution": "node", - "isolatedModules": true, - "strict": true, - "forceConsistentCasingInFileNames": true, - "sourceMap": true, - "jsx": "react-jsx", - "jsxImportSource": "preact" - }, - "include": ["src/**/*"], - "exclude": ["node_modules"] -} +{ + "compilerOptions": { + "target": "ES2017", + "module": "ES2020", + "moduleResolution": "node", + "isolatedModules": true, + "strict": true, + "forceConsistentCasingInFileNames": true, + "sourceMap": true, + "jsx": "react-jsx", + "jsxImportSource": "preact" + }, + "include": ["src/**/*"], + "exclude": ["node_modules"] +} diff --git a/cmd/tsconnect/tsconnect.go b/cmd/tsconnect/tsconnect.go index 4c8a0a52ece34..60ea6ef822d99 100644 --- a/cmd/tsconnect/tsconnect.go +++ b/cmd/tsconnect/tsconnect.go @@ -1,71 +1,71 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -// The tsconnect command builds and serves the static site that is generated for -// the Tailscale Connect JS/WASM client. Can be run in 3 modes: -// - dev: builds the site and serves it. JS and CSS changes can be picked up -// with a reload. -// - build: builds the site and writes it to dist/ -// - serve: serves the site from dist/ (embedded in the binary) -package main // import "tailscale.com/cmd/tsconnect" - -import ( - "flag" - "fmt" - "log" - "os" -) - -var ( - addr = flag.String("addr", ":9090", "address to listen on") - distDir = flag.String("distdir", "./dist", "path of directory to place build output in") - pkgDir = flag.String("pkgdir", "./pkg", "path of directory to place NPM package build output in") - yarnPath = flag.String("yarnpath", "", "path yarn executable used to install JavaScript dependencies") - fastCompression = flag.Bool("fast-compression", false, "Use faster compression when building, to speed up build time. Meant to iterative/debugging use only.") - devControl = flag.String("dev-control", "", "URL of a development control server to be used with dev. If provided without specifying dev, an error will be returned.") - rootDir = flag.String("rootdir", "", "Root directory of repo. If not specified, will be inferred from the cwd.") -) - -func main() { - flag.Usage = usage - flag.Parse() - if len(flag.Args()) != 1 { - flag.Usage() - } - - switch flag.Arg(0) { - case "dev": - runDev() - case "dev-pkg": - runDevPkg() - case "build": - runBuild() - case "build-pkg": - runBuildPkg() - case "serve": - runServe() - default: - log.Printf("Unknown command: %s", flag.Arg(0)) - flag.Usage() - } -} - -func usage() { - fmt.Fprintf(os.Stderr, ` -usage: tsconnect {dev|build|serve} -`[1:]) - - flag.PrintDefaults() - fmt.Fprintf(os.Stderr, ` - -tsconnect implements development/build/serving workflows for Tailscale Connect. -It can be invoked with one of three subcommands: - -- dev: Run in development mode, allowing JS and CSS changes to be picked up without a rebuilt or restart. -- build: Run in production build mode (generating static assets) -- serve: Run in production serve mode (serving static assets) -`[1:]) - os.Exit(2) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// The tsconnect command builds and serves the static site that is generated for +// the Tailscale Connect JS/WASM client. Can be run in 3 modes: +// - dev: builds the site and serves it. JS and CSS changes can be picked up +// with a reload. +// - build: builds the site and writes it to dist/ +// - serve: serves the site from dist/ (embedded in the binary) +package main // import "tailscale.com/cmd/tsconnect" + +import ( + "flag" + "fmt" + "log" + "os" +) + +var ( + addr = flag.String("addr", ":9090", "address to listen on") + distDir = flag.String("distdir", "./dist", "path of directory to place build output in") + pkgDir = flag.String("pkgdir", "./pkg", "path of directory to place NPM package build output in") + yarnPath = flag.String("yarnpath", "", "path yarn executable used to install JavaScript dependencies") + fastCompression = flag.Bool("fast-compression", false, "Use faster compression when building, to speed up build time. Meant to iterative/debugging use only.") + devControl = flag.String("dev-control", "", "URL of a development control server to be used with dev. If provided without specifying dev, an error will be returned.") + rootDir = flag.String("rootdir", "", "Root directory of repo. If not specified, will be inferred from the cwd.") +) + +func main() { + flag.Usage = usage + flag.Parse() + if len(flag.Args()) != 1 { + flag.Usage() + } + + switch flag.Arg(0) { + case "dev": + runDev() + case "dev-pkg": + runDevPkg() + case "build": + runBuild() + case "build-pkg": + runBuildPkg() + case "serve": + runServe() + default: + log.Printf("Unknown command: %s", flag.Arg(0)) + flag.Usage() + } +} + +func usage() { + fmt.Fprintf(os.Stderr, ` +usage: tsconnect {dev|build|serve} +`[1:]) + + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, ` + +tsconnect implements development/build/serving workflows for Tailscale Connect. +It can be invoked with one of three subcommands: + +- dev: Run in development mode, allowing JS and CSS changes to be picked up without a rebuilt or restart. +- build: Run in production build mode (generating static assets) +- serve: Run in production serve mode (serving static assets) +`[1:]) + os.Exit(2) +} diff --git a/cmd/tsconnect/yarn.lock b/cmd/tsconnect/yarn.lock index 663a1244ebf69..914b4e6d041f7 100644 --- a/cmd/tsconnect/yarn.lock +++ b/cmd/tsconnect/yarn.lock @@ -1,713 +1,713 @@ -# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY. -# yarn lockfile v1 - - -"@nodelib/fs.scandir@2.1.5": - version "2.1.5" - resolved "https://registry.yarnpkg.com/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz#7619c2eb21b25483f6d167548b4cfd5a7488c3d5" - integrity sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g== - dependencies: - "@nodelib/fs.stat" "2.0.5" - run-parallel "^1.1.9" - -"@nodelib/fs.stat@2.0.5", "@nodelib/fs.stat@^2.0.2": - version "2.0.5" - resolved "https://registry.yarnpkg.com/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz#5bd262af94e9d25bd1e71b05deed44876a222e8b" - integrity sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A== - -"@nodelib/fs.walk@^1.2.3": - version "1.2.8" - resolved "https://registry.yarnpkg.com/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz#e95737e8bb6746ddedf69c556953494f196fe69a" - integrity sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg== - dependencies: - "@nodelib/fs.scandir" "2.1.5" - fastq "^1.6.0" - -"@types/golang-wasm-exec@^1.15.0": - version "1.15.0" - resolved "https://registry.yarnpkg.com/@types/golang-wasm-exec/-/golang-wasm-exec-1.15.0.tgz#d0aafbb2b0dc07eaf45dfb83bfb6cdd5b2b3c55c" - integrity sha512-FrL97mp7WW8LqNinVkzTVKOIQKuYjQqgucnh41+1vRQ+bf1LT8uh++KRf9otZPXsa6H1p8ruIGz1BmCGttOL6Q== - -"@types/node@*": - version "18.6.1" - resolved "https://registry.yarnpkg.com/@types/node/-/node-18.6.1.tgz#828e4785ccca13f44e2fb6852ae0ef11e3e20ba5" - integrity sha512-z+2vB6yDt1fNwKOeGbckpmirO+VBDuQqecXkgeIqDlaOtmKn6hPR/viQ8cxCfqLU4fTlvM3+YjM367TukWdxpg== - -"@types/qrcode@^1.4.2": - version "1.4.2" - resolved "https://registry.yarnpkg.com/@types/qrcode/-/qrcode-1.4.2.tgz#7d7142d6fa9921f195db342ed08b539181546c74" - integrity sha512-7uNT9L4WQTNJejHTSTdaJhfBSCN73xtXaHFyBJ8TSwiLhe4PRuTue7Iph0s2nG9R/ifUaSnGhLUOZavlBEqDWQ== - dependencies: - "@types/node" "*" - -acorn-node@^1.8.2: - version "1.8.2" - resolved "https://registry.yarnpkg.com/acorn-node/-/acorn-node-1.8.2.tgz#114c95d64539e53dede23de8b9d96df7c7ae2af8" - integrity sha512-8mt+fslDufLYntIoPAaIMUe/lrbrehIiwmR3t2k9LljIzoigEPF27eLk2hy8zSGzmR/ogr7zbRKINMo1u0yh5A== - dependencies: - acorn "^7.0.0" - acorn-walk "^7.0.0" - xtend "^4.0.2" - -acorn-walk@^7.0.0: - version "7.2.0" - resolved "https://registry.yarnpkg.com/acorn-walk/-/acorn-walk-7.2.0.tgz#0de889a601203909b0fbe07b8938dc21d2e967bc" - integrity sha512-OPdCF6GsMIP+Az+aWfAAOEt2/+iVDKE7oy6lJ098aoe59oAmK76qV6Gw60SbZ8jHuG2wH058GF4pLFbYamYrVA== - -acorn@^7.0.0: - version "7.4.1" - resolved "https://registry.yarnpkg.com/acorn/-/acorn-7.4.1.tgz#feaed255973d2e77555b83dbc08851a6c63520fa" - integrity sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A== - -ansi-regex@^5.0.1: - version "5.0.1" - resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" - integrity sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ== - -ansi-styles@^4.0.0: - version "4.3.0" - resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937" - integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg== - dependencies: - color-convert "^2.0.1" - -anymatch@~3.1.2: - version "3.1.2" - resolved "https://registry.yarnpkg.com/anymatch/-/anymatch-3.1.2.tgz#c0557c096af32f106198f4f4e2a383537e378716" - integrity sha512-P43ePfOAIupkguHUycrc4qJ9kz8ZiuOUijaETwX7THt0Y/GNK7v0aa8rY816xWjZ7rJdA5XdMcpVFTKMq+RvWg== - dependencies: - normalize-path "^3.0.0" - picomatch "^2.0.4" - -arg@^5.0.2: - version "5.0.2" - resolved "https://registry.yarnpkg.com/arg/-/arg-5.0.2.tgz#c81433cc427c92c4dcf4865142dbca6f15acd59c" - integrity sha512-PYjyFOLKQ9y57JvQ6QLo8dAgNqswh8M1RMJYdQduT6xbWSgK36P/Z/v+p888pM69jMMfS8Xd8F6I1kQ/I9HUGg== - -binary-extensions@^2.0.0: - version "2.2.0" - resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-2.2.0.tgz#75f502eeaf9ffde42fc98829645be4ea76bd9e2d" - integrity sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA== - -braces@^3.0.2, braces@~3.0.2: - version "3.0.2" - resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.2.tgz#3454e1a462ee8d599e236df336cd9ea4f8afe107" - integrity sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A== - dependencies: - fill-range "^7.0.1" - -camelcase-css@^2.0.1: - version "2.0.1" - resolved "https://registry.yarnpkg.com/camelcase-css/-/camelcase-css-2.0.1.tgz#ee978f6947914cc30c6b44741b6ed1df7f043fd5" - integrity sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA== - -camelcase@^5.0.0: - version "5.3.1" - resolved "https://registry.yarnpkg.com/camelcase/-/camelcase-5.3.1.tgz#e3c9b31569e106811df242f715725a1f4c494320" - integrity sha512-L28STB170nwWS63UjtlEOE3dldQApaJXZkOI1uMFfzf3rRuPegHaHesyee+YxQ+W6SvRDQV6UrdOdRiR153wJg== - -chokidar@^3.5.3: - version "3.5.3" - resolved "https://registry.yarnpkg.com/chokidar/-/chokidar-3.5.3.tgz#1cf37c8707b932bd1af1ae22c0432e2acd1903bd" - integrity sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw== - dependencies: - anymatch "~3.1.2" - braces "~3.0.2" - glob-parent "~5.1.2" - is-binary-path "~2.1.0" - is-glob "~4.0.1" - normalize-path "~3.0.0" - readdirp "~3.6.0" - optionalDependencies: - fsevents "~2.3.2" - -cliui@^6.0.0: - version "6.0.0" - resolved "https://registry.yarnpkg.com/cliui/-/cliui-6.0.0.tgz#511d702c0c4e41ca156d7d0e96021f23e13225b1" - integrity sha512-t6wbgtoCXvAzst7QgXxJYqPt0usEfbgQdftEPbLL/cvv6HPE5VgvqCuAIDR0NgU52ds6rFwqrgakNLrHEjCbrQ== - dependencies: - string-width "^4.2.0" - strip-ansi "^6.0.0" - wrap-ansi "^6.2.0" - -cliui@^7.0.2: - version "7.0.4" - resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.4.tgz#a0265ee655476fc807aea9df3df8df7783808b4f" - integrity sha512-OcRE68cOsVMXp1Yvonl/fzkQOyjLSu/8bhPDfQt0e0/Eb283TKP20Fs2MqoPsr9SwA595rRCA+QMzYc9nBP+JQ== - dependencies: - string-width "^4.2.0" - strip-ansi "^6.0.0" - wrap-ansi "^7.0.0" - -color-convert@^2.0.1: - version "2.0.1" - resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3" - integrity sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ== - dependencies: - color-name "~1.1.4" - -color-name@^1.1.4, color-name@~1.1.4: - version "1.1.4" - resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2" - integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA== - -cssesc@^3.0.0: - version "3.0.0" - resolved "https://registry.yarnpkg.com/cssesc/-/cssesc-3.0.0.tgz#37741919903b868565e1c09ea747445cd18983ee" - integrity sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg== - -decamelize@^1.2.0: - version "1.2.0" - resolved "https://registry.yarnpkg.com/decamelize/-/decamelize-1.2.0.tgz#f6534d15148269b20352e7bee26f501f9a191290" - integrity sha1-9lNNFRSCabIDUue+4m9QH5oZEpA= - -defined@^1.0.0: - version "1.0.0" - resolved "https://registry.yarnpkg.com/defined/-/defined-1.0.0.tgz#c98d9bcef75674188e110969151199e39b1fa693" - integrity sha512-Y2caI5+ZwS5c3RiNDJ6u53VhQHv+hHKwhkI1iHvceKUHw9Df6EK2zRLfjejRgMuCuxK7PfSWIMwWecceVvThjQ== - -detective@^5.2.1: - version "5.2.1" - resolved "https://registry.yarnpkg.com/detective/-/detective-5.2.1.tgz#6af01eeda11015acb0e73f933242b70f24f91034" - integrity sha512-v9XE1zRnz1wRtgurGu0Bs8uHKFSTdteYZNbIPFVhUZ39L/S79ppMpdmVOZAnoz1jfEFodc48n6MX483Xo3t1yw== - dependencies: - acorn-node "^1.8.2" - defined "^1.0.0" - minimist "^1.2.6" - -didyoumean@^1.2.2: - version "1.2.2" - resolved "https://registry.yarnpkg.com/didyoumean/-/didyoumean-1.2.2.tgz#989346ffe9e839b4555ecf5666edea0d3e8ad037" - integrity sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw== - -dijkstrajs@^1.0.1: - version "1.0.2" - resolved "https://registry.yarnpkg.com/dijkstrajs/-/dijkstrajs-1.0.2.tgz#2e48c0d3b825462afe75ab4ad5e829c8ece36257" - integrity sha512-QV6PMaHTCNmKSeP6QoXhVTw9snc9VD8MulTT0Bd99Pacp4SS1cjcrYPgBPmibqKVtMJJfqC6XvOXgPMEEPH/fg== - -dlv@^1.1.3: - version "1.1.3" - resolved "https://registry.yarnpkg.com/dlv/-/dlv-1.1.3.tgz#5c198a8a11453596e751494d49874bc7732f2e79" - integrity sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA== - -dts-bundle-generator@^6.12.0: - version "6.12.0" - resolved "https://registry.yarnpkg.com/dts-bundle-generator/-/dts-bundle-generator-6.12.0.tgz#0a221bdce5fdd309a56c8556e645f16ed87ab07d" - integrity sha512-k/QAvuVaLIdyWRUHduDrWBe4j8PcE6TDt06+f32KHbW7/SmUPbX1O23fFtQgKwUyTBkbIjJFOFtNrF97tJcKug== - dependencies: - typescript ">=3.0.1" - yargs "^17.2.1" - -emoji-regex@^8.0.0: - version "8.0.0" - resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-8.0.0.tgz#e818fd69ce5ccfcb404594f842963bf53164cc37" - integrity sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A== - -encode-utf8@^1.0.3: - version "1.0.3" - resolved "https://registry.yarnpkg.com/encode-utf8/-/encode-utf8-1.0.3.tgz#f30fdd31da07fb596f281beb2f6b027851994cda" - integrity sha512-ucAnuBEhUK4boH2HjVYG5Q2mQyPorvv0u/ocS+zhdw0S8AlHYY+GOFhP1Gio5z4icpP2ivFSvhtFjQi8+T9ppw== - -escalade@^3.1.1: - version "3.1.1" - resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" - integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== - -fast-glob@^3.2.11: - version "3.2.11" - resolved "https://registry.yarnpkg.com/fast-glob/-/fast-glob-3.2.11.tgz#a1172ad95ceb8a16e20caa5c5e56480e5129c1d9" - integrity sha512-xrO3+1bxSo3ZVHAnqzyuewYT6aMFHRAd4Kcs92MAonjwQZLsK9d0SF1IyQ3k5PoirxTW0Oe/RqFgMQ6TcNE5Ew== - dependencies: - "@nodelib/fs.stat" "^2.0.2" - "@nodelib/fs.walk" "^1.2.3" - glob-parent "^5.1.2" - merge2 "^1.3.0" - micromatch "^4.0.4" - -fastq@^1.6.0: - version "1.13.0" - resolved "https://registry.yarnpkg.com/fastq/-/fastq-1.13.0.tgz#616760f88a7526bdfc596b7cab8c18938c36b98c" - integrity sha512-YpkpUnK8od0o1hmeSc7UUs/eB/vIPWJYjKck2QKIzAf71Vm1AAQ3EbuZB3g2JIy+pg+ERD0vqI79KyZiB2e2Nw== - dependencies: - reusify "^1.0.4" - -fill-range@^7.0.1: - version "7.0.1" - resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.0.1.tgz#1919a6a7c75fe38b2c7c77e5198535da9acdda40" - integrity sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ== - dependencies: - to-regex-range "^5.0.1" - -find-up@^4.1.0: - version "4.1.0" - resolved "https://registry.yarnpkg.com/find-up/-/find-up-4.1.0.tgz#97afe7d6cdc0bc5928584b7c8d7b16e8a9aa5d19" - integrity sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw== - dependencies: - locate-path "^5.0.0" - path-exists "^4.0.0" - -fsevents@~2.3.2: - version "2.3.2" - resolved "https://registry.yarnpkg.com/fsevents/-/fsevents-2.3.2.tgz#8a526f78b8fdf4623b709e0b975c52c24c02fd1a" - integrity sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA== - -function-bind@^1.1.1: - version "1.1.1" - resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" - integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== - -get-caller-file@^2.0.1, get-caller-file@^2.0.5: - version "2.0.5" - resolved "https://registry.yarnpkg.com/get-caller-file/-/get-caller-file-2.0.5.tgz#4f94412a82db32f36e3b0b9741f8a97feb031f7e" - integrity sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg== - -glob-parent@^5.1.2, glob-parent@~5.1.2: - version "5.1.2" - resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-5.1.2.tgz#869832c58034fe68a4093c17dc15e8340d8401c4" - integrity sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow== - dependencies: - is-glob "^4.0.1" - -glob-parent@^6.0.2: - version "6.0.2" - resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-6.0.2.tgz#6d237d99083950c79290f24c7642a3de9a28f9e3" - integrity sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A== - dependencies: - is-glob "^4.0.3" - -has@^1.0.3: - version "1.0.3" - resolved "https://registry.yarnpkg.com/has/-/has-1.0.3.tgz#722d7cbfc1f6aa8241f16dd814e011e1f41e8796" - integrity sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw== - dependencies: - function-bind "^1.1.1" - -is-binary-path@~2.1.0: - version "2.1.0" - resolved "https://registry.yarnpkg.com/is-binary-path/-/is-binary-path-2.1.0.tgz#ea1f7f3b80f064236e83470f86c09c254fb45b09" - integrity sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw== - dependencies: - binary-extensions "^2.0.0" - -is-core-module@^2.9.0: - version "2.9.0" - resolved "https://registry.yarnpkg.com/is-core-module/-/is-core-module-2.9.0.tgz#e1c34429cd51c6dd9e09e0799e396e27b19a9c69" - integrity sha512-+5FPy5PnwmO3lvfMb0AsoPaBG+5KHUI0wYFXOtYPnVVVspTFUuMZNfNaNVRt3FZadstu2c8x23vykRW/NBoU6A== - dependencies: - has "^1.0.3" - -is-extglob@^2.1.1: - version "2.1.1" - resolved "https://registry.yarnpkg.com/is-extglob/-/is-extglob-2.1.1.tgz#a88c02535791f02ed37c76a1b9ea9773c833f8c2" - integrity sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ== - -is-fullwidth-code-point@^3.0.0: - version "3.0.0" - resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz#f116f8064fe90b3f7844a38997c0b75051269f1d" - integrity sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg== - -is-glob@^4.0.1, is-glob@^4.0.3, is-glob@~4.0.1: - version "4.0.3" - resolved "https://registry.yarnpkg.com/is-glob/-/is-glob-4.0.3.tgz#64f61e42cbbb2eec2071a9dac0b28ba1e65d5084" - integrity sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg== - dependencies: - is-extglob "^2.1.1" - -is-number@^7.0.0: - version "7.0.0" - resolved "https://registry.yarnpkg.com/is-number/-/is-number-7.0.0.tgz#7535345b896734d5f80c4d06c50955527a14f12b" - integrity sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng== - -lilconfig@^2.0.5: - version "2.0.6" - resolved "https://registry.yarnpkg.com/lilconfig/-/lilconfig-2.0.6.tgz#32a384558bd58af3d4c6e077dd1ad1d397bc69d4" - integrity sha512-9JROoBW7pobfsx+Sq2JsASvCo6Pfo6WWoUW79HuB1BCoBXD4PLWJPqDF6fNj67pqBYTbAHkE57M1kS/+L1neOg== - -locate-path@^5.0.0: - version "5.0.0" - resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-5.0.0.tgz#1afba396afd676a6d42504d0a67a3a7eb9f62aa0" - integrity sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g== - dependencies: - p-locate "^4.1.0" - -merge2@^1.3.0: - version "1.4.1" - resolved "https://registry.yarnpkg.com/merge2/-/merge2-1.4.1.tgz#4368892f885e907455a6fd7dc55c0c9d404990ae" - integrity sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg== - -micromatch@^4.0.4: - version "4.0.5" - resolved "https://registry.yarnpkg.com/micromatch/-/micromatch-4.0.5.tgz#bc8999a7cbbf77cdc89f132f6e467051b49090c6" - integrity sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA== - dependencies: - braces "^3.0.2" - picomatch "^2.3.1" - -minimist@^1.2.6: - version "1.2.6" - resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.6.tgz#8637a5b759ea0d6e98702cfb3a9283323c93af44" - integrity sha512-Jsjnk4bw3YJqYzbdyBiNsPWHPfO++UGG749Cxs6peCu5Xg4nrena6OVxOYxrQTqww0Jmwt+Ref8rggumkTLz9Q== - -nanoid@^3.3.4: - version "3.3.4" - resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.4.tgz#730b67e3cd09e2deacf03c027c81c9d9dbc5e8ab" - integrity sha512-MqBkQh/OHTS2egovRtLk45wEyNXwF+cokD+1YPf9u5VfJiRdAiRwB2froX5Co9Rh20xs4siNPm8naNotSD6RBw== - -normalize-path@^3.0.0, normalize-path@~3.0.0: - version "3.0.0" - resolved "https://registry.yarnpkg.com/normalize-path/-/normalize-path-3.0.0.tgz#0dcd69ff23a1c9b11fd0978316644a0388216a65" - integrity sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA== - -object-hash@^3.0.0: - version "3.0.0" - resolved "https://registry.yarnpkg.com/object-hash/-/object-hash-3.0.0.tgz#73f97f753e7baffc0e2cc9d6e079079744ac82e9" - integrity sha512-RSn9F68PjH9HqtltsSnqYC1XXoWe9Bju5+213R98cNGttag9q9yAOTzdbsqvIa7aNm5WffBZFpWYr2aWrklWAw== - -p-limit@^2.2.0: - version "2.3.0" - resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-2.3.0.tgz#3dd33c647a214fdfffd835933eb086da0dc21db1" - integrity sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w== - dependencies: - p-try "^2.0.0" - -p-locate@^4.1.0: - version "4.1.0" - resolved "https://registry.yarnpkg.com/p-locate/-/p-locate-4.1.0.tgz#a3428bb7088b3a60292f66919278b7c297ad4f07" - integrity sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A== - dependencies: - p-limit "^2.2.0" - -p-try@^2.0.0: - version "2.2.0" - resolved "https://registry.yarnpkg.com/p-try/-/p-try-2.2.0.tgz#cb2868540e313d61de58fafbe35ce9004d5540e6" - integrity sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ== - -path-exists@^4.0.0: - version "4.0.0" - resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-4.0.0.tgz#513bdbe2d3b95d7762e8c1137efa195c6c61b5b3" - integrity sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w== - -path-parse@^1.0.7: - version "1.0.7" - resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.7.tgz#fbc114b60ca42b30d9daf5858e4bd68bbedb6735" - integrity sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw== - -picocolors@^1.0.0: - version "1.0.0" - resolved "https://registry.yarnpkg.com/picocolors/-/picocolors-1.0.0.tgz#cb5bdc74ff3f51892236eaf79d68bc44564ab81c" - integrity sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ== - -picomatch@^2.0.4, picomatch@^2.2.1, picomatch@^2.3.1: - version "2.3.1" - resolved "https://registry.yarnpkg.com/picomatch/-/picomatch-2.3.1.tgz#3ba3833733646d9d3e4995946c1365a67fb07a42" - integrity sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA== - -pify@^2.3.0: - version "2.3.0" - resolved "https://registry.yarnpkg.com/pify/-/pify-2.3.0.tgz#ed141a6ac043a849ea588498e7dca8b15330e90c" - integrity sha512-udgsAY+fTnvv7kI7aaxbqwWNb0AHiB0qBO89PZKPkoTmGOgdbrHDKD+0B2X4uTfJ/FT1R09r9gTsjUjNJotuog== - -pngjs@^5.0.0: - version "5.0.0" - resolved "https://registry.yarnpkg.com/pngjs/-/pngjs-5.0.0.tgz#e79dd2b215767fd9c04561c01236df960bce7fbb" - integrity sha512-40QW5YalBNfQo5yRYmiw7Yz6TKKVr3h6970B2YE+3fQpsWcrbj1PzJgxeJ19DRQjhMbKPIuMY8rFaXc8moolVw== - -postcss-import@^14.1.0: - version "14.1.0" - resolved "https://registry.yarnpkg.com/postcss-import/-/postcss-import-14.1.0.tgz#a7333ffe32f0b8795303ee9e40215dac922781f0" - integrity sha512-flwI+Vgm4SElObFVPpTIT7SU7R3qk2L7PyduMcokiaVKuWv9d/U+Gm/QAd8NDLuykTWTkcrjOeD2Pp1rMeBTGw== - dependencies: - postcss-value-parser "^4.0.0" - read-cache "^1.0.0" - resolve "^1.1.7" - -postcss-js@^4.0.0: - version "4.0.0" - resolved "https://registry.yarnpkg.com/postcss-js/-/postcss-js-4.0.0.tgz#31db79889531b80dc7bc9b0ad283e418dce0ac00" - integrity sha512-77QESFBwgX4irogGVPgQ5s07vLvFqWr228qZY+w6lW599cRlK/HmnlivnnVUxkjHnCu4J16PDMHcH+e+2HbvTQ== - dependencies: - camelcase-css "^2.0.1" - -postcss-load-config@^3.1.4: - version "3.1.4" - resolved "https://registry.yarnpkg.com/postcss-load-config/-/postcss-load-config-3.1.4.tgz#1ab2571faf84bb078877e1d07905eabe9ebda855" - integrity sha512-6DiM4E7v4coTE4uzA8U//WhtPwyhiim3eyjEMFCnUpzbrkK9wJHgKDT2mR+HbtSrd/NubVaYTOpSpjUl8NQeRg== - dependencies: - lilconfig "^2.0.5" - yaml "^1.10.2" - -postcss-nested@5.0.6: - version "5.0.6" - resolved "https://registry.yarnpkg.com/postcss-nested/-/postcss-nested-5.0.6.tgz#466343f7fc8d3d46af3e7dba3fcd47d052a945bc" - integrity sha512-rKqm2Fk0KbA8Vt3AdGN0FB9OBOMDVajMG6ZCf/GoHgdxUJ4sBFp0A/uMIRm+MJUdo33YXEtjqIz8u7DAp8B7DA== - dependencies: - postcss-selector-parser "^6.0.6" - -postcss-selector-parser@^6.0.10, postcss-selector-parser@^6.0.6: - version "6.0.10" - resolved "https://registry.yarnpkg.com/postcss-selector-parser/-/postcss-selector-parser-6.0.10.tgz#79b61e2c0d1bfc2602d549e11d0876256f8df88d" - integrity sha512-IQ7TZdoaqbT+LCpShg46jnZVlhWD2w6iQYAcYXfHARZ7X1t/UGhhceQDs5X0cGqKvYlHNOuv7Oa1xmb0oQuA3w== - dependencies: - cssesc "^3.0.0" - util-deprecate "^1.0.2" - -postcss-value-parser@^4.0.0, postcss-value-parser@^4.2.0: - version "4.2.0" - resolved "https://registry.yarnpkg.com/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz#723c09920836ba6d3e5af019f92bc0971c02e514" - integrity sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ== - -postcss@^8.4.14: - version "8.4.14" - resolved "https://registry.yarnpkg.com/postcss/-/postcss-8.4.14.tgz#ee9274d5622b4858c1007a74d76e42e56fd21caf" - integrity sha512-E398TUmfAYFPBSdzgeieK2Y1+1cpdxJx8yXbK/m57nRhKSmk1GB2tO4lbLBtlkfPQTDKfe4Xqv1ASWPpayPEig== - dependencies: - nanoid "^3.3.4" - picocolors "^1.0.0" - source-map-js "^1.0.2" - -preact@^10.10.0: - version "10.10.0" - resolved "https://registry.yarnpkg.com/preact/-/preact-10.10.0.tgz#7434750a24b59dae1957d95dc0aa47a4a8e9a180" - integrity sha512-fszkg1iJJjq68I4lI8ZsmBiaoQiQHbxf1lNq+72EmC/mZOsFF5zn3k1yv9QGoFgIXzgsdSKtYymLJsrJPoamjQ== - -qrcode@^1.5.0: - version "1.5.0" - resolved "https://registry.yarnpkg.com/qrcode/-/qrcode-1.5.0.tgz#95abb8a91fdafd86f8190f2836abbfc500c72d1b" - integrity sha512-9MgRpgVc+/+47dFvQeD6U2s0Z92EsKzcHogtum4QB+UNd025WOJSHvn/hjk9xmzj7Stj95CyUAs31mrjxliEsQ== - dependencies: - dijkstrajs "^1.0.1" - encode-utf8 "^1.0.3" - pngjs "^5.0.0" - yargs "^15.3.1" - -queue-microtask@^1.2.2: - version "1.2.3" - resolved "https://registry.yarnpkg.com/queue-microtask/-/queue-microtask-1.2.3.tgz#4929228bbc724dfac43e0efb058caf7b6cfb6243" - integrity sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A== - -quick-lru@^5.1.1: - version "5.1.1" - resolved "https://registry.yarnpkg.com/quick-lru/-/quick-lru-5.1.1.tgz#366493e6b3e42a3a6885e2e99d18f80fb7a8c932" - integrity sha512-WuyALRjWPDGtt/wzJiadO5AXY+8hZ80hVpe6MyivgraREW751X3SbhRvG3eLKOYN+8VEvqLcf3wdnt44Z4S4SA== - -read-cache@^1.0.0: - version "1.0.0" - resolved "https://registry.yarnpkg.com/read-cache/-/read-cache-1.0.0.tgz#e664ef31161166c9751cdbe8dbcf86b5fb58f774" - integrity sha512-Owdv/Ft7IjOgm/i0xvNDZ1LrRANRfew4b2prF3OWMQLxLfu3bS8FVhCsrSCMK4lR56Y9ya+AThoTpDCTxCmpRA== - dependencies: - pify "^2.3.0" - -readdirp@~3.6.0: - version "3.6.0" - resolved "https://registry.yarnpkg.com/readdirp/-/readdirp-3.6.0.tgz#74a370bd857116e245b29cc97340cd431a02a6c7" - integrity sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA== - dependencies: - picomatch "^2.2.1" - -require-directory@^2.1.1: - version "2.1.1" - resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" - integrity sha1-jGStX9MNqxyXbiNE/+f3kqam30I= - -require-main-filename@^2.0.0: - version "2.0.0" - resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-2.0.0.tgz#d0b329ecc7cc0f61649f62215be69af54aa8989b" - integrity sha512-NKN5kMDylKuldxYLSUfrbo5Tuzh4hd+2E8NPPX02mZtn1VuREQToYe/ZdlJy+J3uCpfaiGF05e7B8W0iXbQHmg== - -resolve@^1.1.7, resolve@^1.22.1: - version "1.22.1" - resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.22.1.tgz#27cb2ebb53f91abb49470a928bba7558066ac177" - integrity sha512-nBpuuYuY5jFsli/JIs1oldw6fOQCBioohqWZg/2hiaOybXOft4lonv85uDOKXdf8rhyK159cxU5cDcK/NKk8zw== - dependencies: - is-core-module "^2.9.0" - path-parse "^1.0.7" - supports-preserve-symlinks-flag "^1.0.0" - -reusify@^1.0.4: - version "1.0.4" - resolved "https://registry.yarnpkg.com/reusify/-/reusify-1.0.4.tgz#90da382b1e126efc02146e90845a88db12925d76" - integrity sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw== - -run-parallel@^1.1.9: - version "1.2.0" - resolved "https://registry.yarnpkg.com/run-parallel/-/run-parallel-1.2.0.tgz#66d1368da7bdf921eb9d95bd1a9229e7f21a43ee" - integrity sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA== - dependencies: - queue-microtask "^1.2.2" - -set-blocking@^2.0.0: - version "2.0.0" - resolved "https://registry.yarnpkg.com/set-blocking/-/set-blocking-2.0.0.tgz#045f9782d011ae9a6803ddd382b24392b3d890f7" - integrity sha1-BF+XgtARrppoA93TgrJDkrPYkPc= - -source-map-js@^1.0.2: - version "1.0.2" - resolved "https://registry.yarnpkg.com/source-map-js/-/source-map-js-1.0.2.tgz#adbc361d9c62df380125e7f161f71c826f1e490c" - integrity sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw== - -string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.3: - version "4.2.3" - resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" - integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== - dependencies: - emoji-regex "^8.0.0" - is-fullwidth-code-point "^3.0.0" - strip-ansi "^6.0.1" - -strip-ansi@^6.0.0, strip-ansi@^6.0.1: - version "6.0.1" - resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9" - integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== - dependencies: - ansi-regex "^5.0.1" - -supports-preserve-symlinks-flag@^1.0.0: - version "1.0.0" - resolved "https://registry.yarnpkg.com/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz#6eda4bd344a3c94aea376d4cc31bc77311039e09" - integrity sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w== - -tailwindcss@^3.1.6: - version "3.1.6" - resolved "https://registry.yarnpkg.com/tailwindcss/-/tailwindcss-3.1.6.tgz#bcb719357776c39e6376a8d84e9834b2b19a49f1" - integrity sha512-7skAOY56erZAFQssT1xkpk+kWt2NrO45kORlxFPXUt3CiGsVPhH1smuH5XoDH6sGPXLyBv+zgCKA2HWBsgCytg== - dependencies: - arg "^5.0.2" - chokidar "^3.5.3" - color-name "^1.1.4" - detective "^5.2.1" - didyoumean "^1.2.2" - dlv "^1.1.3" - fast-glob "^3.2.11" - glob-parent "^6.0.2" - is-glob "^4.0.3" - lilconfig "^2.0.5" - normalize-path "^3.0.0" - object-hash "^3.0.0" - picocolors "^1.0.0" - postcss "^8.4.14" - postcss-import "^14.1.0" - postcss-js "^4.0.0" - postcss-load-config "^3.1.4" - postcss-nested "5.0.6" - postcss-selector-parser "^6.0.10" - postcss-value-parser "^4.2.0" - quick-lru "^5.1.1" - resolve "^1.22.1" - -to-regex-range@^5.0.1: - version "5.0.1" - resolved "https://registry.yarnpkg.com/to-regex-range/-/to-regex-range-5.0.1.tgz#1648c44aae7c8d988a326018ed72f5b4dd0392e4" - integrity sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ== - dependencies: - is-number "^7.0.0" - -typescript@>=3.0.1, typescript@^4.7.4: - version "4.7.4" - resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.7.4.tgz#1a88596d1cf47d59507a1bcdfb5b9dfe4d488235" - integrity sha512-C0WQT0gezHuw6AdY1M2jxUO83Rjf0HP7Sk1DtXj6j1EwkQNZrHAg2XPWlq62oqEhYvONq5pkC2Y9oPljWToLmQ== - -util-deprecate@^1.0.2: - version "1.0.2" - resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" - integrity sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw== - -which-module@^2.0.0: - version "2.0.0" - resolved "https://registry.yarnpkg.com/which-module/-/which-module-2.0.0.tgz#d9ef07dce77b9902b8a3a8fa4b31c3e3f7e6e87a" - integrity sha1-2e8H3Od7mQK4o6j6SzHD4/fm6Ho= - -wrap-ansi@^6.2.0: - version "6.2.0" - resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-6.2.0.tgz#e9393ba07102e6c91a3b221478f0257cd2856e53" - integrity sha512-r6lPcBGxZXlIcymEu7InxDMhdW0KDxpLgoFLcguasxCaJ/SOIZwINatK9KY/tf+ZrlywOKU0UDj3ATXUBfxJXA== - dependencies: - ansi-styles "^4.0.0" - string-width "^4.1.0" - strip-ansi "^6.0.0" - -wrap-ansi@^7.0.0: - version "7.0.0" - resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43" - integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== - dependencies: - ansi-styles "^4.0.0" - string-width "^4.1.0" - strip-ansi "^6.0.0" - -xtend@^4.0.2: - version "4.0.2" - resolved "https://registry.yarnpkg.com/xtend/-/xtend-4.0.2.tgz#bb72779f5fa465186b1f438f674fa347fdb5db54" - integrity sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ== - -xterm-addon-fit@^0.7.0: - version "0.7.0" - resolved "https://registry.yarnpkg.com/xterm-addon-fit/-/xterm-addon-fit-0.7.0.tgz#b8ade6d96e63b47443862088f6670b49fb752c6a" - integrity sha512-tQgHGoHqRTgeROPnvmtEJywLKoC/V9eNs4bLLz7iyJr1aW/QFzRwfd3MGiJ6odJd9xEfxcW36/xRU47JkD5NKQ== - -xterm-addon-web-links@^0.8.0: - version "0.8.0" - resolved "https://registry.yarnpkg.com/xterm-addon-web-links/-/xterm-addon-web-links-0.8.0.tgz#2cb1d57129271022569208578b0bf4774e7e6ea9" - integrity sha512-J4tKngmIu20ytX9SEJjAP3UGksah7iALqBtfTwT9ZnmFHVplCumYQsUJfKuS+JwMhjsjH61YXfndenLNvjRrEw== - -xterm@^5.1.0: - version "5.1.0" - resolved "https://registry.yarnpkg.com/xterm/-/xterm-5.1.0.tgz#3e160d60e6801c864b55adf19171c49d2ff2b4fc" - integrity sha512-LovENH4WDzpwynj+OTkLyZgJPeDom9Gra4DMlGAgz6pZhIDCQ+YuO7yfwanY+gVbn/mmZIStNOnVRU/ikQuAEQ== - -y18n@^4.0.0: - version "4.0.3" - resolved "https://registry.yarnpkg.com/y18n/-/y18n-4.0.3.tgz#b5f259c82cd6e336921efd7bfd8bf560de9eeedf" - integrity sha512-JKhqTOwSrqNA1NY5lSztJ1GrBiUodLMmIZuLiDaMRJ+itFd+ABVE8XBjOvIWL+rSqNDC74LCSFmlb/U4UZ4hJQ== - -y18n@^5.0.5: - version "5.0.8" - resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.8.tgz#7f4934d0f7ca8c56f95314939ddcd2dd91ce1d55" - integrity sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA== - -yaml@^1.10.2: - version "1.10.2" - resolved "https://registry.yarnpkg.com/yaml/-/yaml-1.10.2.tgz#2301c5ffbf12b467de8da2333a459e29e7920e4b" - integrity sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg== - -yargs-parser@^18.1.2: - version "18.1.3" - resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-18.1.3.tgz#be68c4975c6b2abf469236b0c870362fab09a7b0" - integrity sha512-o50j0JeToy/4K6OZcaQmW6lyXXKhq7csREXcDwk2omFPJEwUNOVtJKvmDr9EI1fAJZUyZcRF7kxGBWmRXudrCQ== - dependencies: - camelcase "^5.0.0" - decamelize "^1.2.0" - -yargs-parser@^21.0.0: - version "21.1.1" - resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-21.1.1.tgz#9096bceebf990d21bb31fa9516e0ede294a77d35" - integrity sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw== - -yargs@^15.3.1: - version "15.4.1" - resolved "https://registry.yarnpkg.com/yargs/-/yargs-15.4.1.tgz#0d87a16de01aee9d8bec2bfbf74f67851730f4f8" - integrity sha512-aePbxDmcYW++PaqBsJ+HYUFwCdv4LVvdnhBy78E57PIor8/OVvhMrADFFEDh8DHDFRv/O9i3lPhsENjO7QX0+A== - dependencies: - cliui "^6.0.0" - decamelize "^1.2.0" - find-up "^4.1.0" - get-caller-file "^2.0.1" - require-directory "^2.1.1" - require-main-filename "^2.0.0" - set-blocking "^2.0.0" - string-width "^4.2.0" - which-module "^2.0.0" - y18n "^4.0.0" - yargs-parser "^18.1.2" - -yargs@^17.2.1: - version "17.5.1" - resolved "https://registry.yarnpkg.com/yargs/-/yargs-17.5.1.tgz#e109900cab6fcb7fd44b1d8249166feb0b36e58e" - integrity sha512-t6YAJcxDkNX7NFYiVtKvWUz8l+PaKTLiL63mJYWR2GnHq2gjEWISzsLp9wg3aY36dY1j+gfIEL3pIF+XlJJfbA== - dependencies: - cliui "^7.0.2" - escalade "^3.1.1" - get-caller-file "^2.0.5" - require-directory "^2.1.1" - string-width "^4.2.3" - y18n "^5.0.5" - yargs-parser "^21.0.0" +# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY. +# yarn lockfile v1 + + +"@nodelib/fs.scandir@2.1.5": + version "2.1.5" + resolved "https://registry.yarnpkg.com/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz#7619c2eb21b25483f6d167548b4cfd5a7488c3d5" + integrity sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g== + dependencies: + "@nodelib/fs.stat" "2.0.5" + run-parallel "^1.1.9" + +"@nodelib/fs.stat@2.0.5", "@nodelib/fs.stat@^2.0.2": + version "2.0.5" + resolved "https://registry.yarnpkg.com/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz#5bd262af94e9d25bd1e71b05deed44876a222e8b" + integrity sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A== + +"@nodelib/fs.walk@^1.2.3": + version "1.2.8" + resolved "https://registry.yarnpkg.com/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz#e95737e8bb6746ddedf69c556953494f196fe69a" + integrity sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg== + dependencies: + "@nodelib/fs.scandir" "2.1.5" + fastq "^1.6.0" + +"@types/golang-wasm-exec@^1.15.0": + version "1.15.0" + resolved "https://registry.yarnpkg.com/@types/golang-wasm-exec/-/golang-wasm-exec-1.15.0.tgz#d0aafbb2b0dc07eaf45dfb83bfb6cdd5b2b3c55c" + integrity sha512-FrL97mp7WW8LqNinVkzTVKOIQKuYjQqgucnh41+1vRQ+bf1LT8uh++KRf9otZPXsa6H1p8ruIGz1BmCGttOL6Q== + +"@types/node@*": + version "18.6.1" + resolved "https://registry.yarnpkg.com/@types/node/-/node-18.6.1.tgz#828e4785ccca13f44e2fb6852ae0ef11e3e20ba5" + integrity sha512-z+2vB6yDt1fNwKOeGbckpmirO+VBDuQqecXkgeIqDlaOtmKn6hPR/viQ8cxCfqLU4fTlvM3+YjM367TukWdxpg== + +"@types/qrcode@^1.4.2": + version "1.4.2" + resolved "https://registry.yarnpkg.com/@types/qrcode/-/qrcode-1.4.2.tgz#7d7142d6fa9921f195db342ed08b539181546c74" + integrity sha512-7uNT9L4WQTNJejHTSTdaJhfBSCN73xtXaHFyBJ8TSwiLhe4PRuTue7Iph0s2nG9R/ifUaSnGhLUOZavlBEqDWQ== + dependencies: + "@types/node" "*" + +acorn-node@^1.8.2: + version "1.8.2" + resolved "https://registry.yarnpkg.com/acorn-node/-/acorn-node-1.8.2.tgz#114c95d64539e53dede23de8b9d96df7c7ae2af8" + integrity sha512-8mt+fslDufLYntIoPAaIMUe/lrbrehIiwmR3t2k9LljIzoigEPF27eLk2hy8zSGzmR/ogr7zbRKINMo1u0yh5A== + dependencies: + acorn "^7.0.0" + acorn-walk "^7.0.0" + xtend "^4.0.2" + +acorn-walk@^7.0.0: + version "7.2.0" + resolved "https://registry.yarnpkg.com/acorn-walk/-/acorn-walk-7.2.0.tgz#0de889a601203909b0fbe07b8938dc21d2e967bc" + integrity sha512-OPdCF6GsMIP+Az+aWfAAOEt2/+iVDKE7oy6lJ098aoe59oAmK76qV6Gw60SbZ8jHuG2wH058GF4pLFbYamYrVA== + +acorn@^7.0.0: + version "7.4.1" + resolved "https://registry.yarnpkg.com/acorn/-/acorn-7.4.1.tgz#feaed255973d2e77555b83dbc08851a6c63520fa" + integrity sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A== + +ansi-regex@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" + integrity sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ== + +ansi-styles@^4.0.0: + version "4.3.0" + resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937" + integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg== + dependencies: + color-convert "^2.0.1" + +anymatch@~3.1.2: + version "3.1.2" + resolved "https://registry.yarnpkg.com/anymatch/-/anymatch-3.1.2.tgz#c0557c096af32f106198f4f4e2a383537e378716" + integrity sha512-P43ePfOAIupkguHUycrc4qJ9kz8ZiuOUijaETwX7THt0Y/GNK7v0aa8rY816xWjZ7rJdA5XdMcpVFTKMq+RvWg== + dependencies: + normalize-path "^3.0.0" + picomatch "^2.0.4" + +arg@^5.0.2: + version "5.0.2" + resolved "https://registry.yarnpkg.com/arg/-/arg-5.0.2.tgz#c81433cc427c92c4dcf4865142dbca6f15acd59c" + integrity sha512-PYjyFOLKQ9y57JvQ6QLo8dAgNqswh8M1RMJYdQduT6xbWSgK36P/Z/v+p888pM69jMMfS8Xd8F6I1kQ/I9HUGg== + +binary-extensions@^2.0.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-2.2.0.tgz#75f502eeaf9ffde42fc98829645be4ea76bd9e2d" + integrity sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA== + +braces@^3.0.2, braces@~3.0.2: + version "3.0.2" + resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.2.tgz#3454e1a462ee8d599e236df336cd9ea4f8afe107" + integrity sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A== + dependencies: + fill-range "^7.0.1" + +camelcase-css@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/camelcase-css/-/camelcase-css-2.0.1.tgz#ee978f6947914cc30c6b44741b6ed1df7f043fd5" + integrity sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA== + +camelcase@^5.0.0: + version "5.3.1" + resolved "https://registry.yarnpkg.com/camelcase/-/camelcase-5.3.1.tgz#e3c9b31569e106811df242f715725a1f4c494320" + integrity sha512-L28STB170nwWS63UjtlEOE3dldQApaJXZkOI1uMFfzf3rRuPegHaHesyee+YxQ+W6SvRDQV6UrdOdRiR153wJg== + +chokidar@^3.5.3: + version "3.5.3" + resolved "https://registry.yarnpkg.com/chokidar/-/chokidar-3.5.3.tgz#1cf37c8707b932bd1af1ae22c0432e2acd1903bd" + integrity sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw== + dependencies: + anymatch "~3.1.2" + braces "~3.0.2" + glob-parent "~5.1.2" + is-binary-path "~2.1.0" + is-glob "~4.0.1" + normalize-path "~3.0.0" + readdirp "~3.6.0" + optionalDependencies: + fsevents "~2.3.2" + +cliui@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-6.0.0.tgz#511d702c0c4e41ca156d7d0e96021f23e13225b1" + integrity sha512-t6wbgtoCXvAzst7QgXxJYqPt0usEfbgQdftEPbLL/cvv6HPE5VgvqCuAIDR0NgU52ds6rFwqrgakNLrHEjCbrQ== + dependencies: + string-width "^4.2.0" + strip-ansi "^6.0.0" + wrap-ansi "^6.2.0" + +cliui@^7.0.2: + version "7.0.4" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.4.tgz#a0265ee655476fc807aea9df3df8df7783808b4f" + integrity sha512-OcRE68cOsVMXp1Yvonl/fzkQOyjLSu/8bhPDfQt0e0/Eb283TKP20Fs2MqoPsr9SwA595rRCA+QMzYc9nBP+JQ== + dependencies: + string-width "^4.2.0" + strip-ansi "^6.0.0" + wrap-ansi "^7.0.0" + +color-convert@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3" + integrity sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ== + dependencies: + color-name "~1.1.4" + +color-name@^1.1.4, color-name@~1.1.4: + version "1.1.4" + resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2" + integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA== + +cssesc@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/cssesc/-/cssesc-3.0.0.tgz#37741919903b868565e1c09ea747445cd18983ee" + integrity sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg== + +decamelize@^1.2.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/decamelize/-/decamelize-1.2.0.tgz#f6534d15148269b20352e7bee26f501f9a191290" + integrity sha1-9lNNFRSCabIDUue+4m9QH5oZEpA= + +defined@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/defined/-/defined-1.0.0.tgz#c98d9bcef75674188e110969151199e39b1fa693" + integrity sha512-Y2caI5+ZwS5c3RiNDJ6u53VhQHv+hHKwhkI1iHvceKUHw9Df6EK2zRLfjejRgMuCuxK7PfSWIMwWecceVvThjQ== + +detective@^5.2.1: + version "5.2.1" + resolved "https://registry.yarnpkg.com/detective/-/detective-5.2.1.tgz#6af01eeda11015acb0e73f933242b70f24f91034" + integrity sha512-v9XE1zRnz1wRtgurGu0Bs8uHKFSTdteYZNbIPFVhUZ39L/S79ppMpdmVOZAnoz1jfEFodc48n6MX483Xo3t1yw== + dependencies: + acorn-node "^1.8.2" + defined "^1.0.0" + minimist "^1.2.6" + +didyoumean@^1.2.2: + version "1.2.2" + resolved "https://registry.yarnpkg.com/didyoumean/-/didyoumean-1.2.2.tgz#989346ffe9e839b4555ecf5666edea0d3e8ad037" + integrity sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw== + +dijkstrajs@^1.0.1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/dijkstrajs/-/dijkstrajs-1.0.2.tgz#2e48c0d3b825462afe75ab4ad5e829c8ece36257" + integrity sha512-QV6PMaHTCNmKSeP6QoXhVTw9snc9VD8MulTT0Bd99Pacp4SS1cjcrYPgBPmibqKVtMJJfqC6XvOXgPMEEPH/fg== + +dlv@^1.1.3: + version "1.1.3" + resolved "https://registry.yarnpkg.com/dlv/-/dlv-1.1.3.tgz#5c198a8a11453596e751494d49874bc7732f2e79" + integrity sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA== + +dts-bundle-generator@^6.12.0: + version "6.12.0" + resolved "https://registry.yarnpkg.com/dts-bundle-generator/-/dts-bundle-generator-6.12.0.tgz#0a221bdce5fdd309a56c8556e645f16ed87ab07d" + integrity sha512-k/QAvuVaLIdyWRUHduDrWBe4j8PcE6TDt06+f32KHbW7/SmUPbX1O23fFtQgKwUyTBkbIjJFOFtNrF97tJcKug== + dependencies: + typescript ">=3.0.1" + yargs "^17.2.1" + +emoji-regex@^8.0.0: + version "8.0.0" + resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-8.0.0.tgz#e818fd69ce5ccfcb404594f842963bf53164cc37" + integrity sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A== + +encode-utf8@^1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/encode-utf8/-/encode-utf8-1.0.3.tgz#f30fdd31da07fb596f281beb2f6b027851994cda" + integrity sha512-ucAnuBEhUK4boH2HjVYG5Q2mQyPorvv0u/ocS+zhdw0S8AlHYY+GOFhP1Gio5z4icpP2ivFSvhtFjQi8+T9ppw== + +escalade@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" + integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== + +fast-glob@^3.2.11: + version "3.2.11" + resolved "https://registry.yarnpkg.com/fast-glob/-/fast-glob-3.2.11.tgz#a1172ad95ceb8a16e20caa5c5e56480e5129c1d9" + integrity sha512-xrO3+1bxSo3ZVHAnqzyuewYT6aMFHRAd4Kcs92MAonjwQZLsK9d0SF1IyQ3k5PoirxTW0Oe/RqFgMQ6TcNE5Ew== + dependencies: + "@nodelib/fs.stat" "^2.0.2" + "@nodelib/fs.walk" "^1.2.3" + glob-parent "^5.1.2" + merge2 "^1.3.0" + micromatch "^4.0.4" + +fastq@^1.6.0: + version "1.13.0" + resolved "https://registry.yarnpkg.com/fastq/-/fastq-1.13.0.tgz#616760f88a7526bdfc596b7cab8c18938c36b98c" + integrity sha512-YpkpUnK8od0o1hmeSc7UUs/eB/vIPWJYjKck2QKIzAf71Vm1AAQ3EbuZB3g2JIy+pg+ERD0vqI79KyZiB2e2Nw== + dependencies: + reusify "^1.0.4" + +fill-range@^7.0.1: + version "7.0.1" + resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.0.1.tgz#1919a6a7c75fe38b2c7c77e5198535da9acdda40" + integrity sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ== + dependencies: + to-regex-range "^5.0.1" + +find-up@^4.1.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/find-up/-/find-up-4.1.0.tgz#97afe7d6cdc0bc5928584b7c8d7b16e8a9aa5d19" + integrity sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw== + dependencies: + locate-path "^5.0.0" + path-exists "^4.0.0" + +fsevents@~2.3.2: + version "2.3.2" + resolved "https://registry.yarnpkg.com/fsevents/-/fsevents-2.3.2.tgz#8a526f78b8fdf4623b709e0b975c52c24c02fd1a" + integrity sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA== + +function-bind@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" + integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== + +get-caller-file@^2.0.1, get-caller-file@^2.0.5: + version "2.0.5" + resolved "https://registry.yarnpkg.com/get-caller-file/-/get-caller-file-2.0.5.tgz#4f94412a82db32f36e3b0b9741f8a97feb031f7e" + integrity sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg== + +glob-parent@^5.1.2, glob-parent@~5.1.2: + version "5.1.2" + resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-5.1.2.tgz#869832c58034fe68a4093c17dc15e8340d8401c4" + integrity sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow== + dependencies: + is-glob "^4.0.1" + +glob-parent@^6.0.2: + version "6.0.2" + resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-6.0.2.tgz#6d237d99083950c79290f24c7642a3de9a28f9e3" + integrity sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A== + dependencies: + is-glob "^4.0.3" + +has@^1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/has/-/has-1.0.3.tgz#722d7cbfc1f6aa8241f16dd814e011e1f41e8796" + integrity sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw== + dependencies: + function-bind "^1.1.1" + +is-binary-path@~2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/is-binary-path/-/is-binary-path-2.1.0.tgz#ea1f7f3b80f064236e83470f86c09c254fb45b09" + integrity sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw== + dependencies: + binary-extensions "^2.0.0" + +is-core-module@^2.9.0: + version "2.9.0" + resolved "https://registry.yarnpkg.com/is-core-module/-/is-core-module-2.9.0.tgz#e1c34429cd51c6dd9e09e0799e396e27b19a9c69" + integrity sha512-+5FPy5PnwmO3lvfMb0AsoPaBG+5KHUI0wYFXOtYPnVVVspTFUuMZNfNaNVRt3FZadstu2c8x23vykRW/NBoU6A== + dependencies: + has "^1.0.3" + +is-extglob@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/is-extglob/-/is-extglob-2.1.1.tgz#a88c02535791f02ed37c76a1b9ea9773c833f8c2" + integrity sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ== + +is-fullwidth-code-point@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz#f116f8064fe90b3f7844a38997c0b75051269f1d" + integrity sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg== + +is-glob@^4.0.1, is-glob@^4.0.3, is-glob@~4.0.1: + version "4.0.3" + resolved "https://registry.yarnpkg.com/is-glob/-/is-glob-4.0.3.tgz#64f61e42cbbb2eec2071a9dac0b28ba1e65d5084" + integrity sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg== + dependencies: + is-extglob "^2.1.1" + +is-number@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/is-number/-/is-number-7.0.0.tgz#7535345b896734d5f80c4d06c50955527a14f12b" + integrity sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng== + +lilconfig@^2.0.5: + version "2.0.6" + resolved "https://registry.yarnpkg.com/lilconfig/-/lilconfig-2.0.6.tgz#32a384558bd58af3d4c6e077dd1ad1d397bc69d4" + integrity sha512-9JROoBW7pobfsx+Sq2JsASvCo6Pfo6WWoUW79HuB1BCoBXD4PLWJPqDF6fNj67pqBYTbAHkE57M1kS/+L1neOg== + +locate-path@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-5.0.0.tgz#1afba396afd676a6d42504d0a67a3a7eb9f62aa0" + integrity sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g== + dependencies: + p-locate "^4.1.0" + +merge2@^1.3.0: + version "1.4.1" + resolved "https://registry.yarnpkg.com/merge2/-/merge2-1.4.1.tgz#4368892f885e907455a6fd7dc55c0c9d404990ae" + integrity sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg== + +micromatch@^4.0.4: + version "4.0.5" + resolved "https://registry.yarnpkg.com/micromatch/-/micromatch-4.0.5.tgz#bc8999a7cbbf77cdc89f132f6e467051b49090c6" + integrity sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA== + dependencies: + braces "^3.0.2" + picomatch "^2.3.1" + +minimist@^1.2.6: + version "1.2.6" + resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.6.tgz#8637a5b759ea0d6e98702cfb3a9283323c93af44" + integrity sha512-Jsjnk4bw3YJqYzbdyBiNsPWHPfO++UGG749Cxs6peCu5Xg4nrena6OVxOYxrQTqww0Jmwt+Ref8rggumkTLz9Q== + +nanoid@^3.3.4: + version "3.3.4" + resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.4.tgz#730b67e3cd09e2deacf03c027c81c9d9dbc5e8ab" + integrity sha512-MqBkQh/OHTS2egovRtLk45wEyNXwF+cokD+1YPf9u5VfJiRdAiRwB2froX5Co9Rh20xs4siNPm8naNotSD6RBw== + +normalize-path@^3.0.0, normalize-path@~3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/normalize-path/-/normalize-path-3.0.0.tgz#0dcd69ff23a1c9b11fd0978316644a0388216a65" + integrity sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA== + +object-hash@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/object-hash/-/object-hash-3.0.0.tgz#73f97f753e7baffc0e2cc9d6e079079744ac82e9" + integrity sha512-RSn9F68PjH9HqtltsSnqYC1XXoWe9Bju5+213R98cNGttag9q9yAOTzdbsqvIa7aNm5WffBZFpWYr2aWrklWAw== + +p-limit@^2.2.0: + version "2.3.0" + resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-2.3.0.tgz#3dd33c647a214fdfffd835933eb086da0dc21db1" + integrity sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w== + dependencies: + p-try "^2.0.0" + +p-locate@^4.1.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/p-locate/-/p-locate-4.1.0.tgz#a3428bb7088b3a60292f66919278b7c297ad4f07" + integrity sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A== + dependencies: + p-limit "^2.2.0" + +p-try@^2.0.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/p-try/-/p-try-2.2.0.tgz#cb2868540e313d61de58fafbe35ce9004d5540e6" + integrity sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ== + +path-exists@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-4.0.0.tgz#513bdbe2d3b95d7762e8c1137efa195c6c61b5b3" + integrity sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w== + +path-parse@^1.0.7: + version "1.0.7" + resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.7.tgz#fbc114b60ca42b30d9daf5858e4bd68bbedb6735" + integrity sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw== + +picocolors@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/picocolors/-/picocolors-1.0.0.tgz#cb5bdc74ff3f51892236eaf79d68bc44564ab81c" + integrity sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ== + +picomatch@^2.0.4, picomatch@^2.2.1, picomatch@^2.3.1: + version "2.3.1" + resolved "https://registry.yarnpkg.com/picomatch/-/picomatch-2.3.1.tgz#3ba3833733646d9d3e4995946c1365a67fb07a42" + integrity sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA== + +pify@^2.3.0: + version "2.3.0" + resolved "https://registry.yarnpkg.com/pify/-/pify-2.3.0.tgz#ed141a6ac043a849ea588498e7dca8b15330e90c" + integrity sha512-udgsAY+fTnvv7kI7aaxbqwWNb0AHiB0qBO89PZKPkoTmGOgdbrHDKD+0B2X4uTfJ/FT1R09r9gTsjUjNJotuog== + +pngjs@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/pngjs/-/pngjs-5.0.0.tgz#e79dd2b215767fd9c04561c01236df960bce7fbb" + integrity sha512-40QW5YalBNfQo5yRYmiw7Yz6TKKVr3h6970B2YE+3fQpsWcrbj1PzJgxeJ19DRQjhMbKPIuMY8rFaXc8moolVw== + +postcss-import@^14.1.0: + version "14.1.0" + resolved "https://registry.yarnpkg.com/postcss-import/-/postcss-import-14.1.0.tgz#a7333ffe32f0b8795303ee9e40215dac922781f0" + integrity sha512-flwI+Vgm4SElObFVPpTIT7SU7R3qk2L7PyduMcokiaVKuWv9d/U+Gm/QAd8NDLuykTWTkcrjOeD2Pp1rMeBTGw== + dependencies: + postcss-value-parser "^4.0.0" + read-cache "^1.0.0" + resolve "^1.1.7" + +postcss-js@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/postcss-js/-/postcss-js-4.0.0.tgz#31db79889531b80dc7bc9b0ad283e418dce0ac00" + integrity sha512-77QESFBwgX4irogGVPgQ5s07vLvFqWr228qZY+w6lW599cRlK/HmnlivnnVUxkjHnCu4J16PDMHcH+e+2HbvTQ== + dependencies: + camelcase-css "^2.0.1" + +postcss-load-config@^3.1.4: + version "3.1.4" + resolved "https://registry.yarnpkg.com/postcss-load-config/-/postcss-load-config-3.1.4.tgz#1ab2571faf84bb078877e1d07905eabe9ebda855" + integrity sha512-6DiM4E7v4coTE4uzA8U//WhtPwyhiim3eyjEMFCnUpzbrkK9wJHgKDT2mR+HbtSrd/NubVaYTOpSpjUl8NQeRg== + dependencies: + lilconfig "^2.0.5" + yaml "^1.10.2" + +postcss-nested@5.0.6: + version "5.0.6" + resolved "https://registry.yarnpkg.com/postcss-nested/-/postcss-nested-5.0.6.tgz#466343f7fc8d3d46af3e7dba3fcd47d052a945bc" + integrity sha512-rKqm2Fk0KbA8Vt3AdGN0FB9OBOMDVajMG6ZCf/GoHgdxUJ4sBFp0A/uMIRm+MJUdo33YXEtjqIz8u7DAp8B7DA== + dependencies: + postcss-selector-parser "^6.0.6" + +postcss-selector-parser@^6.0.10, postcss-selector-parser@^6.0.6: + version "6.0.10" + resolved "https://registry.yarnpkg.com/postcss-selector-parser/-/postcss-selector-parser-6.0.10.tgz#79b61e2c0d1bfc2602d549e11d0876256f8df88d" + integrity sha512-IQ7TZdoaqbT+LCpShg46jnZVlhWD2w6iQYAcYXfHARZ7X1t/UGhhceQDs5X0cGqKvYlHNOuv7Oa1xmb0oQuA3w== + dependencies: + cssesc "^3.0.0" + util-deprecate "^1.0.2" + +postcss-value-parser@^4.0.0, postcss-value-parser@^4.2.0: + version "4.2.0" + resolved "https://registry.yarnpkg.com/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz#723c09920836ba6d3e5af019f92bc0971c02e514" + integrity sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ== + +postcss@^8.4.14: + version "8.4.14" + resolved "https://registry.yarnpkg.com/postcss/-/postcss-8.4.14.tgz#ee9274d5622b4858c1007a74d76e42e56fd21caf" + integrity sha512-E398TUmfAYFPBSdzgeieK2Y1+1cpdxJx8yXbK/m57nRhKSmk1GB2tO4lbLBtlkfPQTDKfe4Xqv1ASWPpayPEig== + dependencies: + nanoid "^3.3.4" + picocolors "^1.0.0" + source-map-js "^1.0.2" + +preact@^10.10.0: + version "10.10.0" + resolved "https://registry.yarnpkg.com/preact/-/preact-10.10.0.tgz#7434750a24b59dae1957d95dc0aa47a4a8e9a180" + integrity sha512-fszkg1iJJjq68I4lI8ZsmBiaoQiQHbxf1lNq+72EmC/mZOsFF5zn3k1yv9QGoFgIXzgsdSKtYymLJsrJPoamjQ== + +qrcode@^1.5.0: + version "1.5.0" + resolved "https://registry.yarnpkg.com/qrcode/-/qrcode-1.5.0.tgz#95abb8a91fdafd86f8190f2836abbfc500c72d1b" + integrity sha512-9MgRpgVc+/+47dFvQeD6U2s0Z92EsKzcHogtum4QB+UNd025WOJSHvn/hjk9xmzj7Stj95CyUAs31mrjxliEsQ== + dependencies: + dijkstrajs "^1.0.1" + encode-utf8 "^1.0.3" + pngjs "^5.0.0" + yargs "^15.3.1" + +queue-microtask@^1.2.2: + version "1.2.3" + resolved "https://registry.yarnpkg.com/queue-microtask/-/queue-microtask-1.2.3.tgz#4929228bbc724dfac43e0efb058caf7b6cfb6243" + integrity sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A== + +quick-lru@^5.1.1: + version "5.1.1" + resolved "https://registry.yarnpkg.com/quick-lru/-/quick-lru-5.1.1.tgz#366493e6b3e42a3a6885e2e99d18f80fb7a8c932" + integrity sha512-WuyALRjWPDGtt/wzJiadO5AXY+8hZ80hVpe6MyivgraREW751X3SbhRvG3eLKOYN+8VEvqLcf3wdnt44Z4S4SA== + +read-cache@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/read-cache/-/read-cache-1.0.0.tgz#e664ef31161166c9751cdbe8dbcf86b5fb58f774" + integrity sha512-Owdv/Ft7IjOgm/i0xvNDZ1LrRANRfew4b2prF3OWMQLxLfu3bS8FVhCsrSCMK4lR56Y9ya+AThoTpDCTxCmpRA== + dependencies: + pify "^2.3.0" + +readdirp@~3.6.0: + version "3.6.0" + resolved "https://registry.yarnpkg.com/readdirp/-/readdirp-3.6.0.tgz#74a370bd857116e245b29cc97340cd431a02a6c7" + integrity sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA== + dependencies: + picomatch "^2.2.1" + +require-directory@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" + integrity sha1-jGStX9MNqxyXbiNE/+f3kqam30I= + +require-main-filename@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-2.0.0.tgz#d0b329ecc7cc0f61649f62215be69af54aa8989b" + integrity sha512-NKN5kMDylKuldxYLSUfrbo5Tuzh4hd+2E8NPPX02mZtn1VuREQToYe/ZdlJy+J3uCpfaiGF05e7B8W0iXbQHmg== + +resolve@^1.1.7, resolve@^1.22.1: + version "1.22.1" + resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.22.1.tgz#27cb2ebb53f91abb49470a928bba7558066ac177" + integrity sha512-nBpuuYuY5jFsli/JIs1oldw6fOQCBioohqWZg/2hiaOybXOft4lonv85uDOKXdf8rhyK159cxU5cDcK/NKk8zw== + dependencies: + is-core-module "^2.9.0" + path-parse "^1.0.7" + supports-preserve-symlinks-flag "^1.0.0" + +reusify@^1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/reusify/-/reusify-1.0.4.tgz#90da382b1e126efc02146e90845a88db12925d76" + integrity sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw== + +run-parallel@^1.1.9: + version "1.2.0" + resolved "https://registry.yarnpkg.com/run-parallel/-/run-parallel-1.2.0.tgz#66d1368da7bdf921eb9d95bd1a9229e7f21a43ee" + integrity sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA== + dependencies: + queue-microtask "^1.2.2" + +set-blocking@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/set-blocking/-/set-blocking-2.0.0.tgz#045f9782d011ae9a6803ddd382b24392b3d890f7" + integrity sha1-BF+XgtARrppoA93TgrJDkrPYkPc= + +source-map-js@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/source-map-js/-/source-map-js-1.0.2.tgz#adbc361d9c62df380125e7f161f71c826f1e490c" + integrity sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw== + +string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.3: + version "4.2.3" + resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" + integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== + dependencies: + emoji-regex "^8.0.0" + is-fullwidth-code-point "^3.0.0" + strip-ansi "^6.0.1" + +strip-ansi@^6.0.0, strip-ansi@^6.0.1: + version "6.0.1" + resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9" + integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== + dependencies: + ansi-regex "^5.0.1" + +supports-preserve-symlinks-flag@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz#6eda4bd344a3c94aea376d4cc31bc77311039e09" + integrity sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w== + +tailwindcss@^3.1.6: + version "3.1.6" + resolved "https://registry.yarnpkg.com/tailwindcss/-/tailwindcss-3.1.6.tgz#bcb719357776c39e6376a8d84e9834b2b19a49f1" + integrity sha512-7skAOY56erZAFQssT1xkpk+kWt2NrO45kORlxFPXUt3CiGsVPhH1smuH5XoDH6sGPXLyBv+zgCKA2HWBsgCytg== + dependencies: + arg "^5.0.2" + chokidar "^3.5.3" + color-name "^1.1.4" + detective "^5.2.1" + didyoumean "^1.2.2" + dlv "^1.1.3" + fast-glob "^3.2.11" + glob-parent "^6.0.2" + is-glob "^4.0.3" + lilconfig "^2.0.5" + normalize-path "^3.0.0" + object-hash "^3.0.0" + picocolors "^1.0.0" + postcss "^8.4.14" + postcss-import "^14.1.0" + postcss-js "^4.0.0" + postcss-load-config "^3.1.4" + postcss-nested "5.0.6" + postcss-selector-parser "^6.0.10" + postcss-value-parser "^4.2.0" + quick-lru "^5.1.1" + resolve "^1.22.1" + +to-regex-range@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/to-regex-range/-/to-regex-range-5.0.1.tgz#1648c44aae7c8d988a326018ed72f5b4dd0392e4" + integrity sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ== + dependencies: + is-number "^7.0.0" + +typescript@>=3.0.1, typescript@^4.7.4: + version "4.7.4" + resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.7.4.tgz#1a88596d1cf47d59507a1bcdfb5b9dfe4d488235" + integrity sha512-C0WQT0gezHuw6AdY1M2jxUO83Rjf0HP7Sk1DtXj6j1EwkQNZrHAg2XPWlq62oqEhYvONq5pkC2Y9oPljWToLmQ== + +util-deprecate@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" + integrity sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw== + +which-module@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/which-module/-/which-module-2.0.0.tgz#d9ef07dce77b9902b8a3a8fa4b31c3e3f7e6e87a" + integrity sha1-2e8H3Od7mQK4o6j6SzHD4/fm6Ho= + +wrap-ansi@^6.2.0: + version "6.2.0" + resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-6.2.0.tgz#e9393ba07102e6c91a3b221478f0257cd2856e53" + integrity sha512-r6lPcBGxZXlIcymEu7InxDMhdW0KDxpLgoFLcguasxCaJ/SOIZwINatK9KY/tf+ZrlywOKU0UDj3ATXUBfxJXA== + dependencies: + ansi-styles "^4.0.0" + string-width "^4.1.0" + strip-ansi "^6.0.0" + +wrap-ansi@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43" + integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== + dependencies: + ansi-styles "^4.0.0" + string-width "^4.1.0" + strip-ansi "^6.0.0" + +xtend@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/xtend/-/xtend-4.0.2.tgz#bb72779f5fa465186b1f438f674fa347fdb5db54" + integrity sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ== + +xterm-addon-fit@^0.7.0: + version "0.7.0" + resolved "https://registry.yarnpkg.com/xterm-addon-fit/-/xterm-addon-fit-0.7.0.tgz#b8ade6d96e63b47443862088f6670b49fb752c6a" + integrity sha512-tQgHGoHqRTgeROPnvmtEJywLKoC/V9eNs4bLLz7iyJr1aW/QFzRwfd3MGiJ6odJd9xEfxcW36/xRU47JkD5NKQ== + +xterm-addon-web-links@^0.8.0: + version "0.8.0" + resolved "https://registry.yarnpkg.com/xterm-addon-web-links/-/xterm-addon-web-links-0.8.0.tgz#2cb1d57129271022569208578b0bf4774e7e6ea9" + integrity sha512-J4tKngmIu20ytX9SEJjAP3UGksah7iALqBtfTwT9ZnmFHVplCumYQsUJfKuS+JwMhjsjH61YXfndenLNvjRrEw== + +xterm@^5.1.0: + version "5.1.0" + resolved "https://registry.yarnpkg.com/xterm/-/xterm-5.1.0.tgz#3e160d60e6801c864b55adf19171c49d2ff2b4fc" + integrity sha512-LovENH4WDzpwynj+OTkLyZgJPeDom9Gra4DMlGAgz6pZhIDCQ+YuO7yfwanY+gVbn/mmZIStNOnVRU/ikQuAEQ== + +y18n@^4.0.0: + version "4.0.3" + resolved "https://registry.yarnpkg.com/y18n/-/y18n-4.0.3.tgz#b5f259c82cd6e336921efd7bfd8bf560de9eeedf" + integrity sha512-JKhqTOwSrqNA1NY5lSztJ1GrBiUodLMmIZuLiDaMRJ+itFd+ABVE8XBjOvIWL+rSqNDC74LCSFmlb/U4UZ4hJQ== + +y18n@^5.0.5: + version "5.0.8" + resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.8.tgz#7f4934d0f7ca8c56f95314939ddcd2dd91ce1d55" + integrity sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA== + +yaml@^1.10.2: + version "1.10.2" + resolved "https://registry.yarnpkg.com/yaml/-/yaml-1.10.2.tgz#2301c5ffbf12b467de8da2333a459e29e7920e4b" + integrity sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg== + +yargs-parser@^18.1.2: + version "18.1.3" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-18.1.3.tgz#be68c4975c6b2abf469236b0c870362fab09a7b0" + integrity sha512-o50j0JeToy/4K6OZcaQmW6lyXXKhq7csREXcDwk2omFPJEwUNOVtJKvmDr9EI1fAJZUyZcRF7kxGBWmRXudrCQ== + dependencies: + camelcase "^5.0.0" + decamelize "^1.2.0" + +yargs-parser@^21.0.0: + version "21.1.1" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-21.1.1.tgz#9096bceebf990d21bb31fa9516e0ede294a77d35" + integrity sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw== + +yargs@^15.3.1: + version "15.4.1" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-15.4.1.tgz#0d87a16de01aee9d8bec2bfbf74f67851730f4f8" + integrity sha512-aePbxDmcYW++PaqBsJ+HYUFwCdv4LVvdnhBy78E57PIor8/OVvhMrADFFEDh8DHDFRv/O9i3lPhsENjO7QX0+A== + dependencies: + cliui "^6.0.0" + decamelize "^1.2.0" + find-up "^4.1.0" + get-caller-file "^2.0.1" + require-directory "^2.1.1" + require-main-filename "^2.0.0" + set-blocking "^2.0.0" + string-width "^4.2.0" + which-module "^2.0.0" + y18n "^4.0.0" + yargs-parser "^18.1.2" + +yargs@^17.2.1: + version "17.5.1" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-17.5.1.tgz#e109900cab6fcb7fd44b1d8249166feb0b36e58e" + integrity sha512-t6YAJcxDkNX7NFYiVtKvWUz8l+PaKTLiL63mJYWR2GnHq2gjEWISzsLp9wg3aY36dY1j+gfIEL3pIF+XlJJfbA== + dependencies: + cliui "^7.0.2" + escalade "^3.1.1" + get-caller-file "^2.0.5" + require-directory "^2.1.1" + string-width "^4.2.3" + y18n "^5.0.5" + yargs-parser "^21.0.0" diff --git a/cmd/tsshd/tsshd.go b/cmd/tsshd/tsshd.go index 950eb661cdb23..1ec09a0d47611 100644 --- a/cmd/tsshd/tsshd.go +++ b/cmd/tsshd/tsshd.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -// The tsshd binary was an experimental SSH server that accepts connections -// from anybody on the same Tailscale network. -// -// Its functionality moved into tailscaled. -// -// See https://github.com/tailscale/tailscale/issues/3802 -package main +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +// The tsshd binary was an experimental SSH server that accepts connections +// from anybody on the same Tailscale network. +// +// Its functionality moved into tailscaled. +// +// See https://github.com/tailscale/tailscale/issues/3802 +package main diff --git a/control/controlbase/conn.go b/control/controlbase/conn.go index dc22212e887cb..b6fc53b3a40f3 100644 --- a/control/controlbase/conn.go +++ b/control/controlbase/conn.go @@ -1,408 +1,408 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package controlbase implements the base transport of the Tailscale -// 2021 control protocol. -// -// The base transport implements Noise IK, instantiated with -// Curve25519, ChaCha20Poly1305 and BLAKE2s. -package controlbase - -import ( - "crypto/cipher" - "encoding/binary" - "fmt" - "net" - "sync" - "time" - - "golang.org/x/crypto/blake2s" - chp "golang.org/x/crypto/chacha20poly1305" - "tailscale.com/types/key" -) - -const ( - // maxMessageSize is the maximum size of a protocol frame on the - // wire, including header and payload. - maxMessageSize = 4096 - // maxCiphertextSize is the maximum amount of ciphertext bytes - // that one protocol frame can carry, after framing. - maxCiphertextSize = maxMessageSize - 3 - // maxPlaintextSize is the maximum amount of plaintext bytes that - // one protocol frame can carry, after encryption and framing. - maxPlaintextSize = maxCiphertextSize - chp.Overhead -) - -// A Conn is a secured Noise connection. It implements the net.Conn -// interface, with the unusual trait that any write error (including a -// SetWriteDeadline induced i/o timeout) causes all future writes to -// fail. -type Conn struct { - conn net.Conn - version uint16 - peer key.MachinePublic - handshakeHash [blake2s.Size]byte - rx rxState - tx txState -} - -// rxState is all the Conn state that Read uses. -type rxState struct { - sync.Mutex - cipher cipher.AEAD - nonce nonce - buf *maxMsgBuffer // or nil when reads exhausted - n int // number of valid bytes in buf - next int // offset of next undecrypted packet - plaintext []byte // slice into buf of decrypted bytes - hdrBuf [headerLen]byte // small buffer used when buf is nil -} - -// txState is all the Conn state that Write uses. -type txState struct { - sync.Mutex - cipher cipher.AEAD - nonce nonce - err error // records the first partial write error for all future calls -} - -// ProtocolVersion returns the protocol version that was used to -// establish this Conn. -func (c *Conn) ProtocolVersion() int { - return int(c.version) -} - -// HandshakeHash returns the Noise handshake hash for the connection, -// which can be used to bind other messages to this connection -// (i.e. to ensure that the message wasn't replayed from a different -// connection). -func (c *Conn) HandshakeHash() [blake2s.Size]byte { - return c.handshakeHash -} - -// Peer returns the peer's long-term public key. -func (c *Conn) Peer() key.MachinePublic { - return c.peer -} - -// readNLocked reads into c.rx.buf until buf contains at least total -// bytes. Returns a slice of the total bytes in rxBuf, or an -// error if fewer than total bytes are available. -// -// It may be called with a nil c.rx.buf only if total == headerLen. -// -// On success, c.rx.buf will be non-nil. -func (c *Conn) readNLocked(total int) ([]byte, error) { - if total > maxMessageSize { - return nil, errReadTooBig{total} - } - for { - if total <= c.rx.n { - return c.rx.buf[:total], nil - } - var n int - var err error - if c.rx.buf == nil { - if c.rx.n != 0 || total != headerLen { - panic("unexpected") - } - // Optimization to reduce memory usage. - // Most connections are blocked forever waiting for - // a read, so we don't want c.rx.buf to be allocated until - // we know there's data to read. Instead, when we're - // waiting for data to arrive here, read into the - // 3 byte hdrBuf: - n, err = c.conn.Read(c.rx.hdrBuf[:]) - if n > 0 { - c.rx.buf = getMaxMsgBuffer() - copy(c.rx.buf[:], c.rx.hdrBuf[:n]) - } - } else { - n, err = c.conn.Read(c.rx.buf[c.rx.n:]) - } - c.rx.n += n - if err != nil { - return nil, err - } - } -} - -// decryptLocked decrypts msg (which is header+ciphertext) in-place -// and sets c.rx.plaintext to the decrypted bytes. -func (c *Conn) decryptLocked(msg []byte) (err error) { - if msgType := msg[0]; msgType != msgTypeRecord { - return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord) - } - // We don't check the length field here, because the caller - // already did in order to figure out how big the msg slice should - // be. - ciphertext := msg[headerLen:] - - if !c.rx.nonce.Valid() { - return errCipherExhausted{} - } - - c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil) - c.rx.nonce.Increment() - - if err != nil { - // Once a decryption has failed, our Conn is no longer - // synchronized with our peer. Nuke the cipher state to be - // safe, so that no further decryptions are attempted. Future - // read attempts will return net.ErrClosed. - c.rx.cipher = nil - } - return err -} - -// encryptLocked encrypts plaintext into buf (including the -// packet header) and returns a slice of the ciphertext, or an error -// if the cipher is exhausted (i.e. can no longer be used safely). -func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) { - if !c.tx.nonce.Valid() { - // Received 2^64-1 messages on this cipher state. Connection - // is no longer usable. - return nil, errCipherExhausted{} - } - - buf[0] = msgTypeRecord - binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead)) - ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil) - c.tx.nonce.Increment() - - return ret, nil -} - -// wholeMessageLocked returns a slice of one whole Noise transport -// message from c.rx.buf, if one whole message is available, and -// advances the read state to the next Noise message in the -// buffer. Returns nil without advancing read state if there isn't one -// whole message in c.rx.buf. -func (c *Conn) wholeMessageLocked() []byte { - available := c.rx.n - c.rx.next - if available < headerLen { - return nil - } - bs := c.rx.buf[c.rx.next:c.rx.n] - totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) - if len(bs) < totalSize { - return nil - } - c.rx.next += totalSize - return bs[:totalSize] -} - -// decryptOneLocked decrypts one Noise transport message, reading from -// c.conn as needed, and sets c.rx.plaintext to point to the decrypted -// bytes. c.rx.plaintext is only valid if err == nil. -func (c *Conn) decryptOneLocked() error { - c.rx.plaintext = nil - - // Fast path: do we have one whole ciphertext frame buffered - // already? - if bs := c.wholeMessageLocked(); bs != nil { - return c.decryptLocked(bs) - } - - if c.rx.next != 0 { - // To simplify the read logic, move the remainder of the - // buffered bytes back to the head of the buffer, so we can - // grow it without worrying about wraparound. - c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n]) - c.rx.next = 0 - } - - // Return our buffer to the pool if it's empty, lest we be - // blocked in a long Read call, reading the 3 byte header. We - // don't to keep that buffer unnecessarily alive. - if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil { - bufPool.Put(c.rx.buf) - c.rx.buf = nil - } - - bs, err := c.readNLocked(headerLen) - if err != nil { - return err - } - // The rest of the header (besides the length field) gets verified - // in decryptLocked, not here. - messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) - bs, err = c.readNLocked(messageLen) - if err != nil { - return err - } - - c.rx.next = len(bs) - - return c.decryptLocked(bs) -} - -// Read implements io.Reader. -func (c *Conn) Read(bs []byte) (int, error) { - c.rx.Lock() - defer c.rx.Unlock() - - if c.rx.cipher == nil { - return 0, net.ErrClosed - } - // If no plaintext is buffered, decrypt incoming frames until we - // have some plaintext. Zero-byte Noise frames are allowed in this - // protocol, which is why we have to loop here rather than decrypt - // a single additional frame. - for len(c.rx.plaintext) == 0 { - if err := c.decryptOneLocked(); err != nil { - return 0, err - } - } - n := copy(bs, c.rx.plaintext) - c.rx.plaintext = c.rx.plaintext[n:] - - // Lose slice's underlying array pointer to unneeded memory so - // GC can collect more. - if len(c.rx.plaintext) == 0 { - c.rx.plaintext = nil - } - return n, nil -} - -// Write implements io.Writer. -func (c *Conn) Write(bs []byte) (n int, err error) { - c.tx.Lock() - defer c.tx.Unlock() - - if c.tx.err != nil { - return 0, c.tx.err - } - defer func() { - if err != nil { - // All write errors are fatal for this conn, so clear the - // cipher state whenever an error happens. - c.tx.cipher = nil - } - if c.tx.err == nil { - // Only set c.tx.err if not nil so that we can return one - // error on the first failure, and a different one for - // subsequent calls. See the error handling around Write - // below for why. - c.tx.err = err - } - }() - - if c.tx.cipher == nil { - return 0, net.ErrClosed - } - - buf := getMaxMsgBuffer() - defer bufPool.Put(buf) - - var sent int - for len(bs) > 0 { - toSend := bs - if len(toSend) > maxPlaintextSize { - toSend = bs[:maxPlaintextSize] - } - bs = bs[len(toSend):] - - ciphertext, err := c.encryptLocked(toSend, buf) - if err != nil { - return sent, err - } - if _, err := c.conn.Write(ciphertext); err != nil { - // Return the raw error on the Write that actually - // failed. For future writes, return that error wrapped in - // a desync error. - c.tx.err = errPartialWrite{err} - return sent, err - } - sent += len(toSend) - } - return sent, nil -} - -// Close implements io.Closer. -func (c *Conn) Close() error { - closeErr := c.conn.Close() // unblocks any waiting reads or writes - - // Remove references to live cipher state. Strictly speaking this - // is unnecessary, but we want to try and hand the active cipher - // state to the garbage collector promptly, to preserve perfect - // forward secrecy as much as we can. - c.rx.Lock() - c.rx.cipher = nil - c.rx.Unlock() - c.tx.Lock() - c.tx.cipher = nil - c.tx.Unlock() - return closeErr -} - -func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } -func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } -func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } -func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } -func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } - -// errCipherExhausted is the error returned when we run out of nonces -// on a cipher. -type errCipherExhausted struct{} - -func (errCipherExhausted) Error() string { - return "cipher exhausted, no more nonces available for current key" -} -func (errCipherExhausted) Timeout() bool { return false } -func (errCipherExhausted) Temporary() bool { return false } - -// errPartialWrite is the error returned when the cipher state has -// become unusable due to a past partial write. -type errPartialWrite struct { - err error -} - -func (e errPartialWrite) Error() string { - return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err) -} -func (e errPartialWrite) Unwrap() error { return e.err } -func (e errPartialWrite) Temporary() bool { return false } -func (e errPartialWrite) Timeout() bool { return false } - -// errReadTooBig is the error returned when the peer sent an -// unacceptably large Noise frame. -type errReadTooBig struct { - requested int -} - -func (e errReadTooBig) Error() string { - return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested) -} -func (e errReadTooBig) Temporary() bool { - // permanent error because this error only occurs when our peer - // sends us a frame so large we're unwilling to ever decode it. - return false -} -func (e errReadTooBig) Timeout() bool { return false } - -type nonce [chp.NonceSize]byte - -func (n *nonce) Valid() bool { - return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce -} - -func (n *nonce) Increment() { - if !n.Valid() { - panic("increment of invalid nonce") - } - binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:])) -} - -type maxMsgBuffer [maxMessageSize]byte - -// bufPool holds the temporary buffers for Conn.Read & Write. -var bufPool = &sync.Pool{ - New: func() any { - return new(maxMsgBuffer) - }, -} - -func getMaxMsgBuffer() *maxMsgBuffer { - return bufPool.Get().(*maxMsgBuffer) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package controlbase implements the base transport of the Tailscale +// 2021 control protocol. +// +// The base transport implements Noise IK, instantiated with +// Curve25519, ChaCha20Poly1305 and BLAKE2s. +package controlbase + +import ( + "crypto/cipher" + "encoding/binary" + "fmt" + "net" + "sync" + "time" + + "golang.org/x/crypto/blake2s" + chp "golang.org/x/crypto/chacha20poly1305" + "tailscale.com/types/key" +) + +const ( + // maxMessageSize is the maximum size of a protocol frame on the + // wire, including header and payload. + maxMessageSize = 4096 + // maxCiphertextSize is the maximum amount of ciphertext bytes + // that one protocol frame can carry, after framing. + maxCiphertextSize = maxMessageSize - 3 + // maxPlaintextSize is the maximum amount of plaintext bytes that + // one protocol frame can carry, after encryption and framing. + maxPlaintextSize = maxCiphertextSize - chp.Overhead +) + +// A Conn is a secured Noise connection. It implements the net.Conn +// interface, with the unusual trait that any write error (including a +// SetWriteDeadline induced i/o timeout) causes all future writes to +// fail. +type Conn struct { + conn net.Conn + version uint16 + peer key.MachinePublic + handshakeHash [blake2s.Size]byte + rx rxState + tx txState +} + +// rxState is all the Conn state that Read uses. +type rxState struct { + sync.Mutex + cipher cipher.AEAD + nonce nonce + buf *maxMsgBuffer // or nil when reads exhausted + n int // number of valid bytes in buf + next int // offset of next undecrypted packet + plaintext []byte // slice into buf of decrypted bytes + hdrBuf [headerLen]byte // small buffer used when buf is nil +} + +// txState is all the Conn state that Write uses. +type txState struct { + sync.Mutex + cipher cipher.AEAD + nonce nonce + err error // records the first partial write error for all future calls +} + +// ProtocolVersion returns the protocol version that was used to +// establish this Conn. +func (c *Conn) ProtocolVersion() int { + return int(c.version) +} + +// HandshakeHash returns the Noise handshake hash for the connection, +// which can be used to bind other messages to this connection +// (i.e. to ensure that the message wasn't replayed from a different +// connection). +func (c *Conn) HandshakeHash() [blake2s.Size]byte { + return c.handshakeHash +} + +// Peer returns the peer's long-term public key. +func (c *Conn) Peer() key.MachinePublic { + return c.peer +} + +// readNLocked reads into c.rx.buf until buf contains at least total +// bytes. Returns a slice of the total bytes in rxBuf, or an +// error if fewer than total bytes are available. +// +// It may be called with a nil c.rx.buf only if total == headerLen. +// +// On success, c.rx.buf will be non-nil. +func (c *Conn) readNLocked(total int) ([]byte, error) { + if total > maxMessageSize { + return nil, errReadTooBig{total} + } + for { + if total <= c.rx.n { + return c.rx.buf[:total], nil + } + var n int + var err error + if c.rx.buf == nil { + if c.rx.n != 0 || total != headerLen { + panic("unexpected") + } + // Optimization to reduce memory usage. + // Most connections are blocked forever waiting for + // a read, so we don't want c.rx.buf to be allocated until + // we know there's data to read. Instead, when we're + // waiting for data to arrive here, read into the + // 3 byte hdrBuf: + n, err = c.conn.Read(c.rx.hdrBuf[:]) + if n > 0 { + c.rx.buf = getMaxMsgBuffer() + copy(c.rx.buf[:], c.rx.hdrBuf[:n]) + } + } else { + n, err = c.conn.Read(c.rx.buf[c.rx.n:]) + } + c.rx.n += n + if err != nil { + return nil, err + } + } +} + +// decryptLocked decrypts msg (which is header+ciphertext) in-place +// and sets c.rx.plaintext to the decrypted bytes. +func (c *Conn) decryptLocked(msg []byte) (err error) { + if msgType := msg[0]; msgType != msgTypeRecord { + return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord) + } + // We don't check the length field here, because the caller + // already did in order to figure out how big the msg slice should + // be. + ciphertext := msg[headerLen:] + + if !c.rx.nonce.Valid() { + return errCipherExhausted{} + } + + c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil) + c.rx.nonce.Increment() + + if err != nil { + // Once a decryption has failed, our Conn is no longer + // synchronized with our peer. Nuke the cipher state to be + // safe, so that no further decryptions are attempted. Future + // read attempts will return net.ErrClosed. + c.rx.cipher = nil + } + return err +} + +// encryptLocked encrypts plaintext into buf (including the +// packet header) and returns a slice of the ciphertext, or an error +// if the cipher is exhausted (i.e. can no longer be used safely). +func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) { + if !c.tx.nonce.Valid() { + // Received 2^64-1 messages on this cipher state. Connection + // is no longer usable. + return nil, errCipherExhausted{} + } + + buf[0] = msgTypeRecord + binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead)) + ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil) + c.tx.nonce.Increment() + + return ret, nil +} + +// wholeMessageLocked returns a slice of one whole Noise transport +// message from c.rx.buf, if one whole message is available, and +// advances the read state to the next Noise message in the +// buffer. Returns nil without advancing read state if there isn't one +// whole message in c.rx.buf. +func (c *Conn) wholeMessageLocked() []byte { + available := c.rx.n - c.rx.next + if available < headerLen { + return nil + } + bs := c.rx.buf[c.rx.next:c.rx.n] + totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) + if len(bs) < totalSize { + return nil + } + c.rx.next += totalSize + return bs[:totalSize] +} + +// decryptOneLocked decrypts one Noise transport message, reading from +// c.conn as needed, and sets c.rx.plaintext to point to the decrypted +// bytes. c.rx.plaintext is only valid if err == nil. +func (c *Conn) decryptOneLocked() error { + c.rx.plaintext = nil + + // Fast path: do we have one whole ciphertext frame buffered + // already? + if bs := c.wholeMessageLocked(); bs != nil { + return c.decryptLocked(bs) + } + + if c.rx.next != 0 { + // To simplify the read logic, move the remainder of the + // buffered bytes back to the head of the buffer, so we can + // grow it without worrying about wraparound. + c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n]) + c.rx.next = 0 + } + + // Return our buffer to the pool if it's empty, lest we be + // blocked in a long Read call, reading the 3 byte header. We + // don't to keep that buffer unnecessarily alive. + if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil { + bufPool.Put(c.rx.buf) + c.rx.buf = nil + } + + bs, err := c.readNLocked(headerLen) + if err != nil { + return err + } + // The rest of the header (besides the length field) gets verified + // in decryptLocked, not here. + messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) + bs, err = c.readNLocked(messageLen) + if err != nil { + return err + } + + c.rx.next = len(bs) + + return c.decryptLocked(bs) +} + +// Read implements io.Reader. +func (c *Conn) Read(bs []byte) (int, error) { + c.rx.Lock() + defer c.rx.Unlock() + + if c.rx.cipher == nil { + return 0, net.ErrClosed + } + // If no plaintext is buffered, decrypt incoming frames until we + // have some plaintext. Zero-byte Noise frames are allowed in this + // protocol, which is why we have to loop here rather than decrypt + // a single additional frame. + for len(c.rx.plaintext) == 0 { + if err := c.decryptOneLocked(); err != nil { + return 0, err + } + } + n := copy(bs, c.rx.plaintext) + c.rx.plaintext = c.rx.plaintext[n:] + + // Lose slice's underlying array pointer to unneeded memory so + // GC can collect more. + if len(c.rx.plaintext) == 0 { + c.rx.plaintext = nil + } + return n, nil +} + +// Write implements io.Writer. +func (c *Conn) Write(bs []byte) (n int, err error) { + c.tx.Lock() + defer c.tx.Unlock() + + if c.tx.err != nil { + return 0, c.tx.err + } + defer func() { + if err != nil { + // All write errors are fatal for this conn, so clear the + // cipher state whenever an error happens. + c.tx.cipher = nil + } + if c.tx.err == nil { + // Only set c.tx.err if not nil so that we can return one + // error on the first failure, and a different one for + // subsequent calls. See the error handling around Write + // below for why. + c.tx.err = err + } + }() + + if c.tx.cipher == nil { + return 0, net.ErrClosed + } + + buf := getMaxMsgBuffer() + defer bufPool.Put(buf) + + var sent int + for len(bs) > 0 { + toSend := bs + if len(toSend) > maxPlaintextSize { + toSend = bs[:maxPlaintextSize] + } + bs = bs[len(toSend):] + + ciphertext, err := c.encryptLocked(toSend, buf) + if err != nil { + return sent, err + } + if _, err := c.conn.Write(ciphertext); err != nil { + // Return the raw error on the Write that actually + // failed. For future writes, return that error wrapped in + // a desync error. + c.tx.err = errPartialWrite{err} + return sent, err + } + sent += len(toSend) + } + return sent, nil +} + +// Close implements io.Closer. +func (c *Conn) Close() error { + closeErr := c.conn.Close() // unblocks any waiting reads or writes + + // Remove references to live cipher state. Strictly speaking this + // is unnecessary, but we want to try and hand the active cipher + // state to the garbage collector promptly, to preserve perfect + // forward secrecy as much as we can. + c.rx.Lock() + c.rx.cipher = nil + c.rx.Unlock() + c.tx.Lock() + c.tx.cipher = nil + c.tx.Unlock() + return closeErr +} + +func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } +func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } +func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } +func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } +func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } + +// errCipherExhausted is the error returned when we run out of nonces +// on a cipher. +type errCipherExhausted struct{} + +func (errCipherExhausted) Error() string { + return "cipher exhausted, no more nonces available for current key" +} +func (errCipherExhausted) Timeout() bool { return false } +func (errCipherExhausted) Temporary() bool { return false } + +// errPartialWrite is the error returned when the cipher state has +// become unusable due to a past partial write. +type errPartialWrite struct { + err error +} + +func (e errPartialWrite) Error() string { + return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err) +} +func (e errPartialWrite) Unwrap() error { return e.err } +func (e errPartialWrite) Temporary() bool { return false } +func (e errPartialWrite) Timeout() bool { return false } + +// errReadTooBig is the error returned when the peer sent an +// unacceptably large Noise frame. +type errReadTooBig struct { + requested int +} + +func (e errReadTooBig) Error() string { + return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested) +} +func (e errReadTooBig) Temporary() bool { + // permanent error because this error only occurs when our peer + // sends us a frame so large we're unwilling to ever decode it. + return false +} +func (e errReadTooBig) Timeout() bool { return false } + +type nonce [chp.NonceSize]byte + +func (n *nonce) Valid() bool { + return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce +} + +func (n *nonce) Increment() { + if !n.Valid() { + panic("increment of invalid nonce") + } + binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:])) +} + +type maxMsgBuffer [maxMessageSize]byte + +// bufPool holds the temporary buffers for Conn.Read & Write. +var bufPool = &sync.Pool{ + New: func() any { + return new(maxMsgBuffer) + }, +} + +func getMaxMsgBuffer() *maxMsgBuffer { + return bufPool.Get().(*maxMsgBuffer) +} diff --git a/control/controlbase/handshake.go b/control/controlbase/handshake.go index 765a4620b876f..937969a3078a8 100644 --- a/control/controlbase/handshake.go +++ b/control/controlbase/handshake.go @@ -1,494 +1,494 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import ( - "context" - "crypto/cipher" - "encoding/binary" - "errors" - "fmt" - "hash" - "io" - "net" - "strconv" - "time" - - "go4.org/mem" - "golang.org/x/crypto/blake2s" - chp "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/hkdf" - "tailscale.com/types/key" -) - -const ( - // protocolName is the name of the specific instantiation of Noise - // that the control protocol uses. This string's value is fixed by - // the Noise spec, and shouldn't be changed unless we're updating - // the control protocol to use a different Noise instance. - protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s" - // protocolVersion is the version of the control protocol that - // Client will use when initiating a handshake. - //protocolVersion uint16 = 1 - // protocolVersionPrefix is the name portion of the protocol - // name+version string that gets mixed into the handshake as a - // prologue. - // - // This mixing verifies that both clients agree that they're - // executing the control protocol at a specific version that - // matches the advertised version in the cleartext packet header. - protocolVersionPrefix = "Tailscale Control Protocol v" - invalidNonce = ^uint64(0) -) - -func protocolVersionPrologue(version uint16) []byte { - ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers. - ret = append(ret, protocolVersionPrefix...) - return strconv.AppendUint(ret, uint64(version), 10) -} - -// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn -// is assumed to have already sent the client>server handshake -// initiation message. -type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error) - -// ClientDeferred initiates a control client handshake, returning the -// initial message to send to the server and a continuation to -// finalize the handshake. -// -// ClientDeferred is split in this way for RTT reduction: we run this -// protocol after negotiating a protocol switch from HTTP/HTTPS. If we -// completely serialized the negotiation followed by the handshake, -// we'd pay an extra RTT to transmit the handshake initiation after -// protocol switching. By splitting the handshake into an initial -// message and a continuation, we can embed the handshake initiation -// into the HTTP protocol switching request and avoid a bit of delay. -func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) { - var s symmetricState - s.Initialize() - - // prologue - s.MixHash(protocolVersionPrologue(protocolVersion)) - - // <- s - // ... - s.MixHash(controlKey.UntypedBytes()) - - // -> e, es, s, ss - init := mkInitiationMessage(protocolVersion) - machineEphemeral := key.NewMachine() - machineEphemeralPub := machineEphemeral.Public() - copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes()) - s.MixHash(machineEphemeralPub.UntypedBytes()) - cipher, err := s.MixDH(machineEphemeral, controlKey) - if err != nil { - return nil, nil, fmt.Errorf("computing es: %w", err) - } - machineKeyPub := machineKey.Public() - s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes()) - cipher, err = s.MixDH(machineKey, controlKey) - if err != nil { - return nil, nil, fmt.Errorf("computing ss: %w", err) - } - s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload - - cont := func(ctx context.Context, conn net.Conn) (*Conn, error) { - return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion) - } - return init[:], cont, nil -} - -// Client wraps ClientDeferred and immediately invokes the returned -// continuation with conn. -// -// This is a helper for when you don't need the fancy -// continuation-style handshake, and just want to synchronously -// upgrade a net.Conn to a secure transport. -func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { - init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion) - if err != nil { - return nil, err - } - if _, err := conn.Write(init); err != nil { - return nil, err - } - return cont(ctx, conn) -} - -func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { - // No matter what, this function can only run once per s. Ensure - // attempted reuse causes a panic. - defer func() { - s.finished = true - }() - - if deadline, ok := ctx.Deadline(); ok { - if err := conn.SetDeadline(deadline); err != nil { - return nil, fmt.Errorf("setting conn deadline: %w", err) - } - defer func() { - conn.SetDeadline(time.Time{}) - }() - } - - // Read in the payload and look for errors/protocol violations from the server. - var resp responseMessage - if _, err := io.ReadFull(conn, resp.Header()); err != nil { - return nil, fmt.Errorf("reading response header: %w", err) - } - if resp.Type() != msgTypeResponse { - if resp.Type() != msgTypeError { - return nil, fmt.Errorf("unexpected response message type %d", resp.Type()) - } - msg := make([]byte, resp.Length()) - if _, err := io.ReadFull(conn, msg); err != nil { - return nil, err - } - return nil, fmt.Errorf("server error: %q", msg) - } - if resp.Length() != len(resp.Payload()) { - return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length()) - } - if _, err := io.ReadFull(conn, resp.Payload()); err != nil { - return nil, err - } - - // <- e, ee, se - controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub())) - s.MixHash(controlEphemeralPub.UntypedBytes()) - if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { - return nil, fmt.Errorf("computing ee: %w", err) - } - cipher, err := s.MixDH(machineKey, controlEphemeralPub) - if err != nil { - return nil, fmt.Errorf("computing se: %w", err) - } - if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil { - return nil, fmt.Errorf("decrypting payload: %w", err) - } - - c1, c2, err := s.Split() - if err != nil { - return nil, fmt.Errorf("finalizing handshake: %w", err) - } - - c := &Conn{ - conn: conn, - version: protocolVersion, - peer: controlKey, - handshakeHash: s.h, - tx: txState{ - cipher: c1, - }, - rx: rxState{ - cipher: c2, - }, - } - return c, nil -} - -// Server initiates a control server handshake, returning the resulting -// control connection. -// -// optionalInit can be the client's initial handshake message as -// returned by ClientDeferred, or nil in which case the initial -// message is read from conn. -// -// The context deadline, if any, covers the entire handshaking -// process. -func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) { - if deadline, ok := ctx.Deadline(); ok { - if err := conn.SetDeadline(deadline); err != nil { - return nil, fmt.Errorf("setting conn deadline: %w", err) - } - defer func() { - conn.SetDeadline(time.Time{}) - }() - } - - // Deliberately does not support formatting, so that we don't echo - // attacker-controlled input back to them. - sendErr := func(msg string) error { - if len(msg) >= 1<<16 { - msg = msg[:1<<16] - } - var hdr [headerLen]byte - hdr[0] = msgTypeError - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg))) - if _, err := conn.Write(hdr[:]); err != nil { - return fmt.Errorf("sending %q error to client: %w", msg, err) - } - if _, err := io.WriteString(conn, msg); err != nil { - return fmt.Errorf("sending %q error to client: %w", msg, err) - } - return fmt.Errorf("refused client handshake: %q", msg) - } - - var s symmetricState - s.Initialize() - - var init initiationMessage - if optionalInit != nil { - if len(optionalInit) != len(init) { - return nil, sendErr("wrong handshake initiation size") - } - copy(init[:], optionalInit) - } else if _, err := io.ReadFull(conn, init.Header()); err != nil { - return nil, err - } - // Just a rename to make it more obvious what the value is. In the - // current implementation we don't need to block any protocol - // versions at this layer, it's safe to let the handshake proceed - // and then let the caller make decisions based on the agreed-upon - // protocol version. - clientVersion := init.Version() - if init.Type() != msgTypeInitiation { - return nil, sendErr("unexpected handshake message type") - } - if init.Length() != len(init.Payload()) { - return nil, sendErr("wrong handshake initiation length") - } - // if optionalInit was provided, we have the payload already. - if optionalInit == nil { - if _, err := io.ReadFull(conn, init.Payload()); err != nil { - return nil, err - } - } - - // prologue. Can only do this once we at least think the client is - // handshaking using a supported version. - s.MixHash(protocolVersionPrologue(clientVersion)) - - // <- s - // ... - controlKeyPub := controlKey.Public() - s.MixHash(controlKeyPub.UntypedBytes()) - - // -> e, es, s, ss - machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub())) - s.MixHash(machineEphemeralPub.UntypedBytes()) - cipher, err := s.MixDH(controlKey, machineEphemeralPub) - if err != nil { - return nil, fmt.Errorf("computing es: %w", err) - } - var machineKeyBytes [32]byte - if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil { - return nil, fmt.Errorf("decrypting machine key: %w", err) - } - machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:])) - cipher, err = s.MixDH(controlKey, machineKey) - if err != nil { - return nil, fmt.Errorf("computing ss: %w", err) - } - if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil { - return nil, fmt.Errorf("decrypting initiation tag: %w", err) - } - - // <- e, ee, se - resp := mkResponseMessage() - controlEphemeral := key.NewMachine() - controlEphemeralPub := controlEphemeral.Public() - copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes()) - s.MixHash(controlEphemeralPub.UntypedBytes()) - if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil { - return nil, fmt.Errorf("computing ee: %w", err) - } - cipher, err = s.MixDH(controlEphemeral, machineKey) - if err != nil { - return nil, fmt.Errorf("computing se: %w", err) - } - s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload - - c1, c2, err := s.Split() - if err != nil { - return nil, fmt.Errorf("finalizing handshake: %w", err) - } - - if _, err := conn.Write(resp[:]); err != nil { - return nil, err - } - - c := &Conn{ - conn: conn, - version: clientVersion, - peer: machineKey, - handshakeHash: s.h, - tx: txState{ - cipher: c2, - }, - rx: rxState{ - cipher: c1, - }, - } - return c, nil -} - -// symmetricState contains the state of an in-flight handshake. -type symmetricState struct { - finished bool - - h [blake2s.Size]byte // hash of currently-processed handshake state - ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake -} - -func (s *symmetricState) checkFinished() { - if s.finished { - panic("attempted to use symmetricState after Split was called") - } -} - -// Initialize sets s to the initial handshake state, prior to -// processing any handshake messages. -func (s *symmetricState) Initialize() { - s.checkFinished() - s.h = blake2s.Sum256([]byte(protocolName)) - s.ck = s.h -} - -// MixHash updates s.h to be BLAKE2s(s.h || data), where || is -// concatenation. -func (s *symmetricState) MixHash(data []byte) { - s.checkFinished() - h := newBLAKE2s() - h.Write(s.h[:]) - h.Write(data) - h.Sum(s.h[:0]) -} - -// MixDH updates s.ck with the result of X25519(priv, pub) and returns -// a singleUseCHP that can be used to encrypt or decrypt handshake -// data. -// -// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing -// it as a single function allows for strongly-typed arguments that -// reduce the risk of error in the caller (e.g. invoking X25519 with -// two private keys, or two public keys), and thus producing the wrong -// calculation. -func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) { - s.checkFinished() - keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes()) - if err != nil { - return nil, fmt.Errorf("computing X25519: %w", err) - } - - r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil) - if _, err := io.ReadFull(r, s.ck[:]); err != nil { - return nil, fmt.Errorf("extracting ck: %w", err) - } - var k [chp.KeySize]byte - if _, err := io.ReadFull(r, k[:]); err != nil { - return nil, fmt.Errorf("extracting k: %w", err) - } - return newSingleUseCHP(k), nil -} - -// EncryptAndHash encrypts plaintext into ciphertext (which must be -// the correct size to hold the encrypted plaintext) using cipher, -// mixes the ciphertext into s.h, and returns the ciphertext. -func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) { - s.checkFinished() - if len(ciphertext) != len(plaintext)+chp.Overhead { - panic("ciphertext is wrong size for given plaintext") - } - ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:]) - s.MixHash(ret) -} - -// DecryptAndHash decrypts the given ciphertext into plaintext (which -// must be the correct size to hold the decrypted ciphertext) using -// cipher. If decryption is successful, it mixes the ciphertext into -// s.h. -func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error { - s.checkFinished() - if len(ciphertext) != len(plaintext)+chp.Overhead { - return errors.New("plaintext is wrong size for given ciphertext") - } - if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil { - return err - } - s.MixHash(ciphertext) - return nil -} - -// Split returns two ChaCha20Poly1305 ciphers with keys derived from -// the current handshake state. Methods on s cannot be used again -// after calling Split. -func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) { - s.finished = true - - var k1, k2 [chp.KeySize]byte - r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil) - if _, err := io.ReadFull(r, k1[:]); err != nil { - return nil, nil, fmt.Errorf("extracting k1: %w", err) - } - if _, err := io.ReadFull(r, k2[:]); err != nil { - return nil, nil, fmt.Errorf("extracting k2: %w", err) - } - c1, err = chp.New(k1[:]) - if err != nil { - return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err) - } - c2, err = chp.New(k2[:]) - if err != nil { - return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err) - } - return c1, c2, nil -} - -// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on -// error. -func newBLAKE2s() hash.Hash { - h, err := blake2s.New256(nil) - if err != nil { - // Should never happen, errors only happen when using BLAKE2s - // in MAC mode with a key. - panic(err) - } - return h -} - -// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or -// panics on error. -func newCHP(key [chp.KeySize]byte) cipher.AEAD { - aead, err := chp.New(key[:]) - if err != nil { - // Can only happen if we passed a key of the wrong length. The - // function signature prevents that. - panic(err) - } - return aead -} - -// singleUseCHP is an instance of ChaCha20Poly1305 that can be used -// only once, either for encrypting or decrypting, but not both. The -// chosen operation is always executed with an all-zeros -// nonce. Subsequent calls to either Seal or Open panic. -type singleUseCHP struct { - c cipher.AEAD -} - -func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP { - return &singleUseCHP{newCHP(key)} -} - -func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte { - if c.c == nil { - panic("Attempted reuse of singleUseAEAD") - } - cipher := c.c - c.c = nil - var nonce [chp.NonceSize]byte - return cipher.Seal(dst, nonce[:], plaintext, additionalData) -} - -func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) { - if c.c == nil { - panic("Attempted reuse of singleUseAEAD") - } - cipher := c.c - c.c = nil - var nonce [chp.NonceSize]byte - return cipher.Open(dst, nonce[:], ciphertext, additionalData) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlbase + +import ( + "context" + "crypto/cipher" + "encoding/binary" + "errors" + "fmt" + "hash" + "io" + "net" + "strconv" + "time" + + "go4.org/mem" + "golang.org/x/crypto/blake2s" + chp "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" + "tailscale.com/types/key" +) + +const ( + // protocolName is the name of the specific instantiation of Noise + // that the control protocol uses. This string's value is fixed by + // the Noise spec, and shouldn't be changed unless we're updating + // the control protocol to use a different Noise instance. + protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s" + // protocolVersion is the version of the control protocol that + // Client will use when initiating a handshake. + //protocolVersion uint16 = 1 + // protocolVersionPrefix is the name portion of the protocol + // name+version string that gets mixed into the handshake as a + // prologue. + // + // This mixing verifies that both clients agree that they're + // executing the control protocol at a specific version that + // matches the advertised version in the cleartext packet header. + protocolVersionPrefix = "Tailscale Control Protocol v" + invalidNonce = ^uint64(0) +) + +func protocolVersionPrologue(version uint16) []byte { + ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers. + ret = append(ret, protocolVersionPrefix...) + return strconv.AppendUint(ret, uint64(version), 10) +} + +// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn +// is assumed to have already sent the client>server handshake +// initiation message. +type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error) + +// ClientDeferred initiates a control client handshake, returning the +// initial message to send to the server and a continuation to +// finalize the handshake. +// +// ClientDeferred is split in this way for RTT reduction: we run this +// protocol after negotiating a protocol switch from HTTP/HTTPS. If we +// completely serialized the negotiation followed by the handshake, +// we'd pay an extra RTT to transmit the handshake initiation after +// protocol switching. By splitting the handshake into an initial +// message and a continuation, we can embed the handshake initiation +// into the HTTP protocol switching request and avoid a bit of delay. +func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) { + var s symmetricState + s.Initialize() + + // prologue + s.MixHash(protocolVersionPrologue(protocolVersion)) + + // <- s + // ... + s.MixHash(controlKey.UntypedBytes()) + + // -> e, es, s, ss + init := mkInitiationMessage(protocolVersion) + machineEphemeral := key.NewMachine() + machineEphemeralPub := machineEphemeral.Public() + copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes()) + s.MixHash(machineEphemeralPub.UntypedBytes()) + cipher, err := s.MixDH(machineEphemeral, controlKey) + if err != nil { + return nil, nil, fmt.Errorf("computing es: %w", err) + } + machineKeyPub := machineKey.Public() + s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes()) + cipher, err = s.MixDH(machineKey, controlKey) + if err != nil { + return nil, nil, fmt.Errorf("computing ss: %w", err) + } + s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload + + cont := func(ctx context.Context, conn net.Conn) (*Conn, error) { + return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion) + } + return init[:], cont, nil +} + +// Client wraps ClientDeferred and immediately invokes the returned +// continuation with conn. +// +// This is a helper for when you don't need the fancy +// continuation-style handshake, and just want to synchronously +// upgrade a net.Conn to a secure transport. +func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { + init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion) + if err != nil { + return nil, err + } + if _, err := conn.Write(init); err != nil { + return nil, err + } + return cont(ctx, conn) +} + +func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { + // No matter what, this function can only run once per s. Ensure + // attempted reuse causes a panic. + defer func() { + s.finished = true + }() + + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + return nil, fmt.Errorf("setting conn deadline: %w", err) + } + defer func() { + conn.SetDeadline(time.Time{}) + }() + } + + // Read in the payload and look for errors/protocol violations from the server. + var resp responseMessage + if _, err := io.ReadFull(conn, resp.Header()); err != nil { + return nil, fmt.Errorf("reading response header: %w", err) + } + if resp.Type() != msgTypeResponse { + if resp.Type() != msgTypeError { + return nil, fmt.Errorf("unexpected response message type %d", resp.Type()) + } + msg := make([]byte, resp.Length()) + if _, err := io.ReadFull(conn, msg); err != nil { + return nil, err + } + return nil, fmt.Errorf("server error: %q", msg) + } + if resp.Length() != len(resp.Payload()) { + return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length()) + } + if _, err := io.ReadFull(conn, resp.Payload()); err != nil { + return nil, err + } + + // <- e, ee, se + controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub())) + s.MixHash(controlEphemeralPub.UntypedBytes()) + if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { + return nil, fmt.Errorf("computing ee: %w", err) + } + cipher, err := s.MixDH(machineKey, controlEphemeralPub) + if err != nil { + return nil, fmt.Errorf("computing se: %w", err) + } + if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil { + return nil, fmt.Errorf("decrypting payload: %w", err) + } + + c1, c2, err := s.Split() + if err != nil { + return nil, fmt.Errorf("finalizing handshake: %w", err) + } + + c := &Conn{ + conn: conn, + version: protocolVersion, + peer: controlKey, + handshakeHash: s.h, + tx: txState{ + cipher: c1, + }, + rx: rxState{ + cipher: c2, + }, + } + return c, nil +} + +// Server initiates a control server handshake, returning the resulting +// control connection. +// +// optionalInit can be the client's initial handshake message as +// returned by ClientDeferred, or nil in which case the initial +// message is read from conn. +// +// The context deadline, if any, covers the entire handshaking +// process. +func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) { + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + return nil, fmt.Errorf("setting conn deadline: %w", err) + } + defer func() { + conn.SetDeadline(time.Time{}) + }() + } + + // Deliberately does not support formatting, so that we don't echo + // attacker-controlled input back to them. + sendErr := func(msg string) error { + if len(msg) >= 1<<16 { + msg = msg[:1<<16] + } + var hdr [headerLen]byte + hdr[0] = msgTypeError + binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg))) + if _, err := conn.Write(hdr[:]); err != nil { + return fmt.Errorf("sending %q error to client: %w", msg, err) + } + if _, err := io.WriteString(conn, msg); err != nil { + return fmt.Errorf("sending %q error to client: %w", msg, err) + } + return fmt.Errorf("refused client handshake: %q", msg) + } + + var s symmetricState + s.Initialize() + + var init initiationMessage + if optionalInit != nil { + if len(optionalInit) != len(init) { + return nil, sendErr("wrong handshake initiation size") + } + copy(init[:], optionalInit) + } else if _, err := io.ReadFull(conn, init.Header()); err != nil { + return nil, err + } + // Just a rename to make it more obvious what the value is. In the + // current implementation we don't need to block any protocol + // versions at this layer, it's safe to let the handshake proceed + // and then let the caller make decisions based on the agreed-upon + // protocol version. + clientVersion := init.Version() + if init.Type() != msgTypeInitiation { + return nil, sendErr("unexpected handshake message type") + } + if init.Length() != len(init.Payload()) { + return nil, sendErr("wrong handshake initiation length") + } + // if optionalInit was provided, we have the payload already. + if optionalInit == nil { + if _, err := io.ReadFull(conn, init.Payload()); err != nil { + return nil, err + } + } + + // prologue. Can only do this once we at least think the client is + // handshaking using a supported version. + s.MixHash(protocolVersionPrologue(clientVersion)) + + // <- s + // ... + controlKeyPub := controlKey.Public() + s.MixHash(controlKeyPub.UntypedBytes()) + + // -> e, es, s, ss + machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub())) + s.MixHash(machineEphemeralPub.UntypedBytes()) + cipher, err := s.MixDH(controlKey, machineEphemeralPub) + if err != nil { + return nil, fmt.Errorf("computing es: %w", err) + } + var machineKeyBytes [32]byte + if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil { + return nil, fmt.Errorf("decrypting machine key: %w", err) + } + machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:])) + cipher, err = s.MixDH(controlKey, machineKey) + if err != nil { + return nil, fmt.Errorf("computing ss: %w", err) + } + if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil { + return nil, fmt.Errorf("decrypting initiation tag: %w", err) + } + + // <- e, ee, se + resp := mkResponseMessage() + controlEphemeral := key.NewMachine() + controlEphemeralPub := controlEphemeral.Public() + copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes()) + s.MixHash(controlEphemeralPub.UntypedBytes()) + if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil { + return nil, fmt.Errorf("computing ee: %w", err) + } + cipher, err = s.MixDH(controlEphemeral, machineKey) + if err != nil { + return nil, fmt.Errorf("computing se: %w", err) + } + s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload + + c1, c2, err := s.Split() + if err != nil { + return nil, fmt.Errorf("finalizing handshake: %w", err) + } + + if _, err := conn.Write(resp[:]); err != nil { + return nil, err + } + + c := &Conn{ + conn: conn, + version: clientVersion, + peer: machineKey, + handshakeHash: s.h, + tx: txState{ + cipher: c2, + }, + rx: rxState{ + cipher: c1, + }, + } + return c, nil +} + +// symmetricState contains the state of an in-flight handshake. +type symmetricState struct { + finished bool + + h [blake2s.Size]byte // hash of currently-processed handshake state + ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake +} + +func (s *symmetricState) checkFinished() { + if s.finished { + panic("attempted to use symmetricState after Split was called") + } +} + +// Initialize sets s to the initial handshake state, prior to +// processing any handshake messages. +func (s *symmetricState) Initialize() { + s.checkFinished() + s.h = blake2s.Sum256([]byte(protocolName)) + s.ck = s.h +} + +// MixHash updates s.h to be BLAKE2s(s.h || data), where || is +// concatenation. +func (s *symmetricState) MixHash(data []byte) { + s.checkFinished() + h := newBLAKE2s() + h.Write(s.h[:]) + h.Write(data) + h.Sum(s.h[:0]) +} + +// MixDH updates s.ck with the result of X25519(priv, pub) and returns +// a singleUseCHP that can be used to encrypt or decrypt handshake +// data. +// +// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing +// it as a single function allows for strongly-typed arguments that +// reduce the risk of error in the caller (e.g. invoking X25519 with +// two private keys, or two public keys), and thus producing the wrong +// calculation. +func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) { + s.checkFinished() + keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes()) + if err != nil { + return nil, fmt.Errorf("computing X25519: %w", err) + } + + r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil) + if _, err := io.ReadFull(r, s.ck[:]); err != nil { + return nil, fmt.Errorf("extracting ck: %w", err) + } + var k [chp.KeySize]byte + if _, err := io.ReadFull(r, k[:]); err != nil { + return nil, fmt.Errorf("extracting k: %w", err) + } + return newSingleUseCHP(k), nil +} + +// EncryptAndHash encrypts plaintext into ciphertext (which must be +// the correct size to hold the encrypted plaintext) using cipher, +// mixes the ciphertext into s.h, and returns the ciphertext. +func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) { + s.checkFinished() + if len(ciphertext) != len(plaintext)+chp.Overhead { + panic("ciphertext is wrong size for given plaintext") + } + ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:]) + s.MixHash(ret) +} + +// DecryptAndHash decrypts the given ciphertext into plaintext (which +// must be the correct size to hold the decrypted ciphertext) using +// cipher. If decryption is successful, it mixes the ciphertext into +// s.h. +func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error { + s.checkFinished() + if len(ciphertext) != len(plaintext)+chp.Overhead { + return errors.New("plaintext is wrong size for given ciphertext") + } + if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil { + return err + } + s.MixHash(ciphertext) + return nil +} + +// Split returns two ChaCha20Poly1305 ciphers with keys derived from +// the current handshake state. Methods on s cannot be used again +// after calling Split. +func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) { + s.finished = true + + var k1, k2 [chp.KeySize]byte + r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil) + if _, err := io.ReadFull(r, k1[:]); err != nil { + return nil, nil, fmt.Errorf("extracting k1: %w", err) + } + if _, err := io.ReadFull(r, k2[:]); err != nil { + return nil, nil, fmt.Errorf("extracting k2: %w", err) + } + c1, err = chp.New(k1[:]) + if err != nil { + return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err) + } + c2, err = chp.New(k2[:]) + if err != nil { + return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err) + } + return c1, c2, nil +} + +// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on +// error. +func newBLAKE2s() hash.Hash { + h, err := blake2s.New256(nil) + if err != nil { + // Should never happen, errors only happen when using BLAKE2s + // in MAC mode with a key. + panic(err) + } + return h +} + +// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or +// panics on error. +func newCHP(key [chp.KeySize]byte) cipher.AEAD { + aead, err := chp.New(key[:]) + if err != nil { + // Can only happen if we passed a key of the wrong length. The + // function signature prevents that. + panic(err) + } + return aead +} + +// singleUseCHP is an instance of ChaCha20Poly1305 that can be used +// only once, either for encrypting or decrypting, but not both. The +// chosen operation is always executed with an all-zeros +// nonce. Subsequent calls to either Seal or Open panic. +type singleUseCHP struct { + c cipher.AEAD +} + +func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP { + return &singleUseCHP{newCHP(key)} +} + +func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte { + if c.c == nil { + panic("Attempted reuse of singleUseAEAD") + } + cipher := c.c + c.c = nil + var nonce [chp.NonceSize]byte + return cipher.Seal(dst, nonce[:], plaintext, additionalData) +} + +func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) { + if c.c == nil { + panic("Attempted reuse of singleUseAEAD") + } + cipher := c.c + c.c = nil + var nonce [chp.NonceSize]byte + return cipher.Open(dst, nonce[:], ciphertext, additionalData) +} diff --git a/control/controlbase/interop_test.go b/control/controlbase/interop_test.go index c41fbf4dd4950..d11c0414911f3 100644 --- a/control/controlbase/interop_test.go +++ b/control/controlbase/interop_test.go @@ -1,256 +1,256 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import ( - "context" - "encoding/binary" - "errors" - "io" - "net" - "testing" - - "tailscale.com/net/memnet" - "tailscale.com/types/key" -) - -// Can a reference Noise IK client talk to our server? -func TestInteropClient(t *testing.T) { - var ( - s1, s2 = memnet.NewConn("noise", 128000) - controlKey = key.NewMachine() - machineKey = key.NewMachine() - serverErr = make(chan error, 2) - serverBytes = make(chan []byte, 1) - c2s = "client>server" - s2c = "server>client" - ) - - go func() { - server, err := Server(context.Background(), s2, controlKey, nil) - serverErr <- err - if err != nil { - return - } - var buf [1024]byte - _, err = io.ReadFull(server, buf[:len(c2s)]) - serverBytes <- buf[:len(c2s)] - if err != nil { - serverErr <- err - return - } - _, err = server.Write([]byte(s2c)) - serverErr <- err - }() - - gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s)) - if err != nil { - t.Fatalf("failed client interop: %v", err) - } - if string(gotS2C) != s2c { - t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c) - } - - if err := <-serverErr; err != nil { - t.Fatalf("server handshake failed: %v", err) - } - if err := <-serverErr; err != nil { - t.Fatalf("server read/write failed: %v", err) - } - if got := string(<-serverBytes); got != c2s { - t.Fatalf("server received %q, want %q", got, c2s) - } -} - -// Can our client talk to a reference Noise IK server? -func TestInteropServer(t *testing.T) { - var ( - s1, s2 = memnet.NewConn("noise", 128000) - controlKey = key.NewMachine() - machineKey = key.NewMachine() - clientErr = make(chan error, 2) - clientBytes = make(chan []byte, 1) - c2s = "client>server" - s2c = "server>client" - ) - - go func() { - client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion) - clientErr <- err - if err != nil { - return - } - _, err = client.Write([]byte(c2s)) - if err != nil { - clientErr <- err - return - } - var buf [1024]byte - _, err = io.ReadFull(client, buf[:len(s2c)]) - clientBytes <- buf[:len(s2c)] - clientErr <- err - }() - - gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c)) - if err != nil { - t.Fatalf("failed server interop: %v", err) - } - if string(gotC2S) != c2s { - t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s) - } - - if err := <-clientErr; err != nil { - t.Fatalf("client handshake failed: %v", err) - } - if err := <-clientErr; err != nil { - t.Fatalf("client read/write failed: %v", err) - } - if got := string(<-clientBytes); got != s2c { - t.Fatalf("client received %q, want %q", got, s2c) - } -} - -// noiseExplorerClient uses the Noise Explorer implementation of Noise -// IK to handshake as a Noise client on conn, transmit payload, and -// read+return a payload from the peer. -func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey key.MachinePrivate, payload []byte) ([]byte, error) { - var mk keypair - copy(mk.private_key[:], machineKey.UntypedBytes()) - copy(mk.public_key[:], machineKey.Public().UntypedBytes()) - var peerKey [32]byte - copy(peerKey[:], controlKey.UntypedBytes()) - session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey) - - _, msg1 := SendMessage(&session, nil) - var hdr [initiationHeaderLen]byte - binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion) - hdr[2] = msgTypeInitiation - binary.BigEndian.PutUint16(hdr[3:5], 96) - if _, err := conn.Write(hdr[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg1.ne[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg1.ns); err != nil { - return nil, err - } - if _, err := conn.Write(msg1.ciphertext); err != nil { - return nil, err - } - - var buf [1024]byte - if _, err := io.ReadFull(conn, buf[:51]); err != nil { - return nil, err - } - // ignore the header for this test, we're only checking the noise - // implementation. - msg2 := messagebuffer{ - ciphertext: buf[35:51], - } - copy(msg2.ne[:], buf[3:35]) - _, p, valid := RecvMessage(&session, &msg2) - if !valid { - return nil, errors.New("handshake failed") - } - if len(p) != 0 { - return nil, errors.New("non-empty payload") - } - - _, msg3 := SendMessage(&session, payload) - hdr[0] = msgTypeRecord - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg3.ciphertext))) - if _, err := conn.Write(hdr[:3]); err != nil { - return nil, err - } - if _, err := conn.Write(msg3.ciphertext); err != nil { - return nil, err - } - - if _, err := io.ReadFull(conn, buf[:3]); err != nil { - return nil, err - } - // Ignore all of the header except the payload length - plen := int(binary.BigEndian.Uint16(buf[1:3])) - if _, err := io.ReadFull(conn, buf[:plen]); err != nil { - return nil, err - } - - msg4 := messagebuffer{ - ciphertext: buf[:plen], - } - _, p, valid = RecvMessage(&session, &msg4) - if !valid { - return nil, errors.New("transport message decryption failed") - } - - return p, nil -} - -func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachineKey key.MachinePublic, payload []byte) ([]byte, error) { - var mk keypair - copy(mk.private_key[:], controlKey.UntypedBytes()) - copy(mk.public_key[:], controlKey.Public().UntypedBytes()) - session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{}) - - var buf [1024]byte - if _, err := io.ReadFull(conn, buf[:101]); err != nil { - return nil, err - } - // Ignore the header, we're just checking the noise implementation. - msg1 := messagebuffer{ - ns: buf[37:85], - ciphertext: buf[85:101], - } - copy(msg1.ne[:], buf[5:37]) - _, p, valid := RecvMessage(&session, &msg1) - if !valid { - return nil, errors.New("handshake failed") - } - if len(p) != 0 { - return nil, errors.New("non-empty payload") - } - - _, msg2 := SendMessage(&session, nil) - var hdr [headerLen]byte - hdr[0] = msgTypeResponse - binary.BigEndian.PutUint16(hdr[1:3], 48) - if _, err := conn.Write(hdr[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg2.ne[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg2.ciphertext[:]); err != nil { - return nil, err - } - - if _, err := io.ReadFull(conn, buf[:3]); err != nil { - return nil, err - } - plen := int(binary.BigEndian.Uint16(buf[1:3])) - if _, err := io.ReadFull(conn, buf[:plen]); err != nil { - return nil, err - } - - msg3 := messagebuffer{ - ciphertext: buf[:plen], - } - _, p, valid = RecvMessage(&session, &msg3) - if !valid { - return nil, errors.New("transport message decryption failed") - } - - _, msg4 := SendMessage(&session, payload) - hdr[0] = msgTypeRecord - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg4.ciphertext))) - if _, err := conn.Write(hdr[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg4.ciphertext); err != nil { - return nil, err - } - - return p, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlbase + +import ( + "context" + "encoding/binary" + "errors" + "io" + "net" + "testing" + + "tailscale.com/net/memnet" + "tailscale.com/types/key" +) + +// Can a reference Noise IK client talk to our server? +func TestInteropClient(t *testing.T) { + var ( + s1, s2 = memnet.NewConn("noise", 128000) + controlKey = key.NewMachine() + machineKey = key.NewMachine() + serverErr = make(chan error, 2) + serverBytes = make(chan []byte, 1) + c2s = "client>server" + s2c = "server>client" + ) + + go func() { + server, err := Server(context.Background(), s2, controlKey, nil) + serverErr <- err + if err != nil { + return + } + var buf [1024]byte + _, err = io.ReadFull(server, buf[:len(c2s)]) + serverBytes <- buf[:len(c2s)] + if err != nil { + serverErr <- err + return + } + _, err = server.Write([]byte(s2c)) + serverErr <- err + }() + + gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s)) + if err != nil { + t.Fatalf("failed client interop: %v", err) + } + if string(gotS2C) != s2c { + t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c) + } + + if err := <-serverErr; err != nil { + t.Fatalf("server handshake failed: %v", err) + } + if err := <-serverErr; err != nil { + t.Fatalf("server read/write failed: %v", err) + } + if got := string(<-serverBytes); got != c2s { + t.Fatalf("server received %q, want %q", got, c2s) + } +} + +// Can our client talk to a reference Noise IK server? +func TestInteropServer(t *testing.T) { + var ( + s1, s2 = memnet.NewConn("noise", 128000) + controlKey = key.NewMachine() + machineKey = key.NewMachine() + clientErr = make(chan error, 2) + clientBytes = make(chan []byte, 1) + c2s = "client>server" + s2c = "server>client" + ) + + go func() { + client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion) + clientErr <- err + if err != nil { + return + } + _, err = client.Write([]byte(c2s)) + if err != nil { + clientErr <- err + return + } + var buf [1024]byte + _, err = io.ReadFull(client, buf[:len(s2c)]) + clientBytes <- buf[:len(s2c)] + clientErr <- err + }() + + gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c)) + if err != nil { + t.Fatalf("failed server interop: %v", err) + } + if string(gotC2S) != c2s { + t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s) + } + + if err := <-clientErr; err != nil { + t.Fatalf("client handshake failed: %v", err) + } + if err := <-clientErr; err != nil { + t.Fatalf("client read/write failed: %v", err) + } + if got := string(<-clientBytes); got != s2c { + t.Fatalf("client received %q, want %q", got, s2c) + } +} + +// noiseExplorerClient uses the Noise Explorer implementation of Noise +// IK to handshake as a Noise client on conn, transmit payload, and +// read+return a payload from the peer. +func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey key.MachinePrivate, payload []byte) ([]byte, error) { + var mk keypair + copy(mk.private_key[:], machineKey.UntypedBytes()) + copy(mk.public_key[:], machineKey.Public().UntypedBytes()) + var peerKey [32]byte + copy(peerKey[:], controlKey.UntypedBytes()) + session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey) + + _, msg1 := SendMessage(&session, nil) + var hdr [initiationHeaderLen]byte + binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion) + hdr[2] = msgTypeInitiation + binary.BigEndian.PutUint16(hdr[3:5], 96) + if _, err := conn.Write(hdr[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg1.ne[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg1.ns); err != nil { + return nil, err + } + if _, err := conn.Write(msg1.ciphertext); err != nil { + return nil, err + } + + var buf [1024]byte + if _, err := io.ReadFull(conn, buf[:51]); err != nil { + return nil, err + } + // ignore the header for this test, we're only checking the noise + // implementation. + msg2 := messagebuffer{ + ciphertext: buf[35:51], + } + copy(msg2.ne[:], buf[3:35]) + _, p, valid := RecvMessage(&session, &msg2) + if !valid { + return nil, errors.New("handshake failed") + } + if len(p) != 0 { + return nil, errors.New("non-empty payload") + } + + _, msg3 := SendMessage(&session, payload) + hdr[0] = msgTypeRecord + binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg3.ciphertext))) + if _, err := conn.Write(hdr[:3]); err != nil { + return nil, err + } + if _, err := conn.Write(msg3.ciphertext); err != nil { + return nil, err + } + + if _, err := io.ReadFull(conn, buf[:3]); err != nil { + return nil, err + } + // Ignore all of the header except the payload length + plen := int(binary.BigEndian.Uint16(buf[1:3])) + if _, err := io.ReadFull(conn, buf[:plen]); err != nil { + return nil, err + } + + msg4 := messagebuffer{ + ciphertext: buf[:plen], + } + _, p, valid = RecvMessage(&session, &msg4) + if !valid { + return nil, errors.New("transport message decryption failed") + } + + return p, nil +} + +func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachineKey key.MachinePublic, payload []byte) ([]byte, error) { + var mk keypair + copy(mk.private_key[:], controlKey.UntypedBytes()) + copy(mk.public_key[:], controlKey.Public().UntypedBytes()) + session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{}) + + var buf [1024]byte + if _, err := io.ReadFull(conn, buf[:101]); err != nil { + return nil, err + } + // Ignore the header, we're just checking the noise implementation. + msg1 := messagebuffer{ + ns: buf[37:85], + ciphertext: buf[85:101], + } + copy(msg1.ne[:], buf[5:37]) + _, p, valid := RecvMessage(&session, &msg1) + if !valid { + return nil, errors.New("handshake failed") + } + if len(p) != 0 { + return nil, errors.New("non-empty payload") + } + + _, msg2 := SendMessage(&session, nil) + var hdr [headerLen]byte + hdr[0] = msgTypeResponse + binary.BigEndian.PutUint16(hdr[1:3], 48) + if _, err := conn.Write(hdr[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg2.ne[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg2.ciphertext[:]); err != nil { + return nil, err + } + + if _, err := io.ReadFull(conn, buf[:3]); err != nil { + return nil, err + } + plen := int(binary.BigEndian.Uint16(buf[1:3])) + if _, err := io.ReadFull(conn, buf[:plen]); err != nil { + return nil, err + } + + msg3 := messagebuffer{ + ciphertext: buf[:plen], + } + _, p, valid = RecvMessage(&session, &msg3) + if !valid { + return nil, errors.New("transport message decryption failed") + } + + _, msg4 := SendMessage(&session, payload) + hdr[0] = msgTypeRecord + binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg4.ciphertext))) + if _, err := conn.Write(hdr[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg4.ciphertext); err != nil { + return nil, err + } + + return p, nil +} diff --git a/control/controlbase/messages.go b/control/controlbase/messages.go index 59073088f5e81..8993786819b6c 100644 --- a/control/controlbase/messages.go +++ b/control/controlbase/messages.go @@ -1,87 +1,87 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import "encoding/binary" - -const ( - // msgTypeInitiation frames carry a Noise IK handshake initiation message. - msgTypeInitiation = 1 - // msgTypeResponse frames carry a Noise IK handshake response message. - msgTypeResponse = 2 - // msgTypeError frames carry an unauthenticated human-readable - // error message. - // - // Errors reported in this message type must be treated as public - // hints only. They are not encrypted or authenticated, and so can - // be seen and tampered with on the wire. - msgTypeError = 3 - // msgTypeRecord frames carry session data bytes. - msgTypeRecord = 4 - - // headerLen is the size of the header on all messages except msgTypeInitiation. - headerLen = 3 - // initiationHeaderLen is the size of the header on all msgTypeInitiation messages. - initiationHeaderLen = 5 -) - -// initiationMessage is the protocol message sent from a client -// machine to a control server. -// -// 2b: protocol version -// 1b: message type (0x01) -// 2b: payload length (96) -// 5b: header (see headerLen for fields) -// 32b: client ephemeral public key (cleartext) -// 48b: client machine public key (encrypted) -// 16b: message tag (authenticates the whole message) -type initiationMessage [101]byte - -func mkInitiationMessage(protocolVersion uint16) initiationMessage { - var ret initiationMessage - binary.BigEndian.PutUint16(ret[:2], protocolVersion) - ret[2] = msgTypeInitiation - binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload()))) - return ret -} - -func (m *initiationMessage) Header() []byte { return m[:initiationHeaderLen] } -func (m *initiationMessage) Payload() []byte { return m[initiationHeaderLen:] } - -func (m *initiationMessage) Version() uint16 { return binary.BigEndian.Uint16(m[:2]) } -func (m *initiationMessage) Type() byte { return m[2] } -func (m *initiationMessage) Length() int { return int(binary.BigEndian.Uint16(m[3:5])) } - -func (m *initiationMessage) EphemeralPub() []byte { - return m[initiationHeaderLen : initiationHeaderLen+32] -} -func (m *initiationMessage) MachinePub() []byte { - return m[initiationHeaderLen+32 : initiationHeaderLen+32+48] -} -func (m *initiationMessage) Tag() []byte { return m[initiationHeaderLen+32+48:] } - -// responseMessage is the protocol message sent from a control server -// to a client machine. -// -// 1b: message type (0x02) -// 2b: payload length (48) -// 32b: control ephemeral public key (cleartext) -// 16b: message tag (authenticates the whole message) -type responseMessage [51]byte - -func mkResponseMessage() responseMessage { - var ret responseMessage - ret[0] = msgTypeResponse - binary.BigEndian.PutUint16(ret[1:], uint16(len(ret.Payload()))) - return ret -} - -func (m *responseMessage) Header() []byte { return m[:headerLen] } -func (m *responseMessage) Payload() []byte { return m[headerLen:] } - -func (m *responseMessage) Type() byte { return m[0] } -func (m *responseMessage) Length() int { return int(binary.BigEndian.Uint16(m[1:3])) } - -func (m *responseMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] } -func (m *responseMessage) Tag() []byte { return m[headerLen+32:] } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlbase + +import "encoding/binary" + +const ( + // msgTypeInitiation frames carry a Noise IK handshake initiation message. + msgTypeInitiation = 1 + // msgTypeResponse frames carry a Noise IK handshake response message. + msgTypeResponse = 2 + // msgTypeError frames carry an unauthenticated human-readable + // error message. + // + // Errors reported in this message type must be treated as public + // hints only. They are not encrypted or authenticated, and so can + // be seen and tampered with on the wire. + msgTypeError = 3 + // msgTypeRecord frames carry session data bytes. + msgTypeRecord = 4 + + // headerLen is the size of the header on all messages except msgTypeInitiation. + headerLen = 3 + // initiationHeaderLen is the size of the header on all msgTypeInitiation messages. + initiationHeaderLen = 5 +) + +// initiationMessage is the protocol message sent from a client +// machine to a control server. +// +// 2b: protocol version +// 1b: message type (0x01) +// 2b: payload length (96) +// 5b: header (see headerLen for fields) +// 32b: client ephemeral public key (cleartext) +// 48b: client machine public key (encrypted) +// 16b: message tag (authenticates the whole message) +type initiationMessage [101]byte + +func mkInitiationMessage(protocolVersion uint16) initiationMessage { + var ret initiationMessage + binary.BigEndian.PutUint16(ret[:2], protocolVersion) + ret[2] = msgTypeInitiation + binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload()))) + return ret +} + +func (m *initiationMessage) Header() []byte { return m[:initiationHeaderLen] } +func (m *initiationMessage) Payload() []byte { return m[initiationHeaderLen:] } + +func (m *initiationMessage) Version() uint16 { return binary.BigEndian.Uint16(m[:2]) } +func (m *initiationMessage) Type() byte { return m[2] } +func (m *initiationMessage) Length() int { return int(binary.BigEndian.Uint16(m[3:5])) } + +func (m *initiationMessage) EphemeralPub() []byte { + return m[initiationHeaderLen : initiationHeaderLen+32] +} +func (m *initiationMessage) MachinePub() []byte { + return m[initiationHeaderLen+32 : initiationHeaderLen+32+48] +} +func (m *initiationMessage) Tag() []byte { return m[initiationHeaderLen+32+48:] } + +// responseMessage is the protocol message sent from a control server +// to a client machine. +// +// 1b: message type (0x02) +// 2b: payload length (48) +// 32b: control ephemeral public key (cleartext) +// 16b: message tag (authenticates the whole message) +type responseMessage [51]byte + +func mkResponseMessage() responseMessage { + var ret responseMessage + ret[0] = msgTypeResponse + binary.BigEndian.PutUint16(ret[1:], uint16(len(ret.Payload()))) + return ret +} + +func (m *responseMessage) Header() []byte { return m[:headerLen] } +func (m *responseMessage) Payload() []byte { return m[headerLen:] } + +func (m *responseMessage) Type() byte { return m[0] } +func (m *responseMessage) Length() int { return int(binary.BigEndian.Uint16(m[1:3])) } + +func (m *responseMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] } +func (m *responseMessage) Tag() []byte { return m[headerLen+32:] } diff --git a/control/controlclient/sign.go b/control/controlclient/sign.go index e3a479c283c62..5e72f1cf4b2b6 100644 --- a/control/controlclient/sign.go +++ b/control/controlclient/sign.go @@ -1,42 +1,42 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlclient - -import ( - "crypto" - "errors" - "fmt" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -var ( - errNoCertStore = errors.New("no certificate store") - errCertificateNotConfigured = errors.New("no certificate subject configured") - errUnsupportedSignatureVersion = errors.New("unsupported signature version") -) - -// HashRegisterRequest generates the hash required sign or verify a -// tailcfg.RegisterRequest. -func HashRegisterRequest( - version tailcfg.SignatureType, ts time.Time, serverURL string, deviceCert []byte, - serverPubKey, machinePubKey key.MachinePublic) ([]byte, error) { - h := crypto.SHA256.New() - - // hash.Hash.Write never returns an error, so we don't check for one here. - switch version { - case tailcfg.SignatureV1: - fmt.Fprintf(h, "%s%s%s%s%s", - ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey.ShortString(), machinePubKey.ShortString()) - case tailcfg.SignatureV2: - fmt.Fprintf(h, "%s%s%s%s%s", - ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey, machinePubKey) - default: - return nil, errUnsupportedSignatureVersion - } - - return h.Sum(nil), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlclient + +import ( + "crypto" + "errors" + "fmt" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +var ( + errNoCertStore = errors.New("no certificate store") + errCertificateNotConfigured = errors.New("no certificate subject configured") + errUnsupportedSignatureVersion = errors.New("unsupported signature version") +) + +// HashRegisterRequest generates the hash required sign or verify a +// tailcfg.RegisterRequest. +func HashRegisterRequest( + version tailcfg.SignatureType, ts time.Time, serverURL string, deviceCert []byte, + serverPubKey, machinePubKey key.MachinePublic) ([]byte, error) { + h := crypto.SHA256.New() + + // hash.Hash.Write never returns an error, so we don't check for one here. + switch version { + case tailcfg.SignatureV1: + fmt.Fprintf(h, "%s%s%s%s%s", + ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey.ShortString(), machinePubKey.ShortString()) + case tailcfg.SignatureV2: + fmt.Fprintf(h, "%s%s%s%s%s", + ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey, machinePubKey) + default: + return nil, errUnsupportedSignatureVersion + } + + return h.Sum(nil), nil +} diff --git a/control/controlclient/sign_supported_test.go b/control/controlclient/sign_supported_test.go index e20349a4e82c3..ca41794d11775 100644 --- a/control/controlclient/sign_supported_test.go +++ b/control/controlclient/sign_supported_test.go @@ -1,236 +1,236 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build windows && cgo - -package controlclient - -import ( - "crypto" - "crypto/x509" - "crypto/x509/pkix" - "errors" - "reflect" - "testing" - "time" - - "github.com/tailscale/certstore" -) - -const ( - testRootCommonName = "testroot" - testRootSubject = "CN=testroot" -) - -type testIdentity struct { - chain []*x509.Certificate -} - -func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate { - return []*x509.Certificate{ - { - NotBefore: notBefore, - NotAfter: notAfter, - PublicKeyAlgorithm: x509.RSA, - }, - { - Subject: pkix.Name{ - CommonName: rootCommonName, - }, - PublicKeyAlgorithm: x509.RSA, - }, - } -} - -func (t *testIdentity) Certificate() (*x509.Certificate, error) { - return t.chain[0], nil -} - -func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) { - return t.chain, nil -} - -func (t *testIdentity) Signer() (crypto.Signer, error) { - return nil, errors.New("not implemented") -} - -func (t *testIdentity) Delete() error { - return errors.New("not implemented") -} - -func (t *testIdentity) Close() {} - -func TestSelectIdentityFromSlice(t *testing.T) { - var times []time.Time - for _, ts := range []string{ - "2000-01-01T00:00:00Z", - "2001-01-01T00:00:00Z", - "2002-01-01T00:00:00Z", - "2003-01-01T00:00:00Z", - } { - tm, err := time.Parse(time.RFC3339, ts) - if err != nil { - t.Fatal(err) - } - times = append(times, tm) - } - - tests := []struct { - name string - subject string - ids []certstore.Identity - now time.Time - // wantIndex is an index into ids, or -1 for nil. - wantIndex int - }{ - { - name: "single unexpired identity", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[2]), - }, - }, - now: times[1], - wantIndex: 0, - }, - { - name: "single expired identity", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - }, - now: times[2], - wantIndex: -1, - }, - { - name: "unrelated ids", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain("something", times[0], times[2]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[2]), - }, - &testIdentity{ - chain: makeChain("else", times[0], times[2]), - }, - }, - now: times[1], - wantIndex: 1, - }, - { - name: "expired with unrelated ids", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain("something", times[0], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - &testIdentity{ - chain: makeChain("else", times[0], times[3]), - }, - }, - now: times[2], - wantIndex: -1, - }, - { - name: "one expired", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[1], times[3]), - }, - }, - now: times[2], - wantIndex: 1, - }, - { - name: "two certs both unexpired", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[1], times[3]), - }, - }, - now: times[2], - wantIndex: 1, - }, - { - name: "two unexpired one expired", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[1], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - }, - now: times[2], - wantIndex: 1, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now) - - if gotId == nil && gotChain != nil { - t.Error("id is nil: got non-nil chain, want nil chain") - return - } - if gotId != nil && gotChain == nil { - t.Error("id is not nil: got nil chain, want non-nil chain") - return - } - if tt.wantIndex == -1 { - if gotId != nil { - t.Error("got non-nil id, want nil id") - } - return - } - if gotId == nil { - t.Error("got nil id, want non-nil id") - return - } - if gotId != tt.ids[tt.wantIndex] { - found := -1 - for i := range tt.ids { - if tt.ids[i] == gotId { - found = i - break - } - } - if found == -1 { - t.Errorf("got unknown id, want id at index %v", tt.wantIndex) - } else { - t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex) - } - } - - tid, ok := tt.ids[tt.wantIndex].(*testIdentity) - if !ok { - t.Error("got non-testIdentity, want testIdentity") - return - } - - if !reflect.DeepEqual(tid.chain, gotChain) { - t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows && cgo + +package controlclient + +import ( + "crypto" + "crypto/x509" + "crypto/x509/pkix" + "errors" + "reflect" + "testing" + "time" + + "github.com/tailscale/certstore" +) + +const ( + testRootCommonName = "testroot" + testRootSubject = "CN=testroot" +) + +type testIdentity struct { + chain []*x509.Certificate +} + +func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate { + return []*x509.Certificate{ + { + NotBefore: notBefore, + NotAfter: notAfter, + PublicKeyAlgorithm: x509.RSA, + }, + { + Subject: pkix.Name{ + CommonName: rootCommonName, + }, + PublicKeyAlgorithm: x509.RSA, + }, + } +} + +func (t *testIdentity) Certificate() (*x509.Certificate, error) { + return t.chain[0], nil +} + +func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) { + return t.chain, nil +} + +func (t *testIdentity) Signer() (crypto.Signer, error) { + return nil, errors.New("not implemented") +} + +func (t *testIdentity) Delete() error { + return errors.New("not implemented") +} + +func (t *testIdentity) Close() {} + +func TestSelectIdentityFromSlice(t *testing.T) { + var times []time.Time + for _, ts := range []string{ + "2000-01-01T00:00:00Z", + "2001-01-01T00:00:00Z", + "2002-01-01T00:00:00Z", + "2003-01-01T00:00:00Z", + } { + tm, err := time.Parse(time.RFC3339, ts) + if err != nil { + t.Fatal(err) + } + times = append(times, tm) + } + + tests := []struct { + name string + subject string + ids []certstore.Identity + now time.Time + // wantIndex is an index into ids, or -1 for nil. + wantIndex int + }{ + { + name: "single unexpired identity", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[2]), + }, + }, + now: times[1], + wantIndex: 0, + }, + { + name: "single expired identity", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + }, + now: times[2], + wantIndex: -1, + }, + { + name: "unrelated ids", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain("something", times[0], times[2]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[2]), + }, + &testIdentity{ + chain: makeChain("else", times[0], times[2]), + }, + }, + now: times[1], + wantIndex: 1, + }, + { + name: "expired with unrelated ids", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain("something", times[0], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + &testIdentity{ + chain: makeChain("else", times[0], times[3]), + }, + }, + now: times[2], + wantIndex: -1, + }, + { + name: "one expired", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[1], times[3]), + }, + }, + now: times[2], + wantIndex: 1, + }, + { + name: "two certs both unexpired", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[1], times[3]), + }, + }, + now: times[2], + wantIndex: 1, + }, + { + name: "two unexpired one expired", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[1], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + }, + now: times[2], + wantIndex: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now) + + if gotId == nil && gotChain != nil { + t.Error("id is nil: got non-nil chain, want nil chain") + return + } + if gotId != nil && gotChain == nil { + t.Error("id is not nil: got nil chain, want non-nil chain") + return + } + if tt.wantIndex == -1 { + if gotId != nil { + t.Error("got non-nil id, want nil id") + } + return + } + if gotId == nil { + t.Error("got nil id, want non-nil id") + return + } + if gotId != tt.ids[tt.wantIndex] { + found := -1 + for i := range tt.ids { + if tt.ids[i] == gotId { + found = i + break + } + } + if found == -1 { + t.Errorf("got unknown id, want id at index %v", tt.wantIndex) + } else { + t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex) + } + } + + tid, ok := tt.ids[tt.wantIndex].(*testIdentity) + if !ok { + t.Error("got non-testIdentity, want testIdentity") + return + } + + if !reflect.DeepEqual(tid.chain, gotChain) { + t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex) + } + }) + } +} diff --git a/control/controlclient/sign_unsupported.go b/control/controlclient/sign_unsupported.go index 5e161dcbce453..4ec40d502773f 100644 --- a/control/controlclient/sign_unsupported.go +++ b/control/controlclient/sign_unsupported.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package controlclient - -import ( - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -// signRegisterRequest on non-supported platforms always returns errNoCertStore. -func signRegisterRequest(req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) error { - return errNoCertStore -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package controlclient + +import ( + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// signRegisterRequest on non-supported platforms always returns errNoCertStore. +func signRegisterRequest(req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) error { + return errNoCertStore +} diff --git a/control/controlclient/status.go b/control/controlclient/status.go index d0fdf80d745e3..7dba14d3f5015 100644 --- a/control/controlclient/status.go +++ b/control/controlclient/status.go @@ -1,125 +1,125 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlclient - -import ( - "encoding/json" - "fmt" - "reflect" - - "tailscale.com/types/netmap" - "tailscale.com/types/persist" - "tailscale.com/types/structs" -) - -// State is the high-level state of the client. It is used only in -// unit tests for proper sequencing, don't depend on it anywhere else. -// -// TODO(apenwarr): eliminate the state, as it's now obsolete. -// -// apenwarr: Historical note: controlclient.Auto was originally -// intended to be the state machine for the whole tailscale client, but that -// turned out to not be the right abstraction layer, and it moved to -// ipn.Backend. Since ipn.Backend now has a state machine, it would be -// much better if controlclient could be a simple stateless API. But the -// current server-side API (two interlocking polling https calls) makes that -// very hard to implement. A server side API change could untangle this and -// remove all the statefulness. -type State int - -const ( - StateNew = State(iota) - StateNotAuthenticated - StateAuthenticating - StateURLVisitRequired - StateAuthenticated - StateSynchronized // connected and received map update -) - -func (s State) AppendText(b []byte) ([]byte, error) { - return append(b, s.String()...), nil -} - -func (s State) MarshalText() ([]byte, error) { - return []byte(s.String()), nil -} - -func (s State) String() string { - switch s { - case StateNew: - return "state:new" - case StateNotAuthenticated: - return "state:not-authenticated" - case StateAuthenticating: - return "state:authenticating" - case StateURLVisitRequired: - return "state:url-visit-required" - case StateAuthenticated: - return "state:authenticated" - case StateSynchronized: - return "state:synchronized" - default: - return fmt.Sprintf("state:unknown:%d", int(s)) - } -} - -type Status struct { - _ structs.Incomparable - - // Err, if non-nil, is an error that occurred while logging in. - // - // If it's of type UserVisibleError then it's meant to be shown to users in - // their Tailscale client. Otherwise it's just logged to tailscaled's logs. - Err error - - // URL, if non-empty, is the interactive URL to visit to finish logging in. - URL string - - // NetMap is the latest server-pushed state of the tailnet network. - NetMap *netmap.NetworkMap - - // Persist, when Valid, is the locally persisted configuration. - // - // TODO(bradfitz,maisem): clarify this. - Persist persist.PersistView - - // state is the internal state. It should not be exposed outside this - // package, but we have some automated tests elsewhere that need to - // use it via the StateForTest accessor. - // TODO(apenwarr): Unexport or remove these. - state State -} - -// LoginFinished reports whether the controlclient is in its "StateAuthenticated" -// state where it's in a happy register state but not yet in a map poll. -// -// TODO(bradfitz): delete this and everything around Status.state. -func (s *Status) LoginFinished() bool { return s.state == StateAuthenticated } - -// StateForTest returns the internal state of s for tests only. -func (s *Status) StateForTest() State { return s.state } - -// SetStateForTest sets the internal state of s for tests only. -func (s *Status) SetStateForTest(state State) { s.state = state } - -// Equal reports whether s and s2 are equal. -func (s *Status) Equal(s2 *Status) bool { - if s == nil && s2 == nil { - return true - } - return s != nil && s2 != nil && - s.Err == s2.Err && - s.URL == s2.URL && - s.state == s2.state && - reflect.DeepEqual(s.Persist, s2.Persist) && - reflect.DeepEqual(s.NetMap, s2.NetMap) -} - -func (s Status) String() string { - b, err := json.MarshalIndent(s, "", "\t") - if err != nil { - panic(err) - } - return s.state.String() + " " + string(b) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlclient + +import ( + "encoding/json" + "fmt" + "reflect" + + "tailscale.com/types/netmap" + "tailscale.com/types/persist" + "tailscale.com/types/structs" +) + +// State is the high-level state of the client. It is used only in +// unit tests for proper sequencing, don't depend on it anywhere else. +// +// TODO(apenwarr): eliminate the state, as it's now obsolete. +// +// apenwarr: Historical note: controlclient.Auto was originally +// intended to be the state machine for the whole tailscale client, but that +// turned out to not be the right abstraction layer, and it moved to +// ipn.Backend. Since ipn.Backend now has a state machine, it would be +// much better if controlclient could be a simple stateless API. But the +// current server-side API (two interlocking polling https calls) makes that +// very hard to implement. A server side API change could untangle this and +// remove all the statefulness. +type State int + +const ( + StateNew = State(iota) + StateNotAuthenticated + StateAuthenticating + StateURLVisitRequired + StateAuthenticated + StateSynchronized // connected and received map update +) + +func (s State) AppendText(b []byte) ([]byte, error) { + return append(b, s.String()...), nil +} + +func (s State) MarshalText() ([]byte, error) { + return []byte(s.String()), nil +} + +func (s State) String() string { + switch s { + case StateNew: + return "state:new" + case StateNotAuthenticated: + return "state:not-authenticated" + case StateAuthenticating: + return "state:authenticating" + case StateURLVisitRequired: + return "state:url-visit-required" + case StateAuthenticated: + return "state:authenticated" + case StateSynchronized: + return "state:synchronized" + default: + return fmt.Sprintf("state:unknown:%d", int(s)) + } +} + +type Status struct { + _ structs.Incomparable + + // Err, if non-nil, is an error that occurred while logging in. + // + // If it's of type UserVisibleError then it's meant to be shown to users in + // their Tailscale client. Otherwise it's just logged to tailscaled's logs. + Err error + + // URL, if non-empty, is the interactive URL to visit to finish logging in. + URL string + + // NetMap is the latest server-pushed state of the tailnet network. + NetMap *netmap.NetworkMap + + // Persist, when Valid, is the locally persisted configuration. + // + // TODO(bradfitz,maisem): clarify this. + Persist persist.PersistView + + // state is the internal state. It should not be exposed outside this + // package, but we have some automated tests elsewhere that need to + // use it via the StateForTest accessor. + // TODO(apenwarr): Unexport or remove these. + state State +} + +// LoginFinished reports whether the controlclient is in its "StateAuthenticated" +// state where it's in a happy register state but not yet in a map poll. +// +// TODO(bradfitz): delete this and everything around Status.state. +func (s *Status) LoginFinished() bool { return s.state == StateAuthenticated } + +// StateForTest returns the internal state of s for tests only. +func (s *Status) StateForTest() State { return s.state } + +// SetStateForTest sets the internal state of s for tests only. +func (s *Status) SetStateForTest(state State) { s.state = state } + +// Equal reports whether s and s2 are equal. +func (s *Status) Equal(s2 *Status) bool { + if s == nil && s2 == nil { + return true + } + return s != nil && s2 != nil && + s.Err == s2.Err && + s.URL == s2.URL && + s.state == s2.state && + reflect.DeepEqual(s.Persist, s2.Persist) && + reflect.DeepEqual(s.NetMap, s2.NetMap) +} + +func (s Status) String() string { + b, err := json.MarshalIndent(s, "", "\t") + if err != nil { + panic(err) + } + return s.state.String() + " " + string(b) +} diff --git a/control/controlhttp/client_common.go b/control/controlhttp/client_common.go index dd94e93cdc3cf..72a89e3cdbbed 100644 --- a/control/controlhttp/client_common.go +++ b/control/controlhttp/client_common.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlhttp - -import ( - "tailscale.com/control/controlbase" -) - -// ClientConn is a Tailscale control client as returned by the Dialer. -// -// It's effectively just a *controlbase.Conn (which it embeds) with -// optional metadata. -type ClientConn struct { - // Conn is the noise connection. - *controlbase.Conn -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlhttp + +import ( + "tailscale.com/control/controlbase" +) + +// ClientConn is a Tailscale control client as returned by the Dialer. +// +// It's effectively just a *controlbase.Conn (which it embeds) with +// optional metadata. +type ClientConn struct { + // Conn is the noise connection. + *controlbase.Conn +} diff --git a/derp/README.md b/derp/README.md index 16877020d465e..acd986ea9cf08 100644 --- a/derp/README.md +++ b/derp/README.md @@ -1,61 +1,61 @@ -# DERP - -This directory (and subdirectories) contain the DERP code. The server itself is -in `../cmd/derper`. - -DERP is a packet relay system (client and servers) where peers are addressed -using WireGuard public keys instead of IP addresses. - -It relays two types of packets: - -* "Disco" discovery messages (see `../disco`) as the a side channel during [NAT - traversal](https://tailscale.com/blog/how-nat-traversal-works/). - -* Encrypted WireGuard packets as the fallback of last resort when UDP is blocked - or NAT traversal fails. - -## DERP Map - -Each client receives a "[DERP -Map](https://pkg.go.dev/tailscale.com/tailcfg#DERPMap)" from the coordination -server describing the DERP servers the client should try to use. - -The client picks its home "DERP home" based on latency. This is done to keep -costs low by avoid using cloud load balancers (pricey) or anycast, which would -necessarily require server-side routing between DERP regions. - -Clients pick their DERP home and report it to the coordination server which -shares it to all the peers in the tailnet. When a peer wants to send a packet -and it doesn't already have a WireGuard session open, it sends disco messages -(some direct, and some over DERP), trying to do the NAT traversal. The client -will make connections to multiple DERP regions as needed. Only the DERP home -region connection needs to be alive forever. - -## DERP Regions - -Tailscale runs 1 or more DERP nodes (instances of `cmd/derper`) in various -geographic regions to make sure users have low latency to their DERP home. - -Regions generally have multiple nodes per region "meshed" (routing to each -other) together for redundancy: it allows for cloud failures or upgrades without -kicking users out to a higher latency region. Instead, clients will reconnect to -the next node in the region. Each node in the region is required to to be meshed -with every other node in the region and forward packets to the other nodes in -the region. Packets are forwarded only one hop within the region. There is no -routing between regions. The assumption is that the mesh TCP connections are -over a VPC that's very fast, low latency, and not charged per byte. The -coordination server assigns the list of nodes in a region as a function of the -tailnet, so all nodes within a tailnet should generally be on the same node and -not require forwarding. Only after a failure do clients of a particular tailnet -get split between nodes in a region and require inter-node forwarding. But over -time it balances back out. There's also an admin-only DERP frame type to force -close the TCP connection of a particular client to force them to reconnect to -their primary if the operator wants to force things to balance out sooner. -(Using the `(*derphttp.Client).ClosePeer` method, as used by Tailscale's -internal rarely-used `cmd/derpprune` maintenance tool) - -We generally run a minimum of three nodes in a region not for quorum reasons -(there's no voting) but just because two is too uncomfortably few for cascading -failure reasons: if you're running two nodes at 51% load (CPU, memory, etc) and -then one fails, that makes the second one fail. With three or more nodes, you +# DERP + +This directory (and subdirectories) contain the DERP code. The server itself is +in `../cmd/derper`. + +DERP is a packet relay system (client and servers) where peers are addressed +using WireGuard public keys instead of IP addresses. + +It relays two types of packets: + +* "Disco" discovery messages (see `../disco`) as the a side channel during [NAT + traversal](https://tailscale.com/blog/how-nat-traversal-works/). + +* Encrypted WireGuard packets as the fallback of last resort when UDP is blocked + or NAT traversal fails. + +## DERP Map + +Each client receives a "[DERP +Map](https://pkg.go.dev/tailscale.com/tailcfg#DERPMap)" from the coordination +server describing the DERP servers the client should try to use. + +The client picks its home "DERP home" based on latency. This is done to keep +costs low by avoid using cloud load balancers (pricey) or anycast, which would +necessarily require server-side routing between DERP regions. + +Clients pick their DERP home and report it to the coordination server which +shares it to all the peers in the tailnet. When a peer wants to send a packet +and it doesn't already have a WireGuard session open, it sends disco messages +(some direct, and some over DERP), trying to do the NAT traversal. The client +will make connections to multiple DERP regions as needed. Only the DERP home +region connection needs to be alive forever. + +## DERP Regions + +Tailscale runs 1 or more DERP nodes (instances of `cmd/derper`) in various +geographic regions to make sure users have low latency to their DERP home. + +Regions generally have multiple nodes per region "meshed" (routing to each +other) together for redundancy: it allows for cloud failures or upgrades without +kicking users out to a higher latency region. Instead, clients will reconnect to +the next node in the region. Each node in the region is required to to be meshed +with every other node in the region and forward packets to the other nodes in +the region. Packets are forwarded only one hop within the region. There is no +routing between regions. The assumption is that the mesh TCP connections are +over a VPC that's very fast, low latency, and not charged per byte. The +coordination server assigns the list of nodes in a region as a function of the +tailnet, so all nodes within a tailnet should generally be on the same node and +not require forwarding. Only after a failure do clients of a particular tailnet +get split between nodes in a region and require inter-node forwarding. But over +time it balances back out. There's also an admin-only DERP frame type to force +close the TCP connection of a particular client to force them to reconnect to +their primary if the operator wants to force things to balance out sooner. +(Using the `(*derphttp.Client).ClosePeer` method, as used by Tailscale's +internal rarely-used `cmd/derpprune` maintenance tool) + +We generally run a minimum of three nodes in a region not for quorum reasons +(there's no voting) but just because two is too uncomfortably few for cascading +failure reasons: if you're running two nodes at 51% load (CPU, memory, etc) and +then one fails, that makes the second one fail. With three or more nodes, you can run each node a bit hotter. \ No newline at end of file diff --git a/derp/testdata/example_ss.txt b/derp/testdata/example_ss.txt index 2885f1bc15a16..ae25003b22856 100644 --- a/derp/testdata/example_ss.txt +++ b/derp/testdata/example_ss.txt @@ -1,8 +1,8 @@ -ESTAB 0 0 10.255.1.11:35238 34.210.105.16:https - cubic wscale:7,7 rto:236 rtt:34.14/3.432 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:8 ssthresh:6 bytes_sent:38056577 bytes_retrans:2918 bytes_acked:38053660 bytes_received:6973211 segs_out:165090 segs_in:124227 data_segs_out:78018 data_segs_in:71645 send 2.71Mbps lastsnd:1156 lastrcv:1120 lastack:1120 pacing_rate 3.26Mbps delivery_rate 2.35Mbps delivered:78017 app_limited busy:2586132ms retrans:0/6 dsack_dups:4 reordering:5 reord_seen:15 rcv_rtt:126355 rcv_space:65780 rcv_ssthresh:541928 minrtt:26.632 -ESTAB 0 80 100.79.58.14:ssh 100.95.73.104:58145 - cubic wscale:6,7 rto:224 rtt:23.051/2.03 ato:172 mss:1228 pmtu:1280 rcvmss:1228 advmss:1228 cwnd:10 ssthresh:94 bytes_sent:1591815 bytes_retrans:944 bytes_acked:1590791 bytes_received:158925 segs_out:8070 segs_in:8858 data_segs_out:7452 data_segs_in:3789 send 4.26Mbps lastsnd:4 lastrcv:4 lastack:4 pacing_rate 8.52Mbps delivery_rate 10.9Mbps delivered:7451 app_limited busy:61656ms unacked:2 retrans:0/10 dsack_dups:10 rcv_rtt:174712 rcv_space:65025 rcv_ssthresh:64296 minrtt:16.186 -ESTAB 0 374 10.255.1.11:43254 167.172.206.31:https - cubic wscale:7,7 rto:224 rtt:22.55/1.941 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:6 ssthresh:4 bytes_sent:14594668 bytes_retrans:173314 bytes_acked:14420981 bytes_received:4207111 segs_out:80566 segs_in:70310 data_segs_out:24317 data_segs_in:20365 send 3.08Mbps lastsnd:4 lastrcv:4 lastack:4 pacing_rate 3.7Mbps delivery_rate 3.05Mbps delivered:24111 app_limited busy:184820ms unacked:2 retrans:0/185 dsack_dups:1 reord_seen:3 rcv_rtt:651.262 rcv_space:226657 rcv_ssthresh:1557136 minrtt:10.18 -ESTAB 0 0 10.255.1.11:33036 3.121.18.47:https - cubic wscale:7,7 rto:372 rtt:168.408/2.044 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:10 bytes_sent:27500 bytes_acked:27501 bytes_received:1386524 segs_out:10990 segs_in:11037 data_segs_out:303 data_segs_in:3414 send 688kbps lastsnd:125776 lastrcv:9640 lastack:22760 pacing_rate 1.38Mbps delivery_rate 482kbps delivered:304 app_limited busy:43024ms rcv_rtt:3345.12 rcv_space:62431 rcv_ssthresh:760472 minrtt:168.867 +ESTAB 0 0 10.255.1.11:35238 34.210.105.16:https + cubic wscale:7,7 rto:236 rtt:34.14/3.432 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:8 ssthresh:6 bytes_sent:38056577 bytes_retrans:2918 bytes_acked:38053660 bytes_received:6973211 segs_out:165090 segs_in:124227 data_segs_out:78018 data_segs_in:71645 send 2.71Mbps lastsnd:1156 lastrcv:1120 lastack:1120 pacing_rate 3.26Mbps delivery_rate 2.35Mbps delivered:78017 app_limited busy:2586132ms retrans:0/6 dsack_dups:4 reordering:5 reord_seen:15 rcv_rtt:126355 rcv_space:65780 rcv_ssthresh:541928 minrtt:26.632 +ESTAB 0 80 100.79.58.14:ssh 100.95.73.104:58145 + cubic wscale:6,7 rto:224 rtt:23.051/2.03 ato:172 mss:1228 pmtu:1280 rcvmss:1228 advmss:1228 cwnd:10 ssthresh:94 bytes_sent:1591815 bytes_retrans:944 bytes_acked:1590791 bytes_received:158925 segs_out:8070 segs_in:8858 data_segs_out:7452 data_segs_in:3789 send 4.26Mbps lastsnd:4 lastrcv:4 lastack:4 pacing_rate 8.52Mbps delivery_rate 10.9Mbps delivered:7451 app_limited busy:61656ms unacked:2 retrans:0/10 dsack_dups:10 rcv_rtt:174712 rcv_space:65025 rcv_ssthresh:64296 minrtt:16.186 +ESTAB 0 374 10.255.1.11:43254 167.172.206.31:https + cubic wscale:7,7 rto:224 rtt:22.55/1.941 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:6 ssthresh:4 bytes_sent:14594668 bytes_retrans:173314 bytes_acked:14420981 bytes_received:4207111 segs_out:80566 segs_in:70310 data_segs_out:24317 data_segs_in:20365 send 3.08Mbps lastsnd:4 lastrcv:4 lastack:4 pacing_rate 3.7Mbps delivery_rate 3.05Mbps delivered:24111 app_limited busy:184820ms unacked:2 retrans:0/185 dsack_dups:1 reord_seen:3 rcv_rtt:651.262 rcv_space:226657 rcv_ssthresh:1557136 minrtt:10.18 +ESTAB 0 0 10.255.1.11:33036 3.121.18.47:https + cubic wscale:7,7 rto:372 rtt:168.408/2.044 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:10 bytes_sent:27500 bytes_acked:27501 bytes_received:1386524 segs_out:10990 segs_in:11037 data_segs_out:303 data_segs_in:3414 send 688kbps lastsnd:125776 lastrcv:9640 lastack:22760 pacing_rate 1.38Mbps delivery_rate 482kbps delivered:304 app_limited busy:43024ms rcv_rtt:3345.12 rcv_space:62431 rcv_ssthresh:760472 minrtt:168.867 diff --git a/disco/disco_fuzzer.go b/disco/disco_fuzzer.go index b9ffabfb00906..0deede05018d3 100644 --- a/disco/disco_fuzzer.go +++ b/disco/disco_fuzzer.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -//go:build gofuzz - -package disco - -func Fuzz(data []byte) int { - m, _ := Parse(data) - - newBytes := m.AppendMarshal(data) - parsedMarshall, _ := Parse(newBytes) - - if m != parsedMarshall { - panic("Parsing error") - } - return 1 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +//go:build gofuzz + +package disco + +func Fuzz(data []byte) int { + m, _ := Parse(data) + + newBytes := m.AppendMarshal(data) + parsedMarshall, _ := Parse(newBytes) + + if m != parsedMarshall { + panic("Parsing error") + } + return 1 +} diff --git a/disco/disco_test.go b/disco/disco_test.go index 1a56324a5a423..045425eb722df 100644 --- a/disco/disco_test.go +++ b/disco/disco_test.go @@ -1,118 +1,118 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package disco - -import ( - "fmt" - "net/netip" - "reflect" - "strings" - "testing" - - "go4.org/mem" - "tailscale.com/types/key" -) - -func TestMarshalAndParse(t *testing.T) { - tests := []struct { - name string - want string - m Message - }{ - { - name: "ping", - m: &Ping{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - }, - want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c", - }, - { - name: "ping_with_nodekey_src", - m: &Ping{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), - }, - want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f", - }, - { - name: "ping_with_padding", - m: &Ping{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - Padding: 3, - }, - want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00", - }, - { - name: "ping_with_padding_and_nodekey_src", - m: &Ping{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), - Padding: 3, - }, - want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 00 00", - }, - { - name: "pong", - m: &Pong{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - Src: mustIPPort("2.3.4.5:1234"), - }, - want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00 00 00 00 00 00 00 00 ff ff 02 03 04 05 04 d2", - }, - { - name: "pongv6", - m: &Pong{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - Src: mustIPPort("[fed0::12]:6666"), - }, - want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c fe d0 00 00 00 00 00 00 00 00 00 00 00 00 00 12 1a 0a", - }, - { - name: "call_me_maybe", - m: &CallMeMaybe{}, - want: "03 00", - }, - { - name: "call_me_maybe_endpoints", - m: &CallMeMaybe{ - MyNumber: []netip.AddrPort{ - netip.MustParseAddrPort("1.2.3.4:567"), - netip.MustParseAddrPort("[2001::3456]:789"), - }, - }, - want: "03 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - foo := []byte("foo") - got := string(tt.m.AppendMarshal(foo)) - got, ok := strings.CutPrefix(got, "foo") - if !ok { - t.Fatalf("didn't start with foo: got %q", got) - } - - gotHex := fmt.Sprintf("% x", got) - if gotHex != tt.want { - t.Fatalf("wrong marshal\n got: %s\nwant: %s\n", gotHex, tt.want) - } - - back, err := Parse([]byte(got)) - if err != nil { - t.Fatalf("parse back: %v", err) - } - if !reflect.DeepEqual(back, tt.m) { - t.Errorf("message in %+v doesn't match Parse back result %+v", tt.m, back) - } - }) - } -} - -func mustIPPort(s string) netip.AddrPort { - ipp, err := netip.ParseAddrPort(s) - if err != nil { - panic(err) - } - return ipp -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package disco + +import ( + "fmt" + "net/netip" + "reflect" + "strings" + "testing" + + "go4.org/mem" + "tailscale.com/types/key" +) + +func TestMarshalAndParse(t *testing.T) { + tests := []struct { + name string + want string + m Message + }{ + { + name: "ping", + m: &Ping{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + }, + want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c", + }, + { + name: "ping_with_nodekey_src", + m: &Ping{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), + }, + want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f", + }, + { + name: "ping_with_padding", + m: &Ping{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + Padding: 3, + }, + want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00", + }, + { + name: "ping_with_padding_and_nodekey_src", + m: &Ping{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), + Padding: 3, + }, + want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 00 00", + }, + { + name: "pong", + m: &Pong{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + Src: mustIPPort("2.3.4.5:1234"), + }, + want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00 00 00 00 00 00 00 00 ff ff 02 03 04 05 04 d2", + }, + { + name: "pongv6", + m: &Pong{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + Src: mustIPPort("[fed0::12]:6666"), + }, + want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c fe d0 00 00 00 00 00 00 00 00 00 00 00 00 00 12 1a 0a", + }, + { + name: "call_me_maybe", + m: &CallMeMaybe{}, + want: "03 00", + }, + { + name: "call_me_maybe_endpoints", + m: &CallMeMaybe{ + MyNumber: []netip.AddrPort{ + netip.MustParseAddrPort("1.2.3.4:567"), + netip.MustParseAddrPort("[2001::3456]:789"), + }, + }, + want: "03 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + foo := []byte("foo") + got := string(tt.m.AppendMarshal(foo)) + got, ok := strings.CutPrefix(got, "foo") + if !ok { + t.Fatalf("didn't start with foo: got %q", got) + } + + gotHex := fmt.Sprintf("% x", got) + if gotHex != tt.want { + t.Fatalf("wrong marshal\n got: %s\nwant: %s\n", gotHex, tt.want) + } + + back, err := Parse([]byte(got)) + if err != nil { + t.Fatalf("parse back: %v", err) + } + if !reflect.DeepEqual(back, tt.m) { + t.Errorf("message in %+v doesn't match Parse back result %+v", tt.m, back) + } + }) + } +} + +func mustIPPort(s string) netip.AddrPort { + ipp, err := netip.ParseAddrPort(s) + if err != nil { + panic(err) + } + return ipp +} diff --git a/disco/pcap.go b/disco/pcap.go index 71035424868e8..5d60ceb28eeef 100644 --- a/disco/pcap.go +++ b/disco/pcap.go @@ -1,40 +1,40 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package disco - -import ( - "bytes" - "encoding/binary" - "net/netip" - - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -// ToPCAPFrame marshals the bytes for a pcap record that describe a disco frame. -// -// Warning: Alloc garbage. Acceptable while capturing. -func ToPCAPFrame(src netip.AddrPort, derpNodeSrc key.NodePublic, payload []byte) []byte { - var ( - b bytes.Buffer - flag uint8 - ) - b.Grow(128) // Most disco frames will probably be smaller than this. - - if src.Addr() == tailcfg.DerpMagicIPAddr { - flag |= 0x01 - } - b.WriteByte(flag) // 1b: flag - - derpSrc := derpNodeSrc.Raw32() - b.Write(derpSrc[:]) // 32b: derp public key - binary.Write(&b, binary.LittleEndian, uint16(src.Port())) // 2b: port - addr, _ := src.Addr().MarshalBinary() - binary.Write(&b, binary.LittleEndian, uint16(len(addr))) // 2b: len(addr) - b.Write(addr) // Xb: addr - binary.Write(&b, binary.LittleEndian, uint16(len(payload))) // 2b: len(payload) - b.Write(payload) // Xb: payload - - return b.Bytes() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package disco + +import ( + "bytes" + "encoding/binary" + "net/netip" + + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// ToPCAPFrame marshals the bytes for a pcap record that describe a disco frame. +// +// Warning: Alloc garbage. Acceptable while capturing. +func ToPCAPFrame(src netip.AddrPort, derpNodeSrc key.NodePublic, payload []byte) []byte { + var ( + b bytes.Buffer + flag uint8 + ) + b.Grow(128) // Most disco frames will probably be smaller than this. + + if src.Addr() == tailcfg.DerpMagicIPAddr { + flag |= 0x01 + } + b.WriteByte(flag) // 1b: flag + + derpSrc := derpNodeSrc.Raw32() + b.Write(derpSrc[:]) // 32b: derp public key + binary.Write(&b, binary.LittleEndian, uint16(src.Port())) // 2b: port + addr, _ := src.Addr().MarshalBinary() + binary.Write(&b, binary.LittleEndian, uint16(len(addr))) // 2b: len(addr) + b.Write(addr) // Xb: addr + binary.Write(&b, binary.LittleEndian, uint16(len(payload))) // 2b: len(payload) + b.Write(payload) // Xb: payload + + return b.Bytes() +} diff --git a/docs/bird/sample_bird.conf b/docs/bird/sample_bird.conf index ed38e66c5c0a2..87222c59af0e6 100644 --- a/docs/bird/sample_bird.conf +++ b/docs/bird/sample_bird.conf @@ -1,16 +1,16 @@ -log syslog all; - -protocol device { - scan time 10; -} - -protocol bgp { - local as 64001; - neighbor 10.40.2.101 as 64002; - ipv4 { - import none; - export all; - }; -} - -include "tailscale_bird.conf"; +log syslog all; + +protocol device { + scan time 10; +} + +protocol bgp { + local as 64001; + neighbor 10.40.2.101 as 64002; + ipv4 { + import none; + export all; + }; +} + +include "tailscale_bird.conf"; diff --git a/docs/bird/tailscale_bird.conf b/docs/bird/tailscale_bird.conf index 8211a50a3c58e..a5f4307479b79 100644 --- a/docs/bird/tailscale_bird.conf +++ b/docs/bird/tailscale_bird.conf @@ -1,4 +1,4 @@ -protocol static tailscale { - ipv4; - route 100.64.0.0/10 via "tailscale0"; -} +protocol static tailscale { + ipv4; + route 100.64.0.0/10 via "tailscale0"; +} diff --git a/docs/k8s/Makefile b/docs/k8s/Makefile index 55804c857c049..107c1c1361c61 100644 --- a/docs/k8s/Makefile +++ b/docs/k8s/Makefile @@ -1,25 +1,25 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -TS_ROUTES ?= "" -SA_NAME ?= tailscale -TS_KUBE_SECRET ?= tailscale - -rbac: - @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" role.yaml - @echo "---" - @sed -e "s;{{SA_NAME}};$(SA_NAME);g" rolebinding.yaml - @echo "---" - @sed -e "s;{{SA_NAME}};$(SA_NAME);g" sa.yaml - -sidecar: - @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" sidecar.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" - -userspace-sidecar: - @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" userspace-sidecar.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" - -proxy: - @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" proxy.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" | sed -e "s;{{TS_DEST_IP}};$(TS_DEST_IP);g" - -subnet-router: - @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" subnet.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" | sed -e "s;{{TS_ROUTES}};$(TS_ROUTES);g" +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +TS_ROUTES ?= "" +SA_NAME ?= tailscale +TS_KUBE_SECRET ?= tailscale + +rbac: + @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" role.yaml + @echo "---" + @sed -e "s;{{SA_NAME}};$(SA_NAME);g" rolebinding.yaml + @echo "---" + @sed -e "s;{{SA_NAME}};$(SA_NAME);g" sa.yaml + +sidecar: + @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" sidecar.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" + +userspace-sidecar: + @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" userspace-sidecar.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" + +proxy: + @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" proxy.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" | sed -e "s;{{TS_DEST_IP}};$(TS_DEST_IP);g" + +subnet-router: + @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" subnet.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" | sed -e "s;{{TS_ROUTES}};$(TS_ROUTES);g" diff --git a/docs/k8s/rolebinding.yaml b/docs/k8s/rolebinding.yaml index 3b18ba8d35e57..b32e66b984510 100644 --- a/docs/k8s/rolebinding.yaml +++ b/docs/k8s/rolebinding.yaml @@ -1,13 +1,13 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause -apiVersion: rbac.authorization.k8s.io/v1 -kind: RoleBinding -metadata: - name: tailscale -subjects: -- kind: ServiceAccount - name: "{{SA_NAME}}" -roleRef: - kind: Role - name: tailscale - apiGroup: rbac.authorization.k8s.io +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: tailscale +subjects: +- kind: ServiceAccount + name: "{{SA_NAME}}" +roleRef: + kind: Role + name: tailscale + apiGroup: rbac.authorization.k8s.io diff --git a/docs/k8s/sa.yaml b/docs/k8s/sa.yaml index edd3944ba8987..85b56bd24a7fe 100644 --- a/docs/k8s/sa.yaml +++ b/docs/k8s/sa.yaml @@ -1,6 +1,6 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause -apiVersion: v1 -kind: ServiceAccount -metadata: - name: {{SA_NAME}} +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{SA_NAME}} diff --git a/docs/sysv/tailscale.init b/docs/sysv/tailscale.init index ca21033df7b27..fc22088b16a5b 100755 --- a/docs/sysv/tailscale.init +++ b/docs/sysv/tailscale.init @@ -1,63 +1,63 @@ -#!/bin/sh -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -### BEGIN INIT INFO -# Provides: tailscaled -# Required-Start: -# Required-Stop: -# Default-Start: -# Default-Stop: -# Short-Description: Tailscale Mesh Wireguard VPN -### END INIT INFO - -set -e - -# /etc/init.d/tailscale: start and stop the Tailscale VPN service - -test -x /usr/sbin/tailscaled || exit 0 - -umask 022 - -. /lib/lsb/init-functions - -# Are we running from init? -run_by_init() { - ([ "$previous" ] && [ "$runlevel" ]) || [ "$runlevel" = S ] -} - -export PATH="${PATH:+$PATH:}/usr/sbin:/sbin" - -case "$1" in - start) - log_daemon_msg "Starting Tailscale VPN" "tailscaled" || true - if start-stop-daemon --start --oknodo --name tailscaled -m --pidfile /run/tailscaled.pid --background \ - --exec /usr/sbin/tailscaled -- \ - --state=/var/lib/tailscale/tailscaled.state \ - --socket=/run/tailscale/tailscaled.sock \ - --port 41641; - then - log_end_msg 0 || true - else - log_end_msg 1 || true - fi - ;; - stop) - log_daemon_msg "Stopping Tailscale VPN" "tailscaled" || true - if start-stop-daemon --stop --remove-pidfile --pidfile /run/tailscaled.pid --exec /usr/sbin/tailscaled; then - log_end_msg 0 || true - else - log_end_msg 1 || true - fi - ;; - - status) - status_of_proc -p /run/tailscaled.pid /usr/sbin/tailscaled tailscaled && exit 0 || exit $? - ;; - - *) - log_action_msg "Usage: /etc/init.d/tailscaled {start|stop|status}" || true - exit 1 -esac - -exit 0 +#!/bin/sh +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +### BEGIN INIT INFO +# Provides: tailscaled +# Required-Start: +# Required-Stop: +# Default-Start: +# Default-Stop: +# Short-Description: Tailscale Mesh Wireguard VPN +### END INIT INFO + +set -e + +# /etc/init.d/tailscale: start and stop the Tailscale VPN service + +test -x /usr/sbin/tailscaled || exit 0 + +umask 022 + +. /lib/lsb/init-functions + +# Are we running from init? +run_by_init() { + ([ "$previous" ] && [ "$runlevel" ]) || [ "$runlevel" = S ] +} + +export PATH="${PATH:+$PATH:}/usr/sbin:/sbin" + +case "$1" in + start) + log_daemon_msg "Starting Tailscale VPN" "tailscaled" || true + if start-stop-daemon --start --oknodo --name tailscaled -m --pidfile /run/tailscaled.pid --background \ + --exec /usr/sbin/tailscaled -- \ + --state=/var/lib/tailscale/tailscaled.state \ + --socket=/run/tailscale/tailscaled.sock \ + --port 41641; + then + log_end_msg 0 || true + else + log_end_msg 1 || true + fi + ;; + stop) + log_daemon_msg "Stopping Tailscale VPN" "tailscaled" || true + if start-stop-daemon --stop --remove-pidfile --pidfile /run/tailscaled.pid --exec /usr/sbin/tailscaled; then + log_end_msg 0 || true + else + log_end_msg 1 || true + fi + ;; + + status) + status_of_proc -p /run/tailscaled.pid /usr/sbin/tailscaled tailscaled && exit 0 || exit $? + ;; + + *) + log_action_msg "Usage: /etc/init.d/tailscaled {start|stop|status}" || true + exit 1 +esac + +exit 0 diff --git a/doctor/doctor.go b/doctor/doctor.go index 7c3047e12b62d..96af39f5f3eb9 100644 --- a/doctor/doctor.go +++ b/doctor/doctor.go @@ -1,79 +1,79 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package doctor contains more in-depth healthchecks that can be run to aid in -// diagnosing Tailscale issues. -package doctor - -import ( - "context" - "sync" - - "tailscale.com/types/logger" -) - -// Check is the interface defining a singular check. -// -// A check should log information that it gathers using the provided log -// function, and should attempt to make as much progress as possible in error -// conditions. -type Check interface { - // Name should return a name describing this check, in lower-kebab-case - // (i.e. "my-check", not "MyCheck" or "my_check"). - Name() string - // Run executes the check, logging diagnostic information to the - // provided logger function. - Run(context.Context, logger.Logf) error -} - -// RunChecks runs a list of checks in parallel, and logs any returned errors -// after all checks have returned. -func RunChecks(ctx context.Context, log logger.Logf, checks ...Check) { - if len(checks) == 0 { - return - } - - type namedErr struct { - name string - err error - } - errs := make(chan namedErr, len(checks)) - - var wg sync.WaitGroup - wg.Add(len(checks)) - for _, check := range checks { - go func(c Check) { - defer wg.Done() - - plog := logger.WithPrefix(log, c.Name()+": ") - errs <- namedErr{ - name: c.Name(), - err: c.Run(ctx, plog), - } - }(check) - } - - wg.Wait() - close(errs) - - for n := range errs { - if n.err == nil { - continue - } - - log("check %s: %v", n.name, n.err) - } -} - -// CheckFunc creates a Check from a name and a function. -func CheckFunc(name string, run func(context.Context, logger.Logf) error) Check { - return checkFunc{name, run} -} - -type checkFunc struct { - name string - run func(context.Context, logger.Logf) error -} - -func (c checkFunc) Name() string { return c.name } -func (c checkFunc) Run(ctx context.Context, log logger.Logf) error { return c.run(ctx, log) } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package doctor contains more in-depth healthchecks that can be run to aid in +// diagnosing Tailscale issues. +package doctor + +import ( + "context" + "sync" + + "tailscale.com/types/logger" +) + +// Check is the interface defining a singular check. +// +// A check should log information that it gathers using the provided log +// function, and should attempt to make as much progress as possible in error +// conditions. +type Check interface { + // Name should return a name describing this check, in lower-kebab-case + // (i.e. "my-check", not "MyCheck" or "my_check"). + Name() string + // Run executes the check, logging diagnostic information to the + // provided logger function. + Run(context.Context, logger.Logf) error +} + +// RunChecks runs a list of checks in parallel, and logs any returned errors +// after all checks have returned. +func RunChecks(ctx context.Context, log logger.Logf, checks ...Check) { + if len(checks) == 0 { + return + } + + type namedErr struct { + name string + err error + } + errs := make(chan namedErr, len(checks)) + + var wg sync.WaitGroup + wg.Add(len(checks)) + for _, check := range checks { + go func(c Check) { + defer wg.Done() + + plog := logger.WithPrefix(log, c.Name()+": ") + errs <- namedErr{ + name: c.Name(), + err: c.Run(ctx, plog), + } + }(check) + } + + wg.Wait() + close(errs) + + for n := range errs { + if n.err == nil { + continue + } + + log("check %s: %v", n.name, n.err) + } +} + +// CheckFunc creates a Check from a name and a function. +func CheckFunc(name string, run func(context.Context, logger.Logf) error) Check { + return checkFunc{name, run} +} + +type checkFunc struct { + name string + run func(context.Context, logger.Logf) error +} + +func (c checkFunc) Name() string { return c.name } +func (c checkFunc) Run(ctx context.Context, log logger.Logf) error { return c.run(ctx, log) } diff --git a/doctor/doctor_test.go b/doctor/doctor_test.go index 87250f10ed00a..dab7afa38a5fc 100644 --- a/doctor/doctor_test.go +++ b/doctor/doctor_test.go @@ -1,49 +1,49 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package doctor - -import ( - "context" - "fmt" - "sync" - "testing" - - qt "github.com/frankban/quicktest" - "tailscale.com/types/logger" -) - -func TestRunChecks(t *testing.T) { - c := qt.New(t) - var ( - mu sync.Mutex - lines []string - ) - logf := func(format string, args ...any) { - mu.Lock() - defer mu.Unlock() - lines = append(lines, fmt.Sprintf(format, args...)) - } - - ctx := context.Background() - RunChecks(ctx, logf, - testCheck1{}, - CheckFunc("testcheck2", func(_ context.Context, log logger.Logf) error { - log("check 2") - return nil - }), - ) - - mu.Lock() - defer mu.Unlock() - c.Assert(lines, qt.Contains, "testcheck1: check 1") - c.Assert(lines, qt.Contains, "testcheck2: check 2") -} - -type testCheck1 struct{} - -func (t testCheck1) Name() string { return "testcheck1" } -func (t testCheck1) Run(_ context.Context, log logger.Logf) error { - log("check 1") - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package doctor + +import ( + "context" + "fmt" + "sync" + "testing" + + qt "github.com/frankban/quicktest" + "tailscale.com/types/logger" +) + +func TestRunChecks(t *testing.T) { + c := qt.New(t) + var ( + mu sync.Mutex + lines []string + ) + logf := func(format string, args ...any) { + mu.Lock() + defer mu.Unlock() + lines = append(lines, fmt.Sprintf(format, args...)) + } + + ctx := context.Background() + RunChecks(ctx, logf, + testCheck1{}, + CheckFunc("testcheck2", func(_ context.Context, log logger.Logf) error { + log("check 2") + return nil + }), + ) + + mu.Lock() + defer mu.Unlock() + c.Assert(lines, qt.Contains, "testcheck1: check 1") + c.Assert(lines, qt.Contains, "testcheck2: check 2") +} + +type testCheck1 struct{} + +func (t testCheck1) Name() string { return "testcheck1" } +func (t testCheck1) Run(_ context.Context, log logger.Logf) error { + log("check 1") + return nil +} diff --git a/doctor/permissions/permissions_bsd.go b/doctor/permissions/permissions_bsd.go index 8b034cfff1af3..4031af7221cd5 100644 --- a/doctor/permissions/permissions_bsd.go +++ b/doctor/permissions/permissions_bsd.go @@ -1,23 +1,23 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin || freebsd || openbsd - -package permissions - -import ( - "golang.org/x/sys/unix" - "tailscale.com/types/logger" -) - -func permissionsImpl(logf logger.Logf) error { - groups, _ := unix.Getgroups() - logf("uid=%s euid=%s gid=%s egid=%s groups=%s", - formatUserID(unix.Getuid()), - formatUserID(unix.Geteuid()), - formatGroupID(unix.Getgid()), - formatGroupID(unix.Getegid()), - formatGroups(groups), - ) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin || freebsd || openbsd + +package permissions + +import ( + "golang.org/x/sys/unix" + "tailscale.com/types/logger" +) + +func permissionsImpl(logf logger.Logf) error { + groups, _ := unix.Getgroups() + logf("uid=%s euid=%s gid=%s egid=%s groups=%s", + formatUserID(unix.Getuid()), + formatUserID(unix.Geteuid()), + formatGroupID(unix.Getgid()), + formatGroupID(unix.Getegid()), + formatGroups(groups), + ) + return nil +} diff --git a/doctor/permissions/permissions_linux.go b/doctor/permissions/permissions_linux.go index 12bb393d53383..ef0a97056f411 100644 --- a/doctor/permissions/permissions_linux.go +++ b/doctor/permissions/permissions_linux.go @@ -1,62 +1,62 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package permissions - -import ( - "fmt" - "strings" - "unsafe" - - "golang.org/x/sys/unix" - "tailscale.com/types/logger" -) - -func permissionsImpl(logf logger.Logf) error { - // NOTE: getresuid and getresgid never fail unless passed an - // invalid address. - var ruid, euid, suid uint64 - unix.Syscall(unix.SYS_GETRESUID, - uintptr(unsafe.Pointer(&ruid)), - uintptr(unsafe.Pointer(&euid)), - uintptr(unsafe.Pointer(&suid)), - ) - - var rgid, egid, sgid uint64 - unix.Syscall(unix.SYS_GETRESGID, - uintptr(unsafe.Pointer(&rgid)), - uintptr(unsafe.Pointer(&egid)), - uintptr(unsafe.Pointer(&sgid)), - ) - - groups, _ := unix.Getgroups() - - var buf strings.Builder - fmt.Fprintf(&buf, "ruid=%s euid=%s suid=%s rgid=%s egid=%s sgid=%s groups=%s", - formatUserID(ruid), formatUserID(euid), formatUserID(suid), - formatGroupID(rgid), formatGroupID(egid), formatGroupID(sgid), - formatGroups(groups), - ) - - // Get process capabilities - var ( - capHeader = unix.CapUserHeader{ - Version: unix.LINUX_CAPABILITY_VERSION_3, - Pid: 0, // 0 means 'ourselves' - } - capData unix.CapUserData - ) - - if err := unix.Capget(&capHeader, &capData); err != nil { - fmt.Fprintf(&buf, " caperr=%v", err) - } else { - fmt.Fprintf(&buf, " cap_effective=%08x cap_permitted=%08x cap_inheritable=%08x", - capData.Effective, capData.Permitted, capData.Inheritable, - ) - } - - logf("%s", buf.String()) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package permissions + +import ( + "fmt" + "strings" + "unsafe" + + "golang.org/x/sys/unix" + "tailscale.com/types/logger" +) + +func permissionsImpl(logf logger.Logf) error { + // NOTE: getresuid and getresgid never fail unless passed an + // invalid address. + var ruid, euid, suid uint64 + unix.Syscall(unix.SYS_GETRESUID, + uintptr(unsafe.Pointer(&ruid)), + uintptr(unsafe.Pointer(&euid)), + uintptr(unsafe.Pointer(&suid)), + ) + + var rgid, egid, sgid uint64 + unix.Syscall(unix.SYS_GETRESGID, + uintptr(unsafe.Pointer(&rgid)), + uintptr(unsafe.Pointer(&egid)), + uintptr(unsafe.Pointer(&sgid)), + ) + + groups, _ := unix.Getgroups() + + var buf strings.Builder + fmt.Fprintf(&buf, "ruid=%s euid=%s suid=%s rgid=%s egid=%s sgid=%s groups=%s", + formatUserID(ruid), formatUserID(euid), formatUserID(suid), + formatGroupID(rgid), formatGroupID(egid), formatGroupID(sgid), + formatGroups(groups), + ) + + // Get process capabilities + var ( + capHeader = unix.CapUserHeader{ + Version: unix.LINUX_CAPABILITY_VERSION_3, + Pid: 0, // 0 means 'ourselves' + } + capData unix.CapUserData + ) + + if err := unix.Capget(&capHeader, &capData); err != nil { + fmt.Fprintf(&buf, " caperr=%v", err) + } else { + fmt.Fprintf(&buf, " cap_effective=%08x cap_permitted=%08x cap_inheritable=%08x", + capData.Effective, capData.Permitted, capData.Inheritable, + ) + } + + logf("%s", buf.String()) + return nil +} diff --git a/doctor/permissions/permissions_other.go b/doctor/permissions/permissions_other.go index 7e6912b4928cf..5e310b98e361e 100644 --- a/doctor/permissions/permissions_other.go +++ b/doctor/permissions/permissions_other.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !(linux || darwin || freebsd || openbsd) - -package permissions - -import ( - "runtime" - - "tailscale.com/types/logger" -) - -func permissionsImpl(logf logger.Logf) error { - logf("unsupported on %s/%s", runtime.GOOS, runtime.GOARCH) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !(linux || darwin || freebsd || openbsd) + +package permissions + +import ( + "runtime" + + "tailscale.com/types/logger" +) + +func permissionsImpl(logf logger.Logf) error { + logf("unsupported on %s/%s", runtime.GOOS, runtime.GOARCH) + return nil +} diff --git a/doctor/permissions/permissions_test.go b/doctor/permissions/permissions_test.go index 941d406ef8318..9b71c3be1cfe3 100644 --- a/doctor/permissions/permissions_test.go +++ b/doctor/permissions/permissions_test.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package permissions - -import "testing" - -func TestPermissionsImpl(t *testing.T) { - if err := permissionsImpl(t.Logf); err != nil { - t.Error(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package permissions + +import "testing" + +func TestPermissionsImpl(t *testing.T) { + if err := permissionsImpl(t.Logf); err != nil { + t.Error(err) + } +} diff --git a/doctor/routetable/routetable.go b/doctor/routetable/routetable.go index 76e4ef949b9af..1ebf294ce1474 100644 --- a/doctor/routetable/routetable.go +++ b/doctor/routetable/routetable.go @@ -1,34 +1,34 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package routetable provides a doctor.Check that dumps the current system's -// route table to the log. -package routetable - -import ( - "context" - - "tailscale.com/net/routetable" - "tailscale.com/types/logger" -) - -// MaxRoutes is the maximum number of routes that will be displayed. -const MaxRoutes = 1000 - -// Check implements the doctor.Check interface. -type Check struct{} - -func (Check) Name() string { - return "routetable" -} - -func (Check) Run(_ context.Context, logf logger.Logf) error { - rs, err := routetable.Get(MaxRoutes) - if err != nil { - return err - } - for _, r := range rs { - logf("%s", r) - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package routetable provides a doctor.Check that dumps the current system's +// route table to the log. +package routetable + +import ( + "context" + + "tailscale.com/net/routetable" + "tailscale.com/types/logger" +) + +// MaxRoutes is the maximum number of routes that will be displayed. +const MaxRoutes = 1000 + +// Check implements the doctor.Check interface. +type Check struct{} + +func (Check) Name() string { + return "routetable" +} + +func (Check) Run(_ context.Context, logf logger.Logf) error { + rs, err := routetable.Get(MaxRoutes) + if err != nil { + return err + } + for _, r := range rs { + logf("%s", r) + } + return nil +} diff --git a/envknob/envknob_nottest.go b/envknob/envknob_nottest.go index 0dd900cc8104e..b21266f1377ca 100644 --- a/envknob/envknob_nottest.go +++ b/envknob/envknob_nottest.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ts_not_in_tests - -package envknob - -import "runtime" - -func GOOS() string { - // When the "ts_not_in_tests" build tag is used, we define this func to just - // return a simple constant so callers optimize just as if the knob were not - // present. We can then build production/optimized builds with the - // "ts_not_in_tests" build tag. - return runtime.GOOS -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_not_in_tests + +package envknob + +import "runtime" + +func GOOS() string { + // When the "ts_not_in_tests" build tag is used, we define this func to just + // return a simple constant so callers optimize just as if the knob were not + // present. We can then build production/optimized builds with the + // "ts_not_in_tests" build tag. + return runtime.GOOS +} diff --git a/envknob/envknob_testable.go b/envknob/envknob_testable.go index e7f038336c4f3..53687d732d493 100644 --- a/envknob/envknob_testable.go +++ b/envknob/envknob_testable.go @@ -1,23 +1,23 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ts_not_in_tests - -package envknob - -import "runtime" - -// GOOS reports the effective runtime.GOOS to run as. -// -// In practice this returns just runtime.GOOS, unless overridden by -// test TS_DEBUG_FAKE_GOOS. -// -// This allows changing OS-specific stuff like the IPN server behavior -// for tests so we can e.g. test Windows-specific behaviors on Linux. -// This isn't universally used. -func GOOS() string { - if v := String("TS_DEBUG_FAKE_GOOS"); v != "" { - return v - } - return runtime.GOOS -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_not_in_tests + +package envknob + +import "runtime" + +// GOOS reports the effective runtime.GOOS to run as. +// +// In practice this returns just runtime.GOOS, unless overridden by +// test TS_DEBUG_FAKE_GOOS. +// +// This allows changing OS-specific stuff like the IPN server behavior +// for tests so we can e.g. test Windows-specific behaviors on Linux. +// This isn't universally used. +func GOOS() string { + if v := String("TS_DEBUG_FAKE_GOOS"); v != "" { + return v + } + return runtime.GOOS +} diff --git a/envknob/logknob/logknob.go b/envknob/logknob/logknob.go index 350384b8626e3..a7b0a05e8b1b8 100644 --- a/envknob/logknob/logknob.go +++ b/envknob/logknob/logknob.go @@ -1,85 +1,85 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package logknob provides a helpful wrapper that allows enabling logging -// based on either an envknob or other methods of enablement. -package logknob - -import ( - "sync/atomic" - - "tailscale.com/envknob" - "tailscale.com/tailcfg" - "tailscale.com/types/logger" - "tailscale.com/types/views" -) - -// TODO(andrew-d): should we have a package-global registry of logknobs? It -// would allow us to update from a netmap in a central location, which might be -// reason enough to do it... - -// LogKnob allows configuring verbose logging, with multiple ways to enable. It -// supports enabling logging via envknob, via atomic boolean (for use in e.g. -// c2n log level changes), and via capabilities from a NetMap (so users can -// enable logging via the ACL JSON). -type LogKnob struct { - capName tailcfg.NodeCapability - cap atomic.Bool - env func() bool - manual atomic.Bool -} - -// NewLogKnob creates a new LogKnob, with the provided environment variable -// name and/or NetMap capability. -func NewLogKnob(env string, cap tailcfg.NodeCapability) *LogKnob { - if env == "" && cap == "" { - panic("must provide either an environment variable or capability") - } - - lk := &LogKnob{ - capName: cap, - } - if env != "" { - lk.env = envknob.RegisterBool(env) - } else { - lk.env = func() bool { return false } - } - return lk -} - -// Set will cause logs to be printed when called with Set(true). When called -// with Set(false), logs will not be printed due to an earlier call of -// Set(true), but may be printed due to either the envknob and/or capability of -// this LogKnob. -func (lk *LogKnob) Set(v bool) { - lk.manual.Store(v) -} - -// NetMap is an interface for the parts of netmap.NetworkMap that we care -// about; we use this rather than a concrete type to avoid a circular -// dependency. -type NetMap interface { - SelfCapabilities() views.Slice[tailcfg.NodeCapability] -} - -// UpdateFromNetMap will enable logging if the SelfNode in the provided NetMap -// contains the capability provided for this LogKnob. -func (lk *LogKnob) UpdateFromNetMap(nm NetMap) { - if lk.capName == "" { - return - } - - lk.cap.Store(views.SliceContains(nm.SelfCapabilities(), lk.capName)) -} - -// Do will call log with the provided format and arguments if any of the -// configured methods for enabling logging are true. -func (lk *LogKnob) Do(log logger.Logf, format string, args ...any) { - if lk.shouldLog() { - log(format, args...) - } -} - -func (lk *LogKnob) shouldLog() bool { - return lk.manual.Load() || lk.env() || lk.cap.Load() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package logknob provides a helpful wrapper that allows enabling logging +// based on either an envknob or other methods of enablement. +package logknob + +import ( + "sync/atomic" + + "tailscale.com/envknob" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/types/views" +) + +// TODO(andrew-d): should we have a package-global registry of logknobs? It +// would allow us to update from a netmap in a central location, which might be +// reason enough to do it... + +// LogKnob allows configuring verbose logging, with multiple ways to enable. It +// supports enabling logging via envknob, via atomic boolean (for use in e.g. +// c2n log level changes), and via capabilities from a NetMap (so users can +// enable logging via the ACL JSON). +type LogKnob struct { + capName tailcfg.NodeCapability + cap atomic.Bool + env func() bool + manual atomic.Bool +} + +// NewLogKnob creates a new LogKnob, with the provided environment variable +// name and/or NetMap capability. +func NewLogKnob(env string, cap tailcfg.NodeCapability) *LogKnob { + if env == "" && cap == "" { + panic("must provide either an environment variable or capability") + } + + lk := &LogKnob{ + capName: cap, + } + if env != "" { + lk.env = envknob.RegisterBool(env) + } else { + lk.env = func() bool { return false } + } + return lk +} + +// Set will cause logs to be printed when called with Set(true). When called +// with Set(false), logs will not be printed due to an earlier call of +// Set(true), but may be printed due to either the envknob and/or capability of +// this LogKnob. +func (lk *LogKnob) Set(v bool) { + lk.manual.Store(v) +} + +// NetMap is an interface for the parts of netmap.NetworkMap that we care +// about; we use this rather than a concrete type to avoid a circular +// dependency. +type NetMap interface { + SelfCapabilities() views.Slice[tailcfg.NodeCapability] +} + +// UpdateFromNetMap will enable logging if the SelfNode in the provided NetMap +// contains the capability provided for this LogKnob. +func (lk *LogKnob) UpdateFromNetMap(nm NetMap) { + if lk.capName == "" { + return + } + + lk.cap.Store(views.SliceContains(nm.SelfCapabilities(), lk.capName)) +} + +// Do will call log with the provided format and arguments if any of the +// configured methods for enabling logging are true. +func (lk *LogKnob) Do(log logger.Logf, format string, args ...any) { + if lk.shouldLog() { + log(format, args...) + } +} + +func (lk *LogKnob) shouldLog() bool { + return lk.manual.Load() || lk.env() || lk.cap.Load() +} diff --git a/envknob/logknob/logknob_test.go b/envknob/logknob/logknob_test.go index b2a376a25b371..c9eed5612379a 100644 --- a/envknob/logknob/logknob_test.go +++ b/envknob/logknob/logknob_test.go @@ -1,102 +1,102 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package logknob - -import ( - "bytes" - "fmt" - "testing" - - "tailscale.com/envknob" - "tailscale.com/tailcfg" - "tailscale.com/types/netmap" -) - -var testKnob = NewLogKnob( - "TS_TEST_LOGKNOB", - "https://tailscale.com/cap/testing", -) - -// Static type assertion for our interface type. -var _ NetMap = &netmap.NetworkMap{} - -func TestLogKnob(t *testing.T) { - t.Run("Default", func(t *testing.T) { - if testKnob.shouldLog() { - t.Errorf("expected default shouldLog()=false") - } - assertNoLogs(t) - }) - t.Run("Manual", func(t *testing.T) { - t.Cleanup(func() { testKnob.Set(false) }) - - assertNoLogs(t) - testKnob.Set(true) - if !testKnob.shouldLog() { - t.Errorf("expected shouldLog()=true") - } - assertLogs(t) - }) - t.Run("Env", func(t *testing.T) { - t.Cleanup(func() { - envknob.Setenv("TS_TEST_LOGKNOB", "") - }) - - assertNoLogs(t) - if testKnob.shouldLog() { - t.Errorf("expected default shouldLog()=false") - } - - envknob.Setenv("TS_TEST_LOGKNOB", "true") - if !testKnob.shouldLog() { - t.Errorf("expected shouldLog()=true") - } - assertLogs(t) - }) - t.Run("NetMap", func(t *testing.T) { - t.Cleanup(func() { testKnob.cap.Store(false) }) - - assertNoLogs(t) - if testKnob.shouldLog() { - t.Errorf("expected default shouldLog()=false") - } - - testKnob.UpdateFromNetMap(&netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - Capabilities: []tailcfg.NodeCapability{ - "https://tailscale.com/cap/testing", - }, - }).View(), - }) - if !testKnob.shouldLog() { - t.Errorf("expected shouldLog()=true") - } - assertLogs(t) - }) -} - -func assertLogs(t *testing.T) { - var buf bytes.Buffer - logf := func(format string, args ...any) { - fmt.Fprintf(&buf, format, args...) - } - - testKnob.Do(logf, "hello %s", "world") - const want = "hello world" - if got := buf.String(); got != want { - t.Errorf("got %q, want %q", got, want) - } -} - -func assertNoLogs(t *testing.T) { - var buf bytes.Buffer - logf := func(format string, args ...any) { - fmt.Fprintf(&buf, format, args...) - } - - testKnob.Do(logf, "hello %s", "world") - if got := buf.String(); got != "" { - t.Errorf("expected no logs, but got: %q", got) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package logknob + +import ( + "bytes" + "fmt" + "testing" + + "tailscale.com/envknob" + "tailscale.com/tailcfg" + "tailscale.com/types/netmap" +) + +var testKnob = NewLogKnob( + "TS_TEST_LOGKNOB", + "https://tailscale.com/cap/testing", +) + +// Static type assertion for our interface type. +var _ NetMap = &netmap.NetworkMap{} + +func TestLogKnob(t *testing.T) { + t.Run("Default", func(t *testing.T) { + if testKnob.shouldLog() { + t.Errorf("expected default shouldLog()=false") + } + assertNoLogs(t) + }) + t.Run("Manual", func(t *testing.T) { + t.Cleanup(func() { testKnob.Set(false) }) + + assertNoLogs(t) + testKnob.Set(true) + if !testKnob.shouldLog() { + t.Errorf("expected shouldLog()=true") + } + assertLogs(t) + }) + t.Run("Env", func(t *testing.T) { + t.Cleanup(func() { + envknob.Setenv("TS_TEST_LOGKNOB", "") + }) + + assertNoLogs(t) + if testKnob.shouldLog() { + t.Errorf("expected default shouldLog()=false") + } + + envknob.Setenv("TS_TEST_LOGKNOB", "true") + if !testKnob.shouldLog() { + t.Errorf("expected shouldLog()=true") + } + assertLogs(t) + }) + t.Run("NetMap", func(t *testing.T) { + t.Cleanup(func() { testKnob.cap.Store(false) }) + + assertNoLogs(t) + if testKnob.shouldLog() { + t.Errorf("expected default shouldLog()=false") + } + + testKnob.UpdateFromNetMap(&netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Capabilities: []tailcfg.NodeCapability{ + "https://tailscale.com/cap/testing", + }, + }).View(), + }) + if !testKnob.shouldLog() { + t.Errorf("expected shouldLog()=true") + } + assertLogs(t) + }) +} + +func assertLogs(t *testing.T) { + var buf bytes.Buffer + logf := func(format string, args ...any) { + fmt.Fprintf(&buf, format, args...) + } + + testKnob.Do(logf, "hello %s", "world") + const want = "hello world" + if got := buf.String(); got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func assertNoLogs(t *testing.T) { + var buf bytes.Buffer + logf := func(format string, args ...any) { + fmt.Fprintf(&buf, format, args...) + } + + testKnob.Do(logf, "hello %s", "world") + if got := buf.String(); got != "" { + t.Errorf("expected no logs, but got: %q", got) + } +} diff --git a/gomod_test.go b/gomod_test.go index f984b5d6f27a5..52fdd463910c4 100644 --- a/gomod_test.go +++ b/gomod_test.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailscaleroot - -import ( - "os" - "testing" - - "golang.org/x/mod/modfile" -) - -func TestGoMod(t *testing.T) { - goMod, err := os.ReadFile("go.mod") - if err != nil { - t.Fatal(err) - } - f, err := modfile.Parse("go.mod", goMod, nil) - if err != nil { - t.Fatal(err) - } - if len(f.Replace) > 0 { - t.Errorf("go.mod has %d replace directives; expect zero in this repo", len(f.Replace)) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailscaleroot + +import ( + "os" + "testing" + + "golang.org/x/mod/modfile" +) + +func TestGoMod(t *testing.T) { + goMod, err := os.ReadFile("go.mod") + if err != nil { + t.Fatal(err) + } + f, err := modfile.Parse("go.mod", goMod, nil) + if err != nil { + t.Fatal(err) + } + if len(f.Replace) > 0 { + t.Errorf("go.mod has %d replace directives; expect zero in this repo", len(f.Replace)) + } +} diff --git a/hostinfo/hostinfo_darwin.go b/hostinfo/hostinfo_darwin.go index 0b1774e7712d7..a61d95b32c907 100644 --- a/hostinfo/hostinfo_darwin.go +++ b/hostinfo/hostinfo_darwin.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin - -package hostinfo - -import ( - "os" - "path/filepath" -) - -func init() { - packageType = packageTypeDarwin -} - -func packageTypeDarwin() string { - // Using tailscaled or IPNExtension? - exe, _ := os.Executable() - return filepath.Base(exe) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin + +package hostinfo + +import ( + "os" + "path/filepath" +) + +func init() { + packageType = packageTypeDarwin +} + +func packageTypeDarwin() string { + // Using tailscaled or IPNExtension? + exe, _ := os.Executable() + return filepath.Base(exe) +} diff --git a/hostinfo/hostinfo_freebsd.go b/hostinfo/hostinfo_freebsd.go index 3661b13229ac5..15c7783aa4e4c 100644 --- a/hostinfo/hostinfo_freebsd.go +++ b/hostinfo/hostinfo_freebsd.go @@ -1,64 +1,64 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build freebsd - -package hostinfo - -import ( - "bytes" - "os" - "os/exec" - - "golang.org/x/sys/unix" - "tailscale.com/types/ptr" - "tailscale.com/version/distro" -) - -func init() { - osVersion = lazyOSVersion.Get - distroName = distroNameFreeBSD - distroVersion = distroVersionFreeBSD -} - -var ( - lazyVersionMeta = &lazyAtomicValue[versionMeta]{f: ptr.To(freebsdVersionMeta)} - lazyOSVersion = &lazyAtomicValue[string]{f: ptr.To(osVersionFreeBSD)} -) - -func distroNameFreeBSD() string { - return lazyVersionMeta.Get().DistroName -} - -func distroVersionFreeBSD() string { - return lazyVersionMeta.Get().DistroVersion -} - -type versionMeta struct { - DistroName string - DistroVersion string - DistroCodeName string -} - -func osVersionFreeBSD() string { - var un unix.Utsname - unix.Uname(&un) - return unix.ByteSliceToString(un.Release[:]) -} - -func freebsdVersionMeta() (meta versionMeta) { - d := distro.Get() - meta.DistroName = string(d) - switch d { - case distro.Pfsense: - b, _ := os.ReadFile("/etc/version") - meta.DistroVersion = string(bytes.TrimSpace(b)) - case distro.OPNsense: - b, _ := exec.Command("opnsense-version").Output() - meta.DistroVersion = string(bytes.TrimSpace(b)) - case distro.TrueNAS: - b, _ := os.ReadFile("/etc/version") - meta.DistroVersion = string(bytes.TrimSpace(b)) - } - return -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build freebsd + +package hostinfo + +import ( + "bytes" + "os" + "os/exec" + + "golang.org/x/sys/unix" + "tailscale.com/types/ptr" + "tailscale.com/version/distro" +) + +func init() { + osVersion = lazyOSVersion.Get + distroName = distroNameFreeBSD + distroVersion = distroVersionFreeBSD +} + +var ( + lazyVersionMeta = &lazyAtomicValue[versionMeta]{f: ptr.To(freebsdVersionMeta)} + lazyOSVersion = &lazyAtomicValue[string]{f: ptr.To(osVersionFreeBSD)} +) + +func distroNameFreeBSD() string { + return lazyVersionMeta.Get().DistroName +} + +func distroVersionFreeBSD() string { + return lazyVersionMeta.Get().DistroVersion +} + +type versionMeta struct { + DistroName string + DistroVersion string + DistroCodeName string +} + +func osVersionFreeBSD() string { + var un unix.Utsname + unix.Uname(&un) + return unix.ByteSliceToString(un.Release[:]) +} + +func freebsdVersionMeta() (meta versionMeta) { + d := distro.Get() + meta.DistroName = string(d) + switch d { + case distro.Pfsense: + b, _ := os.ReadFile("/etc/version") + meta.DistroVersion = string(bytes.TrimSpace(b)) + case distro.OPNsense: + b, _ := exec.Command("opnsense-version").Output() + meta.DistroVersion = string(bytes.TrimSpace(b)) + case distro.TrueNAS: + b, _ := os.ReadFile("/etc/version") + meta.DistroVersion = string(bytes.TrimSpace(b)) + } + return +} diff --git a/hostinfo/hostinfo_test.go b/hostinfo/hostinfo_test.go index 9fe32e0449be1..76282ebf56733 100644 --- a/hostinfo/hostinfo_test.go +++ b/hostinfo/hostinfo_test.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package hostinfo - -import ( - "encoding/json" - "strings" - "testing" -) - -func TestNew(t *testing.T) { - hi := New() - if hi == nil { - t.Fatal("no Hostinfo") - } - j, err := json.MarshalIndent(hi, " ", "") - if err != nil { - t.Fatal(err) - } - t.Logf("Got: %s", j) -} - -func TestOSVersion(t *testing.T) { - if osVersion == nil { - t.Skip("not available for OS") - } - t.Logf("Got: %#q", osVersion()) -} - -func TestEtcAptSourceFileIsDisabled(t *testing.T) { - tests := []struct { - name string - in string - want bool - }{ - {"empty", "", false}, - {"normal", "deb foo\n", false}, - {"normal-commented", "# deb foo\n", false}, - {"normal-disabled-by-ubuntu", "# deb foo # disabled on upgrade to dingus\n", true}, - {"normal-disabled-then-uncommented", "deb foo # disabled on upgrade to dingus\n", false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := etcAptSourceFileIsDisabled(strings.NewReader(tt.in)) - if got != tt.want { - t.Errorf("got %v; want %v", got, tt.want) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package hostinfo + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestNew(t *testing.T) { + hi := New() + if hi == nil { + t.Fatal("no Hostinfo") + } + j, err := json.MarshalIndent(hi, " ", "") + if err != nil { + t.Fatal(err) + } + t.Logf("Got: %s", j) +} + +func TestOSVersion(t *testing.T) { + if osVersion == nil { + t.Skip("not available for OS") + } + t.Logf("Got: %#q", osVersion()) +} + +func TestEtcAptSourceFileIsDisabled(t *testing.T) { + tests := []struct { + name string + in string + want bool + }{ + {"empty", "", false}, + {"normal", "deb foo\n", false}, + {"normal-commented", "# deb foo\n", false}, + {"normal-disabled-by-ubuntu", "# deb foo # disabled on upgrade to dingus\n", true}, + {"normal-disabled-then-uncommented", "deb foo # disabled on upgrade to dingus\n", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := etcAptSourceFileIsDisabled(strings.NewReader(tt.in)) + if got != tt.want { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} diff --git a/hostinfo/hostinfo_uname.go b/hostinfo/hostinfo_uname.go index 32b733a03bcb3..10995c1c78652 100644 --- a/hostinfo/hostinfo_uname.go +++ b/hostinfo/hostinfo_uname.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || freebsd || openbsd || darwin - -package hostinfo - -import ( - "runtime" - - "golang.org/x/sys/unix" - "tailscale.com/types/ptr" -) - -func init() { - unameMachine = lazyUnameMachine.Get -} - -var lazyUnameMachine = &lazyAtomicValue[string]{f: ptr.To(unameMachineUnix)} - -func unameMachineUnix() string { - switch runtime.GOOS { - case "android": - // Don't call on Android for now. We're late in the 1.36 release cycle - // and don't want to test syscall filters on various Android versions to - // see what's permitted. Notably, the hostinfo_linux.go file has build - // tag !android, so maybe Uname is verboten. - return "" - case "ios": - // For similar reasons, don't call on iOS. There aren't many iOS devices - // and we know their CPU properties so calling this is only risk and no - // reward. - return "" - } - var un unix.Utsname - unix.Uname(&un) - return unix.ByteSliceToString(un.Machine[:]) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux || freebsd || openbsd || darwin + +package hostinfo + +import ( + "runtime" + + "golang.org/x/sys/unix" + "tailscale.com/types/ptr" +) + +func init() { + unameMachine = lazyUnameMachine.Get +} + +var lazyUnameMachine = &lazyAtomicValue[string]{f: ptr.To(unameMachineUnix)} + +func unameMachineUnix() string { + switch runtime.GOOS { + case "android": + // Don't call on Android for now. We're late in the 1.36 release cycle + // and don't want to test syscall filters on various Android versions to + // see what's permitted. Notably, the hostinfo_linux.go file has build + // tag !android, so maybe Uname is verboten. + return "" + case "ios": + // For similar reasons, don't call on iOS. There aren't many iOS devices + // and we know their CPU properties so calling this is only risk and no + // reward. + return "" + } + var un unix.Utsname + unix.Uname(&un) + return unix.ByteSliceToString(un.Machine[:]) +} diff --git a/hostinfo/wol.go b/hostinfo/wol.go index 3a30af2fe3a37..b6fc81a8b2482 100644 --- a/hostinfo/wol.go +++ b/hostinfo/wol.go @@ -1,106 +1,106 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package hostinfo - -import ( - "log" - "net" - "runtime" - "strings" - "unicode" - - "tailscale.com/envknob" -) - -// TODO(bradfitz): this is all too simplistic and static. It needs to run -// continuously in response to netmon events (USB ethernet adapaters might get -// plugged in) and look for the media type/status/etc. Right now on macOS it -// still detects a half dozen "up" en0, en1, en2, en3 etc interfaces that don't -// have any media. We should only report the one that's actually connected. -// But it works for now (2023-10-05) for fleshing out the rest. - -var wakeMAC = envknob.RegisterString("TS_WAKE_MAC") // mac address, "false" or "auto". for https://github.com/tailscale/tailscale/issues/306 - -// getWoLMACs returns up to 10 MAC address of the local machine to send -// wake-on-LAN packets to in order to wake it up. The returned MACs are in -// lowercase hex colon-separated form ("xx:xx:xx:xx:xx:xx"). -// -// If TS_WAKE_MAC=auto, it tries to automatically find the MACs based on the OS -// type and interface properties. (TODO(bradfitz): incomplete) If TS_WAKE_MAC is -// set to a MAC address, that sole MAC address is returned. -func getWoLMACs() (macs []string) { - switch runtime.GOOS { - case "ios", "android": - return nil - } - if s := wakeMAC(); s != "" { - switch s { - case "auto": - ifs, _ := net.Interfaces() - for _, iface := range ifs { - if iface.Flags&net.FlagLoopback != 0 { - continue - } - if iface.Flags&net.FlagBroadcast == 0 || - iface.Flags&net.FlagRunning == 0 || - iface.Flags&net.FlagUp == 0 { - continue - } - if keepMAC(iface.Name, iface.HardwareAddr) { - macs = append(macs, iface.HardwareAddr.String()) - } - if len(macs) == 10 { - break - } - } - return macs - case "false", "off": // fast path before ParseMAC error - return nil - } - mac, err := net.ParseMAC(s) - if err != nil { - log.Printf("invalid MAC %q", s) - return nil - } - return []string{mac.String()} - } - return nil -} - -var ignoreWakeOUI = map[[3]byte]bool{ - {0x00, 0x15, 0x5d}: true, // Hyper-V - {0x00, 0x50, 0x56}: true, // VMware - {0x00, 0x1c, 0x14}: true, // VMware - {0x00, 0x05, 0x69}: true, // VMware - {0x00, 0x0c, 0x29}: true, // VMware - {0x00, 0x1c, 0x42}: true, // Parallels - {0x08, 0x00, 0x27}: true, // VirtualBox - {0x00, 0x21, 0xf6}: true, // VirtualBox - {0x00, 0x14, 0x4f}: true, // VirtualBox - {0x00, 0x0f, 0x4b}: true, // VirtualBox - {0x52, 0x54, 0x00}: true, // VirtualBox/Vagrant -} - -func keepMAC(ifName string, mac []byte) bool { - if len(mac) != 6 { - return false - } - base := strings.TrimRightFunc(ifName, unicode.IsNumber) - switch runtime.GOOS { - case "darwin": - switch base { - case "llw", "awdl", "utun", "bridge", "lo", "gif", "stf", "anpi", "ap": - return false - } - } - if mac[0] == 0x02 && mac[1] == 0x42 { - // Docker container. - return false - } - oui := [3]byte{mac[0], mac[1], mac[2]} - if ignoreWakeOUI[oui] { - return false - } - return true -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package hostinfo + +import ( + "log" + "net" + "runtime" + "strings" + "unicode" + + "tailscale.com/envknob" +) + +// TODO(bradfitz): this is all too simplistic and static. It needs to run +// continuously in response to netmon events (USB ethernet adapaters might get +// plugged in) and look for the media type/status/etc. Right now on macOS it +// still detects a half dozen "up" en0, en1, en2, en3 etc interfaces that don't +// have any media. We should only report the one that's actually connected. +// But it works for now (2023-10-05) for fleshing out the rest. + +var wakeMAC = envknob.RegisterString("TS_WAKE_MAC") // mac address, "false" or "auto". for https://github.com/tailscale/tailscale/issues/306 + +// getWoLMACs returns up to 10 MAC address of the local machine to send +// wake-on-LAN packets to in order to wake it up. The returned MACs are in +// lowercase hex colon-separated form ("xx:xx:xx:xx:xx:xx"). +// +// If TS_WAKE_MAC=auto, it tries to automatically find the MACs based on the OS +// type and interface properties. (TODO(bradfitz): incomplete) If TS_WAKE_MAC is +// set to a MAC address, that sole MAC address is returned. +func getWoLMACs() (macs []string) { + switch runtime.GOOS { + case "ios", "android": + return nil + } + if s := wakeMAC(); s != "" { + switch s { + case "auto": + ifs, _ := net.Interfaces() + for _, iface := range ifs { + if iface.Flags&net.FlagLoopback != 0 { + continue + } + if iface.Flags&net.FlagBroadcast == 0 || + iface.Flags&net.FlagRunning == 0 || + iface.Flags&net.FlagUp == 0 { + continue + } + if keepMAC(iface.Name, iface.HardwareAddr) { + macs = append(macs, iface.HardwareAddr.String()) + } + if len(macs) == 10 { + break + } + } + return macs + case "false", "off": // fast path before ParseMAC error + return nil + } + mac, err := net.ParseMAC(s) + if err != nil { + log.Printf("invalid MAC %q", s) + return nil + } + return []string{mac.String()} + } + return nil +} + +var ignoreWakeOUI = map[[3]byte]bool{ + {0x00, 0x15, 0x5d}: true, // Hyper-V + {0x00, 0x50, 0x56}: true, // VMware + {0x00, 0x1c, 0x14}: true, // VMware + {0x00, 0x05, 0x69}: true, // VMware + {0x00, 0x0c, 0x29}: true, // VMware + {0x00, 0x1c, 0x42}: true, // Parallels + {0x08, 0x00, 0x27}: true, // VirtualBox + {0x00, 0x21, 0xf6}: true, // VirtualBox + {0x00, 0x14, 0x4f}: true, // VirtualBox + {0x00, 0x0f, 0x4b}: true, // VirtualBox + {0x52, 0x54, 0x00}: true, // VirtualBox/Vagrant +} + +func keepMAC(ifName string, mac []byte) bool { + if len(mac) != 6 { + return false + } + base := strings.TrimRightFunc(ifName, unicode.IsNumber) + switch runtime.GOOS { + case "darwin": + switch base { + case "llw", "awdl", "utun", "bridge", "lo", "gif", "stf", "anpi", "ap": + return false + } + } + if mac[0] == 0x02 && mac[1] == 0x42 { + // Docker container. + return false + } + oui := [3]byte{mac[0], mac[1], mac[2]} + if ignoreWakeOUI[oui] { + return false + } + return true +} diff --git a/ipn/ipnlocal/breaktcp_darwin.go b/ipn/ipnlocal/breaktcp_darwin.go index 13566198ce9fc..289e760e194a4 100644 --- a/ipn/ipnlocal/breaktcp_darwin.go +++ b/ipn/ipnlocal/breaktcp_darwin.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnlocal - -import ( - "log" - - "golang.org/x/sys/unix" -) - -func init() { - breakTCPConns = breakTCPConnsDarwin -} - -func breakTCPConnsDarwin() error { - var matched int - for fd := 0; fd < 1000; fd++ { - _, err := unix.GetsockoptTCPConnectionInfo(fd, unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) - if err == nil { - matched++ - err = unix.Close(fd) - log.Printf("debug: closed TCP fd %v: %v", fd, err) - } - } - if matched == 0 { - log.Printf("debug: no TCP connections found") - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "log" + + "golang.org/x/sys/unix" +) + +func init() { + breakTCPConns = breakTCPConnsDarwin +} + +func breakTCPConnsDarwin() error { + var matched int + for fd := 0; fd < 1000; fd++ { + _, err := unix.GetsockoptTCPConnectionInfo(fd, unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) + if err == nil { + matched++ + err = unix.Close(fd) + log.Printf("debug: closed TCP fd %v: %v", fd, err) + } + } + if matched == 0 { + log.Printf("debug: no TCP connections found") + } + return nil +} diff --git a/ipn/ipnlocal/breaktcp_linux.go b/ipn/ipnlocal/breaktcp_linux.go index b82f6521246f0..d078103cf5388 100644 --- a/ipn/ipnlocal/breaktcp_linux.go +++ b/ipn/ipnlocal/breaktcp_linux.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnlocal - -import ( - "log" - - "golang.org/x/sys/unix" -) - -func init() { - breakTCPConns = breakTCPConnsLinux -} - -func breakTCPConnsLinux() error { - var matched int - for fd := 0; fd < 1000; fd++ { - _, err := unix.GetsockoptTCPInfo(fd, unix.IPPROTO_TCP, unix.TCP_INFO) - if err == nil { - matched++ - err = unix.Close(fd) - log.Printf("debug: closed TCP fd %v: %v", fd, err) - } - } - if matched == 0 { - log.Printf("debug: no TCP connections found") - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "log" + + "golang.org/x/sys/unix" +) + +func init() { + breakTCPConns = breakTCPConnsLinux +} + +func breakTCPConnsLinux() error { + var matched int + for fd := 0; fd < 1000; fd++ { + _, err := unix.GetsockoptTCPInfo(fd, unix.IPPROTO_TCP, unix.TCP_INFO) + if err == nil { + matched++ + err = unix.Close(fd) + log.Printf("debug: closed TCP fd %v: %v", fd, err) + } + } + if matched == 0 { + log.Printf("debug: no TCP connections found") + } + return nil +} diff --git a/ipn/ipnlocal/expiry_test.go b/ipn/ipnlocal/expiry_test.go index af1aa337bbe0c..efc18133f556d 100644 --- a/ipn/ipnlocal/expiry_test.go +++ b/ipn/ipnlocal/expiry_test.go @@ -1,301 +1,301 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnlocal - -import ( - "fmt" - "reflect" - "strings" - "testing" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/tstest" - "tailscale.com/types/key" - "tailscale.com/types/netmap" -) - -func TestFlagExpiredPeers(t *testing.T) { - n := func(id tailcfg.NodeID, name string, expiry time.Time, mod ...func(*tailcfg.Node)) *tailcfg.Node { - n := &tailcfg.Node{ID: id, Name: name, KeyExpiry: expiry} - for _, f := range mod { - f(n) - } - return n - } - - now := time.Unix(1673373129, 0) - - timeInPast := now.Add(-1 * time.Hour) - timeInFuture := now.Add(1 * time.Hour) - - timeBeforeEpoch := flagExpiredPeersEpoch.Add(-1 * time.Second) - if now.Before(timeBeforeEpoch) { - panic("current time in test cannot be before epoch") - } - - var expiredKey key.NodePublic - if err := expiredKey.UnmarshalText([]byte("nodekey:6da774d5d7740000000000000000000000000000000000000000000000000000")); err != nil { - panic(err) - } - - tests := []struct { - name string - controlTime *time.Time - netmap *netmap.NetworkMap - want []tailcfg.NodeView - }{ - { - name: "no_expiry", - controlTime: &now, - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeInFuture), - }), - }, - want: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeInFuture), - }), - }, - { - name: "expiry", - controlTime: &now, - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeInPast), - }), - }, - want: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeInPast, func(n *tailcfg.Node) { - n.Expired = true - n.Key = expiredKey - }), - }), - }, - { - name: "bad_ControlTime", - // controlTime here is intentionally before our hardcoded epoch - controlTime: &timeBeforeEpoch, - - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeBeforeEpoch.Add(-1*time.Hour)), // before ControlTime - }), - }, - want: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeBeforeEpoch.Add(-1*time.Hour)), // should have expired, but ControlTime is before epoch - }), - }, - { - name: "tagged_node", - controlTime: &now, - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", time.Time{}), // tagged node; zero expiry - }), - }, - want: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", time.Time{}), // not expired - }), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - em := newExpiryManager(t.Logf) - em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) - if tt.controlTime != nil { - em.onControlTime(*tt.controlTime) - } - em.flagExpiredPeers(tt.netmap, now) - if !reflect.DeepEqual(tt.netmap.Peers, tt.want) { - t.Errorf("wrong results\n got: %s\nwant: %s", formatNodes(tt.netmap.Peers), formatNodes(tt.want)) - } - }) - } -} - -func TestNextPeerExpiry(t *testing.T) { - n := func(id tailcfg.NodeID, name string, expiry time.Time, mod ...func(*tailcfg.Node)) *tailcfg.Node { - n := &tailcfg.Node{ID: id, Name: name, KeyExpiry: expiry} - for _, f := range mod { - f(n) - } - return n - } - - now := time.Unix(1675725516, 0) - - noExpiry := time.Time{} - timeInPast := now.Add(-1 * time.Hour) - timeInFuture := now.Add(1 * time.Hour) - timeInMoreFuture := now.Add(2 * time.Hour) - - tests := []struct { - name string - netmap *netmap.NetworkMap - want time.Time - }{ - { - name: "no_expiry", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", noExpiry), - n(2, "bar", noExpiry), - }), - SelfNode: n(3, "self", noExpiry).View(), - }, - want: noExpiry, - }, - { - name: "future_expiry_from_peer", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", noExpiry), - n(2, "bar", timeInFuture), - }), - SelfNode: n(3, "self", noExpiry).View(), - }, - want: timeInFuture, - }, - { - name: "future_expiry_from_self", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", noExpiry), - n(2, "bar", noExpiry), - }), - SelfNode: n(3, "self", timeInFuture).View(), - }, - want: timeInFuture, - }, - { - name: "future_expiry_from_multiple_peers", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeInMoreFuture), - }), - SelfNode: n(3, "self", noExpiry).View(), - }, - want: timeInFuture, - }, - { - name: "future_expiry_from_peer_and_self", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInMoreFuture), - }), - SelfNode: n(2, "self", timeInFuture).View(), - }, - want: timeInFuture, - }, - { - name: "only_self", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{}), - SelfNode: n(1, "self", timeInFuture).View(), - }, - want: timeInFuture, - }, - { - name: "peer_already_expired", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInPast), - }), - SelfNode: n(2, "self", timeInFuture).View(), - }, - want: timeInFuture, - }, - { - name: "self_already_expired", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - }), - SelfNode: n(2, "self", timeInPast).View(), - }, - want: timeInFuture, - }, - { - name: "all_nodes_already_expired", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInPast), - }), - SelfNode: n(2, "self", timeInPast).View(), - }, - want: noExpiry, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - em := newExpiryManager(t.Logf) - em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) - got := em.nextPeerExpiry(tt.netmap, now) - if !got.Equal(tt.want) { - t.Errorf("got %q, want %q", got.Format(time.RFC3339), tt.want.Format(time.RFC3339)) - } else if !got.IsZero() && got.Before(now) { - t.Errorf("unexpectedly got expiry %q before now %q", got.Format(time.RFC3339), now.Format(time.RFC3339)) - } - }) - } - - t.Run("ClockSkew", func(t *testing.T) { - t.Logf("local time: %q", now.Format(time.RFC3339)) - em := newExpiryManager(t.Logf) - em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) - - // The local clock is "running fast"; our clock skew is -2h - em.clockDelta.Store(-2 * time.Hour) - t.Logf("'real' time: %q", now.Add(-2*time.Hour).Format(time.RFC3339)) - - // If we don't adjust for the local time, this would return a - // time in the past. - nm := &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInPast), - }), - } - got := em.nextPeerExpiry(nm, now) - want := now.Add(30 * time.Second) - if !got.Equal(want) { - t.Errorf("got %q, want %q", got.Format(time.RFC3339), want.Format(time.RFC3339)) - } - }) -} - -func formatNodes(nodes []tailcfg.NodeView) string { - var sb strings.Builder - for i, n := range nodes { - if i > 0 { - sb.WriteString(", ") - } - fmt.Fprintf(&sb, "(%d, %q", n.ID(), n.Name()) - - if n.Online() != nil { - fmt.Fprintf(&sb, ", online=%v", *n.Online()) - } - if n.LastSeen() != nil { - fmt.Fprintf(&sb, ", lastSeen=%v", n.LastSeen().Unix()) - } - if n.Key() != (key.NodePublic{}) { - fmt.Fprintf(&sb, ", key=%v", n.Key().String()) - } - if n.Expired() { - fmt.Fprintf(&sb, ", expired=true") - } - sb.WriteString(")") - } - return sb.String() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "fmt" + "reflect" + "strings" + "testing" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/types/key" + "tailscale.com/types/netmap" +) + +func TestFlagExpiredPeers(t *testing.T) { + n := func(id tailcfg.NodeID, name string, expiry time.Time, mod ...func(*tailcfg.Node)) *tailcfg.Node { + n := &tailcfg.Node{ID: id, Name: name, KeyExpiry: expiry} + for _, f := range mod { + f(n) + } + return n + } + + now := time.Unix(1673373129, 0) + + timeInPast := now.Add(-1 * time.Hour) + timeInFuture := now.Add(1 * time.Hour) + + timeBeforeEpoch := flagExpiredPeersEpoch.Add(-1 * time.Second) + if now.Before(timeBeforeEpoch) { + panic("current time in test cannot be before epoch") + } + + var expiredKey key.NodePublic + if err := expiredKey.UnmarshalText([]byte("nodekey:6da774d5d7740000000000000000000000000000000000000000000000000000")); err != nil { + panic(err) + } + + tests := []struct { + name string + controlTime *time.Time + netmap *netmap.NetworkMap + want []tailcfg.NodeView + }{ + { + name: "no_expiry", + controlTime: &now, + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeInFuture), + }), + }, + want: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeInFuture), + }), + }, + { + name: "expiry", + controlTime: &now, + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeInPast), + }), + }, + want: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeInPast, func(n *tailcfg.Node) { + n.Expired = true + n.Key = expiredKey + }), + }), + }, + { + name: "bad_ControlTime", + // controlTime here is intentionally before our hardcoded epoch + controlTime: &timeBeforeEpoch, + + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeBeforeEpoch.Add(-1*time.Hour)), // before ControlTime + }), + }, + want: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeBeforeEpoch.Add(-1*time.Hour)), // should have expired, but ControlTime is before epoch + }), + }, + { + name: "tagged_node", + controlTime: &now, + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", time.Time{}), // tagged node; zero expiry + }), + }, + want: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", time.Time{}), // not expired + }), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + em := newExpiryManager(t.Logf) + em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) + if tt.controlTime != nil { + em.onControlTime(*tt.controlTime) + } + em.flagExpiredPeers(tt.netmap, now) + if !reflect.DeepEqual(tt.netmap.Peers, tt.want) { + t.Errorf("wrong results\n got: %s\nwant: %s", formatNodes(tt.netmap.Peers), formatNodes(tt.want)) + } + }) + } +} + +func TestNextPeerExpiry(t *testing.T) { + n := func(id tailcfg.NodeID, name string, expiry time.Time, mod ...func(*tailcfg.Node)) *tailcfg.Node { + n := &tailcfg.Node{ID: id, Name: name, KeyExpiry: expiry} + for _, f := range mod { + f(n) + } + return n + } + + now := time.Unix(1675725516, 0) + + noExpiry := time.Time{} + timeInPast := now.Add(-1 * time.Hour) + timeInFuture := now.Add(1 * time.Hour) + timeInMoreFuture := now.Add(2 * time.Hour) + + tests := []struct { + name string + netmap *netmap.NetworkMap + want time.Time + }{ + { + name: "no_expiry", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", noExpiry), + n(2, "bar", noExpiry), + }), + SelfNode: n(3, "self", noExpiry).View(), + }, + want: noExpiry, + }, + { + name: "future_expiry_from_peer", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", noExpiry), + n(2, "bar", timeInFuture), + }), + SelfNode: n(3, "self", noExpiry).View(), + }, + want: timeInFuture, + }, + { + name: "future_expiry_from_self", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", noExpiry), + n(2, "bar", noExpiry), + }), + SelfNode: n(3, "self", timeInFuture).View(), + }, + want: timeInFuture, + }, + { + name: "future_expiry_from_multiple_peers", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeInMoreFuture), + }), + SelfNode: n(3, "self", noExpiry).View(), + }, + want: timeInFuture, + }, + { + name: "future_expiry_from_peer_and_self", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInMoreFuture), + }), + SelfNode: n(2, "self", timeInFuture).View(), + }, + want: timeInFuture, + }, + { + name: "only_self", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{}), + SelfNode: n(1, "self", timeInFuture).View(), + }, + want: timeInFuture, + }, + { + name: "peer_already_expired", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInPast), + }), + SelfNode: n(2, "self", timeInFuture).View(), + }, + want: timeInFuture, + }, + { + name: "self_already_expired", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + }), + SelfNode: n(2, "self", timeInPast).View(), + }, + want: timeInFuture, + }, + { + name: "all_nodes_already_expired", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInPast), + }), + SelfNode: n(2, "self", timeInPast).View(), + }, + want: noExpiry, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + em := newExpiryManager(t.Logf) + em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) + got := em.nextPeerExpiry(tt.netmap, now) + if !got.Equal(tt.want) { + t.Errorf("got %q, want %q", got.Format(time.RFC3339), tt.want.Format(time.RFC3339)) + } else if !got.IsZero() && got.Before(now) { + t.Errorf("unexpectedly got expiry %q before now %q", got.Format(time.RFC3339), now.Format(time.RFC3339)) + } + }) + } + + t.Run("ClockSkew", func(t *testing.T) { + t.Logf("local time: %q", now.Format(time.RFC3339)) + em := newExpiryManager(t.Logf) + em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) + + // The local clock is "running fast"; our clock skew is -2h + em.clockDelta.Store(-2 * time.Hour) + t.Logf("'real' time: %q", now.Add(-2*time.Hour).Format(time.RFC3339)) + + // If we don't adjust for the local time, this would return a + // time in the past. + nm := &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInPast), + }), + } + got := em.nextPeerExpiry(nm, now) + want := now.Add(30 * time.Second) + if !got.Equal(want) { + t.Errorf("got %q, want %q", got.Format(time.RFC3339), want.Format(time.RFC3339)) + } + }) +} + +func formatNodes(nodes []tailcfg.NodeView) string { + var sb strings.Builder + for i, n := range nodes { + if i > 0 { + sb.WriteString(", ") + } + fmt.Fprintf(&sb, "(%d, %q", n.ID(), n.Name()) + + if n.Online() != nil { + fmt.Fprintf(&sb, ", online=%v", *n.Online()) + } + if n.LastSeen() != nil { + fmt.Fprintf(&sb, ", lastSeen=%v", n.LastSeen().Unix()) + } + if n.Key() != (key.NodePublic{}) { + fmt.Fprintf(&sb, ", key=%v", n.Key().String()) + } + if n.Expired() { + fmt.Fprintf(&sb, ", expired=true") + } + sb.WriteString(")") + } + return sb.String() +} diff --git a/ipn/ipnlocal/peerapi_h2c.go b/ipn/ipnlocal/peerapi_h2c.go index fbfa8639808ae..e6335fe2be5b6 100644 --- a/ipn/ipnlocal/peerapi_h2c.go +++ b/ipn/ipnlocal/peerapi_h2c.go @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ios && !android && !js - -package ipnlocal - -import ( - "net/http" - - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" -) - -func init() { - addH2C = func(s *http.Server) { - h2s := &http2.Server{} - s.Handler = h2c.NewHandler(s.Handler, h2s) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !android && !js + +package ipnlocal + +import ( + "net/http" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +func init() { + addH2C = func(s *http.Server) { + h2s := &http2.Server{} + s.Handler = h2c.NewHandler(s.Handler, h2s) + } +} diff --git a/ipn/ipnlocal/testdata/example.com-key.pem b/ipn/ipnlocal/testdata/example.com-key.pem index 06902f4c9c314..9020553f1829b 100644 --- a/ipn/ipnlocal/testdata/example.com-key.pem +++ b/ipn/ipnlocal/testdata/example.com-key.pem @@ -1,28 +1,28 @@ ------BEGIN PRIVATE KEY----- -MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCejQaJrntrJSgE -QtScyTU6TXOU+v1FdFjrsyHFK5mjV1C5pVQxnLn93GRshtIrGOLLrd3Wv2TVYZOX -xH7f1ZLFbneDURCXbS+7nmsg+TLHRSRKfODbE3oYZj7NSJ163CCvwSJKTdmLpXbn -ui9F04tyk0zxO4Wre4ukwf6xtse8G5zl2RJrueiVAiouTG/pJdIS08dGQa0GM1n9 -Aesa+TerlZcpRZR6X402yQqa8q/QqbIuzrlfDmgOb8sm6T8+JMtj3hEvnYdpMVOg -w/XiTlX0v/YrB9sVQ9XnqGsqwTL0OMG0choMNKipwLi2n+XPSCIiRhi666zNNivE -K1qaPS5RAgMBAAECggEAV9dAGQWPISR70CiKjLa5A60nbRHFQjackTE0c32daC6W -7dOYGsh/DxOMm8fyJqhp9nhEYJa3MbUWxU27ER3NbA6wrhM6gvqeKG8zYRhPNrGq -0o3vMdDPozb6cldZ0Fimz1jMO6h373NjtiyjxibWqkrLpRbaDtCq5EQKbMEcVa2D -Xt5hxCOaCA3OZ/mAcGUNFmDNgNsGP/r6eXdI5pbqnUNMPkv/JsHl8h2HuyKUm4hf -TRnXPAak6DkUod9QXYFKVBVPa5pjiO09e0aiMUvJ8vYd/6bNIsAKWLPa1PYuUE2l -kg8Nik+P/XLzffKsLxiFKY0nCqrorM9K5q7baofGdQKBgQDPujjebFg6OKw6MS3S -PESopvL//C/XgtgifcSSZCWzIZRVBVTbbJCGRtqFzF0XO4YRX3EOAyD/L7wYUPzO -+W3AU2W3/DVJYdcm2CASABbHNy0kk52LI0HHAssbFDgyB9XuuWP+vVZk7B5OmCAD -Bppuj6Mnu03i282nKNJzvRiVnwKBgQDDZUXv22K8y7GkKw/ZW/wQP2zBNtFc15he -1EOyUGHlXuQixnDSaqonkwec6IOlo7Sx/vwO/7+v4Jzc24Wq3DFAmMu/EYJgvI+m -m3kpB4H7Xus4JqnhxqN7GB7zOdguCWZF1HLemZNZlVrUjG5mQ9cizzvvYptnQDLq -FEJ1hddWDwKBgB+vy276Xfb7oCH8UH4KXXrQhK7RvEaGmgug3bRq/Gk3zRWvC4Ox -KtagxkK0qtqZZNkPkwJNLeJfWLTo3beAyuIUlqabHVHFT/mH7FRymQbofsVekyCf -TzBZV7wYuH3BPjv9IajBHwWkEvdwMyni/vmwhXXRF49schF2o6uuA6sHAoGBAL1J -Xnb+EKjUq0JedPwcIBOdXb3PXQKT2QgEmZAkTrHlOxx1INa2fh/YT4ext9a+wE2u -tn/RQeEfttY90z+yEASEAN0YGTWddYvxEW6t1z2stjGvQuN1ium0dEcrwkDW2jzL -knwSSqx+A3/kiw6GqeMO3wEIhYOArdIVzkwLXJABAoGAOXLGhz5u5FWjF3zAeYme -uHTU/3Z3jeI80PvShGrgAakPOBt3cIFpUaiOEslcqqgDUSGE3EnmkRqaEch+UapF -ty6Zz7cKjXhQSWOjew1uUW2ANNEpsnYbmZOOnfvosd7jfHSVbL6KIhWmIdC6h0NP -c/bJnTXEEVsWjLZTwYaq0Us= +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCejQaJrntrJSgE +QtScyTU6TXOU+v1FdFjrsyHFK5mjV1C5pVQxnLn93GRshtIrGOLLrd3Wv2TVYZOX +xH7f1ZLFbneDURCXbS+7nmsg+TLHRSRKfODbE3oYZj7NSJ163CCvwSJKTdmLpXbn +ui9F04tyk0zxO4Wre4ukwf6xtse8G5zl2RJrueiVAiouTG/pJdIS08dGQa0GM1n9 +Aesa+TerlZcpRZR6X402yQqa8q/QqbIuzrlfDmgOb8sm6T8+JMtj3hEvnYdpMVOg +w/XiTlX0v/YrB9sVQ9XnqGsqwTL0OMG0choMNKipwLi2n+XPSCIiRhi666zNNivE +K1qaPS5RAgMBAAECggEAV9dAGQWPISR70CiKjLa5A60nbRHFQjackTE0c32daC6W +7dOYGsh/DxOMm8fyJqhp9nhEYJa3MbUWxU27ER3NbA6wrhM6gvqeKG8zYRhPNrGq +0o3vMdDPozb6cldZ0Fimz1jMO6h373NjtiyjxibWqkrLpRbaDtCq5EQKbMEcVa2D +Xt5hxCOaCA3OZ/mAcGUNFmDNgNsGP/r6eXdI5pbqnUNMPkv/JsHl8h2HuyKUm4hf +TRnXPAak6DkUod9QXYFKVBVPa5pjiO09e0aiMUvJ8vYd/6bNIsAKWLPa1PYuUE2l +kg8Nik+P/XLzffKsLxiFKY0nCqrorM9K5q7baofGdQKBgQDPujjebFg6OKw6MS3S +PESopvL//C/XgtgifcSSZCWzIZRVBVTbbJCGRtqFzF0XO4YRX3EOAyD/L7wYUPzO ++W3AU2W3/DVJYdcm2CASABbHNy0kk52LI0HHAssbFDgyB9XuuWP+vVZk7B5OmCAD +Bppuj6Mnu03i282nKNJzvRiVnwKBgQDDZUXv22K8y7GkKw/ZW/wQP2zBNtFc15he +1EOyUGHlXuQixnDSaqonkwec6IOlo7Sx/vwO/7+v4Jzc24Wq3DFAmMu/EYJgvI+m +m3kpB4H7Xus4JqnhxqN7GB7zOdguCWZF1HLemZNZlVrUjG5mQ9cizzvvYptnQDLq +FEJ1hddWDwKBgB+vy276Xfb7oCH8UH4KXXrQhK7RvEaGmgug3bRq/Gk3zRWvC4Ox +KtagxkK0qtqZZNkPkwJNLeJfWLTo3beAyuIUlqabHVHFT/mH7FRymQbofsVekyCf +TzBZV7wYuH3BPjv9IajBHwWkEvdwMyni/vmwhXXRF49schF2o6uuA6sHAoGBAL1J +Xnb+EKjUq0JedPwcIBOdXb3PXQKT2QgEmZAkTrHlOxx1INa2fh/YT4ext9a+wE2u +tn/RQeEfttY90z+yEASEAN0YGTWddYvxEW6t1z2stjGvQuN1ium0dEcrwkDW2jzL +knwSSqx+A3/kiw6GqeMO3wEIhYOArdIVzkwLXJABAoGAOXLGhz5u5FWjF3zAeYme +uHTU/3Z3jeI80PvShGrgAakPOBt3cIFpUaiOEslcqqgDUSGE3EnmkRqaEch+UapF +ty6Zz7cKjXhQSWOjew1uUW2ANNEpsnYbmZOOnfvosd7jfHSVbL6KIhWmIdC6h0NP +c/bJnTXEEVsWjLZTwYaq0Us= -----END PRIVATE KEY----- \ No newline at end of file diff --git a/ipn/ipnlocal/testdata/example.com.pem b/ipn/ipnlocal/testdata/example.com.pem index 588850813b102..65e7110a8d1ae 100644 --- a/ipn/ipnlocal/testdata/example.com.pem +++ b/ipn/ipnlocal/testdata/example.com.pem @@ -1,26 +1,26 @@ ------BEGIN CERTIFICATE----- -MIIEcDCCAtigAwIBAgIRAPmUKRkyFAkVVxFblB/233cwDQYJKoZIhvcNAQELBQAw -gZ8xHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTE6MDgGA1UECwwxZnJv -bWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWljaGFlbCBKLiBGcm9tYmVyZ2VyKTFB -MD8GA1UEAww4bWtjZXJ0IGZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hh -ZWwgSi4gRnJvbWJlcmdlcikwHhcNMjMwMjA3MjAzNDE4WhcNMjUwNTA3MTkzNDE4 -WjBlMScwJQYDVQQKEx5ta2NlcnQgZGV2ZWxvcG1lbnQgY2VydGlmaWNhdGUxOjA4 -BgNVBAsMMWZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hhZWwgSi4gRnJv -bWJlcmdlcikwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCejQaJrntr -JSgEQtScyTU6TXOU+v1FdFjrsyHFK5mjV1C5pVQxnLn93GRshtIrGOLLrd3Wv2TV -YZOXxH7f1ZLFbneDURCXbS+7nmsg+TLHRSRKfODbE3oYZj7NSJ163CCvwSJKTdmL -pXbnui9F04tyk0zxO4Wre4ukwf6xtse8G5zl2RJrueiVAiouTG/pJdIS08dGQa0G -M1n9Aesa+TerlZcpRZR6X402yQqa8q/QqbIuzrlfDmgOb8sm6T8+JMtj3hEvnYdp -MVOgw/XiTlX0v/YrB9sVQ9XnqGsqwTL0OMG0choMNKipwLi2n+XPSCIiRhi666zN -NivEK1qaPS5RAgMBAAGjYDBeMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUEDDAKBggr -BgEFBQcDATAfBgNVHSMEGDAWgBTXyq2jQVrnqQKL8fB9C4L0QJftwDAWBgNVHREE -DzANggtleGFtcGxlLmNvbTANBgkqhkiG9w0BAQsFAAOCAYEAQWzpOaBkRR4M+WqB -CsT4ARyM6WpZ+jpeSblCzPdlDRW+50G1HV7K930zayq4DwncPY/SqSn0Q31WuzZv -bTWHkWa+MLPGYANHsusOmMR8Eh16G4+5+GGf8psWa0npAYO35cuNkyyCCc1LEB4M -NrzCB2+KZ+SyOdfCCA5VzEKN3I8wvVLaYovi24Zjwv+0uETG92TlZmLQRhj8uPxN -deeLM45aBkQZSYCbGMDVDK/XYKBkNLn3kxD/eZeXxxr41v4pH44+46FkYcYJzdn8 -ccAg5LRGieqTozhLiXARNK1vTy6kR1l/Az8DIx6GN4sP2/LMFYFijiiOCDKS1wWA -xQgZeHt4GIuBym+Kd+Z5KXcP0AT+47Cby3+B10Kq8vHwjTELiF0UFeEYYMdynPAW -pbEwVLhsfMsBqFtj3dsxHr8Kz3rnarOYzkaw7EMZnLAthb2CN7y5uGV9imQC5RMI -/qZdRSuCYZ3A1E/WJkGbPY/YdPql/IE+LIAgKGFHZZNftBCo +-----BEGIN CERTIFICATE----- +MIIEcDCCAtigAwIBAgIRAPmUKRkyFAkVVxFblB/233cwDQYJKoZIhvcNAQELBQAw +gZ8xHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTE6MDgGA1UECwwxZnJv +bWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWljaGFlbCBKLiBGcm9tYmVyZ2VyKTFB +MD8GA1UEAww4bWtjZXJ0IGZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hh +ZWwgSi4gRnJvbWJlcmdlcikwHhcNMjMwMjA3MjAzNDE4WhcNMjUwNTA3MTkzNDE4 +WjBlMScwJQYDVQQKEx5ta2NlcnQgZGV2ZWxvcG1lbnQgY2VydGlmaWNhdGUxOjA4 +BgNVBAsMMWZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hhZWwgSi4gRnJv +bWJlcmdlcikwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCejQaJrntr +JSgEQtScyTU6TXOU+v1FdFjrsyHFK5mjV1C5pVQxnLn93GRshtIrGOLLrd3Wv2TV +YZOXxH7f1ZLFbneDURCXbS+7nmsg+TLHRSRKfODbE3oYZj7NSJ163CCvwSJKTdmL +pXbnui9F04tyk0zxO4Wre4ukwf6xtse8G5zl2RJrueiVAiouTG/pJdIS08dGQa0G +M1n9Aesa+TerlZcpRZR6X402yQqa8q/QqbIuzrlfDmgOb8sm6T8+JMtj3hEvnYdp +MVOgw/XiTlX0v/YrB9sVQ9XnqGsqwTL0OMG0choMNKipwLi2n+XPSCIiRhi666zN +NivEK1qaPS5RAgMBAAGjYDBeMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUEDDAKBggr +BgEFBQcDATAfBgNVHSMEGDAWgBTXyq2jQVrnqQKL8fB9C4L0QJftwDAWBgNVHREE +DzANggtleGFtcGxlLmNvbTANBgkqhkiG9w0BAQsFAAOCAYEAQWzpOaBkRR4M+WqB +CsT4ARyM6WpZ+jpeSblCzPdlDRW+50G1HV7K930zayq4DwncPY/SqSn0Q31WuzZv +bTWHkWa+MLPGYANHsusOmMR8Eh16G4+5+GGf8psWa0npAYO35cuNkyyCCc1LEB4M +NrzCB2+KZ+SyOdfCCA5VzEKN3I8wvVLaYovi24Zjwv+0uETG92TlZmLQRhj8uPxN +deeLM45aBkQZSYCbGMDVDK/XYKBkNLn3kxD/eZeXxxr41v4pH44+46FkYcYJzdn8 +ccAg5LRGieqTozhLiXARNK1vTy6kR1l/Az8DIx6GN4sP2/LMFYFijiiOCDKS1wWA +xQgZeHt4GIuBym+Kd+Z5KXcP0AT+47Cby3+B10Kq8vHwjTELiF0UFeEYYMdynPAW +pbEwVLhsfMsBqFtj3dsxHr8Kz3rnarOYzkaw7EMZnLAthb2CN7y5uGV9imQC5RMI +/qZdRSuCYZ3A1E/WJkGbPY/YdPql/IE+LIAgKGFHZZNftBCo -----END CERTIFICATE----- \ No newline at end of file diff --git a/ipn/ipnlocal/testdata/rootCA.pem b/ipn/ipnlocal/testdata/rootCA.pem index 88a16f47a8ac9..28bd25467f07f 100644 --- a/ipn/ipnlocal/testdata/rootCA.pem +++ b/ipn/ipnlocal/testdata/rootCA.pem @@ -1,30 +1,30 @@ ------BEGIN CERTIFICATE----- -MIIFEDCCA3igAwIBAgIRANf5NdPojIfj70wMfJVYUg8wDQYJKoZIhvcNAQELBQAw -gZ8xHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTE6MDgGA1UECwwxZnJv -bWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWljaGFlbCBKLiBGcm9tYmVyZ2VyKTFB -MD8GA1UEAww4bWtjZXJ0IGZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hh -ZWwgSi4gRnJvbWJlcmdlcikwHhcNMjMwMjA3MjAzNDE4WhcNMzMwMjA3MjAzNDE4 -WjCBnzEeMBwGA1UEChMVbWtjZXJ0IGRldmVsb3BtZW50IENBMTowOAYDVQQLDDFm -cm9tYmVyZ2VyQHN0YXJkdXN0LmxvY2FsIChNaWNoYWVsIEouIEZyb21iZXJnZXIp -MUEwPwYDVQQDDDhta2NlcnQgZnJvbWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWlj -aGFlbCBKLiBGcm9tYmVyZ2VyKTCCAaIwDQYJKoZIhvcNAQEBBQADggGPADCCAYoC -ggGBAL5uXNnrZ6dgjcvK0Hc7ZNUIRYEWst9qbO0P9H7le08pJ6d9T2BUWruZtVjk -Q12msv5/bVWHhVk8dZclI9FLXuMsIrocH8bsoP4wruPMyRyp6EedSKODN51fFSRv -/jHbS5vzUVAWTYy9qYmd6qL0uhsHCZCCT6gfigamHPUFKM3sHDn5ZHWvySMwcyGl -AicmPAIkBWqiCZAkB5+WM7+oyRLjmrIalfWIZYxW/rojGLwTfneHv6J5WjVQnpJB -ayWCzCzaiXukK9MeBWeTOe8UfVN0Engd74/rjLWvjbfC+uZSr6RVkZvs2jANLwPF -zgzBPHgRPfAhszU1NNAMjnNQ47+OMOTKRt7e6jYzhO5fyO1qVAAvGBqcfpj+JfDk -cccaUMhUvdiGrhGf1V1tN/PislxvALirzcFipjD01isBKwn0fxRugzvJNrjEo8RA -RvbcdeKcwex7M0o/Cd0+G2B13gZNOFvR33PmG7iTpp7IUrUKfQg28I83Sp8tMY3s -ljJSawIDAQABo0UwQzAOBgNVHQ8BAf8EBAMCAgQwEgYDVR0TAQH/BAgwBgEB/wIB -ADAdBgNVHQ4EFgQU18qto0Fa56kCi/HwfQuC9ECX7cAwDQYJKoZIhvcNAQELBQAD -ggGBAAzs96LwZVOsRSlBdQqMo8oMAvs7HgnYbXt8SqaACLX3+kJ3cV/vrCE3iJrW -ma4CiQbxS/HqsiZjota5m4lYeEevRnUDpXhp+7ugZTiz33Flm1RU99c9UYfQ+919 -ANPAKeqNpoPco/HF5Bz0ocepjcfKQrVZZNTj6noLs8o12FHBLO5976AcF9mqlNfh -8/F0gDJXq6+x7VT5y8u0rY004XKPRe3CklRt8kpeMiP6mhRyyUehOaHeIbNx8ubi -Pi44ByN/ueAnuRhF9zYtyZVZZOaSLysJge01tuPXF8rBXGruoJIv35xTTBa9BzaP -YDOGbGn1ZnajdNagHqCba8vjTLDSpqMvgRj3TFrGHdETA2LDQat38uVxX8gxm68K -va5Tyv7n+6BQ5YTpJjTPnmSJKaXZrrhdLPvG0OU2TxeEsvbcm5LFQofirOOw86Se -vzF2cQ94mmHRZiEk0Av3NO0jF93ELDrBCuiccVyEKq6TknuvPQlutCXKDOYSEb8I -MHctBg== +-----BEGIN CERTIFICATE----- +MIIFEDCCA3igAwIBAgIRANf5NdPojIfj70wMfJVYUg8wDQYJKoZIhvcNAQELBQAw +gZ8xHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTE6MDgGA1UECwwxZnJv +bWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWljaGFlbCBKLiBGcm9tYmVyZ2VyKTFB +MD8GA1UEAww4bWtjZXJ0IGZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hh +ZWwgSi4gRnJvbWJlcmdlcikwHhcNMjMwMjA3MjAzNDE4WhcNMzMwMjA3MjAzNDE4 +WjCBnzEeMBwGA1UEChMVbWtjZXJ0IGRldmVsb3BtZW50IENBMTowOAYDVQQLDDFm +cm9tYmVyZ2VyQHN0YXJkdXN0LmxvY2FsIChNaWNoYWVsIEouIEZyb21iZXJnZXIp +MUEwPwYDVQQDDDhta2NlcnQgZnJvbWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWlj +aGFlbCBKLiBGcm9tYmVyZ2VyKTCCAaIwDQYJKoZIhvcNAQEBBQADggGPADCCAYoC +ggGBAL5uXNnrZ6dgjcvK0Hc7ZNUIRYEWst9qbO0P9H7le08pJ6d9T2BUWruZtVjk +Q12msv5/bVWHhVk8dZclI9FLXuMsIrocH8bsoP4wruPMyRyp6EedSKODN51fFSRv +/jHbS5vzUVAWTYy9qYmd6qL0uhsHCZCCT6gfigamHPUFKM3sHDn5ZHWvySMwcyGl +AicmPAIkBWqiCZAkB5+WM7+oyRLjmrIalfWIZYxW/rojGLwTfneHv6J5WjVQnpJB +ayWCzCzaiXukK9MeBWeTOe8UfVN0Engd74/rjLWvjbfC+uZSr6RVkZvs2jANLwPF +zgzBPHgRPfAhszU1NNAMjnNQ47+OMOTKRt7e6jYzhO5fyO1qVAAvGBqcfpj+JfDk +cccaUMhUvdiGrhGf1V1tN/PislxvALirzcFipjD01isBKwn0fxRugzvJNrjEo8RA +RvbcdeKcwex7M0o/Cd0+G2B13gZNOFvR33PmG7iTpp7IUrUKfQg28I83Sp8tMY3s +ljJSawIDAQABo0UwQzAOBgNVHQ8BAf8EBAMCAgQwEgYDVR0TAQH/BAgwBgEB/wIB +ADAdBgNVHQ4EFgQU18qto0Fa56kCi/HwfQuC9ECX7cAwDQYJKoZIhvcNAQELBQAD +ggGBAAzs96LwZVOsRSlBdQqMo8oMAvs7HgnYbXt8SqaACLX3+kJ3cV/vrCE3iJrW +ma4CiQbxS/HqsiZjota5m4lYeEevRnUDpXhp+7ugZTiz33Flm1RU99c9UYfQ+919 +ANPAKeqNpoPco/HF5Bz0ocepjcfKQrVZZNTj6noLs8o12FHBLO5976AcF9mqlNfh +8/F0gDJXq6+x7VT5y8u0rY004XKPRe3CklRt8kpeMiP6mhRyyUehOaHeIbNx8ubi +Pi44ByN/ueAnuRhF9zYtyZVZZOaSLysJge01tuPXF8rBXGruoJIv35xTTBa9BzaP +YDOGbGn1ZnajdNagHqCba8vjTLDSpqMvgRj3TFrGHdETA2LDQat38uVxX8gxm68K +va5Tyv7n+6BQ5YTpJjTPnmSJKaXZrrhdLPvG0OU2TxeEsvbcm5LFQofirOOw86Se +vzF2cQ94mmHRZiEk0Av3NO0jF93ELDrBCuiccVyEKq6TknuvPQlutCXKDOYSEb8I +MHctBg== -----END CERTIFICATE----- \ No newline at end of file diff --git a/ipn/ipnserver/proxyconnect_js.go b/ipn/ipnserver/proxyconnect_js.go index 368221e2269c8..27448fa0dcce6 100644 --- a/ipn/ipnserver/proxyconnect_js.go +++ b/ipn/ipnserver/proxyconnect_js.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnserver - -import "net/http" - -func (s *Server) handleProxyConnectConn(w http.ResponseWriter, r *http.Request) { - panic("unreachable") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnserver + +import "net/http" + +func (s *Server) handleProxyConnectConn(w http.ResponseWriter, r *http.Request) { + panic("unreachable") +} diff --git a/ipn/ipnserver/server_test.go b/ipn/ipnserver/server_test.go index b7d5ea144c408..49fb4d01f3ae0 100644 --- a/ipn/ipnserver/server_test.go +++ b/ipn/ipnserver/server_test.go @@ -1,46 +1,46 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnserver - -import ( - "context" - "sync" - "testing" -) - -func TestWaiterSet(t *testing.T) { - var s waiterSet - - wantLen := func(want int, when string) { - t.Helper() - if got := len(s); got != want { - t.Errorf("%s: len = %v; want %v", when, got, want) - } - } - wantLen(0, "initial") - var mu sync.Mutex - ctx, cancel := context.WithCancel(context.Background()) - - ready, cleanup := s.add(&mu, ctx) - wantLen(1, "after add") - - select { - case <-ready: - t.Fatal("should not be ready") - default: - } - s.wakeAll() - <-ready - - wantLen(1, "after fire") - cleanup() - wantLen(0, "after cleanup") - - // And again but on an already-expired ctx. - cancel() - ready, cleanup = s.add(&mu, ctx) - <-ready // shouldn't block - cleanup() - wantLen(0, "at end") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnserver + +import ( + "context" + "sync" + "testing" +) + +func TestWaiterSet(t *testing.T) { + var s waiterSet + + wantLen := func(want int, when string) { + t.Helper() + if got := len(s); got != want { + t.Errorf("%s: len = %v; want %v", when, got, want) + } + } + wantLen(0, "initial") + var mu sync.Mutex + ctx, cancel := context.WithCancel(context.Background()) + + ready, cleanup := s.add(&mu, ctx) + wantLen(1, "after add") + + select { + case <-ready: + t.Fatal("should not be ready") + default: + } + s.wakeAll() + <-ready + + wantLen(1, "after fire") + cleanup() + wantLen(0, "after cleanup") + + // And again but on an already-expired ctx. + cancel() + ready, cleanup = s.add(&mu, ctx) + <-ready // shouldn't block + cleanup() + wantLen(0, "at end") +} diff --git a/ipn/localapi/disabled_stubs.go b/ipn/localapi/disabled_stubs.go index c744f34d5f5c5..230553c145840 100644 --- a/ipn/localapi/disabled_stubs.go +++ b/ipn/localapi/disabled_stubs.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ios || android || js - -package localapi - -import ( - "net/http" - "runtime" -) - -func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { - http.Error(w, "disabled on "+runtime.GOOS, http.StatusNotFound) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ios || android || js + +package localapi + +import ( + "net/http" + "runtime" +) + +func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { + http.Error(w, "disabled on "+runtime.GOOS, http.StatusNotFound) +} diff --git a/ipn/localapi/pprof.go b/ipn/localapi/pprof.go index 8c9429b31385a..5cc4daca1cf39 100644 --- a/ipn/localapi/pprof.go +++ b/ipn/localapi/pprof.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ios && !android && !js - -// We don't include it on mobile where we're more memory constrained and -// there's no CLI to get at the results anyway. - -package localapi - -import ( - "net/http" - "net/http/pprof" -) - -func init() { - servePprofFunc = servePprof -} - -func servePprof(w http.ResponseWriter, r *http.Request) { - name := r.FormValue("name") - switch name { - case "profile": - pprof.Profile(w, r) - default: - pprof.Handler(name).ServeHTTP(w, r) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !android && !js + +// We don't include it on mobile where we're more memory constrained and +// there's no CLI to get at the results anyway. + +package localapi + +import ( + "net/http" + "net/http/pprof" +) + +func init() { + servePprofFunc = servePprof +} + +func servePprof(w http.ResponseWriter, r *http.Request) { + name := r.FormValue("name") + switch name { + case "profile": + pprof.Profile(w, r) + default: + pprof.Handler(name).ServeHTTP(w, r) + } +} diff --git a/ipn/policy/policy.go b/ipn/policy/policy.go index 494a0dc408819..834706f31a389 100644 --- a/ipn/policy/policy.go +++ b/ipn/policy/policy.go @@ -1,47 +1,47 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package policy contains various policy decisions that need to be -// shared between the node client & control server. -package policy - -import ( - "tailscale.com/tailcfg" -) - -// IsInterestingService reports whether service s on the given operating -// system (a version.OS value) is an interesting enough port to report -// to our peer nodes for discovery purposes. -func IsInterestingService(s tailcfg.Service, os string) bool { - switch s.Proto { - case tailcfg.PeerAPI4, tailcfg.PeerAPI6, tailcfg.PeerAPIDNS: - return true - } - if s.Proto != tailcfg.TCP { - return false - } - if os != "windows" { - // For non-Windows machines, assume all TCP listeners - // are interesting enough. We don't see listener spam - // there. - return true - } - // Windows has tons of TCP listeners. We need to move to a denylist - // model later, but for now we just allow some common ones: - switch s.Port { - case 22, // ssh - 80, // http - 443, // https (but no hostname, so little useless) - 3389, // rdp - 5900, // vnc - 32400, // plex - - // And now some arbitrary HTTP dev server ports: - // Eventually we'll remove this and make all ports - // work, once we nicely filter away noisy system - // ports. - 8000, 8080, 8443, 8888: - return true - } - return false -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package policy contains various policy decisions that need to be +// shared between the node client & control server. +package policy + +import ( + "tailscale.com/tailcfg" +) + +// IsInterestingService reports whether service s on the given operating +// system (a version.OS value) is an interesting enough port to report +// to our peer nodes for discovery purposes. +func IsInterestingService(s tailcfg.Service, os string) bool { + switch s.Proto { + case tailcfg.PeerAPI4, tailcfg.PeerAPI6, tailcfg.PeerAPIDNS: + return true + } + if s.Proto != tailcfg.TCP { + return false + } + if os != "windows" { + // For non-Windows machines, assume all TCP listeners + // are interesting enough. We don't see listener spam + // there. + return true + } + // Windows has tons of TCP listeners. We need to move to a denylist + // model later, but for now we just allow some common ones: + switch s.Port { + case 22, // ssh + 80, // http + 443, // https (but no hostname, so little useless) + 3389, // rdp + 5900, // vnc + 32400, // plex + + // And now some arbitrary HTTP dev server ports: + // Eventually we'll remove this and make all ports + // work, once we nicely filter away noisy system + // ports. + 8000, 8080, 8443, 8888: + return true + } + return false +} diff --git a/ipn/store/awsstore/store_aws.go b/ipn/store/awsstore/store_aws.go index 0fb78d45a6a53..84059af67c57d 100644 --- a/ipn/store/awsstore/store_aws.go +++ b/ipn/store/awsstore/store_aws.go @@ -1,186 +1,186 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux && !ts_omit_aws - -// Package awsstore contains an ipn.StateStore implementation using AWS SSM. -package awsstore - -import ( - "context" - "errors" - "fmt" - "regexp" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/aws/arn" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/service/ssm" - ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" - "tailscale.com/ipn" - "tailscale.com/ipn/store/mem" - "tailscale.com/types/logger" -) - -const ( - parameterNameRxStr = `^parameter(/.*)` -) - -var parameterNameRx = regexp.MustCompile(parameterNameRxStr) - -// awsSSMClient is an interface allowing us to mock the couple of -// API calls we are leveraging with the AWSStore provider -type awsSSMClient interface { - GetParameter(ctx context.Context, - params *ssm.GetParameterInput, - optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) - - PutParameter(ctx context.Context, - params *ssm.PutParameterInput, - optFns ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) -} - -// store is a store which leverages AWS SSM parameter store -// to persist the state -type awsStore struct { - ssmClient awsSSMClient - ssmARN arn.ARN - - memory mem.Store -} - -// New returns a new ipn.StateStore using the AWS SSM storage -// location given by ssmARN. -// -// Note that we store the entire store in a single parameter -// key, therefore if the state is above 8kb, it can cause -// Tailscaled to only only store new state in-memory and -// restarting Tailscaled can fail until you delete your state -// from the AWS Parameter Store. -func New(_ logger.Logf, ssmARN string) (ipn.StateStore, error) { - return newStore(ssmARN, nil) -} - -// newStore is NewStore, but for tests. If client is non-nil, it's -// used instead of making one. -func newStore(ssmARN string, client awsSSMClient) (ipn.StateStore, error) { - s := &awsStore{ - ssmClient: client, - } - - var err error - - // Parse the ARN - if s.ssmARN, err = arn.Parse(ssmARN); err != nil { - return nil, fmt.Errorf("unable to parse the ARN correctly: %v", err) - } - - // Validate the ARN corresponds to the SSM service - if s.ssmARN.Service != "ssm" { - return nil, fmt.Errorf("invalid service %q, expected 'ssm'", s.ssmARN.Service) - } - - // Validate the ARN corresponds to a parameter store resource - if !parameterNameRx.MatchString(s.ssmARN.Resource) { - return nil, fmt.Errorf("invalid resource %q, expected to match %v", s.ssmARN.Resource, parameterNameRxStr) - } - - if s.ssmClient == nil { - var cfg aws.Config - if cfg, err = config.LoadDefaultConfig( - context.TODO(), - config.WithRegion(s.ssmARN.Region), - ); err != nil { - return nil, err - } - s.ssmClient = ssm.NewFromConfig(cfg) - } - - // Hydrate cache with the potentially current state - if err := s.LoadState(); err != nil { - return nil, err - } - return s, nil - -} - -// LoadState attempts to read the state from AWS SSM parameter store key. -func (s *awsStore) LoadState() error { - param, err := s.ssmClient.GetParameter( - context.TODO(), - &ssm.GetParameterInput{ - Name: aws.String(s.ParameterName()), - WithDecryption: aws.Bool(true), - }, - ) - - if err != nil { - var pnf *ssmTypes.ParameterNotFound - if errors.As(err, &pnf) { - // Create the parameter as it does not exist yet - // and return directly as it is defacto empty - return s.persistState() - } - return err - } - - // Load the content in-memory - return s.memory.LoadFromJSON([]byte(*param.Parameter.Value)) -} - -// ParameterName returns the parameter name extracted from -// the provided ARN -func (s *awsStore) ParameterName() (name string) { - values := parameterNameRx.FindStringSubmatch(s.ssmARN.Resource) - if len(values) == 2 { - name = values[1] - } - return -} - -// String returns the awsStore and the ARN of the SSM parameter store -// configured to store the state -func (s *awsStore) String() string { return fmt.Sprintf("awsStore(%q)", s.ssmARN.String()) } - -// ReadState implements the Store interface. -func (s *awsStore) ReadState(id ipn.StateKey) (bs []byte, err error) { - return s.memory.ReadState(id) -} - -// WriteState implements the Store interface. -func (s *awsStore) WriteState(id ipn.StateKey, bs []byte) (err error) { - // Write the state in-memory - if err = s.memory.WriteState(id, bs); err != nil { - return - } - - // Persist the state in AWS SSM parameter store - return s.persistState() -} - -// PersistState saves the states into the AWS SSM parameter store -func (s *awsStore) persistState() error { - // Generate JSON from in-memory cache - bs, err := s.memory.ExportToJSON() - if err != nil { - return err - } - - // Store in AWS SSM parameter store. - // - // We use intelligent tiering so that when the state is below 4kb, it uses Standard tiering - // which is free. However, if it exceeds 4kb it switches the parameter to advanced tiering - // doubling the capacity to 8kb per the following docs: - // https://aws.amazon.com/about-aws/whats-new/2019/08/aws-systems-manager-parameter-store-announces-intelligent-tiering-to-enable-automatic-parameter-tier-selection/ - _, err = s.ssmClient.PutParameter( - context.TODO(), - &ssm.PutParameterInput{ - Name: aws.String(s.ParameterName()), - Value: aws.String(string(bs)), - Overwrite: aws.Bool(true), - Tier: ssmTypes.ParameterTierIntelligentTiering, - Type: ssmTypes.ParameterTypeSecureString, - }, - ) - return err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !ts_omit_aws + +// Package awsstore contains an ipn.StateStore implementation using AWS SSM. +package awsstore + +import ( + "context" + "errors" + "fmt" + "regexp" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ssm" + ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" + "tailscale.com/ipn" + "tailscale.com/ipn/store/mem" + "tailscale.com/types/logger" +) + +const ( + parameterNameRxStr = `^parameter(/.*)` +) + +var parameterNameRx = regexp.MustCompile(parameterNameRxStr) + +// awsSSMClient is an interface allowing us to mock the couple of +// API calls we are leveraging with the AWSStore provider +type awsSSMClient interface { + GetParameter(ctx context.Context, + params *ssm.GetParameterInput, + optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) + + PutParameter(ctx context.Context, + params *ssm.PutParameterInput, + optFns ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) +} + +// store is a store which leverages AWS SSM parameter store +// to persist the state +type awsStore struct { + ssmClient awsSSMClient + ssmARN arn.ARN + + memory mem.Store +} + +// New returns a new ipn.StateStore using the AWS SSM storage +// location given by ssmARN. +// +// Note that we store the entire store in a single parameter +// key, therefore if the state is above 8kb, it can cause +// Tailscaled to only only store new state in-memory and +// restarting Tailscaled can fail until you delete your state +// from the AWS Parameter Store. +func New(_ logger.Logf, ssmARN string) (ipn.StateStore, error) { + return newStore(ssmARN, nil) +} + +// newStore is NewStore, but for tests. If client is non-nil, it's +// used instead of making one. +func newStore(ssmARN string, client awsSSMClient) (ipn.StateStore, error) { + s := &awsStore{ + ssmClient: client, + } + + var err error + + // Parse the ARN + if s.ssmARN, err = arn.Parse(ssmARN); err != nil { + return nil, fmt.Errorf("unable to parse the ARN correctly: %v", err) + } + + // Validate the ARN corresponds to the SSM service + if s.ssmARN.Service != "ssm" { + return nil, fmt.Errorf("invalid service %q, expected 'ssm'", s.ssmARN.Service) + } + + // Validate the ARN corresponds to a parameter store resource + if !parameterNameRx.MatchString(s.ssmARN.Resource) { + return nil, fmt.Errorf("invalid resource %q, expected to match %v", s.ssmARN.Resource, parameterNameRxStr) + } + + if s.ssmClient == nil { + var cfg aws.Config + if cfg, err = config.LoadDefaultConfig( + context.TODO(), + config.WithRegion(s.ssmARN.Region), + ); err != nil { + return nil, err + } + s.ssmClient = ssm.NewFromConfig(cfg) + } + + // Hydrate cache with the potentially current state + if err := s.LoadState(); err != nil { + return nil, err + } + return s, nil + +} + +// LoadState attempts to read the state from AWS SSM parameter store key. +func (s *awsStore) LoadState() error { + param, err := s.ssmClient.GetParameter( + context.TODO(), + &ssm.GetParameterInput{ + Name: aws.String(s.ParameterName()), + WithDecryption: aws.Bool(true), + }, + ) + + if err != nil { + var pnf *ssmTypes.ParameterNotFound + if errors.As(err, &pnf) { + // Create the parameter as it does not exist yet + // and return directly as it is defacto empty + return s.persistState() + } + return err + } + + // Load the content in-memory + return s.memory.LoadFromJSON([]byte(*param.Parameter.Value)) +} + +// ParameterName returns the parameter name extracted from +// the provided ARN +func (s *awsStore) ParameterName() (name string) { + values := parameterNameRx.FindStringSubmatch(s.ssmARN.Resource) + if len(values) == 2 { + name = values[1] + } + return +} + +// String returns the awsStore and the ARN of the SSM parameter store +// configured to store the state +func (s *awsStore) String() string { return fmt.Sprintf("awsStore(%q)", s.ssmARN.String()) } + +// ReadState implements the Store interface. +func (s *awsStore) ReadState(id ipn.StateKey) (bs []byte, err error) { + return s.memory.ReadState(id) +} + +// WriteState implements the Store interface. +func (s *awsStore) WriteState(id ipn.StateKey, bs []byte) (err error) { + // Write the state in-memory + if err = s.memory.WriteState(id, bs); err != nil { + return + } + + // Persist the state in AWS SSM parameter store + return s.persistState() +} + +// PersistState saves the states into the AWS SSM parameter store +func (s *awsStore) persistState() error { + // Generate JSON from in-memory cache + bs, err := s.memory.ExportToJSON() + if err != nil { + return err + } + + // Store in AWS SSM parameter store. + // + // We use intelligent tiering so that when the state is below 4kb, it uses Standard tiering + // which is free. However, if it exceeds 4kb it switches the parameter to advanced tiering + // doubling the capacity to 8kb per the following docs: + // https://aws.amazon.com/about-aws/whats-new/2019/08/aws-systems-manager-parameter-store-announces-intelligent-tiering-to-enable-automatic-parameter-tier-selection/ + _, err = s.ssmClient.PutParameter( + context.TODO(), + &ssm.PutParameterInput{ + Name: aws.String(s.ParameterName()), + Value: aws.String(string(bs)), + Overwrite: aws.Bool(true), + Tier: ssmTypes.ParameterTierIntelligentTiering, + Type: ssmTypes.ParameterTypeSecureString, + }, + ) + return err +} diff --git a/ipn/store/awsstore/store_aws_stub.go b/ipn/store/awsstore/store_aws_stub.go index 8d2156ce948d5..7be8b858d752f 100644 --- a/ipn/store/awsstore/store_aws_stub.go +++ b/ipn/store/awsstore/store_aws_stub.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux || ts_omit_aws - -package awsstore - -import ( - "fmt" - "runtime" - - "tailscale.com/ipn" - "tailscale.com/types/logger" -) - -func New(logger.Logf, string) (ipn.StateStore, error) { - return nil, fmt.Errorf("AWS store is not supported on %v", runtime.GOOS) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux || ts_omit_aws + +package awsstore + +import ( + "fmt" + "runtime" + + "tailscale.com/ipn" + "tailscale.com/types/logger" +) + +func New(logger.Logf, string) (ipn.StateStore, error) { + return nil, fmt.Errorf("AWS store is not supported on %v", runtime.GOOS) +} diff --git a/ipn/store/awsstore/store_aws_test.go b/ipn/store/awsstore/store_aws_test.go index f6c8fedb32dc9..54e6e18cb4115 100644 --- a/ipn/store/awsstore/store_aws_test.go +++ b/ipn/store/awsstore/store_aws_test.go @@ -1,164 +1,164 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package awsstore - -import ( - "context" - "testing" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/aws/arn" - "github.com/aws/aws-sdk-go-v2/service/ssm" - ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" - "tailscale.com/ipn" - "tailscale.com/tstest" -) - -type mockedAWSSSMClient struct { - value string -} - -func (sp *mockedAWSSSMClient) GetParameter(_ context.Context, input *ssm.GetParameterInput, _ ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { - output := new(ssm.GetParameterOutput) - if sp.value == "" { - return output, &ssmTypes.ParameterNotFound{} - } - - output.Parameter = &ssmTypes.Parameter{ - Value: aws.String(sp.value), - } - - return output, nil -} - -func (sp *mockedAWSSSMClient) PutParameter(_ context.Context, input *ssm.PutParameterInput, _ ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) { - sp.value = *input.Value - return new(ssm.PutParameterOutput), nil -} - -func TestAWSStoreString(t *testing.T) { - store := &awsStore{ - ssmARN: arn.ARN{ - Service: "ssm", - Region: "eu-west-1", - AccountID: "123456789", - Resource: "parameter/foo", - }, - } - want := "awsStore(\"arn::ssm:eu-west-1:123456789:parameter/foo\")" - if got := store.String(); got != want { - t.Errorf("AWSStore.String = %q; want %q", got, want) - } -} - -func TestNewAWSStore(t *testing.T) { - tstest.PanicOnLog() - - mc := &mockedAWSSSMClient{} - storeParameterARN := arn.ARN{ - Service: "ssm", - Region: "eu-west-1", - AccountID: "123456789", - Resource: "parameter/foo", - } - - s, err := newStore(storeParameterARN.String(), mc) - if err != nil { - t.Fatalf("creating aws store failed: %v", err) - } - testStoreSemantics(t, s) - - // Build a brand new file store and check that both IDs written - // above are still there. - s2, err := newStore(storeParameterARN.String(), mc) - if err != nil { - t.Fatalf("creating second aws store failed: %v", err) - } - store2 := s.(*awsStore) - - // This is specific to the test, with the non-mocked API, LoadState() should - // have been already called and successful as no err is returned from NewAWSStore() - s2.(*awsStore).LoadState() - - expected := map[ipn.StateKey]string{ - "foo": "bar", - "baz": "quux", - } - for id, want := range expected { - bs, err := store2.ReadState(id) - if err != nil { - t.Errorf("reading %q (2nd store): %v", id, err) - } - if string(bs) != want { - t.Errorf("reading %q (2nd store): got %q, want %q", id, string(bs), want) - } - } -} - -func testStoreSemantics(t *testing.T, store ipn.StateStore) { - t.Helper() - - tests := []struct { - // if true, data is data to write. If false, data is expected - // output of read. - write bool - id ipn.StateKey - data string - // If write=false, true if we expect a not-exist error. - notExists bool - }{ - { - id: "foo", - notExists: true, - }, - { - write: true, - id: "foo", - data: "bar", - }, - { - id: "foo", - data: "bar", - }, - { - id: "baz", - notExists: true, - }, - { - write: true, - id: "baz", - data: "quux", - }, - { - id: "foo", - data: "bar", - }, - { - id: "baz", - data: "quux", - }, - } - - for _, test := range tests { - if test.write { - if err := store.WriteState(test.id, []byte(test.data)); err != nil { - t.Errorf("writing %q to %q: %v", test.data, test.id, err) - } - } else { - bs, err := store.ReadState(test.id) - if err != nil { - if test.notExists && err == ipn.ErrStateNotExist { - continue - } - t.Errorf("reading %q: %v", test.id, err) - continue - } - if string(bs) != test.data { - t.Errorf("reading %q: got %q, want %q", test.id, string(bs), test.data) - } - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package awsstore + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/service/ssm" + ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" + "tailscale.com/ipn" + "tailscale.com/tstest" +) + +type mockedAWSSSMClient struct { + value string +} + +func (sp *mockedAWSSSMClient) GetParameter(_ context.Context, input *ssm.GetParameterInput, _ ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + output := new(ssm.GetParameterOutput) + if sp.value == "" { + return output, &ssmTypes.ParameterNotFound{} + } + + output.Parameter = &ssmTypes.Parameter{ + Value: aws.String(sp.value), + } + + return output, nil +} + +func (sp *mockedAWSSSMClient) PutParameter(_ context.Context, input *ssm.PutParameterInput, _ ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) { + sp.value = *input.Value + return new(ssm.PutParameterOutput), nil +} + +func TestAWSStoreString(t *testing.T) { + store := &awsStore{ + ssmARN: arn.ARN{ + Service: "ssm", + Region: "eu-west-1", + AccountID: "123456789", + Resource: "parameter/foo", + }, + } + want := "awsStore(\"arn::ssm:eu-west-1:123456789:parameter/foo\")" + if got := store.String(); got != want { + t.Errorf("AWSStore.String = %q; want %q", got, want) + } +} + +func TestNewAWSStore(t *testing.T) { + tstest.PanicOnLog() + + mc := &mockedAWSSSMClient{} + storeParameterARN := arn.ARN{ + Service: "ssm", + Region: "eu-west-1", + AccountID: "123456789", + Resource: "parameter/foo", + } + + s, err := newStore(storeParameterARN.String(), mc) + if err != nil { + t.Fatalf("creating aws store failed: %v", err) + } + testStoreSemantics(t, s) + + // Build a brand new file store and check that both IDs written + // above are still there. + s2, err := newStore(storeParameterARN.String(), mc) + if err != nil { + t.Fatalf("creating second aws store failed: %v", err) + } + store2 := s.(*awsStore) + + // This is specific to the test, with the non-mocked API, LoadState() should + // have been already called and successful as no err is returned from NewAWSStore() + s2.(*awsStore).LoadState() + + expected := map[ipn.StateKey]string{ + "foo": "bar", + "baz": "quux", + } + for id, want := range expected { + bs, err := store2.ReadState(id) + if err != nil { + t.Errorf("reading %q (2nd store): %v", id, err) + } + if string(bs) != want { + t.Errorf("reading %q (2nd store): got %q, want %q", id, string(bs), want) + } + } +} + +func testStoreSemantics(t *testing.T, store ipn.StateStore) { + t.Helper() + + tests := []struct { + // if true, data is data to write. If false, data is expected + // output of read. + write bool + id ipn.StateKey + data string + // If write=false, true if we expect a not-exist error. + notExists bool + }{ + { + id: "foo", + notExists: true, + }, + { + write: true, + id: "foo", + data: "bar", + }, + { + id: "foo", + data: "bar", + }, + { + id: "baz", + notExists: true, + }, + { + write: true, + id: "baz", + data: "quux", + }, + { + id: "foo", + data: "bar", + }, + { + id: "baz", + data: "quux", + }, + } + + for _, test := range tests { + if test.write { + if err := store.WriteState(test.id, []byte(test.data)); err != nil { + t.Errorf("writing %q to %q: %v", test.data, test.id, err) + } + } else { + bs, err := store.ReadState(test.id) + if err != nil { + if test.notExists && err == ipn.ErrStateNotExist { + continue + } + t.Errorf("reading %q: %v", test.id, err) + continue + } + if string(bs) != test.data { + t.Errorf("reading %q: got %q, want %q", test.id, string(bs), test.data) + } + } + } +} diff --git a/ipn/store/stores_test.go b/ipn/store/stores_test.go index ea09e6ea63ae4..69aa791938747 100644 --- a/ipn/store/stores_test.go +++ b/ipn/store/stores_test.go @@ -1,179 +1,179 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package store - -import ( - "path/filepath" - "testing" - - "tailscale.com/ipn" - "tailscale.com/ipn/store/mem" - "tailscale.com/tstest" - "tailscale.com/types/logger" -) - -func TestNewStore(t *testing.T) { - regOnce.Do(registerDefaultStores) - t.Cleanup(func() { - knownStores = map[string]Provider{} - registerDefaultStores() - }) - knownStores = map[string]Provider{} - - type store1 struct { - ipn.StateStore - path string - } - - type store2 struct { - ipn.StateStore - path string - } - - Register("arn:", func(_ logger.Logf, path string) (ipn.StateStore, error) { - return &store1{new(mem.Store), path}, nil - }) - Register("kube:", func(_ logger.Logf, path string) (ipn.StateStore, error) { - return &store2{new(mem.Store), path}, nil - }) - Register("mem:", func(_ logger.Logf, path string) (ipn.StateStore, error) { - return new(mem.Store), nil - }) - - path := "mem:abcd" - if s, err := New(t.Logf, path); err != nil { - t.Fatalf("%q: %v", path, err) - } else if _, ok := s.(*mem.Store); !ok { - t.Fatalf("%q: got: %T, want: %T", path, s, new(mem.Store)) - } - - path = "arn:foo" - if s, err := New(t.Logf, path); err != nil { - t.Fatalf("%q: %v", path, err) - } else if _, ok := s.(*store1); !ok { - t.Fatalf("%q: got: %T, want: %T", path, s, new(store1)) - } - - path = "kube:abcd" - if s, err := New(t.Logf, path); err != nil { - t.Fatalf("%q: %v", path, err) - } else if _, ok := s.(*store2); !ok { - t.Fatalf("%q: got: %T, want: %T", path, s, new(store2)) - } - - path = filepath.Join(t.TempDir(), "state") - if s, err := New(t.Logf, path); err != nil { - t.Fatalf("%q: %v", path, err) - } else if _, ok := s.(*FileStore); !ok { - t.Fatalf("%q: got: %T, want: %T", path, s, new(FileStore)) - } -} - -func testStoreSemantics(t *testing.T, store ipn.StateStore) { - t.Helper() - - tests := []struct { - // if true, data is data to write. If false, data is expected - // output of read. - write bool - id ipn.StateKey - data string - // If write=false, true if we expect a not-exist error. - notExists bool - }{ - { - id: "foo", - notExists: true, - }, - { - write: true, - id: "foo", - data: "bar", - }, - { - id: "foo", - data: "bar", - }, - { - id: "baz", - notExists: true, - }, - { - write: true, - id: "baz", - data: "quux", - }, - { - id: "foo", - data: "bar", - }, - { - id: "baz", - data: "quux", - }, - } - - for _, test := range tests { - if test.write { - if err := store.WriteState(test.id, []byte(test.data)); err != nil { - t.Errorf("writing %q to %q: %v", test.data, test.id, err) - } - } else { - bs, err := store.ReadState(test.id) - if err != nil { - if test.notExists && err == ipn.ErrStateNotExist { - continue - } - t.Errorf("reading %q: %v", test.id, err) - continue - } - if string(bs) != test.data { - t.Errorf("reading %q: got %q, want %q", test.id, string(bs), test.data) - } - } - } -} - -func TestMemoryStore(t *testing.T) { - tstest.PanicOnLog() - - store := new(mem.Store) - testStoreSemantics(t, store) -} - -func TestFileStore(t *testing.T) { - tstest.PanicOnLog() - - dir := t.TempDir() - path := filepath.Join(dir, "test-file-store.conf") - - store, err := NewFileStore(nil, path) - if err != nil { - t.Fatalf("creating file store failed: %v", err) - } - - testStoreSemantics(t, store) - - // Build a brand new file store and check that both IDs written - // above are still there. - store, err = NewFileStore(nil, path) - if err != nil { - t.Fatalf("creating second file store failed: %v", err) - } - - expected := map[ipn.StateKey]string{ - "foo": "bar", - "baz": "quux", - } - for key, want := range expected { - bs, err := store.ReadState(key) - if err != nil { - t.Errorf("reading %q (2nd store): %v", key, err) - continue - } - if string(bs) != want { - t.Errorf("reading %q (2nd store): got %q, want %q", key, bs, want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package store + +import ( + "path/filepath" + "testing" + + "tailscale.com/ipn" + "tailscale.com/ipn/store/mem" + "tailscale.com/tstest" + "tailscale.com/types/logger" +) + +func TestNewStore(t *testing.T) { + regOnce.Do(registerDefaultStores) + t.Cleanup(func() { + knownStores = map[string]Provider{} + registerDefaultStores() + }) + knownStores = map[string]Provider{} + + type store1 struct { + ipn.StateStore + path string + } + + type store2 struct { + ipn.StateStore + path string + } + + Register("arn:", func(_ logger.Logf, path string) (ipn.StateStore, error) { + return &store1{new(mem.Store), path}, nil + }) + Register("kube:", func(_ logger.Logf, path string) (ipn.StateStore, error) { + return &store2{new(mem.Store), path}, nil + }) + Register("mem:", func(_ logger.Logf, path string) (ipn.StateStore, error) { + return new(mem.Store), nil + }) + + path := "mem:abcd" + if s, err := New(t.Logf, path); err != nil { + t.Fatalf("%q: %v", path, err) + } else if _, ok := s.(*mem.Store); !ok { + t.Fatalf("%q: got: %T, want: %T", path, s, new(mem.Store)) + } + + path = "arn:foo" + if s, err := New(t.Logf, path); err != nil { + t.Fatalf("%q: %v", path, err) + } else if _, ok := s.(*store1); !ok { + t.Fatalf("%q: got: %T, want: %T", path, s, new(store1)) + } + + path = "kube:abcd" + if s, err := New(t.Logf, path); err != nil { + t.Fatalf("%q: %v", path, err) + } else if _, ok := s.(*store2); !ok { + t.Fatalf("%q: got: %T, want: %T", path, s, new(store2)) + } + + path = filepath.Join(t.TempDir(), "state") + if s, err := New(t.Logf, path); err != nil { + t.Fatalf("%q: %v", path, err) + } else if _, ok := s.(*FileStore); !ok { + t.Fatalf("%q: got: %T, want: %T", path, s, new(FileStore)) + } +} + +func testStoreSemantics(t *testing.T, store ipn.StateStore) { + t.Helper() + + tests := []struct { + // if true, data is data to write. If false, data is expected + // output of read. + write bool + id ipn.StateKey + data string + // If write=false, true if we expect a not-exist error. + notExists bool + }{ + { + id: "foo", + notExists: true, + }, + { + write: true, + id: "foo", + data: "bar", + }, + { + id: "foo", + data: "bar", + }, + { + id: "baz", + notExists: true, + }, + { + write: true, + id: "baz", + data: "quux", + }, + { + id: "foo", + data: "bar", + }, + { + id: "baz", + data: "quux", + }, + } + + for _, test := range tests { + if test.write { + if err := store.WriteState(test.id, []byte(test.data)); err != nil { + t.Errorf("writing %q to %q: %v", test.data, test.id, err) + } + } else { + bs, err := store.ReadState(test.id) + if err != nil { + if test.notExists && err == ipn.ErrStateNotExist { + continue + } + t.Errorf("reading %q: %v", test.id, err) + continue + } + if string(bs) != test.data { + t.Errorf("reading %q: got %q, want %q", test.id, string(bs), test.data) + } + } + } +} + +func TestMemoryStore(t *testing.T) { + tstest.PanicOnLog() + + store := new(mem.Store) + testStoreSemantics(t, store) +} + +func TestFileStore(t *testing.T) { + tstest.PanicOnLog() + + dir := t.TempDir() + path := filepath.Join(dir, "test-file-store.conf") + + store, err := NewFileStore(nil, path) + if err != nil { + t.Fatalf("creating file store failed: %v", err) + } + + testStoreSemantics(t, store) + + // Build a brand new file store and check that both IDs written + // above are still there. + store, err = NewFileStore(nil, path) + if err != nil { + t.Fatalf("creating second file store failed: %v", err) + } + + expected := map[ipn.StateKey]string{ + "foo": "bar", + "baz": "quux", + } + for key, want := range expected { + bs, err := store.ReadState(key) + if err != nil { + t.Errorf("reading %q (2nd store): %v", key, err) + continue + } + if string(bs) != want { + t.Errorf("reading %q (2nd store): got %q, want %q", key, bs, want) + } + } +} diff --git a/ipn/store_test.go b/ipn/store_test.go index fcc082d8a8a87..330f67969085b 100644 --- a/ipn/store_test.go +++ b/ipn/store_test.go @@ -1,48 +1,48 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipn - -import ( - "bytes" - "sync" - "testing" - - "tailscale.com/util/mak" -) - -type memStore struct { - mu sync.Mutex - writes int - m map[StateKey][]byte -} - -func (s *memStore) ReadState(k StateKey) ([]byte, error) { - s.mu.Lock() - defer s.mu.Unlock() - return bytes.Clone(s.m[k]), nil -} - -func (s *memStore) WriteState(k StateKey, v []byte) error { - s.mu.Lock() - defer s.mu.Unlock() - mak.Set(&s.m, k, bytes.Clone(v)) - s.writes++ - return nil -} - -func TestWriteState(t *testing.T) { - var ss StateStore = new(memStore) - WriteState(ss, "foo", []byte("bar")) - WriteState(ss, "foo", []byte("bar")) - got, err := ss.ReadState("foo") - if err != nil { - t.Fatal(err) - } - if want := []byte("bar"); !bytes.Equal(got, want) { - t.Errorf("got %q; want %q", got, want) - } - if got, want := ss.(*memStore).writes, 1; got != want { - t.Errorf("got %d writes; want %d", got, want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipn + +import ( + "bytes" + "sync" + "testing" + + "tailscale.com/util/mak" +) + +type memStore struct { + mu sync.Mutex + writes int + m map[StateKey][]byte +} + +func (s *memStore) ReadState(k StateKey) ([]byte, error) { + s.mu.Lock() + defer s.mu.Unlock() + return bytes.Clone(s.m[k]), nil +} + +func (s *memStore) WriteState(k StateKey, v []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + mak.Set(&s.m, k, bytes.Clone(v)) + s.writes++ + return nil +} + +func TestWriteState(t *testing.T) { + var ss StateStore = new(memStore) + WriteState(ss, "foo", []byte("bar")) + WriteState(ss, "foo", []byte("bar")) + got, err := ss.ReadState("foo") + if err != nil { + t.Fatal(err) + } + if want := []byte("bar"); !bytes.Equal(got, want) { + t.Errorf("got %q; want %q", got, want) + } + if got, want := ss.(*memStore).writes, 1; got != want { + t.Errorf("got %d writes; want %d", got, want) + } +} diff --git a/jsondb/db.go b/jsondb/db.go index 68bb05af45e8e..c45c1f819ca05 100644 --- a/jsondb/db.go +++ b/jsondb/db.go @@ -1,57 +1,57 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package jsondb provides a trivial "database": a Go object saved to -// disk as JSON. -package jsondb - -import ( - "encoding/json" - "errors" - "io/fs" - "os" - - "tailscale.com/atomicfile" -) - -// DB is a database backed by a JSON file. -type DB[T any] struct { - // Data is the contents of the database. - Data *T - - path string -} - -// Open opens the database at path, creating it with a zero value if -// necessary. -func Open[T any](path string) (*DB[T], error) { - bs, err := os.ReadFile(path) - if errors.Is(err, fs.ErrNotExist) { - return &DB[T]{ - Data: new(T), - path: path, - }, nil - } else if err != nil { - return nil, err - } - - var val T - if err := json.Unmarshal(bs, &val); err != nil { - return nil, err - } - - return &DB[T]{ - Data: &val, - path: path, - }, nil -} - -// Save writes db.Data back to disk. -func (db *DB[T]) Save() error { - bs, err := json.Marshal(db.Data) - if err != nil { - return err - } - - return atomicfile.WriteFile(db.path, bs, 0600) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package jsondb provides a trivial "database": a Go object saved to +// disk as JSON. +package jsondb + +import ( + "encoding/json" + "errors" + "io/fs" + "os" + + "tailscale.com/atomicfile" +) + +// DB is a database backed by a JSON file. +type DB[T any] struct { + // Data is the contents of the database. + Data *T + + path string +} + +// Open opens the database at path, creating it with a zero value if +// necessary. +func Open[T any](path string) (*DB[T], error) { + bs, err := os.ReadFile(path) + if errors.Is(err, fs.ErrNotExist) { + return &DB[T]{ + Data: new(T), + path: path, + }, nil + } else if err != nil { + return nil, err + } + + var val T + if err := json.Unmarshal(bs, &val); err != nil { + return nil, err + } + + return &DB[T]{ + Data: &val, + path: path, + }, nil +} + +// Save writes db.Data back to disk. +func (db *DB[T]) Save() error { + bs, err := json.Marshal(db.Data) + if err != nil { + return err + } + + return atomicfile.WriteFile(db.path, bs, 0600) +} diff --git a/jsondb/db_test.go b/jsondb/db_test.go index 655754f38e1a9..a78b15b4f32c7 100644 --- a/jsondb/db_test.go +++ b/jsondb/db_test.go @@ -1,55 +1,55 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package jsondb - -import ( - "log" - "os" - "path/filepath" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestDB(t *testing.T) { - dir, err := os.MkdirTemp("", "db-test") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dir) - - path := filepath.Join(dir, "db.json") - db, err := Open[testDB](path) - if err != nil { - t.Fatalf("creating empty DB: %v", err) - } - - if diff := cmp.Diff(db.Data, &testDB{}, cmp.AllowUnexported(testDB{})); diff != "" { - t.Fatalf("unexpected empty DB content (-got+want):\n%s", diff) - } - db.Data.MyString = "test" - db.Data.unexported = "don't keep" - db.Data.AnInt = 42 - if err := db.Save(); err != nil { - t.Fatalf("saving database: %v", err) - } - - db2, err := Open[testDB](path) - if err != nil { - log.Fatalf("opening DB again: %v", err) - } - want := &testDB{ - MyString: "test", - AnInt: 42, - } - if diff := cmp.Diff(db2.Data, want, cmp.AllowUnexported(testDB{})); diff != "" { - t.Fatalf("unexpected saved DB content (-got+want):\n%s", diff) - } -} - -type testDB struct { - MyString string - unexported string - AnInt int64 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package jsondb + +import ( + "log" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestDB(t *testing.T) { + dir, err := os.MkdirTemp("", "db-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + path := filepath.Join(dir, "db.json") + db, err := Open[testDB](path) + if err != nil { + t.Fatalf("creating empty DB: %v", err) + } + + if diff := cmp.Diff(db.Data, &testDB{}, cmp.AllowUnexported(testDB{})); diff != "" { + t.Fatalf("unexpected empty DB content (-got+want):\n%s", diff) + } + db.Data.MyString = "test" + db.Data.unexported = "don't keep" + db.Data.AnInt = 42 + if err := db.Save(); err != nil { + t.Fatalf("saving database: %v", err) + } + + db2, err := Open[testDB](path) + if err != nil { + log.Fatalf("opening DB again: %v", err) + } + want := &testDB{ + MyString: "test", + AnInt: 42, + } + if diff := cmp.Diff(db2.Data, want, cmp.AllowUnexported(testDB{})); diff != "" { + t.Fatalf("unexpected saved DB content (-got+want):\n%s", diff) + } +} + +type testDB struct { + MyString string + unexported string + AnInt int64 +} diff --git a/licenses/licenses.go b/licenses/licenses.go index 5e59edb9f7b75..3ec7013214bb5 100644 --- a/licenses/licenses.go +++ b/licenses/licenses.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package licenses provides utilities for working with open source licenses. -package licenses - -import "runtime" - -// LicensesURL returns the absolute URL containing open source license information for the current platform. -func LicensesURL() string { - switch runtime.GOOS { - case "android": - return "https://tailscale.com/licenses/android" - case "darwin", "ios": - return "https://tailscale.com/licenses/apple" - case "windows": - return "https://tailscale.com/licenses/windows" - default: - return "https://tailscale.com/licenses/tailscale" - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package licenses provides utilities for working with open source licenses. +package licenses + +import "runtime" + +// LicensesURL returns the absolute URL containing open source license information for the current platform. +func LicensesURL() string { + switch runtime.GOOS { + case "android": + return "https://tailscale.com/licenses/android" + case "darwin", "ios": + return "https://tailscale.com/licenses/apple" + case "windows": + return "https://tailscale.com/licenses/windows" + default: + return "https://tailscale.com/licenses/tailscale" + } +} diff --git a/log/filelogger/log.go b/log/filelogger/log.go index 599e5237b3e22..9d7097eb83e84 100644 --- a/log/filelogger/log.go +++ b/log/filelogger/log.go @@ -1,228 +1,228 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package filelogger provides localdisk log writing & rotation, primarily for Windows -// clients. (We get this for free on other platforms.) -package filelogger - -import ( - "bytes" - "fmt" - "log" - "os" - "path/filepath" - "runtime" - "strings" - "sync" - "time" - - "tailscale.com/types/logger" -) - -const ( - maxSize = 100 << 20 - maxFiles = 50 -) - -// New returns a logf wrapper that appends to local disk log -// files on Windows, rotating old log files as needed to stay under -// file count & byte limits. -func New(fileBasePrefix, logID string, logf logger.Logf) logger.Logf { - if runtime.GOOS != "windows" { - panic("not yet supported on any platform except Windows") - } - if logf == nil { - panic("nil logf") - } - dir := filepath.Join(os.Getenv("ProgramData"), "Tailscale", "Logs") - - if err := os.MkdirAll(dir, 0700); err != nil { - log.Printf("failed to create local log directory; not writing logs to disk: %v", err) - return logf - } - logf("local disk logdir: %v", dir) - lfw := &logFileWriter{ - fileBasePrefix: fileBasePrefix, - logID: logID, - dir: dir, - wrappedLogf: logf, - } - return lfw.Logf -} - -// logFileWriter is the state for the log writer & rotator. -type logFileWriter struct { - dir string // e.g. `C:\Users\FooBarUser\AppData\Local\Tailscale\Logs` - logID string // hex logID - fileBasePrefix string // e.g. "tailscale-service" or "tailscale-gui" - wrappedLogf logger.Logf // underlying logger to send to - - mu sync.Mutex // guards following - buf bytes.Buffer // scratch buffer to avoid allocs - fday civilDay // day that f was opened; zero means no file yet open - f *os.File // file currently opened for append -} - -// civilDay is a year, month, and day in the local timezone. -// It's a comparable value type. -type civilDay struct { - year int - month time.Month - day int -} - -func dayOf(t time.Time) civilDay { - return civilDay{t.Year(), t.Month(), t.Day()} -} - -func (w *logFileWriter) Logf(format string, a ...any) { - w.mu.Lock() - defer w.mu.Unlock() - - w.buf.Reset() - fmt.Fprintf(&w.buf, format, a...) - if w.buf.Len() == 0 { - return - } - out := w.buf.Bytes() - w.wrappedLogf("%s", out) - - // Make sure there's a final newline before we write to the log file. - if out[len(out)-1] != '\n' { - w.buf.WriteByte('\n') - out = w.buf.Bytes() - } - - w.appendToFileLocked(out) -} - -// out should end in a newline. -// w.mu must be held. -func (w *logFileWriter) appendToFileLocked(out []byte) { - now := time.Now() - day := dayOf(now) - if w.fday != day { - w.startNewFileLocked() - } - out = removeDatePrefix(out) - if w.f != nil { - // RFC3339Nano but with a fixed number (3) of nanosecond digits: - const formatPre = "2006-01-02T15:04:05" - const formatPost = "Z07:00" - fmt.Fprintf(w.f, "%s.%03d%s: %s", - now.Format(formatPre), - now.Nanosecond()/int(time.Millisecond/time.Nanosecond), - now.Format(formatPost), - out) - } -} - -func isNum(b byte) bool { return '0' <= b && b <= '9' } - -// removeDatePrefix returns a subslice of v with the log package's -// standard datetime prefix format removed, if present. -func removeDatePrefix(v []byte) []byte { - const format = "2009/01/23 01:23:23 " - if len(v) < len(format) { - return v - } - for i, b := range v[:len(format)] { - fb := format[i] - if isNum(fb) { - if !isNum(b) { - return v - } - continue - } - if b != fb { - return v - } - } - return v[len(format):] -} - -// startNewFileLocked opens a new log file for writing -// and also cleans up any old files. -// -// w.mu must be held. -func (w *logFileWriter) startNewFileLocked() { - var oldName string - if w.f != nil { - oldName = filepath.Base(w.f.Name()) - w.f.Close() - w.f = nil - w.fday = civilDay{} - } - w.cleanLocked() - - now := time.Now() - day := dayOf(now) - name := filepath.Join(w.dir, fmt.Sprintf("%s-%04d%02d%02dT%02d%02d%02d-%d.txt", - w.fileBasePrefix, - day.year, - day.month, - day.day, - now.Hour(), - now.Minute(), - now.Second(), - now.Unix())) - var err error - w.f, err = os.Create(name) - if err != nil { - w.wrappedLogf("failed to create log file: %v", err) - return - } - if oldName != "" { - fmt.Fprintf(w.f, "(logID %q; continued from log file %s)\n", w.logID, oldName) - } else { - fmt.Fprintf(w.f, "(logID %q)\n", w.logID) - } - w.fday = day -} - -// cleanLocked cleans up old log files. -// -// w.mu must be held. -func (w *logFileWriter) cleanLocked() { - entries, _ := os.ReadDir(w.dir) - prefix := w.fileBasePrefix + "-" - fileSize := map[string]int64{} - var files []string - var sumSize int64 - for _, entry := range entries { - fi, err := entry.Info() - if err != nil { - w.wrappedLogf("error getting log file info: %v", err) - continue - } - - baseName := filepath.Base(fi.Name()) - if !strings.HasPrefix(baseName, prefix) { - continue - } - size := fi.Size() - fileSize[baseName] = size - sumSize += size - files = append(files, baseName) - } - if sumSize > maxSize { - w.wrappedLogf("cleaning log files; sum byte count %d > %d", sumSize, maxSize) - } - if len(files) > maxFiles { - w.wrappedLogf("cleaning log files; number of files %d > %d", len(files), maxFiles) - } - for (sumSize > maxSize || len(files) > maxFiles) && len(files) > 0 { - target := files[0] - files = files[1:] - - targetSize := fileSize[target] - targetFull := filepath.Join(w.dir, target) - err := os.Remove(targetFull) - if err != nil { - w.wrappedLogf("error cleaning log file: %v", err) - } else { - sumSize -= targetSize - w.wrappedLogf("cleaned log file %s (size %d); new bytes=%v, files=%v", targetFull, targetSize, sumSize, len(files)) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package filelogger provides localdisk log writing & rotation, primarily for Windows +// clients. (We get this for free on other platforms.) +package filelogger + +import ( + "bytes" + "fmt" + "log" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + "time" + + "tailscale.com/types/logger" +) + +const ( + maxSize = 100 << 20 + maxFiles = 50 +) + +// New returns a logf wrapper that appends to local disk log +// files on Windows, rotating old log files as needed to stay under +// file count & byte limits. +func New(fileBasePrefix, logID string, logf logger.Logf) logger.Logf { + if runtime.GOOS != "windows" { + panic("not yet supported on any platform except Windows") + } + if logf == nil { + panic("nil logf") + } + dir := filepath.Join(os.Getenv("ProgramData"), "Tailscale", "Logs") + + if err := os.MkdirAll(dir, 0700); err != nil { + log.Printf("failed to create local log directory; not writing logs to disk: %v", err) + return logf + } + logf("local disk logdir: %v", dir) + lfw := &logFileWriter{ + fileBasePrefix: fileBasePrefix, + logID: logID, + dir: dir, + wrappedLogf: logf, + } + return lfw.Logf +} + +// logFileWriter is the state for the log writer & rotator. +type logFileWriter struct { + dir string // e.g. `C:\Users\FooBarUser\AppData\Local\Tailscale\Logs` + logID string // hex logID + fileBasePrefix string // e.g. "tailscale-service" or "tailscale-gui" + wrappedLogf logger.Logf // underlying logger to send to + + mu sync.Mutex // guards following + buf bytes.Buffer // scratch buffer to avoid allocs + fday civilDay // day that f was opened; zero means no file yet open + f *os.File // file currently opened for append +} + +// civilDay is a year, month, and day in the local timezone. +// It's a comparable value type. +type civilDay struct { + year int + month time.Month + day int +} + +func dayOf(t time.Time) civilDay { + return civilDay{t.Year(), t.Month(), t.Day()} +} + +func (w *logFileWriter) Logf(format string, a ...any) { + w.mu.Lock() + defer w.mu.Unlock() + + w.buf.Reset() + fmt.Fprintf(&w.buf, format, a...) + if w.buf.Len() == 0 { + return + } + out := w.buf.Bytes() + w.wrappedLogf("%s", out) + + // Make sure there's a final newline before we write to the log file. + if out[len(out)-1] != '\n' { + w.buf.WriteByte('\n') + out = w.buf.Bytes() + } + + w.appendToFileLocked(out) +} + +// out should end in a newline. +// w.mu must be held. +func (w *logFileWriter) appendToFileLocked(out []byte) { + now := time.Now() + day := dayOf(now) + if w.fday != day { + w.startNewFileLocked() + } + out = removeDatePrefix(out) + if w.f != nil { + // RFC3339Nano but with a fixed number (3) of nanosecond digits: + const formatPre = "2006-01-02T15:04:05" + const formatPost = "Z07:00" + fmt.Fprintf(w.f, "%s.%03d%s: %s", + now.Format(formatPre), + now.Nanosecond()/int(time.Millisecond/time.Nanosecond), + now.Format(formatPost), + out) + } +} + +func isNum(b byte) bool { return '0' <= b && b <= '9' } + +// removeDatePrefix returns a subslice of v with the log package's +// standard datetime prefix format removed, if present. +func removeDatePrefix(v []byte) []byte { + const format = "2009/01/23 01:23:23 " + if len(v) < len(format) { + return v + } + for i, b := range v[:len(format)] { + fb := format[i] + if isNum(fb) { + if !isNum(b) { + return v + } + continue + } + if b != fb { + return v + } + } + return v[len(format):] +} + +// startNewFileLocked opens a new log file for writing +// and also cleans up any old files. +// +// w.mu must be held. +func (w *logFileWriter) startNewFileLocked() { + var oldName string + if w.f != nil { + oldName = filepath.Base(w.f.Name()) + w.f.Close() + w.f = nil + w.fday = civilDay{} + } + w.cleanLocked() + + now := time.Now() + day := dayOf(now) + name := filepath.Join(w.dir, fmt.Sprintf("%s-%04d%02d%02dT%02d%02d%02d-%d.txt", + w.fileBasePrefix, + day.year, + day.month, + day.day, + now.Hour(), + now.Minute(), + now.Second(), + now.Unix())) + var err error + w.f, err = os.Create(name) + if err != nil { + w.wrappedLogf("failed to create log file: %v", err) + return + } + if oldName != "" { + fmt.Fprintf(w.f, "(logID %q; continued from log file %s)\n", w.logID, oldName) + } else { + fmt.Fprintf(w.f, "(logID %q)\n", w.logID) + } + w.fday = day +} + +// cleanLocked cleans up old log files. +// +// w.mu must be held. +func (w *logFileWriter) cleanLocked() { + entries, _ := os.ReadDir(w.dir) + prefix := w.fileBasePrefix + "-" + fileSize := map[string]int64{} + var files []string + var sumSize int64 + for _, entry := range entries { + fi, err := entry.Info() + if err != nil { + w.wrappedLogf("error getting log file info: %v", err) + continue + } + + baseName := filepath.Base(fi.Name()) + if !strings.HasPrefix(baseName, prefix) { + continue + } + size := fi.Size() + fileSize[baseName] = size + sumSize += size + files = append(files, baseName) + } + if sumSize > maxSize { + w.wrappedLogf("cleaning log files; sum byte count %d > %d", sumSize, maxSize) + } + if len(files) > maxFiles { + w.wrappedLogf("cleaning log files; number of files %d > %d", len(files), maxFiles) + } + for (sumSize > maxSize || len(files) > maxFiles) && len(files) > 0 { + target := files[0] + files = files[1:] + + targetSize := fileSize[target] + targetFull := filepath.Join(w.dir, target) + err := os.Remove(targetFull) + if err != nil { + w.wrappedLogf("error cleaning log file: %v", err) + } else { + sumSize -= targetSize + w.wrappedLogf("cleaned log file %s (size %d); new bytes=%v, files=%v", targetFull, targetSize, sumSize, len(files)) + } + } +} diff --git a/log/filelogger/log_test.go b/log/filelogger/log_test.go index dfa489637f720..27f80ab0ae37a 100644 --- a/log/filelogger/log_test.go +++ b/log/filelogger/log_test.go @@ -1,27 +1,27 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package filelogger - -import "testing" - -func TestRemoveDatePrefix(t *testing.T) { - tests := []struct { - in, want string - }{ - {"", ""}, - {"\n", "\n"}, - {"2009/01/23 01:23:23", "2009/01/23 01:23:23"}, - {"2009/01/23 01:23:23 \n", "\n"}, - {"2009/01/23 01:23:23 foo\n", "foo\n"}, - {"9999/01/23 01:23:23 foo\n", "foo\n"}, - {"2009_01/23 01:23:23 had an underscore\n", "2009_01/23 01:23:23 had an underscore\n"}, - } - for i, tt := range tests { - got := removeDatePrefix([]byte(tt.in)) - if string(got) != tt.want { - t.Logf("[%d] removeDatePrefix(%q) = %q; want %q", i, tt.in, got, tt.want) - } - } - -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package filelogger + +import "testing" + +func TestRemoveDatePrefix(t *testing.T) { + tests := []struct { + in, want string + }{ + {"", ""}, + {"\n", "\n"}, + {"2009/01/23 01:23:23", "2009/01/23 01:23:23"}, + {"2009/01/23 01:23:23 \n", "\n"}, + {"2009/01/23 01:23:23 foo\n", "foo\n"}, + {"9999/01/23 01:23:23 foo\n", "foo\n"}, + {"2009_01/23 01:23:23 had an underscore\n", "2009_01/23 01:23:23 had an underscore\n"}, + } + for i, tt := range tests { + got := removeDatePrefix([]byte(tt.in)) + if string(got) != tt.want { + t.Logf("[%d] removeDatePrefix(%q) = %q; want %q", i, tt.in, got, tt.want) + } + } + +} diff --git a/logpolicy/logpolicy_test.go b/logpolicy/logpolicy_test.go index fdbfe4506e038..c0cdfb965c80e 100644 --- a/logpolicy/logpolicy_test.go +++ b/logpolicy/logpolicy_test.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package logpolicy - -import ( - "os" - "reflect" - "testing" -) - -func TestLogHost(t *testing.T) { - v := reflect.ValueOf(&getLogTargetOnce).Elem() - reset := func() { - v.Set(reflect.Zero(v.Type())) - } - defer reset() - - tests := []struct { - env string - want string - }{ - {"", "log.tailscale.io"}, - {"http://foo.com", "foo.com"}, - {"https://foo.com", "foo.com"}, - {"https://foo.com/", "foo.com"}, - {"https://foo.com:123/", "foo.com"}, - } - for _, tt := range tests { - reset() - os.Setenv("TS_LOG_TARGET", tt.env) - if got := LogHost(); got != tt.want { - t.Errorf("for env %q, got %q, want %q", tt.env, got, tt.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package logpolicy + +import ( + "os" + "reflect" + "testing" +) + +func TestLogHost(t *testing.T) { + v := reflect.ValueOf(&getLogTargetOnce).Elem() + reset := func() { + v.Set(reflect.Zero(v.Type())) + } + defer reset() + + tests := []struct { + env string + want string + }{ + {"", "log.tailscale.io"}, + {"http://foo.com", "foo.com"}, + {"https://foo.com", "foo.com"}, + {"https://foo.com/", "foo.com"}, + {"https://foo.com:123/", "foo.com"}, + } + for _, tt := range tests { + reset() + os.Setenv("TS_LOG_TARGET", tt.env) + if got := LogHost(); got != tt.want { + t.Errorf("for env %q, got %q, want %q", tt.env, got, tt.want) + } + } +} diff --git a/logtail/.gitignore b/logtail/.gitignore index 0b29b4aca8ef3..b262949a827d0 100644 --- a/logtail/.gitignore +++ b/logtail/.gitignore @@ -1,6 +1,6 @@ -*~ -*.out -/example/logadopt/logadopt -/example/logreprocess/logreprocess -/example/logtail/logtail -/logtail +*~ +*.out +/example/logadopt/logadopt +/example/logreprocess/logreprocess +/example/logtail/logtail +/logtail diff --git a/logtail/README.md b/logtail/README.md index 20d22c3501432..b7b2ada34e985 100644 --- a/logtail/README.md +++ b/logtail/README.md @@ -1,10 +1,10 @@ -# Tailscale Logs Service - -This github repository contains libraries, documentation, and examples -for working with the public API of the tailscale logs service. - -For a very quick introduction to the core features, read the -[API docs](api.md) and peruse the -[logs reprocessing](./example/logreprocess/demo.sh) example. - +# Tailscale Logs Service + +This github repository contains libraries, documentation, and examples +for working with the public API of the tailscale logs service. + +For a very quick introduction to the core features, read the +[API docs](api.md) and peruse the +[logs reprocessing](./example/logreprocess/demo.sh) example. + For more information, write to info@tailscale.io. \ No newline at end of file diff --git a/logtail/api.md b/logtail/api.md index 8ec0b69c0f331..296913ce4985b 100644 --- a/logtail/api.md +++ b/logtail/api.md @@ -1,195 +1,195 @@ -# Tailscale Logs Service - -The Tailscale Logs Service defines a REST interface for configuring, storing, -retrieving, and processing log entries. - -# Overview - -HTTP requests are received at the service **base URL** -[https://log.tailscale.io](https://log.tailscale.io), and return JSON-encoded -responses using standard HTTP response codes. - -Authorization for the configuration and retrieval APIs is done with a secret -API key passed as the HTTP basic auth username. Secret keys are generated via -the web UI at base URL. An example of using basic auth with curl: - - curl -u : https://log.tailscale.io/collections - -In the future, an HTTP header will allow using MessagePack instead of JSON. - -## Collections - -Logs are organized into collections. Inside each collection is any number of -instances. - -A collection is a domain name. It is a grouping of related logs. As a -guideline, create one collection per product using subdomains of your -company's domain name. Collections must be registered with the logs service -before any attempt is made to store logs. - -## Instances - -Each collection is a set of instances. There is one instance per machine -writing logs. - -An instance has a name and a number. An instance has a **private** and -**public** ID. The private ID is a 32-byte random number encoded as hex. -The public ID is the SHA-256 hash of the private ID, encoded as hex. - -The private ID is used to write logs. The only copy of the private ID -should be on the machine sending logs. Ideally it is generated on the -machine. Logs can be written as soon as a private ID is generated. - -The public ID is used to read and adopt logs. It is designed to be sent -to a service that also holds a logs service API key. - -The tailscale logs service will store any logs for a short period of time. -To enable logs retention, the log can be **adopted** using the public ID -and a logs service API key. -Once this is done, logs will be retained long-term (for the configured -retention period). - -Unadopted instance logs are stored temporarily to help with debugging: -a misconfigured machine writing logs with a bad ID can be spotted by -reading the logs. -If a public ID is not adopted, storage is tightly capped and logs are -deleted after 12 hours. - -# APIs - -## Storage - -### `POST /c//` — send a log - -The body of the request is JSON. - -A **single message** is an object with properties: - -`{ }` - -The client may send any properties it wants in the JSON message, except -for the `logtail` property which has special meaning. Inside the logtail -object the client may only set the following properties: - -- `client_time` in the format of RFC3339: "2006-01-02T15:04:05.999999999Z07:00" - -A future version of the logs service API will also support: - -- `client_time_offset` a integer of nanoseconds since the client was reset -- `client_time_reset` a boolean if set to true resets the time offset counter - -On receipt by the server the `client_time_offset` is transformed into a -`client_time` based on the `server_time` when the first (or -client_time_reset) event was received. - -If any other properties are set in the logtail object they are moved into -the "error" field, the message is saved and a 4xx status code is returned. - -A **batch of messages** is a JSON array filled with single message objects: - -`[ { }, { }, ... ]` - -If any of the array entries are not objects, the content is converted -into a message with a `"logtail": { "error": ...}` property, saved, and -a 4xx status code is returned. - -Similarly any other request content not matching one of these formats is -saved in a logtail error field, and a 4xx status code is returned. - -An invalid collection name returns `{"error": "invalid collection name"}` -along with a 403 status code. - -Clients are encouraged to: - -- POST as rapidly as possible (if not battery constrained). This minimizes - both the time necessary to see logs in a log viewer and the chance of - losing logs. -- Use HTTP/2 when streaming logs, as it does a much better job of - maintaining a TLS connection to minimize overhead for subsequent posts. - -A future version of logs service API will support sending requests with -`Content-Encoding: zstd`. - -## Retrieval - -### `GET /collections` — query the set of collections and instances - -Returns a JSON object listing all of the named collections. - -The caller can query-encode the following fields: - -- `collection-name` — limit the results to one collection - - ``` - { - "collections": { - "collection1.yourcompany.com": { - "instances": { - "" :{ - "first-seen": "timestamp", - "size": 4096 - }, - "" :{ - "first-seen": "timestamp", - "size": 512000, - "orphan": true, - } - } - } - } - } - ``` - -### `GET /c/` — query stored logs - -The caller can query-encode the following fields: - -- `instances` — zero or more log collection instances to limit results to -- `time-start` — the earliest log to include -- One of: - - `time-end` — the latest log to include - - `max-count` — maximum number of logs to return, allows paging - - `stream` — boolean that keeps the response dangling, streaming in - logs like `tail -f`. Incompatible with logtail-time-end. - -In **stream=false** mode, the response is a single JSON object: - - { - // TODO: header fields - "logs": [ {}, {}, ... ] - } - -In **stream=true** mode, the response begins with a JSON header object -similar to the storage format, and then is a sequence of JSON log -objects, `{...}`, one per line. The server continues to send these until -the client closes the connection. - -## Configuration - -For organizations with a small number of instances writing logs, the -Configuration API are best used by a trusted human operator, usually -through a GUI. Organizations with many instances will need to automate -the creation of tokens. - -### `POST /collections` — create or delete a collection - -The caller must set the `collection` property and `action=create` or -`action=delete`, either form encoded or JSON encoded. Its character set -is restricted to the mundane: [a-zA-Z0-9-_.]+ - -Collection names are a global space. Typically they are a domain name. - -### `POST /instances` — adopt an instance into a collection - -The caller must send the following properties, form encoded or JSON encoded: - -- `collection` — a valid FQDN ([a-zA-Z0-9-_.]+) -- `instances` an instance public ID encoded as hex - -The collection name must be claimed by a group the caller belongs to. -The pair (collection-name, instance-public-ID) may or may not already have -logs associated with it. - -On failure, an error message is returned with a 4xx or 5xx status code: - +# Tailscale Logs Service + +The Tailscale Logs Service defines a REST interface for configuring, storing, +retrieving, and processing log entries. + +# Overview + +HTTP requests are received at the service **base URL** +[https://log.tailscale.io](https://log.tailscale.io), and return JSON-encoded +responses using standard HTTP response codes. + +Authorization for the configuration and retrieval APIs is done with a secret +API key passed as the HTTP basic auth username. Secret keys are generated via +the web UI at base URL. An example of using basic auth with curl: + + curl -u : https://log.tailscale.io/collections + +In the future, an HTTP header will allow using MessagePack instead of JSON. + +## Collections + +Logs are organized into collections. Inside each collection is any number of +instances. + +A collection is a domain name. It is a grouping of related logs. As a +guideline, create one collection per product using subdomains of your +company's domain name. Collections must be registered with the logs service +before any attempt is made to store logs. + +## Instances + +Each collection is a set of instances. There is one instance per machine +writing logs. + +An instance has a name and a number. An instance has a **private** and +**public** ID. The private ID is a 32-byte random number encoded as hex. +The public ID is the SHA-256 hash of the private ID, encoded as hex. + +The private ID is used to write logs. The only copy of the private ID +should be on the machine sending logs. Ideally it is generated on the +machine. Logs can be written as soon as a private ID is generated. + +The public ID is used to read and adopt logs. It is designed to be sent +to a service that also holds a logs service API key. + +The tailscale logs service will store any logs for a short period of time. +To enable logs retention, the log can be **adopted** using the public ID +and a logs service API key. +Once this is done, logs will be retained long-term (for the configured +retention period). + +Unadopted instance logs are stored temporarily to help with debugging: +a misconfigured machine writing logs with a bad ID can be spotted by +reading the logs. +If a public ID is not adopted, storage is tightly capped and logs are +deleted after 12 hours. + +# APIs + +## Storage + +### `POST /c//` — send a log + +The body of the request is JSON. + +A **single message** is an object with properties: + +`{ }` + +The client may send any properties it wants in the JSON message, except +for the `logtail` property which has special meaning. Inside the logtail +object the client may only set the following properties: + +- `client_time` in the format of RFC3339: "2006-01-02T15:04:05.999999999Z07:00" + +A future version of the logs service API will also support: + +- `client_time_offset` a integer of nanoseconds since the client was reset +- `client_time_reset` a boolean if set to true resets the time offset counter + +On receipt by the server the `client_time_offset` is transformed into a +`client_time` based on the `server_time` when the first (or +client_time_reset) event was received. + +If any other properties are set in the logtail object they are moved into +the "error" field, the message is saved and a 4xx status code is returned. + +A **batch of messages** is a JSON array filled with single message objects: + +`[ { }, { }, ... ]` + +If any of the array entries are not objects, the content is converted +into a message with a `"logtail": { "error": ...}` property, saved, and +a 4xx status code is returned. + +Similarly any other request content not matching one of these formats is +saved in a logtail error field, and a 4xx status code is returned. + +An invalid collection name returns `{"error": "invalid collection name"}` +along with a 403 status code. + +Clients are encouraged to: + +- POST as rapidly as possible (if not battery constrained). This minimizes + both the time necessary to see logs in a log viewer and the chance of + losing logs. +- Use HTTP/2 when streaming logs, as it does a much better job of + maintaining a TLS connection to minimize overhead for subsequent posts. + +A future version of logs service API will support sending requests with +`Content-Encoding: zstd`. + +## Retrieval + +### `GET /collections` — query the set of collections and instances + +Returns a JSON object listing all of the named collections. + +The caller can query-encode the following fields: + +- `collection-name` — limit the results to one collection + + ``` + { + "collections": { + "collection1.yourcompany.com": { + "instances": { + "" :{ + "first-seen": "timestamp", + "size": 4096 + }, + "" :{ + "first-seen": "timestamp", + "size": 512000, + "orphan": true, + } + } + } + } + } + ``` + +### `GET /c/` — query stored logs + +The caller can query-encode the following fields: + +- `instances` — zero or more log collection instances to limit results to +- `time-start` — the earliest log to include +- One of: + - `time-end` — the latest log to include + - `max-count` — maximum number of logs to return, allows paging + - `stream` — boolean that keeps the response dangling, streaming in + logs like `tail -f`. Incompatible with logtail-time-end. + +In **stream=false** mode, the response is a single JSON object: + + { + // TODO: header fields + "logs": [ {}, {}, ... ] + } + +In **stream=true** mode, the response begins with a JSON header object +similar to the storage format, and then is a sequence of JSON log +objects, `{...}`, one per line. The server continues to send these until +the client closes the connection. + +## Configuration + +For organizations with a small number of instances writing logs, the +Configuration API are best used by a trusted human operator, usually +through a GUI. Organizations with many instances will need to automate +the creation of tokens. + +### `POST /collections` — create or delete a collection + +The caller must set the `collection` property and `action=create` or +`action=delete`, either form encoded or JSON encoded. Its character set +is restricted to the mundane: [a-zA-Z0-9-_.]+ + +Collection names are a global space. Typically they are a domain name. + +### `POST /instances` — adopt an instance into a collection + +The caller must send the following properties, form encoded or JSON encoded: + +- `collection` — a valid FQDN ([a-zA-Z0-9-_.]+) +- `instances` an instance public ID encoded as hex + +The collection name must be claimed by a group the caller belongs to. +The pair (collection-name, instance-public-ID) may or may not already have +logs associated with it. + +On failure, an error message is returned with a 4xx or 5xx status code: + `{"error": "what went wrong"}` \ No newline at end of file diff --git a/logtail/example/logreprocess/demo.sh b/logtail/example/logreprocess/demo.sh index 4ec819a67450d..eaec706a38718 100755 --- a/logtail/example/logreprocess/demo.sh +++ b/logtail/example/logreprocess/demo.sh @@ -1,86 +1,86 @@ -#!/bin/bash -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -# -# This shell script demonstrates writing logs from machines -# and then reprocessing those logs to amalgamate python tracebacks -# into a single log entry in a new collection. -# -# To run this demo, first install the example applications: -# -# go install tailscale.com/logtail/example/... -# -# Then generate a LOGTAIL_API_KEY and two test collections by visiting: -# -# https://log.tailscale.io -# -# Then set the three variables below. -trap 'rv=$?; [ "$rv" = 0 ] || echo "-- exiting with code $rv"; exit $rv' EXIT -set -e - -LOG_TEXT='server starting -config file loaded -answering queries -Traceback (most recent call last): - File "/Users/crawshaw/junk.py", line 6, in - main() - File "/Users/crawshaw/junk.py", line 4, in main - raise Exception("oops") -Exception: oops' - -die() { - echo "$0: $*" >&2 - exit 1 -} - -msg() { - echo "-- $*" >&2 -} - -if [ -z "$LOGTAIL_API_KEY" ]; then - die "LOGTAIL_API_KEY is not set" -fi - -if [ -z "$COLLECTION_IN" ]; then - die "COLLECTION_IN is not set" -fi - -if [ -z "$COLLECTION_OUT" ]; then - die "COLLECTION_OUT is not set" -fi - -# Private IDs are 32-bytes of random hex. -# Normally you'd keep the same private IDs from one run to the next, but -# this is just an example. -msg "Generating keys..." -privateid1=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) -privateid2=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) -privateid3=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) - -# Public IDs are the SHA-256 of the private ID. -publicid1=$(echo -n $privateid1 | xxd -r -p - | shasum -a 256 | sed 's/ -//') -publicid2=$(echo -n $privateid2 | xxd -r -p - | shasum -a 256 | sed 's/ -//') -publicid3=$(echo -n $privateid3 | xxd -r -p - | shasum -a 256 | sed 's/ -//') - -# Write the machine logs to the input collection. -# Notice that this doesn't require an API key. -msg "Producing new logs..." -echo "$LOG_TEXT" | logtail -c $COLLECTION_IN -k $privateid1 >/dev/null -echo "$LOG_TEXT" | logtail -c $COLLECTION_IN -k $privateid2 >/dev/null - -# Adopt the logs, so they will be kept and are readable. -msg "Adopting logs..." -logadopt -p "$LOGTAIL_API_KEY" -c "$COLLECTION_IN" -m $publicid1 -logadopt -p "$LOGTAIL_API_KEY" -c "$COLLECTION_IN" -m $publicid2 - -# Reprocess the logs, amalgamating python tracebacks. -# -# We'll take that reprocessed output and write it to a separate collection, -# again via logtail. -# -# Time out quickly because all our "interesting" logs (generated -# above) have already been processed. -msg "Reprocessing logs..." -logreprocess -t 3s -c "$COLLECTION_IN" -p "$LOGTAIL_API_KEY" 2>&1 | - logtail -c "$COLLECTION_OUT" -k $privateid3 +#!/bin/bash +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +# +# This shell script demonstrates writing logs from machines +# and then reprocessing those logs to amalgamate python tracebacks +# into a single log entry in a new collection. +# +# To run this demo, first install the example applications: +# +# go install tailscale.com/logtail/example/... +# +# Then generate a LOGTAIL_API_KEY and two test collections by visiting: +# +# https://log.tailscale.io +# +# Then set the three variables below. +trap 'rv=$?; [ "$rv" = 0 ] || echo "-- exiting with code $rv"; exit $rv' EXIT +set -e + +LOG_TEXT='server starting +config file loaded +answering queries +Traceback (most recent call last): + File "/Users/crawshaw/junk.py", line 6, in + main() + File "/Users/crawshaw/junk.py", line 4, in main + raise Exception("oops") +Exception: oops' + +die() { + echo "$0: $*" >&2 + exit 1 +} + +msg() { + echo "-- $*" >&2 +} + +if [ -z "$LOGTAIL_API_KEY" ]; then + die "LOGTAIL_API_KEY is not set" +fi + +if [ -z "$COLLECTION_IN" ]; then + die "COLLECTION_IN is not set" +fi + +if [ -z "$COLLECTION_OUT" ]; then + die "COLLECTION_OUT is not set" +fi + +# Private IDs are 32-bytes of random hex. +# Normally you'd keep the same private IDs from one run to the next, but +# this is just an example. +msg "Generating keys..." +privateid1=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) +privateid2=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) +privateid3=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) + +# Public IDs are the SHA-256 of the private ID. +publicid1=$(echo -n $privateid1 | xxd -r -p - | shasum -a 256 | sed 's/ -//') +publicid2=$(echo -n $privateid2 | xxd -r -p - | shasum -a 256 | sed 's/ -//') +publicid3=$(echo -n $privateid3 | xxd -r -p - | shasum -a 256 | sed 's/ -//') + +# Write the machine logs to the input collection. +# Notice that this doesn't require an API key. +msg "Producing new logs..." +echo "$LOG_TEXT" | logtail -c $COLLECTION_IN -k $privateid1 >/dev/null +echo "$LOG_TEXT" | logtail -c $COLLECTION_IN -k $privateid2 >/dev/null + +# Adopt the logs, so they will be kept and are readable. +msg "Adopting logs..." +logadopt -p "$LOGTAIL_API_KEY" -c "$COLLECTION_IN" -m $publicid1 +logadopt -p "$LOGTAIL_API_KEY" -c "$COLLECTION_IN" -m $publicid2 + +# Reprocess the logs, amalgamating python tracebacks. +# +# We'll take that reprocessed output and write it to a separate collection, +# again via logtail. +# +# Time out quickly because all our "interesting" logs (generated +# above) have already been processed. +msg "Reprocessing logs..." +logreprocess -t 3s -c "$COLLECTION_IN" -p "$LOGTAIL_API_KEY" 2>&1 | + logtail -c "$COLLECTION_OUT" -k $privateid3 diff --git a/logtail/example/logreprocess/logreprocess.go b/logtail/example/logreprocess/logreprocess.go index 5dbf765788165..e88d5b4856700 100644 --- a/logtail/example/logreprocess/logreprocess.go +++ b/logtail/example/logreprocess/logreprocess.go @@ -1,115 +1,115 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The logreprocess program tails a log and reprocesses it. -package main - -import ( - "bufio" - "encoding/json" - "flag" - "io" - "log" - "net/http" - "os" - "strings" - "time" - - "tailscale.com/types/logid" -) - -func main() { - collection := flag.String("c", "", "logtail collection name to read") - apiKey := flag.String("p", "", "logtail API key") - timeout := flag.Duration("t", 0, "timeout after which logreprocess quits") - flag.Parse() - if len(flag.Args()) != 0 { - flag.Usage() - os.Exit(1) - } - log.SetFlags(0) - - if *timeout != 0 { - go func() { - <-time.After(*timeout) - log.Printf("logreprocess: timeout reached, quitting") - os.Exit(1) - }() - } - - req, err := http.NewRequest("GET", "https://log.tailscale.io/c/"+*collection+"?stream=true", nil) - if err != nil { - log.Fatal(err) - } - req.SetBasicAuth(*apiKey, "") - resp, err := http.DefaultClient.Do(req) - if err != nil { - log.Fatal(err) - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - b, err := io.ReadAll(resp.Body) - if err != nil { - log.Fatalf("logreprocess: read error %d: %v", resp.StatusCode, err) - } - log.Fatalf("logreprocess: read error %d: %s", resp.StatusCode, string(b)) - } - - tracebackCache := make(map[logid.PublicID]*ProcessedMsg) - - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - var msg Msg - if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil { - log.Fatalf("logreprocess of %q: %v", string(scanner.Bytes()), err) - } - var pMsg *ProcessedMsg - if pMsg = tracebackCache[msg.Logtail.Instance]; pMsg != nil { - pMsg.Text += "\n" + msg.Text - if strings.HasPrefix(msg.Text, "Exception: ") { - delete(tracebackCache, msg.Logtail.Instance) - } else { - continue // write later - } - } else { - pMsg = &ProcessedMsg{ - OrigInstance: msg.Logtail.Instance, - Text: msg.Text, - } - pMsg.Logtail.ClientTime = msg.Logtail.ClientTime - } - - if strings.HasPrefix(msg.Text, "Traceback (most recent call last):") { - tracebackCache[msg.Logtail.Instance] = pMsg - continue // write later - } - - b, err := json.Marshal(pMsg) - if err != nil { - log.Fatal(err) - } - log.Printf("%s", b) - } - if err := scanner.Err(); err != nil { - log.Fatal(err) - } -} - -type Msg struct { - Logtail struct { - Instance logid.PublicID `json:"instance"` - ClientTime time.Time `json:"client_time"` - } `json:"logtail"` - - Text string `json:"text"` -} - -type ProcessedMsg struct { - Logtail struct { - ClientTime time.Time `json:"client_time"` - } `json:"logtail"` - - OrigInstance logid.PublicID `json:"orig_instance"` - Text string `json:"text"` -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The logreprocess program tails a log and reprocesses it. +package main + +import ( + "bufio" + "encoding/json" + "flag" + "io" + "log" + "net/http" + "os" + "strings" + "time" + + "tailscale.com/types/logid" +) + +func main() { + collection := flag.String("c", "", "logtail collection name to read") + apiKey := flag.String("p", "", "logtail API key") + timeout := flag.Duration("t", 0, "timeout after which logreprocess quits") + flag.Parse() + if len(flag.Args()) != 0 { + flag.Usage() + os.Exit(1) + } + log.SetFlags(0) + + if *timeout != 0 { + go func() { + <-time.After(*timeout) + log.Printf("logreprocess: timeout reached, quitting") + os.Exit(1) + }() + } + + req, err := http.NewRequest("GET", "https://log.tailscale.io/c/"+*collection+"?stream=true", nil) + if err != nil { + log.Fatal(err) + } + req.SetBasicAuth(*apiKey, "") + resp, err := http.DefaultClient.Do(req) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + b, err := io.ReadAll(resp.Body) + if err != nil { + log.Fatalf("logreprocess: read error %d: %v", resp.StatusCode, err) + } + log.Fatalf("logreprocess: read error %d: %s", resp.StatusCode, string(b)) + } + + tracebackCache := make(map[logid.PublicID]*ProcessedMsg) + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + var msg Msg + if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil { + log.Fatalf("logreprocess of %q: %v", string(scanner.Bytes()), err) + } + var pMsg *ProcessedMsg + if pMsg = tracebackCache[msg.Logtail.Instance]; pMsg != nil { + pMsg.Text += "\n" + msg.Text + if strings.HasPrefix(msg.Text, "Exception: ") { + delete(tracebackCache, msg.Logtail.Instance) + } else { + continue // write later + } + } else { + pMsg = &ProcessedMsg{ + OrigInstance: msg.Logtail.Instance, + Text: msg.Text, + } + pMsg.Logtail.ClientTime = msg.Logtail.ClientTime + } + + if strings.HasPrefix(msg.Text, "Traceback (most recent call last):") { + tracebackCache[msg.Logtail.Instance] = pMsg + continue // write later + } + + b, err := json.Marshal(pMsg) + if err != nil { + log.Fatal(err) + } + log.Printf("%s", b) + } + if err := scanner.Err(); err != nil { + log.Fatal(err) + } +} + +type Msg struct { + Logtail struct { + Instance logid.PublicID `json:"instance"` + ClientTime time.Time `json:"client_time"` + } `json:"logtail"` + + Text string `json:"text"` +} + +type ProcessedMsg struct { + Logtail struct { + ClientTime time.Time `json:"client_time"` + } `json:"logtail"` + + OrigInstance logid.PublicID `json:"orig_instance"` + Text string `json:"text"` +} diff --git a/logtail/example/logtail/logtail.go b/logtail/example/logtail/logtail.go index 0c9e442584410..e777055133904 100644 --- a/logtail/example/logtail/logtail.go +++ b/logtail/example/logtail/logtail.go @@ -1,46 +1,46 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The logtail program logs stdin. -package main - -import ( - "bufio" - "flag" - "io" - "log" - "os" - - "tailscale.com/logtail" - "tailscale.com/types/logid" -) - -func main() { - collection := flag.String("c", "", "logtail collection name") - privateID := flag.String("k", "", "machine private identifier, 32-bytes in hex") - flag.Parse() - if len(flag.Args()) != 0 { - flag.Usage() - os.Exit(1) - } - - log.SetFlags(0) - - var id logid.PrivateID - if err := id.UnmarshalText([]byte(*privateID)); err != nil { - log.Fatalf("logtail: bad -privateid: %v", err) - } - - logger := logtail.NewLogger(logtail.Config{ - Collection: *collection, - PrivateID: id, - }, log.Printf) - log.SetOutput(io.MultiWriter(logger, os.Stdout)) - defer logger.Flush() - defer log.Printf("logtail exited") - - scanner := bufio.NewScanner(os.Stdin) - for scanner.Scan() { - log.Println(scanner.Text()) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The logtail program logs stdin. +package main + +import ( + "bufio" + "flag" + "io" + "log" + "os" + + "tailscale.com/logtail" + "tailscale.com/types/logid" +) + +func main() { + collection := flag.String("c", "", "logtail collection name") + privateID := flag.String("k", "", "machine private identifier, 32-bytes in hex") + flag.Parse() + if len(flag.Args()) != 0 { + flag.Usage() + os.Exit(1) + } + + log.SetFlags(0) + + var id logid.PrivateID + if err := id.UnmarshalText([]byte(*privateID)); err != nil { + log.Fatalf("logtail: bad -privateid: %v", err) + } + + logger := logtail.NewLogger(logtail.Config{ + Collection: *collection, + PrivateID: id, + }, log.Printf) + log.SetOutput(io.MultiWriter(logger, os.Stdout)) + defer logger.Flush() + defer log.Printf("logtail exited") + + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + log.Println(scanner.Text()) + } +} diff --git a/logtail/filch/filch.go b/logtail/filch/filch.go index d00206dd51487..886fe239c71b8 100644 --- a/logtail/filch/filch.go +++ b/logtail/filch/filch.go @@ -1,284 +1,284 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package filch is a file system queue that pilfers your stderr. -// (A FILe CHannel that filches.) -package filch - -import ( - "bufio" - "bytes" - "fmt" - "io" - "os" - "sync" -) - -var stderrFD = 2 // a variable for testing - -const defaultMaxFileSize = 50 << 20 - -type Options struct { - ReplaceStderr bool // dup over fd 2 so everything written to stderr comes here - MaxFileSize int -} - -// A Filch uses two alternating files as a simplistic ring buffer. -type Filch struct { - OrigStderr *os.File - - mu sync.Mutex - cur *os.File - alt *os.File - altscan *bufio.Scanner - recovered int64 - - maxFileSize int64 - writeCounter int - - // buf is an initial buffer for altscan. - // As of August 2021, 99.96% of all log lines - // are below 4096 bytes in length. - // Since this cutoff is arbitrary, instead of using 4096, - // we subtract off the size of the rest of the struct - // so that the whole struct takes 4096 bytes - // (less on 32 bit platforms). - // This reduces allocation waste. - buf [4096 - 64]byte -} - -// TryReadline implements the logtail.Buffer interface. -func (f *Filch) TryReadLine() ([]byte, error) { - f.mu.Lock() - defer f.mu.Unlock() - - if f.altscan != nil { - if b, err := f.scan(); b != nil || err != nil { - return b, err - } - } - - f.cur, f.alt = f.alt, f.cur - if f.OrigStderr != nil { - if err := dup2Stderr(f.cur); err != nil { - return nil, err - } - } - if _, err := f.alt.Seek(0, io.SeekStart); err != nil { - return nil, err - } - f.altscan = bufio.NewScanner(f.alt) - f.altscan.Buffer(f.buf[:], bufio.MaxScanTokenSize) - f.altscan.Split(splitLines) - return f.scan() -} - -func (f *Filch) scan() ([]byte, error) { - if f.altscan.Scan() { - return f.altscan.Bytes(), nil - } - err := f.altscan.Err() - err2 := f.alt.Truncate(0) - _, err3 := f.alt.Seek(0, io.SeekStart) - f.altscan = nil - if err != nil { - return nil, err - } - if err2 != nil { - return nil, err2 - } - if err3 != nil { - return nil, err3 - } - return nil, nil -} - -// Write implements the logtail.Buffer interface. -func (f *Filch) Write(b []byte) (int, error) { - f.mu.Lock() - defer f.mu.Unlock() - if f.writeCounter == 100 { - // Check the file size every 100 writes. - f.writeCounter = 0 - fi, err := f.cur.Stat() - if err != nil { - return 0, err - } - if fi.Size() >= f.maxFileSize { - // This most likely means we are not draining. - // To limit the amount of space we use, throw away the old logs. - if err := moveContents(f.alt, f.cur); err != nil { - return 0, err - } - } - } - f.writeCounter++ - - if len(b) == 0 || b[len(b)-1] != '\n' { - bnl := make([]byte, len(b)+1) - copy(bnl, b) - bnl[len(bnl)-1] = '\n' - return f.cur.Write(bnl) - } - return f.cur.Write(b) -} - -// Close closes the Filch, releasing all os resources. -func (f *Filch) Close() (err error) { - f.mu.Lock() - defer f.mu.Unlock() - - if f.OrigStderr != nil { - if err2 := unsaveStderr(f.OrigStderr); err == nil { - err = err2 - } - f.OrigStderr = nil - } - - if err2 := f.cur.Close(); err == nil { - err = err2 - } - if err2 := f.alt.Close(); err == nil { - err = err2 - } - - return err -} - -// New creates a new filch around two log files, each starting with filePrefix. -func New(filePrefix string, opts Options) (f *Filch, err error) { - var f1, f2 *os.File - defer func() { - if err != nil { - if f1 != nil { - f1.Close() - } - if f2 != nil { - f2.Close() - } - err = fmt.Errorf("filch: %s", err) - } - }() - - path1 := filePrefix + ".log1.txt" - path2 := filePrefix + ".log2.txt" - - f1, err = os.OpenFile(path1, os.O_CREATE|os.O_RDWR, 0600) - if err != nil { - return nil, err - } - f2, err = os.OpenFile(path2, os.O_CREATE|os.O_RDWR, 0600) - if err != nil { - return nil, err - } - - fi1, err := f1.Stat() - if err != nil { - return nil, err - } - fi2, err := f2.Stat() - if err != nil { - return nil, err - } - - mfs := defaultMaxFileSize - if opts.MaxFileSize > 0 { - mfs = opts.MaxFileSize - } - f = &Filch{ - OrigStderr: os.Stderr, // temporary, for past logs recovery - maxFileSize: int64(mfs), - } - - // Neither, either, or both files may exist and contain logs from - // the last time the process ran. The three cases are: - // - // - neither: all logs were read out and files were truncated - // - either: logs were being written into one of the files - // - both: the files were swapped and were starting to be - // read out, while new logs streamed into the other - // file, but the read out did not complete - if n := fi1.Size() + fi2.Size(); n > 0 { - f.recovered = n - } - switch { - case fi1.Size() > 0 && fi2.Size() == 0: - f.cur, f.alt = f2, f1 - case fi2.Size() > 0 && fi1.Size() == 0: - f.cur, f.alt = f1, f2 - case fi1.Size() > 0 && fi2.Size() > 0: // both - // We need to pick one of the files to be the elder, - // which we do using the mtime. - var older, newer *os.File - if fi1.ModTime().Before(fi2.ModTime()) { - older, newer = f1, f2 - } else { - older, newer = f2, f1 - } - if err := moveContents(older, newer); err != nil { - fmt.Fprintf(f.OrigStderr, "filch: recover move failed: %v\n", err) - fmt.Fprintf(older, "filch: recover move failed: %v\n", err) - } - f.cur, f.alt = newer, older - default: - f.cur, f.alt = f1, f2 // does not matter - } - if f.recovered > 0 { - f.altscan = bufio.NewScanner(f.alt) - f.altscan.Buffer(f.buf[:], bufio.MaxScanTokenSize) - f.altscan.Split(splitLines) - } - - f.OrigStderr = nil - if opts.ReplaceStderr { - f.OrigStderr, err = saveStderr() - if err != nil { - return nil, err - } - if err := dup2Stderr(f.cur); err != nil { - return nil, err - } - } - - return f, nil -} - -func moveContents(dst, src *os.File) (err error) { - defer func() { - _, err2 := src.Seek(0, io.SeekStart) - err3 := src.Truncate(0) - _, err4 := dst.Seek(0, io.SeekStart) - if err == nil { - err = err2 - } - if err == nil { - err = err3 - } - if err == nil { - err = err4 - } - }() - if _, err := src.Seek(0, io.SeekStart); err != nil { - return err - } - if _, err := dst.Seek(0, io.SeekStart); err != nil { - return err - } - if _, err := io.Copy(dst, src); err != nil { - return err - } - return nil -} - -func splitLines(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := bytes.IndexByte(data, '\n'); i >= 0 { - return i + 1, data[0 : i+1], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package filch is a file system queue that pilfers your stderr. +// (A FILe CHannel that filches.) +package filch + +import ( + "bufio" + "bytes" + "fmt" + "io" + "os" + "sync" +) + +var stderrFD = 2 // a variable for testing + +const defaultMaxFileSize = 50 << 20 + +type Options struct { + ReplaceStderr bool // dup over fd 2 so everything written to stderr comes here + MaxFileSize int +} + +// A Filch uses two alternating files as a simplistic ring buffer. +type Filch struct { + OrigStderr *os.File + + mu sync.Mutex + cur *os.File + alt *os.File + altscan *bufio.Scanner + recovered int64 + + maxFileSize int64 + writeCounter int + + // buf is an initial buffer for altscan. + // As of August 2021, 99.96% of all log lines + // are below 4096 bytes in length. + // Since this cutoff is arbitrary, instead of using 4096, + // we subtract off the size of the rest of the struct + // so that the whole struct takes 4096 bytes + // (less on 32 bit platforms). + // This reduces allocation waste. + buf [4096 - 64]byte +} + +// TryReadline implements the logtail.Buffer interface. +func (f *Filch) TryReadLine() ([]byte, error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.altscan != nil { + if b, err := f.scan(); b != nil || err != nil { + return b, err + } + } + + f.cur, f.alt = f.alt, f.cur + if f.OrigStderr != nil { + if err := dup2Stderr(f.cur); err != nil { + return nil, err + } + } + if _, err := f.alt.Seek(0, io.SeekStart); err != nil { + return nil, err + } + f.altscan = bufio.NewScanner(f.alt) + f.altscan.Buffer(f.buf[:], bufio.MaxScanTokenSize) + f.altscan.Split(splitLines) + return f.scan() +} + +func (f *Filch) scan() ([]byte, error) { + if f.altscan.Scan() { + return f.altscan.Bytes(), nil + } + err := f.altscan.Err() + err2 := f.alt.Truncate(0) + _, err3 := f.alt.Seek(0, io.SeekStart) + f.altscan = nil + if err != nil { + return nil, err + } + if err2 != nil { + return nil, err2 + } + if err3 != nil { + return nil, err3 + } + return nil, nil +} + +// Write implements the logtail.Buffer interface. +func (f *Filch) Write(b []byte) (int, error) { + f.mu.Lock() + defer f.mu.Unlock() + if f.writeCounter == 100 { + // Check the file size every 100 writes. + f.writeCounter = 0 + fi, err := f.cur.Stat() + if err != nil { + return 0, err + } + if fi.Size() >= f.maxFileSize { + // This most likely means we are not draining. + // To limit the amount of space we use, throw away the old logs. + if err := moveContents(f.alt, f.cur); err != nil { + return 0, err + } + } + } + f.writeCounter++ + + if len(b) == 0 || b[len(b)-1] != '\n' { + bnl := make([]byte, len(b)+1) + copy(bnl, b) + bnl[len(bnl)-1] = '\n' + return f.cur.Write(bnl) + } + return f.cur.Write(b) +} + +// Close closes the Filch, releasing all os resources. +func (f *Filch) Close() (err error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.OrigStderr != nil { + if err2 := unsaveStderr(f.OrigStderr); err == nil { + err = err2 + } + f.OrigStderr = nil + } + + if err2 := f.cur.Close(); err == nil { + err = err2 + } + if err2 := f.alt.Close(); err == nil { + err = err2 + } + + return err +} + +// New creates a new filch around two log files, each starting with filePrefix. +func New(filePrefix string, opts Options) (f *Filch, err error) { + var f1, f2 *os.File + defer func() { + if err != nil { + if f1 != nil { + f1.Close() + } + if f2 != nil { + f2.Close() + } + err = fmt.Errorf("filch: %s", err) + } + }() + + path1 := filePrefix + ".log1.txt" + path2 := filePrefix + ".log2.txt" + + f1, err = os.OpenFile(path1, os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return nil, err + } + f2, err = os.OpenFile(path2, os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return nil, err + } + + fi1, err := f1.Stat() + if err != nil { + return nil, err + } + fi2, err := f2.Stat() + if err != nil { + return nil, err + } + + mfs := defaultMaxFileSize + if opts.MaxFileSize > 0 { + mfs = opts.MaxFileSize + } + f = &Filch{ + OrigStderr: os.Stderr, // temporary, for past logs recovery + maxFileSize: int64(mfs), + } + + // Neither, either, or both files may exist and contain logs from + // the last time the process ran. The three cases are: + // + // - neither: all logs were read out and files were truncated + // - either: logs were being written into one of the files + // - both: the files were swapped and were starting to be + // read out, while new logs streamed into the other + // file, but the read out did not complete + if n := fi1.Size() + fi2.Size(); n > 0 { + f.recovered = n + } + switch { + case fi1.Size() > 0 && fi2.Size() == 0: + f.cur, f.alt = f2, f1 + case fi2.Size() > 0 && fi1.Size() == 0: + f.cur, f.alt = f1, f2 + case fi1.Size() > 0 && fi2.Size() > 0: // both + // We need to pick one of the files to be the elder, + // which we do using the mtime. + var older, newer *os.File + if fi1.ModTime().Before(fi2.ModTime()) { + older, newer = f1, f2 + } else { + older, newer = f2, f1 + } + if err := moveContents(older, newer); err != nil { + fmt.Fprintf(f.OrigStderr, "filch: recover move failed: %v\n", err) + fmt.Fprintf(older, "filch: recover move failed: %v\n", err) + } + f.cur, f.alt = newer, older + default: + f.cur, f.alt = f1, f2 // does not matter + } + if f.recovered > 0 { + f.altscan = bufio.NewScanner(f.alt) + f.altscan.Buffer(f.buf[:], bufio.MaxScanTokenSize) + f.altscan.Split(splitLines) + } + + f.OrigStderr = nil + if opts.ReplaceStderr { + f.OrigStderr, err = saveStderr() + if err != nil { + return nil, err + } + if err := dup2Stderr(f.cur); err != nil { + return nil, err + } + } + + return f, nil +} + +func moveContents(dst, src *os.File) (err error) { + defer func() { + _, err2 := src.Seek(0, io.SeekStart) + err3 := src.Truncate(0) + _, err4 := dst.Seek(0, io.SeekStart) + if err == nil { + err = err2 + } + if err == nil { + err = err3 + } + if err == nil { + err = err4 + } + }() + if _, err := src.Seek(0, io.SeekStart); err != nil { + return err + } + if _, err := dst.Seek(0, io.SeekStart); err != nil { + return err + } + if _, err := io.Copy(dst, src); err != nil { + return err + } + return nil +} + +func splitLines(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.IndexByte(data, '\n'); i >= 0 { + return i + 1, data[0 : i+1], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil +} diff --git a/logtail/filch/filch_stub.go b/logtail/filch/filch_stub.go index 3bb82b1906f17..fe718d150d0b8 100644 --- a/logtail/filch/filch_stub.go +++ b/logtail/filch/filch_stub.go @@ -1,23 +1,23 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build wasm || plan9 || tamago - -package filch - -import ( - "os" -) - -func saveStderr() (*os.File, error) { - return os.Stderr, nil -} - -func unsaveStderr(f *os.File) error { - os.Stderr = f - return nil -} - -func dup2Stderr(f *os.File) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build wasm || plan9 || tamago + +package filch + +import ( + "os" +) + +func saveStderr() (*os.File, error) { + return os.Stderr, nil +} + +func unsaveStderr(f *os.File) error { + os.Stderr = f + return nil +} + +func dup2Stderr(f *os.File) error { + return nil +} diff --git a/logtail/filch/filch_unix.go b/logtail/filch/filch_unix.go index 2eae70aceb187..b06ef6afde99f 100644 --- a/logtail/filch/filch_unix.go +++ b/logtail/filch/filch_unix.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !wasm && !plan9 && !tamago - -package filch - -import ( - "os" - - "golang.org/x/sys/unix" -) - -func saveStderr() (*os.File, error) { - fd, err := unix.Dup(stderrFD) - if err != nil { - return nil, err - } - return os.NewFile(uintptr(fd), "stderr"), nil -} - -func unsaveStderr(f *os.File) error { - err := dup2Stderr(f) - f.Close() - return err -} - -func dup2Stderr(f *os.File) error { - return unix.Dup2(int(f.Fd()), stderrFD) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !wasm && !plan9 && !tamago + +package filch + +import ( + "os" + + "golang.org/x/sys/unix" +) + +func saveStderr() (*os.File, error) { + fd, err := unix.Dup(stderrFD) + if err != nil { + return nil, err + } + return os.NewFile(uintptr(fd), "stderr"), nil +} + +func unsaveStderr(f *os.File) error { + err := dup2Stderr(f) + f.Close() + return err +} + +func dup2Stderr(f *os.File) error { + return unix.Dup2(int(f.Fd()), stderrFD) +} diff --git a/logtail/filch/filch_windows.go b/logtail/filch/filch_windows.go index d60514bf00abe..1419d660689ce 100644 --- a/logtail/filch/filch_windows.go +++ b/logtail/filch/filch_windows.go @@ -1,43 +1,43 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package filch - -import ( - "fmt" - "os" - "syscall" -) - -var kernel32 = syscall.MustLoadDLL("kernel32.dll") -var procSetStdHandle = kernel32.MustFindProc("SetStdHandle") - -func setStdHandle(stdHandle int32, handle syscall.Handle) error { - r, _, e := syscall.Syscall(procSetStdHandle.Addr(), 2, uintptr(stdHandle), uintptr(handle), 0) - if r == 0 { - if e != 0 { - return error(e) - } - return syscall.EINVAL - } - return nil -} - -func saveStderr() (*os.File, error) { - return os.Stderr, nil -} - -func unsaveStderr(f *os.File) error { - os.Stderr = f - return nil -} - -func dup2Stderr(f *os.File) error { - fd := int(f.Fd()) - err := setStdHandle(syscall.STD_ERROR_HANDLE, syscall.Handle(fd)) - if err != nil { - return fmt.Errorf("dup2Stderr: %w", err) - } - os.Stderr = f - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package filch + +import ( + "fmt" + "os" + "syscall" +) + +var kernel32 = syscall.MustLoadDLL("kernel32.dll") +var procSetStdHandle = kernel32.MustFindProc("SetStdHandle") + +func setStdHandle(stdHandle int32, handle syscall.Handle) error { + r, _, e := syscall.Syscall(procSetStdHandle.Addr(), 2, uintptr(stdHandle), uintptr(handle), 0) + if r == 0 { + if e != 0 { + return error(e) + } + return syscall.EINVAL + } + return nil +} + +func saveStderr() (*os.File, error) { + return os.Stderr, nil +} + +func unsaveStderr(f *os.File) error { + os.Stderr = f + return nil +} + +func dup2Stderr(f *os.File) error { + fd := int(f.Fd()) + err := setStdHandle(syscall.STD_ERROR_HANDLE, syscall.Handle(fd)) + if err != nil { + return fmt.Errorf("dup2Stderr: %w", err) + } + os.Stderr = f + return nil +} diff --git a/metrics/fds_linux.go b/metrics/fds_linux.go index 34740c2bb1c74..66ebb419d787c 100644 --- a/metrics/fds_linux.go +++ b/metrics/fds_linux.go @@ -1,41 +1,41 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package metrics - -import ( - "io/fs" - "sync" - - "go4.org/mem" - "tailscale.com/util/dirwalk" -) - -// counter is a reusable counter for counting file descriptors. -type counter struct { - n int - - // cb is the (*counter).count method value. Creating it allocates, - // so we have to save it away and use a sync.Pool to keep currentFDs - // amortized alloc-free. - cb func(name mem.RO, de fs.DirEntry) error -} - -var counterPool = &sync.Pool{New: func() any { - c := new(counter) - c.cb = c.count - return c -}} - -func (c *counter) count(name mem.RO, de fs.DirEntry) error { - c.n++ - return nil -} - -func currentFDs() int { - c := counterPool.Get().(*counter) - defer counterPool.Put(c) - c.n = 0 - dirwalk.WalkShallow(mem.S("/proc/self/fd"), c.cb) - return c.n -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package metrics + +import ( + "io/fs" + "sync" + + "go4.org/mem" + "tailscale.com/util/dirwalk" +) + +// counter is a reusable counter for counting file descriptors. +type counter struct { + n int + + // cb is the (*counter).count method value. Creating it allocates, + // so we have to save it away and use a sync.Pool to keep currentFDs + // amortized alloc-free. + cb func(name mem.RO, de fs.DirEntry) error +} + +var counterPool = &sync.Pool{New: func() any { + c := new(counter) + c.cb = c.count + return c +}} + +func (c *counter) count(name mem.RO, de fs.DirEntry) error { + c.n++ + return nil +} + +func currentFDs() int { + c := counterPool.Get().(*counter) + defer counterPool.Put(c) + c.n = 0 + dirwalk.WalkShallow(mem.S("/proc/self/fd"), c.cb) + return c.n +} diff --git a/metrics/fds_notlinux.go b/metrics/fds_notlinux.go index 2dae97cad86b9..5a59d4de9d8bf 100644 --- a/metrics/fds_notlinux.go +++ b/metrics/fds_notlinux.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package metrics - -func currentFDs() int { return 0 } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package metrics + +func currentFDs() int { return 0 } diff --git a/metrics/metrics.go b/metrics/metrics.go index a07ddccae5107..0f67ffa305e7c 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -1,163 +1,163 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package metrics contains expvar & Prometheus types and code used by -// Tailscale for monitoring. -package metrics - -import ( - "expvar" - "fmt" - "io" - "slices" - "strings" -) - -// Set is a string-to-Var map variable that satisfies the expvar.Var -// interface. -// -// Semantically, this is mapped by tsweb's Prometheus exporter as a -// collection of unrelated variables exported with a common prefix. -// -// This lets us have tsweb recognize *expvar.Map for different -// purposes in the future. (Or perhaps all uses of expvar.Map will -// require explicit types like this one, declaring how we want tsweb -// to export it to Prometheus.) -type Set struct { - expvar.Map -} - -// LabelMap is a string-to-Var map variable that satisfies the -// expvar.Var interface. -// -// Semantically, this is mapped by tsweb's Prometheus exporter as a -// collection of variables with the same name, with a varying label -// value. Use this to export things that are intuitively breakdowns -// into different buckets. -type LabelMap struct { - Label string - expvar.Map -} - -// SetInt64 sets the *Int value stored under the given map key. -func (m *LabelMap) SetInt64(key string, v int64) { - m.Get(key).Set(v) -} - -// Get returns a direct pointer to the expvar.Int for key, creating it -// if necessary. -func (m *LabelMap) Get(key string) *expvar.Int { - m.Add(key, 0) - return m.Map.Get(key).(*expvar.Int) -} - -// GetIncrFunc returns a function that increments the expvar.Int named by key. -// -// Most callers should not need this; it exists to satisfy an -// interface elsewhere. -func (m *LabelMap) GetIncrFunc(key string) func(delta int64) { - return m.Get(key).Add -} - -// GetFloat returns a direct pointer to the expvar.Float for key, creating it -// if necessary. -func (m *LabelMap) GetFloat(key string) *expvar.Float { - m.AddFloat(key, 0.0) - return m.Map.Get(key).(*expvar.Float) -} - -// CurrentFDs reports how many file descriptors are currently open. -// -// It only works on Linux. It returns zero otherwise. -func CurrentFDs() int { - return currentFDs() -} - -// Histogram is a histogram of values. -// It should be created with NewHistogram. -type Histogram struct { - // buckets is a list of bucket boundaries, in increasing order. - buckets []float64 - - // bucketStrings is a list of the same buckets, but as strings. - // This are allocated once at creation time by NewHistogram. - bucketStrings []string - - bucketVars []expvar.Int - sum expvar.Float - count expvar.Int -} - -// NewHistogram returns a new histogram that reports to the given -// expvar map under the given name. -// -// The buckets are the boundaries of the histogram buckets, in -// increasing order. The last bucket is +Inf. -func NewHistogram(buckets []float64) *Histogram { - if !slices.IsSorted(buckets) { - panic("buckets must be sorted") - } - labels := make([]string, len(buckets)) - for i, b := range buckets { - labels[i] = fmt.Sprintf("%v", b) - } - h := &Histogram{ - buckets: buckets, - bucketStrings: labels, - bucketVars: make([]expvar.Int, len(buckets)), - } - return h -} - -// Observe records a new observation in the histogram. -func (h *Histogram) Observe(v float64) { - h.sum.Add(v) - h.count.Add(1) - for i, b := range h.buckets { - if v <= b { - h.bucketVars[i].Add(1) - } - } -} - -// String returns a JSON representation of the histogram. -// This is used to satisfy the expvar.Var interface. -func (h *Histogram) String() string { - var b strings.Builder - fmt.Fprintf(&b, "{") - first := true - h.Do(func(kv expvar.KeyValue) { - if !first { - fmt.Fprintf(&b, ",") - } - fmt.Fprintf(&b, "%q: ", kv.Key) - if kv.Value != nil { - fmt.Fprintf(&b, "%v", kv.Value) - } else { - fmt.Fprint(&b, "null") - } - first = false - }) - fmt.Fprintf(&b, ",\"sum\": %v", &h.sum) - fmt.Fprintf(&b, ",\"count\": %v", &h.count) - fmt.Fprintf(&b, "}") - return b.String() -} - -// Do calls f for each bucket in the histogram. -func (h *Histogram) Do(f func(expvar.KeyValue)) { - for i := range h.bucketVars { - f(expvar.KeyValue{Key: h.bucketStrings[i], Value: &h.bucketVars[i]}) - } - f(expvar.KeyValue{Key: "+Inf", Value: &h.count}) -} - -// PromExport writes the histogram to w in Prometheus exposition format. -func (h *Histogram) PromExport(w io.Writer, name string) { - fmt.Fprintf(w, "# TYPE %s histogram\n", name) - h.Do(func(kv expvar.KeyValue) { - fmt.Fprintf(w, "%s_bucket{le=%q} %v\n", name, kv.Key, kv.Value) - }) - fmt.Fprintf(w, "%s_sum %v\n", name, &h.sum) - fmt.Fprintf(w, "%s_count %v\n", name, &h.count) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package metrics contains expvar & Prometheus types and code used by +// Tailscale for monitoring. +package metrics + +import ( + "expvar" + "fmt" + "io" + "slices" + "strings" +) + +// Set is a string-to-Var map variable that satisfies the expvar.Var +// interface. +// +// Semantically, this is mapped by tsweb's Prometheus exporter as a +// collection of unrelated variables exported with a common prefix. +// +// This lets us have tsweb recognize *expvar.Map for different +// purposes in the future. (Or perhaps all uses of expvar.Map will +// require explicit types like this one, declaring how we want tsweb +// to export it to Prometheus.) +type Set struct { + expvar.Map +} + +// LabelMap is a string-to-Var map variable that satisfies the +// expvar.Var interface. +// +// Semantically, this is mapped by tsweb's Prometheus exporter as a +// collection of variables with the same name, with a varying label +// value. Use this to export things that are intuitively breakdowns +// into different buckets. +type LabelMap struct { + Label string + expvar.Map +} + +// SetInt64 sets the *Int value stored under the given map key. +func (m *LabelMap) SetInt64(key string, v int64) { + m.Get(key).Set(v) +} + +// Get returns a direct pointer to the expvar.Int for key, creating it +// if necessary. +func (m *LabelMap) Get(key string) *expvar.Int { + m.Add(key, 0) + return m.Map.Get(key).(*expvar.Int) +} + +// GetIncrFunc returns a function that increments the expvar.Int named by key. +// +// Most callers should not need this; it exists to satisfy an +// interface elsewhere. +func (m *LabelMap) GetIncrFunc(key string) func(delta int64) { + return m.Get(key).Add +} + +// GetFloat returns a direct pointer to the expvar.Float for key, creating it +// if necessary. +func (m *LabelMap) GetFloat(key string) *expvar.Float { + m.AddFloat(key, 0.0) + return m.Map.Get(key).(*expvar.Float) +} + +// CurrentFDs reports how many file descriptors are currently open. +// +// It only works on Linux. It returns zero otherwise. +func CurrentFDs() int { + return currentFDs() +} + +// Histogram is a histogram of values. +// It should be created with NewHistogram. +type Histogram struct { + // buckets is a list of bucket boundaries, in increasing order. + buckets []float64 + + // bucketStrings is a list of the same buckets, but as strings. + // This are allocated once at creation time by NewHistogram. + bucketStrings []string + + bucketVars []expvar.Int + sum expvar.Float + count expvar.Int +} + +// NewHistogram returns a new histogram that reports to the given +// expvar map under the given name. +// +// The buckets are the boundaries of the histogram buckets, in +// increasing order. The last bucket is +Inf. +func NewHistogram(buckets []float64) *Histogram { + if !slices.IsSorted(buckets) { + panic("buckets must be sorted") + } + labels := make([]string, len(buckets)) + for i, b := range buckets { + labels[i] = fmt.Sprintf("%v", b) + } + h := &Histogram{ + buckets: buckets, + bucketStrings: labels, + bucketVars: make([]expvar.Int, len(buckets)), + } + return h +} + +// Observe records a new observation in the histogram. +func (h *Histogram) Observe(v float64) { + h.sum.Add(v) + h.count.Add(1) + for i, b := range h.buckets { + if v <= b { + h.bucketVars[i].Add(1) + } + } +} + +// String returns a JSON representation of the histogram. +// This is used to satisfy the expvar.Var interface. +func (h *Histogram) String() string { + var b strings.Builder + fmt.Fprintf(&b, "{") + first := true + h.Do(func(kv expvar.KeyValue) { + if !first { + fmt.Fprintf(&b, ",") + } + fmt.Fprintf(&b, "%q: ", kv.Key) + if kv.Value != nil { + fmt.Fprintf(&b, "%v", kv.Value) + } else { + fmt.Fprint(&b, "null") + } + first = false + }) + fmt.Fprintf(&b, ",\"sum\": %v", &h.sum) + fmt.Fprintf(&b, ",\"count\": %v", &h.count) + fmt.Fprintf(&b, "}") + return b.String() +} + +// Do calls f for each bucket in the histogram. +func (h *Histogram) Do(f func(expvar.KeyValue)) { + for i := range h.bucketVars { + f(expvar.KeyValue{Key: h.bucketStrings[i], Value: &h.bucketVars[i]}) + } + f(expvar.KeyValue{Key: "+Inf", Value: &h.count}) +} + +// PromExport writes the histogram to w in Prometheus exposition format. +func (h *Histogram) PromExport(w io.Writer, name string) { + fmt.Fprintf(w, "# TYPE %s histogram\n", name) + h.Do(func(kv expvar.KeyValue) { + fmt.Fprintf(w, "%s_bucket{le=%q} %v\n", name, kv.Key, kv.Value) + }) + fmt.Fprintf(w, "%s_sum %v\n", name, &h.sum) + fmt.Fprintf(w, "%s_count %v\n", name, &h.count) +} diff --git a/net/art/art_test.go b/net/art/art_test.go index daf8553ca020d..e3a427107e69b 100644 --- a/net/art/art_test.go +++ b/net/art/art_test.go @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package art - -import ( - "os" - "testing" - - "tailscale.com/util/cibuild" -) - -func TestMain(m *testing.M) { - if cibuild.On() { - // Skip CI on GitHub for now - // TODO: https://github.com/tailscale/tailscale/issues/7866 - os.Exit(0) - } - os.Exit(m.Run()) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package art + +import ( + "os" + "testing" + + "tailscale.com/util/cibuild" +) + +func TestMain(m *testing.M) { + if cibuild.On() { + // Skip CI on GitHub for now + // TODO: https://github.com/tailscale/tailscale/issues/7866 + os.Exit(0) + } + os.Exit(m.Run()) +} diff --git a/net/art/table.go b/net/art/table.go index fa397577868a8..2e130d82f78a1 100644 --- a/net/art/table.go +++ b/net/art/table.go @@ -1,641 +1,641 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package art provides a routing table that implements the Allotment Routing -// Table (ART) algorithm by Donald Knuth, as described in the paper by Yoichi -// Hariguchi. -// -// ART outperforms the traditional radix tree implementations for route lookups, -// insertions, and deletions. -// -// For more information, see Yoichi Hariguchi's paper: -// https://cseweb.ucsd.edu//~varghese/TEACH/cs228/artlookup.pdf -package art - -import ( - "bytes" - "encoding/binary" - "fmt" - "io" - "math/bits" - "net/netip" - "strings" - "sync" -) - -const ( - debugInsert = false - debugDelete = false -) - -// Table is an IPv4 and IPv6 routing table. -type Table[T any] struct { - v4 strideTable[T] - v6 strideTable[T] - initOnce sync.Once -} - -func (t *Table[T]) init() { - t.initOnce.Do(func() { - t.v4.prefix = netip.PrefixFrom(netip.IPv4Unspecified(), 0) - t.v6.prefix = netip.PrefixFrom(netip.IPv6Unspecified(), 0) - }) -} - -func (t *Table[T]) tableForAddr(addr netip.Addr) *strideTable[T] { - if addr.Is6() { - return &t.v6 - } - return &t.v4 -} - -// Get does a route lookup for addr and returns the associated value, or nil if -// no route matched. -func (t *Table[T]) Get(addr netip.Addr) (ret T, ok bool) { - t.init() - - // Ideally we would use addr.AsSlice here, but AsSlice is just - // barely complex enough that it can't be inlined, and that in - // turn causes the slice to escape to the heap. Using As16 and - // manual slicing here helps the compiler keep Get alloc-free. - st := t.tableForAddr(addr) - rawAddr := addr.As16() - bs := rawAddr[:] - if addr.Is4() { - bs = bs[12:] - } - - i := 0 - // With path compression, we might skip over some address bits while walking - // to a strideTable leaf. This means the leaf answer we find might not be - // correct, because path compression took us down the wrong subtree. When - // that happens, we have to backtrack and figure out which most specific - // route further up the tree is relevant to addr, and return that. - // - // So, as we walk down the stride tables, each time we find a non-nil route - // result, we have to remember it and the associated strideTable prefix. - // - // We could also deal with this edge case of path compression by checking - // the strideTable prefix on each table as we descend, but that means we - // have to pay N prefix.Contains checks on every route lookup (where N is - // the number of strideTables in the path), rather than only paying M prefix - // comparisons in the edge case (where M is the number of strideTables in - // the path with a non-nil route of their own). - const maxDepth = 16 - type prefixAndRoute struct { - prefix netip.Prefix - route T - } - strideMatch := make([]prefixAndRoute, 0, maxDepth) -findLeaf: - for { - rt, rtOK, child := st.getValAndChild(bs[i]) - if rtOK { - // This strideTable contains a route that may be relevant to our - // search, remember it. - strideMatch = append(strideMatch, prefixAndRoute{st.prefix, rt}) - } - if child == nil { - // No sub-routes further down, the last thing we recorded - // in strideRoutes is tentatively the result, barring - // misdirection from path compression. - break findLeaf - } - st = child - // Path compression means we may be skipping over some intermediate - // tables. We have to skip forward to whatever depth st now references. - i = st.prefix.Bits() / 8 - } - - // Walk backwards through the hits we recorded in strideRoutes and - // stridePrefixes, returning the first one whose subtree matches addr. - // - // In the common case where path compression did not mislead us, we'll - // return on the first loop iteration because the last route we recorded was - // the correct most-specific route. - for i := len(strideMatch) - 1; i >= 0; i-- { - if m := strideMatch[i]; m.prefix.Contains(addr) { - return m.route, true - } - } - - // We either found no route hits at all (both previous loops terminated - // immediately), or we went on a wild goose chase down a compressed path for - // the wrong prefix, and also found no usable routes on the way back up to - // the root. This is a miss. - return ret, false -} - -// Insert adds pfx to the table, with value val. -// If pfx is already present in the table, its value is set to val. -func (t *Table[T]) Insert(pfx netip.Prefix, val T) { - t.init() - - // The standard library doesn't enforce normalized prefixes (where - // the non-prefix bits are all zero). These algorithms require - // normalized prefixes, so do it upfront. - pfx = pfx.Masked() - - if debugInsert { - defer func() { - fmt.Printf("%s", t.debugSummary()) - }() - fmt.Printf("\ninsert: start pfx=%s\n", pfx) - } - - st := t.tableForAddr(pfx.Addr()) - - // This algorithm is full of off-by-one headaches that boil down - // to the fact that pfx.Bits() has (2^n)+1 values, rather than - // just 2^n. For example, an IPv4 prefix length can be 0 through - // 32, which is 33 values. - // - // This extra possible value creates a lot of problems as we do - // bits and bytes math to traverse strideTables below. So, we - // treat the default route 0/0 specially here, that way the rest - // of the logic goes back to having 2^n values to reason about, - // which can be done in a nice and regular fashion with no edge - // cases. - if pfx.Bits() == 0 { - if debugInsert { - fmt.Printf("insert: default route\n") - } - st.insert(0, 0, val) - return - } - - // No matter what we do as we traverse strideTables, our final - // action will be to insert the last 1-8 bits of pfx into a - // strideTable somewhere. - // - // We calculate upfront the byte position of the end of the - // prefix; the number of bits within that byte that contain prefix - // data; and the prefix of the strideTable into which we'll - // eventually insert. - // - // We need this in a couple different branches of the code below, - // and because the possible values are 1-indexed (1 through 32 for - // ipv4, 1 through 128 for ipv6), the math is very slightly - // unusual to account for the off-by-one indexing. Do it once up - // here, with this large comment, rather than reproduce the subtle - // math in multiple places further down. - finalByteIdx := (pfx.Bits() - 1) / 8 - finalBits := pfx.Bits() - (finalByteIdx * 8) - finalStridePrefix, err := pfx.Addr().Prefix(finalByteIdx * 8) - if err != nil { - panic(fmt.Sprintf("invalid prefix requested: %s/%d", pfx.Addr(), finalByteIdx*8)) - } - if debugInsert { - fmt.Printf("insert: finalByteIdx=%d finalBits=%d finalStridePrefix=%s\n", finalByteIdx, finalBits, finalStridePrefix) - } - - // The strideTable we want to insert into is potentially at the - // end of a chain of strideTables, each one encoding 8 bits of the - // prefix. - // - // We're expecting to walk down a path of tables, although with - // prefix compression we may end up skipping some links in the - // chain, or taking wrong turns and having to course correct. - // - // As we walk down the tree, byteIdx is the byte of bs we're - // currently examining to choose our next step, and numBits is the - // number of bits that remain in pfx, starting with the byte at - // byteIdx inclusive. - bs := pfx.Addr().AsSlice() - byteIdx := 0 - numBits := pfx.Bits() - for { - if debugInsert { - fmt.Printf("insert: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix) - } - if numBits <= 8 { - if debugInsert { - fmt.Printf("insert: existing leaf st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits) - } - // We've reached the end of the prefix, whichever - // strideTable we're looking at now is the place where we - // need to insert. - st.insert(bs[finalByteIdx], finalBits, val) - return - } - - // Otherwise, we need to go down at least one more level of - // strideTables. With prefix compression, each level of - // descent can have one of three outcomes: we find a place - // where prefix compression is possible; a place where prefix - // compression made us take a "wrong turn"; or a point along - // our intended path that we have to keep following. - child, created := st.getOrCreateChild(bs[byteIdx]) - switch { - case created: - // The subtree we need for pfx doesn't exist yet. The rest - // of the path, if we were to create it, will consist of a - // bunch of strideTables with a single child each. We can - // use path compression to elide those intermediates, and - // jump straight to the final strideTable that hosts this - // prefix. - child.prefix = finalStridePrefix - child.insert(bs[finalByteIdx], finalBits, val) - if debugInsert { - fmt.Printf("insert: new leaf st.prefix=%s child.prefix=%s addr=%d/%d\n", st.prefix, child.prefix, bs[finalByteIdx], finalBits) - } - return - case !prefixStrictlyContains(child.prefix, pfx): - // child already exists, but its prefix does not contain - // our destination. This means that the path between st - // and child was compressed by a previous insertion, and - // somewhere in the (implicit) compressed path we took a - // wrong turn, into the wrong part of st's subtree. - // - // This is okay, because pfx and child.prefix must have a - // common ancestor node somewhere between st and child. We - // can figure out what node that is, and materialize it. - // - // Once we've done that, we can immediately complete the - // remainder of the insertion in one of two ways, without - // further traversal. See a little further down for what - // those are. - if debugInsert { - fmt.Printf("insert: wrong turn, pfx=%s child.prefix=%s\n", pfx, child.prefix) - } - intermediatePrefix, addrOfExisting, addrOfNew := computePrefixSplit(child.prefix, pfx) - intermediate := &strideTable[T]{prefix: intermediatePrefix} // TODO: make this whole thing be st.AddIntermediate or something? - st.setChild(bs[byteIdx], intermediate) - intermediate.setChild(addrOfExisting, child) - - if debugInsert { - fmt.Printf("insert: new intermediate st.prefix=%s intermediate.prefix=%s child.prefix=%s\n", st.prefix, intermediate.prefix, child.prefix) - } - - // Now, we have a chain of st -> intermediate -> child. - // - // pfx either lives in a different child of intermediate, - // or in intermediate itself. For example, if we created - // the intermediate 1.2.0.0/16, pfx=1.2.3.4/32 would have - // to go into a new child of intermediate, but - // pfx=1.2.0.0/18 would go into intermediate directly. - if remain := pfx.Bits() - intermediate.prefix.Bits(); remain <= 8 { - // pfx lives in intermediate. - if debugInsert { - fmt.Printf("insert: into intermediate intermediate.prefix=%s addr=%d/%d\n", intermediate.prefix, bs[finalByteIdx], finalBits) - } - intermediate.insert(bs[finalByteIdx], finalBits, val) - } else { - // pfx lives in a different child subtree of - // intermediate. By definition this subtree doesn't - // exist at all, otherwise we'd never have entered - // this entire "wrong turn" codepath in the first - // place. - // - // This means we can apply prefix compression as we - // create this new child, and we're done. - st, created = intermediate.getOrCreateChild(addrOfNew) - if !created { - panic("new child path unexpectedly exists during path decompression") - } - st.prefix = finalStridePrefix - st.insert(bs[finalByteIdx], finalBits, val) - if debugInsert { - fmt.Printf("insert: new child st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits) - } - } - - return - default: - // An expected child table exists along pfx's - // path. Continue traversing downwards. - st = child - byteIdx = child.prefix.Bits() / 8 - numBits = pfx.Bits() - child.prefix.Bits() - if debugInsert { - fmt.Printf("insert: descend st.prefix=%s\n", st.prefix) - } - } - } -} - -// Delete removes pfx from the table, if it is present. -func (t *Table[T]) Delete(pfx netip.Prefix) { - t.init() - - // The standard library doesn't enforce normalized prefixes (where - // the non-prefix bits are all zero). These algorithms require - // normalized prefixes, so do it upfront. - pfx = pfx.Masked() - - if debugDelete { - defer func() { - fmt.Printf("%s", t.debugSummary()) - }() - fmt.Printf("\ndelete: start pfx=%s table:\n%s", pfx, t.debugSummary()) - } - - st := t.tableForAddr(pfx.Addr()) - - // This algorithm is full of off-by-one headaches, just like - // Insert. See the comment in Insert for more details. Bottom - // line: we handle the default route as a special case, and that - // simplifies the rest of the code slightly. - if pfx.Bits() == 0 { - if debugDelete { - fmt.Printf("delete: default route\n") - } - st.delete(0, 0) - return - } - - // Deletion may drive the refcount of some strideTables down to - // zero. We need to clean up these dangling tables, so we have to - // keep track of which tables we touch on the way down, and which - // strideEntry index each child is registered in. - // - // Note that the strideIndex and strideTables entries are off-by-one. - // The child table pointer is recorded at i+1, but it is referenced by a - // particular index in the parent table, at index i. - // - // In other words: entry number strideIndexes[0] in - // strideTables[0] is the same pointer as strideTables[1]. - // - // This results in some slightly odd array accesses further down - // in this code, because in a single loop iteration we have to - // write to strideTables[N] and strideIndexes[N-1]. - strideIdx := 0 - strideTables := [16]*strideTable[T]{st} - strideIndexes := [15]uint8{} - - // Similar to Insert, navigate down the tree of strideTables, - // looking for the one that houses this prefix. This part is - // easier than with insertion, since we can bail if the path ends - // early or takes an unexpected detour. However, unlike - // insertion, there's a whole post-deletion cleanup phase later - // on. - // - // As we walk down the tree, byteIdx is the byte of bs we're - // currently examining to choose our next step, and numBits is the - // number of bits that remain in pfx, starting with the byte at - // byteIdx inclusive. - bs := pfx.Addr().AsSlice() - byteIdx := 0 - numBits := pfx.Bits() - for numBits > 8 { - if debugDelete { - fmt.Printf("delete: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix) - } - child := st.getChild(bs[byteIdx]) - if child == nil { - // Prefix can't exist in the table, because one of the - // necessary strideTables doesn't exist. - if debugDelete { - fmt.Printf("delete: missing necessary child pfx=%s\n", pfx) - } - return - } - strideIndexes[strideIdx] = bs[byteIdx] - strideTables[strideIdx+1] = child - strideIdx++ - - // Path compression means byteIdx can jump forwards - // unpredictably. Recompute the next byte to look at from the - // child we just found. - byteIdx = child.prefix.Bits() / 8 - numBits = pfx.Bits() - child.prefix.Bits() - st = child - - if debugDelete { - fmt.Printf("delete: descend st.prefix=%s\n", st.prefix) - } - } - - // We reached a leaf stride table that seems to be in the right - // spot. But path compression might have led us to the wrong - // table. - if !prefixStrictlyContains(st.prefix, pfx) { - // Wrong table, the requested prefix can't exist since its - // path led us to the wrong place. - if debugDelete { - fmt.Printf("delete: wrong leaf table pfx=%s\n", pfx) - } - return - } - if debugDelete { - fmt.Printf("delete: delete from st.prefix=%s addr=%d/%d\n", st.prefix, bs[byteIdx], numBits) - } - if routeExisted := st.delete(bs[byteIdx], numBits); !routeExisted { - // We're in the right strideTable, but pfx wasn't in - // it. Refcounts haven't changed, so we can skip cleanup. - if debugDelete { - fmt.Printf("delete: prefix not present pfx=%s\n", pfx) - } - return - } - - // st.delete reduced st's refcount by one. This table may now be - // reclaimable, and depending on how we can reclaim it, the parent - // tables may also need to be reclaimed. This loop ends as soon as - // an iteration takes no action, or takes an action that doesn't - // alter the parent table's refcounts. - // - // We start our walk back at strideTables[strideIdx], which - // contains st. - for strideIdx > 0 { - cur := strideTables[strideIdx] - if debugDelete { - fmt.Printf("delete: GC? strideIdx=%d st.prefix=%s\n", strideIdx, cur.prefix) - } - if cur.routeRefs > 0 { - // the strideTable has other route entries, it cannot be - // deleted or compacted. - if debugDelete { - fmt.Printf("delete: has other routes st.prefix=%s\n", cur.prefix) - } - return - } - switch cur.childRefs { - case 0: - // no routeRefs and no childRefs, this table can be - // deleted. This will alter the parent table's refcount, - // so we'll have to look at it as well (in the next loop - // iteration). - if debugDelete { - fmt.Printf("delete: remove st.prefix=%s\n", cur.prefix) - } - strideTables[strideIdx-1].deleteChild(strideIndexes[strideIdx-1]) - strideIdx-- - case 1: - // This table has no routes, and a single child. Compact - // this table out of existence by making the parent point - // directly at the one child. This does not affect the - // parent's refcounts, so the parent can't be eligible for - // deletion or compaction, and we can stop. - child := strideTables[strideIdx].findFirstChild() // only 1 child exists, by definition - parent := strideTables[strideIdx-1] - if debugDelete { - fmt.Printf("delete: compact parent.prefix=%s st.prefix=%s child.prefix=%s\n", parent.prefix, cur.prefix, child.prefix) - } - strideTables[strideIdx-1].setChild(strideIndexes[strideIdx-1], child) - return - default: - // This table has two or more children, so it's acting as a "fork in - // the road" between two prefix subtrees. It cannot be deleted, and - // thus no further cleanups are possible. - if debugDelete { - fmt.Printf("delete: fork table st.prefix=%s\n", cur.prefix) - } - return - } - } -} - -// debugSummary prints the tree of allocated strideTables in t, with each -// strideTable's refcount. -func (t *Table[T]) debugSummary() string { - t.init() - var ret bytes.Buffer - fmt.Fprintf(&ret, "v4: ") - strideSummary(&ret, &t.v4, 4) - fmt.Fprintf(&ret, "v6: ") - strideSummary(&ret, &t.v6, 4) - return ret.String() -} - -func strideSummary[T any](w io.Writer, st *strideTable[T], indent int) { - fmt.Fprintf(w, "%s: %d routes, %d children\n", st.prefix, st.routeRefs, st.childRefs) - indent += 4 - st.treeDebugStringRec(w, 1, indent) - for addr, child := range st.children { - if child == nil { - continue - } - fmt.Fprintf(w, "%s%d/8 (%02x/8): ", strings.Repeat(" ", indent), addr, addr) - strideSummary(w, child, indent) - } -} - -// prefixStrictlyContains reports whether child is a prefix within -// parent, but not parent itself. -func prefixStrictlyContains(parent, child netip.Prefix) bool { - return parent.Overlaps(child) && parent.Bits() < child.Bits() -} - -// computePrefixSplit returns the smallest common prefix that contains -// both a and b. lastCommon is 8-bit aligned, with aStride and bStride -// indicating the value of the 8-bit stride immediately following -// lastCommon. -// -// computePrefixSplit is used in constructing an intermediate -// strideTable when a new prefix needs to be inserted in a compressed -// table. It can be read as: given that a is already in the table, and -// b is being inserted, what is the prefix of the new intermediate -// strideTable that needs to be created, and at what addresses in that -// new strideTable should a and b's subsequent strideTables be -// attached? -// -// Note as a special case, this can be called with a==b. An example of -// when this happens: -// - We want to insert the prefix 1.2.0.0/16 -// - A strideTable exists for 1.2.0.0/16, because another child -// prefix already exists (e.g. 1.2.3.4/32) -// - The 1.0.0.0/8 strideTable does not exist, because path -// compression removed it. -// -// In this scenario, the caller of computePrefixSplit ends up making a -// "wrong turn" while traversing strideTables: it was looking for the -// 1.0.0.0/8 table, but ended up at the 1.2.0.0/16 table. When this -// happens, it will invoke computePrefixSplit(1.2.0.0/16, 1.2.0.0/16), -// and we return 1.0.0.0/8 as the missing intermediate. -func computePrefixSplit(a, b netip.Prefix) (lastCommon netip.Prefix, aStride, bStride uint8) { - a = a.Masked() - b = b.Masked() - if a.Bits() == 0 || b.Bits() == 0 { - panic("computePrefixSplit called with a default route") - } - if a.Addr().Is4() != b.Addr().Is4() { - panic("computePrefixSplit called with mismatched address families") - } - - minPrefixLen := a.Bits() - if b.Bits() < minPrefixLen { - minPrefixLen = b.Bits() - } - - commonBits := commonBits(a.Addr(), b.Addr(), minPrefixLen) - // We want to know how many 8-bit strides are shared between a and - // b. Naively, this would be commonBits/8, but this introduces an - // off-by-one error. This is due to the way our ART stores - // prefixes whose length falls exactly on a stride boundary. - // - // Consider 192.168.1.0/24 and 192.168.0.0/16. commonBits - // correctly reports that these prefixes have their first 16 bits - // in common. However, in the ART they only share 1 common stride: - // they both use the 192.0.0.0/8 strideTable, but 192.168.0.0/16 - // is stored as 168/8 within that table, and not as 0/0 in the - // 192.168.0.0/16 table. - // - // So, when commonBits matches the length of one of the inputs and - // falls on a boundary between strides, the strideTable one - // further up from commonBits/8 is the one we need to create, - // which means we have to adjust the stride count down by one. - if commonBits == minPrefixLen { - commonBits-- - } - commonStrides := commonBits / 8 - lastCommon, err := a.Addr().Prefix(commonStrides * 8) - if err != nil { - panic(fmt.Sprintf("computePrefixSplit constructing common prefix: %v", err)) - } - if a.Addr().Is4() { - aStride = a.Addr().As4()[commonStrides] - bStride = b.Addr().As4()[commonStrides] - } else { - aStride = a.Addr().As16()[commonStrides] - bStride = b.Addr().As16()[commonStrides] - } - return lastCommon, aStride, bStride -} - -// commonBits returns the number of common leading bits of a and b. -// If the number of common bits exceeds maxBits, it returns maxBits -// instead. -func commonBits(a, b netip.Addr, maxBits int) int { - if a.Is4() != b.Is4() { - panic("commonStrides called with mismatched address families") - } - var common int - // The following implements an old bit-twiddling trick to compute - // the number of common leading bits: if you XOR two numbers - // together, equal bits become 0 and unequal bits become 1. You - // can then count the number of leading zeros (which is a single - // instruction on modern CPUs) to get the answer. - // - // This code is a little more complex than just XOR + count - // leading zeros, because IPv4 and IPv6 are different sizes, and - // for IPv6 we have to do the math in two 64-bit chunks because Go - // lacks a uint128 type. - if a.Is4() { - aNum, bNum := ipv4AsUint(a), ipv4AsUint(b) - common = bits.LeadingZeros32(aNum ^ bNum) - } else { - aNumHi, aNumLo := ipv6AsUint(a) - bNumHi, bNumLo := ipv6AsUint(b) - common = bits.LeadingZeros64(aNumHi ^ bNumHi) - if common == 64 { - common += bits.LeadingZeros64(aNumLo ^ bNumLo) - } - } - if common > maxBits { - common = maxBits - } - return common -} - -// ipv4AsUint returns ip as a uint32. -func ipv4AsUint(ip netip.Addr) uint32 { - bs := ip.As4() - return binary.BigEndian.Uint32(bs[:]) -} - -// ipv6AsUint returns ip as a pair of uint64s. -func ipv6AsUint(ip netip.Addr) (uint64, uint64) { - bs := ip.As16() - return binary.BigEndian.Uint64(bs[:8]), binary.BigEndian.Uint64(bs[8:]) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package art provides a routing table that implements the Allotment Routing +// Table (ART) algorithm by Donald Knuth, as described in the paper by Yoichi +// Hariguchi. +// +// ART outperforms the traditional radix tree implementations for route lookups, +// insertions, and deletions. +// +// For more information, see Yoichi Hariguchi's paper: +// https://cseweb.ucsd.edu//~varghese/TEACH/cs228/artlookup.pdf +package art + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math/bits" + "net/netip" + "strings" + "sync" +) + +const ( + debugInsert = false + debugDelete = false +) + +// Table is an IPv4 and IPv6 routing table. +type Table[T any] struct { + v4 strideTable[T] + v6 strideTable[T] + initOnce sync.Once +} + +func (t *Table[T]) init() { + t.initOnce.Do(func() { + t.v4.prefix = netip.PrefixFrom(netip.IPv4Unspecified(), 0) + t.v6.prefix = netip.PrefixFrom(netip.IPv6Unspecified(), 0) + }) +} + +func (t *Table[T]) tableForAddr(addr netip.Addr) *strideTable[T] { + if addr.Is6() { + return &t.v6 + } + return &t.v4 +} + +// Get does a route lookup for addr and returns the associated value, or nil if +// no route matched. +func (t *Table[T]) Get(addr netip.Addr) (ret T, ok bool) { + t.init() + + // Ideally we would use addr.AsSlice here, but AsSlice is just + // barely complex enough that it can't be inlined, and that in + // turn causes the slice to escape to the heap. Using As16 and + // manual slicing here helps the compiler keep Get alloc-free. + st := t.tableForAddr(addr) + rawAddr := addr.As16() + bs := rawAddr[:] + if addr.Is4() { + bs = bs[12:] + } + + i := 0 + // With path compression, we might skip over some address bits while walking + // to a strideTable leaf. This means the leaf answer we find might not be + // correct, because path compression took us down the wrong subtree. When + // that happens, we have to backtrack and figure out which most specific + // route further up the tree is relevant to addr, and return that. + // + // So, as we walk down the stride tables, each time we find a non-nil route + // result, we have to remember it and the associated strideTable prefix. + // + // We could also deal with this edge case of path compression by checking + // the strideTable prefix on each table as we descend, but that means we + // have to pay N prefix.Contains checks on every route lookup (where N is + // the number of strideTables in the path), rather than only paying M prefix + // comparisons in the edge case (where M is the number of strideTables in + // the path with a non-nil route of their own). + const maxDepth = 16 + type prefixAndRoute struct { + prefix netip.Prefix + route T + } + strideMatch := make([]prefixAndRoute, 0, maxDepth) +findLeaf: + for { + rt, rtOK, child := st.getValAndChild(bs[i]) + if rtOK { + // This strideTable contains a route that may be relevant to our + // search, remember it. + strideMatch = append(strideMatch, prefixAndRoute{st.prefix, rt}) + } + if child == nil { + // No sub-routes further down, the last thing we recorded + // in strideRoutes is tentatively the result, barring + // misdirection from path compression. + break findLeaf + } + st = child + // Path compression means we may be skipping over some intermediate + // tables. We have to skip forward to whatever depth st now references. + i = st.prefix.Bits() / 8 + } + + // Walk backwards through the hits we recorded in strideRoutes and + // stridePrefixes, returning the first one whose subtree matches addr. + // + // In the common case where path compression did not mislead us, we'll + // return on the first loop iteration because the last route we recorded was + // the correct most-specific route. + for i := len(strideMatch) - 1; i >= 0; i-- { + if m := strideMatch[i]; m.prefix.Contains(addr) { + return m.route, true + } + } + + // We either found no route hits at all (both previous loops terminated + // immediately), or we went on a wild goose chase down a compressed path for + // the wrong prefix, and also found no usable routes on the way back up to + // the root. This is a miss. + return ret, false +} + +// Insert adds pfx to the table, with value val. +// If pfx is already present in the table, its value is set to val. +func (t *Table[T]) Insert(pfx netip.Prefix, val T) { + t.init() + + // The standard library doesn't enforce normalized prefixes (where + // the non-prefix bits are all zero). These algorithms require + // normalized prefixes, so do it upfront. + pfx = pfx.Masked() + + if debugInsert { + defer func() { + fmt.Printf("%s", t.debugSummary()) + }() + fmt.Printf("\ninsert: start pfx=%s\n", pfx) + } + + st := t.tableForAddr(pfx.Addr()) + + // This algorithm is full of off-by-one headaches that boil down + // to the fact that pfx.Bits() has (2^n)+1 values, rather than + // just 2^n. For example, an IPv4 prefix length can be 0 through + // 32, which is 33 values. + // + // This extra possible value creates a lot of problems as we do + // bits and bytes math to traverse strideTables below. So, we + // treat the default route 0/0 specially here, that way the rest + // of the logic goes back to having 2^n values to reason about, + // which can be done in a nice and regular fashion with no edge + // cases. + if pfx.Bits() == 0 { + if debugInsert { + fmt.Printf("insert: default route\n") + } + st.insert(0, 0, val) + return + } + + // No matter what we do as we traverse strideTables, our final + // action will be to insert the last 1-8 bits of pfx into a + // strideTable somewhere. + // + // We calculate upfront the byte position of the end of the + // prefix; the number of bits within that byte that contain prefix + // data; and the prefix of the strideTable into which we'll + // eventually insert. + // + // We need this in a couple different branches of the code below, + // and because the possible values are 1-indexed (1 through 32 for + // ipv4, 1 through 128 for ipv6), the math is very slightly + // unusual to account for the off-by-one indexing. Do it once up + // here, with this large comment, rather than reproduce the subtle + // math in multiple places further down. + finalByteIdx := (pfx.Bits() - 1) / 8 + finalBits := pfx.Bits() - (finalByteIdx * 8) + finalStridePrefix, err := pfx.Addr().Prefix(finalByteIdx * 8) + if err != nil { + panic(fmt.Sprintf("invalid prefix requested: %s/%d", pfx.Addr(), finalByteIdx*8)) + } + if debugInsert { + fmt.Printf("insert: finalByteIdx=%d finalBits=%d finalStridePrefix=%s\n", finalByteIdx, finalBits, finalStridePrefix) + } + + // The strideTable we want to insert into is potentially at the + // end of a chain of strideTables, each one encoding 8 bits of the + // prefix. + // + // We're expecting to walk down a path of tables, although with + // prefix compression we may end up skipping some links in the + // chain, or taking wrong turns and having to course correct. + // + // As we walk down the tree, byteIdx is the byte of bs we're + // currently examining to choose our next step, and numBits is the + // number of bits that remain in pfx, starting with the byte at + // byteIdx inclusive. + bs := pfx.Addr().AsSlice() + byteIdx := 0 + numBits := pfx.Bits() + for { + if debugInsert { + fmt.Printf("insert: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix) + } + if numBits <= 8 { + if debugInsert { + fmt.Printf("insert: existing leaf st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits) + } + // We've reached the end of the prefix, whichever + // strideTable we're looking at now is the place where we + // need to insert. + st.insert(bs[finalByteIdx], finalBits, val) + return + } + + // Otherwise, we need to go down at least one more level of + // strideTables. With prefix compression, each level of + // descent can have one of three outcomes: we find a place + // where prefix compression is possible; a place where prefix + // compression made us take a "wrong turn"; or a point along + // our intended path that we have to keep following. + child, created := st.getOrCreateChild(bs[byteIdx]) + switch { + case created: + // The subtree we need for pfx doesn't exist yet. The rest + // of the path, if we were to create it, will consist of a + // bunch of strideTables with a single child each. We can + // use path compression to elide those intermediates, and + // jump straight to the final strideTable that hosts this + // prefix. + child.prefix = finalStridePrefix + child.insert(bs[finalByteIdx], finalBits, val) + if debugInsert { + fmt.Printf("insert: new leaf st.prefix=%s child.prefix=%s addr=%d/%d\n", st.prefix, child.prefix, bs[finalByteIdx], finalBits) + } + return + case !prefixStrictlyContains(child.prefix, pfx): + // child already exists, but its prefix does not contain + // our destination. This means that the path between st + // and child was compressed by a previous insertion, and + // somewhere in the (implicit) compressed path we took a + // wrong turn, into the wrong part of st's subtree. + // + // This is okay, because pfx and child.prefix must have a + // common ancestor node somewhere between st and child. We + // can figure out what node that is, and materialize it. + // + // Once we've done that, we can immediately complete the + // remainder of the insertion in one of two ways, without + // further traversal. See a little further down for what + // those are. + if debugInsert { + fmt.Printf("insert: wrong turn, pfx=%s child.prefix=%s\n", pfx, child.prefix) + } + intermediatePrefix, addrOfExisting, addrOfNew := computePrefixSplit(child.prefix, pfx) + intermediate := &strideTable[T]{prefix: intermediatePrefix} // TODO: make this whole thing be st.AddIntermediate or something? + st.setChild(bs[byteIdx], intermediate) + intermediate.setChild(addrOfExisting, child) + + if debugInsert { + fmt.Printf("insert: new intermediate st.prefix=%s intermediate.prefix=%s child.prefix=%s\n", st.prefix, intermediate.prefix, child.prefix) + } + + // Now, we have a chain of st -> intermediate -> child. + // + // pfx either lives in a different child of intermediate, + // or in intermediate itself. For example, if we created + // the intermediate 1.2.0.0/16, pfx=1.2.3.4/32 would have + // to go into a new child of intermediate, but + // pfx=1.2.0.0/18 would go into intermediate directly. + if remain := pfx.Bits() - intermediate.prefix.Bits(); remain <= 8 { + // pfx lives in intermediate. + if debugInsert { + fmt.Printf("insert: into intermediate intermediate.prefix=%s addr=%d/%d\n", intermediate.prefix, bs[finalByteIdx], finalBits) + } + intermediate.insert(bs[finalByteIdx], finalBits, val) + } else { + // pfx lives in a different child subtree of + // intermediate. By definition this subtree doesn't + // exist at all, otherwise we'd never have entered + // this entire "wrong turn" codepath in the first + // place. + // + // This means we can apply prefix compression as we + // create this new child, and we're done. + st, created = intermediate.getOrCreateChild(addrOfNew) + if !created { + panic("new child path unexpectedly exists during path decompression") + } + st.prefix = finalStridePrefix + st.insert(bs[finalByteIdx], finalBits, val) + if debugInsert { + fmt.Printf("insert: new child st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits) + } + } + + return + default: + // An expected child table exists along pfx's + // path. Continue traversing downwards. + st = child + byteIdx = child.prefix.Bits() / 8 + numBits = pfx.Bits() - child.prefix.Bits() + if debugInsert { + fmt.Printf("insert: descend st.prefix=%s\n", st.prefix) + } + } + } +} + +// Delete removes pfx from the table, if it is present. +func (t *Table[T]) Delete(pfx netip.Prefix) { + t.init() + + // The standard library doesn't enforce normalized prefixes (where + // the non-prefix bits are all zero). These algorithms require + // normalized prefixes, so do it upfront. + pfx = pfx.Masked() + + if debugDelete { + defer func() { + fmt.Printf("%s", t.debugSummary()) + }() + fmt.Printf("\ndelete: start pfx=%s table:\n%s", pfx, t.debugSummary()) + } + + st := t.tableForAddr(pfx.Addr()) + + // This algorithm is full of off-by-one headaches, just like + // Insert. See the comment in Insert for more details. Bottom + // line: we handle the default route as a special case, and that + // simplifies the rest of the code slightly. + if pfx.Bits() == 0 { + if debugDelete { + fmt.Printf("delete: default route\n") + } + st.delete(0, 0) + return + } + + // Deletion may drive the refcount of some strideTables down to + // zero. We need to clean up these dangling tables, so we have to + // keep track of which tables we touch on the way down, and which + // strideEntry index each child is registered in. + // + // Note that the strideIndex and strideTables entries are off-by-one. + // The child table pointer is recorded at i+1, but it is referenced by a + // particular index in the parent table, at index i. + // + // In other words: entry number strideIndexes[0] in + // strideTables[0] is the same pointer as strideTables[1]. + // + // This results in some slightly odd array accesses further down + // in this code, because in a single loop iteration we have to + // write to strideTables[N] and strideIndexes[N-1]. + strideIdx := 0 + strideTables := [16]*strideTable[T]{st} + strideIndexes := [15]uint8{} + + // Similar to Insert, navigate down the tree of strideTables, + // looking for the one that houses this prefix. This part is + // easier than with insertion, since we can bail if the path ends + // early or takes an unexpected detour. However, unlike + // insertion, there's a whole post-deletion cleanup phase later + // on. + // + // As we walk down the tree, byteIdx is the byte of bs we're + // currently examining to choose our next step, and numBits is the + // number of bits that remain in pfx, starting with the byte at + // byteIdx inclusive. + bs := pfx.Addr().AsSlice() + byteIdx := 0 + numBits := pfx.Bits() + for numBits > 8 { + if debugDelete { + fmt.Printf("delete: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix) + } + child := st.getChild(bs[byteIdx]) + if child == nil { + // Prefix can't exist in the table, because one of the + // necessary strideTables doesn't exist. + if debugDelete { + fmt.Printf("delete: missing necessary child pfx=%s\n", pfx) + } + return + } + strideIndexes[strideIdx] = bs[byteIdx] + strideTables[strideIdx+1] = child + strideIdx++ + + // Path compression means byteIdx can jump forwards + // unpredictably. Recompute the next byte to look at from the + // child we just found. + byteIdx = child.prefix.Bits() / 8 + numBits = pfx.Bits() - child.prefix.Bits() + st = child + + if debugDelete { + fmt.Printf("delete: descend st.prefix=%s\n", st.prefix) + } + } + + // We reached a leaf stride table that seems to be in the right + // spot. But path compression might have led us to the wrong + // table. + if !prefixStrictlyContains(st.prefix, pfx) { + // Wrong table, the requested prefix can't exist since its + // path led us to the wrong place. + if debugDelete { + fmt.Printf("delete: wrong leaf table pfx=%s\n", pfx) + } + return + } + if debugDelete { + fmt.Printf("delete: delete from st.prefix=%s addr=%d/%d\n", st.prefix, bs[byteIdx], numBits) + } + if routeExisted := st.delete(bs[byteIdx], numBits); !routeExisted { + // We're in the right strideTable, but pfx wasn't in + // it. Refcounts haven't changed, so we can skip cleanup. + if debugDelete { + fmt.Printf("delete: prefix not present pfx=%s\n", pfx) + } + return + } + + // st.delete reduced st's refcount by one. This table may now be + // reclaimable, and depending on how we can reclaim it, the parent + // tables may also need to be reclaimed. This loop ends as soon as + // an iteration takes no action, or takes an action that doesn't + // alter the parent table's refcounts. + // + // We start our walk back at strideTables[strideIdx], which + // contains st. + for strideIdx > 0 { + cur := strideTables[strideIdx] + if debugDelete { + fmt.Printf("delete: GC? strideIdx=%d st.prefix=%s\n", strideIdx, cur.prefix) + } + if cur.routeRefs > 0 { + // the strideTable has other route entries, it cannot be + // deleted or compacted. + if debugDelete { + fmt.Printf("delete: has other routes st.prefix=%s\n", cur.prefix) + } + return + } + switch cur.childRefs { + case 0: + // no routeRefs and no childRefs, this table can be + // deleted. This will alter the parent table's refcount, + // so we'll have to look at it as well (in the next loop + // iteration). + if debugDelete { + fmt.Printf("delete: remove st.prefix=%s\n", cur.prefix) + } + strideTables[strideIdx-1].deleteChild(strideIndexes[strideIdx-1]) + strideIdx-- + case 1: + // This table has no routes, and a single child. Compact + // this table out of existence by making the parent point + // directly at the one child. This does not affect the + // parent's refcounts, so the parent can't be eligible for + // deletion or compaction, and we can stop. + child := strideTables[strideIdx].findFirstChild() // only 1 child exists, by definition + parent := strideTables[strideIdx-1] + if debugDelete { + fmt.Printf("delete: compact parent.prefix=%s st.prefix=%s child.prefix=%s\n", parent.prefix, cur.prefix, child.prefix) + } + strideTables[strideIdx-1].setChild(strideIndexes[strideIdx-1], child) + return + default: + // This table has two or more children, so it's acting as a "fork in + // the road" between two prefix subtrees. It cannot be deleted, and + // thus no further cleanups are possible. + if debugDelete { + fmt.Printf("delete: fork table st.prefix=%s\n", cur.prefix) + } + return + } + } +} + +// debugSummary prints the tree of allocated strideTables in t, with each +// strideTable's refcount. +func (t *Table[T]) debugSummary() string { + t.init() + var ret bytes.Buffer + fmt.Fprintf(&ret, "v4: ") + strideSummary(&ret, &t.v4, 4) + fmt.Fprintf(&ret, "v6: ") + strideSummary(&ret, &t.v6, 4) + return ret.String() +} + +func strideSummary[T any](w io.Writer, st *strideTable[T], indent int) { + fmt.Fprintf(w, "%s: %d routes, %d children\n", st.prefix, st.routeRefs, st.childRefs) + indent += 4 + st.treeDebugStringRec(w, 1, indent) + for addr, child := range st.children { + if child == nil { + continue + } + fmt.Fprintf(w, "%s%d/8 (%02x/8): ", strings.Repeat(" ", indent), addr, addr) + strideSummary(w, child, indent) + } +} + +// prefixStrictlyContains reports whether child is a prefix within +// parent, but not parent itself. +func prefixStrictlyContains(parent, child netip.Prefix) bool { + return parent.Overlaps(child) && parent.Bits() < child.Bits() +} + +// computePrefixSplit returns the smallest common prefix that contains +// both a and b. lastCommon is 8-bit aligned, with aStride and bStride +// indicating the value of the 8-bit stride immediately following +// lastCommon. +// +// computePrefixSplit is used in constructing an intermediate +// strideTable when a new prefix needs to be inserted in a compressed +// table. It can be read as: given that a is already in the table, and +// b is being inserted, what is the prefix of the new intermediate +// strideTable that needs to be created, and at what addresses in that +// new strideTable should a and b's subsequent strideTables be +// attached? +// +// Note as a special case, this can be called with a==b. An example of +// when this happens: +// - We want to insert the prefix 1.2.0.0/16 +// - A strideTable exists for 1.2.0.0/16, because another child +// prefix already exists (e.g. 1.2.3.4/32) +// - The 1.0.0.0/8 strideTable does not exist, because path +// compression removed it. +// +// In this scenario, the caller of computePrefixSplit ends up making a +// "wrong turn" while traversing strideTables: it was looking for the +// 1.0.0.0/8 table, but ended up at the 1.2.0.0/16 table. When this +// happens, it will invoke computePrefixSplit(1.2.0.0/16, 1.2.0.0/16), +// and we return 1.0.0.0/8 as the missing intermediate. +func computePrefixSplit(a, b netip.Prefix) (lastCommon netip.Prefix, aStride, bStride uint8) { + a = a.Masked() + b = b.Masked() + if a.Bits() == 0 || b.Bits() == 0 { + panic("computePrefixSplit called with a default route") + } + if a.Addr().Is4() != b.Addr().Is4() { + panic("computePrefixSplit called with mismatched address families") + } + + minPrefixLen := a.Bits() + if b.Bits() < minPrefixLen { + minPrefixLen = b.Bits() + } + + commonBits := commonBits(a.Addr(), b.Addr(), minPrefixLen) + // We want to know how many 8-bit strides are shared between a and + // b. Naively, this would be commonBits/8, but this introduces an + // off-by-one error. This is due to the way our ART stores + // prefixes whose length falls exactly on a stride boundary. + // + // Consider 192.168.1.0/24 and 192.168.0.0/16. commonBits + // correctly reports that these prefixes have their first 16 bits + // in common. However, in the ART they only share 1 common stride: + // they both use the 192.0.0.0/8 strideTable, but 192.168.0.0/16 + // is stored as 168/8 within that table, and not as 0/0 in the + // 192.168.0.0/16 table. + // + // So, when commonBits matches the length of one of the inputs and + // falls on a boundary between strides, the strideTable one + // further up from commonBits/8 is the one we need to create, + // which means we have to adjust the stride count down by one. + if commonBits == minPrefixLen { + commonBits-- + } + commonStrides := commonBits / 8 + lastCommon, err := a.Addr().Prefix(commonStrides * 8) + if err != nil { + panic(fmt.Sprintf("computePrefixSplit constructing common prefix: %v", err)) + } + if a.Addr().Is4() { + aStride = a.Addr().As4()[commonStrides] + bStride = b.Addr().As4()[commonStrides] + } else { + aStride = a.Addr().As16()[commonStrides] + bStride = b.Addr().As16()[commonStrides] + } + return lastCommon, aStride, bStride +} + +// commonBits returns the number of common leading bits of a and b. +// If the number of common bits exceeds maxBits, it returns maxBits +// instead. +func commonBits(a, b netip.Addr, maxBits int) int { + if a.Is4() != b.Is4() { + panic("commonStrides called with mismatched address families") + } + var common int + // The following implements an old bit-twiddling trick to compute + // the number of common leading bits: if you XOR two numbers + // together, equal bits become 0 and unequal bits become 1. You + // can then count the number of leading zeros (which is a single + // instruction on modern CPUs) to get the answer. + // + // This code is a little more complex than just XOR + count + // leading zeros, because IPv4 and IPv6 are different sizes, and + // for IPv6 we have to do the math in two 64-bit chunks because Go + // lacks a uint128 type. + if a.Is4() { + aNum, bNum := ipv4AsUint(a), ipv4AsUint(b) + common = bits.LeadingZeros32(aNum ^ bNum) + } else { + aNumHi, aNumLo := ipv6AsUint(a) + bNumHi, bNumLo := ipv6AsUint(b) + common = bits.LeadingZeros64(aNumHi ^ bNumHi) + if common == 64 { + common += bits.LeadingZeros64(aNumLo ^ bNumLo) + } + } + if common > maxBits { + common = maxBits + } + return common +} + +// ipv4AsUint returns ip as a uint32. +func ipv4AsUint(ip netip.Addr) uint32 { + bs := ip.As4() + return binary.BigEndian.Uint32(bs[:]) +} + +// ipv6AsUint returns ip as a pair of uint64s. +func ipv6AsUint(ip netip.Addr) (uint64, uint64) { + bs := ip.As16() + return binary.BigEndian.Uint64(bs[:8]), binary.BigEndian.Uint64(bs[8:]) +} diff --git a/net/dns/debian_resolvconf.go b/net/dns/debian_resolvconf.go index 3ffc796e06d1b..2a1fb18de967f 100644 --- a/net/dns/debian_resolvconf.go +++ b/net/dns/debian_resolvconf.go @@ -1,184 +1,184 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || freebsd || openbsd - -package dns - -import ( - "bufio" - "bytes" - _ "embed" - "fmt" - "os" - "os/exec" - "path/filepath" - - "tailscale.com/atomicfile" - "tailscale.com/types/logger" -) - -//go:embed resolvconf-workaround.sh -var workaroundScript []byte - -// resolvconfConfigName is the name of the config submitted to -// resolvconf. -// The name starts with 'tun' in order to match the hardcoded -// interface order in debian resolvconf, which will place this -// configuration ahead of regular network links. In theory, this -// doesn't matter because we then fix things up to ensure our config -// is the only one in use, but in case that fails, this will make our -// configuration slightly preferred. -// The 'inet' suffix has no specific meaning, but conventionally -// resolvconf implementations encourage adding a suffix roughly -// indicating where the config came from, and "inet" is the "none of -// the above" value (rather than, say, "ppp" or "dhcp"). -const resolvconfConfigName = "tun-tailscale.inet" - -// resolvconfLibcHookPath is the directory containing libc update -// scripts, which are run by Debian resolvconf when /etc/resolv.conf -// has been updated. -const resolvconfLibcHookPath = "/etc/resolvconf/update-libc.d" - -// resolvconfHookPath is the name of the libc hook script we install -// to force Tailscale's DNS config to take effect. -var resolvconfHookPath = filepath.Join(resolvconfLibcHookPath, "tailscale") - -// resolvconfManager manages DNS configuration using the Debian -// implementation of the `resolvconf` program, written by Thomas Hood. -type resolvconfManager struct { - logf logger.Logf - listRecordsPath string - interfacesDir string - scriptInstalled bool // libc update script has been installed -} - -func newDebianResolvconfManager(logf logger.Logf) (*resolvconfManager, error) { - ret := &resolvconfManager{ - logf: logf, - listRecordsPath: "/lib/resolvconf/list-records", - interfacesDir: "/etc/resolvconf/run/interface", // panic fallback if nothing seems to work - } - - if _, err := os.Stat(ret.listRecordsPath); os.IsNotExist(err) { - // This might be a Debian system from before the big /usr - // merge, try /usr instead. - ret.listRecordsPath = "/usr" + ret.listRecordsPath - } - // The runtime directory is currently (2020-04) canonically - // /etc/resolvconf/run, but the manpage is making noise about - // switching to /run/resolvconf and dropping the /etc path. So, - // let's probe the possible directories and use the first one - // that works. - for _, path := range []string{ - "/etc/resolvconf/run/interface", - "/run/resolvconf/interface", - "/var/run/resolvconf/interface", - } { - if _, err := os.Stat(path); err == nil { - ret.interfacesDir = path - break - } - } - if ret.interfacesDir == "" { - // None of the paths seem to work, use the canonical location - // that the current manpage says to use. - ret.interfacesDir = "/etc/resolvconf/run/interfaces" - } - - return ret, nil -} - -func (m *resolvconfManager) deleteTailscaleConfig() error { - cmd := exec.Command("resolvconf", "-d", resolvconfConfigName) - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("running %s: %s", cmd, out) - } - return nil -} - -func (m *resolvconfManager) SetDNS(config OSConfig) error { - if !m.scriptInstalled { - m.logf("injecting resolvconf workaround script") - if err := os.MkdirAll(resolvconfLibcHookPath, 0755); err != nil { - return err - } - if err := atomicfile.WriteFile(resolvconfHookPath, workaroundScript, 0755); err != nil { - return err - } - m.scriptInstalled = true - } - - if config.IsZero() { - if err := m.deleteTailscaleConfig(); err != nil { - return err - } - } else { - stdin := new(bytes.Buffer) - writeResolvConf(stdin, config.Nameservers, config.SearchDomains) // dns_direct.go - - // This resolvconf implementation doesn't support exclusive - // mode or interface priorities, so it will end up blending - // our configuration with other sources. However, this will - // get fixed up by the script we injected above. - cmd := exec.Command("resolvconf", "-a", resolvconfConfigName) - cmd.Stdin = stdin - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("running %s: %s", cmd, out) - } - } - - return nil -} - -func (m *resolvconfManager) SupportsSplitDNS() bool { - return false -} - -func (m *resolvconfManager) GetBaseConfig() (OSConfig, error) { - var bs bytes.Buffer - - cmd := exec.Command(m.listRecordsPath) - // list-records assumes it's being run with CWD set to the - // interfaces runtime dir, and returns nonsense otherwise. - cmd.Dir = m.interfacesDir - cmd.Stdout = &bs - if err := cmd.Run(); err != nil { - return OSConfig{}, err - } - - var conf bytes.Buffer - sc := bufio.NewScanner(&bs) - for sc.Scan() { - if sc.Text() == resolvconfConfigName { - continue - } - bs, err := os.ReadFile(filepath.Join(m.interfacesDir, sc.Text())) - if err != nil { - if os.IsNotExist(err) { - // Probably raced with a deletion, that's okay. - continue - } - return OSConfig{}, err - } - conf.Write(bs) - conf.WriteByte('\n') - } - - return readResolv(&conf) -} - -func (m *resolvconfManager) Close() error { - if err := m.deleteTailscaleConfig(); err != nil { - return err - } - - if m.scriptInstalled { - m.logf("removing resolvconf workaround script") - os.Remove(resolvconfHookPath) // Best-effort - } - - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux || freebsd || openbsd + +package dns + +import ( + "bufio" + "bytes" + _ "embed" + "fmt" + "os" + "os/exec" + "path/filepath" + + "tailscale.com/atomicfile" + "tailscale.com/types/logger" +) + +//go:embed resolvconf-workaround.sh +var workaroundScript []byte + +// resolvconfConfigName is the name of the config submitted to +// resolvconf. +// The name starts with 'tun' in order to match the hardcoded +// interface order in debian resolvconf, which will place this +// configuration ahead of regular network links. In theory, this +// doesn't matter because we then fix things up to ensure our config +// is the only one in use, but in case that fails, this will make our +// configuration slightly preferred. +// The 'inet' suffix has no specific meaning, but conventionally +// resolvconf implementations encourage adding a suffix roughly +// indicating where the config came from, and "inet" is the "none of +// the above" value (rather than, say, "ppp" or "dhcp"). +const resolvconfConfigName = "tun-tailscale.inet" + +// resolvconfLibcHookPath is the directory containing libc update +// scripts, which are run by Debian resolvconf when /etc/resolv.conf +// has been updated. +const resolvconfLibcHookPath = "/etc/resolvconf/update-libc.d" + +// resolvconfHookPath is the name of the libc hook script we install +// to force Tailscale's DNS config to take effect. +var resolvconfHookPath = filepath.Join(resolvconfLibcHookPath, "tailscale") + +// resolvconfManager manages DNS configuration using the Debian +// implementation of the `resolvconf` program, written by Thomas Hood. +type resolvconfManager struct { + logf logger.Logf + listRecordsPath string + interfacesDir string + scriptInstalled bool // libc update script has been installed +} + +func newDebianResolvconfManager(logf logger.Logf) (*resolvconfManager, error) { + ret := &resolvconfManager{ + logf: logf, + listRecordsPath: "/lib/resolvconf/list-records", + interfacesDir: "/etc/resolvconf/run/interface", // panic fallback if nothing seems to work + } + + if _, err := os.Stat(ret.listRecordsPath); os.IsNotExist(err) { + // This might be a Debian system from before the big /usr + // merge, try /usr instead. + ret.listRecordsPath = "/usr" + ret.listRecordsPath + } + // The runtime directory is currently (2020-04) canonically + // /etc/resolvconf/run, but the manpage is making noise about + // switching to /run/resolvconf and dropping the /etc path. So, + // let's probe the possible directories and use the first one + // that works. + for _, path := range []string{ + "/etc/resolvconf/run/interface", + "/run/resolvconf/interface", + "/var/run/resolvconf/interface", + } { + if _, err := os.Stat(path); err == nil { + ret.interfacesDir = path + break + } + } + if ret.interfacesDir == "" { + // None of the paths seem to work, use the canonical location + // that the current manpage says to use. + ret.interfacesDir = "/etc/resolvconf/run/interfaces" + } + + return ret, nil +} + +func (m *resolvconfManager) deleteTailscaleConfig() error { + cmd := exec.Command("resolvconf", "-d", resolvconfConfigName) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("running %s: %s", cmd, out) + } + return nil +} + +func (m *resolvconfManager) SetDNS(config OSConfig) error { + if !m.scriptInstalled { + m.logf("injecting resolvconf workaround script") + if err := os.MkdirAll(resolvconfLibcHookPath, 0755); err != nil { + return err + } + if err := atomicfile.WriteFile(resolvconfHookPath, workaroundScript, 0755); err != nil { + return err + } + m.scriptInstalled = true + } + + if config.IsZero() { + if err := m.deleteTailscaleConfig(); err != nil { + return err + } + } else { + stdin := new(bytes.Buffer) + writeResolvConf(stdin, config.Nameservers, config.SearchDomains) // dns_direct.go + + // This resolvconf implementation doesn't support exclusive + // mode or interface priorities, so it will end up blending + // our configuration with other sources. However, this will + // get fixed up by the script we injected above. + cmd := exec.Command("resolvconf", "-a", resolvconfConfigName) + cmd.Stdin = stdin + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("running %s: %s", cmd, out) + } + } + + return nil +} + +func (m *resolvconfManager) SupportsSplitDNS() bool { + return false +} + +func (m *resolvconfManager) GetBaseConfig() (OSConfig, error) { + var bs bytes.Buffer + + cmd := exec.Command(m.listRecordsPath) + // list-records assumes it's being run with CWD set to the + // interfaces runtime dir, and returns nonsense otherwise. + cmd.Dir = m.interfacesDir + cmd.Stdout = &bs + if err := cmd.Run(); err != nil { + return OSConfig{}, err + } + + var conf bytes.Buffer + sc := bufio.NewScanner(&bs) + for sc.Scan() { + if sc.Text() == resolvconfConfigName { + continue + } + bs, err := os.ReadFile(filepath.Join(m.interfacesDir, sc.Text())) + if err != nil { + if os.IsNotExist(err) { + // Probably raced with a deletion, that's okay. + continue + } + return OSConfig{}, err + } + conf.Write(bs) + conf.WriteByte('\n') + } + + return readResolv(&conf) +} + +func (m *resolvconfManager) Close() error { + if err := m.deleteTailscaleConfig(); err != nil { + return err + } + + if m.scriptInstalled { + m.logf("removing resolvconf workaround script") + os.Remove(resolvconfHookPath) // Best-effort + } + + return nil +} diff --git a/net/dns/direct_notlinux.go b/net/dns/direct_notlinux.go index c221ca1beaa59..5bd8093d65b7b 100644 --- a/net/dns/direct_notlinux.go +++ b/net/dns/direct_notlinux.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package dns - -func (m *directManager) runFileWatcher() { - // Not implemented on other platforms. Maybe it could resort to polling. -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package dns + +func (m *directManager) runFileWatcher() { + // Not implemented on other platforms. Maybe it could resort to polling. +} diff --git a/net/dns/flush_default.go b/net/dns/flush_default.go index eb6d9da417104..73e446389e2c7 100644 --- a/net/dns/flush_default.go +++ b/net/dns/flush_default.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package dns - -func flushCaches() error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package dns + +func flushCaches() error { + return nil +} diff --git a/net/dns/ini.go b/net/dns/ini.go index 1e47d606e970f..deec04019560f 100644 --- a/net/dns/ini.go +++ b/net/dns/ini.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build windows - -package dns - -import ( - "regexp" - "strings" -) - -// parseIni parses a basic .ini file, used for wsl.conf. -func parseIni(data string) map[string]map[string]string { - sectionRE := regexp.MustCompile(`^\[([^]]+)\]`) - kvRE := regexp.MustCompile(`^\s*(\w+)\s*=\s*([^#]*)`) - - ini := map[string]map[string]string{} - var section string - for _, line := range strings.Split(data, "\n") { - if res := sectionRE.FindStringSubmatch(line); len(res) > 1 { - section = res[1] - ini[section] = map[string]string{} - } else if res := kvRE.FindStringSubmatch(line); len(res) > 2 { - k, v := strings.TrimSpace(res[1]), strings.TrimSpace(res[2]) - ini[section][k] = v - } - } - return ini -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package dns + +import ( + "regexp" + "strings" +) + +// parseIni parses a basic .ini file, used for wsl.conf. +func parseIni(data string) map[string]map[string]string { + sectionRE := regexp.MustCompile(`^\[([^]]+)\]`) + kvRE := regexp.MustCompile(`^\s*(\w+)\s*=\s*([^#]*)`) + + ini := map[string]map[string]string{} + var section string + for _, line := range strings.Split(data, "\n") { + if res := sectionRE.FindStringSubmatch(line); len(res) > 1 { + section = res[1] + ini[section] = map[string]string{} + } else if res := kvRE.FindStringSubmatch(line); len(res) > 2 { + k, v := strings.TrimSpace(res[1]), strings.TrimSpace(res[2]) + ini[section][k] = v + } + } + return ini +} diff --git a/net/dns/ini_test.go b/net/dns/ini_test.go index 3afe7009caa27..0e9eaa6727bbe 100644 --- a/net/dns/ini_test.go +++ b/net/dns/ini_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build windows - -package dns - -import ( - "reflect" - "testing" -) - -func TestParseIni(t *testing.T) { - var tests = []struct { - src string - want map[string]map[string]string - }{ - { - src: `# appended wsl.conf file -[automount] - enabled = true - root=/mnt/ -# added by tailscale -[network] # trailing comment -generateResolvConf = false # trailing comment`, - want: map[string]map[string]string{ - "automount": {"enabled": "true", "root": "/mnt/"}, - "network": {"generateResolvConf": "false"}, - }, - }, - } - for _, test := range tests { - got := parseIni(test.src) - if !reflect.DeepEqual(got, test.want) { - t.Errorf("for:\n%s\ngot: %v\nwant: %v", test.src, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package dns + +import ( + "reflect" + "testing" +) + +func TestParseIni(t *testing.T) { + var tests = []struct { + src string + want map[string]map[string]string + }{ + { + src: `# appended wsl.conf file +[automount] + enabled = true + root=/mnt/ +# added by tailscale +[network] # trailing comment +generateResolvConf = false # trailing comment`, + want: map[string]map[string]string{ + "automount": {"enabled": "true", "root": "/mnt/"}, + "network": {"generateResolvConf": "false"}, + }, + }, + } + for _, test := range tests { + got := parseIni(test.src) + if !reflect.DeepEqual(got, test.want) { + t.Errorf("for:\n%s\ngot: %v\nwant: %v", test.src, got, test.want) + } + } +} diff --git a/net/dns/noop.go b/net/dns/noop.go index 9466b57a0f477..c90162668e85d 100644 --- a/net/dns/noop.go +++ b/net/dns/noop.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dns - -type noopManager struct{} - -func (m noopManager) SetDNS(OSConfig) error { return nil } -func (m noopManager) SupportsSplitDNS() bool { return false } -func (m noopManager) Close() error { return nil } -func (m noopManager) GetBaseConfig() (OSConfig, error) { - return OSConfig{}, ErrGetBaseConfigNotSupported -} - -func NewNoopManager() (noopManager, error) { - return noopManager{}, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +type noopManager struct{} + +func (m noopManager) SetDNS(OSConfig) error { return nil } +func (m noopManager) SupportsSplitDNS() bool { return false } +func (m noopManager) Close() error { return nil } +func (m noopManager) GetBaseConfig() (OSConfig, error) { + return OSConfig{}, ErrGetBaseConfigNotSupported +} + +func NewNoopManager() (noopManager, error) { + return noopManager{}, nil +} diff --git a/net/dns/resolvconf-workaround.sh b/net/dns/resolvconf-workaround.sh index aec6708a06da1..254b3949b1930 100644 --- a/net/dns/resolvconf-workaround.sh +++ b/net/dns/resolvconf-workaround.sh @@ -1,62 +1,62 @@ -#!/bin/sh -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause -# -# This script is a workaround for a vpn-unfriendly behavior of the -# original resolvconf by Thomas Hood. Unlike the `openresolv` -# implementation (whose binary is also called resolvconf, -# confusingly), the original resolvconf lacks a way to specify -# "exclusive mode" for a provider configuration. In practice, this -# means that if Tailscale wants to install a DNS configuration, that -# config will get "blended" with the configs from other sources, -# rather than override those other sources. -# -# This script gets installed at /etc/resolvconf/update-libc.d, which -# is a directory of hook scripts that get run after resolvconf's libc -# helper has finished rewriting /etc/resolv.conf. It's meant to notify -# consumers of resolv.conf of a new configuration. -# -# Instead, we use that hook mechanism to reach into resolvconf's -# stuff, and rewrite the libc-generated resolv.conf to exclusively -# contain Tailscale's configuration - effectively implementing -# exclusive mode ourselves in post-production. - -set -e - -if [ -n "$TAILSCALE_RESOLVCONF_HOOK_LOOP" ]; then - # Hook script being invoked by itself, skip. - exit 0 -fi - -if [ ! -f tun-tailscale.inet ]; then - # Tailscale isn't trying to manage DNS, do nothing. - exit 0 -fi - -if ! grep resolvconf /etc/resolv.conf >/dev/null; then - # resolvconf isn't managing /etc/resolv.conf, do nothing. - exit 0 -fi - -# Write out a modified /etc/resolv.conf containing just our config. -( - if [ -f /etc/resolvconf/resolv.conf.d/head ]; then - cat /etc/resolvconf/resolv.conf.d/head - fi - echo "# Tailscale workaround applied to set exclusive DNS configuration." - cat tun-tailscale.inet - if [ -f /etc/resolvconf/resolv.conf.d/base ]; then - # Keep options and sortlist, discard other base things since - # they're the things we're trying to override. - grep -e 'sortlist ' -e 'options ' /etc/resolvconf/resolv.conf.d/base || true - fi - if [ -f /etc/resolvconf/resolv.conf.d/tail ]; then - cat /etc/resolvconf/resolv.conf.d/tail - fi -) >/etc/resolv.conf - -if [ -d /etc/resolvconf/update-libc.d ] ; then - # Re-notify libc watchers that we've changed resolv.conf again. - export TAILSCALE_RESOLVCONF_HOOK_LOOP=1 - exec run-parts /etc/resolvconf/update-libc.d -fi +#!/bin/sh +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause +# +# This script is a workaround for a vpn-unfriendly behavior of the +# original resolvconf by Thomas Hood. Unlike the `openresolv` +# implementation (whose binary is also called resolvconf, +# confusingly), the original resolvconf lacks a way to specify +# "exclusive mode" for a provider configuration. In practice, this +# means that if Tailscale wants to install a DNS configuration, that +# config will get "blended" with the configs from other sources, +# rather than override those other sources. +# +# This script gets installed at /etc/resolvconf/update-libc.d, which +# is a directory of hook scripts that get run after resolvconf's libc +# helper has finished rewriting /etc/resolv.conf. It's meant to notify +# consumers of resolv.conf of a new configuration. +# +# Instead, we use that hook mechanism to reach into resolvconf's +# stuff, and rewrite the libc-generated resolv.conf to exclusively +# contain Tailscale's configuration - effectively implementing +# exclusive mode ourselves in post-production. + +set -e + +if [ -n "$TAILSCALE_RESOLVCONF_HOOK_LOOP" ]; then + # Hook script being invoked by itself, skip. + exit 0 +fi + +if [ ! -f tun-tailscale.inet ]; then + # Tailscale isn't trying to manage DNS, do nothing. + exit 0 +fi + +if ! grep resolvconf /etc/resolv.conf >/dev/null; then + # resolvconf isn't managing /etc/resolv.conf, do nothing. + exit 0 +fi + +# Write out a modified /etc/resolv.conf containing just our config. +( + if [ -f /etc/resolvconf/resolv.conf.d/head ]; then + cat /etc/resolvconf/resolv.conf.d/head + fi + echo "# Tailscale workaround applied to set exclusive DNS configuration." + cat tun-tailscale.inet + if [ -f /etc/resolvconf/resolv.conf.d/base ]; then + # Keep options and sortlist, discard other base things since + # they're the things we're trying to override. + grep -e 'sortlist ' -e 'options ' /etc/resolvconf/resolv.conf.d/base || true + fi + if [ -f /etc/resolvconf/resolv.conf.d/tail ]; then + cat /etc/resolvconf/resolv.conf.d/tail + fi +) >/etc/resolv.conf + +if [ -d /etc/resolvconf/update-libc.d ] ; then + # Re-notify libc watchers that we've changed resolv.conf again. + export TAILSCALE_RESOLVCONF_HOOK_LOOP=1 + exec run-parts /etc/resolvconf/update-libc.d +fi diff --git a/net/dns/resolvconf.go b/net/dns/resolvconf.go index ca584ffcc5f1f..9e2a41c4ac45b 100644 --- a/net/dns/resolvconf.go +++ b/net/dns/resolvconf.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || freebsd || openbsd - -package dns - -import ( - "bytes" - "os/exec" -) - -func resolvconfStyle() string { - if _, err := exec.LookPath("resolvconf"); err != nil { - return "" - } - output, err := exec.Command("resolvconf", "--version").CombinedOutput() - if err != nil { - // Debian resolvconf doesn't understand --version, and - // exits with a specific error code. - if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 99 { - return "debian" - } - } - if bytes.HasPrefix(output, []byte("Debian resolvconf")) { - return "debian" - } - // Treat everything else as openresolv, by far the more popular implementation. - return "openresolv" -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux || freebsd || openbsd + +package dns + +import ( + "bytes" + "os/exec" +) + +func resolvconfStyle() string { + if _, err := exec.LookPath("resolvconf"); err != nil { + return "" + } + output, err := exec.Command("resolvconf", "--version").CombinedOutput() + if err != nil { + // Debian resolvconf doesn't understand --version, and + // exits with a specific error code. + if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 99 { + return "debian" + } + } + if bytes.HasPrefix(output, []byte("Debian resolvconf")) { + return "debian" + } + // Treat everything else as openresolv, by far the more popular implementation. + return "openresolv" +} diff --git a/net/dns/resolvconffile/resolvconffile.go b/net/dns/resolvconffile/resolvconffile.go index 753000f6d33da..66c1600d8ecba 100644 --- a/net/dns/resolvconffile/resolvconffile.go +++ b/net/dns/resolvconffile/resolvconffile.go @@ -1,124 +1,124 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package resolvconffile parses & serializes /etc/resolv.conf-style files. -// -// It's a leaf package so both net/dns and net/dns/resolver can depend -// on it and we can unify a handful of implementations. -// -// The package is verbosely named to disambiguate it from resolvconf -// the daemon, which Tailscale also supports. -package resolvconffile - -import ( - "bufio" - "bytes" - "fmt" - "io" - "net/netip" - "os" - "strings" - - "tailscale.com/util/dnsname" -) - -// Path is the canonical location of resolv.conf. -const Path = "/etc/resolv.conf" - -// Config represents a resolv.conf(5) file. -type Config struct { - // Nameservers are the IP addresses of the nameservers to use. - Nameservers []netip.Addr - - // SearchDomains are the domain suffixes to use when expanding - // single-label name queries. SearchDomains is additive to - // whatever non-Tailscale search domains the OS has. - SearchDomains []dnsname.FQDN -} - -// Write writes c to w. It does so in one Write call. -func (c *Config) Write(w io.Writer) error { - buf := new(bytes.Buffer) - io.WriteString(buf, "# resolv.conf(5) file generated by tailscale\n") - io.WriteString(buf, "# For more info, see https://tailscale.com/s/resolvconf-overwrite\n") - io.WriteString(buf, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n") - for _, ns := range c.Nameservers { - io.WriteString(buf, "nameserver ") - io.WriteString(buf, ns.String()) - io.WriteString(buf, "\n") - } - if len(c.SearchDomains) > 0 { - io.WriteString(buf, "search") - for _, domain := range c.SearchDomains { - io.WriteString(buf, " ") - io.WriteString(buf, domain.WithoutTrailingDot()) - } - io.WriteString(buf, "\n") - } - _, err := w.Write(buf.Bytes()) - return err -} - -// Parse parses a resolv.conf file from r. -func Parse(r io.Reader) (*Config, error) { - config := new(Config) - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - line, _, _ = strings.Cut(line, "#") // remove any comments - line = strings.TrimSpace(line) - - if s, ok := strings.CutPrefix(line, "nameserver"); ok { - nameserver := strings.TrimSpace(s) - if len(nameserver) == len(s) { - return nil, fmt.Errorf("missing space after \"nameserver\" in %q", line) - } - ip, err := netip.ParseAddr(nameserver) - if err != nil { - return nil, err - } - config.Nameservers = append(config.Nameservers, ip) - continue - } - - if s, ok := strings.CutPrefix(line, "search"); ok { - domains := strings.TrimSpace(s) - if len(domains) == len(s) { - // No leading space?! - return nil, fmt.Errorf("missing space after \"search\" in %q", line) - } - for len(domains) > 0 { - domain := domains - i := strings.IndexAny(domain, " \t") - if i != -1 { - domain = domain[:i] - domains = strings.TrimSpace(domains[i+1:]) - } else { - domains = "" - } - fqdn, err := dnsname.ToFQDN(domain) - if err != nil { - return nil, fmt.Errorf("parsing search domain %q in %q: %w", domain, line, err) - } - config.SearchDomains = append(config.SearchDomains, fqdn) - } - } - } - return config, nil -} - -// ParseFile parses the named resolv.conf file. -func ParseFile(name string) (*Config, error) { - fi, err := os.Stat(name) - if err != nil { - return nil, err - } - if n := fi.Size(); n > 10<<10 { - return nil, fmt.Errorf("unexpectedly large %q file: %d bytes", name, n) - } - all, err := os.ReadFile(name) - if err != nil { - return nil, err - } - return Parse(bytes.NewReader(all)) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package resolvconffile parses & serializes /etc/resolv.conf-style files. +// +// It's a leaf package so both net/dns and net/dns/resolver can depend +// on it and we can unify a handful of implementations. +// +// The package is verbosely named to disambiguate it from resolvconf +// the daemon, which Tailscale also supports. +package resolvconffile + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/netip" + "os" + "strings" + + "tailscale.com/util/dnsname" +) + +// Path is the canonical location of resolv.conf. +const Path = "/etc/resolv.conf" + +// Config represents a resolv.conf(5) file. +type Config struct { + // Nameservers are the IP addresses of the nameservers to use. + Nameservers []netip.Addr + + // SearchDomains are the domain suffixes to use when expanding + // single-label name queries. SearchDomains is additive to + // whatever non-Tailscale search domains the OS has. + SearchDomains []dnsname.FQDN +} + +// Write writes c to w. It does so in one Write call. +func (c *Config) Write(w io.Writer) error { + buf := new(bytes.Buffer) + io.WriteString(buf, "# resolv.conf(5) file generated by tailscale\n") + io.WriteString(buf, "# For more info, see https://tailscale.com/s/resolvconf-overwrite\n") + io.WriteString(buf, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n") + for _, ns := range c.Nameservers { + io.WriteString(buf, "nameserver ") + io.WriteString(buf, ns.String()) + io.WriteString(buf, "\n") + } + if len(c.SearchDomains) > 0 { + io.WriteString(buf, "search") + for _, domain := range c.SearchDomains { + io.WriteString(buf, " ") + io.WriteString(buf, domain.WithoutTrailingDot()) + } + io.WriteString(buf, "\n") + } + _, err := w.Write(buf.Bytes()) + return err +} + +// Parse parses a resolv.conf file from r. +func Parse(r io.Reader) (*Config, error) { + config := new(Config) + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + line, _, _ = strings.Cut(line, "#") // remove any comments + line = strings.TrimSpace(line) + + if s, ok := strings.CutPrefix(line, "nameserver"); ok { + nameserver := strings.TrimSpace(s) + if len(nameserver) == len(s) { + return nil, fmt.Errorf("missing space after \"nameserver\" in %q", line) + } + ip, err := netip.ParseAddr(nameserver) + if err != nil { + return nil, err + } + config.Nameservers = append(config.Nameservers, ip) + continue + } + + if s, ok := strings.CutPrefix(line, "search"); ok { + domains := strings.TrimSpace(s) + if len(domains) == len(s) { + // No leading space?! + return nil, fmt.Errorf("missing space after \"search\" in %q", line) + } + for len(domains) > 0 { + domain := domains + i := strings.IndexAny(domain, " \t") + if i != -1 { + domain = domain[:i] + domains = strings.TrimSpace(domains[i+1:]) + } else { + domains = "" + } + fqdn, err := dnsname.ToFQDN(domain) + if err != nil { + return nil, fmt.Errorf("parsing search domain %q in %q: %w", domain, line, err) + } + config.SearchDomains = append(config.SearchDomains, fqdn) + } + } + } + return config, nil +} + +// ParseFile parses the named resolv.conf file. +func ParseFile(name string) (*Config, error) { + fi, err := os.Stat(name) + if err != nil { + return nil, err + } + if n := fi.Size(); n > 10<<10 { + return nil, fmt.Errorf("unexpectedly large %q file: %d bytes", name, n) + } + all, err := os.ReadFile(name) + if err != nil { + return nil, err + } + return Parse(bytes.NewReader(all)) +} diff --git a/net/dns/resolvconfpath_default.go b/net/dns/resolvconfpath_default.go index 57e82c4c773ea..02f24a0cfa535 100644 --- a/net/dns/resolvconfpath_default.go +++ b/net/dns/resolvconfpath_default.go @@ -1,11 +1,11 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !gokrazy - -package dns - -const ( - resolvConf = "/etc/resolv.conf" - backupConf = "/etc/resolv.pre-tailscale-backup.conf" -) +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !gokrazy + +package dns + +const ( + resolvConf = "/etc/resolv.conf" + backupConf = "/etc/resolv.pre-tailscale-backup.conf" +) diff --git a/net/dns/resolvconfpath_gokrazy.go b/net/dns/resolvconfpath_gokrazy.go index f0759b0e31a0f..6315596d20efa 100644 --- a/net/dns/resolvconfpath_gokrazy.go +++ b/net/dns/resolvconfpath_gokrazy.go @@ -1,11 +1,11 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build gokrazy - -package dns - -const ( - resolvConf = "/tmp/resolv.conf" - backupConf = "/tmp/resolv.pre-tailscale-backup.conf" -) +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build gokrazy + +package dns + +const ( + resolvConf = "/tmp/resolv.conf" + backupConf = "/tmp/resolv.pre-tailscale-backup.conf" +) diff --git a/net/dns/resolver/doh_test.go b/net/dns/resolver/doh_test.go index a9c28476166fc..d9ef970c224f2 100644 --- a/net/dns/resolver/doh_test.go +++ b/net/dns/resolver/doh_test.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package resolver - -import ( - "context" - "flag" - "net/http" - "testing" - - "golang.org/x/net/dns/dnsmessage" - "tailscale.com/net/dns/publicdns" -) - -var testDoH = flag.Bool("test-doh", false, "do real DoH tests against the network") - -const someDNSID = 123 // something non-zero as a test; in violation of spec's SHOULD of 0 - -func someDNSQuestion(t testing.TB) []byte { - b := dnsmessage.NewBuilder(nil, dnsmessage.Header{ - OpCode: 0, // query - RecursionDesired: true, - ID: someDNSID, - }) - b.StartQuestions() // err - b.Question(dnsmessage.Question{ - Name: dnsmessage.MustNewName("tailscale.com."), - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - }) - msg, err := b.Finish() - if err != nil { - t.Fatal(err) - } - return msg -} - -func TestDoH(t *testing.T) { - if !*testDoH { - t.Skip("skipping manual test without --test-doh flag") - } - prefixes := publicdns.KnownDoHPrefixes() - if len(prefixes) == 0 { - t.Fatal("no known DoH") - } - - f := &forwarder{} - - for _, urlBase := range prefixes { - t.Run(urlBase, func(t *testing.T) { - c, ok := f.getKnownDoHClientForProvider(urlBase) - if !ok { - t.Fatal("expected DoH") - } - res, err := f.sendDoH(context.Background(), urlBase, c, someDNSQuestion(t)) - if err != nil { - t.Fatal(err) - } - c.Transport.(*http.Transport).CloseIdleConnections() - - var p dnsmessage.Parser - h, err := p.Start(res) - if err != nil { - t.Fatal(err) - } - if h.ID != someDNSID { - t.Errorf("response DNS ID = %v; want %v", h.ID, someDNSID) - } - - p.SkipAllQuestions() - aa, err := p.AllAnswers() - if err != nil { - t.Fatal(err) - } - if len(aa) == 0 { - t.Fatal("no answers") - } - for _, r := range aa { - t.Logf("got: %v", r.GoString()) - } - }) - } -} - -func TestDoHV6Fallback(t *testing.T) { - for _, base := range publicdns.KnownDoHPrefixes() { - for _, ip := range publicdns.DoHIPsOfBase(base) { - if ip.Is4() { - ip6, ok := publicdns.DoHV6(base) - if !ok { - t.Errorf("no v6 DoH known for %v", ip) - } else if !ip6.Is6() { - t.Errorf("dohV6(%q) returned non-v6 address %v", base, ip6) - } - } - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package resolver + +import ( + "context" + "flag" + "net/http" + "testing" + + "golang.org/x/net/dns/dnsmessage" + "tailscale.com/net/dns/publicdns" +) + +var testDoH = flag.Bool("test-doh", false, "do real DoH tests against the network") + +const someDNSID = 123 // something non-zero as a test; in violation of spec's SHOULD of 0 + +func someDNSQuestion(t testing.TB) []byte { + b := dnsmessage.NewBuilder(nil, dnsmessage.Header{ + OpCode: 0, // query + RecursionDesired: true, + ID: someDNSID, + }) + b.StartQuestions() // err + b.Question(dnsmessage.Question{ + Name: dnsmessage.MustNewName("tailscale.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }) + msg, err := b.Finish() + if err != nil { + t.Fatal(err) + } + return msg +} + +func TestDoH(t *testing.T) { + if !*testDoH { + t.Skip("skipping manual test without --test-doh flag") + } + prefixes := publicdns.KnownDoHPrefixes() + if len(prefixes) == 0 { + t.Fatal("no known DoH") + } + + f := &forwarder{} + + for _, urlBase := range prefixes { + t.Run(urlBase, func(t *testing.T) { + c, ok := f.getKnownDoHClientForProvider(urlBase) + if !ok { + t.Fatal("expected DoH") + } + res, err := f.sendDoH(context.Background(), urlBase, c, someDNSQuestion(t)) + if err != nil { + t.Fatal(err) + } + c.Transport.(*http.Transport).CloseIdleConnections() + + var p dnsmessage.Parser + h, err := p.Start(res) + if err != nil { + t.Fatal(err) + } + if h.ID != someDNSID { + t.Errorf("response DNS ID = %v; want %v", h.ID, someDNSID) + } + + p.SkipAllQuestions() + aa, err := p.AllAnswers() + if err != nil { + t.Fatal(err) + } + if len(aa) == 0 { + t.Fatal("no answers") + } + for _, r := range aa { + t.Logf("got: %v", r.GoString()) + } + }) + } +} + +func TestDoHV6Fallback(t *testing.T) { + for _, base := range publicdns.KnownDoHPrefixes() { + for _, ip := range publicdns.DoHIPsOfBase(base) { + if ip.Is4() { + ip6, ok := publicdns.DoHV6(base) + if !ok { + t.Errorf("no v6 DoH known for %v", ip) + } else if !ip6.Is6() { + t.Errorf("dohV6(%q) returned non-v6 address %v", base, ip6) + } + } + } + } +} diff --git a/net/dns/resolver/macios_ext.go b/net/dns/resolver/macios_ext.go index e3f979c194d91..37cccc7f0c7ba 100644 --- a/net/dns/resolver/macios_ext.go +++ b/net/dns/resolver/macios_ext.go @@ -1,26 +1,26 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ts_macext && (darwin || ios) - -package resolver - -import ( - "errors" - "net" - - "tailscale.com/net/netmon" - "tailscale.com/net/netns" -) - -func init() { - initListenConfig = initListenConfigNetworkExtension -} - -func initListenConfigNetworkExtension(nc *net.ListenConfig, netMon *netmon.Monitor, tunName string) error { - nif, ok := netMon.InterfaceState().Interface[tunName] - if !ok { - return errors.New("utun not found") - } - return netns.SetListenConfigInterfaceIndex(nc, nif.Interface.Index) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_macext && (darwin || ios) + +package resolver + +import ( + "errors" + "net" + + "tailscale.com/net/netmon" + "tailscale.com/net/netns" +) + +func init() { + initListenConfig = initListenConfigNetworkExtension +} + +func initListenConfigNetworkExtension(nc *net.ListenConfig, netMon *netmon.Monitor, tunName string) error { + nif, ok := netMon.InterfaceState().Interface[tunName] + if !ok { + return errors.New("utun not found") + } + return netns.SetListenConfigInterfaceIndex(nc, nif.Interface.Index) +} diff --git a/net/dns/resolver/tsdns_server_test.go b/net/dns/resolver/tsdns_server_test.go index 82fd3bebf232c..be47cdfbcf913 100644 --- a/net/dns/resolver/tsdns_server_test.go +++ b/net/dns/resolver/tsdns_server_test.go @@ -1,333 +1,333 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package resolver - -import ( - "fmt" - "net" - "net/netip" - "strings" - "testing" - - "github.com/miekg/dns" -) - -// This file exists to isolate the test infrastructure -// that depends on github.com/miekg/dns -// from the rest, which only depends on dnsmessage. - -// resolveToIP returns a handler function which responds -// to queries of type A it receives with an A record containing ipv4, -// to queries of type AAAA with an AAAA record containing ipv6, -// to queries of type NS with an NS record containing name. -func resolveToIP(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - - if len(req.Question) != 1 { - panic("not a single-question request") - } - question := req.Question[0] - - var ans dns.RR - switch question.Qtype { - case dns.TypeA: - ans = &dns.A{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - }, - A: ipv4.AsSlice(), - } - case dns.TypeAAAA: - ans = &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - }, - AAAA: ipv6.AsSlice(), - } - case dns.TypeNS: - ans = &dns.NS{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - }, - Ns: ns, - } - } - - m.Answer = append(m.Answer, ans) - w.WriteMsg(m) - } -} - -// resolveToIPLowercase returns a handler function which canonicalizes responses -// by lowercasing the question and answer names, and responds -// to queries of type A it receives with an A record containing ipv4, -// to queries of type AAAA with an AAAA record containing ipv6, -// to queries of type NS with an NS record containing name. -func resolveToIPLowercase(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - - if len(req.Question) != 1 { - panic("not a single-question request") - } - m.Question[0].Name = strings.ToLower(m.Question[0].Name) - question := req.Question[0] - - var ans dns.RR - switch question.Qtype { - case dns.TypeA: - ans = &dns.A{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - }, - A: ipv4.AsSlice(), - } - case dns.TypeAAAA: - ans = &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - }, - AAAA: ipv6.AsSlice(), - } - case dns.TypeNS: - ans = &dns.NS{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - }, - Ns: ns, - } - } - - m.Answer = append(m.Answer, ans) - w.WriteMsg(m) - } -} - -// resolveToTXT returns a handler function which responds to queries of type TXT -// it receives with the strings in txts. -func resolveToTXT(txts []string, ednsMaxSize uint16) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - - if len(req.Question) != 1 { - panic("not a single-question request") - } - question := req.Question[0] - - if question.Qtype != dns.TypeTXT { - w.WriteMsg(m) - return - } - - ans := &dns.TXT{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeTXT, - Class: dns.ClassINET, - }, - Txt: txts, - } - - m.Answer = append(m.Answer, ans) - - queryInfo := &dns.TXT{ - Hdr: dns.RR_Header{ - Name: "query-info.test.", - Rrtype: dns.TypeTXT, - Class: dns.ClassINET, - }, - } - - if edns := req.IsEdns0(); edns == nil { - queryInfo.Txt = []string{"EDNS=false"} - } else { - queryInfo.Txt = []string{"EDNS=true", fmt.Sprintf("maxSize=%v", edns.UDPSize())} - } - - m.Extra = append(m.Extra, queryInfo) - - if ednsMaxSize > 0 { - m.SetEdns0(ednsMaxSize, false) - } - - if err := w.WriteMsg(m); err != nil { - panic(err) - } - } -} - -var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetRcode(req, dns.RcodeNameError) - w.WriteMsg(m) -}) - -// weirdoGoCNAMEHandler returns a DNS handler that satisfies -// Go's weird Resolver.LookupCNAME (read its godoc carefully!). -// -// This doesn't even return a CNAME record, because that's not -// what Go looks for. -func weirdoGoCNAMEHandler(target string) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - question := req.Question[0] - - switch question.Qtype { - case dns.TypeA: - m.Answer = append(m.Answer, &dns.CNAME{ - Hdr: dns.RR_Header{ - Name: target, - Rrtype: dns.TypeCNAME, - Class: dns.ClassINET, - Ttl: 600, - }, - Target: target, - }) - case dns.TypeAAAA: - m.Answer = append(m.Answer, &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: target, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: 600, - }, - AAAA: net.ParseIP("1::2"), - }) - } - w.WriteMsg(m) - } -} - -// dnsHandler returns a handler that replies with the answers/options -// provided. -// -// Types supported: netip.Addr. -func dnsHandler(answers ...any) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - if len(req.Question) != 1 { - panic("not a single-question request") - } - m.RecursionAvailable = true // to stop net package's errLameReferral on empty replies - - question := req.Question[0] - for _, a := range answers { - switch a := a.(type) { - default: - panic(fmt.Sprintf("unsupported dnsHandler arg %T", a)) - case netip.Addr: - ip := a - if ip.Is4() { - m.Answer = append(m.Answer, &dns.A{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - }, - A: ip.AsSlice(), - }) - } else if ip.Is6() { - m.Answer = append(m.Answer, &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - }, - AAAA: ip.AsSlice(), - }) - } - case dns.PTR: - ptr := a - ptr.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypePTR, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &ptr) - case dns.CNAME: - c := a - c.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeCNAME, - Class: dns.ClassINET, - Ttl: 600, - } - m.Answer = append(m.Answer, &c) - case dns.TXT: - txt := a - txt.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeTXT, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &txt) - case dns.SRV: - srv := a - srv.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeSRV, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &srv) - case dns.NS: - rr := a - rr.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &rr) - } - } - w.WriteMsg(m) - } -} - -func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server { - if len(records)%2 != 0 { - panic("must have an even number of record values") - } - mux := dns.NewServeMux() - for i := 0; i < len(records); i += 2 { - name := records[i].(string) - handler := records[i+1].(dns.Handler) - mux.Handle(name, handler) - } - waitch := make(chan struct{}) - server := &dns.Server{ - Addr: addr, - Net: "udp", - Handler: mux, - NotifyStartedFunc: func() { close(waitch) }, - ReusePort: true, - } - - go func() { - err := server.ListenAndServe() - if err != nil { - panic(fmt.Sprintf("ListenAndServe(%q): %v", addr, err)) - } - }() - - <-waitch - return server -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package resolver + +import ( + "fmt" + "net" + "net/netip" + "strings" + "testing" + + "github.com/miekg/dns" +) + +// This file exists to isolate the test infrastructure +// that depends on github.com/miekg/dns +// from the rest, which only depends on dnsmessage. + +// resolveToIP returns a handler function which responds +// to queries of type A it receives with an A record containing ipv4, +// to queries of type AAAA with an AAAA record containing ipv6, +// to queries of type NS with an NS record containing name. +func resolveToIP(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + + if len(req.Question) != 1 { + panic("not a single-question request") + } + question := req.Question[0] + + var ans dns.RR + switch question.Qtype { + case dns.TypeA: + ans = &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: ipv4.AsSlice(), + } + case dns.TypeAAAA: + ans = &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + }, + AAAA: ipv6.AsSlice(), + } + case dns.TypeNS: + ans = &dns.NS{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + }, + Ns: ns, + } + } + + m.Answer = append(m.Answer, ans) + w.WriteMsg(m) + } +} + +// resolveToIPLowercase returns a handler function which canonicalizes responses +// by lowercasing the question and answer names, and responds +// to queries of type A it receives with an A record containing ipv4, +// to queries of type AAAA with an AAAA record containing ipv6, +// to queries of type NS with an NS record containing name. +func resolveToIPLowercase(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + + if len(req.Question) != 1 { + panic("not a single-question request") + } + m.Question[0].Name = strings.ToLower(m.Question[0].Name) + question := req.Question[0] + + var ans dns.RR + switch question.Qtype { + case dns.TypeA: + ans = &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: ipv4.AsSlice(), + } + case dns.TypeAAAA: + ans = &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + }, + AAAA: ipv6.AsSlice(), + } + case dns.TypeNS: + ans = &dns.NS{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + }, + Ns: ns, + } + } + + m.Answer = append(m.Answer, ans) + w.WriteMsg(m) + } +} + +// resolveToTXT returns a handler function which responds to queries of type TXT +// it receives with the strings in txts. +func resolveToTXT(txts []string, ednsMaxSize uint16) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + + if len(req.Question) != 1 { + panic("not a single-question request") + } + question := req.Question[0] + + if question.Qtype != dns.TypeTXT { + w.WriteMsg(m) + return + } + + ans := &dns.TXT{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + }, + Txt: txts, + } + + m.Answer = append(m.Answer, ans) + + queryInfo := &dns.TXT{ + Hdr: dns.RR_Header{ + Name: "query-info.test.", + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + }, + } + + if edns := req.IsEdns0(); edns == nil { + queryInfo.Txt = []string{"EDNS=false"} + } else { + queryInfo.Txt = []string{"EDNS=true", fmt.Sprintf("maxSize=%v", edns.UDPSize())} + } + + m.Extra = append(m.Extra, queryInfo) + + if ednsMaxSize > 0 { + m.SetEdns0(ednsMaxSize, false) + } + + if err := w.WriteMsg(m); err != nil { + panic(err) + } + } +} + +var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(req, dns.RcodeNameError) + w.WriteMsg(m) +}) + +// weirdoGoCNAMEHandler returns a DNS handler that satisfies +// Go's weird Resolver.LookupCNAME (read its godoc carefully!). +// +// This doesn't even return a CNAME record, because that's not +// what Go looks for. +func weirdoGoCNAMEHandler(target string) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + question := req.Question[0] + + switch question.Qtype { + case dns.TypeA: + m.Answer = append(m.Answer, &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: 600, + }, + Target: target, + }) + case dns.TypeAAAA: + m.Answer = append(m.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 600, + }, + AAAA: net.ParseIP("1::2"), + }) + } + w.WriteMsg(m) + } +} + +// dnsHandler returns a handler that replies with the answers/options +// provided. +// +// Types supported: netip.Addr. +func dnsHandler(answers ...any) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + if len(req.Question) != 1 { + panic("not a single-question request") + } + m.RecursionAvailable = true // to stop net package's errLameReferral on empty replies + + question := req.Question[0] + for _, a := range answers { + switch a := a.(type) { + default: + panic(fmt.Sprintf("unsupported dnsHandler arg %T", a)) + case netip.Addr: + ip := a + if ip.Is4() { + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: ip.AsSlice(), + }) + } else if ip.Is6() { + m.Answer = append(m.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + }, + AAAA: ip.AsSlice(), + }) + } + case dns.PTR: + ptr := a + ptr.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &ptr) + case dns.CNAME: + c := a + c.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: 600, + } + m.Answer = append(m.Answer, &c) + case dns.TXT: + txt := a + txt.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &txt) + case dns.SRV: + srv := a + srv.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &srv) + case dns.NS: + rr := a + rr.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &rr) + } + } + w.WriteMsg(m) + } +} + +func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server { + if len(records)%2 != 0 { + panic("must have an even number of record values") + } + mux := dns.NewServeMux() + for i := 0; i < len(records); i += 2 { + name := records[i].(string) + handler := records[i+1].(dns.Handler) + mux.Handle(name, handler) + } + waitch := make(chan struct{}) + server := &dns.Server{ + Addr: addr, + Net: "udp", + Handler: mux, + NotifyStartedFunc: func() { close(waitch) }, + ReusePort: true, + } + + go func() { + err := server.ListenAndServe() + if err != nil { + panic(fmt.Sprintf("ListenAndServe(%q): %v", addr, err)) + } + }() + + <-waitch + return server +} diff --git a/net/dns/utf.go b/net/dns/utf.go index 0c1db69acb33b..267829c05fbfa 100644 --- a/net/dns/utf.go +++ b/net/dns/utf.go @@ -1,55 +1,55 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dns - -// This code is only used in Windows builds, but is in an -// OS-independent file so tests can run all the time. - -import ( - "bytes" - "encoding/binary" - "unicode/utf16" -) - -// maybeUnUTF16 tries to detect whether bs contains UTF-16, and if so -// translates it to regular UTF-8. -// -// Some of wsl.exe's output get printed as UTF-16, which breaks a -// bunch of things. Try to detect this by looking for a zero byte in -// the first few bytes of output (which will appear if any of those -// codepoints are basic ASCII - very likely). From that we can infer -// that UTF-16 is being printed, and the byte order in use, and we -// decode that back to UTF-8. -// -// https://github.com/microsoft/WSL/issues/4607 -func maybeUnUTF16(bs []byte) []byte { - if len(bs)%2 != 0 { - // Can't be complete UTF-16. - return bs - } - checkLen := 20 - if len(bs) < checkLen { - checkLen = len(bs) - } - zeroOff := bytes.IndexByte(bs[:checkLen], 0) - if zeroOff == -1 { - return bs - } - - // We assume wsl.exe is trying to print an ASCII codepoint, - // meaning the zero byte is in the upper 8 bits of the - // codepoint. That means we can use the zero's byte offset to - // work out if we're seeing little-endian or big-endian - // UTF-16. - var endian binary.ByteOrder = binary.LittleEndian - if zeroOff%2 == 0 { - endian = binary.BigEndian - } - - var u16 []uint16 - for i := 0; i < len(bs); i += 2 { - u16 = append(u16, endian.Uint16(bs[i:])) - } - return []byte(string(utf16.Decode(u16))) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +// This code is only used in Windows builds, but is in an +// OS-independent file so tests can run all the time. + +import ( + "bytes" + "encoding/binary" + "unicode/utf16" +) + +// maybeUnUTF16 tries to detect whether bs contains UTF-16, and if so +// translates it to regular UTF-8. +// +// Some of wsl.exe's output get printed as UTF-16, which breaks a +// bunch of things. Try to detect this by looking for a zero byte in +// the first few bytes of output (which will appear if any of those +// codepoints are basic ASCII - very likely). From that we can infer +// that UTF-16 is being printed, and the byte order in use, and we +// decode that back to UTF-8. +// +// https://github.com/microsoft/WSL/issues/4607 +func maybeUnUTF16(bs []byte) []byte { + if len(bs)%2 != 0 { + // Can't be complete UTF-16. + return bs + } + checkLen := 20 + if len(bs) < checkLen { + checkLen = len(bs) + } + zeroOff := bytes.IndexByte(bs[:checkLen], 0) + if zeroOff == -1 { + return bs + } + + // We assume wsl.exe is trying to print an ASCII codepoint, + // meaning the zero byte is in the upper 8 bits of the + // codepoint. That means we can use the zero's byte offset to + // work out if we're seeing little-endian or big-endian + // UTF-16. + var endian binary.ByteOrder = binary.LittleEndian + if zeroOff%2 == 0 { + endian = binary.BigEndian + } + + var u16 []uint16 + for i := 0; i < len(bs); i += 2 { + u16 = append(u16, endian.Uint16(bs[i:])) + } + return []byte(string(utf16.Decode(u16))) +} diff --git a/net/dns/utf_test.go b/net/dns/utf_test.go index b5fd372622519..fcf593497e08b 100644 --- a/net/dns/utf_test.go +++ b/net/dns/utf_test.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dns - -import "testing" - -func TestMaybeUnUTF16(t *testing.T) { - tests := []struct { - in string - want string - }{ - {"abc", "abc"}, // UTF-8 - {"a\x00b\x00c\x00", "abc"}, // UTF-16-LE - {"\x00a\x00b\x00c", "abc"}, // UTF-16-BE - } - - for _, test := range tests { - got := string(maybeUnUTF16([]byte(test.in))) - if got != test.want { - t.Errorf("maybeUnUTF16(%q) = %q, want %q", test.in, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +import "testing" + +func TestMaybeUnUTF16(t *testing.T) { + tests := []struct { + in string + want string + }{ + {"abc", "abc"}, // UTF-8 + {"a\x00b\x00c\x00", "abc"}, // UTF-16-LE + {"\x00a\x00b\x00c", "abc"}, // UTF-16-BE + } + + for _, test := range tests { + got := string(maybeUnUTF16([]byte(test.in))) + if got != test.want { + t.Errorf("maybeUnUTF16(%q) = %q, want %q", test.in, got, test.want) + } + } +} diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index ef4249b7401f3..6a4b969315050 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -1,242 +1,242 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dnscache - -import ( - "context" - "errors" - "flag" - "fmt" - "net" - "net/netip" - "reflect" - "testing" - "time" - - "tailscale.com/tstest" -) - -var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial") - -func TestDialer(t *testing.T) { - if *dialTest == "" { - t.Skip("skipping; --dial-test is blank") - } - r := &Resolver{Logf: t.Logf} - var std net.Dialer - dialer := Dialer(std.DialContext, r) - t0 := time.Now() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - c, err := dialer(ctx, "tcp", *dialTest) - if err != nil { - t.Fatal(err) - } - t.Logf("dialed in %v", time.Since(t0)) - c.Close() -} - -func TestDialCall_DNSWasTrustworthy(t *testing.T) { - type step struct { - ip netip.Addr // IP we pretended to dial - err error // the dial error or nil for success - } - mustIP := netip.MustParseAddr - errFail := errors.New("some connect failure") - tests := []struct { - name string - steps []step - want bool - }{ - { - name: "no-info", - want: false, - }, - { - name: "previous-dial", - steps: []step{ - {mustIP("2003::1"), nil}, - {mustIP("2003::1"), errFail}, - }, - want: true, - }, - { - name: "no-previous-dial", - steps: []step{ - {mustIP("2003::1"), errFail}, - }, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := &dialer{ - pastConnect: map[netip.Addr]time.Time{}, - } - dc := &dialCall{ - d: d, - } - for _, st := range tt.steps { - dc.noteDialResult(st.ip, st.err) - } - got := dc.dnsWasTrustworthy() - if got != tt.want { - t.Errorf("got %v; want %v", got, tt.want) - } - }) - } -} - -func TestDialCall_uniqueIPs(t *testing.T) { - dc := &dialCall{} - mustIP := netip.MustParseAddr - errFail := errors.New("some connect failure") - dc.noteDialResult(mustIP("2003::1"), errFail) - dc.noteDialResult(mustIP("2003::2"), errFail) - got := dc.uniqueIPs([]netip.Addr{ - mustIP("2003::1"), - mustIP("2003::2"), - mustIP("2003::2"), - mustIP("2003::3"), - mustIP("2003::3"), - mustIP("2003::4"), - mustIP("2003::4"), - }) - want := []netip.Addr{ - mustIP("2003::3"), - mustIP("2003::4"), - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v; want %v", got, want) - } -} - -func TestResolverAllHostStaticResult(t *testing.T) { - r := &Resolver{ - Logf: t.Logf, - SingleHost: "foo.bar", - SingleHostStaticResult: []netip.Addr{ - netip.MustParseAddr("2001:4860:4860::8888"), - netip.MustParseAddr("2001:4860:4860::8844"), - netip.MustParseAddr("8.8.8.8"), - netip.MustParseAddr("8.8.4.4"), - }, - } - ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar") - if err != nil { - t.Fatal(err) - } - if got, want := ip4.String(), "8.8.8.8"; got != want { - t.Errorf("ip4 got %q; want %q", got, want) - } - if got, want := ip6.String(), "2001:4860:4860::8888"; got != want { - t.Errorf("ip4 got %q; want %q", got, want) - } - if got, want := fmt.Sprintf("%q", allIPs), `["2001:4860:4860::8888" "2001:4860:4860::8844" "8.8.8.8" "8.8.4.4"]`; got != want { - t.Errorf("allIPs got %q; want %q", got, want) - } - - _, _, _, err = r.LookupIP(context.Background(), "bad") - if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want { - t.Errorf("bad dial error got %q; want %q", got, want) - } -} - -func TestShouldTryBootstrap(t *testing.T) { - tstest.Replace(t, &debug, func() bool { return true }) - - type step struct { - ip netip.Addr // IP we pretended to dial - err error // the dial error or nil for success - } - - canceled, cancel := context.WithCancel(context.Background()) - cancel() - - deadlineExceeded, cancel := context.WithTimeout(context.Background(), 0) - defer cancel() - - ctx := context.Background() - errFailed := errors.New("some failure") - - cacheWithFallback := &Resolver{ - Logf: t.Logf, - LookupIPFallback: func(_ context.Context, _ string) ([]netip.Addr, error) { - panic("unimplemented") - }, - } - cacheNoFallback := &Resolver{Logf: t.Logf} - - testCases := []struct { - name string - steps []step - ctx context.Context - err error - noFallback bool - want bool - }{ - { - name: "no-error", - ctx: ctx, - err: nil, - want: false, - }, - { - name: "canceled", - ctx: canceled, - err: errFailed, - want: false, - }, - { - name: "deadline-exceeded", - ctx: deadlineExceeded, - err: errFailed, - want: false, - }, - { - name: "no-fallback", - ctx: ctx, - err: errFailed, - noFallback: true, - want: false, - }, - { - name: "dns-was-trustworthy", - ctx: ctx, - err: errFailed, - steps: []step{ - {netip.MustParseAddr("2003::1"), nil}, - {netip.MustParseAddr("2003::1"), errFailed}, - }, - want: false, - }, - { - name: "should-bootstrap", - ctx: ctx, - err: errFailed, - want: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - d := &dialer{ - pastConnect: map[netip.Addr]time.Time{}, - } - if tt.noFallback { - d.dnsCache = cacheNoFallback - } else { - d.dnsCache = cacheWithFallback - } - dc := &dialCall{d: d} - for _, st := range tt.steps { - dc.noteDialResult(st.ip, st.err) - } - got := d.shouldTryBootstrap(tt.ctx, tt.err, dc) - if got != tt.want { - t.Errorf("got %v; want %v", got, tt.want) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dnscache + +import ( + "context" + "errors" + "flag" + "fmt" + "net" + "net/netip" + "reflect" + "testing" + "time" + + "tailscale.com/tstest" +) + +var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial") + +func TestDialer(t *testing.T) { + if *dialTest == "" { + t.Skip("skipping; --dial-test is blank") + } + r := &Resolver{Logf: t.Logf} + var std net.Dialer + dialer := Dialer(std.DialContext, r) + t0 := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + c, err := dialer(ctx, "tcp", *dialTest) + if err != nil { + t.Fatal(err) + } + t.Logf("dialed in %v", time.Since(t0)) + c.Close() +} + +func TestDialCall_DNSWasTrustworthy(t *testing.T) { + type step struct { + ip netip.Addr // IP we pretended to dial + err error // the dial error or nil for success + } + mustIP := netip.MustParseAddr + errFail := errors.New("some connect failure") + tests := []struct { + name string + steps []step + want bool + }{ + { + name: "no-info", + want: false, + }, + { + name: "previous-dial", + steps: []step{ + {mustIP("2003::1"), nil}, + {mustIP("2003::1"), errFail}, + }, + want: true, + }, + { + name: "no-previous-dial", + steps: []step{ + {mustIP("2003::1"), errFail}, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &dialer{ + pastConnect: map[netip.Addr]time.Time{}, + } + dc := &dialCall{ + d: d, + } + for _, st := range tt.steps { + dc.noteDialResult(st.ip, st.err) + } + got := dc.dnsWasTrustworthy() + if got != tt.want { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} + +func TestDialCall_uniqueIPs(t *testing.T) { + dc := &dialCall{} + mustIP := netip.MustParseAddr + errFail := errors.New("some connect failure") + dc.noteDialResult(mustIP("2003::1"), errFail) + dc.noteDialResult(mustIP("2003::2"), errFail) + got := dc.uniqueIPs([]netip.Addr{ + mustIP("2003::1"), + mustIP("2003::2"), + mustIP("2003::2"), + mustIP("2003::3"), + mustIP("2003::3"), + mustIP("2003::4"), + mustIP("2003::4"), + }) + want := []netip.Addr{ + mustIP("2003::3"), + mustIP("2003::4"), + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } +} + +func TestResolverAllHostStaticResult(t *testing.T) { + r := &Resolver{ + Logf: t.Logf, + SingleHost: "foo.bar", + SingleHostStaticResult: []netip.Addr{ + netip.MustParseAddr("2001:4860:4860::8888"), + netip.MustParseAddr("2001:4860:4860::8844"), + netip.MustParseAddr("8.8.8.8"), + netip.MustParseAddr("8.8.4.4"), + }, + } + ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar") + if err != nil { + t.Fatal(err) + } + if got, want := ip4.String(), "8.8.8.8"; got != want { + t.Errorf("ip4 got %q; want %q", got, want) + } + if got, want := ip6.String(), "2001:4860:4860::8888"; got != want { + t.Errorf("ip4 got %q; want %q", got, want) + } + if got, want := fmt.Sprintf("%q", allIPs), `["2001:4860:4860::8888" "2001:4860:4860::8844" "8.8.8.8" "8.8.4.4"]`; got != want { + t.Errorf("allIPs got %q; want %q", got, want) + } + + _, _, _, err = r.LookupIP(context.Background(), "bad") + if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want { + t.Errorf("bad dial error got %q; want %q", got, want) + } +} + +func TestShouldTryBootstrap(t *testing.T) { + tstest.Replace(t, &debug, func() bool { return true }) + + type step struct { + ip netip.Addr // IP we pretended to dial + err error // the dial error or nil for success + } + + canceled, cancel := context.WithCancel(context.Background()) + cancel() + + deadlineExceeded, cancel := context.WithTimeout(context.Background(), 0) + defer cancel() + + ctx := context.Background() + errFailed := errors.New("some failure") + + cacheWithFallback := &Resolver{ + Logf: t.Logf, + LookupIPFallback: func(_ context.Context, _ string) ([]netip.Addr, error) { + panic("unimplemented") + }, + } + cacheNoFallback := &Resolver{Logf: t.Logf} + + testCases := []struct { + name string + steps []step + ctx context.Context + err error + noFallback bool + want bool + }{ + { + name: "no-error", + ctx: ctx, + err: nil, + want: false, + }, + { + name: "canceled", + ctx: canceled, + err: errFailed, + want: false, + }, + { + name: "deadline-exceeded", + ctx: deadlineExceeded, + err: errFailed, + want: false, + }, + { + name: "no-fallback", + ctx: ctx, + err: errFailed, + noFallback: true, + want: false, + }, + { + name: "dns-was-trustworthy", + ctx: ctx, + err: errFailed, + steps: []step{ + {netip.MustParseAddr("2003::1"), nil}, + {netip.MustParseAddr("2003::1"), errFailed}, + }, + want: false, + }, + { + name: "should-bootstrap", + ctx: ctx, + err: errFailed, + want: true, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + d := &dialer{ + pastConnect: map[netip.Addr]time.Time{}, + } + if tt.noFallback { + d.dnsCache = cacheNoFallback + } else { + d.dnsCache = cacheWithFallback + } + dc := &dialCall{d: d} + for _, st := range tt.steps { + dc.noteDialResult(st.ip, st.err) + } + got := d.shouldTryBootstrap(tt.ctx, tt.err, dc) + if got != tt.want { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} diff --git a/net/dnscache/messagecache_test.go b/net/dnscache/messagecache_test.go index 41fc334483f78..18af324597a43 100644 --- a/net/dnscache/messagecache_test.go +++ b/net/dnscache/messagecache_test.go @@ -1,291 +1,291 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dnscache - -import ( - "bytes" - "context" - "errors" - "fmt" - "net" - "runtime" - "testing" - "time" - - "golang.org/x/net/dns/dnsmessage" - "tailscale.com/tstest" -) - -func TestMessageCache(t *testing.T) { - clock := tstest.NewClock(tstest.ClockOpts{ - Start: time.Date(1987, 11, 1, 0, 0, 0, 0, time.UTC), - }) - mc := &MessageCache{Clock: clock.Now} - mc.SetMaxCacheSize(2) - clock.Advance(time.Second) - - var out bytes.Buffer - if err := mc.ReplyFromCache(&out, makeQ(1, "foo.com.")); err != ErrCacheMiss { - t.Fatalf("unexpected error: %v", err) - } - - if err := mc.AddCacheEntry( - makeQ(2, "foo.com."), - makeRes(2, "FOO.COM.", ttlOpt(10), - &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}, - &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}})); err != nil { - t.Fatal(err) - } - - // Expect cache hit, with 10 seconds remaining. - out.Reset() - if err := mc.ReplyFromCache(&out, makeQ(3, "foo.com.")); err != nil { - t.Fatalf("expected cache hit; got: %v", err) - } - if p := mustParseResponse(t, out.Bytes()); p.TxID != 3 { - t.Errorf("TxID = %v; want %v", p.TxID, 3) - } else if p.TTL != 10 { - t.Errorf("TTL = %v; want 10", p.TTL) - } - - // One second elapses, expect a cache hit, with 9 seconds - // remaining. - clock.Advance(time.Second) - out.Reset() - if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.")); err != nil { - t.Fatalf("expected cache hit; got: %v", err) - } - if p := mustParseResponse(t, out.Bytes()); p.TxID != 4 { - t.Errorf("TxID = %v; want %v", p.TxID, 4) - } else if p.TTL != 9 { - t.Errorf("TTL = %v; want 9", p.TTL) - } - - // Expect cache miss on MX record. - if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.TypeMX)); err != ErrCacheMiss { - t.Fatalf("expected cache miss on MX; got: %v", err) - } - // Expect cache miss on CHAOS class. - if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.ClassCHAOS)); err != ErrCacheMiss { - t.Fatalf("expected cache miss on CHAOS; got: %v", err) - } - - // Ten seconds elapses; expect a cache miss. - clock.Advance(10 * time.Second) - if err := mc.ReplyFromCache(&out, makeQ(5, "foo.com.")); err != ErrCacheMiss { - t.Fatalf("expected cache miss, got: %v", err) - } -} - -type parsedMeta struct { - TxID uint16 - TTL uint32 -} - -func mustParseResponse(t testing.TB, r []byte) (ret parsedMeta) { - t.Helper() - var p dnsmessage.Parser - h, err := p.Start(r) - if err != nil { - t.Fatal(err) - } - ret.TxID = h.ID - qq, err := p.AllQuestions() - if err != nil { - t.Fatalf("AllQuestions: %v", err) - } - if len(qq) != 1 { - t.Fatalf("num questions = %v; want 1", len(qq)) - } - aa, err := p.AllAnswers() - if err != nil { - t.Fatalf("AllAnswers: %v", err) - } - for _, r := range aa { - if ret.TTL == 0 { - ret.TTL = r.Header.TTL - } - if ret.TTL != r.Header.TTL { - t.Fatal("mixed TTLs") - } - } - return ret -} - -type responseOpt bool - -type ttlOpt uint32 - -func makeQ(txID uint16, name string, opt ...any) []byte { - opt = append(opt, responseOpt(false)) - return makeDNSPkt(txID, name, opt...) -} - -func makeRes(txID uint16, name string, opt ...any) []byte { - opt = append(opt, responseOpt(true)) - return makeDNSPkt(txID, name, opt...) -} - -func makeDNSPkt(txID uint16, name string, opt ...any) []byte { - typ := dnsmessage.TypeA - class := dnsmessage.ClassINET - var response bool - var answers []dnsmessage.ResourceBody - var ttl uint32 = 1 // one second by default - for _, o := range opt { - switch o := o.(type) { - case dnsmessage.Type: - typ = o - case dnsmessage.Class: - class = o - case responseOpt: - response = bool(o) - case dnsmessage.ResourceBody: - answers = append(answers, o) - case ttlOpt: - ttl = uint32(o) - default: - panic(fmt.Sprintf("unknown opt type %T", o)) - } - } - qname := dnsmessage.MustNewName(name) - msg := dnsmessage.Message{ - Header: dnsmessage.Header{ID: txID, Response: response}, - Questions: []dnsmessage.Question{ - { - Name: qname, - Type: typ, - Class: class, - }, - }, - } - for _, rb := range answers { - msg.Answers = append(msg.Answers, dnsmessage.Resource{ - Header: dnsmessage.ResourceHeader{ - Name: qname, - Type: typ, - Class: class, - TTL: ttl, - }, - Body: rb, - }) - } - buf, err := msg.Pack() - if err != nil { - panic(err) - } - return buf -} - -func TestASCIILowerName(t *testing.T) { - n := asciiLowerName(dnsmessage.MustNewName("Foo.COM.")) - if got, want := n.String(), "foo.com."; got != want { - t.Errorf("got = %q; want %q", got, want) - } -} - -func TestGetDNSQueryCacheKey(t *testing.T) { - tests := []struct { - name string - pkt []byte - want msgQ - txID uint16 - anyTX bool - }{ - { - name: "empty", - }, - { - name: "a", - pkt: makeQ(123, "foo.com."), - want: msgQ{"foo.com.", dnsmessage.TypeA}, - txID: 123, - }, - { - name: "aaaa", - pkt: makeQ(6, "foo.com.", dnsmessage.TypeAAAA), - want: msgQ{"foo.com.", dnsmessage.TypeAAAA}, - txID: 6, - }, - { - name: "normalize_case", - pkt: makeQ(123, "FoO.CoM."), - want: msgQ{"foo.com.", dnsmessage.TypeA}, - txID: 123, - }, - { - name: "ignore_response", - pkt: makeRes(123, "foo.com."), - }, - { - name: "ignore_question_with_answers", - pkt: makeQ(2, "foo.com.", &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}), - }, - { - name: "whatever_go_generates", // in case Go's net package grows functionality we don't handle - pkt: getGoNetPacketDNSQuery("from-go.foo."), - want: msgQ{"from-go.foo.", dnsmessage.TypeA}, - anyTX: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, gotTX, ok := getDNSQueryCacheKey(tt.pkt) - if !ok { - if tt.txID == 0 && got == (msgQ{}) { - return - } - t.Fatal("failed") - } - if got != tt.want { - t.Errorf("got %+v, want %+v", got, tt.want) - } - if gotTX != tt.txID && !tt.anyTX { - t.Errorf("got tx %v, want %v", gotTX, tt.txID) - } - }) - } -} - -func getGoNetPacketDNSQuery(name string) []byte { - if runtime.GOOS == "windows" { - // On Windows, Go's net.Resolver doesn't use the DNS client. - // See https://github.com/golang/go/issues/33097 which - // was approved but not yet implemented. - // For now just pretend it's implemented to make this test - // pass on Windows with complicated the caller. - return makeQ(123, name) - } - res := make(chan []byte, 1) - r := &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - return goResolverConn(res), nil - }, - } - r.LookupIP(context.Background(), "ip4", name) - return <-res -} - -type goResolverConn chan<- []byte - -func (goResolverConn) Close() error { return nil } -func (goResolverConn) LocalAddr() net.Addr { return todoAddr{} } -func (goResolverConn) RemoteAddr() net.Addr { return todoAddr{} } -func (goResolverConn) SetDeadline(t time.Time) error { return nil } -func (goResolverConn) SetReadDeadline(t time.Time) error { return nil } -func (goResolverConn) SetWriteDeadline(t time.Time) error { return nil } -func (goResolverConn) Read([]byte) (int, error) { return 0, errors.New("boom") } -func (c goResolverConn) Write(p []byte) (int, error) { - select { - case c <- p[2:]: // skip 2 byte length for TCP mode DNS query - default: - } - return 0, errors.New("boom") -} - -type todoAddr struct{} - -func (todoAddr) Network() string { return "unused" } -func (todoAddr) String() string { return "unused-todoAddr" } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dnscache + +import ( + "bytes" + "context" + "errors" + "fmt" + "net" + "runtime" + "testing" + "time" + + "golang.org/x/net/dns/dnsmessage" + "tailscale.com/tstest" +) + +func TestMessageCache(t *testing.T) { + clock := tstest.NewClock(tstest.ClockOpts{ + Start: time.Date(1987, 11, 1, 0, 0, 0, 0, time.UTC), + }) + mc := &MessageCache{Clock: clock.Now} + mc.SetMaxCacheSize(2) + clock.Advance(time.Second) + + var out bytes.Buffer + if err := mc.ReplyFromCache(&out, makeQ(1, "foo.com.")); err != ErrCacheMiss { + t.Fatalf("unexpected error: %v", err) + } + + if err := mc.AddCacheEntry( + makeQ(2, "foo.com."), + makeRes(2, "FOO.COM.", ttlOpt(10), + &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}, + &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}})); err != nil { + t.Fatal(err) + } + + // Expect cache hit, with 10 seconds remaining. + out.Reset() + if err := mc.ReplyFromCache(&out, makeQ(3, "foo.com.")); err != nil { + t.Fatalf("expected cache hit; got: %v", err) + } + if p := mustParseResponse(t, out.Bytes()); p.TxID != 3 { + t.Errorf("TxID = %v; want %v", p.TxID, 3) + } else if p.TTL != 10 { + t.Errorf("TTL = %v; want 10", p.TTL) + } + + // One second elapses, expect a cache hit, with 9 seconds + // remaining. + clock.Advance(time.Second) + out.Reset() + if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.")); err != nil { + t.Fatalf("expected cache hit; got: %v", err) + } + if p := mustParseResponse(t, out.Bytes()); p.TxID != 4 { + t.Errorf("TxID = %v; want %v", p.TxID, 4) + } else if p.TTL != 9 { + t.Errorf("TTL = %v; want 9", p.TTL) + } + + // Expect cache miss on MX record. + if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.TypeMX)); err != ErrCacheMiss { + t.Fatalf("expected cache miss on MX; got: %v", err) + } + // Expect cache miss on CHAOS class. + if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.ClassCHAOS)); err != ErrCacheMiss { + t.Fatalf("expected cache miss on CHAOS; got: %v", err) + } + + // Ten seconds elapses; expect a cache miss. + clock.Advance(10 * time.Second) + if err := mc.ReplyFromCache(&out, makeQ(5, "foo.com.")); err != ErrCacheMiss { + t.Fatalf("expected cache miss, got: %v", err) + } +} + +type parsedMeta struct { + TxID uint16 + TTL uint32 +} + +func mustParseResponse(t testing.TB, r []byte) (ret parsedMeta) { + t.Helper() + var p dnsmessage.Parser + h, err := p.Start(r) + if err != nil { + t.Fatal(err) + } + ret.TxID = h.ID + qq, err := p.AllQuestions() + if err != nil { + t.Fatalf("AllQuestions: %v", err) + } + if len(qq) != 1 { + t.Fatalf("num questions = %v; want 1", len(qq)) + } + aa, err := p.AllAnswers() + if err != nil { + t.Fatalf("AllAnswers: %v", err) + } + for _, r := range aa { + if ret.TTL == 0 { + ret.TTL = r.Header.TTL + } + if ret.TTL != r.Header.TTL { + t.Fatal("mixed TTLs") + } + } + return ret +} + +type responseOpt bool + +type ttlOpt uint32 + +func makeQ(txID uint16, name string, opt ...any) []byte { + opt = append(opt, responseOpt(false)) + return makeDNSPkt(txID, name, opt...) +} + +func makeRes(txID uint16, name string, opt ...any) []byte { + opt = append(opt, responseOpt(true)) + return makeDNSPkt(txID, name, opt...) +} + +func makeDNSPkt(txID uint16, name string, opt ...any) []byte { + typ := dnsmessage.TypeA + class := dnsmessage.ClassINET + var response bool + var answers []dnsmessage.ResourceBody + var ttl uint32 = 1 // one second by default + for _, o := range opt { + switch o := o.(type) { + case dnsmessage.Type: + typ = o + case dnsmessage.Class: + class = o + case responseOpt: + response = bool(o) + case dnsmessage.ResourceBody: + answers = append(answers, o) + case ttlOpt: + ttl = uint32(o) + default: + panic(fmt.Sprintf("unknown opt type %T", o)) + } + } + qname := dnsmessage.MustNewName(name) + msg := dnsmessage.Message{ + Header: dnsmessage.Header{ID: txID, Response: response}, + Questions: []dnsmessage.Question{ + { + Name: qname, + Type: typ, + Class: class, + }, + }, + } + for _, rb := range answers { + msg.Answers = append(msg.Answers, dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: qname, + Type: typ, + Class: class, + TTL: ttl, + }, + Body: rb, + }) + } + buf, err := msg.Pack() + if err != nil { + panic(err) + } + return buf +} + +func TestASCIILowerName(t *testing.T) { + n := asciiLowerName(dnsmessage.MustNewName("Foo.COM.")) + if got, want := n.String(), "foo.com."; got != want { + t.Errorf("got = %q; want %q", got, want) + } +} + +func TestGetDNSQueryCacheKey(t *testing.T) { + tests := []struct { + name string + pkt []byte + want msgQ + txID uint16 + anyTX bool + }{ + { + name: "empty", + }, + { + name: "a", + pkt: makeQ(123, "foo.com."), + want: msgQ{"foo.com.", dnsmessage.TypeA}, + txID: 123, + }, + { + name: "aaaa", + pkt: makeQ(6, "foo.com.", dnsmessage.TypeAAAA), + want: msgQ{"foo.com.", dnsmessage.TypeAAAA}, + txID: 6, + }, + { + name: "normalize_case", + pkt: makeQ(123, "FoO.CoM."), + want: msgQ{"foo.com.", dnsmessage.TypeA}, + txID: 123, + }, + { + name: "ignore_response", + pkt: makeRes(123, "foo.com."), + }, + { + name: "ignore_question_with_answers", + pkt: makeQ(2, "foo.com.", &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}), + }, + { + name: "whatever_go_generates", // in case Go's net package grows functionality we don't handle + pkt: getGoNetPacketDNSQuery("from-go.foo."), + want: msgQ{"from-go.foo.", dnsmessage.TypeA}, + anyTX: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, gotTX, ok := getDNSQueryCacheKey(tt.pkt) + if !ok { + if tt.txID == 0 && got == (msgQ{}) { + return + } + t.Fatal("failed") + } + if got != tt.want { + t.Errorf("got %+v, want %+v", got, tt.want) + } + if gotTX != tt.txID && !tt.anyTX { + t.Errorf("got tx %v, want %v", gotTX, tt.txID) + } + }) + } +} + +func getGoNetPacketDNSQuery(name string) []byte { + if runtime.GOOS == "windows" { + // On Windows, Go's net.Resolver doesn't use the DNS client. + // See https://github.com/golang/go/issues/33097 which + // was approved but not yet implemented. + // For now just pretend it's implemented to make this test + // pass on Windows with complicated the caller. + return makeQ(123, name) + } + res := make(chan []byte, 1) + r := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return goResolverConn(res), nil + }, + } + r.LookupIP(context.Background(), "ip4", name) + return <-res +} + +type goResolverConn chan<- []byte + +func (goResolverConn) Close() error { return nil } +func (goResolverConn) LocalAddr() net.Addr { return todoAddr{} } +func (goResolverConn) RemoteAddr() net.Addr { return todoAddr{} } +func (goResolverConn) SetDeadline(t time.Time) error { return nil } +func (goResolverConn) SetReadDeadline(t time.Time) error { return nil } +func (goResolverConn) SetWriteDeadline(t time.Time) error { return nil } +func (goResolverConn) Read([]byte) (int, error) { return 0, errors.New("boom") } +func (c goResolverConn) Write(p []byte) (int, error) { + select { + case c <- p[2:]: // skip 2 byte length for TCP mode DNS query + default: + } + return 0, errors.New("boom") +} + +type todoAddr struct{} + +func (todoAddr) Network() string { return "unused" } +func (todoAddr) String() string { return "unused-todoAddr" } diff --git a/net/dnsfallback/update-dns-fallbacks.go b/net/dnsfallback/update-dns-fallbacks.go index 384e77e104cdc..ebbfc2ad17409 100644 --- a/net/dnsfallback/update-dns-fallbacks.go +++ b/net/dnsfallback/update-dns-fallbacks.go @@ -1,45 +1,45 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -package main - -import ( - "encoding/json" - "fmt" - "log" - "net/http" - "os" - - "tailscale.com/tailcfg" -) - -func main() { - res, err := http.Get("https://login.tailscale.com/derpmap/default") - if err != nil { - log.Fatal(err) - } - if res.StatusCode != 200 { - res.Write(os.Stderr) - os.Exit(1) - } - dm := new(tailcfg.DERPMap) - if err := json.NewDecoder(res.Body).Decode(dm); err != nil { - log.Fatal(err) - } - for rid, r := range dm.Regions { - // Names misleading to check into git, as this is a - // static snapshot and doesn't reflect the live DERP - // map. - r.RegionCode = fmt.Sprintf("r%d", rid) - r.RegionName = r.RegionCode - } - out, err := json.MarshalIndent(dm, "", "\t") - if err != nil { - log.Fatal(err) - } - if err := os.WriteFile("dns-fallback-servers.json", out, 0644); err != nil { - log.Fatal(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +package main + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "os" + + "tailscale.com/tailcfg" +) + +func main() { + res, err := http.Get("https://login.tailscale.com/derpmap/default") + if err != nil { + log.Fatal(err) + } + if res.StatusCode != 200 { + res.Write(os.Stderr) + os.Exit(1) + } + dm := new(tailcfg.DERPMap) + if err := json.NewDecoder(res.Body).Decode(dm); err != nil { + log.Fatal(err) + } + for rid, r := range dm.Regions { + // Names misleading to check into git, as this is a + // static snapshot and doesn't reflect the live DERP + // map. + r.RegionCode = fmt.Sprintf("r%d", rid) + r.RegionName = r.RegionCode + } + out, err := json.MarshalIndent(dm, "", "\t") + if err != nil { + log.Fatal(err) + } + if err := os.WriteFile("dns-fallback-servers.json", out, 0644); err != nil { + log.Fatal(err) + } +} diff --git a/net/memnet/conn.go b/net/memnet/conn.go index a9e1fd39901a0..f599612d93553 100644 --- a/net/memnet/conn.go +++ b/net/memnet/conn.go @@ -1,114 +1,114 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "net" - "net/netip" - "time" -) - -// NetworkName is the network name returned by [net.Addr.Network] -// for [net.Conn.LocalAddr] and [net.Conn.RemoteAddr] from the [Conn] type. -const NetworkName = "mem" - -// Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked. -type Conn interface { - net.Conn - - // SetReadBlock blocks or unblocks the Read method of this Conn. - // It reports an error if the existing value matches the new value, - // or if the Conn has been Closed. - SetReadBlock(bool) error - - // SetWriteBlock blocks or unblocks the Write method of this Conn. - // It reports an error if the existing value matches the new value, - // or if the Conn has been Closed. - SetWriteBlock(bool) error -} - -// NewConn creates a pair of Conns that are wired together by pipes. -func NewConn(name string, maxBuf int) (Conn, Conn) { - r := NewPipe(name+"|0", maxBuf) - w := NewPipe(name+"|1", maxBuf) - - return &connHalf{r: r, w: w}, &connHalf{r: w, w: r} -} - -// NewTCPConn creates a pair of Conns that are wired together by pipes. -func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) { - r := NewPipe(src.String(), maxBuf) - w := NewPipe(dst.String(), maxBuf) - - lAddr := net.TCPAddrFromAddrPort(src) - rAddr := net.TCPAddrFromAddrPort(dst) - - return &connHalf{r: r, w: w, remote: rAddr, local: lAddr}, &connHalf{r: w, w: r, remote: lAddr, local: rAddr} -} - -type connAddr string - -func (a connAddr) Network() string { return NetworkName } -func (a connAddr) String() string { return string(a) } - -type connHalf struct { - local, remote net.Addr - r, w *Pipe -} - -func (c *connHalf) LocalAddr() net.Addr { - if c.local != nil { - return c.local - } - return connAddr(c.r.name) -} - -func (c *connHalf) RemoteAddr() net.Addr { - if c.remote != nil { - return c.remote - } - return connAddr(c.w.name) -} - -func (c *connHalf) Read(b []byte) (n int, err error) { - return c.r.Read(b) -} -func (c *connHalf) Write(b []byte) (n int, err error) { - return c.w.Write(b) -} - -func (c *connHalf) Close() error { - if err := c.w.Close(); err != nil { - return err - } - return c.r.Close() -} - -func (c *connHalf) SetDeadline(t time.Time) error { - err1 := c.SetReadDeadline(t) - err2 := c.SetWriteDeadline(t) - if err1 != nil { - return err1 - } - return err2 -} -func (c *connHalf) SetReadDeadline(t time.Time) error { - return c.r.SetReadDeadline(t) -} -func (c *connHalf) SetWriteDeadline(t time.Time) error { - return c.w.SetWriteDeadline(t) -} - -func (c *connHalf) SetReadBlock(b bool) error { - if b { - return c.r.Block() - } - return c.r.Unblock() -} -func (c *connHalf) SetWriteBlock(b bool) error { - if b { - return c.w.Block() - } - return c.w.Unblock() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "net" + "net/netip" + "time" +) + +// NetworkName is the network name returned by [net.Addr.Network] +// for [net.Conn.LocalAddr] and [net.Conn.RemoteAddr] from the [Conn] type. +const NetworkName = "mem" + +// Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked. +type Conn interface { + net.Conn + + // SetReadBlock blocks or unblocks the Read method of this Conn. + // It reports an error if the existing value matches the new value, + // or if the Conn has been Closed. + SetReadBlock(bool) error + + // SetWriteBlock blocks or unblocks the Write method of this Conn. + // It reports an error if the existing value matches the new value, + // or if the Conn has been Closed. + SetWriteBlock(bool) error +} + +// NewConn creates a pair of Conns that are wired together by pipes. +func NewConn(name string, maxBuf int) (Conn, Conn) { + r := NewPipe(name+"|0", maxBuf) + w := NewPipe(name+"|1", maxBuf) + + return &connHalf{r: r, w: w}, &connHalf{r: w, w: r} +} + +// NewTCPConn creates a pair of Conns that are wired together by pipes. +func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) { + r := NewPipe(src.String(), maxBuf) + w := NewPipe(dst.String(), maxBuf) + + lAddr := net.TCPAddrFromAddrPort(src) + rAddr := net.TCPAddrFromAddrPort(dst) + + return &connHalf{r: r, w: w, remote: rAddr, local: lAddr}, &connHalf{r: w, w: r, remote: lAddr, local: rAddr} +} + +type connAddr string + +func (a connAddr) Network() string { return NetworkName } +func (a connAddr) String() string { return string(a) } + +type connHalf struct { + local, remote net.Addr + r, w *Pipe +} + +func (c *connHalf) LocalAddr() net.Addr { + if c.local != nil { + return c.local + } + return connAddr(c.r.name) +} + +func (c *connHalf) RemoteAddr() net.Addr { + if c.remote != nil { + return c.remote + } + return connAddr(c.w.name) +} + +func (c *connHalf) Read(b []byte) (n int, err error) { + return c.r.Read(b) +} +func (c *connHalf) Write(b []byte) (n int, err error) { + return c.w.Write(b) +} + +func (c *connHalf) Close() error { + if err := c.w.Close(); err != nil { + return err + } + return c.r.Close() +} + +func (c *connHalf) SetDeadline(t time.Time) error { + err1 := c.SetReadDeadline(t) + err2 := c.SetWriteDeadline(t) + if err1 != nil { + return err1 + } + return err2 +} +func (c *connHalf) SetReadDeadline(t time.Time) error { + return c.r.SetReadDeadline(t) +} +func (c *connHalf) SetWriteDeadline(t time.Time) error { + return c.w.SetWriteDeadline(t) +} + +func (c *connHalf) SetReadBlock(b bool) error { + if b { + return c.r.Block() + } + return c.r.Unblock() +} +func (c *connHalf) SetWriteBlock(b bool) error { + if b { + return c.w.Block() + } + return c.w.Unblock() +} diff --git a/net/memnet/conn_test.go b/net/memnet/conn_test.go index 743ce5248cb9d..3eec80bc6a583 100644 --- a/net/memnet/conn_test.go +++ b/net/memnet/conn_test.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "net" - "testing" - - "golang.org/x/net/nettest" -) - -func TestConn(t *testing.T) { - nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) { - c1, c2 = NewConn("test", bufferSize) - return c1, c2, func() { - c1.Close() - c2.Close() - }, nil - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "net" + "testing" + + "golang.org/x/net/nettest" +) + +func TestConn(t *testing.T) { + nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) { + c1, c2 = NewConn("test", bufferSize) + return c1, c2, func() { + c1.Close() + c2.Close() + }, nil + }) +} diff --git a/net/memnet/listener.go b/net/memnet/listener.go index d84a2e443cbff..d1364d7903d15 100644 --- a/net/memnet/listener.go +++ b/net/memnet/listener.go @@ -1,100 +1,100 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "context" - "net" - "strings" - "sync" -) - -const ( - bufferSize = 256 * 1024 -) - -// Listener is a net.Listener using NewConn to create pairs of network -// connections connected in memory using a buffered pipe. It also provides a -// Dial method to establish new connections. -type Listener struct { - addr connAddr - ch chan Conn - closeOnce sync.Once - closed chan struct{} - - // NewConn, if non-nil, is called to create a new pair of connections - // when dialing. If nil, NewConn is used. - NewConn func(network, addr string, maxBuf int) (Conn, Conn) -} - -// Listen returns a new Listener for the provided address. -func Listen(addr string) *Listener { - return &Listener{ - addr: connAddr(addr), - ch: make(chan Conn), - closed: make(chan struct{}), - } -} - -// Addr implements net.Listener.Addr. -func (l *Listener) Addr() net.Addr { - return l.addr -} - -// Close closes the pipe listener. -func (l *Listener) Close() error { - l.closeOnce.Do(func() { - close(l.closed) - }) - return nil -} - -// Accept blocks until a new connection is available or the listener is closed. -func (l *Listener) Accept() (net.Conn, error) { - select { - case c := <-l.ch: - return c, nil - case <-l.closed: - return nil, net.ErrClosed - } -} - -// Dial connects to the listener using the provided context. -// The provided Context must be non-nil. If the context expires before the -// connection is complete, an error is returned. Once successfully connected -// any expiration of the context will not affect the connection. -func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) { - if !strings.HasSuffix(network, "tcp") { - return nil, net.UnknownNetworkError(network) - } - if connAddr(addr) != l.addr { - return nil, &net.AddrError{ - Err: "invalid address", - Addr: addr, - } - } - - newConn := l.NewConn - if newConn == nil { - newConn = func(network, addr string, maxBuf int) (Conn, Conn) { - return NewConn(addr, maxBuf) - } - } - c, s := newConn(network, addr, bufferSize) - defer func() { - if err != nil { - c.Close() - s.Close() - } - }() - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-l.closed: - return nil, net.ErrClosed - case l.ch <- s: - return c, nil - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "context" + "net" + "strings" + "sync" +) + +const ( + bufferSize = 256 * 1024 +) + +// Listener is a net.Listener using NewConn to create pairs of network +// connections connected in memory using a buffered pipe. It also provides a +// Dial method to establish new connections. +type Listener struct { + addr connAddr + ch chan Conn + closeOnce sync.Once + closed chan struct{} + + // NewConn, if non-nil, is called to create a new pair of connections + // when dialing. If nil, NewConn is used. + NewConn func(network, addr string, maxBuf int) (Conn, Conn) +} + +// Listen returns a new Listener for the provided address. +func Listen(addr string) *Listener { + return &Listener{ + addr: connAddr(addr), + ch: make(chan Conn), + closed: make(chan struct{}), + } +} + +// Addr implements net.Listener.Addr. +func (l *Listener) Addr() net.Addr { + return l.addr +} + +// Close closes the pipe listener. +func (l *Listener) Close() error { + l.closeOnce.Do(func() { + close(l.closed) + }) + return nil +} + +// Accept blocks until a new connection is available or the listener is closed. +func (l *Listener) Accept() (net.Conn, error) { + select { + case c := <-l.ch: + return c, nil + case <-l.closed: + return nil, net.ErrClosed + } +} + +// Dial connects to the listener using the provided context. +// The provided Context must be non-nil. If the context expires before the +// connection is complete, an error is returned. Once successfully connected +// any expiration of the context will not affect the connection. +func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) { + if !strings.HasSuffix(network, "tcp") { + return nil, net.UnknownNetworkError(network) + } + if connAddr(addr) != l.addr { + return nil, &net.AddrError{ + Err: "invalid address", + Addr: addr, + } + } + + newConn := l.NewConn + if newConn == nil { + newConn = func(network, addr string, maxBuf int) (Conn, Conn) { + return NewConn(addr, maxBuf) + } + } + c, s := newConn(network, addr, bufferSize) + defer func() { + if err != nil { + c.Close() + s.Close() + } + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-l.closed: + return nil, net.ErrClosed + case l.ch <- s: + return c, nil + } +} diff --git a/net/memnet/listener_test.go b/net/memnet/listener_test.go index 73b67841ad08c..989d5e9e4bb2b 100644 --- a/net/memnet/listener_test.go +++ b/net/memnet/listener_test.go @@ -1,33 +1,33 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "context" - "testing" -) - -func TestListener(t *testing.T) { - l := Listen("srv.local") - defer l.Close() - go func() { - c, err := l.Accept() - if err != nil { - t.Error(err) - return - } - defer c.Close() - }() - - if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil { - c.Close() - t.Fatalf("dial to invalid address succeeded") - } - c, err := l.Dial(context.Background(), "tcp", "srv.local") - if err != nil { - t.Fatalf("dial failed: %v", err) - return - } - c.Close() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "context" + "testing" +) + +func TestListener(t *testing.T) { + l := Listen("srv.local") + defer l.Close() + go func() { + c, err := l.Accept() + if err != nil { + t.Error(err) + return + } + defer c.Close() + }() + + if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil { + c.Close() + t.Fatalf("dial to invalid address succeeded") + } + c, err := l.Dial(context.Background(), "tcp", "srv.local") + if err != nil { + t.Fatalf("dial failed: %v", err) + return + } + c.Close() +} diff --git a/net/memnet/memnet.go b/net/memnet/memnet.go index c8799bc17035e..2fc13b4b2436f 100644 --- a/net/memnet/memnet.go +++ b/net/memnet/memnet.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package memnet implements an in-memory network implementation. -// It is useful for dialing and listening on in-memory addresses -// in tests and other situations where you don't want to use the -// network. -package memnet +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package memnet implements an in-memory network implementation. +// It is useful for dialing and listening on in-memory addresses +// in tests and other situations where you don't want to use the +// network. +package memnet diff --git a/net/memnet/pipe.go b/net/memnet/pipe.go index 47163508353a6..51bee109024d0 100644 --- a/net/memnet/pipe.go +++ b/net/memnet/pipe.go @@ -1,244 +1,244 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "bytes" - "context" - "fmt" - "io" - "log" - "net" - "os" - "sync" - "time" -) - -const debugPipe = false - -// Pipe implements an in-memory FIFO with timeouts. -type Pipe struct { - name string - maxBuf int - mu sync.Mutex - cnd *sync.Cond - - blocked bool - closed bool - buf bytes.Buffer - readTimeout time.Time - writeTimeout time.Time - cancelReadTimer func() - cancelWriteTimer func() -} - -// NewPipe creates a Pipe with a buffer size fixed at maxBuf. -func NewPipe(name string, maxBuf int) *Pipe { - p := &Pipe{ - name: name, - maxBuf: maxBuf, - } - p.cnd = sync.NewCond(&p.mu) - return p -} - -// readOrBlock attempts to read from the buffer, if the buffer is empty and -// the connection hasn't been closed it will block until there is a change. -func (p *Pipe) readOrBlock(b []byte) (int, error) { - p.mu.Lock() - defer p.mu.Unlock() - if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) { - return 0, os.ErrDeadlineExceeded - } - if p.blocked { - p.cnd.Wait() - return 0, nil - } - - n, err := p.buf.Read(b) - // err will either be nil or io.EOF. - if err == io.EOF { - if p.closed { - return n, err - } - // Wait for something to change. - p.cnd.Wait() - } - return n, nil -} - -// Read implements io.Reader. -// Once the buffer is drained (i.e. after Close), subsequent calls will -// return io.EOF. -func (p *Pipe) Read(b []byte) (n int, err error) { - if debugPipe { - orig := b - defer func() { - log.Printf("Pipe(%q).Read(%q) n=%d, err=%v", p.name, string(orig[:n]), n, err) - }() - } - for n == 0 { - n2, err := p.readOrBlock(b) - if err != nil { - return n2, err - } - n += n2 - } - p.cnd.Signal() - return n, nil -} - -// writeOrBlock attempts to write to the buffer, if the buffer is full it will -// block until there is a change. -func (p *Pipe) writeOrBlock(b []byte) (int, error) { - p.mu.Lock() - defer p.mu.Unlock() - if p.closed { - return 0, net.ErrClosed - } - if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) { - return 0, os.ErrDeadlineExceeded - } - if p.blocked { - p.cnd.Wait() - return 0, nil - } - - // Optimistically we want to write the entire slice. - n := len(b) - if limit := p.maxBuf - p.buf.Len(); limit < n { - // However, we don't have enough capacity to write everything. - n = limit - } - if n == 0 { - // Wait for something to change. - p.cnd.Wait() - return 0, nil - } - - p.buf.Write(b[:n]) - p.cnd.Signal() - return n, nil -} - -// Write implements io.Writer. -func (p *Pipe) Write(b []byte) (n int, err error) { - if debugPipe { - orig := b - defer func() { - log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err) - }() - } - for len(b) > 0 { - n2, err := p.writeOrBlock(b) - if err != nil { - return n + n2, err - } - n += n2 - b = b[n2:] - } - return n, nil -} - -// Close closes the pipe. -func (p *Pipe) Close() error { - p.mu.Lock() - defer p.mu.Unlock() - p.closed = true - p.blocked = false - if p.cancelWriteTimer != nil { - p.cancelWriteTimer() - p.cancelWriteTimer = nil - } - if p.cancelReadTimer != nil { - p.cancelReadTimer() - p.cancelReadTimer = nil - } - p.cnd.Broadcast() - - return nil -} - -func (p *Pipe) deadlineTimer(t time.Time) func() { - if t.IsZero() { - return nil - } - if t.Before(time.Now()) { - p.cnd.Broadcast() - return nil - } - ctx, cancel := context.WithDeadline(context.Background(), t) - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - p.cnd.Broadcast() - } - }() - return cancel -} - -// SetReadDeadline sets the deadline for future Read calls. -func (p *Pipe) SetReadDeadline(t time.Time) error { - p.mu.Lock() - defer p.mu.Unlock() - p.readTimeout = t - // If we already have a deadline, cancel it and create a new one. - if p.cancelReadTimer != nil { - p.cancelReadTimer() - p.cancelReadTimer = nil - } - p.cancelReadTimer = p.deadlineTimer(t) - return nil -} - -// SetWriteDeadline sets the deadline for future Write calls. -func (p *Pipe) SetWriteDeadline(t time.Time) error { - p.mu.Lock() - defer p.mu.Unlock() - p.writeTimeout = t - // If we already have a deadline, cancel it and create a new one. - if p.cancelWriteTimer != nil { - p.cancelWriteTimer() - p.cancelWriteTimer = nil - } - p.cancelWriteTimer = p.deadlineTimer(t) - return nil -} - -// Block will cause all calls to Read and Write to block until they either -// timeout, are unblocked or the pipe is closed. -func (p *Pipe) Block() error { - p.mu.Lock() - defer p.mu.Unlock() - closed := p.closed - blocked := p.blocked - p.blocked = true - - if closed { - return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) - } - if blocked { - return fmt.Errorf("memnet.Pipe(%q).Block: already blocked", p.name) - } - p.cnd.Broadcast() - return nil -} - -// Unblock will cause all blocked Read/Write calls to continue execution. -func (p *Pipe) Unblock() error { - p.mu.Lock() - defer p.mu.Unlock() - closed := p.closed - blocked := p.blocked - p.blocked = false - - if closed { - return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) - } - if !blocked { - return fmt.Errorf("memnet.Pipe(%q).Block: already unblocked", p.name) - } - p.cnd.Broadcast() - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "bytes" + "context" + "fmt" + "io" + "log" + "net" + "os" + "sync" + "time" +) + +const debugPipe = false + +// Pipe implements an in-memory FIFO with timeouts. +type Pipe struct { + name string + maxBuf int + mu sync.Mutex + cnd *sync.Cond + + blocked bool + closed bool + buf bytes.Buffer + readTimeout time.Time + writeTimeout time.Time + cancelReadTimer func() + cancelWriteTimer func() +} + +// NewPipe creates a Pipe with a buffer size fixed at maxBuf. +func NewPipe(name string, maxBuf int) *Pipe { + p := &Pipe{ + name: name, + maxBuf: maxBuf, + } + p.cnd = sync.NewCond(&p.mu) + return p +} + +// readOrBlock attempts to read from the buffer, if the buffer is empty and +// the connection hasn't been closed it will block until there is a change. +func (p *Pipe) readOrBlock(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) { + return 0, os.ErrDeadlineExceeded + } + if p.blocked { + p.cnd.Wait() + return 0, nil + } + + n, err := p.buf.Read(b) + // err will either be nil or io.EOF. + if err == io.EOF { + if p.closed { + return n, err + } + // Wait for something to change. + p.cnd.Wait() + } + return n, nil +} + +// Read implements io.Reader. +// Once the buffer is drained (i.e. after Close), subsequent calls will +// return io.EOF. +func (p *Pipe) Read(b []byte) (n int, err error) { + if debugPipe { + orig := b + defer func() { + log.Printf("Pipe(%q).Read(%q) n=%d, err=%v", p.name, string(orig[:n]), n, err) + }() + } + for n == 0 { + n2, err := p.readOrBlock(b) + if err != nil { + return n2, err + } + n += n2 + } + p.cnd.Signal() + return n, nil +} + +// writeOrBlock attempts to write to the buffer, if the buffer is full it will +// block until there is a change. +func (p *Pipe) writeOrBlock(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.closed { + return 0, net.ErrClosed + } + if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) { + return 0, os.ErrDeadlineExceeded + } + if p.blocked { + p.cnd.Wait() + return 0, nil + } + + // Optimistically we want to write the entire slice. + n := len(b) + if limit := p.maxBuf - p.buf.Len(); limit < n { + // However, we don't have enough capacity to write everything. + n = limit + } + if n == 0 { + // Wait for something to change. + p.cnd.Wait() + return 0, nil + } + + p.buf.Write(b[:n]) + p.cnd.Signal() + return n, nil +} + +// Write implements io.Writer. +func (p *Pipe) Write(b []byte) (n int, err error) { + if debugPipe { + orig := b + defer func() { + log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err) + }() + } + for len(b) > 0 { + n2, err := p.writeOrBlock(b) + if err != nil { + return n + n2, err + } + n += n2 + b = b[n2:] + } + return n, nil +} + +// Close closes the pipe. +func (p *Pipe) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + p.closed = true + p.blocked = false + if p.cancelWriteTimer != nil { + p.cancelWriteTimer() + p.cancelWriteTimer = nil + } + if p.cancelReadTimer != nil { + p.cancelReadTimer() + p.cancelReadTimer = nil + } + p.cnd.Broadcast() + + return nil +} + +func (p *Pipe) deadlineTimer(t time.Time) func() { + if t.IsZero() { + return nil + } + if t.Before(time.Now()) { + p.cnd.Broadcast() + return nil + } + ctx, cancel := context.WithDeadline(context.Background(), t) + go func() { + <-ctx.Done() + if ctx.Err() == context.DeadlineExceeded { + p.cnd.Broadcast() + } + }() + return cancel +} + +// SetReadDeadline sets the deadline for future Read calls. +func (p *Pipe) SetReadDeadline(t time.Time) error { + p.mu.Lock() + defer p.mu.Unlock() + p.readTimeout = t + // If we already have a deadline, cancel it and create a new one. + if p.cancelReadTimer != nil { + p.cancelReadTimer() + p.cancelReadTimer = nil + } + p.cancelReadTimer = p.deadlineTimer(t) + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls. +func (p *Pipe) SetWriteDeadline(t time.Time) error { + p.mu.Lock() + defer p.mu.Unlock() + p.writeTimeout = t + // If we already have a deadline, cancel it and create a new one. + if p.cancelWriteTimer != nil { + p.cancelWriteTimer() + p.cancelWriteTimer = nil + } + p.cancelWriteTimer = p.deadlineTimer(t) + return nil +} + +// Block will cause all calls to Read and Write to block until they either +// timeout, are unblocked or the pipe is closed. +func (p *Pipe) Block() error { + p.mu.Lock() + defer p.mu.Unlock() + closed := p.closed + blocked := p.blocked + p.blocked = true + + if closed { + return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) + } + if blocked { + return fmt.Errorf("memnet.Pipe(%q).Block: already blocked", p.name) + } + p.cnd.Broadcast() + return nil +} + +// Unblock will cause all blocked Read/Write calls to continue execution. +func (p *Pipe) Unblock() error { + p.mu.Lock() + defer p.mu.Unlock() + closed := p.closed + blocked := p.blocked + p.blocked = false + + if closed { + return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) + } + if !blocked { + return fmt.Errorf("memnet.Pipe(%q).Block: already unblocked", p.name) + } + p.cnd.Broadcast() + return nil +} diff --git a/net/memnet/pipe_test.go b/net/memnet/pipe_test.go index a86d65388e27d..b3775cf7f9130 100644 --- a/net/memnet/pipe_test.go +++ b/net/memnet/pipe_test.go @@ -1,117 +1,117 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "errors" - "fmt" - "os" - "testing" - "time" -) - -func TestPipeHello(t *testing.T) { - p := NewPipe("p1", 1<<16) - msg := "Hello, World!" - if n, err := p.Write([]byte(msg)); err != nil { - t.Fatal(err) - } else if n != len(msg) { - t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg)) - } - b := make([]byte, len(msg)) - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != len(b) { - t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b)) - } - if got := string(b); got != msg { - t.Errorf("p.Read: %q, want %q", got, msg) - } -} - -func TestPipeTimeout(t *testing.T) { - t.Run("write", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.SetWriteDeadline(time.Now().Add(-1 * time.Second)) - n, err := p.Write([]byte{'h'}) - if !errors.Is(err, os.ErrDeadlineExceeded) { - t.Errorf("missing write timeout got err: %v", err) - } - if n != 0 { - t.Errorf("n=%d on timeout", n) - } - }) - t.Run("read", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.Write([]byte{'h'}) - - p.SetReadDeadline(time.Now().Add(-1 * time.Second)) - b := make([]byte, 1) - n, err := p.Read(b) - if !errors.Is(err, os.ErrDeadlineExceeded) { - t.Errorf("missing read timeout got err: %v", err) - } - if n != 0 { - t.Errorf("n=%d on timeout", n) - } - }) - t.Run("block-write", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) - if err := p.Block(); err != nil { - t.Fatal(err) - } - if _, err := p.Write([]byte{'h'}); !errors.Is(err, os.ErrDeadlineExceeded) { - t.Fatalf("want write timeout got: %v", err) - } - }) - t.Run("block-read", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.Write([]byte{'h', 'i'}) - p.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) - b := make([]byte, 1) - if err := p.Block(); err != nil { - t.Fatal(err) - } - if _, err := p.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) { - t.Fatalf("want read timeout got: %v", err) - } - }) -} - -func TestLimit(t *testing.T) { - p := NewPipe("p1", 1) - errCh := make(chan error) - go func() { - n, err := p.Write([]byte{'a', 'b', 'c'}) - if err != nil { - errCh <- err - } else if n != 3 { - errCh <- fmt.Errorf("p.Write n=%d, want 3", n) - } else { - errCh <- nil - } - }() - b := make([]byte, 3) - - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Errorf("Read(%q): n=%d want 1", string(b), n) - } - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Errorf("Read(%q): n=%d want 1", string(b), n) - } - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Errorf("Read(%q): n=%d want 1", string(b), n) - } - - if err := <-errCh; err != nil { - t.Error(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "errors" + "fmt" + "os" + "testing" + "time" +) + +func TestPipeHello(t *testing.T) { + p := NewPipe("p1", 1<<16) + msg := "Hello, World!" + if n, err := p.Write([]byte(msg)); err != nil { + t.Fatal(err) + } else if n != len(msg) { + t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg)) + } + b := make([]byte, len(msg)) + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != len(b) { + t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b)) + } + if got := string(b); got != msg { + t.Errorf("p.Read: %q, want %q", got, msg) + } +} + +func TestPipeTimeout(t *testing.T) { + t.Run("write", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.SetWriteDeadline(time.Now().Add(-1 * time.Second)) + n, err := p.Write([]byte{'h'}) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("missing write timeout got err: %v", err) + } + if n != 0 { + t.Errorf("n=%d on timeout", n) + } + }) + t.Run("read", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.Write([]byte{'h'}) + + p.SetReadDeadline(time.Now().Add(-1 * time.Second)) + b := make([]byte, 1) + n, err := p.Read(b) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("missing read timeout got err: %v", err) + } + if n != 0 { + t.Errorf("n=%d on timeout", n) + } + }) + t.Run("block-write", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) + if err := p.Block(); err != nil { + t.Fatal(err) + } + if _, err := p.Write([]byte{'h'}); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("want write timeout got: %v", err) + } + }) + t.Run("block-read", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.Write([]byte{'h', 'i'}) + p.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + b := make([]byte, 1) + if err := p.Block(); err != nil { + t.Fatal(err) + } + if _, err := p.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("want read timeout got: %v", err) + } + }) +} + +func TestLimit(t *testing.T) { + p := NewPipe("p1", 1) + errCh := make(chan error) + go func() { + n, err := p.Write([]byte{'a', 'b', 'c'}) + if err != nil { + errCh <- err + } else if n != 3 { + errCh <- fmt.Errorf("p.Write n=%d, want 3", n) + } else { + errCh <- nil + } + }() + b := make([]byte, 3) + + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + + if err := <-errCh; err != nil { + t.Error(err) + } +} diff --git a/net/netaddr/netaddr.go b/net/netaddr/netaddr.go index 1ab6c053a523e..6f85a52b7c550 100644 --- a/net/netaddr/netaddr.go +++ b/net/netaddr/netaddr.go @@ -1,49 +1,49 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netaddr is a transitional package while we finish migrating from inet.af/netaddr -// to Go 1.18's net/netip. -// -// TODO(bradfitz): delete this package eventually. Tracking bug is -// https://github.com/tailscale/tailscale/issues/5162 -package netaddr - -import ( - "net" - "net/netip" -) - -// IPv4 returns the IP of the IPv4 address a.b.c.d. -func IPv4(a, b, c, d uint8) netip.Addr { - return netip.AddrFrom4([4]byte{a, b, c, d}) -} - -// Unmap returns the provided AddrPort with its Addr IP component Unmap'ed. -// -// See https://github.com/golang/go/issues/53607#issuecomment-1203466984 -func Unmap(ap netip.AddrPort) netip.AddrPort { - return netip.AddrPortFrom(ap.Addr().Unmap(), ap.Port()) -} - -// FromStdIPNet returns an IPPrefix from the standard library's IPNet type. -// If std is invalid, ok is false. -func FromStdIPNet(std *net.IPNet) (prefix netip.Prefix, ok bool) { - ip, ok := netip.AddrFromSlice(std.IP) - if !ok { - return netip.Prefix{}, false - } - ip = ip.Unmap() - - if l := len(std.Mask); l != net.IPv4len && l != net.IPv6len { - // Invalid mask. - return netip.Prefix{}, false - } - - ones, bits := std.Mask.Size() - if ones == 0 && bits == 0 { - // IPPrefix does not support non-contiguous masks. - return netip.Prefix{}, false - } - - return netip.PrefixFrom(ip, ones), true -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netaddr is a transitional package while we finish migrating from inet.af/netaddr +// to Go 1.18's net/netip. +// +// TODO(bradfitz): delete this package eventually. Tracking bug is +// https://github.com/tailscale/tailscale/issues/5162 +package netaddr + +import ( + "net" + "net/netip" +) + +// IPv4 returns the IP of the IPv4 address a.b.c.d. +func IPv4(a, b, c, d uint8) netip.Addr { + return netip.AddrFrom4([4]byte{a, b, c, d}) +} + +// Unmap returns the provided AddrPort with its Addr IP component Unmap'ed. +// +// See https://github.com/golang/go/issues/53607#issuecomment-1203466984 +func Unmap(ap netip.AddrPort) netip.AddrPort { + return netip.AddrPortFrom(ap.Addr().Unmap(), ap.Port()) +} + +// FromStdIPNet returns an IPPrefix from the standard library's IPNet type. +// If std is invalid, ok is false. +func FromStdIPNet(std *net.IPNet) (prefix netip.Prefix, ok bool) { + ip, ok := netip.AddrFromSlice(std.IP) + if !ok { + return netip.Prefix{}, false + } + ip = ip.Unmap() + + if l := len(std.Mask); l != net.IPv4len && l != net.IPv6len { + // Invalid mask. + return netip.Prefix{}, false + } + + ones, bits := std.Mask.Size() + if ones == 0 && bits == 0 { + // IPPrefix does not support non-contiguous masks. + return netip.Prefix{}, false + } + + return netip.PrefixFrom(ip, ones), true +} diff --git a/net/neterror/neterror.go b/net/neterror/neterror.go index e2387440d33d5..f570b89302a1b 100644 --- a/net/neterror/neterror.go +++ b/net/neterror/neterror.go @@ -1,82 +1,82 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package neterror classifies network errors. -package neterror - -import ( - "errors" - "fmt" - "runtime" - "syscall" -) - -var errEPERM error = syscall.EPERM // box it into interface just once - -// TreatAsLostUDP reports whether err is an error from a UDP send -// operation that should be treated as a UDP packet that just got -// lost. -// -// Notably, on Linux this reports true for EPERM errors (from outbound -// firewall blocks) which aren't really send errors; they're just -// sends that are never going to make it because the local OS blocked -// it. -func TreatAsLostUDP(err error) bool { - if err == nil { - return false - } - switch runtime.GOOS { - case "linux": - // Linux, while not documented in the man page, - // returns EPERM when there's an OUTPUT rule with -j - // DROP or -j REJECT. We use this very specific - // Linux+EPERM check rather than something super broad - // like net.Error.Temporary which could be anything. - // - // For now we only do this on Linux, as such outgoing - // firewall violations mapping to syscall errors - // hasn't yet been observed on other OSes. - return errors.Is(err, errEPERM) - } - return false -} - -var packetWasTruncated func(error) bool // non-nil on Windows at least - -// PacketWasTruncated reports whether err indicates truncation but the RecvFrom -// that generated err was otherwise successful. On Windows, Go's UDP RecvFrom -// calls WSARecvFrom which returns the WSAEMSGSIZE error code when the received -// datagram is larger than the provided buffer. When that happens, both a valid -// size and an error are returned (as per the partial fix for golang/go#14074). -// If the WSAEMSGSIZE error is returned, then we ignore the error to get -// semantics similar to the POSIX operating systems. One caveat is that it -// appears that the source address is not returned when WSAEMSGSIZE occurs, but -// we do not currently look at the source address. -func PacketWasTruncated(err error) bool { - if packetWasTruncated == nil { - return false - } - return packetWasTruncated(err) -} - -var shouldDisableUDPGSO func(error) bool // non-nil on Linux - -func ShouldDisableUDPGSO(err error) bool { - if shouldDisableUDPGSO == nil { - return false - } - return shouldDisableUDPGSO(err) -} - -type ErrUDPGSODisabled struct { - OnLaddr string - RetryErr error -} - -func (e ErrUDPGSODisabled) Error() string { - return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.OnLaddr) -} - -func (e ErrUDPGSODisabled) Unwrap() error { - return e.RetryErr -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package neterror classifies network errors. +package neterror + +import ( + "errors" + "fmt" + "runtime" + "syscall" +) + +var errEPERM error = syscall.EPERM // box it into interface just once + +// TreatAsLostUDP reports whether err is an error from a UDP send +// operation that should be treated as a UDP packet that just got +// lost. +// +// Notably, on Linux this reports true for EPERM errors (from outbound +// firewall blocks) which aren't really send errors; they're just +// sends that are never going to make it because the local OS blocked +// it. +func TreatAsLostUDP(err error) bool { + if err == nil { + return false + } + switch runtime.GOOS { + case "linux": + // Linux, while not documented in the man page, + // returns EPERM when there's an OUTPUT rule with -j + // DROP or -j REJECT. We use this very specific + // Linux+EPERM check rather than something super broad + // like net.Error.Temporary which could be anything. + // + // For now we only do this on Linux, as such outgoing + // firewall violations mapping to syscall errors + // hasn't yet been observed on other OSes. + return errors.Is(err, errEPERM) + } + return false +} + +var packetWasTruncated func(error) bool // non-nil on Windows at least + +// PacketWasTruncated reports whether err indicates truncation but the RecvFrom +// that generated err was otherwise successful. On Windows, Go's UDP RecvFrom +// calls WSARecvFrom which returns the WSAEMSGSIZE error code when the received +// datagram is larger than the provided buffer. When that happens, both a valid +// size and an error are returned (as per the partial fix for golang/go#14074). +// If the WSAEMSGSIZE error is returned, then we ignore the error to get +// semantics similar to the POSIX operating systems. One caveat is that it +// appears that the source address is not returned when WSAEMSGSIZE occurs, but +// we do not currently look at the source address. +func PacketWasTruncated(err error) bool { + if packetWasTruncated == nil { + return false + } + return packetWasTruncated(err) +} + +var shouldDisableUDPGSO func(error) bool // non-nil on Linux + +func ShouldDisableUDPGSO(err error) bool { + if shouldDisableUDPGSO == nil { + return false + } + return shouldDisableUDPGSO(err) +} + +type ErrUDPGSODisabled struct { + OnLaddr string + RetryErr error +} + +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.OnLaddr) +} + +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} diff --git a/net/neterror/neterror_linux.go b/net/neterror/neterror_linux.go index 857367fe8ebb5..3f402dd30d236 100644 --- a/net/neterror/neterror_linux.go +++ b/net/neterror/neterror_linux.go @@ -1,26 +1,26 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package neterror - -import ( - "errors" - "os" - - "golang.org/x/sys/unix" -) - -func init() { - shouldDisableUDPGSO = func(err error) bool { - var serr *os.SyscallError - if errors.As(err, &serr) { - // EIO is returned by udp_send_skb() if the device driver does not - // have tx checksumming enabled, which is a hard requirement of - // UDP_SEGMENT. See: - // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 - // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 - return serr.Err == unix.EIO - } - return false - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package neterror + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func init() { + shouldDisableUDPGSO = func(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not + // have tx checksumming enabled, which is a hard requirement of + // UDP_SEGMENT. See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false + } +} diff --git a/net/neterror/neterror_linux_test.go b/net/neterror/neterror_linux_test.go index 5b99060741351..1d600d6b6e073 100644 --- a/net/neterror/neterror_linux_test.go +++ b/net/neterror/neterror_linux_test.go @@ -1,54 +1,54 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package neterror - -import ( - "errors" - "net" - "os" - "syscall" - "testing" -) - -func TestTreatAsLostUDP(t *testing.T) { - tests := []struct { - name string - err error - want bool - }{ - {"nil", nil, false}, - {"non-nil", errors.New("foo"), false}, - {"eperm", syscall.EPERM, true}, - { - name: "operror", - err: &net.OpError{ - Op: "write", - Err: &os.SyscallError{ - Syscall: "sendto", - Err: syscall.EPERM, - }, - }, - want: true, - }, - { - name: "host_unreach", - err: &net.OpError{ - Op: "write", - Err: &os.SyscallError{ - Syscall: "sendto", - Err: syscall.EHOSTUNREACH, - }, - }, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := TreatAsLostUDP(tt.err); got != tt.want { - t.Errorf("got = %v; want %v", got, tt.want) - } - }) - } - -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package neterror + +import ( + "errors" + "net" + "os" + "syscall" + "testing" +) + +func TestTreatAsLostUDP(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"non-nil", errors.New("foo"), false}, + {"eperm", syscall.EPERM, true}, + { + name: "operror", + err: &net.OpError{ + Op: "write", + Err: &os.SyscallError{ + Syscall: "sendto", + Err: syscall.EPERM, + }, + }, + want: true, + }, + { + name: "host_unreach", + err: &net.OpError{ + Op: "write", + Err: &os.SyscallError{ + Syscall: "sendto", + Err: syscall.EHOSTUNREACH, + }, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := TreatAsLostUDP(tt.err); got != tt.want { + t.Errorf("got = %v; want %v", got, tt.want) + } + }) + } + +} diff --git a/net/neterror/neterror_windows.go b/net/neterror/neterror_windows.go index bf112f5ed7ab7..c293ec4a96295 100644 --- a/net/neterror/neterror_windows.go +++ b/net/neterror/neterror_windows.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package neterror - -import ( - "errors" - - "golang.org/x/sys/windows" -) - -func init() { - packetWasTruncated = func(err error) bool { - return errors.Is(err, windows.WSAEMSGSIZE) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package neterror + +import ( + "errors" + + "golang.org/x/sys/windows" +) + +func init() { + packetWasTruncated = func(err error) bool { + return errors.Is(err, windows.WSAEMSGSIZE) + } +} diff --git a/net/netkernelconf/netkernelconf.go b/net/netkernelconf/netkernelconf.go index 3ea502b377fdf..23ec9c5b69f19 100644 --- a/net/netkernelconf/netkernelconf.go +++ b/net/netkernelconf/netkernelconf.go @@ -1,5 +1,5 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netkernelconf contains code for checking kernel netdev config. -package netkernelconf +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netkernelconf contains code for checking kernel netdev config. +package netkernelconf diff --git a/net/netknob/netknob.go b/net/netknob/netknob.go index 53171f4243f8d..0b271fc95b720 100644 --- a/net/netknob/netknob.go +++ b/net/netknob/netknob.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netknob has Tailscale network knobs. -package netknob - -import ( - "runtime" - "time" -) - -// PlatformTCPKeepAlive returns the default net.Dialer.KeepAlive -// value for the current runtime.GOOS. -func PlatformTCPKeepAlive() time.Duration { - switch runtime.GOOS { - case "ios", "android": - // Disable TCP keep-alives on mobile platforms. - // See https://github.com/golang/go/issues/48622. - // - // TODO(bradfitz): in 1.17.x, try disabling TCP - // keep-alives on for all platforms. - return -1 - } - - // Otherwise, default to 30 seconds, which is mostly what we - // used to do. In some places we used the zero value, which Go - // defaults to 15 seconds. But 30 seconds is fine. - return 30 * time.Second -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netknob has Tailscale network knobs. +package netknob + +import ( + "runtime" + "time" +) + +// PlatformTCPKeepAlive returns the default net.Dialer.KeepAlive +// value for the current runtime.GOOS. +func PlatformTCPKeepAlive() time.Duration { + switch runtime.GOOS { + case "ios", "android": + // Disable TCP keep-alives on mobile platforms. + // See https://github.com/golang/go/issues/48622. + // + // TODO(bradfitz): in 1.17.x, try disabling TCP + // keep-alives on for all platforms. + return -1 + } + + // Otherwise, default to 30 seconds, which is mostly what we + // used to do. In some places we used the zero value, which Go + // defaults to 15 seconds. But 30 seconds is fine. + return 30 * time.Second +} diff --git a/net/netmon/netmon_darwin_test.go b/net/netmon/netmon_darwin_test.go index 84c67cf6fa3e2..77a212683e035 100644 --- a/net/netmon/netmon_darwin_test.go +++ b/net/netmon/netmon_darwin_test.go @@ -1,27 +1,27 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netmon - -import ( - "encoding/hex" - "strings" - "testing" - - "golang.org/x/net/route" -) - -func TestIssue1416RIB(t *testing.T) { - const ribHex = `32 00 05 10 30 00 00 00 00 00 00 00 04 00 00 00 14 12 04 00 06 03 06 00 65 6e 30 ac 87 a3 19 7f 82 00 00 00 0e 12 00 00 00 00 06 00 91 e0 f0 01 00 00` - rtmMsg, err := hex.DecodeString(strings.ReplaceAll(ribHex, " ", "")) - if err != nil { - t.Fatal(err) - } - msgs, err := route.ParseRIB(route.RIBTypeRoute, rtmMsg) - if err != nil { - t.Logf("ParseRIB: %v", err) - t.Skip("skipping on known failure; see https://github.com/tailscale/tailscale/issues/1416") - t.Fatal(err) - } - t.Logf("Got: %#v", msgs) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netmon + +import ( + "encoding/hex" + "strings" + "testing" + + "golang.org/x/net/route" +) + +func TestIssue1416RIB(t *testing.T) { + const ribHex = `32 00 05 10 30 00 00 00 00 00 00 00 04 00 00 00 14 12 04 00 06 03 06 00 65 6e 30 ac 87 a3 19 7f 82 00 00 00 0e 12 00 00 00 00 06 00 91 e0 f0 01 00 00` + rtmMsg, err := hex.DecodeString(strings.ReplaceAll(ribHex, " ", "")) + if err != nil { + t.Fatal(err) + } + msgs, err := route.ParseRIB(route.RIBTypeRoute, rtmMsg) + if err != nil { + t.Logf("ParseRIB: %v", err) + t.Skip("skipping on known failure; see https://github.com/tailscale/tailscale/issues/1416") + t.Fatal(err) + } + t.Logf("Got: %#v", msgs) +} diff --git a/net/netmon/netmon_freebsd.go b/net/netmon/netmon_freebsd.go index 30480a1d3387e..724f964c98747 100644 --- a/net/netmon/netmon_freebsd.go +++ b/net/netmon/netmon_freebsd.go @@ -1,56 +1,56 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netmon - -import ( - "bufio" - "fmt" - "net" - "strings" - - "tailscale.com/types/logger" -) - -// unspecifiedMessage is a minimal message implementation that should not -// be ignored. In general, OS-specific implementations should use better -// types and avoid this if they can. -type unspecifiedMessage struct{} - -func (unspecifiedMessage) ignore() bool { return false } - -// devdConn implements osMon using devd(8). -type devdConn struct { - conn net.Conn -} - -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { - conn, err := net.Dial("unixpacket", "/var/run/devd.seqpacket.pipe") - if err != nil { - logf("devd dial error: %v, falling back to polling method", err) - return newPollingMon(logf, m) - } - return &devdConn{conn}, nil -} - -func (c *devdConn) IsInterestingInterface(iface string) bool { return true } - -func (c *devdConn) Close() error { - return c.conn.Close() -} - -func (c *devdConn) Receive() (message, error) { - for { - msg, err := bufio.NewReader(c.conn).ReadString('\n') - if err != nil { - return nil, fmt.Errorf("reading devd socket: %v", err) - } - // Only return messages related to the network subsystem. - if !strings.Contains(msg, "system=IFNET") { - continue - } - // TODO: this is where the devd-specific message would - // get converted into a "standard" event message and returned. - return unspecifiedMessage{}, nil - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netmon + +import ( + "bufio" + "fmt" + "net" + "strings" + + "tailscale.com/types/logger" +) + +// unspecifiedMessage is a minimal message implementation that should not +// be ignored. In general, OS-specific implementations should use better +// types and avoid this if they can. +type unspecifiedMessage struct{} + +func (unspecifiedMessage) ignore() bool { return false } + +// devdConn implements osMon using devd(8). +type devdConn struct { + conn net.Conn +} + +func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { + conn, err := net.Dial("unixpacket", "/var/run/devd.seqpacket.pipe") + if err != nil { + logf("devd dial error: %v, falling back to polling method", err) + return newPollingMon(logf, m) + } + return &devdConn{conn}, nil +} + +func (c *devdConn) IsInterestingInterface(iface string) bool { return true } + +func (c *devdConn) Close() error { + return c.conn.Close() +} + +func (c *devdConn) Receive() (message, error) { + for { + msg, err := bufio.NewReader(c.conn).ReadString('\n') + if err != nil { + return nil, fmt.Errorf("reading devd socket: %v", err) + } + // Only return messages related to the network subsystem. + if !strings.Contains(msg, "system=IFNET") { + continue + } + // TODO: this is where the devd-specific message would + // get converted into a "standard" event message and returned. + return unspecifiedMessage{}, nil + } +} diff --git a/net/netmon/netmon_linux.go b/net/netmon/netmon_linux.go index dd23dd34263c5..888afa92d7612 100644 --- a/net/netmon/netmon_linux.go +++ b/net/netmon/netmon_linux.go @@ -1,290 +1,290 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !android - -package netmon - -import ( - "net" - "net/netip" - "time" - - "github.com/jsimonetti/rtnetlink" - "github.com/mdlayher/netlink" - "golang.org/x/sys/unix" - "tailscale.com/envknob" - "tailscale.com/net/tsaddr" - "tailscale.com/types/logger" -) - -var debugNetlinkMessages = envknob.RegisterBool("TS_DEBUG_NETLINK") - -// unspecifiedMessage is a minimal message implementation that should not -// be ignored. In general, OS-specific implementations should use better -// types and avoid this if they can. -type unspecifiedMessage struct{} - -func (unspecifiedMessage) ignore() bool { return false } - -// nlConn wraps a *netlink.Conn and returns a monitor.Message -// instead of a netlink.Message. Currently, messages are discarded, -// but down the line, when messages trigger different logic depending -// on the type of event, this provides the capability of handling -// each architecture-specific message in a generic fashion. -type nlConn struct { - logf logger.Logf - conn *netlink.Conn - buffered []netlink.Message - - // addrCache maps interface indices to a set of addresses, and is - // used to suppress duplicate RTM_NEWADDR messages. It is populated - // by RTM_NEWADDR messages and de-populated by RTM_DELADDR. See - // issue #4282. - addrCache map[uint32]map[netip.Addr]bool -} - -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { - conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ - // Routes get us most of the events of interest, but we need - // address as well to cover things like DHCP deciding to give - // us a new address upon renewal - routing wouldn't change, - // but all reachability would. - Groups: unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR | - unix.RTMGRP_IPV4_ROUTE | unix.RTMGRP_IPV6_ROUTE | - unix.RTMGRP_IPV4_RULE, // no IPV6_RULE in x/sys/unix - }) - if err != nil { - // Google Cloud Run does not implement NETLINK_ROUTE RTMGRP support - logf("monitor_linux: AF_NETLINK RTMGRP failed, falling back to polling") - return newPollingMon(logf, m) - } - return &nlConn{logf: logf, conn: conn, addrCache: make(map[uint32]map[netip.Addr]bool)}, nil -} - -func (c *nlConn) IsInterestingInterface(iface string) bool { return true } - -func (c *nlConn) Close() error { return c.conn.Close() } - -func (c *nlConn) Receive() (message, error) { - if len(c.buffered) == 0 { - var err error - c.buffered, err = c.conn.Receive() - if err != nil { - return nil, err - } - if len(c.buffered) == 0 { - // Unexpected. Not seen in wild, but sleep defensively. - time.Sleep(time.Second) - return ignoreMessage{}, nil - } - } - msg := c.buffered[0] - c.buffered = c.buffered[1:] - - // See https://github.com/torvalds/linux/blob/master/include/uapi/linux/rtnetlink.h - // And https://man7.org/linux/man-pages/man7/rtnetlink.7.html - switch msg.Header.Type { - case unix.RTM_NEWADDR, unix.RTM_DELADDR: - var rmsg rtnetlink.AddressMessage - if err := rmsg.UnmarshalBinary(msg.Data); err != nil { - c.logf("failed to parse type %v: %v", msg.Header.Type, err) - return unspecifiedMessage{}, nil - } - - nip := netaddrIP(rmsg.Attributes.Address) - - if debugNetlinkMessages() { - typ := "RTM_NEWADDR" - if msg.Header.Type == unix.RTM_DELADDR { - typ = "RTM_DELADDR" - } - - // label attributes are seemingly only populated for IPv4 addresses in the wild. - label := rmsg.Attributes.Label - if label == "" { - itf, err := net.InterfaceByIndex(int(rmsg.Index)) - if err == nil { - label = itf.Name - } - } - - c.logf("%s: %s(%d) %s / %s", typ, label, rmsg.Index, rmsg.Attributes.Address, rmsg.Attributes.Local) - } - - addrs := c.addrCache[rmsg.Index] - - // Ignore duplicate RTM_NEWADDR messages using c.addrCache to - // detect them. See nlConn.addrcache and issue #4282. - if msg.Header.Type == unix.RTM_NEWADDR { - if addrs == nil { - addrs = make(map[netip.Addr]bool) - c.addrCache[rmsg.Index] = addrs - } - - if addrs[nip] { - if debugNetlinkMessages() { - c.logf("ignored duplicate RTM_NEWADDR for %s", nip) - } - return ignoreMessage{}, nil - } - - addrs[nip] = true - } else { // msg.Header.Type == unix.RTM_DELADDR - if addrs != nil { - delete(addrs, nip) - } - - if len(addrs) == 0 { - delete(c.addrCache, rmsg.Index) - } - } - - nam := &newAddrMessage{ - IfIndex: rmsg.Index, - Addr: nip, - Delete: msg.Header.Type == unix.RTM_DELADDR, - } - if debugNetlinkMessages() { - c.logf("%+v", nam) - } - return nam, nil - case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: - typeStr := "RTM_NEWROUTE" - if msg.Header.Type == unix.RTM_DELROUTE { - typeStr = "RTM_DELROUTE" - } - var rmsg rtnetlink.RouteMessage - if err := rmsg.UnmarshalBinary(msg.Data); err != nil { - c.logf("%s: failed to parse: %v", typeStr, err) - return unspecifiedMessage{}, nil - } - src := netaddrIPPrefix(rmsg.Attributes.Src, rmsg.SrcLength) - dst := netaddrIPPrefix(rmsg.Attributes.Dst, rmsg.DstLength) - gw := netaddrIP(rmsg.Attributes.Gateway) - - if msg.Header.Type == unix.RTM_NEWROUTE && - (rmsg.Attributes.Table == 255 || rmsg.Attributes.Table == 254) && - (dst.Addr().IsMulticast() || dst.Addr().IsLinkLocalUnicast()) { - - if debugNetlinkMessages() { - c.logf("%s ignored", typeStr) - } - - // Normal Linux route changes on new interface coming up; don't log or react. - return ignoreMessage{}, nil - } - - if rmsg.Table == tsTable && dst.IsSingleIP() { - // Don't log. Spammy and normal to see a bunch of these on start-up, - // which we make ourselves. - } else if tsaddr.IsTailscaleIP(dst.Addr()) { - // Verbose only. - c.logf("%s: [v1] src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr, - condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw), - rmsg.Attributes.OutIface, rmsg.Attributes.Table) - } else { - c.logf("%s: src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr, - condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw), - rmsg.Attributes.OutIface, rmsg.Attributes.Table) - } - if msg.Header.Type == unix.RTM_DELROUTE { - // Just logging it for now. - // (Debugging https://github.com/tailscale/tailscale/issues/643) - return unspecifiedMessage{}, nil - } - - nrm := &newRouteMessage{ - Table: rmsg.Table, - Src: src, - Dst: dst, - Gateway: gw, - } - if debugNetlinkMessages() { - c.logf("%+v", nrm) - } - return nrm, nil - case unix.RTM_NEWRULE: - // Probably ourselves adding it. - return ignoreMessage{}, nil - case unix.RTM_DELRULE: - // For https://github.com/tailscale/tailscale/issues/1591 where - // systemd-networkd deletes our rules. - var rmsg rtnetlink.RouteMessage - err := rmsg.UnmarshalBinary(msg.Data) - if err != nil { - c.logf("ip rule deleted; failed to parse netlink message: %v", err) - } else { - c.logf("ip rule deleted: %+v", rmsg) - // On `ip -4 rule del pref 5210 table main`, logs: - // monitor: ip rule deleted: {Family:2 DstLength:0 SrcLength:0 Tos:0 Table:254 Protocol:0 Scope:0 Type:1 Flags:0 Attributes:{Dst: Src: Gateway: OutIface:0 Priority:5210 Table:254 Mark:4294967295 Expires: Metrics: Multipath:[]}} - } - rdm := ipRuleDeletedMessage{ - table: rmsg.Table, - priority: rmsg.Attributes.Priority, - } - if debugNetlinkMessages() { - c.logf("%+v", rdm) - } - return rdm, nil - case unix.RTM_NEWLINK, unix.RTM_DELLINK: - // This is an unhandled message, but don't print an error. - // See https://github.com/tailscale/tailscale/issues/6806 - return unspecifiedMessage{}, nil - default: - c.logf("unhandled netlink msg type %+v, %q", msg.Header, msg.Data) - return unspecifiedMessage{}, nil - } -} - -func netaddrIP(std net.IP) netip.Addr { - ip, _ := netip.AddrFromSlice(std) - return ip.Unmap() -} - -func netaddrIPPrefix(std net.IP, bits uint8) netip.Prefix { - ip, _ := netip.AddrFromSlice(std) - return netip.PrefixFrom(ip.Unmap(), int(bits)) -} - -func condNetAddrPrefix(ipp netip.Prefix) string { - if !ipp.Addr().IsValid() { - return "" - } - return ipp.String() -} - -func condNetAddrIP(ip netip.Addr) string { - if !ip.IsValid() { - return "" - } - return ip.String() -} - -// newRouteMessage is a message for a new route being added. -type newRouteMessage struct { - Src, Dst netip.Prefix - Gateway netip.Addr - Table uint8 -} - -const tsTable = 52 - -func (m *newRouteMessage) ignore() bool { - return m.Table == tsTable || tsaddr.IsTailscaleIP(m.Dst.Addr()) -} - -// newAddrMessage is a message for a new address being added. -type newAddrMessage struct { - Delete bool - Addr netip.Addr - IfIndex uint32 // interface index -} - -func (m *newAddrMessage) ignore() bool { - return tsaddr.IsTailscaleIP(m.Addr) -} - -type ignoreMessage struct{} - -func (ignoreMessage) ignore() bool { return true } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !android + +package netmon + +import ( + "net" + "net/netip" + "time" + + "github.com/jsimonetti/rtnetlink" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" + "tailscale.com/envknob" + "tailscale.com/net/tsaddr" + "tailscale.com/types/logger" +) + +var debugNetlinkMessages = envknob.RegisterBool("TS_DEBUG_NETLINK") + +// unspecifiedMessage is a minimal message implementation that should not +// be ignored. In general, OS-specific implementations should use better +// types and avoid this if they can. +type unspecifiedMessage struct{} + +func (unspecifiedMessage) ignore() bool { return false } + +// nlConn wraps a *netlink.Conn and returns a monitor.Message +// instead of a netlink.Message. Currently, messages are discarded, +// but down the line, when messages trigger different logic depending +// on the type of event, this provides the capability of handling +// each architecture-specific message in a generic fashion. +type nlConn struct { + logf logger.Logf + conn *netlink.Conn + buffered []netlink.Message + + // addrCache maps interface indices to a set of addresses, and is + // used to suppress duplicate RTM_NEWADDR messages. It is populated + // by RTM_NEWADDR messages and de-populated by RTM_DELADDR. See + // issue #4282. + addrCache map[uint32]map[netip.Addr]bool +} + +func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { + conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ + // Routes get us most of the events of interest, but we need + // address as well to cover things like DHCP deciding to give + // us a new address upon renewal - routing wouldn't change, + // but all reachability would. + Groups: unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR | + unix.RTMGRP_IPV4_ROUTE | unix.RTMGRP_IPV6_ROUTE | + unix.RTMGRP_IPV4_RULE, // no IPV6_RULE in x/sys/unix + }) + if err != nil { + // Google Cloud Run does not implement NETLINK_ROUTE RTMGRP support + logf("monitor_linux: AF_NETLINK RTMGRP failed, falling back to polling") + return newPollingMon(logf, m) + } + return &nlConn{logf: logf, conn: conn, addrCache: make(map[uint32]map[netip.Addr]bool)}, nil +} + +func (c *nlConn) IsInterestingInterface(iface string) bool { return true } + +func (c *nlConn) Close() error { return c.conn.Close() } + +func (c *nlConn) Receive() (message, error) { + if len(c.buffered) == 0 { + var err error + c.buffered, err = c.conn.Receive() + if err != nil { + return nil, err + } + if len(c.buffered) == 0 { + // Unexpected. Not seen in wild, but sleep defensively. + time.Sleep(time.Second) + return ignoreMessage{}, nil + } + } + msg := c.buffered[0] + c.buffered = c.buffered[1:] + + // See https://github.com/torvalds/linux/blob/master/include/uapi/linux/rtnetlink.h + // And https://man7.org/linux/man-pages/man7/rtnetlink.7.html + switch msg.Header.Type { + case unix.RTM_NEWADDR, unix.RTM_DELADDR: + var rmsg rtnetlink.AddressMessage + if err := rmsg.UnmarshalBinary(msg.Data); err != nil { + c.logf("failed to parse type %v: %v", msg.Header.Type, err) + return unspecifiedMessage{}, nil + } + + nip := netaddrIP(rmsg.Attributes.Address) + + if debugNetlinkMessages() { + typ := "RTM_NEWADDR" + if msg.Header.Type == unix.RTM_DELADDR { + typ = "RTM_DELADDR" + } + + // label attributes are seemingly only populated for IPv4 addresses in the wild. + label := rmsg.Attributes.Label + if label == "" { + itf, err := net.InterfaceByIndex(int(rmsg.Index)) + if err == nil { + label = itf.Name + } + } + + c.logf("%s: %s(%d) %s / %s", typ, label, rmsg.Index, rmsg.Attributes.Address, rmsg.Attributes.Local) + } + + addrs := c.addrCache[rmsg.Index] + + // Ignore duplicate RTM_NEWADDR messages using c.addrCache to + // detect them. See nlConn.addrcache and issue #4282. + if msg.Header.Type == unix.RTM_NEWADDR { + if addrs == nil { + addrs = make(map[netip.Addr]bool) + c.addrCache[rmsg.Index] = addrs + } + + if addrs[nip] { + if debugNetlinkMessages() { + c.logf("ignored duplicate RTM_NEWADDR for %s", nip) + } + return ignoreMessage{}, nil + } + + addrs[nip] = true + } else { // msg.Header.Type == unix.RTM_DELADDR + if addrs != nil { + delete(addrs, nip) + } + + if len(addrs) == 0 { + delete(c.addrCache, rmsg.Index) + } + } + + nam := &newAddrMessage{ + IfIndex: rmsg.Index, + Addr: nip, + Delete: msg.Header.Type == unix.RTM_DELADDR, + } + if debugNetlinkMessages() { + c.logf("%+v", nam) + } + return nam, nil + case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: + typeStr := "RTM_NEWROUTE" + if msg.Header.Type == unix.RTM_DELROUTE { + typeStr = "RTM_DELROUTE" + } + var rmsg rtnetlink.RouteMessage + if err := rmsg.UnmarshalBinary(msg.Data); err != nil { + c.logf("%s: failed to parse: %v", typeStr, err) + return unspecifiedMessage{}, nil + } + src := netaddrIPPrefix(rmsg.Attributes.Src, rmsg.SrcLength) + dst := netaddrIPPrefix(rmsg.Attributes.Dst, rmsg.DstLength) + gw := netaddrIP(rmsg.Attributes.Gateway) + + if msg.Header.Type == unix.RTM_NEWROUTE && + (rmsg.Attributes.Table == 255 || rmsg.Attributes.Table == 254) && + (dst.Addr().IsMulticast() || dst.Addr().IsLinkLocalUnicast()) { + + if debugNetlinkMessages() { + c.logf("%s ignored", typeStr) + } + + // Normal Linux route changes on new interface coming up; don't log or react. + return ignoreMessage{}, nil + } + + if rmsg.Table == tsTable && dst.IsSingleIP() { + // Don't log. Spammy and normal to see a bunch of these on start-up, + // which we make ourselves. + } else if tsaddr.IsTailscaleIP(dst.Addr()) { + // Verbose only. + c.logf("%s: [v1] src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr, + condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw), + rmsg.Attributes.OutIface, rmsg.Attributes.Table) + } else { + c.logf("%s: src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr, + condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw), + rmsg.Attributes.OutIface, rmsg.Attributes.Table) + } + if msg.Header.Type == unix.RTM_DELROUTE { + // Just logging it for now. + // (Debugging https://github.com/tailscale/tailscale/issues/643) + return unspecifiedMessage{}, nil + } + + nrm := &newRouteMessage{ + Table: rmsg.Table, + Src: src, + Dst: dst, + Gateway: gw, + } + if debugNetlinkMessages() { + c.logf("%+v", nrm) + } + return nrm, nil + case unix.RTM_NEWRULE: + // Probably ourselves adding it. + return ignoreMessage{}, nil + case unix.RTM_DELRULE: + // For https://github.com/tailscale/tailscale/issues/1591 where + // systemd-networkd deletes our rules. + var rmsg rtnetlink.RouteMessage + err := rmsg.UnmarshalBinary(msg.Data) + if err != nil { + c.logf("ip rule deleted; failed to parse netlink message: %v", err) + } else { + c.logf("ip rule deleted: %+v", rmsg) + // On `ip -4 rule del pref 5210 table main`, logs: + // monitor: ip rule deleted: {Family:2 DstLength:0 SrcLength:0 Tos:0 Table:254 Protocol:0 Scope:0 Type:1 Flags:0 Attributes:{Dst: Src: Gateway: OutIface:0 Priority:5210 Table:254 Mark:4294967295 Expires: Metrics: Multipath:[]}} + } + rdm := ipRuleDeletedMessage{ + table: rmsg.Table, + priority: rmsg.Attributes.Priority, + } + if debugNetlinkMessages() { + c.logf("%+v", rdm) + } + return rdm, nil + case unix.RTM_NEWLINK, unix.RTM_DELLINK: + // This is an unhandled message, but don't print an error. + // See https://github.com/tailscale/tailscale/issues/6806 + return unspecifiedMessage{}, nil + default: + c.logf("unhandled netlink msg type %+v, %q", msg.Header, msg.Data) + return unspecifiedMessage{}, nil + } +} + +func netaddrIP(std net.IP) netip.Addr { + ip, _ := netip.AddrFromSlice(std) + return ip.Unmap() +} + +func netaddrIPPrefix(std net.IP, bits uint8) netip.Prefix { + ip, _ := netip.AddrFromSlice(std) + return netip.PrefixFrom(ip.Unmap(), int(bits)) +} + +func condNetAddrPrefix(ipp netip.Prefix) string { + if !ipp.Addr().IsValid() { + return "" + } + return ipp.String() +} + +func condNetAddrIP(ip netip.Addr) string { + if !ip.IsValid() { + return "" + } + return ip.String() +} + +// newRouteMessage is a message for a new route being added. +type newRouteMessage struct { + Src, Dst netip.Prefix + Gateway netip.Addr + Table uint8 +} + +const tsTable = 52 + +func (m *newRouteMessage) ignore() bool { + return m.Table == tsTable || tsaddr.IsTailscaleIP(m.Dst.Addr()) +} + +// newAddrMessage is a message for a new address being added. +type newAddrMessage struct { + Delete bool + Addr netip.Addr + IfIndex uint32 // interface index +} + +func (m *newAddrMessage) ignore() bool { + return tsaddr.IsTailscaleIP(m.Addr) +} + +type ignoreMessage struct{} + +func (ignoreMessage) ignore() bool { return true } diff --git a/net/netmon/netmon_polling.go b/net/netmon/netmon_polling.go index 3d6f94731077a..1ce4a51deadc4 100644 --- a/net/netmon/netmon_polling.go +++ b/net/netmon/netmon_polling.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build (!linux && !freebsd && !windows && !darwin) || android - -package netmon - -import ( - "tailscale.com/types/logger" -) - -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { - return newPollingMon(logf, m) -} - -// unspecifiedMessage is a minimal message implementation that should not -// be ignored. In general, OS-specific implementations should use better -// types and avoid this if they can. -type unspecifiedMessage struct{} - -func (unspecifiedMessage) ignore() bool { return false } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (!linux && !freebsd && !windows && !darwin) || android + +package netmon + +import ( + "tailscale.com/types/logger" +) + +func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { + return newPollingMon(logf, m) +} + +// unspecifiedMessage is a minimal message implementation that should not +// be ignored. In general, OS-specific implementations should use better +// types and avoid this if they can. +type unspecifiedMessage struct{} + +func (unspecifiedMessage) ignore() bool { return false } diff --git a/net/netmon/polling.go b/net/netmon/polling.go index ce1618ed6c987..bb7210b94ed62 100644 --- a/net/netmon/polling.go +++ b/net/netmon/polling.go @@ -1,86 +1,86 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !darwin - -package netmon - -import ( - "bytes" - "errors" - "os" - "runtime" - "sync" - "time" - - "tailscale.com/types/logger" -) - -func newPollingMon(logf logger.Logf, m *Monitor) (osMon, error) { - return &pollingMon{ - logf: logf, - m: m, - stop: make(chan struct{}), - }, nil -} - -// pollingMon is a bad but portable implementation of the link monitor -// that works by polling the interface state every 10 seconds, in lieu -// of anything to subscribe to. -type pollingMon struct { - logf logger.Logf - m *Monitor - - closeOnce sync.Once - stop chan struct{} -} - -func (pm *pollingMon) IsInterestingInterface(iface string) bool { - return true -} - -func (pm *pollingMon) Close() error { - pm.closeOnce.Do(func() { - close(pm.stop) - }) - return nil -} - -func (pm *pollingMon) isCloudRun() bool { - // https: //cloud.google.com/run/docs/reference/container-contract#env-vars - if os.Getenv("K_REVISION") == "" || os.Getenv("K_CONFIGURATION") == "" || - os.Getenv("K_SERVICE") == "" || os.Getenv("PORT") == "" { - return false - } - vers, err := os.ReadFile("/proc/version") - if err != nil { - pm.logf("Failed to read /proc/version: %v", err) - return false - } - return string(bytes.TrimSpace(vers)) == "Linux version 4.4.0 #1 SMP Sun Jan 10 15:06:54 PST 2016" -} - -func (pm *pollingMon) Receive() (message, error) { - d := 10 * time.Second - if runtime.GOOS == "android" { - // We'll have Android notify the link monitor to wake up earlier, - // so this can go very slowly there, to save battery. - // https://github.com/tailscale/tailscale/issues/1427 - d = 10 * time.Minute - } else if pm.isCloudRun() { - // Cloud Run routes never change at runtime. the containers are killed within - // 15 minutes by default, set the interval long enough to be effectively infinite. - pm.logf("monitor polling: Cloud Run detected, reduce polling interval to 24h") - d = 24 * time.Hour - } - timer := time.NewTimer(d) - defer timer.Stop() - for { - select { - case <-timer.C: - return unspecifiedMessage{}, nil - case <-pm.stop: - return nil, errors.New("stopped") - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !darwin + +package netmon + +import ( + "bytes" + "errors" + "os" + "runtime" + "sync" + "time" + + "tailscale.com/types/logger" +) + +func newPollingMon(logf logger.Logf, m *Monitor) (osMon, error) { + return &pollingMon{ + logf: logf, + m: m, + stop: make(chan struct{}), + }, nil +} + +// pollingMon is a bad but portable implementation of the link monitor +// that works by polling the interface state every 10 seconds, in lieu +// of anything to subscribe to. +type pollingMon struct { + logf logger.Logf + m *Monitor + + closeOnce sync.Once + stop chan struct{} +} + +func (pm *pollingMon) IsInterestingInterface(iface string) bool { + return true +} + +func (pm *pollingMon) Close() error { + pm.closeOnce.Do(func() { + close(pm.stop) + }) + return nil +} + +func (pm *pollingMon) isCloudRun() bool { + // https: //cloud.google.com/run/docs/reference/container-contract#env-vars + if os.Getenv("K_REVISION") == "" || os.Getenv("K_CONFIGURATION") == "" || + os.Getenv("K_SERVICE") == "" || os.Getenv("PORT") == "" { + return false + } + vers, err := os.ReadFile("/proc/version") + if err != nil { + pm.logf("Failed to read /proc/version: %v", err) + return false + } + return string(bytes.TrimSpace(vers)) == "Linux version 4.4.0 #1 SMP Sun Jan 10 15:06:54 PST 2016" +} + +func (pm *pollingMon) Receive() (message, error) { + d := 10 * time.Second + if runtime.GOOS == "android" { + // We'll have Android notify the link monitor to wake up earlier, + // so this can go very slowly there, to save battery. + // https://github.com/tailscale/tailscale/issues/1427 + d = 10 * time.Minute + } else if pm.isCloudRun() { + // Cloud Run routes never change at runtime. the containers are killed within + // 15 minutes by default, set the interval long enough to be effectively infinite. + pm.logf("monitor polling: Cloud Run detected, reduce polling interval to 24h") + d = 24 * time.Hour + } + timer := time.NewTimer(d) + defer timer.Stop() + for { + select { + case <-timer.C: + return unspecifiedMessage{}, nil + case <-pm.stop: + return nil, errors.New("stopped") + } + } +} diff --git a/net/netns/netns_android.go b/net/netns/netns_android.go index 162e5c79a62fa..79084ff11f521 100644 --- a/net/netns/netns_android.go +++ b/net/netns/netns_android.go @@ -1,75 +1,75 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build android - -package netns - -import ( - "fmt" - "sync" - "syscall" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -var ( - androidProtectFuncMu sync.Mutex - androidProtectFunc func(fd int) error -) - -// UseSocketMark reports whether SO_MARK is in use. Android does not use SO_MARK. -func UseSocketMark() bool { - return false -} - -// SetAndroidProtectFunc register a func that Android provides that JNI calls into -// https://developer.android.com/reference/android/net/VpnService#protect(int) -// which is documented as: -// -// "Protect a socket from VPN connections. After protecting, data sent -// through this socket will go directly to the underlying network, so -// its traffic will not be forwarded through the VPN. This method is -// useful if some connections need to be kept outside of VPN. For -// example, a VPN tunnel should protect itself if its destination is -// covered by VPN routes. Otherwise its outgoing packets will be sent -// back to the VPN interface and cause an infinite loop. This method -// will fail if the application is not prepared or is revoked." -// -// A nil func disables the use the hook. -// -// This indirection is necessary because this is the supported, stable -// interface to use on Android, and doing the sockopts to set the -// fwmark return errors on Android. The actual implementation of -// VpnService.protect ends up doing an IPC to another process on -// Android, asking for the fwmark to be set. -func SetAndroidProtectFunc(f func(fd int) error) { - androidProtectFuncMu.Lock() - defer androidProtectFuncMu.Unlock() - androidProtectFunc = f -} - -func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { - return controlC -} - -// controlC marks c as necessary to dial in a separate network namespace. -// -// It's intentionally the same signature as net.Dialer.Control -// and net.ListenConfig.Control. -func controlC(network, address string, c syscall.RawConn) error { - var sockErr error - err := c.Control(func(fd uintptr) { - androidProtectFuncMu.Lock() - f := androidProtectFunc - androidProtectFuncMu.Unlock() - if f != nil { - sockErr = f(int(fd)) - } - }) - if err != nil { - return fmt.Errorf("RawConn.Control on %T: %w", c, err) - } - return sockErr -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build android + +package netns + +import ( + "fmt" + "sync" + "syscall" + + "tailscale.com/net/netmon" + "tailscale.com/types/logger" +) + +var ( + androidProtectFuncMu sync.Mutex + androidProtectFunc func(fd int) error +) + +// UseSocketMark reports whether SO_MARK is in use. Android does not use SO_MARK. +func UseSocketMark() bool { + return false +} + +// SetAndroidProtectFunc register a func that Android provides that JNI calls into +// https://developer.android.com/reference/android/net/VpnService#protect(int) +// which is documented as: +// +// "Protect a socket from VPN connections. After protecting, data sent +// through this socket will go directly to the underlying network, so +// its traffic will not be forwarded through the VPN. This method is +// useful if some connections need to be kept outside of VPN. For +// example, a VPN tunnel should protect itself if its destination is +// covered by VPN routes. Otherwise its outgoing packets will be sent +// back to the VPN interface and cause an infinite loop. This method +// will fail if the application is not prepared or is revoked." +// +// A nil func disables the use the hook. +// +// This indirection is necessary because this is the supported, stable +// interface to use on Android, and doing the sockopts to set the +// fwmark return errors on Android. The actual implementation of +// VpnService.protect ends up doing an IPC to another process on +// Android, asking for the fwmark to be set. +func SetAndroidProtectFunc(f func(fd int) error) { + androidProtectFuncMu.Lock() + defer androidProtectFuncMu.Unlock() + androidProtectFunc = f +} + +func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { + return controlC +} + +// controlC marks c as necessary to dial in a separate network namespace. +// +// It's intentionally the same signature as net.Dialer.Control +// and net.ListenConfig.Control. +func controlC(network, address string, c syscall.RawConn) error { + var sockErr error + err := c.Control(func(fd uintptr) { + androidProtectFuncMu.Lock() + f := androidProtectFunc + androidProtectFuncMu.Unlock() + if f != nil { + sockErr = f(int(fd)) + } + }) + if err != nil { + return fmt.Errorf("RawConn.Control on %T: %w", c, err) + } + return sockErr +} diff --git a/net/netns/netns_default.go b/net/netns/netns_default.go index 94f24d8fa4e19..02db19e7566fa 100644 --- a/net/netns/netns_default.go +++ b/net/netns/netns_default.go @@ -1,22 +1,22 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !windows && !darwin - -package netns - -import ( - "syscall" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { - return controlC -} - -// controlC does nothing to c. -func controlC(network, address string, c syscall.RawConn) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !windows && !darwin + +package netns + +import ( + "syscall" + + "tailscale.com/net/netmon" + "tailscale.com/types/logger" +) + +func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { + return controlC +} + +// controlC does nothing to c. +func controlC(network, address string, c syscall.RawConn) error { + return nil +} diff --git a/net/netns/netns_linux_test.go b/net/netns/netns_linux_test.go index a5000f37f0a44..cc221bcb1712c 100644 --- a/net/netns/netns_linux_test.go +++ b/net/netns/netns_linux_test.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netns - -import ( - "testing" -) - -func TestSocketMarkWorks(t *testing.T) { - _ = socketMarkWorks() - // we cannot actually assert whether the test runner has SO_MARK available - // or not, as we don't know. We're just checking that it doesn't panic. -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netns + +import ( + "testing" +) + +func TestSocketMarkWorks(t *testing.T) { + _ = socketMarkWorks() + // we cannot actually assert whether the test runner has SO_MARK available + // or not, as we don't know. We're just checking that it doesn't panic. +} diff --git a/net/netns/netns_test.go b/net/netns/netns_test.go index 82f919b946d4a..1c6d699ac88aa 100644 --- a/net/netns/netns_test.go +++ b/net/netns/netns_test.go @@ -1,78 +1,78 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netns contains the common code for using the Go net package -// in a logical "network namespace" to avoid routing loops where -// Tailscale-created packets would otherwise loop back through -// Tailscale routes. -// -// Despite the name netns, the exact mechanism used differs by -// operating system, and perhaps even by version of the OS. -// -// The netns package also handles connecting via SOCKS proxies when -// configured by the environment. -package netns - -import ( - "flag" - "testing" -) - -var extNetwork = flag.Bool("use-external-network", false, "use the external network in tests") - -func TestDial(t *testing.T) { - if !*extNetwork { - t.Skip("skipping test without --use-external-network") - } - d := NewDialer(t.Logf, nil) - c, err := d.Dial("tcp", "google.com:80") - if err != nil { - t.Fatal(err) - } - defer c.Close() - t.Logf("got addr %v", c.RemoteAddr()) - - c, err = d.Dial("tcp4", "google.com:80") - if err != nil { - t.Fatal(err) - } - defer c.Close() - t.Logf("got addr %v", c.RemoteAddr()) -} - -func TestIsLocalhost(t *testing.T) { - tests := []struct { - name string - host string - want bool - }{ - {"IPv4 loopback", "127.0.0.1", true}, - {"IPv4 !loopback", "192.168.0.1", false}, - {"IPv4 loopback with port", "127.0.0.1:1", true}, - {"IPv4 !loopback with port", "192.168.0.1:1", false}, - {"IPv4 unspecified", "0.0.0.0", false}, - {"IPv4 unspecified with port", "0.0.0.0:1", false}, - {"IPv6 loopback", "::1", true}, - {"IPv6 !loopback", "2001:4860:4860::8888", false}, - {"IPv6 loopback with port", "[::1]:1", true}, - {"IPv6 !loopback with port", "[2001:4860:4860::8888]:1", false}, - {"IPv6 unspecified", "::", false}, - {"IPv6 unspecified with port", "[::]:1", false}, - {"empty", "", false}, - {"hostname", "example.com", false}, - {"localhost", "localhost", true}, - {"localhost6", "localhost6", true}, - {"localhost with port", "localhost:1", true}, - {"localhost6 with port", "localhost6:1", true}, - {"ip6-localhost", "ip6-localhost", true}, - {"ip6-localhost with port", "ip6-localhost:1", true}, - {"ip6-loopback", "ip6-loopback", true}, - {"ip6-loopback with port", "ip6-loopback:1", true}, - } - - for _, test := range tests { - if got := isLocalhost(test.host); got != test.want { - t.Errorf("isLocalhost(%q) = %v, want %v", test.name, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netns contains the common code for using the Go net package +// in a logical "network namespace" to avoid routing loops where +// Tailscale-created packets would otherwise loop back through +// Tailscale routes. +// +// Despite the name netns, the exact mechanism used differs by +// operating system, and perhaps even by version of the OS. +// +// The netns package also handles connecting via SOCKS proxies when +// configured by the environment. +package netns + +import ( + "flag" + "testing" +) + +var extNetwork = flag.Bool("use-external-network", false, "use the external network in tests") + +func TestDial(t *testing.T) { + if !*extNetwork { + t.Skip("skipping test without --use-external-network") + } + d := NewDialer(t.Logf, nil) + c, err := d.Dial("tcp", "google.com:80") + if err != nil { + t.Fatal(err) + } + defer c.Close() + t.Logf("got addr %v", c.RemoteAddr()) + + c, err = d.Dial("tcp4", "google.com:80") + if err != nil { + t.Fatal(err) + } + defer c.Close() + t.Logf("got addr %v", c.RemoteAddr()) +} + +func TestIsLocalhost(t *testing.T) { + tests := []struct { + name string + host string + want bool + }{ + {"IPv4 loopback", "127.0.0.1", true}, + {"IPv4 !loopback", "192.168.0.1", false}, + {"IPv4 loopback with port", "127.0.0.1:1", true}, + {"IPv4 !loopback with port", "192.168.0.1:1", false}, + {"IPv4 unspecified", "0.0.0.0", false}, + {"IPv4 unspecified with port", "0.0.0.0:1", false}, + {"IPv6 loopback", "::1", true}, + {"IPv6 !loopback", "2001:4860:4860::8888", false}, + {"IPv6 loopback with port", "[::1]:1", true}, + {"IPv6 !loopback with port", "[2001:4860:4860::8888]:1", false}, + {"IPv6 unspecified", "::", false}, + {"IPv6 unspecified with port", "[::]:1", false}, + {"empty", "", false}, + {"hostname", "example.com", false}, + {"localhost", "localhost", true}, + {"localhost6", "localhost6", true}, + {"localhost with port", "localhost:1", true}, + {"localhost6 with port", "localhost6:1", true}, + {"ip6-localhost", "ip6-localhost", true}, + {"ip6-localhost with port", "ip6-localhost:1", true}, + {"ip6-loopback", "ip6-loopback", true}, + {"ip6-loopback with port", "ip6-loopback:1", true}, + } + + for _, test := range tests { + if got := isLocalhost(test.host); got != test.want { + t.Errorf("isLocalhost(%q) = %v, want %v", test.name, got, test.want) + } + } +} diff --git a/net/netns/socks.go b/net/netns/socks.go index eea69d8651eda..a3d10d3ae80c5 100644 --- a/net/netns/socks.go +++ b/net/netns/socks.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ios && !js - -package netns - -import "golang.org/x/net/proxy" - -func init() { - wrapDialer = wrapSocks -} - -func wrapSocks(d Dialer) Dialer { - if cd, ok := proxy.FromEnvironmentUsing(d).(Dialer); ok { - return cd - } - return d -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !js + +package netns + +import "golang.org/x/net/proxy" + +func init() { + wrapDialer = wrapSocks +} + +func wrapSocks(d Dialer) Dialer { + if cd, ok := proxy.FromEnvironmentUsing(d).(Dialer); ok { + return cd + } + return d +} diff --git a/net/netstat/netstat.go b/net/netstat/netstat.go index 53c7d7757eac6..53121dc52e202 100644 --- a/net/netstat/netstat.go +++ b/net/netstat/netstat.go @@ -1,35 +1,35 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netstat returns the local machine's network connection table. -package netstat - -import ( - "errors" - "net/netip" - "runtime" -) - -var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) - -type Entry struct { - Local, Remote netip.AddrPort - Pid int - State string // TODO: type? - OSMetadata OSMetadata -} - -// Table contains local machine's TCP connection entries. -// -// Currently only TCP (IPv4 and IPv6) are included. -type Table struct { - Entries []Entry -} - -// Get returns the connection table. -// -// It returns ErrNotImplemented if the table is not available for the -// current operating system. -func Get() (*Table, error) { - return get() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netstat returns the local machine's network connection table. +package netstat + +import ( + "errors" + "net/netip" + "runtime" +) + +var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) + +type Entry struct { + Local, Remote netip.AddrPort + Pid int + State string // TODO: type? + OSMetadata OSMetadata +} + +// Table contains local machine's TCP connection entries. +// +// Currently only TCP (IPv4 and IPv6) are included. +type Table struct { + Entries []Entry +} + +// Get returns the connection table. +// +// It returns ErrNotImplemented if the table is not available for the +// current operating system. +func Get() (*Table, error) { + return get() +} diff --git a/net/netstat/netstat_noimpl.go b/net/netstat/netstat_noimpl.go index e455c8ce931de..608b1a617bc5d 100644 --- a/net/netstat/netstat_noimpl.go +++ b/net/netstat/netstat_noimpl.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package netstat - -// OSMetadata includes any additional OS-specific information that may be -// obtained during the retrieval of a given Entry. -type OSMetadata struct{} - -func get() (*Table, error) { - return nil, ErrNotImplemented -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package netstat + +// OSMetadata includes any additional OS-specific information that may be +// obtained during the retrieval of a given Entry. +type OSMetadata struct{} + +func get() (*Table, error) { + return nil, ErrNotImplemented +} diff --git a/net/netstat/netstat_test.go b/net/netstat/netstat_test.go index 38827df5ef65a..74f4fcec02616 100644 --- a/net/netstat/netstat_test.go +++ b/net/netstat/netstat_test.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netstat - -import ( - "testing" -) - -func TestGet(t *testing.T) { - nt, err := Get() - if err == ErrNotImplemented { - t.Skip("TODO: not implemented") - } - if err != nil { - t.Fatal(err) - } - for _, e := range nt.Entries { - t.Logf("Entry: %+v", e) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netstat + +import ( + "testing" +) + +func TestGet(t *testing.T) { + nt, err := Get() + if err == ErrNotImplemented { + t.Skip("TODO: not implemented") + } + if err != nil { + t.Fatal(err) + } + for _, e := range nt.Entries { + t.Logf("Entry: %+v", e) + } +} diff --git a/net/packet/doc.go b/net/packet/doc.go index ce6c0c30716c6..f3cb93db87e03 100644 --- a/net/packet/doc.go +++ b/net/packet/doc.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package packet contains packet parsing and marshaling utilities. -// -// Parsed provides allocation-free minimal packet header decoding, for -// use in packet filtering. The other types in the package are for -// constructing and marshaling packets into []bytes. -// -// To support allocation-free parsing, this package defines IPv4 and -// IPv6 address types. You should prefer to use netaddr's types, -// except where you absolutely need allocation-free IP handling -// (i.e. in the tunnel datapath) and are willing to implement all -// codepaths and data structures twice, once per IP family. -package packet +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package packet contains packet parsing and marshaling utilities. +// +// Parsed provides allocation-free minimal packet header decoding, for +// use in packet filtering. The other types in the package are for +// constructing and marshaling packets into []bytes. +// +// To support allocation-free parsing, this package defines IPv4 and +// IPv6 address types. You should prefer to use netaddr's types, +// except where you absolutely need allocation-free IP handling +// (i.e. in the tunnel datapath) and are willing to implement all +// codepaths and data structures twice, once per IP family. +package packet diff --git a/net/packet/header.go b/net/packet/header.go index dbe84429adbd8..0b1947c0abc36 100644 --- a/net/packet/header.go +++ b/net/packet/header.go @@ -1,66 +1,66 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "errors" - "math" -) - -const tcpHeaderLength = 20 -const sctpHeaderLength = 12 - -// maxPacketLength is the largest length that all headers support. -// IPv4 headers using uint16 for this forces an upper bound of 64KB. -const maxPacketLength = math.MaxUint16 - -var ( - // errSmallBuffer is returned when Marshal receives a buffer - // too small to contain the header to marshal. - errSmallBuffer = errors.New("buffer too small") - // errLargePacket is returned when Marshal receives a payload - // larger than the maximum representable size in header - // fields. - errLargePacket = errors.New("packet too large") -) - -// Header is a packet header capable of marshaling itself into a byte -// buffer. -type Header interface { - // Len returns the length of the marshaled packet. - Len() int - // Marshal serializes the header into buf, which must be at - // least Len() bytes long. Implementations of Marshal assume - // that bytes after the first Len() are payload bytes for the - // purpose of computing length and checksum fields. Marshal - // implementations must not allocate memory. - Marshal(buf []byte) error -} - -// HeaderChecksummer is implemented by Header implementations that -// need to do a checksum over their payloads. -type HeaderChecksummer interface { - Header - - // WriteCheck writes the correct checksum into buf, which should - // be be the already-marshalled header and payload. - WriteChecksum(buf []byte) -} - -// Generate generates a new packet with the given Header and -// payload. This function allocates memory, see Header.Marshal for an -// allocation-free option. -func Generate(h Header, payload []byte) []byte { - hlen := h.Len() - buf := make([]byte, hlen+len(payload)) - - copy(buf[hlen:], payload) - h.Marshal(buf) - - if hc, ok := h.(HeaderChecksummer); ok { - hc.WriteChecksum(buf) - } - - return buf -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "errors" + "math" +) + +const tcpHeaderLength = 20 +const sctpHeaderLength = 12 + +// maxPacketLength is the largest length that all headers support. +// IPv4 headers using uint16 for this forces an upper bound of 64KB. +const maxPacketLength = math.MaxUint16 + +var ( + // errSmallBuffer is returned when Marshal receives a buffer + // too small to contain the header to marshal. + errSmallBuffer = errors.New("buffer too small") + // errLargePacket is returned when Marshal receives a payload + // larger than the maximum representable size in header + // fields. + errLargePacket = errors.New("packet too large") +) + +// Header is a packet header capable of marshaling itself into a byte +// buffer. +type Header interface { + // Len returns the length of the marshaled packet. + Len() int + // Marshal serializes the header into buf, which must be at + // least Len() bytes long. Implementations of Marshal assume + // that bytes after the first Len() are payload bytes for the + // purpose of computing length and checksum fields. Marshal + // implementations must not allocate memory. + Marshal(buf []byte) error +} + +// HeaderChecksummer is implemented by Header implementations that +// need to do a checksum over their payloads. +type HeaderChecksummer interface { + Header + + // WriteCheck writes the correct checksum into buf, which should + // be be the already-marshalled header and payload. + WriteChecksum(buf []byte) +} + +// Generate generates a new packet with the given Header and +// payload. This function allocates memory, see Header.Marshal for an +// allocation-free option. +func Generate(h Header, payload []byte) []byte { + hlen := h.Len() + buf := make([]byte, hlen+len(payload)) + + copy(buf[hlen:], payload) + h.Marshal(buf) + + if hc, ok := h.(HeaderChecksummer); ok { + hc.WriteChecksum(buf) + } + + return buf +} diff --git a/net/packet/icmp.go b/net/packet/icmp.go index 89a7aaa32bec4..7b86edd815384 100644 --- a/net/packet/icmp.go +++ b/net/packet/icmp.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - crand "crypto/rand" - - "encoding/binary" -) - -// ICMPEchoPayload generates a new random ID/Sequence pair, and returns a uint32 -// derived from them, along with the id, sequence and given payload in a buffer. -// It returns an error if the random source could not be read. -func ICMPEchoPayload(payload []byte) (idSeq uint32, buf []byte) { - buf = make([]byte, len(payload)+4) - - // make a completely random id/sequence combo, which is very unlikely to - // collide with a running ping sequence on the host system. Errors are - // ignored, that would result in collisions, but errors reading from the - // random device are rare, and will cause this process universe to soon end. - crand.Read(buf[:4]) - - idSeq = binary.LittleEndian.Uint32(buf) - copy(buf[4:], payload) - - return -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + crand "crypto/rand" + + "encoding/binary" +) + +// ICMPEchoPayload generates a new random ID/Sequence pair, and returns a uint32 +// derived from them, along with the id, sequence and given payload in a buffer. +// It returns an error if the random source could not be read. +func ICMPEchoPayload(payload []byte) (idSeq uint32, buf []byte) { + buf = make([]byte, len(payload)+4) + + // make a completely random id/sequence combo, which is very unlikely to + // collide with a running ping sequence on the host system. Errors are + // ignored, that would result in collisions, but errors reading from the + // random device are rare, and will cause this process universe to soon end. + crand.Read(buf[:4]) + + idSeq = binary.LittleEndian.Uint32(buf) + copy(buf[4:], payload) + + return +} diff --git a/net/packet/icmp6_test.go b/net/packet/icmp6_test.go index f34883ca41e7e..c2fab353a582d 100644 --- a/net/packet/icmp6_test.go +++ b/net/packet/icmp6_test.go @@ -1,79 +1,79 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "net/netip" - "testing" - - "tailscale.com/types/ipproto" -) - -func TestICMPv6PingResponse(t *testing.T) { - pingHdr := ICMP6Header{ - IP6Header: IP6Header{ - Src: netip.MustParseAddr("1::1"), - Dst: netip.MustParseAddr("2::2"), - IPProto: ipproto.ICMPv6, - }, - Type: ICMP6EchoRequest, - Code: ICMP6NoCode, - } - - // echoReqLen is 2 bytes identifier + 2 bytes seq number. - // https://datatracker.ietf.org/doc/html/rfc4443#section-4.1 - // Packet.IsEchoRequest verifies that these 4 bytes are present. - const echoReqLen = 4 - buf := make([]byte, pingHdr.Len()+echoReqLen) - if err := pingHdr.Marshal(buf); err != nil { - t.Fatal(err) - } - - var p Parsed - p.Decode(buf) - if !p.IsEchoRequest() { - t.Fatalf("not an echo request, got: %+v", p) - } - - pingHdr.ToResponse() - buf = make([]byte, pingHdr.Len()+echoReqLen) - if err := pingHdr.Marshal(buf); err != nil { - t.Fatal(err) - } - - p.Decode(buf) - if p.IsEchoRequest() { - t.Fatalf("unexpectedly still an echo request: %+v", p) - } - if !p.IsEchoResponse() { - t.Fatalf("not an echo response: %+v", p) - } -} - -func TestICMPv6Checksum(t *testing.T) { - const req = "\x60\x0f\x07\x00\x00\x10\x3a\x40\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" + - "\x48\x43\xcd\x96\x62\x7b\x65\x28\x26\x07\xf8\xb0\x40\x0a\x08\x07" + - "\x00\x00\x00\x00\x00\x00\x20\x0e\x80\x00\x4a\x9a\x2e\xea\x00\x02" + - "\x61\xb1\x9e\xad\x00\x06\x45\xaa" - // The packet that we'd originally generated incorrectly, but with the checksum - // bytes fixed per WireShark's correct calculation: - const wantRes = "\x60\x00\xf8\xff\x00\x10\x3a\x40\x26\x07\xf8\xb0\x40\x0a\x08\x07" + - "\x00\x00\x00\x00\x00\x00\x20\x0e\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" + - "\x48\x43\xcd\x96\x62\x7b\x65\x28\x81\x00\x49\x9a\x2e\xea\x00\x02" + - "\x61\xb1\x9e\xad\x00\x06\x45\xaa" - - var p Parsed - p.Decode([]byte(req)) - if !p.IsEchoRequest() { - t.Fatalf("not an echo request, got: %+v", p) - } - - h := p.ICMP6Header() - h.ToResponse() - pong := Generate(&h, p.Payload()) - - if string(pong) != wantRes { - t.Errorf("wrong packet\n\n got: %x\nwant: %x", pong, wantRes) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "net/netip" + "testing" + + "tailscale.com/types/ipproto" +) + +func TestICMPv6PingResponse(t *testing.T) { + pingHdr := ICMP6Header{ + IP6Header: IP6Header{ + Src: netip.MustParseAddr("1::1"), + Dst: netip.MustParseAddr("2::2"), + IPProto: ipproto.ICMPv6, + }, + Type: ICMP6EchoRequest, + Code: ICMP6NoCode, + } + + // echoReqLen is 2 bytes identifier + 2 bytes seq number. + // https://datatracker.ietf.org/doc/html/rfc4443#section-4.1 + // Packet.IsEchoRequest verifies that these 4 bytes are present. + const echoReqLen = 4 + buf := make([]byte, pingHdr.Len()+echoReqLen) + if err := pingHdr.Marshal(buf); err != nil { + t.Fatal(err) + } + + var p Parsed + p.Decode(buf) + if !p.IsEchoRequest() { + t.Fatalf("not an echo request, got: %+v", p) + } + + pingHdr.ToResponse() + buf = make([]byte, pingHdr.Len()+echoReqLen) + if err := pingHdr.Marshal(buf); err != nil { + t.Fatal(err) + } + + p.Decode(buf) + if p.IsEchoRequest() { + t.Fatalf("unexpectedly still an echo request: %+v", p) + } + if !p.IsEchoResponse() { + t.Fatalf("not an echo response: %+v", p) + } +} + +func TestICMPv6Checksum(t *testing.T) { + const req = "\x60\x0f\x07\x00\x00\x10\x3a\x40\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" + + "\x48\x43\xcd\x96\x62\x7b\x65\x28\x26\x07\xf8\xb0\x40\x0a\x08\x07" + + "\x00\x00\x00\x00\x00\x00\x20\x0e\x80\x00\x4a\x9a\x2e\xea\x00\x02" + + "\x61\xb1\x9e\xad\x00\x06\x45\xaa" + // The packet that we'd originally generated incorrectly, but with the checksum + // bytes fixed per WireShark's correct calculation: + const wantRes = "\x60\x00\xf8\xff\x00\x10\x3a\x40\x26\x07\xf8\xb0\x40\x0a\x08\x07" + + "\x00\x00\x00\x00\x00\x00\x20\x0e\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" + + "\x48\x43\xcd\x96\x62\x7b\x65\x28\x81\x00\x49\x9a\x2e\xea\x00\x02" + + "\x61\xb1\x9e\xad\x00\x06\x45\xaa" + + var p Parsed + p.Decode([]byte(req)) + if !p.IsEchoRequest() { + t.Fatalf("not an echo request, got: %+v", p) + } + + h := p.ICMP6Header() + h.ToResponse() + pong := Generate(&h, p.Payload()) + + if string(pong) != wantRes { + t.Errorf("wrong packet\n\n got: %x\nwant: %x", pong, wantRes) + } +} diff --git a/net/packet/ip4.go b/net/packet/ip4.go index 967a8dba7f57b..596bc766d9a17 100644 --- a/net/packet/ip4.go +++ b/net/packet/ip4.go @@ -1,116 +1,116 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - "errors" - "net/netip" - - "tailscale.com/types/ipproto" -) - -// ip4HeaderLength is the length of an IPv4 header with no IP options. -const ip4HeaderLength = 20 - -// IP4Header represents an IPv4 packet header. -type IP4Header struct { - IPProto ipproto.Proto - IPID uint16 - Src netip.Addr - Dst netip.Addr -} - -// Len implements Header. -func (h IP4Header) Len() int { - return ip4HeaderLength -} - -var errWrongFamily = errors.New("wrong address family for src/dst IP") - -// Marshal implements Header. -func (h IP4Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - if !h.Src.Is4() || !h.Dst.Is4() { - return errWrongFamily - } - - buf[0] = 0x40 | (byte(h.Len() >> 2)) // IPv4 + IHL - buf[1] = 0x00 // DSCP + ECN - binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf))) // Total length - binary.BigEndian.PutUint16(buf[4:6], h.IPID) // ID - binary.BigEndian.PutUint16(buf[6:8], 0) // Flags + fragment offset - buf[8] = 64 // TTL - buf[9] = uint8(h.IPProto) // Inner protocol - // Blank checksum. This is necessary even though we overwrite - // it later, because the checksum computation runs over these - // bytes and expects them to be zero. - binary.BigEndian.PutUint16(buf[10:12], 0) - src := h.Src.As4() - dst := h.Dst.As4() - copy(buf[12:16], src[:]) - copy(buf[16:20], dst[:]) - - binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum - - return nil -} - -// ToResponse implements Header. -func (h *IP4Header) ToResponse() { - h.Src, h.Dst = h.Dst, h.Src - // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. - h.IPID = ^h.IPID -} - -// ip4Checksum computes an IPv4 checksum, as specified in -// https://tools.ietf.org/html/rfc1071 -func ip4Checksum(b []byte) uint16 { - var ac uint32 - i := 0 - n := len(b) - for n >= 2 { - ac += uint32(binary.BigEndian.Uint16(b[i : i+2])) - n -= 2 - i += 2 - } - if n == 1 { - ac += uint32(b[i]) << 8 - } - for (ac >> 16) > 0 { - ac = (ac >> 16) + (ac & 0xffff) - } - return uint16(^ac) -} - -// ip4PseudoHeaderOffset is the number of bytes by which the IPv4 UDP -// pseudo-header is smaller than the real IPv4 header. -const ip4PseudoHeaderOffset = 8 - -// marshalPseudo serializes h into buf in the "pseudo-header" form -// required when calculating UDP checksums. The pseudo-header starts -// at buf[ip4PseudoHeaderOffset] so as to abut the following UDP -// header, while leaving enough space in buf for a full IPv4 header. -func (h IP4Header) marshalPseudo(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - - length := len(buf) - h.Len() - src, dst := h.Src.As4(), h.Dst.As4() - copy(buf[8:12], src[:]) - copy(buf[12:16], dst[:]) - buf[16] = 0x0 - buf[17] = uint8(h.IPProto) - binary.BigEndian.PutUint16(buf[18:20], uint16(length)) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "encoding/binary" + "errors" + "net/netip" + + "tailscale.com/types/ipproto" +) + +// ip4HeaderLength is the length of an IPv4 header with no IP options. +const ip4HeaderLength = 20 + +// IP4Header represents an IPv4 packet header. +type IP4Header struct { + IPProto ipproto.Proto + IPID uint16 + Src netip.Addr + Dst netip.Addr +} + +// Len implements Header. +func (h IP4Header) Len() int { + return ip4HeaderLength +} + +var errWrongFamily = errors.New("wrong address family for src/dst IP") + +// Marshal implements Header. +func (h IP4Header) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + if !h.Src.Is4() || !h.Dst.Is4() { + return errWrongFamily + } + + buf[0] = 0x40 | (byte(h.Len() >> 2)) // IPv4 + IHL + buf[1] = 0x00 // DSCP + ECN + binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf))) // Total length + binary.BigEndian.PutUint16(buf[4:6], h.IPID) // ID + binary.BigEndian.PutUint16(buf[6:8], 0) // Flags + fragment offset + buf[8] = 64 // TTL + buf[9] = uint8(h.IPProto) // Inner protocol + // Blank checksum. This is necessary even though we overwrite + // it later, because the checksum computation runs over these + // bytes and expects them to be zero. + binary.BigEndian.PutUint16(buf[10:12], 0) + src := h.Src.As4() + dst := h.Dst.As4() + copy(buf[12:16], src[:]) + copy(buf[16:20], dst[:]) + + binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum + + return nil +} + +// ToResponse implements Header. +func (h *IP4Header) ToResponse() { + h.Src, h.Dst = h.Dst, h.Src + // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. + h.IPID = ^h.IPID +} + +// ip4Checksum computes an IPv4 checksum, as specified in +// https://tools.ietf.org/html/rfc1071 +func ip4Checksum(b []byte) uint16 { + var ac uint32 + i := 0 + n := len(b) + for n >= 2 { + ac += uint32(binary.BigEndian.Uint16(b[i : i+2])) + n -= 2 + i += 2 + } + if n == 1 { + ac += uint32(b[i]) << 8 + } + for (ac >> 16) > 0 { + ac = (ac >> 16) + (ac & 0xffff) + } + return uint16(^ac) +} + +// ip4PseudoHeaderOffset is the number of bytes by which the IPv4 UDP +// pseudo-header is smaller than the real IPv4 header. +const ip4PseudoHeaderOffset = 8 + +// marshalPseudo serializes h into buf in the "pseudo-header" form +// required when calculating UDP checksums. The pseudo-header starts +// at buf[ip4PseudoHeaderOffset] so as to abut the following UDP +// header, while leaving enough space in buf for a full IPv4 header. +func (h IP4Header) marshalPseudo(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + + length := len(buf) - h.Len() + src, dst := h.Src.As4(), h.Dst.As4() + copy(buf[8:12], src[:]) + copy(buf[12:16], dst[:]) + buf[16] = 0x0 + buf[17] = uint8(h.IPProto) + binary.BigEndian.PutUint16(buf[18:20], uint16(length)) + return nil +} diff --git a/net/packet/ip6.go b/net/packet/ip6.go index d26b9a1619b31..cebc46c534c04 100644 --- a/net/packet/ip6.go +++ b/net/packet/ip6.go @@ -1,76 +1,76 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - "net/netip" - - "tailscale.com/types/ipproto" -) - -// ip6HeaderLength is the length of an IPv6 header with no IP options. -const ip6HeaderLength = 40 - -// IP6Header represents an IPv6 packet header. -type IP6Header struct { - IPProto ipproto.Proto - IPID uint32 // only lower 20 bits used - Src netip.Addr - Dst netip.Addr -} - -// Len implements Header. -func (h IP6Header) Len() int { - return ip6HeaderLength -} - -// Marshal implements Header. -func (h IP6Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - - binary.BigEndian.PutUint32(buf[:4], h.IPID&0x000FFFFF) - buf[0] = 0x60 - binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length - buf[6] = uint8(h.IPProto) // Inner protocol - buf[7] = 64 // TTL - src, dst := h.Src.As16(), h.Dst.As16() - copy(buf[8:24], src[:]) - copy(buf[24:40], dst[:]) - - return nil -} - -// ToResponse implements Header. -func (h *IP6Header) ToResponse() { - h.Src, h.Dst = h.Dst, h.Src - // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. - h.IPID = (^h.IPID) & 0x000FFFFF -} - -// marshalPseudo serializes h into buf in the "pseudo-header" form -// required when calculating UDP checksums. -func (h IP6Header) marshalPseudo(buf []byte, proto ipproto.Proto) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - - src, dst := h.Src.As16(), h.Dst.As16() - copy(buf[:16], src[:]) - copy(buf[16:32], dst[:]) - binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len())) - buf[36] = 0 - buf[37] = 0 - buf[38] = 0 - buf[39] = byte(proto) // NextProto - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "encoding/binary" + "net/netip" + + "tailscale.com/types/ipproto" +) + +// ip6HeaderLength is the length of an IPv6 header with no IP options. +const ip6HeaderLength = 40 + +// IP6Header represents an IPv6 packet header. +type IP6Header struct { + IPProto ipproto.Proto + IPID uint32 // only lower 20 bits used + Src netip.Addr + Dst netip.Addr +} + +// Len implements Header. +func (h IP6Header) Len() int { + return ip6HeaderLength +} + +// Marshal implements Header. +func (h IP6Header) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + + binary.BigEndian.PutUint32(buf[:4], h.IPID&0x000FFFFF) + buf[0] = 0x60 + binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length + buf[6] = uint8(h.IPProto) // Inner protocol + buf[7] = 64 // TTL + src, dst := h.Src.As16(), h.Dst.As16() + copy(buf[8:24], src[:]) + copy(buf[24:40], dst[:]) + + return nil +} + +// ToResponse implements Header. +func (h *IP6Header) ToResponse() { + h.Src, h.Dst = h.Dst, h.Src + // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. + h.IPID = (^h.IPID) & 0x000FFFFF +} + +// marshalPseudo serializes h into buf in the "pseudo-header" form +// required when calculating UDP checksums. +func (h IP6Header) marshalPseudo(buf []byte, proto ipproto.Proto) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + + src, dst := h.Src.As16(), h.Dst.As16() + copy(buf[:16], src[:]) + copy(buf[16:32], dst[:]) + binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len())) + buf[36] = 0 + buf[37] = 0 + buf[38] = 0 + buf[39] = byte(proto) // NextProto + return nil +} diff --git a/net/packet/tsmp_test.go b/net/packet/tsmp_test.go index e261e6a4199b3..4ec24e1ea0a4c 100644 --- a/net/packet/tsmp_test.go +++ b/net/packet/tsmp_test.go @@ -1,73 +1,73 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "net/netip" - "testing" -) - -func TestTailscaleRejectedHeader(t *testing.T) { - tests := []struct { - h TailscaleRejectedHeader - wantStr string - }{ - { - h: TailscaleRejectedHeader{ - IPSrc: netip.MustParseAddr("5.5.5.5"), - IPDst: netip.MustParseAddr("1.2.3.4"), - Src: netip.MustParseAddrPort("1.2.3.4:567"), - Dst: netip.MustParseAddrPort("5.5.5.5:443"), - Proto: TCP, - Reason: RejectedDueToACLs, - }, - wantStr: "TSMP-reject-flow{TCP 1.2.3.4:567 > 5.5.5.5:443}: acl", - }, - { - h: TailscaleRejectedHeader{ - IPSrc: netip.MustParseAddr("2::2"), - IPDst: netip.MustParseAddr("1::1"), - Src: netip.MustParseAddrPort("[1::1]:567"), - Dst: netip.MustParseAddrPort("[2::2]:443"), - Proto: UDP, - Reason: RejectedDueToShieldsUp, - }, - wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: shields", - }, - { - h: TailscaleRejectedHeader{ - IPSrc: netip.MustParseAddr("2::2"), - IPDst: netip.MustParseAddr("1::1"), - Src: netip.MustParseAddrPort("[1::1]:567"), - Dst: netip.MustParseAddrPort("[2::2]:443"), - Proto: UDP, - Reason: RejectedDueToIPForwarding, - MaybeBroken: true, - }, - wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: host-ip-forwarding-unavailable", - }, - } - for i, tt := range tests { - gotStr := tt.h.String() - if gotStr != tt.wantStr { - t.Errorf("%v. String = %q; want %q", i, gotStr, tt.wantStr) - continue - } - pkt := make([]byte, tt.h.Len()) - tt.h.Marshal(pkt) - - var p Parsed - p.Decode(pkt) - t.Logf("Parsed: %+v", p) - t.Logf("Parsed: %s", p.String()) - back, ok := p.AsTailscaleRejectedHeader() - if !ok { - t.Errorf("%v. %q (%02x) didn't parse back", i, gotStr, pkt) - continue - } - if back != tt.h { - t.Errorf("%v. %q parsed back as %q", i, tt.h, back) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "net/netip" + "testing" +) + +func TestTailscaleRejectedHeader(t *testing.T) { + tests := []struct { + h TailscaleRejectedHeader + wantStr string + }{ + { + h: TailscaleRejectedHeader{ + IPSrc: netip.MustParseAddr("5.5.5.5"), + IPDst: netip.MustParseAddr("1.2.3.4"), + Src: netip.MustParseAddrPort("1.2.3.4:567"), + Dst: netip.MustParseAddrPort("5.5.5.5:443"), + Proto: TCP, + Reason: RejectedDueToACLs, + }, + wantStr: "TSMP-reject-flow{TCP 1.2.3.4:567 > 5.5.5.5:443}: acl", + }, + { + h: TailscaleRejectedHeader{ + IPSrc: netip.MustParseAddr("2::2"), + IPDst: netip.MustParseAddr("1::1"), + Src: netip.MustParseAddrPort("[1::1]:567"), + Dst: netip.MustParseAddrPort("[2::2]:443"), + Proto: UDP, + Reason: RejectedDueToShieldsUp, + }, + wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: shields", + }, + { + h: TailscaleRejectedHeader{ + IPSrc: netip.MustParseAddr("2::2"), + IPDst: netip.MustParseAddr("1::1"), + Src: netip.MustParseAddrPort("[1::1]:567"), + Dst: netip.MustParseAddrPort("[2::2]:443"), + Proto: UDP, + Reason: RejectedDueToIPForwarding, + MaybeBroken: true, + }, + wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: host-ip-forwarding-unavailable", + }, + } + for i, tt := range tests { + gotStr := tt.h.String() + if gotStr != tt.wantStr { + t.Errorf("%v. String = %q; want %q", i, gotStr, tt.wantStr) + continue + } + pkt := make([]byte, tt.h.Len()) + tt.h.Marshal(pkt) + + var p Parsed + p.Decode(pkt) + t.Logf("Parsed: %+v", p) + t.Logf("Parsed: %s", p.String()) + back, ok := p.AsTailscaleRejectedHeader() + if !ok { + t.Errorf("%v. %q (%02x) didn't parse back", i, gotStr, pkt) + continue + } + if back != tt.h { + t.Errorf("%v. %q parsed back as %q", i, tt.h, back) + } + } +} diff --git a/net/packet/udp4.go b/net/packet/udp4.go index 0d5bca73e8c89..c8761baef2d36 100644 --- a/net/packet/udp4.go +++ b/net/packet/udp4.go @@ -1,58 +1,58 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - - "tailscale.com/types/ipproto" -) - -// udpHeaderLength is the size of the UDP packet header, not including -// the outer IP header. -const udpHeaderLength = 8 - -// UDP4Header is an IPv4+UDP header. -type UDP4Header struct { - IP4Header - SrcPort uint16 - DstPort uint16 -} - -// Len implements Header. -func (h UDP4Header) Len() int { - return h.IP4Header.Len() + udpHeaderLength -} - -// Marshal implements Header. -func (h UDP4Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - // The caller does not need to set this. - h.IPProto = ipproto.UDP - - length := len(buf) - h.IP4Header.Len() - binary.BigEndian.PutUint16(buf[20:22], h.SrcPort) - binary.BigEndian.PutUint16(buf[22:24], h.DstPort) - binary.BigEndian.PutUint16(buf[24:26], uint16(length)) - binary.BigEndian.PutUint16(buf[26:28], 0) // blank checksum - - // UDP checksum with IP pseudo header. - h.IP4Header.marshalPseudo(buf) - binary.BigEndian.PutUint16(buf[26:28], ip4Checksum(buf[ip4PseudoHeaderOffset:])) - - h.IP4Header.Marshal(buf) - - return nil -} - -// ToResponse implements Header. -func (h *UDP4Header) ToResponse() { - h.SrcPort, h.DstPort = h.DstPort, h.SrcPort - h.IP4Header.ToResponse() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "encoding/binary" + + "tailscale.com/types/ipproto" +) + +// udpHeaderLength is the size of the UDP packet header, not including +// the outer IP header. +const udpHeaderLength = 8 + +// UDP4Header is an IPv4+UDP header. +type UDP4Header struct { + IP4Header + SrcPort uint16 + DstPort uint16 +} + +// Len implements Header. +func (h UDP4Header) Len() int { + return h.IP4Header.Len() + udpHeaderLength +} + +// Marshal implements Header. +func (h UDP4Header) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + // The caller does not need to set this. + h.IPProto = ipproto.UDP + + length := len(buf) - h.IP4Header.Len() + binary.BigEndian.PutUint16(buf[20:22], h.SrcPort) + binary.BigEndian.PutUint16(buf[22:24], h.DstPort) + binary.BigEndian.PutUint16(buf[24:26], uint16(length)) + binary.BigEndian.PutUint16(buf[26:28], 0) // blank checksum + + // UDP checksum with IP pseudo header. + h.IP4Header.marshalPseudo(buf) + binary.BigEndian.PutUint16(buf[26:28], ip4Checksum(buf[ip4PseudoHeaderOffset:])) + + h.IP4Header.Marshal(buf) + + return nil +} + +// ToResponse implements Header. +func (h *UDP4Header) ToResponse() { + h.SrcPort, h.DstPort = h.DstPort, h.SrcPort + h.IP4Header.ToResponse() +} diff --git a/net/packet/udp6.go b/net/packet/udp6.go index 10fdcb99e525c..c8634b5080aea 100644 --- a/net/packet/udp6.go +++ b/net/packet/udp6.go @@ -1,54 +1,54 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - - "tailscale.com/types/ipproto" -) - -// UDP6Header is an IPv6+UDP header. -type UDP6Header struct { - IP6Header - SrcPort uint16 - DstPort uint16 -} - -// Len implements Header. -func (h UDP6Header) Len() int { - return h.IP6Header.Len() + udpHeaderLength -} - -// Marshal implements Header. -func (h UDP6Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - // The caller does not need to set this. - h.IPProto = ipproto.UDP - - length := len(buf) - h.IP6Header.Len() - binary.BigEndian.PutUint16(buf[40:42], h.SrcPort) - binary.BigEndian.PutUint16(buf[42:44], h.DstPort) - binary.BigEndian.PutUint16(buf[44:46], uint16(length)) - binary.BigEndian.PutUint16(buf[46:48], 0) // blank checksum - - // UDP checksum with IP pseudo header. - h.IP6Header.marshalPseudo(buf, ipproto.UDP) - binary.BigEndian.PutUint16(buf[46:48], ip4Checksum(buf[:])) - - h.IP6Header.Marshal(buf) - - return nil -} - -// ToResponse implements Header. -func (h *UDP6Header) ToResponse() { - h.SrcPort, h.DstPort = h.DstPort, h.SrcPort - h.IP6Header.ToResponse() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "encoding/binary" + + "tailscale.com/types/ipproto" +) + +// UDP6Header is an IPv6+UDP header. +type UDP6Header struct { + IP6Header + SrcPort uint16 + DstPort uint16 +} + +// Len implements Header. +func (h UDP6Header) Len() int { + return h.IP6Header.Len() + udpHeaderLength +} + +// Marshal implements Header. +func (h UDP6Header) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + // The caller does not need to set this. + h.IPProto = ipproto.UDP + + length := len(buf) - h.IP6Header.Len() + binary.BigEndian.PutUint16(buf[40:42], h.SrcPort) + binary.BigEndian.PutUint16(buf[42:44], h.DstPort) + binary.BigEndian.PutUint16(buf[44:46], uint16(length)) + binary.BigEndian.PutUint16(buf[46:48], 0) // blank checksum + + // UDP checksum with IP pseudo header. + h.IP6Header.marshalPseudo(buf, ipproto.UDP) + binary.BigEndian.PutUint16(buf[46:48], ip4Checksum(buf[:])) + + h.IP6Header.Marshal(buf) + + return nil +} + +// ToResponse implements Header. +func (h *UDP6Header) ToResponse() { + h.SrcPort, h.DstPort = h.DstPort, h.SrcPort + h.IP6Header.ToResponse() +} diff --git a/net/ping/ping.go b/net/ping/ping.go index 01f3dcf2c4976..f2093292a7a2c 100644 --- a/net/ping/ping.go +++ b/net/ping/ping.go @@ -1,343 +1,343 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package ping allows sending ICMP echo requests to a host in order to -// determine network latency. -package ping - -import ( - "bytes" - "context" - "crypto/rand" - "encoding/binary" - "fmt" - "io" - "log" - "net" - "net/netip" - "sync" - "sync/atomic" - "time" - - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "tailscale.com/types/logger" - "tailscale.com/util/mak" - "tailscale.com/util/multierr" -) - -const ( - v4Type = "ip4:icmp" - v6Type = "ip6:icmp" -) - -type response struct { - t time.Time - err error -} - -type outstanding struct { - ch chan response - data []byte -} - -// PacketListener defines the interface required to listen to packages -// on an address. -type ListenPacketer interface { - ListenPacket(ctx context.Context, typ string, addr string) (net.PacketConn, error) -} - -// Pinger represents a set of ICMP echo requests to be sent at a single time. -// -// A new instance should be created for each concurrent set of ping requests; -// this type should not be reused. -type Pinger struct { - lp ListenPacketer - - // closed guards against send incrementing the waitgroup concurrently with close. - closed atomic.Bool - Logf logger.Logf - Verbose bool - timeNow func() time.Time - id uint16 // uint16 per RFC 792 - wg sync.WaitGroup - - // Following fields protected by mu - mu sync.Mutex - // conns is a map of "type" to net.PacketConn, type is either - // "ip4:icmp" or "ip6:icmp" - conns map[string]net.PacketConn - seq uint16 // uint16 per RFC 792 - pings map[uint16]outstanding -} - -// New creates a new Pinger. The Context provided will be used to create -// network listeners, and to set an absolute deadline (if any) on the net.Conn -func New(ctx context.Context, logf logger.Logf, lp ListenPacketer) *Pinger { - var id [2]byte - if _, err := io.ReadFull(rand.Reader, id[:]); err != nil { - panic("net/ping: New:" + err.Error()) - } - - return &Pinger{ - lp: lp, - Logf: logf, - timeNow: time.Now, - id: binary.LittleEndian.Uint16(id[:]), - pings: make(map[uint16]outstanding), - } -} - -func (p *Pinger) mkconn(ctx context.Context, typ, addr string) (net.PacketConn, error) { - if p.closed.Load() { - return nil, net.ErrClosed - } - - c, err := p.lp.ListenPacket(ctx, typ, addr) - if err != nil { - return nil, err - } - - // Start by setting the deadline from the context; note that this - // applies to all future I/O, so we only need to do it once. - deadline, ok := ctx.Deadline() - if ok { - if err := c.SetReadDeadline(deadline); err != nil { - return nil, err - } - } - - p.wg.Add(1) - go p.run(ctx, c, typ) - - return c, err -} - -// getConn creates or returns a conn matching typ which is ip4:icmp -// or ip6:icmp. -func (p *Pinger) getConn(ctx context.Context, typ string) (net.PacketConn, error) { - p.mu.Lock() - defer p.mu.Unlock() - if c, ok := p.conns[typ]; ok { - return c, nil - } - - var addr = "0.0.0.0" - if typ == v6Type { - addr = "::" - } - c, err := p.mkconn(ctx, typ, addr) - if err != nil { - return nil, err - } - mak.Set(&p.conns, typ, c) - return c, nil -} - -func (p *Pinger) logf(format string, a ...any) { - if p.Logf != nil { - p.Logf(format, a...) - } else { - log.Printf(format, a...) - } -} - -func (p *Pinger) vlogf(format string, a ...any) { - if p.Verbose { - p.logf(format, a...) - } -} - -func (p *Pinger) Close() error { - p.closed.Store(true) - - p.mu.Lock() - conns := p.conns - p.conns = nil - p.mu.Unlock() - - var errors []error - for _, c := range conns { - if err := c.Close(); err != nil { - errors = append(errors, err) - } - } - - p.wg.Wait() - p.cleanupOutstanding() - - return multierr.New(errors...) -} - -func (p *Pinger) run(ctx context.Context, conn net.PacketConn, typ string) { - defer p.wg.Done() - defer func() { - conn.Close() - p.mu.Lock() - delete(p.conns, typ) - p.mu.Unlock() - }() - buf := make([]byte, 1500) - -loop: - for { - select { - case <-ctx.Done(): - break loop - default: - } - - n, _, err := conn.ReadFrom(buf) - if err != nil { - // Ignore temporary errors; everything else is fatal - if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() { - break - } - continue - } - - p.handleResponse(buf[:n], p.timeNow(), typ) - } -} - -func (p *Pinger) cleanupOutstanding() { - // Complete outstanding requests - p.mu.Lock() - defer p.mu.Unlock() - for _, o := range p.pings { - o.ch <- response{err: net.ErrClosed} - } -} - -func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) { - // We need to handle responding to both IPv4 - // and IPv6. - var icmpType icmp.Type - switch typ { - case v4Type: - icmpType = ipv4.ICMPTypeEchoReply - case v6Type: - icmpType = ipv6.ICMPTypeEchoReply - default: - p.vlogf("handleResponse: unknown icmp.Type") - return - } - - m, err := icmp.ParseMessage(icmpType.Protocol(), buf) - if err != nil { - p.vlogf("handleResponse: invalid packet: %v", err) - return - } - - if m.Type != icmpType { - p.vlogf("handleResponse: wanted m.Type=%d; got %d", icmpType, m.Type) - return - } - - resp, ok := m.Body.(*icmp.Echo) - if !ok || resp == nil { - p.vlogf("handleResponse: wanted body=*icmp.Echo; got %v", m.Body) - return - } - - // We assume we sent this if the ID in the response is ours. - if uint16(resp.ID) != p.id { - p.vlogf("handleResponse: wanted ID=%d; got %d", p.id, resp.ID) - return - } - - // Search for existing running echo request - var o outstanding - p.mu.Lock() - if o, ok = p.pings[uint16(resp.Seq)]; ok { - // Ensure that the data matches before we delete from our map, - // so a future correct packet will be handled correctly. - if bytes.Equal(resp.Data, o.data) { - delete(p.pings, uint16(resp.Seq)) - } else { - p.vlogf("handleResponse: got response for Seq %d with mismatched data", resp.Seq) - ok = false - } - } else { - p.vlogf("handleResponse: got response for unknown Seq %d", resp.Seq) - } - p.mu.Unlock() - - if ok { - o.ch <- response{t: now} - } -} - -// Send sends an ICMP Echo Request packet to the destination, waits for a -// response, and returns the duration between when the request was sent and -// when the reply was received. -// -// If provided, "data" is sent with the packet and is compared upon receiving a -// reply. -func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Duration, error) { - // Use sequential sequence numbers on the assumption that we will not - // wrap around when using a single Pinger instance - p.mu.Lock() - p.seq++ - seq := p.seq - p.mu.Unlock() - - // Check whether the address is IPv4 or IPv6 to - // determine the icmp.Type and conn to use. - var conn net.PacketConn - var icmpType icmp.Type = ipv4.ICMPTypeEcho - ap, err := netip.ParseAddr(dest.String()) - if err != nil { - return 0, err - } - if ap.Is6() { - icmpType = ipv6.ICMPTypeEchoRequest - conn, err = p.getConn(ctx, v6Type) - } else { - conn, err = p.getConn(ctx, v4Type) - } - if err != nil { - return 0, err - } - - m := icmp.Message{ - Type: icmpType, - Code: 0, - Body: &icmp.Echo{ - ID: int(p.id), - Seq: int(seq), - Data: data, - }, - } - b, err := m.Marshal(nil) - if err != nil { - return 0, err - } - - // Register our response before sending since we could otherwise race a - // quick reply. - ch := make(chan response, 1) - p.mu.Lock() - p.pings[seq] = outstanding{ch: ch, data: data} - p.mu.Unlock() - - start := p.timeNow() - n, err := conn.WriteTo(b, dest) - if err != nil { - return 0, err - } else if n != len(b) { - return 0, fmt.Errorf("conn.WriteTo: got %v; want %v", n, len(b)) - } - - select { - case resp := <-ch: - if resp.err != nil { - return 0, resp.err - } - return resp.t.Sub(start), nil - - case <-ctx.Done(): - return 0, ctx.Err() - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ping allows sending ICMP echo requests to a host in order to +// determine network latency. +package ping + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "log" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "tailscale.com/types/logger" + "tailscale.com/util/mak" + "tailscale.com/util/multierr" +) + +const ( + v4Type = "ip4:icmp" + v6Type = "ip6:icmp" +) + +type response struct { + t time.Time + err error +} + +type outstanding struct { + ch chan response + data []byte +} + +// PacketListener defines the interface required to listen to packages +// on an address. +type ListenPacketer interface { + ListenPacket(ctx context.Context, typ string, addr string) (net.PacketConn, error) +} + +// Pinger represents a set of ICMP echo requests to be sent at a single time. +// +// A new instance should be created for each concurrent set of ping requests; +// this type should not be reused. +type Pinger struct { + lp ListenPacketer + + // closed guards against send incrementing the waitgroup concurrently with close. + closed atomic.Bool + Logf logger.Logf + Verbose bool + timeNow func() time.Time + id uint16 // uint16 per RFC 792 + wg sync.WaitGroup + + // Following fields protected by mu + mu sync.Mutex + // conns is a map of "type" to net.PacketConn, type is either + // "ip4:icmp" or "ip6:icmp" + conns map[string]net.PacketConn + seq uint16 // uint16 per RFC 792 + pings map[uint16]outstanding +} + +// New creates a new Pinger. The Context provided will be used to create +// network listeners, and to set an absolute deadline (if any) on the net.Conn +func New(ctx context.Context, logf logger.Logf, lp ListenPacketer) *Pinger { + var id [2]byte + if _, err := io.ReadFull(rand.Reader, id[:]); err != nil { + panic("net/ping: New:" + err.Error()) + } + + return &Pinger{ + lp: lp, + Logf: logf, + timeNow: time.Now, + id: binary.LittleEndian.Uint16(id[:]), + pings: make(map[uint16]outstanding), + } +} + +func (p *Pinger) mkconn(ctx context.Context, typ, addr string) (net.PacketConn, error) { + if p.closed.Load() { + return nil, net.ErrClosed + } + + c, err := p.lp.ListenPacket(ctx, typ, addr) + if err != nil { + return nil, err + } + + // Start by setting the deadline from the context; note that this + // applies to all future I/O, so we only need to do it once. + deadline, ok := ctx.Deadline() + if ok { + if err := c.SetReadDeadline(deadline); err != nil { + return nil, err + } + } + + p.wg.Add(1) + go p.run(ctx, c, typ) + + return c, err +} + +// getConn creates or returns a conn matching typ which is ip4:icmp +// or ip6:icmp. +func (p *Pinger) getConn(ctx context.Context, typ string) (net.PacketConn, error) { + p.mu.Lock() + defer p.mu.Unlock() + if c, ok := p.conns[typ]; ok { + return c, nil + } + + var addr = "0.0.0.0" + if typ == v6Type { + addr = "::" + } + c, err := p.mkconn(ctx, typ, addr) + if err != nil { + return nil, err + } + mak.Set(&p.conns, typ, c) + return c, nil +} + +func (p *Pinger) logf(format string, a ...any) { + if p.Logf != nil { + p.Logf(format, a...) + } else { + log.Printf(format, a...) + } +} + +func (p *Pinger) vlogf(format string, a ...any) { + if p.Verbose { + p.logf(format, a...) + } +} + +func (p *Pinger) Close() error { + p.closed.Store(true) + + p.mu.Lock() + conns := p.conns + p.conns = nil + p.mu.Unlock() + + var errors []error + for _, c := range conns { + if err := c.Close(); err != nil { + errors = append(errors, err) + } + } + + p.wg.Wait() + p.cleanupOutstanding() + + return multierr.New(errors...) +} + +func (p *Pinger) run(ctx context.Context, conn net.PacketConn, typ string) { + defer p.wg.Done() + defer func() { + conn.Close() + p.mu.Lock() + delete(p.conns, typ) + p.mu.Unlock() + }() + buf := make([]byte, 1500) + +loop: + for { + select { + case <-ctx.Done(): + break loop + default: + } + + n, _, err := conn.ReadFrom(buf) + if err != nil { + // Ignore temporary errors; everything else is fatal + if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() { + break + } + continue + } + + p.handleResponse(buf[:n], p.timeNow(), typ) + } +} + +func (p *Pinger) cleanupOutstanding() { + // Complete outstanding requests + p.mu.Lock() + defer p.mu.Unlock() + for _, o := range p.pings { + o.ch <- response{err: net.ErrClosed} + } +} + +func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) { + // We need to handle responding to both IPv4 + // and IPv6. + var icmpType icmp.Type + switch typ { + case v4Type: + icmpType = ipv4.ICMPTypeEchoReply + case v6Type: + icmpType = ipv6.ICMPTypeEchoReply + default: + p.vlogf("handleResponse: unknown icmp.Type") + return + } + + m, err := icmp.ParseMessage(icmpType.Protocol(), buf) + if err != nil { + p.vlogf("handleResponse: invalid packet: %v", err) + return + } + + if m.Type != icmpType { + p.vlogf("handleResponse: wanted m.Type=%d; got %d", icmpType, m.Type) + return + } + + resp, ok := m.Body.(*icmp.Echo) + if !ok || resp == nil { + p.vlogf("handleResponse: wanted body=*icmp.Echo; got %v", m.Body) + return + } + + // We assume we sent this if the ID in the response is ours. + if uint16(resp.ID) != p.id { + p.vlogf("handleResponse: wanted ID=%d; got %d", p.id, resp.ID) + return + } + + // Search for existing running echo request + var o outstanding + p.mu.Lock() + if o, ok = p.pings[uint16(resp.Seq)]; ok { + // Ensure that the data matches before we delete from our map, + // so a future correct packet will be handled correctly. + if bytes.Equal(resp.Data, o.data) { + delete(p.pings, uint16(resp.Seq)) + } else { + p.vlogf("handleResponse: got response for Seq %d with mismatched data", resp.Seq) + ok = false + } + } else { + p.vlogf("handleResponse: got response for unknown Seq %d", resp.Seq) + } + p.mu.Unlock() + + if ok { + o.ch <- response{t: now} + } +} + +// Send sends an ICMP Echo Request packet to the destination, waits for a +// response, and returns the duration between when the request was sent and +// when the reply was received. +// +// If provided, "data" is sent with the packet and is compared upon receiving a +// reply. +func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Duration, error) { + // Use sequential sequence numbers on the assumption that we will not + // wrap around when using a single Pinger instance + p.mu.Lock() + p.seq++ + seq := p.seq + p.mu.Unlock() + + // Check whether the address is IPv4 or IPv6 to + // determine the icmp.Type and conn to use. + var conn net.PacketConn + var icmpType icmp.Type = ipv4.ICMPTypeEcho + ap, err := netip.ParseAddr(dest.String()) + if err != nil { + return 0, err + } + if ap.Is6() { + icmpType = ipv6.ICMPTypeEchoRequest + conn, err = p.getConn(ctx, v6Type) + } else { + conn, err = p.getConn(ctx, v4Type) + } + if err != nil { + return 0, err + } + + m := icmp.Message{ + Type: icmpType, + Code: 0, + Body: &icmp.Echo{ + ID: int(p.id), + Seq: int(seq), + Data: data, + }, + } + b, err := m.Marshal(nil) + if err != nil { + return 0, err + } + + // Register our response before sending since we could otherwise race a + // quick reply. + ch := make(chan response, 1) + p.mu.Lock() + p.pings[seq] = outstanding{ch: ch, data: data} + p.mu.Unlock() + + start := p.timeNow() + n, err := conn.WriteTo(b, dest) + if err != nil { + return 0, err + } else if n != len(b) { + return 0, fmt.Errorf("conn.WriteTo: got %v; want %v", n, len(b)) + } + + select { + case resp := <-ch: + if resp.err != nil { + return 0, resp.err + } + return resp.t.Sub(start), nil + + case <-ctx.Done(): + return 0, ctx.Err() + } +} diff --git a/net/ping/ping_test.go b/net/ping/ping_test.go index bbedbcad80e44..5232f6ada85e0 100644 --- a/net/ping/ping_test.go +++ b/net/ping/ping_test.go @@ -1,350 +1,350 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ping - -import ( - "context" - "errors" - "fmt" - "net" - "testing" - "time" - - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "tailscale.com/tstest" - "tailscale.com/util/mak" -) - -var ( - localhost = &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} -) - -func TestPinger(t *testing.T) { - clock := &tstest.Clock{} - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - p, closeP := mockPinger(t, clock) - defer closeP() - - bodyData := []byte("data goes here") - - // Start a ping in the background - r := make(chan time.Duration, 1) - go func() { - dur, err := p.Send(ctx, localhost, bodyData) - if err != nil { - t.Errorf("p.Send: %v", err) - r <- 0 - } else { - r <- dur - } - }() - - p.waitOutstanding(t, ctx, 1) - - // Fake a response from ourself - fakeResponse := mustMarshal(t, &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: ipv4.ICMPTypeEchoReply.Protocol(), - Body: &icmp.Echo{ - ID: 1234, - Seq: 1, - Data: bodyData, - }, - }) - - const fakeDuration = 100 * time.Millisecond - p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v4Type) - - select { - case dur := <-r: - want := fakeDuration - if dur != want { - t.Errorf("wanted ping response time = %d; got %d", want, dur) - } - case <-ctx.Done(): - t.Fatal("did not get response by timeout") - } -} - -func TestV6Pinger(t *testing.T) { - if c, err := net.ListenPacket("udp6", "::1"); err != nil { - // skip test if we can't use IPv6. - t.Skipf("IPv6 not supported: %s", err) - } else { - c.Close() - } - - clock := &tstest.Clock{} - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - p, closeP := mockPinger(t, clock) - defer closeP() - - bodyData := []byte("data goes here") - - // Start a ping in the background - r := make(chan time.Duration, 1) - go func() { - dur, err := p.Send(ctx, &net.IPAddr{IP: net.ParseIP("::")}, bodyData) - if err != nil { - t.Errorf("p.Send: %v", err) - r <- 0 - } else { - r <- dur - } - }() - - p.waitOutstanding(t, ctx, 1) - - // Fake a response from ourself - fakeResponse := mustMarshal(t, &icmp.Message{ - Type: ipv6.ICMPTypeEchoReply, - Code: ipv6.ICMPTypeEchoReply.Protocol(), - Body: &icmp.Echo{ - ID: 1234, - Seq: 1, - Data: bodyData, - }, - }) - - const fakeDuration = 100 * time.Millisecond - p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v6Type) - - select { - case dur := <-r: - want := fakeDuration - if dur != want { - t.Errorf("wanted ping response time = %d; got %d", want, dur) - } - case <-ctx.Done(): - t.Fatal("did not get response by timeout") - } -} - -func TestPingerTimeout(t *testing.T) { - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - clock := &tstest.Clock{} - p, closeP := mockPinger(t, clock) - defer closeP() - - // Send a ping in the background - r := make(chan error, 1) - go func() { - _, err := p.Send(ctx, localhost, []byte("data goes here")) - r <- err - }() - - // Wait until we're blocking - p.waitOutstanding(t, ctx, 1) - - // Close everything down - p.cleanupOutstanding() - - // Should have got an error from the ping - err := <-r - if !errors.Is(err, net.ErrClosed) { - t.Errorf("wanted errors.Is(err, net.ErrClosed); got=%v", err) - } -} - -func TestPingerMismatch(t *testing.T) { - clock := &tstest.Clock{} - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 1*time.Second) // intentionally short - defer cancel() - - p, closeP := mockPinger(t, clock) - defer closeP() - - bodyData := []byte("data goes here") - - // Start a ping in the background - r := make(chan time.Duration, 1) - go func() { - dur, err := p.Send(ctx, localhost, bodyData) - if err != nil && !errors.Is(err, context.DeadlineExceeded) { - t.Errorf("p.Send: %v", err) - r <- 0 - } else { - r <- dur - } - }() - - p.waitOutstanding(t, ctx, 1) - - // "Receive" a bunch of intentionally malformed packets that should not - // result in the Send call above returning - badPackets := []struct { - name string - pkt *icmp.Message - }{ - { - name: "wrong type", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeDestinationUnreachable, - Code: 0, - Body: &icmp.DstUnreach{}, - }, - }, - { - name: "wrong id", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ - ID: 9999, - Seq: 1, - Data: bodyData, - }, - }, - }, - { - name: "wrong seq", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ - ID: 1234, - Seq: 5, - Data: bodyData, - }, - }, - }, - { - name: "bad body", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ - ID: 1234, - Seq: 1, - - // Intentionally missing first byte - Data: bodyData[1:], - }, - }, - }, - } - - const fakeDuration = 100 * time.Millisecond - tm := clock.Now().Add(fakeDuration) - - for _, tt := range badPackets { - fakeResponse := mustMarshal(t, tt.pkt) - p.handleResponse(fakeResponse, tm, v4Type) - } - - // Also "receive" a packet that does not unmarshal as an ICMP packet - p.handleResponse([]byte("foo"), tm, v4Type) - - select { - case <-r: - t.Fatal("wanted timeout") - case <-ctx.Done(): - t.Logf("test correctly timed out") - } -} - -// udpingPacketConn will convert potentially ICMP destination addrs to UDP -// destination addrs in WriteTo so that a test that is intending to send ICMP -// traffic will instead send UDP traffic, without the higher level Pinger being -// aware of this difference. -type udpingPacketConn struct { - net.PacketConn - // destPort will be configured by the test to be the peer expected to respond to a ping. - destPort uint16 -} - -func (u *udpingPacketConn) WriteTo(body []byte, dest net.Addr) (int, error) { - switch d := dest.(type) { - case *net.IPAddr: - udpAddr := &net.UDPAddr{ - IP: d.IP, - Port: int(u.destPort), - Zone: d.Zone, - } - return u.PacketConn.WriteTo(body, udpAddr) - } - return 0, fmt.Errorf("unimplemented udpingPacketConn for %T", dest) -} - -func mockPinger(t *testing.T, clock *tstest.Clock) (*Pinger, func()) { - p := New(context.Background(), t.Logf, nil) - p.timeNow = clock.Now - p.Verbose = true - p.id = 1234 - - // In tests, we use UDP so that we can test without being root; this - // doesn't matter because we mock out the ICMP reply below to be a real - // ICMP echo reply packet. - conn4, err := net.ListenPacket("udp4", "127.0.0.1:0") - if err != nil { - t.Fatalf("net.ListenPacket: %v", err) - } - - conn6, err := net.ListenPacket("udp6", "[::]:0") - if err != nil { - t.Fatalf("net.ListenPacket: %v", err) - } - - conn4 = &udpingPacketConn{ - destPort: 12345, - PacketConn: conn4, - } - conn6 = &udpingPacketConn{ - PacketConn: conn6, - destPort: 12345, - } - - mak.Set(&p.conns, v4Type, conn4) - mak.Set(&p.conns, v6Type, conn6) - done := func() { - if err := p.Close(); err != nil { - t.Errorf("error on close: %v", err) - } - } - return p, done -} - -func mustMarshal(t *testing.T, m *icmp.Message) []byte { - t.Helper() - - b, err := m.Marshal(nil) - if err != nil { - t.Fatal(err) - } - return b -} - -func (p *Pinger) waitOutstanding(t *testing.T, ctx context.Context, count int) { - // This is a bit janky, but... we busy-loop to wait for the Send call - // to write to our map so we know that a response will be handled. - var haveMapEntry bool - for !haveMapEntry { - time.Sleep(10 * time.Millisecond) - select { - case <-ctx.Done(): - t.Error("no entry in ping map before timeout") - return - default: - } - - p.mu.Lock() - haveMapEntry = len(p.pings) == count - p.mu.Unlock() - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ping + +import ( + "context" + "errors" + "fmt" + "net" + "testing" + "time" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "tailscale.com/tstest" + "tailscale.com/util/mak" +) + +var ( + localhost = &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} +) + +func TestPinger(t *testing.T) { + clock := &tstest.Clock{} + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + p, closeP := mockPinger(t, clock) + defer closeP() + + bodyData := []byte("data goes here") + + // Start a ping in the background + r := make(chan time.Duration, 1) + go func() { + dur, err := p.Send(ctx, localhost, bodyData) + if err != nil { + t.Errorf("p.Send: %v", err) + r <- 0 + } else { + r <- dur + } + }() + + p.waitOutstanding(t, ctx, 1) + + // Fake a response from ourself + fakeResponse := mustMarshal(t, &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: ipv4.ICMPTypeEchoReply.Protocol(), + Body: &icmp.Echo{ + ID: 1234, + Seq: 1, + Data: bodyData, + }, + }) + + const fakeDuration = 100 * time.Millisecond + p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v4Type) + + select { + case dur := <-r: + want := fakeDuration + if dur != want { + t.Errorf("wanted ping response time = %d; got %d", want, dur) + } + case <-ctx.Done(): + t.Fatal("did not get response by timeout") + } +} + +func TestV6Pinger(t *testing.T) { + if c, err := net.ListenPacket("udp6", "::1"); err != nil { + // skip test if we can't use IPv6. + t.Skipf("IPv6 not supported: %s", err) + } else { + c.Close() + } + + clock := &tstest.Clock{} + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + p, closeP := mockPinger(t, clock) + defer closeP() + + bodyData := []byte("data goes here") + + // Start a ping in the background + r := make(chan time.Duration, 1) + go func() { + dur, err := p.Send(ctx, &net.IPAddr{IP: net.ParseIP("::")}, bodyData) + if err != nil { + t.Errorf("p.Send: %v", err) + r <- 0 + } else { + r <- dur + } + }() + + p.waitOutstanding(t, ctx, 1) + + // Fake a response from ourself + fakeResponse := mustMarshal(t, &icmp.Message{ + Type: ipv6.ICMPTypeEchoReply, + Code: ipv6.ICMPTypeEchoReply.Protocol(), + Body: &icmp.Echo{ + ID: 1234, + Seq: 1, + Data: bodyData, + }, + }) + + const fakeDuration = 100 * time.Millisecond + p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v6Type) + + select { + case dur := <-r: + want := fakeDuration + if dur != want { + t.Errorf("wanted ping response time = %d; got %d", want, dur) + } + case <-ctx.Done(): + t.Fatal("did not get response by timeout") + } +} + +func TestPingerTimeout(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + clock := &tstest.Clock{} + p, closeP := mockPinger(t, clock) + defer closeP() + + // Send a ping in the background + r := make(chan error, 1) + go func() { + _, err := p.Send(ctx, localhost, []byte("data goes here")) + r <- err + }() + + // Wait until we're blocking + p.waitOutstanding(t, ctx, 1) + + // Close everything down + p.cleanupOutstanding() + + // Should have got an error from the ping + err := <-r + if !errors.Is(err, net.ErrClosed) { + t.Errorf("wanted errors.Is(err, net.ErrClosed); got=%v", err) + } +} + +func TestPingerMismatch(t *testing.T) { + clock := &tstest.Clock{} + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) // intentionally short + defer cancel() + + p, closeP := mockPinger(t, clock) + defer closeP() + + bodyData := []byte("data goes here") + + // Start a ping in the background + r := make(chan time.Duration, 1) + go func() { + dur, err := p.Send(ctx, localhost, bodyData) + if err != nil && !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("p.Send: %v", err) + r <- 0 + } else { + r <- dur + } + }() + + p.waitOutstanding(t, ctx, 1) + + // "Receive" a bunch of intentionally malformed packets that should not + // result in the Send call above returning + badPackets := []struct { + name string + pkt *icmp.Message + }{ + { + name: "wrong type", + pkt: &icmp.Message{ + Type: ipv4.ICMPTypeDestinationUnreachable, + Code: 0, + Body: &icmp.DstUnreach{}, + }, + }, + { + name: "wrong id", + pkt: &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + Body: &icmp.Echo{ + ID: 9999, + Seq: 1, + Data: bodyData, + }, + }, + }, + { + name: "wrong seq", + pkt: &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + Body: &icmp.Echo{ + ID: 1234, + Seq: 5, + Data: bodyData, + }, + }, + }, + { + name: "bad body", + pkt: &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + Body: &icmp.Echo{ + ID: 1234, + Seq: 1, + + // Intentionally missing first byte + Data: bodyData[1:], + }, + }, + }, + } + + const fakeDuration = 100 * time.Millisecond + tm := clock.Now().Add(fakeDuration) + + for _, tt := range badPackets { + fakeResponse := mustMarshal(t, tt.pkt) + p.handleResponse(fakeResponse, tm, v4Type) + } + + // Also "receive" a packet that does not unmarshal as an ICMP packet + p.handleResponse([]byte("foo"), tm, v4Type) + + select { + case <-r: + t.Fatal("wanted timeout") + case <-ctx.Done(): + t.Logf("test correctly timed out") + } +} + +// udpingPacketConn will convert potentially ICMP destination addrs to UDP +// destination addrs in WriteTo so that a test that is intending to send ICMP +// traffic will instead send UDP traffic, without the higher level Pinger being +// aware of this difference. +type udpingPacketConn struct { + net.PacketConn + // destPort will be configured by the test to be the peer expected to respond to a ping. + destPort uint16 +} + +func (u *udpingPacketConn) WriteTo(body []byte, dest net.Addr) (int, error) { + switch d := dest.(type) { + case *net.IPAddr: + udpAddr := &net.UDPAddr{ + IP: d.IP, + Port: int(u.destPort), + Zone: d.Zone, + } + return u.PacketConn.WriteTo(body, udpAddr) + } + return 0, fmt.Errorf("unimplemented udpingPacketConn for %T", dest) +} + +func mockPinger(t *testing.T, clock *tstest.Clock) (*Pinger, func()) { + p := New(context.Background(), t.Logf, nil) + p.timeNow = clock.Now + p.Verbose = true + p.id = 1234 + + // In tests, we use UDP so that we can test without being root; this + // doesn't matter because we mock out the ICMP reply below to be a real + // ICMP echo reply packet. + conn4, err := net.ListenPacket("udp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket: %v", err) + } + + conn6, err := net.ListenPacket("udp6", "[::]:0") + if err != nil { + t.Fatalf("net.ListenPacket: %v", err) + } + + conn4 = &udpingPacketConn{ + destPort: 12345, + PacketConn: conn4, + } + conn6 = &udpingPacketConn{ + PacketConn: conn6, + destPort: 12345, + } + + mak.Set(&p.conns, v4Type, conn4) + mak.Set(&p.conns, v6Type, conn6) + done := func() { + if err := p.Close(); err != nil { + t.Errorf("error on close: %v", err) + } + } + return p, done +} + +func mustMarshal(t *testing.T, m *icmp.Message) []byte { + t.Helper() + + b, err := m.Marshal(nil) + if err != nil { + t.Fatal(err) + } + return b +} + +func (p *Pinger) waitOutstanding(t *testing.T, ctx context.Context, count int) { + // This is a bit janky, but... we busy-loop to wait for the Send call + // to write to our map so we know that a response will be handled. + var haveMapEntry bool + for !haveMapEntry { + time.Sleep(10 * time.Millisecond) + select { + case <-ctx.Done(): + t.Error("no entry in ping map before timeout") + return + default: + } + + p.mu.Lock() + haveMapEntry = len(p.pings) == count + p.mu.Unlock() + } +} diff --git a/net/portmapper/pcp_test.go b/net/portmapper/pcp_test.go index 8f8eef3ef8399..3dece72367423 100644 --- a/net/portmapper/pcp_test.go +++ b/net/portmapper/pcp_test.go @@ -1,62 +1,62 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package portmapper - -import ( - "encoding/binary" - "net/netip" - "testing" - - "tailscale.com/net/netaddr" -) - -var examplePCPMapResponse = []byte{2, 129, 0, 0, 0, 0, 28, 32, 0, 2, 155, 237, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 129, 112, 9, 24, 241, 208, 251, 45, 157, 76, 10, 188, 17, 0, 0, 0, 4, 210, 4, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 135, 180, 175, 246} - -func TestParsePCPMapResponse(t *testing.T) { - mapping, err := parsePCPMapResponse(examplePCPMapResponse) - if err != nil { - t.Fatalf("failed to parse PCP Map Response: %v", err) - } - if mapping == nil { - t.Fatalf("got nil mapping when expected non-nil") - } - expectedAddr := netip.MustParseAddrPort("135.180.175.246:1234") - if mapping.external != expectedAddr { - t.Errorf("mismatched external address, got: %v, want: %v", mapping.external, expectedAddr) - } -} - -const ( - serverResponseBit = 1 << 7 - fakeLifetimeSec = 1<<31 - 1 -) - -func buildPCPDiscoResponse(req []byte) []byte { - out := make([]byte, 24) - out[0] = pcpVersion - out[1] = req[1] | serverResponseBit - out[3] = 0 - // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. - return out -} - -func buildPCPMapResponse(req []byte) []byte { - out := make([]byte, 24+36) - out[0] = pcpVersion - out[1] = req[1] | serverResponseBit - out[3] = 0 - binary.BigEndian.PutUint32(out[4:8], 1<<30) - // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. - mapResp := out[24:] - mapReq := req[24:] - // copy nonce, protocol and internal port - copy(mapResp[:13], mapReq[:13]) - copy(mapResp[16:18], mapReq[16:18]) - // assign external port - binary.BigEndian.PutUint16(mapResp[18:20], 4242) - assignedIP := netaddr.IPv4(127, 0, 0, 1) - assignedIP16 := assignedIP.As16() - copy(mapResp[20:36], assignedIP16[:]) - return out -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package portmapper + +import ( + "encoding/binary" + "net/netip" + "testing" + + "tailscale.com/net/netaddr" +) + +var examplePCPMapResponse = []byte{2, 129, 0, 0, 0, 0, 28, 32, 0, 2, 155, 237, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 129, 112, 9, 24, 241, 208, 251, 45, 157, 76, 10, 188, 17, 0, 0, 0, 4, 210, 4, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 135, 180, 175, 246} + +func TestParsePCPMapResponse(t *testing.T) { + mapping, err := parsePCPMapResponse(examplePCPMapResponse) + if err != nil { + t.Fatalf("failed to parse PCP Map Response: %v", err) + } + if mapping == nil { + t.Fatalf("got nil mapping when expected non-nil") + } + expectedAddr := netip.MustParseAddrPort("135.180.175.246:1234") + if mapping.external != expectedAddr { + t.Errorf("mismatched external address, got: %v, want: %v", mapping.external, expectedAddr) + } +} + +const ( + serverResponseBit = 1 << 7 + fakeLifetimeSec = 1<<31 - 1 +) + +func buildPCPDiscoResponse(req []byte) []byte { + out := make([]byte, 24) + out[0] = pcpVersion + out[1] = req[1] | serverResponseBit + out[3] = 0 + // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. + return out +} + +func buildPCPMapResponse(req []byte) []byte { + out := make([]byte, 24+36) + out[0] = pcpVersion + out[1] = req[1] | serverResponseBit + out[3] = 0 + binary.BigEndian.PutUint32(out[4:8], 1<<30) + // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. + mapResp := out[24:] + mapReq := req[24:] + // copy nonce, protocol and internal port + copy(mapResp[:13], mapReq[:13]) + copy(mapResp[16:18], mapReq[16:18]) + // assign external port + binary.BigEndian.PutUint16(mapResp[18:20], 4242) + assignedIP := netaddr.IPv4(127, 0, 0, 1) + assignedIP16 := assignedIP.As16() + copy(mapResp[20:36], assignedIP16[:]) + return out +} diff --git a/net/proxymux/mux.go b/net/proxymux/mux.go index ff5aaff3b975f..12c3107de8339 100644 --- a/net/proxymux/mux.go +++ b/net/proxymux/mux.go @@ -1,144 +1,144 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package proxymux splits a net.Listener in two, routing SOCKS5 -// connections to one and HTTP requests to the other. -// -// It allows for hosting both a SOCKS5 proxy and an HTTP proxy on the -// same listener. -package proxymux - -import ( - "io" - "net" - "sync" - "time" -) - -// SplitSOCKSAndHTTP accepts connections on ln and passes connections -// through to either socksListener or httpListener, depending the -// first byte sent by the client. -func SplitSOCKSAndHTTP(ln net.Listener) (socksListener, httpListener net.Listener) { - sl := &listener{ - addr: ln.Addr(), - c: make(chan net.Conn), - closed: make(chan struct{}), - } - hl := &listener{ - addr: ln.Addr(), - c: make(chan net.Conn), - closed: make(chan struct{}), - } - - go splitSOCKSAndHTTPListener(ln, sl, hl) - - return sl, hl -} - -func splitSOCKSAndHTTPListener(ln net.Listener, sl, hl *listener) { - for { - conn, err := ln.Accept() - if err != nil { - sl.Close() - hl.Close() - return - } - go routeConn(conn, sl, hl) - } -} - -func routeConn(c net.Conn, socksListener, httpListener *listener) { - if err := c.SetReadDeadline(time.Now().Add(15 * time.Second)); err != nil { - c.Close() - return - } - - var b [1]byte - if _, err := io.ReadFull(c, b[:]); err != nil { - c.Close() - return - } - - if err := c.SetReadDeadline(time.Time{}); err != nil { - c.Close() - return - } - - conn := &connWithOneByte{ - Conn: c, - b: b[0], - } - - // First byte of a SOCKS5 session is a version byte set to 5. - var ln *listener - if b[0] == 5 { - ln = socksListener - } else { - ln = httpListener - } - select { - case ln.c <- conn: - case <-ln.closed: - c.Close() - } -} - -type listener struct { - addr net.Addr - c chan net.Conn - mu sync.Mutex // serializes close() on closed. It's okay to receive on closed without locking. - closed chan struct{} -} - -func (ln *listener) Accept() (net.Conn, error) { - // Once closed, reliably stay closed, don't race with attempts at - // further connections. - select { - case <-ln.closed: - return nil, net.ErrClosed - default: - } - select { - case ret := <-ln.c: - return ret, nil - case <-ln.closed: - return nil, net.ErrClosed - } -} - -func (ln *listener) Close() error { - ln.mu.Lock() - defer ln.mu.Unlock() - select { - case <-ln.closed: - // Already closed - default: - close(ln.closed) - } - return nil -} - -func (ln *listener) Addr() net.Addr { - return ln.addr -} - -// connWithOneByte is a net.Conn that returns b for the first read -// request, then forwards everything else to Conn. -type connWithOneByte struct { - net.Conn - - b byte - bRead bool -} - -func (c *connWithOneByte) Read(bs []byte) (int, error) { - if c.bRead { - return c.Conn.Read(bs) - } - if len(bs) == 0 { - return 0, nil - } - c.bRead = true - bs[0] = c.b - return 1, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package proxymux splits a net.Listener in two, routing SOCKS5 +// connections to one and HTTP requests to the other. +// +// It allows for hosting both a SOCKS5 proxy and an HTTP proxy on the +// same listener. +package proxymux + +import ( + "io" + "net" + "sync" + "time" +) + +// SplitSOCKSAndHTTP accepts connections on ln and passes connections +// through to either socksListener or httpListener, depending the +// first byte sent by the client. +func SplitSOCKSAndHTTP(ln net.Listener) (socksListener, httpListener net.Listener) { + sl := &listener{ + addr: ln.Addr(), + c: make(chan net.Conn), + closed: make(chan struct{}), + } + hl := &listener{ + addr: ln.Addr(), + c: make(chan net.Conn), + closed: make(chan struct{}), + } + + go splitSOCKSAndHTTPListener(ln, sl, hl) + + return sl, hl +} + +func splitSOCKSAndHTTPListener(ln net.Listener, sl, hl *listener) { + for { + conn, err := ln.Accept() + if err != nil { + sl.Close() + hl.Close() + return + } + go routeConn(conn, sl, hl) + } +} + +func routeConn(c net.Conn, socksListener, httpListener *listener) { + if err := c.SetReadDeadline(time.Now().Add(15 * time.Second)); err != nil { + c.Close() + return + } + + var b [1]byte + if _, err := io.ReadFull(c, b[:]); err != nil { + c.Close() + return + } + + if err := c.SetReadDeadline(time.Time{}); err != nil { + c.Close() + return + } + + conn := &connWithOneByte{ + Conn: c, + b: b[0], + } + + // First byte of a SOCKS5 session is a version byte set to 5. + var ln *listener + if b[0] == 5 { + ln = socksListener + } else { + ln = httpListener + } + select { + case ln.c <- conn: + case <-ln.closed: + c.Close() + } +} + +type listener struct { + addr net.Addr + c chan net.Conn + mu sync.Mutex // serializes close() on closed. It's okay to receive on closed without locking. + closed chan struct{} +} + +func (ln *listener) Accept() (net.Conn, error) { + // Once closed, reliably stay closed, don't race with attempts at + // further connections. + select { + case <-ln.closed: + return nil, net.ErrClosed + default: + } + select { + case ret := <-ln.c: + return ret, nil + case <-ln.closed: + return nil, net.ErrClosed + } +} + +func (ln *listener) Close() error { + ln.mu.Lock() + defer ln.mu.Unlock() + select { + case <-ln.closed: + // Already closed + default: + close(ln.closed) + } + return nil +} + +func (ln *listener) Addr() net.Addr { + return ln.addr +} + +// connWithOneByte is a net.Conn that returns b for the first read +// request, then forwards everything else to Conn. +type connWithOneByte struct { + net.Conn + + b byte + bRead bool +} + +func (c *connWithOneByte) Read(bs []byte) (int, error) { + if c.bRead { + return c.Conn.Read(bs) + } + if len(bs) == 0 { + return 0, nil + } + c.bRead = true + bs[0] = c.b + return 1, nil +} diff --git a/net/routetable/routetable_darwin.go b/net/routetable/routetable_darwin.go index 7f525ae32807a..7de80a66229e9 100644 --- a/net/routetable/routetable_darwin.go +++ b/net/routetable/routetable_darwin.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin - -package routetable - -import "golang.org/x/sys/unix" - -const ( - ribType = unix.NET_RT_DUMP2 - parseType = unix.NET_RT_IFLIST2 - rmExpectedType = unix.RTM_GET2 - - // Skip routes that were cloned from a parent - skipFlags = unix.RTF_WASCLONED -) - -var flags = map[int]string{ - unix.RTF_BLACKHOLE: "blackhole", - unix.RTF_BROADCAST: "broadcast", - unix.RTF_GATEWAY: "gateway", - unix.RTF_GLOBAL: "global", - unix.RTF_HOST: "host", - unix.RTF_IFSCOPE: "ifscope", - unix.RTF_LOCAL: "local", - unix.RTF_MULTICAST: "multicast", - unix.RTF_REJECT: "reject", - unix.RTF_ROUTER: "router", - unix.RTF_STATIC: "static", - unix.RTF_UP: "up", - // More obscure flags, just to have full coverage. - unix.RTF_LLINFO: "{RTF_LLINFO}", - unix.RTF_PRCLONING: "{RTF_PRCLONING}", - unix.RTF_CLONING: "{RTF_CLONING}", -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin + +package routetable + +import "golang.org/x/sys/unix" + +const ( + ribType = unix.NET_RT_DUMP2 + parseType = unix.NET_RT_IFLIST2 + rmExpectedType = unix.RTM_GET2 + + // Skip routes that were cloned from a parent + skipFlags = unix.RTF_WASCLONED +) + +var flags = map[int]string{ + unix.RTF_BLACKHOLE: "blackhole", + unix.RTF_BROADCAST: "broadcast", + unix.RTF_GATEWAY: "gateway", + unix.RTF_GLOBAL: "global", + unix.RTF_HOST: "host", + unix.RTF_IFSCOPE: "ifscope", + unix.RTF_LOCAL: "local", + unix.RTF_MULTICAST: "multicast", + unix.RTF_REJECT: "reject", + unix.RTF_ROUTER: "router", + unix.RTF_STATIC: "static", + unix.RTF_UP: "up", + // More obscure flags, just to have full coverage. + unix.RTF_LLINFO: "{RTF_LLINFO}", + unix.RTF_PRCLONING: "{RTF_PRCLONING}", + unix.RTF_CLONING: "{RTF_CLONING}", +} diff --git a/net/routetable/routetable_freebsd.go b/net/routetable/routetable_freebsd.go index 8e57a330246ed..aa4e03c41236a 100644 --- a/net/routetable/routetable_freebsd.go +++ b/net/routetable/routetable_freebsd.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build freebsd - -package routetable - -import "golang.org/x/sys/unix" - -const ( - ribType = unix.NET_RT_DUMP - parseType = unix.NET_RT_IFLIST - rmExpectedType = unix.RTM_GET - - // Nothing to skip - skipFlags = 0 -) - -var flags = map[int]string{ - unix.RTF_BLACKHOLE: "blackhole", - unix.RTF_BROADCAST: "broadcast", - unix.RTF_GATEWAY: "gateway", - unix.RTF_HOST: "host", - unix.RTF_MULTICAST: "multicast", - unix.RTF_REJECT: "reject", - unix.RTF_STATIC: "static", - unix.RTF_UP: "up", -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build freebsd + +package routetable + +import "golang.org/x/sys/unix" + +const ( + ribType = unix.NET_RT_DUMP + parseType = unix.NET_RT_IFLIST + rmExpectedType = unix.RTM_GET + + // Nothing to skip + skipFlags = 0 +) + +var flags = map[int]string{ + unix.RTF_BLACKHOLE: "blackhole", + unix.RTF_BROADCAST: "broadcast", + unix.RTF_GATEWAY: "gateway", + unix.RTF_HOST: "host", + unix.RTF_MULTICAST: "multicast", + unix.RTF_REJECT: "reject", + unix.RTF_STATIC: "static", + unix.RTF_UP: "up", +} diff --git a/net/routetable/routetable_other.go b/net/routetable/routetable_other.go index 35c83e374564f..521fe1911aaa5 100644 --- a/net/routetable/routetable_other.go +++ b/net/routetable/routetable_other.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !darwin && !freebsd - -package routetable - -import ( - "errors" - "runtime" -) - -var errUnsupported = errors.New("cannot get route table on platform " + runtime.GOOS) - -func Get(max int) ([]RouteEntry, error) { - return nil, errUnsupported -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !darwin && !freebsd + +package routetable + +import ( + "errors" + "runtime" +) + +var errUnsupported = errors.New("cannot get route table on platform " + runtime.GOOS) + +func Get(max int) ([]RouteEntry, error) { + return nil, errUnsupported +} diff --git a/net/sockstats/sockstats.go b/net/sockstats/sockstats.go index 715c1ee06e9a9..fb524a5c53684 100644 --- a/net/sockstats/sockstats.go +++ b/net/sockstats/sockstats.go @@ -1,121 +1,121 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package sockstats collects statistics about network sockets used by -// the Tailscale client. The context where sockets are used must be -// instrumented with the WithSockStats() function. -// -// Only available on POSIX platforms when built with Tailscale's fork of Go. -package sockstats - -import ( - "context" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -// SockStats contains statistics for sockets instrumented with the -// WithSockStats() function -type SockStats struct { - Stats map[Label]SockStat - CurrentInterfaceCellular bool -} - -// SockStat contains the sent and received bytes for a socket instrumented with -// the WithSockStats() function. -type SockStat struct { - TxBytes uint64 - RxBytes uint64 -} - -// Label is an identifier for a socket that stats are collected for. A finite -// set of values that may be used to label a socket to encourage grouping and -// to make storage more efficient. -type Label uint8 - -//go:generate go run golang.org/x/tools/cmd/stringer -type Label -trimprefix Label - -// Labels are named after the package and function/struct that uses the socket. -// Values may be persisted and thus existing entries should not be re-numbered. -const ( - LabelControlClientAuto Label = 0 // control/controlclient/auto.go - LabelControlClientDialer Label = 1 // control/controlhttp/client.go - LabelDERPHTTPClient Label = 2 // derp/derphttp/derphttp_client.go - LabelLogtailLogger Label = 3 // logtail/logtail.go - LabelDNSForwarderDoH Label = 4 // net/dns/resolver/forwarder.go - LabelDNSForwarderUDP Label = 5 // net/dns/resolver/forwarder.go - LabelNetcheckClient Label = 6 // net/netcheck/netcheck.go - LabelPortmapperClient Label = 7 // net/portmapper/portmapper.go - LabelMagicsockConnUDP4 Label = 8 // wgengine/magicsock/magicsock.go - LabelMagicsockConnUDP6 Label = 9 // wgengine/magicsock/magicsock.go - LabelNetlogLogger Label = 10 // wgengine/netlog/logger.go - LabelSockstatlogLogger Label = 11 // log/sockstatlog/logger.go - LabelDNSForwarderTCP Label = 12 // net/dns/resolver/forwarder.go -) - -// WithSockStats instruments a context so that sockets created with it will -// have their statistics collected. -func WithSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context { - return withSockStats(ctx, label, logf) -} - -// Get returns the current socket statistics. -func Get() *SockStats { - return get() -} - -// InterfaceSockStats contains statistics for sockets instrumented with the -// WithSockStats() function, broken down by interface. The statistics may be a -// subset of the total if interfaces were added after the instrumented socket -// was created. -type InterfaceSockStats struct { - Stats map[Label]InterfaceSockStat - Interfaces []string -} - -// InterfaceSockStat contains the per-interface sent and received bytes for a -// socket instrumented with the WithSockStats() function. -type InterfaceSockStat struct { - TxBytesByInterface map[string]uint64 - RxBytesByInterface map[string]uint64 -} - -// GetWithInterfaces is a variant of Get that returns the current socket -// statistics broken down by interface. It is slightly more expensive than Get. -func GetInterfaces() *InterfaceSockStats { - return getInterfaces() -} - -// ValidationSockStats contains external validation numbers for sockets -// instrumented with WithSockStats. It may be a subset of the all sockets, -// depending on what externa measurement mechanisms the platform supports. -type ValidationSockStats struct { - Stats map[Label]ValidationSockStat -} - -// ValidationSockStat contains the validation bytes for a socket instrumented -// with WithSockStats. -type ValidationSockStat struct { - TxBytes uint64 - RxBytes uint64 -} - -// GetValidation is a variant of Get that returns external validation numbers -// for stats. It is more expensive than Get and should be used in debug -// interfaces only. -func GetValidation() *ValidationSockStats { - return getValidation() -} - -// SetNetMon configures the sockstats package to monitor the active -// interface, so that per-interface stats can be collected. -func SetNetMon(netMon *netmon.Monitor) { - setNetMon(netMon) -} - -// DebugInfo returns a string containing debug information about the tracked -// statistics. -func DebugInfo() string { - return debugInfo() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package sockstats collects statistics about network sockets used by +// the Tailscale client. The context where sockets are used must be +// instrumented with the WithSockStats() function. +// +// Only available on POSIX platforms when built with Tailscale's fork of Go. +package sockstats + +import ( + "context" + + "tailscale.com/net/netmon" + "tailscale.com/types/logger" +) + +// SockStats contains statistics for sockets instrumented with the +// WithSockStats() function +type SockStats struct { + Stats map[Label]SockStat + CurrentInterfaceCellular bool +} + +// SockStat contains the sent and received bytes for a socket instrumented with +// the WithSockStats() function. +type SockStat struct { + TxBytes uint64 + RxBytes uint64 +} + +// Label is an identifier for a socket that stats are collected for. A finite +// set of values that may be used to label a socket to encourage grouping and +// to make storage more efficient. +type Label uint8 + +//go:generate go run golang.org/x/tools/cmd/stringer -type Label -trimprefix Label + +// Labels are named after the package and function/struct that uses the socket. +// Values may be persisted and thus existing entries should not be re-numbered. +const ( + LabelControlClientAuto Label = 0 // control/controlclient/auto.go + LabelControlClientDialer Label = 1 // control/controlhttp/client.go + LabelDERPHTTPClient Label = 2 // derp/derphttp/derphttp_client.go + LabelLogtailLogger Label = 3 // logtail/logtail.go + LabelDNSForwarderDoH Label = 4 // net/dns/resolver/forwarder.go + LabelDNSForwarderUDP Label = 5 // net/dns/resolver/forwarder.go + LabelNetcheckClient Label = 6 // net/netcheck/netcheck.go + LabelPortmapperClient Label = 7 // net/portmapper/portmapper.go + LabelMagicsockConnUDP4 Label = 8 // wgengine/magicsock/magicsock.go + LabelMagicsockConnUDP6 Label = 9 // wgengine/magicsock/magicsock.go + LabelNetlogLogger Label = 10 // wgengine/netlog/logger.go + LabelSockstatlogLogger Label = 11 // log/sockstatlog/logger.go + LabelDNSForwarderTCP Label = 12 // net/dns/resolver/forwarder.go +) + +// WithSockStats instruments a context so that sockets created with it will +// have their statistics collected. +func WithSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context { + return withSockStats(ctx, label, logf) +} + +// Get returns the current socket statistics. +func Get() *SockStats { + return get() +} + +// InterfaceSockStats contains statistics for sockets instrumented with the +// WithSockStats() function, broken down by interface. The statistics may be a +// subset of the total if interfaces were added after the instrumented socket +// was created. +type InterfaceSockStats struct { + Stats map[Label]InterfaceSockStat + Interfaces []string +} + +// InterfaceSockStat contains the per-interface sent and received bytes for a +// socket instrumented with the WithSockStats() function. +type InterfaceSockStat struct { + TxBytesByInterface map[string]uint64 + RxBytesByInterface map[string]uint64 +} + +// GetWithInterfaces is a variant of Get that returns the current socket +// statistics broken down by interface. It is slightly more expensive than Get. +func GetInterfaces() *InterfaceSockStats { + return getInterfaces() +} + +// ValidationSockStats contains external validation numbers for sockets +// instrumented with WithSockStats. It may be a subset of the all sockets, +// depending on what externa measurement mechanisms the platform supports. +type ValidationSockStats struct { + Stats map[Label]ValidationSockStat +} + +// ValidationSockStat contains the validation bytes for a socket instrumented +// with WithSockStats. +type ValidationSockStat struct { + TxBytes uint64 + RxBytes uint64 +} + +// GetValidation is a variant of Get that returns external validation numbers +// for stats. It is more expensive than Get and should be used in debug +// interfaces only. +func GetValidation() *ValidationSockStats { + return getValidation() +} + +// SetNetMon configures the sockstats package to monitor the active +// interface, so that per-interface stats can be collected. +func SetNetMon(netMon *netmon.Monitor) { + setNetMon(netMon) +} + +// DebugInfo returns a string containing debug information about the tracked +// statistics. +func DebugInfo() string { + return debugInfo() +} diff --git a/net/sockstats/sockstats_noop.go b/net/sockstats/sockstats_noop.go index 96723111ade7a..797fdc42bde18 100644 --- a/net/sockstats/sockstats_noop.go +++ b/net/sockstats/sockstats_noop.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !tailscale_go || !(darwin || ios || android || ts_enable_sockstats) - -package sockstats - -import ( - "context" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -const IsAvailable = false - -func withSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context { - return ctx -} - -func get() *SockStats { - return nil -} - -func getInterfaces() *InterfaceSockStats { - return nil -} - -func getValidation() *ValidationSockStats { - return nil -} - -func setNetMon(netMon *netmon.Monitor) { -} - -func debugInfo() string { - return "" -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !tailscale_go || !(darwin || ios || android || ts_enable_sockstats) + +package sockstats + +import ( + "context" + + "tailscale.com/net/netmon" + "tailscale.com/types/logger" +) + +const IsAvailable = false + +func withSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context { + return ctx +} + +func get() *SockStats { + return nil +} + +func getInterfaces() *InterfaceSockStats { + return nil +} + +func getValidation() *ValidationSockStats { + return nil +} + +func setNetMon(netMon *netmon.Monitor) { +} + +func debugInfo() string { + return "" +} diff --git a/net/sockstats/sockstats_tsgo_darwin.go b/net/sockstats/sockstats_tsgo_darwin.go index 321d32e04e5f0..4b03ed6162965 100644 --- a/net/sockstats/sockstats_tsgo_darwin.go +++ b/net/sockstats/sockstats_tsgo_darwin.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build tailscale_go && (darwin || ios) - -package sockstats - -import ( - "syscall" - - "golang.org/x/sys/unix" -) - -func init() { - tcpConnStats = darwinTcpConnStats -} - -func darwinTcpConnStats(c syscall.RawConn) (tx, rx uint64) { - c.Control(func(fd uintptr) { - if rawInfo, err := unix.GetsockoptTCPConnectionInfo( - int(fd), - unix.IPPROTO_TCP, - unix.TCP_CONNECTION_INFO, - ); err == nil { - tx = uint64(rawInfo.Txbytes) - rx = uint64(rawInfo.Rxbytes) - } - }) - return -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build tailscale_go && (darwin || ios) + +package sockstats + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + tcpConnStats = darwinTcpConnStats +} + +func darwinTcpConnStats(c syscall.RawConn) (tx, rx uint64) { + c.Control(func(fd uintptr) { + if rawInfo, err := unix.GetsockoptTCPConnectionInfo( + int(fd), + unix.IPPROTO_TCP, + unix.TCP_CONNECTION_INFO, + ); err == nil { + tx = uint64(rawInfo.Txbytes) + rx = uint64(rawInfo.Rxbytes) + } + }) + return +} diff --git a/net/speedtest/speedtest.go b/net/speedtest/speedtest.go index 7ab0881cc22f9..89639c12d5fc2 100644 --- a/net/speedtest/speedtest.go +++ b/net/speedtest/speedtest.go @@ -1,87 +1,87 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package speedtest contains both server and client code for -// running speedtests between tailscale nodes. -package speedtest - -import ( - "time" -) - -const ( - blockSize = 2 * 1024 * 1024 // size of the block of data to send - MinDuration = 5 * time.Second // minimum duration for a test - DefaultDuration = MinDuration // default duration for a test - MaxDuration = 30 * time.Second // maximum duration for a test - version = 2 // value used when comparing client and server versions - increment = time.Second // increment to display results for, in seconds - minInterval = 10 * time.Millisecond // minimum interval length for a result to be included - DefaultPort = 20333 -) - -// config is the initial message sent to the server, that contains information on how to -// conduct the test. -type config struct { - Version int `json:"version"` - TestDuration time.Duration `json:"time"` - Direction Direction `json:"direction"` -} - -// configResponse is the response to the testConfig message. If the server has an -// error with the config, the Error variable will hold that error value. -type configResponse struct { - Error string `json:"error,omitempty"` -} - -// This represents the Result of a speedtest within a specific interval -type Result struct { - Bytes int // number of bytes sent/received during the interval - IntervalStart time.Time // start of the interval - IntervalEnd time.Time // end of the interval - Total bool // if true, this result struct represents the entire test, rather than a segment of the test -} - -func (r Result) MBitsPerSecond() float64 { - return r.MegaBits() / r.IntervalEnd.Sub(r.IntervalStart).Seconds() -} - -func (r Result) MegaBytes() float64 { - return float64(r.Bytes) / 1000000.0 -} - -func (r Result) MegaBits() float64 { - return r.MegaBytes() * 8.0 -} - -func (r Result) Interval() time.Duration { - return r.IntervalEnd.Sub(r.IntervalStart) -} - -type Direction int - -const ( - Download Direction = iota - Upload -) - -func (d Direction) String() string { - switch d { - case Upload: - return "upload" - case Download: - return "download" - default: - return "" - } -} - -func (d *Direction) Reverse() { - switch *d { - case Upload: - *d = Download - case Download: - *d = Upload - default: - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package speedtest contains both server and client code for +// running speedtests between tailscale nodes. +package speedtest + +import ( + "time" +) + +const ( + blockSize = 2 * 1024 * 1024 // size of the block of data to send + MinDuration = 5 * time.Second // minimum duration for a test + DefaultDuration = MinDuration // default duration for a test + MaxDuration = 30 * time.Second // maximum duration for a test + version = 2 // value used when comparing client and server versions + increment = time.Second // increment to display results for, in seconds + minInterval = 10 * time.Millisecond // minimum interval length for a result to be included + DefaultPort = 20333 +) + +// config is the initial message sent to the server, that contains information on how to +// conduct the test. +type config struct { + Version int `json:"version"` + TestDuration time.Duration `json:"time"` + Direction Direction `json:"direction"` +} + +// configResponse is the response to the testConfig message. If the server has an +// error with the config, the Error variable will hold that error value. +type configResponse struct { + Error string `json:"error,omitempty"` +} + +// This represents the Result of a speedtest within a specific interval +type Result struct { + Bytes int // number of bytes sent/received during the interval + IntervalStart time.Time // start of the interval + IntervalEnd time.Time // end of the interval + Total bool // if true, this result struct represents the entire test, rather than a segment of the test +} + +func (r Result) MBitsPerSecond() float64 { + return r.MegaBits() / r.IntervalEnd.Sub(r.IntervalStart).Seconds() +} + +func (r Result) MegaBytes() float64 { + return float64(r.Bytes) / 1000000.0 +} + +func (r Result) MegaBits() float64 { + return r.MegaBytes() * 8.0 +} + +func (r Result) Interval() time.Duration { + return r.IntervalEnd.Sub(r.IntervalStart) +} + +type Direction int + +const ( + Download Direction = iota + Upload +) + +func (d Direction) String() string { + switch d { + case Upload: + return "upload" + case Download: + return "download" + default: + return "" + } +} + +func (d *Direction) Reverse() { + switch *d { + case Upload: + *d = Download + case Download: + *d = Upload + default: + } +} diff --git a/net/speedtest/speedtest_client.go b/net/speedtest/speedtest_client.go index 299a12a8dfaec..cc34c468c22c0 100644 --- a/net/speedtest/speedtest_client.go +++ b/net/speedtest/speedtest_client.go @@ -1,41 +1,41 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package speedtest - -import ( - "encoding/json" - "errors" - "net" - "time" -) - -// RunClient dials the given address and starts a speedtest. -// It returns any errors that come up in the tests. -// If there are no errors in the test, it returns a slice of results. -func RunClient(direction Direction, duration time.Duration, host string) ([]Result, error) { - conn, err := net.Dial("tcp", host) - if err != nil { - return nil, err - } - - conf := config{TestDuration: duration, Version: version, Direction: direction} - - defer conn.Close() - encoder := json.NewEncoder(conn) - - if err = encoder.Encode(conf); err != nil { - return nil, err - } - - var response configResponse - decoder := json.NewDecoder(conn) - if err = decoder.Decode(&response); err != nil { - return nil, err - } - if response.Error != "" { - return nil, errors.New(response.Error) - } - - return doTest(conn, conf) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package speedtest + +import ( + "encoding/json" + "errors" + "net" + "time" +) + +// RunClient dials the given address and starts a speedtest. +// It returns any errors that come up in the tests. +// If there are no errors in the test, it returns a slice of results. +func RunClient(direction Direction, duration time.Duration, host string) ([]Result, error) { + conn, err := net.Dial("tcp", host) + if err != nil { + return nil, err + } + + conf := config{TestDuration: duration, Version: version, Direction: direction} + + defer conn.Close() + encoder := json.NewEncoder(conn) + + if err = encoder.Encode(conf); err != nil { + return nil, err + } + + var response configResponse + decoder := json.NewDecoder(conn) + if err = decoder.Decode(&response); err != nil { + return nil, err + } + if response.Error != "" { + return nil, errors.New(response.Error) + } + + return doTest(conn, conf) +} diff --git a/net/speedtest/speedtest_server.go b/net/speedtest/speedtest_server.go index 9dd78b195fff4..d2673464e3132 100644 --- a/net/speedtest/speedtest_server.go +++ b/net/speedtest/speedtest_server.go @@ -1,146 +1,146 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package speedtest - -import ( - "crypto/rand" - "encoding/json" - "errors" - "fmt" - "io" - "net" - "time" -) - -// Serve starts up the server on a given host and port pair. It starts to listen for -// connections and handles each one in a goroutine. Because it runs in an infinite loop, -// this function only returns if any of the speedtests return with errors, or if the -// listener is closed. -func Serve(l net.Listener) error { - for { - conn, err := l.Accept() - if errors.Is(err, net.ErrClosed) { - return nil - } - if err != nil { - return err - } - err = handleConnection(conn) - if err != nil { - return err - } - } -} - -// handleConnection handles the initial exchange between the server and the client. -// It reads the testconfig message into a config struct. If any errors occur with -// the testconfig (specifically, if there is a version mismatch), it will return those -// errors to the client with a configResponse. After the exchange, it will start -// the speed test. -func handleConnection(conn net.Conn) error { - defer conn.Close() - var conf config - - decoder := json.NewDecoder(conn) - err := decoder.Decode(&conf) - encoder := json.NewEncoder(conn) - - // Both return and encode errors that occurred before the test started. - if err != nil { - encoder.Encode(configResponse{Error: err.Error()}) - return err - } - - // The server should always be doing the opposite of what the client is doing. - conf.Direction.Reverse() - - if conf.Version != version { - err = fmt.Errorf("version mismatch! Server is version %d, client is version %d", version, conf.Version) - encoder.Encode(configResponse{Error: err.Error()}) - return err - } - - // Start the test - encoder.Encode(configResponse{}) - _, err = doTest(conn, conf) - return err -} - -// TODO include code to detect whether the code is direct vs DERP - -// doTest contains the code to run both the upload and download speedtest. -// the direction value in the config parameter determines which test to run. -func doTest(conn net.Conn, conf config) ([]Result, error) { - bufferData := make([]byte, blockSize) - - intervalBytes := 0 - totalBytes := 0 - - var currentTime time.Time - var results []Result - - if conf.Direction == Download { - conn.SetReadDeadline(time.Now().Add(conf.TestDuration).Add(5 * time.Second)) - } else { - _, err := rand.Read(bufferData) - if err != nil { - return nil, err - } - - } - - startTime := time.Now() - lastCalculated := startTime - -SpeedTestLoop: - for { - var n int - var err error - - if conf.Direction == Download { - n, err = io.ReadFull(conn, bufferData) - switch err { - case io.EOF, io.ErrUnexpectedEOF: - break SpeedTestLoop - case nil: - // successful read - default: - return nil, fmt.Errorf("unexpected error has occurred: %w", err) - } - } else { - n, err = conn.Write(bufferData) - if err != nil { - // If the write failed, there is most likely something wrong with the connection. - return nil, fmt.Errorf("upload failed: %w", err) - } - } - intervalBytes += n - - currentTime = time.Now() - // checks if the current time is more or equal to the lastCalculated time plus the increment - if currentTime.Sub(lastCalculated) >= increment { - results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false}) - lastCalculated = currentTime - totalBytes += intervalBytes - intervalBytes = 0 - } - - if conf.Direction == Upload && currentTime.Sub(startTime) > conf.TestDuration { - break SpeedTestLoop - } - } - - // get last segment - if currentTime.Sub(lastCalculated) > minInterval { - results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false}) - } - - // get total - totalBytes += intervalBytes - if currentTime.Sub(startTime) > minInterval { - results = append(results, Result{Bytes: totalBytes, IntervalStart: startTime, IntervalEnd: currentTime, Total: true}) - } - - return results, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package speedtest + +import ( + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "time" +) + +// Serve starts up the server on a given host and port pair. It starts to listen for +// connections and handles each one in a goroutine. Because it runs in an infinite loop, +// this function only returns if any of the speedtests return with errors, or if the +// listener is closed. +func Serve(l net.Listener) error { + for { + conn, err := l.Accept() + if errors.Is(err, net.ErrClosed) { + return nil + } + if err != nil { + return err + } + err = handleConnection(conn) + if err != nil { + return err + } + } +} + +// handleConnection handles the initial exchange between the server and the client. +// It reads the testconfig message into a config struct. If any errors occur with +// the testconfig (specifically, if there is a version mismatch), it will return those +// errors to the client with a configResponse. After the exchange, it will start +// the speed test. +func handleConnection(conn net.Conn) error { + defer conn.Close() + var conf config + + decoder := json.NewDecoder(conn) + err := decoder.Decode(&conf) + encoder := json.NewEncoder(conn) + + // Both return and encode errors that occurred before the test started. + if err != nil { + encoder.Encode(configResponse{Error: err.Error()}) + return err + } + + // The server should always be doing the opposite of what the client is doing. + conf.Direction.Reverse() + + if conf.Version != version { + err = fmt.Errorf("version mismatch! Server is version %d, client is version %d", version, conf.Version) + encoder.Encode(configResponse{Error: err.Error()}) + return err + } + + // Start the test + encoder.Encode(configResponse{}) + _, err = doTest(conn, conf) + return err +} + +// TODO include code to detect whether the code is direct vs DERP + +// doTest contains the code to run both the upload and download speedtest. +// the direction value in the config parameter determines which test to run. +func doTest(conn net.Conn, conf config) ([]Result, error) { + bufferData := make([]byte, blockSize) + + intervalBytes := 0 + totalBytes := 0 + + var currentTime time.Time + var results []Result + + if conf.Direction == Download { + conn.SetReadDeadline(time.Now().Add(conf.TestDuration).Add(5 * time.Second)) + } else { + _, err := rand.Read(bufferData) + if err != nil { + return nil, err + } + + } + + startTime := time.Now() + lastCalculated := startTime + +SpeedTestLoop: + for { + var n int + var err error + + if conf.Direction == Download { + n, err = io.ReadFull(conn, bufferData) + switch err { + case io.EOF, io.ErrUnexpectedEOF: + break SpeedTestLoop + case nil: + // successful read + default: + return nil, fmt.Errorf("unexpected error has occurred: %w", err) + } + } else { + n, err = conn.Write(bufferData) + if err != nil { + // If the write failed, there is most likely something wrong with the connection. + return nil, fmt.Errorf("upload failed: %w", err) + } + } + intervalBytes += n + + currentTime = time.Now() + // checks if the current time is more or equal to the lastCalculated time plus the increment + if currentTime.Sub(lastCalculated) >= increment { + results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false}) + lastCalculated = currentTime + totalBytes += intervalBytes + intervalBytes = 0 + } + + if conf.Direction == Upload && currentTime.Sub(startTime) > conf.TestDuration { + break SpeedTestLoop + } + } + + // get last segment + if currentTime.Sub(lastCalculated) > minInterval { + results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false}) + } + + // get total + totalBytes += intervalBytes + if currentTime.Sub(startTime) > minInterval { + results = append(results, Result{Bytes: totalBytes, IntervalStart: startTime, IntervalEnd: currentTime, Total: true}) + } + + return results, nil +} diff --git a/net/speedtest/speedtest_test.go b/net/speedtest/speedtest_test.go index 55dcbeea1abdf..a413e9efafcd4 100644 --- a/net/speedtest/speedtest_test.go +++ b/net/speedtest/speedtest_test.go @@ -1,83 +1,83 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package speedtest - -import ( - "net" - "testing" - "time" -) - -func TestDownload(t *testing.T) { - // start a listener and find the port where the server will be listening. - l, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { l.Close() }) - - serverIP := l.Addr().String() - t.Log("server IP found:", serverIP) - - type state struct { - err error - } - displayResult := func(t *testing.T, r Result, start time.Time) { - t.Helper() - t.Logf("{ Megabytes: %.2f, Start: %.1f, End: %.1f, Total: %t }", r.MegaBytes(), r.IntervalStart.Sub(start).Seconds(), r.IntervalEnd.Sub(start).Seconds(), r.Total) - } - stateChan := make(chan state, 1) - - go func() { - err := Serve(l) - stateChan <- state{err: err} - }() - - // ensure that the test returns an appropriate number of Result structs - expectedLen := int(DefaultDuration.Seconds()) + 1 - - t.Run("download test", func(t *testing.T) { - // conduct a download test - results, err := RunClient(Download, DefaultDuration, serverIP) - - if err != nil { - t.Fatal("download test failed:", err) - } - - if len(results) < expectedLen { - t.Fatalf("download results: expected length: %d, actual length: %d", expectedLen, len(results)) - } - - start := results[0].IntervalStart - for _, result := range results { - displayResult(t, result, start) - } - }) - - t.Run("upload test", func(t *testing.T) { - // conduct an upload test - results, err := RunClient(Upload, DefaultDuration, serverIP) - - if err != nil { - t.Fatal("upload test failed:", err) - } - - if len(results) < expectedLen { - t.Fatalf("upload results: expected length: %d, actual length: %d", expectedLen, len(results)) - } - - start := results[0].IntervalStart - for _, result := range results { - displayResult(t, result, start) - } - }) - - // causes the server goroutine to finish - l.Close() - - testState := <-stateChan - if testState.err != nil { - t.Error("server error:", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package speedtest + +import ( + "net" + "testing" + "time" +) + +func TestDownload(t *testing.T) { + // start a listener and find the port where the server will be listening. + l, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { l.Close() }) + + serverIP := l.Addr().String() + t.Log("server IP found:", serverIP) + + type state struct { + err error + } + displayResult := func(t *testing.T, r Result, start time.Time) { + t.Helper() + t.Logf("{ Megabytes: %.2f, Start: %.1f, End: %.1f, Total: %t }", r.MegaBytes(), r.IntervalStart.Sub(start).Seconds(), r.IntervalEnd.Sub(start).Seconds(), r.Total) + } + stateChan := make(chan state, 1) + + go func() { + err := Serve(l) + stateChan <- state{err: err} + }() + + // ensure that the test returns an appropriate number of Result structs + expectedLen := int(DefaultDuration.Seconds()) + 1 + + t.Run("download test", func(t *testing.T) { + // conduct a download test + results, err := RunClient(Download, DefaultDuration, serverIP) + + if err != nil { + t.Fatal("download test failed:", err) + } + + if len(results) < expectedLen { + t.Fatalf("download results: expected length: %d, actual length: %d", expectedLen, len(results)) + } + + start := results[0].IntervalStart + for _, result := range results { + displayResult(t, result, start) + } + }) + + t.Run("upload test", func(t *testing.T) { + // conduct an upload test + results, err := RunClient(Upload, DefaultDuration, serverIP) + + if err != nil { + t.Fatal("upload test failed:", err) + } + + if len(results) < expectedLen { + t.Fatalf("upload results: expected length: %d, actual length: %d", expectedLen, len(results)) + } + + start := results[0].IntervalStart + for _, result := range results { + displayResult(t, result, start) + } + }) + + // causes the server goroutine to finish + l.Close() + + testState := <-stateChan + if testState.err != nil { + t.Error("server error:", err) + } +} diff --git a/net/stun/stun.go b/net/stun/stun.go index eeac23cbbd45d..81cf9b6080d26 100644 --- a/net/stun/stun.go +++ b/net/stun/stun.go @@ -1,312 +1,312 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package STUN generates STUN request packets and parses response packets. -package stun - -import ( - "bytes" - crand "crypto/rand" - "encoding/binary" - "errors" - "hash/crc32" - "net" - "net/netip" -) - -const ( - attrNumSoftware = 0x8022 - attrNumFingerprint = 0x8028 - attrMappedAddress = 0x0001 - attrXorMappedAddress = 0x0020 - // This alternative attribute type is not - // mentioned in the RFC, but the shift into - // the "comprehension-optional" range seems - // like an easy mistake for a server to make. - // And servers appear to send it. - attrXorMappedAddressAlt = 0x8020 - - software = "tailnode" // notably: 8 bytes long, so no padding - bindingRequest = "\x00\x01" - magicCookie = "\x21\x12\xa4\x42" - lenFingerprint = 8 // 2+byte header + 2-byte length + 4-byte crc32 - headerLen = 20 -) - -// TxID is a transaction ID. -type TxID [12]byte - -// NewTxID returns a new random TxID. -func NewTxID() TxID { - var tx TxID - if _, err := crand.Read(tx[:]); err != nil { - panic(err) - } - return tx -} - -// Request generates a binding request STUN packet. -// The transaction ID, tID, should be a random sequence of bytes. -func Request(tID TxID) []byte { - // STUN header, RFC5389 Section 6. - const lenAttrSoftware = 4 + len(software) - b := make([]byte, 0, headerLen+lenAttrSoftware+lenFingerprint) - b = append(b, bindingRequest...) - b = appendU16(b, uint16(lenAttrSoftware+lenFingerprint)) // number of bytes following header - b = append(b, magicCookie...) - b = append(b, tID[:]...) - - // Attribute SOFTWARE, RFC5389 Section 15.5. - b = appendU16(b, attrNumSoftware) - b = appendU16(b, uint16(len(software))) - b = append(b, software...) - - // Attribute FINGERPRINT, RFC5389 Section 15.5. - fp := fingerPrint(b) - b = appendU16(b, attrNumFingerprint) - b = appendU16(b, 4) - b = appendU32(b, fp) - - return b -} - -func fingerPrint(b []byte) uint32 { return crc32.ChecksumIEEE(b) ^ 0x5354554e } - -func appendU16(b []byte, v uint16) []byte { - return append(b, byte(v>>8), byte(v)) -} - -func appendU32(b []byte, v uint32) []byte { - return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) -} - -// ParseBindingRequest parses a STUN binding request. -// -// It returns an error unless it advertises that it came from -// Tailscale. -func ParseBindingRequest(b []byte) (TxID, error) { - if !Is(b) { - return TxID{}, ErrNotSTUN - } - if string(b[:len(bindingRequest)]) != bindingRequest { - return TxID{}, ErrNotBindingRequest - } - var txID TxID - copy(txID[:], b[8:8+len(txID)]) - var softwareOK bool - var lastAttr uint16 - var gotFP uint32 - if err := foreachAttr(b[headerLen:], func(attrType uint16, a []byte) error { - lastAttr = attrType - if attrType == attrNumSoftware && string(a) == software { - softwareOK = true - } - if attrType == attrNumFingerprint && len(a) == 4 { - gotFP = binary.BigEndian.Uint32(a) - } - return nil - }); err != nil { - return TxID{}, err - } - if !softwareOK { - return TxID{}, ErrWrongSoftware - } - if lastAttr != attrNumFingerprint { - return TxID{}, ErrNoFingerprint - } - wantFP := fingerPrint(b[:len(b)-lenFingerprint]) - if gotFP != wantFP { - return TxID{}, ErrWrongFingerprint - } - return txID, nil -} - -var ( - ErrNotSTUN = errors.New("response is not a STUN packet") - ErrNotSuccessResponse = errors.New("STUN packet is not a response") - ErrMalformedAttrs = errors.New("STUN response has malformed attributes") - ErrNotBindingRequest = errors.New("STUN request not a binding request") - ErrWrongSoftware = errors.New("STUN request came from non-Tailscale software") - ErrNoFingerprint = errors.New("STUN request didn't end in fingerprint") - ErrWrongFingerprint = errors.New("STUN request had bogus fingerprint") -) - -func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error { - for len(b) > 0 { - if len(b) < 4 { - return ErrMalformedAttrs - } - attrType := binary.BigEndian.Uint16(b[:2]) - attrLen := int(binary.BigEndian.Uint16(b[2:4])) - attrLenWithPad := (attrLen + 3) &^ 3 - b = b[4:] - if attrLenWithPad > len(b) { - return ErrMalformedAttrs - } - if err := fn(attrType, b[:attrLen]); err != nil { - return err - } - b = b[attrLenWithPad:] - } - return nil -} - -// Response generates a binding response. -func Response(txID TxID, addrPort netip.AddrPort) []byte { - addr := addrPort.Addr() - - var fam byte - if addr.Is4() { - fam = 1 - } else if addr.Is6() { - fam = 2 - } else { - return nil - } - attrsLen := 8 + addr.BitLen()/8 - b := make([]byte, 0, headerLen+attrsLen) - - // Header - b = append(b, 0x01, 0x01) // success - b = appendU16(b, uint16(attrsLen)) - b = append(b, magicCookie...) - b = append(b, txID[:]...) - - // Attributes (well, one) - b = appendU16(b, attrXorMappedAddress) - b = appendU16(b, uint16(4+addr.BitLen()/8)) - b = append(b, - 0, // unused byte - fam) - b = appendU16(b, addrPort.Port()^0x2112) // first half of magicCookie - ipa := addr.As16() - for i, o := range ipa[16-addr.BitLen()/8:] { - if i < 4 { - b = append(b, o^magicCookie[i]) - } else { - b = append(b, o^txID[i-len(magicCookie)]) - } - } - return b -} - -// ParseResponse parses a successful binding response STUN packet. -// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute. -func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) { - if !Is(b) { - return tID, netip.AddrPort{}, ErrNotSTUN - } - copy(tID[:], b[8:8+len(tID)]) - if b[0] != 0x01 || b[1] != 0x01 { - return tID, netip.AddrPort{}, ErrNotSuccessResponse - } - attrsLen := int(binary.BigEndian.Uint16(b[2:4])) - b = b[headerLen:] // remove STUN header - if attrsLen > len(b) { - return tID, netip.AddrPort{}, ErrMalformedAttrs - } else if len(b) > attrsLen { - b = b[:attrsLen] // trim trailing packet bytes - } - - var fallbackAddr netip.AddrPort - - // Read through the attributes. - // The the addr+port reported by XOR-MAPPED-ADDRESS - // as the canonical value. If the attribute is not - // present but the STUN server responds with - // MAPPED-ADDRESS we fall back to it. - if err := foreachAttr(b, func(attrType uint16, attr []byte) error { - switch attrType { - case attrXorMappedAddress, attrXorMappedAddressAlt: - ipSlice, port, err := xorMappedAddress(tID, attr) - if err != nil { - return err - } - if ip, ok := netip.AddrFromSlice(ipSlice); ok { - addr = netip.AddrPortFrom(ip.Unmap(), port) - } - case attrMappedAddress: - ipSlice, port, err := mappedAddress(attr) - if err != nil { - return ErrMalformedAttrs - } - if ip, ok := netip.AddrFromSlice(ipSlice); ok { - fallbackAddr = netip.AddrPortFrom(ip.Unmap(), port) - } - } - return nil - - }); err != nil { - return TxID{}, netip.AddrPort{}, err - } - - if addr.IsValid() { - return tID, addr, nil - } - if fallbackAddr.IsValid() { - return tID, fallbackAddr, nil - } - return tID, netip.AddrPort{}, ErrMalformedAttrs -} - -func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error) { - // XOR-MAPPED-ADDRESS attribute, RFC5389 Section 15.2 - if len(b) < 4 { - return nil, 0, ErrMalformedAttrs - } - xorPort := binary.BigEndian.Uint16(b[2:4]) - addrField := b[4:] - port = xorPort ^ 0x2112 // first half of magicCookie - - addrLen := familyAddrLen(b[1]) - if addrLen == 0 { - return nil, 0, ErrMalformedAttrs - } - if len(addrField) < addrLen { - return nil, 0, ErrMalformedAttrs - } - xorAddr := addrField[:addrLen] - addr = make([]byte, addrLen) - for i := range xorAddr { - if i < len(magicCookie) { - addr[i] = xorAddr[i] ^ magicCookie[i] - } else { - addr[i] = xorAddr[i] ^ tID[i-len(magicCookie)] - } - } - return addr, port, nil -} - -func familyAddrLen(fam byte) int { - switch fam { - case 0x01: // IPv4 - return net.IPv4len - case 0x02: // IPv6 - return net.IPv6len - default: - return 0 - } -} - -func mappedAddress(b []byte) (addr []byte, port uint16, err error) { - if len(b) < 4 { - return nil, 0, ErrMalformedAttrs - } - port = uint16(b[2])<<8 | uint16(b[3]) - addrField := b[4:] - addrLen := familyAddrLen(b[1]) - if addrLen == 0 { - return nil, 0, ErrMalformedAttrs - } - if len(addrField) < addrLen { - return nil, 0, ErrMalformedAttrs - } - return bytes.Clone(addrField[:addrLen]), port, nil -} - -// Is reports whether b is a STUN message. -func Is(b []byte) bool { - return len(b) >= headerLen && - b[0]&0b11000000 == 0 && // top two bits must be zero - string(b[4:8]) == magicCookie -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package STUN generates STUN request packets and parses response packets. +package stun + +import ( + "bytes" + crand "crypto/rand" + "encoding/binary" + "errors" + "hash/crc32" + "net" + "net/netip" +) + +const ( + attrNumSoftware = 0x8022 + attrNumFingerprint = 0x8028 + attrMappedAddress = 0x0001 + attrXorMappedAddress = 0x0020 + // This alternative attribute type is not + // mentioned in the RFC, but the shift into + // the "comprehension-optional" range seems + // like an easy mistake for a server to make. + // And servers appear to send it. + attrXorMappedAddressAlt = 0x8020 + + software = "tailnode" // notably: 8 bytes long, so no padding + bindingRequest = "\x00\x01" + magicCookie = "\x21\x12\xa4\x42" + lenFingerprint = 8 // 2+byte header + 2-byte length + 4-byte crc32 + headerLen = 20 +) + +// TxID is a transaction ID. +type TxID [12]byte + +// NewTxID returns a new random TxID. +func NewTxID() TxID { + var tx TxID + if _, err := crand.Read(tx[:]); err != nil { + panic(err) + } + return tx +} + +// Request generates a binding request STUN packet. +// The transaction ID, tID, should be a random sequence of bytes. +func Request(tID TxID) []byte { + // STUN header, RFC5389 Section 6. + const lenAttrSoftware = 4 + len(software) + b := make([]byte, 0, headerLen+lenAttrSoftware+lenFingerprint) + b = append(b, bindingRequest...) + b = appendU16(b, uint16(lenAttrSoftware+lenFingerprint)) // number of bytes following header + b = append(b, magicCookie...) + b = append(b, tID[:]...) + + // Attribute SOFTWARE, RFC5389 Section 15.5. + b = appendU16(b, attrNumSoftware) + b = appendU16(b, uint16(len(software))) + b = append(b, software...) + + // Attribute FINGERPRINT, RFC5389 Section 15.5. + fp := fingerPrint(b) + b = appendU16(b, attrNumFingerprint) + b = appendU16(b, 4) + b = appendU32(b, fp) + + return b +} + +func fingerPrint(b []byte) uint32 { return crc32.ChecksumIEEE(b) ^ 0x5354554e } + +func appendU16(b []byte, v uint16) []byte { + return append(b, byte(v>>8), byte(v)) +} + +func appendU32(b []byte, v uint32) []byte { + return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +// ParseBindingRequest parses a STUN binding request. +// +// It returns an error unless it advertises that it came from +// Tailscale. +func ParseBindingRequest(b []byte) (TxID, error) { + if !Is(b) { + return TxID{}, ErrNotSTUN + } + if string(b[:len(bindingRequest)]) != bindingRequest { + return TxID{}, ErrNotBindingRequest + } + var txID TxID + copy(txID[:], b[8:8+len(txID)]) + var softwareOK bool + var lastAttr uint16 + var gotFP uint32 + if err := foreachAttr(b[headerLen:], func(attrType uint16, a []byte) error { + lastAttr = attrType + if attrType == attrNumSoftware && string(a) == software { + softwareOK = true + } + if attrType == attrNumFingerprint && len(a) == 4 { + gotFP = binary.BigEndian.Uint32(a) + } + return nil + }); err != nil { + return TxID{}, err + } + if !softwareOK { + return TxID{}, ErrWrongSoftware + } + if lastAttr != attrNumFingerprint { + return TxID{}, ErrNoFingerprint + } + wantFP := fingerPrint(b[:len(b)-lenFingerprint]) + if gotFP != wantFP { + return TxID{}, ErrWrongFingerprint + } + return txID, nil +} + +var ( + ErrNotSTUN = errors.New("response is not a STUN packet") + ErrNotSuccessResponse = errors.New("STUN packet is not a response") + ErrMalformedAttrs = errors.New("STUN response has malformed attributes") + ErrNotBindingRequest = errors.New("STUN request not a binding request") + ErrWrongSoftware = errors.New("STUN request came from non-Tailscale software") + ErrNoFingerprint = errors.New("STUN request didn't end in fingerprint") + ErrWrongFingerprint = errors.New("STUN request had bogus fingerprint") +) + +func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error { + for len(b) > 0 { + if len(b) < 4 { + return ErrMalformedAttrs + } + attrType := binary.BigEndian.Uint16(b[:2]) + attrLen := int(binary.BigEndian.Uint16(b[2:4])) + attrLenWithPad := (attrLen + 3) &^ 3 + b = b[4:] + if attrLenWithPad > len(b) { + return ErrMalformedAttrs + } + if err := fn(attrType, b[:attrLen]); err != nil { + return err + } + b = b[attrLenWithPad:] + } + return nil +} + +// Response generates a binding response. +func Response(txID TxID, addrPort netip.AddrPort) []byte { + addr := addrPort.Addr() + + var fam byte + if addr.Is4() { + fam = 1 + } else if addr.Is6() { + fam = 2 + } else { + return nil + } + attrsLen := 8 + addr.BitLen()/8 + b := make([]byte, 0, headerLen+attrsLen) + + // Header + b = append(b, 0x01, 0x01) // success + b = appendU16(b, uint16(attrsLen)) + b = append(b, magicCookie...) + b = append(b, txID[:]...) + + // Attributes (well, one) + b = appendU16(b, attrXorMappedAddress) + b = appendU16(b, uint16(4+addr.BitLen()/8)) + b = append(b, + 0, // unused byte + fam) + b = appendU16(b, addrPort.Port()^0x2112) // first half of magicCookie + ipa := addr.As16() + for i, o := range ipa[16-addr.BitLen()/8:] { + if i < 4 { + b = append(b, o^magicCookie[i]) + } else { + b = append(b, o^txID[i-len(magicCookie)]) + } + } + return b +} + +// ParseResponse parses a successful binding response STUN packet. +// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute. +func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) { + if !Is(b) { + return tID, netip.AddrPort{}, ErrNotSTUN + } + copy(tID[:], b[8:8+len(tID)]) + if b[0] != 0x01 || b[1] != 0x01 { + return tID, netip.AddrPort{}, ErrNotSuccessResponse + } + attrsLen := int(binary.BigEndian.Uint16(b[2:4])) + b = b[headerLen:] // remove STUN header + if attrsLen > len(b) { + return tID, netip.AddrPort{}, ErrMalformedAttrs + } else if len(b) > attrsLen { + b = b[:attrsLen] // trim trailing packet bytes + } + + var fallbackAddr netip.AddrPort + + // Read through the attributes. + // The the addr+port reported by XOR-MAPPED-ADDRESS + // as the canonical value. If the attribute is not + // present but the STUN server responds with + // MAPPED-ADDRESS we fall back to it. + if err := foreachAttr(b, func(attrType uint16, attr []byte) error { + switch attrType { + case attrXorMappedAddress, attrXorMappedAddressAlt: + ipSlice, port, err := xorMappedAddress(tID, attr) + if err != nil { + return err + } + if ip, ok := netip.AddrFromSlice(ipSlice); ok { + addr = netip.AddrPortFrom(ip.Unmap(), port) + } + case attrMappedAddress: + ipSlice, port, err := mappedAddress(attr) + if err != nil { + return ErrMalformedAttrs + } + if ip, ok := netip.AddrFromSlice(ipSlice); ok { + fallbackAddr = netip.AddrPortFrom(ip.Unmap(), port) + } + } + return nil + + }); err != nil { + return TxID{}, netip.AddrPort{}, err + } + + if addr.IsValid() { + return tID, addr, nil + } + if fallbackAddr.IsValid() { + return tID, fallbackAddr, nil + } + return tID, netip.AddrPort{}, ErrMalformedAttrs +} + +func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error) { + // XOR-MAPPED-ADDRESS attribute, RFC5389 Section 15.2 + if len(b) < 4 { + return nil, 0, ErrMalformedAttrs + } + xorPort := binary.BigEndian.Uint16(b[2:4]) + addrField := b[4:] + port = xorPort ^ 0x2112 // first half of magicCookie + + addrLen := familyAddrLen(b[1]) + if addrLen == 0 { + return nil, 0, ErrMalformedAttrs + } + if len(addrField) < addrLen { + return nil, 0, ErrMalformedAttrs + } + xorAddr := addrField[:addrLen] + addr = make([]byte, addrLen) + for i := range xorAddr { + if i < len(magicCookie) { + addr[i] = xorAddr[i] ^ magicCookie[i] + } else { + addr[i] = xorAddr[i] ^ tID[i-len(magicCookie)] + } + } + return addr, port, nil +} + +func familyAddrLen(fam byte) int { + switch fam { + case 0x01: // IPv4 + return net.IPv4len + case 0x02: // IPv6 + return net.IPv6len + default: + return 0 + } +} + +func mappedAddress(b []byte) (addr []byte, port uint16, err error) { + if len(b) < 4 { + return nil, 0, ErrMalformedAttrs + } + port = uint16(b[2])<<8 | uint16(b[3]) + addrField := b[4:] + addrLen := familyAddrLen(b[1]) + if addrLen == 0 { + return nil, 0, ErrMalformedAttrs + } + if len(addrField) < addrLen { + return nil, 0, ErrMalformedAttrs + } + return bytes.Clone(addrField[:addrLen]), port, nil +} + +// Is reports whether b is a STUN message. +func Is(b []byte) bool { + return len(b) >= headerLen && + b[0]&0b11000000 == 0 && // top two bits must be zero + string(b[4:8]) == magicCookie +} diff --git a/net/stun/stun_fuzzer.go b/net/stun/stun_fuzzer.go index 6f0c9e3b0beae..9ddb418950b39 100644 --- a/net/stun/stun_fuzzer.go +++ b/net/stun/stun_fuzzer.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -//go:build gofuzz - -package stun - -func FuzzStunParser(data []byte) int { - _, _, _ = ParseResponse(data) - - _, _ = ParseBindingRequest(data) - return 1 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +//go:build gofuzz + +package stun + +func FuzzStunParser(data []byte) int { + _, _, _ = ParseResponse(data) + + _, _ = ParseBindingRequest(data) + return 1 +} diff --git a/net/tcpinfo/tcpinfo.go b/net/tcpinfo/tcpinfo.go index a757add9f8f46..adc40ca372cf5 100644 --- a/net/tcpinfo/tcpinfo.go +++ b/net/tcpinfo/tcpinfo.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package tcpinfo provides platform-agnostic accessors to information about a -// TCP connection (e.g. RTT, MSS, etc.). -package tcpinfo - -import ( - "errors" - "net" - "time" -) - -var ( - ErrNotTCP = errors.New("tcpinfo: not a TCP conn") - ErrUnimplemented = errors.New("tcpinfo: unimplemented") -) - -// RTT returns the RTT for the given net.Conn. -// -// If the net.Conn is not a *net.TCPConn and cannot be unwrapped into one, then -// ErrNotTCP will be returned. If retrieving the RTT is not supported on the -// current platform, ErrUnimplemented will be returned. -func RTT(conn net.Conn) (time.Duration, error) { - tcpConn, err := unwrap(conn) - if err != nil { - return 0, err - } - - return rttImpl(tcpConn) -} - -// netConner is implemented by crypto/tls.Conn to unwrap into an underlying -// net.Conn. -type netConner interface { - NetConn() net.Conn -} - -// unwrap attempts to unwrap a net.Conn into an underlying *net.TCPConn -func unwrap(nc net.Conn) (*net.TCPConn, error) { - for { - switch v := nc.(type) { - case *net.TCPConn: - return v, nil - case netConner: - nc = v.NetConn() - default: - return nil, ErrNotTCP - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tcpinfo provides platform-agnostic accessors to information about a +// TCP connection (e.g. RTT, MSS, etc.). +package tcpinfo + +import ( + "errors" + "net" + "time" +) + +var ( + ErrNotTCP = errors.New("tcpinfo: not a TCP conn") + ErrUnimplemented = errors.New("tcpinfo: unimplemented") +) + +// RTT returns the RTT for the given net.Conn. +// +// If the net.Conn is not a *net.TCPConn and cannot be unwrapped into one, then +// ErrNotTCP will be returned. If retrieving the RTT is not supported on the +// current platform, ErrUnimplemented will be returned. +func RTT(conn net.Conn) (time.Duration, error) { + tcpConn, err := unwrap(conn) + if err != nil { + return 0, err + } + + return rttImpl(tcpConn) +} + +// netConner is implemented by crypto/tls.Conn to unwrap into an underlying +// net.Conn. +type netConner interface { + NetConn() net.Conn +} + +// unwrap attempts to unwrap a net.Conn into an underlying *net.TCPConn +func unwrap(nc net.Conn) (*net.TCPConn, error) { + for { + switch v := nc.(type) { + case *net.TCPConn: + return v, nil + case netConner: + nc = v.NetConn() + default: + return nil, ErrNotTCP + } + } +} diff --git a/net/tcpinfo/tcpinfo_darwin.go b/net/tcpinfo/tcpinfo_darwin.go index 53fa22fbf5bed..bc4ac08b38b04 100644 --- a/net/tcpinfo/tcpinfo_darwin.go +++ b/net/tcpinfo/tcpinfo_darwin.go @@ -1,33 +1,33 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tcpinfo - -import ( - "net" - "time" - - "golang.org/x/sys/unix" -) - -func rttImpl(conn *net.TCPConn) (time.Duration, error) { - rawConn, err := conn.SyscallConn() - if err != nil { - return 0, err - } - - var ( - tcpInfo *unix.TCPConnectionInfo - sysErr error - ) - err = rawConn.Control(func(fd uintptr) { - tcpInfo, sysErr = unix.GetsockoptTCPConnectionInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) - }) - if err != nil { - return 0, err - } else if sysErr != nil { - return 0, sysErr - } - - return time.Duration(tcpInfo.Rttcur) * time.Millisecond, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tcpinfo + +import ( + "net" + "time" + + "golang.org/x/sys/unix" +) + +func rttImpl(conn *net.TCPConn) (time.Duration, error) { + rawConn, err := conn.SyscallConn() + if err != nil { + return 0, err + } + + var ( + tcpInfo *unix.TCPConnectionInfo + sysErr error + ) + err = rawConn.Control(func(fd uintptr) { + tcpInfo, sysErr = unix.GetsockoptTCPConnectionInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) + }) + if err != nil { + return 0, err + } else if sysErr != nil { + return 0, sysErr + } + + return time.Duration(tcpInfo.Rttcur) * time.Millisecond, nil +} diff --git a/net/tcpinfo/tcpinfo_linux.go b/net/tcpinfo/tcpinfo_linux.go index 885d462c95e35..5d86055bb8499 100644 --- a/net/tcpinfo/tcpinfo_linux.go +++ b/net/tcpinfo/tcpinfo_linux.go @@ -1,33 +1,33 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tcpinfo - -import ( - "net" - "time" - - "golang.org/x/sys/unix" -) - -func rttImpl(conn *net.TCPConn) (time.Duration, error) { - rawConn, err := conn.SyscallConn() - if err != nil { - return 0, err - } - - var ( - tcpInfo *unix.TCPInfo - sysErr error - ) - err = rawConn.Control(func(fd uintptr) { - tcpInfo, sysErr = unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO) - }) - if err != nil { - return 0, err - } else if sysErr != nil { - return 0, sysErr - } - - return time.Duration(tcpInfo.Rtt) * time.Microsecond, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tcpinfo + +import ( + "net" + "time" + + "golang.org/x/sys/unix" +) + +func rttImpl(conn *net.TCPConn) (time.Duration, error) { + rawConn, err := conn.SyscallConn() + if err != nil { + return 0, err + } + + var ( + tcpInfo *unix.TCPInfo + sysErr error + ) + err = rawConn.Control(func(fd uintptr) { + tcpInfo, sysErr = unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO) + }) + if err != nil { + return 0, err + } else if sysErr != nil { + return 0, sysErr + } + + return time.Duration(tcpInfo.Rtt) * time.Microsecond, nil +} diff --git a/net/tcpinfo/tcpinfo_other.go b/net/tcpinfo/tcpinfo_other.go index be45523aeb00d..f219cda1bd4a0 100644 --- a/net/tcpinfo/tcpinfo_other.go +++ b/net/tcpinfo/tcpinfo_other.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !darwin - -package tcpinfo - -import ( - "net" - "time" -) - -func rttImpl(conn *net.TCPConn) (time.Duration, error) { - return 0, ErrUnimplemented -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !darwin + +package tcpinfo + +import ( + "net" + "time" +) + +func rttImpl(conn *net.TCPConn) (time.Duration, error) { + return 0, ErrUnimplemented +} diff --git a/net/tlsdial/deps_test.go b/net/tlsdial/deps_test.go index 7a93899c2f126..750cb300ae5eb 100644 --- a/net/tlsdial/deps_test.go +++ b/net/tlsdial/deps_test.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build for_go_mod_tidy_only - -package tlsdial - -import _ "filippo.io/mkcert" +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build for_go_mod_tidy_only + +package tlsdial + +import _ "filippo.io/mkcert" diff --git a/net/tsdial/dnsmap_test.go b/net/tsdial/dnsmap_test.go index 43461a135e1c5..f846b853e1432 100644 --- a/net/tsdial/dnsmap_test.go +++ b/net/tsdial/dnsmap_test.go @@ -1,125 +1,125 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsdial - -import ( - "net/netip" - "reflect" - "testing" - - "tailscale.com/tailcfg" - "tailscale.com/types/netmap" -) - -func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { - nv := make([]tailcfg.NodeView, len(v)) - for i, n := range v { - nv[i] = n.View() - } - return nv -} - -func TestDNSMapFromNetworkMap(t *testing.T) { - pfx := netip.MustParsePrefix - ip := netip.MustParseAddr - tests := []struct { - name string - nm *netmap.NetworkMap - want dnsMap - }{ - { - name: "self", - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100.102.103.104/32"), - pfx("100::123/128"), - }, - }).View(), - }, - want: dnsMap{ - "foo": ip("100.102.103.104"), - "foo.tailnet": ip("100.102.103.104"), - }, - }, - { - name: "self_and_peers", - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100.102.103.104/32"), - pfx("100::123/128"), - }, - }).View(), - Peers: []tailcfg.NodeView{ - (&tailcfg.Node{ - Name: "a.tailnet", - Addresses: []netip.Prefix{ - pfx("100.0.0.201/32"), - pfx("100::201/128"), - }, - }).View(), - (&tailcfg.Node{ - Name: "b.tailnet", - Addresses: []netip.Prefix{ - pfx("100::202/128"), - }, - }).View(), - }, - }, - want: dnsMap{ - "foo": ip("100.102.103.104"), - "foo.tailnet": ip("100.102.103.104"), - "a": ip("100.0.0.201"), - "a.tailnet": ip("100.0.0.201"), - "b": ip("100::202"), - "b.tailnet": ip("100::202"), - }, - }, - { - name: "self_has_v6_only", - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100::123/128"), - }, - }).View(), - Peers: nodeViews([]*tailcfg.Node{ - { - Name: "a.tailnet", - Addresses: []netip.Prefix{ - pfx("100.0.0.201/32"), - pfx("100::201/128"), - }, - }, - { - Name: "b.tailnet", - Addresses: []netip.Prefix{ - pfx("100::202/128"), - }, - }, - }), - }, - want: dnsMap{ - "foo": ip("100::123"), - "foo.tailnet": ip("100::123"), - "a": ip("100::201"), - "a.tailnet": ip("100::201"), - "b": ip("100::202"), - "b.tailnet": ip("100::202"), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := dnsMapFromNetworkMap(tt.nm) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("mismatch:\n got %v\nwant %v\n", got, tt.want) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsdial + +import ( + "net/netip" + "reflect" + "testing" + + "tailscale.com/tailcfg" + "tailscale.com/types/netmap" +) + +func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { + nv := make([]tailcfg.NodeView, len(v)) + for i, n := range v { + nv[i] = n.View() + } + return nv +} + +func TestDNSMapFromNetworkMap(t *testing.T) { + pfx := netip.MustParsePrefix + ip := netip.MustParseAddr + tests := []struct { + name string + nm *netmap.NetworkMap + want dnsMap + }{ + { + name: "self", + nm: &netmap.NetworkMap{ + Name: "foo.tailnet", + SelfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{ + pfx("100.102.103.104/32"), + pfx("100::123/128"), + }, + }).View(), + }, + want: dnsMap{ + "foo": ip("100.102.103.104"), + "foo.tailnet": ip("100.102.103.104"), + }, + }, + { + name: "self_and_peers", + nm: &netmap.NetworkMap{ + Name: "foo.tailnet", + SelfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{ + pfx("100.102.103.104/32"), + pfx("100::123/128"), + }, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + Name: "a.tailnet", + Addresses: []netip.Prefix{ + pfx("100.0.0.201/32"), + pfx("100::201/128"), + }, + }).View(), + (&tailcfg.Node{ + Name: "b.tailnet", + Addresses: []netip.Prefix{ + pfx("100::202/128"), + }, + }).View(), + }, + }, + want: dnsMap{ + "foo": ip("100.102.103.104"), + "foo.tailnet": ip("100.102.103.104"), + "a": ip("100.0.0.201"), + "a.tailnet": ip("100.0.0.201"), + "b": ip("100::202"), + "b.tailnet": ip("100::202"), + }, + }, + { + name: "self_has_v6_only", + nm: &netmap.NetworkMap{ + Name: "foo.tailnet", + SelfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{ + pfx("100::123/128"), + }, + }).View(), + Peers: nodeViews([]*tailcfg.Node{ + { + Name: "a.tailnet", + Addresses: []netip.Prefix{ + pfx("100.0.0.201/32"), + pfx("100::201/128"), + }, + }, + { + Name: "b.tailnet", + Addresses: []netip.Prefix{ + pfx("100::202/128"), + }, + }, + }), + }, + want: dnsMap{ + "foo": ip("100::123"), + "foo.tailnet": ip("100::123"), + "a": ip("100::201"), + "a.tailnet": ip("100::201"), + "b": ip("100::202"), + "b.tailnet": ip("100::202"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := dnsMapFromNetworkMap(tt.nm) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("mismatch:\n got %v\nwant %v\n", got, tt.want) + } + }) + } +} diff --git a/net/tsdial/dohclient.go b/net/tsdial/dohclient.go index d830398cdfb9c..64c127fd3270a 100644 --- a/net/tsdial/dohclient.go +++ b/net/tsdial/dohclient.go @@ -1,100 +1,100 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsdial - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "net" - "net/http" - "time" - - "tailscale.com/net/dnscache" -) - -// dohConn is a net.PacketConn suitable for returning from -// net.Dialer.Dial to send DNS queries over PeerAPI to exit nodes' -// ExitDNS DoH proxy service. -type dohConn struct { - ctx context.Context - baseURL string - hc *http.Client // if nil, default is used - dnsCache *dnscache.MessageCache - - rbuf bytes.Buffer -} - -var ( - _ net.Conn = (*dohConn)(nil) - _ net.PacketConn = (*dohConn)(nil) // be a PacketConn to change net.Resolver semantics -) - -func (*dohConn) Close() error { return nil } -func (*dohConn) LocalAddr() net.Addr { return todoAddr{} } -func (*dohConn) RemoteAddr() net.Addr { return todoAddr{} } -func (*dohConn) SetDeadline(t time.Time) error { return nil } -func (*dohConn) SetReadDeadline(t time.Time) error { return nil } -func (*dohConn) SetWriteDeadline(t time.Time) error { return nil } - -func (c *dohConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - return c.Write(p) -} - -func (c *dohConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - n, err = c.Read(p) - return n, todoAddr{}, err -} - -func (c *dohConn) Read(p []byte) (n int, err error) { - return c.rbuf.Read(p) -} - -func (c *dohConn) Write(packet []byte) (n int, err error) { - if c.dnsCache != nil { - err := c.dnsCache.ReplyFromCache(&c.rbuf, packet) - if err == nil { - // Cache hit. - // TODO(bradfitz): add clientmetric - return len(packet), nil - } - c.rbuf.Reset() - } - req, err := http.NewRequestWithContext(c.ctx, "POST", c.baseURL, bytes.NewReader(packet)) - if err != nil { - return 0, err - } - const dohType = "application/dns-message" - req.Header.Set("Content-Type", dohType) - hc := c.hc - if hc == nil { - hc = http.DefaultClient - } - hres, err := hc.Do(req) - if err != nil { - return 0, err - } - defer hres.Body.Close() - if hres.StatusCode != 200 { - return 0, errors.New(hres.Status) - } - if ct := hres.Header.Get("Content-Type"); ct != dohType { - return 0, fmt.Errorf("unexpected response Content-Type %q", ct) - } - _, err = io.Copy(&c.rbuf, hres.Body) - if err != nil { - return 0, err - } - if c.dnsCache != nil { - c.dnsCache.AddCacheEntry(packet, c.rbuf.Bytes()) - } - return len(packet), nil -} - -type todoAddr struct{} - -func (todoAddr) Network() string { return "unused" } -func (todoAddr) String() string { return "unused-todoAddr" } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsdial + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "time" + + "tailscale.com/net/dnscache" +) + +// dohConn is a net.PacketConn suitable for returning from +// net.Dialer.Dial to send DNS queries over PeerAPI to exit nodes' +// ExitDNS DoH proxy service. +type dohConn struct { + ctx context.Context + baseURL string + hc *http.Client // if nil, default is used + dnsCache *dnscache.MessageCache + + rbuf bytes.Buffer +} + +var ( + _ net.Conn = (*dohConn)(nil) + _ net.PacketConn = (*dohConn)(nil) // be a PacketConn to change net.Resolver semantics +) + +func (*dohConn) Close() error { return nil } +func (*dohConn) LocalAddr() net.Addr { return todoAddr{} } +func (*dohConn) RemoteAddr() net.Addr { return todoAddr{} } +func (*dohConn) SetDeadline(t time.Time) error { return nil } +func (*dohConn) SetReadDeadline(t time.Time) error { return nil } +func (*dohConn) SetWriteDeadline(t time.Time) error { return nil } + +func (c *dohConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return c.Write(p) +} + +func (c *dohConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(p) + return n, todoAddr{}, err +} + +func (c *dohConn) Read(p []byte) (n int, err error) { + return c.rbuf.Read(p) +} + +func (c *dohConn) Write(packet []byte) (n int, err error) { + if c.dnsCache != nil { + err := c.dnsCache.ReplyFromCache(&c.rbuf, packet) + if err == nil { + // Cache hit. + // TODO(bradfitz): add clientmetric + return len(packet), nil + } + c.rbuf.Reset() + } + req, err := http.NewRequestWithContext(c.ctx, "POST", c.baseURL, bytes.NewReader(packet)) + if err != nil { + return 0, err + } + const dohType = "application/dns-message" + req.Header.Set("Content-Type", dohType) + hc := c.hc + if hc == nil { + hc = http.DefaultClient + } + hres, err := hc.Do(req) + if err != nil { + return 0, err + } + defer hres.Body.Close() + if hres.StatusCode != 200 { + return 0, errors.New(hres.Status) + } + if ct := hres.Header.Get("Content-Type"); ct != dohType { + return 0, fmt.Errorf("unexpected response Content-Type %q", ct) + } + _, err = io.Copy(&c.rbuf, hres.Body) + if err != nil { + return 0, err + } + if c.dnsCache != nil { + c.dnsCache.AddCacheEntry(packet, c.rbuf.Bytes()) + } + return len(packet), nil +} + +type todoAddr struct{} + +func (todoAddr) Network() string { return "unused" } +func (todoAddr) String() string { return "unused-todoAddr" } diff --git a/net/tsdial/dohclient_test.go b/net/tsdial/dohclient_test.go index 23255769f4847..41a66f8f71edd 100644 --- a/net/tsdial/dohclient_test.go +++ b/net/tsdial/dohclient_test.go @@ -1,31 +1,31 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsdial - -import ( - "context" - "flag" - "net" - "testing" - "time" -) - -var dohBase = flag.String("doh-base", "", "DoH base URL for manual DoH tests; e.g. \"http://100.68.82.120:47830/dns-query\"") - -func TestDoHResolve(t *testing.T) { - if *dohBase == "" { - t.Skip("skipping manual test without --doh-base= set") - } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - var r net.Resolver - r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { - return &dohConn{ctx: ctx, baseURL: *dohBase}, nil - } - addrs, err := r.LookupIP(ctx, "ip4", "google.com.") - if err != nil { - t.Fatal(err) - } - t.Logf("Got: %q", addrs) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsdial + +import ( + "context" + "flag" + "net" + "testing" + "time" +) + +var dohBase = flag.String("doh-base", "", "DoH base URL for manual DoH tests; e.g. \"http://100.68.82.120:47830/dns-query\"") + +func TestDoHResolve(t *testing.T) { + if *dohBase == "" { + t.Skip("skipping manual test without --doh-base= set") + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + var r net.Resolver + r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { + return &dohConn{ctx: ctx, baseURL: *dohBase}, nil + } + addrs, err := r.LookupIP(ctx, "ip4", "google.com.") + if err != nil { + t.Fatal(err) + } + t.Logf("Got: %q", addrs) +} diff --git a/net/tshttpproxy/mksyscall.go b/net/tshttpproxy/mksyscall.go index f8fdae89b55f0..467dc49170092 100644 --- a/net/tshttpproxy/mksyscall.go +++ b/net/tshttpproxy/mksyscall.go @@ -1,11 +1,11 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tshttpproxy - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go - -//sys globalFree(hglobal winHGlobal) (err error) [failretval==0] = kernel32.GlobalFree -//sys winHTTPCloseHandle(whi winHTTPInternet) (err error) [failretval==0] = winhttp.WinHttpCloseHandle -//sys winHTTPGetProxyForURL(whi winHTTPInternet, url *uint16, options *winHTTPAutoProxyOptions, proxyInfo *winHTTPProxyInfo) (err error) [failretval==0] = winhttp.WinHttpGetProxyForUrl -//sys winHTTPOpen(agent *uint16, accessType uint32, proxy *uint16, proxyBypass *uint16, flags uint32) (whi winHTTPInternet, err error) [failretval==0] = winhttp.WinHttpOpen +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tshttpproxy + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go + +//sys globalFree(hglobal winHGlobal) (err error) [failretval==0] = kernel32.GlobalFree +//sys winHTTPCloseHandle(whi winHTTPInternet) (err error) [failretval==0] = winhttp.WinHttpCloseHandle +//sys winHTTPGetProxyForURL(whi winHTTPInternet, url *uint16, options *winHTTPAutoProxyOptions, proxyInfo *winHTTPProxyInfo) (err error) [failretval==0] = winhttp.WinHttpGetProxyForUrl +//sys winHTTPOpen(agent *uint16, accessType uint32, proxy *uint16, proxyBypass *uint16, flags uint32) (whi winHTTPInternet, err error) [failretval==0] = winhttp.WinHttpOpen diff --git a/net/tshttpproxy/tshttpproxy_linux.go b/net/tshttpproxy/tshttpproxy_linux.go index b241c256d4798..09019893ade8c 100644 --- a/net/tshttpproxy/tshttpproxy_linux.go +++ b/net/tshttpproxy/tshttpproxy_linux.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package tshttpproxy - -import ( - "net/http" - "net/url" - - "tailscale.com/version/distro" -) - -func init() { - sysProxyFromEnv = linuxSysProxyFromEnv -} - -func linuxSysProxyFromEnv(req *http.Request) (*url.URL, error) { - if distro.Get() == distro.Synology { - return synologyProxyFromConfigCached(req) - } - return nil, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package tshttpproxy + +import ( + "net/http" + "net/url" + + "tailscale.com/version/distro" +) + +func init() { + sysProxyFromEnv = linuxSysProxyFromEnv +} + +func linuxSysProxyFromEnv(req *http.Request) (*url.URL, error) { + if distro.Get() == distro.Synology { + return synologyProxyFromConfigCached(req) + } + return nil, nil +} diff --git a/net/tshttpproxy/tshttpproxy_synology_test.go b/net/tshttpproxy/tshttpproxy_synology_test.go index 3061740f3beff..e11c9d05996ed 100644 --- a/net/tshttpproxy/tshttpproxy_synology_test.go +++ b/net/tshttpproxy/tshttpproxy_synology_test.go @@ -1,376 +1,376 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package tshttpproxy - -import ( - "errors" - "fmt" - "io" - "net/http" - "net/url" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "tailscale.com/tstest" -) - -func TestSynologyProxyFromConfigCached(t *testing.T) { - req, err := http.NewRequest("GET", "http://example.org/", nil) - if err != nil { - t.Fatal(err) - } - - tstest.Replace(t, &synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf")) - - t.Run("no config file", func(t *testing.T) { - if _, err := os.Stat(synologyProxyConfigPath); err == nil { - t.Fatalf("%s must not exist for this test", synologyProxyConfigPath) - } - - cache.updated = time.Time{} - cache.httpProxy = nil - cache.httpsProxy = nil - - if val, err := synologyProxyFromConfigCached(req); val != nil || err != nil { - t.Fatalf("got %s, %v; want nil, nil", val, err) - } - - if got, want := cache.updated, time.Unix(0, 0); got != want { - t.Fatalf("got %s, want %s", got, want) - } - if cache.httpProxy != nil { - t.Fatalf("got %s, want nil", cache.httpProxy) - } - if cache.httpsProxy != nil { - t.Fatalf("got %s, want nil", cache.httpsProxy) - } - }) - - t.Run("config file updated", func(t *testing.T) { - cache.updated = time.Now() - cache.httpProxy = nil - cache.httpsProxy = nil - - if err := os.WriteFile(synologyProxyConfigPath, []byte(` -proxy_enabled=yes -http_host=10.0.0.55 -http_port=80 -https_host=10.0.0.66 -https_port=443 - `), 0600); err != nil { - t.Fatal(err) - } - - val, err := synologyProxyFromConfigCached(req) - if err != nil { - t.Fatal(err) - } - - if cache.httpProxy == nil { - t.Fatal("http proxy was not cached") - } - if cache.httpsProxy == nil { - t.Fatal("https proxy was not cached") - } - - if want := urlMustParse("http://10.0.0.55:80"); val.String() != want.String() { - t.Fatalf("got %s; want %s", val, want) - } - }) - - t.Run("config file removed", func(t *testing.T) { - cache.updated = time.Now() - cache.httpProxy = urlMustParse("http://127.0.0.1/") - cache.httpsProxy = urlMustParse("http://127.0.0.1/") - - if err := os.Remove(synologyProxyConfigPath); err != nil && !os.IsNotExist(err) { - t.Fatal(err) - } - - val, err := synologyProxyFromConfigCached(req) - if err != nil { - t.Fatal(err) - } - if val != nil { - t.Fatalf("got %s; want nil", val) - } - if cache.httpProxy != nil { - t.Fatalf("got %s, want nil", cache.httpProxy) - } - if cache.httpsProxy != nil { - t.Fatalf("got %s, want nil", cache.httpsProxy) - } - }) - - t.Run("picks proxy from request scheme", func(t *testing.T) { - cache.updated = time.Now() - cache.httpProxy = nil - cache.httpsProxy = nil - - if err := os.WriteFile(synologyProxyConfigPath, []byte(` -proxy_enabled=yes -http_host=10.0.0.55 -http_port=80 -https_host=10.0.0.66 -https_port=443 - `), 0600); err != nil { - t.Fatal(err) - } - - httpReq, err := http.NewRequest("GET", "http://example.com", nil) - if err != nil { - t.Fatal(err) - } - val, err := synologyProxyFromConfigCached(httpReq) - if err != nil { - t.Fatal(err) - } - if val == nil { - t.Fatalf("got nil, want an http URL") - } - if got, want := val.String(), "http://10.0.0.55:80"; got != want { - t.Fatalf("got %q, want %q", got, want) - } - - httpsReq, err := http.NewRequest("GET", "https://example.com", nil) - if err != nil { - t.Fatal(err) - } - val, err = synologyProxyFromConfigCached(httpsReq) - if err != nil { - t.Fatal(err) - } - if val == nil { - t.Fatalf("got nil, want an http URL") - } - if got, want := val.String(), "http://10.0.0.66:443"; got != want { - t.Fatalf("got %q, want %q", got, want) - } - }) -} - -func TestSynologyProxiesFromConfig(t *testing.T) { - var ( - openReader io.ReadCloser - openErr error - ) - tstest.Replace(t, &openSynologyProxyConf, func() (io.ReadCloser, error) { - return openReader, openErr - }) - - t.Run("with config", func(t *testing.T) { - mc := &mustCloser{Reader: strings.NewReader(` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=yes -https_host=10.0.0.66 -https_port=8443 -http_host=10.0.0.55 -http_port=80 - `)} - defer mc.check(t) - openReader = mc - - httpProxy, httpsProxy, err := synologyProxiesFromConfig() - - if got, want := err, openErr; got != want { - t.Fatalf("got %s, want %s", got, want) - } - - if got, want := httpsProxy, urlMustParse("http://foo:bar@10.0.0.66:8443"); got.String() != want.String() { - t.Fatalf("got %s, want %s", got, want) - } - - if got, want := err, openErr; got != want { - t.Fatalf("got %s, want %s", got, want) - } - - if got, want := httpProxy, urlMustParse("http://foo:bar@10.0.0.55:80"); got.String() != want.String() { - t.Fatalf("got %s, want %s", got, want) - } - - }) - - t.Run("nonexistent config", func(t *testing.T) { - openReader = nil - openErr = os.ErrNotExist - - httpProxy, httpsProxy, err := synologyProxiesFromConfig() - if err != nil { - t.Fatalf("expected no error, got %s", err) - } - if httpProxy != nil { - t.Fatalf("expected no url, got %s", httpProxy) - } - if httpsProxy != nil { - t.Fatalf("expected no url, got %s", httpsProxy) - } - }) - - t.Run("error opening config", func(t *testing.T) { - openReader = nil - openErr = errors.New("example error") - - httpProxy, httpsProxy, err := synologyProxiesFromConfig() - if err != openErr { - t.Fatalf("expected %s, got %s", openErr, err) - } - if httpProxy != nil { - t.Fatalf("expected no url, got %s", httpProxy) - } - if httpsProxy != nil { - t.Fatalf("expected no url, got %s", httpsProxy) - } - }) - -} - -func TestParseSynologyConfig(t *testing.T) { - cases := map[string]struct { - input string - httpProxy *url.URL - httpsProxy *url.URL - err error - }{ - "populated": { - input: ` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=yes -https_host=10.0.0.66 -https_port=8443 -http_host=10.0.0.55 -http_port=80 -`, - httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"), - httpsProxy: urlMustParse("http://foo:bar@10.0.0.66:8443"), - err: nil, - }, - "no-auth": { - input: ` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=no -https_host=10.0.0.66 -https_port=8443 -http_host=10.0.0.55 -http_port=80 -`, - httpProxy: urlMustParse("http://10.0.0.55:80"), - httpsProxy: urlMustParse("http://10.0.0.66:8443"), - err: nil, - }, - "http-only": { - input: ` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=yes -https_host= -https_port=8443 -http_host=10.0.0.55 -http_port=80 -`, - httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"), - httpsProxy: nil, - err: nil, - }, - "empty": { - input: ` -proxy_user= -proxy_pwd= -proxy_enabled= -adv_enabled= -bypass_enabled= -auth_enabled= -https_host= -https_port= -http_host= -http_port= -`, - httpProxy: nil, - httpsProxy: nil, - err: nil, - }, - } - - for name, example := range cases { - t.Run(name, func(t *testing.T) { - httpProxy, httpsProxy, err := parseSynologyConfig(strings.NewReader(example.input)) - if err != example.err { - t.Fatal(err) - } - if example.err != nil { - return - } - - if example.httpProxy == nil && httpProxy != nil { - t.Fatalf("got %s, want nil", httpProxy) - } - - if example.httpProxy != nil { - if httpProxy == nil { - t.Fatalf("got nil, want %s", example.httpProxy) - } - - if got, want := example.httpProxy.String(), httpProxy.String(); got != want { - t.Fatalf("got %s, want %s", got, want) - } - } - - if example.httpsProxy == nil && httpsProxy != nil { - t.Fatalf("got %s, want nil", httpProxy) - } - - if example.httpsProxy != nil { - if httpsProxy == nil { - t.Fatalf("got nil, want %s", example.httpsProxy) - } - - if got, want := example.httpsProxy.String(), httpsProxy.String(); got != want { - t.Fatalf("got %s, want %s", got, want) - } - } - }) - } -} -func urlMustParse(u string) *url.URL { - r, err := url.Parse(u) - if err != nil { - panic(fmt.Sprintf("urlMustParse: %s", err)) - } - return r -} - -type mustCloser struct { - io.Reader - closed bool -} - -func (m *mustCloser) Close() error { - m.closed = true - return nil -} - -func (m *mustCloser) check(t *testing.T) { - if !m.closed { - t.Errorf("mustCloser wrapping %#v was not closed at time of check", m.Reader) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package tshttpproxy + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "tailscale.com/tstest" +) + +func TestSynologyProxyFromConfigCached(t *testing.T) { + req, err := http.NewRequest("GET", "http://example.org/", nil) + if err != nil { + t.Fatal(err) + } + + tstest.Replace(t, &synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf")) + + t.Run("no config file", func(t *testing.T) { + if _, err := os.Stat(synologyProxyConfigPath); err == nil { + t.Fatalf("%s must not exist for this test", synologyProxyConfigPath) + } + + cache.updated = time.Time{} + cache.httpProxy = nil + cache.httpsProxy = nil + + if val, err := synologyProxyFromConfigCached(req); val != nil || err != nil { + t.Fatalf("got %s, %v; want nil, nil", val, err) + } + + if got, want := cache.updated, time.Unix(0, 0); got != want { + t.Fatalf("got %s, want %s", got, want) + } + if cache.httpProxy != nil { + t.Fatalf("got %s, want nil", cache.httpProxy) + } + if cache.httpsProxy != nil { + t.Fatalf("got %s, want nil", cache.httpsProxy) + } + }) + + t.Run("config file updated", func(t *testing.T) { + cache.updated = time.Now() + cache.httpProxy = nil + cache.httpsProxy = nil + + if err := os.WriteFile(synologyProxyConfigPath, []byte(` +proxy_enabled=yes +http_host=10.0.0.55 +http_port=80 +https_host=10.0.0.66 +https_port=443 + `), 0600); err != nil { + t.Fatal(err) + } + + val, err := synologyProxyFromConfigCached(req) + if err != nil { + t.Fatal(err) + } + + if cache.httpProxy == nil { + t.Fatal("http proxy was not cached") + } + if cache.httpsProxy == nil { + t.Fatal("https proxy was not cached") + } + + if want := urlMustParse("http://10.0.0.55:80"); val.String() != want.String() { + t.Fatalf("got %s; want %s", val, want) + } + }) + + t.Run("config file removed", func(t *testing.T) { + cache.updated = time.Now() + cache.httpProxy = urlMustParse("http://127.0.0.1/") + cache.httpsProxy = urlMustParse("http://127.0.0.1/") + + if err := os.Remove(synologyProxyConfigPath); err != nil && !os.IsNotExist(err) { + t.Fatal(err) + } + + val, err := synologyProxyFromConfigCached(req) + if err != nil { + t.Fatal(err) + } + if val != nil { + t.Fatalf("got %s; want nil", val) + } + if cache.httpProxy != nil { + t.Fatalf("got %s, want nil", cache.httpProxy) + } + if cache.httpsProxy != nil { + t.Fatalf("got %s, want nil", cache.httpsProxy) + } + }) + + t.Run("picks proxy from request scheme", func(t *testing.T) { + cache.updated = time.Now() + cache.httpProxy = nil + cache.httpsProxy = nil + + if err := os.WriteFile(synologyProxyConfigPath, []byte(` +proxy_enabled=yes +http_host=10.0.0.55 +http_port=80 +https_host=10.0.0.66 +https_port=443 + `), 0600); err != nil { + t.Fatal(err) + } + + httpReq, err := http.NewRequest("GET", "http://example.com", nil) + if err != nil { + t.Fatal(err) + } + val, err := synologyProxyFromConfigCached(httpReq) + if err != nil { + t.Fatal(err) + } + if val == nil { + t.Fatalf("got nil, want an http URL") + } + if got, want := val.String(), "http://10.0.0.55:80"; got != want { + t.Fatalf("got %q, want %q", got, want) + } + + httpsReq, err := http.NewRequest("GET", "https://example.com", nil) + if err != nil { + t.Fatal(err) + } + val, err = synologyProxyFromConfigCached(httpsReq) + if err != nil { + t.Fatal(err) + } + if val == nil { + t.Fatalf("got nil, want an http URL") + } + if got, want := val.String(), "http://10.0.0.66:443"; got != want { + t.Fatalf("got %q, want %q", got, want) + } + }) +} + +func TestSynologyProxiesFromConfig(t *testing.T) { + var ( + openReader io.ReadCloser + openErr error + ) + tstest.Replace(t, &openSynologyProxyConf, func() (io.ReadCloser, error) { + return openReader, openErr + }) + + t.Run("with config", func(t *testing.T) { + mc := &mustCloser{Reader: strings.NewReader(` +proxy_user=foo +proxy_pwd=bar +proxy_enabled=yes +adv_enabled=yes +bypass_enabled=yes +auth_enabled=yes +https_host=10.0.0.66 +https_port=8443 +http_host=10.0.0.55 +http_port=80 + `)} + defer mc.check(t) + openReader = mc + + httpProxy, httpsProxy, err := synologyProxiesFromConfig() + + if got, want := err, openErr; got != want { + t.Fatalf("got %s, want %s", got, want) + } + + if got, want := httpsProxy, urlMustParse("http://foo:bar@10.0.0.66:8443"); got.String() != want.String() { + t.Fatalf("got %s, want %s", got, want) + } + + if got, want := err, openErr; got != want { + t.Fatalf("got %s, want %s", got, want) + } + + if got, want := httpProxy, urlMustParse("http://foo:bar@10.0.0.55:80"); got.String() != want.String() { + t.Fatalf("got %s, want %s", got, want) + } + + }) + + t.Run("nonexistent config", func(t *testing.T) { + openReader = nil + openErr = os.ErrNotExist + + httpProxy, httpsProxy, err := synologyProxiesFromConfig() + if err != nil { + t.Fatalf("expected no error, got %s", err) + } + if httpProxy != nil { + t.Fatalf("expected no url, got %s", httpProxy) + } + if httpsProxy != nil { + t.Fatalf("expected no url, got %s", httpsProxy) + } + }) + + t.Run("error opening config", func(t *testing.T) { + openReader = nil + openErr = errors.New("example error") + + httpProxy, httpsProxy, err := synologyProxiesFromConfig() + if err != openErr { + t.Fatalf("expected %s, got %s", openErr, err) + } + if httpProxy != nil { + t.Fatalf("expected no url, got %s", httpProxy) + } + if httpsProxy != nil { + t.Fatalf("expected no url, got %s", httpsProxy) + } + }) + +} + +func TestParseSynologyConfig(t *testing.T) { + cases := map[string]struct { + input string + httpProxy *url.URL + httpsProxy *url.URL + err error + }{ + "populated": { + input: ` +proxy_user=foo +proxy_pwd=bar +proxy_enabled=yes +adv_enabled=yes +bypass_enabled=yes +auth_enabled=yes +https_host=10.0.0.66 +https_port=8443 +http_host=10.0.0.55 +http_port=80 +`, + httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"), + httpsProxy: urlMustParse("http://foo:bar@10.0.0.66:8443"), + err: nil, + }, + "no-auth": { + input: ` +proxy_user=foo +proxy_pwd=bar +proxy_enabled=yes +adv_enabled=yes +bypass_enabled=yes +auth_enabled=no +https_host=10.0.0.66 +https_port=8443 +http_host=10.0.0.55 +http_port=80 +`, + httpProxy: urlMustParse("http://10.0.0.55:80"), + httpsProxy: urlMustParse("http://10.0.0.66:8443"), + err: nil, + }, + "http-only": { + input: ` +proxy_user=foo +proxy_pwd=bar +proxy_enabled=yes +adv_enabled=yes +bypass_enabled=yes +auth_enabled=yes +https_host= +https_port=8443 +http_host=10.0.0.55 +http_port=80 +`, + httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"), + httpsProxy: nil, + err: nil, + }, + "empty": { + input: ` +proxy_user= +proxy_pwd= +proxy_enabled= +adv_enabled= +bypass_enabled= +auth_enabled= +https_host= +https_port= +http_host= +http_port= +`, + httpProxy: nil, + httpsProxy: nil, + err: nil, + }, + } + + for name, example := range cases { + t.Run(name, func(t *testing.T) { + httpProxy, httpsProxy, err := parseSynologyConfig(strings.NewReader(example.input)) + if err != example.err { + t.Fatal(err) + } + if example.err != nil { + return + } + + if example.httpProxy == nil && httpProxy != nil { + t.Fatalf("got %s, want nil", httpProxy) + } + + if example.httpProxy != nil { + if httpProxy == nil { + t.Fatalf("got nil, want %s", example.httpProxy) + } + + if got, want := example.httpProxy.String(), httpProxy.String(); got != want { + t.Fatalf("got %s, want %s", got, want) + } + } + + if example.httpsProxy == nil && httpsProxy != nil { + t.Fatalf("got %s, want nil", httpProxy) + } + + if example.httpsProxy != nil { + if httpsProxy == nil { + t.Fatalf("got nil, want %s", example.httpsProxy) + } + + if got, want := example.httpsProxy.String(), httpsProxy.String(); got != want { + t.Fatalf("got %s, want %s", got, want) + } + } + }) + } +} +func urlMustParse(u string) *url.URL { + r, err := url.Parse(u) + if err != nil { + panic(fmt.Sprintf("urlMustParse: %s", err)) + } + return r +} + +type mustCloser struct { + io.Reader + closed bool +} + +func (m *mustCloser) Close() error { + m.closed = true + return nil +} + +func (m *mustCloser) check(t *testing.T) { + if !m.closed { + t.Errorf("mustCloser wrapping %#v was not closed at time of check", m.Reader) + } +} diff --git a/net/tshttpproxy/tshttpproxy_windows.go b/net/tshttpproxy/tshttpproxy_windows.go index 06a1f5ae445d0..cb6b24c8355e8 100644 --- a/net/tshttpproxy/tshttpproxy_windows.go +++ b/net/tshttpproxy/tshttpproxy_windows.go @@ -1,276 +1,276 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tshttpproxy - -import ( - "context" - "encoding/base64" - "fmt" - "log" - "net/http" - "net/url" - "runtime" - "strings" - "sync" - "syscall" - "time" - "unsafe" - - "github.com/alexbrainman/sspi/negotiate" - "golang.org/x/sys/windows" - "tailscale.com/hostinfo" - "tailscale.com/syncs" - "tailscale.com/types/logger" - "tailscale.com/util/clientmetric" - "tailscale.com/util/cmpver" -) - -func init() { - sysProxyFromEnv = proxyFromWinHTTPOrCache - sysAuthHeader = sysAuthHeaderWindows -} - -var cachedProxy struct { - sync.Mutex - val *url.URL -} - -// proxyErrorf is a rate-limited logger specifically for errors asking -// WinHTTP for the proxy information. We don't want to log about -// errors often, otherwise the log message itself will generate a new -// HTTP request which ultimately will call back into us to log again, -// forever. So for errors, we only log a bit. -var proxyErrorf = logger.RateLimitedFn(log.Printf, 10*time.Minute, 2 /* burst*/, 10 /* maxCache */) - -var ( - metricSuccess = clientmetric.NewCounter("winhttp_proxy_success") - metricErrDetectionFailed = clientmetric.NewCounter("winhttp_proxy_err_detection_failed") - metricErrInvalidParameters = clientmetric.NewCounter("winhttp_proxy_err_invalid_param") - metricErrDownloadScript = clientmetric.NewCounter("winhttp_proxy_err_download_script") - metricErrTimeout = clientmetric.NewCounter("winhttp_proxy_err_timeout") - metricErrOther = clientmetric.NewCounter("winhttp_proxy_err_other") -) - -func proxyFromWinHTTPOrCache(req *http.Request) (*url.URL, error) { - if req.URL == nil { - return nil, nil - } - urlStr := req.URL.String() - - ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second) - defer cancel() - - type result struct { - proxy *url.URL - err error - } - resc := make(chan result, 1) - go func() { - proxy, err := proxyFromWinHTTP(ctx, urlStr) - resc <- result{proxy, err} - }() - - select { - case res := <-resc: - err := res.err - if err == nil { - metricSuccess.Add(1) - cachedProxy.Lock() - defer cachedProxy.Unlock() - if was, now := fmt.Sprint(cachedProxy.val), fmt.Sprint(res.proxy); was != now { - log.Printf("tshttpproxy: winhttp: updating cached proxy setting from %v to %v", was, now) - } - cachedProxy.val = res.proxy - return res.proxy, nil - } - - // See https://docs.microsoft.com/en-us/windows/win32/winhttp/error-messages - const ( - ERROR_WINHTTP_AUTODETECTION_FAILED = 12180 - ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT = 12167 - ) - if err == syscall.Errno(ERROR_WINHTTP_AUTODETECTION_FAILED) { - metricErrDetectionFailed.Add(1) - setNoProxyUntil(10 * time.Second) - return nil, nil - } - if err == windows.ERROR_INVALID_PARAMETER { - metricErrInvalidParameters.Add(1) - // Seen on Windows 8.1. (https://github.com/tailscale/tailscale/issues/879) - // TODO(bradfitz): figure this out. - setNoProxyUntil(time.Hour) - proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): ERROR_INVALID_PARAMETER [unexpected]", urlStr) - return nil, nil - } - proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): %v/%#v", urlStr, err, err) - if err == syscall.Errno(ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT) { - metricErrDownloadScript.Add(1) - setNoProxyUntil(10 * time.Second) - return nil, nil - } - metricErrOther.Add(1) - return nil, err - case <-ctx.Done(): - metricErrTimeout.Add(1) - cachedProxy.Lock() - defer cachedProxy.Unlock() - proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): timeout; using cached proxy %v", urlStr, cachedProxy.val) - return cachedProxy.val, nil - } -} - -func proxyFromWinHTTP(ctx context.Context, urlStr string) (proxy *url.URL, err error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - whi, err := httpOpen() - if err != nil { - proxyErrorf("winhttp: Open: %v", err) - return nil, err - } - defer whi.Close() - - t0 := time.Now() - v, err := whi.GetProxyForURL(urlStr) - td := time.Since(t0).Round(time.Millisecond) - if err := ctx.Err(); err != nil { - log.Printf("tshttpproxy: winhttp: context canceled, ignoring GetProxyForURL(%q) after %v", urlStr, td) - return nil, err - } - if err != nil { - return nil, err - } - if v == "" { - return nil, nil - } - // Discard all but first proxy value for now. - if i := strings.Index(v, ";"); i != -1 { - v = v[:i] - } - if !strings.HasPrefix(v, "https://") { - v = "http://" + v - } - return url.Parse(v) -} - -var userAgent = windows.StringToUTF16Ptr("Tailscale") - -const ( - winHTTP_ACCESS_TYPE_DEFAULT_PROXY = 0 - winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY = 4 - winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG = 0x00000100 - winHTTP_AUTOPROXY_AUTO_DETECT = 1 - winHTTP_AUTO_DETECT_TYPE_DHCP = 0x00000001 - winHTTP_AUTO_DETECT_TYPE_DNS_A = 0x00000002 -) - -// Windows 8.1 is actually Windows 6.3 under the hood. Yay, marketing! -const win8dot1Ver = "6.3" - -// accessType is the flag we must pass to WinHttpOpen for proxy resolution -// depending on whether or not we're running Windows < 8.1 -var accessType syncs.AtomicValue[uint32] - -func getAccessFlag() uint32 { - if flag, ok := accessType.LoadOk(); ok { - return flag - } - var flag uint32 - if cmpver.Compare(hostinfo.GetOSVersion(), win8dot1Ver) < 0 { - flag = winHTTP_ACCESS_TYPE_DEFAULT_PROXY - } else { - flag = winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY - } - accessType.Store(flag) - return flag -} - -func httpOpen() (winHTTPInternet, error) { - return winHTTPOpen( - userAgent, - getAccessFlag(), - nil, /* WINHTTP_NO_PROXY_NAME */ - nil, /* WINHTTP_NO_PROXY_BYPASS */ - 0, - ) -} - -type winHTTPInternet windows.Handle - -func (hi winHTTPInternet) Close() error { - return winHTTPCloseHandle(hi) -} - -// WINHTTP_AUTOPROXY_OPTIONS -// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_autoproxy_options -type winHTTPAutoProxyOptions struct { - DwFlags uint32 - DwAutoDetectFlags uint32 - AutoConfigUrl *uint16 - _ uintptr - _ uint32 - FAutoLogonIfChallenged int32 // BOOL -} - -// WINHTTP_PROXY_INFO -// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_proxy_info -type winHTTPProxyInfo struct { - AccessType uint32 - Proxy *uint16 - ProxyBypass *uint16 -} - -type winHGlobal windows.Handle - -func globalFreeUTF16Ptr(p *uint16) error { - return globalFree((winHGlobal)(unsafe.Pointer(p))) -} - -func (pi *winHTTPProxyInfo) free() { - if pi.Proxy != nil { - globalFreeUTF16Ptr(pi.Proxy) - pi.Proxy = nil - } - if pi.ProxyBypass != nil { - globalFreeUTF16Ptr(pi.ProxyBypass) - pi.ProxyBypass = nil - } -} - -var proxyForURLOpts = &winHTTPAutoProxyOptions{ - DwFlags: winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG | winHTTP_AUTOPROXY_AUTO_DETECT, - DwAutoDetectFlags: winHTTP_AUTO_DETECT_TYPE_DHCP, // | winHTTP_AUTO_DETECT_TYPE_DNS_A, -} - -func (hi winHTTPInternet) GetProxyForURL(urlStr string) (string, error) { - var out winHTTPProxyInfo - err := winHTTPGetProxyForURL( - hi, - windows.StringToUTF16Ptr(urlStr), - proxyForURLOpts, - &out, - ) - if err != nil { - return "", err - } - defer out.free() - return windows.UTF16PtrToString(out.Proxy), nil -} - -func sysAuthHeaderWindows(u *url.URL) (string, error) { - spn := "HTTP/" + u.Hostname() - creds, err := negotiate.AcquireCurrentUserCredentials() - if err != nil { - return "", fmt.Errorf("negotiate.AcquireCurrentUserCredentials: %w", err) - } - defer creds.Release() - - secCtx, token, err := negotiate.NewClientContext(creds, spn) - if err != nil { - return "", fmt.Errorf("negotiate.NewClientContext: %w", err) - } - defer secCtx.Release() - - return "Negotiate " + base64.StdEncoding.EncodeToString(token), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tshttpproxy + +import ( + "context" + "encoding/base64" + "fmt" + "log" + "net/http" + "net/url" + "runtime" + "strings" + "sync" + "syscall" + "time" + "unsafe" + + "github.com/alexbrainman/sspi/negotiate" + "golang.org/x/sys/windows" + "tailscale.com/hostinfo" + "tailscale.com/syncs" + "tailscale.com/types/logger" + "tailscale.com/util/clientmetric" + "tailscale.com/util/cmpver" +) + +func init() { + sysProxyFromEnv = proxyFromWinHTTPOrCache + sysAuthHeader = sysAuthHeaderWindows +} + +var cachedProxy struct { + sync.Mutex + val *url.URL +} + +// proxyErrorf is a rate-limited logger specifically for errors asking +// WinHTTP for the proxy information. We don't want to log about +// errors often, otherwise the log message itself will generate a new +// HTTP request which ultimately will call back into us to log again, +// forever. So for errors, we only log a bit. +var proxyErrorf = logger.RateLimitedFn(log.Printf, 10*time.Minute, 2 /* burst*/, 10 /* maxCache */) + +var ( + metricSuccess = clientmetric.NewCounter("winhttp_proxy_success") + metricErrDetectionFailed = clientmetric.NewCounter("winhttp_proxy_err_detection_failed") + metricErrInvalidParameters = clientmetric.NewCounter("winhttp_proxy_err_invalid_param") + metricErrDownloadScript = clientmetric.NewCounter("winhttp_proxy_err_download_script") + metricErrTimeout = clientmetric.NewCounter("winhttp_proxy_err_timeout") + metricErrOther = clientmetric.NewCounter("winhttp_proxy_err_other") +) + +func proxyFromWinHTTPOrCache(req *http.Request) (*url.URL, error) { + if req.URL == nil { + return nil, nil + } + urlStr := req.URL.String() + + ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second) + defer cancel() + + type result struct { + proxy *url.URL + err error + } + resc := make(chan result, 1) + go func() { + proxy, err := proxyFromWinHTTP(ctx, urlStr) + resc <- result{proxy, err} + }() + + select { + case res := <-resc: + err := res.err + if err == nil { + metricSuccess.Add(1) + cachedProxy.Lock() + defer cachedProxy.Unlock() + if was, now := fmt.Sprint(cachedProxy.val), fmt.Sprint(res.proxy); was != now { + log.Printf("tshttpproxy: winhttp: updating cached proxy setting from %v to %v", was, now) + } + cachedProxy.val = res.proxy + return res.proxy, nil + } + + // See https://docs.microsoft.com/en-us/windows/win32/winhttp/error-messages + const ( + ERROR_WINHTTP_AUTODETECTION_FAILED = 12180 + ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT = 12167 + ) + if err == syscall.Errno(ERROR_WINHTTP_AUTODETECTION_FAILED) { + metricErrDetectionFailed.Add(1) + setNoProxyUntil(10 * time.Second) + return nil, nil + } + if err == windows.ERROR_INVALID_PARAMETER { + metricErrInvalidParameters.Add(1) + // Seen on Windows 8.1. (https://github.com/tailscale/tailscale/issues/879) + // TODO(bradfitz): figure this out. + setNoProxyUntil(time.Hour) + proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): ERROR_INVALID_PARAMETER [unexpected]", urlStr) + return nil, nil + } + proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): %v/%#v", urlStr, err, err) + if err == syscall.Errno(ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT) { + metricErrDownloadScript.Add(1) + setNoProxyUntil(10 * time.Second) + return nil, nil + } + metricErrOther.Add(1) + return nil, err + case <-ctx.Done(): + metricErrTimeout.Add(1) + cachedProxy.Lock() + defer cachedProxy.Unlock() + proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): timeout; using cached proxy %v", urlStr, cachedProxy.val) + return cachedProxy.val, nil + } +} + +func proxyFromWinHTTP(ctx context.Context, urlStr string) (proxy *url.URL, err error) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + whi, err := httpOpen() + if err != nil { + proxyErrorf("winhttp: Open: %v", err) + return nil, err + } + defer whi.Close() + + t0 := time.Now() + v, err := whi.GetProxyForURL(urlStr) + td := time.Since(t0).Round(time.Millisecond) + if err := ctx.Err(); err != nil { + log.Printf("tshttpproxy: winhttp: context canceled, ignoring GetProxyForURL(%q) after %v", urlStr, td) + return nil, err + } + if err != nil { + return nil, err + } + if v == "" { + return nil, nil + } + // Discard all but first proxy value for now. + if i := strings.Index(v, ";"); i != -1 { + v = v[:i] + } + if !strings.HasPrefix(v, "https://") { + v = "http://" + v + } + return url.Parse(v) +} + +var userAgent = windows.StringToUTF16Ptr("Tailscale") + +const ( + winHTTP_ACCESS_TYPE_DEFAULT_PROXY = 0 + winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY = 4 + winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG = 0x00000100 + winHTTP_AUTOPROXY_AUTO_DETECT = 1 + winHTTP_AUTO_DETECT_TYPE_DHCP = 0x00000001 + winHTTP_AUTO_DETECT_TYPE_DNS_A = 0x00000002 +) + +// Windows 8.1 is actually Windows 6.3 under the hood. Yay, marketing! +const win8dot1Ver = "6.3" + +// accessType is the flag we must pass to WinHttpOpen for proxy resolution +// depending on whether or not we're running Windows < 8.1 +var accessType syncs.AtomicValue[uint32] + +func getAccessFlag() uint32 { + if flag, ok := accessType.LoadOk(); ok { + return flag + } + var flag uint32 + if cmpver.Compare(hostinfo.GetOSVersion(), win8dot1Ver) < 0 { + flag = winHTTP_ACCESS_TYPE_DEFAULT_PROXY + } else { + flag = winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY + } + accessType.Store(flag) + return flag +} + +func httpOpen() (winHTTPInternet, error) { + return winHTTPOpen( + userAgent, + getAccessFlag(), + nil, /* WINHTTP_NO_PROXY_NAME */ + nil, /* WINHTTP_NO_PROXY_BYPASS */ + 0, + ) +} + +type winHTTPInternet windows.Handle + +func (hi winHTTPInternet) Close() error { + return winHTTPCloseHandle(hi) +} + +// WINHTTP_AUTOPROXY_OPTIONS +// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_autoproxy_options +type winHTTPAutoProxyOptions struct { + DwFlags uint32 + DwAutoDetectFlags uint32 + AutoConfigUrl *uint16 + _ uintptr + _ uint32 + FAutoLogonIfChallenged int32 // BOOL +} + +// WINHTTP_PROXY_INFO +// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_proxy_info +type winHTTPProxyInfo struct { + AccessType uint32 + Proxy *uint16 + ProxyBypass *uint16 +} + +type winHGlobal windows.Handle + +func globalFreeUTF16Ptr(p *uint16) error { + return globalFree((winHGlobal)(unsafe.Pointer(p))) +} + +func (pi *winHTTPProxyInfo) free() { + if pi.Proxy != nil { + globalFreeUTF16Ptr(pi.Proxy) + pi.Proxy = nil + } + if pi.ProxyBypass != nil { + globalFreeUTF16Ptr(pi.ProxyBypass) + pi.ProxyBypass = nil + } +} + +var proxyForURLOpts = &winHTTPAutoProxyOptions{ + DwFlags: winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG | winHTTP_AUTOPROXY_AUTO_DETECT, + DwAutoDetectFlags: winHTTP_AUTO_DETECT_TYPE_DHCP, // | winHTTP_AUTO_DETECT_TYPE_DNS_A, +} + +func (hi winHTTPInternet) GetProxyForURL(urlStr string) (string, error) { + var out winHTTPProxyInfo + err := winHTTPGetProxyForURL( + hi, + windows.StringToUTF16Ptr(urlStr), + proxyForURLOpts, + &out, + ) + if err != nil { + return "", err + } + defer out.free() + return windows.UTF16PtrToString(out.Proxy), nil +} + +func sysAuthHeaderWindows(u *url.URL) (string, error) { + spn := "HTTP/" + u.Hostname() + creds, err := negotiate.AcquireCurrentUserCredentials() + if err != nil { + return "", fmt.Errorf("negotiate.AcquireCurrentUserCredentials: %w", err) + } + defer creds.Release() + + secCtx, token, err := negotiate.NewClientContext(creds, spn) + if err != nil { + return "", fmt.Errorf("negotiate.NewClientContext: %w", err) + } + defer secCtx.Release() + + return "Negotiate " + base64.StdEncoding.EncodeToString(token), nil +} diff --git a/net/tstun/fake.go b/net/tstun/fake.go index 3d86bb3df4ca9..a002952a3eef5 100644 --- a/net/tstun/fake.go +++ b/net/tstun/fake.go @@ -1,58 +1,58 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "io" - "os" - - "github.com/tailscale/wireguard-go/tun" -) - -type fakeTUN struct { - evchan chan tun.Event - closechan chan struct{} -} - -// NewFake returns a tun.Device that does nothing. -func NewFake() tun.Device { - return &fakeTUN{ - evchan: make(chan tun.Event), - closechan: make(chan struct{}), - } -} - -func (t *fakeTUN) File() *os.File { - panic("fakeTUN.File() called, which makes no sense") -} - -func (t *fakeTUN) Close() error { - close(t.closechan) - close(t.evchan) - return nil -} - -func (t *fakeTUN) Read(out [][]byte, sizes []int, offset int) (int, error) { - <-t.closechan - return 0, io.EOF -} - -func (t *fakeTUN) Write(b [][]byte, n int) (int, error) { - select { - case <-t.closechan: - return 0, ErrClosed - default: - } - return 1, nil -} - -// FakeTUNName is the name of the fake TUN device. -const FakeTUNName = "FakeTUN" - -func (t *fakeTUN) Flush() error { return nil } -func (t *fakeTUN) MTU() (int, error) { return 1500, nil } -func (t *fakeTUN) Name() (string, error) { return FakeTUNName, nil } -func (t *fakeTUN) Events() <-chan tun.Event { return t.evchan } -func (t *fakeTUN) BatchSize() int { return 1 } -func (t *fakeTUN) IsFakeTun() bool { return true } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "io" + "os" + + "github.com/tailscale/wireguard-go/tun" +) + +type fakeTUN struct { + evchan chan tun.Event + closechan chan struct{} +} + +// NewFake returns a tun.Device that does nothing. +func NewFake() tun.Device { + return &fakeTUN{ + evchan: make(chan tun.Event), + closechan: make(chan struct{}), + } +} + +func (t *fakeTUN) File() *os.File { + panic("fakeTUN.File() called, which makes no sense") +} + +func (t *fakeTUN) Close() error { + close(t.closechan) + close(t.evchan) + return nil +} + +func (t *fakeTUN) Read(out [][]byte, sizes []int, offset int) (int, error) { + <-t.closechan + return 0, io.EOF +} + +func (t *fakeTUN) Write(b [][]byte, n int) (int, error) { + select { + case <-t.closechan: + return 0, ErrClosed + default: + } + return 1, nil +} + +// FakeTUNName is the name of the fake TUN device. +const FakeTUNName = "FakeTUN" + +func (t *fakeTUN) Flush() error { return nil } +func (t *fakeTUN) MTU() (int, error) { return 1500, nil } +func (t *fakeTUN) Name() (string, error) { return FakeTUNName, nil } +func (t *fakeTUN) Events() <-chan tun.Event { return t.evchan } +func (t *fakeTUN) BatchSize() int { return 1 } +func (t *fakeTUN) IsFakeTun() bool { return true } diff --git a/net/tstun/ifstatus_noop.go b/net/tstun/ifstatus_noop.go index 8cf569f982010..4d453b72c83bc 100644 --- a/net/tstun/ifstatus_noop.go +++ b/net/tstun/ifstatus_noop.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package tstun - -import ( - "time" - - "github.com/tailscale/wireguard-go/tun" - "tailscale.com/types/logger" -) - -// Dummy implementation that does nothing. -func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package tstun + +import ( + "time" + + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/types/logger" +) + +// Dummy implementation that does nothing. +func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { + return nil +} diff --git a/net/tstun/ifstatus_windows.go b/net/tstun/ifstatus_windows.go index fd9fc2112524c..6c6377bb40fb6 100644 --- a/net/tstun/ifstatus_windows.go +++ b/net/tstun/ifstatus_windows.go @@ -1,109 +1,109 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "fmt" - "sync" - "time" - - "github.com/tailscale/wireguard-go/tun" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "tailscale.com/types/logger" -) - -// ifaceWatcher waits for an interface to be up. -type ifaceWatcher struct { - logf logger.Logf - luid winipcfg.LUID - - mu sync.Mutex // guards following - done bool - sig chan bool -} - -// callback is the callback we register with Windows to call when IP interface changes. -func (iw *ifaceWatcher) callback(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { - // Probably should check only when MibParameterNotification, but just in case included MibAddInstance also. - if notificationType == winipcfg.MibParameterNotification || notificationType == winipcfg.MibAddInstance { - // Out of paranoia, start a goroutine to finish our work, to return to Windows out of this callback. - go iw.isUp() - } -} - -func (iw *ifaceWatcher) isUp() bool { - iw.mu.Lock() - defer iw.mu.Unlock() - - if iw.done { - // We already know that it's up - return true - } - - if iw.getOperStatus() != winipcfg.IfOperStatusUp { - return false - } - - iw.done = true - iw.sig <- true - return true -} - -func (iw *ifaceWatcher) getOperStatus() winipcfg.IfOperStatus { - ifc, err := iw.luid.Interface() - if err != nil { - iw.logf("iw.luid.Interface error: %v", err) - return 0 - } - return ifc.OperStatus -} - -func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { - iw := &ifaceWatcher{ - luid: winipcfg.LUID(iface.(*tun.NativeTun).LUID()), - logf: logger.WithPrefix(logf, "waitInterfaceUp: "), - } - - // Just in case check the status first - if iw.getOperStatus() == winipcfg.IfOperStatusUp { - iw.logf("TUN interface already up; no need to wait") - return nil - } - - iw.sig = make(chan bool, 1) - cb, err := winipcfg.RegisterInterfaceChangeCallback(iw.callback) - if err != nil { - iw.logf("RegisterInterfaceChangeCallback error: %v", err) - return err - } - defer cb.Unregister() - - t0 := time.Now() - expires := t0.Add(timeout) - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - for { - iw.logf("waiting for TUN interface to come up...") - - select { - case <-iw.sig: - iw.logf("TUN interface is up after %v", time.Since(t0)) - return nil - case <-ticker.C: - } - - if iw.isUp() { - // Very unlikely to happen - either NotifyIpInterfaceChange doesn't work - // or it came up in the same moment as tick. Indicate this in the log message. - iw.logf("TUN interface is up after %v (on poll, without notification)", time.Since(t0)) - return nil - } - - if expires.Before(time.Now()) { - iw.logf("timeout waiting %v for TUN interface to come up", timeout) - return fmt.Errorf("timeout waiting for TUN interface to come up") - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "fmt" + "sync" + "time" + + "github.com/tailscale/wireguard-go/tun" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "tailscale.com/types/logger" +) + +// ifaceWatcher waits for an interface to be up. +type ifaceWatcher struct { + logf logger.Logf + luid winipcfg.LUID + + mu sync.Mutex // guards following + done bool + sig chan bool +} + +// callback is the callback we register with Windows to call when IP interface changes. +func (iw *ifaceWatcher) callback(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { + // Probably should check only when MibParameterNotification, but just in case included MibAddInstance also. + if notificationType == winipcfg.MibParameterNotification || notificationType == winipcfg.MibAddInstance { + // Out of paranoia, start a goroutine to finish our work, to return to Windows out of this callback. + go iw.isUp() + } +} + +func (iw *ifaceWatcher) isUp() bool { + iw.mu.Lock() + defer iw.mu.Unlock() + + if iw.done { + // We already know that it's up + return true + } + + if iw.getOperStatus() != winipcfg.IfOperStatusUp { + return false + } + + iw.done = true + iw.sig <- true + return true +} + +func (iw *ifaceWatcher) getOperStatus() winipcfg.IfOperStatus { + ifc, err := iw.luid.Interface() + if err != nil { + iw.logf("iw.luid.Interface error: %v", err) + return 0 + } + return ifc.OperStatus +} + +func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { + iw := &ifaceWatcher{ + luid: winipcfg.LUID(iface.(*tun.NativeTun).LUID()), + logf: logger.WithPrefix(logf, "waitInterfaceUp: "), + } + + // Just in case check the status first + if iw.getOperStatus() == winipcfg.IfOperStatusUp { + iw.logf("TUN interface already up; no need to wait") + return nil + } + + iw.sig = make(chan bool, 1) + cb, err := winipcfg.RegisterInterfaceChangeCallback(iw.callback) + if err != nil { + iw.logf("RegisterInterfaceChangeCallback error: %v", err) + return err + } + defer cb.Unregister() + + t0 := time.Now() + expires := t0.Add(timeout) + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + iw.logf("waiting for TUN interface to come up...") + + select { + case <-iw.sig: + iw.logf("TUN interface is up after %v", time.Since(t0)) + return nil + case <-ticker.C: + } + + if iw.isUp() { + // Very unlikely to happen - either NotifyIpInterfaceChange doesn't work + // or it came up in the same moment as tick. Indicate this in the log message. + iw.logf("TUN interface is up after %v (on poll, without notification)", time.Since(t0)) + return nil + } + + if expires.Before(time.Now()) { + iw.logf("timeout waiting %v for TUN interface to come up", timeout) + return fmt.Errorf("timeout waiting for TUN interface to come up") + } + } +} diff --git a/net/tstun/linkattrs_linux.go b/net/tstun/linkattrs_linux.go index 681e79269f75f..7f546110995ee 100644 --- a/net/tstun/linkattrs_linux.go +++ b/net/tstun/linkattrs_linux.go @@ -1,63 +1,63 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "github.com/mdlayher/genetlink" - "github.com/mdlayher/netlink" - "github.com/tailscale/wireguard-go/tun" - "golang.org/x/sys/unix" -) - -// setLinkSpeed sets the advertised link speed of the TUN interface. -func setLinkSpeed(iface tun.Device, mbps int) error { - name, err := iface.Name() - if err != nil { - return err - } - - conn, err := genetlink.Dial(&netlink.Config{Strict: true}) - if err != nil { - return err - } - - defer conn.Close() - - f, err := conn.GetFamily(unix.ETHTOOL_GENL_NAME) - if err != nil { - return err - } - - ae := netlink.NewAttributeEncoder() - ae.Nested(unix.ETHTOOL_A_LINKMODES_HEADER, func(nae *netlink.AttributeEncoder) error { - nae.String(unix.ETHTOOL_A_HEADER_DEV_NAME, name) - return nil - }) - ae.Uint32(unix.ETHTOOL_A_LINKMODES_SPEED, uint32(mbps)) - - b, err := ae.Encode() - if err != nil { - return err - } - - _, err = conn.Execute( - genetlink.Message{ - Header: genetlink.Header{ - Command: unix.ETHTOOL_MSG_LINKMODES_SET, - Version: unix.ETHTOOL_GENL_VERSION, - }, - Data: b, - }, - f.ID, - netlink.Request|netlink.Acknowledge, - ) - return err -} - -// setLinkAttrs sets up link attributes that can be queried by external tools. -// Its failure is non-fatal to interface bringup. -func setLinkAttrs(iface tun.Device) error { - // By default the link speed is 10Mbps, which is easily exceeded and causes monitoring tools to complain (#3933). - return setLinkSpeed(iface, unix.SPEED_UNKNOWN) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "github.com/mdlayher/genetlink" + "github.com/mdlayher/netlink" + "github.com/tailscale/wireguard-go/tun" + "golang.org/x/sys/unix" +) + +// setLinkSpeed sets the advertised link speed of the TUN interface. +func setLinkSpeed(iface tun.Device, mbps int) error { + name, err := iface.Name() + if err != nil { + return err + } + + conn, err := genetlink.Dial(&netlink.Config{Strict: true}) + if err != nil { + return err + } + + defer conn.Close() + + f, err := conn.GetFamily(unix.ETHTOOL_GENL_NAME) + if err != nil { + return err + } + + ae := netlink.NewAttributeEncoder() + ae.Nested(unix.ETHTOOL_A_LINKMODES_HEADER, func(nae *netlink.AttributeEncoder) error { + nae.String(unix.ETHTOOL_A_HEADER_DEV_NAME, name) + return nil + }) + ae.Uint32(unix.ETHTOOL_A_LINKMODES_SPEED, uint32(mbps)) + + b, err := ae.Encode() + if err != nil { + return err + } + + _, err = conn.Execute( + genetlink.Message{ + Header: genetlink.Header{ + Command: unix.ETHTOOL_MSG_LINKMODES_SET, + Version: unix.ETHTOOL_GENL_VERSION, + }, + Data: b, + }, + f.ID, + netlink.Request|netlink.Acknowledge, + ) + return err +} + +// setLinkAttrs sets up link attributes that can be queried by external tools. +// Its failure is non-fatal to interface bringup. +func setLinkAttrs(iface tun.Device) error { + // By default the link speed is 10Mbps, which is easily exceeded and causes monitoring tools to complain (#3933). + return setLinkSpeed(iface, unix.SPEED_UNKNOWN) +} diff --git a/net/tstun/linkattrs_notlinux.go b/net/tstun/linkattrs_notlinux.go index 7a7b40fc2652b..45dd000b3d7d4 100644 --- a/net/tstun/linkattrs_notlinux.go +++ b/net/tstun/linkattrs_notlinux.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package tstun - -import "github.com/tailscale/wireguard-go/tun" - -func setLinkAttrs(iface tun.Device) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package tstun + +import "github.com/tailscale/wireguard-go/tun" + +func setLinkAttrs(iface tun.Device) error { + return nil +} diff --git a/net/tstun/mtu.go b/net/tstun/mtu.go index 004529c205f9e..b72a19bdebe6e 100644 --- a/net/tstun/mtu.go +++ b/net/tstun/mtu.go @@ -1,161 +1,161 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "tailscale.com/envknob" -) - -// The MTU (Maximum Transmission Unit) of a network interface is the largest -// packet that can be sent or received through that interface, including all -// headers above the link layer (e.g. IP headers, UDP headers, Wireguard -// headers, etc.). We have to think about several different values of MTU: -// -// Wire MTU: The MTU of an interface underneath the tailscale TUN, e.g. an -// Ethernet network card will default to a 1500 byte MTU. The user may change -// this MTU at any time. -// -// TUN MTU: The current MTU of the tailscale TUN. This MTU is adjusted downward -// to make room for the wireguard/tailscale headers. For example, if the -// underlying network interface's MTU is 1500 bytes, the maximum size of a -// packet entering the tailscale TUN is 1420 bytes. The user may change this MTU -// at any time via the OS's tools (ifconfig, ip, etc.). -// -// User configured initial MTU: The MTU the tailscale TUN should be created -// with, set by the user via TS_DEBUG_MTU. It should be adjusted down from the -// underlying interface MTU by 80 bytes to make room for the wireguard -// headers. This envknob is mostly for debugging. This value is used once at TUN -// creation and ignored thereafter. -// -// User configured current MTU: The MTU set via the OS's tools (ifconfig, ip, -// etc.). This MTU can change at any time. Setting the MTU this way goes through -// the MTU() method of tailscale's TUN wrapper. -// -// Maximum probed MTU: This is the largest MTU size that we send probe packets -// for. -// -// Safe MTU: If the tailscale TUN MTU is set to this value, almost all packets -// will get to their destination. Tailscale defaults to this MTU in the absence -// of path MTU probe information or user MTU configuration. We may occasionally -// find a path that needs a smaller MTU but it is very rare. -// -// Peer MTU: This is the path MTU to a peer's current best endpoint. It defaults -// to the Safe MTU unless we have path MTU probe results that tell us otherwise. -// -// Initial MTU: This is the MTU tailscaled creates the TUN with. In order of -// priority, it is: -// -// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of 65536 -// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg -// overhead -// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU -// -// Current MTU: This the MTU of the tailscale TUN at any given moment -// after TUN creation. In order of priority, it is: -// -// 1. The MTU set by the user via the OS, if it has ever been set -// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg -// overhead -// 4. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU - -// TUNMTU is the MTU for the tailscale TUN. -type TUNMTU uint32 - -// WireMTU is the MTU for the underlying network devices. -type WireMTU uint32 - -const ( - // maxTUNMTU is the largest MTU we will consider for the Tailscale - // TUN. This is inherited from wireguard-go and can be surprisingly - // small; on Windows it is currently 2048 - 32 bytes and iOS it is 1700 - // - 32 bytes. - // TODO(val,raggi): On Windows this seems to derive from RIO driver - // constraints in Wireguard but we don't use RIO so could probably make - // this bigger. - maxTUNMTU TUNMTU = TUNMTU(MaxPacketSize) - // safeTUNMTU is the default "safe" MTU for the Tailscale TUN that we - // use in the absence of other information such as path MTU probes. - safeTUNMTU TUNMTU = 1280 -) - -// WireMTUsToProbe is a list of the on-the-wire MTUs we want to probe. Each time -// magicsock discovery begins, it will send a set of pings, one of each size -// listed below. -var WireMTUsToProbe = []WireMTU{ - WireMTU(safeTUNMTU), // Tailscale over Tailscale :) - TUNToWireMTU(safeTUNMTU), // Smallest MTU allowed for IPv6, current default - 1400, // Most common MTU minus a few bytes for tunnels - 1500, // Most common MTU - 8000, // Should fit inside all jumbo frame sizes - 9000, // Most jumbo frames are this size or larger -} - -// wgHeaderLen is the length of all the headers Wireguard adds to a packet -// in the worst case (IPv6). This constant is for use when we can't or -// shouldn't use information about the IP version of a specific packet -// (e.g., calculating the MTU for the Tailscale interface. -// -// A Wireguard header includes: -// -// - 20-byte IPv4 header or 40-byte IPv6 header -// - 8-byte UDP header -// - 4-byte type -// - 4-byte key index -// - 8-byte nonce -// - 16-byte authentication tag -const wgHeaderLen = 40 + 8 + 4 + 4 + 8 + 16 - -// TUNToWireMTU takes the MTU that the Tailscale TUN presents to the user and -// returns the on-the-wire MTU necessary to transmit the largest packet that -// will fit through the TUN, given that we have to add wireguard headers. -func TUNToWireMTU(t TUNMTU) WireMTU { - return WireMTU(t + wgHeaderLen) -} - -// WireToTUNMTU takes the MTU of an underlying network device and returns the -// largest possible MTU for a Tailscale TUN operating on top of that device, -// given that we have to add wireguard headers. -func WireToTUNMTU(w WireMTU) TUNMTU { - if w < wgHeaderLen { - return 0 - } - return TUNMTU(w - wgHeaderLen) -} - -// DefaultTUNMTU returns the MTU we use to set the Tailscale TUN -// MTU. It is also the path MTU that we default to if we have no -// information about the path to a peer. -// -// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of MaxTUNMTU -// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg overhead -// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU -func DefaultTUNMTU() TUNMTU { - if m, ok := envknob.LookupUintSized("TS_DEBUG_MTU", 10, 32); ok { - return min(TUNMTU(m), maxTUNMTU) - } - - debugPMTUD, _ := envknob.LookupBool("TS_DEBUG_ENABLE_PMTUD") - if debugPMTUD { - // TODO: While we are just probing MTU but not generating PTB, - // this has to continue to return the safe MTU. When we add the - // code to generate PTB, this will be: - // - // return WireToTUNMTU(maxProbedWireMTU) - return safeTUNMTU - } - - return safeTUNMTU -} - -// SafeWireMTU returns the wire MTU that is safe to use if we have no -// information about the path MTU to this peer. -func SafeWireMTU() WireMTU { - return TUNToWireMTU(safeTUNMTU) -} - -// DefaultWireMTU returns the default TUN MTU, adjusted for wireguard -// overhead. -func DefaultWireMTU() WireMTU { - return TUNToWireMTU(DefaultTUNMTU()) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "tailscale.com/envknob" +) + +// The MTU (Maximum Transmission Unit) of a network interface is the largest +// packet that can be sent or received through that interface, including all +// headers above the link layer (e.g. IP headers, UDP headers, Wireguard +// headers, etc.). We have to think about several different values of MTU: +// +// Wire MTU: The MTU of an interface underneath the tailscale TUN, e.g. an +// Ethernet network card will default to a 1500 byte MTU. The user may change +// this MTU at any time. +// +// TUN MTU: The current MTU of the tailscale TUN. This MTU is adjusted downward +// to make room for the wireguard/tailscale headers. For example, if the +// underlying network interface's MTU is 1500 bytes, the maximum size of a +// packet entering the tailscale TUN is 1420 bytes. The user may change this MTU +// at any time via the OS's tools (ifconfig, ip, etc.). +// +// User configured initial MTU: The MTU the tailscale TUN should be created +// with, set by the user via TS_DEBUG_MTU. It should be adjusted down from the +// underlying interface MTU by 80 bytes to make room for the wireguard +// headers. This envknob is mostly for debugging. This value is used once at TUN +// creation and ignored thereafter. +// +// User configured current MTU: The MTU set via the OS's tools (ifconfig, ip, +// etc.). This MTU can change at any time. Setting the MTU this way goes through +// the MTU() method of tailscale's TUN wrapper. +// +// Maximum probed MTU: This is the largest MTU size that we send probe packets +// for. +// +// Safe MTU: If the tailscale TUN MTU is set to this value, almost all packets +// will get to their destination. Tailscale defaults to this MTU in the absence +// of path MTU probe information or user MTU configuration. We may occasionally +// find a path that needs a smaller MTU but it is very rare. +// +// Peer MTU: This is the path MTU to a peer's current best endpoint. It defaults +// to the Safe MTU unless we have path MTU probe results that tell us otherwise. +// +// Initial MTU: This is the MTU tailscaled creates the TUN with. In order of +// priority, it is: +// +// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of 65536 +// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg +// overhead +// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU +// +// Current MTU: This the MTU of the tailscale TUN at any given moment +// after TUN creation. In order of priority, it is: +// +// 1. The MTU set by the user via the OS, if it has ever been set +// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg +// overhead +// 4. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU + +// TUNMTU is the MTU for the tailscale TUN. +type TUNMTU uint32 + +// WireMTU is the MTU for the underlying network devices. +type WireMTU uint32 + +const ( + // maxTUNMTU is the largest MTU we will consider for the Tailscale + // TUN. This is inherited from wireguard-go and can be surprisingly + // small; on Windows it is currently 2048 - 32 bytes and iOS it is 1700 + // - 32 bytes. + // TODO(val,raggi): On Windows this seems to derive from RIO driver + // constraints in Wireguard but we don't use RIO so could probably make + // this bigger. + maxTUNMTU TUNMTU = TUNMTU(MaxPacketSize) + // safeTUNMTU is the default "safe" MTU for the Tailscale TUN that we + // use in the absence of other information such as path MTU probes. + safeTUNMTU TUNMTU = 1280 +) + +// WireMTUsToProbe is a list of the on-the-wire MTUs we want to probe. Each time +// magicsock discovery begins, it will send a set of pings, one of each size +// listed below. +var WireMTUsToProbe = []WireMTU{ + WireMTU(safeTUNMTU), // Tailscale over Tailscale :) + TUNToWireMTU(safeTUNMTU), // Smallest MTU allowed for IPv6, current default + 1400, // Most common MTU minus a few bytes for tunnels + 1500, // Most common MTU + 8000, // Should fit inside all jumbo frame sizes + 9000, // Most jumbo frames are this size or larger +} + +// wgHeaderLen is the length of all the headers Wireguard adds to a packet +// in the worst case (IPv6). This constant is for use when we can't or +// shouldn't use information about the IP version of a specific packet +// (e.g., calculating the MTU for the Tailscale interface. +// +// A Wireguard header includes: +// +// - 20-byte IPv4 header or 40-byte IPv6 header +// - 8-byte UDP header +// - 4-byte type +// - 4-byte key index +// - 8-byte nonce +// - 16-byte authentication tag +const wgHeaderLen = 40 + 8 + 4 + 4 + 8 + 16 + +// TUNToWireMTU takes the MTU that the Tailscale TUN presents to the user and +// returns the on-the-wire MTU necessary to transmit the largest packet that +// will fit through the TUN, given that we have to add wireguard headers. +func TUNToWireMTU(t TUNMTU) WireMTU { + return WireMTU(t + wgHeaderLen) +} + +// WireToTUNMTU takes the MTU of an underlying network device and returns the +// largest possible MTU for a Tailscale TUN operating on top of that device, +// given that we have to add wireguard headers. +func WireToTUNMTU(w WireMTU) TUNMTU { + if w < wgHeaderLen { + return 0 + } + return TUNMTU(w - wgHeaderLen) +} + +// DefaultTUNMTU returns the MTU we use to set the Tailscale TUN +// MTU. It is also the path MTU that we default to if we have no +// information about the path to a peer. +// +// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of MaxTUNMTU +// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg overhead +// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU +func DefaultTUNMTU() TUNMTU { + if m, ok := envknob.LookupUintSized("TS_DEBUG_MTU", 10, 32); ok { + return min(TUNMTU(m), maxTUNMTU) + } + + debugPMTUD, _ := envknob.LookupBool("TS_DEBUG_ENABLE_PMTUD") + if debugPMTUD { + // TODO: While we are just probing MTU but not generating PTB, + // this has to continue to return the safe MTU. When we add the + // code to generate PTB, this will be: + // + // return WireToTUNMTU(maxProbedWireMTU) + return safeTUNMTU + } + + return safeTUNMTU +} + +// SafeWireMTU returns the wire MTU that is safe to use if we have no +// information about the path MTU to this peer. +func SafeWireMTU() WireMTU { + return TUNToWireMTU(safeTUNMTU) +} + +// DefaultWireMTU returns the default TUN MTU, adjusted for wireguard +// overhead. +func DefaultWireMTU() WireMTU { + return TUNToWireMTU(DefaultTUNMTU()) +} diff --git a/net/tstun/mtu_test.go b/net/tstun/mtu_test.go index 8d165bfd341a9..fc5274ae1037c 100644 --- a/net/tstun/mtu_test.go +++ b/net/tstun/mtu_test.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package tstun - -import ( - "os" - "strconv" - "testing" -) - -// Test the default MTU in the presence of various envknobs. -func TestDefaultTunMTU(t *testing.T) { - // Save and restore the envknobs we will be changing. - - // TS_DEBUG_MTU sets the MTU to a specific value. - defer os.Setenv("TS_DEBUG_MTU", os.Getenv("TS_DEBUG_MTU")) - os.Setenv("TS_DEBUG_MTU", "") - - // TS_DEBUG_ENABLE_PMTUD enables path MTU discovery. - defer os.Setenv("TS_DEBUG_ENABLE_PMTUD", os.Getenv("TS_DEBUG_ENABLE_PMTUD")) - os.Setenv("TS_DEBUG_ENABLE_PMTUD", "") - - // With no MTU envknobs set, we should get the conservative MTU. - if DefaultTUNMTU() != safeTUNMTU { - t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU) - } - - // If set, TS_DEBUG_MTU should set the MTU. - mtu := maxTUNMTU - 1 - os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) - if DefaultTUNMTU() != mtu { - t.Errorf("default TUN MTU = %d, want %d, TS_DEBUG_MTU ignored", DefaultTUNMTU(), mtu) - } - - // MTU should be clamped to maxTunMTU. - mtu = maxTUNMTU + 1 - os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) - if DefaultTUNMTU() != maxTUNMTU { - t.Errorf("default TUN MTU = %d, want %d, clamping failed", DefaultTUNMTU(), maxTUNMTU) - } - - // If PMTUD is enabled, the MTU should default to the safe MTU, but only - // if the user hasn't requested a specific MTU. - // - // TODO: When PMTUD is generating PTB responses, this will become the - // largest MTU we probe. - os.Setenv("TS_DEBUG_MTU", "") - os.Setenv("TS_DEBUG_ENABLE_PMTUD", "true") - if DefaultTUNMTU() != safeTUNMTU { - t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU) - } - // TS_DEBUG_MTU should take precedence over TS_DEBUG_ENABLE_PMTUD. - mtu = WireToTUNMTU(MaxPacketSize - 1) - os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) - if DefaultTUNMTU() != mtu { - t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), mtu) - } -} - -// Test the conversion of wire MTU to/from Tailscale TUN MTU corner cases. -func TestMTUConversion(t *testing.T) { - tests := []struct { - w WireMTU - t TUNMTU - }{ - {w: 0, t: 0}, - {w: wgHeaderLen - 1, t: 0}, - {w: wgHeaderLen, t: 0}, - {w: wgHeaderLen + 1, t: 1}, - {w: 1360, t: 1280}, - {w: 1500, t: 1420}, - {w: 9000, t: 8920}, - } - - for _, tt := range tests { - m := WireToTUNMTU(tt.w) - if m != tt.t { - t.Errorf("conversion of wire MTU %v to TUN MTU = %v, want %v", tt.w, m, tt.t) - } - } - - tests2 := []struct { - t TUNMTU - w WireMTU - }{ - {t: 0, w: wgHeaderLen}, - {t: 1, w: wgHeaderLen + 1}, - {t: 1280, w: 1360}, - {t: 1420, w: 1500}, - {t: 8920, w: 9000}, - } - - for _, tt := range tests2 { - m := TUNToWireMTU(tt.t) - if m != tt.w { - t.Errorf("conversion of TUN MTU %v to wire MTU = %v, want %v", tt.t, m, tt.w) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package tstun + +import ( + "os" + "strconv" + "testing" +) + +// Test the default MTU in the presence of various envknobs. +func TestDefaultTunMTU(t *testing.T) { + // Save and restore the envknobs we will be changing. + + // TS_DEBUG_MTU sets the MTU to a specific value. + defer os.Setenv("TS_DEBUG_MTU", os.Getenv("TS_DEBUG_MTU")) + os.Setenv("TS_DEBUG_MTU", "") + + // TS_DEBUG_ENABLE_PMTUD enables path MTU discovery. + defer os.Setenv("TS_DEBUG_ENABLE_PMTUD", os.Getenv("TS_DEBUG_ENABLE_PMTUD")) + os.Setenv("TS_DEBUG_ENABLE_PMTUD", "") + + // With no MTU envknobs set, we should get the conservative MTU. + if DefaultTUNMTU() != safeTUNMTU { + t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU) + } + + // If set, TS_DEBUG_MTU should set the MTU. + mtu := maxTUNMTU - 1 + os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) + if DefaultTUNMTU() != mtu { + t.Errorf("default TUN MTU = %d, want %d, TS_DEBUG_MTU ignored", DefaultTUNMTU(), mtu) + } + + // MTU should be clamped to maxTunMTU. + mtu = maxTUNMTU + 1 + os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) + if DefaultTUNMTU() != maxTUNMTU { + t.Errorf("default TUN MTU = %d, want %d, clamping failed", DefaultTUNMTU(), maxTUNMTU) + } + + // If PMTUD is enabled, the MTU should default to the safe MTU, but only + // if the user hasn't requested a specific MTU. + // + // TODO: When PMTUD is generating PTB responses, this will become the + // largest MTU we probe. + os.Setenv("TS_DEBUG_MTU", "") + os.Setenv("TS_DEBUG_ENABLE_PMTUD", "true") + if DefaultTUNMTU() != safeTUNMTU { + t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU) + } + // TS_DEBUG_MTU should take precedence over TS_DEBUG_ENABLE_PMTUD. + mtu = WireToTUNMTU(MaxPacketSize - 1) + os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) + if DefaultTUNMTU() != mtu { + t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), mtu) + } +} + +// Test the conversion of wire MTU to/from Tailscale TUN MTU corner cases. +func TestMTUConversion(t *testing.T) { + tests := []struct { + w WireMTU + t TUNMTU + }{ + {w: 0, t: 0}, + {w: wgHeaderLen - 1, t: 0}, + {w: wgHeaderLen, t: 0}, + {w: wgHeaderLen + 1, t: 1}, + {w: 1360, t: 1280}, + {w: 1500, t: 1420}, + {w: 9000, t: 8920}, + } + + for _, tt := range tests { + m := WireToTUNMTU(tt.w) + if m != tt.t { + t.Errorf("conversion of wire MTU %v to TUN MTU = %v, want %v", tt.w, m, tt.t) + } + } + + tests2 := []struct { + t TUNMTU + w WireMTU + }{ + {t: 0, w: wgHeaderLen}, + {t: 1, w: wgHeaderLen + 1}, + {t: 1280, w: 1360}, + {t: 1420, w: 1500}, + {t: 8920, w: 9000}, + } + + for _, tt := range tests2 { + m := TUNToWireMTU(tt.t) + if m != tt.w { + t.Errorf("conversion of TUN MTU %v to wire MTU = %v, want %v", tt.t, m, tt.w) + } + } +} diff --git a/net/tstun/tun_linux.go b/net/tstun/tun_linux.go index 9600ceb77328f..e08f12bc14129 100644 --- a/net/tstun/tun_linux.go +++ b/net/tstun/tun_linux.go @@ -1,103 +1,103 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "bytes" - "errors" - "os" - "os/exec" - "strings" - "syscall" - - "tailscale.com/types/logger" - "tailscale.com/version/distro" -) - -func init() { - tunDiagnoseFailure = diagnoseLinuxTUNFailure -} - -func diagnoseLinuxTUNFailure(tunName string, logf logger.Logf, createErr error) { - if errors.Is(createErr, syscall.EBUSY) { - logf("TUN device %s is busy; another process probably still has it open (from old version of Tailscale that had a bug)", tunName) - logf("To fix, kill the process that has it open. Find with:\n\n$ sudo lsof -n /dev/net/tun\n\n") - logf("... and then kill those PID(s)") - return - } - - var un syscall.Utsname - err := syscall.Uname(&un) - if err != nil { - logf("no TUN, and failed to look up kernel version: %v", err) - return - } - kernel := utsReleaseField(&un) - logf("Linux kernel version: %s", kernel) - - modprobeOut, err := exec.Command("/sbin/modprobe", "tun").CombinedOutput() - if err == nil { - logf("'modprobe tun' successful") - // Either tun is currently loaded, or it's statically - // compiled into the kernel (which modprobe checks - // with /lib/modules/$(uname -r)/modules.builtin) - // - // So if there's a problem at this point, it's - // probably because /dev/net/tun doesn't exist. - const dev = "/dev/net/tun" - if fi, err := os.Stat(dev); err != nil { - logf("tun module loaded in kernel, but %s does not exist", dev) - } else { - logf("%s: %v", dev, fi.Mode()) - } - - // We failed to find why it failed. Just let our - // caller report the error it got from wireguard-go. - return - } - logf("is CONFIG_TUN enabled in your kernel? `modprobe tun` failed with: %s", modprobeOut) - - switch distro.Get() { - case distro.Debian: - dpkgOut, err := exec.Command("dpkg", "-S", "kernel/drivers/net/tun.ko").CombinedOutput() - if len(bytes.TrimSpace(dpkgOut)) == 0 || err != nil { - logf("tun module not loaded nor found on disk") - return - } - if !bytes.Contains(dpkgOut, []byte(kernel)) { - logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", dpkgOut) - } - case distro.Arch: - findOut, err := exec.Command("find", "/lib/modules/", "-path", "*/net/tun.ko*").CombinedOutput() - if len(bytes.TrimSpace(findOut)) == 0 || err != nil { - logf("tun module not loaded nor found on disk") - return - } - if !bytes.Contains(findOut, []byte(kernel)) { - logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", findOut) - } - case distro.OpenWrt: - out, err := exec.Command("opkg", "list-installed").CombinedOutput() - if err != nil { - logf("error querying OpenWrt installed packages: %s", out) - return - } - for _, pkg := range []string{"kmod-tun", "ca-bundle"} { - if !bytes.Contains(out, []byte(pkg+" - ")) { - logf("Missing required package %s; run: opkg install %s", pkg, pkg) - } - } - } -} - -func utsReleaseField(u *syscall.Utsname) string { - var sb strings.Builder - for _, v := range u.Release { - if v == 0 { - break - } - sb.WriteByte(byte(v)) - } - return strings.TrimSpace(sb.String()) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "bytes" + "errors" + "os" + "os/exec" + "strings" + "syscall" + + "tailscale.com/types/logger" + "tailscale.com/version/distro" +) + +func init() { + tunDiagnoseFailure = diagnoseLinuxTUNFailure +} + +func diagnoseLinuxTUNFailure(tunName string, logf logger.Logf, createErr error) { + if errors.Is(createErr, syscall.EBUSY) { + logf("TUN device %s is busy; another process probably still has it open (from old version of Tailscale that had a bug)", tunName) + logf("To fix, kill the process that has it open. Find with:\n\n$ sudo lsof -n /dev/net/tun\n\n") + logf("... and then kill those PID(s)") + return + } + + var un syscall.Utsname + err := syscall.Uname(&un) + if err != nil { + logf("no TUN, and failed to look up kernel version: %v", err) + return + } + kernel := utsReleaseField(&un) + logf("Linux kernel version: %s", kernel) + + modprobeOut, err := exec.Command("/sbin/modprobe", "tun").CombinedOutput() + if err == nil { + logf("'modprobe tun' successful") + // Either tun is currently loaded, or it's statically + // compiled into the kernel (which modprobe checks + // with /lib/modules/$(uname -r)/modules.builtin) + // + // So if there's a problem at this point, it's + // probably because /dev/net/tun doesn't exist. + const dev = "/dev/net/tun" + if fi, err := os.Stat(dev); err != nil { + logf("tun module loaded in kernel, but %s does not exist", dev) + } else { + logf("%s: %v", dev, fi.Mode()) + } + + // We failed to find why it failed. Just let our + // caller report the error it got from wireguard-go. + return + } + logf("is CONFIG_TUN enabled in your kernel? `modprobe tun` failed with: %s", modprobeOut) + + switch distro.Get() { + case distro.Debian: + dpkgOut, err := exec.Command("dpkg", "-S", "kernel/drivers/net/tun.ko").CombinedOutput() + if len(bytes.TrimSpace(dpkgOut)) == 0 || err != nil { + logf("tun module not loaded nor found on disk") + return + } + if !bytes.Contains(dpkgOut, []byte(kernel)) { + logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", dpkgOut) + } + case distro.Arch: + findOut, err := exec.Command("find", "/lib/modules/", "-path", "*/net/tun.ko*").CombinedOutput() + if len(bytes.TrimSpace(findOut)) == 0 || err != nil { + logf("tun module not loaded nor found on disk") + return + } + if !bytes.Contains(findOut, []byte(kernel)) { + logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", findOut) + } + case distro.OpenWrt: + out, err := exec.Command("opkg", "list-installed").CombinedOutput() + if err != nil { + logf("error querying OpenWrt installed packages: %s", out) + return + } + for _, pkg := range []string{"kmod-tun", "ca-bundle"} { + if !bytes.Contains(out, []byte(pkg+" - ")) { + logf("Missing required package %s; run: opkg install %s", pkg, pkg) + } + } + } +} + +func utsReleaseField(u *syscall.Utsname) string { + var sb strings.Builder + for _, v := range u.Release { + if v == 0 { + break + } + sb.WriteByte(byte(v)) + } + return strings.TrimSpace(sb.String()) +} diff --git a/net/tstun/tun_macos.go b/net/tstun/tun_macos.go index 3506f05b1e4c9..f71494f0b91b6 100644 --- a/net/tstun/tun_macos.go +++ b/net/tstun/tun_macos.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin && !ios - -package tstun - -import ( - "os" - - "tailscale.com/types/logger" -) - -func init() { - tunDiagnoseFailure = diagnoseDarwinTUNFailure -} - -func diagnoseDarwinTUNFailure(tunName string, logf logger.Logf, err error) { - if os.Getuid() != 0 { - logf("failed to create TUN device as non-root user; use 'sudo tailscaled', or run under launchd with 'sudo tailscaled install-system-daemon'") - } - if tunName != "utun" { - logf("failed to create TUN device %q; try using tun device \"utun\" instead for automatic selection", tunName) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin && !ios + +package tstun + +import ( + "os" + + "tailscale.com/types/logger" +) + +func init() { + tunDiagnoseFailure = diagnoseDarwinTUNFailure +} + +func diagnoseDarwinTUNFailure(tunName string, logf logger.Logf, err error) { + if os.Getuid() != 0 { + logf("failed to create TUN device as non-root user; use 'sudo tailscaled', or run under launchd with 'sudo tailscaled install-system-daemon'") + } + if tunName != "utun" { + logf("failed to create TUN device %q; try using tun device \"utun\" instead for automatic selection", tunName) + } +} diff --git a/net/tstun/tun_notwindows.go b/net/tstun/tun_notwindows.go index 087fcd4eec784..60f1c62bacaab 100644 --- a/net/tstun/tun_notwindows.go +++ b/net/tstun/tun_notwindows.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package tstun - -import "github.com/tailscale/wireguard-go/tun" - -func interfaceName(dev tun.Device) (string, error) { - return dev.Name() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package tstun + +import "github.com/tailscale/wireguard-go/tun" + +func interfaceName(dev tun.Device) (string, error) { + return dev.Name() +} diff --git a/packages/deb/deb.go b/packages/deb/deb.go index 30e3f2b4d360c..1be7f96526d1e 100644 --- a/packages/deb/deb.go +++ b/packages/deb/deb.go @@ -1,182 +1,182 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package deb extracts metadata from Debian packages. -package deb - -import ( - "archive/tar" - "bufio" - "bytes" - "compress/gzip" - "crypto/md5" - "crypto/sha1" - "crypto/sha256" - "errors" - "fmt" - "io" - "os" - "path/filepath" - "strconv" - "strings" -) - -// Info is the Debian package metadata needed to integrate the package -// into a repository. -type Info struct { - // Version is the version of the package, as reported by dpkg. - Version string - // Arch is the Debian CPU architecture the package is for. - Arch string - // Control is the entire contents of the package's control file, - // with leading and trailing whitespace removed. - Control []byte - // MD5 is the MD5 hash of the package file. - MD5 []byte - // SHA1 is the SHA1 hash of the package file. - SHA1 []byte - // SHA256 is the SHA256 hash of the package file. - SHA256 []byte -} - -// ReadFile returns Debian package metadata from the .deb file at path. -func ReadFile(path string) (*Info, error) { - f, err := os.Open(path) - if err != nil { - return nil, err - } - return Read(f) -} - -// Read returns Debian package metadata from the .deb file in r. -func Read(r io.Reader) (*Info, error) { - b := bufio.NewReader(r) - - m5, s1, s256 := md5.New(), sha1.New(), sha256.New() - summers := io.MultiWriter(m5, s1, s256) - r = io.TeeReader(b, summers) - - t, err := findControlTar(r) - if err != nil { - return nil, fmt.Errorf("searching for control.tar.gz: %w", err) - } - - control, err := findControlFile(t) - if err != nil { - return nil, fmt.Errorf("searching for control file in control.tar.gz: %w", err) - } - - arch, version, err := findArchAndVersion(control) - if err != nil { - return nil, fmt.Errorf("extracting version and architecture from control file: %w", err) - } - - // Exhaust the remainder of r, so that the summers see the entire file. - if _, err := io.Copy(io.Discard, r); err != nil { - return nil, fmt.Errorf("hashing file: %w", err) - } - - return &Info{ - Version: version, - Arch: arch, - Control: control, - MD5: m5.Sum(nil), - SHA1: s1.Sum(nil), - SHA256: s256.Sum(nil), - }, nil -} - -// findControlTar reads r as an `ar` archive, finds a tarball named -// `control.tar.gz` within, and returns a reader for that file. -func findControlTar(r io.Reader) (tarReader io.Reader, err error) { - var magic [8]byte - if _, err := io.ReadFull(r, magic[:]); err != nil { - return nil, fmt.Errorf("reading ar magic: %w", err) - } - if string(magic[:]) != "!\n" { - return nil, fmt.Errorf("not an ar file (bad magic %q)", magic) - } - - for { - var hdr [60]byte - if _, err := io.ReadFull(r, hdr[:]); err != nil { - return nil, fmt.Errorf("reading file header: %w", err) - } - filename := strings.TrimSpace(string(hdr[:16])) - size, err := strconv.ParseInt(strings.TrimSpace(string(hdr[48:58])), 10, 64) - if err != nil { - return nil, fmt.Errorf("reading size of file %q: %w", filename, err) - } - if filename == "control.tar.gz" { - return io.LimitReader(r, size), nil - } - - // files in ar are padded out to 2 bytes. - if size%2 == 1 { - size++ - } - if _, err := io.CopyN(io.Discard, r, size); err != nil { - return nil, fmt.Errorf("seeking past file %q: %w", filename, err) - } - } -} - -// findControlFile reads r as a tar.gz archive, finds a file named -// `control` within, and returns its contents. -func findControlFile(r io.Reader) (control []byte, err error) { - gz, err := gzip.NewReader(r) - if err != nil { - return nil, fmt.Errorf("decompressing control.tar.gz: %w", err) - } - defer gz.Close() - - tr := tar.NewReader(gz) - for { - hdr, err := tr.Next() - if err != nil { - if errors.Is(err, io.EOF) { - return nil, errors.New("EOF while looking for control file in control.tar.gz") - } - return nil, fmt.Errorf("reading tar header: %w", err) - } - - if filepath.Clean(hdr.Name) != "control" { - continue - } - - // Found control file - break - } - - bs, err := io.ReadAll(tr) - if err != nil { - return nil, fmt.Errorf("reading control file: %w", err) - } - - return bytes.TrimSpace(bs), nil -} - -var ( - archKey = []byte("Architecture:") - versionKey = []byte("Version:") -) - -// findArchAndVersion extracts the architecture and version strings -// from the given control file. -func findArchAndVersion(control []byte) (arch string, version string, err error) { - b := bytes.NewBuffer(control) - for { - l, err := b.ReadBytes('\n') - if err != nil { - return "", "", err - } - if bytes.HasPrefix(l, archKey) { - arch = string(bytes.TrimSpace(l[len(archKey):])) - } else if bytes.HasPrefix(l, versionKey) { - version = string(bytes.TrimSpace(l[len(versionKey):])) - } - if arch != "" && version != "" { - return arch, version, nil - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package deb extracts metadata from Debian packages. +package deb + +import ( + "archive/tar" + "bufio" + "bytes" + "compress/gzip" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strconv" + "strings" +) + +// Info is the Debian package metadata needed to integrate the package +// into a repository. +type Info struct { + // Version is the version of the package, as reported by dpkg. + Version string + // Arch is the Debian CPU architecture the package is for. + Arch string + // Control is the entire contents of the package's control file, + // with leading and trailing whitespace removed. + Control []byte + // MD5 is the MD5 hash of the package file. + MD5 []byte + // SHA1 is the SHA1 hash of the package file. + SHA1 []byte + // SHA256 is the SHA256 hash of the package file. + SHA256 []byte +} + +// ReadFile returns Debian package metadata from the .deb file at path. +func ReadFile(path string) (*Info, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + return Read(f) +} + +// Read returns Debian package metadata from the .deb file in r. +func Read(r io.Reader) (*Info, error) { + b := bufio.NewReader(r) + + m5, s1, s256 := md5.New(), sha1.New(), sha256.New() + summers := io.MultiWriter(m5, s1, s256) + r = io.TeeReader(b, summers) + + t, err := findControlTar(r) + if err != nil { + return nil, fmt.Errorf("searching for control.tar.gz: %w", err) + } + + control, err := findControlFile(t) + if err != nil { + return nil, fmt.Errorf("searching for control file in control.tar.gz: %w", err) + } + + arch, version, err := findArchAndVersion(control) + if err != nil { + return nil, fmt.Errorf("extracting version and architecture from control file: %w", err) + } + + // Exhaust the remainder of r, so that the summers see the entire file. + if _, err := io.Copy(io.Discard, r); err != nil { + return nil, fmt.Errorf("hashing file: %w", err) + } + + return &Info{ + Version: version, + Arch: arch, + Control: control, + MD5: m5.Sum(nil), + SHA1: s1.Sum(nil), + SHA256: s256.Sum(nil), + }, nil +} + +// findControlTar reads r as an `ar` archive, finds a tarball named +// `control.tar.gz` within, and returns a reader for that file. +func findControlTar(r io.Reader) (tarReader io.Reader, err error) { + var magic [8]byte + if _, err := io.ReadFull(r, magic[:]); err != nil { + return nil, fmt.Errorf("reading ar magic: %w", err) + } + if string(magic[:]) != "!\n" { + return nil, fmt.Errorf("not an ar file (bad magic %q)", magic) + } + + for { + var hdr [60]byte + if _, err := io.ReadFull(r, hdr[:]); err != nil { + return nil, fmt.Errorf("reading file header: %w", err) + } + filename := strings.TrimSpace(string(hdr[:16])) + size, err := strconv.ParseInt(strings.TrimSpace(string(hdr[48:58])), 10, 64) + if err != nil { + return nil, fmt.Errorf("reading size of file %q: %w", filename, err) + } + if filename == "control.tar.gz" { + return io.LimitReader(r, size), nil + } + + // files in ar are padded out to 2 bytes. + if size%2 == 1 { + size++ + } + if _, err := io.CopyN(io.Discard, r, size); err != nil { + return nil, fmt.Errorf("seeking past file %q: %w", filename, err) + } + } +} + +// findControlFile reads r as a tar.gz archive, finds a file named +// `control` within, and returns its contents. +func findControlFile(r io.Reader) (control []byte, err error) { + gz, err := gzip.NewReader(r) + if err != nil { + return nil, fmt.Errorf("decompressing control.tar.gz: %w", err) + } + defer gz.Close() + + tr := tar.NewReader(gz) + for { + hdr, err := tr.Next() + if err != nil { + if errors.Is(err, io.EOF) { + return nil, errors.New("EOF while looking for control file in control.tar.gz") + } + return nil, fmt.Errorf("reading tar header: %w", err) + } + + if filepath.Clean(hdr.Name) != "control" { + continue + } + + // Found control file + break + } + + bs, err := io.ReadAll(tr) + if err != nil { + return nil, fmt.Errorf("reading control file: %w", err) + } + + return bytes.TrimSpace(bs), nil +} + +var ( + archKey = []byte("Architecture:") + versionKey = []byte("Version:") +) + +// findArchAndVersion extracts the architecture and version strings +// from the given control file. +func findArchAndVersion(control []byte) (arch string, version string, err error) { + b := bytes.NewBuffer(control) + for { + l, err := b.ReadBytes('\n') + if err != nil { + return "", "", err + } + if bytes.HasPrefix(l, archKey) { + arch = string(bytes.TrimSpace(l[len(archKey):])) + } else if bytes.HasPrefix(l, versionKey) { + version = string(bytes.TrimSpace(l[len(versionKey):])) + } + if arch != "" && version != "" { + return arch, version, nil + } + } +} diff --git a/packages/deb/deb_test.go b/packages/deb/deb_test.go index 1a25f67ad4875..0ff43da21d151 100644 --- a/packages/deb/deb_test.go +++ b/packages/deb/deb_test.go @@ -1,205 +1,205 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package deb - -import ( - "bytes" - "crypto/md5" - "crypto/sha1" - "crypto/sha256" - "encoding/hex" - "fmt" - "hash" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/goreleaser/nfpm/v2" - _ "github.com/goreleaser/nfpm/v2/deb" -) - -func TestDebInfo(t *testing.T) { - tests := []struct { - name string - in []byte - want *Info - wantErr bool - }{ - { - name: "simple", - in: mkTestDeb("1.2.3", "amd64"), - want: &Info{ - Version: "1.2.3", - Arch: "amd64", - Control: mkControl( - "Package", "tailscale", - "Version", "1.2.3", - "Section", "net", - "Priority", "extra", - "Architecture", "amd64", - "Maintainer", "Tail Scalar", - "Installed-Size", "0", - "Description", "test package"), - }, - }, - { - name: "arm64", - in: mkTestDeb("1.2.3", "arm64"), - want: &Info{ - Version: "1.2.3", - Arch: "arm64", - Control: mkControl( - "Package", "tailscale", - "Version", "1.2.3", - "Section", "net", - "Priority", "extra", - "Architecture", "arm64", - "Maintainer", "Tail Scalar", - "Installed-Size", "0", - "Description", "test package"), - }, - }, - { - name: "unstable", - in: mkTestDeb("1.7.25", "amd64"), - want: &Info{ - Version: "1.7.25", - Arch: "amd64", - Control: mkControl( - "Package", "tailscale", - "Version", "1.7.25", - "Section", "net", - "Priority", "extra", - "Architecture", "amd64", - "Maintainer", "Tail Scalar", - "Installed-Size", "0", - "Description", "test package"), - }, - }, - - // These truncation tests assume the structure of a .deb - // package, which is as follows: - // magic: 8 bytes - // file header: 60 bytes, before each file blob - // - // The first file in a .deb ar is "debian-binary", which is 4 - // bytes long and consists of "2.0\n". - // The second file is control.tar.gz, which is what we care - // about introspecting for metadata. - // The final file is data.tar.gz, which we don't care about. - // - // The first file in control.tar.gz is the "control" file we - // want to read for metadata. - { - name: "truncated_ar_magic", - in: mkTestDeb("1.7.25", "amd64")[:4], - wantErr: true, - }, - { - name: "truncated_ar_header", - in: mkTestDeb("1.7.25", "amd64")[:30], - wantErr: true, - }, - { - name: "missing_control_tgz", - // Truncate right after the "debian-binary" file, which - // makes the file a valid 1-file archive that's missing - // control.tar.gz. - in: mkTestDeb("1.7.25", "amd64")[:72], - wantErr: true, - }, - { - name: "truncated_tgz", - in: mkTestDeb("1.7.25", "amd64")[:172], - wantErr: true, - }, - } - - for _, test := range tests { - // mkTestDeb returns non-deterministic output due to - // timestamps embedded in the package file, so compute the - // wanted hashes on the fly here. - if test.want != nil { - test.want.MD5 = mkHash(test.in, md5.New) - test.want.SHA1 = mkHash(test.in, sha1.New) - test.want.SHA256 = mkHash(test.in, sha256.New) - } - - t.Run(test.name, func(t *testing.T) { - b := bytes.NewBuffer(test.in) - got, err := Read(b) - if err != nil { - if test.wantErr { - t.Logf("got expected error: %v", err) - return - } - t.Fatalf("reading deb info: %v", err) - } - if diff := diff(got, test.want); diff != "" { - t.Fatalf("parsed info diff (-got+want):\n%s", diff) - } - }) - } -} - -func diff(got, want any) string { - matchField := func(name string) func(p cmp.Path) bool { - return func(p cmp.Path) bool { - if len(p) != 3 { - return false - } - return p[2].String() == "."+name - } - } - toLines := cmp.Transformer("lines", func(b []byte) []string { return strings.Split(string(b), "\n") }) - toHex := cmp.Transformer("hex", func(b []byte) string { return hex.EncodeToString(b) }) - return cmp.Diff(got, want, - cmp.FilterPath(matchField("Control"), toLines), - cmp.FilterPath(matchField("MD5"), toHex), - cmp.FilterPath(matchField("SHA1"), toHex), - cmp.FilterPath(matchField("SHA256"), toHex)) -} - -func mkTestDeb(version, arch string) []byte { - info := nfpm.WithDefaults(&nfpm.Info{ - Name: "tailscale", - Description: "test package", - Arch: arch, - Platform: "linux", - Version: version, - Section: "net", - Priority: "extra", - Maintainer: "Tail Scalar", - }) - - pkg, err := nfpm.Get("deb") - if err != nil { - panic(fmt.Sprintf("getting deb packager: %v", err)) - } - - var b bytes.Buffer - if err := pkg.Package(info, &b); err != nil { - panic(fmt.Sprintf("creating deb package: %v", err)) - } - - return b.Bytes() -} - -func mkControl(fs ...string) []byte { - if len(fs)%2 != 0 { - panic("odd number of control file fields") - } - var b bytes.Buffer - for i := 0; i < len(fs); i = i + 2 { - k, v := fs[i], fs[i+1] - fmt.Fprintf(&b, "%s: %s\n", k, v) - } - return bytes.TrimSpace(b.Bytes()) -} - -func mkHash(b []byte, hasher func() hash.Hash) []byte { - h := hasher() - h.Write(b) - return h.Sum(nil) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package deb + +import ( + "bytes" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "encoding/hex" + "fmt" + "hash" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/goreleaser/nfpm/v2" + _ "github.com/goreleaser/nfpm/v2/deb" +) + +func TestDebInfo(t *testing.T) { + tests := []struct { + name string + in []byte + want *Info + wantErr bool + }{ + { + name: "simple", + in: mkTestDeb("1.2.3", "amd64"), + want: &Info{ + Version: "1.2.3", + Arch: "amd64", + Control: mkControl( + "Package", "tailscale", + "Version", "1.2.3", + "Section", "net", + "Priority", "extra", + "Architecture", "amd64", + "Maintainer", "Tail Scalar", + "Installed-Size", "0", + "Description", "test package"), + }, + }, + { + name: "arm64", + in: mkTestDeb("1.2.3", "arm64"), + want: &Info{ + Version: "1.2.3", + Arch: "arm64", + Control: mkControl( + "Package", "tailscale", + "Version", "1.2.3", + "Section", "net", + "Priority", "extra", + "Architecture", "arm64", + "Maintainer", "Tail Scalar", + "Installed-Size", "0", + "Description", "test package"), + }, + }, + { + name: "unstable", + in: mkTestDeb("1.7.25", "amd64"), + want: &Info{ + Version: "1.7.25", + Arch: "amd64", + Control: mkControl( + "Package", "tailscale", + "Version", "1.7.25", + "Section", "net", + "Priority", "extra", + "Architecture", "amd64", + "Maintainer", "Tail Scalar", + "Installed-Size", "0", + "Description", "test package"), + }, + }, + + // These truncation tests assume the structure of a .deb + // package, which is as follows: + // magic: 8 bytes + // file header: 60 bytes, before each file blob + // + // The first file in a .deb ar is "debian-binary", which is 4 + // bytes long and consists of "2.0\n". + // The second file is control.tar.gz, which is what we care + // about introspecting for metadata. + // The final file is data.tar.gz, which we don't care about. + // + // The first file in control.tar.gz is the "control" file we + // want to read for metadata. + { + name: "truncated_ar_magic", + in: mkTestDeb("1.7.25", "amd64")[:4], + wantErr: true, + }, + { + name: "truncated_ar_header", + in: mkTestDeb("1.7.25", "amd64")[:30], + wantErr: true, + }, + { + name: "missing_control_tgz", + // Truncate right after the "debian-binary" file, which + // makes the file a valid 1-file archive that's missing + // control.tar.gz. + in: mkTestDeb("1.7.25", "amd64")[:72], + wantErr: true, + }, + { + name: "truncated_tgz", + in: mkTestDeb("1.7.25", "amd64")[:172], + wantErr: true, + }, + } + + for _, test := range tests { + // mkTestDeb returns non-deterministic output due to + // timestamps embedded in the package file, so compute the + // wanted hashes on the fly here. + if test.want != nil { + test.want.MD5 = mkHash(test.in, md5.New) + test.want.SHA1 = mkHash(test.in, sha1.New) + test.want.SHA256 = mkHash(test.in, sha256.New) + } + + t.Run(test.name, func(t *testing.T) { + b := bytes.NewBuffer(test.in) + got, err := Read(b) + if err != nil { + if test.wantErr { + t.Logf("got expected error: %v", err) + return + } + t.Fatalf("reading deb info: %v", err) + } + if diff := diff(got, test.want); diff != "" { + t.Fatalf("parsed info diff (-got+want):\n%s", diff) + } + }) + } +} + +func diff(got, want any) string { + matchField := func(name string) func(p cmp.Path) bool { + return func(p cmp.Path) bool { + if len(p) != 3 { + return false + } + return p[2].String() == "."+name + } + } + toLines := cmp.Transformer("lines", func(b []byte) []string { return strings.Split(string(b), "\n") }) + toHex := cmp.Transformer("hex", func(b []byte) string { return hex.EncodeToString(b) }) + return cmp.Diff(got, want, + cmp.FilterPath(matchField("Control"), toLines), + cmp.FilterPath(matchField("MD5"), toHex), + cmp.FilterPath(matchField("SHA1"), toHex), + cmp.FilterPath(matchField("SHA256"), toHex)) +} + +func mkTestDeb(version, arch string) []byte { + info := nfpm.WithDefaults(&nfpm.Info{ + Name: "tailscale", + Description: "test package", + Arch: arch, + Platform: "linux", + Version: version, + Section: "net", + Priority: "extra", + Maintainer: "Tail Scalar", + }) + + pkg, err := nfpm.Get("deb") + if err != nil { + panic(fmt.Sprintf("getting deb packager: %v", err)) + } + + var b bytes.Buffer + if err := pkg.Package(info, &b); err != nil { + panic(fmt.Sprintf("creating deb package: %v", err)) + } + + return b.Bytes() +} + +func mkControl(fs ...string) []byte { + if len(fs)%2 != 0 { + panic("odd number of control file fields") + } + var b bytes.Buffer + for i := 0; i < len(fs); i = i + 2 { + k, v := fs[i], fs[i+1] + fmt.Fprintf(&b, "%s: %s\n", k, v) + } + return bytes.TrimSpace(b.Bytes()) +} + +func mkHash(b []byte, hasher func() hash.Hash) []byte { + h := hasher() + h.Write(b) + return h.Sum(nil) +} diff --git a/paths/migrate.go b/paths/migrate.go index 3a23ecca34fdc..11d90a6272a65 100644 --- a/paths/migrate.go +++ b/paths/migrate.go @@ -1,58 +1,58 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package paths - -import ( - "os" - "path/filepath" - - "tailscale.com/types/logger" -) - -// TryConfigFileMigration carefully copies the contents of oldFile to -// newFile, returning the path which should be used to read the config. -// - if newFile already exists, don't modify it just return its path -// - if neither oldFile nor newFile exist, return newFile for a fresh -// default config to be written to. -// - if oldFile exists but copying to newFile fails, return oldFile so -// there will at least be some config to work with. -func TryConfigFileMigration(logf logger.Logf, oldFile, newFile string) string { - _, err := os.Stat(newFile) - if err == nil { - // Common case for a system which has already been migrated. - return newFile - } - if !os.IsNotExist(err) { - logf("TryConfigFileMigration failed; new file: %v", err) - return newFile - } - - contents, err := os.ReadFile(oldFile) - if err != nil { - // Common case for a new user. - return newFile - } - - if err = MkStateDir(filepath.Dir(newFile)); err != nil { - logf("TryConfigFileMigration failed; MkStateDir: %v", err) - return oldFile - } - - err = os.WriteFile(newFile, contents, 0600) - if err != nil { - removeErr := os.Remove(newFile) - if removeErr != nil { - logf("TryConfigFileMigration failed; write newFile no cleanup: %v, remove err: %v", - err, removeErr) - return oldFile - } - logf("TryConfigFileMigration failed; write newFile: %v", err) - return oldFile - } - - logf("TryConfigFileMigration: successfully migrated: from %v to %v", - oldFile, newFile) - - return newFile -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package paths + +import ( + "os" + "path/filepath" + + "tailscale.com/types/logger" +) + +// TryConfigFileMigration carefully copies the contents of oldFile to +// newFile, returning the path which should be used to read the config. +// - if newFile already exists, don't modify it just return its path +// - if neither oldFile nor newFile exist, return newFile for a fresh +// default config to be written to. +// - if oldFile exists but copying to newFile fails, return oldFile so +// there will at least be some config to work with. +func TryConfigFileMigration(logf logger.Logf, oldFile, newFile string) string { + _, err := os.Stat(newFile) + if err == nil { + // Common case for a system which has already been migrated. + return newFile + } + if !os.IsNotExist(err) { + logf("TryConfigFileMigration failed; new file: %v", err) + return newFile + } + + contents, err := os.ReadFile(oldFile) + if err != nil { + // Common case for a new user. + return newFile + } + + if err = MkStateDir(filepath.Dir(newFile)); err != nil { + logf("TryConfigFileMigration failed; MkStateDir: %v", err) + return oldFile + } + + err = os.WriteFile(newFile, contents, 0600) + if err != nil { + removeErr := os.Remove(newFile) + if removeErr != nil { + logf("TryConfigFileMigration failed; write newFile no cleanup: %v, remove err: %v", + err, removeErr) + return oldFile + } + logf("TryConfigFileMigration failed; write newFile: %v", err) + return oldFile + } + + logf("TryConfigFileMigration: successfully migrated: from %v to %v", + oldFile, newFile) + + return newFile +} diff --git a/paths/paths.go b/paths/paths.go index 28c3be02a9c86..8cee4cabfd2a9 100644 --- a/paths/paths.go +++ b/paths/paths.go @@ -1,92 +1,92 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package paths returns platform and user-specific default paths to -// Tailscale files and directories. -package paths - -import ( - "os" - "path/filepath" - "runtime" - - "tailscale.com/syncs" - "tailscale.com/version/distro" -) - -// AppSharedDir is a string set by the iOS or Android app on start -// containing a directory we can read/write in. -var AppSharedDir syncs.AtomicValue[string] - -// DefaultTailscaledSocket returns the path to the tailscaled Unix socket -// or the empty string if there's no reasonable default. -func DefaultTailscaledSocket() string { - if runtime.GOOS == "windows" { - return `\\.\pipe\ProtectedPrefix\Administrators\Tailscale\tailscaled` - } - if runtime.GOOS == "darwin" { - return "/var/run/tailscaled.socket" - } - if runtime.GOOS == "plan9" { - return "/srv/tailscaled.sock" - } - switch distro.Get() { - case distro.Synology: - if distro.DSMVersion() == 6 { - return "/var/packages/Tailscale/etc/tailscaled.sock" - } - // DSM 7 (and higher? or failure to detect.) - return "/var/packages/Tailscale/var/tailscaled.sock" - case distro.Gokrazy: - return "/perm/tailscaled/tailscaled.sock" - case distro.QNAP: - return "/tmp/tailscale/tailscaled.sock" - } - if fi, err := os.Stat("/var/run"); err == nil && fi.IsDir() { - return "/var/run/tailscale/tailscaled.sock" - } - return "tailscaled.sock" -} - -// Overridden in init by OS-specific files. -var ( - stateFileFunc func() string - - // ensureStateDirPerms applies a restrictive ACL/chmod - // to the provided directory. - ensureStateDirPerms = func(string) error { return nil } -) - -// DefaultTailscaledStateFile returns the default path to the -// tailscaled state file, or the empty string if there's no reasonable -// default value. -func DefaultTailscaledStateFile() string { - if f := stateFileFunc; f != nil { - return f() - } - if runtime.GOOS == "windows" { - return filepath.Join(os.Getenv("ProgramData"), "Tailscale", "server-state.conf") - } - return "" -} - -// MkStateDir ensures that dirPath, the daemon's configuration directory -// containing machine keys etc, both exists and has the correct permissions. -// We want it to only be accessible to the user the daemon is running under. -func MkStateDir(dirPath string) error { - if err := os.MkdirAll(dirPath, 0700); err != nil { - return err - } - return ensureStateDirPerms(dirPath) -} - -// LegacyStateFilePath returns the legacy path to the state file when -// it was stored under the current user's %LocalAppData%. -// -// It is only called on Windows. -func LegacyStateFilePath() string { - if runtime.GOOS == "windows" { - return filepath.Join(os.Getenv("LocalAppData"), "Tailscale", "server-state.conf") - } - return "" -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package paths returns platform and user-specific default paths to +// Tailscale files and directories. +package paths + +import ( + "os" + "path/filepath" + "runtime" + + "tailscale.com/syncs" + "tailscale.com/version/distro" +) + +// AppSharedDir is a string set by the iOS or Android app on start +// containing a directory we can read/write in. +var AppSharedDir syncs.AtomicValue[string] + +// DefaultTailscaledSocket returns the path to the tailscaled Unix socket +// or the empty string if there's no reasonable default. +func DefaultTailscaledSocket() string { + if runtime.GOOS == "windows" { + return `\\.\pipe\ProtectedPrefix\Administrators\Tailscale\tailscaled` + } + if runtime.GOOS == "darwin" { + return "/var/run/tailscaled.socket" + } + if runtime.GOOS == "plan9" { + return "/srv/tailscaled.sock" + } + switch distro.Get() { + case distro.Synology: + if distro.DSMVersion() == 6 { + return "/var/packages/Tailscale/etc/tailscaled.sock" + } + // DSM 7 (and higher? or failure to detect.) + return "/var/packages/Tailscale/var/tailscaled.sock" + case distro.Gokrazy: + return "/perm/tailscaled/tailscaled.sock" + case distro.QNAP: + return "/tmp/tailscale/tailscaled.sock" + } + if fi, err := os.Stat("/var/run"); err == nil && fi.IsDir() { + return "/var/run/tailscale/tailscaled.sock" + } + return "tailscaled.sock" +} + +// Overridden in init by OS-specific files. +var ( + stateFileFunc func() string + + // ensureStateDirPerms applies a restrictive ACL/chmod + // to the provided directory. + ensureStateDirPerms = func(string) error { return nil } +) + +// DefaultTailscaledStateFile returns the default path to the +// tailscaled state file, or the empty string if there's no reasonable +// default value. +func DefaultTailscaledStateFile() string { + if f := stateFileFunc; f != nil { + return f() + } + if runtime.GOOS == "windows" { + return filepath.Join(os.Getenv("ProgramData"), "Tailscale", "server-state.conf") + } + return "" +} + +// MkStateDir ensures that dirPath, the daemon's configuration directory +// containing machine keys etc, both exists and has the correct permissions. +// We want it to only be accessible to the user the daemon is running under. +func MkStateDir(dirPath string) error { + if err := os.MkdirAll(dirPath, 0700); err != nil { + return err + } + return ensureStateDirPerms(dirPath) +} + +// LegacyStateFilePath returns the legacy path to the state file when +// it was stored under the current user's %LocalAppData%. +// +// It is only called on Windows. +func LegacyStateFilePath() string { + if runtime.GOOS == "windows" { + return filepath.Join(os.Getenv("LocalAppData"), "Tailscale", "server-state.conf") + } + return "" +} diff --git a/paths/paths_windows.go b/paths/paths_windows.go index 4705400655212..2249810494b14 100644 --- a/paths/paths_windows.go +++ b/paths/paths_windows.go @@ -1,100 +1,100 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package paths - -import ( - "os" - "path/filepath" - "strings" - - "golang.org/x/sys/windows" - "tailscale.com/util/winutil" -) - -func init() { - ensureStateDirPerms = ensureStateDirPermsWindows -} - -// ensureStateDirPermsWindows applies a restrictive ACL to the directory specified by dirPath. -// It sets the following security attributes on the directory: -// Owner: The user for the current process; -// Primary Group: The primary group for the current process; -// DACL: Full control to the current user and to the Administrators group. -// -// (We include Administrators so that admin users may still access logs; -// granting access exclusively to LocalSystem would require admins to use -// special tools to access the Log directory) -// -// Inheritance: The directory does not inherit the ACL from its parent. -// -// However, any directories and/or files created within this -// directory *do* inherit the ACL that we are setting. -func ensureStateDirPermsWindows(dirPath string) error { - fi, err := os.Stat(dirPath) - if err != nil { - return err - } - if !fi.IsDir() { - return os.ErrInvalid - } - if strings.ToLower(filepath.Base(dirPath)) != "tailscale" { - return nil - } - - // We need the info for our current user as SIDs - sids, err := winutil.GetCurrentUserSIDs() - if err != nil { - return err - } - - // We also need the SID for the Administrators group so that admins may - // easily access logs. - adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid) - if err != nil { - return err - } - - // Munge the SIDs into the format required by EXPLICIT_ACCESS. - userTrustee := windows.TRUSTEE{nil, windows.NO_MULTIPLE_TRUSTEE, - windows.TRUSTEE_IS_SID, windows.TRUSTEE_IS_USER, - windows.TrusteeValueFromSID(sids.User)} - - adminTrustee := windows.TRUSTEE{nil, windows.NO_MULTIPLE_TRUSTEE, - windows.TRUSTEE_IS_SID, windows.TRUSTEE_IS_WELL_KNOWN_GROUP, - windows.TrusteeValueFromSID(adminGroupSid)} - - // We declare our access rights via this array of EXPLICIT_ACCESS structures. - // We set full access to our user and to Administrators. - // We configure the DACL such that any files or directories created within - // dirPath will also inherit this DACL. - explicitAccess := []windows.EXPLICIT_ACCESS{ - { - windows.GENERIC_ALL, - windows.SET_ACCESS, - windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, - userTrustee, - }, - { - windows.GENERIC_ALL, - windows.SET_ACCESS, - windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, - adminTrustee, - }, - } - - dacl, err := windows.ACLFromEntries(explicitAccess, nil) - if err != nil { - return err - } - - // We now reset the file's owner, primary group, and DACL. - // We also must pass PROTECTED_DACL_SECURITY_INFORMATION so that our new ACL - // does not inherit any ACL entries from the parent directory. - const flags = windows.OWNER_SECURITY_INFORMATION | - windows.GROUP_SECURITY_INFORMATION | - windows.DACL_SECURITY_INFORMATION | - windows.PROTECTED_DACL_SECURITY_INFORMATION - return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, flags, - sids.User, sids.PrimaryGroup, dacl, nil) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package paths + +import ( + "os" + "path/filepath" + "strings" + + "golang.org/x/sys/windows" + "tailscale.com/util/winutil" +) + +func init() { + ensureStateDirPerms = ensureStateDirPermsWindows +} + +// ensureStateDirPermsWindows applies a restrictive ACL to the directory specified by dirPath. +// It sets the following security attributes on the directory: +// Owner: The user for the current process; +// Primary Group: The primary group for the current process; +// DACL: Full control to the current user and to the Administrators group. +// +// (We include Administrators so that admin users may still access logs; +// granting access exclusively to LocalSystem would require admins to use +// special tools to access the Log directory) +// +// Inheritance: The directory does not inherit the ACL from its parent. +// +// However, any directories and/or files created within this +// directory *do* inherit the ACL that we are setting. +func ensureStateDirPermsWindows(dirPath string) error { + fi, err := os.Stat(dirPath) + if err != nil { + return err + } + if !fi.IsDir() { + return os.ErrInvalid + } + if strings.ToLower(filepath.Base(dirPath)) != "tailscale" { + return nil + } + + // We need the info for our current user as SIDs + sids, err := winutil.GetCurrentUserSIDs() + if err != nil { + return err + } + + // We also need the SID for the Administrators group so that admins may + // easily access logs. + adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid) + if err != nil { + return err + } + + // Munge the SIDs into the format required by EXPLICIT_ACCESS. + userTrustee := windows.TRUSTEE{nil, windows.NO_MULTIPLE_TRUSTEE, + windows.TRUSTEE_IS_SID, windows.TRUSTEE_IS_USER, + windows.TrusteeValueFromSID(sids.User)} + + adminTrustee := windows.TRUSTEE{nil, windows.NO_MULTIPLE_TRUSTEE, + windows.TRUSTEE_IS_SID, windows.TRUSTEE_IS_WELL_KNOWN_GROUP, + windows.TrusteeValueFromSID(adminGroupSid)} + + // We declare our access rights via this array of EXPLICIT_ACCESS structures. + // We set full access to our user and to Administrators. + // We configure the DACL such that any files or directories created within + // dirPath will also inherit this DACL. + explicitAccess := []windows.EXPLICIT_ACCESS{ + { + windows.GENERIC_ALL, + windows.SET_ACCESS, + windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, + userTrustee, + }, + { + windows.GENERIC_ALL, + windows.SET_ACCESS, + windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, + adminTrustee, + }, + } + + dacl, err := windows.ACLFromEntries(explicitAccess, nil) + if err != nil { + return err + } + + // We now reset the file's owner, primary group, and DACL. + // We also must pass PROTECTED_DACL_SECURITY_INFORMATION so that our new ACL + // does not inherit any ACL entries from the parent directory. + const flags = windows.OWNER_SECURITY_INFORMATION | + windows.GROUP_SECURITY_INFORMATION | + windows.DACL_SECURITY_INFORMATION | + windows.PROTECTED_DACL_SECURITY_INFORMATION + return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, flags, + sids.User, sids.PrimaryGroup, dacl, nil) +} diff --git a/portlist/clean.go b/portlist/clean.go index 7e137de948e99..cad1562c3e1d8 100644 --- a/portlist/clean.go +++ b/portlist/clean.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package portlist - -import ( - "path/filepath" - "strings" -) - -// argvSubject takes a command and its flags, and returns the -// short/pretty name for the process. This is usually the basename of -// the binary being executed, but can sometimes vary (e.g. so that we -// don't report all Java programs as "java"). -func argvSubject(argv ...string) string { - if len(argv) == 0 { - return "" - } - ret := filepath.Base(argv[0]) - - // Handle special cases. - switch { - case ret == "mono" && len(argv) >= 2: - // .Net programs execute as `mono actualProgram.exe`. - ret = filepath.Base(argv[1]) - } - - // Handle space separated argv - ret, _, _ = strings.Cut(ret, " ") - - // Remove common noise. - ret = strings.TrimSpace(ret) - ret = strings.TrimSuffix(ret, ".exe") - - return ret -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package portlist + +import ( + "path/filepath" + "strings" +) + +// argvSubject takes a command and its flags, and returns the +// short/pretty name for the process. This is usually the basename of +// the binary being executed, but can sometimes vary (e.g. so that we +// don't report all Java programs as "java"). +func argvSubject(argv ...string) string { + if len(argv) == 0 { + return "" + } + ret := filepath.Base(argv[0]) + + // Handle special cases. + switch { + case ret == "mono" && len(argv) >= 2: + // .Net programs execute as `mono actualProgram.exe`. + ret = filepath.Base(argv[1]) + } + + // Handle space separated argv + ret, _, _ = strings.Cut(ret, " ") + + // Remove common noise. + ret = strings.TrimSpace(ret) + ret = strings.TrimSuffix(ret, ".exe") + + return ret +} diff --git a/portlist/clean_test.go b/portlist/clean_test.go index 5a1e34405eed0..cca18ab8eb2c6 100644 --- a/portlist/clean_test.go +++ b/portlist/clean_test.go @@ -1,57 +1,57 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package portlist - -import "testing" - -func TestArgvSubject(t *testing.T) { - tests := []struct { - in []string - want string - }{ - { - in: nil, - want: "", - }, - { - in: []string{"/usr/bin/sshd"}, - want: "sshd", - }, - { - in: []string{"/bin/mono"}, - want: "mono", - }, - { - in: []string{"/nix/store/x2cw2xjw98zdysf56bdlfzsr7cyxv0jf-mono-5.20.1.27/bin/mono", "/bin/exampleProgram.exe"}, - want: "exampleProgram", - }, - { - in: []string{"/bin/mono", "/sbin/exampleProgram.bin"}, - want: "exampleProgram.bin", - }, - { - in: []string{"/usr/bin/sshd_config [listener] 1 of 10-100 startups"}, - want: "sshd_config", - }, - { - in: []string{"/usr/bin/sshd [listener] 0 of 10-100 startups"}, - want: "sshd", - }, - { - in: []string{"/opt/aws/bin/eic_run_authorized_keys %u %f -o AuthorizedKeysCommandUser ec2-instance-connect [listener] 0 of 10-100 startups"}, - want: "eic_run_authorized_keys", - }, - { - in: []string{"/usr/bin/nginx worker"}, - want: "nginx", - }, - } - - for _, test := range tests { - got := argvSubject(test.in...) - if got != test.want { - t.Errorf("argvSubject(%v) = %q, want %q", test.in, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package portlist + +import "testing" + +func TestArgvSubject(t *testing.T) { + tests := []struct { + in []string + want string + }{ + { + in: nil, + want: "", + }, + { + in: []string{"/usr/bin/sshd"}, + want: "sshd", + }, + { + in: []string{"/bin/mono"}, + want: "mono", + }, + { + in: []string{"/nix/store/x2cw2xjw98zdysf56bdlfzsr7cyxv0jf-mono-5.20.1.27/bin/mono", "/bin/exampleProgram.exe"}, + want: "exampleProgram", + }, + { + in: []string{"/bin/mono", "/sbin/exampleProgram.bin"}, + want: "exampleProgram.bin", + }, + { + in: []string{"/usr/bin/sshd_config [listener] 1 of 10-100 startups"}, + want: "sshd_config", + }, + { + in: []string{"/usr/bin/sshd [listener] 0 of 10-100 startups"}, + want: "sshd", + }, + { + in: []string{"/opt/aws/bin/eic_run_authorized_keys %u %f -o AuthorizedKeysCommandUser ec2-instance-connect [listener] 0 of 10-100 startups"}, + want: "eic_run_authorized_keys", + }, + { + in: []string{"/usr/bin/nginx worker"}, + want: "nginx", + }, + } + + for _, test := range tests { + got := argvSubject(test.in...) + if got != test.want { + t.Errorf("argvSubject(%v) = %q, want %q", test.in, got, test.want) + } + } +} diff --git a/portlist/netstat_test.go b/portlist/netstat_test.go index 023b75b794426..d04b657f623f4 100644 --- a/portlist/netstat_test.go +++ b/portlist/netstat_test.go @@ -1,92 +1,92 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin && !ios - -package portlist - -import ( - "bufio" - "encoding/json" - "fmt" - "strings" - "testing" - - "go4.org/mem" -) - -func TestParsePort(t *testing.T) { - type InOut struct { - in string - expect int - } - tests := []InOut{ - {"1.2.3.4:5678", 5678}, - {"0.0.0.0.999", 999}, - {"1.2.3.4:*", 0}, - {"5.5.5.5:0", 0}, - {"[1::2]:5", 5}, - {"[1::2].5", 5}, - {"gibberish", -1}, - } - - for _, io := range tests { - got := parsePort(mem.S(io.in)) - if got != io.expect { - t.Fatalf("input:%#v expect:%v got:%v\n", io.in, io.expect, got) - } - } -} - -const netstatOutput = ` -// macOS -tcp4 0 0 *.23 *.* LISTEN -tcp6 0 0 *.24 *.* LISTEN -tcp4 0 0 *.8185 *.* LISTEN -tcp4 0 0 127.0.0.1.8186 *.* LISTEN -tcp6 0 0 ::1.8187 *.* LISTEN -tcp4 0 0 127.1.2.3.8188 *.* LISTEN - -udp6 0 0 *.106 *.* -udp4 0 0 *.104 *.* -udp46 0 0 *.146 *.* -` - -func TestParsePortsNetstat(t *testing.T) { - for _, loopBack := range [...]bool{false, true} { - t.Run(fmt.Sprintf("loopback_%v", loopBack), func(t *testing.T) { - want := List{ - {"tcp", 23, "", 0}, - {"tcp", 24, "", 0}, - {"udp", 104, "", 0}, - {"udp", 106, "", 0}, - {"udp", 146, "", 0}, - {"tcp", 8185, "", 0}, // but not 8186, 8187, 8188 on localhost, when loopback is false - } - if loopBack { - want = append(want, - Port{"tcp", 8186, "", 0}, - Port{"tcp", 8187, "", 0}, - Port{"tcp", 8188, "", 0}, - ) - } - pl, err := appendParsePortsNetstat(nil, bufio.NewReader(strings.NewReader(netstatOutput)), loopBack) - if err != nil { - t.Fatal(err) - } - pl = sortAndDedup(pl) - jgot, _ := json.MarshalIndent(pl, "", "\t") - jwant, _ := json.MarshalIndent(want, "", "\t") - if len(pl) != len(want) { - t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant) - } - for i := range pl { - if pl[i] != want[i] { - t.Errorf("row#%d\n got: %+v\n\nwant: %+v\n", - i, pl[i], want[i]) - t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant) - } - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin && !ios + +package portlist + +import ( + "bufio" + "encoding/json" + "fmt" + "strings" + "testing" + + "go4.org/mem" +) + +func TestParsePort(t *testing.T) { + type InOut struct { + in string + expect int + } + tests := []InOut{ + {"1.2.3.4:5678", 5678}, + {"0.0.0.0.999", 999}, + {"1.2.3.4:*", 0}, + {"5.5.5.5:0", 0}, + {"[1::2]:5", 5}, + {"[1::2].5", 5}, + {"gibberish", -1}, + } + + for _, io := range tests { + got := parsePort(mem.S(io.in)) + if got != io.expect { + t.Fatalf("input:%#v expect:%v got:%v\n", io.in, io.expect, got) + } + } +} + +const netstatOutput = ` +// macOS +tcp4 0 0 *.23 *.* LISTEN +tcp6 0 0 *.24 *.* LISTEN +tcp4 0 0 *.8185 *.* LISTEN +tcp4 0 0 127.0.0.1.8186 *.* LISTEN +tcp6 0 0 ::1.8187 *.* LISTEN +tcp4 0 0 127.1.2.3.8188 *.* LISTEN + +udp6 0 0 *.106 *.* +udp4 0 0 *.104 *.* +udp46 0 0 *.146 *.* +` + +func TestParsePortsNetstat(t *testing.T) { + for _, loopBack := range [...]bool{false, true} { + t.Run(fmt.Sprintf("loopback_%v", loopBack), func(t *testing.T) { + want := List{ + {"tcp", 23, "", 0}, + {"tcp", 24, "", 0}, + {"udp", 104, "", 0}, + {"udp", 106, "", 0}, + {"udp", 146, "", 0}, + {"tcp", 8185, "", 0}, // but not 8186, 8187, 8188 on localhost, when loopback is false + } + if loopBack { + want = append(want, + Port{"tcp", 8186, "", 0}, + Port{"tcp", 8187, "", 0}, + Port{"tcp", 8188, "", 0}, + ) + } + pl, err := appendParsePortsNetstat(nil, bufio.NewReader(strings.NewReader(netstatOutput)), loopBack) + if err != nil { + t.Fatal(err) + } + pl = sortAndDedup(pl) + jgot, _ := json.MarshalIndent(pl, "", "\t") + jwant, _ := json.MarshalIndent(want, "", "\t") + if len(pl) != len(want) { + t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant) + } + for i := range pl { + if pl[i] != want[i] { + t.Errorf("row#%d\n got: %+v\n\nwant: %+v\n", + i, pl[i], want[i]) + t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant) + } + } + }) + } +} diff --git a/portlist/poller.go b/portlist/poller.go index 423bad3be33ba..226f3b9958e8d 100644 --- a/portlist/poller.go +++ b/portlist/poller.go @@ -1,122 +1,122 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// This file contains the code related to the Poller type and its methods. -// The hot loop to keep efficient is Poller.Run. - -package portlist - -import ( - "errors" - "fmt" - "runtime" - "slices" - "sync" - "time" - - "tailscale.com/envknob" -) - -var ( - newOSImpl func(includeLocalhost bool) osImpl // if non-nil, constructs a new osImpl. - pollInterval = 5 * time.Second // default; changed by some OS-specific init funcs - debugDisablePortlist = envknob.RegisterBool("TS_DEBUG_DISABLE_PORTLIST") -) - -// PollInterval is the recommended OS-specific interval -// to wait between *Poller.Poll method calls. -func PollInterval() time.Duration { - return pollInterval -} - -// Poller scans the systems for listening ports periodically and sends -// the results to C. -type Poller struct { - // IncludeLocalhost controls whether services bound to localhost are included. - // - // This field should only be changed before calling Run. - IncludeLocalhost bool - - // os, if non-nil, is an OS-specific implementation of the portlist getting - // code. When non-nil, it's responsible for getting the complete list of - // cached ports complete with the process name. That is, when set, - // addProcesses is not used. - // A nil values means we don't have code for getting the list on the current - // operating system. - os osImpl - initOnce sync.Once // guards init of os - initErr error - - // scatch is memory for Poller.getList to reuse between calls. - scratch []Port - - prev List // most recent data, not aliasing scratch -} - -// osImpl is the OS-specific implementation of getting the open listening ports. -type osImpl interface { - Close() error - - // AppendListeningPorts appends to base (which must have length 0 but - // optional capacity) the list of listening ports. The Port struct should be - // populated as completely as possible. Another pass will not add anything - // to it. - // - // The appended ports should be in a sorted (or at least stable) order so - // the caller can cheaply detect when there are no changes. - AppendListeningPorts(base []Port) ([]Port, error) -} - -func (p *Poller) setPrev(pl List) { - // Make a copy, as the pass in pl slice aliases pl.scratch and we don't want - // that to except to the caller. - p.prev = slices.Clone(pl) -} - -// init initializes the Poller by ensuring it has an underlying -// OS implementation and is not turned off by envknob. -func (p *Poller) init() { - switch { - case debugDisablePortlist(): - p.initErr = errors.New("portlist disabled by envknob") - case newOSImpl == nil: - p.initErr = errors.New("portlist poller not implemented on " + runtime.GOOS) - default: - p.os = newOSImpl(p.IncludeLocalhost) - } -} - -// Close closes the Poller. -func (p *Poller) Close() error { - if p.initErr != nil { - return p.initErr - } - if p.os == nil { - return nil - } - return p.os.Close() -} - -// Poll returns the list of listening ports, if changed from -// a previous call as indicated by the changed result. -func (p *Poller) Poll() (ports []Port, changed bool, err error) { - p.initOnce.Do(p.init) - if p.initErr != nil { - return nil, false, fmt.Errorf("error initializing poller: %w", p.initErr) - } - pl, err := p.getList() - if err != nil { - return nil, false, err - } - if pl.equal(p.prev) { - return nil, false, nil - } - p.setPrev(pl) - return p.prev, true, nil -} - -func (p *Poller) getList() (List, error) { - var err error - p.scratch, err = p.os.AppendListeningPorts(p.scratch[:0]) - return p.scratch, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This file contains the code related to the Poller type and its methods. +// The hot loop to keep efficient is Poller.Run. + +package portlist + +import ( + "errors" + "fmt" + "runtime" + "slices" + "sync" + "time" + + "tailscale.com/envknob" +) + +var ( + newOSImpl func(includeLocalhost bool) osImpl // if non-nil, constructs a new osImpl. + pollInterval = 5 * time.Second // default; changed by some OS-specific init funcs + debugDisablePortlist = envknob.RegisterBool("TS_DEBUG_DISABLE_PORTLIST") +) + +// PollInterval is the recommended OS-specific interval +// to wait between *Poller.Poll method calls. +func PollInterval() time.Duration { + return pollInterval +} + +// Poller scans the systems for listening ports periodically and sends +// the results to C. +type Poller struct { + // IncludeLocalhost controls whether services bound to localhost are included. + // + // This field should only be changed before calling Run. + IncludeLocalhost bool + + // os, if non-nil, is an OS-specific implementation of the portlist getting + // code. When non-nil, it's responsible for getting the complete list of + // cached ports complete with the process name. That is, when set, + // addProcesses is not used. + // A nil values means we don't have code for getting the list on the current + // operating system. + os osImpl + initOnce sync.Once // guards init of os + initErr error + + // scatch is memory for Poller.getList to reuse between calls. + scratch []Port + + prev List // most recent data, not aliasing scratch +} + +// osImpl is the OS-specific implementation of getting the open listening ports. +type osImpl interface { + Close() error + + // AppendListeningPorts appends to base (which must have length 0 but + // optional capacity) the list of listening ports. The Port struct should be + // populated as completely as possible. Another pass will not add anything + // to it. + // + // The appended ports should be in a sorted (or at least stable) order so + // the caller can cheaply detect when there are no changes. + AppendListeningPorts(base []Port) ([]Port, error) +} + +func (p *Poller) setPrev(pl List) { + // Make a copy, as the pass in pl slice aliases pl.scratch and we don't want + // that to except to the caller. + p.prev = slices.Clone(pl) +} + +// init initializes the Poller by ensuring it has an underlying +// OS implementation and is not turned off by envknob. +func (p *Poller) init() { + switch { + case debugDisablePortlist(): + p.initErr = errors.New("portlist disabled by envknob") + case newOSImpl == nil: + p.initErr = errors.New("portlist poller not implemented on " + runtime.GOOS) + default: + p.os = newOSImpl(p.IncludeLocalhost) + } +} + +// Close closes the Poller. +func (p *Poller) Close() error { + if p.initErr != nil { + return p.initErr + } + if p.os == nil { + return nil + } + return p.os.Close() +} + +// Poll returns the list of listening ports, if changed from +// a previous call as indicated by the changed result. +func (p *Poller) Poll() (ports []Port, changed bool, err error) { + p.initOnce.Do(p.init) + if p.initErr != nil { + return nil, false, fmt.Errorf("error initializing poller: %w", p.initErr) + } + pl, err := p.getList() + if err != nil { + return nil, false, err + } + if pl.equal(p.prev) { + return nil, false, nil + } + p.setPrev(pl) + return p.prev, true, nil +} + +func (p *Poller) getList() (List, error) { + var err error + p.scratch, err = p.os.AppendListeningPorts(p.scratch[:0]) + return p.scratch, err +} diff --git a/portlist/portlist.go b/portlist/portlist.go index 9f7af40d08dc1..6d24cedcc5038 100644 --- a/portlist/portlist.go +++ b/portlist/portlist.go @@ -1,80 +1,80 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// This file is just the types. The bulk of the code is in poller.go. - -// The portlist package contains code that checks what ports are open and -// listening on the current machine. -package portlist - -import ( - "fmt" - "sort" - "strings" -) - -// Port is a listening port on the machine. -type Port struct { - Proto string // "tcp" or "udp" - Port uint16 // port number - Process string // optional process name, if found (requires suitable permissions) - Pid int // process ID, if known (requires suitable permissions) -} - -// List is a list of Ports. -type List []Port - -func (a *Port) lessThan(b *Port) bool { - if a.Port != b.Port { - return a.Port < b.Port - } - if a.Proto != b.Proto { - return a.Proto < b.Proto - } - return a.Process < b.Process -} - -func (a *Port) equal(b *Port) bool { - return a.Port == b.Port && - a.Proto == b.Proto && - a.Process == b.Process -} - -func (a List) equal(b List) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if !a[i].equal(&b[i]) { - return false - } - } - return true -} - -func (pl List) String() string { - var sb strings.Builder - for _, v := range pl { - fmt.Fprintf(&sb, "%-3s %5d %#v\n", - v.Proto, v.Port, v.Process) - } - return strings.TrimRight(sb.String(), "\n") -} - -// sortAndDedup sorts ps in place (by Port.lessThan) and then returns -// a subset of it with duplicate (Proto, Port) removed. -func sortAndDedup(ps List) List { - sort.Slice(ps, func(i, j int) bool { - return (&ps[i]).lessThan(&ps[j]) - }) - out := ps[:0] - var last Port - for _, p := range ps { - if last.Proto == p.Proto && last.Port == p.Port { - continue - } - out = append(out, p) - last = p - } - return out -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This file is just the types. The bulk of the code is in poller.go. + +// The portlist package contains code that checks what ports are open and +// listening on the current machine. +package portlist + +import ( + "fmt" + "sort" + "strings" +) + +// Port is a listening port on the machine. +type Port struct { + Proto string // "tcp" or "udp" + Port uint16 // port number + Process string // optional process name, if found (requires suitable permissions) + Pid int // process ID, if known (requires suitable permissions) +} + +// List is a list of Ports. +type List []Port + +func (a *Port) lessThan(b *Port) bool { + if a.Port != b.Port { + return a.Port < b.Port + } + if a.Proto != b.Proto { + return a.Proto < b.Proto + } + return a.Process < b.Process +} + +func (a *Port) equal(b *Port) bool { + return a.Port == b.Port && + a.Proto == b.Proto && + a.Process == b.Process +} + +func (a List) equal(b List) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !a[i].equal(&b[i]) { + return false + } + } + return true +} + +func (pl List) String() string { + var sb strings.Builder + for _, v := range pl { + fmt.Fprintf(&sb, "%-3s %5d %#v\n", + v.Proto, v.Port, v.Process) + } + return strings.TrimRight(sb.String(), "\n") +} + +// sortAndDedup sorts ps in place (by Port.lessThan) and then returns +// a subset of it with duplicate (Proto, Port) removed. +func sortAndDedup(ps List) List { + sort.Slice(ps, func(i, j int) bool { + return (&ps[i]).lessThan(&ps[j]) + }) + out := ps[:0] + var last Port + for _, p := range ps { + if last.Proto == p.Proto && last.Port == p.Port { + continue + } + out = append(out, p) + last = p + } + return out +} diff --git a/portlist/portlist_macos.go b/portlist/portlist_macos.go index e67b2c9b8c064..2f4fee351f1cf 100644 --- a/portlist/portlist_macos.go +++ b/portlist/portlist_macos.go @@ -1,230 +1,230 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin && !ios - -package portlist - -import ( - "bufio" - "bytes" - "fmt" - "log" - "os/exec" - "strings" - "sync/atomic" - "time" - - "go4.org/mem" -) - -func init() { - newOSImpl = newMacOSImpl - - // We have to run netstat, which is a bit expensive, so don't do it too often. - pollInterval = 5 * time.Second -} - -type macOSImpl struct { - known map[protoPort]*portMeta // inode string => metadata - netstatPath string // lazily populated - - br *bufio.Reader // reused - portsBuf []Port - includeLocalhost bool -} - -type protoPort struct { - proto string - port uint16 -} - -type portMeta struct { - port Port - keep bool -} - -func newMacOSImpl(includeLocalhost bool) osImpl { - return &macOSImpl{ - known: map[protoPort]*portMeta{}, - br: bufio.NewReader(bytes.NewReader(nil)), - includeLocalhost: includeLocalhost, - } -} - -func (*macOSImpl) Close() error { return nil } - -func (im *macOSImpl) AppendListeningPorts(base []Port) ([]Port, error) { - var err error - im.portsBuf, err = im.appendListeningPortsNetstat(im.portsBuf[:0]) - if err != nil { - return nil, err - } - - for _, pm := range im.known { - pm.keep = false - } - - var needProcs bool - for _, p := range im.portsBuf { - fp := protoPort{ - proto: p.Proto, - port: p.Port, - } - if pm, ok := im.known[fp]; ok { - pm.keep = true - } else { - needProcs = true - im.known[fp] = &portMeta{ - port: p, - keep: true, - } - } - } - - ret := base - for k, m := range im.known { - if !m.keep { - delete(im.known, k) - } - } - - if needProcs { - im.addProcesses() // best effort - } - - for _, m := range im.known { - ret = append(ret, m.port) - } - return sortAndDedup(ret), nil -} - -func (im *macOSImpl) appendListeningPortsNetstat(base []Port) ([]Port, error) { - if im.netstatPath == "" { - var err error - im.netstatPath, err = exec.LookPath("netstat") - if err != nil { - return nil, fmt.Errorf("netstat: lookup: %v", err) - } - } - - cmd := exec.Command(im.netstatPath, "-na") - outPipe, err := cmd.StdoutPipe() - if err != nil { - return nil, err - } - im.br.Reset(outPipe) - - if err := cmd.Start(); err != nil { - return nil, err - } - defer cmd.Process.Wait() - defer cmd.Process.Kill() - - return appendParsePortsNetstat(base, im.br, im.includeLocalhost) -} - -var lsofFailed atomic.Bool - -// In theory, lsof could replace the function of both listPorts() and -// addProcesses(), since it provides a superset of the netstat output. -// However, "netstat -na" runs ~100x faster than lsof on my machine, so -// we should do it only if the list of open ports has actually changed. -// -// This fails in a macOS sandbox (i.e. in the Mac App Store or System -// Extension GUI build), but does at least work in the -// tailscaled-on-macos mode. -func (im *macOSImpl) addProcesses() error { - if lsofFailed.Load() { - // This previously failed in the macOS sandbox, so don't try again. - return nil - } - exe, err := exec.LookPath("lsof") - if err != nil { - return fmt.Errorf("lsof: lookup: %v", err) - } - lsofCmd := exec.Command(exe, "-F", "-n", "-P", "-O", "-S2", "-T", "-i4", "-i6") - outPipe, err := lsofCmd.StdoutPipe() - if err != nil { - return err - } - err = lsofCmd.Start() - if err != nil { - var stderr []byte - if xe, ok := err.(*exec.ExitError); ok { - stderr = xe.Stderr - } - // fails when run in a macOS sandbox, so make this non-fatal. - if lsofFailed.CompareAndSwap(false, true) { - log.Printf("portlist: can't run lsof in Mac sandbox; omitting process names from service list. Error details: %v, %s", err, bytes.TrimSpace(stderr)) - } - return nil - } - defer func() { - ps, err := lsofCmd.Process.Wait() - if err != nil || ps.ExitCode() != 0 { - log.Printf("portlist: can't run lsof in Mac sandbox; omitting process names from service list. Error: %v, exit code %d", err, ps.ExitCode()) - lsofFailed.Store(true) - } - }() - defer lsofCmd.Process.Kill() - - im.br.Reset(outPipe) - - var cmd, proto string - var pid int - for { - line, err := im.br.ReadBytes('\n') - if err != nil { - break - } - if len(line) < 1 { - continue - } - field, val := line[0], bytes.TrimSpace(line[1:]) - switch field { - case 'p': - // starting a new process - cmd = "" - proto = "" - pid = 0 - if p, err := mem.ParseInt(mem.B(val), 10, 0); err == nil { - pid = int(p) - } - case 'c': - cmd = string(val) // TODO(bradfitz): avoid garbage; cache process names between runs? - case 'P': - proto = lsofProtoLower(val) - case 'n': - if mem.Contains(mem.B(val), mem.S("->")) { - continue - } - // a listening port - port := parsePort(mem.B(val)) - if port <= 0 { - continue - } - pp := protoPort{proto, uint16(port)} - m := im.known[pp] - switch { - case m != nil: - m.port.Process = cmd - m.port.Pid = pid - default: - // ignore: processes and ports come and go - } - } - } - - return nil -} - -func lsofProtoLower(p []byte) string { - if string(p) == "TCP" { - return "tcp" - } - if string(p) == "UDP" { - return "udp" - } - return strings.ToLower(string(p)) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin && !ios + +package portlist + +import ( + "bufio" + "bytes" + "fmt" + "log" + "os/exec" + "strings" + "sync/atomic" + "time" + + "go4.org/mem" +) + +func init() { + newOSImpl = newMacOSImpl + + // We have to run netstat, which is a bit expensive, so don't do it too often. + pollInterval = 5 * time.Second +} + +type macOSImpl struct { + known map[protoPort]*portMeta // inode string => metadata + netstatPath string // lazily populated + + br *bufio.Reader // reused + portsBuf []Port + includeLocalhost bool +} + +type protoPort struct { + proto string + port uint16 +} + +type portMeta struct { + port Port + keep bool +} + +func newMacOSImpl(includeLocalhost bool) osImpl { + return &macOSImpl{ + known: map[protoPort]*portMeta{}, + br: bufio.NewReader(bytes.NewReader(nil)), + includeLocalhost: includeLocalhost, + } +} + +func (*macOSImpl) Close() error { return nil } + +func (im *macOSImpl) AppendListeningPorts(base []Port) ([]Port, error) { + var err error + im.portsBuf, err = im.appendListeningPortsNetstat(im.portsBuf[:0]) + if err != nil { + return nil, err + } + + for _, pm := range im.known { + pm.keep = false + } + + var needProcs bool + for _, p := range im.portsBuf { + fp := protoPort{ + proto: p.Proto, + port: p.Port, + } + if pm, ok := im.known[fp]; ok { + pm.keep = true + } else { + needProcs = true + im.known[fp] = &portMeta{ + port: p, + keep: true, + } + } + } + + ret := base + for k, m := range im.known { + if !m.keep { + delete(im.known, k) + } + } + + if needProcs { + im.addProcesses() // best effort + } + + for _, m := range im.known { + ret = append(ret, m.port) + } + return sortAndDedup(ret), nil +} + +func (im *macOSImpl) appendListeningPortsNetstat(base []Port) ([]Port, error) { + if im.netstatPath == "" { + var err error + im.netstatPath, err = exec.LookPath("netstat") + if err != nil { + return nil, fmt.Errorf("netstat: lookup: %v", err) + } + } + + cmd := exec.Command(im.netstatPath, "-na") + outPipe, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + im.br.Reset(outPipe) + + if err := cmd.Start(); err != nil { + return nil, err + } + defer cmd.Process.Wait() + defer cmd.Process.Kill() + + return appendParsePortsNetstat(base, im.br, im.includeLocalhost) +} + +var lsofFailed atomic.Bool + +// In theory, lsof could replace the function of both listPorts() and +// addProcesses(), since it provides a superset of the netstat output. +// However, "netstat -na" runs ~100x faster than lsof on my machine, so +// we should do it only if the list of open ports has actually changed. +// +// This fails in a macOS sandbox (i.e. in the Mac App Store or System +// Extension GUI build), but does at least work in the +// tailscaled-on-macos mode. +func (im *macOSImpl) addProcesses() error { + if lsofFailed.Load() { + // This previously failed in the macOS sandbox, so don't try again. + return nil + } + exe, err := exec.LookPath("lsof") + if err != nil { + return fmt.Errorf("lsof: lookup: %v", err) + } + lsofCmd := exec.Command(exe, "-F", "-n", "-P", "-O", "-S2", "-T", "-i4", "-i6") + outPipe, err := lsofCmd.StdoutPipe() + if err != nil { + return err + } + err = lsofCmd.Start() + if err != nil { + var stderr []byte + if xe, ok := err.(*exec.ExitError); ok { + stderr = xe.Stderr + } + // fails when run in a macOS sandbox, so make this non-fatal. + if lsofFailed.CompareAndSwap(false, true) { + log.Printf("portlist: can't run lsof in Mac sandbox; omitting process names from service list. Error details: %v, %s", err, bytes.TrimSpace(stderr)) + } + return nil + } + defer func() { + ps, err := lsofCmd.Process.Wait() + if err != nil || ps.ExitCode() != 0 { + log.Printf("portlist: can't run lsof in Mac sandbox; omitting process names from service list. Error: %v, exit code %d", err, ps.ExitCode()) + lsofFailed.Store(true) + } + }() + defer lsofCmd.Process.Kill() + + im.br.Reset(outPipe) + + var cmd, proto string + var pid int + for { + line, err := im.br.ReadBytes('\n') + if err != nil { + break + } + if len(line) < 1 { + continue + } + field, val := line[0], bytes.TrimSpace(line[1:]) + switch field { + case 'p': + // starting a new process + cmd = "" + proto = "" + pid = 0 + if p, err := mem.ParseInt(mem.B(val), 10, 0); err == nil { + pid = int(p) + } + case 'c': + cmd = string(val) // TODO(bradfitz): avoid garbage; cache process names between runs? + case 'P': + proto = lsofProtoLower(val) + case 'n': + if mem.Contains(mem.B(val), mem.S("->")) { + continue + } + // a listening port + port := parsePort(mem.B(val)) + if port <= 0 { + continue + } + pp := protoPort{proto, uint16(port)} + m := im.known[pp] + switch { + case m != nil: + m.port.Process = cmd + m.port.Pid = pid + default: + // ignore: processes and ports come and go + } + } + } + + return nil +} + +func lsofProtoLower(p []byte) string { + if string(p) == "TCP" { + return "tcp" + } + if string(p) == "UDP" { + return "udp" + } + return strings.ToLower(string(p)) +} diff --git a/portlist/portlist_windows.go b/portlist/portlist_windows.go index f449973599247..c164dbad75485 100644 --- a/portlist/portlist_windows.go +++ b/portlist/portlist_windows.go @@ -1,103 +1,103 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package portlist - -import ( - "time" - - "tailscale.com/net/netstat" -) - -func init() { - newOSImpl = newWindowsImpl - // The portlist poller used to fork on Windows, which is insanely expensive, - // so historically we only did this every 5 seconds on Windows. Maybe we - // could reduce it down to 1 seconds like Linux, but nobody's benchmarked as - // of 2022-11-04. - pollInterval = 5 * time.Second -} - -type famPort struct { - proto string - port uint16 - pid uint32 -} - -type windowsImpl struct { - known map[famPort]*portMeta // inode string => metadata - includeLocalhost bool -} - -type portMeta struct { - port Port - keep bool -} - -func newWindowsImpl(includeLocalhost bool) osImpl { - return &windowsImpl{ - known: map[famPort]*portMeta{}, - includeLocalhost: includeLocalhost, - } -} - -func (*windowsImpl) Close() error { return nil } - -func (im *windowsImpl) AppendListeningPorts(base []Port) ([]Port, error) { - // TODO(bradfitz): netstat.Get makes a bunch of garbage. Add an Append-style - // API to that package instead/additionally. - tab, err := netstat.Get() - if err != nil { - return nil, err - } - - for _, pm := range im.known { - pm.keep = false - } - - ret := base - for _, e := range tab.Entries { - if e.State != "LISTEN" { - continue - } - if !im.includeLocalhost && !e.Local.Addr().IsUnspecified() { - continue - } - fp := famPort{ - proto: "tcp", // TODO(bradfitz): UDP too; add to netstat - port: e.Local.Port(), - pid: uint32(e.Pid), - } - pm, ok := im.known[fp] - if ok { - pm.keep = true - continue - } - var process string - if e.OSMetadata != nil { - if module, err := e.OSMetadata.GetModule(); err == nil { - process = module - } - } - pm = &portMeta{ - keep: true, - port: Port{ - Proto: "tcp", - Port: e.Local.Port(), - Process: process, - Pid: e.Pid, - }, - } - im.known[fp] = pm - } - - for k, m := range im.known { - if !m.keep { - delete(im.known, k) - continue - } - ret = append(ret, m.port) - } - - return sortAndDedup(ret), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package portlist + +import ( + "time" + + "tailscale.com/net/netstat" +) + +func init() { + newOSImpl = newWindowsImpl + // The portlist poller used to fork on Windows, which is insanely expensive, + // so historically we only did this every 5 seconds on Windows. Maybe we + // could reduce it down to 1 seconds like Linux, but nobody's benchmarked as + // of 2022-11-04. + pollInterval = 5 * time.Second +} + +type famPort struct { + proto string + port uint16 + pid uint32 +} + +type windowsImpl struct { + known map[famPort]*portMeta // inode string => metadata + includeLocalhost bool +} + +type portMeta struct { + port Port + keep bool +} + +func newWindowsImpl(includeLocalhost bool) osImpl { + return &windowsImpl{ + known: map[famPort]*portMeta{}, + includeLocalhost: includeLocalhost, + } +} + +func (*windowsImpl) Close() error { return nil } + +func (im *windowsImpl) AppendListeningPorts(base []Port) ([]Port, error) { + // TODO(bradfitz): netstat.Get makes a bunch of garbage. Add an Append-style + // API to that package instead/additionally. + tab, err := netstat.Get() + if err != nil { + return nil, err + } + + for _, pm := range im.known { + pm.keep = false + } + + ret := base + for _, e := range tab.Entries { + if e.State != "LISTEN" { + continue + } + if !im.includeLocalhost && !e.Local.Addr().IsUnspecified() { + continue + } + fp := famPort{ + proto: "tcp", // TODO(bradfitz): UDP too; add to netstat + port: e.Local.Port(), + pid: uint32(e.Pid), + } + pm, ok := im.known[fp] + if ok { + pm.keep = true + continue + } + var process string + if e.OSMetadata != nil { + if module, err := e.OSMetadata.GetModule(); err == nil { + process = module + } + } + pm = &portMeta{ + keep: true, + port: Port{ + Proto: "tcp", + Port: e.Local.Port(), + Process: process, + Pid: e.Pid, + }, + } + im.known[fp] = pm + } + + for k, m := range im.known { + if !m.keep { + delete(im.known, k) + continue + } + ret = append(ret, m.port) + } + + return sortAndDedup(ret), nil +} diff --git a/posture/serialnumber_macos.go b/posture/serialnumber_macos.go index 48355d31393ee..ce0b996837889 100644 --- a/posture/serialnumber_macos.go +++ b/posture/serialnumber_macos.go @@ -1,74 +1,74 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build cgo && darwin && !ios - -package posture - -// #cgo LDFLAGS: -framework CoreFoundation -framework IOKit -// #include -// #include -// -// #if __MAC_OS_X_VERSION_MIN_REQUIRED < 120000 -// #define kIOMainPortDefault kIOMasterPortDefault -// #endif -// -// const char * -// getSerialNumber() -// { -// CFMutableDictionaryRef matching = IOServiceMatching("IOPlatformExpertDevice"); -// if (!matching) { -// return "err: failed to create dictionary to match IOServices"; -// } -// -// io_service_t service = IOServiceGetMatchingService(kIOMainPortDefault, matching); -// if (!service) { -// return "err: failed to look up registered IOService objects that match a matching dictionary"; -// } -// -// CFStringRef serialNumberRef = IORegistryEntryCreateCFProperty( -// service, -// CFSTR("IOPlatformSerialNumber"), -// kCFAllocatorDefault, -// 0 -// ); -// if (!serialNumberRef) { -// return "err: failed to look up serial number in IORegistry"; -// } -// -// CFIndex length = CFStringGetLength(serialNumberRef); -// CFIndex max_size = CFStringGetMaximumSizeForEncoding(length, kCFStringEncodingUTF8) + 1; -// char *serialNumberBuf = (char *)malloc(max_size); -// -// bool result = CFStringGetCString(serialNumberRef, serialNumberBuf, max_size, kCFStringEncodingUTF8); -// -// CFRelease(serialNumberRef); -// IOObjectRelease(service); -// -// if (!result) { -// free(serialNumberBuf); -// -// return "err: failed to convert serial number reference to string"; -// } -// -// return serialNumberBuf; -// } -import "C" -import ( - "fmt" - "strings" - - "tailscale.com/types/logger" -) - -// GetSerialNumber returns the platform serial sumber as reported by IOKit. -func GetSerialNumbers(_ logger.Logf) ([]string, error) { - csn := C.getSerialNumber() - serialNumber := C.GoString(csn) - - if err, ok := strings.CutPrefix(serialNumber, "err: "); ok { - return nil, fmt.Errorf("failed to get serial number from IOKit: %s", err) - } - - return []string{serialNumber}, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build cgo && darwin && !ios + +package posture + +// #cgo LDFLAGS: -framework CoreFoundation -framework IOKit +// #include +// #include +// +// #if __MAC_OS_X_VERSION_MIN_REQUIRED < 120000 +// #define kIOMainPortDefault kIOMasterPortDefault +// #endif +// +// const char * +// getSerialNumber() +// { +// CFMutableDictionaryRef matching = IOServiceMatching("IOPlatformExpertDevice"); +// if (!matching) { +// return "err: failed to create dictionary to match IOServices"; +// } +// +// io_service_t service = IOServiceGetMatchingService(kIOMainPortDefault, matching); +// if (!service) { +// return "err: failed to look up registered IOService objects that match a matching dictionary"; +// } +// +// CFStringRef serialNumberRef = IORegistryEntryCreateCFProperty( +// service, +// CFSTR("IOPlatformSerialNumber"), +// kCFAllocatorDefault, +// 0 +// ); +// if (!serialNumberRef) { +// return "err: failed to look up serial number in IORegistry"; +// } +// +// CFIndex length = CFStringGetLength(serialNumberRef); +// CFIndex max_size = CFStringGetMaximumSizeForEncoding(length, kCFStringEncodingUTF8) + 1; +// char *serialNumberBuf = (char *)malloc(max_size); +// +// bool result = CFStringGetCString(serialNumberRef, serialNumberBuf, max_size, kCFStringEncodingUTF8); +// +// CFRelease(serialNumberRef); +// IOObjectRelease(service); +// +// if (!result) { +// free(serialNumberBuf); +// +// return "err: failed to convert serial number reference to string"; +// } +// +// return serialNumberBuf; +// } +import "C" +import ( + "fmt" + "strings" + + "tailscale.com/types/logger" +) + +// GetSerialNumber returns the platform serial sumber as reported by IOKit. +func GetSerialNumbers(_ logger.Logf) ([]string, error) { + csn := C.getSerialNumber() + serialNumber := C.GoString(csn) + + if err, ok := strings.CutPrefix(serialNumber, "err: "); ok { + return nil, fmt.Errorf("failed to get serial number from IOKit: %s", err) + } + + return []string{serialNumber}, nil +} diff --git a/posture/serialnumber_notmacos_test.go b/posture/serialnumber_notmacos_test.go index f2a15e0373caf..8106c34b36541 100644 --- a/posture/serialnumber_notmacos_test.go +++ b/posture/serialnumber_notmacos_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Build on Windows, Linux and *BSD - -//go:build windows || (linux && !android) || freebsd || openbsd || dragonfly || netbsd - -package posture - -import ( - "fmt" - "testing" - - "tailscale.com/types/logger" -) - -func TestGetSerialNumberNotMac(t *testing.T) { - // This test is intentionally skipped as it will - // require root on Linux to get access to the serials. - // The test case is intended for local testing. - // Comment out skip for local testing. - t.Skip() - - sns, err := GetSerialNumbers(logger.Discard) - if err != nil { - t.Fatalf("failed to get serial number: %s", err) - } - - if len(sns) == 0 { - t.Fatalf("expected at least one serial number, got %v", sns) - } - - if len(sns[0]) <= 0 { - t.Errorf("expected a serial number with more than zero characters, got %s", sns[0]) - } - - fmt.Printf("serials: %v\n", sns) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Build on Windows, Linux and *BSD + +//go:build windows || (linux && !android) || freebsd || openbsd || dragonfly || netbsd + +package posture + +import ( + "fmt" + "testing" + + "tailscale.com/types/logger" +) + +func TestGetSerialNumberNotMac(t *testing.T) { + // This test is intentionally skipped as it will + // require root on Linux to get access to the serials. + // The test case is intended for local testing. + // Comment out skip for local testing. + t.Skip() + + sns, err := GetSerialNumbers(logger.Discard) + if err != nil { + t.Fatalf("failed to get serial number: %s", err) + } + + if len(sns) == 0 { + t.Fatalf("expected at least one serial number, got %v", sns) + } + + if len(sns[0]) <= 0 { + t.Errorf("expected a serial number with more than zero characters, got %s", sns[0]) + } + + fmt.Printf("serials: %v\n", sns) +} diff --git a/posture/serialnumber_test.go b/posture/serialnumber_test.go index fac4392fab7d3..1ab8193367bc2 100644 --- a/posture/serialnumber_test.go +++ b/posture/serialnumber_test.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package posture - -import ( - "testing" - - "tailscale.com/types/logger" -) - -func TestGetSerialNumber(t *testing.T) { - // ensure GetSerialNumbers is implemented - // or covered by a stub on a given platform. - _, _ = GetSerialNumbers(logger.Discard) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package posture + +import ( + "testing" + + "tailscale.com/types/logger" +) + +func TestGetSerialNumber(t *testing.T) { + // ensure GetSerialNumbers is implemented + // or covered by a stub on a given platform. + _, _ = GetSerialNumbers(logger.Discard) +} diff --git a/pull-toolchain.sh b/pull-toolchain.sh index f5a19e7d75de1..87350ff53e39a 100755 --- a/pull-toolchain.sh +++ b/pull-toolchain.sh @@ -1,16 +1,16 @@ -#!/bin/sh -# Retrieve the latest Go toolchain. -# -set -eu -cd "$(dirname "$0")" - -read -r go_branch go.toolchain.rev -fi - -if [ -n "$(git diff-index --name-only HEAD -- go.toolchain.rev)" ]; then - echo "pull-toolchain.sh: changes imported. Use git commit to make them permanent." >&2 -fi +#!/bin/sh +# Retrieve the latest Go toolchain. +# +set -eu +cd "$(dirname "$0")" + +read -r go_branch go.toolchain.rev +fi + +if [ -n "$(git diff-index --name-only HEAD -- go.toolchain.rev)" ]; then + echo "pull-toolchain.sh: changes imported. Use git commit to make them permanent." >&2 +fi diff --git a/release/deb/debian.postrm.sh b/release/deb/debian.postrm.sh index f4dd4ed9cdc15..93d90b0ea2707 100755 --- a/release/deb/debian.postrm.sh +++ b/release/deb/debian.postrm.sh @@ -1,17 +1,17 @@ -#!/bin/sh -set -e -if [ -d /run/systemd/system ] ; then - systemctl --system daemon-reload >/dev/null || true -fi - -if [ -x "/usr/bin/deb-systemd-helper" ]; then - if [ "$1" = "remove" ]; then - deb-systemd-helper mask 'tailscaled.service' >/dev/null || true - fi - - if [ "$1" = "purge" ]; then - deb-systemd-helper purge 'tailscaled.service' >/dev/null || true - deb-systemd-helper unmask 'tailscaled.service' >/dev/null || true - rm -rf /var/lib/tailscale - fi -fi +#!/bin/sh +set -e +if [ -d /run/systemd/system ] ; then + systemctl --system daemon-reload >/dev/null || true +fi + +if [ -x "/usr/bin/deb-systemd-helper" ]; then + if [ "$1" = "remove" ]; then + deb-systemd-helper mask 'tailscaled.service' >/dev/null || true + fi + + if [ "$1" = "purge" ]; then + deb-systemd-helper purge 'tailscaled.service' >/dev/null || true + deb-systemd-helper unmask 'tailscaled.service' >/dev/null || true + rm -rf /var/lib/tailscale + fi +fi diff --git a/release/deb/debian.prerm.sh b/release/deb/debian.prerm.sh index 9be58ede4d963..a712a08c8181f 100755 --- a/release/deb/debian.prerm.sh +++ b/release/deb/debian.prerm.sh @@ -1,7 +1,7 @@ -#!/bin/sh -set -e -if [ "$1" = "remove" ]; then - if [ -d /run/systemd/system ]; then - deb-systemd-invoke stop 'tailscaled.service' >/dev/null || true - fi -fi +#!/bin/sh +set -e +if [ "$1" = "remove" ]; then + if [ -d /run/systemd/system ]; then + deb-systemd-invoke stop 'tailscaled.service' >/dev/null || true + fi +fi diff --git a/release/dist/memoize.go b/release/dist/memoize.go index 0927ac0a81540..f148cd2b79c86 100644 --- a/release/dist/memoize.go +++ b/release/dist/memoize.go @@ -1,86 +1,86 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dist - -import ( - "sync" - - "tailscale.com/util/deephash" -) - -// MemoizedFn is a function that memoize.Do can call. -type MemoizedFn[T any] func() (T, error) - -// Memoize runs MemoizedFns and remembers their results. -type Memoize[O any] struct { - mu sync.Mutex - cond *sync.Cond - outs map[deephash.Sum]O - errs map[deephash.Sum]error - inflight map[deephash.Sum]bool -} - -// Do runs fn and returns its result. -// fn is only run once per unique key. Subsequent Do calls with the same key -// return the memoized result of the first call, even if fn is a different -// function. -func (m *Memoize[O]) Do(key any, fn MemoizedFn[O]) (ret O, err error) { - m.mu.Lock() - defer m.mu.Unlock() - if m.cond == nil { - m.cond = sync.NewCond(&m.mu) - m.outs = map[deephash.Sum]O{} - m.errs = map[deephash.Sum]error{} - m.inflight = map[deephash.Sum]bool{} - } - - k := deephash.Hash(&key) - - for m.inflight[k] { - m.cond.Wait() - } - if err := m.errs[k]; err != nil { - var ret O - return ret, err - } - if ret, ok := m.outs[k]; ok { - return ret, nil - } - - m.inflight[k] = true - m.mu.Unlock() - defer func() { - m.mu.Lock() - delete(m.inflight, k) - if err != nil { - m.errs[k] = err - } else { - m.outs[k] = ret - } - m.cond.Broadcast() - }() - - ret, err = fn() - if err != nil { - var ret O - return ret, err - } - return ret, nil -} - -// once is like memoize, but for functions that don't return non-error values. -type once struct { - m Memoize[any] -} - -// Do runs fn. -// fn is only run once per unique key. Subsequent Do calls with the same key -// return the memoized result of the first call, even if fn is a different -// function. -func (o *once) Do(key any, fn func() error) error { - _, err := o.m.Do(key, func() (any, error) { - return nil, fn() - }) - return err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dist + +import ( + "sync" + + "tailscale.com/util/deephash" +) + +// MemoizedFn is a function that memoize.Do can call. +type MemoizedFn[T any] func() (T, error) + +// Memoize runs MemoizedFns and remembers their results. +type Memoize[O any] struct { + mu sync.Mutex + cond *sync.Cond + outs map[deephash.Sum]O + errs map[deephash.Sum]error + inflight map[deephash.Sum]bool +} + +// Do runs fn and returns its result. +// fn is only run once per unique key. Subsequent Do calls with the same key +// return the memoized result of the first call, even if fn is a different +// function. +func (m *Memoize[O]) Do(key any, fn MemoizedFn[O]) (ret O, err error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.cond == nil { + m.cond = sync.NewCond(&m.mu) + m.outs = map[deephash.Sum]O{} + m.errs = map[deephash.Sum]error{} + m.inflight = map[deephash.Sum]bool{} + } + + k := deephash.Hash(&key) + + for m.inflight[k] { + m.cond.Wait() + } + if err := m.errs[k]; err != nil { + var ret O + return ret, err + } + if ret, ok := m.outs[k]; ok { + return ret, nil + } + + m.inflight[k] = true + m.mu.Unlock() + defer func() { + m.mu.Lock() + delete(m.inflight, k) + if err != nil { + m.errs[k] = err + } else { + m.outs[k] = ret + } + m.cond.Broadcast() + }() + + ret, err = fn() + if err != nil { + var ret O + return ret, err + } + return ret, nil +} + +// once is like memoize, but for functions that don't return non-error values. +type once struct { + m Memoize[any] +} + +// Do runs fn. +// fn is only run once per unique key. Subsequent Do calls with the same key +// return the memoized result of the first call, even if fn is a different +// function. +func (o *once) Do(key any, fn func() error) error { + _, err := o.m.Do(key, func() (any, error) { + return nil, fn() + }) + return err +} diff --git a/release/dist/synology/files/Tailscale.sc b/release/dist/synology/files/Tailscale.sc index 707ac6bb079b1..f3bb1f0bdbe5d 100644 --- a/release/dist/synology/files/Tailscale.sc +++ b/release/dist/synology/files/Tailscale.sc @@ -1,6 +1,6 @@ -[Tailscale] -title="Tailscale" -desc="Tailscale VPN" -port_forward="no" -src.ports="41641/udp" +[Tailscale] +title="Tailscale" +desc="Tailscale VPN" +port_forward="no" +src.ports="41641/udp" dst.ports="41641/udp" \ No newline at end of file diff --git a/release/dist/synology/files/config b/release/dist/synology/files/config index 4dbc48dfb9434..1cf1a6cfaee47 100644 --- a/release/dist/synology/files/config +++ b/release/dist/synology/files/config @@ -1,11 +1,11 @@ -{ - ".url": { - "SYNO.SDS.Tailscale": { - "type": "url", - "title": "Tailscale", - "icon": "PACKAGE_ICON_256.PNG", - "url": "webman/3rdparty/Tailscale/index.cgi/", - "urlTarget": "_syno_tailscale" - } - } -} +{ + ".url": { + "SYNO.SDS.Tailscale": { + "type": "url", + "title": "Tailscale", + "icon": "PACKAGE_ICON_256.PNG", + "url": "webman/3rdparty/Tailscale/index.cgi/", + "urlTarget": "_syno_tailscale" + } + } +} diff --git a/release/dist/synology/files/index.cgi b/release/dist/synology/files/index.cgi index 2c1990cfd138a..996160d1dca4e 100755 --- a/release/dist/synology/files/index.cgi +++ b/release/dist/synology/files/index.cgi @@ -1,2 +1,2 @@ -#! /bin/sh -exec /var/packages/Tailscale/target/bin/tailscale web -cgi -prefix="/webman/3rdparty/Tailscale/index.cgi/" +#! /bin/sh +exec /var/packages/Tailscale/target/bin/tailscale web -cgi -prefix="/webman/3rdparty/Tailscale/index.cgi/" diff --git a/release/dist/synology/files/logrotate-dsm6 b/release/dist/synology/files/logrotate-dsm6 index 2df64283afc30..a52a6ba24c59e 100644 --- a/release/dist/synology/files/logrotate-dsm6 +++ b/release/dist/synology/files/logrotate-dsm6 @@ -1,8 +1,8 @@ -/var/packages/Tailscale/etc/tailscaled.stdout.log { - size 10M - rotate 3 - missingok - copytruncate - compress - notifempty -} +/var/packages/Tailscale/etc/tailscaled.stdout.log { + size 10M + rotate 3 + missingok + copytruncate + compress + notifempty +} diff --git a/release/dist/synology/files/logrotate-dsm7 b/release/dist/synology/files/logrotate-dsm7 index 7020dc925c2ca..3fe6775102b72 100644 --- a/release/dist/synology/files/logrotate-dsm7 +++ b/release/dist/synology/files/logrotate-dsm7 @@ -1,8 +1,8 @@ -/var/packages/Tailscale/var/tailscaled.stdout.log { - size 10M - rotate 3 - missingok - copytruncate - compress - notifempty -} +/var/packages/Tailscale/var/tailscaled.stdout.log { + size 10M + rotate 3 + missingok + copytruncate + compress + notifempty +} diff --git a/release/dist/synology/files/privilege-dsm6 b/release/dist/synology/files/privilege-dsm6 index 4b6fe093a1f23..c638528d199bc 100644 --- a/release/dist/synology/files/privilege-dsm6 +++ b/release/dist/synology/files/privilege-dsm6 @@ -1,7 +1,7 @@ -{ - "defaults":{ - "run-as": "root" - }, - "username": "tailscale", - "groupname": "tailscale" -} +{ + "defaults":{ + "run-as": "root" + }, + "username": "tailscale", + "groupname": "tailscale" +} diff --git a/release/dist/synology/files/privilege-dsm7 b/release/dist/synology/files/privilege-dsm7 index 93a9c4f7d7bb5..4eca66cff5dd0 100644 --- a/release/dist/synology/files/privilege-dsm7 +++ b/release/dist/synology/files/privilege-dsm7 @@ -1,7 +1,7 @@ -{ - "defaults":{ - "run-as": "package" - }, - "username": "tailscale", - "groupname": "tailscale" -} +{ + "defaults":{ + "run-as": "package" + }, + "username": "tailscale", + "groupname": "tailscale" +} diff --git a/release/dist/synology/files/privilege-dsm7.for-package-center b/release/dist/synology/files/privilege-dsm7.for-package-center index db14683460909..b2f93cee1a3c6 100644 --- a/release/dist/synology/files/privilege-dsm7.for-package-center +++ b/release/dist/synology/files/privilege-dsm7.for-package-center @@ -1,13 +1,13 @@ -{ - "defaults":{ - "run-as": "package" - }, - "username": "tailscale", - "groupname": "tailscale", - "tool": [{ - "relpath": "bin/tailscaled", - "user": "package", - "group": "package", - "capabilities": "cap_net_admin,cap_chown,cap_net_raw" - }] -} +{ + "defaults":{ + "run-as": "package" + }, + "username": "tailscale", + "groupname": "tailscale", + "tool": [{ + "relpath": "bin/tailscaled", + "user": "package", + "group": "package", + "capabilities": "cap_net_admin,cap_chown,cap_net_raw" + }] +} diff --git a/release/dist/synology/files/resource b/release/dist/synology/files/resource index 0da0002ef2fb2..706c97671ed47 100644 --- a/release/dist/synology/files/resource +++ b/release/dist/synology/files/resource @@ -1,11 +1,11 @@ -{ - "port-config": { - "protocol-file": "conf/Tailscale.sc" - }, - "usr-local-linker": { - "bin": ["bin/tailscale"] - }, - "syslog-config": { - "logrotate-relpath": "conf/logrotate.conf" - } +{ + "port-config": { + "protocol-file": "conf/Tailscale.sc" + }, + "usr-local-linker": { + "bin": ["bin/tailscale"] + }, + "syslog-config": { + "logrotate-relpath": "conf/logrotate.conf" + } } \ No newline at end of file diff --git a/release/dist/synology/files/scripts/postupgrade b/release/dist/synology/files/scripts/postupgrade index 92b94c40c5f2b..2a7fba5b6f483 100644 --- a/release/dist/synology/files/scripts/postupgrade +++ b/release/dist/synology/files/scripts/postupgrade @@ -1,3 +1,3 @@ -#!/bin/sh - +#!/bin/sh + exit 0 \ No newline at end of file diff --git a/release/dist/synology/files/scripts/preupgrade b/release/dist/synology/files/scripts/preupgrade index 92b94c40c5f2b..2a7fba5b6f483 100644 --- a/release/dist/synology/files/scripts/preupgrade +++ b/release/dist/synology/files/scripts/preupgrade @@ -1,3 +1,3 @@ -#!/bin/sh - +#!/bin/sh + exit 0 \ No newline at end of file diff --git a/release/dist/synology/files/scripts/start-stop-status b/release/dist/synology/files/scripts/start-stop-status index e6ece04e3383e..311f9293bd62a 100755 --- a/release/dist/synology/files/scripts/start-stop-status +++ b/release/dist/synology/files/scripts/start-stop-status @@ -1,129 +1,129 @@ -#!/bin/bash - -SERVICE_NAME="tailscale" - -if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "6" ]; then - PKGVAR="/var/packages/Tailscale/etc" -else - PKGVAR="${SYNOPKG_PKGVAR}" -fi - -PID_FILE="${PKGVAR}/tailscaled.pid" -LOG_FILE="${PKGVAR}/tailscaled.stdout.log" -STATE_FILE="${PKGVAR}/tailscaled.state" -SOCKET_FILE="${PKGVAR}/tailscaled.sock" -PORT="41641" - -SERVICE_COMMAND="${SYNOPKG_PKGDEST}/bin/tailscaled \ ---state=${STATE_FILE} \ ---socket=${SOCKET_FILE} \ ---port=$PORT" - -if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "7" -a ! -e "/dev/net/tun" ]; then - # TODO(maisem/crawshaw): Disable the tun device in DSM7 for now. - SERVICE_COMMAND="${SERVICE_COMMAND} --tun=userspace-networking" -fi - -if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "6" ]; then - chown -R tailscale:tailscale "${PKGVAR}/" -fi - -start_daemon() { - local ts=$(date --iso-8601=second) - echo "${ts} Starting ${SERVICE_NAME} with: ${SERVICE_COMMAND}" >${LOG_FILE} - STATE_DIRECTORY=${PKGVAR} ${SERVICE_COMMAND} 2>&1 | sed -u '1,200p;201s,.*,[further tailscaled logs suppressed],p;d' >>${LOG_FILE} & - # We pipe tailscaled's output to sed, so "$!" retrieves the PID of sed not tailscaled. - # Use jobs -p to retrieve the PID of the most recent process group leader. - jobs -p >"${PID_FILE}" -} - -stop_daemon() { - if [ -r "${PID_FILE}" ]; then - local PID=$(cat "${PID_FILE}") - local ts=$(date --iso-8601=second) - echo "${ts} Stopping ${SERVICE_NAME} service PID=${PID}" >>${LOG_FILE} - kill -TERM $PID >>${LOG_FILE} 2>&1 - wait_for_status 1 || kill -KILL $PID >>${LOG_FILE} 2>&1 - rm -f "${PID_FILE}" >/dev/null - fi -} - -daemon_status() { - if [ -r "${PID_FILE}" ]; then - local PID=$(cat "${PID_FILE}") - if ps -o pid -p ${PID} > /dev/null; then - return - fi - rm -f "${PID_FILE}" >/dev/null - fi - return 1 -} - -wait_for_status() { - # 20 tries - # sleeps for 1 second after each try - local counter=20 - while [ ${counter} -gt 0 ]; do - daemon_status - [ $? -eq $1 ] && return - counter=$((counter - 1)) - sleep 1 - done - return 1 -} - -ensure_tun_created() { - if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "7" ]; then - # TODO(maisem/crawshaw): Disable the tun device in DSM7 for now. - return - fi - # Create the necessary file structure for /dev/net/tun - if ([ ! -c /dev/net/tun ]); then - if ([ ! -d /dev/net ]); then - mkdir -m 755 /dev/net - fi - mknod /dev/net/tun c 10 200 - chmod 0755 /dev/net/tun - fi - - # Load the tun module if not already loaded - if (!(lsmod | grep -q "^tun\s")); then - insmod /lib/modules/tun.ko - fi -} - -case $1 in -start) - if daemon_status; then - exit 0 - else - ensure_tun_created - start_daemon - exit $? - fi - ;; -stop) - if daemon_status; then - stop_daemon - exit $? - else - exit 0 - fi - ;; -status) - if daemon_status; then - echo "${SERVICE_NAME} is running" - exit 0 - else - echo "${SERVICE_NAME} is not running" - exit 3 - fi - ;; -log) - exit 0 - ;; -*) - echo "command $1 is not implemented" - exit 0 - ;; -esac +#!/bin/bash + +SERVICE_NAME="tailscale" + +if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "6" ]; then + PKGVAR="/var/packages/Tailscale/etc" +else + PKGVAR="${SYNOPKG_PKGVAR}" +fi + +PID_FILE="${PKGVAR}/tailscaled.pid" +LOG_FILE="${PKGVAR}/tailscaled.stdout.log" +STATE_FILE="${PKGVAR}/tailscaled.state" +SOCKET_FILE="${PKGVAR}/tailscaled.sock" +PORT="41641" + +SERVICE_COMMAND="${SYNOPKG_PKGDEST}/bin/tailscaled \ +--state=${STATE_FILE} \ +--socket=${SOCKET_FILE} \ +--port=$PORT" + +if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "7" -a ! -e "/dev/net/tun" ]; then + # TODO(maisem/crawshaw): Disable the tun device in DSM7 for now. + SERVICE_COMMAND="${SERVICE_COMMAND} --tun=userspace-networking" +fi + +if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "6" ]; then + chown -R tailscale:tailscale "${PKGVAR}/" +fi + +start_daemon() { + local ts=$(date --iso-8601=second) + echo "${ts} Starting ${SERVICE_NAME} with: ${SERVICE_COMMAND}" >${LOG_FILE} + STATE_DIRECTORY=${PKGVAR} ${SERVICE_COMMAND} 2>&1 | sed -u '1,200p;201s,.*,[further tailscaled logs suppressed],p;d' >>${LOG_FILE} & + # We pipe tailscaled's output to sed, so "$!" retrieves the PID of sed not tailscaled. + # Use jobs -p to retrieve the PID of the most recent process group leader. + jobs -p >"${PID_FILE}" +} + +stop_daemon() { + if [ -r "${PID_FILE}" ]; then + local PID=$(cat "${PID_FILE}") + local ts=$(date --iso-8601=second) + echo "${ts} Stopping ${SERVICE_NAME} service PID=${PID}" >>${LOG_FILE} + kill -TERM $PID >>${LOG_FILE} 2>&1 + wait_for_status 1 || kill -KILL $PID >>${LOG_FILE} 2>&1 + rm -f "${PID_FILE}" >/dev/null + fi +} + +daemon_status() { + if [ -r "${PID_FILE}" ]; then + local PID=$(cat "${PID_FILE}") + if ps -o pid -p ${PID} > /dev/null; then + return + fi + rm -f "${PID_FILE}" >/dev/null + fi + return 1 +} + +wait_for_status() { + # 20 tries + # sleeps for 1 second after each try + local counter=20 + while [ ${counter} -gt 0 ]; do + daemon_status + [ $? -eq $1 ] && return + counter=$((counter - 1)) + sleep 1 + done + return 1 +} + +ensure_tun_created() { + if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "7" ]; then + # TODO(maisem/crawshaw): Disable the tun device in DSM7 for now. + return + fi + # Create the necessary file structure for /dev/net/tun + if ([ ! -c /dev/net/tun ]); then + if ([ ! -d /dev/net ]); then + mkdir -m 755 /dev/net + fi + mknod /dev/net/tun c 10 200 + chmod 0755 /dev/net/tun + fi + + # Load the tun module if not already loaded + if (!(lsmod | grep -q "^tun\s")); then + insmod /lib/modules/tun.ko + fi +} + +case $1 in +start) + if daemon_status; then + exit 0 + else + ensure_tun_created + start_daemon + exit $? + fi + ;; +stop) + if daemon_status; then + stop_daemon + exit $? + else + exit 0 + fi + ;; +status) + if daemon_status; then + echo "${SERVICE_NAME} is running" + exit 0 + else + echo "${SERVICE_NAME} is not running" + exit 3 + fi + ;; +log) + exit 0 + ;; +*) + echo "command $1 is not implemented" + exit 0 + ;; +esac diff --git a/release/dist/unixpkgs/pkgs.go b/release/dist/unixpkgs/pkgs.go index bad6ce572e675..60a038eb49d21 100644 --- a/release/dist/unixpkgs/pkgs.go +++ b/release/dist/unixpkgs/pkgs.go @@ -1,472 +1,472 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package unixpkgs contains dist Targets for building unix Tailscale packages. -package unixpkgs - -import ( - "archive/tar" - "compress/gzip" - "errors" - "fmt" - "io" - "log" - "os" - "path/filepath" - "strings" - - "github.com/goreleaser/nfpm/v2" - "github.com/goreleaser/nfpm/v2/files" - "tailscale.com/release/dist" -) - -type tgzTarget struct { - filenameArch string // arch to use in filename instead of deriving from goEnv["GOARCH"] - goEnv map[string]string - signer dist.Signer -} - -func (t *tgzTarget) arch() string { - if t.filenameArch != "" { - return t.filenameArch - } - return t.goEnv["GOARCH"] -} - -func (t *tgzTarget) os() string { - return t.goEnv["GOOS"] -} - -func (t *tgzTarget) String() string { - return fmt.Sprintf("%s/%s/tgz", t.os(), t.arch()) -} - -func (t *tgzTarget) Build(b *dist.Build) ([]string, error) { - var filename string - if t.goEnv["GOOS"] == "linux" { - // Linux used to be the only tgz architecture, so we didn't put the OS - // name in the filename. - filename = fmt.Sprintf("tailscale_%s_%s.tgz", b.Version.Short, t.arch()) - } else { - filename = fmt.Sprintf("tailscale_%s_%s_%s.tgz", b.Version.Short, t.os(), t.arch()) - } - if err := b.BuildWebClientAssets(); err != nil { - return nil, err - } - ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) - if err != nil { - return nil, err - } - tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) - if err != nil { - return nil, err - } - - log.Printf("Building %s", filename) - - out := filepath.Join(b.Out, filename) - f, err := os.Create(out) - if err != nil { - return nil, err - } - defer f.Close() - gw := gzip.NewWriter(f) - defer gw.Close() - tw := tar.NewWriter(gw) - defer tw.Close() - - addFile := func(src, dst string, mode int64) error { - f, err := os.Open(src) - if err != nil { - return err - } - defer f.Close() - fi, err := f.Stat() - if err != nil { - return err - } - hdr := &tar.Header{ - Name: dst, - Size: fi.Size(), - Mode: mode, - ModTime: b.Time, - Uid: 0, - Gid: 0, - Uname: "root", - Gname: "root", - } - if err := tw.WriteHeader(hdr); err != nil { - return err - } - if _, err = io.Copy(tw, f); err != nil { - return err - } - return nil - } - addDir := func(name string) error { - hdr := &tar.Header{ - Name: name + "/", - Mode: 0755, - ModTime: b.Time, - Uid: 0, - Gid: 0, - Uname: "root", - Gname: "root", - } - return tw.WriteHeader(hdr) - } - dir := strings.TrimSuffix(filename, ".tgz") - if err := addDir(dir); err != nil { - return nil, err - } - if err := addFile(tsd, filepath.Join(dir, "tailscaled"), 0755); err != nil { - return nil, err - } - if err := addFile(ts, filepath.Join(dir, "tailscale"), 0755); err != nil { - return nil, err - } - if t.os() == "linux" { - dir = filepath.Join(dir, "systemd") - if err := addDir(dir); err != nil { - return nil, err - } - tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") - if err != nil { - return nil, err - } - if err := addFile(filepath.Join(tailscaledDir, "tailscaled.service"), filepath.Join(dir, "tailscaled.service"), 0644); err != nil { - return nil, err - } - if err := addFile(filepath.Join(tailscaledDir, "tailscaled.defaults"), filepath.Join(dir, "tailscaled.defaults"), 0644); err != nil { - return nil, err - } - } - if err := tw.Close(); err != nil { - return nil, err - } - if err := gw.Close(); err != nil { - return nil, err - } - if err := f.Close(); err != nil { - return nil, err - } - - files := []string{filename} - - if t.signer != nil { - outSig := out + ".sig" - if err := t.signer.SignFile(out, outSig); err != nil { - return nil, err - } - files = append(files, filepath.Base(outSig)) - } - - return files, nil -} - -type debTarget struct { - goEnv map[string]string -} - -func (t *debTarget) os() string { - return t.goEnv["GOOS"] -} - -func (t *debTarget) arch() string { - return t.goEnv["GOARCH"] -} - -func (t *debTarget) String() string { - return fmt.Sprintf("linux/%s/deb", t.goEnv["GOARCH"]) -} - -func (t *debTarget) Build(b *dist.Build) ([]string, error) { - if t.os() != "linux" { - return nil, errors.New("deb only supported on linux") - } - - if err := b.BuildWebClientAssets(); err != nil { - return nil, err - } - ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) - if err != nil { - return nil, err - } - tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) - if err != nil { - return nil, err - } - - tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") - if err != nil { - return nil, err - } - repoDir, err := b.GoPkg("tailscale.com") - if err != nil { - return nil, err - } - - arch := debArch(t.arch()) - contents, err := files.PrepareForPackager(files.Contents{ - &files.Content{ - Type: files.TypeFile, - Source: ts, - Destination: "/usr/bin/tailscale", - }, - &files.Content{ - Type: files.TypeFile, - Source: tsd, - Destination: "/usr/sbin/tailscaled", - }, - &files.Content{ - Type: files.TypeFile, - Source: filepath.Join(tailscaledDir, "tailscaled.service"), - Destination: "/lib/systemd/system/tailscaled.service", - }, - &files.Content{ - Type: files.TypeConfigNoReplace, - Source: filepath.Join(tailscaledDir, "tailscaled.defaults"), - Destination: "/etc/default/tailscaled", - }, - }, 0, "deb", false) - if err != nil { - return nil, err - } - info := nfpm.WithDefaults(&nfpm.Info{ - Name: "tailscale", - Arch: arch, - Platform: "linux", - Version: b.Version.Short, - Maintainer: "Tailscale Inc ", - Description: "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", - Homepage: "https://www.tailscale.com", - License: "MIT", - Section: "net", - Priority: "extra", - Overridables: nfpm.Overridables{ - Contents: contents, - Scripts: nfpm.Scripts{ - PostInstall: filepath.Join(repoDir, "release/deb/debian.postinst.sh"), - PreRemove: filepath.Join(repoDir, "release/deb/debian.prerm.sh"), - PostRemove: filepath.Join(repoDir, "release/deb/debian.postrm.sh"), - }, - Depends: []string{ - // iptables is almost always required but not strictly needed. - // Even if you can technically run Tailscale without it (by - // manually configuring nftables or userspace mode), we still - // mark this as "Depends" because our previous experiment in - // https://github.com/tailscale/tailscale/issues/9236 of making - // it only Recommends caused too many problems. Until our - // nftables table is more mature, we'd rather err on the side of - // wasting a little disk by including iptables for people who - // might not need it rather than handle reports of it being - // missing. - "iptables", - }, - Recommends: []string{ - "tailscale-archive-keyring (>= 1.35.181)", - // The "ip" command isn't needed since 2021-11-01 in - // 408b0923a61972ed but kept as an option as of - // 2021-11-18 in d24ed3f68e35e802d531371. See - // https://github.com/tailscale/tailscale/issues/391. - // We keep it recommended because it's usually - // installed anyway and it's useful for debugging. But - // we can live without it, so it's not Depends. - "iproute2", - }, - Replaces: []string{"tailscale-relay"}, - Conflicts: []string{"tailscale-relay"}, - }, - }) - pkg, err := nfpm.Get("deb") - if err != nil { - return nil, err - } - - filename := fmt.Sprintf("tailscale_%s_%s.deb", b.Version.Short, arch) - log.Printf("Building %s", filename) - f, err := os.Create(filepath.Join(b.Out, filename)) - if err != nil { - return nil, err - } - defer f.Close() - if err := pkg.Package(info, f); err != nil { - return nil, err - } - if err := f.Close(); err != nil { - return nil, err - } - - return []string{filename}, nil -} - -type rpmTarget struct { - goEnv map[string]string - signer dist.Signer -} - -func (t *rpmTarget) os() string { - return t.goEnv["GOOS"] -} - -func (t *rpmTarget) arch() string { - return t.goEnv["GOARCH"] -} - -func (t *rpmTarget) String() string { - return fmt.Sprintf("linux/%s/rpm", t.arch()) -} - -func (t *rpmTarget) Build(b *dist.Build) ([]string, error) { - if t.os() != "linux" { - return nil, errors.New("rpm only supported on linux") - } - - if err := b.BuildWebClientAssets(); err != nil { - return nil, err - } - ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) - if err != nil { - return nil, err - } - tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) - if err != nil { - return nil, err - } - - tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") - if err != nil { - return nil, err - } - repoDir, err := b.GoPkg("tailscale.com") - if err != nil { - return nil, err - } - - arch := rpmArch(t.arch()) - contents, err := files.PrepareForPackager(files.Contents{ - &files.Content{ - Type: files.TypeFile, - Source: ts, - Destination: "/usr/bin/tailscale", - }, - &files.Content{ - Type: files.TypeFile, - Source: tsd, - Destination: "/usr/sbin/tailscaled", - }, - &files.Content{ - Type: files.TypeFile, - Source: filepath.Join(tailscaledDir, "tailscaled.service"), - Destination: "/lib/systemd/system/tailscaled.service", - }, - &files.Content{ - Type: files.TypeConfigNoReplace, - Source: filepath.Join(tailscaledDir, "tailscaled.defaults"), - Destination: "/etc/default/tailscaled", - }, - // SELinux policy on e.g. CentOS 8 forbids writing to /var/cache. - // Creating an empty directory at install time resolves this issue. - &files.Content{ - Type: files.TypeDir, - Destination: "/var/cache/tailscale", - }, - }, 0, "rpm", false) - if err != nil { - return nil, err - } - info := nfpm.WithDefaults(&nfpm.Info{ - Name: "tailscale", - Arch: arch, - Platform: "linux", - Version: b.Version.Short, - Maintainer: "Tailscale Inc ", - Description: "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", - Homepage: "https://www.tailscale.com", - License: "MIT", - Overridables: nfpm.Overridables{ - Contents: contents, - Scripts: nfpm.Scripts{ - PostInstall: filepath.Join(repoDir, "release/rpm/rpm.postinst.sh"), - PreRemove: filepath.Join(repoDir, "release/rpm/rpm.prerm.sh"), - PostRemove: filepath.Join(repoDir, "release/rpm/rpm.postrm.sh"), - }, - Depends: []string{"iptables", "iproute"}, - Replaces: []string{"tailscale-relay"}, - Conflicts: []string{"tailscale-relay"}, - RPM: nfpm.RPM{ - Group: "Network", - Signature: nfpm.RPMSignature{ - PackageSignature: nfpm.PackageSignature{ - SignFn: t.signer, - }, - }, - }, - }, - }) - pkg, err := nfpm.Get("rpm") - if err != nil { - return nil, err - } - - filename := fmt.Sprintf("tailscale_%s_%s.rpm", b.Version.Short, arch) - log.Printf("Building %s", filename) - - f, err := os.Create(filepath.Join(b.Out, filename)) - if err != nil { - return nil, err - } - defer f.Close() - if err := pkg.Package(info, f); err != nil { - return nil, err - } - if err := f.Close(); err != nil { - return nil, err - } - - return []string{filename}, nil -} - -// debArch returns the debian arch name for the given Go arch name. -// nfpm also does this translation internally, but we need to do it outside nfpm -// because we also need the filename to be correct. -func debArch(arch string) string { - switch arch { - case "386": - return "i386" - case "arm": - // TODO: this is supposed to be "armel" for GOARM=5, and "armhf" for - // GOARM=6 and 7. But we have some tech debt to pay off here before we - // can ship more than 1 ARM deb, so for now match redo's behavior of - // shipping armv5 binaries in an armv7 trenchcoat. - return "armhf" - case "mipsle": - return "mipsel" - case "mips64le": - return "mips64el" - default: - return arch - } -} - -// rpmArch returns the RPM arch name for the given Go arch name. -// nfpm also does this translation internally, but we need to do it outside nfpm -// because we also need the filename to be correct. -func rpmArch(arch string) string { - switch arch { - case "amd64": - return "x86_64" - case "386": - return "i386" - case "arm": - return "armv7hl" - case "arm64": - return "aarch64" - case "mipsle": - return "mipsel" - case "mips64le": - return "mips64el" - default: - return arch - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package unixpkgs contains dist Targets for building unix Tailscale packages. +package unixpkgs + +import ( + "archive/tar" + "compress/gzip" + "errors" + "fmt" + "io" + "log" + "os" + "path/filepath" + "strings" + + "github.com/goreleaser/nfpm/v2" + "github.com/goreleaser/nfpm/v2/files" + "tailscale.com/release/dist" +) + +type tgzTarget struct { + filenameArch string // arch to use in filename instead of deriving from goEnv["GOARCH"] + goEnv map[string]string + signer dist.Signer +} + +func (t *tgzTarget) arch() string { + if t.filenameArch != "" { + return t.filenameArch + } + return t.goEnv["GOARCH"] +} + +func (t *tgzTarget) os() string { + return t.goEnv["GOOS"] +} + +func (t *tgzTarget) String() string { + return fmt.Sprintf("%s/%s/tgz", t.os(), t.arch()) +} + +func (t *tgzTarget) Build(b *dist.Build) ([]string, error) { + var filename string + if t.goEnv["GOOS"] == "linux" { + // Linux used to be the only tgz architecture, so we didn't put the OS + // name in the filename. + filename = fmt.Sprintf("tailscale_%s_%s.tgz", b.Version.Short, t.arch()) + } else { + filename = fmt.Sprintf("tailscale_%s_%s_%s.tgz", b.Version.Short, t.os(), t.arch()) + } + if err := b.BuildWebClientAssets(); err != nil { + return nil, err + } + ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) + if err != nil { + return nil, err + } + tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) + if err != nil { + return nil, err + } + + log.Printf("Building %s", filename) + + out := filepath.Join(b.Out, filename) + f, err := os.Create(out) + if err != nil { + return nil, err + } + defer f.Close() + gw := gzip.NewWriter(f) + defer gw.Close() + tw := tar.NewWriter(gw) + defer tw.Close() + + addFile := func(src, dst string, mode int64) error { + f, err := os.Open(src) + if err != nil { + return err + } + defer f.Close() + fi, err := f.Stat() + if err != nil { + return err + } + hdr := &tar.Header{ + Name: dst, + Size: fi.Size(), + Mode: mode, + ModTime: b.Time, + Uid: 0, + Gid: 0, + Uname: "root", + Gname: "root", + } + if err := tw.WriteHeader(hdr); err != nil { + return err + } + if _, err = io.Copy(tw, f); err != nil { + return err + } + return nil + } + addDir := func(name string) error { + hdr := &tar.Header{ + Name: name + "/", + Mode: 0755, + ModTime: b.Time, + Uid: 0, + Gid: 0, + Uname: "root", + Gname: "root", + } + return tw.WriteHeader(hdr) + } + dir := strings.TrimSuffix(filename, ".tgz") + if err := addDir(dir); err != nil { + return nil, err + } + if err := addFile(tsd, filepath.Join(dir, "tailscaled"), 0755); err != nil { + return nil, err + } + if err := addFile(ts, filepath.Join(dir, "tailscale"), 0755); err != nil { + return nil, err + } + if t.os() == "linux" { + dir = filepath.Join(dir, "systemd") + if err := addDir(dir); err != nil { + return nil, err + } + tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") + if err != nil { + return nil, err + } + if err := addFile(filepath.Join(tailscaledDir, "tailscaled.service"), filepath.Join(dir, "tailscaled.service"), 0644); err != nil { + return nil, err + } + if err := addFile(filepath.Join(tailscaledDir, "tailscaled.defaults"), filepath.Join(dir, "tailscaled.defaults"), 0644); err != nil { + return nil, err + } + } + if err := tw.Close(); err != nil { + return nil, err + } + if err := gw.Close(); err != nil { + return nil, err + } + if err := f.Close(); err != nil { + return nil, err + } + + files := []string{filename} + + if t.signer != nil { + outSig := out + ".sig" + if err := t.signer.SignFile(out, outSig); err != nil { + return nil, err + } + files = append(files, filepath.Base(outSig)) + } + + return files, nil +} + +type debTarget struct { + goEnv map[string]string +} + +func (t *debTarget) os() string { + return t.goEnv["GOOS"] +} + +func (t *debTarget) arch() string { + return t.goEnv["GOARCH"] +} + +func (t *debTarget) String() string { + return fmt.Sprintf("linux/%s/deb", t.goEnv["GOARCH"]) +} + +func (t *debTarget) Build(b *dist.Build) ([]string, error) { + if t.os() != "linux" { + return nil, errors.New("deb only supported on linux") + } + + if err := b.BuildWebClientAssets(); err != nil { + return nil, err + } + ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) + if err != nil { + return nil, err + } + tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) + if err != nil { + return nil, err + } + + tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") + if err != nil { + return nil, err + } + repoDir, err := b.GoPkg("tailscale.com") + if err != nil { + return nil, err + } + + arch := debArch(t.arch()) + contents, err := files.PrepareForPackager(files.Contents{ + &files.Content{ + Type: files.TypeFile, + Source: ts, + Destination: "/usr/bin/tailscale", + }, + &files.Content{ + Type: files.TypeFile, + Source: tsd, + Destination: "/usr/sbin/tailscaled", + }, + &files.Content{ + Type: files.TypeFile, + Source: filepath.Join(tailscaledDir, "tailscaled.service"), + Destination: "/lib/systemd/system/tailscaled.service", + }, + &files.Content{ + Type: files.TypeConfigNoReplace, + Source: filepath.Join(tailscaledDir, "tailscaled.defaults"), + Destination: "/etc/default/tailscaled", + }, + }, 0, "deb", false) + if err != nil { + return nil, err + } + info := nfpm.WithDefaults(&nfpm.Info{ + Name: "tailscale", + Arch: arch, + Platform: "linux", + Version: b.Version.Short, + Maintainer: "Tailscale Inc ", + Description: "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", + Homepage: "https://www.tailscale.com", + License: "MIT", + Section: "net", + Priority: "extra", + Overridables: nfpm.Overridables{ + Contents: contents, + Scripts: nfpm.Scripts{ + PostInstall: filepath.Join(repoDir, "release/deb/debian.postinst.sh"), + PreRemove: filepath.Join(repoDir, "release/deb/debian.prerm.sh"), + PostRemove: filepath.Join(repoDir, "release/deb/debian.postrm.sh"), + }, + Depends: []string{ + // iptables is almost always required but not strictly needed. + // Even if you can technically run Tailscale without it (by + // manually configuring nftables or userspace mode), we still + // mark this as "Depends" because our previous experiment in + // https://github.com/tailscale/tailscale/issues/9236 of making + // it only Recommends caused too many problems. Until our + // nftables table is more mature, we'd rather err on the side of + // wasting a little disk by including iptables for people who + // might not need it rather than handle reports of it being + // missing. + "iptables", + }, + Recommends: []string{ + "tailscale-archive-keyring (>= 1.35.181)", + // The "ip" command isn't needed since 2021-11-01 in + // 408b0923a61972ed but kept as an option as of + // 2021-11-18 in d24ed3f68e35e802d531371. See + // https://github.com/tailscale/tailscale/issues/391. + // We keep it recommended because it's usually + // installed anyway and it's useful for debugging. But + // we can live without it, so it's not Depends. + "iproute2", + }, + Replaces: []string{"tailscale-relay"}, + Conflicts: []string{"tailscale-relay"}, + }, + }) + pkg, err := nfpm.Get("deb") + if err != nil { + return nil, err + } + + filename := fmt.Sprintf("tailscale_%s_%s.deb", b.Version.Short, arch) + log.Printf("Building %s", filename) + f, err := os.Create(filepath.Join(b.Out, filename)) + if err != nil { + return nil, err + } + defer f.Close() + if err := pkg.Package(info, f); err != nil { + return nil, err + } + if err := f.Close(); err != nil { + return nil, err + } + + return []string{filename}, nil +} + +type rpmTarget struct { + goEnv map[string]string + signer dist.Signer +} + +func (t *rpmTarget) os() string { + return t.goEnv["GOOS"] +} + +func (t *rpmTarget) arch() string { + return t.goEnv["GOARCH"] +} + +func (t *rpmTarget) String() string { + return fmt.Sprintf("linux/%s/rpm", t.arch()) +} + +func (t *rpmTarget) Build(b *dist.Build) ([]string, error) { + if t.os() != "linux" { + return nil, errors.New("rpm only supported on linux") + } + + if err := b.BuildWebClientAssets(); err != nil { + return nil, err + } + ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) + if err != nil { + return nil, err + } + tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) + if err != nil { + return nil, err + } + + tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") + if err != nil { + return nil, err + } + repoDir, err := b.GoPkg("tailscale.com") + if err != nil { + return nil, err + } + + arch := rpmArch(t.arch()) + contents, err := files.PrepareForPackager(files.Contents{ + &files.Content{ + Type: files.TypeFile, + Source: ts, + Destination: "/usr/bin/tailscale", + }, + &files.Content{ + Type: files.TypeFile, + Source: tsd, + Destination: "/usr/sbin/tailscaled", + }, + &files.Content{ + Type: files.TypeFile, + Source: filepath.Join(tailscaledDir, "tailscaled.service"), + Destination: "/lib/systemd/system/tailscaled.service", + }, + &files.Content{ + Type: files.TypeConfigNoReplace, + Source: filepath.Join(tailscaledDir, "tailscaled.defaults"), + Destination: "/etc/default/tailscaled", + }, + // SELinux policy on e.g. CentOS 8 forbids writing to /var/cache. + // Creating an empty directory at install time resolves this issue. + &files.Content{ + Type: files.TypeDir, + Destination: "/var/cache/tailscale", + }, + }, 0, "rpm", false) + if err != nil { + return nil, err + } + info := nfpm.WithDefaults(&nfpm.Info{ + Name: "tailscale", + Arch: arch, + Platform: "linux", + Version: b.Version.Short, + Maintainer: "Tailscale Inc ", + Description: "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", + Homepage: "https://www.tailscale.com", + License: "MIT", + Overridables: nfpm.Overridables{ + Contents: contents, + Scripts: nfpm.Scripts{ + PostInstall: filepath.Join(repoDir, "release/rpm/rpm.postinst.sh"), + PreRemove: filepath.Join(repoDir, "release/rpm/rpm.prerm.sh"), + PostRemove: filepath.Join(repoDir, "release/rpm/rpm.postrm.sh"), + }, + Depends: []string{"iptables", "iproute"}, + Replaces: []string{"tailscale-relay"}, + Conflicts: []string{"tailscale-relay"}, + RPM: nfpm.RPM{ + Group: "Network", + Signature: nfpm.RPMSignature{ + PackageSignature: nfpm.PackageSignature{ + SignFn: t.signer, + }, + }, + }, + }, + }) + pkg, err := nfpm.Get("rpm") + if err != nil { + return nil, err + } + + filename := fmt.Sprintf("tailscale_%s_%s.rpm", b.Version.Short, arch) + log.Printf("Building %s", filename) + + f, err := os.Create(filepath.Join(b.Out, filename)) + if err != nil { + return nil, err + } + defer f.Close() + if err := pkg.Package(info, f); err != nil { + return nil, err + } + if err := f.Close(); err != nil { + return nil, err + } + + return []string{filename}, nil +} + +// debArch returns the debian arch name for the given Go arch name. +// nfpm also does this translation internally, but we need to do it outside nfpm +// because we also need the filename to be correct. +func debArch(arch string) string { + switch arch { + case "386": + return "i386" + case "arm": + // TODO: this is supposed to be "armel" for GOARM=5, and "armhf" for + // GOARM=6 and 7. But we have some tech debt to pay off here before we + // can ship more than 1 ARM deb, so for now match redo's behavior of + // shipping armv5 binaries in an armv7 trenchcoat. + return "armhf" + case "mipsle": + return "mipsel" + case "mips64le": + return "mips64el" + default: + return arch + } +} + +// rpmArch returns the RPM arch name for the given Go arch name. +// nfpm also does this translation internally, but we need to do it outside nfpm +// because we also need the filename to be correct. +func rpmArch(arch string) string { + switch arch { + case "amd64": + return "x86_64" + case "386": + return "i386" + case "arm": + return "armv7hl" + case "arm64": + return "aarch64" + case "mipsle": + return "mipsel" + case "mips64le": + return "mips64el" + default: + return arch + } +} diff --git a/release/dist/unixpkgs/targets.go b/release/dist/unixpkgs/targets.go index 42bab6d3b2685..f87c56d317d9f 100644 --- a/release/dist/unixpkgs/targets.go +++ b/release/dist/unixpkgs/targets.go @@ -1,127 +1,127 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package unixpkgs - -import ( - "fmt" - "sort" - "strings" - - "tailscale.com/release/dist" - - _ "github.com/goreleaser/nfpm/v2/deb" - _ "github.com/goreleaser/nfpm/v2/rpm" -) - -type Signers struct { - Tarball dist.Signer - RPM dist.Signer -} - -func Targets(signers Signers) []dist.Target { - var ret []dist.Target - for goosgoarch := range tarballs { - goos, goarch := splitGoosGoarch(goosgoarch) - ret = append(ret, &tgzTarget{ - goEnv: map[string]string{ - "GOOS": goos, - "GOARCH": goarch, - }, - signer: signers.Tarball, - }) - } - for goosgoarch := range debs { - goos, goarch := splitGoosGoarch(goosgoarch) - ret = append(ret, &debTarget{ - goEnv: map[string]string{ - "GOOS": goos, - "GOARCH": goarch, - }, - }) - } - for goosgoarch := range rpms { - goos, goarch := splitGoosGoarch(goosgoarch) - ret = append(ret, &rpmTarget{ - goEnv: map[string]string{ - "GOOS": goos, - "GOARCH": goarch, - }, - signer: signers.RPM, - }) - } - - // Special case: AMD Geode is 386 with softfloat. Tarballs only since it's - // an ancient architecture. - ret = append(ret, &tgzTarget{ - filenameArch: "geode", - goEnv: map[string]string{ - "GOOS": "linux", - "GOARCH": "386", - "GO386": "softfloat", - }, - signer: signers.Tarball, - }) - - sort.Slice(ret, func(i, j int) bool { - return ret[i].String() < ret[j].String() - }) - - return ret -} - -var ( - tarballs = map[string]bool{ - "linux/386": true, - "linux/amd64": true, - "linux/arm": true, - "linux/arm64": true, - "linux/mips64": true, - "linux/mips64le": true, - "linux/mips": true, - "linux/mipsle": true, - "linux/riscv64": true, - // TODO: more tarballs we could distribute, but don't currently. Leaving - // out for initial parity with redo. - // "darwin/amd64": true, - // "darwin/arm64": true, - // "freebsd/amd64": true, - // "openbsd/amd64": true, - } - - debs = map[string]bool{ - "linux/386": true, - "linux/amd64": true, - "linux/arm": true, - "linux/arm64": true, - "linux/riscv64": true, - "linux/mipsle": true, - "linux/mips64le": true, - "linux/mips": true, - // Debian does not support big endian mips64. Leave that out until we know - // we need it. - // "linux/mips64": true, - } - - rpms = map[string]bool{ - "linux/386": true, - "linux/amd64": true, - "linux/arm": true, - "linux/arm64": true, - "linux/riscv64": true, - "linux/mipsle": true, - "linux/mips64le": true, - // Fedora only supports little endian mipses. Maybe some other distribution - // supports big-endian? Leave them out for now. - // "linux/mips": true, - // "linux/mips64": true, - } -) - -func splitGoosGoarch(s string) (string, string) { - goos, goarch, ok := strings.Cut(s, "/") - if !ok { - panic(fmt.Sprintf("invalid target %q", s)) - } - return goos, goarch -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package unixpkgs + +import ( + "fmt" + "sort" + "strings" + + "tailscale.com/release/dist" + + _ "github.com/goreleaser/nfpm/v2/deb" + _ "github.com/goreleaser/nfpm/v2/rpm" +) + +type Signers struct { + Tarball dist.Signer + RPM dist.Signer +} + +func Targets(signers Signers) []dist.Target { + var ret []dist.Target + for goosgoarch := range tarballs { + goos, goarch := splitGoosGoarch(goosgoarch) + ret = append(ret, &tgzTarget{ + goEnv: map[string]string{ + "GOOS": goos, + "GOARCH": goarch, + }, + signer: signers.Tarball, + }) + } + for goosgoarch := range debs { + goos, goarch := splitGoosGoarch(goosgoarch) + ret = append(ret, &debTarget{ + goEnv: map[string]string{ + "GOOS": goos, + "GOARCH": goarch, + }, + }) + } + for goosgoarch := range rpms { + goos, goarch := splitGoosGoarch(goosgoarch) + ret = append(ret, &rpmTarget{ + goEnv: map[string]string{ + "GOOS": goos, + "GOARCH": goarch, + }, + signer: signers.RPM, + }) + } + + // Special case: AMD Geode is 386 with softfloat. Tarballs only since it's + // an ancient architecture. + ret = append(ret, &tgzTarget{ + filenameArch: "geode", + goEnv: map[string]string{ + "GOOS": "linux", + "GOARCH": "386", + "GO386": "softfloat", + }, + signer: signers.Tarball, + }) + + sort.Slice(ret, func(i, j int) bool { + return ret[i].String() < ret[j].String() + }) + + return ret +} + +var ( + tarballs = map[string]bool{ + "linux/386": true, + "linux/amd64": true, + "linux/arm": true, + "linux/arm64": true, + "linux/mips64": true, + "linux/mips64le": true, + "linux/mips": true, + "linux/mipsle": true, + "linux/riscv64": true, + // TODO: more tarballs we could distribute, but don't currently. Leaving + // out for initial parity with redo. + // "darwin/amd64": true, + // "darwin/arm64": true, + // "freebsd/amd64": true, + // "openbsd/amd64": true, + } + + debs = map[string]bool{ + "linux/386": true, + "linux/amd64": true, + "linux/arm": true, + "linux/arm64": true, + "linux/riscv64": true, + "linux/mipsle": true, + "linux/mips64le": true, + "linux/mips": true, + // Debian does not support big endian mips64. Leave that out until we know + // we need it. + // "linux/mips64": true, + } + + rpms = map[string]bool{ + "linux/386": true, + "linux/amd64": true, + "linux/arm": true, + "linux/arm64": true, + "linux/riscv64": true, + "linux/mipsle": true, + "linux/mips64le": true, + // Fedora only supports little endian mipses. Maybe some other distribution + // supports big-endian? Leave them out for now. + // "linux/mips": true, + // "linux/mips64": true, + } +) + +func splitGoosGoarch(s string) (string, string) { + goos, goarch, ok := strings.Cut(s, "/") + if !ok { + panic(fmt.Sprintf("invalid target %q", s)) + } + return goos, goarch +} diff --git a/release/release.go b/release/release.go index a8d0e6b62e8d7..638635b6d23e9 100644 --- a/release/release.go +++ b/release/release.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package release provides functionality for building client releases. -package release - -import "embed" - -// This contains all files in the release directory, -// notably the files needed for deb, rpm, and similar packages. -// Because we assign this to the blank identifier, it does not actually embed the files. -// However, this does cause `go mod vendor` to include the files when vendoring the package. -// -//go:embed * -var _ embed.FS +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package release provides functionality for building client releases. +package release + +import "embed" + +// This contains all files in the release directory, +// notably the files needed for deb, rpm, and similar packages. +// Because we assign this to the blank identifier, it does not actually embed the files. +// However, this does cause `go mod vendor` to include the files when vendoring the package. +// +//go:embed * +var _ embed.FS diff --git a/release/rpm/rpm.postinst.sh b/release/rpm/rpm.postinst.sh index 3d264c5f60b18..f9c1fddfdfc73 100755 --- a/release/rpm/rpm.postinst.sh +++ b/release/rpm/rpm.postinst.sh @@ -1,41 +1,41 @@ -# $1 == 1 for initial installation. -# $1 == 2 for upgrades. - -if [ $1 -eq 1 ] ; then - # Normally, the tailscale-relay package would request shutdown of - # its service before uninstallation. Unfortunately, the - # tailscale-relay package we distributed doesn't have those - # scriptlets. We definitely want relaynode to be stopped when - # installing tailscaled though, so we blindly try to turn off - # relaynode here. - # - # However, we also want this package installation to look like an - # upgrade from relaynode! Therefore, if relaynode is currently - # enabled, we want to also enable tailscaled. If relaynode is - # currently running, we also want to start tailscaled. - # - # If there doesn't seem to be an active or enabled relaynode on - # the system, we follow the RPM convention for package installs, - # which is to not enable or start the service. - relaynode_enabled=0 - relaynode_running=0 - if systemctl is-enabled tailscale-relay.service >/dev/null 2>&1; then - relaynode_enabled=1 - fi - if systemctl is-active tailscale-relay.service >/dev/null 2>&1; then - relaynode_running=1 - fi - - systemctl --no-reload disable tailscale-relay.service >/dev/null 2>&1 || : - systemctl stop tailscale-relay.service >/dev/null 2>&1 || : - - if [ $relaynode_enabled -eq 1 ]; then - systemctl enable tailscaled.service >/dev/null 2>&1 || : - else - systemctl preset tailscaled.service >/dev/null 2>&1 || : - fi - - if [ $relaynode_running -eq 1 ]; then - systemctl start tailscaled.service >/dev/null 2>&1 || : - fi -fi +# $1 == 1 for initial installation. +# $1 == 2 for upgrades. + +if [ $1 -eq 1 ] ; then + # Normally, the tailscale-relay package would request shutdown of + # its service before uninstallation. Unfortunately, the + # tailscale-relay package we distributed doesn't have those + # scriptlets. We definitely want relaynode to be stopped when + # installing tailscaled though, so we blindly try to turn off + # relaynode here. + # + # However, we also want this package installation to look like an + # upgrade from relaynode! Therefore, if relaynode is currently + # enabled, we want to also enable tailscaled. If relaynode is + # currently running, we also want to start tailscaled. + # + # If there doesn't seem to be an active or enabled relaynode on + # the system, we follow the RPM convention for package installs, + # which is to not enable or start the service. + relaynode_enabled=0 + relaynode_running=0 + if systemctl is-enabled tailscale-relay.service >/dev/null 2>&1; then + relaynode_enabled=1 + fi + if systemctl is-active tailscale-relay.service >/dev/null 2>&1; then + relaynode_running=1 + fi + + systemctl --no-reload disable tailscale-relay.service >/dev/null 2>&1 || : + systemctl stop tailscale-relay.service >/dev/null 2>&1 || : + + if [ $relaynode_enabled -eq 1 ]; then + systemctl enable tailscaled.service >/dev/null 2>&1 || : + else + systemctl preset tailscaled.service >/dev/null 2>&1 || : + fi + + if [ $relaynode_running -eq 1 ]; then + systemctl start tailscaled.service >/dev/null 2>&1 || : + fi +fi diff --git a/release/rpm/rpm.postrm.sh b/release/rpm/rpm.postrm.sh index d74f7e9deac77..e19a7305cac23 100755 --- a/release/rpm/rpm.postrm.sh +++ b/release/rpm/rpm.postrm.sh @@ -1,8 +1,8 @@ -# $1 == 0 for uninstallation. -# $1 == 1 for removing old package during upgrade. - -systemctl daemon-reload >/dev/null 2>&1 || : -if [ $1 -ge 1 ] ; then - # Package upgrade, not uninstall - systemctl try-restart tailscaled.service >/dev/null 2>&1 || : -fi +# $1 == 0 for uninstallation. +# $1 == 1 for removing old package during upgrade. + +systemctl daemon-reload >/dev/null 2>&1 || : +if [ $1 -ge 1 ] ; then + # Package upgrade, not uninstall + systemctl try-restart tailscaled.service >/dev/null 2>&1 || : +fi diff --git a/release/rpm/rpm.prerm.sh b/release/rpm/rpm.prerm.sh index 682c01bd574d8..eeabf3b584721 100755 --- a/release/rpm/rpm.prerm.sh +++ b/release/rpm/rpm.prerm.sh @@ -1,8 +1,8 @@ -# $1 == 0 for uninstallation. -# $1 == 1 for removing old package during upgrade. - -if [ $1 -eq 0 ] ; then - # Package removal, not upgrade - systemctl --no-reload disable tailscaled.service > /dev/null 2>&1 || : - systemctl stop tailscaled.service > /dev/null 2>&1 || : -fi +# $1 == 0 for uninstallation. +# $1 == 1 for removing old package during upgrade. + +if [ $1 -eq 0 ] ; then + # Package removal, not upgrade + systemctl --no-reload disable tailscaled.service > /dev/null 2>&1 || : + systemctl stop tailscaled.service > /dev/null 2>&1 || : +fi diff --git a/safesocket/safesocket_test.go b/safesocket/safesocket_test.go index 3f36a1cf6ca1f..85b317bd6e70f 100644 --- a/safesocket/safesocket_test.go +++ b/safesocket/safesocket_test.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package safesocket - -import "testing" - -func TestLocalTCPPortAndToken(t *testing.T) { - // Just test that it compiles for now (is available on all platforms). - port, token, err := LocalTCPPortAndToken() - t.Logf("got %v, %s, %v", port, token, err) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package safesocket + +import "testing" + +func TestLocalTCPPortAndToken(t *testing.T) { + // Just test that it compiles for now (is available on all platforms). + port, token, err := LocalTCPPortAndToken() + t.Logf("got %v, %s, %v", port, token, err) +} diff --git a/smallzstd/testdata b/smallzstd/testdata index 76640fdc57df0..498b014fd8d36 100644 --- a/smallzstd/testdata +++ b/smallzstd/testdata @@ -1,14 +1,14 @@ -{"logtail":{"client_time":"2020-07-01T14:49:40.196597018-07:00","server_time":"2020-07-01T21:49:40.198371511Z"},"text":"9.8M/25.6M magicsock: starting endpoint update (periodic)\n"} -{"logtail":{"client_time":"2020-07-01T14:49:40.345925455-07:00","server_time":"2020-07-01T21:49:40.347904717Z"},"text":"9.9M/25.6M netcheck: udp=true v6=false mapvarydest=false hair=false v4a=202.188.7.1:41641 derp=2 derpdist=1v4:7ms,2v4:3ms,4v4:18ms\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347155742-07:00","server_time":"2020-07-01T21:49:43.34828658Z"},"text":"9.9M/25.6M control: map response long-poll timed out!\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347539333-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.9M/25.6M control: PollNetMap: context canceled\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347767812-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M control: sendStatus: mapRoutine1: state:authenticated\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347817165-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M blockEngineUpdates(false)\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347989028-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M wgcfg: [SViTM] skipping subnet route\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.349997554-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M Received error: PollNetMap: context canceled\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.350072606-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M control: mapRoutine: backoff: 30136 msec\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.998364646-07:00","server_time":"2020-07-01T21:49:47.999333754Z"},"text":"9.5M/25.6M [W1NbE] - [UcppE] Send handshake init [127.3.3.40:1, 6.1.1.6:37388*, 10.3.2.6:41641]\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.99881914-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: adding connection to derp-1 for [W1NbE]\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.998904932-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: 2 active derp conns: derp-1=cr0s,wr0s derp-2=cr16h0m0s,wr14h38m0s\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.999045606-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M derphttp.Client.Recv: connecting to derp-1 (nyc)\n"} -{"logtail":{"client_time":"2020-07-01T14:49:48.091104119-07:00","server_time":"2020-07-01T21:49:48.09280535Z"},"text":"9.6M/25.6M magicsock: rx [W1NbE] from 6.1.1.6:37388 (1/3), set as new priority\n"} +{"logtail":{"client_time":"2020-07-01T14:49:40.196597018-07:00","server_time":"2020-07-01T21:49:40.198371511Z"},"text":"9.8M/25.6M magicsock: starting endpoint update (periodic)\n"} +{"logtail":{"client_time":"2020-07-01T14:49:40.345925455-07:00","server_time":"2020-07-01T21:49:40.347904717Z"},"text":"9.9M/25.6M netcheck: udp=true v6=false mapvarydest=false hair=false v4a=202.188.7.1:41641 derp=2 derpdist=1v4:7ms,2v4:3ms,4v4:18ms\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.347155742-07:00","server_time":"2020-07-01T21:49:43.34828658Z"},"text":"9.9M/25.6M control: map response long-poll timed out!\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.347539333-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.9M/25.6M control: PollNetMap: context canceled\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.347767812-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M control: sendStatus: mapRoutine1: state:authenticated\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.347817165-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M blockEngineUpdates(false)\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.347989028-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M wgcfg: [SViTM] skipping subnet route\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.349997554-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M Received error: PollNetMap: context canceled\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.350072606-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M control: mapRoutine: backoff: 30136 msec\n"} +{"logtail":{"client_time":"2020-07-01T14:49:47.998364646-07:00","server_time":"2020-07-01T21:49:47.999333754Z"},"text":"9.5M/25.6M [W1NbE] - [UcppE] Send handshake init [127.3.3.40:1, 6.1.1.6:37388*, 10.3.2.6:41641]\n"} +{"logtail":{"client_time":"2020-07-01T14:49:47.99881914-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: adding connection to derp-1 for [W1NbE]\n"} +{"logtail":{"client_time":"2020-07-01T14:49:47.998904932-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: 2 active derp conns: derp-1=cr0s,wr0s derp-2=cr16h0m0s,wr14h38m0s\n"} +{"logtail":{"client_time":"2020-07-01T14:49:47.999045606-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M derphttp.Client.Recv: connecting to derp-1 (nyc)\n"} +{"logtail":{"client_time":"2020-07-01T14:49:48.091104119-07:00","server_time":"2020-07-01T21:49:48.09280535Z"},"text":"9.6M/25.6M magicsock: rx [W1NbE] from 6.1.1.6:37388 (1/3), set as new priority\n"} diff --git a/smallzstd/zstd.go b/smallzstd/zstd.go index 1d80854224359..d91afeb67e254 100644 --- a/smallzstd/zstd.go +++ b/smallzstd/zstd.go @@ -1,78 +1,78 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package smallzstd produces zstd encoders and decoders optimized for -// low memory usage, at the expense of compression efficiency. -// -// This package is optimized primarily for the memory cost of -// compressing and decompressing data. We reduce this cost in two -// major ways: disable parallelism within the library (i.e. don't use -// multiple CPU cores to decompress), and drop the compression window -// down from the defaults of 4-16MiB, to 8kiB. -// -// Decompressors cost 2x the window size in RAM to run, so by using an -// 8kiB window, we can run ~1000 more decompressors per unit of memory -// than with the defaults. -// -// Depending on context, the benefit is either being able to run more -// decoders (e.g. in our logs processing system), or having a lower -// memory footprint when using compression in network protocols -// (e.g. in tailscaled, which should have a minimal RAM cost). -package smallzstd - -import ( - "io" - - "github.com/klauspost/compress/zstd" -) - -// WindowSize is the window size used for zstd compression. Decoder -// memory usage scales linearly with WindowSize. -const WindowSize = 8 << 10 // 8kiB - -// NewDecoder returns a zstd.Decoder configured for low memory usage, -// at the expense of decompression performance. -func NewDecoder(r io.Reader, options ...zstd.DOption) (*zstd.Decoder, error) { - defaults := []zstd.DOption{ - // Default is GOMAXPROCS, which costs many KiB in stacks. - zstd.WithDecoderConcurrency(1), - // Default is to allocate more upfront for performance. We - // prefer lower memory use and a bit of GC load. - zstd.WithDecoderLowmem(true), - // You might expect to see zstd.WithDecoderMaxMemory - // here. However, it's not terribly safe to use if you're - // doing stateless decoding, because it sets the maximum - // amount of memory the decompressed data can occupy, rather - // than the window size of the zstd stream. This means a very - // compressible piece of data might violate the max memory - // limit here, even if the window size (and thus total memory - // required to decompress the data) is small. - // - // As a result, we don't set a decoder limit here, and rely on - // the encoder below producing "cheap" streams. Callers are - // welcome to set their own max memory setting, if - // contextually there is a clearly correct value (e.g. it's - // known from the upper layer protocol that the decoded data - // can never be more than 1MiB). - } - - return zstd.NewReader(r, append(defaults, options...)...) -} - -// NewEncoder returns a zstd.Encoder configured for low memory usage, -// both during compression and at decompression time, at the expense -// of performance and compression efficiency. -func NewEncoder(w io.Writer, options ...zstd.EOption) (*zstd.Encoder, error) { - defaults := []zstd.EOption{ - // Default is GOMAXPROCS, which costs many KiB in stacks. - zstd.WithEncoderConcurrency(1), - // Default is several MiB, which bloats both encoders and - // their corresponding decoders. - zstd.WithWindowSize(WindowSize), - // Encode zero-length inputs in a way that the `zstd` utility - // can read, because interoperability is handy. - zstd.WithZeroFrames(true), - } - - return zstd.NewWriter(w, append(defaults, options...)...) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package smallzstd produces zstd encoders and decoders optimized for +// low memory usage, at the expense of compression efficiency. +// +// This package is optimized primarily for the memory cost of +// compressing and decompressing data. We reduce this cost in two +// major ways: disable parallelism within the library (i.e. don't use +// multiple CPU cores to decompress), and drop the compression window +// down from the defaults of 4-16MiB, to 8kiB. +// +// Decompressors cost 2x the window size in RAM to run, so by using an +// 8kiB window, we can run ~1000 more decompressors per unit of memory +// than with the defaults. +// +// Depending on context, the benefit is either being able to run more +// decoders (e.g. in our logs processing system), or having a lower +// memory footprint when using compression in network protocols +// (e.g. in tailscaled, which should have a minimal RAM cost). +package smallzstd + +import ( + "io" + + "github.com/klauspost/compress/zstd" +) + +// WindowSize is the window size used for zstd compression. Decoder +// memory usage scales linearly with WindowSize. +const WindowSize = 8 << 10 // 8kiB + +// NewDecoder returns a zstd.Decoder configured for low memory usage, +// at the expense of decompression performance. +func NewDecoder(r io.Reader, options ...zstd.DOption) (*zstd.Decoder, error) { + defaults := []zstd.DOption{ + // Default is GOMAXPROCS, which costs many KiB in stacks. + zstd.WithDecoderConcurrency(1), + // Default is to allocate more upfront for performance. We + // prefer lower memory use and a bit of GC load. + zstd.WithDecoderLowmem(true), + // You might expect to see zstd.WithDecoderMaxMemory + // here. However, it's not terribly safe to use if you're + // doing stateless decoding, because it sets the maximum + // amount of memory the decompressed data can occupy, rather + // than the window size of the zstd stream. This means a very + // compressible piece of data might violate the max memory + // limit here, even if the window size (and thus total memory + // required to decompress the data) is small. + // + // As a result, we don't set a decoder limit here, and rely on + // the encoder below producing "cheap" streams. Callers are + // welcome to set their own max memory setting, if + // contextually there is a clearly correct value (e.g. it's + // known from the upper layer protocol that the decoded data + // can never be more than 1MiB). + } + + return zstd.NewReader(r, append(defaults, options...)...) +} + +// NewEncoder returns a zstd.Encoder configured for low memory usage, +// both during compression and at decompression time, at the expense +// of performance and compression efficiency. +func NewEncoder(w io.Writer, options ...zstd.EOption) (*zstd.Encoder, error) { + defaults := []zstd.EOption{ + // Default is GOMAXPROCS, which costs many KiB in stacks. + zstd.WithEncoderConcurrency(1), + // Default is several MiB, which bloats both encoders and + // their corresponding decoders. + zstd.WithWindowSize(WindowSize), + // Encode zero-length inputs in a way that the `zstd` utility + // can read, because interoperability is handy. + zstd.WithZeroFrames(true), + } + + return zstd.NewWriter(w, append(defaults, options...)...) +} diff --git a/syncs/locked.go b/syncs/locked.go index d2048665dee3d..abde5bca62415 100644 --- a/syncs/locked.go +++ b/syncs/locked.go @@ -1,32 +1,32 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syncs - -import ( - "sync" -) - -// AssertLocked panics if m is not locked. -func AssertLocked(m *sync.Mutex) { - if m.TryLock() { - m.Unlock() - panic("mutex is not locked") - } -} - -// AssertRLocked panics if rw is not locked for reading or writing. -func AssertRLocked(rw *sync.RWMutex) { - if rw.TryLock() { - rw.Unlock() - panic("mutex is not locked") - } -} - -// AssertWLocked panics if rw is not locked for writing. -func AssertWLocked(rw *sync.RWMutex) { - if rw.TryRLock() { - rw.RUnlock() - panic("mutex is not rlocked") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "sync" +) + +// AssertLocked panics if m is not locked. +func AssertLocked(m *sync.Mutex) { + if m.TryLock() { + m.Unlock() + panic("mutex is not locked") + } +} + +// AssertRLocked panics if rw is not locked for reading or writing. +func AssertRLocked(rw *sync.RWMutex) { + if rw.TryLock() { + rw.Unlock() + panic("mutex is not locked") + } +} + +// AssertWLocked panics if rw is not locked for writing. +func AssertWLocked(rw *sync.RWMutex) { + if rw.TryRLock() { + rw.RUnlock() + panic("mutex is not rlocked") + } +} diff --git a/syncs/locked_test.go b/syncs/locked_test.go index 90b36e8321d82..44877be50be1a 100644 --- a/syncs/locked_test.go +++ b/syncs/locked_test.go @@ -1,120 +1,120 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.13 && !go1.19 - -package syncs - -import ( - "sync" - "testing" - "time" -) - -func wantPanic(t *testing.T, fn func()) { - t.Helper() - defer func() { - recover() - }() - fn() - t.Fatal("failed to panic") -} - -func TestAssertLocked(t *testing.T) { - m := new(sync.Mutex) - wantPanic(t, func() { AssertLocked(m) }) - m.Lock() - AssertLocked(m) - m.Unlock() - wantPanic(t, func() { AssertLocked(m) }) - // Test correct handling of mutex with waiter. - m.Lock() - AssertLocked(m) - go func() { - m.Lock() - m.Unlock() - }() - // Give the goroutine above a few moments to get started. - // The test will pass whether or not we win the race, - // but we want to run sometimes, to get the test coverage. - time.Sleep(10 * time.Millisecond) - AssertLocked(m) -} - -func TestAssertWLocked(t *testing.T) { - m := new(sync.RWMutex) - wantPanic(t, func() { AssertWLocked(m) }) - m.Lock() - AssertWLocked(m) - m.Unlock() - wantPanic(t, func() { AssertWLocked(m) }) - // Test correct handling of mutex with waiter. - m.Lock() - AssertWLocked(m) - go func() { - m.Lock() - m.Unlock() - }() - // Give the goroutine above a few moments to get started. - // The test will pass whether or not we win the race, - // but we want to run sometimes, to get the test coverage. - time.Sleep(10 * time.Millisecond) - AssertWLocked(m) -} - -func TestAssertRLocked(t *testing.T) { - m := new(sync.RWMutex) - wantPanic(t, func() { AssertRLocked(m) }) - - m.Lock() - AssertRLocked(m) - m.Unlock() - - m.RLock() - AssertRLocked(m) - m.RUnlock() - - wantPanic(t, func() { AssertRLocked(m) }) - - // Test correct handling of mutex with waiter. - m.RLock() - AssertRLocked(m) - go func() { - m.RLock() - m.RUnlock() - }() - // Give the goroutine above a few moments to get started. - // The test will pass whether or not we win the race, - // but we want to run sometimes, to get the test coverage. - time.Sleep(10 * time.Millisecond) - AssertRLocked(m) - m.RUnlock() - - // Test correct handling of rlock with write waiter. - m.RLock() - AssertRLocked(m) - go func() { - m.Lock() - m.Unlock() - }() - // Give the goroutine above a few moments to get started. - // The test will pass whether or not we win the race, - // but we want to run sometimes, to get the test coverage. - time.Sleep(10 * time.Millisecond) - AssertRLocked(m) - m.RUnlock() - - // Test correct handling of rlock with other rlocks. - // This is a bit racy, but losing the race hurts nothing, - // and winning the race means correct test coverage. - m.RLock() - AssertRLocked(m) - go func() { - m.RLock() - time.Sleep(10 * time.Millisecond) - m.RUnlock() - }() - time.Sleep(5 * time.Millisecond) - AssertRLocked(m) - m.RUnlock() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.13 && !go1.19 + +package syncs + +import ( + "sync" + "testing" + "time" +) + +func wantPanic(t *testing.T, fn func()) { + t.Helper() + defer func() { + recover() + }() + fn() + t.Fatal("failed to panic") +} + +func TestAssertLocked(t *testing.T) { + m := new(sync.Mutex) + wantPanic(t, func() { AssertLocked(m) }) + m.Lock() + AssertLocked(m) + m.Unlock() + wantPanic(t, func() { AssertLocked(m) }) + // Test correct handling of mutex with waiter. + m.Lock() + AssertLocked(m) + go func() { + m.Lock() + m.Unlock() + }() + // Give the goroutine above a few moments to get started. + // The test will pass whether or not we win the race, + // but we want to run sometimes, to get the test coverage. + time.Sleep(10 * time.Millisecond) + AssertLocked(m) +} + +func TestAssertWLocked(t *testing.T) { + m := new(sync.RWMutex) + wantPanic(t, func() { AssertWLocked(m) }) + m.Lock() + AssertWLocked(m) + m.Unlock() + wantPanic(t, func() { AssertWLocked(m) }) + // Test correct handling of mutex with waiter. + m.Lock() + AssertWLocked(m) + go func() { + m.Lock() + m.Unlock() + }() + // Give the goroutine above a few moments to get started. + // The test will pass whether or not we win the race, + // but we want to run sometimes, to get the test coverage. + time.Sleep(10 * time.Millisecond) + AssertWLocked(m) +} + +func TestAssertRLocked(t *testing.T) { + m := new(sync.RWMutex) + wantPanic(t, func() { AssertRLocked(m) }) + + m.Lock() + AssertRLocked(m) + m.Unlock() + + m.RLock() + AssertRLocked(m) + m.RUnlock() + + wantPanic(t, func() { AssertRLocked(m) }) + + // Test correct handling of mutex with waiter. + m.RLock() + AssertRLocked(m) + go func() { + m.RLock() + m.RUnlock() + }() + // Give the goroutine above a few moments to get started. + // The test will pass whether or not we win the race, + // but we want to run sometimes, to get the test coverage. + time.Sleep(10 * time.Millisecond) + AssertRLocked(m) + m.RUnlock() + + // Test correct handling of rlock with write waiter. + m.RLock() + AssertRLocked(m) + go func() { + m.Lock() + m.Unlock() + }() + // Give the goroutine above a few moments to get started. + // The test will pass whether or not we win the race, + // but we want to run sometimes, to get the test coverage. + time.Sleep(10 * time.Millisecond) + AssertRLocked(m) + m.RUnlock() + + // Test correct handling of rlock with other rlocks. + // This is a bit racy, but losing the race hurts nothing, + // and winning the race means correct test coverage. + m.RLock() + AssertRLocked(m) + go func() { + m.RLock() + time.Sleep(10 * time.Millisecond) + m.RUnlock() + }() + time.Sleep(5 * time.Millisecond) + AssertRLocked(m) + m.RUnlock() +} diff --git a/syncs/shardedmap.go b/syncs/shardedmap.go index 12edf5bfce475..906de3ade2d5c 100644 --- a/syncs/shardedmap.go +++ b/syncs/shardedmap.go @@ -1,138 +1,138 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syncs - -import ( - "sync" - - "golang.org/x/sys/cpu" -) - -// ShardedMap is a synchronized map[K]V, internally sharded by a user-defined -// K-sharding function. -// -// The zero value is not safe for use; use NewShardedMap. -type ShardedMap[K comparable, V any] struct { - shardFunc func(K) int - shards []mapShard[K, V] -} - -type mapShard[K comparable, V any] struct { - mu sync.Mutex - m map[K]V - _ cpu.CacheLinePad // avoid false sharing of neighboring shards' mutexes -} - -// NewShardedMap returns a new ShardedMap with the given number of shards and -// sharding function. -// -// The shard func must return a integer in the range [0, shards) purely -// deterministically based on the provided K. -func NewShardedMap[K comparable, V any](shards int, shard func(K) int) *ShardedMap[K, V] { - m := &ShardedMap[K, V]{ - shardFunc: shard, - shards: make([]mapShard[K, V], shards), - } - for i := range m.shards { - m.shards[i].m = make(map[K]V) - } - return m -} - -func (m *ShardedMap[K, V]) shard(key K) *mapShard[K, V] { - return &m.shards[m.shardFunc(key)] -} - -// GetOk returns m[key] and whether it was present. -func (m *ShardedMap[K, V]) GetOk(key K) (value V, ok bool) { - shard := m.shard(key) - shard.mu.Lock() - defer shard.mu.Unlock() - value, ok = shard.m[key] - return -} - -// Get returns m[key] or the zero value of V if key is not present. -func (m *ShardedMap[K, V]) Get(key K) (value V) { - value, _ = m.GetOk(key) - return -} - -// Mutate atomically mutates m[k] by calling mutator. -// -// The mutator function is called with the old value (or its zero value) and -// whether it existed in the map and it returns the new value and whether it -// should be set in the map (true) or deleted from the map (false). -// -// It returns the change in size of the map as a result of the mutation, one of -// -1 (delete), 0 (change), or 1 (addition). -func (m *ShardedMap[K, V]) Mutate(key K, mutator func(oldValue V, oldValueExisted bool) (newValue V, keep bool)) (sizeDelta int) { - shard := m.shard(key) - shard.mu.Lock() - defer shard.mu.Unlock() - oldV, oldOK := shard.m[key] - newV, newOK := mutator(oldV, oldOK) - if newOK { - shard.m[key] = newV - if oldOK { - return 0 - } - return 1 - } - delete(shard.m, key) - if oldOK { - return -1 - } - return 0 -} - -// Set sets m[key] = value. -// -// present in m). -func (m *ShardedMap[K, V]) Set(key K, value V) (grew bool) { - shard := m.shard(key) - shard.mu.Lock() - defer shard.mu.Unlock() - s0 := len(shard.m) - shard.m[key] = value - return len(shard.m) > s0 -} - -// Delete removes key from m. -// -// It reports whether the map size shrunk (that is, whether key was present in -// the map). -func (m *ShardedMap[K, V]) Delete(key K) (shrunk bool) { - shard := m.shard(key) - shard.mu.Lock() - defer shard.mu.Unlock() - s0 := len(shard.m) - delete(shard.m, key) - return len(shard.m) < s0 -} - -// Contains reports whether m contains key. -func (m *ShardedMap[K, V]) Contains(key K) bool { - shard := m.shard(key) - shard.mu.Lock() - defer shard.mu.Unlock() - _, ok := shard.m[key] - return ok -} - -// Len returns the number of elements in m. -// -// It does so by locking shards one at a time, so it's not particularly cheap, -// nor does it give a consistent snapshot of the map. It's mostly intended for -// metrics or testing. -func (m *ShardedMap[K, V]) Len() int { - n := 0 - for i := range m.shards { - shard := &m.shards[i] - shard.mu.Lock() - n += len(shard.m) - shard.mu.Unlock() - } - return n -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "sync" + + "golang.org/x/sys/cpu" +) + +// ShardedMap is a synchronized map[K]V, internally sharded by a user-defined +// K-sharding function. +// +// The zero value is not safe for use; use NewShardedMap. +type ShardedMap[K comparable, V any] struct { + shardFunc func(K) int + shards []mapShard[K, V] +} + +type mapShard[K comparable, V any] struct { + mu sync.Mutex + m map[K]V + _ cpu.CacheLinePad // avoid false sharing of neighboring shards' mutexes +} + +// NewShardedMap returns a new ShardedMap with the given number of shards and +// sharding function. +// +// The shard func must return a integer in the range [0, shards) purely +// deterministically based on the provided K. +func NewShardedMap[K comparable, V any](shards int, shard func(K) int) *ShardedMap[K, V] { + m := &ShardedMap[K, V]{ + shardFunc: shard, + shards: make([]mapShard[K, V], shards), + } + for i := range m.shards { + m.shards[i].m = make(map[K]V) + } + return m +} + +func (m *ShardedMap[K, V]) shard(key K) *mapShard[K, V] { + return &m.shards[m.shardFunc(key)] +} + +// GetOk returns m[key] and whether it was present. +func (m *ShardedMap[K, V]) GetOk(key K) (value V, ok bool) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + value, ok = shard.m[key] + return +} + +// Get returns m[key] or the zero value of V if key is not present. +func (m *ShardedMap[K, V]) Get(key K) (value V) { + value, _ = m.GetOk(key) + return +} + +// Mutate atomically mutates m[k] by calling mutator. +// +// The mutator function is called with the old value (or its zero value) and +// whether it existed in the map and it returns the new value and whether it +// should be set in the map (true) or deleted from the map (false). +// +// It returns the change in size of the map as a result of the mutation, one of +// -1 (delete), 0 (change), or 1 (addition). +func (m *ShardedMap[K, V]) Mutate(key K, mutator func(oldValue V, oldValueExisted bool) (newValue V, keep bool)) (sizeDelta int) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + oldV, oldOK := shard.m[key] + newV, newOK := mutator(oldV, oldOK) + if newOK { + shard.m[key] = newV + if oldOK { + return 0 + } + return 1 + } + delete(shard.m, key) + if oldOK { + return -1 + } + return 0 +} + +// Set sets m[key] = value. +// +// present in m). +func (m *ShardedMap[K, V]) Set(key K, value V) (grew bool) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + s0 := len(shard.m) + shard.m[key] = value + return len(shard.m) > s0 +} + +// Delete removes key from m. +// +// It reports whether the map size shrunk (that is, whether key was present in +// the map). +func (m *ShardedMap[K, V]) Delete(key K) (shrunk bool) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + s0 := len(shard.m) + delete(shard.m, key) + return len(shard.m) < s0 +} + +// Contains reports whether m contains key. +func (m *ShardedMap[K, V]) Contains(key K) bool { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + _, ok := shard.m[key] + return ok +} + +// Len returns the number of elements in m. +// +// It does so by locking shards one at a time, so it's not particularly cheap, +// nor does it give a consistent snapshot of the map. It's mostly intended for +// metrics or testing. +func (m *ShardedMap[K, V]) Len() int { + n := 0 + for i := range m.shards { + shard := &m.shards[i] + shard.mu.Lock() + n += len(shard.m) + shard.mu.Unlock() + } + return n +} diff --git a/syncs/shardedmap_test.go b/syncs/shardedmap_test.go index 993ffdff875c2..170201c0a2b13 100644 --- a/syncs/shardedmap_test.go +++ b/syncs/shardedmap_test.go @@ -1,81 +1,81 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syncs - -import "testing" - -func TestShardedMap(t *testing.T) { - m := NewShardedMap[int, string](16, func(i int) int { return i % 16 }) - - if m.Contains(1) { - t.Errorf("got contains; want !contains") - } - if !m.Set(1, "one") { - t.Errorf("got !set; want set") - } - if m.Set(1, "one") { - t.Errorf("got set; want !set") - } - if !m.Contains(1) { - t.Errorf("got !contains; want contains") - } - if g, w := m.Get(1), "one"; g != w { - t.Errorf("got %q; want %q", g, w) - } - if _, ok := m.GetOk(1); !ok { - t.Errorf("got ok; want !ok") - } - if _, ok := m.GetOk(2); ok { - t.Errorf("got ok; want !ok") - } - if g, w := m.Len(), 1; g != w { - t.Errorf("got Len %v; want %v", g, w) - } - if m.Delete(2) { - t.Errorf("got deleted; want !deleted") - } - if !m.Delete(1) { - t.Errorf("got !deleted; want deleted") - } - if g, w := m.Len(), 0; g != w { - t.Errorf("got Len %v; want %v", g, w) - } - - // Mutation adding an entry. - if v := m.Mutate(1, func(was string, ok bool) (string, bool) { - if ok { - t.Fatal("was okay") - } - return "ONE", true - }); v != 1 { - t.Errorf("Mutate = %v; want 1", v) - } - if g, w := m.Get(1), "ONE"; g != w { - t.Errorf("got %q; want %q", g, w) - } - // Mutation changing an entry. - if v := m.Mutate(1, func(was string, ok bool) (string, bool) { - if !ok { - t.Fatal("wasn't okay") - } - return was + "-" + was, true - }); v != 0 { - t.Errorf("Mutate = %v; want 0", v) - } - if g, w := m.Get(1), "ONE-ONE"; g != w { - t.Errorf("got %q; want %q", g, w) - } - // Mutation removing an entry. - if v := m.Mutate(1, func(was string, ok bool) (string, bool) { - if !ok { - t.Fatal("wasn't okay") - } - return "", false - }); v != -1 { - t.Errorf("Mutate = %v; want -1", v) - } - if g, w := m.Get(1), ""; g != w { - t.Errorf("got %q; want %q", g, w) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import "testing" + +func TestShardedMap(t *testing.T) { + m := NewShardedMap[int, string](16, func(i int) int { return i % 16 }) + + if m.Contains(1) { + t.Errorf("got contains; want !contains") + } + if !m.Set(1, "one") { + t.Errorf("got !set; want set") + } + if m.Set(1, "one") { + t.Errorf("got set; want !set") + } + if !m.Contains(1) { + t.Errorf("got !contains; want contains") + } + if g, w := m.Get(1), "one"; g != w { + t.Errorf("got %q; want %q", g, w) + } + if _, ok := m.GetOk(1); !ok { + t.Errorf("got ok; want !ok") + } + if _, ok := m.GetOk(2); ok { + t.Errorf("got ok; want !ok") + } + if g, w := m.Len(), 1; g != w { + t.Errorf("got Len %v; want %v", g, w) + } + if m.Delete(2) { + t.Errorf("got deleted; want !deleted") + } + if !m.Delete(1) { + t.Errorf("got !deleted; want deleted") + } + if g, w := m.Len(), 0; g != w { + t.Errorf("got Len %v; want %v", g, w) + } + + // Mutation adding an entry. + if v := m.Mutate(1, func(was string, ok bool) (string, bool) { + if ok { + t.Fatal("was okay") + } + return "ONE", true + }); v != 1 { + t.Errorf("Mutate = %v; want 1", v) + } + if g, w := m.Get(1), "ONE"; g != w { + t.Errorf("got %q; want %q", g, w) + } + // Mutation changing an entry. + if v := m.Mutate(1, func(was string, ok bool) (string, bool) { + if !ok { + t.Fatal("wasn't okay") + } + return was + "-" + was, true + }); v != 0 { + t.Errorf("Mutate = %v; want 0", v) + } + if g, w := m.Get(1), "ONE-ONE"; g != w { + t.Errorf("got %q; want %q", g, w) + } + // Mutation removing an entry. + if v := m.Mutate(1, func(was string, ok bool) (string, bool) { + if !ok { + t.Fatal("wasn't okay") + } + return "", false + }); v != -1 { + t.Errorf("Mutate = %v; want -1", v) + } + if g, w := m.Get(1), ""; g != w { + t.Errorf("got %q; want %q", g, w) + } +} diff --git a/tailcfg/proto_port_range.go b/tailcfg/proto_port_range.go index f65c58804d44d..0bb7e388eaaa8 100644 --- a/tailcfg/proto_port_range.go +++ b/tailcfg/proto_port_range.go @@ -1,187 +1,187 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailcfg - -import ( - "errors" - "fmt" - "strconv" - "strings" - - "tailscale.com/types/ipproto" - "tailscale.com/util/vizerror" -) - -var ( - errEmptyProtocol = errors.New("empty protocol") - errEmptyString = errors.New("empty string") -) - -// ProtoPortRange is used to encode "proto:port" format. -// The following formats are supported: -// -// "*" allows all TCP, UDP and ICMP traffic on all ports. -// "" allows all TCP, UDP and ICMP traffic on the specified ports. -// "proto:*" allows traffic of the specified proto on all ports. -// "proto:" allows traffic of the specified proto on the specified port. -// -// Ports are either a single port number or a range of ports (e.g. "80-90"). -// String named protocols support names that ipproto.Proto accepts. -type ProtoPortRange struct { - // Proto is the IP protocol number. - // If Proto is 0, it means TCP+UDP+ICMP(4+6). - Proto int - Ports PortRange -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. See -// ProtoPortRange for the format. -func (ppr *ProtoPortRange) UnmarshalText(text []byte) error { - ppr2, err := parseProtoPortRange(string(text)) - if err != nil { - return err - } - *ppr = *ppr2 - return nil -} - -// MarshalText implements the encoding.TextMarshaler interface. See -// ProtoPortRange for the format. -func (ppr *ProtoPortRange) MarshalText() ([]byte, error) { - if ppr.Proto == 0 && ppr.Ports == (PortRange{}) { - return []byte{}, nil - } - return []byte(ppr.String()), nil -} - -// String implements the stringer interface. See ProtoPortRange for the -// format. -func (ppr ProtoPortRange) String() string { - if ppr.Proto == 0 { - if ppr.Ports == PortRangeAny { - return "*" - } - } - var buf strings.Builder - if ppr.Proto != 0 { - // Proto.MarshalText is infallible. - text, _ := ipproto.Proto(ppr.Proto).MarshalText() - buf.Write(text) - buf.Write([]byte(":")) - } - pr := ppr.Ports - if pr.First == pr.Last { - fmt.Fprintf(&buf, "%d", pr.First) - } else if pr == PortRangeAny { - buf.WriteByte('*') - } else { - fmt.Fprintf(&buf, "%d-%d", pr.First, pr.Last) - } - return buf.String() -} - -// ParseProtoPortRanges parses a slice of IP port range fields. -func ParseProtoPortRanges(ips []string) ([]ProtoPortRange, error) { - var out []ProtoPortRange - for _, p := range ips { - ppr, err := parseProtoPortRange(p) - if err != nil { - return nil, err - } - out = append(out, *ppr) - } - return out, nil -} - -func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) { - if ipProtoPort == "" { - return nil, errEmptyString - } - if ipProtoPort == "*" { - return &ProtoPortRange{Ports: PortRangeAny}, nil - } - if !strings.Contains(ipProtoPort, ":") { - ipProtoPort = "*:" + ipProtoPort - } - protoStr, portRange, err := parseHostPortRange(ipProtoPort) - if err != nil { - return nil, err - } - if protoStr == "" { - return nil, errEmptyProtocol - } - - ppr := &ProtoPortRange{ - Ports: portRange, - } - if protoStr == "*" { - return ppr, nil - } - var ipProto ipproto.Proto - if err := ipProto.UnmarshalText([]byte(protoStr)); err != nil { - return nil, err - } - ppr.Proto = int(ipProto) - return ppr, nil -} - -// parseHostPortRange parses hostport as HOST:PORTS where HOST is -// returned unchanged and PORTS is is either "*" or PORTLOW-PORTHIGH ranges. -func parseHostPortRange(hostport string) (host string, ports PortRange, err error) { - hostport = strings.ToLower(hostport) - colon := strings.LastIndexByte(hostport, ':') - if colon < 0 { - return "", ports, vizerror.New("hostport must contain a colon (\":\")") - } - host = hostport[:colon] - portlist := hostport[colon+1:] - - if strings.Contains(host, ",") { - return "", ports, vizerror.New("host cannot contain a comma (\",\")") - } - - if portlist == "*" { - // Special case: permit hostname:* as a port wildcard. - return host, PortRangeAny, nil - } - - if len(portlist) == 0 { - return "", ports, vizerror.Errorf("invalid port list: %#v", portlist) - } - - if strings.Count(portlist, "-") > 1 { - return "", ports, vizerror.Errorf("port range %#v: too many dashes(-)", portlist) - } - - firstStr, lastStr, isRange := strings.Cut(portlist, "-") - - var first, last uint64 - first, err = strconv.ParseUint(firstStr, 10, 16) - if err != nil { - return "", ports, vizerror.Errorf("port range %#v: invalid first integer", portlist) - } - - if isRange { - last, err = strconv.ParseUint(lastStr, 10, 16) - if err != nil { - return "", ports, vizerror.Errorf("port range %#v: invalid last integer", portlist) - } - } else { - last = first - } - - if first == 0 { - return "", ports, vizerror.Errorf("port range %#v: first port must be >0, or use '*' for wildcard", portlist) - } - - if first > last { - return "", ports, vizerror.Errorf("port range %#v: first port must be >= last port", portlist) - } - - return host, newPortRange(uint16(first), uint16(last)), nil -} - -func newPortRange(first, last uint16) PortRange { - return PortRange{First: first, Last: last} -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailcfg + +import ( + "errors" + "fmt" + "strconv" + "strings" + + "tailscale.com/types/ipproto" + "tailscale.com/util/vizerror" +) + +var ( + errEmptyProtocol = errors.New("empty protocol") + errEmptyString = errors.New("empty string") +) + +// ProtoPortRange is used to encode "proto:port" format. +// The following formats are supported: +// +// "*" allows all TCP, UDP and ICMP traffic on all ports. +// "" allows all TCP, UDP and ICMP traffic on the specified ports. +// "proto:*" allows traffic of the specified proto on all ports. +// "proto:" allows traffic of the specified proto on the specified port. +// +// Ports are either a single port number or a range of ports (e.g. "80-90"). +// String named protocols support names that ipproto.Proto accepts. +type ProtoPortRange struct { + // Proto is the IP protocol number. + // If Proto is 0, it means TCP+UDP+ICMP(4+6). + Proto int + Ports PortRange +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. See +// ProtoPortRange for the format. +func (ppr *ProtoPortRange) UnmarshalText(text []byte) error { + ppr2, err := parseProtoPortRange(string(text)) + if err != nil { + return err + } + *ppr = *ppr2 + return nil +} + +// MarshalText implements the encoding.TextMarshaler interface. See +// ProtoPortRange for the format. +func (ppr *ProtoPortRange) MarshalText() ([]byte, error) { + if ppr.Proto == 0 && ppr.Ports == (PortRange{}) { + return []byte{}, nil + } + return []byte(ppr.String()), nil +} + +// String implements the stringer interface. See ProtoPortRange for the +// format. +func (ppr ProtoPortRange) String() string { + if ppr.Proto == 0 { + if ppr.Ports == PortRangeAny { + return "*" + } + } + var buf strings.Builder + if ppr.Proto != 0 { + // Proto.MarshalText is infallible. + text, _ := ipproto.Proto(ppr.Proto).MarshalText() + buf.Write(text) + buf.Write([]byte(":")) + } + pr := ppr.Ports + if pr.First == pr.Last { + fmt.Fprintf(&buf, "%d", pr.First) + } else if pr == PortRangeAny { + buf.WriteByte('*') + } else { + fmt.Fprintf(&buf, "%d-%d", pr.First, pr.Last) + } + return buf.String() +} + +// ParseProtoPortRanges parses a slice of IP port range fields. +func ParseProtoPortRanges(ips []string) ([]ProtoPortRange, error) { + var out []ProtoPortRange + for _, p := range ips { + ppr, err := parseProtoPortRange(p) + if err != nil { + return nil, err + } + out = append(out, *ppr) + } + return out, nil +} + +func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) { + if ipProtoPort == "" { + return nil, errEmptyString + } + if ipProtoPort == "*" { + return &ProtoPortRange{Ports: PortRangeAny}, nil + } + if !strings.Contains(ipProtoPort, ":") { + ipProtoPort = "*:" + ipProtoPort + } + protoStr, portRange, err := parseHostPortRange(ipProtoPort) + if err != nil { + return nil, err + } + if protoStr == "" { + return nil, errEmptyProtocol + } + + ppr := &ProtoPortRange{ + Ports: portRange, + } + if protoStr == "*" { + return ppr, nil + } + var ipProto ipproto.Proto + if err := ipProto.UnmarshalText([]byte(protoStr)); err != nil { + return nil, err + } + ppr.Proto = int(ipProto) + return ppr, nil +} + +// parseHostPortRange parses hostport as HOST:PORTS where HOST is +// returned unchanged and PORTS is is either "*" or PORTLOW-PORTHIGH ranges. +func parseHostPortRange(hostport string) (host string, ports PortRange, err error) { + hostport = strings.ToLower(hostport) + colon := strings.LastIndexByte(hostport, ':') + if colon < 0 { + return "", ports, vizerror.New("hostport must contain a colon (\":\")") + } + host = hostport[:colon] + portlist := hostport[colon+1:] + + if strings.Contains(host, ",") { + return "", ports, vizerror.New("host cannot contain a comma (\",\")") + } + + if portlist == "*" { + // Special case: permit hostname:* as a port wildcard. + return host, PortRangeAny, nil + } + + if len(portlist) == 0 { + return "", ports, vizerror.Errorf("invalid port list: %#v", portlist) + } + + if strings.Count(portlist, "-") > 1 { + return "", ports, vizerror.Errorf("port range %#v: too many dashes(-)", portlist) + } + + firstStr, lastStr, isRange := strings.Cut(portlist, "-") + + var first, last uint64 + first, err = strconv.ParseUint(firstStr, 10, 16) + if err != nil { + return "", ports, vizerror.Errorf("port range %#v: invalid first integer", portlist) + } + + if isRange { + last, err = strconv.ParseUint(lastStr, 10, 16) + if err != nil { + return "", ports, vizerror.Errorf("port range %#v: invalid last integer", portlist) + } + } else { + last = first + } + + if first == 0 { + return "", ports, vizerror.Errorf("port range %#v: first port must be >0, or use '*' for wildcard", portlist) + } + + if first > last { + return "", ports, vizerror.Errorf("port range %#v: first port must be >= last port", portlist) + } + + return host, newPortRange(uint16(first), uint16(last)), nil +} + +func newPortRange(first, last uint16) PortRange { + return PortRange{First: first, Last: last} +} diff --git a/tailcfg/proto_port_range_test.go b/tailcfg/proto_port_range_test.go index 59ccc9be4a1a8..31b282641e975 100644 --- a/tailcfg/proto_port_range_test.go +++ b/tailcfg/proto_port_range_test.go @@ -1,131 +1,131 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailcfg - -import ( - "encoding" - "testing" - - "tailscale.com/types/ipproto" - "tailscale.com/util/vizerror" -) - -var _ encoding.TextUnmarshaler = (*ProtoPortRange)(nil) - -func TestProtoPortRangeParsing(t *testing.T) { - pr := func(s, e uint16) PortRange { - return PortRange{First: s, Last: e} - } - tests := []struct { - in string - out ProtoPortRange - err error - }{ - {in: "tcp:80", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: pr(80, 80)}}, - {in: "80", out: ProtoPortRange{Ports: pr(80, 80)}}, - {in: "*", out: ProtoPortRange{Ports: PortRangeAny}}, - {in: "*:*", out: ProtoPortRange{Ports: PortRangeAny}}, - {in: "tcp:*", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: PortRangeAny}}, - { - in: "tcp:", - err: vizerror.Errorf("invalid port list: %#v", ""), - }, - { - in: ":80", - err: errEmptyProtocol, - }, - { - in: "", - err: errEmptyString, - }, - } - - for _, tc := range tests { - t.Run(tc.in, func(t *testing.T) { - var ppr ProtoPortRange - err := ppr.UnmarshalText([]byte(tc.in)) - if tc.err != err { - if err == nil || tc.err.Error() != err.Error() { - t.Fatalf("want err=%v, got %v", tc.err, err) - } - } - if ppr != tc.out { - t.Fatalf("got %v; want %v", ppr, tc.out) - } - }) - } -} - -func TestProtoPortRangeString(t *testing.T) { - tests := []struct { - input ProtoPortRange - want string - }{ - {ProtoPortRange{}, "0"}, - - // Zero protocol. - {ProtoPortRange{Ports: PortRangeAny}, "*"}, - {ProtoPortRange{Ports: PortRange{23, 23}}, "23"}, - {ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"}, - - // Non-zero unnamed protocol. - {ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"}, - {ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"}, - - // Non-zero named protocol. - {ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"}, - {ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"}, - {ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"}, - {ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"}, - {ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"}, - {ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"}, - {ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"}, - {ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"}, - } - for _, tc := range tests { - if got := tc.input.String(); got != tc.want { - t.Errorf("String for %v: got %q, want %q", tc.input, got, tc.want) - } - } -} - -func TestProtoPortRangeRoundTrip(t *testing.T) { - tests := []struct { - input ProtoPortRange - text string - }{ - {ProtoPortRange{Ports: PortRangeAny}, "*"}, - {ProtoPortRange{Ports: PortRange{23, 23}}, "23"}, - {ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"}, - {ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"}, - {ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"}, - {ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"}, - {ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"}, - {ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"}, - {ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"}, - {ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"}, - {ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"}, - {ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"}, - {ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"}, - } - - for _, tc := range tests { - out, err := tc.input.MarshalText() - if err != nil { - t.Errorf("MarshalText for %v: %v", tc.input, err) - continue - } - if got := string(out); got != tc.text { - t.Errorf("MarshalText for %#v: got %q, want %q", tc.input, got, tc.text) - } - var ppr ProtoPortRange - if err := ppr.UnmarshalText(out); err != nil { - t.Errorf("UnmarshalText for %q: err=%v", tc.text, err) - continue - } - if ppr != tc.input { - t.Errorf("round trip error for %q: got %v, want %#v", tc.text, ppr, tc.input) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailcfg + +import ( + "encoding" + "testing" + + "tailscale.com/types/ipproto" + "tailscale.com/util/vizerror" +) + +var _ encoding.TextUnmarshaler = (*ProtoPortRange)(nil) + +func TestProtoPortRangeParsing(t *testing.T) { + pr := func(s, e uint16) PortRange { + return PortRange{First: s, Last: e} + } + tests := []struct { + in string + out ProtoPortRange + err error + }{ + {in: "tcp:80", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: pr(80, 80)}}, + {in: "80", out: ProtoPortRange{Ports: pr(80, 80)}}, + {in: "*", out: ProtoPortRange{Ports: PortRangeAny}}, + {in: "*:*", out: ProtoPortRange{Ports: PortRangeAny}}, + {in: "tcp:*", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: PortRangeAny}}, + { + in: "tcp:", + err: vizerror.Errorf("invalid port list: %#v", ""), + }, + { + in: ":80", + err: errEmptyProtocol, + }, + { + in: "", + err: errEmptyString, + }, + } + + for _, tc := range tests { + t.Run(tc.in, func(t *testing.T) { + var ppr ProtoPortRange + err := ppr.UnmarshalText([]byte(tc.in)) + if tc.err != err { + if err == nil || tc.err.Error() != err.Error() { + t.Fatalf("want err=%v, got %v", tc.err, err) + } + } + if ppr != tc.out { + t.Fatalf("got %v; want %v", ppr, tc.out) + } + }) + } +} + +func TestProtoPortRangeString(t *testing.T) { + tests := []struct { + input ProtoPortRange + want string + }{ + {ProtoPortRange{}, "0"}, + + // Zero protocol. + {ProtoPortRange{Ports: PortRangeAny}, "*"}, + {ProtoPortRange{Ports: PortRange{23, 23}}, "23"}, + {ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"}, + + // Non-zero unnamed protocol. + {ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"}, + {ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"}, + + // Non-zero named protocol. + {ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"}, + {ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"}, + {ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"}, + {ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"}, + {ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"}, + {ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"}, + {ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"}, + {ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"}, + } + for _, tc := range tests { + if got := tc.input.String(); got != tc.want { + t.Errorf("String for %v: got %q, want %q", tc.input, got, tc.want) + } + } +} + +func TestProtoPortRangeRoundTrip(t *testing.T) { + tests := []struct { + input ProtoPortRange + text string + }{ + {ProtoPortRange{Ports: PortRangeAny}, "*"}, + {ProtoPortRange{Ports: PortRange{23, 23}}, "23"}, + {ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"}, + {ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"}, + {ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"}, + {ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"}, + {ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"}, + {ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"}, + {ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"}, + {ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"}, + {ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"}, + {ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"}, + {ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"}, + } + + for _, tc := range tests { + out, err := tc.input.MarshalText() + if err != nil { + t.Errorf("MarshalText for %v: %v", tc.input, err) + continue + } + if got := string(out); got != tc.text { + t.Errorf("MarshalText for %#v: got %q, want %q", tc.input, got, tc.text) + } + var ppr ProtoPortRange + if err := ppr.UnmarshalText(out); err != nil { + t.Errorf("UnmarshalText for %q: err=%v", tc.text, err) + continue + } + if ppr != tc.input { + t.Errorf("round trip error for %q: got %v, want %#v", tc.text, ppr, tc.input) + } + } +} diff --git a/tailcfg/tka.go b/tailcfg/tka.go index 97fdcc0db687a..ca7e6be76ba1e 100644 --- a/tailcfg/tka.go +++ b/tailcfg/tka.go @@ -1,264 +1,264 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailcfg - -import ( - "tailscale.com/types/key" - "tailscale.com/types/tkatype" -) - -// TKAInitBeginRequest submits a genesis AUM to seed the creation of the -// tailnet's key authority. -type TKAInitBeginRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // GenesisAUM is the initial (genesis) AUM that the node generated - // to bootstrap tailnet key authority state. - GenesisAUM tkatype.MarshaledAUM -} - -// TKASignInfo describes information about an existing node that needs -// to be signed into a node-key signature. -type TKASignInfo struct { - // NodeID is the ID of the node which needs a signature. It must - // correspond to NodePublic. - NodeID NodeID - // NodePublic is the node (Wireguard) public key which is being - // signed. - NodePublic key.NodePublic - - // RotationPubkey specifies the public key which may sign - // a NodeKeySignature (NKS), which rotates the node key. - // - // This is necessary so the node can rotate its node-key without - // talking to a node which holds a trusted network-lock key. - // It does this by nesting the original NKS in a 'rotation' NKS, - // which it then signs with the key corresponding to RotationPubkey. - // - // This field expects a raw ed25519 public key. - RotationPubkey []byte -} - -// TKAInitBeginResponse is the JSON response from a /tka/init/begin RPC. -// This structure describes node information which must be signed to -// complete initialization of the tailnets' key authority. -type TKAInitBeginResponse struct { - // NeedSignatures specify information about the nodes in your tailnet - // which need initial signatures to function once the tailnet key - // authority is in use. The generated signatures should then be - // submitted in a /tka/init/finish RPC. - NeedSignatures []TKASignInfo -} - -// TKAInitFinishRequest is the JSON request of a /tka/init/finish RPC. -// This RPC finalizes initialization of the tailnet key authority -// by submitting node-key signatures for all existing nodes. -type TKAInitFinishRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // Signatures are serialized tka.NodeKeySignatures for all nodes - // in the tailnet. - Signatures map[NodeID]tkatype.MarshaledSignature - - // SupportDisablement is a disablement secret for Tailscale support. - // This is only generated if --gen-disablement-for-support is specified - // in an invocation to 'tailscale lock init'. - SupportDisablement []byte `json:",omitempty"` -} - -// TKAInitFinishResponse is the JSON response from a /tka/init/finish RPC. -// This schema describes the successful enablement of the tailnet's -// key authority. -type TKAInitFinishResponse struct { - // Nothing. (yet?) -} - -// TKAInfo encodes the control plane's view of tailnet key authority (TKA) -// state. This information is transmitted as part of the MapResponse. -type TKAInfo struct { - // Head describes the hash of the latest AUM applied to the authority. - // Head is encoded as tka.AUMHash.MarshalText. - // - // If the Head state differs to that known locally, the node should perform - // synchronization via a separate RPC. - Head string `json:",omitempty"` - - // Disabled indicates the control plane believes TKA should be disabled, - // and the node should reach out to fetch a disablement - // secret. If the disablement secret verifies, then the node should then - // disable TKA locally. - // This field exists to disambiguate a nil TKAInfo in a delta mapresponse - // from a nil TKAInfo indicating TKA should be disabled. - Disabled bool `json:",omitempty"` -} - -// TKABootstrapRequest is sent by a node to get information necessary for -// enabling or disabling the tailnet key authority. -type TKABootstrapRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // Head represents the node's head AUMHash (tka.Authority.Head), if - // network lock is enabled. - Head string -} - -// TKABootstrapResponse encodes values necessary to enable or disable -// the tailnet key authority (TKA). -type TKABootstrapResponse struct { - // GenesisAUM returns the initial AUM necessary to initialize TKA. - GenesisAUM tkatype.MarshaledAUM `json:",omitempty"` - - // DisablementSecret encodes a secret necessary to disable TKA. - DisablementSecret []byte `json:",omitempty"` -} - -// TKASyncOfferRequest encodes a request to synchronize tailnet key authority -// state (TKA). Values of type tka.AUMHash are encoded as strings in their -// MarshalText form. -type TKASyncOfferRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // Head represents the node's head AUMHash (tka.Authority.Head). This - // corresponds to tka.SyncOffer.Head. - Head string - // Ancestors represents a selection of ancestor AUMHash values ascending - // from the current head. This corresponds to tka.SyncOffer.Ancestors. - Ancestors []string -} - -// TKASyncOfferResponse encodes a response in synchronizing a node's -// tailnet key authority state. Values of type tka.AUMHash are encoded as -// strings in their MarshalText form. -type TKASyncOfferResponse struct { - // Head represents the control plane's head AUMHash (tka.Authority.Head). - // This corresponds to tka.SyncOffer.Head. - Head string - // Ancestors represents a selection of ancestor AUMHash values ascending - // from the control plane's head. This corresponds to - // tka.SyncOffer.Ancestors. - Ancestors []string - // MissingAUMs encodes AUMs that the control plane believes the node - // is missing. - MissingAUMs []tkatype.MarshaledAUM -} - -// TKASyncSendRequest encodes AUMs that a node believes the control plane -// is missing, and notifies control of its local TKA state (specifically -// the head hash). -type TKASyncSendRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // Head represents the node's head AUMHash (tka.Authority.Head) after - // applying any AUMs from the sync-offer response. - // It is encoded as tka.AUMHash.MarshalText. - Head string - - // MissingAUMs encodes AUMs that the node believes the control plane - // is missing. - MissingAUMs []tkatype.MarshaledAUM - - // Interactive is true if additional error checking should be performed as - // the request is on behalf of an interactive operation (e.g., an - // administrator publishing new changes) as opposed to an automatic - // synchronization that may be reporting lost data. - Interactive bool -} - -// TKASyncSendResponse encodes the control plane's response to a node -// submitting AUMs during AUM synchronization. -type TKASyncSendResponse struct { - // Head represents the control plane's head AUMHash (tka.Authority.Head), - // after applying the missing AUMs. - Head string -} - -// TKADisableRequest disables network-lock across the tailnet using the -// provided disablement secret. -// -// This is the request schema for a /tka/disable noise RPC. -type TKADisableRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // Head represents the node's head AUMHash (tka.Authority.Head). - // It is encoded as tka.AUMHash.MarshalText. - Head string - - // DisablementSecret encodes the secret necessary to disable TKA. - DisablementSecret []byte -} - -// TKADisableResponse is the JSON response from a /tka/disable RPC. -// This schema describes the successful disablement of the tailnet's -// key authority. -type TKADisableResponse struct { - // Nothing. (yet?) -} - -// TKASubmitSignatureRequest transmits a node-key signature to the control plane. -// -// This is the request schema for a /tka/sign noise RPC. -type TKASubmitSignatureRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. The node-key which - // is being signed is embedded in Signature. - NodeKey key.NodePublic - - // Signature encodes the node-key signature being submitted. - Signature tkatype.MarshaledSignature -} - -// TKASubmitSignatureResponse is the JSON response from a /tka/sign RPC. -type TKASubmitSignatureResponse struct { - // Nothing. (yet?) -} - -// TKASignaturesUsingKeyRequest asks the control plane for -// all signatures which are signed by the provided keyID. -// -// This is the request schema for a /tka/affected-sigs RPC. -type TKASignaturesUsingKeyRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // KeyID is the key we are querying using. - KeyID tkatype.KeyID -} - -// TKASignaturesUsingKeyResponse is the JSON response to -// a /tka/affected-sigs RPC. -// -// It enumerates all signatures which are signed by the -// queried keyID. -type TKASignaturesUsingKeyResponse struct { - Signatures []tkatype.MarshaledSignature -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailcfg + +import ( + "tailscale.com/types/key" + "tailscale.com/types/tkatype" +) + +// TKAInitBeginRequest submits a genesis AUM to seed the creation of the +// tailnet's key authority. +type TKAInitBeginRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // GenesisAUM is the initial (genesis) AUM that the node generated + // to bootstrap tailnet key authority state. + GenesisAUM tkatype.MarshaledAUM +} + +// TKASignInfo describes information about an existing node that needs +// to be signed into a node-key signature. +type TKASignInfo struct { + // NodeID is the ID of the node which needs a signature. It must + // correspond to NodePublic. + NodeID NodeID + // NodePublic is the node (Wireguard) public key which is being + // signed. + NodePublic key.NodePublic + + // RotationPubkey specifies the public key which may sign + // a NodeKeySignature (NKS), which rotates the node key. + // + // This is necessary so the node can rotate its node-key without + // talking to a node which holds a trusted network-lock key. + // It does this by nesting the original NKS in a 'rotation' NKS, + // which it then signs with the key corresponding to RotationPubkey. + // + // This field expects a raw ed25519 public key. + RotationPubkey []byte +} + +// TKAInitBeginResponse is the JSON response from a /tka/init/begin RPC. +// This structure describes node information which must be signed to +// complete initialization of the tailnets' key authority. +type TKAInitBeginResponse struct { + // NeedSignatures specify information about the nodes in your tailnet + // which need initial signatures to function once the tailnet key + // authority is in use. The generated signatures should then be + // submitted in a /tka/init/finish RPC. + NeedSignatures []TKASignInfo +} + +// TKAInitFinishRequest is the JSON request of a /tka/init/finish RPC. +// This RPC finalizes initialization of the tailnet key authority +// by submitting node-key signatures for all existing nodes. +type TKAInitFinishRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // Signatures are serialized tka.NodeKeySignatures for all nodes + // in the tailnet. + Signatures map[NodeID]tkatype.MarshaledSignature + + // SupportDisablement is a disablement secret for Tailscale support. + // This is only generated if --gen-disablement-for-support is specified + // in an invocation to 'tailscale lock init'. + SupportDisablement []byte `json:",omitempty"` +} + +// TKAInitFinishResponse is the JSON response from a /tka/init/finish RPC. +// This schema describes the successful enablement of the tailnet's +// key authority. +type TKAInitFinishResponse struct { + // Nothing. (yet?) +} + +// TKAInfo encodes the control plane's view of tailnet key authority (TKA) +// state. This information is transmitted as part of the MapResponse. +type TKAInfo struct { + // Head describes the hash of the latest AUM applied to the authority. + // Head is encoded as tka.AUMHash.MarshalText. + // + // If the Head state differs to that known locally, the node should perform + // synchronization via a separate RPC. + Head string `json:",omitempty"` + + // Disabled indicates the control plane believes TKA should be disabled, + // and the node should reach out to fetch a disablement + // secret. If the disablement secret verifies, then the node should then + // disable TKA locally. + // This field exists to disambiguate a nil TKAInfo in a delta mapresponse + // from a nil TKAInfo indicating TKA should be disabled. + Disabled bool `json:",omitempty"` +} + +// TKABootstrapRequest is sent by a node to get information necessary for +// enabling or disabling the tailnet key authority. +type TKABootstrapRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // Head represents the node's head AUMHash (tka.Authority.Head), if + // network lock is enabled. + Head string +} + +// TKABootstrapResponse encodes values necessary to enable or disable +// the tailnet key authority (TKA). +type TKABootstrapResponse struct { + // GenesisAUM returns the initial AUM necessary to initialize TKA. + GenesisAUM tkatype.MarshaledAUM `json:",omitempty"` + + // DisablementSecret encodes a secret necessary to disable TKA. + DisablementSecret []byte `json:",omitempty"` +} + +// TKASyncOfferRequest encodes a request to synchronize tailnet key authority +// state (TKA). Values of type tka.AUMHash are encoded as strings in their +// MarshalText form. +type TKASyncOfferRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // Head represents the node's head AUMHash (tka.Authority.Head). This + // corresponds to tka.SyncOffer.Head. + Head string + // Ancestors represents a selection of ancestor AUMHash values ascending + // from the current head. This corresponds to tka.SyncOffer.Ancestors. + Ancestors []string +} + +// TKASyncOfferResponse encodes a response in synchronizing a node's +// tailnet key authority state. Values of type tka.AUMHash are encoded as +// strings in their MarshalText form. +type TKASyncOfferResponse struct { + // Head represents the control plane's head AUMHash (tka.Authority.Head). + // This corresponds to tka.SyncOffer.Head. + Head string + // Ancestors represents a selection of ancestor AUMHash values ascending + // from the control plane's head. This corresponds to + // tka.SyncOffer.Ancestors. + Ancestors []string + // MissingAUMs encodes AUMs that the control plane believes the node + // is missing. + MissingAUMs []tkatype.MarshaledAUM +} + +// TKASyncSendRequest encodes AUMs that a node believes the control plane +// is missing, and notifies control of its local TKA state (specifically +// the head hash). +type TKASyncSendRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // Head represents the node's head AUMHash (tka.Authority.Head) after + // applying any AUMs from the sync-offer response. + // It is encoded as tka.AUMHash.MarshalText. + Head string + + // MissingAUMs encodes AUMs that the node believes the control plane + // is missing. + MissingAUMs []tkatype.MarshaledAUM + + // Interactive is true if additional error checking should be performed as + // the request is on behalf of an interactive operation (e.g., an + // administrator publishing new changes) as opposed to an automatic + // synchronization that may be reporting lost data. + Interactive bool +} + +// TKASyncSendResponse encodes the control plane's response to a node +// submitting AUMs during AUM synchronization. +type TKASyncSendResponse struct { + // Head represents the control plane's head AUMHash (tka.Authority.Head), + // after applying the missing AUMs. + Head string +} + +// TKADisableRequest disables network-lock across the tailnet using the +// provided disablement secret. +// +// This is the request schema for a /tka/disable noise RPC. +type TKADisableRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // Head represents the node's head AUMHash (tka.Authority.Head). + // It is encoded as tka.AUMHash.MarshalText. + Head string + + // DisablementSecret encodes the secret necessary to disable TKA. + DisablementSecret []byte +} + +// TKADisableResponse is the JSON response from a /tka/disable RPC. +// This schema describes the successful disablement of the tailnet's +// key authority. +type TKADisableResponse struct { + // Nothing. (yet?) +} + +// TKASubmitSignatureRequest transmits a node-key signature to the control plane. +// +// This is the request schema for a /tka/sign noise RPC. +type TKASubmitSignatureRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. The node-key which + // is being signed is embedded in Signature. + NodeKey key.NodePublic + + // Signature encodes the node-key signature being submitted. + Signature tkatype.MarshaledSignature +} + +// TKASubmitSignatureResponse is the JSON response from a /tka/sign RPC. +type TKASubmitSignatureResponse struct { + // Nothing. (yet?) +} + +// TKASignaturesUsingKeyRequest asks the control plane for +// all signatures which are signed by the provided keyID. +// +// This is the request schema for a /tka/affected-sigs RPC. +type TKASignaturesUsingKeyRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // KeyID is the key we are querying using. + KeyID tkatype.KeyID +} + +// TKASignaturesUsingKeyResponse is the JSON response to +// a /tka/affected-sigs RPC. +// +// It enumerates all signatures which are signed by the +// queried keyID. +type TKASignaturesUsingKeyResponse struct { + Signatures []tkatype.MarshaledSignature +} diff --git a/taildrop/delete.go b/taildrop/delete.go index aaef34df1a7e4..7279a7687b2ec 100644 --- a/taildrop/delete.go +++ b/taildrop/delete.go @@ -1,205 +1,205 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package taildrop - -import ( - "container/list" - "context" - "io/fs" - "os" - "path/filepath" - "strings" - "sync" - "time" - - "tailscale.com/ipn" - "tailscale.com/syncs" - "tailscale.com/tstime" - "tailscale.com/types/logger" -) - -// deleteDelay is the amount of time to wait before we delete a file. -// A shorter value ensures timely deletion of deleted and partial files, while -// a longer value provides more opportunity for partial files to be resumed. -const deleteDelay = time.Hour - -// fileDeleter manages asynchronous deletion of files after deleteDelay. -type fileDeleter struct { - logf logger.Logf - clock tstime.DefaultClock - dir string - event func(string) // called for certain events; for testing only - - mu sync.Mutex - queue list.List - byName map[string]*list.Element - - emptySignal chan struct{} // signal that the queue is empty - group syncs.WaitGroup - shutdownCtx context.Context - shutdown context.CancelFunc -} - -// deleteFile is a specific file to delete after deleteDelay. -type deleteFile struct { - name string - inserted time.Time -} - -func (d *fileDeleter) Init(m *Manager, eventHook func(string)) { - d.logf = m.opts.Logf - d.clock = m.opts.Clock - d.dir = m.opts.Dir - d.event = eventHook - - d.byName = make(map[string]*list.Element) - d.emptySignal = make(chan struct{}) - d.shutdownCtx, d.shutdown = context.WithCancel(context.Background()) - - // From a cold-start, load the list of partial and deleted files. - // - // Only run this if we have ever received at least one file - // to avoid ever touching the taildrop directory on systems (e.g., MacOS) - // that pop up a security dialog window upon first access. - if m.opts.State == nil { - return - } - if b, _ := m.opts.State.ReadState(ipn.TaildropReceivedKey); len(b) == 0 { - return - } - d.group.Go(func() { - d.event("start full-scan") - defer d.event("end full-scan") - rangeDir(d.dir, func(de fs.DirEntry) bool { - switch { - case d.shutdownCtx.Err() != nil: - return false // terminate early - case !de.Type().IsRegular(): - return true - case strings.HasSuffix(de.Name(), partialSuffix): - // Only enqueue the file for deletion if there is no active put. - nameID := strings.TrimSuffix(de.Name(), partialSuffix) - if i := strings.LastIndexByte(nameID, '.'); i > 0 { - key := incomingFileKey{ClientID(nameID[i+len("."):]), nameID[:i]} - m.incomingFiles.LoadFunc(key, func(_ *incomingFile, loaded bool) { - if !loaded { - d.Insert(de.Name()) - } - }) - } else { - d.Insert(de.Name()) - } - case strings.HasSuffix(de.Name(), deletedSuffix): - // Best-effort immediate deletion of deleted files. - name := strings.TrimSuffix(de.Name(), deletedSuffix) - if os.Remove(filepath.Join(d.dir, name)) == nil { - if os.Remove(filepath.Join(d.dir, de.Name())) == nil { - break - } - } - // Otherwise, enqueue the file for later deletion. - d.Insert(de.Name()) - } - return true - }) - }) -} - -// Insert enqueues baseName for eventual deletion. -func (d *fileDeleter) Insert(baseName string) { - d.mu.Lock() - defer d.mu.Unlock() - if d.shutdownCtx.Err() != nil { - return - } - if _, ok := d.byName[baseName]; ok { - return // already queued for deletion - } - d.byName[baseName] = d.queue.PushBack(&deleteFile{baseName, d.clock.Now()}) - if d.queue.Len() == 1 && d.shutdownCtx.Err() == nil { - d.group.Go(func() { d.waitAndDelete(deleteDelay) }) - } -} - -// waitAndDelete is an asynchronous deletion goroutine. -// At most one waitAndDelete routine is ever running at a time. -// It is not started unless there is at least one file in the queue. -func (d *fileDeleter) waitAndDelete(wait time.Duration) { - tc, ch := d.clock.NewTimer(wait) - defer tc.Stop() // cleanup the timer resource if we stop early - d.event("start waitAndDelete") - defer d.event("end waitAndDelete") - select { - case <-d.shutdownCtx.Done(): - case <-d.emptySignal: - case now := <-ch: - d.mu.Lock() - defer d.mu.Unlock() - - // Iterate over all files to delete, and delete anything old enough. - var next *list.Element - var failed []*list.Element - for elem := d.queue.Front(); elem != nil; elem = next { - next = elem.Next() - file := elem.Value.(*deleteFile) - if now.Sub(file.inserted) < deleteDelay { - break // everything after this is recently inserted - } - - // Delete the expired file. - if name, ok := strings.CutSuffix(file.name, deletedSuffix); ok { - if err := os.Remove(filepath.Join(d.dir, name)); err != nil && !os.IsNotExist(err) { - d.logf("could not delete: %v", redactError(err)) - failed = append(failed, elem) - continue - } - } - if err := os.Remove(filepath.Join(d.dir, file.name)); err != nil && !os.IsNotExist(err) { - d.logf("could not delete: %v", redactError(err)) - failed = append(failed, elem) - continue - } - d.queue.Remove(elem) - delete(d.byName, file.name) - d.event("deleted " + file.name) - } - for _, elem := range failed { - elem.Value.(*deleteFile).inserted = now // retry after deleteDelay - d.queue.MoveToBack(elem) - } - - // If there are still some files to delete, retry again later. - if d.queue.Len() > 0 && d.shutdownCtx.Err() == nil { - file := d.queue.Front().Value.(*deleteFile) - retryAfter := deleteDelay - now.Sub(file.inserted) - d.group.Go(func() { d.waitAndDelete(retryAfter) }) - } - } -} - -// Remove dequeues baseName from eventual deletion. -func (d *fileDeleter) Remove(baseName string) { - d.mu.Lock() - defer d.mu.Unlock() - if elem := d.byName[baseName]; elem != nil { - d.queue.Remove(elem) - delete(d.byName, baseName) - // Signal to terminate any waitAndDelete goroutines. - if d.queue.Len() == 0 { - select { - case <-d.shutdownCtx.Done(): - case d.emptySignal <- struct{}{}: - } - } - } -} - -// Shutdown shuts down the deleter. -// It blocks until all goroutines are stopped. -func (d *fileDeleter) Shutdown() { - d.mu.Lock() // acquire lock to ensure no new goroutines start after shutdown - d.shutdown() - d.mu.Unlock() - d.group.Wait() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "container/list" + "context" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "tailscale.com/ipn" + "tailscale.com/syncs" + "tailscale.com/tstime" + "tailscale.com/types/logger" +) + +// deleteDelay is the amount of time to wait before we delete a file. +// A shorter value ensures timely deletion of deleted and partial files, while +// a longer value provides more opportunity for partial files to be resumed. +const deleteDelay = time.Hour + +// fileDeleter manages asynchronous deletion of files after deleteDelay. +type fileDeleter struct { + logf logger.Logf + clock tstime.DefaultClock + dir string + event func(string) // called for certain events; for testing only + + mu sync.Mutex + queue list.List + byName map[string]*list.Element + + emptySignal chan struct{} // signal that the queue is empty + group syncs.WaitGroup + shutdownCtx context.Context + shutdown context.CancelFunc +} + +// deleteFile is a specific file to delete after deleteDelay. +type deleteFile struct { + name string + inserted time.Time +} + +func (d *fileDeleter) Init(m *Manager, eventHook func(string)) { + d.logf = m.opts.Logf + d.clock = m.opts.Clock + d.dir = m.opts.Dir + d.event = eventHook + + d.byName = make(map[string]*list.Element) + d.emptySignal = make(chan struct{}) + d.shutdownCtx, d.shutdown = context.WithCancel(context.Background()) + + // From a cold-start, load the list of partial and deleted files. + // + // Only run this if we have ever received at least one file + // to avoid ever touching the taildrop directory on systems (e.g., MacOS) + // that pop up a security dialog window upon first access. + if m.opts.State == nil { + return + } + if b, _ := m.opts.State.ReadState(ipn.TaildropReceivedKey); len(b) == 0 { + return + } + d.group.Go(func() { + d.event("start full-scan") + defer d.event("end full-scan") + rangeDir(d.dir, func(de fs.DirEntry) bool { + switch { + case d.shutdownCtx.Err() != nil: + return false // terminate early + case !de.Type().IsRegular(): + return true + case strings.HasSuffix(de.Name(), partialSuffix): + // Only enqueue the file for deletion if there is no active put. + nameID := strings.TrimSuffix(de.Name(), partialSuffix) + if i := strings.LastIndexByte(nameID, '.'); i > 0 { + key := incomingFileKey{ClientID(nameID[i+len("."):]), nameID[:i]} + m.incomingFiles.LoadFunc(key, func(_ *incomingFile, loaded bool) { + if !loaded { + d.Insert(de.Name()) + } + }) + } else { + d.Insert(de.Name()) + } + case strings.HasSuffix(de.Name(), deletedSuffix): + // Best-effort immediate deletion of deleted files. + name := strings.TrimSuffix(de.Name(), deletedSuffix) + if os.Remove(filepath.Join(d.dir, name)) == nil { + if os.Remove(filepath.Join(d.dir, de.Name())) == nil { + break + } + } + // Otherwise, enqueue the file for later deletion. + d.Insert(de.Name()) + } + return true + }) + }) +} + +// Insert enqueues baseName for eventual deletion. +func (d *fileDeleter) Insert(baseName string) { + d.mu.Lock() + defer d.mu.Unlock() + if d.shutdownCtx.Err() != nil { + return + } + if _, ok := d.byName[baseName]; ok { + return // already queued for deletion + } + d.byName[baseName] = d.queue.PushBack(&deleteFile{baseName, d.clock.Now()}) + if d.queue.Len() == 1 && d.shutdownCtx.Err() == nil { + d.group.Go(func() { d.waitAndDelete(deleteDelay) }) + } +} + +// waitAndDelete is an asynchronous deletion goroutine. +// At most one waitAndDelete routine is ever running at a time. +// It is not started unless there is at least one file in the queue. +func (d *fileDeleter) waitAndDelete(wait time.Duration) { + tc, ch := d.clock.NewTimer(wait) + defer tc.Stop() // cleanup the timer resource if we stop early + d.event("start waitAndDelete") + defer d.event("end waitAndDelete") + select { + case <-d.shutdownCtx.Done(): + case <-d.emptySignal: + case now := <-ch: + d.mu.Lock() + defer d.mu.Unlock() + + // Iterate over all files to delete, and delete anything old enough. + var next *list.Element + var failed []*list.Element + for elem := d.queue.Front(); elem != nil; elem = next { + next = elem.Next() + file := elem.Value.(*deleteFile) + if now.Sub(file.inserted) < deleteDelay { + break // everything after this is recently inserted + } + + // Delete the expired file. + if name, ok := strings.CutSuffix(file.name, deletedSuffix); ok { + if err := os.Remove(filepath.Join(d.dir, name)); err != nil && !os.IsNotExist(err) { + d.logf("could not delete: %v", redactError(err)) + failed = append(failed, elem) + continue + } + } + if err := os.Remove(filepath.Join(d.dir, file.name)); err != nil && !os.IsNotExist(err) { + d.logf("could not delete: %v", redactError(err)) + failed = append(failed, elem) + continue + } + d.queue.Remove(elem) + delete(d.byName, file.name) + d.event("deleted " + file.name) + } + for _, elem := range failed { + elem.Value.(*deleteFile).inserted = now // retry after deleteDelay + d.queue.MoveToBack(elem) + } + + // If there are still some files to delete, retry again later. + if d.queue.Len() > 0 && d.shutdownCtx.Err() == nil { + file := d.queue.Front().Value.(*deleteFile) + retryAfter := deleteDelay - now.Sub(file.inserted) + d.group.Go(func() { d.waitAndDelete(retryAfter) }) + } + } +} + +// Remove dequeues baseName from eventual deletion. +func (d *fileDeleter) Remove(baseName string) { + d.mu.Lock() + defer d.mu.Unlock() + if elem := d.byName[baseName]; elem != nil { + d.queue.Remove(elem) + delete(d.byName, baseName) + // Signal to terminate any waitAndDelete goroutines. + if d.queue.Len() == 0 { + select { + case <-d.shutdownCtx.Done(): + case d.emptySignal <- struct{}{}: + } + } + } +} + +// Shutdown shuts down the deleter. +// It blocks until all goroutines are stopped. +func (d *fileDeleter) Shutdown() { + d.mu.Lock() // acquire lock to ensure no new goroutines start after shutdown + d.shutdown() + d.mu.Unlock() + d.group.Wait() +} diff --git a/taildrop/delete_test.go b/taildrop/delete_test.go index 5fa4b9c374fdf..b40fa35bfb0e3 100644 --- a/taildrop/delete_test.go +++ b/taildrop/delete_test.go @@ -1,152 +1,152 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package taildrop - -import ( - "os" - "path/filepath" - "slices" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "tailscale.com/ipn" - "tailscale.com/ipn/store/mem" - "tailscale.com/tstest" - "tailscale.com/tstime" - "tailscale.com/util/must" -) - -func TestDeleter(t *testing.T) { - dir := t.TempDir() - must.Do(touchFile(filepath.Join(dir, "foo.partial"))) - must.Do(touchFile(filepath.Join(dir, "bar.partial"))) - must.Do(touchFile(filepath.Join(dir, "fizz"))) - must.Do(touchFile(filepath.Join(dir, "fizz.deleted"))) - must.Do(touchFile(filepath.Join(dir, "buzz.deleted"))) // lacks a matching "buzz" file - - checkDirectory := func(want ...string) { - t.Helper() - var got []string - for _, de := range must.Get(os.ReadDir(dir)) { - got = append(got, de.Name()) - } - slices.Sort(got) - slices.Sort(want) - if diff := cmp.Diff(got, want); diff != "" { - t.Fatalf("directory mismatch (-got +want):\n%s", diff) - } - } - - clock := tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}) - advance := func(d time.Duration) { - t.Helper() - t.Logf("advance: %v", d) - clock.Advance(d) - } - - eventsChan := make(chan string, 1000) - checkEvents := func(want ...string) { - t.Helper() - tm := time.NewTimer(10 * time.Second) - defer tm.Stop() - var got []string - for range want { - select { - case event := <-eventsChan: - t.Logf("event: %s", event) - got = append(got, event) - case <-tm.C: - t.Fatalf("timed out waiting for event: got %v, want %v", got, want) - } - } - slices.Sort(got) - slices.Sort(want) - if diff := cmp.Diff(got, want); diff != "" { - t.Fatalf("events mismatch (-got +want):\n%s", diff) - } - } - eventHook := func(event string) { eventsChan <- event } - - var m Manager - var fd fileDeleter - m.opts.Logf = t.Logf - m.opts.Clock = tstime.DefaultClock{Clock: clock} - m.opts.Dir = dir - m.opts.State = must.Get(mem.New(nil, "")) - must.Do(m.opts.State.WriteState(ipn.TaildropReceivedKey, []byte{1})) - fd.Init(&m, eventHook) - defer fd.Shutdown() - insert := func(name string) { - t.Helper() - t.Logf("insert: %v", name) - fd.Insert(name) - } - remove := func(name string) { - t.Helper() - t.Logf("remove: %v", name) - fd.Remove(name) - } - - checkEvents("start full-scan") - checkEvents("end full-scan", "start waitAndDelete") - checkDirectory("foo.partial", "bar.partial", "buzz.deleted") - - advance(deleteDelay / 2) - checkDirectory("foo.partial", "bar.partial", "buzz.deleted") - advance(deleteDelay / 2) - checkEvents("deleted foo.partial", "deleted bar.partial", "deleted buzz.deleted") - checkEvents("end waitAndDelete") - checkDirectory() - - must.Do(touchFile(filepath.Join(dir, "one.partial"))) - insert("one.partial") - checkEvents("start waitAndDelete") - advance(deleteDelay / 4) - must.Do(touchFile(filepath.Join(dir, "two.partial"))) - insert("two.partial") - advance(deleteDelay / 4) - must.Do(touchFile(filepath.Join(dir, "three.partial"))) - insert("three.partial") - advance(deleteDelay / 4) - must.Do(touchFile(filepath.Join(dir, "four.partial"))) - insert("four.partial") - - advance(deleteDelay / 4) - checkEvents("deleted one.partial") - checkDirectory("two.partial", "three.partial", "four.partial") - checkEvents("end waitAndDelete", "start waitAndDelete") - - advance(deleteDelay / 4) - checkEvents("deleted two.partial") - checkDirectory("three.partial", "four.partial") - checkEvents("end waitAndDelete", "start waitAndDelete") - - advance(deleteDelay / 4) - checkEvents("deleted three.partial") - checkDirectory("four.partial") - checkEvents("end waitAndDelete", "start waitAndDelete") - - advance(deleteDelay / 4) - checkEvents("deleted four.partial") - checkDirectory() - checkEvents("end waitAndDelete") - - insert("wuzz.partial") - checkEvents("start waitAndDelete") - remove("wuzz.partial") - checkEvents("end waitAndDelete") -} - -// Test that the asynchronous full scan of the taildrop directory does not occur -// on a cold start if taildrop has never received any files. -func TestDeleterInitWithoutTaildrop(t *testing.T) { - var m Manager - var fd fileDeleter - m.opts.Logf = t.Logf - m.opts.Dir = t.TempDir() - m.opts.State = must.Get(mem.New(nil, "")) - fd.Init(&m, func(event string) { t.Errorf("unexpected event: %v", event) }) - fd.Shutdown() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "os" + "path/filepath" + "slices" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "tailscale.com/ipn" + "tailscale.com/ipn/store/mem" + "tailscale.com/tstest" + "tailscale.com/tstime" + "tailscale.com/util/must" +) + +func TestDeleter(t *testing.T) { + dir := t.TempDir() + must.Do(touchFile(filepath.Join(dir, "foo.partial"))) + must.Do(touchFile(filepath.Join(dir, "bar.partial"))) + must.Do(touchFile(filepath.Join(dir, "fizz"))) + must.Do(touchFile(filepath.Join(dir, "fizz.deleted"))) + must.Do(touchFile(filepath.Join(dir, "buzz.deleted"))) // lacks a matching "buzz" file + + checkDirectory := func(want ...string) { + t.Helper() + var got []string + for _, de := range must.Get(os.ReadDir(dir)) { + got = append(got, de.Name()) + } + slices.Sort(got) + slices.Sort(want) + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("directory mismatch (-got +want):\n%s", diff) + } + } + + clock := tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}) + advance := func(d time.Duration) { + t.Helper() + t.Logf("advance: %v", d) + clock.Advance(d) + } + + eventsChan := make(chan string, 1000) + checkEvents := func(want ...string) { + t.Helper() + tm := time.NewTimer(10 * time.Second) + defer tm.Stop() + var got []string + for range want { + select { + case event := <-eventsChan: + t.Logf("event: %s", event) + got = append(got, event) + case <-tm.C: + t.Fatalf("timed out waiting for event: got %v, want %v", got, want) + } + } + slices.Sort(got) + slices.Sort(want) + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("events mismatch (-got +want):\n%s", diff) + } + } + eventHook := func(event string) { eventsChan <- event } + + var m Manager + var fd fileDeleter + m.opts.Logf = t.Logf + m.opts.Clock = tstime.DefaultClock{Clock: clock} + m.opts.Dir = dir + m.opts.State = must.Get(mem.New(nil, "")) + must.Do(m.opts.State.WriteState(ipn.TaildropReceivedKey, []byte{1})) + fd.Init(&m, eventHook) + defer fd.Shutdown() + insert := func(name string) { + t.Helper() + t.Logf("insert: %v", name) + fd.Insert(name) + } + remove := func(name string) { + t.Helper() + t.Logf("remove: %v", name) + fd.Remove(name) + } + + checkEvents("start full-scan") + checkEvents("end full-scan", "start waitAndDelete") + checkDirectory("foo.partial", "bar.partial", "buzz.deleted") + + advance(deleteDelay / 2) + checkDirectory("foo.partial", "bar.partial", "buzz.deleted") + advance(deleteDelay / 2) + checkEvents("deleted foo.partial", "deleted bar.partial", "deleted buzz.deleted") + checkEvents("end waitAndDelete") + checkDirectory() + + must.Do(touchFile(filepath.Join(dir, "one.partial"))) + insert("one.partial") + checkEvents("start waitAndDelete") + advance(deleteDelay / 4) + must.Do(touchFile(filepath.Join(dir, "two.partial"))) + insert("two.partial") + advance(deleteDelay / 4) + must.Do(touchFile(filepath.Join(dir, "three.partial"))) + insert("three.partial") + advance(deleteDelay / 4) + must.Do(touchFile(filepath.Join(dir, "four.partial"))) + insert("four.partial") + + advance(deleteDelay / 4) + checkEvents("deleted one.partial") + checkDirectory("two.partial", "three.partial", "four.partial") + checkEvents("end waitAndDelete", "start waitAndDelete") + + advance(deleteDelay / 4) + checkEvents("deleted two.partial") + checkDirectory("three.partial", "four.partial") + checkEvents("end waitAndDelete", "start waitAndDelete") + + advance(deleteDelay / 4) + checkEvents("deleted three.partial") + checkDirectory("four.partial") + checkEvents("end waitAndDelete", "start waitAndDelete") + + advance(deleteDelay / 4) + checkEvents("deleted four.partial") + checkDirectory() + checkEvents("end waitAndDelete") + + insert("wuzz.partial") + checkEvents("start waitAndDelete") + remove("wuzz.partial") + checkEvents("end waitAndDelete") +} + +// Test that the asynchronous full scan of the taildrop directory does not occur +// on a cold start if taildrop has never received any files. +func TestDeleterInitWithoutTaildrop(t *testing.T) { + var m Manager + var fd fileDeleter + m.opts.Logf = t.Logf + m.opts.Dir = t.TempDir() + m.opts.State = must.Get(mem.New(nil, "")) + fd.Init(&m, func(event string) { t.Errorf("unexpected event: %v", event) }) + fd.Shutdown() +} diff --git a/taildrop/resume_test.go b/taildrop/resume_test.go index d366340eb6efa..8758ddd29d48c 100644 --- a/taildrop/resume_test.go +++ b/taildrop/resume_test.go @@ -1,74 +1,74 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package taildrop - -import ( - "bytes" - "io" - "math/rand" - "os" - "testing" - "testing/iotest" - - "tailscale.com/util/must" -) - -func TestResume(t *testing.T) { - oldBlockSize := blockSize - defer func() { blockSize = oldBlockSize }() - blockSize = 256 - - m := ManagerOptions{Logf: t.Logf, Dir: t.TempDir()}.New() - defer m.Shutdown() - - rn := rand.New(rand.NewSource(0)) - want := make([]byte, 12345) - must.Get(io.ReadFull(rn, want)) - - t.Run("resume-noexist", func(t *testing.T) { - r := io.Reader(bytes.NewReader(want)) - - next, close, err := m.HashPartialFile("", "foo") - must.Do(err) - defer close() - offset, r, err := ResumeReader(r, next) - must.Do(err) - must.Do(close()) // Windows wants the file handle to be closed to rename it. - - must.Get(m.PutFile("", "foo", r, offset, -1)) - got := must.Get(os.ReadFile(must.Get(joinDir(m.opts.Dir, "foo")))) - if !bytes.Equal(got, want) { - t.Errorf("content mismatches") - } - }) - - t.Run("resume-retry", func(t *testing.T) { - rn := rand.New(rand.NewSource(0)) - for i := 0; true; i++ { - r := io.Reader(bytes.NewReader(want)) - - next, close, err := m.HashPartialFile("", "bar") - must.Do(err) - defer close() - offset, r, err := ResumeReader(r, next) - must.Do(err) - must.Do(close()) // Windows wants the file handle to be closed to rename it. - - numWant := rn.Int63n(min(int64(len(want))-offset, 1000) + 1) - if offset < int64(len(want)) { - r = io.MultiReader(io.LimitReader(r, numWant), iotest.ErrReader(io.ErrClosedPipe)) - } - if _, err := m.PutFile("", "bar", r, offset, -1); err == nil { - break - } - if i > 1000 { - t.Fatalf("too many iterations to complete the test") - } - } - got := must.Get(os.ReadFile(must.Get(joinDir(m.opts.Dir, "bar")))) - if !bytes.Equal(got, want) { - t.Errorf("content mismatches") - } - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "bytes" + "io" + "math/rand" + "os" + "testing" + "testing/iotest" + + "tailscale.com/util/must" +) + +func TestResume(t *testing.T) { + oldBlockSize := blockSize + defer func() { blockSize = oldBlockSize }() + blockSize = 256 + + m := ManagerOptions{Logf: t.Logf, Dir: t.TempDir()}.New() + defer m.Shutdown() + + rn := rand.New(rand.NewSource(0)) + want := make([]byte, 12345) + must.Get(io.ReadFull(rn, want)) + + t.Run("resume-noexist", func(t *testing.T) { + r := io.Reader(bytes.NewReader(want)) + + next, close, err := m.HashPartialFile("", "foo") + must.Do(err) + defer close() + offset, r, err := ResumeReader(r, next) + must.Do(err) + must.Do(close()) // Windows wants the file handle to be closed to rename it. + + must.Get(m.PutFile("", "foo", r, offset, -1)) + got := must.Get(os.ReadFile(must.Get(joinDir(m.opts.Dir, "foo")))) + if !bytes.Equal(got, want) { + t.Errorf("content mismatches") + } + }) + + t.Run("resume-retry", func(t *testing.T) { + rn := rand.New(rand.NewSource(0)) + for i := 0; true; i++ { + r := io.Reader(bytes.NewReader(want)) + + next, close, err := m.HashPartialFile("", "bar") + must.Do(err) + defer close() + offset, r, err := ResumeReader(r, next) + must.Do(err) + must.Do(close()) // Windows wants the file handle to be closed to rename it. + + numWant := rn.Int63n(min(int64(len(want))-offset, 1000) + 1) + if offset < int64(len(want)) { + r = io.MultiReader(io.LimitReader(r, numWant), iotest.ErrReader(io.ErrClosedPipe)) + } + if _, err := m.PutFile("", "bar", r, offset, -1); err == nil { + break + } + if i > 1000 { + t.Fatalf("too many iterations to complete the test") + } + } + got := must.Get(os.ReadFile(must.Get(joinDir(m.opts.Dir, "bar")))) + if !bytes.Equal(got, want) { + t.Errorf("content mismatches") + } + }) +} diff --git a/taildrop/retrieve.go b/taildrop/retrieve.go index 3e37b492adc0a..527f8caed2bf5 100644 --- a/taildrop/retrieve.go +++ b/taildrop/retrieve.go @@ -1,178 +1,178 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package taildrop - -import ( - "context" - "errors" - "io" - "io/fs" - "os" - "path/filepath" - "runtime" - "sort" - "time" - - "tailscale.com/client/tailscale/apitype" - "tailscale.com/logtail/backoff" -) - -// HasFilesWaiting reports whether any files are buffered in [Handler.Dir]. -// This always returns false when [Handler.DirectFileMode] is false. -func (m *Manager) HasFilesWaiting() (has bool) { - if m == nil || m.opts.Dir == "" || m.opts.DirectFileMode { - return false - } - - // Optimization: this is usually empty, so avoid opening - // the directory and checking. We can't cache the actual - // has-files-or-not values as the macOS/iOS client might - // in the future use+delete the files directly. So only - // keep this negative cache. - totalReceived := m.totalReceived.Load() - if totalReceived == m.emptySince.Load() { - return false - } - - // Check whether there is at least one one waiting file. - err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { - name := de.Name() - if isPartialOrDeleted(name) || !de.Type().IsRegular() { - return true - } - _, err := os.Stat(filepath.Join(m.opts.Dir, name+deletedSuffix)) - if os.IsNotExist(err) { - has = true - return false - } - return true - }) - - // If there are no more waiting files, record totalReceived as emptySince - // so that we can short-circuit the expensive directory traversal - // if no files have been received after the start of this call. - if err == nil && !has { - m.emptySince.Store(totalReceived) - } - return has -} - -// WaitingFiles returns the list of files that have been sent by a -// peer that are waiting in [Handler.Dir]. -// This always returns nil when [Handler.DirectFileMode] is false. -func (m *Manager) WaitingFiles() (ret []apitype.WaitingFile, err error) { - if m == nil || m.opts.Dir == "" { - return nil, ErrNoTaildrop - } - if m.opts.DirectFileMode { - return nil, nil - } - if err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { - name := de.Name() - if isPartialOrDeleted(name) || !de.Type().IsRegular() { - return true - } - _, err := os.Stat(filepath.Join(m.opts.Dir, name+deletedSuffix)) - if os.IsNotExist(err) { - fi, err := de.Info() - if err != nil { - return true - } - ret = append(ret, apitype.WaitingFile{ - Name: filepath.Base(name), - Size: fi.Size(), - }) - } - return true - }); err != nil { - return nil, redactError(err) - } - sort.Slice(ret, func(i, j int) bool { return ret[i].Name < ret[j].Name }) - return ret, nil -} - -// DeleteFile deletes a file of the given baseName from [Handler.Dir]. -// This method is only allowed when [Handler.DirectFileMode] is false. -func (m *Manager) DeleteFile(baseName string) error { - if m == nil || m.opts.Dir == "" { - return ErrNoTaildrop - } - if m.opts.DirectFileMode { - return errors.New("deletes not allowed in direct mode") - } - path, err := joinDir(m.opts.Dir, baseName) - if err != nil { - return err - } - var bo *backoff.Backoff - logf := m.opts.Logf - t0 := m.opts.Clock.Now() - for { - err := os.Remove(path) - if err != nil && !os.IsNotExist(err) { - err = redactError(err) - // Put a retry loop around deletes on Windows. - // - // Windows file descriptor closes are effectively asynchronous, - // as a bunch of hooks run on/after close, - // and we can't necessarily delete the file for a while after close, - // as we need to wait for everybody to be done with it. - // On Windows, unlike Unix, a file can't be deleted if it's open anywhere. - // So try a few times but ultimately just leave a "foo.jpg.deleted" - // marker file to note that it's deleted and we clean it up later. - if runtime.GOOS == "windows" { - if bo == nil { - bo = backoff.NewBackoff("delete-retry", logf, 1*time.Second) - } - if m.opts.Clock.Since(t0) < 5*time.Second { - bo.BackOff(context.Background(), err) - continue - } - if err := touchFile(path + deletedSuffix); err != nil { - logf("peerapi: failed to leave deleted marker: %v", err) - } - m.deleter.Insert(baseName + deletedSuffix) - } - logf("peerapi: failed to DeleteFile: %v", err) - return err - } - return nil - } -} - -func touchFile(path string) error { - f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0666) - if err != nil { - return redactError(err) - } - return f.Close() -} - -// OpenFile opens a file of the given baseName from [Handler.Dir]. -// This method is only allowed when [Handler.DirectFileMode] is false. -func (m *Manager) OpenFile(baseName string) (rc io.ReadCloser, size int64, err error) { - if m == nil || m.opts.Dir == "" { - return nil, 0, ErrNoTaildrop - } - if m.opts.DirectFileMode { - return nil, 0, errors.New("opens not allowed in direct mode") - } - path, err := joinDir(m.opts.Dir, baseName) - if err != nil { - return nil, 0, err - } - if _, err := os.Stat(path + deletedSuffix); err == nil { - return nil, 0, redactError(&fs.PathError{Op: "open", Path: path, Err: fs.ErrNotExist}) - } - f, err := os.Open(path) - if err != nil { - return nil, 0, redactError(err) - } - fi, err := f.Stat() - if err != nil { - f.Close() - return nil, 0, redactError(err) - } - return f, fi.Size(), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "context" + "errors" + "io" + "io/fs" + "os" + "path/filepath" + "runtime" + "sort" + "time" + + "tailscale.com/client/tailscale/apitype" + "tailscale.com/logtail/backoff" +) + +// HasFilesWaiting reports whether any files are buffered in [Handler.Dir]. +// This always returns false when [Handler.DirectFileMode] is false. +func (m *Manager) HasFilesWaiting() (has bool) { + if m == nil || m.opts.Dir == "" || m.opts.DirectFileMode { + return false + } + + // Optimization: this is usually empty, so avoid opening + // the directory and checking. We can't cache the actual + // has-files-or-not values as the macOS/iOS client might + // in the future use+delete the files directly. So only + // keep this negative cache. + totalReceived := m.totalReceived.Load() + if totalReceived == m.emptySince.Load() { + return false + } + + // Check whether there is at least one one waiting file. + err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { + name := de.Name() + if isPartialOrDeleted(name) || !de.Type().IsRegular() { + return true + } + _, err := os.Stat(filepath.Join(m.opts.Dir, name+deletedSuffix)) + if os.IsNotExist(err) { + has = true + return false + } + return true + }) + + // If there are no more waiting files, record totalReceived as emptySince + // so that we can short-circuit the expensive directory traversal + // if no files have been received after the start of this call. + if err == nil && !has { + m.emptySince.Store(totalReceived) + } + return has +} + +// WaitingFiles returns the list of files that have been sent by a +// peer that are waiting in [Handler.Dir]. +// This always returns nil when [Handler.DirectFileMode] is false. +func (m *Manager) WaitingFiles() (ret []apitype.WaitingFile, err error) { + if m == nil || m.opts.Dir == "" { + return nil, ErrNoTaildrop + } + if m.opts.DirectFileMode { + return nil, nil + } + if err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { + name := de.Name() + if isPartialOrDeleted(name) || !de.Type().IsRegular() { + return true + } + _, err := os.Stat(filepath.Join(m.opts.Dir, name+deletedSuffix)) + if os.IsNotExist(err) { + fi, err := de.Info() + if err != nil { + return true + } + ret = append(ret, apitype.WaitingFile{ + Name: filepath.Base(name), + Size: fi.Size(), + }) + } + return true + }); err != nil { + return nil, redactError(err) + } + sort.Slice(ret, func(i, j int) bool { return ret[i].Name < ret[j].Name }) + return ret, nil +} + +// DeleteFile deletes a file of the given baseName from [Handler.Dir]. +// This method is only allowed when [Handler.DirectFileMode] is false. +func (m *Manager) DeleteFile(baseName string) error { + if m == nil || m.opts.Dir == "" { + return ErrNoTaildrop + } + if m.opts.DirectFileMode { + return errors.New("deletes not allowed in direct mode") + } + path, err := joinDir(m.opts.Dir, baseName) + if err != nil { + return err + } + var bo *backoff.Backoff + logf := m.opts.Logf + t0 := m.opts.Clock.Now() + for { + err := os.Remove(path) + if err != nil && !os.IsNotExist(err) { + err = redactError(err) + // Put a retry loop around deletes on Windows. + // + // Windows file descriptor closes are effectively asynchronous, + // as a bunch of hooks run on/after close, + // and we can't necessarily delete the file for a while after close, + // as we need to wait for everybody to be done with it. + // On Windows, unlike Unix, a file can't be deleted if it's open anywhere. + // So try a few times but ultimately just leave a "foo.jpg.deleted" + // marker file to note that it's deleted and we clean it up later. + if runtime.GOOS == "windows" { + if bo == nil { + bo = backoff.NewBackoff("delete-retry", logf, 1*time.Second) + } + if m.opts.Clock.Since(t0) < 5*time.Second { + bo.BackOff(context.Background(), err) + continue + } + if err := touchFile(path + deletedSuffix); err != nil { + logf("peerapi: failed to leave deleted marker: %v", err) + } + m.deleter.Insert(baseName + deletedSuffix) + } + logf("peerapi: failed to DeleteFile: %v", err) + return err + } + return nil + } +} + +func touchFile(path string) error { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0666) + if err != nil { + return redactError(err) + } + return f.Close() +} + +// OpenFile opens a file of the given baseName from [Handler.Dir]. +// This method is only allowed when [Handler.DirectFileMode] is false. +func (m *Manager) OpenFile(baseName string) (rc io.ReadCloser, size int64, err error) { + if m == nil || m.opts.Dir == "" { + return nil, 0, ErrNoTaildrop + } + if m.opts.DirectFileMode { + return nil, 0, errors.New("opens not allowed in direct mode") + } + path, err := joinDir(m.opts.Dir, baseName) + if err != nil { + return nil, 0, err + } + if _, err := os.Stat(path + deletedSuffix); err == nil { + return nil, 0, redactError(&fs.PathError{Op: "open", Path: path, Err: fs.ErrNotExist}) + } + f, err := os.Open(path) + if err != nil { + return nil, 0, redactError(err) + } + fi, err := f.Stat() + if err != nil { + f.Close() + return nil, 0, redactError(err) + } + return f, fi.Size(), nil +} diff --git a/tempfork/gliderlabs/ssh/LICENSE b/tempfork/gliderlabs/ssh/LICENSE index 4a03f02a28185..80b2b2baa7d2f 100644 --- a/tempfork/gliderlabs/ssh/LICENSE +++ b/tempfork/gliderlabs/ssh/LICENSE @@ -1,27 +1,27 @@ -Copyright (c) 2016 Glider Labs. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Glider Labs nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +Copyright (c) 2016 Glider Labs. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Glider Labs nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/tempfork/gliderlabs/ssh/README.md b/tempfork/gliderlabs/ssh/README.md index 79b5b89fa8a94..ecef6b7c47895 100644 --- a/tempfork/gliderlabs/ssh/README.md +++ b/tempfork/gliderlabs/ssh/README.md @@ -1,96 +1,96 @@ -# gliderlabs/ssh - -[![GoDoc](https://godoc.org/tailscale.com/tempfork/gliderlabs/ssh?status.svg)](https://godoc.org/github.com/gliderlabs/ssh) -[![CircleCI](https://img.shields.io/circleci/project/github/gliderlabs/ssh.svg)](https://circleci.com/gh/gliderlabs/ssh) -[![Go Report Card](https://goreportcard.com/badge/tailscale.com/tempfork/gliderlabs/ssh)](https://goreportcard.com/report/github.com/gliderlabs/ssh) -[![OpenCollective](https://opencollective.com/ssh/sponsors/badge.svg)](#sponsors) -[![Slack](http://slack.gliderlabs.com/badge.svg)](http://slack.gliderlabs.com) -[![Email Updates](https://img.shields.io/badge/updates-subscribe-yellow.svg)](https://app.convertkit.com/landing_pages/243312) - -> The Glider Labs SSH server package is dope. —[@bradfitz](https://twitter.com/bradfitz), Go team member - -This Go package wraps the [crypto/ssh -package](https://godoc.org/golang.org/x/crypto/ssh) with a higher-level API for -building SSH servers. The goal of the API was to make it as simple as using -[net/http](https://golang.org/pkg/net/http/), so the API is very similar: - -```go - package main - - import ( - "tailscale.com/tempfork/gliderlabs/ssh" - "io" - "log" - ) - - func main() { - ssh.Handle(func(s ssh.Session) { - io.WriteString(s, "Hello world\n") - }) - - log.Fatal(ssh.ListenAndServe(":2222", nil)) - } - -``` -This package was built by [@progrium](https://twitter.com/progrium) after working on nearly a dozen projects at Glider Labs using SSH and collaborating with [@shazow](https://twitter.com/shazow) (known for [ssh-chat](https://github.com/shazow/ssh-chat)). - -## Examples - -A bunch of great examples are in the `_examples` directory. - -## Usage - -[See GoDoc reference.](https://godoc.org/tailscale.com/tempfork/gliderlabs/ssh) - -## Contributing - -Pull requests are welcome! However, since this project is very much about API -design, please submit API changes as issues to discuss before submitting PRs. - -Also, you can [join our Slack](http://slack.gliderlabs.com) to discuss as well. - -## Roadmap - -* Non-session channel handlers -* Cleanup callback API -* 1.0 release -* High-level client? - -## Sponsors - -Become a sponsor and get your logo on our README on Github with a link to your site. [[Become a sponsor](https://opencollective.com/ssh#sponsor)] - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -## License - -[BSD](LICENSE) +# gliderlabs/ssh + +[![GoDoc](https://godoc.org/tailscale.com/tempfork/gliderlabs/ssh?status.svg)](https://godoc.org/github.com/gliderlabs/ssh) +[![CircleCI](https://img.shields.io/circleci/project/github/gliderlabs/ssh.svg)](https://circleci.com/gh/gliderlabs/ssh) +[![Go Report Card](https://goreportcard.com/badge/tailscale.com/tempfork/gliderlabs/ssh)](https://goreportcard.com/report/github.com/gliderlabs/ssh) +[![OpenCollective](https://opencollective.com/ssh/sponsors/badge.svg)](#sponsors) +[![Slack](http://slack.gliderlabs.com/badge.svg)](http://slack.gliderlabs.com) +[![Email Updates](https://img.shields.io/badge/updates-subscribe-yellow.svg)](https://app.convertkit.com/landing_pages/243312) + +> The Glider Labs SSH server package is dope. —[@bradfitz](https://twitter.com/bradfitz), Go team member + +This Go package wraps the [crypto/ssh +package](https://godoc.org/golang.org/x/crypto/ssh) with a higher-level API for +building SSH servers. The goal of the API was to make it as simple as using +[net/http](https://golang.org/pkg/net/http/), so the API is very similar: + +```go + package main + + import ( + "tailscale.com/tempfork/gliderlabs/ssh" + "io" + "log" + ) + + func main() { + ssh.Handle(func(s ssh.Session) { + io.WriteString(s, "Hello world\n") + }) + + log.Fatal(ssh.ListenAndServe(":2222", nil)) + } + +``` +This package was built by [@progrium](https://twitter.com/progrium) after working on nearly a dozen projects at Glider Labs using SSH and collaborating with [@shazow](https://twitter.com/shazow) (known for [ssh-chat](https://github.com/shazow/ssh-chat)). + +## Examples + +A bunch of great examples are in the `_examples` directory. + +## Usage + +[See GoDoc reference.](https://godoc.org/tailscale.com/tempfork/gliderlabs/ssh) + +## Contributing + +Pull requests are welcome! However, since this project is very much about API +design, please submit API changes as issues to discuss before submitting PRs. + +Also, you can [join our Slack](http://slack.gliderlabs.com) to discuss as well. + +## Roadmap + +* Non-session channel handlers +* Cleanup callback API +* 1.0 release +* High-level client? + +## Sponsors + +Become a sponsor and get your logo on our README on Github with a link to your site. [[Become a sponsor](https://opencollective.com/ssh#sponsor)] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +## License + +[BSD](LICENSE) diff --git a/tempfork/gliderlabs/ssh/agent.go b/tempfork/gliderlabs/ssh/agent.go index 86a5bce7f8ebc..3da665292a447 100644 --- a/tempfork/gliderlabs/ssh/agent.go +++ b/tempfork/gliderlabs/ssh/agent.go @@ -1,83 +1,83 @@ -package ssh - -import ( - "io" - "net" - "os" - "path" - "sync" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -const ( - agentRequestType = "auth-agent-req@openssh.com" - agentChannelType = "auth-agent@openssh.com" - - agentTempDir = "auth-agent" - agentListenFile = "listener.sock" -) - -// contextKeyAgentRequest is an internal context key for storing if the -// client requested agent forwarding -var contextKeyAgentRequest = &contextKey{"auth-agent-req"} - -// SetAgentRequested sets up the session context so that AgentRequested -// returns true. -func SetAgentRequested(ctx Context) { - ctx.SetValue(contextKeyAgentRequest, true) -} - -// AgentRequested returns true if the client requested agent forwarding. -func AgentRequested(sess Session) bool { - return sess.Context().Value(contextKeyAgentRequest) == true -} - -// NewAgentListener sets up a temporary Unix socket that can be communicated -// to the session environment and used for forwarding connections. -func NewAgentListener() (net.Listener, error) { - dir, err := os.MkdirTemp("", agentTempDir) - if err != nil { - return nil, err - } - l, err := net.Listen("unix", path.Join(dir, agentListenFile)) - if err != nil { - return nil, err - } - return l, nil -} - -// ForwardAgentConnections takes connections from a listener to proxy into the -// session on the OpenSSH channel for agent connections. It blocks and services -// connections until the listener stop accepting. -func ForwardAgentConnections(l net.Listener, s Session) { - sshConn := s.Context().Value(ContextKeyConn).(gossh.Conn) - for { - conn, err := l.Accept() - if err != nil { - return - } - go func(conn net.Conn) { - defer conn.Close() - channel, reqs, err := sshConn.OpenChannel(agentChannelType, nil) - if err != nil { - return - } - defer channel.Close() - go gossh.DiscardRequests(reqs) - var wg sync.WaitGroup - wg.Add(2) - go func() { - io.Copy(conn, channel) - conn.(*net.UnixConn).CloseWrite() - wg.Done() - }() - go func() { - io.Copy(channel, conn) - channel.CloseWrite() - wg.Done() - }() - wg.Wait() - }(conn) - } -} +package ssh + +import ( + "io" + "net" + "os" + "path" + "sync" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +const ( + agentRequestType = "auth-agent-req@openssh.com" + agentChannelType = "auth-agent@openssh.com" + + agentTempDir = "auth-agent" + agentListenFile = "listener.sock" +) + +// contextKeyAgentRequest is an internal context key for storing if the +// client requested agent forwarding +var contextKeyAgentRequest = &contextKey{"auth-agent-req"} + +// SetAgentRequested sets up the session context so that AgentRequested +// returns true. +func SetAgentRequested(ctx Context) { + ctx.SetValue(contextKeyAgentRequest, true) +} + +// AgentRequested returns true if the client requested agent forwarding. +func AgentRequested(sess Session) bool { + return sess.Context().Value(contextKeyAgentRequest) == true +} + +// NewAgentListener sets up a temporary Unix socket that can be communicated +// to the session environment and used for forwarding connections. +func NewAgentListener() (net.Listener, error) { + dir, err := os.MkdirTemp("", agentTempDir) + if err != nil { + return nil, err + } + l, err := net.Listen("unix", path.Join(dir, agentListenFile)) + if err != nil { + return nil, err + } + return l, nil +} + +// ForwardAgentConnections takes connections from a listener to proxy into the +// session on the OpenSSH channel for agent connections. It blocks and services +// connections until the listener stop accepting. +func ForwardAgentConnections(l net.Listener, s Session) { + sshConn := s.Context().Value(ContextKeyConn).(gossh.Conn) + for { + conn, err := l.Accept() + if err != nil { + return + } + go func(conn net.Conn) { + defer conn.Close() + channel, reqs, err := sshConn.OpenChannel(agentChannelType, nil) + if err != nil { + return + } + defer channel.Close() + go gossh.DiscardRequests(reqs) + var wg sync.WaitGroup + wg.Add(2) + go func() { + io.Copy(conn, channel) + conn.(*net.UnixConn).CloseWrite() + wg.Done() + }() + go func() { + io.Copy(channel, conn) + channel.CloseWrite() + wg.Done() + }() + wg.Wait() + }(conn) + } +} diff --git a/tempfork/gliderlabs/ssh/conn.go b/tempfork/gliderlabs/ssh/conn.go index ebef8845baccb..ec277bf27676f 100644 --- a/tempfork/gliderlabs/ssh/conn.go +++ b/tempfork/gliderlabs/ssh/conn.go @@ -1,55 +1,55 @@ -package ssh - -import ( - "context" - "net" - "time" -) - -type serverConn struct { - net.Conn - - idleTimeout time.Duration - maxDeadline time.Time - closeCanceler context.CancelFunc -} - -func (c *serverConn) Write(p []byte) (n int, err error) { - c.updateDeadline() - n, err = c.Conn.Write(p) - if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { - c.closeCanceler() - } - return -} - -func (c *serverConn) Read(b []byte) (n int, err error) { - c.updateDeadline() - n, err = c.Conn.Read(b) - if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { - c.closeCanceler() - } - return -} - -func (c *serverConn) Close() (err error) { - err = c.Conn.Close() - if c.closeCanceler != nil { - c.closeCanceler() - } - return -} - -func (c *serverConn) updateDeadline() { - switch { - case c.idleTimeout > 0: - idleDeadline := time.Now().Add(c.idleTimeout) - if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() { - c.Conn.SetDeadline(idleDeadline) - return - } - fallthrough - default: - c.Conn.SetDeadline(c.maxDeadline) - } -} +package ssh + +import ( + "context" + "net" + "time" +) + +type serverConn struct { + net.Conn + + idleTimeout time.Duration + maxDeadline time.Time + closeCanceler context.CancelFunc +} + +func (c *serverConn) Write(p []byte) (n int, err error) { + c.updateDeadline() + n, err = c.Conn.Write(p) + if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) Read(b []byte) (n int, err error) { + c.updateDeadline() + n, err = c.Conn.Read(b) + if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) Close() (err error) { + err = c.Conn.Close() + if c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) updateDeadline() { + switch { + case c.idleTimeout > 0: + idleDeadline := time.Now().Add(c.idleTimeout) + if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() { + c.Conn.SetDeadline(idleDeadline) + return + } + fallthrough + default: + c.Conn.SetDeadline(c.maxDeadline) + } +} diff --git a/tempfork/gliderlabs/ssh/context.go b/tempfork/gliderlabs/ssh/context.go index d43de6f09c8a5..6f7245574060d 100644 --- a/tempfork/gliderlabs/ssh/context.go +++ b/tempfork/gliderlabs/ssh/context.go @@ -1,164 +1,164 @@ -package ssh - -import ( - "context" - "encoding/hex" - "net" - "sync" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -// contextKey is a value for use with context.WithValue. It's used as -// a pointer so it fits in an interface{} without allocation. -type contextKey struct { - name string -} - -var ( - // ContextKeyUser is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeyUser = &contextKey{"user"} - - // ContextKeySessionID is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeySessionID = &contextKey{"session-id"} - - // ContextKeyPermissions is a context key for use with Contexts in this package. - // The associated value will be of type *Permissions. - ContextKeyPermissions = &contextKey{"permissions"} - - // ContextKeyClientVersion is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeyClientVersion = &contextKey{"client-version"} - - // ContextKeyServerVersion is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeyServerVersion = &contextKey{"server-version"} - - // ContextKeyLocalAddr is a context key for use with Contexts in this package. - // The associated value will be of type net.Addr. - ContextKeyLocalAddr = &contextKey{"local-addr"} - - // ContextKeyRemoteAddr is a context key for use with Contexts in this package. - // The associated value will be of type net.Addr. - ContextKeyRemoteAddr = &contextKey{"remote-addr"} - - // ContextKeyServer is a context key for use with Contexts in this package. - // The associated value will be of type *Server. - ContextKeyServer = &contextKey{"ssh-server"} - - // ContextKeyConn is a context key for use with Contexts in this package. - // The associated value will be of type gossh.ServerConn. - ContextKeyConn = &contextKey{"ssh-conn"} - - // ContextKeyPublicKey is a context key for use with Contexts in this package. - // The associated value will be of type PublicKey. - ContextKeyPublicKey = &contextKey{"public-key"} - - ContextKeySendAuthBanner = &contextKey{"send-auth-banner"} -) - -// Context is a package specific context interface. It exposes connection -// metadata and allows new values to be easily written to it. It's used in -// authentication handlers and callbacks, and its underlying context.Context is -// exposed on Session in the session Handler. A connection-scoped lock is also -// embedded in the context to make it easier to limit operations per-connection. -type Context interface { - context.Context - sync.Locker - - // User returns the username used when establishing the SSH connection. - User() string - - // SessionID returns the session hash. - SessionID() string - - // ClientVersion returns the version reported by the client. - ClientVersion() string - - // ServerVersion returns the version reported by the server. - ServerVersion() string - - // RemoteAddr returns the remote address for this connection. - RemoteAddr() net.Addr - - // LocalAddr returns the local address for this connection. - LocalAddr() net.Addr - - // Permissions returns the Permissions object used for this connection. - Permissions() *Permissions - - // SetValue allows you to easily write new values into the underlying context. - SetValue(key, value interface{}) - - SendAuthBanner(banner string) error -} - -type sshContext struct { - context.Context - *sync.Mutex -} - -func newContext(srv *Server) (*sshContext, context.CancelFunc) { - innerCtx, cancel := context.WithCancel(context.Background()) - ctx := &sshContext{innerCtx, &sync.Mutex{}} - ctx.SetValue(ContextKeyServer, srv) - perms := &Permissions{&gossh.Permissions{}} - ctx.SetValue(ContextKeyPermissions, perms) - return ctx, cancel -} - -// this is separate from newContext because we will get ConnMetadata -// at different points so it needs to be applied separately -func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { - if ctx.Value(ContextKeySessionID) != nil { - return - } - ctx.SetValue(ContextKeySessionID, hex.EncodeToString(conn.SessionID())) - ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion())) - ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion())) - ctx.SetValue(ContextKeyUser, conn.User()) - ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr()) - ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr()) - ctx.SetValue(ContextKeySendAuthBanner, conn.SendAuthBanner) -} - -func (ctx *sshContext) SetValue(key, value interface{}) { - ctx.Context = context.WithValue(ctx.Context, key, value) -} - -func (ctx *sshContext) User() string { - return ctx.Value(ContextKeyUser).(string) -} - -func (ctx *sshContext) SessionID() string { - return ctx.Value(ContextKeySessionID).(string) -} - -func (ctx *sshContext) ClientVersion() string { - return ctx.Value(ContextKeyClientVersion).(string) -} - -func (ctx *sshContext) ServerVersion() string { - return ctx.Value(ContextKeyServerVersion).(string) -} - -func (ctx *sshContext) RemoteAddr() net.Addr { - if addr, ok := ctx.Value(ContextKeyRemoteAddr).(net.Addr); ok { - return addr - } - return nil -} - -func (ctx *sshContext) LocalAddr() net.Addr { - return ctx.Value(ContextKeyLocalAddr).(net.Addr) -} - -func (ctx *sshContext) Permissions() *Permissions { - return ctx.Value(ContextKeyPermissions).(*Permissions) -} - -func (ctx *sshContext) SendAuthBanner(msg string) error { - return ctx.Value(ContextKeySendAuthBanner).(func(string) error)(msg) -} +package ssh + +import ( + "context" + "encoding/hex" + "net" + "sync" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +var ( + // ContextKeyUser is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeyUser = &contextKey{"user"} + + // ContextKeySessionID is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeySessionID = &contextKey{"session-id"} + + // ContextKeyPermissions is a context key for use with Contexts in this package. + // The associated value will be of type *Permissions. + ContextKeyPermissions = &contextKey{"permissions"} + + // ContextKeyClientVersion is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeyClientVersion = &contextKey{"client-version"} + + // ContextKeyServerVersion is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeyServerVersion = &contextKey{"server-version"} + + // ContextKeyLocalAddr is a context key for use with Contexts in this package. + // The associated value will be of type net.Addr. + ContextKeyLocalAddr = &contextKey{"local-addr"} + + // ContextKeyRemoteAddr is a context key for use with Contexts in this package. + // The associated value will be of type net.Addr. + ContextKeyRemoteAddr = &contextKey{"remote-addr"} + + // ContextKeyServer is a context key for use with Contexts in this package. + // The associated value will be of type *Server. + ContextKeyServer = &contextKey{"ssh-server"} + + // ContextKeyConn is a context key for use with Contexts in this package. + // The associated value will be of type gossh.ServerConn. + ContextKeyConn = &contextKey{"ssh-conn"} + + // ContextKeyPublicKey is a context key for use with Contexts in this package. + // The associated value will be of type PublicKey. + ContextKeyPublicKey = &contextKey{"public-key"} + + ContextKeySendAuthBanner = &contextKey{"send-auth-banner"} +) + +// Context is a package specific context interface. It exposes connection +// metadata and allows new values to be easily written to it. It's used in +// authentication handlers and callbacks, and its underlying context.Context is +// exposed on Session in the session Handler. A connection-scoped lock is also +// embedded in the context to make it easier to limit operations per-connection. +type Context interface { + context.Context + sync.Locker + + // User returns the username used when establishing the SSH connection. + User() string + + // SessionID returns the session hash. + SessionID() string + + // ClientVersion returns the version reported by the client. + ClientVersion() string + + // ServerVersion returns the version reported by the server. + ServerVersion() string + + // RemoteAddr returns the remote address for this connection. + RemoteAddr() net.Addr + + // LocalAddr returns the local address for this connection. + LocalAddr() net.Addr + + // Permissions returns the Permissions object used for this connection. + Permissions() *Permissions + + // SetValue allows you to easily write new values into the underlying context. + SetValue(key, value interface{}) + + SendAuthBanner(banner string) error +} + +type sshContext struct { + context.Context + *sync.Mutex +} + +func newContext(srv *Server) (*sshContext, context.CancelFunc) { + innerCtx, cancel := context.WithCancel(context.Background()) + ctx := &sshContext{innerCtx, &sync.Mutex{}} + ctx.SetValue(ContextKeyServer, srv) + perms := &Permissions{&gossh.Permissions{}} + ctx.SetValue(ContextKeyPermissions, perms) + return ctx, cancel +} + +// this is separate from newContext because we will get ConnMetadata +// at different points so it needs to be applied separately +func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { + if ctx.Value(ContextKeySessionID) != nil { + return + } + ctx.SetValue(ContextKeySessionID, hex.EncodeToString(conn.SessionID())) + ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion())) + ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion())) + ctx.SetValue(ContextKeyUser, conn.User()) + ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr()) + ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr()) + ctx.SetValue(ContextKeySendAuthBanner, conn.SendAuthBanner) +} + +func (ctx *sshContext) SetValue(key, value interface{}) { + ctx.Context = context.WithValue(ctx.Context, key, value) +} + +func (ctx *sshContext) User() string { + return ctx.Value(ContextKeyUser).(string) +} + +func (ctx *sshContext) SessionID() string { + return ctx.Value(ContextKeySessionID).(string) +} + +func (ctx *sshContext) ClientVersion() string { + return ctx.Value(ContextKeyClientVersion).(string) +} + +func (ctx *sshContext) ServerVersion() string { + return ctx.Value(ContextKeyServerVersion).(string) +} + +func (ctx *sshContext) RemoteAddr() net.Addr { + if addr, ok := ctx.Value(ContextKeyRemoteAddr).(net.Addr); ok { + return addr + } + return nil +} + +func (ctx *sshContext) LocalAddr() net.Addr { + return ctx.Value(ContextKeyLocalAddr).(net.Addr) +} + +func (ctx *sshContext) Permissions() *Permissions { + return ctx.Value(ContextKeyPermissions).(*Permissions) +} + +func (ctx *sshContext) SendAuthBanner(msg string) error { + return ctx.Value(ContextKeySendAuthBanner).(func(string) error)(msg) +} diff --git a/tempfork/gliderlabs/ssh/context_test.go b/tempfork/gliderlabs/ssh/context_test.go index dcbd326b77809..8f71c395841c9 100644 --- a/tempfork/gliderlabs/ssh/context_test.go +++ b/tempfork/gliderlabs/ssh/context_test.go @@ -1,49 +1,49 @@ -//go:build glidertests - -package ssh - -import "testing" - -func TestSetPermissions(t *testing.T) { - t.Parallel() - permsExt := map[string]string{ - "foo": "bar", - } - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - if _, ok := s.Permissions().Extensions["foo"]; !ok { - t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt) - } - }, - }, nil, PasswordAuth(func(ctx Context, password string) bool { - ctx.Permissions().Extensions = permsExt - return true - })) - defer cleanup() - if err := session.Run(""); err != nil { - t.Fatal(err) - } -} - -func TestSetValue(t *testing.T) { - t.Parallel() - value := map[string]string{ - "foo": "bar", - } - key := "testValue" - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - v := s.Context().Value(key).(map[string]string) - if v["foo"] != value["foo"] { - t.Fatalf("got %#v; want %#v", v, value) - } - }, - }, nil, PasswordAuth(func(ctx Context, password string) bool { - ctx.SetValue(key, value) - return true - })) - defer cleanup() - if err := session.Run(""); err != nil { - t.Fatal(err) - } -} +//go:build glidertests + +package ssh + +import "testing" + +func TestSetPermissions(t *testing.T) { + t.Parallel() + permsExt := map[string]string{ + "foo": "bar", + } + session, _, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + if _, ok := s.Permissions().Extensions["foo"]; !ok { + t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt) + } + }, + }, nil, PasswordAuth(func(ctx Context, password string) bool { + ctx.Permissions().Extensions = permsExt + return true + })) + defer cleanup() + if err := session.Run(""); err != nil { + t.Fatal(err) + } +} + +func TestSetValue(t *testing.T) { + t.Parallel() + value := map[string]string{ + "foo": "bar", + } + key := "testValue" + session, _, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + v := s.Context().Value(key).(map[string]string) + if v["foo"] != value["foo"] { + t.Fatalf("got %#v; want %#v", v, value) + } + }, + }, nil, PasswordAuth(func(ctx Context, password string) bool { + ctx.SetValue(key, value) + return true + })) + defer cleanup() + if err := session.Run(""); err != nil { + t.Fatal(err) + } +} diff --git a/tempfork/gliderlabs/ssh/doc.go b/tempfork/gliderlabs/ssh/doc.go index d139191768d55..46c47d650a06c 100644 --- a/tempfork/gliderlabs/ssh/doc.go +++ b/tempfork/gliderlabs/ssh/doc.go @@ -1,45 +1,45 @@ -/* -Package ssh wraps the crypto/ssh package with a higher-level API for building -SSH servers. The goal of the API was to make it as simple as using net/http, so -the API is very similar. - -You should be able to build any SSH server using only this package, which wraps -relevant types and some functions from crypto/ssh. However, you still need to -use crypto/ssh for building SSH clients. - -ListenAndServe starts an SSH server with a given address, handler, and options. The -handler is usually nil, which means to use DefaultHandler. Handle sets DefaultHandler: - - ssh.Handle(func(s ssh.Session) { - io.WriteString(s, "Hello world\n") - }) - - log.Fatal(ssh.ListenAndServe(":2222", nil)) - -If you don't specify a host key, it will generate one every time. This is convenient -except you'll have to deal with clients being confused that the host key is different. -It's a better idea to generate or point to an existing key on your system: - - log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/Users/progrium/.ssh/id_rsa"))) - -Although all options have functional option helpers, another way to control the -server's behavior is by creating a custom Server: - - s := &ssh.Server{ - Addr: ":2222", - Handler: sessionHandler, - PublicKeyHandler: authHandler, - } - s.AddHostKey(hostKeySigner) - - log.Fatal(s.ListenAndServe()) - -This package automatically handles basic SSH requests like setting environment -variables, requesting PTY, and changing window size. These requests are -processed, responded to, and any relevant state is updated. This state is then -exposed to you via the Session interface. - -The one big feature missing from the Session abstraction is signals. This was -started, but not completed. Pull Requests welcome! -*/ -package ssh +/* +Package ssh wraps the crypto/ssh package with a higher-level API for building +SSH servers. The goal of the API was to make it as simple as using net/http, so +the API is very similar. + +You should be able to build any SSH server using only this package, which wraps +relevant types and some functions from crypto/ssh. However, you still need to +use crypto/ssh for building SSH clients. + +ListenAndServe starts an SSH server with a given address, handler, and options. The +handler is usually nil, which means to use DefaultHandler. Handle sets DefaultHandler: + + ssh.Handle(func(s ssh.Session) { + io.WriteString(s, "Hello world\n") + }) + + log.Fatal(ssh.ListenAndServe(":2222", nil)) + +If you don't specify a host key, it will generate one every time. This is convenient +except you'll have to deal with clients being confused that the host key is different. +It's a better idea to generate or point to an existing key on your system: + + log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/Users/progrium/.ssh/id_rsa"))) + +Although all options have functional option helpers, another way to control the +server's behavior is by creating a custom Server: + + s := &ssh.Server{ + Addr: ":2222", + Handler: sessionHandler, + PublicKeyHandler: authHandler, + } + s.AddHostKey(hostKeySigner) + + log.Fatal(s.ListenAndServe()) + +This package automatically handles basic SSH requests like setting environment +variables, requesting PTY, and changing window size. These requests are +processed, responded to, and any relevant state is updated. This state is then +exposed to you via the Session interface. + +The one big feature missing from the Session abstraction is signals. This was +started, but not completed. Pull Requests welcome! +*/ +package ssh diff --git a/tempfork/gliderlabs/ssh/example_test.go b/tempfork/gliderlabs/ssh/example_test.go index c174bc4ae190e..61ffebbc045dc 100644 --- a/tempfork/gliderlabs/ssh/example_test.go +++ b/tempfork/gliderlabs/ssh/example_test.go @@ -1,50 +1,50 @@ -package ssh_test - -import ( - "errors" - "io" - "os" - - "tailscale.com/tempfork/gliderlabs/ssh" -) - -func ExampleListenAndServe() { - ssh.ListenAndServe(":2222", func(s ssh.Session) { - io.WriteString(s, "Hello world\n") - }) -} - -func ExamplePasswordAuth() { - ssh.ListenAndServe(":2222", nil, - ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool { - return pass == "secret" - }), - ) -} - -func ExampleNoPty() { - ssh.ListenAndServe(":2222", nil, ssh.NoPty()) -} - -func ExamplePublicKeyAuth() { - ssh.ListenAndServe(":2222", nil, - ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) error { - data, err := os.ReadFile("/path/to/allowed/key.pub") - if err != nil { - return err - } - allowed, _, _, _, err := ssh.ParseAuthorizedKey(data) - if err != nil { - return err - } - if !ssh.KeysEqual(key, allowed) { - return errors.New("some error") - } - return nil - }), - ) -} - -func ExampleHostKeyFile() { - ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/path/to/host/key")) -} +package ssh_test + +import ( + "errors" + "io" + "os" + + "tailscale.com/tempfork/gliderlabs/ssh" +) + +func ExampleListenAndServe() { + ssh.ListenAndServe(":2222", func(s ssh.Session) { + io.WriteString(s, "Hello world\n") + }) +} + +func ExamplePasswordAuth() { + ssh.ListenAndServe(":2222", nil, + ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool { + return pass == "secret" + }), + ) +} + +func ExampleNoPty() { + ssh.ListenAndServe(":2222", nil, ssh.NoPty()) +} + +func ExamplePublicKeyAuth() { + ssh.ListenAndServe(":2222", nil, + ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) error { + data, err := os.ReadFile("/path/to/allowed/key.pub") + if err != nil { + return err + } + allowed, _, _, _, err := ssh.ParseAuthorizedKey(data) + if err != nil { + return err + } + if !ssh.KeysEqual(key, allowed) { + return errors.New("some error") + } + return nil + }), + ) +} + +func ExampleHostKeyFile() { + ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/path/to/host/key")) +} diff --git a/tempfork/gliderlabs/ssh/options.go b/tempfork/gliderlabs/ssh/options.go index aa87a4f39db9e..bb24909bebd2a 100644 --- a/tempfork/gliderlabs/ssh/options.go +++ b/tempfork/gliderlabs/ssh/options.go @@ -1,84 +1,84 @@ -package ssh - -import ( - "os" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -// PasswordAuth returns a functional option that sets PasswordHandler on the server. -func PasswordAuth(fn PasswordHandler) Option { - return func(srv *Server) error { - srv.PasswordHandler = fn - return nil - } -} - -// PublicKeyAuth returns a functional option that sets PublicKeyHandler on the server. -func PublicKeyAuth(fn PublicKeyHandler) Option { - return func(srv *Server) error { - srv.PublicKeyHandler = fn - return nil - } -} - -// HostKeyFile returns a functional option that adds HostSigners to the server -// from a PEM file at filepath. -func HostKeyFile(filepath string) Option { - return func(srv *Server) error { - pemBytes, err := os.ReadFile(filepath) - if err != nil { - return err - } - - signer, err := gossh.ParsePrivateKey(pemBytes) - if err != nil { - return err - } - - srv.AddHostKey(signer) - - return nil - } -} - -func KeyboardInteractiveAuth(fn KeyboardInteractiveHandler) Option { - return func(srv *Server) error { - srv.KeyboardInteractiveHandler = fn - return nil - } -} - -// HostKeyPEM returns a functional option that adds HostSigners to the server -// from a PEM file as bytes. -func HostKeyPEM(bytes []byte) Option { - return func(srv *Server) error { - signer, err := gossh.ParsePrivateKey(bytes) - if err != nil { - return err - } - - srv.AddHostKey(signer) - - return nil - } -} - -// NoPty returns a functional option that sets PtyCallback to return false, -// denying PTY requests. -func NoPty() Option { - return func(srv *Server) error { - srv.PtyCallback = func(ctx Context, pty Pty) bool { - return false - } - return nil - } -} - -// WrapConn returns a functional option that sets ConnCallback on the server. -func WrapConn(fn ConnCallback) Option { - return func(srv *Server) error { - srv.ConnCallback = fn - return nil - } -} +package ssh + +import ( + "os" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +// PasswordAuth returns a functional option that sets PasswordHandler on the server. +func PasswordAuth(fn PasswordHandler) Option { + return func(srv *Server) error { + srv.PasswordHandler = fn + return nil + } +} + +// PublicKeyAuth returns a functional option that sets PublicKeyHandler on the server. +func PublicKeyAuth(fn PublicKeyHandler) Option { + return func(srv *Server) error { + srv.PublicKeyHandler = fn + return nil + } +} + +// HostKeyFile returns a functional option that adds HostSigners to the server +// from a PEM file at filepath. +func HostKeyFile(filepath string) Option { + return func(srv *Server) error { + pemBytes, err := os.ReadFile(filepath) + if err != nil { + return err + } + + signer, err := gossh.ParsePrivateKey(pemBytes) + if err != nil { + return err + } + + srv.AddHostKey(signer) + + return nil + } +} + +func KeyboardInteractiveAuth(fn KeyboardInteractiveHandler) Option { + return func(srv *Server) error { + srv.KeyboardInteractiveHandler = fn + return nil + } +} + +// HostKeyPEM returns a functional option that adds HostSigners to the server +// from a PEM file as bytes. +func HostKeyPEM(bytes []byte) Option { + return func(srv *Server) error { + signer, err := gossh.ParsePrivateKey(bytes) + if err != nil { + return err + } + + srv.AddHostKey(signer) + + return nil + } +} + +// NoPty returns a functional option that sets PtyCallback to return false, +// denying PTY requests. +func NoPty() Option { + return func(srv *Server) error { + srv.PtyCallback = func(ctx Context, pty Pty) bool { + return false + } + return nil + } +} + +// WrapConn returns a functional option that sets ConnCallback on the server. +func WrapConn(fn ConnCallback) Option { + return func(srv *Server) error { + srv.ConnCallback = fn + return nil + } +} diff --git a/tempfork/gliderlabs/ssh/options_test.go b/tempfork/gliderlabs/ssh/options_test.go index 7cf6f376c6a88..3aa2f1cf5e31b 100644 --- a/tempfork/gliderlabs/ssh/options_test.go +++ b/tempfork/gliderlabs/ssh/options_test.go @@ -1,111 +1,111 @@ -//go:build glidertests - -package ssh - -import ( - "net" - "strings" - "sync/atomic" - "testing" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, *gossh.Client, func()) { - for _, option := range options { - if err := srv.SetOption(option); err != nil { - t.Fatal(err) - } - } - return newTestSession(t, srv, cfg) -} - -func TestPasswordAuth(t *testing.T) { - t.Parallel() - testUser := "testuser" - testPass := "testpass" - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - // noop - }, - }, &gossh.ClientConfig{ - User: testUser, - Auth: []gossh.AuthMethod{ - gossh.Password(testPass), - }, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - }, PasswordAuth(func(ctx Context, password string) bool { - if ctx.User() != testUser { - t.Fatalf("user = %#v; want %#v", ctx.User(), testUser) - } - if password != testPass { - t.Fatalf("user = %#v; want %#v", password, testPass) - } - return true - })) - defer cleanup() - if err := session.Run(""); err != nil { - t.Fatal(err) - } -} - -func TestPasswordAuthBadPass(t *testing.T) { - t.Parallel() - l := newLocalListener() - srv := &Server{Handler: func(s Session) {}} - srv.SetOption(PasswordAuth(func(ctx Context, password string) bool { - return false - })) - go srv.serveOnce(l) - _, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{ - User: "testuser", - Auth: []gossh.AuthMethod{ - gossh.Password("testpass"), - }, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - }) - if err != nil { - if !strings.Contains(err.Error(), "unable to authenticate") { - t.Fatal(err) - } - } -} - -type wrappedConn struct { - net.Conn - written int32 -} - -func (c *wrappedConn) Write(p []byte) (n int, err error) { - n, err = c.Conn.Write(p) - atomic.AddInt32(&(c.written), int32(n)) - return -} - -func TestConnWrapping(t *testing.T) { - t.Parallel() - var wrapped *wrappedConn - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - // nothing - }, - }, &gossh.ClientConfig{ - User: "testuser", - Auth: []gossh.AuthMethod{ - gossh.Password("testpass"), - }, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - }, PasswordAuth(func(ctx Context, password string) bool { - return true - }), WrapConn(func(ctx Context, conn net.Conn) net.Conn { - wrapped = &wrappedConn{conn, 0} - return wrapped - })) - defer cleanup() - if err := session.Shell(); err != nil { - t.Fatal(err) - } - if atomic.LoadInt32(&(wrapped.written)) == 0 { - t.Fatal("wrapped conn not written to") - } -} +//go:build glidertests + +package ssh + +import ( + "net" + "strings" + "sync/atomic" + "testing" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, *gossh.Client, func()) { + for _, option := range options { + if err := srv.SetOption(option); err != nil { + t.Fatal(err) + } + } + return newTestSession(t, srv, cfg) +} + +func TestPasswordAuth(t *testing.T) { + t.Parallel() + testUser := "testuser" + testPass := "testpass" + session, _, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + // noop + }, + }, &gossh.ClientConfig{ + User: testUser, + Auth: []gossh.AuthMethod{ + gossh.Password(testPass), + }, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }, PasswordAuth(func(ctx Context, password string) bool { + if ctx.User() != testUser { + t.Fatalf("user = %#v; want %#v", ctx.User(), testUser) + } + if password != testPass { + t.Fatalf("user = %#v; want %#v", password, testPass) + } + return true + })) + defer cleanup() + if err := session.Run(""); err != nil { + t.Fatal(err) + } +} + +func TestPasswordAuthBadPass(t *testing.T) { + t.Parallel() + l := newLocalListener() + srv := &Server{Handler: func(s Session) {}} + srv.SetOption(PasswordAuth(func(ctx Context, password string) bool { + return false + })) + go srv.serveOnce(l) + _, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{ + User: "testuser", + Auth: []gossh.AuthMethod{ + gossh.Password("testpass"), + }, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }) + if err != nil { + if !strings.Contains(err.Error(), "unable to authenticate") { + t.Fatal(err) + } + } +} + +type wrappedConn struct { + net.Conn + written int32 +} + +func (c *wrappedConn) Write(p []byte) (n int, err error) { + n, err = c.Conn.Write(p) + atomic.AddInt32(&(c.written), int32(n)) + return +} + +func TestConnWrapping(t *testing.T) { + t.Parallel() + var wrapped *wrappedConn + session, _, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + // nothing + }, + }, &gossh.ClientConfig{ + User: "testuser", + Auth: []gossh.AuthMethod{ + gossh.Password("testpass"), + }, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }, PasswordAuth(func(ctx Context, password string) bool { + return true + }), WrapConn(func(ctx Context, conn net.Conn) net.Conn { + wrapped = &wrappedConn{conn, 0} + return wrapped + })) + defer cleanup() + if err := session.Shell(); err != nil { + t.Fatal(err) + } + if atomic.LoadInt32(&(wrapped.written)) == 0 { + t.Fatal("wrapped conn not written to") + } +} diff --git a/tempfork/gliderlabs/ssh/server.go b/tempfork/gliderlabs/ssh/server.go index 1086a72caf0e5..32f633e87b58e 100644 --- a/tempfork/gliderlabs/ssh/server.go +++ b/tempfork/gliderlabs/ssh/server.go @@ -1,459 +1,459 @@ -package ssh - -import ( - "context" - "errors" - "fmt" - "net" - "sync" - "time" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -// ErrServerClosed is returned by the Server's Serve, ListenAndServe, -// and ListenAndServeTLS methods after a call to Shutdown or Close. -var ErrServerClosed = errors.New("ssh: Server closed") - -type SubsystemHandler func(s Session) - -var DefaultSubsystemHandlers = map[string]SubsystemHandler{} - -type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) - -var DefaultRequestHandlers = map[string]RequestHandler{} - -type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) - -var DefaultChannelHandlers = map[string]ChannelHandler{ - "session": DefaultSessionHandler, -} - -// Server defines parameters for running an SSH server. The zero value for -// Server is a valid configuration. When both PasswordHandler and -// PublicKeyHandler are nil, no client authentication is performed. -type Server struct { - Addr string // TCP address to listen on, ":22" if empty - Handler Handler // handler to invoke, ssh.DefaultHandler if nil - HostSigners []Signer // private keys for the host key, must have at least one - Version string // server version to be sent before the initial handshake - - KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler - PasswordHandler PasswordHandler // password authentication handler - PublicKeyHandler PublicKeyHandler // public key authentication handler - NoClientAuthHandler NoClientAuthHandler // no client authentication handler - PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil - ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling - LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil - ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil - ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options - SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions - - ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures - - IdleTimeout time.Duration // connection timeout when no activity, none if empty - MaxTimeout time.Duration // absolute connection timeout, none if empty - - // ChannelHandlers allow overriding the built-in session handlers or provide - // extensions to the protocol, such as tcpip forwarding. By default only the - // "session" handler is enabled. - ChannelHandlers map[string]ChannelHandler - - // RequestHandlers allow overriding the server-level request handlers or - // provide extensions to the protocol, such as tcpip forwarding. By default - // no handlers are enabled. - RequestHandlers map[string]RequestHandler - - // SubsystemHandlers are handlers which are similar to the usual SSH command - // handlers, but handle named subsystems. - SubsystemHandlers map[string]SubsystemHandler - - listenerWg sync.WaitGroup - mu sync.RWMutex - listeners map[net.Listener]struct{} - conns map[*gossh.ServerConn]struct{} - connWg sync.WaitGroup - doneChan chan struct{} -} - -func (srv *Server) ensureHostSigner() error { - srv.mu.Lock() - defer srv.mu.Unlock() - - if len(srv.HostSigners) == 0 { - signer, err := generateSigner() - if err != nil { - return err - } - srv.HostSigners = append(srv.HostSigners, signer) - } - return nil -} - -func (srv *Server) ensureHandlers() { - srv.mu.Lock() - defer srv.mu.Unlock() - - if srv.RequestHandlers == nil { - srv.RequestHandlers = map[string]RequestHandler{} - for k, v := range DefaultRequestHandlers { - srv.RequestHandlers[k] = v - } - } - if srv.ChannelHandlers == nil { - srv.ChannelHandlers = map[string]ChannelHandler{} - for k, v := range DefaultChannelHandlers { - srv.ChannelHandlers[k] = v - } - } - if srv.SubsystemHandlers == nil { - srv.SubsystemHandlers = map[string]SubsystemHandler{} - for k, v := range DefaultSubsystemHandlers { - srv.SubsystemHandlers[k] = v - } - } -} - -func (srv *Server) config(ctx Context) *gossh.ServerConfig { - srv.mu.RLock() - defer srv.mu.RUnlock() - - var config *gossh.ServerConfig - if srv.ServerConfigCallback == nil { - config = &gossh.ServerConfig{} - } else { - config = srv.ServerConfigCallback(ctx) - } - for _, signer := range srv.HostSigners { - config.AddHostKey(signer) - } - if srv.PasswordHandler == nil && srv.PublicKeyHandler == nil && srv.KeyboardInteractiveHandler == nil { - config.NoClientAuth = true - } - if srv.Version != "" { - config.ServerVersion = "SSH-2.0-" + srv.Version - } - if srv.PasswordHandler != nil { - config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if ok := srv.PasswordHandler(ctx, string(password)); !ok { - return ctx.Permissions().Permissions, fmt.Errorf("permission denied") - } - return ctx.Permissions().Permissions, nil - } - } - if srv.PublicKeyHandler != nil { - config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if err := srv.PublicKeyHandler(ctx, key); err != nil { - return ctx.Permissions().Permissions, err - } - ctx.SetValue(ContextKeyPublicKey, key) - return ctx.Permissions().Permissions, nil - } - } - if srv.KeyboardInteractiveHandler != nil { - config.KeyboardInteractiveCallback = func(conn gossh.ConnMetadata, challenger gossh.KeyboardInteractiveChallenge) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if ok := srv.KeyboardInteractiveHandler(ctx, challenger); !ok { - return ctx.Permissions().Permissions, fmt.Errorf("permission denied") - } - return ctx.Permissions().Permissions, nil - } - } - if srv.NoClientAuthHandler != nil { - config.NoClientAuthCallback = func(conn gossh.ConnMetadata) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if err := srv.NoClientAuthHandler(ctx); err != nil { - return ctx.Permissions().Permissions, err - } - return ctx.Permissions().Permissions, nil - } - } - return config -} - -// Handle sets the Handler for the server. -func (srv *Server) Handle(fn Handler) { - srv.mu.Lock() - defer srv.mu.Unlock() - - srv.Handler = fn -} - -// Close immediately closes all active listeners and all active -// connections. -// -// Close returns any error returned from closing the Server's -// underlying Listener(s). -func (srv *Server) Close() error { - srv.mu.Lock() - defer srv.mu.Unlock() - - srv.closeDoneChanLocked() - err := srv.closeListenersLocked() - for c := range srv.conns { - c.Close() - delete(srv.conns, c) - } - return err -} - -// Shutdown gracefully shuts down the server without interrupting any -// active connections. Shutdown works by first closing all open -// listeners, and then waiting indefinitely for connections to close. -// If the provided context expires before the shutdown is complete, -// then the context's error is returned. -func (srv *Server) Shutdown(ctx context.Context) error { - srv.mu.Lock() - lnerr := srv.closeListenersLocked() - srv.closeDoneChanLocked() - srv.mu.Unlock() - - finished := make(chan struct{}, 1) - go func() { - srv.listenerWg.Wait() - srv.connWg.Wait() - finished <- struct{}{} - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case <-finished: - return lnerr - } -} - -// Serve accepts incoming connections on the Listener l, creating a new -// connection goroutine for each. The connection goroutines read requests and then -// calls srv.Handler to handle sessions. -// -// Serve always returns a non-nil error. -func (srv *Server) Serve(l net.Listener) error { - srv.ensureHandlers() - defer l.Close() - if err := srv.ensureHostSigner(); err != nil { - return err - } - if srv.Handler == nil { - srv.Handler = DefaultHandler - } - var tempDelay time.Duration - - srv.trackListener(l, true) - defer srv.trackListener(l, false) - for { - conn, e := l.Accept() - if e != nil { - select { - case <-srv.getDoneChan(): - return ErrServerClosed - default: - } - if ne, ok := e.(net.Error); ok && ne.Temporary() { - if tempDelay == 0 { - tempDelay = 5 * time.Millisecond - } else { - tempDelay *= 2 - } - if max := 1 * time.Second; tempDelay > max { - tempDelay = max - } - time.Sleep(tempDelay) - continue - } - return e - } - go srv.HandleConn(conn) - } -} - -func (srv *Server) HandleConn(newConn net.Conn) { - ctx, cancel := newContext(srv) - if srv.ConnCallback != nil { - cbConn := srv.ConnCallback(ctx, newConn) - if cbConn == nil { - newConn.Close() - return - } - newConn = cbConn - } - conn := &serverConn{ - Conn: newConn, - idleTimeout: srv.IdleTimeout, - closeCanceler: cancel, - } - if srv.MaxTimeout > 0 { - conn.maxDeadline = time.Now().Add(srv.MaxTimeout) - } - defer conn.Close() - sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx)) - if err != nil { - if srv.ConnectionFailedCallback != nil { - srv.ConnectionFailedCallback(conn, err) - } - return - } - - srv.trackConn(sshConn, true) - defer srv.trackConn(sshConn, false) - - ctx.SetValue(ContextKeyConn, sshConn) - applyConnMetadata(ctx, sshConn) - //go gossh.DiscardRequests(reqs) - go srv.handleRequests(ctx, reqs) - for ch := range chans { - handler := srv.ChannelHandlers[ch.ChannelType()] - if handler == nil { - handler = srv.ChannelHandlers["default"] - } - if handler == nil { - ch.Reject(gossh.UnknownChannelType, "unsupported channel type") - continue - } - go handler(srv, sshConn, ch, ctx) - } -} - -func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { - for req := range in { - handler := srv.RequestHandlers[req.Type] - if handler == nil { - handler = srv.RequestHandlers["default"] - } - if handler == nil { - req.Reply(false, nil) - continue - } - /*reqCtx, cancel := context.WithCancel(ctx) - defer cancel() */ - ret, payload := handler(ctx, srv, req) - req.Reply(ret, payload) - } -} - -// ListenAndServe listens on the TCP network address srv.Addr and then calls -// Serve to handle incoming connections. If srv.Addr is blank, ":22" is used. -// ListenAndServe always returns a non-nil error. -func (srv *Server) ListenAndServe() error { - addr := srv.Addr - if addr == "" { - addr = ":22" - } - ln, err := net.Listen("tcp", addr) - if err != nil { - return err - } - return srv.Serve(ln) -} - -// AddHostKey adds a private key as a host key. If an existing host key exists -// with the same algorithm, it is overwritten. Each server config must have at -// least one host key. -func (srv *Server) AddHostKey(key Signer) { - srv.mu.Lock() - defer srv.mu.Unlock() - - // these are later added via AddHostKey on ServerConfig, which performs the - // check for one of every algorithm. - - // This check is based on the AddHostKey method from the x/crypto/ssh - // library. This allows us to only keep one active key for each type on a - // server at once. So, if you're dynamically updating keys at runtime, this - // list will not keep growing. - for i, k := range srv.HostSigners { - if k.PublicKey().Type() == key.PublicKey().Type() { - srv.HostSigners[i] = key - return - } - } - - srv.HostSigners = append(srv.HostSigners, key) -} - -// SetOption runs a functional option against the server. -func (srv *Server) SetOption(option Option) error { - // NOTE: there is a potential race here for any option that doesn't call an - // internal method. We can't actually lock here because if something calls - // (as an example) AddHostKey, it will deadlock. - - //srv.mu.Lock() - //defer srv.mu.Unlock() - - return option(srv) -} - -func (srv *Server) getDoneChan() <-chan struct{} { - srv.mu.Lock() - defer srv.mu.Unlock() - - return srv.getDoneChanLocked() -} - -func (srv *Server) getDoneChanLocked() chan struct{} { - if srv.doneChan == nil { - srv.doneChan = make(chan struct{}) - } - return srv.doneChan -} - -func (srv *Server) closeDoneChanLocked() { - ch := srv.getDoneChanLocked() - select { - case <-ch: - // Already closed. Don't close again. - default: - // Safe to close here. We're the only closer, guarded - // by srv.mu. - close(ch) - } -} - -func (srv *Server) closeListenersLocked() error { - var err error - for ln := range srv.listeners { - if cerr := ln.Close(); cerr != nil && err == nil { - err = cerr - } - delete(srv.listeners, ln) - } - return err -} - -func (srv *Server) trackListener(ln net.Listener, add bool) { - srv.mu.Lock() - defer srv.mu.Unlock() - - if srv.listeners == nil { - srv.listeners = make(map[net.Listener]struct{}) - } - if add { - // If the *Server is being reused after a previous - // Close or Shutdown, reset its doneChan: - if len(srv.listeners) == 0 && len(srv.conns) == 0 { - srv.doneChan = nil - } - srv.listeners[ln] = struct{}{} - srv.listenerWg.Add(1) - } else { - delete(srv.listeners, ln) - srv.listenerWg.Done() - } -} - -func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { - srv.mu.Lock() - defer srv.mu.Unlock() - - if srv.conns == nil { - srv.conns = make(map[*gossh.ServerConn]struct{}) - } - if add { - srv.conns[c] = struct{}{} - srv.connWg.Add(1) - } else { - delete(srv.conns, c) - srv.connWg.Done() - } -} +package ssh + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +// ErrServerClosed is returned by the Server's Serve, ListenAndServe, +// and ListenAndServeTLS methods after a call to Shutdown or Close. +var ErrServerClosed = errors.New("ssh: Server closed") + +type SubsystemHandler func(s Session) + +var DefaultSubsystemHandlers = map[string]SubsystemHandler{} + +type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) + +var DefaultRequestHandlers = map[string]RequestHandler{} + +type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) + +var DefaultChannelHandlers = map[string]ChannelHandler{ + "session": DefaultSessionHandler, +} + +// Server defines parameters for running an SSH server. The zero value for +// Server is a valid configuration. When both PasswordHandler and +// PublicKeyHandler are nil, no client authentication is performed. +type Server struct { + Addr string // TCP address to listen on, ":22" if empty + Handler Handler // handler to invoke, ssh.DefaultHandler if nil + HostSigners []Signer // private keys for the host key, must have at least one + Version string // server version to be sent before the initial handshake + + KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler + PasswordHandler PasswordHandler // password authentication handler + PublicKeyHandler PublicKeyHandler // public key authentication handler + NoClientAuthHandler NoClientAuthHandler // no client authentication handler + PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil + ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling + LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil + ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil + ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options + SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions + + ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures + + IdleTimeout time.Duration // connection timeout when no activity, none if empty + MaxTimeout time.Duration // absolute connection timeout, none if empty + + // ChannelHandlers allow overriding the built-in session handlers or provide + // extensions to the protocol, such as tcpip forwarding. By default only the + // "session" handler is enabled. + ChannelHandlers map[string]ChannelHandler + + // RequestHandlers allow overriding the server-level request handlers or + // provide extensions to the protocol, such as tcpip forwarding. By default + // no handlers are enabled. + RequestHandlers map[string]RequestHandler + + // SubsystemHandlers are handlers which are similar to the usual SSH command + // handlers, but handle named subsystems. + SubsystemHandlers map[string]SubsystemHandler + + listenerWg sync.WaitGroup + mu sync.RWMutex + listeners map[net.Listener]struct{} + conns map[*gossh.ServerConn]struct{} + connWg sync.WaitGroup + doneChan chan struct{} +} + +func (srv *Server) ensureHostSigner() error { + srv.mu.Lock() + defer srv.mu.Unlock() + + if len(srv.HostSigners) == 0 { + signer, err := generateSigner() + if err != nil { + return err + } + srv.HostSigners = append(srv.HostSigners, signer) + } + return nil +} + +func (srv *Server) ensureHandlers() { + srv.mu.Lock() + defer srv.mu.Unlock() + + if srv.RequestHandlers == nil { + srv.RequestHandlers = map[string]RequestHandler{} + for k, v := range DefaultRequestHandlers { + srv.RequestHandlers[k] = v + } + } + if srv.ChannelHandlers == nil { + srv.ChannelHandlers = map[string]ChannelHandler{} + for k, v := range DefaultChannelHandlers { + srv.ChannelHandlers[k] = v + } + } + if srv.SubsystemHandlers == nil { + srv.SubsystemHandlers = map[string]SubsystemHandler{} + for k, v := range DefaultSubsystemHandlers { + srv.SubsystemHandlers[k] = v + } + } +} + +func (srv *Server) config(ctx Context) *gossh.ServerConfig { + srv.mu.RLock() + defer srv.mu.RUnlock() + + var config *gossh.ServerConfig + if srv.ServerConfigCallback == nil { + config = &gossh.ServerConfig{} + } else { + config = srv.ServerConfigCallback(ctx) + } + for _, signer := range srv.HostSigners { + config.AddHostKey(signer) + } + if srv.PasswordHandler == nil && srv.PublicKeyHandler == nil && srv.KeyboardInteractiveHandler == nil { + config.NoClientAuth = true + } + if srv.Version != "" { + config.ServerVersion = "SSH-2.0-" + srv.Version + } + if srv.PasswordHandler != nil { + config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { + applyConnMetadata(ctx, conn) + if ok := srv.PasswordHandler(ctx, string(password)); !ok { + return ctx.Permissions().Permissions, fmt.Errorf("permission denied") + } + return ctx.Permissions().Permissions, nil + } + } + if srv.PublicKeyHandler != nil { + config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { + applyConnMetadata(ctx, conn) + if err := srv.PublicKeyHandler(ctx, key); err != nil { + return ctx.Permissions().Permissions, err + } + ctx.SetValue(ContextKeyPublicKey, key) + return ctx.Permissions().Permissions, nil + } + } + if srv.KeyboardInteractiveHandler != nil { + config.KeyboardInteractiveCallback = func(conn gossh.ConnMetadata, challenger gossh.KeyboardInteractiveChallenge) (*gossh.Permissions, error) { + applyConnMetadata(ctx, conn) + if ok := srv.KeyboardInteractiveHandler(ctx, challenger); !ok { + return ctx.Permissions().Permissions, fmt.Errorf("permission denied") + } + return ctx.Permissions().Permissions, nil + } + } + if srv.NoClientAuthHandler != nil { + config.NoClientAuthCallback = func(conn gossh.ConnMetadata) (*gossh.Permissions, error) { + applyConnMetadata(ctx, conn) + if err := srv.NoClientAuthHandler(ctx); err != nil { + return ctx.Permissions().Permissions, err + } + return ctx.Permissions().Permissions, nil + } + } + return config +} + +// Handle sets the Handler for the server. +func (srv *Server) Handle(fn Handler) { + srv.mu.Lock() + defer srv.mu.Unlock() + + srv.Handler = fn +} + +// Close immediately closes all active listeners and all active +// connections. +// +// Close returns any error returned from closing the Server's +// underlying Listener(s). +func (srv *Server) Close() error { + srv.mu.Lock() + defer srv.mu.Unlock() + + srv.closeDoneChanLocked() + err := srv.closeListenersLocked() + for c := range srv.conns { + c.Close() + delete(srv.conns, c) + } + return err +} + +// Shutdown gracefully shuts down the server without interrupting any +// active connections. Shutdown works by first closing all open +// listeners, and then waiting indefinitely for connections to close. +// If the provided context expires before the shutdown is complete, +// then the context's error is returned. +func (srv *Server) Shutdown(ctx context.Context) error { + srv.mu.Lock() + lnerr := srv.closeListenersLocked() + srv.closeDoneChanLocked() + srv.mu.Unlock() + + finished := make(chan struct{}, 1) + go func() { + srv.listenerWg.Wait() + srv.connWg.Wait() + finished <- struct{}{} + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-finished: + return lnerr + } +} + +// Serve accepts incoming connections on the Listener l, creating a new +// connection goroutine for each. The connection goroutines read requests and then +// calls srv.Handler to handle sessions. +// +// Serve always returns a non-nil error. +func (srv *Server) Serve(l net.Listener) error { + srv.ensureHandlers() + defer l.Close() + if err := srv.ensureHostSigner(); err != nil { + return err + } + if srv.Handler == nil { + srv.Handler = DefaultHandler + } + var tempDelay time.Duration + + srv.trackListener(l, true) + defer srv.trackListener(l, false) + for { + conn, e := l.Accept() + if e != nil { + select { + case <-srv.getDoneChan(): + return ErrServerClosed + default: + } + if ne, ok := e.(net.Error); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + time.Sleep(tempDelay) + continue + } + return e + } + go srv.HandleConn(conn) + } +} + +func (srv *Server) HandleConn(newConn net.Conn) { + ctx, cancel := newContext(srv) + if srv.ConnCallback != nil { + cbConn := srv.ConnCallback(ctx, newConn) + if cbConn == nil { + newConn.Close() + return + } + newConn = cbConn + } + conn := &serverConn{ + Conn: newConn, + idleTimeout: srv.IdleTimeout, + closeCanceler: cancel, + } + if srv.MaxTimeout > 0 { + conn.maxDeadline = time.Now().Add(srv.MaxTimeout) + } + defer conn.Close() + sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx)) + if err != nil { + if srv.ConnectionFailedCallback != nil { + srv.ConnectionFailedCallback(conn, err) + } + return + } + + srv.trackConn(sshConn, true) + defer srv.trackConn(sshConn, false) + + ctx.SetValue(ContextKeyConn, sshConn) + applyConnMetadata(ctx, sshConn) + //go gossh.DiscardRequests(reqs) + go srv.handleRequests(ctx, reqs) + for ch := range chans { + handler := srv.ChannelHandlers[ch.ChannelType()] + if handler == nil { + handler = srv.ChannelHandlers["default"] + } + if handler == nil { + ch.Reject(gossh.UnknownChannelType, "unsupported channel type") + continue + } + go handler(srv, sshConn, ch, ctx) + } +} + +func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { + for req := range in { + handler := srv.RequestHandlers[req.Type] + if handler == nil { + handler = srv.RequestHandlers["default"] + } + if handler == nil { + req.Reply(false, nil) + continue + } + /*reqCtx, cancel := context.WithCancel(ctx) + defer cancel() */ + ret, payload := handler(ctx, srv, req) + req.Reply(ret, payload) + } +} + +// ListenAndServe listens on the TCP network address srv.Addr and then calls +// Serve to handle incoming connections. If srv.Addr is blank, ":22" is used. +// ListenAndServe always returns a non-nil error. +func (srv *Server) ListenAndServe() error { + addr := srv.Addr + if addr == "" { + addr = ":22" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + return srv.Serve(ln) +} + +// AddHostKey adds a private key as a host key. If an existing host key exists +// with the same algorithm, it is overwritten. Each server config must have at +// least one host key. +func (srv *Server) AddHostKey(key Signer) { + srv.mu.Lock() + defer srv.mu.Unlock() + + // these are later added via AddHostKey on ServerConfig, which performs the + // check for one of every algorithm. + + // This check is based on the AddHostKey method from the x/crypto/ssh + // library. This allows us to only keep one active key for each type on a + // server at once. So, if you're dynamically updating keys at runtime, this + // list will not keep growing. + for i, k := range srv.HostSigners { + if k.PublicKey().Type() == key.PublicKey().Type() { + srv.HostSigners[i] = key + return + } + } + + srv.HostSigners = append(srv.HostSigners, key) +} + +// SetOption runs a functional option against the server. +func (srv *Server) SetOption(option Option) error { + // NOTE: there is a potential race here for any option that doesn't call an + // internal method. We can't actually lock here because if something calls + // (as an example) AddHostKey, it will deadlock. + + //srv.mu.Lock() + //defer srv.mu.Unlock() + + return option(srv) +} + +func (srv *Server) getDoneChan() <-chan struct{} { + srv.mu.Lock() + defer srv.mu.Unlock() + + return srv.getDoneChanLocked() +} + +func (srv *Server) getDoneChanLocked() chan struct{} { + if srv.doneChan == nil { + srv.doneChan = make(chan struct{}) + } + return srv.doneChan +} + +func (srv *Server) closeDoneChanLocked() { + ch := srv.getDoneChanLocked() + select { + case <-ch: + // Already closed. Don't close again. + default: + // Safe to close here. We're the only closer, guarded + // by srv.mu. + close(ch) + } +} + +func (srv *Server) closeListenersLocked() error { + var err error + for ln := range srv.listeners { + if cerr := ln.Close(); cerr != nil && err == nil { + err = cerr + } + delete(srv.listeners, ln) + } + return err +} + +func (srv *Server) trackListener(ln net.Listener, add bool) { + srv.mu.Lock() + defer srv.mu.Unlock() + + if srv.listeners == nil { + srv.listeners = make(map[net.Listener]struct{}) + } + if add { + // If the *Server is being reused after a previous + // Close or Shutdown, reset its doneChan: + if len(srv.listeners) == 0 && len(srv.conns) == 0 { + srv.doneChan = nil + } + srv.listeners[ln] = struct{}{} + srv.listenerWg.Add(1) + } else { + delete(srv.listeners, ln) + srv.listenerWg.Done() + } +} + +func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { + srv.mu.Lock() + defer srv.mu.Unlock() + + if srv.conns == nil { + srv.conns = make(map[*gossh.ServerConn]struct{}) + } + if add { + srv.conns[c] = struct{}{} + srv.connWg.Add(1) + } else { + delete(srv.conns, c) + srv.connWg.Done() + } +} diff --git a/tempfork/gliderlabs/ssh/server_test.go b/tempfork/gliderlabs/ssh/server_test.go index 177c071170c4e..1a63bb4b2f3d5 100644 --- a/tempfork/gliderlabs/ssh/server_test.go +++ b/tempfork/gliderlabs/ssh/server_test.go @@ -1,128 +1,128 @@ -//go:build glidertests - -package ssh - -import ( - "bytes" - "context" - "io" - "testing" - "time" -) - -func TestAddHostKey(t *testing.T) { - s := Server{} - signer, err := generateSigner() - if err != nil { - t.Fatal(err) - } - s.AddHostKey(signer) - if len(s.HostSigners) != 1 { - t.Fatal("Key was not properly added") - } - signer, err = generateSigner() - if err != nil { - t.Fatal(err) - } - s.AddHostKey(signer) - if len(s.HostSigners) != 1 { - t.Fatal("Key was not properly replaced") - } -} - -func TestServerShutdown(t *testing.T) { - l := newLocalListener() - testBytes := []byte("Hello world\n") - s := &Server{ - Handler: func(s Session) { - s.Write(testBytes) - time.Sleep(50 * time.Millisecond) - }, - } - go func() { - err := s.Serve(l) - if err != nil && err != ErrServerClosed { - t.Fatal(err) - } - }() - sessDone := make(chan struct{}) - sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) - go func() { - defer cleanup() - defer close(sessDone) - var stdout bytes.Buffer - sess.Stdout = &stdout - if err := sess.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testBytes) { - t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes()) - } - }() - - srvDone := make(chan struct{}) - go func() { - defer close(srvDone) - err := s.Shutdown(context.Background()) - if err != nil { - t.Fatal(err) - } - }() - - timeout := time.After(2 * time.Second) - select { - case <-timeout: - t.Fatal("timeout") - return - case <-srvDone: - // TODO: add timeout for sessDone - <-sessDone - return - } -} - -func TestServerClose(t *testing.T) { - l := newLocalListener() - s := &Server{ - Handler: func(s Session) { - time.Sleep(5 * time.Second) - }, - } - go func() { - err := s.Serve(l) - if err != nil && err != ErrServerClosed { - t.Fatal(err) - } - }() - - clientDoneChan := make(chan struct{}) - closeDoneChan := make(chan struct{}) - - sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) - go func() { - defer cleanup() - defer close(clientDoneChan) - <-closeDoneChan - if err := sess.Run(""); err != nil && err != io.EOF { - t.Fatal(err) - } - }() - - go func() { - err := s.Close() - if err != nil { - t.Fatal(err) - } - close(closeDoneChan) - }() - - timeout := time.After(100 * time.Millisecond) - select { - case <-timeout: - t.Error("timeout") - return - case <-s.getDoneChan(): - <-clientDoneChan - return - } -} +//go:build glidertests + +package ssh + +import ( + "bytes" + "context" + "io" + "testing" + "time" +) + +func TestAddHostKey(t *testing.T) { + s := Server{} + signer, err := generateSigner() + if err != nil { + t.Fatal(err) + } + s.AddHostKey(signer) + if len(s.HostSigners) != 1 { + t.Fatal("Key was not properly added") + } + signer, err = generateSigner() + if err != nil { + t.Fatal(err) + } + s.AddHostKey(signer) + if len(s.HostSigners) != 1 { + t.Fatal("Key was not properly replaced") + } +} + +func TestServerShutdown(t *testing.T) { + l := newLocalListener() + testBytes := []byte("Hello world\n") + s := &Server{ + Handler: func(s Session) { + s.Write(testBytes) + time.Sleep(50 * time.Millisecond) + }, + } + go func() { + err := s.Serve(l) + if err != nil && err != ErrServerClosed { + t.Fatal(err) + } + }() + sessDone := make(chan struct{}) + sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) + go func() { + defer cleanup() + defer close(sessDone) + var stdout bytes.Buffer + sess.Stdout = &stdout + if err := sess.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testBytes) { + t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes()) + } + }() + + srvDone := make(chan struct{}) + go func() { + defer close(srvDone) + err := s.Shutdown(context.Background()) + if err != nil { + t.Fatal(err) + } + }() + + timeout := time.After(2 * time.Second) + select { + case <-timeout: + t.Fatal("timeout") + return + case <-srvDone: + // TODO: add timeout for sessDone + <-sessDone + return + } +} + +func TestServerClose(t *testing.T) { + l := newLocalListener() + s := &Server{ + Handler: func(s Session) { + time.Sleep(5 * time.Second) + }, + } + go func() { + err := s.Serve(l) + if err != nil && err != ErrServerClosed { + t.Fatal(err) + } + }() + + clientDoneChan := make(chan struct{}) + closeDoneChan := make(chan struct{}) + + sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) + go func() { + defer cleanup() + defer close(clientDoneChan) + <-closeDoneChan + if err := sess.Run(""); err != nil && err != io.EOF { + t.Fatal(err) + } + }() + + go func() { + err := s.Close() + if err != nil { + t.Fatal(err) + } + close(closeDoneChan) + }() + + timeout := time.After(100 * time.Millisecond) + select { + case <-timeout: + t.Error("timeout") + return + case <-s.getDoneChan(): + <-clientDoneChan + return + } +} diff --git a/tempfork/gliderlabs/ssh/session.go b/tempfork/gliderlabs/ssh/session.go index 0a4a21e534401..2f43de739d6d0 100644 --- a/tempfork/gliderlabs/ssh/session.go +++ b/tempfork/gliderlabs/ssh/session.go @@ -1,386 +1,386 @@ -package ssh - -import ( - "bytes" - "context" - "errors" - "fmt" - "net" - "sync" - - "github.com/anmitsu/go-shlex" - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -// Session provides access to information about an SSH session and methods -// to read and write to the SSH channel with an embedded Channel interface from -// crypto/ssh. -// -// When Command() returns an empty slice, the user requested a shell. Otherwise -// the user is performing an exec with those command arguments. -// -// TODO: Signals -type Session interface { - gossh.Channel - - // User returns the username used when establishing the SSH connection. - User() string - - // RemoteAddr returns the net.Addr of the client side of the connection. - RemoteAddr() net.Addr - - // LocalAddr returns the net.Addr of the server side of the connection. - LocalAddr() net.Addr - - // Environ returns a copy of strings representing the environment set by the - // user for this session, in the form "key=value". - Environ() []string - - // Exit sends an exit status and then closes the session. - Exit(code int) error - - // Command returns a shell parsed slice of arguments that were provided by the - // user. Shell parsing splits the command string according to POSIX shell rules, - // which considers quoting not just whitespace. - Command() []string - - // RawCommand returns the exact command that was provided by the user. - RawCommand() string - - // Subsystem returns the subsystem requested by the user. - Subsystem() string - - // PublicKey returns the PublicKey used to authenticate. If a public key was not - // used it will return nil. - PublicKey() PublicKey - - // Context returns the connection's context. The returned context is always - // non-nil and holds the same data as the Context passed into auth - // handlers and callbacks. - // - // The context is canceled when the client's connection closes or I/O - // operation fails. - Context() context.Context - - // Permissions returns a copy of the Permissions object that was available for - // setup in the auth handlers via the Context. - Permissions() Permissions - - // Pty returns PTY information, a channel of window size changes, and a boolean - // of whether or not a PTY was accepted for this session. - Pty() (Pty, <-chan Window, bool) - - // Signals registers a channel to receive signals sent from the client. The - // channel must handle signal sends or it will block the SSH request loop. - // Registering nil will unregister the channel from signal sends. During the - // time no channel is registered signals are buffered up to a reasonable amount. - // If there are buffered signals when a channel is registered, they will be - // sent in order on the channel immediately after registering. - Signals(c chan<- Signal) - - // Break regisers a channel to receive notifications of break requests sent - // from the client. The channel must handle break requests, or it will block - // the request handling loop. Registering nil will unregister the channel. - // During the time that no channel is registered, breaks are ignored. - Break(c chan<- bool) - - // DisablePTYEmulation disables the session's default minimal PTY emulation. - // If you're setting the pty's termios settings from the Pty request, use - // this method to avoid corruption. - // Currently (2022-03-12) the only emulation implemented is NL-to-CRNL translation (`\n`=>`\r\n`). - // A call of DisablePTYEmulation must precede any call to Write. - DisablePTYEmulation() -} - -// maxSigBufSize is how many signals will be buffered -// when there is no signal channel specified -const maxSigBufSize = 128 - -func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { - ch, reqs, err := newChan.Accept() - if err != nil { - // TODO: trigger event callback - return - } - sess := &session{ - Channel: ch, - conn: conn, - handler: srv.Handler, - ptyCb: srv.PtyCallback, - sessReqCb: srv.SessionRequestCallback, - subsystemHandlers: srv.SubsystemHandlers, - ctx: ctx, - } - sess.handleRequests(reqs) -} - -type session struct { - sync.Mutex - gossh.Channel - conn *gossh.ServerConn - handler Handler - subsystemHandlers map[string]SubsystemHandler - handled bool - exited bool - pty *Pty - winch chan Window - env []string - ptyCb PtyCallback - sessReqCb SessionRequestCallback - rawCmd string - subsystem string - ctx Context - sigCh chan<- Signal - sigBuf []Signal - breakCh chan<- bool - disablePtyEmulation bool -} - -func (sess *session) DisablePTYEmulation() { - sess.disablePtyEmulation = true -} - -func (sess *session) Write(p []byte) (n int, err error) { - if sess.pty != nil && !sess.disablePtyEmulation { - m := len(p) - // normalize \n to \r\n when pty is accepted. - // this is a hardcoded shortcut since we don't support terminal modes. - p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1) - p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1) - n, err = sess.Channel.Write(p) - if n > m { - n = m - } - return - } - return sess.Channel.Write(p) -} - -func (sess *session) PublicKey() PublicKey { - sessionkey := sess.ctx.Value(ContextKeyPublicKey) - if sessionkey == nil { - return nil - } - return sessionkey.(PublicKey) -} - -func (sess *session) Permissions() Permissions { - // use context permissions because its properly - // wrapped and easier to dereference - perms := sess.ctx.Value(ContextKeyPermissions).(*Permissions) - return *perms -} - -func (sess *session) Context() context.Context { - return sess.ctx -} - -func (sess *session) Exit(code int) error { - sess.Lock() - defer sess.Unlock() - if sess.exited { - return errors.New("Session.Exit called multiple times") - } - sess.exited = true - - status := struct{ Status uint32 }{uint32(code)} - _, err := sess.SendRequest("exit-status", false, gossh.Marshal(&status)) - if err != nil { - return err - } - return sess.Close() -} - -func (sess *session) User() string { - return sess.conn.User() -} - -func (sess *session) RemoteAddr() net.Addr { - return sess.conn.RemoteAddr() -} - -func (sess *session) LocalAddr() net.Addr { - return sess.conn.LocalAddr() -} - -func (sess *session) Environ() []string { - return append([]string(nil), sess.env...) -} - -func (sess *session) RawCommand() string { - return sess.rawCmd -} - -func (sess *session) Command() []string { - cmd, _ := shlex.Split(sess.rawCmd, true) - return append([]string(nil), cmd...) -} - -func (sess *session) Subsystem() string { - return sess.subsystem -} - -func (sess *session) Pty() (Pty, <-chan Window, bool) { - if sess.pty != nil { - return *sess.pty, sess.winch, true - } - return Pty{}, sess.winch, false -} - -func (sess *session) Signals(c chan<- Signal) { - sess.Lock() - defer sess.Unlock() - sess.sigCh = c - if len(sess.sigBuf) > 0 { - go func() { - for _, sig := range sess.sigBuf { - sess.sigCh <- sig - } - }() - } -} - -func (sess *session) Break(c chan<- bool) { - sess.Lock() - defer sess.Unlock() - sess.breakCh = c -} - -func (sess *session) handleRequests(reqs <-chan *gossh.Request) { - for req := range reqs { - switch req.Type { - case "shell", "exec": - if sess.handled { - req.Reply(false, nil) - continue - } - - var payload = struct{ Value string }{} - gossh.Unmarshal(req.Payload, &payload) - sess.rawCmd = payload.Value - - // If there's a session policy callback, we need to confirm before - // accepting the session. - if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { - sess.rawCmd = "" - req.Reply(false, nil) - continue - } - - sess.handled = true - req.Reply(true, nil) - - go func() { - sess.handler(sess) - sess.Exit(0) - }() - case "subsystem": - if sess.handled { - req.Reply(false, nil) - continue - } - - var payload = struct{ Value string }{} - gossh.Unmarshal(req.Payload, &payload) - sess.subsystem = payload.Value - - // If there's a session policy callback, we need to confirm before - // accepting the session. - if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { - sess.rawCmd = "" - req.Reply(false, nil) - continue - } - - handler := sess.subsystemHandlers[payload.Value] - if handler == nil { - handler = sess.subsystemHandlers["default"] - } - if handler == nil { - req.Reply(false, nil) - continue - } - - sess.handled = true - req.Reply(true, nil) - - go func() { - handler(sess) - sess.Exit(0) - }() - case "env": - if sess.handled { - req.Reply(false, nil) - continue - } - var kv struct{ Key, Value string } - gossh.Unmarshal(req.Payload, &kv) - sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value)) - req.Reply(true, nil) - case "signal": - var payload struct{ Signal string } - gossh.Unmarshal(req.Payload, &payload) - sess.Lock() - if sess.sigCh != nil { - sess.sigCh <- Signal(payload.Signal) - } else { - if len(sess.sigBuf) < maxSigBufSize { - sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal)) - } - } - sess.Unlock() - case "pty-req": - if sess.handled || sess.pty != nil { - req.Reply(false, nil) - continue - } - ptyReq, ok := parsePtyRequest(req.Payload) - if !ok { - req.Reply(false, nil) - continue - } - if sess.ptyCb != nil { - ok := sess.ptyCb(sess.ctx, ptyReq) - if !ok { - req.Reply(false, nil) - continue - } - } - sess.pty = &ptyReq - sess.winch = make(chan Window, 1) - sess.winch <- ptyReq.Window - defer func() { - // when reqs is closed - close(sess.winch) - }() - req.Reply(ok, nil) - case "window-change": - if sess.pty == nil { - req.Reply(false, nil) - continue - } - win, _, ok := parseWindow(req.Payload) - if ok { - sess.pty.Window = win - sess.winch <- win - } - req.Reply(ok, nil) - case agentRequestType: - // TODO: option/callback to allow agent forwarding - SetAgentRequested(sess.ctx) - req.Reply(true, nil) - case "break": - ok := false - sess.Lock() - if sess.breakCh != nil { - sess.breakCh <- true - ok = true - } - req.Reply(ok, nil) - sess.Unlock() - default: - // TODO: debug log - req.Reply(false, nil) - } - } -} +package ssh + +import ( + "bytes" + "context" + "errors" + "fmt" + "net" + "sync" + + "github.com/anmitsu/go-shlex" + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +// Session provides access to information about an SSH session and methods +// to read and write to the SSH channel with an embedded Channel interface from +// crypto/ssh. +// +// When Command() returns an empty slice, the user requested a shell. Otherwise +// the user is performing an exec with those command arguments. +// +// TODO: Signals +type Session interface { + gossh.Channel + + // User returns the username used when establishing the SSH connection. + User() string + + // RemoteAddr returns the net.Addr of the client side of the connection. + RemoteAddr() net.Addr + + // LocalAddr returns the net.Addr of the server side of the connection. + LocalAddr() net.Addr + + // Environ returns a copy of strings representing the environment set by the + // user for this session, in the form "key=value". + Environ() []string + + // Exit sends an exit status and then closes the session. + Exit(code int) error + + // Command returns a shell parsed slice of arguments that were provided by the + // user. Shell parsing splits the command string according to POSIX shell rules, + // which considers quoting not just whitespace. + Command() []string + + // RawCommand returns the exact command that was provided by the user. + RawCommand() string + + // Subsystem returns the subsystem requested by the user. + Subsystem() string + + // PublicKey returns the PublicKey used to authenticate. If a public key was not + // used it will return nil. + PublicKey() PublicKey + + // Context returns the connection's context. The returned context is always + // non-nil and holds the same data as the Context passed into auth + // handlers and callbacks. + // + // The context is canceled when the client's connection closes or I/O + // operation fails. + Context() context.Context + + // Permissions returns a copy of the Permissions object that was available for + // setup in the auth handlers via the Context. + Permissions() Permissions + + // Pty returns PTY information, a channel of window size changes, and a boolean + // of whether or not a PTY was accepted for this session. + Pty() (Pty, <-chan Window, bool) + + // Signals registers a channel to receive signals sent from the client. The + // channel must handle signal sends or it will block the SSH request loop. + // Registering nil will unregister the channel from signal sends. During the + // time no channel is registered signals are buffered up to a reasonable amount. + // If there are buffered signals when a channel is registered, they will be + // sent in order on the channel immediately after registering. + Signals(c chan<- Signal) + + // Break regisers a channel to receive notifications of break requests sent + // from the client. The channel must handle break requests, or it will block + // the request handling loop. Registering nil will unregister the channel. + // During the time that no channel is registered, breaks are ignored. + Break(c chan<- bool) + + // DisablePTYEmulation disables the session's default minimal PTY emulation. + // If you're setting the pty's termios settings from the Pty request, use + // this method to avoid corruption. + // Currently (2022-03-12) the only emulation implemented is NL-to-CRNL translation (`\n`=>`\r\n`). + // A call of DisablePTYEmulation must precede any call to Write. + DisablePTYEmulation() +} + +// maxSigBufSize is how many signals will be buffered +// when there is no signal channel specified +const maxSigBufSize = 128 + +func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { + ch, reqs, err := newChan.Accept() + if err != nil { + // TODO: trigger event callback + return + } + sess := &session{ + Channel: ch, + conn: conn, + handler: srv.Handler, + ptyCb: srv.PtyCallback, + sessReqCb: srv.SessionRequestCallback, + subsystemHandlers: srv.SubsystemHandlers, + ctx: ctx, + } + sess.handleRequests(reqs) +} + +type session struct { + sync.Mutex + gossh.Channel + conn *gossh.ServerConn + handler Handler + subsystemHandlers map[string]SubsystemHandler + handled bool + exited bool + pty *Pty + winch chan Window + env []string + ptyCb PtyCallback + sessReqCb SessionRequestCallback + rawCmd string + subsystem string + ctx Context + sigCh chan<- Signal + sigBuf []Signal + breakCh chan<- bool + disablePtyEmulation bool +} + +func (sess *session) DisablePTYEmulation() { + sess.disablePtyEmulation = true +} + +func (sess *session) Write(p []byte) (n int, err error) { + if sess.pty != nil && !sess.disablePtyEmulation { + m := len(p) + // normalize \n to \r\n when pty is accepted. + // this is a hardcoded shortcut since we don't support terminal modes. + p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1) + p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1) + n, err = sess.Channel.Write(p) + if n > m { + n = m + } + return + } + return sess.Channel.Write(p) +} + +func (sess *session) PublicKey() PublicKey { + sessionkey := sess.ctx.Value(ContextKeyPublicKey) + if sessionkey == nil { + return nil + } + return sessionkey.(PublicKey) +} + +func (sess *session) Permissions() Permissions { + // use context permissions because its properly + // wrapped and easier to dereference + perms := sess.ctx.Value(ContextKeyPermissions).(*Permissions) + return *perms +} + +func (sess *session) Context() context.Context { + return sess.ctx +} + +func (sess *session) Exit(code int) error { + sess.Lock() + defer sess.Unlock() + if sess.exited { + return errors.New("Session.Exit called multiple times") + } + sess.exited = true + + status := struct{ Status uint32 }{uint32(code)} + _, err := sess.SendRequest("exit-status", false, gossh.Marshal(&status)) + if err != nil { + return err + } + return sess.Close() +} + +func (sess *session) User() string { + return sess.conn.User() +} + +func (sess *session) RemoteAddr() net.Addr { + return sess.conn.RemoteAddr() +} + +func (sess *session) LocalAddr() net.Addr { + return sess.conn.LocalAddr() +} + +func (sess *session) Environ() []string { + return append([]string(nil), sess.env...) +} + +func (sess *session) RawCommand() string { + return sess.rawCmd +} + +func (sess *session) Command() []string { + cmd, _ := shlex.Split(sess.rawCmd, true) + return append([]string(nil), cmd...) +} + +func (sess *session) Subsystem() string { + return sess.subsystem +} + +func (sess *session) Pty() (Pty, <-chan Window, bool) { + if sess.pty != nil { + return *sess.pty, sess.winch, true + } + return Pty{}, sess.winch, false +} + +func (sess *session) Signals(c chan<- Signal) { + sess.Lock() + defer sess.Unlock() + sess.sigCh = c + if len(sess.sigBuf) > 0 { + go func() { + for _, sig := range sess.sigBuf { + sess.sigCh <- sig + } + }() + } +} + +func (sess *session) Break(c chan<- bool) { + sess.Lock() + defer sess.Unlock() + sess.breakCh = c +} + +func (sess *session) handleRequests(reqs <-chan *gossh.Request) { + for req := range reqs { + switch req.Type { + case "shell", "exec": + if sess.handled { + req.Reply(false, nil) + continue + } + + var payload = struct{ Value string }{} + gossh.Unmarshal(req.Payload, &payload) + sess.rawCmd = payload.Value + + // If there's a session policy callback, we need to confirm before + // accepting the session. + if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { + sess.rawCmd = "" + req.Reply(false, nil) + continue + } + + sess.handled = true + req.Reply(true, nil) + + go func() { + sess.handler(sess) + sess.Exit(0) + }() + case "subsystem": + if sess.handled { + req.Reply(false, nil) + continue + } + + var payload = struct{ Value string }{} + gossh.Unmarshal(req.Payload, &payload) + sess.subsystem = payload.Value + + // If there's a session policy callback, we need to confirm before + // accepting the session. + if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { + sess.rawCmd = "" + req.Reply(false, nil) + continue + } + + handler := sess.subsystemHandlers[payload.Value] + if handler == nil { + handler = sess.subsystemHandlers["default"] + } + if handler == nil { + req.Reply(false, nil) + continue + } + + sess.handled = true + req.Reply(true, nil) + + go func() { + handler(sess) + sess.Exit(0) + }() + case "env": + if sess.handled { + req.Reply(false, nil) + continue + } + var kv struct{ Key, Value string } + gossh.Unmarshal(req.Payload, &kv) + sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value)) + req.Reply(true, nil) + case "signal": + var payload struct{ Signal string } + gossh.Unmarshal(req.Payload, &payload) + sess.Lock() + if sess.sigCh != nil { + sess.sigCh <- Signal(payload.Signal) + } else { + if len(sess.sigBuf) < maxSigBufSize { + sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal)) + } + } + sess.Unlock() + case "pty-req": + if sess.handled || sess.pty != nil { + req.Reply(false, nil) + continue + } + ptyReq, ok := parsePtyRequest(req.Payload) + if !ok { + req.Reply(false, nil) + continue + } + if sess.ptyCb != nil { + ok := sess.ptyCb(sess.ctx, ptyReq) + if !ok { + req.Reply(false, nil) + continue + } + } + sess.pty = &ptyReq + sess.winch = make(chan Window, 1) + sess.winch <- ptyReq.Window + defer func() { + // when reqs is closed + close(sess.winch) + }() + req.Reply(ok, nil) + case "window-change": + if sess.pty == nil { + req.Reply(false, nil) + continue + } + win, _, ok := parseWindow(req.Payload) + if ok { + sess.pty.Window = win + sess.winch <- win + } + req.Reply(ok, nil) + case agentRequestType: + // TODO: option/callback to allow agent forwarding + SetAgentRequested(sess.ctx) + req.Reply(true, nil) + case "break": + ok := false + sess.Lock() + if sess.breakCh != nil { + sess.breakCh <- true + ok = true + } + req.Reply(ok, nil) + sess.Unlock() + default: + // TODO: debug log + req.Reply(false, nil) + } + } +} diff --git a/tempfork/gliderlabs/ssh/session_test.go b/tempfork/gliderlabs/ssh/session_test.go index a60be5ec12d4e..fddd67f6d41cc 100644 --- a/tempfork/gliderlabs/ssh/session_test.go +++ b/tempfork/gliderlabs/ssh/session_test.go @@ -1,440 +1,440 @@ -//go:build glidertests - -package ssh - -import ( - "bytes" - "fmt" - "io" - "net" - "testing" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -func (srv *Server) serveOnce(l net.Listener) error { - srv.ensureHandlers() - if err := srv.ensureHostSigner(); err != nil { - return err - } - conn, e := l.Accept() - if e != nil { - return e - } - srv.ChannelHandlers = map[string]ChannelHandler{ - "session": DefaultSessionHandler, - "direct-tcpip": DirectTCPIPHandler, - } - srv.HandleConn(conn) - return nil -} - -func newLocalListener() net.Listener { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { - panic(fmt.Sprintf("failed to listen on a port: %v", err)) - } - } - return l -} - -func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { - if config == nil { - config = &gossh.ClientConfig{ - User: "testuser", - Auth: []gossh.AuthMethod{ - gossh.Password("testpass"), - }, - } - } - if config.HostKeyCallback == nil { - config.HostKeyCallback = gossh.InsecureIgnoreHostKey() - } - client, err := gossh.Dial("tcp", addr, config) - if err != nil { - t.Fatal(err) - } - session, err := client.NewSession() - if err != nil { - t.Fatal(err) - } - return session, client, func() { - session.Close() - client.Close() - } -} - -func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { - l := newLocalListener() - go srv.serveOnce(l) - return newClientSession(t, l.Addr().String(), cfg) -} - -func TestStdout(t *testing.T) { - t.Parallel() - testBytes := []byte("Hello world\n") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Write(testBytes) - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testBytes) { - t.Fatalf("stdout = %#v; want %#v", stdout.Bytes(), testBytes) - } -} - -func TestStderr(t *testing.T) { - t.Parallel() - testBytes := []byte("Hello world\n") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Stderr().Write(testBytes) - }, - }, nil) - defer cleanup() - var stderr bytes.Buffer - session.Stderr = &stderr - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stderr.Bytes(), testBytes) { - t.Fatalf("stderr = %#v; want %#v", stderr.Bytes(), testBytes) - } -} - -func TestStdin(t *testing.T) { - t.Parallel() - testBytes := []byte("Hello world\n") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - io.Copy(s, s) // stdin back into stdout - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - session.Stdin = bytes.NewBuffer(testBytes) - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testBytes) { - t.Fatalf("stdout = %#v; want %#v given stdin = %#v", stdout.Bytes(), testBytes, testBytes) - } -} - -func TestUser(t *testing.T) { - t.Parallel() - testUser := []byte("progrium") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - io.WriteString(s, s.User()) - }, - }, &gossh.ClientConfig{ - User: string(testUser), - }) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testUser) { - t.Fatalf("stdout = %#v; want %#v given user = %#v", stdout.Bytes(), testUser, string(testUser)) - } -} - -func TestDefaultExitStatusZero(t *testing.T) { - t.Parallel() - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - // noop - }, - }, nil) - defer cleanup() - err := session.Run("") - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -func TestExplicitExitStatusZero(t *testing.T) { - t.Parallel() - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Exit(0) - }, - }, nil) - defer cleanup() - err := session.Run("") - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -func TestExitStatusNonZero(t *testing.T) { - t.Parallel() - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Exit(1) - }, - }, nil) - defer cleanup() - err := session.Run("") - e, ok := err.(*gossh.ExitError) - if !ok { - t.Fatalf("expected ExitError but got %T", err) - } - if e.ExitStatus() != 1 { - t.Fatalf("exit-status = %#v; want %#v", e.ExitStatus(), 1) - } -} - -func TestPty(t *testing.T) { - t.Parallel() - term := "xterm" - winWidth := 40 - winHeight := 80 - done := make(chan bool) - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - ptyReq, _, isPty := s.Pty() - if !isPty { - t.Fatalf("expected pty but none requested") - } - if ptyReq.Term != term { - t.Fatalf("expected term %#v but got %#v", term, ptyReq.Term) - } - if ptyReq.Window.Width != winWidth { - t.Fatalf("expected window width %#v but got %#v", winWidth, ptyReq.Window.Width) - } - if ptyReq.Window.Height != winHeight { - t.Fatalf("expected window height %#v but got %#v", winHeight, ptyReq.Window.Height) - } - close(done) - }, - }, nil) - defer cleanup() - if err := session.RequestPty(term, winHeight, winWidth, gossh.TerminalModes{}); err != nil { - t.Fatalf("expected nil but got %v", err) - } - if err := session.Shell(); err != nil { - t.Fatalf("expected nil but got %v", err) - } - <-done -} - -func TestPtyResize(t *testing.T) { - t.Parallel() - winch0 := Window{Width: 40, Height: 80} - winch1 := Window{Width: 80, Height: 160} - winch2 := Window{Width: 20, Height: 40} - winches := make(chan Window) - done := make(chan bool) - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - ptyReq, winCh, isPty := s.Pty() - if !isPty { - t.Fatalf("expected pty but none requested") - } - if ptyReq.Window != winch0 { - t.Fatalf("expected window %#v but got %#v", winch0, ptyReq.Window) - } - for win := range winCh { - winches <- win - } - close(done) - }, - }, nil) - defer cleanup() - // winch0 - if err := session.RequestPty("xterm", winch0.Height, winch0.Width, gossh.TerminalModes{}); err != nil { - t.Fatalf("expected nil but got %v", err) - } - if err := session.Shell(); err != nil { - t.Fatalf("expected nil but got %v", err) - } - gotWinch := <-winches - if gotWinch != winch0 { - t.Fatalf("expected window %#v but got %#v", winch0, gotWinch) - } - // winch1 - winchMsg := struct{ w, h uint32 }{uint32(winch1.Width), uint32(winch1.Height)} - ok, err := session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) - if err == nil && !ok { - t.Fatalf("unexpected error or bad reply on send request") - } - gotWinch = <-winches - if gotWinch != winch1 { - t.Fatalf("expected window %#v but got %#v", winch1, gotWinch) - } - // winch2 - winchMsg = struct{ w, h uint32 }{uint32(winch2.Width), uint32(winch2.Height)} - ok, err = session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) - if err == nil && !ok { - t.Fatalf("unexpected error or bad reply on send request") - } - gotWinch = <-winches - if gotWinch != winch2 { - t.Fatalf("expected window %#v but got %#v", winch2, gotWinch) - } - session.Close() - <-done -} - -func TestSignals(t *testing.T) { - t.Parallel() - - // errChan lets us get errors back from the session - errChan := make(chan error, 5) - - // doneChan lets us specify that we should exit. - doneChan := make(chan interface{}) - - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - // We need to use a buffered channel here, otherwise it's possible for the - // second call to Signal to get discarded. - signals := make(chan Signal, 2) - s.Signals(signals) - - select { - case sig := <-signals: - if sig != SIGINT { - errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig) - return - } - case <-doneChan: - errChan <- fmt.Errorf("Unexpected done") - return - } - - select { - case sig := <-signals: - if sig != SIGKILL { - errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig) - return - } - case <-doneChan: - errChan <- fmt.Errorf("Unexpected done") - return - } - }, - }, nil) - defer cleanup() - - go func() { - session.Signal(gossh.SIGINT) - session.Signal(gossh.SIGKILL) - }() - - go func() { - errChan <- session.Run("") - }() - - err := <-errChan - close(doneChan) - - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -func TestBreakWithChanRegistered(t *testing.T) { - t.Parallel() - - // errChan lets us get errors back from the session - errChan := make(chan error, 5) - - // doneChan lets us specify that we should exit. - doneChan := make(chan interface{}) - - breakChan := make(chan bool) - - readyToReceiveBreak := make(chan bool) - - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Break(breakChan) // register a break channel with the session - readyToReceiveBreak <- true - - select { - case <-breakChan: - io.WriteString(s, "break") - case <-doneChan: - errChan <- fmt.Errorf("Unexpected done") - return - } - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - go func() { - errChan <- session.Run("") - }() - - <-readyToReceiveBreak - ok, err := session.SendRequest("break", true, nil) - if err != nil { - t.Fatalf("expected nil but got %v", err) - } - if ok != true { - t.Fatalf("expected true but got %v", ok) - } - - err = <-errChan - close(doneChan) - - if err != nil { - t.Fatalf("expected nil but got %v", err) - } - if !bytes.Equal(stdout.Bytes(), []byte("break")) { - t.Fatalf("stdout = %#v, expected 'break'", stdout.Bytes()) - } -} - -func TestBreakWithoutChanRegistered(t *testing.T) { - t.Parallel() - - // errChan lets us get errors back from the session - errChan := make(chan error, 5) - - // doneChan lets us specify that we should exit. - doneChan := make(chan interface{}) - - waitUntilAfterBreakSent := make(chan bool) - - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - <-waitUntilAfterBreakSent - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - go func() { - errChan <- session.Run("") - }() - - ok, err := session.SendRequest("break", true, nil) - if err != nil { - t.Fatalf("expected nil but got %v", err) - } - if ok != false { - t.Fatalf("expected false but got %v", ok) - } - waitUntilAfterBreakSent <- true - - err = <-errChan - close(doneChan) - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} +//go:build glidertests + +package ssh + +import ( + "bytes" + "fmt" + "io" + "net" + "testing" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +func (srv *Server) serveOnce(l net.Listener) error { + srv.ensureHandlers() + if err := srv.ensureHostSigner(); err != nil { + return err + } + conn, e := l.Accept() + if e != nil { + return e + } + srv.ChannelHandlers = map[string]ChannelHandler{ + "session": DefaultSessionHandler, + "direct-tcpip": DirectTCPIPHandler, + } + srv.HandleConn(conn) + return nil +} + +func newLocalListener() net.Listener { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(fmt.Sprintf("failed to listen on a port: %v", err)) + } + } + return l +} + +func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { + if config == nil { + config = &gossh.ClientConfig{ + User: "testuser", + Auth: []gossh.AuthMethod{ + gossh.Password("testpass"), + }, + } + } + if config.HostKeyCallback == nil { + config.HostKeyCallback = gossh.InsecureIgnoreHostKey() + } + client, err := gossh.Dial("tcp", addr, config) + if err != nil { + t.Fatal(err) + } + session, err := client.NewSession() + if err != nil { + t.Fatal(err) + } + return session, client, func() { + session.Close() + client.Close() + } +} + +func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { + l := newLocalListener() + go srv.serveOnce(l) + return newClientSession(t, l.Addr().String(), cfg) +} + +func TestStdout(t *testing.T) { + t.Parallel() + testBytes := []byte("Hello world\n") + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Write(testBytes) + }, + }, nil) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testBytes) { + t.Fatalf("stdout = %#v; want %#v", stdout.Bytes(), testBytes) + } +} + +func TestStderr(t *testing.T) { + t.Parallel() + testBytes := []byte("Hello world\n") + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Stderr().Write(testBytes) + }, + }, nil) + defer cleanup() + var stderr bytes.Buffer + session.Stderr = &stderr + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stderr.Bytes(), testBytes) { + t.Fatalf("stderr = %#v; want %#v", stderr.Bytes(), testBytes) + } +} + +func TestStdin(t *testing.T) { + t.Parallel() + testBytes := []byte("Hello world\n") + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + io.Copy(s, s) // stdin back into stdout + }, + }, nil) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + session.Stdin = bytes.NewBuffer(testBytes) + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testBytes) { + t.Fatalf("stdout = %#v; want %#v given stdin = %#v", stdout.Bytes(), testBytes, testBytes) + } +} + +func TestUser(t *testing.T) { + t.Parallel() + testUser := []byte("progrium") + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + io.WriteString(s, s.User()) + }, + }, &gossh.ClientConfig{ + User: string(testUser), + }) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testUser) { + t.Fatalf("stdout = %#v; want %#v given user = %#v", stdout.Bytes(), testUser, string(testUser)) + } +} + +func TestDefaultExitStatusZero(t *testing.T) { + t.Parallel() + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + // noop + }, + }, nil) + defer cleanup() + err := session.Run("") + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +func TestExplicitExitStatusZero(t *testing.T) { + t.Parallel() + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Exit(0) + }, + }, nil) + defer cleanup() + err := session.Run("") + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +func TestExitStatusNonZero(t *testing.T) { + t.Parallel() + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Exit(1) + }, + }, nil) + defer cleanup() + err := session.Run("") + e, ok := err.(*gossh.ExitError) + if !ok { + t.Fatalf("expected ExitError but got %T", err) + } + if e.ExitStatus() != 1 { + t.Fatalf("exit-status = %#v; want %#v", e.ExitStatus(), 1) + } +} + +func TestPty(t *testing.T) { + t.Parallel() + term := "xterm" + winWidth := 40 + winHeight := 80 + done := make(chan bool) + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + ptyReq, _, isPty := s.Pty() + if !isPty { + t.Fatalf("expected pty but none requested") + } + if ptyReq.Term != term { + t.Fatalf("expected term %#v but got %#v", term, ptyReq.Term) + } + if ptyReq.Window.Width != winWidth { + t.Fatalf("expected window width %#v but got %#v", winWidth, ptyReq.Window.Width) + } + if ptyReq.Window.Height != winHeight { + t.Fatalf("expected window height %#v but got %#v", winHeight, ptyReq.Window.Height) + } + close(done) + }, + }, nil) + defer cleanup() + if err := session.RequestPty(term, winHeight, winWidth, gossh.TerminalModes{}); err != nil { + t.Fatalf("expected nil but got %v", err) + } + if err := session.Shell(); err != nil { + t.Fatalf("expected nil but got %v", err) + } + <-done +} + +func TestPtyResize(t *testing.T) { + t.Parallel() + winch0 := Window{Width: 40, Height: 80} + winch1 := Window{Width: 80, Height: 160} + winch2 := Window{Width: 20, Height: 40} + winches := make(chan Window) + done := make(chan bool) + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + ptyReq, winCh, isPty := s.Pty() + if !isPty { + t.Fatalf("expected pty but none requested") + } + if ptyReq.Window != winch0 { + t.Fatalf("expected window %#v but got %#v", winch0, ptyReq.Window) + } + for win := range winCh { + winches <- win + } + close(done) + }, + }, nil) + defer cleanup() + // winch0 + if err := session.RequestPty("xterm", winch0.Height, winch0.Width, gossh.TerminalModes{}); err != nil { + t.Fatalf("expected nil but got %v", err) + } + if err := session.Shell(); err != nil { + t.Fatalf("expected nil but got %v", err) + } + gotWinch := <-winches + if gotWinch != winch0 { + t.Fatalf("expected window %#v but got %#v", winch0, gotWinch) + } + // winch1 + winchMsg := struct{ w, h uint32 }{uint32(winch1.Width), uint32(winch1.Height)} + ok, err := session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) + if err == nil && !ok { + t.Fatalf("unexpected error or bad reply on send request") + } + gotWinch = <-winches + if gotWinch != winch1 { + t.Fatalf("expected window %#v but got %#v", winch1, gotWinch) + } + // winch2 + winchMsg = struct{ w, h uint32 }{uint32(winch2.Width), uint32(winch2.Height)} + ok, err = session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) + if err == nil && !ok { + t.Fatalf("unexpected error or bad reply on send request") + } + gotWinch = <-winches + if gotWinch != winch2 { + t.Fatalf("expected window %#v but got %#v", winch2, gotWinch) + } + session.Close() + <-done +} + +func TestSignals(t *testing.T) { + t.Parallel() + + // errChan lets us get errors back from the session + errChan := make(chan error, 5) + + // doneChan lets us specify that we should exit. + doneChan := make(chan interface{}) + + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + // We need to use a buffered channel here, otherwise it's possible for the + // second call to Signal to get discarded. + signals := make(chan Signal, 2) + s.Signals(signals) + + select { + case sig := <-signals: + if sig != SIGINT { + errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig) + return + } + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } + + select { + case sig := <-signals: + if sig != SIGKILL { + errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig) + return + } + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } + }, + }, nil) + defer cleanup() + + go func() { + session.Signal(gossh.SIGINT) + session.Signal(gossh.SIGKILL) + }() + + go func() { + errChan <- session.Run("") + }() + + err := <-errChan + close(doneChan) + + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +func TestBreakWithChanRegistered(t *testing.T) { + t.Parallel() + + // errChan lets us get errors back from the session + errChan := make(chan error, 5) + + // doneChan lets us specify that we should exit. + doneChan := make(chan interface{}) + + breakChan := make(chan bool) + + readyToReceiveBreak := make(chan bool) + + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Break(breakChan) // register a break channel with the session + readyToReceiveBreak <- true + + select { + case <-breakChan: + io.WriteString(s, "break") + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } + }, + }, nil) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + go func() { + errChan <- session.Run("") + }() + + <-readyToReceiveBreak + ok, err := session.SendRequest("break", true, nil) + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + if ok != true { + t.Fatalf("expected true but got %v", ok) + } + + err = <-errChan + close(doneChan) + + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + if !bytes.Equal(stdout.Bytes(), []byte("break")) { + t.Fatalf("stdout = %#v, expected 'break'", stdout.Bytes()) + } +} + +func TestBreakWithoutChanRegistered(t *testing.T) { + t.Parallel() + + // errChan lets us get errors back from the session + errChan := make(chan error, 5) + + // doneChan lets us specify that we should exit. + doneChan := make(chan interface{}) + + waitUntilAfterBreakSent := make(chan bool) + + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + <-waitUntilAfterBreakSent + }, + }, nil) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + go func() { + errChan <- session.Run("") + }() + + ok, err := session.SendRequest("break", true, nil) + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + if ok != false { + t.Fatalf("expected false but got %v", ok) + } + waitUntilAfterBreakSent <- true + + err = <-errChan + close(doneChan) + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} diff --git a/tempfork/gliderlabs/ssh/ssh.go b/tempfork/gliderlabs/ssh/ssh.go index 644cb257d9afa..4216ea97ab932 100644 --- a/tempfork/gliderlabs/ssh/ssh.go +++ b/tempfork/gliderlabs/ssh/ssh.go @@ -1,156 +1,156 @@ -package ssh - -import ( - "crypto/subtle" - "net" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -type Signal string - -// POSIX signals as listed in RFC 4254 Section 6.10. -const ( - SIGABRT Signal = "ABRT" - SIGALRM Signal = "ALRM" - SIGFPE Signal = "FPE" - SIGHUP Signal = "HUP" - SIGILL Signal = "ILL" - SIGINT Signal = "INT" - SIGKILL Signal = "KILL" - SIGPIPE Signal = "PIPE" - SIGQUIT Signal = "QUIT" - SIGSEGV Signal = "SEGV" - SIGTERM Signal = "TERM" - SIGUSR1 Signal = "USR1" - SIGUSR2 Signal = "USR2" -) - -// DefaultHandler is the default Handler used by Serve. -var DefaultHandler Handler - -// Option is a functional option handler for Server. -type Option func(*Server) error - -// Handler is a callback for handling established SSH sessions. -type Handler func(Session) - -// PublicKeyHandler is a callback for performing public key authentication. -type PublicKeyHandler func(ctx Context, key PublicKey) error - -type NoClientAuthHandler func(ctx Context) error - -type BannerHandler func(ctx Context) string - -// PasswordHandler is a callback for performing password authentication. -type PasswordHandler func(ctx Context, password string) bool - -// KeyboardInteractiveHandler is a callback for performing keyboard-interactive authentication. -type KeyboardInteractiveHandler func(ctx Context, challenger gossh.KeyboardInteractiveChallenge) bool - -// PtyCallback is a hook for allowing PTY sessions. -type PtyCallback func(ctx Context, pty Pty) bool - -// SessionRequestCallback is a callback for allowing or denying SSH sessions. -type SessionRequestCallback func(sess Session, requestType string) bool - -// ConnCallback is a hook for new connections before handling. -// It allows wrapping for timeouts and limiting by returning -// the net.Conn that will be used as the underlying connection. -type ConnCallback func(ctx Context, conn net.Conn) net.Conn - -// LocalPortForwardingCallback is a hook for allowing port forwarding -type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool - -// ReversePortForwardingCallback is a hook for allowing reverse port forwarding -type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool - -// ServerConfigCallback is a hook for creating custom default server configs -type ServerConfigCallback func(ctx Context) *gossh.ServerConfig - -// ConnectionFailedCallback is a hook for reporting failed connections -// Please note: the net.Conn is likely to be closed at this point -type ConnectionFailedCallback func(conn net.Conn, err error) - -// Window represents the size of a PTY window. -// -// See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 -// -// Zero dimension parameters MUST be ignored. The character/row dimensions -// override the pixel dimensions (when nonzero). Pixel dimensions refer -// to the drawable area of the window. -type Window struct { - // Width is the number of columns. - // It overrides WidthPixels. - Width int - // Height is the number of rows. - // It overrides HeightPixels. - Height int - - // WidthPixels is the drawable width of the window, in pixels. - WidthPixels int - // HeightPixels is the drawable height of the window, in pixels. - HeightPixels int -} - -// Pty represents a PTY request and configuration. -type Pty struct { - // Term is the TERM environment variable value. - Term string - - // Window is the Window sent as part of the pty-req. - Window Window - - // Modes represent a mapping of Terminal Mode opcode to value as it was - // requested by the client as part of the pty-req. These are outlined as - // part of https://datatracker.ietf.org/doc/html/rfc4254#section-8. - // - // The opcodes are defined as constants in github.com/tailscale/golang-x-crypto/ssh (VINTR,VQUIT,etc.). - // Boolean opcodes have values 0 or 1. - Modes gossh.TerminalModes -} - -// Serve accepts incoming SSH connections on the listener l, creating a new -// connection goroutine for each. The connection goroutines read requests and -// then calls handler to handle sessions. Handler is typically nil, in which -// case the DefaultHandler is used. -func Serve(l net.Listener, handler Handler, options ...Option) error { - srv := &Server{Handler: handler} - for _, option := range options { - if err := srv.SetOption(option); err != nil { - return err - } - } - return srv.Serve(l) -} - -// ListenAndServe listens on the TCP network address addr and then calls Serve -// with handler to handle sessions on incoming connections. Handler is typically -// nil, in which case the DefaultHandler is used. -func ListenAndServe(addr string, handler Handler, options ...Option) error { - srv := &Server{Addr: addr, Handler: handler} - for _, option := range options { - if err := srv.SetOption(option); err != nil { - return err - } - } - return srv.ListenAndServe() -} - -// Handle registers the handler as the DefaultHandler. -func Handle(handler Handler) { - DefaultHandler = handler -} - -// KeysEqual is constant time compare of the keys to avoid timing attacks. -func KeysEqual(ak, bk PublicKey) bool { - - //avoid panic if one of the keys is nil, return false instead - if ak == nil || bk == nil { - return false - } - - a := ak.Marshal() - b := bk.Marshal() - return (len(a) == len(b) && subtle.ConstantTimeCompare(a, b) == 1) -} +package ssh + +import ( + "crypto/subtle" + "net" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +type Signal string + +// POSIX signals as listed in RFC 4254 Section 6.10. +const ( + SIGABRT Signal = "ABRT" + SIGALRM Signal = "ALRM" + SIGFPE Signal = "FPE" + SIGHUP Signal = "HUP" + SIGILL Signal = "ILL" + SIGINT Signal = "INT" + SIGKILL Signal = "KILL" + SIGPIPE Signal = "PIPE" + SIGQUIT Signal = "QUIT" + SIGSEGV Signal = "SEGV" + SIGTERM Signal = "TERM" + SIGUSR1 Signal = "USR1" + SIGUSR2 Signal = "USR2" +) + +// DefaultHandler is the default Handler used by Serve. +var DefaultHandler Handler + +// Option is a functional option handler for Server. +type Option func(*Server) error + +// Handler is a callback for handling established SSH sessions. +type Handler func(Session) + +// PublicKeyHandler is a callback for performing public key authentication. +type PublicKeyHandler func(ctx Context, key PublicKey) error + +type NoClientAuthHandler func(ctx Context) error + +type BannerHandler func(ctx Context) string + +// PasswordHandler is a callback for performing password authentication. +type PasswordHandler func(ctx Context, password string) bool + +// KeyboardInteractiveHandler is a callback for performing keyboard-interactive authentication. +type KeyboardInteractiveHandler func(ctx Context, challenger gossh.KeyboardInteractiveChallenge) bool + +// PtyCallback is a hook for allowing PTY sessions. +type PtyCallback func(ctx Context, pty Pty) bool + +// SessionRequestCallback is a callback for allowing or denying SSH sessions. +type SessionRequestCallback func(sess Session, requestType string) bool + +// ConnCallback is a hook for new connections before handling. +// It allows wrapping for timeouts and limiting by returning +// the net.Conn that will be used as the underlying connection. +type ConnCallback func(ctx Context, conn net.Conn) net.Conn + +// LocalPortForwardingCallback is a hook for allowing port forwarding +type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool + +// ReversePortForwardingCallback is a hook for allowing reverse port forwarding +type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool + +// ServerConfigCallback is a hook for creating custom default server configs +type ServerConfigCallback func(ctx Context) *gossh.ServerConfig + +// ConnectionFailedCallback is a hook for reporting failed connections +// Please note: the net.Conn is likely to be closed at this point +type ConnectionFailedCallback func(conn net.Conn, err error) + +// Window represents the size of a PTY window. +// +// See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 +// +// Zero dimension parameters MUST be ignored. The character/row dimensions +// override the pixel dimensions (when nonzero). Pixel dimensions refer +// to the drawable area of the window. +type Window struct { + // Width is the number of columns. + // It overrides WidthPixels. + Width int + // Height is the number of rows. + // It overrides HeightPixels. + Height int + + // WidthPixels is the drawable width of the window, in pixels. + WidthPixels int + // HeightPixels is the drawable height of the window, in pixels. + HeightPixels int +} + +// Pty represents a PTY request and configuration. +type Pty struct { + // Term is the TERM environment variable value. + Term string + + // Window is the Window sent as part of the pty-req. + Window Window + + // Modes represent a mapping of Terminal Mode opcode to value as it was + // requested by the client as part of the pty-req. These are outlined as + // part of https://datatracker.ietf.org/doc/html/rfc4254#section-8. + // + // The opcodes are defined as constants in github.com/tailscale/golang-x-crypto/ssh (VINTR,VQUIT,etc.). + // Boolean opcodes have values 0 or 1. + Modes gossh.TerminalModes +} + +// Serve accepts incoming SSH connections on the listener l, creating a new +// connection goroutine for each. The connection goroutines read requests and +// then calls handler to handle sessions. Handler is typically nil, in which +// case the DefaultHandler is used. +func Serve(l net.Listener, handler Handler, options ...Option) error { + srv := &Server{Handler: handler} + for _, option := range options { + if err := srv.SetOption(option); err != nil { + return err + } + } + return srv.Serve(l) +} + +// ListenAndServe listens on the TCP network address addr and then calls Serve +// with handler to handle sessions on incoming connections. Handler is typically +// nil, in which case the DefaultHandler is used. +func ListenAndServe(addr string, handler Handler, options ...Option) error { + srv := &Server{Addr: addr, Handler: handler} + for _, option := range options { + if err := srv.SetOption(option); err != nil { + return err + } + } + return srv.ListenAndServe() +} + +// Handle registers the handler as the DefaultHandler. +func Handle(handler Handler) { + DefaultHandler = handler +} + +// KeysEqual is constant time compare of the keys to avoid timing attacks. +func KeysEqual(ak, bk PublicKey) bool { + + //avoid panic if one of the keys is nil, return false instead + if ak == nil || bk == nil { + return false + } + + a := ak.Marshal() + b := bk.Marshal() + return (len(a) == len(b) && subtle.ConstantTimeCompare(a, b) == 1) +} diff --git a/tempfork/gliderlabs/ssh/ssh_test.go b/tempfork/gliderlabs/ssh/ssh_test.go index aa301b0489f21..8772c03adea53 100644 --- a/tempfork/gliderlabs/ssh/ssh_test.go +++ b/tempfork/gliderlabs/ssh/ssh_test.go @@ -1,17 +1,17 @@ -package ssh - -import ( - "testing" -) - -func TestKeysEqual(t *testing.T) { - defer func() { - if r := recover(); r != nil { - t.Errorf("The code did panic") - } - }() - - if KeysEqual(nil, nil) { - t.Error("two nil keys should not return true") - } -} +package ssh + +import ( + "testing" +) + +func TestKeysEqual(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("The code did panic") + } + }() + + if KeysEqual(nil, nil) { + t.Error("two nil keys should not return true") + } +} diff --git a/tempfork/gliderlabs/ssh/tcpip.go b/tempfork/gliderlabs/ssh/tcpip.go index 056a0c7343daf..d30bb15ac284b 100644 --- a/tempfork/gliderlabs/ssh/tcpip.go +++ b/tempfork/gliderlabs/ssh/tcpip.go @@ -1,193 +1,193 @@ -package ssh - -import ( - "io" - "log" - "net" - "strconv" - "sync" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -const ( - forwardedTCPChannelType = "forwarded-tcpip" -) - -// direct-tcpip data struct as specified in RFC4254, Section 7.2 -type localForwardChannelData struct { - DestAddr string - DestPort uint32 - - OriginAddr string - OriginPort uint32 -} - -// DirectTCPIPHandler can be enabled by adding it to the server's -// ChannelHandlers under direct-tcpip. -func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { - d := localForwardChannelData{} - if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { - newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) - return - } - - if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestAddr, d.DestPort) { - newChan.Reject(gossh.Prohibited, "port forwarding is disabled") - return - } - - dest := net.JoinHostPort(d.DestAddr, strconv.FormatInt(int64(d.DestPort), 10)) - - var dialer net.Dialer - dconn, err := dialer.DialContext(ctx, "tcp", dest) - if err != nil { - newChan.Reject(gossh.ConnectionFailed, err.Error()) - return - } - - ch, reqs, err := newChan.Accept() - if err != nil { - dconn.Close() - return - } - go gossh.DiscardRequests(reqs) - - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(ch, dconn) - }() - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(dconn, ch) - }() -} - -type remoteForwardRequest struct { - BindAddr string - BindPort uint32 -} - -type remoteForwardSuccess struct { - BindPort uint32 -} - -type remoteForwardCancelRequest struct { - BindAddr string - BindPort uint32 -} - -type remoteForwardChannelData struct { - DestAddr string - DestPort uint32 - OriginAddr string - OriginPort uint32 -} - -// ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and -// adding the HandleSSHRequest callback to the server's RequestHandlers under -// tcpip-forward and cancel-tcpip-forward. -type ForwardedTCPHandler struct { - forwards map[string]net.Listener - sync.Mutex -} - -func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { - h.Lock() - if h.forwards == nil { - h.forwards = make(map[string]net.Listener) - } - h.Unlock() - conn := ctx.Value(ContextKeyConn).(*gossh.ServerConn) - switch req.Type { - case "tcpip-forward": - var reqPayload remoteForwardRequest - if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { - // TODO: log parse failure - return false, []byte{} - } - if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) { - return false, []byte("port forwarding is disabled") - } - addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) - ln, err := net.Listen("tcp", addr) - if err != nil { - // TODO: log listen failure - return false, []byte{} - } - _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) - destPort, _ := strconv.Atoi(destPortStr) - h.Lock() - h.forwards[addr] = ln - h.Unlock() - go func() { - <-ctx.Done() - h.Lock() - ln, ok := h.forwards[addr] - h.Unlock() - if ok { - ln.Close() - } - }() - go func() { - for { - c, err := ln.Accept() - if err != nil { - // TODO: log accept failure - break - } - originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String()) - originPort, _ := strconv.Atoi(orignPortStr) - payload := gossh.Marshal(&remoteForwardChannelData{ - DestAddr: reqPayload.BindAddr, - DestPort: uint32(destPort), - OriginAddr: originAddr, - OriginPort: uint32(originPort), - }) - go func() { - ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload) - if err != nil { - // TODO: log failure to open channel - log.Println(err) - c.Close() - return - } - go gossh.DiscardRequests(reqs) - go func() { - defer ch.Close() - defer c.Close() - io.Copy(ch, c) - }() - go func() { - defer ch.Close() - defer c.Close() - io.Copy(c, ch) - }() - }() - } - h.Lock() - delete(h.forwards, addr) - h.Unlock() - }() - return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)}) - - case "cancel-tcpip-forward": - var reqPayload remoteForwardCancelRequest - if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { - // TODO: log parse failure - return false, []byte{} - } - addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) - h.Lock() - ln, ok := h.forwards[addr] - h.Unlock() - if ok { - ln.Close() - } - return true, nil - default: - return false, nil - } -} +package ssh + +import ( + "io" + "log" + "net" + "strconv" + "sync" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +const ( + forwardedTCPChannelType = "forwarded-tcpip" +) + +// direct-tcpip data struct as specified in RFC4254, Section 7.2 +type localForwardChannelData struct { + DestAddr string + DestPort uint32 + + OriginAddr string + OriginPort uint32 +} + +// DirectTCPIPHandler can be enabled by adding it to the server's +// ChannelHandlers under direct-tcpip. +func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { + d := localForwardChannelData{} + if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { + newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) + return + } + + if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestAddr, d.DestPort) { + newChan.Reject(gossh.Prohibited, "port forwarding is disabled") + return + } + + dest := net.JoinHostPort(d.DestAddr, strconv.FormatInt(int64(d.DestPort), 10)) + + var dialer net.Dialer + dconn, err := dialer.DialContext(ctx, "tcp", dest) + if err != nil { + newChan.Reject(gossh.ConnectionFailed, err.Error()) + return + } + + ch, reqs, err := newChan.Accept() + if err != nil { + dconn.Close() + return + } + go gossh.DiscardRequests(reqs) + + go func() { + defer ch.Close() + defer dconn.Close() + io.Copy(ch, dconn) + }() + go func() { + defer ch.Close() + defer dconn.Close() + io.Copy(dconn, ch) + }() +} + +type remoteForwardRequest struct { + BindAddr string + BindPort uint32 +} + +type remoteForwardSuccess struct { + BindPort uint32 +} + +type remoteForwardCancelRequest struct { + BindAddr string + BindPort uint32 +} + +type remoteForwardChannelData struct { + DestAddr string + DestPort uint32 + OriginAddr string + OriginPort uint32 +} + +// ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and +// adding the HandleSSHRequest callback to the server's RequestHandlers under +// tcpip-forward and cancel-tcpip-forward. +type ForwardedTCPHandler struct { + forwards map[string]net.Listener + sync.Mutex +} + +func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { + h.Lock() + if h.forwards == nil { + h.forwards = make(map[string]net.Listener) + } + h.Unlock() + conn := ctx.Value(ContextKeyConn).(*gossh.ServerConn) + switch req.Type { + case "tcpip-forward": + var reqPayload remoteForwardRequest + if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { + // TODO: log parse failure + return false, []byte{} + } + if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) { + return false, []byte("port forwarding is disabled") + } + addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) + ln, err := net.Listen("tcp", addr) + if err != nil { + // TODO: log listen failure + return false, []byte{} + } + _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) + destPort, _ := strconv.Atoi(destPortStr) + h.Lock() + h.forwards[addr] = ln + h.Unlock() + go func() { + <-ctx.Done() + h.Lock() + ln, ok := h.forwards[addr] + h.Unlock() + if ok { + ln.Close() + } + }() + go func() { + for { + c, err := ln.Accept() + if err != nil { + // TODO: log accept failure + break + } + originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String()) + originPort, _ := strconv.Atoi(orignPortStr) + payload := gossh.Marshal(&remoteForwardChannelData{ + DestAddr: reqPayload.BindAddr, + DestPort: uint32(destPort), + OriginAddr: originAddr, + OriginPort: uint32(originPort), + }) + go func() { + ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload) + if err != nil { + // TODO: log failure to open channel + log.Println(err) + c.Close() + return + } + go gossh.DiscardRequests(reqs) + go func() { + defer ch.Close() + defer c.Close() + io.Copy(ch, c) + }() + go func() { + defer ch.Close() + defer c.Close() + io.Copy(c, ch) + }() + }() + } + h.Lock() + delete(h.forwards, addr) + h.Unlock() + }() + return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)}) + + case "cancel-tcpip-forward": + var reqPayload remoteForwardCancelRequest + if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { + // TODO: log parse failure + return false, []byte{} + } + addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) + h.Lock() + ln, ok := h.forwards[addr] + h.Unlock() + if ok { + ln.Close() + } + return true, nil + default: + return false, nil + } +} diff --git a/tempfork/gliderlabs/ssh/tcpip_test.go b/tempfork/gliderlabs/ssh/tcpip_test.go index 118b5d53ac4a1..e1d74d566c7bf 100644 --- a/tempfork/gliderlabs/ssh/tcpip_test.go +++ b/tempfork/gliderlabs/ssh/tcpip_test.go @@ -1,85 +1,85 @@ -//go:build glidertests - -package ssh - -import ( - "bytes" - "io" - "net" - "strconv" - "strings" - "testing" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -var sampleServerResponse = []byte("Hello world") - -func sampleSocketServer() net.Listener { - l := newLocalListener() - - go func() { - conn, err := l.Accept() - if err != nil { - return - } - conn.Write(sampleServerResponse) - conn.Close() - }() - - return l -} - -func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { - l := sampleSocketServer() - - _, client, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) {}, - LocalPortForwardingCallback: func(ctx Context, destinationHost string, destinationPort uint32) bool { - addr := net.JoinHostPort(destinationHost, strconv.FormatInt(int64(destinationPort), 10)) - if addr != l.Addr().String() { - panic("unexpected destinationHost: " + addr) - } - return forwardingEnabled - }, - }, nil) - - return l, client, func() { - cleanup() - l.Close() - } -} - -func TestLocalPortForwardingWorks(t *testing.T) { - t.Parallel() - - l, client, cleanup := newTestSessionWithForwarding(t, true) - defer cleanup() - - conn, err := client.Dial("tcp", l.Addr().String()) - if err != nil { - t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) - } - result, err := io.ReadAll(conn) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(result, sampleServerResponse) { - t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) - } -} - -func TestLocalPortForwardingRespectsCallback(t *testing.T) { - t.Parallel() - - l, client, cleanup := newTestSessionWithForwarding(t, false) - defer cleanup() - - _, err := client.Dial("tcp", l.Addr().String()) - if err == nil { - t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String()) - } - if !strings.Contains(err.Error(), "port forwarding is disabled") { - t.Fatalf("Expected permission error but got %#v", err) - } -} +//go:build glidertests + +package ssh + +import ( + "bytes" + "io" + "net" + "strconv" + "strings" + "testing" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +var sampleServerResponse = []byte("Hello world") + +func sampleSocketServer() net.Listener { + l := newLocalListener() + + go func() { + conn, err := l.Accept() + if err != nil { + return + } + conn.Write(sampleServerResponse) + conn.Close() + }() + + return l +} + +func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { + l := sampleSocketServer() + + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + LocalPortForwardingCallback: func(ctx Context, destinationHost string, destinationPort uint32) bool { + addr := net.JoinHostPort(destinationHost, strconv.FormatInt(int64(destinationPort), 10)) + if addr != l.Addr().String() { + panic("unexpected destinationHost: " + addr) + } + return forwardingEnabled + }, + }, nil) + + return l, client, func() { + cleanup() + l.Close() + } +} + +func TestLocalPortForwardingWorks(t *testing.T) { + t.Parallel() + + l, client, cleanup := newTestSessionWithForwarding(t, true) + defer cleanup() + + conn, err := client.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) + } + result, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, sampleServerResponse) { + t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) + } +} + +func TestLocalPortForwardingRespectsCallback(t *testing.T) { + t.Parallel() + + l, client, cleanup := newTestSessionWithForwarding(t, false) + defer cleanup() + + _, err := client.Dial("tcp", l.Addr().String()) + if err == nil { + t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String()) + } + if !strings.Contains(err.Error(), "port forwarding is disabled") { + t.Fatalf("Expected permission error but got %#v", err) + } +} diff --git a/tempfork/gliderlabs/ssh/util.go b/tempfork/gliderlabs/ssh/util.go index e3b5716a3ab55..7a6a1824109bf 100644 --- a/tempfork/gliderlabs/ssh/util.go +++ b/tempfork/gliderlabs/ssh/util.go @@ -1,157 +1,157 @@ -package ssh - -import ( - "crypto/rand" - "crypto/rsa" - "encoding/binary" - - "github.com/tailscale/golang-x-crypto/ssh" -) - -func generateSigner() (ssh.Signer, error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, err - } - return ssh.NewSignerFromKey(key) -} - -func parsePtyRequest(payload []byte) (pty Pty, ok bool) { - // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 - // 6.2. Requesting a Pseudo-Terminal - // A pseudo-terminal can be allocated for the session by sending the - // following message. - // byte SSH_MSG_CHANNEL_REQUEST - // uint32 recipient channel - // string "pty-req" - // boolean want_reply - // string TERM environment variable value (e.g., vt100) - // uint32 terminal width, characters (e.g., 80) - // uint32 terminal height, rows (e.g., 24) - // uint32 terminal width, pixels (e.g., 640) - // uint32 terminal height, pixels (e.g., 480) - // string encoded terminal modes - - // The payload starts from the TERM variable. - term, rem, ok := parseString(payload) - if !ok { - return - } - win, rem, ok := parseWindow(rem) - if !ok { - return - } - modes, ok := parseTerminalModes(rem) - if !ok { - return - } - pty = Pty{ - Term: term, - Window: win, - Modes: modes, - } - return -} - -func parseTerminalModes(in []byte) (modes ssh.TerminalModes, ok bool) { - // See https://datatracker.ietf.org/doc/html/rfc4254#section-8 - // 8. Encoding of Terminal Modes - // - // All 'encoded terminal modes' (as passed in a pty request) are encoded - // into a byte stream. It is intended that the coding be portable - // across different environments. The stream consists of opcode- - // argument pairs wherein the opcode is a byte value. Opcodes 1 to 159 - // have a single uint32 argument. Opcodes 160 to 255 are not yet - // defined, and cause parsing to stop (they should only be used after - // any other data). The stream is terminated by opcode TTY_OP_END - // (0x00). - // - // The client SHOULD put any modes it knows about in the stream, and the - // server MAY ignore any modes it does not know about. This allows some - // degree of machine-independence, at least between systems that use a - // POSIX-like tty interface. The protocol can support other systems as - // well, but the client may need to fill reasonable values for a number - // of parameters so the server pty gets set to a reasonable mode (the - // server leaves all unspecified mode bits in their default values, and - // only some combinations make sense). - _, rem, ok := parseUint32(in) - if !ok { - return - } - const ttyOpEnd = 0 - for len(rem) > 0 { - if modes == nil { - modes = make(ssh.TerminalModes) - } - code := uint8(rem[0]) - rem = rem[1:] - if code == ttyOpEnd || code > 160 { - break - } - var val uint32 - val, rem, ok = parseUint32(rem) - if !ok { - return - } - modes[code] = val - } - ok = true - return -} - -func parseWindow(s []byte) (win Window, rem []byte, ok bool) { - // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.7 - // 6.7. Window Dimension Change Message - // When the window (terminal) size changes on the client side, it MAY - // send a message to the other side to inform it of the new dimensions. - - // byte SSH_MSG_CHANNEL_REQUEST - // uint32 recipient channel - // string "window-change" - // boolean FALSE - // uint32 terminal width, columns - // uint32 terminal height, rows - // uint32 terminal width, pixels - // uint32 terminal height, pixels - wCols, rem, ok := parseUint32(s) - if !ok { - return - } - hRows, rem, ok := parseUint32(rem) - if !ok { - return - } - wPixels, rem, ok := parseUint32(rem) - if !ok { - return - } - hPixels, rem, ok := parseUint32(rem) - if !ok { - return - } - win = Window{ - Width: int(wCols), - Height: int(hRows), - WidthPixels: int(wPixels), - HeightPixels: int(hPixels), - } - return -} - -func parseString(in []byte) (out string, rem []byte, ok bool) { - length, rem, ok := parseUint32(in) - if uint32(len(rem)) < length || !ok { - ok = false - return - } - out, rem = string(rem[:length]), rem[length:] - ok = true - return -} - -func parseUint32(in []byte) (uint32, []byte, bool) { - if len(in) < 4 { - return 0, nil, false - } - return binary.BigEndian.Uint32(in), in[4:], true -} +package ssh + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/binary" + + "github.com/tailscale/golang-x-crypto/ssh" +) + +func generateSigner() (ssh.Signer, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + return ssh.NewSignerFromKey(key) +} + +func parsePtyRequest(payload []byte) (pty Pty, ok bool) { + // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 + // 6.2. Requesting a Pseudo-Terminal + // A pseudo-terminal can be allocated for the session by sending the + // following message. + // byte SSH_MSG_CHANNEL_REQUEST + // uint32 recipient channel + // string "pty-req" + // boolean want_reply + // string TERM environment variable value (e.g., vt100) + // uint32 terminal width, characters (e.g., 80) + // uint32 terminal height, rows (e.g., 24) + // uint32 terminal width, pixels (e.g., 640) + // uint32 terminal height, pixels (e.g., 480) + // string encoded terminal modes + + // The payload starts from the TERM variable. + term, rem, ok := parseString(payload) + if !ok { + return + } + win, rem, ok := parseWindow(rem) + if !ok { + return + } + modes, ok := parseTerminalModes(rem) + if !ok { + return + } + pty = Pty{ + Term: term, + Window: win, + Modes: modes, + } + return +} + +func parseTerminalModes(in []byte) (modes ssh.TerminalModes, ok bool) { + // See https://datatracker.ietf.org/doc/html/rfc4254#section-8 + // 8. Encoding of Terminal Modes + // + // All 'encoded terminal modes' (as passed in a pty request) are encoded + // into a byte stream. It is intended that the coding be portable + // across different environments. The stream consists of opcode- + // argument pairs wherein the opcode is a byte value. Opcodes 1 to 159 + // have a single uint32 argument. Opcodes 160 to 255 are not yet + // defined, and cause parsing to stop (they should only be used after + // any other data). The stream is terminated by opcode TTY_OP_END + // (0x00). + // + // The client SHOULD put any modes it knows about in the stream, and the + // server MAY ignore any modes it does not know about. This allows some + // degree of machine-independence, at least between systems that use a + // POSIX-like tty interface. The protocol can support other systems as + // well, but the client may need to fill reasonable values for a number + // of parameters so the server pty gets set to a reasonable mode (the + // server leaves all unspecified mode bits in their default values, and + // only some combinations make sense). + _, rem, ok := parseUint32(in) + if !ok { + return + } + const ttyOpEnd = 0 + for len(rem) > 0 { + if modes == nil { + modes = make(ssh.TerminalModes) + } + code := uint8(rem[0]) + rem = rem[1:] + if code == ttyOpEnd || code > 160 { + break + } + var val uint32 + val, rem, ok = parseUint32(rem) + if !ok { + return + } + modes[code] = val + } + ok = true + return +} + +func parseWindow(s []byte) (win Window, rem []byte, ok bool) { + // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.7 + // 6.7. Window Dimension Change Message + // When the window (terminal) size changes on the client side, it MAY + // send a message to the other side to inform it of the new dimensions. + + // byte SSH_MSG_CHANNEL_REQUEST + // uint32 recipient channel + // string "window-change" + // boolean FALSE + // uint32 terminal width, columns + // uint32 terminal height, rows + // uint32 terminal width, pixels + // uint32 terminal height, pixels + wCols, rem, ok := parseUint32(s) + if !ok { + return + } + hRows, rem, ok := parseUint32(rem) + if !ok { + return + } + wPixels, rem, ok := parseUint32(rem) + if !ok { + return + } + hPixels, rem, ok := parseUint32(rem) + if !ok { + return + } + win = Window{ + Width: int(wCols), + Height: int(hRows), + WidthPixels: int(wPixels), + HeightPixels: int(hPixels), + } + return +} + +func parseString(in []byte) (out string, rem []byte, ok bool) { + length, rem, ok := parseUint32(in) + if uint32(len(rem)) < length || !ok { + ok = false + return + } + out, rem = string(rem[:length]), rem[length:] + ok = true + return +} + +func parseUint32(in []byte) (uint32, []byte, bool) { + if len(in) < 4 { + return 0, nil, false + } + return binary.BigEndian.Uint32(in), in[4:], true +} diff --git a/tempfork/gliderlabs/ssh/wrap.go b/tempfork/gliderlabs/ssh/wrap.go index 17867d7518dd1..f44f5d9bff299 100644 --- a/tempfork/gliderlabs/ssh/wrap.go +++ b/tempfork/gliderlabs/ssh/wrap.go @@ -1,33 +1,33 @@ -package ssh - -import gossh "github.com/tailscale/golang-x-crypto/ssh" - -// PublicKey is an abstraction of different types of public keys. -type PublicKey interface { - gossh.PublicKey -} - -// The Permissions type holds fine-grained permissions that are specific to a -// user or a specific authentication method for a user. Permissions, except for -// "source-address", must be enforced in the server application layer, after -// successful authentication. -type Permissions struct { - *gossh.Permissions -} - -// A Signer can create signatures that verify against a public key. -type Signer interface { - gossh.Signer -} - -// ParseAuthorizedKey parses a public key from an authorized_keys file used in -// OpenSSH according to the sshd(8) manual page. -func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { - return gossh.ParseAuthorizedKey(in) -} - -// ParsePublicKey parses an SSH public key formatted for use in -// the SSH wire protocol according to RFC 4253, section 6.6. -func ParsePublicKey(in []byte) (out PublicKey, err error) { - return gossh.ParsePublicKey(in) -} +package ssh + +import gossh "github.com/tailscale/golang-x-crypto/ssh" + +// PublicKey is an abstraction of different types of public keys. +type PublicKey interface { + gossh.PublicKey +} + +// The Permissions type holds fine-grained permissions that are specific to a +// user or a specific authentication method for a user. Permissions, except for +// "source-address", must be enforced in the server application layer, after +// successful authentication. +type Permissions struct { + *gossh.Permissions +} + +// A Signer can create signatures that verify against a public key. +type Signer interface { + gossh.Signer +} + +// ParseAuthorizedKey parses a public key from an authorized_keys file used in +// OpenSSH according to the sshd(8) manual page. +func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { + return gossh.ParseAuthorizedKey(in) +} + +// ParsePublicKey parses an SSH public key formatted for use in +// the SSH wire protocol according to RFC 4253, section 6.6. +func ParsePublicKey(in []byte) (out PublicKey, err error) { + return gossh.ParsePublicKey(in) +} diff --git a/tempfork/heap/heap.go b/tempfork/heap/heap.go index 3dfab492ad0b8..080b80ca5f7f0 100644 --- a/tempfork/heap/heap.go +++ b/tempfork/heap/heap.go @@ -1,121 +1,121 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package heap provides heap operations for any type that implements -// heap.Interface. A heap is a tree with the property that each node is the -// minimum-valued node in its subtree. -// -// The minimum element in the tree is the root, at index 0. -// -// A heap is a common way to implement a priority queue. To build a priority -// queue, implement the Heap interface with the (negative) priority as the -// ordering for the Less method, so Push adds items while Pop removes the -// highest-priority item from the queue. The Examples include such an -// implementation; the file example_pq_test.go has the complete source. -// -// This package is a copy of the Go standard library's -// container/heap, but using generics. -package heap - -import "sort" - -// The Interface type describes the requirements -// for a type using the routines in this package. -// Any type that implements it may be used as a -// min-heap with the following invariants (established after -// Init has been called or if the data is empty or sorted): -// -// !h.Less(j, i) for 0 <= i < h.Len() and 2*i+1 <= j <= 2*i+2 and j < h.Len() -// -// Note that Push and Pop in this interface are for package heap's -// implementation to call. To add and remove things from the heap, -// use heap.Push and heap.Pop. -type Interface[V any] interface { - sort.Interface - Push(x V) // add x as element Len() - Pop() V // remove and return element Len() - 1. -} - -// Init establishes the heap invariants required by the other routines in this package. -// Init is idempotent with respect to the heap invariants -// and may be called whenever the heap invariants may have been invalidated. -// The complexity is O(n) where n = h.Len(). -func Init[V any](h Interface[V]) { - // heapify - n := h.Len() - for i := n/2 - 1; i >= 0; i-- { - down(h, i, n) - } -} - -// Push pushes the element x onto the heap. -// The complexity is O(log n) where n = h.Len(). -func Push[V any](h Interface[V], x V) { - h.Push(x) - up(h, h.Len()-1) -} - -// Pop removes and returns the minimum element (according to Less) from the heap. -// The complexity is O(log n) where n = h.Len(). -// Pop is equivalent to Remove(h, 0). -func Pop[V any](h Interface[V]) V { - n := h.Len() - 1 - h.Swap(0, n) - down(h, 0, n) - return h.Pop() -} - -// Remove removes and returns the element at index i from the heap. -// The complexity is O(log n) where n = h.Len(). -func Remove[V any](h Interface[V], i int) V { - n := h.Len() - 1 - if n != i { - h.Swap(i, n) - if !down(h, i, n) { - up(h, i) - } - } - return h.Pop() -} - -// Fix re-establishes the heap ordering after the element at index i has changed its value. -// Changing the value of the element at index i and then calling Fix is equivalent to, -// but less expensive than, calling Remove(h, i) followed by a Push of the new value. -// The complexity is O(log n) where n = h.Len(). -func Fix[V any](h Interface[V], i int) { - if !down(h, i, h.Len()) { - up(h, i) - } -} - -func up[V any](h Interface[V], j int) { - for { - i := (j - 1) / 2 // parent - if i == j || !h.Less(j, i) { - break - } - h.Swap(i, j) - j = i - } -} - -func down[V any](h Interface[V], i0, n int) bool { - i := i0 - for { - j1 := 2*i + 1 - if j1 >= n || j1 < 0 { // j1 < 0 after int overflow - break - } - j := j1 // left child - if j2 := j1 + 1; j2 < n && h.Less(j2, j1) { - j = j2 // = 2*i + 2 // right child - } - if !h.Less(j, i) { - break - } - h.Swap(i, j) - i = j - } - return i > i0 -} +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package heap provides heap operations for any type that implements +// heap.Interface. A heap is a tree with the property that each node is the +// minimum-valued node in its subtree. +// +// The minimum element in the tree is the root, at index 0. +// +// A heap is a common way to implement a priority queue. To build a priority +// queue, implement the Heap interface with the (negative) priority as the +// ordering for the Less method, so Push adds items while Pop removes the +// highest-priority item from the queue. The Examples include such an +// implementation; the file example_pq_test.go has the complete source. +// +// This package is a copy of the Go standard library's +// container/heap, but using generics. +package heap + +import "sort" + +// The Interface type describes the requirements +// for a type using the routines in this package. +// Any type that implements it may be used as a +// min-heap with the following invariants (established after +// Init has been called or if the data is empty or sorted): +// +// !h.Less(j, i) for 0 <= i < h.Len() and 2*i+1 <= j <= 2*i+2 and j < h.Len() +// +// Note that Push and Pop in this interface are for package heap's +// implementation to call. To add and remove things from the heap, +// use heap.Push and heap.Pop. +type Interface[V any] interface { + sort.Interface + Push(x V) // add x as element Len() + Pop() V // remove and return element Len() - 1. +} + +// Init establishes the heap invariants required by the other routines in this package. +// Init is idempotent with respect to the heap invariants +// and may be called whenever the heap invariants may have been invalidated. +// The complexity is O(n) where n = h.Len(). +func Init[V any](h Interface[V]) { + // heapify + n := h.Len() + for i := n/2 - 1; i >= 0; i-- { + down(h, i, n) + } +} + +// Push pushes the element x onto the heap. +// The complexity is O(log n) where n = h.Len(). +func Push[V any](h Interface[V], x V) { + h.Push(x) + up(h, h.Len()-1) +} + +// Pop removes and returns the minimum element (according to Less) from the heap. +// The complexity is O(log n) where n = h.Len(). +// Pop is equivalent to Remove(h, 0). +func Pop[V any](h Interface[V]) V { + n := h.Len() - 1 + h.Swap(0, n) + down(h, 0, n) + return h.Pop() +} + +// Remove removes and returns the element at index i from the heap. +// The complexity is O(log n) where n = h.Len(). +func Remove[V any](h Interface[V], i int) V { + n := h.Len() - 1 + if n != i { + h.Swap(i, n) + if !down(h, i, n) { + up(h, i) + } + } + return h.Pop() +} + +// Fix re-establishes the heap ordering after the element at index i has changed its value. +// Changing the value of the element at index i and then calling Fix is equivalent to, +// but less expensive than, calling Remove(h, i) followed by a Push of the new value. +// The complexity is O(log n) where n = h.Len(). +func Fix[V any](h Interface[V], i int) { + if !down(h, i, h.Len()) { + up(h, i) + } +} + +func up[V any](h Interface[V], j int) { + for { + i := (j - 1) / 2 // parent + if i == j || !h.Less(j, i) { + break + } + h.Swap(i, j) + j = i + } +} + +func down[V any](h Interface[V], i0, n int) bool { + i := i0 + for { + j1 := 2*i + 1 + if j1 >= n || j1 < 0 { // j1 < 0 after int overflow + break + } + j := j1 // left child + if j2 := j1 + 1; j2 < n && h.Less(j2, j1) { + j = j2 // = 2*i + 2 // right child + } + if !h.Less(j, i) { + break + } + h.Swap(i, j) + i = j + } + return i > i0 +} diff --git a/tka/aum_test.go b/tka/aum_test.go index 4297efabff13f..84b5674776319 100644 --- a/tka/aum_test.go +++ b/tka/aum_test.go @@ -1,253 +1,253 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "testing" - - "github.com/google/go-cmp/cmp" - "golang.org/x/crypto/blake2s" - "tailscale.com/types/tkatype" -) - -func TestSerialization(t *testing.T) { - uint2 := uint(2) - var fakeAUMHash AUMHash - - tcs := []struct { - Name string - AUM AUM - Expect []byte - }{ - { - "AddKey", - AUM{MessageKind: AUMAddKey, Key: &Key{}}, - []byte{ - 0xa3, // major type 5 (map), 3 items - 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) - 0x01, // |- major type 0 (int), value 1 (first value, AUMAddKey) - 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) - 0xf6, // |- major type 7 (val), value null (second value, nil) - 0x03, // |- major type 0 (int), value 3 (third key, Key) - 0xa3, // |- major type 5 (map), 3 items (type Key) - 0x01, // |- major type 0 (int), value 1 (first key, Kind) - 0x00, // |- major type 0 (int), value 0 (first value) - 0x02, // |- major type 0 (int), value 2 (second key, Votes) - 0x00, // |- major type 0 (int), value 0 (first value) - 0x03, // |- major type 0 (int), value 3 (third key, Public) - 0xf6, // |- major type 7 (val), value null (third value, nil) - }, - }, - { - "RemoveKey", - AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}}, - []byte{ - 0xa3, // major type 5 (map), 3 items - 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) - 0x02, // |- major type 0 (int), value 2 (first value, AUMRemoveKey) - 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) - 0xf6, // |- major type 7 (val), value null (second value, nil) - 0x04, // |- major type 0 (int), value 4 (third key, KeyID) - 0x42, // |- major type 2 (byte string), 2 items - 0x01, // |- major type 0 (int), value 1 (byte 1) - 0x02, // |- major type 0 (int), value 2 (byte 2) - }, - }, - { - "UpdateKey", - AUM{MessageKind: AUMUpdateKey, Votes: &uint2, KeyID: []byte{1, 2}, Meta: map[string]string{"a": "b"}}, - []byte{ - 0xa5, // major type 5 (map), 5 items - 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) - 0x04, // |- major type 0 (int), value 4 (first value, AUMUpdateKey) - 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) - 0xf6, // |- major type 7 (val), value null (second value, nil) - 0x04, // |- major type 0 (int), value 4 (third key, KeyID) - 0x42, // |- major type 2 (byte string), 2 items - 0x01, // |- major type 0 (int), value 1 (byte 1) - 0x02, // |- major type 0 (int), value 2 (byte 2) - 0x06, // |- major type 0 (int), value 6 (fourth key, Votes) - 0x02, // |- major type 0 (int), value 2 (forth value, 2) - 0x07, // |- major type 0 (int), value 7 (fifth key, Meta) - 0xa1, // |- major type 5 (map), 1 item (map[string]string type) - 0x61, // |- major type 3 (text string), value 1 (first key, one byte long) - 0x61, // |- byte 'a' - 0x61, // |- major type 3 (text string), value 1 (first value, one byte long) - 0x62, // |- byte 'b' - }, - }, - { - "Checkpoint", - AUM{MessageKind: AUMCheckpoint, PrevAUMHash: []byte{1, 2}, State: &State{ - LastAUMHash: &fakeAUMHash, - Keys: []Key{ - {Kind: Key25519, Public: []byte{5, 6}}, - }, - }}, - append( - append([]byte{ - 0xa3, // major type 5 (map), 3 items - 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) - 0x05, // |- major type 0 (int), value 5 (first value, AUMCheckpoint) - 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) - 0x42, // |- major type 2 (byte string), 2 items (second value) - 0x01, // |- major type 0 (int), value 1 (byte 1) - 0x02, // |- major type 0 (int), value 2 (byte 2) - 0x05, // |- major type 0 (int), value 5 (third key, State) - 0xa3, // |- major type 5 (map), 3 items (third value, State type) - 0x01, // |- major type 0 (int), value 1 (first key, LastAUMHash) - 0x58, 0x20, // |- major type 2 (byte string), 32 items (first value) - }, - bytes.Repeat([]byte{0}, 32)...), - []byte{ - 0x02, // |- major type 0 (int), value 2 (second key, DisablementSecrets) - 0xf6, // |- major type 7 (val), value null (second value, nil) - 0x03, // |- major type 0 (int), value 3 (third key, Keys) - 0x81, // |- major type 4 (array), value 1 (one item in array) - 0xa3, // |- major type 5 (map), 3 items (Key type) - 0x01, // |- major type 0 (int), value 1 (first key, Kind) - 0x01, // |- major type 0 (int), value 1 (first value, Key25519) - 0x02, // |- major type 0 (int), value 2 (second key, Votes) - 0x00, // |- major type 0 (int), value 0 (second value, 0) - 0x03, // |- major type 0 (int), value 3 (third key, Public) - 0x42, // |- major type 2 (byte string), 2 items (third value) - 0x05, // |- major type 0 (int), value 5 (byte 5) - 0x06, // |- major type 0 (int), value 6 (byte 6) - }...), - }, - { - "Signature", - AUM{MessageKind: AUMAddKey, Signatures: []tkatype.Signature{{KeyID: []byte{1}}}}, - []byte{ - 0xa3, // major type 5 (map), 3 items - 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) - 0x01, // |- major type 0 (int), value 1 (first value, AUMAddKey) - 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) - 0xf6, // |- major type 7 (val), value null (second value, nil) - 0x17, // |- major type 0 (int), value 22 (third key, Signatures) - 0x81, // |- major type 4 (array), value 1 (one item in array) - 0xa2, // |- major type 5 (map), 2 items (Signature type) - 0x01, // |- major type 0 (int), value 1 (first key, KeyID) - 0x41, // |- major type 2 (byte string), 1 item - 0x01, // |- major type 0 (int), value 1 (byte 1) - 0x02, // |- major type 0 (int), value 2 (second key, Signature) - 0xf6, // |- major type 7 (val), value null (second value, nil) - }, - }, - } - - for _, tc := range tcs { - t.Run(tc.Name, func(t *testing.T) { - data := []byte(tc.AUM.Serialize()) - if diff := cmp.Diff(tc.Expect, data); diff != "" { - t.Errorf("serialization differs (-want, +got):\n%s", diff) - } - - var decodedAUM AUM - if err := decodedAUM.Unserialize(data); err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - if diff := cmp.Diff(tc.AUM, decodedAUM); diff != "" { - t.Errorf("unmarshalled version differs (-want, +got):\n%s", diff) - } - }) - } -} - -func TestAUMWeight(t *testing.T) { - var fakeKeyID [blake2s.Size]byte - testingRand(t, 1).Read(fakeKeyID[:]) - - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - pub, _ = testingKey25519(t, 2) - key2 := Key{Kind: Key25519, Public: pub, Votes: 2} - - tcs := []struct { - Name string - AUM AUM - State State - Want uint - }{ - { - "Empty", - AUM{}, - State{}, - 0, - }, - { - "Key unknown", - AUM{ - Signatures: []tkatype.Signature{{KeyID: fakeKeyID[:]}}, - }, - State{}, - 0, - }, - { - "Unary key", - AUM{ - Signatures: []tkatype.Signature{{KeyID: key.MustID()}}, - }, - State{ - Keys: []Key{key}, - }, - 2, - }, - { - "Multiple keys", - AUM{ - Signatures: []tkatype.Signature{{KeyID: key.MustID()}, {KeyID: key2.MustID()}}, - }, - State{ - Keys: []Key{key, key2}, - }, - 4, - }, - { - "Double use", - AUM{ - Signatures: []tkatype.Signature{{KeyID: key.MustID()}, {KeyID: key.MustID()}}, - }, - State{ - Keys: []Key{key}, - }, - 2, - }, - } - - for _, tc := range tcs { - t.Run(tc.Name, func(t *testing.T) { - got := tc.AUM.Weight(tc.State) - if got != tc.Want { - t.Errorf("Weight() = %d, want %d", got, tc.Want) - } - }) - } -} - -func TestAUMHashes(t *testing.T) { - // .Hash(): a hash over everything. - // .SigHash(): a hash over everything except the signatures. - // The signatures are over a hash of the AUM, so - // using SigHash() breaks this circularity. - - aum := AUM{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519}} - sigHash1 := aum.SigHash() - aumHash1 := aum.Hash() - - aum.Signatures = []tkatype.Signature{{KeyID: []byte{1, 2, 3, 4}}} - sigHash2 := aum.SigHash() - aumHash2 := aum.Hash() - if len(aum.Signatures) != 1 { - t.Error("signature was removed by one of the hash functions") - } - - if !bytes.Equal(sigHash1[:], sigHash1[:]) { - t.Errorf("signature hash dependent on signatures!\n\t1 = %x\n\t2 = %x", sigHash1, sigHash2) - } - if bytes.Equal(aumHash1[:], aumHash2[:]) { - t.Error("aum hash didnt change") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "testing" + + "github.com/google/go-cmp/cmp" + "golang.org/x/crypto/blake2s" + "tailscale.com/types/tkatype" +) + +func TestSerialization(t *testing.T) { + uint2 := uint(2) + var fakeAUMHash AUMHash + + tcs := []struct { + Name string + AUM AUM + Expect []byte + }{ + { + "AddKey", + AUM{MessageKind: AUMAddKey, Key: &Key{}}, + []byte{ + 0xa3, // major type 5 (map), 3 items + 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) + 0x01, // |- major type 0 (int), value 1 (first value, AUMAddKey) + 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) + 0xf6, // |- major type 7 (val), value null (second value, nil) + 0x03, // |- major type 0 (int), value 3 (third key, Key) + 0xa3, // |- major type 5 (map), 3 items (type Key) + 0x01, // |- major type 0 (int), value 1 (first key, Kind) + 0x00, // |- major type 0 (int), value 0 (first value) + 0x02, // |- major type 0 (int), value 2 (second key, Votes) + 0x00, // |- major type 0 (int), value 0 (first value) + 0x03, // |- major type 0 (int), value 3 (third key, Public) + 0xf6, // |- major type 7 (val), value null (third value, nil) + }, + }, + { + "RemoveKey", + AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}}, + []byte{ + 0xa3, // major type 5 (map), 3 items + 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) + 0x02, // |- major type 0 (int), value 2 (first value, AUMRemoveKey) + 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) + 0xf6, // |- major type 7 (val), value null (second value, nil) + 0x04, // |- major type 0 (int), value 4 (third key, KeyID) + 0x42, // |- major type 2 (byte string), 2 items + 0x01, // |- major type 0 (int), value 1 (byte 1) + 0x02, // |- major type 0 (int), value 2 (byte 2) + }, + }, + { + "UpdateKey", + AUM{MessageKind: AUMUpdateKey, Votes: &uint2, KeyID: []byte{1, 2}, Meta: map[string]string{"a": "b"}}, + []byte{ + 0xa5, // major type 5 (map), 5 items + 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) + 0x04, // |- major type 0 (int), value 4 (first value, AUMUpdateKey) + 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) + 0xf6, // |- major type 7 (val), value null (second value, nil) + 0x04, // |- major type 0 (int), value 4 (third key, KeyID) + 0x42, // |- major type 2 (byte string), 2 items + 0x01, // |- major type 0 (int), value 1 (byte 1) + 0x02, // |- major type 0 (int), value 2 (byte 2) + 0x06, // |- major type 0 (int), value 6 (fourth key, Votes) + 0x02, // |- major type 0 (int), value 2 (forth value, 2) + 0x07, // |- major type 0 (int), value 7 (fifth key, Meta) + 0xa1, // |- major type 5 (map), 1 item (map[string]string type) + 0x61, // |- major type 3 (text string), value 1 (first key, one byte long) + 0x61, // |- byte 'a' + 0x61, // |- major type 3 (text string), value 1 (first value, one byte long) + 0x62, // |- byte 'b' + }, + }, + { + "Checkpoint", + AUM{MessageKind: AUMCheckpoint, PrevAUMHash: []byte{1, 2}, State: &State{ + LastAUMHash: &fakeAUMHash, + Keys: []Key{ + {Kind: Key25519, Public: []byte{5, 6}}, + }, + }}, + append( + append([]byte{ + 0xa3, // major type 5 (map), 3 items + 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) + 0x05, // |- major type 0 (int), value 5 (first value, AUMCheckpoint) + 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) + 0x42, // |- major type 2 (byte string), 2 items (second value) + 0x01, // |- major type 0 (int), value 1 (byte 1) + 0x02, // |- major type 0 (int), value 2 (byte 2) + 0x05, // |- major type 0 (int), value 5 (third key, State) + 0xa3, // |- major type 5 (map), 3 items (third value, State type) + 0x01, // |- major type 0 (int), value 1 (first key, LastAUMHash) + 0x58, 0x20, // |- major type 2 (byte string), 32 items (first value) + }, + bytes.Repeat([]byte{0}, 32)...), + []byte{ + 0x02, // |- major type 0 (int), value 2 (second key, DisablementSecrets) + 0xf6, // |- major type 7 (val), value null (second value, nil) + 0x03, // |- major type 0 (int), value 3 (third key, Keys) + 0x81, // |- major type 4 (array), value 1 (one item in array) + 0xa3, // |- major type 5 (map), 3 items (Key type) + 0x01, // |- major type 0 (int), value 1 (first key, Kind) + 0x01, // |- major type 0 (int), value 1 (first value, Key25519) + 0x02, // |- major type 0 (int), value 2 (second key, Votes) + 0x00, // |- major type 0 (int), value 0 (second value, 0) + 0x03, // |- major type 0 (int), value 3 (third key, Public) + 0x42, // |- major type 2 (byte string), 2 items (third value) + 0x05, // |- major type 0 (int), value 5 (byte 5) + 0x06, // |- major type 0 (int), value 6 (byte 6) + }...), + }, + { + "Signature", + AUM{MessageKind: AUMAddKey, Signatures: []tkatype.Signature{{KeyID: []byte{1}}}}, + []byte{ + 0xa3, // major type 5 (map), 3 items + 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) + 0x01, // |- major type 0 (int), value 1 (first value, AUMAddKey) + 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) + 0xf6, // |- major type 7 (val), value null (second value, nil) + 0x17, // |- major type 0 (int), value 22 (third key, Signatures) + 0x81, // |- major type 4 (array), value 1 (one item in array) + 0xa2, // |- major type 5 (map), 2 items (Signature type) + 0x01, // |- major type 0 (int), value 1 (first key, KeyID) + 0x41, // |- major type 2 (byte string), 1 item + 0x01, // |- major type 0 (int), value 1 (byte 1) + 0x02, // |- major type 0 (int), value 2 (second key, Signature) + 0xf6, // |- major type 7 (val), value null (second value, nil) + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.Name, func(t *testing.T) { + data := []byte(tc.AUM.Serialize()) + if diff := cmp.Diff(tc.Expect, data); diff != "" { + t.Errorf("serialization differs (-want, +got):\n%s", diff) + } + + var decodedAUM AUM + if err := decodedAUM.Unserialize(data); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if diff := cmp.Diff(tc.AUM, decodedAUM); diff != "" { + t.Errorf("unmarshalled version differs (-want, +got):\n%s", diff) + } + }) + } +} + +func TestAUMWeight(t *testing.T) { + var fakeKeyID [blake2s.Size]byte + testingRand(t, 1).Read(fakeKeyID[:]) + + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + pub, _ = testingKey25519(t, 2) + key2 := Key{Kind: Key25519, Public: pub, Votes: 2} + + tcs := []struct { + Name string + AUM AUM + State State + Want uint + }{ + { + "Empty", + AUM{}, + State{}, + 0, + }, + { + "Key unknown", + AUM{ + Signatures: []tkatype.Signature{{KeyID: fakeKeyID[:]}}, + }, + State{}, + 0, + }, + { + "Unary key", + AUM{ + Signatures: []tkatype.Signature{{KeyID: key.MustID()}}, + }, + State{ + Keys: []Key{key}, + }, + 2, + }, + { + "Multiple keys", + AUM{ + Signatures: []tkatype.Signature{{KeyID: key.MustID()}, {KeyID: key2.MustID()}}, + }, + State{ + Keys: []Key{key, key2}, + }, + 4, + }, + { + "Double use", + AUM{ + Signatures: []tkatype.Signature{{KeyID: key.MustID()}, {KeyID: key.MustID()}}, + }, + State{ + Keys: []Key{key}, + }, + 2, + }, + } + + for _, tc := range tcs { + t.Run(tc.Name, func(t *testing.T) { + got := tc.AUM.Weight(tc.State) + if got != tc.Want { + t.Errorf("Weight() = %d, want %d", got, tc.Want) + } + }) + } +} + +func TestAUMHashes(t *testing.T) { + // .Hash(): a hash over everything. + // .SigHash(): a hash over everything except the signatures. + // The signatures are over a hash of the AUM, so + // using SigHash() breaks this circularity. + + aum := AUM{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519}} + sigHash1 := aum.SigHash() + aumHash1 := aum.Hash() + + aum.Signatures = []tkatype.Signature{{KeyID: []byte{1, 2, 3, 4}}} + sigHash2 := aum.SigHash() + aumHash2 := aum.Hash() + if len(aum.Signatures) != 1 { + t.Error("signature was removed by one of the hash functions") + } + + if !bytes.Equal(sigHash1[:], sigHash1[:]) { + t.Errorf("signature hash dependent on signatures!\n\t1 = %x\n\t2 = %x", sigHash1, sigHash2) + } + if bytes.Equal(aumHash1[:], aumHash2[:]) { + t.Error("aum hash didnt change") + } +} diff --git a/tka/builder.go b/tka/builder.go index c14ba2330ae0d..19cd340f03823 100644 --- a/tka/builder.go +++ b/tka/builder.go @@ -1,180 +1,180 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "fmt" - "os" - - "tailscale.com/types/tkatype" -) - -// Types implementing Signer can sign update messages. -type Signer interface { - // SignAUM returns signatures for the AUM encoded by the given AUMSigHash. - SignAUM(tkatype.AUMSigHash) ([]tkatype.Signature, error) -} - -// UpdateBuilder implements a builder for changes to the tailnet -// key authority. -// -// Finalize must be called to compute the update messages, which -// must then be applied to all Authority objects using Inform(). -type UpdateBuilder struct { - a *Authority - signer Signer - - state State - parent AUMHash - - out []AUM -} - -func (b *UpdateBuilder) mkUpdate(update AUM) error { - prevHash := make([]byte, len(b.parent)) - copy(prevHash, b.parent[:]) - update.PrevAUMHash = prevHash - - if b.signer != nil { - sigs, err := b.signer.SignAUM(update.SigHash()) - if err != nil { - return fmt.Errorf("signing failed: %v", err) - } - update.Signatures = append(update.Signatures, sigs...) - } - if err := update.StaticValidate(); err != nil { - return fmt.Errorf("generated update was invalid: %v", err) - } - state, err := b.state.applyVerifiedAUM(update) - if err != nil { - return fmt.Errorf("update cannot be applied: %v", err) - } - - b.state = state - b.parent = update.Hash() - b.out = append(b.out, update) - return nil -} - -// AddKey adds a new key to the authority. -func (b *UpdateBuilder) AddKey(key Key) error { - keyID, err := key.ID() - if err != nil { - return err - } - - if _, err := b.state.GetKey(keyID); err == nil { - return fmt.Errorf("cannot add key %v: already exists", key) - } - return b.mkUpdate(AUM{MessageKind: AUMAddKey, Key: &key}) -} - -// RemoveKey removes a key from the authority. -func (b *UpdateBuilder) RemoveKey(keyID tkatype.KeyID) error { - if _, err := b.state.GetKey(keyID); err != nil { - return fmt.Errorf("failed reading key %x: %v", keyID, err) - } - return b.mkUpdate(AUM{MessageKind: AUMRemoveKey, KeyID: keyID}) -} - -// SetKeyVote updates the number of votes of an existing key. -func (b *UpdateBuilder) SetKeyVote(keyID tkatype.KeyID, votes uint) error { - if _, err := b.state.GetKey(keyID); err != nil { - return fmt.Errorf("failed reading key %x: %v", keyID, err) - } - return b.mkUpdate(AUM{MessageKind: AUMUpdateKey, Votes: &votes, KeyID: keyID}) -} - -// SetKeyMeta updates key-value metadata stored against an existing key. -// -// TODO(tom): Provide an API to update specific values rather than the whole -// map. -func (b *UpdateBuilder) SetKeyMeta(keyID tkatype.KeyID, meta map[string]string) error { - if _, err := b.state.GetKey(keyID); err != nil { - return fmt.Errorf("failed reading key %x: %v", keyID, err) - } - return b.mkUpdate(AUM{MessageKind: AUMUpdateKey, Meta: meta, KeyID: keyID}) -} - -func (b *UpdateBuilder) generateCheckpoint() error { - // Compute the checkpoint state. - state := b.a.state - for i, update := range b.out { - var err error - if state, err = state.applyVerifiedAUM(update); err != nil { - return fmt.Errorf("applying update %d: %v", i, err) - } - } - - // Checkpoints cant specify a parent AUM. - state.LastAUMHash = nil - return b.mkUpdate(AUM{MessageKind: AUMCheckpoint, State: &state}) -} - -// checkpointEvery sets how often a checkpoint AUM should be generated. -const checkpointEvery = 50 - -// Finalize returns the set of update message to actuate the update. -func (b *UpdateBuilder) Finalize(storage Chonk) ([]AUM, error) { - var ( - needCheckpoint bool = true - cursor AUMHash = b.a.Head() - ) - for i := len(b.out); i < checkpointEvery; i++ { - aum, err := storage.AUM(cursor) - if err != nil { - if err == os.ErrNotExist { - // The available chain is shorter than the interval to checkpoint at. - needCheckpoint = false - break - } - return nil, fmt.Errorf("reading AUM: %v", err) - } - - if aum.MessageKind == AUMCheckpoint { - needCheckpoint = false - break - } - - parent, hasParent := aum.Parent() - if !hasParent { - // We've hit the genesis update, so the chain is shorter than the interval to checkpoint at. - needCheckpoint = false - break - } - cursor = parent - } - - if needCheckpoint { - if err := b.generateCheckpoint(); err != nil { - return nil, fmt.Errorf("generating checkpoint: %v", err) - } - } - - // Check no AUMs were applied in the meantime - if len(b.out) > 0 { - if parent, _ := b.out[0].Parent(); parent != b.a.Head() { - return nil, fmt.Errorf("updates no longer apply to head: based on %x but head is %x", parent, b.a.Head()) - } - } - return b.out, nil -} - -// NewUpdater returns a builder you can use to make changes to -// the tailnet key authority. -// -// The provided signer function, if non-nil, is called with each update -// to compute and apply signatures. -// -// Updates are specified by calling methods on the returned UpdatedBuilder. -// Call Finalize() when you are done to obtain the specific update messages -// which actuate the changes. -func (a *Authority) NewUpdater(signer Signer) *UpdateBuilder { - return &UpdateBuilder{ - a: a, - signer: signer, - parent: a.Head(), - state: a.state, - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "fmt" + "os" + + "tailscale.com/types/tkatype" +) + +// Types implementing Signer can sign update messages. +type Signer interface { + // SignAUM returns signatures for the AUM encoded by the given AUMSigHash. + SignAUM(tkatype.AUMSigHash) ([]tkatype.Signature, error) +} + +// UpdateBuilder implements a builder for changes to the tailnet +// key authority. +// +// Finalize must be called to compute the update messages, which +// must then be applied to all Authority objects using Inform(). +type UpdateBuilder struct { + a *Authority + signer Signer + + state State + parent AUMHash + + out []AUM +} + +func (b *UpdateBuilder) mkUpdate(update AUM) error { + prevHash := make([]byte, len(b.parent)) + copy(prevHash, b.parent[:]) + update.PrevAUMHash = prevHash + + if b.signer != nil { + sigs, err := b.signer.SignAUM(update.SigHash()) + if err != nil { + return fmt.Errorf("signing failed: %v", err) + } + update.Signatures = append(update.Signatures, sigs...) + } + if err := update.StaticValidate(); err != nil { + return fmt.Errorf("generated update was invalid: %v", err) + } + state, err := b.state.applyVerifiedAUM(update) + if err != nil { + return fmt.Errorf("update cannot be applied: %v", err) + } + + b.state = state + b.parent = update.Hash() + b.out = append(b.out, update) + return nil +} + +// AddKey adds a new key to the authority. +func (b *UpdateBuilder) AddKey(key Key) error { + keyID, err := key.ID() + if err != nil { + return err + } + + if _, err := b.state.GetKey(keyID); err == nil { + return fmt.Errorf("cannot add key %v: already exists", key) + } + return b.mkUpdate(AUM{MessageKind: AUMAddKey, Key: &key}) +} + +// RemoveKey removes a key from the authority. +func (b *UpdateBuilder) RemoveKey(keyID tkatype.KeyID) error { + if _, err := b.state.GetKey(keyID); err != nil { + return fmt.Errorf("failed reading key %x: %v", keyID, err) + } + return b.mkUpdate(AUM{MessageKind: AUMRemoveKey, KeyID: keyID}) +} + +// SetKeyVote updates the number of votes of an existing key. +func (b *UpdateBuilder) SetKeyVote(keyID tkatype.KeyID, votes uint) error { + if _, err := b.state.GetKey(keyID); err != nil { + return fmt.Errorf("failed reading key %x: %v", keyID, err) + } + return b.mkUpdate(AUM{MessageKind: AUMUpdateKey, Votes: &votes, KeyID: keyID}) +} + +// SetKeyMeta updates key-value metadata stored against an existing key. +// +// TODO(tom): Provide an API to update specific values rather than the whole +// map. +func (b *UpdateBuilder) SetKeyMeta(keyID tkatype.KeyID, meta map[string]string) error { + if _, err := b.state.GetKey(keyID); err != nil { + return fmt.Errorf("failed reading key %x: %v", keyID, err) + } + return b.mkUpdate(AUM{MessageKind: AUMUpdateKey, Meta: meta, KeyID: keyID}) +} + +func (b *UpdateBuilder) generateCheckpoint() error { + // Compute the checkpoint state. + state := b.a.state + for i, update := range b.out { + var err error + if state, err = state.applyVerifiedAUM(update); err != nil { + return fmt.Errorf("applying update %d: %v", i, err) + } + } + + // Checkpoints cant specify a parent AUM. + state.LastAUMHash = nil + return b.mkUpdate(AUM{MessageKind: AUMCheckpoint, State: &state}) +} + +// checkpointEvery sets how often a checkpoint AUM should be generated. +const checkpointEvery = 50 + +// Finalize returns the set of update message to actuate the update. +func (b *UpdateBuilder) Finalize(storage Chonk) ([]AUM, error) { + var ( + needCheckpoint bool = true + cursor AUMHash = b.a.Head() + ) + for i := len(b.out); i < checkpointEvery; i++ { + aum, err := storage.AUM(cursor) + if err != nil { + if err == os.ErrNotExist { + // The available chain is shorter than the interval to checkpoint at. + needCheckpoint = false + break + } + return nil, fmt.Errorf("reading AUM: %v", err) + } + + if aum.MessageKind == AUMCheckpoint { + needCheckpoint = false + break + } + + parent, hasParent := aum.Parent() + if !hasParent { + // We've hit the genesis update, so the chain is shorter than the interval to checkpoint at. + needCheckpoint = false + break + } + cursor = parent + } + + if needCheckpoint { + if err := b.generateCheckpoint(); err != nil { + return nil, fmt.Errorf("generating checkpoint: %v", err) + } + } + + // Check no AUMs were applied in the meantime + if len(b.out) > 0 { + if parent, _ := b.out[0].Parent(); parent != b.a.Head() { + return nil, fmt.Errorf("updates no longer apply to head: based on %x but head is %x", parent, b.a.Head()) + } + } + return b.out, nil +} + +// NewUpdater returns a builder you can use to make changes to +// the tailnet key authority. +// +// The provided signer function, if non-nil, is called with each update +// to compute and apply signatures. +// +// Updates are specified by calling methods on the returned UpdatedBuilder. +// Call Finalize() when you are done to obtain the specific update messages +// which actuate the changes. +func (a *Authority) NewUpdater(signer Signer) *UpdateBuilder { + return &UpdateBuilder{ + a: a, + signer: signer, + parent: a.Head(), + state: a.state, + } +} diff --git a/tka/builder_test.go b/tka/builder_test.go index 666af9ad07daf..758fb170c0b5e 100644 --- a/tka/builder_test.go +++ b/tka/builder_test.go @@ -1,270 +1,270 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "crypto/ed25519" - "testing" - - "github.com/google/go-cmp/cmp" - "tailscale.com/types/tkatype" -) - -type signer25519 ed25519.PrivateKey - -func (s signer25519) SignAUM(sigHash tkatype.AUMSigHash) ([]tkatype.Signature, error) { - priv := ed25519.PrivateKey(s) - key := Key{Kind: Key25519, Public: priv.Public().(ed25519.PublicKey)} - - return []tkatype.Signature{{ - KeyID: key.MustID(), - Signature: ed25519.Sign(priv, sigHash[:]), - }}, nil -} - -func TestAuthorityBuilderAddKey(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - pub2, _ := testingKey25519(t, 2) - key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - - b := a.NewUpdater(signer25519(priv)) - if err := b.AddKey(key2); err != nil { - t.Fatalf("AddKey(%v) failed: %v", key2, err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - - // See if the update is valid by applying it to the authority - // + checking if the new key is there. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - if _, err := a.state.GetKey(key2.MustID()); err != nil { - t.Errorf("could not read new key: %v", err) - } -} - -func TestAuthorityBuilderRemoveKey(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - pub2, _ := testingKey25519(t, 2) - key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key, key2}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - b := a.NewUpdater(signer25519(priv)) - if err := b.RemoveKey(key2.MustID()); err != nil { - t.Fatalf("RemoveKey(%v) failed: %v", key2, err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - - // See if the update is valid by applying it to the authority - // + checking if the key has been removed. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - if _, err := a.state.GetKey(key2.MustID()); err != ErrNoSuchKey { - t.Errorf("GetKey(key2).err = %v, want %v", err, ErrNoSuchKey) - } -} - -func TestAuthorityBuilderSetKeyVote(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - b := a.NewUpdater(signer25519(priv)) - if err := b.SetKeyVote(key.MustID(), 5); err != nil { - t.Fatalf("SetKeyVote(%v) failed: %v", key.MustID(), err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - - // See if the update is valid by applying it to the authority - // + checking if the update is there. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - k, err := a.state.GetKey(key.MustID()) - if err != nil { - t.Fatal(err) - } - if got, want := k.Votes, uint(5); got != want { - t.Errorf("key.Votes = %d, want %d", got, want) - } -} - -func TestAuthorityBuilderSetKeyMeta(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2, Meta: map[string]string{"a": "b"}} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - b := a.NewUpdater(signer25519(priv)) - if err := b.SetKeyMeta(key.MustID(), map[string]string{"b": "c"}); err != nil { - t.Fatalf("SetKeyMeta(%v) failed: %v", key, err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - - // See if the update is valid by applying it to the authority - // + checking if the update is there. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - k, err := a.state.GetKey(key.MustID()) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(map[string]string{"b": "c"}, k.Meta); diff != "" { - t.Errorf("updated meta differs (-want, +got):\n%s", diff) - } -} - -func TestAuthorityBuilderMultiple(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - pub2, _ := testingKey25519(t, 2) - key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - - b := a.NewUpdater(signer25519(priv)) - if err := b.AddKey(key2); err != nil { - t.Fatalf("AddKey(%v) failed: %v", key2, err) - } - if err := b.SetKeyVote(key2.MustID(), 42); err != nil { - t.Fatalf("SetKeyVote(%v) failed: %v", key2, err) - } - if err := b.RemoveKey(key.MustID()); err != nil { - t.Fatalf("RemoveKey(%v) failed: %v", key, err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - - // See if the update is valid by applying it to the authority - // + checking if the update is there. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - k, err := a.state.GetKey(key2.MustID()) - if err != nil { - t.Fatal(err) - } - if got, want := k.Votes, uint(42); got != want { - t.Errorf("key.Votes = %d, want %d", got, want) - } - if _, err := a.state.GetKey(key.MustID()); err != ErrNoSuchKey { - t.Errorf("GetKey(key).err = %v, want %v", err, ErrNoSuchKey) - } -} - -func TestAuthorityBuilderCheckpointsAfterXUpdates(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - for i := 0; i <= checkpointEvery; i++ { - pub2, _ := testingKey25519(t, int64(i+2)) - key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - - b := a.NewUpdater(signer25519(priv)) - if err := b.AddKey(key2); err != nil { - t.Fatalf("AddKey(%v) failed: %v", key2, err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - // See if the update is valid by applying it to the authority - // + checking if the new key is there. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - if _, err := a.state.GetKey(key2.MustID()); err != nil { - t.Fatal(err) - } - - wantKind := AUMAddKey - if i == checkpointEvery-1 { // Genesis + 49 updates == 50 (the value of checkpointEvery) - wantKind = AUMCheckpoint - } - lastAUM, err := storage.AUM(a.Head()) - if err != nil { - t.Fatal(err) - } - if lastAUM.MessageKind != wantKind { - t.Errorf("[%d] HeadAUM.MessageKind = %v, want %v", i, lastAUM.MessageKind, wantKind) - } - } - - // Try starting an authority just based on storage. - a2, err := Open(storage) - if err != nil { - t.Fatalf("Failed to open from stored AUMs: %v", err) - } - if a.Head() != a2.Head() { - t.Errorf("stored and computed HEAD differ: got %v, want %v", a2.Head(), a.Head()) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "crypto/ed25519" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/types/tkatype" +) + +type signer25519 ed25519.PrivateKey + +func (s signer25519) SignAUM(sigHash tkatype.AUMSigHash) ([]tkatype.Signature, error) { + priv := ed25519.PrivateKey(s) + key := Key{Kind: Key25519, Public: priv.Public().(ed25519.PublicKey)} + + return []tkatype.Signature{{ + KeyID: key.MustID(), + Signature: ed25519.Sign(priv, sigHash[:]), + }}, nil +} + +func TestAuthorityBuilderAddKey(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + pub2, _ := testingKey25519(t, 2) + key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + + b := a.NewUpdater(signer25519(priv)) + if err := b.AddKey(key2); err != nil { + t.Fatalf("AddKey(%v) failed: %v", key2, err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + // See if the update is valid by applying it to the authority + // + checking if the new key is there. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + if _, err := a.state.GetKey(key2.MustID()); err != nil { + t.Errorf("could not read new key: %v", err) + } +} + +func TestAuthorityBuilderRemoveKey(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + pub2, _ := testingKey25519(t, 2) + key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key, key2}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + b := a.NewUpdater(signer25519(priv)) + if err := b.RemoveKey(key2.MustID()); err != nil { + t.Fatalf("RemoveKey(%v) failed: %v", key2, err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + // See if the update is valid by applying it to the authority + // + checking if the key has been removed. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + if _, err := a.state.GetKey(key2.MustID()); err != ErrNoSuchKey { + t.Errorf("GetKey(key2).err = %v, want %v", err, ErrNoSuchKey) + } +} + +func TestAuthorityBuilderSetKeyVote(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + b := a.NewUpdater(signer25519(priv)) + if err := b.SetKeyVote(key.MustID(), 5); err != nil { + t.Fatalf("SetKeyVote(%v) failed: %v", key.MustID(), err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + // See if the update is valid by applying it to the authority + // + checking if the update is there. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + k, err := a.state.GetKey(key.MustID()) + if err != nil { + t.Fatal(err) + } + if got, want := k.Votes, uint(5); got != want { + t.Errorf("key.Votes = %d, want %d", got, want) + } +} + +func TestAuthorityBuilderSetKeyMeta(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2, Meta: map[string]string{"a": "b"}} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + b := a.NewUpdater(signer25519(priv)) + if err := b.SetKeyMeta(key.MustID(), map[string]string{"b": "c"}); err != nil { + t.Fatalf("SetKeyMeta(%v) failed: %v", key, err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + // See if the update is valid by applying it to the authority + // + checking if the update is there. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + k, err := a.state.GetKey(key.MustID()) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(map[string]string{"b": "c"}, k.Meta); diff != "" { + t.Errorf("updated meta differs (-want, +got):\n%s", diff) + } +} + +func TestAuthorityBuilderMultiple(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + pub2, _ := testingKey25519(t, 2) + key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + + b := a.NewUpdater(signer25519(priv)) + if err := b.AddKey(key2); err != nil { + t.Fatalf("AddKey(%v) failed: %v", key2, err) + } + if err := b.SetKeyVote(key2.MustID(), 42); err != nil { + t.Fatalf("SetKeyVote(%v) failed: %v", key2, err) + } + if err := b.RemoveKey(key.MustID()); err != nil { + t.Fatalf("RemoveKey(%v) failed: %v", key, err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + // See if the update is valid by applying it to the authority + // + checking if the update is there. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + k, err := a.state.GetKey(key2.MustID()) + if err != nil { + t.Fatal(err) + } + if got, want := k.Votes, uint(42); got != want { + t.Errorf("key.Votes = %d, want %d", got, want) + } + if _, err := a.state.GetKey(key.MustID()); err != ErrNoSuchKey { + t.Errorf("GetKey(key).err = %v, want %v", err, ErrNoSuchKey) + } +} + +func TestAuthorityBuilderCheckpointsAfterXUpdates(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + for i := 0; i <= checkpointEvery; i++ { + pub2, _ := testingKey25519(t, int64(i+2)) + key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + + b := a.NewUpdater(signer25519(priv)) + if err := b.AddKey(key2); err != nil { + t.Fatalf("AddKey(%v) failed: %v", key2, err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + // See if the update is valid by applying it to the authority + // + checking if the new key is there. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + if _, err := a.state.GetKey(key2.MustID()); err != nil { + t.Fatal(err) + } + + wantKind := AUMAddKey + if i == checkpointEvery-1 { // Genesis + 49 updates == 50 (the value of checkpointEvery) + wantKind = AUMCheckpoint + } + lastAUM, err := storage.AUM(a.Head()) + if err != nil { + t.Fatal(err) + } + if lastAUM.MessageKind != wantKind { + t.Errorf("[%d] HeadAUM.MessageKind = %v, want %v", i, lastAUM.MessageKind, wantKind) + } + } + + // Try starting an authority just based on storage. + a2, err := Open(storage) + if err != nil { + t.Fatalf("Failed to open from stored AUMs: %v", err) + } + if a.Head() != a2.Head() { + t.Errorf("stored and computed HEAD differ: got %v, want %v", a2.Head(), a.Head()) + } +} diff --git a/tka/deeplink.go b/tka/deeplink.go index 5cf24fc5c2c82..97bcd664b23ec 100644 --- a/tka/deeplink.go +++ b/tka/deeplink.go @@ -1,221 +1,221 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "crypto/hmac" - "crypto/sha256" - "encoding/binary" - "encoding/hex" - "fmt" - "net/url" - "strings" -) - -const ( - DeeplinkTailscaleURLScheme = "tailscale" - DeeplinkCommandSign = "sign-device" -) - -// generateHMAC computes a SHA-256 HMAC for the concatenation of components, -// using the Authority stateID as secret. -func (a *Authority) generateHMAC(params NewDeeplinkParams) []byte { - stateID, _ := a.StateIDs() - - key := make([]byte, 8) - binary.LittleEndian.PutUint64(key, stateID) - mac := hmac.New(sha256.New, key) - mac.Write([]byte(params.NodeKey)) - mac.Write([]byte(params.TLPub)) - mac.Write([]byte(params.DeviceName)) - mac.Write([]byte(params.OSName)) - mac.Write([]byte(params.LoginName)) - return mac.Sum(nil) -} - -type NewDeeplinkParams struct { - NodeKey string - TLPub string - DeviceName string - OSName string - LoginName string -} - -// NewDeeplink creates a signed deeplink using the authority's stateID as a -// secret. This deeplink can then be validated by ValidateDeeplink. -func (a *Authority) NewDeeplink(params NewDeeplinkParams) (string, error) { - if params.NodeKey == "" || !strings.HasPrefix(params.NodeKey, "nodekey:") { - return "", fmt.Errorf("invalid node key %q", params.NodeKey) - } - if params.TLPub == "" || !strings.HasPrefix(params.TLPub, "tlpub:") { - return "", fmt.Errorf("invalid tlpub %q", params.TLPub) - } - if params.DeviceName == "" { - return "", fmt.Errorf("invalid device name %q", params.DeviceName) - } - if params.OSName == "" { - return "", fmt.Errorf("invalid os name %q", params.OSName) - } - if params.LoginName == "" { - return "", fmt.Errorf("invalid login name %q", params.LoginName) - } - - u := url.URL{ - Scheme: DeeplinkTailscaleURLScheme, - Host: DeeplinkCommandSign, - Path: "/v1/", - } - v := url.Values{} - v.Set("nk", params.NodeKey) - v.Set("tp", params.TLPub) - v.Set("dn", params.DeviceName) - v.Set("os", params.OSName) - v.Set("em", params.LoginName) - - hmac := a.generateHMAC(params) - v.Set("hm", hex.EncodeToString(hmac)) - - u.RawQuery = v.Encode() - return u.String(), nil -} - -type DeeplinkValidationResult struct { - IsValid bool - Error string - Version uint8 - NodeKey string - TLPub string - DeviceName string - OSName string - EmailAddress string -} - -// ValidateDeeplink validates a device signing deeplink using the authority's stateID. -// The input urlString follows this structure: -// -// tailscale://sign-device/v1/?nk=xxx&tp=xxx&dn=xxx&os=xxx&em=xxx&hm=xxx -// -// where: -// - "nk" is the nodekey of the node being signed -// - "tp" is the tailnet lock public key -// - "dn" is the name of the node -// - "os" is the operating system of the node -// - "em" is the email address associated with the node -// - "hm" is a SHA-256 HMAC computed over the concatenation of the above fields, encoded as a hex string -func (a *Authority) ValidateDeeplink(urlString string) DeeplinkValidationResult { - parsedUrl, err := url.Parse(urlString) - if err != nil { - return DeeplinkValidationResult{ - IsValid: false, - Error: err.Error(), - } - } - - if parsedUrl.Scheme != DeeplinkTailscaleURLScheme { - return DeeplinkValidationResult{ - IsValid: false, - Error: fmt.Sprintf("unhandled scheme %s, expected %s", parsedUrl.Scheme, DeeplinkTailscaleURLScheme), - } - } - - if parsedUrl.Host != DeeplinkCommandSign { - return DeeplinkValidationResult{ - IsValid: false, - Error: fmt.Sprintf("unhandled host %s, expected %s", parsedUrl.Host, DeeplinkCommandSign), - } - } - - path := parsedUrl.EscapedPath() - pathComponents := strings.Split(path, "/") - if len(pathComponents) != 3 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "invalid path components number found", - } - } - - if pathComponents[1] != "v1" { - return DeeplinkValidationResult{ - IsValid: false, - Error: fmt.Sprintf("expected v1 deeplink version, found something else: %s", pathComponents[1]), - } - } - - nodeKey := parsedUrl.Query().Get("nk") - if len(nodeKey) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing nk (NodeKey) query parameter", - } - } - - tlPub := parsedUrl.Query().Get("tp") - if len(tlPub) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing tp (TLPub) query parameter", - } - } - - deviceName := parsedUrl.Query().Get("dn") - if len(deviceName) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing dn (DeviceName) query parameter", - } - } - - osName := parsedUrl.Query().Get("os") - if len(deviceName) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing os (OSName) query parameter", - } - } - - emailAddress := parsedUrl.Query().Get("em") - if len(emailAddress) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing em (EmailAddress) query parameter", - } - } - - hmacString := parsedUrl.Query().Get("hm") - if len(hmacString) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing hm (HMAC) query parameter", - } - } - - computedHMAC := a.generateHMAC(NewDeeplinkParams{ - NodeKey: nodeKey, - TLPub: tlPub, - DeviceName: deviceName, - OSName: osName, - LoginName: emailAddress, - }) - - hmacHexBytes, err := hex.DecodeString(hmacString) - if err != nil { - return DeeplinkValidationResult{IsValid: false, Error: "could not hex-decode hmac"} - } - - if !hmac.Equal(computedHMAC, hmacHexBytes) { - return DeeplinkValidationResult{ - IsValid: false, - Error: "hmac authentication failed", - } - } - - return DeeplinkValidationResult{ - IsValid: true, - NodeKey: nodeKey, - TLPub: tlPub, - DeviceName: deviceName, - OSName: osName, - EmailAddress: emailAddress, - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "net/url" + "strings" +) + +const ( + DeeplinkTailscaleURLScheme = "tailscale" + DeeplinkCommandSign = "sign-device" +) + +// generateHMAC computes a SHA-256 HMAC for the concatenation of components, +// using the Authority stateID as secret. +func (a *Authority) generateHMAC(params NewDeeplinkParams) []byte { + stateID, _ := a.StateIDs() + + key := make([]byte, 8) + binary.LittleEndian.PutUint64(key, stateID) + mac := hmac.New(sha256.New, key) + mac.Write([]byte(params.NodeKey)) + mac.Write([]byte(params.TLPub)) + mac.Write([]byte(params.DeviceName)) + mac.Write([]byte(params.OSName)) + mac.Write([]byte(params.LoginName)) + return mac.Sum(nil) +} + +type NewDeeplinkParams struct { + NodeKey string + TLPub string + DeviceName string + OSName string + LoginName string +} + +// NewDeeplink creates a signed deeplink using the authority's stateID as a +// secret. This deeplink can then be validated by ValidateDeeplink. +func (a *Authority) NewDeeplink(params NewDeeplinkParams) (string, error) { + if params.NodeKey == "" || !strings.HasPrefix(params.NodeKey, "nodekey:") { + return "", fmt.Errorf("invalid node key %q", params.NodeKey) + } + if params.TLPub == "" || !strings.HasPrefix(params.TLPub, "tlpub:") { + return "", fmt.Errorf("invalid tlpub %q", params.TLPub) + } + if params.DeviceName == "" { + return "", fmt.Errorf("invalid device name %q", params.DeviceName) + } + if params.OSName == "" { + return "", fmt.Errorf("invalid os name %q", params.OSName) + } + if params.LoginName == "" { + return "", fmt.Errorf("invalid login name %q", params.LoginName) + } + + u := url.URL{ + Scheme: DeeplinkTailscaleURLScheme, + Host: DeeplinkCommandSign, + Path: "/v1/", + } + v := url.Values{} + v.Set("nk", params.NodeKey) + v.Set("tp", params.TLPub) + v.Set("dn", params.DeviceName) + v.Set("os", params.OSName) + v.Set("em", params.LoginName) + + hmac := a.generateHMAC(params) + v.Set("hm", hex.EncodeToString(hmac)) + + u.RawQuery = v.Encode() + return u.String(), nil +} + +type DeeplinkValidationResult struct { + IsValid bool + Error string + Version uint8 + NodeKey string + TLPub string + DeviceName string + OSName string + EmailAddress string +} + +// ValidateDeeplink validates a device signing deeplink using the authority's stateID. +// The input urlString follows this structure: +// +// tailscale://sign-device/v1/?nk=xxx&tp=xxx&dn=xxx&os=xxx&em=xxx&hm=xxx +// +// where: +// - "nk" is the nodekey of the node being signed +// - "tp" is the tailnet lock public key +// - "dn" is the name of the node +// - "os" is the operating system of the node +// - "em" is the email address associated with the node +// - "hm" is a SHA-256 HMAC computed over the concatenation of the above fields, encoded as a hex string +func (a *Authority) ValidateDeeplink(urlString string) DeeplinkValidationResult { + parsedUrl, err := url.Parse(urlString) + if err != nil { + return DeeplinkValidationResult{ + IsValid: false, + Error: err.Error(), + } + } + + if parsedUrl.Scheme != DeeplinkTailscaleURLScheme { + return DeeplinkValidationResult{ + IsValid: false, + Error: fmt.Sprintf("unhandled scheme %s, expected %s", parsedUrl.Scheme, DeeplinkTailscaleURLScheme), + } + } + + if parsedUrl.Host != DeeplinkCommandSign { + return DeeplinkValidationResult{ + IsValid: false, + Error: fmt.Sprintf("unhandled host %s, expected %s", parsedUrl.Host, DeeplinkCommandSign), + } + } + + path := parsedUrl.EscapedPath() + pathComponents := strings.Split(path, "/") + if len(pathComponents) != 3 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "invalid path components number found", + } + } + + if pathComponents[1] != "v1" { + return DeeplinkValidationResult{ + IsValid: false, + Error: fmt.Sprintf("expected v1 deeplink version, found something else: %s", pathComponents[1]), + } + } + + nodeKey := parsedUrl.Query().Get("nk") + if len(nodeKey) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing nk (NodeKey) query parameter", + } + } + + tlPub := parsedUrl.Query().Get("tp") + if len(tlPub) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing tp (TLPub) query parameter", + } + } + + deviceName := parsedUrl.Query().Get("dn") + if len(deviceName) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing dn (DeviceName) query parameter", + } + } + + osName := parsedUrl.Query().Get("os") + if len(deviceName) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing os (OSName) query parameter", + } + } + + emailAddress := parsedUrl.Query().Get("em") + if len(emailAddress) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing em (EmailAddress) query parameter", + } + } + + hmacString := parsedUrl.Query().Get("hm") + if len(hmacString) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing hm (HMAC) query parameter", + } + } + + computedHMAC := a.generateHMAC(NewDeeplinkParams{ + NodeKey: nodeKey, + TLPub: tlPub, + DeviceName: deviceName, + OSName: osName, + LoginName: emailAddress, + }) + + hmacHexBytes, err := hex.DecodeString(hmacString) + if err != nil { + return DeeplinkValidationResult{IsValid: false, Error: "could not hex-decode hmac"} + } + + if !hmac.Equal(computedHMAC, hmacHexBytes) { + return DeeplinkValidationResult{ + IsValid: false, + Error: "hmac authentication failed", + } + } + + return DeeplinkValidationResult{ + IsValid: true, + NodeKey: nodeKey, + TLPub: tlPub, + DeviceName: deviceName, + OSName: osName, + EmailAddress: emailAddress, + } +} diff --git a/tka/deeplink_test.go b/tka/deeplink_test.go index 03523202fed8b..397cc6917f289 100644 --- a/tka/deeplink_test.go +++ b/tka/deeplink_test.go @@ -1,52 +1,52 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "testing" -) - -func TestGenerateDeeplink(t *testing.T) { - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - c := newTestchain(t, ` - G1 -> L1 - - G1.template = genesis - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - ) - a, _ := Open(c.Chonk()) - - nodeKey := "nodekey:1234567890" - tlPub := "tlpub:1234567890" - deviceName := "Example Device" - osName := "iOS" - loginName := "insecure@example.com" - - deeplink, err := a.NewDeeplink(NewDeeplinkParams{ - NodeKey: nodeKey, - TLPub: tlPub, - DeviceName: deviceName, - OSName: osName, - LoginName: loginName, - }) - if err != nil { - t.Errorf("deeplink generation failed: %v", err) - } - - res := a.ValidateDeeplink(deeplink) - if !res.IsValid { - t.Errorf("deeplink validation failed: %s", res.Error) - } - if res.NodeKey != nodeKey { - t.Errorf("node key mismatch: %s != %s", res.NodeKey, nodeKey) - } - if res.TLPub != tlPub { - t.Errorf("tlpub mismatch: %s != %s", res.TLPub, tlPub) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "testing" +) + +func TestGenerateDeeplink(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + c := newTestchain(t, ` + G1 -> L1 + + G1.template = genesis + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + ) + a, _ := Open(c.Chonk()) + + nodeKey := "nodekey:1234567890" + tlPub := "tlpub:1234567890" + deviceName := "Example Device" + osName := "iOS" + loginName := "insecure@example.com" + + deeplink, err := a.NewDeeplink(NewDeeplinkParams{ + NodeKey: nodeKey, + TLPub: tlPub, + DeviceName: deviceName, + OSName: osName, + LoginName: loginName, + }) + if err != nil { + t.Errorf("deeplink generation failed: %v", err) + } + + res := a.ValidateDeeplink(deeplink) + if !res.IsValid { + t.Errorf("deeplink validation failed: %s", res.Error) + } + if res.NodeKey != nodeKey { + t.Errorf("node key mismatch: %s != %s", res.NodeKey, nodeKey) + } + if res.TLPub != tlPub { + t.Errorf("tlpub mismatch: %s != %s", res.TLPub, tlPub) + } +} diff --git a/tka/key.go b/tka/key.go index 07736795d8e58..47218438d88ea 100644 --- a/tka/key.go +++ b/tka/key.go @@ -1,159 +1,159 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "crypto/ed25519" - "errors" - "fmt" - - "github.com/hdevalence/ed25519consensus" - "tailscale.com/types/tkatype" -) - -// KeyKind describes the different varieties of a Key. -type KeyKind uint8 - -// Valid KeyKind values. -const ( - KeyInvalid KeyKind = iota - Key25519 -) - -func (k KeyKind) String() string { - switch k { - case KeyInvalid: - return "invalid" - case Key25519: - return "25519" - default: - return fmt.Sprintf("Key?<%d>", int(k)) - } -} - -// Key describes the public components of a key known to network-lock. -type Key struct { - Kind KeyKind `cbor:"1,keyasint"` - - // Votes describes the weight applied to signatures using this key. - // Weighting is used to deterministically resolve branches in the AUM - // chain (i.e. forks, where two AUMs exist with the same parent). - Votes uint `cbor:"2,keyasint"` - - // Public encodes the public key of the key. For 25519 keys, - // this is simply the point on the curve representing the public - // key. - Public []byte `cbor:"3,keyasint"` - - // Meta describes arbitrary metadata about the key. This could be - // used to store the name of the key, for instance. - Meta map[string]string `cbor:"12,keyasint,omitempty"` -} - -// Clone makes an independent copy of Key. -// -// NOTE: There is a difference between a nil slice and an empty slice for encoding purposes, -// so an implementation of Clone() must take care to preserve this. -func (k Key) Clone() Key { - out := k - - if k.Public != nil { - out.Public = make([]byte, len(k.Public)) - copy(out.Public, k.Public) - } - - if k.Meta != nil { - out.Meta = make(map[string]string, len(k.Meta)) - for k, v := range k.Meta { - out.Meta[k] = v - } - } - - return out -} - -// MustID returns the KeyID of the key, panicking if an error is -// encountered. This must only be used for tests. -func (k Key) MustID() tkatype.KeyID { - id, err := k.ID() - if err != nil { - panic(err) - } - return id -} - -// ID returns the KeyID of the key. -func (k Key) ID() (tkatype.KeyID, error) { - switch k.Kind { - // Because 25519 public keys are so short, we just use the 32-byte - // public as their 'key ID'. - case Key25519: - return tkatype.KeyID(k.Public), nil - default: - return nil, fmt.Errorf("unknown key kind: %v", k.Kind) - } -} - -// Ed25519 returns the ed25519 public key encoded by Key. An error is -// returned for keys which do not represent ed25519 public keys. -func (k Key) Ed25519() (ed25519.PublicKey, error) { - switch k.Kind { - case Key25519: - return ed25519.PublicKey(k.Public), nil - default: - return nil, fmt.Errorf("key is of type %v, not ed25519", k.Kind) - } -} - -const maxMetaBytes = 512 - -func (k Key) StaticValidate() error { - if k.Votes > 4096 { - return fmt.Errorf("excessive key weight: %d > 4096", k.Votes) - } - if k.Votes == 0 { - return errors.New("key votes must be non-zero") - } - - // We have an arbitrary upper limit on the amount - // of metadata that can be associated with a key, so - // people don't start using it as a key-value store and - // causing pathological cases due to the number + size of - // AUMs. - var metaBytes uint - for k, v := range k.Meta { - metaBytes += uint(len(k) + len(v)) - } - if metaBytes > maxMetaBytes { - return fmt.Errorf("key metadata too big (%d > %d)", metaBytes, maxMetaBytes) - } - - switch k.Kind { - case Key25519: - default: - return fmt.Errorf("unrecognized key kind: %v", k.Kind) - } - return nil -} - -// Verify returns a nil error if the signature is valid over the -// provided AUM BLAKE2s digest, using the given key. -func signatureVerify(s *tkatype.Signature, aumDigest tkatype.AUMSigHash, key Key) error { - // NOTE(tom): Even if we can compute the public from the KeyID, - // its possible for the KeyID to be attacker-controlled - // so we should use the public contained in the state machine. - switch key.Kind { - case Key25519: - if len(key.Public) != ed25519.PublicKeySize { - return fmt.Errorf("ed25519 key has wrong length: %d", len(key.Public)) - } - if ed25519consensus.Verify(ed25519.PublicKey(key.Public), aumDigest[:], s.Signature) { - return nil - } - return errors.New("invalid signature") - - default: - return fmt.Errorf("unhandled key type: %v", key.Kind) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "crypto/ed25519" + "errors" + "fmt" + + "github.com/hdevalence/ed25519consensus" + "tailscale.com/types/tkatype" +) + +// KeyKind describes the different varieties of a Key. +type KeyKind uint8 + +// Valid KeyKind values. +const ( + KeyInvalid KeyKind = iota + Key25519 +) + +func (k KeyKind) String() string { + switch k { + case KeyInvalid: + return "invalid" + case Key25519: + return "25519" + default: + return fmt.Sprintf("Key?<%d>", int(k)) + } +} + +// Key describes the public components of a key known to network-lock. +type Key struct { + Kind KeyKind `cbor:"1,keyasint"` + + // Votes describes the weight applied to signatures using this key. + // Weighting is used to deterministically resolve branches in the AUM + // chain (i.e. forks, where two AUMs exist with the same parent). + Votes uint `cbor:"2,keyasint"` + + // Public encodes the public key of the key. For 25519 keys, + // this is simply the point on the curve representing the public + // key. + Public []byte `cbor:"3,keyasint"` + + // Meta describes arbitrary metadata about the key. This could be + // used to store the name of the key, for instance. + Meta map[string]string `cbor:"12,keyasint,omitempty"` +} + +// Clone makes an independent copy of Key. +// +// NOTE: There is a difference between a nil slice and an empty slice for encoding purposes, +// so an implementation of Clone() must take care to preserve this. +func (k Key) Clone() Key { + out := k + + if k.Public != nil { + out.Public = make([]byte, len(k.Public)) + copy(out.Public, k.Public) + } + + if k.Meta != nil { + out.Meta = make(map[string]string, len(k.Meta)) + for k, v := range k.Meta { + out.Meta[k] = v + } + } + + return out +} + +// MustID returns the KeyID of the key, panicking if an error is +// encountered. This must only be used for tests. +func (k Key) MustID() tkatype.KeyID { + id, err := k.ID() + if err != nil { + panic(err) + } + return id +} + +// ID returns the KeyID of the key. +func (k Key) ID() (tkatype.KeyID, error) { + switch k.Kind { + // Because 25519 public keys are so short, we just use the 32-byte + // public as their 'key ID'. + case Key25519: + return tkatype.KeyID(k.Public), nil + default: + return nil, fmt.Errorf("unknown key kind: %v", k.Kind) + } +} + +// Ed25519 returns the ed25519 public key encoded by Key. An error is +// returned for keys which do not represent ed25519 public keys. +func (k Key) Ed25519() (ed25519.PublicKey, error) { + switch k.Kind { + case Key25519: + return ed25519.PublicKey(k.Public), nil + default: + return nil, fmt.Errorf("key is of type %v, not ed25519", k.Kind) + } +} + +const maxMetaBytes = 512 + +func (k Key) StaticValidate() error { + if k.Votes > 4096 { + return fmt.Errorf("excessive key weight: %d > 4096", k.Votes) + } + if k.Votes == 0 { + return errors.New("key votes must be non-zero") + } + + // We have an arbitrary upper limit on the amount + // of metadata that can be associated with a key, so + // people don't start using it as a key-value store and + // causing pathological cases due to the number + size of + // AUMs. + var metaBytes uint + for k, v := range k.Meta { + metaBytes += uint(len(k) + len(v)) + } + if metaBytes > maxMetaBytes { + return fmt.Errorf("key metadata too big (%d > %d)", metaBytes, maxMetaBytes) + } + + switch k.Kind { + case Key25519: + default: + return fmt.Errorf("unrecognized key kind: %v", k.Kind) + } + return nil +} + +// Verify returns a nil error if the signature is valid over the +// provided AUM BLAKE2s digest, using the given key. +func signatureVerify(s *tkatype.Signature, aumDigest tkatype.AUMSigHash, key Key) error { + // NOTE(tom): Even if we can compute the public from the KeyID, + // its possible for the KeyID to be attacker-controlled + // so we should use the public contained in the state machine. + switch key.Kind { + case Key25519: + if len(key.Public) != ed25519.PublicKeySize { + return fmt.Errorf("ed25519 key has wrong length: %d", len(key.Public)) + } + if ed25519consensus.Verify(ed25519.PublicKey(key.Public), aumDigest[:], s.Signature) { + return nil + } + return errors.New("invalid signature") + + default: + return fmt.Errorf("unhandled key type: %v", key.Kind) + } +} diff --git a/tka/key_test.go b/tka/key_test.go index e912f89c4f7eb..aaddb2f404f10 100644 --- a/tka/key_test.go +++ b/tka/key_test.go @@ -1,97 +1,97 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "crypto/ed25519" - "encoding/binary" - "math/rand" - "testing" - - "tailscale.com/types/key" - "tailscale.com/types/tkatype" -) - -// returns a random source based on the test name + extraSeed. -func testingRand(t *testing.T, extraSeed int64) *rand.Rand { - var seed int64 - if err := binary.Read(bytes.NewBuffer([]byte(t.Name())), binary.LittleEndian, &seed); err != nil { - panic(err) - } - return rand.New(rand.NewSource(seed + extraSeed)) -} - -// generates a 25519 private key based on the seed + test name. -func testingKey25519(t *testing.T, seed int64) (ed25519.PublicKey, ed25519.PrivateKey) { - pub, priv, err := ed25519.GenerateKey(testingRand(t, seed)) - if err != nil { - panic(err) - } - return pub, priv -} - -func TestVerify25519(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{ - Kind: Key25519, - Public: pub, - } - - aum := AUM{ - MessageKind: AUMRemoveKey, - KeyID: []byte{1, 2, 3, 4}, - // Signatures is set to crap so we are sure its ignored in the sigHash computation. - Signatures: []tkatype.Signature{{KeyID: []byte{45, 42}}}, - } - sigHash := aum.SigHash() - aum.Signatures = []tkatype.Signature{ - { - KeyID: key.MustID(), - Signature: ed25519.Sign(priv, sigHash[:]), - }, - } - - if err := signatureVerify(&aum.Signatures[0], aum.SigHash(), key); err != nil { - t.Errorf("signature verification failed: %v", err) - } - - // Make sure it fails with a different public key. - pub2, _ := testingKey25519(t, 2) - key2 := Key{Kind: Key25519, Public: pub2} - if err := signatureVerify(&aum.Signatures[0], aum.SigHash(), key2); err == nil { - t.Error("signature verification with different key did not fail") - } -} - -func TestNLPrivate(t *testing.T) { - p := key.NewNLPrivate() - pub := p.Public() - - // Test that key.NLPrivate implements Signer by making a new - // authority. - k := Key{Kind: Key25519, Public: pub.Verifier(), Votes: 1} - _, aum, err := Create(&Mem{}, State{ - Keys: []Key{k}, - DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, - }, p) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - // Make sure the generated genesis AUM was signed. - if got, want := len(aum.Signatures), 1; got != want { - t.Fatalf("len(signatures) = %d, want %d", got, want) - } - sigHash := aum.SigHash() - if ok := ed25519.Verify(pub.Verifier(), sigHash[:], aum.Signatures[0].Signature); !ok { - t.Error("signature did not verify") - } - - // We manually compute the keyID, so make sure its consistent with - // tka.Key.ID(). - if !bytes.Equal(k.MustID(), p.KeyID()) { - t.Errorf("private.KeyID() & tka KeyID differ: %x != %x", k.MustID(), p.KeyID()) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "crypto/ed25519" + "encoding/binary" + "math/rand" + "testing" + + "tailscale.com/types/key" + "tailscale.com/types/tkatype" +) + +// returns a random source based on the test name + extraSeed. +func testingRand(t *testing.T, extraSeed int64) *rand.Rand { + var seed int64 + if err := binary.Read(bytes.NewBuffer([]byte(t.Name())), binary.LittleEndian, &seed); err != nil { + panic(err) + } + return rand.New(rand.NewSource(seed + extraSeed)) +} + +// generates a 25519 private key based on the seed + test name. +func testingKey25519(t *testing.T, seed int64) (ed25519.PublicKey, ed25519.PrivateKey) { + pub, priv, err := ed25519.GenerateKey(testingRand(t, seed)) + if err != nil { + panic(err) + } + return pub, priv +} + +func TestVerify25519(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{ + Kind: Key25519, + Public: pub, + } + + aum := AUM{ + MessageKind: AUMRemoveKey, + KeyID: []byte{1, 2, 3, 4}, + // Signatures is set to crap so we are sure its ignored in the sigHash computation. + Signatures: []tkatype.Signature{{KeyID: []byte{45, 42}}}, + } + sigHash := aum.SigHash() + aum.Signatures = []tkatype.Signature{ + { + KeyID: key.MustID(), + Signature: ed25519.Sign(priv, sigHash[:]), + }, + } + + if err := signatureVerify(&aum.Signatures[0], aum.SigHash(), key); err != nil { + t.Errorf("signature verification failed: %v", err) + } + + // Make sure it fails with a different public key. + pub2, _ := testingKey25519(t, 2) + key2 := Key{Kind: Key25519, Public: pub2} + if err := signatureVerify(&aum.Signatures[0], aum.SigHash(), key2); err == nil { + t.Error("signature verification with different key did not fail") + } +} + +func TestNLPrivate(t *testing.T) { + p := key.NewNLPrivate() + pub := p.Public() + + // Test that key.NLPrivate implements Signer by making a new + // authority. + k := Key{Kind: Key25519, Public: pub.Verifier(), Votes: 1} + _, aum, err := Create(&Mem{}, State{ + Keys: []Key{k}, + DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, + }, p) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + // Make sure the generated genesis AUM was signed. + if got, want := len(aum.Signatures), 1; got != want { + t.Fatalf("len(signatures) = %d, want %d", got, want) + } + sigHash := aum.SigHash() + if ok := ed25519.Verify(pub.Verifier(), sigHash[:], aum.Signatures[0].Signature); !ok { + t.Error("signature did not verify") + } + + // We manually compute the keyID, so make sure its consistent with + // tka.Key.ID(). + if !bytes.Equal(k.MustID(), p.KeyID()) { + t.Errorf("private.KeyID() & tka KeyID differ: %x != %x", k.MustID(), p.KeyID()) + } +} diff --git a/tka/state.go b/tka/state.go index 0a459bd9a1b24..e99b731ccb2ad 100644 --- a/tka/state.go +++ b/tka/state.go @@ -1,315 +1,315 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "errors" - "fmt" - - "golang.org/x/crypto/argon2" - "tailscale.com/types/tkatype" -) - -// ErrNoSuchKey is returned if the key referenced by a KeyID does not exist. -var ErrNoSuchKey = errors.New("key not found") - -// State describes Tailnet Key Authority state at an instant in time. -// -// State is mutated by applying Authority Update Messages (AUMs), resulting -// in a new State. -type State struct { - // LastAUMHash is the blake2s digest of the last-applied AUM. - // Because AUMs are strictly ordered and form a hash chain, we - // check the previous AUM hash in an update we are applying - // is the same as the LastAUMHash. - LastAUMHash *AUMHash `cbor:"1,keyasint"` - - // DisablementSecrets are KDF-derived values which can be used - // to turn off the TKA in the event of a consensus-breaking bug. - DisablementSecrets [][]byte `cbor:"2,keyasint"` - - // Keys are the public keys of either: - // - // 1. The signing nodes currently trusted by the TKA. - // 2. Ephemeral keys that were used to generate pre-signed auth keys. - Keys []Key `cbor:"3,keyasint"` - - // StateID's are nonce's, generated on enablement and fixed for - // the lifetime of the Tailnet Key Authority. We generate 16-bytes - // worth of keyspace here just in case we come up with a cool future - // use for this. - StateID1 uint64 `cbor:"4,keyasint,omitempty"` - StateID2 uint64 `cbor:"5,keyasint,omitempty"` -} - -// GetKey returns the trusted key with the specified KeyID. -func (s State) GetKey(key tkatype.KeyID) (Key, error) { - for _, k := range s.Keys { - keyID, err := k.ID() - if err != nil { - return Key{}, err - } - - if bytes.Equal(keyID, key) { - return k, nil - } - } - - return Key{}, ErrNoSuchKey -} - -// Clone makes an independent copy of State. -// -// NOTE: There is a difference between a nil slice and an empty -// slice for encoding purposes, so an implementation of Clone() -// must take care to preserve this. -func (s State) Clone() State { - out := State{ - StateID1: s.StateID1, - StateID2: s.StateID2, - } - - if s.LastAUMHash != nil { - dupe := *s.LastAUMHash - out.LastAUMHash = &dupe - } - - if s.DisablementSecrets != nil { - out.DisablementSecrets = make([][]byte, len(s.DisablementSecrets)) - for i := range s.DisablementSecrets { - out.DisablementSecrets[i] = make([]byte, len(s.DisablementSecrets[i])) - copy(out.DisablementSecrets[i], s.DisablementSecrets[i]) - } - } - - if s.Keys != nil { - out.Keys = make([]Key, len(s.Keys)) - for i := range s.Keys { - out.Keys[i] = s.Keys[i].Clone() - } - } - - return out -} - -// cloneForUpdate is like Clone, except LastAUMHash is set based -// on the hash of the given update. -func (s State) cloneForUpdate(update *AUM) State { - out := s.Clone() - aumHash := update.Hash() - out.LastAUMHash = &aumHash - return out -} - -const disablementLength = 32 - -var disablementSalt = []byte("tailscale network-lock disablement salt") - -// DisablementKDF computes a public value which can be stored in a -// key authority, but cannot be reversed to find the input secret. -// -// When the output of this function is stored in tka state (i.e. in -// tka.State.DisablementSecrets) a call to Authority.ValidDisablement() -// with the input of this function as the argument will return true. -func DisablementKDF(secret []byte) []byte { - // time = 4 (3 recommended, booped to 4 to compensate for less memory) - // memory = 16 (32 recommended) - // threads = 4 - // keyLen = 32 (256 bits) - return argon2.Key(secret, disablementSalt, 4, 16*1024, 4, disablementLength) -} - -// checkDisablement returns true for a valid disablement secret. -func (s State) checkDisablement(secret []byte) bool { - derived := DisablementKDF(secret) - for _, candidate := range s.DisablementSecrets { - if bytes.Equal(derived, candidate) { - return true - } - } - return false -} - -// parentMatches returns true if an AUM can chain to (be applied) -// to the current state. -// -// Specifically, the rules are: -// - The last AUM hash must match (transitively, this implies that this -// update follows the last update message applied to the state machine) -// - Or, the state machine knows no parent (its brand new). -func (s State) parentMatches(update AUM) bool { - if s.LastAUMHash == nil { - return true - } - return bytes.Equal(s.LastAUMHash[:], update.PrevAUMHash) -} - -// applyVerifiedAUM computes a new state based on the update provided. -// -// The provided update MUST be verified: That is, the AUM must be well-formed -// (as defined by StaticValidate()), and signatures over the AUM must have -// been verified. -func (s State) applyVerifiedAUM(update AUM) (State, error) { - // Validate that the update message has the right parent. - if !s.parentMatches(update) { - return State{}, errors.New("parent AUMHash mismatch") - } - - switch update.MessageKind { - case AUMNoOp: - out := s.cloneForUpdate(&update) - return out, nil - - case AUMCheckpoint: - if update.State == nil { - return State{}, errors.New("missing checkpoint state") - } - id1Match, id2Match := update.State.StateID1 == s.StateID1, update.State.StateID2 == s.StateID2 - if !id1Match || !id2Match { - return State{}, errors.New("checkpointed state has an incorrect stateID") - } - return update.State.cloneForUpdate(&update), nil - - case AUMAddKey: - if update.Key == nil { - return State{}, errors.New("no key to add provided") - } - keyID, err := update.Key.ID() - if err != nil { - return State{}, err - } - if _, err := s.GetKey(keyID); err == nil { - return State{}, errors.New("key already exists") - } - out := s.cloneForUpdate(&update) - out.Keys = append(out.Keys, *update.Key) - return out, nil - - case AUMUpdateKey: - k, err := s.GetKey(update.KeyID) - if err != nil { - return State{}, err - } - if update.Votes != nil { - k.Votes = *update.Votes - } - if update.Meta != nil { - k.Meta = update.Meta - } - if err := k.StaticValidate(); err != nil { - return State{}, fmt.Errorf("updated key fails validation: %v", err) - } - out := s.cloneForUpdate(&update) - for i := range out.Keys { - keyID, err := out.Keys[i].ID() - if err != nil { - return State{}, err - } - if bytes.Equal(keyID, update.KeyID) { - out.Keys[i] = k - } - } - return out, nil - - case AUMRemoveKey: - idx := -1 - for i := range s.Keys { - keyID, err := s.Keys[i].ID() - if err != nil { - return State{}, err - } - if bytes.Equal(update.KeyID, keyID) { - idx = i - break - } - } - if idx < 0 { - return State{}, ErrNoSuchKey - } - out := s.cloneForUpdate(&update) - out.Keys = append(out.Keys[:idx], out.Keys[idx+1:]...) - return out, nil - - default: - // An AUM with an unknown message kind was received! That means - // that a future version of tailscaled added some feature we don't - // understand. - // - // The future-compatibility contract for AUM message types is that - // they must only add new features, not change the semantics of existing - // mechanisms or features. As such, old clients can safely ignore them. - out := s.cloneForUpdate(&update) - return out, nil - } -} - -// Upper bound on checkpoint elements, chosen arbitrarily. Intended to -// cap out insanely large AUMs. -const ( - maxDisablementSecrets = 32 - maxKeys = 512 -) - -// staticValidateCheckpoint validates that the state is well-formed for -// inclusion in a checkpoint AUM. -func (s *State) staticValidateCheckpoint() error { - if s.LastAUMHash != nil { - return errors.New("cannot specify a parent AUM") - } - if len(s.DisablementSecrets) == 0 { - return errors.New("at least one disablement secret required") - } - if numDS := len(s.DisablementSecrets); numDS > maxDisablementSecrets { - return fmt.Errorf("too many disablement secrets (%d, max %d)", numDS, maxDisablementSecrets) - } - for i, ds := range s.DisablementSecrets { - if len(ds) != disablementLength { - return fmt.Errorf("disablement[%d]: invalid length (got %d, want %d)", i, len(ds), disablementLength) - } - for j, ds2 := range s.DisablementSecrets { - if i == j { - continue - } - if bytes.Equal(ds, ds2) { - return fmt.Errorf("disablement[%d]: duplicates disablement[%d]", i, j) - } - } - } - - if len(s.Keys) == 0 { - return errors.New("at least one key is required") - } - if numKeys := len(s.Keys); numKeys > maxKeys { - return fmt.Errorf("too many keys (%d, max %d)", numKeys, maxKeys) - } - for i, k := range s.Keys { - if err := k.StaticValidate(); err != nil { - return fmt.Errorf("key[%d]: %v", i, err) - } - } - // NOTE: The max number of keys is constrained (512), so - // O(n^2) is fine. - for i, k := range s.Keys { - for j, k2 := range s.Keys { - if i == j { - continue - } - - id1, err := k.ID() - if err != nil { - return fmt.Errorf("key[%d]: %w", i, err) - } - id2, err := k2.ID() - if err != nil { - return fmt.Errorf("key[%d]: %w", j, err) - } - - if bytes.Equal(id1, id2) { - return fmt.Errorf("key[%d]: duplicates key[%d]", i, j) - } - } - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "errors" + "fmt" + + "golang.org/x/crypto/argon2" + "tailscale.com/types/tkatype" +) + +// ErrNoSuchKey is returned if the key referenced by a KeyID does not exist. +var ErrNoSuchKey = errors.New("key not found") + +// State describes Tailnet Key Authority state at an instant in time. +// +// State is mutated by applying Authority Update Messages (AUMs), resulting +// in a new State. +type State struct { + // LastAUMHash is the blake2s digest of the last-applied AUM. + // Because AUMs are strictly ordered and form a hash chain, we + // check the previous AUM hash in an update we are applying + // is the same as the LastAUMHash. + LastAUMHash *AUMHash `cbor:"1,keyasint"` + + // DisablementSecrets are KDF-derived values which can be used + // to turn off the TKA in the event of a consensus-breaking bug. + DisablementSecrets [][]byte `cbor:"2,keyasint"` + + // Keys are the public keys of either: + // + // 1. The signing nodes currently trusted by the TKA. + // 2. Ephemeral keys that were used to generate pre-signed auth keys. + Keys []Key `cbor:"3,keyasint"` + + // StateID's are nonce's, generated on enablement and fixed for + // the lifetime of the Tailnet Key Authority. We generate 16-bytes + // worth of keyspace here just in case we come up with a cool future + // use for this. + StateID1 uint64 `cbor:"4,keyasint,omitempty"` + StateID2 uint64 `cbor:"5,keyasint,omitempty"` +} + +// GetKey returns the trusted key with the specified KeyID. +func (s State) GetKey(key tkatype.KeyID) (Key, error) { + for _, k := range s.Keys { + keyID, err := k.ID() + if err != nil { + return Key{}, err + } + + if bytes.Equal(keyID, key) { + return k, nil + } + } + + return Key{}, ErrNoSuchKey +} + +// Clone makes an independent copy of State. +// +// NOTE: There is a difference between a nil slice and an empty +// slice for encoding purposes, so an implementation of Clone() +// must take care to preserve this. +func (s State) Clone() State { + out := State{ + StateID1: s.StateID1, + StateID2: s.StateID2, + } + + if s.LastAUMHash != nil { + dupe := *s.LastAUMHash + out.LastAUMHash = &dupe + } + + if s.DisablementSecrets != nil { + out.DisablementSecrets = make([][]byte, len(s.DisablementSecrets)) + for i := range s.DisablementSecrets { + out.DisablementSecrets[i] = make([]byte, len(s.DisablementSecrets[i])) + copy(out.DisablementSecrets[i], s.DisablementSecrets[i]) + } + } + + if s.Keys != nil { + out.Keys = make([]Key, len(s.Keys)) + for i := range s.Keys { + out.Keys[i] = s.Keys[i].Clone() + } + } + + return out +} + +// cloneForUpdate is like Clone, except LastAUMHash is set based +// on the hash of the given update. +func (s State) cloneForUpdate(update *AUM) State { + out := s.Clone() + aumHash := update.Hash() + out.LastAUMHash = &aumHash + return out +} + +const disablementLength = 32 + +var disablementSalt = []byte("tailscale network-lock disablement salt") + +// DisablementKDF computes a public value which can be stored in a +// key authority, but cannot be reversed to find the input secret. +// +// When the output of this function is stored in tka state (i.e. in +// tka.State.DisablementSecrets) a call to Authority.ValidDisablement() +// with the input of this function as the argument will return true. +func DisablementKDF(secret []byte) []byte { + // time = 4 (3 recommended, booped to 4 to compensate for less memory) + // memory = 16 (32 recommended) + // threads = 4 + // keyLen = 32 (256 bits) + return argon2.Key(secret, disablementSalt, 4, 16*1024, 4, disablementLength) +} + +// checkDisablement returns true for a valid disablement secret. +func (s State) checkDisablement(secret []byte) bool { + derived := DisablementKDF(secret) + for _, candidate := range s.DisablementSecrets { + if bytes.Equal(derived, candidate) { + return true + } + } + return false +} + +// parentMatches returns true if an AUM can chain to (be applied) +// to the current state. +// +// Specifically, the rules are: +// - The last AUM hash must match (transitively, this implies that this +// update follows the last update message applied to the state machine) +// - Or, the state machine knows no parent (its brand new). +func (s State) parentMatches(update AUM) bool { + if s.LastAUMHash == nil { + return true + } + return bytes.Equal(s.LastAUMHash[:], update.PrevAUMHash) +} + +// applyVerifiedAUM computes a new state based on the update provided. +// +// The provided update MUST be verified: That is, the AUM must be well-formed +// (as defined by StaticValidate()), and signatures over the AUM must have +// been verified. +func (s State) applyVerifiedAUM(update AUM) (State, error) { + // Validate that the update message has the right parent. + if !s.parentMatches(update) { + return State{}, errors.New("parent AUMHash mismatch") + } + + switch update.MessageKind { + case AUMNoOp: + out := s.cloneForUpdate(&update) + return out, nil + + case AUMCheckpoint: + if update.State == nil { + return State{}, errors.New("missing checkpoint state") + } + id1Match, id2Match := update.State.StateID1 == s.StateID1, update.State.StateID2 == s.StateID2 + if !id1Match || !id2Match { + return State{}, errors.New("checkpointed state has an incorrect stateID") + } + return update.State.cloneForUpdate(&update), nil + + case AUMAddKey: + if update.Key == nil { + return State{}, errors.New("no key to add provided") + } + keyID, err := update.Key.ID() + if err != nil { + return State{}, err + } + if _, err := s.GetKey(keyID); err == nil { + return State{}, errors.New("key already exists") + } + out := s.cloneForUpdate(&update) + out.Keys = append(out.Keys, *update.Key) + return out, nil + + case AUMUpdateKey: + k, err := s.GetKey(update.KeyID) + if err != nil { + return State{}, err + } + if update.Votes != nil { + k.Votes = *update.Votes + } + if update.Meta != nil { + k.Meta = update.Meta + } + if err := k.StaticValidate(); err != nil { + return State{}, fmt.Errorf("updated key fails validation: %v", err) + } + out := s.cloneForUpdate(&update) + for i := range out.Keys { + keyID, err := out.Keys[i].ID() + if err != nil { + return State{}, err + } + if bytes.Equal(keyID, update.KeyID) { + out.Keys[i] = k + } + } + return out, nil + + case AUMRemoveKey: + idx := -1 + for i := range s.Keys { + keyID, err := s.Keys[i].ID() + if err != nil { + return State{}, err + } + if bytes.Equal(update.KeyID, keyID) { + idx = i + break + } + } + if idx < 0 { + return State{}, ErrNoSuchKey + } + out := s.cloneForUpdate(&update) + out.Keys = append(out.Keys[:idx], out.Keys[idx+1:]...) + return out, nil + + default: + // An AUM with an unknown message kind was received! That means + // that a future version of tailscaled added some feature we don't + // understand. + // + // The future-compatibility contract for AUM message types is that + // they must only add new features, not change the semantics of existing + // mechanisms or features. As such, old clients can safely ignore them. + out := s.cloneForUpdate(&update) + return out, nil + } +} + +// Upper bound on checkpoint elements, chosen arbitrarily. Intended to +// cap out insanely large AUMs. +const ( + maxDisablementSecrets = 32 + maxKeys = 512 +) + +// staticValidateCheckpoint validates that the state is well-formed for +// inclusion in a checkpoint AUM. +func (s *State) staticValidateCheckpoint() error { + if s.LastAUMHash != nil { + return errors.New("cannot specify a parent AUM") + } + if len(s.DisablementSecrets) == 0 { + return errors.New("at least one disablement secret required") + } + if numDS := len(s.DisablementSecrets); numDS > maxDisablementSecrets { + return fmt.Errorf("too many disablement secrets (%d, max %d)", numDS, maxDisablementSecrets) + } + for i, ds := range s.DisablementSecrets { + if len(ds) != disablementLength { + return fmt.Errorf("disablement[%d]: invalid length (got %d, want %d)", i, len(ds), disablementLength) + } + for j, ds2 := range s.DisablementSecrets { + if i == j { + continue + } + if bytes.Equal(ds, ds2) { + return fmt.Errorf("disablement[%d]: duplicates disablement[%d]", i, j) + } + } + } + + if len(s.Keys) == 0 { + return errors.New("at least one key is required") + } + if numKeys := len(s.Keys); numKeys > maxKeys { + return fmt.Errorf("too many keys (%d, max %d)", numKeys, maxKeys) + } + for i, k := range s.Keys { + if err := k.StaticValidate(); err != nil { + return fmt.Errorf("key[%d]: %v", i, err) + } + } + // NOTE: The max number of keys is constrained (512), so + // O(n^2) is fine. + for i, k := range s.Keys { + for j, k2 := range s.Keys { + if i == j { + continue + } + + id1, err := k.ID() + if err != nil { + return fmt.Errorf("key[%d]: %w", i, err) + } + id2, err := k2.ID() + if err != nil { + return fmt.Errorf("key[%d]: %w", j, err) + } + + if bytes.Equal(id1, id2) { + return fmt.Errorf("key[%d]: duplicates key[%d]", i, j) + } + } + } + return nil +} diff --git a/tka/state_test.go b/tka/state_test.go index 060bd9350dd06..b8337dd8a6cb8 100644 --- a/tka/state_test.go +++ b/tka/state_test.go @@ -1,260 +1,260 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "encoding/hex" - "errors" - "testing" - - "github.com/fxamacker/cbor/v2" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" -) - -func fromHex(in string) []byte { - out, err := hex.DecodeString(in) - if err != nil { - panic(err) - } - return out -} - -func hashFromHex(in string) *AUMHash { - var out AUMHash - copy(out[:], fromHex(in)) - return &out -} - -func TestCloneState(t *testing.T) { - tcs := []struct { - Name string - State State - }{ - { - "Empty", - State{}, - }, - { - "Key", - State{ - Keys: []Key{{Kind: Key25519, Votes: 2, Public: []byte{5, 6, 7, 8}, Meta: map[string]string{"a": "b"}}}, - }, - }, - { - "StateID", - State{ - StateID1: 42, - StateID2: 22, - }, - }, - { - "DisablementSecrets", - State{ - DisablementSecrets: [][]byte{ - {1, 2, 3, 4}, - {5, 6, 7, 8}, - }, - }, - }, - } - - for _, tc := range tcs { - t.Run(tc.Name, func(t *testing.T) { - if diff := cmp.Diff(tc.State, tc.State.Clone()); diff != "" { - t.Errorf("output state differs (-want, +got):\n%s", diff) - } - - // Make sure the cloned State is the same even after - // an encode + decode into + from CBOR. - t.Run("cbor", func(t *testing.T) { - out := bytes.NewBuffer(nil) - encoder, err := cbor.CTAP2EncOptions().EncMode() - if err != nil { - t.Fatal(err) - } - if err := encoder.NewEncoder(out).Encode(tc.State.Clone()); err != nil { - t.Fatal(err) - } - - var decodedState State - if err := cbor.Unmarshal(out.Bytes(), &decodedState); err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - if diff := cmp.Diff(tc.State, decodedState); diff != "" { - t.Errorf("decoded state differs (-want, +got):\n%s", diff) - } - }) - }) - } -} - -func TestApplyUpdatesChain(t *testing.T) { - intOne := uint(1) - tcs := []struct { - Name string - Updates []AUM - Start State - End State - }{ - { - "AddKey", - []AUM{{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, - State{}, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), - }, - }, - { - "RemoveKey", - []AUM{{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03")}}, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), - }, - State{ - LastAUMHash: hashFromHex("15d65756abfafbb592279503f40759898590c9c59056be1e2e9f02684c15ba4b"), - }, - }, - { - "UpdateKey", - []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1, 2, 3, 4}, Votes: &intOne, Meta: map[string]string{"a": "b"}, PrevAUMHash: fromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03")}}, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), - }, - State{ - LastAUMHash: hashFromHex("d55458a9c3ed6997439ba5a18b9b62d2c6e5e0c1bb4c61409e92a1281a3b458d"), - Keys: []Key{{Kind: Key25519, Votes: 1, Meta: map[string]string{"a": "b"}, Public: []byte{1, 2, 3, 4}}}, - }, - }, - { - "ChainedKeyUpdates", - []AUM{ - {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, - {MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("f09bda3bb7cf6756ea9adc25770aede4b3ca8142949d6ef5ca0add29af912fd4")}, - }, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - }, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, - LastAUMHash: hashFromHex("218165fe5f757304b9deaff4ac742890364f5f509e533c74e80e0ce35e44ee1d"), - }, - }, - { - "Checkpoint", - []AUM{ - {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, - {MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - }, PrevAUMHash: fromHex("f09bda3bb7cf6756ea9adc25770aede4b3ca8142949d6ef5ca0add29af912fd4")}, - }, - State{DisablementSecrets: [][]byte{{1, 2, 3, 4}}}, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - LastAUMHash: hashFromHex("57343671da5eea3cfb502954e976e8028bffd3540b50a043b2a65a8d8d8217d0"), - }, - }, - } - - for _, tc := range tcs { - t.Run(tc.Name, func(t *testing.T) { - state := tc.Start - for i := range tc.Updates { - var err error - // t.Logf("update[%d] start-state = %+v", i, state) - state, err = state.applyVerifiedAUM(tc.Updates[i]) - if err != nil { - t.Fatalf("Apply message[%d] failed: %v", i, err) - } - // t.Logf("update[%d] end-state = %+v", i, state) - - updateHash := tc.Updates[i].Hash() - if got, want := *state.LastAUMHash, updateHash[:]; !bytes.Equal(got[:], want) { - t.Errorf("expected state.LastAUMHash = %x (update %d), got %x", want, i, got) - } - } - - if diff := cmp.Diff(tc.End, state, cmpopts.EquateEmpty()); diff != "" { - t.Errorf("output state differs (+got, -want):\n%s", diff) - } - }) - } -} - -func TestApplyUpdateErrors(t *testing.T) { - tooLargeVotes := uint(99999) - tcs := []struct { - Name string - Updates []AUM - Start State - Error error - }{ - { - "AddKey exists", - []AUM{{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, - State{Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, - errors.New("key already exists"), - }, - { - "RemoveKey notfound", - []AUM{{MessageKind: AUMRemoveKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, - State{}, - ErrNoSuchKey, - }, - { - "UpdateKey notfound", - []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}}}, - State{}, - ErrNoSuchKey, - }, - { - "UpdateKey now fails validation", - []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}, Votes: &tooLargeVotes}}, - State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}}, - errors.New("updated key fails validation: excessive key weight: 99999 > 4096"), - }, - { - "Bad lastAUMHash", - []AUM{ - {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, - {MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("1234")}, - }, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - }, - errors.New("parent AUMHash mismatch"), - }, - { - "Bad StateID", - []AUM{{MessageKind: AUMCheckpoint, State: &State{StateID1: 1}}}, - State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}, StateID1: 42}, - errors.New("checkpointed state has an incorrect stateID"), - }, - } - - for _, tc := range tcs { - t.Run(tc.Name, func(t *testing.T) { - state := tc.Start - for i := range tc.Updates { - var err error - // t.Logf("update[%d] start-state = %+v", i, state) - state, err = state.applyVerifiedAUM(tc.Updates[i]) - if err != nil { - if err.Error() != tc.Error.Error() { - t.Errorf("state[%d].Err = %v, want %v", i, err, tc.Error) - } else { - return - } - } - // t.Logf("update[%d] end-state = %+v", i, state) - } - - t.Errorf("did not error, expected %v", tc.Error) - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "encoding/hex" + "errors" + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func fromHex(in string) []byte { + out, err := hex.DecodeString(in) + if err != nil { + panic(err) + } + return out +} + +func hashFromHex(in string) *AUMHash { + var out AUMHash + copy(out[:], fromHex(in)) + return &out +} + +func TestCloneState(t *testing.T) { + tcs := []struct { + Name string + State State + }{ + { + "Empty", + State{}, + }, + { + "Key", + State{ + Keys: []Key{{Kind: Key25519, Votes: 2, Public: []byte{5, 6, 7, 8}, Meta: map[string]string{"a": "b"}}}, + }, + }, + { + "StateID", + State{ + StateID1: 42, + StateID2: 22, + }, + }, + { + "DisablementSecrets", + State{ + DisablementSecrets: [][]byte{ + {1, 2, 3, 4}, + {5, 6, 7, 8}, + }, + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.Name, func(t *testing.T) { + if diff := cmp.Diff(tc.State, tc.State.Clone()); diff != "" { + t.Errorf("output state differs (-want, +got):\n%s", diff) + } + + // Make sure the cloned State is the same even after + // an encode + decode into + from CBOR. + t.Run("cbor", func(t *testing.T) { + out := bytes.NewBuffer(nil) + encoder, err := cbor.CTAP2EncOptions().EncMode() + if err != nil { + t.Fatal(err) + } + if err := encoder.NewEncoder(out).Encode(tc.State.Clone()); err != nil { + t.Fatal(err) + } + + var decodedState State + if err := cbor.Unmarshal(out.Bytes(), &decodedState); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if diff := cmp.Diff(tc.State, decodedState); diff != "" { + t.Errorf("decoded state differs (-want, +got):\n%s", diff) + } + }) + }) + } +} + +func TestApplyUpdatesChain(t *testing.T) { + intOne := uint(1) + tcs := []struct { + Name string + Updates []AUM + Start State + End State + }{ + { + "AddKey", + []AUM{{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, + State{}, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), + }, + }, + { + "RemoveKey", + []AUM{{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03")}}, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), + }, + State{ + LastAUMHash: hashFromHex("15d65756abfafbb592279503f40759898590c9c59056be1e2e9f02684c15ba4b"), + }, + }, + { + "UpdateKey", + []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1, 2, 3, 4}, Votes: &intOne, Meta: map[string]string{"a": "b"}, PrevAUMHash: fromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03")}}, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), + }, + State{ + LastAUMHash: hashFromHex("d55458a9c3ed6997439ba5a18b9b62d2c6e5e0c1bb4c61409e92a1281a3b458d"), + Keys: []Key{{Kind: Key25519, Votes: 1, Meta: map[string]string{"a": "b"}, Public: []byte{1, 2, 3, 4}}}, + }, + }, + { + "ChainedKeyUpdates", + []AUM{ + {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, + {MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("f09bda3bb7cf6756ea9adc25770aede4b3ca8142949d6ef5ca0add29af912fd4")}, + }, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + }, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, + LastAUMHash: hashFromHex("218165fe5f757304b9deaff4ac742890364f5f509e533c74e80e0ce35e44ee1d"), + }, + }, + { + "Checkpoint", + []AUM{ + {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, + {MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + }, PrevAUMHash: fromHex("f09bda3bb7cf6756ea9adc25770aede4b3ca8142949d6ef5ca0add29af912fd4")}, + }, + State{DisablementSecrets: [][]byte{{1, 2, 3, 4}}}, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + LastAUMHash: hashFromHex("57343671da5eea3cfb502954e976e8028bffd3540b50a043b2a65a8d8d8217d0"), + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.Name, func(t *testing.T) { + state := tc.Start + for i := range tc.Updates { + var err error + // t.Logf("update[%d] start-state = %+v", i, state) + state, err = state.applyVerifiedAUM(tc.Updates[i]) + if err != nil { + t.Fatalf("Apply message[%d] failed: %v", i, err) + } + // t.Logf("update[%d] end-state = %+v", i, state) + + updateHash := tc.Updates[i].Hash() + if got, want := *state.LastAUMHash, updateHash[:]; !bytes.Equal(got[:], want) { + t.Errorf("expected state.LastAUMHash = %x (update %d), got %x", want, i, got) + } + } + + if diff := cmp.Diff(tc.End, state, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("output state differs (+got, -want):\n%s", diff) + } + }) + } +} + +func TestApplyUpdateErrors(t *testing.T) { + tooLargeVotes := uint(99999) + tcs := []struct { + Name string + Updates []AUM + Start State + Error error + }{ + { + "AddKey exists", + []AUM{{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, + State{Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, + errors.New("key already exists"), + }, + { + "RemoveKey notfound", + []AUM{{MessageKind: AUMRemoveKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, + State{}, + ErrNoSuchKey, + }, + { + "UpdateKey notfound", + []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}}}, + State{}, + ErrNoSuchKey, + }, + { + "UpdateKey now fails validation", + []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}, Votes: &tooLargeVotes}}, + State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}}, + errors.New("updated key fails validation: excessive key weight: 99999 > 4096"), + }, + { + "Bad lastAUMHash", + []AUM{ + {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, + {MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("1234")}, + }, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + }, + errors.New("parent AUMHash mismatch"), + }, + { + "Bad StateID", + []AUM{{MessageKind: AUMCheckpoint, State: &State{StateID1: 1}}}, + State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}, StateID1: 42}, + errors.New("checkpointed state has an incorrect stateID"), + }, + } + + for _, tc := range tcs { + t.Run(tc.Name, func(t *testing.T) { + state := tc.Start + for i := range tc.Updates { + var err error + // t.Logf("update[%d] start-state = %+v", i, state) + state, err = state.applyVerifiedAUM(tc.Updates[i]) + if err != nil { + if err.Error() != tc.Error.Error() { + t.Errorf("state[%d].Err = %v, want %v", i, err, tc.Error) + } else { + return + } + } + // t.Logf("update[%d] end-state = %+v", i, state) + } + + t.Errorf("did not error, expected %v", tc.Error) + }) + } +} diff --git a/tka/sync_test.go b/tka/sync_test.go index 7250eacf7d143..d214020c41af4 100644 --- a/tka/sync_test.go +++ b/tka/sync_test.go @@ -1,377 +1,377 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "strconv" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestSyncOffer(t *testing.T) { - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 - A10 -> A11 -> A12 -> A13 -> A14 -> A15 -> A16 -> A17 -> A18 - A18 -> A19 -> A20 -> A21 -> A22 -> A23 -> A24 -> A25 - `) - storage := c.Chonk() - a, err := Open(storage) - if err != nil { - t.Fatal(err) - } - got, err := a.SyncOffer(storage) - if err != nil { - t.Fatal(err) - } - - // A SyncOffer includes a selection of AUMs going backwards in the tree, - // progressively skipping more and more each iteration. - want := SyncOffer{ - Head: c.AUMHashes["A25"], - Ancestors: []AUMHash{ - c.AUMHashes["A"+strconv.Itoa(25-ancestorsSkipStart)], - c.AUMHashes["A"+strconv.Itoa(25-ancestorsSkipStart< A2 - // Node 2 has: A1 -> A2 -> A3 -> A4 - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 - `) - a1H, a2H := c.AUMHashes["A1"], c.AUMHashes["A2"] - - chonk1 := c.ChonkWith("A1", "A2") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } - - chonk2 := c.Chonk() // All AUMs - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } - - // Node 1 only knows about the first two nodes, so the head of n2 is - // alien to it. - t.Run("n1", func(t *testing.T) { - got, err := computeSyncIntersection(chonk1, offer1, offer2) - if err != nil { - t.Fatalf("computeSyncIntersection() failed: %v", err) - } - want := &intersection{ - tailIntersection: &a1H, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { - t.Errorf("intersection diff (-want, +got):\n%s", diff) - } - }) - - // Node 2 knows about the full chain, so it can see that the head of n1 - // intersects with a subset of its chain (a Head Intersection). - t.Run("n2", func(t *testing.T) { - got, err := computeSyncIntersection(chonk2, offer2, offer1) - if err != nil { - t.Fatalf("computeSyncIntersection() failed: %v", err) - } - want := &intersection{ - headIntersection: &a2H, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { - t.Errorf("intersection diff (-want, +got):\n%s", diff) - } - }) -} - -func TestComputeSyncIntersection_ForkSmallDiff(t *testing.T) { - // The number of nodes in the chain is longer than ancestorSkipStart, - // so that during sync both nodes are able to find a common ancestor - // which was later than A1. - - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 - | -> F1 - // Make F1 different to A9. - // hashSeed is chosen such that the hash is higher than A9. - F1.hashSeed = 7 - `) - // Node 1 has: A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> F1 - // Node 2 has: A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 - f1H, a9H := c.AUMHashes["F1"], c.AUMHashes["A9"] - - if bytes.Compare(f1H[:], a9H[:]) < 0 { - t.Fatal("failed assert: h(a9) > h(f1H)\nTweak hashSeed till this passes") - } - - chonk1 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "F1") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(SyncOffer{ - Head: c.AUMHashes["F1"], - Ancestors: []AUMHash{ - c.AUMHashes["A"+strconv.Itoa(9-ancestorsSkipStart)], - c.AUMHashes["A1"], - }, - }, offer1); diff != "" { - t.Errorf("offer1 diff (-want, +got):\n%s", diff) - } - - chonk2 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9", "A10") - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(SyncOffer{ - Head: c.AUMHashes["A10"], - Ancestors: []AUMHash{ - c.AUMHashes["A"+strconv.Itoa(10-ancestorsSkipStart)], - c.AUMHashes["A1"], - }, - }, offer2); diff != "" { - t.Errorf("offer2 diff (-want, +got):\n%s", diff) - } - - // Node 1 only knows about the first eight nodes, so the head of n2 is - // alien to it. - t.Run("n1", func(t *testing.T) { - // n2 has 10 nodes, so the first common ancestor should be 10-ancestorsSkipStart - wantIntersection := c.AUMHashes["A"+strconv.Itoa(10-ancestorsSkipStart)] - - got, err := computeSyncIntersection(chonk1, offer1, offer2) - if err != nil { - t.Fatalf("computeSyncIntersection() failed: %v", err) - } - want := &intersection{ - tailIntersection: &wantIntersection, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { - t.Errorf("intersection diff (-want, +got):\n%s", diff) - } - }) - - // Node 2 knows about the full chain but doesn't recognize the head. - t.Run("n2", func(t *testing.T) { - // n1 has 9 nodes, so the first common ancestor should be 9-ancestorsSkipStart - wantIntersection := c.AUMHashes["A"+strconv.Itoa(9-ancestorsSkipStart)] - - got, err := computeSyncIntersection(chonk2, offer2, offer1) - if err != nil { - t.Fatalf("computeSyncIntersection() failed: %v", err) - } - want := &intersection{ - tailIntersection: &wantIntersection, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { - t.Errorf("intersection diff (-want, +got):\n%s", diff) - } - }) -} - -func TestMissingAUMs_FastForward(t *testing.T) { - // Node 1 has: A1 -> A2 - // Node 2 has: A1 -> A2 -> A3 -> A4 - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 - A1.hashSeed = 1 - A2.hashSeed = 2 - A3.hashSeed = 3 - A4.hashSeed = 4 - `) - - chonk1 := c.ChonkWith("A1", "A2") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } - - chonk2 := c.Chonk() // All AUMs - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } - - // Node 1 only knows about the first two nodes, so the head of n2 is - // alien to it. As such, it should send history from the newest ancestor, - // A1 (if the chain was longer there would be one in the middle). - t.Run("n1", func(t *testing.T) { - got, err := n1.MissingAUMs(chonk1, offer2) - if err != nil { - t.Fatalf("MissingAUMs() failed: %v", err) - } - - // Both sides have A1, so the only AUM that n2 might not have is - // A2. - want := []AUM{c.AUMs["A2"]} - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) - } - }) - - // Node 2 knows about the full chain, so it can see that the head of n1 - // intersects with a subset of its chain (a Head Intersection). - t.Run("n2", func(t *testing.T) { - got, err := n2.MissingAUMs(chonk2, offer1) - if err != nil { - t.Fatalf("MissingAUMs() failed: %v", err) - } - - want := []AUM{ - c.AUMs["A3"], - c.AUMs["A4"], - } - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) - } - }) -} - -func TestMissingAUMs_Fork(t *testing.T) { - // Node 1 has: A1 -> A2 -> A3 -> F1 - // Node 2 has: A1 -> A2 -> A3 -> A4 - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 - | -> F1 - A1.hashSeed = 1 - A2.hashSeed = 2 - A3.hashSeed = 3 - A4.hashSeed = 4 - `) - - chonk1 := c.ChonkWith("A1", "A2", "A3", "F1") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } - - chonk2 := c.ChonkWith("A1", "A2", "A3", "A4") - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } - - t.Run("n1", func(t *testing.T) { - got, err := n1.MissingAUMs(chonk1, offer2) - if err != nil { - t.Fatalf("MissingAUMs() failed: %v", err) - } - - // Both sides have A1, so n1 will send everything it knows from - // there to head. - want := []AUM{ - c.AUMs["A2"], - c.AUMs["A3"], - c.AUMs["F1"], - } - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) - } - }) - - t.Run("n2", func(t *testing.T) { - got, err := n2.MissingAUMs(chonk2, offer1) - if err != nil { - t.Fatalf("MissingAUMs() failed: %v", err) - } - - // Both sides have A1, so n2 will send everything it knows from - // there to head. - want := []AUM{ - c.AUMs["A2"], - c.AUMs["A3"], - c.AUMs["A4"], - } - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) - } - }) -} - -func TestSyncSimpleE2E(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> L1 -> L2 -> L3 - G1.template = genesis - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optKey("key", key, priv), - optSignAllUsing("key")) - - nodeStorage := &Mem{} - node, err := Bootstrap(nodeStorage, c.AUMs["G1"]) - if err != nil { - t.Fatalf("node Bootstrap() failed: %v", err) - } - controlStorage := c.Chonk() - control, err := Open(controlStorage) - if err != nil { - t.Fatalf("control Open() failed: %v", err) - } - - // Control knows the full chain, node only knows the genesis. Lets see - // if they can sync. - nodeOffer, err := node.SyncOffer(nodeStorage) - if err != nil { - t.Fatal(err) - } - controlAUMs, err := control.MissingAUMs(controlStorage, nodeOffer) - if err != nil { - t.Fatalf("control.MissingAUMs(%v) failed: %v", nodeOffer, err) - } - if err := node.Inform(nodeStorage, controlAUMs); err != nil { - t.Fatalf("node.Inform(%v) failed: %v", controlAUMs, err) - } - - if cHash, nHash := control.Head(), node.Head(); cHash != nHash { - t.Errorf("node & control are not synced: c=%x, n=%x", cHash, nHash) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "strconv" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestSyncOffer(t *testing.T) { + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 + A10 -> A11 -> A12 -> A13 -> A14 -> A15 -> A16 -> A17 -> A18 + A18 -> A19 -> A20 -> A21 -> A22 -> A23 -> A24 -> A25 + `) + storage := c.Chonk() + a, err := Open(storage) + if err != nil { + t.Fatal(err) + } + got, err := a.SyncOffer(storage) + if err != nil { + t.Fatal(err) + } + + // A SyncOffer includes a selection of AUMs going backwards in the tree, + // progressively skipping more and more each iteration. + want := SyncOffer{ + Head: c.AUMHashes["A25"], + Ancestors: []AUMHash{ + c.AUMHashes["A"+strconv.Itoa(25-ancestorsSkipStart)], + c.AUMHashes["A"+strconv.Itoa(25-ancestorsSkipStart< A2 + // Node 2 has: A1 -> A2 -> A3 -> A4 + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 + `) + a1H, a2H := c.AUMHashes["A1"], c.AUMHashes["A2"] + + chonk1 := c.ChonkWith("A1", "A2") + n1, err := Open(chonk1) + if err != nil { + t.Fatal(err) + } + offer1, err := n1.SyncOffer(chonk1) + if err != nil { + t.Fatal(err) + } + + chonk2 := c.Chonk() // All AUMs + n2, err := Open(chonk2) + if err != nil { + t.Fatal(err) + } + offer2, err := n2.SyncOffer(chonk2) + if err != nil { + t.Fatal(err) + } + + // Node 1 only knows about the first two nodes, so the head of n2 is + // alien to it. + t.Run("n1", func(t *testing.T) { + got, err := computeSyncIntersection(chonk1, offer1, offer2) + if err != nil { + t.Fatalf("computeSyncIntersection() failed: %v", err) + } + want := &intersection{ + tailIntersection: &a1H, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { + t.Errorf("intersection diff (-want, +got):\n%s", diff) + } + }) + + // Node 2 knows about the full chain, so it can see that the head of n1 + // intersects with a subset of its chain (a Head Intersection). + t.Run("n2", func(t *testing.T) { + got, err := computeSyncIntersection(chonk2, offer2, offer1) + if err != nil { + t.Fatalf("computeSyncIntersection() failed: %v", err) + } + want := &intersection{ + headIntersection: &a2H, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { + t.Errorf("intersection diff (-want, +got):\n%s", diff) + } + }) +} + +func TestComputeSyncIntersection_ForkSmallDiff(t *testing.T) { + // The number of nodes in the chain is longer than ancestorSkipStart, + // so that during sync both nodes are able to find a common ancestor + // which was later than A1. + + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 + | -> F1 + // Make F1 different to A9. + // hashSeed is chosen such that the hash is higher than A9. + F1.hashSeed = 7 + `) + // Node 1 has: A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> F1 + // Node 2 has: A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 + f1H, a9H := c.AUMHashes["F1"], c.AUMHashes["A9"] + + if bytes.Compare(f1H[:], a9H[:]) < 0 { + t.Fatal("failed assert: h(a9) > h(f1H)\nTweak hashSeed till this passes") + } + + chonk1 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "F1") + n1, err := Open(chonk1) + if err != nil { + t.Fatal(err) + } + offer1, err := n1.SyncOffer(chonk1) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(SyncOffer{ + Head: c.AUMHashes["F1"], + Ancestors: []AUMHash{ + c.AUMHashes["A"+strconv.Itoa(9-ancestorsSkipStart)], + c.AUMHashes["A1"], + }, + }, offer1); diff != "" { + t.Errorf("offer1 diff (-want, +got):\n%s", diff) + } + + chonk2 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9", "A10") + n2, err := Open(chonk2) + if err != nil { + t.Fatal(err) + } + offer2, err := n2.SyncOffer(chonk2) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(SyncOffer{ + Head: c.AUMHashes["A10"], + Ancestors: []AUMHash{ + c.AUMHashes["A"+strconv.Itoa(10-ancestorsSkipStart)], + c.AUMHashes["A1"], + }, + }, offer2); diff != "" { + t.Errorf("offer2 diff (-want, +got):\n%s", diff) + } + + // Node 1 only knows about the first eight nodes, so the head of n2 is + // alien to it. + t.Run("n1", func(t *testing.T) { + // n2 has 10 nodes, so the first common ancestor should be 10-ancestorsSkipStart + wantIntersection := c.AUMHashes["A"+strconv.Itoa(10-ancestorsSkipStart)] + + got, err := computeSyncIntersection(chonk1, offer1, offer2) + if err != nil { + t.Fatalf("computeSyncIntersection() failed: %v", err) + } + want := &intersection{ + tailIntersection: &wantIntersection, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { + t.Errorf("intersection diff (-want, +got):\n%s", diff) + } + }) + + // Node 2 knows about the full chain but doesn't recognize the head. + t.Run("n2", func(t *testing.T) { + // n1 has 9 nodes, so the first common ancestor should be 9-ancestorsSkipStart + wantIntersection := c.AUMHashes["A"+strconv.Itoa(9-ancestorsSkipStart)] + + got, err := computeSyncIntersection(chonk2, offer2, offer1) + if err != nil { + t.Fatalf("computeSyncIntersection() failed: %v", err) + } + want := &intersection{ + tailIntersection: &wantIntersection, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { + t.Errorf("intersection diff (-want, +got):\n%s", diff) + } + }) +} + +func TestMissingAUMs_FastForward(t *testing.T) { + // Node 1 has: A1 -> A2 + // Node 2 has: A1 -> A2 -> A3 -> A4 + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 + A1.hashSeed = 1 + A2.hashSeed = 2 + A3.hashSeed = 3 + A4.hashSeed = 4 + `) + + chonk1 := c.ChonkWith("A1", "A2") + n1, err := Open(chonk1) + if err != nil { + t.Fatal(err) + } + offer1, err := n1.SyncOffer(chonk1) + if err != nil { + t.Fatal(err) + } + + chonk2 := c.Chonk() // All AUMs + n2, err := Open(chonk2) + if err != nil { + t.Fatal(err) + } + offer2, err := n2.SyncOffer(chonk2) + if err != nil { + t.Fatal(err) + } + + // Node 1 only knows about the first two nodes, so the head of n2 is + // alien to it. As such, it should send history from the newest ancestor, + // A1 (if the chain was longer there would be one in the middle). + t.Run("n1", func(t *testing.T) { + got, err := n1.MissingAUMs(chonk1, offer2) + if err != nil { + t.Fatalf("MissingAUMs() failed: %v", err) + } + + // Both sides have A1, so the only AUM that n2 might not have is + // A2. + want := []AUM{c.AUMs["A2"]} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) + } + }) + + // Node 2 knows about the full chain, so it can see that the head of n1 + // intersects with a subset of its chain (a Head Intersection). + t.Run("n2", func(t *testing.T) { + got, err := n2.MissingAUMs(chonk2, offer1) + if err != nil { + t.Fatalf("MissingAUMs() failed: %v", err) + } + + want := []AUM{ + c.AUMs["A3"], + c.AUMs["A4"], + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) + } + }) +} + +func TestMissingAUMs_Fork(t *testing.T) { + // Node 1 has: A1 -> A2 -> A3 -> F1 + // Node 2 has: A1 -> A2 -> A3 -> A4 + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 + | -> F1 + A1.hashSeed = 1 + A2.hashSeed = 2 + A3.hashSeed = 3 + A4.hashSeed = 4 + `) + + chonk1 := c.ChonkWith("A1", "A2", "A3", "F1") + n1, err := Open(chonk1) + if err != nil { + t.Fatal(err) + } + offer1, err := n1.SyncOffer(chonk1) + if err != nil { + t.Fatal(err) + } + + chonk2 := c.ChonkWith("A1", "A2", "A3", "A4") + n2, err := Open(chonk2) + if err != nil { + t.Fatal(err) + } + offer2, err := n2.SyncOffer(chonk2) + if err != nil { + t.Fatal(err) + } + + t.Run("n1", func(t *testing.T) { + got, err := n1.MissingAUMs(chonk1, offer2) + if err != nil { + t.Fatalf("MissingAUMs() failed: %v", err) + } + + // Both sides have A1, so n1 will send everything it knows from + // there to head. + want := []AUM{ + c.AUMs["A2"], + c.AUMs["A3"], + c.AUMs["F1"], + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) + } + }) + + t.Run("n2", func(t *testing.T) { + got, err := n2.MissingAUMs(chonk2, offer1) + if err != nil { + t.Fatalf("MissingAUMs() failed: %v", err) + } + + // Both sides have A1, so n2 will send everything it knows from + // there to head. + want := []AUM{ + c.AUMs["A2"], + c.AUMs["A3"], + c.AUMs["A4"], + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) + } + }) +} + +func TestSyncSimpleE2E(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 -> L2 -> L3 + G1.template = genesis + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optKey("key", key, priv), + optSignAllUsing("key")) + + nodeStorage := &Mem{} + node, err := Bootstrap(nodeStorage, c.AUMs["G1"]) + if err != nil { + t.Fatalf("node Bootstrap() failed: %v", err) + } + controlStorage := c.Chonk() + control, err := Open(controlStorage) + if err != nil { + t.Fatalf("control Open() failed: %v", err) + } + + // Control knows the full chain, node only knows the genesis. Lets see + // if they can sync. + nodeOffer, err := node.SyncOffer(nodeStorage) + if err != nil { + t.Fatal(err) + } + controlAUMs, err := control.MissingAUMs(controlStorage, nodeOffer) + if err != nil { + t.Fatalf("control.MissingAUMs(%v) failed: %v", nodeOffer, err) + } + if err := node.Inform(nodeStorage, controlAUMs); err != nil { + t.Fatalf("node.Inform(%v) failed: %v", controlAUMs, err) + } + + if cHash, nHash := control.Head(), node.Head(); cHash != nHash { + t.Errorf("node & control are not synced: c=%x, n=%x", cHash, nHash) + } +} diff --git a/tka/tailchonk_test.go b/tka/tailchonk_test.go index 86d5642a3bd10..13d989f0c3c63 100644 --- a/tka/tailchonk_test.go +++ b/tka/tailchonk_test.go @@ -1,693 +1,693 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "fmt" - "os" - "path/filepath" - "sync" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "golang.org/x/crypto/blake2s" -) - -// randHash derives a fake blake2s hash from the test name -// and the given seed. -func randHash(t *testing.T, seed int64) [blake2s.Size]byte { - var out [blake2s.Size]byte - testingRand(t, seed).Read(out[:]) - return out -} - -func TestImplementsChonk(t *testing.T) { - impls := []Chonk{&Mem{}, &FS{}} - t.Logf("chonks: %v", impls) -} - -func TestTailchonk_ChildAUMs(t *testing.T) { - for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { - t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { - parentHash := randHash(t, 1) - data := []AUM{ - { - MessageKind: AUMRemoveKey, - KeyID: []byte{1, 2}, - PrevAUMHash: parentHash[:], - }, - { - MessageKind: AUMRemoveKey, - KeyID: []byte{3, 4}, - PrevAUMHash: parentHash[:], - }, - } - - if err := chonk.CommitVerifiedAUMs(data); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) - } - stored, err := chonk.ChildAUMs(parentHash) - if err != nil { - t.Fatalf("ChildAUMs failed: %v", err) - } - if diff := cmp.Diff(data, stored); diff != "" { - t.Errorf("stored AUM differs (-want, +got):\n%s", diff) - } - }) - } -} - -func TestTailchonk_AUMMissing(t *testing.T) { - for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { - t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { - var notExists AUMHash - notExists[:][0] = 42 - if _, err := chonk.AUM(notExists); err != os.ErrNotExist { - t.Errorf("chonk.AUM(notExists).err = %v, want %v", err, os.ErrNotExist) - } - }) - } -} - -func TestTailchonkMem_Orphans(t *testing.T) { - chonk := Mem{} - - parentHash := randHash(t, 1) - orphan := AUM{MessageKind: AUMNoOp} - aums := []AUM{ - orphan, - // A parent is specified, so we shouldnt see it in GetOrphans() - { - MessageKind: AUMRemoveKey, - KeyID: []byte{3, 4}, - PrevAUMHash: parentHash[:], - }, - } - if err := chonk.CommitVerifiedAUMs(aums); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) - } - - stored, err := chonk.Orphans() - if err != nil { - t.Fatalf("Orphans failed: %v", err) - } - if diff := cmp.Diff([]AUM{orphan}, stored); diff != "" { - t.Errorf("stored AUM differs (-want, +got):\n%s", diff) - } -} - -func TestTailchonk_ReadChainFromHead(t *testing.T) { - for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { - - t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { - genesis := AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}} - gHash := genesis.Hash() - intermediate := AUM{PrevAUMHash: gHash[:]} - iHash := intermediate.Hash() - leaf := AUM{PrevAUMHash: iHash[:]} - - commitSet := []AUM{ - genesis, - intermediate, - leaf, - } - if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) - } - // t.Logf("genesis hash = %X", genesis.Hash()) - // t.Logf("intermediate hash = %X", intermediate.Hash()) - // t.Logf("leaf hash = %X", leaf.Hash()) - - // Read the chain from the leaf backwards. - gotLeafs, err := chonk.Heads() - if err != nil { - t.Fatalf("Heads failed: %v", err) - } - if diff := cmp.Diff([]AUM{leaf}, gotLeafs); diff != "" { - t.Fatalf("leaf AUM differs (-want, +got):\n%s", diff) - } - - parent, _ := gotLeafs[0].Parent() - gotIntermediate, err := chonk.AUM(parent) - if err != nil { - t.Fatalf("AUM() failed: %v", err) - } - if diff := cmp.Diff(intermediate, gotIntermediate); diff != "" { - t.Errorf("intermediate AUM differs (-want, +got):\n%s", diff) - } - - parent, _ = gotIntermediate.Parent() - gotGenesis, err := chonk.AUM(parent) - if err != nil { - t.Fatalf("AUM() failed: %v", err) - } - if diff := cmp.Diff(genesis, gotGenesis); diff != "" { - t.Errorf("genesis AUM differs (-want, +got):\n%s", diff) - } - }) - } -} - -func TestTailchonkFS_Commit(t *testing.T) { - chonk := &FS{base: t.TempDir()} - parentHash := randHash(t, 1) - aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} - - if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { - t.Fatal(err) - } - - dir, base := chonk.aumDir(aum.Hash()) - if got, want := dir, filepath.Join(chonk.base, "PD"); got != want { - t.Errorf("aum dir=%s, want %s", got, want) - } - if want := "PD57DVP6GKC76OOZMXFFZUSOEFQXOLAVT7N2ZM5KB3HDIMCANF4A"; base != want { - t.Errorf("aum base=%s, want %s", base, want) - } - if _, err := os.Stat(filepath.Join(dir, base)); err != nil { - t.Errorf("stat of AUM file failed: %v", err) - } - if _, err := os.Stat(filepath.Join(chonk.base, "M7", "M7LL2NDB4NKCZIUPVS6RDM2GUOIMW6EEAFVBWMVCPUANQJPHT3SQ")); err != nil { - t.Errorf("stat of AUM parent failed: %v", err) - } - - info, err := chonk.get(aum.Hash()) - if err != nil { - t.Fatal(err) - } - if info.PurgedUnix > 0 { - t.Errorf("recently-created AUM PurgedUnix = %d, want 0", info.PurgedUnix) - } -} - -func TestTailchonkFS_CommitTime(t *testing.T) { - chonk := &FS{base: t.TempDir()} - parentHash := randHash(t, 1) - aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} - - if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { - t.Fatal(err) - } - ct, err := chonk.CommitTime(aum.Hash()) - if err != nil { - t.Fatalf("CommitTime() failed: %v", err) - } - if ct.Before(time.Now().Add(-time.Minute)) || ct.After(time.Now().Add(time.Minute)) { - t.Errorf("commit time was wrong: %v more than a minute off from now (%v)", ct, time.Now()) - } -} - -func TestTailchonkFS_PurgeAUMs(t *testing.T) { - chonk := &FS{base: t.TempDir()} - parentHash := randHash(t, 1) - aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} - - if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { - t.Fatal(err) - } - if err := chonk.PurgeAUMs([]AUMHash{aum.Hash()}); err != nil { - t.Fatal(err) - } - - if _, err := chonk.AUM(aum.Hash()); err != os.ErrNotExist { - t.Errorf("AUM() on purged AUM returned err = %v, want ErrNotExist", err) - } - - info, err := chonk.get(aum.Hash()) - if err != nil { - t.Fatal(err) - } - if info.PurgedUnix == 0 { - t.Errorf("recently-created AUM PurgedUnix = %d, want non-zero", info.PurgedUnix) - } -} - -func TestTailchonkFS_AllAUMs(t *testing.T) { - chonk := &FS{base: t.TempDir()} - genesis := AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}} - gHash := genesis.Hash() - intermediate := AUM{PrevAUMHash: gHash[:]} - iHash := intermediate.Hash() - leaf := AUM{PrevAUMHash: iHash[:]} - - commitSet := []AUM{ - genesis, - intermediate, - leaf, - } - if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) - } - - hashes, err := chonk.AllAUMs() - if err != nil { - t.Fatal(err) - } - hashesLess := func(a, b AUMHash) bool { - return bytes.Compare(a[:], b[:]) < 0 - } - if diff := cmp.Diff([]AUMHash{genesis.Hash(), intermediate.Hash(), leaf.Hash()}, hashes, cmpopts.SortSlices(hashesLess)); diff != "" { - t.Fatalf("AllAUMs() output differs (-want, +got):\n%s", diff) - } -} - -func TestMarkActiveChain(t *testing.T) { - type aumTemplate struct { - AUM AUM - } - - tcs := []struct { - name string - minChain int - chain []aumTemplate - expectLastActiveIdx int // expected lastActiveAncestor, corresponds to an index on chain. - }{ - { - name: "genesis", - minChain: 2, - chain: []aumTemplate{ - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - }, - expectLastActiveIdx: 0, - }, - { - name: "simple truncate", - minChain: 2, - chain: []aumTemplate{ - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - }, - expectLastActiveIdx: 1, - }, - { - name: "long truncate", - minChain: 5, - chain: []aumTemplate{ - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - }, - expectLastActiveIdx: 2, - }, - { - name: "truncate finding checkpoint", - minChain: 2, - chain: []aumTemplate{ - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMAddKey, Key: &Key{}}}, // Should keep searching upwards for a checkpoint - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - }, - expectLastActiveIdx: 1, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - verdict := make(map[AUMHash]retainState, len(tc.chain)) - - // Build the state of the tailchonk for tests. - storage := &Mem{} - var prev AUMHash - for i := range tc.chain { - if !prev.IsZero() { - tc.chain[i].AUM.PrevAUMHash = make([]byte, len(prev[:])) - copy(tc.chain[i].AUM.PrevAUMHash, prev[:]) - } - if err := storage.CommitVerifiedAUMs([]AUM{tc.chain[i].AUM}); err != nil { - t.Fatal(err) - } - - h := tc.chain[i].AUM.Hash() - prev = h - verdict[h] = 0 - } - - got, err := markActiveChain(storage, verdict, tc.minChain, prev) - if err != nil { - t.Logf("state = %+v", verdict) - t.Fatalf("markActiveChain() failed: %v", err) - } - want := tc.chain[tc.expectLastActiveIdx].AUM.Hash() - if got != want { - t.Logf("state = %+v", verdict) - t.Errorf("lastActiveAncestor = %v, want %v", got, want) - } - - // Make sure the verdict array was marked correctly. - for i := range tc.chain { - h := tc.chain[i].AUM.Hash() - if i >= tc.expectLastActiveIdx { - if (verdict[h] & retainStateActive) == 0 { - t.Errorf("verdict[%v] = %v, want %v set", h, verdict[h], retainStateActive) - } - } else { - if (verdict[h] & retainStateCandidate) == 0 { - t.Errorf("verdict[%v] = %v, want %v set", h, verdict[h], retainStateCandidate) - } - } - } - }) - } -} - -func TestMarkDescendantAUMs(t *testing.T) { - c := newTestchain(t, ` - genesis -> B -> C -> C2 - | -> D - | -> E -> F -> G -> H - | -> E2 - - // tweak seeds so hashes arent identical - C.hashSeed = 1 - D.hashSeed = 2 - E.hashSeed = 3 - E2.hashSeed = 4 - `) - - verdict := make(map[AUMHash]retainState, len(c.AUMs)) - for _, a := range c.AUMs { - verdict[a.Hash()] = 0 - } - - // Mark E & C. - verdict[c.AUMHashes["C"]] = retainStateActive - verdict[c.AUMHashes["E"]] = retainStateActive - - if err := markDescendantAUMs(c.Chonk(), verdict); err != nil { - t.Errorf("markDescendantAUMs() failed: %v", err) - } - - // Make sure the descendants got marked. - hs := c.AUMHashes - for _, h := range []AUMHash{hs["C2"], hs["F"], hs["G"], hs["H"], hs["E2"]} { - if (verdict[h] & retainStateLeaf) == 0 { - t.Errorf("%v was not marked as a descendant", h) - } - } - for _, h := range []AUMHash{hs["genesis"], hs["B"], hs["D"]} { - if (verdict[h] & retainStateLeaf) != 0 { - t.Errorf("%v was marked as a descendant and shouldnt be", h) - } - } -} - -func TestMarkAncestorIntersectionAUMs(t *testing.T) { - fakeState := &State{ - Keys: []Key{{Kind: Key25519, Votes: 1}}, - DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, - } - - tcs := []struct { - name string - chain *testChain - verdicts map[string]retainState - initialAncestor string - wantAncestor string - wantRetained []string - wantDeleted []string - }{ - { - name: "genesis", - chain: newTestchain(t, ` - A - A.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "A", - wantAncestor: "A", - verdicts: map[string]retainState{ - "A": retainStateActive, - }, - wantRetained: []string{"A"}, - }, - { - name: "no adjustment", - chain: newTestchain(t, ` - DEAD -> A -> B -> C - A.template = checkpoint - B.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "A", - wantAncestor: "A", - verdicts: map[string]retainState{ - "A": retainStateActive, - "B": retainStateActive, - "C": retainStateActive, - "DEAD": retainStateCandidate, - }, - wantRetained: []string{"A", "B", "C"}, - wantDeleted: []string{"DEAD"}, - }, - { - name: "fork", - chain: newTestchain(t, ` - A -> B -> C -> D - | -> FORK - A.template = checkpoint - C.template = checkpoint - D.template = checkpoint - FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "D", - wantAncestor: "C", - verdicts: map[string]retainState{ - "A": retainStateCandidate, - "B": retainStateCandidate, - "C": retainStateCandidate, - "D": retainStateActive, - "FORK": retainStateYoung, - }, - wantRetained: []string{"C", "D", "FORK"}, - wantDeleted: []string{"A", "B"}, - }, - { - name: "fork finding earlier checkpoint", - chain: newTestchain(t, ` - A -> B -> C -> D -> E -> F - | -> FORK - A.template = checkpoint - B.template = checkpoint - E.template = checkpoint - FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "E", - wantAncestor: "B", - verdicts: map[string]retainState{ - "A": retainStateCandidate, - "B": retainStateCandidate, - "C": retainStateCandidate, - "D": retainStateCandidate, - "E": retainStateActive, - "F": retainStateActive, - "FORK": retainStateYoung, - }, - wantRetained: []string{"B", "C", "D", "E", "F", "FORK"}, - wantDeleted: []string{"A"}, - }, - { - name: "fork multi", - chain: newTestchain(t, ` - A -> B -> C -> D -> E - | -> DEADFORK - C -> FORK - A.template = checkpoint - C.template = checkpoint - D.template = checkpoint - E.template = checkpoint - FORK.hashSeed = 2 - DEADFORK.hashSeed = 3`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "D", - wantAncestor: "C", - verdicts: map[string]retainState{ - "A": retainStateCandidate, - "B": retainStateCandidate, - "C": retainStateCandidate, - "D": retainStateActive, - "E": retainStateActive, - "FORK": retainStateYoung, - "DEADFORK": 0, - }, - wantRetained: []string{"C", "D", "E", "FORK"}, - wantDeleted: []string{"A", "B", "DEADFORK"}, - }, - { - name: "fork multi 2", - chain: newTestchain(t, ` - A -> B -> C -> D -> E -> F -> G - - F -> F1 - D -> F2 - B -> F3 - - A.template = checkpoint - B.template = checkpoint - D.template = checkpoint - F.template = checkpoint - F1.hashSeed = 2 - F2.hashSeed = 3 - F3.hashSeed = 4`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "F", - wantAncestor: "B", - verdicts: map[string]retainState{ - "A": retainStateCandidate, - "B": retainStateCandidate, - "C": retainStateCandidate, - "D": retainStateCandidate, - "E": retainStateCandidate, - "F": retainStateActive, - "G": retainStateActive, - "F1": retainStateYoung, - "F2": retainStateYoung, - "F3": retainStateYoung, - }, - wantRetained: []string{"B", "C", "D", "E", "F", "G", "F1", "F2", "F3"}, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - verdict := make(map[AUMHash]retainState, len(tc.verdicts)) - for name, v := range tc.verdicts { - verdict[tc.chain.AUMHashes[name]] = v - } - - got, err := markAncestorIntersectionAUMs(tc.chain.Chonk(), verdict, tc.chain.AUMHashes[tc.initialAncestor]) - if err != nil { - t.Logf("state = %+v", verdict) - t.Fatalf("markAncestorIntersectionAUMs() failed: %v", err) - } - if want := tc.chain.AUMHashes[tc.wantAncestor]; got != want { - t.Logf("state = %+v", verdict) - t.Errorf("lastActiveAncestor = %v, want %v", got, want) - } - - for _, name := range tc.wantRetained { - h := tc.chain.AUMHashes[name] - if v := verdict[h]; v&retainAUMMask == 0 { - t.Errorf("AUM %q was not retained: verdict = %v", name, v) - } - } - for _, name := range tc.wantDeleted { - h := tc.chain.AUMHashes[name] - if v := verdict[h]; v&retainAUMMask != 0 { - t.Errorf("AUM %q was retained: verdict = %v", name, v) - } - } - - if t.Failed() { - for name, hash := range tc.chain.AUMHashes { - t.Logf("AUM[%q] = %v", name, hash) - } - } - }) - } -} - -type compactingChonkFake struct { - Mem - - aumAge map[AUMHash]time.Time - t *testing.T - wantDelete []AUMHash -} - -func (c *compactingChonkFake) AllAUMs() ([]AUMHash, error) { - out := make([]AUMHash, 0, len(c.Mem.aums)) - for h := range c.Mem.aums { - out = append(out, h) - } - return out, nil -} - -func (c *compactingChonkFake) CommitTime(hash AUMHash) (time.Time, error) { - return c.aumAge[hash], nil -} - -func (c *compactingChonkFake) PurgeAUMs(hashes []AUMHash) error { - lessHashes := func(a, b AUMHash) bool { - return bytes.Compare(a[:], b[:]) < 0 - } - if diff := cmp.Diff(c.wantDelete, hashes, cmpopts.SortSlices(lessHashes)); diff != "" { - c.t.Errorf("deletion set differs (-want, +got):\n%s", diff) - } - return nil -} - -// Avoid go vet complaining about copying a lock value -func cloneMem(src, dst *Mem) { - dst.l = sync.RWMutex{} - dst.aums = src.aums - dst.parentIndex = src.parentIndex - dst.lastActiveAncestor = src.lastActiveAncestor -} - -func TestCompact(t *testing.T) { - fakeState := &State{ - Keys: []Key{{Kind: Key25519, Votes: 1}}, - DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, - } - - // A & B are deleted because the new lastActiveAncestor advances beyond them. - // OLD is deleted because it does not match retention criteria, and - // though it is a descendant of the new lastActiveAncestor (C), it is not a - // descendant of a retained AUM. - // G, & H are retained as recent (MinChain=2) ancestors of HEAD. - // E & F are retained because they are between retained AUMs (G+) and - // their newest checkpoint ancestor. - // D is retained because it is the newest checkpoint ancestor from - // MinChain-retained AUMs. - // G2 is retained because it is a descendant of a retained AUM (G). - // F1 is retained because it is new enough by wall-clock time. - // F2 is retained because it is a descendant of a retained AUM (F1). - // C2 is retained because it is between an ancestor checkpoint and - // a retained AUM (F1). - // C is retained because it is the new lastActiveAncestor. It is the - // new lastActiveAncestor because it is the newest common checkpoint - // of all retained AUMs. - c := newTestchain(t, ` - A -> B -> C -> C2 -> D -> E -> F -> G -> H - | -> F1 -> F2 | -> G2 - | -> OLD - - // make {A,B,C,D} compaction candidates - A.template = checkpoint - B.template = checkpoint - C.template = checkpoint - D.template = checkpoint - - // tweak seeds of forks so hashes arent identical - F1.hashSeed = 1 - OLD.hashSeed = 2 - G2.hashSeed = 3 - `, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})) - - storage := &compactingChonkFake{ - aumAge: map[AUMHash]time.Time{(c.AUMHashes["F1"]): time.Now()}, - t: t, - wantDelete: []AUMHash{c.AUMHashes["A"], c.AUMHashes["B"], c.AUMHashes["OLD"]}, - } - - cloneMem(c.Chonk().(*Mem), &storage.Mem) - - lastActiveAncestor, err := Compact(storage, c.AUMHashes["H"], CompactionOptions{MinChain: 2, MinAge: time.Hour}) - if err != nil { - t.Errorf("Compact() failed: %v", err) - } - if lastActiveAncestor != c.AUMHashes["C"] { - t.Errorf("last active ancestor = %v, want %v", lastActiveAncestor, c.AUMHashes["C"]) - } - - if t.Failed() { - for name, hash := range c.AUMHashes { - t.Logf("AUM[%q] = %v", name, hash) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "golang.org/x/crypto/blake2s" +) + +// randHash derives a fake blake2s hash from the test name +// and the given seed. +func randHash(t *testing.T, seed int64) [blake2s.Size]byte { + var out [blake2s.Size]byte + testingRand(t, seed).Read(out[:]) + return out +} + +func TestImplementsChonk(t *testing.T) { + impls := []Chonk{&Mem{}, &FS{}} + t.Logf("chonks: %v", impls) +} + +func TestTailchonk_ChildAUMs(t *testing.T) { + for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { + t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { + parentHash := randHash(t, 1) + data := []AUM{ + { + MessageKind: AUMRemoveKey, + KeyID: []byte{1, 2}, + PrevAUMHash: parentHash[:], + }, + { + MessageKind: AUMRemoveKey, + KeyID: []byte{3, 4}, + PrevAUMHash: parentHash[:], + }, + } + + if err := chonk.CommitVerifiedAUMs(data); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + stored, err := chonk.ChildAUMs(parentHash) + if err != nil { + t.Fatalf("ChildAUMs failed: %v", err) + } + if diff := cmp.Diff(data, stored); diff != "" { + t.Errorf("stored AUM differs (-want, +got):\n%s", diff) + } + }) + } +} + +func TestTailchonk_AUMMissing(t *testing.T) { + for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { + t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { + var notExists AUMHash + notExists[:][0] = 42 + if _, err := chonk.AUM(notExists); err != os.ErrNotExist { + t.Errorf("chonk.AUM(notExists).err = %v, want %v", err, os.ErrNotExist) + } + }) + } +} + +func TestTailchonkMem_Orphans(t *testing.T) { + chonk := Mem{} + + parentHash := randHash(t, 1) + orphan := AUM{MessageKind: AUMNoOp} + aums := []AUM{ + orphan, + // A parent is specified, so we shouldnt see it in GetOrphans() + { + MessageKind: AUMRemoveKey, + KeyID: []byte{3, 4}, + PrevAUMHash: parentHash[:], + }, + } + if err := chonk.CommitVerifiedAUMs(aums); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + + stored, err := chonk.Orphans() + if err != nil { + t.Fatalf("Orphans failed: %v", err) + } + if diff := cmp.Diff([]AUM{orphan}, stored); diff != "" { + t.Errorf("stored AUM differs (-want, +got):\n%s", diff) + } +} + +func TestTailchonk_ReadChainFromHead(t *testing.T) { + for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { + + t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { + genesis := AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}} + gHash := genesis.Hash() + intermediate := AUM{PrevAUMHash: gHash[:]} + iHash := intermediate.Hash() + leaf := AUM{PrevAUMHash: iHash[:]} + + commitSet := []AUM{ + genesis, + intermediate, + leaf, + } + if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + // t.Logf("genesis hash = %X", genesis.Hash()) + // t.Logf("intermediate hash = %X", intermediate.Hash()) + // t.Logf("leaf hash = %X", leaf.Hash()) + + // Read the chain from the leaf backwards. + gotLeafs, err := chonk.Heads() + if err != nil { + t.Fatalf("Heads failed: %v", err) + } + if diff := cmp.Diff([]AUM{leaf}, gotLeafs); diff != "" { + t.Fatalf("leaf AUM differs (-want, +got):\n%s", diff) + } + + parent, _ := gotLeafs[0].Parent() + gotIntermediate, err := chonk.AUM(parent) + if err != nil { + t.Fatalf("AUM() failed: %v", err) + } + if diff := cmp.Diff(intermediate, gotIntermediate); diff != "" { + t.Errorf("intermediate AUM differs (-want, +got):\n%s", diff) + } + + parent, _ = gotIntermediate.Parent() + gotGenesis, err := chonk.AUM(parent) + if err != nil { + t.Fatalf("AUM() failed: %v", err) + } + if diff := cmp.Diff(genesis, gotGenesis); diff != "" { + t.Errorf("genesis AUM differs (-want, +got):\n%s", diff) + } + }) + } +} + +func TestTailchonkFS_Commit(t *testing.T) { + chonk := &FS{base: t.TempDir()} + parentHash := randHash(t, 1) + aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} + + if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { + t.Fatal(err) + } + + dir, base := chonk.aumDir(aum.Hash()) + if got, want := dir, filepath.Join(chonk.base, "PD"); got != want { + t.Errorf("aum dir=%s, want %s", got, want) + } + if want := "PD57DVP6GKC76OOZMXFFZUSOEFQXOLAVT7N2ZM5KB3HDIMCANF4A"; base != want { + t.Errorf("aum base=%s, want %s", base, want) + } + if _, err := os.Stat(filepath.Join(dir, base)); err != nil { + t.Errorf("stat of AUM file failed: %v", err) + } + if _, err := os.Stat(filepath.Join(chonk.base, "M7", "M7LL2NDB4NKCZIUPVS6RDM2GUOIMW6EEAFVBWMVCPUANQJPHT3SQ")); err != nil { + t.Errorf("stat of AUM parent failed: %v", err) + } + + info, err := chonk.get(aum.Hash()) + if err != nil { + t.Fatal(err) + } + if info.PurgedUnix > 0 { + t.Errorf("recently-created AUM PurgedUnix = %d, want 0", info.PurgedUnix) + } +} + +func TestTailchonkFS_CommitTime(t *testing.T) { + chonk := &FS{base: t.TempDir()} + parentHash := randHash(t, 1) + aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} + + if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { + t.Fatal(err) + } + ct, err := chonk.CommitTime(aum.Hash()) + if err != nil { + t.Fatalf("CommitTime() failed: %v", err) + } + if ct.Before(time.Now().Add(-time.Minute)) || ct.After(time.Now().Add(time.Minute)) { + t.Errorf("commit time was wrong: %v more than a minute off from now (%v)", ct, time.Now()) + } +} + +func TestTailchonkFS_PurgeAUMs(t *testing.T) { + chonk := &FS{base: t.TempDir()} + parentHash := randHash(t, 1) + aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} + + if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { + t.Fatal(err) + } + if err := chonk.PurgeAUMs([]AUMHash{aum.Hash()}); err != nil { + t.Fatal(err) + } + + if _, err := chonk.AUM(aum.Hash()); err != os.ErrNotExist { + t.Errorf("AUM() on purged AUM returned err = %v, want ErrNotExist", err) + } + + info, err := chonk.get(aum.Hash()) + if err != nil { + t.Fatal(err) + } + if info.PurgedUnix == 0 { + t.Errorf("recently-created AUM PurgedUnix = %d, want non-zero", info.PurgedUnix) + } +} + +func TestTailchonkFS_AllAUMs(t *testing.T) { + chonk := &FS{base: t.TempDir()} + genesis := AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}} + gHash := genesis.Hash() + intermediate := AUM{PrevAUMHash: gHash[:]} + iHash := intermediate.Hash() + leaf := AUM{PrevAUMHash: iHash[:]} + + commitSet := []AUM{ + genesis, + intermediate, + leaf, + } + if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + + hashes, err := chonk.AllAUMs() + if err != nil { + t.Fatal(err) + } + hashesLess := func(a, b AUMHash) bool { + return bytes.Compare(a[:], b[:]) < 0 + } + if diff := cmp.Diff([]AUMHash{genesis.Hash(), intermediate.Hash(), leaf.Hash()}, hashes, cmpopts.SortSlices(hashesLess)); diff != "" { + t.Fatalf("AllAUMs() output differs (-want, +got):\n%s", diff) + } +} + +func TestMarkActiveChain(t *testing.T) { + type aumTemplate struct { + AUM AUM + } + + tcs := []struct { + name string + minChain int + chain []aumTemplate + expectLastActiveIdx int // expected lastActiveAncestor, corresponds to an index on chain. + }{ + { + name: "genesis", + minChain: 2, + chain: []aumTemplate{ + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + }, + expectLastActiveIdx: 0, + }, + { + name: "simple truncate", + minChain: 2, + chain: []aumTemplate{ + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + }, + expectLastActiveIdx: 1, + }, + { + name: "long truncate", + minChain: 5, + chain: []aumTemplate{ + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + }, + expectLastActiveIdx: 2, + }, + { + name: "truncate finding checkpoint", + minChain: 2, + chain: []aumTemplate{ + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMAddKey, Key: &Key{}}}, // Should keep searching upwards for a checkpoint + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + }, + expectLastActiveIdx: 1, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + verdict := make(map[AUMHash]retainState, len(tc.chain)) + + // Build the state of the tailchonk for tests. + storage := &Mem{} + var prev AUMHash + for i := range tc.chain { + if !prev.IsZero() { + tc.chain[i].AUM.PrevAUMHash = make([]byte, len(prev[:])) + copy(tc.chain[i].AUM.PrevAUMHash, prev[:]) + } + if err := storage.CommitVerifiedAUMs([]AUM{tc.chain[i].AUM}); err != nil { + t.Fatal(err) + } + + h := tc.chain[i].AUM.Hash() + prev = h + verdict[h] = 0 + } + + got, err := markActiveChain(storage, verdict, tc.minChain, prev) + if err != nil { + t.Logf("state = %+v", verdict) + t.Fatalf("markActiveChain() failed: %v", err) + } + want := tc.chain[tc.expectLastActiveIdx].AUM.Hash() + if got != want { + t.Logf("state = %+v", verdict) + t.Errorf("lastActiveAncestor = %v, want %v", got, want) + } + + // Make sure the verdict array was marked correctly. + for i := range tc.chain { + h := tc.chain[i].AUM.Hash() + if i >= tc.expectLastActiveIdx { + if (verdict[h] & retainStateActive) == 0 { + t.Errorf("verdict[%v] = %v, want %v set", h, verdict[h], retainStateActive) + } + } else { + if (verdict[h] & retainStateCandidate) == 0 { + t.Errorf("verdict[%v] = %v, want %v set", h, verdict[h], retainStateCandidate) + } + } + } + }) + } +} + +func TestMarkDescendantAUMs(t *testing.T) { + c := newTestchain(t, ` + genesis -> B -> C -> C2 + | -> D + | -> E -> F -> G -> H + | -> E2 + + // tweak seeds so hashes arent identical + C.hashSeed = 1 + D.hashSeed = 2 + E.hashSeed = 3 + E2.hashSeed = 4 + `) + + verdict := make(map[AUMHash]retainState, len(c.AUMs)) + for _, a := range c.AUMs { + verdict[a.Hash()] = 0 + } + + // Mark E & C. + verdict[c.AUMHashes["C"]] = retainStateActive + verdict[c.AUMHashes["E"]] = retainStateActive + + if err := markDescendantAUMs(c.Chonk(), verdict); err != nil { + t.Errorf("markDescendantAUMs() failed: %v", err) + } + + // Make sure the descendants got marked. + hs := c.AUMHashes + for _, h := range []AUMHash{hs["C2"], hs["F"], hs["G"], hs["H"], hs["E2"]} { + if (verdict[h] & retainStateLeaf) == 0 { + t.Errorf("%v was not marked as a descendant", h) + } + } + for _, h := range []AUMHash{hs["genesis"], hs["B"], hs["D"]} { + if (verdict[h] & retainStateLeaf) != 0 { + t.Errorf("%v was marked as a descendant and shouldnt be", h) + } + } +} + +func TestMarkAncestorIntersectionAUMs(t *testing.T) { + fakeState := &State{ + Keys: []Key{{Kind: Key25519, Votes: 1}}, + DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, + } + + tcs := []struct { + name string + chain *testChain + verdicts map[string]retainState + initialAncestor string + wantAncestor string + wantRetained []string + wantDeleted []string + }{ + { + name: "genesis", + chain: newTestchain(t, ` + A + A.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "A", + wantAncestor: "A", + verdicts: map[string]retainState{ + "A": retainStateActive, + }, + wantRetained: []string{"A"}, + }, + { + name: "no adjustment", + chain: newTestchain(t, ` + DEAD -> A -> B -> C + A.template = checkpoint + B.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "A", + wantAncestor: "A", + verdicts: map[string]retainState{ + "A": retainStateActive, + "B": retainStateActive, + "C": retainStateActive, + "DEAD": retainStateCandidate, + }, + wantRetained: []string{"A", "B", "C"}, + wantDeleted: []string{"DEAD"}, + }, + { + name: "fork", + chain: newTestchain(t, ` + A -> B -> C -> D + | -> FORK + A.template = checkpoint + C.template = checkpoint + D.template = checkpoint + FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "D", + wantAncestor: "C", + verdicts: map[string]retainState{ + "A": retainStateCandidate, + "B": retainStateCandidate, + "C": retainStateCandidate, + "D": retainStateActive, + "FORK": retainStateYoung, + }, + wantRetained: []string{"C", "D", "FORK"}, + wantDeleted: []string{"A", "B"}, + }, + { + name: "fork finding earlier checkpoint", + chain: newTestchain(t, ` + A -> B -> C -> D -> E -> F + | -> FORK + A.template = checkpoint + B.template = checkpoint + E.template = checkpoint + FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "E", + wantAncestor: "B", + verdicts: map[string]retainState{ + "A": retainStateCandidate, + "B": retainStateCandidate, + "C": retainStateCandidate, + "D": retainStateCandidate, + "E": retainStateActive, + "F": retainStateActive, + "FORK": retainStateYoung, + }, + wantRetained: []string{"B", "C", "D", "E", "F", "FORK"}, + wantDeleted: []string{"A"}, + }, + { + name: "fork multi", + chain: newTestchain(t, ` + A -> B -> C -> D -> E + | -> DEADFORK + C -> FORK + A.template = checkpoint + C.template = checkpoint + D.template = checkpoint + E.template = checkpoint + FORK.hashSeed = 2 + DEADFORK.hashSeed = 3`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "D", + wantAncestor: "C", + verdicts: map[string]retainState{ + "A": retainStateCandidate, + "B": retainStateCandidate, + "C": retainStateCandidate, + "D": retainStateActive, + "E": retainStateActive, + "FORK": retainStateYoung, + "DEADFORK": 0, + }, + wantRetained: []string{"C", "D", "E", "FORK"}, + wantDeleted: []string{"A", "B", "DEADFORK"}, + }, + { + name: "fork multi 2", + chain: newTestchain(t, ` + A -> B -> C -> D -> E -> F -> G + + F -> F1 + D -> F2 + B -> F3 + + A.template = checkpoint + B.template = checkpoint + D.template = checkpoint + F.template = checkpoint + F1.hashSeed = 2 + F2.hashSeed = 3 + F3.hashSeed = 4`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "F", + wantAncestor: "B", + verdicts: map[string]retainState{ + "A": retainStateCandidate, + "B": retainStateCandidate, + "C": retainStateCandidate, + "D": retainStateCandidate, + "E": retainStateCandidate, + "F": retainStateActive, + "G": retainStateActive, + "F1": retainStateYoung, + "F2": retainStateYoung, + "F3": retainStateYoung, + }, + wantRetained: []string{"B", "C", "D", "E", "F", "G", "F1", "F2", "F3"}, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + verdict := make(map[AUMHash]retainState, len(tc.verdicts)) + for name, v := range tc.verdicts { + verdict[tc.chain.AUMHashes[name]] = v + } + + got, err := markAncestorIntersectionAUMs(tc.chain.Chonk(), verdict, tc.chain.AUMHashes[tc.initialAncestor]) + if err != nil { + t.Logf("state = %+v", verdict) + t.Fatalf("markAncestorIntersectionAUMs() failed: %v", err) + } + if want := tc.chain.AUMHashes[tc.wantAncestor]; got != want { + t.Logf("state = %+v", verdict) + t.Errorf("lastActiveAncestor = %v, want %v", got, want) + } + + for _, name := range tc.wantRetained { + h := tc.chain.AUMHashes[name] + if v := verdict[h]; v&retainAUMMask == 0 { + t.Errorf("AUM %q was not retained: verdict = %v", name, v) + } + } + for _, name := range tc.wantDeleted { + h := tc.chain.AUMHashes[name] + if v := verdict[h]; v&retainAUMMask != 0 { + t.Errorf("AUM %q was retained: verdict = %v", name, v) + } + } + + if t.Failed() { + for name, hash := range tc.chain.AUMHashes { + t.Logf("AUM[%q] = %v", name, hash) + } + } + }) + } +} + +type compactingChonkFake struct { + Mem + + aumAge map[AUMHash]time.Time + t *testing.T + wantDelete []AUMHash +} + +func (c *compactingChonkFake) AllAUMs() ([]AUMHash, error) { + out := make([]AUMHash, 0, len(c.Mem.aums)) + for h := range c.Mem.aums { + out = append(out, h) + } + return out, nil +} + +func (c *compactingChonkFake) CommitTime(hash AUMHash) (time.Time, error) { + return c.aumAge[hash], nil +} + +func (c *compactingChonkFake) PurgeAUMs(hashes []AUMHash) error { + lessHashes := func(a, b AUMHash) bool { + return bytes.Compare(a[:], b[:]) < 0 + } + if diff := cmp.Diff(c.wantDelete, hashes, cmpopts.SortSlices(lessHashes)); diff != "" { + c.t.Errorf("deletion set differs (-want, +got):\n%s", diff) + } + return nil +} + +// Avoid go vet complaining about copying a lock value +func cloneMem(src, dst *Mem) { + dst.l = sync.RWMutex{} + dst.aums = src.aums + dst.parentIndex = src.parentIndex + dst.lastActiveAncestor = src.lastActiveAncestor +} + +func TestCompact(t *testing.T) { + fakeState := &State{ + Keys: []Key{{Kind: Key25519, Votes: 1}}, + DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, + } + + // A & B are deleted because the new lastActiveAncestor advances beyond them. + // OLD is deleted because it does not match retention criteria, and + // though it is a descendant of the new lastActiveAncestor (C), it is not a + // descendant of a retained AUM. + // G, & H are retained as recent (MinChain=2) ancestors of HEAD. + // E & F are retained because they are between retained AUMs (G+) and + // their newest checkpoint ancestor. + // D is retained because it is the newest checkpoint ancestor from + // MinChain-retained AUMs. + // G2 is retained because it is a descendant of a retained AUM (G). + // F1 is retained because it is new enough by wall-clock time. + // F2 is retained because it is a descendant of a retained AUM (F1). + // C2 is retained because it is between an ancestor checkpoint and + // a retained AUM (F1). + // C is retained because it is the new lastActiveAncestor. It is the + // new lastActiveAncestor because it is the newest common checkpoint + // of all retained AUMs. + c := newTestchain(t, ` + A -> B -> C -> C2 -> D -> E -> F -> G -> H + | -> F1 -> F2 | -> G2 + | -> OLD + + // make {A,B,C,D} compaction candidates + A.template = checkpoint + B.template = checkpoint + C.template = checkpoint + D.template = checkpoint + + // tweak seeds of forks so hashes arent identical + F1.hashSeed = 1 + OLD.hashSeed = 2 + G2.hashSeed = 3 + `, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})) + + storage := &compactingChonkFake{ + aumAge: map[AUMHash]time.Time{(c.AUMHashes["F1"]): time.Now()}, + t: t, + wantDelete: []AUMHash{c.AUMHashes["A"], c.AUMHashes["B"], c.AUMHashes["OLD"]}, + } + + cloneMem(c.Chonk().(*Mem), &storage.Mem) + + lastActiveAncestor, err := Compact(storage, c.AUMHashes["H"], CompactionOptions{MinChain: 2, MinAge: time.Hour}) + if err != nil { + t.Errorf("Compact() failed: %v", err) + } + if lastActiveAncestor != c.AUMHashes["C"] { + t.Errorf("last active ancestor = %v, want %v", lastActiveAncestor, c.AUMHashes["C"]) + } + + if t.Failed() { + for name, hash := range c.AUMHashes { + t.Logf("AUM[%q] = %v", name, hash) + } + } +} diff --git a/tka/tka_test.go b/tka/tka_test.go index 9e3c4e79d05bd..3438a4016f0f6 100644 --- a/tka/tka_test.go +++ b/tka/tka_test.go @@ -1,654 +1,654 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "testing" - - "github.com/google/go-cmp/cmp" - "tailscale.com/types/key" - "tailscale.com/types/tkatype" -) - -func TestComputeChainCandidates(t *testing.T) { - c := newTestchain(t, ` - G1 -> I1 -> I2 -> I3 -> L2 - | -> L1 | -> L3 - - G2 -> L4 - - // We tweak these AUMs so they are different hashes. - G2.hashSeed = 2 - L1.hashSeed = 2 - L3.hashSeed = 2 - L4.hashSeed = 3 - `) - // Should result in 4 chains: - // G1->L1, G1->L2, G1->L3, G2->L4 - - i1H := c.AUMHashes["I1"] - got, err := computeChainCandidates(c.Chonk(), &i1H, 50) - if err != nil { - t.Fatalf("computeChainCandidates() failed: %v", err) - } - - want := []chain{ - {Oldest: c.AUMs["G2"], Head: c.AUMs["L4"]}, - {Oldest: c.AUMs["G1"], Head: c.AUMs["L3"], chainsThroughActive: true}, - {Oldest: c.AUMs["G1"], Head: c.AUMs["L1"], chainsThroughActive: true}, - {Oldest: c.AUMs["G1"], Head: c.AUMs["L2"], chainsThroughActive: true}, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(chain{})); diff != "" { - t.Errorf("chains differ (-want, +got):\n%s", diff) - } -} - -func TestForkResolutionHash(t *testing.T) { - c := newTestchain(t, ` - G1 -> L1 - | -> L2 - - // tweak hashes so L1 & L2 are not identical - L1.hashSeed = 2 - L2.hashSeed = 3 - `) - - got, err := computeActiveChain(c.Chonk(), nil, 50) - if err != nil { - t.Fatalf("computeActiveChain() failed: %v", err) - } - - // The fork with the lowest AUM hash should have been chosen. - l1H := c.AUMHashes["L1"] - l2H := c.AUMHashes["L2"] - want := l1H - if bytes.Compare(l2H[:], l1H[:]) < 0 { - want = l2H - } - - if got := got.Head.Hash(); got != want { - t.Errorf("head was %x, want %x", got, want) - } -} - -func TestForkResolutionSigWeight(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> L1 - | -> L2 - - G1.template = addKey - L1.hashSeed = 11 - L2.signedWith = key - `, - optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}), - optKey("key", key, priv)) - - l1H := c.AUMHashes["L1"] - l2H := c.AUMHashes["L2"] - if bytes.Compare(l2H[:], l1H[:]) < 0 { - t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes") - } - - got, err := computeActiveChain(c.Chonk(), nil, 50) - if err != nil { - t.Fatalf("computeActiveChain() failed: %v", err) - } - - // Based on the hash, l1H should be chosen. - // But based on the signature weight (which has higher - // precedence), it should be l2H - want := l2H - if got := got.Head.Hash(); got != want { - t.Errorf("head was %x, want %x", got, want) - } -} - -func TestForkResolutionMessageType(t *testing.T) { - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> L1 - | -> L2 - | -> L3 - - G1.template = addKey - L1.hashSeed = 11 - L2.template = removeKey - L3.hashSeed = 18 - `, - optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}), - optTemplate("removeKey", AUM{MessageKind: AUMRemoveKey, KeyID: key.MustID()})) - - l1H := c.AUMHashes["L1"] - l2H := c.AUMHashes["L2"] - l3H := c.AUMHashes["L3"] - if bytes.Compare(l2H[:], l1H[:]) < 0 { - t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes") - } - if bytes.Compare(l2H[:], l3H[:]) < 0 { - t.Fatal("failed assert: h(l3) > h(l2)\nTweak hashSeed till this passes") - } - - got, err := computeActiveChain(c.Chonk(), nil, 50) - if err != nil { - t.Fatalf("computeActiveChain() failed: %v", err) - } - - // Based on the hash, L1 or L3 should be chosen. - // But based on the preference for AUMRemoveKey messages, - // it should be L2. - want := l2H - if got := got.Head.Hash(); got != want { - t.Errorf("head was %x, want %x", got, want) - } -} - -func TestComputeStateAt(t *testing.T) { - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> I1 -> I2 - I1.template = addKey - `, - optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key})) - - // G1 is before the key, so there shouldn't be a key there. - state, err := computeStateAt(c.Chonk(), 500, c.AUMHashes["G1"]) - if err != nil { - t.Fatalf("computeStateAt(G1) failed: %v", err) - } - if _, err := state.GetKey(key.MustID()); err != ErrNoSuchKey { - t.Errorf("expected key to be missing: err = %v", err) - } - if *state.LastAUMHash != c.AUMHashes["G1"] { - t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, c.AUMHashes["G1"]) - } - - // I1 & I2 are after the key, so the computed state should contain - // the key. - for _, wantHash := range []AUMHash{c.AUMHashes["I1"], c.AUMHashes["I2"]} { - state, err = computeStateAt(c.Chonk(), 500, wantHash) - if err != nil { - t.Fatalf("computeStateAt(%X) failed: %v", wantHash, err) - } - if *state.LastAUMHash != wantHash { - t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, wantHash) - } - if _, err := state.GetKey(key.MustID()); err != nil { - t.Errorf("expected key to be present at state: err = %v", err) - } - } -} - -// fakeAUM generates an AUM structure based on the template. -// If parent is provided, PrevAUMHash is set to that value. -// -// If template is an AUM, the returned AUM is based on that. -// If template is an int, a NOOP AUM is returned, and the -// provided int can be used to tweak the resulting hash (needed -// for tests you want one AUM to be 'lower' than another, so that -// that chain is taken based on fork resolution rules). -func fakeAUM(t *testing.T, template any, parent *AUMHash) (AUM, AUMHash) { - if seed, ok := template.(int); ok { - a := AUM{MessageKind: AUMNoOp, KeyID: []byte{byte(seed)}} - if parent != nil { - a.PrevAUMHash = (*parent)[:] - } - h := a.Hash() - return a, h - } - - if a, ok := template.(AUM); ok { - if parent != nil { - a.PrevAUMHash = (*parent)[:] - } - h := a.Hash() - return a, h - } - - panic("template must be an int or an AUM") -} - -func TestOpenAuthority(t *testing.T) { - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - // /- L1 - // G1 - I1 - I2 - I3 -L2 - // \-L3 - // G2 - L4 - // - // We set the previous-known ancestor to G1, so the - // ancestor to start from should be G1. - g1, g1H := fakeAUM(t, AUM{MessageKind: AUMAddKey, Key: &key}, nil) - i1, i1H := fakeAUM(t, 2, &g1H) // AUM{MessageKind: AUMAddKey, Key: &key2} - l1, l1H := fakeAUM(t, 13, &i1H) - - i2, i2H := fakeAUM(t, 2, &i1H) - i3, i3H := fakeAUM(t, 5, &i2H) - l2, l2H := fakeAUM(t, AUM{MessageKind: AUMNoOp, KeyID: []byte{7}, Signatures: []tkatype.Signature{{KeyID: key.MustID()}}}, &i3H) - l3, l3H := fakeAUM(t, 4, &i3H) - - g2, g2H := fakeAUM(t, 8, nil) - l4, _ := fakeAUM(t, 9, &g2H) - - // We make sure that I2 has a lower hash than L1, so - // it should take that path rather than L1. - if bytes.Compare(l1H[:], i2H[:]) < 0 { - t.Fatal("failed assert: h(i2) > h(l1)\nTweak parameters to fakeAUM till this passes") - } - // We make sure L2 has a signature with key, so it should - // take that path over L3. We assert that the L3 hash - // is less than L2 so the test will fail if the signature - // preference logic is broken. - if bytes.Compare(l2H[:], l3H[:]) < 0 { - t.Fatal("failed assert: h(l3) > h(l2)\nTweak parameters to fakeAUM till this passes") - } - - // Construct the state of durable storage. - chonk := &Mem{} - err := chonk.CommitVerifiedAUMs([]AUM{g1, i1, l1, i2, i3, l2, l3, g2, l4}) - if err != nil { - t.Fatal(err) - } - chonk.SetLastActiveAncestor(i1H) - - a, err := Open(chonk) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - // Should include the key added in G1 - if _, err := a.state.GetKey(key.MustID()); err != nil { - t.Errorf("missing G1 key: %v", err) - } - // The head of the chain should be L2. - if a.Head() != l2H { - t.Errorf("head was %x, want %x", a.state.LastAUMHash, l2H) - } -} - -func TestOpenAuthority_EmptyErrors(t *testing.T) { - _, err := Open(&Mem{}) - if err == nil { - t.Error("Expected an error initializing an empty authority, got nil") - } -} - -func TestAuthorityHead(t *testing.T) { - c := newTestchain(t, ` - G1 -> L1 - | -> L2 - - L1.hashSeed = 2 - `) - - a, _ := Open(c.Chonk()) - if got, want := a.head.Hash(), a.Head(); got != want { - t.Errorf("Hash() returned %x, want %x", got, want) - } -} - -func TestAuthorityValidDisablement(t *testing.T) { - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - c := newTestchain(t, ` - G1 -> L1 - - G1.template = genesis - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - ) - - a, _ := Open(c.Chonk()) - if valid := a.ValidDisablement([]byte{1, 2, 3}); !valid { - t.Error("ValidDisablement() returned false, want true") - } -} - -func TestCreateBootstrapAuthority(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - a1, genesisAUM, err := Create(&Mem{}, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - a2, err := Bootstrap(&Mem{}, genesisAUM) - if err != nil { - t.Fatalf("Bootstrap() failed: %v", err) - } - - if a1.Head() != a2.Head() { - t.Fatal("created and bootstrapped authority differ") - } - - // Both authorities should trust the key laid down in the genesis state. - if !a1.KeyTrusted(key.MustID()) { - t.Error("a1 did not trust genesis key") - } - if !a2.KeyTrusted(key.MustID()) { - t.Error("a2 did not trust genesis key") - } -} - -func TestAuthorityInformNonLinear(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> L1 - | -> L2 -> L3 - | -> L4 -> L5 - - G1.template = genesis - L1.hashSeed = 3 - L2.hashSeed = 2 - L4.hashSeed = 2 - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optKey("key", key, priv), - optSignAllUsing("key")) - - storage := &Mem{} - a, err := Bootstrap(storage, c.AUMs["G1"]) - if err != nil { - t.Fatalf("Bootstrap() failed: %v", err) - } - - // L2 does not chain from L1, disabling the isHeadChain optimization - // and forcing Inform() to take the slow path. - informAUMs := []AUM{c.AUMs["L1"], c.AUMs["L2"], c.AUMs["L3"], c.AUMs["L4"], c.AUMs["L5"]} - - if err := a.Inform(storage, informAUMs); err != nil { - t.Fatalf("Inform() failed: %v", err) - } - for i, update := range informAUMs { - stored, err := storage.AUM(update.Hash()) - if err != nil { - t.Errorf("reading stored update %d: %v", i, err) - continue - } - if diff := cmp.Diff(update, stored); diff != "" { - t.Errorf("update %d differs (-want, +got):\n%s", i, diff) - } - } - - if a.Head() != c.AUMHashes["L3"] { - t.Fatal("authority did not converge to correct AUM") - } -} - -func TestAuthorityInformLinear(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> L1 -> L2 -> L3 - - G1.template = genesis - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optKey("key", key, priv), - optSignAllUsing("key")) - - storage := &Mem{} - a, err := Bootstrap(storage, c.AUMs["G1"]) - if err != nil { - t.Fatalf("Bootstrap() failed: %v", err) - } - - informAUMs := []AUM{c.AUMs["L1"], c.AUMs["L2"], c.AUMs["L3"]} - - if err := a.Inform(storage, informAUMs); err != nil { - t.Fatalf("Inform() failed: %v", err) - } - for i, update := range informAUMs { - stored, err := storage.AUM(update.Hash()) - if err != nil { - t.Errorf("reading stored update %d: %v", i, err) - continue - } - if diff := cmp.Diff(update, stored); diff != "" { - t.Errorf("update %d differs (-want, +got):\n%s", i, diff) - } - } - - if a.Head() != c.AUMHashes["L3"] { - t.Fatal("authority did not converge to correct AUM") - } -} - -func TestInteropWithNLKey(t *testing.T) { - priv1 := key.NewNLPrivate() - pub1 := priv1.Public() - pub2 := key.NewNLPrivate().Public() - pub3 := key.NewNLPrivate().Public() - - a, _, err := Create(&Mem{}, State{ - Keys: []Key{ - { - Kind: Key25519, - Votes: 1, - Public: pub1.KeyID(), - }, - { - Kind: Key25519, - Votes: 1, - Public: pub2.KeyID(), - }, - }, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, priv1) - if err != nil { - t.Errorf("tka.Create: %v", err) - return - } - - if !a.KeyTrusted(pub1.KeyID()) { - t.Error("pub1 want trusted, got untrusted") - } - if !a.KeyTrusted(pub2.KeyID()) { - t.Error("pub2 want trusted, got untrusted") - } - if a.KeyTrusted(pub3.KeyID()) { - t.Error("pub3 want untrusted, got trusted") - } -} - -func TestAuthorityCompact(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G -> A -> B -> C -> D -> E - - G.template = genesis - C.template = checkpoint2 - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optTemplate("checkpoint2", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optKey("key", key, priv), - optSignAllUsing("key")) - - storage := &FS{base: t.TempDir()} - a, err := Bootstrap(storage, c.AUMs["G"]) - if err != nil { - t.Fatalf("Bootstrap() failed: %v", err) - } - a.Inform(storage, []AUM{c.AUMs["A"], c.AUMs["B"], c.AUMs["C"], c.AUMs["D"], c.AUMs["E"]}) - - // Should compact down to C -> D -> E - if err := a.Compact(storage, CompactionOptions{MinChain: 2, MinAge: 1}); err != nil { - t.Fatal(err) - } - if a.oldestAncestor.Hash() != c.AUMHashes["C"] { - t.Errorf("ancestor = %v, want %v", a.oldestAncestor.Hash(), c.AUMHashes["C"]) - } - - // Make sure the stored authority is still openable and resolves to the same state. - stored, err := Open(storage) - if err != nil { - t.Fatalf("Failed to open stored authority: %v", err) - } - if stored.Head() != a.Head() { - t.Errorf("Stored authority head differs: head = %v, want %v", stored.Head(), a.Head()) - } - t.Logf("original ancestor = %v", c.AUMHashes["G"]) - if anc, _ := storage.LastActiveAncestor(); *anc != c.AUMHashes["C"] { - t.Errorf("ancestor = %v, want %v", anc, c.AUMHashes["C"]) - } -} - -func TestFindParentForRewrite(t *testing.T) { - pub, _ := testingKey25519(t, 1) - k1 := Key{Kind: Key25519, Public: pub, Votes: 1} - - pub2, _ := testingKey25519(t, 2) - k2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - k2ID, _ := k2.ID() - pub3, _ := testingKey25519(t, 3) - k3 := Key{Kind: Key25519, Public: pub3, Votes: 1} - - c := newTestchain(t, ` - A -> B -> C -> D -> E - A.template = genesis - B.template = add2 - C.template = add3 - D.template = remove2 - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{k1}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), - optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3}), - optTemplate("remove2", AUM{MessageKind: AUMRemoveKey, KeyID: k2ID})) - - a, err := Open(c.Chonk()) - if err != nil { - t.Fatal(err) - } - - // k1 was trusted at genesis, so there's no better rewrite parent - // than the genesis. - k1ID, _ := k1.ID() - k1P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k1ID}, k1ID) - if err != nil { - t.Fatalf("FindParentForRewrite(k1) failed: %v", err) - } - if k1P != a.oldestAncestor.Hash() { - t.Errorf("FindParentForRewrite(k1) = %v, want %v", k1P, a.oldestAncestor.Hash()) - } - - // k3 was trusted at C, so B would be an ideal rewrite point. - k3ID, _ := k3.ID() - k3P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k3ID}, k1ID) - if err != nil { - t.Fatalf("FindParentForRewrite(k3) failed: %v", err) - } - if k3P != c.AUMHashes["B"] { - t.Errorf("FindParentForRewrite(k3) = %v, want %v", k3P, c.AUMHashes["B"]) - } - - // k2 was added but then removed, so HEAD is an appropriate rewrite point. - k2P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k2ID}, k1ID) - if err != nil { - t.Fatalf("FindParentForRewrite(k2) failed: %v", err) - } - if k3P != c.AUMHashes["B"] { - t.Errorf("FindParentForRewrite(k2) = %v, want %v", k2P, a.Head()) - } - - // There's no appropriate point where both k2 and k3 are simultaneously not trusted, - // so the best rewrite point is the genesis AUM. - doubleP, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k2ID, k3ID}, k1ID) - if err != nil { - t.Fatalf("FindParentForRewrite({k2, k3}) failed: %v", err) - } - if doubleP != a.oldestAncestor.Hash() { - t.Errorf("FindParentForRewrite({k2, k3}) = %v, want %v", doubleP, a.oldestAncestor.Hash()) - } -} - -func TestMakeRetroactiveRevocation(t *testing.T) { - pub, _ := testingKey25519(t, 1) - k1 := Key{Kind: Key25519, Public: pub, Votes: 1} - - pub2, _ := testingKey25519(t, 2) - k2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - pub3, _ := testingKey25519(t, 3) - k3 := Key{Kind: Key25519, Public: pub3, Votes: 1} - - c := newTestchain(t, ` - A -> B -> C -> D - A.template = genesis - C.template = add2 - D.template = add3 - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{k1}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), - optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3})) - - a, err := Open(c.Chonk()) - if err != nil { - t.Fatal(err) - } - - // k2 was added by C, so a forking revocation should: - // - have B as a parent - // - trust the remaining keys at the time, k1 & k3. - k1ID, _ := k1.ID() - k2ID, _ := k2.ID() - k3ID, _ := k3.ID() - forkingAUM, err := a.MakeRetroactiveRevocation(c.Chonk(), []tkatype.KeyID{k2ID}, k1ID, AUMHash{}) - if err != nil { - t.Fatalf("MakeRetroactiveRevocation(k2) failed: %v", err) - } - if bHash := c.AUMHashes["B"]; !bytes.Equal(forkingAUM.PrevAUMHash, bHash[:]) { - t.Errorf("forking AUM has parent %v, want %v", forkingAUM.PrevAUMHash, bHash[:]) - } - if _, err := forkingAUM.State.GetKey(k1ID); err != nil { - t.Error("Forked state did not trust k1") - } - if _, err := forkingAUM.State.GetKey(k3ID); err != nil { - t.Error("Forked state did not trust k3") - } - if _, err := forkingAUM.State.GetKey(k2ID); err == nil { - t.Error("Forked state trusted removed-key k2") - } - - // Test that removing all trusted keys results in an error. - _, err = a.MakeRetroactiveRevocation(c.Chonk(), []tkatype.KeyID{k1ID, k2ID, k3ID}, k1ID, AUMHash{}) - if wantErr := "cannot revoke all trusted keys"; err == nil || err.Error() != wantErr { - t.Fatalf("MakeRetroactiveRevocation({k1, k2, k3}) returned %v, expected %q", err, wantErr) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/types/key" + "tailscale.com/types/tkatype" +) + +func TestComputeChainCandidates(t *testing.T) { + c := newTestchain(t, ` + G1 -> I1 -> I2 -> I3 -> L2 + | -> L1 | -> L3 + + G2 -> L4 + + // We tweak these AUMs so they are different hashes. + G2.hashSeed = 2 + L1.hashSeed = 2 + L3.hashSeed = 2 + L4.hashSeed = 3 + `) + // Should result in 4 chains: + // G1->L1, G1->L2, G1->L3, G2->L4 + + i1H := c.AUMHashes["I1"] + got, err := computeChainCandidates(c.Chonk(), &i1H, 50) + if err != nil { + t.Fatalf("computeChainCandidates() failed: %v", err) + } + + want := []chain{ + {Oldest: c.AUMs["G2"], Head: c.AUMs["L4"]}, + {Oldest: c.AUMs["G1"], Head: c.AUMs["L3"], chainsThroughActive: true}, + {Oldest: c.AUMs["G1"], Head: c.AUMs["L1"], chainsThroughActive: true}, + {Oldest: c.AUMs["G1"], Head: c.AUMs["L2"], chainsThroughActive: true}, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(chain{})); diff != "" { + t.Errorf("chains differ (-want, +got):\n%s", diff) + } +} + +func TestForkResolutionHash(t *testing.T) { + c := newTestchain(t, ` + G1 -> L1 + | -> L2 + + // tweak hashes so L1 & L2 are not identical + L1.hashSeed = 2 + L2.hashSeed = 3 + `) + + got, err := computeActiveChain(c.Chonk(), nil, 50) + if err != nil { + t.Fatalf("computeActiveChain() failed: %v", err) + } + + // The fork with the lowest AUM hash should have been chosen. + l1H := c.AUMHashes["L1"] + l2H := c.AUMHashes["L2"] + want := l1H + if bytes.Compare(l2H[:], l1H[:]) < 0 { + want = l2H + } + + if got := got.Head.Hash(); got != want { + t.Errorf("head was %x, want %x", got, want) + } +} + +func TestForkResolutionSigWeight(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 + | -> L2 + + G1.template = addKey + L1.hashSeed = 11 + L2.signedWith = key + `, + optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}), + optKey("key", key, priv)) + + l1H := c.AUMHashes["L1"] + l2H := c.AUMHashes["L2"] + if bytes.Compare(l2H[:], l1H[:]) < 0 { + t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes") + } + + got, err := computeActiveChain(c.Chonk(), nil, 50) + if err != nil { + t.Fatalf("computeActiveChain() failed: %v", err) + } + + // Based on the hash, l1H should be chosen. + // But based on the signature weight (which has higher + // precedence), it should be l2H + want := l2H + if got := got.Head.Hash(); got != want { + t.Errorf("head was %x, want %x", got, want) + } +} + +func TestForkResolutionMessageType(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 + | -> L2 + | -> L3 + + G1.template = addKey + L1.hashSeed = 11 + L2.template = removeKey + L3.hashSeed = 18 + `, + optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}), + optTemplate("removeKey", AUM{MessageKind: AUMRemoveKey, KeyID: key.MustID()})) + + l1H := c.AUMHashes["L1"] + l2H := c.AUMHashes["L2"] + l3H := c.AUMHashes["L3"] + if bytes.Compare(l2H[:], l1H[:]) < 0 { + t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes") + } + if bytes.Compare(l2H[:], l3H[:]) < 0 { + t.Fatal("failed assert: h(l3) > h(l2)\nTweak hashSeed till this passes") + } + + got, err := computeActiveChain(c.Chonk(), nil, 50) + if err != nil { + t.Fatalf("computeActiveChain() failed: %v", err) + } + + // Based on the hash, L1 or L3 should be chosen. + // But based on the preference for AUMRemoveKey messages, + // it should be L2. + want := l2H + if got := got.Head.Hash(); got != want { + t.Errorf("head was %x, want %x", got, want) + } +} + +func TestComputeStateAt(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> I1 -> I2 + I1.template = addKey + `, + optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key})) + + // G1 is before the key, so there shouldn't be a key there. + state, err := computeStateAt(c.Chonk(), 500, c.AUMHashes["G1"]) + if err != nil { + t.Fatalf("computeStateAt(G1) failed: %v", err) + } + if _, err := state.GetKey(key.MustID()); err != ErrNoSuchKey { + t.Errorf("expected key to be missing: err = %v", err) + } + if *state.LastAUMHash != c.AUMHashes["G1"] { + t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, c.AUMHashes["G1"]) + } + + // I1 & I2 are after the key, so the computed state should contain + // the key. + for _, wantHash := range []AUMHash{c.AUMHashes["I1"], c.AUMHashes["I2"]} { + state, err = computeStateAt(c.Chonk(), 500, wantHash) + if err != nil { + t.Fatalf("computeStateAt(%X) failed: %v", wantHash, err) + } + if *state.LastAUMHash != wantHash { + t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, wantHash) + } + if _, err := state.GetKey(key.MustID()); err != nil { + t.Errorf("expected key to be present at state: err = %v", err) + } + } +} + +// fakeAUM generates an AUM structure based on the template. +// If parent is provided, PrevAUMHash is set to that value. +// +// If template is an AUM, the returned AUM is based on that. +// If template is an int, a NOOP AUM is returned, and the +// provided int can be used to tweak the resulting hash (needed +// for tests you want one AUM to be 'lower' than another, so that +// that chain is taken based on fork resolution rules). +func fakeAUM(t *testing.T, template any, parent *AUMHash) (AUM, AUMHash) { + if seed, ok := template.(int); ok { + a := AUM{MessageKind: AUMNoOp, KeyID: []byte{byte(seed)}} + if parent != nil { + a.PrevAUMHash = (*parent)[:] + } + h := a.Hash() + return a, h + } + + if a, ok := template.(AUM); ok { + if parent != nil { + a.PrevAUMHash = (*parent)[:] + } + h := a.Hash() + return a, h + } + + panic("template must be an int or an AUM") +} + +func TestOpenAuthority(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + // /- L1 + // G1 - I1 - I2 - I3 -L2 + // \-L3 + // G2 - L4 + // + // We set the previous-known ancestor to G1, so the + // ancestor to start from should be G1. + g1, g1H := fakeAUM(t, AUM{MessageKind: AUMAddKey, Key: &key}, nil) + i1, i1H := fakeAUM(t, 2, &g1H) // AUM{MessageKind: AUMAddKey, Key: &key2} + l1, l1H := fakeAUM(t, 13, &i1H) + + i2, i2H := fakeAUM(t, 2, &i1H) + i3, i3H := fakeAUM(t, 5, &i2H) + l2, l2H := fakeAUM(t, AUM{MessageKind: AUMNoOp, KeyID: []byte{7}, Signatures: []tkatype.Signature{{KeyID: key.MustID()}}}, &i3H) + l3, l3H := fakeAUM(t, 4, &i3H) + + g2, g2H := fakeAUM(t, 8, nil) + l4, _ := fakeAUM(t, 9, &g2H) + + // We make sure that I2 has a lower hash than L1, so + // it should take that path rather than L1. + if bytes.Compare(l1H[:], i2H[:]) < 0 { + t.Fatal("failed assert: h(i2) > h(l1)\nTweak parameters to fakeAUM till this passes") + } + // We make sure L2 has a signature with key, so it should + // take that path over L3. We assert that the L3 hash + // is less than L2 so the test will fail if the signature + // preference logic is broken. + if bytes.Compare(l2H[:], l3H[:]) < 0 { + t.Fatal("failed assert: h(l3) > h(l2)\nTweak parameters to fakeAUM till this passes") + } + + // Construct the state of durable storage. + chonk := &Mem{} + err := chonk.CommitVerifiedAUMs([]AUM{g1, i1, l1, i2, i3, l2, l3, g2, l4}) + if err != nil { + t.Fatal(err) + } + chonk.SetLastActiveAncestor(i1H) + + a, err := Open(chonk) + if err != nil { + t.Fatalf("New() failed: %v", err) + } + // Should include the key added in G1 + if _, err := a.state.GetKey(key.MustID()); err != nil { + t.Errorf("missing G1 key: %v", err) + } + // The head of the chain should be L2. + if a.Head() != l2H { + t.Errorf("head was %x, want %x", a.state.LastAUMHash, l2H) + } +} + +func TestOpenAuthority_EmptyErrors(t *testing.T) { + _, err := Open(&Mem{}) + if err == nil { + t.Error("Expected an error initializing an empty authority, got nil") + } +} + +func TestAuthorityHead(t *testing.T) { + c := newTestchain(t, ` + G1 -> L1 + | -> L2 + + L1.hashSeed = 2 + `) + + a, _ := Open(c.Chonk()) + if got, want := a.head.Hash(), a.Head(); got != want { + t.Errorf("Hash() returned %x, want %x", got, want) + } +} + +func TestAuthorityValidDisablement(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + c := newTestchain(t, ` + G1 -> L1 + + G1.template = genesis + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + ) + + a, _ := Open(c.Chonk()) + if valid := a.ValidDisablement([]byte{1, 2, 3}); !valid { + t.Error("ValidDisablement() returned false, want true") + } +} + +func TestCreateBootstrapAuthority(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + a1, genesisAUM, err := Create(&Mem{}, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + a2, err := Bootstrap(&Mem{}, genesisAUM) + if err != nil { + t.Fatalf("Bootstrap() failed: %v", err) + } + + if a1.Head() != a2.Head() { + t.Fatal("created and bootstrapped authority differ") + } + + // Both authorities should trust the key laid down in the genesis state. + if !a1.KeyTrusted(key.MustID()) { + t.Error("a1 did not trust genesis key") + } + if !a2.KeyTrusted(key.MustID()) { + t.Error("a2 did not trust genesis key") + } +} + +func TestAuthorityInformNonLinear(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 + | -> L2 -> L3 + | -> L4 -> L5 + + G1.template = genesis + L1.hashSeed = 3 + L2.hashSeed = 2 + L4.hashSeed = 2 + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optKey("key", key, priv), + optSignAllUsing("key")) + + storage := &Mem{} + a, err := Bootstrap(storage, c.AUMs["G1"]) + if err != nil { + t.Fatalf("Bootstrap() failed: %v", err) + } + + // L2 does not chain from L1, disabling the isHeadChain optimization + // and forcing Inform() to take the slow path. + informAUMs := []AUM{c.AUMs["L1"], c.AUMs["L2"], c.AUMs["L3"], c.AUMs["L4"], c.AUMs["L5"]} + + if err := a.Inform(storage, informAUMs); err != nil { + t.Fatalf("Inform() failed: %v", err) + } + for i, update := range informAUMs { + stored, err := storage.AUM(update.Hash()) + if err != nil { + t.Errorf("reading stored update %d: %v", i, err) + continue + } + if diff := cmp.Diff(update, stored); diff != "" { + t.Errorf("update %d differs (-want, +got):\n%s", i, diff) + } + } + + if a.Head() != c.AUMHashes["L3"] { + t.Fatal("authority did not converge to correct AUM") + } +} + +func TestAuthorityInformLinear(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 -> L2 -> L3 + + G1.template = genesis + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optKey("key", key, priv), + optSignAllUsing("key")) + + storage := &Mem{} + a, err := Bootstrap(storage, c.AUMs["G1"]) + if err != nil { + t.Fatalf("Bootstrap() failed: %v", err) + } + + informAUMs := []AUM{c.AUMs["L1"], c.AUMs["L2"], c.AUMs["L3"]} + + if err := a.Inform(storage, informAUMs); err != nil { + t.Fatalf("Inform() failed: %v", err) + } + for i, update := range informAUMs { + stored, err := storage.AUM(update.Hash()) + if err != nil { + t.Errorf("reading stored update %d: %v", i, err) + continue + } + if diff := cmp.Diff(update, stored); diff != "" { + t.Errorf("update %d differs (-want, +got):\n%s", i, diff) + } + } + + if a.Head() != c.AUMHashes["L3"] { + t.Fatal("authority did not converge to correct AUM") + } +} + +func TestInteropWithNLKey(t *testing.T) { + priv1 := key.NewNLPrivate() + pub1 := priv1.Public() + pub2 := key.NewNLPrivate().Public() + pub3 := key.NewNLPrivate().Public() + + a, _, err := Create(&Mem{}, State{ + Keys: []Key{ + { + Kind: Key25519, + Votes: 1, + Public: pub1.KeyID(), + }, + { + Kind: Key25519, + Votes: 1, + Public: pub2.KeyID(), + }, + }, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, priv1) + if err != nil { + t.Errorf("tka.Create: %v", err) + return + } + + if !a.KeyTrusted(pub1.KeyID()) { + t.Error("pub1 want trusted, got untrusted") + } + if !a.KeyTrusted(pub2.KeyID()) { + t.Error("pub2 want trusted, got untrusted") + } + if a.KeyTrusted(pub3.KeyID()) { + t.Error("pub3 want untrusted, got trusted") + } +} + +func TestAuthorityCompact(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G -> A -> B -> C -> D -> E + + G.template = genesis + C.template = checkpoint2 + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optTemplate("checkpoint2", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optKey("key", key, priv), + optSignAllUsing("key")) + + storage := &FS{base: t.TempDir()} + a, err := Bootstrap(storage, c.AUMs["G"]) + if err != nil { + t.Fatalf("Bootstrap() failed: %v", err) + } + a.Inform(storage, []AUM{c.AUMs["A"], c.AUMs["B"], c.AUMs["C"], c.AUMs["D"], c.AUMs["E"]}) + + // Should compact down to C -> D -> E + if err := a.Compact(storage, CompactionOptions{MinChain: 2, MinAge: 1}); err != nil { + t.Fatal(err) + } + if a.oldestAncestor.Hash() != c.AUMHashes["C"] { + t.Errorf("ancestor = %v, want %v", a.oldestAncestor.Hash(), c.AUMHashes["C"]) + } + + // Make sure the stored authority is still openable and resolves to the same state. + stored, err := Open(storage) + if err != nil { + t.Fatalf("Failed to open stored authority: %v", err) + } + if stored.Head() != a.Head() { + t.Errorf("Stored authority head differs: head = %v, want %v", stored.Head(), a.Head()) + } + t.Logf("original ancestor = %v", c.AUMHashes["G"]) + if anc, _ := storage.LastActiveAncestor(); *anc != c.AUMHashes["C"] { + t.Errorf("ancestor = %v, want %v", anc, c.AUMHashes["C"]) + } +} + +func TestFindParentForRewrite(t *testing.T) { + pub, _ := testingKey25519(t, 1) + k1 := Key{Kind: Key25519, Public: pub, Votes: 1} + + pub2, _ := testingKey25519(t, 2) + k2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + k2ID, _ := k2.ID() + pub3, _ := testingKey25519(t, 3) + k3 := Key{Kind: Key25519, Public: pub3, Votes: 1} + + c := newTestchain(t, ` + A -> B -> C -> D -> E + A.template = genesis + B.template = add2 + C.template = add3 + D.template = remove2 + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{k1}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), + optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3}), + optTemplate("remove2", AUM{MessageKind: AUMRemoveKey, KeyID: k2ID})) + + a, err := Open(c.Chonk()) + if err != nil { + t.Fatal(err) + } + + // k1 was trusted at genesis, so there's no better rewrite parent + // than the genesis. + k1ID, _ := k1.ID() + k1P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k1ID}, k1ID) + if err != nil { + t.Fatalf("FindParentForRewrite(k1) failed: %v", err) + } + if k1P != a.oldestAncestor.Hash() { + t.Errorf("FindParentForRewrite(k1) = %v, want %v", k1P, a.oldestAncestor.Hash()) + } + + // k3 was trusted at C, so B would be an ideal rewrite point. + k3ID, _ := k3.ID() + k3P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k3ID}, k1ID) + if err != nil { + t.Fatalf("FindParentForRewrite(k3) failed: %v", err) + } + if k3P != c.AUMHashes["B"] { + t.Errorf("FindParentForRewrite(k3) = %v, want %v", k3P, c.AUMHashes["B"]) + } + + // k2 was added but then removed, so HEAD is an appropriate rewrite point. + k2P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k2ID}, k1ID) + if err != nil { + t.Fatalf("FindParentForRewrite(k2) failed: %v", err) + } + if k3P != c.AUMHashes["B"] { + t.Errorf("FindParentForRewrite(k2) = %v, want %v", k2P, a.Head()) + } + + // There's no appropriate point where both k2 and k3 are simultaneously not trusted, + // so the best rewrite point is the genesis AUM. + doubleP, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k2ID, k3ID}, k1ID) + if err != nil { + t.Fatalf("FindParentForRewrite({k2, k3}) failed: %v", err) + } + if doubleP != a.oldestAncestor.Hash() { + t.Errorf("FindParentForRewrite({k2, k3}) = %v, want %v", doubleP, a.oldestAncestor.Hash()) + } +} + +func TestMakeRetroactiveRevocation(t *testing.T) { + pub, _ := testingKey25519(t, 1) + k1 := Key{Kind: Key25519, Public: pub, Votes: 1} + + pub2, _ := testingKey25519(t, 2) + k2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + pub3, _ := testingKey25519(t, 3) + k3 := Key{Kind: Key25519, Public: pub3, Votes: 1} + + c := newTestchain(t, ` + A -> B -> C -> D + A.template = genesis + C.template = add2 + D.template = add3 + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{k1}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), + optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3})) + + a, err := Open(c.Chonk()) + if err != nil { + t.Fatal(err) + } + + // k2 was added by C, so a forking revocation should: + // - have B as a parent + // - trust the remaining keys at the time, k1 & k3. + k1ID, _ := k1.ID() + k2ID, _ := k2.ID() + k3ID, _ := k3.ID() + forkingAUM, err := a.MakeRetroactiveRevocation(c.Chonk(), []tkatype.KeyID{k2ID}, k1ID, AUMHash{}) + if err != nil { + t.Fatalf("MakeRetroactiveRevocation(k2) failed: %v", err) + } + if bHash := c.AUMHashes["B"]; !bytes.Equal(forkingAUM.PrevAUMHash, bHash[:]) { + t.Errorf("forking AUM has parent %v, want %v", forkingAUM.PrevAUMHash, bHash[:]) + } + if _, err := forkingAUM.State.GetKey(k1ID); err != nil { + t.Error("Forked state did not trust k1") + } + if _, err := forkingAUM.State.GetKey(k3ID); err != nil { + t.Error("Forked state did not trust k3") + } + if _, err := forkingAUM.State.GetKey(k2ID); err == nil { + t.Error("Forked state trusted removed-key k2") + } + + // Test that removing all trusted keys results in an error. + _, err = a.MakeRetroactiveRevocation(c.Chonk(), []tkatype.KeyID{k1ID, k2ID, k3ID}, k1ID, AUMHash{}) + if wantErr := "cannot revoke all trusted keys"; err == nil || err.Error() != wantErr { + t.Fatalf("MakeRetroactiveRevocation({k1, k2, k3}) returned %v, expected %q", err, wantErr) + } +} diff --git a/tool/binaryen.rev b/tool/binaryen.rev index 58c9bdf9d017f..e0d03ab88bb4a 100644 --- a/tool/binaryen.rev +++ b/tool/binaryen.rev @@ -1 +1 @@ -111 +111 diff --git a/tool/go b/tool/go index 1c53683d52f95..3c99f3e2fceeb 100755 --- a/tool/go +++ b/tool/go @@ -1,7 +1,7 @@ -#!/bin/sh -# -# This script acts like the "go" command, but uses Tailscale's -# currently-desired version from https://github.com/tailscale/go, -# downloading it first if necessary. - -exec "$(dirname "$0")/../tool/gocross/gocross-wrapper.sh" "$@" +#!/bin/sh +# +# This script acts like the "go" command, but uses Tailscale's +# currently-desired version from https://github.com/tailscale/go, +# downloading it first if necessary. + +exec "$(dirname "$0")/../tool/gocross/gocross-wrapper.sh" "$@" diff --git a/tool/gocross/env.go b/tool/gocross/env.go index 9d8a4f1b390b4..249476dc1b5a3 100644 --- a/tool/gocross/env.go +++ b/tool/gocross/env.go @@ -1,131 +1,131 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "fmt" - "os" - "sort" - "strings" -) - -// Environment starts from an initial set of environment variables, and tracks -// mutations to the environment. It can then apply those mutations to the -// environment, or produce debugging output that illustrates the changes it -// would make. -type Environment struct { - init map[string]string - set map[string]string - unset map[string]bool - - setenv func(string, string) error - unsetenv func(string) error -} - -// NewEnvironment returns an Environment initialized from os.Environ. -func NewEnvironment() *Environment { - init := map[string]string{} - for _, env := range os.Environ() { - fs := strings.SplitN(env, "=", 2) - if len(fs) != 2 { - panic("bad environ provided") - } - init[fs[0]] = fs[1] - } - - return newEnvironmentForTest(init, os.Setenv, os.Unsetenv) -} - -func newEnvironmentForTest(init map[string]string, setenv func(string, string) error, unsetenv func(string) error) *Environment { - return &Environment{ - init: init, - set: map[string]string{}, - unset: map[string]bool{}, - setenv: setenv, - unsetenv: unsetenv, - } -} - -// Set sets the environment variable k to v. -func (e *Environment) Set(k, v string) { - e.set[k] = v - delete(e.unset, k) -} - -// Unset removes the environment variable k. -func (e *Environment) Unset(k string) { - delete(e.set, k) - e.unset[k] = true -} - -// IsSet reports whether the environment variable k is set. -func (e *Environment) IsSet(k string) bool { - if e.unset[k] { - return false - } - if _, ok := e.init[k]; ok { - return true - } - if _, ok := e.set[k]; ok { - return true - } - return false -} - -// Get returns the value of the environment variable k, or defaultVal if it is -// not set. -func (e *Environment) Get(k, defaultVal string) string { - if e.unset[k] { - return defaultVal - } - if v, ok := e.set[k]; ok { - return v - } - if v, ok := e.init[k]; ok { - return v - } - return defaultVal -} - -// Apply applies all pending mutations to the environment. -func (e *Environment) Apply() error { - for k, v := range e.set { - if err := e.setenv(k, v); err != nil { - return fmt.Errorf("setting %q: %v", k, err) - } - e.init[k] = v - delete(e.set, k) - } - for k := range e.unset { - if err := e.unsetenv(k); err != nil { - return fmt.Errorf("unsetting %q: %v", k, err) - } - delete(e.init, k) - delete(e.unset, k) - } - return nil -} - -// Diff returns a string describing the pending mutations to the environment. -func (e *Environment) Diff() string { - lines := make([]string, 0, len(e.set)+len(e.unset)) - for k, v := range e.set { - old, ok := e.init[k] - if ok { - lines = append(lines, fmt.Sprintf("%s=%s (was %s)", k, v, old)) - } else { - lines = append(lines, fmt.Sprintf("%s=%s (was )", k, v)) - } - } - for k := range e.unset { - old, ok := e.init[k] - if ok { - lines = append(lines, fmt.Sprintf("%s= (was %s)", k, old)) - } else { - lines = append(lines, fmt.Sprintf("%s= (was )", k)) - } - } - sort.Strings(lines) - return strings.Join(lines, "\n") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "fmt" + "os" + "sort" + "strings" +) + +// Environment starts from an initial set of environment variables, and tracks +// mutations to the environment. It can then apply those mutations to the +// environment, or produce debugging output that illustrates the changes it +// would make. +type Environment struct { + init map[string]string + set map[string]string + unset map[string]bool + + setenv func(string, string) error + unsetenv func(string) error +} + +// NewEnvironment returns an Environment initialized from os.Environ. +func NewEnvironment() *Environment { + init := map[string]string{} + for _, env := range os.Environ() { + fs := strings.SplitN(env, "=", 2) + if len(fs) != 2 { + panic("bad environ provided") + } + init[fs[0]] = fs[1] + } + + return newEnvironmentForTest(init, os.Setenv, os.Unsetenv) +} + +func newEnvironmentForTest(init map[string]string, setenv func(string, string) error, unsetenv func(string) error) *Environment { + return &Environment{ + init: init, + set: map[string]string{}, + unset: map[string]bool{}, + setenv: setenv, + unsetenv: unsetenv, + } +} + +// Set sets the environment variable k to v. +func (e *Environment) Set(k, v string) { + e.set[k] = v + delete(e.unset, k) +} + +// Unset removes the environment variable k. +func (e *Environment) Unset(k string) { + delete(e.set, k) + e.unset[k] = true +} + +// IsSet reports whether the environment variable k is set. +func (e *Environment) IsSet(k string) bool { + if e.unset[k] { + return false + } + if _, ok := e.init[k]; ok { + return true + } + if _, ok := e.set[k]; ok { + return true + } + return false +} + +// Get returns the value of the environment variable k, or defaultVal if it is +// not set. +func (e *Environment) Get(k, defaultVal string) string { + if e.unset[k] { + return defaultVal + } + if v, ok := e.set[k]; ok { + return v + } + if v, ok := e.init[k]; ok { + return v + } + return defaultVal +} + +// Apply applies all pending mutations to the environment. +func (e *Environment) Apply() error { + for k, v := range e.set { + if err := e.setenv(k, v); err != nil { + return fmt.Errorf("setting %q: %v", k, err) + } + e.init[k] = v + delete(e.set, k) + } + for k := range e.unset { + if err := e.unsetenv(k); err != nil { + return fmt.Errorf("unsetting %q: %v", k, err) + } + delete(e.init, k) + delete(e.unset, k) + } + return nil +} + +// Diff returns a string describing the pending mutations to the environment. +func (e *Environment) Diff() string { + lines := make([]string, 0, len(e.set)+len(e.unset)) + for k, v := range e.set { + old, ok := e.init[k] + if ok { + lines = append(lines, fmt.Sprintf("%s=%s (was %s)", k, v, old)) + } else { + lines = append(lines, fmt.Sprintf("%s=%s (was )", k, v)) + } + } + for k := range e.unset { + old, ok := e.init[k] + if ok { + lines = append(lines, fmt.Sprintf("%s= (was %s)", k, old)) + } else { + lines = append(lines, fmt.Sprintf("%s= (was )", k)) + } + } + sort.Strings(lines) + return strings.Join(lines, "\n") +} diff --git a/tool/gocross/env_test.go b/tool/gocross/env_test.go index 001487bb8e1a6..9a797530d72cd 100644 --- a/tool/gocross/env_test.go +++ b/tool/gocross/env_test.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestEnv(t *testing.T) { - - var ( - init = map[string]string{ - "FOO": "bar", - } - - wasSet = map[string]string{} - wasUnset = map[string]bool{} - - setenv = func(k, v string) error { - wasSet[k] = v - return nil - } - unsetenv = func(k string) error { - wasUnset[k] = true - return nil - } - ) - - env := newEnvironmentForTest(init, setenv, unsetenv) - - if got, want := env.Get("FOO", ""), "bar"; got != want { - t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) - } - if got, want := env.IsSet("FOO"), true; got != want { - t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) - } - - if got, want := env.Get("BAR", "defaultVal"), "defaultVal"; got != want { - t.Errorf(`env.Get("BAR") = %q, want %q`, got, want) - } - if got, want := env.IsSet("BAR"), false; got != want { - t.Errorf(`env.IsSet("BAR") = %v, want %v`, got, want) - } - - env.Set("BAR", "quux") - if got, want := env.Get("BAR", ""), "quux"; got != want { - t.Errorf(`env.Get("BAR") = %q, want %q`, got, want) - } - if got, want := env.IsSet("BAR"), true; got != want { - t.Errorf(`env.IsSet("BAR") = %v, want %v`, got, want) - } - diff := "BAR=quux (was )" - if got := env.Diff(); got != diff { - t.Errorf("env.Diff() = %q, want %q", got, diff) - } - - env.Set("FOO", "foo2") - if got, want := env.Get("FOO", ""), "foo2"; got != want { - t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) - } - if got, want := env.IsSet("FOO"), true; got != want { - t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) - } - diff = `BAR=quux (was ) -FOO=foo2 (was bar)` - if got := env.Diff(); got != diff { - t.Errorf("env.Diff() = %q, want %q", got, diff) - } - - env.Unset("FOO") - if got, want := env.Get("FOO", "default"), "default"; got != want { - t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) - } - if got, want := env.IsSet("FOO"), false; got != want { - t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) - } - diff = `BAR=quux (was ) -FOO= (was bar)` - if got := env.Diff(); got != diff { - t.Errorf("env.Diff() = %q, want %q", got, diff) - } - - if err := env.Apply(); err != nil { - t.Fatalf("env.Apply() failed: %v", err) - } - - wantSet := map[string]string{"BAR": "quux"} - wantUnset := map[string]bool{"FOO": true} - - if diff := cmp.Diff(wasSet, wantSet); diff != "" { - t.Errorf("env.Apply didn't set as expected (-got+want):\n%s", diff) - } - if diff := cmp.Diff(wasUnset, wantUnset); diff != "" { - t.Errorf("env.Apply didn't unset as expected (-got+want):\n%s", diff) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestEnv(t *testing.T) { + + var ( + init = map[string]string{ + "FOO": "bar", + } + + wasSet = map[string]string{} + wasUnset = map[string]bool{} + + setenv = func(k, v string) error { + wasSet[k] = v + return nil + } + unsetenv = func(k string) error { + wasUnset[k] = true + return nil + } + ) + + env := newEnvironmentForTest(init, setenv, unsetenv) + + if got, want := env.Get("FOO", ""), "bar"; got != want { + t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) + } + if got, want := env.IsSet("FOO"), true; got != want { + t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) + } + + if got, want := env.Get("BAR", "defaultVal"), "defaultVal"; got != want { + t.Errorf(`env.Get("BAR") = %q, want %q`, got, want) + } + if got, want := env.IsSet("BAR"), false; got != want { + t.Errorf(`env.IsSet("BAR") = %v, want %v`, got, want) + } + + env.Set("BAR", "quux") + if got, want := env.Get("BAR", ""), "quux"; got != want { + t.Errorf(`env.Get("BAR") = %q, want %q`, got, want) + } + if got, want := env.IsSet("BAR"), true; got != want { + t.Errorf(`env.IsSet("BAR") = %v, want %v`, got, want) + } + diff := "BAR=quux (was )" + if got := env.Diff(); got != diff { + t.Errorf("env.Diff() = %q, want %q", got, diff) + } + + env.Set("FOO", "foo2") + if got, want := env.Get("FOO", ""), "foo2"; got != want { + t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) + } + if got, want := env.IsSet("FOO"), true; got != want { + t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) + } + diff = `BAR=quux (was ) +FOO=foo2 (was bar)` + if got := env.Diff(); got != diff { + t.Errorf("env.Diff() = %q, want %q", got, diff) + } + + env.Unset("FOO") + if got, want := env.Get("FOO", "default"), "default"; got != want { + t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) + } + if got, want := env.IsSet("FOO"), false; got != want { + t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) + } + diff = `BAR=quux (was ) +FOO= (was bar)` + if got := env.Diff(); got != diff { + t.Errorf("env.Diff() = %q, want %q", got, diff) + } + + if err := env.Apply(); err != nil { + t.Fatalf("env.Apply() failed: %v", err) + } + + wantSet := map[string]string{"BAR": "quux"} + wantUnset := map[string]bool{"FOO": true} + + if diff := cmp.Diff(wasSet, wantSet); diff != "" { + t.Errorf("env.Apply didn't set as expected (-got+want):\n%s", diff) + } + if diff := cmp.Diff(wasUnset, wantUnset); diff != "" { + t.Errorf("env.Apply didn't unset as expected (-got+want):\n%s", diff) + } +} diff --git a/tool/gocross/exec_other.go b/tool/gocross/exec_other.go index 8d4df0db334dd..ec9663df7c7d9 100644 --- a/tool/gocross/exec_other.go +++ b/tool/gocross/exec_other.go @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !unix - -package main - -import ( - "os" - "os/exec" -) - -func doExec(cmd string, args []string, env []string) error { - c := exec.Command(cmd, args...) - c.Env = env - c.Stdin = os.Stdin - c.Stdout = os.Stdout - c.Stderr = os.Stderr - return c.Run() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !unix + +package main + +import ( + "os" + "os/exec" +) + +func doExec(cmd string, args []string, env []string) error { + c := exec.Command(cmd, args...) + c.Env = env + c.Stdin = os.Stdin + c.Stdout = os.Stdout + c.Stderr = os.Stderr + return c.Run() +} diff --git a/tool/gocross/exec_unix.go b/tool/gocross/exec_unix.go index 79cbf764ad2f6..eeffd5f939aab 100644 --- a/tool/gocross/exec_unix.go +++ b/tool/gocross/exec_unix.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build unix - -package main - -import "golang.org/x/sys/unix" - -func doExec(cmd string, args []string, env []string) error { - return unix.Exec(cmd, args, env) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build unix + +package main + +import "golang.org/x/sys/unix" + +func doExec(cmd string, args []string, env []string) error { + return unix.Exec(cmd, args, env) +} diff --git a/tool/helm b/tool/helm index 3f9a9dfd5ba21..8cbc2f2065ead 100755 --- a/tool/helm +++ b/tool/helm @@ -1,69 +1,69 @@ -#!/usr/bin/env bash - -# installs $(cat ./helm.rev) version of helm as $HOME/.cache/tailscale-helm - -set -euo pipefail - -if [[ "${CI:-}" == "true" ]]; then - set -x -fi - -( - if [[ "${CI:-}" == "true" ]]; then - set -x - fi - - repo_root="${BASH_SOURCE%/*}/../" - cd "$repo_root" - - cachedir="$HOME/.cache/tailscale-helm" - tarball="${cachedir}.tar.gz" - - read -r want_rev < "$(dirname "$0")/helm.rev" - - got_rev="" - if [[ -x "${cachedir}/helm" ]]; then - got_rev=$("${cachedir}/helm" version --short) - got_rev="${got_rev#v}" # trim the leading 'v' - got_rev="${got_rev%+*}" # trim the trailing '+" followed by a commit SHA' - - - fi - - if [[ "$want_rev" != "$got_rev" ]]; then - rm -rf "$cachedir" "$tarball" - if [[ -n "${IN_NIX_SHELL:-}" ]]; then - nix_helm="$(which -a helm | grep /nix/store | head -1)" - nix_helm="${nix_helm%/helm}" - nix_helm_rev="${nix_helm##*-}" - if [[ "$nix_helm_rev" != "$want_rev" ]]; then - echo "Wrong helm version in Nix, got $nix_helm_rev want $want_rev" >&2 - exit 1 - fi - ln -sf "$nix_helm" "$cachedir" - else - # works for linux and darwin - # https://github.com/helm/helm/releases - OS=$(uname -s | tr A-Z a-z) - ARCH=$(uname -m) - if [ "$ARCH" = "x86_64" ]; then - ARCH="amd64" - fi - if [ "$ARCH" = "aarch64" ]; then - ARCH="arm64" - fi - mkdir -p "$cachedir" - # When running on GitHub in CI, the below curl sometimes fails with - # INTERNAL_ERROR after finishing the download. The most common cause - # of INTERNAL_ERROR is glitches in intermediate hosts handling of - # HTTP/2 forwarding, so forcing HTTP 1.1 often fixes the issue. See - # https://github.com/tailscale/tailscale/issues/8988 - curl -f -L --http1.1 -o "$tarball" -sSL "https://get.helm.sh/helm-v${want_rev}-${OS}-${ARCH}.tar.gz" - (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") - rm -f "$tarball" - fi - fi -) - -export PATH="$HOME/.cache/tailscale-helm:$PATH" -exec "$HOME/.cache/tailscale-helm/helm" "$@" +#!/usr/bin/env bash + +# installs $(cat ./helm.rev) version of helm as $HOME/.cache/tailscale-helm + +set -euo pipefail + +if [[ "${CI:-}" == "true" ]]; then + set -x +fi + +( + if [[ "${CI:-}" == "true" ]]; then + set -x + fi + + repo_root="${BASH_SOURCE%/*}/../" + cd "$repo_root" + + cachedir="$HOME/.cache/tailscale-helm" + tarball="${cachedir}.tar.gz" + + read -r want_rev < "$(dirname "$0")/helm.rev" + + got_rev="" + if [[ -x "${cachedir}/helm" ]]; then + got_rev=$("${cachedir}/helm" version --short) + got_rev="${got_rev#v}" # trim the leading 'v' + got_rev="${got_rev%+*}" # trim the trailing '+" followed by a commit SHA' + + + fi + + if [[ "$want_rev" != "$got_rev" ]]; then + rm -rf "$cachedir" "$tarball" + if [[ -n "${IN_NIX_SHELL:-}" ]]; then + nix_helm="$(which -a helm | grep /nix/store | head -1)" + nix_helm="${nix_helm%/helm}" + nix_helm_rev="${nix_helm##*-}" + if [[ "$nix_helm_rev" != "$want_rev" ]]; then + echo "Wrong helm version in Nix, got $nix_helm_rev want $want_rev" >&2 + exit 1 + fi + ln -sf "$nix_helm" "$cachedir" + else + # works for linux and darwin + # https://github.com/helm/helm/releases + OS=$(uname -s | tr A-Z a-z) + ARCH=$(uname -m) + if [ "$ARCH" = "x86_64" ]; then + ARCH="amd64" + fi + if [ "$ARCH" = "aarch64" ]; then + ARCH="arm64" + fi + mkdir -p "$cachedir" + # When running on GitHub in CI, the below curl sometimes fails with + # INTERNAL_ERROR after finishing the download. The most common cause + # of INTERNAL_ERROR is glitches in intermediate hosts handling of + # HTTP/2 forwarding, so forcing HTTP 1.1 often fixes the issue. See + # https://github.com/tailscale/tailscale/issues/8988 + curl -f -L --http1.1 -o "$tarball" -sSL "https://get.helm.sh/helm-v${want_rev}-${OS}-${ARCH}.tar.gz" + (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") + rm -f "$tarball" + fi + fi +) + +export PATH="$HOME/.cache/tailscale-helm:$PATH" +exec "$HOME/.cache/tailscale-helm/helm" "$@" diff --git a/tool/helm.rev b/tool/helm.rev index c10780c628ad5..0d0e48dd05bb9 100644 --- a/tool/helm.rev +++ b/tool/helm.rev @@ -1 +1 @@ -3.13.1 +3.13.1 diff --git a/tool/node b/tool/node index 310140ae5bfa0..7e96826f34988 100755 --- a/tool/node +++ b/tool/node @@ -1,65 +1,65 @@ -#!/usr/bin/env bash -# Run a command with our local node install, rather than any globally installed -# instance. - -set -euo pipefail - -if [[ "${CI:-}" == "true" ]]; then - set -x -fi - -( - if [[ "${CI:-}" == "true" ]]; then - set -x - fi - - repo_root="${BASH_SOURCE%/*}/../" - cd "$repo_root" - - cachedir="$HOME/.cache/tailscale-node" - tarball="${cachedir}.tar.gz" - - read -r want_rev < "$(dirname "$0")/node.rev" - - got_rev="" - if [[ -x "${cachedir}/bin/node" ]]; then - got_rev=$("${cachedir}/bin/node" --version) - got_rev="${got_rev#v}" # trim the leading 'v' - fi - - if [[ "$want_rev" != "$got_rev" ]]; then - rm -rf "$cachedir" "$tarball" - if [[ -n "${IN_NIX_SHELL:-}" ]]; then - nix_node="$(which -a node | grep /nix/store | head -1)" - nix_node="${nix_node%/bin/node}" - nix_node_rev="${nix_node##*-}" - if [[ "$nix_node_rev" != "$want_rev" ]]; then - echo "Wrong node version in Nix, got $nix_node_rev want $want_rev" >&2 - exit 1 - fi - ln -sf "$nix_node" "$cachedir" - else - # works for "linux" and "darwin" - OS=$(uname -s | tr A-Z a-z) - ARCH=$(uname -m) - if [ "$ARCH" = "x86_64" ]; then - ARCH="x64" - fi - if [ "$ARCH" = "aarch64" ]; then - ARCH="arm64" - fi - mkdir -p "$cachedir" - # When running on GitHub in CI, the below curl sometimes fails with - # INTERNAL_ERROR after finishing the download. The most common cause - # of INTERNAL_ERROR is glitches in intermediate hosts handling of - # HTTP/2 forwarding, so forcing HTTP 1.1 often fixes the issue. See - # https://github.com/tailscale/tailscale/issues/8988 - curl -f -L --http1.1 -o "$tarball" "https://nodejs.org/dist/v${want_rev}/node-v${want_rev}-${OS}-${ARCH}.tar.gz" - (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") - rm -f "$tarball" - fi - fi -) - -export PATH="$HOME/.cache/tailscale-node/bin:$PATH" -exec "$HOME/.cache/tailscale-node/bin/node" "$@" +#!/usr/bin/env bash +# Run a command with our local node install, rather than any globally installed +# instance. + +set -euo pipefail + +if [[ "${CI:-}" == "true" ]]; then + set -x +fi + +( + if [[ "${CI:-}" == "true" ]]; then + set -x + fi + + repo_root="${BASH_SOURCE%/*}/../" + cd "$repo_root" + + cachedir="$HOME/.cache/tailscale-node" + tarball="${cachedir}.tar.gz" + + read -r want_rev < "$(dirname "$0")/node.rev" + + got_rev="" + if [[ -x "${cachedir}/bin/node" ]]; then + got_rev=$("${cachedir}/bin/node" --version) + got_rev="${got_rev#v}" # trim the leading 'v' + fi + + if [[ "$want_rev" != "$got_rev" ]]; then + rm -rf "$cachedir" "$tarball" + if [[ -n "${IN_NIX_SHELL:-}" ]]; then + nix_node="$(which -a node | grep /nix/store | head -1)" + nix_node="${nix_node%/bin/node}" + nix_node_rev="${nix_node##*-}" + if [[ "$nix_node_rev" != "$want_rev" ]]; then + echo "Wrong node version in Nix, got $nix_node_rev want $want_rev" >&2 + exit 1 + fi + ln -sf "$nix_node" "$cachedir" + else + # works for "linux" and "darwin" + OS=$(uname -s | tr A-Z a-z) + ARCH=$(uname -m) + if [ "$ARCH" = "x86_64" ]; then + ARCH="x64" + fi + if [ "$ARCH" = "aarch64" ]; then + ARCH="arm64" + fi + mkdir -p "$cachedir" + # When running on GitHub in CI, the below curl sometimes fails with + # INTERNAL_ERROR after finishing the download. The most common cause + # of INTERNAL_ERROR is glitches in intermediate hosts handling of + # HTTP/2 forwarding, so forcing HTTP 1.1 often fixes the issue. See + # https://github.com/tailscale/tailscale/issues/8988 + curl -f -L --http1.1 -o "$tarball" "https://nodejs.org/dist/v${want_rev}/node-v${want_rev}-${OS}-${ARCH}.tar.gz" + (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") + rm -f "$tarball" + fi + fi +) + +export PATH="$HOME/.cache/tailscale-node/bin:$PATH" +exec "$HOME/.cache/tailscale-node/bin/node" "$@" diff --git a/tool/wasm-opt b/tool/wasm-opt index 08f3e5bfbb841..88d332f0b2ca4 100755 --- a/tool/wasm-opt +++ b/tool/wasm-opt @@ -1,74 +1,74 @@ -#!/bin/sh -# -# This script acts like the "wasm-opt" command from the Binaryen toolchain, but -# uses Tailscale's currently-desired version, downloading it first if necessary. - -set -eu - -BINARYEN_DIR="$HOME/.cache/tailscale-binaryen" -read -r BINARYEN_REV < "$(dirname "$0")/binaryen.rev" -# This works for Linux and Darwin, which is sufficient -# (we do not build for other targets). -OS=$(uname -s | tr A-Z a-z) -if [ "$OS" = "darwin" ]; then - # Binaryen uses the name "macos". - OS="macos" -fi -ARCH="$(uname -m)" -if [ "$ARCH" = "aarch64" ]; then - # Binaryen uses the name "arm64". - ARCH="arm64" -fi - -install_binaryen() { - BINARYEN_URL="https://github.com/WebAssembly/binaryen/releases/download/version_${BINARYEN_REV}/binaryen-version_${BINARYEN_REV}-${ARCH}-${OS}.tar.gz" - install_tool "wasm-opt" $BINARYEN_REV $BINARYEN_DIR $BINARYEN_URL -} - -install_tool() { - TOOL=$1 - REV=$2 - TOOLCHAIN=$3 - URL=$4 - - archive="$TOOLCHAIN-$REV.tar.gz" - mark="$TOOLCHAIN.extracted" - extracted= - [ ! -e "$mark" ] || read -r extracted junk <$mark - - if [ "$extracted" = "$REV" ] && [ -e "$TOOLCHAIN/bin/$TOOL" ]; then - # Already extracted, continue silently - return 0 - fi - echo "" - - rm -f "$archive.new" "$TOOLCHAIN.extracted" - if [ ! -e "$archive" ]; then - log "Need to download $TOOL '$REV' from $URL." - curl -f -L -o "$archive.new" $URL - rm -f "$archive" - mv "$archive.new" "$archive" - fi - - log "Extracting $TOOL '$REV' into '$TOOLCHAIN'." >&2 - rm -rf "$TOOLCHAIN" - mkdir -p "$TOOLCHAIN" - (cd "$TOOLCHAIN" && tar --strip-components=1 -xf "$archive") - echo "$REV" >$mark -} - -log() { - echo "$@" >&2 -} - -if [ "${BINARYEN_DIR}" = "SKIP" ] || - [ "${OS}" != "macos" -a "${OS}" != "linux" ] || - [ "${ARCH}" != "x86_64" -a "${ARCH}" != "arm64" ]; then - log "Unsupported OS (${OS}) and architecture (${ARCH}) combination." - log "Using existing wasm-opt (`which wasm-opt`)." - exec wasm-opt "$@" -fi - -install_binaryen - -"$BINARYEN_DIR/bin/wasm-opt" "$@" +#!/bin/sh +# +# This script acts like the "wasm-opt" command from the Binaryen toolchain, but +# uses Tailscale's currently-desired version, downloading it first if necessary. + +set -eu + +BINARYEN_DIR="$HOME/.cache/tailscale-binaryen" +read -r BINARYEN_REV < "$(dirname "$0")/binaryen.rev" +# This works for Linux and Darwin, which is sufficient +# (we do not build for other targets). +OS=$(uname -s | tr A-Z a-z) +if [ "$OS" = "darwin" ]; then + # Binaryen uses the name "macos". + OS="macos" +fi +ARCH="$(uname -m)" +if [ "$ARCH" = "aarch64" ]; then + # Binaryen uses the name "arm64". + ARCH="arm64" +fi + +install_binaryen() { + BINARYEN_URL="https://github.com/WebAssembly/binaryen/releases/download/version_${BINARYEN_REV}/binaryen-version_${BINARYEN_REV}-${ARCH}-${OS}.tar.gz" + install_tool "wasm-opt" $BINARYEN_REV $BINARYEN_DIR $BINARYEN_URL +} + +install_tool() { + TOOL=$1 + REV=$2 + TOOLCHAIN=$3 + URL=$4 + + archive="$TOOLCHAIN-$REV.tar.gz" + mark="$TOOLCHAIN.extracted" + extracted= + [ ! -e "$mark" ] || read -r extracted junk <$mark + + if [ "$extracted" = "$REV" ] && [ -e "$TOOLCHAIN/bin/$TOOL" ]; then + # Already extracted, continue silently + return 0 + fi + echo "" + + rm -f "$archive.new" "$TOOLCHAIN.extracted" + if [ ! -e "$archive" ]; then + log "Need to download $TOOL '$REV' from $URL." + curl -f -L -o "$archive.new" $URL + rm -f "$archive" + mv "$archive.new" "$archive" + fi + + log "Extracting $TOOL '$REV' into '$TOOLCHAIN'." >&2 + rm -rf "$TOOLCHAIN" + mkdir -p "$TOOLCHAIN" + (cd "$TOOLCHAIN" && tar --strip-components=1 -xf "$archive") + echo "$REV" >$mark +} + +log() { + echo "$@" >&2 +} + +if [ "${BINARYEN_DIR}" = "SKIP" ] || + [ "${OS}" != "macos" -a "${OS}" != "linux" ] || + [ "${ARCH}" != "x86_64" -a "${ARCH}" != "arm64" ]; then + log "Unsupported OS (${OS}) and architecture (${ARCH}) combination." + log "Using existing wasm-opt (`which wasm-opt`)." + exec wasm-opt "$@" +fi + +install_binaryen + +"$BINARYEN_DIR/bin/wasm-opt" "$@" diff --git a/tool/yarn b/tool/yarn index 6357beda61cb9..6bb01d2f223de 100755 --- a/tool/yarn +++ b/tool/yarn @@ -1,43 +1,43 @@ -#!/usr/bin/env bash -# Run a command with our local yarn install, rather than any globally installed -# instance. - -set -euo pipefail - -if [[ "${CI:-}" == "true" ]]; then - set -x -fi - -( - if [[ "${CI:-}" == "true" ]]; then - set -x - fi - - repo_root="${BASH_SOURCE%/*}/../" - cd "$repo_root" - - ./tool/node --version >/dev/null # Ensure node is unpacked and ready - - cachedir="$HOME/.cache/tailscale-yarn" - tarball="${cachedir}.tar.gz" - - read -r want_rev < "./tool/yarn.rev" - - got_rev="" - if [[ -x "${cachedir}/bin/yarn" ]]; then - got_rev=$(PATH="$HOME/.cache/tailscale-node/bin:$PATH" "${cachedir}/bin/yarn" --version) - fi - - if [[ "$want_rev" != "$got_rev" ]]; then - rm -rf "$cachedir" "$tarball" - mkdir -p "$cachedir" - curl -f -L -o "$tarball" "https://github.com/yarnpkg/yarn/releases/download/v${want_rev}/yarn-v${want_rev}.tar.gz" - (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") - rm -f "$tarball" - fi -) - -# Deliberately not using cachedir here, to keep the environment -# completely pristine for execution of yarn. -export PATH="$HOME/.cache/tailscale-node/bin:$HOME/.cache/tailscale-yarn/bin:$PATH" -exec "$HOME/.cache/tailscale-yarn/bin/yarn" "$@" +#!/usr/bin/env bash +# Run a command with our local yarn install, rather than any globally installed +# instance. + +set -euo pipefail + +if [[ "${CI:-}" == "true" ]]; then + set -x +fi + +( + if [[ "${CI:-}" == "true" ]]; then + set -x + fi + + repo_root="${BASH_SOURCE%/*}/../" + cd "$repo_root" + + ./tool/node --version >/dev/null # Ensure node is unpacked and ready + + cachedir="$HOME/.cache/tailscale-yarn" + tarball="${cachedir}.tar.gz" + + read -r want_rev < "./tool/yarn.rev" + + got_rev="" + if [[ -x "${cachedir}/bin/yarn" ]]; then + got_rev=$(PATH="$HOME/.cache/tailscale-node/bin:$PATH" "${cachedir}/bin/yarn" --version) + fi + + if [[ "$want_rev" != "$got_rev" ]]; then + rm -rf "$cachedir" "$tarball" + mkdir -p "$cachedir" + curl -f -L -o "$tarball" "https://github.com/yarnpkg/yarn/releases/download/v${want_rev}/yarn-v${want_rev}.tar.gz" + (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") + rm -f "$tarball" + fi +) + +# Deliberately not using cachedir here, to keep the environment +# completely pristine for execution of yarn. +export PATH="$HOME/.cache/tailscale-node/bin:$HOME/.cache/tailscale-yarn/bin:$PATH" +exec "$HOME/.cache/tailscale-yarn/bin/yarn" "$@" diff --git a/tool/yarn.rev b/tool/yarn.rev index de5856e86ba27..736c4acbded70 100644 --- a/tool/yarn.rev +++ b/tool/yarn.rev @@ -1 +1 @@ -1.22.19 +1.22.19 diff --git a/tsnet/example/tshello/tshello.go b/tsnet/example/tshello/tshello.go index 0cadcdd837d99..2110c4d9699d8 100644 --- a/tsnet/example/tshello/tshello.go +++ b/tsnet/example/tshello/tshello.go @@ -1,60 +1,60 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The tshello server demonstrates how to use Tailscale as a library. -package main - -import ( - "crypto/tls" - "flag" - "fmt" - "html" - "log" - "net/http" - "strings" - - "tailscale.com/tsnet" -) - -var ( - addr = flag.String("addr", ":80", "address to listen on") -) - -func main() { - flag.Parse() - s := new(tsnet.Server) - defer s.Close() - ln, err := s.Listen("tcp", *addr) - if err != nil { - log.Fatal(err) - } - defer ln.Close() - - lc, err := s.LocalClient() - if err != nil { - log.Fatal(err) - } - - if *addr == ":443" { - ln = tls.NewListener(ln, &tls.Config{ - GetCertificate: lc.GetCertificate, - }) - } - log.Fatal(http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - who, err := lc.WhoIs(r.Context(), r.RemoteAddr) - if err != nil { - http.Error(w, err.Error(), 500) - return - } - fmt.Fprintf(w, "

Hello, world!

\n") - fmt.Fprintf(w, "

You are %s from %s (%s)

", - html.EscapeString(who.UserProfile.LoginName), - html.EscapeString(firstLabel(who.Node.ComputedName)), - r.RemoteAddr) - }))) -} - -func firstLabel(s string) string { - s, _, _ = strings.Cut(s, ".") - return s -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The tshello server demonstrates how to use Tailscale as a library. +package main + +import ( + "crypto/tls" + "flag" + "fmt" + "html" + "log" + "net/http" + "strings" + + "tailscale.com/tsnet" +) + +var ( + addr = flag.String("addr", ":80", "address to listen on") +) + +func main() { + flag.Parse() + s := new(tsnet.Server) + defer s.Close() + ln, err := s.Listen("tcp", *addr) + if err != nil { + log.Fatal(err) + } + defer ln.Close() + + lc, err := s.LocalClient() + if err != nil { + log.Fatal(err) + } + + if *addr == ":443" { + ln = tls.NewListener(ln, &tls.Config{ + GetCertificate: lc.GetCertificate, + }) + } + log.Fatal(http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + who, err := lc.WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + fmt.Fprintf(w, "

Hello, world!

\n") + fmt.Fprintf(w, "

You are %s from %s (%s)

", + html.EscapeString(who.UserProfile.LoginName), + html.EscapeString(firstLabel(who.Node.ComputedName)), + r.RemoteAddr) + }))) +} + +func firstLabel(s string) string { + s, _, _ = strings.Cut(s, ".") + return s +} diff --git a/tsnet/example/tsnet-http-client/tsnet-http-client.go b/tsnet/example/tsnet-http-client/tsnet-http-client.go index 9666fe9992745..cda52eef75ac1 100644 --- a/tsnet/example/tsnet-http-client/tsnet-http-client.go +++ b/tsnet/example/tsnet-http-client/tsnet-http-client.go @@ -1,44 +1,44 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The tshello server demonstrates how to use Tailscale as a library. -package main - -import ( - "flag" - "fmt" - "log" - "os" - "path/filepath" - - "tailscale.com/tsnet" -) - -func main() { - flag.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: %s \n", filepath.Base(os.Args[0])) - os.Exit(2) - } - flag.Parse() - - if flag.NArg() != 1 { - flag.Usage() - } - tailnetURL := flag.Arg(0) - - s := new(tsnet.Server) - defer s.Close() - - if err := s.Start(); err != nil { - log.Fatal(err) - } - - cli := s.HTTPClient() - - resp, err := cli.Get(tailnetURL) - if err != nil { - log.Fatal(err) - } - - resp.Write(os.Stdout) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The tshello server demonstrates how to use Tailscale as a library. +package main + +import ( + "flag" + "fmt" + "log" + "os" + "path/filepath" + + "tailscale.com/tsnet" +) + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s \n", filepath.Base(os.Args[0])) + os.Exit(2) + } + flag.Parse() + + if flag.NArg() != 1 { + flag.Usage() + } + tailnetURL := flag.Arg(0) + + s := new(tsnet.Server) + defer s.Close() + + if err := s.Start(); err != nil { + log.Fatal(err) + } + + cli := s.HTTPClient() + + resp, err := cli.Get(tailnetURL) + if err != nil { + log.Fatal(err) + } + + resp.Write(os.Stdout) +} diff --git a/tsnet/example/web-client/web-client.go b/tsnet/example/web-client/web-client.go index 541efbaedf3d3..dee7fedfab2ba 100644 --- a/tsnet/example/web-client/web-client.go +++ b/tsnet/example/web-client/web-client.go @@ -1,46 +1,46 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The web-client command demonstrates serving the Tailscale web client over tsnet. -package main - -import ( - "flag" - "log" - "net/http" - - "tailscale.com/client/web" - "tailscale.com/tsnet" -) - -var ( - addr = flag.String("addr", "localhost:8060", "address of Tailscale web client") -) - -func main() { - flag.Parse() - - s := &tsnet.Server{RunWebClient: true} - defer s.Close() - - lc, err := s.LocalClient() - if err != nil { - log.Fatal(err) - } - - // Serve the Tailscale web client. - ws, err := web.NewServer(web.ServerOpts{ - Mode: web.LoginServerMode, - LocalClient: lc, - }) - if err != nil { - log.Fatal(err) - } - defer ws.Shutdown() - log.Printf("Serving Tailscale web client on http://%s", *addr) - if err := http.ListenAndServe(*addr, ws); err != nil { - if err != http.ErrServerClosed { - log.Fatal(err) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The web-client command demonstrates serving the Tailscale web client over tsnet. +package main + +import ( + "flag" + "log" + "net/http" + + "tailscale.com/client/web" + "tailscale.com/tsnet" +) + +var ( + addr = flag.String("addr", "localhost:8060", "address of Tailscale web client") +) + +func main() { + flag.Parse() + + s := &tsnet.Server{RunWebClient: true} + defer s.Close() + + lc, err := s.LocalClient() + if err != nil { + log.Fatal(err) + } + + // Serve the Tailscale web client. + ws, err := web.NewServer(web.ServerOpts{ + Mode: web.LoginServerMode, + LocalClient: lc, + }) + if err != nil { + log.Fatal(err) + } + defer ws.Shutdown() + log.Printf("Serving Tailscale web client on http://%s", *addr) + if err := http.ListenAndServe(*addr, ws); err != nil { + if err != http.ErrServerClosed { + log.Fatal(err) + } + } +} diff --git a/tsnet/example_tshello_test.go b/tsnet/example_tshello_test.go index d534bcfd1f1d4..4dec482339e2c 100644 --- a/tsnet/example_tshello_test.go +++ b/tsnet/example_tshello_test.go @@ -1,72 +1,72 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsnet_test - -import ( - "flag" - "fmt" - "html" - "log" - "net/http" - "strings" - - "tailscale.com/tsnet" -) - -func firstLabel(s string) string { - s, _, _ = strings.Cut(s, ".") - return s -} - -// Example_tshello is a full example on using tsnet. When you run this program it will print -// an authentication link. Open it in your favorite web browser and add it to your tailnet -// like any other machine. Open another terminal window and try to ping it: -// -// $ ping tshello -c 2 -// PING tshello (100.105.183.159) 56(84) bytes of data. -// 64 bytes from tshello.your-tailnet.ts.net (100.105.183.159): icmp_seq=1 ttl=64 time=25.0 ms -// 64 bytes from tshello.your-tailnet.ts.net (100.105.183.159): icmp_seq=2 ttl=64 time=1.12 ms -// -// Then connect to it using curl: -// -// $ curl http://tshello -//

Hello, world!

-//

You are Xe from pneuma (100.78.40.86:49214)

-// -// From here you can do anything you want with the Go standard library HTTP stack, or anything -// that is compatible with it (Gin/Gonic, Gorilla/mux, etc.). -func Example_tshello() { - var ( - addr = flag.String("addr", ":80", "address to listen on") - hostname = flag.String("hostname", "tshello", "hostname to use on the tailnet") - ) - - flag.Parse() - s := new(tsnet.Server) - s.Hostname = *hostname - defer s.Close() - ln, err := s.Listen("tcp", *addr) - if err != nil { - log.Fatal(err) - } - defer ln.Close() - - lc, err := s.LocalClient() - if err != nil { - log.Fatal(err) - } - - log.Fatal(http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - who, err := lc.WhoIs(r.Context(), r.RemoteAddr) - if err != nil { - http.Error(w, err.Error(), 500) - return - } - fmt.Fprintf(w, "

Hello, tailnet!

\n") - fmt.Fprintf(w, "

You are %s from %s (%s)

", - html.EscapeString(who.UserProfile.LoginName), - html.EscapeString(firstLabel(who.Node.ComputedName)), - r.RemoteAddr) - }))) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsnet_test + +import ( + "flag" + "fmt" + "html" + "log" + "net/http" + "strings" + + "tailscale.com/tsnet" +) + +func firstLabel(s string) string { + s, _, _ = strings.Cut(s, ".") + return s +} + +// Example_tshello is a full example on using tsnet. When you run this program it will print +// an authentication link. Open it in your favorite web browser and add it to your tailnet +// like any other machine. Open another terminal window and try to ping it: +// +// $ ping tshello -c 2 +// PING tshello (100.105.183.159) 56(84) bytes of data. +// 64 bytes from tshello.your-tailnet.ts.net (100.105.183.159): icmp_seq=1 ttl=64 time=25.0 ms +// 64 bytes from tshello.your-tailnet.ts.net (100.105.183.159): icmp_seq=2 ttl=64 time=1.12 ms +// +// Then connect to it using curl: +// +// $ curl http://tshello +//

Hello, world!

+//

You are Xe from pneuma (100.78.40.86:49214)

+// +// From here you can do anything you want with the Go standard library HTTP stack, or anything +// that is compatible with it (Gin/Gonic, Gorilla/mux, etc.). +func Example_tshello() { + var ( + addr = flag.String("addr", ":80", "address to listen on") + hostname = flag.String("hostname", "tshello", "hostname to use on the tailnet") + ) + + flag.Parse() + s := new(tsnet.Server) + s.Hostname = *hostname + defer s.Close() + ln, err := s.Listen("tcp", *addr) + if err != nil { + log.Fatal(err) + } + defer ln.Close() + + lc, err := s.LocalClient() + if err != nil { + log.Fatal(err) + } + + log.Fatal(http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + who, err := lc.WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + fmt.Fprintf(w, "

Hello, tailnet!

\n") + fmt.Fprintf(w, "

You are %s from %s (%s)

", + html.EscapeString(who.UserProfile.LoginName), + html.EscapeString(firstLabel(who.Node.ComputedName)), + r.RemoteAddr) + }))) +} diff --git a/tstest/allocs.go b/tstest/allocs.go index f15a00508d87f..a6d9c79f69ff7 100644 --- a/tstest/allocs.go +++ b/tstest/allocs.go @@ -1,50 +1,50 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstest - -import ( - "fmt" - "runtime" - "testing" - "time" -) - -// MinAllocsPerRun asserts that f can run with no more than target allocations. -// It runs f up to 1000 times or 5s, whichever happens first. -// If f has executed more than target allocations on every run, it returns a non-nil error. -// -// MinAllocsPerRun sets GOMAXPROCS to 1 during its measurement and restores -// it before returning. -func MinAllocsPerRun(t *testing.T, target uint64, f func()) error { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) - - var memstats runtime.MemStats - var min, max, sum uint64 - start := time.Now() - var iters int - for { - runtime.ReadMemStats(&memstats) - startMallocs := memstats.Mallocs - f() - runtime.ReadMemStats(&memstats) - mallocs := memstats.Mallocs - startMallocs - // TODO: if mallocs < target, return an error? See discussion in #3204. - if mallocs <= target { - return nil - } - if min == 0 || mallocs < min { - min = mallocs - } - if mallocs > max { - max = mallocs - } - sum += mallocs - iters++ - if iters == 1000 || time.Since(start) > 5*time.Second { - break - } - } - - return fmt.Errorf("min allocs = %d, max allocs = %d, avg allocs/run = %f, want run with <= %d allocs", min, max, float64(sum)/float64(iters), target) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import ( + "fmt" + "runtime" + "testing" + "time" +) + +// MinAllocsPerRun asserts that f can run with no more than target allocations. +// It runs f up to 1000 times or 5s, whichever happens first. +// If f has executed more than target allocations on every run, it returns a non-nil error. +// +// MinAllocsPerRun sets GOMAXPROCS to 1 during its measurement and restores +// it before returning. +func MinAllocsPerRun(t *testing.T, target uint64, f func()) error { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + + var memstats runtime.MemStats + var min, max, sum uint64 + start := time.Now() + var iters int + for { + runtime.ReadMemStats(&memstats) + startMallocs := memstats.Mallocs + f() + runtime.ReadMemStats(&memstats) + mallocs := memstats.Mallocs - startMallocs + // TODO: if mallocs < target, return an error? See discussion in #3204. + if mallocs <= target { + return nil + } + if min == 0 || mallocs < min { + min = mallocs + } + if mallocs > max { + max = mallocs + } + sum += mallocs + iters++ + if iters == 1000 || time.Since(start) > 5*time.Second { + break + } + } + + return fmt.Errorf("min allocs = %d, max allocs = %d, avg allocs/run = %f, want run with <= %d allocs", min, max, float64(sum)/float64(iters), target) +} diff --git a/tstest/archtest/qemu_test.go b/tstest/archtest/qemu_test.go index 8b59ae5d9fee1..cea3b4b8e9b53 100644 --- a/tstest/archtest/qemu_test.go +++ b/tstest/archtest/qemu_test.go @@ -1,73 +1,73 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux && amd64 && !race - -package archtest - -import ( - "bytes" - "fmt" - "os" - "os/exec" - "strings" - "testing" - - "tailscale.com/util/cibuild" -) - -func TestInQemu(t *testing.T) { - t.Parallel() - type Arch struct { - Goarch string // GOARCH value - Qarch string // qemu name - } - arches := []Arch{ - {"arm", "arm"}, - {"arm64", "aarch64"}, - {"mips", "mips"}, - {"mipsle", "mipsel"}, - {"mips64", "mips64"}, - {"mips64le", "mips64el"}, - {"386", "386"}, - } - inCI := cibuild.On() - for _, arch := range arches { - arch := arch - t.Run(arch.Goarch, func(t *testing.T) { - t.Parallel() - qemuUser := "qemu-" + arch.Qarch - execVia := qemuUser - if arch.Goarch == "386" { - execVia = "" // amd64 can run it fine - } else { - look, err := exec.LookPath(qemuUser) - if err != nil { - if inCI { - t.Fatalf("in CI and qemu not available: %v", err) - } - t.Skipf("%s not found; skipping test. error was: %v", qemuUser, err) - } - t.Logf("using %v", look) - } - cmd := exec.Command("go", - "test", - "--exec="+execVia, - "-v", - "tailscale.com/tstest/archtest", - ) - cmd.Env = append(os.Environ(), "GOARCH="+arch.Goarch) - out, err := cmd.CombinedOutput() - if err != nil { - if strings.Contains(string(out), "fatal error: sigaction failed") && !inCI { - t.Skip("skipping; qemu too old. use 5.x.") - } - t.Errorf("failed: %s", out) - } - sub := fmt.Sprintf("I am linux/%s", arch.Goarch) - if !bytes.Contains(out, []byte(sub)) { - t.Errorf("output didn't contain %q: %s", sub, out) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && amd64 && !race + +package archtest + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "strings" + "testing" + + "tailscale.com/util/cibuild" +) + +func TestInQemu(t *testing.T) { + t.Parallel() + type Arch struct { + Goarch string // GOARCH value + Qarch string // qemu name + } + arches := []Arch{ + {"arm", "arm"}, + {"arm64", "aarch64"}, + {"mips", "mips"}, + {"mipsle", "mipsel"}, + {"mips64", "mips64"}, + {"mips64le", "mips64el"}, + {"386", "386"}, + } + inCI := cibuild.On() + for _, arch := range arches { + arch := arch + t.Run(arch.Goarch, func(t *testing.T) { + t.Parallel() + qemuUser := "qemu-" + arch.Qarch + execVia := qemuUser + if arch.Goarch == "386" { + execVia = "" // amd64 can run it fine + } else { + look, err := exec.LookPath(qemuUser) + if err != nil { + if inCI { + t.Fatalf("in CI and qemu not available: %v", err) + } + t.Skipf("%s not found; skipping test. error was: %v", qemuUser, err) + } + t.Logf("using %v", look) + } + cmd := exec.Command("go", + "test", + "--exec="+execVia, + "-v", + "tailscale.com/tstest/archtest", + ) + cmd.Env = append(os.Environ(), "GOARCH="+arch.Goarch) + out, err := cmd.CombinedOutput() + if err != nil { + if strings.Contains(string(out), "fatal error: sigaction failed") && !inCI { + t.Skip("skipping; qemu too old. use 5.x.") + } + t.Errorf("failed: %s", out) + } + sub := fmt.Sprintf("I am linux/%s", arch.Goarch) + if !bytes.Contains(out, []byte(sub)) { + t.Errorf("output didn't contain %q: %s", sub, out) + } + }) + } +} diff --git a/tstest/clock.go b/tstest/clock.go index ee7523430ff54..48684957ec421 100644 --- a/tstest/clock.go +++ b/tstest/clock.go @@ -1,694 +1,694 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstest - -import ( - "container/heap" - "sync" - "time" - - "tailscale.com/tstime" - "tailscale.com/util/mak" -) - -// ClockOpts is used to configure the initial settings for a Clock. Once the -// settings are configured as desired, call NewClock to get the resulting Clock. -type ClockOpts struct { - // Start is the starting time for the Clock. When FollowRealTime is false, - // Start is also the value that will be returned by the first call - // to Clock.Now. - Start time.Time - // Step is the amount of time the Clock will advance whenever Clock.Now is - // called. If set to zero, the Clock will only advance when Clock.Advance is - // called and/or if FollowRealTime is true. - // - // FollowRealTime and Step cannot be enabled at the same time. - Step time.Duration - - // TimerChannelSize configures the maximum buffered ticks that are - // permitted in the channel of any Timer and Ticker created by this Clock. - // The special value 0 means to use the default of 1. The buffer may need to - // be increased if time is advanced by more than a single tick and proper - // functioning of the test requires that the ticks are not lost. - TimerChannelSize int - - // FollowRealTime makes the simulated time increment along with real time. - // It is a compromise between determinism and the difficulty of explicitly - // managing the simulated time via Step or Clock.Advance. When - // FollowRealTime is set, calls to Now() and PeekNow() will add the - // elapsed real-world time to the simulated time. - // - // FollowRealTime and Step cannot be enabled at the same time. - FollowRealTime bool -} - -// NewClock creates a Clock with the specified settings. To create a -// Clock with only the default settings, new(Clock) is equivalent, except that -// the start time will not be computed until one of the receivers is called. -func NewClock(co ClockOpts) *Clock { - if co.FollowRealTime && co.Step != 0 { - panic("only one of FollowRealTime and Step are allowed in NewClock") - } - - return newClockInternal(co, nil) -} - -// newClockInternal creates a Clock with the specified settings and allows -// specifying a non-standard realTimeClock. -func newClockInternal(co ClockOpts, rtClock tstime.Clock) *Clock { - if !co.FollowRealTime && rtClock != nil { - panic("rtClock can only be set with FollowRealTime enabled") - } - - if co.FollowRealTime && rtClock == nil { - rtClock = new(tstime.StdClock) - } - - c := &Clock{ - start: co.Start, - realTimeClock: rtClock, - step: co.Step, - timerChannelSize: co.TimerChannelSize, - } - c.init() // init now to capture the current time when co.Start.IsZero() - return c -} - -// Clock is a testing clock that advances every time its Now method is -// called, beginning at its start time. If no start time is specified using -// ClockBuilder, an arbitrary start time will be selected when the Clock is -// created and can be retrieved by calling Clock.Start(). -type Clock struct { - // start is the first value returned by Now. It must not be modified after - // init is called. - start time.Time - - // realTimeClock, if not nil, indicates that the Clock shall move forward - // according to realTimeClock + the accumulated calls to Advance. This can - // make writing tests easier that require some control over the clock but do - // not need exact control over the clock. While step can also be used for - // this purpose, it is harder to control how quickly time moves using step. - realTimeClock tstime.Clock - - initOnce sync.Once - mu sync.Mutex - - // step is how much to advance with each Now call. - step time.Duration - // present is the last value returned by Now (and will be returned again by - // PeekNow). - present time.Time - // realTime is the time from realTimeClock corresponding to the current - // value of present. - realTime time.Time - // skipStep indicates that the next call to Now should not add step to - // present. This occurs after initialization and after Advance. - skipStep bool - // timerChannelSize is the buffer size to use for channels created by - // NewTimer and NewTicker. - timerChannelSize int - - events eventManager -} - -func (c *Clock) init() { - c.initOnce.Do(func() { - if c.realTimeClock != nil { - c.realTime = c.realTimeClock.Now() - } - if c.start.IsZero() { - if c.realTime.IsZero() { - c.start = time.Now() - } else { - c.start = c.realTime - } - } - if c.timerChannelSize == 0 { - c.timerChannelSize = 1 - } - c.present = c.start - c.skipStep = true - c.events.AdvanceTo(c.present) - }) -} - -// Now returns the virtual clock's current time, and advances it -// according to its step configuration. -func (c *Clock) Now() time.Time { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - - step := c.step - if c.skipStep { - step = 0 - c.skipStep = false - } - c.advanceLocked(rt, step) - - return c.present -} - -func (c *Clock) maybeGetRealTime() time.Time { - if c.realTimeClock == nil { - return time.Time{} - } - return c.realTimeClock.Now() -} - -func (c *Clock) advanceLocked(now time.Time, add time.Duration) { - if !now.IsZero() { - add += now.Sub(c.realTime) - c.realTime = now - } - if add == 0 { - return - } - c.present = c.present.Add(add) - c.events.AdvanceTo(c.present) -} - -// PeekNow returns the last time reported by Now. If Now has never been called, -// PeekNow returns the same value as GetStart. -func (c *Clock) PeekNow() time.Time { - c.init() - c.mu.Lock() - defer c.mu.Unlock() - return c.present -} - -// Advance moves simulated time forward or backwards by a relative amount. Any -// Timer or Ticker that is waiting will fire at the requested point in simulated -// time. Advance returns the new simulated time. If this Clock follows real time -// then the next call to Now will equal the return value of Advance + the -// elapsed time since calling Advance. Otherwise, the next call to Now will -// equal the return value of Advance, regardless of the current step. -func (c *Clock) Advance(d time.Duration) time.Time { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - c.skipStep = true - - c.advanceLocked(rt, d) - return c.present -} - -// AdvanceTo moves simulated time to a new absolute value. Any Timer or Ticker -// that is waiting will fire at the requested point in simulated time. If this -// Clock follows real time then the next call to Now will equal t + the elapsed -// time since calling Advance. Otherwise, the next call to Now will equal t, -// regardless of the configured step. -func (c *Clock) AdvanceTo(t time.Time) { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - c.skipStep = true - c.realTime = rt - c.present = t - c.events.AdvanceTo(c.present) -} - -// GetStart returns the initial simulated time when this Clock was created. -func (c *Clock) GetStart() time.Time { - c.init() - c.mu.Lock() - defer c.mu.Unlock() - return c.start -} - -// GetStep returns the amount that simulated time advances on every call to Now. -func (c *Clock) GetStep() time.Duration { - c.init() - c.mu.Lock() - defer c.mu.Unlock() - return c.step -} - -// SetStep updates the amount that simulated time advances on every call to Now. -func (c *Clock) SetStep(d time.Duration) { - c.init() - c.mu.Lock() - defer c.mu.Unlock() - c.step = d -} - -// SetTimerChannelSize changes the channel size for any Timer or Ticker created -// in the future. It does not affect those that were already created. -func (c *Clock) SetTimerChannelSize(n int) { - c.init() - c.mu.Lock() - defer c.mu.Unlock() - c.timerChannelSize = n -} - -// NewTicker returns a Ticker that uses this Clock for accessing the current -// time. -func (c *Clock) NewTicker(d time.Duration) (tstime.TickerController, <-chan time.Time) { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - - c.advanceLocked(rt, 0) - t := &Ticker{ - nextTrigger: c.present.Add(d), - period: d, - em: &c.events, - } - t.init(c.timerChannelSize) - return t, t.C -} - -// NewTimer returns a Timer that uses this Clock for accessing the current -// time. -func (c *Clock) NewTimer(d time.Duration) (tstime.TimerController, <-chan time.Time) { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - - c.advanceLocked(rt, 0) - t := &Timer{ - nextTrigger: c.present.Add(d), - em: &c.events, - } - t.init(c.timerChannelSize, nil) - return t, t.C -} - -// AfterFunc returns a Timer that calls f when it fires, using this Clock for -// accessing the current time. -func (c *Clock) AfterFunc(d time.Duration, f func()) tstime.TimerController { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - - c.advanceLocked(rt, 0) - t := &Timer{ - nextTrigger: c.present.Add(d), - em: &c.events, - } - t.init(c.timerChannelSize, f) - return t -} - -// Since subtracts specified duration from Now(). -func (c *Clock) Since(t time.Time) time.Duration { - return c.Now().Sub(t) -} - -// eventHandler offers a common interface for Timer and Ticker events to avoid -// code duplication in eventManager. -type eventHandler interface { - // Fire signals the event. The provided time is written to the event's - // channel as the current time. The return value is the next time this event - // should fire, otherwise if it is zero then the event will be removed from - // the eventManager. - Fire(time.Time) time.Time -} - -// event tracks details about an upcoming Timer or Ticker firing. -type event struct { - position int // The current index in the heap, needed for heap.Fix and heap.Remove. - when time.Time // A cache of the next time the event triggers to avoid locking issues if we were to get it from eh. - eh eventHandler -} - -// eventManager tracks pending events created by Timer and Ticker. eventManager -// implements heap.Interface for efficient lookups of the next event. -type eventManager struct { - // clock is a real time clock for scheduling events with. When clock is nil, - // events only fire when AdvanceTo is called by the simulated clock that - // this eventManager belongs to. When clock is not nil, events may fire when - // timer triggers. - clock tstime.Clock - - mu sync.Mutex - now time.Time - heap []*event - reverseLookup map[eventHandler]*event - - // timer is an AfterFunc that triggers at heap[0].when.Sub(now) relative to - // the time represented by clock. In other words, if clock is real world - // time, then if an event is scheduled 1 second into the future in the - // simulated time, then the event will trigger after 1 second of actual test - // execution time (unless the test advances simulated time, in which case - // the timer is updated accordingly). This makes tests easier to write in - // situations where the simulated time only needs to be partially - // controlled, and the test writer wishes for simulated time to pass with an - // offset but still synchronized with the real world. - // - // In the future, this could be extended to allow simulated time to run at a - // multiple of real world time. - timer tstime.TimerController -} - -func (em *eventManager) handleTimer() { - rt := em.clock.Now() - em.AdvanceTo(rt) -} - -// Push implements heap.Interface.Push and must only be called by heap funcs -// with em.mu already held. -func (em *eventManager) Push(x any) { - e, ok := x.(*event) - if !ok { - panic("incorrect event type") - } - if e == nil { - panic("nil event") - } - - mak.Set(&em.reverseLookup, e.eh, e) - e.position = len(em.heap) - em.heap = append(em.heap, e) -} - -// Pop implements heap.Interface.Pop and must only be called by heap funcs with -// em.mu already held. -func (em *eventManager) Pop() any { - e := em.heap[len(em.heap)-1] - em.heap = em.heap[:len(em.heap)-1] - delete(em.reverseLookup, e.eh) - return e -} - -// Len implements sort.Interface.Len and must only be called by heap funcs with -// em.mu already held. -func (em *eventManager) Len() int { - return len(em.heap) -} - -// Less implements sort.Interface.Less and must only be called by heap funcs -// with em.mu already held. -func (em *eventManager) Less(i, j int) bool { - return em.heap[i].when.Before(em.heap[j].when) -} - -// Swap implements sort.Interface.Swap and must only be called by heap funcs -// with em.mu already held. -func (em *eventManager) Swap(i, j int) { - em.heap[i], em.heap[j] = em.heap[j], em.heap[i] - em.heap[i].position = i - em.heap[j].position = j -} - -// Reschedule adds/updates/deletes an event in the heap, whichever -// operation is applicable (use a zero time to delete). -func (em *eventManager) Reschedule(eh eventHandler, t time.Time) { - em.mu.Lock() - defer em.mu.Unlock() - defer em.updateTimerLocked() - - e, ok := em.reverseLookup[eh] - if !ok { - if t.IsZero() { - // eh is not scheduled and also not active, so do nothing. - return - } - // eh is not scheduled but is active, so add it. - heap.Push(em, &event{ - when: t, - eh: eh, - }) - em.processEventsLocked(em.now) // This is always safe and required when !t.After(em.now). - return - } - - if t.IsZero() { - // e is scheduled but not active, so remove it. - heap.Remove(em, e.position) - return - } - - // e is scheduled and active, so update it. - e.when = t - heap.Fix(em, e.position) - em.processEventsLocked(em.now) // This is always safe and required when !t.After(em.now). -} - -// AdvanceTo updates the current time to tm and fires all events scheduled -// before or equal to tm. When an event fires, it may request rescheduling and -// the rescheduled events will be combined with the other existing events that -// are waiting, and will be run in the unified ordering. A poorly behaved event -// may theoretically prevent this from ever completing, but both Timer and -// Ticker require positive steps into the future. -func (em *eventManager) AdvanceTo(tm time.Time) { - em.mu.Lock() - defer em.mu.Unlock() - defer em.updateTimerLocked() - - em.processEventsLocked(tm) - em.now = tm -} - -// Now returns the cached current time. It is intended for use by a Timer or -// Ticker that needs to convert a relative time to an absolute time. -func (em *eventManager) Now() time.Time { - em.mu.Lock() - defer em.mu.Unlock() - return em.now -} - -func (em *eventManager) processEventsLocked(tm time.Time) { - for len(em.heap) > 0 && !em.heap[0].when.After(tm) { - // Ideally some jitter would be added here but it's difficult to do so - // in a deterministic fashion. - em.now = em.heap[0].when - - if nextFire := em.heap[0].eh.Fire(em.now); !nextFire.IsZero() { - em.heap[0].when = nextFire - heap.Fix(em, 0) - } else { - heap.Pop(em) - } - } -} - -func (em *eventManager) updateTimerLocked() { - if em.clock == nil { - return - } - if len(em.heap) == 0 { - if em.timer != nil { - em.timer.Stop() - } - return - } - - timeToEvent := em.heap[0].when.Sub(em.now) - if em.timer == nil { - em.timer = em.clock.AfterFunc(timeToEvent, em.handleTimer) - return - } - em.timer.Reset(timeToEvent) -} - -// Ticker is a time.Ticker lookalike for use in tests that need to control when -// events fire. Ticker could be made standalone in future but for now is -// expected to be paired with a Clock and created by Clock.NewTicker. -type Ticker struct { - C <-chan time.Time // The channel on which ticks are delivered. - - // em is the eventManager to be notified when nextTrigger changes. - // eventManager has its own mutex, and the pointer is immutable, therefore - // em can be accessed without holding mu. - em *eventManager - - c chan<- time.Time // The writer side of C. - - mu sync.Mutex - - // nextTrigger is the time of the ticker's next scheduled activation. When - // Fire activates the ticker, nextTrigger is the timestamp written to the - // channel. - nextTrigger time.Time - - // period is the duration that is added to nextTrigger when the ticker - // fires. - period time.Duration -} - -func (t *Ticker) init(channelSize int) { - if channelSize <= 0 { - panic("ticker channel size must be non-negative") - } - c := make(chan time.Time, channelSize) - t.c = c - t.C = c - t.em.Reschedule(t, t.nextTrigger) -} - -// Fire triggers the ticker. curTime is the timestamp to write to the channel. -// The next trigger time for the ticker is updated to the last computed trigger -// time + the ticker period (set at creation or using Reset). The next trigger -// time is computed this way to match standard time.Ticker behavior, which -// prevents accumulation of long term drift caused by delays in event execution. -func (t *Ticker) Fire(curTime time.Time) time.Time { - t.mu.Lock() - defer t.mu.Unlock() - - if t.nextTrigger.IsZero() { - return time.Time{} - } - select { - case t.c <- curTime: - default: - } - t.nextTrigger = t.nextTrigger.Add(t.period) - - return t.nextTrigger -} - -// Reset adjusts the Ticker's period to d and reschedules the next fire time to -// the current simulated time + d. -func (t *Ticker) Reset(d time.Duration) { - if d <= 0 { - // The standard time.Ticker requires a positive period. - panic("non-positive period for Ticker.Reset") - } - - now := t.em.Now() - - t.mu.Lock() - t.resetLocked(now.Add(d), d) - t.mu.Unlock() - - t.em.Reschedule(t, t.nextTrigger) -} - -// ResetAbsolute adjusts the Ticker's period to d and reschedules the next fire -// time to nextTrigger. -func (t *Ticker) ResetAbsolute(nextTrigger time.Time, d time.Duration) { - if nextTrigger.IsZero() { - panic("zero nextTrigger time for ResetAbsolute") - } - if d <= 0 { - panic("non-positive period for ResetAbsolute") - } - - t.mu.Lock() - t.resetLocked(nextTrigger, d) - t.mu.Unlock() - - t.em.Reschedule(t, t.nextTrigger) -} - -func (t *Ticker) resetLocked(nextTrigger time.Time, d time.Duration) { - t.nextTrigger = nextTrigger - t.period = d -} - -// Stop deactivates the Ticker. -func (t *Ticker) Stop() { - t.mu.Lock() - t.nextTrigger = time.Time{} - t.mu.Unlock() - - t.em.Reschedule(t, t.nextTrigger) -} - -// Timer is a time.Timer lookalike for use in tests that need to control when -// events fire. Timer could be made standalone in future but for now must be -// paired with a Clock and created by Clock.NewTimer. -type Timer struct { - C <-chan time.Time // The channel on which ticks are delivered. - - // em is the eventManager to be notified when nextTrigger changes. - // eventManager has its own mutex, and the pointer is immutable, therefore - // em can be accessed without holding mu. - em *eventManager - - f func(time.Time) // The function to call when the timer expires. - - mu sync.Mutex - - // nextTrigger is the time of the ticker's next scheduled activation. When - // Fire activates the ticker, nextTrigger is the timestamp written to the - // channel. - nextTrigger time.Time -} - -func (t *Timer) init(channelSize int, afterFunc func()) { - if channelSize <= 0 { - panic("ticker channel size must be non-negative") - } - c := make(chan time.Time, channelSize) - t.C = c - if afterFunc == nil { - t.f = func(curTime time.Time) { - select { - case c <- curTime: - default: - } - } - } else { - t.f = func(_ time.Time) { afterFunc() } - } - t.em.Reschedule(t, t.nextTrigger) -} - -// Fire triggers the ticker. curTime is the timestamp to write to the channel. -// The next trigger time for the ticker is updated to the last computed trigger -// time + the ticker period (set at creation or using Reset). The next trigger -// time is computed this way to match standard time.Ticker behavior, which -// prevents accumulation of long term drift caused by delays in event execution. -func (t *Timer) Fire(curTime time.Time) time.Time { - t.mu.Lock() - defer t.mu.Unlock() - - if t.nextTrigger.IsZero() { - return time.Time{} - } - t.nextTrigger = time.Time{} - t.f(curTime) - return time.Time{} -} - -// Reset reschedules the next fire time to the current simulated time + d. -// Reset reports whether the timer was still active before the reset. -func (t *Timer) Reset(d time.Duration) bool { - if d <= 0 { - // The standard time.Timer requires a positive delay. - panic("non-positive delay for Timer.Reset") - } - - return t.reset(t.em.Now().Add(d)) -} - -// ResetAbsolute reschedules the next fire time to nextTrigger. -// ResetAbsolute reports whether the timer was still active before the reset. -func (t *Timer) ResetAbsolute(nextTrigger time.Time) bool { - if nextTrigger.IsZero() { - panic("zero nextTrigger time for ResetAbsolute") - } - - return t.reset(nextTrigger) -} - -// Stop deactivates the Timer. Stop reports whether the timer was active before -// stopping. -func (t *Timer) Stop() bool { - return t.reset(time.Time{}) -} - -func (t *Timer) reset(nextTrigger time.Time) bool { - t.mu.Lock() - wasActive := !t.nextTrigger.IsZero() - t.nextTrigger = nextTrigger - t.mu.Unlock() - - t.em.Reschedule(t, t.nextTrigger) - return wasActive -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import ( + "container/heap" + "sync" + "time" + + "tailscale.com/tstime" + "tailscale.com/util/mak" +) + +// ClockOpts is used to configure the initial settings for a Clock. Once the +// settings are configured as desired, call NewClock to get the resulting Clock. +type ClockOpts struct { + // Start is the starting time for the Clock. When FollowRealTime is false, + // Start is also the value that will be returned by the first call + // to Clock.Now. + Start time.Time + // Step is the amount of time the Clock will advance whenever Clock.Now is + // called. If set to zero, the Clock will only advance when Clock.Advance is + // called and/or if FollowRealTime is true. + // + // FollowRealTime and Step cannot be enabled at the same time. + Step time.Duration + + // TimerChannelSize configures the maximum buffered ticks that are + // permitted in the channel of any Timer and Ticker created by this Clock. + // The special value 0 means to use the default of 1. The buffer may need to + // be increased if time is advanced by more than a single tick and proper + // functioning of the test requires that the ticks are not lost. + TimerChannelSize int + + // FollowRealTime makes the simulated time increment along with real time. + // It is a compromise between determinism and the difficulty of explicitly + // managing the simulated time via Step or Clock.Advance. When + // FollowRealTime is set, calls to Now() and PeekNow() will add the + // elapsed real-world time to the simulated time. + // + // FollowRealTime and Step cannot be enabled at the same time. + FollowRealTime bool +} + +// NewClock creates a Clock with the specified settings. To create a +// Clock with only the default settings, new(Clock) is equivalent, except that +// the start time will not be computed until one of the receivers is called. +func NewClock(co ClockOpts) *Clock { + if co.FollowRealTime && co.Step != 0 { + panic("only one of FollowRealTime and Step are allowed in NewClock") + } + + return newClockInternal(co, nil) +} + +// newClockInternal creates a Clock with the specified settings and allows +// specifying a non-standard realTimeClock. +func newClockInternal(co ClockOpts, rtClock tstime.Clock) *Clock { + if !co.FollowRealTime && rtClock != nil { + panic("rtClock can only be set with FollowRealTime enabled") + } + + if co.FollowRealTime && rtClock == nil { + rtClock = new(tstime.StdClock) + } + + c := &Clock{ + start: co.Start, + realTimeClock: rtClock, + step: co.Step, + timerChannelSize: co.TimerChannelSize, + } + c.init() // init now to capture the current time when co.Start.IsZero() + return c +} + +// Clock is a testing clock that advances every time its Now method is +// called, beginning at its start time. If no start time is specified using +// ClockBuilder, an arbitrary start time will be selected when the Clock is +// created and can be retrieved by calling Clock.Start(). +type Clock struct { + // start is the first value returned by Now. It must not be modified after + // init is called. + start time.Time + + // realTimeClock, if not nil, indicates that the Clock shall move forward + // according to realTimeClock + the accumulated calls to Advance. This can + // make writing tests easier that require some control over the clock but do + // not need exact control over the clock. While step can also be used for + // this purpose, it is harder to control how quickly time moves using step. + realTimeClock tstime.Clock + + initOnce sync.Once + mu sync.Mutex + + // step is how much to advance with each Now call. + step time.Duration + // present is the last value returned by Now (and will be returned again by + // PeekNow). + present time.Time + // realTime is the time from realTimeClock corresponding to the current + // value of present. + realTime time.Time + // skipStep indicates that the next call to Now should not add step to + // present. This occurs after initialization and after Advance. + skipStep bool + // timerChannelSize is the buffer size to use for channels created by + // NewTimer and NewTicker. + timerChannelSize int + + events eventManager +} + +func (c *Clock) init() { + c.initOnce.Do(func() { + if c.realTimeClock != nil { + c.realTime = c.realTimeClock.Now() + } + if c.start.IsZero() { + if c.realTime.IsZero() { + c.start = time.Now() + } else { + c.start = c.realTime + } + } + if c.timerChannelSize == 0 { + c.timerChannelSize = 1 + } + c.present = c.start + c.skipStep = true + c.events.AdvanceTo(c.present) + }) +} + +// Now returns the virtual clock's current time, and advances it +// according to its step configuration. +func (c *Clock) Now() time.Time { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + + step := c.step + if c.skipStep { + step = 0 + c.skipStep = false + } + c.advanceLocked(rt, step) + + return c.present +} + +func (c *Clock) maybeGetRealTime() time.Time { + if c.realTimeClock == nil { + return time.Time{} + } + return c.realTimeClock.Now() +} + +func (c *Clock) advanceLocked(now time.Time, add time.Duration) { + if !now.IsZero() { + add += now.Sub(c.realTime) + c.realTime = now + } + if add == 0 { + return + } + c.present = c.present.Add(add) + c.events.AdvanceTo(c.present) +} + +// PeekNow returns the last time reported by Now. If Now has never been called, +// PeekNow returns the same value as GetStart. +func (c *Clock) PeekNow() time.Time { + c.init() + c.mu.Lock() + defer c.mu.Unlock() + return c.present +} + +// Advance moves simulated time forward or backwards by a relative amount. Any +// Timer or Ticker that is waiting will fire at the requested point in simulated +// time. Advance returns the new simulated time. If this Clock follows real time +// then the next call to Now will equal the return value of Advance + the +// elapsed time since calling Advance. Otherwise, the next call to Now will +// equal the return value of Advance, regardless of the current step. +func (c *Clock) Advance(d time.Duration) time.Time { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + c.skipStep = true + + c.advanceLocked(rt, d) + return c.present +} + +// AdvanceTo moves simulated time to a new absolute value. Any Timer or Ticker +// that is waiting will fire at the requested point in simulated time. If this +// Clock follows real time then the next call to Now will equal t + the elapsed +// time since calling Advance. Otherwise, the next call to Now will equal t, +// regardless of the configured step. +func (c *Clock) AdvanceTo(t time.Time) { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + c.skipStep = true + c.realTime = rt + c.present = t + c.events.AdvanceTo(c.present) +} + +// GetStart returns the initial simulated time when this Clock was created. +func (c *Clock) GetStart() time.Time { + c.init() + c.mu.Lock() + defer c.mu.Unlock() + return c.start +} + +// GetStep returns the amount that simulated time advances on every call to Now. +func (c *Clock) GetStep() time.Duration { + c.init() + c.mu.Lock() + defer c.mu.Unlock() + return c.step +} + +// SetStep updates the amount that simulated time advances on every call to Now. +func (c *Clock) SetStep(d time.Duration) { + c.init() + c.mu.Lock() + defer c.mu.Unlock() + c.step = d +} + +// SetTimerChannelSize changes the channel size for any Timer or Ticker created +// in the future. It does not affect those that were already created. +func (c *Clock) SetTimerChannelSize(n int) { + c.init() + c.mu.Lock() + defer c.mu.Unlock() + c.timerChannelSize = n +} + +// NewTicker returns a Ticker that uses this Clock for accessing the current +// time. +func (c *Clock) NewTicker(d time.Duration) (tstime.TickerController, <-chan time.Time) { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + + c.advanceLocked(rt, 0) + t := &Ticker{ + nextTrigger: c.present.Add(d), + period: d, + em: &c.events, + } + t.init(c.timerChannelSize) + return t, t.C +} + +// NewTimer returns a Timer that uses this Clock for accessing the current +// time. +func (c *Clock) NewTimer(d time.Duration) (tstime.TimerController, <-chan time.Time) { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + + c.advanceLocked(rt, 0) + t := &Timer{ + nextTrigger: c.present.Add(d), + em: &c.events, + } + t.init(c.timerChannelSize, nil) + return t, t.C +} + +// AfterFunc returns a Timer that calls f when it fires, using this Clock for +// accessing the current time. +func (c *Clock) AfterFunc(d time.Duration, f func()) tstime.TimerController { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + + c.advanceLocked(rt, 0) + t := &Timer{ + nextTrigger: c.present.Add(d), + em: &c.events, + } + t.init(c.timerChannelSize, f) + return t +} + +// Since subtracts specified duration from Now(). +func (c *Clock) Since(t time.Time) time.Duration { + return c.Now().Sub(t) +} + +// eventHandler offers a common interface for Timer and Ticker events to avoid +// code duplication in eventManager. +type eventHandler interface { + // Fire signals the event. The provided time is written to the event's + // channel as the current time. The return value is the next time this event + // should fire, otherwise if it is zero then the event will be removed from + // the eventManager. + Fire(time.Time) time.Time +} + +// event tracks details about an upcoming Timer or Ticker firing. +type event struct { + position int // The current index in the heap, needed for heap.Fix and heap.Remove. + when time.Time // A cache of the next time the event triggers to avoid locking issues if we were to get it from eh. + eh eventHandler +} + +// eventManager tracks pending events created by Timer and Ticker. eventManager +// implements heap.Interface for efficient lookups of the next event. +type eventManager struct { + // clock is a real time clock for scheduling events with. When clock is nil, + // events only fire when AdvanceTo is called by the simulated clock that + // this eventManager belongs to. When clock is not nil, events may fire when + // timer triggers. + clock tstime.Clock + + mu sync.Mutex + now time.Time + heap []*event + reverseLookup map[eventHandler]*event + + // timer is an AfterFunc that triggers at heap[0].when.Sub(now) relative to + // the time represented by clock. In other words, if clock is real world + // time, then if an event is scheduled 1 second into the future in the + // simulated time, then the event will trigger after 1 second of actual test + // execution time (unless the test advances simulated time, in which case + // the timer is updated accordingly). This makes tests easier to write in + // situations where the simulated time only needs to be partially + // controlled, and the test writer wishes for simulated time to pass with an + // offset but still synchronized with the real world. + // + // In the future, this could be extended to allow simulated time to run at a + // multiple of real world time. + timer tstime.TimerController +} + +func (em *eventManager) handleTimer() { + rt := em.clock.Now() + em.AdvanceTo(rt) +} + +// Push implements heap.Interface.Push and must only be called by heap funcs +// with em.mu already held. +func (em *eventManager) Push(x any) { + e, ok := x.(*event) + if !ok { + panic("incorrect event type") + } + if e == nil { + panic("nil event") + } + + mak.Set(&em.reverseLookup, e.eh, e) + e.position = len(em.heap) + em.heap = append(em.heap, e) +} + +// Pop implements heap.Interface.Pop and must only be called by heap funcs with +// em.mu already held. +func (em *eventManager) Pop() any { + e := em.heap[len(em.heap)-1] + em.heap = em.heap[:len(em.heap)-1] + delete(em.reverseLookup, e.eh) + return e +} + +// Len implements sort.Interface.Len and must only be called by heap funcs with +// em.mu already held. +func (em *eventManager) Len() int { + return len(em.heap) +} + +// Less implements sort.Interface.Less and must only be called by heap funcs +// with em.mu already held. +func (em *eventManager) Less(i, j int) bool { + return em.heap[i].when.Before(em.heap[j].when) +} + +// Swap implements sort.Interface.Swap and must only be called by heap funcs +// with em.mu already held. +func (em *eventManager) Swap(i, j int) { + em.heap[i], em.heap[j] = em.heap[j], em.heap[i] + em.heap[i].position = i + em.heap[j].position = j +} + +// Reschedule adds/updates/deletes an event in the heap, whichever +// operation is applicable (use a zero time to delete). +func (em *eventManager) Reschedule(eh eventHandler, t time.Time) { + em.mu.Lock() + defer em.mu.Unlock() + defer em.updateTimerLocked() + + e, ok := em.reverseLookup[eh] + if !ok { + if t.IsZero() { + // eh is not scheduled and also not active, so do nothing. + return + } + // eh is not scheduled but is active, so add it. + heap.Push(em, &event{ + when: t, + eh: eh, + }) + em.processEventsLocked(em.now) // This is always safe and required when !t.After(em.now). + return + } + + if t.IsZero() { + // e is scheduled but not active, so remove it. + heap.Remove(em, e.position) + return + } + + // e is scheduled and active, so update it. + e.when = t + heap.Fix(em, e.position) + em.processEventsLocked(em.now) // This is always safe and required when !t.After(em.now). +} + +// AdvanceTo updates the current time to tm and fires all events scheduled +// before or equal to tm. When an event fires, it may request rescheduling and +// the rescheduled events will be combined with the other existing events that +// are waiting, and will be run in the unified ordering. A poorly behaved event +// may theoretically prevent this from ever completing, but both Timer and +// Ticker require positive steps into the future. +func (em *eventManager) AdvanceTo(tm time.Time) { + em.mu.Lock() + defer em.mu.Unlock() + defer em.updateTimerLocked() + + em.processEventsLocked(tm) + em.now = tm +} + +// Now returns the cached current time. It is intended for use by a Timer or +// Ticker that needs to convert a relative time to an absolute time. +func (em *eventManager) Now() time.Time { + em.mu.Lock() + defer em.mu.Unlock() + return em.now +} + +func (em *eventManager) processEventsLocked(tm time.Time) { + for len(em.heap) > 0 && !em.heap[0].when.After(tm) { + // Ideally some jitter would be added here but it's difficult to do so + // in a deterministic fashion. + em.now = em.heap[0].when + + if nextFire := em.heap[0].eh.Fire(em.now); !nextFire.IsZero() { + em.heap[0].when = nextFire + heap.Fix(em, 0) + } else { + heap.Pop(em) + } + } +} + +func (em *eventManager) updateTimerLocked() { + if em.clock == nil { + return + } + if len(em.heap) == 0 { + if em.timer != nil { + em.timer.Stop() + } + return + } + + timeToEvent := em.heap[0].when.Sub(em.now) + if em.timer == nil { + em.timer = em.clock.AfterFunc(timeToEvent, em.handleTimer) + return + } + em.timer.Reset(timeToEvent) +} + +// Ticker is a time.Ticker lookalike for use in tests that need to control when +// events fire. Ticker could be made standalone in future but for now is +// expected to be paired with a Clock and created by Clock.NewTicker. +type Ticker struct { + C <-chan time.Time // The channel on which ticks are delivered. + + // em is the eventManager to be notified when nextTrigger changes. + // eventManager has its own mutex, and the pointer is immutable, therefore + // em can be accessed without holding mu. + em *eventManager + + c chan<- time.Time // The writer side of C. + + mu sync.Mutex + + // nextTrigger is the time of the ticker's next scheduled activation. When + // Fire activates the ticker, nextTrigger is the timestamp written to the + // channel. + nextTrigger time.Time + + // period is the duration that is added to nextTrigger when the ticker + // fires. + period time.Duration +} + +func (t *Ticker) init(channelSize int) { + if channelSize <= 0 { + panic("ticker channel size must be non-negative") + } + c := make(chan time.Time, channelSize) + t.c = c + t.C = c + t.em.Reschedule(t, t.nextTrigger) +} + +// Fire triggers the ticker. curTime is the timestamp to write to the channel. +// The next trigger time for the ticker is updated to the last computed trigger +// time + the ticker period (set at creation or using Reset). The next trigger +// time is computed this way to match standard time.Ticker behavior, which +// prevents accumulation of long term drift caused by delays in event execution. +func (t *Ticker) Fire(curTime time.Time) time.Time { + t.mu.Lock() + defer t.mu.Unlock() + + if t.nextTrigger.IsZero() { + return time.Time{} + } + select { + case t.c <- curTime: + default: + } + t.nextTrigger = t.nextTrigger.Add(t.period) + + return t.nextTrigger +} + +// Reset adjusts the Ticker's period to d and reschedules the next fire time to +// the current simulated time + d. +func (t *Ticker) Reset(d time.Duration) { + if d <= 0 { + // The standard time.Ticker requires a positive period. + panic("non-positive period for Ticker.Reset") + } + + now := t.em.Now() + + t.mu.Lock() + t.resetLocked(now.Add(d), d) + t.mu.Unlock() + + t.em.Reschedule(t, t.nextTrigger) +} + +// ResetAbsolute adjusts the Ticker's period to d and reschedules the next fire +// time to nextTrigger. +func (t *Ticker) ResetAbsolute(nextTrigger time.Time, d time.Duration) { + if nextTrigger.IsZero() { + panic("zero nextTrigger time for ResetAbsolute") + } + if d <= 0 { + panic("non-positive period for ResetAbsolute") + } + + t.mu.Lock() + t.resetLocked(nextTrigger, d) + t.mu.Unlock() + + t.em.Reschedule(t, t.nextTrigger) +} + +func (t *Ticker) resetLocked(nextTrigger time.Time, d time.Duration) { + t.nextTrigger = nextTrigger + t.period = d +} + +// Stop deactivates the Ticker. +func (t *Ticker) Stop() { + t.mu.Lock() + t.nextTrigger = time.Time{} + t.mu.Unlock() + + t.em.Reschedule(t, t.nextTrigger) +} + +// Timer is a time.Timer lookalike for use in tests that need to control when +// events fire. Timer could be made standalone in future but for now must be +// paired with a Clock and created by Clock.NewTimer. +type Timer struct { + C <-chan time.Time // The channel on which ticks are delivered. + + // em is the eventManager to be notified when nextTrigger changes. + // eventManager has its own mutex, and the pointer is immutable, therefore + // em can be accessed without holding mu. + em *eventManager + + f func(time.Time) // The function to call when the timer expires. + + mu sync.Mutex + + // nextTrigger is the time of the ticker's next scheduled activation. When + // Fire activates the ticker, nextTrigger is the timestamp written to the + // channel. + nextTrigger time.Time +} + +func (t *Timer) init(channelSize int, afterFunc func()) { + if channelSize <= 0 { + panic("ticker channel size must be non-negative") + } + c := make(chan time.Time, channelSize) + t.C = c + if afterFunc == nil { + t.f = func(curTime time.Time) { + select { + case c <- curTime: + default: + } + } + } else { + t.f = func(_ time.Time) { afterFunc() } + } + t.em.Reschedule(t, t.nextTrigger) +} + +// Fire triggers the ticker. curTime is the timestamp to write to the channel. +// The next trigger time for the ticker is updated to the last computed trigger +// time + the ticker period (set at creation or using Reset). The next trigger +// time is computed this way to match standard time.Ticker behavior, which +// prevents accumulation of long term drift caused by delays in event execution. +func (t *Timer) Fire(curTime time.Time) time.Time { + t.mu.Lock() + defer t.mu.Unlock() + + if t.nextTrigger.IsZero() { + return time.Time{} + } + t.nextTrigger = time.Time{} + t.f(curTime) + return time.Time{} +} + +// Reset reschedules the next fire time to the current simulated time + d. +// Reset reports whether the timer was still active before the reset. +func (t *Timer) Reset(d time.Duration) bool { + if d <= 0 { + // The standard time.Timer requires a positive delay. + panic("non-positive delay for Timer.Reset") + } + + return t.reset(t.em.Now().Add(d)) +} + +// ResetAbsolute reschedules the next fire time to nextTrigger. +// ResetAbsolute reports whether the timer was still active before the reset. +func (t *Timer) ResetAbsolute(nextTrigger time.Time) bool { + if nextTrigger.IsZero() { + panic("zero nextTrigger time for ResetAbsolute") + } + + return t.reset(nextTrigger) +} + +// Stop deactivates the Timer. Stop reports whether the timer was active before +// stopping. +func (t *Timer) Stop() bool { + return t.reset(time.Time{}) +} + +func (t *Timer) reset(nextTrigger time.Time) bool { + t.mu.Lock() + wasActive := !t.nextTrigger.IsZero() + t.nextTrigger = nextTrigger + t.mu.Unlock() + + t.em.Reschedule(t, t.nextTrigger) + return wasActive +} diff --git a/tstest/deptest/deptest_test.go b/tstest/deptest/deptest_test.go index ebafa56849efb..3b7b2dde91dec 100644 --- a/tstest/deptest/deptest_test.go +++ b/tstest/deptest/deptest_test.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package deptest - -import "testing" - -func TestImports(t *testing.T) { - ImportAliasCheck(t, "../../") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package deptest + +import "testing" + +func TestImports(t *testing.T) { + ImportAliasCheck(t, "../../") +} diff --git a/tstest/integration/gen_deps.go b/tstest/integration/gen_deps.go index 23bb95ee56a9f..ab5cc0448b54d 100644 --- a/tstest/integration/gen_deps.go +++ b/tstest/integration/gen_deps.go @@ -1,65 +1,65 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -package main - -import ( - "bytes" - "encoding/json" - "fmt" - "log" - "os" - "os/exec" - "strings" -) - -func main() { - for _, goos := range []string{"windows", "linux", "darwin", "freebsd", "openbsd"} { - generate(goos) - } -} - -func generate(goos string) { - var x struct { - Imports []string - } - cmd := exec.Command("go", "list", "-json", "tailscale.com/cmd/tailscaled") - cmd.Env = append(os.Environ(), "GOOS="+goos, "GOARCH=amd64") - j, err := cmd.Output() - if err != nil { - log.Fatalf("GOOS=%s GOARCH=amd64 %s: %v", goos, cmd, err) - } - if err := json.Unmarshal(j, &x); err != nil { - log.Fatal(err) - } - var out bytes.Buffer - out.WriteString(`// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Code generated by gen_deps.go; DO NOT EDIT. - -package integration - -import ( - // And depend on a bunch of tailscaled innards, for Go's test caching. - // Otherwise cmd/go never sees that we depend on these packages' - // transitive deps when we run "go install tailscaled" in a child - // process and can cache a prior success when a dependency changes. -`) - for _, dep := range x.Imports { - if !strings.Contains(dep, ".") { - // Omit standard library deps. - continue - } - fmt.Fprintf(&out, "\t_ %q\n", dep) - } - fmt.Fprintf(&out, ")\n") - - filename := fmt.Sprintf("tailscaled_deps_test_%s.go", goos) - err = os.WriteFile(filename, out.Bytes(), 0644) - if err != nil { - log.Fatal(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "log" + "os" + "os/exec" + "strings" +) + +func main() { + for _, goos := range []string{"windows", "linux", "darwin", "freebsd", "openbsd"} { + generate(goos) + } +} + +func generate(goos string) { + var x struct { + Imports []string + } + cmd := exec.Command("go", "list", "-json", "tailscale.com/cmd/tailscaled") + cmd.Env = append(os.Environ(), "GOOS="+goos, "GOARCH=amd64") + j, err := cmd.Output() + if err != nil { + log.Fatalf("GOOS=%s GOARCH=amd64 %s: %v", goos, cmd, err) + } + if err := json.Unmarshal(j, &x); err != nil { + log.Fatal(err) + } + var out bytes.Buffer + out.WriteString(`// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen_deps.go; DO NOT EDIT. + +package integration + +import ( + // And depend on a bunch of tailscaled innards, for Go's test caching. + // Otherwise cmd/go never sees that we depend on these packages' + // transitive deps when we run "go install tailscaled" in a child + // process and can cache a prior success when a dependency changes. +`) + for _, dep := range x.Imports { + if !strings.Contains(dep, ".") { + // Omit standard library deps. + continue + } + fmt.Fprintf(&out, "\t_ %q\n", dep) + } + fmt.Fprintf(&out, ")\n") + + filename := fmt.Sprintf("tailscaled_deps_test_%s.go", goos) + err = os.WriteFile(filename, out.Bytes(), 0644) + if err != nil { + log.Fatal(err) + } +} diff --git a/tstest/integration/vms/README.md b/tstest/integration/vms/README.md index 519c3d000fb63..766d8e5741e6d 100644 --- a/tstest/integration/vms/README.md +++ b/tstest/integration/vms/README.md @@ -1,95 +1,95 @@ -# End-to-End VM-based Integration Testing - -This test spins up a bunch of common linux distributions and then tries to get -them to connect to a -[`testcontrol`](https://pkg.go.dev/tailscale.com/tstest/integration/testcontrol) -server. - -## Running - -This test currently only runs on Linux. - -This test depends on the following command line tools: - -- [qemu](https://www.qemu.org/) -- [cdrkit](https://en.wikipedia.org/wiki/Cdrkit) -- [openssh](https://www.openssh.com/) - -This test also requires the following: - -- about 10 GB of temporary storage -- about 10 GB of cached VM images -- at least 4 GB of ram for virtual machines -- hardware virtualization support - ([KVM](https://www.linux-kvm.org/page/Main_Page)) enabled in the BIOS -- the `kvm` module to be loaded (`modprobe kvm`) -- the user running these tests must have access to `/dev/kvm` (being in the - `kvm` group should suffice) - -The `--no-s3` flag is needed to disable downloads from S3, which require -credentials. However keep in mind that some distributions do not use stable URLs -for each individual image artifact, so there may be spurious test failures as a -result. - -If you are using [Nix](https://nixos.org), you can run all of the tests with the -correct command line tools using this command: - -```console -$ nix-shell -p nixos-generators -p openssh -p go -p qemu -p cdrkit --run "go test . --run-vm-tests --v --timeout 30m --no-s3" -``` - -Keep the timeout high for the first run, especially if you are not downloading -VM images from S3. The mirrors we pull images from have download rate limits and -will take a while to download. - -Because of the hardware requirements of this test, this test will not run -without the `--run-vm-tests` flag set. - -## Other Fun Flags - -This test's behavior is customized with command line flags. - -### Don't Download Images From S3 - -If you pass the `-no-s3` flag to `go test`, the S3 step will be skipped in favor -of downloading the images directly from upstream sources, which may cause the -test to fail in odd places. - -### Distribution Picking - -This test runs on a large number of distributions. By default it tries to run -everything, which may or may not be ideal for you. If you only want to test a -subset of distributions, you can use the `--distro-regex` flag to match a subset -of distributions using a [regular expression](https://golang.org/pkg/regexp/) -such as like this: - -```console -$ go test -run-vm-tests -distro-regex centos -``` - -This would run all tests on all versions of CentOS. - -```console -$ go test -run-vm-tests -distro-regex '(debian|ubuntu)' -``` - -This would run all tests on all versions of Debian and Ubuntu. - -### Ram Limiting - -This test uses a lot of memory. In order to avoid making machines run out of -memory running this test, a semaphore is used to limit how many megabytes of ram -are being used at once. By default this semaphore is set to 4096 MB of ram -(about 4 gigabytes). You can customize this with the `--ram-limit` flag: - -```console -$ go test --run-vm-tests --ram-limit 2048 -$ go test --run-vm-tests --ram-limit 65536 -``` - -The first example will set the limit to 2048 MB of ram (about 2 gigabytes). The -second example will set the limit to 65536 MB of ram (about 65 gigabytes). -Please be careful with this flag, improper usage of it is known to cause the -Linux out-of-memory killer to engage. Try to keep it within 50-75% of your -machine's available ram (there is some overhead involved with the -virtualization) to be on the safe side. +# End-to-End VM-based Integration Testing + +This test spins up a bunch of common linux distributions and then tries to get +them to connect to a +[`testcontrol`](https://pkg.go.dev/tailscale.com/tstest/integration/testcontrol) +server. + +## Running + +This test currently only runs on Linux. + +This test depends on the following command line tools: + +- [qemu](https://www.qemu.org/) +- [cdrkit](https://en.wikipedia.org/wiki/Cdrkit) +- [openssh](https://www.openssh.com/) + +This test also requires the following: + +- about 10 GB of temporary storage +- about 10 GB of cached VM images +- at least 4 GB of ram for virtual machines +- hardware virtualization support + ([KVM](https://www.linux-kvm.org/page/Main_Page)) enabled in the BIOS +- the `kvm` module to be loaded (`modprobe kvm`) +- the user running these tests must have access to `/dev/kvm` (being in the + `kvm` group should suffice) + +The `--no-s3` flag is needed to disable downloads from S3, which require +credentials. However keep in mind that some distributions do not use stable URLs +for each individual image artifact, so there may be spurious test failures as a +result. + +If you are using [Nix](https://nixos.org), you can run all of the tests with the +correct command line tools using this command: + +```console +$ nix-shell -p nixos-generators -p openssh -p go -p qemu -p cdrkit --run "go test . --run-vm-tests --v --timeout 30m --no-s3" +``` + +Keep the timeout high for the first run, especially if you are not downloading +VM images from S3. The mirrors we pull images from have download rate limits and +will take a while to download. + +Because of the hardware requirements of this test, this test will not run +without the `--run-vm-tests` flag set. + +## Other Fun Flags + +This test's behavior is customized with command line flags. + +### Don't Download Images From S3 + +If you pass the `-no-s3` flag to `go test`, the S3 step will be skipped in favor +of downloading the images directly from upstream sources, which may cause the +test to fail in odd places. + +### Distribution Picking + +This test runs on a large number of distributions. By default it tries to run +everything, which may or may not be ideal for you. If you only want to test a +subset of distributions, you can use the `--distro-regex` flag to match a subset +of distributions using a [regular expression](https://golang.org/pkg/regexp/) +such as like this: + +```console +$ go test -run-vm-tests -distro-regex centos +``` + +This would run all tests on all versions of CentOS. + +```console +$ go test -run-vm-tests -distro-regex '(debian|ubuntu)' +``` + +This would run all tests on all versions of Debian and Ubuntu. + +### Ram Limiting + +This test uses a lot of memory. In order to avoid making machines run out of +memory running this test, a semaphore is used to limit how many megabytes of ram +are being used at once. By default this semaphore is set to 4096 MB of ram +(about 4 gigabytes). You can customize this with the `--ram-limit` flag: + +```console +$ go test --run-vm-tests --ram-limit 2048 +$ go test --run-vm-tests --ram-limit 65536 +``` + +The first example will set the limit to 2048 MB of ram (about 2 gigabytes). The +second example will set the limit to 65536 MB of ram (about 65 gigabytes). +Please be careful with this flag, improper usage of it is known to cause the +Linux out-of-memory killer to engage. Try to keep it within 50-75% of your +machine's available ram (there is some overhead involved with the +virtualization) to be on the safe side. diff --git a/tstest/integration/vms/distros.hujson b/tstest/integration/vms/distros.hujson index 049091ed50e6e..5634d6d678562 100644 --- a/tstest/integration/vms/distros.hujson +++ b/tstest/integration/vms/distros.hujson @@ -1,39 +1,39 @@ -// NOTE(Xe): If you run into issues getting the autoconfig to work, run -// this test with the flag `--distro-regex=alpine-edge`. Connect with a VNC -// client with a command like this: -// -// $ vncviewer :0 -// -// On NixOS you can get away with something like this: -// -// $ env NIXPKGS_ALLOW_UNFREE=1 nix-shell -p tigervnc --run 'vncviewer :0' -// -// Login as root with the password root. Then look in -// /var/log/cloud-init-output.log for what you messed up. -[ - { - "Name": "ubuntu-18-04", - "URL": "https://cloud-images.ubuntu.com/releases/bionic/release-20210817/ubuntu-18.04-server-cloudimg-amd64.img", - "SHA256Sum": "1ee1039f0b91c8367351413b5b5f56026aaf302fd5f66f17f8215132d6e946d2", - "MemoryMegs": 512, - "PackageManager": "apt", - "InitSystem": "systemd" - }, - { - "Name": "ubuntu-20-04", - "URL": "https://cloud-images.ubuntu.com/releases/focal/release-20210819/ubuntu-20.04-server-cloudimg-amd64.img", - "SHA256Sum": "99e25e6e344e3a50a081235e825937238a3d51b099969e107ef66f0d3a1f955e", - "MemoryMegs": 512, - "PackageManager": "apt", - "InitSystem": "systemd" - }, - { - "Name": "nixos-21-11", - "URL": "channel:nixos-21.11", - "SHA256Sum": "lolfakesha", - "MemoryMegs": 512, - "PackageManager": "nix", - "InitSystem": "systemd", - "HostGenerated": true - }, -] +// NOTE(Xe): If you run into issues getting the autoconfig to work, run +// this test with the flag `--distro-regex=alpine-edge`. Connect with a VNC +// client with a command like this: +// +// $ vncviewer :0 +// +// On NixOS you can get away with something like this: +// +// $ env NIXPKGS_ALLOW_UNFREE=1 nix-shell -p tigervnc --run 'vncviewer :0' +// +// Login as root with the password root. Then look in +// /var/log/cloud-init-output.log for what you messed up. +[ + { + "Name": "ubuntu-18-04", + "URL": "https://cloud-images.ubuntu.com/releases/bionic/release-20210817/ubuntu-18.04-server-cloudimg-amd64.img", + "SHA256Sum": "1ee1039f0b91c8367351413b5b5f56026aaf302fd5f66f17f8215132d6e946d2", + "MemoryMegs": 512, + "PackageManager": "apt", + "InitSystem": "systemd" + }, + { + "Name": "ubuntu-20-04", + "URL": "https://cloud-images.ubuntu.com/releases/focal/release-20210819/ubuntu-20.04-server-cloudimg-amd64.img", + "SHA256Sum": "99e25e6e344e3a50a081235e825937238a3d51b099969e107ef66f0d3a1f955e", + "MemoryMegs": 512, + "PackageManager": "apt", + "InitSystem": "systemd" + }, + { + "Name": "nixos-21-11", + "URL": "channel:nixos-21.11", + "SHA256Sum": "lolfakesha", + "MemoryMegs": 512, + "PackageManager": "nix", + "InitSystem": "systemd", + "HostGenerated": true + }, +] diff --git a/tstest/integration/vms/distros_test.go b/tstest/integration/vms/distros_test.go index 462aa2a6bc825..db3bae793b367 100644 --- a/tstest/integration/vms/distros_test.go +++ b/tstest/integration/vms/distros_test.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package vms - -import ( - "testing" -) - -func TestDistrosGotLoaded(t *testing.T) { - if len(Distros) == 0 { - t.Fatal("no distros were loaded") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package vms + +import ( + "testing" +) + +func TestDistrosGotLoaded(t *testing.T) { + if len(Distros) == 0 { + t.Fatal("no distros were loaded") + } +} diff --git a/tstest/integration/vms/dns_tester.go b/tstest/integration/vms/dns_tester.go index 50b39bb5f1fa1..be7d7ee6d69c8 100644 --- a/tstest/integration/vms/dns_tester.go +++ b/tstest/integration/vms/dns_tester.go @@ -1,54 +1,54 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -// Command dns_tester exists in order to perform tests of our DNS -// configuration stack. This was written because the state of DNS -// in our target environments is so diverse that we need a little tool -// to do this test for us. -package main - -import ( - "context" - "encoding/json" - "flag" - "net" - "os" - "time" -) - -func main() { - flag.Parse() - target := flag.Arg(0) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - errCount := 0 - wait := 25 * time.Millisecond - for range make([]struct{}, 5) { - err := lookup(ctx, target) - if err != nil { - errCount++ - time.Sleep(wait) - wait = wait * 2 - continue - } - - break - } -} - -func lookup(ctx context.Context, target string) error { - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - - hosts, err := net.LookupHost(target) - if err != nil { - return err - } - - json.NewEncoder(os.Stdout).Encode(hosts) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +// Command dns_tester exists in order to perform tests of our DNS +// configuration stack. This was written because the state of DNS +// in our target environments is so diverse that we need a little tool +// to do this test for us. +package main + +import ( + "context" + "encoding/json" + "flag" + "net" + "os" + "time" +) + +func main() { + flag.Parse() + target := flag.Arg(0) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCount := 0 + wait := 25 * time.Millisecond + for range make([]struct{}, 5) { + err := lookup(ctx, target) + if err != nil { + errCount++ + time.Sleep(wait) + wait = wait * 2 + continue + } + + break + } +} + +func lookup(ctx context.Context, target string) error { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + hosts, err := net.LookupHost(target) + if err != nil { + return err + } + + json.NewEncoder(os.Stdout).Encode(hosts) + return nil +} diff --git a/tstest/integration/vms/doc.go b/tstest/integration/vms/doc.go index 6093b53ac8ed5..3008493ea1a33 100644 --- a/tstest/integration/vms/doc.go +++ b/tstest/integration/vms/doc.go @@ -1,6 +1,6 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package vms does VM-based integration/functional tests by using -// qemu and a bank of pre-made VM images. -package vms +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package vms does VM-based integration/functional tests by using +// qemu and a bank of pre-made VM images. +package vms diff --git a/tstest/integration/vms/harness_test.go b/tstest/integration/vms/harness_test.go index 1e080414d72e7..620276ac26491 100644 --- a/tstest/integration/vms/harness_test.go +++ b/tstest/integration/vms/harness_test.go @@ -1,242 +1,242 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !plan9 - -package vms - -import ( - "bytes" - "context" - "fmt" - "log" - "net" - "net/http" - "net/netip" - "os" - "os/exec" - "path" - "path/filepath" - "strconv" - "sync" - "testing" - "time" - - "golang.org/x/crypto/ssh" - "golang.org/x/net/proxy" - "tailscale.com/tailcfg" - "tailscale.com/tstest/integration" - "tailscale.com/tstest/integration/testcontrol" - "tailscale.com/types/dnstype" -) - -type Harness struct { - testerDialer proxy.Dialer - testerDir string - binaryDir string - cli string - daemon string - pubKey string - signer ssh.Signer - cs *testcontrol.Server - loginServerURL string - testerV4 netip.Addr - ipMu *sync.Mutex - ipMap map[string]ipMapping -} - -func newHarness(t *testing.T) *Harness { - dir := t.TempDir() - bindHost := deriveBindhost(t) - ln, err := net.Listen("tcp", net.JoinHostPort(bindHost, "0")) - if err != nil { - t.Fatalf("can't make TCP listener: %v", err) - } - t.Cleanup(func() { - ln.Close() - }) - t.Logf("host:port: %s", ln.Addr()) - - cs := &testcontrol.Server{ - DNSConfig: &tailcfg.DNSConfig{ - // TODO: this is wrong. - // It is also only one of many configurations. - // Figure out how to scale it up. - Resolvers: []*dnstype.Resolver{{Addr: "100.100.100.100"}, {Addr: "8.8.8.8"}}, - Domains: []string{"record"}, - Proxied: true, - ExtraRecords: []tailcfg.DNSRecord{{Name: "extratest.record", Type: "A", Value: "1.2.3.4"}}, - }, - } - - derpMap := integration.RunDERPAndSTUN(t, t.Logf, bindHost) - cs.DERPMap = derpMap - - var ( - ipMu sync.Mutex - ipMap = map[string]ipMapping{} - ) - - mux := http.NewServeMux() - mux.Handle("/", cs) - - lc := &integration.LogCatcher{} - if *verboseLogcatcher { - lc.UseLogf(t.Logf) - t.Cleanup(func() { - lc.UseLogf(nil) // do not log after test is complete - }) - } - mux.Handle("/c/", lc) - - // This handler will let the virtual machines tell the host information about that VM. - // This is used to maintain a list of port->IP address mappings that are known to be - // working. This allows later steps to connect over SSH. This returns no response to - // clients because no response is needed. - mux.HandleFunc("/myip/", func(w http.ResponseWriter, r *http.Request) { - ipMu.Lock() - defer ipMu.Unlock() - - name := path.Base(r.URL.Path) - host, _, _ := net.SplitHostPort(r.RemoteAddr) - port, err := strconv.Atoi(name) - if err != nil { - log.Panicf("bad port: %v", port) - } - distro := r.UserAgent() - ipMap[distro] = ipMapping{distro, port, host} - t.Logf("%s: %v", name, host) - }) - - hs := &http.Server{Handler: mux} - go hs.Serve(ln) - - cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", "machinekey", "-N", "") - cmd.Dir = dir - if out, err := cmd.CombinedOutput(); err != nil { - t.Fatalf("ssh-keygen: %v, %s", err, out) - } - pubkey, err := os.ReadFile(filepath.Join(dir, "machinekey.pub")) - if err != nil { - t.Fatalf("can't read ssh key: %v", err) - } - - privateKey, err := os.ReadFile(filepath.Join(dir, "machinekey")) - if err != nil { - t.Fatalf("can't read ssh private key: %v", err) - } - - signer, err := ssh.ParsePrivateKey(privateKey) - if err != nil { - t.Fatalf("can't parse private key: %v", err) - } - - loginServer := fmt.Sprintf("http://%s", ln.Addr()) - t.Logf("loginServer: %s", loginServer) - - h := &Harness{ - pubKey: string(pubkey), - binaryDir: integration.BinaryDir(t), - cli: integration.TailscaleBinary(t), - daemon: integration.TailscaledBinary(t), - signer: signer, - loginServerURL: loginServer, - cs: cs, - ipMu: &ipMu, - ipMap: ipMap, - } - - h.makeTestNode(t, loginServer) - - return h -} - -func (h *Harness) Tailscale(t *testing.T, args ...string) []byte { - t.Helper() - - args = append([]string{"--socket=" + filepath.Join(h.testerDir, "sock")}, args...) - - cmd := exec.Command(h.cli, args...) - out, err := cmd.CombinedOutput() - if err != nil { - t.Fatal(err) - } - - return out -} - -// makeTestNode creates a userspace tailscaled running in netstack mode that -// enables us to make connections to and from the tailscale network being -// tested. This mutates the Harness to allow tests to dial into the tailscale -// network as well as control the tester's tailscaled. -func (h *Harness) makeTestNode(t *testing.T, controlURL string) { - dir := t.TempDir() - h.testerDir = dir - - port, err := getProbablyFreePortNumber() - if err != nil { - t.Fatalf("can't get free port: %v", err) - } - - cmd := exec.Command( - h.daemon, - "--tun=userspace-networking", - "--state="+filepath.Join(dir, "state.json"), - "--socket="+filepath.Join(dir, "sock"), - fmt.Sprintf("--socks5-server=localhost:%d", port), - ) - - cmd.Env = append( - os.Environ(), - "NOTIFY_SOCKET="+filepath.Join(dir, "notify_socket"), - "TS_LOG_TARGET="+h.loginServerURL, - ) - - err = cmd.Start() - if err != nil { - t.Fatalf("can't start tailscaled: %v", err) - } - - t.Cleanup(func() { - cmd.Process.Kill() - }) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - ticker := time.NewTicker(100 * time.Millisecond) - -outer: - for { - select { - case <-ctx.Done(): - t.Fatal("timed out waiting for tailscaled to come up") - return - case <-ticker.C: - conn, err := net.Dial("unix", filepath.Join(dir, "sock")) - if err != nil { - continue - } - - conn.Close() - break outer - } - } - - run(t, dir, h.cli, - "--socket="+filepath.Join(dir, "sock"), - "up", - "--login-server="+controlURL, - "--hostname=tester", - ) - - dialer, err := proxy.SOCKS5("tcp", net.JoinHostPort("127.0.0.1", fmt.Sprint(port)), nil, &net.Dialer{}) - if err != nil { - t.Fatalf("can't make netstack proxy dialer: %v", err) - } - h.testerDialer = dialer - h.testerV4 = bytes2Netaddr(h.Tailscale(t, "ip", "-4")) -} - -func bytes2Netaddr(inp []byte) netip.Addr { - return netip.MustParseAddr(string(bytes.TrimSpace(inp))) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !plan9 + +package vms + +import ( + "bytes" + "context" + "fmt" + "log" + "net" + "net/http" + "net/netip" + "os" + "os/exec" + "path" + "path/filepath" + "strconv" + "sync" + "testing" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/net/proxy" + "tailscale.com/tailcfg" + "tailscale.com/tstest/integration" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/dnstype" +) + +type Harness struct { + testerDialer proxy.Dialer + testerDir string + binaryDir string + cli string + daemon string + pubKey string + signer ssh.Signer + cs *testcontrol.Server + loginServerURL string + testerV4 netip.Addr + ipMu *sync.Mutex + ipMap map[string]ipMapping +} + +func newHarness(t *testing.T) *Harness { + dir := t.TempDir() + bindHost := deriveBindhost(t) + ln, err := net.Listen("tcp", net.JoinHostPort(bindHost, "0")) + if err != nil { + t.Fatalf("can't make TCP listener: %v", err) + } + t.Cleanup(func() { + ln.Close() + }) + t.Logf("host:port: %s", ln.Addr()) + + cs := &testcontrol.Server{ + DNSConfig: &tailcfg.DNSConfig{ + // TODO: this is wrong. + // It is also only one of many configurations. + // Figure out how to scale it up. + Resolvers: []*dnstype.Resolver{{Addr: "100.100.100.100"}, {Addr: "8.8.8.8"}}, + Domains: []string{"record"}, + Proxied: true, + ExtraRecords: []tailcfg.DNSRecord{{Name: "extratest.record", Type: "A", Value: "1.2.3.4"}}, + }, + } + + derpMap := integration.RunDERPAndSTUN(t, t.Logf, bindHost) + cs.DERPMap = derpMap + + var ( + ipMu sync.Mutex + ipMap = map[string]ipMapping{} + ) + + mux := http.NewServeMux() + mux.Handle("/", cs) + + lc := &integration.LogCatcher{} + if *verboseLogcatcher { + lc.UseLogf(t.Logf) + t.Cleanup(func() { + lc.UseLogf(nil) // do not log after test is complete + }) + } + mux.Handle("/c/", lc) + + // This handler will let the virtual machines tell the host information about that VM. + // This is used to maintain a list of port->IP address mappings that are known to be + // working. This allows later steps to connect over SSH. This returns no response to + // clients because no response is needed. + mux.HandleFunc("/myip/", func(w http.ResponseWriter, r *http.Request) { + ipMu.Lock() + defer ipMu.Unlock() + + name := path.Base(r.URL.Path) + host, _, _ := net.SplitHostPort(r.RemoteAddr) + port, err := strconv.Atoi(name) + if err != nil { + log.Panicf("bad port: %v", port) + } + distro := r.UserAgent() + ipMap[distro] = ipMapping{distro, port, host} + t.Logf("%s: %v", name, host) + }) + + hs := &http.Server{Handler: mux} + go hs.Serve(ln) + + cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", "machinekey", "-N", "") + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("ssh-keygen: %v, %s", err, out) + } + pubkey, err := os.ReadFile(filepath.Join(dir, "machinekey.pub")) + if err != nil { + t.Fatalf("can't read ssh key: %v", err) + } + + privateKey, err := os.ReadFile(filepath.Join(dir, "machinekey")) + if err != nil { + t.Fatalf("can't read ssh private key: %v", err) + } + + signer, err := ssh.ParsePrivateKey(privateKey) + if err != nil { + t.Fatalf("can't parse private key: %v", err) + } + + loginServer := fmt.Sprintf("http://%s", ln.Addr()) + t.Logf("loginServer: %s", loginServer) + + h := &Harness{ + pubKey: string(pubkey), + binaryDir: integration.BinaryDir(t), + cli: integration.TailscaleBinary(t), + daemon: integration.TailscaledBinary(t), + signer: signer, + loginServerURL: loginServer, + cs: cs, + ipMu: &ipMu, + ipMap: ipMap, + } + + h.makeTestNode(t, loginServer) + + return h +} + +func (h *Harness) Tailscale(t *testing.T, args ...string) []byte { + t.Helper() + + args = append([]string{"--socket=" + filepath.Join(h.testerDir, "sock")}, args...) + + cmd := exec.Command(h.cli, args...) + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatal(err) + } + + return out +} + +// makeTestNode creates a userspace tailscaled running in netstack mode that +// enables us to make connections to and from the tailscale network being +// tested. This mutates the Harness to allow tests to dial into the tailscale +// network as well as control the tester's tailscaled. +func (h *Harness) makeTestNode(t *testing.T, controlURL string) { + dir := t.TempDir() + h.testerDir = dir + + port, err := getProbablyFreePortNumber() + if err != nil { + t.Fatalf("can't get free port: %v", err) + } + + cmd := exec.Command( + h.daemon, + "--tun=userspace-networking", + "--state="+filepath.Join(dir, "state.json"), + "--socket="+filepath.Join(dir, "sock"), + fmt.Sprintf("--socks5-server=localhost:%d", port), + ) + + cmd.Env = append( + os.Environ(), + "NOTIFY_SOCKET="+filepath.Join(dir, "notify_socket"), + "TS_LOG_TARGET="+h.loginServerURL, + ) + + err = cmd.Start() + if err != nil { + t.Fatalf("can't start tailscaled: %v", err) + } + + t.Cleanup(func() { + cmd.Process.Kill() + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ticker := time.NewTicker(100 * time.Millisecond) + +outer: + for { + select { + case <-ctx.Done(): + t.Fatal("timed out waiting for tailscaled to come up") + return + case <-ticker.C: + conn, err := net.Dial("unix", filepath.Join(dir, "sock")) + if err != nil { + continue + } + + conn.Close() + break outer + } + } + + run(t, dir, h.cli, + "--socket="+filepath.Join(dir, "sock"), + "up", + "--login-server="+controlURL, + "--hostname=tester", + ) + + dialer, err := proxy.SOCKS5("tcp", net.JoinHostPort("127.0.0.1", fmt.Sprint(port)), nil, &net.Dialer{}) + if err != nil { + t.Fatalf("can't make netstack proxy dialer: %v", err) + } + h.testerDialer = dialer + h.testerV4 = bytes2Netaddr(h.Tailscale(t, "ip", "-4")) +} + +func bytes2Netaddr(inp []byte) netip.Addr { + return netip.MustParseAddr(string(bytes.TrimSpace(inp))) +} diff --git a/tstest/integration/vms/nixos_test.go b/tstest/integration/vms/nixos_test.go index c2998ff3c087c..06a14e4f6cc21 100644 --- a/tstest/integration/vms/nixos_test.go +++ b/tstest/integration/vms/nixos_test.go @@ -1,231 +1,231 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !plan9 - -package vms - -import ( - "flag" - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - "testing" - "text/template" - - "tailscale.com/types/logger" -) - -var ( - verboseNixOutput = flag.Bool("verbose-nix-output", false, "if set, use verbose nix output (lots of noise)") -) - -/* - NOTE(Xe): Okay, so, at a high level testing NixOS is a lot different than - other distros due to NixOS' determinism. Normally NixOS wants packages to - be defined in either an overlay, a custom packageOverrides or even - yolo-inline as a part of the system configuration. This is going to have - us take a different approach compared to other distributions. The overall - plan here is as following: - - 1. make the binaries as normal - 2. template in their paths as raw strings to the nixos system module - 3. run `nixos-generators -f qcow -o $CACHE_DIR/tailscale/nixos/version -c generated-config.nix` - 4. pass that to the steps that make the virtual machine - - It doesn't really make sense for us to use a premade virtual machine image - for this as that will make it harder to deterministically create the image. -*/ - -const nixosConfigTemplate = ` -# NOTE(Xe): This template is going to be heavily commented. - -# All NixOS modules are functions. Here is the function prelude for this NixOS -# module that defines the system. It is a function that takes in an attribute -# set (effectively a map[string]nix.Value) and destructures it to some variables: -{ - # other NixOS settings as defined in other modules - config, - - # nixpkgs, which is basically the standard library of NixOS - pkgs, - - # the path to some system-scoped NixOS modules that aren't imported by default - modulesPath, - - # the rest of the arguments don't matter - ... -}: - -# Nix's syntax was inspired by Haskell and other functional languages, so the -# let .. in pattern is used to create scoped variables: -let - # Define the package (derivation) for Tailscale based on the binaries we - # just built for this test: - testTailscale = pkgs.stdenv.mkDerivation { - # The name of the package. This usually includes a version however it - # doesn't matter here. - name = "tailscale-test"; - - # The path on disk to the "source code" of the package, in this case it is - # the path to the binaries that are built. This needs to be the raw - # unquoted slash-separated path, not a string containing the path because Nix - # has a special path type. - src = {{.BinPath}}; - - # We only need to worry about the install phase because we've already - # built the binaries. - phases = "installPhase"; - - # We need to wrap tailscaled such that it has iptables in its $PATH. - nativeBuildInputs = [ pkgs.makeWrapper ]; - - # The install instructions for this package ('' ''defines a multi-line string). - # The with statement lets us bring in values into scope as if they were - # defined in the current scope. - installPhase = with pkgs; '' - # This is bash. - - # Make the output folders for the package (systemd unit and binary folders). - mkdir -p $out/bin - - # Install tailscale{,d} - cp $src/tailscale $out/bin/tailscale - cp $src/tailscaled $out/bin/tailscaled - - # Wrap tailscaled with the ip and iptables commands. - wrapProgram $out/bin/tailscaled --prefix PATH : ${ - lib.makeBinPath [ iproute iptables ] - } - - # Install systemd unit. - cp $src/systemd/tailscaled.service . - sed -i -e "s#/usr/sbin#$out/bin#" -e "/^EnvironmentFile/d" ./tailscaled.service - install -D -m0444 -t $out/lib/systemd/system ./tailscaled.service - ''; - }; -in { - # This is a QEMU VM. This module has a lot of common qemu VM settings so you - # don't have to set them manually. - imports = [ (modulesPath + "/profiles/qemu-guest.nix") ]; - - # We need virtio support to boot. - boot.initrd.availableKernelModules = - [ "ata_piix" "uhci_hcd" "virtio_pci" "sr_mod" "virtio_blk" ]; - boot.initrd.kernelModules = [ ]; - boot.kernelModules = [ ]; - boot.extraModulePackages = [ ]; - - # Curl is needed for one of the steps in cloud-final - systemd.services.cloud-final.path = with pkgs; [ curl ]; - - # Curl is needed for one of the integration tests - environment.systemPackages = with pkgs; [ curl nix bash squid openssl daemonize ]; - - # yolo, this vm can sudo freely. - security.sudo.wheelNeedsPassword = false; - - # Enable cloud-init so we can set VM hostnames and the like the same as other - # distros. This will also take care of SSH keys. It's pretty handy. - services.cloud-init = { - enable = true; - ext4.enable = true; - }; - - # We want sshd running. - services.openssh.enable = true; - - # Tailscale settings: - services.tailscale = { - # We want Tailscale to start at boot. - enable = true; - - # Use the Tailscale package we just assembled. - package = testTailscale; - }; - - # Override TS_LOG_TARGET to our private logcatcher. - systemd.services.tailscaled.environment."TS_LOG_TARGET" = "{{.LogTarget}}"; -}` - -func (h *Harness) copyUnit(t *testing.T) { - t.Helper() - - data, err := os.ReadFile("../../../cmd/tailscaled/tailscaled.service") - if err != nil { - t.Fatal(err) - } - os.MkdirAll(filepath.Join(h.binaryDir, "systemd"), 0755) - err = os.WriteFile(filepath.Join(h.binaryDir, "systemd", "tailscaled.service"), data, 0666) - if err != nil { - t.Fatal(err) - } -} - -func (h *Harness) makeNixOSImage(t *testing.T, d Distro, cdir string) string { - if d.Name == "nixos-unstable" { - t.Skip("https://github.com/NixOS/nixpkgs/issues/131098") - } - - h.copyUnit(t) - dir := t.TempDir() - fname := filepath.Join(dir, d.Name+".nix") - fout, err := os.Create(fname) - if err != nil { - t.Fatal(err) - } - - tmpl := template.Must(template.New("base.nix").Parse(nixosConfigTemplate)) - err = tmpl.Execute(fout, struct { - BinPath string - LogTarget string - }{ - BinPath: h.binaryDir, - LogTarget: h.loginServerURL, - }) - if err != nil { - t.Fatal(err) - } - - err = fout.Close() - if err != nil { - t.Fatal(err) - } - - outpath := filepath.Join(cdir, "nixos") - os.MkdirAll(outpath, 0755) - - t.Cleanup(func() { - os.RemoveAll(filepath.Join(outpath, d.Name)) // makes the disk image a candidate for GC - }) - - cmd := exec.Command("nixos-generate", "-f", "qcow", "-o", filepath.Join(outpath, d.Name), "-c", fname) - if *verboseNixOutput { - cmd.Stdout = logger.FuncWriter(t.Logf) - cmd.Stderr = logger.FuncWriter(t.Logf) - } else { - fname := fmt.Sprintf("nix-build-%s-%s", os.Getenv("GITHUB_RUN_NUMBER"), strings.Replace(t.Name(), "/", "-", -1)) - t.Logf("writing nix logs to %s", fname) - fout, err := os.Create(fname) - if err != nil { - t.Fatalf("can't make log file for nix build: %v", err) - } - cmd.Stdout = fout - cmd.Stderr = fout - defer fout.Close() - } - cmd.Env = append(os.Environ(), "NIX_PATH=nixpkgs="+d.URL) - cmd.Dir = outpath - t.Logf("running %s %#v", "nixos-generate", cmd.Args) - if err := cmd.Run(); err != nil { - t.Fatalf("error while making NixOS image for %s: %v", d.Name, err) - } - - if !*verboseNixOutput { - t.Log("done") - } - - return filepath.Join(outpath, d.Name, "nixos.qcow2") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !plan9 + +package vms + +import ( + "flag" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "text/template" + + "tailscale.com/types/logger" +) + +var ( + verboseNixOutput = flag.Bool("verbose-nix-output", false, "if set, use verbose nix output (lots of noise)") +) + +/* + NOTE(Xe): Okay, so, at a high level testing NixOS is a lot different than + other distros due to NixOS' determinism. Normally NixOS wants packages to + be defined in either an overlay, a custom packageOverrides or even + yolo-inline as a part of the system configuration. This is going to have + us take a different approach compared to other distributions. The overall + plan here is as following: + + 1. make the binaries as normal + 2. template in their paths as raw strings to the nixos system module + 3. run `nixos-generators -f qcow -o $CACHE_DIR/tailscale/nixos/version -c generated-config.nix` + 4. pass that to the steps that make the virtual machine + + It doesn't really make sense for us to use a premade virtual machine image + for this as that will make it harder to deterministically create the image. +*/ + +const nixosConfigTemplate = ` +# NOTE(Xe): This template is going to be heavily commented. + +# All NixOS modules are functions. Here is the function prelude for this NixOS +# module that defines the system. It is a function that takes in an attribute +# set (effectively a map[string]nix.Value) and destructures it to some variables: +{ + # other NixOS settings as defined in other modules + config, + + # nixpkgs, which is basically the standard library of NixOS + pkgs, + + # the path to some system-scoped NixOS modules that aren't imported by default + modulesPath, + + # the rest of the arguments don't matter + ... +}: + +# Nix's syntax was inspired by Haskell and other functional languages, so the +# let .. in pattern is used to create scoped variables: +let + # Define the package (derivation) for Tailscale based on the binaries we + # just built for this test: + testTailscale = pkgs.stdenv.mkDerivation { + # The name of the package. This usually includes a version however it + # doesn't matter here. + name = "tailscale-test"; + + # The path on disk to the "source code" of the package, in this case it is + # the path to the binaries that are built. This needs to be the raw + # unquoted slash-separated path, not a string containing the path because Nix + # has a special path type. + src = {{.BinPath}}; + + # We only need to worry about the install phase because we've already + # built the binaries. + phases = "installPhase"; + + # We need to wrap tailscaled such that it has iptables in its $PATH. + nativeBuildInputs = [ pkgs.makeWrapper ]; + + # The install instructions for this package ('' ''defines a multi-line string). + # The with statement lets us bring in values into scope as if they were + # defined in the current scope. + installPhase = with pkgs; '' + # This is bash. + + # Make the output folders for the package (systemd unit and binary folders). + mkdir -p $out/bin + + # Install tailscale{,d} + cp $src/tailscale $out/bin/tailscale + cp $src/tailscaled $out/bin/tailscaled + + # Wrap tailscaled with the ip and iptables commands. + wrapProgram $out/bin/tailscaled --prefix PATH : ${ + lib.makeBinPath [ iproute iptables ] + } + + # Install systemd unit. + cp $src/systemd/tailscaled.service . + sed -i -e "s#/usr/sbin#$out/bin#" -e "/^EnvironmentFile/d" ./tailscaled.service + install -D -m0444 -t $out/lib/systemd/system ./tailscaled.service + ''; + }; +in { + # This is a QEMU VM. This module has a lot of common qemu VM settings so you + # don't have to set them manually. + imports = [ (modulesPath + "/profiles/qemu-guest.nix") ]; + + # We need virtio support to boot. + boot.initrd.availableKernelModules = + [ "ata_piix" "uhci_hcd" "virtio_pci" "sr_mod" "virtio_blk" ]; + boot.initrd.kernelModules = [ ]; + boot.kernelModules = [ ]; + boot.extraModulePackages = [ ]; + + # Curl is needed for one of the steps in cloud-final + systemd.services.cloud-final.path = with pkgs; [ curl ]; + + # Curl is needed for one of the integration tests + environment.systemPackages = with pkgs; [ curl nix bash squid openssl daemonize ]; + + # yolo, this vm can sudo freely. + security.sudo.wheelNeedsPassword = false; + + # Enable cloud-init so we can set VM hostnames and the like the same as other + # distros. This will also take care of SSH keys. It's pretty handy. + services.cloud-init = { + enable = true; + ext4.enable = true; + }; + + # We want sshd running. + services.openssh.enable = true; + + # Tailscale settings: + services.tailscale = { + # We want Tailscale to start at boot. + enable = true; + + # Use the Tailscale package we just assembled. + package = testTailscale; + }; + + # Override TS_LOG_TARGET to our private logcatcher. + systemd.services.tailscaled.environment."TS_LOG_TARGET" = "{{.LogTarget}}"; +}` + +func (h *Harness) copyUnit(t *testing.T) { + t.Helper() + + data, err := os.ReadFile("../../../cmd/tailscaled/tailscaled.service") + if err != nil { + t.Fatal(err) + } + os.MkdirAll(filepath.Join(h.binaryDir, "systemd"), 0755) + err = os.WriteFile(filepath.Join(h.binaryDir, "systemd", "tailscaled.service"), data, 0666) + if err != nil { + t.Fatal(err) + } +} + +func (h *Harness) makeNixOSImage(t *testing.T, d Distro, cdir string) string { + if d.Name == "nixos-unstable" { + t.Skip("https://github.com/NixOS/nixpkgs/issues/131098") + } + + h.copyUnit(t) + dir := t.TempDir() + fname := filepath.Join(dir, d.Name+".nix") + fout, err := os.Create(fname) + if err != nil { + t.Fatal(err) + } + + tmpl := template.Must(template.New("base.nix").Parse(nixosConfigTemplate)) + err = tmpl.Execute(fout, struct { + BinPath string + LogTarget string + }{ + BinPath: h.binaryDir, + LogTarget: h.loginServerURL, + }) + if err != nil { + t.Fatal(err) + } + + err = fout.Close() + if err != nil { + t.Fatal(err) + } + + outpath := filepath.Join(cdir, "nixos") + os.MkdirAll(outpath, 0755) + + t.Cleanup(func() { + os.RemoveAll(filepath.Join(outpath, d.Name)) // makes the disk image a candidate for GC + }) + + cmd := exec.Command("nixos-generate", "-f", "qcow", "-o", filepath.Join(outpath, d.Name), "-c", fname) + if *verboseNixOutput { + cmd.Stdout = logger.FuncWriter(t.Logf) + cmd.Stderr = logger.FuncWriter(t.Logf) + } else { + fname := fmt.Sprintf("nix-build-%s-%s", os.Getenv("GITHUB_RUN_NUMBER"), strings.Replace(t.Name(), "/", "-", -1)) + t.Logf("writing nix logs to %s", fname) + fout, err := os.Create(fname) + if err != nil { + t.Fatalf("can't make log file for nix build: %v", err) + } + cmd.Stdout = fout + cmd.Stderr = fout + defer fout.Close() + } + cmd.Env = append(os.Environ(), "NIX_PATH=nixpkgs="+d.URL) + cmd.Dir = outpath + t.Logf("running %s %#v", "nixos-generate", cmd.Args) + if err := cmd.Run(); err != nil { + t.Fatalf("error while making NixOS image for %s: %v", d.Name, err) + } + + if !*verboseNixOutput { + t.Log("done") + } + + return filepath.Join(outpath, d.Name, "nixos.qcow2") +} diff --git a/tstest/integration/vms/regex_flag.go b/tstest/integration/vms/regex_flag.go index 02e399ecdfaad..195f7c7718b7c 100644 --- a/tstest/integration/vms/regex_flag.go +++ b/tstest/integration/vms/regex_flag.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package vms - -import "regexp" - -type regexValue struct { - r *regexp.Regexp -} - -func (r *regexValue) String() string { - if r.r == nil { - return "" - } - - return r.r.String() -} - -func (r *regexValue) Set(val string) error { - if rex, err := regexp.Compile(val); err != nil { - return err - } else { - r.r = rex - return nil - } -} - -func (r regexValue) Unwrap() *regexp.Regexp { return r.r } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package vms + +import "regexp" + +type regexValue struct { + r *regexp.Regexp +} + +func (r *regexValue) String() string { + if r.r == nil { + return "" + } + + return r.r.String() +} + +func (r *regexValue) Set(val string) error { + if rex, err := regexp.Compile(val); err != nil { + return err + } else { + r.r = rex + return nil + } +} + +func (r regexValue) Unwrap() *regexp.Regexp { return r.r } diff --git a/tstest/integration/vms/regex_flag_test.go b/tstest/integration/vms/regex_flag_test.go index 0f4e5f8f7bdec..790894080a7d5 100644 --- a/tstest/integration/vms/regex_flag_test.go +++ b/tstest/integration/vms/regex_flag_test.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package vms - -import ( - "flag" - "testing" -) - -func TestRegexFlag(t *testing.T) { - var v regexValue - fs := flag.NewFlagSet(t.Name(), flag.PanicOnError) - fs.Var(&v, "regex", "regex to parse") - - const want = `.*` - fs.Parse([]string{"-regex", want}) - if v.Unwrap().String() != want { - t.Fatalf("got wrong regex: %q, wanted: %q", v.Unwrap().String(), want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package vms + +import ( + "flag" + "testing" +) + +func TestRegexFlag(t *testing.T) { + var v regexValue + fs := flag.NewFlagSet(t.Name(), flag.PanicOnError) + fs.Var(&v, "regex", "regex to parse") + + const want = `.*` + fs.Parse([]string{"-regex", want}) + if v.Unwrap().String() != want { + t.Fatalf("got wrong regex: %q, wanted: %q", v.Unwrap().String(), want) + } +} diff --git a/tstest/integration/vms/runner.nix b/tstest/integration/vms/runner.nix index ac569cf658cb1..8d4c0a25dc5f6 100644 --- a/tstest/integration/vms/runner.nix +++ b/tstest/integration/vms/runner.nix @@ -1,89 +1,89 @@ -# This is a NixOS module to allow a machine to act as an integration test -# runner. This is used for the end-to-end VM test suite. - -{ lib, config, pkgs, ... }: - -{ - # The GitHub Actions self-hosted runner service. - services.github-runner = { - enable = true; - url = "https://github.com/tailscale/tailscale"; - replace = true; - extraLabels = [ "vm_integration_test" ]; - - # Justifications for the packages: - extraPackages = with pkgs; [ - # The test suite is written in Go. - go - - # This contains genisoimage, which is needed to create cloud-init - # seeds. - cdrkit - - # This package is the virtual machine hypervisor we use in tests. - qemu - - # This package contains tools like `ssh-keygen`. - openssh - - # The C compiler so cgo builds work. - gcc - - # The package manager Nix, just in case. - nix - - # Used to generate a NixOS image for testing. - nixos-generators - - # Used to extract things. - gnutar - - # Used to decompress things. - lzma - ]; - - # Customize this to include your GitHub username so we can track - # who is running which node. - name = "YOUR-GITHUB-USERNAME-tstest-integration-vms"; - - # Replace this with the path to the GitHub Actions runner token on - # your disk. - tokenFile = "/run/decrypted/ts-oss-ghaction-token"; - }; - - # A user account so there is a home directory and so they have kvm - # access. Please don't change this account name. - users.users.ghrunner = { - createHome = true; - isSystemUser = true; - extraGroups = [ "kvm" ]; - }; - - # The default github-runner service sets a lot of isolation features - # that attempt to limit the damage that malicious code can use. - # Unfortunately we rely on some "dangerous" features to do these tests, - # so this shim will peel some of them away. - systemd.services.github-runner = { - serviceConfig = { - # We need access to /dev to poke /dev/kvm. - PrivateDevices = lib.mkForce false; - - # /dev/kvm is how qemu creates a virtual machine with KVM. - DeviceAllow = lib.mkForce [ "/dev/kvm" ]; - - # Ensure the service has KVM permissions with the `kvm` group. - ExtraGroups = [ "kvm" ]; - - # The service runs as a dynamic user by default. This makes it hard - # to persistently store things in /var/lib/ghrunner. This line - # disables the dynamic user feature. - DynamicUser = lib.mkForce false; - - # Run this service as our ghrunner user. - User = "ghrunner"; - - # We need access to /var/lib/ghrunner to store VM images. - ProtectSystem = lib.mkForce null; - }; - }; -} +# This is a NixOS module to allow a machine to act as an integration test +# runner. This is used for the end-to-end VM test suite. + +{ lib, config, pkgs, ... }: + +{ + # The GitHub Actions self-hosted runner service. + services.github-runner = { + enable = true; + url = "https://github.com/tailscale/tailscale"; + replace = true; + extraLabels = [ "vm_integration_test" ]; + + # Justifications for the packages: + extraPackages = with pkgs; [ + # The test suite is written in Go. + go + + # This contains genisoimage, which is needed to create cloud-init + # seeds. + cdrkit + + # This package is the virtual machine hypervisor we use in tests. + qemu + + # This package contains tools like `ssh-keygen`. + openssh + + # The C compiler so cgo builds work. + gcc + + # The package manager Nix, just in case. + nix + + # Used to generate a NixOS image for testing. + nixos-generators + + # Used to extract things. + gnutar + + # Used to decompress things. + lzma + ]; + + # Customize this to include your GitHub username so we can track + # who is running which node. + name = "YOUR-GITHUB-USERNAME-tstest-integration-vms"; + + # Replace this with the path to the GitHub Actions runner token on + # your disk. + tokenFile = "/run/decrypted/ts-oss-ghaction-token"; + }; + + # A user account so there is a home directory and so they have kvm + # access. Please don't change this account name. + users.users.ghrunner = { + createHome = true; + isSystemUser = true; + extraGroups = [ "kvm" ]; + }; + + # The default github-runner service sets a lot of isolation features + # that attempt to limit the damage that malicious code can use. + # Unfortunately we rely on some "dangerous" features to do these tests, + # so this shim will peel some of them away. + systemd.services.github-runner = { + serviceConfig = { + # We need access to /dev to poke /dev/kvm. + PrivateDevices = lib.mkForce false; + + # /dev/kvm is how qemu creates a virtual machine with KVM. + DeviceAllow = lib.mkForce [ "/dev/kvm" ]; + + # Ensure the service has KVM permissions with the `kvm` group. + ExtraGroups = [ "kvm" ]; + + # The service runs as a dynamic user by default. This makes it hard + # to persistently store things in /var/lib/ghrunner. This line + # disables the dynamic user feature. + DynamicUser = lib.mkForce false; + + # Run this service as our ghrunner user. + User = "ghrunner"; + + # We need access to /var/lib/ghrunner to store VM images. + ProtectSystem = lib.mkForce null; + }; + }; +} diff --git a/tstest/integration/vms/squid.conf b/tstest/integration/vms/squid.conf index 29d32bd6d8606..e43c5cd1f41d4 100644 --- a/tstest/integration/vms/squid.conf +++ b/tstest/integration/vms/squid.conf @@ -1,39 +1,39 @@ -pid_filename /run/squid.pid -cache_dir ufs /tmp/squid/cache 500 16 256 -maximum_object_size 4096 KB -coredump_dir /tmp/squid/core -visible_hostname localhost -cache_access_log /tmp/squid/access.log -cache_log /tmp/squid/cache.log - -# Access Control lists -acl localhost src 127.0.0.1 ::1 -acl manager proto cache_object -acl SSL_ports port 443 -acl Safe_ports port 80 # http -acl Safe_ports port 21 # ftp -acl Safe_ports port 443 # https -acl Safe_ports port 70 # gopher -acl Safe_ports port 210 # wais -acl Safe_ports port 1025-65535 # unregistered ports -acl Safe_ports port 280 # http-mgmt -acl Safe_ports port 488 # gss-http -acl Safe_ports port 591 # filemaker -acl Safe_ports port 777 # multiling http -acl CONNECT method CONNECT - -http_access allow localhost -http_access deny all -forwarded_for on - -# sslcrtd_program /nix/store/nqlqk1f6qlxdirlrl1aijgb6vbzxs0gs-squid-4.17/libexec/security_file_certgen -s /tmp/squid/ssl_db -M 4MB -sslcrtd_children 5 - -http_port 127.0.0.1:3128 \ - ssl-bump \ - generate-host-certificates=on \ - dynamic_cert_mem_cache_size=4MB \ - cert=/tmp/squid/myca-mitm.pem - -ssl_bump stare all # mimic the Client Hello, drop unsupported extensions +pid_filename /run/squid.pid +cache_dir ufs /tmp/squid/cache 500 16 256 +maximum_object_size 4096 KB +coredump_dir /tmp/squid/core +visible_hostname localhost +cache_access_log /tmp/squid/access.log +cache_log /tmp/squid/cache.log + +# Access Control lists +acl localhost src 127.0.0.1 ::1 +acl manager proto cache_object +acl SSL_ports port 443 +acl Safe_ports port 80 # http +acl Safe_ports port 21 # ftp +acl Safe_ports port 443 # https +acl Safe_ports port 70 # gopher +acl Safe_ports port 210 # wais +acl Safe_ports port 1025-65535 # unregistered ports +acl Safe_ports port 280 # http-mgmt +acl Safe_ports port 488 # gss-http +acl Safe_ports port 591 # filemaker +acl Safe_ports port 777 # multiling http +acl CONNECT method CONNECT + +http_access allow localhost +http_access deny all +forwarded_for on + +# sslcrtd_program /nix/store/nqlqk1f6qlxdirlrl1aijgb6vbzxs0gs-squid-4.17/libexec/security_file_certgen -s /tmp/squid/ssl_db -M 4MB +sslcrtd_children 5 + +http_port 127.0.0.1:3128 \ + ssl-bump \ + generate-host-certificates=on \ + dynamic_cert_mem_cache_size=4MB \ + cert=/tmp/squid/myca-mitm.pem + +ssl_bump stare all # mimic the Client Hello, drop unsupported extensions ssl_bump bump all # terminate and establish new TLS connection \ No newline at end of file diff --git a/tstest/integration/vms/top_level_test.go b/tstest/integration/vms/top_level_test.go index c107fd89cc886..1b9c10e29297a 100644 --- a/tstest/integration/vms/top_level_test.go +++ b/tstest/integration/vms/top_level_test.go @@ -1,124 +1,124 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !plan9 - -package vms - -import ( - "context" - "testing" - "time" - - "github.com/pkg/sftp" - expect "github.com/tailscale/goexpect" -) - -func TestRunUbuntu1804(t *testing.T) { - testOneDistribution(t, 0, Distros[0]) -} - -func TestRunUbuntu2004(t *testing.T) { - testOneDistribution(t, 1, Distros[1]) -} - -func TestRunNixos2111(t *testing.T) { - t.Parallel() - testOneDistribution(t, 2, Distros[2]) -} - -// TestMITMProxy is a smoke test for derphttp through a MITM proxy. -// Encountering such proxies is unfortunately commonplace in more -// traditional enterprise networks. -// -// We invoke tailscale netcheck because the networking check is done -// by tailscale rather than tailscaled, making it easier to configure -// the proxy. -// -// To provide the actual MITM server, we use squid. -func TestMITMProxy(t *testing.T) { - t.Parallel() - setupTests(t) - distro := Distros[2] // nixos-21.11 - - if distroRex.Unwrap().MatchString(distro.Name) { - t.Logf("%s matches %s", distro.Name, distroRex.Unwrap()) - } else { - t.Skip("regex not matched") - } - - ctx, done := context.WithCancel(context.Background()) - t.Cleanup(done) - - h := newHarness(t) - - err := ramsem.sem.Acquire(ctx, int64(distro.MemoryMegs)) - if err != nil { - t.Fatalf("can't acquire ram semaphore: %v", err) - } - t.Cleanup(func() { ramsem.sem.Release(int64(distro.MemoryMegs)) }) - - vm := h.mkVM(t, 2, distro, h.pubKey, h.loginServerURL, t.TempDir()) - vm.waitStartup(t) - - ipm := h.waitForIPMap(t, vm, distro) - _, cli := h.setupSSHShell(t, distro, ipm) - - sftpCli, err := sftp.NewClient(cli) - if err != nil { - t.Fatalf("can't connect over sftp to copy binaries: %v", err) - } - defer sftpCli.Close() - - // Initialize a squid installation. - // - // A few things of note here: - // - The first thing we do is append the nsslcrtd_program stanza to the config. - // This must be an absolute path and is based on the nix path of the squid derivation, - // so we compute and write it out here. - // - Squid expects a pre-initialized directory layout, so we create that in /tmp/squid then - // invoke squid with -z to have it fill in the rest. - // - Doing a meddler-in-the-middle attack requires using some fake keys, so we create - // them using openssl and then use the security_file_certgen tool to setup squids' ssl_db. - // - There were some perms issues, so i yeeted 0777. Its only a test anyway - copyFile(t, sftpCli, "squid.conf", "/tmp/squid.conf") - runTestCommands(t, 30*time.Second, cli, []expect.Batcher{ - &expect.BSnd{S: "echo -e \"\\nsslcrtd_program $(nix eval --raw nixpkgs.squid)/libexec/security_file_certgen -s /tmp/squid/ssl_db -M 4MB\\n\" >> /tmp/squid.conf\n"}, - &expect.BSnd{S: "mkdir -p /tmp/squid/{cache,core}\n"}, - &expect.BSnd{S: "openssl req -batch -new -newkey rsa:4096 -sha256 -days 3650 -nodes -x509 -keyout /tmp/squid/myca-mitm.pem -out /tmp/squid/myca-mitm.pem\n"}, - &expect.BExp{R: `writing new private key to '/tmp/squid/myca-mitm.pem'`}, - &expect.BSnd{S: "$(nix eval --raw nixpkgs.squid)/libexec/security_file_certgen -c -s /tmp/squid/ssl_db -M 4MB\n"}, - &expect.BExp{R: `Done`}, - &expect.BSnd{S: "sudo chmod -R 0777 /tmp/squid\n"}, - &expect.BSnd{S: "squid --foreground -YCs -z -f /tmp/squid.conf\n"}, - &expect.BSnd{S: "echo Success.\n"}, - &expect.BExp{R: `Success.`}, - }) - - // Start the squid server. - runTestCommands(t, 10*time.Second, cli, []expect.Batcher{ - &expect.BSnd{S: "daemonize -v -c /tmp/squid $(nix eval --raw nixpkgs.squid)/bin/squid --foreground -YCs -f /tmp/squid.conf\n"}, // start daemon - // NOTE(tom): Writing to /dev/tcp/* is bash magic, not a file. This - // eldritchian incantation lets us wait till squid is up. - &expect.BSnd{S: "while ! timeout 5 bash -c 'echo > /dev/tcp/localhost/3128'; do sleep 1; done\n"}, - &expect.BSnd{S: "echo Success.\n"}, - &expect.BExp{R: `Success.`}, - }) - - // Uncomment to help debugging this test if it fails. - // - // runTestCommands(t, 30 * time.Second, cli, []expect.Batcher{ - // &expect.BSnd{S: "sudo ifconfig\n"}, - // &expect.BSnd{S: "sudo ip link\n"}, - // &expect.BSnd{S: "sudo ip route\n"}, - // &expect.BSnd{S: "ps -aux\n"}, - // &expect.BSnd{S: "netstat -a\n"}, - // &expect.BSnd{S: "cat /tmp/squid/access.log && cat /tmp/squid/cache.log && cat /tmp/squid.conf && echo Success.\n"}, - // &expect.BExp{R: `Success.`}, - // }) - - runTestCommands(t, 30*time.Second, cli, []expect.Batcher{ - &expect.BSnd{S: "SSL_CERT_FILE=/tmp/squid/myca-mitm.pem HTTPS_PROXY=http://127.0.0.1:3128 tailscale netcheck\n"}, - &expect.BExp{R: `IPv4: yes`}, - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !plan9 + +package vms + +import ( + "context" + "testing" + "time" + + "github.com/pkg/sftp" + expect "github.com/tailscale/goexpect" +) + +func TestRunUbuntu1804(t *testing.T) { + testOneDistribution(t, 0, Distros[0]) +} + +func TestRunUbuntu2004(t *testing.T) { + testOneDistribution(t, 1, Distros[1]) +} + +func TestRunNixos2111(t *testing.T) { + t.Parallel() + testOneDistribution(t, 2, Distros[2]) +} + +// TestMITMProxy is a smoke test for derphttp through a MITM proxy. +// Encountering such proxies is unfortunately commonplace in more +// traditional enterprise networks. +// +// We invoke tailscale netcheck because the networking check is done +// by tailscale rather than tailscaled, making it easier to configure +// the proxy. +// +// To provide the actual MITM server, we use squid. +func TestMITMProxy(t *testing.T) { + t.Parallel() + setupTests(t) + distro := Distros[2] // nixos-21.11 + + if distroRex.Unwrap().MatchString(distro.Name) { + t.Logf("%s matches %s", distro.Name, distroRex.Unwrap()) + } else { + t.Skip("regex not matched") + } + + ctx, done := context.WithCancel(context.Background()) + t.Cleanup(done) + + h := newHarness(t) + + err := ramsem.sem.Acquire(ctx, int64(distro.MemoryMegs)) + if err != nil { + t.Fatalf("can't acquire ram semaphore: %v", err) + } + t.Cleanup(func() { ramsem.sem.Release(int64(distro.MemoryMegs)) }) + + vm := h.mkVM(t, 2, distro, h.pubKey, h.loginServerURL, t.TempDir()) + vm.waitStartup(t) + + ipm := h.waitForIPMap(t, vm, distro) + _, cli := h.setupSSHShell(t, distro, ipm) + + sftpCli, err := sftp.NewClient(cli) + if err != nil { + t.Fatalf("can't connect over sftp to copy binaries: %v", err) + } + defer sftpCli.Close() + + // Initialize a squid installation. + // + // A few things of note here: + // - The first thing we do is append the nsslcrtd_program stanza to the config. + // This must be an absolute path and is based on the nix path of the squid derivation, + // so we compute and write it out here. + // - Squid expects a pre-initialized directory layout, so we create that in /tmp/squid then + // invoke squid with -z to have it fill in the rest. + // - Doing a meddler-in-the-middle attack requires using some fake keys, so we create + // them using openssl and then use the security_file_certgen tool to setup squids' ssl_db. + // - There were some perms issues, so i yeeted 0777. Its only a test anyway + copyFile(t, sftpCli, "squid.conf", "/tmp/squid.conf") + runTestCommands(t, 30*time.Second, cli, []expect.Batcher{ + &expect.BSnd{S: "echo -e \"\\nsslcrtd_program $(nix eval --raw nixpkgs.squid)/libexec/security_file_certgen -s /tmp/squid/ssl_db -M 4MB\\n\" >> /tmp/squid.conf\n"}, + &expect.BSnd{S: "mkdir -p /tmp/squid/{cache,core}\n"}, + &expect.BSnd{S: "openssl req -batch -new -newkey rsa:4096 -sha256 -days 3650 -nodes -x509 -keyout /tmp/squid/myca-mitm.pem -out /tmp/squid/myca-mitm.pem\n"}, + &expect.BExp{R: `writing new private key to '/tmp/squid/myca-mitm.pem'`}, + &expect.BSnd{S: "$(nix eval --raw nixpkgs.squid)/libexec/security_file_certgen -c -s /tmp/squid/ssl_db -M 4MB\n"}, + &expect.BExp{R: `Done`}, + &expect.BSnd{S: "sudo chmod -R 0777 /tmp/squid\n"}, + &expect.BSnd{S: "squid --foreground -YCs -z -f /tmp/squid.conf\n"}, + &expect.BSnd{S: "echo Success.\n"}, + &expect.BExp{R: `Success.`}, + }) + + // Start the squid server. + runTestCommands(t, 10*time.Second, cli, []expect.Batcher{ + &expect.BSnd{S: "daemonize -v -c /tmp/squid $(nix eval --raw nixpkgs.squid)/bin/squid --foreground -YCs -f /tmp/squid.conf\n"}, // start daemon + // NOTE(tom): Writing to /dev/tcp/* is bash magic, not a file. This + // eldritchian incantation lets us wait till squid is up. + &expect.BSnd{S: "while ! timeout 5 bash -c 'echo > /dev/tcp/localhost/3128'; do sleep 1; done\n"}, + &expect.BSnd{S: "echo Success.\n"}, + &expect.BExp{R: `Success.`}, + }) + + // Uncomment to help debugging this test if it fails. + // + // runTestCommands(t, 30 * time.Second, cli, []expect.Batcher{ + // &expect.BSnd{S: "sudo ifconfig\n"}, + // &expect.BSnd{S: "sudo ip link\n"}, + // &expect.BSnd{S: "sudo ip route\n"}, + // &expect.BSnd{S: "ps -aux\n"}, + // &expect.BSnd{S: "netstat -a\n"}, + // &expect.BSnd{S: "cat /tmp/squid/access.log && cat /tmp/squid/cache.log && cat /tmp/squid.conf && echo Success.\n"}, + // &expect.BExp{R: `Success.`}, + // }) + + runTestCommands(t, 30*time.Second, cli, []expect.Batcher{ + &expect.BSnd{S: "SSL_CERT_FILE=/tmp/squid/myca-mitm.pem HTTPS_PROXY=http://127.0.0.1:3128 tailscale netcheck\n"}, + &expect.BExp{R: `IPv4: yes`}, + }) +} diff --git a/tstest/integration/vms/udp_tester.go b/tstest/integration/vms/udp_tester.go index be44aa9636103..14c8c6ed0c7a5 100644 --- a/tstest/integration/vms/udp_tester.go +++ b/tstest/integration/vms/udp_tester.go @@ -1,77 +1,77 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -// Command udp_tester exists because all of these distros being tested don't -// have a consistent tool for doing UDP traffic. This is a very hacked up tool -// that does that UDP traffic so these tests can be done. -package main - -import ( - "flag" - "io" - "log" - "net" - "os" -) - -var ( - client = flag.String("client", "", "host:port to connect to for sending UDP") - server = flag.String("server", "", "host:port to bind to for receiving UDP") -) - -func main() { - flag.Parse() - - if *client == "" && *server == "" { - log.Fatal("specify -client or -server") - } - - if *client != "" { - conn, err := net.Dial("udp", *client) - if err != nil { - log.Fatalf("can't dial %s: %v", *client, err) - } - log.Printf("dialed to %s", conn.RemoteAddr()) - defer conn.Close() - - buf := make([]byte, 2048) - n, err := os.Stdin.Read(buf) - if err != nil && err != io.EOF { - log.Fatalf("can't read from stdin: %v", err) - } - - nn, err := conn.Write(buf[:n]) - if err != nil { - log.Fatalf("can't write to %s: %v", conn.RemoteAddr(), err) - } - - if n == nn { - return - } - - log.Fatalf("wanted to write %d bytes, wrote %d bytes", n, nn) - } - - if *server != "" { - addr, err := net.ResolveUDPAddr("udp", *server) - if err != nil { - log.Fatalf("can't resolve %s: %v", *server, err) - } - ln, err := net.ListenUDP("udp", addr) - if err != nil { - log.Fatalf("can't listen %s: %v", *server, err) - } - defer ln.Close() - - buf := make([]byte, 2048) - - n, _, err := ln.ReadFromUDP(buf) - if err != nil { - log.Fatal(err) - } - - os.Stdout.Write(buf[:n]) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +// Command udp_tester exists because all of these distros being tested don't +// have a consistent tool for doing UDP traffic. This is a very hacked up tool +// that does that UDP traffic so these tests can be done. +package main + +import ( + "flag" + "io" + "log" + "net" + "os" +) + +var ( + client = flag.String("client", "", "host:port to connect to for sending UDP") + server = flag.String("server", "", "host:port to bind to for receiving UDP") +) + +func main() { + flag.Parse() + + if *client == "" && *server == "" { + log.Fatal("specify -client or -server") + } + + if *client != "" { + conn, err := net.Dial("udp", *client) + if err != nil { + log.Fatalf("can't dial %s: %v", *client, err) + } + log.Printf("dialed to %s", conn.RemoteAddr()) + defer conn.Close() + + buf := make([]byte, 2048) + n, err := os.Stdin.Read(buf) + if err != nil && err != io.EOF { + log.Fatalf("can't read from stdin: %v", err) + } + + nn, err := conn.Write(buf[:n]) + if err != nil { + log.Fatalf("can't write to %s: %v", conn.RemoteAddr(), err) + } + + if n == nn { + return + } + + log.Fatalf("wanted to write %d bytes, wrote %d bytes", n, nn) + } + + if *server != "" { + addr, err := net.ResolveUDPAddr("udp", *server) + if err != nil { + log.Fatalf("can't resolve %s: %v", *server, err) + } + ln, err := net.ListenUDP("udp", addr) + if err != nil { + log.Fatalf("can't listen %s: %v", *server, err) + } + defer ln.Close() + + buf := make([]byte, 2048) + + n, _, err := ln.ReadFromUDP(buf) + if err != nil { + log.Fatal(err) + } + + os.Stdout.Write(buf[:n]) + } +} diff --git a/tstest/log_test.go b/tstest/log_test.go index 51a5743c2c7f2..a8cb62cf5ccf2 100644 --- a/tstest/log_test.go +++ b/tstest/log_test.go @@ -1,47 +1,47 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstest - -import ( - "reflect" - "testing" -) - -func TestLogLineTracker(t *testing.T) { - const ( - l1 = "line 1: %s" - l2 = "line 2: %s" - l3 = "line 3: %s" - ) - - lt := NewLogLineTracker(t.Logf, []string{l1, l2}) - - if got, want := lt.Check(), []string{l1, l2}; !reflect.DeepEqual(got, want) { - t.Errorf("Check = %q; want %q", got, want) - } - - lt.Logf(l3, "hi") - - if got, want := lt.Check(), []string{l1, l2}; !reflect.DeepEqual(got, want) { - t.Errorf("Check = %q; want %q", got, want) - } - - lt.Logf(l1, "hi") - - if got, want := lt.Check(), []string{l2}; !reflect.DeepEqual(got, want) { - t.Errorf("Check = %q; want %q", got, want) - } - - lt.Logf(l1, "bye") - - if got, want := lt.Check(), []string{l2}; !reflect.DeepEqual(got, want) { - t.Errorf("Check = %q; want %q", got, want) - } - - lt.Logf(l2, "hi") - - if got, want := lt.Check(), []string(nil); !reflect.DeepEqual(got, want) { - t.Errorf("Check = %q; want %q", got, want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import ( + "reflect" + "testing" +) + +func TestLogLineTracker(t *testing.T) { + const ( + l1 = "line 1: %s" + l2 = "line 2: %s" + l3 = "line 3: %s" + ) + + lt := NewLogLineTracker(t.Logf, []string{l1, l2}) + + if got, want := lt.Check(), []string{l1, l2}; !reflect.DeepEqual(got, want) { + t.Errorf("Check = %q; want %q", got, want) + } + + lt.Logf(l3, "hi") + + if got, want := lt.Check(), []string{l1, l2}; !reflect.DeepEqual(got, want) { + t.Errorf("Check = %q; want %q", got, want) + } + + lt.Logf(l1, "hi") + + if got, want := lt.Check(), []string{l2}; !reflect.DeepEqual(got, want) { + t.Errorf("Check = %q; want %q", got, want) + } + + lt.Logf(l1, "bye") + + if got, want := lt.Check(), []string{l2}; !reflect.DeepEqual(got, want) { + t.Errorf("Check = %q; want %q", got, want) + } + + lt.Logf(l2, "hi") + + if got, want := lt.Check(), []string(nil); !reflect.DeepEqual(got, want) { + t.Errorf("Check = %q; want %q", got, want) + } +} diff --git a/tstest/natlab/firewall.go b/tstest/natlab/firewall.go index c427d6692a29c..851f1c56dcf8d 100644 --- a/tstest/natlab/firewall.go +++ b/tstest/natlab/firewall.go @@ -1,156 +1,156 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package natlab - -import ( - "fmt" - "net/netip" - "sync" - "time" - - "tailscale.com/util/mak" -) - -// FirewallType is the type of filtering a stateful firewall -// does. Values express different modes defined by RFC 4787. -type FirewallType int - -const ( - // AddressAndPortDependentFirewall specifies a destination - // address-and-port dependent firewall. Outbound traffic to an - // ip:port authorizes traffic from that ip:port exactly, and - // nothing else. - AddressAndPortDependentFirewall FirewallType = iota - // AddressDependentFirewall specifies a destination address - // dependent firewall. Once outbound traffic has been seen to an - // IP address, that IP address can talk back from any port. - AddressDependentFirewall - // EndpointIndependentFirewall specifies a destination endpoint - // independent firewall. Once outbound traffic has been seen from - // a source, anyone can talk back to that source. - EndpointIndependentFirewall -) - -// fwKey is the lookup key for a firewall session. While it contains a -// 4-tuple ({src,dst} {ip,port}), some FirewallTypes will zero out -// some fields, so in practice the key is either a 2-tuple (src only), -// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port). -type fwKey struct { - src netip.AddrPort - dst netip.AddrPort -} - -// key returns an fwKey for the given src and dst, trimmed according -// to the FirewallType. fwKeys are always constructed from the -// "outbound" point of view (i.e. src is the "trusted" side of the -// world), it's the caller's responsibility to swap src and dst in the -// call to key when processing packets inbound from the "untrusted" -// world. -func (s FirewallType) key(src, dst netip.AddrPort) fwKey { - k := fwKey{src: src} - switch s { - case EndpointIndependentFirewall: - case AddressDependentFirewall: - k.dst = netip.AddrPortFrom(dst.Addr(), k.dst.Port()) - case AddressAndPortDependentFirewall: - k.dst = dst - default: - panic(fmt.Sprintf("unknown firewall selectivity %v", s)) - } - return k -} - -// DefaultSessionTimeout is the default timeout for a firewall -// session. -const DefaultSessionTimeout = 30 * time.Second - -// Firewall is a simple stateful firewall that allows all outbound -// traffic and filters inbound traffic based on recently seen outbound -// traffic. Its HandlePacket method should be attached to a Machine to -// give it a stateful firewall. -type Firewall struct { - // SessionTimeout is the lifetime of idle sessions in the firewall - // state. Packets transiting from the TrustedInterface reset the - // session lifetime to SessionTimeout. If zero, - // DefaultSessionTimeout is used. - SessionTimeout time.Duration - // Type specifies how precisely return traffic must match - // previously seen outbound traffic to be allowed. Defaults to - // AddressAndPortDependentFirewall. - Type FirewallType - // TrustedInterface is an optional interface that is considered - // trusted in addition to PacketConns local to the Machine. All - // other interfaces can only respond to traffic from - // TrustedInterface or the local host. - TrustedInterface *Interface - // TimeNow is a function returning the current time. If nil, - // time.Now is used. - TimeNow func() time.Time - - // TODO: refresh directionality: outbound-only, both - - mu sync.Mutex - seen map[fwKey]time.Time // session -> deadline -} - -func (f *Firewall) timeNow() time.Time { - if f.TimeNow != nil { - return f.TimeNow() - } - return time.Now() -} - -// Reset drops all firewall state, forgetting all flows. -func (f *Firewall) Reset() { - f.mu.Lock() - defer f.mu.Unlock() - f.seen = nil -} - -func (f *Firewall) HandleOut(p *Packet, oif *Interface) *Packet { - f.mu.Lock() - defer f.mu.Unlock() - - k := f.Type.key(p.Src, p.Dst) - mak.Set(&f.seen, k, f.timeNow().Add(f.sessionTimeoutLocked())) - p.Trace("firewall out ok") - return p -} - -func (f *Firewall) HandleIn(p *Packet, iif *Interface) *Packet { - f.mu.Lock() - defer f.mu.Unlock() - - // reverse src and dst because the session table is from the POV - // of outbound packets. - k := f.Type.key(p.Dst, p.Src) - now := f.timeNow() - if now.After(f.seen[k]) { - p.Trace("firewall drop") - return nil - } - p.Trace("firewall in ok") - return p -} - -func (f *Firewall) HandleForward(p *Packet, iif *Interface, oif *Interface) *Packet { - if iif == f.TrustedInterface { - // Treat just like a locally originated packet - return f.HandleOut(p, oif) - } - if oif != f.TrustedInterface { - // Not a possible return packet from our trusted interface, drop. - p.Trace("firewall drop, unexpected oif") - return nil - } - // Otherwise, a session must exist, same as HandleIn. - return f.HandleIn(p, iif) -} - -func (f *Firewall) sessionTimeoutLocked() time.Duration { - if f.SessionTimeout == 0 { - return DefaultSessionTimeout - } - return f.SessionTimeout -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package natlab + +import ( + "fmt" + "net/netip" + "sync" + "time" + + "tailscale.com/util/mak" +) + +// FirewallType is the type of filtering a stateful firewall +// does. Values express different modes defined by RFC 4787. +type FirewallType int + +const ( + // AddressAndPortDependentFirewall specifies a destination + // address-and-port dependent firewall. Outbound traffic to an + // ip:port authorizes traffic from that ip:port exactly, and + // nothing else. + AddressAndPortDependentFirewall FirewallType = iota + // AddressDependentFirewall specifies a destination address + // dependent firewall. Once outbound traffic has been seen to an + // IP address, that IP address can talk back from any port. + AddressDependentFirewall + // EndpointIndependentFirewall specifies a destination endpoint + // independent firewall. Once outbound traffic has been seen from + // a source, anyone can talk back to that source. + EndpointIndependentFirewall +) + +// fwKey is the lookup key for a firewall session. While it contains a +// 4-tuple ({src,dst} {ip,port}), some FirewallTypes will zero out +// some fields, so in practice the key is either a 2-tuple (src only), +// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port). +type fwKey struct { + src netip.AddrPort + dst netip.AddrPort +} + +// key returns an fwKey for the given src and dst, trimmed according +// to the FirewallType. fwKeys are always constructed from the +// "outbound" point of view (i.e. src is the "trusted" side of the +// world), it's the caller's responsibility to swap src and dst in the +// call to key when processing packets inbound from the "untrusted" +// world. +func (s FirewallType) key(src, dst netip.AddrPort) fwKey { + k := fwKey{src: src} + switch s { + case EndpointIndependentFirewall: + case AddressDependentFirewall: + k.dst = netip.AddrPortFrom(dst.Addr(), k.dst.Port()) + case AddressAndPortDependentFirewall: + k.dst = dst + default: + panic(fmt.Sprintf("unknown firewall selectivity %v", s)) + } + return k +} + +// DefaultSessionTimeout is the default timeout for a firewall +// session. +const DefaultSessionTimeout = 30 * time.Second + +// Firewall is a simple stateful firewall that allows all outbound +// traffic and filters inbound traffic based on recently seen outbound +// traffic. Its HandlePacket method should be attached to a Machine to +// give it a stateful firewall. +type Firewall struct { + // SessionTimeout is the lifetime of idle sessions in the firewall + // state. Packets transiting from the TrustedInterface reset the + // session lifetime to SessionTimeout. If zero, + // DefaultSessionTimeout is used. + SessionTimeout time.Duration + // Type specifies how precisely return traffic must match + // previously seen outbound traffic to be allowed. Defaults to + // AddressAndPortDependentFirewall. + Type FirewallType + // TrustedInterface is an optional interface that is considered + // trusted in addition to PacketConns local to the Machine. All + // other interfaces can only respond to traffic from + // TrustedInterface or the local host. + TrustedInterface *Interface + // TimeNow is a function returning the current time. If nil, + // time.Now is used. + TimeNow func() time.Time + + // TODO: refresh directionality: outbound-only, both + + mu sync.Mutex + seen map[fwKey]time.Time // session -> deadline +} + +func (f *Firewall) timeNow() time.Time { + if f.TimeNow != nil { + return f.TimeNow() + } + return time.Now() +} + +// Reset drops all firewall state, forgetting all flows. +func (f *Firewall) Reset() { + f.mu.Lock() + defer f.mu.Unlock() + f.seen = nil +} + +func (f *Firewall) HandleOut(p *Packet, oif *Interface) *Packet { + f.mu.Lock() + defer f.mu.Unlock() + + k := f.Type.key(p.Src, p.Dst) + mak.Set(&f.seen, k, f.timeNow().Add(f.sessionTimeoutLocked())) + p.Trace("firewall out ok") + return p +} + +func (f *Firewall) HandleIn(p *Packet, iif *Interface) *Packet { + f.mu.Lock() + defer f.mu.Unlock() + + // reverse src and dst because the session table is from the POV + // of outbound packets. + k := f.Type.key(p.Dst, p.Src) + now := f.timeNow() + if now.After(f.seen[k]) { + p.Trace("firewall drop") + return nil + } + p.Trace("firewall in ok") + return p +} + +func (f *Firewall) HandleForward(p *Packet, iif *Interface, oif *Interface) *Packet { + if iif == f.TrustedInterface { + // Treat just like a locally originated packet + return f.HandleOut(p, oif) + } + if oif != f.TrustedInterface { + // Not a possible return packet from our trusted interface, drop. + p.Trace("firewall drop, unexpected oif") + return nil + } + // Otherwise, a session must exist, same as HandleIn. + return f.HandleIn(p, iif) +} + +func (f *Firewall) sessionTimeoutLocked() time.Duration { + if f.SessionTimeout == 0 { + return DefaultSessionTimeout + } + return f.SessionTimeout +} diff --git a/tstest/natlab/nat.go b/tstest/natlab/nat.go index d756c5bf11833..36b1322cdb62c 100644 --- a/tstest/natlab/nat.go +++ b/tstest/natlab/nat.go @@ -1,252 +1,252 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package natlab - -import ( - "context" - "fmt" - "net" - "net/netip" - "sync" - "time" -) - -// mapping is the state of an allocated NAT session. -type mapping struct { - lanSrc netip.AddrPort - lanDst netip.AddrPort - wanSrc netip.AddrPort - deadline time.Time - - // pc is a PacketConn that reserves an outbound port on the NAT's - // WAN interface. We do this because ListenPacket already has - // random port selection logic built in. Additionally this means - // that concurrent use of ListenPacket for connections originating - // from the NAT box won't conflict with NAT mappings, since both - // use PacketConn to reserve ports on the machine. - pc net.PacketConn -} - -// NATType is the mapping behavior of a NAT device. Values express -// different modes defined by RFC 4787. -type NATType int - -const ( - // EndpointIndependentNAT specifies a destination endpoint - // independent NAT. All traffic from a source ip:port gets mapped - // to a single WAN ip:port. - EndpointIndependentNAT NATType = iota - // AddressDependentNAT specifies a destination address dependent - // NAT. Every distinct destination IP gets its own WAN ip:port - // allocation. - AddressDependentNAT - // AddressAndPortDependentNAT specifies a destination - // address-and-port dependent NAT. Every distinct destination - // ip:port gets its own WAN ip:port allocation. - AddressAndPortDependentNAT -) - -// natKey is the lookup key for a NAT session. While it contains a -// 4-tuple ({src,dst} {ip,port}), some NATTypes will zero out some -// fields, so in practice the key is either a 2-tuple (src only), -// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port). -type natKey struct { - src, dst netip.AddrPort -} - -func (t NATType) key(src, dst netip.AddrPort) natKey { - k := natKey{src: src} - switch t { - case EndpointIndependentNAT: - case AddressDependentNAT: - k.dst = netip.AddrPortFrom(dst.Addr(), k.dst.Port()) - case AddressAndPortDependentNAT: - k.dst = dst - default: - panic(fmt.Sprintf("unknown NAT type %v", t)) - } - return k -} - -// DefaultMappingTimeout is the default timeout for a NAT mapping. -const DefaultMappingTimeout = 30 * time.Second - -// SNAT44 implements an IPv4-to-IPv4 source NAT (SNAT) translator, with -// optional builtin firewall. -type SNAT44 struct { - // Machine is the machine to which this NAT is attached. Altered - // packets are injected back into this Machine for processing. - Machine *Machine - // ExternalInterface is the "WAN" interface of Machine. Packets - // from other sources get NATed onto this interface. - ExternalInterface *Interface - // Type specifies the mapping allocation behavior for this NAT. - Type NATType - // MappingTimeout is the lifetime of individual NAT sessions. Once - // a session expires, the mapped port effectively "closes" to new - // traffic. If MappingTimeout is 0, DefaultMappingTimeout is used. - MappingTimeout time.Duration - // Firewall is an optional packet handler that will be invoked as - // a firewall during NAT translation. The firewall always sees - // packets in their "LAN form", i.e. before translation in the - // outbound direction and after translation in the inbound - // direction. - Firewall PacketHandler - // TimeNow is a function that returns the current time. If - // nil, time.Now is used. - TimeNow func() time.Time - - mu sync.Mutex - byLAN map[natKey]*mapping // lookup by outbound packet tuple - byWAN map[netip.AddrPort]*mapping // lookup by wan ip:port only -} - -func (n *SNAT44) timeNow() time.Time { - if n.TimeNow != nil { - return n.TimeNow() - } - return time.Now() -} - -func (n *SNAT44) mappingTimeout() time.Duration { - if n.MappingTimeout == 0 { - return DefaultMappingTimeout - } - return n.MappingTimeout -} - -func (n *SNAT44) initLocked() { - if n.byLAN == nil { - n.byLAN = map[natKey]*mapping{} - n.byWAN = map[netip.AddrPort]*mapping{} - } - if n.ExternalInterface.Machine() != n.Machine { - panic(fmt.Sprintf("NAT given interface %s that is not part of given machine %s", n.ExternalInterface, n.Machine.Name)) - } -} - -func (n *SNAT44) HandleOut(p *Packet, oif *Interface) *Packet { - // NATs don't affect locally originated packets. - if n.Firewall != nil { - return n.Firewall.HandleOut(p, oif) - } - return p -} - -func (n *SNAT44) HandleIn(p *Packet, iif *Interface) *Packet { - if iif != n.ExternalInterface { - // NAT can't apply, defer to firewall. - if n.Firewall != nil { - return n.Firewall.HandleIn(p, iif) - } - return p - } - - n.mu.Lock() - defer n.mu.Unlock() - n.initLocked() - - now := n.timeNow() - mapping := n.byWAN[p.Dst] - if mapping == nil || now.After(mapping.deadline) { - // NAT didn't hit, defer to firewall or allow in for local - // socket handling. - if n.Firewall != nil { - return n.Firewall.HandleIn(p, iif) - } - return p - } - - p.Dst = mapping.lanSrc - p.Trace("dnat to %v", p.Dst) - // Don't process firewall here. We mutated the packet such that - // it's no longer destined locally, so we'll get reinvoked as - // HandleForward and need to process the altered packet there. - return p -} - -func (n *SNAT44) HandleForward(p *Packet, iif, oif *Interface) *Packet { - switch { - case oif == n.ExternalInterface: - if p.Src.Addr() == oif.V4() { - // Packet already NATed and is just retraversing Forward, - // don't touch it again. - return p - } - - if n.Firewall != nil { - p2 := n.Firewall.HandleForward(p, iif, oif) - if p2 == nil { - // firewall dropped, done - return nil - } - if !p.Equivalent(p2) { - // firewall mutated packet? Weird, but okay. - return p2 - } - } - - n.mu.Lock() - defer n.mu.Unlock() - n.initLocked() - - k := n.Type.key(p.Src, p.Dst) - now := n.timeNow() - m := n.byLAN[k] - if m == nil || now.After(m.deadline) { - pc, wanAddr := n.allocateMappedPort() - m = &mapping{ - lanSrc: p.Src, - lanDst: p.Dst, - wanSrc: wanAddr, - pc: pc, - } - n.byLAN[k] = m - n.byWAN[wanAddr] = m - } - m.deadline = now.Add(n.mappingTimeout()) - p.Src = m.wanSrc - p.Trace("snat from %v", p.Src) - return p - case iif == n.ExternalInterface: - // Packet was already un-NAT-ed, we just need to either - // firewall it or let it through. - if n.Firewall != nil { - return n.Firewall.HandleForward(p, iif, oif) - } - return p - default: - // No NAT applies, invoke firewall or drop. - if n.Firewall != nil { - return n.Firewall.HandleForward(p, iif, oif) - } - return nil - } -} - -func (n *SNAT44) allocateMappedPort() (net.PacketConn, netip.AddrPort) { - // Clean up old entries before trying to allocate, to free up any - // expired ports. - n.gc() - - ip := n.ExternalInterface.V4() - pc, err := n.Machine.ListenPacket(context.Background(), "udp", net.JoinHostPort(ip.String(), "0")) - if err != nil { - panic(fmt.Sprintf("ran out of NAT ports: %v", err)) - } - addr := netip.AddrPortFrom(ip, uint16(pc.LocalAddr().(*net.UDPAddr).Port)) - return pc, addr -} - -func (n *SNAT44) gc() { - now := n.timeNow() - for _, m := range n.byLAN { - if !now.After(m.deadline) { - continue - } - m.pc.Close() - delete(n.byLAN, n.Type.key(m.lanSrc, m.lanDst)) - delete(n.byWAN, m.wanSrc) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package natlab + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + "time" +) + +// mapping is the state of an allocated NAT session. +type mapping struct { + lanSrc netip.AddrPort + lanDst netip.AddrPort + wanSrc netip.AddrPort + deadline time.Time + + // pc is a PacketConn that reserves an outbound port on the NAT's + // WAN interface. We do this because ListenPacket already has + // random port selection logic built in. Additionally this means + // that concurrent use of ListenPacket for connections originating + // from the NAT box won't conflict with NAT mappings, since both + // use PacketConn to reserve ports on the machine. + pc net.PacketConn +} + +// NATType is the mapping behavior of a NAT device. Values express +// different modes defined by RFC 4787. +type NATType int + +const ( + // EndpointIndependentNAT specifies a destination endpoint + // independent NAT. All traffic from a source ip:port gets mapped + // to a single WAN ip:port. + EndpointIndependentNAT NATType = iota + // AddressDependentNAT specifies a destination address dependent + // NAT. Every distinct destination IP gets its own WAN ip:port + // allocation. + AddressDependentNAT + // AddressAndPortDependentNAT specifies a destination + // address-and-port dependent NAT. Every distinct destination + // ip:port gets its own WAN ip:port allocation. + AddressAndPortDependentNAT +) + +// natKey is the lookup key for a NAT session. While it contains a +// 4-tuple ({src,dst} {ip,port}), some NATTypes will zero out some +// fields, so in practice the key is either a 2-tuple (src only), +// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port). +type natKey struct { + src, dst netip.AddrPort +} + +func (t NATType) key(src, dst netip.AddrPort) natKey { + k := natKey{src: src} + switch t { + case EndpointIndependentNAT: + case AddressDependentNAT: + k.dst = netip.AddrPortFrom(dst.Addr(), k.dst.Port()) + case AddressAndPortDependentNAT: + k.dst = dst + default: + panic(fmt.Sprintf("unknown NAT type %v", t)) + } + return k +} + +// DefaultMappingTimeout is the default timeout for a NAT mapping. +const DefaultMappingTimeout = 30 * time.Second + +// SNAT44 implements an IPv4-to-IPv4 source NAT (SNAT) translator, with +// optional builtin firewall. +type SNAT44 struct { + // Machine is the machine to which this NAT is attached. Altered + // packets are injected back into this Machine for processing. + Machine *Machine + // ExternalInterface is the "WAN" interface of Machine. Packets + // from other sources get NATed onto this interface. + ExternalInterface *Interface + // Type specifies the mapping allocation behavior for this NAT. + Type NATType + // MappingTimeout is the lifetime of individual NAT sessions. Once + // a session expires, the mapped port effectively "closes" to new + // traffic. If MappingTimeout is 0, DefaultMappingTimeout is used. + MappingTimeout time.Duration + // Firewall is an optional packet handler that will be invoked as + // a firewall during NAT translation. The firewall always sees + // packets in their "LAN form", i.e. before translation in the + // outbound direction and after translation in the inbound + // direction. + Firewall PacketHandler + // TimeNow is a function that returns the current time. If + // nil, time.Now is used. + TimeNow func() time.Time + + mu sync.Mutex + byLAN map[natKey]*mapping // lookup by outbound packet tuple + byWAN map[netip.AddrPort]*mapping // lookup by wan ip:port only +} + +func (n *SNAT44) timeNow() time.Time { + if n.TimeNow != nil { + return n.TimeNow() + } + return time.Now() +} + +func (n *SNAT44) mappingTimeout() time.Duration { + if n.MappingTimeout == 0 { + return DefaultMappingTimeout + } + return n.MappingTimeout +} + +func (n *SNAT44) initLocked() { + if n.byLAN == nil { + n.byLAN = map[natKey]*mapping{} + n.byWAN = map[netip.AddrPort]*mapping{} + } + if n.ExternalInterface.Machine() != n.Machine { + panic(fmt.Sprintf("NAT given interface %s that is not part of given machine %s", n.ExternalInterface, n.Machine.Name)) + } +} + +func (n *SNAT44) HandleOut(p *Packet, oif *Interface) *Packet { + // NATs don't affect locally originated packets. + if n.Firewall != nil { + return n.Firewall.HandleOut(p, oif) + } + return p +} + +func (n *SNAT44) HandleIn(p *Packet, iif *Interface) *Packet { + if iif != n.ExternalInterface { + // NAT can't apply, defer to firewall. + if n.Firewall != nil { + return n.Firewall.HandleIn(p, iif) + } + return p + } + + n.mu.Lock() + defer n.mu.Unlock() + n.initLocked() + + now := n.timeNow() + mapping := n.byWAN[p.Dst] + if mapping == nil || now.After(mapping.deadline) { + // NAT didn't hit, defer to firewall or allow in for local + // socket handling. + if n.Firewall != nil { + return n.Firewall.HandleIn(p, iif) + } + return p + } + + p.Dst = mapping.lanSrc + p.Trace("dnat to %v", p.Dst) + // Don't process firewall here. We mutated the packet such that + // it's no longer destined locally, so we'll get reinvoked as + // HandleForward and need to process the altered packet there. + return p +} + +func (n *SNAT44) HandleForward(p *Packet, iif, oif *Interface) *Packet { + switch { + case oif == n.ExternalInterface: + if p.Src.Addr() == oif.V4() { + // Packet already NATed and is just retraversing Forward, + // don't touch it again. + return p + } + + if n.Firewall != nil { + p2 := n.Firewall.HandleForward(p, iif, oif) + if p2 == nil { + // firewall dropped, done + return nil + } + if !p.Equivalent(p2) { + // firewall mutated packet? Weird, but okay. + return p2 + } + } + + n.mu.Lock() + defer n.mu.Unlock() + n.initLocked() + + k := n.Type.key(p.Src, p.Dst) + now := n.timeNow() + m := n.byLAN[k] + if m == nil || now.After(m.deadline) { + pc, wanAddr := n.allocateMappedPort() + m = &mapping{ + lanSrc: p.Src, + lanDst: p.Dst, + wanSrc: wanAddr, + pc: pc, + } + n.byLAN[k] = m + n.byWAN[wanAddr] = m + } + m.deadline = now.Add(n.mappingTimeout()) + p.Src = m.wanSrc + p.Trace("snat from %v", p.Src) + return p + case iif == n.ExternalInterface: + // Packet was already un-NAT-ed, we just need to either + // firewall it or let it through. + if n.Firewall != nil { + return n.Firewall.HandleForward(p, iif, oif) + } + return p + default: + // No NAT applies, invoke firewall or drop. + if n.Firewall != nil { + return n.Firewall.HandleForward(p, iif, oif) + } + return nil + } +} + +func (n *SNAT44) allocateMappedPort() (net.PacketConn, netip.AddrPort) { + // Clean up old entries before trying to allocate, to free up any + // expired ports. + n.gc() + + ip := n.ExternalInterface.V4() + pc, err := n.Machine.ListenPacket(context.Background(), "udp", net.JoinHostPort(ip.String(), "0")) + if err != nil { + panic(fmt.Sprintf("ran out of NAT ports: %v", err)) + } + addr := netip.AddrPortFrom(ip, uint16(pc.LocalAddr().(*net.UDPAddr).Port)) + return pc, addr +} + +func (n *SNAT44) gc() { + now := n.timeNow() + for _, m := range n.byLAN { + if !now.After(m.deadline) { + continue + } + m.pc.Close() + delete(n.byLAN, n.Type.key(m.lanSrc, m.lanDst)) + delete(n.byWAN, m.wanSrc) + } +} diff --git a/tstest/tstest.go b/tstest/tstest.go index 2d0d1351e293a..118aa382749ae 100644 --- a/tstest/tstest.go +++ b/tstest/tstest.go @@ -1,95 +1,95 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package tstest provides utilities for use in unit tests. -package tstest - -import ( - "context" - "os" - "strconv" - "strings" - "sync/atomic" - "testing" - "time" - - "tailscale.com/envknob" - "tailscale.com/logtail/backoff" - "tailscale.com/types/logger" - "tailscale.com/util/cibuild" -) - -// Replace replaces the value of target with val. -// The old value is restored when the test ends. -func Replace[T any](t testing.TB, target *T, val T) { - t.Helper() - if target == nil { - t.Fatalf("Replace: nil pointer") - panic("unreachable") // pacify staticcheck - } - old := *target - t.Cleanup(func() { - *target = old - }) - - *target = val - return -} - -// WaitFor retries try for up to maxWait. -// It returns nil once try returns nil the first time. -// If maxWait passes without success, it returns try's last error. -func WaitFor(maxWait time.Duration, try func() error) error { - bo := backoff.NewBackoff("wait-for", logger.Discard, maxWait/4) - deadline := time.Now().Add(maxWait) - var err error - for time.Now().Before(deadline) { - err = try() - if err == nil { - break - } - bo.BackOff(context.Background(), err) - } - return err -} - -var testNum atomic.Int32 - -// Shard skips t if it's not running if the TS_TEST_SHARD test shard is set to -// "n/m" and this test execution number in the process mod m is not equal to n-1. -// That is, to run with 4 shards, set TS_TEST_SHARD=1/4, ..., TS_TEST_SHARD=4/4 -// for the four jobs. -func Shard(t testing.TB) { - e := os.Getenv("TS_TEST_SHARD") - a, b, ok := strings.Cut(e, "/") - if !ok { - return - } - wantShard, _ := strconv.ParseInt(a, 10, 32) - shards, _ := strconv.ParseInt(b, 10, 32) - if wantShard == 0 || shards == 0 { - return - } - - shard := ((testNum.Add(1) - 1) % int32(shards)) + 1 - if shard != int32(wantShard) { - t.Skipf("skipping shard %d/%d (process has TS_TEST_SHARD=%q)", shard, shards, e) - } -} - -// SkipOnUnshardedCI skips t if we're in CI and the TS_TEST_SHARD -// environment variable isn't set. -func SkipOnUnshardedCI(t testing.TB) { - if cibuild.On() && os.Getenv("TS_TEST_SHARD") == "" { - t.Skip("skipping on CI without TS_TEST_SHARD") - } -} - -var serializeParallel = envknob.RegisterBool("TS_SERIAL_TESTS") - -// Parallel calls t.Parallel, unless TS_SERIAL_TESTS is set true. -func Parallel(t *testing.T) { - if !serializeParallel() { - t.Parallel() - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tstest provides utilities for use in unit tests. +package tstest + +import ( + "context" + "os" + "strconv" + "strings" + "sync/atomic" + "testing" + "time" + + "tailscale.com/envknob" + "tailscale.com/logtail/backoff" + "tailscale.com/types/logger" + "tailscale.com/util/cibuild" +) + +// Replace replaces the value of target with val. +// The old value is restored when the test ends. +func Replace[T any](t testing.TB, target *T, val T) { + t.Helper() + if target == nil { + t.Fatalf("Replace: nil pointer") + panic("unreachable") // pacify staticcheck + } + old := *target + t.Cleanup(func() { + *target = old + }) + + *target = val + return +} + +// WaitFor retries try for up to maxWait. +// It returns nil once try returns nil the first time. +// If maxWait passes without success, it returns try's last error. +func WaitFor(maxWait time.Duration, try func() error) error { + bo := backoff.NewBackoff("wait-for", logger.Discard, maxWait/4) + deadline := time.Now().Add(maxWait) + var err error + for time.Now().Before(deadline) { + err = try() + if err == nil { + break + } + bo.BackOff(context.Background(), err) + } + return err +} + +var testNum atomic.Int32 + +// Shard skips t if it's not running if the TS_TEST_SHARD test shard is set to +// "n/m" and this test execution number in the process mod m is not equal to n-1. +// That is, to run with 4 shards, set TS_TEST_SHARD=1/4, ..., TS_TEST_SHARD=4/4 +// for the four jobs. +func Shard(t testing.TB) { + e := os.Getenv("TS_TEST_SHARD") + a, b, ok := strings.Cut(e, "/") + if !ok { + return + } + wantShard, _ := strconv.ParseInt(a, 10, 32) + shards, _ := strconv.ParseInt(b, 10, 32) + if wantShard == 0 || shards == 0 { + return + } + + shard := ((testNum.Add(1) - 1) % int32(shards)) + 1 + if shard != int32(wantShard) { + t.Skipf("skipping shard %d/%d (process has TS_TEST_SHARD=%q)", shard, shards, e) + } +} + +// SkipOnUnshardedCI skips t if we're in CI and the TS_TEST_SHARD +// environment variable isn't set. +func SkipOnUnshardedCI(t testing.TB) { + if cibuild.On() && os.Getenv("TS_TEST_SHARD") == "" { + t.Skip("skipping on CI without TS_TEST_SHARD") + } +} + +var serializeParallel = envknob.RegisterBool("TS_SERIAL_TESTS") + +// Parallel calls t.Parallel, unless TS_SERIAL_TESTS is set true. +func Parallel(t *testing.T) { + if !serializeParallel() { + t.Parallel() + } +} diff --git a/tstest/tstest_test.go b/tstest/tstest_test.go index e988d5d5624b6..20a9f7bf1faa2 100644 --- a/tstest/tstest_test.go +++ b/tstest/tstest_test.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstest - -import "testing" - -func TestReplace(t *testing.T) { - before := "before" - done := false - t.Run("replace", func(t *testing.T) { - Replace(t, &before, "after") - if before != "after" { - t.Errorf("before = %q; want %q", before, "after") - } - done = true - }) - if !done { - t.Fatal("subtest didn't run") - } - if before != "before" { - t.Errorf("before = %q; want %q", before, "before") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import "testing" + +func TestReplace(t *testing.T) { + before := "before" + done := false + t.Run("replace", func(t *testing.T) { + Replace(t, &before, "after") + if before != "after" { + t.Errorf("before = %q; want %q", before, "after") + } + done = true + }) + if !done { + t.Fatal("subtest didn't run") + } + if before != "before" { + t.Errorf("before = %q; want %q", before, "before") + } +} diff --git a/tstime/mono/mono.go b/tstime/mono/mono.go index 260e02b0fb0f3..94dca7d79b6bb 100644 --- a/tstime/mono/mono.go +++ b/tstime/mono/mono.go @@ -1,127 +1,127 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package mono provides fast monotonic time. -// On most platforms, mono.Now is about 2x faster than time.Now. -// However, time.Now is really fast, and nicer to use. -// -// For almost all purposes, you should use time.Now. -// -// Package mono exists because we get the current time multiple -// times per network packet, at which point it makes a -// measurable difference. -package mono - -import ( - "fmt" - "sync/atomic" - "time" -) - -// Time is the number of nanoseconds elapsed since an unspecified reference start time. -type Time int64 - -// Now returns the current monotonic time. -func Now() Time { - // On a newly started machine, the monotonic clock might be very near zero. - // Thus mono.Time(0).Before(mono.Now.Add(-time.Minute)) might yield true. - // The corresponding package time expression never does, if the wall clock is correct. - // Preserve this correspondence by increasing the "base" monotonic clock by a fair amount. - const baseOffset int64 = 1 << 55 // approximately 10,000 hours in nanoseconds - return Time(int64(time.Since(baseWall)) + baseOffset) -} - -// Since returns the time elapsed since t. -func Since(t Time) time.Duration { - return time.Duration(Now() - t) -} - -// Sub returns t-n, the duration from n to t. -func (t Time) Sub(n Time) time.Duration { - return time.Duration(t - n) -} - -// Add returns t+d. -func (t Time) Add(d time.Duration) Time { - return t + Time(d) -} - -// After reports t > n, whether t is after n. -func (t Time) After(n Time) bool { - return t > n -} - -// Before reports t < n, whether t is before n. -func (t Time) Before(n Time) bool { - return t < n -} - -// IsZero reports whether t == 0. -func (t Time) IsZero() bool { - return t == 0 -} - -// StoreAtomic does an atomic store *t = new. -func (t *Time) StoreAtomic(new Time) { - atomic.StoreInt64((*int64)(t), int64(new)) -} - -// LoadAtomic does an atomic load *t. -func (t *Time) LoadAtomic() Time { - return Time(atomic.LoadInt64((*int64)(t))) -} - -// baseWall and baseMono are a pair of almost-identical times used to correlate a Time with a wall time. -var ( - baseWall time.Time - baseMono Time -) - -func init() { - baseWall = time.Now() - baseMono = Now() -} - -// String prints t, including an estimated equivalent wall clock. -// This is best-effort only, for rough debugging purposes only. -// Since t is a monotonic time, it can vary from the actual wall clock by arbitrary amounts. -// Even in the best of circumstances, it may vary by a few milliseconds. -func (t Time) String() string { - return fmt.Sprintf("mono.Time(ns=%d, estimated wall=%v)", int64(t), baseWall.Add(t.Sub(baseMono)).Truncate(0)) -} - -// WallTime returns an approximate wall time that corresponded to t. -func (t Time) WallTime() time.Time { - if !t.IsZero() { - return baseWall.Add(t.Sub(baseMono)).Truncate(0) - } - return time.Time{} -} - -// MarshalJSON formats t for JSON as if it were a time.Time. -// We format Time this way for backwards-compatibility. -// Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged -// across different invocations of the Go process. This is best-effort only. -// Since t is a monotonic time, it can vary from the actual wall clock by arbitrary amounts. -// Even in the best of circumstances, it may vary by a few milliseconds. -func (t Time) MarshalJSON() ([]byte, error) { - tt := t.WallTime() - return tt.MarshalJSON() -} - -// UnmarshalJSON sets t according to data. -// Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged -// across different invocations of the Go process. This is best-effort only. -func (t *Time) UnmarshalJSON(data []byte) error { - var tt time.Time - err := tt.UnmarshalJSON(data) - if err != nil { - return err - } - if tt.IsZero() { - *t = 0 - return nil - } - *t = baseMono.Add(tt.Sub(baseWall)) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package mono provides fast monotonic time. +// On most platforms, mono.Now is about 2x faster than time.Now. +// However, time.Now is really fast, and nicer to use. +// +// For almost all purposes, you should use time.Now. +// +// Package mono exists because we get the current time multiple +// times per network packet, at which point it makes a +// measurable difference. +package mono + +import ( + "fmt" + "sync/atomic" + "time" +) + +// Time is the number of nanoseconds elapsed since an unspecified reference start time. +type Time int64 + +// Now returns the current monotonic time. +func Now() Time { + // On a newly started machine, the monotonic clock might be very near zero. + // Thus mono.Time(0).Before(mono.Now.Add(-time.Minute)) might yield true. + // The corresponding package time expression never does, if the wall clock is correct. + // Preserve this correspondence by increasing the "base" monotonic clock by a fair amount. + const baseOffset int64 = 1 << 55 // approximately 10,000 hours in nanoseconds + return Time(int64(time.Since(baseWall)) + baseOffset) +} + +// Since returns the time elapsed since t. +func Since(t Time) time.Duration { + return time.Duration(Now() - t) +} + +// Sub returns t-n, the duration from n to t. +func (t Time) Sub(n Time) time.Duration { + return time.Duration(t - n) +} + +// Add returns t+d. +func (t Time) Add(d time.Duration) Time { + return t + Time(d) +} + +// After reports t > n, whether t is after n. +func (t Time) After(n Time) bool { + return t > n +} + +// Before reports t < n, whether t is before n. +func (t Time) Before(n Time) bool { + return t < n +} + +// IsZero reports whether t == 0. +func (t Time) IsZero() bool { + return t == 0 +} + +// StoreAtomic does an atomic store *t = new. +func (t *Time) StoreAtomic(new Time) { + atomic.StoreInt64((*int64)(t), int64(new)) +} + +// LoadAtomic does an atomic load *t. +func (t *Time) LoadAtomic() Time { + return Time(atomic.LoadInt64((*int64)(t))) +} + +// baseWall and baseMono are a pair of almost-identical times used to correlate a Time with a wall time. +var ( + baseWall time.Time + baseMono Time +) + +func init() { + baseWall = time.Now() + baseMono = Now() +} + +// String prints t, including an estimated equivalent wall clock. +// This is best-effort only, for rough debugging purposes only. +// Since t is a monotonic time, it can vary from the actual wall clock by arbitrary amounts. +// Even in the best of circumstances, it may vary by a few milliseconds. +func (t Time) String() string { + return fmt.Sprintf("mono.Time(ns=%d, estimated wall=%v)", int64(t), baseWall.Add(t.Sub(baseMono)).Truncate(0)) +} + +// WallTime returns an approximate wall time that corresponded to t. +func (t Time) WallTime() time.Time { + if !t.IsZero() { + return baseWall.Add(t.Sub(baseMono)).Truncate(0) + } + return time.Time{} +} + +// MarshalJSON formats t for JSON as if it were a time.Time. +// We format Time this way for backwards-compatibility. +// Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged +// across different invocations of the Go process. This is best-effort only. +// Since t is a monotonic time, it can vary from the actual wall clock by arbitrary amounts. +// Even in the best of circumstances, it may vary by a few milliseconds. +func (t Time) MarshalJSON() ([]byte, error) { + tt := t.WallTime() + return tt.MarshalJSON() +} + +// UnmarshalJSON sets t according to data. +// Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged +// across different invocations of the Go process. This is best-effort only. +func (t *Time) UnmarshalJSON(data []byte) error { + var tt time.Time + err := tt.UnmarshalJSON(data) + if err != nil { + return err + } + if tt.IsZero() { + *t = 0 + return nil + } + *t = baseMono.Add(tt.Sub(baseWall)) + return nil +} diff --git a/tstime/rate/rate.go b/tstime/rate/rate.go index f0473862a2890..19dc26e6ae8a7 100644 --- a/tstime/rate/rate.go +++ b/tstime/rate/rate.go @@ -1,90 +1,90 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// This is a modified, simplified version of code from golang.org/x/time/rate. - -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package rate provides a rate limiter. -package rate - -import ( - "sync" - "time" - - "tailscale.com/tstime/mono" -) - -// Limit defines the maximum frequency of some events. -// Limit is represented as number of events per second. -// A zero Limit is invalid. -type Limit float64 - -// Every converts a minimum time interval between events to a Limit. -func Every(interval time.Duration) Limit { - if interval <= 0 { - panic("invalid interval") - } - return 1 / Limit(interval.Seconds()) -} - -// A Limiter controls how frequently events are allowed to happen. -// It implements a [token bucket] of a particular size b, -// initially full and refilled at rate r tokens per second. -// Informally, in any large enough time interval, -// the Limiter limits the rate to r tokens per second, -// with a maximum burst size of b events. -// Use NewLimiter to create non-zero Limiters. -// -// [token bucket]: https://en.wikipedia.org/wiki/Token_bucket -type Limiter struct { - limit Limit - burst float64 - mu sync.Mutex // protects following fields - tokens float64 // number of tokens currently in bucket - last mono.Time // the last time the limiter's tokens field was updated -} - -// NewLimiter returns a new Limiter that allows events up to rate r and permits -// bursts of at most b tokens. -func NewLimiter(r Limit, b int) *Limiter { - if b < 1 { - panic("bad burst, must be at least 1") - } - return &Limiter{limit: r, burst: float64(b)} -} - -// Allow reports whether an event may happen now. -func (lim *Limiter) Allow() bool { - return lim.allow(mono.Now()) -} - -func (lim *Limiter) allow(now mono.Time) bool { - lim.mu.Lock() - defer lim.mu.Unlock() - - // If time has moved backwards, look around awkwardly and pretend nothing happened. - if now.Before(lim.last) { - lim.last = now - } - - // Calculate the new number of tokens available due to the passage of time. - elapsed := now.Sub(lim.last) - tokens := lim.tokens + float64(lim.limit)*elapsed.Seconds() - if tokens > lim.burst { - tokens = lim.burst - } - - // Consume a token. - tokens-- - - // Update state. - ok := tokens >= 0 - if ok { - lim.last = now - lim.tokens = tokens - } - return ok -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This is a modified, simplified version of code from golang.org/x/time/rate. + +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package rate provides a rate limiter. +package rate + +import ( + "sync" + "time" + + "tailscale.com/tstime/mono" +) + +// Limit defines the maximum frequency of some events. +// Limit is represented as number of events per second. +// A zero Limit is invalid. +type Limit float64 + +// Every converts a minimum time interval between events to a Limit. +func Every(interval time.Duration) Limit { + if interval <= 0 { + panic("invalid interval") + } + return 1 / Limit(interval.Seconds()) +} + +// A Limiter controls how frequently events are allowed to happen. +// It implements a [token bucket] of a particular size b, +// initially full and refilled at rate r tokens per second. +// Informally, in any large enough time interval, +// the Limiter limits the rate to r tokens per second, +// with a maximum burst size of b events. +// Use NewLimiter to create non-zero Limiters. +// +// [token bucket]: https://en.wikipedia.org/wiki/Token_bucket +type Limiter struct { + limit Limit + burst float64 + mu sync.Mutex // protects following fields + tokens float64 // number of tokens currently in bucket + last mono.Time // the last time the limiter's tokens field was updated +} + +// NewLimiter returns a new Limiter that allows events up to rate r and permits +// bursts of at most b tokens. +func NewLimiter(r Limit, b int) *Limiter { + if b < 1 { + panic("bad burst, must be at least 1") + } + return &Limiter{limit: r, burst: float64(b)} +} + +// Allow reports whether an event may happen now. +func (lim *Limiter) Allow() bool { + return lim.allow(mono.Now()) +} + +func (lim *Limiter) allow(now mono.Time) bool { + lim.mu.Lock() + defer lim.mu.Unlock() + + // If time has moved backwards, look around awkwardly and pretend nothing happened. + if now.Before(lim.last) { + lim.last = now + } + + // Calculate the new number of tokens available due to the passage of time. + elapsed := now.Sub(lim.last) + tokens := lim.tokens + float64(lim.limit)*elapsed.Seconds() + if tokens > lim.burst { + tokens = lim.burst + } + + // Consume a token. + tokens-- + + // Update state. + ok := tokens >= 0 + if ok { + lim.last = now + lim.tokens = tokens + } + return ok +} diff --git a/tstime/tstime.go b/tstime/tstime.go index 1c006355f8726..22616bca7a47a 100644 --- a/tstime/tstime.go +++ b/tstime/tstime.go @@ -1,185 +1,185 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package tstime defines Tailscale-specific time utilities. -package tstime - -import ( - "context" - "strconv" - "strings" - "time" -) - -// Parse3339 is a wrapper around time.Parse(time.RFC3339, s). -func Parse3339(s string) (time.Time, error) { - return time.Parse(time.RFC3339, s) -} - -// Parse3339B is Parse3339 but for byte slices. -func Parse3339B(b []byte) (time.Time, error) { - var t time.Time - if err := t.UnmarshalText(b); err != nil { - return Parse3339(string(b)) // reproduce same error message - } - return t, nil -} - -// ParseDuration is more expressive than [time.ParseDuration], -// also accepting 'd' (days) and 'w' (weeks) literals. -func ParseDuration(s string) (time.Duration, error) { - for { - end := strings.IndexAny(s, "dw") - if end < 0 { - break - } - start := end - (len(s[:end]) - len(strings.TrimRight(s[:end], "0123456789"))) - n, err := strconv.Atoi(s[start:end]) - if err != nil { - return 0, err - } - hours := 24 - if s[end] == 'w' { - hours *= 7 - } - s = s[:start] + s[end+1:] + strconv.Itoa(n*hours) + "h" - } - return time.ParseDuration(s) -} - -// Sleep is like [time.Sleep] but returns early upon context cancelation. -// It reports whether the full sleep duration was achieved. -func Sleep(ctx context.Context, d time.Duration) bool { - timer := time.NewTimer(d) - defer timer.Stop() - select { - case <-ctx.Done(): - return false - case <-timer.C: - return true - } -} - -// DefaultClock is a wrapper around a Clock. -// It uses StdClock by default if Clock is nil. -type DefaultClock struct{ Clock } - -// TODO: We should make the methods of DefaultClock inlineable -// so that we can optimize for the common case where c.Clock == nil. - -func (c DefaultClock) Now() time.Time { - if c.Clock == nil { - return time.Now() - } - return c.Clock.Now() -} -func (c DefaultClock) NewTimer(d time.Duration) (TimerController, <-chan time.Time) { - if c.Clock == nil { - t := time.NewTimer(d) - return t, t.C - } - return c.Clock.NewTimer(d) -} -func (c DefaultClock) NewTicker(d time.Duration) (TickerController, <-chan time.Time) { - if c.Clock == nil { - t := time.NewTicker(d) - return t, t.C - } - return c.Clock.NewTicker(d) -} -func (c DefaultClock) AfterFunc(d time.Duration, f func()) TimerController { - if c.Clock == nil { - return time.AfterFunc(d, f) - } - return c.Clock.AfterFunc(d, f) -} -func (c DefaultClock) Since(t time.Time) time.Duration { - if c.Clock == nil { - return time.Since(t) - } - return c.Clock.Since(t) -} - -// Clock offers a subset of the functionality from the std/time package. -// Normally, applications will use the StdClock implementation that calls the -// appropriate std/time exported funcs. The advantage of using Clock is that -// tests can substitute a different implementation, allowing the test to control -// time precisely, something required for certain types of tests to be possible -// at all, speeds up execution by not needing to sleep, and can dramatically -// reduce the risk of flakes due to tests executing too slowly or quickly. -type Clock interface { - // Now returns the current time, as in time.Now. - Now() time.Time - // NewTimer returns a timer whose notion of the current time is controlled - // by this Clock. It follows the semantics of time.NewTimer as closely as - // possible but is adapted to return an interface, so the channel needs to - // be returned as well. - NewTimer(d time.Duration) (TimerController, <-chan time.Time) - // NewTicker returns a ticker whose notion of the current time is controlled - // by this Clock. It follows the semantics of time.NewTicker as closely as - // possible but is adapted to return an interface, so the channel needs to - // be returned as well. - NewTicker(d time.Duration) (TickerController, <-chan time.Time) - // AfterFunc returns a ticker whose notion of the current time is controlled - // by this Clock. When the ticker expires, it will call the provided func. - // It follows the semantics of time.AfterFunc. - AfterFunc(d time.Duration, f func()) TimerController - // Since returns the time elapsed since t. - // It follows the semantics of time.Since. - Since(t time.Time) time.Duration -} - -// TickerController offers the receivers of a time.Ticker to ensure -// compatibility with standard timers, but allows for the option of substituting -// a standard timer with something else for testing purposes. -type TickerController interface { - // Reset follows the same semantics as with time.Ticker.Reset. - Reset(d time.Duration) - // Stop follows the same semantics as with time.Ticker.Stop. - Stop() -} - -// TimerController offers the receivers of a time.Timer to ensure -// compatibility with standard timers, but allows for the option of substituting -// a standard timer with something else for testing purposes. -type TimerController interface { - // Reset follows the same semantics as with time.Timer.Reset. - Reset(d time.Duration) bool - // Stop follows the same semantics as with time.Timer.Stop. - Stop() bool -} - -// StdClock is a simple implementation of Clock using the relevant funcs in the -// std/time package. -type StdClock struct{} - -// Now calls time.Now. -func (StdClock) Now() time.Time { - return time.Now() -} - -// NewTimer calls time.NewTimer. As an interface does not allow for struct -// members and other packages cannot add receivers to another package, the -// channel is also returned because it would be otherwise inaccessible. -func (StdClock) NewTimer(d time.Duration) (TimerController, <-chan time.Time) { - t := time.NewTimer(d) - return t, t.C -} - -// NewTicker calls time.NewTicker. As an interface does not allow for struct -// members and other packages cannot add receivers to another package, the -// channel is also returned because it would be otherwise inaccessible. -func (StdClock) NewTicker(d time.Duration) (TickerController, <-chan time.Time) { - t := time.NewTicker(d) - return t, t.C -} - -// AfterFunc calls time.AfterFunc. -func (StdClock) AfterFunc(d time.Duration, f func()) TimerController { - return time.AfterFunc(d, f) -} - -// Since calls time.Since. -func (StdClock) Since(t time.Time) time.Duration { - return time.Since(t) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tstime defines Tailscale-specific time utilities. +package tstime + +import ( + "context" + "strconv" + "strings" + "time" +) + +// Parse3339 is a wrapper around time.Parse(time.RFC3339, s). +func Parse3339(s string) (time.Time, error) { + return time.Parse(time.RFC3339, s) +} + +// Parse3339B is Parse3339 but for byte slices. +func Parse3339B(b []byte) (time.Time, error) { + var t time.Time + if err := t.UnmarshalText(b); err != nil { + return Parse3339(string(b)) // reproduce same error message + } + return t, nil +} + +// ParseDuration is more expressive than [time.ParseDuration], +// also accepting 'd' (days) and 'w' (weeks) literals. +func ParseDuration(s string) (time.Duration, error) { + for { + end := strings.IndexAny(s, "dw") + if end < 0 { + break + } + start := end - (len(s[:end]) - len(strings.TrimRight(s[:end], "0123456789"))) + n, err := strconv.Atoi(s[start:end]) + if err != nil { + return 0, err + } + hours := 24 + if s[end] == 'w' { + hours *= 7 + } + s = s[:start] + s[end+1:] + strconv.Itoa(n*hours) + "h" + } + return time.ParseDuration(s) +} + +// Sleep is like [time.Sleep] but returns early upon context cancelation. +// It reports whether the full sleep duration was achieved. +func Sleep(ctx context.Context, d time.Duration) bool { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return false + case <-timer.C: + return true + } +} + +// DefaultClock is a wrapper around a Clock. +// It uses StdClock by default if Clock is nil. +type DefaultClock struct{ Clock } + +// TODO: We should make the methods of DefaultClock inlineable +// so that we can optimize for the common case where c.Clock == nil. + +func (c DefaultClock) Now() time.Time { + if c.Clock == nil { + return time.Now() + } + return c.Clock.Now() +} +func (c DefaultClock) NewTimer(d time.Duration) (TimerController, <-chan time.Time) { + if c.Clock == nil { + t := time.NewTimer(d) + return t, t.C + } + return c.Clock.NewTimer(d) +} +func (c DefaultClock) NewTicker(d time.Duration) (TickerController, <-chan time.Time) { + if c.Clock == nil { + t := time.NewTicker(d) + return t, t.C + } + return c.Clock.NewTicker(d) +} +func (c DefaultClock) AfterFunc(d time.Duration, f func()) TimerController { + if c.Clock == nil { + return time.AfterFunc(d, f) + } + return c.Clock.AfterFunc(d, f) +} +func (c DefaultClock) Since(t time.Time) time.Duration { + if c.Clock == nil { + return time.Since(t) + } + return c.Clock.Since(t) +} + +// Clock offers a subset of the functionality from the std/time package. +// Normally, applications will use the StdClock implementation that calls the +// appropriate std/time exported funcs. The advantage of using Clock is that +// tests can substitute a different implementation, allowing the test to control +// time precisely, something required for certain types of tests to be possible +// at all, speeds up execution by not needing to sleep, and can dramatically +// reduce the risk of flakes due to tests executing too slowly or quickly. +type Clock interface { + // Now returns the current time, as in time.Now. + Now() time.Time + // NewTimer returns a timer whose notion of the current time is controlled + // by this Clock. It follows the semantics of time.NewTimer as closely as + // possible but is adapted to return an interface, so the channel needs to + // be returned as well. + NewTimer(d time.Duration) (TimerController, <-chan time.Time) + // NewTicker returns a ticker whose notion of the current time is controlled + // by this Clock. It follows the semantics of time.NewTicker as closely as + // possible but is adapted to return an interface, so the channel needs to + // be returned as well. + NewTicker(d time.Duration) (TickerController, <-chan time.Time) + // AfterFunc returns a ticker whose notion of the current time is controlled + // by this Clock. When the ticker expires, it will call the provided func. + // It follows the semantics of time.AfterFunc. + AfterFunc(d time.Duration, f func()) TimerController + // Since returns the time elapsed since t. + // It follows the semantics of time.Since. + Since(t time.Time) time.Duration +} + +// TickerController offers the receivers of a time.Ticker to ensure +// compatibility with standard timers, but allows for the option of substituting +// a standard timer with something else for testing purposes. +type TickerController interface { + // Reset follows the same semantics as with time.Ticker.Reset. + Reset(d time.Duration) + // Stop follows the same semantics as with time.Ticker.Stop. + Stop() +} + +// TimerController offers the receivers of a time.Timer to ensure +// compatibility with standard timers, but allows for the option of substituting +// a standard timer with something else for testing purposes. +type TimerController interface { + // Reset follows the same semantics as with time.Timer.Reset. + Reset(d time.Duration) bool + // Stop follows the same semantics as with time.Timer.Stop. + Stop() bool +} + +// StdClock is a simple implementation of Clock using the relevant funcs in the +// std/time package. +type StdClock struct{} + +// Now calls time.Now. +func (StdClock) Now() time.Time { + return time.Now() +} + +// NewTimer calls time.NewTimer. As an interface does not allow for struct +// members and other packages cannot add receivers to another package, the +// channel is also returned because it would be otherwise inaccessible. +func (StdClock) NewTimer(d time.Duration) (TimerController, <-chan time.Time) { + t := time.NewTimer(d) + return t, t.C +} + +// NewTicker calls time.NewTicker. As an interface does not allow for struct +// members and other packages cannot add receivers to another package, the +// channel is also returned because it would be otherwise inaccessible. +func (StdClock) NewTicker(d time.Duration) (TickerController, <-chan time.Time) { + t := time.NewTicker(d) + return t, t.C +} + +// AfterFunc calls time.AfterFunc. +func (StdClock) AfterFunc(d time.Duration, f func()) TimerController { + return time.AfterFunc(d, f) +} + +// Since calls time.Since. +func (StdClock) Since(t time.Time) time.Duration { + return time.Since(t) +} diff --git a/tstime/tstime_test.go b/tstime/tstime_test.go index 3ffeaf0fff1b8..1169408b69b29 100644 --- a/tstime/tstime_test.go +++ b/tstime/tstime_test.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstime - -import ( - "testing" - "time" -) - -func TestParseDuration(t *testing.T) { - tests := []struct { - in string - want time.Duration - }{ - {"1h", time.Hour}, - {"1d", 24 * time.Hour}, - {"365d", 365 * 24 * time.Hour}, - {"12345d", 12345 * 24 * time.Hour}, - {"67890d", 67890 * 24 * time.Hour}, - {"100d", 100 * 24 * time.Hour}, - {"1d1d", 48 * time.Hour}, - {"1h1d", 25 * time.Hour}, - {"1d1h", 25 * time.Hour}, - {"1w", 7 * 24 * time.Hour}, - {"1w1d1h", 8*24*time.Hour + time.Hour}, - {"1w1d1h", 8*24*time.Hour + time.Hour}, - {"1y", 0}, - {"", 0}, - } - for _, tt := range tests { - if got, _ := ParseDuration(tt.in); got != tt.want { - t.Errorf("ParseDuration(%q) = %d; want %d", tt.in, got, tt.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstime + +import ( + "testing" + "time" +) + +func TestParseDuration(t *testing.T) { + tests := []struct { + in string + want time.Duration + }{ + {"1h", time.Hour}, + {"1d", 24 * time.Hour}, + {"365d", 365 * 24 * time.Hour}, + {"12345d", 12345 * 24 * time.Hour}, + {"67890d", 67890 * 24 * time.Hour}, + {"100d", 100 * 24 * time.Hour}, + {"1d1d", 48 * time.Hour}, + {"1h1d", 25 * time.Hour}, + {"1d1h", 25 * time.Hour}, + {"1w", 7 * 24 * time.Hour}, + {"1w1d1h", 8*24*time.Hour + time.Hour}, + {"1w1d1h", 8*24*time.Hour + time.Hour}, + {"1y", 0}, + {"", 0}, + } + for _, tt := range tests { + if got, _ := ParseDuration(tt.in); got != tt.want { + t.Errorf("ParseDuration(%q) = %d; want %d", tt.in, got, tt.want) + } + } +} diff --git a/tsweb/debug_test.go b/tsweb/debug_test.go index 2a68ab6fb27b9..504ec06ba20ab 100644 --- a/tsweb/debug_test.go +++ b/tsweb/debug_test.go @@ -1,208 +1,208 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsweb - -import ( - "fmt" - "io" - "net/http" - "net/http/httptest" - "runtime" - "strings" - "testing" -) - -func TestDebugger(t *testing.T) { - mux := http.NewServeMux() - - dbg1 := Debugger(mux) - if dbg1 == nil { - t.Fatal("didn't get a debugger from mux") - } - - dbg2 := Debugger(mux) - if dbg2 != dbg1 { - t.Fatal("Debugger returned different debuggers for the same mux") - } - - t.Run("cpu_pprof", func(t *testing.T) { - if testing.Short() { - t.Skip("skipping second long test") - } - switch runtime.GOOS { - case "linux", "darwin": - default: - t.Skipf("skipping test on %v", runtime.GOOS) - } - req := httptest.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil) - req.RemoteAddr = "100.101.102.103:1234" - rec := httptest.NewRecorder() - mux.ServeHTTP(rec, req) - res := rec.Result() - if res.StatusCode != 200 { - t.Errorf("unexpected %v", res.Status) - } - }) -} - -func get(m http.Handler, path, srcIP string) (int, string) { - req := httptest.NewRequest("GET", path, nil) - req.RemoteAddr = srcIP + ":1234" - rec := httptest.NewRecorder() - m.ServeHTTP(rec, req) - return rec.Result().StatusCode, rec.Body.String() -} - -const ( - tsIP = "100.100.100.100" - pubIP = "8.8.8.8" -) - -func TestDebuggerKV(t *testing.T) { - mux := http.NewServeMux() - dbg := Debugger(mux) - dbg.KV("Donuts", 42) - dbg.KV("Secret code", "hunter2") - val := "red" - dbg.KVFunc("Condition", func() any { return val }) - - code, _ := get(mux, "/debug/", pubIP) - if code != 403 { - t.Fatalf("debug access wasn't denied, got %v", code) - } - - code, body := get(mux, "/debug/", tsIP) - if code != 200 { - t.Fatalf("debug access failed, got %v", code) - } - for _, want := range []string{"Donuts", "42", "Secret code", "hunter2", "Condition", "red"} { - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } - } - - val = "green" - code, body = get(mux, "/debug/", tsIP) - if code != 200 { - t.Fatalf("debug access failed, got %v", code) - } - for _, want := range []string{"Condition", "green"} { - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } - } -} - -func TestDebuggerURL(t *testing.T) { - mux := http.NewServeMux() - dbg := Debugger(mux) - dbg.URL("https://www.tailscale.com", "Homepage") - - code, body := get(mux, "/debug/", tsIP) - if code != 200 { - t.Fatalf("debug access failed, got %v", code) - } - for _, want := range []string{"https://www.tailscale.com", "Homepage"} { - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } - } -} - -func TestDebuggerSection(t *testing.T) { - mux := http.NewServeMux() - dbg := Debugger(mux) - dbg.Section(func(w io.Writer, r *http.Request) { - fmt.Fprintf(w, "Test output %v", r.RemoteAddr) - }) - - code, body := get(mux, "/debug/", tsIP) - if code != 200 { - t.Fatalf("debug access failed, got %v", code) - } - want := `Test output 100.100.100.100:1234` - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } -} - -func TestDebuggerHandle(t *testing.T) { - mux := http.NewServeMux() - dbg := Debugger(mux) - dbg.Handle("check", "Consistency check", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Test output %v", r.RemoteAddr) - })) - - code, body := get(mux, "/debug/", tsIP) - if code != 200 { - t.Fatalf("debug access failed, got %v", code) - } - for _, want := range []string{"/debug/check", "Consistency check"} { - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } - } - - code, _ = get(mux, "/debug/check", pubIP) - if code != 403 { - t.Fatal("/debug/check should be protected, but isn't") - } - - code, body = get(mux, "/debug/check", tsIP) - if code != 200 { - t.Fatal("/debug/check denied debug access") - } - want := "Test output " + tsIP - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } -} - -func ExampleDebugHandler_Handle() { - mux := http.NewServeMux() - dbg := Debugger(mux) - // Registers /debug/flushcache with the given handler, and adds a - // link to /debug/ with the description "Flush caches". - dbg.Handle("flushcache", "Flush caches", http.HandlerFunc(http.NotFound)) -} - -func ExampleDebugHandler_KV() { - mux := http.NewServeMux() - dbg := Debugger(mux) - // Adds two list items to /debug/, showing that the condition is - // red and there are 42 donuts. - dbg.KV("Condition", "red") - dbg.KV("Donuts", 42) -} - -func ExampleDebugHandler_KVFunc() { - mux := http.NewServeMux() - dbg := Debugger(mux) - // Adds an count of page renders to /debug/. Note this example - // isn't concurrency-safe. - views := 0 - dbg.KVFunc("Debug pageviews", func() any { - views = views + 1 - return views - }) - dbg.KV("Donuts", 42) -} - -func ExampleDebugHandler_URL() { - mux := http.NewServeMux() - dbg := Debugger(mux) - // Links to the Tailscale website from /debug/. - dbg.URL("https://www.tailscale.com", "Homepage") -} - -func ExampleDebugHandler_Section() { - mux := http.NewServeMux() - dbg := Debugger(mux) - // Adds a section to /debug/ that dumps the HTTP request of the - // visitor. - dbg.Section(func(w io.Writer, r *http.Request) { - io.WriteString(w, "

Dump of your HTTP request

") - fmt.Fprintf(w, "%#v", r) - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsweb + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "runtime" + "strings" + "testing" +) + +func TestDebugger(t *testing.T) { + mux := http.NewServeMux() + + dbg1 := Debugger(mux) + if dbg1 == nil { + t.Fatal("didn't get a debugger from mux") + } + + dbg2 := Debugger(mux) + if dbg2 != dbg1 { + t.Fatal("Debugger returned different debuggers for the same mux") + } + + t.Run("cpu_pprof", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping second long test") + } + switch runtime.GOOS { + case "linux", "darwin": + default: + t.Skipf("skipping test on %v", runtime.GOOS) + } + req := httptest.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil) + req.RemoteAddr = "100.101.102.103:1234" + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + res := rec.Result() + if res.StatusCode != 200 { + t.Errorf("unexpected %v", res.Status) + } + }) +} + +func get(m http.Handler, path, srcIP string) (int, string) { + req := httptest.NewRequest("GET", path, nil) + req.RemoteAddr = srcIP + ":1234" + rec := httptest.NewRecorder() + m.ServeHTTP(rec, req) + return rec.Result().StatusCode, rec.Body.String() +} + +const ( + tsIP = "100.100.100.100" + pubIP = "8.8.8.8" +) + +func TestDebuggerKV(t *testing.T) { + mux := http.NewServeMux() + dbg := Debugger(mux) + dbg.KV("Donuts", 42) + dbg.KV("Secret code", "hunter2") + val := "red" + dbg.KVFunc("Condition", func() any { return val }) + + code, _ := get(mux, "/debug/", pubIP) + if code != 403 { + t.Fatalf("debug access wasn't denied, got %v", code) + } + + code, body := get(mux, "/debug/", tsIP) + if code != 200 { + t.Fatalf("debug access failed, got %v", code) + } + for _, want := range []string{"Donuts", "42", "Secret code", "hunter2", "Condition", "red"} { + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } + } + + val = "green" + code, body = get(mux, "/debug/", tsIP) + if code != 200 { + t.Fatalf("debug access failed, got %v", code) + } + for _, want := range []string{"Condition", "green"} { + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } + } +} + +func TestDebuggerURL(t *testing.T) { + mux := http.NewServeMux() + dbg := Debugger(mux) + dbg.URL("https://www.tailscale.com", "Homepage") + + code, body := get(mux, "/debug/", tsIP) + if code != 200 { + t.Fatalf("debug access failed, got %v", code) + } + for _, want := range []string{"https://www.tailscale.com", "Homepage"} { + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } + } +} + +func TestDebuggerSection(t *testing.T) { + mux := http.NewServeMux() + dbg := Debugger(mux) + dbg.Section(func(w io.Writer, r *http.Request) { + fmt.Fprintf(w, "Test output %v", r.RemoteAddr) + }) + + code, body := get(mux, "/debug/", tsIP) + if code != 200 { + t.Fatalf("debug access failed, got %v", code) + } + want := `Test output 100.100.100.100:1234` + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } +} + +func TestDebuggerHandle(t *testing.T) { + mux := http.NewServeMux() + dbg := Debugger(mux) + dbg.Handle("check", "Consistency check", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Test output %v", r.RemoteAddr) + })) + + code, body := get(mux, "/debug/", tsIP) + if code != 200 { + t.Fatalf("debug access failed, got %v", code) + } + for _, want := range []string{"/debug/check", "Consistency check"} { + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } + } + + code, _ = get(mux, "/debug/check", pubIP) + if code != 403 { + t.Fatal("/debug/check should be protected, but isn't") + } + + code, body = get(mux, "/debug/check", tsIP) + if code != 200 { + t.Fatal("/debug/check denied debug access") + } + want := "Test output " + tsIP + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } +} + +func ExampleDebugHandler_Handle() { + mux := http.NewServeMux() + dbg := Debugger(mux) + // Registers /debug/flushcache with the given handler, and adds a + // link to /debug/ with the description "Flush caches". + dbg.Handle("flushcache", "Flush caches", http.HandlerFunc(http.NotFound)) +} + +func ExampleDebugHandler_KV() { + mux := http.NewServeMux() + dbg := Debugger(mux) + // Adds two list items to /debug/, showing that the condition is + // red and there are 42 donuts. + dbg.KV("Condition", "red") + dbg.KV("Donuts", 42) +} + +func ExampleDebugHandler_KVFunc() { + mux := http.NewServeMux() + dbg := Debugger(mux) + // Adds an count of page renders to /debug/. Note this example + // isn't concurrency-safe. + views := 0 + dbg.KVFunc("Debug pageviews", func() any { + views = views + 1 + return views + }) + dbg.KV("Donuts", 42) +} + +func ExampleDebugHandler_URL() { + mux := http.NewServeMux() + dbg := Debugger(mux) + // Links to the Tailscale website from /debug/. + dbg.URL("https://www.tailscale.com", "Homepage") +} + +func ExampleDebugHandler_Section() { + mux := http.NewServeMux() + dbg := Debugger(mux) + // Adds a section to /debug/ that dumps the HTTP request of the + // visitor. + dbg.Section(func(w io.Writer, r *http.Request) { + io.WriteString(w, "

Dump of your HTTP request

") + fmt.Fprintf(w, "%#v", r) + }) +} diff --git a/tsweb/promvarz/promvarz_test.go b/tsweb/promvarz/promvarz_test.go index a3f4e66f11a42..7f9b3396ed3c9 100644 --- a/tsweb/promvarz/promvarz_test.go +++ b/tsweb/promvarz/promvarz_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package promvarz - -import ( - "expvar" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/prometheus/client_golang/prometheus/testutil" -) - -var ( - testVar1 = expvar.NewInt("gauge_promvarz_test_expvar") - testVar2 = promauto.NewGauge(prometheus.GaugeOpts{Name: "promvarz_test_native"}) -) - -func TestHandler(t *testing.T) { - testVar1.Set(42) - testVar2.Set(4242) - - svr := httptest.NewServer(http.HandlerFunc(Handler)) - defer svr.Close() - - want := ` - # TYPE promvarz_test_expvar gauge - promvarz_test_expvar 42 - # TYPE promvarz_test_native gauge - promvarz_test_native 4242 - ` - if err := testutil.ScrapeAndCompare(svr.URL, strings.NewReader(want), "promvarz_test_expvar", "promvarz_test_native"); err != nil { - t.Error(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package promvarz + +import ( + "expvar" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prometheus/client_golang/prometheus/testutil" +) + +var ( + testVar1 = expvar.NewInt("gauge_promvarz_test_expvar") + testVar2 = promauto.NewGauge(prometheus.GaugeOpts{Name: "promvarz_test_native"}) +) + +func TestHandler(t *testing.T) { + testVar1.Set(42) + testVar2.Set(4242) + + svr := httptest.NewServer(http.HandlerFunc(Handler)) + defer svr.Close() + + want := ` + # TYPE promvarz_test_expvar gauge + promvarz_test_expvar 42 + # TYPE promvarz_test_native gauge + promvarz_test_native 4242 + ` + if err := testutil.ScrapeAndCompare(svr.URL, strings.NewReader(want), "promvarz_test_expvar", "promvarz_test_native"); err != nil { + t.Error(err) + } +} diff --git a/types/appctype/appconnector_test.go b/types/appctype/appconnector_test.go index 390d1776a3280..8aef135b4a876 100644 --- a/types/appctype/appconnector_test.go +++ b/types/appctype/appconnector_test.go @@ -1,78 +1,78 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package appctype - -import ( - "encoding/json" - "net/netip" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "tailscale.com/tailcfg" - "tailscale.com/util/must" -) - -var golden = `{ - "dnat": { - "opaqueid1": { - "addrs": ["100.64.0.1", "fd7a:115c:a1e0::1"], - "to": ["example.org"], - "ip": ["*"] - } - }, - "sniProxy": { - "opaqueid2": { - "addrs": ["::"], - "ip": ["tcp:443"], - "allowedDomains": ["*"] - } - }, - "advertiseRoutes": true -}` - -func TestGolden(t *testing.T) { - wantDNAT := map[ConfigID]DNATConfig{"opaqueid1": { - Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, - To: []string{"example.org"}, - IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, - }} - - wantSNI := map[ConfigID]SNIProxyConfig{"opaqueid2": { - Addrs: []netip.Addr{netip.MustParseAddr("::")}, - IP: []tailcfg.ProtoPortRange{{Proto: 6, Ports: tailcfg.PortRange{First: 443, Last: 443}}}, - AllowedDomains: []string{"*"}, - }} - - var config AppConnectorConfig - if err := json.NewDecoder(strings.NewReader(golden)).Decode(&config); err != nil { - t.Fatalf("failed to decode golden config: %v", err) - } - - if !config.AdvertiseRoutes { - t.Fatalf("expected AdvertiseRoutes to be true, got false") - } - - assertEqual(t, "DNAT", config.DNAT, wantDNAT) - assertEqual(t, "SNI", config.SNIProxy, wantSNI) -} - -func TestRoundTrip(t *testing.T) { - var config AppConnectorConfig - must.Do(json.NewDecoder(strings.NewReader(golden)).Decode(&config)) - b := must.Get(json.Marshal(config)) - var config2 AppConnectorConfig - must.Do(json.Unmarshal(b, &config2)) - assertEqual(t, "DNAT", config.DNAT, config2.DNAT) -} - -func assertEqual(t *testing.T, name string, a, b any) { - var addrComparer = cmp.Comparer(func(a, b netip.Addr) bool { - return a.Compare(b) == 0 - }) - t.Helper() - if diff := cmp.Diff(a, b, addrComparer); diff != "" { - t.Fatalf("mismatch (-want +got):\n%s", diff) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package appctype + +import ( + "encoding/json" + "net/netip" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/tailcfg" + "tailscale.com/util/must" +) + +var golden = `{ + "dnat": { + "opaqueid1": { + "addrs": ["100.64.0.1", "fd7a:115c:a1e0::1"], + "to": ["example.org"], + "ip": ["*"] + } + }, + "sniProxy": { + "opaqueid2": { + "addrs": ["::"], + "ip": ["tcp:443"], + "allowedDomains": ["*"] + } + }, + "advertiseRoutes": true +}` + +func TestGolden(t *testing.T) { + wantDNAT := map[ConfigID]DNATConfig{"opaqueid1": { + Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, + To: []string{"example.org"}, + IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, + }} + + wantSNI := map[ConfigID]SNIProxyConfig{"opaqueid2": { + Addrs: []netip.Addr{netip.MustParseAddr("::")}, + IP: []tailcfg.ProtoPortRange{{Proto: 6, Ports: tailcfg.PortRange{First: 443, Last: 443}}}, + AllowedDomains: []string{"*"}, + }} + + var config AppConnectorConfig + if err := json.NewDecoder(strings.NewReader(golden)).Decode(&config); err != nil { + t.Fatalf("failed to decode golden config: %v", err) + } + + if !config.AdvertiseRoutes { + t.Fatalf("expected AdvertiseRoutes to be true, got false") + } + + assertEqual(t, "DNAT", config.DNAT, wantDNAT) + assertEqual(t, "SNI", config.SNIProxy, wantSNI) +} + +func TestRoundTrip(t *testing.T) { + var config AppConnectorConfig + must.Do(json.NewDecoder(strings.NewReader(golden)).Decode(&config)) + b := must.Get(json.Marshal(config)) + var config2 AppConnectorConfig + must.Do(json.Unmarshal(b, &config2)) + assertEqual(t, "DNAT", config.DNAT, config2.DNAT) +} + +func assertEqual(t *testing.T, name string, a, b any) { + var addrComparer = cmp.Comparer(func(a, b netip.Addr) bool { + return a.Compare(b) == 0 + }) + t.Helper() + if diff := cmp.Diff(a, b, addrComparer); diff != "" { + t.Fatalf("mismatch (-want +got):\n%s", diff) + } +} diff --git a/types/dnstype/dnstype.go b/types/dnstype/dnstype.go index b7f5b9d02fe47..6cc91c999e8d4 100644 --- a/types/dnstype/dnstype.go +++ b/types/dnstype/dnstype.go @@ -1,68 +1,68 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package dnstype defines types for working with DNS. -package dnstype - -//go:generate go run tailscale.com/cmd/viewer --type=Resolver --clonefunc=true - -import ( - "net/netip" - "slices" -) - -// Resolver is the configuration for one DNS resolver. -type Resolver struct { - // Addr is the address of the DNS resolver, one of: - // - A plain IP address for a "classic" UDP+TCP DNS resolver. - // This is the common format as sent by the control plane. - // - An IP:port, for tests. - // - "https://resolver.com/path" for DNS over HTTPS; currently - // as of 2022-09-08 only used for certain well-known resolvers - // (see the publicdns package) for which the IP addresses to dial DoH are - // known ahead of time, so bootstrap DNS resolution is not required. - // - "http://node-address:port/path" for DNS over HTTP over WireGuard. This - // is implemented in the PeerAPI for exit nodes and app connectors. - // - [TODO] "tls://resolver.com" for DNS over TCP+TLS - Addr string `json:",omitempty"` - - // BootstrapResolution is an optional suggested resolution for the - // DoT/DoH resolver, if the resolver URL does not reference an IP - // address directly. - // BootstrapResolution may be empty, in which case clients should - // look up the DoT/DoH server using their local "classic" DNS - // resolver. - // - // As of 2022-09-08, BootstrapResolution is not yet used. - BootstrapResolution []netip.Addr `json:",omitempty"` -} - -// IPPort returns r.Addr as an IP address and port if either -// r.Addr is an IP address (the common case) or if r.Addr -// is an IP:port (as done in tests). -func (r *Resolver) IPPort() (ipp netip.AddrPort, ok bool) { - if r.Addr == "" || r.Addr[0] == 'h' || r.Addr[0] == 't' { - // Fast path to avoid ParseIP error allocation for obviously not IP - // cases. - return - } - if ip, err := netip.ParseAddr(r.Addr); err == nil { - return netip.AddrPortFrom(ip, 53), true - } - if ipp, err := netip.ParseAddrPort(r.Addr); err == nil { - return ipp, true - } - return -} - -// Equal reports whether r and other are equal. -func (r *Resolver) Equal(other *Resolver) bool { - if r == nil || other == nil { - return r == other - } - if r == other { - return true - } - - return r.Addr == other.Addr && slices.Equal(r.BootstrapResolution, other.BootstrapResolution) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package dnstype defines types for working with DNS. +package dnstype + +//go:generate go run tailscale.com/cmd/viewer --type=Resolver --clonefunc=true + +import ( + "net/netip" + "slices" +) + +// Resolver is the configuration for one DNS resolver. +type Resolver struct { + // Addr is the address of the DNS resolver, one of: + // - A plain IP address for a "classic" UDP+TCP DNS resolver. + // This is the common format as sent by the control plane. + // - An IP:port, for tests. + // - "https://resolver.com/path" for DNS over HTTPS; currently + // as of 2022-09-08 only used for certain well-known resolvers + // (see the publicdns package) for which the IP addresses to dial DoH are + // known ahead of time, so bootstrap DNS resolution is not required. + // - "http://node-address:port/path" for DNS over HTTP over WireGuard. This + // is implemented in the PeerAPI for exit nodes and app connectors. + // - [TODO] "tls://resolver.com" for DNS over TCP+TLS + Addr string `json:",omitempty"` + + // BootstrapResolution is an optional suggested resolution for the + // DoT/DoH resolver, if the resolver URL does not reference an IP + // address directly. + // BootstrapResolution may be empty, in which case clients should + // look up the DoT/DoH server using their local "classic" DNS + // resolver. + // + // As of 2022-09-08, BootstrapResolution is not yet used. + BootstrapResolution []netip.Addr `json:",omitempty"` +} + +// IPPort returns r.Addr as an IP address and port if either +// r.Addr is an IP address (the common case) or if r.Addr +// is an IP:port (as done in tests). +func (r *Resolver) IPPort() (ipp netip.AddrPort, ok bool) { + if r.Addr == "" || r.Addr[0] == 'h' || r.Addr[0] == 't' { + // Fast path to avoid ParseIP error allocation for obviously not IP + // cases. + return + } + if ip, err := netip.ParseAddr(r.Addr); err == nil { + return netip.AddrPortFrom(ip, 53), true + } + if ipp, err := netip.ParseAddrPort(r.Addr); err == nil { + return ipp, true + } + return +} + +// Equal reports whether r and other are equal. +func (r *Resolver) Equal(other *Resolver) bool { + if r == nil || other == nil { + return r == other + } + if r == other { + return true + } + + return r.Addr == other.Addr && slices.Equal(r.BootstrapResolution, other.BootstrapResolution) +} diff --git a/types/empty/message.go b/types/empty/message.go index dc8eb4cc2dc37..5ada7f40202af 100644 --- a/types/empty/message.go +++ b/types/empty/message.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package empty defines an empty struct type. -package empty - -// Message is an empty message. Its purpose is to be used as pointer -// type where nil and non-nil distinguish whether it's set. This is -// used instead of a bool when we want to marshal it as a JSON empty -// object (or null) for the future ability to add other fields, at -// which point callers would define a new struct and not use -// empty.Message. -type Message struct{} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package empty defines an empty struct type. +package empty + +// Message is an empty message. Its purpose is to be used as pointer +// type where nil and non-nil distinguish whether it's set. This is +// used instead of a bool when we want to marshal it as a JSON empty +// object (or null) for the future ability to add other fields, at +// which point callers would define a new struct and not use +// empty.Message. +type Message struct{} diff --git a/types/flagtype/flagtype.go b/types/flagtype/flagtype.go index be160dee82a21..c76b16353a280 100644 --- a/types/flagtype/flagtype.go +++ b/types/flagtype/flagtype.go @@ -1,45 +1,45 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package flagtype defines flag.Value types. -package flagtype - -import ( - "errors" - "flag" - "fmt" - "math" - "strconv" - "strings" -) - -type portValue struct{ n *uint16 } - -func PortValue(dst *uint16, defaultPort uint16) flag.Value { - *dst = defaultPort - return portValue{dst} -} - -func (p portValue) String() string { - if p.n == nil { - return "" - } - return fmt.Sprint(*p.n) -} -func (p portValue) Set(v string) error { - if v == "" { - return errors.New("can't be the empty string") - } - if strings.Contains(v, ":") { - return errors.New("expecting just a port number, without a colon") - } - n, err := strconv.ParseUint(v, 10, 64) // use 64 instead of 16 to return nicer error message - if err != nil { - return fmt.Errorf("not a valid number") - } - if n > math.MaxUint16 { - return errors.New("out of range for port number") - } - *p.n = uint16(n) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package flagtype defines flag.Value types. +package flagtype + +import ( + "errors" + "flag" + "fmt" + "math" + "strconv" + "strings" +) + +type portValue struct{ n *uint16 } + +func PortValue(dst *uint16, defaultPort uint16) flag.Value { + *dst = defaultPort + return portValue{dst} +} + +func (p portValue) String() string { + if p.n == nil { + return "" + } + return fmt.Sprint(*p.n) +} +func (p portValue) Set(v string) error { + if v == "" { + return errors.New("can't be the empty string") + } + if strings.Contains(v, ":") { + return errors.New("expecting just a port number, without a colon") + } + n, err := strconv.ParseUint(v, 10, 64) // use 64 instead of 16 to return nicer error message + if err != nil { + return fmt.Errorf("not a valid number") + } + if n > math.MaxUint16 { + return errors.New("out of range for port number") + } + *p.n = uint16(n) + return nil +} diff --git a/types/ipproto/ipproto.go b/types/ipproto/ipproto.go index b5333eb56ace0..97fc4f3dd89e8 100644 --- a/types/ipproto/ipproto.go +++ b/types/ipproto/ipproto.go @@ -1,199 +1,199 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package ipproto contains IP Protocol constants. -package ipproto - -import ( - "fmt" - "strconv" - - "tailscale.com/util/nocasemaps" - "tailscale.com/util/vizerror" -) - -// Version describes the IP address version. -type Version uint8 - -// Valid Version values. -const ( - Version4 = 4 - Version6 = 6 -) - -func (p Version) String() string { - switch p { - case Version4: - return "IPv4" - case Version6: - return "IPv6" - default: - return fmt.Sprintf("Version-%d", int(p)) - } -} - -// Proto is an IP subprotocol as defined by the IANA protocol -// numbers list -// (https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml), -// or the special values Unknown or Fragment. -type Proto uint8 - -const ( - // Unknown represents an unknown or unsupported protocol; it's - // deliberately the zero value. Strictly speaking the zero - // value is IPv6 hop-by-hop extensions, but we don't support - // those, so this is still technically correct. - Unknown Proto = 0x00 - - // Values from the IANA registry. - ICMPv4 Proto = 0x01 - IGMP Proto = 0x02 - ICMPv6 Proto = 0x3a - TCP Proto = 0x06 - UDP Proto = 0x11 - DCCP Proto = 0x21 - GRE Proto = 0x2f - SCTP Proto = 0x84 - - // TSMP is the Tailscale Message Protocol (our ICMP-ish - // thing), an IP protocol used only between Tailscale nodes - // (still encrypted by WireGuard) that communicates why things - // failed, etc. - // - // Proto number 99 is reserved for "any private encryption - // scheme". We never accept these from the host OS stack nor - // send them to the host network stack. It's only used between - // nodes. - TSMP Proto = 99 - - // Fragment represents any non-first IP fragment, for which we - // don't have the sub-protocol header (and therefore can't - // figure out what the sub-protocol is). - // - // 0xFF is reserved in the IANA registry, so we steal it for - // internal use. - Fragment Proto = 0xFF -) - -// Deprecated: use MarshalText instead. -func (p Proto) String() string { - switch p { - case Unknown: - return "Unknown" - case Fragment: - return "Frag" - case ICMPv4: - return "ICMPv4" - case IGMP: - return "IGMP" - case ICMPv6: - return "ICMPv6" - case UDP: - return "UDP" - case TCP: - return "TCP" - case SCTP: - return "SCTP" - case TSMP: - return "TSMP" - case GRE: - return "GRE" - case DCCP: - return "DCCP" - default: - return fmt.Sprintf("IPProto-%d", int(p)) - } -} - -// Prefer names from -// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml -// unless otherwise noted. -var ( - // preferredNames is the set of protocol names that re produced by - // MarshalText, and are the preferred representation. - preferredNames = map[Proto]string{ - 51: "ah", - DCCP: "dccp", - 8: "egp", - 50: "esp", - 47: "gre", - ICMPv4: "icmp", - IGMP: "igmp", - 9: "igp", - 4: "ipv4", - ICMPv6: "ipv6-icmp", - SCTP: "sctp", - TCP: "tcp", - UDP: "udp", - } - - // acceptedNames is the set of protocol names that are accepted by - // UnmarshalText. - acceptedNames = map[string]Proto{ - "ah": 51, - "dccp": DCCP, - "egp": 8, - "esp": 50, - "gre": 47, - "icmp": ICMPv4, - "icmpv4": ICMPv4, - "icmpv6": ICMPv6, - "igmp": IGMP, - "igp": 9, - "ip-in-ip": 4, // IANA says "ipv4"; Wikipedia/popular use says "ip-in-ip" - "ipv4": 4, - "ipv6-icmp": ICMPv6, - "sctp": SCTP, - "tcp": TCP, - "tsmp": TSMP, - "udp": UDP, - } -) - -// UnmarshalText implements encoding.TextUnmarshaler. If the input is empty, p -// is set to 0. If an error occurs, p is unchanged. -func (p *Proto) UnmarshalText(b []byte) error { - if len(b) == 0 { - *p = 0 - return nil - } - - if u, err := strconv.ParseUint(string(b), 10, 8); err == nil { - *p = Proto(u) - return nil - } - - if newP, ok := nocasemaps.GetOk(acceptedNames, string(b)); ok { - *p = newP - return nil - } - - return vizerror.Errorf("proto name %q not known; use protocol number 0-255", b) -} - -// MarshalText implements encoding.TextMarshaler. -func (p Proto) MarshalText() ([]byte, error) { - if s, ok := preferredNames[p]; ok { - return []byte(s), nil - } - return []byte(strconv.Itoa(int(p))), nil -} - -// MarshalJSON implements json.Marshaler. -func (p Proto) MarshalJSON() ([]byte, error) { - return []byte(strconv.Itoa(int(p))), nil -} - -// UnmarshalJSON implements json.Unmarshaler. If the input is empty, p is set to -// 0. If an error occurs, p is unchanged. The input must be a JSON number or an -// accepted string name. -func (p *Proto) UnmarshalJSON(b []byte) error { - if len(b) == 0 { - *p = 0 - return nil - } - if b[0] == '"' { - b = b[1 : len(b)-1] - } - return p.UnmarshalText(b) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ipproto contains IP Protocol constants. +package ipproto + +import ( + "fmt" + "strconv" + + "tailscale.com/util/nocasemaps" + "tailscale.com/util/vizerror" +) + +// Version describes the IP address version. +type Version uint8 + +// Valid Version values. +const ( + Version4 = 4 + Version6 = 6 +) + +func (p Version) String() string { + switch p { + case Version4: + return "IPv4" + case Version6: + return "IPv6" + default: + return fmt.Sprintf("Version-%d", int(p)) + } +} + +// Proto is an IP subprotocol as defined by the IANA protocol +// numbers list +// (https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml), +// or the special values Unknown or Fragment. +type Proto uint8 + +const ( + // Unknown represents an unknown or unsupported protocol; it's + // deliberately the zero value. Strictly speaking the zero + // value is IPv6 hop-by-hop extensions, but we don't support + // those, so this is still technically correct. + Unknown Proto = 0x00 + + // Values from the IANA registry. + ICMPv4 Proto = 0x01 + IGMP Proto = 0x02 + ICMPv6 Proto = 0x3a + TCP Proto = 0x06 + UDP Proto = 0x11 + DCCP Proto = 0x21 + GRE Proto = 0x2f + SCTP Proto = 0x84 + + // TSMP is the Tailscale Message Protocol (our ICMP-ish + // thing), an IP protocol used only between Tailscale nodes + // (still encrypted by WireGuard) that communicates why things + // failed, etc. + // + // Proto number 99 is reserved for "any private encryption + // scheme". We never accept these from the host OS stack nor + // send them to the host network stack. It's only used between + // nodes. + TSMP Proto = 99 + + // Fragment represents any non-first IP fragment, for which we + // don't have the sub-protocol header (and therefore can't + // figure out what the sub-protocol is). + // + // 0xFF is reserved in the IANA registry, so we steal it for + // internal use. + Fragment Proto = 0xFF +) + +// Deprecated: use MarshalText instead. +func (p Proto) String() string { + switch p { + case Unknown: + return "Unknown" + case Fragment: + return "Frag" + case ICMPv4: + return "ICMPv4" + case IGMP: + return "IGMP" + case ICMPv6: + return "ICMPv6" + case UDP: + return "UDP" + case TCP: + return "TCP" + case SCTP: + return "SCTP" + case TSMP: + return "TSMP" + case GRE: + return "GRE" + case DCCP: + return "DCCP" + default: + return fmt.Sprintf("IPProto-%d", int(p)) + } +} + +// Prefer names from +// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml +// unless otherwise noted. +var ( + // preferredNames is the set of protocol names that re produced by + // MarshalText, and are the preferred representation. + preferredNames = map[Proto]string{ + 51: "ah", + DCCP: "dccp", + 8: "egp", + 50: "esp", + 47: "gre", + ICMPv4: "icmp", + IGMP: "igmp", + 9: "igp", + 4: "ipv4", + ICMPv6: "ipv6-icmp", + SCTP: "sctp", + TCP: "tcp", + UDP: "udp", + } + + // acceptedNames is the set of protocol names that are accepted by + // UnmarshalText. + acceptedNames = map[string]Proto{ + "ah": 51, + "dccp": DCCP, + "egp": 8, + "esp": 50, + "gre": 47, + "icmp": ICMPv4, + "icmpv4": ICMPv4, + "icmpv6": ICMPv6, + "igmp": IGMP, + "igp": 9, + "ip-in-ip": 4, // IANA says "ipv4"; Wikipedia/popular use says "ip-in-ip" + "ipv4": 4, + "ipv6-icmp": ICMPv6, + "sctp": SCTP, + "tcp": TCP, + "tsmp": TSMP, + "udp": UDP, + } +) + +// UnmarshalText implements encoding.TextUnmarshaler. If the input is empty, p +// is set to 0. If an error occurs, p is unchanged. +func (p *Proto) UnmarshalText(b []byte) error { + if len(b) == 0 { + *p = 0 + return nil + } + + if u, err := strconv.ParseUint(string(b), 10, 8); err == nil { + *p = Proto(u) + return nil + } + + if newP, ok := nocasemaps.GetOk(acceptedNames, string(b)); ok { + *p = newP + return nil + } + + return vizerror.Errorf("proto name %q not known; use protocol number 0-255", b) +} + +// MarshalText implements encoding.TextMarshaler. +func (p Proto) MarshalText() ([]byte, error) { + if s, ok := preferredNames[p]; ok { + return []byte(s), nil + } + return []byte(strconv.Itoa(int(p))), nil +} + +// MarshalJSON implements json.Marshaler. +func (p Proto) MarshalJSON() ([]byte, error) { + return []byte(strconv.Itoa(int(p))), nil +} + +// UnmarshalJSON implements json.Unmarshaler. If the input is empty, p is set to +// 0. If an error occurs, p is unchanged. The input must be a JSON number or an +// accepted string name. +func (p *Proto) UnmarshalJSON(b []byte) error { + if len(b) == 0 { + *p = 0 + return nil + } + if b[0] == '"' { + b = b[1 : len(b)-1] + } + return p.UnmarshalText(b) +} diff --git a/types/key/chal.go b/types/key/chal.go index 742ac5479e4a1..da15dd1f8a01d 100644 --- a/types/key/chal.go +++ b/types/key/chal.go @@ -1,91 +1,91 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "errors" - - "go4.org/mem" - "tailscale.com/types/structs" -) - -const ( - // chalPublicHexPrefix is the prefix used to identify a - // hex-encoded challenge public key. - // - // This prefix is used in the control protocol, so cannot be - // changed. - chalPublicHexPrefix = "chalpub:" -) - -// ChallengePrivate is a challenge key, used to test whether clients control a -// key they want to prove ownership of. -// -// A ChallengePrivate is ephemeral and not serialized to the disk or network. -type ChallengePrivate struct { - _ structs.Incomparable // because == isn't constant-time - k [32]byte -} - -// NewChallenge creates and returns a new node private key. -func NewChallenge() ChallengePrivate { - return ChallengePrivate(NewNode()) -} - -// Public returns the ChallengePublic for k. -// Panics if ChallengePublic is zero. -func (k ChallengePrivate) Public() ChallengePublic { - pub := NodePrivate(k).Public() - return ChallengePublic(pub) -} - -// MarshalText implements encoding.TextMarshaler, but by returning an error. -// It shouldn't need to be marshalled anywhere. -func (k ChallengePrivate) MarshalText() ([]byte, error) { - return nil, errors.New("refusing to marshal") -} - -// SealToChallenge is like SealTo, but for a ChallengePublic. -func (k NodePrivate) SealToChallenge(p ChallengePublic, cleartext []byte) (ciphertext []byte) { - return k.SealTo(NodePublic(p), cleartext) -} - -// OpenFrom opens the NaCl box ciphertext, which must be a value -// created by NodePrivate.SealToChallenge, and returns the inner cleartext if -// ciphertext is a valid box from p to k. -func (k ChallengePrivate) OpenFrom(p NodePublic, ciphertext []byte) (cleartext []byte, ok bool) { - return NodePrivate(k).OpenFrom(p, ciphertext) -} - -// ChallengePublic is the public portion of a ChallengePrivate. -type ChallengePublic struct { - k [32]byte -} - -// String returns the output of MarshalText as a string. -func (k ChallengePublic) String() string { - bs, err := k.MarshalText() - if err != nil { - panic(err) - } - return string(bs) -} - -// AppendText implements encoding.TextAppender. -func (k ChallengePublic) AppendText(b []byte) ([]byte, error) { - return appendHexKey(b, chalPublicHexPrefix, k.k[:]), nil -} - -// MarshalText implements encoding.TextMarshaler. -func (k ChallengePublic) MarshalText() ([]byte, error) { - return k.AppendText(nil) -} - -// UnmarshalText implements encoding.TextUnmarshaler. -func (k *ChallengePublic) UnmarshalText(b []byte) error { - return parseHex(k.k[:], mem.B(b), mem.S(chalPublicHexPrefix)) -} - -// IsZero reports whether k is the zero value. -func (k ChallengePublic) IsZero() bool { return k == ChallengePublic{} } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "errors" + + "go4.org/mem" + "tailscale.com/types/structs" +) + +const ( + // chalPublicHexPrefix is the prefix used to identify a + // hex-encoded challenge public key. + // + // This prefix is used in the control protocol, so cannot be + // changed. + chalPublicHexPrefix = "chalpub:" +) + +// ChallengePrivate is a challenge key, used to test whether clients control a +// key they want to prove ownership of. +// +// A ChallengePrivate is ephemeral and not serialized to the disk or network. +type ChallengePrivate struct { + _ structs.Incomparable // because == isn't constant-time + k [32]byte +} + +// NewChallenge creates and returns a new node private key. +func NewChallenge() ChallengePrivate { + return ChallengePrivate(NewNode()) +} + +// Public returns the ChallengePublic for k. +// Panics if ChallengePublic is zero. +func (k ChallengePrivate) Public() ChallengePublic { + pub := NodePrivate(k).Public() + return ChallengePublic(pub) +} + +// MarshalText implements encoding.TextMarshaler, but by returning an error. +// It shouldn't need to be marshalled anywhere. +func (k ChallengePrivate) MarshalText() ([]byte, error) { + return nil, errors.New("refusing to marshal") +} + +// SealToChallenge is like SealTo, but for a ChallengePublic. +func (k NodePrivate) SealToChallenge(p ChallengePublic, cleartext []byte) (ciphertext []byte) { + return k.SealTo(NodePublic(p), cleartext) +} + +// OpenFrom opens the NaCl box ciphertext, which must be a value +// created by NodePrivate.SealToChallenge, and returns the inner cleartext if +// ciphertext is a valid box from p to k. +func (k ChallengePrivate) OpenFrom(p NodePublic, ciphertext []byte) (cleartext []byte, ok bool) { + return NodePrivate(k).OpenFrom(p, ciphertext) +} + +// ChallengePublic is the public portion of a ChallengePrivate. +type ChallengePublic struct { + k [32]byte +} + +// String returns the output of MarshalText as a string. +func (k ChallengePublic) String() string { + bs, err := k.MarshalText() + if err != nil { + panic(err) + } + return string(bs) +} + +// AppendText implements encoding.TextAppender. +func (k ChallengePublic) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, chalPublicHexPrefix, k.k[:]), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (k ChallengePublic) MarshalText() ([]byte, error) { + return k.AppendText(nil) +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (k *ChallengePublic) UnmarshalText(b []byte) error { + return parseHex(k.k[:], mem.B(b), mem.S(chalPublicHexPrefix)) +} + +// IsZero reports whether k is the zero value. +func (k ChallengePublic) IsZero() bool { return k == ChallengePublic{} } diff --git a/types/key/control.go b/types/key/control.go index 96021249ba047..a84359771bcab 100644 --- a/types/key/control.go +++ b/types/key/control.go @@ -1,68 +1,68 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import "encoding/json" - -// ControlPrivate is a Tailscale control plane private key. -// -// It is functionally equivalent to a MachinePrivate, but serializes -// to JSON as a byte array rather than a typed string, because our -// control plane database stores the key that way. -// -// Deprecated: this type should only be used in Tailscale's control -// plane, where existing database serializations require this -// less-good serialization format to persist. Other control plane -// implementations can use MachinePrivate with no downsides. -type ControlPrivate struct { - mkey MachinePrivate // unexported so we can limit the API surface to only exactly what we need -} - -// NewControl generates and returns a new control plane private key. -func NewControl() ControlPrivate { - return ControlPrivate{NewMachine()} -} - -// IsZero reports whether k is the zero value. -func (k ControlPrivate) IsZero() bool { - return k.mkey.IsZero() -} - -// Public returns the MachinePublic for k. -// Panics if ControlPrivate is zero. -func (k ControlPrivate) Public() MachinePublic { - return k.mkey.Public() -} - -// MarshalJSON implements json.Marshaler. -func (k ControlPrivate) MarshalJSON() ([]byte, error) { - return json.Marshal(k.mkey.k) -} - -// UnmarshalJSON implements json.Unmarshaler. -func (k *ControlPrivate) UnmarshalJSON(bs []byte) error { - return json.Unmarshal(bs, &k.mkey.k) -} - -// SealTo wraps cleartext into a NaCl box (see -// golang.org/x/crypto/nacl) to p, authenticated from k, using a -// random nonce. -// -// The returned ciphertext is a 24-byte nonce concatenated with the -// box value. -func (k ControlPrivate) SealTo(p MachinePublic, cleartext []byte) (ciphertext []byte) { - return k.mkey.SealTo(p, cleartext) -} - -// SharedKey returns the precomputed Nacl box shared key between k and p. -func (k ControlPrivate) SharedKey(p MachinePublic) MachinePrecomputedSharedKey { - return k.mkey.SharedKey(p) -} - -// OpenFrom opens the NaCl box ciphertext, which must be a value -// created by SealTo, and returns the inner cleartext if ciphertext is -// a valid box from p to k. -func (k ControlPrivate) OpenFrom(p MachinePublic, ciphertext []byte) (cleartext []byte, ok bool) { - return k.mkey.OpenFrom(p, ciphertext) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import "encoding/json" + +// ControlPrivate is a Tailscale control plane private key. +// +// It is functionally equivalent to a MachinePrivate, but serializes +// to JSON as a byte array rather than a typed string, because our +// control plane database stores the key that way. +// +// Deprecated: this type should only be used in Tailscale's control +// plane, where existing database serializations require this +// less-good serialization format to persist. Other control plane +// implementations can use MachinePrivate with no downsides. +type ControlPrivate struct { + mkey MachinePrivate // unexported so we can limit the API surface to only exactly what we need +} + +// NewControl generates and returns a new control plane private key. +func NewControl() ControlPrivate { + return ControlPrivate{NewMachine()} +} + +// IsZero reports whether k is the zero value. +func (k ControlPrivate) IsZero() bool { + return k.mkey.IsZero() +} + +// Public returns the MachinePublic for k. +// Panics if ControlPrivate is zero. +func (k ControlPrivate) Public() MachinePublic { + return k.mkey.Public() +} + +// MarshalJSON implements json.Marshaler. +func (k ControlPrivate) MarshalJSON() ([]byte, error) { + return json.Marshal(k.mkey.k) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (k *ControlPrivate) UnmarshalJSON(bs []byte) error { + return json.Unmarshal(bs, &k.mkey.k) +} + +// SealTo wraps cleartext into a NaCl box (see +// golang.org/x/crypto/nacl) to p, authenticated from k, using a +// random nonce. +// +// The returned ciphertext is a 24-byte nonce concatenated with the +// box value. +func (k ControlPrivate) SealTo(p MachinePublic, cleartext []byte) (ciphertext []byte) { + return k.mkey.SealTo(p, cleartext) +} + +// SharedKey returns the precomputed Nacl box shared key between k and p. +func (k ControlPrivate) SharedKey(p MachinePublic) MachinePrecomputedSharedKey { + return k.mkey.SharedKey(p) +} + +// OpenFrom opens the NaCl box ciphertext, which must be a value +// created by SealTo, and returns the inner cleartext if ciphertext is +// a valid box from p to k. +func (k ControlPrivate) OpenFrom(p MachinePublic, ciphertext []byte) (cleartext []byte, ok bool) { + return k.mkey.OpenFrom(p, ciphertext) +} diff --git a/types/key/control_test.go b/types/key/control_test.go index a98a586f3ba5a..06e0f36d50bcf 100644 --- a/types/key/control_test.go +++ b/types/key/control_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "encoding/json" - "testing" -) - -func TestControlKey(t *testing.T) { - serialized := `{"PrivateKey":[36,132,249,6,73,141,249,49,9,96,49,60,240,217,253,57,3,69,248,64,178,62,121,73,121,88,115,218,130,145,68,254]}` - want := ControlPrivate{ - MachinePrivate{ - k: [32]byte{36, 132, 249, 6, 73, 141, 249, 49, 9, 96, 49, 60, 240, 217, 253, 57, 3, 69, 248, 64, 178, 62, 121, 73, 121, 88, 115, 218, 130, 145, 68, 254}, - }, - } - - var got struct { - PrivateKey ControlPrivate - } - if err := json.Unmarshal([]byte(serialized), &got); err != nil { - t.Fatalf("decoding serialized ControlPrivate: %v", err) - } - - if !got.PrivateKey.mkey.Equal(want.mkey) { - t.Fatalf("Serialized ControlPrivate didn't deserialize as expected, got %v want %v", got.PrivateKey, want) - } - - bs, err := json.Marshal(got) - if err != nil { - t.Fatalf("json reserialization of ControlPrivate failed: %v", err) - } - - if got, want := string(bs), serialized; got != want { - t.Fatalf("ControlPrivate didn't round-trip, got %q want %q", got, want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "encoding/json" + "testing" +) + +func TestControlKey(t *testing.T) { + serialized := `{"PrivateKey":[36,132,249,6,73,141,249,49,9,96,49,60,240,217,253,57,3,69,248,64,178,62,121,73,121,88,115,218,130,145,68,254]}` + want := ControlPrivate{ + MachinePrivate{ + k: [32]byte{36, 132, 249, 6, 73, 141, 249, 49, 9, 96, 49, 60, 240, 217, 253, 57, 3, 69, 248, 64, 178, 62, 121, 73, 121, 88, 115, 218, 130, 145, 68, 254}, + }, + } + + var got struct { + PrivateKey ControlPrivate + } + if err := json.Unmarshal([]byte(serialized), &got); err != nil { + t.Fatalf("decoding serialized ControlPrivate: %v", err) + } + + if !got.PrivateKey.mkey.Equal(want.mkey) { + t.Fatalf("Serialized ControlPrivate didn't deserialize as expected, got %v want %v", got.PrivateKey, want) + } + + bs, err := json.Marshal(got) + if err != nil { + t.Fatalf("json reserialization of ControlPrivate failed: %v", err) + } + + if got, want := string(bs), serialized; got != want { + t.Fatalf("ControlPrivate didn't round-trip, got %q want %q", got, want) + } +} diff --git a/types/key/disco_test.go b/types/key/disco_test.go index c62c13cbf8970..c9d60c82874f8 100644 --- a/types/key/disco_test.go +++ b/types/key/disco_test.go @@ -1,83 +1,83 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "bytes" - "encoding/json" - "testing" -) - -func TestDiscoKey(t *testing.T) { - k := NewDisco() - if k.IsZero() { - t.Fatal("DiscoPrivate should not be zero") - } - - p := k.Public() - if p.IsZero() { - t.Fatal("DiscoPublic should not be zero") - } - - bs, err := p.MarshalText() - if err != nil { - t.Fatal(err) - } - if !bytes.HasPrefix(bs, []byte("discokey:")) { - t.Fatalf("serialization of public discokey %s has wrong prefix", p) - } - - z := DiscoPublic{} - if !z.IsZero() { - t.Fatal("IsZero(DiscoPublic{}) is false") - } - if s := z.ShortString(); s != "" { - t.Fatalf("DiscoPublic{}.ShortString() is %q, want \"\"", s) - } -} - -func TestDiscoSerialization(t *testing.T) { - serialized := `{ - "Pub":"discokey:50d20b455ecf12bc453f83c2cfdb2a24925d06cf2598dcaa54e91af82ce9f765" - }` - - pub := DiscoPublic{ - k: [32]uint8{ - 0x50, 0xd2, 0xb, 0x45, 0x5e, 0xcf, 0x12, 0xbc, 0x45, 0x3f, 0x83, - 0xc2, 0xcf, 0xdb, 0x2a, 0x24, 0x92, 0x5d, 0x6, 0xcf, 0x25, 0x98, - 0xdc, 0xaa, 0x54, 0xe9, 0x1a, 0xf8, 0x2c, 0xe9, 0xf7, 0x65, - }, - } - - type key struct { - Pub DiscoPublic - } - - var a key - if err := json.Unmarshal([]byte(serialized), &a); err != nil { - t.Fatal(err) - } - if a.Pub != pub { - t.Errorf("wrong deserialization of public key, got %#v want %#v", a.Pub, pub) - } - - bs, err := json.MarshalIndent(a, "", " ") - if err != nil { - t.Fatal(err) - } - - var b bytes.Buffer - json.Indent(&b, []byte(serialized), "", " ") - if got, want := string(bs), b.String(); got != want { - t.Error("json serialization doesn't roundtrip") - } -} - -func TestDiscoShared(t *testing.T) { - k1, k2 := NewDisco(), NewDisco() - s1, s2 := k1.Shared(k2.Public()), k2.Shared(k1.Public()) - if !s1.Equal(s2) { - t.Error("k1.Shared(k2) != k2.Shared(k1)") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "bytes" + "encoding/json" + "testing" +) + +func TestDiscoKey(t *testing.T) { + k := NewDisco() + if k.IsZero() { + t.Fatal("DiscoPrivate should not be zero") + } + + p := k.Public() + if p.IsZero() { + t.Fatal("DiscoPublic should not be zero") + } + + bs, err := p.MarshalText() + if err != nil { + t.Fatal(err) + } + if !bytes.HasPrefix(bs, []byte("discokey:")) { + t.Fatalf("serialization of public discokey %s has wrong prefix", p) + } + + z := DiscoPublic{} + if !z.IsZero() { + t.Fatal("IsZero(DiscoPublic{}) is false") + } + if s := z.ShortString(); s != "" { + t.Fatalf("DiscoPublic{}.ShortString() is %q, want \"\"", s) + } +} + +func TestDiscoSerialization(t *testing.T) { + serialized := `{ + "Pub":"discokey:50d20b455ecf12bc453f83c2cfdb2a24925d06cf2598dcaa54e91af82ce9f765" + }` + + pub := DiscoPublic{ + k: [32]uint8{ + 0x50, 0xd2, 0xb, 0x45, 0x5e, 0xcf, 0x12, 0xbc, 0x45, 0x3f, 0x83, + 0xc2, 0xcf, 0xdb, 0x2a, 0x24, 0x92, 0x5d, 0x6, 0xcf, 0x25, 0x98, + 0xdc, 0xaa, 0x54, 0xe9, 0x1a, 0xf8, 0x2c, 0xe9, 0xf7, 0x65, + }, + } + + type key struct { + Pub DiscoPublic + } + + var a key + if err := json.Unmarshal([]byte(serialized), &a); err != nil { + t.Fatal(err) + } + if a.Pub != pub { + t.Errorf("wrong deserialization of public key, got %#v want %#v", a.Pub, pub) + } + + bs, err := json.MarshalIndent(a, "", " ") + if err != nil { + t.Fatal(err) + } + + var b bytes.Buffer + json.Indent(&b, []byte(serialized), "", " ") + if got, want := string(bs), b.String(); got != want { + t.Error("json serialization doesn't roundtrip") + } +} + +func TestDiscoShared(t *testing.T) { + k1, k2 := NewDisco(), NewDisco() + s1, s2 := k1.Shared(k2.Public()), k2.Shared(k1.Public()) + if !s1.Equal(s2) { + t.Error("k1.Shared(k2) != k2.Shared(k1)") + } +} diff --git a/types/key/machine.go b/types/key/machine.go index a05f3cc1f5735..0dc02574c510d 100644 --- a/types/key/machine.go +++ b/types/key/machine.go @@ -1,264 +1,264 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "bytes" - "crypto/subtle" - "encoding/hex" - - "go4.org/mem" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/nacl/box" - "tailscale.com/types/structs" -) - -const ( - // machinePrivateHexPrefix is the prefix used to identify a - // hex-encoded machine private key. - // - // This prefix name is a little unfortunate, in that it comes from - // WireGuard's own key types. Unfortunately we're stuck with it for - // machine keys, because we serialize them to disk with this prefix. - machinePrivateHexPrefix = "privkey:" - - // machinePublicHexPrefix is the prefix used to identify a - // hex-encoded machine public key. - // - // This prefix is used in the control protocol, so cannot be - // changed. - machinePublicHexPrefix = "mkey:" -) - -// MachinePrivate is a machine key, used for communication with the -// Tailscale coordination server. -type MachinePrivate struct { - _ structs.Incomparable // == isn't constant-time - k [32]byte -} - -// NewMachine creates and returns a new machine private key. -func NewMachine() MachinePrivate { - var ret MachinePrivate - rand(ret.k[:]) - clamp25519Private(ret.k[:]) - return ret -} - -// IsZero reports whether k is the zero value. -func (k MachinePrivate) IsZero() bool { - return k.Equal(MachinePrivate{}) -} - -// Equal reports whether k and other are the same key. -func (k MachinePrivate) Equal(other MachinePrivate) bool { - return subtle.ConstantTimeCompare(k.k[:], other.k[:]) == 1 -} - -// Public returns the MachinePublic for k. -// Panics if MachinePrivate is zero. -func (k MachinePrivate) Public() MachinePublic { - if k.IsZero() { - panic("can't take the public key of a zero MachinePrivate") - } - var ret MachinePublic - curve25519.ScalarBaseMult(&ret.k, &k.k) - return ret -} - -// AppendText implements encoding.TextAppender. -func (k MachinePrivate) AppendText(b []byte) ([]byte, error) { - return appendHexKey(b, machinePrivateHexPrefix, k.k[:]), nil -} - -// MarshalText implements encoding.TextMarshaler. -func (k MachinePrivate) MarshalText() ([]byte, error) { - return k.AppendText(nil) -} - -// MarshalText implements encoding.TextUnmarshaler. -func (k *MachinePrivate) UnmarshalText(b []byte) error { - return parseHex(k.k[:], mem.B(b), mem.S(machinePrivateHexPrefix)) -} - -// UntypedBytes returns k, encoded as an untyped 64-character hex -// string. -// -// Deprecated: this function is risky to use, because it produces -// serialized values that do not identify themselves as a -// MachinePrivate, allowing other code to potentially parse it back in -// as the wrong key type. For new uses that don't require this -// specific raw byte serialization, please use -// MarshalText/UnmarshalText. -func (k MachinePrivate) UntypedBytes() []byte { - return bytes.Clone(k.k[:]) -} - -// SealTo wraps cleartext into a NaCl box (see -// golang.org/x/crypto/nacl) to p, authenticated from k, using a -// random nonce. -// -// The returned ciphertext is a 24-byte nonce concatenated with the -// box value. -func (k MachinePrivate) SealTo(p MachinePublic, cleartext []byte) (ciphertext []byte) { - if k.IsZero() || p.IsZero() { - panic("can't seal with zero keys") - } - var nonce [24]byte - rand(nonce[:]) - return box.Seal(nonce[:], cleartext, &nonce, &p.k, &k.k) -} - -// SharedKey returns the precomputed Nacl box shared key between k and p. -func (k MachinePrivate) SharedKey(p MachinePublic) MachinePrecomputedSharedKey { - var shared MachinePrecomputedSharedKey - box.Precompute(&shared.k, &p.k, &k.k) - return shared -} - -// MachinePrecomputedSharedKey is a precomputed shared NaCl box shared key. -type MachinePrecomputedSharedKey struct { - k [32]byte -} - -// Seal wraps cleartext into a NaCl box (see -// golang.org/x/crypto/nacl) using the shared key k as generated -// by MachinePrivate.SharedKey. -// -// The returned ciphertext is a 24-byte nonce concatenated with the -// box value. -func (k MachinePrecomputedSharedKey) Seal(cleartext []byte) (ciphertext []byte) { - if k == (MachinePrecomputedSharedKey{}) { - panic("can't seal with zero keys") - } - var nonce [24]byte - rand(nonce[:]) - return box.SealAfterPrecomputation(nonce[:], cleartext, &nonce, &k.k) -} - -// Open opens the NaCl box ciphertext, which must be a value created by -// MachinePrecomputedSharedKey.Seal or MachinePrivate.SealTo, and returns the -// inner cleartext if ciphertext is a valid box for the shared key k. -func (k MachinePrecomputedSharedKey) Open(ciphertext []byte) (cleartext []byte, ok bool) { - if k == (MachinePrecomputedSharedKey{}) { - panic("can't open with zero keys") - } - if len(ciphertext) < 24 { - return nil, false - } - var nonce [24]byte - copy(nonce[:], ciphertext) - return box.OpenAfterPrecomputation(nil, ciphertext[len(nonce):], &nonce, &k.k) -} - -// OpenFrom opens the NaCl box ciphertext, which must be a value -// created by SealTo, and returns the inner cleartext if ciphertext is -// a valid box from p to k. -func (k MachinePrivate) OpenFrom(p MachinePublic, ciphertext []byte) (cleartext []byte, ok bool) { - if k.IsZero() || p.IsZero() { - panic("can't open with zero keys") - } - if len(ciphertext) < 24 { - return nil, false - } - var nonce [24]byte - copy(nonce[:], ciphertext) - return box.Open(nil, ciphertext[len(nonce):], &nonce, &p.k, &k.k) -} - -// MachinePublic is the public portion of a a MachinePrivate. -type MachinePublic struct { - k [32]byte -} - -// MachinePublicFromRaw32 parses a 32-byte raw value as a MachinePublic. -// -// This should be used only when deserializing a MachinePublic from a -// binary protocol. -func MachinePublicFromRaw32(raw mem.RO) MachinePublic { - if raw.Len() != 32 { - panic("input has wrong size") - } - var ret MachinePublic - raw.Copy(ret.k[:]) - return ret -} - -// ParseMachinePublicUntyped parses an untyped 64-character hex value -// as a MachinePublic. -// -// Deprecated: this function is risky to use, because it cannot verify -// that the hex string was intended to be a MachinePublic. This can -// lead to accidentally decoding one type of key as another. For new -// uses that don't require backwards compatibility with the untyped -// string format, please use MarshalText/UnmarshalText. -func ParseMachinePublicUntyped(raw mem.RO) (MachinePublic, error) { - var ret MachinePublic - if err := parseHex(ret.k[:], raw, mem.B(nil)); err != nil { - return MachinePublic{}, err - } - return ret, nil -} - -// IsZero reports whether k is the zero value. -func (k MachinePublic) IsZero() bool { - return k == MachinePublic{} -} - -// ShortString returns the Tailscale conventional debug representation -// of a public key: the first five base64 digits of the key, in square -// brackets. -func (k MachinePublic) ShortString() string { - return debug32(k.k) -} - -// UntypedHexString returns k, encoded as an untyped 64-character hex -// string. -// -// Deprecated: this function is risky to use, because it produces -// serialized values that do not identify themselves as a -// MachinePublic, allowing other code to potentially parse it back in -// as the wrong key type. For new uses that don't require backwards -// compatibility with the untyped string format, please use -// MarshalText/UnmarshalText. -func (k MachinePublic) UntypedHexString() string { - return hex.EncodeToString(k.k[:]) -} - -// UntypedBytes returns k, encoded as an untyped 64-character hex -// string. -// -// Deprecated: this function is risky to use, because it produces -// serialized values that do not identify themselves as a -// MachinePublic, allowing other code to potentially parse it back in -// as the wrong key type. For new uses that don't require this -// specific raw byte serialization, please use -// MarshalText/UnmarshalText. -func (k MachinePublic) UntypedBytes() []byte { - return bytes.Clone(k.k[:]) -} - -// String returns the output of MarshalText as a string. -func (k MachinePublic) String() string { - bs, err := k.MarshalText() - if err != nil { - panic(err) - } - return string(bs) -} - -// AppendText implements encoding.TextAppender. -func (k MachinePublic) AppendText(b []byte) ([]byte, error) { - return appendHexKey(b, machinePublicHexPrefix, k.k[:]), nil -} - -// MarshalText implements encoding.TextMarshaler. -func (k MachinePublic) MarshalText() ([]byte, error) { - return k.AppendText(nil) -} - -// MarshalText implements encoding.TextUnmarshaler. -func (k *MachinePublic) UnmarshalText(b []byte) error { - return parseHex(k.k[:], mem.B(b), mem.S(machinePublicHexPrefix)) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "bytes" + "crypto/subtle" + "encoding/hex" + + "go4.org/mem" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/nacl/box" + "tailscale.com/types/structs" +) + +const ( + // machinePrivateHexPrefix is the prefix used to identify a + // hex-encoded machine private key. + // + // This prefix name is a little unfortunate, in that it comes from + // WireGuard's own key types. Unfortunately we're stuck with it for + // machine keys, because we serialize them to disk with this prefix. + machinePrivateHexPrefix = "privkey:" + + // machinePublicHexPrefix is the prefix used to identify a + // hex-encoded machine public key. + // + // This prefix is used in the control protocol, so cannot be + // changed. + machinePublicHexPrefix = "mkey:" +) + +// MachinePrivate is a machine key, used for communication with the +// Tailscale coordination server. +type MachinePrivate struct { + _ structs.Incomparable // == isn't constant-time + k [32]byte +} + +// NewMachine creates and returns a new machine private key. +func NewMachine() MachinePrivate { + var ret MachinePrivate + rand(ret.k[:]) + clamp25519Private(ret.k[:]) + return ret +} + +// IsZero reports whether k is the zero value. +func (k MachinePrivate) IsZero() bool { + return k.Equal(MachinePrivate{}) +} + +// Equal reports whether k and other are the same key. +func (k MachinePrivate) Equal(other MachinePrivate) bool { + return subtle.ConstantTimeCompare(k.k[:], other.k[:]) == 1 +} + +// Public returns the MachinePublic for k. +// Panics if MachinePrivate is zero. +func (k MachinePrivate) Public() MachinePublic { + if k.IsZero() { + panic("can't take the public key of a zero MachinePrivate") + } + var ret MachinePublic + curve25519.ScalarBaseMult(&ret.k, &k.k) + return ret +} + +// AppendText implements encoding.TextAppender. +func (k MachinePrivate) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, machinePrivateHexPrefix, k.k[:]), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (k MachinePrivate) MarshalText() ([]byte, error) { + return k.AppendText(nil) +} + +// MarshalText implements encoding.TextUnmarshaler. +func (k *MachinePrivate) UnmarshalText(b []byte) error { + return parseHex(k.k[:], mem.B(b), mem.S(machinePrivateHexPrefix)) +} + +// UntypedBytes returns k, encoded as an untyped 64-character hex +// string. +// +// Deprecated: this function is risky to use, because it produces +// serialized values that do not identify themselves as a +// MachinePrivate, allowing other code to potentially parse it back in +// as the wrong key type. For new uses that don't require this +// specific raw byte serialization, please use +// MarshalText/UnmarshalText. +func (k MachinePrivate) UntypedBytes() []byte { + return bytes.Clone(k.k[:]) +} + +// SealTo wraps cleartext into a NaCl box (see +// golang.org/x/crypto/nacl) to p, authenticated from k, using a +// random nonce. +// +// The returned ciphertext is a 24-byte nonce concatenated with the +// box value. +func (k MachinePrivate) SealTo(p MachinePublic, cleartext []byte) (ciphertext []byte) { + if k.IsZero() || p.IsZero() { + panic("can't seal with zero keys") + } + var nonce [24]byte + rand(nonce[:]) + return box.Seal(nonce[:], cleartext, &nonce, &p.k, &k.k) +} + +// SharedKey returns the precomputed Nacl box shared key between k and p. +func (k MachinePrivate) SharedKey(p MachinePublic) MachinePrecomputedSharedKey { + var shared MachinePrecomputedSharedKey + box.Precompute(&shared.k, &p.k, &k.k) + return shared +} + +// MachinePrecomputedSharedKey is a precomputed shared NaCl box shared key. +type MachinePrecomputedSharedKey struct { + k [32]byte +} + +// Seal wraps cleartext into a NaCl box (see +// golang.org/x/crypto/nacl) using the shared key k as generated +// by MachinePrivate.SharedKey. +// +// The returned ciphertext is a 24-byte nonce concatenated with the +// box value. +func (k MachinePrecomputedSharedKey) Seal(cleartext []byte) (ciphertext []byte) { + if k == (MachinePrecomputedSharedKey{}) { + panic("can't seal with zero keys") + } + var nonce [24]byte + rand(nonce[:]) + return box.SealAfterPrecomputation(nonce[:], cleartext, &nonce, &k.k) +} + +// Open opens the NaCl box ciphertext, which must be a value created by +// MachinePrecomputedSharedKey.Seal or MachinePrivate.SealTo, and returns the +// inner cleartext if ciphertext is a valid box for the shared key k. +func (k MachinePrecomputedSharedKey) Open(ciphertext []byte) (cleartext []byte, ok bool) { + if k == (MachinePrecomputedSharedKey{}) { + panic("can't open with zero keys") + } + if len(ciphertext) < 24 { + return nil, false + } + var nonce [24]byte + copy(nonce[:], ciphertext) + return box.OpenAfterPrecomputation(nil, ciphertext[len(nonce):], &nonce, &k.k) +} + +// OpenFrom opens the NaCl box ciphertext, which must be a value +// created by SealTo, and returns the inner cleartext if ciphertext is +// a valid box from p to k. +func (k MachinePrivate) OpenFrom(p MachinePublic, ciphertext []byte) (cleartext []byte, ok bool) { + if k.IsZero() || p.IsZero() { + panic("can't open with zero keys") + } + if len(ciphertext) < 24 { + return nil, false + } + var nonce [24]byte + copy(nonce[:], ciphertext) + return box.Open(nil, ciphertext[len(nonce):], &nonce, &p.k, &k.k) +} + +// MachinePublic is the public portion of a a MachinePrivate. +type MachinePublic struct { + k [32]byte +} + +// MachinePublicFromRaw32 parses a 32-byte raw value as a MachinePublic. +// +// This should be used only when deserializing a MachinePublic from a +// binary protocol. +func MachinePublicFromRaw32(raw mem.RO) MachinePublic { + if raw.Len() != 32 { + panic("input has wrong size") + } + var ret MachinePublic + raw.Copy(ret.k[:]) + return ret +} + +// ParseMachinePublicUntyped parses an untyped 64-character hex value +// as a MachinePublic. +// +// Deprecated: this function is risky to use, because it cannot verify +// that the hex string was intended to be a MachinePublic. This can +// lead to accidentally decoding one type of key as another. For new +// uses that don't require backwards compatibility with the untyped +// string format, please use MarshalText/UnmarshalText. +func ParseMachinePublicUntyped(raw mem.RO) (MachinePublic, error) { + var ret MachinePublic + if err := parseHex(ret.k[:], raw, mem.B(nil)); err != nil { + return MachinePublic{}, err + } + return ret, nil +} + +// IsZero reports whether k is the zero value. +func (k MachinePublic) IsZero() bool { + return k == MachinePublic{} +} + +// ShortString returns the Tailscale conventional debug representation +// of a public key: the first five base64 digits of the key, in square +// brackets. +func (k MachinePublic) ShortString() string { + return debug32(k.k) +} + +// UntypedHexString returns k, encoded as an untyped 64-character hex +// string. +// +// Deprecated: this function is risky to use, because it produces +// serialized values that do not identify themselves as a +// MachinePublic, allowing other code to potentially parse it back in +// as the wrong key type. For new uses that don't require backwards +// compatibility with the untyped string format, please use +// MarshalText/UnmarshalText. +func (k MachinePublic) UntypedHexString() string { + return hex.EncodeToString(k.k[:]) +} + +// UntypedBytes returns k, encoded as an untyped 64-character hex +// string. +// +// Deprecated: this function is risky to use, because it produces +// serialized values that do not identify themselves as a +// MachinePublic, allowing other code to potentially parse it back in +// as the wrong key type. For new uses that don't require this +// specific raw byte serialization, please use +// MarshalText/UnmarshalText. +func (k MachinePublic) UntypedBytes() []byte { + return bytes.Clone(k.k[:]) +} + +// String returns the output of MarshalText as a string. +func (k MachinePublic) String() string { + bs, err := k.MarshalText() + if err != nil { + panic(err) + } + return string(bs) +} + +// AppendText implements encoding.TextAppender. +func (k MachinePublic) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, machinePublicHexPrefix, k.k[:]), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (k MachinePublic) MarshalText() ([]byte, error) { + return k.AppendText(nil) +} + +// MarshalText implements encoding.TextUnmarshaler. +func (k *MachinePublic) UnmarshalText(b []byte) error { + return parseHex(k.k[:], mem.B(b), mem.S(machinePublicHexPrefix)) +} diff --git a/types/key/machine_test.go b/types/key/machine_test.go index 157df9e4356b1..f797ff087f090 100644 --- a/types/key/machine_test.go +++ b/types/key/machine_test.go @@ -1,119 +1,119 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "bytes" - "encoding/json" - "strings" - "testing" -) - -func TestMachineKey(t *testing.T) { - k := NewMachine() - if k.IsZero() { - t.Fatal("MachinePrivate should not be zero") - } - - p := k.Public() - if p.IsZero() { - t.Fatal("MachinePublic should not be zero") - } - - bs, err := p.MarshalText() - if err != nil { - t.Fatal(err) - } - if full, got := string(bs), ":"+p.UntypedHexString(); !strings.HasSuffix(full, got) { - t.Fatalf("MachinePublic.UntypedHexString is not a suffix of the typed serialization, got %q want suffix of %q", got, full) - } - - z := MachinePublic{} - if !z.IsZero() { - t.Fatal("IsZero(MachinePublic{}) is false") - } - if s := z.ShortString(); s != "" { - t.Fatalf("MachinePublic{}.ShortString() is %q, want \"\"", s) - } -} - -func TestMachineSerialization(t *testing.T) { - serialized := `{ - "Priv": "privkey:40ab1b58e9076c7a4d9d07291f5edf9d1aa017eb949624ba683317f48a640369", - "Pub":"mkey:50d20b455ecf12bc453f83c2cfdb2a24925d06cf2598dcaa54e91af82ce9f765" - }` - - // Carefully check that the expected serialized data decodes and - // reencodes to the expected keys. These types are serialized to - // disk all over the place and need to be stable. - priv := MachinePrivate{ - k: [32]uint8{ - 0x40, 0xab, 0x1b, 0x58, 0xe9, 0x7, 0x6c, 0x7a, 0x4d, 0x9d, 0x7, - 0x29, 0x1f, 0x5e, 0xdf, 0x9d, 0x1a, 0xa0, 0x17, 0xeb, 0x94, - 0x96, 0x24, 0xba, 0x68, 0x33, 0x17, 0xf4, 0x8a, 0x64, 0x3, 0x69, - }, - } - pub := MachinePublic{ - k: [32]uint8{ - 0x50, 0xd2, 0xb, 0x45, 0x5e, 0xcf, 0x12, 0xbc, 0x45, 0x3f, 0x83, - 0xc2, 0xcf, 0xdb, 0x2a, 0x24, 0x92, 0x5d, 0x6, 0xcf, 0x25, 0x98, - 0xdc, 0xaa, 0x54, 0xe9, 0x1a, 0xf8, 0x2c, 0xe9, 0xf7, 0x65, - }, - } - - type keypair struct { - Priv MachinePrivate - Pub MachinePublic - } - - var a keypair - if err := json.Unmarshal([]byte(serialized), &a); err != nil { - t.Fatal(err) - } - if !a.Priv.Equal(priv) { - t.Errorf("wrong deserialization of private key, got %#v want %#v", a.Priv, priv) - } - if a.Pub != pub { - t.Errorf("wrong deserialization of public key, got %#v want %#v", a.Pub, pub) - } - - bs, err := json.MarshalIndent(a, "", " ") - if err != nil { - t.Fatal(err) - } - - var b bytes.Buffer - json.Indent(&b, []byte(serialized), "", " ") - if got, want := string(bs), b.String(); got != want { - t.Error("json serialization doesn't roundtrip") - } -} - -func TestSealViaSharedKey(t *testing.T) { - // encrypt a message from a to b - a := NewMachine() - b := NewMachine() - apub, bpub := a.Public(), b.Public() - - shared := a.SharedKey(bpub) - - const clear = "the eagle flies at midnight" - enc := shared.Seal([]byte(clear)) - - back, ok := b.OpenFrom(apub, enc) - if !ok { - t.Fatal("failed to decrypt") - } - if string(back) != clear { - t.Errorf("OpenFrom got %q; want cleartext %q", back, clear) - } - - backShared, ok := shared.Open(enc) - if !ok { - t.Fatal("failed to decrypt from shared key") - } - if string(backShared) != clear { - t.Errorf("Open got %q; want cleartext %q", back, clear) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "bytes" + "encoding/json" + "strings" + "testing" +) + +func TestMachineKey(t *testing.T) { + k := NewMachine() + if k.IsZero() { + t.Fatal("MachinePrivate should not be zero") + } + + p := k.Public() + if p.IsZero() { + t.Fatal("MachinePublic should not be zero") + } + + bs, err := p.MarshalText() + if err != nil { + t.Fatal(err) + } + if full, got := string(bs), ":"+p.UntypedHexString(); !strings.HasSuffix(full, got) { + t.Fatalf("MachinePublic.UntypedHexString is not a suffix of the typed serialization, got %q want suffix of %q", got, full) + } + + z := MachinePublic{} + if !z.IsZero() { + t.Fatal("IsZero(MachinePublic{}) is false") + } + if s := z.ShortString(); s != "" { + t.Fatalf("MachinePublic{}.ShortString() is %q, want \"\"", s) + } +} + +func TestMachineSerialization(t *testing.T) { + serialized := `{ + "Priv": "privkey:40ab1b58e9076c7a4d9d07291f5edf9d1aa017eb949624ba683317f48a640369", + "Pub":"mkey:50d20b455ecf12bc453f83c2cfdb2a24925d06cf2598dcaa54e91af82ce9f765" + }` + + // Carefully check that the expected serialized data decodes and + // reencodes to the expected keys. These types are serialized to + // disk all over the place and need to be stable. + priv := MachinePrivate{ + k: [32]uint8{ + 0x40, 0xab, 0x1b, 0x58, 0xe9, 0x7, 0x6c, 0x7a, 0x4d, 0x9d, 0x7, + 0x29, 0x1f, 0x5e, 0xdf, 0x9d, 0x1a, 0xa0, 0x17, 0xeb, 0x94, + 0x96, 0x24, 0xba, 0x68, 0x33, 0x17, 0xf4, 0x8a, 0x64, 0x3, 0x69, + }, + } + pub := MachinePublic{ + k: [32]uint8{ + 0x50, 0xd2, 0xb, 0x45, 0x5e, 0xcf, 0x12, 0xbc, 0x45, 0x3f, 0x83, + 0xc2, 0xcf, 0xdb, 0x2a, 0x24, 0x92, 0x5d, 0x6, 0xcf, 0x25, 0x98, + 0xdc, 0xaa, 0x54, 0xe9, 0x1a, 0xf8, 0x2c, 0xe9, 0xf7, 0x65, + }, + } + + type keypair struct { + Priv MachinePrivate + Pub MachinePublic + } + + var a keypair + if err := json.Unmarshal([]byte(serialized), &a); err != nil { + t.Fatal(err) + } + if !a.Priv.Equal(priv) { + t.Errorf("wrong deserialization of private key, got %#v want %#v", a.Priv, priv) + } + if a.Pub != pub { + t.Errorf("wrong deserialization of public key, got %#v want %#v", a.Pub, pub) + } + + bs, err := json.MarshalIndent(a, "", " ") + if err != nil { + t.Fatal(err) + } + + var b bytes.Buffer + json.Indent(&b, []byte(serialized), "", " ") + if got, want := string(bs), b.String(); got != want { + t.Error("json serialization doesn't roundtrip") + } +} + +func TestSealViaSharedKey(t *testing.T) { + // encrypt a message from a to b + a := NewMachine() + b := NewMachine() + apub, bpub := a.Public(), b.Public() + + shared := a.SharedKey(bpub) + + const clear = "the eagle flies at midnight" + enc := shared.Seal([]byte(clear)) + + back, ok := b.OpenFrom(apub, enc) + if !ok { + t.Fatal("failed to decrypt") + } + if string(back) != clear { + t.Errorf("OpenFrom got %q; want cleartext %q", back, clear) + } + + backShared, ok := shared.Open(enc) + if !ok { + t.Fatal("failed to decrypt from shared key") + } + if string(backShared) != clear { + t.Errorf("Open got %q; want cleartext %q", back, clear) + } +} diff --git a/types/key/nl_test.go b/types/key/nl_test.go index 75b7765a19ea1..2e10d04acc58b 100644 --- a/types/key/nl_test.go +++ b/types/key/nl_test.go @@ -1,48 +1,48 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "bytes" - "testing" -) - -func TestNLPrivate(t *testing.T) { - p := NewNLPrivate() - - encoded, err := p.MarshalText() - if err != nil { - t.Fatal(err) - } - var decoded NLPrivate - if err := decoded.UnmarshalText(encoded); err != nil { - t.Fatal(err) - } - if !bytes.Equal(decoded.k[:], p.k[:]) { - t.Error("decoded and generated NLPrivate bytes differ") - } - - // Test NLPublic - pub := p.Public() - encoded, err = pub.MarshalText() - if err != nil { - t.Fatal(err) - } - var decodedPub NLPublic - if err := decodedPub.UnmarshalText(encoded); err != nil { - t.Fatal(err) - } - if !bytes.Equal(decodedPub.k[:], pub.k[:]) { - t.Error("decoded and generated NLPublic bytes differ") - } - - // Test decoding with CLI prefix: 'nlpub:' => 'tlpub:' - decodedPub = NLPublic{} - if err := decodedPub.UnmarshalText([]byte(pub.CLIString())); err != nil { - t.Fatal(err) - } - if !bytes.Equal(decodedPub.k[:], pub.k[:]) { - t.Error("decoded and generated NLPublic bytes differ (CLI prefix)") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "bytes" + "testing" +) + +func TestNLPrivate(t *testing.T) { + p := NewNLPrivate() + + encoded, err := p.MarshalText() + if err != nil { + t.Fatal(err) + } + var decoded NLPrivate + if err := decoded.UnmarshalText(encoded); err != nil { + t.Fatal(err) + } + if !bytes.Equal(decoded.k[:], p.k[:]) { + t.Error("decoded and generated NLPrivate bytes differ") + } + + // Test NLPublic + pub := p.Public() + encoded, err = pub.MarshalText() + if err != nil { + t.Fatal(err) + } + var decodedPub NLPublic + if err := decodedPub.UnmarshalText(encoded); err != nil { + t.Fatal(err) + } + if !bytes.Equal(decodedPub.k[:], pub.k[:]) { + t.Error("decoded and generated NLPublic bytes differ") + } + + // Test decoding with CLI prefix: 'nlpub:' => 'tlpub:' + decodedPub = NLPublic{} + if err := decodedPub.UnmarshalText([]byte(pub.CLIString())); err != nil { + t.Fatal(err) + } + if !bytes.Equal(decodedPub.k[:], pub.k[:]) { + t.Error("decoded and generated NLPublic bytes differ (CLI prefix)") + } +} diff --git a/types/lazy/unsync.go b/types/lazy/unsync.go index 0f89ce4f6935a..ca46f9c7bbad3 100644 --- a/types/lazy/unsync.go +++ b/types/lazy/unsync.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package lazy - -// GValue is a lazily computed value. -// -// Use either Get or GetErr, depending on whether your fill function returns an -// error. -// -// Recursive use of a GValue from its own fill function will panic. -// -// GValue is not safe for concurrent use. (Mnemonic: G is for one Goroutine, -// which isn't strictly true if you provide your own synchronization between -// goroutines, but in practice most of our callers have been using it within -// a single goroutine.) -type GValue[T any] struct { - done bool - calling bool - V T - err error -} - -// Set attempts to set z's value to val, and reports whether it succeeded. -// Set only succeeds if none of Get/GetErr/Set have been called before. -func (z *GValue[T]) Set(v T) bool { - if z.done { - return false - } - if z.calling { - panic("Set while Get fill is running") - } - z.V = v - z.done = true - return true -} - -// MustSet sets z's value to val, or panics if z already has a value. -func (z *GValue[T]) MustSet(val T) { - if !z.Set(val) { - panic("Set after already filled") - } -} - -// Get returns z's value, calling fill to compute it if necessary. -// f is called at most once. -func (z *GValue[T]) Get(fill func() T) T { - if !z.done { - if z.calling { - panic("recursive lazy fill") - } - z.calling = true - z.V = fill() - z.done = true - z.calling = false - } - return z.V -} - -// GetErr returns z's value, calling fill to compute it if necessary. -// f is called at most once, and z remembers both of fill's outputs. -func (z *GValue[T]) GetErr(fill func() (T, error)) (T, error) { - if !z.done { - if z.calling { - panic("recursive lazy fill") - } - z.calling = true - z.V, z.err = fill() - z.done = true - z.calling = false - } - return z.V, z.err -} - -// GFunc wraps a function to make it lazy. -// -// The returned function calls fill the first time it's called, and returns -// fill's result on every subsequent call. -// -// The returned function is not safe for concurrent use. -func GFunc[T any](fill func() T) func() T { - var v GValue[T] - return func() T { - return v.Get(fill) - } -} - -// SyncFuncErr wraps a function to make it lazy. -// -// The returned function calls fill the first time it's called, and returns -// fill's results on every subsequent call. -// -// The returned function is not safe for concurrent use. -func GFuncErr[T any](fill func() (T, error)) func() (T, error) { - var v GValue[T] - return func() (T, error) { - return v.GetErr(fill) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lazy + +// GValue is a lazily computed value. +// +// Use either Get or GetErr, depending on whether your fill function returns an +// error. +// +// Recursive use of a GValue from its own fill function will panic. +// +// GValue is not safe for concurrent use. (Mnemonic: G is for one Goroutine, +// which isn't strictly true if you provide your own synchronization between +// goroutines, but in practice most of our callers have been using it within +// a single goroutine.) +type GValue[T any] struct { + done bool + calling bool + V T + err error +} + +// Set attempts to set z's value to val, and reports whether it succeeded. +// Set only succeeds if none of Get/GetErr/Set have been called before. +func (z *GValue[T]) Set(v T) bool { + if z.done { + return false + } + if z.calling { + panic("Set while Get fill is running") + } + z.V = v + z.done = true + return true +} + +// MustSet sets z's value to val, or panics if z already has a value. +func (z *GValue[T]) MustSet(val T) { + if !z.Set(val) { + panic("Set after already filled") + } +} + +// Get returns z's value, calling fill to compute it if necessary. +// f is called at most once. +func (z *GValue[T]) Get(fill func() T) T { + if !z.done { + if z.calling { + panic("recursive lazy fill") + } + z.calling = true + z.V = fill() + z.done = true + z.calling = false + } + return z.V +} + +// GetErr returns z's value, calling fill to compute it if necessary. +// f is called at most once, and z remembers both of fill's outputs. +func (z *GValue[T]) GetErr(fill func() (T, error)) (T, error) { + if !z.done { + if z.calling { + panic("recursive lazy fill") + } + z.calling = true + z.V, z.err = fill() + z.done = true + z.calling = false + } + return z.V, z.err +} + +// GFunc wraps a function to make it lazy. +// +// The returned function calls fill the first time it's called, and returns +// fill's result on every subsequent call. +// +// The returned function is not safe for concurrent use. +func GFunc[T any](fill func() T) func() T { + var v GValue[T] + return func() T { + return v.Get(fill) + } +} + +// SyncFuncErr wraps a function to make it lazy. +// +// The returned function calls fill the first time it's called, and returns +// fill's results on every subsequent call. +// +// The returned function is not safe for concurrent use. +func GFuncErr[T any](fill func() (T, error)) func() (T, error) { + var v GValue[T] + return func() (T, error) { + return v.GetErr(fill) + } +} diff --git a/types/lazy/unsync_test.go b/types/lazy/unsync_test.go index f0d2494d12b6e..d8b870dbeb8a8 100644 --- a/types/lazy/unsync_test.go +++ b/types/lazy/unsync_test.go @@ -1,140 +1,140 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package lazy - -import ( - "errors" - "testing" -) - -func fortyTwo() int { return 42 } - -func TestGValue(t *testing.T) { - var lt GValue[int] - n := int(testing.AllocsPerRun(1000, func() { - got := lt.Get(fortyTwo) - if got != 42 { - t.Fatalf("got %v; want 42", got) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} - -func TestGValueErr(t *testing.T) { - var lt GValue[int] - n := int(testing.AllocsPerRun(1000, func() { - got, err := lt.GetErr(func() (int, error) { - return 42, nil - }) - if got != 42 || err != nil { - t.Fatalf("got %v, %v; want 42, nil", got, err) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } - - var lterr GValue[int] - wantErr := errors.New("test error") - n = int(testing.AllocsPerRun(1000, func() { - got, err := lterr.GetErr(func() (int, error) { - return 0, wantErr - }) - if got != 0 || err != wantErr { - t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} - -func TestGValueSet(t *testing.T) { - var lt GValue[int] - if !lt.Set(42) { - t.Fatalf("Set failed") - } - if lt.Set(43) { - t.Fatalf("Set succeeded after first Set") - } - n := int(testing.AllocsPerRun(1000, func() { - got := lt.Get(fortyTwo) - if got != 42 { - t.Fatalf("got %v; want 42", got) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} - -func TestGValueMustSet(t *testing.T) { - var lt GValue[int] - lt.MustSet(42) - defer func() { - if e := recover(); e == nil { - t.Errorf("unexpected success; want panic") - } - }() - lt.MustSet(43) -} - -func TestGValueRecursivePanic(t *testing.T) { - defer func() { - if e := recover(); e != nil { - t.Logf("got panic, as expected") - } else { - t.Errorf("unexpected success; want panic") - } - }() - v := GValue[int]{} - v.Get(func() int { - return v.Get(func() int { return 42 }) - }) -} - -func TestGFunc(t *testing.T) { - f := GFunc(fortyTwo) - - n := int(testing.AllocsPerRun(1000, func() { - got := f() - if got != 42 { - t.Fatalf("got %v; want 42", got) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} - -func TestGFuncErr(t *testing.T) { - f := GFuncErr(func() (int, error) { - return 42, nil - }) - n := int(testing.AllocsPerRun(1000, func() { - got, err := f() - if got != 42 || err != nil { - t.Fatalf("got %v, %v; want 42, nil", got, err) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } - - wantErr := errors.New("test error") - f = GFuncErr(func() (int, error) { - return 0, wantErr - }) - n = int(testing.AllocsPerRun(1000, func() { - got, err := f() - if got != 0 || err != wantErr { - t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lazy + +import ( + "errors" + "testing" +) + +func fortyTwo() int { return 42 } + +func TestGValue(t *testing.T) { + var lt GValue[int] + n := int(testing.AllocsPerRun(1000, func() { + got := lt.Get(fortyTwo) + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGValueErr(t *testing.T) { + var lt GValue[int] + n := int(testing.AllocsPerRun(1000, func() { + got, err := lt.GetErr(func() (int, error) { + return 42, nil + }) + if got != 42 || err != nil { + t.Fatalf("got %v, %v; want 42, nil", got, err) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } + + var lterr GValue[int] + wantErr := errors.New("test error") + n = int(testing.AllocsPerRun(1000, func() { + got, err := lterr.GetErr(func() (int, error) { + return 0, wantErr + }) + if got != 0 || err != wantErr { + t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGValueSet(t *testing.T) { + var lt GValue[int] + if !lt.Set(42) { + t.Fatalf("Set failed") + } + if lt.Set(43) { + t.Fatalf("Set succeeded after first Set") + } + n := int(testing.AllocsPerRun(1000, func() { + got := lt.Get(fortyTwo) + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGValueMustSet(t *testing.T) { + var lt GValue[int] + lt.MustSet(42) + defer func() { + if e := recover(); e == nil { + t.Errorf("unexpected success; want panic") + } + }() + lt.MustSet(43) +} + +func TestGValueRecursivePanic(t *testing.T) { + defer func() { + if e := recover(); e != nil { + t.Logf("got panic, as expected") + } else { + t.Errorf("unexpected success; want panic") + } + }() + v := GValue[int]{} + v.Get(func() int { + return v.Get(func() int { return 42 }) + }) +} + +func TestGFunc(t *testing.T) { + f := GFunc(fortyTwo) + + n := int(testing.AllocsPerRun(1000, func() { + got := f() + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGFuncErr(t *testing.T) { + f := GFuncErr(func() (int, error) { + return 42, nil + }) + n := int(testing.AllocsPerRun(1000, func() { + got, err := f() + if got != 42 || err != nil { + t.Fatalf("got %v, %v; want 42, nil", got, err) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } + + wantErr := errors.New("test error") + f = GFuncErr(func() (int, error) { + return 0, wantErr + }) + n = int(testing.AllocsPerRun(1000, func() { + got, err := f() + if got != 0 || err != wantErr { + t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} diff --git a/types/logger/rusage.go b/types/logger/rusage.go index 3943636d6e255..ebe0e972d7749 100644 --- a/types/logger/rusage.go +++ b/types/logger/rusage.go @@ -1,23 +1,23 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package logger - -import ( - "fmt" - "runtime" -) - -// RusagePrefixLog returns a Logf func wrapping the provided logf func that adds -// a prefixed log message to each line with the current binary memory usage -// and max RSS. -func RusagePrefixLog(logf Logf) Logf { - return func(f string, argv ...any) { - var m runtime.MemStats - runtime.ReadMemStats(&m) - goMem := float64(m.HeapInuse+m.StackInuse) / (1 << 20) - maxRSS := rusageMaxRSS() - pf := fmt.Sprintf("%.1fM/%.1fM %s", goMem, maxRSS, f) - logf(pf, argv...) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package logger + +import ( + "fmt" + "runtime" +) + +// RusagePrefixLog returns a Logf func wrapping the provided logf func that adds +// a prefixed log message to each line with the current binary memory usage +// and max RSS. +func RusagePrefixLog(logf Logf) Logf { + return func(f string, argv ...any) { + var m runtime.MemStats + runtime.ReadMemStats(&m) + goMem := float64(m.HeapInuse+m.StackInuse) / (1 << 20) + maxRSS := rusageMaxRSS() + pf := fmt.Sprintf("%.1fM/%.1fM %s", goMem, maxRSS, f) + logf(pf, argv...) + } +} diff --git a/types/logger/rusage_stub.go b/types/logger/rusage_stub.go index f646f1e1eee7f..a228b086557fb 100644 --- a/types/logger/rusage_stub.go +++ b/types/logger/rusage_stub.go @@ -1,11 +1,11 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build windows || wasm || plan9 || tamago - -package logger - -func rusageMaxRSS() float64 { - // TODO(apenwarr): Substitute Windows equivalent of Getrusage() here. - return 0 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows || wasm || plan9 || tamago + +package logger + +func rusageMaxRSS() float64 { + // TODO(apenwarr): Substitute Windows equivalent of Getrusage() here. + return 0 +} diff --git a/types/logger/rusage_syscall.go b/types/logger/rusage_syscall.go index 2871b66c6bb24..19488aef1e800 100644 --- a/types/logger/rusage_syscall.go +++ b/types/logger/rusage_syscall.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !wasm && !plan9 && !tamago - -package logger - -import ( - "runtime" - - "golang.org/x/sys/unix" -) - -func rusageMaxRSS() float64 { - var ru unix.Rusage - err := unix.Getrusage(unix.RUSAGE_SELF, &ru) - if err != nil { - return 0 - } - - rss := float64(ru.Maxrss) - if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { - rss /= 1 << 20 // ru_maxrss is bytes on darwin - } else { - // ru_maxrss is kilobytes elsewhere (linux, openbsd, etc) - rss /= 1 << 10 - } - return rss -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !wasm && !plan9 && !tamago + +package logger + +import ( + "runtime" + + "golang.org/x/sys/unix" +) + +func rusageMaxRSS() float64 { + var ru unix.Rusage + err := unix.Getrusage(unix.RUSAGE_SELF, &ru) + if err != nil { + return 0 + } + + rss := float64(ru.Maxrss) + if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { + rss /= 1 << 20 // ru_maxrss is bytes on darwin + } else { + // ru_maxrss is kilobytes elsewhere (linux, openbsd, etc) + rss /= 1 << 10 + } + return rss +} diff --git a/types/logger/tokenbucket.go b/types/logger/tokenbucket.go index 83d4059c2af00..2407e01a7abc4 100644 --- a/types/logger/tokenbucket.go +++ b/types/logger/tokenbucket.go @@ -1,63 +1,63 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package logger - -import ( - "time" -) - -// tokenBucket is a simple token bucket style rate limiter. - -// It's similar in function to golang.org/x/time/rate.Limiter, which we -// can't use because: -// - It doesn't give access to the number of accumulated tokens, which we -// need for implementing hysteresis; -// - It doesn't let us provide our own time function, which we need for -// implementing proper unit tests. -// -// rate.Limiter is also much more complex than necessary, but that wouldn't -// be enough to disqualify it on its own. -// -// Unlike rate.Limiter, this token bucket does not attempt to -// do any locking of its own. Don't try to access it reentrantly. -// That's fine inside this types/logger package because we already have -// locking at a higher level. -type tokenBucket struct { - remaining int - max int - tick time.Duration - t time.Time -} - -func newTokenBucket(tick time.Duration, max int, now time.Time) *tokenBucket { - return &tokenBucket{max, max, tick, now} -} - -func (tb *tokenBucket) Get() bool { - if tb.remaining > 0 { - tb.remaining-- - return true - } - return false -} - -func (tb *tokenBucket) Refund(n int) { - b := tb.remaining + n - if b > tb.max { - tb.remaining = tb.max - } else { - tb.remaining = b - } -} - -func (tb *tokenBucket) AdvanceTo(t time.Time) { - diff := t.Sub(tb.t) - - // only use up whole ticks. The remainder will be used up - // next time. - ticks := int(diff / tb.tick) - tb.t = tb.t.Add(time.Duration(ticks) * tb.tick) - - tb.Refund(ticks) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package logger + +import ( + "time" +) + +// tokenBucket is a simple token bucket style rate limiter. + +// It's similar in function to golang.org/x/time/rate.Limiter, which we +// can't use because: +// - It doesn't give access to the number of accumulated tokens, which we +// need for implementing hysteresis; +// - It doesn't let us provide our own time function, which we need for +// implementing proper unit tests. +// +// rate.Limiter is also much more complex than necessary, but that wouldn't +// be enough to disqualify it on its own. +// +// Unlike rate.Limiter, this token bucket does not attempt to +// do any locking of its own. Don't try to access it reentrantly. +// That's fine inside this types/logger package because we already have +// locking at a higher level. +type tokenBucket struct { + remaining int + max int + tick time.Duration + t time.Time +} + +func newTokenBucket(tick time.Duration, max int, now time.Time) *tokenBucket { + return &tokenBucket{max, max, tick, now} +} + +func (tb *tokenBucket) Get() bool { + if tb.remaining > 0 { + tb.remaining-- + return true + } + return false +} + +func (tb *tokenBucket) Refund(n int) { + b := tb.remaining + n + if b > tb.max { + tb.remaining = tb.max + } else { + tb.remaining = b + } +} + +func (tb *tokenBucket) AdvanceTo(t time.Time) { + diff := t.Sub(tb.t) + + // only use up whole ticks. The remainder will be used up + // next time. + ticks := int(diff / tb.tick) + tb.t = tb.t.Add(time.Duration(ticks) * tb.tick) + + tb.Refund(ticks) +} diff --git a/types/netlogtype/netlogtype.go b/types/netlogtype/netlogtype.go index f2fa2bda92366..56002628e94e0 100644 --- a/types/netlogtype/netlogtype.go +++ b/types/netlogtype/netlogtype.go @@ -1,100 +1,100 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netlogtype defines types for network logging. -package netlogtype - -import ( - "net/netip" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/types/ipproto" -) - -// TODO(joetsai): Remove "omitempty" if "omitzero" is ever supported in both -// the v1 and v2 "json" packages. - -// Message is the log message that captures network traffic. -type Message struct { - NodeID tailcfg.StableNodeID `json:"nodeId" cbor:"0,keyasint"` // e.g., "n123456CNTRL" - - Start time.Time `json:"start" cbor:"12,keyasint"` // inclusive - End time.Time `json:"end" cbor:"13,keyasint"` // inclusive - - VirtualTraffic []ConnectionCounts `json:"virtualTraffic,omitempty" cbor:"14,keyasint,omitempty"` - SubnetTraffic []ConnectionCounts `json:"subnetTraffic,omitempty" cbor:"15,keyasint,omitempty"` - ExitTraffic []ConnectionCounts `json:"exitTraffic,omitempty" cbor:"16,keyasint,omitempty"` - PhysicalTraffic []ConnectionCounts `json:"physicalTraffic,omitempty" cbor:"17,keyasint,omitempty"` -} - -const ( - messageJSON = `{"nodeId":"n0123456789abcdefCNTRL",` + maxJSONTimeRange + `,` + minJSONTraffic + `}` - maxJSONTimeRange = `"start":` + maxJSONRFC3339 + `,"end":` + maxJSONRFC3339 - maxJSONRFC3339 = `"0001-01-01T00:00:00.000000000Z"` - minJSONTraffic = `"virtualTraffic":{},"subnetTraffic":{},"exitTraffic":{},"physicalTraffic":{}` - - // MaxMessageJSONSize is the overhead size of Message when it is - // serialized as JSON assuming that each traffic map is populated. - MaxMessageJSONSize = len(messageJSON) - - maxJSONConnCounts = `{` + maxJSONConn + `,` + maxJSONCounts + `}` - maxJSONConn = `"proto":` + maxJSONProto + `,"src":` + maxJSONAddrPort + `,"dst":` + maxJSONAddrPort - maxJSONProto = `255` - maxJSONAddrPort = `"[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535"` - maxJSONCounts = `"txPkts":` + maxJSONCount + `,"txBytes":` + maxJSONCount + `,"rxPkts":` + maxJSONCount + `,"rxBytes":` + maxJSONCount - maxJSONCount = `18446744073709551615` - - // MaxConnectionCountsJSONSize is the maximum size of a ConnectionCounts - // when it is serialized as JSON, assuming no superfluous whitespace. - // It does not include the trailing comma that often appears when - // this object is nested within an array. - // It assumes that netip.Addr never has IPv6 zones. - MaxConnectionCountsJSONSize = len(maxJSONConnCounts) - - maxCBORConnCounts = "\xbf" + maxCBORConn + maxCBORCounts + "\xff" - maxCBORConn = "\x00" + maxCBORProto + "\x01" + maxCBORAddrPort + "\x02" + maxCBORAddrPort - maxCBORProto = "\x18\xff" - maxCBORAddrPort = "\x52\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" - maxCBORCounts = "\x0c" + maxCBORCount + "\x0d" + maxCBORCount + "\x0e" + maxCBORCount + "\x0f" + maxCBORCount - maxCBORCount = "\x1b\xff\xff\xff\xff\xff\xff\xff\xff" - - // MaxConnectionCountsCBORSize is the maximum size of a ConnectionCounts - // when it is serialized as CBOR. - // It assumes that netip.Addr never has IPv6 zones. - MaxConnectionCountsCBORSize = len(maxCBORConnCounts) -) - -// ConnectionCounts is a flattened struct of both a connection and counts. -type ConnectionCounts struct { - Connection - Counts -} - -// Connection is a 5-tuple of proto, source and destination IP and port. -type Connection struct { - Proto ipproto.Proto `json:"proto,omitzero,omitempty" cbor:"0,keyasint,omitempty"` - Src netip.AddrPort `json:"src,omitzero,omitempty" cbor:"1,keyasint,omitempty"` - Dst netip.AddrPort `json:"dst,omitzero,omitempty" cbor:"2,keyasint,omitempty"` -} - -func (c Connection) IsZero() bool { return c == Connection{} } - -// Counts are statistics about a particular connection. -type Counts struct { - TxPackets uint64 `json:"txPkts,omitzero,omitempty" cbor:"12,keyasint,omitempty"` - TxBytes uint64 `json:"txBytes,omitzero,omitempty" cbor:"13,keyasint,omitempty"` - RxPackets uint64 `json:"rxPkts,omitzero,omitempty" cbor:"14,keyasint,omitempty"` - RxBytes uint64 `json:"rxBytes,omitzero,omitempty" cbor:"15,keyasint,omitempty"` -} - -func (c Counts) IsZero() bool { return c == Counts{} } - -// Add adds the counts from both c1 and c2. -func (c1 Counts) Add(c2 Counts) Counts { - c1.TxPackets += c2.TxPackets - c1.TxBytes += c2.TxBytes - c1.RxPackets += c2.RxPackets - c1.RxBytes += c2.RxBytes - return c1 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netlogtype defines types for network logging. +package netlogtype + +import ( + "net/netip" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/types/ipproto" +) + +// TODO(joetsai): Remove "omitempty" if "omitzero" is ever supported in both +// the v1 and v2 "json" packages. + +// Message is the log message that captures network traffic. +type Message struct { + NodeID tailcfg.StableNodeID `json:"nodeId" cbor:"0,keyasint"` // e.g., "n123456CNTRL" + + Start time.Time `json:"start" cbor:"12,keyasint"` // inclusive + End time.Time `json:"end" cbor:"13,keyasint"` // inclusive + + VirtualTraffic []ConnectionCounts `json:"virtualTraffic,omitempty" cbor:"14,keyasint,omitempty"` + SubnetTraffic []ConnectionCounts `json:"subnetTraffic,omitempty" cbor:"15,keyasint,omitempty"` + ExitTraffic []ConnectionCounts `json:"exitTraffic,omitempty" cbor:"16,keyasint,omitempty"` + PhysicalTraffic []ConnectionCounts `json:"physicalTraffic,omitempty" cbor:"17,keyasint,omitempty"` +} + +const ( + messageJSON = `{"nodeId":"n0123456789abcdefCNTRL",` + maxJSONTimeRange + `,` + minJSONTraffic + `}` + maxJSONTimeRange = `"start":` + maxJSONRFC3339 + `,"end":` + maxJSONRFC3339 + maxJSONRFC3339 = `"0001-01-01T00:00:00.000000000Z"` + minJSONTraffic = `"virtualTraffic":{},"subnetTraffic":{},"exitTraffic":{},"physicalTraffic":{}` + + // MaxMessageJSONSize is the overhead size of Message when it is + // serialized as JSON assuming that each traffic map is populated. + MaxMessageJSONSize = len(messageJSON) + + maxJSONConnCounts = `{` + maxJSONConn + `,` + maxJSONCounts + `}` + maxJSONConn = `"proto":` + maxJSONProto + `,"src":` + maxJSONAddrPort + `,"dst":` + maxJSONAddrPort + maxJSONProto = `255` + maxJSONAddrPort = `"[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535"` + maxJSONCounts = `"txPkts":` + maxJSONCount + `,"txBytes":` + maxJSONCount + `,"rxPkts":` + maxJSONCount + `,"rxBytes":` + maxJSONCount + maxJSONCount = `18446744073709551615` + + // MaxConnectionCountsJSONSize is the maximum size of a ConnectionCounts + // when it is serialized as JSON, assuming no superfluous whitespace. + // It does not include the trailing comma that often appears when + // this object is nested within an array. + // It assumes that netip.Addr never has IPv6 zones. + MaxConnectionCountsJSONSize = len(maxJSONConnCounts) + + maxCBORConnCounts = "\xbf" + maxCBORConn + maxCBORCounts + "\xff" + maxCBORConn = "\x00" + maxCBORProto + "\x01" + maxCBORAddrPort + "\x02" + maxCBORAddrPort + maxCBORProto = "\x18\xff" + maxCBORAddrPort = "\x52\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + maxCBORCounts = "\x0c" + maxCBORCount + "\x0d" + maxCBORCount + "\x0e" + maxCBORCount + "\x0f" + maxCBORCount + maxCBORCount = "\x1b\xff\xff\xff\xff\xff\xff\xff\xff" + + // MaxConnectionCountsCBORSize is the maximum size of a ConnectionCounts + // when it is serialized as CBOR. + // It assumes that netip.Addr never has IPv6 zones. + MaxConnectionCountsCBORSize = len(maxCBORConnCounts) +) + +// ConnectionCounts is a flattened struct of both a connection and counts. +type ConnectionCounts struct { + Connection + Counts +} + +// Connection is a 5-tuple of proto, source and destination IP and port. +type Connection struct { + Proto ipproto.Proto `json:"proto,omitzero,omitempty" cbor:"0,keyasint,omitempty"` + Src netip.AddrPort `json:"src,omitzero,omitempty" cbor:"1,keyasint,omitempty"` + Dst netip.AddrPort `json:"dst,omitzero,omitempty" cbor:"2,keyasint,omitempty"` +} + +func (c Connection) IsZero() bool { return c == Connection{} } + +// Counts are statistics about a particular connection. +type Counts struct { + TxPackets uint64 `json:"txPkts,omitzero,omitempty" cbor:"12,keyasint,omitempty"` + TxBytes uint64 `json:"txBytes,omitzero,omitempty" cbor:"13,keyasint,omitempty"` + RxPackets uint64 `json:"rxPkts,omitzero,omitempty" cbor:"14,keyasint,omitempty"` + RxBytes uint64 `json:"rxBytes,omitzero,omitempty" cbor:"15,keyasint,omitempty"` +} + +func (c Counts) IsZero() bool { return c == Counts{} } + +// Add adds the counts from both c1 and c2. +func (c1 Counts) Add(c2 Counts) Counts { + c1.TxPackets += c2.TxPackets + c1.TxBytes += c2.TxBytes + c1.RxPackets += c2.RxPackets + c1.RxBytes += c2.RxBytes + return c1 +} diff --git a/types/netlogtype/netlogtype_test.go b/types/netlogtype/netlogtype_test.go index 7f29090c5f757..1fa604b317de4 100644 --- a/types/netlogtype/netlogtype_test.go +++ b/types/netlogtype/netlogtype_test.go @@ -1,39 +1,39 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netlogtype - -import ( - "encoding/json" - "math" - "net/netip" - "testing" - - "github.com/fxamacker/cbor/v2" - "github.com/google/go-cmp/cmp" - "tailscale.com/util/must" -) - -func TestMaxSize(t *testing.T) { - maxAddr := netip.AddrFrom16([16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}) - maxAddrPort := netip.AddrPortFrom(maxAddr, math.MaxUint16) - cc := ConnectionCounts{ - // NOTE: These composite literals are deliberately unkeyed so that - // added fields result in a build failure here. - // Newly added fields should result in an update to both - // MaxConnectionCountsJSONSize and MaxConnectionCountsCBORSize. - Connection{math.MaxUint8, maxAddrPort, maxAddrPort}, - Counts{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64}, - } - - outJSON := must.Get(json.Marshal(cc)) - if string(outJSON) != maxJSONConnCounts { - t.Errorf("JSON mismatch (-got +want):\n%s", cmp.Diff(string(outJSON), maxJSONConnCounts)) - } - - outCBOR := must.Get(cbor.Marshal(cc)) - maxCBORConnCountsAlt := "\xa7" + maxCBORConnCounts[1:len(maxCBORConnCounts)-1] // may use a definite encoding of map - if string(outCBOR) != maxCBORConnCounts && string(outCBOR) != maxCBORConnCountsAlt { - t.Errorf("CBOR mismatch (-got +want):\n%s", cmp.Diff(string(outCBOR), maxCBORConnCounts)) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netlogtype + +import ( + "encoding/json" + "math" + "net/netip" + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/google/go-cmp/cmp" + "tailscale.com/util/must" +) + +func TestMaxSize(t *testing.T) { + maxAddr := netip.AddrFrom16([16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}) + maxAddrPort := netip.AddrPortFrom(maxAddr, math.MaxUint16) + cc := ConnectionCounts{ + // NOTE: These composite literals are deliberately unkeyed so that + // added fields result in a build failure here. + // Newly added fields should result in an update to both + // MaxConnectionCountsJSONSize and MaxConnectionCountsCBORSize. + Connection{math.MaxUint8, maxAddrPort, maxAddrPort}, + Counts{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64}, + } + + outJSON := must.Get(json.Marshal(cc)) + if string(outJSON) != maxJSONConnCounts { + t.Errorf("JSON mismatch (-got +want):\n%s", cmp.Diff(string(outJSON), maxJSONConnCounts)) + } + + outCBOR := must.Get(cbor.Marshal(cc)) + maxCBORConnCountsAlt := "\xa7" + maxCBORConnCounts[1:len(maxCBORConnCounts)-1] // may use a definite encoding of map + if string(outCBOR) != maxCBORConnCounts && string(outCBOR) != maxCBORConnCountsAlt { + t.Errorf("CBOR mismatch (-got +want):\n%s", cmp.Diff(string(outCBOR), maxCBORConnCounts)) + } +} diff --git a/types/netmap/netmap_test.go b/types/netmap/netmap_test.go index e7e2d19575c44..910b6bc21fc8d 100644 --- a/types/netmap/netmap_test.go +++ b/types/netmap/netmap_test.go @@ -1,318 +1,318 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netmap - -import ( - "encoding/hex" - "net/netip" - "testing" - - "go4.org/mem" - "tailscale.com/net/netaddr" - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -func testNodeKey(b byte) (ret key.NodePublic) { - var bs [key.NodePublicRawLen]byte - for i := range bs { - bs[i] = b - } - return key.NodePublicFromRaw32(mem.B(bs[:])) -} - -func testDiscoKey(hexPrefix string) (ret key.DiscoPublic) { - b, err := hex.DecodeString(hexPrefix) - if err != nil { - panic(err) - } - // this function is used with short hexes, so zero-extend the raw - // value. - var bs [32]byte - copy(bs[:], b) - return key.DiscoPublicFromRaw32(mem.B(bs[:])) -} - -func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { - nv := make([]tailcfg.NodeView, len(v)) - for i, n := range v { - nv[i] = n.View() - } - return nv -} - -func eps(s ...string) []netip.AddrPort { - var eps []netip.AddrPort - for _, ep := range s { - eps = append(eps, netip.MustParseAddrPort(ep)) - } - return eps -} - -func TestNetworkMapConcise(t *testing.T) { - for _, tt := range []struct { - name string - nm *NetworkMap - want string - }{ - { - name: "basic", - nm: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - { - Key: testNodeKey(3), - DERP: "127.3.3.40:4", - Endpoints: eps("10.2.0.100:12", "10.1.0.100:12345"), - }, - }), - }, - want: "netmap: self: [AQEBA] auth=machine-unknown u=? []\n [AgICA] D2 : 192.168.0.100:12 192.168.0.100:12354\n [AwMDA] D4 : 10.2.0.100:12 10.1.0.100:12345\n", - }, - } { - t.Run(tt.name, func(t *testing.T) { - var got string - n := int(testing.AllocsPerRun(1000, func() { - got = tt.nm.Concise() - })) - t.Logf("Allocs = %d", n) - if got != tt.want { - t.Errorf("Wrong output\n Got: %q\nWant: %q\n## Got (unescaped):\n%s\n## Want (unescaped):\n%s\n", got, tt.want, got, tt.want) - } - }) - } -} - -func TestConciseDiffFrom(t *testing.T) { - for _, tt := range []struct { - name string - a, b *NetworkMap - want string - }{ - { - name: "no_change", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - want: "", - }, - { - name: "header_change", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(2), - Peers: nodeViews([]*tailcfg.Node{ - { - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - want: "-netmap: self: [AQEBA] auth=machine-unknown u=? []\n+netmap: self: [AgICA] auth=machine-unknown u=? []\n", - }, - { - name: "peer_add", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 1, - Key: testNodeKey(1), - DERP: "127.3.3.40:1", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - { - ID: 3, - Key: testNodeKey(3), - DERP: "127.3.3.40:3", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - want: "+ [AQEBA] D1 : 192.168.0.100:12 192.168.0.100:12354\n+ [AwMDA] D3 : 192.168.0.100:12 192.168.0.100:12354\n", - }, - { - name: "peer_remove", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 1, - Key: testNodeKey(1), - DERP: "127.3.3.40:1", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - { - ID: 3, - Key: testNodeKey(3), - DERP: "127.3.3.40:3", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - want: "- [AQEBA] D1 : 192.168.0.100:12 192.168.0.100:12354\n- [AwMDA] D3 : 192.168.0.100:12 192.168.0.100:12354\n", - }, - { - name: "peer_port_change", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "1.1.1.1:1"), - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "1.1.1.1:2"), - }, - }), - }, - want: "- [AgICA] D2 : 192.168.0.100:12 1.1.1.1:1 \n+ [AgICA] D2 : 192.168.0.100:12 1.1.1.1:2 \n", - }, - { - name: "disco_key_only_change", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), - DiscoKey: testDiscoKey("f00f00f00f"), - AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), - DiscoKey: testDiscoKey("ba4ba4ba4b"), - AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, - }, - }), - }, - want: "- [AgICA] d:f00f00f00f000000 D2 100.102.103.104 : 192.168.0.100:41641 1.1.1.1:41641\n+ [AgICA] d:ba4ba4ba4b000000 D2 100.102.103.104 : 192.168.0.100:41641 1.1.1.1:41641\n", - }, - } { - t.Run(tt.name, func(t *testing.T) { - var got string - n := int(testing.AllocsPerRun(50, func() { - got = tt.b.ConciseDiffFrom(tt.a) - })) - t.Logf("Allocs = %d", n) - if got != tt.want { - t.Errorf("Wrong output\n Got: %q\nWant: %q\n## Got (unescaped):\n%s\n## Want (unescaped):\n%s\n", got, tt.want, got, tt.want) - } - }) - } -} - -func TestPeerIndexByNodeID(t *testing.T) { - var nilPtr *NetworkMap - if nilPtr.PeerIndexByNodeID(123) != -1 { - t.Errorf("nil PeerIndexByNodeID should return -1") - } - var nm NetworkMap - const min = 2 - const max = 10000 - const hole = max / 2 - for nid := tailcfg.NodeID(2); nid <= max; nid++ { - if nid == hole { - continue - } - nm.Peers = append(nm.Peers, (&tailcfg.Node{ID: nid}).View()) - } - for want, nv := range nm.Peers { - got := nm.PeerIndexByNodeID(nv.ID()) - if got != want { - t.Errorf("PeerIndexByNodeID(%v) = %v; want %v", nv.ID(), got, want) - } - } - for _, miss := range []tailcfg.NodeID{min - 1, hole, max + 1} { - if got := nm.PeerIndexByNodeID(miss); got != -1 { - t.Errorf("PeerIndexByNodeID(%v) = %v; want -1", miss, got) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netmap + +import ( + "encoding/hex" + "net/netip" + "testing" + + "go4.org/mem" + "tailscale.com/net/netaddr" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +func testNodeKey(b byte) (ret key.NodePublic) { + var bs [key.NodePublicRawLen]byte + for i := range bs { + bs[i] = b + } + return key.NodePublicFromRaw32(mem.B(bs[:])) +} + +func testDiscoKey(hexPrefix string) (ret key.DiscoPublic) { + b, err := hex.DecodeString(hexPrefix) + if err != nil { + panic(err) + } + // this function is used with short hexes, so zero-extend the raw + // value. + var bs [32]byte + copy(bs[:], b) + return key.DiscoPublicFromRaw32(mem.B(bs[:])) +} + +func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { + nv := make([]tailcfg.NodeView, len(v)) + for i, n := range v { + nv[i] = n.View() + } + return nv +} + +func eps(s ...string) []netip.AddrPort { + var eps []netip.AddrPort + for _, ep := range s { + eps = append(eps, netip.MustParseAddrPort(ep)) + } + return eps +} + +func TestNetworkMapConcise(t *testing.T) { + for _, tt := range []struct { + name string + nm *NetworkMap + want string + }{ + { + name: "basic", + nm: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + { + Key: testNodeKey(3), + DERP: "127.3.3.40:4", + Endpoints: eps("10.2.0.100:12", "10.1.0.100:12345"), + }, + }), + }, + want: "netmap: self: [AQEBA] auth=machine-unknown u=? []\n [AgICA] D2 : 192.168.0.100:12 192.168.0.100:12354\n [AwMDA] D4 : 10.2.0.100:12 10.1.0.100:12345\n", + }, + } { + t.Run(tt.name, func(t *testing.T) { + var got string + n := int(testing.AllocsPerRun(1000, func() { + got = tt.nm.Concise() + })) + t.Logf("Allocs = %d", n) + if got != tt.want { + t.Errorf("Wrong output\n Got: %q\nWant: %q\n## Got (unescaped):\n%s\n## Want (unescaped):\n%s\n", got, tt.want, got, tt.want) + } + }) + } +} + +func TestConciseDiffFrom(t *testing.T) { + for _, tt := range []struct { + name string + a, b *NetworkMap + want string + }{ + { + name: "no_change", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + want: "", + }, + { + name: "header_change", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(2), + Peers: nodeViews([]*tailcfg.Node{ + { + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + want: "-netmap: self: [AQEBA] auth=machine-unknown u=? []\n+netmap: self: [AgICA] auth=machine-unknown u=? []\n", + }, + { + name: "peer_add", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 1, + Key: testNodeKey(1), + DERP: "127.3.3.40:1", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + { + ID: 3, + Key: testNodeKey(3), + DERP: "127.3.3.40:3", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + want: "+ [AQEBA] D1 : 192.168.0.100:12 192.168.0.100:12354\n+ [AwMDA] D3 : 192.168.0.100:12 192.168.0.100:12354\n", + }, + { + name: "peer_remove", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 1, + Key: testNodeKey(1), + DERP: "127.3.3.40:1", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + { + ID: 3, + Key: testNodeKey(3), + DERP: "127.3.3.40:3", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + want: "- [AQEBA] D1 : 192.168.0.100:12 192.168.0.100:12354\n- [AwMDA] D3 : 192.168.0.100:12 192.168.0.100:12354\n", + }, + { + name: "peer_port_change", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "1.1.1.1:1"), + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "1.1.1.1:2"), + }, + }), + }, + want: "- [AgICA] D2 : 192.168.0.100:12 1.1.1.1:1 \n+ [AgICA] D2 : 192.168.0.100:12 1.1.1.1:2 \n", + }, + { + name: "disco_key_only_change", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), + DiscoKey: testDiscoKey("f00f00f00f"), + AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), + DiscoKey: testDiscoKey("ba4ba4ba4b"), + AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, + }, + }), + }, + want: "- [AgICA] d:f00f00f00f000000 D2 100.102.103.104 : 192.168.0.100:41641 1.1.1.1:41641\n+ [AgICA] d:ba4ba4ba4b000000 D2 100.102.103.104 : 192.168.0.100:41641 1.1.1.1:41641\n", + }, + } { + t.Run(tt.name, func(t *testing.T) { + var got string + n := int(testing.AllocsPerRun(50, func() { + got = tt.b.ConciseDiffFrom(tt.a) + })) + t.Logf("Allocs = %d", n) + if got != tt.want { + t.Errorf("Wrong output\n Got: %q\nWant: %q\n## Got (unescaped):\n%s\n## Want (unescaped):\n%s\n", got, tt.want, got, tt.want) + } + }) + } +} + +func TestPeerIndexByNodeID(t *testing.T) { + var nilPtr *NetworkMap + if nilPtr.PeerIndexByNodeID(123) != -1 { + t.Errorf("nil PeerIndexByNodeID should return -1") + } + var nm NetworkMap + const min = 2 + const max = 10000 + const hole = max / 2 + for nid := tailcfg.NodeID(2); nid <= max; nid++ { + if nid == hole { + continue + } + nm.Peers = append(nm.Peers, (&tailcfg.Node{ID: nid}).View()) + } + for want, nv := range nm.Peers { + got := nm.PeerIndexByNodeID(nv.ID()) + if got != want { + t.Errorf("PeerIndexByNodeID(%v) = %v; want %v", nv.ID(), got, want) + } + } + for _, miss := range []tailcfg.NodeID{min - 1, hole, max + 1} { + if got := nm.PeerIndexByNodeID(miss); got != -1 { + t.Errorf("PeerIndexByNodeID(%v) = %v; want -1", miss, got) + } + } +} diff --git a/types/nettype/nettype.go b/types/nettype/nettype.go index 5d3d303c38a0d..8930c36d845b6 100644 --- a/types/nettype/nettype.go +++ b/types/nettype/nettype.go @@ -1,65 +1,65 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package nettype defines an interface that doesn't exist in the Go net package. -package nettype - -import ( - "context" - "io" - "net" - "net/netip" - "time" -) - -// PacketListener defines the ListenPacket method as implemented -// by net.ListenConfig, net.ListenPacket, and tstest/natlab. -type PacketListener interface { - ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) -} - -type PacketListenerWithNetIP interface { - ListenPacket(ctx context.Context, network, address string) (PacketConn, error) -} - -// Std implements PacketListener using the Go net package's ListenPacket func. -type Std struct{} - -func (Std) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { - var conf net.ListenConfig - return conf.ListenPacket(ctx, network, address) -} - -// PacketConn is like a net.PacketConn but uses the newer netip.AddrPort -// write/read methods. -type PacketConn interface { - WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) - ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error) - io.Closer - LocalAddr() net.Addr - SetDeadline(time.Time) error - SetReadDeadline(time.Time) error - SetWriteDeadline(time.Time) error -} - -func MakePacketListenerWithNetIP(ln PacketListener) PacketListenerWithNetIP { - return packetListenerAdapter{ln} -} - -type packetListenerAdapter struct { - PacketListener -} - -func (a packetListenerAdapter) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) { - pc, err := a.PacketListener.ListenPacket(ctx, network, address) - if err != nil { - return nil, err - } - return pc.(PacketConn), nil -} - -// ConnPacketConn is the interface that's a superset of net.Conn and net.PacketConn. -type ConnPacketConn interface { - net.Conn - net.PacketConn -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package nettype defines an interface that doesn't exist in the Go net package. +package nettype + +import ( + "context" + "io" + "net" + "net/netip" + "time" +) + +// PacketListener defines the ListenPacket method as implemented +// by net.ListenConfig, net.ListenPacket, and tstest/natlab. +type PacketListener interface { + ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) +} + +type PacketListenerWithNetIP interface { + ListenPacket(ctx context.Context, network, address string) (PacketConn, error) +} + +// Std implements PacketListener using the Go net package's ListenPacket func. +type Std struct{} + +func (Std) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + var conf net.ListenConfig + return conf.ListenPacket(ctx, network, address) +} + +// PacketConn is like a net.PacketConn but uses the newer netip.AddrPort +// write/read methods. +type PacketConn interface { + WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) + ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error) + io.Closer + LocalAddr() net.Addr + SetDeadline(time.Time) error + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error +} + +func MakePacketListenerWithNetIP(ln PacketListener) PacketListenerWithNetIP { + return packetListenerAdapter{ln} +} + +type packetListenerAdapter struct { + PacketListener +} + +func (a packetListenerAdapter) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) { + pc, err := a.PacketListener.ListenPacket(ctx, network, address) + if err != nil { + return nil, err + } + return pc.(PacketConn), nil +} + +// ConnPacketConn is the interface that's a superset of net.Conn and net.PacketConn. +type ConnPacketConn interface { + net.Conn + net.PacketConn +} diff --git a/types/preftype/netfiltermode.go b/types/preftype/netfiltermode.go index 273e173444365..5756e50968fa5 100644 --- a/types/preftype/netfiltermode.go +++ b/types/preftype/netfiltermode.go @@ -1,46 +1,46 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package preftype is a leaf package containing types for various -// preferences. -package preftype - -import "fmt" - -// NetfilterMode is the firewall management mode to use when -// programming the Linux network stack. -type NetfilterMode int - -// These numbers are persisted to disk in JSON files and thus can't be -// renumbered or repurposed. -const ( - NetfilterOff NetfilterMode = 0 // remove all tailscale netfilter state - NetfilterNoDivert NetfilterMode = 1 // manage tailscale chains, but don't call them - NetfilterOn NetfilterMode = 2 // manage tailscale chains and call them from main chains -) - -func ParseNetfilterMode(s string) (NetfilterMode, error) { - switch s { - case "off": - return NetfilterOff, nil - case "nodivert": - return NetfilterNoDivert, nil - case "on": - return NetfilterOn, nil - default: - return NetfilterOff, fmt.Errorf("unknown netfilter mode %q", s) - } -} - -func (m NetfilterMode) String() string { - switch m { - case NetfilterOff: - return "off" - case NetfilterNoDivert: - return "nodivert" - case NetfilterOn: - return "on" - default: - return "???" - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package preftype is a leaf package containing types for various +// preferences. +package preftype + +import "fmt" + +// NetfilterMode is the firewall management mode to use when +// programming the Linux network stack. +type NetfilterMode int + +// These numbers are persisted to disk in JSON files and thus can't be +// renumbered or repurposed. +const ( + NetfilterOff NetfilterMode = 0 // remove all tailscale netfilter state + NetfilterNoDivert NetfilterMode = 1 // manage tailscale chains, but don't call them + NetfilterOn NetfilterMode = 2 // manage tailscale chains and call them from main chains +) + +func ParseNetfilterMode(s string) (NetfilterMode, error) { + switch s { + case "off": + return NetfilterOff, nil + case "nodivert": + return NetfilterNoDivert, nil + case "on": + return NetfilterOn, nil + default: + return NetfilterOff, fmt.Errorf("unknown netfilter mode %q", s) + } +} + +func (m NetfilterMode) String() string { + switch m { + case NetfilterOff: + return "off" + case NetfilterNoDivert: + return "nodivert" + case NetfilterOn: + return "on" + default: + return "???" + } +} diff --git a/types/ptr/ptr.go b/types/ptr/ptr.go index beb17bee8ee0e..beb955bf00b61 100644 --- a/types/ptr/ptr.go +++ b/types/ptr/ptr.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package ptr contains the ptr.To function. -package ptr - -// To returns a pointer to a shallow copy of v. -func To[T any](v T) *T { - return &v -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ptr contains the ptr.To function. +package ptr + +// To returns a pointer to a shallow copy of v. +func To[T any](v T) *T { + return &v +} diff --git a/types/structs/structs.go b/types/structs/structs.go index 47c359f0caa0f..bac6b29917318 100644 --- a/types/structs/structs.go +++ b/types/structs/structs.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package structs contains the Incomparable type. -package structs - -// Incomparable is a zero-width incomparable type. If added as the -// first field in a struct, it marks that struct as not comparable -// (can't do == or be a map key) and usually doesn't add any width to -// the struct (unless the struct has only small fields). -// -// Be making a struct incomparable, you can prevent misuse (prevent -// people from using ==), but also you can shrink generated binaries, -// as the compiler can omit equality funcs from the binary. -type Incomparable [0]func() +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package structs contains the Incomparable type. +package structs + +// Incomparable is a zero-width incomparable type. If added as the +// first field in a struct, it marks that struct as not comparable +// (can't do == or be a map key) and usually doesn't add any width to +// the struct (unless the struct has only small fields). +// +// Be making a struct incomparable, you can prevent misuse (prevent +// people from using ==), but also you can shrink generated binaries, +// as the compiler can omit equality funcs from the binary. +type Incomparable [0]func() diff --git a/types/tkatype/tkatype.go b/types/tkatype/tkatype.go index 6ad51f6a90240..aca6f144303d0 100644 --- a/types/tkatype/tkatype.go +++ b/types/tkatype/tkatype.go @@ -1,40 +1,40 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package tkatype defines types for working with the tka package. -// -// Do not add extra dependencies to this package unless they are tiny, -// because this package encodes wire types that should be lightweight to use. -package tkatype - -// KeyID references a verification key stored in the key authority. A keyID -// uniquely identifies a key. KeyIDs are all 32 bytes. -// -// For 25519 keys: We just use the 32-byte public key. -// -// Even though this is a 32-byte value, we use a byte slice because -// CBOR-encoded byte slices have a different prefix to CBOR-encoded arrays. -// Encoding as a byte slice allows us to change the size in the future if we -// ever need to. -type KeyID []byte - -// MarshaledSignature represents a marshaled tka.NodeKeySignature. -type MarshaledSignature []byte - -// MarshaledAUM represents a marshaled tka.AUM. -type MarshaledAUM []byte - -// AUMSigHash represents the BLAKE2s digest of an Authority Update -// Message (AUM), sans any signatures. -type AUMSigHash [32]byte - -// NKSSigHash represents the BLAKE2s digest of a Node-Key Signature (NKS), -// sans the Signature field if present. -type NKSSigHash [32]byte - -// Signature describes a signature over an AUM, which can be verified -// using the key referenced by KeyID. -type Signature struct { - KeyID KeyID `cbor:"1,keyasint"` - Signature []byte `cbor:"2,keyasint"` -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tkatype defines types for working with the tka package. +// +// Do not add extra dependencies to this package unless they are tiny, +// because this package encodes wire types that should be lightweight to use. +package tkatype + +// KeyID references a verification key stored in the key authority. A keyID +// uniquely identifies a key. KeyIDs are all 32 bytes. +// +// For 25519 keys: We just use the 32-byte public key. +// +// Even though this is a 32-byte value, we use a byte slice because +// CBOR-encoded byte slices have a different prefix to CBOR-encoded arrays. +// Encoding as a byte slice allows us to change the size in the future if we +// ever need to. +type KeyID []byte + +// MarshaledSignature represents a marshaled tka.NodeKeySignature. +type MarshaledSignature []byte + +// MarshaledAUM represents a marshaled tka.AUM. +type MarshaledAUM []byte + +// AUMSigHash represents the BLAKE2s digest of an Authority Update +// Message (AUM), sans any signatures. +type AUMSigHash [32]byte + +// NKSSigHash represents the BLAKE2s digest of a Node-Key Signature (NKS), +// sans the Signature field if present. +type NKSSigHash [32]byte + +// Signature describes a signature over an AUM, which can be verified +// using the key referenced by KeyID. +type Signature struct { + KeyID KeyID `cbor:"1,keyasint"` + Signature []byte `cbor:"2,keyasint"` +} diff --git a/types/tkatype/tkatype_test.go b/types/tkatype/tkatype_test.go index c81891b9ce103..bff90807240e1 100644 --- a/types/tkatype/tkatype_test.go +++ b/types/tkatype/tkatype_test.go @@ -1,43 +1,43 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tkatype - -import ( - "encoding/json" - "testing" - - "golang.org/x/crypto/blake2s" -) - -func TestSigHashSize(t *testing.T) { - var sigHash AUMSigHash - if len(sigHash) != blake2s.Size { - t.Errorf("AUMSigHash is wrong size: got %d, want %d", len(sigHash), blake2s.Size) - } - - var nksHash NKSSigHash - if len(nksHash) != blake2s.Size { - t.Errorf("NKSSigHash is wrong size: got %d, want %d", len(nksHash), blake2s.Size) - } -} - -func TestMarshaledSignatureJSON(t *testing.T) { - sig := MarshaledSignature("abcdef") - j, err := json.Marshal(sig) - if err != nil { - t.Fatal(err) - } - const encoded = `"YWJjZGVm"` - if string(j) != encoded { - t.Errorf("got JSON %q; want %q", j, encoded) - } - - var back MarshaledSignature - if err := json.Unmarshal([]byte(encoded), &back); err != nil { - t.Fatal(err) - } - if string(back) != string(sig) { - t.Errorf("decoded JSON back to %q; want %q", back, sig) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tkatype + +import ( + "encoding/json" + "testing" + + "golang.org/x/crypto/blake2s" +) + +func TestSigHashSize(t *testing.T) { + var sigHash AUMSigHash + if len(sigHash) != blake2s.Size { + t.Errorf("AUMSigHash is wrong size: got %d, want %d", len(sigHash), blake2s.Size) + } + + var nksHash NKSSigHash + if len(nksHash) != blake2s.Size { + t.Errorf("NKSSigHash is wrong size: got %d, want %d", len(nksHash), blake2s.Size) + } +} + +func TestMarshaledSignatureJSON(t *testing.T) { + sig := MarshaledSignature("abcdef") + j, err := json.Marshal(sig) + if err != nil { + t.Fatal(err) + } + const encoded = `"YWJjZGVm"` + if string(j) != encoded { + t.Errorf("got JSON %q; want %q", j, encoded) + } + + var back MarshaledSignature + if err := json.Unmarshal([]byte(encoded), &back); err != nil { + t.Fatal(err) + } + if string(back) != string(sig) { + t.Errorf("decoded JSON back to %q; want %q", back, sig) + } +} diff --git a/util/cibuild/cibuild.go b/util/cibuild/cibuild.go index c1e337f9a142a..c3dee61548b42 100644 --- a/util/cibuild/cibuild.go +++ b/util/cibuild/cibuild.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package cibuild reports runtime CI information. -package cibuild - -import "os" - -// On reports whether the current binary is executing on a CI system. -func On() bool { - // CI env variable is set by GitHub. - // https://docs.github.com/en/actions/learn-github-actions/environment-variables#default-environment-variables - return os.Getenv("GITHUB_ACTIONS") != "" || os.Getenv("CI") == "true" -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package cibuild reports runtime CI information. +package cibuild + +import "os" + +// On reports whether the current binary is executing on a CI system. +func On() bool { + // CI env variable is set by GitHub. + // https://docs.github.com/en/actions/learn-github-actions/environment-variables#default-environment-variables + return os.Getenv("GITHUB_ACTIONS") != "" || os.Getenv("CI") == "true" +} diff --git a/util/cstruct/cstruct.go b/util/cstruct/cstruct.go index 464dc5dc3cadf..e32c90830e6a7 100644 --- a/util/cstruct/cstruct.go +++ b/util/cstruct/cstruct.go @@ -1,178 +1,178 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package cstruct provides a helper for decoding binary data that is in the -// form of a padded C structure. -package cstruct - -import ( - "errors" - "io" - - "github.com/josharian/native" -) - -// Size of a pointer-typed value, in bits -const pointerSize = 32 << (^uintptr(0) >> 63) - -// We assume that non-64-bit platforms are 32-bit; we don't expect Go to run on -// a 16- or 8-bit architecture any time soon. -const is64Bit = pointerSize == 64 - -// Decoder reads and decodes padded fields from a slice of bytes. All fields -// are decoded with native endianness. -// -// Methods of a Decoder do not return errors, but rather store any error within -// the Decoder. The first error can be obtained via the Err method; after the -// first error, methods will return the zero value for their type. -type Decoder struct { - b []byte - off int - err error - dbuf [8]byte // for decoding -} - -// NewDecoder creates a Decoder from a byte slice. -func NewDecoder(b []byte) *Decoder { - return &Decoder{b: b} -} - -var errUnsupportedSize = errors.New("unsupported size") - -func padBytes(offset, size int) int { - if offset == 0 || size == 1 { - return 0 - } - remainder := offset % size - return size - remainder -} - -func (d *Decoder) getField(b []byte) error { - size := len(b) - - // We only support fields that are multiples of 2 (or 1-sized) - if size != 1 && size&1 == 1 { - return errUnsupportedSize - } - - // Fields are aligned to their size - padBytes := padBytes(d.off, size) - if d.off+size+padBytes > len(d.b) { - return io.EOF - } - d.off += padBytes - - copy(b, d.b[d.off:d.off+size]) - d.off += size - return nil -} - -// Err returns the first error that was encountered by this Decoder. -func (d *Decoder) Err() error { - return d.err -} - -// Offset returns the current read offset for data in the buffer. -func (d *Decoder) Offset() int { - return d.off -} - -// Byte returns a single byte from the buffer. -func (d *Decoder) Byte() byte { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:1]); err != nil { - d.err = err - return 0 - } - return d.dbuf[0] -} - -// Byte returns a number of bytes from the buffer based on the size of the -// input slice. No padding is applied. -// -// If an error is encountered or this Decoder has previously encountered an -// error, no changes are made to the provided buffer. -func (d *Decoder) Bytes(b []byte) { - if d.err != nil { - return - } - - // No padding for byte slices - size := len(b) - if d.off+size >= len(d.b) { - d.err = io.EOF - return - } - copy(b, d.b[d.off:d.off+size]) - d.off += size -} - -// Uint16 returns a uint16 decoded from the buffer. -func (d *Decoder) Uint16() uint16 { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:2]); err != nil { - d.err = err - return 0 - } - return native.Endian.Uint16(d.dbuf[0:2]) -} - -// Uint32 returns a uint32 decoded from the buffer. -func (d *Decoder) Uint32() uint32 { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:4]); err != nil { - d.err = err - return 0 - } - return native.Endian.Uint32(d.dbuf[0:4]) -} - -// Uint64 returns a uint64 decoded from the buffer. -func (d *Decoder) Uint64() uint64 { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:8]); err != nil { - d.err = err - return 0 - } - return native.Endian.Uint64(d.dbuf[0:8]) -} - -// Uintptr returns a uintptr decoded from the buffer. -func (d *Decoder) Uintptr() uintptr { - if d.err != nil { - return 0 - } - - if is64Bit { - return uintptr(d.Uint64()) - } else { - return uintptr(d.Uint32()) - } -} - -// Int16 returns a int16 decoded from the buffer. -func (d *Decoder) Int16() int16 { - return int16(d.Uint16()) -} - -// Int32 returns a int32 decoded from the buffer. -func (d *Decoder) Int32() int32 { - return int32(d.Uint32()) -} - -// Int64 returns a int64 decoded from the buffer. -func (d *Decoder) Int64() int64 { - return int64(d.Uint64()) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package cstruct provides a helper for decoding binary data that is in the +// form of a padded C structure. +package cstruct + +import ( + "errors" + "io" + + "github.com/josharian/native" +) + +// Size of a pointer-typed value, in bits +const pointerSize = 32 << (^uintptr(0) >> 63) + +// We assume that non-64-bit platforms are 32-bit; we don't expect Go to run on +// a 16- or 8-bit architecture any time soon. +const is64Bit = pointerSize == 64 + +// Decoder reads and decodes padded fields from a slice of bytes. All fields +// are decoded with native endianness. +// +// Methods of a Decoder do not return errors, but rather store any error within +// the Decoder. The first error can be obtained via the Err method; after the +// first error, methods will return the zero value for their type. +type Decoder struct { + b []byte + off int + err error + dbuf [8]byte // for decoding +} + +// NewDecoder creates a Decoder from a byte slice. +func NewDecoder(b []byte) *Decoder { + return &Decoder{b: b} +} + +var errUnsupportedSize = errors.New("unsupported size") + +func padBytes(offset, size int) int { + if offset == 0 || size == 1 { + return 0 + } + remainder := offset % size + return size - remainder +} + +func (d *Decoder) getField(b []byte) error { + size := len(b) + + // We only support fields that are multiples of 2 (or 1-sized) + if size != 1 && size&1 == 1 { + return errUnsupportedSize + } + + // Fields are aligned to their size + padBytes := padBytes(d.off, size) + if d.off+size+padBytes > len(d.b) { + return io.EOF + } + d.off += padBytes + + copy(b, d.b[d.off:d.off+size]) + d.off += size + return nil +} + +// Err returns the first error that was encountered by this Decoder. +func (d *Decoder) Err() error { + return d.err +} + +// Offset returns the current read offset for data in the buffer. +func (d *Decoder) Offset() int { + return d.off +} + +// Byte returns a single byte from the buffer. +func (d *Decoder) Byte() byte { + if d.err != nil { + return 0 + } + + if err := d.getField(d.dbuf[0:1]); err != nil { + d.err = err + return 0 + } + return d.dbuf[0] +} + +// Byte returns a number of bytes from the buffer based on the size of the +// input slice. No padding is applied. +// +// If an error is encountered or this Decoder has previously encountered an +// error, no changes are made to the provided buffer. +func (d *Decoder) Bytes(b []byte) { + if d.err != nil { + return + } + + // No padding for byte slices + size := len(b) + if d.off+size >= len(d.b) { + d.err = io.EOF + return + } + copy(b, d.b[d.off:d.off+size]) + d.off += size +} + +// Uint16 returns a uint16 decoded from the buffer. +func (d *Decoder) Uint16() uint16 { + if d.err != nil { + return 0 + } + + if err := d.getField(d.dbuf[0:2]); err != nil { + d.err = err + return 0 + } + return native.Endian.Uint16(d.dbuf[0:2]) +} + +// Uint32 returns a uint32 decoded from the buffer. +func (d *Decoder) Uint32() uint32 { + if d.err != nil { + return 0 + } + + if err := d.getField(d.dbuf[0:4]); err != nil { + d.err = err + return 0 + } + return native.Endian.Uint32(d.dbuf[0:4]) +} + +// Uint64 returns a uint64 decoded from the buffer. +func (d *Decoder) Uint64() uint64 { + if d.err != nil { + return 0 + } + + if err := d.getField(d.dbuf[0:8]); err != nil { + d.err = err + return 0 + } + return native.Endian.Uint64(d.dbuf[0:8]) +} + +// Uintptr returns a uintptr decoded from the buffer. +func (d *Decoder) Uintptr() uintptr { + if d.err != nil { + return 0 + } + + if is64Bit { + return uintptr(d.Uint64()) + } else { + return uintptr(d.Uint32()) + } +} + +// Int16 returns a int16 decoded from the buffer. +func (d *Decoder) Int16() int16 { + return int16(d.Uint16()) +} + +// Int32 returns a int32 decoded from the buffer. +func (d *Decoder) Int32() int32 { + return int32(d.Uint32()) +} + +// Int64 returns a int64 decoded from the buffer. +func (d *Decoder) Int64() int64 { + return int64(d.Uint64()) +} diff --git a/util/cstruct/cstruct_example_test.go b/util/cstruct/cstruct_example_test.go index 17032267b9dc6..a36cbf9f0caa3 100644 --- a/util/cstruct/cstruct_example_test.go +++ b/util/cstruct/cstruct_example_test.go @@ -1,73 +1,73 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Only built on 64-bit platforms to avoid complexity - -//go:build amd64 || arm64 || mips64le || ppc64le || riscv64 - -package cstruct - -import "fmt" - -// This test provides a semi-realistic example of how you can -// use this package to decode a C structure. -func ExampleDecoder() { - // Our example C structure: - // struct mystruct { - // char *p; - // char c; - // /* implicit: char _pad[3]; */ - // int x; - // }; - // - // The Go structure definition: - type myStruct struct { - Ptr uintptr - Ch byte - Intval uint32 - } - - // Our "in-memory" version of the above structure - buf := []byte{ - 1, 2, 3, 4, 0, 0, 0, 0, // ptr - 5, // ch - 99, 99, 99, // padding - 78, 6, 0, 0, // x - } - d := NewDecoder(buf) - - // Decode the structure; if one of these function returns an error, - // then subsequent decoder functions will return the zero value. - var x myStruct - x.Ptr = d.Uintptr() - x.Ch = d.Byte() - x.Intval = d.Uint32() - - // Note that per the Go language spec: - // [...] when evaluating the operands of an expression, assignment, - // or return statement, all function calls, method calls, and - // (channel) communication operations are evaluated in lexical - // left-to-right order - // - // Since each field is assigned via a function call, one could use the - // following snippet to decode the struct. - // x := myStruct{ - // Ptr: d.Uintptr(), - // Ch: d.Byte(), - // Intval: d.Uint32(), - // } - // - // However, this means that reordering the fields in the initialization - // statement–normally a semantically identical operation–would change - // the way the structure is parsed. Thus we do it as above with - // explicit ordering. - - // After finishing with the decoder, check errors - if err := d.Err(); err != nil { - panic(err) - } - - // Print the decoder offset and structure - fmt.Printf("off=%d struct=%#v\n", d.Offset(), x) - // Output: off=16 struct=cstruct.myStruct{Ptr:0x4030201, Ch:0x5, Intval:0x64e} -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Only built on 64-bit platforms to avoid complexity + +//go:build amd64 || arm64 || mips64le || ppc64le || riscv64 + +package cstruct + +import "fmt" + +// This test provides a semi-realistic example of how you can +// use this package to decode a C structure. +func ExampleDecoder() { + // Our example C structure: + // struct mystruct { + // char *p; + // char c; + // /* implicit: char _pad[3]; */ + // int x; + // }; + // + // The Go structure definition: + type myStruct struct { + Ptr uintptr + Ch byte + Intval uint32 + } + + // Our "in-memory" version of the above structure + buf := []byte{ + 1, 2, 3, 4, 0, 0, 0, 0, // ptr + 5, // ch + 99, 99, 99, // padding + 78, 6, 0, 0, // x + } + d := NewDecoder(buf) + + // Decode the structure; if one of these function returns an error, + // then subsequent decoder functions will return the zero value. + var x myStruct + x.Ptr = d.Uintptr() + x.Ch = d.Byte() + x.Intval = d.Uint32() + + // Note that per the Go language spec: + // [...] when evaluating the operands of an expression, assignment, + // or return statement, all function calls, method calls, and + // (channel) communication operations are evaluated in lexical + // left-to-right order + // + // Since each field is assigned via a function call, one could use the + // following snippet to decode the struct. + // x := myStruct{ + // Ptr: d.Uintptr(), + // Ch: d.Byte(), + // Intval: d.Uint32(), + // } + // + // However, this means that reordering the fields in the initialization + // statement–normally a semantically identical operation–would change + // the way the structure is parsed. Thus we do it as above with + // explicit ordering. + + // After finishing with the decoder, check errors + if err := d.Err(); err != nil { + panic(err) + } + + // Print the decoder offset and structure + fmt.Printf("off=%d struct=%#v\n", d.Offset(), x) + // Output: off=16 struct=cstruct.myStruct{Ptr:0x4030201, Ch:0x5, Intval:0x64e} +} diff --git a/util/deephash/debug.go b/util/deephash/debug.go index 50b3d5605f327..ff417e5835178 100644 --- a/util/deephash/debug.go +++ b/util/deephash/debug.go @@ -1,37 +1,37 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build deephash_debug - -package deephash - -import "fmt" - -func (h *hasher) HashBytes(b []byte) { - fmt.Printf("B(%q)+", b) - h.Block512.HashBytes(b) -} -func (h *hasher) HashString(s string) { - fmt.Printf("S(%q)+", s) - h.Block512.HashString(s) -} -func (h *hasher) HashUint8(n uint8) { - fmt.Printf("U8(%d)+", n) - h.Block512.HashUint8(n) -} -func (h *hasher) HashUint16(n uint16) { - fmt.Printf("U16(%d)+", n) - h.Block512.HashUint16(n) -} -func (h *hasher) HashUint32(n uint32) { - fmt.Printf("U32(%d)+", n) - h.Block512.HashUint32(n) -} -func (h *hasher) HashUint64(n uint64) { - fmt.Printf("U64(%d)+", n) - h.Block512.HashUint64(n) -} -func (h *hasher) Sum(b []byte) []byte { - fmt.Println("FIN") - return h.Block512.Sum(b) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build deephash_debug + +package deephash + +import "fmt" + +func (h *hasher) HashBytes(b []byte) { + fmt.Printf("B(%q)+", b) + h.Block512.HashBytes(b) +} +func (h *hasher) HashString(s string) { + fmt.Printf("S(%q)+", s) + h.Block512.HashString(s) +} +func (h *hasher) HashUint8(n uint8) { + fmt.Printf("U8(%d)+", n) + h.Block512.HashUint8(n) +} +func (h *hasher) HashUint16(n uint16) { + fmt.Printf("U16(%d)+", n) + h.Block512.HashUint16(n) +} +func (h *hasher) HashUint32(n uint32) { + fmt.Printf("U32(%d)+", n) + h.Block512.HashUint32(n) +} +func (h *hasher) HashUint64(n uint64) { + fmt.Printf("U64(%d)+", n) + h.Block512.HashUint64(n) +} +func (h *hasher) Sum(b []byte) []byte { + fmt.Println("FIN") + return h.Block512.Sum(b) +} diff --git a/util/deephash/pointer.go b/util/deephash/pointer.go index aafae47a23673..71b11d7ff1d75 100644 --- a/util/deephash/pointer.go +++ b/util/deephash/pointer.go @@ -1,114 +1,114 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package deephash - -import ( - "net/netip" - "reflect" - "time" - "unsafe" -) - -// unsafePointer is an untyped pointer. -// It is the caller's responsibility to call operations on the correct type. -// -// This pointer only ever points to a small set of kinds or types: -// time.Time, netip.Addr, string, array, slice, struct, map, pointer, interface, -// or a pointer to memory that is directly hashable. -// -// Arrays are represented as pointers to the first element. -// Structs are represented as pointers to the first field. -// Slices are represented as pointers to a slice header. -// Pointers are represented as pointers to a pointer. -// -// We do not support direct operations on maps and interfaces, and instead -// rely on pointer.asValue to convert the pointer back to a reflect.Value. -// Conversion of an unsafe.Pointer to reflect.Value guarantees that the -// read-only flag in the reflect.Value is unpopulated, avoiding panics that may -// otherwise have occurred since the value was obtained from an unexported field. -type unsafePointer struct{ p unsafe.Pointer } - -func unsafePointerOf(v reflect.Value) unsafePointer { - return unsafePointer{v.UnsafePointer()} -} -func (p unsafePointer) isNil() bool { - return p.p == nil -} - -// pointerElem dereferences a pointer. -// p must point to a pointer. -func (p unsafePointer) pointerElem() unsafePointer { - return unsafePointer{*(*unsafe.Pointer)(p.p)} -} - -// sliceLen returns the slice length. -// p must point to a slice. -func (p unsafePointer) sliceLen() int { - return (*reflect.SliceHeader)(p.p).Len -} - -// sliceArray returns a pointer to the underlying slice array. -// p must point to a slice. -func (p unsafePointer) sliceArray() unsafePointer { - return unsafePointer{unsafe.Pointer((*reflect.SliceHeader)(p.p).Data)} -} - -// arrayIndex returns a pointer to an element in the array. -// p must point to an array. -func (p unsafePointer) arrayIndex(index int, size uintptr) unsafePointer { - return unsafePointer{unsafe.Add(p.p, uintptr(index)*size)} -} - -// structField returns a pointer to a field in a struct. -// p must pointer to a struct. -func (p unsafePointer) structField(index int, offset, size uintptr) unsafePointer { - return unsafePointer{unsafe.Add(p.p, offset)} -} - -// asString casts p as a *string. -func (p unsafePointer) asString() *string { - return (*string)(p.p) -} - -// asTime casts p as a *time.Time. -func (p unsafePointer) asTime() *time.Time { - return (*time.Time)(p.p) -} - -// asAddr casts p as a *netip.Addr. -func (p unsafePointer) asAddr() *netip.Addr { - return (*netip.Addr)(p.p) -} - -// asValue casts p as a reflect.Value containing a pointer to value of t. -func (p unsafePointer) asValue(typ reflect.Type) reflect.Value { - return reflect.NewAt(typ, p.p) -} - -// asMemory returns the memory pointer at by p for a specified size. -func (p unsafePointer) asMemory(size uintptr) []byte { - return unsafe.Slice((*byte)(p.p), size) -} - -// visitStack is a stack of pointers visited. -// Pointers are pushed onto the stack when visited, and popped when leaving. -// The integer value is the depth at which the pointer was visited. -// The length of this stack should be zero after every hashing operation. -type visitStack map[unsafe.Pointer]int - -func (v visitStack) seen(p unsafe.Pointer) (int, bool) { - idx, ok := v[p] - return idx, ok -} - -func (v *visitStack) push(p unsafe.Pointer) { - if *v == nil { - *v = make(map[unsafe.Pointer]int) - } - (*v)[p] = len(*v) -} - -func (v visitStack) pop(p unsafe.Pointer) { - delete(v, p) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package deephash + +import ( + "net/netip" + "reflect" + "time" + "unsafe" +) + +// unsafePointer is an untyped pointer. +// It is the caller's responsibility to call operations on the correct type. +// +// This pointer only ever points to a small set of kinds or types: +// time.Time, netip.Addr, string, array, slice, struct, map, pointer, interface, +// or a pointer to memory that is directly hashable. +// +// Arrays are represented as pointers to the first element. +// Structs are represented as pointers to the first field. +// Slices are represented as pointers to a slice header. +// Pointers are represented as pointers to a pointer. +// +// We do not support direct operations on maps and interfaces, and instead +// rely on pointer.asValue to convert the pointer back to a reflect.Value. +// Conversion of an unsafe.Pointer to reflect.Value guarantees that the +// read-only flag in the reflect.Value is unpopulated, avoiding panics that may +// otherwise have occurred since the value was obtained from an unexported field. +type unsafePointer struct{ p unsafe.Pointer } + +func unsafePointerOf(v reflect.Value) unsafePointer { + return unsafePointer{v.UnsafePointer()} +} +func (p unsafePointer) isNil() bool { + return p.p == nil +} + +// pointerElem dereferences a pointer. +// p must point to a pointer. +func (p unsafePointer) pointerElem() unsafePointer { + return unsafePointer{*(*unsafe.Pointer)(p.p)} +} + +// sliceLen returns the slice length. +// p must point to a slice. +func (p unsafePointer) sliceLen() int { + return (*reflect.SliceHeader)(p.p).Len +} + +// sliceArray returns a pointer to the underlying slice array. +// p must point to a slice. +func (p unsafePointer) sliceArray() unsafePointer { + return unsafePointer{unsafe.Pointer((*reflect.SliceHeader)(p.p).Data)} +} + +// arrayIndex returns a pointer to an element in the array. +// p must point to an array. +func (p unsafePointer) arrayIndex(index int, size uintptr) unsafePointer { + return unsafePointer{unsafe.Add(p.p, uintptr(index)*size)} +} + +// structField returns a pointer to a field in a struct. +// p must pointer to a struct. +func (p unsafePointer) structField(index int, offset, size uintptr) unsafePointer { + return unsafePointer{unsafe.Add(p.p, offset)} +} + +// asString casts p as a *string. +func (p unsafePointer) asString() *string { + return (*string)(p.p) +} + +// asTime casts p as a *time.Time. +func (p unsafePointer) asTime() *time.Time { + return (*time.Time)(p.p) +} + +// asAddr casts p as a *netip.Addr. +func (p unsafePointer) asAddr() *netip.Addr { + return (*netip.Addr)(p.p) +} + +// asValue casts p as a reflect.Value containing a pointer to value of t. +func (p unsafePointer) asValue(typ reflect.Type) reflect.Value { + return reflect.NewAt(typ, p.p) +} + +// asMemory returns the memory pointer at by p for a specified size. +func (p unsafePointer) asMemory(size uintptr) []byte { + return unsafe.Slice((*byte)(p.p), size) +} + +// visitStack is a stack of pointers visited. +// Pointers are pushed onto the stack when visited, and popped when leaving. +// The integer value is the depth at which the pointer was visited. +// The length of this stack should be zero after every hashing operation. +type visitStack map[unsafe.Pointer]int + +func (v visitStack) seen(p unsafe.Pointer) (int, bool) { + idx, ok := v[p] + return idx, ok +} + +func (v *visitStack) push(p unsafe.Pointer) { + if *v == nil { + *v = make(map[unsafe.Pointer]int) + } + (*v)[p] = len(*v) +} + +func (v visitStack) pop(p unsafe.Pointer) { + delete(v, p) +} diff --git a/util/deephash/pointer_norace.go b/util/deephash/pointer_norace.go index f98a70f6a18e5..4993720002460 100644 --- a/util/deephash/pointer_norace.go +++ b/util/deephash/pointer_norace.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !race - -package deephash - -import "reflect" - -type pointer = unsafePointer - -// pointerOf returns a pointer from v, which must be a reflect.Pointer. -func pointerOf(v reflect.Value) pointer { return unsafePointerOf(v) } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !race + +package deephash + +import "reflect" + +type pointer = unsafePointer + +// pointerOf returns a pointer from v, which must be a reflect.Pointer. +func pointerOf(v reflect.Value) pointer { return unsafePointerOf(v) } diff --git a/util/deephash/pointer_race.go b/util/deephash/pointer_race.go index c638c7d39f393..93a358b6df358 100644 --- a/util/deephash/pointer_race.go +++ b/util/deephash/pointer_race.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build race - -package deephash - -import ( - "fmt" - "net/netip" - "reflect" - "time" -) - -// pointer is a typed pointer that performs safety checks for every operation. -type pointer struct { - unsafePointer - t reflect.Type // type of pointed-at value; may be nil - n uintptr // size of valid memory after p -} - -// pointerOf returns a pointer from v, which must be a reflect.Pointer. -func pointerOf(v reflect.Value) pointer { - assert(v.Kind() == reflect.Pointer, "got %v, want pointer", v.Kind()) - te := v.Type().Elem() - return pointer{unsafePointerOf(v), te, te.Size()} -} - -func (p pointer) pointerElem() pointer { - assert(p.t.Kind() == reflect.Pointer, "got %v, want pointer", p.t.Kind()) - te := p.t.Elem() - return pointer{p.unsafePointer.pointerElem(), te, te.Size()} -} - -func (p pointer) sliceLen() int { - assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind()) - return p.unsafePointer.sliceLen() -} - -func (p pointer) sliceArray() pointer { - assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind()) - n := p.sliceLen() - assert(n >= 0, "got negative slice length %d", n) - ta := reflect.ArrayOf(n, p.t.Elem()) - return pointer{p.unsafePointer.sliceArray(), ta, ta.Size()} -} - -func (p pointer) arrayIndex(index int, size uintptr) pointer { - assert(p.t.Kind() == reflect.Array, "got %v, want array", p.t.Kind()) - assert(0 <= index && index < p.t.Len(), "got array of size %d, want to access element %d", p.t.Len(), index) - assert(p.t.Elem().Size() == size, "got element size of %d, want %d", p.t.Elem().Size(), size) - te := p.t.Elem() - return pointer{p.unsafePointer.arrayIndex(index, size), te, te.Size()} -} - -func (p pointer) structField(index int, offset, size uintptr) pointer { - assert(p.t.Kind() == reflect.Struct, "got %v, want struct", p.t.Kind()) - assert(p.n >= offset, "got size of %d, want excessive start offset of %d", p.n, offset) - assert(p.n >= offset+size, "got size of %d, want excessive end offset of %d", p.n, offset+size) - if index < 0 { - return pointer{p.unsafePointer.structField(index, offset, size), nil, size} - } - sf := p.t.Field(index) - t := sf.Type - assert(sf.Offset == offset, "got offset of %d, want offset %d", sf.Offset, offset) - assert(t.Size() == size, "got size of %d, want size %d", t.Size(), size) - return pointer{p.unsafePointer.structField(index, offset, size), t, t.Size()} -} - -func (p pointer) asString() *string { - assert(p.t.Kind() == reflect.String, "got %v, want string", p.t) - return p.unsafePointer.asString() -} - -func (p pointer) asTime() *time.Time { - assert(p.t == timeTimeType, "got %v, want %v", p.t, timeTimeType) - return p.unsafePointer.asTime() -} - -func (p pointer) asAddr() *netip.Addr { - assert(p.t == netipAddrType, "got %v, want %v", p.t, netipAddrType) - return p.unsafePointer.asAddr() -} - -func (p pointer) asValue(typ reflect.Type) reflect.Value { - assert(p.t == typ, "got %v, want %v", p.t, typ) - return p.unsafePointer.asValue(typ) -} - -func (p pointer) asMemory(size uintptr) []byte { - assert(p.n >= size, "got size of %d, want excessive size of %d", p.n, size) - return p.unsafePointer.asMemory(size) -} - -func assert(b bool, f string, a ...any) { - if !b { - panic(fmt.Sprintf(f, a...)) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build race + +package deephash + +import ( + "fmt" + "net/netip" + "reflect" + "time" +) + +// pointer is a typed pointer that performs safety checks for every operation. +type pointer struct { + unsafePointer + t reflect.Type // type of pointed-at value; may be nil + n uintptr // size of valid memory after p +} + +// pointerOf returns a pointer from v, which must be a reflect.Pointer. +func pointerOf(v reflect.Value) pointer { + assert(v.Kind() == reflect.Pointer, "got %v, want pointer", v.Kind()) + te := v.Type().Elem() + return pointer{unsafePointerOf(v), te, te.Size()} +} + +func (p pointer) pointerElem() pointer { + assert(p.t.Kind() == reflect.Pointer, "got %v, want pointer", p.t.Kind()) + te := p.t.Elem() + return pointer{p.unsafePointer.pointerElem(), te, te.Size()} +} + +func (p pointer) sliceLen() int { + assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind()) + return p.unsafePointer.sliceLen() +} + +func (p pointer) sliceArray() pointer { + assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind()) + n := p.sliceLen() + assert(n >= 0, "got negative slice length %d", n) + ta := reflect.ArrayOf(n, p.t.Elem()) + return pointer{p.unsafePointer.sliceArray(), ta, ta.Size()} +} + +func (p pointer) arrayIndex(index int, size uintptr) pointer { + assert(p.t.Kind() == reflect.Array, "got %v, want array", p.t.Kind()) + assert(0 <= index && index < p.t.Len(), "got array of size %d, want to access element %d", p.t.Len(), index) + assert(p.t.Elem().Size() == size, "got element size of %d, want %d", p.t.Elem().Size(), size) + te := p.t.Elem() + return pointer{p.unsafePointer.arrayIndex(index, size), te, te.Size()} +} + +func (p pointer) structField(index int, offset, size uintptr) pointer { + assert(p.t.Kind() == reflect.Struct, "got %v, want struct", p.t.Kind()) + assert(p.n >= offset, "got size of %d, want excessive start offset of %d", p.n, offset) + assert(p.n >= offset+size, "got size of %d, want excessive end offset of %d", p.n, offset+size) + if index < 0 { + return pointer{p.unsafePointer.structField(index, offset, size), nil, size} + } + sf := p.t.Field(index) + t := sf.Type + assert(sf.Offset == offset, "got offset of %d, want offset %d", sf.Offset, offset) + assert(t.Size() == size, "got size of %d, want size %d", t.Size(), size) + return pointer{p.unsafePointer.structField(index, offset, size), t, t.Size()} +} + +func (p pointer) asString() *string { + assert(p.t.Kind() == reflect.String, "got %v, want string", p.t) + return p.unsafePointer.asString() +} + +func (p pointer) asTime() *time.Time { + assert(p.t == timeTimeType, "got %v, want %v", p.t, timeTimeType) + return p.unsafePointer.asTime() +} + +func (p pointer) asAddr() *netip.Addr { + assert(p.t == netipAddrType, "got %v, want %v", p.t, netipAddrType) + return p.unsafePointer.asAddr() +} + +func (p pointer) asValue(typ reflect.Type) reflect.Value { + assert(p.t == typ, "got %v, want %v", p.t, typ) + return p.unsafePointer.asValue(typ) +} + +func (p pointer) asMemory(size uintptr) []byte { + assert(p.n >= size, "got size of %d, want excessive size of %d", p.n, size) + return p.unsafePointer.asMemory(size) +} + +func assert(b bool, f string, a ...any) { + if !b { + panic(fmt.Sprintf(f, a...)) + } +} diff --git a/util/deephash/testtype/testtype.go b/util/deephash/testtype/testtype.go index 3c90053d6dfd5..2df38da8777ff 100644 --- a/util/deephash/testtype/testtype.go +++ b/util/deephash/testtype/testtype.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package testtype contains types for testing deephash. -package testtype - -import "time" - -type UnexportedAddressableTime struct { - t time.Time -} - -func NewUnexportedAddressableTime(t time.Time) *UnexportedAddressableTime { - return &UnexportedAddressableTime{t: t} -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package testtype contains types for testing deephash. +package testtype + +import "time" + +type UnexportedAddressableTime struct { + t time.Time +} + +func NewUnexportedAddressableTime(t time.Time) *UnexportedAddressableTime { + return &UnexportedAddressableTime{t: t} +} diff --git a/util/dirwalk/dirwalk.go b/util/dirwalk/dirwalk.go index 811766892896a..a05ee3553ad90 100644 --- a/util/dirwalk/dirwalk.go +++ b/util/dirwalk/dirwalk.go @@ -1,53 +1,53 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package dirwalk contains code to walk a directory. -package dirwalk - -import ( - "io" - "io/fs" - "os" - - "go4.org/mem" -) - -var osWalkShallow func(name mem.RO, fn WalkFunc) error - -// WalkFunc is the callback type used with WalkShallow. -// -// The name and de are only valid for the duration of func's call -// and should not be retained. -type WalkFunc func(name mem.RO, de fs.DirEntry) error - -// WalkShallow reads the entries in the named directory and calls fn for each. -// It does not recurse into subdirectories. -// -// If fn returns an error, iteration stops and WalkShallow returns that value. -// -// On Linux, WalkShallow does not allocate, so long as certain methods on the -// WalkFunc's DirEntry are not called which necessarily allocate. -func WalkShallow(dirName mem.RO, fn WalkFunc) error { - if f := osWalkShallow; f != nil { - return f(dirName, fn) - } - of, err := os.Open(dirName.StringCopy()) - if err != nil { - return err - } - defer of.Close() - for { - fis, err := of.ReadDir(100) - for _, de := range fis { - if err := fn(mem.S(de.Name()), de); err != nil { - return err - } - } - if err != nil { - if err == io.EOF { - return nil - } - return err - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package dirwalk contains code to walk a directory. +package dirwalk + +import ( + "io" + "io/fs" + "os" + + "go4.org/mem" +) + +var osWalkShallow func(name mem.RO, fn WalkFunc) error + +// WalkFunc is the callback type used with WalkShallow. +// +// The name and de are only valid for the duration of func's call +// and should not be retained. +type WalkFunc func(name mem.RO, de fs.DirEntry) error + +// WalkShallow reads the entries in the named directory and calls fn for each. +// It does not recurse into subdirectories. +// +// If fn returns an error, iteration stops and WalkShallow returns that value. +// +// On Linux, WalkShallow does not allocate, so long as certain methods on the +// WalkFunc's DirEntry are not called which necessarily allocate. +func WalkShallow(dirName mem.RO, fn WalkFunc) error { + if f := osWalkShallow; f != nil { + return f(dirName, fn) + } + of, err := os.Open(dirName.StringCopy()) + if err != nil { + return err + } + defer of.Close() + for { + fis, err := of.ReadDir(100) + for _, de := range fis { + if err := fn(mem.S(de.Name()), de); err != nil { + return err + } + } + if err != nil { + if err == io.EOF { + return nil + } + return err + } + } +} diff --git a/util/dirwalk/dirwalk_linux.go b/util/dirwalk/dirwalk_linux.go index 256467ebd8ac5..7147831452d38 100644 --- a/util/dirwalk/dirwalk_linux.go +++ b/util/dirwalk/dirwalk_linux.go @@ -1,167 +1,167 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dirwalk - -import ( - "fmt" - "io/fs" - "os" - "path/filepath" - "sync" - "syscall" - "unsafe" - - "go4.org/mem" - "golang.org/x/sys/unix" -) - -func init() { - osWalkShallow = linuxWalkShallow -} - -var dirEntPool = &sync.Pool{New: func() any { return new(linuxDirEnt) }} - -func linuxWalkShallow(dirName mem.RO, fn WalkFunc) error { - const blockSize = 8 << 10 - buf := make([]byte, blockSize) // stack-allocated; doesn't escape - - nameb := mem.Append(buf[:0], dirName) - nameb = append(nameb, 0) - - fd, err := sysOpen(nameb) - if err != nil { - return err - } - defer syscall.Close(fd) - - bufp := 0 // starting read position in buf - nbuf := 0 // end valid data in buf - - de := dirEntPool.Get().(*linuxDirEnt) - defer de.cleanAndPutInPool() - de.root = dirName - - for { - if bufp >= nbuf { - bufp = 0 - nbuf, err = readDirent(fd, buf) - if err != nil { - return err - } - if nbuf <= 0 { - return nil - } - } - consumed, name := parseDirEnt(&de.d, buf[bufp:nbuf]) - bufp += consumed - if len(name) == 0 || string(name) == "." || string(name) == ".." { - continue - } - de.name = mem.B(name) - if err := fn(de.name, de); err != nil { - return err - } - } -} - -type linuxDirEnt struct { - root mem.RO - d syscall.Dirent - name mem.RO -} - -func (de *linuxDirEnt) cleanAndPutInPool() { - de.root = mem.RO{} - de.name = mem.RO{} - dirEntPool.Put(de) -} - -func (de *linuxDirEnt) Name() string { return de.name.StringCopy() } -func (de *linuxDirEnt) Info() (fs.FileInfo, error) { - return os.Lstat(filepath.Join(de.root.StringCopy(), de.name.StringCopy())) -} -func (de *linuxDirEnt) IsDir() bool { - return de.d.Type == syscall.DT_DIR -} -func (de *linuxDirEnt) Type() fs.FileMode { - switch de.d.Type { - case syscall.DT_BLK: - return fs.ModeDevice // shrug - case syscall.DT_CHR: - return fs.ModeCharDevice - case syscall.DT_DIR: - return fs.ModeDir - case syscall.DT_FIFO: - return fs.ModeNamedPipe - case syscall.DT_LNK: - return fs.ModeSymlink - case syscall.DT_REG: - return 0 - case syscall.DT_SOCK: - return fs.ModeSocket - default: - return fs.ModeIrregular // shrug - } -} - -func direntNamlen(dirent *syscall.Dirent) int { - const fixedHdr = uint16(unsafe.Offsetof(syscall.Dirent{}.Name)) - limit := dirent.Reclen - fixedHdr - const dirNameLen = 256 // sizeof syscall.Dirent.Name - if limit > dirNameLen { - limit = dirNameLen - } - for i := uint16(0); i < limit; i++ { - if dirent.Name[i] == 0 { - return int(i) - } - } - panic("failed to find terminating 0 byte in dirent") -} - -func parseDirEnt(dirent *syscall.Dirent, buf []byte) (consumed int, name []byte) { - // golang.org/issue/37269 - copy(unsafe.Slice((*byte)(unsafe.Pointer(dirent)), unsafe.Sizeof(syscall.Dirent{})), buf) - if v := unsafe.Offsetof(dirent.Reclen) + unsafe.Sizeof(dirent.Reclen); uintptr(len(buf)) < v { - panic(fmt.Sprintf("buf size of %d smaller than dirent header size %d", len(buf), v)) - } - if len(buf) < int(dirent.Reclen) { - panic(fmt.Sprintf("buf size %d < record length %d", len(buf), dirent.Reclen)) - } - consumed = int(dirent.Reclen) - if dirent.Ino == 0 { // File absent in directory. - return - } - name = unsafe.Slice((*byte)(unsafe.Pointer(&dirent.Name[0])), direntNamlen(dirent)) - return -} - -func sysOpen(name []byte) (fd int, err error) { - if len(name) == 0 || name[len(name)-1] != 0 { - return 0, syscall.EINVAL - } - var dirfd int = unix.AT_FDCWD - for { - r0, _, e1 := syscall.Syscall(unix.SYS_OPENAT, uintptr(dirfd), - uintptr(unsafe.Pointer(&name[0])), 0) - if e1 == 0 { - return int(r0), nil - } - if e1 == syscall.EINTR { - // Since https://golang.org/doc/go1.14#runtime we - // need to loop on EINTR on more places. - continue - } - return 0, syscall.Errno(e1) - } -} - -func readDirent(fd int, buf []byte) (n int, err error) { - for { - nbuf, err := syscall.ReadDirent(fd, buf) - if err != syscall.EINTR { - return nbuf, err - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dirwalk + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "sync" + "syscall" + "unsafe" + + "go4.org/mem" + "golang.org/x/sys/unix" +) + +func init() { + osWalkShallow = linuxWalkShallow +} + +var dirEntPool = &sync.Pool{New: func() any { return new(linuxDirEnt) }} + +func linuxWalkShallow(dirName mem.RO, fn WalkFunc) error { + const blockSize = 8 << 10 + buf := make([]byte, blockSize) // stack-allocated; doesn't escape + + nameb := mem.Append(buf[:0], dirName) + nameb = append(nameb, 0) + + fd, err := sysOpen(nameb) + if err != nil { + return err + } + defer syscall.Close(fd) + + bufp := 0 // starting read position in buf + nbuf := 0 // end valid data in buf + + de := dirEntPool.Get().(*linuxDirEnt) + defer de.cleanAndPutInPool() + de.root = dirName + + for { + if bufp >= nbuf { + bufp = 0 + nbuf, err = readDirent(fd, buf) + if err != nil { + return err + } + if nbuf <= 0 { + return nil + } + } + consumed, name := parseDirEnt(&de.d, buf[bufp:nbuf]) + bufp += consumed + if len(name) == 0 || string(name) == "." || string(name) == ".." { + continue + } + de.name = mem.B(name) + if err := fn(de.name, de); err != nil { + return err + } + } +} + +type linuxDirEnt struct { + root mem.RO + d syscall.Dirent + name mem.RO +} + +func (de *linuxDirEnt) cleanAndPutInPool() { + de.root = mem.RO{} + de.name = mem.RO{} + dirEntPool.Put(de) +} + +func (de *linuxDirEnt) Name() string { return de.name.StringCopy() } +func (de *linuxDirEnt) Info() (fs.FileInfo, error) { + return os.Lstat(filepath.Join(de.root.StringCopy(), de.name.StringCopy())) +} +func (de *linuxDirEnt) IsDir() bool { + return de.d.Type == syscall.DT_DIR +} +func (de *linuxDirEnt) Type() fs.FileMode { + switch de.d.Type { + case syscall.DT_BLK: + return fs.ModeDevice // shrug + case syscall.DT_CHR: + return fs.ModeCharDevice + case syscall.DT_DIR: + return fs.ModeDir + case syscall.DT_FIFO: + return fs.ModeNamedPipe + case syscall.DT_LNK: + return fs.ModeSymlink + case syscall.DT_REG: + return 0 + case syscall.DT_SOCK: + return fs.ModeSocket + default: + return fs.ModeIrregular // shrug + } +} + +func direntNamlen(dirent *syscall.Dirent) int { + const fixedHdr = uint16(unsafe.Offsetof(syscall.Dirent{}.Name)) + limit := dirent.Reclen - fixedHdr + const dirNameLen = 256 // sizeof syscall.Dirent.Name + if limit > dirNameLen { + limit = dirNameLen + } + for i := uint16(0); i < limit; i++ { + if dirent.Name[i] == 0 { + return int(i) + } + } + panic("failed to find terminating 0 byte in dirent") +} + +func parseDirEnt(dirent *syscall.Dirent, buf []byte) (consumed int, name []byte) { + // golang.org/issue/37269 + copy(unsafe.Slice((*byte)(unsafe.Pointer(dirent)), unsafe.Sizeof(syscall.Dirent{})), buf) + if v := unsafe.Offsetof(dirent.Reclen) + unsafe.Sizeof(dirent.Reclen); uintptr(len(buf)) < v { + panic(fmt.Sprintf("buf size of %d smaller than dirent header size %d", len(buf), v)) + } + if len(buf) < int(dirent.Reclen) { + panic(fmt.Sprintf("buf size %d < record length %d", len(buf), dirent.Reclen)) + } + consumed = int(dirent.Reclen) + if dirent.Ino == 0 { // File absent in directory. + return + } + name = unsafe.Slice((*byte)(unsafe.Pointer(&dirent.Name[0])), direntNamlen(dirent)) + return +} + +func sysOpen(name []byte) (fd int, err error) { + if len(name) == 0 || name[len(name)-1] != 0 { + return 0, syscall.EINVAL + } + var dirfd int = unix.AT_FDCWD + for { + r0, _, e1 := syscall.Syscall(unix.SYS_OPENAT, uintptr(dirfd), + uintptr(unsafe.Pointer(&name[0])), 0) + if e1 == 0 { + return int(r0), nil + } + if e1 == syscall.EINTR { + // Since https://golang.org/doc/go1.14#runtime we + // need to loop on EINTR on more places. + continue + } + return 0, syscall.Errno(e1) + } +} + +func readDirent(fd int, buf []byte) (n int, err error) { + for { + nbuf, err := syscall.ReadDirent(fd, buf) + if err != syscall.EINTR { + return nbuf, err + } + } +} diff --git a/util/dirwalk/dirwalk_test.go b/util/dirwalk/dirwalk_test.go index 15ebc13dd404d..e2e41f634947e 100644 --- a/util/dirwalk/dirwalk_test.go +++ b/util/dirwalk/dirwalk_test.go @@ -1,91 +1,91 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dirwalk - -import ( - "fmt" - "os" - "path/filepath" - "reflect" - "runtime" - "sort" - "testing" - - "go4.org/mem" - "tailscale.com/tstest" -) - -func TestWalkShallowOSSpecific(t *testing.T) { - if osWalkShallow == nil { - t.Skip("no OS-specific implementation") - } - testWalkShallow(t, false) -} - -func TestWalkShallowPortable(t *testing.T) { - testWalkShallow(t, true) -} - -func testWalkShallow(t *testing.T, portable bool) { - if portable { - tstest.Replace(t, &osWalkShallow, nil) - } - d := t.TempDir() - - t.Run("basics", func(t *testing.T) { - if err := os.WriteFile(filepath.Join(d, "foo"), []byte("1"), 0600); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(d, "bar"), []byte("22"), 0400); err != nil { - t.Fatal(err) - } - if err := os.Mkdir(filepath.Join(d, "baz"), 0777); err != nil { - t.Fatal(err) - } - - var got []string - if err := WalkShallow(mem.S(d), func(name mem.RO, de os.DirEntry) error { - var size int64 - if fi, err := de.Info(); err != nil { - t.Errorf("Info stat error on %q: %v", de.Name(), err) - } else if !fi.IsDir() { - size = fi.Size() - } - got = append(got, fmt.Sprintf("%q %q dir=%v type=%d size=%v", name.StringCopy(), de.Name(), de.IsDir(), de.Type(), size)) - return nil - }); err != nil { - t.Fatal(err) - } - sort.Strings(got) - want := []string{ - `"bar" "bar" dir=false type=0 size=2`, - `"baz" "baz" dir=true type=2147483648 size=0`, - `"foo" "foo" dir=false type=0 size=1`, - } - if !reflect.DeepEqual(got, want) { - t.Errorf("mismatch:\n got %#q\nwant %#q", got, want) - } - }) - - t.Run("err_not_exist", func(t *testing.T) { - err := WalkShallow(mem.S(filepath.Join(d, "not_exist")), func(name mem.RO, de os.DirEntry) error { - return nil - }) - if !os.IsNotExist(err) { - t.Errorf("unexpected error: %v", err) - } - }) - - t.Run("allocs", func(t *testing.T) { - allocs := int(testing.AllocsPerRun(1000, func() { - if err := WalkShallow(mem.S(d), func(name mem.RO, de os.DirEntry) error { return nil }); err != nil { - t.Fatal(err) - } - })) - t.Logf("allocs = %v", allocs) - if !portable && runtime.GOOS == "linux" && allocs != 0 { - t.Errorf("unexpected allocs: got %v, want 0", allocs) - } - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dirwalk + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "runtime" + "sort" + "testing" + + "go4.org/mem" + "tailscale.com/tstest" +) + +func TestWalkShallowOSSpecific(t *testing.T) { + if osWalkShallow == nil { + t.Skip("no OS-specific implementation") + } + testWalkShallow(t, false) +} + +func TestWalkShallowPortable(t *testing.T) { + testWalkShallow(t, true) +} + +func testWalkShallow(t *testing.T, portable bool) { + if portable { + tstest.Replace(t, &osWalkShallow, nil) + } + d := t.TempDir() + + t.Run("basics", func(t *testing.T) { + if err := os.WriteFile(filepath.Join(d, "foo"), []byte("1"), 0600); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(d, "bar"), []byte("22"), 0400); err != nil { + t.Fatal(err) + } + if err := os.Mkdir(filepath.Join(d, "baz"), 0777); err != nil { + t.Fatal(err) + } + + var got []string + if err := WalkShallow(mem.S(d), func(name mem.RO, de os.DirEntry) error { + var size int64 + if fi, err := de.Info(); err != nil { + t.Errorf("Info stat error on %q: %v", de.Name(), err) + } else if !fi.IsDir() { + size = fi.Size() + } + got = append(got, fmt.Sprintf("%q %q dir=%v type=%d size=%v", name.StringCopy(), de.Name(), de.IsDir(), de.Type(), size)) + return nil + }); err != nil { + t.Fatal(err) + } + sort.Strings(got) + want := []string{ + `"bar" "bar" dir=false type=0 size=2`, + `"baz" "baz" dir=true type=2147483648 size=0`, + `"foo" "foo" dir=false type=0 size=1`, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch:\n got %#q\nwant %#q", got, want) + } + }) + + t.Run("err_not_exist", func(t *testing.T) { + err := WalkShallow(mem.S(filepath.Join(d, "not_exist")), func(name mem.RO, de os.DirEntry) error { + return nil + }) + if !os.IsNotExist(err) { + t.Errorf("unexpected error: %v", err) + } + }) + + t.Run("allocs", func(t *testing.T) { + allocs := int(testing.AllocsPerRun(1000, func() { + if err := WalkShallow(mem.S(d), func(name mem.RO, de os.DirEntry) error { return nil }); err != nil { + t.Fatal(err) + } + })) + t.Logf("allocs = %v", allocs) + if !portable && runtime.GOOS == "linux" && allocs != 0 { + t.Errorf("unexpected allocs: got %v, want 0", allocs) + } + }) +} diff --git a/util/goroutines/goroutines.go b/util/goroutines/goroutines.go index 9758b07586613..24c61b37cd399 100644 --- a/util/goroutines/goroutines.go +++ b/util/goroutines/goroutines.go @@ -1,93 +1,93 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The goroutines package contains utilities for getting active goroutines. -package goroutines - -import ( - "bytes" - "fmt" - "runtime" - "strconv" -) - -// ScrubbedGoroutineDump returns either the current goroutine's stack or all -// goroutines' stacks, but with the actual values of arguments scrubbed out, -// lest it contain some private key material. -func ScrubbedGoroutineDump(all bool) []byte { - var buf []byte - // Grab stacks multiple times into increasingly larger buffer sizes - // to minimize the risk that we blow past our iOS memory limit. - for size := 1 << 10; size <= 1<<20; size += 1 << 10 { - buf = make([]byte, size) - buf = buf[:runtime.Stack(buf, all)] - if len(buf) < size { - // It fit. - break - } - } - return scrubHex(buf) -} - -func scrubHex(buf []byte) []byte { - saw := map[string][]byte{} // "0x123" => "v1%3" (unique value 1 and its value mod 8) - - foreachHexAddress(buf, func(in []byte) { - if string(in) == "0x0" { - return - } - if v, ok := saw[string(in)]; ok { - for i := range in { - in[i] = '_' - } - copy(in, v) - return - } - inStr := string(in) - u64, err := strconv.ParseUint(string(in[2:]), 16, 64) - for i := range in { - in[i] = '_' - } - if err != nil { - in[0] = '?' - return - } - v := []byte(fmt.Sprintf("v%d%%%d", len(saw)+1, u64%8)) - saw[inStr] = v - copy(in, v) - }) - return buf -} - -var ohx = []byte("0x") - -// foreachHexAddress calls f with each subslice of b that matches -// regexp `0x[0-9a-f]*`. -func foreachHexAddress(b []byte, f func([]byte)) { - for len(b) > 0 { - i := bytes.Index(b, ohx) - if i == -1 { - return - } - b = b[i:] - hx := hexPrefix(b) - f(hx) - b = b[len(hx):] - } -} - -func hexPrefix(b []byte) []byte { - for i, c := range b { - if i < 2 { - continue - } - if !isHexByte(c) { - return b[:i] - } - } - return b -} - -func isHexByte(b byte) bool { - return '0' <= b && b <= '9' || 'a' <= b && b <= 'f' || 'A' <= b && b <= 'F' -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The goroutines package contains utilities for getting active goroutines. +package goroutines + +import ( + "bytes" + "fmt" + "runtime" + "strconv" +) + +// ScrubbedGoroutineDump returns either the current goroutine's stack or all +// goroutines' stacks, but with the actual values of arguments scrubbed out, +// lest it contain some private key material. +func ScrubbedGoroutineDump(all bool) []byte { + var buf []byte + // Grab stacks multiple times into increasingly larger buffer sizes + // to minimize the risk that we blow past our iOS memory limit. + for size := 1 << 10; size <= 1<<20; size += 1 << 10 { + buf = make([]byte, size) + buf = buf[:runtime.Stack(buf, all)] + if len(buf) < size { + // It fit. + break + } + } + return scrubHex(buf) +} + +func scrubHex(buf []byte) []byte { + saw := map[string][]byte{} // "0x123" => "v1%3" (unique value 1 and its value mod 8) + + foreachHexAddress(buf, func(in []byte) { + if string(in) == "0x0" { + return + } + if v, ok := saw[string(in)]; ok { + for i := range in { + in[i] = '_' + } + copy(in, v) + return + } + inStr := string(in) + u64, err := strconv.ParseUint(string(in[2:]), 16, 64) + for i := range in { + in[i] = '_' + } + if err != nil { + in[0] = '?' + return + } + v := []byte(fmt.Sprintf("v%d%%%d", len(saw)+1, u64%8)) + saw[inStr] = v + copy(in, v) + }) + return buf +} + +var ohx = []byte("0x") + +// foreachHexAddress calls f with each subslice of b that matches +// regexp `0x[0-9a-f]*`. +func foreachHexAddress(b []byte, f func([]byte)) { + for len(b) > 0 { + i := bytes.Index(b, ohx) + if i == -1 { + return + } + b = b[i:] + hx := hexPrefix(b) + f(hx) + b = b[len(hx):] + } +} + +func hexPrefix(b []byte) []byte { + for i, c := range b { + if i < 2 { + continue + } + if !isHexByte(c) { + return b[:i] + } + } + return b +} + +func isHexByte(b byte) bool { + return '0' <= b && b <= '9' || 'a' <= b && b <= 'f' || 'A' <= b && b <= 'F' +} diff --git a/util/goroutines/goroutines_test.go b/util/goroutines/goroutines_test.go index ae17c399ca274..df6560fe5e20b 100644 --- a/util/goroutines/goroutines_test.go +++ b/util/goroutines/goroutines_test.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package goroutines - -import "testing" - -func TestScrubbedGoroutineDump(t *testing.T) { - t.Logf("Got:\n%s\n", ScrubbedGoroutineDump(true)) -} - -func TestScrubHex(t *testing.T) { - tests := []struct { - in, want string - }{ - {"foo", "foo"}, - {"", ""}, - {"0x", "?_"}, - {"0x001 and same 0x001", "v1%1_ and same v1%1_"}, - {"0x008 and same 0x008", "v1%0_ and same v1%0_"}, - {"0x001 and diff 0x002", "v1%1_ and diff v2%2_"}, - } - for _, tt := range tests { - got := scrubHex([]byte(tt.in)) - if string(got) != tt.want { - t.Errorf("for input:\n%s\n\ngot:\n%s\n\nwant:\n%s\n", tt.in, got, tt.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package goroutines + +import "testing" + +func TestScrubbedGoroutineDump(t *testing.T) { + t.Logf("Got:\n%s\n", ScrubbedGoroutineDump(true)) +} + +func TestScrubHex(t *testing.T) { + tests := []struct { + in, want string + }{ + {"foo", "foo"}, + {"", ""}, + {"0x", "?_"}, + {"0x001 and same 0x001", "v1%1_ and same v1%1_"}, + {"0x008 and same 0x008", "v1%0_ and same v1%0_"}, + {"0x001 and diff 0x002", "v1%1_ and diff v2%2_"}, + } + for _, tt := range tests { + got := scrubHex([]byte(tt.in)) + if string(got) != tt.want { + t.Errorf("for input:\n%s\n\ngot:\n%s\n\nwant:\n%s\n", tt.in, got, tt.want) + } + } +} diff --git a/util/groupmember/groupmember.go b/util/groupmember/groupmember.go index d604168169022..38431a7ff8791 100644 --- a/util/groupmember/groupmember.go +++ b/util/groupmember/groupmember.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package groupmember verifies group membership of the provided user on the -// local system. -package groupmember - -import ( - "os/user" - "slices" -) - -// IsMemberOfGroup reports whether the provided user is a member of -// the provided system group. -func IsMemberOfGroup(group, userName string) (bool, error) { - u, err := user.Lookup(userName) - if err != nil { - return false, err - } - g, err := user.LookupGroup(group) - if err != nil { - return false, err - } - ugids, err := u.GroupIds() - if err != nil { - return false, err - } - return slices.Contains(ugids, g.Gid), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package groupmember verifies group membership of the provided user on the +// local system. +package groupmember + +import ( + "os/user" + "slices" +) + +// IsMemberOfGroup reports whether the provided user is a member of +// the provided system group. +func IsMemberOfGroup(group, userName string) (bool, error) { + u, err := user.Lookup(userName) + if err != nil { + return false, err + } + g, err := user.LookupGroup(group) + if err != nil { + return false, err + } + ugids, err := u.GroupIds() + if err != nil { + return false, err + } + return slices.Contains(ugids, g.Gid), nil +} diff --git a/util/hashx/block512.go b/util/hashx/block512.go index e637c0c030653..dd69ccd35637c 100644 --- a/util/hashx/block512.go +++ b/util/hashx/block512.go @@ -1,197 +1,197 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package hashx provides a concrete implementation of [hash.Hash] -// that operates on a particular block size. -package hashx - -import ( - "encoding/binary" - "fmt" - "hash" - "unsafe" -) - -var _ hash.Hash = (*Block512)(nil) - -// Block512 wraps a [hash.Hash] for functions that operate on 512-bit block sizes. -// It has efficient methods for hashing fixed-width integers. -// -// A hashing algorithm that operates on 512-bit block sizes should be used. -// The hash still operates correctly even with misaligned block sizes, -// but operates less efficiently. -// -// Example algorithms with 512-bit block sizes include: -// - MD4 (https://golang.org/x/crypto/md4) -// - MD5 (https://golang.org/pkg/crypto/md5) -// - BLAKE2s (https://golang.org/x/crypto/blake2s) -// - BLAKE3 -// - RIPEMD (https://golang.org/x/crypto/ripemd160) -// - SHA-0 -// - SHA-1 (https://golang.org/pkg/crypto/sha1) -// - SHA-2 (https://golang.org/pkg/crypto/sha256) -// - Whirlpool -// -// See https://en.wikipedia.org/wiki/Comparison_of_cryptographic_hash_functions#Parameters -// for a list of hash functions and their block sizes. -// -// Block512 assumes that [hash.Hash.Write] never fails and -// never allows the provided buffer to escape. -type Block512 struct { - hash.Hash - - x [512 / 8]byte - nx int -} - -// New512 constructs a new Block512 that wraps h. -// -// It reports an error if the block sizes do not match. -// Misaligned block sizes perform poorly, but execute correctly. -// The error may be ignored if performance is not a concern. -func New512(h hash.Hash) (*Block512, error) { - b := &Block512{Hash: h} - if len(b.x)%h.BlockSize() != 0 { - return b, fmt.Errorf("hashx.Block512: inefficient use of hash.Hash with %d-bit block size", 8*h.BlockSize()) - } - return b, nil -} - -// Write hashes the contents of b. -func (h *Block512) Write(b []byte) (int, error) { - h.HashBytes(b) - return len(b), nil -} - -// Sum appends the current hash to b and returns the resulting slice. -// -// It flushes any partially completed blocks to the underlying [hash.Hash], -// which may cause future operations to be misaligned and less efficient -// until [Block512.Reset] is called. -func (h *Block512) Sum(b []byte) []byte { - if h.nx > 0 { - h.Hash.Write(h.x[:h.nx]) - h.nx = 0 - } - - // Unfortunately hash.Hash.Sum always causes the input to escape since - // escape analysis cannot prove anything past an interface method call. - // Assuming h already escapes, we call Sum with h.x first, - // and then copy the result to b. - sum := h.Hash.Sum(h.x[:0]) - return append(b, sum...) -} - -// Reset resets Block512 to its initial state. -// It recursively resets the underlying [hash.Hash]. -func (h *Block512) Reset() { - h.Hash.Reset() - h.nx = 0 -} - -// HashUint8 hashes n as a 1-byte integer. -func (h *Block512) HashUint8(n uint8) { - // NOTE: This method is carefully written to be inlineable. - if h.nx <= len(h.x)-1 { - h.x[h.nx] = n - h.nx += 1 - } else { - h.hashUint8Slow(n) // mark "noinline" to keep this within inline budget - } -} - -//go:noinline -func (h *Block512) hashUint8Slow(n uint8) { h.hashUint(uint64(n), 1) } - -// HashUint16 hashes n as a 2-byte little-endian integer. -func (h *Block512) HashUint16(n uint16) { - // NOTE: This method is carefully written to be inlineable. - if h.nx <= len(h.x)-2 { - binary.LittleEndian.PutUint16(h.x[h.nx:], n) - h.nx += 2 - } else { - h.hashUint16Slow(n) // mark "noinline" to keep this within inline budget - } -} - -//go:noinline -func (h *Block512) hashUint16Slow(n uint16) { h.hashUint(uint64(n), 2) } - -// HashUint32 hashes n as a 4-byte little-endian integer. -func (h *Block512) HashUint32(n uint32) { - // NOTE: This method is carefully written to be inlineable. - if h.nx <= len(h.x)-4 { - binary.LittleEndian.PutUint32(h.x[h.nx:], n) - h.nx += 4 - } else { - h.hashUint32Slow(n) // mark "noinline" to keep this within inline budget - } -} - -//go:noinline -func (h *Block512) hashUint32Slow(n uint32) { h.hashUint(uint64(n), 4) } - -// HashUint64 hashes n as a 8-byte little-endian integer. -func (h *Block512) HashUint64(n uint64) { - // NOTE: This method is carefully written to be inlineable. - if h.nx <= len(h.x)-8 { - binary.LittleEndian.PutUint64(h.x[h.nx:], n) - h.nx += 8 - } else { - h.hashUint64Slow(n) // mark "noinline" to keep this within inline budget - } -} - -//go:noinline -func (h *Block512) hashUint64Slow(n uint64) { h.hashUint(uint64(n), 8) } - -func (h *Block512) hashUint(n uint64, i int) { - for ; i > 0; i-- { - if h.nx == len(h.x) { - h.Hash.Write(h.x[:]) - h.nx = 0 - } - h.x[h.nx] = byte(n) - h.nx += 1 - n >>= 8 - } -} - -// HashBytes hashes the contents of b. -// It does not explicitly hash the length separately. -func (h *Block512) HashBytes(b []byte) { - // Nearly identical to sha256.digest.Write. - if h.nx > 0 { - n := copy(h.x[h.nx:], b) - h.nx += n - if h.nx == len(h.x) { - h.Hash.Write(h.x[:]) - h.nx = 0 - } - b = b[n:] - } - if len(b) >= len(h.x) { - n := len(b) &^ (len(h.x) - 1) // n is a multiple of len(h.x) - h.Hash.Write(b[:n]) - b = b[n:] - } - if len(b) > 0 { - h.nx = copy(h.x[:], b) - } -} - -// HashString hashes the contents of s. -// It does not explicitly hash the length separately. -func (h *Block512) HashString(s string) { - // TODO: Avoid unsafe when standard hashers implement io.StringWriter. - // See https://go.dev/issue/38776. - type stringHeader struct { - p unsafe.Pointer - n int - } - p := (*stringHeader)(unsafe.Pointer(&s)) - b := unsafe.Slice((*byte)(p.p), p.n) - h.HashBytes(b) -} - -// TODO: Add Hash.MarshalBinary and Hash.UnmarshalBinary? +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package hashx provides a concrete implementation of [hash.Hash] +// that operates on a particular block size. +package hashx + +import ( + "encoding/binary" + "fmt" + "hash" + "unsafe" +) + +var _ hash.Hash = (*Block512)(nil) + +// Block512 wraps a [hash.Hash] for functions that operate on 512-bit block sizes. +// It has efficient methods for hashing fixed-width integers. +// +// A hashing algorithm that operates on 512-bit block sizes should be used. +// The hash still operates correctly even with misaligned block sizes, +// but operates less efficiently. +// +// Example algorithms with 512-bit block sizes include: +// - MD4 (https://golang.org/x/crypto/md4) +// - MD5 (https://golang.org/pkg/crypto/md5) +// - BLAKE2s (https://golang.org/x/crypto/blake2s) +// - BLAKE3 +// - RIPEMD (https://golang.org/x/crypto/ripemd160) +// - SHA-0 +// - SHA-1 (https://golang.org/pkg/crypto/sha1) +// - SHA-2 (https://golang.org/pkg/crypto/sha256) +// - Whirlpool +// +// See https://en.wikipedia.org/wiki/Comparison_of_cryptographic_hash_functions#Parameters +// for a list of hash functions and their block sizes. +// +// Block512 assumes that [hash.Hash.Write] never fails and +// never allows the provided buffer to escape. +type Block512 struct { + hash.Hash + + x [512 / 8]byte + nx int +} + +// New512 constructs a new Block512 that wraps h. +// +// It reports an error if the block sizes do not match. +// Misaligned block sizes perform poorly, but execute correctly. +// The error may be ignored if performance is not a concern. +func New512(h hash.Hash) (*Block512, error) { + b := &Block512{Hash: h} + if len(b.x)%h.BlockSize() != 0 { + return b, fmt.Errorf("hashx.Block512: inefficient use of hash.Hash with %d-bit block size", 8*h.BlockSize()) + } + return b, nil +} + +// Write hashes the contents of b. +func (h *Block512) Write(b []byte) (int, error) { + h.HashBytes(b) + return len(b), nil +} + +// Sum appends the current hash to b and returns the resulting slice. +// +// It flushes any partially completed blocks to the underlying [hash.Hash], +// which may cause future operations to be misaligned and less efficient +// until [Block512.Reset] is called. +func (h *Block512) Sum(b []byte) []byte { + if h.nx > 0 { + h.Hash.Write(h.x[:h.nx]) + h.nx = 0 + } + + // Unfortunately hash.Hash.Sum always causes the input to escape since + // escape analysis cannot prove anything past an interface method call. + // Assuming h already escapes, we call Sum with h.x first, + // and then copy the result to b. + sum := h.Hash.Sum(h.x[:0]) + return append(b, sum...) +} + +// Reset resets Block512 to its initial state. +// It recursively resets the underlying [hash.Hash]. +func (h *Block512) Reset() { + h.Hash.Reset() + h.nx = 0 +} + +// HashUint8 hashes n as a 1-byte integer. +func (h *Block512) HashUint8(n uint8) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-1 { + h.x[h.nx] = n + h.nx += 1 + } else { + h.hashUint8Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Block512) hashUint8Slow(n uint8) { h.hashUint(uint64(n), 1) } + +// HashUint16 hashes n as a 2-byte little-endian integer. +func (h *Block512) HashUint16(n uint16) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-2 { + binary.LittleEndian.PutUint16(h.x[h.nx:], n) + h.nx += 2 + } else { + h.hashUint16Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Block512) hashUint16Slow(n uint16) { h.hashUint(uint64(n), 2) } + +// HashUint32 hashes n as a 4-byte little-endian integer. +func (h *Block512) HashUint32(n uint32) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-4 { + binary.LittleEndian.PutUint32(h.x[h.nx:], n) + h.nx += 4 + } else { + h.hashUint32Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Block512) hashUint32Slow(n uint32) { h.hashUint(uint64(n), 4) } + +// HashUint64 hashes n as a 8-byte little-endian integer. +func (h *Block512) HashUint64(n uint64) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-8 { + binary.LittleEndian.PutUint64(h.x[h.nx:], n) + h.nx += 8 + } else { + h.hashUint64Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Block512) hashUint64Slow(n uint64) { h.hashUint(uint64(n), 8) } + +func (h *Block512) hashUint(n uint64, i int) { + for ; i > 0; i-- { + if h.nx == len(h.x) { + h.Hash.Write(h.x[:]) + h.nx = 0 + } + h.x[h.nx] = byte(n) + h.nx += 1 + n >>= 8 + } +} + +// HashBytes hashes the contents of b. +// It does not explicitly hash the length separately. +func (h *Block512) HashBytes(b []byte) { + // Nearly identical to sha256.digest.Write. + if h.nx > 0 { + n := copy(h.x[h.nx:], b) + h.nx += n + if h.nx == len(h.x) { + h.Hash.Write(h.x[:]) + h.nx = 0 + } + b = b[n:] + } + if len(b) >= len(h.x) { + n := len(b) &^ (len(h.x) - 1) // n is a multiple of len(h.x) + h.Hash.Write(b[:n]) + b = b[n:] + } + if len(b) > 0 { + h.nx = copy(h.x[:], b) + } +} + +// HashString hashes the contents of s. +// It does not explicitly hash the length separately. +func (h *Block512) HashString(s string) { + // TODO: Avoid unsafe when standard hashers implement io.StringWriter. + // See https://go.dev/issue/38776. + type stringHeader struct { + p unsafe.Pointer + n int + } + p := (*stringHeader)(unsafe.Pointer(&s)) + b := unsafe.Slice((*byte)(p.p), p.n) + h.HashBytes(b) +} + +// TODO: Add Hash.MarshalBinary and Hash.UnmarshalBinary? diff --git a/util/httphdr/httphdr.go b/util/httphdr/httphdr.go index 852e28b8fae03..b78b165c65701 100644 --- a/util/httphdr/httphdr.go +++ b/util/httphdr/httphdr.go @@ -1,197 +1,197 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package httphdr implements functionality for parsing and formatting -// standard HTTP headers. -package httphdr - -import ( - "bytes" - "strconv" - "strings" -) - -// Range is a range of bytes within some content. -type Range struct { - // Start is the starting offset. - // It is zero if Length is negative; it must not be negative. - Start int64 - // Length is the length of the content. - // It is zero if the length extends to the end of the content. - // It is negative if the length is relative to the end (e.g., last 5 bytes). - Length int64 -} - -// ows is optional whitespace. -const ows = " \t" // per RFC 7230, section 3.2.3 - -// ParseRange parses a "Range" header per RFC 7233, section 3. -// It only handles "Range" headers where the units is "bytes". -// The "Range" header is usually only specified in GET requests. -func ParseRange(hdr string) (ranges []Range, ok bool) { - // Grammar per RFC 7233, appendix D: - // Range = byte-ranges-specifier | other-ranges-specifier - // byte-ranges-specifier = bytes-unit "=" byte-range-set - // bytes-unit = "bytes" - // byte-range-set = - // *("," OWS) - // (byte-range-spec | suffix-byte-range-spec) - // *(OWS "," [OWS ( byte-range-spec | suffix-byte-range-spec )]) - // byte-range-spec = first-byte-pos "-" [last-byte-pos] - // suffix-byte-range-spec = "-" suffix-length - // We do not support other-ranges-specifier. - // All other identifiers are 1*DIGIT. - hdr = strings.Trim(hdr, ows) // per RFC 7230, section 3.2 - units, elems, hasUnits := strings.Cut(hdr, "=") - elems = strings.TrimLeft(elems, ","+ows) - for _, elem := range strings.Split(elems, ",") { - elem = strings.Trim(elem, ows) // per RFC 7230, section 7 - switch { - case strings.HasPrefix(elem, "-"): // i.e., "-" suffix-length - n, ok := parseNumber(strings.TrimPrefix(elem, "-")) - if !ok { - return ranges, false - } - ranges = append(ranges, Range{0, -n}) - case strings.HasSuffix(elem, "-"): // i.e., first-byte-pos "-" - n, ok := parseNumber(strings.TrimSuffix(elem, "-")) - if !ok { - return ranges, false - } - ranges = append(ranges, Range{n, 0}) - default: // i.e., first-byte-pos "-" last-byte-pos - prefix, suffix, hasDash := strings.Cut(elem, "-") - n, ok2 := parseNumber(prefix) - m, ok3 := parseNumber(suffix) - if !hasDash || !ok2 || !ok3 || m < n { - return ranges, false - } - ranges = append(ranges, Range{n, m - n + 1}) - } - } - return ranges, units == "bytes" && hasUnits && len(ranges) > 0 // must see at least one element per RFC 7233, section 2.1 -} - -// FormatRange formats a "Range" header per RFC 7233, section 3. -// It only handles "Range" headers where the units is "bytes". -// The "Range" header is usually only specified in GET requests. -func FormatRange(ranges []Range) (hdr string, ok bool) { - b := []byte("bytes=") - for _, r := range ranges { - switch { - case r.Length > 0: // i.e., first-byte-pos "-" last-byte-pos - if r.Start < 0 { - return string(b), false - } - b = strconv.AppendUint(b, uint64(r.Start), 10) - b = append(b, '-') - b = strconv.AppendUint(b, uint64(r.Start+r.Length-1), 10) - b = append(b, ',') - case r.Length == 0: // i.e., first-byte-pos "-" - if r.Start < 0 { - return string(b), false - } - b = strconv.AppendUint(b, uint64(r.Start), 10) - b = append(b, '-') - b = append(b, ',') - case r.Length < 0: // i.e., "-" suffix-length - if r.Start != 0 { - return string(b), false - } - b = append(b, '-') - b = strconv.AppendUint(b, uint64(-r.Length), 10) - b = append(b, ',') - default: - return string(b), false - } - } - return string(bytes.TrimRight(b, ",")), len(ranges) > 0 -} - -// ParseContentRange parses a "Content-Range" header per RFC 7233, section 4.2. -// It only handles "Content-Range" headers where the units is "bytes". -// The "Content-Range" header is usually only specified in HTTP responses. -// -// If only the completeLength is specified, then start and length are both zero. -// -// Otherwise, the parses the start and length and the optional completeLength, -// which is -1 if unspecified. The start is non-negative and the length is positive. -func ParseContentRange(hdr string) (start, length, completeLength int64, ok bool) { - // Grammar per RFC 7233, appendix D: - // Content-Range = byte-content-range | other-content-range - // byte-content-range = bytes-unit SP (byte-range-resp | unsatisfied-range) - // bytes-unit = "bytes" - // byte-range-resp = byte-range "/" (complete-length | "*") - // unsatisfied-range = "*/" complete-length - // byte-range = first-byte-pos "-" last-byte-pos - // We do not support other-content-range. - // All other identifiers are 1*DIGIT. - hdr = strings.Trim(hdr, ows) // per RFC 7230, section 3.2 - suffix, hasUnits := strings.CutPrefix(hdr, "bytes ") - suffix, unsatisfied := strings.CutPrefix(suffix, "*/") - if unsatisfied { // i.e., unsatisfied-range - n, ok := parseNumber(suffix) - if !ok { - return start, length, completeLength, false - } - completeLength = n - } else { // i.e., byte-range "/" (complete-length | "*") - prefix, suffix, hasDash := strings.Cut(suffix, "-") - middle, suffix, hasSlash := strings.Cut(suffix, "/") - n, ok0 := parseNumber(prefix) - m, ok1 := parseNumber(middle) - o, ok2 := parseNumber(suffix) - if suffix == "*" { - o, ok2 = -1, true - } - if !hasDash || !hasSlash || !ok0 || !ok1 || !ok2 || m < n || (o >= 0 && o <= m) { - return start, length, completeLength, false - } - start = n - length = m - n + 1 - completeLength = o - } - return start, length, completeLength, hasUnits -} - -// FormatContentRange parses a "Content-Range" header per RFC 7233, section 4.2. -// It only handles "Content-Range" headers where the units is "bytes". -// The "Content-Range" header is usually only specified in HTTP responses. -// -// If start and length are non-positive, then it encodes just the completeLength, -// which must be a non-negative value. -// -// Otherwise, it encodes the start and length as a byte-range, -// and optionally emits the complete length if it is non-negative. -// The length must be positive (as RFC 7233 uses inclusive end offsets). -func FormatContentRange(start, length, completeLength int64) (hdr string, ok bool) { - b := []byte("bytes ") - switch { - case start <= 0 && length <= 0 && completeLength >= 0: // i.e., unsatisfied-range - b = append(b, "*/"...) - b = strconv.AppendUint(b, uint64(completeLength), 10) - ok = true - case start >= 0 && length > 0: // i.e., byte-range "/" (complete-length | "*") - b = strconv.AppendUint(b, uint64(start), 10) - b = append(b, '-') - b = strconv.AppendUint(b, uint64(start+length-1), 10) - b = append(b, '/') - if completeLength >= 0 { - b = strconv.AppendUint(b, uint64(completeLength), 10) - ok = completeLength >= start+length && start+length > 0 - } else { - b = append(b, '*') - ok = true - } - } - return string(b), ok -} - -// parseNumber parses s as an unsigned decimal integer. -// It parses according to the 1*DIGIT grammar, which allows leading zeros. -func parseNumber(s string) (int64, bool) { - suffix := strings.TrimLeft(s, "0123456789") - prefix := s[:len(s)-len(suffix)] - n, err := strconv.ParseInt(prefix, 10, 64) - return n, suffix == "" && err == nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package httphdr implements functionality for parsing and formatting +// standard HTTP headers. +package httphdr + +import ( + "bytes" + "strconv" + "strings" +) + +// Range is a range of bytes within some content. +type Range struct { + // Start is the starting offset. + // It is zero if Length is negative; it must not be negative. + Start int64 + // Length is the length of the content. + // It is zero if the length extends to the end of the content. + // It is negative if the length is relative to the end (e.g., last 5 bytes). + Length int64 +} + +// ows is optional whitespace. +const ows = " \t" // per RFC 7230, section 3.2.3 + +// ParseRange parses a "Range" header per RFC 7233, section 3. +// It only handles "Range" headers where the units is "bytes". +// The "Range" header is usually only specified in GET requests. +func ParseRange(hdr string) (ranges []Range, ok bool) { + // Grammar per RFC 7233, appendix D: + // Range = byte-ranges-specifier | other-ranges-specifier + // byte-ranges-specifier = bytes-unit "=" byte-range-set + // bytes-unit = "bytes" + // byte-range-set = + // *("," OWS) + // (byte-range-spec | suffix-byte-range-spec) + // *(OWS "," [OWS ( byte-range-spec | suffix-byte-range-spec )]) + // byte-range-spec = first-byte-pos "-" [last-byte-pos] + // suffix-byte-range-spec = "-" suffix-length + // We do not support other-ranges-specifier. + // All other identifiers are 1*DIGIT. + hdr = strings.Trim(hdr, ows) // per RFC 7230, section 3.2 + units, elems, hasUnits := strings.Cut(hdr, "=") + elems = strings.TrimLeft(elems, ","+ows) + for _, elem := range strings.Split(elems, ",") { + elem = strings.Trim(elem, ows) // per RFC 7230, section 7 + switch { + case strings.HasPrefix(elem, "-"): // i.e., "-" suffix-length + n, ok := parseNumber(strings.TrimPrefix(elem, "-")) + if !ok { + return ranges, false + } + ranges = append(ranges, Range{0, -n}) + case strings.HasSuffix(elem, "-"): // i.e., first-byte-pos "-" + n, ok := parseNumber(strings.TrimSuffix(elem, "-")) + if !ok { + return ranges, false + } + ranges = append(ranges, Range{n, 0}) + default: // i.e., first-byte-pos "-" last-byte-pos + prefix, suffix, hasDash := strings.Cut(elem, "-") + n, ok2 := parseNumber(prefix) + m, ok3 := parseNumber(suffix) + if !hasDash || !ok2 || !ok3 || m < n { + return ranges, false + } + ranges = append(ranges, Range{n, m - n + 1}) + } + } + return ranges, units == "bytes" && hasUnits && len(ranges) > 0 // must see at least one element per RFC 7233, section 2.1 +} + +// FormatRange formats a "Range" header per RFC 7233, section 3. +// It only handles "Range" headers where the units is "bytes". +// The "Range" header is usually only specified in GET requests. +func FormatRange(ranges []Range) (hdr string, ok bool) { + b := []byte("bytes=") + for _, r := range ranges { + switch { + case r.Length > 0: // i.e., first-byte-pos "-" last-byte-pos + if r.Start < 0 { + return string(b), false + } + b = strconv.AppendUint(b, uint64(r.Start), 10) + b = append(b, '-') + b = strconv.AppendUint(b, uint64(r.Start+r.Length-1), 10) + b = append(b, ',') + case r.Length == 0: // i.e., first-byte-pos "-" + if r.Start < 0 { + return string(b), false + } + b = strconv.AppendUint(b, uint64(r.Start), 10) + b = append(b, '-') + b = append(b, ',') + case r.Length < 0: // i.e., "-" suffix-length + if r.Start != 0 { + return string(b), false + } + b = append(b, '-') + b = strconv.AppendUint(b, uint64(-r.Length), 10) + b = append(b, ',') + default: + return string(b), false + } + } + return string(bytes.TrimRight(b, ",")), len(ranges) > 0 +} + +// ParseContentRange parses a "Content-Range" header per RFC 7233, section 4.2. +// It only handles "Content-Range" headers where the units is "bytes". +// The "Content-Range" header is usually only specified in HTTP responses. +// +// If only the completeLength is specified, then start and length are both zero. +// +// Otherwise, the parses the start and length and the optional completeLength, +// which is -1 if unspecified. The start is non-negative and the length is positive. +func ParseContentRange(hdr string) (start, length, completeLength int64, ok bool) { + // Grammar per RFC 7233, appendix D: + // Content-Range = byte-content-range | other-content-range + // byte-content-range = bytes-unit SP (byte-range-resp | unsatisfied-range) + // bytes-unit = "bytes" + // byte-range-resp = byte-range "/" (complete-length | "*") + // unsatisfied-range = "*/" complete-length + // byte-range = first-byte-pos "-" last-byte-pos + // We do not support other-content-range. + // All other identifiers are 1*DIGIT. + hdr = strings.Trim(hdr, ows) // per RFC 7230, section 3.2 + suffix, hasUnits := strings.CutPrefix(hdr, "bytes ") + suffix, unsatisfied := strings.CutPrefix(suffix, "*/") + if unsatisfied { // i.e., unsatisfied-range + n, ok := parseNumber(suffix) + if !ok { + return start, length, completeLength, false + } + completeLength = n + } else { // i.e., byte-range "/" (complete-length | "*") + prefix, suffix, hasDash := strings.Cut(suffix, "-") + middle, suffix, hasSlash := strings.Cut(suffix, "/") + n, ok0 := parseNumber(prefix) + m, ok1 := parseNumber(middle) + o, ok2 := parseNumber(suffix) + if suffix == "*" { + o, ok2 = -1, true + } + if !hasDash || !hasSlash || !ok0 || !ok1 || !ok2 || m < n || (o >= 0 && o <= m) { + return start, length, completeLength, false + } + start = n + length = m - n + 1 + completeLength = o + } + return start, length, completeLength, hasUnits +} + +// FormatContentRange parses a "Content-Range" header per RFC 7233, section 4.2. +// It only handles "Content-Range" headers where the units is "bytes". +// The "Content-Range" header is usually only specified in HTTP responses. +// +// If start and length are non-positive, then it encodes just the completeLength, +// which must be a non-negative value. +// +// Otherwise, it encodes the start and length as a byte-range, +// and optionally emits the complete length if it is non-negative. +// The length must be positive (as RFC 7233 uses inclusive end offsets). +func FormatContentRange(start, length, completeLength int64) (hdr string, ok bool) { + b := []byte("bytes ") + switch { + case start <= 0 && length <= 0 && completeLength >= 0: // i.e., unsatisfied-range + b = append(b, "*/"...) + b = strconv.AppendUint(b, uint64(completeLength), 10) + ok = true + case start >= 0 && length > 0: // i.e., byte-range "/" (complete-length | "*") + b = strconv.AppendUint(b, uint64(start), 10) + b = append(b, '-') + b = strconv.AppendUint(b, uint64(start+length-1), 10) + b = append(b, '/') + if completeLength >= 0 { + b = strconv.AppendUint(b, uint64(completeLength), 10) + ok = completeLength >= start+length && start+length > 0 + } else { + b = append(b, '*') + ok = true + } + } + return string(b), ok +} + +// parseNumber parses s as an unsigned decimal integer. +// It parses according to the 1*DIGIT grammar, which allows leading zeros. +func parseNumber(s string) (int64, bool) { + suffix := strings.TrimLeft(s, "0123456789") + prefix := s[:len(s)-len(suffix)] + n, err := strconv.ParseInt(prefix, 10, 64) + return n, suffix == "" && err == nil +} diff --git a/util/httphdr/httphdr_test.go b/util/httphdr/httphdr_test.go index 81feeaca080d8..77ec0c3247d3e 100644 --- a/util/httphdr/httphdr_test.go +++ b/util/httphdr/httphdr_test.go @@ -1,96 +1,96 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package httphdr - -import ( - "testing" - - "github.com/google/go-cmp/cmp" -) - -func valOk[T any](v T, ok bool) (out struct { - V T - Ok bool -}) { - out.V = v - out.Ok = ok - return out -} - -func TestRange(t *testing.T) { - tests := []struct { - in string - want []Range - wantOk bool - roundtrip bool - }{ - {"", nil, false, false}, - {"1-3", nil, false, false}, - {"units=1-3", []Range{{1, 3}}, false, false}, - {"bytes=1-3", []Range{{1, 3}}, true, true}, - {"bytes=#-3", nil, false, false}, - {"bytes=#-", nil, false, false}, - {"bytes=13", nil, false, false}, - {"bytes=1-#", nil, false, false}, - {"bytes=-#", nil, false, false}, - {"bytes= , , , ,\t , \t 1-3", []Range{{1, 3}}, true, false}, - {"bytes=1-1", []Range{{1, 1}}, true, true}, - {"bytes=01-01", []Range{{1, 1}}, true, false}, - {"bytes=1-0", nil, false, false}, - {"bytes=0-5,2-3", []Range{{0, 6}, {2, 2}}, true, true}, - {"bytes=2-3,0-5", []Range{{2, 2}, {0, 6}}, true, true}, - {"bytes=0-5,2-,-5", []Range{{0, 6}, {2, 0}, {0, -5}}, true, true}, - } - - for _, tt := range tests { - got, gotOk := ParseRange(tt.in) - if d := cmp.Diff(valOk(got, gotOk), valOk(tt.want, tt.wantOk)); d != "" { - t.Errorf("ParseRange(%q) mismatch (-got +want):\n%s", tt.in, d) - } - if tt.roundtrip { - got, gotOk := FormatRange(tt.want) - if d := cmp.Diff(valOk(got, gotOk), valOk(tt.in, tt.wantOk)); d != "" { - t.Errorf("FormatRange(%v) mismatch (-got +want):\n%s", tt.want, d) - } - } - } -} - -type contentRange struct{ Start, Length, CompleteLength int64 } - -func TestContentRange(t *testing.T) { - tests := []struct { - in string - want contentRange - wantOk bool - roundtrip bool - }{ - {"", contentRange{}, false, false}, - {"bytes 5-6/*", contentRange{5, 2, -1}, true, true}, - {"units 5-6/*", contentRange{}, false, false}, - {"bytes 5-6/*", contentRange{}, false, false}, - {"bytes 5-5/*", contentRange{5, 1, -1}, true, true}, - {"bytes 5-4/*", contentRange{}, false, false}, - {"bytes 5-5/6", contentRange{5, 1, 6}, true, true}, - {"bytes 05-005/0006", contentRange{5, 1, 6}, true, false}, - {"bytes 5-5/5", contentRange{}, false, false}, - {"bytes #-5/6", contentRange{}, false, false}, - {"bytes 5-#/6", contentRange{}, false, false}, - {"bytes 5-5/#", contentRange{}, false, false}, - } - - for _, tt := range tests { - start, length, completeLength, gotOk := ParseContentRange(tt.in) - got := contentRange{start, length, completeLength} - if d := cmp.Diff(valOk(got, gotOk), valOk(tt.want, tt.wantOk)); d != "" { - t.Errorf("ParseContentRange mismatch (-got +want):\n%s", d) - } - if tt.roundtrip { - got, gotOk := FormatContentRange(tt.want.Start, tt.want.Length, tt.want.CompleteLength) - if d := cmp.Diff(valOk(got, gotOk), valOk(tt.in, tt.wantOk)); d != "" { - t.Errorf("FormatContentRange mismatch (-got +want):\n%s", d) - } - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package httphdr + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func valOk[T any](v T, ok bool) (out struct { + V T + Ok bool +}) { + out.V = v + out.Ok = ok + return out +} + +func TestRange(t *testing.T) { + tests := []struct { + in string + want []Range + wantOk bool + roundtrip bool + }{ + {"", nil, false, false}, + {"1-3", nil, false, false}, + {"units=1-3", []Range{{1, 3}}, false, false}, + {"bytes=1-3", []Range{{1, 3}}, true, true}, + {"bytes=#-3", nil, false, false}, + {"bytes=#-", nil, false, false}, + {"bytes=13", nil, false, false}, + {"bytes=1-#", nil, false, false}, + {"bytes=-#", nil, false, false}, + {"bytes= , , , ,\t , \t 1-3", []Range{{1, 3}}, true, false}, + {"bytes=1-1", []Range{{1, 1}}, true, true}, + {"bytes=01-01", []Range{{1, 1}}, true, false}, + {"bytes=1-0", nil, false, false}, + {"bytes=0-5,2-3", []Range{{0, 6}, {2, 2}}, true, true}, + {"bytes=2-3,0-5", []Range{{2, 2}, {0, 6}}, true, true}, + {"bytes=0-5,2-,-5", []Range{{0, 6}, {2, 0}, {0, -5}}, true, true}, + } + + for _, tt := range tests { + got, gotOk := ParseRange(tt.in) + if d := cmp.Diff(valOk(got, gotOk), valOk(tt.want, tt.wantOk)); d != "" { + t.Errorf("ParseRange(%q) mismatch (-got +want):\n%s", tt.in, d) + } + if tt.roundtrip { + got, gotOk := FormatRange(tt.want) + if d := cmp.Diff(valOk(got, gotOk), valOk(tt.in, tt.wantOk)); d != "" { + t.Errorf("FormatRange(%v) mismatch (-got +want):\n%s", tt.want, d) + } + } + } +} + +type contentRange struct{ Start, Length, CompleteLength int64 } + +func TestContentRange(t *testing.T) { + tests := []struct { + in string + want contentRange + wantOk bool + roundtrip bool + }{ + {"", contentRange{}, false, false}, + {"bytes 5-6/*", contentRange{5, 2, -1}, true, true}, + {"units 5-6/*", contentRange{}, false, false}, + {"bytes 5-6/*", contentRange{}, false, false}, + {"bytes 5-5/*", contentRange{5, 1, -1}, true, true}, + {"bytes 5-4/*", contentRange{}, false, false}, + {"bytes 5-5/6", contentRange{5, 1, 6}, true, true}, + {"bytes 05-005/0006", contentRange{5, 1, 6}, true, false}, + {"bytes 5-5/5", contentRange{}, false, false}, + {"bytes #-5/6", contentRange{}, false, false}, + {"bytes 5-#/6", contentRange{}, false, false}, + {"bytes 5-5/#", contentRange{}, false, false}, + } + + for _, tt := range tests { + start, length, completeLength, gotOk := ParseContentRange(tt.in) + got := contentRange{start, length, completeLength} + if d := cmp.Diff(valOk(got, gotOk), valOk(tt.want, tt.wantOk)); d != "" { + t.Errorf("ParseContentRange mismatch (-got +want):\n%s", d) + } + if tt.roundtrip { + got, gotOk := FormatContentRange(tt.want.Start, tt.want.Length, tt.want.CompleteLength) + if d := cmp.Diff(valOk(got, gotOk), valOk(tt.in, tt.wantOk)); d != "" { + t.Errorf("FormatContentRange mismatch (-got +want):\n%s", d) + } + } + } +} diff --git a/util/httpm/httpm.go b/util/httpm/httpm.go index a9a691b8a69e2..05292f0fa1fa2 100644 --- a/util/httpm/httpm.go +++ b/util/httpm/httpm.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package httpm has shorter names for HTTP method constants. -// -// Some background: originally Go didn't have http.MethodGet, http.MethodPost -// and life was good and people just wrote readable "GET" and "POST". But then -// in a moment of weakness Brad and others maintaining net/http caved and let -// the http.MethodFoo constants be added and code's been less readable since. -// Now the substance of the method name is hidden away at the end after -// "http.Method" and they all blend together and it's hard to read code using -// them. -// -// This package is a compromise. It provides constants, but shorter and closer -// to how it used to look. It does violate Go style -// (https://github.com/golang/go/wiki/CodeReviewComments#mixed-caps) that says -// constants shouldn't be SCREAM_CASE. But this isn't INT_MAX; it's GET and -// POST, which are already defined as all caps. -// -// It would be tempting to make these constants be typed but then they wouldn't -// be assignable to things in net/http that just want string. Oh well. -package httpm - -const ( - GET = "GET" - HEAD = "HEAD" - POST = "POST" - PUT = "PUT" - PATCH = "PATCH" - DELETE = "DELETE" - CONNECT = "CONNECT" - OPTIONS = "OPTIONS" - TRACE = "TRACE" - SPACEJUMP = "SPACEJUMP" // https://www.w3.org/Protocols/HTTP/Methods/SpaceJump.html - BREW = "BREW" // https://datatracker.ietf.org/doc/html/rfc2324#section-2.1.1 -) +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package httpm has shorter names for HTTP method constants. +// +// Some background: originally Go didn't have http.MethodGet, http.MethodPost +// and life was good and people just wrote readable "GET" and "POST". But then +// in a moment of weakness Brad and others maintaining net/http caved and let +// the http.MethodFoo constants be added and code's been less readable since. +// Now the substance of the method name is hidden away at the end after +// "http.Method" and they all blend together and it's hard to read code using +// them. +// +// This package is a compromise. It provides constants, but shorter and closer +// to how it used to look. It does violate Go style +// (https://github.com/golang/go/wiki/CodeReviewComments#mixed-caps) that says +// constants shouldn't be SCREAM_CASE. But this isn't INT_MAX; it's GET and +// POST, which are already defined as all caps. +// +// It would be tempting to make these constants be typed but then they wouldn't +// be assignable to things in net/http that just want string. Oh well. +package httpm + +const ( + GET = "GET" + HEAD = "HEAD" + POST = "POST" + PUT = "PUT" + PATCH = "PATCH" + DELETE = "DELETE" + CONNECT = "CONNECT" + OPTIONS = "OPTIONS" + TRACE = "TRACE" + SPACEJUMP = "SPACEJUMP" // https://www.w3.org/Protocols/HTTP/Methods/SpaceJump.html + BREW = "BREW" // https://datatracker.ietf.org/doc/html/rfc2324#section-2.1.1 +) diff --git a/util/httpm/httpm_test.go b/util/httpm/httpm_test.go index 0c71edc2f3c42..cbe327d956083 100644 --- a/util/httpm/httpm_test.go +++ b/util/httpm/httpm_test.go @@ -1,37 +1,37 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package httpm - -import ( - "os" - "os/exec" - "path/filepath" - "strings" - "testing" -) - -func TestUsedConsistently(t *testing.T) { - dir, err := os.Getwd() - if err != nil { - t.Fatal(err) - } - rootDir := filepath.Join(dir, "../..") - - // If we don't have a .git directory, we're not in a git checkout (e.g. - // a downstream package); skip this test. - if _, err := os.Stat(filepath.Join(rootDir, ".git")); err != nil { - t.Skipf("skipping test since .git doesn't exist: %v", err) - } - - cmd := exec.Command("git", "grep", "-l", "-F", "http.Method") - cmd.Dir = rootDir - matches, _ := cmd.Output() - for _, fn := range strings.Split(strings.TrimSpace(string(matches)), "\n") { - switch fn { - case "util/httpm/httpm.go", "util/httpm/httpm_test.go": - continue - } - t.Errorf("http.MethodFoo constant used in %s; use httpm.FOO instead", fn) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package httpm + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +func TestUsedConsistently(t *testing.T) { + dir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + rootDir := filepath.Join(dir, "../..") + + // If we don't have a .git directory, we're not in a git checkout (e.g. + // a downstream package); skip this test. + if _, err := os.Stat(filepath.Join(rootDir, ".git")); err != nil { + t.Skipf("skipping test since .git doesn't exist: %v", err) + } + + cmd := exec.Command("git", "grep", "-l", "-F", "http.Method") + cmd.Dir = rootDir + matches, _ := cmd.Output() + for _, fn := range strings.Split(strings.TrimSpace(string(matches)), "\n") { + switch fn { + case "util/httpm/httpm.go", "util/httpm/httpm_test.go": + continue + } + t.Errorf("http.MethodFoo constant used in %s; use httpm.FOO instead", fn) + } +} diff --git a/util/jsonutil/types.go b/util/jsonutil/types.go index 057473249f258..2ee53f44a1037 100644 --- a/util/jsonutil/types.go +++ b/util/jsonutil/types.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package jsonutil - -// Bytes is a byte slice in a json-encoded struct. -// encoding/json assumes that []byte fields are hex-encoded. -// Bytes are not hex-encoded; they are treated the same as strings. -// This can avoid unnecessary allocations due to a round trip through strings. -type Bytes []byte - -func (b *Bytes) UnmarshalText(text []byte) error { - // Copy the contexts of text. - *b = append(*b, text...) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package jsonutil + +// Bytes is a byte slice in a json-encoded struct. +// encoding/json assumes that []byte fields are hex-encoded. +// Bytes are not hex-encoded; they are treated the same as strings. +// This can avoid unnecessary allocations due to a round trip through strings. +type Bytes []byte + +func (b *Bytes) UnmarshalText(text []byte) error { + // Copy the contexts of text. + *b = append(*b, text...) + return nil +} diff --git a/util/jsonutil/unmarshal.go b/util/jsonutil/unmarshal.go index b1eb4ea873e67..13aea0c87ff30 100644 --- a/util/jsonutil/unmarshal.go +++ b/util/jsonutil/unmarshal.go @@ -1,89 +1,89 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package jsonutil provides utilities to improve JSON performance. -// It includes an Unmarshal wrapper that amortizes allocated garbage over subsequent runs -// and a Bytes type to reduce allocations when unmarshalling a non-hex-encoded string into a []byte. -package jsonutil - -import ( - "bytes" - "encoding/json" - "sync" -) - -// decoder is a re-usable json decoder. -type decoder struct { - dec *json.Decoder - r *bytes.Reader -} - -var readerPool = sync.Pool{ - New: func() any { - return bytes.NewReader(nil) - }, -} - -var decoderPool = sync.Pool{ - New: func() any { - var d decoder - d.r = readerPool.Get().(*bytes.Reader) - d.dec = json.NewDecoder(d.r) - return &d - }, -} - -// Unmarshal is similar to encoding/json.Unmarshal. -// There are three major differences: -// -// On error, encoding/json.Unmarshal zeros v. -// This Unmarshal may leave partial data in v. -// Always check the error before using v! -// (Future improvements may remove this bug.) -// -// The errors they return don't always match perfectly. -// If you do error matching more precise than err != nil, -// don't use this Unmarshal. -// -// This Unmarshal allocates considerably less memory. -func Unmarshal(b []byte, v any) error { - d := decoderPool.Get().(*decoder) - d.r.Reset(b) - off := d.dec.InputOffset() - err := d.dec.Decode(v) - d.r.Reset(nil) // don't keep a reference to b - // In case of error, report the offset in this byte slice, - // instead of in the totality of all bytes this decoder has processed. - // It is not possible to make all errors match json.Unmarshal exactly, - // but we can at least try. - switch jsonerr := err.(type) { - case *json.SyntaxError: - jsonerr.Offset -= off - case *json.UnmarshalTypeError: - jsonerr.Offset -= off - case nil: - // json.Unmarshal fails if there's any extra junk in the input. - // json.Decoder does not; see https://github.com/golang/go/issues/36225. - // We need to check for anything left over in the buffer. - if d.dec.More() { - // TODO: Provide a better error message. - // Unfortunately, we can't set the msg field. - // The offset doesn't perfectly match json: - // Ours is at the end of the valid data, - // and theirs is at the beginning of the extra data after whitespace. - // Close enough, though. - err = &json.SyntaxError{Offset: d.dec.InputOffset() - off} - - // TODO: zero v. This is hard; see encoding/json.indirect. - } - } - if err == nil { - decoderPool.Put(d) - } else { - // There might be junk left in the decoder's buffer. - // There's no way to flush it, no Reset method. - // Abandoned the decoder but reuse the reader. - readerPool.Put(d.r) - } - return err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package jsonutil provides utilities to improve JSON performance. +// It includes an Unmarshal wrapper that amortizes allocated garbage over subsequent runs +// and a Bytes type to reduce allocations when unmarshalling a non-hex-encoded string into a []byte. +package jsonutil + +import ( + "bytes" + "encoding/json" + "sync" +) + +// decoder is a re-usable json decoder. +type decoder struct { + dec *json.Decoder + r *bytes.Reader +} + +var readerPool = sync.Pool{ + New: func() any { + return bytes.NewReader(nil) + }, +} + +var decoderPool = sync.Pool{ + New: func() any { + var d decoder + d.r = readerPool.Get().(*bytes.Reader) + d.dec = json.NewDecoder(d.r) + return &d + }, +} + +// Unmarshal is similar to encoding/json.Unmarshal. +// There are three major differences: +// +// On error, encoding/json.Unmarshal zeros v. +// This Unmarshal may leave partial data in v. +// Always check the error before using v! +// (Future improvements may remove this bug.) +// +// The errors they return don't always match perfectly. +// If you do error matching more precise than err != nil, +// don't use this Unmarshal. +// +// This Unmarshal allocates considerably less memory. +func Unmarshal(b []byte, v any) error { + d := decoderPool.Get().(*decoder) + d.r.Reset(b) + off := d.dec.InputOffset() + err := d.dec.Decode(v) + d.r.Reset(nil) // don't keep a reference to b + // In case of error, report the offset in this byte slice, + // instead of in the totality of all bytes this decoder has processed. + // It is not possible to make all errors match json.Unmarshal exactly, + // but we can at least try. + switch jsonerr := err.(type) { + case *json.SyntaxError: + jsonerr.Offset -= off + case *json.UnmarshalTypeError: + jsonerr.Offset -= off + case nil: + // json.Unmarshal fails if there's any extra junk in the input. + // json.Decoder does not; see https://github.com/golang/go/issues/36225. + // We need to check for anything left over in the buffer. + if d.dec.More() { + // TODO: Provide a better error message. + // Unfortunately, we can't set the msg field. + // The offset doesn't perfectly match json: + // Ours is at the end of the valid data, + // and theirs is at the beginning of the extra data after whitespace. + // Close enough, though. + err = &json.SyntaxError{Offset: d.dec.InputOffset() - off} + + // TODO: zero v. This is hard; see encoding/json.indirect. + } + } + if err == nil { + decoderPool.Put(d) + } else { + // There might be junk left in the decoder's buffer. + // There's no way to flush it, no Reset method. + // Abandoned the decoder but reuse the reader. + readerPool.Put(d.r) + } + return err +} diff --git a/util/lineread/lineread.go b/util/lineread/lineread.go index 6b01d2b69ffd7..2a7486e0a4fec 100644 --- a/util/lineread/lineread.go +++ b/util/lineread/lineread.go @@ -1,37 +1,37 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package lineread reads lines from files. It's not fancy, but it got repetitive. -package lineread - -import ( - "bufio" - "io" - "os" -) - -// File opens name and calls fn for each line. It returns an error if the Open failed -// or once fn returns an error. -func File(name string, fn func(line []byte) error) error { - f, err := os.Open(name) - if err != nil { - return err - } - defer f.Close() - return Reader(f, fn) -} - -// Reader calls fn for each line. -// If fn returns an error, Reader stops reading and returns that error. -// Reader may also return errors encountered reading and parsing from r. -// To stop reading early, use a sentinel "stop" error value and ignore -// it when returned from Reader. -func Reader(r io.Reader, fn func(line []byte) error) error { - bs := bufio.NewScanner(r) - for bs.Scan() { - if err := fn(bs.Bytes()); err != nil { - return err - } - } - return bs.Err() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package lineread reads lines from files. It's not fancy, but it got repetitive. +package lineread + +import ( + "bufio" + "io" + "os" +) + +// File opens name and calls fn for each line. It returns an error if the Open failed +// or once fn returns an error. +func File(name string, fn func(line []byte) error) error { + f, err := os.Open(name) + if err != nil { + return err + } + defer f.Close() + return Reader(f, fn) +} + +// Reader calls fn for each line. +// If fn returns an error, Reader stops reading and returns that error. +// Reader may also return errors encountered reading and parsing from r. +// To stop reading early, use a sentinel "stop" error value and ignore +// it when returned from Reader. +func Reader(r io.Reader, fn func(line []byte) error) error { + bs := bufio.NewScanner(r) + for bs.Scan() { + if err := fn(bs.Bytes()); err != nil { + return err + } + } + return bs.Err() +} diff --git a/util/linuxfw/linuxfwtest/linuxfwtest.go b/util/linuxfw/linuxfwtest/linuxfwtest.go index ee2cbd1b227f4..04f179199fb6b 100644 --- a/util/linuxfw/linuxfwtest/linuxfwtest.go +++ b/util/linuxfw/linuxfwtest/linuxfwtest.go @@ -1,31 +1,31 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build cgo && linux - -// Package linuxfwtest contains tests for the linuxfw package. Go does not -// support cgo in tests, and we don't want the main package to have a cgo -// dependency, so we put all the tests here and call them from the main package -// in tests intead. -package linuxfwtest - -import ( - "testing" - "unsafe" -) - -/* -#include // socket() -*/ -import "C" - -type SizeInfo struct { - SizeofSocklen uintptr -} - -func TestSizes(t *testing.T, si *SizeInfo) { - want := unsafe.Sizeof(C.socklen_t(0)) - if want != si.SizeofSocklen { - t.Errorf("sockLen has wrong size; want=%d got=%d", want, si.SizeofSocklen) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build cgo && linux + +// Package linuxfwtest contains tests for the linuxfw package. Go does not +// support cgo in tests, and we don't want the main package to have a cgo +// dependency, so we put all the tests here and call them from the main package +// in tests intead. +package linuxfwtest + +import ( + "testing" + "unsafe" +) + +/* +#include // socket() +*/ +import "C" + +type SizeInfo struct { + SizeofSocklen uintptr +} + +func TestSizes(t *testing.T, si *SizeInfo) { + want := unsafe.Sizeof(C.socklen_t(0)) + if want != si.SizeofSocklen { + t.Errorf("sockLen has wrong size; want=%d got=%d", want, si.SizeofSocklen) + } +} diff --git a/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go b/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go index 6e95699001d4b..d5e297da7b965 100644 --- a/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go +++ b/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !cgo || !linux - -package linuxfwtest - -import ( - "testing" -) - -type SizeInfo struct { - SizeofSocklen uintptr -} - -func TestSizes(t *testing.T, si *SizeInfo) { - t.Skip("not supported without cgo") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !cgo || !linux + +package linuxfwtest + +import ( + "testing" +) + +type SizeInfo struct { + SizeofSocklen uintptr +} + +func TestSizes(t *testing.T, si *SizeInfo) { + t.Skip("not supported without cgo") +} diff --git a/util/linuxfw/nftables_types.go b/util/linuxfw/nftables_types.go index b6e24d2a67b5b..a8c5a0730dbd3 100644 --- a/util/linuxfw/nftables_types.go +++ b/util/linuxfw/nftables_types.go @@ -1,95 +1,95 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// TODO(#8502): add support for more architectures -//go:build linux && (arm64 || amd64) - -package linuxfw - -import ( - "github.com/google/nftables/expr" - "github.com/google/nftables/xt" -) - -var metaKeyNames = map[expr.MetaKey]string{ - expr.MetaKeyLEN: "LEN", - expr.MetaKeyPROTOCOL: "PROTOCOL", - expr.MetaKeyPRIORITY: "PRIORITY", - expr.MetaKeyMARK: "MARK", - expr.MetaKeyIIF: "IIF", - expr.MetaKeyOIF: "OIF", - expr.MetaKeyIIFNAME: "IIFNAME", - expr.MetaKeyOIFNAME: "OIFNAME", - expr.MetaKeyIIFTYPE: "IIFTYPE", - expr.MetaKeyOIFTYPE: "OIFTYPE", - expr.MetaKeySKUID: "SKUID", - expr.MetaKeySKGID: "SKGID", - expr.MetaKeyNFTRACE: "NFTRACE", - expr.MetaKeyRTCLASSID: "RTCLASSID", - expr.MetaKeySECMARK: "SECMARK", - expr.MetaKeyNFPROTO: "NFPROTO", - expr.MetaKeyL4PROTO: "L4PROTO", - expr.MetaKeyBRIIIFNAME: "BRIIIFNAME", - expr.MetaKeyBRIOIFNAME: "BRIOIFNAME", - expr.MetaKeyPKTTYPE: "PKTTYPE", - expr.MetaKeyCPU: "CPU", - expr.MetaKeyIIFGROUP: "IIFGROUP", - expr.MetaKeyOIFGROUP: "OIFGROUP", - expr.MetaKeyCGROUP: "CGROUP", - expr.MetaKeyPRANDOM: "PRANDOM", -} - -var cmpOpNames = map[expr.CmpOp]string{ - expr.CmpOpEq: "EQ", - expr.CmpOpNeq: "NEQ", - expr.CmpOpLt: "LT", - expr.CmpOpLte: "LTE", - expr.CmpOpGt: "GT", - expr.CmpOpGte: "GTE", -} - -var verdictNames = map[expr.VerdictKind]string{ - expr.VerdictReturn: "RETURN", - expr.VerdictGoto: "GOTO", - expr.VerdictJump: "JUMP", - expr.VerdictBreak: "BREAK", - expr.VerdictContinue: "CONTINUE", - expr.VerdictDrop: "DROP", - expr.VerdictAccept: "ACCEPT", - expr.VerdictStolen: "STOLEN", - expr.VerdictQueue: "QUEUE", - expr.VerdictRepeat: "REPEAT", - expr.VerdictStop: "STOP", -} - -var payloadOperationTypeNames = map[expr.PayloadOperationType]string{ - expr.PayloadLoad: "LOAD", - expr.PayloadWrite: "WRITE", -} - -var payloadBaseNames = map[expr.PayloadBase]string{ - expr.PayloadBaseLLHeader: "ll-header", - expr.PayloadBaseNetworkHeader: "network-header", - expr.PayloadBaseTransportHeader: "transport-header", -} - -var packetTypeNames = map[int]string{ - 0 /* PACKET_HOST */ : "unicast", - 1 /* PACKET_BROADCAST */ : "broadcast", - 2 /* PACKET_MULTICAST */ : "multicast", -} - -var addrTypeFlagNames = map[xt.AddrTypeFlags]string{ - xt.AddrTypeUnspec: "unspec", - xt.AddrTypeUnicast: "unicast", - xt.AddrTypeLocal: "local", - xt.AddrTypeBroadcast: "broadcast", - xt.AddrTypeAnycast: "anycast", - xt.AddrTypeMulticast: "multicast", - xt.AddrTypeBlackhole: "blackhole", - xt.AddrTypeUnreachable: "unreachable", - xt.AddrTypeProhibit: "prohibit", - xt.AddrTypeThrow: "throw", - xt.AddrTypeNat: "nat", - xt.AddrTypeXresolve: "xresolve", -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// TODO(#8502): add support for more architectures +//go:build linux && (arm64 || amd64) + +package linuxfw + +import ( + "github.com/google/nftables/expr" + "github.com/google/nftables/xt" +) + +var metaKeyNames = map[expr.MetaKey]string{ + expr.MetaKeyLEN: "LEN", + expr.MetaKeyPROTOCOL: "PROTOCOL", + expr.MetaKeyPRIORITY: "PRIORITY", + expr.MetaKeyMARK: "MARK", + expr.MetaKeyIIF: "IIF", + expr.MetaKeyOIF: "OIF", + expr.MetaKeyIIFNAME: "IIFNAME", + expr.MetaKeyOIFNAME: "OIFNAME", + expr.MetaKeyIIFTYPE: "IIFTYPE", + expr.MetaKeyOIFTYPE: "OIFTYPE", + expr.MetaKeySKUID: "SKUID", + expr.MetaKeySKGID: "SKGID", + expr.MetaKeyNFTRACE: "NFTRACE", + expr.MetaKeyRTCLASSID: "RTCLASSID", + expr.MetaKeySECMARK: "SECMARK", + expr.MetaKeyNFPROTO: "NFPROTO", + expr.MetaKeyL4PROTO: "L4PROTO", + expr.MetaKeyBRIIIFNAME: "BRIIIFNAME", + expr.MetaKeyBRIOIFNAME: "BRIOIFNAME", + expr.MetaKeyPKTTYPE: "PKTTYPE", + expr.MetaKeyCPU: "CPU", + expr.MetaKeyIIFGROUP: "IIFGROUP", + expr.MetaKeyOIFGROUP: "OIFGROUP", + expr.MetaKeyCGROUP: "CGROUP", + expr.MetaKeyPRANDOM: "PRANDOM", +} + +var cmpOpNames = map[expr.CmpOp]string{ + expr.CmpOpEq: "EQ", + expr.CmpOpNeq: "NEQ", + expr.CmpOpLt: "LT", + expr.CmpOpLte: "LTE", + expr.CmpOpGt: "GT", + expr.CmpOpGte: "GTE", +} + +var verdictNames = map[expr.VerdictKind]string{ + expr.VerdictReturn: "RETURN", + expr.VerdictGoto: "GOTO", + expr.VerdictJump: "JUMP", + expr.VerdictBreak: "BREAK", + expr.VerdictContinue: "CONTINUE", + expr.VerdictDrop: "DROP", + expr.VerdictAccept: "ACCEPT", + expr.VerdictStolen: "STOLEN", + expr.VerdictQueue: "QUEUE", + expr.VerdictRepeat: "REPEAT", + expr.VerdictStop: "STOP", +} + +var payloadOperationTypeNames = map[expr.PayloadOperationType]string{ + expr.PayloadLoad: "LOAD", + expr.PayloadWrite: "WRITE", +} + +var payloadBaseNames = map[expr.PayloadBase]string{ + expr.PayloadBaseLLHeader: "ll-header", + expr.PayloadBaseNetworkHeader: "network-header", + expr.PayloadBaseTransportHeader: "transport-header", +} + +var packetTypeNames = map[int]string{ + 0 /* PACKET_HOST */ : "unicast", + 1 /* PACKET_BROADCAST */ : "broadcast", + 2 /* PACKET_MULTICAST */ : "multicast", +} + +var addrTypeFlagNames = map[xt.AddrTypeFlags]string{ + xt.AddrTypeUnspec: "unspec", + xt.AddrTypeUnicast: "unicast", + xt.AddrTypeLocal: "local", + xt.AddrTypeBroadcast: "broadcast", + xt.AddrTypeAnycast: "anycast", + xt.AddrTypeMulticast: "multicast", + xt.AddrTypeBlackhole: "blackhole", + xt.AddrTypeUnreachable: "unreachable", + xt.AddrTypeProhibit: "prohibit", + xt.AddrTypeThrow: "throw", + xt.AddrTypeNat: "nat", + xt.AddrTypeXresolve: "xresolve", +} diff --git a/util/mak/mak.go b/util/mak/mak.go index b421fb0ed5a55..b0d64daa422d4 100644 --- a/util/mak/mak.go +++ b/util/mak/mak.go @@ -1,70 +1,70 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package mak helps make maps. It contains generic helpers to make/assign -// things, notably to maps, but also slices. -package mak - -import ( - "fmt" - "reflect" -) - -// Set populates an entry in a map, making the map if necessary. -// -// That is, it assigns (*m)[k] = v, making *m if it was nil. -func Set[K comparable, V any, T ~map[K]V](m *T, k K, v V) { - if *m == nil { - *m = make(map[K]V) - } - (*m)[k] = v -} - -// NonNil takes a pointer to a Go data structure -// (currently only a slice or a map) and makes sure it's non-nil for -// JSON serialization. (In particular, JavaScript clients usually want -// the field to be defined after they decode the JSON.) -// -// Deprecated: use NonNilSliceForJSON or NonNilMapForJSON instead. -func NonNil(ptr any) { - if ptr == nil { - panic("nil interface") - } - rv := reflect.ValueOf(ptr) - if rv.Kind() != reflect.Ptr { - panic(fmt.Sprintf("kind %v, not Ptr", rv.Kind())) - } - if rv.Pointer() == 0 { - panic("nil pointer") - } - rv = rv.Elem() - if rv.Pointer() != 0 { - return - } - switch rv.Type().Kind() { - case reflect.Slice: - rv.Set(reflect.MakeSlice(rv.Type(), 0, 0)) - case reflect.Map: - rv.Set(reflect.MakeMap(rv.Type())) - } -} - -// NonNilSliceForJSON makes sure that *slicePtr is non-nil so it will -// won't be omitted from JSON serialization and possibly confuse JavaScript -// clients expecting it to be present. -func NonNilSliceForJSON[T any, S ~[]T](slicePtr *S) { - if *slicePtr != nil { - return - } - *slicePtr = make([]T, 0) -} - -// NonNilMapForJSON makes sure that *slicePtr is non-nil so it will -// won't be omitted from JSON serialization and possibly confuse JavaScript -// clients expecting it to be present. -func NonNilMapForJSON[K comparable, V any, M ~map[K]V](mapPtr *M) { - if *mapPtr != nil { - return - } - *mapPtr = make(M) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package mak helps make maps. It contains generic helpers to make/assign +// things, notably to maps, but also slices. +package mak + +import ( + "fmt" + "reflect" +) + +// Set populates an entry in a map, making the map if necessary. +// +// That is, it assigns (*m)[k] = v, making *m if it was nil. +func Set[K comparable, V any, T ~map[K]V](m *T, k K, v V) { + if *m == nil { + *m = make(map[K]V) + } + (*m)[k] = v +} + +// NonNil takes a pointer to a Go data structure +// (currently only a slice or a map) and makes sure it's non-nil for +// JSON serialization. (In particular, JavaScript clients usually want +// the field to be defined after they decode the JSON.) +// +// Deprecated: use NonNilSliceForJSON or NonNilMapForJSON instead. +func NonNil(ptr any) { + if ptr == nil { + panic("nil interface") + } + rv := reflect.ValueOf(ptr) + if rv.Kind() != reflect.Ptr { + panic(fmt.Sprintf("kind %v, not Ptr", rv.Kind())) + } + if rv.Pointer() == 0 { + panic("nil pointer") + } + rv = rv.Elem() + if rv.Pointer() != 0 { + return + } + switch rv.Type().Kind() { + case reflect.Slice: + rv.Set(reflect.MakeSlice(rv.Type(), 0, 0)) + case reflect.Map: + rv.Set(reflect.MakeMap(rv.Type())) + } +} + +// NonNilSliceForJSON makes sure that *slicePtr is non-nil so it will +// won't be omitted from JSON serialization and possibly confuse JavaScript +// clients expecting it to be present. +func NonNilSliceForJSON[T any, S ~[]T](slicePtr *S) { + if *slicePtr != nil { + return + } + *slicePtr = make([]T, 0) +} + +// NonNilMapForJSON makes sure that *slicePtr is non-nil so it will +// won't be omitted from JSON serialization and possibly confuse JavaScript +// clients expecting it to be present. +func NonNilMapForJSON[K comparable, V any, M ~map[K]V](mapPtr *M) { + if *mapPtr != nil { + return + } + *mapPtr = make(M) +} diff --git a/util/mak/mak_test.go b/util/mak/mak_test.go index 4de499a9d5040..dc1d7e93d7b19 100644 --- a/util/mak/mak_test.go +++ b/util/mak/mak_test.go @@ -1,88 +1,88 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package mak contains code to help make things. -package mak - -import ( - "reflect" - "testing" -) - -type M map[string]int - -func TestSet(t *testing.T) { - t.Run("unnamed", func(t *testing.T) { - var m map[string]int - Set(&m, "foo", 42) - Set(&m, "bar", 1) - Set(&m, "bar", 2) - want := map[string]int{ - "foo": 42, - "bar": 2, - } - if got := m; !reflect.DeepEqual(got, want) { - t.Errorf("got %v; want %v", got, want) - } - }) - t.Run("named", func(t *testing.T) { - var m M - Set(&m, "foo", 1) - Set(&m, "bar", 1) - Set(&m, "bar", 2) - want := M{ - "foo": 1, - "bar": 2, - } - if got := m; !reflect.DeepEqual(got, want) { - t.Errorf("got %v; want %v", got, want) - } - }) -} - -func TestNonNil(t *testing.T) { - var s []string - NonNil(&s) - if len(s) != 0 { - t.Errorf("slice len = %d; want 0", len(s)) - } - if s == nil { - t.Error("slice still nil") - } - - s = append(s, "foo") - NonNil(&s) - if len(s) != 1 { - t.Errorf("len = %d; want 1", len(s)) - } - if s[0] != "foo" { - t.Errorf("value = %q; want foo", s) - } - - var m map[string]string - NonNil(&m) - if len(m) != 0 { - t.Errorf("map len = %d; want 0", len(s)) - } - if m == nil { - t.Error("map still nil") - } -} - -func TestNonNilMapForJSON(t *testing.T) { - type M map[string]int - var m M - NonNilMapForJSON(&m) - if m == nil { - t.Fatal("still nil") - } -} - -func TestNonNilSliceForJSON(t *testing.T) { - type S []int - var s S - NonNilSliceForJSON(&s) - if s == nil { - t.Fatal("still nil") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package mak contains code to help make things. +package mak + +import ( + "reflect" + "testing" +) + +type M map[string]int + +func TestSet(t *testing.T) { + t.Run("unnamed", func(t *testing.T) { + var m map[string]int + Set(&m, "foo", 42) + Set(&m, "bar", 1) + Set(&m, "bar", 2) + want := map[string]int{ + "foo": 42, + "bar": 2, + } + if got := m; !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } + }) + t.Run("named", func(t *testing.T) { + var m M + Set(&m, "foo", 1) + Set(&m, "bar", 1) + Set(&m, "bar", 2) + want := M{ + "foo": 1, + "bar": 2, + } + if got := m; !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } + }) +} + +func TestNonNil(t *testing.T) { + var s []string + NonNil(&s) + if len(s) != 0 { + t.Errorf("slice len = %d; want 0", len(s)) + } + if s == nil { + t.Error("slice still nil") + } + + s = append(s, "foo") + NonNil(&s) + if len(s) != 1 { + t.Errorf("len = %d; want 1", len(s)) + } + if s[0] != "foo" { + t.Errorf("value = %q; want foo", s) + } + + var m map[string]string + NonNil(&m) + if len(m) != 0 { + t.Errorf("map len = %d; want 0", len(s)) + } + if m == nil { + t.Error("map still nil") + } +} + +func TestNonNilMapForJSON(t *testing.T) { + type M map[string]int + var m M + NonNilMapForJSON(&m) + if m == nil { + t.Fatal("still nil") + } +} + +func TestNonNilSliceForJSON(t *testing.T) { + type S []int + var s S + NonNilSliceForJSON(&s) + if s == nil { + t.Fatal("still nil") + } +} diff --git a/util/multierr/multierr.go b/util/multierr/multierr.go index 93ca068f56532..5ec36f644b73c 100644 --- a/util/multierr/multierr.go +++ b/util/multierr/multierr.go @@ -1,136 +1,136 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package multierr provides a simple multiple-error type. -// It was inspired by github.com/go-multierror/multierror. -package multierr - -import ( - "errors" - "slices" - "strings" -) - -// An Error represents multiple errors. -type Error struct { - errs []error -} - -// Error implements the error interface. -func (e Error) Error() string { - s := new(strings.Builder) - s.WriteString("multiple errors:") - for _, err := range e.errs { - s.WriteString("\n\t") - s.WriteString(err.Error()) - } - return s.String() -} - -// Errors returns a slice containing all errors in e. -func (e Error) Errors() []error { - return slices.Clone(e.errs) -} - -// Unwrap returns the underlying errors as-is. -func (e Error) Unwrap() []error { - // Do not clone since Unwrap requires callers to not mutate the slice. - // See the documentation in the Go "errors" package. - return e.errs -} - -// New returns an error composed from errs. -// Some errors in errs get special treatment: -// - nil errors are discarded -// - errors of type Error are expanded into the top level -// -// If the resulting slice has length 0, New returns nil. -// If the resulting slice has length 1, New returns that error. -// If the resulting slice has length > 1, New returns that slice as an Error. -func New(errs ...error) error { - // First count the number of errors to avoid allocating. - var n int - var errFirst error - for _, e := range errs { - switch e := e.(type) { - case nil: - continue - case Error: - n += len(e.errs) - if errFirst == nil && len(e.errs) > 0 { - errFirst = e.errs[0] - } - default: - n++ - if errFirst == nil { - errFirst = e - } - } - } - if n <= 1 { - return errFirst // nil if n == 0 - } - - // More than one error, allocate slice and construct the multi-error. - dst := make([]error, 0, n) - for _, e := range errs { - switch e := e.(type) { - case nil: - continue - case Error: - dst = append(dst, e.errs...) - default: - dst = append(dst, e) - } - } - return Error{errs: dst} -} - -// Is reports whether any error in e matches target. -func (e Error) Is(target error) bool { - for _, err := range e.errs { - if errors.Is(err, target) { - return true - } - } - return false -} - -// As finds the first error in e that matches target, and if any is found, -// sets target to that error value and returns true. Otherwise, it returns false. -func (e Error) As(target any) bool { - for _, err := range e.errs { - if ok := errors.As(err, target); ok { - return true - } - } - return false -} - -// Range performs a pre-order, depth-first iteration of the error tree -// by successively unwrapping all error values. -// For each iteration it calls fn with the current error value and -// stops iteration if it ever reports false. -func Range(err error, fn func(error) bool) bool { - if err == nil { - return true - } - if !fn(err) { - return false - } - switch err := err.(type) { - case interface{ Unwrap() error }: - if err := err.Unwrap(); err != nil { - if !Range(err, fn) { - return false - } - } - case interface{ Unwrap() []error }: - for _, err := range err.Unwrap() { - if !Range(err, fn) { - return false - } - } - } - return true -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package multierr provides a simple multiple-error type. +// It was inspired by github.com/go-multierror/multierror. +package multierr + +import ( + "errors" + "slices" + "strings" +) + +// An Error represents multiple errors. +type Error struct { + errs []error +} + +// Error implements the error interface. +func (e Error) Error() string { + s := new(strings.Builder) + s.WriteString("multiple errors:") + for _, err := range e.errs { + s.WriteString("\n\t") + s.WriteString(err.Error()) + } + return s.String() +} + +// Errors returns a slice containing all errors in e. +func (e Error) Errors() []error { + return slices.Clone(e.errs) +} + +// Unwrap returns the underlying errors as-is. +func (e Error) Unwrap() []error { + // Do not clone since Unwrap requires callers to not mutate the slice. + // See the documentation in the Go "errors" package. + return e.errs +} + +// New returns an error composed from errs. +// Some errors in errs get special treatment: +// - nil errors are discarded +// - errors of type Error are expanded into the top level +// +// If the resulting slice has length 0, New returns nil. +// If the resulting slice has length 1, New returns that error. +// If the resulting slice has length > 1, New returns that slice as an Error. +func New(errs ...error) error { + // First count the number of errors to avoid allocating. + var n int + var errFirst error + for _, e := range errs { + switch e := e.(type) { + case nil: + continue + case Error: + n += len(e.errs) + if errFirst == nil && len(e.errs) > 0 { + errFirst = e.errs[0] + } + default: + n++ + if errFirst == nil { + errFirst = e + } + } + } + if n <= 1 { + return errFirst // nil if n == 0 + } + + // More than one error, allocate slice and construct the multi-error. + dst := make([]error, 0, n) + for _, e := range errs { + switch e := e.(type) { + case nil: + continue + case Error: + dst = append(dst, e.errs...) + default: + dst = append(dst, e) + } + } + return Error{errs: dst} +} + +// Is reports whether any error in e matches target. +func (e Error) Is(target error) bool { + for _, err := range e.errs { + if errors.Is(err, target) { + return true + } + } + return false +} + +// As finds the first error in e that matches target, and if any is found, +// sets target to that error value and returns true. Otherwise, it returns false. +func (e Error) As(target any) bool { + for _, err := range e.errs { + if ok := errors.As(err, target); ok { + return true + } + } + return false +} + +// Range performs a pre-order, depth-first iteration of the error tree +// by successively unwrapping all error values. +// For each iteration it calls fn with the current error value and +// stops iteration if it ever reports false. +func Range(err error, fn func(error) bool) bool { + if err == nil { + return true + } + if !fn(err) { + return false + } + switch err := err.(type) { + case interface{ Unwrap() error }: + if err := err.Unwrap(); err != nil { + if !Range(err, fn) { + return false + } + } + case interface{ Unwrap() []error }: + for _, err := range err.Unwrap() { + if !Range(err, fn) { + return false + } + } + } + return true +} diff --git a/util/must/must.go b/util/must/must.go index 21965daa9b038..056986fcac915 100644 --- a/util/must/must.go +++ b/util/must/must.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package must assists in calling functions that must succeed. -// -// Example usage: -// -// var target = must.Get(url.Parse(...)) -// must.Do(close()) -package must - -// Do panics if err is non-nil. -func Do(err error) { - if err != nil { - panic(err) - } -} - -// Get returns v as is. It panics if err is non-nil. -func Get[T any](v T, err error) T { - if err != nil { - panic(err) - } - return v -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package must assists in calling functions that must succeed. +// +// Example usage: +// +// var target = must.Get(url.Parse(...)) +// must.Do(close()) +package must + +// Do panics if err is non-nil. +func Do(err error) { + if err != nil { + panic(err) + } +} + +// Get returns v as is. It panics if err is non-nil. +func Get[T any](v T, err error) T { + if err != nil { + panic(err) + } + return v +} diff --git a/util/osdiag/mksyscall.go b/util/osdiag/mksyscall.go index bcbe113b051cd..f20be7f92da7f 100644 --- a/util/osdiag/mksyscall.go +++ b/util/osdiag/mksyscall.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package osdiag - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go -//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go - -//sys globalMemoryStatusEx(memStatus *_MEMORYSTATUSEX) (err error) [int32(failretval)==0] = kernel32.GlobalMemoryStatusEx -//sys regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLen *uint32, reserved *uint32, valueType *uint32, pData *byte, cbData *uint32) (ret error) [failretval!=0] = advapi32.RegEnumValueW -//sys wscEnumProtocols(iProtocols *int32, protocolBuffer *wsaProtocolInfo, bufLen *uint32, errno *int32) (ret int32) = ws2_32.WSCEnumProtocols -//sys wscGetProviderInfo(providerId *windows.GUID, infoType _WSC_PROVIDER_INFO_TYPE, info unsafe.Pointer, infoSize *uintptr, flags uint32, errno *int32) (ret int32) = ws2_32.WSCGetProviderInfo -//sys wscGetProviderPath(providerId *windows.GUID, providerDllPath *uint16, providerDllPathLen *int32, errno *int32) (ret int32) = ws2_32.WSCGetProviderPath +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package osdiag + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go +//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go + +//sys globalMemoryStatusEx(memStatus *_MEMORYSTATUSEX) (err error) [int32(failretval)==0] = kernel32.GlobalMemoryStatusEx +//sys regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLen *uint32, reserved *uint32, valueType *uint32, pData *byte, cbData *uint32) (ret error) [failretval!=0] = advapi32.RegEnumValueW +//sys wscEnumProtocols(iProtocols *int32, protocolBuffer *wsaProtocolInfo, bufLen *uint32, errno *int32) (ret int32) = ws2_32.WSCEnumProtocols +//sys wscGetProviderInfo(providerId *windows.GUID, infoType _WSC_PROVIDER_INFO_TYPE, info unsafe.Pointer, infoSize *uintptr, flags uint32, errno *int32) (ret int32) = ws2_32.WSCGetProviderInfo +//sys wscGetProviderPath(providerId *windows.GUID, providerDllPath *uint16, providerDllPathLen *int32, errno *int32) (ret int32) = ws2_32.WSCGetProviderPath diff --git a/util/osdiag/osdiag_windows_test.go b/util/osdiag/osdiag_windows_test.go index b29b602ccb73c..776852a345f2b 100644 --- a/util/osdiag/osdiag_windows_test.go +++ b/util/osdiag/osdiag_windows_test.go @@ -1,128 +1,128 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package osdiag - -import ( - "errors" - "fmt" - "maps" - "strings" - "testing" - - "golang.org/x/sys/windows/registry" -) - -func makeLongBinaryValue() []byte { - buf := make([]byte, maxBinaryValueLen*2) - for i, _ := range buf { - buf[i] = byte(i % 0xFF) - } - return buf -} - -var testData = map[string]any{ - "": "I am the default", - "StringEmpty": "", - "StringShort": "Hello", - "StringLong": strings.Repeat("7", initialValueBufLen+1), - "MultiStringEmpty": []string{}, - "MultiStringSingle": []string{"Foo"}, - "MultiStringSingleEmpty": []string{""}, - "MultiString": []string{"Foo", "Bar", "Baz"}, - "MultiStringWithEmptyBeginning": []string{"", "Foo", "Bar"}, - "MultiStringWithEmptyMiddle": []string{"Foo", "", "Bar"}, - "MultiStringWithEmptyEnd": []string{"Foo", "Bar", ""}, - "DWord": uint32(0x12345678), - "QWord": uint64(0x123456789abcdef0), - "BinaryEmpty": []byte{}, - "BinaryShort": []byte{0x01, 0x02, 0x03, 0x04}, - "BinaryLong": makeLongBinaryValue(), -} - -const ( - keyNameTest = `SOFTWARE\Tailscale Test` - subKeyNameTest = "SubKey" -) - -func setValues(t *testing.T, k registry.Key) { - for vk, v := range testData { - var err error - switch tv := v.(type) { - case string: - err = k.SetStringValue(vk, tv) - case []string: - err = k.SetStringsValue(vk, tv) - case uint32: - err = k.SetDWordValue(vk, tv) - case uint64: - err = k.SetQWordValue(vk, tv) - case []byte: - err = k.SetBinaryValue(vk, tv) - default: - t.Fatalf("Unknown type") - } - - if err != nil { - t.Fatalf("Error setting %q: %v", vk, err) - } - } -} - -func TestRegistrySupportInfo(t *testing.T) { - // Make sure the key doesn't exist yet - k, err := registry.OpenKey(registry.CURRENT_USER, keyNameTest, registry.READ) - switch { - case err == nil: - k.Close() - t.Fatalf("Test key already exists") - case !errors.Is(err, registry.ErrNotExist): - t.Fatal(err) - } - - func() { - k, _, err := registry.CreateKey(registry.CURRENT_USER, keyNameTest, registry.WRITE) - if err != nil { - t.Fatalf("Error creating test key: %v", err) - } - defer k.Close() - - setValues(t, k) - - sk, _, err := registry.CreateKey(k, subKeyNameTest, registry.WRITE) - if err != nil { - t.Fatalf("Error creating test subkey: %v", err) - } - defer sk.Close() - - setValues(t, sk) - }() - - t.Cleanup(func() { - registry.DeleteKey(registry.CURRENT_USER, keyNameTest+"\\"+subKeyNameTest) - registry.DeleteKey(registry.CURRENT_USER, keyNameTest) - }) - - wantValuesData := maps.Clone(testData) - wantValuesData["BinaryLong"] = (wantValuesData["BinaryLong"].([]byte))[:maxBinaryValueLen] - - wantKeyData := make(map[string]any) - maps.Copy(wantKeyData, wantValuesData) - wantSubKeyData := make(map[string]any) - maps.Copy(wantSubKeyData, wantValuesData) - wantKeyData[subKeyNameTest] = wantSubKeyData - - wantData := map[string]any{ - "HKCU\\" + keyNameTest: wantKeyData, - } - - gotData, err := getRegistrySupportInfo(registry.CURRENT_USER, []string{keyNameTest}) - if err != nil { - t.Errorf("getRegistrySupportInfo error: %v", err) - } - - want, got := fmt.Sprintf("%#v", wantData), fmt.Sprintf("%#v", gotData) - if want != got { - t.Errorf("Compare error: want\n%s,\ngot %s", want, got) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package osdiag + +import ( + "errors" + "fmt" + "maps" + "strings" + "testing" + + "golang.org/x/sys/windows/registry" +) + +func makeLongBinaryValue() []byte { + buf := make([]byte, maxBinaryValueLen*2) + for i, _ := range buf { + buf[i] = byte(i % 0xFF) + } + return buf +} + +var testData = map[string]any{ + "": "I am the default", + "StringEmpty": "", + "StringShort": "Hello", + "StringLong": strings.Repeat("7", initialValueBufLen+1), + "MultiStringEmpty": []string{}, + "MultiStringSingle": []string{"Foo"}, + "MultiStringSingleEmpty": []string{""}, + "MultiString": []string{"Foo", "Bar", "Baz"}, + "MultiStringWithEmptyBeginning": []string{"", "Foo", "Bar"}, + "MultiStringWithEmptyMiddle": []string{"Foo", "", "Bar"}, + "MultiStringWithEmptyEnd": []string{"Foo", "Bar", ""}, + "DWord": uint32(0x12345678), + "QWord": uint64(0x123456789abcdef0), + "BinaryEmpty": []byte{}, + "BinaryShort": []byte{0x01, 0x02, 0x03, 0x04}, + "BinaryLong": makeLongBinaryValue(), +} + +const ( + keyNameTest = `SOFTWARE\Tailscale Test` + subKeyNameTest = "SubKey" +) + +func setValues(t *testing.T, k registry.Key) { + for vk, v := range testData { + var err error + switch tv := v.(type) { + case string: + err = k.SetStringValue(vk, tv) + case []string: + err = k.SetStringsValue(vk, tv) + case uint32: + err = k.SetDWordValue(vk, tv) + case uint64: + err = k.SetQWordValue(vk, tv) + case []byte: + err = k.SetBinaryValue(vk, tv) + default: + t.Fatalf("Unknown type") + } + + if err != nil { + t.Fatalf("Error setting %q: %v", vk, err) + } + } +} + +func TestRegistrySupportInfo(t *testing.T) { + // Make sure the key doesn't exist yet + k, err := registry.OpenKey(registry.CURRENT_USER, keyNameTest, registry.READ) + switch { + case err == nil: + k.Close() + t.Fatalf("Test key already exists") + case !errors.Is(err, registry.ErrNotExist): + t.Fatal(err) + } + + func() { + k, _, err := registry.CreateKey(registry.CURRENT_USER, keyNameTest, registry.WRITE) + if err != nil { + t.Fatalf("Error creating test key: %v", err) + } + defer k.Close() + + setValues(t, k) + + sk, _, err := registry.CreateKey(k, subKeyNameTest, registry.WRITE) + if err != nil { + t.Fatalf("Error creating test subkey: %v", err) + } + defer sk.Close() + + setValues(t, sk) + }() + + t.Cleanup(func() { + registry.DeleteKey(registry.CURRENT_USER, keyNameTest+"\\"+subKeyNameTest) + registry.DeleteKey(registry.CURRENT_USER, keyNameTest) + }) + + wantValuesData := maps.Clone(testData) + wantValuesData["BinaryLong"] = (wantValuesData["BinaryLong"].([]byte))[:maxBinaryValueLen] + + wantKeyData := make(map[string]any) + maps.Copy(wantKeyData, wantValuesData) + wantSubKeyData := make(map[string]any) + maps.Copy(wantSubKeyData, wantValuesData) + wantKeyData[subKeyNameTest] = wantSubKeyData + + wantData := map[string]any{ + "HKCU\\" + keyNameTest: wantKeyData, + } + + gotData, err := getRegistrySupportInfo(registry.CURRENT_USER, []string{keyNameTest}) + if err != nil { + t.Errorf("getRegistrySupportInfo error: %v", err) + } + + want, got := fmt.Sprintf("%#v", wantData), fmt.Sprintf("%#v", gotData) + if want != got { + t.Errorf("Compare error: want\n%s,\ngot %s", want, got) + } +} diff --git a/util/osshare/filesharingstatus_noop.go b/util/osshare/filesharingstatus_noop.go index 7f2b131904ea9..6be4131a991d6 100644 --- a/util/osshare/filesharingstatus_noop.go +++ b/util/osshare/filesharingstatus_noop.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package osshare - -import ( - "tailscale.com/types/logger" -) - -func SetFileSharingEnabled(enabled bool, logf logger.Logf) {} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package osshare + +import ( + "tailscale.com/types/logger" +) + +func SetFileSharingEnabled(enabled bool, logf logger.Logf) {} diff --git a/util/pidowner/pidowner.go b/util/pidowner/pidowner.go index 56bb640b785dd..62ea85d780b07 100644 --- a/util/pidowner/pidowner.go +++ b/util/pidowner/pidowner.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package pidowner handles lookups from process ID to its owning user. -package pidowner - -import ( - "errors" - "runtime" -) - -var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) - -var ErrProcessNotFound = errors.New("process not found") - -// OwnerOfPID returns the user ID that owns the given process ID. -// -// The returned user ID is suitable to passing to os/user.LookupId. -// -// The returned error will be ErrNotImplemented for operating systems where -// this isn't supported. -func OwnerOfPID(pid int) (userID string, err error) { - return ownerOfPID(pid) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package pidowner handles lookups from process ID to its owning user. +package pidowner + +import ( + "errors" + "runtime" +) + +var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) + +var ErrProcessNotFound = errors.New("process not found") + +// OwnerOfPID returns the user ID that owns the given process ID. +// +// The returned user ID is suitable to passing to os/user.LookupId. +// +// The returned error will be ErrNotImplemented for operating systems where +// this isn't supported. +func OwnerOfPID(pid int) (userID string, err error) { + return ownerOfPID(pid) +} diff --git a/util/pidowner/pidowner_noimpl.go b/util/pidowner/pidowner_noimpl.go index 50add492fda76..a631e3f249896 100644 --- a/util/pidowner/pidowner_noimpl.go +++ b/util/pidowner/pidowner_noimpl.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !linux - -package pidowner - -func ownerOfPID(pid int) (userID string, err error) { return "", ErrNotImplemented } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !linux + +package pidowner + +func ownerOfPID(pid int) (userID string, err error) { return "", ErrNotImplemented } diff --git a/util/pidowner/pidowner_windows.go b/util/pidowner/pidowner_windows.go index dbf13ac8135f1..c7b2512a497ed 100644 --- a/util/pidowner/pidowner_windows.go +++ b/util/pidowner/pidowner_windows.go @@ -1,35 +1,35 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package pidowner - -import ( - "fmt" - "syscall" - - "golang.org/x/sys/windows" -) - -func ownerOfPID(pid int) (userID string, err error) { - procHnd, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, uint32(pid)) - if err == syscall.Errno(0x57) { // invalid parameter, for PIDs that don't exist - return "", ErrProcessNotFound - } - if err != nil { - return "", fmt.Errorf("OpenProcess: %T %#v", err, err) - } - defer windows.CloseHandle(procHnd) - - var tok windows.Token - if err := windows.OpenProcessToken(procHnd, windows.TOKEN_QUERY, &tok); err != nil { - return "", fmt.Errorf("OpenProcessToken: %w", err) - } - - tokUser, err := tok.GetTokenUser() - if err != nil { - return "", fmt.Errorf("GetTokenUser: %w", err) - } - - sid := tokUser.User.Sid - return sid.String(), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package pidowner + +import ( + "fmt" + "syscall" + + "golang.org/x/sys/windows" +) + +func ownerOfPID(pid int) (userID string, err error) { + procHnd, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, uint32(pid)) + if err == syscall.Errno(0x57) { // invalid parameter, for PIDs that don't exist + return "", ErrProcessNotFound + } + if err != nil { + return "", fmt.Errorf("OpenProcess: %T %#v", err, err) + } + defer windows.CloseHandle(procHnd) + + var tok windows.Token + if err := windows.OpenProcessToken(procHnd, windows.TOKEN_QUERY, &tok); err != nil { + return "", fmt.Errorf("OpenProcessToken: %w", err) + } + + tokUser, err := tok.GetTokenUser() + if err != nil { + return "", fmt.Errorf("GetTokenUser: %w", err) + } + + sid := tokUser.User.Sid + return sid.String(), nil +} diff --git a/util/precompress/precompress.go b/util/precompress/precompress.go index 6d1a26efdd767..e9bebb333e2af 100644 --- a/util/precompress/precompress.go +++ b/util/precompress/precompress.go @@ -1,129 +1,129 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package precompress provides build- and serving-time support for -// precompressed static resources, to avoid the cost of repeatedly compressing -// unchanging resources. -package precompress - -import ( - "bytes" - "compress/gzip" - "io" - "io/fs" - "net/http" - "os" - "path" - "path/filepath" - - "github.com/andybalholm/brotli" - "golang.org/x/sync/errgroup" - "tailscale.com/tsweb" -) - -// PrecompressDir compresses static assets in dirPath using Gzip and Brotli, so -// that they can be later served with OpenPrecompressedFile. -func PrecompressDir(dirPath string, options Options) error { - var eg errgroup.Group - err := fs.WalkDir(os.DirFS(dirPath), ".", func(p string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - return nil - } - if !compressibleExtensions[filepath.Ext(p)] { - return nil - } - p = path.Join(dirPath, p) - if options.ProgressFn != nil { - options.ProgressFn(p) - } - - eg.Go(func() error { - return Precompress(p, options) - }) - return nil - }) - if err != nil { - return err - } - return eg.Wait() -} - -type Options struct { - // FastCompression controls whether compression should be optimized for - // speed rather than size. - FastCompression bool - // ProgressFn, if non-nil, is invoked when a file in the directory is about - // to be compressed. - ProgressFn func(path string) -} - -// OpenPrecompressedFile opens a file from fs, preferring compressed versions -// generated by PrecompressDir if possible. -func OpenPrecompressedFile(w http.ResponseWriter, r *http.Request, path string, fs fs.FS) (fs.File, error) { - if tsweb.AcceptsEncoding(r, "br") { - if f, err := fs.Open(path + ".br"); err == nil { - w.Header().Set("Content-Encoding", "br") - return f, nil - } - } - if tsweb.AcceptsEncoding(r, "gzip") { - if f, err := fs.Open(path + ".gz"); err == nil { - w.Header().Set("Content-Encoding", "gzip") - return f, nil - } - } - - return fs.Open(path) -} - -var compressibleExtensions = map[string]bool{ - ".js": true, - ".css": true, -} - -func Precompress(path string, options Options) error { - contents, err := os.ReadFile(path) - if err != nil { - return err - } - fi, err := os.Lstat(path) - if err != nil { - return err - } - - gzipLevel := gzip.BestCompression - if options.FastCompression { - gzipLevel = gzip.BestSpeed - } - err = writeCompressed(contents, func(w io.Writer) (io.WriteCloser, error) { - return gzip.NewWriterLevel(w, gzipLevel) - }, path+".gz", fi.Mode()) - if err != nil { - return err - } - brotliLevel := brotli.BestCompression - if options.FastCompression { - brotliLevel = brotli.BestSpeed - } - return writeCompressed(contents, func(w io.Writer) (io.WriteCloser, error) { - return brotli.NewWriterLevel(w, brotliLevel), nil - }, path+".br", fi.Mode()) -} - -func writeCompressed(contents []byte, compressedWriterCreator func(io.Writer) (io.WriteCloser, error), outputPath string, outputMode fs.FileMode) error { - var buf bytes.Buffer - compressedWriter, err := compressedWriterCreator(&buf) - if err != nil { - return err - } - if _, err := compressedWriter.Write(contents); err != nil { - return err - } - if err := compressedWriter.Close(); err != nil { - return err - } - return os.WriteFile(outputPath, buf.Bytes(), outputMode) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package precompress provides build- and serving-time support for +// precompressed static resources, to avoid the cost of repeatedly compressing +// unchanging resources. +package precompress + +import ( + "bytes" + "compress/gzip" + "io" + "io/fs" + "net/http" + "os" + "path" + "path/filepath" + + "github.com/andybalholm/brotli" + "golang.org/x/sync/errgroup" + "tailscale.com/tsweb" +) + +// PrecompressDir compresses static assets in dirPath using Gzip and Brotli, so +// that they can be later served with OpenPrecompressedFile. +func PrecompressDir(dirPath string, options Options) error { + var eg errgroup.Group + err := fs.WalkDir(os.DirFS(dirPath), ".", func(p string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + if !compressibleExtensions[filepath.Ext(p)] { + return nil + } + p = path.Join(dirPath, p) + if options.ProgressFn != nil { + options.ProgressFn(p) + } + + eg.Go(func() error { + return Precompress(p, options) + }) + return nil + }) + if err != nil { + return err + } + return eg.Wait() +} + +type Options struct { + // FastCompression controls whether compression should be optimized for + // speed rather than size. + FastCompression bool + // ProgressFn, if non-nil, is invoked when a file in the directory is about + // to be compressed. + ProgressFn func(path string) +} + +// OpenPrecompressedFile opens a file from fs, preferring compressed versions +// generated by PrecompressDir if possible. +func OpenPrecompressedFile(w http.ResponseWriter, r *http.Request, path string, fs fs.FS) (fs.File, error) { + if tsweb.AcceptsEncoding(r, "br") { + if f, err := fs.Open(path + ".br"); err == nil { + w.Header().Set("Content-Encoding", "br") + return f, nil + } + } + if tsweb.AcceptsEncoding(r, "gzip") { + if f, err := fs.Open(path + ".gz"); err == nil { + w.Header().Set("Content-Encoding", "gzip") + return f, nil + } + } + + return fs.Open(path) +} + +var compressibleExtensions = map[string]bool{ + ".js": true, + ".css": true, +} + +func Precompress(path string, options Options) error { + contents, err := os.ReadFile(path) + if err != nil { + return err + } + fi, err := os.Lstat(path) + if err != nil { + return err + } + + gzipLevel := gzip.BestCompression + if options.FastCompression { + gzipLevel = gzip.BestSpeed + } + err = writeCompressed(contents, func(w io.Writer) (io.WriteCloser, error) { + return gzip.NewWriterLevel(w, gzipLevel) + }, path+".gz", fi.Mode()) + if err != nil { + return err + } + brotliLevel := brotli.BestCompression + if options.FastCompression { + brotliLevel = brotli.BestSpeed + } + return writeCompressed(contents, func(w io.Writer) (io.WriteCloser, error) { + return brotli.NewWriterLevel(w, brotliLevel), nil + }, path+".br", fi.Mode()) +} + +func writeCompressed(contents []byte, compressedWriterCreator func(io.Writer) (io.WriteCloser, error), outputPath string, outputMode fs.FileMode) error { + var buf bytes.Buffer + compressedWriter, err := compressedWriterCreator(&buf) + if err != nil { + return err + } + if _, err := compressedWriter.Write(contents); err != nil { + return err + } + if err := compressedWriter.Close(); err != nil { + return err + } + return os.WriteFile(outputPath, buf.Bytes(), outputMode) +} diff --git a/util/quarantine/quarantine.go b/util/quarantine/quarantine.go index 7ad65a81d69ee..488465ba055bb 100644 --- a/util/quarantine/quarantine.go +++ b/util/quarantine/quarantine.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package quarantine sets platform specific "quarantine" attributes on files -// that are received from other hosts. -package quarantine - -import "os" - -// SetOnFile sets the platform-specific quarantine attribute (if any) on the -// provided file. -func SetOnFile(f *os.File) error { - return setQuarantineAttr(f) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package quarantine sets platform specific "quarantine" attributes on files +// that are received from other hosts. +package quarantine + +import "os" + +// SetOnFile sets the platform-specific quarantine attribute (if any) on the +// provided file. +func SetOnFile(f *os.File) error { + return setQuarantineAttr(f) +} diff --git a/util/quarantine/quarantine_darwin.go b/util/quarantine/quarantine_darwin.go index 35405d9cc7a87..b7757f3346809 100644 --- a/util/quarantine/quarantine_darwin.go +++ b/util/quarantine/quarantine_darwin.go @@ -1,56 +1,56 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package quarantine - -import ( - "fmt" - "os" - "strings" - "time" - - "github.com/google/uuid" - "golang.org/x/sys/unix" -) - -func setQuarantineAttr(f *os.File) error { - sc, err := f.SyscallConn() - if err != nil { - return err - } - - now := time.Now() - - // We uppercase the UUID to match what other applications on macOS do - id := strings.ToUpper(uuid.New().String()) - - // kLSQuarantineTypeOtherDownload; this matches what AirDrop sets when - // receiving a file. - quarantineType := "0001" - - // This format is under-documented, but the following links contain a - // reasonably comprehensive overview: - // https://eclecticlight.co/2020/10/29/quarantine-and-the-quarantine-flag/ - // https://nixhacker.com/security-protection-in-macos-1/ - // https://ilostmynotes.blogspot.com/2012/06/gatekeeper-xprotect-and-quarantine.html - attrData := fmt.Sprintf("%s;%x;%s;%s", - quarantineType, // quarantine value - now.Unix(), // time in hex - "Tailscale", // application - id, // UUID - ) - - var innerErr error - err = sc.Control(func(fd uintptr) { - innerErr = unix.Fsetxattr( - int(fd), - "com.apple.quarantine", // attr - []byte(attrData), - 0, - ) - }) - if err != nil { - return err - } - return innerErr -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package quarantine + +import ( + "fmt" + "os" + "strings" + "time" + + "github.com/google/uuid" + "golang.org/x/sys/unix" +) + +func setQuarantineAttr(f *os.File) error { + sc, err := f.SyscallConn() + if err != nil { + return err + } + + now := time.Now() + + // We uppercase the UUID to match what other applications on macOS do + id := strings.ToUpper(uuid.New().String()) + + // kLSQuarantineTypeOtherDownload; this matches what AirDrop sets when + // receiving a file. + quarantineType := "0001" + + // This format is under-documented, but the following links contain a + // reasonably comprehensive overview: + // https://eclecticlight.co/2020/10/29/quarantine-and-the-quarantine-flag/ + // https://nixhacker.com/security-protection-in-macos-1/ + // https://ilostmynotes.blogspot.com/2012/06/gatekeeper-xprotect-and-quarantine.html + attrData := fmt.Sprintf("%s;%x;%s;%s", + quarantineType, // quarantine value + now.Unix(), // time in hex + "Tailscale", // application + id, // UUID + ) + + var innerErr error + err = sc.Control(func(fd uintptr) { + innerErr = unix.Fsetxattr( + int(fd), + "com.apple.quarantine", // attr + []byte(attrData), + 0, + ) + }) + if err != nil { + return err + } + return innerErr +} diff --git a/util/quarantine/quarantine_default.go b/util/quarantine/quarantine_default.go index 65954a4d25415..65a14ed26fa97 100644 --- a/util/quarantine/quarantine_default.go +++ b/util/quarantine/quarantine_default.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !darwin && !windows - -package quarantine - -import ( - "os" -) - -func setQuarantineAttr(f *os.File) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !darwin && !windows + +package quarantine + +import ( + "os" +) + +func setQuarantineAttr(f *os.File) error { + return nil +} diff --git a/util/quarantine/quarantine_windows.go b/util/quarantine/quarantine_windows.go index 6fdf4e699b75b..3052c2c6dfab5 100644 --- a/util/quarantine/quarantine_windows.go +++ b/util/quarantine/quarantine_windows.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package quarantine - -import ( - "os" - "strings" -) - -func setQuarantineAttr(f *os.File) error { - // Documentation on this can be found here: - // https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-fscc/6e3f7352-d11c-4d76-8c39-2516a9df36e8 - // - // Additional information can be found at: - // https://www.digital-detective.net/forensic-analysis-of-zone-identifier-stream/ - // https://bugzilla.mozilla.org/show_bug.cgi?id=1433179 - content := strings.Join([]string{ - "[ZoneTransfer]", - - // "URLZONE_INTERNET" - // https://docs.microsoft.com/en-us/previous-versions/windows/internet-explorer/ie-developer/platform-apis/ms537175(v=vs.85) - "ZoneId=3", - - // TODO(andrew): should/could we add ReferrerUrl or HostUrl? - }, "\r\n") - - return os.WriteFile(f.Name()+":Zone.Identifier", []byte(content), 0) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package quarantine + +import ( + "os" + "strings" +) + +func setQuarantineAttr(f *os.File) error { + // Documentation on this can be found here: + // https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-fscc/6e3f7352-d11c-4d76-8c39-2516a9df36e8 + // + // Additional information can be found at: + // https://www.digital-detective.net/forensic-analysis-of-zone-identifier-stream/ + // https://bugzilla.mozilla.org/show_bug.cgi?id=1433179 + content := strings.Join([]string{ + "[ZoneTransfer]", + + // "URLZONE_INTERNET" + // https://docs.microsoft.com/en-us/previous-versions/windows/internet-explorer/ie-developer/platform-apis/ms537175(v=vs.85) + "ZoneId=3", + + // TODO(andrew): should/could we add ReferrerUrl or HostUrl? + }, "\r\n") + + return os.WriteFile(f.Name()+":Zone.Identifier", []byte(content), 0) +} diff --git a/util/race/race_test.go b/util/race/race_test.go index d3838271226ac..17ea764591503 100644 --- a/util/race/race_test.go +++ b/util/race/race_test.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package race - -import ( - "context" - "errors" - "testing" - "time" - - "tailscale.com/tstest" -) - -func TestRaceSuccess1(t *testing.T) { - tstest.ResourceCheck(t) - - const want = "success" - rh := New[string]( - 10*time.Second, - func(context.Context) (string, error) { - return want, nil - }, func(context.Context) (string, error) { - t.Fatal("should not be called") - return "", nil - }) - res, err := rh.Start(context.Background()) - if err != nil { - t.Fatal(err) - } - if res != want { - t.Errorf("got res=%q, want %q", res, want) - } -} - -func TestRaceRetry(t *testing.T) { - tstest.ResourceCheck(t) - - const want = "fallback" - rh := New[string]( - 10*time.Second, - func(context.Context) (string, error) { - return "", errors.New("some error") - }, func(context.Context) (string, error) { - return want, nil - }) - res, err := rh.Start(context.Background()) - if err != nil { - t.Fatal(err) - } - if res != want { - t.Errorf("got res=%q, want %q", res, want) - } -} - -func TestRaceTimeout(t *testing.T) { - tstest.ResourceCheck(t) - - const want = "fallback" - rh := New[string]( - 100*time.Millisecond, - func(ctx context.Context) (string, error) { - // Block forever - <-ctx.Done() - return "", ctx.Err() - }, func(context.Context) (string, error) { - return want, nil - }) - res, err := rh.Start(context.Background()) - if err != nil { - t.Fatal(err) - } - if res != want { - t.Errorf("got res=%q, want %q", res, want) - } -} - -func TestRaceError(t *testing.T) { - tstest.ResourceCheck(t) - - err1 := errors.New("error 1") - err2 := errors.New("error 2") - - rh := New[string]( - 100*time.Millisecond, - func(ctx context.Context) (string, error) { - return "", err1 - }, func(context.Context) (string, error) { - return "", err2 - }) - - _, err := rh.Start(context.Background()) - if !errors.Is(err, err1) { - t.Errorf("wanted err to contain err1; got %v", err) - } - if !errors.Is(err, err2) { - t.Errorf("wanted err to contain err2; got %v", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package race + +import ( + "context" + "errors" + "testing" + "time" + + "tailscale.com/tstest" +) + +func TestRaceSuccess1(t *testing.T) { + tstest.ResourceCheck(t) + + const want = "success" + rh := New[string]( + 10*time.Second, + func(context.Context) (string, error) { + return want, nil + }, func(context.Context) (string, error) { + t.Fatal("should not be called") + return "", nil + }) + res, err := rh.Start(context.Background()) + if err != nil { + t.Fatal(err) + } + if res != want { + t.Errorf("got res=%q, want %q", res, want) + } +} + +func TestRaceRetry(t *testing.T) { + tstest.ResourceCheck(t) + + const want = "fallback" + rh := New[string]( + 10*time.Second, + func(context.Context) (string, error) { + return "", errors.New("some error") + }, func(context.Context) (string, error) { + return want, nil + }) + res, err := rh.Start(context.Background()) + if err != nil { + t.Fatal(err) + } + if res != want { + t.Errorf("got res=%q, want %q", res, want) + } +} + +func TestRaceTimeout(t *testing.T) { + tstest.ResourceCheck(t) + + const want = "fallback" + rh := New[string]( + 100*time.Millisecond, + func(ctx context.Context) (string, error) { + // Block forever + <-ctx.Done() + return "", ctx.Err() + }, func(context.Context) (string, error) { + return want, nil + }) + res, err := rh.Start(context.Background()) + if err != nil { + t.Fatal(err) + } + if res != want { + t.Errorf("got res=%q, want %q", res, want) + } +} + +func TestRaceError(t *testing.T) { + tstest.ResourceCheck(t) + + err1 := errors.New("error 1") + err2 := errors.New("error 2") + + rh := New[string]( + 100*time.Millisecond, + func(ctx context.Context) (string, error) { + return "", err1 + }, func(context.Context) (string, error) { + return "", err2 + }) + + _, err := rh.Start(context.Background()) + if !errors.Is(err, err1) { + t.Errorf("wanted err to contain err1; got %v", err) + } + if !errors.Is(err, err2) { + t.Errorf("wanted err to contain err2; got %v", err) + } +} diff --git a/util/racebuild/off.go b/util/racebuild/off.go index 8f4fe998fb4bb..a0dba0f32c052 100644 --- a/util/racebuild/off.go +++ b/util/racebuild/off.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !race - -package racebuild - -const On = false +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !race + +package racebuild + +const On = false diff --git a/util/racebuild/on.go b/util/racebuild/on.go index 69ae2bcae4239..c60bca2e6f8df 100644 --- a/util/racebuild/on.go +++ b/util/racebuild/on.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build race - -package racebuild - -const On = true +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build race + +package racebuild + +const On = true diff --git a/util/racebuild/racebuild.go b/util/racebuild/racebuild.go index d061276cb8a0a..c1a43eb96a376 100644 --- a/util/racebuild/racebuild.go +++ b/util/racebuild/racebuild.go @@ -1,6 +1,6 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package racebuild exports a constant about whether the current binary -// was built with the race detector. -package racebuild +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package racebuild exports a constant about whether the current binary +// was built with the race detector. +package racebuild diff --git a/util/rands/rands.go b/util/rands/rands.go index d83e1e55898dc..dcd75c5f37158 100644 --- a/util/rands/rands.go +++ b/util/rands/rands.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package rands contains utility functions for randomness. -package rands - -import ( - crand "crypto/rand" - "encoding/hex" -) - -// HexString returns a string of n cryptographically random lowercase -// hex characters. -// -// That is, HexString(3) returns something like "0fc", containing 12 -// bits of randomness. -func HexString(n int) string { - nb := n / 2 - if n%2 == 1 { - nb++ - } - b := make([]byte, nb) - crand.Read(b) - return hex.EncodeToString(b)[:n] -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package rands contains utility functions for randomness. +package rands + +import ( + crand "crypto/rand" + "encoding/hex" +) + +// HexString returns a string of n cryptographically random lowercase +// hex characters. +// +// That is, HexString(3) returns something like "0fc", containing 12 +// bits of randomness. +func HexString(n int) string { + nb := n / 2 + if n%2 == 1 { + nb++ + } + b := make([]byte, nb) + crand.Read(b) + return hex.EncodeToString(b)[:n] +} diff --git a/util/rands/rands_test.go b/util/rands/rands_test.go index 5813f2bb46763..ec339f94bace7 100644 --- a/util/rands/rands_test.go +++ b/util/rands/rands_test.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package rands - -import "testing" - -func TestHexString(t *testing.T) { - for i := 0; i <= 8; i++ { - s := HexString(i) - if len(s) != i { - t.Errorf("HexString(%v) = %q; want len %v, not %v", i, s, i, len(s)) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rands + +import "testing" + +func TestHexString(t *testing.T) { + for i := 0; i <= 8; i++ { + s := HexString(i) + if len(s) != i { + t.Errorf("HexString(%v) = %q; want len %v, not %v", i, s, i, len(s)) + } + } +} diff --git a/util/set/handle.go b/util/set/handle.go index 471ceeba2d523..61b4eb93d8b4d 100644 --- a/util/set/handle.go +++ b/util/set/handle.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package set - -// HandleSet is a set of T. -// -// It is not safe for concurrent use. -type HandleSet[T any] map[Handle]T - -// Handle is an opaque comparable value that's used as the map key in a -// HandleSet. The only way to get one is to call HandleSet.Add. -type Handle struct { - v *byte -} - -// Add adds the element (map value) e to the set. -// -// It returns the handle (map key) with which e can be removed, using a map -// delete. -func (s *HandleSet[T]) Add(e T) Handle { - h := Handle{new(byte)} - if *s == nil { - *s = make(HandleSet[T]) - } - (*s)[h] = e - return h -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +// HandleSet is a set of T. +// +// It is not safe for concurrent use. +type HandleSet[T any] map[Handle]T + +// Handle is an opaque comparable value that's used as the map key in a +// HandleSet. The only way to get one is to call HandleSet.Add. +type Handle struct { + v *byte +} + +// Add adds the element (map value) e to the set. +// +// It returns the handle (map key) with which e can be removed, using a map +// delete. +func (s *HandleSet[T]) Add(e T) Handle { + h := Handle{new(byte)} + if *s == nil { + *s = make(HandleSet[T]) + } + (*s)[h] = e + return h +} diff --git a/util/set/slice_test.go b/util/set/slice_test.go index 9134c296292d3..ca57e52e8cbc3 100644 --- a/util/set/slice_test.go +++ b/util/set/slice_test.go @@ -1,56 +1,56 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package set - -import ( - "testing" - - qt "github.com/frankban/quicktest" -) - -func TestSliceSet(t *testing.T) { - c := qt.New(t) - - var ss Slice[int] - c.Check(len(ss.slice), qt.Equals, 0) - ss.Add(1) - c.Check(len(ss.slice), qt.Equals, 1) - c.Check(len(ss.set), qt.Equals, 0) - c.Check(ss.Contains(1), qt.Equals, true) - c.Check(ss.Contains(2), qt.Equals, false) - - ss.Add(1) - c.Check(len(ss.slice), qt.Equals, 1) - c.Check(len(ss.set), qt.Equals, 0) - - ss.Add(2) - ss.Add(3) - ss.Add(4) - ss.Add(5) - ss.Add(6) - ss.Add(7) - ss.Add(8) - c.Check(len(ss.slice), qt.Equals, 8) - c.Check(len(ss.set), qt.Equals, 0) - - ss.Add(9) - c.Check(len(ss.slice), qt.Equals, 9) - c.Check(len(ss.set), qt.Equals, 9) - - ss.Remove(4) - c.Check(len(ss.slice), qt.Equals, 8) - c.Check(len(ss.set), qt.Equals, 8) - c.Assert(ss.Contains(4), qt.IsFalse) - - // Ensure that the order of insertion is maintained - c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9}) - ss.Add(4) - c.Check(len(ss.slice), qt.Equals, 9) - c.Check(len(ss.set), qt.Equals, 9) - c.Assert(ss.Contains(4), qt.IsTrue) - c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9, 4}) - - ss.Add(1, 234, 556) - c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9, 4, 234, 556}) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +import ( + "testing" + + qt "github.com/frankban/quicktest" +) + +func TestSliceSet(t *testing.T) { + c := qt.New(t) + + var ss Slice[int] + c.Check(len(ss.slice), qt.Equals, 0) + ss.Add(1) + c.Check(len(ss.slice), qt.Equals, 1) + c.Check(len(ss.set), qt.Equals, 0) + c.Check(ss.Contains(1), qt.Equals, true) + c.Check(ss.Contains(2), qt.Equals, false) + + ss.Add(1) + c.Check(len(ss.slice), qt.Equals, 1) + c.Check(len(ss.set), qt.Equals, 0) + + ss.Add(2) + ss.Add(3) + ss.Add(4) + ss.Add(5) + ss.Add(6) + ss.Add(7) + ss.Add(8) + c.Check(len(ss.slice), qt.Equals, 8) + c.Check(len(ss.set), qt.Equals, 0) + + ss.Add(9) + c.Check(len(ss.slice), qt.Equals, 9) + c.Check(len(ss.set), qt.Equals, 9) + + ss.Remove(4) + c.Check(len(ss.slice), qt.Equals, 8) + c.Check(len(ss.set), qt.Equals, 8) + c.Assert(ss.Contains(4), qt.IsFalse) + + // Ensure that the order of insertion is maintained + c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9}) + ss.Add(4) + c.Check(len(ss.slice), qt.Equals, 9) + c.Check(len(ss.set), qt.Equals, 9) + c.Assert(ss.Contains(4), qt.IsTrue) + c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9, 4}) + + ss.Add(1, 234, 556) + c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9, 4, 234, 556}) +} diff --git a/util/sysresources/memory.go b/util/sysresources/memory.go index 7363155cdb2ae..8bf784e13d831 100644 --- a/util/sysresources/memory.go +++ b/util/sysresources/memory.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package sysresources - -// TotalMemory returns the total accessible system memory, in bytes. If the -// value cannot be determined, then 0 will be returned. -func TotalMemory() uint64 { - return totalMemoryImpl() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package sysresources + +// TotalMemory returns the total accessible system memory, in bytes. If the +// value cannot be determined, then 0 will be returned. +func TotalMemory() uint64 { + return totalMemoryImpl() +} diff --git a/util/sysresources/memory_bsd.go b/util/sysresources/memory_bsd.go index 26850dce652ff..39d3a18a972f1 100644 --- a/util/sysresources/memory_bsd.go +++ b/util/sysresources/memory_bsd.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build freebsd || openbsd || dragonfly || netbsd - -package sysresources - -import "golang.org/x/sys/unix" - -func totalMemoryImpl() uint64 { - val, err := unix.SysctlUint64("hw.physmem") - if err != nil { - return 0 - } - return val -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build freebsd || openbsd || dragonfly || netbsd + +package sysresources + +import "golang.org/x/sys/unix" + +func totalMemoryImpl() uint64 { + val, err := unix.SysctlUint64("hw.physmem") + if err != nil { + return 0 + } + return val +} diff --git a/util/sysresources/memory_darwin.go b/util/sysresources/memory_darwin.go index e07bac0cd7f9b..2f74b6cecd7f3 100644 --- a/util/sysresources/memory_darwin.go +++ b/util/sysresources/memory_darwin.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin - -package sysresources - -import "golang.org/x/sys/unix" - -func totalMemoryImpl() uint64 { - val, err := unix.SysctlUint64("hw.memsize") - if err != nil { - return 0 - } - return val -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin + +package sysresources + +import "golang.org/x/sys/unix" + +func totalMemoryImpl() uint64 { + val, err := unix.SysctlUint64("hw.memsize") + if err != nil { + return 0 + } + return val +} diff --git a/util/sysresources/memory_linux.go b/util/sysresources/memory_linux.go index 0239b0e80d62a..f3c51469fcc6c 100644 --- a/util/sysresources/memory_linux.go +++ b/util/sysresources/memory_linux.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package sysresources - -import "golang.org/x/sys/unix" - -func totalMemoryImpl() uint64 { - var info unix.Sysinfo_t - - if err := unix.Sysinfo(&info); err != nil { - return 0 - } - - // uint64 casts are required since these might be uint32s - return uint64(info.Totalram) * uint64(info.Unit) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package sysresources + +import "golang.org/x/sys/unix" + +func totalMemoryImpl() uint64 { + var info unix.Sysinfo_t + + if err := unix.Sysinfo(&info); err != nil { + return 0 + } + + // uint64 casts are required since these might be uint32s + return uint64(info.Totalram) * uint64(info.Unit) +} diff --git a/util/sysresources/memory_unsupported.go b/util/sysresources/memory_unsupported.go index 0fde256e0543d..f80ef4e6ebfe8 100644 --- a/util/sysresources/memory_unsupported.go +++ b/util/sysresources/memory_unsupported.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !(linux || darwin || freebsd || openbsd || dragonfly || netbsd) - -package sysresources - -func totalMemoryImpl() uint64 { return 0 } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !(linux || darwin || freebsd || openbsd || dragonfly || netbsd) + +package sysresources + +func totalMemoryImpl() uint64 { return 0 } diff --git a/util/sysresources/sysresources.go b/util/sysresources/sysresources.go index 32d972ab15513..1cce164a74730 100644 --- a/util/sysresources/sysresources.go +++ b/util/sysresources/sysresources.go @@ -1,6 +1,6 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package sysresources provides OS-independent methods of determining the -// resources available to the current system. -package sysresources +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package sysresources provides OS-independent methods of determining the +// resources available to the current system. +package sysresources diff --git a/util/sysresources/sysresources_test.go b/util/sysresources/sysresources_test.go index 331ad913bfba1..af96620421bae 100644 --- a/util/sysresources/sysresources_test.go +++ b/util/sysresources/sysresources_test.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package sysresources - -import ( - "runtime" - "testing" -) - -func TestTotalMemory(t *testing.T) { - switch runtime.GOOS { - case "linux": - case "freebsd", "openbsd", "dragonfly", "netbsd": - case "darwin": - default: - t.Skipf("not supported on runtime.GOOS=%q yet", runtime.GOOS) - } - - mem := TotalMemory() - if mem == 0 { - t.Fatal("wanted TotalMemory > 0") - } - t.Logf("total memory: %v bytes", mem) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package sysresources + +import ( + "runtime" + "testing" +) + +func TestTotalMemory(t *testing.T) { + switch runtime.GOOS { + case "linux": + case "freebsd", "openbsd", "dragonfly", "netbsd": + case "darwin": + default: + t.Skipf("not supported on runtime.GOOS=%q yet", runtime.GOOS) + } + + mem := TotalMemory() + if mem == 0 { + t.Fatal("wanted TotalMemory > 0") + } + t.Logf("total memory: %v bytes", mem) +} diff --git a/util/systemd/doc.go b/util/systemd/doc.go index 0c28e182354ec..296f74e9d4cd6 100644 --- a/util/systemd/doc.go +++ b/util/systemd/doc.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -/* -Package systemd contains a minimal wrapper around systemd-notify to enable -applications to signal readiness and status to systemd. - -This package will only have effect on Linux systems running Tailscale in a -systemd unit with the Type=notify flag set. On other operating systems (or -when running in a Linux distro without being run from inside systemd) this -package will become a no-op. -*/ -package systemd +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +/* +Package systemd contains a minimal wrapper around systemd-notify to enable +applications to signal readiness and status to systemd. + +This package will only have effect on Linux systems running Tailscale in a +systemd unit with the Type=notify flag set. On other operating systems (or +when running in a Linux distro without being run from inside systemd) this +package will become a no-op. +*/ +package systemd diff --git a/util/systemd/systemd_linux.go b/util/systemd/systemd_linux.go index 909cfcb20ac6e..34d6daff39e3b 100644 --- a/util/systemd/systemd_linux.go +++ b/util/systemd/systemd_linux.go @@ -1,77 +1,77 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package systemd - -import ( - "errors" - "log" - "os" - "sync" - - "github.com/mdlayher/sdnotify" -) - -var getNotifyOnce struct { - sync.Once - v *sdnotify.Notifier -} - -type logOnce struct { - sync.Once -} - -func (l *logOnce) logf(format string, args ...any) { - l.Once.Do(func() { - log.Printf(format, args...) - }) -} - -var ( - readyOnce = &logOnce{} - statusOnce = &logOnce{} -) - -func notifier() *sdnotify.Notifier { - getNotifyOnce.Do(func() { - var err error - getNotifyOnce.v, err = sdnotify.New() - // Not exist means probably not running under systemd, so don't log. - if err != nil && !errors.Is(err, os.ErrNotExist) { - log.Printf("systemd: systemd-notifier error: %v", err) - } - }) - return getNotifyOnce.v -} - -// Ready signals readiness to systemd. This will unblock service dependents from starting. -func Ready() { - err := notifier().Notify(sdnotify.Ready) - if err != nil { - readyOnce.logf("systemd: error notifying: %v", err) - } -} - -// Status sends a single line status update to systemd so that information shows up -// in systemctl output. For example: -// -// $ systemctl status tailscale -// ● tailscale.service - Tailscale client daemon -// Loaded: loaded (/nix/store/qc312qcy907wz80fqrgbbm8a9djafmlg-unit-tailscale.service/tailscale.service; enabled; vendor preset: enabled) -// Active: active (running) since Tue 2020-11-24 17:54:07 EST; 13h ago -// Main PID: 26741 (.tailscaled-wra) -// Status: "Connected; user@host.domain.tld; 100.101.102.103" -// IP: 0B in, 0B out -// Tasks: 22 (limit: 4915) -// Memory: 30.9M -// CPU: 2min 38.469s -// CGroup: /system.slice/tailscale.service -// └─26741 /nix/store/sv6cj4mw2jajm9xkbwj07k29dj30lh0n-tailscale-date.20200727/bin/tailscaled --port 41641 -func Status(format string, args ...any) { - err := notifier().Notify(sdnotify.Statusf(format, args...)) - if err != nil { - statusOnce.logf("systemd: error notifying: %v", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package systemd + +import ( + "errors" + "log" + "os" + "sync" + + "github.com/mdlayher/sdnotify" +) + +var getNotifyOnce struct { + sync.Once + v *sdnotify.Notifier +} + +type logOnce struct { + sync.Once +} + +func (l *logOnce) logf(format string, args ...any) { + l.Once.Do(func() { + log.Printf(format, args...) + }) +} + +var ( + readyOnce = &logOnce{} + statusOnce = &logOnce{} +) + +func notifier() *sdnotify.Notifier { + getNotifyOnce.Do(func() { + var err error + getNotifyOnce.v, err = sdnotify.New() + // Not exist means probably not running under systemd, so don't log. + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Printf("systemd: systemd-notifier error: %v", err) + } + }) + return getNotifyOnce.v +} + +// Ready signals readiness to systemd. This will unblock service dependents from starting. +func Ready() { + err := notifier().Notify(sdnotify.Ready) + if err != nil { + readyOnce.logf("systemd: error notifying: %v", err) + } +} + +// Status sends a single line status update to systemd so that information shows up +// in systemctl output. For example: +// +// $ systemctl status tailscale +// ● tailscale.service - Tailscale client daemon +// Loaded: loaded (/nix/store/qc312qcy907wz80fqrgbbm8a9djafmlg-unit-tailscale.service/tailscale.service; enabled; vendor preset: enabled) +// Active: active (running) since Tue 2020-11-24 17:54:07 EST; 13h ago +// Main PID: 26741 (.tailscaled-wra) +// Status: "Connected; user@host.domain.tld; 100.101.102.103" +// IP: 0B in, 0B out +// Tasks: 22 (limit: 4915) +// Memory: 30.9M +// CPU: 2min 38.469s +// CGroup: /system.slice/tailscale.service +// └─26741 /nix/store/sv6cj4mw2jajm9xkbwj07k29dj30lh0n-tailscale-date.20200727/bin/tailscaled --port 41641 +func Status(format string, args ...any) { + err := notifier().Notify(sdnotify.Statusf(format, args...)) + if err != nil { + statusOnce.logf("systemd: error notifying: %v", err) + } +} diff --git a/util/systemd/systemd_nonlinux.go b/util/systemd/systemd_nonlinux.go index 36214020ce566..d8b20665fb7ba 100644 --- a/util/systemd/systemd_nonlinux.go +++ b/util/systemd/systemd_nonlinux.go @@ -1,9 +1,9 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package systemd - -func Ready() {} -func Status(string, ...any) {} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package systemd + +func Ready() {} +func Status(string, ...any) {} diff --git a/util/testenv/testenv.go b/util/testenv/testenv.go index 12ada9003052b..02c688803a943 100644 --- a/util/testenv/testenv.go +++ b/util/testenv/testenv.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package testenv provides utility functions for tests. It does not depend on -// the `testing` package to allow usage in non-test code. -package testenv - -import ( - "flag" - - "tailscale.com/types/lazy" -) - -var lazyInTest lazy.SyncValue[bool] - -// InTest reports whether the current binary is a test binary. -func InTest() bool { - return lazyInTest.Get(func() bool { - return flag.Lookup("test.v") != nil - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package testenv provides utility functions for tests. It does not depend on +// the `testing` package to allow usage in non-test code. +package testenv + +import ( + "flag" + + "tailscale.com/types/lazy" +) + +var lazyInTest lazy.SyncValue[bool] + +// InTest reports whether the current binary is a test binary. +func InTest() bool { + return lazyInTest.Get(func() bool { + return flag.Lookup("test.v") != nil + }) +} diff --git a/util/truncate/truncate_test.go b/util/truncate/truncate_test.go index c0d9e6e14df99..6ead55a6ae76e 100644 --- a/util/truncate/truncate_test.go +++ b/util/truncate/truncate_test.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package truncate_test - -import ( - "testing" - - "tailscale.com/util/truncate" -) - -func TestString(t *testing.T) { - tests := []struct { - input string - size int - want string - }{ - {"", 1000, ""}, // n > length - {"abc", 4, "abc"}, // n > length - {"abc", 3, "abc"}, // n == length - {"abcdefg", 4, "abcd"}, // n < length, safe - {"abcdefg", 0, ""}, // n < length, safe - {"abc\U0001fc2d", 3, "abc"}, // n < length, at boundary - {"abc\U0001fc2d", 4, "abc"}, // n < length, mid-rune - {"abc\U0001fc2d", 5, "abc"}, // n < length, mid-rune - {"abc\U0001fc2d", 6, "abc"}, // n < length, mid-rune - {"abc\U0001fc2defg", 7, "abc"}, // n < length, cut multibyte - } - - for _, tc := range tests { - got := truncate.String(tc.input, tc.size) - if got != tc.want { - t.Errorf("truncate(%q, %d): got %q, want %q", tc.input, tc.size, got, tc.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package truncate_test + +import ( + "testing" + + "tailscale.com/util/truncate" +) + +func TestString(t *testing.T) { + tests := []struct { + input string + size int + want string + }{ + {"", 1000, ""}, // n > length + {"abc", 4, "abc"}, // n > length + {"abc", 3, "abc"}, // n == length + {"abcdefg", 4, "abcd"}, // n < length, safe + {"abcdefg", 0, ""}, // n < length, safe + {"abc\U0001fc2d", 3, "abc"}, // n < length, at boundary + {"abc\U0001fc2d", 4, "abc"}, // n < length, mid-rune + {"abc\U0001fc2d", 5, "abc"}, // n < length, mid-rune + {"abc\U0001fc2d", 6, "abc"}, // n < length, mid-rune + {"abc\U0001fc2defg", 7, "abc"}, // n < length, cut multibyte + } + + for _, tc := range tests { + got := truncate.String(tc.input, tc.size) + if got != tc.want { + t.Errorf("truncate(%q, %d): got %q, want %q", tc.input, tc.size, got, tc.want) + } + } +} diff --git a/util/uniq/slice.go b/util/uniq/slice.go index 4ab933a9d82d1..fb46cc491f5d7 100644 --- a/util/uniq/slice.go +++ b/util/uniq/slice.go @@ -1,62 +1,62 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package uniq provides removal of adjacent duplicate elements in slices. -// It is similar to the unix command uniq. -package uniq - -// ModifySlice removes adjacent duplicate elements from the given slice. It -// adjusts the length of the slice appropriately and zeros the tail. -// -// ModifySlice does O(len(*slice)) operations. -func ModifySlice[E comparable](slice *[]E) { - // Remove duplicates - dst := 0 - for i := 1; i < len(*slice); i++ { - if (*slice)[i] == (*slice)[dst] { - continue - } - dst++ - (*slice)[dst] = (*slice)[i] - } - - // Zero out the elements we removed at the end of the slice - end := dst + 1 - var zero E - for i := end; i < len(*slice); i++ { - (*slice)[i] = zero - } - - // Truncate the slice - if end < len(*slice) { - *slice = (*slice)[:end] - } -} - -// ModifySliceFunc is the same as ModifySlice except that it allows using a -// custom comparison function. -// -// eq should report whether the two provided elements are equal. -func ModifySliceFunc[E any](slice *[]E, eq func(i, j E) bool) { - // Remove duplicates - dst := 0 - for i := 1; i < len(*slice); i++ { - if eq((*slice)[dst], (*slice)[i]) { - continue - } - dst++ - (*slice)[dst] = (*slice)[i] - } - - // Zero out the elements we removed at the end of the slice - end := dst + 1 - var zero E - for i := end; i < len(*slice); i++ { - (*slice)[i] = zero - } - - // Truncate the slice - if end < len(*slice) { - *slice = (*slice)[:end] - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package uniq provides removal of adjacent duplicate elements in slices. +// It is similar to the unix command uniq. +package uniq + +// ModifySlice removes adjacent duplicate elements from the given slice. It +// adjusts the length of the slice appropriately and zeros the tail. +// +// ModifySlice does O(len(*slice)) operations. +func ModifySlice[E comparable](slice *[]E) { + // Remove duplicates + dst := 0 + for i := 1; i < len(*slice); i++ { + if (*slice)[i] == (*slice)[dst] { + continue + } + dst++ + (*slice)[dst] = (*slice)[i] + } + + // Zero out the elements we removed at the end of the slice + end := dst + 1 + var zero E + for i := end; i < len(*slice); i++ { + (*slice)[i] = zero + } + + // Truncate the slice + if end < len(*slice) { + *slice = (*slice)[:end] + } +} + +// ModifySliceFunc is the same as ModifySlice except that it allows using a +// custom comparison function. +// +// eq should report whether the two provided elements are equal. +func ModifySliceFunc[E any](slice *[]E, eq func(i, j E) bool) { + // Remove duplicates + dst := 0 + for i := 1; i < len(*slice); i++ { + if eq((*slice)[dst], (*slice)[i]) { + continue + } + dst++ + (*slice)[dst] = (*slice)[i] + } + + // Zero out the elements we removed at the end of the slice + end := dst + 1 + var zero E + for i := end; i < len(*slice); i++ { + (*slice)[i] = zero + } + + // Truncate the slice + if end < len(*slice) { + *slice = (*slice)[:end] + } +} diff --git a/util/winutil/authenticode/mksyscall.go b/util/winutil/authenticode/mksyscall.go index 8b7cabe6e4d7f..7c6b33973de8e 100644 --- a/util/winutil/authenticode/mksyscall.go +++ b/util/winutil/authenticode/mksyscall.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package authenticode - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go -//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go - -//sys cryptCATAdminAcquireContext2(hCatAdmin *_HCATADMIN, pgSubsystem *windows.GUID, hashAlgorithm *uint16, strongHashPolicy *windows.CertStrongSignPara, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminAcquireContext2 -//sys cryptCATAdminCalcHashFromFileHandle2(hCatAdmin _HCATADMIN, file windows.Handle, pcbHash *uint32, pbHash *byte, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminCalcHashFromFileHandle2 -//sys cryptCATAdminCatalogInfoFromContext(hCatInfo _HCATINFO, catInfo *_CATALOG_INFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATCatalogInfoFromContext -//sys cryptCATAdminEnumCatalogFromHash(hCatAdmin _HCATADMIN, pbHash *byte, cbHash uint32, flags uint32, prevCatInfo *_HCATINFO) (ret _HCATINFO, err error) [ret==0] = wintrust.CryptCATAdminEnumCatalogFromHash -//sys cryptCATAdminReleaseCatalogContext(hCatAdmin _HCATADMIN, hCatInfo _HCATINFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseCatalogContext -//sys cryptCATAdminReleaseContext(hCatAdmin _HCATADMIN, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseContext -//sys cryptMsgClose(cryptMsg windows.Handle) (err error) [int32(failretval)==0] = crypt32.CryptMsgClose -//sys cryptMsgGetParam(cryptMsg windows.Handle, paramType uint32, index uint32, data unsafe.Pointer, dataLen *uint32) (err error) [int32(failretval)==0] = crypt32.CryptMsgGetParam -//sys cryptVerifyMessageSignature(pVerifyPara *_CRYPT_VERIFY_MESSAGE_PARA, signerIndex uint32, pbSignedBlob *byte, cbSignedBlob uint32, pbDecoded *byte, pdbDecoded *uint32, ppSignerCert **windows.CertContext) (err error) [int32(failretval)==0] = crypt32.CryptVerifyMessageSignature -//sys msiGetFileSignatureInformation(signedObjectPath *uint16, flags uint32, certCtx **windows.CertContext, pbHashData *byte, cbHashData *uint32) (ret wingoes.HRESULT) = msi.MsiGetFileSignatureInformationW +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package authenticode + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go +//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go + +//sys cryptCATAdminAcquireContext2(hCatAdmin *_HCATADMIN, pgSubsystem *windows.GUID, hashAlgorithm *uint16, strongHashPolicy *windows.CertStrongSignPara, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminAcquireContext2 +//sys cryptCATAdminCalcHashFromFileHandle2(hCatAdmin _HCATADMIN, file windows.Handle, pcbHash *uint32, pbHash *byte, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminCalcHashFromFileHandle2 +//sys cryptCATAdminCatalogInfoFromContext(hCatInfo _HCATINFO, catInfo *_CATALOG_INFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATCatalogInfoFromContext +//sys cryptCATAdminEnumCatalogFromHash(hCatAdmin _HCATADMIN, pbHash *byte, cbHash uint32, flags uint32, prevCatInfo *_HCATINFO) (ret _HCATINFO, err error) [ret==0] = wintrust.CryptCATAdminEnumCatalogFromHash +//sys cryptCATAdminReleaseCatalogContext(hCatAdmin _HCATADMIN, hCatInfo _HCATINFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseCatalogContext +//sys cryptCATAdminReleaseContext(hCatAdmin _HCATADMIN, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseContext +//sys cryptMsgClose(cryptMsg windows.Handle) (err error) [int32(failretval)==0] = crypt32.CryptMsgClose +//sys cryptMsgGetParam(cryptMsg windows.Handle, paramType uint32, index uint32, data unsafe.Pointer, dataLen *uint32) (err error) [int32(failretval)==0] = crypt32.CryptMsgGetParam +//sys cryptVerifyMessageSignature(pVerifyPara *_CRYPT_VERIFY_MESSAGE_PARA, signerIndex uint32, pbSignedBlob *byte, cbSignedBlob uint32, pbDecoded *byte, pdbDecoded *uint32, ppSignerCert **windows.CertContext) (err error) [int32(failretval)==0] = crypt32.CryptVerifyMessageSignature +//sys msiGetFileSignatureInformation(signedObjectPath *uint16, flags uint32, certCtx **windows.CertContext, pbHashData *byte, cbHashData *uint32) (ret wingoes.HRESULT) = msi.MsiGetFileSignatureInformationW diff --git a/util/winutil/policy/policy_windows.go b/util/winutil/policy/policy_windows.go index 89142951f8bd5..4674696fa101d 100644 --- a/util/winutil/policy/policy_windows.go +++ b/util/winutil/policy/policy_windows.go @@ -1,155 +1,155 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package policy contains higher-level abstractions for accessing Windows enterprise policies. -package policy - -import ( - "time" - - "tailscale.com/util/winutil" -) - -// PreferenceOptionPolicy is a policy that governs whether a boolean variable -// is forcibly assigned an administrator-defined value, or allowed to receive -// a user-defined value. -type PreferenceOptionPolicy int - -const ( - showChoiceByPolicy PreferenceOptionPolicy = iota - neverByPolicy - alwaysByPolicy -) - -// Show returns if the UI option that controls the choice administered by this -// policy should be shown. Currently this is true if and only if the policy is -// showChoiceByPolicy. -func (p PreferenceOptionPolicy) Show() bool { - return p == showChoiceByPolicy -} - -// ShouldEnable checks if the choice administered by this policy should be -// enabled. If the administrator has chosen a setting, the administrator's -// setting is returned, otherwise userChoice is returned. -func (p PreferenceOptionPolicy) ShouldEnable(userChoice bool) bool { - switch p { - case neverByPolicy: - return false - case alwaysByPolicy: - return true - default: - return userChoice - } -} - -// GetPreferenceOptionPolicy loads a policy from the registry that can be -// managed by an enterprise policy management system and allows administrative -// overrides of users' choices in a way that we do not want tailcontrol to have -// the authority to set. It describes user-decides/always/never options, where -// "always" and "never" remove the user's ability to make a selection. If not -// present or set to a different value, "user-decides" is the default. -func GetPreferenceOptionPolicy(name string) PreferenceOptionPolicy { - opt, err := winutil.GetPolicyString(name) - if opt == "" || err != nil { - return showChoiceByPolicy - } - switch opt { - case "always": - return alwaysByPolicy - case "never": - return neverByPolicy - default: - return showChoiceByPolicy - } -} - -// VisibilityPolicy is a policy that controls whether or not a particular -// component of a user interface is to be shown. -type VisibilityPolicy byte - -const ( - visibleByPolicy VisibilityPolicy = 'v' - hiddenByPolicy VisibilityPolicy = 'h' -) - -// Show reports whether the UI option administered by this policy should be shown. -// Currently this is true if and only if the policy is visibleByPolicy. -func (p VisibilityPolicy) Show() bool { - return p == visibleByPolicy -} - -// GetVisibilityPolicy loads a policy from the registry that can be managed -// by an enterprise policy management system and describes show/hide decisions -// for UI elements. The registry value should be a string set to "show" (return -// true) or "hide" (return true). If not present or set to a different value, -// "show" (return false) is the default. -func GetVisibilityPolicy(name string) VisibilityPolicy { - opt, err := winutil.GetPolicyString(name) - if opt == "" || err != nil { - return visibleByPolicy - } - switch opt { - case "hide": - return hiddenByPolicy - default: - return visibleByPolicy - } -} - -// GetDurationPolicy loads a policy from the registry that can be managed -// by an enterprise policy management system and describes a duration for some -// action. The registry value should be a string that time.ParseDuration -// understands. If the registry value is "" or can not be processed, -// defaultValue is returned instead. -func GetDurationPolicy(name string, defaultValue time.Duration) time.Duration { - opt, err := winutil.GetPolicyString(name) - if opt == "" || err != nil { - return defaultValue - } - v, err := time.ParseDuration(opt) - if err != nil || v < 0 { - return defaultValue - } - return v -} - -// SelectControlURL returns the ControlURL to use based on a value in -// the registry (LoginURL) and the one on disk (in the GUI's -// prefs.conf). If both are empty, it returns a default value. (It -// always return a non-empty value) -// -// See https://github.com/tailscale/tailscale/issues/2798 for some background. -func SelectControlURL(reg, disk string) string { - const def = "https://controlplane.tailscale.com" - - // Prior to Dec 2020's commit 739b02e6, the installer - // wrote a LoginURL value of https://login.tailscale.com to the registry. - const oldRegDef = "https://login.tailscale.com" - - // If they have an explicit value in the registry, use it, - // unless it's an old default value from an old installer. - // Then we have to see which is better. - if reg != "" { - if reg != oldRegDef { - // Something explicit in the registry that we didn't - // set ourselves by the installer. - return reg - } - if disk == "" { - // Something in the registry is better than nothing on disk. - return reg - } - if disk != def && disk != oldRegDef { - // The value in the registry is the old - // default (login.tailscale.com) but the value - // on disk is neither our old nor new default - // value, so it must be some custom thing that - // the user cares about. Prefer the disk value. - return disk - } - } - if disk != "" { - return disk - } - return def -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package policy contains higher-level abstractions for accessing Windows enterprise policies. +package policy + +import ( + "time" + + "tailscale.com/util/winutil" +) + +// PreferenceOptionPolicy is a policy that governs whether a boolean variable +// is forcibly assigned an administrator-defined value, or allowed to receive +// a user-defined value. +type PreferenceOptionPolicy int + +const ( + showChoiceByPolicy PreferenceOptionPolicy = iota + neverByPolicy + alwaysByPolicy +) + +// Show returns if the UI option that controls the choice administered by this +// policy should be shown. Currently this is true if and only if the policy is +// showChoiceByPolicy. +func (p PreferenceOptionPolicy) Show() bool { + return p == showChoiceByPolicy +} + +// ShouldEnable checks if the choice administered by this policy should be +// enabled. If the administrator has chosen a setting, the administrator's +// setting is returned, otherwise userChoice is returned. +func (p PreferenceOptionPolicy) ShouldEnable(userChoice bool) bool { + switch p { + case neverByPolicy: + return false + case alwaysByPolicy: + return true + default: + return userChoice + } +} + +// GetPreferenceOptionPolicy loads a policy from the registry that can be +// managed by an enterprise policy management system and allows administrative +// overrides of users' choices in a way that we do not want tailcontrol to have +// the authority to set. It describes user-decides/always/never options, where +// "always" and "never" remove the user's ability to make a selection. If not +// present or set to a different value, "user-decides" is the default. +func GetPreferenceOptionPolicy(name string) PreferenceOptionPolicy { + opt, err := winutil.GetPolicyString(name) + if opt == "" || err != nil { + return showChoiceByPolicy + } + switch opt { + case "always": + return alwaysByPolicy + case "never": + return neverByPolicy + default: + return showChoiceByPolicy + } +} + +// VisibilityPolicy is a policy that controls whether or not a particular +// component of a user interface is to be shown. +type VisibilityPolicy byte + +const ( + visibleByPolicy VisibilityPolicy = 'v' + hiddenByPolicy VisibilityPolicy = 'h' +) + +// Show reports whether the UI option administered by this policy should be shown. +// Currently this is true if and only if the policy is visibleByPolicy. +func (p VisibilityPolicy) Show() bool { + return p == visibleByPolicy +} + +// GetVisibilityPolicy loads a policy from the registry that can be managed +// by an enterprise policy management system and describes show/hide decisions +// for UI elements. The registry value should be a string set to "show" (return +// true) or "hide" (return true). If not present or set to a different value, +// "show" (return false) is the default. +func GetVisibilityPolicy(name string) VisibilityPolicy { + opt, err := winutil.GetPolicyString(name) + if opt == "" || err != nil { + return visibleByPolicy + } + switch opt { + case "hide": + return hiddenByPolicy + default: + return visibleByPolicy + } +} + +// GetDurationPolicy loads a policy from the registry that can be managed +// by an enterprise policy management system and describes a duration for some +// action. The registry value should be a string that time.ParseDuration +// understands. If the registry value is "" or can not be processed, +// defaultValue is returned instead. +func GetDurationPolicy(name string, defaultValue time.Duration) time.Duration { + opt, err := winutil.GetPolicyString(name) + if opt == "" || err != nil { + return defaultValue + } + v, err := time.ParseDuration(opt) + if err != nil || v < 0 { + return defaultValue + } + return v +} + +// SelectControlURL returns the ControlURL to use based on a value in +// the registry (LoginURL) and the one on disk (in the GUI's +// prefs.conf). If both are empty, it returns a default value. (It +// always return a non-empty value) +// +// See https://github.com/tailscale/tailscale/issues/2798 for some background. +func SelectControlURL(reg, disk string) string { + const def = "https://controlplane.tailscale.com" + + // Prior to Dec 2020's commit 739b02e6, the installer + // wrote a LoginURL value of https://login.tailscale.com to the registry. + const oldRegDef = "https://login.tailscale.com" + + // If they have an explicit value in the registry, use it, + // unless it's an old default value from an old installer. + // Then we have to see which is better. + if reg != "" { + if reg != oldRegDef { + // Something explicit in the registry that we didn't + // set ourselves by the installer. + return reg + } + if disk == "" { + // Something in the registry is better than nothing on disk. + return reg + } + if disk != def && disk != oldRegDef { + // The value in the registry is the old + // default (login.tailscale.com) but the value + // on disk is neither our old nor new default + // value, so it must be some custom thing that + // the user cares about. Prefer the disk value. + return disk + } + } + if disk != "" { + return disk + } + return def +} diff --git a/util/winutil/policy/policy_windows_test.go b/util/winutil/policy/policy_windows_test.go index cf2390c568cce..ebfd185deaaf2 100644 --- a/util/winutil/policy/policy_windows_test.go +++ b/util/winutil/policy/policy_windows_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package policy - -import "testing" - -func TestSelectControlURL(t *testing.T) { - tests := []struct { - reg, disk, want string - }{ - // Modern default case. - {"", "", "https://controlplane.tailscale.com"}, - - // For a user who installed prior to Dec 2020, with - // stuff in their registry. - {"https://login.tailscale.com", "", "https://login.tailscale.com"}, - - // Ignore pre-Dec'20 LoginURL from installer if prefs - // prefs overridden manually to an on-prem control - // server. - {"https://login.tailscale.com", "http://on-prem", "http://on-prem"}, - - // Something unknown explicitly set in the registry always wins. - {"http://explicit-reg", "", "http://explicit-reg"}, - {"http://explicit-reg", "http://on-prem", "http://explicit-reg"}, - {"http://explicit-reg", "https://login.tailscale.com", "http://explicit-reg"}, - {"http://explicit-reg", "https://controlplane.tailscale.com", "http://explicit-reg"}, - - // If nothing in the registry, disk wins. - {"", "http://on-prem", "http://on-prem"}, - } - for _, tt := range tests { - if got := SelectControlURL(tt.reg, tt.disk); got != tt.want { - t.Errorf("(reg %q, disk %q) = %q; want %q", tt.reg, tt.disk, got, tt.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package policy + +import "testing" + +func TestSelectControlURL(t *testing.T) { + tests := []struct { + reg, disk, want string + }{ + // Modern default case. + {"", "", "https://controlplane.tailscale.com"}, + + // For a user who installed prior to Dec 2020, with + // stuff in their registry. + {"https://login.tailscale.com", "", "https://login.tailscale.com"}, + + // Ignore pre-Dec'20 LoginURL from installer if prefs + // prefs overridden manually to an on-prem control + // server. + {"https://login.tailscale.com", "http://on-prem", "http://on-prem"}, + + // Something unknown explicitly set in the registry always wins. + {"http://explicit-reg", "", "http://explicit-reg"}, + {"http://explicit-reg", "http://on-prem", "http://explicit-reg"}, + {"http://explicit-reg", "https://login.tailscale.com", "http://explicit-reg"}, + {"http://explicit-reg", "https://controlplane.tailscale.com", "http://explicit-reg"}, + + // If nothing in the registry, disk wins. + {"", "http://on-prem", "http://on-prem"}, + } + for _, tt := range tests { + if got := SelectControlURL(tt.reg, tt.disk); got != tt.want { + t.Errorf("(reg %q, disk %q) = %q; want %q", tt.reg, tt.disk, got, tt.want) + } + } +} diff --git a/version/.gitignore b/version/.gitignore index 58d19bfc27c97..8878450fa4364 100644 --- a/version/.gitignore +++ b/version/.gitignore @@ -1,10 +1,10 @@ -describe.txt -long.txt -short.txt -gitcommit.txt -extragitcommit.txt -version-info.sh -version.h -version.xcconfig -ver.go -version +describe.txt +long.txt +short.txt +gitcommit.txt +extragitcommit.txt +version-info.sh +version.h +version.xcconfig +ver.go +version diff --git a/version/cmdname.go b/version/cmdname.go index 51e065438e3a5..9f85ef96d427f 100644 --- a/version/cmdname.go +++ b/version/cmdname.go @@ -1,139 +1,139 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ios - -package version - -import ( - "bytes" - "encoding/hex" - "errors" - "io" - "os" - "path" - "path/filepath" - "strings" -) - -// CmdName returns either the base name of the current binary -// using os.Executable. If os.Executable fails (it shouldn't), then -// "cmd" is returned. -func CmdName() string { - e, err := os.Executable() - if err != nil { - return "cmd" - } - return cmdName(e) -} - -func cmdName(exe string) string { - // fallbackName, the lowercase basename of the executable, is what we return if - // we can't find the Go module metadata embedded in the file. - fallbackName := filepath.Base(strings.TrimSuffix(strings.ToLower(exe), ".exe")) - - var ret string - info, err := findModuleInfo(exe) - if err != nil { - return fallbackName - } - // v is like: - // "path\ttailscale.com/cmd/tailscale\nmod\ttailscale.com\t(devel)\t\ndep\tgithub.com/apenwarr/fixconsole\tv0.0.0-20191012055117-5a9f6489cc29\th1:muXWUcay7DDy1/hEQWrYlBy+g0EuwT70sBHg65SeUc4=\ndep\tgithub.... - for _, line := range strings.Split(info, "\n") { - if goPkg, ok := strings.CutPrefix(line, "path\t"); ok { // like "tailscale.com/cmd/tailscale" - ret = path.Base(goPkg) // goPkg is always forward slashes; use path, not filepath - break - } - } - if strings.HasPrefix(ret, "wg") && fallbackName == "tailscale-ipn" { - // The tailscale-ipn.exe binary for internal build system packaging reasons - // has a path of "tailscale.io/win/wg64", "tailscale.io/win/wg32", etc. - // Ignore that name and use "tailscale-ipn" instead. - return fallbackName - } - if ret == "" { - return fallbackName - } - return ret -} - -// findModuleInfo returns the Go module info from the executable file. -func findModuleInfo(file string) (s string, err error) { - f, err := os.Open(file) - if err != nil { - return "", err - } - defer f.Close() - // Scan through f until we find infoStart. - buf := make([]byte, 65536) - start, err := findOffset(f, buf, infoStart) - if err != nil { - return "", err - } - start += int64(len(infoStart)) - // Seek to the end of infoStart and scan for infoEnd. - _, err = f.Seek(start, io.SeekStart) - if err != nil { - return "", err - } - end, err := findOffset(f, buf, infoEnd) - if err != nil { - return "", err - } - length := end - start - // As of Aug 2021, tailscaled's mod info was about 2k. - if length > int64(len(buf)) { - return "", errors.New("mod info too large") - } - // We have located modinfo. Read it into buf. - buf = buf[:length] - _, err = f.Seek(start, io.SeekStart) - if err != nil { - return "", err - } - _, err = io.ReadFull(f, buf) - if err != nil { - return "", err - } - return string(buf), nil -} - -// findOffset finds the absolute offset of needle in f, -// starting at f's current read position, -// using temporary buffer buf. -func findOffset(f *os.File, buf, needle []byte) (int64, error) { - for { - // Fill buf and look within it. - n, err := f.Read(buf) - if err != nil { - return -1, err - } - i := bytes.Index(buf[:n], needle) - if i < 0 { - // Not found. Rewind a little bit in case we happened to end halfway through needle. - rewind, err := f.Seek(int64(-len(needle)), io.SeekCurrent) - if err != nil { - return -1, err - } - // If we're at EOF and rewound exactly len(needle) bytes, return io.EOF. - _, err = f.ReadAt(buf[:1], rewind+int64(len(needle))) - if err == io.EOF { - return -1, err - } - continue - } - // Found! Figure out exactly where. - cur, err := f.Seek(0, io.SeekCurrent) - if err != nil { - return -1, err - } - return cur - int64(n) + int64(i), nil - } -} - -// These constants are taken from rsc.io/goversion. - -var ( - infoStart, _ = hex.DecodeString("3077af0c9274080241e1c107e6d618e6") - infoEnd, _ = hex.DecodeString("f932433186182072008242104116d8f2") -) +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios + +package version + +import ( + "bytes" + "encoding/hex" + "errors" + "io" + "os" + "path" + "path/filepath" + "strings" +) + +// CmdName returns either the base name of the current binary +// using os.Executable. If os.Executable fails (it shouldn't), then +// "cmd" is returned. +func CmdName() string { + e, err := os.Executable() + if err != nil { + return "cmd" + } + return cmdName(e) +} + +func cmdName(exe string) string { + // fallbackName, the lowercase basename of the executable, is what we return if + // we can't find the Go module metadata embedded in the file. + fallbackName := filepath.Base(strings.TrimSuffix(strings.ToLower(exe), ".exe")) + + var ret string + info, err := findModuleInfo(exe) + if err != nil { + return fallbackName + } + // v is like: + // "path\ttailscale.com/cmd/tailscale\nmod\ttailscale.com\t(devel)\t\ndep\tgithub.com/apenwarr/fixconsole\tv0.0.0-20191012055117-5a9f6489cc29\th1:muXWUcay7DDy1/hEQWrYlBy+g0EuwT70sBHg65SeUc4=\ndep\tgithub.... + for _, line := range strings.Split(info, "\n") { + if goPkg, ok := strings.CutPrefix(line, "path\t"); ok { // like "tailscale.com/cmd/tailscale" + ret = path.Base(goPkg) // goPkg is always forward slashes; use path, not filepath + break + } + } + if strings.HasPrefix(ret, "wg") && fallbackName == "tailscale-ipn" { + // The tailscale-ipn.exe binary for internal build system packaging reasons + // has a path of "tailscale.io/win/wg64", "tailscale.io/win/wg32", etc. + // Ignore that name and use "tailscale-ipn" instead. + return fallbackName + } + if ret == "" { + return fallbackName + } + return ret +} + +// findModuleInfo returns the Go module info from the executable file. +func findModuleInfo(file string) (s string, err error) { + f, err := os.Open(file) + if err != nil { + return "", err + } + defer f.Close() + // Scan through f until we find infoStart. + buf := make([]byte, 65536) + start, err := findOffset(f, buf, infoStart) + if err != nil { + return "", err + } + start += int64(len(infoStart)) + // Seek to the end of infoStart and scan for infoEnd. + _, err = f.Seek(start, io.SeekStart) + if err != nil { + return "", err + } + end, err := findOffset(f, buf, infoEnd) + if err != nil { + return "", err + } + length := end - start + // As of Aug 2021, tailscaled's mod info was about 2k. + if length > int64(len(buf)) { + return "", errors.New("mod info too large") + } + // We have located modinfo. Read it into buf. + buf = buf[:length] + _, err = f.Seek(start, io.SeekStart) + if err != nil { + return "", err + } + _, err = io.ReadFull(f, buf) + if err != nil { + return "", err + } + return string(buf), nil +} + +// findOffset finds the absolute offset of needle in f, +// starting at f's current read position, +// using temporary buffer buf. +func findOffset(f *os.File, buf, needle []byte) (int64, error) { + for { + // Fill buf and look within it. + n, err := f.Read(buf) + if err != nil { + return -1, err + } + i := bytes.Index(buf[:n], needle) + if i < 0 { + // Not found. Rewind a little bit in case we happened to end halfway through needle. + rewind, err := f.Seek(int64(-len(needle)), io.SeekCurrent) + if err != nil { + return -1, err + } + // If we're at EOF and rewound exactly len(needle) bytes, return io.EOF. + _, err = f.ReadAt(buf[:1], rewind+int64(len(needle))) + if err == io.EOF { + return -1, err + } + continue + } + // Found! Figure out exactly where. + cur, err := f.Seek(0, io.SeekCurrent) + if err != nil { + return -1, err + } + return cur - int64(n) + int64(i), nil + } +} + +// These constants are taken from rsc.io/goversion. + +var ( + infoStart, _ = hex.DecodeString("3077af0c9274080241e1c107e6d618e6") + infoEnd, _ = hex.DecodeString("f932433186182072008242104116d8f2") +) diff --git a/version/cmdname_ios.go b/version/cmdname_ios.go index 6bfed38b64226..5e338944c6916 100644 --- a/version/cmdname_ios.go +++ b/version/cmdname_ios.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ios - -package version - -import ( - "os" -) - -func CmdName() string { - e, err := os.Executable() - if err != nil { - return "cmd" - } - return e -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ios + +package version + +import ( + "os" +) + +func CmdName() string { + e, err := os.Executable() + if err != nil { + return "cmd" + } + return e +} diff --git a/version/cmp_test.go b/version/cmp_test.go index e244d5e16fe22..59153f0dd15d0 100644 --- a/version/cmp_test.go +++ b/version/cmp_test.go @@ -1,82 +1,82 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package version_test - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - "tailscale.com/tstest" - "tailscale.com/version" -) - -func TestParse(t *testing.T) { - parse := version.ExportParse - type parsed = version.ExportParsed - - tests := []struct { - version string - parsed parsed - want bool - }{ - {"1", parsed{Major: 1}, true}, - {"1.2", parsed{Major: 1, Minor: 2}, true}, - {"1.2.3", parsed{Major: 1, Minor: 2, Patch: 3}, true}, - {"1.2.3-4", parsed{Major: 1, Minor: 2, Patch: 3, ExtraCommits: 4}, true}, - {"1.2-4", parsed{Major: 1, Minor: 2, ExtraCommits: 4}, true}, - {"1.2.3-4-extra", parsed{Major: 1, Minor: 2, Patch: 3, ExtraCommits: 4}, true}, - {"1.2.3-4a-test", parsed{Major: 1, Minor: 2, Patch: 3}, true}, - {"1.2-extra", parsed{Major: 1, Minor: 2}, true}, - {"1.2.3-extra", parsed{Major: 1, Minor: 2, Patch: 3}, true}, - {"date.20200612", parsed{Datestamp: 20200612}, true}, - {"borkbork", parsed{}, false}, - {"1a.2.3", parsed{}, false}, - {"", parsed{}, false}, - } - - for _, test := range tests { - gotParsed, got := parse(test.version) - if got != test.want { - t.Errorf("version(%q) = %v, want %v", test.version, got, test.want) - } - if diff := cmp.Diff(gotParsed, test.parsed); diff != "" { - t.Errorf("parse(%q) diff (-got+want):\n%s", test.version, diff) - } - err := tstest.MinAllocsPerRun(t, 0, func() { - gotParsed, got = parse(test.version) - }) - if err != nil { - t.Errorf("parse(%q): %v", test.version, err) - } - } -} - -func TestAtLeast(t *testing.T) { - tests := []struct { - v, m string - want bool - }{ - {"1", "1", true}, - {"1.2", "1", true}, - {"1.2.3", "1", true}, - {"1.2.3-4", "1", true}, - {"0.98-0", "0.98", true}, - {"0.97.1-216", "0.98", false}, - {"0.94", "0.98", false}, - {"0.98", "0.98", true}, - {"0.98.0-0", "0.98", true}, - {"1.2.3-4", "1.2.4-4", false}, - {"1.2.3-4", "1.2.3-4", true}, - {"date.20200612", "date.20200612", true}, - {"date.20200701", "date.20200612", true}, - {"date.20200501", "date.20200612", false}, - } - - for _, test := range tests { - got := version.AtLeast(test.v, test.m) - if got != test.want { - t.Errorf("AtLeast(%q, %q) = %v, want %v", test.v, test.m, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/tstest" + "tailscale.com/version" +) + +func TestParse(t *testing.T) { + parse := version.ExportParse + type parsed = version.ExportParsed + + tests := []struct { + version string + parsed parsed + want bool + }{ + {"1", parsed{Major: 1}, true}, + {"1.2", parsed{Major: 1, Minor: 2}, true}, + {"1.2.3", parsed{Major: 1, Minor: 2, Patch: 3}, true}, + {"1.2.3-4", parsed{Major: 1, Minor: 2, Patch: 3, ExtraCommits: 4}, true}, + {"1.2-4", parsed{Major: 1, Minor: 2, ExtraCommits: 4}, true}, + {"1.2.3-4-extra", parsed{Major: 1, Minor: 2, Patch: 3, ExtraCommits: 4}, true}, + {"1.2.3-4a-test", parsed{Major: 1, Minor: 2, Patch: 3}, true}, + {"1.2-extra", parsed{Major: 1, Minor: 2}, true}, + {"1.2.3-extra", parsed{Major: 1, Minor: 2, Patch: 3}, true}, + {"date.20200612", parsed{Datestamp: 20200612}, true}, + {"borkbork", parsed{}, false}, + {"1a.2.3", parsed{}, false}, + {"", parsed{}, false}, + } + + for _, test := range tests { + gotParsed, got := parse(test.version) + if got != test.want { + t.Errorf("version(%q) = %v, want %v", test.version, got, test.want) + } + if diff := cmp.Diff(gotParsed, test.parsed); diff != "" { + t.Errorf("parse(%q) diff (-got+want):\n%s", test.version, diff) + } + err := tstest.MinAllocsPerRun(t, 0, func() { + gotParsed, got = parse(test.version) + }) + if err != nil { + t.Errorf("parse(%q): %v", test.version, err) + } + } +} + +func TestAtLeast(t *testing.T) { + tests := []struct { + v, m string + want bool + }{ + {"1", "1", true}, + {"1.2", "1", true}, + {"1.2.3", "1", true}, + {"1.2.3-4", "1", true}, + {"0.98-0", "0.98", true}, + {"0.97.1-216", "0.98", false}, + {"0.94", "0.98", false}, + {"0.98", "0.98", true}, + {"0.98.0-0", "0.98", true}, + {"1.2.3-4", "1.2.4-4", false}, + {"1.2.3-4", "1.2.3-4", true}, + {"date.20200612", "date.20200612", true}, + {"date.20200701", "date.20200612", true}, + {"date.20200501", "date.20200612", false}, + } + + for _, test := range tests { + got := version.AtLeast(test.v, test.m) + if got != test.want { + t.Errorf("AtLeast(%q, %q) = %v, want %v", test.v, test.m, got, test.want) + } + } +} diff --git a/version/export_test.go b/version/export_test.go index 8e8ce5ecb2129..fabba13e8ba55 100644 --- a/version/export_test.go +++ b/version/export_test.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package version - -var ( - ExportParse = parse - ExportFindModuleInfo = findModuleInfo - ExportCmdName = cmdName -) - -type ( - ExportParsed = parsed -) +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version + +var ( + ExportParse = parse + ExportFindModuleInfo = findModuleInfo + ExportCmdName = cmdName +) + +type ( + ExportParsed = parsed +) diff --git a/version/print.go b/version/print.go index 7d8554279f255..e3bfc38efa16c 100644 --- a/version/print.go +++ b/version/print.go @@ -1,33 +1,33 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package version - -import ( - "fmt" - "runtime" - "strings" - - "tailscale.com/types/lazy" -) - -var stringLazy = lazy.SyncFunc(func() string { - var ret strings.Builder - ret.WriteString(Short()) - ret.WriteByte('\n') - if IsUnstableBuild() { - fmt.Fprintf(&ret, " track: unstable (dev); frequent updates and bugs are likely\n") - } - if gitCommit() != "" { - fmt.Fprintf(&ret, " tailscale commit: %s%s\n", gitCommit(), dirtyString()) - } - if extraGitCommitStamp != "" { - fmt.Fprintf(&ret, " other commit: %s\n", extraGitCommitStamp) - } - fmt.Fprintf(&ret, " go version: %s\n", runtime.Version()) - return strings.TrimSpace(ret.String()) -}) - -func String() string { - return stringLazy() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version + +import ( + "fmt" + "runtime" + "strings" + + "tailscale.com/types/lazy" +) + +var stringLazy = lazy.SyncFunc(func() string { + var ret strings.Builder + ret.WriteString(Short()) + ret.WriteByte('\n') + if IsUnstableBuild() { + fmt.Fprintf(&ret, " track: unstable (dev); frequent updates and bugs are likely\n") + } + if gitCommit() != "" { + fmt.Fprintf(&ret, " tailscale commit: %s%s\n", gitCommit(), dirtyString()) + } + if extraGitCommitStamp != "" { + fmt.Fprintf(&ret, " other commit: %s\n", extraGitCommitStamp) + } + fmt.Fprintf(&ret, " go version: %s\n", runtime.Version()) + return strings.TrimSpace(ret.String()) +}) + +func String() string { + return stringLazy() +} diff --git a/version/race.go b/version/race.go index e1dc76591ebf4..bc3ca8db6b6dd 100644 --- a/version/race.go +++ b/version/race.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build race - -package version - -// IsRace reports whether the current binary was built with the Go -// race detector enabled. -func IsRace() bool { return true } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build race + +package version + +// IsRace reports whether the current binary was built with the Go +// race detector enabled. +func IsRace() bool { return true } diff --git a/version/race_off.go b/version/race_off.go index 6db901974bb77..d55288d9cc962 100644 --- a/version/race_off.go +++ b/version/race_off.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !race - -package version - -// IsRace reports whether the current binary was built with the Go -// race detector enabled. -func IsRace() bool { return false } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !race + +package version + +// IsRace reports whether the current binary was built with the Go +// race detector enabled. +func IsRace() bool { return false } diff --git a/version/version_test.go b/version/version_test.go index a515650586cc4..4d676f9f5ea1f 100644 --- a/version/version_test.go +++ b/version/version_test.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package version_test - -import ( - "bytes" - "os" - "testing" - - ts "tailscale.com" - "tailscale.com/version" -) - -func TestAlpineTag(t *testing.T) { - if tag := readAlpineTag(t, "../Dockerfile.base"); tag == "" { - t.Fatal(`"FROM alpine:" not found in Dockerfile.base`) - } else if tag != ts.AlpineDockerTag { - t.Errorf("alpine version mismatch: Dockerfile.base has %q; ALPINE.txt has %q", tag, ts.AlpineDockerTag) - } - if tag := readAlpineTag(t, "../Dockerfile"); tag == "" { - t.Fatal(`"FROM alpine:" not found in Dockerfile`) - } else if tag != ts.AlpineDockerTag { - t.Errorf("alpine version mismatch: Dockerfile has %q; ALPINE.txt has %q", tag, ts.AlpineDockerTag) - } -} - -func readAlpineTag(t *testing.T, file string) string { - f, err := os.ReadFile(file) - if err != nil { - t.Fatal(err) - } - for _, line := range bytes.Split(f, []byte{'\n'}) { - line = bytes.TrimSpace(line) - _, suf, ok := bytes.Cut(line, []byte("FROM alpine:")) - if !ok { - continue - } - return string(suf) - } - return "" -} - -func TestShortAllocs(t *testing.T) { - allocs := int(testing.AllocsPerRun(10000, func() { - _ = version.Short() - })) - if allocs > 0 { - t.Errorf("allocs = %v; want 0", allocs) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version_test + +import ( + "bytes" + "os" + "testing" + + ts "tailscale.com" + "tailscale.com/version" +) + +func TestAlpineTag(t *testing.T) { + if tag := readAlpineTag(t, "../Dockerfile.base"); tag == "" { + t.Fatal(`"FROM alpine:" not found in Dockerfile.base`) + } else if tag != ts.AlpineDockerTag { + t.Errorf("alpine version mismatch: Dockerfile.base has %q; ALPINE.txt has %q", tag, ts.AlpineDockerTag) + } + if tag := readAlpineTag(t, "../Dockerfile"); tag == "" { + t.Fatal(`"FROM alpine:" not found in Dockerfile`) + } else if tag != ts.AlpineDockerTag { + t.Errorf("alpine version mismatch: Dockerfile has %q; ALPINE.txt has %q", tag, ts.AlpineDockerTag) + } +} + +func readAlpineTag(t *testing.T, file string) string { + f, err := os.ReadFile(file) + if err != nil { + t.Fatal(err) + } + for _, line := range bytes.Split(f, []byte{'\n'}) { + line = bytes.TrimSpace(line) + _, suf, ok := bytes.Cut(line, []byte("FROM alpine:")) + if !ok { + continue + } + return string(suf) + } + return "" +} + +func TestShortAllocs(t *testing.T) { + allocs := int(testing.AllocsPerRun(10000, func() { + _ = version.Short() + })) + if allocs > 0 { + t.Errorf("allocs = %v; want 0", allocs) + } +} diff --git a/wgengine/bench/bench.go b/wgengine/bench/bench.go index 8695f18d15899..b94930ee50c11 100644 --- a/wgengine/bench/bench.go +++ b/wgengine/bench/bench.go @@ -1,409 +1,409 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Create two wgengine instances and pass data through them, measuring -// throughput, latency, and packet loss. -package main - -import ( - "bufio" - "io" - "log" - "net" - "net/http" - "net/http/pprof" - "net/netip" - "os" - "strconv" - "sync" - "time" - - "tailscale.com/types/logger" -) - -const PayloadSize = 1000 -const ICMPMinSize = 24 - -var Addr1 = netip.MustParsePrefix("100.64.1.1/32") -var Addr2 = netip.MustParsePrefix("100.64.1.2/32") - -func main() { - var logf logger.Logf = log.Printf - log.SetFlags(0) - - debugMux := newDebugMux() - go runDebugServer(debugMux, "0.0.0.0:8999") - - mode, err := strconv.Atoi(os.Args[1]) - if err != nil { - log.Fatalf("%q: %v", os.Args[1], err) - } - - traf := NewTrafficGen(nil) - - // Sample test results below are using GOMAXPROCS=2 (for some - // tests, including wireguard-go, higher GOMAXPROCS goes slower) - // on apenwarr's old Linux box: - // Intel(R) Core(TM) i7-4785T CPU @ 2.20GHz - // My 2019 Mac Mini is about 20% faster on most tests. - - switch mode { - // tx=8786325 rx=8786326 (0 = 0.00% loss) (70768.7 Mbits/sec) - case 1: - setupTrivialNoAllocTest(logf, traf) - - // tx=6476293 rx=6476293 (0 = 0.00% loss) (52249.7 Mbits/sec) - case 2: - setupTrivialTest(logf, traf) - - // tx=1957974 rx=1958379 (0 = 0.00% loss) (15939.8 Mbits/sec) - case 11: - setupBlockingChannelTest(logf, traf) - - // tx=728621 rx=701825 (26620 = 3.65% loss) (5525.2 Mbits/sec) - // (much faster on macOS??) - case 12: - setupNonblockingChannelTest(logf, traf) - - // tx=1024260 rx=941098 (83334 = 8.14% loss) (7516.6 Mbits/sec) - // (much faster on macOS??) - case 13: - setupDoubleChannelTest(logf, traf) - - // tx=265468 rx=263189 (2279 = 0.86% loss) (2162.0 Mbits/sec) - case 21: - setupUDPTest(logf, traf) - - // tx=1493580 rx=1493580 (0 = 0.00% loss) (12210.4 Mbits/sec) - case 31: - setupBatchTCPTest(logf, traf) - - // tx=134236 rx=133166 (1070 = 0.80% loss) (1088.9 Mbits/sec) - case 101: - setupWGTest(nil, logf, traf, Addr1, Addr2) - - default: - log.Fatalf("provide a valid test number (0..n)") - } - - logf("initialized ok.") - traf.Start(Addr1.Addr(), Addr2.Addr(), PayloadSize+ICMPMinSize, 0) - - var cur, prev Snapshot - var pps int64 - i := 0 - for { - i += 1 - time.Sleep(10 * time.Millisecond) - - if (i % 100) == 0 { - prev = cur - cur = traf.Snap() - d := cur.Sub(prev) - - if prev.WhenNsec == 0 { - logf("tx=%-6d rx=%-6d", d.TxPackets, d.RxPackets) - } else { - logf("%v @%7d pkt/s", d, pps) - } - } - - pps = traf.Adjust() - } -} - -func newDebugMux() *http.ServeMux { - mux := http.NewServeMux() - mux.HandleFunc("/debug/pprof/", pprof.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprof.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprof.Trace) - return mux -} - -func runDebugServer(mux *http.ServeMux, addr string) { - srv := &http.Server{ - Addr: addr, - Handler: mux, - } - if err := srv.ListenAndServe(); err != nil { - log.Fatal(err) - } -} - -// The absolute minimal test of the traffic generator: have it fill -// a packet buffer, then absorb it again. Zero packet loss. -func setupTrivialNoAllocTest(logf logger.Logf, traf *TrafficGen) { - go func() { - b := make([]byte, 1600) - for { - n := traf.Generate(b, 16) - if n == 0 { - break - } - traf.GotPacket(b[0:n+16], 16) - } - }() -} - -// Almost the same, but this time allocate a fresh buffer each time -// through the loop. Still zero packet loss. Runs about 2/3 as fast for me. -func setupTrivialTest(logf logger.Logf, traf *TrafficGen) { - go func() { - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - break - } - traf.GotPacket(b[0:n+16], 16) - } - }() -} - -// Pass packets through a blocking channel between sender and receiver. -// Still zero packet loss since the sender stops when the channel is full. -// Max speed depends on channel length (I'm not sure why). -func setupBlockingChannelTest(logf logger.Logf, traf *TrafficGen) { - ch := make(chan []byte, 1000) - - go func() { - // transmitter - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - close(ch) - break - } - ch <- b[0 : n+16] - } - }() - - go func() { - // receiver - for b := range ch { - traf.GotPacket(b, 16) - } - }() -} - -// Same as setupBlockingChannelTest, but now we drop packets whenever the -// channel is full. Max speed is about the same as the above test, but -// now with nonzero packet loss. -func setupNonblockingChannelTest(logf logger.Logf, traf *TrafficGen) { - ch := make(chan []byte, 1000) - - go func() { - // transmitter - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - close(ch) - break - } - select { - case ch <- b[0 : n+16]: - default: - } - } - }() - - go func() { - // receiver - for b := range ch { - traf.GotPacket(b, 16) - } - }() -} - -// Same as above, but at an intermediate blocking channel and goroutine -// to make things a little more like wireguard-go. Roughly 20% slower than -// the single-channel version. -func setupDoubleChannelTest(logf logger.Logf, traf *TrafficGen) { - ch := make(chan []byte, 1000) - ch2 := make(chan []byte, 1000) - - go func() { - // transmitter - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - close(ch) - break - } - select { - case ch <- b[0 : n+16]: - default: - } - } - }() - - go func() { - // intermediary - for b := range ch { - ch2 <- b - } - close(ch2) - }() - - go func() { - // receiver - for b := range ch2 { - traf.GotPacket(b, 16) - } - }() -} - -// Instead of a channel, pass packets through a UDP socket. -func setupUDPTest(logf logger.Logf, traf *TrafficGen) { - la, err := net.ResolveUDPAddr("udp", ":0") - if err != nil { - log.Fatalf("resolve: %v", err) - } - - s1, err := net.ListenUDP("udp", la) - if err != nil { - log.Fatalf("listen1: %v", err) - } - s2, err := net.ListenUDP("udp", la) - if err != nil { - log.Fatalf("listen2: %v", err) - } - - a2 := s2.LocalAddr() - - // On macOS (but not Linux), you can't transmit to 0.0.0.0:port, - // which is what returns from .LocalAddr() above. We have to - // force it to localhost instead. - a2.(*net.UDPAddr).IP = net.ParseIP("127.0.0.1") - - s1.SetWriteBuffer(1024 * 1024) - s2.SetReadBuffer(1024 * 1024) - - go func() { - // transmitter - b := make([]byte, 1600) - for { - n := traf.Generate(b, 16) - if n == 0 { - break - } - s1.WriteTo(b[16:n+16], a2) - } - }() - - go func() { - // receiver - b := make([]byte, 1600) - for traf.Running() { - // Use ReadFrom instead of Read, to be more like - // how wireguard-go does it, even though we're not - // going to actually look at the address. - n, _, err := s2.ReadFrom(b) - if err != nil { - log.Fatalf("s2.Read: %v", err) - } - traf.GotPacket(b[:n], 0) - } - }() -} - -// Instead of a channel, pass packets through a TCP socket. -// TCP is a single stream, so we can amortize one syscall across -// multiple packets. 10x amortization seems to make it go ~10x faster, -// as expected, getting us close to the speed of the channel tests above. -// There's also zero packet loss. -func setupBatchTCPTest(logf logger.Logf, traf *TrafficGen) { - sl, err := net.Listen("tcp", ":0") - if err != nil { - log.Fatalf("listen: %v", err) - } - - var slCloseOnce sync.Once - slClose := func() { - slCloseOnce.Do(func() { - sl.Close() - }) - } - - s1, err := net.Dial("tcp", sl.Addr().String()) - if err != nil { - log.Fatalf("dial: %v", err) - } - - s2, err := sl.Accept() - if err != nil { - log.Fatalf("accept: %v", err) - } - - s1.(*net.TCPConn).SetWriteBuffer(1024 * 1024) - s2.(*net.TCPConn).SetReadBuffer(1024 * 1024) - - ch := make(chan int) - - go func() { - // transmitter - defer slClose() - defer s1.Close() - - bs1 := bufio.NewWriterSize(s1, 1024*1024) - - b := make([]byte, 1600) - i := 0 - for { - i += 1 - n := traf.Generate(b, 16) - if n == 0 { - break - } - if i == 1 { - ch <- n - } - bs1.Write(b[16 : n+16]) - - // TODO: this is a pretty half-baked batching - // function, which we'd never want to employ in - // a real-life program. - // - // In real life, we'd probably want to flush - // immediately when there are no more packets to - // generate, and queue up only if we fall behind. - // - // In our case however, we just want to see the - // technical benefits of batching 10 syscalls - // into 1, so a fixed ratio makes more sense. - if (i % 10) == 0 { - bs1.Flush() - } - } - }() - - go func() { - // receiver - defer slClose() - defer s2.Close() - - bs2 := bufio.NewReaderSize(s2, 1024*1024) - - // Find out the packet size (we happen to know they're - // all the same size) - packetSize := <-ch - - b := make([]byte, packetSize) - for traf.Running() { - // TODO: can't use ReadFrom() here, which is - // unfair compared to UDP. (ReadFrom for UDP - // apparently allocates memory per packet, which - // this test does not.) - n, err := io.ReadFull(bs2, b) - if err != nil { - log.Fatalf("s2.Read: %v", err) - } - traf.GotPacket(b[:n], 0) - } - }() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Create two wgengine instances and pass data through them, measuring +// throughput, latency, and packet loss. +package main + +import ( + "bufio" + "io" + "log" + "net" + "net/http" + "net/http/pprof" + "net/netip" + "os" + "strconv" + "sync" + "time" + + "tailscale.com/types/logger" +) + +const PayloadSize = 1000 +const ICMPMinSize = 24 + +var Addr1 = netip.MustParsePrefix("100.64.1.1/32") +var Addr2 = netip.MustParsePrefix("100.64.1.2/32") + +func main() { + var logf logger.Logf = log.Printf + log.SetFlags(0) + + debugMux := newDebugMux() + go runDebugServer(debugMux, "0.0.0.0:8999") + + mode, err := strconv.Atoi(os.Args[1]) + if err != nil { + log.Fatalf("%q: %v", os.Args[1], err) + } + + traf := NewTrafficGen(nil) + + // Sample test results below are using GOMAXPROCS=2 (for some + // tests, including wireguard-go, higher GOMAXPROCS goes slower) + // on apenwarr's old Linux box: + // Intel(R) Core(TM) i7-4785T CPU @ 2.20GHz + // My 2019 Mac Mini is about 20% faster on most tests. + + switch mode { + // tx=8786325 rx=8786326 (0 = 0.00% loss) (70768.7 Mbits/sec) + case 1: + setupTrivialNoAllocTest(logf, traf) + + // tx=6476293 rx=6476293 (0 = 0.00% loss) (52249.7 Mbits/sec) + case 2: + setupTrivialTest(logf, traf) + + // tx=1957974 rx=1958379 (0 = 0.00% loss) (15939.8 Mbits/sec) + case 11: + setupBlockingChannelTest(logf, traf) + + // tx=728621 rx=701825 (26620 = 3.65% loss) (5525.2 Mbits/sec) + // (much faster on macOS??) + case 12: + setupNonblockingChannelTest(logf, traf) + + // tx=1024260 rx=941098 (83334 = 8.14% loss) (7516.6 Mbits/sec) + // (much faster on macOS??) + case 13: + setupDoubleChannelTest(logf, traf) + + // tx=265468 rx=263189 (2279 = 0.86% loss) (2162.0 Mbits/sec) + case 21: + setupUDPTest(logf, traf) + + // tx=1493580 rx=1493580 (0 = 0.00% loss) (12210.4 Mbits/sec) + case 31: + setupBatchTCPTest(logf, traf) + + // tx=134236 rx=133166 (1070 = 0.80% loss) (1088.9 Mbits/sec) + case 101: + setupWGTest(nil, logf, traf, Addr1, Addr2) + + default: + log.Fatalf("provide a valid test number (0..n)") + } + + logf("initialized ok.") + traf.Start(Addr1.Addr(), Addr2.Addr(), PayloadSize+ICMPMinSize, 0) + + var cur, prev Snapshot + var pps int64 + i := 0 + for { + i += 1 + time.Sleep(10 * time.Millisecond) + + if (i % 100) == 0 { + prev = cur + cur = traf.Snap() + d := cur.Sub(prev) + + if prev.WhenNsec == 0 { + logf("tx=%-6d rx=%-6d", d.TxPackets, d.RxPackets) + } else { + logf("%v @%7d pkt/s", d, pps) + } + } + + pps = traf.Adjust() + } +} + +func newDebugMux() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + return mux +} + +func runDebugServer(mux *http.ServeMux, addr string) { + srv := &http.Server{ + Addr: addr, + Handler: mux, + } + if err := srv.ListenAndServe(); err != nil { + log.Fatal(err) + } +} + +// The absolute minimal test of the traffic generator: have it fill +// a packet buffer, then absorb it again. Zero packet loss. +func setupTrivialNoAllocTest(logf logger.Logf, traf *TrafficGen) { + go func() { + b := make([]byte, 1600) + for { + n := traf.Generate(b, 16) + if n == 0 { + break + } + traf.GotPacket(b[0:n+16], 16) + } + }() +} + +// Almost the same, but this time allocate a fresh buffer each time +// through the loop. Still zero packet loss. Runs about 2/3 as fast for me. +func setupTrivialTest(logf logger.Logf, traf *TrafficGen) { + go func() { + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + break + } + traf.GotPacket(b[0:n+16], 16) + } + }() +} + +// Pass packets through a blocking channel between sender and receiver. +// Still zero packet loss since the sender stops when the channel is full. +// Max speed depends on channel length (I'm not sure why). +func setupBlockingChannelTest(logf logger.Logf, traf *TrafficGen) { + ch := make(chan []byte, 1000) + + go func() { + // transmitter + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + close(ch) + break + } + ch <- b[0 : n+16] + } + }() + + go func() { + // receiver + for b := range ch { + traf.GotPacket(b, 16) + } + }() +} + +// Same as setupBlockingChannelTest, but now we drop packets whenever the +// channel is full. Max speed is about the same as the above test, but +// now with nonzero packet loss. +func setupNonblockingChannelTest(logf logger.Logf, traf *TrafficGen) { + ch := make(chan []byte, 1000) + + go func() { + // transmitter + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + close(ch) + break + } + select { + case ch <- b[0 : n+16]: + default: + } + } + }() + + go func() { + // receiver + for b := range ch { + traf.GotPacket(b, 16) + } + }() +} + +// Same as above, but at an intermediate blocking channel and goroutine +// to make things a little more like wireguard-go. Roughly 20% slower than +// the single-channel version. +func setupDoubleChannelTest(logf logger.Logf, traf *TrafficGen) { + ch := make(chan []byte, 1000) + ch2 := make(chan []byte, 1000) + + go func() { + // transmitter + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + close(ch) + break + } + select { + case ch <- b[0 : n+16]: + default: + } + } + }() + + go func() { + // intermediary + for b := range ch { + ch2 <- b + } + close(ch2) + }() + + go func() { + // receiver + for b := range ch2 { + traf.GotPacket(b, 16) + } + }() +} + +// Instead of a channel, pass packets through a UDP socket. +func setupUDPTest(logf logger.Logf, traf *TrafficGen) { + la, err := net.ResolveUDPAddr("udp", ":0") + if err != nil { + log.Fatalf("resolve: %v", err) + } + + s1, err := net.ListenUDP("udp", la) + if err != nil { + log.Fatalf("listen1: %v", err) + } + s2, err := net.ListenUDP("udp", la) + if err != nil { + log.Fatalf("listen2: %v", err) + } + + a2 := s2.LocalAddr() + + // On macOS (but not Linux), you can't transmit to 0.0.0.0:port, + // which is what returns from .LocalAddr() above. We have to + // force it to localhost instead. + a2.(*net.UDPAddr).IP = net.ParseIP("127.0.0.1") + + s1.SetWriteBuffer(1024 * 1024) + s2.SetReadBuffer(1024 * 1024) + + go func() { + // transmitter + b := make([]byte, 1600) + for { + n := traf.Generate(b, 16) + if n == 0 { + break + } + s1.WriteTo(b[16:n+16], a2) + } + }() + + go func() { + // receiver + b := make([]byte, 1600) + for traf.Running() { + // Use ReadFrom instead of Read, to be more like + // how wireguard-go does it, even though we're not + // going to actually look at the address. + n, _, err := s2.ReadFrom(b) + if err != nil { + log.Fatalf("s2.Read: %v", err) + } + traf.GotPacket(b[:n], 0) + } + }() +} + +// Instead of a channel, pass packets through a TCP socket. +// TCP is a single stream, so we can amortize one syscall across +// multiple packets. 10x amortization seems to make it go ~10x faster, +// as expected, getting us close to the speed of the channel tests above. +// There's also zero packet loss. +func setupBatchTCPTest(logf logger.Logf, traf *TrafficGen) { + sl, err := net.Listen("tcp", ":0") + if err != nil { + log.Fatalf("listen: %v", err) + } + + var slCloseOnce sync.Once + slClose := func() { + slCloseOnce.Do(func() { + sl.Close() + }) + } + + s1, err := net.Dial("tcp", sl.Addr().String()) + if err != nil { + log.Fatalf("dial: %v", err) + } + + s2, err := sl.Accept() + if err != nil { + log.Fatalf("accept: %v", err) + } + + s1.(*net.TCPConn).SetWriteBuffer(1024 * 1024) + s2.(*net.TCPConn).SetReadBuffer(1024 * 1024) + + ch := make(chan int) + + go func() { + // transmitter + defer slClose() + defer s1.Close() + + bs1 := bufio.NewWriterSize(s1, 1024*1024) + + b := make([]byte, 1600) + i := 0 + for { + i += 1 + n := traf.Generate(b, 16) + if n == 0 { + break + } + if i == 1 { + ch <- n + } + bs1.Write(b[16 : n+16]) + + // TODO: this is a pretty half-baked batching + // function, which we'd never want to employ in + // a real-life program. + // + // In real life, we'd probably want to flush + // immediately when there are no more packets to + // generate, and queue up only if we fall behind. + // + // In our case however, we just want to see the + // technical benefits of batching 10 syscalls + // into 1, so a fixed ratio makes more sense. + if (i % 10) == 0 { + bs1.Flush() + } + } + }() + + go func() { + // receiver + defer slClose() + defer s2.Close() + + bs2 := bufio.NewReaderSize(s2, 1024*1024) + + // Find out the packet size (we happen to know they're + // all the same size) + packetSize := <-ch + + b := make([]byte, packetSize) + for traf.Running() { + // TODO: can't use ReadFrom() here, which is + // unfair compared to UDP. (ReadFrom for UDP + // apparently allocates memory per packet, which + // this test does not.) + n, err := io.ReadFull(bs2, b) + if err != nil { + log.Fatalf("s2.Read: %v", err) + } + traf.GotPacket(b[:n], 0) + } + }() +} diff --git a/wgengine/bench/bench_test.go b/wgengine/bench/bench_test.go index 4fae86c0580ba..42571d0557115 100644 --- a/wgengine/bench/bench_test.go +++ b/wgengine/bench/bench_test.go @@ -1,108 +1,108 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Create two wgengine instances and pass data through them, measuring -// throughput, latency, and packet loss. -package main - -import ( - "fmt" - "testing" - "time" - - "tailscale.com/types/logger" -) - -func BenchmarkTrivialNoAlloc(b *testing.B) { - run(b, setupTrivialNoAllocTest) -} -func BenchmarkTrivial(b *testing.B) { - run(b, setupTrivialTest) -} - -func BenchmarkBlockingChannel(b *testing.B) { - run(b, setupBlockingChannelTest) -} - -func BenchmarkNonblockingChannel(b *testing.B) { - run(b, setupNonblockingChannelTest) -} - -func BenchmarkDoubleChannel(b *testing.B) { - run(b, setupDoubleChannelTest) -} - -func BenchmarkUDP(b *testing.B) { - run(b, setupUDPTest) -} - -func BenchmarkBatchTCP(b *testing.B) { - run(b, setupBatchTCPTest) -} - -func BenchmarkWireGuardTest(b *testing.B) { - b.Skip("https://github.com/tailscale/tailscale/issues/2716") - run(b, func(logf logger.Logf, traf *TrafficGen) { - setupWGTest(b, logf, traf, Addr1, Addr2) - }) -} - -type SetupFunc func(logger.Logf, *TrafficGen) - -func run(b *testing.B, setup SetupFunc) { - sizes := []int{ - ICMPMinSize + 8, - ICMPMinSize + 100, - ICMPMinSize + 1000, - } - - for _, size := range sizes { - b.Run(fmt.Sprintf("%d", size), func(b *testing.B) { - runOnce(b, setup, size) - }) - } -} - -func runOnce(b *testing.B, setup SetupFunc, payload int) { - b.StopTimer() - b.ReportAllocs() - - var logf logger.Logf = b.Logf - if !testing.Verbose() { - logf = logger.Discard - } - - traf := NewTrafficGen(b.StartTimer) - setup(logf, traf) - - logf("initialized. (n=%v)", b.N) - b.SetBytes(int64(payload)) - - traf.Start(Addr1.Addr(), Addr2.Addr(), payload, int64(b.N)) - - var cur, prev Snapshot - var pps int64 - i := 0 - for traf.Running() { - i += 1 - time.Sleep(10 * time.Millisecond) - - if (i % 100) == 0 { - prev = cur - cur = traf.Snap() - d := cur.Sub(prev) - - if prev.WhenNsec != 0 { - logf("%v @%7d pkt/sec", d, pps) - } - } - - pps = traf.Adjust() - } - - cur = traf.Snap() - d := cur.Sub(prev) - loss := float64(d.LostPackets) / float64(d.RxPackets) - - b.ReportMetric(loss*100, "%lost") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Create two wgengine instances and pass data through them, measuring +// throughput, latency, and packet loss. +package main + +import ( + "fmt" + "testing" + "time" + + "tailscale.com/types/logger" +) + +func BenchmarkTrivialNoAlloc(b *testing.B) { + run(b, setupTrivialNoAllocTest) +} +func BenchmarkTrivial(b *testing.B) { + run(b, setupTrivialTest) +} + +func BenchmarkBlockingChannel(b *testing.B) { + run(b, setupBlockingChannelTest) +} + +func BenchmarkNonblockingChannel(b *testing.B) { + run(b, setupNonblockingChannelTest) +} + +func BenchmarkDoubleChannel(b *testing.B) { + run(b, setupDoubleChannelTest) +} + +func BenchmarkUDP(b *testing.B) { + run(b, setupUDPTest) +} + +func BenchmarkBatchTCP(b *testing.B) { + run(b, setupBatchTCPTest) +} + +func BenchmarkWireGuardTest(b *testing.B) { + b.Skip("https://github.com/tailscale/tailscale/issues/2716") + run(b, func(logf logger.Logf, traf *TrafficGen) { + setupWGTest(b, logf, traf, Addr1, Addr2) + }) +} + +type SetupFunc func(logger.Logf, *TrafficGen) + +func run(b *testing.B, setup SetupFunc) { + sizes := []int{ + ICMPMinSize + 8, + ICMPMinSize + 100, + ICMPMinSize + 1000, + } + + for _, size := range sizes { + b.Run(fmt.Sprintf("%d", size), func(b *testing.B) { + runOnce(b, setup, size) + }) + } +} + +func runOnce(b *testing.B, setup SetupFunc, payload int) { + b.StopTimer() + b.ReportAllocs() + + var logf logger.Logf = b.Logf + if !testing.Verbose() { + logf = logger.Discard + } + + traf := NewTrafficGen(b.StartTimer) + setup(logf, traf) + + logf("initialized. (n=%v)", b.N) + b.SetBytes(int64(payload)) + + traf.Start(Addr1.Addr(), Addr2.Addr(), payload, int64(b.N)) + + var cur, prev Snapshot + var pps int64 + i := 0 + for traf.Running() { + i += 1 + time.Sleep(10 * time.Millisecond) + + if (i % 100) == 0 { + prev = cur + cur = traf.Snap() + d := cur.Sub(prev) + + if prev.WhenNsec != 0 { + logf("%v @%7d pkt/sec", d, pps) + } + } + + pps = traf.Adjust() + } + + cur = traf.Snap() + d := cur.Sub(prev) + loss := float64(d.LostPackets) / float64(d.RxPackets) + + b.ReportMetric(loss*100, "%lost") +} diff --git a/wgengine/bench/trafficgen.go b/wgengine/bench/trafficgen.go index ce79c616f86ed..9de3c2e6bbc4b 100644 --- a/wgengine/bench/trafficgen.go +++ b/wgengine/bench/trafficgen.go @@ -1,259 +1,259 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "encoding/binary" - "fmt" - "log" - "net/netip" - "sync" - "time" - - "tailscale.com/net/packet" - "tailscale.com/types/ipproto" -) - -type Snapshot struct { - WhenNsec int64 // current time - timeAcc int64 // accumulated time (+NSecPerTx per transmit) - - LastSeqTx int64 // last sequence number sent - LastSeqRx int64 // last sequence number received - TotalLost int64 // packets out-of-order or lost so far - TotalOOO int64 // packets out-of-order so far - TotalBytesRx int64 // total bytes received so far -} - -type Delta struct { - DurationNsec int64 - TxPackets int64 - RxPackets int64 - LostPackets int64 - OOOPackets int64 - Bytes int64 -} - -func (b Snapshot) Sub(a Snapshot) Delta { - return Delta{ - DurationNsec: b.WhenNsec - a.WhenNsec, - TxPackets: b.LastSeqTx - a.LastSeqTx, - RxPackets: (b.LastSeqRx - a.LastSeqRx) - - (b.TotalLost - a.TotalLost) + - (b.TotalOOO - a.TotalOOO), - LostPackets: b.TotalLost - a.TotalLost, - OOOPackets: b.TotalOOO - a.TotalOOO, - Bytes: b.TotalBytesRx - a.TotalBytesRx, - } -} - -func (d Delta) String() string { - return fmt.Sprintf("tx=%-6d rx=%-4d (%6d = %.1f%% loss) (%d OOO) (%4.1f Mbit/s)", - d.TxPackets, d.RxPackets, d.LostPackets, - float64(d.LostPackets)*100/float64(d.TxPackets), - d.OOOPackets, - float64(d.Bytes)*8*1e9/float64(d.DurationNsec)/1e6) -} - -type TrafficGen struct { - mu sync.Mutex - cur, prev Snapshot // snapshots used for rate control - buf []byte // pre-generated packet buffer - done bool // true if the test has completed - - onFirstPacket func() // function to call on first received packet - - // maxPackets is the max packets to receive (not send) before - // ending the test. If it's zero, the test runs forever. - maxPackets int64 - - // nsPerPacket is the target average nanoseconds between packets. - // It's initially zero, which means transmit as fast as the - // caller wants to go. - nsPerPacket int64 - - // ppsHistory is the observed packets-per-second from recent - // samples. - ppsHistory [5]int64 -} - -// NewTrafficGen creates a new, initially locked, TrafficGen. -// Until Start() is called, Generate() will block forever. -func NewTrafficGen(onFirstPacket func()) *TrafficGen { - t := TrafficGen{ - onFirstPacket: onFirstPacket, - } - - // initially locked, until first Start() - t.mu.Lock() - - return &t -} - -// Start starts the traffic generator. It assumes mu is already locked, -// and unlocks it. -func (t *TrafficGen) Start(src, dst netip.Addr, bytesPerPacket int, maxPackets int64) { - h12 := packet.ICMP4Header{ - IP4Header: packet.IP4Header{ - IPProto: ipproto.ICMPv4, - IPID: 0, - Src: src, - Dst: dst, - }, - Type: packet.ICMP4EchoRequest, - Code: packet.ICMP4NoCode, - } - - // ensure there's room for ICMP header plus sequence number - if bytesPerPacket < ICMPMinSize+8 { - log.Fatalf("bytesPerPacket must be > 24+8") - } - - t.maxPackets = maxPackets - - payload := make([]byte, bytesPerPacket-ICMPMinSize) - t.buf = packet.Generate(h12, payload) - - t.mu.Unlock() -} - -func (t *TrafficGen) Snap() Snapshot { - t.mu.Lock() - defer t.mu.Unlock() - - t.cur.WhenNsec = time.Now().UnixNano() - return t.cur -} - -func (t *TrafficGen) Running() bool { - t.mu.Lock() - defer t.mu.Unlock() - - return !t.done -} - -// Generate produces the next packet in the sequence. It sleeps if -// it's too soon for the next packet to be sent. -// -// The generated packet is placed into buf at offset ofs, for compatibility -// with the wireguard-go conventions. -// -// The return value is the number of bytes generated in the packet, or 0 -// if the test has finished running. -func (t *TrafficGen) Generate(b []byte, ofs int) int { - t.mu.Lock() - - now := time.Now().UnixNano() - if t.nsPerPacket == 0 || t.cur.timeAcc == 0 { - t.cur.timeAcc = now - 1 - } - if t.cur.timeAcc >= now { - // too soon - t.mu.Unlock() - time.Sleep(time.Duration(t.cur.timeAcc-now) * time.Nanosecond) - t.mu.Lock() - - now = t.cur.timeAcc - } - if t.done { - t.mu.Unlock() - return 0 - } - - t.cur.timeAcc += t.nsPerPacket - t.cur.LastSeqTx += 1 - t.cur.WhenNsec = now - seq := t.cur.LastSeqTx - - t.mu.Unlock() - - copy(b[ofs:], t.buf) - binary.BigEndian.PutUint64( - b[ofs+ICMPMinSize:ofs+ICMPMinSize+8], - uint64(seq)) - - return len(t.buf) -} - -// GotPacket processes a packet that came back on the receive side. -func (t *TrafficGen) GotPacket(b []byte, ofs int) { - t.mu.Lock() - defer t.mu.Unlock() - - s := &t.cur - seq := int64(binary.BigEndian.Uint64( - b[ofs+ICMPMinSize : ofs+ICMPMinSize+8])) - if seq > s.LastSeqRx { - if s.LastSeqRx > 0 { - // only count lost packets after the very first - // successful one. - s.TotalLost += seq - s.LastSeqRx - 1 - } - s.LastSeqRx = seq - } else { - s.TotalOOO += 1 - } - - // +1 packet since we only start counting after the first one - if t.maxPackets > 0 && s.LastSeqRx >= t.maxPackets+1 { - t.done = true - } - s.TotalBytesRx += int64(len(b) - ofs) - - f := t.onFirstPacket - t.onFirstPacket = nil - if f != nil { - f() - } -} - -// Adjust tunes the transmit rate based on the received packets. -// The goal is to converge on the fastest transmit rate that still has -// minimal packet loss. Returns the new target rate in packets/sec. -// -// We need to play this guessing game in order to balance out tx and rx -// rates when there's a lossy network between them. Otherwise we can end -// up using 99% of the CPU to blast out transmitted packets and leaving only -// 1% to receive them, leading to a misleading throughput calculation. -// -// Call this function multiple times per second. -func (t *TrafficGen) Adjust() (pps int64) { - t.mu.Lock() - defer t.mu.Unlock() - - d := t.cur.Sub(t.prev) - - // don't adjust rate until the first full period *after* receiving - // the first packet. This skips any handshake time in the underlying - // transport. - if t.prev.LastSeqRx == 0 || d.DurationNsec == 0 { - t.prev = t.cur - return 0 // no estimate yet, continue at max speed - } - - pps = int64(d.RxPackets) * 1e9 / int64(d.DurationNsec) - - // We use a rate selection algorithm based loosely on TCP BBR. - // Basically, we set the transmit rate to be a bit higher than - // the best observed transmit rate in the last several time - // periods. This guarantees some packet loss, but should converge - // quickly on a rate near the sustainable maximum. - bestPPS := pps - for _, p := range t.ppsHistory { - if p > bestPPS { - bestPPS = p - } - } - if pps > 0 && t.prev.WhenNsec > 0 { - copy(t.ppsHistory[1:], t.ppsHistory[0:len(t.ppsHistory)-1]) - t.ppsHistory[0] = pps - } - if bestPPS > 0 { - pps = bestPPS * 103 / 100 - t.nsPerPacket = int64(1e9 / pps) - } - t.prev = t.cur - - return pps -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "encoding/binary" + "fmt" + "log" + "net/netip" + "sync" + "time" + + "tailscale.com/net/packet" + "tailscale.com/types/ipproto" +) + +type Snapshot struct { + WhenNsec int64 // current time + timeAcc int64 // accumulated time (+NSecPerTx per transmit) + + LastSeqTx int64 // last sequence number sent + LastSeqRx int64 // last sequence number received + TotalLost int64 // packets out-of-order or lost so far + TotalOOO int64 // packets out-of-order so far + TotalBytesRx int64 // total bytes received so far +} + +type Delta struct { + DurationNsec int64 + TxPackets int64 + RxPackets int64 + LostPackets int64 + OOOPackets int64 + Bytes int64 +} + +func (b Snapshot) Sub(a Snapshot) Delta { + return Delta{ + DurationNsec: b.WhenNsec - a.WhenNsec, + TxPackets: b.LastSeqTx - a.LastSeqTx, + RxPackets: (b.LastSeqRx - a.LastSeqRx) - + (b.TotalLost - a.TotalLost) + + (b.TotalOOO - a.TotalOOO), + LostPackets: b.TotalLost - a.TotalLost, + OOOPackets: b.TotalOOO - a.TotalOOO, + Bytes: b.TotalBytesRx - a.TotalBytesRx, + } +} + +func (d Delta) String() string { + return fmt.Sprintf("tx=%-6d rx=%-4d (%6d = %.1f%% loss) (%d OOO) (%4.1f Mbit/s)", + d.TxPackets, d.RxPackets, d.LostPackets, + float64(d.LostPackets)*100/float64(d.TxPackets), + d.OOOPackets, + float64(d.Bytes)*8*1e9/float64(d.DurationNsec)/1e6) +} + +type TrafficGen struct { + mu sync.Mutex + cur, prev Snapshot // snapshots used for rate control + buf []byte // pre-generated packet buffer + done bool // true if the test has completed + + onFirstPacket func() // function to call on first received packet + + // maxPackets is the max packets to receive (not send) before + // ending the test. If it's zero, the test runs forever. + maxPackets int64 + + // nsPerPacket is the target average nanoseconds between packets. + // It's initially zero, which means transmit as fast as the + // caller wants to go. + nsPerPacket int64 + + // ppsHistory is the observed packets-per-second from recent + // samples. + ppsHistory [5]int64 +} + +// NewTrafficGen creates a new, initially locked, TrafficGen. +// Until Start() is called, Generate() will block forever. +func NewTrafficGen(onFirstPacket func()) *TrafficGen { + t := TrafficGen{ + onFirstPacket: onFirstPacket, + } + + // initially locked, until first Start() + t.mu.Lock() + + return &t +} + +// Start starts the traffic generator. It assumes mu is already locked, +// and unlocks it. +func (t *TrafficGen) Start(src, dst netip.Addr, bytesPerPacket int, maxPackets int64) { + h12 := packet.ICMP4Header{ + IP4Header: packet.IP4Header{ + IPProto: ipproto.ICMPv4, + IPID: 0, + Src: src, + Dst: dst, + }, + Type: packet.ICMP4EchoRequest, + Code: packet.ICMP4NoCode, + } + + // ensure there's room for ICMP header plus sequence number + if bytesPerPacket < ICMPMinSize+8 { + log.Fatalf("bytesPerPacket must be > 24+8") + } + + t.maxPackets = maxPackets + + payload := make([]byte, bytesPerPacket-ICMPMinSize) + t.buf = packet.Generate(h12, payload) + + t.mu.Unlock() +} + +func (t *TrafficGen) Snap() Snapshot { + t.mu.Lock() + defer t.mu.Unlock() + + t.cur.WhenNsec = time.Now().UnixNano() + return t.cur +} + +func (t *TrafficGen) Running() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return !t.done +} + +// Generate produces the next packet in the sequence. It sleeps if +// it's too soon for the next packet to be sent. +// +// The generated packet is placed into buf at offset ofs, for compatibility +// with the wireguard-go conventions. +// +// The return value is the number of bytes generated in the packet, or 0 +// if the test has finished running. +func (t *TrafficGen) Generate(b []byte, ofs int) int { + t.mu.Lock() + + now := time.Now().UnixNano() + if t.nsPerPacket == 0 || t.cur.timeAcc == 0 { + t.cur.timeAcc = now - 1 + } + if t.cur.timeAcc >= now { + // too soon + t.mu.Unlock() + time.Sleep(time.Duration(t.cur.timeAcc-now) * time.Nanosecond) + t.mu.Lock() + + now = t.cur.timeAcc + } + if t.done { + t.mu.Unlock() + return 0 + } + + t.cur.timeAcc += t.nsPerPacket + t.cur.LastSeqTx += 1 + t.cur.WhenNsec = now + seq := t.cur.LastSeqTx + + t.mu.Unlock() + + copy(b[ofs:], t.buf) + binary.BigEndian.PutUint64( + b[ofs+ICMPMinSize:ofs+ICMPMinSize+8], + uint64(seq)) + + return len(t.buf) +} + +// GotPacket processes a packet that came back on the receive side. +func (t *TrafficGen) GotPacket(b []byte, ofs int) { + t.mu.Lock() + defer t.mu.Unlock() + + s := &t.cur + seq := int64(binary.BigEndian.Uint64( + b[ofs+ICMPMinSize : ofs+ICMPMinSize+8])) + if seq > s.LastSeqRx { + if s.LastSeqRx > 0 { + // only count lost packets after the very first + // successful one. + s.TotalLost += seq - s.LastSeqRx - 1 + } + s.LastSeqRx = seq + } else { + s.TotalOOO += 1 + } + + // +1 packet since we only start counting after the first one + if t.maxPackets > 0 && s.LastSeqRx >= t.maxPackets+1 { + t.done = true + } + s.TotalBytesRx += int64(len(b) - ofs) + + f := t.onFirstPacket + t.onFirstPacket = nil + if f != nil { + f() + } +} + +// Adjust tunes the transmit rate based on the received packets. +// The goal is to converge on the fastest transmit rate that still has +// minimal packet loss. Returns the new target rate in packets/sec. +// +// We need to play this guessing game in order to balance out tx and rx +// rates when there's a lossy network between them. Otherwise we can end +// up using 99% of the CPU to blast out transmitted packets and leaving only +// 1% to receive them, leading to a misleading throughput calculation. +// +// Call this function multiple times per second. +func (t *TrafficGen) Adjust() (pps int64) { + t.mu.Lock() + defer t.mu.Unlock() + + d := t.cur.Sub(t.prev) + + // don't adjust rate until the first full period *after* receiving + // the first packet. This skips any handshake time in the underlying + // transport. + if t.prev.LastSeqRx == 0 || d.DurationNsec == 0 { + t.prev = t.cur + return 0 // no estimate yet, continue at max speed + } + + pps = int64(d.RxPackets) * 1e9 / int64(d.DurationNsec) + + // We use a rate selection algorithm based loosely on TCP BBR. + // Basically, we set the transmit rate to be a bit higher than + // the best observed transmit rate in the last several time + // periods. This guarantees some packet loss, but should converge + // quickly on a rate near the sustainable maximum. + bestPPS := pps + for _, p := range t.ppsHistory { + if p > bestPPS { + bestPPS = p + } + } + if pps > 0 && t.prev.WhenNsec > 0 { + copy(t.ppsHistory[1:], t.ppsHistory[0:len(t.ppsHistory)-1]) + t.ppsHistory[0] = pps + } + if bestPPS > 0 { + pps = bestPPS * 103 / 100 + t.nsPerPacket = int64(1e9 / pps) + } + t.prev = t.cur + + return pps +} diff --git a/wgengine/capture/capture.go b/wgengine/capture/capture.go index 6ea5a9549b4f1..01f79ea9f5485 100644 --- a/wgengine/capture/capture.go +++ b/wgengine/capture/capture.go @@ -1,238 +1,238 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package capture formats packet logging into a debug pcap stream. -package capture - -import ( - "bytes" - "context" - "encoding/binary" - "io" - "net/http" - "sync" - "time" - - _ "embed" - - "tailscale.com/net/packet" - "tailscale.com/util/set" -) - -//go:embed ts-dissector.lua -var DissectorLua string - -// Callback describes a function which is called to -// record packets when debugging packet-capture. -// Such callbacks must not take ownership of the -// provided data slice: it may only copy out of it -// within the lifetime of the function. -type Callback func(Path, time.Time, []byte, packet.CaptureMeta) - -var bufferPool = sync.Pool{ - New: func() any { - return new(bytes.Buffer) - }, -} - -const flushPeriod = 100 * time.Millisecond - -func writePcapHeader(w io.Writer) { - binary.Write(w, binary.LittleEndian, uint32(0xA1B2C3D4)) // pcap magic number - binary.Write(w, binary.LittleEndian, uint16(2)) // version major - binary.Write(w, binary.LittleEndian, uint16(4)) // version minor - binary.Write(w, binary.LittleEndian, uint32(0)) // this zone - binary.Write(w, binary.LittleEndian, uint32(0)) // zone significant figures - binary.Write(w, binary.LittleEndian, uint32(65535)) // max packet len - binary.Write(w, binary.LittleEndian, uint32(147)) // link-layer ID - USER0 -} - -func writePktHeader(w *bytes.Buffer, when time.Time, length int) { - s := when.Unix() - us := when.UnixMicro() - (s * 1000000) - - binary.Write(w, binary.LittleEndian, uint32(s)) // timestamp in seconds - binary.Write(w, binary.LittleEndian, uint32(us)) // timestamp microseconds - binary.Write(w, binary.LittleEndian, uint32(length)) // length present - binary.Write(w, binary.LittleEndian, uint32(length)) // total length -} - -// Path describes where in the data path the packet was captured. -type Path uint8 - -// Valid Path values. -const ( - // FromLocal indicates the packet was logged as it traversed the FromLocal path: - // i.e.: A packet from the local system into the TUN. - FromLocal Path = 0 - // FromPeer indicates the packet was logged upon reception from a remote peer. - FromPeer Path = 1 - // SynthesizedToLocal indicates the packet was generated from within tailscaled, - // and is being routed to the local machine's network stack. - SynthesizedToLocal Path = 2 - // SynthesizedToPeer indicates the packet was generated from within tailscaled, - // and is being routed to a remote Wireguard peer. - SynthesizedToPeer Path = 3 - - // PathDisco indicates the packet is information about a disco frame. - PathDisco Path = 254 -) - -// New creates a new capture sink. -func New() *Sink { - ctx, c := context.WithCancel(context.Background()) - return &Sink{ - ctx: ctx, - ctxCancel: c, - } -} - -// Type Sink handles callbacks with packets to be logged, -// formatting them into a pcap stream which is mirrored to -// all registered outputs. -type Sink struct { - ctx context.Context - ctxCancel context.CancelFunc - - mu sync.Mutex - outputs set.HandleSet[io.Writer] - flushTimer *time.Timer // or nil if none running -} - -// RegisterOutput connects an output to this sink, which -// will be written to with a pcap stream as packets are logged. -// A function is returned which unregisters the output when -// called. -// -// If w implements io.Closer, it will be closed upon error -// or when the sink is closed. If w implements http.Flusher, -// it will be flushed periodically. -func (s *Sink) RegisterOutput(w io.Writer) (unregister func()) { - select { - case <-s.ctx.Done(): - return func() {} - default: - } - - writePcapHeader(w) - s.mu.Lock() - hnd := s.outputs.Add(w) - s.mu.Unlock() - - return func() { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.outputs, hnd) - } -} - -// NumOutputs returns the number of outputs registered with the sink. -func (s *Sink) NumOutputs() int { - s.mu.Lock() - defer s.mu.Unlock() - return len(s.outputs) -} - -// Close shuts down the sink. Future calls to LogPacket -// are ignored, and any registered output that implements -// io.Closer is closed. -func (s *Sink) Close() error { - s.ctxCancel() - s.mu.Lock() - defer s.mu.Unlock() - if s.flushTimer != nil { - s.flushTimer.Stop() - s.flushTimer = nil - } - - for _, o := range s.outputs { - if o, ok := o.(io.Closer); ok { - o.Close() - } - } - s.outputs = nil - return nil -} - -// WaitCh returns a channel which blocks until -// the sink is closed. -func (s *Sink) WaitCh() <-chan struct{} { - return s.ctx.Done() -} - -func customDataLen(meta packet.CaptureMeta) int { - length := 4 - if meta.DidSNAT { - length += meta.OriginalSrc.Addr().BitLen() / 8 - } - if meta.DidDNAT { - length += meta.OriginalDst.Addr().BitLen() / 8 - } - return length -} - -// LogPacket is called to insert a packet into the capture. -// -// This function does not take ownership of the provided data slice. -func (s *Sink) LogPacket(path Path, when time.Time, data []byte, meta packet.CaptureMeta) { - select { - case <-s.ctx.Done(): - return - default: - } - - extraLen := customDataLen(meta) - b := bufferPool.Get().(*bytes.Buffer) - b.Reset() - b.Grow(16 + extraLen + len(data)) // 16b pcap header + len(metadata) + len(payload) - defer bufferPool.Put(b) - - writePktHeader(b, when, len(data)+extraLen) - - // Custom tailscale debugging data - binary.Write(b, binary.LittleEndian, uint16(path)) - if meta.DidSNAT { - binary.Write(b, binary.LittleEndian, uint8(meta.OriginalSrc.Addr().BitLen()/8)) - b.Write(meta.OriginalSrc.Addr().AsSlice()) - } else { - binary.Write(b, binary.LittleEndian, uint8(0)) // SNAT addr len == 0 - } - if meta.DidDNAT { - binary.Write(b, binary.LittleEndian, uint8(meta.OriginalDst.Addr().BitLen()/8)) - b.Write(meta.OriginalDst.Addr().AsSlice()) - } else { - binary.Write(b, binary.LittleEndian, uint8(0)) // DNAT addr len == 0 - } - - b.Write(data) - - s.mu.Lock() - defer s.mu.Unlock() - - var hadError []set.Handle - for hnd, o := range s.outputs { - if _, err := o.Write(b.Bytes()); err != nil { - hadError = append(hadError, hnd) - continue - } - } - for _, hnd := range hadError { - if o, ok := s.outputs[hnd].(io.Closer); ok { - o.Close() - } - delete(s.outputs, hnd) - } - - if s.flushTimer == nil { - s.flushTimer = time.AfterFunc(flushPeriod, func() { - s.mu.Lock() - defer s.mu.Unlock() - for _, o := range s.outputs { - if f, ok := o.(http.Flusher); ok { - f.Flush() - } - } - s.flushTimer = nil - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package capture formats packet logging into a debug pcap stream. +package capture + +import ( + "bytes" + "context" + "encoding/binary" + "io" + "net/http" + "sync" + "time" + + _ "embed" + + "tailscale.com/net/packet" + "tailscale.com/util/set" +) + +//go:embed ts-dissector.lua +var DissectorLua string + +// Callback describes a function which is called to +// record packets when debugging packet-capture. +// Such callbacks must not take ownership of the +// provided data slice: it may only copy out of it +// within the lifetime of the function. +type Callback func(Path, time.Time, []byte, packet.CaptureMeta) + +var bufferPool = sync.Pool{ + New: func() any { + return new(bytes.Buffer) + }, +} + +const flushPeriod = 100 * time.Millisecond + +func writePcapHeader(w io.Writer) { + binary.Write(w, binary.LittleEndian, uint32(0xA1B2C3D4)) // pcap magic number + binary.Write(w, binary.LittleEndian, uint16(2)) // version major + binary.Write(w, binary.LittleEndian, uint16(4)) // version minor + binary.Write(w, binary.LittleEndian, uint32(0)) // this zone + binary.Write(w, binary.LittleEndian, uint32(0)) // zone significant figures + binary.Write(w, binary.LittleEndian, uint32(65535)) // max packet len + binary.Write(w, binary.LittleEndian, uint32(147)) // link-layer ID - USER0 +} + +func writePktHeader(w *bytes.Buffer, when time.Time, length int) { + s := when.Unix() + us := when.UnixMicro() - (s * 1000000) + + binary.Write(w, binary.LittleEndian, uint32(s)) // timestamp in seconds + binary.Write(w, binary.LittleEndian, uint32(us)) // timestamp microseconds + binary.Write(w, binary.LittleEndian, uint32(length)) // length present + binary.Write(w, binary.LittleEndian, uint32(length)) // total length +} + +// Path describes where in the data path the packet was captured. +type Path uint8 + +// Valid Path values. +const ( + // FromLocal indicates the packet was logged as it traversed the FromLocal path: + // i.e.: A packet from the local system into the TUN. + FromLocal Path = 0 + // FromPeer indicates the packet was logged upon reception from a remote peer. + FromPeer Path = 1 + // SynthesizedToLocal indicates the packet was generated from within tailscaled, + // and is being routed to the local machine's network stack. + SynthesizedToLocal Path = 2 + // SynthesizedToPeer indicates the packet was generated from within tailscaled, + // and is being routed to a remote Wireguard peer. + SynthesizedToPeer Path = 3 + + // PathDisco indicates the packet is information about a disco frame. + PathDisco Path = 254 +) + +// New creates a new capture sink. +func New() *Sink { + ctx, c := context.WithCancel(context.Background()) + return &Sink{ + ctx: ctx, + ctxCancel: c, + } +} + +// Type Sink handles callbacks with packets to be logged, +// formatting them into a pcap stream which is mirrored to +// all registered outputs. +type Sink struct { + ctx context.Context + ctxCancel context.CancelFunc + + mu sync.Mutex + outputs set.HandleSet[io.Writer] + flushTimer *time.Timer // or nil if none running +} + +// RegisterOutput connects an output to this sink, which +// will be written to with a pcap stream as packets are logged. +// A function is returned which unregisters the output when +// called. +// +// If w implements io.Closer, it will be closed upon error +// or when the sink is closed. If w implements http.Flusher, +// it will be flushed periodically. +func (s *Sink) RegisterOutput(w io.Writer) (unregister func()) { + select { + case <-s.ctx.Done(): + return func() {} + default: + } + + writePcapHeader(w) + s.mu.Lock() + hnd := s.outputs.Add(w) + s.mu.Unlock() + + return func() { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.outputs, hnd) + } +} + +// NumOutputs returns the number of outputs registered with the sink. +func (s *Sink) NumOutputs() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.outputs) +} + +// Close shuts down the sink. Future calls to LogPacket +// are ignored, and any registered output that implements +// io.Closer is closed. +func (s *Sink) Close() error { + s.ctxCancel() + s.mu.Lock() + defer s.mu.Unlock() + if s.flushTimer != nil { + s.flushTimer.Stop() + s.flushTimer = nil + } + + for _, o := range s.outputs { + if o, ok := o.(io.Closer); ok { + o.Close() + } + } + s.outputs = nil + return nil +} + +// WaitCh returns a channel which blocks until +// the sink is closed. +func (s *Sink) WaitCh() <-chan struct{} { + return s.ctx.Done() +} + +func customDataLen(meta packet.CaptureMeta) int { + length := 4 + if meta.DidSNAT { + length += meta.OriginalSrc.Addr().BitLen() / 8 + } + if meta.DidDNAT { + length += meta.OriginalDst.Addr().BitLen() / 8 + } + return length +} + +// LogPacket is called to insert a packet into the capture. +// +// This function does not take ownership of the provided data slice. +func (s *Sink) LogPacket(path Path, when time.Time, data []byte, meta packet.CaptureMeta) { + select { + case <-s.ctx.Done(): + return + default: + } + + extraLen := customDataLen(meta) + b := bufferPool.Get().(*bytes.Buffer) + b.Reset() + b.Grow(16 + extraLen + len(data)) // 16b pcap header + len(metadata) + len(payload) + defer bufferPool.Put(b) + + writePktHeader(b, when, len(data)+extraLen) + + // Custom tailscale debugging data + binary.Write(b, binary.LittleEndian, uint16(path)) + if meta.DidSNAT { + binary.Write(b, binary.LittleEndian, uint8(meta.OriginalSrc.Addr().BitLen()/8)) + b.Write(meta.OriginalSrc.Addr().AsSlice()) + } else { + binary.Write(b, binary.LittleEndian, uint8(0)) // SNAT addr len == 0 + } + if meta.DidDNAT { + binary.Write(b, binary.LittleEndian, uint8(meta.OriginalDst.Addr().BitLen()/8)) + b.Write(meta.OriginalDst.Addr().AsSlice()) + } else { + binary.Write(b, binary.LittleEndian, uint8(0)) // DNAT addr len == 0 + } + + b.Write(data) + + s.mu.Lock() + defer s.mu.Unlock() + + var hadError []set.Handle + for hnd, o := range s.outputs { + if _, err := o.Write(b.Bytes()); err != nil { + hadError = append(hadError, hnd) + continue + } + } + for _, hnd := range hadError { + if o, ok := s.outputs[hnd].(io.Closer); ok { + o.Close() + } + delete(s.outputs, hnd) + } + + if s.flushTimer == nil { + s.flushTimer = time.AfterFunc(flushPeriod, func() { + s.mu.Lock() + defer s.mu.Unlock() + for _, o := range s.outputs { + if f, ok := o.(http.Flusher); ok { + f.Flush() + } + } + s.flushTimer = nil + }) + } +} diff --git a/wgengine/magicsock/blockforever_conn.go b/wgengine/magicsock/blockforever_conn.go index f2e85dcd57002..58359acdd51f2 100644 --- a/wgengine/magicsock/blockforever_conn.go +++ b/wgengine/magicsock/blockforever_conn.go @@ -1,55 +1,55 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package magicsock - -import ( - "errors" - "net" - "net/netip" - "sync" - "syscall" - "time" -) - -// blockForeverConn is a net.PacketConn whose reads block until it is closed. -type blockForeverConn struct { - mu sync.Mutex - cond *sync.Cond - closed bool -} - -func (c *blockForeverConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { - c.mu.Lock() - for !c.closed { - c.cond.Wait() - } - c.mu.Unlock() - return 0, netip.AddrPort{}, net.ErrClosed -} - -func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (int, error) { - // Silently drop writes. - return len(p), nil -} - -func (c *blockForeverConn) LocalAddr() net.Addr { - // Return a *net.UDPAddr because lots of code assumes that it will. - return new(net.UDPAddr) -} - -func (c *blockForeverConn) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return net.ErrClosed - } - c.closed = true - c.cond.Broadcast() - return nil -} - -func (c *blockForeverConn) SetDeadline(t time.Time) error { return errors.New("unimplemented") } -func (c *blockForeverConn) SetReadDeadline(t time.Time) error { return errors.New("unimplemented") } -func (c *blockForeverConn) SetWriteDeadline(t time.Time) error { return errors.New("unimplemented") } -func (c *blockForeverConn) SyscallConn() (syscall.RawConn, error) { return nil, errUnsupportedConnType } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "errors" + "net" + "net/netip" + "sync" + "syscall" + "time" +) + +// blockForeverConn is a net.PacketConn whose reads block until it is closed. +type blockForeverConn struct { + mu sync.Mutex + cond *sync.Cond + closed bool +} + +func (c *blockForeverConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { + c.mu.Lock() + for !c.closed { + c.cond.Wait() + } + c.mu.Unlock() + return 0, netip.AddrPort{}, net.ErrClosed +} + +func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (int, error) { + // Silently drop writes. + return len(p), nil +} + +func (c *blockForeverConn) LocalAddr() net.Addr { + // Return a *net.UDPAddr because lots of code assumes that it will. + return new(net.UDPAddr) +} + +func (c *blockForeverConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return net.ErrClosed + } + c.closed = true + c.cond.Broadcast() + return nil +} + +func (c *blockForeverConn) SetDeadline(t time.Time) error { return errors.New("unimplemented") } +func (c *blockForeverConn) SetReadDeadline(t time.Time) error { return errors.New("unimplemented") } +func (c *blockForeverConn) SetWriteDeadline(t time.Time) error { return errors.New("unimplemented") } +func (c *blockForeverConn) SyscallConn() (syscall.RawConn, error) { return nil, errUnsupportedConnType } diff --git a/wgengine/magicsock/endpoint_default.go b/wgengine/magicsock/endpoint_default.go index 1ed6e5e0e2399..9ffeef5f8a7bf 100644 --- a/wgengine/magicsock/endpoint_default.go +++ b/wgengine/magicsock/endpoint_default.go @@ -1,22 +1,22 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !js && !wasm && !plan9 - -package magicsock - -import ( - "errors" - "syscall" -) - -// errHOSTUNREACH wraps unix.EHOSTUNREACH in an interface type to pass to -// errors.Is while avoiding an allocation per call. -var errHOSTUNREACH error = syscall.EHOSTUNREACH - -// isBadEndpointErr checks if err is one which is known to report that an -// endpoint can no longer be sent to. It is not exhaustive, and for unknown -// errors always reports false. -func isBadEndpointErr(err error) bool { - return errors.Is(err, errHOSTUNREACH) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !wasm && !plan9 + +package magicsock + +import ( + "errors" + "syscall" +) + +// errHOSTUNREACH wraps unix.EHOSTUNREACH in an interface type to pass to +// errors.Is while avoiding an allocation per call. +var errHOSTUNREACH error = syscall.EHOSTUNREACH + +// isBadEndpointErr checks if err is one which is known to report that an +// endpoint can no longer be sent to. It is not exhaustive, and for unknown +// errors always reports false. +func isBadEndpointErr(err error) bool { + return errors.Is(err, errHOSTUNREACH) +} diff --git a/wgengine/magicsock/endpoint_stub.go b/wgengine/magicsock/endpoint_stub.go index a209c352bfe5e..9a5c9d937560c 100644 --- a/wgengine/magicsock/endpoint_stub.go +++ b/wgengine/magicsock/endpoint_stub.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build wasm || plan9 - -package magicsock - -// isBadEndpointErr checks if err is one which is known to report that an -// endpoint can no longer be sent to. It is not exhaustive, but covers known -// cases. -func isBadEndpointErr(err error) bool { - return false -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build wasm || plan9 + +package magicsock + +// isBadEndpointErr checks if err is one which is known to report that an +// endpoint can no longer be sent to. It is not exhaustive, but covers known +// cases. +func isBadEndpointErr(err error) bool { + return false +} diff --git a/wgengine/magicsock/endpoint_tracker.go b/wgengine/magicsock/endpoint_tracker.go index 5caddd1a06960..e2ac926b43060 100644 --- a/wgengine/magicsock/endpoint_tracker.go +++ b/wgengine/magicsock/endpoint_tracker.go @@ -1,248 +1,248 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package magicsock - -import ( - "net/netip" - "slices" - "sync" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/tempfork/heap" - "tailscale.com/util/mak" - "tailscale.com/util/set" -) - -const ( - // endpointTrackerLifetime is how long we continue advertising an - // endpoint after we last see it. This is intentionally chosen to be - // slightly longer than a full netcheck period. - endpointTrackerLifetime = 5*time.Minute + 10*time.Second - - // endpointTrackerMaxPerAddr is how many cached addresses we track for - // a given netip.Addr. This allows e.g. restricting the number of STUN - // endpoints we cache (which usually have the same netip.Addr but - // different ports). - // - // The value of 6 is chosen because we can advertise up to 3 endpoints - // based on the STUN IP: - // 1. The STUN endpoint itself (EndpointSTUN) - // 2. The STUN IP with the local Tailscale port (EndpointSTUN4LocalPort) - // 3. The STUN IP with a portmapped port (EndpointPortmapped) - // - // Storing 6 endpoints in the cache means we can store up to 2 previous - // sets of endpoints. - endpointTrackerMaxPerAddr = 6 -) - -// endpointTrackerEntry is an entry in an endpointHeap that stores the state of -// a given cached endpoint. -type endpointTrackerEntry struct { - // endpoint is the cached endpoint. - endpoint tailcfg.Endpoint - // until is the time until which this endpoint is being cached. - until time.Time - // index is the index within the containing endpointHeap. - index int -} - -// endpointHeap is an ordered heap of endpointTrackerEntry structs, ordered in -// ascending order by the 'until' expiry time (i.e. oldest first). -type endpointHeap []*endpointTrackerEntry - -var _ heap.Interface[*endpointTrackerEntry] = (*endpointHeap)(nil) - -// Len implements heap.Interface. -func (eh endpointHeap) Len() int { return len(eh) } - -// Less implements heap.Interface. -func (eh endpointHeap) Less(i, j int) bool { - // We want to store items so that the lowest item in the heap is the - // oldest, so that heap.Pop()-ing from the endpointHeap will remove the - // oldest entry. - return eh[i].until.Before(eh[j].until) -} - -// Swap implements heap.Interface. -func (eh endpointHeap) Swap(i, j int) { - eh[i], eh[j] = eh[j], eh[i] - eh[i].index = i - eh[j].index = j -} - -// Push implements heap.Interface. -func (eh *endpointHeap) Push(item *endpointTrackerEntry) { - n := len(*eh) - item.index = n - *eh = append(*eh, item) -} - -// Pop implements heap.Interface. -func (eh *endpointHeap) Pop() *endpointTrackerEntry { - old := *eh - n := len(old) - item := old[n-1] - old[n-1] = nil // avoid memory leak - item.index = -1 // for safety - *eh = old[0 : n-1] - return item -} - -// Min returns a pointer to the minimum element in the heap, without removing -// it. Since this is a min-heap ordered by the 'until' field, this returns the -// chronologically "earliest" element in the heap. -// -// Len() must be non-zero. -func (eh endpointHeap) Min() *endpointTrackerEntry { - return eh[0] -} - -// endpointTracker caches endpoints that are advertised to peers. This allows -// peers to still reach this node if there's a temporary endpoint flap; rather -// than withdrawing an endpoint and then re-advertising it the next time we run -// a netcheck, we keep advertising the endpoint until it's not present for a -// defined timeout. -// -// See tailscale/tailscale#7877 for more information. -type endpointTracker struct { - mu sync.Mutex - endpoints map[netip.Addr]*endpointHeap -} - -// update takes as input the current sent of discovered endpoints and the -// current time, and returns the set of endpoints plus any previous-cached and -// non-expired endpoints that should be advertised to peers. -func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) { - var inputEps set.Slice[netip.AddrPort] - for _, ep := range eps { - inputEps.Add(ep.Addr) - } - - et.mu.Lock() - defer et.mu.Unlock() - - // Extend endpoints that already exist in the cache. We do this before - // we remove expired endpoints, below, so we don't remove something - // that would otherwise have survived by extending. - until := now.Add(endpointTrackerLifetime) - for _, ep := range eps { - et.extendLocked(ep, until) - } - - // Now that we've extended existing endpoints, remove everything that - // has expired. - et.removeExpiredLocked(now) - - // Add entries from the input set of endpoints into the cache; we do - // this after removing expired ones so that we can store as many as - // possible, with space freed by the entries removed after expiry. - for _, ep := range eps { - et.addLocked(now, ep, until) - } - - // Finally, add entries to the return array that aren't already there. - epsPlusCached = eps - for _, heap := range et.endpoints { - for _, ep := range *heap { - // If the endpoint was in the input list, or has expired, skip it. - if inputEps.Contains(ep.endpoint.Addr) { - continue - } else if now.After(ep.until) { - // Defense-in-depth; should never happen since - // we removed expired entries above, but ignore - // it anyway. - continue - } - - // We haven't seen this endpoint; add to the return array - epsPlusCached = append(epsPlusCached, ep.endpoint) - } - } - - return epsPlusCached -} - -// extendLocked will update the expiry time of the provided endpoint in the -// cache, if it is present. If it is not present, nothing will be done. -// -// et.mu must be held. -func (et *endpointTracker) extendLocked(ep tailcfg.Endpoint, until time.Time) { - key := ep.Addr.Addr() - epHeap, found := et.endpoints[key] - if !found { - return - } - - // Find the entry for this exact address; this loop is quick since we - // bound the number of items in the heap. - // - // TODO(andrew): this means we iterate over the entire heap once per - // endpoint; even if the heap is small, if we have a lot of input - // endpoints this can be expensive? - for i, entry := range *epHeap { - if entry.endpoint == ep { - entry.until = until - heap.Fix(epHeap, i) - return - } - } -} - -// addLocked will store the provided endpoint(s) in the cache for a fixed -// period of time, ensuring that the size of the endpoint cache remains below -// the maximum. -// -// et.mu must be held. -func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) { - key := ep.Addr.Addr() - - // Create or get the heap for this endpoint's addr - epHeap := et.endpoints[key] - if epHeap == nil { - epHeap = new(endpointHeap) - mak.Set(&et.endpoints, key, epHeap) - } - - // Find the entry for this exact address; this loop is quick - // since we bound the number of items in the heap. - found := slices.ContainsFunc(*epHeap, func(v *endpointTrackerEntry) bool { - return v.endpoint == ep - }) - if !found { - // Add address to heap; either the endpoint is new, or the heap - // was newly-created and thus empty. - heap.Push(epHeap, &endpointTrackerEntry{endpoint: ep, until: until}) - } - - // Now that we've added everything, pop from our heap until we're below - // the limit. This is a min-heap, so popping removes the lowest (and - // thus oldest) endpoint. - for epHeap.Len() > endpointTrackerMaxPerAddr { - heap.Pop(epHeap) - } -} - -// removeExpired will remove all expired entries from the cache. -// -// et.mu must be held. -func (et *endpointTracker) removeExpiredLocked(now time.Time) { - for k, epHeap := range et.endpoints { - // The minimum element is oldest/earliest endpoint; repeatedly - // pop from the heap while it's in the past. - for epHeap.Len() > 0 { - minElem := epHeap.Min() - if now.After(minElem.until) { - heap.Pop(epHeap) - } else { - break - } - } - - if epHeap.Len() == 0 { - // Free up space in the map by removing the empty heap. - delete(et.endpoints, k) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "net/netip" + "slices" + "sync" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/tempfork/heap" + "tailscale.com/util/mak" + "tailscale.com/util/set" +) + +const ( + // endpointTrackerLifetime is how long we continue advertising an + // endpoint after we last see it. This is intentionally chosen to be + // slightly longer than a full netcheck period. + endpointTrackerLifetime = 5*time.Minute + 10*time.Second + + // endpointTrackerMaxPerAddr is how many cached addresses we track for + // a given netip.Addr. This allows e.g. restricting the number of STUN + // endpoints we cache (which usually have the same netip.Addr but + // different ports). + // + // The value of 6 is chosen because we can advertise up to 3 endpoints + // based on the STUN IP: + // 1. The STUN endpoint itself (EndpointSTUN) + // 2. The STUN IP with the local Tailscale port (EndpointSTUN4LocalPort) + // 3. The STUN IP with a portmapped port (EndpointPortmapped) + // + // Storing 6 endpoints in the cache means we can store up to 2 previous + // sets of endpoints. + endpointTrackerMaxPerAddr = 6 +) + +// endpointTrackerEntry is an entry in an endpointHeap that stores the state of +// a given cached endpoint. +type endpointTrackerEntry struct { + // endpoint is the cached endpoint. + endpoint tailcfg.Endpoint + // until is the time until which this endpoint is being cached. + until time.Time + // index is the index within the containing endpointHeap. + index int +} + +// endpointHeap is an ordered heap of endpointTrackerEntry structs, ordered in +// ascending order by the 'until' expiry time (i.e. oldest first). +type endpointHeap []*endpointTrackerEntry + +var _ heap.Interface[*endpointTrackerEntry] = (*endpointHeap)(nil) + +// Len implements heap.Interface. +func (eh endpointHeap) Len() int { return len(eh) } + +// Less implements heap.Interface. +func (eh endpointHeap) Less(i, j int) bool { + // We want to store items so that the lowest item in the heap is the + // oldest, so that heap.Pop()-ing from the endpointHeap will remove the + // oldest entry. + return eh[i].until.Before(eh[j].until) +} + +// Swap implements heap.Interface. +func (eh endpointHeap) Swap(i, j int) { + eh[i], eh[j] = eh[j], eh[i] + eh[i].index = i + eh[j].index = j +} + +// Push implements heap.Interface. +func (eh *endpointHeap) Push(item *endpointTrackerEntry) { + n := len(*eh) + item.index = n + *eh = append(*eh, item) +} + +// Pop implements heap.Interface. +func (eh *endpointHeap) Pop() *endpointTrackerEntry { + old := *eh + n := len(old) + item := old[n-1] + old[n-1] = nil // avoid memory leak + item.index = -1 // for safety + *eh = old[0 : n-1] + return item +} + +// Min returns a pointer to the minimum element in the heap, without removing +// it. Since this is a min-heap ordered by the 'until' field, this returns the +// chronologically "earliest" element in the heap. +// +// Len() must be non-zero. +func (eh endpointHeap) Min() *endpointTrackerEntry { + return eh[0] +} + +// endpointTracker caches endpoints that are advertised to peers. This allows +// peers to still reach this node if there's a temporary endpoint flap; rather +// than withdrawing an endpoint and then re-advertising it the next time we run +// a netcheck, we keep advertising the endpoint until it's not present for a +// defined timeout. +// +// See tailscale/tailscale#7877 for more information. +type endpointTracker struct { + mu sync.Mutex + endpoints map[netip.Addr]*endpointHeap +} + +// update takes as input the current sent of discovered endpoints and the +// current time, and returns the set of endpoints plus any previous-cached and +// non-expired endpoints that should be advertised to peers. +func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) { + var inputEps set.Slice[netip.AddrPort] + for _, ep := range eps { + inputEps.Add(ep.Addr) + } + + et.mu.Lock() + defer et.mu.Unlock() + + // Extend endpoints that already exist in the cache. We do this before + // we remove expired endpoints, below, so we don't remove something + // that would otherwise have survived by extending. + until := now.Add(endpointTrackerLifetime) + for _, ep := range eps { + et.extendLocked(ep, until) + } + + // Now that we've extended existing endpoints, remove everything that + // has expired. + et.removeExpiredLocked(now) + + // Add entries from the input set of endpoints into the cache; we do + // this after removing expired ones so that we can store as many as + // possible, with space freed by the entries removed after expiry. + for _, ep := range eps { + et.addLocked(now, ep, until) + } + + // Finally, add entries to the return array that aren't already there. + epsPlusCached = eps + for _, heap := range et.endpoints { + for _, ep := range *heap { + // If the endpoint was in the input list, or has expired, skip it. + if inputEps.Contains(ep.endpoint.Addr) { + continue + } else if now.After(ep.until) { + // Defense-in-depth; should never happen since + // we removed expired entries above, but ignore + // it anyway. + continue + } + + // We haven't seen this endpoint; add to the return array + epsPlusCached = append(epsPlusCached, ep.endpoint) + } + } + + return epsPlusCached +} + +// extendLocked will update the expiry time of the provided endpoint in the +// cache, if it is present. If it is not present, nothing will be done. +// +// et.mu must be held. +func (et *endpointTracker) extendLocked(ep tailcfg.Endpoint, until time.Time) { + key := ep.Addr.Addr() + epHeap, found := et.endpoints[key] + if !found { + return + } + + // Find the entry for this exact address; this loop is quick since we + // bound the number of items in the heap. + // + // TODO(andrew): this means we iterate over the entire heap once per + // endpoint; even if the heap is small, if we have a lot of input + // endpoints this can be expensive? + for i, entry := range *epHeap { + if entry.endpoint == ep { + entry.until = until + heap.Fix(epHeap, i) + return + } + } +} + +// addLocked will store the provided endpoint(s) in the cache for a fixed +// period of time, ensuring that the size of the endpoint cache remains below +// the maximum. +// +// et.mu must be held. +func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) { + key := ep.Addr.Addr() + + // Create or get the heap for this endpoint's addr + epHeap := et.endpoints[key] + if epHeap == nil { + epHeap = new(endpointHeap) + mak.Set(&et.endpoints, key, epHeap) + } + + // Find the entry for this exact address; this loop is quick + // since we bound the number of items in the heap. + found := slices.ContainsFunc(*epHeap, func(v *endpointTrackerEntry) bool { + return v.endpoint == ep + }) + if !found { + // Add address to heap; either the endpoint is new, or the heap + // was newly-created and thus empty. + heap.Push(epHeap, &endpointTrackerEntry{endpoint: ep, until: until}) + } + + // Now that we've added everything, pop from our heap until we're below + // the limit. This is a min-heap, so popping removes the lowest (and + // thus oldest) endpoint. + for epHeap.Len() > endpointTrackerMaxPerAddr { + heap.Pop(epHeap) + } +} + +// removeExpired will remove all expired entries from the cache. +// +// et.mu must be held. +func (et *endpointTracker) removeExpiredLocked(now time.Time) { + for k, epHeap := range et.endpoints { + // The minimum element is oldest/earliest endpoint; repeatedly + // pop from the heap while it's in the past. + for epHeap.Len() > 0 { + minElem := epHeap.Min() + if now.After(minElem.until) { + heap.Pop(epHeap) + } else { + break + } + } + + if epHeap.Len() == 0 { + // Free up space in the map by removing the empty heap. + delete(et.endpoints, k) + } + } +} diff --git a/wgengine/magicsock/magicsock_unix_test.go b/wgengine/magicsock/magicsock_unix_test.go index b0700a8ebe870..9ad8cab93330b 100644 --- a/wgengine/magicsock/magicsock_unix_test.go +++ b/wgengine/magicsock/magicsock_unix_test.go @@ -1,60 +1,60 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build unix - -package magicsock - -import ( - "net" - "syscall" - "testing" - - "tailscale.com/types/nettype" -) - -func TestTrySetSocketBuffer(t *testing.T) { - c, err := net.ListenPacket("udp", ":0") - if err != nil { - t.Fatal(err) - } - defer c.Close() - - rc, err := c.(*net.UDPConn).SyscallConn() - if err != nil { - t.Fatal(err) - } - - getBufs := func() (int, int) { - var rcv, snd int - rc.Control(func(fd uintptr) { - rcv, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF) - if err != nil { - t.Errorf("getsockopt(SO_RCVBUF): %v", err) - } - snd, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUF) - if err != nil { - t.Errorf("getsockopt(SO_SNDBUF): %v", err) - } - }) - return rcv, snd - } - - curRcv, curSnd := getBufs() - - trySetSocketBuffer(c.(nettype.PacketConn), t.Logf) - - newRcv, newSnd := getBufs() - - if curRcv > newRcv { - t.Errorf("SO_RCVBUF decreased: %v -> %v", curRcv, newRcv) - } - if curSnd > newSnd { - t.Errorf("SO_SNDBUF decreased: %v -> %v", curSnd, newSnd) - } - - // On many systems we may not increase the value, particularly running as a - // regular user, so log the information for manual verification. - t.Logf("SO_RCVBUF: %v -> %v", curRcv, newRcv) - t.Logf("SO_SNDBUF: %v -> %v", curRcv, newRcv) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build unix + +package magicsock + +import ( + "net" + "syscall" + "testing" + + "tailscale.com/types/nettype" +) + +func TestTrySetSocketBuffer(t *testing.T) { + c, err := net.ListenPacket("udp", ":0") + if err != nil { + t.Fatal(err) + } + defer c.Close() + + rc, err := c.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + getBufs := func() (int, int) { + var rcv, snd int + rc.Control(func(fd uintptr) { + rcv, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF) + if err != nil { + t.Errorf("getsockopt(SO_RCVBUF): %v", err) + } + snd, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUF) + if err != nil { + t.Errorf("getsockopt(SO_SNDBUF): %v", err) + } + }) + return rcv, snd + } + + curRcv, curSnd := getBufs() + + trySetSocketBuffer(c.(nettype.PacketConn), t.Logf) + + newRcv, newSnd := getBufs() + + if curRcv > newRcv { + t.Errorf("SO_RCVBUF decreased: %v -> %v", curRcv, newRcv) + } + if curSnd > newSnd { + t.Errorf("SO_SNDBUF decreased: %v -> %v", curSnd, newSnd) + } + + // On many systems we may not increase the value, particularly running as a + // regular user, so log the information for manual verification. + t.Logf("SO_RCVBUF: %v -> %v", curRcv, newRcv) + t.Logf("SO_SNDBUF: %v -> %v", curRcv, newRcv) +} diff --git a/wgengine/magicsock/peermtu_darwin.go b/wgengine/magicsock/peermtu_darwin.go index a0a1aacb55f5f..b2a1ed217b2b8 100644 --- a/wgengine/magicsock/peermtu_darwin.go +++ b/wgengine/magicsock/peermtu_darwin.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin && !ios - -package magicsock - -import ( - "syscall" - - "golang.org/x/sys/unix" -) - -func getDontFragOpt(network string) int { - if network == "udp4" { - return unix.IP_DONTFRAG - } - return unix.IPV6_DONTFRAG -} - -func (c *Conn) setDontFragment(network string, enable bool) error { - optArg := 1 - if enable == false { - optArg = 0 - } - var err error - rcErr := c.connControl(network, func(fd uintptr) { - err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) - }) - - if rcErr != nil { - return rcErr - } - return err -} - -func (c *Conn) getDontFragment(network string) (bool, error) { - var v int - var err error - rcErr := c.connControl(network, func(fd uintptr) { - v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) - }) - - if rcErr != nil { - return false, rcErr - } - if v == 1 { - return true, err - } - return false, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin && !ios + +package magicsock + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func getDontFragOpt(network string) int { + if network == "udp4" { + return unix.IP_DONTFRAG + } + return unix.IPV6_DONTFRAG +} + +func (c *Conn) setDontFragment(network string, enable bool) error { + optArg := 1 + if enable == false { + optArg = 0 + } + var err error + rcErr := c.connControl(network, func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) + }) + + if rcErr != nil { + return rcErr + } + return err +} + +func (c *Conn) getDontFragment(network string) (bool, error) { + var v int + var err error + rcErr := c.connControl(network, func(fd uintptr) { + v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) + }) + + if rcErr != nil { + return false, rcErr + } + if v == 1 { + return true, err + } + return false, err +} diff --git a/wgengine/magicsock/peermtu_linux.go b/wgengine/magicsock/peermtu_linux.go index b76f30f081042..d32ead0991953 100644 --- a/wgengine/magicsock/peermtu_linux.go +++ b/wgengine/magicsock/peermtu_linux.go @@ -1,49 +1,49 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux && !android - -package magicsock - -import ( - "syscall" -) - -func getDontFragOpt(network string) int { - if network == "udp4" { - return syscall.IP_MTU_DISCOVER - } - return syscall.IPV6_MTU_DISCOVER -} - -func (c *Conn) setDontFragment(network string, enable bool) error { - optArg := syscall.IP_PMTUDISC_DO - if enable == false { - optArg = syscall.IP_PMTUDISC_DONT - } - var err error - rcErr := c.connControl(network, func(fd uintptr) { - err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) - }) - - if rcErr != nil { - return rcErr - } - return err -} - -func (c *Conn) getDontFragment(network string) (bool, error) { - var v int - var err error - rcErr := c.connControl(network, func(fd uintptr) { - v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) - }) - - if rcErr != nil { - return false, rcErr - } - if v == syscall.IP_PMTUDISC_DO { - return true, err - } - return false, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !android + +package magicsock + +import ( + "syscall" +) + +func getDontFragOpt(network string) int { + if network == "udp4" { + return syscall.IP_MTU_DISCOVER + } + return syscall.IPV6_MTU_DISCOVER +} + +func (c *Conn) setDontFragment(network string, enable bool) error { + optArg := syscall.IP_PMTUDISC_DO + if enable == false { + optArg = syscall.IP_PMTUDISC_DONT + } + var err error + rcErr := c.connControl(network, func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) + }) + + if rcErr != nil { + return rcErr + } + return err +} + +func (c *Conn) getDontFragment(network string) (bool, error) { + var v int + var err error + rcErr := c.connControl(network, func(fd uintptr) { + v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) + }) + + if rcErr != nil { + return false, rcErr + } + if v == syscall.IP_PMTUDISC_DO { + return true, err + } + return false, err +} diff --git a/wgengine/magicsock/peermtu_unix.go b/wgengine/magicsock/peermtu_unix.go index eec3d744f3ded..59e808ee75e34 100644 --- a/wgengine/magicsock/peermtu_unix.go +++ b/wgengine/magicsock/peermtu_unix.go @@ -1,42 +1,42 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build (darwin && !ios) || (linux && !android) - -package magicsock - -import ( - "syscall" -) - -// getIPProto returns the value of the get/setsockopt proto argument necessary -// to set an IP sockopt that corresponds with the string network, which must be -// "udp4" or "udp6". -func getIPProto(network string) int { - if network == "udp4" { - return syscall.IPPROTO_IP - } - return syscall.IPPROTO_IPV6 -} - -// connControl allows the caller to run a system call on the socket underlying -// Conn specified by the string network, which must be "udp4" or "udp6". If the -// pconn type implements the syscall method, this function returns the value of -// of the system call fn called with the fd of the socket as its arg (or the -// error from rc.Control() if that fails). Otherwise it returns the error -// errUnsupportedConnType. -func (c *Conn) connControl(network string, fn func(fd uintptr)) error { - pconn := c.pconn4.pconn - if network == "udp6" { - pconn = c.pconn6.pconn - } - sc, ok := pconn.(syscall.Conn) - if !ok { - return errUnsupportedConnType - } - rc, err := sc.SyscallConn() - if err != nil { - return err - } - return rc.Control(fn) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (darwin && !ios) || (linux && !android) + +package magicsock + +import ( + "syscall" +) + +// getIPProto returns the value of the get/setsockopt proto argument necessary +// to set an IP sockopt that corresponds with the string network, which must be +// "udp4" or "udp6". +func getIPProto(network string) int { + if network == "udp4" { + return syscall.IPPROTO_IP + } + return syscall.IPPROTO_IPV6 +} + +// connControl allows the caller to run a system call on the socket underlying +// Conn specified by the string network, which must be "udp4" or "udp6". If the +// pconn type implements the syscall method, this function returns the value of +// of the system call fn called with the fd of the socket as its arg (or the +// error from rc.Control() if that fails). Otherwise it returns the error +// errUnsupportedConnType. +func (c *Conn) connControl(network string, fn func(fd uintptr)) error { + pconn := c.pconn4.pconn + if network == "udp6" { + pconn = c.pconn6.pconn + } + sc, ok := pconn.(syscall.Conn) + if !ok { + return errUnsupportedConnType + } + rc, err := sc.SyscallConn() + if err != nil { + return err + } + return rc.Control(fn) +} diff --git a/wgengine/mem_ios.go b/wgengine/mem_ios.go index cc266ea3aadc8..975dfca611fbb 100644 --- a/wgengine/mem_ios.go +++ b/wgengine/mem_ios.go @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgengine - -import ( - "github.com/tailscale/wireguard-go/device" -) - -// iOS has a very restrictive memory limit on network extensions. -// Reduce the maximum amount of memory that wireguard-go can allocate -// to avoid getting killed. - -func init() { - device.QueueStagedSize = 64 - device.QueueOutboundSize = 64 - device.QueueInboundSize = 64 - device.QueueHandshakeSize = 64 - device.PreallocatedBuffersPerPool = 64 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package wgengine + +import ( + "github.com/tailscale/wireguard-go/device" +) + +// iOS has a very restrictive memory limit on network extensions. +// Reduce the maximum amount of memory that wireguard-go can allocate +// to avoid getting killed. + +func init() { + device.QueueStagedSize = 64 + device.QueueOutboundSize = 64 + device.QueueInboundSize = 64 + device.QueueHandshakeSize = 64 + device.PreallocatedBuffersPerPool = 64 +} diff --git a/wgengine/netstack/netstack_linux.go b/wgengine/netstack/netstack_linux.go index a0bfb44567da7..9e27b7819dc4d 100644 --- a/wgengine/netstack/netstack_linux.go +++ b/wgengine/netstack/netstack_linux.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netstack - -import ( - "os/exec" - "syscall" - - "golang.org/x/sys/unix" -) - -func init() { - setAmbientCapsRaw = func(cmd *exec.Cmd) { - cmd.SysProcAttr = &syscall.SysProcAttr{ - AmbientCaps: []uintptr{unix.CAP_NET_RAW}, - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netstack + +import ( + "os/exec" + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + setAmbientCapsRaw = func(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + AmbientCaps: []uintptr{unix.CAP_NET_RAW}, + } + } +} diff --git a/wgengine/router/runner.go b/wgengine/router/runner.go index 8fa068e335e66..7ba633344f601 100644 --- a/wgengine/router/runner.go +++ b/wgengine/router/runner.go @@ -1,120 +1,120 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package router - -import ( - "errors" - "fmt" - "os" - "os/exec" - "strconv" - "strings" - "syscall" - - "golang.org/x/sys/unix" -) - -// commandRunner abstracts helpers to run OS commands. It exists -// purely to swap out osCommandRunner (below) with a fake runner in -// tests. -type commandRunner interface { - run(...string) error - output(...string) ([]byte, error) -} - -type osCommandRunner struct { - // ambientCapNetAdmin determines whether commands are executed with - // CAP_NET_ADMIN. - // CAP_NET_ADMIN is required when running as non-root and executing cmds - // like `ip rule`. Even if our process has the capability, we need to - // explicitly grant it to the new process. - // We specifically need this for Synology DSM7 where tailscaled no longer - // runs as root. - ambientCapNetAdmin bool -} - -// errCode extracts and returns the process exit code from err, or -// zero if err is nil. -func errCode(err error) int { - if err == nil { - return 0 - } - var e *exec.ExitError - if ok := errors.As(err, &e); ok { - return e.ExitCode() - } - s := err.Error() - if strings.HasPrefix(s, "exitcode:") { - code, err := strconv.Atoi(s[9:]) - if err == nil { - return code - } - } - return -42 -} - -func (o osCommandRunner) run(args ...string) error { - _, err := o.output(args...) - return err -} - -func (o osCommandRunner) output(args ...string) ([]byte, error) { - if len(args) == 0 { - return nil, errors.New("cmd: no argv[0]") - } - - cmd := exec.Command(args[0], args[1:]...) - cmd.Env = append(os.Environ(), "LC_ALL=C") - if o.ambientCapNetAdmin { - cmd.SysProcAttr = &syscall.SysProcAttr{ - AmbientCaps: []uintptr{unix.CAP_NET_ADMIN}, - } - } - out, err := cmd.CombinedOutput() - if err != nil { - return nil, fmt.Errorf("running %q failed: %w\n%s", strings.Join(args, " "), err, out) - } - - return out, nil -} - -type runGroup struct { - OkCode []int // error codes that are acceptable, other than 0, if any - Runner commandRunner // the runner that actually runs our commands - ErrAcc error // first error encountered, if any -} - -func newRunGroup(okCode []int, runner commandRunner) *runGroup { - return &runGroup{ - OkCode: okCode, - Runner: runner, - } -} - -func (rg *runGroup) okCode(err error) bool { - got := errCode(err) - for _, want := range rg.OkCode { - if got == want { - return true - } - } - return false -} - -func (rg *runGroup) Output(args ...string) []byte { - b, err := rg.Runner.output(args...) - if rg.ErrAcc == nil && err != nil && !rg.okCode(err) { - rg.ErrAcc = err - } - return b -} - -func (rg *runGroup) Run(args ...string) { - err := rg.Runner.run(args...) - if rg.ErrAcc == nil && err != nil && !rg.okCode(err) { - rg.ErrAcc = err - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package router + +import ( + "errors" + "fmt" + "os" + "os/exec" + "strconv" + "strings" + "syscall" + + "golang.org/x/sys/unix" +) + +// commandRunner abstracts helpers to run OS commands. It exists +// purely to swap out osCommandRunner (below) with a fake runner in +// tests. +type commandRunner interface { + run(...string) error + output(...string) ([]byte, error) +} + +type osCommandRunner struct { + // ambientCapNetAdmin determines whether commands are executed with + // CAP_NET_ADMIN. + // CAP_NET_ADMIN is required when running as non-root and executing cmds + // like `ip rule`. Even if our process has the capability, we need to + // explicitly grant it to the new process. + // We specifically need this for Synology DSM7 where tailscaled no longer + // runs as root. + ambientCapNetAdmin bool +} + +// errCode extracts and returns the process exit code from err, or +// zero if err is nil. +func errCode(err error) int { + if err == nil { + return 0 + } + var e *exec.ExitError + if ok := errors.As(err, &e); ok { + return e.ExitCode() + } + s := err.Error() + if strings.HasPrefix(s, "exitcode:") { + code, err := strconv.Atoi(s[9:]) + if err == nil { + return code + } + } + return -42 +} + +func (o osCommandRunner) run(args ...string) error { + _, err := o.output(args...) + return err +} + +func (o osCommandRunner) output(args ...string) ([]byte, error) { + if len(args) == 0 { + return nil, errors.New("cmd: no argv[0]") + } + + cmd := exec.Command(args[0], args[1:]...) + cmd.Env = append(os.Environ(), "LC_ALL=C") + if o.ambientCapNetAdmin { + cmd.SysProcAttr = &syscall.SysProcAttr{ + AmbientCaps: []uintptr{unix.CAP_NET_ADMIN}, + } + } + out, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("running %q failed: %w\n%s", strings.Join(args, " "), err, out) + } + + return out, nil +} + +type runGroup struct { + OkCode []int // error codes that are acceptable, other than 0, if any + Runner commandRunner // the runner that actually runs our commands + ErrAcc error // first error encountered, if any +} + +func newRunGroup(okCode []int, runner commandRunner) *runGroup { + return &runGroup{ + OkCode: okCode, + Runner: runner, + } +} + +func (rg *runGroup) okCode(err error) bool { + got := errCode(err) + for _, want := range rg.OkCode { + if got == want { + return true + } + } + return false +} + +func (rg *runGroup) Output(args ...string) []byte { + b, err := rg.Runner.output(args...) + if rg.ErrAcc == nil && err != nil && !rg.okCode(err) { + rg.ErrAcc = err + } + return b +} + +func (rg *runGroup) Run(args ...string) { + err := rg.Runner.run(args...) + if rg.ErrAcc == nil && err != nil && !rg.okCode(err) { + rg.ErrAcc = err + } +} diff --git a/wgengine/watchdog_js.go b/wgengine/watchdog_js.go index 872ce36d5fd5d..9dcb29c4ee556 100644 --- a/wgengine/watchdog_js.go +++ b/wgengine/watchdog_js.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build js - -package wgengine - -import "tailscale.com/net/dns/resolver" - -type watchdogEngine struct { - Engine - wrap Engine -} - -func (e *watchdogEngine) GetResolver() (r *resolver.Resolver, ok bool) { - return nil, false -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build js + +package wgengine + +import "tailscale.com/net/dns/resolver" + +type watchdogEngine struct { + Engine + wrap Engine +} + +func (e *watchdogEngine) GetResolver() (r *resolver.Resolver, ok bool) { + return nil, false +} diff --git a/wgengine/wgcfg/device.go b/wgengine/wgcfg/device.go index 80fa159e38972..9b83998cb4232 100644 --- a/wgengine/wgcfg/device.go +++ b/wgengine/wgcfg/device.go @@ -1,68 +1,68 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "io" - "sort" - - "github.com/tailscale/wireguard-go/conn" - "github.com/tailscale/wireguard-go/device" - "github.com/tailscale/wireguard-go/tun" - "tailscale.com/types/logger" - "tailscale.com/util/multierr" -) - -// NewDevice returns a wireguard-go Device configured for Tailscale use. -func NewDevice(tunDev tun.Device, bind conn.Bind, logger *device.Logger) *device.Device { - ret := device.NewDevice(tunDev, bind, logger) - ret.DisableSomeRoamingForBrokenMobileSemantics() - return ret -} - -func DeviceConfig(d *device.Device) (*Config, error) { - r, w := io.Pipe() - errc := make(chan error, 1) - go func() { - errc <- d.IpcGetOperation(w) - w.Close() - }() - cfg, fromErr := FromUAPI(r) - r.Close() - getErr := <-errc - err := multierr.New(getErr, fromErr) - if err != nil { - return nil, err - } - sort.Slice(cfg.Peers, func(i, j int) bool { - return cfg.Peers[i].PublicKey.Less(cfg.Peers[j].PublicKey) - }) - return cfg, nil -} - -// ReconfigDevice replaces the existing device configuration with cfg. -func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) { - defer func() { - if err != nil { - logf("wgcfg.Reconfig failed: %v", err) - } - }() - - prev, err := DeviceConfig(d) - if err != nil { - return err - } - - r, w := io.Pipe() - errc := make(chan error, 1) - go func() { - errc <- d.IpcSetOperation(r) - r.Close() - }() - - toErr := cfg.ToUAPI(logf, w, prev) - w.Close() - setErr := <-errc - return multierr.New(setErr, toErr) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package wgcfg + +import ( + "io" + "sort" + + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/types/logger" + "tailscale.com/util/multierr" +) + +// NewDevice returns a wireguard-go Device configured for Tailscale use. +func NewDevice(tunDev tun.Device, bind conn.Bind, logger *device.Logger) *device.Device { + ret := device.NewDevice(tunDev, bind, logger) + ret.DisableSomeRoamingForBrokenMobileSemantics() + return ret +} + +func DeviceConfig(d *device.Device) (*Config, error) { + r, w := io.Pipe() + errc := make(chan error, 1) + go func() { + errc <- d.IpcGetOperation(w) + w.Close() + }() + cfg, fromErr := FromUAPI(r) + r.Close() + getErr := <-errc + err := multierr.New(getErr, fromErr) + if err != nil { + return nil, err + } + sort.Slice(cfg.Peers, func(i, j int) bool { + return cfg.Peers[i].PublicKey.Less(cfg.Peers[j].PublicKey) + }) + return cfg, nil +} + +// ReconfigDevice replaces the existing device configuration with cfg. +func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) { + defer func() { + if err != nil { + logf("wgcfg.Reconfig failed: %v", err) + } + }() + + prev, err := DeviceConfig(d) + if err != nil { + return err + } + + r, w := io.Pipe() + errc := make(chan error, 1) + go func() { + errc <- d.IpcSetOperation(r) + r.Close() + }() + + toErr := cfg.ToUAPI(logf, w, prev) + w.Close() + setErr := <-errc + return multierr.New(setErr, toErr) +} diff --git a/wgengine/wgcfg/device_test.go b/wgengine/wgcfg/device_test.go index d54282e4bdf04..c54ad16d9e8b2 100644 --- a/wgengine/wgcfg/device_test.go +++ b/wgengine/wgcfg/device_test.go @@ -1,261 +1,261 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "bufio" - "bytes" - "io" - "net/netip" - "os" - "sort" - "strings" - "sync" - "testing" - - "github.com/tailscale/wireguard-go/conn" - "github.com/tailscale/wireguard-go/device" - "github.com/tailscale/wireguard-go/tun" - "go4.org/mem" - "tailscale.com/types/key" -) - -func TestDeviceConfig(t *testing.T) { - newK := func() (key.NodePublic, key.NodePrivate) { - t.Helper() - k := key.NewNode() - return k.Public(), k - } - k1, pk1 := newK() - ip1 := netip.MustParsePrefix("10.0.0.1/32") - - k2, pk2 := newK() - ip2 := netip.MustParsePrefix("10.0.0.2/32") - - k3, _ := newK() - ip3 := netip.MustParsePrefix("10.0.0.3/32") - - cfg1 := &Config{ - PrivateKey: pk1, - Peers: []Peer{{ - PublicKey: k2, - AllowedIPs: []netip.Prefix{ip2}, - }}, - } - - cfg2 := &Config{ - PrivateKey: pk2, - Peers: []Peer{{ - PublicKey: k1, - AllowedIPs: []netip.Prefix{ip1}, - PersistentKeepalive: 5, - }}, - } - - device1 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device1")) - device2 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device2")) - defer device1.Close() - defer device2.Close() - - cmp := func(t *testing.T, d *device.Device, want *Config) { - t.Helper() - got, err := DeviceConfig(d) - if err != nil { - t.Fatal(err) - } - prev := new(Config) - gotbuf := new(strings.Builder) - err = got.ToUAPI(t.Logf, gotbuf, prev) - gotStr := gotbuf.String() - if err != nil { - t.Errorf("got.ToUAPI(): error: %v", err) - return - } - wantbuf := new(strings.Builder) - err = want.ToUAPI(t.Logf, wantbuf, prev) - wantStr := wantbuf.String() - if err != nil { - t.Errorf("want.ToUAPI(): error: %v", err) - return - } - if gotStr != wantStr { - buf := new(bytes.Buffer) - w := bufio.NewWriter(buf) - if err := d.IpcGetOperation(w); err != nil { - t.Errorf("on error, could not IpcGetOperation: %v", err) - } - w.Flush() - t.Errorf("config mismatch:\n---- got:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String()) - } - } - - t.Run("device1 config", func(t *testing.T) { - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device2 config", func(t *testing.T) { - if err := ReconfigDevice(device2, cfg2, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device2, cfg2) - }) - - // This is only to test that Config and Reconfig are properly synchronized. - t.Run("device2 config/reconfig", func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(2) - - go func() { - ReconfigDevice(device2, cfg2, t.Logf) - wg.Done() - }() - - go func() { - DeviceConfig(device2) - wg.Done() - }() - - wg.Wait() - }) - - t.Run("device1 modify peer", func(t *testing.T) { - cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 1, 31: 0})) - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device1 replace endpoint", func(t *testing.T) { - cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 2, 31: 0})) - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device1 add new peer", func(t *testing.T) { - cfg1.Peers = append(cfg1.Peers, Peer{ - PublicKey: k3, - AllowedIPs: []netip.Prefix{ip3}, - }) - sort.Slice(cfg1.Peers, func(i, j int) bool { - return cfg1.Peers[i].PublicKey.Less(cfg1.Peers[j].PublicKey) - }) - - origCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - - newCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - peer0 := func(cfg *Config) Peer { - p, ok := cfg.PeerWithKey(k2) - if !ok { - t.Helper() - t.Fatal("failed to look up peer 2") - } - return p - } - peersEqual := func(p, q Peer) bool { - return p.PublicKey == q.PublicKey && p.DiscoKey == q.DiscoKey && p.PersistentKeepalive == q.PersistentKeepalive && cidrsEqual(p.AllowedIPs, q.AllowedIPs) - } - if !peersEqual(peer0(origCfg), peer0(newCfg)) { - t.Error("reconfig modified old peer") - } - }) - - t.Run("device1 remove peer", func(t *testing.T) { - removeKey := cfg1.Peers[len(cfg1.Peers)-1].PublicKey - cfg1.Peers = cfg1.Peers[:len(cfg1.Peers)-1] - - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - - newCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - _, ok := newCfg.PeerWithKey(removeKey) - if ok { - t.Error("reconfig failed to remove peer") - } - }) -} - -// TODO: replace with a loopback tunnel -type nilTun struct { - events chan tun.Event - closed chan struct{} -} - -func newNilTun() tun.Device { - return &nilTun{ - events: make(chan tun.Event), - closed: make(chan struct{}), - } -} - -func (t *nilTun) File() *os.File { return nil } -func (t *nilTun) Flush() error { return nil } -func (t *nilTun) MTU() (int, error) { return 1420, nil } -func (t *nilTun) Name() (string, error) { return "niltun", nil } -func (t *nilTun) Events() <-chan tun.Event { return t.events } - -func (t *nilTun) Read(data [][]byte, sizes []int, offset int) (int, error) { - <-t.closed - return 0, io.EOF -} - -func (t *nilTun) Write(data [][]byte, offset int) (int, error) { - <-t.closed - return 0, io.EOF -} - -func (t *nilTun) Close() error { - close(t.events) - close(t.closed) - return nil -} - -func (t *nilTun) BatchSize() int { return 1 } - -// A noopBind is a conn.Bind that does no actual binding work. -type noopBind struct{} - -func (noopBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { - return nil, 1, nil -} -func (noopBind) Close() error { return nil } -func (noopBind) SetMark(mark uint32) error { return nil } -func (noopBind) Send(b [][]byte, ep conn.Endpoint) error { return nil } -func (noopBind) ParseEndpoint(s string) (conn.Endpoint, error) { - return dummyEndpoint(s), nil -} -func (noopBind) BatchSize() int { return 1 } - -// A dummyEndpoint is a string holding the endpoint destination. -type dummyEndpoint string - -func (e dummyEndpoint) ClearSrc() {} -func (e dummyEndpoint) SrcToString() string { return "" } -func (e dummyEndpoint) DstToString() string { return string(e) } -func (e dummyEndpoint) DstToBytes() []byte { return nil } -func (e dummyEndpoint) DstIP() netip.Addr { return netip.Addr{} } -func (dummyEndpoint) SrcIP() netip.Addr { return netip.Addr{} } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package wgcfg + +import ( + "bufio" + "bytes" + "io" + "net/netip" + "os" + "sort" + "strings" + "sync" + "testing" + + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "go4.org/mem" + "tailscale.com/types/key" +) + +func TestDeviceConfig(t *testing.T) { + newK := func() (key.NodePublic, key.NodePrivate) { + t.Helper() + k := key.NewNode() + return k.Public(), k + } + k1, pk1 := newK() + ip1 := netip.MustParsePrefix("10.0.0.1/32") + + k2, pk2 := newK() + ip2 := netip.MustParsePrefix("10.0.0.2/32") + + k3, _ := newK() + ip3 := netip.MustParsePrefix("10.0.0.3/32") + + cfg1 := &Config{ + PrivateKey: pk1, + Peers: []Peer{{ + PublicKey: k2, + AllowedIPs: []netip.Prefix{ip2}, + }}, + } + + cfg2 := &Config{ + PrivateKey: pk2, + Peers: []Peer{{ + PublicKey: k1, + AllowedIPs: []netip.Prefix{ip1}, + PersistentKeepalive: 5, + }}, + } + + device1 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device1")) + device2 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device2")) + defer device1.Close() + defer device2.Close() + + cmp := func(t *testing.T, d *device.Device, want *Config) { + t.Helper() + got, err := DeviceConfig(d) + if err != nil { + t.Fatal(err) + } + prev := new(Config) + gotbuf := new(strings.Builder) + err = got.ToUAPI(t.Logf, gotbuf, prev) + gotStr := gotbuf.String() + if err != nil { + t.Errorf("got.ToUAPI(): error: %v", err) + return + } + wantbuf := new(strings.Builder) + err = want.ToUAPI(t.Logf, wantbuf, prev) + wantStr := wantbuf.String() + if err != nil { + t.Errorf("want.ToUAPI(): error: %v", err) + return + } + if gotStr != wantStr { + buf := new(bytes.Buffer) + w := bufio.NewWriter(buf) + if err := d.IpcGetOperation(w); err != nil { + t.Errorf("on error, could not IpcGetOperation: %v", err) + } + w.Flush() + t.Errorf("config mismatch:\n---- got:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String()) + } + } + + t.Run("device1 config", func(t *testing.T) { + if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device1, cfg1) + }) + + t.Run("device2 config", func(t *testing.T) { + if err := ReconfigDevice(device2, cfg2, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device2, cfg2) + }) + + // This is only to test that Config and Reconfig are properly synchronized. + t.Run("device2 config/reconfig", func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(2) + + go func() { + ReconfigDevice(device2, cfg2, t.Logf) + wg.Done() + }() + + go func() { + DeviceConfig(device2) + wg.Done() + }() + + wg.Wait() + }) + + t.Run("device1 modify peer", func(t *testing.T) { + cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 1, 31: 0})) + if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device1, cfg1) + }) + + t.Run("device1 replace endpoint", func(t *testing.T) { + cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 2, 31: 0})) + if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device1, cfg1) + }) + + t.Run("device1 add new peer", func(t *testing.T) { + cfg1.Peers = append(cfg1.Peers, Peer{ + PublicKey: k3, + AllowedIPs: []netip.Prefix{ip3}, + }) + sort.Slice(cfg1.Peers, func(i, j int) bool { + return cfg1.Peers[i].PublicKey.Less(cfg1.Peers[j].PublicKey) + }) + + origCfg, err := DeviceConfig(device1) + if err != nil { + t.Fatal(err) + } + + if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device1, cfg1) + + newCfg, err := DeviceConfig(device1) + if err != nil { + t.Fatal(err) + } + + peer0 := func(cfg *Config) Peer { + p, ok := cfg.PeerWithKey(k2) + if !ok { + t.Helper() + t.Fatal("failed to look up peer 2") + } + return p + } + peersEqual := func(p, q Peer) bool { + return p.PublicKey == q.PublicKey && p.DiscoKey == q.DiscoKey && p.PersistentKeepalive == q.PersistentKeepalive && cidrsEqual(p.AllowedIPs, q.AllowedIPs) + } + if !peersEqual(peer0(origCfg), peer0(newCfg)) { + t.Error("reconfig modified old peer") + } + }) + + t.Run("device1 remove peer", func(t *testing.T) { + removeKey := cfg1.Peers[len(cfg1.Peers)-1].PublicKey + cfg1.Peers = cfg1.Peers[:len(cfg1.Peers)-1] + + if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device1, cfg1) + + newCfg, err := DeviceConfig(device1) + if err != nil { + t.Fatal(err) + } + + _, ok := newCfg.PeerWithKey(removeKey) + if ok { + t.Error("reconfig failed to remove peer") + } + }) +} + +// TODO: replace with a loopback tunnel +type nilTun struct { + events chan tun.Event + closed chan struct{} +} + +func newNilTun() tun.Device { + return &nilTun{ + events: make(chan tun.Event), + closed: make(chan struct{}), + } +} + +func (t *nilTun) File() *os.File { return nil } +func (t *nilTun) Flush() error { return nil } +func (t *nilTun) MTU() (int, error) { return 1420, nil } +func (t *nilTun) Name() (string, error) { return "niltun", nil } +func (t *nilTun) Events() <-chan tun.Event { return t.events } + +func (t *nilTun) Read(data [][]byte, sizes []int, offset int) (int, error) { + <-t.closed + return 0, io.EOF +} + +func (t *nilTun) Write(data [][]byte, offset int) (int, error) { + <-t.closed + return 0, io.EOF +} + +func (t *nilTun) Close() error { + close(t.events) + close(t.closed) + return nil +} + +func (t *nilTun) BatchSize() int { return 1 } + +// A noopBind is a conn.Bind that does no actual binding work. +type noopBind struct{} + +func (noopBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + return nil, 1, nil +} +func (noopBind) Close() error { return nil } +func (noopBind) SetMark(mark uint32) error { return nil } +func (noopBind) Send(b [][]byte, ep conn.Endpoint) error { return nil } +func (noopBind) ParseEndpoint(s string) (conn.Endpoint, error) { + return dummyEndpoint(s), nil +} +func (noopBind) BatchSize() int { return 1 } + +// A dummyEndpoint is a string holding the endpoint destination. +type dummyEndpoint string + +func (e dummyEndpoint) ClearSrc() {} +func (e dummyEndpoint) SrcToString() string { return "" } +func (e dummyEndpoint) DstToString() string { return string(e) } +func (e dummyEndpoint) DstToBytes() []byte { return nil } +func (e dummyEndpoint) DstIP() netip.Addr { return netip.Addr{} } +func (dummyEndpoint) SrcIP() netip.Addr { return netip.Addr{} } diff --git a/wgengine/wgcfg/parser.go b/wgengine/wgcfg/parser.go index ec3d008f7de97..553aaecbb7171 100644 --- a/wgengine/wgcfg/parser.go +++ b/wgengine/wgcfg/parser.go @@ -1,186 +1,186 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "bufio" - "fmt" - "io" - "net" - "net/netip" - "strconv" - "strings" - - "go4.org/mem" - "tailscale.com/types/key" -) - -type ParseError struct { - why string - offender string -} - -func (e *ParseError) Error() string { - return fmt.Sprintf("%s: %q", e.why, e.offender) -} - -func parseEndpoint(s string) (host string, port uint16, err error) { - i := strings.LastIndexByte(s, ':') - if i < 0 { - return "", 0, &ParseError{"Missing port from endpoint", s} - } - host, portStr := s[:i], s[i+1:] - if len(host) < 1 { - return "", 0, &ParseError{"Invalid endpoint host", host} - } - uport, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return "", 0, err - } - hostColon := strings.IndexByte(host, ':') - if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { - err := &ParseError{"Brackets must contain an IPv6 address", host} - if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { - maybeV6 := net.ParseIP(host[1 : len(host)-1]) - if maybeV6 == nil || len(maybeV6) != net.IPv6len { - return "", 0, err - } - } else { - return "", 0, err - } - host = host[1 : len(host)-1] - } - return host, uint16(uport), nil -} - -// memROCut separates a mem.RO at the separator if it exists, otherwise -// it returns two empty ROs and reports that it was not found. -func memROCut(s mem.RO, sep byte) (before, after mem.RO, found bool) { - if i := mem.IndexByte(s, sep); i >= 0 { - return s.SliceTo(i), s.SliceFrom(i + 1), true - } - found = false - return -} - -// FromUAPI generates a Config from r. -// r should be generated by calling device.IpcGetOperation; -// it is not compatible with other uapi streams. -func FromUAPI(r io.Reader) (*Config, error) { - cfg := new(Config) - var peer *Peer // current peer being operated on - deviceConfig := true - - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := mem.B(scanner.Bytes()) - if line.Len() == 0 { - continue - } - key, value, ok := memROCut(line, '=') - if !ok { - return nil, fmt.Errorf("failed to cut line %q on =", line.StringCopy()) - } - valueBytes := scanner.Bytes()[key.Len()+1:] - - if key.EqualString("public_key") { - if deviceConfig { - deviceConfig = false - } - // Load/create the peer we are now configuring. - var err error - peer, err = cfg.handlePublicKeyLine(valueBytes) - if err != nil { - return nil, err - } - continue - } - - var err error - if deviceConfig { - err = cfg.handleDeviceLine(key, value, valueBytes) - } else { - err = cfg.handlePeerLine(peer, key, value, valueBytes) - } - if err != nil { - return nil, err - } - } - - if err := scanner.Err(); err != nil { - return nil, err - } - - return cfg, nil -} - -func (cfg *Config) handleDeviceLine(k, value mem.RO, valueBytes []byte) error { - switch { - case k.EqualString("private_key"): - // wireguard-go guarantees not to send zero value; private keys are already clamped. - var err error - cfg.PrivateKey, err = key.ParseNodePrivateUntyped(value) - if err != nil { - return err - } - case k.EqualString("listen_port") || k.EqualString("fwmark"): - // ignore - default: - return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) - } - return nil -} - -func (cfg *Config) handlePublicKeyLine(valueBytes []byte) (*Peer, error) { - p := Peer{} - var err error - p.PublicKey, err = key.ParseNodePublicUntyped(mem.B(valueBytes)) - if err != nil { - return nil, err - } - cfg.Peers = append(cfg.Peers, p) - return &cfg.Peers[len(cfg.Peers)-1], nil -} - -func (cfg *Config) handlePeerLine(peer *Peer, k, value mem.RO, valueBytes []byte) error { - switch { - case k.EqualString("endpoint"): - nk, err := key.ParseNodePublicUntyped(value) - if err != nil { - return fmt.Errorf("invalid endpoint %q for peer %q, expected a hex public key", value.StringCopy(), peer.PublicKey.ShortString()) - } - // nk ought to equal peer.PublicKey. - // Under some rare circumstances, it might not. See corp issue #3016. - // Even if that happens, don't stop early, so that we can recover from it. - // Instead, note the value of nk so we can fix as needed. - peer.WGEndpoint = nk - case k.EqualString("persistent_keepalive_interval"): - n, err := mem.ParseUint(value, 10, 16) - if err != nil { - return err - } - peer.PersistentKeepalive = uint16(n) - case k.EqualString("allowed_ip"): - ipp := netip.Prefix{} - err := ipp.UnmarshalText(valueBytes) - if err != nil { - return err - } - peer.AllowedIPs = append(peer.AllowedIPs, ipp) - case k.EqualString("protocol_version"): - if !value.EqualString("1") { - return fmt.Errorf("invalid protocol version: %q", value.StringCopy()) - } - case k.EqualString("replace_allowed_ips") || - k.EqualString("preshared_key") || - k.EqualString("last_handshake_time_sec") || - k.EqualString("last_handshake_time_nsec") || - k.EqualString("tx_bytes") || - k.EqualString("rx_bytes"): - // ignore - default: - return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package wgcfg + +import ( + "bufio" + "fmt" + "io" + "net" + "net/netip" + "strconv" + "strings" + + "go4.org/mem" + "tailscale.com/types/key" +) + +type ParseError struct { + why string + offender string +} + +func (e *ParseError) Error() string { + return fmt.Sprintf("%s: %q", e.why, e.offender) +} + +func parseEndpoint(s string) (host string, port uint16, err error) { + i := strings.LastIndexByte(s, ':') + if i < 0 { + return "", 0, &ParseError{"Missing port from endpoint", s} + } + host, portStr := s[:i], s[i+1:] + if len(host) < 1 { + return "", 0, &ParseError{"Invalid endpoint host", host} + } + uport, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return "", 0, err + } + hostColon := strings.IndexByte(host, ':') + if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { + err := &ParseError{"Brackets must contain an IPv6 address", host} + if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { + maybeV6 := net.ParseIP(host[1 : len(host)-1]) + if maybeV6 == nil || len(maybeV6) != net.IPv6len { + return "", 0, err + } + } else { + return "", 0, err + } + host = host[1 : len(host)-1] + } + return host, uint16(uport), nil +} + +// memROCut separates a mem.RO at the separator if it exists, otherwise +// it returns two empty ROs and reports that it was not found. +func memROCut(s mem.RO, sep byte) (before, after mem.RO, found bool) { + if i := mem.IndexByte(s, sep); i >= 0 { + return s.SliceTo(i), s.SliceFrom(i + 1), true + } + found = false + return +} + +// FromUAPI generates a Config from r. +// r should be generated by calling device.IpcGetOperation; +// it is not compatible with other uapi streams. +func FromUAPI(r io.Reader) (*Config, error) { + cfg := new(Config) + var peer *Peer // current peer being operated on + deviceConfig := true + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := mem.B(scanner.Bytes()) + if line.Len() == 0 { + continue + } + key, value, ok := memROCut(line, '=') + if !ok { + return nil, fmt.Errorf("failed to cut line %q on =", line.StringCopy()) + } + valueBytes := scanner.Bytes()[key.Len()+1:] + + if key.EqualString("public_key") { + if deviceConfig { + deviceConfig = false + } + // Load/create the peer we are now configuring. + var err error + peer, err = cfg.handlePublicKeyLine(valueBytes) + if err != nil { + return nil, err + } + continue + } + + var err error + if deviceConfig { + err = cfg.handleDeviceLine(key, value, valueBytes) + } else { + err = cfg.handlePeerLine(peer, key, value, valueBytes) + } + if err != nil { + return nil, err + } + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return cfg, nil +} + +func (cfg *Config) handleDeviceLine(k, value mem.RO, valueBytes []byte) error { + switch { + case k.EqualString("private_key"): + // wireguard-go guarantees not to send zero value; private keys are already clamped. + var err error + cfg.PrivateKey, err = key.ParseNodePrivateUntyped(value) + if err != nil { + return err + } + case k.EqualString("listen_port") || k.EqualString("fwmark"): + // ignore + default: + return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) + } + return nil +} + +func (cfg *Config) handlePublicKeyLine(valueBytes []byte) (*Peer, error) { + p := Peer{} + var err error + p.PublicKey, err = key.ParseNodePublicUntyped(mem.B(valueBytes)) + if err != nil { + return nil, err + } + cfg.Peers = append(cfg.Peers, p) + return &cfg.Peers[len(cfg.Peers)-1], nil +} + +func (cfg *Config) handlePeerLine(peer *Peer, k, value mem.RO, valueBytes []byte) error { + switch { + case k.EqualString("endpoint"): + nk, err := key.ParseNodePublicUntyped(value) + if err != nil { + return fmt.Errorf("invalid endpoint %q for peer %q, expected a hex public key", value.StringCopy(), peer.PublicKey.ShortString()) + } + // nk ought to equal peer.PublicKey. + // Under some rare circumstances, it might not. See corp issue #3016. + // Even if that happens, don't stop early, so that we can recover from it. + // Instead, note the value of nk so we can fix as needed. + peer.WGEndpoint = nk + case k.EqualString("persistent_keepalive_interval"): + n, err := mem.ParseUint(value, 10, 16) + if err != nil { + return err + } + peer.PersistentKeepalive = uint16(n) + case k.EqualString("allowed_ip"): + ipp := netip.Prefix{} + err := ipp.UnmarshalText(valueBytes) + if err != nil { + return err + } + peer.AllowedIPs = append(peer.AllowedIPs, ipp) + case k.EqualString("protocol_version"): + if !value.EqualString("1") { + return fmt.Errorf("invalid protocol version: %q", value.StringCopy()) + } + case k.EqualString("replace_allowed_ips") || + k.EqualString("preshared_key") || + k.EqualString("last_handshake_time_sec") || + k.EqualString("last_handshake_time_nsec") || + k.EqualString("tx_bytes") || + k.EqualString("rx_bytes"): + // ignore + default: + return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) + } + return nil +} diff --git a/wgengine/winnet/winnet_windows.go b/wgengine/winnet/winnet_windows.go index 283ce5ad17b68..01e38517d2d64 100644 --- a/wgengine/winnet/winnet_windows.go +++ b/wgengine/winnet/winnet_windows.go @@ -1,26 +1,26 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package winnet - -import ( - "fmt" - "syscall" - "unsafe" - - "github.com/go-ole/go-ole" -) - -func (v *INetworkConnection) GetAdapterId() (string, error) { - buf := ole.GUID{} - hr, _, _ := syscall.Syscall( - v.VTable().GetAdapterId, - 2, - uintptr(unsafe.Pointer(v)), - uintptr(unsafe.Pointer(&buf)), - 0) - if hr != 0 { - return "", fmt.Errorf("GetAdapterId failed: %08x", hr) - } - return buf.String(), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package winnet + +import ( + "fmt" + "syscall" + "unsafe" + + "github.com/go-ole/go-ole" +) + +func (v *INetworkConnection) GetAdapterId() (string, error) { + buf := ole.GUID{} + hr, _, _ := syscall.Syscall( + v.VTable().GetAdapterId, + 2, + uintptr(unsafe.Pointer(v)), + uintptr(unsafe.Pointer(&buf)), + 0) + if hr != 0 { + return "", fmt.Errorf("GetAdapterId failed: %08x", hr) + } + return buf.String(), nil +} diff --git a/words/words.go b/words/words.go index b373ffef6541f..18efb75d77506 100644 --- a/words/words.go +++ b/words/words.go @@ -1,58 +1,58 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package words contains accessors for some nice words. -package words - -import ( - "bytes" - _ "embed" - "strings" - "sync" -) - -//go:embed tails.txt -var tailsTxt []byte - -//go:embed scales.txt -var scalesTxt []byte - -var ( - once sync.Once - tails, scales []string -) - -// Tails returns words about tails. -func Tails() []string { - once.Do(initWords) - return tails -} - -// Scales returns words about scales. -func Scales() []string { - once.Do(initWords) - return scales -} - -func initWords() { - tails = parseWords(tailsTxt) - scales = parseWords(scalesTxt) -} - -func parseWords(txt []byte) []string { - n := bytes.Count(txt, []byte{'\n'}) - ret := make([]string, 0, n) - for len(txt) > 0 { - word := txt - i := bytes.IndexByte(txt, '\n') - if i != -1 { - word, txt = word[:i], txt[i+1:] - } else { - txt = nil - } - if word := strings.TrimSpace(string(word)); word != "" && word[0] != '#' { - ret = append(ret, word) - } - } - return ret -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package words contains accessors for some nice words. +package words + +import ( + "bytes" + _ "embed" + "strings" + "sync" +) + +//go:embed tails.txt +var tailsTxt []byte + +//go:embed scales.txt +var scalesTxt []byte + +var ( + once sync.Once + tails, scales []string +) + +// Tails returns words about tails. +func Tails() []string { + once.Do(initWords) + return tails +} + +// Scales returns words about scales. +func Scales() []string { + once.Do(initWords) + return scales +} + +func initWords() { + tails = parseWords(tailsTxt) + scales = parseWords(scalesTxt) +} + +func parseWords(txt []byte) []string { + n := bytes.Count(txt, []byte{'\n'}) + ret := make([]string, 0, n) + for len(txt) > 0 { + word := txt + i := bytes.IndexByte(txt, '\n') + if i != -1 { + word, txt = word[:i], txt[i+1:] + } else { + txt = nil + } + if word := strings.TrimSpace(string(word)); word != "" && word[0] != '#' { + ret = append(ret, word) + } + } + return ret +} diff --git a/words/words_test.go b/words/words_test.go index a9691792a5c00..e96c234d7b84b 100644 --- a/words/words_test.go +++ b/words/words_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package words - -import ( - "strings" - "testing" -) - -func TestWords(t *testing.T) { - test := func(t *testing.T, words []string) { - t.Helper() - if len(words) == 0 { - t.Error("no words") - } - seen := map[string]bool{} - for _, w := range words { - if seen[w] { - t.Errorf("dup word %q", w) - } - seen[w] = true - if w == "" || strings.IndexFunc(w, nonASCIILower) != -1 { - t.Errorf("malformed word %q", w) - } - } - } - t.Run("tails", func(t *testing.T) { test(t, Tails()) }) - t.Run("scales", func(t *testing.T) { test(t, Scales()) }) - t.Logf("%v tails * %v scales = %v beautiful combinations", len(Tails()), len(Scales()), len(Tails())*len(Scales())) -} - -func nonASCIILower(r rune) bool { - if 'a' <= r && r <= 'z' { - return false - } - return true -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package words + +import ( + "strings" + "testing" +) + +func TestWords(t *testing.T) { + test := func(t *testing.T, words []string) { + t.Helper() + if len(words) == 0 { + t.Error("no words") + } + seen := map[string]bool{} + for _, w := range words { + if seen[w] { + t.Errorf("dup word %q", w) + } + seen[w] = true + if w == "" || strings.IndexFunc(w, nonASCIILower) != -1 { + t.Errorf("malformed word %q", w) + } + } + } + t.Run("tails", func(t *testing.T) { test(t, Tails()) }) + t.Run("scales", func(t *testing.T) { test(t, Scales()) }) + t.Logf("%v tails * %v scales = %v beautiful combinations", len(Tails()), len(Scales()), len(Tails())*len(Scales())) +} + +func nonASCIILower(r rune) bool { + if 'a' <= r && r <= 'z' { + return false + } + return true +} From 2690b4762f9f6eded9857acfd70ab4a913aebcc1 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 5 Dec 2024 15:25:42 -0800 Subject: [PATCH 172/179] Revert "VERSION.txt: this is v1.78.0" This reverts commit 0267fe83b200f1702a2fa0a395442c02a053fadb. Reason: it converted the tree to Windows line endings. Updates #14299 Change-Id: I2271a61d43e99bd0bbcf9f4831e8783e570ba08a Signed-off-by: Brad Fitzpatrick --- .bencher/config.yaml | 2 +- .gitattributes | 4 +- .github/ISSUE_TEMPLATE/bug_report.yml | 162 +- .github/ISSUE_TEMPLATE/config.yml | 14 +- .github/ISSUE_TEMPLATE/feature_request.yml | 84 +- .github/dependabot.yml | 42 +- AUTHORS | 34 +- CODEOWNERS | 2 +- CODE_OF_CONDUCT.md | 270 ++-- LICENSE | 56 +- PATENTS | 48 +- SECURITY.md | 16 +- VERSION.txt | 2 +- atomicfile/atomicfile.go | 102 +- atomicfile/atomicfile_test.go | 94 +- chirp/chirp.go | 326 ++-- chirp/chirp_test.go | 384 ++--- client/tailscale/apitype/controltype.go | 38 +- client/tailscale/dns.go | 466 +++--- client/tailscale/example/servetls/servetls.go | 56 +- client/tailscale/keys.go | 332 ++-- client/tailscale/routes.go | 190 +-- client/tailscale/tailnet.go | 84 +- client/web/qnap.go | 254 +-- client/web/src/assets/icons/arrow-right.svg | 8 +- .../web/src/assets/icons/arrow-up-circle.svg | 10 +- client/web/src/assets/icons/check-circle.svg | 8 +- client/web/src/assets/icons/check.svg | 6 +- client/web/src/assets/icons/chevron-down.svg | 6 +- client/web/src/assets/icons/eye.svg | 22 +- client/web/src/assets/icons/search.svg | 8 +- .../web/src/assets/icons/tailscale-icon.svg | 36 +- .../web/src/assets/icons/tailscale-logo.svg | 40 +- client/web/src/assets/icons/user.svg | 8 +- client/web/src/assets/icons/x-circle.svg | 10 +- client/web/synology.go | 118 +- clientupdate/distsign/distsign.go | 972 +++++------ clientupdate/distsign/roots.go | 108 +- clientupdate/distsign/roots/crawshaw-root.pem | 6 +- .../roots/distsign-prod-root-1-pub.pem | 6 +- clientupdate/distsign/roots_test.go | 32 +- cmd/addlicense/main.go | 146 +- cmd/cloner/cloner_test.go | 120 +- cmd/containerboot/test_tailscale.sh | 16 +- cmd/containerboot/test_tailscaled.sh | 76 +- cmd/get-authkey/.gitignore | 2 +- cmd/gitops-pusher/.gitignore | 2 +- cmd/gitops-pusher/README.md | 96 +- cmd/gitops-pusher/cache.go | 132 +- cmd/gitops-pusher/gitops-pusher_test.go | 110 +- cmd/k8s-operator/deploy/chart/.helmignore | 46 +- cmd/k8s-operator/deploy/chart/Chart.yaml | 58 +- .../chart/templates/apiserverproxy-rbac.yaml | 52 +- .../deploy/chart/templates/oauth-secret.yaml | 26 +- .../deploy/manifests/authproxy-rbac.yaml | 46 +- cmd/mkmanifest/main.go | 102 +- cmd/mkpkg/main.go | 268 ++-- cmd/mkversion/mkversion.go | 88 +- cmd/nardump/README.md | 14 +- cmd/nardump/nardump.go | 368 ++--- cmd/nginx-auth/.gitignore | 8 +- cmd/nginx-auth/README.md | 322 ++-- cmd/nginx-auth/deb/postinst.sh | 28 +- cmd/nginx-auth/deb/postrm.sh | 38 +- cmd/nginx-auth/deb/prerm.sh | 16 +- cmd/nginx-auth/mkdeb.sh | 64 +- cmd/nginx-auth/nginx-auth.go | 256 +-- cmd/nginx-auth/rpm/postrm.sh | 18 +- cmd/nginx-auth/rpm/prerm.sh | 18 +- cmd/nginx-auth/tailscale.nginx-auth.service | 22 +- cmd/nginx-auth/tailscale.nginx-auth.socket | 16 +- cmd/pgproxy/README.md | 84 +- cmd/printdep/printdep.go | 82 +- cmd/sniproxy/.gitignore | 2 +- cmd/sniproxy/handlers_test.go | 318 ++-- cmd/sniproxy/server.go | 654 ++++---- cmd/sniproxy/server_test.go | 190 +-- cmd/sniproxy/sniproxy.go | 582 +++---- cmd/speedtest/speedtest.go | 242 +-- cmd/ssh-auth-none-demo/ssh-auth-none-demo.go | 374 ++--- cmd/sync-containers/main.go | 428 ++--- cmd/tailscale/cli/diag.go | 148 +- cmd/tailscale/cli/diag_other.go | 30 +- cmd/tailscale/cli/set_test.go | 262 +-- cmd/tailscale/cli/ssh_exec.go | 48 +- cmd/tailscale/cli/ssh_exec_js.go | 32 +- cmd/tailscale/cli/ssh_exec_windows.go | 74 +- cmd/tailscale/cli/ssh_unix.go | 98 +- cmd/tailscale/cli/web_test.go | 90 +- cmd/tailscale/generate.go | 16 +- cmd/tailscale/tailscale.go | 52 +- cmd/tailscale/windows-manifest.xml | 26 +- cmd/tailscaled/childproc/childproc.go | 38 +- cmd/tailscaled/generate.go | 16 +- cmd/tailscaled/install_darwin.go | 398 ++--- cmd/tailscaled/install_windows.go | 248 +-- cmd/tailscaled/proxy.go | 160 +- cmd/tailscaled/sigpipe.go | 24 +- cmd/tailscaled/tailscaled.defaults | 16 +- cmd/tailscaled/tailscaled.openrc | 50 +- cmd/tailscaled/tailscaled_bird.go | 34 +- cmd/tailscaled/tailscaled_notwindows.go | 28 +- cmd/tailscaled/windows-manifest.xml | 26 +- cmd/tailscaled/with_cli.go | 46 +- cmd/testwrapper/args_test.go | 194 +-- cmd/testwrapper/flakytest/flakytest.go | 88 +- cmd/testwrapper/flakytest/flakytest_test.go | 86 +- cmd/tsconnect/.gitignore | 6 +- cmd/tsconnect/README.md | 98 +- cmd/tsconnect/README.pkg.md | 6 +- cmd/tsconnect/build-pkg.go | 198 +-- cmd/tsconnect/dev-pkg.go | 36 +- cmd/tsconnect/dev.go | 36 +- cmd/tsconnect/dist/placeholder | 4 +- cmd/tsconnect/index.html | 40 +- cmd/tsconnect/package.json | 50 +- cmd/tsconnect/package.json.tmpl | 32 +- cmd/tsconnect/serve.go | 288 ++-- cmd/tsconnect/src/app/app.tsx | 294 ++-- cmd/tsconnect/src/app/go-panic-display.tsx | 40 +- cmd/tsconnect/src/app/header.tsx | 74 +- cmd/tsconnect/src/app/index.css | 148 +- cmd/tsconnect/src/app/index.ts | 72 +- cmd/tsconnect/src/app/ssh.tsx | 314 ++-- cmd/tsconnect/src/app/url-display.tsx | 62 +- cmd/tsconnect/src/lib/js-state-store.ts | 26 +- cmd/tsconnect/src/pkg/pkg.css | 16 +- cmd/tsconnect/src/pkg/pkg.ts | 80 +- cmd/tsconnect/src/types/esbuild.d.ts | 28 +- cmd/tsconnect/src/types/wasm_js.d.ts | 206 +-- cmd/tsconnect/tailwind.config.js | 16 +- cmd/tsconnect/tsconfig.json | 30 +- cmd/tsconnect/tsconnect.go | 142 +- cmd/tsconnect/yarn.lock | 1426 ++++++++--------- cmd/tsshd/tsshd.go | 24 +- control/controlbase/conn.go | 816 +++++----- control/controlbase/handshake.go | 988 ++++++------ control/controlbase/interop_test.go | 512 +++--- control/controlbase/messages.go | 174 +- control/controlclient/sign.go | 84 +- control/controlclient/sign_supported_test.go | 472 +++--- control/controlclient/sign_unsupported.go | 32 +- control/controlclient/status.go | 250 +-- control/controlhttp/client_common.go | 34 +- derp/README.md | 120 +- derp/testdata/example_ss.txt | 16 +- disco/disco_fuzzer.go | 34 +- disco/disco_test.go | 236 +-- disco/pcap.go | 80 +- docs/bird/sample_bird.conf | 32 +- docs/bird/tailscale_bird.conf | 8 +- docs/k8s/Makefile | 50 +- docs/k8s/rolebinding.yaml | 26 +- docs/k8s/sa.yaml | 12 +- docs/sysv/tailscale.init | 126 +- doctor/doctor.go | 158 +- doctor/doctor_test.go | 98 +- doctor/permissions/permissions_bsd.go | 46 +- doctor/permissions/permissions_linux.go | 124 +- doctor/permissions/permissions_other.go | 34 +- doctor/permissions/permissions_test.go | 24 +- doctor/routetable/routetable.go | 68 +- envknob/envknob_nottest.go | 32 +- envknob/envknob_testable.go | 46 +- envknob/logknob/logknob.go | 170 +- envknob/logknob/logknob_test.go | 204 +-- gomod_test.go | 50 +- hostinfo/hostinfo_darwin.go | 42 +- hostinfo/hostinfo_freebsd.go | 128 +- hostinfo/hostinfo_test.go | 102 +- hostinfo/hostinfo_uname.go | 76 +- hostinfo/wol.go | 212 +-- ipn/ipnlocal/breaktcp_darwin.go | 60 +- ipn/ipnlocal/breaktcp_linux.go | 60 +- ipn/ipnlocal/expiry_test.go | 602 +++---- ipn/ipnlocal/peerapi_h2c.go | 40 +- ipn/ipnlocal/testdata/example.com-key.pem | 54 +- ipn/ipnlocal/testdata/example.com.pem | 50 +- ipn/ipnlocal/testdata/rootCA.pem | 58 +- ipn/ipnserver/proxyconnect_js.go | 20 +- ipn/ipnserver/server_test.go | 92 +- ipn/localapi/disabled_stubs.go | 30 +- ipn/localapi/pprof.go | 56 +- ipn/policy/policy.go | 94 +- ipn/store/awsstore/store_aws.go | 372 ++--- ipn/store/awsstore/store_aws_stub.go | 36 +- ipn/store/awsstore/store_aws_test.go | 328 ++-- ipn/store/stores_test.go | 358 ++--- ipn/store_test.go | 96 +- jsondb/db.go | 114 +- jsondb/db_test.go | 110 +- licenses/licenses.go | 42 +- log/filelogger/log.go | 456 +++--- log/filelogger/log_test.go | 54 +- logpolicy/logpolicy_test.go | 72 +- logtail/.gitignore | 12 +- logtail/README.md | 18 +- logtail/api.md | 388 ++--- logtail/example/logreprocess/demo.sh | 172 +- logtail/example/logreprocess/logreprocess.go | 230 +-- logtail/example/logtail/logtail.go | 92 +- logtail/filch/filch.go | 568 +++---- logtail/filch/filch_stub.go | 46 +- logtail/filch/filch_unix.go | 60 +- logtail/filch/filch_windows.go | 86 +- metrics/fds_linux.go | 82 +- metrics/fds_notlinux.go | 16 +- metrics/metrics.go | 326 ++-- net/art/art_test.go | 40 +- net/art/table.go | 1282 +++++++-------- net/dns/debian_resolvconf.go | 368 ++--- net/dns/direct_notlinux.go | 20 +- net/dns/flush_default.go | 20 +- net/dns/ini.go | 60 +- net/dns/ini_test.go | 76 +- net/dns/noop.go | 34 +- net/dns/resolvconf-workaround.sh | 124 +- net/dns/resolvconf.go | 60 +- net/dns/resolvconffile/resolvconffile.go | 248 +-- net/dns/resolvconfpath_default.go | 22 +- net/dns/resolvconfpath_gokrazy.go | 22 +- net/dns/resolver/doh_test.go | 198 +-- net/dns/resolver/macios_ext.go | 52 +- net/dns/resolver/tsdns_server_test.go | 666 ++++---- net/dns/utf.go | 110 +- net/dns/utf_test.go | 48 +- net/dnscache/dnscache_test.go | 484 +++--- net/dnscache/messagecache_test.go | 582 +++---- net/dnsfallback/update-dns-fallbacks.go | 90 +- net/memnet/conn.go | 228 +-- net/memnet/conn_test.go | 42 +- net/memnet/listener.go | 200 +-- net/memnet/listener_test.go | 66 +- net/memnet/memnet.go | 16 +- net/memnet/pipe.go | 488 +++--- net/memnet/pipe_test.go | 234 +-- net/netaddr/netaddr.go | 98 +- net/neterror/neterror.go | 164 +- net/neterror/neterror_linux.go | 52 +- net/neterror/neterror_linux_test.go | 108 +- net/neterror/neterror_windows.go | 32 +- net/netkernelconf/netkernelconf.go | 10 +- net/netknob/netknob.go | 58 +- net/netmon/netmon_darwin_test.go | 54 +- net/netmon/netmon_freebsd.go | 112 +- net/netmon/netmon_linux.go | 580 +++---- net/netmon/netmon_polling.go | 42 +- net/netmon/polling.go | 172 +- net/netns/netns_android.go | 150 +- net/netns/netns_default.go | 44 +- net/netns/netns_linux_test.go | 28 +- net/netns/netns_test.go | 156 +- net/netns/socks.go | 38 +- net/netstat/netstat.go | 70 +- net/netstat/netstat_noimpl.go | 28 +- net/netstat/netstat_test.go | 42 +- net/packet/doc.go | 30 +- net/packet/header.go | 132 +- net/packet/icmp.go | 56 +- net/packet/icmp6_test.go | 158 +- net/packet/ip4.go | 232 +-- net/packet/ip6.go | 152 +- net/packet/tsmp_test.go | 146 +- net/packet/udp4.go | 116 +- net/packet/udp6.go | 108 +- net/ping/ping.go | 686 ++++---- net/ping/ping_test.go | 700 ++++---- net/portmapper/pcp_test.go | 124 +- net/proxymux/mux.go | 288 ++-- net/routetable/routetable_darwin.go | 72 +- net/routetable/routetable_freebsd.go | 56 +- net/routetable/routetable_other.go | 34 +- net/sockstats/sockstats.go | 242 +-- net/sockstats/sockstats_noop.go | 76 +- net/sockstats/sockstats_tsgo_darwin.go | 60 +- net/speedtest/speedtest.go | 174 +- net/speedtest/speedtest_client.go | 82 +- net/speedtest/speedtest_server.go | 292 ++-- net/speedtest/speedtest_test.go | 166 +- net/stun/stun.go | 624 ++++---- net/stun/stun_fuzzer.go | 24 +- net/tcpinfo/tcpinfo.go | 102 +- net/tcpinfo/tcpinfo_darwin.go | 66 +- net/tcpinfo/tcpinfo_linux.go | 66 +- net/tcpinfo/tcpinfo_other.go | 30 +- net/tlsdial/deps_test.go | 16 +- net/tsdial/dnsmap_test.go | 250 +-- net/tsdial/dohclient.go | 200 +-- net/tsdial/dohclient_test.go | 62 +- net/tshttpproxy/mksyscall.go | 22 +- net/tshttpproxy/tshttpproxy_linux.go | 48 +- net/tshttpproxy/tshttpproxy_synology_test.go | 752 ++++----- net/tshttpproxy/tshttpproxy_windows.go | 552 +++---- net/tstun/fake.go | 116 +- net/tstun/ifstatus_noop.go | 36 +- net/tstun/ifstatus_windows.go | 218 +-- net/tstun/linkattrs_linux.go | 126 +- net/tstun/linkattrs_notlinux.go | 24 +- net/tstun/mtu.go | 322 ++-- net/tstun/mtu_test.go | 198 +-- net/tstun/tun_linux.go | 206 +-- net/tstun/tun_macos.go | 50 +- net/tstun/tun_notwindows.go | 24 +- packages/deb/deb.go | 364 ++--- packages/deb/deb_test.go | 410 ++--- paths/migrate.go | 116 +- paths/paths.go | 184 +-- paths/paths_windows.go | 200 +-- portlist/clean.go | 72 +- portlist/clean_test.go | 114 +- portlist/netstat_test.go | 184 +-- portlist/poller.go | 244 +-- portlist/portlist.go | 160 +- portlist/portlist_macos.go | 460 +++--- portlist/portlist_windows.go | 206 +-- posture/serialnumber_macos.go | 148 +- posture/serialnumber_notmacos_test.go | 76 +- posture/serialnumber_test.go | 32 +- pull-toolchain.sh | 32 +- release/deb/debian.postrm.sh | 34 +- release/deb/debian.prerm.sh | 14 +- release/dist/memoize.go | 172 +- release/dist/synology/files/Tailscale.sc | 10 +- release/dist/synology/files/config | 22 +- release/dist/synology/files/index.cgi | 4 +- release/dist/synology/files/logrotate-dsm6 | 16 +- release/dist/synology/files/logrotate-dsm7 | 16 +- release/dist/synology/files/privilege-dsm6 | 14 +- release/dist/synology/files/privilege-dsm7 | 14 +- .../files/privilege-dsm7.for-package-center | 26 +- release/dist/synology/files/resource | 20 +- .../dist/synology/files/scripts/postupgrade | 4 +- .../dist/synology/files/scripts/preupgrade | 4 +- .../synology/files/scripts/start-stop-status | 258 +-- release/dist/unixpkgs/pkgs.go | 944 +++++------ release/dist/unixpkgs/targets.go | 254 +-- release/release.go | 30 +- release/rpm/rpm.postinst.sh | 82 +- release/rpm/rpm.postrm.sh | 16 +- release/rpm/rpm.prerm.sh | 16 +- safesocket/safesocket_test.go | 24 +- smallzstd/testdata | 28 +- smallzstd/zstd.go | 156 +- syncs/locked.go | 64 +- syncs/locked_test.go | 240 +-- syncs/shardedmap.go | 276 ++-- syncs/shardedmap_test.go | 162 +- tailcfg/proto_port_range.go | 374 ++--- tailcfg/proto_port_range_test.go | 262 +-- tailcfg/tka.go | 528 +++--- taildrop/delete.go | 410 ++--- taildrop/delete_test.go | 304 ++-- taildrop/resume_test.go | 148 +- taildrop/retrieve.go | 356 ++-- tempfork/gliderlabs/ssh/LICENSE | 54 +- tempfork/gliderlabs/ssh/README.md | 192 +-- tempfork/gliderlabs/ssh/agent.go | 166 +- tempfork/gliderlabs/ssh/conn.go | 110 +- tempfork/gliderlabs/ssh/context.go | 328 ++-- tempfork/gliderlabs/ssh/context_test.go | 98 +- tempfork/gliderlabs/ssh/doc.go | 90 +- tempfork/gliderlabs/ssh/example_test.go | 100 +- tempfork/gliderlabs/ssh/options.go | 168 +- tempfork/gliderlabs/ssh/options_test.go | 222 +-- tempfork/gliderlabs/ssh/server.go | 918 +++++------ tempfork/gliderlabs/ssh/server_test.go | 256 +-- tempfork/gliderlabs/ssh/session.go | 772 ++++----- tempfork/gliderlabs/ssh/session_test.go | 880 +++++----- tempfork/gliderlabs/ssh/ssh.go | 312 ++-- tempfork/gliderlabs/ssh/ssh_test.go | 34 +- tempfork/gliderlabs/ssh/tcpip.go | 386 ++--- tempfork/gliderlabs/ssh/tcpip_test.go | 170 +- tempfork/gliderlabs/ssh/util.go | 314 ++-- tempfork/gliderlabs/ssh/wrap.go | 66 +- tempfork/heap/heap.go | 242 +-- tka/aum_test.go | 506 +++--- tka/builder.go | 360 ++--- tka/builder_test.go | 540 +++---- tka/deeplink.go | 442 ++--- tka/deeplink_test.go | 104 +- tka/key.go | 318 ++-- tka/key_test.go | 194 +-- tka/state.go | 630 ++++---- tka/state_test.go | 520 +++--- tka/sync_test.go | 754 ++++----- tka/tailchonk_test.go | 1386 ++++++++-------- tka/tka_test.go | 1308 +++++++-------- tool/binaryen.rev | 2 +- tool/go | 14 +- tool/gocross/env.go | 262 +-- tool/gocross/env_test.go | 198 +-- tool/gocross/exec_other.go | 40 +- tool/gocross/exec_unix.go | 24 +- tool/helm | 138 +- tool/helm.rev | 2 +- tool/node | 130 +- tool/wasm-opt | 148 +- tool/yarn | 86 +- tool/yarn.rev | 2 +- tsnet/example/tshello/tshello.go | 120 +- .../tsnet-http-client/tsnet-http-client.go | 88 +- tsnet/example/web-client/web-client.go | 92 +- tsnet/example_tshello_test.go | 144 +- tstest/allocs.go | 100 +- tstest/archtest/qemu_test.go | 146 +- tstest/clock.go | 1388 ++++++++-------- tstest/deptest/deptest_test.go | 20 +- tstest/integration/gen_deps.go | 130 +- tstest/integration/vms/README.md | 190 +-- tstest/integration/vms/distros.hujson | 78 +- tstest/integration/vms/distros_test.go | 28 +- tstest/integration/vms/dns_tester.go | 108 +- tstest/integration/vms/doc.go | 12 +- tstest/integration/vms/harness_test.go | 484 +++--- tstest/integration/vms/nixos_test.go | 462 +++--- tstest/integration/vms/regex_flag.go | 58 +- tstest/integration/vms/regex_flag_test.go | 42 +- tstest/integration/vms/runner.nix | 178 +- tstest/integration/vms/squid.conf | 76 +- tstest/integration/vms/top_level_test.go | 248 +-- tstest/integration/vms/udp_tester.go | 154 +- tstest/log_test.go | 94 +- tstest/natlab/firewall.go | 312 ++-- tstest/natlab/nat.go | 504 +++--- tstest/tstest.go | 190 +-- tstest/tstest_test.go | 48 +- tstime/mono/mono.go | 254 +-- tstime/rate/rate.go | 180 +-- tstime/tstime.go | 370 ++--- tstime/tstime_test.go | 72 +- tsweb/debug_test.go | 416 ++--- tsweb/promvarz/promvarz_test.go | 76 +- types/appctype/appconnector_test.go | 156 +- types/dnstype/dnstype.go | 136 +- types/empty/message.go | 26 +- types/flagtype/flagtype.go | 90 +- types/ipproto/ipproto.go | 398 ++--- types/key/chal.go | 182 +-- types/key/control.go | 136 +- types/key/control_test.go | 76 +- types/key/disco_test.go | 166 +- types/key/machine.go | 528 +++--- types/key/machine_test.go | 238 +-- types/key/nl_test.go | 96 +- types/lazy/unsync.go | 198 +-- types/lazy/unsync_test.go | 280 ++-- types/logger/rusage.go | 46 +- types/logger/rusage_stub.go | 22 +- types/logger/rusage_syscall.go | 58 +- types/logger/tokenbucket.go | 126 +- types/netlogtype/netlogtype.go | 200 +-- types/netlogtype/netlogtype_test.go | 78 +- types/netmap/netmap_test.go | 636 ++++---- types/nettype/nettype.go | 130 +- types/preftype/netfiltermode.go | 92 +- types/ptr/ptr.go | 20 +- types/structs/structs.go | 30 +- types/tkatype/tkatype.go | 80 +- types/tkatype/tkatype_test.go | 86 +- util/cibuild/cibuild.go | 28 +- util/cstruct/cstruct.go | 356 ++-- util/cstruct/cstruct_example_test.go | 146 +- util/deephash/debug.go | 74 +- util/deephash/pointer.go | 228 +-- util/deephash/pointer_norace.go | 26 +- util/deephash/pointer_race.go | 198 +-- util/deephash/testtype/testtype.go | 30 +- util/dirwalk/dirwalk.go | 106 +- util/dirwalk/dirwalk_linux.go | 334 ++-- util/dirwalk/dirwalk_test.go | 182 +-- util/goroutines/goroutines.go | 186 +-- util/goroutines/goroutines_test.go | 58 +- util/groupmember/groupmember.go | 58 +- util/hashx/block512.go | 394 ++--- util/httphdr/httphdr.go | 394 ++--- util/httphdr/httphdr_test.go | 192 +-- util/httpm/httpm.go | 72 +- util/httpm/httpm_test.go | 74 +- util/jsonutil/types.go | 32 +- util/jsonutil/unmarshal.go | 178 +- util/lineread/lineread.go | 74 +- util/linuxfw/linuxfwtest/linuxfwtest.go | 62 +- .../linuxfwtest/linuxfwtest_unsupported.go | 36 +- util/linuxfw/nftables_types.go | 190 +-- util/mak/mak.go | 140 +- util/mak/mak_test.go | 176 +- util/multierr/multierr.go | 272 ++-- util/must/must.go | 50 +- util/osdiag/mksyscall.go | 26 +- util/osdiag/osdiag_windows_test.go | 256 +-- util/osshare/filesharingstatus_noop.go | 24 +- util/pidowner/pidowner.go | 48 +- util/pidowner/pidowner_noimpl.go | 16 +- util/pidowner/pidowner_windows.go | 70 +- util/precompress/precompress.go | 258 +-- util/quarantine/quarantine.go | 28 +- util/quarantine/quarantine_darwin.go | 112 +- util/quarantine/quarantine_default.go | 28 +- util/quarantine/quarantine_windows.go | 58 +- util/race/race_test.go | 198 +-- util/racebuild/off.go | 16 +- util/racebuild/on.go | 16 +- util/racebuild/racebuild.go | 12 +- util/rands/rands.go | 50 +- util/rands/rands_test.go | 30 +- util/set/handle.go | 56 +- util/set/slice_test.go | 112 +- util/sysresources/memory.go | 20 +- util/sysresources/memory_bsd.go | 32 +- util/sysresources/memory_darwin.go | 32 +- util/sysresources/memory_linux.go | 38 +- util/sysresources/memory_unsupported.go | 16 +- util/sysresources/sysresources.go | 12 +- util/sysresources/sysresources_test.go | 50 +- util/systemd/doc.go | 26 +- util/systemd/systemd_linux.go | 154 +- util/systemd/systemd_nonlinux.go | 18 +- util/testenv/testenv.go | 42 +- util/truncate/truncate_test.go | 72 +- util/uniq/slice.go | 124 +- util/winutil/authenticode/mksyscall.go | 36 +- util/winutil/policy/policy_windows.go | 310 ++-- util/winutil/policy/policy_windows_test.go | 76 +- version/.gitignore | 20 +- version/cmdname.go | 278 ++-- version/cmdname_ios.go | 36 +- version/cmp_test.go | 164 +- version/export_test.go | 28 +- version/print.go | 66 +- version/race.go | 20 +- version/race_off.go | 20 +- version/version_test.go | 102 +- wgengine/bench/bench.go | 818 +++++----- wgengine/bench/bench_test.go | 216 +-- wgengine/bench/trafficgen.go | 518 +++--- wgengine/capture/capture.go | 476 +++--- wgengine/magicsock/blockforever_conn.go | 110 +- wgengine/magicsock/endpoint_default.go | 44 +- wgengine/magicsock/endpoint_stub.go | 26 +- wgengine/magicsock/endpoint_tracker.go | 496 +++--- wgengine/magicsock/magicsock_unix_test.go | 120 +- wgengine/magicsock/peermtu_darwin.go | 102 +- wgengine/magicsock/peermtu_linux.go | 98 +- wgengine/magicsock/peermtu_unix.go | 84 +- wgengine/mem_ios.go | 40 +- wgengine/netstack/netstack_linux.go | 38 +- wgengine/router/runner.go | 240 +-- wgengine/watchdog_js.go | 34 +- wgengine/wgcfg/device.go | 136 +- wgengine/wgcfg/device_test.go | 522 +++--- wgengine/wgcfg/parser.go | 372 ++--- wgengine/winnet/winnet_windows.go | 52 +- words/words.go | 116 +- words/words_test.go | 76 +- 554 files changed, 44582 insertions(+), 44582 deletions(-) diff --git a/.bencher/config.yaml b/.bencher/config.yaml index b60c5c352d48a..220bd9d3b7dc0 100644 --- a/.bencher/config.yaml +++ b/.bencher/config.yaml @@ -1 +1 @@ -suppress_failure_on_regression: true +suppress_failure_on_regression: true diff --git a/.gitattributes b/.gitattributes index 38a6b06a3147f..3eb52878271f3 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,2 @@ -go.mod filter=go-mod -*.go diff=golang +go.mod filter=go-mod +*.go diff=golang diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 688de14440a46..9163171c90248 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -1,81 +1,81 @@ -name: Bug report -description: File a bug report. If you need help, contact support instead -labels: [needs-triage, bug] -body: - - type: markdown - attributes: - value: | - Need help with your tailnet? [Contact support](https://tailscale.com/contact/support) instead. - Otherwise, please check if your bug is [already filed](https://github.com/tailscale/tailscale/issues) before filing a new one. - - type: textarea - id: what-happened - attributes: - label: What is the issue? - description: What happened? What did you expect to happen? - validations: - required: true - - type: textarea - id: steps - attributes: - label: Steps to reproduce - description: What are the steps you took that hit this issue? - validations: - required: false - - type: textarea - id: changes - attributes: - label: Are there any recent changes that introduced the issue? - description: If so, what are those changes? - validations: - required: false - - type: dropdown - id: os - attributes: - label: OS - description: What OS are you using? You may select more than one. - multiple: true - options: - - Linux - - macOS - - Windows - - iOS - - Android - - Synology - - Other - validations: - required: false - - type: input - id: os-version - attributes: - label: OS version - description: What OS version are you using? - placeholder: e.g., Debian 11.0, macOS Big Sur 11.6, Synology DSM 7 - validations: - required: false - - type: input - id: ts-version - attributes: - label: Tailscale version - description: What Tailscale version are you using? - placeholder: e.g., 1.14.4 - validations: - required: false - - type: textarea - id: other-software - attributes: - label: Other software - description: What [other software](https://github.com/tailscale/tailscale/wiki/OtherSoftwareInterop) (networking, security, etc) are you running? - validations: - required: false - - type: input - id: bug-report - attributes: - label: Bug report - description: Please run [`tailscale bugreport`](https://tailscale.com/kb/1080/cli/?q=Cli#bugreport) and share the bug identifier. The identifier is a random string which allows Tailscale support to locate your account and gives a point to focus on when looking for errors. - placeholder: e.g., BUG-1b7641a16971a9cd75822c0ed8043fee70ae88cf05c52981dc220eb96a5c49a8-20210427151443Z-fbcd4fd3a4b7ad94 - validations: - required: false - - type: markdown - attributes: - value: | - Thanks for filing a bug report! +name: Bug report +description: File a bug report. If you need help, contact support instead +labels: [needs-triage, bug] +body: + - type: markdown + attributes: + value: | + Need help with your tailnet? [Contact support](https://tailscale.com/contact/support) instead. + Otherwise, please check if your bug is [already filed](https://github.com/tailscale/tailscale/issues) before filing a new one. + - type: textarea + id: what-happened + attributes: + label: What is the issue? + description: What happened? What did you expect to happen? + validations: + required: true + - type: textarea + id: steps + attributes: + label: Steps to reproduce + description: What are the steps you took that hit this issue? + validations: + required: false + - type: textarea + id: changes + attributes: + label: Are there any recent changes that introduced the issue? + description: If so, what are those changes? + validations: + required: false + - type: dropdown + id: os + attributes: + label: OS + description: What OS are you using? You may select more than one. + multiple: true + options: + - Linux + - macOS + - Windows + - iOS + - Android + - Synology + - Other + validations: + required: false + - type: input + id: os-version + attributes: + label: OS version + description: What OS version are you using? + placeholder: e.g., Debian 11.0, macOS Big Sur 11.6, Synology DSM 7 + validations: + required: false + - type: input + id: ts-version + attributes: + label: Tailscale version + description: What Tailscale version are you using? + placeholder: e.g., 1.14.4 + validations: + required: false + - type: textarea + id: other-software + attributes: + label: Other software + description: What [other software](https://github.com/tailscale/tailscale/wiki/OtherSoftwareInterop) (networking, security, etc) are you running? + validations: + required: false + - type: input + id: bug-report + attributes: + label: Bug report + description: Please run [`tailscale bugreport`](https://tailscale.com/kb/1080/cli/?q=Cli#bugreport) and share the bug identifier. The identifier is a random string which allows Tailscale support to locate your account and gives a point to focus on when looking for errors. + placeholder: e.g., BUG-1b7641a16971a9cd75822c0ed8043fee70ae88cf05c52981dc220eb96a5c49a8-20210427151443Z-fbcd4fd3a4b7ad94 + validations: + required: false + - type: markdown + attributes: + value: | + Thanks for filing a bug report! diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index e3c44b6a1ab0a..3f4a31534b7d7 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,8 +1,8 @@ -blank_issues_enabled: true -contact_links: - - name: Support - url: https://tailscale.com/contact/support/ - about: Contact us for support - - name: Troubleshooting - url: https://tailscale.com/kb/1023/troubleshooting +blank_issues_enabled: true +contact_links: + - name: Support + url: https://tailscale.com/contact/support/ + about: Contact us for support + - name: Troubleshooting + url: https://tailscale.com/kb/1023/troubleshooting about: See the troubleshooting guide for help addressing common issues \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 02ecae13c5acd..f7538627483ab 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -1,42 +1,42 @@ -name: Feature request -description: Propose a new feature -title: "FR: " -labels: [needs-triage, fr] -body: - - type: markdown - attributes: - value: | - Please check if your feature request is [already filed](https://github.com/tailscale/tailscale/issues). - Tell us about your idea! - - type: textarea - id: problem - attributes: - label: What are you trying to do? - description: Tell us about the problem you're trying to solve. - validations: - required: false - - type: textarea - id: solution - attributes: - label: How should we solve this? - description: If you have an idea of how you'd like to see this feature work, let us know. - validations: - required: false - - type: textarea - id: alternative - attributes: - label: What is the impact of not solving this? - description: (How) Are you currently working around the issue? - validations: - required: false - - type: textarea - id: context - attributes: - label: Anything else? - description: Any additional context to share, e.g., links - validations: - required: false - - type: markdown - attributes: - value: | - Thanks for filing a feature request! +name: Feature request +description: Propose a new feature +title: "FR: " +labels: [needs-triage, fr] +body: + - type: markdown + attributes: + value: | + Please check if your feature request is [already filed](https://github.com/tailscale/tailscale/issues). + Tell us about your idea! + - type: textarea + id: problem + attributes: + label: What are you trying to do? + description: Tell us about the problem you're trying to solve. + validations: + required: false + - type: textarea + id: solution + attributes: + label: How should we solve this? + description: If you have an idea of how you'd like to see this feature work, let us know. + validations: + required: false + - type: textarea + id: alternative + attributes: + label: What is the impact of not solving this? + description: (How) Are you currently working around the issue? + validations: + required: false + - type: textarea + id: context + attributes: + label: Anything else? + description: Any additional context to share, e.g., links + validations: + required: false + - type: markdown + attributes: + value: | + Thanks for filing a feature request! diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 225132e5485c0..14c912905363e 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,21 +1,21 @@ -# Documentation for this file can be found at: -# https://docs.github.com/en/code-security/supply-chain-security/keeping-your-dependencies-updated-automatically/configuration-options-for-dependency-updates -version: 2 -updates: - ## Disabled between releases. We reenable it briefly after every - ## stable release, pull in all changes, and close it again so that - ## the tree remains more stable during development and the upstream - ## changes have time to soak before the next release. - # - package-ecosystem: "gomod" - # directory: "/" - # schedule: - # interval: "daily" - # commit-message: - # prefix: "go.mod:" - # open-pull-requests-limit: 100 - - package-ecosystem: "github-actions" - directory: "/" - schedule: - interval: "weekly" - commit-message: - prefix: ".github:" +# Documentation for this file can be found at: +# https://docs.github.com/en/code-security/supply-chain-security/keeping-your-dependencies-updated-automatically/configuration-options-for-dependency-updates +version: 2 +updates: + ## Disabled between releases. We reenable it briefly after every + ## stable release, pull in all changes, and close it again so that + ## the tree remains more stable during development and the upstream + ## changes have time to soak before the next release. + # - package-ecosystem: "gomod" + # directory: "/" + # schedule: + # interval: "daily" + # commit-message: + # prefix: "go.mod:" + # open-pull-requests-limit: 100 + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + commit-message: + prefix: ".github:" diff --git a/AUTHORS b/AUTHORS index 3fafc44923b2c..03d5932c04746 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,17 +1,17 @@ -# This is the official list of Tailscale -# authors for copyright purposes. -# -# Names should be added to this file as one of -# Organization's name -# Individual's name -# Individual's name -# -# Please keep the list sorted. -# -# You do not need to add entries to this list, and we don't actively -# populate this list. If you do want to be acknowledged explicitly as -# a copyright holder, though, then please send a PR referencing your -# earlier contributions and clarifying whether it's you or your -# company that owns the rights to your contribution. - -Tailscale Inc. +# This is the official list of Tailscale +# authors for copyright purposes. +# +# Names should be added to this file as one of +# Organization's name +# Individual's name +# Individual's name +# +# Please keep the list sorted. +# +# You do not need to add entries to this list, and we don't actively +# populate this list. If you do want to be acknowledged explicitly as +# a copyright holder, though, then please send a PR referencing your +# earlier contributions and clarifying whether it's you or your +# company that owns the rights to your contribution. + +Tailscale Inc. diff --git a/CODEOWNERS b/CODEOWNERS index 76edf10061958..af9b0d9f95928 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1 +1 @@ -/tailcfg/ @tailscale/control-protocol-owners +/tailcfg/ @tailscale/control-protocol-owners diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index cf4e6ddbe4c31..be5564ef4a3de 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,135 +1,135 @@ -# Contributor Covenant Code of Conduct - -## Our Pledge - -We as members, contributors, and leaders pledge to make participation -in our community a harassment-free experience for everyone, regardless -of age, body size, visible or invisible disability, ethnicity, sex -characteristics, gender identity and expression, level of experience, -education, socio-economic status, nationality, personal appearance, -race, religion, or sexual identity and orientation. - -We pledge to act and interact in ways that contribute to an open, -welcoming, diverse, inclusive, and healthy community. - -## Our Standards - -Examples of behavior that contributes to a positive environment for -our community include: - -* Demonstrating empathy and kindness toward other people -* Being respectful of differing opinions, viewpoints, and experiences -* Giving and gracefully accepting constructive feedback -* Accepting responsibility and apologizing to those affected by our - mistakes, and learning from the experience -* Focusing on what is best not just for us as individuals, but for the - overall community - -Examples of unacceptable behavior include: - -* The use of sexualized language or imagery, and sexual attention or - advances of any kind -* Trolling, insulting or derogatory comments, and personal or - political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or email - address, without their explicit permission -* Other conduct which could reasonably be considered inappropriate in - a professional setting - -## Enforcement Responsibilities - -Community leaders are responsible for clarifying and enforcing our -standards of acceptable behavior and will take appropriate and fair -corrective action in response to any behavior that they deem -inappropriate, threatening, offensive, or harmful. - -Community leaders have the right and responsibility to remove, edit, -or reject comments, commits, code, wiki edits, issues, and other -contributions that are not aligned to this Code of Conduct, and will -communicate reasons for moderation decisions when appropriate. - -## Scope - -This Code of Conduct applies within all community spaces, and also -applies when an individual is officially representing the community in -public spaces. Examples of representing our community include using an -official e-mail address, posting via an official social media account, -or acting as an appointed representative at an online or offline -event. - -## Enforcement - -Instances of abusive, harassing, or otherwise unacceptable behavior -may be reported to the community leaders responsible for enforcement -at [info@tailscale.com](mailto:info@tailscale.com). All complaints -will be reviewed and investigated promptly and fairly. - -All community leaders are obligated to respect the privacy and -security of the reporter of any incident. - -## Enforcement Guidelines - -Community leaders will follow these Community Impact Guidelines in -determining the consequences for any action they deem in violation of -this Code of Conduct: - -### 1. Correction - -**Community Impact**: Use of inappropriate language or other behavior -deemed unprofessional or unwelcome in the community. - -**Consequence**: A private, written warning from community leaders, -providing clarity around the nature of the violation and an -explanation of why the behavior was inappropriate. A public apology -may be requested. - -### 2. Warning - -**Community Impact**: A violation through a single incident or series -of actions. - -**Consequence**: A warning with consequences for continued -behavior. No interaction with the people involved, including -unsolicited interaction with those enforcing the Code of Conduct, for -a specified period of time. This includes avoiding interactions in -community spaces as well as external channels like social -media. Violating these terms may lead to a temporary or permanent ban. - -### 3. Temporary Ban - -**Community Impact**: A serious violation of community standards, -including sustained inappropriate behavior. - -**Consequence**: A temporary ban from any sort of interaction or -public communication with the community for a specified period of -time. No public or private interaction with the people involved, -including unsolicited interaction with those enforcing the Code of -Conduct, is allowed during this period. Violating these terms may lead -to a permanent ban. - -### 4. Permanent Ban - -**Community Impact**: Demonstrating a pattern of violation of -community standards, including sustained inappropriate behavior, -harassment of an individual, or aggression toward or disparagement of -classes of individuals. - -**Consequence**: A permanent ban from any sort of public interaction -within the community. - -## Attribution - -This Code of Conduct is adapted from the [Contributor -Covenant][homepage], version 2.0, available at -https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. - -Community Impact Guidelines were inspired by [Mozilla's code of -conduct enforcement ladder](https://github.com/mozilla/diversity). - -[homepage]: https://www.contributor-covenant.org - -For answers to common questions about this code of conduct, see the -FAQ at https://www.contributor-covenant.org/faq. Translations are -available at https://www.contributor-covenant.org/translations. - +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation +in our community a harassment-free experience for everyone, regardless +of age, body size, visible or invisible disability, ethnicity, sex +characteristics, gender identity and expression, level of experience, +education, socio-economic status, nationality, personal appearance, +race, religion, or sexual identity and orientation. + +We pledge to act and interact in ways that contribute to an open, +welcoming, diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for +our community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our + mistakes, and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or + political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in + a professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our +standards of acceptable behavior and will take appropriate and fair +corrective action in response to any behavior that they deem +inappropriate, threatening, offensive, or harmful. + +Community leaders have the right and responsibility to remove, edit, +or reject comments, commits, code, wiki edits, issues, and other +contributions that are not aligned to this Code of Conduct, and will +communicate reasons for moderation decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also +applies when an individual is officially representing the community in +public spaces. Examples of representing our community include using an +official e-mail address, posting via an official social media account, +or acting as an appointed representative at an online or offline +event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior +may be reported to the community leaders responsible for enforcement +at [info@tailscale.com](mailto:info@tailscale.com). All complaints +will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and +security of the reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in +determining the consequences for any action they deem in violation of +this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior +deemed unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, +providing clarity around the nature of the violation and an +explanation of why the behavior was inappropriate. A public apology +may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued +behavior. No interaction with the people involved, including +unsolicited interaction with those enforcing the Code of Conduct, for +a specified period of time. This includes avoiding interactions in +community spaces as well as external channels like social +media. Violating these terms may lead to a temporary or permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, +including sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or +public communication with the community for a specified period of +time. No public or private interaction with the people involved, +including unsolicited interaction with those enforcing the Code of +Conduct, is allowed during this period. Violating these terms may lead +to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of +community standards, including sustained inappropriate behavior, +harassment of an individual, or aggression toward or disparagement of +classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction +within the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor +Covenant][homepage], version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of +conduct enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the +FAQ at https://www.contributor-covenant.org/faq. Translations are +available at https://www.contributor-covenant.org/translations. + diff --git a/LICENSE b/LICENSE index 3d511c30c1ff5..394db19e4aa5c 100644 --- a/LICENSE +++ b/LICENSE @@ -1,28 +1,28 @@ -BSD 3-Clause License - -Copyright (c) 2020 Tailscale Inc & AUTHORS. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +BSD 3-Clause License + +Copyright (c) 2020 Tailscale Inc & AUTHORS. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/PATENTS b/PATENTS index b001fb9c1b0b2..560a2b8f0e401 100644 --- a/PATENTS +++ b/PATENTS @@ -1,24 +1,24 @@ -Additional IP Rights Grant (Patents) - -"This implementation" means the copyrightable works distributed by -Tailscale Inc. as part of the Tailscale project. - -Tailscale Inc. hereby grants to You a perpetual, worldwide, -non-exclusive, no-charge, royalty-free, irrevocable (except as stated -in this section) patent license to make, have made, use, offer to -sell, sell, import, transfer and otherwise run, modify and propagate -the contents of this implementation of Tailscale, where such license -applies only to those patent claims, both currently owned or -controlled by Tailscale Inc. and acquired in the future, licensable -by Tailscale Inc. that are necessarily infringed by this -implementation of Tailscale. This grant does not include claims that -would be infringed only as a consequence of further modification of -this implementation. If you or your agent or exclusive licensee -institute or order or agree to the institution of patent litigation -against any entity (including a cross-claim or counterclaim in a -lawsuit) alleging that this implementation of Tailscale or any code -incorporated within this implementation of Tailscale constitutes -direct or contributory patent infringement, or inducement of patent -infringement, then any patent rights granted to you under this License -for this implementation of Tailscale shall terminate as of the date -such litigation is filed. +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Tailscale Inc. as part of the Tailscale project. + +Tailscale Inc. hereby grants to You a perpetual, worldwide, +non-exclusive, no-charge, royalty-free, irrevocable (except as stated +in this section) patent license to make, have made, use, offer to +sell, sell, import, transfer and otherwise run, modify and propagate +the contents of this implementation of Tailscale, where such license +applies only to those patent claims, both currently owned or +controlled by Tailscale Inc. and acquired in the future, licensable +by Tailscale Inc. that are necessarily infringed by this +implementation of Tailscale. This grant does not include claims that +would be infringed only as a consequence of further modification of +this implementation. If you or your agent or exclusive licensee +institute or order or agree to the institution of patent litigation +against any entity (including a cross-claim or counterclaim in a +lawsuit) alleging that this implementation of Tailscale or any code +incorporated within this implementation of Tailscale constitutes +direct or contributory patent infringement, or inducement of patent +infringement, then any patent rights granted to you under this License +for this implementation of Tailscale shall terminate as of the date +such litigation is filed. diff --git a/SECURITY.md b/SECURITY.md index e8cd9a326c787..26702b14143c3 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,8 +1,8 @@ -# Security Policy - -## Reporting a Vulnerability - -You can report vulnerabilities privately to -[security@tailscale.com](mailto:security@tailscale.com). Tailscale -staff will triage the issue, and work with you on a coordinated -disclosure timeline. +# Security Policy + +## Reporting a Vulnerability + +You can report vulnerabilities privately to +[security@tailscale.com](mailto:security@tailscale.com). Tailscale +staff will triage the issue, and work with you on a coordinated +disclosure timeline. diff --git a/VERSION.txt b/VERSION.txt index 54227249d1ff9..79e15fd49370a 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -1.78.0 +1.77.0 diff --git a/atomicfile/atomicfile.go b/atomicfile/atomicfile.go index b95c7cbe14964..5c18e85a896eb 100644 --- a/atomicfile/atomicfile.go +++ b/atomicfile/atomicfile.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package atomicfile contains code related to writing to filesystems -// atomically. -// -// This package should be considered internal; its API is not stable. -package atomicfile // import "tailscale.com/atomicfile" - -import ( - "fmt" - "os" - "path/filepath" - "runtime" -) - -// WriteFile writes data to filename+some suffix, then renames it into filename. -// The perm argument is ignored on Windows. If the target filename already -// exists but is not a regular file, WriteFile returns an error. -func WriteFile(filename string, data []byte, perm os.FileMode) (err error) { - fi, err := os.Stat(filename) - if err == nil && !fi.Mode().IsRegular() { - return fmt.Errorf("%s already exists and is not a regular file", filename) - } - f, err := os.CreateTemp(filepath.Dir(filename), filepath.Base(filename)+".tmp") - if err != nil { - return err - } - tmpName := f.Name() - defer func() { - if err != nil { - f.Close() - os.Remove(tmpName) - } - }() - if _, err := f.Write(data); err != nil { - return err - } - if runtime.GOOS != "windows" { - if err := f.Chmod(perm); err != nil { - return err - } - } - if err := f.Sync(); err != nil { - return err - } - if err := f.Close(); err != nil { - return err - } - return os.Rename(tmpName, filename) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package atomicfile contains code related to writing to filesystems +// atomically. +// +// This package should be considered internal; its API is not stable. +package atomicfile // import "tailscale.com/atomicfile" + +import ( + "fmt" + "os" + "path/filepath" + "runtime" +) + +// WriteFile writes data to filename+some suffix, then renames it into filename. +// The perm argument is ignored on Windows. If the target filename already +// exists but is not a regular file, WriteFile returns an error. +func WriteFile(filename string, data []byte, perm os.FileMode) (err error) { + fi, err := os.Stat(filename) + if err == nil && !fi.Mode().IsRegular() { + return fmt.Errorf("%s already exists and is not a regular file", filename) + } + f, err := os.CreateTemp(filepath.Dir(filename), filepath.Base(filename)+".tmp") + if err != nil { + return err + } + tmpName := f.Name() + defer func() { + if err != nil { + f.Close() + os.Remove(tmpName) + } + }() + if _, err := f.Write(data); err != nil { + return err + } + if runtime.GOOS != "windows" { + if err := f.Chmod(perm); err != nil { + return err + } + } + if err := f.Sync(); err != nil { + return err + } + if err := f.Close(); err != nil { + return err + } + return os.Rename(tmpName, filename) +} diff --git a/atomicfile/atomicfile_test.go b/atomicfile/atomicfile_test.go index b7a78765b745e..78c93e664f738 100644 --- a/atomicfile/atomicfile_test.go +++ b/atomicfile/atomicfile_test.go @@ -1,47 +1,47 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !js && !windows - -package atomicfile - -import ( - "net" - "os" - "path/filepath" - "runtime" - "strings" - "testing" -) - -func TestDoesNotOverwriteIrregularFiles(t *testing.T) { - // Per tailscale/tailscale#7658 as one example, almost any imagined use of - // atomicfile.Write should likely not attempt to overwrite an irregular file - // such as a device node, socket, or named pipe. - - const filename = "TestDoesNotOverwriteIrregularFiles" - var path string - // macOS private temp does not allow unix socket creation, but /tmp does. - if runtime.GOOS == "darwin" { - path = filepath.Join("/tmp", filename) - t.Cleanup(func() { os.Remove(path) }) - } else { - path = filepath.Join(t.TempDir(), filename) - } - - // The least troublesome thing to make that is not a file is a unix socket. - // Making a null device sadly requires root. - l, err := net.ListenUnix("unix", &net.UnixAddr{Name: path, Net: "unix"}) - if err != nil { - t.Fatal(err) - } - defer l.Close() - - err = WriteFile(path, []byte("hello"), 0644) - if err == nil { - t.Fatal("expected error, got nil") - } - if !strings.Contains(err.Error(), "is not a regular file") { - t.Fatalf("unexpected error: %v", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !windows + +package atomicfile + +import ( + "net" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func TestDoesNotOverwriteIrregularFiles(t *testing.T) { + // Per tailscale/tailscale#7658 as one example, almost any imagined use of + // atomicfile.Write should likely not attempt to overwrite an irregular file + // such as a device node, socket, or named pipe. + + const filename = "TestDoesNotOverwriteIrregularFiles" + var path string + // macOS private temp does not allow unix socket creation, but /tmp does. + if runtime.GOOS == "darwin" { + path = filepath.Join("/tmp", filename) + t.Cleanup(func() { os.Remove(path) }) + } else { + path = filepath.Join(t.TempDir(), filename) + } + + // The least troublesome thing to make that is not a file is a unix socket. + // Making a null device sadly requires root. + l, err := net.ListenUnix("unix", &net.UnixAddr{Name: path, Net: "unix"}) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + err = WriteFile(path, []byte("hello"), 0644) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "is not a regular file") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/chirp/chirp.go b/chirp/chirp.go index 1b448f2394106..9653877221778 100644 --- a/chirp/chirp.go +++ b/chirp/chirp.go @@ -1,163 +1,163 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package chirp implements a client to communicate with the BIRD Internet -// Routing Daemon. -package chirp - -import ( - "bufio" - "fmt" - "net" - "strings" - "time" -) - -const ( - // Maximum amount of time we should wait when reading a response from BIRD. - responseTimeout = 10 * time.Second -) - -// New creates a BIRDClient. -func New(socket string) (*BIRDClient, error) { - return newWithTimeout(socket, responseTimeout) -} - -func newWithTimeout(socket string, timeout time.Duration) (_ *BIRDClient, err error) { - conn, err := net.Dial("unix", socket) - if err != nil { - return nil, fmt.Errorf("failed to connect to BIRD: %w", err) - } - defer func() { - if err != nil { - conn.Close() - } - }() - - b := &BIRDClient{ - socket: socket, - conn: conn, - scanner: bufio.NewScanner(conn), - timeNow: time.Now, - timeout: timeout, - } - // Read and discard the first line as that is the welcome message. - if _, err := b.readResponse(); err != nil { - return nil, err - } - return b, nil -} - -// BIRDClient handles communication with the BIRD Internet Routing Daemon. -type BIRDClient struct { - socket string - conn net.Conn - scanner *bufio.Scanner - timeNow func() time.Time - timeout time.Duration -} - -// Close closes the underlying connection to BIRD. -func (b *BIRDClient) Close() error { return b.conn.Close() } - -// DisableProtocol disables the provided protocol. -func (b *BIRDClient) DisableProtocol(protocol string) error { - out, err := b.exec("disable %s", protocol) - if err != nil { - return err - } - if strings.Contains(out, fmt.Sprintf("%s: already disabled", protocol)) { - return nil - } else if strings.Contains(out, fmt.Sprintf("%s: disabled", protocol)) { - return nil - } - return fmt.Errorf("failed to disable %s: %v", protocol, out) -} - -// EnableProtocol enables the provided protocol. -func (b *BIRDClient) EnableProtocol(protocol string) error { - out, err := b.exec("enable %s", protocol) - if err != nil { - return err - } - if strings.Contains(out, fmt.Sprintf("%s: already enabled", protocol)) { - return nil - } else if strings.Contains(out, fmt.Sprintf("%s: enabled", protocol)) { - return nil - } - return fmt.Errorf("failed to enable %s: %v", protocol, out) -} - -// BIRD CLI docs from https://bird.network.cz/?get_doc&v=20&f=prog-2.html#ss2.9 - -// Each session of the CLI consists of a sequence of request and replies, -// slightly resembling the FTP and SMTP protocols. -// Requests are commands encoded as a single line of text, -// replies are sequences of lines starting with a four-digit code -// followed by either a space (if it's the last line of the reply) or -// a minus sign (when the reply is going to continue with the next line), -// the rest of the line contains a textual message semantics of which depends on the numeric code. -// If a reply line has the same code as the previous one and it's a continuation line, -// the whole prefix can be replaced by a single white space character. -// -// Reply codes starting with 0 stand for ‘action successfully completed’ messages, -// 1 means ‘table entry’, 8 ‘runtime error’ and 9 ‘syntax error’. - -func (b *BIRDClient) exec(cmd string, args ...any) (string, error) { - if err := b.conn.SetWriteDeadline(b.timeNow().Add(b.timeout)); err != nil { - return "", err - } - if _, err := fmt.Fprintf(b.conn, cmd, args...); err != nil { - return "", err - } - if _, err := fmt.Fprintln(b.conn); err != nil { - return "", err - } - return b.readResponse() -} - -// hasResponseCode reports whether the provided byte slice is -// prefixed with a BIRD response code. -// Equivalent regex: `^\d{4}[ -]`. -func hasResponseCode(s []byte) bool { - if len(s) < 5 { - return false - } - for _, b := range s[:4] { - if '0' <= b && b <= '9' { - continue - } - return false - } - return s[4] == ' ' || s[4] == '-' -} - -func (b *BIRDClient) readResponse() (string, error) { - // Set the read timeout before we start reading anything. - if err := b.conn.SetReadDeadline(b.timeNow().Add(b.timeout)); err != nil { - return "", err - } - - var resp strings.Builder - var done bool - for !done { - if !b.scanner.Scan() { - if err := b.scanner.Err(); err != nil { - return "", err - } - - return "", fmt.Errorf("reading response from bird failed (EOF): %q", resp.String()) - } - out := b.scanner.Bytes() - if _, err := resp.Write(out); err != nil { - return "", err - } - if hasResponseCode(out) { - done = out[4] == ' ' - } - if !done { - resp.WriteRune('\n') - } - } - return resp.String(), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package chirp implements a client to communicate with the BIRD Internet +// Routing Daemon. +package chirp + +import ( + "bufio" + "fmt" + "net" + "strings" + "time" +) + +const ( + // Maximum amount of time we should wait when reading a response from BIRD. + responseTimeout = 10 * time.Second +) + +// New creates a BIRDClient. +func New(socket string) (*BIRDClient, error) { + return newWithTimeout(socket, responseTimeout) +} + +func newWithTimeout(socket string, timeout time.Duration) (_ *BIRDClient, err error) { + conn, err := net.Dial("unix", socket) + if err != nil { + return nil, fmt.Errorf("failed to connect to BIRD: %w", err) + } + defer func() { + if err != nil { + conn.Close() + } + }() + + b := &BIRDClient{ + socket: socket, + conn: conn, + scanner: bufio.NewScanner(conn), + timeNow: time.Now, + timeout: timeout, + } + // Read and discard the first line as that is the welcome message. + if _, err := b.readResponse(); err != nil { + return nil, err + } + return b, nil +} + +// BIRDClient handles communication with the BIRD Internet Routing Daemon. +type BIRDClient struct { + socket string + conn net.Conn + scanner *bufio.Scanner + timeNow func() time.Time + timeout time.Duration +} + +// Close closes the underlying connection to BIRD. +func (b *BIRDClient) Close() error { return b.conn.Close() } + +// DisableProtocol disables the provided protocol. +func (b *BIRDClient) DisableProtocol(protocol string) error { + out, err := b.exec("disable %s", protocol) + if err != nil { + return err + } + if strings.Contains(out, fmt.Sprintf("%s: already disabled", protocol)) { + return nil + } else if strings.Contains(out, fmt.Sprintf("%s: disabled", protocol)) { + return nil + } + return fmt.Errorf("failed to disable %s: %v", protocol, out) +} + +// EnableProtocol enables the provided protocol. +func (b *BIRDClient) EnableProtocol(protocol string) error { + out, err := b.exec("enable %s", protocol) + if err != nil { + return err + } + if strings.Contains(out, fmt.Sprintf("%s: already enabled", protocol)) { + return nil + } else if strings.Contains(out, fmt.Sprintf("%s: enabled", protocol)) { + return nil + } + return fmt.Errorf("failed to enable %s: %v", protocol, out) +} + +// BIRD CLI docs from https://bird.network.cz/?get_doc&v=20&f=prog-2.html#ss2.9 + +// Each session of the CLI consists of a sequence of request and replies, +// slightly resembling the FTP and SMTP protocols. +// Requests are commands encoded as a single line of text, +// replies are sequences of lines starting with a four-digit code +// followed by either a space (if it's the last line of the reply) or +// a minus sign (when the reply is going to continue with the next line), +// the rest of the line contains a textual message semantics of which depends on the numeric code. +// If a reply line has the same code as the previous one and it's a continuation line, +// the whole prefix can be replaced by a single white space character. +// +// Reply codes starting with 0 stand for ‘action successfully completed’ messages, +// 1 means ‘table entry’, 8 ‘runtime error’ and 9 ‘syntax error’. + +func (b *BIRDClient) exec(cmd string, args ...any) (string, error) { + if err := b.conn.SetWriteDeadline(b.timeNow().Add(b.timeout)); err != nil { + return "", err + } + if _, err := fmt.Fprintf(b.conn, cmd, args...); err != nil { + return "", err + } + if _, err := fmt.Fprintln(b.conn); err != nil { + return "", err + } + return b.readResponse() +} + +// hasResponseCode reports whether the provided byte slice is +// prefixed with a BIRD response code. +// Equivalent regex: `^\d{4}[ -]`. +func hasResponseCode(s []byte) bool { + if len(s) < 5 { + return false + } + for _, b := range s[:4] { + if '0' <= b && b <= '9' { + continue + } + return false + } + return s[4] == ' ' || s[4] == '-' +} + +func (b *BIRDClient) readResponse() (string, error) { + // Set the read timeout before we start reading anything. + if err := b.conn.SetReadDeadline(b.timeNow().Add(b.timeout)); err != nil { + return "", err + } + + var resp strings.Builder + var done bool + for !done { + if !b.scanner.Scan() { + if err := b.scanner.Err(); err != nil { + return "", err + } + + return "", fmt.Errorf("reading response from bird failed (EOF): %q", resp.String()) + } + out := b.scanner.Bytes() + if _, err := resp.Write(out); err != nil { + return "", err + } + if hasResponseCode(out) { + done = out[4] == ' ' + } + if !done { + resp.WriteRune('\n') + } + } + return resp.String(), nil +} diff --git a/chirp/chirp_test.go b/chirp/chirp_test.go index b8947a796c996..2549c163fd819 100644 --- a/chirp/chirp_test.go +++ b/chirp/chirp_test.go @@ -1,192 +1,192 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package chirp - -import ( - "bufio" - "errors" - "fmt" - "net" - "os" - "path/filepath" - "strings" - "sync" - "testing" - "time" -) - -type fakeBIRD struct { - net.Listener - protocolsEnabled map[string]bool - sock string -} - -func newFakeBIRD(t *testing.T, protocols ...string) *fakeBIRD { - sock := filepath.Join(t.TempDir(), "sock") - l, err := net.Listen("unix", sock) - if err != nil { - t.Fatal(err) - } - pe := make(map[string]bool) - for _, p := range protocols { - pe[p] = false - } - return &fakeBIRD{ - Listener: l, - protocolsEnabled: pe, - sock: sock, - } -} - -func (fb *fakeBIRD) listen() error { - for { - c, err := fb.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return nil - } - return err - } - go fb.handle(c) - } -} - -func (fb *fakeBIRD) handle(c net.Conn) { - fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.") - sc := bufio.NewScanner(c) - for sc.Scan() { - cmd := sc.Text() - args := strings.Split(cmd, " ") - switch args[0] { - case "enable": - en, ok := fb.protocolsEnabled[args[1]] - if !ok { - fmt.Fprintln(c, "9001 syntax error, unexpected CF_SYM_UNDEFINED, expecting CF_SYM_KNOWN or TEXT or ALL") - } else if en { - fmt.Fprintf(c, "0010-%s: already enabled\n", args[1]) - } else { - fmt.Fprintf(c, "0011-%s: enabled\n", args[1]) - } - fmt.Fprintln(c, "0000 ") - fb.protocolsEnabled[args[1]] = true - case "disable": - en, ok := fb.protocolsEnabled[args[1]] - if !ok { - fmt.Fprintln(c, "9001 syntax error, unexpected CF_SYM_UNDEFINED, expecting CF_SYM_KNOWN or TEXT or ALL") - } else if !en { - fmt.Fprintf(c, "0008-%s: already disabled\n", args[1]) - } else { - fmt.Fprintf(c, "0009-%s: disabled\n", args[1]) - } - fmt.Fprintln(c, "0000 ") - fb.protocolsEnabled[args[1]] = false - } - } -} - -func TestChirp(t *testing.T) { - fb := newFakeBIRD(t, "tailscale") - defer fb.Close() - go fb.listen() - c, err := New(fb.sock) - if err != nil { - t.Fatal(err) - } - if err := c.EnableProtocol("tailscale"); err != nil { - t.Fatal(err) - } - if err := c.EnableProtocol("tailscale"); err != nil { - t.Fatal(err) - } - if err := c.DisableProtocol("tailscale"); err != nil { - t.Fatal(err) - } - if err := c.DisableProtocol("tailscale"); err != nil { - t.Fatal(err) - } - if err := c.EnableProtocol("rando"); err == nil { - t.Fatalf("enabling %q succeeded", "rando") - } - if err := c.DisableProtocol("rando"); err == nil { - t.Fatalf("disabling %q succeeded", "rando") - } -} - -type hangingListener struct { - net.Listener - t *testing.T - done chan struct{} - wg sync.WaitGroup - sock string -} - -func newHangingListener(t *testing.T) *hangingListener { - sock := filepath.Join(t.TempDir(), "sock") - l, err := net.Listen("unix", sock) - if err != nil { - t.Fatal(err) - } - return &hangingListener{ - Listener: l, - t: t, - done: make(chan struct{}), - sock: sock, - } -} - -func (hl *hangingListener) Stop() { - hl.Close() - close(hl.done) - hl.wg.Wait() -} - -func (hl *hangingListener) listen() error { - for { - c, err := hl.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return nil - } - return err - } - hl.wg.Add(1) - go hl.handle(c) - } -} - -func (hl *hangingListener) handle(c net.Conn) { - defer hl.wg.Done() - - // Write our fake first line of response so that we get into the read loop - fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.") - - ticker := time.NewTicker(2 * time.Second) - defer ticker.Stop() - for { - select { - case <-ticker.C: - hl.t.Logf("connection still hanging") - case <-hl.done: - return - } - } -} - -func TestChirpTimeout(t *testing.T) { - fb := newHangingListener(t) - defer fb.Stop() - go fb.listen() - - c, err := newWithTimeout(fb.sock, 500*time.Millisecond) - if err != nil { - t.Fatal(err) - } - - err = c.EnableProtocol("tailscale") - if err == nil { - t.Fatal("got err=nil, want timeout") - } - if !os.IsTimeout(err) { - t.Fatalf("got err=%v, want os.IsTimeout(err)=true", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package chirp + +import ( + "bufio" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" +) + +type fakeBIRD struct { + net.Listener + protocolsEnabled map[string]bool + sock string +} + +func newFakeBIRD(t *testing.T, protocols ...string) *fakeBIRD { + sock := filepath.Join(t.TempDir(), "sock") + l, err := net.Listen("unix", sock) + if err != nil { + t.Fatal(err) + } + pe := make(map[string]bool) + for _, p := range protocols { + pe[p] = false + } + return &fakeBIRD{ + Listener: l, + protocolsEnabled: pe, + sock: sock, + } +} + +func (fb *fakeBIRD) listen() error { + for { + c, err := fb.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + go fb.handle(c) + } +} + +func (fb *fakeBIRD) handle(c net.Conn) { + fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.") + sc := bufio.NewScanner(c) + for sc.Scan() { + cmd := sc.Text() + args := strings.Split(cmd, " ") + switch args[0] { + case "enable": + en, ok := fb.protocolsEnabled[args[1]] + if !ok { + fmt.Fprintln(c, "9001 syntax error, unexpected CF_SYM_UNDEFINED, expecting CF_SYM_KNOWN or TEXT or ALL") + } else if en { + fmt.Fprintf(c, "0010-%s: already enabled\n", args[1]) + } else { + fmt.Fprintf(c, "0011-%s: enabled\n", args[1]) + } + fmt.Fprintln(c, "0000 ") + fb.protocolsEnabled[args[1]] = true + case "disable": + en, ok := fb.protocolsEnabled[args[1]] + if !ok { + fmt.Fprintln(c, "9001 syntax error, unexpected CF_SYM_UNDEFINED, expecting CF_SYM_KNOWN or TEXT or ALL") + } else if !en { + fmt.Fprintf(c, "0008-%s: already disabled\n", args[1]) + } else { + fmt.Fprintf(c, "0009-%s: disabled\n", args[1]) + } + fmt.Fprintln(c, "0000 ") + fb.protocolsEnabled[args[1]] = false + } + } +} + +func TestChirp(t *testing.T) { + fb := newFakeBIRD(t, "tailscale") + defer fb.Close() + go fb.listen() + c, err := New(fb.sock) + if err != nil { + t.Fatal(err) + } + if err := c.EnableProtocol("tailscale"); err != nil { + t.Fatal(err) + } + if err := c.EnableProtocol("tailscale"); err != nil { + t.Fatal(err) + } + if err := c.DisableProtocol("tailscale"); err != nil { + t.Fatal(err) + } + if err := c.DisableProtocol("tailscale"); err != nil { + t.Fatal(err) + } + if err := c.EnableProtocol("rando"); err == nil { + t.Fatalf("enabling %q succeeded", "rando") + } + if err := c.DisableProtocol("rando"); err == nil { + t.Fatalf("disabling %q succeeded", "rando") + } +} + +type hangingListener struct { + net.Listener + t *testing.T + done chan struct{} + wg sync.WaitGroup + sock string +} + +func newHangingListener(t *testing.T) *hangingListener { + sock := filepath.Join(t.TempDir(), "sock") + l, err := net.Listen("unix", sock) + if err != nil { + t.Fatal(err) + } + return &hangingListener{ + Listener: l, + t: t, + done: make(chan struct{}), + sock: sock, + } +} + +func (hl *hangingListener) Stop() { + hl.Close() + close(hl.done) + hl.wg.Wait() +} + +func (hl *hangingListener) listen() error { + for { + c, err := hl.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + hl.wg.Add(1) + go hl.handle(c) + } +} + +func (hl *hangingListener) handle(c net.Conn) { + defer hl.wg.Done() + + // Write our fake first line of response so that we get into the read loop + fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.") + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + hl.t.Logf("connection still hanging") + case <-hl.done: + return + } + } +} + +func TestChirpTimeout(t *testing.T) { + fb := newHangingListener(t) + defer fb.Stop() + go fb.listen() + + c, err := newWithTimeout(fb.sock, 500*time.Millisecond) + if err != nil { + t.Fatal(err) + } + + err = c.EnableProtocol("tailscale") + if err == nil { + t.Fatal("got err=nil, want timeout") + } + if !os.IsTimeout(err) { + t.Fatalf("got err=%v, want os.IsTimeout(err)=true", err) + } +} diff --git a/client/tailscale/apitype/controltype.go b/client/tailscale/apitype/controltype.go index a9a76065f711e..9a623be319606 100644 --- a/client/tailscale/apitype/controltype.go +++ b/client/tailscale/apitype/controltype.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package apitype - -type DNSConfig struct { - Resolvers []DNSResolver `json:"resolvers"` - FallbackResolvers []DNSResolver `json:"fallbackResolvers"` - Routes map[string][]DNSResolver `json:"routes"` - Domains []string `json:"domains"` - Nameservers []string `json:"nameservers"` - Proxied bool `json:"proxied"` - TempCorpIssue13969 string `json:"TempCorpIssue13969,omitempty"` -} - -type DNSResolver struct { - Addr string `json:"addr"` - BootstrapResolution []string `json:"bootstrapResolution,omitempty"` -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package apitype + +type DNSConfig struct { + Resolvers []DNSResolver `json:"resolvers"` + FallbackResolvers []DNSResolver `json:"fallbackResolvers"` + Routes map[string][]DNSResolver `json:"routes"` + Domains []string `json:"domains"` + Nameservers []string `json:"nameservers"` + Proxied bool `json:"proxied"` + TempCorpIssue13969 string `json:"TempCorpIssue13969,omitempty"` +} + +type DNSResolver struct { + Addr string `json:"addr"` + BootstrapResolution []string `json:"bootstrapResolution,omitempty"` +} diff --git a/client/tailscale/dns.go b/client/tailscale/dns.go index 12b9e15c8b7a5..f198742b3ca51 100644 --- a/client/tailscale/dns.go +++ b/client/tailscale/dns.go @@ -1,233 +1,233 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -package tailscale - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - - "tailscale.com/client/tailscale/apitype" -) - -// DNSNameServers is returned when retrieving the list of nameservers. -// It is also the structure provided when setting nameservers. -type DNSNameServers struct { - DNS []string `json:"dns"` // DNS name servers -} - -// DNSNameServersPostResponse is returned when setting the list of DNS nameservers. -// -// It includes the MagicDNS status since nameservers changes may affect MagicDNS. -type DNSNameServersPostResponse struct { - DNS []string `json:"dns"` // DNS name servers - MagicDNS bool `json:"magicDNS"` // whether MagicDNS is active for this tailnet (enabled + has fallback nameservers) -} - -// DNSSearchpaths is the list of search paths for a given domain. -type DNSSearchPaths struct { - SearchPaths []string `json:"searchPaths"` // DNS search paths -} - -// DNSPreferences is the preferences set for a given tailnet. -// -// It includes MagicDNS which can be turned on or off. To enable MagicDNS, -// there must be at least one nameserver. When all nameservers are removed, -// MagicDNS is disabled. -type DNSPreferences struct { - MagicDNS bool `json:"magicDNS"` // whether MagicDNS is active for this tailnet (enabled + has fallback nameservers) -} - -func (c *Client) dnsGETRequest(ctx context.Context, endpoint string) ([]byte, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) - req, err := http.NewRequestWithContext(ctx, "GET", path, nil) - if err != nil { - return nil, err - } - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - - // If status code was not successful, return the error. - // TODO: Change the check for the StatusCode to include other 2XX success codes. - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - return b, nil -} - -func (c *Client) dnsPOSTRequest(ctx context.Context, endpoint string, postData any) ([]byte, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) - data, err := json.Marshal(&postData) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(data)) - req.Header.Set("Content-Type", "application/json") - if err != nil { - return nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - - // If status code was not successful, return the error. - // TODO: Change the check for the StatusCode to include other 2XX success codes. - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - return b, nil -} - -// DNSConfig retrieves the DNSConfig settings for a domain. -func (c *Client) DNSConfig(ctx context.Context) (cfg *apitype.DNSConfig, err error) { - // Format return errors to be descriptive. - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.DNSConfig: %w", err) - } - }() - b, err := c.dnsGETRequest(ctx, "config") - if err != nil { - return nil, err - } - var dnsResp apitype.DNSConfig - err = json.Unmarshal(b, &dnsResp) - return &dnsResp, err -} - -func (c *Client) SetDNSConfig(ctx context.Context, cfg apitype.DNSConfig) (resp *apitype.DNSConfig, err error) { - // Format return errors to be descriptive. - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SetDNSConfig: %w", err) - } - }() - var dnsResp apitype.DNSConfig - b, err := c.dnsPOSTRequest(ctx, "config", cfg) - if err != nil { - return nil, err - } - err = json.Unmarshal(b, &dnsResp) - return &dnsResp, err -} - -// NameServers retrieves the list of nameservers set for a domain. -func (c *Client) NameServers(ctx context.Context) (nameservers []string, err error) { - // Format return errors to be descriptive. - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.NameServers: %w", err) - } - }() - b, err := c.dnsGETRequest(ctx, "nameservers") - if err != nil { - return nil, err - } - var dnsResp DNSNameServers - err = json.Unmarshal(b, &dnsResp) - return dnsResp.DNS, err -} - -// SetNameServers sets the list of nameservers for a tailnet to the list provided -// by the user. -// -// It returns the new list of nameservers and the MagicDNS status in case it was -// affected by the change. For example, removing all nameservers will turn off -// MagicDNS. -func (c *Client) SetNameServers(ctx context.Context, nameservers []string) (dnsResp *DNSNameServersPostResponse, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SetNameServers: %w", err) - } - }() - dnsReq := DNSNameServers{DNS: nameservers} - b, err := c.dnsPOSTRequest(ctx, "nameservers", dnsReq) - if err != nil { - return nil, err - } - err = json.Unmarshal(b, &dnsResp) - return dnsResp, err -} - -// DNSPreferences retrieves the DNS preferences set for a tailnet. -// -// It returns the status of MagicDNS. -func (c *Client) DNSPreferences(ctx context.Context) (dnsResp *DNSPreferences, err error) { - // Format return errors to be descriptive. - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.DNSPreferences: %w", err) - } - }() - b, err := c.dnsGETRequest(ctx, "preferences") - if err != nil { - return nil, err - } - err = json.Unmarshal(b, &dnsResp) - return dnsResp, err -} - -// SetDNSPreferences sets the DNS preferences for a tailnet. -// -// MagicDNS can only be enabled when there is at least one nameserver provided. -// When all nameservers are removed, MagicDNS is disabled and will stay disabled, -// unless explicitly enabled by a user again. -func (c *Client) SetDNSPreferences(ctx context.Context, magicDNS bool) (dnsResp *DNSPreferences, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SetDNSPreferences: %w", err) - } - }() - dnsReq := DNSPreferences{MagicDNS: magicDNS} - b, err := c.dnsPOSTRequest(ctx, "preferences", dnsReq) - if err != nil { - return - } - err = json.Unmarshal(b, &dnsResp) - return dnsResp, err -} - -// SearchPaths retrieves the list of searchpaths set for a tailnet. -func (c *Client) SearchPaths(ctx context.Context) (searchpaths []string, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SearchPaths: %w", err) - } - }() - b, err := c.dnsGETRequest(ctx, "searchpaths") - if err != nil { - return nil, err - } - var dnsResp *DNSSearchPaths - err = json.Unmarshal(b, &dnsResp) - return dnsResp.SearchPaths, err -} - -// SetSearchPaths sets the list of searchpaths for a tailnet. -func (c *Client) SetSearchPaths(ctx context.Context, searchpaths []string) (newSearchPaths []string, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SetSearchPaths: %w", err) - } - }() - dnsReq := DNSSearchPaths{SearchPaths: searchpaths} - b, err := c.dnsPOSTRequest(ctx, "searchpaths", dnsReq) - if err != nil { - return nil, err - } - var dnsResp DNSSearchPaths - err = json.Unmarshal(b, &dnsResp) - return dnsResp.SearchPaths, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package tailscale + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + + "tailscale.com/client/tailscale/apitype" +) + +// DNSNameServers is returned when retrieving the list of nameservers. +// It is also the structure provided when setting nameservers. +type DNSNameServers struct { + DNS []string `json:"dns"` // DNS name servers +} + +// DNSNameServersPostResponse is returned when setting the list of DNS nameservers. +// +// It includes the MagicDNS status since nameservers changes may affect MagicDNS. +type DNSNameServersPostResponse struct { + DNS []string `json:"dns"` // DNS name servers + MagicDNS bool `json:"magicDNS"` // whether MagicDNS is active for this tailnet (enabled + has fallback nameservers) +} + +// DNSSearchpaths is the list of search paths for a given domain. +type DNSSearchPaths struct { + SearchPaths []string `json:"searchPaths"` // DNS search paths +} + +// DNSPreferences is the preferences set for a given tailnet. +// +// It includes MagicDNS which can be turned on or off. To enable MagicDNS, +// there must be at least one nameserver. When all nameservers are removed, +// MagicDNS is disabled. +type DNSPreferences struct { + MagicDNS bool `json:"magicDNS"` // whether MagicDNS is active for this tailnet (enabled + has fallback nameservers) +} + +func (c *Client) dnsGETRequest(ctx context.Context, endpoint string) ([]byte, error) { + path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) + req, err := http.NewRequestWithContext(ctx, "GET", path, nil) + if err != nil { + return nil, err + } + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + return b, nil +} + +func (c *Client) dnsPOSTRequest(ctx context.Context, endpoint string, postData any) ([]byte, error) { + path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) + data, err := json.Marshal(&postData) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(data)) + req.Header.Set("Content-Type", "application/json") + if err != nil { + return nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + return b, nil +} + +// DNSConfig retrieves the DNSConfig settings for a domain. +func (c *Client) DNSConfig(ctx context.Context) (cfg *apitype.DNSConfig, err error) { + // Format return errors to be descriptive. + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.DNSConfig: %w", err) + } + }() + b, err := c.dnsGETRequest(ctx, "config") + if err != nil { + return nil, err + } + var dnsResp apitype.DNSConfig + err = json.Unmarshal(b, &dnsResp) + return &dnsResp, err +} + +func (c *Client) SetDNSConfig(ctx context.Context, cfg apitype.DNSConfig) (resp *apitype.DNSConfig, err error) { + // Format return errors to be descriptive. + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SetDNSConfig: %w", err) + } + }() + var dnsResp apitype.DNSConfig + b, err := c.dnsPOSTRequest(ctx, "config", cfg) + if err != nil { + return nil, err + } + err = json.Unmarshal(b, &dnsResp) + return &dnsResp, err +} + +// NameServers retrieves the list of nameservers set for a domain. +func (c *Client) NameServers(ctx context.Context) (nameservers []string, err error) { + // Format return errors to be descriptive. + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.NameServers: %w", err) + } + }() + b, err := c.dnsGETRequest(ctx, "nameservers") + if err != nil { + return nil, err + } + var dnsResp DNSNameServers + err = json.Unmarshal(b, &dnsResp) + return dnsResp.DNS, err +} + +// SetNameServers sets the list of nameservers for a tailnet to the list provided +// by the user. +// +// It returns the new list of nameservers and the MagicDNS status in case it was +// affected by the change. For example, removing all nameservers will turn off +// MagicDNS. +func (c *Client) SetNameServers(ctx context.Context, nameservers []string) (dnsResp *DNSNameServersPostResponse, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SetNameServers: %w", err) + } + }() + dnsReq := DNSNameServers{DNS: nameservers} + b, err := c.dnsPOSTRequest(ctx, "nameservers", dnsReq) + if err != nil { + return nil, err + } + err = json.Unmarshal(b, &dnsResp) + return dnsResp, err +} + +// DNSPreferences retrieves the DNS preferences set for a tailnet. +// +// It returns the status of MagicDNS. +func (c *Client) DNSPreferences(ctx context.Context) (dnsResp *DNSPreferences, err error) { + // Format return errors to be descriptive. + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.DNSPreferences: %w", err) + } + }() + b, err := c.dnsGETRequest(ctx, "preferences") + if err != nil { + return nil, err + } + err = json.Unmarshal(b, &dnsResp) + return dnsResp, err +} + +// SetDNSPreferences sets the DNS preferences for a tailnet. +// +// MagicDNS can only be enabled when there is at least one nameserver provided. +// When all nameservers are removed, MagicDNS is disabled and will stay disabled, +// unless explicitly enabled by a user again. +func (c *Client) SetDNSPreferences(ctx context.Context, magicDNS bool) (dnsResp *DNSPreferences, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SetDNSPreferences: %w", err) + } + }() + dnsReq := DNSPreferences{MagicDNS: magicDNS} + b, err := c.dnsPOSTRequest(ctx, "preferences", dnsReq) + if err != nil { + return + } + err = json.Unmarshal(b, &dnsResp) + return dnsResp, err +} + +// SearchPaths retrieves the list of searchpaths set for a tailnet. +func (c *Client) SearchPaths(ctx context.Context) (searchpaths []string, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SearchPaths: %w", err) + } + }() + b, err := c.dnsGETRequest(ctx, "searchpaths") + if err != nil { + return nil, err + } + var dnsResp *DNSSearchPaths + err = json.Unmarshal(b, &dnsResp) + return dnsResp.SearchPaths, err +} + +// SetSearchPaths sets the list of searchpaths for a tailnet. +func (c *Client) SetSearchPaths(ctx context.Context, searchpaths []string) (newSearchPaths []string, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SetSearchPaths: %w", err) + } + }() + dnsReq := DNSSearchPaths{SearchPaths: searchpaths} + b, err := c.dnsPOSTRequest(ctx, "searchpaths", dnsReq) + if err != nil { + return nil, err + } + var dnsResp DNSSearchPaths + err = json.Unmarshal(b, &dnsResp) + return dnsResp.SearchPaths, err +} diff --git a/client/tailscale/example/servetls/servetls.go b/client/tailscale/example/servetls/servetls.go index e426cbea2b375..f48e90d163527 100644 --- a/client/tailscale/example/servetls/servetls.go +++ b/client/tailscale/example/servetls/servetls.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The servetls program shows how to run an HTTPS server -// using a Tailscale cert via LetsEncrypt. -package main - -import ( - "crypto/tls" - "io" - "log" - "net/http" - - "tailscale.com/client/tailscale" -) - -func main() { - s := &http.Server{ - TLSConfig: &tls.Config{ - GetCertificate: tailscale.GetCertificate, - }, - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, "

Hello from Tailscale!

It works.") - }), - } - log.Printf("Running TLS server on :443 ...") - log.Fatal(s.ListenAndServeTLS("", "")) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The servetls program shows how to run an HTTPS server +// using a Tailscale cert via LetsEncrypt. +package main + +import ( + "crypto/tls" + "io" + "log" + "net/http" + + "tailscale.com/client/tailscale" +) + +func main() { + s := &http.Server{ + TLSConfig: &tls.Config{ + GetCertificate: tailscale.GetCertificate, + }, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "

Hello from Tailscale!

It works.") + }), + } + log.Printf("Running TLS server on :443 ...") + log.Fatal(s.ListenAndServeTLS("", "")) +} diff --git a/client/tailscale/keys.go b/client/tailscale/keys.go index ae5f721b74d6d..84bcdfae6aeeb 100644 --- a/client/tailscale/keys.go +++ b/client/tailscale/keys.go @@ -1,166 +1,166 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailscale - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "time" -) - -// Key represents a Tailscale API or auth key. -type Key struct { - ID string `json:"id"` - Created time.Time `json:"created"` - Expires time.Time `json:"expires"` - Capabilities KeyCapabilities `json:"capabilities"` -} - -// KeyCapabilities are the capabilities of a Key. -type KeyCapabilities struct { - Devices KeyDeviceCapabilities `json:"devices,omitempty"` -} - -// KeyDeviceCapabilities are the device-related capabilities of a Key. -type KeyDeviceCapabilities struct { - Create KeyDeviceCreateCapabilities `json:"create"` -} - -// KeyDeviceCreateCapabilities are the device creation capabilities of a Key. -type KeyDeviceCreateCapabilities struct { - Reusable bool `json:"reusable"` - Ephemeral bool `json:"ephemeral"` - Preauthorized bool `json:"preauthorized"` - Tags []string `json:"tags,omitempty"` -} - -// Keys returns the list of keys for the current user. -func (c *Client) Keys(ctx context.Context) ([]string, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) - req, err := http.NewRequestWithContext(ctx, "GET", path, nil) - if err != nil { - return nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - var keys struct { - Keys []*Key `json:"keys"` - } - if err := json.Unmarshal(b, &keys); err != nil { - return nil, err - } - ret := make([]string, 0, len(keys.Keys)) - for _, k := range keys.Keys { - ret = append(ret, k.ID) - } - return ret, nil -} - -// CreateKey creates a new key for the current user. Currently, only auth keys -// can be created. It returns the secret key itself, which cannot be retrieved again -// later, and the key metadata. -// -// To create a key with a specific expiry, use CreateKeyWithExpiry. -func (c *Client) CreateKey(ctx context.Context, caps KeyCapabilities) (keySecret string, keyMeta *Key, _ error) { - return c.CreateKeyWithExpiry(ctx, caps, 0) -} - -// CreateKeyWithExpiry is like CreateKey, but allows specifying a expiration time. -// -// The time is truncated to a whole number of seconds. If zero, that means no expiration. -func (c *Client) CreateKeyWithExpiry(ctx context.Context, caps KeyCapabilities, expiry time.Duration) (keySecret string, keyMeta *Key, _ error) { - - // convert expirySeconds to an int64 (seconds) - expirySeconds := int64(expiry.Seconds()) - if expirySeconds < 0 { - return "", nil, fmt.Errorf("expiry must be positive") - } - if expirySeconds == 0 && expiry != 0 { - return "", nil, fmt.Errorf("non-zero expiry must be at least one second") - } - - keyRequest := struct { - Capabilities KeyCapabilities `json:"capabilities"` - ExpirySeconds int64 `json:"expirySeconds,omitempty"` - }{caps, int64(expirySeconds)} - bs, err := json.Marshal(keyRequest) - if err != nil { - return "", nil, err - } - - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) - req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewReader(bs)) - if err != nil { - return "", nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return "", nil, err - } - if resp.StatusCode != http.StatusOK { - return "", nil, handleErrorResponse(b, resp) - } - - var key struct { - Key - Secret string `json:"key"` - } - if err := json.Unmarshal(b, &key); err != nil { - return "", nil, err - } - return key.Secret, &key.Key, nil -} - -// Key returns the metadata for the given key ID. Currently, capabilities are -// only returned for auth keys, API keys only return general metadata. -func (c *Client) Key(ctx context.Context, id string) (*Key, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) - req, err := http.NewRequestWithContext(ctx, "GET", path, nil) - if err != nil { - return nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - var key Key - if err := json.Unmarshal(b, &key); err != nil { - return nil, err - } - return &key, nil -} - -// DeleteKey deletes the key with the given ID. -func (c *Client) DeleteKey(ctx context.Context, id string) error { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) - req, err := http.NewRequestWithContext(ctx, "DELETE", path, nil) - if err != nil { - return err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return err - } - if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailscale + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "time" +) + +// Key represents a Tailscale API or auth key. +type Key struct { + ID string `json:"id"` + Created time.Time `json:"created"` + Expires time.Time `json:"expires"` + Capabilities KeyCapabilities `json:"capabilities"` +} + +// KeyCapabilities are the capabilities of a Key. +type KeyCapabilities struct { + Devices KeyDeviceCapabilities `json:"devices,omitempty"` +} + +// KeyDeviceCapabilities are the device-related capabilities of a Key. +type KeyDeviceCapabilities struct { + Create KeyDeviceCreateCapabilities `json:"create"` +} + +// KeyDeviceCreateCapabilities are the device creation capabilities of a Key. +type KeyDeviceCreateCapabilities struct { + Reusable bool `json:"reusable"` + Ephemeral bool `json:"ephemeral"` + Preauthorized bool `json:"preauthorized"` + Tags []string `json:"tags,omitempty"` +} + +// Keys returns the list of keys for the current user. +func (c *Client) Keys(ctx context.Context) ([]string, error) { + path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) + req, err := http.NewRequestWithContext(ctx, "GET", path, nil) + if err != nil { + return nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + var keys struct { + Keys []*Key `json:"keys"` + } + if err := json.Unmarshal(b, &keys); err != nil { + return nil, err + } + ret := make([]string, 0, len(keys.Keys)) + for _, k := range keys.Keys { + ret = append(ret, k.ID) + } + return ret, nil +} + +// CreateKey creates a new key for the current user. Currently, only auth keys +// can be created. It returns the secret key itself, which cannot be retrieved again +// later, and the key metadata. +// +// To create a key with a specific expiry, use CreateKeyWithExpiry. +func (c *Client) CreateKey(ctx context.Context, caps KeyCapabilities) (keySecret string, keyMeta *Key, _ error) { + return c.CreateKeyWithExpiry(ctx, caps, 0) +} + +// CreateKeyWithExpiry is like CreateKey, but allows specifying a expiration time. +// +// The time is truncated to a whole number of seconds. If zero, that means no expiration. +func (c *Client) CreateKeyWithExpiry(ctx context.Context, caps KeyCapabilities, expiry time.Duration) (keySecret string, keyMeta *Key, _ error) { + + // convert expirySeconds to an int64 (seconds) + expirySeconds := int64(expiry.Seconds()) + if expirySeconds < 0 { + return "", nil, fmt.Errorf("expiry must be positive") + } + if expirySeconds == 0 && expiry != 0 { + return "", nil, fmt.Errorf("non-zero expiry must be at least one second") + } + + keyRequest := struct { + Capabilities KeyCapabilities `json:"capabilities"` + ExpirySeconds int64 `json:"expirySeconds,omitempty"` + }{caps, int64(expirySeconds)} + bs, err := json.Marshal(keyRequest) + if err != nil { + return "", nil, err + } + + path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) + req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewReader(bs)) + if err != nil { + return "", nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return "", nil, err + } + if resp.StatusCode != http.StatusOK { + return "", nil, handleErrorResponse(b, resp) + } + + var key struct { + Key + Secret string `json:"key"` + } + if err := json.Unmarshal(b, &key); err != nil { + return "", nil, err + } + return key.Secret, &key.Key, nil +} + +// Key returns the metadata for the given key ID. Currently, capabilities are +// only returned for auth keys, API keys only return general metadata. +func (c *Client) Key(ctx context.Context, id string) (*Key, error) { + path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) + req, err := http.NewRequestWithContext(ctx, "GET", path, nil) + if err != nil { + return nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + var key Key + if err := json.Unmarshal(b, &key); err != nil { + return nil, err + } + return &key, nil +} + +// DeleteKey deletes the key with the given ID. +func (c *Client) DeleteKey(ctx context.Context, id string) error { + path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) + req, err := http.NewRequestWithContext(ctx, "DELETE", path, nil) + if err != nil { + return err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return err + } + if resp.StatusCode != http.StatusOK { + return handleErrorResponse(b, resp) + } + return nil +} diff --git a/client/tailscale/routes.go b/client/tailscale/routes.go index 41415d1b44c29..5912fc46c09a6 100644 --- a/client/tailscale/routes.go +++ b/client/tailscale/routes.go @@ -1,95 +1,95 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -package tailscale - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "net/netip" -) - -// Routes contains the lists of subnet routes that are currently advertised by a device, -// as well as the subnets that are enabled to be routed by the device. -type Routes struct { - AdvertisedRoutes []netip.Prefix `json:"advertisedRoutes"` - EnabledRoutes []netip.Prefix `json:"enabledRoutes"` -} - -// Routes retrieves the list of subnet routes that have been enabled for a device. -// The routes that are returned are not necessarily advertised by the device, -// they have only been preapproved. -func (c *Client) Routes(ctx context.Context, deviceID string) (routes *Routes, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.Routes: %w", err) - } - }() - - path := fmt.Sprintf("%s/api/v2/device/%s/routes", c.baseURL(), deviceID) - req, err := http.NewRequestWithContext(ctx, "GET", path, nil) - if err != nil { - return nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - // If status code was not successful, return the error. - // TODO: Change the check for the StatusCode to include other 2XX success codes. - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - var sr Routes - err = json.Unmarshal(b, &sr) - return &sr, err -} - -type postRoutesParams struct { - Routes []netip.Prefix `json:"routes"` -} - -// SetRoutes updates the list of subnets that are enabled for a device. -// Subnets must be parsable by net/netip.ParsePrefix. -// Subnets do not have to be currently advertised by a device, they may be pre-enabled. -// Returns the updated list of enabled and advertised subnet routes in a *Routes object. -func (c *Client) SetRoutes(ctx context.Context, deviceID string, subnets []netip.Prefix) (routes *Routes, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SetRoutes: %w", err) - } - }() - params := &postRoutesParams{Routes: subnets} - data, err := json.Marshal(params) - if err != nil { - return nil, err - } - path := fmt.Sprintf("%s/api/v2/device/%s/routes", c.baseURL(), deviceID) - req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(data)) - if err != nil { - return nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - // If status code was not successful, return the error. - // TODO: Change the check for the StatusCode to include other 2XX success codes. - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - var srr *Routes - if err := json.Unmarshal(b, &srr); err != nil { - return nil, err - } - return srr, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package tailscale + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/netip" +) + +// Routes contains the lists of subnet routes that are currently advertised by a device, +// as well as the subnets that are enabled to be routed by the device. +type Routes struct { + AdvertisedRoutes []netip.Prefix `json:"advertisedRoutes"` + EnabledRoutes []netip.Prefix `json:"enabledRoutes"` +} + +// Routes retrieves the list of subnet routes that have been enabled for a device. +// The routes that are returned are not necessarily advertised by the device, +// they have only been preapproved. +func (c *Client) Routes(ctx context.Context, deviceID string) (routes *Routes, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.Routes: %w", err) + } + }() + + path := fmt.Sprintf("%s/api/v2/device/%s/routes", c.baseURL(), deviceID) + req, err := http.NewRequestWithContext(ctx, "GET", path, nil) + if err != nil { + return nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + var sr Routes + err = json.Unmarshal(b, &sr) + return &sr, err +} + +type postRoutesParams struct { + Routes []netip.Prefix `json:"routes"` +} + +// SetRoutes updates the list of subnets that are enabled for a device. +// Subnets must be parsable by net/netip.ParsePrefix. +// Subnets do not have to be currently advertised by a device, they may be pre-enabled. +// Returns the updated list of enabled and advertised subnet routes in a *Routes object. +func (c *Client) SetRoutes(ctx context.Context, deviceID string, subnets []netip.Prefix) (routes *Routes, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SetRoutes: %w", err) + } + }() + params := &postRoutesParams{Routes: subnets} + data, err := json.Marshal(params) + if err != nil { + return nil, err + } + path := fmt.Sprintf("%s/api/v2/device/%s/routes", c.baseURL(), deviceID) + req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + var srr *Routes + if err := json.Unmarshal(b, &srr); err != nil { + return nil, err + } + return srr, err +} diff --git a/client/tailscale/tailnet.go b/client/tailscale/tailnet.go index eef2dca2014ad..2539e7f235b0e 100644 --- a/client/tailscale/tailnet.go +++ b/client/tailscale/tailnet.go @@ -1,42 +1,42 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -package tailscale - -import ( - "context" - "fmt" - "net/http" - "net/url" - - "tailscale.com/util/httpm" -) - -// TailnetDeleteRequest handles sending a DELETE request for a tailnet to control. -func (c *Client) TailnetDeleteRequest(ctx context.Context, tailnetID string) (err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.DeleteTailnet: %w", err) - } - }() - - path := fmt.Sprintf("%s/api/v2/tailnet/%s", c.baseURL(), url.PathEscape(string(tailnetID))) - req, err := http.NewRequestWithContext(ctx, httpm.DELETE, path, nil) - if err != nil { - return err - } - - c.setAuth(req) - b, resp, err := c.sendRequest(req) - if err != nil { - return err - } - - if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) - } - - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package tailscale + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "tailscale.com/util/httpm" +) + +// TailnetDeleteRequest handles sending a DELETE request for a tailnet to control. +func (c *Client) TailnetDeleteRequest(ctx context.Context, tailnetID string) (err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.DeleteTailnet: %w", err) + } + }() + + path := fmt.Sprintf("%s/api/v2/tailnet/%s", c.baseURL(), url.PathEscape(string(tailnetID))) + req, err := http.NewRequestWithContext(ctx, httpm.DELETE, path, nil) + if err != nil { + return err + } + + c.setAuth(req) + b, resp, err := c.sendRequest(req) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusOK { + return handleErrorResponse(b, resp) + } + + return nil +} diff --git a/client/web/qnap.go b/client/web/qnap.go index 8fa5ee174bae6..9bde64bf5885b 100644 --- a/client/web/qnap.go +++ b/client/web/qnap.go @@ -1,127 +1,127 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// qnap.go contains handlers and logic, such as authentication, -// that is specific to running the web client on QNAP. - -package web - -import ( - "crypto/tls" - "encoding/xml" - "errors" - "fmt" - "io" - "log" - "net/http" - "net/url" -) - -// authorizeQNAP authenticates the logged-in QNAP user and verifies that they -// are authorized to use the web client. -// If the user is not authorized to use the client, an error is returned. -func authorizeQNAP(r *http.Request) (authorized bool, err error) { - _, resp, err := qnapAuthn(r) - if err != nil { - return false, err - } - if resp.IsAdmin == 0 { - return false, errors.New("user is not an admin") - } - - return true, nil -} - -type qnapAuthResponse struct { - AuthPassed int `xml:"authPassed"` - IsAdmin int `xml:"isAdmin"` - AuthSID string `xml:"authSid"` - ErrorValue int `xml:"errorValue"` -} - -func qnapAuthn(r *http.Request) (string, *qnapAuthResponse, error) { - user, err := r.Cookie("NAS_USER") - if err != nil { - return "", nil, err - } - token, err := r.Cookie("qtoken") - if err == nil { - return qnapAuthnQtoken(r, user.Value, token.Value) - } - sid, err := r.Cookie("NAS_SID") - if err == nil { - return qnapAuthnSid(r, user.Value, sid.Value) - } - return "", nil, fmt.Errorf("not authenticated by any mechanism") -} - -// qnapAuthnURL returns the auth URL to use by inferring where the UI is -// running based on the request URL. This is necessary because QNAP has so -// many options, see https://github.com/tailscale/tailscale/issues/7108 -// and https://github.com/tailscale/tailscale/issues/6903 -func qnapAuthnURL(requestUrl string, query url.Values) string { - in, err := url.Parse(requestUrl) - scheme := "" - host := "" - if err != nil || in.Scheme == "" { - log.Printf("Cannot parse QNAP login URL %v", err) - - // try localhost and hope for the best - scheme = "http" - host = "localhost" - } else { - scheme = in.Scheme - host = in.Host - } - - u := url.URL{ - Scheme: scheme, - Host: host, - Path: "/cgi-bin/authLogin.cgi", - RawQuery: query.Encode(), - } - - return u.String() -} - -func qnapAuthnQtoken(r *http.Request, user, token string) (string, *qnapAuthResponse, error) { - query := url.Values{ - "qtoken": []string{token}, - "user": []string{user}, - } - return qnapAuthnFinish(user, qnapAuthnURL(r.URL.String(), query)) -} - -func qnapAuthnSid(r *http.Request, user, sid string) (string, *qnapAuthResponse, error) { - query := url.Values{ - "sid": []string{sid}, - } - return qnapAuthnFinish(user, qnapAuthnURL(r.URL.String(), query)) -} - -func qnapAuthnFinish(user, url string) (string, *qnapAuthResponse, error) { - // QNAP Force HTTPS mode uses a self-signed certificate. Even importing - // the QNAP root CA isn't enough, the cert doesn't have a usable CN nor - // SAN. See https://github.com/tailscale/tailscale/issues/6903 - tr := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - client := &http.Client{Transport: tr} - resp, err := client.Get(url) - if err != nil { - return "", nil, err - } - defer resp.Body.Close() - out, err := io.ReadAll(resp.Body) - if err != nil { - return "", nil, err - } - authResp := &qnapAuthResponse{} - if err := xml.Unmarshal(out, authResp); err != nil { - return "", nil, err - } - if authResp.AuthPassed == 0 { - return "", nil, fmt.Errorf("not authenticated") - } - return user, authResp, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// qnap.go contains handlers and logic, such as authentication, +// that is specific to running the web client on QNAP. + +package web + +import ( + "crypto/tls" + "encoding/xml" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/url" +) + +// authorizeQNAP authenticates the logged-in QNAP user and verifies that they +// are authorized to use the web client. +// If the user is not authorized to use the client, an error is returned. +func authorizeQNAP(r *http.Request) (authorized bool, err error) { + _, resp, err := qnapAuthn(r) + if err != nil { + return false, err + } + if resp.IsAdmin == 0 { + return false, errors.New("user is not an admin") + } + + return true, nil +} + +type qnapAuthResponse struct { + AuthPassed int `xml:"authPassed"` + IsAdmin int `xml:"isAdmin"` + AuthSID string `xml:"authSid"` + ErrorValue int `xml:"errorValue"` +} + +func qnapAuthn(r *http.Request) (string, *qnapAuthResponse, error) { + user, err := r.Cookie("NAS_USER") + if err != nil { + return "", nil, err + } + token, err := r.Cookie("qtoken") + if err == nil { + return qnapAuthnQtoken(r, user.Value, token.Value) + } + sid, err := r.Cookie("NAS_SID") + if err == nil { + return qnapAuthnSid(r, user.Value, sid.Value) + } + return "", nil, fmt.Errorf("not authenticated by any mechanism") +} + +// qnapAuthnURL returns the auth URL to use by inferring where the UI is +// running based on the request URL. This is necessary because QNAP has so +// many options, see https://github.com/tailscale/tailscale/issues/7108 +// and https://github.com/tailscale/tailscale/issues/6903 +func qnapAuthnURL(requestUrl string, query url.Values) string { + in, err := url.Parse(requestUrl) + scheme := "" + host := "" + if err != nil || in.Scheme == "" { + log.Printf("Cannot parse QNAP login URL %v", err) + + // try localhost and hope for the best + scheme = "http" + host = "localhost" + } else { + scheme = in.Scheme + host = in.Host + } + + u := url.URL{ + Scheme: scheme, + Host: host, + Path: "/cgi-bin/authLogin.cgi", + RawQuery: query.Encode(), + } + + return u.String() +} + +func qnapAuthnQtoken(r *http.Request, user, token string) (string, *qnapAuthResponse, error) { + query := url.Values{ + "qtoken": []string{token}, + "user": []string{user}, + } + return qnapAuthnFinish(user, qnapAuthnURL(r.URL.String(), query)) +} + +func qnapAuthnSid(r *http.Request, user, sid string) (string, *qnapAuthResponse, error) { + query := url.Values{ + "sid": []string{sid}, + } + return qnapAuthnFinish(user, qnapAuthnURL(r.URL.String(), query)) +} + +func qnapAuthnFinish(user, url string) (string, *qnapAuthResponse, error) { + // QNAP Force HTTPS mode uses a self-signed certificate. Even importing + // the QNAP root CA isn't enough, the cert doesn't have a usable CN nor + // SAN. See https://github.com/tailscale/tailscale/issues/6903 + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + client := &http.Client{Transport: tr} + resp, err := client.Get(url) + if err != nil { + return "", nil, err + } + defer resp.Body.Close() + out, err := io.ReadAll(resp.Body) + if err != nil { + return "", nil, err + } + authResp := &qnapAuthResponse{} + if err := xml.Unmarshal(out, authResp); err != nil { + return "", nil, err + } + if authResp.AuthPassed == 0 { + return "", nil, fmt.Errorf("not authenticated") + } + return user, authResp, nil +} diff --git a/client/web/src/assets/icons/arrow-right.svg b/client/web/src/assets/icons/arrow-right.svg index 0a32ef4844395..fbc4bb7ae3b7a 100644 --- a/client/web/src/assets/icons/arrow-right.svg +++ b/client/web/src/assets/icons/arrow-right.svg @@ -1,4 +1,4 @@ - - - - + + + + diff --git a/client/web/src/assets/icons/arrow-up-circle.svg b/client/web/src/assets/icons/arrow-up-circle.svg index e64c836be71c9..e9d009eb6bf65 100644 --- a/client/web/src/assets/icons/arrow-up-circle.svg +++ b/client/web/src/assets/icons/arrow-up-circle.svg @@ -1,5 +1,5 @@ - - - - - + + + + + diff --git a/client/web/src/assets/icons/check-circle.svg b/client/web/src/assets/icons/check-circle.svg index 6c5ee519e6d35..4daeed514d1ff 100644 --- a/client/web/src/assets/icons/check-circle.svg +++ b/client/web/src/assets/icons/check-circle.svg @@ -1,4 +1,4 @@ - - - - + + + + diff --git a/client/web/src/assets/icons/check.svg b/client/web/src/assets/icons/check.svg index 70027536a6960..efa11685d772c 100644 --- a/client/web/src/assets/icons/check.svg +++ b/client/web/src/assets/icons/check.svg @@ -1,3 +1,3 @@ - - - + + + diff --git a/client/web/src/assets/icons/chevron-down.svg b/client/web/src/assets/icons/chevron-down.svg index 993744c2fa287..afc98f255d4e5 100644 --- a/client/web/src/assets/icons/chevron-down.svg +++ b/client/web/src/assets/icons/chevron-down.svg @@ -1,3 +1,3 @@ - - - + + + diff --git a/client/web/src/assets/icons/eye.svg b/client/web/src/assets/icons/eye.svg index e277674777814..b0b21ed3f701c 100644 --- a/client/web/src/assets/icons/eye.svg +++ b/client/web/src/assets/icons/eye.svg @@ -1,11 +1,11 @@ - - - - - - - - - - - + + + + + + + + + + + diff --git a/client/web/src/assets/icons/search.svg b/client/web/src/assets/icons/search.svg index 08eb2d3dc3b8f..782cd90eee1d8 100644 --- a/client/web/src/assets/icons/search.svg +++ b/client/web/src/assets/icons/search.svg @@ -1,4 +1,4 @@ - - - - + + + + diff --git a/client/web/src/assets/icons/tailscale-icon.svg b/client/web/src/assets/icons/tailscale-icon.svg index de3c975ce1d53..d6052fe5e7cd6 100644 --- a/client/web/src/assets/icons/tailscale-icon.svg +++ b/client/web/src/assets/icons/tailscale-icon.svg @@ -1,18 +1,18 @@ - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + diff --git a/client/web/src/assets/icons/tailscale-logo.svg b/client/web/src/assets/icons/tailscale-logo.svg index 94a9cc4ee906e..6d5c7ce0caae3 100644 --- a/client/web/src/assets/icons/tailscale-logo.svg +++ b/client/web/src/assets/icons/tailscale-logo.svg @@ -1,20 +1,20 @@ - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + diff --git a/client/web/src/assets/icons/user.svg b/client/web/src/assets/icons/user.svg index 7fa3d26034d8c..29d86f0499956 100644 --- a/client/web/src/assets/icons/user.svg +++ b/client/web/src/assets/icons/user.svg @@ -1,4 +1,4 @@ - - - - + + + + diff --git a/client/web/src/assets/icons/x-circle.svg b/client/web/src/assets/icons/x-circle.svg index d6259c9177672..49afc5a0366fe 100644 --- a/client/web/src/assets/icons/x-circle.svg +++ b/client/web/src/assets/icons/x-circle.svg @@ -1,5 +1,5 @@ - - - - - + + + + + diff --git a/client/web/synology.go b/client/web/synology.go index 5480263834893..922489d78af16 100644 --- a/client/web/synology.go +++ b/client/web/synology.go @@ -1,59 +1,59 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// synology.go contains handlers and logic, such as authentication, -// that is specific to running the web client on Synology. - -package web - -import ( - "errors" - "fmt" - "net/http" - "os/exec" - "strings" - - "tailscale.com/util/groupmember" -) - -// authorizeSynology authenticates the logged-in Synology user and verifies -// that they are authorized to use the web client. -// If the user is authenticated, but not authorized to use the client, an error is returned. -func authorizeSynology(r *http.Request) (authorized bool, err error) { - if !hasSynoToken(r) { - return false, nil - } - - // authenticate the Synology user - cmd := exec.Command("/usr/syno/synoman/webman/modules/authenticate.cgi") - out, err := cmd.CombinedOutput() - if err != nil { - return false, fmt.Errorf("auth: %v: %s", err, out) - } - user := strings.TrimSpace(string(out)) - - // check if the user is in the administrators group - isAdmin, err := groupmember.IsMemberOfGroup("administrators", user) - if err != nil { - return false, err - } - if !isAdmin { - return false, errors.New("not a member of administrators group") - } - - return true, nil -} - -// hasSynoToken returns true if the request include a SynoToken used for synology auth. -func hasSynoToken(r *http.Request) bool { - if r.Header.Get("X-Syno-Token") != "" { - return true - } - if r.URL.Query().Get("SynoToken") != "" { - return true - } - if r.Method == "POST" && r.FormValue("SynoToken") != "" { - return true - } - return false -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// synology.go contains handlers and logic, such as authentication, +// that is specific to running the web client on Synology. + +package web + +import ( + "errors" + "fmt" + "net/http" + "os/exec" + "strings" + + "tailscale.com/util/groupmember" +) + +// authorizeSynology authenticates the logged-in Synology user and verifies +// that they are authorized to use the web client. +// If the user is authenticated, but not authorized to use the client, an error is returned. +func authorizeSynology(r *http.Request) (authorized bool, err error) { + if !hasSynoToken(r) { + return false, nil + } + + // authenticate the Synology user + cmd := exec.Command("/usr/syno/synoman/webman/modules/authenticate.cgi") + out, err := cmd.CombinedOutput() + if err != nil { + return false, fmt.Errorf("auth: %v: %s", err, out) + } + user := strings.TrimSpace(string(out)) + + // check if the user is in the administrators group + isAdmin, err := groupmember.IsMemberOfGroup("administrators", user) + if err != nil { + return false, err + } + if !isAdmin { + return false, errors.New("not a member of administrators group") + } + + return true, nil +} + +// hasSynoToken returns true if the request include a SynoToken used for synology auth. +func hasSynoToken(r *http.Request) bool { + if r.Header.Get("X-Syno-Token") != "" { + return true + } + if r.URL.Query().Get("SynoToken") != "" { + return true + } + if r.Method == "POST" && r.FormValue("SynoToken") != "" { + return true + } + return false +} diff --git a/clientupdate/distsign/distsign.go b/clientupdate/distsign/distsign.go index aae6201539c59..eba4b9267b119 100644 --- a/clientupdate/distsign/distsign.go +++ b/clientupdate/distsign/distsign.go @@ -1,486 +1,486 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package distsign implements signature and validation of arbitrary -// distributable files. -// -// There are 3 parties in this exchange: -// - builder, which creates files, signs them with signing keys and publishes -// to server -// - server, which distributes public signing keys, files and signatures -// - client, which downloads files and signatures from server, and validates -// the signatures -// -// There are 2 types of keys: -// - signing keys, that sign individual distributable files on the builder -// - root keys, that sign signing keys and are kept offline -// -// root keys -(sign)-> signing keys -(sign)-> files -// -// All keys are asymmetric Ed25519 key pairs. -// -// The server serves static files under some known prefix. The kinds of files are: -// - distsign.pub - bundle of PEM-encoded public signing keys -// - distsign.pub.sig - signature of distsign.pub using one of the root keys -// - $file - any distributable file -// - $file.sig - signature of $file using any of the signing keys -// -// The root public keys are baked into the client software at compile time. -// These keys are long-lived and prove the validity of current signing keys -// from distsign.pub. To rotate root keys, a new client release must be -// published, they are not rotated dynamically. There are multiple root keys in -// different locations specifically to allow this rotation without using the -// discarded root key for any new signatures. -// -// The signing public keys are fetched by the client dynamically before every -// download and can be rotated more readily, assuming that most deployed -// clients trust the root keys used to issue fresh signing keys. -package distsign - -import ( - "context" - "crypto/ed25519" - "crypto/rand" - "encoding/binary" - "encoding/pem" - "errors" - "fmt" - "hash" - "io" - "log" - "net/http" - "net/url" - "os" - "time" - - "github.com/hdevalence/ed25519consensus" - "golang.org/x/crypto/blake2s" - "tailscale.com/net/tshttpproxy" - "tailscale.com/types/logger" - "tailscale.com/util/httpm" - "tailscale.com/util/must" -) - -const ( - pemTypeRootPrivate = "ROOT PRIVATE KEY" - pemTypeRootPublic = "ROOT PUBLIC KEY" - pemTypeSigningPrivate = "SIGNING PRIVATE KEY" - pemTypeSigningPublic = "SIGNING PUBLIC KEY" - - downloadSizeLimit = 1 << 29 // 512MB - signingKeysSizeLimit = 1 << 20 // 1MB - signatureSizeLimit = ed25519.SignatureSize -) - -// RootKey is a root key used to sign signing keys. -type RootKey struct { - k ed25519.PrivateKey -} - -// GenerateRootKey generates a new root key pair and encodes it as PEM. -func GenerateRootKey() (priv, pub []byte, err error) { - pub, priv, err = ed25519.GenerateKey(rand.Reader) - if err != nil { - return nil, nil, err - } - return pem.EncodeToMemory(&pem.Block{ - Type: pemTypeRootPrivate, - Bytes: []byte(priv), - }), pem.EncodeToMemory(&pem.Block{ - Type: pemTypeRootPublic, - Bytes: []byte(pub), - }), nil -} - -// ParseRootKey parses the PEM-encoded private root key. The key must be in the -// same format as returned by GenerateRootKey. -func ParseRootKey(privKey []byte) (*RootKey, error) { - k, err := parsePrivateKey(privKey, pemTypeRootPrivate) - if err != nil { - return nil, fmt.Errorf("failed to parse root key: %w", err) - } - return &RootKey{k: k}, nil -} - -// SignSigningKeys signs the bundle of public signing keys. The bundle must be -// a sequence of PEM blocks joined with newlines. -func (r *RootKey) SignSigningKeys(pubBundle []byte) ([]byte, error) { - if _, err := ParseSigningKeyBundle(pubBundle); err != nil { - return nil, err - } - return ed25519.Sign(r.k, pubBundle), nil -} - -// SigningKey is a signing key used to sign packages. -type SigningKey struct { - k ed25519.PrivateKey -} - -// GenerateSigningKey generates a new signing key pair and encodes it as PEM. -func GenerateSigningKey() (priv, pub []byte, err error) { - pub, priv, err = ed25519.GenerateKey(rand.Reader) - if err != nil { - return nil, nil, err - } - return pem.EncodeToMemory(&pem.Block{ - Type: pemTypeSigningPrivate, - Bytes: []byte(priv), - }), pem.EncodeToMemory(&pem.Block{ - Type: pemTypeSigningPublic, - Bytes: []byte(pub), - }), nil -} - -// ParseSigningKey parses the PEM-encoded private signing key. The key must be -// in the same format as returned by GenerateSigningKey. -func ParseSigningKey(privKey []byte) (*SigningKey, error) { - k, err := parsePrivateKey(privKey, pemTypeSigningPrivate) - if err != nil { - return nil, fmt.Errorf("failed to parse root key: %w", err) - } - return &SigningKey{k: k}, nil -} - -// SignPackageHash signs the hash and the length of a package. Use PackageHash -// to compute the inputs. -func (s *SigningKey) SignPackageHash(hash []byte, len int64) ([]byte, error) { - if len <= 0 { - return nil, fmt.Errorf("package length must be positive, got %d", len) - } - msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) - return ed25519.Sign(s.k, msg), nil -} - -// PackageHash is a hash.Hash that counts the number of bytes written. Use it -// to get the hash and length inputs to SigningKey.SignPackageHash. -type PackageHash struct { - hash.Hash - len int64 -} - -// NewPackageHash returns an initialized PackageHash using BLAKE2s. -func NewPackageHash() *PackageHash { - h, err := blake2s.New256(nil) - if err != nil { - // Should never happen with a nil key passed to blake2s. - panic(err) - } - return &PackageHash{Hash: h} -} - -func (ph *PackageHash) Write(b []byte) (int, error) { - ph.len += int64(len(b)) - return ph.Hash.Write(b) -} - -// Reset the PackageHash to its initial state. -func (ph *PackageHash) Reset() { - ph.len = 0 - ph.Hash.Reset() -} - -// Len returns the total number of bytes written. -func (ph *PackageHash) Len() int64 { return ph.len } - -// Client downloads and validates files from a distribution server. -type Client struct { - logf logger.Logf - roots []ed25519.PublicKey - pkgsAddr *url.URL -} - -// NewClient returns a new client for distribution server located at pkgsAddr, -// and uses embedded root keys from the roots/ subdirectory of this package. -func NewClient(logf logger.Logf, pkgsAddr string) (*Client, error) { - if logf == nil { - logf = log.Printf - } - u, err := url.Parse(pkgsAddr) - if err != nil { - return nil, fmt.Errorf("invalid pkgsAddr %q: %w", pkgsAddr, err) - } - return &Client{logf: logf, roots: roots(), pkgsAddr: u}, nil -} - -func (c *Client) url(path string) string { - return c.pkgsAddr.JoinPath(path).String() -} - -// Download fetches a file at path srcPath from pkgsAddr passed in NewClient. -// The file is downloaded to dstPath and its signature is validated using the -// embedded root keys. Download returns an error if anything goes wrong with -// the actual file download or with signature validation. -func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error { - // Always fetch a fresh signing key. - sigPub, err := c.signingKeys() - if err != nil { - return err - } - - srcURL := c.url(srcPath) - sigURL := srcURL + ".sig" - - c.logf("Downloading %q", srcURL) - dstPathUnverified := dstPath + ".unverified" - hash, len, err := c.download(ctx, srcURL, dstPathUnverified, downloadSizeLimit) - if err != nil { - return err - } - c.logf("Downloading %q", sigURL) - sig, err := fetch(sigURL, signatureSizeLimit) - if err != nil { - // Best-effort clean up of downloaded package. - os.Remove(dstPathUnverified) - return err - } - msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) - if !VerifyAny(sigPub, msg, sig) { - // Best-effort clean up of downloaded package. - os.Remove(dstPathUnverified) - return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, srcURL) - } - c.logf("Signature OK") - - if err := os.Rename(dstPathUnverified, dstPath); err != nil { - return fmt.Errorf("failed to move %q to %q after signature validation", dstPathUnverified, dstPath) - } - - return nil -} - -// ValidateLocalBinary fetches the latest signature associated with the binary -// at srcURLPath and uses it to validate the file located on disk via -// localFilePath. ValidateLocalBinary returns an error if anything goes wrong -// with the signature download or with signature validation. -func (c *Client) ValidateLocalBinary(srcURLPath, localFilePath string) error { - // Always fetch a fresh signing key. - sigPub, err := c.signingKeys() - if err != nil { - return err - } - - srcURL := c.url(srcURLPath) - sigURL := srcURL + ".sig" - - localFile, err := os.Open(localFilePath) - if err != nil { - return err - } - defer localFile.Close() - - h := NewPackageHash() - _, err = io.Copy(h, localFile) - if err != nil { - return err - } - hash, hashLen := h.Sum(nil), h.Len() - - c.logf("Downloading %q", sigURL) - sig, err := fetch(sigURL, signatureSizeLimit) - if err != nil { - return err - } - - msg := binary.LittleEndian.AppendUint64(hash, uint64(hashLen)) - if !VerifyAny(sigPub, msg, sig) { - return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, localFilePath) - } - c.logf("Signature OK") - - return nil -} - -// signingKeys fetches current signing keys from the server and validates them -// against the roots. Should be called before validation of any downloaded file -// to get the fresh keys. -func (c *Client) signingKeys() ([]ed25519.PublicKey, error) { - keyURL := c.url("distsign.pub") - sigURL := keyURL + ".sig" - raw, err := fetch(keyURL, signingKeysSizeLimit) - if err != nil { - return nil, err - } - sig, err := fetch(sigURL, signatureSizeLimit) - if err != nil { - return nil, err - } - if !VerifyAny(c.roots, raw, sig) { - return nil, fmt.Errorf("signature %q for key %q does not validate with any known root key; either you are under attack, or running a very old version of Tailscale with outdated root keys", sigURL, keyURL) - } - - keys, err := ParseSigningKeyBundle(raw) - if err != nil { - return nil, fmt.Errorf("cannot parse signing key bundle from %q: %w", keyURL, err) - } - return keys, nil -} - -// fetch reads the response body from url into memory, up to limit bytes. -func fetch(url string, limit int64) ([]byte, error) { - resp, err := http.Get(url) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - return io.ReadAll(io.LimitReader(resp.Body, limit)) -} - -// download writes the response body of url into a local file at dst, up to -// limit bytes. On success, the returned value is a BLAKE2s hash of the file. -func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) { - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.Proxy = tshttpproxy.ProxyFromEnvironment - defer tr.CloseIdleConnections() - hc := &http.Client{Transport: tr} - - quickCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - headReq := must.Get(http.NewRequestWithContext(quickCtx, httpm.HEAD, url, nil)) - - res, err := hc.Do(headReq) - if err != nil { - return nil, 0, err - } - if res.StatusCode != http.StatusOK { - return nil, 0, fmt.Errorf("HEAD %q: %v", url, res.Status) - } - if res.ContentLength <= 0 { - return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength) - } - c.logf("Download size: %v", res.ContentLength) - - dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil)) - dlRes, err := hc.Do(dlReq) - if err != nil { - return nil, 0, err - } - defer dlRes.Body.Close() - // TODO(bradfitz): resume from existing partial file on disk - if dlRes.StatusCode != http.StatusOK { - return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status) - } - - of, err := os.Create(dst) - if err != nil { - return nil, 0, err - } - defer of.Close() - pw := &progressWriter{total: res.ContentLength, logf: c.logf} - h := NewPackageHash() - n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit)) - if err != nil { - return nil, n, err - } - if n != res.ContentLength { - return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength) - } - if err := dlRes.Body.Close(); err != nil { - return nil, n, err - } - if err := of.Close(); err != nil { - return nil, n, err - } - pw.print() - - return h.Sum(nil), h.Len(), nil -} - -type progressWriter struct { - done int64 - total int64 - lastPrint time.Time - logf logger.Logf -} - -func (pw *progressWriter) Write(p []byte) (n int, err error) { - pw.done += int64(len(p)) - if time.Since(pw.lastPrint) > 2*time.Second { - pw.print() - } - return len(p), nil -} - -func (pw *progressWriter) print() { - pw.lastPrint = time.Now() - pw.logf("Downloaded %v/%v (%.1f%%)", pw.done, pw.total, float64(pw.done)/float64(pw.total)*100) -} - -func parsePrivateKey(data []byte, typeTag string) (ed25519.PrivateKey, error) { - b, rest := pem.Decode(data) - if b == nil { - return nil, errors.New("failed to decode PEM data") - } - if len(rest) > 0 { - return nil, errors.New("trailing PEM data") - } - if b.Type != typeTag { - return nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) - } - if len(b.Bytes) != ed25519.PrivateKeySize { - return nil, errors.New("private key has incorrect length for an Ed25519 private key") - } - return ed25519.PrivateKey(b.Bytes), nil -} - -// ParseSigningKeyBundle parses the bundle of PEM-encoded public signing keys. -func ParseSigningKeyBundle(bundle []byte) ([]ed25519.PublicKey, error) { - return parsePublicKeyBundle(bundle, pemTypeSigningPublic) -} - -// ParseRootKeyBundle parses the bundle of PEM-encoded public root keys. -func ParseRootKeyBundle(bundle []byte) ([]ed25519.PublicKey, error) { - return parsePublicKeyBundle(bundle, pemTypeRootPublic) -} - -func parsePublicKeyBundle(bundle []byte, typeTag string) ([]ed25519.PublicKey, error) { - var keys []ed25519.PublicKey - for len(bundle) > 0 { - pub, rest, err := parsePublicKey(bundle, typeTag) - if err != nil { - return nil, err - } - keys = append(keys, pub) - bundle = rest - } - if len(keys) == 0 { - return nil, errors.New("no signing keys found in the bundle") - } - return keys, nil -} - -func parseSinglePublicKey(data []byte, typeTag string) (ed25519.PublicKey, error) { - pub, rest, err := parsePublicKey(data, typeTag) - if err != nil { - return nil, err - } - if len(rest) > 0 { - return nil, errors.New("trailing PEM data") - } - return pub, err -} - -func parsePublicKey(data []byte, typeTag string) (pub ed25519.PublicKey, rest []byte, retErr error) { - b, rest := pem.Decode(data) - if b == nil { - return nil, nil, errors.New("failed to decode PEM data") - } - if b.Type != typeTag { - return nil, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) - } - if len(b.Bytes) != ed25519.PublicKeySize { - return nil, nil, errors.New("public key has incorrect length for an Ed25519 public key") - } - return ed25519.PublicKey(b.Bytes), rest, nil -} - -// VerifyAny verifies whether sig is valid for msg using any of the keys. -// VerifyAny will panic if any of the keys have the wrong size for Ed25519. -func VerifyAny(keys []ed25519.PublicKey, msg, sig []byte) bool { - for _, k := range keys { - if ed25519consensus.Verify(k, msg, sig) { - return true - } - } - return false -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package distsign implements signature and validation of arbitrary +// distributable files. +// +// There are 3 parties in this exchange: +// - builder, which creates files, signs them with signing keys and publishes +// to server +// - server, which distributes public signing keys, files and signatures +// - client, which downloads files and signatures from server, and validates +// the signatures +// +// There are 2 types of keys: +// - signing keys, that sign individual distributable files on the builder +// - root keys, that sign signing keys and are kept offline +// +// root keys -(sign)-> signing keys -(sign)-> files +// +// All keys are asymmetric Ed25519 key pairs. +// +// The server serves static files under some known prefix. The kinds of files are: +// - distsign.pub - bundle of PEM-encoded public signing keys +// - distsign.pub.sig - signature of distsign.pub using one of the root keys +// - $file - any distributable file +// - $file.sig - signature of $file using any of the signing keys +// +// The root public keys are baked into the client software at compile time. +// These keys are long-lived and prove the validity of current signing keys +// from distsign.pub. To rotate root keys, a new client release must be +// published, they are not rotated dynamically. There are multiple root keys in +// different locations specifically to allow this rotation without using the +// discarded root key for any new signatures. +// +// The signing public keys are fetched by the client dynamically before every +// download and can be rotated more readily, assuming that most deployed +// clients trust the root keys used to issue fresh signing keys. +package distsign + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "encoding/pem" + "errors" + "fmt" + "hash" + "io" + "log" + "net/http" + "net/url" + "os" + "time" + + "github.com/hdevalence/ed25519consensus" + "golang.org/x/crypto/blake2s" + "tailscale.com/net/tshttpproxy" + "tailscale.com/types/logger" + "tailscale.com/util/httpm" + "tailscale.com/util/must" +) + +const ( + pemTypeRootPrivate = "ROOT PRIVATE KEY" + pemTypeRootPublic = "ROOT PUBLIC KEY" + pemTypeSigningPrivate = "SIGNING PRIVATE KEY" + pemTypeSigningPublic = "SIGNING PUBLIC KEY" + + downloadSizeLimit = 1 << 29 // 512MB + signingKeysSizeLimit = 1 << 20 // 1MB + signatureSizeLimit = ed25519.SignatureSize +) + +// RootKey is a root key used to sign signing keys. +type RootKey struct { + k ed25519.PrivateKey +} + +// GenerateRootKey generates a new root key pair and encodes it as PEM. +func GenerateRootKey() (priv, pub []byte, err error) { + pub, priv, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + return pem.EncodeToMemory(&pem.Block{ + Type: pemTypeRootPrivate, + Bytes: []byte(priv), + }), pem.EncodeToMemory(&pem.Block{ + Type: pemTypeRootPublic, + Bytes: []byte(pub), + }), nil +} + +// ParseRootKey parses the PEM-encoded private root key. The key must be in the +// same format as returned by GenerateRootKey. +func ParseRootKey(privKey []byte) (*RootKey, error) { + k, err := parsePrivateKey(privKey, pemTypeRootPrivate) + if err != nil { + return nil, fmt.Errorf("failed to parse root key: %w", err) + } + return &RootKey{k: k}, nil +} + +// SignSigningKeys signs the bundle of public signing keys. The bundle must be +// a sequence of PEM blocks joined with newlines. +func (r *RootKey) SignSigningKeys(pubBundle []byte) ([]byte, error) { + if _, err := ParseSigningKeyBundle(pubBundle); err != nil { + return nil, err + } + return ed25519.Sign(r.k, pubBundle), nil +} + +// SigningKey is a signing key used to sign packages. +type SigningKey struct { + k ed25519.PrivateKey +} + +// GenerateSigningKey generates a new signing key pair and encodes it as PEM. +func GenerateSigningKey() (priv, pub []byte, err error) { + pub, priv, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + return pem.EncodeToMemory(&pem.Block{ + Type: pemTypeSigningPrivate, + Bytes: []byte(priv), + }), pem.EncodeToMemory(&pem.Block{ + Type: pemTypeSigningPublic, + Bytes: []byte(pub), + }), nil +} + +// ParseSigningKey parses the PEM-encoded private signing key. The key must be +// in the same format as returned by GenerateSigningKey. +func ParseSigningKey(privKey []byte) (*SigningKey, error) { + k, err := parsePrivateKey(privKey, pemTypeSigningPrivate) + if err != nil { + return nil, fmt.Errorf("failed to parse root key: %w", err) + } + return &SigningKey{k: k}, nil +} + +// SignPackageHash signs the hash and the length of a package. Use PackageHash +// to compute the inputs. +func (s *SigningKey) SignPackageHash(hash []byte, len int64) ([]byte, error) { + if len <= 0 { + return nil, fmt.Errorf("package length must be positive, got %d", len) + } + msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) + return ed25519.Sign(s.k, msg), nil +} + +// PackageHash is a hash.Hash that counts the number of bytes written. Use it +// to get the hash and length inputs to SigningKey.SignPackageHash. +type PackageHash struct { + hash.Hash + len int64 +} + +// NewPackageHash returns an initialized PackageHash using BLAKE2s. +func NewPackageHash() *PackageHash { + h, err := blake2s.New256(nil) + if err != nil { + // Should never happen with a nil key passed to blake2s. + panic(err) + } + return &PackageHash{Hash: h} +} + +func (ph *PackageHash) Write(b []byte) (int, error) { + ph.len += int64(len(b)) + return ph.Hash.Write(b) +} + +// Reset the PackageHash to its initial state. +func (ph *PackageHash) Reset() { + ph.len = 0 + ph.Hash.Reset() +} + +// Len returns the total number of bytes written. +func (ph *PackageHash) Len() int64 { return ph.len } + +// Client downloads and validates files from a distribution server. +type Client struct { + logf logger.Logf + roots []ed25519.PublicKey + pkgsAddr *url.URL +} + +// NewClient returns a new client for distribution server located at pkgsAddr, +// and uses embedded root keys from the roots/ subdirectory of this package. +func NewClient(logf logger.Logf, pkgsAddr string) (*Client, error) { + if logf == nil { + logf = log.Printf + } + u, err := url.Parse(pkgsAddr) + if err != nil { + return nil, fmt.Errorf("invalid pkgsAddr %q: %w", pkgsAddr, err) + } + return &Client{logf: logf, roots: roots(), pkgsAddr: u}, nil +} + +func (c *Client) url(path string) string { + return c.pkgsAddr.JoinPath(path).String() +} + +// Download fetches a file at path srcPath from pkgsAddr passed in NewClient. +// The file is downloaded to dstPath and its signature is validated using the +// embedded root keys. Download returns an error if anything goes wrong with +// the actual file download or with signature validation. +func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error { + // Always fetch a fresh signing key. + sigPub, err := c.signingKeys() + if err != nil { + return err + } + + srcURL := c.url(srcPath) + sigURL := srcURL + ".sig" + + c.logf("Downloading %q", srcURL) + dstPathUnverified := dstPath + ".unverified" + hash, len, err := c.download(ctx, srcURL, dstPathUnverified, downloadSizeLimit) + if err != nil { + return err + } + c.logf("Downloading %q", sigURL) + sig, err := fetch(sigURL, signatureSizeLimit) + if err != nil { + // Best-effort clean up of downloaded package. + os.Remove(dstPathUnverified) + return err + } + msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) + if !VerifyAny(sigPub, msg, sig) { + // Best-effort clean up of downloaded package. + os.Remove(dstPathUnverified) + return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, srcURL) + } + c.logf("Signature OK") + + if err := os.Rename(dstPathUnverified, dstPath); err != nil { + return fmt.Errorf("failed to move %q to %q after signature validation", dstPathUnverified, dstPath) + } + + return nil +} + +// ValidateLocalBinary fetches the latest signature associated with the binary +// at srcURLPath and uses it to validate the file located on disk via +// localFilePath. ValidateLocalBinary returns an error if anything goes wrong +// with the signature download or with signature validation. +func (c *Client) ValidateLocalBinary(srcURLPath, localFilePath string) error { + // Always fetch a fresh signing key. + sigPub, err := c.signingKeys() + if err != nil { + return err + } + + srcURL := c.url(srcURLPath) + sigURL := srcURL + ".sig" + + localFile, err := os.Open(localFilePath) + if err != nil { + return err + } + defer localFile.Close() + + h := NewPackageHash() + _, err = io.Copy(h, localFile) + if err != nil { + return err + } + hash, hashLen := h.Sum(nil), h.Len() + + c.logf("Downloading %q", sigURL) + sig, err := fetch(sigURL, signatureSizeLimit) + if err != nil { + return err + } + + msg := binary.LittleEndian.AppendUint64(hash, uint64(hashLen)) + if !VerifyAny(sigPub, msg, sig) { + return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, localFilePath) + } + c.logf("Signature OK") + + return nil +} + +// signingKeys fetches current signing keys from the server and validates them +// against the roots. Should be called before validation of any downloaded file +// to get the fresh keys. +func (c *Client) signingKeys() ([]ed25519.PublicKey, error) { + keyURL := c.url("distsign.pub") + sigURL := keyURL + ".sig" + raw, err := fetch(keyURL, signingKeysSizeLimit) + if err != nil { + return nil, err + } + sig, err := fetch(sigURL, signatureSizeLimit) + if err != nil { + return nil, err + } + if !VerifyAny(c.roots, raw, sig) { + return nil, fmt.Errorf("signature %q for key %q does not validate with any known root key; either you are under attack, or running a very old version of Tailscale with outdated root keys", sigURL, keyURL) + } + + keys, err := ParseSigningKeyBundle(raw) + if err != nil { + return nil, fmt.Errorf("cannot parse signing key bundle from %q: %w", keyURL, err) + } + return keys, nil +} + +// fetch reads the response body from url into memory, up to limit bytes. +func fetch(url string, limit int64) ([]byte, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return io.ReadAll(io.LimitReader(resp.Body, limit)) +} + +// download writes the response body of url into a local file at dst, up to +// limit bytes. On success, the returned value is a BLAKE2s hash of the file. +func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.Proxy = tshttpproxy.ProxyFromEnvironment + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr} + + quickCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + headReq := must.Get(http.NewRequestWithContext(quickCtx, httpm.HEAD, url, nil)) + + res, err := hc.Do(headReq) + if err != nil { + return nil, 0, err + } + if res.StatusCode != http.StatusOK { + return nil, 0, fmt.Errorf("HEAD %q: %v", url, res.Status) + } + if res.ContentLength <= 0 { + return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength) + } + c.logf("Download size: %v", res.ContentLength) + + dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil)) + dlRes, err := hc.Do(dlReq) + if err != nil { + return nil, 0, err + } + defer dlRes.Body.Close() + // TODO(bradfitz): resume from existing partial file on disk + if dlRes.StatusCode != http.StatusOK { + return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status) + } + + of, err := os.Create(dst) + if err != nil { + return nil, 0, err + } + defer of.Close() + pw := &progressWriter{total: res.ContentLength, logf: c.logf} + h := NewPackageHash() + n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit)) + if err != nil { + return nil, n, err + } + if n != res.ContentLength { + return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength) + } + if err := dlRes.Body.Close(); err != nil { + return nil, n, err + } + if err := of.Close(); err != nil { + return nil, n, err + } + pw.print() + + return h.Sum(nil), h.Len(), nil +} + +type progressWriter struct { + done int64 + total int64 + lastPrint time.Time + logf logger.Logf +} + +func (pw *progressWriter) Write(p []byte) (n int, err error) { + pw.done += int64(len(p)) + if time.Since(pw.lastPrint) > 2*time.Second { + pw.print() + } + return len(p), nil +} + +func (pw *progressWriter) print() { + pw.lastPrint = time.Now() + pw.logf("Downloaded %v/%v (%.1f%%)", pw.done, pw.total, float64(pw.done)/float64(pw.total)*100) +} + +func parsePrivateKey(data []byte, typeTag string) (ed25519.PrivateKey, error) { + b, rest := pem.Decode(data) + if b == nil { + return nil, errors.New("failed to decode PEM data") + } + if len(rest) > 0 { + return nil, errors.New("trailing PEM data") + } + if b.Type != typeTag { + return nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) + } + if len(b.Bytes) != ed25519.PrivateKeySize { + return nil, errors.New("private key has incorrect length for an Ed25519 private key") + } + return ed25519.PrivateKey(b.Bytes), nil +} + +// ParseSigningKeyBundle parses the bundle of PEM-encoded public signing keys. +func ParseSigningKeyBundle(bundle []byte) ([]ed25519.PublicKey, error) { + return parsePublicKeyBundle(bundle, pemTypeSigningPublic) +} + +// ParseRootKeyBundle parses the bundle of PEM-encoded public root keys. +func ParseRootKeyBundle(bundle []byte) ([]ed25519.PublicKey, error) { + return parsePublicKeyBundle(bundle, pemTypeRootPublic) +} + +func parsePublicKeyBundle(bundle []byte, typeTag string) ([]ed25519.PublicKey, error) { + var keys []ed25519.PublicKey + for len(bundle) > 0 { + pub, rest, err := parsePublicKey(bundle, typeTag) + if err != nil { + return nil, err + } + keys = append(keys, pub) + bundle = rest + } + if len(keys) == 0 { + return nil, errors.New("no signing keys found in the bundle") + } + return keys, nil +} + +func parseSinglePublicKey(data []byte, typeTag string) (ed25519.PublicKey, error) { + pub, rest, err := parsePublicKey(data, typeTag) + if err != nil { + return nil, err + } + if len(rest) > 0 { + return nil, errors.New("trailing PEM data") + } + return pub, err +} + +func parsePublicKey(data []byte, typeTag string) (pub ed25519.PublicKey, rest []byte, retErr error) { + b, rest := pem.Decode(data) + if b == nil { + return nil, nil, errors.New("failed to decode PEM data") + } + if b.Type != typeTag { + return nil, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) + } + if len(b.Bytes) != ed25519.PublicKeySize { + return nil, nil, errors.New("public key has incorrect length for an Ed25519 public key") + } + return ed25519.PublicKey(b.Bytes), rest, nil +} + +// VerifyAny verifies whether sig is valid for msg using any of the keys. +// VerifyAny will panic if any of the keys have the wrong size for Ed25519. +func VerifyAny(keys []ed25519.PublicKey, msg, sig []byte) bool { + for _, k := range keys { + if ed25519consensus.Verify(k, msg, sig) { + return true + } + } + return false +} diff --git a/clientupdate/distsign/roots.go b/clientupdate/distsign/roots.go index df86557979ecd..d5b47b7b62e92 100644 --- a/clientupdate/distsign/roots.go +++ b/clientupdate/distsign/roots.go @@ -1,54 +1,54 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package distsign - -import ( - "crypto/ed25519" - "embed" - "errors" - "fmt" - "path" - "path/filepath" - "sync" -) - -//go:embed roots -var rootsFS embed.FS - -var roots = sync.OnceValue(func() []ed25519.PublicKey { - roots, err := parseRoots() - if err != nil { - panic(err) - } - return roots -}) - -func parseRoots() ([]ed25519.PublicKey, error) { - files, err := rootsFS.ReadDir("roots") - if err != nil { - return nil, err - } - var keys []ed25519.PublicKey - for _, f := range files { - if !f.Type().IsRegular() { - continue - } - if filepath.Ext(f.Name()) != ".pem" { - continue - } - raw, err := rootsFS.ReadFile(path.Join("roots", f.Name())) - if err != nil { - return nil, err - } - key, err := parseSinglePublicKey(raw, pemTypeRootPublic) - if err != nil { - return nil, fmt.Errorf("parsing root key %q: %w", f.Name(), err) - } - keys = append(keys, key) - } - if len(keys) == 0 { - return nil, errors.New("no embedded root keys, please check clientupdate/distsign/roots/") - } - return keys, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package distsign + +import ( + "crypto/ed25519" + "embed" + "errors" + "fmt" + "path" + "path/filepath" + "sync" +) + +//go:embed roots +var rootsFS embed.FS + +var roots = sync.OnceValue(func() []ed25519.PublicKey { + roots, err := parseRoots() + if err != nil { + panic(err) + } + return roots +}) + +func parseRoots() ([]ed25519.PublicKey, error) { + files, err := rootsFS.ReadDir("roots") + if err != nil { + return nil, err + } + var keys []ed25519.PublicKey + for _, f := range files { + if !f.Type().IsRegular() { + continue + } + if filepath.Ext(f.Name()) != ".pem" { + continue + } + raw, err := rootsFS.ReadFile(path.Join("roots", f.Name())) + if err != nil { + return nil, err + } + key, err := parseSinglePublicKey(raw, pemTypeRootPublic) + if err != nil { + return nil, fmt.Errorf("parsing root key %q: %w", f.Name(), err) + } + keys = append(keys, key) + } + if len(keys) == 0 { + return nil, errors.New("no embedded root keys, please check clientupdate/distsign/roots/") + } + return keys, nil +} diff --git a/clientupdate/distsign/roots/crawshaw-root.pem b/clientupdate/distsign/roots/crawshaw-root.pem index 897a38295b6b0..f80b9aec78b11 100755 --- a/clientupdate/distsign/roots/crawshaw-root.pem +++ b/clientupdate/distsign/roots/crawshaw-root.pem @@ -1,3 +1,3 @@ ------BEGIN ROOT PUBLIC KEY----- -Psrabv2YNiEDhPlnLVSMtB5EKACm7zxvKxfvYD4i7X8= ------END ROOT PUBLIC KEY----- +-----BEGIN ROOT PUBLIC KEY----- +Psrabv2YNiEDhPlnLVSMtB5EKACm7zxvKxfvYD4i7X8= +-----END ROOT PUBLIC KEY----- diff --git a/clientupdate/distsign/roots/distsign-prod-root-1-pub.pem b/clientupdate/distsign/roots/distsign-prod-root-1-pub.pem index e2f937ed3b0d1..d5d6516ab0368 100644 --- a/clientupdate/distsign/roots/distsign-prod-root-1-pub.pem +++ b/clientupdate/distsign/roots/distsign-prod-root-1-pub.pem @@ -1,3 +1,3 @@ ------BEGIN ROOT PUBLIC KEY----- -ZjjKhUHBtLNRSO1dhOTjrXJGJ8lDe1594WM2XDuheVQ= ------END ROOT PUBLIC KEY----- +-----BEGIN ROOT PUBLIC KEY----- +ZjjKhUHBtLNRSO1dhOTjrXJGJ8lDe1594WM2XDuheVQ= +-----END ROOT PUBLIC KEY----- diff --git a/clientupdate/distsign/roots_test.go b/clientupdate/distsign/roots_test.go index ae0dfbc22d5bd..7a94529538ef1 100644 --- a/clientupdate/distsign/roots_test.go +++ b/clientupdate/distsign/roots_test.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package distsign - -import "testing" - -func TestParseRoots(t *testing.T) { - roots, err := parseRoots() - if err != nil { - t.Fatal(err) - } - if len(roots) == 0 { - t.Error("parseRoots returned no root keys") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package distsign + +import "testing" + +func TestParseRoots(t *testing.T) { + roots, err := parseRoots() + if err != nil { + t.Fatal(err) + } + if len(roots) == 0 { + t.Error("parseRoots returned no root keys") + } +} diff --git a/cmd/addlicense/main.go b/cmd/addlicense/main.go index 58ef7a4711c93..a8fd9dd4ab96a 100644 --- a/cmd/addlicense/main.go +++ b/cmd/addlicense/main.go @@ -1,73 +1,73 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Program addlicense adds a license header to a file. -// It is intended for use with 'go generate', -// so it has a slightly weird usage. -package main - -import ( - "flag" - "fmt" - "os" - "os/exec" -) - -var ( - file = flag.String("file", "", "file to modify") -) - -func usage() { - fmt.Fprintf(os.Stderr, ` -usage: addlicense -file FILE -`[1:]) - - flag.PrintDefaults() - fmt.Fprintf(os.Stderr, ` -addlicense adds a Tailscale license to the beginning of file. - -It is intended for use with 'go generate', so it also runs a subcommand, -which presumably creates the file. - -Sample usage: - -addlicense -file pull_strings.go stringer -type=pull -`[1:]) - os.Exit(2) -} - -func main() { - flag.Usage = usage - flag.Parse() - if len(flag.Args()) == 0 { - flag.Usage() - } - cmd := exec.Command(flag.Arg(0), flag.Args()[1:]...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - err := cmd.Run() - check(err) - b, err := os.ReadFile(*file) - check(err) - f, err := os.OpenFile(*file, os.O_TRUNC|os.O_WRONLY, 0644) - check(err) - _, err = fmt.Fprint(f, license) - check(err) - _, err = f.Write(b) - check(err) - err = f.Close() - check(err) -} - -func check(err error) { - if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } -} - -var license = ` -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -`[1:] +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Program addlicense adds a license header to a file. +// It is intended for use with 'go generate', +// so it has a slightly weird usage. +package main + +import ( + "flag" + "fmt" + "os" + "os/exec" +) + +var ( + file = flag.String("file", "", "file to modify") +) + +func usage() { + fmt.Fprintf(os.Stderr, ` +usage: addlicense -file FILE +`[1:]) + + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, ` +addlicense adds a Tailscale license to the beginning of file. + +It is intended for use with 'go generate', so it also runs a subcommand, +which presumably creates the file. + +Sample usage: + +addlicense -file pull_strings.go stringer -type=pull +`[1:]) + os.Exit(2) +} + +func main() { + flag.Usage = usage + flag.Parse() + if len(flag.Args()) == 0 { + flag.Usage() + } + cmd := exec.Command(flag.Arg(0), flag.Args()[1:]...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Run() + check(err) + b, err := os.ReadFile(*file) + check(err) + f, err := os.OpenFile(*file, os.O_TRUNC|os.O_WRONLY, 0644) + check(err) + _, err = fmt.Fprint(f, license) + check(err) + _, err = f.Write(b) + check(err) + err = f.Close() + check(err) +} + +func check(err error) { + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +var license = ` +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +`[1:] diff --git a/cmd/cloner/cloner_test.go b/cmd/cloner/cloner_test.go index 83d33ab0e615b..d8d5df3cb040c 100644 --- a/cmd/cloner/cloner_test.go +++ b/cmd/cloner/cloner_test.go @@ -1,60 +1,60 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package main - -import ( - "reflect" - "testing" - - "tailscale.com/cmd/cloner/clonerex" -) - -func TestSliceContainer(t *testing.T) { - num := 5 - examples := []struct { - name string - in *clonerex.SliceContainer - }{ - { - name: "nil", - in: nil, - }, - { - name: "zero", - in: &clonerex.SliceContainer{}, - }, - { - name: "empty", - in: &clonerex.SliceContainer{ - Slice: []*int{}, - }, - }, - { - name: "nils", - in: &clonerex.SliceContainer{ - Slice: []*int{nil, nil, nil, nil, nil}, - }, - }, - { - name: "one", - in: &clonerex.SliceContainer{ - Slice: []*int{&num}, - }, - }, - { - name: "several", - in: &clonerex.SliceContainer{ - Slice: []*int{&num, &num, &num, &num, &num}, - }, - }, - } - - for _, ex := range examples { - t.Run(ex.name, func(t *testing.T) { - out := ex.in.Clone() - if !reflect.DeepEqual(ex.in, out) { - t.Errorf("Clone() = %v, want %v", out, ex.in) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package main + +import ( + "reflect" + "testing" + + "tailscale.com/cmd/cloner/clonerex" +) + +func TestSliceContainer(t *testing.T) { + num := 5 + examples := []struct { + name string + in *clonerex.SliceContainer + }{ + { + name: "nil", + in: nil, + }, + { + name: "zero", + in: &clonerex.SliceContainer{}, + }, + { + name: "empty", + in: &clonerex.SliceContainer{ + Slice: []*int{}, + }, + }, + { + name: "nils", + in: &clonerex.SliceContainer{ + Slice: []*int{nil, nil, nil, nil, nil}, + }, + }, + { + name: "one", + in: &clonerex.SliceContainer{ + Slice: []*int{&num}, + }, + }, + { + name: "several", + in: &clonerex.SliceContainer{ + Slice: []*int{&num, &num, &num, &num, &num}, + }, + }, + } + + for _, ex := range examples { + t.Run(ex.name, func(t *testing.T) { + out := ex.in.Clone() + if !reflect.DeepEqual(ex.in, out) { + t.Errorf("Clone() = %v, want %v", out, ex.in) + } + }) + } +} diff --git a/cmd/containerboot/test_tailscale.sh b/cmd/containerboot/test_tailscale.sh index dd56adf044bd4..1fa10abb18185 100644 --- a/cmd/containerboot/test_tailscale.sh +++ b/cmd/containerboot/test_tailscale.sh @@ -1,8 +1,8 @@ -#!/usr/bin/env bash -# -# This is a fake tailscale CLI (and also iptables and ip6tables) that -# records its arguments and exits successfully. -# -# It is used by main_test.go to test the behavior of containerboot. - -echo $0 $@ >>$TS_TEST_RECORD_ARGS +#!/usr/bin/env bash +# +# This is a fake tailscale CLI (and also iptables and ip6tables) that +# records its arguments and exits successfully. +# +# It is used by main_test.go to test the behavior of containerboot. + +echo $0 $@ >>$TS_TEST_RECORD_ARGS diff --git a/cmd/containerboot/test_tailscaled.sh b/cmd/containerboot/test_tailscaled.sh index b7404a0a9d368..335e2cb0dcfd1 100644 --- a/cmd/containerboot/test_tailscaled.sh +++ b/cmd/containerboot/test_tailscaled.sh @@ -1,38 +1,38 @@ -#!/usr/bin/env bash -# -# This is a fake tailscale daemon that records its arguments, symlinks a -# fake LocalAPI socket into place, and does nothing until terminated. -# -# It is used by main_test.go to test the behavior of containerboot. - -set -eu - -echo $0 $@ >>$TS_TEST_RECORD_ARGS - -socket="" -while [[ $# -gt 0 ]]; do - case $1 in - --socket=*) - socket="${1#--socket=}" - shift - ;; - --socket) - shift - socket="$1" - shift - ;; - *) - shift - ;; - esac -done - -if [[ -z "$socket" ]]; then - echo "didn't find socket path in args" - exit 1 -fi - -ln -s "$TS_TEST_SOCKET" "$socket" -trap 'rm -f "$socket"' EXIT - -while sleep 10; do :; done +#!/usr/bin/env bash +# +# This is a fake tailscale daemon that records its arguments, symlinks a +# fake LocalAPI socket into place, and does nothing until terminated. +# +# It is used by main_test.go to test the behavior of containerboot. + +set -eu + +echo $0 $@ >>$TS_TEST_RECORD_ARGS + +socket="" +while [[ $# -gt 0 ]]; do + case $1 in + --socket=*) + socket="${1#--socket=}" + shift + ;; + --socket) + shift + socket="$1" + shift + ;; + *) + shift + ;; + esac +done + +if [[ -z "$socket" ]]; then + echo "didn't find socket path in args" + exit 1 +fi + +ln -s "$TS_TEST_SOCKET" "$socket" +trap 'rm -f "$socket"' EXIT + +while sleep 10; do :; done diff --git a/cmd/get-authkey/.gitignore b/cmd/get-authkey/.gitignore index e00856fa12524..3f9c9fb90e68e 100644 --- a/cmd/get-authkey/.gitignore +++ b/cmd/get-authkey/.gitignore @@ -1 +1 @@ -get-authkey +get-authkey diff --git a/cmd/gitops-pusher/.gitignore b/cmd/gitops-pusher/.gitignore index eeed6e4bf5b1a..5044522494b23 100644 --- a/cmd/gitops-pusher/.gitignore +++ b/cmd/gitops-pusher/.gitignore @@ -1 +1 @@ -version-cache.json +version-cache.json diff --git a/cmd/gitops-pusher/README.md b/cmd/gitops-pusher/README.md index b08125397a1ec..9f77ea970e033 100644 --- a/cmd/gitops-pusher/README.md +++ b/cmd/gitops-pusher/README.md @@ -1,48 +1,48 @@ -# gitops-pusher - -This is a small tool to help people achieve a -[GitOps](https://about.gitlab.com/topics/gitops/) workflow with Tailscale ACL -changes. This tool is intended to be used in a CI flow that looks like this: - -```yaml -name: Tailscale ACL syncing - -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] - -jobs: - acls: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - name: Setup Go environment - uses: actions/setup-go@v3.2.0 - - - name: Install gitops-pusher - run: go install tailscale.com/cmd/gitops-pusher@latest - - - name: Deploy ACL - if: github.event_name == 'push' - env: - TS_API_KEY: ${{ secrets.TS_API_KEY }} - TS_TAILNET: ${{ secrets.TS_TAILNET }} - run: | - ~/go/bin/gitops-pusher --policy-file ./policy.hujson apply - - - name: ACL tests - if: github.event_name == 'pull_request' - env: - TS_API_KEY: ${{ secrets.TS_API_KEY }} - TS_TAILNET: ${{ secrets.TS_TAILNET }} - run: | - ~/go/bin/gitops-pusher --policy-file ./policy.hujson test -``` - -Change the value of the `--policy-file` flag to point to the policy file on -disk. Policy files should be in [HuJSON](https://github.com/tailscale/hujson) -format. +# gitops-pusher + +This is a small tool to help people achieve a +[GitOps](https://about.gitlab.com/topics/gitops/) workflow with Tailscale ACL +changes. This tool is intended to be used in a CI flow that looks like this: + +```yaml +name: Tailscale ACL syncing + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + acls: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Setup Go environment + uses: actions/setup-go@v3.2.0 + + - name: Install gitops-pusher + run: go install tailscale.com/cmd/gitops-pusher@latest + + - name: Deploy ACL + if: github.event_name == 'push' + env: + TS_API_KEY: ${{ secrets.TS_API_KEY }} + TS_TAILNET: ${{ secrets.TS_TAILNET }} + run: | + ~/go/bin/gitops-pusher --policy-file ./policy.hujson apply + + - name: ACL tests + if: github.event_name == 'pull_request' + env: + TS_API_KEY: ${{ secrets.TS_API_KEY }} + TS_TAILNET: ${{ secrets.TS_TAILNET }} + run: | + ~/go/bin/gitops-pusher --policy-file ./policy.hujson test +``` + +Change the value of the `--policy-file` flag to point to the policy file on +disk. Policy files should be in [HuJSON](https://github.com/tailscale/hujson) +format. diff --git a/cmd/gitops-pusher/cache.go b/cmd/gitops-pusher/cache.go index 89225e6f86309..6792e5e63e9cc 100644 --- a/cmd/gitops-pusher/cache.go +++ b/cmd/gitops-pusher/cache.go @@ -1,66 +1,66 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "encoding/json" - "os" -) - -// Cache contains cached information about the last time this tool was run. -// -// This is serialized to a JSON file that should NOT be checked into git. -// It should be managed with either CI cache tools or stored locally somehow. The -// exact mechanism is irrelevant as long as it is consistent. -// -// This allows gitops-pusher to detect external ACL changes. I'm not sure what to -// call this problem, so I've been calling it the "three version problem" in my -// notes. The basic problem is that at any given time we only have two versions -// of the ACL file at any given point. In order to check if there has been -// tampering of the ACL files in the admin panel, we need to have a _third_ version -// to compare against. -// -// In this case I am not storing the old ACL entirely (though that could be a -// reasonable thing to add in the future), but only its sha256sum. This allows -// us to detect if the shasum in control matches the shasum we expect, and if that -// expectation fails, then we can react accordingly. -type Cache struct { - PrevETag string // Stores the previous ETag of the ACL to allow -} - -// Save persists the cache to a given file. -func (c *Cache) Save(fname string) error { - os.Remove(fname) - fout, err := os.Create(fname) - if err != nil { - return err - } - defer fout.Close() - - return json.NewEncoder(fout).Encode(c) -} - -// LoadCache loads the cache from a given file. -func LoadCache(fname string) (*Cache, error) { - var result Cache - - fin, err := os.Open(fname) - if err != nil { - return nil, err - } - defer fin.Close() - - err = json.NewDecoder(fin).Decode(&result) - if err != nil { - return nil, err - } - - return &result, nil -} - -// Shuck removes the first and last character of a string, analogous to -// shucking off the husk of an ear of corn. -func Shuck(s string) string { - return s[1 : len(s)-1] -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "encoding/json" + "os" +) + +// Cache contains cached information about the last time this tool was run. +// +// This is serialized to a JSON file that should NOT be checked into git. +// It should be managed with either CI cache tools or stored locally somehow. The +// exact mechanism is irrelevant as long as it is consistent. +// +// This allows gitops-pusher to detect external ACL changes. I'm not sure what to +// call this problem, so I've been calling it the "three version problem" in my +// notes. The basic problem is that at any given time we only have two versions +// of the ACL file at any given point. In order to check if there has been +// tampering of the ACL files in the admin panel, we need to have a _third_ version +// to compare against. +// +// In this case I am not storing the old ACL entirely (though that could be a +// reasonable thing to add in the future), but only its sha256sum. This allows +// us to detect if the shasum in control matches the shasum we expect, and if that +// expectation fails, then we can react accordingly. +type Cache struct { + PrevETag string // Stores the previous ETag of the ACL to allow +} + +// Save persists the cache to a given file. +func (c *Cache) Save(fname string) error { + os.Remove(fname) + fout, err := os.Create(fname) + if err != nil { + return err + } + defer fout.Close() + + return json.NewEncoder(fout).Encode(c) +} + +// LoadCache loads the cache from a given file. +func LoadCache(fname string) (*Cache, error) { + var result Cache + + fin, err := os.Open(fname) + if err != nil { + return nil, err + } + defer fin.Close() + + err = json.NewDecoder(fin).Decode(&result) + if err != nil { + return nil, err + } + + return &result, nil +} + +// Shuck removes the first and last character of a string, analogous to +// shucking off the husk of an ear of corn. +func Shuck(s string) string { + return s[1 : len(s)-1] +} diff --git a/cmd/gitops-pusher/gitops-pusher_test.go b/cmd/gitops-pusher/gitops-pusher_test.go index 1beb049c67d5a..b050761d9832d 100644 --- a/cmd/gitops-pusher/gitops-pusher_test.go +++ b/cmd/gitops-pusher/gitops-pusher_test.go @@ -1,55 +1,55 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package main - -import ( - "encoding/json" - "strings" - "testing" - - "tailscale.com/client/tailscale" -) - -func TestEmbeddedTypeUnmarshal(t *testing.T) { - var gitopsErr ACLGitopsTestError - gitopsErr.Message = "gitops response error" - gitopsErr.Data = []tailscale.ACLTestFailureSummary{ - { - User: "GitopsError", - Errors: []string{"this was initially created as a gitops error"}, - }, - } - - var aclTestErr tailscale.ACLTestError - aclTestErr.Message = "native ACL response error" - aclTestErr.Data = []tailscale.ACLTestFailureSummary{ - { - User: "ACLError", - Errors: []string{"this was initially created as an ACL error"}, - }, - } - - t.Run("unmarshal gitops type from acl type", func(t *testing.T) { - b, _ := json.Marshal(aclTestErr) - var e ACLGitopsTestError - err := json.Unmarshal(b, &e) - if err != nil { - t.Fatal(err) - } - if !strings.Contains(e.Error(), "For user ACLError") { // the gitops error prints out the user, the acl error doesn't - t.Fatalf("user heading for 'ACLError' not found in gitops error: %v", e.Error()) - } - }) - t.Run("unmarshal acl type from gitops type", func(t *testing.T) { - b, _ := json.Marshal(gitopsErr) - var e tailscale.ACLTestError - err := json.Unmarshal(b, &e) - if err != nil { - t.Fatal(err) - } - expectedErr := `Status: 0, Message: "gitops response error", Data: [{User:GitopsError Errors:[this was initially created as a gitops error] Warnings:[]}]` - if e.Error() != expectedErr { - t.Fatalf("got %v\n, expected %v", e.Error(), expectedErr) - } - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package main + +import ( + "encoding/json" + "strings" + "testing" + + "tailscale.com/client/tailscale" +) + +func TestEmbeddedTypeUnmarshal(t *testing.T) { + var gitopsErr ACLGitopsTestError + gitopsErr.Message = "gitops response error" + gitopsErr.Data = []tailscale.ACLTestFailureSummary{ + { + User: "GitopsError", + Errors: []string{"this was initially created as a gitops error"}, + }, + } + + var aclTestErr tailscale.ACLTestError + aclTestErr.Message = "native ACL response error" + aclTestErr.Data = []tailscale.ACLTestFailureSummary{ + { + User: "ACLError", + Errors: []string{"this was initially created as an ACL error"}, + }, + } + + t.Run("unmarshal gitops type from acl type", func(t *testing.T) { + b, _ := json.Marshal(aclTestErr) + var e ACLGitopsTestError + err := json.Unmarshal(b, &e) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(e.Error(), "For user ACLError") { // the gitops error prints out the user, the acl error doesn't + t.Fatalf("user heading for 'ACLError' not found in gitops error: %v", e.Error()) + } + }) + t.Run("unmarshal acl type from gitops type", func(t *testing.T) { + b, _ := json.Marshal(gitopsErr) + var e tailscale.ACLTestError + err := json.Unmarshal(b, &e) + if err != nil { + t.Fatal(err) + } + expectedErr := `Status: 0, Message: "gitops response error", Data: [{User:GitopsError Errors:[this was initially created as a gitops error] Warnings:[]}]` + if e.Error() != expectedErr { + t.Fatalf("got %v\n, expected %v", e.Error(), expectedErr) + } + }) +} diff --git a/cmd/k8s-operator/deploy/chart/.helmignore b/cmd/k8s-operator/deploy/chart/.helmignore index f82e96d46779c..0e8a0eb36f4ca 100644 --- a/cmd/k8s-operator/deploy/chart/.helmignore +++ b/cmd/k8s-operator/deploy/chart/.helmignore @@ -1,23 +1,23 @@ -# Patterns to ignore when building packages. -# This supports shell glob matching, relative path matching, and -# negation (prefixed with !). Only one pattern per line. -.DS_Store -# Common VCS dirs -.git/ -.gitignore -.bzr/ -.bzrignore -.hg/ -.hgignore -.svn/ -# Common backup files -*.swp -*.bak -*.tmp -*.orig -*~ -# Various IDEs -.project -.idea/ -*.tmproj -.vscode/ +# Patterns to ignore when building packages. +# This supports shell glob matching, relative path matching, and +# negation (prefixed with !). Only one pattern per line. +.DS_Store +# Common VCS dirs +.git/ +.gitignore +.bzr/ +.bzrignore +.hg/ +.hgignore +.svn/ +# Common backup files +*.swp +*.bak +*.tmp +*.orig +*~ +# Various IDEs +.project +.idea/ +*.tmproj +.vscode/ diff --git a/cmd/k8s-operator/deploy/chart/Chart.yaml b/cmd/k8s-operator/deploy/chart/Chart.yaml index 472850c415200..363d87d15954a 100644 --- a/cmd/k8s-operator/deploy/chart/Chart.yaml +++ b/cmd/k8s-operator/deploy/chart/Chart.yaml @@ -1,29 +1,29 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -apiVersion: v2 -name: tailscale-operator -description: A Helm chart for Tailscale Kubernetes operator -home: https://github.com/tailscale/tailscale - -keywords: - - "tailscale" - - "vpn" - - "ingress" - - "egress" - - "wireguard" - -sources: -- https://github.com/tailscale/tailscale - -type: application - -maintainers: - - name: tailscale-maintainers - url: https://tailscale.com/ - -# version will be set to Tailscale repo tag (without 'v') at release time. -version: 0.1.0 - -# appVersion will be set to Tailscale repo tag at release time. -appVersion: "unstable" +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +apiVersion: v2 +name: tailscale-operator +description: A Helm chart for Tailscale Kubernetes operator +home: https://github.com/tailscale/tailscale + +keywords: + - "tailscale" + - "vpn" + - "ingress" + - "egress" + - "wireguard" + +sources: +- https://github.com/tailscale/tailscale + +type: application + +maintainers: + - name: tailscale-maintainers + url: https://tailscale.com/ + +# version will be set to Tailscale repo tag (without 'v') at release time. +version: 0.1.0 + +# appVersion will be set to Tailscale repo tag at release time. +appVersion: "unstable" diff --git a/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml b/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml index 488c87d8a09c5..072ecf6d22e2f 100644 --- a/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml @@ -1,26 +1,26 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -{{ if eq .Values.apiServerProxyConfig.mode "true" }} -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - name: tailscale-auth-proxy -rules: -- apiGroups: [""] - resources: ["users", "groups"] - verbs: ["impersonate"] ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRoleBinding -metadata: - name: tailscale-auth-proxy -subjects: -- kind: ServiceAccount - name: operator - namespace: {{ .Release.Namespace }} -roleRef: - kind: ClusterRole - name: tailscale-auth-proxy - apiGroup: rbac.authorization.k8s.io -{{ end }} +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +{{ if eq .Values.apiServerProxyConfig.mode "true" }} +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: tailscale-auth-proxy +rules: +- apiGroups: [""] + resources: ["users", "groups"] + verbs: ["impersonate"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: tailscale-auth-proxy +subjects: +- kind: ServiceAccount + name: operator + namespace: {{ .Release.Namespace }} +roleRef: + kind: ClusterRole + name: tailscale-auth-proxy + apiGroup: rbac.authorization.k8s.io +{{ end }} diff --git a/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml b/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml index bde64b7f625eb..b44fde0a17b49 100644 --- a/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml @@ -1,13 +1,13 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -{{ if and .Values.oauth .Values.oauth.clientId -}} -apiVersion: v1 -kind: Secret -metadata: - name: operator-oauth - namespace: {{ .Release.Namespace }} -stringData: - client_id: {{ .Values.oauth.clientId }} - client_secret: {{ .Values.oauth.clientSecret }} -{{- end -}} +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +{{ if and .Values.oauth .Values.oauth.clientId -}} +apiVersion: v1 +kind: Secret +metadata: + name: operator-oauth + namespace: {{ .Release.Namespace }} +stringData: + client_id: {{ .Values.oauth.clientId }} + client_secret: {{ .Values.oauth.clientSecret }} +{{- end -}} diff --git a/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml b/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml index d957260eb513f..ddbdda32e476e 100644 --- a/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml +++ b/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml @@ -1,24 +1,24 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - name: tailscale-auth-proxy -rules: -- apiGroups: [""] - resources: ["users", "groups"] - verbs: ["impersonate"] ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRoleBinding -metadata: - name: tailscale-auth-proxy -subjects: -- kind: ServiceAccount - name: operator - namespace: tailscale -roleRef: - kind: ClusterRole - name: tailscale-auth-proxy +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: tailscale-auth-proxy +rules: +- apiGroups: [""] + resources: ["users", "groups"] + verbs: ["impersonate"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: tailscale-auth-proxy +subjects: +- kind: ServiceAccount + name: operator + namespace: tailscale +roleRef: + kind: ClusterRole + name: tailscale-auth-proxy apiGroup: rbac.authorization.k8s.io \ No newline at end of file diff --git a/cmd/mkmanifest/main.go b/cmd/mkmanifest/main.go index 22cd150262cbb..fb3c729f12d21 100644 --- a/cmd/mkmanifest/main.go +++ b/cmd/mkmanifest/main.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The mkmanifest command is a simple helper utility to create a '.syso' file -// that contains a Windows manifest file. -package main - -import ( - "log" - "os" - - "github.com/tc-hib/winres" -) - -func main() { - if len(os.Args) != 4 { - log.Fatalf("usage: %s arch manifest.xml output.syso", os.Args[0]) - } - - arch := winres.Arch(os.Args[1]) - switch arch { - case winres.ArchAMD64, winres.ArchARM64, winres.ArchI386: - default: - log.Fatalf("unsupported arch: %s", arch) - } - - manifest, err := os.ReadFile(os.Args[2]) - if err != nil { - log.Fatalf("error reading manifest file %q: %v", os.Args[2], err) - } - - out := os.Args[3] - - // Start by creating an empty resource set - rs := winres.ResourceSet{} - - // Add resources - rs.Set(winres.RT_MANIFEST, winres.ID(1), 0, manifest) - - // Compile to a COFF object file - f, err := os.Create(out) - if err != nil { - log.Fatalf("error creating output file %q: %v", out, err) - } - if err := rs.WriteObject(f, arch); err != nil { - log.Fatalf("error writing object: %v", err) - } - if err := f.Close(); err != nil { - log.Fatalf("error writing output file %q: %v", out, err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The mkmanifest command is a simple helper utility to create a '.syso' file +// that contains a Windows manifest file. +package main + +import ( + "log" + "os" + + "github.com/tc-hib/winres" +) + +func main() { + if len(os.Args) != 4 { + log.Fatalf("usage: %s arch manifest.xml output.syso", os.Args[0]) + } + + arch := winres.Arch(os.Args[1]) + switch arch { + case winres.ArchAMD64, winres.ArchARM64, winres.ArchI386: + default: + log.Fatalf("unsupported arch: %s", arch) + } + + manifest, err := os.ReadFile(os.Args[2]) + if err != nil { + log.Fatalf("error reading manifest file %q: %v", os.Args[2], err) + } + + out := os.Args[3] + + // Start by creating an empty resource set + rs := winres.ResourceSet{} + + // Add resources + rs.Set(winres.RT_MANIFEST, winres.ID(1), 0, manifest) + + // Compile to a COFF object file + f, err := os.Create(out) + if err != nil { + log.Fatalf("error creating output file %q: %v", out, err) + } + if err := rs.WriteObject(f, arch); err != nil { + log.Fatalf("error writing object: %v", err) + } + if err := f.Close(); err != nil { + log.Fatalf("error writing output file %q: %v", out, err) + } +} diff --git a/cmd/mkpkg/main.go b/cmd/mkpkg/main.go index e942c0162a4fd..5e26b07f8f9f8 100644 --- a/cmd/mkpkg/main.go +++ b/cmd/mkpkg/main.go @@ -1,134 +1,134 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// mkpkg builds the Tailscale rpm and deb packages. -package main - -import ( - "flag" - "fmt" - "log" - "os" - "strings" - - "github.com/goreleaser/nfpm/v2" - _ "github.com/goreleaser/nfpm/v2/deb" - "github.com/goreleaser/nfpm/v2/files" - _ "github.com/goreleaser/nfpm/v2/rpm" -) - -// parseFiles parses a comma-separated list of colon-separated pairs -// into files.Contents format. -func parseFiles(s string, typ string) (files.Contents, error) { - if len(s) == 0 { - return nil, nil - } - var contents files.Contents - for _, f := range strings.Split(s, ",") { - fs := strings.Split(f, ":") - if len(fs) != 2 { - return nil, fmt.Errorf("unparseable file field %q", f) - } - contents = append(contents, &files.Content{Type: files.TypeFile, Source: fs[0], Destination: fs[1]}) - } - return contents, nil -} - -func parseEmptyDirs(s string) files.Contents { - // strings.Split("", ",") would return []string{""}, which is not suitable: - // this would create an empty dir record with path "", breaking the package - if s == "" { - return nil - } - var contents files.Contents - for _, d := range strings.Split(s, ",") { - contents = append(contents, &files.Content{Type: files.TypeDir, Destination: d}) - } - return contents -} - -func main() { - out := flag.String("out", "", "output file to write") - name := flag.String("name", "tailscale", "package name") - description := flag.String("description", "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", "package description") - goarch := flag.String("arch", "amd64", "GOARCH this package is for") - pkgType := flag.String("type", "deb", "type of package to build (deb or rpm)") - regularFiles := flag.String("files", "", "comma-separated list of files in src:dst form") - configFiles := flag.String("configs", "", "like --files, but for files marked as user-editable config files") - emptyDirs := flag.String("emptydirs", "", "comma-separated list of empty directories") - version := flag.String("version", "0.0.0", "version of the package") - postinst := flag.String("postinst", "", "debian postinst script path") - prerm := flag.String("prerm", "", "debian prerm script path") - postrm := flag.String("postrm", "", "debian postrm script path") - replaces := flag.String("replaces", "", "package which this package replaces, if any") - depends := flag.String("depends", "", "comma-separated list of packages this package depends on") - recommends := flag.String("recommends", "", "comma-separated list of packages this package recommends") - flag.Parse() - - filesList, err := parseFiles(*regularFiles, files.TypeFile) - if err != nil { - log.Fatalf("Parsing --files: %v", err) - } - configsList, err := parseFiles(*configFiles, files.TypeConfig) - if err != nil { - log.Fatalf("Parsing --configs: %v", err) - } - emptyDirList := parseEmptyDirs(*emptyDirs) - contents := append(filesList, append(configsList, emptyDirList...)...) - contents, err = files.PrepareForPackager(contents, 0, *pkgType, false) - if err != nil { - log.Fatalf("Building package contents: %v", err) - } - info := nfpm.WithDefaults(&nfpm.Info{ - Name: *name, - Arch: *goarch, - Platform: "linux", - Version: *version, - Maintainer: "Tailscale Inc ", - Description: *description, - Homepage: "https://www.tailscale.com", - License: "MIT", - Overridables: nfpm.Overridables{ - Contents: contents, - Scripts: nfpm.Scripts{ - PostInstall: *postinst, - PreRemove: *prerm, - PostRemove: *postrm, - }, - }, - }) - - if len(*depends) != 0 { - info.Overridables.Depends = strings.Split(*depends, ",") - } - if len(*recommends) != 0 { - info.Overridables.Recommends = strings.Split(*recommends, ",") - } - if *replaces != "" { - info.Overridables.Replaces = []string{*replaces} - info.Overridables.Conflicts = []string{*replaces} - } - - switch *pkgType { - case "deb": - info.Section = "net" - info.Priority = "extra" - case "rpm": - info.Overridables.RPM.Group = "Network" - } - - pkg, err := nfpm.Get(*pkgType) - if err != nil { - log.Fatalf("Getting packager for %q: %v", *pkgType, err) - } - - f, err := os.Create(*out) - if err != nil { - log.Fatalf("Creating output file %q: %v", *out, err) - } - defer f.Close() - - if err := pkg.Package(info, f); err != nil { - log.Fatalf("Creating package %q: %v", *out, err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// mkpkg builds the Tailscale rpm and deb packages. +package main + +import ( + "flag" + "fmt" + "log" + "os" + "strings" + + "github.com/goreleaser/nfpm/v2" + _ "github.com/goreleaser/nfpm/v2/deb" + "github.com/goreleaser/nfpm/v2/files" + _ "github.com/goreleaser/nfpm/v2/rpm" +) + +// parseFiles parses a comma-separated list of colon-separated pairs +// into files.Contents format. +func parseFiles(s string, typ string) (files.Contents, error) { + if len(s) == 0 { + return nil, nil + } + var contents files.Contents + for _, f := range strings.Split(s, ",") { + fs := strings.Split(f, ":") + if len(fs) != 2 { + return nil, fmt.Errorf("unparseable file field %q", f) + } + contents = append(contents, &files.Content{Type: files.TypeFile, Source: fs[0], Destination: fs[1]}) + } + return contents, nil +} + +func parseEmptyDirs(s string) files.Contents { + // strings.Split("", ",") would return []string{""}, which is not suitable: + // this would create an empty dir record with path "", breaking the package + if s == "" { + return nil + } + var contents files.Contents + for _, d := range strings.Split(s, ",") { + contents = append(contents, &files.Content{Type: files.TypeDir, Destination: d}) + } + return contents +} + +func main() { + out := flag.String("out", "", "output file to write") + name := flag.String("name", "tailscale", "package name") + description := flag.String("description", "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", "package description") + goarch := flag.String("arch", "amd64", "GOARCH this package is for") + pkgType := flag.String("type", "deb", "type of package to build (deb or rpm)") + regularFiles := flag.String("files", "", "comma-separated list of files in src:dst form") + configFiles := flag.String("configs", "", "like --files, but for files marked as user-editable config files") + emptyDirs := flag.String("emptydirs", "", "comma-separated list of empty directories") + version := flag.String("version", "0.0.0", "version of the package") + postinst := flag.String("postinst", "", "debian postinst script path") + prerm := flag.String("prerm", "", "debian prerm script path") + postrm := flag.String("postrm", "", "debian postrm script path") + replaces := flag.String("replaces", "", "package which this package replaces, if any") + depends := flag.String("depends", "", "comma-separated list of packages this package depends on") + recommends := flag.String("recommends", "", "comma-separated list of packages this package recommends") + flag.Parse() + + filesList, err := parseFiles(*regularFiles, files.TypeFile) + if err != nil { + log.Fatalf("Parsing --files: %v", err) + } + configsList, err := parseFiles(*configFiles, files.TypeConfig) + if err != nil { + log.Fatalf("Parsing --configs: %v", err) + } + emptyDirList := parseEmptyDirs(*emptyDirs) + contents := append(filesList, append(configsList, emptyDirList...)...) + contents, err = files.PrepareForPackager(contents, 0, *pkgType, false) + if err != nil { + log.Fatalf("Building package contents: %v", err) + } + info := nfpm.WithDefaults(&nfpm.Info{ + Name: *name, + Arch: *goarch, + Platform: "linux", + Version: *version, + Maintainer: "Tailscale Inc ", + Description: *description, + Homepage: "https://www.tailscale.com", + License: "MIT", + Overridables: nfpm.Overridables{ + Contents: contents, + Scripts: nfpm.Scripts{ + PostInstall: *postinst, + PreRemove: *prerm, + PostRemove: *postrm, + }, + }, + }) + + if len(*depends) != 0 { + info.Overridables.Depends = strings.Split(*depends, ",") + } + if len(*recommends) != 0 { + info.Overridables.Recommends = strings.Split(*recommends, ",") + } + if *replaces != "" { + info.Overridables.Replaces = []string{*replaces} + info.Overridables.Conflicts = []string{*replaces} + } + + switch *pkgType { + case "deb": + info.Section = "net" + info.Priority = "extra" + case "rpm": + info.Overridables.RPM.Group = "Network" + } + + pkg, err := nfpm.Get(*pkgType) + if err != nil { + log.Fatalf("Getting packager for %q: %v", *pkgType, err) + } + + f, err := os.Create(*out) + if err != nil { + log.Fatalf("Creating output file %q: %v", *out, err) + } + defer f.Close() + + if err := pkg.Package(info, f); err != nil { + log.Fatalf("Creating package %q: %v", *out, err) + } +} diff --git a/cmd/mkversion/mkversion.go b/cmd/mkversion/mkversion.go index 6a6a18a50d090..c8c8bf17930f6 100644 --- a/cmd/mkversion/mkversion.go +++ b/cmd/mkversion/mkversion.go @@ -1,44 +1,44 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// mkversion gets version info from git and outputs a bunch of shell variables -// that get used elsewhere in the build system to embed version numbers into -// binaries. -package main - -import ( - "bufio" - "bytes" - "fmt" - "io" - "os" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/version/mkversion" -) - -func main() { - prefix := "" - if len(os.Args) > 1 { - if os.Args[1] == "--export" { - prefix = "export " - } else { - fmt.Println("usage: mkversion [--export|-h|--help]") - os.Exit(1) - } - } - - var b bytes.Buffer - io.WriteString(&b, mkversion.Info().String()) - // Copyright and the client capability are not part of the version - // information, but similarly used in Xcode builds to embed in the metadata, - // thus generate them now. - copyright := fmt.Sprintf("Copyright © %d Tailscale Inc. All Rights Reserved.", time.Now().Year()) - fmt.Fprintf(&b, "VERSION_COPYRIGHT=%q\n", copyright) - fmt.Fprintf(&b, "VERSION_CAPABILITY=%d\n", tailcfg.CurrentCapabilityVersion) - s := bufio.NewScanner(&b) - for s.Scan() { - fmt.Println(prefix + s.Text()) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// mkversion gets version info from git and outputs a bunch of shell variables +// that get used elsewhere in the build system to embed version numbers into +// binaries. +package main + +import ( + "bufio" + "bytes" + "fmt" + "io" + "os" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/version/mkversion" +) + +func main() { + prefix := "" + if len(os.Args) > 1 { + if os.Args[1] == "--export" { + prefix = "export " + } else { + fmt.Println("usage: mkversion [--export|-h|--help]") + os.Exit(1) + } + } + + var b bytes.Buffer + io.WriteString(&b, mkversion.Info().String()) + // Copyright and the client capability are not part of the version + // information, but similarly used in Xcode builds to embed in the metadata, + // thus generate them now. + copyright := fmt.Sprintf("Copyright © %d Tailscale Inc. All Rights Reserved.", time.Now().Year()) + fmt.Fprintf(&b, "VERSION_COPYRIGHT=%q\n", copyright) + fmt.Fprintf(&b, "VERSION_CAPABILITY=%d\n", tailcfg.CurrentCapabilityVersion) + s := bufio.NewScanner(&b) + for s.Scan() { + fmt.Println(prefix + s.Text()) + } +} diff --git a/cmd/nardump/README.md b/cmd/nardump/README.md index 6c73ff9b0f399..6fa7fc2f1d345 100644 --- a/cmd/nardump/README.md +++ b/cmd/nardump/README.md @@ -1,7 +1,7 @@ -# nardump - -nardump is like nix-store --dump, but in Go, writing a NAR file (tar-like, -but focused on being reproducible) to stdout or to a hash with the --sri flag. - -It lets us calculate the Nix sha256 in shell.nix without the person running -git-pull-oss.sh having Nix available. +# nardump + +nardump is like nix-store --dump, but in Go, writing a NAR file (tar-like, +but focused on being reproducible) to stdout or to a hash with the --sri flag. + +It lets us calculate the Nix sha256 in shell.nix without the person running +git-pull-oss.sh having Nix available. diff --git a/cmd/nardump/nardump.go b/cmd/nardump/nardump.go index 241475537c418..05be7b65a7e37 100644 --- a/cmd/nardump/nardump.go +++ b/cmd/nardump/nardump.go @@ -1,184 +1,184 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// nardump is like nix-store --dump, but in Go, writing a NAR -// file (tar-like, but focused on being reproducible) to stdout -// or to a hash with the --sri flag. -// -// It lets us calculate a Nix sha256 without the person running -// git-pull-oss.sh having Nix available. -package main - -// For the format, see: -// See https://gist.github.com/jbeda/5c79d2b1434f0018d693 - -import ( - "bufio" - "crypto/sha256" - "encoding/base64" - "encoding/binary" - "flag" - "fmt" - "io" - "io/fs" - "log" - "os" - "path" - "sort" -) - -var sri = flag.Bool("sri", false, "print SRI") - -func main() { - flag.Parse() - if flag.NArg() != 1 { - log.Fatal("usage: nardump ") - } - arg := flag.Arg(0) - if err := os.Chdir(arg); err != nil { - log.Fatal(err) - } - if *sri { - hash := sha256.New() - if err := writeNAR(hash, os.DirFS(".")); err != nil { - log.Fatal(err) - } - fmt.Printf("sha256-%s\n", base64.StdEncoding.EncodeToString(hash.Sum(nil))) - return - } - bw := bufio.NewWriter(os.Stdout) - if err := writeNAR(bw, os.DirFS(".")); err != nil { - log.Fatal(err) - } - bw.Flush() -} - -// writeNARError is a sentinel panic type that's recovered by writeNAR -// and converted into the wrapped error. -type writeNARError struct{ err error } - -// narWriter writes NAR files. -type narWriter struct { - w io.Writer - fs fs.FS -} - -// writeNAR writes a NAR file to w from the root of fs. -func writeNAR(w io.Writer, fs fs.FS) (err error) { - defer func() { - if e := recover(); e != nil { - if we, ok := e.(writeNARError); ok { - err = we.err - return - } - panic(e) - } - }() - nw := &narWriter{w: w, fs: fs} - nw.str("nix-archive-1") - return nw.writeDir(".") -} - -func (nw *narWriter) writeDir(dirPath string) error { - ents, err := fs.ReadDir(nw.fs, dirPath) - if err != nil { - return err - } - sort.Slice(ents, func(i, j int) bool { - return ents[i].Name() < ents[j].Name() - }) - nw.str("(") - nw.str("type") - nw.str("directory") - for _, ent := range ents { - nw.str("entry") - nw.str("(") - nw.str("name") - nw.str(ent.Name()) - nw.str("node") - mode := ent.Type() - sub := path.Join(dirPath, ent.Name()) - var err error - switch { - case mode.IsRegular(): - err = nw.writeRegular(sub) - case mode.IsDir(): - err = nw.writeDir(sub) - default: - // TODO(bradfitz): symlink, but requires fighting io/fs a bit - // to get at Readlink or the osFS via fs. But for now - // we don't need symlinks because they're not in Go's archive. - return fmt.Errorf("unsupported file type %v at %q", sub, mode) - } - if err != nil { - return err - } - nw.str(")") - } - nw.str(")") - return nil -} - -func (nw *narWriter) writeRegular(path string) error { - nw.str("(") - nw.str("type") - nw.str("regular") - fi, err := fs.Stat(nw.fs, path) - if err != nil { - return err - } - if fi.Mode()&0111 != 0 { - nw.str("executable") - nw.str("") - } - contents, err := fs.ReadFile(nw.fs, path) - if err != nil { - return err - } - nw.str("contents") - if err := writeBytes(nw.w, contents); err != nil { - return err - } - nw.str(")") - return nil -} - -func (nw *narWriter) str(s string) { - if err := writeString(nw.w, s); err != nil { - panic(writeNARError{err}) - } -} - -func writeString(w io.Writer, s string) error { - var buf [8]byte - binary.LittleEndian.PutUint64(buf[:], uint64(len(s))) - if _, err := w.Write(buf[:]); err != nil { - return err - } - if _, err := io.WriteString(w, s); err != nil { - return err - } - return writePad(w, len(s)) -} - -func writeBytes(w io.Writer, b []byte) error { - var buf [8]byte - binary.LittleEndian.PutUint64(buf[:], uint64(len(b))) - if _, err := w.Write(buf[:]); err != nil { - return err - } - if _, err := w.Write(b); err != nil { - return err - } - return writePad(w, len(b)) -} - -func writePad(w io.Writer, n int) error { - pad := n % 8 - if pad == 0 { - return nil - } - var zeroes [8]byte - _, err := w.Write(zeroes[:8-pad]) - return err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// nardump is like nix-store --dump, but in Go, writing a NAR +// file (tar-like, but focused on being reproducible) to stdout +// or to a hash with the --sri flag. +// +// It lets us calculate a Nix sha256 without the person running +// git-pull-oss.sh having Nix available. +package main + +// For the format, see: +// See https://gist.github.com/jbeda/5c79d2b1434f0018d693 + +import ( + "bufio" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "flag" + "fmt" + "io" + "io/fs" + "log" + "os" + "path" + "sort" +) + +var sri = flag.Bool("sri", false, "print SRI") + +func main() { + flag.Parse() + if flag.NArg() != 1 { + log.Fatal("usage: nardump ") + } + arg := flag.Arg(0) + if err := os.Chdir(arg); err != nil { + log.Fatal(err) + } + if *sri { + hash := sha256.New() + if err := writeNAR(hash, os.DirFS(".")); err != nil { + log.Fatal(err) + } + fmt.Printf("sha256-%s\n", base64.StdEncoding.EncodeToString(hash.Sum(nil))) + return + } + bw := bufio.NewWriter(os.Stdout) + if err := writeNAR(bw, os.DirFS(".")); err != nil { + log.Fatal(err) + } + bw.Flush() +} + +// writeNARError is a sentinel panic type that's recovered by writeNAR +// and converted into the wrapped error. +type writeNARError struct{ err error } + +// narWriter writes NAR files. +type narWriter struct { + w io.Writer + fs fs.FS +} + +// writeNAR writes a NAR file to w from the root of fs. +func writeNAR(w io.Writer, fs fs.FS) (err error) { + defer func() { + if e := recover(); e != nil { + if we, ok := e.(writeNARError); ok { + err = we.err + return + } + panic(e) + } + }() + nw := &narWriter{w: w, fs: fs} + nw.str("nix-archive-1") + return nw.writeDir(".") +} + +func (nw *narWriter) writeDir(dirPath string) error { + ents, err := fs.ReadDir(nw.fs, dirPath) + if err != nil { + return err + } + sort.Slice(ents, func(i, j int) bool { + return ents[i].Name() < ents[j].Name() + }) + nw.str("(") + nw.str("type") + nw.str("directory") + for _, ent := range ents { + nw.str("entry") + nw.str("(") + nw.str("name") + nw.str(ent.Name()) + nw.str("node") + mode := ent.Type() + sub := path.Join(dirPath, ent.Name()) + var err error + switch { + case mode.IsRegular(): + err = nw.writeRegular(sub) + case mode.IsDir(): + err = nw.writeDir(sub) + default: + // TODO(bradfitz): symlink, but requires fighting io/fs a bit + // to get at Readlink or the osFS via fs. But for now + // we don't need symlinks because they're not in Go's archive. + return fmt.Errorf("unsupported file type %v at %q", sub, mode) + } + if err != nil { + return err + } + nw.str(")") + } + nw.str(")") + return nil +} + +func (nw *narWriter) writeRegular(path string) error { + nw.str("(") + nw.str("type") + nw.str("regular") + fi, err := fs.Stat(nw.fs, path) + if err != nil { + return err + } + if fi.Mode()&0111 != 0 { + nw.str("executable") + nw.str("") + } + contents, err := fs.ReadFile(nw.fs, path) + if err != nil { + return err + } + nw.str("contents") + if err := writeBytes(nw.w, contents); err != nil { + return err + } + nw.str(")") + return nil +} + +func (nw *narWriter) str(s string) { + if err := writeString(nw.w, s); err != nil { + panic(writeNARError{err}) + } +} + +func writeString(w io.Writer, s string) error { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], uint64(len(s))) + if _, err := w.Write(buf[:]); err != nil { + return err + } + if _, err := io.WriteString(w, s); err != nil { + return err + } + return writePad(w, len(s)) +} + +func writeBytes(w io.Writer, b []byte) error { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], uint64(len(b))) + if _, err := w.Write(buf[:]); err != nil { + return err + } + if _, err := w.Write(b); err != nil { + return err + } + return writePad(w, len(b)) +} + +func writePad(w io.Writer, n int) error { + pad := n % 8 + if pad == 0 { + return nil + } + var zeroes [8]byte + _, err := w.Write(zeroes[:8-pad]) + return err +} diff --git a/cmd/nginx-auth/.gitignore b/cmd/nginx-auth/.gitignore index 255276578b60d..3c608aeb1eede 100644 --- a/cmd/nginx-auth/.gitignore +++ b/cmd/nginx-auth/.gitignore @@ -1,4 +1,4 @@ -nga.sock -*.deb -*.rpm -tailscale.nginx-auth +nga.sock +*.deb +*.rpm +tailscale.nginx-auth diff --git a/cmd/nginx-auth/README.md b/cmd/nginx-auth/README.md index 869b1487bf57b..858f9ab81a83e 100644 --- a/cmd/nginx-auth/README.md +++ b/cmd/nginx-auth/README.md @@ -1,161 +1,161 @@ -# nginx-auth - -[![status: experimental](https://img.shields.io/badge/status-experimental-blue)](https://tailscale.com/kb/1167/release-stages/#experimental) - -This is a tool that allows users to use Tailscale Whois authentication with -NGINX as a reverse proxy. This allows users that already have a bunch of -services hosted on an internal NGINX server to point those domains to the -Tailscale IP of the NGINX server and then seamlessly use Tailscale for -authentication. - -Many thanks to [@zrail](https://twitter.com/zrail/status/1511788463586222087) on -Twitter for introducing the basic idea and offering some sample code. This -program is based on that sample code with security enhancements. Namely: - -* This listens over a UNIX socket instead of a TCP socket, to prevent - leakage to the network -* This uses systemd socket activation so that systemd owns the socket - and can then lock down the service to the bare minimum required to do - its job without having to worry about dropping permissions -* This provides additional information in HTTP response headers that can - be useful for integrating with various services - -## Configuration - -In order to protect a service with this tool, do the following in the respective -`server` block: - -Create an authentication location with the `internal` flag set: - -```nginx -location /auth { - internal; - - proxy_pass http://unix:/run/tailscale.nginx-auth.sock; - proxy_pass_request_body off; - - proxy_set_header Host $http_host; - proxy_set_header Remote-Addr $remote_addr; - proxy_set_header Remote-Port $remote_port; - proxy_set_header Original-URI $request_uri; -} -``` - -Then add the following to the `location /` block: - -``` -auth_request /auth; -auth_request_set $auth_user $upstream_http_tailscale_user; -auth_request_set $auth_name $upstream_http_tailscale_name; -auth_request_set $auth_login $upstream_http_tailscale_login; -auth_request_set $auth_tailnet $upstream_http_tailscale_tailnet; -auth_request_set $auth_profile_picture $upstream_http_tailscale_profile_picture; - -proxy_set_header X-Webauth-User "$auth_user"; -proxy_set_header X-Webauth-Name "$auth_name"; -proxy_set_header X-Webauth-Login "$auth_login"; -proxy_set_header X-Webauth-Tailnet "$auth_tailnet"; -proxy_set_header X-Webauth-Profile-Picture "$auth_profile_picture"; -``` - -When this configuration is used with a Go HTTP handler such as this: - -```go -http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { - e := json.NewEncoder(w) - e.SetIndent("", " ") - e.Encode(r.Header) -}) -``` - -You will get output like this: - -```json -{ - "Accept": [ - "*/*" - ], - "Connection": [ - "upgrade" - ], - "User-Agent": [ - "curl/7.82.0" - ], - "X-Webauth-Login": [ - "Xe" - ], - "X-Webauth-Name": [ - "Xe Iaso" - ], - "X-Webauth-Profile-Picture": [ - "https://avatars.githubusercontent.com/u/529003?v=4" - ], - "X-Webauth-Tailnet": [ - "cetacean.org.github" - ] - "X-Webauth-User": [ - "Xe@github" - ] -} -``` - -## Headers - -The authentication service provides the following headers to decorate your -proxied requests: - -| Header | Example Value | Description | -| :------ | :-------------- | :---------- | -| `Tailscale-User` | `azurediamond@hunter2.net` | The Tailscale username the remote machine is logged in as in user@host form | -| `Tailscale-Login` | `azurediamond` | The user portion of the Tailscale username the remote machine is logged in as | -| `Tailscale-Name` | `Azure Diamond` | The "real name" of the Tailscale user the machine is logged in as | -| `Tailscale-Profile-Picture` | `https://i.kym-cdn.com/photos/images/newsfeed/001/065/963/ae0.png` | The profile picture provided by the Identity Provider your tailnet uses | -| `Tailscale-Tailnet` | `hunter2.net` | The tailnet name | - -Most of the time you can set `X-Webauth-User` to the contents of the -`Tailscale-User` header, but some services may not accept a username with an `@` -symbol in it. If this is the case, set `X-Webauth-User` to the `Tailscale-Login` -header. - -The `Tailscale-Tailnet` header can help you identify which tailnet the session -is coming from. If you are using node sharing, this can help you make sure that -you aren't giving administrative access to people outside your tailnet. - -### Allow Requests From Only One Tailnet - -If you want to prevent node sharing from allowing users to access a service, add -the `Expected-Tailnet` header to your auth request: - -```nginx -location /auth { - # ... - proxy_set_header Expected-Tailnet "tailnet012345.ts.net"; -} -``` - -If a user from a different tailnet tries to use that service, this will return a -generic "forbidden" error page: - -```html - -403 Forbidden - -

403 Forbidden

-
nginx/1.18.0 (Ubuntu)
- - -``` - -You can get the tailnet name from [the admin panel](https://login.tailscale.com/admin/dns). - -## Building - -Install `cmd/mkpkg`: - -``` -cd .. && go install ./mkpkg -``` - -Then run `./mkdeb.sh`. It will emit a `.deb` and `.rpm` package for amd64 -machines (Linux uname flag: `x86_64`). You can add these to your deployment -methods as you see fit. +# nginx-auth + +[![status: experimental](https://img.shields.io/badge/status-experimental-blue)](https://tailscale.com/kb/1167/release-stages/#experimental) + +This is a tool that allows users to use Tailscale Whois authentication with +NGINX as a reverse proxy. This allows users that already have a bunch of +services hosted on an internal NGINX server to point those domains to the +Tailscale IP of the NGINX server and then seamlessly use Tailscale for +authentication. + +Many thanks to [@zrail](https://twitter.com/zrail/status/1511788463586222087) on +Twitter for introducing the basic idea and offering some sample code. This +program is based on that sample code with security enhancements. Namely: + +* This listens over a UNIX socket instead of a TCP socket, to prevent + leakage to the network +* This uses systemd socket activation so that systemd owns the socket + and can then lock down the service to the bare minimum required to do + its job without having to worry about dropping permissions +* This provides additional information in HTTP response headers that can + be useful for integrating with various services + +## Configuration + +In order to protect a service with this tool, do the following in the respective +`server` block: + +Create an authentication location with the `internal` flag set: + +```nginx +location /auth { + internal; + + proxy_pass http://unix:/run/tailscale.nginx-auth.sock; + proxy_pass_request_body off; + + proxy_set_header Host $http_host; + proxy_set_header Remote-Addr $remote_addr; + proxy_set_header Remote-Port $remote_port; + proxy_set_header Original-URI $request_uri; +} +``` + +Then add the following to the `location /` block: + +``` +auth_request /auth; +auth_request_set $auth_user $upstream_http_tailscale_user; +auth_request_set $auth_name $upstream_http_tailscale_name; +auth_request_set $auth_login $upstream_http_tailscale_login; +auth_request_set $auth_tailnet $upstream_http_tailscale_tailnet; +auth_request_set $auth_profile_picture $upstream_http_tailscale_profile_picture; + +proxy_set_header X-Webauth-User "$auth_user"; +proxy_set_header X-Webauth-Name "$auth_name"; +proxy_set_header X-Webauth-Login "$auth_login"; +proxy_set_header X-Webauth-Tailnet "$auth_tailnet"; +proxy_set_header X-Webauth-Profile-Picture "$auth_profile_picture"; +``` + +When this configuration is used with a Go HTTP handler such as this: + +```go +http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { + e := json.NewEncoder(w) + e.SetIndent("", " ") + e.Encode(r.Header) +}) +``` + +You will get output like this: + +```json +{ + "Accept": [ + "*/*" + ], + "Connection": [ + "upgrade" + ], + "User-Agent": [ + "curl/7.82.0" + ], + "X-Webauth-Login": [ + "Xe" + ], + "X-Webauth-Name": [ + "Xe Iaso" + ], + "X-Webauth-Profile-Picture": [ + "https://avatars.githubusercontent.com/u/529003?v=4" + ], + "X-Webauth-Tailnet": [ + "cetacean.org.github" + ] + "X-Webauth-User": [ + "Xe@github" + ] +} +``` + +## Headers + +The authentication service provides the following headers to decorate your +proxied requests: + +| Header | Example Value | Description | +| :------ | :-------------- | :---------- | +| `Tailscale-User` | `azurediamond@hunter2.net` | The Tailscale username the remote machine is logged in as in user@host form | +| `Tailscale-Login` | `azurediamond` | The user portion of the Tailscale username the remote machine is logged in as | +| `Tailscale-Name` | `Azure Diamond` | The "real name" of the Tailscale user the machine is logged in as | +| `Tailscale-Profile-Picture` | `https://i.kym-cdn.com/photos/images/newsfeed/001/065/963/ae0.png` | The profile picture provided by the Identity Provider your tailnet uses | +| `Tailscale-Tailnet` | `hunter2.net` | The tailnet name | + +Most of the time you can set `X-Webauth-User` to the contents of the +`Tailscale-User` header, but some services may not accept a username with an `@` +symbol in it. If this is the case, set `X-Webauth-User` to the `Tailscale-Login` +header. + +The `Tailscale-Tailnet` header can help you identify which tailnet the session +is coming from. If you are using node sharing, this can help you make sure that +you aren't giving administrative access to people outside your tailnet. + +### Allow Requests From Only One Tailnet + +If you want to prevent node sharing from allowing users to access a service, add +the `Expected-Tailnet` header to your auth request: + +```nginx +location /auth { + # ... + proxy_set_header Expected-Tailnet "tailnet012345.ts.net"; +} +``` + +If a user from a different tailnet tries to use that service, this will return a +generic "forbidden" error page: + +```html + +403 Forbidden + +

403 Forbidden

+
nginx/1.18.0 (Ubuntu)
+ + +``` + +You can get the tailnet name from [the admin panel](https://login.tailscale.com/admin/dns). + +## Building + +Install `cmd/mkpkg`: + +``` +cd .. && go install ./mkpkg +``` + +Then run `./mkdeb.sh`. It will emit a `.deb` and `.rpm` package for amd64 +machines (Linux uname flag: `x86_64`). You can add these to your deployment +methods as you see fit. diff --git a/cmd/nginx-auth/deb/postinst.sh b/cmd/nginx-auth/deb/postinst.sh index e692ced0757e3..d352a84885403 100755 --- a/cmd/nginx-auth/deb/postinst.sh +++ b/cmd/nginx-auth/deb/postinst.sh @@ -1,14 +1,14 @@ -if [ "$1" = "configure" ] || [ "$1" = "abort-upgrade" ] || [ "$1" = "abort-deconfigure" ] || [ "$1" = "abort-remove" ] ; then - deb-systemd-helper unmask 'tailscale.nginx-auth.socket' >/dev/null || true - if deb-systemd-helper --quiet was-enabled 'tailscale.nginx-auth.socket'; then - deb-systemd-helper enable 'tailscale.nginx-auth.socket' >/dev/null || true - else - deb-systemd-helper update-state 'tailscale.nginx-auth.socket' >/dev/null || true - fi - - if systemctl is-active tailscale.nginx-auth.socket >/dev/null; then - systemctl --system daemon-reload >/dev/null || true - deb-systemd-invoke stop 'tailscale.nginx-auth.service' >/dev/null || true - deb-systemd-invoke restart 'tailscale.nginx-auth.socket' >/dev/null || true - fi -fi +if [ "$1" = "configure" ] || [ "$1" = "abort-upgrade" ] || [ "$1" = "abort-deconfigure" ] || [ "$1" = "abort-remove" ] ; then + deb-systemd-helper unmask 'tailscale.nginx-auth.socket' >/dev/null || true + if deb-systemd-helper --quiet was-enabled 'tailscale.nginx-auth.socket'; then + deb-systemd-helper enable 'tailscale.nginx-auth.socket' >/dev/null || true + else + deb-systemd-helper update-state 'tailscale.nginx-auth.socket' >/dev/null || true + fi + + if systemctl is-active tailscale.nginx-auth.socket >/dev/null; then + systemctl --system daemon-reload >/dev/null || true + deb-systemd-invoke stop 'tailscale.nginx-auth.service' >/dev/null || true + deb-systemd-invoke restart 'tailscale.nginx-auth.socket' >/dev/null || true + fi +fi diff --git a/cmd/nginx-auth/deb/postrm.sh b/cmd/nginx-auth/deb/postrm.sh index 7870efd18fb39..4bce86139c6c2 100755 --- a/cmd/nginx-auth/deb/postrm.sh +++ b/cmd/nginx-auth/deb/postrm.sh @@ -1,19 +1,19 @@ -#!/bin/sh -set -e -if [ -d /run/systemd/system ] ; then - systemctl --system daemon-reload >/dev/null || true -fi - -if [ -x "/usr/bin/deb-systemd-helper" ]; then - if [ "$1" = "remove" ]; then - deb-systemd-helper mask 'tailscale.nginx-auth.socket' >/dev/null || true - deb-systemd-helper mask 'tailscale.nginx-auth.service' >/dev/null || true - fi - - if [ "$1" = "purge" ]; then - deb-systemd-helper purge 'tailscale.nginx-auth.socket' >/dev/null || true - deb-systemd-helper unmask 'tailscale.nginx-auth.socket' >/dev/null || true - deb-systemd-helper purge 'tailscale.nginx-auth.service' >/dev/null || true - deb-systemd-helper unmask 'tailscale.nginx-auth.service' >/dev/null || true - fi -fi +#!/bin/sh +set -e +if [ -d /run/systemd/system ] ; then + systemctl --system daemon-reload >/dev/null || true +fi + +if [ -x "/usr/bin/deb-systemd-helper" ]; then + if [ "$1" = "remove" ]; then + deb-systemd-helper mask 'tailscale.nginx-auth.socket' >/dev/null || true + deb-systemd-helper mask 'tailscale.nginx-auth.service' >/dev/null || true + fi + + if [ "$1" = "purge" ]; then + deb-systemd-helper purge 'tailscale.nginx-auth.socket' >/dev/null || true + deb-systemd-helper unmask 'tailscale.nginx-auth.socket' >/dev/null || true + deb-systemd-helper purge 'tailscale.nginx-auth.service' >/dev/null || true + deb-systemd-helper unmask 'tailscale.nginx-auth.service' >/dev/null || true + fi +fi diff --git a/cmd/nginx-auth/deb/prerm.sh b/cmd/nginx-auth/deb/prerm.sh index 22be23387c37e..e4becd17039ba 100755 --- a/cmd/nginx-auth/deb/prerm.sh +++ b/cmd/nginx-auth/deb/prerm.sh @@ -1,8 +1,8 @@ -#!/bin/sh -set -e -if [ "$1" = "remove" ]; then - if [ -d /run/systemd/system ]; then - deb-systemd-invoke stop 'tailscale.nginx-auth.service' >/dev/null || true - deb-systemd-invoke stop 'tailscale.nginx-auth.socket' >/dev/null || true - fi -fi +#!/bin/sh +set -e +if [ "$1" = "remove" ]; then + if [ -d /run/systemd/system ]; then + deb-systemd-invoke stop 'tailscale.nginx-auth.service' >/dev/null || true + deb-systemd-invoke stop 'tailscale.nginx-auth.socket' >/dev/null || true + fi +fi diff --git a/cmd/nginx-auth/mkdeb.sh b/cmd/nginx-auth/mkdeb.sh index 6a57210937f87..59f43230d0817 100755 --- a/cmd/nginx-auth/mkdeb.sh +++ b/cmd/nginx-auth/mkdeb.sh @@ -1,32 +1,32 @@ -#!/usr/bin/env bash - -set -e - -VERSION=0.1.3 -for ARCH in amd64 arm64; do - CGO_ENABLED=0 GOARCH=${ARCH} GOOS=linux go build -o tailscale.nginx-auth . - - mkpkg \ - --out=tailscale-nginx-auth-${VERSION}-${ARCH}.deb \ - --name=tailscale-nginx-auth \ - --version=${VERSION} \ - --type=deb \ - --arch=${ARCH} \ - --postinst=deb/postinst.sh \ - --postrm=deb/postrm.sh \ - --prerm=deb/prerm.sh \ - --description="Tailscale NGINX authentication protocol handler" \ - --files=./tailscale.nginx-auth:/usr/sbin/tailscale.nginx-auth,./tailscale.nginx-auth.socket:/lib/systemd/system/tailscale.nginx-auth.socket,./tailscale.nginx-auth.service:/lib/systemd/system/tailscale.nginx-auth.service,./README.md:/usr/share/tailscale/nginx-auth/README.md - - mkpkg \ - --out=tailscale-nginx-auth-${VERSION}-${ARCH}.rpm \ - --name=tailscale-nginx-auth \ - --version=${VERSION} \ - --type=rpm \ - --arch=${ARCH} \ - --postinst=rpm/postinst.sh \ - --postrm=rpm/postrm.sh \ - --prerm=rpm/prerm.sh \ - --description="Tailscale NGINX authentication protocol handler" \ - --files=./tailscale.nginx-auth:/usr/sbin/tailscale.nginx-auth,./tailscale.nginx-auth.socket:/lib/systemd/system/tailscale.nginx-auth.socket,./tailscale.nginx-auth.service:/lib/systemd/system/tailscale.nginx-auth.service,./README.md:/usr/share/tailscale/nginx-auth/README.md -done +#!/usr/bin/env bash + +set -e + +VERSION=0.1.3 +for ARCH in amd64 arm64; do + CGO_ENABLED=0 GOARCH=${ARCH} GOOS=linux go build -o tailscale.nginx-auth . + + mkpkg \ + --out=tailscale-nginx-auth-${VERSION}-${ARCH}.deb \ + --name=tailscale-nginx-auth \ + --version=${VERSION} \ + --type=deb \ + --arch=${ARCH} \ + --postinst=deb/postinst.sh \ + --postrm=deb/postrm.sh \ + --prerm=deb/prerm.sh \ + --description="Tailscale NGINX authentication protocol handler" \ + --files=./tailscale.nginx-auth:/usr/sbin/tailscale.nginx-auth,./tailscale.nginx-auth.socket:/lib/systemd/system/tailscale.nginx-auth.socket,./tailscale.nginx-auth.service:/lib/systemd/system/tailscale.nginx-auth.service,./README.md:/usr/share/tailscale/nginx-auth/README.md + + mkpkg \ + --out=tailscale-nginx-auth-${VERSION}-${ARCH}.rpm \ + --name=tailscale-nginx-auth \ + --version=${VERSION} \ + --type=rpm \ + --arch=${ARCH} \ + --postinst=rpm/postinst.sh \ + --postrm=rpm/postrm.sh \ + --prerm=rpm/prerm.sh \ + --description="Tailscale NGINX authentication protocol handler" \ + --files=./tailscale.nginx-auth:/usr/sbin/tailscale.nginx-auth,./tailscale.nginx-auth.socket:/lib/systemd/system/tailscale.nginx-auth.socket,./tailscale.nginx-auth.service:/lib/systemd/system/tailscale.nginx-auth.service,./README.md:/usr/share/tailscale/nginx-auth/README.md +done diff --git a/cmd/nginx-auth/nginx-auth.go b/cmd/nginx-auth/nginx-auth.go index befcb6d6c0423..09da74da1d3c8 100644 --- a/cmd/nginx-auth/nginx-auth.go +++ b/cmd/nginx-auth/nginx-auth.go @@ -1,128 +1,128 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -// Command nginx-auth is a tool that allows users to use Tailscale Whois -// authentication with NGINX as a reverse proxy. This allows users that -// already have a bunch of services hosted on an internal NGINX server -// to point those domains to the Tailscale IP of the NGINX server and -// then seamlessly use Tailscale for authentication. -package main - -import ( - "flag" - "log" - "net" - "net/http" - "net/netip" - "net/url" - "os" - "strings" - - "github.com/coreos/go-systemd/activation" - "tailscale.com/client/tailscale" -) - -var ( - sockPath = flag.String("sockpath", "", "the filesystem path for the unix socket this service exposes") -) - -func main() { - flag.Parse() - - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - remoteHost := r.Header.Get("Remote-Addr") - remotePort := r.Header.Get("Remote-Port") - if remoteHost == "" || remotePort == "" { - w.WriteHeader(http.StatusBadRequest) - log.Println("set Remote-Addr to $remote_addr and Remote-Port to $remote_port in your nginx config") - return - } - - remoteAddrStr := net.JoinHostPort(remoteHost, remotePort) - remoteAddr, err := netip.ParseAddrPort(remoteAddrStr) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - log.Printf("remote address and port are not valid: %v", err) - return - } - - info, err := tailscale.WhoIs(r.Context(), remoteAddr.String()) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - log.Printf("can't look up %s: %v", remoteAddr, err) - return - } - - if info.Node.IsTagged() { - w.WriteHeader(http.StatusForbidden) - log.Printf("node %s is tagged", info.Node.Hostinfo.Hostname()) - return - } - - // tailnet of connected node. When accessing shared nodes, this - // will be empty because the tailnet of the sharee is not exposed. - var tailnet string - - if !info.Node.Hostinfo.ShareeNode() { - var ok bool - _, tailnet, ok = strings.Cut(info.Node.Name, info.Node.ComputedName+".") - if !ok { - w.WriteHeader(http.StatusUnauthorized) - log.Printf("can't extract tailnet name from hostname %q", info.Node.Name) - return - } - tailnet = strings.TrimSuffix(tailnet, ".beta.tailscale.net") - } - - if expectedTailnet := r.Header.Get("Expected-Tailnet"); expectedTailnet != "" && expectedTailnet != tailnet { - w.WriteHeader(http.StatusForbidden) - log.Printf("user is part of tailnet %s, wanted: %s", tailnet, url.QueryEscape(expectedTailnet)) - return - } - - h := w.Header() - h.Set("Tailscale-Login", strings.Split(info.UserProfile.LoginName, "@")[0]) - h.Set("Tailscale-User", info.UserProfile.LoginName) - h.Set("Tailscale-Name", info.UserProfile.DisplayName) - h.Set("Tailscale-Profile-Picture", info.UserProfile.ProfilePicURL) - h.Set("Tailscale-Tailnet", tailnet) - w.WriteHeader(http.StatusNoContent) - }) - - if *sockPath != "" { - _ = os.Remove(*sockPath) // ignore error, this file may not already exist - ln, err := net.Listen("unix", *sockPath) - if err != nil { - log.Fatalf("can't listen on %s: %v", *sockPath, err) - } - defer ln.Close() - - log.Printf("listening on %s", *sockPath) - log.Fatal(http.Serve(ln, mux)) - } - - listeners, err := activation.Listeners() - if err != nil { - log.Fatalf("no sockets passed to this service with systemd: %v", err) - } - - // NOTE(Xe): normally you'd want to make a waitgroup here and then register - // each listener with it. In this case I want this to blow up horribly if - // any of the listeners stop working. systemd will restart it due to the - // socket activation at play. - // - // TL;DR: Let it crash, it will come back - for _, ln := range listeners { - go func(ln net.Listener) { - log.Printf("listening on %s", ln.Addr()) - log.Fatal(http.Serve(ln, mux)) - }(ln) - } - - for { - select {} - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +// Command nginx-auth is a tool that allows users to use Tailscale Whois +// authentication with NGINX as a reverse proxy. This allows users that +// already have a bunch of services hosted on an internal NGINX server +// to point those domains to the Tailscale IP of the NGINX server and +// then seamlessly use Tailscale for authentication. +package main + +import ( + "flag" + "log" + "net" + "net/http" + "net/netip" + "net/url" + "os" + "strings" + + "github.com/coreos/go-systemd/activation" + "tailscale.com/client/tailscale" +) + +var ( + sockPath = flag.String("sockpath", "", "the filesystem path for the unix socket this service exposes") +) + +func main() { + flag.Parse() + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + remoteHost := r.Header.Get("Remote-Addr") + remotePort := r.Header.Get("Remote-Port") + if remoteHost == "" || remotePort == "" { + w.WriteHeader(http.StatusBadRequest) + log.Println("set Remote-Addr to $remote_addr and Remote-Port to $remote_port in your nginx config") + return + } + + remoteAddrStr := net.JoinHostPort(remoteHost, remotePort) + remoteAddr, err := netip.ParseAddrPort(remoteAddrStr) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + log.Printf("remote address and port are not valid: %v", err) + return + } + + info, err := tailscale.WhoIs(r.Context(), remoteAddr.String()) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + log.Printf("can't look up %s: %v", remoteAddr, err) + return + } + + if info.Node.IsTagged() { + w.WriteHeader(http.StatusForbidden) + log.Printf("node %s is tagged", info.Node.Hostinfo.Hostname()) + return + } + + // tailnet of connected node. When accessing shared nodes, this + // will be empty because the tailnet of the sharee is not exposed. + var tailnet string + + if !info.Node.Hostinfo.ShareeNode() { + var ok bool + _, tailnet, ok = strings.Cut(info.Node.Name, info.Node.ComputedName+".") + if !ok { + w.WriteHeader(http.StatusUnauthorized) + log.Printf("can't extract tailnet name from hostname %q", info.Node.Name) + return + } + tailnet = strings.TrimSuffix(tailnet, ".beta.tailscale.net") + } + + if expectedTailnet := r.Header.Get("Expected-Tailnet"); expectedTailnet != "" && expectedTailnet != tailnet { + w.WriteHeader(http.StatusForbidden) + log.Printf("user is part of tailnet %s, wanted: %s", tailnet, url.QueryEscape(expectedTailnet)) + return + } + + h := w.Header() + h.Set("Tailscale-Login", strings.Split(info.UserProfile.LoginName, "@")[0]) + h.Set("Tailscale-User", info.UserProfile.LoginName) + h.Set("Tailscale-Name", info.UserProfile.DisplayName) + h.Set("Tailscale-Profile-Picture", info.UserProfile.ProfilePicURL) + h.Set("Tailscale-Tailnet", tailnet) + w.WriteHeader(http.StatusNoContent) + }) + + if *sockPath != "" { + _ = os.Remove(*sockPath) // ignore error, this file may not already exist + ln, err := net.Listen("unix", *sockPath) + if err != nil { + log.Fatalf("can't listen on %s: %v", *sockPath, err) + } + defer ln.Close() + + log.Printf("listening on %s", *sockPath) + log.Fatal(http.Serve(ln, mux)) + } + + listeners, err := activation.Listeners() + if err != nil { + log.Fatalf("no sockets passed to this service with systemd: %v", err) + } + + // NOTE(Xe): normally you'd want to make a waitgroup here and then register + // each listener with it. In this case I want this to blow up horribly if + // any of the listeners stop working. systemd will restart it due to the + // socket activation at play. + // + // TL;DR: Let it crash, it will come back + for _, ln := range listeners { + go func(ln net.Listener) { + log.Printf("listening on %s", ln.Addr()) + log.Fatal(http.Serve(ln, mux)) + }(ln) + } + + for { + select {} + } +} diff --git a/cmd/nginx-auth/rpm/postrm.sh b/cmd/nginx-auth/rpm/postrm.sh index d8d36893fd931..3d0abfb199137 100755 --- a/cmd/nginx-auth/rpm/postrm.sh +++ b/cmd/nginx-auth/rpm/postrm.sh @@ -1,9 +1,9 @@ -# $1 == 0 for uninstallation. -# $1 == 1 for removing old package during upgrade. - -systemctl daemon-reload >/dev/null 2>&1 || : -if [ $1 -ge 1 ] ; then - # Package upgrade, not uninstall - systemctl stop tailscale.nginx-auth.service >/dev/null 2>&1 || : - systemctl try-restart tailscale.nginx-auth.socket >/dev/null 2>&1 || : -fi +# $1 == 0 for uninstallation. +# $1 == 1 for removing old package during upgrade. + +systemctl daemon-reload >/dev/null 2>&1 || : +if [ $1 -ge 1 ] ; then + # Package upgrade, not uninstall + systemctl stop tailscale.nginx-auth.service >/dev/null 2>&1 || : + systemctl try-restart tailscale.nginx-auth.socket >/dev/null 2>&1 || : +fi diff --git a/cmd/nginx-auth/rpm/prerm.sh b/cmd/nginx-auth/rpm/prerm.sh index 2e47a53ed9356..1f198d8292bc5 100755 --- a/cmd/nginx-auth/rpm/prerm.sh +++ b/cmd/nginx-auth/rpm/prerm.sh @@ -1,9 +1,9 @@ -# $1 == 0 for uninstallation. -# $1 == 1 for removing old package during upgrade. - -if [ $1 -eq 0 ] ; then - # Package removal, not upgrade - systemctl --no-reload disable tailscale.nginx-auth.socket > /dev/null 2>&1 || : - systemctl stop tailscale.nginx-auth.socket > /dev/null 2>&1 || : - systemctl stop tailscale.nginx-auth.service > /dev/null 2>&1 || : -fi +# $1 == 0 for uninstallation. +# $1 == 1 for removing old package during upgrade. + +if [ $1 -eq 0 ] ; then + # Package removal, not upgrade + systemctl --no-reload disable tailscale.nginx-auth.socket > /dev/null 2>&1 || : + systemctl stop tailscale.nginx-auth.socket > /dev/null 2>&1 || : + systemctl stop tailscale.nginx-auth.service > /dev/null 2>&1 || : +fi diff --git a/cmd/nginx-auth/tailscale.nginx-auth.service b/cmd/nginx-auth/tailscale.nginx-auth.service index 8534e25c1048d..086f6c7741d88 100644 --- a/cmd/nginx-auth/tailscale.nginx-auth.service +++ b/cmd/nginx-auth/tailscale.nginx-auth.service @@ -1,11 +1,11 @@ -[Unit] -Description=Tailscale NGINX Authentication service -After=nginx.service -Wants=nginx.service - -[Service] -ExecStart=/usr/sbin/tailscale.nginx-auth -DynamicUser=yes - -[Install] -WantedBy=default.target +[Unit] +Description=Tailscale NGINX Authentication service +After=nginx.service +Wants=nginx.service + +[Service] +ExecStart=/usr/sbin/tailscale.nginx-auth +DynamicUser=yes + +[Install] +WantedBy=default.target diff --git a/cmd/nginx-auth/tailscale.nginx-auth.socket b/cmd/nginx-auth/tailscale.nginx-auth.socket index 53e3e8d83edf3..7e5641ff3a2f5 100644 --- a/cmd/nginx-auth/tailscale.nginx-auth.socket +++ b/cmd/nginx-auth/tailscale.nginx-auth.socket @@ -1,9 +1,9 @@ -[Unit] -Description=Tailscale NGINX Authentication socket -PartOf=tailscale.nginx-auth.service - -[Socket] -ListenStream=/var/run/tailscale.nginx-auth.sock - -[Install] +[Unit] +Description=Tailscale NGINX Authentication socket +PartOf=tailscale.nginx-auth.service + +[Socket] +ListenStream=/var/run/tailscale.nginx-auth.sock + +[Install] WantedBy=sockets.target \ No newline at end of file diff --git a/cmd/pgproxy/README.md b/cmd/pgproxy/README.md index a867ad8cad9de..2e013072a1900 100644 --- a/cmd/pgproxy/README.md +++ b/cmd/pgproxy/README.md @@ -1,42 +1,42 @@ -# pgproxy - -The pgproxy server is a proxy for the Postgres wire protocol. [Read -more in our blog -post](https://tailscale.com/blog/introducing-pgproxy/) about it! - -The proxy runs an in-process Tailscale instance, accepts postgres -client connections over Tailscale only, and proxies them to the -configured upstream postgres server. - -This proxy exists because postgres clients default to very insecure -connection settings: either they "prefer" but do not require TLS; or -they set sslmode=require, which merely requires that a TLS handshake -took place, but don't verify the server's TLS certificate or the -presented TLS hostname. In other words, sslmode=require enforces that -a TLS session is created, but that session can trivially be -machine-in-the-middled to steal credentials, data, inject malicious -queries, and so forth. - -Because this flaw is in the client's validation of the TLS session, -you have no way of reliably detecting the misconfiguration -server-side. You could fix the configuration of all the clients you -know of, but the default makes it very easy to accidentally regress. - -Instead of trying to verify client configuration over time, this proxy -removes the need for postgres clients to be configured correctly: the -upstream database is configured to only accept connections from the -proxy, and the proxy is only available to clients over Tailscale. - -Therefore, clients must use the proxy to connect to the database. The -client<>proxy connection is secured end-to-end by Tailscale, which the -proxy enforces by verifying that the connecting client is a known -current Tailscale peer. The proxy<>server connection is established by -the proxy itself, using strict TLS verification settings, and the -client is only allowed to communicate with the server once we've -established that the upstream connection is safe to use. - -A couple side benefits: because clients can only connect via -Tailscale, you can use Tailscale ACLs as an extra layer of defense on -top of the postgres user/password authentication. And, the proxy can -maintain an audit log of who connected to the database, complete with -the strongly authenticated Tailscale identity of the client. +# pgproxy + +The pgproxy server is a proxy for the Postgres wire protocol. [Read +more in our blog +post](https://tailscale.com/blog/introducing-pgproxy/) about it! + +The proxy runs an in-process Tailscale instance, accepts postgres +client connections over Tailscale only, and proxies them to the +configured upstream postgres server. + +This proxy exists because postgres clients default to very insecure +connection settings: either they "prefer" but do not require TLS; or +they set sslmode=require, which merely requires that a TLS handshake +took place, but don't verify the server's TLS certificate or the +presented TLS hostname. In other words, sslmode=require enforces that +a TLS session is created, but that session can trivially be +machine-in-the-middled to steal credentials, data, inject malicious +queries, and so forth. + +Because this flaw is in the client's validation of the TLS session, +you have no way of reliably detecting the misconfiguration +server-side. You could fix the configuration of all the clients you +know of, but the default makes it very easy to accidentally regress. + +Instead of trying to verify client configuration over time, this proxy +removes the need for postgres clients to be configured correctly: the +upstream database is configured to only accept connections from the +proxy, and the proxy is only available to clients over Tailscale. + +Therefore, clients must use the proxy to connect to the database. The +client<>proxy connection is secured end-to-end by Tailscale, which the +proxy enforces by verifying that the connecting client is a known +current Tailscale peer. The proxy<>server connection is established by +the proxy itself, using strict TLS verification settings, and the +client is only allowed to communicate with the server once we've +established that the upstream connection is safe to use. + +A couple side benefits: because clients can only connect via +Tailscale, you can use Tailscale ACLs as an extra layer of defense on +top of the postgres user/password authentication. And, the proxy can +maintain an audit log of who connected to the database, complete with +the strongly authenticated Tailscale identity of the client. diff --git a/cmd/printdep/printdep.go b/cmd/printdep/printdep.go index 0790a8b813cc6..044283209c08c 100644 --- a/cmd/printdep/printdep.go +++ b/cmd/printdep/printdep.go @@ -1,41 +1,41 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The printdep command is a build system tool for printing out information -// about dependencies. -package main - -import ( - "flag" - "fmt" - "log" - "runtime" - "strings" - - ts "tailscale.com" -) - -var ( - goToolchain = flag.Bool("go", false, "print the supported Go toolchain git hash (a github.com/tailscale/go commit)") - goToolchainURL = flag.Bool("go-url", false, "print the URL to the tarball of the Tailscale Go toolchain") - alpine = flag.Bool("alpine", false, "print the tag of alpine docker image") -) - -func main() { - flag.Parse() - if *alpine { - fmt.Println(strings.TrimSpace(ts.AlpineDockerTag)) - return - } - if *goToolchain { - fmt.Println(strings.TrimSpace(ts.GoToolchainRev)) - } - if *goToolchainURL { - switch runtime.GOOS { - case "linux", "darwin": - default: - log.Fatalf("unsupported GOOS %q", runtime.GOOS) - } - fmt.Printf("https://github.com/tailscale/go/releases/download/build-%s/%s-%s.tar.gz\n", strings.TrimSpace(ts.GoToolchainRev), runtime.GOOS, runtime.GOARCH) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The printdep command is a build system tool for printing out information +// about dependencies. +package main + +import ( + "flag" + "fmt" + "log" + "runtime" + "strings" + + ts "tailscale.com" +) + +var ( + goToolchain = flag.Bool("go", false, "print the supported Go toolchain git hash (a github.com/tailscale/go commit)") + goToolchainURL = flag.Bool("go-url", false, "print the URL to the tarball of the Tailscale Go toolchain") + alpine = flag.Bool("alpine", false, "print the tag of alpine docker image") +) + +func main() { + flag.Parse() + if *alpine { + fmt.Println(strings.TrimSpace(ts.AlpineDockerTag)) + return + } + if *goToolchain { + fmt.Println(strings.TrimSpace(ts.GoToolchainRev)) + } + if *goToolchainURL { + switch runtime.GOOS { + case "linux", "darwin": + default: + log.Fatalf("unsupported GOOS %q", runtime.GOOS) + } + fmt.Printf("https://github.com/tailscale/go/releases/download/build-%s/%s-%s.tar.gz\n", strings.TrimSpace(ts.GoToolchainRev), runtime.GOOS, runtime.GOARCH) + } +} diff --git a/cmd/sniproxy/.gitignore b/cmd/sniproxy/.gitignore index 0bca339122774..b1399c88167d4 100644 --- a/cmd/sniproxy/.gitignore +++ b/cmd/sniproxy/.gitignore @@ -1 +1 @@ -sniproxy +sniproxy diff --git a/cmd/sniproxy/handlers_test.go b/cmd/sniproxy/handlers_test.go index 8ec5b097c9b3c..4f9fc6a34b184 100644 --- a/cmd/sniproxy/handlers_test.go +++ b/cmd/sniproxy/handlers_test.go @@ -1,159 +1,159 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "bytes" - "context" - "encoding/hex" - "io" - "net" - "net/netip" - "strings" - "testing" - - "tailscale.com/net/memnet" -) - -func echoConnOnce(conn net.Conn) { - defer conn.Close() - - b := make([]byte, 256) - n, err := conn.Read(b) - if err != nil { - return - } - - if _, err := conn.Write(b[:n]); err != nil { - return - } -} - -func TestTCPRoundRobinHandler(t *testing.T) { - h := tcpRoundRobinHandler{ - To: []string{"yeet.com"}, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - if network != "tcp" { - t.Errorf("network = %s, want %s", network, "tcp") - } - if addr != "yeet.com:22" { - t.Errorf("addr = %s, want %s", addr, "yeet.com:22") - } - - c, s := memnet.NewConn("outbound", 1024) - go echoConnOnce(s) - return c, nil - }, - } - - cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:22"), 1024) - h.Handle(sSock) - - // Test data write and read, the other end will echo back - // a single stanza - want := "hello" - if _, err := io.WriteString(cSock, want); err != nil { - t.Fatal(err) - } - got := make([]byte, len(want)) - if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { - t.Fatal(err) - } - if string(got) != want { - t.Errorf("got %q, want %q", got, want) - } - - // The other end closed the socket after the first echo, so - // any following read should error. - io.WriteString(cSock, "deadass heres some data on god fr") - if _, err := io.ReadAtLeast(cSock, got, len(got)); err == nil { - t.Error("read succeeded on closed socket") - } -} - -// Capture of first TCP data segment for a connection to https://pkgs.tailscale.com -const tlsStart = `45000239ff1840004006f9f5c0a801f2 -c726b5efcf9e01bbe803b21394e3b752 -801801f641dc00000101080ade3474f2 -2fb93ee71603010200010001fc030303 -c3acbd19d2624765bb19af4bce03365e -1d197f5bb939cdadeff26b0f8e7a0620 -295b04127b82bae46aac4ff58cffef25 -eba75a4b7a6de729532c411bd9dd0d2c -00203a3a130113021303c02bc02fc02c -c030cca9cca8c013c014009c009d002f -003501000193caca0000000a000a0008 -1a1a001d001700180010000e000c0268 -3208687474702f312e31002b0007062a -2a03040303ff01000100000d00120010 -04030804040105030805050108060601 -000b00020100002300000033002b0029 -1a1a000100001d0020d3c76bef062979 -a812ce935cfb4dbe6b3a84dc5ba9226f -23b0f34af9d1d03b4a001b0003020002 -00120000446900050003026832000000 -170015000012706b67732e7461696c73 -63616c652e636f6d002d000201010005 -00050100000000001700003a3a000100 -0015002d000000000000000000000000 -00000000000000000000000000000000 -00000000000000000000000000000000 -0000290094006f0069e76f2016f963ad -38c8632d1f240cd75e00e25fdef295d4 -7042b26f3a9a543b1c7dc74939d77803 -20527d423ff996997bda2c6383a14f49 -219eeef8a053e90a32228df37ddbe126 -eccf6b085c93890d08341d819aea6111 -0d909f4cd6b071d9ea40618e74588a33 -90d494bbb5c3002120d5a164a16c9724 -c9ef5e540d8d6f007789a7acf9f5f16f -bf6a1907a6782ed02b` - -func fakeSNIHeader() []byte { - b, err := hex.DecodeString(strings.Replace(tlsStart, "\n", "", -1)) - if err != nil { - panic(err) - } - return b[0x34:] // trim IP + TCP header -} - -func TestTCPSNIHandler(t *testing.T) { - h := tcpSNIHandler{ - Allowlist: []string{"pkgs.tailscale.com"}, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - if network != "tcp" { - t.Errorf("network = %s, want %s", network, "tcp") - } - if addr != "pkgs.tailscale.com:443" { - t.Errorf("addr = %s, want %s", addr, "pkgs.tailscale.com:443") - } - - c, s := memnet.NewConn("outbound", 1024) - go echoConnOnce(s) - return c, nil - }, - } - - cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:443"), 1024) - h.Handle(sSock) - - // Fake a TLS handshake record with an SNI in it. - if _, err := cSock.Write(fakeSNIHeader()); err != nil { - t.Fatal(err) - } - - // Test read, the other end will echo back - // a single stanza, which is at least the beginning of the SNI header. - want := fakeSNIHeader()[:5] - if _, err := cSock.Write(want); err != nil { - t.Fatal(err) - } - got := make([]byte, len(want)) - if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { - t.Fatal(err) - } - if !bytes.Equal(got, want) { - t.Errorf("got %q, want %q", got, want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "bytes" + "context" + "encoding/hex" + "io" + "net" + "net/netip" + "strings" + "testing" + + "tailscale.com/net/memnet" +) + +func echoConnOnce(conn net.Conn) { + defer conn.Close() + + b := make([]byte, 256) + n, err := conn.Read(b) + if err != nil { + return + } + + if _, err := conn.Write(b[:n]); err != nil { + return + } +} + +func TestTCPRoundRobinHandler(t *testing.T) { + h := tcpRoundRobinHandler{ + To: []string{"yeet.com"}, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if network != "tcp" { + t.Errorf("network = %s, want %s", network, "tcp") + } + if addr != "yeet.com:22" { + t.Errorf("addr = %s, want %s", addr, "yeet.com:22") + } + + c, s := memnet.NewConn("outbound", 1024) + go echoConnOnce(s) + return c, nil + }, + } + + cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:22"), 1024) + h.Handle(sSock) + + // Test data write and read, the other end will echo back + // a single stanza + want := "hello" + if _, err := io.WriteString(cSock, want); err != nil { + t.Fatal(err) + } + got := make([]byte, len(want)) + if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { + t.Fatal(err) + } + if string(got) != want { + t.Errorf("got %q, want %q", got, want) + } + + // The other end closed the socket after the first echo, so + // any following read should error. + io.WriteString(cSock, "deadass heres some data on god fr") + if _, err := io.ReadAtLeast(cSock, got, len(got)); err == nil { + t.Error("read succeeded on closed socket") + } +} + +// Capture of first TCP data segment for a connection to https://pkgs.tailscale.com +const tlsStart = `45000239ff1840004006f9f5c0a801f2 +c726b5efcf9e01bbe803b21394e3b752 +801801f641dc00000101080ade3474f2 +2fb93ee71603010200010001fc030303 +c3acbd19d2624765bb19af4bce03365e +1d197f5bb939cdadeff26b0f8e7a0620 +295b04127b82bae46aac4ff58cffef25 +eba75a4b7a6de729532c411bd9dd0d2c +00203a3a130113021303c02bc02fc02c +c030cca9cca8c013c014009c009d002f +003501000193caca0000000a000a0008 +1a1a001d001700180010000e000c0268 +3208687474702f312e31002b0007062a +2a03040303ff01000100000d00120010 +04030804040105030805050108060601 +000b00020100002300000033002b0029 +1a1a000100001d0020d3c76bef062979 +a812ce935cfb4dbe6b3a84dc5ba9226f +23b0f34af9d1d03b4a001b0003020002 +00120000446900050003026832000000 +170015000012706b67732e7461696c73 +63616c652e636f6d002d000201010005 +00050100000000001700003a3a000100 +0015002d000000000000000000000000 +00000000000000000000000000000000 +00000000000000000000000000000000 +0000290094006f0069e76f2016f963ad +38c8632d1f240cd75e00e25fdef295d4 +7042b26f3a9a543b1c7dc74939d77803 +20527d423ff996997bda2c6383a14f49 +219eeef8a053e90a32228df37ddbe126 +eccf6b085c93890d08341d819aea6111 +0d909f4cd6b071d9ea40618e74588a33 +90d494bbb5c3002120d5a164a16c9724 +c9ef5e540d8d6f007789a7acf9f5f16f +bf6a1907a6782ed02b` + +func fakeSNIHeader() []byte { + b, err := hex.DecodeString(strings.Replace(tlsStart, "\n", "", -1)) + if err != nil { + panic(err) + } + return b[0x34:] // trim IP + TCP header +} + +func TestTCPSNIHandler(t *testing.T) { + h := tcpSNIHandler{ + Allowlist: []string{"pkgs.tailscale.com"}, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if network != "tcp" { + t.Errorf("network = %s, want %s", network, "tcp") + } + if addr != "pkgs.tailscale.com:443" { + t.Errorf("addr = %s, want %s", addr, "pkgs.tailscale.com:443") + } + + c, s := memnet.NewConn("outbound", 1024) + go echoConnOnce(s) + return c, nil + }, + } + + cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:443"), 1024) + h.Handle(sSock) + + // Fake a TLS handshake record with an SNI in it. + if _, err := cSock.Write(fakeSNIHeader()); err != nil { + t.Fatal(err) + } + + // Test read, the other end will echo back + // a single stanza, which is at least the beginning of the SNI header. + want := fakeSNIHeader()[:5] + if _, err := cSock.Write(want); err != nil { + t.Fatal(err) + } + got := make([]byte, len(want)) + if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, want) { + t.Errorf("got %q, want %q", got, want) + } +} diff --git a/cmd/sniproxy/server.go b/cmd/sniproxy/server.go index c894206613f4a..b322b6f4b1137 100644 --- a/cmd/sniproxy/server.go +++ b/cmd/sniproxy/server.go @@ -1,327 +1,327 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "expvar" - "log" - "net" - "net/netip" - "sync" - "time" - - "golang.org/x/net/dns/dnsmessage" - "tailscale.com/metrics" - "tailscale.com/tailcfg" - "tailscale.com/types/appctype" - "tailscale.com/types/ipproto" - "tailscale.com/types/nettype" - "tailscale.com/util/clientmetric" - "tailscale.com/util/mak" -) - -var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") - -// target describes the predicates which route some inbound -// traffic to the app connector to a specific handler. -type target struct { - Dest netip.Prefix - Matching tailcfg.ProtoPortRange -} - -// Server implements an App Connector as expressed in sniproxy. -type Server struct { - mu sync.RWMutex // mu guards following fields - connectors map[appctype.ConfigID]connector -} - -type appcMetrics struct { - dnsResponses expvar.Int - dnsFailures expvar.Int - tcpConns expvar.Int - sniConns expvar.Int - unhandledConns expvar.Int -} - -var getMetrics = sync.OnceValue[*appcMetrics](func() *appcMetrics { - m := appcMetrics{} - - stats := new(metrics.Set) - stats.Set("tls_sessions", &m.sniConns) - clientmetric.NewCounterFunc("sniproxy_tls_sessions", m.sniConns.Value) - stats.Set("tcp_sessions", &m.tcpConns) - clientmetric.NewCounterFunc("sniproxy_tcp_sessions", m.tcpConns.Value) - stats.Set("dns_responses", &m.dnsResponses) - clientmetric.NewCounterFunc("sniproxy_dns_responses", m.dnsResponses.Value) - stats.Set("dns_failed", &m.dnsFailures) - clientmetric.NewCounterFunc("sniproxy_dns_failed", m.dnsFailures.Value) - expvar.Publish("sniproxy", stats) - - return &m -}) - -// Configure applies the provided configuration to the app connector. -func (s *Server) Configure(cfg *appctype.AppConnectorConfig) { - s.mu.Lock() - defer s.mu.Unlock() - s.connectors = makeConnectorsFromConfig(cfg) - log.Printf("installed app connector config: %+v", s.connectors) -} - -// HandleTCPFlow implements tsnet.FallbackTCPHandler. -func (s *Server) HandleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { - m := getMetrics() - s.mu.RLock() - defer s.mu.RUnlock() - - for _, c := range s.connectors { - if handler, intercept := c.handleTCPFlow(src, dst, m); intercept { - return handler, intercept - } - } - - return nil, false -} - -// HandleDNS handles a DNS request to the app connector. -func (s *Server) HandleDNS(c nettype.ConnPacketConn) { - defer c.Close() - c.SetReadDeadline(time.Now().Add(5 * time.Second)) - m := getMetrics() - - buf := make([]byte, 1500) - n, err := c.Read(buf) - if err != nil { - log.Printf("HandleDNS: read failed: %v\n ", err) - m.dnsFailures.Add(1) - return - } - - addrPortStr := c.LocalAddr().String() - host, _, err := net.SplitHostPort(addrPortStr) - if err != nil { - log.Printf("HandleDNS: bogus addrPort %q", addrPortStr) - m.dnsFailures.Add(1) - return - } - localAddr, err := netip.ParseAddr(host) - if err != nil { - log.Printf("HandleDNS: bogus local address %q", host) - m.dnsFailures.Add(1) - return - } - - var msg dnsmessage.Message - err = msg.Unpack(buf[:n]) - if err != nil { - log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err) - m.dnsFailures.Add(1) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - for _, connector := range s.connectors { - resp, err := connector.handleDNS(&msg, localAddr) - if err != nil { - log.Printf("HandleDNS: connector handling failed: %v\n", err) - m.dnsFailures.Add(1) - return - } - if len(resp) > 0 { - // This connector handled the DNS request - _, err = c.Write(resp) - if err != nil { - log.Printf("HandleDNS: write failed: %v\n", err) - m.dnsFailures.Add(1) - return - } - - m.dnsResponses.Add(1) - return - } - } -} - -// connector describes a logical collection of -// services which need to be proxied. -type connector struct { - Handlers map[target]handler -} - -// handleTCPFlow implements tsnet.FallbackTCPHandler. -func (c *connector) handleTCPFlow(src, dst netip.AddrPort, m *appcMetrics) (handler func(net.Conn), intercept bool) { - for t, h := range c.Handlers { - if t.Matching.Proto != 0 && t.Matching.Proto != int(ipproto.TCP) { - continue - } - if !t.Dest.Contains(dst.Addr()) { - continue - } - if !t.Matching.Ports.Contains(dst.Port()) { - continue - } - - switch h.(type) { - case *tcpSNIHandler: - m.sniConns.Add(1) - case *tcpRoundRobinHandler: - m.tcpConns.Add(1) - default: - log.Printf("handleTCPFlow: unhandled handler type %T", h) - } - - return h.Handle, true - } - - m.unhandledConns.Add(1) - return nil, false -} - -// handleDNS returns the DNS response to the given query. If this -// connector is unable to handle the request, nil is returned. -func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (response []byte, err error) { - for t, h := range c.Handlers { - if t.Dest.Contains(localAddr) { - return makeDNSResponse(req, h.ReachableOn()) - } - } - - // Did not match, signal 'not handled' to caller - return nil, nil -} - -func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) { - resp := dnsmessage.NewBuilder(response, - dnsmessage.Header{ - ID: req.Header.ID, - Response: true, - Authoritative: true, - }) - resp.EnableCompression() - - if len(req.Questions) == 0 { - response, _ = resp.Finish() - return response, nil - } - q := req.Questions[0] - err = resp.StartQuestions() - if err != nil { - return - } - resp.Question(q) - - err = resp.StartAnswers() - if err != nil { - return - } - - switch q.Type { - case dnsmessage.TypeAAAA: - for _, ip := range reachableIPs { - if ip.Is6() { - err = resp.AAAAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.AAAAResource{AAAA: ip.As16()}, - ) - } - } - - case dnsmessage.TypeA: - for _, ip := range reachableIPs { - if ip.Is4() { - err = resp.AResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.AResource{A: ip.As4()}, - ) - } - } - - case dnsmessage.TypeSOA: - err = resp.SOAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, - Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, - ) - case dnsmessage.TypeNS: - err = resp.NSResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.NSResource{NS: tsMBox}, - ) - } - - if err != nil { - return nil, err - } - return resp.Finish() -} - -type handler interface { - // Handle handles the given socket. - Handle(c net.Conn) - - // ReachableOn returns the IP addresses this handler is reachable on. - ReachableOn() []netip.Addr -} - -func installDNATHandler(d *appctype.DNATConfig, out *connector) { - // These handlers don't actually do DNAT, they just - // proxy the data over the connection. - var dialer net.Dialer - dialer.Timeout = 5 * time.Second - h := tcpRoundRobinHandler{ - To: d.To, - DialContext: dialer.DialContext, - ReachableIPs: d.Addrs, - } - - for _, addr := range d.Addrs { - for _, protoPort := range d.IP { - t := target{ - Dest: netip.PrefixFrom(addr, addr.BitLen()), - Matching: protoPort, - } - - mak.Set(&out.Handlers, t, handler(&h)) - } - } -} - -func installSNIHandler(c *appctype.SNIProxyConfig, out *connector) { - var dialer net.Dialer - dialer.Timeout = 5 * time.Second - h := tcpSNIHandler{ - Allowlist: c.AllowedDomains, - DialContext: dialer.DialContext, - ReachableIPs: c.Addrs, - } - - for _, addr := range c.Addrs { - for _, protoPort := range c.IP { - t := target{ - Dest: netip.PrefixFrom(addr, addr.BitLen()), - Matching: protoPort, - } - - mak.Set(&out.Handlers, t, handler(&h)) - } - } -} - -func makeConnectorsFromConfig(cfg *appctype.AppConnectorConfig) map[appctype.ConfigID]connector { - var connectors map[appctype.ConfigID]connector - - for cID, d := range cfg.DNAT { - c := connectors[cID] - installDNATHandler(&d, &c) - mak.Set(&connectors, cID, c) - } - for cID, d := range cfg.SNIProxy { - c := connectors[cID] - installSNIHandler(&d, &c) - mak.Set(&connectors, cID, c) - } - - return connectors -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "expvar" + "log" + "net" + "net/netip" + "sync" + "time" + + "golang.org/x/net/dns/dnsmessage" + "tailscale.com/metrics" + "tailscale.com/tailcfg" + "tailscale.com/types/appctype" + "tailscale.com/types/ipproto" + "tailscale.com/types/nettype" + "tailscale.com/util/clientmetric" + "tailscale.com/util/mak" +) + +var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") + +// target describes the predicates which route some inbound +// traffic to the app connector to a specific handler. +type target struct { + Dest netip.Prefix + Matching tailcfg.ProtoPortRange +} + +// Server implements an App Connector as expressed in sniproxy. +type Server struct { + mu sync.RWMutex // mu guards following fields + connectors map[appctype.ConfigID]connector +} + +type appcMetrics struct { + dnsResponses expvar.Int + dnsFailures expvar.Int + tcpConns expvar.Int + sniConns expvar.Int + unhandledConns expvar.Int +} + +var getMetrics = sync.OnceValue[*appcMetrics](func() *appcMetrics { + m := appcMetrics{} + + stats := new(metrics.Set) + stats.Set("tls_sessions", &m.sniConns) + clientmetric.NewCounterFunc("sniproxy_tls_sessions", m.sniConns.Value) + stats.Set("tcp_sessions", &m.tcpConns) + clientmetric.NewCounterFunc("sniproxy_tcp_sessions", m.tcpConns.Value) + stats.Set("dns_responses", &m.dnsResponses) + clientmetric.NewCounterFunc("sniproxy_dns_responses", m.dnsResponses.Value) + stats.Set("dns_failed", &m.dnsFailures) + clientmetric.NewCounterFunc("sniproxy_dns_failed", m.dnsFailures.Value) + expvar.Publish("sniproxy", stats) + + return &m +}) + +// Configure applies the provided configuration to the app connector. +func (s *Server) Configure(cfg *appctype.AppConnectorConfig) { + s.mu.Lock() + defer s.mu.Unlock() + s.connectors = makeConnectorsFromConfig(cfg) + log.Printf("installed app connector config: %+v", s.connectors) +} + +// HandleTCPFlow implements tsnet.FallbackTCPHandler. +func (s *Server) HandleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { + m := getMetrics() + s.mu.RLock() + defer s.mu.RUnlock() + + for _, c := range s.connectors { + if handler, intercept := c.handleTCPFlow(src, dst, m); intercept { + return handler, intercept + } + } + + return nil, false +} + +// HandleDNS handles a DNS request to the app connector. +func (s *Server) HandleDNS(c nettype.ConnPacketConn) { + defer c.Close() + c.SetReadDeadline(time.Now().Add(5 * time.Second)) + m := getMetrics() + + buf := make([]byte, 1500) + n, err := c.Read(buf) + if err != nil { + log.Printf("HandleDNS: read failed: %v\n ", err) + m.dnsFailures.Add(1) + return + } + + addrPortStr := c.LocalAddr().String() + host, _, err := net.SplitHostPort(addrPortStr) + if err != nil { + log.Printf("HandleDNS: bogus addrPort %q", addrPortStr) + m.dnsFailures.Add(1) + return + } + localAddr, err := netip.ParseAddr(host) + if err != nil { + log.Printf("HandleDNS: bogus local address %q", host) + m.dnsFailures.Add(1) + return + } + + var msg dnsmessage.Message + err = msg.Unpack(buf[:n]) + if err != nil { + log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err) + m.dnsFailures.Add(1) + return + } + + s.mu.RLock() + defer s.mu.RUnlock() + for _, connector := range s.connectors { + resp, err := connector.handleDNS(&msg, localAddr) + if err != nil { + log.Printf("HandleDNS: connector handling failed: %v\n", err) + m.dnsFailures.Add(1) + return + } + if len(resp) > 0 { + // This connector handled the DNS request + _, err = c.Write(resp) + if err != nil { + log.Printf("HandleDNS: write failed: %v\n", err) + m.dnsFailures.Add(1) + return + } + + m.dnsResponses.Add(1) + return + } + } +} + +// connector describes a logical collection of +// services which need to be proxied. +type connector struct { + Handlers map[target]handler +} + +// handleTCPFlow implements tsnet.FallbackTCPHandler. +func (c *connector) handleTCPFlow(src, dst netip.AddrPort, m *appcMetrics) (handler func(net.Conn), intercept bool) { + for t, h := range c.Handlers { + if t.Matching.Proto != 0 && t.Matching.Proto != int(ipproto.TCP) { + continue + } + if !t.Dest.Contains(dst.Addr()) { + continue + } + if !t.Matching.Ports.Contains(dst.Port()) { + continue + } + + switch h.(type) { + case *tcpSNIHandler: + m.sniConns.Add(1) + case *tcpRoundRobinHandler: + m.tcpConns.Add(1) + default: + log.Printf("handleTCPFlow: unhandled handler type %T", h) + } + + return h.Handle, true + } + + m.unhandledConns.Add(1) + return nil, false +} + +// handleDNS returns the DNS response to the given query. If this +// connector is unable to handle the request, nil is returned. +func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (response []byte, err error) { + for t, h := range c.Handlers { + if t.Dest.Contains(localAddr) { + return makeDNSResponse(req, h.ReachableOn()) + } + } + + // Did not match, signal 'not handled' to caller + return nil, nil +} + +func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) { + resp := dnsmessage.NewBuilder(response, + dnsmessage.Header{ + ID: req.Header.ID, + Response: true, + Authoritative: true, + }) + resp.EnableCompression() + + if len(req.Questions) == 0 { + response, _ = resp.Finish() + return response, nil + } + q := req.Questions[0] + err = resp.StartQuestions() + if err != nil { + return + } + resp.Question(q) + + err = resp.StartAnswers() + if err != nil { + return + } + + switch q.Type { + case dnsmessage.TypeAAAA: + for _, ip := range reachableIPs { + if ip.Is6() { + err = resp.AAAAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AAAAResource{AAAA: ip.As16()}, + ) + } + } + + case dnsmessage.TypeA: + for _, ip := range reachableIPs { + if ip.Is4() { + err = resp.AResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AResource{A: ip.As4()}, + ) + } + } + + case dnsmessage.TypeSOA: + err = resp.SOAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, + Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, + ) + case dnsmessage.TypeNS: + err = resp.NSResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.NSResource{NS: tsMBox}, + ) + } + + if err != nil { + return nil, err + } + return resp.Finish() +} + +type handler interface { + // Handle handles the given socket. + Handle(c net.Conn) + + // ReachableOn returns the IP addresses this handler is reachable on. + ReachableOn() []netip.Addr +} + +func installDNATHandler(d *appctype.DNATConfig, out *connector) { + // These handlers don't actually do DNAT, they just + // proxy the data over the connection. + var dialer net.Dialer + dialer.Timeout = 5 * time.Second + h := tcpRoundRobinHandler{ + To: d.To, + DialContext: dialer.DialContext, + ReachableIPs: d.Addrs, + } + + for _, addr := range d.Addrs { + for _, protoPort := range d.IP { + t := target{ + Dest: netip.PrefixFrom(addr, addr.BitLen()), + Matching: protoPort, + } + + mak.Set(&out.Handlers, t, handler(&h)) + } + } +} + +func installSNIHandler(c *appctype.SNIProxyConfig, out *connector) { + var dialer net.Dialer + dialer.Timeout = 5 * time.Second + h := tcpSNIHandler{ + Allowlist: c.AllowedDomains, + DialContext: dialer.DialContext, + ReachableIPs: c.Addrs, + } + + for _, addr := range c.Addrs { + for _, protoPort := range c.IP { + t := target{ + Dest: netip.PrefixFrom(addr, addr.BitLen()), + Matching: protoPort, + } + + mak.Set(&out.Handlers, t, handler(&h)) + } + } +} + +func makeConnectorsFromConfig(cfg *appctype.AppConnectorConfig) map[appctype.ConfigID]connector { + var connectors map[appctype.ConfigID]connector + + for cID, d := range cfg.DNAT { + c := connectors[cID] + installDNATHandler(&d, &c) + mak.Set(&connectors, cID, c) + } + for cID, d := range cfg.SNIProxy { + c := connectors[cID] + installSNIHandler(&d, &c) + mak.Set(&connectors, cID, c) + } + + return connectors +} diff --git a/cmd/sniproxy/server_test.go b/cmd/sniproxy/server_test.go index 2a51c874c81b0..d56f2aa754f85 100644 --- a/cmd/sniproxy/server_test.go +++ b/cmd/sniproxy/server_test.go @@ -1,95 +1,95 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "net/netip" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "tailscale.com/tailcfg" - "tailscale.com/types/appctype" -) - -func TestMakeConnectorsFromConfig(t *testing.T) { - tcs := []struct { - name string - input *appctype.AppConnectorConfig - want map[appctype.ConfigID]connector - }{ - { - "empty", - &appctype.AppConnectorConfig{}, - nil, - }, - { - "DNAT", - &appctype.AppConnectorConfig{ - DNAT: map[appctype.ConfigID]appctype.DNATConfig{ - "swiggity_swooty": { - Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, - To: []string{"example.org"}, - IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, - }, - }, - }, - map[appctype.ConfigID]connector{ - "swiggity_swooty": { - Handlers: map[target]handler{ - { - Dest: netip.MustParsePrefix("100.64.0.1/32"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - { - Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - }, - }, - }, - }, - { - "SNIProxy", - &appctype.AppConnectorConfig{ - SNIProxy: map[appctype.ConfigID]appctype.SNIProxyConfig{ - "swiggity_swooty": { - Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, - AllowedDomains: []string{"example.org"}, - IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, - }, - }, - }, - map[appctype.ConfigID]connector{ - "swiggity_swooty": { - Handlers: map[target]handler{ - { - Dest: netip.MustParsePrefix("100.64.0.1/32"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - { - Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - }, - }, - }, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - connectors := makeConnectorsFromConfig(tc.input) - - if diff := cmp.Diff(connectors, tc.want, - cmpopts.IgnoreFields(tcpRoundRobinHandler{}, "DialContext"), - cmpopts.IgnoreFields(tcpSNIHandler{}, "DialContext"), - cmp.Comparer(func(x, y netip.Addr) bool { - return x == y - })); diff != "" { - t.Fatalf("mismatch (-want +got):\n%s", diff) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "net/netip" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/tailcfg" + "tailscale.com/types/appctype" +) + +func TestMakeConnectorsFromConfig(t *testing.T) { + tcs := []struct { + name string + input *appctype.AppConnectorConfig + want map[appctype.ConfigID]connector + }{ + { + "empty", + &appctype.AppConnectorConfig{}, + nil, + }, + { + "DNAT", + &appctype.AppConnectorConfig{ + DNAT: map[appctype.ConfigID]appctype.DNATConfig{ + "swiggity_swooty": { + Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, + To: []string{"example.org"}, + IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, + }, + }, + }, + map[appctype.ConfigID]connector{ + "swiggity_swooty": { + Handlers: map[target]handler{ + { + Dest: netip.MustParsePrefix("100.64.0.1/32"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + { + Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + }, + }, + }, + }, + { + "SNIProxy", + &appctype.AppConnectorConfig{ + SNIProxy: map[appctype.ConfigID]appctype.SNIProxyConfig{ + "swiggity_swooty": { + Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, + AllowedDomains: []string{"example.org"}, + IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, + }, + }, + }, + map[appctype.ConfigID]connector{ + "swiggity_swooty": { + Handlers: map[target]handler{ + { + Dest: netip.MustParsePrefix("100.64.0.1/32"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + { + Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + }, + }, + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + connectors := makeConnectorsFromConfig(tc.input) + + if diff := cmp.Diff(connectors, tc.want, + cmpopts.IgnoreFields(tcpRoundRobinHandler{}, "DialContext"), + cmpopts.IgnoreFields(tcpSNIHandler{}, "DialContext"), + cmp.Comparer(func(x, y netip.Addr) bool { + return x == y + })); diff != "" { + t.Fatalf("mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/cmd/sniproxy/sniproxy.go b/cmd/sniproxy/sniproxy.go index c048c8e7e2792..fa83aaf4ab44e 100644 --- a/cmd/sniproxy/sniproxy.go +++ b/cmd/sniproxy/sniproxy.go @@ -1,291 +1,291 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The sniproxy is an outbound SNI proxy. It receives TLS connections over -// Tailscale on one or more TCP ports and sends them out to the same SNI -// hostname & port on the internet. It can optionally forward one or more -// TCP ports to a specific destination. It only does TCP. -package main - -import ( - "context" - "errors" - "flag" - "fmt" - "log" - "net" - "net/http" - "net/netip" - "os" - "sort" - "strconv" - "strings" - - "github.com/peterbourgon/ff/v3" - "tailscale.com/client/tailscale" - "tailscale.com/hostinfo" - "tailscale.com/ipn" - "tailscale.com/tailcfg" - "tailscale.com/tsnet" - "tailscale.com/tsweb" - "tailscale.com/types/appctype" - "tailscale.com/types/ipproto" - "tailscale.com/types/nettype" - "tailscale.com/util/mak" -) - -const configCapKey = "tailscale.com/sniproxy" - -// portForward is the state for a single port forwarding entry, as passed to the --forward flag. -type portForward struct { - Port int - Proto string - Destination string -} - -// parseForward takes a proto/port/destination tuple as an input, as would be passed -// to the --forward command line flag, and returns a *portForward struct of those parameters. -func parseForward(value string) (*portForward, error) { - parts := strings.Split(value, "/") - if len(parts) != 3 { - return nil, errors.New("cannot parse: " + value) - } - - proto := parts[0] - if proto != "tcp" { - return nil, errors.New("unsupported forwarding protocol: " + proto) - } - port, err := strconv.ParseUint(parts[1], 10, 16) - if err != nil { - return nil, errors.New("bad forwarding port: " + parts[1]) - } - host := parts[2] - if host == "" { - return nil, errors.New("bad destination: " + value) - } - - return &portForward{Port: int(port), Proto: proto, Destination: host}, nil -} - -func main() { - // Parse flags - fs := flag.NewFlagSet("sniproxy", flag.ContinueOnError) - var ( - ports = fs.String("ports", "443", "comma-separated list of ports to proxy") - forwards = fs.String("forwards", "", "comma-separated list of ports to transparently forward, protocol/number/destination. For example, --forwards=tcp/22/github.com,tcp/5432/sql.example.com") - wgPort = fs.Int("wg-listen-port", 0, "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select") - promoteHTTPS = fs.Bool("promote-https", true, "promote HTTP to HTTPS") - debugPort = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint") - hostname = fs.String("hostname", "", "Hostname to register the service under") - ) - err := ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_APPC")) - if err != nil { - log.Fatal("ff.Parse") - } - - var ts tsnet.Server - defer ts.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - run(ctx, &ts, *wgPort, *hostname, *promoteHTTPS, *debugPort, *ports, *forwards) -} - -// run actually runs the sniproxy. Its separate from main() to assist in testing. -func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, promoteHTTPS bool, debugPort int, ports, forwards string) { - // Wire up Tailscale node + app connector server - hostinfo.SetApp("sniproxy") - var s sniproxy - s.ts = ts - - s.ts.Port = uint16(wgPort) - s.ts.Hostname = hostname - - lc, err := s.ts.LocalClient() - if err != nil { - log.Fatalf("LocalClient() failed: %v", err) - } - s.lc = lc - s.ts.RegisterFallbackTCPHandler(s.srv.HandleTCPFlow) - - // Start special-purpose listeners: dns, http promotion, debug server - ln, err := s.ts.Listen("udp", ":53") - if err != nil { - log.Fatalf("failed listening on port 53: %v", err) - } - defer ln.Close() - go s.serveDNS(ln) - if promoteHTTPS { - ln, err := s.ts.Listen("tcp", ":80") - if err != nil { - log.Fatalf("failed listening on port 80: %v", err) - } - defer ln.Close() - log.Printf("Promoting HTTP to HTTPS ...") - go s.promoteHTTPS(ln) - } - if debugPort != 0 { - mux := http.NewServeMux() - tsweb.Debugger(mux) - dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", debugPort)) - if err != nil { - log.Fatalf("failed listening on debug port: %v", err) - } - defer dln.Close() - go func() { - log.Fatalf("debug serve: %v", http.Serve(dln, mux)) - }() - } - - // Finally, start mainloop to configure app connector based on information - // in the netmap. - // We set the NotifyInitialNetMap flag so we will always get woken with the - // current netmap, before only being woken on changes. - bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap|ipn.NotifyNoPrivateKeys) - if err != nil { - log.Fatalf("watching IPN bus: %v", err) - } - defer bus.Close() - for { - msg, err := bus.Next() - if err != nil { - if errors.Is(err, context.Canceled) { - return - } - log.Fatalf("reading IPN bus: %v", err) - } - - // NetMap contains app-connector configuration - if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() { - sn := nm.SelfNode.AsStruct() - - var c appctype.AppConnectorConfig - nmConf, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorConfig](sn.CapMap, configCapKey) - if err != nil { - log.Printf("failed to read app connector configuration from coordination server: %v", err) - } else if len(nmConf) > 0 { - c = nmConf[0] - } - - if c.AdvertiseRoutes { - if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil { - log.Printf("failed to advertise routes: %v", err) - } - } - - // Backwards compatibility: combine any configuration from control with flags specified - // on the command line. This is intentionally done after we advertise any routes - // because its never correct to advertise the nodes native IP addresses. - s.mergeConfigFromFlags(&c, ports, forwards) - s.srv.Configure(&c) - } - } -} - -type sniproxy struct { - srv Server - ts *tsnet.Server - lc *tailscale.LocalClient -} - -func (s *sniproxy) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error { - // Collect the set of addresses to advertise, using a map - // to avoid duplicate entries. - addrs := map[netip.Addr]struct{}{} - for _, c := range c.SNIProxy { - for _, ip := range c.Addrs { - addrs[ip] = struct{}{} - } - } - for _, c := range c.DNAT { - for _, ip := range c.Addrs { - addrs[ip] = struct{}{} - } - } - - var routes []netip.Prefix - for a := range addrs { - routes = append(routes, netip.PrefixFrom(a, a.BitLen())) - } - sort.SliceStable(routes, func(i, j int) bool { - return routes[i].Addr().Less(routes[j].Addr()) // determinism r us - }) - - _, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ - Prefs: ipn.Prefs{ - AdvertiseRoutes: routes, - }, - AdvertiseRoutesSet: true, - }) - return err -} - -func (s *sniproxy) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, forwards string) { - ip4, ip6 := s.ts.TailscaleIPs() - - sniConfigFromFlags := appctype.SNIProxyConfig{ - Addrs: []netip.Addr{ip4, ip6}, - } - if ports != "" { - for _, portStr := range strings.Split(ports, ",") { - port, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - log.Fatalf("invalid port: %s", portStr) - } - sniConfigFromFlags.IP = append(sniConfigFromFlags.IP, tailcfg.ProtoPortRange{ - Proto: int(ipproto.TCP), - Ports: tailcfg.PortRange{First: uint16(port), Last: uint16(port)}, - }) - } - } - - var forwardConfigFromFlags []appctype.DNATConfig - for _, forwStr := range strings.Split(forwards, ",") { - if forwStr == "" { - continue - } - forw, err := parseForward(forwStr) - if err != nil { - log.Printf("invalid forwarding spec: %v", err) - continue - } - - forwardConfigFromFlags = append(forwardConfigFromFlags, appctype.DNATConfig{ - Addrs: []netip.Addr{ip4, ip6}, - To: []string{forw.Destination}, - IP: []tailcfg.ProtoPortRange{ - { - Proto: int(ipproto.TCP), - Ports: tailcfg.PortRange{First: uint16(forw.Port), Last: uint16(forw.Port)}, - }, - }, - }) - } - - if len(forwardConfigFromFlags) == 0 && len(sniConfigFromFlags.IP) == 0 { - return // no config specified on the command line - } - - mak.Set(&out.SNIProxy, "flags", sniConfigFromFlags) - for i, forward := range forwardConfigFromFlags { - mak.Set(&out.DNAT, appctype.ConfigID(fmt.Sprintf("flags_%d", i)), forward) - } -} - -func (s *sniproxy) serveDNS(ln net.Listener) { - for { - c, err := ln.Accept() - if err != nil { - log.Printf("serveDNS accept: %v", err) - return - } - go s.srv.HandleDNS(c.(nettype.ConnPacketConn)) - } -} - -func (s *sniproxy) promoteHTTPS(ln net.Listener) { - err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound) - })) - log.Fatalf("promoteHTTPS http.Serve: %v", err) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The sniproxy is an outbound SNI proxy. It receives TLS connections over +// Tailscale on one or more TCP ports and sends them out to the same SNI +// hostname & port on the internet. It can optionally forward one or more +// TCP ports to a specific destination. It only does TCP. +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net" + "net/http" + "net/netip" + "os" + "sort" + "strconv" + "strings" + + "github.com/peterbourgon/ff/v3" + "tailscale.com/client/tailscale" + "tailscale.com/hostinfo" + "tailscale.com/ipn" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" + "tailscale.com/tsweb" + "tailscale.com/types/appctype" + "tailscale.com/types/ipproto" + "tailscale.com/types/nettype" + "tailscale.com/util/mak" +) + +const configCapKey = "tailscale.com/sniproxy" + +// portForward is the state for a single port forwarding entry, as passed to the --forward flag. +type portForward struct { + Port int + Proto string + Destination string +} + +// parseForward takes a proto/port/destination tuple as an input, as would be passed +// to the --forward command line flag, and returns a *portForward struct of those parameters. +func parseForward(value string) (*portForward, error) { + parts := strings.Split(value, "/") + if len(parts) != 3 { + return nil, errors.New("cannot parse: " + value) + } + + proto := parts[0] + if proto != "tcp" { + return nil, errors.New("unsupported forwarding protocol: " + proto) + } + port, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return nil, errors.New("bad forwarding port: " + parts[1]) + } + host := parts[2] + if host == "" { + return nil, errors.New("bad destination: " + value) + } + + return &portForward{Port: int(port), Proto: proto, Destination: host}, nil +} + +func main() { + // Parse flags + fs := flag.NewFlagSet("sniproxy", flag.ContinueOnError) + var ( + ports = fs.String("ports", "443", "comma-separated list of ports to proxy") + forwards = fs.String("forwards", "", "comma-separated list of ports to transparently forward, protocol/number/destination. For example, --forwards=tcp/22/github.com,tcp/5432/sql.example.com") + wgPort = fs.Int("wg-listen-port", 0, "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select") + promoteHTTPS = fs.Bool("promote-https", true, "promote HTTP to HTTPS") + debugPort = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint") + hostname = fs.String("hostname", "", "Hostname to register the service under") + ) + err := ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_APPC")) + if err != nil { + log.Fatal("ff.Parse") + } + + var ts tsnet.Server + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + run(ctx, &ts, *wgPort, *hostname, *promoteHTTPS, *debugPort, *ports, *forwards) +} + +// run actually runs the sniproxy. Its separate from main() to assist in testing. +func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, promoteHTTPS bool, debugPort int, ports, forwards string) { + // Wire up Tailscale node + app connector server + hostinfo.SetApp("sniproxy") + var s sniproxy + s.ts = ts + + s.ts.Port = uint16(wgPort) + s.ts.Hostname = hostname + + lc, err := s.ts.LocalClient() + if err != nil { + log.Fatalf("LocalClient() failed: %v", err) + } + s.lc = lc + s.ts.RegisterFallbackTCPHandler(s.srv.HandleTCPFlow) + + // Start special-purpose listeners: dns, http promotion, debug server + ln, err := s.ts.Listen("udp", ":53") + if err != nil { + log.Fatalf("failed listening on port 53: %v", err) + } + defer ln.Close() + go s.serveDNS(ln) + if promoteHTTPS { + ln, err := s.ts.Listen("tcp", ":80") + if err != nil { + log.Fatalf("failed listening on port 80: %v", err) + } + defer ln.Close() + log.Printf("Promoting HTTP to HTTPS ...") + go s.promoteHTTPS(ln) + } + if debugPort != 0 { + mux := http.NewServeMux() + tsweb.Debugger(mux) + dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", debugPort)) + if err != nil { + log.Fatalf("failed listening on debug port: %v", err) + } + defer dln.Close() + go func() { + log.Fatalf("debug serve: %v", http.Serve(dln, mux)) + }() + } + + // Finally, start mainloop to configure app connector based on information + // in the netmap. + // We set the NotifyInitialNetMap flag so we will always get woken with the + // current netmap, before only being woken on changes. + bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap|ipn.NotifyNoPrivateKeys) + if err != nil { + log.Fatalf("watching IPN bus: %v", err) + } + defer bus.Close() + for { + msg, err := bus.Next() + if err != nil { + if errors.Is(err, context.Canceled) { + return + } + log.Fatalf("reading IPN bus: %v", err) + } + + // NetMap contains app-connector configuration + if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() { + sn := nm.SelfNode.AsStruct() + + var c appctype.AppConnectorConfig + nmConf, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorConfig](sn.CapMap, configCapKey) + if err != nil { + log.Printf("failed to read app connector configuration from coordination server: %v", err) + } else if len(nmConf) > 0 { + c = nmConf[0] + } + + if c.AdvertiseRoutes { + if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil { + log.Printf("failed to advertise routes: %v", err) + } + } + + // Backwards compatibility: combine any configuration from control with flags specified + // on the command line. This is intentionally done after we advertise any routes + // because its never correct to advertise the nodes native IP addresses. + s.mergeConfigFromFlags(&c, ports, forwards) + s.srv.Configure(&c) + } + } +} + +type sniproxy struct { + srv Server + ts *tsnet.Server + lc *tailscale.LocalClient +} + +func (s *sniproxy) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error { + // Collect the set of addresses to advertise, using a map + // to avoid duplicate entries. + addrs := map[netip.Addr]struct{}{} + for _, c := range c.SNIProxy { + for _, ip := range c.Addrs { + addrs[ip] = struct{}{} + } + } + for _, c := range c.DNAT { + for _, ip := range c.Addrs { + addrs[ip] = struct{}{} + } + } + + var routes []netip.Prefix + for a := range addrs { + routes = append(routes, netip.PrefixFrom(a, a.BitLen())) + } + sort.SliceStable(routes, func(i, j int) bool { + return routes[i].Addr().Less(routes[j].Addr()) // determinism r us + }) + + _, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + AdvertiseRoutes: routes, + }, + AdvertiseRoutesSet: true, + }) + return err +} + +func (s *sniproxy) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, forwards string) { + ip4, ip6 := s.ts.TailscaleIPs() + + sniConfigFromFlags := appctype.SNIProxyConfig{ + Addrs: []netip.Addr{ip4, ip6}, + } + if ports != "" { + for _, portStr := range strings.Split(ports, ",") { + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + log.Fatalf("invalid port: %s", portStr) + } + sniConfigFromFlags.IP = append(sniConfigFromFlags.IP, tailcfg.ProtoPortRange{ + Proto: int(ipproto.TCP), + Ports: tailcfg.PortRange{First: uint16(port), Last: uint16(port)}, + }) + } + } + + var forwardConfigFromFlags []appctype.DNATConfig + for _, forwStr := range strings.Split(forwards, ",") { + if forwStr == "" { + continue + } + forw, err := parseForward(forwStr) + if err != nil { + log.Printf("invalid forwarding spec: %v", err) + continue + } + + forwardConfigFromFlags = append(forwardConfigFromFlags, appctype.DNATConfig{ + Addrs: []netip.Addr{ip4, ip6}, + To: []string{forw.Destination}, + IP: []tailcfg.ProtoPortRange{ + { + Proto: int(ipproto.TCP), + Ports: tailcfg.PortRange{First: uint16(forw.Port), Last: uint16(forw.Port)}, + }, + }, + }) + } + + if len(forwardConfigFromFlags) == 0 && len(sniConfigFromFlags.IP) == 0 { + return // no config specified on the command line + } + + mak.Set(&out.SNIProxy, "flags", sniConfigFromFlags) + for i, forward := range forwardConfigFromFlags { + mak.Set(&out.DNAT, appctype.ConfigID(fmt.Sprintf("flags_%d", i)), forward) + } +} + +func (s *sniproxy) serveDNS(ln net.Listener) { + for { + c, err := ln.Accept() + if err != nil { + log.Printf("serveDNS accept: %v", err) + return + } + go s.srv.HandleDNS(c.(nettype.ConnPacketConn)) + } +} + +func (s *sniproxy) promoteHTTPS(ln net.Listener) { + err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound) + })) + log.Fatalf("promoteHTTPS http.Serve: %v", err) +} diff --git a/cmd/speedtest/speedtest.go b/cmd/speedtest/speedtest.go index 1555c0dcc0b7a..9a457ed6c7486 100644 --- a/cmd/speedtest/speedtest.go +++ b/cmd/speedtest/speedtest.go @@ -1,121 +1,121 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Program speedtest provides the speedtest command. The reason to keep it separate from -// the normal tailscale cli is because it is not yet ready to go in the tailscale binary. -// It will be included in the tailscale cli after it has been added to tailscaled. - -// Example usage for client command: go run cmd/speedtest -host 127.0.0.1:20333 -t 5s -// This will connect to the server on 127.0.0.1:20333 and start a 5 second download speedtest. -// Example usage for server command: go run cmd/speedtest -s -host :20333 -// This will start a speedtest server on port 20333. -package main - -import ( - "context" - "errors" - "flag" - "fmt" - "net" - "os" - "strconv" - "text/tabwriter" - "time" - - "github.com/peterbourgon/ff/v3/ffcli" - "tailscale.com/net/speedtest" -) - -// Runs the speedtest command as a commandline program -func main() { - args := os.Args[1:] - if err := speedtestCmd.Parse(args); err != nil { - fmt.Fprintln(os.Stderr, err.Error()) - os.Exit(1) - } - - err := speedtestCmd.Run(context.Background()) - if errors.Is(err, flag.ErrHelp) { - fmt.Fprintln(os.Stderr, speedtestCmd.ShortUsage) - os.Exit(2) - } - if err != nil { - fmt.Fprintln(os.Stderr, err.Error()) - os.Exit(1) - } -} - -// speedtestCmd is the root command. It runs either the server or client depending on the -// flags passed to it. -var speedtestCmd = &ffcli.Command{ - Name: "speedtest", - ShortUsage: "speedtest [-host ] [-s] [-r] [-t ]", - ShortHelp: "Run a speed test", - FlagSet: (func() *flag.FlagSet { - fs := flag.NewFlagSet("speedtest", flag.ExitOnError) - fs.StringVar(&speedtestArgs.host, "host", ":20333", "host:port pair to connect to or listen on") - fs.DurationVar(&speedtestArgs.testDuration, "t", speedtest.DefaultDuration, "duration of the speed test") - fs.BoolVar(&speedtestArgs.runServer, "s", false, "run a speedtest server") - fs.BoolVar(&speedtestArgs.reverse, "r", false, "run in reverse mode (server sends, client receives)") - return fs - })(), - Exec: runSpeedtest, -} - -var speedtestArgs struct { - host string - testDuration time.Duration - runServer bool - reverse bool -} - -func runSpeedtest(ctx context.Context, args []string) error { - - if _, _, err := net.SplitHostPort(speedtestArgs.host); err != nil { - var addrErr *net.AddrError - if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" { - // if no port is provided, append the default port - speedtestArgs.host = net.JoinHostPort(speedtestArgs.host, strconv.Itoa(speedtest.DefaultPort)) - } - } - - if speedtestArgs.runServer { - listener, err := net.Listen("tcp", speedtestArgs.host) - if err != nil { - return err - } - - fmt.Printf("listening on %v\n", listener.Addr()) - - return speedtest.Serve(listener) - } - - // Ensure the duration is within the allowed range - if speedtestArgs.testDuration < speedtest.MinDuration || speedtestArgs.testDuration > speedtest.MaxDuration { - return fmt.Errorf("test duration must be within %v and %v", speedtest.MinDuration, speedtest.MaxDuration) - } - - dir := speedtest.Download - if speedtestArgs.reverse { - dir = speedtest.Upload - } - - fmt.Printf("Starting a %s test with %s\n", dir, speedtestArgs.host) - results, err := speedtest.RunClient(dir, speedtestArgs.testDuration, speedtestArgs.host) - if err != nil { - return err - } - - w := tabwriter.NewWriter(os.Stdout, 12, 0, 0, ' ', tabwriter.TabIndent) - fmt.Println("Results:") - fmt.Fprintln(w, "Interval\t\tTransfer\t\tBandwidth\t\t") - startTime := results[0].IntervalStart - for _, r := range results { - if r.Total { - fmt.Fprintln(w, "-------------------------------------------------------------------------") - } - fmt.Fprintf(w, "%.2f-%.2f\tsec\t%.4f\tMBits\t%.4f\tMbits/sec\t\n", r.IntervalStart.Sub(startTime).Seconds(), r.IntervalEnd.Sub(startTime).Seconds(), r.MegaBits(), r.MBitsPerSecond()) - } - w.Flush() - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Program speedtest provides the speedtest command. The reason to keep it separate from +// the normal tailscale cli is because it is not yet ready to go in the tailscale binary. +// It will be included in the tailscale cli after it has been added to tailscaled. + +// Example usage for client command: go run cmd/speedtest -host 127.0.0.1:20333 -t 5s +// This will connect to the server on 127.0.0.1:20333 and start a 5 second download speedtest. +// Example usage for server command: go run cmd/speedtest -s -host :20333 +// This will start a speedtest server on port 20333. +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "net" + "os" + "strconv" + "text/tabwriter" + "time" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/net/speedtest" +) + +// Runs the speedtest command as a commandline program +func main() { + args := os.Args[1:] + if err := speedtestCmd.Parse(args); err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + + err := speedtestCmd.Run(context.Background()) + if errors.Is(err, flag.ErrHelp) { + fmt.Fprintln(os.Stderr, speedtestCmd.ShortUsage) + os.Exit(2) + } + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } +} + +// speedtestCmd is the root command. It runs either the server or client depending on the +// flags passed to it. +var speedtestCmd = &ffcli.Command{ + Name: "speedtest", + ShortUsage: "speedtest [-host ] [-s] [-r] [-t ]", + ShortHelp: "Run a speed test", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("speedtest", flag.ExitOnError) + fs.StringVar(&speedtestArgs.host, "host", ":20333", "host:port pair to connect to or listen on") + fs.DurationVar(&speedtestArgs.testDuration, "t", speedtest.DefaultDuration, "duration of the speed test") + fs.BoolVar(&speedtestArgs.runServer, "s", false, "run a speedtest server") + fs.BoolVar(&speedtestArgs.reverse, "r", false, "run in reverse mode (server sends, client receives)") + return fs + })(), + Exec: runSpeedtest, +} + +var speedtestArgs struct { + host string + testDuration time.Duration + runServer bool + reverse bool +} + +func runSpeedtest(ctx context.Context, args []string) error { + + if _, _, err := net.SplitHostPort(speedtestArgs.host); err != nil { + var addrErr *net.AddrError + if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" { + // if no port is provided, append the default port + speedtestArgs.host = net.JoinHostPort(speedtestArgs.host, strconv.Itoa(speedtest.DefaultPort)) + } + } + + if speedtestArgs.runServer { + listener, err := net.Listen("tcp", speedtestArgs.host) + if err != nil { + return err + } + + fmt.Printf("listening on %v\n", listener.Addr()) + + return speedtest.Serve(listener) + } + + // Ensure the duration is within the allowed range + if speedtestArgs.testDuration < speedtest.MinDuration || speedtestArgs.testDuration > speedtest.MaxDuration { + return fmt.Errorf("test duration must be within %v and %v", speedtest.MinDuration, speedtest.MaxDuration) + } + + dir := speedtest.Download + if speedtestArgs.reverse { + dir = speedtest.Upload + } + + fmt.Printf("Starting a %s test with %s\n", dir, speedtestArgs.host) + results, err := speedtest.RunClient(dir, speedtestArgs.testDuration, speedtestArgs.host) + if err != nil { + return err + } + + w := tabwriter.NewWriter(os.Stdout, 12, 0, 0, ' ', tabwriter.TabIndent) + fmt.Println("Results:") + fmt.Fprintln(w, "Interval\t\tTransfer\t\tBandwidth\t\t") + startTime := results[0].IntervalStart + for _, r := range results { + if r.Total { + fmt.Fprintln(w, "-------------------------------------------------------------------------") + } + fmt.Fprintf(w, "%.2f-%.2f\tsec\t%.4f\tMBits\t%.4f\tMbits/sec\t\n", r.IntervalStart.Sub(startTime).Seconds(), r.IntervalEnd.Sub(startTime).Seconds(), r.MegaBits(), r.MBitsPerSecond()) + } + w.Flush() + return nil +} diff --git a/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go b/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go index ade272c4ba811..ee929299a4273 100644 --- a/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go +++ b/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go @@ -1,187 +1,187 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// ssh-auth-none-demo is a demo SSH server that's meant to run on the -// public internet (at 188.166.70.128 port 2222) and -// highlight the unique parts of the Tailscale SSH server so SSH -// client authors can hit it easily and fix their SSH clients without -// needing to set up Tailscale and Tailscale SSH. -package main - -import ( - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "flag" - "fmt" - "io" - "log" - "os" - "path/filepath" - "time" - - gossh "github.com/tailscale/golang-x-crypto/ssh" - "tailscale.com/tempfork/gliderlabs/ssh" -) - -// keyTypes are the SSH key types that we either try to read from the -// system's OpenSSH keys. -var keyTypes = []string{"rsa", "ecdsa", "ed25519"} - -var ( - addr = flag.String("addr", ":2222", "address to listen on") -) - -func main() { - flag.Parse() - - cacheDir, err := os.UserCacheDir() - if err != nil { - log.Fatal(err) - } - dir := filepath.Join(cacheDir, "ssh-auth-none-demo") - if err := os.MkdirAll(dir, 0700); err != nil { - log.Fatal(err) - } - - keys, err := getHostKeys(dir) - if err != nil { - log.Fatal(err) - } - if len(keys) == 0 { - log.Fatal("no host keys") - } - - srv := &ssh.Server{ - Addr: *addr, - Version: "Tailscale", - Handler: handleSessionPostSSHAuth, - ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { - start := time.Now() - return &gossh.ServerConfig{ - NextAuthMethodCallback: func(conn gossh.ConnMetadata, prevErrors []error) []string { - return []string{"tailscale"} - }, - NoClientAuth: true, // required for the NoClientAuthCallback to run - NoClientAuthCallback: func(cm gossh.ConnMetadata) (*gossh.Permissions, error) { - cm.SendAuthBanner(fmt.Sprintf("# Banner: doing none auth at %v\r\n", time.Since(start))) - - totalBanners := 2 - if cm.User() == "banners" { - totalBanners = 5 - } - for banner := 2; banner <= totalBanners; banner++ { - time.Sleep(time.Second) - if banner == totalBanners { - cm.SendAuthBanner(fmt.Sprintf("# Banner%d: access granted at %v\r\n", banner, time.Since(start))) - } else { - cm.SendAuthBanner(fmt.Sprintf("# Banner%d at %v\r\n", banner, time.Since(start))) - } - } - return nil, nil - }, - BannerCallback: func(cm gossh.ConnMetadata) string { - log.Printf("Got connection from user %q, %q from %v", cm.User(), cm.ClientVersion(), cm.RemoteAddr()) - return fmt.Sprintf("# Banner for user %q, %q\n", cm.User(), cm.ClientVersion()) - }, - } - }, - } - - for _, signer := range keys { - srv.AddHostKey(signer) - } - - log.Printf("Running on %s ...", srv.Addr) - if err := srv.ListenAndServe(); err != nil { - log.Fatal(err) - } - log.Printf("done") -} - -func handleSessionPostSSHAuth(s ssh.Session) { - log.Printf("Started session from user %q", s.User()) - fmt.Fprintf(s, "Hello user %q, it worked.\n", s.User()) - - // Abort the session on Control-C or Control-D. - go func() { - buf := make([]byte, 1024) - for { - n, err := s.Read(buf) - for _, b := range buf[:n] { - if b <= 4 { // abort on Control-C (3) or Control-D (4) - io.WriteString(s, "bye\n") - s.Exit(1) - } - } - if err != nil { - return - } - } - }() - - for i := 10; i > 0; i-- { - fmt.Fprintf(s, "%v ...\n", i) - time.Sleep(time.Second) - } - s.Exit(0) -} - -func getHostKeys(dir string) (ret []ssh.Signer, err error) { - for _, typ := range keyTypes { - hostKey, err := hostKeyFileOrCreate(dir, typ) - if err != nil { - return nil, err - } - signer, err := gossh.ParsePrivateKey(hostKey) - if err != nil { - return nil, err - } - ret = append(ret, signer) - } - return ret, nil -} - -func hostKeyFileOrCreate(keyDir, typ string) ([]byte, error) { - path := filepath.Join(keyDir, "ssh_host_"+typ+"_key") - v, err := os.ReadFile(path) - if err == nil { - return v, nil - } - if !os.IsNotExist(err) { - return nil, err - } - var priv any - switch typ { - default: - return nil, fmt.Errorf("unsupported key type %q", typ) - case "ed25519": - _, priv, err = ed25519.GenerateKey(rand.Reader) - case "ecdsa": - // curve is arbitrary. We pick whatever will at - // least pacify clients as the actual encryption - // doesn't matter: it's all over WireGuard anyway. - curve := elliptic.P256() - priv, err = ecdsa.GenerateKey(curve, rand.Reader) - case "rsa": - // keySize is arbitrary. We pick whatever will at - // least pacify clients as the actual encryption - // doesn't matter: it's all over WireGuard anyway. - const keySize = 2048 - priv, err = rsa.GenerateKey(rand.Reader, keySize) - } - if err != nil { - return nil, err - } - mk, err := x509.MarshalPKCS8PrivateKey(priv) - if err != nil { - return nil, err - } - pemGen := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk}) - err = os.WriteFile(path, pemGen, 0700) - return pemGen, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// ssh-auth-none-demo is a demo SSH server that's meant to run on the +// public internet (at 188.166.70.128 port 2222) and +// highlight the unique parts of the Tailscale SSH server so SSH +// client authors can hit it easily and fix their SSH clients without +// needing to set up Tailscale and Tailscale SSH. +package main + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "flag" + "fmt" + "io" + "log" + "os" + "path/filepath" + "time" + + gossh "github.com/tailscale/golang-x-crypto/ssh" + "tailscale.com/tempfork/gliderlabs/ssh" +) + +// keyTypes are the SSH key types that we either try to read from the +// system's OpenSSH keys. +var keyTypes = []string{"rsa", "ecdsa", "ed25519"} + +var ( + addr = flag.String("addr", ":2222", "address to listen on") +) + +func main() { + flag.Parse() + + cacheDir, err := os.UserCacheDir() + if err != nil { + log.Fatal(err) + } + dir := filepath.Join(cacheDir, "ssh-auth-none-demo") + if err := os.MkdirAll(dir, 0700); err != nil { + log.Fatal(err) + } + + keys, err := getHostKeys(dir) + if err != nil { + log.Fatal(err) + } + if len(keys) == 0 { + log.Fatal("no host keys") + } + + srv := &ssh.Server{ + Addr: *addr, + Version: "Tailscale", + Handler: handleSessionPostSSHAuth, + ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { + start := time.Now() + return &gossh.ServerConfig{ + NextAuthMethodCallback: func(conn gossh.ConnMetadata, prevErrors []error) []string { + return []string{"tailscale"} + }, + NoClientAuth: true, // required for the NoClientAuthCallback to run + NoClientAuthCallback: func(cm gossh.ConnMetadata) (*gossh.Permissions, error) { + cm.SendAuthBanner(fmt.Sprintf("# Banner: doing none auth at %v\r\n", time.Since(start))) + + totalBanners := 2 + if cm.User() == "banners" { + totalBanners = 5 + } + for banner := 2; banner <= totalBanners; banner++ { + time.Sleep(time.Second) + if banner == totalBanners { + cm.SendAuthBanner(fmt.Sprintf("# Banner%d: access granted at %v\r\n", banner, time.Since(start))) + } else { + cm.SendAuthBanner(fmt.Sprintf("# Banner%d at %v\r\n", banner, time.Since(start))) + } + } + return nil, nil + }, + BannerCallback: func(cm gossh.ConnMetadata) string { + log.Printf("Got connection from user %q, %q from %v", cm.User(), cm.ClientVersion(), cm.RemoteAddr()) + return fmt.Sprintf("# Banner for user %q, %q\n", cm.User(), cm.ClientVersion()) + }, + } + }, + } + + for _, signer := range keys { + srv.AddHostKey(signer) + } + + log.Printf("Running on %s ...", srv.Addr) + if err := srv.ListenAndServe(); err != nil { + log.Fatal(err) + } + log.Printf("done") +} + +func handleSessionPostSSHAuth(s ssh.Session) { + log.Printf("Started session from user %q", s.User()) + fmt.Fprintf(s, "Hello user %q, it worked.\n", s.User()) + + // Abort the session on Control-C or Control-D. + go func() { + buf := make([]byte, 1024) + for { + n, err := s.Read(buf) + for _, b := range buf[:n] { + if b <= 4 { // abort on Control-C (3) or Control-D (4) + io.WriteString(s, "bye\n") + s.Exit(1) + } + } + if err != nil { + return + } + } + }() + + for i := 10; i > 0; i-- { + fmt.Fprintf(s, "%v ...\n", i) + time.Sleep(time.Second) + } + s.Exit(0) +} + +func getHostKeys(dir string) (ret []ssh.Signer, err error) { + for _, typ := range keyTypes { + hostKey, err := hostKeyFileOrCreate(dir, typ) + if err != nil { + return nil, err + } + signer, err := gossh.ParsePrivateKey(hostKey) + if err != nil { + return nil, err + } + ret = append(ret, signer) + } + return ret, nil +} + +func hostKeyFileOrCreate(keyDir, typ string) ([]byte, error) { + path := filepath.Join(keyDir, "ssh_host_"+typ+"_key") + v, err := os.ReadFile(path) + if err == nil { + return v, nil + } + if !os.IsNotExist(err) { + return nil, err + } + var priv any + switch typ { + default: + return nil, fmt.Errorf("unsupported key type %q", typ) + case "ed25519": + _, priv, err = ed25519.GenerateKey(rand.Reader) + case "ecdsa": + // curve is arbitrary. We pick whatever will at + // least pacify clients as the actual encryption + // doesn't matter: it's all over WireGuard anyway. + curve := elliptic.P256() + priv, err = ecdsa.GenerateKey(curve, rand.Reader) + case "rsa": + // keySize is arbitrary. We pick whatever will at + // least pacify clients as the actual encryption + // doesn't matter: it's all over WireGuard anyway. + const keySize = 2048 + priv, err = rsa.GenerateKey(rand.Reader, keySize) + } + if err != nil { + return nil, err + } + mk, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + return nil, err + } + pemGen := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk}) + err = os.WriteFile(path, pemGen, 0700) + return pemGen, err +} diff --git a/cmd/sync-containers/main.go b/cmd/sync-containers/main.go index 68308cfeb3eda..6317b4943ae82 100644 --- a/cmd/sync-containers/main.go +++ b/cmd/sync-containers/main.go @@ -1,214 +1,214 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -// The sync-containers command synchronizes container image tags from one -// registry to another. -// -// It is intended as a workaround for ghcr.io's lack of good push credentials: -// you can either authorize "classic" Personal Access Tokens in your org (which -// are a common vector of very bad compromise), or you can get a short-lived -// credential in a Github action. -// -// Since we publish to both Docker Hub and ghcr.io, we use this program in a -// Github action to effectively rsync from docker hub into ghcr.io, so that we -// can continue to forbid dangerous Personal Access Tokens in the tailscale org. -package main - -import ( - "context" - "flag" - "fmt" - "log" - "sort" - "strings" - - "github.com/google/go-containerregistry/pkg/authn" - "github.com/google/go-containerregistry/pkg/authn/github" - "github.com/google/go-containerregistry/pkg/name" - v1 "github.com/google/go-containerregistry/pkg/v1" - "github.com/google/go-containerregistry/pkg/v1/remote" - "github.com/google/go-containerregistry/pkg/v1/types" -) - -var ( - src = flag.String("src", "", "Source image") - dst = flag.String("dst", "", "Destination image") - max = flag.Int("max", 0, "Maximum number of tags to sync (0 for all tags)") - dryRun = flag.Bool("dry-run", true, "Don't actually sync anything") -) - -func main() { - flag.Parse() - - if *src == "" { - log.Fatalf("--src is required") - } - if *dst == "" { - log.Fatalf("--dst is required") - } - - keychain := authn.NewMultiKeychain(authn.DefaultKeychain, github.Keychain) - opts := []remote.Option{ - remote.WithAuthFromKeychain(keychain), - remote.WithContext(context.Background()), - } - - stags, err := listTags(*src, opts...) - if err != nil { - log.Fatalf("listing source tags: %v", err) - } - dtags, err := listTags(*dst, opts...) - if err != nil { - log.Fatalf("listing destination tags: %v", err) - } - - add, remove := diffTags(stags, dtags) - if l := len(add); l > 0 { - log.Printf("%d tags to push: %s", len(add), strings.Join(add, ", ")) - if *max > 0 && l > *max { - log.Printf("Limiting sync to %d tags", *max) - add = add[:*max] - } - } - for _, tag := range add { - if !*dryRun { - log.Printf("Syncing tag %q", tag) - if err := copyTag(*src, *dst, tag, opts...); err != nil { - log.Printf("Syncing tag %q: progress error: %v", tag, err) - } - } else { - log.Printf("Dry run: would sync tag %q", tag) - } - } - - if len(remove) > 0 { - log.Printf("%d tags to remove: %s\n", len(remove), strings.Join(remove, ", ")) - log.Printf("Not removing any tags for safety.\n") - } - - var wellKnown = [...]string{"latest", "stable"} - for _, tag := range wellKnown { - if needsUpdate(*src, *dst, tag) { - if err := copyTag(*src, *dst, tag, opts...); err != nil { - log.Printf("Updating tag %q: progress error: %v", tag, err) - } - } - } -} - -func copyTag(srcStr, dstStr, tag string, opts ...remote.Option) error { - src, err := name.ParseReference(fmt.Sprintf("%s:%s", srcStr, tag)) - if err != nil { - return err - } - dst, err := name.ParseReference(fmt.Sprintf("%s:%s", dstStr, tag)) - if err != nil { - return err - } - - desc, err := remote.Get(src) - if err != nil { - return err - } - - ch := make(chan v1.Update, 10) - opts = append(opts, remote.WithProgress(ch)) - progressDone := make(chan struct{}) - - go func() { - defer close(progressDone) - for p := range ch { - fmt.Printf("Syncing tag %q: %d%% (%d/%d)\n", tag, int(float64(p.Complete)/float64(p.Total)*100), p.Complete, p.Total) - if p.Error != nil { - fmt.Printf("error: %v\n", p.Error) - } - } - }() - - switch desc.MediaType { - case types.OCIManifestSchema1, types.DockerManifestSchema2: - img, err := desc.Image() - if err != nil { - return err - } - if err := remote.Write(dst, img, opts...); err != nil { - return err - } - case types.OCIImageIndex, types.DockerManifestList: - idx, err := desc.ImageIndex() - if err != nil { - return err - } - if err := remote.WriteIndex(dst, idx, opts...); err != nil { - return err - } - } - - <-progressDone - return nil -} - -func listTags(repoStr string, opts ...remote.Option) ([]string, error) { - repo, err := name.NewRepository(repoStr) - if err != nil { - return nil, err - } - - tags, err := remote.List(repo, opts...) - if err != nil { - return nil, err - } - - sort.Strings(tags) - return tags, nil -} - -func diffTags(src, dst []string) (add, remove []string) { - srcd := make(map[string]bool) - for _, tag := range src { - srcd[tag] = true - } - dstd := make(map[string]bool) - for _, tag := range dst { - dstd[tag] = true - } - - for _, tag := range src { - if !dstd[tag] { - add = append(add, tag) - } - } - for _, tag := range dst { - if !srcd[tag] { - remove = append(remove, tag) - } - } - sort.Strings(add) - sort.Strings(remove) - return add, remove -} - -func needsUpdate(srcStr, dstStr, tag string) bool { - src, err := name.ParseReference(fmt.Sprintf("%s:%s", srcStr, tag)) - if err != nil { - return false - } - dst, err := name.ParseReference(fmt.Sprintf("%s:%s", dstStr, tag)) - if err != nil { - return false - } - - srcDesc, err := remote.Get(src) - if err != nil { - return false - } - - dstDesc, err := remote.Get(dst) - if err != nil { - return true - } - - return srcDesc.Digest != dstDesc.Digest -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// The sync-containers command synchronizes container image tags from one +// registry to another. +// +// It is intended as a workaround for ghcr.io's lack of good push credentials: +// you can either authorize "classic" Personal Access Tokens in your org (which +// are a common vector of very bad compromise), or you can get a short-lived +// credential in a Github action. +// +// Since we publish to both Docker Hub and ghcr.io, we use this program in a +// Github action to effectively rsync from docker hub into ghcr.io, so that we +// can continue to forbid dangerous Personal Access Tokens in the tailscale org. +package main + +import ( + "context" + "flag" + "fmt" + "log" + "sort" + "strings" + + "github.com/google/go-containerregistry/pkg/authn" + "github.com/google/go-containerregistry/pkg/authn/github" + "github.com/google/go-containerregistry/pkg/name" + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/remote" + "github.com/google/go-containerregistry/pkg/v1/types" +) + +var ( + src = flag.String("src", "", "Source image") + dst = flag.String("dst", "", "Destination image") + max = flag.Int("max", 0, "Maximum number of tags to sync (0 for all tags)") + dryRun = flag.Bool("dry-run", true, "Don't actually sync anything") +) + +func main() { + flag.Parse() + + if *src == "" { + log.Fatalf("--src is required") + } + if *dst == "" { + log.Fatalf("--dst is required") + } + + keychain := authn.NewMultiKeychain(authn.DefaultKeychain, github.Keychain) + opts := []remote.Option{ + remote.WithAuthFromKeychain(keychain), + remote.WithContext(context.Background()), + } + + stags, err := listTags(*src, opts...) + if err != nil { + log.Fatalf("listing source tags: %v", err) + } + dtags, err := listTags(*dst, opts...) + if err != nil { + log.Fatalf("listing destination tags: %v", err) + } + + add, remove := diffTags(stags, dtags) + if l := len(add); l > 0 { + log.Printf("%d tags to push: %s", len(add), strings.Join(add, ", ")) + if *max > 0 && l > *max { + log.Printf("Limiting sync to %d tags", *max) + add = add[:*max] + } + } + for _, tag := range add { + if !*dryRun { + log.Printf("Syncing tag %q", tag) + if err := copyTag(*src, *dst, tag, opts...); err != nil { + log.Printf("Syncing tag %q: progress error: %v", tag, err) + } + } else { + log.Printf("Dry run: would sync tag %q", tag) + } + } + + if len(remove) > 0 { + log.Printf("%d tags to remove: %s\n", len(remove), strings.Join(remove, ", ")) + log.Printf("Not removing any tags for safety.\n") + } + + var wellKnown = [...]string{"latest", "stable"} + for _, tag := range wellKnown { + if needsUpdate(*src, *dst, tag) { + if err := copyTag(*src, *dst, tag, opts...); err != nil { + log.Printf("Updating tag %q: progress error: %v", tag, err) + } + } + } +} + +func copyTag(srcStr, dstStr, tag string, opts ...remote.Option) error { + src, err := name.ParseReference(fmt.Sprintf("%s:%s", srcStr, tag)) + if err != nil { + return err + } + dst, err := name.ParseReference(fmt.Sprintf("%s:%s", dstStr, tag)) + if err != nil { + return err + } + + desc, err := remote.Get(src) + if err != nil { + return err + } + + ch := make(chan v1.Update, 10) + opts = append(opts, remote.WithProgress(ch)) + progressDone := make(chan struct{}) + + go func() { + defer close(progressDone) + for p := range ch { + fmt.Printf("Syncing tag %q: %d%% (%d/%d)\n", tag, int(float64(p.Complete)/float64(p.Total)*100), p.Complete, p.Total) + if p.Error != nil { + fmt.Printf("error: %v\n", p.Error) + } + } + }() + + switch desc.MediaType { + case types.OCIManifestSchema1, types.DockerManifestSchema2: + img, err := desc.Image() + if err != nil { + return err + } + if err := remote.Write(dst, img, opts...); err != nil { + return err + } + case types.OCIImageIndex, types.DockerManifestList: + idx, err := desc.ImageIndex() + if err != nil { + return err + } + if err := remote.WriteIndex(dst, idx, opts...); err != nil { + return err + } + } + + <-progressDone + return nil +} + +func listTags(repoStr string, opts ...remote.Option) ([]string, error) { + repo, err := name.NewRepository(repoStr) + if err != nil { + return nil, err + } + + tags, err := remote.List(repo, opts...) + if err != nil { + return nil, err + } + + sort.Strings(tags) + return tags, nil +} + +func diffTags(src, dst []string) (add, remove []string) { + srcd := make(map[string]bool) + for _, tag := range src { + srcd[tag] = true + } + dstd := make(map[string]bool) + for _, tag := range dst { + dstd[tag] = true + } + + for _, tag := range src { + if !dstd[tag] { + add = append(add, tag) + } + } + for _, tag := range dst { + if !srcd[tag] { + remove = append(remove, tag) + } + } + sort.Strings(add) + sort.Strings(remove) + return add, remove +} + +func needsUpdate(srcStr, dstStr, tag string) bool { + src, err := name.ParseReference(fmt.Sprintf("%s:%s", srcStr, tag)) + if err != nil { + return false + } + dst, err := name.ParseReference(fmt.Sprintf("%s:%s", dstStr, tag)) + if err != nil { + return false + } + + srcDesc, err := remote.Get(src) + if err != nil { + return false + } + + dstDesc, err := remote.Get(dst) + if err != nil { + return true + } + + return srcDesc.Digest != dstDesc.Digest +} diff --git a/cmd/tailscale/cli/diag.go b/cmd/tailscale/cli/diag.go index a1616f851e142..ebf26985fe0bd 100644 --- a/cmd/tailscale/cli/diag.go +++ b/cmd/tailscale/cli/diag.go @@ -1,74 +1,74 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || windows || darwin - -package cli - -import ( - "fmt" - "os/exec" - "path/filepath" - "runtime" - "strings" - - ps "github.com/mitchellh/go-ps" - "tailscale.com/version/distro" -) - -// fixTailscaledConnectError is called when the local tailscaled has -// been determined unreachable due to the provided origErr value. It -// returns either the same error or a better one to help the user -// understand why tailscaled isn't running for their platform. -func fixTailscaledConnectError(origErr error) error { - procs, err := ps.Processes() - if err != nil { - return fmt.Errorf("failed to connect to local Tailscaled process and failed to enumerate processes while looking for it") - } - var foundProc ps.Process - for _, proc := range procs { - base := filepath.Base(proc.Executable()) - if base == "tailscaled" { - foundProc = proc - break - } - if runtime.GOOS == "darwin" && base == "IPNExtension" { - foundProc = proc - break - } - if runtime.GOOS == "windows" && strings.EqualFold(base, "tailscaled.exe") { - foundProc = proc - break - } - } - if foundProc == nil { - switch runtime.GOOS { - case "windows": - return fmt.Errorf("failed to connect to local tailscaled process; is the Tailscale service running?") - case "darwin": - return fmt.Errorf("failed to connect to local Tailscale service; is Tailscale running?") - case "linux": - var hint string - if isSystemdSystem() { - hint = " (sudo systemctl start tailscaled ?)" - } - return fmt.Errorf("failed to connect to local tailscaled; it doesn't appear to be running%s", hint) - } - return fmt.Errorf("failed to connect to local tailscaled process; it doesn't appear to be running") - } - return fmt.Errorf("failed to connect to local tailscaled (which appears to be running as %v, pid %v). Got error: %w", foundProc.Executable(), foundProc.Pid(), origErr) -} - -// isSystemdSystem reports whether the current machine uses systemd -// and in particular whether the systemctl command is available. -func isSystemdSystem() bool { - if runtime.GOOS != "linux" { - return false - } - switch distro.Get() { - case distro.QNAP, distro.Gokrazy, distro.Synology, distro.Unraid: - return false - } - _, err := exec.LookPath("systemctl") - return err == nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux || windows || darwin + +package cli + +import ( + "fmt" + "os/exec" + "path/filepath" + "runtime" + "strings" + + ps "github.com/mitchellh/go-ps" + "tailscale.com/version/distro" +) + +// fixTailscaledConnectError is called when the local tailscaled has +// been determined unreachable due to the provided origErr value. It +// returns either the same error or a better one to help the user +// understand why tailscaled isn't running for their platform. +func fixTailscaledConnectError(origErr error) error { + procs, err := ps.Processes() + if err != nil { + return fmt.Errorf("failed to connect to local Tailscaled process and failed to enumerate processes while looking for it") + } + var foundProc ps.Process + for _, proc := range procs { + base := filepath.Base(proc.Executable()) + if base == "tailscaled" { + foundProc = proc + break + } + if runtime.GOOS == "darwin" && base == "IPNExtension" { + foundProc = proc + break + } + if runtime.GOOS == "windows" && strings.EqualFold(base, "tailscaled.exe") { + foundProc = proc + break + } + } + if foundProc == nil { + switch runtime.GOOS { + case "windows": + return fmt.Errorf("failed to connect to local tailscaled process; is the Tailscale service running?") + case "darwin": + return fmt.Errorf("failed to connect to local Tailscale service; is Tailscale running?") + case "linux": + var hint string + if isSystemdSystem() { + hint = " (sudo systemctl start tailscaled ?)" + } + return fmt.Errorf("failed to connect to local tailscaled; it doesn't appear to be running%s", hint) + } + return fmt.Errorf("failed to connect to local tailscaled process; it doesn't appear to be running") + } + return fmt.Errorf("failed to connect to local tailscaled (which appears to be running as %v, pid %v). Got error: %w", foundProc.Executable(), foundProc.Pid(), origErr) +} + +// isSystemdSystem reports whether the current machine uses systemd +// and in particular whether the systemctl command is available. +func isSystemdSystem() bool { + if runtime.GOOS != "linux" { + return false + } + switch distro.Get() { + case distro.QNAP, distro.Gokrazy, distro.Synology, distro.Unraid: + return false + } + _, err := exec.LookPath("systemctl") + return err == nil +} diff --git a/cmd/tailscale/cli/diag_other.go b/cmd/tailscale/cli/diag_other.go index 82058ef7a139c..ece10cc79a822 100644 --- a/cmd/tailscale/cli/diag_other.go +++ b/cmd/tailscale/cli/diag_other.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !windows && !darwin - -package cli - -import "fmt" - -// The github.com/mitchellh/go-ps package doesn't work on all platforms, -// so just don't diagnose connect failures. - -func fixTailscaledConnectError(origErr error) error { - return fmt.Errorf("failed to connect to local tailscaled process (is it running?); got: %w", origErr) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !windows && !darwin + +package cli + +import "fmt" + +// The github.com/mitchellh/go-ps package doesn't work on all platforms, +// so just don't diagnose connect failures. + +func fixTailscaledConnectError(origErr error) error { + return fmt.Errorf("failed to connect to local tailscaled process (is it running?); got: %w", origErr) +} diff --git a/cmd/tailscale/cli/set_test.go b/cmd/tailscale/cli/set_test.go index 06ef8503f048e..15305c3ce3ed3 100644 --- a/cmd/tailscale/cli/set_test.go +++ b/cmd/tailscale/cli/set_test.go @@ -1,131 +1,131 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cli - -import ( - "net/netip" - "reflect" - "testing" - - "tailscale.com/ipn" - "tailscale.com/net/tsaddr" - "tailscale.com/types/ptr" -) - -func TestCalcAdvertiseRoutesForSet(t *testing.T) { - pfx := netip.MustParsePrefix - tests := []struct { - name string - setExit *bool - setRoutes *string - was []netip.Prefix - want []netip.Prefix - }{ - { - name: "empty", - }, - { - name: "advertise-exit", - setExit: ptr.To(true), - want: tsaddr.ExitRoutes(), - }, - { - name: "advertise-exit/already-routes", - was: []netip.Prefix{pfx("34.0.0.0/16")}, - setExit: ptr.To(true), - want: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - }, - { - name: "advertise-exit/already-exit", - was: tsaddr.ExitRoutes(), - setExit: ptr.To(true), - want: tsaddr.ExitRoutes(), - }, - { - name: "stop-advertise-exit", - was: tsaddr.ExitRoutes(), - setExit: ptr.To(false), - want: nil, - }, - { - name: "stop-advertise-exit/with-routes", - was: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - setExit: ptr.To(false), - want: []netip.Prefix{pfx("34.0.0.0/16")}, - }, - { - name: "advertise-routes", - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, - }, - { - name: "advertise-routes/already-exit", - was: tsaddr.ExitRoutes(), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - }, - { - name: "advertise-routes/already-diff-routes", - was: []netip.Prefix{pfx("34.0.0.0/16")}, - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, - }, - { - name: "stop-advertise-routes", - was: []netip.Prefix{pfx("34.0.0.0/16")}, - setRoutes: ptr.To(""), - want: nil, - }, - { - name: "stop-advertise-routes/already-exit", - was: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - setRoutes: ptr.To(""), - want: tsaddr.ExitRoutes(), - }, - { - name: "advertise-routes-and-exit", - setExit: ptr.To(true), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - }, - { - name: "advertise-routes-and-exit/already-exit", - was: tsaddr.ExitRoutes(), - setExit: ptr.To(true), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - }, - { - name: "advertise-routes-and-exit/already-routes", - was: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, - setExit: ptr.To(true), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - curPrefs := &ipn.Prefs{ - AdvertiseRoutes: tc.was, - } - sa := setArgsT{} - if tc.setExit != nil { - sa.advertiseDefaultRoute = *tc.setExit - } - if tc.setRoutes != nil { - sa.advertiseRoutes = *tc.setRoutes - } - got, err := calcAdvertiseRoutesForSet(tc.setExit != nil, tc.setRoutes != nil, curPrefs, sa) - if err != nil { - t.Fatal(err) - } - tsaddr.SortPrefixes(got) - tsaddr.SortPrefixes(tc.want) - if !reflect.DeepEqual(got, tc.want) { - t.Errorf("got %v, want %v", got, tc.want) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "net/netip" + "reflect" + "testing" + + "tailscale.com/ipn" + "tailscale.com/net/tsaddr" + "tailscale.com/types/ptr" +) + +func TestCalcAdvertiseRoutesForSet(t *testing.T) { + pfx := netip.MustParsePrefix + tests := []struct { + name string + setExit *bool + setRoutes *string + was []netip.Prefix + want []netip.Prefix + }{ + { + name: "empty", + }, + { + name: "advertise-exit", + setExit: ptr.To(true), + want: tsaddr.ExitRoutes(), + }, + { + name: "advertise-exit/already-routes", + was: []netip.Prefix{pfx("34.0.0.0/16")}, + setExit: ptr.To(true), + want: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + { + name: "advertise-exit/already-exit", + was: tsaddr.ExitRoutes(), + setExit: ptr.To(true), + want: tsaddr.ExitRoutes(), + }, + { + name: "stop-advertise-exit", + was: tsaddr.ExitRoutes(), + setExit: ptr.To(false), + want: nil, + }, + { + name: "stop-advertise-exit/with-routes", + was: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + setExit: ptr.To(false), + want: []netip.Prefix{pfx("34.0.0.0/16")}, + }, + { + name: "advertise-routes", + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, + }, + { + name: "advertise-routes/already-exit", + was: tsaddr.ExitRoutes(), + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + { + name: "advertise-routes/already-diff-routes", + was: []netip.Prefix{pfx("34.0.0.0/16")}, + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, + }, + { + name: "stop-advertise-routes", + was: []netip.Prefix{pfx("34.0.0.0/16")}, + setRoutes: ptr.To(""), + want: nil, + }, + { + name: "stop-advertise-routes/already-exit", + was: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + setRoutes: ptr.To(""), + want: tsaddr.ExitRoutes(), + }, + { + name: "advertise-routes-and-exit", + setExit: ptr.To(true), + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + { + name: "advertise-routes-and-exit/already-exit", + was: tsaddr.ExitRoutes(), + setExit: ptr.To(true), + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + { + name: "advertise-routes-and-exit/already-routes", + was: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, + setExit: ptr.To(true), + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + curPrefs := &ipn.Prefs{ + AdvertiseRoutes: tc.was, + } + sa := setArgsT{} + if tc.setExit != nil { + sa.advertiseDefaultRoute = *tc.setExit + } + if tc.setRoutes != nil { + sa.advertiseRoutes = *tc.setRoutes + } + got, err := calcAdvertiseRoutesForSet(tc.setExit != nil, tc.setRoutes != nil, curPrefs, sa) + if err != nil { + t.Fatal(err) + } + tsaddr.SortPrefixes(got) + tsaddr.SortPrefixes(tc.want) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("got %v, want %v", got, tc.want) + } + }) + } +} diff --git a/cmd/tailscale/cli/ssh_exec.go b/cmd/tailscale/cli/ssh_exec.go index 7f7d2a4d5cfe0..10e52903dea64 100644 --- a/cmd/tailscale/cli/ssh_exec.go +++ b/cmd/tailscale/cli/ssh_exec.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !js && !windows - -package cli - -import ( - "errors" - "os" - "os/exec" - "syscall" -) - -func findSSH() (string, error) { - return exec.LookPath("ssh") -} - -func execSSH(ssh string, argv []string) error { - if err := syscall.Exec(ssh, argv, os.Environ()); err != nil { - return err - } - return errors.New("unreachable") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !windows + +package cli + +import ( + "errors" + "os" + "os/exec" + "syscall" +) + +func findSSH() (string, error) { + return exec.LookPath("ssh") +} + +func execSSH(ssh string, argv []string) error { + if err := syscall.Exec(ssh, argv, os.Environ()); err != nil { + return err + } + return errors.New("unreachable") +} diff --git a/cmd/tailscale/cli/ssh_exec_js.go b/cmd/tailscale/cli/ssh_exec_js.go index aa0c09e89ab66..40effc7cafc7e 100644 --- a/cmd/tailscale/cli/ssh_exec_js.go +++ b/cmd/tailscale/cli/ssh_exec_js.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cli - -import ( - "errors" -) - -func findSSH() (string, error) { - return "", errors.New("Not implemented") -} - -func execSSH(ssh string, argv []string) error { - return errors.New("Not implemented") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "errors" +) + +func findSSH() (string, error) { + return "", errors.New("Not implemented") +} + +func execSSH(ssh string, argv []string) error { + return errors.New("Not implemented") +} diff --git a/cmd/tailscale/cli/ssh_exec_windows.go b/cmd/tailscale/cli/ssh_exec_windows.go index 30ab70d046dd4..e249afe667401 100644 --- a/cmd/tailscale/cli/ssh_exec_windows.go +++ b/cmd/tailscale/cli/ssh_exec_windows.go @@ -1,37 +1,37 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cli - -import ( - "errors" - "os" - "os/exec" - "path/filepath" -) - -func findSSH() (string, error) { - // use C:\Windows\System32\OpenSSH\ssh.exe since unexpected behavior - // occurred with ssh.exe provided by msys2/cygwin and other environments. - if systemRoot := os.Getenv("SystemRoot"); systemRoot != "" { - exe := filepath.Join(systemRoot, "System32", "OpenSSH", "ssh.exe") - if st, err := os.Stat(exe); err == nil && !st.IsDir() { - return exe, nil - } - } - return exec.LookPath("ssh") -} - -func execSSH(ssh string, argv []string) error { - // Don't use syscall.Exec on Windows, it's not fully implemented. - cmd := exec.Command(ssh, argv[1:]...) - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - var ee *exec.ExitError - err := cmd.Run() - if errors.As(err, &ee) { - os.Exit(ee.ExitCode()) - } - return err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "errors" + "os" + "os/exec" + "path/filepath" +) + +func findSSH() (string, error) { + // use C:\Windows\System32\OpenSSH\ssh.exe since unexpected behavior + // occurred with ssh.exe provided by msys2/cygwin and other environments. + if systemRoot := os.Getenv("SystemRoot"); systemRoot != "" { + exe := filepath.Join(systemRoot, "System32", "OpenSSH", "ssh.exe") + if st, err := os.Stat(exe); err == nil && !st.IsDir() { + return exe, nil + } + } + return exec.LookPath("ssh") +} + +func execSSH(ssh string, argv []string) error { + // Don't use syscall.Exec on Windows, it's not fully implemented. + cmd := exec.Command(ssh, argv[1:]...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + var ee *exec.ExitError + err := cmd.Run() + if errors.As(err, &ee) { + os.Exit(ee.ExitCode()) + } + return err +} diff --git a/cmd/tailscale/cli/ssh_unix.go b/cmd/tailscale/cli/ssh_unix.go index 07423b69fa9e6..71c0caaa69ad5 100644 --- a/cmd/tailscale/cli/ssh_unix.go +++ b/cmd/tailscale/cli/ssh_unix.go @@ -1,49 +1,49 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !wasm && !windows && !plan9 - -package cli - -import ( - "bytes" - "os" - "path/filepath" - "runtime" - "strconv" - - "golang.org/x/sys/unix" -) - -func init() { - getSSHClientEnvVar = func() string { - if os.Getenv("SUDO_USER") == "" { - // No sudo, just check the env. - return os.Getenv("SSH_CLIENT") - } - if runtime.GOOS != "linux" { - // TODO(maisem): implement this for other platforms. It's not clear - // if there is a way to get the environment for a given process on - // darwin and bsd. - return "" - } - // SID is the session ID of the user's login session. - // It is also the process ID of the original shell that the user logged in with. - // We only need to check the environment of that process. - sid, err := unix.Getsid(os.Getpid()) - if err != nil { - return "" - } - b, err := os.ReadFile(filepath.Join("/proc", strconv.Itoa(sid), "environ")) - if err != nil { - return "" - } - prefix := []byte("SSH_CLIENT=") - for _, env := range bytes.Split(b, []byte{0}) { - if bytes.HasPrefix(env, prefix) { - return string(env[len(prefix):]) - } - } - return "" - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !wasm && !windows && !plan9 + +package cli + +import ( + "bytes" + "os" + "path/filepath" + "runtime" + "strconv" + + "golang.org/x/sys/unix" +) + +func init() { + getSSHClientEnvVar = func() string { + if os.Getenv("SUDO_USER") == "" { + // No sudo, just check the env. + return os.Getenv("SSH_CLIENT") + } + if runtime.GOOS != "linux" { + // TODO(maisem): implement this for other platforms. It's not clear + // if there is a way to get the environment for a given process on + // darwin and bsd. + return "" + } + // SID is the session ID of the user's login session. + // It is also the process ID of the original shell that the user logged in with. + // We only need to check the environment of that process. + sid, err := unix.Getsid(os.Getpid()) + if err != nil { + return "" + } + b, err := os.ReadFile(filepath.Join("/proc", strconv.Itoa(sid), "environ")) + if err != nil { + return "" + } + prefix := []byte("SSH_CLIENT=") + for _, env := range bytes.Split(b, []byte{0}) { + if bytes.HasPrefix(env, prefix) { + return string(env[len(prefix):]) + } + } + return "" + } +} diff --git a/cmd/tailscale/cli/web_test.go b/cmd/tailscale/cli/web_test.go index f1880597e5c53..f2470b364c41e 100644 --- a/cmd/tailscale/cli/web_test.go +++ b/cmd/tailscale/cli/web_test.go @@ -1,45 +1,45 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cli - -import ( - "testing" -) - -func TestUrlOfListenAddr(t *testing.T) { - tests := []struct { - name string - in, want string - }{ - { - name: "TestLocalhost", - in: "localhost:8088", - want: "http://localhost:8088", - }, - { - name: "TestNoHost", - in: ":8088", - want: "http://127.0.0.1:8088", - }, - { - name: "TestExplicitHost", - in: "127.0.0.2:8088", - want: "http://127.0.0.2:8088", - }, - { - name: "TestIPv6", - in: "[::1]:8088", - want: "http://[::1]:8088", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - u := urlOfListenAddr(tt.in) - if u != tt.want { - t.Errorf("expected url: %q, got: %q", tt.want, u) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "testing" +) + +func TestUrlOfListenAddr(t *testing.T) { + tests := []struct { + name string + in, want string + }{ + { + name: "TestLocalhost", + in: "localhost:8088", + want: "http://localhost:8088", + }, + { + name: "TestNoHost", + in: ":8088", + want: "http://127.0.0.1:8088", + }, + { + name: "TestExplicitHost", + in: "127.0.0.2:8088", + want: "http://127.0.0.2:8088", + }, + { + name: "TestIPv6", + in: "[::1]:8088", + want: "http://[::1]:8088", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u := urlOfListenAddr(tt.in) + if u != tt.want { + t.Errorf("expected url: %q, got: %q", tt.want, u) + } + }) + } +} diff --git a/cmd/tailscale/generate.go b/cmd/tailscale/generate.go index fa38b370417aa..5c2e9be915980 100644 --- a/cmd/tailscale/generate.go +++ b/cmd/tailscale/generate.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -//go:generate go run tailscale.com/cmd/mkmanifest amd64 windows-manifest.xml manifest_windows_amd64.syso -//go:generate go run tailscale.com/cmd/mkmanifest 386 windows-manifest.xml manifest_windows_386.syso -//go:generate go run tailscale.com/cmd/mkmanifest arm64 windows-manifest.xml manifest_windows_arm64.syso +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +//go:generate go run tailscale.com/cmd/mkmanifest amd64 windows-manifest.xml manifest_windows_amd64.syso +//go:generate go run tailscale.com/cmd/mkmanifest 386 windows-manifest.xml manifest_windows_386.syso +//go:generate go run tailscale.com/cmd/mkmanifest arm64 windows-manifest.xml manifest_windows_arm64.syso diff --git a/cmd/tailscale/tailscale.go b/cmd/tailscale/tailscale.go index 1848d65088c3d..f6adb6c197071 100644 --- a/cmd/tailscale/tailscale.go +++ b/cmd/tailscale/tailscale.go @@ -1,26 +1,26 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The tailscale command is the Tailscale command-line client. It interacts -// with the tailscaled node agent. -package main // import "tailscale.com/cmd/tailscale" - -import ( - "fmt" - "os" - "path/filepath" - "strings" - - "tailscale.com/cmd/tailscale/cli" -) - -func main() { - args := os.Args[1:] - if name, _ := os.Executable(); strings.HasSuffix(filepath.Base(name), ".cgi") { - args = []string{"web", "-cgi"} - } - if err := cli.Run(args); err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The tailscale command is the Tailscale command-line client. It interacts +// with the tailscaled node agent. +package main // import "tailscale.com/cmd/tailscale" + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "tailscale.com/cmd/tailscale/cli" +) + +func main() { + args := os.Args[1:] + if name, _ := os.Executable(); strings.HasSuffix(filepath.Base(name), ".cgi") { + args = []string{"web", "-cgi"} + } + if err := cli.Run(args); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} diff --git a/cmd/tailscale/windows-manifest.xml b/cmd/tailscale/windows-manifest.xml index 5eaa54fa514e3..6c5f46058387f 100644 --- a/cmd/tailscale/windows-manifest.xml +++ b/cmd/tailscale/windows-manifest.xml @@ -1,13 +1,13 @@ - - - - - - - - - - - - - + + + + + + + + + + + + + diff --git a/cmd/tailscaled/childproc/childproc.go b/cmd/tailscaled/childproc/childproc.go index 068015c59f3eb..cc83a06c6ee7c 100644 --- a/cmd/tailscaled/childproc/childproc.go +++ b/cmd/tailscaled/childproc/childproc.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package childproc allows other packages to register "tailscaled be-child" -// child process hook code. This avoids duplicating build tags in the -// tailscaled package. Instead, the code that needs to fork/exec the self -// executable (when it's tailscaled) can instead register the code -// they want to run. -package childproc - -var Code = map[string]func([]string) error{} - -// Add registers code f to run as 'tailscaled be-child [args]'. -func Add(typ string, f func(args []string) error) { - if _, dup := Code[typ]; dup { - panic("dup hook " + typ) - } - Code[typ] = f -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package childproc allows other packages to register "tailscaled be-child" +// child process hook code. This avoids duplicating build tags in the +// tailscaled package. Instead, the code that needs to fork/exec the self +// executable (when it's tailscaled) can instead register the code +// they want to run. +package childproc + +var Code = map[string]func([]string) error{} + +// Add registers code f to run as 'tailscaled be-child [args]'. +func Add(typ string, f func(args []string) error) { + if _, dup := Code[typ]; dup { + panic("dup hook " + typ) + } + Code[typ] = f +} diff --git a/cmd/tailscaled/generate.go b/cmd/tailscaled/generate.go index fa38b370417aa..5c2e9be915980 100644 --- a/cmd/tailscaled/generate.go +++ b/cmd/tailscaled/generate.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -//go:generate go run tailscale.com/cmd/mkmanifest amd64 windows-manifest.xml manifest_windows_amd64.syso -//go:generate go run tailscale.com/cmd/mkmanifest 386 windows-manifest.xml manifest_windows_386.syso -//go:generate go run tailscale.com/cmd/mkmanifest arm64 windows-manifest.xml manifest_windows_arm64.syso +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +//go:generate go run tailscale.com/cmd/mkmanifest amd64 windows-manifest.xml manifest_windows_amd64.syso +//go:generate go run tailscale.com/cmd/mkmanifest 386 windows-manifest.xml manifest_windows_386.syso +//go:generate go run tailscale.com/cmd/mkmanifest arm64 windows-manifest.xml manifest_windows_arm64.syso diff --git a/cmd/tailscaled/install_darwin.go b/cmd/tailscaled/install_darwin.go index 9013b39ba3567..05e5eaed8af90 100644 --- a/cmd/tailscaled/install_darwin.go +++ b/cmd/tailscaled/install_darwin.go @@ -1,199 +1,199 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -package main - -import ( - "errors" - "fmt" - "io" - "io/fs" - "os" - "os/exec" - "path/filepath" -) - -func init() { - installSystemDaemon = installSystemDaemonDarwin - uninstallSystemDaemon = uninstallSystemDaemonDarwin -} - -// darwinLaunchdPlist is the launchd.plist that's written to -// /Library/LaunchDaemons/com.tailscale.tailscaled.plist or (in the -// future) a user-specific location. -// -// See man launchd.plist. -const darwinLaunchdPlist = ` - - - - - - Label - com.tailscale.tailscaled - - ProgramArguments - - /usr/local/bin/tailscaled - - - RunAtLoad - - - - -` - -const sysPlist = "/Library/LaunchDaemons/com.tailscale.tailscaled.plist" -const targetBin = "/usr/local/bin/tailscaled" -const service = "com.tailscale.tailscaled" - -func uninstallSystemDaemonDarwin(args []string) (ret error) { - if len(args) > 0 { - return errors.New("uninstall subcommand takes no arguments") - } - - plist, err := exec.Command("launchctl", "list", "com.tailscale.tailscaled").Output() - _ = plist // parse it? https://github.com/DHowett/go-plist if we need something. - running := err == nil - - if running { - out, err := exec.Command("launchctl", "stop", "com.tailscale.tailscaled").CombinedOutput() - if err != nil { - fmt.Printf("launchctl stop com.tailscale.tailscaled: %v, %s\n", err, out) - ret = err - } - out, err = exec.Command("launchctl", "unload", sysPlist).CombinedOutput() - if err != nil { - fmt.Printf("launchctl unload %s: %v, %s\n", sysPlist, err, out) - if ret == nil { - ret = err - } - } - } - - if err := os.Remove(sysPlist); err != nil { - if os.IsNotExist(err) { - err = nil - } - if ret == nil { - ret = err - } - } - - // Do not delete targetBin if it's a symlink, which happens if it was installed via - // Homebrew. - if isSymlink(targetBin) { - return ret - } - - if err := os.Remove(targetBin); err != nil { - if os.IsNotExist(err) { - err = nil - } - if ret == nil { - ret = err - } - } - return ret -} - -func installSystemDaemonDarwin(args []string) (err error) { - if len(args) > 0 { - return errors.New("install subcommand takes no arguments") - } - defer func() { - if err != nil && os.Getuid() != 0 { - err = fmt.Errorf("%w; try running tailscaled with sudo", err) - } - }() - - // Best effort: - uninstallSystemDaemonDarwin(nil) - - exe, err := os.Executable() - if err != nil { - return fmt.Errorf("failed to find our own executable path: %w", err) - } - - same, err := sameFile(exe, targetBin) - if err != nil { - return err - } - - // Do not overwrite targetBin with the binary file if it it's already - // pointing to it. This is primarily to handle Homebrew that writes - // /usr/local/bin/tailscaled is a symlink to the actual binary. - if !same { - if err := copyBinary(exe, targetBin); err != nil { - return err - } - } - if err := os.WriteFile(sysPlist, []byte(darwinLaunchdPlist), 0700); err != nil { - return err - } - - if out, err := exec.Command("launchctl", "load", sysPlist).CombinedOutput(); err != nil { - return fmt.Errorf("error running launchctl load %s: %v, %s", sysPlist, err, out) - } - - if out, err := exec.Command("launchctl", "start", service).CombinedOutput(); err != nil { - return fmt.Errorf("error running launchctl start %s: %v, %s", service, err, out) - } - - return nil -} - -// copyBinary copies binary file `src` into `dst`. -func copyBinary(src, dst string) error { - if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { - return err - } - tmpBin := dst + ".tmp" - f, err := os.Create(tmpBin) - if err != nil { - return err - } - srcf, err := os.Open(src) - if err != nil { - f.Close() - return err - } - _, err = io.Copy(f, srcf) - srcf.Close() - if err != nil { - f.Close() - return err - } - if err := f.Close(); err != nil { - return err - } - if err := os.Chmod(tmpBin, 0755); err != nil { - return err - } - if err := os.Rename(tmpBin, dst); err != nil { - return err - } - - return nil -} - -func isSymlink(path string) bool { - fi, err := os.Lstat(path) - return err == nil && (fi.Mode()&os.ModeSymlink == os.ModeSymlink) -} - -// sameFile returns true if both file paths exist and resolve to the same file. -func sameFile(path1, path2 string) (bool, error) { - dst1, err := filepath.EvalSymlinks(path1) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return false, fmt.Errorf("EvalSymlinks(%s): %w", path1, err) - } - dst2, err := filepath.EvalSymlinks(path2) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return false, fmt.Errorf("EvalSymlinks(%s): %w", path2, err) - } - return dst1 == dst2, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package main + +import ( + "errors" + "fmt" + "io" + "io/fs" + "os" + "os/exec" + "path/filepath" +) + +func init() { + installSystemDaemon = installSystemDaemonDarwin + uninstallSystemDaemon = uninstallSystemDaemonDarwin +} + +// darwinLaunchdPlist is the launchd.plist that's written to +// /Library/LaunchDaemons/com.tailscale.tailscaled.plist or (in the +// future) a user-specific location. +// +// See man launchd.plist. +const darwinLaunchdPlist = ` + + + + + + Label + com.tailscale.tailscaled + + ProgramArguments + + /usr/local/bin/tailscaled + + + RunAtLoad + + + + +` + +const sysPlist = "/Library/LaunchDaemons/com.tailscale.tailscaled.plist" +const targetBin = "/usr/local/bin/tailscaled" +const service = "com.tailscale.tailscaled" + +func uninstallSystemDaemonDarwin(args []string) (ret error) { + if len(args) > 0 { + return errors.New("uninstall subcommand takes no arguments") + } + + plist, err := exec.Command("launchctl", "list", "com.tailscale.tailscaled").Output() + _ = plist // parse it? https://github.com/DHowett/go-plist if we need something. + running := err == nil + + if running { + out, err := exec.Command("launchctl", "stop", "com.tailscale.tailscaled").CombinedOutput() + if err != nil { + fmt.Printf("launchctl stop com.tailscale.tailscaled: %v, %s\n", err, out) + ret = err + } + out, err = exec.Command("launchctl", "unload", sysPlist).CombinedOutput() + if err != nil { + fmt.Printf("launchctl unload %s: %v, %s\n", sysPlist, err, out) + if ret == nil { + ret = err + } + } + } + + if err := os.Remove(sysPlist); err != nil { + if os.IsNotExist(err) { + err = nil + } + if ret == nil { + ret = err + } + } + + // Do not delete targetBin if it's a symlink, which happens if it was installed via + // Homebrew. + if isSymlink(targetBin) { + return ret + } + + if err := os.Remove(targetBin); err != nil { + if os.IsNotExist(err) { + err = nil + } + if ret == nil { + ret = err + } + } + return ret +} + +func installSystemDaemonDarwin(args []string) (err error) { + if len(args) > 0 { + return errors.New("install subcommand takes no arguments") + } + defer func() { + if err != nil && os.Getuid() != 0 { + err = fmt.Errorf("%w; try running tailscaled with sudo", err) + } + }() + + // Best effort: + uninstallSystemDaemonDarwin(nil) + + exe, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to find our own executable path: %w", err) + } + + same, err := sameFile(exe, targetBin) + if err != nil { + return err + } + + // Do not overwrite targetBin with the binary file if it it's already + // pointing to it. This is primarily to handle Homebrew that writes + // /usr/local/bin/tailscaled is a symlink to the actual binary. + if !same { + if err := copyBinary(exe, targetBin); err != nil { + return err + } + } + if err := os.WriteFile(sysPlist, []byte(darwinLaunchdPlist), 0700); err != nil { + return err + } + + if out, err := exec.Command("launchctl", "load", sysPlist).CombinedOutput(); err != nil { + return fmt.Errorf("error running launchctl load %s: %v, %s", sysPlist, err, out) + } + + if out, err := exec.Command("launchctl", "start", service).CombinedOutput(); err != nil { + return fmt.Errorf("error running launchctl start %s: %v, %s", service, err, out) + } + + return nil +} + +// copyBinary copies binary file `src` into `dst`. +func copyBinary(src, dst string) error { + if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { + return err + } + tmpBin := dst + ".tmp" + f, err := os.Create(tmpBin) + if err != nil { + return err + } + srcf, err := os.Open(src) + if err != nil { + f.Close() + return err + } + _, err = io.Copy(f, srcf) + srcf.Close() + if err != nil { + f.Close() + return err + } + if err := f.Close(); err != nil { + return err + } + if err := os.Chmod(tmpBin, 0755); err != nil { + return err + } + if err := os.Rename(tmpBin, dst); err != nil { + return err + } + + return nil +} + +func isSymlink(path string) bool { + fi, err := os.Lstat(path) + return err == nil && (fi.Mode()&os.ModeSymlink == os.ModeSymlink) +} + +// sameFile returns true if both file paths exist and resolve to the same file. +func sameFile(path1, path2 string) (bool, error) { + dst1, err := filepath.EvalSymlinks(path1) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return false, fmt.Errorf("EvalSymlinks(%s): %w", path1, err) + } + dst2, err := filepath.EvalSymlinks(path2) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return false, fmt.Errorf("EvalSymlinks(%s): %w", path2, err) + } + return dst1 == dst2, nil +} diff --git a/cmd/tailscaled/install_windows.go b/cmd/tailscaled/install_windows.go index 9e39c8ab37074..c36418642d2b4 100644 --- a/cmd/tailscaled/install_windows.go +++ b/cmd/tailscaled/install_windows.go @@ -1,124 +1,124 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -package main - -import ( - "context" - "errors" - "fmt" - "os" - "time" - - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/svc" - "golang.org/x/sys/windows/svc/mgr" - "tailscale.com/logtail/backoff" - "tailscale.com/types/logger" - "tailscale.com/util/osshare" -) - -func init() { - installSystemDaemon = installSystemDaemonWindows - uninstallSystemDaemon = uninstallSystemDaemonWindows -} - -func installSystemDaemonWindows(args []string) (err error) { - m, err := mgr.Connect() - if err != nil { - return fmt.Errorf("failed to connect to Windows service manager: %v", err) - } - - service, err := m.OpenService(serviceName) - if err == nil { - service.Close() - return fmt.Errorf("service %q is already installed", serviceName) - } - - // no such service; proceed to install the service. - - exe, err := os.Executable() - if err != nil { - return err - } - - c := mgr.Config{ - ServiceType: windows.SERVICE_WIN32_OWN_PROCESS, - StartType: mgr.StartAutomatic, - ErrorControl: mgr.ErrorNormal, - DisplayName: serviceName, - Description: "Connects this computer to others on the Tailscale network.", - } - - service, err = m.CreateService(serviceName, exe, c) - if err != nil { - return fmt.Errorf("failed to create %q service: %v", serviceName, err) - } - defer service.Close() - - // Exponential backoff is often too aggressive, so use (mostly) - // squares instead. - ra := []mgr.RecoveryAction{ - {mgr.ServiceRestart, 1 * time.Second}, - {mgr.ServiceRestart, 2 * time.Second}, - {mgr.ServiceRestart, 4 * time.Second}, - {mgr.ServiceRestart, 9 * time.Second}, - {mgr.ServiceRestart, 16 * time.Second}, - {mgr.ServiceRestart, 25 * time.Second}, - {mgr.ServiceRestart, 36 * time.Second}, - {mgr.ServiceRestart, 49 * time.Second}, - {mgr.ServiceRestart, 64 * time.Second}, - } - const resetPeriodSecs = 60 - err = service.SetRecoveryActions(ra, resetPeriodSecs) - if err != nil { - return fmt.Errorf("failed to set service recovery actions: %v", err) - } - - return nil -} - -func uninstallSystemDaemonWindows(args []string) (ret error) { - // Remove file sharing from Windows shell (noop in non-windows) - osshare.SetFileSharingEnabled(false, logger.Discard) - - m, err := mgr.Connect() - if err != nil { - return fmt.Errorf("failed to connect to Windows service manager: %v", err) - } - defer m.Disconnect() - - service, err := m.OpenService(serviceName) - if err != nil { - return fmt.Errorf("failed to open %q service: %v", serviceName, err) - } - - st, err := service.Query() - if err != nil { - service.Close() - return fmt.Errorf("failed to query service state: %v", err) - } - if st.State != svc.Stopped { - service.Control(svc.Stop) - } - err = service.Delete() - service.Close() - if err != nil { - return fmt.Errorf("failed to delete service: %v", err) - } - - bo := backoff.NewBackoff("uninstall", logger.Discard, 30*time.Second) - end := time.Now().Add(15 * time.Second) - for time.Until(end) > 0 { - service, err = m.OpenService(serviceName) - if err != nil { - // service is no longer openable; success! - break - } - service.Close() - bo.BackOff(context.Background(), errors.New("service not deleted")) - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package main + +import ( + "context" + "errors" + "fmt" + "os" + "time" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" + "tailscale.com/logtail/backoff" + "tailscale.com/types/logger" + "tailscale.com/util/osshare" +) + +func init() { + installSystemDaemon = installSystemDaemonWindows + uninstallSystemDaemon = uninstallSystemDaemonWindows +} + +func installSystemDaemonWindows(args []string) (err error) { + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to Windows service manager: %v", err) + } + + service, err := m.OpenService(serviceName) + if err == nil { + service.Close() + return fmt.Errorf("service %q is already installed", serviceName) + } + + // no such service; proceed to install the service. + + exe, err := os.Executable() + if err != nil { + return err + } + + c := mgr.Config{ + ServiceType: windows.SERVICE_WIN32_OWN_PROCESS, + StartType: mgr.StartAutomatic, + ErrorControl: mgr.ErrorNormal, + DisplayName: serviceName, + Description: "Connects this computer to others on the Tailscale network.", + } + + service, err = m.CreateService(serviceName, exe, c) + if err != nil { + return fmt.Errorf("failed to create %q service: %v", serviceName, err) + } + defer service.Close() + + // Exponential backoff is often too aggressive, so use (mostly) + // squares instead. + ra := []mgr.RecoveryAction{ + {mgr.ServiceRestart, 1 * time.Second}, + {mgr.ServiceRestart, 2 * time.Second}, + {mgr.ServiceRestart, 4 * time.Second}, + {mgr.ServiceRestart, 9 * time.Second}, + {mgr.ServiceRestart, 16 * time.Second}, + {mgr.ServiceRestart, 25 * time.Second}, + {mgr.ServiceRestart, 36 * time.Second}, + {mgr.ServiceRestart, 49 * time.Second}, + {mgr.ServiceRestart, 64 * time.Second}, + } + const resetPeriodSecs = 60 + err = service.SetRecoveryActions(ra, resetPeriodSecs) + if err != nil { + return fmt.Errorf("failed to set service recovery actions: %v", err) + } + + return nil +} + +func uninstallSystemDaemonWindows(args []string) (ret error) { + // Remove file sharing from Windows shell (noop in non-windows) + osshare.SetFileSharingEnabled(false, logger.Discard) + + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to Windows service manager: %v", err) + } + defer m.Disconnect() + + service, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("failed to open %q service: %v", serviceName, err) + } + + st, err := service.Query() + if err != nil { + service.Close() + return fmt.Errorf("failed to query service state: %v", err) + } + if st.State != svc.Stopped { + service.Control(svc.Stop) + } + err = service.Delete() + service.Close() + if err != nil { + return fmt.Errorf("failed to delete service: %v", err) + } + + bo := backoff.NewBackoff("uninstall", logger.Discard, 30*time.Second) + end := time.Now().Add(15 * time.Second) + for time.Until(end) > 0 { + service, err = m.OpenService(serviceName) + if err != nil { + // service is no longer openable; success! + break + } + service.Close() + bo.BackOff(context.Background(), errors.New("service not deleted")) + } + return nil +} diff --git a/cmd/tailscaled/proxy.go b/cmd/tailscaled/proxy.go index 109ad029d3aaf..a91c62bfa44ac 100644 --- a/cmd/tailscaled/proxy.go +++ b/cmd/tailscaled/proxy.go @@ -1,80 +1,80 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -// HTTP proxy code - -package main - -import ( - "context" - "io" - "net" - "net/http" - "net/http/httputil" - "strings" -) - -// httpProxyHandler returns an HTTP proxy http.Handler using the -// provided backend dialer. -func httpProxyHandler(dialer func(ctx context.Context, netw, addr string) (net.Conn, error)) http.Handler { - rp := &httputil.ReverseProxy{ - Director: func(r *http.Request) {}, // no change - Transport: &http.Transport{ - DialContext: dialer, - }, - } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "CONNECT" { - backURL := r.RequestURI - if strings.HasPrefix(backURL, "/") || backURL == "*" { - http.Error(w, "bogus RequestURI; must be absolute URL or CONNECT", 400) - return - } - rp.ServeHTTP(w, r) - return - } - - // CONNECT support: - - dst := r.RequestURI - c, err := dialer(r.Context(), "tcp", dst) - if err != nil { - w.Header().Set("Tailscale-Connect-Error", err.Error()) - http.Error(w, err.Error(), 500) - return - } - defer c.Close() - - cc, ccbuf, err := w.(http.Hijacker).Hijack() - if err != nil { - http.Error(w, err.Error(), 500) - return - } - defer cc.Close() - - io.WriteString(cc, "HTTP/1.1 200 OK\r\n\r\n") - - var clientSrc io.Reader = ccbuf - if ccbuf.Reader.Buffered() == 0 { - // In the common case (with no - // buffered data), read directly from - // the underlying client connection to - // save some memory, letting the - // bufio.Reader/Writer get GC'ed. - clientSrc = cc - } - - errc := make(chan error, 1) - go func() { - _, err := io.Copy(cc, c) - errc <- err - }() - go func() { - _, err := io.Copy(c, clientSrc) - errc <- err - }() - <-errc - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +// HTTP proxy code + +package main + +import ( + "context" + "io" + "net" + "net/http" + "net/http/httputil" + "strings" +) + +// httpProxyHandler returns an HTTP proxy http.Handler using the +// provided backend dialer. +func httpProxyHandler(dialer func(ctx context.Context, netw, addr string) (net.Conn, error)) http.Handler { + rp := &httputil.ReverseProxy{ + Director: func(r *http.Request) {}, // no change + Transport: &http.Transport{ + DialContext: dialer, + }, + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "CONNECT" { + backURL := r.RequestURI + if strings.HasPrefix(backURL, "/") || backURL == "*" { + http.Error(w, "bogus RequestURI; must be absolute URL or CONNECT", 400) + return + } + rp.ServeHTTP(w, r) + return + } + + // CONNECT support: + + dst := r.RequestURI + c, err := dialer(r.Context(), "tcp", dst) + if err != nil { + w.Header().Set("Tailscale-Connect-Error", err.Error()) + http.Error(w, err.Error(), 500) + return + } + defer c.Close() + + cc, ccbuf, err := w.(http.Hijacker).Hijack() + if err != nil { + http.Error(w, err.Error(), 500) + return + } + defer cc.Close() + + io.WriteString(cc, "HTTP/1.1 200 OK\r\n\r\n") + + var clientSrc io.Reader = ccbuf + if ccbuf.Reader.Buffered() == 0 { + // In the common case (with no + // buffered data), read directly from + // the underlying client connection to + // save some memory, letting the + // bufio.Reader/Writer get GC'ed. + clientSrc = cc + } + + errc := make(chan error, 1) + go func() { + _, err := io.Copy(cc, c) + errc <- err + }() + go func() { + _, err := io.Copy(c, clientSrc) + errc <- err + }() + <-errc + }) +} diff --git a/cmd/tailscaled/sigpipe.go b/cmd/tailscaled/sigpipe.go index 695a880248bc0..2fcdab2a4660e 100644 --- a/cmd/tailscaled/sigpipe.go +++ b/cmd/tailscaled/sigpipe.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.21 && !plan9 - -package main - -import "syscall" - -func init() { - sigPipe = syscall.SIGPIPE -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.21 && !plan9 + +package main + +import "syscall" + +func init() { + sigPipe = syscall.SIGPIPE +} diff --git a/cmd/tailscaled/tailscaled.defaults b/cmd/tailscaled/tailscaled.defaults index 693a6190bfac8..e8384a4f82097 100644 --- a/cmd/tailscaled/tailscaled.defaults +++ b/cmd/tailscaled/tailscaled.defaults @@ -1,8 +1,8 @@ -# Set the port to listen on for incoming VPN packets. -# Remote nodes will automatically be informed about the new port number, -# but you might want to configure this in order to set external firewall -# settings. -PORT="41641" - -# Extra flags you might want to pass to tailscaled. -FLAGS="" +# Set the port to listen on for incoming VPN packets. +# Remote nodes will automatically be informed about the new port number, +# but you might want to configure this in order to set external firewall +# settings. +PORT="41641" + +# Extra flags you might want to pass to tailscaled. +FLAGS="" diff --git a/cmd/tailscaled/tailscaled.openrc b/cmd/tailscaled/tailscaled.openrc index 6193247ce3131..309d70f23a26f 100755 --- a/cmd/tailscaled/tailscaled.openrc +++ b/cmd/tailscaled/tailscaled.openrc @@ -1,25 +1,25 @@ -#!/sbin/openrc-run - -set -a -source /etc/default/tailscaled -set +a - -command="/usr/sbin/tailscaled" -command_args="--state=/var/lib/tailscale/tailscaled.state --port=$PORT --socket=/var/run/tailscale/tailscaled.sock $FLAGS" -command_background=true -pidfile="/run/tailscaled.pid" -start_stop_daemon_args="-1 /var/log/tailscaled.log -2 /var/log/tailscaled.log" - -depend() { - need net -} - -start_pre() { - mkdir -p /var/run/tailscale - mkdir -p /var/lib/tailscale - $command --cleanup -} - -stop_post() { - $command --cleanup -} +#!/sbin/openrc-run + +set -a +source /etc/default/tailscaled +set +a + +command="/usr/sbin/tailscaled" +command_args="--state=/var/lib/tailscale/tailscaled.state --port=$PORT --socket=/var/run/tailscale/tailscaled.sock $FLAGS" +command_background=true +pidfile="/run/tailscaled.pid" +start_stop_daemon_args="-1 /var/log/tailscaled.log -2 /var/log/tailscaled.log" + +depend() { + need net +} + +start_pre() { + mkdir -p /var/run/tailscale + mkdir -p /var/lib/tailscale + $command --cleanup +} + +stop_post() { + $command --cleanup +} diff --git a/cmd/tailscaled/tailscaled_bird.go b/cmd/tailscaled/tailscaled_bird.go index 885f552cb8f50..c76f77bec6e36 100644 --- a/cmd/tailscaled/tailscaled_bird.go +++ b/cmd/tailscaled/tailscaled_bird.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 && (linux || darwin || freebsd || openbsd) && !ts_omit_bird - -package main - -import ( - "tailscale.com/chirp" - "tailscale.com/wgengine" -) - -func init() { - createBIRDClient = func(ctlSocket string) (wgengine.BIRDClient, error) { - return chirp.New(ctlSocket) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 && (linux || darwin || freebsd || openbsd) && !ts_omit_bird + +package main + +import ( + "tailscale.com/chirp" + "tailscale.com/wgengine" +) + +func init() { + createBIRDClient = func(ctlSocket string) (wgengine.BIRDClient, error) { + return chirp.New(ctlSocket) + } +} diff --git a/cmd/tailscaled/tailscaled_notwindows.go b/cmd/tailscaled/tailscaled_notwindows.go index b0a7c159833f5..d5361cf286d3d 100644 --- a/cmd/tailscaled/tailscaled_notwindows.go +++ b/cmd/tailscaled/tailscaled_notwindows.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && go1.19 - -package main // import "tailscale.com/cmd/tailscaled" - -import "tailscale.com/logpolicy" - -func isWindowsService() bool { return false } - -func runWindowsService(pol *logpolicy.Policy) error { panic("unreachable") } - -func beWindowsSubprocess() bool { return false } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && go1.19 + +package main // import "tailscale.com/cmd/tailscaled" + +import "tailscale.com/logpolicy" + +func isWindowsService() bool { return false } + +func runWindowsService(pol *logpolicy.Policy) error { panic("unreachable") } + +func beWindowsSubprocess() bool { return false } diff --git a/cmd/tailscaled/windows-manifest.xml b/cmd/tailscaled/windows-manifest.xml index 5eaa54fa514e3..6c5f46058387f 100644 --- a/cmd/tailscaled/windows-manifest.xml +++ b/cmd/tailscaled/windows-manifest.xml @@ -1,13 +1,13 @@ - - - - - - - - - - - - - + + + + + + + + + + + + + diff --git a/cmd/tailscaled/with_cli.go b/cmd/tailscaled/with_cli.go index f191fdb45b288..a8554eb8ce9dc 100644 --- a/cmd/tailscaled/with_cli.go +++ b/cmd/tailscaled/with_cli.go @@ -1,23 +1,23 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ts_include_cli - -package main - -import ( - "fmt" - "os" - - "tailscale.com/cmd/tailscale/cli" -) - -func init() { - beCLI = func() { - args := os.Args[1:] - if err := cli.Run(args); err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_include_cli + +package main + +import ( + "fmt" + "os" + + "tailscale.com/cmd/tailscale/cli" +) + +func init() { + beCLI = func() { + args := os.Args[1:] + if err := cli.Run(args); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + } +} diff --git a/cmd/testwrapper/args_test.go b/cmd/testwrapper/args_test.go index f7f30a7eb2fa5..10063d7bcf6e1 100644 --- a/cmd/testwrapper/args_test.go +++ b/cmd/testwrapper/args_test.go @@ -1,97 +1,97 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "slices" - "testing" -) - -func TestSplitArgs(t *testing.T) { - tests := []struct { - name string - in []string - pre, pkgs, post []string - }{ - { - name: "empty", - }, - { - name: "all", - in: []string{"-v", "pkg1", "pkg2", "-run", "TestFoo", "-timeout=20s"}, - pre: []string{"-v"}, - pkgs: []string{"pkg1", "pkg2"}, - post: []string{"-run", "TestFoo", "-timeout=20s"}, - }, - { - name: "only_pkgs", - in: []string{"./..."}, - pkgs: []string{"./..."}, - }, - { - name: "pkgs_and_post", - in: []string{"pkg1", "-run", "TestFoo"}, - pkgs: []string{"pkg1"}, - post: []string{"-run", "TestFoo"}, - }, - { - name: "pkgs_and_post", - in: []string{"-v", "pkg2"}, - pre: []string{"-v"}, - pkgs: []string{"pkg2"}, - }, - { - name: "only_args", - in: []string{"-v", "-run=TestFoo"}, - pre: []string{"-run", "TestFoo", "-v"}, // sorted - }, - { - name: "space_in_pre_arg", - in: []string{"-run", "TestFoo", "./cmd/testwrapper"}, - pre: []string{"-run", "TestFoo"}, - pkgs: []string{"./cmd/testwrapper"}, - }, - { - name: "space_in_arg", - in: []string{"-exec", "sudo -E", "./cmd/testwrapper"}, - pre: []string{"-exec", "sudo -E"}, - pkgs: []string{"./cmd/testwrapper"}, - }, - { - name: "test-arg", - in: []string{"-exec", "sudo -E", "./cmd/testwrapper", "--", "--some-flag"}, - pre: []string{"-exec", "sudo -E"}, - pkgs: []string{"./cmd/testwrapper"}, - post: []string{"--", "--some-flag"}, - }, - { - name: "dupe-args", - in: []string{"-v", "-v", "-race", "-race", "./cmd/testwrapper", "--", "--some-flag"}, - pre: []string{"-race", "-v"}, - pkgs: []string{"./cmd/testwrapper"}, - post: []string{"--", "--some-flag"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - pre, pkgs, post, err := splitArgs(tt.in) - if err != nil { - t.Fatal(err) - } - if !slices.Equal(pre, tt.pre) { - t.Errorf("pre = %q; want %q", pre, tt.pre) - } - if !slices.Equal(pkgs, tt.pkgs) { - t.Errorf("pattern = %q; want %q", pkgs, tt.pkgs) - } - if !slices.Equal(post, tt.post) { - t.Errorf("post = %q; want %q", post, tt.post) - } - if t.Failed() { - t.Logf("SplitArgs(%q) = %q %q %q", tt.in, pre, pkgs, post) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "slices" + "testing" +) + +func TestSplitArgs(t *testing.T) { + tests := []struct { + name string + in []string + pre, pkgs, post []string + }{ + { + name: "empty", + }, + { + name: "all", + in: []string{"-v", "pkg1", "pkg2", "-run", "TestFoo", "-timeout=20s"}, + pre: []string{"-v"}, + pkgs: []string{"pkg1", "pkg2"}, + post: []string{"-run", "TestFoo", "-timeout=20s"}, + }, + { + name: "only_pkgs", + in: []string{"./..."}, + pkgs: []string{"./..."}, + }, + { + name: "pkgs_and_post", + in: []string{"pkg1", "-run", "TestFoo"}, + pkgs: []string{"pkg1"}, + post: []string{"-run", "TestFoo"}, + }, + { + name: "pkgs_and_post", + in: []string{"-v", "pkg2"}, + pre: []string{"-v"}, + pkgs: []string{"pkg2"}, + }, + { + name: "only_args", + in: []string{"-v", "-run=TestFoo"}, + pre: []string{"-run", "TestFoo", "-v"}, // sorted + }, + { + name: "space_in_pre_arg", + in: []string{"-run", "TestFoo", "./cmd/testwrapper"}, + pre: []string{"-run", "TestFoo"}, + pkgs: []string{"./cmd/testwrapper"}, + }, + { + name: "space_in_arg", + in: []string{"-exec", "sudo -E", "./cmd/testwrapper"}, + pre: []string{"-exec", "sudo -E"}, + pkgs: []string{"./cmd/testwrapper"}, + }, + { + name: "test-arg", + in: []string{"-exec", "sudo -E", "./cmd/testwrapper", "--", "--some-flag"}, + pre: []string{"-exec", "sudo -E"}, + pkgs: []string{"./cmd/testwrapper"}, + post: []string{"--", "--some-flag"}, + }, + { + name: "dupe-args", + in: []string{"-v", "-v", "-race", "-race", "./cmd/testwrapper", "--", "--some-flag"}, + pre: []string{"-race", "-v"}, + pkgs: []string{"./cmd/testwrapper"}, + post: []string{"--", "--some-flag"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pre, pkgs, post, err := splitArgs(tt.in) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(pre, tt.pre) { + t.Errorf("pre = %q; want %q", pre, tt.pre) + } + if !slices.Equal(pkgs, tt.pkgs) { + t.Errorf("pattern = %q; want %q", pkgs, tt.pkgs) + } + if !slices.Equal(post, tt.post) { + t.Errorf("post = %q; want %q", post, tt.post) + } + if t.Failed() { + t.Logf("SplitArgs(%q) = %q %q %q", tt.in, pre, pkgs, post) + } + }) + } +} diff --git a/cmd/testwrapper/flakytest/flakytest.go b/cmd/testwrapper/flakytest/flakytest.go index e5e21dd2159ba..494ed080b26a1 100644 --- a/cmd/testwrapper/flakytest/flakytest.go +++ b/cmd/testwrapper/flakytest/flakytest.go @@ -1,44 +1,44 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package flakytest contains test helpers for marking a test as flaky. For -// tests run using cmd/testwrapper, a failed flaky test will cause tests to be -// re-run a few time until they succeed or exceed our iteration limit. -package flakytest - -import ( - "fmt" - "os" - "regexp" - "testing" -) - -// FlakyTestLogMessage is a sentinel value that is printed to stderr when a -// flaky test is marked. This is used by cmd/testwrapper to detect flaky tests -// and retry them. -const FlakyTestLogMessage = "flakytest: this is a known flaky test" - -// FlakeAttemptEnv is an environment variable that is set by cmd/testwrapper -// when a flaky test is being (re)tried. It contains the attempt number, -// starting at 1. -const FlakeAttemptEnv = "TS_TESTWRAPPER_ATTEMPT" - -var issueRegexp = regexp.MustCompile(`\Ahttps://github\.com/tailscale/[a-zA-Z0-9_.-]+/issues/\d+\z`) - -// Mark sets the current test as a flaky test, such that if it fails, it will -// be retried a few times on failure. issue must be a GitHub issue that tracks -// the status of the flaky test being marked, of the format: -// -// https://github.com/tailscale/myRepo-H3re/issues/12345 -func Mark(t testing.TB, issue string) { - if !issueRegexp.MatchString(issue) { - t.Fatalf("bad issue format: %q", issue) - } - if _, ok := os.LookupEnv(FlakeAttemptEnv); ok { - // We're being run under cmd/testwrapper so send our sentinel message - // to stderr. (We avoid doing this when the env is absent to avoid - // spamming people running tests without the wrapper) - fmt.Fprintf(os.Stderr, "%s: %s\n", FlakyTestLogMessage, issue) - } - t.Logf("flakytest: issue tracking this flaky test: %s", issue) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package flakytest contains test helpers for marking a test as flaky. For +// tests run using cmd/testwrapper, a failed flaky test will cause tests to be +// re-run a few time until they succeed or exceed our iteration limit. +package flakytest + +import ( + "fmt" + "os" + "regexp" + "testing" +) + +// FlakyTestLogMessage is a sentinel value that is printed to stderr when a +// flaky test is marked. This is used by cmd/testwrapper to detect flaky tests +// and retry them. +const FlakyTestLogMessage = "flakytest: this is a known flaky test" + +// FlakeAttemptEnv is an environment variable that is set by cmd/testwrapper +// when a flaky test is being (re)tried. It contains the attempt number, +// starting at 1. +const FlakeAttemptEnv = "TS_TESTWRAPPER_ATTEMPT" + +var issueRegexp = regexp.MustCompile(`\Ahttps://github\.com/tailscale/[a-zA-Z0-9_.-]+/issues/\d+\z`) + +// Mark sets the current test as a flaky test, such that if it fails, it will +// be retried a few times on failure. issue must be a GitHub issue that tracks +// the status of the flaky test being marked, of the format: +// +// https://github.com/tailscale/myRepo-H3re/issues/12345 +func Mark(t testing.TB, issue string) { + if !issueRegexp.MatchString(issue) { + t.Fatalf("bad issue format: %q", issue) + } + if _, ok := os.LookupEnv(FlakeAttemptEnv); ok { + // We're being run under cmd/testwrapper so send our sentinel message + // to stderr. (We avoid doing this when the env is absent to avoid + // spamming people running tests without the wrapper) + fmt.Fprintf(os.Stderr, "%s: %s\n", FlakyTestLogMessage, issue) + } + t.Logf("flakytest: issue tracking this flaky test: %s", issue) +} diff --git a/cmd/testwrapper/flakytest/flakytest_test.go b/cmd/testwrapper/flakytest/flakytest_test.go index 551352f6ad8ea..85e77a939c75d 100644 --- a/cmd/testwrapper/flakytest/flakytest_test.go +++ b/cmd/testwrapper/flakytest/flakytest_test.go @@ -1,43 +1,43 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package flakytest - -import ( - "os" - "testing" -) - -func TestIssueFormat(t *testing.T) { - testCases := []struct { - issue string - want bool - }{ - {"https://github.com/tailscale/cOrp/issues/1234", true}, - {"https://github.com/otherproject/corp/issues/1234", false}, - {"https://github.com/tailscale/corp/issues/", false}, - } - for _, testCase := range testCases { - if issueRegexp.MatchString(testCase.issue) != testCase.want { - ss := "" - if !testCase.want { - ss = " not" - } - t.Errorf("expected issueRegexp to%s match %q", ss, testCase.issue) - } - } -} - -// TestFlakeRun is a test that fails when run in the testwrapper -// for the first time, but succeeds on the second run. -// It's used to test whether the testwrapper retries flaky tests. -func TestFlakeRun(t *testing.T) { - Mark(t, "https://github.com/tailscale/tailscale/issues/0") // random issue - e := os.Getenv(FlakeAttemptEnv) - if e == "" { - t.Skip("not running in testwrapper") - } - if e == "1" { - t.Fatal("First run in testwrapper, failing so that test is retried. This is expected.") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package flakytest + +import ( + "os" + "testing" +) + +func TestIssueFormat(t *testing.T) { + testCases := []struct { + issue string + want bool + }{ + {"https://github.com/tailscale/cOrp/issues/1234", true}, + {"https://github.com/otherproject/corp/issues/1234", false}, + {"https://github.com/tailscale/corp/issues/", false}, + } + for _, testCase := range testCases { + if issueRegexp.MatchString(testCase.issue) != testCase.want { + ss := "" + if !testCase.want { + ss = " not" + } + t.Errorf("expected issueRegexp to%s match %q", ss, testCase.issue) + } + } +} + +// TestFlakeRun is a test that fails when run in the testwrapper +// for the first time, but succeeds on the second run. +// It's used to test whether the testwrapper retries flaky tests. +func TestFlakeRun(t *testing.T) { + Mark(t, "https://github.com/tailscale/tailscale/issues/0") // random issue + e := os.Getenv(FlakeAttemptEnv) + if e == "" { + t.Skip("not running in testwrapper") + } + if e == "1" { + t.Fatal("First run in testwrapper, failing so that test is retried. This is expected.") + } +} diff --git a/cmd/tsconnect/.gitignore b/cmd/tsconnect/.gitignore index b791f8e64b14e..13615d1213d63 100644 --- a/cmd/tsconnect/.gitignore +++ b/cmd/tsconnect/.gitignore @@ -1,3 +1,3 @@ -node_modules/ -/dist -/pkg +node_modules/ +/dist +/pkg diff --git a/cmd/tsconnect/README.md b/cmd/tsconnect/README.md index f518f932e07eb..536cd7bbf562c 100644 --- a/cmd/tsconnect/README.md +++ b/cmd/tsconnect/README.md @@ -1,49 +1,49 @@ -# tsconnect - -The tsconnect command builds and serves the static site that is generated for -the Tailscale Connect JS/WASM client. - -## Development - -To start the development server: - -``` -./tool/go run ./cmd/tsconnect dev -``` - -The site is served at http://localhost:9090/. JavaScript, CSS and Go `wasm` package changes can be picked up with a browser reload. Server-side Go changes require the server to be stopped and restarted. In development mode the state the Tailscale client state is stored in `sessionStorage` and will thus survive page reloads (but not the tab being closed). - -## Deployment - -To build the static assets necessary for serving, run: - -``` -./tool/go run ./cmd/tsconnect build -``` - -To serve them, run: - -``` -./tool/go run ./cmd/tsconnect serve -``` - -By default the build output is placed in the `dist/` directory and embedded in the binary, but this can be controlled by the `-distdir` flag. The `-addr` flag controls the interface and port that the serve listens on. - -# Library / NPM Package - -The client is also available as [an NPM package](https://www.npmjs.com/package/@tailscale/connect). To build it, run: - -``` -./tool/go run ./cmd/tsconnect build-pkg -``` - -That places the output in the `pkg/` directory, which may then be uploaded to a package registry (or installed from the file path directly). - -To do two-sided development (on both the NPM package and code that uses it), run: - -``` -./tool/go run ./cmd/tsconnect dev-pkg - -``` - -This serves the module at http://localhost:9090/pkg/pkg.js and the generated wasm file at http://localhost:9090/pkg/main.wasm. The two files can be used as drop-in replacements for normal imports of the NPM module. +# tsconnect + +The tsconnect command builds and serves the static site that is generated for +the Tailscale Connect JS/WASM client. + +## Development + +To start the development server: + +``` +./tool/go run ./cmd/tsconnect dev +``` + +The site is served at http://localhost:9090/. JavaScript, CSS and Go `wasm` package changes can be picked up with a browser reload. Server-side Go changes require the server to be stopped and restarted. In development mode the state the Tailscale client state is stored in `sessionStorage` and will thus survive page reloads (but not the tab being closed). + +## Deployment + +To build the static assets necessary for serving, run: + +``` +./tool/go run ./cmd/tsconnect build +``` + +To serve them, run: + +``` +./tool/go run ./cmd/tsconnect serve +``` + +By default the build output is placed in the `dist/` directory and embedded in the binary, but this can be controlled by the `-distdir` flag. The `-addr` flag controls the interface and port that the serve listens on. + +# Library / NPM Package + +The client is also available as [an NPM package](https://www.npmjs.com/package/@tailscale/connect). To build it, run: + +``` +./tool/go run ./cmd/tsconnect build-pkg +``` + +That places the output in the `pkg/` directory, which may then be uploaded to a package registry (or installed from the file path directly). + +To do two-sided development (on both the NPM package and code that uses it), run: + +``` +./tool/go run ./cmd/tsconnect dev-pkg + +``` + +This serves the module at http://localhost:9090/pkg/pkg.js and the generated wasm file at http://localhost:9090/pkg/main.wasm. The two files can be used as drop-in replacements for normal imports of the NPM module. diff --git a/cmd/tsconnect/README.pkg.md b/cmd/tsconnect/README.pkg.md index df5799578d5e7..df8d66789894d 100644 --- a/cmd/tsconnect/README.pkg.md +++ b/cmd/tsconnect/README.pkg.md @@ -1,3 +1,3 @@ -# @tailscale/connect - -NPM package that contains a WebAssembly-based Tailscale client, see [the `cmd/tsconnect` directory in the tailscale repo](https://github.com/tailscale/tailscale/tree/main/cmd/tsconnect#library--npm-package) for more details. +# @tailscale/connect + +NPM package that contains a WebAssembly-based Tailscale client, see [the `cmd/tsconnect` directory in the tailscale repo](https://github.com/tailscale/tailscale/tree/main/cmd/tsconnect#library--npm-package) for more details. diff --git a/cmd/tsconnect/build-pkg.go b/cmd/tsconnect/build-pkg.go index 2b6cc9b1fcbc9..047504858ae0c 100644 --- a/cmd/tsconnect/build-pkg.go +++ b/cmd/tsconnect/build-pkg.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -package main - -import ( - "encoding/json" - "fmt" - "log" - "os" - "path" - - "github.com/tailscale/hujson" - "tailscale.com/util/precompress" - "tailscale.com/version" -) - -func runBuildPkg() { - buildOptions, err := commonPkgSetup(prodMode) - if err != nil { - log.Fatalf("Cannot setup: %v", err) - } - - log.Printf("Linting...\n") - if err := runYarn("lint"); err != nil { - log.Fatalf("Linting failed: %v", err) - } - - if err := cleanDir(*pkgDir); err != nil { - log.Fatalf("Cannot clean %s: %v", *pkgDir, err) - } - - buildOptions.Write = true - buildOptions.MinifyWhitespace = true - buildOptions.MinifyIdentifiers = true - buildOptions.MinifySyntax = true - - runEsbuild(*buildOptions) - - if err := precompressWasm(); err != nil { - log.Fatalf("Could not pre-recompress wasm: %v", err) - } - - log.Printf("Generating types...\n") - if err := runYarn("pkg-types"); err != nil { - log.Fatalf("Type generation failed: %v", err) - } - - if err := updateVersion(); err != nil { - log.Fatalf("Cannot update version: %v", err) - } - - if err := copyReadme(); err != nil { - log.Fatalf("Cannot copy readme: %v", err) - } - - log.Printf("Built package version %s", version.Long()) -} - -func precompressWasm() error { - log.Printf("Pre-compressing main.wasm...\n") - return precompress.Precompress(path.Join(*pkgDir, "main.wasm"), precompress.Options{ - FastCompression: *fastCompression, - }) -} - -func updateVersion() error { - packageJSONBytes, err := os.ReadFile("package.json.tmpl") - if err != nil { - return fmt.Errorf("Could not read package.json: %w", err) - } - - var packageJSON map[string]any - packageJSONBytes, err = hujson.Standardize(packageJSONBytes) - if err != nil { - return fmt.Errorf("Could not standardize template package.json: %w", err) - } - if err := json.Unmarshal(packageJSONBytes, &packageJSON); err != nil { - return fmt.Errorf("Could not unmarshal package.json: %w", err) - } - packageJSON["version"] = version.Long() - - packageJSONBytes, err = json.MarshalIndent(packageJSON, "", " ") - if err != nil { - return fmt.Errorf("Could not marshal package.json: %w", err) - } - - return os.WriteFile(path.Join(*pkgDir, "package.json"), packageJSONBytes, 0644) -} - -func copyReadme() error { - readmeBytes, err := os.ReadFile("README.pkg.md") - if err != nil { - return fmt.Errorf("Could not read README.pkg.md: %w", err) - } - return os.WriteFile(path.Join(*pkgDir, "README.md"), readmeBytes, 0644) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "encoding/json" + "fmt" + "log" + "os" + "path" + + "github.com/tailscale/hujson" + "tailscale.com/util/precompress" + "tailscale.com/version" +) + +func runBuildPkg() { + buildOptions, err := commonPkgSetup(prodMode) + if err != nil { + log.Fatalf("Cannot setup: %v", err) + } + + log.Printf("Linting...\n") + if err := runYarn("lint"); err != nil { + log.Fatalf("Linting failed: %v", err) + } + + if err := cleanDir(*pkgDir); err != nil { + log.Fatalf("Cannot clean %s: %v", *pkgDir, err) + } + + buildOptions.Write = true + buildOptions.MinifyWhitespace = true + buildOptions.MinifyIdentifiers = true + buildOptions.MinifySyntax = true + + runEsbuild(*buildOptions) + + if err := precompressWasm(); err != nil { + log.Fatalf("Could not pre-recompress wasm: %v", err) + } + + log.Printf("Generating types...\n") + if err := runYarn("pkg-types"); err != nil { + log.Fatalf("Type generation failed: %v", err) + } + + if err := updateVersion(); err != nil { + log.Fatalf("Cannot update version: %v", err) + } + + if err := copyReadme(); err != nil { + log.Fatalf("Cannot copy readme: %v", err) + } + + log.Printf("Built package version %s", version.Long()) +} + +func precompressWasm() error { + log.Printf("Pre-compressing main.wasm...\n") + return precompress.Precompress(path.Join(*pkgDir, "main.wasm"), precompress.Options{ + FastCompression: *fastCompression, + }) +} + +func updateVersion() error { + packageJSONBytes, err := os.ReadFile("package.json.tmpl") + if err != nil { + return fmt.Errorf("Could not read package.json: %w", err) + } + + var packageJSON map[string]any + packageJSONBytes, err = hujson.Standardize(packageJSONBytes) + if err != nil { + return fmt.Errorf("Could not standardize template package.json: %w", err) + } + if err := json.Unmarshal(packageJSONBytes, &packageJSON); err != nil { + return fmt.Errorf("Could not unmarshal package.json: %w", err) + } + packageJSON["version"] = version.Long() + + packageJSONBytes, err = json.MarshalIndent(packageJSON, "", " ") + if err != nil { + return fmt.Errorf("Could not marshal package.json: %w", err) + } + + return os.WriteFile(path.Join(*pkgDir, "package.json"), packageJSONBytes, 0644) +} + +func copyReadme() error { + readmeBytes, err := os.ReadFile("README.pkg.md") + if err != nil { + return fmt.Errorf("Could not read README.pkg.md: %w", err) + } + return os.WriteFile(path.Join(*pkgDir, "README.md"), readmeBytes, 0644) +} diff --git a/cmd/tsconnect/dev-pkg.go b/cmd/tsconnect/dev-pkg.go index cb5ebf39ef657..de534c3b20625 100644 --- a/cmd/tsconnect/dev-pkg.go +++ b/cmd/tsconnect/dev-pkg.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -package main - -import ( - "log" -) - -func runDevPkg() { - buildOptions, err := commonPkgSetup(devMode) - if err != nil { - log.Fatalf("Cannot setup: %v", err) - } - runEsbuildServe(*buildOptions) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "log" +) + +func runDevPkg() { + buildOptions, err := commonPkgSetup(devMode) + if err != nil { + log.Fatalf("Cannot setup: %v", err) + } + runEsbuildServe(*buildOptions) +} diff --git a/cmd/tsconnect/dev.go b/cmd/tsconnect/dev.go index 161eb3b866a00..87b10adaf49c8 100644 --- a/cmd/tsconnect/dev.go +++ b/cmd/tsconnect/dev.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -package main - -import ( - "log" -) - -func runDev() { - buildOptions, err := commonSetup(devMode) - if err != nil { - log.Fatalf("Cannot setup: %v", err) - } - runEsbuildServe(*buildOptions) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "log" +) + +func runDev() { + buildOptions, err := commonSetup(devMode) + if err != nil { + log.Fatalf("Cannot setup: %v", err) + } + runEsbuildServe(*buildOptions) +} diff --git a/cmd/tsconnect/dist/placeholder b/cmd/tsconnect/dist/placeholder index dddaba4d76687..4af99d997207f 100644 --- a/cmd/tsconnect/dist/placeholder +++ b/cmd/tsconnect/dist/placeholder @@ -1,2 +1,2 @@ -This is here to make sure the dist/ directory exists for the go:embed command -in serve.go. +This is here to make sure the dist/ directory exists for the go:embed command +in serve.go. diff --git a/cmd/tsconnect/index.html b/cmd/tsconnect/index.html index 39aa7571add71..3db45fdef2bca 100644 --- a/cmd/tsconnect/index.html +++ b/cmd/tsconnect/index.html @@ -1,20 +1,20 @@ - - - - - - Tailscale Connect - - - - - -
-
-

Tailscale Connect

-
Loading…
-
-
- - + + + + + + Tailscale Connect + + + + + +
+
+

Tailscale Connect

+
Loading…
+
+
+ + diff --git a/cmd/tsconnect/package.json b/cmd/tsconnect/package.json index 8ea726cc670b8..bf4eb7c099aac 100644 --- a/cmd/tsconnect/package.json +++ b/cmd/tsconnect/package.json @@ -1,25 +1,25 @@ -{ - "name": "tsconnect", - "version": "0.0.1", - "license": "BSD-3-Clause", - "devDependencies": { - "@types/golang-wasm-exec": "^1.15.0", - "@types/qrcode": "^1.4.2", - "dts-bundle-generator": "^6.12.0", - "preact": "^10.10.0", - "qrcode": "^1.5.0", - "tailwindcss": "^3.1.6", - "typescript": "^4.7.4", - "xterm": "^5.1.0", - "xterm-addon-fit": "^0.7.0", - "xterm-addon-web-links": "^0.8.0" - }, - "scripts": { - "lint": "tsc --noEmit", - "pkg-types": "dts-bundle-generator --inline-declare-global=true --no-banner -o pkg/pkg.d.ts src/pkg/pkg.ts" - }, - "prettier": { - "semi": false, - "printWidth": 80 - } -} +{ + "name": "tsconnect", + "version": "0.0.1", + "license": "BSD-3-Clause", + "devDependencies": { + "@types/golang-wasm-exec": "^1.15.0", + "@types/qrcode": "^1.4.2", + "dts-bundle-generator": "^6.12.0", + "preact": "^10.10.0", + "qrcode": "^1.5.0", + "tailwindcss": "^3.1.6", + "typescript": "^4.7.4", + "xterm": "^5.1.0", + "xterm-addon-fit": "^0.7.0", + "xterm-addon-web-links": "^0.8.0" + }, + "scripts": { + "lint": "tsc --noEmit", + "pkg-types": "dts-bundle-generator --inline-declare-global=true --no-banner -o pkg/pkg.d.ts src/pkg/pkg.ts" + }, + "prettier": { + "semi": false, + "printWidth": 80 + } +} diff --git a/cmd/tsconnect/package.json.tmpl b/cmd/tsconnect/package.json.tmpl index 0263bf48118dd..404b896eaf89e 100644 --- a/cmd/tsconnect/package.json.tmpl +++ b/cmd/tsconnect/package.json.tmpl @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Template for the package.json that is generated by the build-pkg command. -// The version number will be replaced by the current Tailscale client version -// number. -{ - "author": "Tailscale Inc.", - "description": "Tailscale Connect SDK", - "license": "BSD-3-Clause", - "name": "@tailscale/connect", - "type": "module", - "main": "./pkg.js", - "types": "./pkg.d.ts", - "version": "AUTO_GENERATED" -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Template for the package.json that is generated by the build-pkg command. +// The version number will be replaced by the current Tailscale client version +// number. +{ + "author": "Tailscale Inc.", + "description": "Tailscale Connect SDK", + "license": "BSD-3-Clause", + "name": "@tailscale/connect", + "type": "module", + "main": "./pkg.js", + "types": "./pkg.d.ts", + "version": "AUTO_GENERATED" +} diff --git a/cmd/tsconnect/serve.go b/cmd/tsconnect/serve.go index 80844bea74b6e..d780bdd57c3e3 100644 --- a/cmd/tsconnect/serve.go +++ b/cmd/tsconnect/serve.go @@ -1,144 +1,144 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -package main - -import ( - "bytes" - "embed" - "encoding/json" - "fmt" - "io" - "io/fs" - "log" - "net/http" - "os" - "path" - "time" - - "tailscale.com/tsweb" - "tailscale.com/util/precompress" -) - -//go:embed index.html -var embeddedFS embed.FS - -//go:embed dist/* -var embeddedDistFS embed.FS - -var serveStartTime = time.Now() - -func runServe() { - mux := http.NewServeMux() - - var distFS fs.FS - if *distDir == "./dist" { - var err error - distFS, err = fs.Sub(embeddedDistFS, "dist") - if err != nil { - log.Fatalf("Could not drop dist/ prefix from embedded FS: %v", err) - } - } else { - distFS = os.DirFS(*distDir) - } - - indexBytes, err := generateServeIndex(distFS) - if err != nil { - log.Fatalf("Could not generate index.html: %v", err) - } - mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.ServeContent(w, r, "index.html", serveStartTime, bytes.NewReader(indexBytes)) - })) - mux.Handle("/dist/", http.StripPrefix("/dist/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handleServeDist(w, r, distFS) - }))) - tsweb.Debugger(mux) - - log.Printf("Listening on %s", *addr) - err = http.ListenAndServe(*addr, mux) - if err != nil { - log.Fatal(err) - } -} - -func generateServeIndex(distFS fs.FS) ([]byte, error) { - log.Printf("Generating index.html...\n") - rawIndexBytes, err := embeddedFS.ReadFile("index.html") - if err != nil { - return nil, fmt.Errorf("Could not read index.html: %w", err) - } - - esbuildMetadataFile, err := distFS.Open("esbuild-metadata.json") - if err != nil { - return nil, fmt.Errorf("Could not open esbuild-metadata.json: %w", err) - } - defer esbuildMetadataFile.Close() - esbuildMetadataBytes, err := io.ReadAll(esbuildMetadataFile) - if err != nil { - return nil, fmt.Errorf("Could not read esbuild-metadata.json: %w", err) - } - var esbuildMetadata EsbuildMetadata - if err := json.Unmarshal(esbuildMetadataBytes, &esbuildMetadata); err != nil { - return nil, fmt.Errorf("Could not parse esbuild-metadata.json: %w", err) - } - entryPointsToHashedDistPaths := make(map[string]string) - mainWasmPath := "" - for outputPath, output := range esbuildMetadata.Outputs { - if output.EntryPoint != "" { - entryPointsToHashedDistPaths[output.EntryPoint] = path.Join("dist", outputPath) - } - if path.Ext(outputPath) == ".wasm" { - for input := range output.Inputs { - if input == "src/main.wasm" { - mainWasmPath = path.Join("dist", outputPath) - break - } - } - } - } - - indexBytes := rawIndexBytes - for entryPointPath, defaultDistPath := range entryPointsToDefaultDistPaths { - hashedDistPath := entryPointsToHashedDistPaths[entryPointPath] - if hashedDistPath != "" { - indexBytes = bytes.ReplaceAll(indexBytes, []byte(defaultDistPath), []byte(hashedDistPath)) - } - } - if mainWasmPath != "" { - mainWasmPrefetch := fmt.Sprintf("\n", mainWasmPath) - indexBytes = bytes.ReplaceAll(indexBytes, []byte(""), []byte(mainWasmPrefetch)) - } - - return indexBytes, nil -} - -var entryPointsToDefaultDistPaths = map[string]string{ - "src/app/index.css": "dist/index.css", - "src/app/index.ts": "dist/index.js", -} - -func handleServeDist(w http.ResponseWriter, r *http.Request, distFS fs.FS) { - path := r.URL.Path - f, err := precompress.OpenPrecompressedFile(w, r, path, distFS) - if err != nil { - http.Error(w, err.Error(), http.StatusNotFound) - return - } - defer f.Close() - - // fs.File does not claim to implement Seeker, but in practice it does. - fSeeker, ok := f.(io.ReadSeeker) - if !ok { - http.Error(w, "Not seekable", http.StatusInternalServerError) - return - } - - // Aggressively cache static assets, since we cache-bust our assets with - // hashed filenames. - w.Header().Set("Cache-Control", "public, max-age=31535996") - w.Header().Set("Vary", "Accept-Encoding") - - http.ServeContent(w, r, path, serveStartTime, fSeeker) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "bytes" + "embed" + "encoding/json" + "fmt" + "io" + "io/fs" + "log" + "net/http" + "os" + "path" + "time" + + "tailscale.com/tsweb" + "tailscale.com/util/precompress" +) + +//go:embed index.html +var embeddedFS embed.FS + +//go:embed dist/* +var embeddedDistFS embed.FS + +var serveStartTime = time.Now() + +func runServe() { + mux := http.NewServeMux() + + var distFS fs.FS + if *distDir == "./dist" { + var err error + distFS, err = fs.Sub(embeddedDistFS, "dist") + if err != nil { + log.Fatalf("Could not drop dist/ prefix from embedded FS: %v", err) + } + } else { + distFS = os.DirFS(*distDir) + } + + indexBytes, err := generateServeIndex(distFS) + if err != nil { + log.Fatalf("Could not generate index.html: %v", err) + } + mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.ServeContent(w, r, "index.html", serveStartTime, bytes.NewReader(indexBytes)) + })) + mux.Handle("/dist/", http.StripPrefix("/dist/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleServeDist(w, r, distFS) + }))) + tsweb.Debugger(mux) + + log.Printf("Listening on %s", *addr) + err = http.ListenAndServe(*addr, mux) + if err != nil { + log.Fatal(err) + } +} + +func generateServeIndex(distFS fs.FS) ([]byte, error) { + log.Printf("Generating index.html...\n") + rawIndexBytes, err := embeddedFS.ReadFile("index.html") + if err != nil { + return nil, fmt.Errorf("Could not read index.html: %w", err) + } + + esbuildMetadataFile, err := distFS.Open("esbuild-metadata.json") + if err != nil { + return nil, fmt.Errorf("Could not open esbuild-metadata.json: %w", err) + } + defer esbuildMetadataFile.Close() + esbuildMetadataBytes, err := io.ReadAll(esbuildMetadataFile) + if err != nil { + return nil, fmt.Errorf("Could not read esbuild-metadata.json: %w", err) + } + var esbuildMetadata EsbuildMetadata + if err := json.Unmarshal(esbuildMetadataBytes, &esbuildMetadata); err != nil { + return nil, fmt.Errorf("Could not parse esbuild-metadata.json: %w", err) + } + entryPointsToHashedDistPaths := make(map[string]string) + mainWasmPath := "" + for outputPath, output := range esbuildMetadata.Outputs { + if output.EntryPoint != "" { + entryPointsToHashedDistPaths[output.EntryPoint] = path.Join("dist", outputPath) + } + if path.Ext(outputPath) == ".wasm" { + for input := range output.Inputs { + if input == "src/main.wasm" { + mainWasmPath = path.Join("dist", outputPath) + break + } + } + } + } + + indexBytes := rawIndexBytes + for entryPointPath, defaultDistPath := range entryPointsToDefaultDistPaths { + hashedDistPath := entryPointsToHashedDistPaths[entryPointPath] + if hashedDistPath != "" { + indexBytes = bytes.ReplaceAll(indexBytes, []byte(defaultDistPath), []byte(hashedDistPath)) + } + } + if mainWasmPath != "" { + mainWasmPrefetch := fmt.Sprintf("\n", mainWasmPath) + indexBytes = bytes.ReplaceAll(indexBytes, []byte(""), []byte(mainWasmPrefetch)) + } + + return indexBytes, nil +} + +var entryPointsToDefaultDistPaths = map[string]string{ + "src/app/index.css": "dist/index.css", + "src/app/index.ts": "dist/index.js", +} + +func handleServeDist(w http.ResponseWriter, r *http.Request, distFS fs.FS) { + path := r.URL.Path + f, err := precompress.OpenPrecompressedFile(w, r, path, distFS) + if err != nil { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + defer f.Close() + + // fs.File does not claim to implement Seeker, but in practice it does. + fSeeker, ok := f.(io.ReadSeeker) + if !ok { + http.Error(w, "Not seekable", http.StatusInternalServerError) + return + } + + // Aggressively cache static assets, since we cache-bust our assets with + // hashed filenames. + w.Header().Set("Cache-Control", "public, max-age=31535996") + w.Header().Set("Vary", "Accept-Encoding") + + http.ServeContent(w, r, path, serveStartTime, fSeeker) +} diff --git a/cmd/tsconnect/src/app/app.tsx b/cmd/tsconnect/src/app/app.tsx index c0aa7a5e88f63..ee538eaeac506 100644 --- a/cmd/tsconnect/src/app/app.tsx +++ b/cmd/tsconnect/src/app/app.tsx @@ -1,147 +1,147 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -import { render, Component } from "preact" -import { URLDisplay } from "./url-display" -import { Header } from "./header" -import { GoPanicDisplay } from "./go-panic-display" -import { SSH } from "./ssh" - -type AppState = { - ipn?: IPN - ipnState: IPNState - netMap?: IPNNetMap - browseToURL?: string - goPanicError?: string -} - -class App extends Component<{}, AppState> { - state: AppState = { ipnState: "NoState" } - #goPanicTimeout?: number - - render() { - const { ipn, ipnState, goPanicError, netMap, browseToURL } = this.state - - let goPanicDisplay - if (goPanicError) { - goPanicDisplay = ( - - ) - } - - let urlDisplay - if (browseToURL) { - urlDisplay = - } - - let machineAuthInstructions - if (ipnState === "NeedsMachineAuth") { - machineAuthInstructions = ( -
- An administrator needs to approve this device. -
- ) - } - - const lockedOut = netMap?.lockedOut - let lockedOutInstructions - if (lockedOut) { - lockedOutInstructions = ( -
-

This instance of Tailscale Connect needs to be signed, due to - {" "}tailnet lock{" "} - being enabled on this domain. -

- -

- Run the following command on a device with a trusted tailnet lock key: -

tailscale lock sign {netMap.self.nodeKey}
-

-
- ) - } - - let ssh - if (ipn && ipnState === "Running" && netMap && !lockedOut) { - ssh = - } - - return ( - <> -
- {goPanicDisplay} -
- {urlDisplay} - {machineAuthInstructions} - {lockedOutInstructions} - {ssh} -
- - ) - } - - runWithIPN(ipn: IPN) { - this.setState({ ipn }, () => { - ipn.run({ - notifyState: this.handleIPNState, - notifyNetMap: this.handleNetMap, - notifyBrowseToURL: this.handleBrowseToURL, - notifyPanicRecover: this.handleGoPanic, - }) - }) - } - - handleIPNState = (state: IPNState) => { - const { ipn } = this.state - this.setState({ ipnState: state }) - if (state === "NeedsLogin") { - ipn?.login() - } else if (["Running", "NeedsMachineAuth"].includes(state)) { - this.setState({ browseToURL: undefined }) - } - } - - handleNetMap = (netMapStr: string) => { - const netMap = JSON.parse(netMapStr) as IPNNetMap - if (DEBUG) { - console.log("Received net map: " + JSON.stringify(netMap, null, 2)) - } - this.setState({ netMap }) - } - - handleBrowseToURL = (url: string) => { - if (this.state.ipnState === "Running") { - // Ignore URL requests if we're already running -- it's most likely an - // SSH check mode trigger and we already linkify the displayed URL - // in the terminal. - return - } - this.setState({ browseToURL: url }) - } - - handleGoPanic = (error: string) => { - if (DEBUG) { - console.error("Go panic", error) - } - this.setState({ goPanicError: error }) - if (this.#goPanicTimeout) { - window.clearTimeout(this.#goPanicTimeout) - } - this.#goPanicTimeout = window.setTimeout(this.clearGoPanic, 10000) - } - - clearGoPanic = () => { - window.clearTimeout(this.#goPanicTimeout) - this.#goPanicTimeout = undefined - this.setState({ goPanicError: undefined }) - } -} - -export function renderApp(): Promise { - return new Promise((resolve) => { - render( - (app ? resolve(app) : undefined)} />, - document.body - ) - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +import { render, Component } from "preact" +import { URLDisplay } from "./url-display" +import { Header } from "./header" +import { GoPanicDisplay } from "./go-panic-display" +import { SSH } from "./ssh" + +type AppState = { + ipn?: IPN + ipnState: IPNState + netMap?: IPNNetMap + browseToURL?: string + goPanicError?: string +} + +class App extends Component<{}, AppState> { + state: AppState = { ipnState: "NoState" } + #goPanicTimeout?: number + + render() { + const { ipn, ipnState, goPanicError, netMap, browseToURL } = this.state + + let goPanicDisplay + if (goPanicError) { + goPanicDisplay = ( + + ) + } + + let urlDisplay + if (browseToURL) { + urlDisplay = + } + + let machineAuthInstructions + if (ipnState === "NeedsMachineAuth") { + machineAuthInstructions = ( +
+ An administrator needs to approve this device. +
+ ) + } + + const lockedOut = netMap?.lockedOut + let lockedOutInstructions + if (lockedOut) { + lockedOutInstructions = ( +
+

This instance of Tailscale Connect needs to be signed, due to + {" "}tailnet lock{" "} + being enabled on this domain. +

+ +

+ Run the following command on a device with a trusted tailnet lock key: +

tailscale lock sign {netMap.self.nodeKey}
+

+
+ ) + } + + let ssh + if (ipn && ipnState === "Running" && netMap && !lockedOut) { + ssh = + } + + return ( + <> +
+ {goPanicDisplay} +
+ {urlDisplay} + {machineAuthInstructions} + {lockedOutInstructions} + {ssh} +
+ + ) + } + + runWithIPN(ipn: IPN) { + this.setState({ ipn }, () => { + ipn.run({ + notifyState: this.handleIPNState, + notifyNetMap: this.handleNetMap, + notifyBrowseToURL: this.handleBrowseToURL, + notifyPanicRecover: this.handleGoPanic, + }) + }) + } + + handleIPNState = (state: IPNState) => { + const { ipn } = this.state + this.setState({ ipnState: state }) + if (state === "NeedsLogin") { + ipn?.login() + } else if (["Running", "NeedsMachineAuth"].includes(state)) { + this.setState({ browseToURL: undefined }) + } + } + + handleNetMap = (netMapStr: string) => { + const netMap = JSON.parse(netMapStr) as IPNNetMap + if (DEBUG) { + console.log("Received net map: " + JSON.stringify(netMap, null, 2)) + } + this.setState({ netMap }) + } + + handleBrowseToURL = (url: string) => { + if (this.state.ipnState === "Running") { + // Ignore URL requests if we're already running -- it's most likely an + // SSH check mode trigger and we already linkify the displayed URL + // in the terminal. + return + } + this.setState({ browseToURL: url }) + } + + handleGoPanic = (error: string) => { + if (DEBUG) { + console.error("Go panic", error) + } + this.setState({ goPanicError: error }) + if (this.#goPanicTimeout) { + window.clearTimeout(this.#goPanicTimeout) + } + this.#goPanicTimeout = window.setTimeout(this.clearGoPanic, 10000) + } + + clearGoPanic = () => { + window.clearTimeout(this.#goPanicTimeout) + this.#goPanicTimeout = undefined + this.setState({ goPanicError: undefined }) + } +} + +export function renderApp(): Promise { + return new Promise((resolve) => { + render( + (app ? resolve(app) : undefined)} />, + document.body + ) + }) +} diff --git a/cmd/tsconnect/src/app/go-panic-display.tsx b/cmd/tsconnect/src/app/go-panic-display.tsx index aab35c4d55e9c..5dd7095a27c7d 100644 --- a/cmd/tsconnect/src/app/go-panic-display.tsx +++ b/cmd/tsconnect/src/app/go-panic-display.tsx @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -export function GoPanicDisplay({ - error, - dismiss, -}: { - error: string - dismiss: () => void -}) { - return ( -
- Tailscale has encountered an error. -
Click to reload
-
- ) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +export function GoPanicDisplay({ + error, + dismiss, +}: { + error: string + dismiss: () => void +}) { + return ( +
+ Tailscale has encountered an error. +
Click to reload
+
+ ) +} diff --git a/cmd/tsconnect/src/app/header.tsx b/cmd/tsconnect/src/app/header.tsx index 8449f4563689d..099ff2f8c2f7d 100644 --- a/cmd/tsconnect/src/app/header.tsx +++ b/cmd/tsconnect/src/app/header.tsx @@ -1,37 +1,37 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -export function Header({ state, ipn }: { state: IPNState; ipn?: IPN }) { - const stateText = STATE_LABELS[state] - - let logoutButton - if (state === "Running") { - logoutButton = ( - - ) - } - return ( -
-
-

Tailscale Connect

-
{stateText}
- {logoutButton} -
-
- ) -} - -const STATE_LABELS = { - NoState: "Initializing…", - InUseOtherUser: "In-use by another user", - NeedsLogin: "Needs login", - NeedsMachineAuth: "Needs approval", - Stopped: "Stopped", - Starting: "Starting…", - Running: "Running", -} as const +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +export function Header({ state, ipn }: { state: IPNState; ipn?: IPN }) { + const stateText = STATE_LABELS[state] + + let logoutButton + if (state === "Running") { + logoutButton = ( + + ) + } + return ( +
+
+

Tailscale Connect

+
{stateText}
+ {logoutButton} +
+
+ ) +} + +const STATE_LABELS = { + NoState: "Initializing…", + InUseOtherUser: "In-use by another user", + NeedsLogin: "Needs login", + NeedsMachineAuth: "Needs approval", + Stopped: "Stopped", + Starting: "Starting…", + Running: "Running", +} as const diff --git a/cmd/tsconnect/src/app/index.css b/cmd/tsconnect/src/app/index.css index 848b83d12b5c9..751b313d9f362 100644 --- a/cmd/tsconnect/src/app/index.css +++ b/cmd/tsconnect/src/app/index.css @@ -1,74 +1,74 @@ -/* Copyright (c) Tailscale Inc & AUTHORS */ -/* SPDX-License-Identifier: BSD-3-Clause */ - -@import "xterm/css/xterm.css"; - -@tailwind base; -@tailwind components; -@tailwind utilities; - -.link { - @apply text-blue-600; -} - -.link:hover { - @apply underline; -} - -.button { - @apply font-medium py-1 px-2 rounded-md border border-transparent text-center cursor-pointer; - transition-property: background-color, border-color, color, box-shadow; - transition-duration: 120ms; - box-shadow: 0 1px 1px rgba(0, 0, 0, 0.04); - min-width: 80px; -} -.button:focus { - @apply outline-none ring; -} -.button:disabled { - @apply pointer-events-none select-none; -} - -.input { - @apply appearance-none leading-tight rounded-md bg-white border border-gray-300 hover:border-gray-400 transition-colors px-3; - height: 2.375rem; -} - -.input::placeholder { - @apply text-gray-400; -} - -.input:disabled { - @apply border-gray-200; - @apply bg-gray-50; - @apply cursor-not-allowed; -} - -.input:focus { - @apply outline-none ring border-transparent; -} - -.select { - @apply appearance-none py-2 px-3 leading-tight rounded-md bg-white border border-gray-300; -} - -.select-with-arrow { - @apply relative; -} - -.select-with-arrow .select { - width: 100%; -} - -.select-with-arrow::after { - @apply absolute; - content: ""; - top: 50%; - right: 0.5rem; - transform: translate(-0.3em, -0.15em); - width: 0.6em; - height: 0.4em; - opacity: 0.6; - background-color: currentColor; - clip-path: polygon(100% 0%, 0 0%, 50% 100%); -} +/* Copyright (c) Tailscale Inc & AUTHORS */ +/* SPDX-License-Identifier: BSD-3-Clause */ + +@import "xterm/css/xterm.css"; + +@tailwind base; +@tailwind components; +@tailwind utilities; + +.link { + @apply text-blue-600; +} + +.link:hover { + @apply underline; +} + +.button { + @apply font-medium py-1 px-2 rounded-md border border-transparent text-center cursor-pointer; + transition-property: background-color, border-color, color, box-shadow; + transition-duration: 120ms; + box-shadow: 0 1px 1px rgba(0, 0, 0, 0.04); + min-width: 80px; +} +.button:focus { + @apply outline-none ring; +} +.button:disabled { + @apply pointer-events-none select-none; +} + +.input { + @apply appearance-none leading-tight rounded-md bg-white border border-gray-300 hover:border-gray-400 transition-colors px-3; + height: 2.375rem; +} + +.input::placeholder { + @apply text-gray-400; +} + +.input:disabled { + @apply border-gray-200; + @apply bg-gray-50; + @apply cursor-not-allowed; +} + +.input:focus { + @apply outline-none ring border-transparent; +} + +.select { + @apply appearance-none py-2 px-3 leading-tight rounded-md bg-white border border-gray-300; +} + +.select-with-arrow { + @apply relative; +} + +.select-with-arrow .select { + width: 100%; +} + +.select-with-arrow::after { + @apply absolute; + content: ""; + top: 50%; + right: 0.5rem; + transform: translate(-0.3em, -0.15em); + width: 0.6em; + height: 0.4em; + opacity: 0.6; + background-color: currentColor; + clip-path: polygon(100% 0%, 0 0%, 50% 100%); +} diff --git a/cmd/tsconnect/src/app/index.ts b/cmd/tsconnect/src/app/index.ts index 1432188aec1a1..24ca4543921ae 100644 --- a/cmd/tsconnect/src/app/index.ts +++ b/cmd/tsconnect/src/app/index.ts @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -import "../wasm_exec" -import wasmUrl from "./main.wasm" -import { sessionStateStorage } from "../lib/js-state-store" -import { renderApp } from "./app" - -async function main() { - const app = await renderApp() - const go = new Go() - const wasmInstance = await WebAssembly.instantiateStreaming( - fetch(`./dist/${wasmUrl}`), - go.importObject - ) - // The Go process should never exit, if it does then it's an unhandled panic. - go.run(wasmInstance.instance).then(() => - app.handleGoPanic("Unexpected shutdown") - ) - - const params = new URLSearchParams(window.location.search) - const authKey = params.get("authkey") ?? undefined - - const ipn = newIPN({ - // Persist IPN state in sessionStorage in development, so that we don't need - // to re-authorize every time we reload the page. - stateStorage: DEBUG ? sessionStateStorage : undefined, - // authKey allows for an auth key to be - // specified as a url param which automatically - // authorizes the client for use. - authKey: DEBUG ? authKey : undefined, - }) - app.runWithIPN(ipn) -} - -main() +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +import "../wasm_exec" +import wasmUrl from "./main.wasm" +import { sessionStateStorage } from "../lib/js-state-store" +import { renderApp } from "./app" + +async function main() { + const app = await renderApp() + const go = new Go() + const wasmInstance = await WebAssembly.instantiateStreaming( + fetch(`./dist/${wasmUrl}`), + go.importObject + ) + // The Go process should never exit, if it does then it's an unhandled panic. + go.run(wasmInstance.instance).then(() => + app.handleGoPanic("Unexpected shutdown") + ) + + const params = new URLSearchParams(window.location.search) + const authKey = params.get("authkey") ?? undefined + + const ipn = newIPN({ + // Persist IPN state in sessionStorage in development, so that we don't need + // to re-authorize every time we reload the page. + stateStorage: DEBUG ? sessionStateStorage : undefined, + // authKey allows for an auth key to be + // specified as a url param which automatically + // authorizes the client for use. + authKey: DEBUG ? authKey : undefined, + }) + app.runWithIPN(ipn) +} + +main() diff --git a/cmd/tsconnect/src/app/ssh.tsx b/cmd/tsconnect/src/app/ssh.tsx index 1534fd5db643f..df81745bd3fd7 100644 --- a/cmd/tsconnect/src/app/ssh.tsx +++ b/cmd/tsconnect/src/app/ssh.tsx @@ -1,157 +1,157 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -import { useState, useCallback, useMemo, useEffect, useRef } from "preact/hooks" -import { createPortal } from "preact/compat" -import type { VNode } from "preact" -import { runSSHSession, SSHSessionDef } from "../lib/ssh" - -export function SSH({ netMap, ipn }: { netMap: IPNNetMap; ipn: IPN }) { - const [sshSessionDef, setSSHSessionDef] = useState( - null - ) - const clearSSHSessionDef = useCallback(() => setSSHSessionDef(null), []) - if (sshSessionDef) { - const sshSession = ( - - ) - if (sshSessionDef.newWindow) { - return {sshSession} - } - return sshSession - } - const sshPeers = netMap.peers.filter( - (p) => p.tailscaleSSHEnabled && p.online !== false - ) - - if (sshPeers.length == 0) { - return - } - - return -} - -type SSHFormSessionDef = SSHSessionDef & { newWindow?: boolean } - -function SSHSession({ - def, - ipn, - onDone, -}: { - def: SSHSessionDef - ipn: IPN - onDone: () => void -}) { - const ref = useRef(null) - useEffect(() => { - if (ref.current) { - runSSHSession(ref.current, def, ipn, { - onConnectionProgress: (p) => console.log("Connection progress", p), - onConnected() {}, - onError: (err) => console.error(err), - onDone, - }) - } - }, [ref]) - - return
-} - -function NoSSHPeers() { - return ( -
- None of your machines have{" "} - - Tailscale SSH - - {" "}enabled. Give it a try! -
- ) -} - -function SSHForm({ - sshPeers, - onSubmit, -}: { - sshPeers: IPNNetMapPeerNode[] - onSubmit: (def: SSHFormSessionDef) => void -}) { - sshPeers = sshPeers.slice().sort((a, b) => a.name.localeCompare(b.name)) - const [username, setUsername] = useState("") - const [hostname, setHostname] = useState(sshPeers[0].name) - return ( -
{ - e.preventDefault() - onSubmit({ username, hostname }) - }} - > - setUsername(e.currentTarget.value)} - /> -
- -
- { - if (e.altKey) { - e.preventDefault() - e.stopPropagation() - onSubmit({ username, hostname, newWindow: true }) - } - }} - /> -
- ) -} - -const NewWindow = ({ - children, - close, -}: { - children: VNode - close: () => void -}) => { - const newWindow = useMemo(() => { - const newWindow = window.open(undefined, undefined, "width=600,height=400") - if (newWindow) { - const containerNode = newWindow.document.createElement("div") - containerNode.className = "h-screen flex flex-col overflow-hidden" - newWindow.document.body.appendChild(containerNode) - - for (const linkNode of document.querySelectorAll( - "head link[rel=stylesheet]" - )) { - const newLink = document.createElement("link") - newLink.rel = "stylesheet" - newLink.href = (linkNode as HTMLLinkElement).href - newWindow.document.head.appendChild(newLink) - } - } - return newWindow - }, []) - if (!newWindow) { - console.error("Could not open window") - return null - } - newWindow.onbeforeunload = () => { - close() - } - - useEffect(() => () => newWindow.close(), []) - return createPortal(children, newWindow.document.body.lastChild as Element) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +import { useState, useCallback, useMemo, useEffect, useRef } from "preact/hooks" +import { createPortal } from "preact/compat" +import type { VNode } from "preact" +import { runSSHSession, SSHSessionDef } from "../lib/ssh" + +export function SSH({ netMap, ipn }: { netMap: IPNNetMap; ipn: IPN }) { + const [sshSessionDef, setSSHSessionDef] = useState( + null + ) + const clearSSHSessionDef = useCallback(() => setSSHSessionDef(null), []) + if (sshSessionDef) { + const sshSession = ( + + ) + if (sshSessionDef.newWindow) { + return {sshSession} + } + return sshSession + } + const sshPeers = netMap.peers.filter( + (p) => p.tailscaleSSHEnabled && p.online !== false + ) + + if (sshPeers.length == 0) { + return + } + + return +} + +type SSHFormSessionDef = SSHSessionDef & { newWindow?: boolean } + +function SSHSession({ + def, + ipn, + onDone, +}: { + def: SSHSessionDef + ipn: IPN + onDone: () => void +}) { + const ref = useRef(null) + useEffect(() => { + if (ref.current) { + runSSHSession(ref.current, def, ipn, { + onConnectionProgress: (p) => console.log("Connection progress", p), + onConnected() {}, + onError: (err) => console.error(err), + onDone, + }) + } + }, [ref]) + + return
+} + +function NoSSHPeers() { + return ( +
+ None of your machines have{" "} + + Tailscale SSH + + {" "}enabled. Give it a try! +
+ ) +} + +function SSHForm({ + sshPeers, + onSubmit, +}: { + sshPeers: IPNNetMapPeerNode[] + onSubmit: (def: SSHFormSessionDef) => void +}) { + sshPeers = sshPeers.slice().sort((a, b) => a.name.localeCompare(b.name)) + const [username, setUsername] = useState("") + const [hostname, setHostname] = useState(sshPeers[0].name) + return ( +
{ + e.preventDefault() + onSubmit({ username, hostname }) + }} + > + setUsername(e.currentTarget.value)} + /> +
+ +
+ { + if (e.altKey) { + e.preventDefault() + e.stopPropagation() + onSubmit({ username, hostname, newWindow: true }) + } + }} + /> +
+ ) +} + +const NewWindow = ({ + children, + close, +}: { + children: VNode + close: () => void +}) => { + const newWindow = useMemo(() => { + const newWindow = window.open(undefined, undefined, "width=600,height=400") + if (newWindow) { + const containerNode = newWindow.document.createElement("div") + containerNode.className = "h-screen flex flex-col overflow-hidden" + newWindow.document.body.appendChild(containerNode) + + for (const linkNode of document.querySelectorAll( + "head link[rel=stylesheet]" + )) { + const newLink = document.createElement("link") + newLink.rel = "stylesheet" + newLink.href = (linkNode as HTMLLinkElement).href + newWindow.document.head.appendChild(newLink) + } + } + return newWindow + }, []) + if (!newWindow) { + console.error("Could not open window") + return null + } + newWindow.onbeforeunload = () => { + close() + } + + useEffect(() => () => newWindow.close(), []) + return createPortal(children, newWindow.document.body.lastChild as Element) +} diff --git a/cmd/tsconnect/src/app/url-display.tsx b/cmd/tsconnect/src/app/url-display.tsx index c9b59018108bc..fc82c7fb91b3c 100644 --- a/cmd/tsconnect/src/app/url-display.tsx +++ b/cmd/tsconnect/src/app/url-display.tsx @@ -1,31 +1,31 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -import { useState } from "preact/hooks" -import * as qrcode from "qrcode" - -export function URLDisplay({ url }: { url: string }) { - const [dataURL, setDataURL] = useState("") - qrcode.toDataURL(url, { width: 512 }, (err, dataURL) => { - if (err) { - console.error("Error generating QR code", err) - } else { - setDataURL(dataURL) - } - }) - - return ( - - ) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +import { useState } from "preact/hooks" +import * as qrcode from "qrcode" + +export function URLDisplay({ url }: { url: string }) { + const [dataURL, setDataURL] = useState("") + qrcode.toDataURL(url, { width: 512 }, (err, dataURL) => { + if (err) { + console.error("Error generating QR code", err) + } else { + setDataURL(dataURL) + } + }) + + return ( + + ) +} diff --git a/cmd/tsconnect/src/lib/js-state-store.ts b/cmd/tsconnect/src/lib/js-state-store.ts index 7685e28a9de7c..e57dfd98efabd 100644 --- a/cmd/tsconnect/src/lib/js-state-store.ts +++ b/cmd/tsconnect/src/lib/js-state-store.ts @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -/** @fileoverview Callbacks used by jsStateStore to persist IPN state. */ - -export const sessionStateStorage: IPNStateStorage = { - setState(id, value) { - window.sessionStorage[`ipn-state-${id}`] = value - }, - getState(id) { - return window.sessionStorage[`ipn-state-${id}`] || "" - }, -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +/** @fileoverview Callbacks used by jsStateStore to persist IPN state. */ + +export const sessionStateStorage: IPNStateStorage = { + setState(id, value) { + window.sessionStorage[`ipn-state-${id}`] = value + }, + getState(id) { + return window.sessionStorage[`ipn-state-${id}`] || "" + }, +} diff --git a/cmd/tsconnect/src/pkg/pkg.css b/cmd/tsconnect/src/pkg/pkg.css index 60146d5b7cca9..76ea21f5b53b2 100644 --- a/cmd/tsconnect/src/pkg/pkg.css +++ b/cmd/tsconnect/src/pkg/pkg.css @@ -1,8 +1,8 @@ -/* Copyright (c) Tailscale Inc & AUTHORS */ -/* SPDX-License-Identifier: BSD-3-Clause */ - -@import "xterm/css/xterm.css"; - -@tailwind base; -@tailwind components; -@tailwind utilities; +/* Copyright (c) Tailscale Inc & AUTHORS */ +/* SPDX-License-Identifier: BSD-3-Clause */ + +@import "xterm/css/xterm.css"; + +@tailwind base; +@tailwind components; +@tailwind utilities; diff --git a/cmd/tsconnect/src/pkg/pkg.ts b/cmd/tsconnect/src/pkg/pkg.ts index c0dcb5652ec62..4d535cb404015 100644 --- a/cmd/tsconnect/src/pkg/pkg.ts +++ b/cmd/tsconnect/src/pkg/pkg.ts @@ -1,40 +1,40 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Type definitions need to be manually imported for dts-bundle-generator to -// discover them. -/// -/// - -import "../wasm_exec" -import wasmURL from "./main.wasm" - -/** - * Superset of the IPNConfig type, with additional configuration that is - * needed for the package to function. - */ -type IPNPackageConfig = IPNConfig & { - // Auth key used to initialize the Tailscale client (required) - authKey: string - // URL of the main.wasm file that is included in the page, if it is not - // accessible via a relative URL. - wasmURL?: string - // Function invoked if the Go process panics or unexpectedly exits. - panicHandler: (err: string) => void -} - -export async function createIPN(config: IPNPackageConfig): Promise { - const go = new Go() - const wasmInstance = await WebAssembly.instantiateStreaming( - fetch(config.wasmURL ?? wasmURL), - go.importObject - ) - // The Go process should never exit, if it does then it's an unhandled panic. - go.run(wasmInstance.instance).then(() => - config.panicHandler("Unexpected shutdown") - ) - - return newIPN(config) -} - -export { runSSHSession } from "../lib/ssh" +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Type definitions need to be manually imported for dts-bundle-generator to +// discover them. +/// +/// + +import "../wasm_exec" +import wasmURL from "./main.wasm" + +/** + * Superset of the IPNConfig type, with additional configuration that is + * needed for the package to function. + */ +type IPNPackageConfig = IPNConfig & { + // Auth key used to initialize the Tailscale client (required) + authKey: string + // URL of the main.wasm file that is included in the page, if it is not + // accessible via a relative URL. + wasmURL?: string + // Function invoked if the Go process panics or unexpectedly exits. + panicHandler: (err: string) => void +} + +export async function createIPN(config: IPNPackageConfig): Promise { + const go = new Go() + const wasmInstance = await WebAssembly.instantiateStreaming( + fetch(config.wasmURL ?? wasmURL), + go.importObject + ) + // The Go process should never exit, if it does then it's an unhandled panic. + go.run(wasmInstance.instance).then(() => + config.panicHandler("Unexpected shutdown") + ) + + return newIPN(config) +} + +export { runSSHSession } from "../lib/ssh" diff --git a/cmd/tsconnect/src/types/esbuild.d.ts b/cmd/tsconnect/src/types/esbuild.d.ts index 7153b4244e7c5..ef28f7b1cf556 100644 --- a/cmd/tsconnect/src/types/esbuild.d.ts +++ b/cmd/tsconnect/src/types/esbuild.d.ts @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -/** - * @fileoverview Type definitions for types generated by the esbuild build - * process. - */ - -declare module "*.wasm" { - const path: string - export default path -} - -declare const DEBUG: boolean +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +/** + * @fileoverview Type definitions for types generated by the esbuild build + * process. + */ + +declare module "*.wasm" { + const path: string + export default path +} + +declare const DEBUG: boolean diff --git a/cmd/tsconnect/src/types/wasm_js.d.ts b/cmd/tsconnect/src/types/wasm_js.d.ts index 82822c508040e..492197ccb1a9b 100644 --- a/cmd/tsconnect/src/types/wasm_js.d.ts +++ b/cmd/tsconnect/src/types/wasm_js.d.ts @@ -1,103 +1,103 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -/** - * @fileoverview Type definitions for types exported by the wasm_js.go Go - * module. - */ - -declare global { - function newIPN(config: IPNConfig): IPN - - interface IPN { - run(callbacks: IPNCallbacks): void - login(): void - logout(): void - ssh( - host: string, - username: string, - termConfig: { - writeFn: (data: string) => void - writeErrorFn: (err: string) => void - setReadFn: (readFn: (data: string) => void) => void - rows: number - cols: number - /** Defaults to 5 seconds */ - timeoutSeconds?: number - onConnectionProgress: (message: string) => void - onConnected: () => void - onDone: () => void - } - ): IPNSSHSession - fetch(url: string): Promise<{ - status: number - statusText: string - text: () => Promise - }> - } - - interface IPNSSHSession { - resize(rows: number, cols: number): boolean - close(): boolean - } - - interface IPNStateStorage { - setState(id: string, value: string): void - getState(id: string): string - } - - type IPNConfig = { - stateStorage?: IPNStateStorage - authKey?: string - controlURL?: string - hostname?: string - } - - type IPNCallbacks = { - notifyState: (state: IPNState) => void - notifyNetMap: (netMapStr: string) => void - notifyBrowseToURL: (url: string) => void - notifyPanicRecover: (err: string) => void - } - - type IPNNetMap = { - self: IPNNetMapSelfNode - peers: IPNNetMapPeerNode[] - lockedOut: boolean - } - - type IPNNetMapNode = { - name: string - addresses: string[] - machineKey: string - nodeKey: string - } - - type IPNNetMapSelfNode = IPNNetMapNode & { - machineStatus: IPNMachineStatus - } - - type IPNNetMapPeerNode = IPNNetMapNode & { - online?: boolean - tailscaleSSHEnabled: boolean - } - - /** Mirrors values from ipn/backend.go */ - type IPNState = - | "NoState" - | "InUseOtherUser" - | "NeedsLogin" - | "NeedsMachineAuth" - | "Stopped" - | "Starting" - | "Running" - - /** Mirrors values from MachineStatus in tailcfg.go */ - type IPNMachineStatus = - | "MachineUnknown" - | "MachineUnauthorized" - | "MachineAuthorized" - | "MachineInvalid" -} - -export {} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +/** + * @fileoverview Type definitions for types exported by the wasm_js.go Go + * module. + */ + +declare global { + function newIPN(config: IPNConfig): IPN + + interface IPN { + run(callbacks: IPNCallbacks): void + login(): void + logout(): void + ssh( + host: string, + username: string, + termConfig: { + writeFn: (data: string) => void + writeErrorFn: (err: string) => void + setReadFn: (readFn: (data: string) => void) => void + rows: number + cols: number + /** Defaults to 5 seconds */ + timeoutSeconds?: number + onConnectionProgress: (message: string) => void + onConnected: () => void + onDone: () => void + } + ): IPNSSHSession + fetch(url: string): Promise<{ + status: number + statusText: string + text: () => Promise + }> + } + + interface IPNSSHSession { + resize(rows: number, cols: number): boolean + close(): boolean + } + + interface IPNStateStorage { + setState(id: string, value: string): void + getState(id: string): string + } + + type IPNConfig = { + stateStorage?: IPNStateStorage + authKey?: string + controlURL?: string + hostname?: string + } + + type IPNCallbacks = { + notifyState: (state: IPNState) => void + notifyNetMap: (netMapStr: string) => void + notifyBrowseToURL: (url: string) => void + notifyPanicRecover: (err: string) => void + } + + type IPNNetMap = { + self: IPNNetMapSelfNode + peers: IPNNetMapPeerNode[] + lockedOut: boolean + } + + type IPNNetMapNode = { + name: string + addresses: string[] + machineKey: string + nodeKey: string + } + + type IPNNetMapSelfNode = IPNNetMapNode & { + machineStatus: IPNMachineStatus + } + + type IPNNetMapPeerNode = IPNNetMapNode & { + online?: boolean + tailscaleSSHEnabled: boolean + } + + /** Mirrors values from ipn/backend.go */ + type IPNState = + | "NoState" + | "InUseOtherUser" + | "NeedsLogin" + | "NeedsMachineAuth" + | "Stopped" + | "Starting" + | "Running" + + /** Mirrors values from MachineStatus in tailcfg.go */ + type IPNMachineStatus = + | "MachineUnknown" + | "MachineUnauthorized" + | "MachineAuthorized" + | "MachineInvalid" +} + +export {} diff --git a/cmd/tsconnect/tailwind.config.js b/cmd/tsconnect/tailwind.config.js index 38bc5b97b714e..31823000b6139 100644 --- a/cmd/tsconnect/tailwind.config.js +++ b/cmd/tsconnect/tailwind.config.js @@ -1,8 +1,8 @@ -/** @type {import('tailwindcss').Config} */ -module.exports = { - content: ["./index.html", "./src/**/*.ts", "./src/**/*.tsx"], - theme: { - extend: {}, - }, - plugins: [], -} +/** @type {import('tailwindcss').Config} */ +module.exports = { + content: ["./index.html", "./src/**/*.ts", "./src/**/*.tsx"], + theme: { + extend: {}, + }, + plugins: [], +} diff --git a/cmd/tsconnect/tsconfig.json b/cmd/tsconnect/tsconfig.json index 1148e2ef0c43a..52c25c7271f7c 100644 --- a/cmd/tsconnect/tsconfig.json +++ b/cmd/tsconnect/tsconfig.json @@ -1,15 +1,15 @@ -{ - "compilerOptions": { - "target": "ES2017", - "module": "ES2020", - "moduleResolution": "node", - "isolatedModules": true, - "strict": true, - "forceConsistentCasingInFileNames": true, - "sourceMap": true, - "jsx": "react-jsx", - "jsxImportSource": "preact" - }, - "include": ["src/**/*"], - "exclude": ["node_modules"] -} +{ + "compilerOptions": { + "target": "ES2017", + "module": "ES2020", + "moduleResolution": "node", + "isolatedModules": true, + "strict": true, + "forceConsistentCasingInFileNames": true, + "sourceMap": true, + "jsx": "react-jsx", + "jsxImportSource": "preact" + }, + "include": ["src/**/*"], + "exclude": ["node_modules"] +} diff --git a/cmd/tsconnect/tsconnect.go b/cmd/tsconnect/tsconnect.go index 60ea6ef822d99..4c8a0a52ece34 100644 --- a/cmd/tsconnect/tsconnect.go +++ b/cmd/tsconnect/tsconnect.go @@ -1,71 +1,71 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -// The tsconnect command builds and serves the static site that is generated for -// the Tailscale Connect JS/WASM client. Can be run in 3 modes: -// - dev: builds the site and serves it. JS and CSS changes can be picked up -// with a reload. -// - build: builds the site and writes it to dist/ -// - serve: serves the site from dist/ (embedded in the binary) -package main // import "tailscale.com/cmd/tsconnect" - -import ( - "flag" - "fmt" - "log" - "os" -) - -var ( - addr = flag.String("addr", ":9090", "address to listen on") - distDir = flag.String("distdir", "./dist", "path of directory to place build output in") - pkgDir = flag.String("pkgdir", "./pkg", "path of directory to place NPM package build output in") - yarnPath = flag.String("yarnpath", "", "path yarn executable used to install JavaScript dependencies") - fastCompression = flag.Bool("fast-compression", false, "Use faster compression when building, to speed up build time. Meant to iterative/debugging use only.") - devControl = flag.String("dev-control", "", "URL of a development control server to be used with dev. If provided without specifying dev, an error will be returned.") - rootDir = flag.String("rootdir", "", "Root directory of repo. If not specified, will be inferred from the cwd.") -) - -func main() { - flag.Usage = usage - flag.Parse() - if len(flag.Args()) != 1 { - flag.Usage() - } - - switch flag.Arg(0) { - case "dev": - runDev() - case "dev-pkg": - runDevPkg() - case "build": - runBuild() - case "build-pkg": - runBuildPkg() - case "serve": - runServe() - default: - log.Printf("Unknown command: %s", flag.Arg(0)) - flag.Usage() - } -} - -func usage() { - fmt.Fprintf(os.Stderr, ` -usage: tsconnect {dev|build|serve} -`[1:]) - - flag.PrintDefaults() - fmt.Fprintf(os.Stderr, ` - -tsconnect implements development/build/serving workflows for Tailscale Connect. -It can be invoked with one of three subcommands: - -- dev: Run in development mode, allowing JS and CSS changes to be picked up without a rebuilt or restart. -- build: Run in production build mode (generating static assets) -- serve: Run in production serve mode (serving static assets) -`[1:]) - os.Exit(2) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// The tsconnect command builds and serves the static site that is generated for +// the Tailscale Connect JS/WASM client. Can be run in 3 modes: +// - dev: builds the site and serves it. JS and CSS changes can be picked up +// with a reload. +// - build: builds the site and writes it to dist/ +// - serve: serves the site from dist/ (embedded in the binary) +package main // import "tailscale.com/cmd/tsconnect" + +import ( + "flag" + "fmt" + "log" + "os" +) + +var ( + addr = flag.String("addr", ":9090", "address to listen on") + distDir = flag.String("distdir", "./dist", "path of directory to place build output in") + pkgDir = flag.String("pkgdir", "./pkg", "path of directory to place NPM package build output in") + yarnPath = flag.String("yarnpath", "", "path yarn executable used to install JavaScript dependencies") + fastCompression = flag.Bool("fast-compression", false, "Use faster compression when building, to speed up build time. Meant to iterative/debugging use only.") + devControl = flag.String("dev-control", "", "URL of a development control server to be used with dev. If provided without specifying dev, an error will be returned.") + rootDir = flag.String("rootdir", "", "Root directory of repo. If not specified, will be inferred from the cwd.") +) + +func main() { + flag.Usage = usage + flag.Parse() + if len(flag.Args()) != 1 { + flag.Usage() + } + + switch flag.Arg(0) { + case "dev": + runDev() + case "dev-pkg": + runDevPkg() + case "build": + runBuild() + case "build-pkg": + runBuildPkg() + case "serve": + runServe() + default: + log.Printf("Unknown command: %s", flag.Arg(0)) + flag.Usage() + } +} + +func usage() { + fmt.Fprintf(os.Stderr, ` +usage: tsconnect {dev|build|serve} +`[1:]) + + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, ` + +tsconnect implements development/build/serving workflows for Tailscale Connect. +It can be invoked with one of three subcommands: + +- dev: Run in development mode, allowing JS and CSS changes to be picked up without a rebuilt or restart. +- build: Run in production build mode (generating static assets) +- serve: Run in production serve mode (serving static assets) +`[1:]) + os.Exit(2) +} diff --git a/cmd/tsconnect/yarn.lock b/cmd/tsconnect/yarn.lock index 914b4e6d041f7..663a1244ebf69 100644 --- a/cmd/tsconnect/yarn.lock +++ b/cmd/tsconnect/yarn.lock @@ -1,713 +1,713 @@ -# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY. -# yarn lockfile v1 - - -"@nodelib/fs.scandir@2.1.5": - version "2.1.5" - resolved "https://registry.yarnpkg.com/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz#7619c2eb21b25483f6d167548b4cfd5a7488c3d5" - integrity sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g== - dependencies: - "@nodelib/fs.stat" "2.0.5" - run-parallel "^1.1.9" - -"@nodelib/fs.stat@2.0.5", "@nodelib/fs.stat@^2.0.2": - version "2.0.5" - resolved "https://registry.yarnpkg.com/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz#5bd262af94e9d25bd1e71b05deed44876a222e8b" - integrity sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A== - -"@nodelib/fs.walk@^1.2.3": - version "1.2.8" - resolved "https://registry.yarnpkg.com/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz#e95737e8bb6746ddedf69c556953494f196fe69a" - integrity sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg== - dependencies: - "@nodelib/fs.scandir" "2.1.5" - fastq "^1.6.0" - -"@types/golang-wasm-exec@^1.15.0": - version "1.15.0" - resolved "https://registry.yarnpkg.com/@types/golang-wasm-exec/-/golang-wasm-exec-1.15.0.tgz#d0aafbb2b0dc07eaf45dfb83bfb6cdd5b2b3c55c" - integrity sha512-FrL97mp7WW8LqNinVkzTVKOIQKuYjQqgucnh41+1vRQ+bf1LT8uh++KRf9otZPXsa6H1p8ruIGz1BmCGttOL6Q== - -"@types/node@*": - version "18.6.1" - resolved "https://registry.yarnpkg.com/@types/node/-/node-18.6.1.tgz#828e4785ccca13f44e2fb6852ae0ef11e3e20ba5" - integrity sha512-z+2vB6yDt1fNwKOeGbckpmirO+VBDuQqecXkgeIqDlaOtmKn6hPR/viQ8cxCfqLU4fTlvM3+YjM367TukWdxpg== - -"@types/qrcode@^1.4.2": - version "1.4.2" - resolved "https://registry.yarnpkg.com/@types/qrcode/-/qrcode-1.4.2.tgz#7d7142d6fa9921f195db342ed08b539181546c74" - integrity sha512-7uNT9L4WQTNJejHTSTdaJhfBSCN73xtXaHFyBJ8TSwiLhe4PRuTue7Iph0s2nG9R/ifUaSnGhLUOZavlBEqDWQ== - dependencies: - "@types/node" "*" - -acorn-node@^1.8.2: - version "1.8.2" - resolved "https://registry.yarnpkg.com/acorn-node/-/acorn-node-1.8.2.tgz#114c95d64539e53dede23de8b9d96df7c7ae2af8" - integrity sha512-8mt+fslDufLYntIoPAaIMUe/lrbrehIiwmR3t2k9LljIzoigEPF27eLk2hy8zSGzmR/ogr7zbRKINMo1u0yh5A== - dependencies: - acorn "^7.0.0" - acorn-walk "^7.0.0" - xtend "^4.0.2" - -acorn-walk@^7.0.0: - version "7.2.0" - resolved "https://registry.yarnpkg.com/acorn-walk/-/acorn-walk-7.2.0.tgz#0de889a601203909b0fbe07b8938dc21d2e967bc" - integrity sha512-OPdCF6GsMIP+Az+aWfAAOEt2/+iVDKE7oy6lJ098aoe59oAmK76qV6Gw60SbZ8jHuG2wH058GF4pLFbYamYrVA== - -acorn@^7.0.0: - version "7.4.1" - resolved "https://registry.yarnpkg.com/acorn/-/acorn-7.4.1.tgz#feaed255973d2e77555b83dbc08851a6c63520fa" - integrity sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A== - -ansi-regex@^5.0.1: - version "5.0.1" - resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" - integrity sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ== - -ansi-styles@^4.0.0: - version "4.3.0" - resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937" - integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg== - dependencies: - color-convert "^2.0.1" - -anymatch@~3.1.2: - version "3.1.2" - resolved "https://registry.yarnpkg.com/anymatch/-/anymatch-3.1.2.tgz#c0557c096af32f106198f4f4e2a383537e378716" - integrity sha512-P43ePfOAIupkguHUycrc4qJ9kz8ZiuOUijaETwX7THt0Y/GNK7v0aa8rY816xWjZ7rJdA5XdMcpVFTKMq+RvWg== - dependencies: - normalize-path "^3.0.0" - picomatch "^2.0.4" - -arg@^5.0.2: - version "5.0.2" - resolved "https://registry.yarnpkg.com/arg/-/arg-5.0.2.tgz#c81433cc427c92c4dcf4865142dbca6f15acd59c" - integrity sha512-PYjyFOLKQ9y57JvQ6QLo8dAgNqswh8M1RMJYdQduT6xbWSgK36P/Z/v+p888pM69jMMfS8Xd8F6I1kQ/I9HUGg== - -binary-extensions@^2.0.0: - version "2.2.0" - resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-2.2.0.tgz#75f502eeaf9ffde42fc98829645be4ea76bd9e2d" - integrity sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA== - -braces@^3.0.2, braces@~3.0.2: - version "3.0.2" - resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.2.tgz#3454e1a462ee8d599e236df336cd9ea4f8afe107" - integrity sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A== - dependencies: - fill-range "^7.0.1" - -camelcase-css@^2.0.1: - version "2.0.1" - resolved "https://registry.yarnpkg.com/camelcase-css/-/camelcase-css-2.0.1.tgz#ee978f6947914cc30c6b44741b6ed1df7f043fd5" - integrity sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA== - -camelcase@^5.0.0: - version "5.3.1" - resolved "https://registry.yarnpkg.com/camelcase/-/camelcase-5.3.1.tgz#e3c9b31569e106811df242f715725a1f4c494320" - integrity sha512-L28STB170nwWS63UjtlEOE3dldQApaJXZkOI1uMFfzf3rRuPegHaHesyee+YxQ+W6SvRDQV6UrdOdRiR153wJg== - -chokidar@^3.5.3: - version "3.5.3" - resolved "https://registry.yarnpkg.com/chokidar/-/chokidar-3.5.3.tgz#1cf37c8707b932bd1af1ae22c0432e2acd1903bd" - integrity sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw== - dependencies: - anymatch "~3.1.2" - braces "~3.0.2" - glob-parent "~5.1.2" - is-binary-path "~2.1.0" - is-glob "~4.0.1" - normalize-path "~3.0.0" - readdirp "~3.6.0" - optionalDependencies: - fsevents "~2.3.2" - -cliui@^6.0.0: - version "6.0.0" - resolved "https://registry.yarnpkg.com/cliui/-/cliui-6.0.0.tgz#511d702c0c4e41ca156d7d0e96021f23e13225b1" - integrity sha512-t6wbgtoCXvAzst7QgXxJYqPt0usEfbgQdftEPbLL/cvv6HPE5VgvqCuAIDR0NgU52ds6rFwqrgakNLrHEjCbrQ== - dependencies: - string-width "^4.2.0" - strip-ansi "^6.0.0" - wrap-ansi "^6.2.0" - -cliui@^7.0.2: - version "7.0.4" - resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.4.tgz#a0265ee655476fc807aea9df3df8df7783808b4f" - integrity sha512-OcRE68cOsVMXp1Yvonl/fzkQOyjLSu/8bhPDfQt0e0/Eb283TKP20Fs2MqoPsr9SwA595rRCA+QMzYc9nBP+JQ== - dependencies: - string-width "^4.2.0" - strip-ansi "^6.0.0" - wrap-ansi "^7.0.0" - -color-convert@^2.0.1: - version "2.0.1" - resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3" - integrity sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ== - dependencies: - color-name "~1.1.4" - -color-name@^1.1.4, color-name@~1.1.4: - version "1.1.4" - resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2" - integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA== - -cssesc@^3.0.0: - version "3.0.0" - resolved "https://registry.yarnpkg.com/cssesc/-/cssesc-3.0.0.tgz#37741919903b868565e1c09ea747445cd18983ee" - integrity sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg== - -decamelize@^1.2.0: - version "1.2.0" - resolved "https://registry.yarnpkg.com/decamelize/-/decamelize-1.2.0.tgz#f6534d15148269b20352e7bee26f501f9a191290" - integrity sha1-9lNNFRSCabIDUue+4m9QH5oZEpA= - -defined@^1.0.0: - version "1.0.0" - resolved "https://registry.yarnpkg.com/defined/-/defined-1.0.0.tgz#c98d9bcef75674188e110969151199e39b1fa693" - integrity sha512-Y2caI5+ZwS5c3RiNDJ6u53VhQHv+hHKwhkI1iHvceKUHw9Df6EK2zRLfjejRgMuCuxK7PfSWIMwWecceVvThjQ== - -detective@^5.2.1: - version "5.2.1" - resolved "https://registry.yarnpkg.com/detective/-/detective-5.2.1.tgz#6af01eeda11015acb0e73f933242b70f24f91034" - integrity sha512-v9XE1zRnz1wRtgurGu0Bs8uHKFSTdteYZNbIPFVhUZ39L/S79ppMpdmVOZAnoz1jfEFodc48n6MX483Xo3t1yw== - dependencies: - acorn-node "^1.8.2" - defined "^1.0.0" - minimist "^1.2.6" - -didyoumean@^1.2.2: - version "1.2.2" - resolved "https://registry.yarnpkg.com/didyoumean/-/didyoumean-1.2.2.tgz#989346ffe9e839b4555ecf5666edea0d3e8ad037" - integrity sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw== - -dijkstrajs@^1.0.1: - version "1.0.2" - resolved "https://registry.yarnpkg.com/dijkstrajs/-/dijkstrajs-1.0.2.tgz#2e48c0d3b825462afe75ab4ad5e829c8ece36257" - integrity sha512-QV6PMaHTCNmKSeP6QoXhVTw9snc9VD8MulTT0Bd99Pacp4SS1cjcrYPgBPmibqKVtMJJfqC6XvOXgPMEEPH/fg== - -dlv@^1.1.3: - version "1.1.3" - resolved "https://registry.yarnpkg.com/dlv/-/dlv-1.1.3.tgz#5c198a8a11453596e751494d49874bc7732f2e79" - integrity sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA== - -dts-bundle-generator@^6.12.0: - version "6.12.0" - resolved "https://registry.yarnpkg.com/dts-bundle-generator/-/dts-bundle-generator-6.12.0.tgz#0a221bdce5fdd309a56c8556e645f16ed87ab07d" - integrity sha512-k/QAvuVaLIdyWRUHduDrWBe4j8PcE6TDt06+f32KHbW7/SmUPbX1O23fFtQgKwUyTBkbIjJFOFtNrF97tJcKug== - dependencies: - typescript ">=3.0.1" - yargs "^17.2.1" - -emoji-regex@^8.0.0: - version "8.0.0" - resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-8.0.0.tgz#e818fd69ce5ccfcb404594f842963bf53164cc37" - integrity sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A== - -encode-utf8@^1.0.3: - version "1.0.3" - resolved "https://registry.yarnpkg.com/encode-utf8/-/encode-utf8-1.0.3.tgz#f30fdd31da07fb596f281beb2f6b027851994cda" - integrity sha512-ucAnuBEhUK4boH2HjVYG5Q2mQyPorvv0u/ocS+zhdw0S8AlHYY+GOFhP1Gio5z4icpP2ivFSvhtFjQi8+T9ppw== - -escalade@^3.1.1: - version "3.1.1" - resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" - integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== - -fast-glob@^3.2.11: - version "3.2.11" - resolved "https://registry.yarnpkg.com/fast-glob/-/fast-glob-3.2.11.tgz#a1172ad95ceb8a16e20caa5c5e56480e5129c1d9" - integrity sha512-xrO3+1bxSo3ZVHAnqzyuewYT6aMFHRAd4Kcs92MAonjwQZLsK9d0SF1IyQ3k5PoirxTW0Oe/RqFgMQ6TcNE5Ew== - dependencies: - "@nodelib/fs.stat" "^2.0.2" - "@nodelib/fs.walk" "^1.2.3" - glob-parent "^5.1.2" - merge2 "^1.3.0" - micromatch "^4.0.4" - -fastq@^1.6.0: - version "1.13.0" - resolved "https://registry.yarnpkg.com/fastq/-/fastq-1.13.0.tgz#616760f88a7526bdfc596b7cab8c18938c36b98c" - integrity sha512-YpkpUnK8od0o1hmeSc7UUs/eB/vIPWJYjKck2QKIzAf71Vm1AAQ3EbuZB3g2JIy+pg+ERD0vqI79KyZiB2e2Nw== - dependencies: - reusify "^1.0.4" - -fill-range@^7.0.1: - version "7.0.1" - resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.0.1.tgz#1919a6a7c75fe38b2c7c77e5198535da9acdda40" - integrity sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ== - dependencies: - to-regex-range "^5.0.1" - -find-up@^4.1.0: - version "4.1.0" - resolved "https://registry.yarnpkg.com/find-up/-/find-up-4.1.0.tgz#97afe7d6cdc0bc5928584b7c8d7b16e8a9aa5d19" - integrity sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw== - dependencies: - locate-path "^5.0.0" - path-exists "^4.0.0" - -fsevents@~2.3.2: - version "2.3.2" - resolved "https://registry.yarnpkg.com/fsevents/-/fsevents-2.3.2.tgz#8a526f78b8fdf4623b709e0b975c52c24c02fd1a" - integrity sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA== - -function-bind@^1.1.1: - version "1.1.1" - resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" - integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== - -get-caller-file@^2.0.1, get-caller-file@^2.0.5: - version "2.0.5" - resolved "https://registry.yarnpkg.com/get-caller-file/-/get-caller-file-2.0.5.tgz#4f94412a82db32f36e3b0b9741f8a97feb031f7e" - integrity sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg== - -glob-parent@^5.1.2, glob-parent@~5.1.2: - version "5.1.2" - resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-5.1.2.tgz#869832c58034fe68a4093c17dc15e8340d8401c4" - integrity sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow== - dependencies: - is-glob "^4.0.1" - -glob-parent@^6.0.2: - version "6.0.2" - resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-6.0.2.tgz#6d237d99083950c79290f24c7642a3de9a28f9e3" - integrity sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A== - dependencies: - is-glob "^4.0.3" - -has@^1.0.3: - version "1.0.3" - resolved "https://registry.yarnpkg.com/has/-/has-1.0.3.tgz#722d7cbfc1f6aa8241f16dd814e011e1f41e8796" - integrity sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw== - dependencies: - function-bind "^1.1.1" - -is-binary-path@~2.1.0: - version "2.1.0" - resolved "https://registry.yarnpkg.com/is-binary-path/-/is-binary-path-2.1.0.tgz#ea1f7f3b80f064236e83470f86c09c254fb45b09" - integrity sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw== - dependencies: - binary-extensions "^2.0.0" - -is-core-module@^2.9.0: - version "2.9.0" - resolved "https://registry.yarnpkg.com/is-core-module/-/is-core-module-2.9.0.tgz#e1c34429cd51c6dd9e09e0799e396e27b19a9c69" - integrity sha512-+5FPy5PnwmO3lvfMb0AsoPaBG+5KHUI0wYFXOtYPnVVVspTFUuMZNfNaNVRt3FZadstu2c8x23vykRW/NBoU6A== - dependencies: - has "^1.0.3" - -is-extglob@^2.1.1: - version "2.1.1" - resolved "https://registry.yarnpkg.com/is-extglob/-/is-extglob-2.1.1.tgz#a88c02535791f02ed37c76a1b9ea9773c833f8c2" - integrity sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ== - -is-fullwidth-code-point@^3.0.0: - version "3.0.0" - resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz#f116f8064fe90b3f7844a38997c0b75051269f1d" - integrity sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg== - -is-glob@^4.0.1, is-glob@^4.0.3, is-glob@~4.0.1: - version "4.0.3" - resolved "https://registry.yarnpkg.com/is-glob/-/is-glob-4.0.3.tgz#64f61e42cbbb2eec2071a9dac0b28ba1e65d5084" - integrity sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg== - dependencies: - is-extglob "^2.1.1" - -is-number@^7.0.0: - version "7.0.0" - resolved "https://registry.yarnpkg.com/is-number/-/is-number-7.0.0.tgz#7535345b896734d5f80c4d06c50955527a14f12b" - integrity sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng== - -lilconfig@^2.0.5: - version "2.0.6" - resolved "https://registry.yarnpkg.com/lilconfig/-/lilconfig-2.0.6.tgz#32a384558bd58af3d4c6e077dd1ad1d397bc69d4" - integrity sha512-9JROoBW7pobfsx+Sq2JsASvCo6Pfo6WWoUW79HuB1BCoBXD4PLWJPqDF6fNj67pqBYTbAHkE57M1kS/+L1neOg== - -locate-path@^5.0.0: - version "5.0.0" - resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-5.0.0.tgz#1afba396afd676a6d42504d0a67a3a7eb9f62aa0" - integrity sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g== - dependencies: - p-locate "^4.1.0" - -merge2@^1.3.0: - version "1.4.1" - resolved "https://registry.yarnpkg.com/merge2/-/merge2-1.4.1.tgz#4368892f885e907455a6fd7dc55c0c9d404990ae" - integrity sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg== - -micromatch@^4.0.4: - version "4.0.5" - resolved "https://registry.yarnpkg.com/micromatch/-/micromatch-4.0.5.tgz#bc8999a7cbbf77cdc89f132f6e467051b49090c6" - integrity sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA== - dependencies: - braces "^3.0.2" - picomatch "^2.3.1" - -minimist@^1.2.6: - version "1.2.6" - resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.6.tgz#8637a5b759ea0d6e98702cfb3a9283323c93af44" - integrity sha512-Jsjnk4bw3YJqYzbdyBiNsPWHPfO++UGG749Cxs6peCu5Xg4nrena6OVxOYxrQTqww0Jmwt+Ref8rggumkTLz9Q== - -nanoid@^3.3.4: - version "3.3.4" - resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.4.tgz#730b67e3cd09e2deacf03c027c81c9d9dbc5e8ab" - integrity sha512-MqBkQh/OHTS2egovRtLk45wEyNXwF+cokD+1YPf9u5VfJiRdAiRwB2froX5Co9Rh20xs4siNPm8naNotSD6RBw== - -normalize-path@^3.0.0, normalize-path@~3.0.0: - version "3.0.0" - resolved "https://registry.yarnpkg.com/normalize-path/-/normalize-path-3.0.0.tgz#0dcd69ff23a1c9b11fd0978316644a0388216a65" - integrity sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA== - -object-hash@^3.0.0: - version "3.0.0" - resolved "https://registry.yarnpkg.com/object-hash/-/object-hash-3.0.0.tgz#73f97f753e7baffc0e2cc9d6e079079744ac82e9" - integrity sha512-RSn9F68PjH9HqtltsSnqYC1XXoWe9Bju5+213R98cNGttag9q9yAOTzdbsqvIa7aNm5WffBZFpWYr2aWrklWAw== - -p-limit@^2.2.0: - version "2.3.0" - resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-2.3.0.tgz#3dd33c647a214fdfffd835933eb086da0dc21db1" - integrity sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w== - dependencies: - p-try "^2.0.0" - -p-locate@^4.1.0: - version "4.1.0" - resolved "https://registry.yarnpkg.com/p-locate/-/p-locate-4.1.0.tgz#a3428bb7088b3a60292f66919278b7c297ad4f07" - integrity sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A== - dependencies: - p-limit "^2.2.0" - -p-try@^2.0.0: - version "2.2.0" - resolved "https://registry.yarnpkg.com/p-try/-/p-try-2.2.0.tgz#cb2868540e313d61de58fafbe35ce9004d5540e6" - integrity sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ== - -path-exists@^4.0.0: - version "4.0.0" - resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-4.0.0.tgz#513bdbe2d3b95d7762e8c1137efa195c6c61b5b3" - integrity sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w== - -path-parse@^1.0.7: - version "1.0.7" - resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.7.tgz#fbc114b60ca42b30d9daf5858e4bd68bbedb6735" - integrity sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw== - -picocolors@^1.0.0: - version "1.0.0" - resolved "https://registry.yarnpkg.com/picocolors/-/picocolors-1.0.0.tgz#cb5bdc74ff3f51892236eaf79d68bc44564ab81c" - integrity sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ== - -picomatch@^2.0.4, picomatch@^2.2.1, picomatch@^2.3.1: - version "2.3.1" - resolved "https://registry.yarnpkg.com/picomatch/-/picomatch-2.3.1.tgz#3ba3833733646d9d3e4995946c1365a67fb07a42" - integrity sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA== - -pify@^2.3.0: - version "2.3.0" - resolved "https://registry.yarnpkg.com/pify/-/pify-2.3.0.tgz#ed141a6ac043a849ea588498e7dca8b15330e90c" - integrity sha512-udgsAY+fTnvv7kI7aaxbqwWNb0AHiB0qBO89PZKPkoTmGOgdbrHDKD+0B2X4uTfJ/FT1R09r9gTsjUjNJotuog== - -pngjs@^5.0.0: - version "5.0.0" - resolved "https://registry.yarnpkg.com/pngjs/-/pngjs-5.0.0.tgz#e79dd2b215767fd9c04561c01236df960bce7fbb" - integrity sha512-40QW5YalBNfQo5yRYmiw7Yz6TKKVr3h6970B2YE+3fQpsWcrbj1PzJgxeJ19DRQjhMbKPIuMY8rFaXc8moolVw== - -postcss-import@^14.1.0: - version "14.1.0" - resolved "https://registry.yarnpkg.com/postcss-import/-/postcss-import-14.1.0.tgz#a7333ffe32f0b8795303ee9e40215dac922781f0" - integrity sha512-flwI+Vgm4SElObFVPpTIT7SU7R3qk2L7PyduMcokiaVKuWv9d/U+Gm/QAd8NDLuykTWTkcrjOeD2Pp1rMeBTGw== - dependencies: - postcss-value-parser "^4.0.0" - read-cache "^1.0.0" - resolve "^1.1.7" - -postcss-js@^4.0.0: - version "4.0.0" - resolved "https://registry.yarnpkg.com/postcss-js/-/postcss-js-4.0.0.tgz#31db79889531b80dc7bc9b0ad283e418dce0ac00" - integrity sha512-77QESFBwgX4irogGVPgQ5s07vLvFqWr228qZY+w6lW599cRlK/HmnlivnnVUxkjHnCu4J16PDMHcH+e+2HbvTQ== - dependencies: - camelcase-css "^2.0.1" - -postcss-load-config@^3.1.4: - version "3.1.4" - resolved "https://registry.yarnpkg.com/postcss-load-config/-/postcss-load-config-3.1.4.tgz#1ab2571faf84bb078877e1d07905eabe9ebda855" - integrity sha512-6DiM4E7v4coTE4uzA8U//WhtPwyhiim3eyjEMFCnUpzbrkK9wJHgKDT2mR+HbtSrd/NubVaYTOpSpjUl8NQeRg== - dependencies: - lilconfig "^2.0.5" - yaml "^1.10.2" - -postcss-nested@5.0.6: - version "5.0.6" - resolved "https://registry.yarnpkg.com/postcss-nested/-/postcss-nested-5.0.6.tgz#466343f7fc8d3d46af3e7dba3fcd47d052a945bc" - integrity sha512-rKqm2Fk0KbA8Vt3AdGN0FB9OBOMDVajMG6ZCf/GoHgdxUJ4sBFp0A/uMIRm+MJUdo33YXEtjqIz8u7DAp8B7DA== - dependencies: - postcss-selector-parser "^6.0.6" - -postcss-selector-parser@^6.0.10, postcss-selector-parser@^6.0.6: - version "6.0.10" - resolved "https://registry.yarnpkg.com/postcss-selector-parser/-/postcss-selector-parser-6.0.10.tgz#79b61e2c0d1bfc2602d549e11d0876256f8df88d" - integrity sha512-IQ7TZdoaqbT+LCpShg46jnZVlhWD2w6iQYAcYXfHARZ7X1t/UGhhceQDs5X0cGqKvYlHNOuv7Oa1xmb0oQuA3w== - dependencies: - cssesc "^3.0.0" - util-deprecate "^1.0.2" - -postcss-value-parser@^4.0.0, postcss-value-parser@^4.2.0: - version "4.2.0" - resolved "https://registry.yarnpkg.com/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz#723c09920836ba6d3e5af019f92bc0971c02e514" - integrity sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ== - -postcss@^8.4.14: - version "8.4.14" - resolved "https://registry.yarnpkg.com/postcss/-/postcss-8.4.14.tgz#ee9274d5622b4858c1007a74d76e42e56fd21caf" - integrity sha512-E398TUmfAYFPBSdzgeieK2Y1+1cpdxJx8yXbK/m57nRhKSmk1GB2tO4lbLBtlkfPQTDKfe4Xqv1ASWPpayPEig== - dependencies: - nanoid "^3.3.4" - picocolors "^1.0.0" - source-map-js "^1.0.2" - -preact@^10.10.0: - version "10.10.0" - resolved "https://registry.yarnpkg.com/preact/-/preact-10.10.0.tgz#7434750a24b59dae1957d95dc0aa47a4a8e9a180" - integrity sha512-fszkg1iJJjq68I4lI8ZsmBiaoQiQHbxf1lNq+72EmC/mZOsFF5zn3k1yv9QGoFgIXzgsdSKtYymLJsrJPoamjQ== - -qrcode@^1.5.0: - version "1.5.0" - resolved "https://registry.yarnpkg.com/qrcode/-/qrcode-1.5.0.tgz#95abb8a91fdafd86f8190f2836abbfc500c72d1b" - integrity sha512-9MgRpgVc+/+47dFvQeD6U2s0Z92EsKzcHogtum4QB+UNd025WOJSHvn/hjk9xmzj7Stj95CyUAs31mrjxliEsQ== - dependencies: - dijkstrajs "^1.0.1" - encode-utf8 "^1.0.3" - pngjs "^5.0.0" - yargs "^15.3.1" - -queue-microtask@^1.2.2: - version "1.2.3" - resolved "https://registry.yarnpkg.com/queue-microtask/-/queue-microtask-1.2.3.tgz#4929228bbc724dfac43e0efb058caf7b6cfb6243" - integrity sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A== - -quick-lru@^5.1.1: - version "5.1.1" - resolved "https://registry.yarnpkg.com/quick-lru/-/quick-lru-5.1.1.tgz#366493e6b3e42a3a6885e2e99d18f80fb7a8c932" - integrity sha512-WuyALRjWPDGtt/wzJiadO5AXY+8hZ80hVpe6MyivgraREW751X3SbhRvG3eLKOYN+8VEvqLcf3wdnt44Z4S4SA== - -read-cache@^1.0.0: - version "1.0.0" - resolved "https://registry.yarnpkg.com/read-cache/-/read-cache-1.0.0.tgz#e664ef31161166c9751cdbe8dbcf86b5fb58f774" - integrity sha512-Owdv/Ft7IjOgm/i0xvNDZ1LrRANRfew4b2prF3OWMQLxLfu3bS8FVhCsrSCMK4lR56Y9ya+AThoTpDCTxCmpRA== - dependencies: - pify "^2.3.0" - -readdirp@~3.6.0: - version "3.6.0" - resolved "https://registry.yarnpkg.com/readdirp/-/readdirp-3.6.0.tgz#74a370bd857116e245b29cc97340cd431a02a6c7" - integrity sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA== - dependencies: - picomatch "^2.2.1" - -require-directory@^2.1.1: - version "2.1.1" - resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" - integrity sha1-jGStX9MNqxyXbiNE/+f3kqam30I= - -require-main-filename@^2.0.0: - version "2.0.0" - resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-2.0.0.tgz#d0b329ecc7cc0f61649f62215be69af54aa8989b" - integrity sha512-NKN5kMDylKuldxYLSUfrbo5Tuzh4hd+2E8NPPX02mZtn1VuREQToYe/ZdlJy+J3uCpfaiGF05e7B8W0iXbQHmg== - -resolve@^1.1.7, resolve@^1.22.1: - version "1.22.1" - resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.22.1.tgz#27cb2ebb53f91abb49470a928bba7558066ac177" - integrity sha512-nBpuuYuY5jFsli/JIs1oldw6fOQCBioohqWZg/2hiaOybXOft4lonv85uDOKXdf8rhyK159cxU5cDcK/NKk8zw== - dependencies: - is-core-module "^2.9.0" - path-parse "^1.0.7" - supports-preserve-symlinks-flag "^1.0.0" - -reusify@^1.0.4: - version "1.0.4" - resolved "https://registry.yarnpkg.com/reusify/-/reusify-1.0.4.tgz#90da382b1e126efc02146e90845a88db12925d76" - integrity sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw== - -run-parallel@^1.1.9: - version "1.2.0" - resolved "https://registry.yarnpkg.com/run-parallel/-/run-parallel-1.2.0.tgz#66d1368da7bdf921eb9d95bd1a9229e7f21a43ee" - integrity sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA== - dependencies: - queue-microtask "^1.2.2" - -set-blocking@^2.0.0: - version "2.0.0" - resolved "https://registry.yarnpkg.com/set-blocking/-/set-blocking-2.0.0.tgz#045f9782d011ae9a6803ddd382b24392b3d890f7" - integrity sha1-BF+XgtARrppoA93TgrJDkrPYkPc= - -source-map-js@^1.0.2: - version "1.0.2" - resolved "https://registry.yarnpkg.com/source-map-js/-/source-map-js-1.0.2.tgz#adbc361d9c62df380125e7f161f71c826f1e490c" - integrity sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw== - -string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.3: - version "4.2.3" - resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" - integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== - dependencies: - emoji-regex "^8.0.0" - is-fullwidth-code-point "^3.0.0" - strip-ansi "^6.0.1" - -strip-ansi@^6.0.0, strip-ansi@^6.0.1: - version "6.0.1" - resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9" - integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== - dependencies: - ansi-regex "^5.0.1" - -supports-preserve-symlinks-flag@^1.0.0: - version "1.0.0" - resolved "https://registry.yarnpkg.com/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz#6eda4bd344a3c94aea376d4cc31bc77311039e09" - integrity sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w== - -tailwindcss@^3.1.6: - version "3.1.6" - resolved "https://registry.yarnpkg.com/tailwindcss/-/tailwindcss-3.1.6.tgz#bcb719357776c39e6376a8d84e9834b2b19a49f1" - integrity sha512-7skAOY56erZAFQssT1xkpk+kWt2NrO45kORlxFPXUt3CiGsVPhH1smuH5XoDH6sGPXLyBv+zgCKA2HWBsgCytg== - dependencies: - arg "^5.0.2" - chokidar "^3.5.3" - color-name "^1.1.4" - detective "^5.2.1" - didyoumean "^1.2.2" - dlv "^1.1.3" - fast-glob "^3.2.11" - glob-parent "^6.0.2" - is-glob "^4.0.3" - lilconfig "^2.0.5" - normalize-path "^3.0.0" - object-hash "^3.0.0" - picocolors "^1.0.0" - postcss "^8.4.14" - postcss-import "^14.1.0" - postcss-js "^4.0.0" - postcss-load-config "^3.1.4" - postcss-nested "5.0.6" - postcss-selector-parser "^6.0.10" - postcss-value-parser "^4.2.0" - quick-lru "^5.1.1" - resolve "^1.22.1" - -to-regex-range@^5.0.1: - version "5.0.1" - resolved "https://registry.yarnpkg.com/to-regex-range/-/to-regex-range-5.0.1.tgz#1648c44aae7c8d988a326018ed72f5b4dd0392e4" - integrity sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ== - dependencies: - is-number "^7.0.0" - -typescript@>=3.0.1, typescript@^4.7.4: - version "4.7.4" - resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.7.4.tgz#1a88596d1cf47d59507a1bcdfb5b9dfe4d488235" - integrity sha512-C0WQT0gezHuw6AdY1M2jxUO83Rjf0HP7Sk1DtXj6j1EwkQNZrHAg2XPWlq62oqEhYvONq5pkC2Y9oPljWToLmQ== - -util-deprecate@^1.0.2: - version "1.0.2" - resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" - integrity sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw== - -which-module@^2.0.0: - version "2.0.0" - resolved "https://registry.yarnpkg.com/which-module/-/which-module-2.0.0.tgz#d9ef07dce77b9902b8a3a8fa4b31c3e3f7e6e87a" - integrity sha1-2e8H3Od7mQK4o6j6SzHD4/fm6Ho= - -wrap-ansi@^6.2.0: - version "6.2.0" - resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-6.2.0.tgz#e9393ba07102e6c91a3b221478f0257cd2856e53" - integrity sha512-r6lPcBGxZXlIcymEu7InxDMhdW0KDxpLgoFLcguasxCaJ/SOIZwINatK9KY/tf+ZrlywOKU0UDj3ATXUBfxJXA== - dependencies: - ansi-styles "^4.0.0" - string-width "^4.1.0" - strip-ansi "^6.0.0" - -wrap-ansi@^7.0.0: - version "7.0.0" - resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43" - integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== - dependencies: - ansi-styles "^4.0.0" - string-width "^4.1.0" - strip-ansi "^6.0.0" - -xtend@^4.0.2: - version "4.0.2" - resolved "https://registry.yarnpkg.com/xtend/-/xtend-4.0.2.tgz#bb72779f5fa465186b1f438f674fa347fdb5db54" - integrity sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ== - -xterm-addon-fit@^0.7.0: - version "0.7.0" - resolved "https://registry.yarnpkg.com/xterm-addon-fit/-/xterm-addon-fit-0.7.0.tgz#b8ade6d96e63b47443862088f6670b49fb752c6a" - integrity sha512-tQgHGoHqRTgeROPnvmtEJywLKoC/V9eNs4bLLz7iyJr1aW/QFzRwfd3MGiJ6odJd9xEfxcW36/xRU47JkD5NKQ== - -xterm-addon-web-links@^0.8.0: - version "0.8.0" - resolved "https://registry.yarnpkg.com/xterm-addon-web-links/-/xterm-addon-web-links-0.8.0.tgz#2cb1d57129271022569208578b0bf4774e7e6ea9" - integrity sha512-J4tKngmIu20ytX9SEJjAP3UGksah7iALqBtfTwT9ZnmFHVplCumYQsUJfKuS+JwMhjsjH61YXfndenLNvjRrEw== - -xterm@^5.1.0: - version "5.1.0" - resolved "https://registry.yarnpkg.com/xterm/-/xterm-5.1.0.tgz#3e160d60e6801c864b55adf19171c49d2ff2b4fc" - integrity sha512-LovENH4WDzpwynj+OTkLyZgJPeDom9Gra4DMlGAgz6pZhIDCQ+YuO7yfwanY+gVbn/mmZIStNOnVRU/ikQuAEQ== - -y18n@^4.0.0: - version "4.0.3" - resolved "https://registry.yarnpkg.com/y18n/-/y18n-4.0.3.tgz#b5f259c82cd6e336921efd7bfd8bf560de9eeedf" - integrity sha512-JKhqTOwSrqNA1NY5lSztJ1GrBiUodLMmIZuLiDaMRJ+itFd+ABVE8XBjOvIWL+rSqNDC74LCSFmlb/U4UZ4hJQ== - -y18n@^5.0.5: - version "5.0.8" - resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.8.tgz#7f4934d0f7ca8c56f95314939ddcd2dd91ce1d55" - integrity sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA== - -yaml@^1.10.2: - version "1.10.2" - resolved "https://registry.yarnpkg.com/yaml/-/yaml-1.10.2.tgz#2301c5ffbf12b467de8da2333a459e29e7920e4b" - integrity sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg== - -yargs-parser@^18.1.2: - version "18.1.3" - resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-18.1.3.tgz#be68c4975c6b2abf469236b0c870362fab09a7b0" - integrity sha512-o50j0JeToy/4K6OZcaQmW6lyXXKhq7csREXcDwk2omFPJEwUNOVtJKvmDr9EI1fAJZUyZcRF7kxGBWmRXudrCQ== - dependencies: - camelcase "^5.0.0" - decamelize "^1.2.0" - -yargs-parser@^21.0.0: - version "21.1.1" - resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-21.1.1.tgz#9096bceebf990d21bb31fa9516e0ede294a77d35" - integrity sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw== - -yargs@^15.3.1: - version "15.4.1" - resolved "https://registry.yarnpkg.com/yargs/-/yargs-15.4.1.tgz#0d87a16de01aee9d8bec2bfbf74f67851730f4f8" - integrity sha512-aePbxDmcYW++PaqBsJ+HYUFwCdv4LVvdnhBy78E57PIor8/OVvhMrADFFEDh8DHDFRv/O9i3lPhsENjO7QX0+A== - dependencies: - cliui "^6.0.0" - decamelize "^1.2.0" - find-up "^4.1.0" - get-caller-file "^2.0.1" - require-directory "^2.1.1" - require-main-filename "^2.0.0" - set-blocking "^2.0.0" - string-width "^4.2.0" - which-module "^2.0.0" - y18n "^4.0.0" - yargs-parser "^18.1.2" - -yargs@^17.2.1: - version "17.5.1" - resolved "https://registry.yarnpkg.com/yargs/-/yargs-17.5.1.tgz#e109900cab6fcb7fd44b1d8249166feb0b36e58e" - integrity sha512-t6YAJcxDkNX7NFYiVtKvWUz8l+PaKTLiL63mJYWR2GnHq2gjEWISzsLp9wg3aY36dY1j+gfIEL3pIF+XlJJfbA== - dependencies: - cliui "^7.0.2" - escalade "^3.1.1" - get-caller-file "^2.0.5" - require-directory "^2.1.1" - string-width "^4.2.3" - y18n "^5.0.5" - yargs-parser "^21.0.0" +# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY. +# yarn lockfile v1 + + +"@nodelib/fs.scandir@2.1.5": + version "2.1.5" + resolved "https://registry.yarnpkg.com/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz#7619c2eb21b25483f6d167548b4cfd5a7488c3d5" + integrity sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g== + dependencies: + "@nodelib/fs.stat" "2.0.5" + run-parallel "^1.1.9" + +"@nodelib/fs.stat@2.0.5", "@nodelib/fs.stat@^2.0.2": + version "2.0.5" + resolved "https://registry.yarnpkg.com/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz#5bd262af94e9d25bd1e71b05deed44876a222e8b" + integrity sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A== + +"@nodelib/fs.walk@^1.2.3": + version "1.2.8" + resolved "https://registry.yarnpkg.com/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz#e95737e8bb6746ddedf69c556953494f196fe69a" + integrity sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg== + dependencies: + "@nodelib/fs.scandir" "2.1.5" + fastq "^1.6.0" + +"@types/golang-wasm-exec@^1.15.0": + version "1.15.0" + resolved "https://registry.yarnpkg.com/@types/golang-wasm-exec/-/golang-wasm-exec-1.15.0.tgz#d0aafbb2b0dc07eaf45dfb83bfb6cdd5b2b3c55c" + integrity sha512-FrL97mp7WW8LqNinVkzTVKOIQKuYjQqgucnh41+1vRQ+bf1LT8uh++KRf9otZPXsa6H1p8ruIGz1BmCGttOL6Q== + +"@types/node@*": + version "18.6.1" + resolved "https://registry.yarnpkg.com/@types/node/-/node-18.6.1.tgz#828e4785ccca13f44e2fb6852ae0ef11e3e20ba5" + integrity sha512-z+2vB6yDt1fNwKOeGbckpmirO+VBDuQqecXkgeIqDlaOtmKn6hPR/viQ8cxCfqLU4fTlvM3+YjM367TukWdxpg== + +"@types/qrcode@^1.4.2": + version "1.4.2" + resolved "https://registry.yarnpkg.com/@types/qrcode/-/qrcode-1.4.2.tgz#7d7142d6fa9921f195db342ed08b539181546c74" + integrity sha512-7uNT9L4WQTNJejHTSTdaJhfBSCN73xtXaHFyBJ8TSwiLhe4PRuTue7Iph0s2nG9R/ifUaSnGhLUOZavlBEqDWQ== + dependencies: + "@types/node" "*" + +acorn-node@^1.8.2: + version "1.8.2" + resolved "https://registry.yarnpkg.com/acorn-node/-/acorn-node-1.8.2.tgz#114c95d64539e53dede23de8b9d96df7c7ae2af8" + integrity sha512-8mt+fslDufLYntIoPAaIMUe/lrbrehIiwmR3t2k9LljIzoigEPF27eLk2hy8zSGzmR/ogr7zbRKINMo1u0yh5A== + dependencies: + acorn "^7.0.0" + acorn-walk "^7.0.0" + xtend "^4.0.2" + +acorn-walk@^7.0.0: + version "7.2.0" + resolved "https://registry.yarnpkg.com/acorn-walk/-/acorn-walk-7.2.0.tgz#0de889a601203909b0fbe07b8938dc21d2e967bc" + integrity sha512-OPdCF6GsMIP+Az+aWfAAOEt2/+iVDKE7oy6lJ098aoe59oAmK76qV6Gw60SbZ8jHuG2wH058GF4pLFbYamYrVA== + +acorn@^7.0.0: + version "7.4.1" + resolved "https://registry.yarnpkg.com/acorn/-/acorn-7.4.1.tgz#feaed255973d2e77555b83dbc08851a6c63520fa" + integrity sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A== + +ansi-regex@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" + integrity sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ== + +ansi-styles@^4.0.0: + version "4.3.0" + resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937" + integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg== + dependencies: + color-convert "^2.0.1" + +anymatch@~3.1.2: + version "3.1.2" + resolved "https://registry.yarnpkg.com/anymatch/-/anymatch-3.1.2.tgz#c0557c096af32f106198f4f4e2a383537e378716" + integrity sha512-P43ePfOAIupkguHUycrc4qJ9kz8ZiuOUijaETwX7THt0Y/GNK7v0aa8rY816xWjZ7rJdA5XdMcpVFTKMq+RvWg== + dependencies: + normalize-path "^3.0.0" + picomatch "^2.0.4" + +arg@^5.0.2: + version "5.0.2" + resolved "https://registry.yarnpkg.com/arg/-/arg-5.0.2.tgz#c81433cc427c92c4dcf4865142dbca6f15acd59c" + integrity sha512-PYjyFOLKQ9y57JvQ6QLo8dAgNqswh8M1RMJYdQduT6xbWSgK36P/Z/v+p888pM69jMMfS8Xd8F6I1kQ/I9HUGg== + +binary-extensions@^2.0.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-2.2.0.tgz#75f502eeaf9ffde42fc98829645be4ea76bd9e2d" + integrity sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA== + +braces@^3.0.2, braces@~3.0.2: + version "3.0.2" + resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.2.tgz#3454e1a462ee8d599e236df336cd9ea4f8afe107" + integrity sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A== + dependencies: + fill-range "^7.0.1" + +camelcase-css@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/camelcase-css/-/camelcase-css-2.0.1.tgz#ee978f6947914cc30c6b44741b6ed1df7f043fd5" + integrity sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA== + +camelcase@^5.0.0: + version "5.3.1" + resolved "https://registry.yarnpkg.com/camelcase/-/camelcase-5.3.1.tgz#e3c9b31569e106811df242f715725a1f4c494320" + integrity sha512-L28STB170nwWS63UjtlEOE3dldQApaJXZkOI1uMFfzf3rRuPegHaHesyee+YxQ+W6SvRDQV6UrdOdRiR153wJg== + +chokidar@^3.5.3: + version "3.5.3" + resolved "https://registry.yarnpkg.com/chokidar/-/chokidar-3.5.3.tgz#1cf37c8707b932bd1af1ae22c0432e2acd1903bd" + integrity sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw== + dependencies: + anymatch "~3.1.2" + braces "~3.0.2" + glob-parent "~5.1.2" + is-binary-path "~2.1.0" + is-glob "~4.0.1" + normalize-path "~3.0.0" + readdirp "~3.6.0" + optionalDependencies: + fsevents "~2.3.2" + +cliui@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-6.0.0.tgz#511d702c0c4e41ca156d7d0e96021f23e13225b1" + integrity sha512-t6wbgtoCXvAzst7QgXxJYqPt0usEfbgQdftEPbLL/cvv6HPE5VgvqCuAIDR0NgU52ds6rFwqrgakNLrHEjCbrQ== + dependencies: + string-width "^4.2.0" + strip-ansi "^6.0.0" + wrap-ansi "^6.2.0" + +cliui@^7.0.2: + version "7.0.4" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.4.tgz#a0265ee655476fc807aea9df3df8df7783808b4f" + integrity sha512-OcRE68cOsVMXp1Yvonl/fzkQOyjLSu/8bhPDfQt0e0/Eb283TKP20Fs2MqoPsr9SwA595rRCA+QMzYc9nBP+JQ== + dependencies: + string-width "^4.2.0" + strip-ansi "^6.0.0" + wrap-ansi "^7.0.0" + +color-convert@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3" + integrity sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ== + dependencies: + color-name "~1.1.4" + +color-name@^1.1.4, color-name@~1.1.4: + version "1.1.4" + resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2" + integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA== + +cssesc@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/cssesc/-/cssesc-3.0.0.tgz#37741919903b868565e1c09ea747445cd18983ee" + integrity sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg== + +decamelize@^1.2.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/decamelize/-/decamelize-1.2.0.tgz#f6534d15148269b20352e7bee26f501f9a191290" + integrity sha1-9lNNFRSCabIDUue+4m9QH5oZEpA= + +defined@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/defined/-/defined-1.0.0.tgz#c98d9bcef75674188e110969151199e39b1fa693" + integrity sha512-Y2caI5+ZwS5c3RiNDJ6u53VhQHv+hHKwhkI1iHvceKUHw9Df6EK2zRLfjejRgMuCuxK7PfSWIMwWecceVvThjQ== + +detective@^5.2.1: + version "5.2.1" + resolved "https://registry.yarnpkg.com/detective/-/detective-5.2.1.tgz#6af01eeda11015acb0e73f933242b70f24f91034" + integrity sha512-v9XE1zRnz1wRtgurGu0Bs8uHKFSTdteYZNbIPFVhUZ39L/S79ppMpdmVOZAnoz1jfEFodc48n6MX483Xo3t1yw== + dependencies: + acorn-node "^1.8.2" + defined "^1.0.0" + minimist "^1.2.6" + +didyoumean@^1.2.2: + version "1.2.2" + resolved "https://registry.yarnpkg.com/didyoumean/-/didyoumean-1.2.2.tgz#989346ffe9e839b4555ecf5666edea0d3e8ad037" + integrity sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw== + +dijkstrajs@^1.0.1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/dijkstrajs/-/dijkstrajs-1.0.2.tgz#2e48c0d3b825462afe75ab4ad5e829c8ece36257" + integrity sha512-QV6PMaHTCNmKSeP6QoXhVTw9snc9VD8MulTT0Bd99Pacp4SS1cjcrYPgBPmibqKVtMJJfqC6XvOXgPMEEPH/fg== + +dlv@^1.1.3: + version "1.1.3" + resolved "https://registry.yarnpkg.com/dlv/-/dlv-1.1.3.tgz#5c198a8a11453596e751494d49874bc7732f2e79" + integrity sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA== + +dts-bundle-generator@^6.12.0: + version "6.12.0" + resolved "https://registry.yarnpkg.com/dts-bundle-generator/-/dts-bundle-generator-6.12.0.tgz#0a221bdce5fdd309a56c8556e645f16ed87ab07d" + integrity sha512-k/QAvuVaLIdyWRUHduDrWBe4j8PcE6TDt06+f32KHbW7/SmUPbX1O23fFtQgKwUyTBkbIjJFOFtNrF97tJcKug== + dependencies: + typescript ">=3.0.1" + yargs "^17.2.1" + +emoji-regex@^8.0.0: + version "8.0.0" + resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-8.0.0.tgz#e818fd69ce5ccfcb404594f842963bf53164cc37" + integrity sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A== + +encode-utf8@^1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/encode-utf8/-/encode-utf8-1.0.3.tgz#f30fdd31da07fb596f281beb2f6b027851994cda" + integrity sha512-ucAnuBEhUK4boH2HjVYG5Q2mQyPorvv0u/ocS+zhdw0S8AlHYY+GOFhP1Gio5z4icpP2ivFSvhtFjQi8+T9ppw== + +escalade@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" + integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== + +fast-glob@^3.2.11: + version "3.2.11" + resolved "https://registry.yarnpkg.com/fast-glob/-/fast-glob-3.2.11.tgz#a1172ad95ceb8a16e20caa5c5e56480e5129c1d9" + integrity sha512-xrO3+1bxSo3ZVHAnqzyuewYT6aMFHRAd4Kcs92MAonjwQZLsK9d0SF1IyQ3k5PoirxTW0Oe/RqFgMQ6TcNE5Ew== + dependencies: + "@nodelib/fs.stat" "^2.0.2" + "@nodelib/fs.walk" "^1.2.3" + glob-parent "^5.1.2" + merge2 "^1.3.0" + micromatch "^4.0.4" + +fastq@^1.6.0: + version "1.13.0" + resolved "https://registry.yarnpkg.com/fastq/-/fastq-1.13.0.tgz#616760f88a7526bdfc596b7cab8c18938c36b98c" + integrity sha512-YpkpUnK8od0o1hmeSc7UUs/eB/vIPWJYjKck2QKIzAf71Vm1AAQ3EbuZB3g2JIy+pg+ERD0vqI79KyZiB2e2Nw== + dependencies: + reusify "^1.0.4" + +fill-range@^7.0.1: + version "7.0.1" + resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.0.1.tgz#1919a6a7c75fe38b2c7c77e5198535da9acdda40" + integrity sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ== + dependencies: + to-regex-range "^5.0.1" + +find-up@^4.1.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/find-up/-/find-up-4.1.0.tgz#97afe7d6cdc0bc5928584b7c8d7b16e8a9aa5d19" + integrity sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw== + dependencies: + locate-path "^5.0.0" + path-exists "^4.0.0" + +fsevents@~2.3.2: + version "2.3.2" + resolved "https://registry.yarnpkg.com/fsevents/-/fsevents-2.3.2.tgz#8a526f78b8fdf4623b709e0b975c52c24c02fd1a" + integrity sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA== + +function-bind@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" + integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== + +get-caller-file@^2.0.1, get-caller-file@^2.0.5: + version "2.0.5" + resolved "https://registry.yarnpkg.com/get-caller-file/-/get-caller-file-2.0.5.tgz#4f94412a82db32f36e3b0b9741f8a97feb031f7e" + integrity sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg== + +glob-parent@^5.1.2, glob-parent@~5.1.2: + version "5.1.2" + resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-5.1.2.tgz#869832c58034fe68a4093c17dc15e8340d8401c4" + integrity sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow== + dependencies: + is-glob "^4.0.1" + +glob-parent@^6.0.2: + version "6.0.2" + resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-6.0.2.tgz#6d237d99083950c79290f24c7642a3de9a28f9e3" + integrity sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A== + dependencies: + is-glob "^4.0.3" + +has@^1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/has/-/has-1.0.3.tgz#722d7cbfc1f6aa8241f16dd814e011e1f41e8796" + integrity sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw== + dependencies: + function-bind "^1.1.1" + +is-binary-path@~2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/is-binary-path/-/is-binary-path-2.1.0.tgz#ea1f7f3b80f064236e83470f86c09c254fb45b09" + integrity sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw== + dependencies: + binary-extensions "^2.0.0" + +is-core-module@^2.9.0: + version "2.9.0" + resolved "https://registry.yarnpkg.com/is-core-module/-/is-core-module-2.9.0.tgz#e1c34429cd51c6dd9e09e0799e396e27b19a9c69" + integrity sha512-+5FPy5PnwmO3lvfMb0AsoPaBG+5KHUI0wYFXOtYPnVVVspTFUuMZNfNaNVRt3FZadstu2c8x23vykRW/NBoU6A== + dependencies: + has "^1.0.3" + +is-extglob@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/is-extglob/-/is-extglob-2.1.1.tgz#a88c02535791f02ed37c76a1b9ea9773c833f8c2" + integrity sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ== + +is-fullwidth-code-point@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz#f116f8064fe90b3f7844a38997c0b75051269f1d" + integrity sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg== + +is-glob@^4.0.1, is-glob@^4.0.3, is-glob@~4.0.1: + version "4.0.3" + resolved "https://registry.yarnpkg.com/is-glob/-/is-glob-4.0.3.tgz#64f61e42cbbb2eec2071a9dac0b28ba1e65d5084" + integrity sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg== + dependencies: + is-extglob "^2.1.1" + +is-number@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/is-number/-/is-number-7.0.0.tgz#7535345b896734d5f80c4d06c50955527a14f12b" + integrity sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng== + +lilconfig@^2.0.5: + version "2.0.6" + resolved "https://registry.yarnpkg.com/lilconfig/-/lilconfig-2.0.6.tgz#32a384558bd58af3d4c6e077dd1ad1d397bc69d4" + integrity sha512-9JROoBW7pobfsx+Sq2JsASvCo6Pfo6WWoUW79HuB1BCoBXD4PLWJPqDF6fNj67pqBYTbAHkE57M1kS/+L1neOg== + +locate-path@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-5.0.0.tgz#1afba396afd676a6d42504d0a67a3a7eb9f62aa0" + integrity sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g== + dependencies: + p-locate "^4.1.0" + +merge2@^1.3.0: + version "1.4.1" + resolved "https://registry.yarnpkg.com/merge2/-/merge2-1.4.1.tgz#4368892f885e907455a6fd7dc55c0c9d404990ae" + integrity sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg== + +micromatch@^4.0.4: + version "4.0.5" + resolved "https://registry.yarnpkg.com/micromatch/-/micromatch-4.0.5.tgz#bc8999a7cbbf77cdc89f132f6e467051b49090c6" + integrity sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA== + dependencies: + braces "^3.0.2" + picomatch "^2.3.1" + +minimist@^1.2.6: + version "1.2.6" + resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.6.tgz#8637a5b759ea0d6e98702cfb3a9283323c93af44" + integrity sha512-Jsjnk4bw3YJqYzbdyBiNsPWHPfO++UGG749Cxs6peCu5Xg4nrena6OVxOYxrQTqww0Jmwt+Ref8rggumkTLz9Q== + +nanoid@^3.3.4: + version "3.3.4" + resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.4.tgz#730b67e3cd09e2deacf03c027c81c9d9dbc5e8ab" + integrity sha512-MqBkQh/OHTS2egovRtLk45wEyNXwF+cokD+1YPf9u5VfJiRdAiRwB2froX5Co9Rh20xs4siNPm8naNotSD6RBw== + +normalize-path@^3.0.0, normalize-path@~3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/normalize-path/-/normalize-path-3.0.0.tgz#0dcd69ff23a1c9b11fd0978316644a0388216a65" + integrity sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA== + +object-hash@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/object-hash/-/object-hash-3.0.0.tgz#73f97f753e7baffc0e2cc9d6e079079744ac82e9" + integrity sha512-RSn9F68PjH9HqtltsSnqYC1XXoWe9Bju5+213R98cNGttag9q9yAOTzdbsqvIa7aNm5WffBZFpWYr2aWrklWAw== + +p-limit@^2.2.0: + version "2.3.0" + resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-2.3.0.tgz#3dd33c647a214fdfffd835933eb086da0dc21db1" + integrity sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w== + dependencies: + p-try "^2.0.0" + +p-locate@^4.1.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/p-locate/-/p-locate-4.1.0.tgz#a3428bb7088b3a60292f66919278b7c297ad4f07" + integrity sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A== + dependencies: + p-limit "^2.2.0" + +p-try@^2.0.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/p-try/-/p-try-2.2.0.tgz#cb2868540e313d61de58fafbe35ce9004d5540e6" + integrity sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ== + +path-exists@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-4.0.0.tgz#513bdbe2d3b95d7762e8c1137efa195c6c61b5b3" + integrity sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w== + +path-parse@^1.0.7: + version "1.0.7" + resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.7.tgz#fbc114b60ca42b30d9daf5858e4bd68bbedb6735" + integrity sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw== + +picocolors@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/picocolors/-/picocolors-1.0.0.tgz#cb5bdc74ff3f51892236eaf79d68bc44564ab81c" + integrity sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ== + +picomatch@^2.0.4, picomatch@^2.2.1, picomatch@^2.3.1: + version "2.3.1" + resolved "https://registry.yarnpkg.com/picomatch/-/picomatch-2.3.1.tgz#3ba3833733646d9d3e4995946c1365a67fb07a42" + integrity sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA== + +pify@^2.3.0: + version "2.3.0" + resolved "https://registry.yarnpkg.com/pify/-/pify-2.3.0.tgz#ed141a6ac043a849ea588498e7dca8b15330e90c" + integrity sha512-udgsAY+fTnvv7kI7aaxbqwWNb0AHiB0qBO89PZKPkoTmGOgdbrHDKD+0B2X4uTfJ/FT1R09r9gTsjUjNJotuog== + +pngjs@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/pngjs/-/pngjs-5.0.0.tgz#e79dd2b215767fd9c04561c01236df960bce7fbb" + integrity sha512-40QW5YalBNfQo5yRYmiw7Yz6TKKVr3h6970B2YE+3fQpsWcrbj1PzJgxeJ19DRQjhMbKPIuMY8rFaXc8moolVw== + +postcss-import@^14.1.0: + version "14.1.0" + resolved "https://registry.yarnpkg.com/postcss-import/-/postcss-import-14.1.0.tgz#a7333ffe32f0b8795303ee9e40215dac922781f0" + integrity sha512-flwI+Vgm4SElObFVPpTIT7SU7R3qk2L7PyduMcokiaVKuWv9d/U+Gm/QAd8NDLuykTWTkcrjOeD2Pp1rMeBTGw== + dependencies: + postcss-value-parser "^4.0.0" + read-cache "^1.0.0" + resolve "^1.1.7" + +postcss-js@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/postcss-js/-/postcss-js-4.0.0.tgz#31db79889531b80dc7bc9b0ad283e418dce0ac00" + integrity sha512-77QESFBwgX4irogGVPgQ5s07vLvFqWr228qZY+w6lW599cRlK/HmnlivnnVUxkjHnCu4J16PDMHcH+e+2HbvTQ== + dependencies: + camelcase-css "^2.0.1" + +postcss-load-config@^3.1.4: + version "3.1.4" + resolved "https://registry.yarnpkg.com/postcss-load-config/-/postcss-load-config-3.1.4.tgz#1ab2571faf84bb078877e1d07905eabe9ebda855" + integrity sha512-6DiM4E7v4coTE4uzA8U//WhtPwyhiim3eyjEMFCnUpzbrkK9wJHgKDT2mR+HbtSrd/NubVaYTOpSpjUl8NQeRg== + dependencies: + lilconfig "^2.0.5" + yaml "^1.10.2" + +postcss-nested@5.0.6: + version "5.0.6" + resolved "https://registry.yarnpkg.com/postcss-nested/-/postcss-nested-5.0.6.tgz#466343f7fc8d3d46af3e7dba3fcd47d052a945bc" + integrity sha512-rKqm2Fk0KbA8Vt3AdGN0FB9OBOMDVajMG6ZCf/GoHgdxUJ4sBFp0A/uMIRm+MJUdo33YXEtjqIz8u7DAp8B7DA== + dependencies: + postcss-selector-parser "^6.0.6" + +postcss-selector-parser@^6.0.10, postcss-selector-parser@^6.0.6: + version "6.0.10" + resolved "https://registry.yarnpkg.com/postcss-selector-parser/-/postcss-selector-parser-6.0.10.tgz#79b61e2c0d1bfc2602d549e11d0876256f8df88d" + integrity sha512-IQ7TZdoaqbT+LCpShg46jnZVlhWD2w6iQYAcYXfHARZ7X1t/UGhhceQDs5X0cGqKvYlHNOuv7Oa1xmb0oQuA3w== + dependencies: + cssesc "^3.0.0" + util-deprecate "^1.0.2" + +postcss-value-parser@^4.0.0, postcss-value-parser@^4.2.0: + version "4.2.0" + resolved "https://registry.yarnpkg.com/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz#723c09920836ba6d3e5af019f92bc0971c02e514" + integrity sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ== + +postcss@^8.4.14: + version "8.4.14" + resolved "https://registry.yarnpkg.com/postcss/-/postcss-8.4.14.tgz#ee9274d5622b4858c1007a74d76e42e56fd21caf" + integrity sha512-E398TUmfAYFPBSdzgeieK2Y1+1cpdxJx8yXbK/m57nRhKSmk1GB2tO4lbLBtlkfPQTDKfe4Xqv1ASWPpayPEig== + dependencies: + nanoid "^3.3.4" + picocolors "^1.0.0" + source-map-js "^1.0.2" + +preact@^10.10.0: + version "10.10.0" + resolved "https://registry.yarnpkg.com/preact/-/preact-10.10.0.tgz#7434750a24b59dae1957d95dc0aa47a4a8e9a180" + integrity sha512-fszkg1iJJjq68I4lI8ZsmBiaoQiQHbxf1lNq+72EmC/mZOsFF5zn3k1yv9QGoFgIXzgsdSKtYymLJsrJPoamjQ== + +qrcode@^1.5.0: + version "1.5.0" + resolved "https://registry.yarnpkg.com/qrcode/-/qrcode-1.5.0.tgz#95abb8a91fdafd86f8190f2836abbfc500c72d1b" + integrity sha512-9MgRpgVc+/+47dFvQeD6U2s0Z92EsKzcHogtum4QB+UNd025WOJSHvn/hjk9xmzj7Stj95CyUAs31mrjxliEsQ== + dependencies: + dijkstrajs "^1.0.1" + encode-utf8 "^1.0.3" + pngjs "^5.0.0" + yargs "^15.3.1" + +queue-microtask@^1.2.2: + version "1.2.3" + resolved "https://registry.yarnpkg.com/queue-microtask/-/queue-microtask-1.2.3.tgz#4929228bbc724dfac43e0efb058caf7b6cfb6243" + integrity sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A== + +quick-lru@^5.1.1: + version "5.1.1" + resolved "https://registry.yarnpkg.com/quick-lru/-/quick-lru-5.1.1.tgz#366493e6b3e42a3a6885e2e99d18f80fb7a8c932" + integrity sha512-WuyALRjWPDGtt/wzJiadO5AXY+8hZ80hVpe6MyivgraREW751X3SbhRvG3eLKOYN+8VEvqLcf3wdnt44Z4S4SA== + +read-cache@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/read-cache/-/read-cache-1.0.0.tgz#e664ef31161166c9751cdbe8dbcf86b5fb58f774" + integrity sha512-Owdv/Ft7IjOgm/i0xvNDZ1LrRANRfew4b2prF3OWMQLxLfu3bS8FVhCsrSCMK4lR56Y9ya+AThoTpDCTxCmpRA== + dependencies: + pify "^2.3.0" + +readdirp@~3.6.0: + version "3.6.0" + resolved "https://registry.yarnpkg.com/readdirp/-/readdirp-3.6.0.tgz#74a370bd857116e245b29cc97340cd431a02a6c7" + integrity sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA== + dependencies: + picomatch "^2.2.1" + +require-directory@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" + integrity sha1-jGStX9MNqxyXbiNE/+f3kqam30I= + +require-main-filename@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-2.0.0.tgz#d0b329ecc7cc0f61649f62215be69af54aa8989b" + integrity sha512-NKN5kMDylKuldxYLSUfrbo5Tuzh4hd+2E8NPPX02mZtn1VuREQToYe/ZdlJy+J3uCpfaiGF05e7B8W0iXbQHmg== + +resolve@^1.1.7, resolve@^1.22.1: + version "1.22.1" + resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.22.1.tgz#27cb2ebb53f91abb49470a928bba7558066ac177" + integrity sha512-nBpuuYuY5jFsli/JIs1oldw6fOQCBioohqWZg/2hiaOybXOft4lonv85uDOKXdf8rhyK159cxU5cDcK/NKk8zw== + dependencies: + is-core-module "^2.9.0" + path-parse "^1.0.7" + supports-preserve-symlinks-flag "^1.0.0" + +reusify@^1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/reusify/-/reusify-1.0.4.tgz#90da382b1e126efc02146e90845a88db12925d76" + integrity sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw== + +run-parallel@^1.1.9: + version "1.2.0" + resolved "https://registry.yarnpkg.com/run-parallel/-/run-parallel-1.2.0.tgz#66d1368da7bdf921eb9d95bd1a9229e7f21a43ee" + integrity sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA== + dependencies: + queue-microtask "^1.2.2" + +set-blocking@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/set-blocking/-/set-blocking-2.0.0.tgz#045f9782d011ae9a6803ddd382b24392b3d890f7" + integrity sha1-BF+XgtARrppoA93TgrJDkrPYkPc= + +source-map-js@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/source-map-js/-/source-map-js-1.0.2.tgz#adbc361d9c62df380125e7f161f71c826f1e490c" + integrity sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw== + +string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.3: + version "4.2.3" + resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" + integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== + dependencies: + emoji-regex "^8.0.0" + is-fullwidth-code-point "^3.0.0" + strip-ansi "^6.0.1" + +strip-ansi@^6.0.0, strip-ansi@^6.0.1: + version "6.0.1" + resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9" + integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== + dependencies: + ansi-regex "^5.0.1" + +supports-preserve-symlinks-flag@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz#6eda4bd344a3c94aea376d4cc31bc77311039e09" + integrity sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w== + +tailwindcss@^3.1.6: + version "3.1.6" + resolved "https://registry.yarnpkg.com/tailwindcss/-/tailwindcss-3.1.6.tgz#bcb719357776c39e6376a8d84e9834b2b19a49f1" + integrity sha512-7skAOY56erZAFQssT1xkpk+kWt2NrO45kORlxFPXUt3CiGsVPhH1smuH5XoDH6sGPXLyBv+zgCKA2HWBsgCytg== + dependencies: + arg "^5.0.2" + chokidar "^3.5.3" + color-name "^1.1.4" + detective "^5.2.1" + didyoumean "^1.2.2" + dlv "^1.1.3" + fast-glob "^3.2.11" + glob-parent "^6.0.2" + is-glob "^4.0.3" + lilconfig "^2.0.5" + normalize-path "^3.0.0" + object-hash "^3.0.0" + picocolors "^1.0.0" + postcss "^8.4.14" + postcss-import "^14.1.0" + postcss-js "^4.0.0" + postcss-load-config "^3.1.4" + postcss-nested "5.0.6" + postcss-selector-parser "^6.0.10" + postcss-value-parser "^4.2.0" + quick-lru "^5.1.1" + resolve "^1.22.1" + +to-regex-range@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/to-regex-range/-/to-regex-range-5.0.1.tgz#1648c44aae7c8d988a326018ed72f5b4dd0392e4" + integrity sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ== + dependencies: + is-number "^7.0.0" + +typescript@>=3.0.1, typescript@^4.7.4: + version "4.7.4" + resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.7.4.tgz#1a88596d1cf47d59507a1bcdfb5b9dfe4d488235" + integrity sha512-C0WQT0gezHuw6AdY1M2jxUO83Rjf0HP7Sk1DtXj6j1EwkQNZrHAg2XPWlq62oqEhYvONq5pkC2Y9oPljWToLmQ== + +util-deprecate@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" + integrity sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw== + +which-module@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/which-module/-/which-module-2.0.0.tgz#d9ef07dce77b9902b8a3a8fa4b31c3e3f7e6e87a" + integrity sha1-2e8H3Od7mQK4o6j6SzHD4/fm6Ho= + +wrap-ansi@^6.2.0: + version "6.2.0" + resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-6.2.0.tgz#e9393ba07102e6c91a3b221478f0257cd2856e53" + integrity sha512-r6lPcBGxZXlIcymEu7InxDMhdW0KDxpLgoFLcguasxCaJ/SOIZwINatK9KY/tf+ZrlywOKU0UDj3ATXUBfxJXA== + dependencies: + ansi-styles "^4.0.0" + string-width "^4.1.0" + strip-ansi "^6.0.0" + +wrap-ansi@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43" + integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== + dependencies: + ansi-styles "^4.0.0" + string-width "^4.1.0" + strip-ansi "^6.0.0" + +xtend@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/xtend/-/xtend-4.0.2.tgz#bb72779f5fa465186b1f438f674fa347fdb5db54" + integrity sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ== + +xterm-addon-fit@^0.7.0: + version "0.7.0" + resolved "https://registry.yarnpkg.com/xterm-addon-fit/-/xterm-addon-fit-0.7.0.tgz#b8ade6d96e63b47443862088f6670b49fb752c6a" + integrity sha512-tQgHGoHqRTgeROPnvmtEJywLKoC/V9eNs4bLLz7iyJr1aW/QFzRwfd3MGiJ6odJd9xEfxcW36/xRU47JkD5NKQ== + +xterm-addon-web-links@^0.8.0: + version "0.8.0" + resolved "https://registry.yarnpkg.com/xterm-addon-web-links/-/xterm-addon-web-links-0.8.0.tgz#2cb1d57129271022569208578b0bf4774e7e6ea9" + integrity sha512-J4tKngmIu20ytX9SEJjAP3UGksah7iALqBtfTwT9ZnmFHVplCumYQsUJfKuS+JwMhjsjH61YXfndenLNvjRrEw== + +xterm@^5.1.0: + version "5.1.0" + resolved "https://registry.yarnpkg.com/xterm/-/xterm-5.1.0.tgz#3e160d60e6801c864b55adf19171c49d2ff2b4fc" + integrity sha512-LovENH4WDzpwynj+OTkLyZgJPeDom9Gra4DMlGAgz6pZhIDCQ+YuO7yfwanY+gVbn/mmZIStNOnVRU/ikQuAEQ== + +y18n@^4.0.0: + version "4.0.3" + resolved "https://registry.yarnpkg.com/y18n/-/y18n-4.0.3.tgz#b5f259c82cd6e336921efd7bfd8bf560de9eeedf" + integrity sha512-JKhqTOwSrqNA1NY5lSztJ1GrBiUodLMmIZuLiDaMRJ+itFd+ABVE8XBjOvIWL+rSqNDC74LCSFmlb/U4UZ4hJQ== + +y18n@^5.0.5: + version "5.0.8" + resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.8.tgz#7f4934d0f7ca8c56f95314939ddcd2dd91ce1d55" + integrity sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA== + +yaml@^1.10.2: + version "1.10.2" + resolved "https://registry.yarnpkg.com/yaml/-/yaml-1.10.2.tgz#2301c5ffbf12b467de8da2333a459e29e7920e4b" + integrity sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg== + +yargs-parser@^18.1.2: + version "18.1.3" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-18.1.3.tgz#be68c4975c6b2abf469236b0c870362fab09a7b0" + integrity sha512-o50j0JeToy/4K6OZcaQmW6lyXXKhq7csREXcDwk2omFPJEwUNOVtJKvmDr9EI1fAJZUyZcRF7kxGBWmRXudrCQ== + dependencies: + camelcase "^5.0.0" + decamelize "^1.2.0" + +yargs-parser@^21.0.0: + version "21.1.1" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-21.1.1.tgz#9096bceebf990d21bb31fa9516e0ede294a77d35" + integrity sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw== + +yargs@^15.3.1: + version "15.4.1" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-15.4.1.tgz#0d87a16de01aee9d8bec2bfbf74f67851730f4f8" + integrity sha512-aePbxDmcYW++PaqBsJ+HYUFwCdv4LVvdnhBy78E57PIor8/OVvhMrADFFEDh8DHDFRv/O9i3lPhsENjO7QX0+A== + dependencies: + cliui "^6.0.0" + decamelize "^1.2.0" + find-up "^4.1.0" + get-caller-file "^2.0.1" + require-directory "^2.1.1" + require-main-filename "^2.0.0" + set-blocking "^2.0.0" + string-width "^4.2.0" + which-module "^2.0.0" + y18n "^4.0.0" + yargs-parser "^18.1.2" + +yargs@^17.2.1: + version "17.5.1" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-17.5.1.tgz#e109900cab6fcb7fd44b1d8249166feb0b36e58e" + integrity sha512-t6YAJcxDkNX7NFYiVtKvWUz8l+PaKTLiL63mJYWR2GnHq2gjEWISzsLp9wg3aY36dY1j+gfIEL3pIF+XlJJfbA== + dependencies: + cliui "^7.0.2" + escalade "^3.1.1" + get-caller-file "^2.0.5" + require-directory "^2.1.1" + string-width "^4.2.3" + y18n "^5.0.5" + yargs-parser "^21.0.0" diff --git a/cmd/tsshd/tsshd.go b/cmd/tsshd/tsshd.go index 1ec09a0d47611..950eb661cdb23 100644 --- a/cmd/tsshd/tsshd.go +++ b/cmd/tsshd/tsshd.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -// The tsshd binary was an experimental SSH server that accepts connections -// from anybody on the same Tailscale network. -// -// Its functionality moved into tailscaled. -// -// See https://github.com/tailscale/tailscale/issues/3802 -package main +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +// The tsshd binary was an experimental SSH server that accepts connections +// from anybody on the same Tailscale network. +// +// Its functionality moved into tailscaled. +// +// See https://github.com/tailscale/tailscale/issues/3802 +package main diff --git a/control/controlbase/conn.go b/control/controlbase/conn.go index b6fc53b3a40f3..dc22212e887cb 100644 --- a/control/controlbase/conn.go +++ b/control/controlbase/conn.go @@ -1,408 +1,408 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package controlbase implements the base transport of the Tailscale -// 2021 control protocol. -// -// The base transport implements Noise IK, instantiated with -// Curve25519, ChaCha20Poly1305 and BLAKE2s. -package controlbase - -import ( - "crypto/cipher" - "encoding/binary" - "fmt" - "net" - "sync" - "time" - - "golang.org/x/crypto/blake2s" - chp "golang.org/x/crypto/chacha20poly1305" - "tailscale.com/types/key" -) - -const ( - // maxMessageSize is the maximum size of a protocol frame on the - // wire, including header and payload. - maxMessageSize = 4096 - // maxCiphertextSize is the maximum amount of ciphertext bytes - // that one protocol frame can carry, after framing. - maxCiphertextSize = maxMessageSize - 3 - // maxPlaintextSize is the maximum amount of plaintext bytes that - // one protocol frame can carry, after encryption and framing. - maxPlaintextSize = maxCiphertextSize - chp.Overhead -) - -// A Conn is a secured Noise connection. It implements the net.Conn -// interface, with the unusual trait that any write error (including a -// SetWriteDeadline induced i/o timeout) causes all future writes to -// fail. -type Conn struct { - conn net.Conn - version uint16 - peer key.MachinePublic - handshakeHash [blake2s.Size]byte - rx rxState - tx txState -} - -// rxState is all the Conn state that Read uses. -type rxState struct { - sync.Mutex - cipher cipher.AEAD - nonce nonce - buf *maxMsgBuffer // or nil when reads exhausted - n int // number of valid bytes in buf - next int // offset of next undecrypted packet - plaintext []byte // slice into buf of decrypted bytes - hdrBuf [headerLen]byte // small buffer used when buf is nil -} - -// txState is all the Conn state that Write uses. -type txState struct { - sync.Mutex - cipher cipher.AEAD - nonce nonce - err error // records the first partial write error for all future calls -} - -// ProtocolVersion returns the protocol version that was used to -// establish this Conn. -func (c *Conn) ProtocolVersion() int { - return int(c.version) -} - -// HandshakeHash returns the Noise handshake hash for the connection, -// which can be used to bind other messages to this connection -// (i.e. to ensure that the message wasn't replayed from a different -// connection). -func (c *Conn) HandshakeHash() [blake2s.Size]byte { - return c.handshakeHash -} - -// Peer returns the peer's long-term public key. -func (c *Conn) Peer() key.MachinePublic { - return c.peer -} - -// readNLocked reads into c.rx.buf until buf contains at least total -// bytes. Returns a slice of the total bytes in rxBuf, or an -// error if fewer than total bytes are available. -// -// It may be called with a nil c.rx.buf only if total == headerLen. -// -// On success, c.rx.buf will be non-nil. -func (c *Conn) readNLocked(total int) ([]byte, error) { - if total > maxMessageSize { - return nil, errReadTooBig{total} - } - for { - if total <= c.rx.n { - return c.rx.buf[:total], nil - } - var n int - var err error - if c.rx.buf == nil { - if c.rx.n != 0 || total != headerLen { - panic("unexpected") - } - // Optimization to reduce memory usage. - // Most connections are blocked forever waiting for - // a read, so we don't want c.rx.buf to be allocated until - // we know there's data to read. Instead, when we're - // waiting for data to arrive here, read into the - // 3 byte hdrBuf: - n, err = c.conn.Read(c.rx.hdrBuf[:]) - if n > 0 { - c.rx.buf = getMaxMsgBuffer() - copy(c.rx.buf[:], c.rx.hdrBuf[:n]) - } - } else { - n, err = c.conn.Read(c.rx.buf[c.rx.n:]) - } - c.rx.n += n - if err != nil { - return nil, err - } - } -} - -// decryptLocked decrypts msg (which is header+ciphertext) in-place -// and sets c.rx.plaintext to the decrypted bytes. -func (c *Conn) decryptLocked(msg []byte) (err error) { - if msgType := msg[0]; msgType != msgTypeRecord { - return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord) - } - // We don't check the length field here, because the caller - // already did in order to figure out how big the msg slice should - // be. - ciphertext := msg[headerLen:] - - if !c.rx.nonce.Valid() { - return errCipherExhausted{} - } - - c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil) - c.rx.nonce.Increment() - - if err != nil { - // Once a decryption has failed, our Conn is no longer - // synchronized with our peer. Nuke the cipher state to be - // safe, so that no further decryptions are attempted. Future - // read attempts will return net.ErrClosed. - c.rx.cipher = nil - } - return err -} - -// encryptLocked encrypts plaintext into buf (including the -// packet header) and returns a slice of the ciphertext, or an error -// if the cipher is exhausted (i.e. can no longer be used safely). -func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) { - if !c.tx.nonce.Valid() { - // Received 2^64-1 messages on this cipher state. Connection - // is no longer usable. - return nil, errCipherExhausted{} - } - - buf[0] = msgTypeRecord - binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead)) - ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil) - c.tx.nonce.Increment() - - return ret, nil -} - -// wholeMessageLocked returns a slice of one whole Noise transport -// message from c.rx.buf, if one whole message is available, and -// advances the read state to the next Noise message in the -// buffer. Returns nil without advancing read state if there isn't one -// whole message in c.rx.buf. -func (c *Conn) wholeMessageLocked() []byte { - available := c.rx.n - c.rx.next - if available < headerLen { - return nil - } - bs := c.rx.buf[c.rx.next:c.rx.n] - totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) - if len(bs) < totalSize { - return nil - } - c.rx.next += totalSize - return bs[:totalSize] -} - -// decryptOneLocked decrypts one Noise transport message, reading from -// c.conn as needed, and sets c.rx.plaintext to point to the decrypted -// bytes. c.rx.plaintext is only valid if err == nil. -func (c *Conn) decryptOneLocked() error { - c.rx.plaintext = nil - - // Fast path: do we have one whole ciphertext frame buffered - // already? - if bs := c.wholeMessageLocked(); bs != nil { - return c.decryptLocked(bs) - } - - if c.rx.next != 0 { - // To simplify the read logic, move the remainder of the - // buffered bytes back to the head of the buffer, so we can - // grow it without worrying about wraparound. - c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n]) - c.rx.next = 0 - } - - // Return our buffer to the pool if it's empty, lest we be - // blocked in a long Read call, reading the 3 byte header. We - // don't to keep that buffer unnecessarily alive. - if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil { - bufPool.Put(c.rx.buf) - c.rx.buf = nil - } - - bs, err := c.readNLocked(headerLen) - if err != nil { - return err - } - // The rest of the header (besides the length field) gets verified - // in decryptLocked, not here. - messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) - bs, err = c.readNLocked(messageLen) - if err != nil { - return err - } - - c.rx.next = len(bs) - - return c.decryptLocked(bs) -} - -// Read implements io.Reader. -func (c *Conn) Read(bs []byte) (int, error) { - c.rx.Lock() - defer c.rx.Unlock() - - if c.rx.cipher == nil { - return 0, net.ErrClosed - } - // If no plaintext is buffered, decrypt incoming frames until we - // have some plaintext. Zero-byte Noise frames are allowed in this - // protocol, which is why we have to loop here rather than decrypt - // a single additional frame. - for len(c.rx.plaintext) == 0 { - if err := c.decryptOneLocked(); err != nil { - return 0, err - } - } - n := copy(bs, c.rx.plaintext) - c.rx.plaintext = c.rx.plaintext[n:] - - // Lose slice's underlying array pointer to unneeded memory so - // GC can collect more. - if len(c.rx.plaintext) == 0 { - c.rx.plaintext = nil - } - return n, nil -} - -// Write implements io.Writer. -func (c *Conn) Write(bs []byte) (n int, err error) { - c.tx.Lock() - defer c.tx.Unlock() - - if c.tx.err != nil { - return 0, c.tx.err - } - defer func() { - if err != nil { - // All write errors are fatal for this conn, so clear the - // cipher state whenever an error happens. - c.tx.cipher = nil - } - if c.tx.err == nil { - // Only set c.tx.err if not nil so that we can return one - // error on the first failure, and a different one for - // subsequent calls. See the error handling around Write - // below for why. - c.tx.err = err - } - }() - - if c.tx.cipher == nil { - return 0, net.ErrClosed - } - - buf := getMaxMsgBuffer() - defer bufPool.Put(buf) - - var sent int - for len(bs) > 0 { - toSend := bs - if len(toSend) > maxPlaintextSize { - toSend = bs[:maxPlaintextSize] - } - bs = bs[len(toSend):] - - ciphertext, err := c.encryptLocked(toSend, buf) - if err != nil { - return sent, err - } - if _, err := c.conn.Write(ciphertext); err != nil { - // Return the raw error on the Write that actually - // failed. For future writes, return that error wrapped in - // a desync error. - c.tx.err = errPartialWrite{err} - return sent, err - } - sent += len(toSend) - } - return sent, nil -} - -// Close implements io.Closer. -func (c *Conn) Close() error { - closeErr := c.conn.Close() // unblocks any waiting reads or writes - - // Remove references to live cipher state. Strictly speaking this - // is unnecessary, but we want to try and hand the active cipher - // state to the garbage collector promptly, to preserve perfect - // forward secrecy as much as we can. - c.rx.Lock() - c.rx.cipher = nil - c.rx.Unlock() - c.tx.Lock() - c.tx.cipher = nil - c.tx.Unlock() - return closeErr -} - -func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } -func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } -func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } -func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } -func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } - -// errCipherExhausted is the error returned when we run out of nonces -// on a cipher. -type errCipherExhausted struct{} - -func (errCipherExhausted) Error() string { - return "cipher exhausted, no more nonces available for current key" -} -func (errCipherExhausted) Timeout() bool { return false } -func (errCipherExhausted) Temporary() bool { return false } - -// errPartialWrite is the error returned when the cipher state has -// become unusable due to a past partial write. -type errPartialWrite struct { - err error -} - -func (e errPartialWrite) Error() string { - return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err) -} -func (e errPartialWrite) Unwrap() error { return e.err } -func (e errPartialWrite) Temporary() bool { return false } -func (e errPartialWrite) Timeout() bool { return false } - -// errReadTooBig is the error returned when the peer sent an -// unacceptably large Noise frame. -type errReadTooBig struct { - requested int -} - -func (e errReadTooBig) Error() string { - return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested) -} -func (e errReadTooBig) Temporary() bool { - // permanent error because this error only occurs when our peer - // sends us a frame so large we're unwilling to ever decode it. - return false -} -func (e errReadTooBig) Timeout() bool { return false } - -type nonce [chp.NonceSize]byte - -func (n *nonce) Valid() bool { - return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce -} - -func (n *nonce) Increment() { - if !n.Valid() { - panic("increment of invalid nonce") - } - binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:])) -} - -type maxMsgBuffer [maxMessageSize]byte - -// bufPool holds the temporary buffers for Conn.Read & Write. -var bufPool = &sync.Pool{ - New: func() any { - return new(maxMsgBuffer) - }, -} - -func getMaxMsgBuffer() *maxMsgBuffer { - return bufPool.Get().(*maxMsgBuffer) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package controlbase implements the base transport of the Tailscale +// 2021 control protocol. +// +// The base transport implements Noise IK, instantiated with +// Curve25519, ChaCha20Poly1305 and BLAKE2s. +package controlbase + +import ( + "crypto/cipher" + "encoding/binary" + "fmt" + "net" + "sync" + "time" + + "golang.org/x/crypto/blake2s" + chp "golang.org/x/crypto/chacha20poly1305" + "tailscale.com/types/key" +) + +const ( + // maxMessageSize is the maximum size of a protocol frame on the + // wire, including header and payload. + maxMessageSize = 4096 + // maxCiphertextSize is the maximum amount of ciphertext bytes + // that one protocol frame can carry, after framing. + maxCiphertextSize = maxMessageSize - 3 + // maxPlaintextSize is the maximum amount of plaintext bytes that + // one protocol frame can carry, after encryption and framing. + maxPlaintextSize = maxCiphertextSize - chp.Overhead +) + +// A Conn is a secured Noise connection. It implements the net.Conn +// interface, with the unusual trait that any write error (including a +// SetWriteDeadline induced i/o timeout) causes all future writes to +// fail. +type Conn struct { + conn net.Conn + version uint16 + peer key.MachinePublic + handshakeHash [blake2s.Size]byte + rx rxState + tx txState +} + +// rxState is all the Conn state that Read uses. +type rxState struct { + sync.Mutex + cipher cipher.AEAD + nonce nonce + buf *maxMsgBuffer // or nil when reads exhausted + n int // number of valid bytes in buf + next int // offset of next undecrypted packet + plaintext []byte // slice into buf of decrypted bytes + hdrBuf [headerLen]byte // small buffer used when buf is nil +} + +// txState is all the Conn state that Write uses. +type txState struct { + sync.Mutex + cipher cipher.AEAD + nonce nonce + err error // records the first partial write error for all future calls +} + +// ProtocolVersion returns the protocol version that was used to +// establish this Conn. +func (c *Conn) ProtocolVersion() int { + return int(c.version) +} + +// HandshakeHash returns the Noise handshake hash for the connection, +// which can be used to bind other messages to this connection +// (i.e. to ensure that the message wasn't replayed from a different +// connection). +func (c *Conn) HandshakeHash() [blake2s.Size]byte { + return c.handshakeHash +} + +// Peer returns the peer's long-term public key. +func (c *Conn) Peer() key.MachinePublic { + return c.peer +} + +// readNLocked reads into c.rx.buf until buf contains at least total +// bytes. Returns a slice of the total bytes in rxBuf, or an +// error if fewer than total bytes are available. +// +// It may be called with a nil c.rx.buf only if total == headerLen. +// +// On success, c.rx.buf will be non-nil. +func (c *Conn) readNLocked(total int) ([]byte, error) { + if total > maxMessageSize { + return nil, errReadTooBig{total} + } + for { + if total <= c.rx.n { + return c.rx.buf[:total], nil + } + var n int + var err error + if c.rx.buf == nil { + if c.rx.n != 0 || total != headerLen { + panic("unexpected") + } + // Optimization to reduce memory usage. + // Most connections are blocked forever waiting for + // a read, so we don't want c.rx.buf to be allocated until + // we know there's data to read. Instead, when we're + // waiting for data to arrive here, read into the + // 3 byte hdrBuf: + n, err = c.conn.Read(c.rx.hdrBuf[:]) + if n > 0 { + c.rx.buf = getMaxMsgBuffer() + copy(c.rx.buf[:], c.rx.hdrBuf[:n]) + } + } else { + n, err = c.conn.Read(c.rx.buf[c.rx.n:]) + } + c.rx.n += n + if err != nil { + return nil, err + } + } +} + +// decryptLocked decrypts msg (which is header+ciphertext) in-place +// and sets c.rx.plaintext to the decrypted bytes. +func (c *Conn) decryptLocked(msg []byte) (err error) { + if msgType := msg[0]; msgType != msgTypeRecord { + return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord) + } + // We don't check the length field here, because the caller + // already did in order to figure out how big the msg slice should + // be. + ciphertext := msg[headerLen:] + + if !c.rx.nonce.Valid() { + return errCipherExhausted{} + } + + c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil) + c.rx.nonce.Increment() + + if err != nil { + // Once a decryption has failed, our Conn is no longer + // synchronized with our peer. Nuke the cipher state to be + // safe, so that no further decryptions are attempted. Future + // read attempts will return net.ErrClosed. + c.rx.cipher = nil + } + return err +} + +// encryptLocked encrypts plaintext into buf (including the +// packet header) and returns a slice of the ciphertext, or an error +// if the cipher is exhausted (i.e. can no longer be used safely). +func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) { + if !c.tx.nonce.Valid() { + // Received 2^64-1 messages on this cipher state. Connection + // is no longer usable. + return nil, errCipherExhausted{} + } + + buf[0] = msgTypeRecord + binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead)) + ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil) + c.tx.nonce.Increment() + + return ret, nil +} + +// wholeMessageLocked returns a slice of one whole Noise transport +// message from c.rx.buf, if one whole message is available, and +// advances the read state to the next Noise message in the +// buffer. Returns nil without advancing read state if there isn't one +// whole message in c.rx.buf. +func (c *Conn) wholeMessageLocked() []byte { + available := c.rx.n - c.rx.next + if available < headerLen { + return nil + } + bs := c.rx.buf[c.rx.next:c.rx.n] + totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) + if len(bs) < totalSize { + return nil + } + c.rx.next += totalSize + return bs[:totalSize] +} + +// decryptOneLocked decrypts one Noise transport message, reading from +// c.conn as needed, and sets c.rx.plaintext to point to the decrypted +// bytes. c.rx.plaintext is only valid if err == nil. +func (c *Conn) decryptOneLocked() error { + c.rx.plaintext = nil + + // Fast path: do we have one whole ciphertext frame buffered + // already? + if bs := c.wholeMessageLocked(); bs != nil { + return c.decryptLocked(bs) + } + + if c.rx.next != 0 { + // To simplify the read logic, move the remainder of the + // buffered bytes back to the head of the buffer, so we can + // grow it without worrying about wraparound. + c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n]) + c.rx.next = 0 + } + + // Return our buffer to the pool if it's empty, lest we be + // blocked in a long Read call, reading the 3 byte header. We + // don't to keep that buffer unnecessarily alive. + if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil { + bufPool.Put(c.rx.buf) + c.rx.buf = nil + } + + bs, err := c.readNLocked(headerLen) + if err != nil { + return err + } + // The rest of the header (besides the length field) gets verified + // in decryptLocked, not here. + messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) + bs, err = c.readNLocked(messageLen) + if err != nil { + return err + } + + c.rx.next = len(bs) + + return c.decryptLocked(bs) +} + +// Read implements io.Reader. +func (c *Conn) Read(bs []byte) (int, error) { + c.rx.Lock() + defer c.rx.Unlock() + + if c.rx.cipher == nil { + return 0, net.ErrClosed + } + // If no plaintext is buffered, decrypt incoming frames until we + // have some plaintext. Zero-byte Noise frames are allowed in this + // protocol, which is why we have to loop here rather than decrypt + // a single additional frame. + for len(c.rx.plaintext) == 0 { + if err := c.decryptOneLocked(); err != nil { + return 0, err + } + } + n := copy(bs, c.rx.plaintext) + c.rx.plaintext = c.rx.plaintext[n:] + + // Lose slice's underlying array pointer to unneeded memory so + // GC can collect more. + if len(c.rx.plaintext) == 0 { + c.rx.plaintext = nil + } + return n, nil +} + +// Write implements io.Writer. +func (c *Conn) Write(bs []byte) (n int, err error) { + c.tx.Lock() + defer c.tx.Unlock() + + if c.tx.err != nil { + return 0, c.tx.err + } + defer func() { + if err != nil { + // All write errors are fatal for this conn, so clear the + // cipher state whenever an error happens. + c.tx.cipher = nil + } + if c.tx.err == nil { + // Only set c.tx.err if not nil so that we can return one + // error on the first failure, and a different one for + // subsequent calls. See the error handling around Write + // below for why. + c.tx.err = err + } + }() + + if c.tx.cipher == nil { + return 0, net.ErrClosed + } + + buf := getMaxMsgBuffer() + defer bufPool.Put(buf) + + var sent int + for len(bs) > 0 { + toSend := bs + if len(toSend) > maxPlaintextSize { + toSend = bs[:maxPlaintextSize] + } + bs = bs[len(toSend):] + + ciphertext, err := c.encryptLocked(toSend, buf) + if err != nil { + return sent, err + } + if _, err := c.conn.Write(ciphertext); err != nil { + // Return the raw error on the Write that actually + // failed. For future writes, return that error wrapped in + // a desync error. + c.tx.err = errPartialWrite{err} + return sent, err + } + sent += len(toSend) + } + return sent, nil +} + +// Close implements io.Closer. +func (c *Conn) Close() error { + closeErr := c.conn.Close() // unblocks any waiting reads or writes + + // Remove references to live cipher state. Strictly speaking this + // is unnecessary, but we want to try and hand the active cipher + // state to the garbage collector promptly, to preserve perfect + // forward secrecy as much as we can. + c.rx.Lock() + c.rx.cipher = nil + c.rx.Unlock() + c.tx.Lock() + c.tx.cipher = nil + c.tx.Unlock() + return closeErr +} + +func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } +func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } +func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } +func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } +func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } + +// errCipherExhausted is the error returned when we run out of nonces +// on a cipher. +type errCipherExhausted struct{} + +func (errCipherExhausted) Error() string { + return "cipher exhausted, no more nonces available for current key" +} +func (errCipherExhausted) Timeout() bool { return false } +func (errCipherExhausted) Temporary() bool { return false } + +// errPartialWrite is the error returned when the cipher state has +// become unusable due to a past partial write. +type errPartialWrite struct { + err error +} + +func (e errPartialWrite) Error() string { + return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err) +} +func (e errPartialWrite) Unwrap() error { return e.err } +func (e errPartialWrite) Temporary() bool { return false } +func (e errPartialWrite) Timeout() bool { return false } + +// errReadTooBig is the error returned when the peer sent an +// unacceptably large Noise frame. +type errReadTooBig struct { + requested int +} + +func (e errReadTooBig) Error() string { + return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested) +} +func (e errReadTooBig) Temporary() bool { + // permanent error because this error only occurs when our peer + // sends us a frame so large we're unwilling to ever decode it. + return false +} +func (e errReadTooBig) Timeout() bool { return false } + +type nonce [chp.NonceSize]byte + +func (n *nonce) Valid() bool { + return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce +} + +func (n *nonce) Increment() { + if !n.Valid() { + panic("increment of invalid nonce") + } + binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:])) +} + +type maxMsgBuffer [maxMessageSize]byte + +// bufPool holds the temporary buffers for Conn.Read & Write. +var bufPool = &sync.Pool{ + New: func() any { + return new(maxMsgBuffer) + }, +} + +func getMaxMsgBuffer() *maxMsgBuffer { + return bufPool.Get().(*maxMsgBuffer) +} diff --git a/control/controlbase/handshake.go b/control/controlbase/handshake.go index 937969a3078a8..765a4620b876f 100644 --- a/control/controlbase/handshake.go +++ b/control/controlbase/handshake.go @@ -1,494 +1,494 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import ( - "context" - "crypto/cipher" - "encoding/binary" - "errors" - "fmt" - "hash" - "io" - "net" - "strconv" - "time" - - "go4.org/mem" - "golang.org/x/crypto/blake2s" - chp "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/hkdf" - "tailscale.com/types/key" -) - -const ( - // protocolName is the name of the specific instantiation of Noise - // that the control protocol uses. This string's value is fixed by - // the Noise spec, and shouldn't be changed unless we're updating - // the control protocol to use a different Noise instance. - protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s" - // protocolVersion is the version of the control protocol that - // Client will use when initiating a handshake. - //protocolVersion uint16 = 1 - // protocolVersionPrefix is the name portion of the protocol - // name+version string that gets mixed into the handshake as a - // prologue. - // - // This mixing verifies that both clients agree that they're - // executing the control protocol at a specific version that - // matches the advertised version in the cleartext packet header. - protocolVersionPrefix = "Tailscale Control Protocol v" - invalidNonce = ^uint64(0) -) - -func protocolVersionPrologue(version uint16) []byte { - ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers. - ret = append(ret, protocolVersionPrefix...) - return strconv.AppendUint(ret, uint64(version), 10) -} - -// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn -// is assumed to have already sent the client>server handshake -// initiation message. -type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error) - -// ClientDeferred initiates a control client handshake, returning the -// initial message to send to the server and a continuation to -// finalize the handshake. -// -// ClientDeferred is split in this way for RTT reduction: we run this -// protocol after negotiating a protocol switch from HTTP/HTTPS. If we -// completely serialized the negotiation followed by the handshake, -// we'd pay an extra RTT to transmit the handshake initiation after -// protocol switching. By splitting the handshake into an initial -// message and a continuation, we can embed the handshake initiation -// into the HTTP protocol switching request and avoid a bit of delay. -func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) { - var s symmetricState - s.Initialize() - - // prologue - s.MixHash(protocolVersionPrologue(protocolVersion)) - - // <- s - // ... - s.MixHash(controlKey.UntypedBytes()) - - // -> e, es, s, ss - init := mkInitiationMessage(protocolVersion) - machineEphemeral := key.NewMachine() - machineEphemeralPub := machineEphemeral.Public() - copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes()) - s.MixHash(machineEphemeralPub.UntypedBytes()) - cipher, err := s.MixDH(machineEphemeral, controlKey) - if err != nil { - return nil, nil, fmt.Errorf("computing es: %w", err) - } - machineKeyPub := machineKey.Public() - s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes()) - cipher, err = s.MixDH(machineKey, controlKey) - if err != nil { - return nil, nil, fmt.Errorf("computing ss: %w", err) - } - s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload - - cont := func(ctx context.Context, conn net.Conn) (*Conn, error) { - return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion) - } - return init[:], cont, nil -} - -// Client wraps ClientDeferred and immediately invokes the returned -// continuation with conn. -// -// This is a helper for when you don't need the fancy -// continuation-style handshake, and just want to synchronously -// upgrade a net.Conn to a secure transport. -func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { - init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion) - if err != nil { - return nil, err - } - if _, err := conn.Write(init); err != nil { - return nil, err - } - return cont(ctx, conn) -} - -func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { - // No matter what, this function can only run once per s. Ensure - // attempted reuse causes a panic. - defer func() { - s.finished = true - }() - - if deadline, ok := ctx.Deadline(); ok { - if err := conn.SetDeadline(deadline); err != nil { - return nil, fmt.Errorf("setting conn deadline: %w", err) - } - defer func() { - conn.SetDeadline(time.Time{}) - }() - } - - // Read in the payload and look for errors/protocol violations from the server. - var resp responseMessage - if _, err := io.ReadFull(conn, resp.Header()); err != nil { - return nil, fmt.Errorf("reading response header: %w", err) - } - if resp.Type() != msgTypeResponse { - if resp.Type() != msgTypeError { - return nil, fmt.Errorf("unexpected response message type %d", resp.Type()) - } - msg := make([]byte, resp.Length()) - if _, err := io.ReadFull(conn, msg); err != nil { - return nil, err - } - return nil, fmt.Errorf("server error: %q", msg) - } - if resp.Length() != len(resp.Payload()) { - return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length()) - } - if _, err := io.ReadFull(conn, resp.Payload()); err != nil { - return nil, err - } - - // <- e, ee, se - controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub())) - s.MixHash(controlEphemeralPub.UntypedBytes()) - if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { - return nil, fmt.Errorf("computing ee: %w", err) - } - cipher, err := s.MixDH(machineKey, controlEphemeralPub) - if err != nil { - return nil, fmt.Errorf("computing se: %w", err) - } - if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil { - return nil, fmt.Errorf("decrypting payload: %w", err) - } - - c1, c2, err := s.Split() - if err != nil { - return nil, fmt.Errorf("finalizing handshake: %w", err) - } - - c := &Conn{ - conn: conn, - version: protocolVersion, - peer: controlKey, - handshakeHash: s.h, - tx: txState{ - cipher: c1, - }, - rx: rxState{ - cipher: c2, - }, - } - return c, nil -} - -// Server initiates a control server handshake, returning the resulting -// control connection. -// -// optionalInit can be the client's initial handshake message as -// returned by ClientDeferred, or nil in which case the initial -// message is read from conn. -// -// The context deadline, if any, covers the entire handshaking -// process. -func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) { - if deadline, ok := ctx.Deadline(); ok { - if err := conn.SetDeadline(deadline); err != nil { - return nil, fmt.Errorf("setting conn deadline: %w", err) - } - defer func() { - conn.SetDeadline(time.Time{}) - }() - } - - // Deliberately does not support formatting, so that we don't echo - // attacker-controlled input back to them. - sendErr := func(msg string) error { - if len(msg) >= 1<<16 { - msg = msg[:1<<16] - } - var hdr [headerLen]byte - hdr[0] = msgTypeError - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg))) - if _, err := conn.Write(hdr[:]); err != nil { - return fmt.Errorf("sending %q error to client: %w", msg, err) - } - if _, err := io.WriteString(conn, msg); err != nil { - return fmt.Errorf("sending %q error to client: %w", msg, err) - } - return fmt.Errorf("refused client handshake: %q", msg) - } - - var s symmetricState - s.Initialize() - - var init initiationMessage - if optionalInit != nil { - if len(optionalInit) != len(init) { - return nil, sendErr("wrong handshake initiation size") - } - copy(init[:], optionalInit) - } else if _, err := io.ReadFull(conn, init.Header()); err != nil { - return nil, err - } - // Just a rename to make it more obvious what the value is. In the - // current implementation we don't need to block any protocol - // versions at this layer, it's safe to let the handshake proceed - // and then let the caller make decisions based on the agreed-upon - // protocol version. - clientVersion := init.Version() - if init.Type() != msgTypeInitiation { - return nil, sendErr("unexpected handshake message type") - } - if init.Length() != len(init.Payload()) { - return nil, sendErr("wrong handshake initiation length") - } - // if optionalInit was provided, we have the payload already. - if optionalInit == nil { - if _, err := io.ReadFull(conn, init.Payload()); err != nil { - return nil, err - } - } - - // prologue. Can only do this once we at least think the client is - // handshaking using a supported version. - s.MixHash(protocolVersionPrologue(clientVersion)) - - // <- s - // ... - controlKeyPub := controlKey.Public() - s.MixHash(controlKeyPub.UntypedBytes()) - - // -> e, es, s, ss - machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub())) - s.MixHash(machineEphemeralPub.UntypedBytes()) - cipher, err := s.MixDH(controlKey, machineEphemeralPub) - if err != nil { - return nil, fmt.Errorf("computing es: %w", err) - } - var machineKeyBytes [32]byte - if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil { - return nil, fmt.Errorf("decrypting machine key: %w", err) - } - machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:])) - cipher, err = s.MixDH(controlKey, machineKey) - if err != nil { - return nil, fmt.Errorf("computing ss: %w", err) - } - if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil { - return nil, fmt.Errorf("decrypting initiation tag: %w", err) - } - - // <- e, ee, se - resp := mkResponseMessage() - controlEphemeral := key.NewMachine() - controlEphemeralPub := controlEphemeral.Public() - copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes()) - s.MixHash(controlEphemeralPub.UntypedBytes()) - if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil { - return nil, fmt.Errorf("computing ee: %w", err) - } - cipher, err = s.MixDH(controlEphemeral, machineKey) - if err != nil { - return nil, fmt.Errorf("computing se: %w", err) - } - s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload - - c1, c2, err := s.Split() - if err != nil { - return nil, fmt.Errorf("finalizing handshake: %w", err) - } - - if _, err := conn.Write(resp[:]); err != nil { - return nil, err - } - - c := &Conn{ - conn: conn, - version: clientVersion, - peer: machineKey, - handshakeHash: s.h, - tx: txState{ - cipher: c2, - }, - rx: rxState{ - cipher: c1, - }, - } - return c, nil -} - -// symmetricState contains the state of an in-flight handshake. -type symmetricState struct { - finished bool - - h [blake2s.Size]byte // hash of currently-processed handshake state - ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake -} - -func (s *symmetricState) checkFinished() { - if s.finished { - panic("attempted to use symmetricState after Split was called") - } -} - -// Initialize sets s to the initial handshake state, prior to -// processing any handshake messages. -func (s *symmetricState) Initialize() { - s.checkFinished() - s.h = blake2s.Sum256([]byte(protocolName)) - s.ck = s.h -} - -// MixHash updates s.h to be BLAKE2s(s.h || data), where || is -// concatenation. -func (s *symmetricState) MixHash(data []byte) { - s.checkFinished() - h := newBLAKE2s() - h.Write(s.h[:]) - h.Write(data) - h.Sum(s.h[:0]) -} - -// MixDH updates s.ck with the result of X25519(priv, pub) and returns -// a singleUseCHP that can be used to encrypt or decrypt handshake -// data. -// -// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing -// it as a single function allows for strongly-typed arguments that -// reduce the risk of error in the caller (e.g. invoking X25519 with -// two private keys, or two public keys), and thus producing the wrong -// calculation. -func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) { - s.checkFinished() - keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes()) - if err != nil { - return nil, fmt.Errorf("computing X25519: %w", err) - } - - r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil) - if _, err := io.ReadFull(r, s.ck[:]); err != nil { - return nil, fmt.Errorf("extracting ck: %w", err) - } - var k [chp.KeySize]byte - if _, err := io.ReadFull(r, k[:]); err != nil { - return nil, fmt.Errorf("extracting k: %w", err) - } - return newSingleUseCHP(k), nil -} - -// EncryptAndHash encrypts plaintext into ciphertext (which must be -// the correct size to hold the encrypted plaintext) using cipher, -// mixes the ciphertext into s.h, and returns the ciphertext. -func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) { - s.checkFinished() - if len(ciphertext) != len(plaintext)+chp.Overhead { - panic("ciphertext is wrong size for given plaintext") - } - ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:]) - s.MixHash(ret) -} - -// DecryptAndHash decrypts the given ciphertext into plaintext (which -// must be the correct size to hold the decrypted ciphertext) using -// cipher. If decryption is successful, it mixes the ciphertext into -// s.h. -func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error { - s.checkFinished() - if len(ciphertext) != len(plaintext)+chp.Overhead { - return errors.New("plaintext is wrong size for given ciphertext") - } - if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil { - return err - } - s.MixHash(ciphertext) - return nil -} - -// Split returns two ChaCha20Poly1305 ciphers with keys derived from -// the current handshake state. Methods on s cannot be used again -// after calling Split. -func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) { - s.finished = true - - var k1, k2 [chp.KeySize]byte - r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil) - if _, err := io.ReadFull(r, k1[:]); err != nil { - return nil, nil, fmt.Errorf("extracting k1: %w", err) - } - if _, err := io.ReadFull(r, k2[:]); err != nil { - return nil, nil, fmt.Errorf("extracting k2: %w", err) - } - c1, err = chp.New(k1[:]) - if err != nil { - return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err) - } - c2, err = chp.New(k2[:]) - if err != nil { - return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err) - } - return c1, c2, nil -} - -// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on -// error. -func newBLAKE2s() hash.Hash { - h, err := blake2s.New256(nil) - if err != nil { - // Should never happen, errors only happen when using BLAKE2s - // in MAC mode with a key. - panic(err) - } - return h -} - -// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or -// panics on error. -func newCHP(key [chp.KeySize]byte) cipher.AEAD { - aead, err := chp.New(key[:]) - if err != nil { - // Can only happen if we passed a key of the wrong length. The - // function signature prevents that. - panic(err) - } - return aead -} - -// singleUseCHP is an instance of ChaCha20Poly1305 that can be used -// only once, either for encrypting or decrypting, but not both. The -// chosen operation is always executed with an all-zeros -// nonce. Subsequent calls to either Seal or Open panic. -type singleUseCHP struct { - c cipher.AEAD -} - -func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP { - return &singleUseCHP{newCHP(key)} -} - -func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte { - if c.c == nil { - panic("Attempted reuse of singleUseAEAD") - } - cipher := c.c - c.c = nil - var nonce [chp.NonceSize]byte - return cipher.Seal(dst, nonce[:], plaintext, additionalData) -} - -func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) { - if c.c == nil { - panic("Attempted reuse of singleUseAEAD") - } - cipher := c.c - c.c = nil - var nonce [chp.NonceSize]byte - return cipher.Open(dst, nonce[:], ciphertext, additionalData) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlbase + +import ( + "context" + "crypto/cipher" + "encoding/binary" + "errors" + "fmt" + "hash" + "io" + "net" + "strconv" + "time" + + "go4.org/mem" + "golang.org/x/crypto/blake2s" + chp "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" + "tailscale.com/types/key" +) + +const ( + // protocolName is the name of the specific instantiation of Noise + // that the control protocol uses. This string's value is fixed by + // the Noise spec, and shouldn't be changed unless we're updating + // the control protocol to use a different Noise instance. + protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s" + // protocolVersion is the version of the control protocol that + // Client will use when initiating a handshake. + //protocolVersion uint16 = 1 + // protocolVersionPrefix is the name portion of the protocol + // name+version string that gets mixed into the handshake as a + // prologue. + // + // This mixing verifies that both clients agree that they're + // executing the control protocol at a specific version that + // matches the advertised version in the cleartext packet header. + protocolVersionPrefix = "Tailscale Control Protocol v" + invalidNonce = ^uint64(0) +) + +func protocolVersionPrologue(version uint16) []byte { + ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers. + ret = append(ret, protocolVersionPrefix...) + return strconv.AppendUint(ret, uint64(version), 10) +} + +// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn +// is assumed to have already sent the client>server handshake +// initiation message. +type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error) + +// ClientDeferred initiates a control client handshake, returning the +// initial message to send to the server and a continuation to +// finalize the handshake. +// +// ClientDeferred is split in this way for RTT reduction: we run this +// protocol after negotiating a protocol switch from HTTP/HTTPS. If we +// completely serialized the negotiation followed by the handshake, +// we'd pay an extra RTT to transmit the handshake initiation after +// protocol switching. By splitting the handshake into an initial +// message and a continuation, we can embed the handshake initiation +// into the HTTP protocol switching request and avoid a bit of delay. +func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) { + var s symmetricState + s.Initialize() + + // prologue + s.MixHash(protocolVersionPrologue(protocolVersion)) + + // <- s + // ... + s.MixHash(controlKey.UntypedBytes()) + + // -> e, es, s, ss + init := mkInitiationMessage(protocolVersion) + machineEphemeral := key.NewMachine() + machineEphemeralPub := machineEphemeral.Public() + copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes()) + s.MixHash(machineEphemeralPub.UntypedBytes()) + cipher, err := s.MixDH(machineEphemeral, controlKey) + if err != nil { + return nil, nil, fmt.Errorf("computing es: %w", err) + } + machineKeyPub := machineKey.Public() + s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes()) + cipher, err = s.MixDH(machineKey, controlKey) + if err != nil { + return nil, nil, fmt.Errorf("computing ss: %w", err) + } + s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload + + cont := func(ctx context.Context, conn net.Conn) (*Conn, error) { + return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion) + } + return init[:], cont, nil +} + +// Client wraps ClientDeferred and immediately invokes the returned +// continuation with conn. +// +// This is a helper for when you don't need the fancy +// continuation-style handshake, and just want to synchronously +// upgrade a net.Conn to a secure transport. +func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { + init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion) + if err != nil { + return nil, err + } + if _, err := conn.Write(init); err != nil { + return nil, err + } + return cont(ctx, conn) +} + +func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { + // No matter what, this function can only run once per s. Ensure + // attempted reuse causes a panic. + defer func() { + s.finished = true + }() + + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + return nil, fmt.Errorf("setting conn deadline: %w", err) + } + defer func() { + conn.SetDeadline(time.Time{}) + }() + } + + // Read in the payload and look for errors/protocol violations from the server. + var resp responseMessage + if _, err := io.ReadFull(conn, resp.Header()); err != nil { + return nil, fmt.Errorf("reading response header: %w", err) + } + if resp.Type() != msgTypeResponse { + if resp.Type() != msgTypeError { + return nil, fmt.Errorf("unexpected response message type %d", resp.Type()) + } + msg := make([]byte, resp.Length()) + if _, err := io.ReadFull(conn, msg); err != nil { + return nil, err + } + return nil, fmt.Errorf("server error: %q", msg) + } + if resp.Length() != len(resp.Payload()) { + return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length()) + } + if _, err := io.ReadFull(conn, resp.Payload()); err != nil { + return nil, err + } + + // <- e, ee, se + controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub())) + s.MixHash(controlEphemeralPub.UntypedBytes()) + if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { + return nil, fmt.Errorf("computing ee: %w", err) + } + cipher, err := s.MixDH(machineKey, controlEphemeralPub) + if err != nil { + return nil, fmt.Errorf("computing se: %w", err) + } + if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil { + return nil, fmt.Errorf("decrypting payload: %w", err) + } + + c1, c2, err := s.Split() + if err != nil { + return nil, fmt.Errorf("finalizing handshake: %w", err) + } + + c := &Conn{ + conn: conn, + version: protocolVersion, + peer: controlKey, + handshakeHash: s.h, + tx: txState{ + cipher: c1, + }, + rx: rxState{ + cipher: c2, + }, + } + return c, nil +} + +// Server initiates a control server handshake, returning the resulting +// control connection. +// +// optionalInit can be the client's initial handshake message as +// returned by ClientDeferred, or nil in which case the initial +// message is read from conn. +// +// The context deadline, if any, covers the entire handshaking +// process. +func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) { + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + return nil, fmt.Errorf("setting conn deadline: %w", err) + } + defer func() { + conn.SetDeadline(time.Time{}) + }() + } + + // Deliberately does not support formatting, so that we don't echo + // attacker-controlled input back to them. + sendErr := func(msg string) error { + if len(msg) >= 1<<16 { + msg = msg[:1<<16] + } + var hdr [headerLen]byte + hdr[0] = msgTypeError + binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg))) + if _, err := conn.Write(hdr[:]); err != nil { + return fmt.Errorf("sending %q error to client: %w", msg, err) + } + if _, err := io.WriteString(conn, msg); err != nil { + return fmt.Errorf("sending %q error to client: %w", msg, err) + } + return fmt.Errorf("refused client handshake: %q", msg) + } + + var s symmetricState + s.Initialize() + + var init initiationMessage + if optionalInit != nil { + if len(optionalInit) != len(init) { + return nil, sendErr("wrong handshake initiation size") + } + copy(init[:], optionalInit) + } else if _, err := io.ReadFull(conn, init.Header()); err != nil { + return nil, err + } + // Just a rename to make it more obvious what the value is. In the + // current implementation we don't need to block any protocol + // versions at this layer, it's safe to let the handshake proceed + // and then let the caller make decisions based on the agreed-upon + // protocol version. + clientVersion := init.Version() + if init.Type() != msgTypeInitiation { + return nil, sendErr("unexpected handshake message type") + } + if init.Length() != len(init.Payload()) { + return nil, sendErr("wrong handshake initiation length") + } + // if optionalInit was provided, we have the payload already. + if optionalInit == nil { + if _, err := io.ReadFull(conn, init.Payload()); err != nil { + return nil, err + } + } + + // prologue. Can only do this once we at least think the client is + // handshaking using a supported version. + s.MixHash(protocolVersionPrologue(clientVersion)) + + // <- s + // ... + controlKeyPub := controlKey.Public() + s.MixHash(controlKeyPub.UntypedBytes()) + + // -> e, es, s, ss + machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub())) + s.MixHash(machineEphemeralPub.UntypedBytes()) + cipher, err := s.MixDH(controlKey, machineEphemeralPub) + if err != nil { + return nil, fmt.Errorf("computing es: %w", err) + } + var machineKeyBytes [32]byte + if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil { + return nil, fmt.Errorf("decrypting machine key: %w", err) + } + machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:])) + cipher, err = s.MixDH(controlKey, machineKey) + if err != nil { + return nil, fmt.Errorf("computing ss: %w", err) + } + if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil { + return nil, fmt.Errorf("decrypting initiation tag: %w", err) + } + + // <- e, ee, se + resp := mkResponseMessage() + controlEphemeral := key.NewMachine() + controlEphemeralPub := controlEphemeral.Public() + copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes()) + s.MixHash(controlEphemeralPub.UntypedBytes()) + if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil { + return nil, fmt.Errorf("computing ee: %w", err) + } + cipher, err = s.MixDH(controlEphemeral, machineKey) + if err != nil { + return nil, fmt.Errorf("computing se: %w", err) + } + s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload + + c1, c2, err := s.Split() + if err != nil { + return nil, fmt.Errorf("finalizing handshake: %w", err) + } + + if _, err := conn.Write(resp[:]); err != nil { + return nil, err + } + + c := &Conn{ + conn: conn, + version: clientVersion, + peer: machineKey, + handshakeHash: s.h, + tx: txState{ + cipher: c2, + }, + rx: rxState{ + cipher: c1, + }, + } + return c, nil +} + +// symmetricState contains the state of an in-flight handshake. +type symmetricState struct { + finished bool + + h [blake2s.Size]byte // hash of currently-processed handshake state + ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake +} + +func (s *symmetricState) checkFinished() { + if s.finished { + panic("attempted to use symmetricState after Split was called") + } +} + +// Initialize sets s to the initial handshake state, prior to +// processing any handshake messages. +func (s *symmetricState) Initialize() { + s.checkFinished() + s.h = blake2s.Sum256([]byte(protocolName)) + s.ck = s.h +} + +// MixHash updates s.h to be BLAKE2s(s.h || data), where || is +// concatenation. +func (s *symmetricState) MixHash(data []byte) { + s.checkFinished() + h := newBLAKE2s() + h.Write(s.h[:]) + h.Write(data) + h.Sum(s.h[:0]) +} + +// MixDH updates s.ck with the result of X25519(priv, pub) and returns +// a singleUseCHP that can be used to encrypt or decrypt handshake +// data. +// +// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing +// it as a single function allows for strongly-typed arguments that +// reduce the risk of error in the caller (e.g. invoking X25519 with +// two private keys, or two public keys), and thus producing the wrong +// calculation. +func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) { + s.checkFinished() + keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes()) + if err != nil { + return nil, fmt.Errorf("computing X25519: %w", err) + } + + r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil) + if _, err := io.ReadFull(r, s.ck[:]); err != nil { + return nil, fmt.Errorf("extracting ck: %w", err) + } + var k [chp.KeySize]byte + if _, err := io.ReadFull(r, k[:]); err != nil { + return nil, fmt.Errorf("extracting k: %w", err) + } + return newSingleUseCHP(k), nil +} + +// EncryptAndHash encrypts plaintext into ciphertext (which must be +// the correct size to hold the encrypted plaintext) using cipher, +// mixes the ciphertext into s.h, and returns the ciphertext. +func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) { + s.checkFinished() + if len(ciphertext) != len(plaintext)+chp.Overhead { + panic("ciphertext is wrong size for given plaintext") + } + ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:]) + s.MixHash(ret) +} + +// DecryptAndHash decrypts the given ciphertext into plaintext (which +// must be the correct size to hold the decrypted ciphertext) using +// cipher. If decryption is successful, it mixes the ciphertext into +// s.h. +func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error { + s.checkFinished() + if len(ciphertext) != len(plaintext)+chp.Overhead { + return errors.New("plaintext is wrong size for given ciphertext") + } + if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil { + return err + } + s.MixHash(ciphertext) + return nil +} + +// Split returns two ChaCha20Poly1305 ciphers with keys derived from +// the current handshake state. Methods on s cannot be used again +// after calling Split. +func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) { + s.finished = true + + var k1, k2 [chp.KeySize]byte + r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil) + if _, err := io.ReadFull(r, k1[:]); err != nil { + return nil, nil, fmt.Errorf("extracting k1: %w", err) + } + if _, err := io.ReadFull(r, k2[:]); err != nil { + return nil, nil, fmt.Errorf("extracting k2: %w", err) + } + c1, err = chp.New(k1[:]) + if err != nil { + return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err) + } + c2, err = chp.New(k2[:]) + if err != nil { + return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err) + } + return c1, c2, nil +} + +// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on +// error. +func newBLAKE2s() hash.Hash { + h, err := blake2s.New256(nil) + if err != nil { + // Should never happen, errors only happen when using BLAKE2s + // in MAC mode with a key. + panic(err) + } + return h +} + +// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or +// panics on error. +func newCHP(key [chp.KeySize]byte) cipher.AEAD { + aead, err := chp.New(key[:]) + if err != nil { + // Can only happen if we passed a key of the wrong length. The + // function signature prevents that. + panic(err) + } + return aead +} + +// singleUseCHP is an instance of ChaCha20Poly1305 that can be used +// only once, either for encrypting or decrypting, but not both. The +// chosen operation is always executed with an all-zeros +// nonce. Subsequent calls to either Seal or Open panic. +type singleUseCHP struct { + c cipher.AEAD +} + +func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP { + return &singleUseCHP{newCHP(key)} +} + +func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte { + if c.c == nil { + panic("Attempted reuse of singleUseAEAD") + } + cipher := c.c + c.c = nil + var nonce [chp.NonceSize]byte + return cipher.Seal(dst, nonce[:], plaintext, additionalData) +} + +func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) { + if c.c == nil { + panic("Attempted reuse of singleUseAEAD") + } + cipher := c.c + c.c = nil + var nonce [chp.NonceSize]byte + return cipher.Open(dst, nonce[:], ciphertext, additionalData) +} diff --git a/control/controlbase/interop_test.go b/control/controlbase/interop_test.go index d11c0414911f3..c41fbf4dd4950 100644 --- a/control/controlbase/interop_test.go +++ b/control/controlbase/interop_test.go @@ -1,256 +1,256 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import ( - "context" - "encoding/binary" - "errors" - "io" - "net" - "testing" - - "tailscale.com/net/memnet" - "tailscale.com/types/key" -) - -// Can a reference Noise IK client talk to our server? -func TestInteropClient(t *testing.T) { - var ( - s1, s2 = memnet.NewConn("noise", 128000) - controlKey = key.NewMachine() - machineKey = key.NewMachine() - serverErr = make(chan error, 2) - serverBytes = make(chan []byte, 1) - c2s = "client>server" - s2c = "server>client" - ) - - go func() { - server, err := Server(context.Background(), s2, controlKey, nil) - serverErr <- err - if err != nil { - return - } - var buf [1024]byte - _, err = io.ReadFull(server, buf[:len(c2s)]) - serverBytes <- buf[:len(c2s)] - if err != nil { - serverErr <- err - return - } - _, err = server.Write([]byte(s2c)) - serverErr <- err - }() - - gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s)) - if err != nil { - t.Fatalf("failed client interop: %v", err) - } - if string(gotS2C) != s2c { - t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c) - } - - if err := <-serverErr; err != nil { - t.Fatalf("server handshake failed: %v", err) - } - if err := <-serverErr; err != nil { - t.Fatalf("server read/write failed: %v", err) - } - if got := string(<-serverBytes); got != c2s { - t.Fatalf("server received %q, want %q", got, c2s) - } -} - -// Can our client talk to a reference Noise IK server? -func TestInteropServer(t *testing.T) { - var ( - s1, s2 = memnet.NewConn("noise", 128000) - controlKey = key.NewMachine() - machineKey = key.NewMachine() - clientErr = make(chan error, 2) - clientBytes = make(chan []byte, 1) - c2s = "client>server" - s2c = "server>client" - ) - - go func() { - client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion) - clientErr <- err - if err != nil { - return - } - _, err = client.Write([]byte(c2s)) - if err != nil { - clientErr <- err - return - } - var buf [1024]byte - _, err = io.ReadFull(client, buf[:len(s2c)]) - clientBytes <- buf[:len(s2c)] - clientErr <- err - }() - - gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c)) - if err != nil { - t.Fatalf("failed server interop: %v", err) - } - if string(gotC2S) != c2s { - t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s) - } - - if err := <-clientErr; err != nil { - t.Fatalf("client handshake failed: %v", err) - } - if err := <-clientErr; err != nil { - t.Fatalf("client read/write failed: %v", err) - } - if got := string(<-clientBytes); got != s2c { - t.Fatalf("client received %q, want %q", got, s2c) - } -} - -// noiseExplorerClient uses the Noise Explorer implementation of Noise -// IK to handshake as a Noise client on conn, transmit payload, and -// read+return a payload from the peer. -func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey key.MachinePrivate, payload []byte) ([]byte, error) { - var mk keypair - copy(mk.private_key[:], machineKey.UntypedBytes()) - copy(mk.public_key[:], machineKey.Public().UntypedBytes()) - var peerKey [32]byte - copy(peerKey[:], controlKey.UntypedBytes()) - session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey) - - _, msg1 := SendMessage(&session, nil) - var hdr [initiationHeaderLen]byte - binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion) - hdr[2] = msgTypeInitiation - binary.BigEndian.PutUint16(hdr[3:5], 96) - if _, err := conn.Write(hdr[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg1.ne[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg1.ns); err != nil { - return nil, err - } - if _, err := conn.Write(msg1.ciphertext); err != nil { - return nil, err - } - - var buf [1024]byte - if _, err := io.ReadFull(conn, buf[:51]); err != nil { - return nil, err - } - // ignore the header for this test, we're only checking the noise - // implementation. - msg2 := messagebuffer{ - ciphertext: buf[35:51], - } - copy(msg2.ne[:], buf[3:35]) - _, p, valid := RecvMessage(&session, &msg2) - if !valid { - return nil, errors.New("handshake failed") - } - if len(p) != 0 { - return nil, errors.New("non-empty payload") - } - - _, msg3 := SendMessage(&session, payload) - hdr[0] = msgTypeRecord - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg3.ciphertext))) - if _, err := conn.Write(hdr[:3]); err != nil { - return nil, err - } - if _, err := conn.Write(msg3.ciphertext); err != nil { - return nil, err - } - - if _, err := io.ReadFull(conn, buf[:3]); err != nil { - return nil, err - } - // Ignore all of the header except the payload length - plen := int(binary.BigEndian.Uint16(buf[1:3])) - if _, err := io.ReadFull(conn, buf[:plen]); err != nil { - return nil, err - } - - msg4 := messagebuffer{ - ciphertext: buf[:plen], - } - _, p, valid = RecvMessage(&session, &msg4) - if !valid { - return nil, errors.New("transport message decryption failed") - } - - return p, nil -} - -func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachineKey key.MachinePublic, payload []byte) ([]byte, error) { - var mk keypair - copy(mk.private_key[:], controlKey.UntypedBytes()) - copy(mk.public_key[:], controlKey.Public().UntypedBytes()) - session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{}) - - var buf [1024]byte - if _, err := io.ReadFull(conn, buf[:101]); err != nil { - return nil, err - } - // Ignore the header, we're just checking the noise implementation. - msg1 := messagebuffer{ - ns: buf[37:85], - ciphertext: buf[85:101], - } - copy(msg1.ne[:], buf[5:37]) - _, p, valid := RecvMessage(&session, &msg1) - if !valid { - return nil, errors.New("handshake failed") - } - if len(p) != 0 { - return nil, errors.New("non-empty payload") - } - - _, msg2 := SendMessage(&session, nil) - var hdr [headerLen]byte - hdr[0] = msgTypeResponse - binary.BigEndian.PutUint16(hdr[1:3], 48) - if _, err := conn.Write(hdr[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg2.ne[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg2.ciphertext[:]); err != nil { - return nil, err - } - - if _, err := io.ReadFull(conn, buf[:3]); err != nil { - return nil, err - } - plen := int(binary.BigEndian.Uint16(buf[1:3])) - if _, err := io.ReadFull(conn, buf[:plen]); err != nil { - return nil, err - } - - msg3 := messagebuffer{ - ciphertext: buf[:plen], - } - _, p, valid = RecvMessage(&session, &msg3) - if !valid { - return nil, errors.New("transport message decryption failed") - } - - _, msg4 := SendMessage(&session, payload) - hdr[0] = msgTypeRecord - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg4.ciphertext))) - if _, err := conn.Write(hdr[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg4.ciphertext); err != nil { - return nil, err - } - - return p, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlbase + +import ( + "context" + "encoding/binary" + "errors" + "io" + "net" + "testing" + + "tailscale.com/net/memnet" + "tailscale.com/types/key" +) + +// Can a reference Noise IK client talk to our server? +func TestInteropClient(t *testing.T) { + var ( + s1, s2 = memnet.NewConn("noise", 128000) + controlKey = key.NewMachine() + machineKey = key.NewMachine() + serverErr = make(chan error, 2) + serverBytes = make(chan []byte, 1) + c2s = "client>server" + s2c = "server>client" + ) + + go func() { + server, err := Server(context.Background(), s2, controlKey, nil) + serverErr <- err + if err != nil { + return + } + var buf [1024]byte + _, err = io.ReadFull(server, buf[:len(c2s)]) + serverBytes <- buf[:len(c2s)] + if err != nil { + serverErr <- err + return + } + _, err = server.Write([]byte(s2c)) + serverErr <- err + }() + + gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s)) + if err != nil { + t.Fatalf("failed client interop: %v", err) + } + if string(gotS2C) != s2c { + t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c) + } + + if err := <-serverErr; err != nil { + t.Fatalf("server handshake failed: %v", err) + } + if err := <-serverErr; err != nil { + t.Fatalf("server read/write failed: %v", err) + } + if got := string(<-serverBytes); got != c2s { + t.Fatalf("server received %q, want %q", got, c2s) + } +} + +// Can our client talk to a reference Noise IK server? +func TestInteropServer(t *testing.T) { + var ( + s1, s2 = memnet.NewConn("noise", 128000) + controlKey = key.NewMachine() + machineKey = key.NewMachine() + clientErr = make(chan error, 2) + clientBytes = make(chan []byte, 1) + c2s = "client>server" + s2c = "server>client" + ) + + go func() { + client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion) + clientErr <- err + if err != nil { + return + } + _, err = client.Write([]byte(c2s)) + if err != nil { + clientErr <- err + return + } + var buf [1024]byte + _, err = io.ReadFull(client, buf[:len(s2c)]) + clientBytes <- buf[:len(s2c)] + clientErr <- err + }() + + gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c)) + if err != nil { + t.Fatalf("failed server interop: %v", err) + } + if string(gotC2S) != c2s { + t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s) + } + + if err := <-clientErr; err != nil { + t.Fatalf("client handshake failed: %v", err) + } + if err := <-clientErr; err != nil { + t.Fatalf("client read/write failed: %v", err) + } + if got := string(<-clientBytes); got != s2c { + t.Fatalf("client received %q, want %q", got, s2c) + } +} + +// noiseExplorerClient uses the Noise Explorer implementation of Noise +// IK to handshake as a Noise client on conn, transmit payload, and +// read+return a payload from the peer. +func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey key.MachinePrivate, payload []byte) ([]byte, error) { + var mk keypair + copy(mk.private_key[:], machineKey.UntypedBytes()) + copy(mk.public_key[:], machineKey.Public().UntypedBytes()) + var peerKey [32]byte + copy(peerKey[:], controlKey.UntypedBytes()) + session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey) + + _, msg1 := SendMessage(&session, nil) + var hdr [initiationHeaderLen]byte + binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion) + hdr[2] = msgTypeInitiation + binary.BigEndian.PutUint16(hdr[3:5], 96) + if _, err := conn.Write(hdr[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg1.ne[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg1.ns); err != nil { + return nil, err + } + if _, err := conn.Write(msg1.ciphertext); err != nil { + return nil, err + } + + var buf [1024]byte + if _, err := io.ReadFull(conn, buf[:51]); err != nil { + return nil, err + } + // ignore the header for this test, we're only checking the noise + // implementation. + msg2 := messagebuffer{ + ciphertext: buf[35:51], + } + copy(msg2.ne[:], buf[3:35]) + _, p, valid := RecvMessage(&session, &msg2) + if !valid { + return nil, errors.New("handshake failed") + } + if len(p) != 0 { + return nil, errors.New("non-empty payload") + } + + _, msg3 := SendMessage(&session, payload) + hdr[0] = msgTypeRecord + binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg3.ciphertext))) + if _, err := conn.Write(hdr[:3]); err != nil { + return nil, err + } + if _, err := conn.Write(msg3.ciphertext); err != nil { + return nil, err + } + + if _, err := io.ReadFull(conn, buf[:3]); err != nil { + return nil, err + } + // Ignore all of the header except the payload length + plen := int(binary.BigEndian.Uint16(buf[1:3])) + if _, err := io.ReadFull(conn, buf[:plen]); err != nil { + return nil, err + } + + msg4 := messagebuffer{ + ciphertext: buf[:plen], + } + _, p, valid = RecvMessage(&session, &msg4) + if !valid { + return nil, errors.New("transport message decryption failed") + } + + return p, nil +} + +func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachineKey key.MachinePublic, payload []byte) ([]byte, error) { + var mk keypair + copy(mk.private_key[:], controlKey.UntypedBytes()) + copy(mk.public_key[:], controlKey.Public().UntypedBytes()) + session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{}) + + var buf [1024]byte + if _, err := io.ReadFull(conn, buf[:101]); err != nil { + return nil, err + } + // Ignore the header, we're just checking the noise implementation. + msg1 := messagebuffer{ + ns: buf[37:85], + ciphertext: buf[85:101], + } + copy(msg1.ne[:], buf[5:37]) + _, p, valid := RecvMessage(&session, &msg1) + if !valid { + return nil, errors.New("handshake failed") + } + if len(p) != 0 { + return nil, errors.New("non-empty payload") + } + + _, msg2 := SendMessage(&session, nil) + var hdr [headerLen]byte + hdr[0] = msgTypeResponse + binary.BigEndian.PutUint16(hdr[1:3], 48) + if _, err := conn.Write(hdr[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg2.ne[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg2.ciphertext[:]); err != nil { + return nil, err + } + + if _, err := io.ReadFull(conn, buf[:3]); err != nil { + return nil, err + } + plen := int(binary.BigEndian.Uint16(buf[1:3])) + if _, err := io.ReadFull(conn, buf[:plen]); err != nil { + return nil, err + } + + msg3 := messagebuffer{ + ciphertext: buf[:plen], + } + _, p, valid = RecvMessage(&session, &msg3) + if !valid { + return nil, errors.New("transport message decryption failed") + } + + _, msg4 := SendMessage(&session, payload) + hdr[0] = msgTypeRecord + binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg4.ciphertext))) + if _, err := conn.Write(hdr[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg4.ciphertext); err != nil { + return nil, err + } + + return p, nil +} diff --git a/control/controlbase/messages.go b/control/controlbase/messages.go index 8993786819b6c..59073088f5e81 100644 --- a/control/controlbase/messages.go +++ b/control/controlbase/messages.go @@ -1,87 +1,87 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import "encoding/binary" - -const ( - // msgTypeInitiation frames carry a Noise IK handshake initiation message. - msgTypeInitiation = 1 - // msgTypeResponse frames carry a Noise IK handshake response message. - msgTypeResponse = 2 - // msgTypeError frames carry an unauthenticated human-readable - // error message. - // - // Errors reported in this message type must be treated as public - // hints only. They are not encrypted or authenticated, and so can - // be seen and tampered with on the wire. - msgTypeError = 3 - // msgTypeRecord frames carry session data bytes. - msgTypeRecord = 4 - - // headerLen is the size of the header on all messages except msgTypeInitiation. - headerLen = 3 - // initiationHeaderLen is the size of the header on all msgTypeInitiation messages. - initiationHeaderLen = 5 -) - -// initiationMessage is the protocol message sent from a client -// machine to a control server. -// -// 2b: protocol version -// 1b: message type (0x01) -// 2b: payload length (96) -// 5b: header (see headerLen for fields) -// 32b: client ephemeral public key (cleartext) -// 48b: client machine public key (encrypted) -// 16b: message tag (authenticates the whole message) -type initiationMessage [101]byte - -func mkInitiationMessage(protocolVersion uint16) initiationMessage { - var ret initiationMessage - binary.BigEndian.PutUint16(ret[:2], protocolVersion) - ret[2] = msgTypeInitiation - binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload()))) - return ret -} - -func (m *initiationMessage) Header() []byte { return m[:initiationHeaderLen] } -func (m *initiationMessage) Payload() []byte { return m[initiationHeaderLen:] } - -func (m *initiationMessage) Version() uint16 { return binary.BigEndian.Uint16(m[:2]) } -func (m *initiationMessage) Type() byte { return m[2] } -func (m *initiationMessage) Length() int { return int(binary.BigEndian.Uint16(m[3:5])) } - -func (m *initiationMessage) EphemeralPub() []byte { - return m[initiationHeaderLen : initiationHeaderLen+32] -} -func (m *initiationMessage) MachinePub() []byte { - return m[initiationHeaderLen+32 : initiationHeaderLen+32+48] -} -func (m *initiationMessage) Tag() []byte { return m[initiationHeaderLen+32+48:] } - -// responseMessage is the protocol message sent from a control server -// to a client machine. -// -// 1b: message type (0x02) -// 2b: payload length (48) -// 32b: control ephemeral public key (cleartext) -// 16b: message tag (authenticates the whole message) -type responseMessage [51]byte - -func mkResponseMessage() responseMessage { - var ret responseMessage - ret[0] = msgTypeResponse - binary.BigEndian.PutUint16(ret[1:], uint16(len(ret.Payload()))) - return ret -} - -func (m *responseMessage) Header() []byte { return m[:headerLen] } -func (m *responseMessage) Payload() []byte { return m[headerLen:] } - -func (m *responseMessage) Type() byte { return m[0] } -func (m *responseMessage) Length() int { return int(binary.BigEndian.Uint16(m[1:3])) } - -func (m *responseMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] } -func (m *responseMessage) Tag() []byte { return m[headerLen+32:] } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlbase + +import "encoding/binary" + +const ( + // msgTypeInitiation frames carry a Noise IK handshake initiation message. + msgTypeInitiation = 1 + // msgTypeResponse frames carry a Noise IK handshake response message. + msgTypeResponse = 2 + // msgTypeError frames carry an unauthenticated human-readable + // error message. + // + // Errors reported in this message type must be treated as public + // hints only. They are not encrypted or authenticated, and so can + // be seen and tampered with on the wire. + msgTypeError = 3 + // msgTypeRecord frames carry session data bytes. + msgTypeRecord = 4 + + // headerLen is the size of the header on all messages except msgTypeInitiation. + headerLen = 3 + // initiationHeaderLen is the size of the header on all msgTypeInitiation messages. + initiationHeaderLen = 5 +) + +// initiationMessage is the protocol message sent from a client +// machine to a control server. +// +// 2b: protocol version +// 1b: message type (0x01) +// 2b: payload length (96) +// 5b: header (see headerLen for fields) +// 32b: client ephemeral public key (cleartext) +// 48b: client machine public key (encrypted) +// 16b: message tag (authenticates the whole message) +type initiationMessage [101]byte + +func mkInitiationMessage(protocolVersion uint16) initiationMessage { + var ret initiationMessage + binary.BigEndian.PutUint16(ret[:2], protocolVersion) + ret[2] = msgTypeInitiation + binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload()))) + return ret +} + +func (m *initiationMessage) Header() []byte { return m[:initiationHeaderLen] } +func (m *initiationMessage) Payload() []byte { return m[initiationHeaderLen:] } + +func (m *initiationMessage) Version() uint16 { return binary.BigEndian.Uint16(m[:2]) } +func (m *initiationMessage) Type() byte { return m[2] } +func (m *initiationMessage) Length() int { return int(binary.BigEndian.Uint16(m[3:5])) } + +func (m *initiationMessage) EphemeralPub() []byte { + return m[initiationHeaderLen : initiationHeaderLen+32] +} +func (m *initiationMessage) MachinePub() []byte { + return m[initiationHeaderLen+32 : initiationHeaderLen+32+48] +} +func (m *initiationMessage) Tag() []byte { return m[initiationHeaderLen+32+48:] } + +// responseMessage is the protocol message sent from a control server +// to a client machine. +// +// 1b: message type (0x02) +// 2b: payload length (48) +// 32b: control ephemeral public key (cleartext) +// 16b: message tag (authenticates the whole message) +type responseMessage [51]byte + +func mkResponseMessage() responseMessage { + var ret responseMessage + ret[0] = msgTypeResponse + binary.BigEndian.PutUint16(ret[1:], uint16(len(ret.Payload()))) + return ret +} + +func (m *responseMessage) Header() []byte { return m[:headerLen] } +func (m *responseMessage) Payload() []byte { return m[headerLen:] } + +func (m *responseMessage) Type() byte { return m[0] } +func (m *responseMessage) Length() int { return int(binary.BigEndian.Uint16(m[1:3])) } + +func (m *responseMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] } +func (m *responseMessage) Tag() []byte { return m[headerLen+32:] } diff --git a/control/controlclient/sign.go b/control/controlclient/sign.go index 5e72f1cf4b2b6..e3a479c283c62 100644 --- a/control/controlclient/sign.go +++ b/control/controlclient/sign.go @@ -1,42 +1,42 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlclient - -import ( - "crypto" - "errors" - "fmt" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -var ( - errNoCertStore = errors.New("no certificate store") - errCertificateNotConfigured = errors.New("no certificate subject configured") - errUnsupportedSignatureVersion = errors.New("unsupported signature version") -) - -// HashRegisterRequest generates the hash required sign or verify a -// tailcfg.RegisterRequest. -func HashRegisterRequest( - version tailcfg.SignatureType, ts time.Time, serverURL string, deviceCert []byte, - serverPubKey, machinePubKey key.MachinePublic) ([]byte, error) { - h := crypto.SHA256.New() - - // hash.Hash.Write never returns an error, so we don't check for one here. - switch version { - case tailcfg.SignatureV1: - fmt.Fprintf(h, "%s%s%s%s%s", - ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey.ShortString(), machinePubKey.ShortString()) - case tailcfg.SignatureV2: - fmt.Fprintf(h, "%s%s%s%s%s", - ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey, machinePubKey) - default: - return nil, errUnsupportedSignatureVersion - } - - return h.Sum(nil), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlclient + +import ( + "crypto" + "errors" + "fmt" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +var ( + errNoCertStore = errors.New("no certificate store") + errCertificateNotConfigured = errors.New("no certificate subject configured") + errUnsupportedSignatureVersion = errors.New("unsupported signature version") +) + +// HashRegisterRequest generates the hash required sign or verify a +// tailcfg.RegisterRequest. +func HashRegisterRequest( + version tailcfg.SignatureType, ts time.Time, serverURL string, deviceCert []byte, + serverPubKey, machinePubKey key.MachinePublic) ([]byte, error) { + h := crypto.SHA256.New() + + // hash.Hash.Write never returns an error, so we don't check for one here. + switch version { + case tailcfg.SignatureV1: + fmt.Fprintf(h, "%s%s%s%s%s", + ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey.ShortString(), machinePubKey.ShortString()) + case tailcfg.SignatureV2: + fmt.Fprintf(h, "%s%s%s%s%s", + ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey, machinePubKey) + default: + return nil, errUnsupportedSignatureVersion + } + + return h.Sum(nil), nil +} diff --git a/control/controlclient/sign_supported_test.go b/control/controlclient/sign_supported_test.go index ca41794d11775..e20349a4e82c3 100644 --- a/control/controlclient/sign_supported_test.go +++ b/control/controlclient/sign_supported_test.go @@ -1,236 +1,236 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build windows && cgo - -package controlclient - -import ( - "crypto" - "crypto/x509" - "crypto/x509/pkix" - "errors" - "reflect" - "testing" - "time" - - "github.com/tailscale/certstore" -) - -const ( - testRootCommonName = "testroot" - testRootSubject = "CN=testroot" -) - -type testIdentity struct { - chain []*x509.Certificate -} - -func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate { - return []*x509.Certificate{ - { - NotBefore: notBefore, - NotAfter: notAfter, - PublicKeyAlgorithm: x509.RSA, - }, - { - Subject: pkix.Name{ - CommonName: rootCommonName, - }, - PublicKeyAlgorithm: x509.RSA, - }, - } -} - -func (t *testIdentity) Certificate() (*x509.Certificate, error) { - return t.chain[0], nil -} - -func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) { - return t.chain, nil -} - -func (t *testIdentity) Signer() (crypto.Signer, error) { - return nil, errors.New("not implemented") -} - -func (t *testIdentity) Delete() error { - return errors.New("not implemented") -} - -func (t *testIdentity) Close() {} - -func TestSelectIdentityFromSlice(t *testing.T) { - var times []time.Time - for _, ts := range []string{ - "2000-01-01T00:00:00Z", - "2001-01-01T00:00:00Z", - "2002-01-01T00:00:00Z", - "2003-01-01T00:00:00Z", - } { - tm, err := time.Parse(time.RFC3339, ts) - if err != nil { - t.Fatal(err) - } - times = append(times, tm) - } - - tests := []struct { - name string - subject string - ids []certstore.Identity - now time.Time - // wantIndex is an index into ids, or -1 for nil. - wantIndex int - }{ - { - name: "single unexpired identity", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[2]), - }, - }, - now: times[1], - wantIndex: 0, - }, - { - name: "single expired identity", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - }, - now: times[2], - wantIndex: -1, - }, - { - name: "unrelated ids", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain("something", times[0], times[2]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[2]), - }, - &testIdentity{ - chain: makeChain("else", times[0], times[2]), - }, - }, - now: times[1], - wantIndex: 1, - }, - { - name: "expired with unrelated ids", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain("something", times[0], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - &testIdentity{ - chain: makeChain("else", times[0], times[3]), - }, - }, - now: times[2], - wantIndex: -1, - }, - { - name: "one expired", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[1], times[3]), - }, - }, - now: times[2], - wantIndex: 1, - }, - { - name: "two certs both unexpired", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[1], times[3]), - }, - }, - now: times[2], - wantIndex: 1, - }, - { - name: "two unexpired one expired", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[1], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - }, - now: times[2], - wantIndex: 1, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now) - - if gotId == nil && gotChain != nil { - t.Error("id is nil: got non-nil chain, want nil chain") - return - } - if gotId != nil && gotChain == nil { - t.Error("id is not nil: got nil chain, want non-nil chain") - return - } - if tt.wantIndex == -1 { - if gotId != nil { - t.Error("got non-nil id, want nil id") - } - return - } - if gotId == nil { - t.Error("got nil id, want non-nil id") - return - } - if gotId != tt.ids[tt.wantIndex] { - found := -1 - for i := range tt.ids { - if tt.ids[i] == gotId { - found = i - break - } - } - if found == -1 { - t.Errorf("got unknown id, want id at index %v", tt.wantIndex) - } else { - t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex) - } - } - - tid, ok := tt.ids[tt.wantIndex].(*testIdentity) - if !ok { - t.Error("got non-testIdentity, want testIdentity") - return - } - - if !reflect.DeepEqual(tid.chain, gotChain) { - t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows && cgo + +package controlclient + +import ( + "crypto" + "crypto/x509" + "crypto/x509/pkix" + "errors" + "reflect" + "testing" + "time" + + "github.com/tailscale/certstore" +) + +const ( + testRootCommonName = "testroot" + testRootSubject = "CN=testroot" +) + +type testIdentity struct { + chain []*x509.Certificate +} + +func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate { + return []*x509.Certificate{ + { + NotBefore: notBefore, + NotAfter: notAfter, + PublicKeyAlgorithm: x509.RSA, + }, + { + Subject: pkix.Name{ + CommonName: rootCommonName, + }, + PublicKeyAlgorithm: x509.RSA, + }, + } +} + +func (t *testIdentity) Certificate() (*x509.Certificate, error) { + return t.chain[0], nil +} + +func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) { + return t.chain, nil +} + +func (t *testIdentity) Signer() (crypto.Signer, error) { + return nil, errors.New("not implemented") +} + +func (t *testIdentity) Delete() error { + return errors.New("not implemented") +} + +func (t *testIdentity) Close() {} + +func TestSelectIdentityFromSlice(t *testing.T) { + var times []time.Time + for _, ts := range []string{ + "2000-01-01T00:00:00Z", + "2001-01-01T00:00:00Z", + "2002-01-01T00:00:00Z", + "2003-01-01T00:00:00Z", + } { + tm, err := time.Parse(time.RFC3339, ts) + if err != nil { + t.Fatal(err) + } + times = append(times, tm) + } + + tests := []struct { + name string + subject string + ids []certstore.Identity + now time.Time + // wantIndex is an index into ids, or -1 for nil. + wantIndex int + }{ + { + name: "single unexpired identity", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[2]), + }, + }, + now: times[1], + wantIndex: 0, + }, + { + name: "single expired identity", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + }, + now: times[2], + wantIndex: -1, + }, + { + name: "unrelated ids", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain("something", times[0], times[2]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[2]), + }, + &testIdentity{ + chain: makeChain("else", times[0], times[2]), + }, + }, + now: times[1], + wantIndex: 1, + }, + { + name: "expired with unrelated ids", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain("something", times[0], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + &testIdentity{ + chain: makeChain("else", times[0], times[3]), + }, + }, + now: times[2], + wantIndex: -1, + }, + { + name: "one expired", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[1], times[3]), + }, + }, + now: times[2], + wantIndex: 1, + }, + { + name: "two certs both unexpired", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[1], times[3]), + }, + }, + now: times[2], + wantIndex: 1, + }, + { + name: "two unexpired one expired", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[1], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + }, + now: times[2], + wantIndex: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now) + + if gotId == nil && gotChain != nil { + t.Error("id is nil: got non-nil chain, want nil chain") + return + } + if gotId != nil && gotChain == nil { + t.Error("id is not nil: got nil chain, want non-nil chain") + return + } + if tt.wantIndex == -1 { + if gotId != nil { + t.Error("got non-nil id, want nil id") + } + return + } + if gotId == nil { + t.Error("got nil id, want non-nil id") + return + } + if gotId != tt.ids[tt.wantIndex] { + found := -1 + for i := range tt.ids { + if tt.ids[i] == gotId { + found = i + break + } + } + if found == -1 { + t.Errorf("got unknown id, want id at index %v", tt.wantIndex) + } else { + t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex) + } + } + + tid, ok := tt.ids[tt.wantIndex].(*testIdentity) + if !ok { + t.Error("got non-testIdentity, want testIdentity") + return + } + + if !reflect.DeepEqual(tid.chain, gotChain) { + t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex) + } + }) + } +} diff --git a/control/controlclient/sign_unsupported.go b/control/controlclient/sign_unsupported.go index 4ec40d502773f..5e161dcbce453 100644 --- a/control/controlclient/sign_unsupported.go +++ b/control/controlclient/sign_unsupported.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package controlclient - -import ( - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -// signRegisterRequest on non-supported platforms always returns errNoCertStore. -func signRegisterRequest(req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) error { - return errNoCertStore -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package controlclient + +import ( + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// signRegisterRequest on non-supported platforms always returns errNoCertStore. +func signRegisterRequest(req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) error { + return errNoCertStore +} diff --git a/control/controlclient/status.go b/control/controlclient/status.go index 7dba14d3f5015..d0fdf80d745e3 100644 --- a/control/controlclient/status.go +++ b/control/controlclient/status.go @@ -1,125 +1,125 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlclient - -import ( - "encoding/json" - "fmt" - "reflect" - - "tailscale.com/types/netmap" - "tailscale.com/types/persist" - "tailscale.com/types/structs" -) - -// State is the high-level state of the client. It is used only in -// unit tests for proper sequencing, don't depend on it anywhere else. -// -// TODO(apenwarr): eliminate the state, as it's now obsolete. -// -// apenwarr: Historical note: controlclient.Auto was originally -// intended to be the state machine for the whole tailscale client, but that -// turned out to not be the right abstraction layer, and it moved to -// ipn.Backend. Since ipn.Backend now has a state machine, it would be -// much better if controlclient could be a simple stateless API. But the -// current server-side API (two interlocking polling https calls) makes that -// very hard to implement. A server side API change could untangle this and -// remove all the statefulness. -type State int - -const ( - StateNew = State(iota) - StateNotAuthenticated - StateAuthenticating - StateURLVisitRequired - StateAuthenticated - StateSynchronized // connected and received map update -) - -func (s State) AppendText(b []byte) ([]byte, error) { - return append(b, s.String()...), nil -} - -func (s State) MarshalText() ([]byte, error) { - return []byte(s.String()), nil -} - -func (s State) String() string { - switch s { - case StateNew: - return "state:new" - case StateNotAuthenticated: - return "state:not-authenticated" - case StateAuthenticating: - return "state:authenticating" - case StateURLVisitRequired: - return "state:url-visit-required" - case StateAuthenticated: - return "state:authenticated" - case StateSynchronized: - return "state:synchronized" - default: - return fmt.Sprintf("state:unknown:%d", int(s)) - } -} - -type Status struct { - _ structs.Incomparable - - // Err, if non-nil, is an error that occurred while logging in. - // - // If it's of type UserVisibleError then it's meant to be shown to users in - // their Tailscale client. Otherwise it's just logged to tailscaled's logs. - Err error - - // URL, if non-empty, is the interactive URL to visit to finish logging in. - URL string - - // NetMap is the latest server-pushed state of the tailnet network. - NetMap *netmap.NetworkMap - - // Persist, when Valid, is the locally persisted configuration. - // - // TODO(bradfitz,maisem): clarify this. - Persist persist.PersistView - - // state is the internal state. It should not be exposed outside this - // package, but we have some automated tests elsewhere that need to - // use it via the StateForTest accessor. - // TODO(apenwarr): Unexport or remove these. - state State -} - -// LoginFinished reports whether the controlclient is in its "StateAuthenticated" -// state where it's in a happy register state but not yet in a map poll. -// -// TODO(bradfitz): delete this and everything around Status.state. -func (s *Status) LoginFinished() bool { return s.state == StateAuthenticated } - -// StateForTest returns the internal state of s for tests only. -func (s *Status) StateForTest() State { return s.state } - -// SetStateForTest sets the internal state of s for tests only. -func (s *Status) SetStateForTest(state State) { s.state = state } - -// Equal reports whether s and s2 are equal. -func (s *Status) Equal(s2 *Status) bool { - if s == nil && s2 == nil { - return true - } - return s != nil && s2 != nil && - s.Err == s2.Err && - s.URL == s2.URL && - s.state == s2.state && - reflect.DeepEqual(s.Persist, s2.Persist) && - reflect.DeepEqual(s.NetMap, s2.NetMap) -} - -func (s Status) String() string { - b, err := json.MarshalIndent(s, "", "\t") - if err != nil { - panic(err) - } - return s.state.String() + " " + string(b) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlclient + +import ( + "encoding/json" + "fmt" + "reflect" + + "tailscale.com/types/netmap" + "tailscale.com/types/persist" + "tailscale.com/types/structs" +) + +// State is the high-level state of the client. It is used only in +// unit tests for proper sequencing, don't depend on it anywhere else. +// +// TODO(apenwarr): eliminate the state, as it's now obsolete. +// +// apenwarr: Historical note: controlclient.Auto was originally +// intended to be the state machine for the whole tailscale client, but that +// turned out to not be the right abstraction layer, and it moved to +// ipn.Backend. Since ipn.Backend now has a state machine, it would be +// much better if controlclient could be a simple stateless API. But the +// current server-side API (two interlocking polling https calls) makes that +// very hard to implement. A server side API change could untangle this and +// remove all the statefulness. +type State int + +const ( + StateNew = State(iota) + StateNotAuthenticated + StateAuthenticating + StateURLVisitRequired + StateAuthenticated + StateSynchronized // connected and received map update +) + +func (s State) AppendText(b []byte) ([]byte, error) { + return append(b, s.String()...), nil +} + +func (s State) MarshalText() ([]byte, error) { + return []byte(s.String()), nil +} + +func (s State) String() string { + switch s { + case StateNew: + return "state:new" + case StateNotAuthenticated: + return "state:not-authenticated" + case StateAuthenticating: + return "state:authenticating" + case StateURLVisitRequired: + return "state:url-visit-required" + case StateAuthenticated: + return "state:authenticated" + case StateSynchronized: + return "state:synchronized" + default: + return fmt.Sprintf("state:unknown:%d", int(s)) + } +} + +type Status struct { + _ structs.Incomparable + + // Err, if non-nil, is an error that occurred while logging in. + // + // If it's of type UserVisibleError then it's meant to be shown to users in + // their Tailscale client. Otherwise it's just logged to tailscaled's logs. + Err error + + // URL, if non-empty, is the interactive URL to visit to finish logging in. + URL string + + // NetMap is the latest server-pushed state of the tailnet network. + NetMap *netmap.NetworkMap + + // Persist, when Valid, is the locally persisted configuration. + // + // TODO(bradfitz,maisem): clarify this. + Persist persist.PersistView + + // state is the internal state. It should not be exposed outside this + // package, but we have some automated tests elsewhere that need to + // use it via the StateForTest accessor. + // TODO(apenwarr): Unexport or remove these. + state State +} + +// LoginFinished reports whether the controlclient is in its "StateAuthenticated" +// state where it's in a happy register state but not yet in a map poll. +// +// TODO(bradfitz): delete this and everything around Status.state. +func (s *Status) LoginFinished() bool { return s.state == StateAuthenticated } + +// StateForTest returns the internal state of s for tests only. +func (s *Status) StateForTest() State { return s.state } + +// SetStateForTest sets the internal state of s for tests only. +func (s *Status) SetStateForTest(state State) { s.state = state } + +// Equal reports whether s and s2 are equal. +func (s *Status) Equal(s2 *Status) bool { + if s == nil && s2 == nil { + return true + } + return s != nil && s2 != nil && + s.Err == s2.Err && + s.URL == s2.URL && + s.state == s2.state && + reflect.DeepEqual(s.Persist, s2.Persist) && + reflect.DeepEqual(s.NetMap, s2.NetMap) +} + +func (s Status) String() string { + b, err := json.MarshalIndent(s, "", "\t") + if err != nil { + panic(err) + } + return s.state.String() + " " + string(b) +} diff --git a/control/controlhttp/client_common.go b/control/controlhttp/client_common.go index 72a89e3cdbbed..dd94e93cdc3cf 100644 --- a/control/controlhttp/client_common.go +++ b/control/controlhttp/client_common.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlhttp - -import ( - "tailscale.com/control/controlbase" -) - -// ClientConn is a Tailscale control client as returned by the Dialer. -// -// It's effectively just a *controlbase.Conn (which it embeds) with -// optional metadata. -type ClientConn struct { - // Conn is the noise connection. - *controlbase.Conn -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlhttp + +import ( + "tailscale.com/control/controlbase" +) + +// ClientConn is a Tailscale control client as returned by the Dialer. +// +// It's effectively just a *controlbase.Conn (which it embeds) with +// optional metadata. +type ClientConn struct { + // Conn is the noise connection. + *controlbase.Conn +} diff --git a/derp/README.md b/derp/README.md index acd986ea9cf08..16877020d465e 100644 --- a/derp/README.md +++ b/derp/README.md @@ -1,61 +1,61 @@ -# DERP - -This directory (and subdirectories) contain the DERP code. The server itself is -in `../cmd/derper`. - -DERP is a packet relay system (client and servers) where peers are addressed -using WireGuard public keys instead of IP addresses. - -It relays two types of packets: - -* "Disco" discovery messages (see `../disco`) as the a side channel during [NAT - traversal](https://tailscale.com/blog/how-nat-traversal-works/). - -* Encrypted WireGuard packets as the fallback of last resort when UDP is blocked - or NAT traversal fails. - -## DERP Map - -Each client receives a "[DERP -Map](https://pkg.go.dev/tailscale.com/tailcfg#DERPMap)" from the coordination -server describing the DERP servers the client should try to use. - -The client picks its home "DERP home" based on latency. This is done to keep -costs low by avoid using cloud load balancers (pricey) or anycast, which would -necessarily require server-side routing between DERP regions. - -Clients pick their DERP home and report it to the coordination server which -shares it to all the peers in the tailnet. When a peer wants to send a packet -and it doesn't already have a WireGuard session open, it sends disco messages -(some direct, and some over DERP), trying to do the NAT traversal. The client -will make connections to multiple DERP regions as needed. Only the DERP home -region connection needs to be alive forever. - -## DERP Regions - -Tailscale runs 1 or more DERP nodes (instances of `cmd/derper`) in various -geographic regions to make sure users have low latency to their DERP home. - -Regions generally have multiple nodes per region "meshed" (routing to each -other) together for redundancy: it allows for cloud failures or upgrades without -kicking users out to a higher latency region. Instead, clients will reconnect to -the next node in the region. Each node in the region is required to to be meshed -with every other node in the region and forward packets to the other nodes in -the region. Packets are forwarded only one hop within the region. There is no -routing between regions. The assumption is that the mesh TCP connections are -over a VPC that's very fast, low latency, and not charged per byte. The -coordination server assigns the list of nodes in a region as a function of the -tailnet, so all nodes within a tailnet should generally be on the same node and -not require forwarding. Only after a failure do clients of a particular tailnet -get split between nodes in a region and require inter-node forwarding. But over -time it balances back out. There's also an admin-only DERP frame type to force -close the TCP connection of a particular client to force them to reconnect to -their primary if the operator wants to force things to balance out sooner. -(Using the `(*derphttp.Client).ClosePeer` method, as used by Tailscale's -internal rarely-used `cmd/derpprune` maintenance tool) - -We generally run a minimum of three nodes in a region not for quorum reasons -(there's no voting) but just because two is too uncomfortably few for cascading -failure reasons: if you're running two nodes at 51% load (CPU, memory, etc) and -then one fails, that makes the second one fail. With three or more nodes, you +# DERP + +This directory (and subdirectories) contain the DERP code. The server itself is +in `../cmd/derper`. + +DERP is a packet relay system (client and servers) where peers are addressed +using WireGuard public keys instead of IP addresses. + +It relays two types of packets: + +* "Disco" discovery messages (see `../disco`) as the a side channel during [NAT + traversal](https://tailscale.com/blog/how-nat-traversal-works/). + +* Encrypted WireGuard packets as the fallback of last resort when UDP is blocked + or NAT traversal fails. + +## DERP Map + +Each client receives a "[DERP +Map](https://pkg.go.dev/tailscale.com/tailcfg#DERPMap)" from the coordination +server describing the DERP servers the client should try to use. + +The client picks its home "DERP home" based on latency. This is done to keep +costs low by avoid using cloud load balancers (pricey) or anycast, which would +necessarily require server-side routing between DERP regions. + +Clients pick their DERP home and report it to the coordination server which +shares it to all the peers in the tailnet. When a peer wants to send a packet +and it doesn't already have a WireGuard session open, it sends disco messages +(some direct, and some over DERP), trying to do the NAT traversal. The client +will make connections to multiple DERP regions as needed. Only the DERP home +region connection needs to be alive forever. + +## DERP Regions + +Tailscale runs 1 or more DERP nodes (instances of `cmd/derper`) in various +geographic regions to make sure users have low latency to their DERP home. + +Regions generally have multiple nodes per region "meshed" (routing to each +other) together for redundancy: it allows for cloud failures or upgrades without +kicking users out to a higher latency region. Instead, clients will reconnect to +the next node in the region. Each node in the region is required to to be meshed +with every other node in the region and forward packets to the other nodes in +the region. Packets are forwarded only one hop within the region. There is no +routing between regions. The assumption is that the mesh TCP connections are +over a VPC that's very fast, low latency, and not charged per byte. The +coordination server assigns the list of nodes in a region as a function of the +tailnet, so all nodes within a tailnet should generally be on the same node and +not require forwarding. Only after a failure do clients of a particular tailnet +get split between nodes in a region and require inter-node forwarding. But over +time it balances back out. There's also an admin-only DERP frame type to force +close the TCP connection of a particular client to force them to reconnect to +their primary if the operator wants to force things to balance out sooner. +(Using the `(*derphttp.Client).ClosePeer` method, as used by Tailscale's +internal rarely-used `cmd/derpprune` maintenance tool) + +We generally run a minimum of three nodes in a region not for quorum reasons +(there's no voting) but just because two is too uncomfortably few for cascading +failure reasons: if you're running two nodes at 51% load (CPU, memory, etc) and +then one fails, that makes the second one fail. With three or more nodes, you can run each node a bit hotter. \ No newline at end of file diff --git a/derp/testdata/example_ss.txt b/derp/testdata/example_ss.txt index ae25003b22856..2885f1bc15a16 100644 --- a/derp/testdata/example_ss.txt +++ b/derp/testdata/example_ss.txt @@ -1,8 +1,8 @@ -ESTAB 0 0 10.255.1.11:35238 34.210.105.16:https - cubic wscale:7,7 rto:236 rtt:34.14/3.432 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:8 ssthresh:6 bytes_sent:38056577 bytes_retrans:2918 bytes_acked:38053660 bytes_received:6973211 segs_out:165090 segs_in:124227 data_segs_out:78018 data_segs_in:71645 send 2.71Mbps lastsnd:1156 lastrcv:1120 lastack:1120 pacing_rate 3.26Mbps delivery_rate 2.35Mbps delivered:78017 app_limited busy:2586132ms retrans:0/6 dsack_dups:4 reordering:5 reord_seen:15 rcv_rtt:126355 rcv_space:65780 rcv_ssthresh:541928 minrtt:26.632 -ESTAB 0 80 100.79.58.14:ssh 100.95.73.104:58145 - cubic wscale:6,7 rto:224 rtt:23.051/2.03 ato:172 mss:1228 pmtu:1280 rcvmss:1228 advmss:1228 cwnd:10 ssthresh:94 bytes_sent:1591815 bytes_retrans:944 bytes_acked:1590791 bytes_received:158925 segs_out:8070 segs_in:8858 data_segs_out:7452 data_segs_in:3789 send 4.26Mbps lastsnd:4 lastrcv:4 lastack:4 pacing_rate 8.52Mbps delivery_rate 10.9Mbps delivered:7451 app_limited busy:61656ms unacked:2 retrans:0/10 dsack_dups:10 rcv_rtt:174712 rcv_space:65025 rcv_ssthresh:64296 minrtt:16.186 -ESTAB 0 374 10.255.1.11:43254 167.172.206.31:https - cubic wscale:7,7 rto:224 rtt:22.55/1.941 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:6 ssthresh:4 bytes_sent:14594668 bytes_retrans:173314 bytes_acked:14420981 bytes_received:4207111 segs_out:80566 segs_in:70310 data_segs_out:24317 data_segs_in:20365 send 3.08Mbps lastsnd:4 lastrcv:4 lastack:4 pacing_rate 3.7Mbps delivery_rate 3.05Mbps delivered:24111 app_limited busy:184820ms unacked:2 retrans:0/185 dsack_dups:1 reord_seen:3 rcv_rtt:651.262 rcv_space:226657 rcv_ssthresh:1557136 minrtt:10.18 -ESTAB 0 0 10.255.1.11:33036 3.121.18.47:https - cubic wscale:7,7 rto:372 rtt:168.408/2.044 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:10 bytes_sent:27500 bytes_acked:27501 bytes_received:1386524 segs_out:10990 segs_in:11037 data_segs_out:303 data_segs_in:3414 send 688kbps lastsnd:125776 lastrcv:9640 lastack:22760 pacing_rate 1.38Mbps delivery_rate 482kbps delivered:304 app_limited busy:43024ms rcv_rtt:3345.12 rcv_space:62431 rcv_ssthresh:760472 minrtt:168.867 +ESTAB 0 0 10.255.1.11:35238 34.210.105.16:https + cubic wscale:7,7 rto:236 rtt:34.14/3.432 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:8 ssthresh:6 bytes_sent:38056577 bytes_retrans:2918 bytes_acked:38053660 bytes_received:6973211 segs_out:165090 segs_in:124227 data_segs_out:78018 data_segs_in:71645 send 2.71Mbps lastsnd:1156 lastrcv:1120 lastack:1120 pacing_rate 3.26Mbps delivery_rate 2.35Mbps delivered:78017 app_limited busy:2586132ms retrans:0/6 dsack_dups:4 reordering:5 reord_seen:15 rcv_rtt:126355 rcv_space:65780 rcv_ssthresh:541928 minrtt:26.632 +ESTAB 0 80 100.79.58.14:ssh 100.95.73.104:58145 + cubic wscale:6,7 rto:224 rtt:23.051/2.03 ato:172 mss:1228 pmtu:1280 rcvmss:1228 advmss:1228 cwnd:10 ssthresh:94 bytes_sent:1591815 bytes_retrans:944 bytes_acked:1590791 bytes_received:158925 segs_out:8070 segs_in:8858 data_segs_out:7452 data_segs_in:3789 send 4.26Mbps lastsnd:4 lastrcv:4 lastack:4 pacing_rate 8.52Mbps delivery_rate 10.9Mbps delivered:7451 app_limited busy:61656ms unacked:2 retrans:0/10 dsack_dups:10 rcv_rtt:174712 rcv_space:65025 rcv_ssthresh:64296 minrtt:16.186 +ESTAB 0 374 10.255.1.11:43254 167.172.206.31:https + cubic wscale:7,7 rto:224 rtt:22.55/1.941 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:6 ssthresh:4 bytes_sent:14594668 bytes_retrans:173314 bytes_acked:14420981 bytes_received:4207111 segs_out:80566 segs_in:70310 data_segs_out:24317 data_segs_in:20365 send 3.08Mbps lastsnd:4 lastrcv:4 lastack:4 pacing_rate 3.7Mbps delivery_rate 3.05Mbps delivered:24111 app_limited busy:184820ms unacked:2 retrans:0/185 dsack_dups:1 reord_seen:3 rcv_rtt:651.262 rcv_space:226657 rcv_ssthresh:1557136 minrtt:10.18 +ESTAB 0 0 10.255.1.11:33036 3.121.18.47:https + cubic wscale:7,7 rto:372 rtt:168.408/2.044 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:10 bytes_sent:27500 bytes_acked:27501 bytes_received:1386524 segs_out:10990 segs_in:11037 data_segs_out:303 data_segs_in:3414 send 688kbps lastsnd:125776 lastrcv:9640 lastack:22760 pacing_rate 1.38Mbps delivery_rate 482kbps delivered:304 app_limited busy:43024ms rcv_rtt:3345.12 rcv_space:62431 rcv_ssthresh:760472 minrtt:168.867 diff --git a/disco/disco_fuzzer.go b/disco/disco_fuzzer.go index 0deede05018d3..b9ffabfb00906 100644 --- a/disco/disco_fuzzer.go +++ b/disco/disco_fuzzer.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -//go:build gofuzz - -package disco - -func Fuzz(data []byte) int { - m, _ := Parse(data) - - newBytes := m.AppendMarshal(data) - parsedMarshall, _ := Parse(newBytes) - - if m != parsedMarshall { - panic("Parsing error") - } - return 1 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +//go:build gofuzz + +package disco + +func Fuzz(data []byte) int { + m, _ := Parse(data) + + newBytes := m.AppendMarshal(data) + parsedMarshall, _ := Parse(newBytes) + + if m != parsedMarshall { + panic("Parsing error") + } + return 1 +} diff --git a/disco/disco_test.go b/disco/disco_test.go index 045425eb722df..1a56324a5a423 100644 --- a/disco/disco_test.go +++ b/disco/disco_test.go @@ -1,118 +1,118 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package disco - -import ( - "fmt" - "net/netip" - "reflect" - "strings" - "testing" - - "go4.org/mem" - "tailscale.com/types/key" -) - -func TestMarshalAndParse(t *testing.T) { - tests := []struct { - name string - want string - m Message - }{ - { - name: "ping", - m: &Ping{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - }, - want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c", - }, - { - name: "ping_with_nodekey_src", - m: &Ping{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), - }, - want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f", - }, - { - name: "ping_with_padding", - m: &Ping{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - Padding: 3, - }, - want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00", - }, - { - name: "ping_with_padding_and_nodekey_src", - m: &Ping{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), - Padding: 3, - }, - want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 00 00", - }, - { - name: "pong", - m: &Pong{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - Src: mustIPPort("2.3.4.5:1234"), - }, - want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00 00 00 00 00 00 00 00 ff ff 02 03 04 05 04 d2", - }, - { - name: "pongv6", - m: &Pong{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - Src: mustIPPort("[fed0::12]:6666"), - }, - want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c fe d0 00 00 00 00 00 00 00 00 00 00 00 00 00 12 1a 0a", - }, - { - name: "call_me_maybe", - m: &CallMeMaybe{}, - want: "03 00", - }, - { - name: "call_me_maybe_endpoints", - m: &CallMeMaybe{ - MyNumber: []netip.AddrPort{ - netip.MustParseAddrPort("1.2.3.4:567"), - netip.MustParseAddrPort("[2001::3456]:789"), - }, - }, - want: "03 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - foo := []byte("foo") - got := string(tt.m.AppendMarshal(foo)) - got, ok := strings.CutPrefix(got, "foo") - if !ok { - t.Fatalf("didn't start with foo: got %q", got) - } - - gotHex := fmt.Sprintf("% x", got) - if gotHex != tt.want { - t.Fatalf("wrong marshal\n got: %s\nwant: %s\n", gotHex, tt.want) - } - - back, err := Parse([]byte(got)) - if err != nil { - t.Fatalf("parse back: %v", err) - } - if !reflect.DeepEqual(back, tt.m) { - t.Errorf("message in %+v doesn't match Parse back result %+v", tt.m, back) - } - }) - } -} - -func mustIPPort(s string) netip.AddrPort { - ipp, err := netip.ParseAddrPort(s) - if err != nil { - panic(err) - } - return ipp -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package disco + +import ( + "fmt" + "net/netip" + "reflect" + "strings" + "testing" + + "go4.org/mem" + "tailscale.com/types/key" +) + +func TestMarshalAndParse(t *testing.T) { + tests := []struct { + name string + want string + m Message + }{ + { + name: "ping", + m: &Ping{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + }, + want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c", + }, + { + name: "ping_with_nodekey_src", + m: &Ping{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), + }, + want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f", + }, + { + name: "ping_with_padding", + m: &Ping{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + Padding: 3, + }, + want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00", + }, + { + name: "ping_with_padding_and_nodekey_src", + m: &Ping{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), + Padding: 3, + }, + want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 00 00", + }, + { + name: "pong", + m: &Pong{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + Src: mustIPPort("2.3.4.5:1234"), + }, + want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00 00 00 00 00 00 00 00 ff ff 02 03 04 05 04 d2", + }, + { + name: "pongv6", + m: &Pong{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + Src: mustIPPort("[fed0::12]:6666"), + }, + want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c fe d0 00 00 00 00 00 00 00 00 00 00 00 00 00 12 1a 0a", + }, + { + name: "call_me_maybe", + m: &CallMeMaybe{}, + want: "03 00", + }, + { + name: "call_me_maybe_endpoints", + m: &CallMeMaybe{ + MyNumber: []netip.AddrPort{ + netip.MustParseAddrPort("1.2.3.4:567"), + netip.MustParseAddrPort("[2001::3456]:789"), + }, + }, + want: "03 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + foo := []byte("foo") + got := string(tt.m.AppendMarshal(foo)) + got, ok := strings.CutPrefix(got, "foo") + if !ok { + t.Fatalf("didn't start with foo: got %q", got) + } + + gotHex := fmt.Sprintf("% x", got) + if gotHex != tt.want { + t.Fatalf("wrong marshal\n got: %s\nwant: %s\n", gotHex, tt.want) + } + + back, err := Parse([]byte(got)) + if err != nil { + t.Fatalf("parse back: %v", err) + } + if !reflect.DeepEqual(back, tt.m) { + t.Errorf("message in %+v doesn't match Parse back result %+v", tt.m, back) + } + }) + } +} + +func mustIPPort(s string) netip.AddrPort { + ipp, err := netip.ParseAddrPort(s) + if err != nil { + panic(err) + } + return ipp +} diff --git a/disco/pcap.go b/disco/pcap.go index 5d60ceb28eeef..71035424868e8 100644 --- a/disco/pcap.go +++ b/disco/pcap.go @@ -1,40 +1,40 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package disco - -import ( - "bytes" - "encoding/binary" - "net/netip" - - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -// ToPCAPFrame marshals the bytes for a pcap record that describe a disco frame. -// -// Warning: Alloc garbage. Acceptable while capturing. -func ToPCAPFrame(src netip.AddrPort, derpNodeSrc key.NodePublic, payload []byte) []byte { - var ( - b bytes.Buffer - flag uint8 - ) - b.Grow(128) // Most disco frames will probably be smaller than this. - - if src.Addr() == tailcfg.DerpMagicIPAddr { - flag |= 0x01 - } - b.WriteByte(flag) // 1b: flag - - derpSrc := derpNodeSrc.Raw32() - b.Write(derpSrc[:]) // 32b: derp public key - binary.Write(&b, binary.LittleEndian, uint16(src.Port())) // 2b: port - addr, _ := src.Addr().MarshalBinary() - binary.Write(&b, binary.LittleEndian, uint16(len(addr))) // 2b: len(addr) - b.Write(addr) // Xb: addr - binary.Write(&b, binary.LittleEndian, uint16(len(payload))) // 2b: len(payload) - b.Write(payload) // Xb: payload - - return b.Bytes() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package disco + +import ( + "bytes" + "encoding/binary" + "net/netip" + + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// ToPCAPFrame marshals the bytes for a pcap record that describe a disco frame. +// +// Warning: Alloc garbage. Acceptable while capturing. +func ToPCAPFrame(src netip.AddrPort, derpNodeSrc key.NodePublic, payload []byte) []byte { + var ( + b bytes.Buffer + flag uint8 + ) + b.Grow(128) // Most disco frames will probably be smaller than this. + + if src.Addr() == tailcfg.DerpMagicIPAddr { + flag |= 0x01 + } + b.WriteByte(flag) // 1b: flag + + derpSrc := derpNodeSrc.Raw32() + b.Write(derpSrc[:]) // 32b: derp public key + binary.Write(&b, binary.LittleEndian, uint16(src.Port())) // 2b: port + addr, _ := src.Addr().MarshalBinary() + binary.Write(&b, binary.LittleEndian, uint16(len(addr))) // 2b: len(addr) + b.Write(addr) // Xb: addr + binary.Write(&b, binary.LittleEndian, uint16(len(payload))) // 2b: len(payload) + b.Write(payload) // Xb: payload + + return b.Bytes() +} diff --git a/docs/bird/sample_bird.conf b/docs/bird/sample_bird.conf index 87222c59af0e6..ed38e66c5c0a2 100644 --- a/docs/bird/sample_bird.conf +++ b/docs/bird/sample_bird.conf @@ -1,16 +1,16 @@ -log syslog all; - -protocol device { - scan time 10; -} - -protocol bgp { - local as 64001; - neighbor 10.40.2.101 as 64002; - ipv4 { - import none; - export all; - }; -} - -include "tailscale_bird.conf"; +log syslog all; + +protocol device { + scan time 10; +} + +protocol bgp { + local as 64001; + neighbor 10.40.2.101 as 64002; + ipv4 { + import none; + export all; + }; +} + +include "tailscale_bird.conf"; diff --git a/docs/bird/tailscale_bird.conf b/docs/bird/tailscale_bird.conf index a5f4307479b79..8211a50a3c58e 100644 --- a/docs/bird/tailscale_bird.conf +++ b/docs/bird/tailscale_bird.conf @@ -1,4 +1,4 @@ -protocol static tailscale { - ipv4; - route 100.64.0.0/10 via "tailscale0"; -} +protocol static tailscale { + ipv4; + route 100.64.0.0/10 via "tailscale0"; +} diff --git a/docs/k8s/Makefile b/docs/k8s/Makefile index 107c1c1361c61..55804c857c049 100644 --- a/docs/k8s/Makefile +++ b/docs/k8s/Makefile @@ -1,25 +1,25 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -TS_ROUTES ?= "" -SA_NAME ?= tailscale -TS_KUBE_SECRET ?= tailscale - -rbac: - @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" role.yaml - @echo "---" - @sed -e "s;{{SA_NAME}};$(SA_NAME);g" rolebinding.yaml - @echo "---" - @sed -e "s;{{SA_NAME}};$(SA_NAME);g" sa.yaml - -sidecar: - @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" sidecar.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" - -userspace-sidecar: - @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" userspace-sidecar.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" - -proxy: - @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" proxy.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" | sed -e "s;{{TS_DEST_IP}};$(TS_DEST_IP);g" - -subnet-router: - @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" subnet.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" | sed -e "s;{{TS_ROUTES}};$(TS_ROUTES);g" +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +TS_ROUTES ?= "" +SA_NAME ?= tailscale +TS_KUBE_SECRET ?= tailscale + +rbac: + @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" role.yaml + @echo "---" + @sed -e "s;{{SA_NAME}};$(SA_NAME);g" rolebinding.yaml + @echo "---" + @sed -e "s;{{SA_NAME}};$(SA_NAME);g" sa.yaml + +sidecar: + @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" sidecar.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" + +userspace-sidecar: + @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" userspace-sidecar.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" + +proxy: + @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" proxy.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" | sed -e "s;{{TS_DEST_IP}};$(TS_DEST_IP);g" + +subnet-router: + @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" subnet.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" | sed -e "s;{{TS_ROUTES}};$(TS_ROUTES);g" diff --git a/docs/k8s/rolebinding.yaml b/docs/k8s/rolebinding.yaml index b32e66b984510..3b18ba8d35e57 100644 --- a/docs/k8s/rolebinding.yaml +++ b/docs/k8s/rolebinding.yaml @@ -1,13 +1,13 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause -apiVersion: rbac.authorization.k8s.io/v1 -kind: RoleBinding -metadata: - name: tailscale -subjects: -- kind: ServiceAccount - name: "{{SA_NAME}}" -roleRef: - kind: Role - name: tailscale - apiGroup: rbac.authorization.k8s.io +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: tailscale +subjects: +- kind: ServiceAccount + name: "{{SA_NAME}}" +roleRef: + kind: Role + name: tailscale + apiGroup: rbac.authorization.k8s.io diff --git a/docs/k8s/sa.yaml b/docs/k8s/sa.yaml index 85b56bd24a7fe..edd3944ba8987 100644 --- a/docs/k8s/sa.yaml +++ b/docs/k8s/sa.yaml @@ -1,6 +1,6 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause -apiVersion: v1 -kind: ServiceAccount -metadata: - name: {{SA_NAME}} +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{SA_NAME}} diff --git a/docs/sysv/tailscale.init b/docs/sysv/tailscale.init index fc22088b16a5b..ca21033df7b27 100755 --- a/docs/sysv/tailscale.init +++ b/docs/sysv/tailscale.init @@ -1,63 +1,63 @@ -#!/bin/sh -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -### BEGIN INIT INFO -# Provides: tailscaled -# Required-Start: -# Required-Stop: -# Default-Start: -# Default-Stop: -# Short-Description: Tailscale Mesh Wireguard VPN -### END INIT INFO - -set -e - -# /etc/init.d/tailscale: start and stop the Tailscale VPN service - -test -x /usr/sbin/tailscaled || exit 0 - -umask 022 - -. /lib/lsb/init-functions - -# Are we running from init? -run_by_init() { - ([ "$previous" ] && [ "$runlevel" ]) || [ "$runlevel" = S ] -} - -export PATH="${PATH:+$PATH:}/usr/sbin:/sbin" - -case "$1" in - start) - log_daemon_msg "Starting Tailscale VPN" "tailscaled" || true - if start-stop-daemon --start --oknodo --name tailscaled -m --pidfile /run/tailscaled.pid --background \ - --exec /usr/sbin/tailscaled -- \ - --state=/var/lib/tailscale/tailscaled.state \ - --socket=/run/tailscale/tailscaled.sock \ - --port 41641; - then - log_end_msg 0 || true - else - log_end_msg 1 || true - fi - ;; - stop) - log_daemon_msg "Stopping Tailscale VPN" "tailscaled" || true - if start-stop-daemon --stop --remove-pidfile --pidfile /run/tailscaled.pid --exec /usr/sbin/tailscaled; then - log_end_msg 0 || true - else - log_end_msg 1 || true - fi - ;; - - status) - status_of_proc -p /run/tailscaled.pid /usr/sbin/tailscaled tailscaled && exit 0 || exit $? - ;; - - *) - log_action_msg "Usage: /etc/init.d/tailscaled {start|stop|status}" || true - exit 1 -esac - -exit 0 +#!/bin/sh +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +### BEGIN INIT INFO +# Provides: tailscaled +# Required-Start: +# Required-Stop: +# Default-Start: +# Default-Stop: +# Short-Description: Tailscale Mesh Wireguard VPN +### END INIT INFO + +set -e + +# /etc/init.d/tailscale: start and stop the Tailscale VPN service + +test -x /usr/sbin/tailscaled || exit 0 + +umask 022 + +. /lib/lsb/init-functions + +# Are we running from init? +run_by_init() { + ([ "$previous" ] && [ "$runlevel" ]) || [ "$runlevel" = S ] +} + +export PATH="${PATH:+$PATH:}/usr/sbin:/sbin" + +case "$1" in + start) + log_daemon_msg "Starting Tailscale VPN" "tailscaled" || true + if start-stop-daemon --start --oknodo --name tailscaled -m --pidfile /run/tailscaled.pid --background \ + --exec /usr/sbin/tailscaled -- \ + --state=/var/lib/tailscale/tailscaled.state \ + --socket=/run/tailscale/tailscaled.sock \ + --port 41641; + then + log_end_msg 0 || true + else + log_end_msg 1 || true + fi + ;; + stop) + log_daemon_msg "Stopping Tailscale VPN" "tailscaled" || true + if start-stop-daemon --stop --remove-pidfile --pidfile /run/tailscaled.pid --exec /usr/sbin/tailscaled; then + log_end_msg 0 || true + else + log_end_msg 1 || true + fi + ;; + + status) + status_of_proc -p /run/tailscaled.pid /usr/sbin/tailscaled tailscaled && exit 0 || exit $? + ;; + + *) + log_action_msg "Usage: /etc/init.d/tailscaled {start|stop|status}" || true + exit 1 +esac + +exit 0 diff --git a/doctor/doctor.go b/doctor/doctor.go index 96af39f5f3eb9..7c3047e12b62d 100644 --- a/doctor/doctor.go +++ b/doctor/doctor.go @@ -1,79 +1,79 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package doctor contains more in-depth healthchecks that can be run to aid in -// diagnosing Tailscale issues. -package doctor - -import ( - "context" - "sync" - - "tailscale.com/types/logger" -) - -// Check is the interface defining a singular check. -// -// A check should log information that it gathers using the provided log -// function, and should attempt to make as much progress as possible in error -// conditions. -type Check interface { - // Name should return a name describing this check, in lower-kebab-case - // (i.e. "my-check", not "MyCheck" or "my_check"). - Name() string - // Run executes the check, logging diagnostic information to the - // provided logger function. - Run(context.Context, logger.Logf) error -} - -// RunChecks runs a list of checks in parallel, and logs any returned errors -// after all checks have returned. -func RunChecks(ctx context.Context, log logger.Logf, checks ...Check) { - if len(checks) == 0 { - return - } - - type namedErr struct { - name string - err error - } - errs := make(chan namedErr, len(checks)) - - var wg sync.WaitGroup - wg.Add(len(checks)) - for _, check := range checks { - go func(c Check) { - defer wg.Done() - - plog := logger.WithPrefix(log, c.Name()+": ") - errs <- namedErr{ - name: c.Name(), - err: c.Run(ctx, plog), - } - }(check) - } - - wg.Wait() - close(errs) - - for n := range errs { - if n.err == nil { - continue - } - - log("check %s: %v", n.name, n.err) - } -} - -// CheckFunc creates a Check from a name and a function. -func CheckFunc(name string, run func(context.Context, logger.Logf) error) Check { - return checkFunc{name, run} -} - -type checkFunc struct { - name string - run func(context.Context, logger.Logf) error -} - -func (c checkFunc) Name() string { return c.name } -func (c checkFunc) Run(ctx context.Context, log logger.Logf) error { return c.run(ctx, log) } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package doctor contains more in-depth healthchecks that can be run to aid in +// diagnosing Tailscale issues. +package doctor + +import ( + "context" + "sync" + + "tailscale.com/types/logger" +) + +// Check is the interface defining a singular check. +// +// A check should log information that it gathers using the provided log +// function, and should attempt to make as much progress as possible in error +// conditions. +type Check interface { + // Name should return a name describing this check, in lower-kebab-case + // (i.e. "my-check", not "MyCheck" or "my_check"). + Name() string + // Run executes the check, logging diagnostic information to the + // provided logger function. + Run(context.Context, logger.Logf) error +} + +// RunChecks runs a list of checks in parallel, and logs any returned errors +// after all checks have returned. +func RunChecks(ctx context.Context, log logger.Logf, checks ...Check) { + if len(checks) == 0 { + return + } + + type namedErr struct { + name string + err error + } + errs := make(chan namedErr, len(checks)) + + var wg sync.WaitGroup + wg.Add(len(checks)) + for _, check := range checks { + go func(c Check) { + defer wg.Done() + + plog := logger.WithPrefix(log, c.Name()+": ") + errs <- namedErr{ + name: c.Name(), + err: c.Run(ctx, plog), + } + }(check) + } + + wg.Wait() + close(errs) + + for n := range errs { + if n.err == nil { + continue + } + + log("check %s: %v", n.name, n.err) + } +} + +// CheckFunc creates a Check from a name and a function. +func CheckFunc(name string, run func(context.Context, logger.Logf) error) Check { + return checkFunc{name, run} +} + +type checkFunc struct { + name string + run func(context.Context, logger.Logf) error +} + +func (c checkFunc) Name() string { return c.name } +func (c checkFunc) Run(ctx context.Context, log logger.Logf) error { return c.run(ctx, log) } diff --git a/doctor/doctor_test.go b/doctor/doctor_test.go index dab7afa38a5fc..87250f10ed00a 100644 --- a/doctor/doctor_test.go +++ b/doctor/doctor_test.go @@ -1,49 +1,49 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package doctor - -import ( - "context" - "fmt" - "sync" - "testing" - - qt "github.com/frankban/quicktest" - "tailscale.com/types/logger" -) - -func TestRunChecks(t *testing.T) { - c := qt.New(t) - var ( - mu sync.Mutex - lines []string - ) - logf := func(format string, args ...any) { - mu.Lock() - defer mu.Unlock() - lines = append(lines, fmt.Sprintf(format, args...)) - } - - ctx := context.Background() - RunChecks(ctx, logf, - testCheck1{}, - CheckFunc("testcheck2", func(_ context.Context, log logger.Logf) error { - log("check 2") - return nil - }), - ) - - mu.Lock() - defer mu.Unlock() - c.Assert(lines, qt.Contains, "testcheck1: check 1") - c.Assert(lines, qt.Contains, "testcheck2: check 2") -} - -type testCheck1 struct{} - -func (t testCheck1) Name() string { return "testcheck1" } -func (t testCheck1) Run(_ context.Context, log logger.Logf) error { - log("check 1") - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package doctor + +import ( + "context" + "fmt" + "sync" + "testing" + + qt "github.com/frankban/quicktest" + "tailscale.com/types/logger" +) + +func TestRunChecks(t *testing.T) { + c := qt.New(t) + var ( + mu sync.Mutex + lines []string + ) + logf := func(format string, args ...any) { + mu.Lock() + defer mu.Unlock() + lines = append(lines, fmt.Sprintf(format, args...)) + } + + ctx := context.Background() + RunChecks(ctx, logf, + testCheck1{}, + CheckFunc("testcheck2", func(_ context.Context, log logger.Logf) error { + log("check 2") + return nil + }), + ) + + mu.Lock() + defer mu.Unlock() + c.Assert(lines, qt.Contains, "testcheck1: check 1") + c.Assert(lines, qt.Contains, "testcheck2: check 2") +} + +type testCheck1 struct{} + +func (t testCheck1) Name() string { return "testcheck1" } +func (t testCheck1) Run(_ context.Context, log logger.Logf) error { + log("check 1") + return nil +} diff --git a/doctor/permissions/permissions_bsd.go b/doctor/permissions/permissions_bsd.go index 4031af7221cd5..8b034cfff1af3 100644 --- a/doctor/permissions/permissions_bsd.go +++ b/doctor/permissions/permissions_bsd.go @@ -1,23 +1,23 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin || freebsd || openbsd - -package permissions - -import ( - "golang.org/x/sys/unix" - "tailscale.com/types/logger" -) - -func permissionsImpl(logf logger.Logf) error { - groups, _ := unix.Getgroups() - logf("uid=%s euid=%s gid=%s egid=%s groups=%s", - formatUserID(unix.Getuid()), - formatUserID(unix.Geteuid()), - formatGroupID(unix.Getgid()), - formatGroupID(unix.Getegid()), - formatGroups(groups), - ) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin || freebsd || openbsd + +package permissions + +import ( + "golang.org/x/sys/unix" + "tailscale.com/types/logger" +) + +func permissionsImpl(logf logger.Logf) error { + groups, _ := unix.Getgroups() + logf("uid=%s euid=%s gid=%s egid=%s groups=%s", + formatUserID(unix.Getuid()), + formatUserID(unix.Geteuid()), + formatGroupID(unix.Getgid()), + formatGroupID(unix.Getegid()), + formatGroups(groups), + ) + return nil +} diff --git a/doctor/permissions/permissions_linux.go b/doctor/permissions/permissions_linux.go index ef0a97056f411..12bb393d53383 100644 --- a/doctor/permissions/permissions_linux.go +++ b/doctor/permissions/permissions_linux.go @@ -1,62 +1,62 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package permissions - -import ( - "fmt" - "strings" - "unsafe" - - "golang.org/x/sys/unix" - "tailscale.com/types/logger" -) - -func permissionsImpl(logf logger.Logf) error { - // NOTE: getresuid and getresgid never fail unless passed an - // invalid address. - var ruid, euid, suid uint64 - unix.Syscall(unix.SYS_GETRESUID, - uintptr(unsafe.Pointer(&ruid)), - uintptr(unsafe.Pointer(&euid)), - uintptr(unsafe.Pointer(&suid)), - ) - - var rgid, egid, sgid uint64 - unix.Syscall(unix.SYS_GETRESGID, - uintptr(unsafe.Pointer(&rgid)), - uintptr(unsafe.Pointer(&egid)), - uintptr(unsafe.Pointer(&sgid)), - ) - - groups, _ := unix.Getgroups() - - var buf strings.Builder - fmt.Fprintf(&buf, "ruid=%s euid=%s suid=%s rgid=%s egid=%s sgid=%s groups=%s", - formatUserID(ruid), formatUserID(euid), formatUserID(suid), - formatGroupID(rgid), formatGroupID(egid), formatGroupID(sgid), - formatGroups(groups), - ) - - // Get process capabilities - var ( - capHeader = unix.CapUserHeader{ - Version: unix.LINUX_CAPABILITY_VERSION_3, - Pid: 0, // 0 means 'ourselves' - } - capData unix.CapUserData - ) - - if err := unix.Capget(&capHeader, &capData); err != nil { - fmt.Fprintf(&buf, " caperr=%v", err) - } else { - fmt.Fprintf(&buf, " cap_effective=%08x cap_permitted=%08x cap_inheritable=%08x", - capData.Effective, capData.Permitted, capData.Inheritable, - ) - } - - logf("%s", buf.String()) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package permissions + +import ( + "fmt" + "strings" + "unsafe" + + "golang.org/x/sys/unix" + "tailscale.com/types/logger" +) + +func permissionsImpl(logf logger.Logf) error { + // NOTE: getresuid and getresgid never fail unless passed an + // invalid address. + var ruid, euid, suid uint64 + unix.Syscall(unix.SYS_GETRESUID, + uintptr(unsafe.Pointer(&ruid)), + uintptr(unsafe.Pointer(&euid)), + uintptr(unsafe.Pointer(&suid)), + ) + + var rgid, egid, sgid uint64 + unix.Syscall(unix.SYS_GETRESGID, + uintptr(unsafe.Pointer(&rgid)), + uintptr(unsafe.Pointer(&egid)), + uintptr(unsafe.Pointer(&sgid)), + ) + + groups, _ := unix.Getgroups() + + var buf strings.Builder + fmt.Fprintf(&buf, "ruid=%s euid=%s suid=%s rgid=%s egid=%s sgid=%s groups=%s", + formatUserID(ruid), formatUserID(euid), formatUserID(suid), + formatGroupID(rgid), formatGroupID(egid), formatGroupID(sgid), + formatGroups(groups), + ) + + // Get process capabilities + var ( + capHeader = unix.CapUserHeader{ + Version: unix.LINUX_CAPABILITY_VERSION_3, + Pid: 0, // 0 means 'ourselves' + } + capData unix.CapUserData + ) + + if err := unix.Capget(&capHeader, &capData); err != nil { + fmt.Fprintf(&buf, " caperr=%v", err) + } else { + fmt.Fprintf(&buf, " cap_effective=%08x cap_permitted=%08x cap_inheritable=%08x", + capData.Effective, capData.Permitted, capData.Inheritable, + ) + } + + logf("%s", buf.String()) + return nil +} diff --git a/doctor/permissions/permissions_other.go b/doctor/permissions/permissions_other.go index 5e310b98e361e..7e6912b4928cf 100644 --- a/doctor/permissions/permissions_other.go +++ b/doctor/permissions/permissions_other.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !(linux || darwin || freebsd || openbsd) - -package permissions - -import ( - "runtime" - - "tailscale.com/types/logger" -) - -func permissionsImpl(logf logger.Logf) error { - logf("unsupported on %s/%s", runtime.GOOS, runtime.GOARCH) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !(linux || darwin || freebsd || openbsd) + +package permissions + +import ( + "runtime" + + "tailscale.com/types/logger" +) + +func permissionsImpl(logf logger.Logf) error { + logf("unsupported on %s/%s", runtime.GOOS, runtime.GOARCH) + return nil +} diff --git a/doctor/permissions/permissions_test.go b/doctor/permissions/permissions_test.go index 9b71c3be1cfe3..941d406ef8318 100644 --- a/doctor/permissions/permissions_test.go +++ b/doctor/permissions/permissions_test.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package permissions - -import "testing" - -func TestPermissionsImpl(t *testing.T) { - if err := permissionsImpl(t.Logf); err != nil { - t.Error(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package permissions + +import "testing" + +func TestPermissionsImpl(t *testing.T) { + if err := permissionsImpl(t.Logf); err != nil { + t.Error(err) + } +} diff --git a/doctor/routetable/routetable.go b/doctor/routetable/routetable.go index 1ebf294ce1474..76e4ef949b9af 100644 --- a/doctor/routetable/routetable.go +++ b/doctor/routetable/routetable.go @@ -1,34 +1,34 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package routetable provides a doctor.Check that dumps the current system's -// route table to the log. -package routetable - -import ( - "context" - - "tailscale.com/net/routetable" - "tailscale.com/types/logger" -) - -// MaxRoutes is the maximum number of routes that will be displayed. -const MaxRoutes = 1000 - -// Check implements the doctor.Check interface. -type Check struct{} - -func (Check) Name() string { - return "routetable" -} - -func (Check) Run(_ context.Context, logf logger.Logf) error { - rs, err := routetable.Get(MaxRoutes) - if err != nil { - return err - } - for _, r := range rs { - logf("%s", r) - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package routetable provides a doctor.Check that dumps the current system's +// route table to the log. +package routetable + +import ( + "context" + + "tailscale.com/net/routetable" + "tailscale.com/types/logger" +) + +// MaxRoutes is the maximum number of routes that will be displayed. +const MaxRoutes = 1000 + +// Check implements the doctor.Check interface. +type Check struct{} + +func (Check) Name() string { + return "routetable" +} + +func (Check) Run(_ context.Context, logf logger.Logf) error { + rs, err := routetable.Get(MaxRoutes) + if err != nil { + return err + } + for _, r := range rs { + logf("%s", r) + } + return nil +} diff --git a/envknob/envknob_nottest.go b/envknob/envknob_nottest.go index b21266f1377ca..0dd900cc8104e 100644 --- a/envknob/envknob_nottest.go +++ b/envknob/envknob_nottest.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ts_not_in_tests - -package envknob - -import "runtime" - -func GOOS() string { - // When the "ts_not_in_tests" build tag is used, we define this func to just - // return a simple constant so callers optimize just as if the knob were not - // present. We can then build production/optimized builds with the - // "ts_not_in_tests" build tag. - return runtime.GOOS -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_not_in_tests + +package envknob + +import "runtime" + +func GOOS() string { + // When the "ts_not_in_tests" build tag is used, we define this func to just + // return a simple constant so callers optimize just as if the knob were not + // present. We can then build production/optimized builds with the + // "ts_not_in_tests" build tag. + return runtime.GOOS +} diff --git a/envknob/envknob_testable.go b/envknob/envknob_testable.go index 53687d732d493..e7f038336c4f3 100644 --- a/envknob/envknob_testable.go +++ b/envknob/envknob_testable.go @@ -1,23 +1,23 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ts_not_in_tests - -package envknob - -import "runtime" - -// GOOS reports the effective runtime.GOOS to run as. -// -// In practice this returns just runtime.GOOS, unless overridden by -// test TS_DEBUG_FAKE_GOOS. -// -// This allows changing OS-specific stuff like the IPN server behavior -// for tests so we can e.g. test Windows-specific behaviors on Linux. -// This isn't universally used. -func GOOS() string { - if v := String("TS_DEBUG_FAKE_GOOS"); v != "" { - return v - } - return runtime.GOOS -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_not_in_tests + +package envknob + +import "runtime" + +// GOOS reports the effective runtime.GOOS to run as. +// +// In practice this returns just runtime.GOOS, unless overridden by +// test TS_DEBUG_FAKE_GOOS. +// +// This allows changing OS-specific stuff like the IPN server behavior +// for tests so we can e.g. test Windows-specific behaviors on Linux. +// This isn't universally used. +func GOOS() string { + if v := String("TS_DEBUG_FAKE_GOOS"); v != "" { + return v + } + return runtime.GOOS +} diff --git a/envknob/logknob/logknob.go b/envknob/logknob/logknob.go index a7b0a05e8b1b8..350384b8626e3 100644 --- a/envknob/logknob/logknob.go +++ b/envknob/logknob/logknob.go @@ -1,85 +1,85 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package logknob provides a helpful wrapper that allows enabling logging -// based on either an envknob or other methods of enablement. -package logknob - -import ( - "sync/atomic" - - "tailscale.com/envknob" - "tailscale.com/tailcfg" - "tailscale.com/types/logger" - "tailscale.com/types/views" -) - -// TODO(andrew-d): should we have a package-global registry of logknobs? It -// would allow us to update from a netmap in a central location, which might be -// reason enough to do it... - -// LogKnob allows configuring verbose logging, with multiple ways to enable. It -// supports enabling logging via envknob, via atomic boolean (for use in e.g. -// c2n log level changes), and via capabilities from a NetMap (so users can -// enable logging via the ACL JSON). -type LogKnob struct { - capName tailcfg.NodeCapability - cap atomic.Bool - env func() bool - manual atomic.Bool -} - -// NewLogKnob creates a new LogKnob, with the provided environment variable -// name and/or NetMap capability. -func NewLogKnob(env string, cap tailcfg.NodeCapability) *LogKnob { - if env == "" && cap == "" { - panic("must provide either an environment variable or capability") - } - - lk := &LogKnob{ - capName: cap, - } - if env != "" { - lk.env = envknob.RegisterBool(env) - } else { - lk.env = func() bool { return false } - } - return lk -} - -// Set will cause logs to be printed when called with Set(true). When called -// with Set(false), logs will not be printed due to an earlier call of -// Set(true), but may be printed due to either the envknob and/or capability of -// this LogKnob. -func (lk *LogKnob) Set(v bool) { - lk.manual.Store(v) -} - -// NetMap is an interface for the parts of netmap.NetworkMap that we care -// about; we use this rather than a concrete type to avoid a circular -// dependency. -type NetMap interface { - SelfCapabilities() views.Slice[tailcfg.NodeCapability] -} - -// UpdateFromNetMap will enable logging if the SelfNode in the provided NetMap -// contains the capability provided for this LogKnob. -func (lk *LogKnob) UpdateFromNetMap(nm NetMap) { - if lk.capName == "" { - return - } - - lk.cap.Store(views.SliceContains(nm.SelfCapabilities(), lk.capName)) -} - -// Do will call log with the provided format and arguments if any of the -// configured methods for enabling logging are true. -func (lk *LogKnob) Do(log logger.Logf, format string, args ...any) { - if lk.shouldLog() { - log(format, args...) - } -} - -func (lk *LogKnob) shouldLog() bool { - return lk.manual.Load() || lk.env() || lk.cap.Load() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package logknob provides a helpful wrapper that allows enabling logging +// based on either an envknob or other methods of enablement. +package logknob + +import ( + "sync/atomic" + + "tailscale.com/envknob" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/types/views" +) + +// TODO(andrew-d): should we have a package-global registry of logknobs? It +// would allow us to update from a netmap in a central location, which might be +// reason enough to do it... + +// LogKnob allows configuring verbose logging, with multiple ways to enable. It +// supports enabling logging via envknob, via atomic boolean (for use in e.g. +// c2n log level changes), and via capabilities from a NetMap (so users can +// enable logging via the ACL JSON). +type LogKnob struct { + capName tailcfg.NodeCapability + cap atomic.Bool + env func() bool + manual atomic.Bool +} + +// NewLogKnob creates a new LogKnob, with the provided environment variable +// name and/or NetMap capability. +func NewLogKnob(env string, cap tailcfg.NodeCapability) *LogKnob { + if env == "" && cap == "" { + panic("must provide either an environment variable or capability") + } + + lk := &LogKnob{ + capName: cap, + } + if env != "" { + lk.env = envknob.RegisterBool(env) + } else { + lk.env = func() bool { return false } + } + return lk +} + +// Set will cause logs to be printed when called with Set(true). When called +// with Set(false), logs will not be printed due to an earlier call of +// Set(true), but may be printed due to either the envknob and/or capability of +// this LogKnob. +func (lk *LogKnob) Set(v bool) { + lk.manual.Store(v) +} + +// NetMap is an interface for the parts of netmap.NetworkMap that we care +// about; we use this rather than a concrete type to avoid a circular +// dependency. +type NetMap interface { + SelfCapabilities() views.Slice[tailcfg.NodeCapability] +} + +// UpdateFromNetMap will enable logging if the SelfNode in the provided NetMap +// contains the capability provided for this LogKnob. +func (lk *LogKnob) UpdateFromNetMap(nm NetMap) { + if lk.capName == "" { + return + } + + lk.cap.Store(views.SliceContains(nm.SelfCapabilities(), lk.capName)) +} + +// Do will call log with the provided format and arguments if any of the +// configured methods for enabling logging are true. +func (lk *LogKnob) Do(log logger.Logf, format string, args ...any) { + if lk.shouldLog() { + log(format, args...) + } +} + +func (lk *LogKnob) shouldLog() bool { + return lk.manual.Load() || lk.env() || lk.cap.Load() +} diff --git a/envknob/logknob/logknob_test.go b/envknob/logknob/logknob_test.go index c9eed5612379a..b2a376a25b371 100644 --- a/envknob/logknob/logknob_test.go +++ b/envknob/logknob/logknob_test.go @@ -1,102 +1,102 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package logknob - -import ( - "bytes" - "fmt" - "testing" - - "tailscale.com/envknob" - "tailscale.com/tailcfg" - "tailscale.com/types/netmap" -) - -var testKnob = NewLogKnob( - "TS_TEST_LOGKNOB", - "https://tailscale.com/cap/testing", -) - -// Static type assertion for our interface type. -var _ NetMap = &netmap.NetworkMap{} - -func TestLogKnob(t *testing.T) { - t.Run("Default", func(t *testing.T) { - if testKnob.shouldLog() { - t.Errorf("expected default shouldLog()=false") - } - assertNoLogs(t) - }) - t.Run("Manual", func(t *testing.T) { - t.Cleanup(func() { testKnob.Set(false) }) - - assertNoLogs(t) - testKnob.Set(true) - if !testKnob.shouldLog() { - t.Errorf("expected shouldLog()=true") - } - assertLogs(t) - }) - t.Run("Env", func(t *testing.T) { - t.Cleanup(func() { - envknob.Setenv("TS_TEST_LOGKNOB", "") - }) - - assertNoLogs(t) - if testKnob.shouldLog() { - t.Errorf("expected default shouldLog()=false") - } - - envknob.Setenv("TS_TEST_LOGKNOB", "true") - if !testKnob.shouldLog() { - t.Errorf("expected shouldLog()=true") - } - assertLogs(t) - }) - t.Run("NetMap", func(t *testing.T) { - t.Cleanup(func() { testKnob.cap.Store(false) }) - - assertNoLogs(t) - if testKnob.shouldLog() { - t.Errorf("expected default shouldLog()=false") - } - - testKnob.UpdateFromNetMap(&netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - Capabilities: []tailcfg.NodeCapability{ - "https://tailscale.com/cap/testing", - }, - }).View(), - }) - if !testKnob.shouldLog() { - t.Errorf("expected shouldLog()=true") - } - assertLogs(t) - }) -} - -func assertLogs(t *testing.T) { - var buf bytes.Buffer - logf := func(format string, args ...any) { - fmt.Fprintf(&buf, format, args...) - } - - testKnob.Do(logf, "hello %s", "world") - const want = "hello world" - if got := buf.String(); got != want { - t.Errorf("got %q, want %q", got, want) - } -} - -func assertNoLogs(t *testing.T) { - var buf bytes.Buffer - logf := func(format string, args ...any) { - fmt.Fprintf(&buf, format, args...) - } - - testKnob.Do(logf, "hello %s", "world") - if got := buf.String(); got != "" { - t.Errorf("expected no logs, but got: %q", got) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package logknob + +import ( + "bytes" + "fmt" + "testing" + + "tailscale.com/envknob" + "tailscale.com/tailcfg" + "tailscale.com/types/netmap" +) + +var testKnob = NewLogKnob( + "TS_TEST_LOGKNOB", + "https://tailscale.com/cap/testing", +) + +// Static type assertion for our interface type. +var _ NetMap = &netmap.NetworkMap{} + +func TestLogKnob(t *testing.T) { + t.Run("Default", func(t *testing.T) { + if testKnob.shouldLog() { + t.Errorf("expected default shouldLog()=false") + } + assertNoLogs(t) + }) + t.Run("Manual", func(t *testing.T) { + t.Cleanup(func() { testKnob.Set(false) }) + + assertNoLogs(t) + testKnob.Set(true) + if !testKnob.shouldLog() { + t.Errorf("expected shouldLog()=true") + } + assertLogs(t) + }) + t.Run("Env", func(t *testing.T) { + t.Cleanup(func() { + envknob.Setenv("TS_TEST_LOGKNOB", "") + }) + + assertNoLogs(t) + if testKnob.shouldLog() { + t.Errorf("expected default shouldLog()=false") + } + + envknob.Setenv("TS_TEST_LOGKNOB", "true") + if !testKnob.shouldLog() { + t.Errorf("expected shouldLog()=true") + } + assertLogs(t) + }) + t.Run("NetMap", func(t *testing.T) { + t.Cleanup(func() { testKnob.cap.Store(false) }) + + assertNoLogs(t) + if testKnob.shouldLog() { + t.Errorf("expected default shouldLog()=false") + } + + testKnob.UpdateFromNetMap(&netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Capabilities: []tailcfg.NodeCapability{ + "https://tailscale.com/cap/testing", + }, + }).View(), + }) + if !testKnob.shouldLog() { + t.Errorf("expected shouldLog()=true") + } + assertLogs(t) + }) +} + +func assertLogs(t *testing.T) { + var buf bytes.Buffer + logf := func(format string, args ...any) { + fmt.Fprintf(&buf, format, args...) + } + + testKnob.Do(logf, "hello %s", "world") + const want = "hello world" + if got := buf.String(); got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func assertNoLogs(t *testing.T) { + var buf bytes.Buffer + logf := func(format string, args ...any) { + fmt.Fprintf(&buf, format, args...) + } + + testKnob.Do(logf, "hello %s", "world") + if got := buf.String(); got != "" { + t.Errorf("expected no logs, but got: %q", got) + } +} diff --git a/gomod_test.go b/gomod_test.go index 52fdd463910c4..f984b5d6f27a5 100644 --- a/gomod_test.go +++ b/gomod_test.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailscaleroot - -import ( - "os" - "testing" - - "golang.org/x/mod/modfile" -) - -func TestGoMod(t *testing.T) { - goMod, err := os.ReadFile("go.mod") - if err != nil { - t.Fatal(err) - } - f, err := modfile.Parse("go.mod", goMod, nil) - if err != nil { - t.Fatal(err) - } - if len(f.Replace) > 0 { - t.Errorf("go.mod has %d replace directives; expect zero in this repo", len(f.Replace)) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailscaleroot + +import ( + "os" + "testing" + + "golang.org/x/mod/modfile" +) + +func TestGoMod(t *testing.T) { + goMod, err := os.ReadFile("go.mod") + if err != nil { + t.Fatal(err) + } + f, err := modfile.Parse("go.mod", goMod, nil) + if err != nil { + t.Fatal(err) + } + if len(f.Replace) > 0 { + t.Errorf("go.mod has %d replace directives; expect zero in this repo", len(f.Replace)) + } +} diff --git a/hostinfo/hostinfo_darwin.go b/hostinfo/hostinfo_darwin.go index a61d95b32c907..0b1774e7712d7 100644 --- a/hostinfo/hostinfo_darwin.go +++ b/hostinfo/hostinfo_darwin.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin - -package hostinfo - -import ( - "os" - "path/filepath" -) - -func init() { - packageType = packageTypeDarwin -} - -func packageTypeDarwin() string { - // Using tailscaled or IPNExtension? - exe, _ := os.Executable() - return filepath.Base(exe) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin + +package hostinfo + +import ( + "os" + "path/filepath" +) + +func init() { + packageType = packageTypeDarwin +} + +func packageTypeDarwin() string { + // Using tailscaled or IPNExtension? + exe, _ := os.Executable() + return filepath.Base(exe) +} diff --git a/hostinfo/hostinfo_freebsd.go b/hostinfo/hostinfo_freebsd.go index 15c7783aa4e4c..3661b13229ac5 100644 --- a/hostinfo/hostinfo_freebsd.go +++ b/hostinfo/hostinfo_freebsd.go @@ -1,64 +1,64 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build freebsd - -package hostinfo - -import ( - "bytes" - "os" - "os/exec" - - "golang.org/x/sys/unix" - "tailscale.com/types/ptr" - "tailscale.com/version/distro" -) - -func init() { - osVersion = lazyOSVersion.Get - distroName = distroNameFreeBSD - distroVersion = distroVersionFreeBSD -} - -var ( - lazyVersionMeta = &lazyAtomicValue[versionMeta]{f: ptr.To(freebsdVersionMeta)} - lazyOSVersion = &lazyAtomicValue[string]{f: ptr.To(osVersionFreeBSD)} -) - -func distroNameFreeBSD() string { - return lazyVersionMeta.Get().DistroName -} - -func distroVersionFreeBSD() string { - return lazyVersionMeta.Get().DistroVersion -} - -type versionMeta struct { - DistroName string - DistroVersion string - DistroCodeName string -} - -func osVersionFreeBSD() string { - var un unix.Utsname - unix.Uname(&un) - return unix.ByteSliceToString(un.Release[:]) -} - -func freebsdVersionMeta() (meta versionMeta) { - d := distro.Get() - meta.DistroName = string(d) - switch d { - case distro.Pfsense: - b, _ := os.ReadFile("/etc/version") - meta.DistroVersion = string(bytes.TrimSpace(b)) - case distro.OPNsense: - b, _ := exec.Command("opnsense-version").Output() - meta.DistroVersion = string(bytes.TrimSpace(b)) - case distro.TrueNAS: - b, _ := os.ReadFile("/etc/version") - meta.DistroVersion = string(bytes.TrimSpace(b)) - } - return -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build freebsd + +package hostinfo + +import ( + "bytes" + "os" + "os/exec" + + "golang.org/x/sys/unix" + "tailscale.com/types/ptr" + "tailscale.com/version/distro" +) + +func init() { + osVersion = lazyOSVersion.Get + distroName = distroNameFreeBSD + distroVersion = distroVersionFreeBSD +} + +var ( + lazyVersionMeta = &lazyAtomicValue[versionMeta]{f: ptr.To(freebsdVersionMeta)} + lazyOSVersion = &lazyAtomicValue[string]{f: ptr.To(osVersionFreeBSD)} +) + +func distroNameFreeBSD() string { + return lazyVersionMeta.Get().DistroName +} + +func distroVersionFreeBSD() string { + return lazyVersionMeta.Get().DistroVersion +} + +type versionMeta struct { + DistroName string + DistroVersion string + DistroCodeName string +} + +func osVersionFreeBSD() string { + var un unix.Utsname + unix.Uname(&un) + return unix.ByteSliceToString(un.Release[:]) +} + +func freebsdVersionMeta() (meta versionMeta) { + d := distro.Get() + meta.DistroName = string(d) + switch d { + case distro.Pfsense: + b, _ := os.ReadFile("/etc/version") + meta.DistroVersion = string(bytes.TrimSpace(b)) + case distro.OPNsense: + b, _ := exec.Command("opnsense-version").Output() + meta.DistroVersion = string(bytes.TrimSpace(b)) + case distro.TrueNAS: + b, _ := os.ReadFile("/etc/version") + meta.DistroVersion = string(bytes.TrimSpace(b)) + } + return +} diff --git a/hostinfo/hostinfo_test.go b/hostinfo/hostinfo_test.go index 76282ebf56733..9fe32e0449be1 100644 --- a/hostinfo/hostinfo_test.go +++ b/hostinfo/hostinfo_test.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package hostinfo - -import ( - "encoding/json" - "strings" - "testing" -) - -func TestNew(t *testing.T) { - hi := New() - if hi == nil { - t.Fatal("no Hostinfo") - } - j, err := json.MarshalIndent(hi, " ", "") - if err != nil { - t.Fatal(err) - } - t.Logf("Got: %s", j) -} - -func TestOSVersion(t *testing.T) { - if osVersion == nil { - t.Skip("not available for OS") - } - t.Logf("Got: %#q", osVersion()) -} - -func TestEtcAptSourceFileIsDisabled(t *testing.T) { - tests := []struct { - name string - in string - want bool - }{ - {"empty", "", false}, - {"normal", "deb foo\n", false}, - {"normal-commented", "# deb foo\n", false}, - {"normal-disabled-by-ubuntu", "# deb foo # disabled on upgrade to dingus\n", true}, - {"normal-disabled-then-uncommented", "deb foo # disabled on upgrade to dingus\n", false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := etcAptSourceFileIsDisabled(strings.NewReader(tt.in)) - if got != tt.want { - t.Errorf("got %v; want %v", got, tt.want) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package hostinfo + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestNew(t *testing.T) { + hi := New() + if hi == nil { + t.Fatal("no Hostinfo") + } + j, err := json.MarshalIndent(hi, " ", "") + if err != nil { + t.Fatal(err) + } + t.Logf("Got: %s", j) +} + +func TestOSVersion(t *testing.T) { + if osVersion == nil { + t.Skip("not available for OS") + } + t.Logf("Got: %#q", osVersion()) +} + +func TestEtcAptSourceFileIsDisabled(t *testing.T) { + tests := []struct { + name string + in string + want bool + }{ + {"empty", "", false}, + {"normal", "deb foo\n", false}, + {"normal-commented", "# deb foo\n", false}, + {"normal-disabled-by-ubuntu", "# deb foo # disabled on upgrade to dingus\n", true}, + {"normal-disabled-then-uncommented", "deb foo # disabled on upgrade to dingus\n", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := etcAptSourceFileIsDisabled(strings.NewReader(tt.in)) + if got != tt.want { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} diff --git a/hostinfo/hostinfo_uname.go b/hostinfo/hostinfo_uname.go index 10995c1c78652..32b733a03bcb3 100644 --- a/hostinfo/hostinfo_uname.go +++ b/hostinfo/hostinfo_uname.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || freebsd || openbsd || darwin - -package hostinfo - -import ( - "runtime" - - "golang.org/x/sys/unix" - "tailscale.com/types/ptr" -) - -func init() { - unameMachine = lazyUnameMachine.Get -} - -var lazyUnameMachine = &lazyAtomicValue[string]{f: ptr.To(unameMachineUnix)} - -func unameMachineUnix() string { - switch runtime.GOOS { - case "android": - // Don't call on Android for now. We're late in the 1.36 release cycle - // and don't want to test syscall filters on various Android versions to - // see what's permitted. Notably, the hostinfo_linux.go file has build - // tag !android, so maybe Uname is verboten. - return "" - case "ios": - // For similar reasons, don't call on iOS. There aren't many iOS devices - // and we know their CPU properties so calling this is only risk and no - // reward. - return "" - } - var un unix.Utsname - unix.Uname(&un) - return unix.ByteSliceToString(un.Machine[:]) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux || freebsd || openbsd || darwin + +package hostinfo + +import ( + "runtime" + + "golang.org/x/sys/unix" + "tailscale.com/types/ptr" +) + +func init() { + unameMachine = lazyUnameMachine.Get +} + +var lazyUnameMachine = &lazyAtomicValue[string]{f: ptr.To(unameMachineUnix)} + +func unameMachineUnix() string { + switch runtime.GOOS { + case "android": + // Don't call on Android for now. We're late in the 1.36 release cycle + // and don't want to test syscall filters on various Android versions to + // see what's permitted. Notably, the hostinfo_linux.go file has build + // tag !android, so maybe Uname is verboten. + return "" + case "ios": + // For similar reasons, don't call on iOS. There aren't many iOS devices + // and we know their CPU properties so calling this is only risk and no + // reward. + return "" + } + var un unix.Utsname + unix.Uname(&un) + return unix.ByteSliceToString(un.Machine[:]) +} diff --git a/hostinfo/wol.go b/hostinfo/wol.go index b6fc81a8b2482..3a30af2fe3a37 100644 --- a/hostinfo/wol.go +++ b/hostinfo/wol.go @@ -1,106 +1,106 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package hostinfo - -import ( - "log" - "net" - "runtime" - "strings" - "unicode" - - "tailscale.com/envknob" -) - -// TODO(bradfitz): this is all too simplistic and static. It needs to run -// continuously in response to netmon events (USB ethernet adapaters might get -// plugged in) and look for the media type/status/etc. Right now on macOS it -// still detects a half dozen "up" en0, en1, en2, en3 etc interfaces that don't -// have any media. We should only report the one that's actually connected. -// But it works for now (2023-10-05) for fleshing out the rest. - -var wakeMAC = envknob.RegisterString("TS_WAKE_MAC") // mac address, "false" or "auto". for https://github.com/tailscale/tailscale/issues/306 - -// getWoLMACs returns up to 10 MAC address of the local machine to send -// wake-on-LAN packets to in order to wake it up. The returned MACs are in -// lowercase hex colon-separated form ("xx:xx:xx:xx:xx:xx"). -// -// If TS_WAKE_MAC=auto, it tries to automatically find the MACs based on the OS -// type and interface properties. (TODO(bradfitz): incomplete) If TS_WAKE_MAC is -// set to a MAC address, that sole MAC address is returned. -func getWoLMACs() (macs []string) { - switch runtime.GOOS { - case "ios", "android": - return nil - } - if s := wakeMAC(); s != "" { - switch s { - case "auto": - ifs, _ := net.Interfaces() - for _, iface := range ifs { - if iface.Flags&net.FlagLoopback != 0 { - continue - } - if iface.Flags&net.FlagBroadcast == 0 || - iface.Flags&net.FlagRunning == 0 || - iface.Flags&net.FlagUp == 0 { - continue - } - if keepMAC(iface.Name, iface.HardwareAddr) { - macs = append(macs, iface.HardwareAddr.String()) - } - if len(macs) == 10 { - break - } - } - return macs - case "false", "off": // fast path before ParseMAC error - return nil - } - mac, err := net.ParseMAC(s) - if err != nil { - log.Printf("invalid MAC %q", s) - return nil - } - return []string{mac.String()} - } - return nil -} - -var ignoreWakeOUI = map[[3]byte]bool{ - {0x00, 0x15, 0x5d}: true, // Hyper-V - {0x00, 0x50, 0x56}: true, // VMware - {0x00, 0x1c, 0x14}: true, // VMware - {0x00, 0x05, 0x69}: true, // VMware - {0x00, 0x0c, 0x29}: true, // VMware - {0x00, 0x1c, 0x42}: true, // Parallels - {0x08, 0x00, 0x27}: true, // VirtualBox - {0x00, 0x21, 0xf6}: true, // VirtualBox - {0x00, 0x14, 0x4f}: true, // VirtualBox - {0x00, 0x0f, 0x4b}: true, // VirtualBox - {0x52, 0x54, 0x00}: true, // VirtualBox/Vagrant -} - -func keepMAC(ifName string, mac []byte) bool { - if len(mac) != 6 { - return false - } - base := strings.TrimRightFunc(ifName, unicode.IsNumber) - switch runtime.GOOS { - case "darwin": - switch base { - case "llw", "awdl", "utun", "bridge", "lo", "gif", "stf", "anpi", "ap": - return false - } - } - if mac[0] == 0x02 && mac[1] == 0x42 { - // Docker container. - return false - } - oui := [3]byte{mac[0], mac[1], mac[2]} - if ignoreWakeOUI[oui] { - return false - } - return true -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package hostinfo + +import ( + "log" + "net" + "runtime" + "strings" + "unicode" + + "tailscale.com/envknob" +) + +// TODO(bradfitz): this is all too simplistic and static. It needs to run +// continuously in response to netmon events (USB ethernet adapaters might get +// plugged in) and look for the media type/status/etc. Right now on macOS it +// still detects a half dozen "up" en0, en1, en2, en3 etc interfaces that don't +// have any media. We should only report the one that's actually connected. +// But it works for now (2023-10-05) for fleshing out the rest. + +var wakeMAC = envknob.RegisterString("TS_WAKE_MAC") // mac address, "false" or "auto". for https://github.com/tailscale/tailscale/issues/306 + +// getWoLMACs returns up to 10 MAC address of the local machine to send +// wake-on-LAN packets to in order to wake it up. The returned MACs are in +// lowercase hex colon-separated form ("xx:xx:xx:xx:xx:xx"). +// +// If TS_WAKE_MAC=auto, it tries to automatically find the MACs based on the OS +// type and interface properties. (TODO(bradfitz): incomplete) If TS_WAKE_MAC is +// set to a MAC address, that sole MAC address is returned. +func getWoLMACs() (macs []string) { + switch runtime.GOOS { + case "ios", "android": + return nil + } + if s := wakeMAC(); s != "" { + switch s { + case "auto": + ifs, _ := net.Interfaces() + for _, iface := range ifs { + if iface.Flags&net.FlagLoopback != 0 { + continue + } + if iface.Flags&net.FlagBroadcast == 0 || + iface.Flags&net.FlagRunning == 0 || + iface.Flags&net.FlagUp == 0 { + continue + } + if keepMAC(iface.Name, iface.HardwareAddr) { + macs = append(macs, iface.HardwareAddr.String()) + } + if len(macs) == 10 { + break + } + } + return macs + case "false", "off": // fast path before ParseMAC error + return nil + } + mac, err := net.ParseMAC(s) + if err != nil { + log.Printf("invalid MAC %q", s) + return nil + } + return []string{mac.String()} + } + return nil +} + +var ignoreWakeOUI = map[[3]byte]bool{ + {0x00, 0x15, 0x5d}: true, // Hyper-V + {0x00, 0x50, 0x56}: true, // VMware + {0x00, 0x1c, 0x14}: true, // VMware + {0x00, 0x05, 0x69}: true, // VMware + {0x00, 0x0c, 0x29}: true, // VMware + {0x00, 0x1c, 0x42}: true, // Parallels + {0x08, 0x00, 0x27}: true, // VirtualBox + {0x00, 0x21, 0xf6}: true, // VirtualBox + {0x00, 0x14, 0x4f}: true, // VirtualBox + {0x00, 0x0f, 0x4b}: true, // VirtualBox + {0x52, 0x54, 0x00}: true, // VirtualBox/Vagrant +} + +func keepMAC(ifName string, mac []byte) bool { + if len(mac) != 6 { + return false + } + base := strings.TrimRightFunc(ifName, unicode.IsNumber) + switch runtime.GOOS { + case "darwin": + switch base { + case "llw", "awdl", "utun", "bridge", "lo", "gif", "stf", "anpi", "ap": + return false + } + } + if mac[0] == 0x02 && mac[1] == 0x42 { + // Docker container. + return false + } + oui := [3]byte{mac[0], mac[1], mac[2]} + if ignoreWakeOUI[oui] { + return false + } + return true +} diff --git a/ipn/ipnlocal/breaktcp_darwin.go b/ipn/ipnlocal/breaktcp_darwin.go index 289e760e194a4..13566198ce9fc 100644 --- a/ipn/ipnlocal/breaktcp_darwin.go +++ b/ipn/ipnlocal/breaktcp_darwin.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnlocal - -import ( - "log" - - "golang.org/x/sys/unix" -) - -func init() { - breakTCPConns = breakTCPConnsDarwin -} - -func breakTCPConnsDarwin() error { - var matched int - for fd := 0; fd < 1000; fd++ { - _, err := unix.GetsockoptTCPConnectionInfo(fd, unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) - if err == nil { - matched++ - err = unix.Close(fd) - log.Printf("debug: closed TCP fd %v: %v", fd, err) - } - } - if matched == 0 { - log.Printf("debug: no TCP connections found") - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "log" + + "golang.org/x/sys/unix" +) + +func init() { + breakTCPConns = breakTCPConnsDarwin +} + +func breakTCPConnsDarwin() error { + var matched int + for fd := 0; fd < 1000; fd++ { + _, err := unix.GetsockoptTCPConnectionInfo(fd, unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) + if err == nil { + matched++ + err = unix.Close(fd) + log.Printf("debug: closed TCP fd %v: %v", fd, err) + } + } + if matched == 0 { + log.Printf("debug: no TCP connections found") + } + return nil +} diff --git a/ipn/ipnlocal/breaktcp_linux.go b/ipn/ipnlocal/breaktcp_linux.go index d078103cf5388..b82f6521246f0 100644 --- a/ipn/ipnlocal/breaktcp_linux.go +++ b/ipn/ipnlocal/breaktcp_linux.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnlocal - -import ( - "log" - - "golang.org/x/sys/unix" -) - -func init() { - breakTCPConns = breakTCPConnsLinux -} - -func breakTCPConnsLinux() error { - var matched int - for fd := 0; fd < 1000; fd++ { - _, err := unix.GetsockoptTCPInfo(fd, unix.IPPROTO_TCP, unix.TCP_INFO) - if err == nil { - matched++ - err = unix.Close(fd) - log.Printf("debug: closed TCP fd %v: %v", fd, err) - } - } - if matched == 0 { - log.Printf("debug: no TCP connections found") - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "log" + + "golang.org/x/sys/unix" +) + +func init() { + breakTCPConns = breakTCPConnsLinux +} + +func breakTCPConnsLinux() error { + var matched int + for fd := 0; fd < 1000; fd++ { + _, err := unix.GetsockoptTCPInfo(fd, unix.IPPROTO_TCP, unix.TCP_INFO) + if err == nil { + matched++ + err = unix.Close(fd) + log.Printf("debug: closed TCP fd %v: %v", fd, err) + } + } + if matched == 0 { + log.Printf("debug: no TCP connections found") + } + return nil +} diff --git a/ipn/ipnlocal/expiry_test.go b/ipn/ipnlocal/expiry_test.go index efc18133f556d..af1aa337bbe0c 100644 --- a/ipn/ipnlocal/expiry_test.go +++ b/ipn/ipnlocal/expiry_test.go @@ -1,301 +1,301 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnlocal - -import ( - "fmt" - "reflect" - "strings" - "testing" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/tstest" - "tailscale.com/types/key" - "tailscale.com/types/netmap" -) - -func TestFlagExpiredPeers(t *testing.T) { - n := func(id tailcfg.NodeID, name string, expiry time.Time, mod ...func(*tailcfg.Node)) *tailcfg.Node { - n := &tailcfg.Node{ID: id, Name: name, KeyExpiry: expiry} - for _, f := range mod { - f(n) - } - return n - } - - now := time.Unix(1673373129, 0) - - timeInPast := now.Add(-1 * time.Hour) - timeInFuture := now.Add(1 * time.Hour) - - timeBeforeEpoch := flagExpiredPeersEpoch.Add(-1 * time.Second) - if now.Before(timeBeforeEpoch) { - panic("current time in test cannot be before epoch") - } - - var expiredKey key.NodePublic - if err := expiredKey.UnmarshalText([]byte("nodekey:6da774d5d7740000000000000000000000000000000000000000000000000000")); err != nil { - panic(err) - } - - tests := []struct { - name string - controlTime *time.Time - netmap *netmap.NetworkMap - want []tailcfg.NodeView - }{ - { - name: "no_expiry", - controlTime: &now, - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeInFuture), - }), - }, - want: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeInFuture), - }), - }, - { - name: "expiry", - controlTime: &now, - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeInPast), - }), - }, - want: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeInPast, func(n *tailcfg.Node) { - n.Expired = true - n.Key = expiredKey - }), - }), - }, - { - name: "bad_ControlTime", - // controlTime here is intentionally before our hardcoded epoch - controlTime: &timeBeforeEpoch, - - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeBeforeEpoch.Add(-1*time.Hour)), // before ControlTime - }), - }, - want: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeBeforeEpoch.Add(-1*time.Hour)), // should have expired, but ControlTime is before epoch - }), - }, - { - name: "tagged_node", - controlTime: &now, - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", time.Time{}), // tagged node; zero expiry - }), - }, - want: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", time.Time{}), // not expired - }), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - em := newExpiryManager(t.Logf) - em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) - if tt.controlTime != nil { - em.onControlTime(*tt.controlTime) - } - em.flagExpiredPeers(tt.netmap, now) - if !reflect.DeepEqual(tt.netmap.Peers, tt.want) { - t.Errorf("wrong results\n got: %s\nwant: %s", formatNodes(tt.netmap.Peers), formatNodes(tt.want)) - } - }) - } -} - -func TestNextPeerExpiry(t *testing.T) { - n := func(id tailcfg.NodeID, name string, expiry time.Time, mod ...func(*tailcfg.Node)) *tailcfg.Node { - n := &tailcfg.Node{ID: id, Name: name, KeyExpiry: expiry} - for _, f := range mod { - f(n) - } - return n - } - - now := time.Unix(1675725516, 0) - - noExpiry := time.Time{} - timeInPast := now.Add(-1 * time.Hour) - timeInFuture := now.Add(1 * time.Hour) - timeInMoreFuture := now.Add(2 * time.Hour) - - tests := []struct { - name string - netmap *netmap.NetworkMap - want time.Time - }{ - { - name: "no_expiry", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", noExpiry), - n(2, "bar", noExpiry), - }), - SelfNode: n(3, "self", noExpiry).View(), - }, - want: noExpiry, - }, - { - name: "future_expiry_from_peer", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", noExpiry), - n(2, "bar", timeInFuture), - }), - SelfNode: n(3, "self", noExpiry).View(), - }, - want: timeInFuture, - }, - { - name: "future_expiry_from_self", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", noExpiry), - n(2, "bar", noExpiry), - }), - SelfNode: n(3, "self", timeInFuture).View(), - }, - want: timeInFuture, - }, - { - name: "future_expiry_from_multiple_peers", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeInMoreFuture), - }), - SelfNode: n(3, "self", noExpiry).View(), - }, - want: timeInFuture, - }, - { - name: "future_expiry_from_peer_and_self", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInMoreFuture), - }), - SelfNode: n(2, "self", timeInFuture).View(), - }, - want: timeInFuture, - }, - { - name: "only_self", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{}), - SelfNode: n(1, "self", timeInFuture).View(), - }, - want: timeInFuture, - }, - { - name: "peer_already_expired", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInPast), - }), - SelfNode: n(2, "self", timeInFuture).View(), - }, - want: timeInFuture, - }, - { - name: "self_already_expired", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - }), - SelfNode: n(2, "self", timeInPast).View(), - }, - want: timeInFuture, - }, - { - name: "all_nodes_already_expired", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInPast), - }), - SelfNode: n(2, "self", timeInPast).View(), - }, - want: noExpiry, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - em := newExpiryManager(t.Logf) - em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) - got := em.nextPeerExpiry(tt.netmap, now) - if !got.Equal(tt.want) { - t.Errorf("got %q, want %q", got.Format(time.RFC3339), tt.want.Format(time.RFC3339)) - } else if !got.IsZero() && got.Before(now) { - t.Errorf("unexpectedly got expiry %q before now %q", got.Format(time.RFC3339), now.Format(time.RFC3339)) - } - }) - } - - t.Run("ClockSkew", func(t *testing.T) { - t.Logf("local time: %q", now.Format(time.RFC3339)) - em := newExpiryManager(t.Logf) - em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) - - // The local clock is "running fast"; our clock skew is -2h - em.clockDelta.Store(-2 * time.Hour) - t.Logf("'real' time: %q", now.Add(-2*time.Hour).Format(time.RFC3339)) - - // If we don't adjust for the local time, this would return a - // time in the past. - nm := &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInPast), - }), - } - got := em.nextPeerExpiry(nm, now) - want := now.Add(30 * time.Second) - if !got.Equal(want) { - t.Errorf("got %q, want %q", got.Format(time.RFC3339), want.Format(time.RFC3339)) - } - }) -} - -func formatNodes(nodes []tailcfg.NodeView) string { - var sb strings.Builder - for i, n := range nodes { - if i > 0 { - sb.WriteString(", ") - } - fmt.Fprintf(&sb, "(%d, %q", n.ID(), n.Name()) - - if n.Online() != nil { - fmt.Fprintf(&sb, ", online=%v", *n.Online()) - } - if n.LastSeen() != nil { - fmt.Fprintf(&sb, ", lastSeen=%v", n.LastSeen().Unix()) - } - if n.Key() != (key.NodePublic{}) { - fmt.Fprintf(&sb, ", key=%v", n.Key().String()) - } - if n.Expired() { - fmt.Fprintf(&sb, ", expired=true") - } - sb.WriteString(")") - } - return sb.String() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "fmt" + "reflect" + "strings" + "testing" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/types/key" + "tailscale.com/types/netmap" +) + +func TestFlagExpiredPeers(t *testing.T) { + n := func(id tailcfg.NodeID, name string, expiry time.Time, mod ...func(*tailcfg.Node)) *tailcfg.Node { + n := &tailcfg.Node{ID: id, Name: name, KeyExpiry: expiry} + for _, f := range mod { + f(n) + } + return n + } + + now := time.Unix(1673373129, 0) + + timeInPast := now.Add(-1 * time.Hour) + timeInFuture := now.Add(1 * time.Hour) + + timeBeforeEpoch := flagExpiredPeersEpoch.Add(-1 * time.Second) + if now.Before(timeBeforeEpoch) { + panic("current time in test cannot be before epoch") + } + + var expiredKey key.NodePublic + if err := expiredKey.UnmarshalText([]byte("nodekey:6da774d5d7740000000000000000000000000000000000000000000000000000")); err != nil { + panic(err) + } + + tests := []struct { + name string + controlTime *time.Time + netmap *netmap.NetworkMap + want []tailcfg.NodeView + }{ + { + name: "no_expiry", + controlTime: &now, + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeInFuture), + }), + }, + want: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeInFuture), + }), + }, + { + name: "expiry", + controlTime: &now, + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeInPast), + }), + }, + want: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeInPast, func(n *tailcfg.Node) { + n.Expired = true + n.Key = expiredKey + }), + }), + }, + { + name: "bad_ControlTime", + // controlTime here is intentionally before our hardcoded epoch + controlTime: &timeBeforeEpoch, + + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeBeforeEpoch.Add(-1*time.Hour)), // before ControlTime + }), + }, + want: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeBeforeEpoch.Add(-1*time.Hour)), // should have expired, but ControlTime is before epoch + }), + }, + { + name: "tagged_node", + controlTime: &now, + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", time.Time{}), // tagged node; zero expiry + }), + }, + want: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", time.Time{}), // not expired + }), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + em := newExpiryManager(t.Logf) + em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) + if tt.controlTime != nil { + em.onControlTime(*tt.controlTime) + } + em.flagExpiredPeers(tt.netmap, now) + if !reflect.DeepEqual(tt.netmap.Peers, tt.want) { + t.Errorf("wrong results\n got: %s\nwant: %s", formatNodes(tt.netmap.Peers), formatNodes(tt.want)) + } + }) + } +} + +func TestNextPeerExpiry(t *testing.T) { + n := func(id tailcfg.NodeID, name string, expiry time.Time, mod ...func(*tailcfg.Node)) *tailcfg.Node { + n := &tailcfg.Node{ID: id, Name: name, KeyExpiry: expiry} + for _, f := range mod { + f(n) + } + return n + } + + now := time.Unix(1675725516, 0) + + noExpiry := time.Time{} + timeInPast := now.Add(-1 * time.Hour) + timeInFuture := now.Add(1 * time.Hour) + timeInMoreFuture := now.Add(2 * time.Hour) + + tests := []struct { + name string + netmap *netmap.NetworkMap + want time.Time + }{ + { + name: "no_expiry", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", noExpiry), + n(2, "bar", noExpiry), + }), + SelfNode: n(3, "self", noExpiry).View(), + }, + want: noExpiry, + }, + { + name: "future_expiry_from_peer", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", noExpiry), + n(2, "bar", timeInFuture), + }), + SelfNode: n(3, "self", noExpiry).View(), + }, + want: timeInFuture, + }, + { + name: "future_expiry_from_self", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", noExpiry), + n(2, "bar", noExpiry), + }), + SelfNode: n(3, "self", timeInFuture).View(), + }, + want: timeInFuture, + }, + { + name: "future_expiry_from_multiple_peers", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeInMoreFuture), + }), + SelfNode: n(3, "self", noExpiry).View(), + }, + want: timeInFuture, + }, + { + name: "future_expiry_from_peer_and_self", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInMoreFuture), + }), + SelfNode: n(2, "self", timeInFuture).View(), + }, + want: timeInFuture, + }, + { + name: "only_self", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{}), + SelfNode: n(1, "self", timeInFuture).View(), + }, + want: timeInFuture, + }, + { + name: "peer_already_expired", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInPast), + }), + SelfNode: n(2, "self", timeInFuture).View(), + }, + want: timeInFuture, + }, + { + name: "self_already_expired", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + }), + SelfNode: n(2, "self", timeInPast).View(), + }, + want: timeInFuture, + }, + { + name: "all_nodes_already_expired", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInPast), + }), + SelfNode: n(2, "self", timeInPast).View(), + }, + want: noExpiry, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + em := newExpiryManager(t.Logf) + em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) + got := em.nextPeerExpiry(tt.netmap, now) + if !got.Equal(tt.want) { + t.Errorf("got %q, want %q", got.Format(time.RFC3339), tt.want.Format(time.RFC3339)) + } else if !got.IsZero() && got.Before(now) { + t.Errorf("unexpectedly got expiry %q before now %q", got.Format(time.RFC3339), now.Format(time.RFC3339)) + } + }) + } + + t.Run("ClockSkew", func(t *testing.T) { + t.Logf("local time: %q", now.Format(time.RFC3339)) + em := newExpiryManager(t.Logf) + em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) + + // The local clock is "running fast"; our clock skew is -2h + em.clockDelta.Store(-2 * time.Hour) + t.Logf("'real' time: %q", now.Add(-2*time.Hour).Format(time.RFC3339)) + + // If we don't adjust for the local time, this would return a + // time in the past. + nm := &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInPast), + }), + } + got := em.nextPeerExpiry(nm, now) + want := now.Add(30 * time.Second) + if !got.Equal(want) { + t.Errorf("got %q, want %q", got.Format(time.RFC3339), want.Format(time.RFC3339)) + } + }) +} + +func formatNodes(nodes []tailcfg.NodeView) string { + var sb strings.Builder + for i, n := range nodes { + if i > 0 { + sb.WriteString(", ") + } + fmt.Fprintf(&sb, "(%d, %q", n.ID(), n.Name()) + + if n.Online() != nil { + fmt.Fprintf(&sb, ", online=%v", *n.Online()) + } + if n.LastSeen() != nil { + fmt.Fprintf(&sb, ", lastSeen=%v", n.LastSeen().Unix()) + } + if n.Key() != (key.NodePublic{}) { + fmt.Fprintf(&sb, ", key=%v", n.Key().String()) + } + if n.Expired() { + fmt.Fprintf(&sb, ", expired=true") + } + sb.WriteString(")") + } + return sb.String() +} diff --git a/ipn/ipnlocal/peerapi_h2c.go b/ipn/ipnlocal/peerapi_h2c.go index e6335fe2be5b6..fbfa8639808ae 100644 --- a/ipn/ipnlocal/peerapi_h2c.go +++ b/ipn/ipnlocal/peerapi_h2c.go @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ios && !android && !js - -package ipnlocal - -import ( - "net/http" - - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" -) - -func init() { - addH2C = func(s *http.Server) { - h2s := &http2.Server{} - s.Handler = h2c.NewHandler(s.Handler, h2s) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !android && !js + +package ipnlocal + +import ( + "net/http" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +func init() { + addH2C = func(s *http.Server) { + h2s := &http2.Server{} + s.Handler = h2c.NewHandler(s.Handler, h2s) + } +} diff --git a/ipn/ipnlocal/testdata/example.com-key.pem b/ipn/ipnlocal/testdata/example.com-key.pem index 9020553f1829b..06902f4c9c314 100644 --- a/ipn/ipnlocal/testdata/example.com-key.pem +++ b/ipn/ipnlocal/testdata/example.com-key.pem @@ -1,28 +1,28 @@ ------BEGIN PRIVATE KEY----- -MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCejQaJrntrJSgE -QtScyTU6TXOU+v1FdFjrsyHFK5mjV1C5pVQxnLn93GRshtIrGOLLrd3Wv2TVYZOX -xH7f1ZLFbneDURCXbS+7nmsg+TLHRSRKfODbE3oYZj7NSJ163CCvwSJKTdmLpXbn -ui9F04tyk0zxO4Wre4ukwf6xtse8G5zl2RJrueiVAiouTG/pJdIS08dGQa0GM1n9 -Aesa+TerlZcpRZR6X402yQqa8q/QqbIuzrlfDmgOb8sm6T8+JMtj3hEvnYdpMVOg -w/XiTlX0v/YrB9sVQ9XnqGsqwTL0OMG0choMNKipwLi2n+XPSCIiRhi666zNNivE -K1qaPS5RAgMBAAECggEAV9dAGQWPISR70CiKjLa5A60nbRHFQjackTE0c32daC6W -7dOYGsh/DxOMm8fyJqhp9nhEYJa3MbUWxU27ER3NbA6wrhM6gvqeKG8zYRhPNrGq -0o3vMdDPozb6cldZ0Fimz1jMO6h373NjtiyjxibWqkrLpRbaDtCq5EQKbMEcVa2D -Xt5hxCOaCA3OZ/mAcGUNFmDNgNsGP/r6eXdI5pbqnUNMPkv/JsHl8h2HuyKUm4hf -TRnXPAak6DkUod9QXYFKVBVPa5pjiO09e0aiMUvJ8vYd/6bNIsAKWLPa1PYuUE2l -kg8Nik+P/XLzffKsLxiFKY0nCqrorM9K5q7baofGdQKBgQDPujjebFg6OKw6MS3S -PESopvL//C/XgtgifcSSZCWzIZRVBVTbbJCGRtqFzF0XO4YRX3EOAyD/L7wYUPzO -+W3AU2W3/DVJYdcm2CASABbHNy0kk52LI0HHAssbFDgyB9XuuWP+vVZk7B5OmCAD -Bppuj6Mnu03i282nKNJzvRiVnwKBgQDDZUXv22K8y7GkKw/ZW/wQP2zBNtFc15he -1EOyUGHlXuQixnDSaqonkwec6IOlo7Sx/vwO/7+v4Jzc24Wq3DFAmMu/EYJgvI+m -m3kpB4H7Xus4JqnhxqN7GB7zOdguCWZF1HLemZNZlVrUjG5mQ9cizzvvYptnQDLq -FEJ1hddWDwKBgB+vy276Xfb7oCH8UH4KXXrQhK7RvEaGmgug3bRq/Gk3zRWvC4Ox -KtagxkK0qtqZZNkPkwJNLeJfWLTo3beAyuIUlqabHVHFT/mH7FRymQbofsVekyCf -TzBZV7wYuH3BPjv9IajBHwWkEvdwMyni/vmwhXXRF49schF2o6uuA6sHAoGBAL1J -Xnb+EKjUq0JedPwcIBOdXb3PXQKT2QgEmZAkTrHlOxx1INa2fh/YT4ext9a+wE2u -tn/RQeEfttY90z+yEASEAN0YGTWddYvxEW6t1z2stjGvQuN1ium0dEcrwkDW2jzL -knwSSqx+A3/kiw6GqeMO3wEIhYOArdIVzkwLXJABAoGAOXLGhz5u5FWjF3zAeYme -uHTU/3Z3jeI80PvShGrgAakPOBt3cIFpUaiOEslcqqgDUSGE3EnmkRqaEch+UapF -ty6Zz7cKjXhQSWOjew1uUW2ANNEpsnYbmZOOnfvosd7jfHSVbL6KIhWmIdC6h0NP -c/bJnTXEEVsWjLZTwYaq0Us= +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCejQaJrntrJSgE +QtScyTU6TXOU+v1FdFjrsyHFK5mjV1C5pVQxnLn93GRshtIrGOLLrd3Wv2TVYZOX +xH7f1ZLFbneDURCXbS+7nmsg+TLHRSRKfODbE3oYZj7NSJ163CCvwSJKTdmLpXbn +ui9F04tyk0zxO4Wre4ukwf6xtse8G5zl2RJrueiVAiouTG/pJdIS08dGQa0GM1n9 +Aesa+TerlZcpRZR6X402yQqa8q/QqbIuzrlfDmgOb8sm6T8+JMtj3hEvnYdpMVOg +w/XiTlX0v/YrB9sVQ9XnqGsqwTL0OMG0choMNKipwLi2n+XPSCIiRhi666zNNivE +K1qaPS5RAgMBAAECggEAV9dAGQWPISR70CiKjLa5A60nbRHFQjackTE0c32daC6W +7dOYGsh/DxOMm8fyJqhp9nhEYJa3MbUWxU27ER3NbA6wrhM6gvqeKG8zYRhPNrGq +0o3vMdDPozb6cldZ0Fimz1jMO6h373NjtiyjxibWqkrLpRbaDtCq5EQKbMEcVa2D +Xt5hxCOaCA3OZ/mAcGUNFmDNgNsGP/r6eXdI5pbqnUNMPkv/JsHl8h2HuyKUm4hf +TRnXPAak6DkUod9QXYFKVBVPa5pjiO09e0aiMUvJ8vYd/6bNIsAKWLPa1PYuUE2l +kg8Nik+P/XLzffKsLxiFKY0nCqrorM9K5q7baofGdQKBgQDPujjebFg6OKw6MS3S +PESopvL//C/XgtgifcSSZCWzIZRVBVTbbJCGRtqFzF0XO4YRX3EOAyD/L7wYUPzO ++W3AU2W3/DVJYdcm2CASABbHNy0kk52LI0HHAssbFDgyB9XuuWP+vVZk7B5OmCAD +Bppuj6Mnu03i282nKNJzvRiVnwKBgQDDZUXv22K8y7GkKw/ZW/wQP2zBNtFc15he +1EOyUGHlXuQixnDSaqonkwec6IOlo7Sx/vwO/7+v4Jzc24Wq3DFAmMu/EYJgvI+m +m3kpB4H7Xus4JqnhxqN7GB7zOdguCWZF1HLemZNZlVrUjG5mQ9cizzvvYptnQDLq +FEJ1hddWDwKBgB+vy276Xfb7oCH8UH4KXXrQhK7RvEaGmgug3bRq/Gk3zRWvC4Ox +KtagxkK0qtqZZNkPkwJNLeJfWLTo3beAyuIUlqabHVHFT/mH7FRymQbofsVekyCf +TzBZV7wYuH3BPjv9IajBHwWkEvdwMyni/vmwhXXRF49schF2o6uuA6sHAoGBAL1J +Xnb+EKjUq0JedPwcIBOdXb3PXQKT2QgEmZAkTrHlOxx1INa2fh/YT4ext9a+wE2u +tn/RQeEfttY90z+yEASEAN0YGTWddYvxEW6t1z2stjGvQuN1ium0dEcrwkDW2jzL +knwSSqx+A3/kiw6GqeMO3wEIhYOArdIVzkwLXJABAoGAOXLGhz5u5FWjF3zAeYme +uHTU/3Z3jeI80PvShGrgAakPOBt3cIFpUaiOEslcqqgDUSGE3EnmkRqaEch+UapF +ty6Zz7cKjXhQSWOjew1uUW2ANNEpsnYbmZOOnfvosd7jfHSVbL6KIhWmIdC6h0NP +c/bJnTXEEVsWjLZTwYaq0Us= -----END PRIVATE KEY----- \ No newline at end of file diff --git a/ipn/ipnlocal/testdata/example.com.pem b/ipn/ipnlocal/testdata/example.com.pem index 65e7110a8d1ae..588850813b102 100644 --- a/ipn/ipnlocal/testdata/example.com.pem +++ b/ipn/ipnlocal/testdata/example.com.pem @@ -1,26 +1,26 @@ ------BEGIN CERTIFICATE----- -MIIEcDCCAtigAwIBAgIRAPmUKRkyFAkVVxFblB/233cwDQYJKoZIhvcNAQELBQAw -gZ8xHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTE6MDgGA1UECwwxZnJv -bWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWljaGFlbCBKLiBGcm9tYmVyZ2VyKTFB -MD8GA1UEAww4bWtjZXJ0IGZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hh -ZWwgSi4gRnJvbWJlcmdlcikwHhcNMjMwMjA3MjAzNDE4WhcNMjUwNTA3MTkzNDE4 -WjBlMScwJQYDVQQKEx5ta2NlcnQgZGV2ZWxvcG1lbnQgY2VydGlmaWNhdGUxOjA4 -BgNVBAsMMWZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hhZWwgSi4gRnJv -bWJlcmdlcikwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCejQaJrntr -JSgEQtScyTU6TXOU+v1FdFjrsyHFK5mjV1C5pVQxnLn93GRshtIrGOLLrd3Wv2TV -YZOXxH7f1ZLFbneDURCXbS+7nmsg+TLHRSRKfODbE3oYZj7NSJ163CCvwSJKTdmL -pXbnui9F04tyk0zxO4Wre4ukwf6xtse8G5zl2RJrueiVAiouTG/pJdIS08dGQa0G -M1n9Aesa+TerlZcpRZR6X402yQqa8q/QqbIuzrlfDmgOb8sm6T8+JMtj3hEvnYdp -MVOgw/XiTlX0v/YrB9sVQ9XnqGsqwTL0OMG0choMNKipwLi2n+XPSCIiRhi666zN -NivEK1qaPS5RAgMBAAGjYDBeMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUEDDAKBggr -BgEFBQcDATAfBgNVHSMEGDAWgBTXyq2jQVrnqQKL8fB9C4L0QJftwDAWBgNVHREE -DzANggtleGFtcGxlLmNvbTANBgkqhkiG9w0BAQsFAAOCAYEAQWzpOaBkRR4M+WqB -CsT4ARyM6WpZ+jpeSblCzPdlDRW+50G1HV7K930zayq4DwncPY/SqSn0Q31WuzZv -bTWHkWa+MLPGYANHsusOmMR8Eh16G4+5+GGf8psWa0npAYO35cuNkyyCCc1LEB4M -NrzCB2+KZ+SyOdfCCA5VzEKN3I8wvVLaYovi24Zjwv+0uETG92TlZmLQRhj8uPxN -deeLM45aBkQZSYCbGMDVDK/XYKBkNLn3kxD/eZeXxxr41v4pH44+46FkYcYJzdn8 -ccAg5LRGieqTozhLiXARNK1vTy6kR1l/Az8DIx6GN4sP2/LMFYFijiiOCDKS1wWA -xQgZeHt4GIuBym+Kd+Z5KXcP0AT+47Cby3+B10Kq8vHwjTELiF0UFeEYYMdynPAW -pbEwVLhsfMsBqFtj3dsxHr8Kz3rnarOYzkaw7EMZnLAthb2CN7y5uGV9imQC5RMI -/qZdRSuCYZ3A1E/WJkGbPY/YdPql/IE+LIAgKGFHZZNftBCo +-----BEGIN CERTIFICATE----- +MIIEcDCCAtigAwIBAgIRAPmUKRkyFAkVVxFblB/233cwDQYJKoZIhvcNAQELBQAw +gZ8xHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTE6MDgGA1UECwwxZnJv +bWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWljaGFlbCBKLiBGcm9tYmVyZ2VyKTFB +MD8GA1UEAww4bWtjZXJ0IGZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hh +ZWwgSi4gRnJvbWJlcmdlcikwHhcNMjMwMjA3MjAzNDE4WhcNMjUwNTA3MTkzNDE4 +WjBlMScwJQYDVQQKEx5ta2NlcnQgZGV2ZWxvcG1lbnQgY2VydGlmaWNhdGUxOjA4 +BgNVBAsMMWZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hhZWwgSi4gRnJv +bWJlcmdlcikwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCejQaJrntr +JSgEQtScyTU6TXOU+v1FdFjrsyHFK5mjV1C5pVQxnLn93GRshtIrGOLLrd3Wv2TV +YZOXxH7f1ZLFbneDURCXbS+7nmsg+TLHRSRKfODbE3oYZj7NSJ163CCvwSJKTdmL +pXbnui9F04tyk0zxO4Wre4ukwf6xtse8G5zl2RJrueiVAiouTG/pJdIS08dGQa0G +M1n9Aesa+TerlZcpRZR6X402yQqa8q/QqbIuzrlfDmgOb8sm6T8+JMtj3hEvnYdp +MVOgw/XiTlX0v/YrB9sVQ9XnqGsqwTL0OMG0choMNKipwLi2n+XPSCIiRhi666zN +NivEK1qaPS5RAgMBAAGjYDBeMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUEDDAKBggr +BgEFBQcDATAfBgNVHSMEGDAWgBTXyq2jQVrnqQKL8fB9C4L0QJftwDAWBgNVHREE +DzANggtleGFtcGxlLmNvbTANBgkqhkiG9w0BAQsFAAOCAYEAQWzpOaBkRR4M+WqB +CsT4ARyM6WpZ+jpeSblCzPdlDRW+50G1HV7K930zayq4DwncPY/SqSn0Q31WuzZv +bTWHkWa+MLPGYANHsusOmMR8Eh16G4+5+GGf8psWa0npAYO35cuNkyyCCc1LEB4M +NrzCB2+KZ+SyOdfCCA5VzEKN3I8wvVLaYovi24Zjwv+0uETG92TlZmLQRhj8uPxN +deeLM45aBkQZSYCbGMDVDK/XYKBkNLn3kxD/eZeXxxr41v4pH44+46FkYcYJzdn8 +ccAg5LRGieqTozhLiXARNK1vTy6kR1l/Az8DIx6GN4sP2/LMFYFijiiOCDKS1wWA +xQgZeHt4GIuBym+Kd+Z5KXcP0AT+47Cby3+B10Kq8vHwjTELiF0UFeEYYMdynPAW +pbEwVLhsfMsBqFtj3dsxHr8Kz3rnarOYzkaw7EMZnLAthb2CN7y5uGV9imQC5RMI +/qZdRSuCYZ3A1E/WJkGbPY/YdPql/IE+LIAgKGFHZZNftBCo -----END CERTIFICATE----- \ No newline at end of file diff --git a/ipn/ipnlocal/testdata/rootCA.pem b/ipn/ipnlocal/testdata/rootCA.pem index 28bd25467f07f..88a16f47a8ac9 100644 --- a/ipn/ipnlocal/testdata/rootCA.pem +++ b/ipn/ipnlocal/testdata/rootCA.pem @@ -1,30 +1,30 @@ ------BEGIN CERTIFICATE----- -MIIFEDCCA3igAwIBAgIRANf5NdPojIfj70wMfJVYUg8wDQYJKoZIhvcNAQELBQAw -gZ8xHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTE6MDgGA1UECwwxZnJv -bWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWljaGFlbCBKLiBGcm9tYmVyZ2VyKTFB -MD8GA1UEAww4bWtjZXJ0IGZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hh -ZWwgSi4gRnJvbWJlcmdlcikwHhcNMjMwMjA3MjAzNDE4WhcNMzMwMjA3MjAzNDE4 -WjCBnzEeMBwGA1UEChMVbWtjZXJ0IGRldmVsb3BtZW50IENBMTowOAYDVQQLDDFm -cm9tYmVyZ2VyQHN0YXJkdXN0LmxvY2FsIChNaWNoYWVsIEouIEZyb21iZXJnZXIp -MUEwPwYDVQQDDDhta2NlcnQgZnJvbWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWlj -aGFlbCBKLiBGcm9tYmVyZ2VyKTCCAaIwDQYJKoZIhvcNAQEBBQADggGPADCCAYoC -ggGBAL5uXNnrZ6dgjcvK0Hc7ZNUIRYEWst9qbO0P9H7le08pJ6d9T2BUWruZtVjk -Q12msv5/bVWHhVk8dZclI9FLXuMsIrocH8bsoP4wruPMyRyp6EedSKODN51fFSRv -/jHbS5vzUVAWTYy9qYmd6qL0uhsHCZCCT6gfigamHPUFKM3sHDn5ZHWvySMwcyGl -AicmPAIkBWqiCZAkB5+WM7+oyRLjmrIalfWIZYxW/rojGLwTfneHv6J5WjVQnpJB -ayWCzCzaiXukK9MeBWeTOe8UfVN0Engd74/rjLWvjbfC+uZSr6RVkZvs2jANLwPF -zgzBPHgRPfAhszU1NNAMjnNQ47+OMOTKRt7e6jYzhO5fyO1qVAAvGBqcfpj+JfDk -cccaUMhUvdiGrhGf1V1tN/PislxvALirzcFipjD01isBKwn0fxRugzvJNrjEo8RA -RvbcdeKcwex7M0o/Cd0+G2B13gZNOFvR33PmG7iTpp7IUrUKfQg28I83Sp8tMY3s -ljJSawIDAQABo0UwQzAOBgNVHQ8BAf8EBAMCAgQwEgYDVR0TAQH/BAgwBgEB/wIB -ADAdBgNVHQ4EFgQU18qto0Fa56kCi/HwfQuC9ECX7cAwDQYJKoZIhvcNAQELBQAD -ggGBAAzs96LwZVOsRSlBdQqMo8oMAvs7HgnYbXt8SqaACLX3+kJ3cV/vrCE3iJrW -ma4CiQbxS/HqsiZjota5m4lYeEevRnUDpXhp+7ugZTiz33Flm1RU99c9UYfQ+919 -ANPAKeqNpoPco/HF5Bz0ocepjcfKQrVZZNTj6noLs8o12FHBLO5976AcF9mqlNfh -8/F0gDJXq6+x7VT5y8u0rY004XKPRe3CklRt8kpeMiP6mhRyyUehOaHeIbNx8ubi -Pi44ByN/ueAnuRhF9zYtyZVZZOaSLysJge01tuPXF8rBXGruoJIv35xTTBa9BzaP -YDOGbGn1ZnajdNagHqCba8vjTLDSpqMvgRj3TFrGHdETA2LDQat38uVxX8gxm68K -va5Tyv7n+6BQ5YTpJjTPnmSJKaXZrrhdLPvG0OU2TxeEsvbcm5LFQofirOOw86Se -vzF2cQ94mmHRZiEk0Av3NO0jF93ELDrBCuiccVyEKq6TknuvPQlutCXKDOYSEb8I -MHctBg== +-----BEGIN CERTIFICATE----- +MIIFEDCCA3igAwIBAgIRANf5NdPojIfj70wMfJVYUg8wDQYJKoZIhvcNAQELBQAw +gZ8xHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTE6MDgGA1UECwwxZnJv +bWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWljaGFlbCBKLiBGcm9tYmVyZ2VyKTFB +MD8GA1UEAww4bWtjZXJ0IGZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hh +ZWwgSi4gRnJvbWJlcmdlcikwHhcNMjMwMjA3MjAzNDE4WhcNMzMwMjA3MjAzNDE4 +WjCBnzEeMBwGA1UEChMVbWtjZXJ0IGRldmVsb3BtZW50IENBMTowOAYDVQQLDDFm +cm9tYmVyZ2VyQHN0YXJkdXN0LmxvY2FsIChNaWNoYWVsIEouIEZyb21iZXJnZXIp +MUEwPwYDVQQDDDhta2NlcnQgZnJvbWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWlj +aGFlbCBKLiBGcm9tYmVyZ2VyKTCCAaIwDQYJKoZIhvcNAQEBBQADggGPADCCAYoC +ggGBAL5uXNnrZ6dgjcvK0Hc7ZNUIRYEWst9qbO0P9H7le08pJ6d9T2BUWruZtVjk +Q12msv5/bVWHhVk8dZclI9FLXuMsIrocH8bsoP4wruPMyRyp6EedSKODN51fFSRv +/jHbS5vzUVAWTYy9qYmd6qL0uhsHCZCCT6gfigamHPUFKM3sHDn5ZHWvySMwcyGl +AicmPAIkBWqiCZAkB5+WM7+oyRLjmrIalfWIZYxW/rojGLwTfneHv6J5WjVQnpJB +ayWCzCzaiXukK9MeBWeTOe8UfVN0Engd74/rjLWvjbfC+uZSr6RVkZvs2jANLwPF +zgzBPHgRPfAhszU1NNAMjnNQ47+OMOTKRt7e6jYzhO5fyO1qVAAvGBqcfpj+JfDk +cccaUMhUvdiGrhGf1V1tN/PislxvALirzcFipjD01isBKwn0fxRugzvJNrjEo8RA +RvbcdeKcwex7M0o/Cd0+G2B13gZNOFvR33PmG7iTpp7IUrUKfQg28I83Sp8tMY3s +ljJSawIDAQABo0UwQzAOBgNVHQ8BAf8EBAMCAgQwEgYDVR0TAQH/BAgwBgEB/wIB +ADAdBgNVHQ4EFgQU18qto0Fa56kCi/HwfQuC9ECX7cAwDQYJKoZIhvcNAQELBQAD +ggGBAAzs96LwZVOsRSlBdQqMo8oMAvs7HgnYbXt8SqaACLX3+kJ3cV/vrCE3iJrW +ma4CiQbxS/HqsiZjota5m4lYeEevRnUDpXhp+7ugZTiz33Flm1RU99c9UYfQ+919 +ANPAKeqNpoPco/HF5Bz0ocepjcfKQrVZZNTj6noLs8o12FHBLO5976AcF9mqlNfh +8/F0gDJXq6+x7VT5y8u0rY004XKPRe3CklRt8kpeMiP6mhRyyUehOaHeIbNx8ubi +Pi44ByN/ueAnuRhF9zYtyZVZZOaSLysJge01tuPXF8rBXGruoJIv35xTTBa9BzaP +YDOGbGn1ZnajdNagHqCba8vjTLDSpqMvgRj3TFrGHdETA2LDQat38uVxX8gxm68K +va5Tyv7n+6BQ5YTpJjTPnmSJKaXZrrhdLPvG0OU2TxeEsvbcm5LFQofirOOw86Se +vzF2cQ94mmHRZiEk0Av3NO0jF93ELDrBCuiccVyEKq6TknuvPQlutCXKDOYSEb8I +MHctBg== -----END CERTIFICATE----- \ No newline at end of file diff --git a/ipn/ipnserver/proxyconnect_js.go b/ipn/ipnserver/proxyconnect_js.go index 27448fa0dcce6..368221e2269c8 100644 --- a/ipn/ipnserver/proxyconnect_js.go +++ b/ipn/ipnserver/proxyconnect_js.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnserver - -import "net/http" - -func (s *Server) handleProxyConnectConn(w http.ResponseWriter, r *http.Request) { - panic("unreachable") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnserver + +import "net/http" + +func (s *Server) handleProxyConnectConn(w http.ResponseWriter, r *http.Request) { + panic("unreachable") +} diff --git a/ipn/ipnserver/server_test.go b/ipn/ipnserver/server_test.go index 49fb4d01f3ae0..b7d5ea144c408 100644 --- a/ipn/ipnserver/server_test.go +++ b/ipn/ipnserver/server_test.go @@ -1,46 +1,46 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnserver - -import ( - "context" - "sync" - "testing" -) - -func TestWaiterSet(t *testing.T) { - var s waiterSet - - wantLen := func(want int, when string) { - t.Helper() - if got := len(s); got != want { - t.Errorf("%s: len = %v; want %v", when, got, want) - } - } - wantLen(0, "initial") - var mu sync.Mutex - ctx, cancel := context.WithCancel(context.Background()) - - ready, cleanup := s.add(&mu, ctx) - wantLen(1, "after add") - - select { - case <-ready: - t.Fatal("should not be ready") - default: - } - s.wakeAll() - <-ready - - wantLen(1, "after fire") - cleanup() - wantLen(0, "after cleanup") - - // And again but on an already-expired ctx. - cancel() - ready, cleanup = s.add(&mu, ctx) - <-ready // shouldn't block - cleanup() - wantLen(0, "at end") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnserver + +import ( + "context" + "sync" + "testing" +) + +func TestWaiterSet(t *testing.T) { + var s waiterSet + + wantLen := func(want int, when string) { + t.Helper() + if got := len(s); got != want { + t.Errorf("%s: len = %v; want %v", when, got, want) + } + } + wantLen(0, "initial") + var mu sync.Mutex + ctx, cancel := context.WithCancel(context.Background()) + + ready, cleanup := s.add(&mu, ctx) + wantLen(1, "after add") + + select { + case <-ready: + t.Fatal("should not be ready") + default: + } + s.wakeAll() + <-ready + + wantLen(1, "after fire") + cleanup() + wantLen(0, "after cleanup") + + // And again but on an already-expired ctx. + cancel() + ready, cleanup = s.add(&mu, ctx) + <-ready // shouldn't block + cleanup() + wantLen(0, "at end") +} diff --git a/ipn/localapi/disabled_stubs.go b/ipn/localapi/disabled_stubs.go index 230553c145840..c744f34d5f5c5 100644 --- a/ipn/localapi/disabled_stubs.go +++ b/ipn/localapi/disabled_stubs.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ios || android || js - -package localapi - -import ( - "net/http" - "runtime" -) - -func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { - http.Error(w, "disabled on "+runtime.GOOS, http.StatusNotFound) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ios || android || js + +package localapi + +import ( + "net/http" + "runtime" +) + +func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { + http.Error(w, "disabled on "+runtime.GOOS, http.StatusNotFound) +} diff --git a/ipn/localapi/pprof.go b/ipn/localapi/pprof.go index 5cc4daca1cf39..8c9429b31385a 100644 --- a/ipn/localapi/pprof.go +++ b/ipn/localapi/pprof.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ios && !android && !js - -// We don't include it on mobile where we're more memory constrained and -// there's no CLI to get at the results anyway. - -package localapi - -import ( - "net/http" - "net/http/pprof" -) - -func init() { - servePprofFunc = servePprof -} - -func servePprof(w http.ResponseWriter, r *http.Request) { - name := r.FormValue("name") - switch name { - case "profile": - pprof.Profile(w, r) - default: - pprof.Handler(name).ServeHTTP(w, r) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !android && !js + +// We don't include it on mobile where we're more memory constrained and +// there's no CLI to get at the results anyway. + +package localapi + +import ( + "net/http" + "net/http/pprof" +) + +func init() { + servePprofFunc = servePprof +} + +func servePprof(w http.ResponseWriter, r *http.Request) { + name := r.FormValue("name") + switch name { + case "profile": + pprof.Profile(w, r) + default: + pprof.Handler(name).ServeHTTP(w, r) + } +} diff --git a/ipn/policy/policy.go b/ipn/policy/policy.go index 834706f31a389..494a0dc408819 100644 --- a/ipn/policy/policy.go +++ b/ipn/policy/policy.go @@ -1,47 +1,47 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package policy contains various policy decisions that need to be -// shared between the node client & control server. -package policy - -import ( - "tailscale.com/tailcfg" -) - -// IsInterestingService reports whether service s on the given operating -// system (a version.OS value) is an interesting enough port to report -// to our peer nodes for discovery purposes. -func IsInterestingService(s tailcfg.Service, os string) bool { - switch s.Proto { - case tailcfg.PeerAPI4, tailcfg.PeerAPI6, tailcfg.PeerAPIDNS: - return true - } - if s.Proto != tailcfg.TCP { - return false - } - if os != "windows" { - // For non-Windows machines, assume all TCP listeners - // are interesting enough. We don't see listener spam - // there. - return true - } - // Windows has tons of TCP listeners. We need to move to a denylist - // model later, but for now we just allow some common ones: - switch s.Port { - case 22, // ssh - 80, // http - 443, // https (but no hostname, so little useless) - 3389, // rdp - 5900, // vnc - 32400, // plex - - // And now some arbitrary HTTP dev server ports: - // Eventually we'll remove this and make all ports - // work, once we nicely filter away noisy system - // ports. - 8000, 8080, 8443, 8888: - return true - } - return false -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package policy contains various policy decisions that need to be +// shared between the node client & control server. +package policy + +import ( + "tailscale.com/tailcfg" +) + +// IsInterestingService reports whether service s on the given operating +// system (a version.OS value) is an interesting enough port to report +// to our peer nodes for discovery purposes. +func IsInterestingService(s tailcfg.Service, os string) bool { + switch s.Proto { + case tailcfg.PeerAPI4, tailcfg.PeerAPI6, tailcfg.PeerAPIDNS: + return true + } + if s.Proto != tailcfg.TCP { + return false + } + if os != "windows" { + // For non-Windows machines, assume all TCP listeners + // are interesting enough. We don't see listener spam + // there. + return true + } + // Windows has tons of TCP listeners. We need to move to a denylist + // model later, but for now we just allow some common ones: + switch s.Port { + case 22, // ssh + 80, // http + 443, // https (but no hostname, so little useless) + 3389, // rdp + 5900, // vnc + 32400, // plex + + // And now some arbitrary HTTP dev server ports: + // Eventually we'll remove this and make all ports + // work, once we nicely filter away noisy system + // ports. + 8000, 8080, 8443, 8888: + return true + } + return false +} diff --git a/ipn/store/awsstore/store_aws.go b/ipn/store/awsstore/store_aws.go index 84059af67c57d..0fb78d45a6a53 100644 --- a/ipn/store/awsstore/store_aws.go +++ b/ipn/store/awsstore/store_aws.go @@ -1,186 +1,186 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux && !ts_omit_aws - -// Package awsstore contains an ipn.StateStore implementation using AWS SSM. -package awsstore - -import ( - "context" - "errors" - "fmt" - "regexp" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/aws/arn" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/service/ssm" - ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" - "tailscale.com/ipn" - "tailscale.com/ipn/store/mem" - "tailscale.com/types/logger" -) - -const ( - parameterNameRxStr = `^parameter(/.*)` -) - -var parameterNameRx = regexp.MustCompile(parameterNameRxStr) - -// awsSSMClient is an interface allowing us to mock the couple of -// API calls we are leveraging with the AWSStore provider -type awsSSMClient interface { - GetParameter(ctx context.Context, - params *ssm.GetParameterInput, - optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) - - PutParameter(ctx context.Context, - params *ssm.PutParameterInput, - optFns ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) -} - -// store is a store which leverages AWS SSM parameter store -// to persist the state -type awsStore struct { - ssmClient awsSSMClient - ssmARN arn.ARN - - memory mem.Store -} - -// New returns a new ipn.StateStore using the AWS SSM storage -// location given by ssmARN. -// -// Note that we store the entire store in a single parameter -// key, therefore if the state is above 8kb, it can cause -// Tailscaled to only only store new state in-memory and -// restarting Tailscaled can fail until you delete your state -// from the AWS Parameter Store. -func New(_ logger.Logf, ssmARN string) (ipn.StateStore, error) { - return newStore(ssmARN, nil) -} - -// newStore is NewStore, but for tests. If client is non-nil, it's -// used instead of making one. -func newStore(ssmARN string, client awsSSMClient) (ipn.StateStore, error) { - s := &awsStore{ - ssmClient: client, - } - - var err error - - // Parse the ARN - if s.ssmARN, err = arn.Parse(ssmARN); err != nil { - return nil, fmt.Errorf("unable to parse the ARN correctly: %v", err) - } - - // Validate the ARN corresponds to the SSM service - if s.ssmARN.Service != "ssm" { - return nil, fmt.Errorf("invalid service %q, expected 'ssm'", s.ssmARN.Service) - } - - // Validate the ARN corresponds to a parameter store resource - if !parameterNameRx.MatchString(s.ssmARN.Resource) { - return nil, fmt.Errorf("invalid resource %q, expected to match %v", s.ssmARN.Resource, parameterNameRxStr) - } - - if s.ssmClient == nil { - var cfg aws.Config - if cfg, err = config.LoadDefaultConfig( - context.TODO(), - config.WithRegion(s.ssmARN.Region), - ); err != nil { - return nil, err - } - s.ssmClient = ssm.NewFromConfig(cfg) - } - - // Hydrate cache with the potentially current state - if err := s.LoadState(); err != nil { - return nil, err - } - return s, nil - -} - -// LoadState attempts to read the state from AWS SSM parameter store key. -func (s *awsStore) LoadState() error { - param, err := s.ssmClient.GetParameter( - context.TODO(), - &ssm.GetParameterInput{ - Name: aws.String(s.ParameterName()), - WithDecryption: aws.Bool(true), - }, - ) - - if err != nil { - var pnf *ssmTypes.ParameterNotFound - if errors.As(err, &pnf) { - // Create the parameter as it does not exist yet - // and return directly as it is defacto empty - return s.persistState() - } - return err - } - - // Load the content in-memory - return s.memory.LoadFromJSON([]byte(*param.Parameter.Value)) -} - -// ParameterName returns the parameter name extracted from -// the provided ARN -func (s *awsStore) ParameterName() (name string) { - values := parameterNameRx.FindStringSubmatch(s.ssmARN.Resource) - if len(values) == 2 { - name = values[1] - } - return -} - -// String returns the awsStore and the ARN of the SSM parameter store -// configured to store the state -func (s *awsStore) String() string { return fmt.Sprintf("awsStore(%q)", s.ssmARN.String()) } - -// ReadState implements the Store interface. -func (s *awsStore) ReadState(id ipn.StateKey) (bs []byte, err error) { - return s.memory.ReadState(id) -} - -// WriteState implements the Store interface. -func (s *awsStore) WriteState(id ipn.StateKey, bs []byte) (err error) { - // Write the state in-memory - if err = s.memory.WriteState(id, bs); err != nil { - return - } - - // Persist the state in AWS SSM parameter store - return s.persistState() -} - -// PersistState saves the states into the AWS SSM parameter store -func (s *awsStore) persistState() error { - // Generate JSON from in-memory cache - bs, err := s.memory.ExportToJSON() - if err != nil { - return err - } - - // Store in AWS SSM parameter store. - // - // We use intelligent tiering so that when the state is below 4kb, it uses Standard tiering - // which is free. However, if it exceeds 4kb it switches the parameter to advanced tiering - // doubling the capacity to 8kb per the following docs: - // https://aws.amazon.com/about-aws/whats-new/2019/08/aws-systems-manager-parameter-store-announces-intelligent-tiering-to-enable-automatic-parameter-tier-selection/ - _, err = s.ssmClient.PutParameter( - context.TODO(), - &ssm.PutParameterInput{ - Name: aws.String(s.ParameterName()), - Value: aws.String(string(bs)), - Overwrite: aws.Bool(true), - Tier: ssmTypes.ParameterTierIntelligentTiering, - Type: ssmTypes.ParameterTypeSecureString, - }, - ) - return err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !ts_omit_aws + +// Package awsstore contains an ipn.StateStore implementation using AWS SSM. +package awsstore + +import ( + "context" + "errors" + "fmt" + "regexp" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ssm" + ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" + "tailscale.com/ipn" + "tailscale.com/ipn/store/mem" + "tailscale.com/types/logger" +) + +const ( + parameterNameRxStr = `^parameter(/.*)` +) + +var parameterNameRx = regexp.MustCompile(parameterNameRxStr) + +// awsSSMClient is an interface allowing us to mock the couple of +// API calls we are leveraging with the AWSStore provider +type awsSSMClient interface { + GetParameter(ctx context.Context, + params *ssm.GetParameterInput, + optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) + + PutParameter(ctx context.Context, + params *ssm.PutParameterInput, + optFns ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) +} + +// store is a store which leverages AWS SSM parameter store +// to persist the state +type awsStore struct { + ssmClient awsSSMClient + ssmARN arn.ARN + + memory mem.Store +} + +// New returns a new ipn.StateStore using the AWS SSM storage +// location given by ssmARN. +// +// Note that we store the entire store in a single parameter +// key, therefore if the state is above 8kb, it can cause +// Tailscaled to only only store new state in-memory and +// restarting Tailscaled can fail until you delete your state +// from the AWS Parameter Store. +func New(_ logger.Logf, ssmARN string) (ipn.StateStore, error) { + return newStore(ssmARN, nil) +} + +// newStore is NewStore, but for tests. If client is non-nil, it's +// used instead of making one. +func newStore(ssmARN string, client awsSSMClient) (ipn.StateStore, error) { + s := &awsStore{ + ssmClient: client, + } + + var err error + + // Parse the ARN + if s.ssmARN, err = arn.Parse(ssmARN); err != nil { + return nil, fmt.Errorf("unable to parse the ARN correctly: %v", err) + } + + // Validate the ARN corresponds to the SSM service + if s.ssmARN.Service != "ssm" { + return nil, fmt.Errorf("invalid service %q, expected 'ssm'", s.ssmARN.Service) + } + + // Validate the ARN corresponds to a parameter store resource + if !parameterNameRx.MatchString(s.ssmARN.Resource) { + return nil, fmt.Errorf("invalid resource %q, expected to match %v", s.ssmARN.Resource, parameterNameRxStr) + } + + if s.ssmClient == nil { + var cfg aws.Config + if cfg, err = config.LoadDefaultConfig( + context.TODO(), + config.WithRegion(s.ssmARN.Region), + ); err != nil { + return nil, err + } + s.ssmClient = ssm.NewFromConfig(cfg) + } + + // Hydrate cache with the potentially current state + if err := s.LoadState(); err != nil { + return nil, err + } + return s, nil + +} + +// LoadState attempts to read the state from AWS SSM parameter store key. +func (s *awsStore) LoadState() error { + param, err := s.ssmClient.GetParameter( + context.TODO(), + &ssm.GetParameterInput{ + Name: aws.String(s.ParameterName()), + WithDecryption: aws.Bool(true), + }, + ) + + if err != nil { + var pnf *ssmTypes.ParameterNotFound + if errors.As(err, &pnf) { + // Create the parameter as it does not exist yet + // and return directly as it is defacto empty + return s.persistState() + } + return err + } + + // Load the content in-memory + return s.memory.LoadFromJSON([]byte(*param.Parameter.Value)) +} + +// ParameterName returns the parameter name extracted from +// the provided ARN +func (s *awsStore) ParameterName() (name string) { + values := parameterNameRx.FindStringSubmatch(s.ssmARN.Resource) + if len(values) == 2 { + name = values[1] + } + return +} + +// String returns the awsStore and the ARN of the SSM parameter store +// configured to store the state +func (s *awsStore) String() string { return fmt.Sprintf("awsStore(%q)", s.ssmARN.String()) } + +// ReadState implements the Store interface. +func (s *awsStore) ReadState(id ipn.StateKey) (bs []byte, err error) { + return s.memory.ReadState(id) +} + +// WriteState implements the Store interface. +func (s *awsStore) WriteState(id ipn.StateKey, bs []byte) (err error) { + // Write the state in-memory + if err = s.memory.WriteState(id, bs); err != nil { + return + } + + // Persist the state in AWS SSM parameter store + return s.persistState() +} + +// PersistState saves the states into the AWS SSM parameter store +func (s *awsStore) persistState() error { + // Generate JSON from in-memory cache + bs, err := s.memory.ExportToJSON() + if err != nil { + return err + } + + // Store in AWS SSM parameter store. + // + // We use intelligent tiering so that when the state is below 4kb, it uses Standard tiering + // which is free. However, if it exceeds 4kb it switches the parameter to advanced tiering + // doubling the capacity to 8kb per the following docs: + // https://aws.amazon.com/about-aws/whats-new/2019/08/aws-systems-manager-parameter-store-announces-intelligent-tiering-to-enable-automatic-parameter-tier-selection/ + _, err = s.ssmClient.PutParameter( + context.TODO(), + &ssm.PutParameterInput{ + Name: aws.String(s.ParameterName()), + Value: aws.String(string(bs)), + Overwrite: aws.Bool(true), + Tier: ssmTypes.ParameterTierIntelligentTiering, + Type: ssmTypes.ParameterTypeSecureString, + }, + ) + return err +} diff --git a/ipn/store/awsstore/store_aws_stub.go b/ipn/store/awsstore/store_aws_stub.go index 7be8b858d752f..8d2156ce948d5 100644 --- a/ipn/store/awsstore/store_aws_stub.go +++ b/ipn/store/awsstore/store_aws_stub.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux || ts_omit_aws - -package awsstore - -import ( - "fmt" - "runtime" - - "tailscale.com/ipn" - "tailscale.com/types/logger" -) - -func New(logger.Logf, string) (ipn.StateStore, error) { - return nil, fmt.Errorf("AWS store is not supported on %v", runtime.GOOS) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux || ts_omit_aws + +package awsstore + +import ( + "fmt" + "runtime" + + "tailscale.com/ipn" + "tailscale.com/types/logger" +) + +func New(logger.Logf, string) (ipn.StateStore, error) { + return nil, fmt.Errorf("AWS store is not supported on %v", runtime.GOOS) +} diff --git a/ipn/store/awsstore/store_aws_test.go b/ipn/store/awsstore/store_aws_test.go index 54e6e18cb4115..f6c8fedb32dc9 100644 --- a/ipn/store/awsstore/store_aws_test.go +++ b/ipn/store/awsstore/store_aws_test.go @@ -1,164 +1,164 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package awsstore - -import ( - "context" - "testing" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/aws/arn" - "github.com/aws/aws-sdk-go-v2/service/ssm" - ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" - "tailscale.com/ipn" - "tailscale.com/tstest" -) - -type mockedAWSSSMClient struct { - value string -} - -func (sp *mockedAWSSSMClient) GetParameter(_ context.Context, input *ssm.GetParameterInput, _ ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { - output := new(ssm.GetParameterOutput) - if sp.value == "" { - return output, &ssmTypes.ParameterNotFound{} - } - - output.Parameter = &ssmTypes.Parameter{ - Value: aws.String(sp.value), - } - - return output, nil -} - -func (sp *mockedAWSSSMClient) PutParameter(_ context.Context, input *ssm.PutParameterInput, _ ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) { - sp.value = *input.Value - return new(ssm.PutParameterOutput), nil -} - -func TestAWSStoreString(t *testing.T) { - store := &awsStore{ - ssmARN: arn.ARN{ - Service: "ssm", - Region: "eu-west-1", - AccountID: "123456789", - Resource: "parameter/foo", - }, - } - want := "awsStore(\"arn::ssm:eu-west-1:123456789:parameter/foo\")" - if got := store.String(); got != want { - t.Errorf("AWSStore.String = %q; want %q", got, want) - } -} - -func TestNewAWSStore(t *testing.T) { - tstest.PanicOnLog() - - mc := &mockedAWSSSMClient{} - storeParameterARN := arn.ARN{ - Service: "ssm", - Region: "eu-west-1", - AccountID: "123456789", - Resource: "parameter/foo", - } - - s, err := newStore(storeParameterARN.String(), mc) - if err != nil { - t.Fatalf("creating aws store failed: %v", err) - } - testStoreSemantics(t, s) - - // Build a brand new file store and check that both IDs written - // above are still there. - s2, err := newStore(storeParameterARN.String(), mc) - if err != nil { - t.Fatalf("creating second aws store failed: %v", err) - } - store2 := s.(*awsStore) - - // This is specific to the test, with the non-mocked API, LoadState() should - // have been already called and successful as no err is returned from NewAWSStore() - s2.(*awsStore).LoadState() - - expected := map[ipn.StateKey]string{ - "foo": "bar", - "baz": "quux", - } - for id, want := range expected { - bs, err := store2.ReadState(id) - if err != nil { - t.Errorf("reading %q (2nd store): %v", id, err) - } - if string(bs) != want { - t.Errorf("reading %q (2nd store): got %q, want %q", id, string(bs), want) - } - } -} - -func testStoreSemantics(t *testing.T, store ipn.StateStore) { - t.Helper() - - tests := []struct { - // if true, data is data to write. If false, data is expected - // output of read. - write bool - id ipn.StateKey - data string - // If write=false, true if we expect a not-exist error. - notExists bool - }{ - { - id: "foo", - notExists: true, - }, - { - write: true, - id: "foo", - data: "bar", - }, - { - id: "foo", - data: "bar", - }, - { - id: "baz", - notExists: true, - }, - { - write: true, - id: "baz", - data: "quux", - }, - { - id: "foo", - data: "bar", - }, - { - id: "baz", - data: "quux", - }, - } - - for _, test := range tests { - if test.write { - if err := store.WriteState(test.id, []byte(test.data)); err != nil { - t.Errorf("writing %q to %q: %v", test.data, test.id, err) - } - } else { - bs, err := store.ReadState(test.id) - if err != nil { - if test.notExists && err == ipn.ErrStateNotExist { - continue - } - t.Errorf("reading %q: %v", test.id, err) - continue - } - if string(bs) != test.data { - t.Errorf("reading %q: got %q, want %q", test.id, string(bs), test.data) - } - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package awsstore + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/service/ssm" + ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" + "tailscale.com/ipn" + "tailscale.com/tstest" +) + +type mockedAWSSSMClient struct { + value string +} + +func (sp *mockedAWSSSMClient) GetParameter(_ context.Context, input *ssm.GetParameterInput, _ ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + output := new(ssm.GetParameterOutput) + if sp.value == "" { + return output, &ssmTypes.ParameterNotFound{} + } + + output.Parameter = &ssmTypes.Parameter{ + Value: aws.String(sp.value), + } + + return output, nil +} + +func (sp *mockedAWSSSMClient) PutParameter(_ context.Context, input *ssm.PutParameterInput, _ ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) { + sp.value = *input.Value + return new(ssm.PutParameterOutput), nil +} + +func TestAWSStoreString(t *testing.T) { + store := &awsStore{ + ssmARN: arn.ARN{ + Service: "ssm", + Region: "eu-west-1", + AccountID: "123456789", + Resource: "parameter/foo", + }, + } + want := "awsStore(\"arn::ssm:eu-west-1:123456789:parameter/foo\")" + if got := store.String(); got != want { + t.Errorf("AWSStore.String = %q; want %q", got, want) + } +} + +func TestNewAWSStore(t *testing.T) { + tstest.PanicOnLog() + + mc := &mockedAWSSSMClient{} + storeParameterARN := arn.ARN{ + Service: "ssm", + Region: "eu-west-1", + AccountID: "123456789", + Resource: "parameter/foo", + } + + s, err := newStore(storeParameterARN.String(), mc) + if err != nil { + t.Fatalf("creating aws store failed: %v", err) + } + testStoreSemantics(t, s) + + // Build a brand new file store and check that both IDs written + // above are still there. + s2, err := newStore(storeParameterARN.String(), mc) + if err != nil { + t.Fatalf("creating second aws store failed: %v", err) + } + store2 := s.(*awsStore) + + // This is specific to the test, with the non-mocked API, LoadState() should + // have been already called and successful as no err is returned from NewAWSStore() + s2.(*awsStore).LoadState() + + expected := map[ipn.StateKey]string{ + "foo": "bar", + "baz": "quux", + } + for id, want := range expected { + bs, err := store2.ReadState(id) + if err != nil { + t.Errorf("reading %q (2nd store): %v", id, err) + } + if string(bs) != want { + t.Errorf("reading %q (2nd store): got %q, want %q", id, string(bs), want) + } + } +} + +func testStoreSemantics(t *testing.T, store ipn.StateStore) { + t.Helper() + + tests := []struct { + // if true, data is data to write. If false, data is expected + // output of read. + write bool + id ipn.StateKey + data string + // If write=false, true if we expect a not-exist error. + notExists bool + }{ + { + id: "foo", + notExists: true, + }, + { + write: true, + id: "foo", + data: "bar", + }, + { + id: "foo", + data: "bar", + }, + { + id: "baz", + notExists: true, + }, + { + write: true, + id: "baz", + data: "quux", + }, + { + id: "foo", + data: "bar", + }, + { + id: "baz", + data: "quux", + }, + } + + for _, test := range tests { + if test.write { + if err := store.WriteState(test.id, []byte(test.data)); err != nil { + t.Errorf("writing %q to %q: %v", test.data, test.id, err) + } + } else { + bs, err := store.ReadState(test.id) + if err != nil { + if test.notExists && err == ipn.ErrStateNotExist { + continue + } + t.Errorf("reading %q: %v", test.id, err) + continue + } + if string(bs) != test.data { + t.Errorf("reading %q: got %q, want %q", test.id, string(bs), test.data) + } + } + } +} diff --git a/ipn/store/stores_test.go b/ipn/store/stores_test.go index 69aa791938747..ea09e6ea63ae4 100644 --- a/ipn/store/stores_test.go +++ b/ipn/store/stores_test.go @@ -1,179 +1,179 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package store - -import ( - "path/filepath" - "testing" - - "tailscale.com/ipn" - "tailscale.com/ipn/store/mem" - "tailscale.com/tstest" - "tailscale.com/types/logger" -) - -func TestNewStore(t *testing.T) { - regOnce.Do(registerDefaultStores) - t.Cleanup(func() { - knownStores = map[string]Provider{} - registerDefaultStores() - }) - knownStores = map[string]Provider{} - - type store1 struct { - ipn.StateStore - path string - } - - type store2 struct { - ipn.StateStore - path string - } - - Register("arn:", func(_ logger.Logf, path string) (ipn.StateStore, error) { - return &store1{new(mem.Store), path}, nil - }) - Register("kube:", func(_ logger.Logf, path string) (ipn.StateStore, error) { - return &store2{new(mem.Store), path}, nil - }) - Register("mem:", func(_ logger.Logf, path string) (ipn.StateStore, error) { - return new(mem.Store), nil - }) - - path := "mem:abcd" - if s, err := New(t.Logf, path); err != nil { - t.Fatalf("%q: %v", path, err) - } else if _, ok := s.(*mem.Store); !ok { - t.Fatalf("%q: got: %T, want: %T", path, s, new(mem.Store)) - } - - path = "arn:foo" - if s, err := New(t.Logf, path); err != nil { - t.Fatalf("%q: %v", path, err) - } else if _, ok := s.(*store1); !ok { - t.Fatalf("%q: got: %T, want: %T", path, s, new(store1)) - } - - path = "kube:abcd" - if s, err := New(t.Logf, path); err != nil { - t.Fatalf("%q: %v", path, err) - } else if _, ok := s.(*store2); !ok { - t.Fatalf("%q: got: %T, want: %T", path, s, new(store2)) - } - - path = filepath.Join(t.TempDir(), "state") - if s, err := New(t.Logf, path); err != nil { - t.Fatalf("%q: %v", path, err) - } else if _, ok := s.(*FileStore); !ok { - t.Fatalf("%q: got: %T, want: %T", path, s, new(FileStore)) - } -} - -func testStoreSemantics(t *testing.T, store ipn.StateStore) { - t.Helper() - - tests := []struct { - // if true, data is data to write. If false, data is expected - // output of read. - write bool - id ipn.StateKey - data string - // If write=false, true if we expect a not-exist error. - notExists bool - }{ - { - id: "foo", - notExists: true, - }, - { - write: true, - id: "foo", - data: "bar", - }, - { - id: "foo", - data: "bar", - }, - { - id: "baz", - notExists: true, - }, - { - write: true, - id: "baz", - data: "quux", - }, - { - id: "foo", - data: "bar", - }, - { - id: "baz", - data: "quux", - }, - } - - for _, test := range tests { - if test.write { - if err := store.WriteState(test.id, []byte(test.data)); err != nil { - t.Errorf("writing %q to %q: %v", test.data, test.id, err) - } - } else { - bs, err := store.ReadState(test.id) - if err != nil { - if test.notExists && err == ipn.ErrStateNotExist { - continue - } - t.Errorf("reading %q: %v", test.id, err) - continue - } - if string(bs) != test.data { - t.Errorf("reading %q: got %q, want %q", test.id, string(bs), test.data) - } - } - } -} - -func TestMemoryStore(t *testing.T) { - tstest.PanicOnLog() - - store := new(mem.Store) - testStoreSemantics(t, store) -} - -func TestFileStore(t *testing.T) { - tstest.PanicOnLog() - - dir := t.TempDir() - path := filepath.Join(dir, "test-file-store.conf") - - store, err := NewFileStore(nil, path) - if err != nil { - t.Fatalf("creating file store failed: %v", err) - } - - testStoreSemantics(t, store) - - // Build a brand new file store and check that both IDs written - // above are still there. - store, err = NewFileStore(nil, path) - if err != nil { - t.Fatalf("creating second file store failed: %v", err) - } - - expected := map[ipn.StateKey]string{ - "foo": "bar", - "baz": "quux", - } - for key, want := range expected { - bs, err := store.ReadState(key) - if err != nil { - t.Errorf("reading %q (2nd store): %v", key, err) - continue - } - if string(bs) != want { - t.Errorf("reading %q (2nd store): got %q, want %q", key, bs, want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package store + +import ( + "path/filepath" + "testing" + + "tailscale.com/ipn" + "tailscale.com/ipn/store/mem" + "tailscale.com/tstest" + "tailscale.com/types/logger" +) + +func TestNewStore(t *testing.T) { + regOnce.Do(registerDefaultStores) + t.Cleanup(func() { + knownStores = map[string]Provider{} + registerDefaultStores() + }) + knownStores = map[string]Provider{} + + type store1 struct { + ipn.StateStore + path string + } + + type store2 struct { + ipn.StateStore + path string + } + + Register("arn:", func(_ logger.Logf, path string) (ipn.StateStore, error) { + return &store1{new(mem.Store), path}, nil + }) + Register("kube:", func(_ logger.Logf, path string) (ipn.StateStore, error) { + return &store2{new(mem.Store), path}, nil + }) + Register("mem:", func(_ logger.Logf, path string) (ipn.StateStore, error) { + return new(mem.Store), nil + }) + + path := "mem:abcd" + if s, err := New(t.Logf, path); err != nil { + t.Fatalf("%q: %v", path, err) + } else if _, ok := s.(*mem.Store); !ok { + t.Fatalf("%q: got: %T, want: %T", path, s, new(mem.Store)) + } + + path = "arn:foo" + if s, err := New(t.Logf, path); err != nil { + t.Fatalf("%q: %v", path, err) + } else if _, ok := s.(*store1); !ok { + t.Fatalf("%q: got: %T, want: %T", path, s, new(store1)) + } + + path = "kube:abcd" + if s, err := New(t.Logf, path); err != nil { + t.Fatalf("%q: %v", path, err) + } else if _, ok := s.(*store2); !ok { + t.Fatalf("%q: got: %T, want: %T", path, s, new(store2)) + } + + path = filepath.Join(t.TempDir(), "state") + if s, err := New(t.Logf, path); err != nil { + t.Fatalf("%q: %v", path, err) + } else if _, ok := s.(*FileStore); !ok { + t.Fatalf("%q: got: %T, want: %T", path, s, new(FileStore)) + } +} + +func testStoreSemantics(t *testing.T, store ipn.StateStore) { + t.Helper() + + tests := []struct { + // if true, data is data to write. If false, data is expected + // output of read. + write bool + id ipn.StateKey + data string + // If write=false, true if we expect a not-exist error. + notExists bool + }{ + { + id: "foo", + notExists: true, + }, + { + write: true, + id: "foo", + data: "bar", + }, + { + id: "foo", + data: "bar", + }, + { + id: "baz", + notExists: true, + }, + { + write: true, + id: "baz", + data: "quux", + }, + { + id: "foo", + data: "bar", + }, + { + id: "baz", + data: "quux", + }, + } + + for _, test := range tests { + if test.write { + if err := store.WriteState(test.id, []byte(test.data)); err != nil { + t.Errorf("writing %q to %q: %v", test.data, test.id, err) + } + } else { + bs, err := store.ReadState(test.id) + if err != nil { + if test.notExists && err == ipn.ErrStateNotExist { + continue + } + t.Errorf("reading %q: %v", test.id, err) + continue + } + if string(bs) != test.data { + t.Errorf("reading %q: got %q, want %q", test.id, string(bs), test.data) + } + } + } +} + +func TestMemoryStore(t *testing.T) { + tstest.PanicOnLog() + + store := new(mem.Store) + testStoreSemantics(t, store) +} + +func TestFileStore(t *testing.T) { + tstest.PanicOnLog() + + dir := t.TempDir() + path := filepath.Join(dir, "test-file-store.conf") + + store, err := NewFileStore(nil, path) + if err != nil { + t.Fatalf("creating file store failed: %v", err) + } + + testStoreSemantics(t, store) + + // Build a brand new file store and check that both IDs written + // above are still there. + store, err = NewFileStore(nil, path) + if err != nil { + t.Fatalf("creating second file store failed: %v", err) + } + + expected := map[ipn.StateKey]string{ + "foo": "bar", + "baz": "quux", + } + for key, want := range expected { + bs, err := store.ReadState(key) + if err != nil { + t.Errorf("reading %q (2nd store): %v", key, err) + continue + } + if string(bs) != want { + t.Errorf("reading %q (2nd store): got %q, want %q", key, bs, want) + } + } +} diff --git a/ipn/store_test.go b/ipn/store_test.go index 330f67969085b..fcc082d8a8a87 100644 --- a/ipn/store_test.go +++ b/ipn/store_test.go @@ -1,48 +1,48 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipn - -import ( - "bytes" - "sync" - "testing" - - "tailscale.com/util/mak" -) - -type memStore struct { - mu sync.Mutex - writes int - m map[StateKey][]byte -} - -func (s *memStore) ReadState(k StateKey) ([]byte, error) { - s.mu.Lock() - defer s.mu.Unlock() - return bytes.Clone(s.m[k]), nil -} - -func (s *memStore) WriteState(k StateKey, v []byte) error { - s.mu.Lock() - defer s.mu.Unlock() - mak.Set(&s.m, k, bytes.Clone(v)) - s.writes++ - return nil -} - -func TestWriteState(t *testing.T) { - var ss StateStore = new(memStore) - WriteState(ss, "foo", []byte("bar")) - WriteState(ss, "foo", []byte("bar")) - got, err := ss.ReadState("foo") - if err != nil { - t.Fatal(err) - } - if want := []byte("bar"); !bytes.Equal(got, want) { - t.Errorf("got %q; want %q", got, want) - } - if got, want := ss.(*memStore).writes, 1; got != want { - t.Errorf("got %d writes; want %d", got, want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipn + +import ( + "bytes" + "sync" + "testing" + + "tailscale.com/util/mak" +) + +type memStore struct { + mu sync.Mutex + writes int + m map[StateKey][]byte +} + +func (s *memStore) ReadState(k StateKey) ([]byte, error) { + s.mu.Lock() + defer s.mu.Unlock() + return bytes.Clone(s.m[k]), nil +} + +func (s *memStore) WriteState(k StateKey, v []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + mak.Set(&s.m, k, bytes.Clone(v)) + s.writes++ + return nil +} + +func TestWriteState(t *testing.T) { + var ss StateStore = new(memStore) + WriteState(ss, "foo", []byte("bar")) + WriteState(ss, "foo", []byte("bar")) + got, err := ss.ReadState("foo") + if err != nil { + t.Fatal(err) + } + if want := []byte("bar"); !bytes.Equal(got, want) { + t.Errorf("got %q; want %q", got, want) + } + if got, want := ss.(*memStore).writes, 1; got != want { + t.Errorf("got %d writes; want %d", got, want) + } +} diff --git a/jsondb/db.go b/jsondb/db.go index c45c1f819ca05..68bb05af45e8e 100644 --- a/jsondb/db.go +++ b/jsondb/db.go @@ -1,57 +1,57 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package jsondb provides a trivial "database": a Go object saved to -// disk as JSON. -package jsondb - -import ( - "encoding/json" - "errors" - "io/fs" - "os" - - "tailscale.com/atomicfile" -) - -// DB is a database backed by a JSON file. -type DB[T any] struct { - // Data is the contents of the database. - Data *T - - path string -} - -// Open opens the database at path, creating it with a zero value if -// necessary. -func Open[T any](path string) (*DB[T], error) { - bs, err := os.ReadFile(path) - if errors.Is(err, fs.ErrNotExist) { - return &DB[T]{ - Data: new(T), - path: path, - }, nil - } else if err != nil { - return nil, err - } - - var val T - if err := json.Unmarshal(bs, &val); err != nil { - return nil, err - } - - return &DB[T]{ - Data: &val, - path: path, - }, nil -} - -// Save writes db.Data back to disk. -func (db *DB[T]) Save() error { - bs, err := json.Marshal(db.Data) - if err != nil { - return err - } - - return atomicfile.WriteFile(db.path, bs, 0600) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package jsondb provides a trivial "database": a Go object saved to +// disk as JSON. +package jsondb + +import ( + "encoding/json" + "errors" + "io/fs" + "os" + + "tailscale.com/atomicfile" +) + +// DB is a database backed by a JSON file. +type DB[T any] struct { + // Data is the contents of the database. + Data *T + + path string +} + +// Open opens the database at path, creating it with a zero value if +// necessary. +func Open[T any](path string) (*DB[T], error) { + bs, err := os.ReadFile(path) + if errors.Is(err, fs.ErrNotExist) { + return &DB[T]{ + Data: new(T), + path: path, + }, nil + } else if err != nil { + return nil, err + } + + var val T + if err := json.Unmarshal(bs, &val); err != nil { + return nil, err + } + + return &DB[T]{ + Data: &val, + path: path, + }, nil +} + +// Save writes db.Data back to disk. +func (db *DB[T]) Save() error { + bs, err := json.Marshal(db.Data) + if err != nil { + return err + } + + return atomicfile.WriteFile(db.path, bs, 0600) +} diff --git a/jsondb/db_test.go b/jsondb/db_test.go index a78b15b4f32c7..655754f38e1a9 100644 --- a/jsondb/db_test.go +++ b/jsondb/db_test.go @@ -1,55 +1,55 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package jsondb - -import ( - "log" - "os" - "path/filepath" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestDB(t *testing.T) { - dir, err := os.MkdirTemp("", "db-test") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dir) - - path := filepath.Join(dir, "db.json") - db, err := Open[testDB](path) - if err != nil { - t.Fatalf("creating empty DB: %v", err) - } - - if diff := cmp.Diff(db.Data, &testDB{}, cmp.AllowUnexported(testDB{})); diff != "" { - t.Fatalf("unexpected empty DB content (-got+want):\n%s", diff) - } - db.Data.MyString = "test" - db.Data.unexported = "don't keep" - db.Data.AnInt = 42 - if err := db.Save(); err != nil { - t.Fatalf("saving database: %v", err) - } - - db2, err := Open[testDB](path) - if err != nil { - log.Fatalf("opening DB again: %v", err) - } - want := &testDB{ - MyString: "test", - AnInt: 42, - } - if diff := cmp.Diff(db2.Data, want, cmp.AllowUnexported(testDB{})); diff != "" { - t.Fatalf("unexpected saved DB content (-got+want):\n%s", diff) - } -} - -type testDB struct { - MyString string - unexported string - AnInt int64 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package jsondb + +import ( + "log" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestDB(t *testing.T) { + dir, err := os.MkdirTemp("", "db-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + path := filepath.Join(dir, "db.json") + db, err := Open[testDB](path) + if err != nil { + t.Fatalf("creating empty DB: %v", err) + } + + if diff := cmp.Diff(db.Data, &testDB{}, cmp.AllowUnexported(testDB{})); diff != "" { + t.Fatalf("unexpected empty DB content (-got+want):\n%s", diff) + } + db.Data.MyString = "test" + db.Data.unexported = "don't keep" + db.Data.AnInt = 42 + if err := db.Save(); err != nil { + t.Fatalf("saving database: %v", err) + } + + db2, err := Open[testDB](path) + if err != nil { + log.Fatalf("opening DB again: %v", err) + } + want := &testDB{ + MyString: "test", + AnInt: 42, + } + if diff := cmp.Diff(db2.Data, want, cmp.AllowUnexported(testDB{})); diff != "" { + t.Fatalf("unexpected saved DB content (-got+want):\n%s", diff) + } +} + +type testDB struct { + MyString string + unexported string + AnInt int64 +} diff --git a/licenses/licenses.go b/licenses/licenses.go index 3ec7013214bb5..5e59edb9f7b75 100644 --- a/licenses/licenses.go +++ b/licenses/licenses.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package licenses provides utilities for working with open source licenses. -package licenses - -import "runtime" - -// LicensesURL returns the absolute URL containing open source license information for the current platform. -func LicensesURL() string { - switch runtime.GOOS { - case "android": - return "https://tailscale.com/licenses/android" - case "darwin", "ios": - return "https://tailscale.com/licenses/apple" - case "windows": - return "https://tailscale.com/licenses/windows" - default: - return "https://tailscale.com/licenses/tailscale" - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package licenses provides utilities for working with open source licenses. +package licenses + +import "runtime" + +// LicensesURL returns the absolute URL containing open source license information for the current platform. +func LicensesURL() string { + switch runtime.GOOS { + case "android": + return "https://tailscale.com/licenses/android" + case "darwin", "ios": + return "https://tailscale.com/licenses/apple" + case "windows": + return "https://tailscale.com/licenses/windows" + default: + return "https://tailscale.com/licenses/tailscale" + } +} diff --git a/log/filelogger/log.go b/log/filelogger/log.go index 9d7097eb83e84..599e5237b3e22 100644 --- a/log/filelogger/log.go +++ b/log/filelogger/log.go @@ -1,228 +1,228 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package filelogger provides localdisk log writing & rotation, primarily for Windows -// clients. (We get this for free on other platforms.) -package filelogger - -import ( - "bytes" - "fmt" - "log" - "os" - "path/filepath" - "runtime" - "strings" - "sync" - "time" - - "tailscale.com/types/logger" -) - -const ( - maxSize = 100 << 20 - maxFiles = 50 -) - -// New returns a logf wrapper that appends to local disk log -// files on Windows, rotating old log files as needed to stay under -// file count & byte limits. -func New(fileBasePrefix, logID string, logf logger.Logf) logger.Logf { - if runtime.GOOS != "windows" { - panic("not yet supported on any platform except Windows") - } - if logf == nil { - panic("nil logf") - } - dir := filepath.Join(os.Getenv("ProgramData"), "Tailscale", "Logs") - - if err := os.MkdirAll(dir, 0700); err != nil { - log.Printf("failed to create local log directory; not writing logs to disk: %v", err) - return logf - } - logf("local disk logdir: %v", dir) - lfw := &logFileWriter{ - fileBasePrefix: fileBasePrefix, - logID: logID, - dir: dir, - wrappedLogf: logf, - } - return lfw.Logf -} - -// logFileWriter is the state for the log writer & rotator. -type logFileWriter struct { - dir string // e.g. `C:\Users\FooBarUser\AppData\Local\Tailscale\Logs` - logID string // hex logID - fileBasePrefix string // e.g. "tailscale-service" or "tailscale-gui" - wrappedLogf logger.Logf // underlying logger to send to - - mu sync.Mutex // guards following - buf bytes.Buffer // scratch buffer to avoid allocs - fday civilDay // day that f was opened; zero means no file yet open - f *os.File // file currently opened for append -} - -// civilDay is a year, month, and day in the local timezone. -// It's a comparable value type. -type civilDay struct { - year int - month time.Month - day int -} - -func dayOf(t time.Time) civilDay { - return civilDay{t.Year(), t.Month(), t.Day()} -} - -func (w *logFileWriter) Logf(format string, a ...any) { - w.mu.Lock() - defer w.mu.Unlock() - - w.buf.Reset() - fmt.Fprintf(&w.buf, format, a...) - if w.buf.Len() == 0 { - return - } - out := w.buf.Bytes() - w.wrappedLogf("%s", out) - - // Make sure there's a final newline before we write to the log file. - if out[len(out)-1] != '\n' { - w.buf.WriteByte('\n') - out = w.buf.Bytes() - } - - w.appendToFileLocked(out) -} - -// out should end in a newline. -// w.mu must be held. -func (w *logFileWriter) appendToFileLocked(out []byte) { - now := time.Now() - day := dayOf(now) - if w.fday != day { - w.startNewFileLocked() - } - out = removeDatePrefix(out) - if w.f != nil { - // RFC3339Nano but with a fixed number (3) of nanosecond digits: - const formatPre = "2006-01-02T15:04:05" - const formatPost = "Z07:00" - fmt.Fprintf(w.f, "%s.%03d%s: %s", - now.Format(formatPre), - now.Nanosecond()/int(time.Millisecond/time.Nanosecond), - now.Format(formatPost), - out) - } -} - -func isNum(b byte) bool { return '0' <= b && b <= '9' } - -// removeDatePrefix returns a subslice of v with the log package's -// standard datetime prefix format removed, if present. -func removeDatePrefix(v []byte) []byte { - const format = "2009/01/23 01:23:23 " - if len(v) < len(format) { - return v - } - for i, b := range v[:len(format)] { - fb := format[i] - if isNum(fb) { - if !isNum(b) { - return v - } - continue - } - if b != fb { - return v - } - } - return v[len(format):] -} - -// startNewFileLocked opens a new log file for writing -// and also cleans up any old files. -// -// w.mu must be held. -func (w *logFileWriter) startNewFileLocked() { - var oldName string - if w.f != nil { - oldName = filepath.Base(w.f.Name()) - w.f.Close() - w.f = nil - w.fday = civilDay{} - } - w.cleanLocked() - - now := time.Now() - day := dayOf(now) - name := filepath.Join(w.dir, fmt.Sprintf("%s-%04d%02d%02dT%02d%02d%02d-%d.txt", - w.fileBasePrefix, - day.year, - day.month, - day.day, - now.Hour(), - now.Minute(), - now.Second(), - now.Unix())) - var err error - w.f, err = os.Create(name) - if err != nil { - w.wrappedLogf("failed to create log file: %v", err) - return - } - if oldName != "" { - fmt.Fprintf(w.f, "(logID %q; continued from log file %s)\n", w.logID, oldName) - } else { - fmt.Fprintf(w.f, "(logID %q)\n", w.logID) - } - w.fday = day -} - -// cleanLocked cleans up old log files. -// -// w.mu must be held. -func (w *logFileWriter) cleanLocked() { - entries, _ := os.ReadDir(w.dir) - prefix := w.fileBasePrefix + "-" - fileSize := map[string]int64{} - var files []string - var sumSize int64 - for _, entry := range entries { - fi, err := entry.Info() - if err != nil { - w.wrappedLogf("error getting log file info: %v", err) - continue - } - - baseName := filepath.Base(fi.Name()) - if !strings.HasPrefix(baseName, prefix) { - continue - } - size := fi.Size() - fileSize[baseName] = size - sumSize += size - files = append(files, baseName) - } - if sumSize > maxSize { - w.wrappedLogf("cleaning log files; sum byte count %d > %d", sumSize, maxSize) - } - if len(files) > maxFiles { - w.wrappedLogf("cleaning log files; number of files %d > %d", len(files), maxFiles) - } - for (sumSize > maxSize || len(files) > maxFiles) && len(files) > 0 { - target := files[0] - files = files[1:] - - targetSize := fileSize[target] - targetFull := filepath.Join(w.dir, target) - err := os.Remove(targetFull) - if err != nil { - w.wrappedLogf("error cleaning log file: %v", err) - } else { - sumSize -= targetSize - w.wrappedLogf("cleaned log file %s (size %d); new bytes=%v, files=%v", targetFull, targetSize, sumSize, len(files)) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package filelogger provides localdisk log writing & rotation, primarily for Windows +// clients. (We get this for free on other platforms.) +package filelogger + +import ( + "bytes" + "fmt" + "log" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + "time" + + "tailscale.com/types/logger" +) + +const ( + maxSize = 100 << 20 + maxFiles = 50 +) + +// New returns a logf wrapper that appends to local disk log +// files on Windows, rotating old log files as needed to stay under +// file count & byte limits. +func New(fileBasePrefix, logID string, logf logger.Logf) logger.Logf { + if runtime.GOOS != "windows" { + panic("not yet supported on any platform except Windows") + } + if logf == nil { + panic("nil logf") + } + dir := filepath.Join(os.Getenv("ProgramData"), "Tailscale", "Logs") + + if err := os.MkdirAll(dir, 0700); err != nil { + log.Printf("failed to create local log directory; not writing logs to disk: %v", err) + return logf + } + logf("local disk logdir: %v", dir) + lfw := &logFileWriter{ + fileBasePrefix: fileBasePrefix, + logID: logID, + dir: dir, + wrappedLogf: logf, + } + return lfw.Logf +} + +// logFileWriter is the state for the log writer & rotator. +type logFileWriter struct { + dir string // e.g. `C:\Users\FooBarUser\AppData\Local\Tailscale\Logs` + logID string // hex logID + fileBasePrefix string // e.g. "tailscale-service" or "tailscale-gui" + wrappedLogf logger.Logf // underlying logger to send to + + mu sync.Mutex // guards following + buf bytes.Buffer // scratch buffer to avoid allocs + fday civilDay // day that f was opened; zero means no file yet open + f *os.File // file currently opened for append +} + +// civilDay is a year, month, and day in the local timezone. +// It's a comparable value type. +type civilDay struct { + year int + month time.Month + day int +} + +func dayOf(t time.Time) civilDay { + return civilDay{t.Year(), t.Month(), t.Day()} +} + +func (w *logFileWriter) Logf(format string, a ...any) { + w.mu.Lock() + defer w.mu.Unlock() + + w.buf.Reset() + fmt.Fprintf(&w.buf, format, a...) + if w.buf.Len() == 0 { + return + } + out := w.buf.Bytes() + w.wrappedLogf("%s", out) + + // Make sure there's a final newline before we write to the log file. + if out[len(out)-1] != '\n' { + w.buf.WriteByte('\n') + out = w.buf.Bytes() + } + + w.appendToFileLocked(out) +} + +// out should end in a newline. +// w.mu must be held. +func (w *logFileWriter) appendToFileLocked(out []byte) { + now := time.Now() + day := dayOf(now) + if w.fday != day { + w.startNewFileLocked() + } + out = removeDatePrefix(out) + if w.f != nil { + // RFC3339Nano but with a fixed number (3) of nanosecond digits: + const formatPre = "2006-01-02T15:04:05" + const formatPost = "Z07:00" + fmt.Fprintf(w.f, "%s.%03d%s: %s", + now.Format(formatPre), + now.Nanosecond()/int(time.Millisecond/time.Nanosecond), + now.Format(formatPost), + out) + } +} + +func isNum(b byte) bool { return '0' <= b && b <= '9' } + +// removeDatePrefix returns a subslice of v with the log package's +// standard datetime prefix format removed, if present. +func removeDatePrefix(v []byte) []byte { + const format = "2009/01/23 01:23:23 " + if len(v) < len(format) { + return v + } + for i, b := range v[:len(format)] { + fb := format[i] + if isNum(fb) { + if !isNum(b) { + return v + } + continue + } + if b != fb { + return v + } + } + return v[len(format):] +} + +// startNewFileLocked opens a new log file for writing +// and also cleans up any old files. +// +// w.mu must be held. +func (w *logFileWriter) startNewFileLocked() { + var oldName string + if w.f != nil { + oldName = filepath.Base(w.f.Name()) + w.f.Close() + w.f = nil + w.fday = civilDay{} + } + w.cleanLocked() + + now := time.Now() + day := dayOf(now) + name := filepath.Join(w.dir, fmt.Sprintf("%s-%04d%02d%02dT%02d%02d%02d-%d.txt", + w.fileBasePrefix, + day.year, + day.month, + day.day, + now.Hour(), + now.Minute(), + now.Second(), + now.Unix())) + var err error + w.f, err = os.Create(name) + if err != nil { + w.wrappedLogf("failed to create log file: %v", err) + return + } + if oldName != "" { + fmt.Fprintf(w.f, "(logID %q; continued from log file %s)\n", w.logID, oldName) + } else { + fmt.Fprintf(w.f, "(logID %q)\n", w.logID) + } + w.fday = day +} + +// cleanLocked cleans up old log files. +// +// w.mu must be held. +func (w *logFileWriter) cleanLocked() { + entries, _ := os.ReadDir(w.dir) + prefix := w.fileBasePrefix + "-" + fileSize := map[string]int64{} + var files []string + var sumSize int64 + for _, entry := range entries { + fi, err := entry.Info() + if err != nil { + w.wrappedLogf("error getting log file info: %v", err) + continue + } + + baseName := filepath.Base(fi.Name()) + if !strings.HasPrefix(baseName, prefix) { + continue + } + size := fi.Size() + fileSize[baseName] = size + sumSize += size + files = append(files, baseName) + } + if sumSize > maxSize { + w.wrappedLogf("cleaning log files; sum byte count %d > %d", sumSize, maxSize) + } + if len(files) > maxFiles { + w.wrappedLogf("cleaning log files; number of files %d > %d", len(files), maxFiles) + } + for (sumSize > maxSize || len(files) > maxFiles) && len(files) > 0 { + target := files[0] + files = files[1:] + + targetSize := fileSize[target] + targetFull := filepath.Join(w.dir, target) + err := os.Remove(targetFull) + if err != nil { + w.wrappedLogf("error cleaning log file: %v", err) + } else { + sumSize -= targetSize + w.wrappedLogf("cleaned log file %s (size %d); new bytes=%v, files=%v", targetFull, targetSize, sumSize, len(files)) + } + } +} diff --git a/log/filelogger/log_test.go b/log/filelogger/log_test.go index 27f80ab0ae37a..dfa489637f720 100644 --- a/log/filelogger/log_test.go +++ b/log/filelogger/log_test.go @@ -1,27 +1,27 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package filelogger - -import "testing" - -func TestRemoveDatePrefix(t *testing.T) { - tests := []struct { - in, want string - }{ - {"", ""}, - {"\n", "\n"}, - {"2009/01/23 01:23:23", "2009/01/23 01:23:23"}, - {"2009/01/23 01:23:23 \n", "\n"}, - {"2009/01/23 01:23:23 foo\n", "foo\n"}, - {"9999/01/23 01:23:23 foo\n", "foo\n"}, - {"2009_01/23 01:23:23 had an underscore\n", "2009_01/23 01:23:23 had an underscore\n"}, - } - for i, tt := range tests { - got := removeDatePrefix([]byte(tt.in)) - if string(got) != tt.want { - t.Logf("[%d] removeDatePrefix(%q) = %q; want %q", i, tt.in, got, tt.want) - } - } - -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package filelogger + +import "testing" + +func TestRemoveDatePrefix(t *testing.T) { + tests := []struct { + in, want string + }{ + {"", ""}, + {"\n", "\n"}, + {"2009/01/23 01:23:23", "2009/01/23 01:23:23"}, + {"2009/01/23 01:23:23 \n", "\n"}, + {"2009/01/23 01:23:23 foo\n", "foo\n"}, + {"9999/01/23 01:23:23 foo\n", "foo\n"}, + {"2009_01/23 01:23:23 had an underscore\n", "2009_01/23 01:23:23 had an underscore\n"}, + } + for i, tt := range tests { + got := removeDatePrefix([]byte(tt.in)) + if string(got) != tt.want { + t.Logf("[%d] removeDatePrefix(%q) = %q; want %q", i, tt.in, got, tt.want) + } + } + +} diff --git a/logpolicy/logpolicy_test.go b/logpolicy/logpolicy_test.go index c0cdfb965c80e..fdbfe4506e038 100644 --- a/logpolicy/logpolicy_test.go +++ b/logpolicy/logpolicy_test.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package logpolicy - -import ( - "os" - "reflect" - "testing" -) - -func TestLogHost(t *testing.T) { - v := reflect.ValueOf(&getLogTargetOnce).Elem() - reset := func() { - v.Set(reflect.Zero(v.Type())) - } - defer reset() - - tests := []struct { - env string - want string - }{ - {"", "log.tailscale.io"}, - {"http://foo.com", "foo.com"}, - {"https://foo.com", "foo.com"}, - {"https://foo.com/", "foo.com"}, - {"https://foo.com:123/", "foo.com"}, - } - for _, tt := range tests { - reset() - os.Setenv("TS_LOG_TARGET", tt.env) - if got := LogHost(); got != tt.want { - t.Errorf("for env %q, got %q, want %q", tt.env, got, tt.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package logpolicy + +import ( + "os" + "reflect" + "testing" +) + +func TestLogHost(t *testing.T) { + v := reflect.ValueOf(&getLogTargetOnce).Elem() + reset := func() { + v.Set(reflect.Zero(v.Type())) + } + defer reset() + + tests := []struct { + env string + want string + }{ + {"", "log.tailscale.io"}, + {"http://foo.com", "foo.com"}, + {"https://foo.com", "foo.com"}, + {"https://foo.com/", "foo.com"}, + {"https://foo.com:123/", "foo.com"}, + } + for _, tt := range tests { + reset() + os.Setenv("TS_LOG_TARGET", tt.env) + if got := LogHost(); got != tt.want { + t.Errorf("for env %q, got %q, want %q", tt.env, got, tt.want) + } + } +} diff --git a/logtail/.gitignore b/logtail/.gitignore index b262949a827d0..0b29b4aca8ef3 100644 --- a/logtail/.gitignore +++ b/logtail/.gitignore @@ -1,6 +1,6 @@ -*~ -*.out -/example/logadopt/logadopt -/example/logreprocess/logreprocess -/example/logtail/logtail -/logtail +*~ +*.out +/example/logadopt/logadopt +/example/logreprocess/logreprocess +/example/logtail/logtail +/logtail diff --git a/logtail/README.md b/logtail/README.md index b7b2ada34e985..20d22c3501432 100644 --- a/logtail/README.md +++ b/logtail/README.md @@ -1,10 +1,10 @@ -# Tailscale Logs Service - -This github repository contains libraries, documentation, and examples -for working with the public API of the tailscale logs service. - -For a very quick introduction to the core features, read the -[API docs](api.md) and peruse the -[logs reprocessing](./example/logreprocess/demo.sh) example. - +# Tailscale Logs Service + +This github repository contains libraries, documentation, and examples +for working with the public API of the tailscale logs service. + +For a very quick introduction to the core features, read the +[API docs](api.md) and peruse the +[logs reprocessing](./example/logreprocess/demo.sh) example. + For more information, write to info@tailscale.io. \ No newline at end of file diff --git a/logtail/api.md b/logtail/api.md index 296913ce4985b..8ec0b69c0f331 100644 --- a/logtail/api.md +++ b/logtail/api.md @@ -1,195 +1,195 @@ -# Tailscale Logs Service - -The Tailscale Logs Service defines a REST interface for configuring, storing, -retrieving, and processing log entries. - -# Overview - -HTTP requests are received at the service **base URL** -[https://log.tailscale.io](https://log.tailscale.io), and return JSON-encoded -responses using standard HTTP response codes. - -Authorization for the configuration and retrieval APIs is done with a secret -API key passed as the HTTP basic auth username. Secret keys are generated via -the web UI at base URL. An example of using basic auth with curl: - - curl -u : https://log.tailscale.io/collections - -In the future, an HTTP header will allow using MessagePack instead of JSON. - -## Collections - -Logs are organized into collections. Inside each collection is any number of -instances. - -A collection is a domain name. It is a grouping of related logs. As a -guideline, create one collection per product using subdomains of your -company's domain name. Collections must be registered with the logs service -before any attempt is made to store logs. - -## Instances - -Each collection is a set of instances. There is one instance per machine -writing logs. - -An instance has a name and a number. An instance has a **private** and -**public** ID. The private ID is a 32-byte random number encoded as hex. -The public ID is the SHA-256 hash of the private ID, encoded as hex. - -The private ID is used to write logs. The only copy of the private ID -should be on the machine sending logs. Ideally it is generated on the -machine. Logs can be written as soon as a private ID is generated. - -The public ID is used to read and adopt logs. It is designed to be sent -to a service that also holds a logs service API key. - -The tailscale logs service will store any logs for a short period of time. -To enable logs retention, the log can be **adopted** using the public ID -and a logs service API key. -Once this is done, logs will be retained long-term (for the configured -retention period). - -Unadopted instance logs are stored temporarily to help with debugging: -a misconfigured machine writing logs with a bad ID can be spotted by -reading the logs. -If a public ID is not adopted, storage is tightly capped and logs are -deleted after 12 hours. - -# APIs - -## Storage - -### `POST /c//` — send a log - -The body of the request is JSON. - -A **single message** is an object with properties: - -`{ }` - -The client may send any properties it wants in the JSON message, except -for the `logtail` property which has special meaning. Inside the logtail -object the client may only set the following properties: - -- `client_time` in the format of RFC3339: "2006-01-02T15:04:05.999999999Z07:00" - -A future version of the logs service API will also support: - -- `client_time_offset` a integer of nanoseconds since the client was reset -- `client_time_reset` a boolean if set to true resets the time offset counter - -On receipt by the server the `client_time_offset` is transformed into a -`client_time` based on the `server_time` when the first (or -client_time_reset) event was received. - -If any other properties are set in the logtail object they are moved into -the "error" field, the message is saved and a 4xx status code is returned. - -A **batch of messages** is a JSON array filled with single message objects: - -`[ { }, { }, ... ]` - -If any of the array entries are not objects, the content is converted -into a message with a `"logtail": { "error": ...}` property, saved, and -a 4xx status code is returned. - -Similarly any other request content not matching one of these formats is -saved in a logtail error field, and a 4xx status code is returned. - -An invalid collection name returns `{"error": "invalid collection name"}` -along with a 403 status code. - -Clients are encouraged to: - -- POST as rapidly as possible (if not battery constrained). This minimizes - both the time necessary to see logs in a log viewer and the chance of - losing logs. -- Use HTTP/2 when streaming logs, as it does a much better job of - maintaining a TLS connection to minimize overhead for subsequent posts. - -A future version of logs service API will support sending requests with -`Content-Encoding: zstd`. - -## Retrieval - -### `GET /collections` — query the set of collections and instances - -Returns a JSON object listing all of the named collections. - -The caller can query-encode the following fields: - -- `collection-name` — limit the results to one collection - - ``` - { - "collections": { - "collection1.yourcompany.com": { - "instances": { - "" :{ - "first-seen": "timestamp", - "size": 4096 - }, - "" :{ - "first-seen": "timestamp", - "size": 512000, - "orphan": true, - } - } - } - } - } - ``` - -### `GET /c/` — query stored logs - -The caller can query-encode the following fields: - -- `instances` — zero or more log collection instances to limit results to -- `time-start` — the earliest log to include -- One of: - - `time-end` — the latest log to include - - `max-count` — maximum number of logs to return, allows paging - - `stream` — boolean that keeps the response dangling, streaming in - logs like `tail -f`. Incompatible with logtail-time-end. - -In **stream=false** mode, the response is a single JSON object: - - { - // TODO: header fields - "logs": [ {}, {}, ... ] - } - -In **stream=true** mode, the response begins with a JSON header object -similar to the storage format, and then is a sequence of JSON log -objects, `{...}`, one per line. The server continues to send these until -the client closes the connection. - -## Configuration - -For organizations with a small number of instances writing logs, the -Configuration API are best used by a trusted human operator, usually -through a GUI. Organizations with many instances will need to automate -the creation of tokens. - -### `POST /collections` — create or delete a collection - -The caller must set the `collection` property and `action=create` or -`action=delete`, either form encoded or JSON encoded. Its character set -is restricted to the mundane: [a-zA-Z0-9-_.]+ - -Collection names are a global space. Typically they are a domain name. - -### `POST /instances` — adopt an instance into a collection - -The caller must send the following properties, form encoded or JSON encoded: - -- `collection` — a valid FQDN ([a-zA-Z0-9-_.]+) -- `instances` an instance public ID encoded as hex - -The collection name must be claimed by a group the caller belongs to. -The pair (collection-name, instance-public-ID) may or may not already have -logs associated with it. - -On failure, an error message is returned with a 4xx or 5xx status code: - +# Tailscale Logs Service + +The Tailscale Logs Service defines a REST interface for configuring, storing, +retrieving, and processing log entries. + +# Overview + +HTTP requests are received at the service **base URL** +[https://log.tailscale.io](https://log.tailscale.io), and return JSON-encoded +responses using standard HTTP response codes. + +Authorization for the configuration and retrieval APIs is done with a secret +API key passed as the HTTP basic auth username. Secret keys are generated via +the web UI at base URL. An example of using basic auth with curl: + + curl -u : https://log.tailscale.io/collections + +In the future, an HTTP header will allow using MessagePack instead of JSON. + +## Collections + +Logs are organized into collections. Inside each collection is any number of +instances. + +A collection is a domain name. It is a grouping of related logs. As a +guideline, create one collection per product using subdomains of your +company's domain name. Collections must be registered with the logs service +before any attempt is made to store logs. + +## Instances + +Each collection is a set of instances. There is one instance per machine +writing logs. + +An instance has a name and a number. An instance has a **private** and +**public** ID. The private ID is a 32-byte random number encoded as hex. +The public ID is the SHA-256 hash of the private ID, encoded as hex. + +The private ID is used to write logs. The only copy of the private ID +should be on the machine sending logs. Ideally it is generated on the +machine. Logs can be written as soon as a private ID is generated. + +The public ID is used to read and adopt logs. It is designed to be sent +to a service that also holds a logs service API key. + +The tailscale logs service will store any logs for a short period of time. +To enable logs retention, the log can be **adopted** using the public ID +and a logs service API key. +Once this is done, logs will be retained long-term (for the configured +retention period). + +Unadopted instance logs are stored temporarily to help with debugging: +a misconfigured machine writing logs with a bad ID can be spotted by +reading the logs. +If a public ID is not adopted, storage is tightly capped and logs are +deleted after 12 hours. + +# APIs + +## Storage + +### `POST /c//` — send a log + +The body of the request is JSON. + +A **single message** is an object with properties: + +`{ }` + +The client may send any properties it wants in the JSON message, except +for the `logtail` property which has special meaning. Inside the logtail +object the client may only set the following properties: + +- `client_time` in the format of RFC3339: "2006-01-02T15:04:05.999999999Z07:00" + +A future version of the logs service API will also support: + +- `client_time_offset` a integer of nanoseconds since the client was reset +- `client_time_reset` a boolean if set to true resets the time offset counter + +On receipt by the server the `client_time_offset` is transformed into a +`client_time` based on the `server_time` when the first (or +client_time_reset) event was received. + +If any other properties are set in the logtail object they are moved into +the "error" field, the message is saved and a 4xx status code is returned. + +A **batch of messages** is a JSON array filled with single message objects: + +`[ { }, { }, ... ]` + +If any of the array entries are not objects, the content is converted +into a message with a `"logtail": { "error": ...}` property, saved, and +a 4xx status code is returned. + +Similarly any other request content not matching one of these formats is +saved in a logtail error field, and a 4xx status code is returned. + +An invalid collection name returns `{"error": "invalid collection name"}` +along with a 403 status code. + +Clients are encouraged to: + +- POST as rapidly as possible (if not battery constrained). This minimizes + both the time necessary to see logs in a log viewer and the chance of + losing logs. +- Use HTTP/2 when streaming logs, as it does a much better job of + maintaining a TLS connection to minimize overhead for subsequent posts. + +A future version of logs service API will support sending requests with +`Content-Encoding: zstd`. + +## Retrieval + +### `GET /collections` — query the set of collections and instances + +Returns a JSON object listing all of the named collections. + +The caller can query-encode the following fields: + +- `collection-name` — limit the results to one collection + + ``` + { + "collections": { + "collection1.yourcompany.com": { + "instances": { + "" :{ + "first-seen": "timestamp", + "size": 4096 + }, + "" :{ + "first-seen": "timestamp", + "size": 512000, + "orphan": true, + } + } + } + } + } + ``` + +### `GET /c/` — query stored logs + +The caller can query-encode the following fields: + +- `instances` — zero or more log collection instances to limit results to +- `time-start` — the earliest log to include +- One of: + - `time-end` — the latest log to include + - `max-count` — maximum number of logs to return, allows paging + - `stream` — boolean that keeps the response dangling, streaming in + logs like `tail -f`. Incompatible with logtail-time-end. + +In **stream=false** mode, the response is a single JSON object: + + { + // TODO: header fields + "logs": [ {}, {}, ... ] + } + +In **stream=true** mode, the response begins with a JSON header object +similar to the storage format, and then is a sequence of JSON log +objects, `{...}`, one per line. The server continues to send these until +the client closes the connection. + +## Configuration + +For organizations with a small number of instances writing logs, the +Configuration API are best used by a trusted human operator, usually +through a GUI. Organizations with many instances will need to automate +the creation of tokens. + +### `POST /collections` — create or delete a collection + +The caller must set the `collection` property and `action=create` or +`action=delete`, either form encoded or JSON encoded. Its character set +is restricted to the mundane: [a-zA-Z0-9-_.]+ + +Collection names are a global space. Typically they are a domain name. + +### `POST /instances` — adopt an instance into a collection + +The caller must send the following properties, form encoded or JSON encoded: + +- `collection` — a valid FQDN ([a-zA-Z0-9-_.]+) +- `instances` an instance public ID encoded as hex + +The collection name must be claimed by a group the caller belongs to. +The pair (collection-name, instance-public-ID) may or may not already have +logs associated with it. + +On failure, an error message is returned with a 4xx or 5xx status code: + `{"error": "what went wrong"}` \ No newline at end of file diff --git a/logtail/example/logreprocess/demo.sh b/logtail/example/logreprocess/demo.sh index eaec706a38718..4ec819a67450d 100755 --- a/logtail/example/logreprocess/demo.sh +++ b/logtail/example/logreprocess/demo.sh @@ -1,86 +1,86 @@ -#!/bin/bash -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -# -# This shell script demonstrates writing logs from machines -# and then reprocessing those logs to amalgamate python tracebacks -# into a single log entry in a new collection. -# -# To run this demo, first install the example applications: -# -# go install tailscale.com/logtail/example/... -# -# Then generate a LOGTAIL_API_KEY and two test collections by visiting: -# -# https://log.tailscale.io -# -# Then set the three variables below. -trap 'rv=$?; [ "$rv" = 0 ] || echo "-- exiting with code $rv"; exit $rv' EXIT -set -e - -LOG_TEXT='server starting -config file loaded -answering queries -Traceback (most recent call last): - File "/Users/crawshaw/junk.py", line 6, in - main() - File "/Users/crawshaw/junk.py", line 4, in main - raise Exception("oops") -Exception: oops' - -die() { - echo "$0: $*" >&2 - exit 1 -} - -msg() { - echo "-- $*" >&2 -} - -if [ -z "$LOGTAIL_API_KEY" ]; then - die "LOGTAIL_API_KEY is not set" -fi - -if [ -z "$COLLECTION_IN" ]; then - die "COLLECTION_IN is not set" -fi - -if [ -z "$COLLECTION_OUT" ]; then - die "COLLECTION_OUT is not set" -fi - -# Private IDs are 32-bytes of random hex. -# Normally you'd keep the same private IDs from one run to the next, but -# this is just an example. -msg "Generating keys..." -privateid1=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) -privateid2=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) -privateid3=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) - -# Public IDs are the SHA-256 of the private ID. -publicid1=$(echo -n $privateid1 | xxd -r -p - | shasum -a 256 | sed 's/ -//') -publicid2=$(echo -n $privateid2 | xxd -r -p - | shasum -a 256 | sed 's/ -//') -publicid3=$(echo -n $privateid3 | xxd -r -p - | shasum -a 256 | sed 's/ -//') - -# Write the machine logs to the input collection. -# Notice that this doesn't require an API key. -msg "Producing new logs..." -echo "$LOG_TEXT" | logtail -c $COLLECTION_IN -k $privateid1 >/dev/null -echo "$LOG_TEXT" | logtail -c $COLLECTION_IN -k $privateid2 >/dev/null - -# Adopt the logs, so they will be kept and are readable. -msg "Adopting logs..." -logadopt -p "$LOGTAIL_API_KEY" -c "$COLLECTION_IN" -m $publicid1 -logadopt -p "$LOGTAIL_API_KEY" -c "$COLLECTION_IN" -m $publicid2 - -# Reprocess the logs, amalgamating python tracebacks. -# -# We'll take that reprocessed output and write it to a separate collection, -# again via logtail. -# -# Time out quickly because all our "interesting" logs (generated -# above) have already been processed. -msg "Reprocessing logs..." -logreprocess -t 3s -c "$COLLECTION_IN" -p "$LOGTAIL_API_KEY" 2>&1 | - logtail -c "$COLLECTION_OUT" -k $privateid3 +#!/bin/bash +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +# +# This shell script demonstrates writing logs from machines +# and then reprocessing those logs to amalgamate python tracebacks +# into a single log entry in a new collection. +# +# To run this demo, first install the example applications: +# +# go install tailscale.com/logtail/example/... +# +# Then generate a LOGTAIL_API_KEY and two test collections by visiting: +# +# https://log.tailscale.io +# +# Then set the three variables below. +trap 'rv=$?; [ "$rv" = 0 ] || echo "-- exiting with code $rv"; exit $rv' EXIT +set -e + +LOG_TEXT='server starting +config file loaded +answering queries +Traceback (most recent call last): + File "/Users/crawshaw/junk.py", line 6, in + main() + File "/Users/crawshaw/junk.py", line 4, in main + raise Exception("oops") +Exception: oops' + +die() { + echo "$0: $*" >&2 + exit 1 +} + +msg() { + echo "-- $*" >&2 +} + +if [ -z "$LOGTAIL_API_KEY" ]; then + die "LOGTAIL_API_KEY is not set" +fi + +if [ -z "$COLLECTION_IN" ]; then + die "COLLECTION_IN is not set" +fi + +if [ -z "$COLLECTION_OUT" ]; then + die "COLLECTION_OUT is not set" +fi + +# Private IDs are 32-bytes of random hex. +# Normally you'd keep the same private IDs from one run to the next, but +# this is just an example. +msg "Generating keys..." +privateid1=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) +privateid2=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) +privateid3=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) + +# Public IDs are the SHA-256 of the private ID. +publicid1=$(echo -n $privateid1 | xxd -r -p - | shasum -a 256 | sed 's/ -//') +publicid2=$(echo -n $privateid2 | xxd -r -p - | shasum -a 256 | sed 's/ -//') +publicid3=$(echo -n $privateid3 | xxd -r -p - | shasum -a 256 | sed 's/ -//') + +# Write the machine logs to the input collection. +# Notice that this doesn't require an API key. +msg "Producing new logs..." +echo "$LOG_TEXT" | logtail -c $COLLECTION_IN -k $privateid1 >/dev/null +echo "$LOG_TEXT" | logtail -c $COLLECTION_IN -k $privateid2 >/dev/null + +# Adopt the logs, so they will be kept and are readable. +msg "Adopting logs..." +logadopt -p "$LOGTAIL_API_KEY" -c "$COLLECTION_IN" -m $publicid1 +logadopt -p "$LOGTAIL_API_KEY" -c "$COLLECTION_IN" -m $publicid2 + +# Reprocess the logs, amalgamating python tracebacks. +# +# We'll take that reprocessed output and write it to a separate collection, +# again via logtail. +# +# Time out quickly because all our "interesting" logs (generated +# above) have already been processed. +msg "Reprocessing logs..." +logreprocess -t 3s -c "$COLLECTION_IN" -p "$LOGTAIL_API_KEY" 2>&1 | + logtail -c "$COLLECTION_OUT" -k $privateid3 diff --git a/logtail/example/logreprocess/logreprocess.go b/logtail/example/logreprocess/logreprocess.go index e88d5b4856700..5dbf765788165 100644 --- a/logtail/example/logreprocess/logreprocess.go +++ b/logtail/example/logreprocess/logreprocess.go @@ -1,115 +1,115 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The logreprocess program tails a log and reprocesses it. -package main - -import ( - "bufio" - "encoding/json" - "flag" - "io" - "log" - "net/http" - "os" - "strings" - "time" - - "tailscale.com/types/logid" -) - -func main() { - collection := flag.String("c", "", "logtail collection name to read") - apiKey := flag.String("p", "", "logtail API key") - timeout := flag.Duration("t", 0, "timeout after which logreprocess quits") - flag.Parse() - if len(flag.Args()) != 0 { - flag.Usage() - os.Exit(1) - } - log.SetFlags(0) - - if *timeout != 0 { - go func() { - <-time.After(*timeout) - log.Printf("logreprocess: timeout reached, quitting") - os.Exit(1) - }() - } - - req, err := http.NewRequest("GET", "https://log.tailscale.io/c/"+*collection+"?stream=true", nil) - if err != nil { - log.Fatal(err) - } - req.SetBasicAuth(*apiKey, "") - resp, err := http.DefaultClient.Do(req) - if err != nil { - log.Fatal(err) - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - b, err := io.ReadAll(resp.Body) - if err != nil { - log.Fatalf("logreprocess: read error %d: %v", resp.StatusCode, err) - } - log.Fatalf("logreprocess: read error %d: %s", resp.StatusCode, string(b)) - } - - tracebackCache := make(map[logid.PublicID]*ProcessedMsg) - - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - var msg Msg - if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil { - log.Fatalf("logreprocess of %q: %v", string(scanner.Bytes()), err) - } - var pMsg *ProcessedMsg - if pMsg = tracebackCache[msg.Logtail.Instance]; pMsg != nil { - pMsg.Text += "\n" + msg.Text - if strings.HasPrefix(msg.Text, "Exception: ") { - delete(tracebackCache, msg.Logtail.Instance) - } else { - continue // write later - } - } else { - pMsg = &ProcessedMsg{ - OrigInstance: msg.Logtail.Instance, - Text: msg.Text, - } - pMsg.Logtail.ClientTime = msg.Logtail.ClientTime - } - - if strings.HasPrefix(msg.Text, "Traceback (most recent call last):") { - tracebackCache[msg.Logtail.Instance] = pMsg - continue // write later - } - - b, err := json.Marshal(pMsg) - if err != nil { - log.Fatal(err) - } - log.Printf("%s", b) - } - if err := scanner.Err(); err != nil { - log.Fatal(err) - } -} - -type Msg struct { - Logtail struct { - Instance logid.PublicID `json:"instance"` - ClientTime time.Time `json:"client_time"` - } `json:"logtail"` - - Text string `json:"text"` -} - -type ProcessedMsg struct { - Logtail struct { - ClientTime time.Time `json:"client_time"` - } `json:"logtail"` - - OrigInstance logid.PublicID `json:"orig_instance"` - Text string `json:"text"` -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The logreprocess program tails a log and reprocesses it. +package main + +import ( + "bufio" + "encoding/json" + "flag" + "io" + "log" + "net/http" + "os" + "strings" + "time" + + "tailscale.com/types/logid" +) + +func main() { + collection := flag.String("c", "", "logtail collection name to read") + apiKey := flag.String("p", "", "logtail API key") + timeout := flag.Duration("t", 0, "timeout after which logreprocess quits") + flag.Parse() + if len(flag.Args()) != 0 { + flag.Usage() + os.Exit(1) + } + log.SetFlags(0) + + if *timeout != 0 { + go func() { + <-time.After(*timeout) + log.Printf("logreprocess: timeout reached, quitting") + os.Exit(1) + }() + } + + req, err := http.NewRequest("GET", "https://log.tailscale.io/c/"+*collection+"?stream=true", nil) + if err != nil { + log.Fatal(err) + } + req.SetBasicAuth(*apiKey, "") + resp, err := http.DefaultClient.Do(req) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + b, err := io.ReadAll(resp.Body) + if err != nil { + log.Fatalf("logreprocess: read error %d: %v", resp.StatusCode, err) + } + log.Fatalf("logreprocess: read error %d: %s", resp.StatusCode, string(b)) + } + + tracebackCache := make(map[logid.PublicID]*ProcessedMsg) + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + var msg Msg + if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil { + log.Fatalf("logreprocess of %q: %v", string(scanner.Bytes()), err) + } + var pMsg *ProcessedMsg + if pMsg = tracebackCache[msg.Logtail.Instance]; pMsg != nil { + pMsg.Text += "\n" + msg.Text + if strings.HasPrefix(msg.Text, "Exception: ") { + delete(tracebackCache, msg.Logtail.Instance) + } else { + continue // write later + } + } else { + pMsg = &ProcessedMsg{ + OrigInstance: msg.Logtail.Instance, + Text: msg.Text, + } + pMsg.Logtail.ClientTime = msg.Logtail.ClientTime + } + + if strings.HasPrefix(msg.Text, "Traceback (most recent call last):") { + tracebackCache[msg.Logtail.Instance] = pMsg + continue // write later + } + + b, err := json.Marshal(pMsg) + if err != nil { + log.Fatal(err) + } + log.Printf("%s", b) + } + if err := scanner.Err(); err != nil { + log.Fatal(err) + } +} + +type Msg struct { + Logtail struct { + Instance logid.PublicID `json:"instance"` + ClientTime time.Time `json:"client_time"` + } `json:"logtail"` + + Text string `json:"text"` +} + +type ProcessedMsg struct { + Logtail struct { + ClientTime time.Time `json:"client_time"` + } `json:"logtail"` + + OrigInstance logid.PublicID `json:"orig_instance"` + Text string `json:"text"` +} diff --git a/logtail/example/logtail/logtail.go b/logtail/example/logtail/logtail.go index e777055133904..0c9e442584410 100644 --- a/logtail/example/logtail/logtail.go +++ b/logtail/example/logtail/logtail.go @@ -1,46 +1,46 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The logtail program logs stdin. -package main - -import ( - "bufio" - "flag" - "io" - "log" - "os" - - "tailscale.com/logtail" - "tailscale.com/types/logid" -) - -func main() { - collection := flag.String("c", "", "logtail collection name") - privateID := flag.String("k", "", "machine private identifier, 32-bytes in hex") - flag.Parse() - if len(flag.Args()) != 0 { - flag.Usage() - os.Exit(1) - } - - log.SetFlags(0) - - var id logid.PrivateID - if err := id.UnmarshalText([]byte(*privateID)); err != nil { - log.Fatalf("logtail: bad -privateid: %v", err) - } - - logger := logtail.NewLogger(logtail.Config{ - Collection: *collection, - PrivateID: id, - }, log.Printf) - log.SetOutput(io.MultiWriter(logger, os.Stdout)) - defer logger.Flush() - defer log.Printf("logtail exited") - - scanner := bufio.NewScanner(os.Stdin) - for scanner.Scan() { - log.Println(scanner.Text()) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The logtail program logs stdin. +package main + +import ( + "bufio" + "flag" + "io" + "log" + "os" + + "tailscale.com/logtail" + "tailscale.com/types/logid" +) + +func main() { + collection := flag.String("c", "", "logtail collection name") + privateID := flag.String("k", "", "machine private identifier, 32-bytes in hex") + flag.Parse() + if len(flag.Args()) != 0 { + flag.Usage() + os.Exit(1) + } + + log.SetFlags(0) + + var id logid.PrivateID + if err := id.UnmarshalText([]byte(*privateID)); err != nil { + log.Fatalf("logtail: bad -privateid: %v", err) + } + + logger := logtail.NewLogger(logtail.Config{ + Collection: *collection, + PrivateID: id, + }, log.Printf) + log.SetOutput(io.MultiWriter(logger, os.Stdout)) + defer logger.Flush() + defer log.Printf("logtail exited") + + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + log.Println(scanner.Text()) + } +} diff --git a/logtail/filch/filch.go b/logtail/filch/filch.go index 886fe239c71b8..d00206dd51487 100644 --- a/logtail/filch/filch.go +++ b/logtail/filch/filch.go @@ -1,284 +1,284 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package filch is a file system queue that pilfers your stderr. -// (A FILe CHannel that filches.) -package filch - -import ( - "bufio" - "bytes" - "fmt" - "io" - "os" - "sync" -) - -var stderrFD = 2 // a variable for testing - -const defaultMaxFileSize = 50 << 20 - -type Options struct { - ReplaceStderr bool // dup over fd 2 so everything written to stderr comes here - MaxFileSize int -} - -// A Filch uses two alternating files as a simplistic ring buffer. -type Filch struct { - OrigStderr *os.File - - mu sync.Mutex - cur *os.File - alt *os.File - altscan *bufio.Scanner - recovered int64 - - maxFileSize int64 - writeCounter int - - // buf is an initial buffer for altscan. - // As of August 2021, 99.96% of all log lines - // are below 4096 bytes in length. - // Since this cutoff is arbitrary, instead of using 4096, - // we subtract off the size of the rest of the struct - // so that the whole struct takes 4096 bytes - // (less on 32 bit platforms). - // This reduces allocation waste. - buf [4096 - 64]byte -} - -// TryReadline implements the logtail.Buffer interface. -func (f *Filch) TryReadLine() ([]byte, error) { - f.mu.Lock() - defer f.mu.Unlock() - - if f.altscan != nil { - if b, err := f.scan(); b != nil || err != nil { - return b, err - } - } - - f.cur, f.alt = f.alt, f.cur - if f.OrigStderr != nil { - if err := dup2Stderr(f.cur); err != nil { - return nil, err - } - } - if _, err := f.alt.Seek(0, io.SeekStart); err != nil { - return nil, err - } - f.altscan = bufio.NewScanner(f.alt) - f.altscan.Buffer(f.buf[:], bufio.MaxScanTokenSize) - f.altscan.Split(splitLines) - return f.scan() -} - -func (f *Filch) scan() ([]byte, error) { - if f.altscan.Scan() { - return f.altscan.Bytes(), nil - } - err := f.altscan.Err() - err2 := f.alt.Truncate(0) - _, err3 := f.alt.Seek(0, io.SeekStart) - f.altscan = nil - if err != nil { - return nil, err - } - if err2 != nil { - return nil, err2 - } - if err3 != nil { - return nil, err3 - } - return nil, nil -} - -// Write implements the logtail.Buffer interface. -func (f *Filch) Write(b []byte) (int, error) { - f.mu.Lock() - defer f.mu.Unlock() - if f.writeCounter == 100 { - // Check the file size every 100 writes. - f.writeCounter = 0 - fi, err := f.cur.Stat() - if err != nil { - return 0, err - } - if fi.Size() >= f.maxFileSize { - // This most likely means we are not draining. - // To limit the amount of space we use, throw away the old logs. - if err := moveContents(f.alt, f.cur); err != nil { - return 0, err - } - } - } - f.writeCounter++ - - if len(b) == 0 || b[len(b)-1] != '\n' { - bnl := make([]byte, len(b)+1) - copy(bnl, b) - bnl[len(bnl)-1] = '\n' - return f.cur.Write(bnl) - } - return f.cur.Write(b) -} - -// Close closes the Filch, releasing all os resources. -func (f *Filch) Close() (err error) { - f.mu.Lock() - defer f.mu.Unlock() - - if f.OrigStderr != nil { - if err2 := unsaveStderr(f.OrigStderr); err == nil { - err = err2 - } - f.OrigStderr = nil - } - - if err2 := f.cur.Close(); err == nil { - err = err2 - } - if err2 := f.alt.Close(); err == nil { - err = err2 - } - - return err -} - -// New creates a new filch around two log files, each starting with filePrefix. -func New(filePrefix string, opts Options) (f *Filch, err error) { - var f1, f2 *os.File - defer func() { - if err != nil { - if f1 != nil { - f1.Close() - } - if f2 != nil { - f2.Close() - } - err = fmt.Errorf("filch: %s", err) - } - }() - - path1 := filePrefix + ".log1.txt" - path2 := filePrefix + ".log2.txt" - - f1, err = os.OpenFile(path1, os.O_CREATE|os.O_RDWR, 0600) - if err != nil { - return nil, err - } - f2, err = os.OpenFile(path2, os.O_CREATE|os.O_RDWR, 0600) - if err != nil { - return nil, err - } - - fi1, err := f1.Stat() - if err != nil { - return nil, err - } - fi2, err := f2.Stat() - if err != nil { - return nil, err - } - - mfs := defaultMaxFileSize - if opts.MaxFileSize > 0 { - mfs = opts.MaxFileSize - } - f = &Filch{ - OrigStderr: os.Stderr, // temporary, for past logs recovery - maxFileSize: int64(mfs), - } - - // Neither, either, or both files may exist and contain logs from - // the last time the process ran. The three cases are: - // - // - neither: all logs were read out and files were truncated - // - either: logs were being written into one of the files - // - both: the files were swapped and were starting to be - // read out, while new logs streamed into the other - // file, but the read out did not complete - if n := fi1.Size() + fi2.Size(); n > 0 { - f.recovered = n - } - switch { - case fi1.Size() > 0 && fi2.Size() == 0: - f.cur, f.alt = f2, f1 - case fi2.Size() > 0 && fi1.Size() == 0: - f.cur, f.alt = f1, f2 - case fi1.Size() > 0 && fi2.Size() > 0: // both - // We need to pick one of the files to be the elder, - // which we do using the mtime. - var older, newer *os.File - if fi1.ModTime().Before(fi2.ModTime()) { - older, newer = f1, f2 - } else { - older, newer = f2, f1 - } - if err := moveContents(older, newer); err != nil { - fmt.Fprintf(f.OrigStderr, "filch: recover move failed: %v\n", err) - fmt.Fprintf(older, "filch: recover move failed: %v\n", err) - } - f.cur, f.alt = newer, older - default: - f.cur, f.alt = f1, f2 // does not matter - } - if f.recovered > 0 { - f.altscan = bufio.NewScanner(f.alt) - f.altscan.Buffer(f.buf[:], bufio.MaxScanTokenSize) - f.altscan.Split(splitLines) - } - - f.OrigStderr = nil - if opts.ReplaceStderr { - f.OrigStderr, err = saveStderr() - if err != nil { - return nil, err - } - if err := dup2Stderr(f.cur); err != nil { - return nil, err - } - } - - return f, nil -} - -func moveContents(dst, src *os.File) (err error) { - defer func() { - _, err2 := src.Seek(0, io.SeekStart) - err3 := src.Truncate(0) - _, err4 := dst.Seek(0, io.SeekStart) - if err == nil { - err = err2 - } - if err == nil { - err = err3 - } - if err == nil { - err = err4 - } - }() - if _, err := src.Seek(0, io.SeekStart); err != nil { - return err - } - if _, err := dst.Seek(0, io.SeekStart); err != nil { - return err - } - if _, err := io.Copy(dst, src); err != nil { - return err - } - return nil -} - -func splitLines(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := bytes.IndexByte(data, '\n'); i >= 0 { - return i + 1, data[0 : i+1], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package filch is a file system queue that pilfers your stderr. +// (A FILe CHannel that filches.) +package filch + +import ( + "bufio" + "bytes" + "fmt" + "io" + "os" + "sync" +) + +var stderrFD = 2 // a variable for testing + +const defaultMaxFileSize = 50 << 20 + +type Options struct { + ReplaceStderr bool // dup over fd 2 so everything written to stderr comes here + MaxFileSize int +} + +// A Filch uses two alternating files as a simplistic ring buffer. +type Filch struct { + OrigStderr *os.File + + mu sync.Mutex + cur *os.File + alt *os.File + altscan *bufio.Scanner + recovered int64 + + maxFileSize int64 + writeCounter int + + // buf is an initial buffer for altscan. + // As of August 2021, 99.96% of all log lines + // are below 4096 bytes in length. + // Since this cutoff is arbitrary, instead of using 4096, + // we subtract off the size of the rest of the struct + // so that the whole struct takes 4096 bytes + // (less on 32 bit platforms). + // This reduces allocation waste. + buf [4096 - 64]byte +} + +// TryReadline implements the logtail.Buffer interface. +func (f *Filch) TryReadLine() ([]byte, error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.altscan != nil { + if b, err := f.scan(); b != nil || err != nil { + return b, err + } + } + + f.cur, f.alt = f.alt, f.cur + if f.OrigStderr != nil { + if err := dup2Stderr(f.cur); err != nil { + return nil, err + } + } + if _, err := f.alt.Seek(0, io.SeekStart); err != nil { + return nil, err + } + f.altscan = bufio.NewScanner(f.alt) + f.altscan.Buffer(f.buf[:], bufio.MaxScanTokenSize) + f.altscan.Split(splitLines) + return f.scan() +} + +func (f *Filch) scan() ([]byte, error) { + if f.altscan.Scan() { + return f.altscan.Bytes(), nil + } + err := f.altscan.Err() + err2 := f.alt.Truncate(0) + _, err3 := f.alt.Seek(0, io.SeekStart) + f.altscan = nil + if err != nil { + return nil, err + } + if err2 != nil { + return nil, err2 + } + if err3 != nil { + return nil, err3 + } + return nil, nil +} + +// Write implements the logtail.Buffer interface. +func (f *Filch) Write(b []byte) (int, error) { + f.mu.Lock() + defer f.mu.Unlock() + if f.writeCounter == 100 { + // Check the file size every 100 writes. + f.writeCounter = 0 + fi, err := f.cur.Stat() + if err != nil { + return 0, err + } + if fi.Size() >= f.maxFileSize { + // This most likely means we are not draining. + // To limit the amount of space we use, throw away the old logs. + if err := moveContents(f.alt, f.cur); err != nil { + return 0, err + } + } + } + f.writeCounter++ + + if len(b) == 0 || b[len(b)-1] != '\n' { + bnl := make([]byte, len(b)+1) + copy(bnl, b) + bnl[len(bnl)-1] = '\n' + return f.cur.Write(bnl) + } + return f.cur.Write(b) +} + +// Close closes the Filch, releasing all os resources. +func (f *Filch) Close() (err error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.OrigStderr != nil { + if err2 := unsaveStderr(f.OrigStderr); err == nil { + err = err2 + } + f.OrigStderr = nil + } + + if err2 := f.cur.Close(); err == nil { + err = err2 + } + if err2 := f.alt.Close(); err == nil { + err = err2 + } + + return err +} + +// New creates a new filch around two log files, each starting with filePrefix. +func New(filePrefix string, opts Options) (f *Filch, err error) { + var f1, f2 *os.File + defer func() { + if err != nil { + if f1 != nil { + f1.Close() + } + if f2 != nil { + f2.Close() + } + err = fmt.Errorf("filch: %s", err) + } + }() + + path1 := filePrefix + ".log1.txt" + path2 := filePrefix + ".log2.txt" + + f1, err = os.OpenFile(path1, os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return nil, err + } + f2, err = os.OpenFile(path2, os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return nil, err + } + + fi1, err := f1.Stat() + if err != nil { + return nil, err + } + fi2, err := f2.Stat() + if err != nil { + return nil, err + } + + mfs := defaultMaxFileSize + if opts.MaxFileSize > 0 { + mfs = opts.MaxFileSize + } + f = &Filch{ + OrigStderr: os.Stderr, // temporary, for past logs recovery + maxFileSize: int64(mfs), + } + + // Neither, either, or both files may exist and contain logs from + // the last time the process ran. The three cases are: + // + // - neither: all logs were read out and files were truncated + // - either: logs were being written into one of the files + // - both: the files were swapped and were starting to be + // read out, while new logs streamed into the other + // file, but the read out did not complete + if n := fi1.Size() + fi2.Size(); n > 0 { + f.recovered = n + } + switch { + case fi1.Size() > 0 && fi2.Size() == 0: + f.cur, f.alt = f2, f1 + case fi2.Size() > 0 && fi1.Size() == 0: + f.cur, f.alt = f1, f2 + case fi1.Size() > 0 && fi2.Size() > 0: // both + // We need to pick one of the files to be the elder, + // which we do using the mtime. + var older, newer *os.File + if fi1.ModTime().Before(fi2.ModTime()) { + older, newer = f1, f2 + } else { + older, newer = f2, f1 + } + if err := moveContents(older, newer); err != nil { + fmt.Fprintf(f.OrigStderr, "filch: recover move failed: %v\n", err) + fmt.Fprintf(older, "filch: recover move failed: %v\n", err) + } + f.cur, f.alt = newer, older + default: + f.cur, f.alt = f1, f2 // does not matter + } + if f.recovered > 0 { + f.altscan = bufio.NewScanner(f.alt) + f.altscan.Buffer(f.buf[:], bufio.MaxScanTokenSize) + f.altscan.Split(splitLines) + } + + f.OrigStderr = nil + if opts.ReplaceStderr { + f.OrigStderr, err = saveStderr() + if err != nil { + return nil, err + } + if err := dup2Stderr(f.cur); err != nil { + return nil, err + } + } + + return f, nil +} + +func moveContents(dst, src *os.File) (err error) { + defer func() { + _, err2 := src.Seek(0, io.SeekStart) + err3 := src.Truncate(0) + _, err4 := dst.Seek(0, io.SeekStart) + if err == nil { + err = err2 + } + if err == nil { + err = err3 + } + if err == nil { + err = err4 + } + }() + if _, err := src.Seek(0, io.SeekStart); err != nil { + return err + } + if _, err := dst.Seek(0, io.SeekStart); err != nil { + return err + } + if _, err := io.Copy(dst, src); err != nil { + return err + } + return nil +} + +func splitLines(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.IndexByte(data, '\n'); i >= 0 { + return i + 1, data[0 : i+1], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil +} diff --git a/logtail/filch/filch_stub.go b/logtail/filch/filch_stub.go index fe718d150d0b8..3bb82b1906f17 100644 --- a/logtail/filch/filch_stub.go +++ b/logtail/filch/filch_stub.go @@ -1,23 +1,23 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build wasm || plan9 || tamago - -package filch - -import ( - "os" -) - -func saveStderr() (*os.File, error) { - return os.Stderr, nil -} - -func unsaveStderr(f *os.File) error { - os.Stderr = f - return nil -} - -func dup2Stderr(f *os.File) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build wasm || plan9 || tamago + +package filch + +import ( + "os" +) + +func saveStderr() (*os.File, error) { + return os.Stderr, nil +} + +func unsaveStderr(f *os.File) error { + os.Stderr = f + return nil +} + +func dup2Stderr(f *os.File) error { + return nil +} diff --git a/logtail/filch/filch_unix.go b/logtail/filch/filch_unix.go index b06ef6afde99f..2eae70aceb187 100644 --- a/logtail/filch/filch_unix.go +++ b/logtail/filch/filch_unix.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !wasm && !plan9 && !tamago - -package filch - -import ( - "os" - - "golang.org/x/sys/unix" -) - -func saveStderr() (*os.File, error) { - fd, err := unix.Dup(stderrFD) - if err != nil { - return nil, err - } - return os.NewFile(uintptr(fd), "stderr"), nil -} - -func unsaveStderr(f *os.File) error { - err := dup2Stderr(f) - f.Close() - return err -} - -func dup2Stderr(f *os.File) error { - return unix.Dup2(int(f.Fd()), stderrFD) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !wasm && !plan9 && !tamago + +package filch + +import ( + "os" + + "golang.org/x/sys/unix" +) + +func saveStderr() (*os.File, error) { + fd, err := unix.Dup(stderrFD) + if err != nil { + return nil, err + } + return os.NewFile(uintptr(fd), "stderr"), nil +} + +func unsaveStderr(f *os.File) error { + err := dup2Stderr(f) + f.Close() + return err +} + +func dup2Stderr(f *os.File) error { + return unix.Dup2(int(f.Fd()), stderrFD) +} diff --git a/logtail/filch/filch_windows.go b/logtail/filch/filch_windows.go index 1419d660689ce..d60514bf00abe 100644 --- a/logtail/filch/filch_windows.go +++ b/logtail/filch/filch_windows.go @@ -1,43 +1,43 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package filch - -import ( - "fmt" - "os" - "syscall" -) - -var kernel32 = syscall.MustLoadDLL("kernel32.dll") -var procSetStdHandle = kernel32.MustFindProc("SetStdHandle") - -func setStdHandle(stdHandle int32, handle syscall.Handle) error { - r, _, e := syscall.Syscall(procSetStdHandle.Addr(), 2, uintptr(stdHandle), uintptr(handle), 0) - if r == 0 { - if e != 0 { - return error(e) - } - return syscall.EINVAL - } - return nil -} - -func saveStderr() (*os.File, error) { - return os.Stderr, nil -} - -func unsaveStderr(f *os.File) error { - os.Stderr = f - return nil -} - -func dup2Stderr(f *os.File) error { - fd := int(f.Fd()) - err := setStdHandle(syscall.STD_ERROR_HANDLE, syscall.Handle(fd)) - if err != nil { - return fmt.Errorf("dup2Stderr: %w", err) - } - os.Stderr = f - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package filch + +import ( + "fmt" + "os" + "syscall" +) + +var kernel32 = syscall.MustLoadDLL("kernel32.dll") +var procSetStdHandle = kernel32.MustFindProc("SetStdHandle") + +func setStdHandle(stdHandle int32, handle syscall.Handle) error { + r, _, e := syscall.Syscall(procSetStdHandle.Addr(), 2, uintptr(stdHandle), uintptr(handle), 0) + if r == 0 { + if e != 0 { + return error(e) + } + return syscall.EINVAL + } + return nil +} + +func saveStderr() (*os.File, error) { + return os.Stderr, nil +} + +func unsaveStderr(f *os.File) error { + os.Stderr = f + return nil +} + +func dup2Stderr(f *os.File) error { + fd := int(f.Fd()) + err := setStdHandle(syscall.STD_ERROR_HANDLE, syscall.Handle(fd)) + if err != nil { + return fmt.Errorf("dup2Stderr: %w", err) + } + os.Stderr = f + return nil +} diff --git a/metrics/fds_linux.go b/metrics/fds_linux.go index 66ebb419d787c..34740c2bb1c74 100644 --- a/metrics/fds_linux.go +++ b/metrics/fds_linux.go @@ -1,41 +1,41 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package metrics - -import ( - "io/fs" - "sync" - - "go4.org/mem" - "tailscale.com/util/dirwalk" -) - -// counter is a reusable counter for counting file descriptors. -type counter struct { - n int - - // cb is the (*counter).count method value. Creating it allocates, - // so we have to save it away and use a sync.Pool to keep currentFDs - // amortized alloc-free. - cb func(name mem.RO, de fs.DirEntry) error -} - -var counterPool = &sync.Pool{New: func() any { - c := new(counter) - c.cb = c.count - return c -}} - -func (c *counter) count(name mem.RO, de fs.DirEntry) error { - c.n++ - return nil -} - -func currentFDs() int { - c := counterPool.Get().(*counter) - defer counterPool.Put(c) - c.n = 0 - dirwalk.WalkShallow(mem.S("/proc/self/fd"), c.cb) - return c.n -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package metrics + +import ( + "io/fs" + "sync" + + "go4.org/mem" + "tailscale.com/util/dirwalk" +) + +// counter is a reusable counter for counting file descriptors. +type counter struct { + n int + + // cb is the (*counter).count method value. Creating it allocates, + // so we have to save it away and use a sync.Pool to keep currentFDs + // amortized alloc-free. + cb func(name mem.RO, de fs.DirEntry) error +} + +var counterPool = &sync.Pool{New: func() any { + c := new(counter) + c.cb = c.count + return c +}} + +func (c *counter) count(name mem.RO, de fs.DirEntry) error { + c.n++ + return nil +} + +func currentFDs() int { + c := counterPool.Get().(*counter) + defer counterPool.Put(c) + c.n = 0 + dirwalk.WalkShallow(mem.S("/proc/self/fd"), c.cb) + return c.n +} diff --git a/metrics/fds_notlinux.go b/metrics/fds_notlinux.go index 5a59d4de9d8bf..2dae97cad86b9 100644 --- a/metrics/fds_notlinux.go +++ b/metrics/fds_notlinux.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package metrics - -func currentFDs() int { return 0 } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package metrics + +func currentFDs() int { return 0 } diff --git a/metrics/metrics.go b/metrics/metrics.go index 0f67ffa305e7c..a07ddccae5107 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -1,163 +1,163 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package metrics contains expvar & Prometheus types and code used by -// Tailscale for monitoring. -package metrics - -import ( - "expvar" - "fmt" - "io" - "slices" - "strings" -) - -// Set is a string-to-Var map variable that satisfies the expvar.Var -// interface. -// -// Semantically, this is mapped by tsweb's Prometheus exporter as a -// collection of unrelated variables exported with a common prefix. -// -// This lets us have tsweb recognize *expvar.Map for different -// purposes in the future. (Or perhaps all uses of expvar.Map will -// require explicit types like this one, declaring how we want tsweb -// to export it to Prometheus.) -type Set struct { - expvar.Map -} - -// LabelMap is a string-to-Var map variable that satisfies the -// expvar.Var interface. -// -// Semantically, this is mapped by tsweb's Prometheus exporter as a -// collection of variables with the same name, with a varying label -// value. Use this to export things that are intuitively breakdowns -// into different buckets. -type LabelMap struct { - Label string - expvar.Map -} - -// SetInt64 sets the *Int value stored under the given map key. -func (m *LabelMap) SetInt64(key string, v int64) { - m.Get(key).Set(v) -} - -// Get returns a direct pointer to the expvar.Int for key, creating it -// if necessary. -func (m *LabelMap) Get(key string) *expvar.Int { - m.Add(key, 0) - return m.Map.Get(key).(*expvar.Int) -} - -// GetIncrFunc returns a function that increments the expvar.Int named by key. -// -// Most callers should not need this; it exists to satisfy an -// interface elsewhere. -func (m *LabelMap) GetIncrFunc(key string) func(delta int64) { - return m.Get(key).Add -} - -// GetFloat returns a direct pointer to the expvar.Float for key, creating it -// if necessary. -func (m *LabelMap) GetFloat(key string) *expvar.Float { - m.AddFloat(key, 0.0) - return m.Map.Get(key).(*expvar.Float) -} - -// CurrentFDs reports how many file descriptors are currently open. -// -// It only works on Linux. It returns zero otherwise. -func CurrentFDs() int { - return currentFDs() -} - -// Histogram is a histogram of values. -// It should be created with NewHistogram. -type Histogram struct { - // buckets is a list of bucket boundaries, in increasing order. - buckets []float64 - - // bucketStrings is a list of the same buckets, but as strings. - // This are allocated once at creation time by NewHistogram. - bucketStrings []string - - bucketVars []expvar.Int - sum expvar.Float - count expvar.Int -} - -// NewHistogram returns a new histogram that reports to the given -// expvar map under the given name. -// -// The buckets are the boundaries of the histogram buckets, in -// increasing order. The last bucket is +Inf. -func NewHistogram(buckets []float64) *Histogram { - if !slices.IsSorted(buckets) { - panic("buckets must be sorted") - } - labels := make([]string, len(buckets)) - for i, b := range buckets { - labels[i] = fmt.Sprintf("%v", b) - } - h := &Histogram{ - buckets: buckets, - bucketStrings: labels, - bucketVars: make([]expvar.Int, len(buckets)), - } - return h -} - -// Observe records a new observation in the histogram. -func (h *Histogram) Observe(v float64) { - h.sum.Add(v) - h.count.Add(1) - for i, b := range h.buckets { - if v <= b { - h.bucketVars[i].Add(1) - } - } -} - -// String returns a JSON representation of the histogram. -// This is used to satisfy the expvar.Var interface. -func (h *Histogram) String() string { - var b strings.Builder - fmt.Fprintf(&b, "{") - first := true - h.Do(func(kv expvar.KeyValue) { - if !first { - fmt.Fprintf(&b, ",") - } - fmt.Fprintf(&b, "%q: ", kv.Key) - if kv.Value != nil { - fmt.Fprintf(&b, "%v", kv.Value) - } else { - fmt.Fprint(&b, "null") - } - first = false - }) - fmt.Fprintf(&b, ",\"sum\": %v", &h.sum) - fmt.Fprintf(&b, ",\"count\": %v", &h.count) - fmt.Fprintf(&b, "}") - return b.String() -} - -// Do calls f for each bucket in the histogram. -func (h *Histogram) Do(f func(expvar.KeyValue)) { - for i := range h.bucketVars { - f(expvar.KeyValue{Key: h.bucketStrings[i], Value: &h.bucketVars[i]}) - } - f(expvar.KeyValue{Key: "+Inf", Value: &h.count}) -} - -// PromExport writes the histogram to w in Prometheus exposition format. -func (h *Histogram) PromExport(w io.Writer, name string) { - fmt.Fprintf(w, "# TYPE %s histogram\n", name) - h.Do(func(kv expvar.KeyValue) { - fmt.Fprintf(w, "%s_bucket{le=%q} %v\n", name, kv.Key, kv.Value) - }) - fmt.Fprintf(w, "%s_sum %v\n", name, &h.sum) - fmt.Fprintf(w, "%s_count %v\n", name, &h.count) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package metrics contains expvar & Prometheus types and code used by +// Tailscale for monitoring. +package metrics + +import ( + "expvar" + "fmt" + "io" + "slices" + "strings" +) + +// Set is a string-to-Var map variable that satisfies the expvar.Var +// interface. +// +// Semantically, this is mapped by tsweb's Prometheus exporter as a +// collection of unrelated variables exported with a common prefix. +// +// This lets us have tsweb recognize *expvar.Map for different +// purposes in the future. (Or perhaps all uses of expvar.Map will +// require explicit types like this one, declaring how we want tsweb +// to export it to Prometheus.) +type Set struct { + expvar.Map +} + +// LabelMap is a string-to-Var map variable that satisfies the +// expvar.Var interface. +// +// Semantically, this is mapped by tsweb's Prometheus exporter as a +// collection of variables with the same name, with a varying label +// value. Use this to export things that are intuitively breakdowns +// into different buckets. +type LabelMap struct { + Label string + expvar.Map +} + +// SetInt64 sets the *Int value stored under the given map key. +func (m *LabelMap) SetInt64(key string, v int64) { + m.Get(key).Set(v) +} + +// Get returns a direct pointer to the expvar.Int for key, creating it +// if necessary. +func (m *LabelMap) Get(key string) *expvar.Int { + m.Add(key, 0) + return m.Map.Get(key).(*expvar.Int) +} + +// GetIncrFunc returns a function that increments the expvar.Int named by key. +// +// Most callers should not need this; it exists to satisfy an +// interface elsewhere. +func (m *LabelMap) GetIncrFunc(key string) func(delta int64) { + return m.Get(key).Add +} + +// GetFloat returns a direct pointer to the expvar.Float for key, creating it +// if necessary. +func (m *LabelMap) GetFloat(key string) *expvar.Float { + m.AddFloat(key, 0.0) + return m.Map.Get(key).(*expvar.Float) +} + +// CurrentFDs reports how many file descriptors are currently open. +// +// It only works on Linux. It returns zero otherwise. +func CurrentFDs() int { + return currentFDs() +} + +// Histogram is a histogram of values. +// It should be created with NewHistogram. +type Histogram struct { + // buckets is a list of bucket boundaries, in increasing order. + buckets []float64 + + // bucketStrings is a list of the same buckets, but as strings. + // This are allocated once at creation time by NewHistogram. + bucketStrings []string + + bucketVars []expvar.Int + sum expvar.Float + count expvar.Int +} + +// NewHistogram returns a new histogram that reports to the given +// expvar map under the given name. +// +// The buckets are the boundaries of the histogram buckets, in +// increasing order. The last bucket is +Inf. +func NewHistogram(buckets []float64) *Histogram { + if !slices.IsSorted(buckets) { + panic("buckets must be sorted") + } + labels := make([]string, len(buckets)) + for i, b := range buckets { + labels[i] = fmt.Sprintf("%v", b) + } + h := &Histogram{ + buckets: buckets, + bucketStrings: labels, + bucketVars: make([]expvar.Int, len(buckets)), + } + return h +} + +// Observe records a new observation in the histogram. +func (h *Histogram) Observe(v float64) { + h.sum.Add(v) + h.count.Add(1) + for i, b := range h.buckets { + if v <= b { + h.bucketVars[i].Add(1) + } + } +} + +// String returns a JSON representation of the histogram. +// This is used to satisfy the expvar.Var interface. +func (h *Histogram) String() string { + var b strings.Builder + fmt.Fprintf(&b, "{") + first := true + h.Do(func(kv expvar.KeyValue) { + if !first { + fmt.Fprintf(&b, ",") + } + fmt.Fprintf(&b, "%q: ", kv.Key) + if kv.Value != nil { + fmt.Fprintf(&b, "%v", kv.Value) + } else { + fmt.Fprint(&b, "null") + } + first = false + }) + fmt.Fprintf(&b, ",\"sum\": %v", &h.sum) + fmt.Fprintf(&b, ",\"count\": %v", &h.count) + fmt.Fprintf(&b, "}") + return b.String() +} + +// Do calls f for each bucket in the histogram. +func (h *Histogram) Do(f func(expvar.KeyValue)) { + for i := range h.bucketVars { + f(expvar.KeyValue{Key: h.bucketStrings[i], Value: &h.bucketVars[i]}) + } + f(expvar.KeyValue{Key: "+Inf", Value: &h.count}) +} + +// PromExport writes the histogram to w in Prometheus exposition format. +func (h *Histogram) PromExport(w io.Writer, name string) { + fmt.Fprintf(w, "# TYPE %s histogram\n", name) + h.Do(func(kv expvar.KeyValue) { + fmt.Fprintf(w, "%s_bucket{le=%q} %v\n", name, kv.Key, kv.Value) + }) + fmt.Fprintf(w, "%s_sum %v\n", name, &h.sum) + fmt.Fprintf(w, "%s_count %v\n", name, &h.count) +} diff --git a/net/art/art_test.go b/net/art/art_test.go index e3a427107e69b..daf8553ca020d 100644 --- a/net/art/art_test.go +++ b/net/art/art_test.go @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package art - -import ( - "os" - "testing" - - "tailscale.com/util/cibuild" -) - -func TestMain(m *testing.M) { - if cibuild.On() { - // Skip CI on GitHub for now - // TODO: https://github.com/tailscale/tailscale/issues/7866 - os.Exit(0) - } - os.Exit(m.Run()) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package art + +import ( + "os" + "testing" + + "tailscale.com/util/cibuild" +) + +func TestMain(m *testing.M) { + if cibuild.On() { + // Skip CI on GitHub for now + // TODO: https://github.com/tailscale/tailscale/issues/7866 + os.Exit(0) + } + os.Exit(m.Run()) +} diff --git a/net/art/table.go b/net/art/table.go index 2e130d82f78a1..fa397577868a8 100644 --- a/net/art/table.go +++ b/net/art/table.go @@ -1,641 +1,641 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package art provides a routing table that implements the Allotment Routing -// Table (ART) algorithm by Donald Knuth, as described in the paper by Yoichi -// Hariguchi. -// -// ART outperforms the traditional radix tree implementations for route lookups, -// insertions, and deletions. -// -// For more information, see Yoichi Hariguchi's paper: -// https://cseweb.ucsd.edu//~varghese/TEACH/cs228/artlookup.pdf -package art - -import ( - "bytes" - "encoding/binary" - "fmt" - "io" - "math/bits" - "net/netip" - "strings" - "sync" -) - -const ( - debugInsert = false - debugDelete = false -) - -// Table is an IPv4 and IPv6 routing table. -type Table[T any] struct { - v4 strideTable[T] - v6 strideTable[T] - initOnce sync.Once -} - -func (t *Table[T]) init() { - t.initOnce.Do(func() { - t.v4.prefix = netip.PrefixFrom(netip.IPv4Unspecified(), 0) - t.v6.prefix = netip.PrefixFrom(netip.IPv6Unspecified(), 0) - }) -} - -func (t *Table[T]) tableForAddr(addr netip.Addr) *strideTable[T] { - if addr.Is6() { - return &t.v6 - } - return &t.v4 -} - -// Get does a route lookup for addr and returns the associated value, or nil if -// no route matched. -func (t *Table[T]) Get(addr netip.Addr) (ret T, ok bool) { - t.init() - - // Ideally we would use addr.AsSlice here, but AsSlice is just - // barely complex enough that it can't be inlined, and that in - // turn causes the slice to escape to the heap. Using As16 and - // manual slicing here helps the compiler keep Get alloc-free. - st := t.tableForAddr(addr) - rawAddr := addr.As16() - bs := rawAddr[:] - if addr.Is4() { - bs = bs[12:] - } - - i := 0 - // With path compression, we might skip over some address bits while walking - // to a strideTable leaf. This means the leaf answer we find might not be - // correct, because path compression took us down the wrong subtree. When - // that happens, we have to backtrack and figure out which most specific - // route further up the tree is relevant to addr, and return that. - // - // So, as we walk down the stride tables, each time we find a non-nil route - // result, we have to remember it and the associated strideTable prefix. - // - // We could also deal with this edge case of path compression by checking - // the strideTable prefix on each table as we descend, but that means we - // have to pay N prefix.Contains checks on every route lookup (where N is - // the number of strideTables in the path), rather than only paying M prefix - // comparisons in the edge case (where M is the number of strideTables in - // the path with a non-nil route of their own). - const maxDepth = 16 - type prefixAndRoute struct { - prefix netip.Prefix - route T - } - strideMatch := make([]prefixAndRoute, 0, maxDepth) -findLeaf: - for { - rt, rtOK, child := st.getValAndChild(bs[i]) - if rtOK { - // This strideTable contains a route that may be relevant to our - // search, remember it. - strideMatch = append(strideMatch, prefixAndRoute{st.prefix, rt}) - } - if child == nil { - // No sub-routes further down, the last thing we recorded - // in strideRoutes is tentatively the result, barring - // misdirection from path compression. - break findLeaf - } - st = child - // Path compression means we may be skipping over some intermediate - // tables. We have to skip forward to whatever depth st now references. - i = st.prefix.Bits() / 8 - } - - // Walk backwards through the hits we recorded in strideRoutes and - // stridePrefixes, returning the first one whose subtree matches addr. - // - // In the common case where path compression did not mislead us, we'll - // return on the first loop iteration because the last route we recorded was - // the correct most-specific route. - for i := len(strideMatch) - 1; i >= 0; i-- { - if m := strideMatch[i]; m.prefix.Contains(addr) { - return m.route, true - } - } - - // We either found no route hits at all (both previous loops terminated - // immediately), or we went on a wild goose chase down a compressed path for - // the wrong prefix, and also found no usable routes on the way back up to - // the root. This is a miss. - return ret, false -} - -// Insert adds pfx to the table, with value val. -// If pfx is already present in the table, its value is set to val. -func (t *Table[T]) Insert(pfx netip.Prefix, val T) { - t.init() - - // The standard library doesn't enforce normalized prefixes (where - // the non-prefix bits are all zero). These algorithms require - // normalized prefixes, so do it upfront. - pfx = pfx.Masked() - - if debugInsert { - defer func() { - fmt.Printf("%s", t.debugSummary()) - }() - fmt.Printf("\ninsert: start pfx=%s\n", pfx) - } - - st := t.tableForAddr(pfx.Addr()) - - // This algorithm is full of off-by-one headaches that boil down - // to the fact that pfx.Bits() has (2^n)+1 values, rather than - // just 2^n. For example, an IPv4 prefix length can be 0 through - // 32, which is 33 values. - // - // This extra possible value creates a lot of problems as we do - // bits and bytes math to traverse strideTables below. So, we - // treat the default route 0/0 specially here, that way the rest - // of the logic goes back to having 2^n values to reason about, - // which can be done in a nice and regular fashion with no edge - // cases. - if pfx.Bits() == 0 { - if debugInsert { - fmt.Printf("insert: default route\n") - } - st.insert(0, 0, val) - return - } - - // No matter what we do as we traverse strideTables, our final - // action will be to insert the last 1-8 bits of pfx into a - // strideTable somewhere. - // - // We calculate upfront the byte position of the end of the - // prefix; the number of bits within that byte that contain prefix - // data; and the prefix of the strideTable into which we'll - // eventually insert. - // - // We need this in a couple different branches of the code below, - // and because the possible values are 1-indexed (1 through 32 for - // ipv4, 1 through 128 for ipv6), the math is very slightly - // unusual to account for the off-by-one indexing. Do it once up - // here, with this large comment, rather than reproduce the subtle - // math in multiple places further down. - finalByteIdx := (pfx.Bits() - 1) / 8 - finalBits := pfx.Bits() - (finalByteIdx * 8) - finalStridePrefix, err := pfx.Addr().Prefix(finalByteIdx * 8) - if err != nil { - panic(fmt.Sprintf("invalid prefix requested: %s/%d", pfx.Addr(), finalByteIdx*8)) - } - if debugInsert { - fmt.Printf("insert: finalByteIdx=%d finalBits=%d finalStridePrefix=%s\n", finalByteIdx, finalBits, finalStridePrefix) - } - - // The strideTable we want to insert into is potentially at the - // end of a chain of strideTables, each one encoding 8 bits of the - // prefix. - // - // We're expecting to walk down a path of tables, although with - // prefix compression we may end up skipping some links in the - // chain, or taking wrong turns and having to course correct. - // - // As we walk down the tree, byteIdx is the byte of bs we're - // currently examining to choose our next step, and numBits is the - // number of bits that remain in pfx, starting with the byte at - // byteIdx inclusive. - bs := pfx.Addr().AsSlice() - byteIdx := 0 - numBits := pfx.Bits() - for { - if debugInsert { - fmt.Printf("insert: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix) - } - if numBits <= 8 { - if debugInsert { - fmt.Printf("insert: existing leaf st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits) - } - // We've reached the end of the prefix, whichever - // strideTable we're looking at now is the place where we - // need to insert. - st.insert(bs[finalByteIdx], finalBits, val) - return - } - - // Otherwise, we need to go down at least one more level of - // strideTables. With prefix compression, each level of - // descent can have one of three outcomes: we find a place - // where prefix compression is possible; a place where prefix - // compression made us take a "wrong turn"; or a point along - // our intended path that we have to keep following. - child, created := st.getOrCreateChild(bs[byteIdx]) - switch { - case created: - // The subtree we need for pfx doesn't exist yet. The rest - // of the path, if we were to create it, will consist of a - // bunch of strideTables with a single child each. We can - // use path compression to elide those intermediates, and - // jump straight to the final strideTable that hosts this - // prefix. - child.prefix = finalStridePrefix - child.insert(bs[finalByteIdx], finalBits, val) - if debugInsert { - fmt.Printf("insert: new leaf st.prefix=%s child.prefix=%s addr=%d/%d\n", st.prefix, child.prefix, bs[finalByteIdx], finalBits) - } - return - case !prefixStrictlyContains(child.prefix, pfx): - // child already exists, but its prefix does not contain - // our destination. This means that the path between st - // and child was compressed by a previous insertion, and - // somewhere in the (implicit) compressed path we took a - // wrong turn, into the wrong part of st's subtree. - // - // This is okay, because pfx and child.prefix must have a - // common ancestor node somewhere between st and child. We - // can figure out what node that is, and materialize it. - // - // Once we've done that, we can immediately complete the - // remainder of the insertion in one of two ways, without - // further traversal. See a little further down for what - // those are. - if debugInsert { - fmt.Printf("insert: wrong turn, pfx=%s child.prefix=%s\n", pfx, child.prefix) - } - intermediatePrefix, addrOfExisting, addrOfNew := computePrefixSplit(child.prefix, pfx) - intermediate := &strideTable[T]{prefix: intermediatePrefix} // TODO: make this whole thing be st.AddIntermediate or something? - st.setChild(bs[byteIdx], intermediate) - intermediate.setChild(addrOfExisting, child) - - if debugInsert { - fmt.Printf("insert: new intermediate st.prefix=%s intermediate.prefix=%s child.prefix=%s\n", st.prefix, intermediate.prefix, child.prefix) - } - - // Now, we have a chain of st -> intermediate -> child. - // - // pfx either lives in a different child of intermediate, - // or in intermediate itself. For example, if we created - // the intermediate 1.2.0.0/16, pfx=1.2.3.4/32 would have - // to go into a new child of intermediate, but - // pfx=1.2.0.0/18 would go into intermediate directly. - if remain := pfx.Bits() - intermediate.prefix.Bits(); remain <= 8 { - // pfx lives in intermediate. - if debugInsert { - fmt.Printf("insert: into intermediate intermediate.prefix=%s addr=%d/%d\n", intermediate.prefix, bs[finalByteIdx], finalBits) - } - intermediate.insert(bs[finalByteIdx], finalBits, val) - } else { - // pfx lives in a different child subtree of - // intermediate. By definition this subtree doesn't - // exist at all, otherwise we'd never have entered - // this entire "wrong turn" codepath in the first - // place. - // - // This means we can apply prefix compression as we - // create this new child, and we're done. - st, created = intermediate.getOrCreateChild(addrOfNew) - if !created { - panic("new child path unexpectedly exists during path decompression") - } - st.prefix = finalStridePrefix - st.insert(bs[finalByteIdx], finalBits, val) - if debugInsert { - fmt.Printf("insert: new child st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits) - } - } - - return - default: - // An expected child table exists along pfx's - // path. Continue traversing downwards. - st = child - byteIdx = child.prefix.Bits() / 8 - numBits = pfx.Bits() - child.prefix.Bits() - if debugInsert { - fmt.Printf("insert: descend st.prefix=%s\n", st.prefix) - } - } - } -} - -// Delete removes pfx from the table, if it is present. -func (t *Table[T]) Delete(pfx netip.Prefix) { - t.init() - - // The standard library doesn't enforce normalized prefixes (where - // the non-prefix bits are all zero). These algorithms require - // normalized prefixes, so do it upfront. - pfx = pfx.Masked() - - if debugDelete { - defer func() { - fmt.Printf("%s", t.debugSummary()) - }() - fmt.Printf("\ndelete: start pfx=%s table:\n%s", pfx, t.debugSummary()) - } - - st := t.tableForAddr(pfx.Addr()) - - // This algorithm is full of off-by-one headaches, just like - // Insert. See the comment in Insert for more details. Bottom - // line: we handle the default route as a special case, and that - // simplifies the rest of the code slightly. - if pfx.Bits() == 0 { - if debugDelete { - fmt.Printf("delete: default route\n") - } - st.delete(0, 0) - return - } - - // Deletion may drive the refcount of some strideTables down to - // zero. We need to clean up these dangling tables, so we have to - // keep track of which tables we touch on the way down, and which - // strideEntry index each child is registered in. - // - // Note that the strideIndex and strideTables entries are off-by-one. - // The child table pointer is recorded at i+1, but it is referenced by a - // particular index in the parent table, at index i. - // - // In other words: entry number strideIndexes[0] in - // strideTables[0] is the same pointer as strideTables[1]. - // - // This results in some slightly odd array accesses further down - // in this code, because in a single loop iteration we have to - // write to strideTables[N] and strideIndexes[N-1]. - strideIdx := 0 - strideTables := [16]*strideTable[T]{st} - strideIndexes := [15]uint8{} - - // Similar to Insert, navigate down the tree of strideTables, - // looking for the one that houses this prefix. This part is - // easier than with insertion, since we can bail if the path ends - // early or takes an unexpected detour. However, unlike - // insertion, there's a whole post-deletion cleanup phase later - // on. - // - // As we walk down the tree, byteIdx is the byte of bs we're - // currently examining to choose our next step, and numBits is the - // number of bits that remain in pfx, starting with the byte at - // byteIdx inclusive. - bs := pfx.Addr().AsSlice() - byteIdx := 0 - numBits := pfx.Bits() - for numBits > 8 { - if debugDelete { - fmt.Printf("delete: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix) - } - child := st.getChild(bs[byteIdx]) - if child == nil { - // Prefix can't exist in the table, because one of the - // necessary strideTables doesn't exist. - if debugDelete { - fmt.Printf("delete: missing necessary child pfx=%s\n", pfx) - } - return - } - strideIndexes[strideIdx] = bs[byteIdx] - strideTables[strideIdx+1] = child - strideIdx++ - - // Path compression means byteIdx can jump forwards - // unpredictably. Recompute the next byte to look at from the - // child we just found. - byteIdx = child.prefix.Bits() / 8 - numBits = pfx.Bits() - child.prefix.Bits() - st = child - - if debugDelete { - fmt.Printf("delete: descend st.prefix=%s\n", st.prefix) - } - } - - // We reached a leaf stride table that seems to be in the right - // spot. But path compression might have led us to the wrong - // table. - if !prefixStrictlyContains(st.prefix, pfx) { - // Wrong table, the requested prefix can't exist since its - // path led us to the wrong place. - if debugDelete { - fmt.Printf("delete: wrong leaf table pfx=%s\n", pfx) - } - return - } - if debugDelete { - fmt.Printf("delete: delete from st.prefix=%s addr=%d/%d\n", st.prefix, bs[byteIdx], numBits) - } - if routeExisted := st.delete(bs[byteIdx], numBits); !routeExisted { - // We're in the right strideTable, but pfx wasn't in - // it. Refcounts haven't changed, so we can skip cleanup. - if debugDelete { - fmt.Printf("delete: prefix not present pfx=%s\n", pfx) - } - return - } - - // st.delete reduced st's refcount by one. This table may now be - // reclaimable, and depending on how we can reclaim it, the parent - // tables may also need to be reclaimed. This loop ends as soon as - // an iteration takes no action, or takes an action that doesn't - // alter the parent table's refcounts. - // - // We start our walk back at strideTables[strideIdx], which - // contains st. - for strideIdx > 0 { - cur := strideTables[strideIdx] - if debugDelete { - fmt.Printf("delete: GC? strideIdx=%d st.prefix=%s\n", strideIdx, cur.prefix) - } - if cur.routeRefs > 0 { - // the strideTable has other route entries, it cannot be - // deleted or compacted. - if debugDelete { - fmt.Printf("delete: has other routes st.prefix=%s\n", cur.prefix) - } - return - } - switch cur.childRefs { - case 0: - // no routeRefs and no childRefs, this table can be - // deleted. This will alter the parent table's refcount, - // so we'll have to look at it as well (in the next loop - // iteration). - if debugDelete { - fmt.Printf("delete: remove st.prefix=%s\n", cur.prefix) - } - strideTables[strideIdx-1].deleteChild(strideIndexes[strideIdx-1]) - strideIdx-- - case 1: - // This table has no routes, and a single child. Compact - // this table out of existence by making the parent point - // directly at the one child. This does not affect the - // parent's refcounts, so the parent can't be eligible for - // deletion or compaction, and we can stop. - child := strideTables[strideIdx].findFirstChild() // only 1 child exists, by definition - parent := strideTables[strideIdx-1] - if debugDelete { - fmt.Printf("delete: compact parent.prefix=%s st.prefix=%s child.prefix=%s\n", parent.prefix, cur.prefix, child.prefix) - } - strideTables[strideIdx-1].setChild(strideIndexes[strideIdx-1], child) - return - default: - // This table has two or more children, so it's acting as a "fork in - // the road" between two prefix subtrees. It cannot be deleted, and - // thus no further cleanups are possible. - if debugDelete { - fmt.Printf("delete: fork table st.prefix=%s\n", cur.prefix) - } - return - } - } -} - -// debugSummary prints the tree of allocated strideTables in t, with each -// strideTable's refcount. -func (t *Table[T]) debugSummary() string { - t.init() - var ret bytes.Buffer - fmt.Fprintf(&ret, "v4: ") - strideSummary(&ret, &t.v4, 4) - fmt.Fprintf(&ret, "v6: ") - strideSummary(&ret, &t.v6, 4) - return ret.String() -} - -func strideSummary[T any](w io.Writer, st *strideTable[T], indent int) { - fmt.Fprintf(w, "%s: %d routes, %d children\n", st.prefix, st.routeRefs, st.childRefs) - indent += 4 - st.treeDebugStringRec(w, 1, indent) - for addr, child := range st.children { - if child == nil { - continue - } - fmt.Fprintf(w, "%s%d/8 (%02x/8): ", strings.Repeat(" ", indent), addr, addr) - strideSummary(w, child, indent) - } -} - -// prefixStrictlyContains reports whether child is a prefix within -// parent, but not parent itself. -func prefixStrictlyContains(parent, child netip.Prefix) bool { - return parent.Overlaps(child) && parent.Bits() < child.Bits() -} - -// computePrefixSplit returns the smallest common prefix that contains -// both a and b. lastCommon is 8-bit aligned, with aStride and bStride -// indicating the value of the 8-bit stride immediately following -// lastCommon. -// -// computePrefixSplit is used in constructing an intermediate -// strideTable when a new prefix needs to be inserted in a compressed -// table. It can be read as: given that a is already in the table, and -// b is being inserted, what is the prefix of the new intermediate -// strideTable that needs to be created, and at what addresses in that -// new strideTable should a and b's subsequent strideTables be -// attached? -// -// Note as a special case, this can be called with a==b. An example of -// when this happens: -// - We want to insert the prefix 1.2.0.0/16 -// - A strideTable exists for 1.2.0.0/16, because another child -// prefix already exists (e.g. 1.2.3.4/32) -// - The 1.0.0.0/8 strideTable does not exist, because path -// compression removed it. -// -// In this scenario, the caller of computePrefixSplit ends up making a -// "wrong turn" while traversing strideTables: it was looking for the -// 1.0.0.0/8 table, but ended up at the 1.2.0.0/16 table. When this -// happens, it will invoke computePrefixSplit(1.2.0.0/16, 1.2.0.0/16), -// and we return 1.0.0.0/8 as the missing intermediate. -func computePrefixSplit(a, b netip.Prefix) (lastCommon netip.Prefix, aStride, bStride uint8) { - a = a.Masked() - b = b.Masked() - if a.Bits() == 0 || b.Bits() == 0 { - panic("computePrefixSplit called with a default route") - } - if a.Addr().Is4() != b.Addr().Is4() { - panic("computePrefixSplit called with mismatched address families") - } - - minPrefixLen := a.Bits() - if b.Bits() < minPrefixLen { - minPrefixLen = b.Bits() - } - - commonBits := commonBits(a.Addr(), b.Addr(), minPrefixLen) - // We want to know how many 8-bit strides are shared between a and - // b. Naively, this would be commonBits/8, but this introduces an - // off-by-one error. This is due to the way our ART stores - // prefixes whose length falls exactly on a stride boundary. - // - // Consider 192.168.1.0/24 and 192.168.0.0/16. commonBits - // correctly reports that these prefixes have their first 16 bits - // in common. However, in the ART they only share 1 common stride: - // they both use the 192.0.0.0/8 strideTable, but 192.168.0.0/16 - // is stored as 168/8 within that table, and not as 0/0 in the - // 192.168.0.0/16 table. - // - // So, when commonBits matches the length of one of the inputs and - // falls on a boundary between strides, the strideTable one - // further up from commonBits/8 is the one we need to create, - // which means we have to adjust the stride count down by one. - if commonBits == minPrefixLen { - commonBits-- - } - commonStrides := commonBits / 8 - lastCommon, err := a.Addr().Prefix(commonStrides * 8) - if err != nil { - panic(fmt.Sprintf("computePrefixSplit constructing common prefix: %v", err)) - } - if a.Addr().Is4() { - aStride = a.Addr().As4()[commonStrides] - bStride = b.Addr().As4()[commonStrides] - } else { - aStride = a.Addr().As16()[commonStrides] - bStride = b.Addr().As16()[commonStrides] - } - return lastCommon, aStride, bStride -} - -// commonBits returns the number of common leading bits of a and b. -// If the number of common bits exceeds maxBits, it returns maxBits -// instead. -func commonBits(a, b netip.Addr, maxBits int) int { - if a.Is4() != b.Is4() { - panic("commonStrides called with mismatched address families") - } - var common int - // The following implements an old bit-twiddling trick to compute - // the number of common leading bits: if you XOR two numbers - // together, equal bits become 0 and unequal bits become 1. You - // can then count the number of leading zeros (which is a single - // instruction on modern CPUs) to get the answer. - // - // This code is a little more complex than just XOR + count - // leading zeros, because IPv4 and IPv6 are different sizes, and - // for IPv6 we have to do the math in two 64-bit chunks because Go - // lacks a uint128 type. - if a.Is4() { - aNum, bNum := ipv4AsUint(a), ipv4AsUint(b) - common = bits.LeadingZeros32(aNum ^ bNum) - } else { - aNumHi, aNumLo := ipv6AsUint(a) - bNumHi, bNumLo := ipv6AsUint(b) - common = bits.LeadingZeros64(aNumHi ^ bNumHi) - if common == 64 { - common += bits.LeadingZeros64(aNumLo ^ bNumLo) - } - } - if common > maxBits { - common = maxBits - } - return common -} - -// ipv4AsUint returns ip as a uint32. -func ipv4AsUint(ip netip.Addr) uint32 { - bs := ip.As4() - return binary.BigEndian.Uint32(bs[:]) -} - -// ipv6AsUint returns ip as a pair of uint64s. -func ipv6AsUint(ip netip.Addr) (uint64, uint64) { - bs := ip.As16() - return binary.BigEndian.Uint64(bs[:8]), binary.BigEndian.Uint64(bs[8:]) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package art provides a routing table that implements the Allotment Routing +// Table (ART) algorithm by Donald Knuth, as described in the paper by Yoichi +// Hariguchi. +// +// ART outperforms the traditional radix tree implementations for route lookups, +// insertions, and deletions. +// +// For more information, see Yoichi Hariguchi's paper: +// https://cseweb.ucsd.edu//~varghese/TEACH/cs228/artlookup.pdf +package art + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math/bits" + "net/netip" + "strings" + "sync" +) + +const ( + debugInsert = false + debugDelete = false +) + +// Table is an IPv4 and IPv6 routing table. +type Table[T any] struct { + v4 strideTable[T] + v6 strideTable[T] + initOnce sync.Once +} + +func (t *Table[T]) init() { + t.initOnce.Do(func() { + t.v4.prefix = netip.PrefixFrom(netip.IPv4Unspecified(), 0) + t.v6.prefix = netip.PrefixFrom(netip.IPv6Unspecified(), 0) + }) +} + +func (t *Table[T]) tableForAddr(addr netip.Addr) *strideTable[T] { + if addr.Is6() { + return &t.v6 + } + return &t.v4 +} + +// Get does a route lookup for addr and returns the associated value, or nil if +// no route matched. +func (t *Table[T]) Get(addr netip.Addr) (ret T, ok bool) { + t.init() + + // Ideally we would use addr.AsSlice here, but AsSlice is just + // barely complex enough that it can't be inlined, and that in + // turn causes the slice to escape to the heap. Using As16 and + // manual slicing here helps the compiler keep Get alloc-free. + st := t.tableForAddr(addr) + rawAddr := addr.As16() + bs := rawAddr[:] + if addr.Is4() { + bs = bs[12:] + } + + i := 0 + // With path compression, we might skip over some address bits while walking + // to a strideTable leaf. This means the leaf answer we find might not be + // correct, because path compression took us down the wrong subtree. When + // that happens, we have to backtrack and figure out which most specific + // route further up the tree is relevant to addr, and return that. + // + // So, as we walk down the stride tables, each time we find a non-nil route + // result, we have to remember it and the associated strideTable prefix. + // + // We could also deal with this edge case of path compression by checking + // the strideTable prefix on each table as we descend, but that means we + // have to pay N prefix.Contains checks on every route lookup (where N is + // the number of strideTables in the path), rather than only paying M prefix + // comparisons in the edge case (where M is the number of strideTables in + // the path with a non-nil route of their own). + const maxDepth = 16 + type prefixAndRoute struct { + prefix netip.Prefix + route T + } + strideMatch := make([]prefixAndRoute, 0, maxDepth) +findLeaf: + for { + rt, rtOK, child := st.getValAndChild(bs[i]) + if rtOK { + // This strideTable contains a route that may be relevant to our + // search, remember it. + strideMatch = append(strideMatch, prefixAndRoute{st.prefix, rt}) + } + if child == nil { + // No sub-routes further down, the last thing we recorded + // in strideRoutes is tentatively the result, barring + // misdirection from path compression. + break findLeaf + } + st = child + // Path compression means we may be skipping over some intermediate + // tables. We have to skip forward to whatever depth st now references. + i = st.prefix.Bits() / 8 + } + + // Walk backwards through the hits we recorded in strideRoutes and + // stridePrefixes, returning the first one whose subtree matches addr. + // + // In the common case where path compression did not mislead us, we'll + // return on the first loop iteration because the last route we recorded was + // the correct most-specific route. + for i := len(strideMatch) - 1; i >= 0; i-- { + if m := strideMatch[i]; m.prefix.Contains(addr) { + return m.route, true + } + } + + // We either found no route hits at all (both previous loops terminated + // immediately), or we went on a wild goose chase down a compressed path for + // the wrong prefix, and also found no usable routes on the way back up to + // the root. This is a miss. + return ret, false +} + +// Insert adds pfx to the table, with value val. +// If pfx is already present in the table, its value is set to val. +func (t *Table[T]) Insert(pfx netip.Prefix, val T) { + t.init() + + // The standard library doesn't enforce normalized prefixes (where + // the non-prefix bits are all zero). These algorithms require + // normalized prefixes, so do it upfront. + pfx = pfx.Masked() + + if debugInsert { + defer func() { + fmt.Printf("%s", t.debugSummary()) + }() + fmt.Printf("\ninsert: start pfx=%s\n", pfx) + } + + st := t.tableForAddr(pfx.Addr()) + + // This algorithm is full of off-by-one headaches that boil down + // to the fact that pfx.Bits() has (2^n)+1 values, rather than + // just 2^n. For example, an IPv4 prefix length can be 0 through + // 32, which is 33 values. + // + // This extra possible value creates a lot of problems as we do + // bits and bytes math to traverse strideTables below. So, we + // treat the default route 0/0 specially here, that way the rest + // of the logic goes back to having 2^n values to reason about, + // which can be done in a nice and regular fashion with no edge + // cases. + if pfx.Bits() == 0 { + if debugInsert { + fmt.Printf("insert: default route\n") + } + st.insert(0, 0, val) + return + } + + // No matter what we do as we traverse strideTables, our final + // action will be to insert the last 1-8 bits of pfx into a + // strideTable somewhere. + // + // We calculate upfront the byte position of the end of the + // prefix; the number of bits within that byte that contain prefix + // data; and the prefix of the strideTable into which we'll + // eventually insert. + // + // We need this in a couple different branches of the code below, + // and because the possible values are 1-indexed (1 through 32 for + // ipv4, 1 through 128 for ipv6), the math is very slightly + // unusual to account for the off-by-one indexing. Do it once up + // here, with this large comment, rather than reproduce the subtle + // math in multiple places further down. + finalByteIdx := (pfx.Bits() - 1) / 8 + finalBits := pfx.Bits() - (finalByteIdx * 8) + finalStridePrefix, err := pfx.Addr().Prefix(finalByteIdx * 8) + if err != nil { + panic(fmt.Sprintf("invalid prefix requested: %s/%d", pfx.Addr(), finalByteIdx*8)) + } + if debugInsert { + fmt.Printf("insert: finalByteIdx=%d finalBits=%d finalStridePrefix=%s\n", finalByteIdx, finalBits, finalStridePrefix) + } + + // The strideTable we want to insert into is potentially at the + // end of a chain of strideTables, each one encoding 8 bits of the + // prefix. + // + // We're expecting to walk down a path of tables, although with + // prefix compression we may end up skipping some links in the + // chain, or taking wrong turns and having to course correct. + // + // As we walk down the tree, byteIdx is the byte of bs we're + // currently examining to choose our next step, and numBits is the + // number of bits that remain in pfx, starting with the byte at + // byteIdx inclusive. + bs := pfx.Addr().AsSlice() + byteIdx := 0 + numBits := pfx.Bits() + for { + if debugInsert { + fmt.Printf("insert: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix) + } + if numBits <= 8 { + if debugInsert { + fmt.Printf("insert: existing leaf st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits) + } + // We've reached the end of the prefix, whichever + // strideTable we're looking at now is the place where we + // need to insert. + st.insert(bs[finalByteIdx], finalBits, val) + return + } + + // Otherwise, we need to go down at least one more level of + // strideTables. With prefix compression, each level of + // descent can have one of three outcomes: we find a place + // where prefix compression is possible; a place where prefix + // compression made us take a "wrong turn"; or a point along + // our intended path that we have to keep following. + child, created := st.getOrCreateChild(bs[byteIdx]) + switch { + case created: + // The subtree we need for pfx doesn't exist yet. The rest + // of the path, if we were to create it, will consist of a + // bunch of strideTables with a single child each. We can + // use path compression to elide those intermediates, and + // jump straight to the final strideTable that hosts this + // prefix. + child.prefix = finalStridePrefix + child.insert(bs[finalByteIdx], finalBits, val) + if debugInsert { + fmt.Printf("insert: new leaf st.prefix=%s child.prefix=%s addr=%d/%d\n", st.prefix, child.prefix, bs[finalByteIdx], finalBits) + } + return + case !prefixStrictlyContains(child.prefix, pfx): + // child already exists, but its prefix does not contain + // our destination. This means that the path between st + // and child was compressed by a previous insertion, and + // somewhere in the (implicit) compressed path we took a + // wrong turn, into the wrong part of st's subtree. + // + // This is okay, because pfx and child.prefix must have a + // common ancestor node somewhere between st and child. We + // can figure out what node that is, and materialize it. + // + // Once we've done that, we can immediately complete the + // remainder of the insertion in one of two ways, without + // further traversal. See a little further down for what + // those are. + if debugInsert { + fmt.Printf("insert: wrong turn, pfx=%s child.prefix=%s\n", pfx, child.prefix) + } + intermediatePrefix, addrOfExisting, addrOfNew := computePrefixSplit(child.prefix, pfx) + intermediate := &strideTable[T]{prefix: intermediatePrefix} // TODO: make this whole thing be st.AddIntermediate or something? + st.setChild(bs[byteIdx], intermediate) + intermediate.setChild(addrOfExisting, child) + + if debugInsert { + fmt.Printf("insert: new intermediate st.prefix=%s intermediate.prefix=%s child.prefix=%s\n", st.prefix, intermediate.prefix, child.prefix) + } + + // Now, we have a chain of st -> intermediate -> child. + // + // pfx either lives in a different child of intermediate, + // or in intermediate itself. For example, if we created + // the intermediate 1.2.0.0/16, pfx=1.2.3.4/32 would have + // to go into a new child of intermediate, but + // pfx=1.2.0.0/18 would go into intermediate directly. + if remain := pfx.Bits() - intermediate.prefix.Bits(); remain <= 8 { + // pfx lives in intermediate. + if debugInsert { + fmt.Printf("insert: into intermediate intermediate.prefix=%s addr=%d/%d\n", intermediate.prefix, bs[finalByteIdx], finalBits) + } + intermediate.insert(bs[finalByteIdx], finalBits, val) + } else { + // pfx lives in a different child subtree of + // intermediate. By definition this subtree doesn't + // exist at all, otherwise we'd never have entered + // this entire "wrong turn" codepath in the first + // place. + // + // This means we can apply prefix compression as we + // create this new child, and we're done. + st, created = intermediate.getOrCreateChild(addrOfNew) + if !created { + panic("new child path unexpectedly exists during path decompression") + } + st.prefix = finalStridePrefix + st.insert(bs[finalByteIdx], finalBits, val) + if debugInsert { + fmt.Printf("insert: new child st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits) + } + } + + return + default: + // An expected child table exists along pfx's + // path. Continue traversing downwards. + st = child + byteIdx = child.prefix.Bits() / 8 + numBits = pfx.Bits() - child.prefix.Bits() + if debugInsert { + fmt.Printf("insert: descend st.prefix=%s\n", st.prefix) + } + } + } +} + +// Delete removes pfx from the table, if it is present. +func (t *Table[T]) Delete(pfx netip.Prefix) { + t.init() + + // The standard library doesn't enforce normalized prefixes (where + // the non-prefix bits are all zero). These algorithms require + // normalized prefixes, so do it upfront. + pfx = pfx.Masked() + + if debugDelete { + defer func() { + fmt.Printf("%s", t.debugSummary()) + }() + fmt.Printf("\ndelete: start pfx=%s table:\n%s", pfx, t.debugSummary()) + } + + st := t.tableForAddr(pfx.Addr()) + + // This algorithm is full of off-by-one headaches, just like + // Insert. See the comment in Insert for more details. Bottom + // line: we handle the default route as a special case, and that + // simplifies the rest of the code slightly. + if pfx.Bits() == 0 { + if debugDelete { + fmt.Printf("delete: default route\n") + } + st.delete(0, 0) + return + } + + // Deletion may drive the refcount of some strideTables down to + // zero. We need to clean up these dangling tables, so we have to + // keep track of which tables we touch on the way down, and which + // strideEntry index each child is registered in. + // + // Note that the strideIndex and strideTables entries are off-by-one. + // The child table pointer is recorded at i+1, but it is referenced by a + // particular index in the parent table, at index i. + // + // In other words: entry number strideIndexes[0] in + // strideTables[0] is the same pointer as strideTables[1]. + // + // This results in some slightly odd array accesses further down + // in this code, because in a single loop iteration we have to + // write to strideTables[N] and strideIndexes[N-1]. + strideIdx := 0 + strideTables := [16]*strideTable[T]{st} + strideIndexes := [15]uint8{} + + // Similar to Insert, navigate down the tree of strideTables, + // looking for the one that houses this prefix. This part is + // easier than with insertion, since we can bail if the path ends + // early or takes an unexpected detour. However, unlike + // insertion, there's a whole post-deletion cleanup phase later + // on. + // + // As we walk down the tree, byteIdx is the byte of bs we're + // currently examining to choose our next step, and numBits is the + // number of bits that remain in pfx, starting with the byte at + // byteIdx inclusive. + bs := pfx.Addr().AsSlice() + byteIdx := 0 + numBits := pfx.Bits() + for numBits > 8 { + if debugDelete { + fmt.Printf("delete: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix) + } + child := st.getChild(bs[byteIdx]) + if child == nil { + // Prefix can't exist in the table, because one of the + // necessary strideTables doesn't exist. + if debugDelete { + fmt.Printf("delete: missing necessary child pfx=%s\n", pfx) + } + return + } + strideIndexes[strideIdx] = bs[byteIdx] + strideTables[strideIdx+1] = child + strideIdx++ + + // Path compression means byteIdx can jump forwards + // unpredictably. Recompute the next byte to look at from the + // child we just found. + byteIdx = child.prefix.Bits() / 8 + numBits = pfx.Bits() - child.prefix.Bits() + st = child + + if debugDelete { + fmt.Printf("delete: descend st.prefix=%s\n", st.prefix) + } + } + + // We reached a leaf stride table that seems to be in the right + // spot. But path compression might have led us to the wrong + // table. + if !prefixStrictlyContains(st.prefix, pfx) { + // Wrong table, the requested prefix can't exist since its + // path led us to the wrong place. + if debugDelete { + fmt.Printf("delete: wrong leaf table pfx=%s\n", pfx) + } + return + } + if debugDelete { + fmt.Printf("delete: delete from st.prefix=%s addr=%d/%d\n", st.prefix, bs[byteIdx], numBits) + } + if routeExisted := st.delete(bs[byteIdx], numBits); !routeExisted { + // We're in the right strideTable, but pfx wasn't in + // it. Refcounts haven't changed, so we can skip cleanup. + if debugDelete { + fmt.Printf("delete: prefix not present pfx=%s\n", pfx) + } + return + } + + // st.delete reduced st's refcount by one. This table may now be + // reclaimable, and depending on how we can reclaim it, the parent + // tables may also need to be reclaimed. This loop ends as soon as + // an iteration takes no action, or takes an action that doesn't + // alter the parent table's refcounts. + // + // We start our walk back at strideTables[strideIdx], which + // contains st. + for strideIdx > 0 { + cur := strideTables[strideIdx] + if debugDelete { + fmt.Printf("delete: GC? strideIdx=%d st.prefix=%s\n", strideIdx, cur.prefix) + } + if cur.routeRefs > 0 { + // the strideTable has other route entries, it cannot be + // deleted or compacted. + if debugDelete { + fmt.Printf("delete: has other routes st.prefix=%s\n", cur.prefix) + } + return + } + switch cur.childRefs { + case 0: + // no routeRefs and no childRefs, this table can be + // deleted. This will alter the parent table's refcount, + // so we'll have to look at it as well (in the next loop + // iteration). + if debugDelete { + fmt.Printf("delete: remove st.prefix=%s\n", cur.prefix) + } + strideTables[strideIdx-1].deleteChild(strideIndexes[strideIdx-1]) + strideIdx-- + case 1: + // This table has no routes, and a single child. Compact + // this table out of existence by making the parent point + // directly at the one child. This does not affect the + // parent's refcounts, so the parent can't be eligible for + // deletion or compaction, and we can stop. + child := strideTables[strideIdx].findFirstChild() // only 1 child exists, by definition + parent := strideTables[strideIdx-1] + if debugDelete { + fmt.Printf("delete: compact parent.prefix=%s st.prefix=%s child.prefix=%s\n", parent.prefix, cur.prefix, child.prefix) + } + strideTables[strideIdx-1].setChild(strideIndexes[strideIdx-1], child) + return + default: + // This table has two or more children, so it's acting as a "fork in + // the road" between two prefix subtrees. It cannot be deleted, and + // thus no further cleanups are possible. + if debugDelete { + fmt.Printf("delete: fork table st.prefix=%s\n", cur.prefix) + } + return + } + } +} + +// debugSummary prints the tree of allocated strideTables in t, with each +// strideTable's refcount. +func (t *Table[T]) debugSummary() string { + t.init() + var ret bytes.Buffer + fmt.Fprintf(&ret, "v4: ") + strideSummary(&ret, &t.v4, 4) + fmt.Fprintf(&ret, "v6: ") + strideSummary(&ret, &t.v6, 4) + return ret.String() +} + +func strideSummary[T any](w io.Writer, st *strideTable[T], indent int) { + fmt.Fprintf(w, "%s: %d routes, %d children\n", st.prefix, st.routeRefs, st.childRefs) + indent += 4 + st.treeDebugStringRec(w, 1, indent) + for addr, child := range st.children { + if child == nil { + continue + } + fmt.Fprintf(w, "%s%d/8 (%02x/8): ", strings.Repeat(" ", indent), addr, addr) + strideSummary(w, child, indent) + } +} + +// prefixStrictlyContains reports whether child is a prefix within +// parent, but not parent itself. +func prefixStrictlyContains(parent, child netip.Prefix) bool { + return parent.Overlaps(child) && parent.Bits() < child.Bits() +} + +// computePrefixSplit returns the smallest common prefix that contains +// both a and b. lastCommon is 8-bit aligned, with aStride and bStride +// indicating the value of the 8-bit stride immediately following +// lastCommon. +// +// computePrefixSplit is used in constructing an intermediate +// strideTable when a new prefix needs to be inserted in a compressed +// table. It can be read as: given that a is already in the table, and +// b is being inserted, what is the prefix of the new intermediate +// strideTable that needs to be created, and at what addresses in that +// new strideTable should a and b's subsequent strideTables be +// attached? +// +// Note as a special case, this can be called with a==b. An example of +// when this happens: +// - We want to insert the prefix 1.2.0.0/16 +// - A strideTable exists for 1.2.0.0/16, because another child +// prefix already exists (e.g. 1.2.3.4/32) +// - The 1.0.0.0/8 strideTable does not exist, because path +// compression removed it. +// +// In this scenario, the caller of computePrefixSplit ends up making a +// "wrong turn" while traversing strideTables: it was looking for the +// 1.0.0.0/8 table, but ended up at the 1.2.0.0/16 table. When this +// happens, it will invoke computePrefixSplit(1.2.0.0/16, 1.2.0.0/16), +// and we return 1.0.0.0/8 as the missing intermediate. +func computePrefixSplit(a, b netip.Prefix) (lastCommon netip.Prefix, aStride, bStride uint8) { + a = a.Masked() + b = b.Masked() + if a.Bits() == 0 || b.Bits() == 0 { + panic("computePrefixSplit called with a default route") + } + if a.Addr().Is4() != b.Addr().Is4() { + panic("computePrefixSplit called with mismatched address families") + } + + minPrefixLen := a.Bits() + if b.Bits() < minPrefixLen { + minPrefixLen = b.Bits() + } + + commonBits := commonBits(a.Addr(), b.Addr(), minPrefixLen) + // We want to know how many 8-bit strides are shared between a and + // b. Naively, this would be commonBits/8, but this introduces an + // off-by-one error. This is due to the way our ART stores + // prefixes whose length falls exactly on a stride boundary. + // + // Consider 192.168.1.0/24 and 192.168.0.0/16. commonBits + // correctly reports that these prefixes have their first 16 bits + // in common. However, in the ART they only share 1 common stride: + // they both use the 192.0.0.0/8 strideTable, but 192.168.0.0/16 + // is stored as 168/8 within that table, and not as 0/0 in the + // 192.168.0.0/16 table. + // + // So, when commonBits matches the length of one of the inputs and + // falls on a boundary between strides, the strideTable one + // further up from commonBits/8 is the one we need to create, + // which means we have to adjust the stride count down by one. + if commonBits == minPrefixLen { + commonBits-- + } + commonStrides := commonBits / 8 + lastCommon, err := a.Addr().Prefix(commonStrides * 8) + if err != nil { + panic(fmt.Sprintf("computePrefixSplit constructing common prefix: %v", err)) + } + if a.Addr().Is4() { + aStride = a.Addr().As4()[commonStrides] + bStride = b.Addr().As4()[commonStrides] + } else { + aStride = a.Addr().As16()[commonStrides] + bStride = b.Addr().As16()[commonStrides] + } + return lastCommon, aStride, bStride +} + +// commonBits returns the number of common leading bits of a and b. +// If the number of common bits exceeds maxBits, it returns maxBits +// instead. +func commonBits(a, b netip.Addr, maxBits int) int { + if a.Is4() != b.Is4() { + panic("commonStrides called with mismatched address families") + } + var common int + // The following implements an old bit-twiddling trick to compute + // the number of common leading bits: if you XOR two numbers + // together, equal bits become 0 and unequal bits become 1. You + // can then count the number of leading zeros (which is a single + // instruction on modern CPUs) to get the answer. + // + // This code is a little more complex than just XOR + count + // leading zeros, because IPv4 and IPv6 are different sizes, and + // for IPv6 we have to do the math in two 64-bit chunks because Go + // lacks a uint128 type. + if a.Is4() { + aNum, bNum := ipv4AsUint(a), ipv4AsUint(b) + common = bits.LeadingZeros32(aNum ^ bNum) + } else { + aNumHi, aNumLo := ipv6AsUint(a) + bNumHi, bNumLo := ipv6AsUint(b) + common = bits.LeadingZeros64(aNumHi ^ bNumHi) + if common == 64 { + common += bits.LeadingZeros64(aNumLo ^ bNumLo) + } + } + if common > maxBits { + common = maxBits + } + return common +} + +// ipv4AsUint returns ip as a uint32. +func ipv4AsUint(ip netip.Addr) uint32 { + bs := ip.As4() + return binary.BigEndian.Uint32(bs[:]) +} + +// ipv6AsUint returns ip as a pair of uint64s. +func ipv6AsUint(ip netip.Addr) (uint64, uint64) { + bs := ip.As16() + return binary.BigEndian.Uint64(bs[:8]), binary.BigEndian.Uint64(bs[8:]) +} diff --git a/net/dns/debian_resolvconf.go b/net/dns/debian_resolvconf.go index 2a1fb18de967f..3ffc796e06d1b 100644 --- a/net/dns/debian_resolvconf.go +++ b/net/dns/debian_resolvconf.go @@ -1,184 +1,184 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || freebsd || openbsd - -package dns - -import ( - "bufio" - "bytes" - _ "embed" - "fmt" - "os" - "os/exec" - "path/filepath" - - "tailscale.com/atomicfile" - "tailscale.com/types/logger" -) - -//go:embed resolvconf-workaround.sh -var workaroundScript []byte - -// resolvconfConfigName is the name of the config submitted to -// resolvconf. -// The name starts with 'tun' in order to match the hardcoded -// interface order in debian resolvconf, which will place this -// configuration ahead of regular network links. In theory, this -// doesn't matter because we then fix things up to ensure our config -// is the only one in use, but in case that fails, this will make our -// configuration slightly preferred. -// The 'inet' suffix has no specific meaning, but conventionally -// resolvconf implementations encourage adding a suffix roughly -// indicating where the config came from, and "inet" is the "none of -// the above" value (rather than, say, "ppp" or "dhcp"). -const resolvconfConfigName = "tun-tailscale.inet" - -// resolvconfLibcHookPath is the directory containing libc update -// scripts, which are run by Debian resolvconf when /etc/resolv.conf -// has been updated. -const resolvconfLibcHookPath = "/etc/resolvconf/update-libc.d" - -// resolvconfHookPath is the name of the libc hook script we install -// to force Tailscale's DNS config to take effect. -var resolvconfHookPath = filepath.Join(resolvconfLibcHookPath, "tailscale") - -// resolvconfManager manages DNS configuration using the Debian -// implementation of the `resolvconf` program, written by Thomas Hood. -type resolvconfManager struct { - logf logger.Logf - listRecordsPath string - interfacesDir string - scriptInstalled bool // libc update script has been installed -} - -func newDebianResolvconfManager(logf logger.Logf) (*resolvconfManager, error) { - ret := &resolvconfManager{ - logf: logf, - listRecordsPath: "/lib/resolvconf/list-records", - interfacesDir: "/etc/resolvconf/run/interface", // panic fallback if nothing seems to work - } - - if _, err := os.Stat(ret.listRecordsPath); os.IsNotExist(err) { - // This might be a Debian system from before the big /usr - // merge, try /usr instead. - ret.listRecordsPath = "/usr" + ret.listRecordsPath - } - // The runtime directory is currently (2020-04) canonically - // /etc/resolvconf/run, but the manpage is making noise about - // switching to /run/resolvconf and dropping the /etc path. So, - // let's probe the possible directories and use the first one - // that works. - for _, path := range []string{ - "/etc/resolvconf/run/interface", - "/run/resolvconf/interface", - "/var/run/resolvconf/interface", - } { - if _, err := os.Stat(path); err == nil { - ret.interfacesDir = path - break - } - } - if ret.interfacesDir == "" { - // None of the paths seem to work, use the canonical location - // that the current manpage says to use. - ret.interfacesDir = "/etc/resolvconf/run/interfaces" - } - - return ret, nil -} - -func (m *resolvconfManager) deleteTailscaleConfig() error { - cmd := exec.Command("resolvconf", "-d", resolvconfConfigName) - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("running %s: %s", cmd, out) - } - return nil -} - -func (m *resolvconfManager) SetDNS(config OSConfig) error { - if !m.scriptInstalled { - m.logf("injecting resolvconf workaround script") - if err := os.MkdirAll(resolvconfLibcHookPath, 0755); err != nil { - return err - } - if err := atomicfile.WriteFile(resolvconfHookPath, workaroundScript, 0755); err != nil { - return err - } - m.scriptInstalled = true - } - - if config.IsZero() { - if err := m.deleteTailscaleConfig(); err != nil { - return err - } - } else { - stdin := new(bytes.Buffer) - writeResolvConf(stdin, config.Nameservers, config.SearchDomains) // dns_direct.go - - // This resolvconf implementation doesn't support exclusive - // mode or interface priorities, so it will end up blending - // our configuration with other sources. However, this will - // get fixed up by the script we injected above. - cmd := exec.Command("resolvconf", "-a", resolvconfConfigName) - cmd.Stdin = stdin - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("running %s: %s", cmd, out) - } - } - - return nil -} - -func (m *resolvconfManager) SupportsSplitDNS() bool { - return false -} - -func (m *resolvconfManager) GetBaseConfig() (OSConfig, error) { - var bs bytes.Buffer - - cmd := exec.Command(m.listRecordsPath) - // list-records assumes it's being run with CWD set to the - // interfaces runtime dir, and returns nonsense otherwise. - cmd.Dir = m.interfacesDir - cmd.Stdout = &bs - if err := cmd.Run(); err != nil { - return OSConfig{}, err - } - - var conf bytes.Buffer - sc := bufio.NewScanner(&bs) - for sc.Scan() { - if sc.Text() == resolvconfConfigName { - continue - } - bs, err := os.ReadFile(filepath.Join(m.interfacesDir, sc.Text())) - if err != nil { - if os.IsNotExist(err) { - // Probably raced with a deletion, that's okay. - continue - } - return OSConfig{}, err - } - conf.Write(bs) - conf.WriteByte('\n') - } - - return readResolv(&conf) -} - -func (m *resolvconfManager) Close() error { - if err := m.deleteTailscaleConfig(); err != nil { - return err - } - - if m.scriptInstalled { - m.logf("removing resolvconf workaround script") - os.Remove(resolvconfHookPath) // Best-effort - } - - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux || freebsd || openbsd + +package dns + +import ( + "bufio" + "bytes" + _ "embed" + "fmt" + "os" + "os/exec" + "path/filepath" + + "tailscale.com/atomicfile" + "tailscale.com/types/logger" +) + +//go:embed resolvconf-workaround.sh +var workaroundScript []byte + +// resolvconfConfigName is the name of the config submitted to +// resolvconf. +// The name starts with 'tun' in order to match the hardcoded +// interface order in debian resolvconf, which will place this +// configuration ahead of regular network links. In theory, this +// doesn't matter because we then fix things up to ensure our config +// is the only one in use, but in case that fails, this will make our +// configuration slightly preferred. +// The 'inet' suffix has no specific meaning, but conventionally +// resolvconf implementations encourage adding a suffix roughly +// indicating where the config came from, and "inet" is the "none of +// the above" value (rather than, say, "ppp" or "dhcp"). +const resolvconfConfigName = "tun-tailscale.inet" + +// resolvconfLibcHookPath is the directory containing libc update +// scripts, which are run by Debian resolvconf when /etc/resolv.conf +// has been updated. +const resolvconfLibcHookPath = "/etc/resolvconf/update-libc.d" + +// resolvconfHookPath is the name of the libc hook script we install +// to force Tailscale's DNS config to take effect. +var resolvconfHookPath = filepath.Join(resolvconfLibcHookPath, "tailscale") + +// resolvconfManager manages DNS configuration using the Debian +// implementation of the `resolvconf` program, written by Thomas Hood. +type resolvconfManager struct { + logf logger.Logf + listRecordsPath string + interfacesDir string + scriptInstalled bool // libc update script has been installed +} + +func newDebianResolvconfManager(logf logger.Logf) (*resolvconfManager, error) { + ret := &resolvconfManager{ + logf: logf, + listRecordsPath: "/lib/resolvconf/list-records", + interfacesDir: "/etc/resolvconf/run/interface", // panic fallback if nothing seems to work + } + + if _, err := os.Stat(ret.listRecordsPath); os.IsNotExist(err) { + // This might be a Debian system from before the big /usr + // merge, try /usr instead. + ret.listRecordsPath = "/usr" + ret.listRecordsPath + } + // The runtime directory is currently (2020-04) canonically + // /etc/resolvconf/run, but the manpage is making noise about + // switching to /run/resolvconf and dropping the /etc path. So, + // let's probe the possible directories and use the first one + // that works. + for _, path := range []string{ + "/etc/resolvconf/run/interface", + "/run/resolvconf/interface", + "/var/run/resolvconf/interface", + } { + if _, err := os.Stat(path); err == nil { + ret.interfacesDir = path + break + } + } + if ret.interfacesDir == "" { + // None of the paths seem to work, use the canonical location + // that the current manpage says to use. + ret.interfacesDir = "/etc/resolvconf/run/interfaces" + } + + return ret, nil +} + +func (m *resolvconfManager) deleteTailscaleConfig() error { + cmd := exec.Command("resolvconf", "-d", resolvconfConfigName) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("running %s: %s", cmd, out) + } + return nil +} + +func (m *resolvconfManager) SetDNS(config OSConfig) error { + if !m.scriptInstalled { + m.logf("injecting resolvconf workaround script") + if err := os.MkdirAll(resolvconfLibcHookPath, 0755); err != nil { + return err + } + if err := atomicfile.WriteFile(resolvconfHookPath, workaroundScript, 0755); err != nil { + return err + } + m.scriptInstalled = true + } + + if config.IsZero() { + if err := m.deleteTailscaleConfig(); err != nil { + return err + } + } else { + stdin := new(bytes.Buffer) + writeResolvConf(stdin, config.Nameservers, config.SearchDomains) // dns_direct.go + + // This resolvconf implementation doesn't support exclusive + // mode or interface priorities, so it will end up blending + // our configuration with other sources. However, this will + // get fixed up by the script we injected above. + cmd := exec.Command("resolvconf", "-a", resolvconfConfigName) + cmd.Stdin = stdin + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("running %s: %s", cmd, out) + } + } + + return nil +} + +func (m *resolvconfManager) SupportsSplitDNS() bool { + return false +} + +func (m *resolvconfManager) GetBaseConfig() (OSConfig, error) { + var bs bytes.Buffer + + cmd := exec.Command(m.listRecordsPath) + // list-records assumes it's being run with CWD set to the + // interfaces runtime dir, and returns nonsense otherwise. + cmd.Dir = m.interfacesDir + cmd.Stdout = &bs + if err := cmd.Run(); err != nil { + return OSConfig{}, err + } + + var conf bytes.Buffer + sc := bufio.NewScanner(&bs) + for sc.Scan() { + if sc.Text() == resolvconfConfigName { + continue + } + bs, err := os.ReadFile(filepath.Join(m.interfacesDir, sc.Text())) + if err != nil { + if os.IsNotExist(err) { + // Probably raced with a deletion, that's okay. + continue + } + return OSConfig{}, err + } + conf.Write(bs) + conf.WriteByte('\n') + } + + return readResolv(&conf) +} + +func (m *resolvconfManager) Close() error { + if err := m.deleteTailscaleConfig(); err != nil { + return err + } + + if m.scriptInstalled { + m.logf("removing resolvconf workaround script") + os.Remove(resolvconfHookPath) // Best-effort + } + + return nil +} diff --git a/net/dns/direct_notlinux.go b/net/dns/direct_notlinux.go index 5bd8093d65b7b..c221ca1beaa59 100644 --- a/net/dns/direct_notlinux.go +++ b/net/dns/direct_notlinux.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package dns - -func (m *directManager) runFileWatcher() { - // Not implemented on other platforms. Maybe it could resort to polling. -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package dns + +func (m *directManager) runFileWatcher() { + // Not implemented on other platforms. Maybe it could resort to polling. +} diff --git a/net/dns/flush_default.go b/net/dns/flush_default.go index 73e446389e2c7..eb6d9da417104 100644 --- a/net/dns/flush_default.go +++ b/net/dns/flush_default.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package dns - -func flushCaches() error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package dns + +func flushCaches() error { + return nil +} diff --git a/net/dns/ini.go b/net/dns/ini.go index deec04019560f..1e47d606e970f 100644 --- a/net/dns/ini.go +++ b/net/dns/ini.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build windows - -package dns - -import ( - "regexp" - "strings" -) - -// parseIni parses a basic .ini file, used for wsl.conf. -func parseIni(data string) map[string]map[string]string { - sectionRE := regexp.MustCompile(`^\[([^]]+)\]`) - kvRE := regexp.MustCompile(`^\s*(\w+)\s*=\s*([^#]*)`) - - ini := map[string]map[string]string{} - var section string - for _, line := range strings.Split(data, "\n") { - if res := sectionRE.FindStringSubmatch(line); len(res) > 1 { - section = res[1] - ini[section] = map[string]string{} - } else if res := kvRE.FindStringSubmatch(line); len(res) > 2 { - k, v := strings.TrimSpace(res[1]), strings.TrimSpace(res[2]) - ini[section][k] = v - } - } - return ini -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package dns + +import ( + "regexp" + "strings" +) + +// parseIni parses a basic .ini file, used for wsl.conf. +func parseIni(data string) map[string]map[string]string { + sectionRE := regexp.MustCompile(`^\[([^]]+)\]`) + kvRE := regexp.MustCompile(`^\s*(\w+)\s*=\s*([^#]*)`) + + ini := map[string]map[string]string{} + var section string + for _, line := range strings.Split(data, "\n") { + if res := sectionRE.FindStringSubmatch(line); len(res) > 1 { + section = res[1] + ini[section] = map[string]string{} + } else if res := kvRE.FindStringSubmatch(line); len(res) > 2 { + k, v := strings.TrimSpace(res[1]), strings.TrimSpace(res[2]) + ini[section][k] = v + } + } + return ini +} diff --git a/net/dns/ini_test.go b/net/dns/ini_test.go index 0e9eaa6727bbe..3afe7009caa27 100644 --- a/net/dns/ini_test.go +++ b/net/dns/ini_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build windows - -package dns - -import ( - "reflect" - "testing" -) - -func TestParseIni(t *testing.T) { - var tests = []struct { - src string - want map[string]map[string]string - }{ - { - src: `# appended wsl.conf file -[automount] - enabled = true - root=/mnt/ -# added by tailscale -[network] # trailing comment -generateResolvConf = false # trailing comment`, - want: map[string]map[string]string{ - "automount": {"enabled": "true", "root": "/mnt/"}, - "network": {"generateResolvConf": "false"}, - }, - }, - } - for _, test := range tests { - got := parseIni(test.src) - if !reflect.DeepEqual(got, test.want) { - t.Errorf("for:\n%s\ngot: %v\nwant: %v", test.src, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package dns + +import ( + "reflect" + "testing" +) + +func TestParseIni(t *testing.T) { + var tests = []struct { + src string + want map[string]map[string]string + }{ + { + src: `# appended wsl.conf file +[automount] + enabled = true + root=/mnt/ +# added by tailscale +[network] # trailing comment +generateResolvConf = false # trailing comment`, + want: map[string]map[string]string{ + "automount": {"enabled": "true", "root": "/mnt/"}, + "network": {"generateResolvConf": "false"}, + }, + }, + } + for _, test := range tests { + got := parseIni(test.src) + if !reflect.DeepEqual(got, test.want) { + t.Errorf("for:\n%s\ngot: %v\nwant: %v", test.src, got, test.want) + } + } +} diff --git a/net/dns/noop.go b/net/dns/noop.go index c90162668e85d..9466b57a0f477 100644 --- a/net/dns/noop.go +++ b/net/dns/noop.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dns - -type noopManager struct{} - -func (m noopManager) SetDNS(OSConfig) error { return nil } -func (m noopManager) SupportsSplitDNS() bool { return false } -func (m noopManager) Close() error { return nil } -func (m noopManager) GetBaseConfig() (OSConfig, error) { - return OSConfig{}, ErrGetBaseConfigNotSupported -} - -func NewNoopManager() (noopManager, error) { - return noopManager{}, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +type noopManager struct{} + +func (m noopManager) SetDNS(OSConfig) error { return nil } +func (m noopManager) SupportsSplitDNS() bool { return false } +func (m noopManager) Close() error { return nil } +func (m noopManager) GetBaseConfig() (OSConfig, error) { + return OSConfig{}, ErrGetBaseConfigNotSupported +} + +func NewNoopManager() (noopManager, error) { + return noopManager{}, nil +} diff --git a/net/dns/resolvconf-workaround.sh b/net/dns/resolvconf-workaround.sh index 254b3949b1930..aec6708a06da1 100644 --- a/net/dns/resolvconf-workaround.sh +++ b/net/dns/resolvconf-workaround.sh @@ -1,62 +1,62 @@ -#!/bin/sh -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause -# -# This script is a workaround for a vpn-unfriendly behavior of the -# original resolvconf by Thomas Hood. Unlike the `openresolv` -# implementation (whose binary is also called resolvconf, -# confusingly), the original resolvconf lacks a way to specify -# "exclusive mode" for a provider configuration. In practice, this -# means that if Tailscale wants to install a DNS configuration, that -# config will get "blended" with the configs from other sources, -# rather than override those other sources. -# -# This script gets installed at /etc/resolvconf/update-libc.d, which -# is a directory of hook scripts that get run after resolvconf's libc -# helper has finished rewriting /etc/resolv.conf. It's meant to notify -# consumers of resolv.conf of a new configuration. -# -# Instead, we use that hook mechanism to reach into resolvconf's -# stuff, and rewrite the libc-generated resolv.conf to exclusively -# contain Tailscale's configuration - effectively implementing -# exclusive mode ourselves in post-production. - -set -e - -if [ -n "$TAILSCALE_RESOLVCONF_HOOK_LOOP" ]; then - # Hook script being invoked by itself, skip. - exit 0 -fi - -if [ ! -f tun-tailscale.inet ]; then - # Tailscale isn't trying to manage DNS, do nothing. - exit 0 -fi - -if ! grep resolvconf /etc/resolv.conf >/dev/null; then - # resolvconf isn't managing /etc/resolv.conf, do nothing. - exit 0 -fi - -# Write out a modified /etc/resolv.conf containing just our config. -( - if [ -f /etc/resolvconf/resolv.conf.d/head ]; then - cat /etc/resolvconf/resolv.conf.d/head - fi - echo "# Tailscale workaround applied to set exclusive DNS configuration." - cat tun-tailscale.inet - if [ -f /etc/resolvconf/resolv.conf.d/base ]; then - # Keep options and sortlist, discard other base things since - # they're the things we're trying to override. - grep -e 'sortlist ' -e 'options ' /etc/resolvconf/resolv.conf.d/base || true - fi - if [ -f /etc/resolvconf/resolv.conf.d/tail ]; then - cat /etc/resolvconf/resolv.conf.d/tail - fi -) >/etc/resolv.conf - -if [ -d /etc/resolvconf/update-libc.d ] ; then - # Re-notify libc watchers that we've changed resolv.conf again. - export TAILSCALE_RESOLVCONF_HOOK_LOOP=1 - exec run-parts /etc/resolvconf/update-libc.d -fi +#!/bin/sh +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause +# +# This script is a workaround for a vpn-unfriendly behavior of the +# original resolvconf by Thomas Hood. Unlike the `openresolv` +# implementation (whose binary is also called resolvconf, +# confusingly), the original resolvconf lacks a way to specify +# "exclusive mode" for a provider configuration. In practice, this +# means that if Tailscale wants to install a DNS configuration, that +# config will get "blended" with the configs from other sources, +# rather than override those other sources. +# +# This script gets installed at /etc/resolvconf/update-libc.d, which +# is a directory of hook scripts that get run after resolvconf's libc +# helper has finished rewriting /etc/resolv.conf. It's meant to notify +# consumers of resolv.conf of a new configuration. +# +# Instead, we use that hook mechanism to reach into resolvconf's +# stuff, and rewrite the libc-generated resolv.conf to exclusively +# contain Tailscale's configuration - effectively implementing +# exclusive mode ourselves in post-production. + +set -e + +if [ -n "$TAILSCALE_RESOLVCONF_HOOK_LOOP" ]; then + # Hook script being invoked by itself, skip. + exit 0 +fi + +if [ ! -f tun-tailscale.inet ]; then + # Tailscale isn't trying to manage DNS, do nothing. + exit 0 +fi + +if ! grep resolvconf /etc/resolv.conf >/dev/null; then + # resolvconf isn't managing /etc/resolv.conf, do nothing. + exit 0 +fi + +# Write out a modified /etc/resolv.conf containing just our config. +( + if [ -f /etc/resolvconf/resolv.conf.d/head ]; then + cat /etc/resolvconf/resolv.conf.d/head + fi + echo "# Tailscale workaround applied to set exclusive DNS configuration." + cat tun-tailscale.inet + if [ -f /etc/resolvconf/resolv.conf.d/base ]; then + # Keep options and sortlist, discard other base things since + # they're the things we're trying to override. + grep -e 'sortlist ' -e 'options ' /etc/resolvconf/resolv.conf.d/base || true + fi + if [ -f /etc/resolvconf/resolv.conf.d/tail ]; then + cat /etc/resolvconf/resolv.conf.d/tail + fi +) >/etc/resolv.conf + +if [ -d /etc/resolvconf/update-libc.d ] ; then + # Re-notify libc watchers that we've changed resolv.conf again. + export TAILSCALE_RESOLVCONF_HOOK_LOOP=1 + exec run-parts /etc/resolvconf/update-libc.d +fi diff --git a/net/dns/resolvconf.go b/net/dns/resolvconf.go index 9e2a41c4ac45b..ca584ffcc5f1f 100644 --- a/net/dns/resolvconf.go +++ b/net/dns/resolvconf.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || freebsd || openbsd - -package dns - -import ( - "bytes" - "os/exec" -) - -func resolvconfStyle() string { - if _, err := exec.LookPath("resolvconf"); err != nil { - return "" - } - output, err := exec.Command("resolvconf", "--version").CombinedOutput() - if err != nil { - // Debian resolvconf doesn't understand --version, and - // exits with a specific error code. - if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 99 { - return "debian" - } - } - if bytes.HasPrefix(output, []byte("Debian resolvconf")) { - return "debian" - } - // Treat everything else as openresolv, by far the more popular implementation. - return "openresolv" -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux || freebsd || openbsd + +package dns + +import ( + "bytes" + "os/exec" +) + +func resolvconfStyle() string { + if _, err := exec.LookPath("resolvconf"); err != nil { + return "" + } + output, err := exec.Command("resolvconf", "--version").CombinedOutput() + if err != nil { + // Debian resolvconf doesn't understand --version, and + // exits with a specific error code. + if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 99 { + return "debian" + } + } + if bytes.HasPrefix(output, []byte("Debian resolvconf")) { + return "debian" + } + // Treat everything else as openresolv, by far the more popular implementation. + return "openresolv" +} diff --git a/net/dns/resolvconffile/resolvconffile.go b/net/dns/resolvconffile/resolvconffile.go index 66c1600d8ecba..753000f6d33da 100644 --- a/net/dns/resolvconffile/resolvconffile.go +++ b/net/dns/resolvconffile/resolvconffile.go @@ -1,124 +1,124 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package resolvconffile parses & serializes /etc/resolv.conf-style files. -// -// It's a leaf package so both net/dns and net/dns/resolver can depend -// on it and we can unify a handful of implementations. -// -// The package is verbosely named to disambiguate it from resolvconf -// the daemon, which Tailscale also supports. -package resolvconffile - -import ( - "bufio" - "bytes" - "fmt" - "io" - "net/netip" - "os" - "strings" - - "tailscale.com/util/dnsname" -) - -// Path is the canonical location of resolv.conf. -const Path = "/etc/resolv.conf" - -// Config represents a resolv.conf(5) file. -type Config struct { - // Nameservers are the IP addresses of the nameservers to use. - Nameservers []netip.Addr - - // SearchDomains are the domain suffixes to use when expanding - // single-label name queries. SearchDomains is additive to - // whatever non-Tailscale search domains the OS has. - SearchDomains []dnsname.FQDN -} - -// Write writes c to w. It does so in one Write call. -func (c *Config) Write(w io.Writer) error { - buf := new(bytes.Buffer) - io.WriteString(buf, "# resolv.conf(5) file generated by tailscale\n") - io.WriteString(buf, "# For more info, see https://tailscale.com/s/resolvconf-overwrite\n") - io.WriteString(buf, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n") - for _, ns := range c.Nameservers { - io.WriteString(buf, "nameserver ") - io.WriteString(buf, ns.String()) - io.WriteString(buf, "\n") - } - if len(c.SearchDomains) > 0 { - io.WriteString(buf, "search") - for _, domain := range c.SearchDomains { - io.WriteString(buf, " ") - io.WriteString(buf, domain.WithoutTrailingDot()) - } - io.WriteString(buf, "\n") - } - _, err := w.Write(buf.Bytes()) - return err -} - -// Parse parses a resolv.conf file from r. -func Parse(r io.Reader) (*Config, error) { - config := new(Config) - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - line, _, _ = strings.Cut(line, "#") // remove any comments - line = strings.TrimSpace(line) - - if s, ok := strings.CutPrefix(line, "nameserver"); ok { - nameserver := strings.TrimSpace(s) - if len(nameserver) == len(s) { - return nil, fmt.Errorf("missing space after \"nameserver\" in %q", line) - } - ip, err := netip.ParseAddr(nameserver) - if err != nil { - return nil, err - } - config.Nameservers = append(config.Nameservers, ip) - continue - } - - if s, ok := strings.CutPrefix(line, "search"); ok { - domains := strings.TrimSpace(s) - if len(domains) == len(s) { - // No leading space?! - return nil, fmt.Errorf("missing space after \"search\" in %q", line) - } - for len(domains) > 0 { - domain := domains - i := strings.IndexAny(domain, " \t") - if i != -1 { - domain = domain[:i] - domains = strings.TrimSpace(domains[i+1:]) - } else { - domains = "" - } - fqdn, err := dnsname.ToFQDN(domain) - if err != nil { - return nil, fmt.Errorf("parsing search domain %q in %q: %w", domain, line, err) - } - config.SearchDomains = append(config.SearchDomains, fqdn) - } - } - } - return config, nil -} - -// ParseFile parses the named resolv.conf file. -func ParseFile(name string) (*Config, error) { - fi, err := os.Stat(name) - if err != nil { - return nil, err - } - if n := fi.Size(); n > 10<<10 { - return nil, fmt.Errorf("unexpectedly large %q file: %d bytes", name, n) - } - all, err := os.ReadFile(name) - if err != nil { - return nil, err - } - return Parse(bytes.NewReader(all)) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package resolvconffile parses & serializes /etc/resolv.conf-style files. +// +// It's a leaf package so both net/dns and net/dns/resolver can depend +// on it and we can unify a handful of implementations. +// +// The package is verbosely named to disambiguate it from resolvconf +// the daemon, which Tailscale also supports. +package resolvconffile + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/netip" + "os" + "strings" + + "tailscale.com/util/dnsname" +) + +// Path is the canonical location of resolv.conf. +const Path = "/etc/resolv.conf" + +// Config represents a resolv.conf(5) file. +type Config struct { + // Nameservers are the IP addresses of the nameservers to use. + Nameservers []netip.Addr + + // SearchDomains are the domain suffixes to use when expanding + // single-label name queries. SearchDomains is additive to + // whatever non-Tailscale search domains the OS has. + SearchDomains []dnsname.FQDN +} + +// Write writes c to w. It does so in one Write call. +func (c *Config) Write(w io.Writer) error { + buf := new(bytes.Buffer) + io.WriteString(buf, "# resolv.conf(5) file generated by tailscale\n") + io.WriteString(buf, "# For more info, see https://tailscale.com/s/resolvconf-overwrite\n") + io.WriteString(buf, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n") + for _, ns := range c.Nameservers { + io.WriteString(buf, "nameserver ") + io.WriteString(buf, ns.String()) + io.WriteString(buf, "\n") + } + if len(c.SearchDomains) > 0 { + io.WriteString(buf, "search") + for _, domain := range c.SearchDomains { + io.WriteString(buf, " ") + io.WriteString(buf, domain.WithoutTrailingDot()) + } + io.WriteString(buf, "\n") + } + _, err := w.Write(buf.Bytes()) + return err +} + +// Parse parses a resolv.conf file from r. +func Parse(r io.Reader) (*Config, error) { + config := new(Config) + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + line, _, _ = strings.Cut(line, "#") // remove any comments + line = strings.TrimSpace(line) + + if s, ok := strings.CutPrefix(line, "nameserver"); ok { + nameserver := strings.TrimSpace(s) + if len(nameserver) == len(s) { + return nil, fmt.Errorf("missing space after \"nameserver\" in %q", line) + } + ip, err := netip.ParseAddr(nameserver) + if err != nil { + return nil, err + } + config.Nameservers = append(config.Nameservers, ip) + continue + } + + if s, ok := strings.CutPrefix(line, "search"); ok { + domains := strings.TrimSpace(s) + if len(domains) == len(s) { + // No leading space?! + return nil, fmt.Errorf("missing space after \"search\" in %q", line) + } + for len(domains) > 0 { + domain := domains + i := strings.IndexAny(domain, " \t") + if i != -1 { + domain = domain[:i] + domains = strings.TrimSpace(domains[i+1:]) + } else { + domains = "" + } + fqdn, err := dnsname.ToFQDN(domain) + if err != nil { + return nil, fmt.Errorf("parsing search domain %q in %q: %w", domain, line, err) + } + config.SearchDomains = append(config.SearchDomains, fqdn) + } + } + } + return config, nil +} + +// ParseFile parses the named resolv.conf file. +func ParseFile(name string) (*Config, error) { + fi, err := os.Stat(name) + if err != nil { + return nil, err + } + if n := fi.Size(); n > 10<<10 { + return nil, fmt.Errorf("unexpectedly large %q file: %d bytes", name, n) + } + all, err := os.ReadFile(name) + if err != nil { + return nil, err + } + return Parse(bytes.NewReader(all)) +} diff --git a/net/dns/resolvconfpath_default.go b/net/dns/resolvconfpath_default.go index 02f24a0cfa535..57e82c4c773ea 100644 --- a/net/dns/resolvconfpath_default.go +++ b/net/dns/resolvconfpath_default.go @@ -1,11 +1,11 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !gokrazy - -package dns - -const ( - resolvConf = "/etc/resolv.conf" - backupConf = "/etc/resolv.pre-tailscale-backup.conf" -) +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !gokrazy + +package dns + +const ( + resolvConf = "/etc/resolv.conf" + backupConf = "/etc/resolv.pre-tailscale-backup.conf" +) diff --git a/net/dns/resolvconfpath_gokrazy.go b/net/dns/resolvconfpath_gokrazy.go index 6315596d20efa..f0759b0e31a0f 100644 --- a/net/dns/resolvconfpath_gokrazy.go +++ b/net/dns/resolvconfpath_gokrazy.go @@ -1,11 +1,11 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build gokrazy - -package dns - -const ( - resolvConf = "/tmp/resolv.conf" - backupConf = "/tmp/resolv.pre-tailscale-backup.conf" -) +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build gokrazy + +package dns + +const ( + resolvConf = "/tmp/resolv.conf" + backupConf = "/tmp/resolv.pre-tailscale-backup.conf" +) diff --git a/net/dns/resolver/doh_test.go b/net/dns/resolver/doh_test.go index d9ef970c224f2..a9c28476166fc 100644 --- a/net/dns/resolver/doh_test.go +++ b/net/dns/resolver/doh_test.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package resolver - -import ( - "context" - "flag" - "net/http" - "testing" - - "golang.org/x/net/dns/dnsmessage" - "tailscale.com/net/dns/publicdns" -) - -var testDoH = flag.Bool("test-doh", false, "do real DoH tests against the network") - -const someDNSID = 123 // something non-zero as a test; in violation of spec's SHOULD of 0 - -func someDNSQuestion(t testing.TB) []byte { - b := dnsmessage.NewBuilder(nil, dnsmessage.Header{ - OpCode: 0, // query - RecursionDesired: true, - ID: someDNSID, - }) - b.StartQuestions() // err - b.Question(dnsmessage.Question{ - Name: dnsmessage.MustNewName("tailscale.com."), - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - }) - msg, err := b.Finish() - if err != nil { - t.Fatal(err) - } - return msg -} - -func TestDoH(t *testing.T) { - if !*testDoH { - t.Skip("skipping manual test without --test-doh flag") - } - prefixes := publicdns.KnownDoHPrefixes() - if len(prefixes) == 0 { - t.Fatal("no known DoH") - } - - f := &forwarder{} - - for _, urlBase := range prefixes { - t.Run(urlBase, func(t *testing.T) { - c, ok := f.getKnownDoHClientForProvider(urlBase) - if !ok { - t.Fatal("expected DoH") - } - res, err := f.sendDoH(context.Background(), urlBase, c, someDNSQuestion(t)) - if err != nil { - t.Fatal(err) - } - c.Transport.(*http.Transport).CloseIdleConnections() - - var p dnsmessage.Parser - h, err := p.Start(res) - if err != nil { - t.Fatal(err) - } - if h.ID != someDNSID { - t.Errorf("response DNS ID = %v; want %v", h.ID, someDNSID) - } - - p.SkipAllQuestions() - aa, err := p.AllAnswers() - if err != nil { - t.Fatal(err) - } - if len(aa) == 0 { - t.Fatal("no answers") - } - for _, r := range aa { - t.Logf("got: %v", r.GoString()) - } - }) - } -} - -func TestDoHV6Fallback(t *testing.T) { - for _, base := range publicdns.KnownDoHPrefixes() { - for _, ip := range publicdns.DoHIPsOfBase(base) { - if ip.Is4() { - ip6, ok := publicdns.DoHV6(base) - if !ok { - t.Errorf("no v6 DoH known for %v", ip) - } else if !ip6.Is6() { - t.Errorf("dohV6(%q) returned non-v6 address %v", base, ip6) - } - } - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package resolver + +import ( + "context" + "flag" + "net/http" + "testing" + + "golang.org/x/net/dns/dnsmessage" + "tailscale.com/net/dns/publicdns" +) + +var testDoH = flag.Bool("test-doh", false, "do real DoH tests against the network") + +const someDNSID = 123 // something non-zero as a test; in violation of spec's SHOULD of 0 + +func someDNSQuestion(t testing.TB) []byte { + b := dnsmessage.NewBuilder(nil, dnsmessage.Header{ + OpCode: 0, // query + RecursionDesired: true, + ID: someDNSID, + }) + b.StartQuestions() // err + b.Question(dnsmessage.Question{ + Name: dnsmessage.MustNewName("tailscale.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }) + msg, err := b.Finish() + if err != nil { + t.Fatal(err) + } + return msg +} + +func TestDoH(t *testing.T) { + if !*testDoH { + t.Skip("skipping manual test without --test-doh flag") + } + prefixes := publicdns.KnownDoHPrefixes() + if len(prefixes) == 0 { + t.Fatal("no known DoH") + } + + f := &forwarder{} + + for _, urlBase := range prefixes { + t.Run(urlBase, func(t *testing.T) { + c, ok := f.getKnownDoHClientForProvider(urlBase) + if !ok { + t.Fatal("expected DoH") + } + res, err := f.sendDoH(context.Background(), urlBase, c, someDNSQuestion(t)) + if err != nil { + t.Fatal(err) + } + c.Transport.(*http.Transport).CloseIdleConnections() + + var p dnsmessage.Parser + h, err := p.Start(res) + if err != nil { + t.Fatal(err) + } + if h.ID != someDNSID { + t.Errorf("response DNS ID = %v; want %v", h.ID, someDNSID) + } + + p.SkipAllQuestions() + aa, err := p.AllAnswers() + if err != nil { + t.Fatal(err) + } + if len(aa) == 0 { + t.Fatal("no answers") + } + for _, r := range aa { + t.Logf("got: %v", r.GoString()) + } + }) + } +} + +func TestDoHV6Fallback(t *testing.T) { + for _, base := range publicdns.KnownDoHPrefixes() { + for _, ip := range publicdns.DoHIPsOfBase(base) { + if ip.Is4() { + ip6, ok := publicdns.DoHV6(base) + if !ok { + t.Errorf("no v6 DoH known for %v", ip) + } else if !ip6.Is6() { + t.Errorf("dohV6(%q) returned non-v6 address %v", base, ip6) + } + } + } + } +} diff --git a/net/dns/resolver/macios_ext.go b/net/dns/resolver/macios_ext.go index 37cccc7f0c7ba..e3f979c194d91 100644 --- a/net/dns/resolver/macios_ext.go +++ b/net/dns/resolver/macios_ext.go @@ -1,26 +1,26 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ts_macext && (darwin || ios) - -package resolver - -import ( - "errors" - "net" - - "tailscale.com/net/netmon" - "tailscale.com/net/netns" -) - -func init() { - initListenConfig = initListenConfigNetworkExtension -} - -func initListenConfigNetworkExtension(nc *net.ListenConfig, netMon *netmon.Monitor, tunName string) error { - nif, ok := netMon.InterfaceState().Interface[tunName] - if !ok { - return errors.New("utun not found") - } - return netns.SetListenConfigInterfaceIndex(nc, nif.Interface.Index) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_macext && (darwin || ios) + +package resolver + +import ( + "errors" + "net" + + "tailscale.com/net/netmon" + "tailscale.com/net/netns" +) + +func init() { + initListenConfig = initListenConfigNetworkExtension +} + +func initListenConfigNetworkExtension(nc *net.ListenConfig, netMon *netmon.Monitor, tunName string) error { + nif, ok := netMon.InterfaceState().Interface[tunName] + if !ok { + return errors.New("utun not found") + } + return netns.SetListenConfigInterfaceIndex(nc, nif.Interface.Index) +} diff --git a/net/dns/resolver/tsdns_server_test.go b/net/dns/resolver/tsdns_server_test.go index be47cdfbcf913..82fd3bebf232c 100644 --- a/net/dns/resolver/tsdns_server_test.go +++ b/net/dns/resolver/tsdns_server_test.go @@ -1,333 +1,333 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package resolver - -import ( - "fmt" - "net" - "net/netip" - "strings" - "testing" - - "github.com/miekg/dns" -) - -// This file exists to isolate the test infrastructure -// that depends on github.com/miekg/dns -// from the rest, which only depends on dnsmessage. - -// resolveToIP returns a handler function which responds -// to queries of type A it receives with an A record containing ipv4, -// to queries of type AAAA with an AAAA record containing ipv6, -// to queries of type NS with an NS record containing name. -func resolveToIP(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - - if len(req.Question) != 1 { - panic("not a single-question request") - } - question := req.Question[0] - - var ans dns.RR - switch question.Qtype { - case dns.TypeA: - ans = &dns.A{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - }, - A: ipv4.AsSlice(), - } - case dns.TypeAAAA: - ans = &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - }, - AAAA: ipv6.AsSlice(), - } - case dns.TypeNS: - ans = &dns.NS{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - }, - Ns: ns, - } - } - - m.Answer = append(m.Answer, ans) - w.WriteMsg(m) - } -} - -// resolveToIPLowercase returns a handler function which canonicalizes responses -// by lowercasing the question and answer names, and responds -// to queries of type A it receives with an A record containing ipv4, -// to queries of type AAAA with an AAAA record containing ipv6, -// to queries of type NS with an NS record containing name. -func resolveToIPLowercase(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - - if len(req.Question) != 1 { - panic("not a single-question request") - } - m.Question[0].Name = strings.ToLower(m.Question[0].Name) - question := req.Question[0] - - var ans dns.RR - switch question.Qtype { - case dns.TypeA: - ans = &dns.A{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - }, - A: ipv4.AsSlice(), - } - case dns.TypeAAAA: - ans = &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - }, - AAAA: ipv6.AsSlice(), - } - case dns.TypeNS: - ans = &dns.NS{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - }, - Ns: ns, - } - } - - m.Answer = append(m.Answer, ans) - w.WriteMsg(m) - } -} - -// resolveToTXT returns a handler function which responds to queries of type TXT -// it receives with the strings in txts. -func resolveToTXT(txts []string, ednsMaxSize uint16) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - - if len(req.Question) != 1 { - panic("not a single-question request") - } - question := req.Question[0] - - if question.Qtype != dns.TypeTXT { - w.WriteMsg(m) - return - } - - ans := &dns.TXT{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeTXT, - Class: dns.ClassINET, - }, - Txt: txts, - } - - m.Answer = append(m.Answer, ans) - - queryInfo := &dns.TXT{ - Hdr: dns.RR_Header{ - Name: "query-info.test.", - Rrtype: dns.TypeTXT, - Class: dns.ClassINET, - }, - } - - if edns := req.IsEdns0(); edns == nil { - queryInfo.Txt = []string{"EDNS=false"} - } else { - queryInfo.Txt = []string{"EDNS=true", fmt.Sprintf("maxSize=%v", edns.UDPSize())} - } - - m.Extra = append(m.Extra, queryInfo) - - if ednsMaxSize > 0 { - m.SetEdns0(ednsMaxSize, false) - } - - if err := w.WriteMsg(m); err != nil { - panic(err) - } - } -} - -var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetRcode(req, dns.RcodeNameError) - w.WriteMsg(m) -}) - -// weirdoGoCNAMEHandler returns a DNS handler that satisfies -// Go's weird Resolver.LookupCNAME (read its godoc carefully!). -// -// This doesn't even return a CNAME record, because that's not -// what Go looks for. -func weirdoGoCNAMEHandler(target string) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - question := req.Question[0] - - switch question.Qtype { - case dns.TypeA: - m.Answer = append(m.Answer, &dns.CNAME{ - Hdr: dns.RR_Header{ - Name: target, - Rrtype: dns.TypeCNAME, - Class: dns.ClassINET, - Ttl: 600, - }, - Target: target, - }) - case dns.TypeAAAA: - m.Answer = append(m.Answer, &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: target, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: 600, - }, - AAAA: net.ParseIP("1::2"), - }) - } - w.WriteMsg(m) - } -} - -// dnsHandler returns a handler that replies with the answers/options -// provided. -// -// Types supported: netip.Addr. -func dnsHandler(answers ...any) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - if len(req.Question) != 1 { - panic("not a single-question request") - } - m.RecursionAvailable = true // to stop net package's errLameReferral on empty replies - - question := req.Question[0] - for _, a := range answers { - switch a := a.(type) { - default: - panic(fmt.Sprintf("unsupported dnsHandler arg %T", a)) - case netip.Addr: - ip := a - if ip.Is4() { - m.Answer = append(m.Answer, &dns.A{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - }, - A: ip.AsSlice(), - }) - } else if ip.Is6() { - m.Answer = append(m.Answer, &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - }, - AAAA: ip.AsSlice(), - }) - } - case dns.PTR: - ptr := a - ptr.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypePTR, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &ptr) - case dns.CNAME: - c := a - c.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeCNAME, - Class: dns.ClassINET, - Ttl: 600, - } - m.Answer = append(m.Answer, &c) - case dns.TXT: - txt := a - txt.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeTXT, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &txt) - case dns.SRV: - srv := a - srv.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeSRV, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &srv) - case dns.NS: - rr := a - rr.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &rr) - } - } - w.WriteMsg(m) - } -} - -func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server { - if len(records)%2 != 0 { - panic("must have an even number of record values") - } - mux := dns.NewServeMux() - for i := 0; i < len(records); i += 2 { - name := records[i].(string) - handler := records[i+1].(dns.Handler) - mux.Handle(name, handler) - } - waitch := make(chan struct{}) - server := &dns.Server{ - Addr: addr, - Net: "udp", - Handler: mux, - NotifyStartedFunc: func() { close(waitch) }, - ReusePort: true, - } - - go func() { - err := server.ListenAndServe() - if err != nil { - panic(fmt.Sprintf("ListenAndServe(%q): %v", addr, err)) - } - }() - - <-waitch - return server -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package resolver + +import ( + "fmt" + "net" + "net/netip" + "strings" + "testing" + + "github.com/miekg/dns" +) + +// This file exists to isolate the test infrastructure +// that depends on github.com/miekg/dns +// from the rest, which only depends on dnsmessage. + +// resolveToIP returns a handler function which responds +// to queries of type A it receives with an A record containing ipv4, +// to queries of type AAAA with an AAAA record containing ipv6, +// to queries of type NS with an NS record containing name. +func resolveToIP(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + + if len(req.Question) != 1 { + panic("not a single-question request") + } + question := req.Question[0] + + var ans dns.RR + switch question.Qtype { + case dns.TypeA: + ans = &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: ipv4.AsSlice(), + } + case dns.TypeAAAA: + ans = &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + }, + AAAA: ipv6.AsSlice(), + } + case dns.TypeNS: + ans = &dns.NS{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + }, + Ns: ns, + } + } + + m.Answer = append(m.Answer, ans) + w.WriteMsg(m) + } +} + +// resolveToIPLowercase returns a handler function which canonicalizes responses +// by lowercasing the question and answer names, and responds +// to queries of type A it receives with an A record containing ipv4, +// to queries of type AAAA with an AAAA record containing ipv6, +// to queries of type NS with an NS record containing name. +func resolveToIPLowercase(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + + if len(req.Question) != 1 { + panic("not a single-question request") + } + m.Question[0].Name = strings.ToLower(m.Question[0].Name) + question := req.Question[0] + + var ans dns.RR + switch question.Qtype { + case dns.TypeA: + ans = &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: ipv4.AsSlice(), + } + case dns.TypeAAAA: + ans = &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + }, + AAAA: ipv6.AsSlice(), + } + case dns.TypeNS: + ans = &dns.NS{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + }, + Ns: ns, + } + } + + m.Answer = append(m.Answer, ans) + w.WriteMsg(m) + } +} + +// resolveToTXT returns a handler function which responds to queries of type TXT +// it receives with the strings in txts. +func resolveToTXT(txts []string, ednsMaxSize uint16) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + + if len(req.Question) != 1 { + panic("not a single-question request") + } + question := req.Question[0] + + if question.Qtype != dns.TypeTXT { + w.WriteMsg(m) + return + } + + ans := &dns.TXT{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + }, + Txt: txts, + } + + m.Answer = append(m.Answer, ans) + + queryInfo := &dns.TXT{ + Hdr: dns.RR_Header{ + Name: "query-info.test.", + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + }, + } + + if edns := req.IsEdns0(); edns == nil { + queryInfo.Txt = []string{"EDNS=false"} + } else { + queryInfo.Txt = []string{"EDNS=true", fmt.Sprintf("maxSize=%v", edns.UDPSize())} + } + + m.Extra = append(m.Extra, queryInfo) + + if ednsMaxSize > 0 { + m.SetEdns0(ednsMaxSize, false) + } + + if err := w.WriteMsg(m); err != nil { + panic(err) + } + } +} + +var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(req, dns.RcodeNameError) + w.WriteMsg(m) +}) + +// weirdoGoCNAMEHandler returns a DNS handler that satisfies +// Go's weird Resolver.LookupCNAME (read its godoc carefully!). +// +// This doesn't even return a CNAME record, because that's not +// what Go looks for. +func weirdoGoCNAMEHandler(target string) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + question := req.Question[0] + + switch question.Qtype { + case dns.TypeA: + m.Answer = append(m.Answer, &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: 600, + }, + Target: target, + }) + case dns.TypeAAAA: + m.Answer = append(m.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 600, + }, + AAAA: net.ParseIP("1::2"), + }) + } + w.WriteMsg(m) + } +} + +// dnsHandler returns a handler that replies with the answers/options +// provided. +// +// Types supported: netip.Addr. +func dnsHandler(answers ...any) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + if len(req.Question) != 1 { + panic("not a single-question request") + } + m.RecursionAvailable = true // to stop net package's errLameReferral on empty replies + + question := req.Question[0] + for _, a := range answers { + switch a := a.(type) { + default: + panic(fmt.Sprintf("unsupported dnsHandler arg %T", a)) + case netip.Addr: + ip := a + if ip.Is4() { + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: ip.AsSlice(), + }) + } else if ip.Is6() { + m.Answer = append(m.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + }, + AAAA: ip.AsSlice(), + }) + } + case dns.PTR: + ptr := a + ptr.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &ptr) + case dns.CNAME: + c := a + c.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: 600, + } + m.Answer = append(m.Answer, &c) + case dns.TXT: + txt := a + txt.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &txt) + case dns.SRV: + srv := a + srv.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &srv) + case dns.NS: + rr := a + rr.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &rr) + } + } + w.WriteMsg(m) + } +} + +func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server { + if len(records)%2 != 0 { + panic("must have an even number of record values") + } + mux := dns.NewServeMux() + for i := 0; i < len(records); i += 2 { + name := records[i].(string) + handler := records[i+1].(dns.Handler) + mux.Handle(name, handler) + } + waitch := make(chan struct{}) + server := &dns.Server{ + Addr: addr, + Net: "udp", + Handler: mux, + NotifyStartedFunc: func() { close(waitch) }, + ReusePort: true, + } + + go func() { + err := server.ListenAndServe() + if err != nil { + panic(fmt.Sprintf("ListenAndServe(%q): %v", addr, err)) + } + }() + + <-waitch + return server +} diff --git a/net/dns/utf.go b/net/dns/utf.go index 267829c05fbfa..0c1db69acb33b 100644 --- a/net/dns/utf.go +++ b/net/dns/utf.go @@ -1,55 +1,55 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dns - -// This code is only used in Windows builds, but is in an -// OS-independent file so tests can run all the time. - -import ( - "bytes" - "encoding/binary" - "unicode/utf16" -) - -// maybeUnUTF16 tries to detect whether bs contains UTF-16, and if so -// translates it to regular UTF-8. -// -// Some of wsl.exe's output get printed as UTF-16, which breaks a -// bunch of things. Try to detect this by looking for a zero byte in -// the first few bytes of output (which will appear if any of those -// codepoints are basic ASCII - very likely). From that we can infer -// that UTF-16 is being printed, and the byte order in use, and we -// decode that back to UTF-8. -// -// https://github.com/microsoft/WSL/issues/4607 -func maybeUnUTF16(bs []byte) []byte { - if len(bs)%2 != 0 { - // Can't be complete UTF-16. - return bs - } - checkLen := 20 - if len(bs) < checkLen { - checkLen = len(bs) - } - zeroOff := bytes.IndexByte(bs[:checkLen], 0) - if zeroOff == -1 { - return bs - } - - // We assume wsl.exe is trying to print an ASCII codepoint, - // meaning the zero byte is in the upper 8 bits of the - // codepoint. That means we can use the zero's byte offset to - // work out if we're seeing little-endian or big-endian - // UTF-16. - var endian binary.ByteOrder = binary.LittleEndian - if zeroOff%2 == 0 { - endian = binary.BigEndian - } - - var u16 []uint16 - for i := 0; i < len(bs); i += 2 { - u16 = append(u16, endian.Uint16(bs[i:])) - } - return []byte(string(utf16.Decode(u16))) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +// This code is only used in Windows builds, but is in an +// OS-independent file so tests can run all the time. + +import ( + "bytes" + "encoding/binary" + "unicode/utf16" +) + +// maybeUnUTF16 tries to detect whether bs contains UTF-16, and if so +// translates it to regular UTF-8. +// +// Some of wsl.exe's output get printed as UTF-16, which breaks a +// bunch of things. Try to detect this by looking for a zero byte in +// the first few bytes of output (which will appear if any of those +// codepoints are basic ASCII - very likely). From that we can infer +// that UTF-16 is being printed, and the byte order in use, and we +// decode that back to UTF-8. +// +// https://github.com/microsoft/WSL/issues/4607 +func maybeUnUTF16(bs []byte) []byte { + if len(bs)%2 != 0 { + // Can't be complete UTF-16. + return bs + } + checkLen := 20 + if len(bs) < checkLen { + checkLen = len(bs) + } + zeroOff := bytes.IndexByte(bs[:checkLen], 0) + if zeroOff == -1 { + return bs + } + + // We assume wsl.exe is trying to print an ASCII codepoint, + // meaning the zero byte is in the upper 8 bits of the + // codepoint. That means we can use the zero's byte offset to + // work out if we're seeing little-endian or big-endian + // UTF-16. + var endian binary.ByteOrder = binary.LittleEndian + if zeroOff%2 == 0 { + endian = binary.BigEndian + } + + var u16 []uint16 + for i := 0; i < len(bs); i += 2 { + u16 = append(u16, endian.Uint16(bs[i:])) + } + return []byte(string(utf16.Decode(u16))) +} diff --git a/net/dns/utf_test.go b/net/dns/utf_test.go index fcf593497e08b..b5fd372622519 100644 --- a/net/dns/utf_test.go +++ b/net/dns/utf_test.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dns - -import "testing" - -func TestMaybeUnUTF16(t *testing.T) { - tests := []struct { - in string - want string - }{ - {"abc", "abc"}, // UTF-8 - {"a\x00b\x00c\x00", "abc"}, // UTF-16-LE - {"\x00a\x00b\x00c", "abc"}, // UTF-16-BE - } - - for _, test := range tests { - got := string(maybeUnUTF16([]byte(test.in))) - if got != test.want { - t.Errorf("maybeUnUTF16(%q) = %q, want %q", test.in, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +import "testing" + +func TestMaybeUnUTF16(t *testing.T) { + tests := []struct { + in string + want string + }{ + {"abc", "abc"}, // UTF-8 + {"a\x00b\x00c\x00", "abc"}, // UTF-16-LE + {"\x00a\x00b\x00c", "abc"}, // UTF-16-BE + } + + for _, test := range tests { + got := string(maybeUnUTF16([]byte(test.in))) + if got != test.want { + t.Errorf("maybeUnUTF16(%q) = %q, want %q", test.in, got, test.want) + } + } +} diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index 6a4b969315050..ef4249b7401f3 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -1,242 +1,242 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dnscache - -import ( - "context" - "errors" - "flag" - "fmt" - "net" - "net/netip" - "reflect" - "testing" - "time" - - "tailscale.com/tstest" -) - -var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial") - -func TestDialer(t *testing.T) { - if *dialTest == "" { - t.Skip("skipping; --dial-test is blank") - } - r := &Resolver{Logf: t.Logf} - var std net.Dialer - dialer := Dialer(std.DialContext, r) - t0 := time.Now() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - c, err := dialer(ctx, "tcp", *dialTest) - if err != nil { - t.Fatal(err) - } - t.Logf("dialed in %v", time.Since(t0)) - c.Close() -} - -func TestDialCall_DNSWasTrustworthy(t *testing.T) { - type step struct { - ip netip.Addr // IP we pretended to dial - err error // the dial error or nil for success - } - mustIP := netip.MustParseAddr - errFail := errors.New("some connect failure") - tests := []struct { - name string - steps []step - want bool - }{ - { - name: "no-info", - want: false, - }, - { - name: "previous-dial", - steps: []step{ - {mustIP("2003::1"), nil}, - {mustIP("2003::1"), errFail}, - }, - want: true, - }, - { - name: "no-previous-dial", - steps: []step{ - {mustIP("2003::1"), errFail}, - }, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := &dialer{ - pastConnect: map[netip.Addr]time.Time{}, - } - dc := &dialCall{ - d: d, - } - for _, st := range tt.steps { - dc.noteDialResult(st.ip, st.err) - } - got := dc.dnsWasTrustworthy() - if got != tt.want { - t.Errorf("got %v; want %v", got, tt.want) - } - }) - } -} - -func TestDialCall_uniqueIPs(t *testing.T) { - dc := &dialCall{} - mustIP := netip.MustParseAddr - errFail := errors.New("some connect failure") - dc.noteDialResult(mustIP("2003::1"), errFail) - dc.noteDialResult(mustIP("2003::2"), errFail) - got := dc.uniqueIPs([]netip.Addr{ - mustIP("2003::1"), - mustIP("2003::2"), - mustIP("2003::2"), - mustIP("2003::3"), - mustIP("2003::3"), - mustIP("2003::4"), - mustIP("2003::4"), - }) - want := []netip.Addr{ - mustIP("2003::3"), - mustIP("2003::4"), - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v; want %v", got, want) - } -} - -func TestResolverAllHostStaticResult(t *testing.T) { - r := &Resolver{ - Logf: t.Logf, - SingleHost: "foo.bar", - SingleHostStaticResult: []netip.Addr{ - netip.MustParseAddr("2001:4860:4860::8888"), - netip.MustParseAddr("2001:4860:4860::8844"), - netip.MustParseAddr("8.8.8.8"), - netip.MustParseAddr("8.8.4.4"), - }, - } - ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar") - if err != nil { - t.Fatal(err) - } - if got, want := ip4.String(), "8.8.8.8"; got != want { - t.Errorf("ip4 got %q; want %q", got, want) - } - if got, want := ip6.String(), "2001:4860:4860::8888"; got != want { - t.Errorf("ip4 got %q; want %q", got, want) - } - if got, want := fmt.Sprintf("%q", allIPs), `["2001:4860:4860::8888" "2001:4860:4860::8844" "8.8.8.8" "8.8.4.4"]`; got != want { - t.Errorf("allIPs got %q; want %q", got, want) - } - - _, _, _, err = r.LookupIP(context.Background(), "bad") - if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want { - t.Errorf("bad dial error got %q; want %q", got, want) - } -} - -func TestShouldTryBootstrap(t *testing.T) { - tstest.Replace(t, &debug, func() bool { return true }) - - type step struct { - ip netip.Addr // IP we pretended to dial - err error // the dial error or nil for success - } - - canceled, cancel := context.WithCancel(context.Background()) - cancel() - - deadlineExceeded, cancel := context.WithTimeout(context.Background(), 0) - defer cancel() - - ctx := context.Background() - errFailed := errors.New("some failure") - - cacheWithFallback := &Resolver{ - Logf: t.Logf, - LookupIPFallback: func(_ context.Context, _ string) ([]netip.Addr, error) { - panic("unimplemented") - }, - } - cacheNoFallback := &Resolver{Logf: t.Logf} - - testCases := []struct { - name string - steps []step - ctx context.Context - err error - noFallback bool - want bool - }{ - { - name: "no-error", - ctx: ctx, - err: nil, - want: false, - }, - { - name: "canceled", - ctx: canceled, - err: errFailed, - want: false, - }, - { - name: "deadline-exceeded", - ctx: deadlineExceeded, - err: errFailed, - want: false, - }, - { - name: "no-fallback", - ctx: ctx, - err: errFailed, - noFallback: true, - want: false, - }, - { - name: "dns-was-trustworthy", - ctx: ctx, - err: errFailed, - steps: []step{ - {netip.MustParseAddr("2003::1"), nil}, - {netip.MustParseAddr("2003::1"), errFailed}, - }, - want: false, - }, - { - name: "should-bootstrap", - ctx: ctx, - err: errFailed, - want: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - d := &dialer{ - pastConnect: map[netip.Addr]time.Time{}, - } - if tt.noFallback { - d.dnsCache = cacheNoFallback - } else { - d.dnsCache = cacheWithFallback - } - dc := &dialCall{d: d} - for _, st := range tt.steps { - dc.noteDialResult(st.ip, st.err) - } - got := d.shouldTryBootstrap(tt.ctx, tt.err, dc) - if got != tt.want { - t.Errorf("got %v; want %v", got, tt.want) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dnscache + +import ( + "context" + "errors" + "flag" + "fmt" + "net" + "net/netip" + "reflect" + "testing" + "time" + + "tailscale.com/tstest" +) + +var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial") + +func TestDialer(t *testing.T) { + if *dialTest == "" { + t.Skip("skipping; --dial-test is blank") + } + r := &Resolver{Logf: t.Logf} + var std net.Dialer + dialer := Dialer(std.DialContext, r) + t0 := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + c, err := dialer(ctx, "tcp", *dialTest) + if err != nil { + t.Fatal(err) + } + t.Logf("dialed in %v", time.Since(t0)) + c.Close() +} + +func TestDialCall_DNSWasTrustworthy(t *testing.T) { + type step struct { + ip netip.Addr // IP we pretended to dial + err error // the dial error or nil for success + } + mustIP := netip.MustParseAddr + errFail := errors.New("some connect failure") + tests := []struct { + name string + steps []step + want bool + }{ + { + name: "no-info", + want: false, + }, + { + name: "previous-dial", + steps: []step{ + {mustIP("2003::1"), nil}, + {mustIP("2003::1"), errFail}, + }, + want: true, + }, + { + name: "no-previous-dial", + steps: []step{ + {mustIP("2003::1"), errFail}, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &dialer{ + pastConnect: map[netip.Addr]time.Time{}, + } + dc := &dialCall{ + d: d, + } + for _, st := range tt.steps { + dc.noteDialResult(st.ip, st.err) + } + got := dc.dnsWasTrustworthy() + if got != tt.want { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} + +func TestDialCall_uniqueIPs(t *testing.T) { + dc := &dialCall{} + mustIP := netip.MustParseAddr + errFail := errors.New("some connect failure") + dc.noteDialResult(mustIP("2003::1"), errFail) + dc.noteDialResult(mustIP("2003::2"), errFail) + got := dc.uniqueIPs([]netip.Addr{ + mustIP("2003::1"), + mustIP("2003::2"), + mustIP("2003::2"), + mustIP("2003::3"), + mustIP("2003::3"), + mustIP("2003::4"), + mustIP("2003::4"), + }) + want := []netip.Addr{ + mustIP("2003::3"), + mustIP("2003::4"), + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } +} + +func TestResolverAllHostStaticResult(t *testing.T) { + r := &Resolver{ + Logf: t.Logf, + SingleHost: "foo.bar", + SingleHostStaticResult: []netip.Addr{ + netip.MustParseAddr("2001:4860:4860::8888"), + netip.MustParseAddr("2001:4860:4860::8844"), + netip.MustParseAddr("8.8.8.8"), + netip.MustParseAddr("8.8.4.4"), + }, + } + ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar") + if err != nil { + t.Fatal(err) + } + if got, want := ip4.String(), "8.8.8.8"; got != want { + t.Errorf("ip4 got %q; want %q", got, want) + } + if got, want := ip6.String(), "2001:4860:4860::8888"; got != want { + t.Errorf("ip4 got %q; want %q", got, want) + } + if got, want := fmt.Sprintf("%q", allIPs), `["2001:4860:4860::8888" "2001:4860:4860::8844" "8.8.8.8" "8.8.4.4"]`; got != want { + t.Errorf("allIPs got %q; want %q", got, want) + } + + _, _, _, err = r.LookupIP(context.Background(), "bad") + if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want { + t.Errorf("bad dial error got %q; want %q", got, want) + } +} + +func TestShouldTryBootstrap(t *testing.T) { + tstest.Replace(t, &debug, func() bool { return true }) + + type step struct { + ip netip.Addr // IP we pretended to dial + err error // the dial error or nil for success + } + + canceled, cancel := context.WithCancel(context.Background()) + cancel() + + deadlineExceeded, cancel := context.WithTimeout(context.Background(), 0) + defer cancel() + + ctx := context.Background() + errFailed := errors.New("some failure") + + cacheWithFallback := &Resolver{ + Logf: t.Logf, + LookupIPFallback: func(_ context.Context, _ string) ([]netip.Addr, error) { + panic("unimplemented") + }, + } + cacheNoFallback := &Resolver{Logf: t.Logf} + + testCases := []struct { + name string + steps []step + ctx context.Context + err error + noFallback bool + want bool + }{ + { + name: "no-error", + ctx: ctx, + err: nil, + want: false, + }, + { + name: "canceled", + ctx: canceled, + err: errFailed, + want: false, + }, + { + name: "deadline-exceeded", + ctx: deadlineExceeded, + err: errFailed, + want: false, + }, + { + name: "no-fallback", + ctx: ctx, + err: errFailed, + noFallback: true, + want: false, + }, + { + name: "dns-was-trustworthy", + ctx: ctx, + err: errFailed, + steps: []step{ + {netip.MustParseAddr("2003::1"), nil}, + {netip.MustParseAddr("2003::1"), errFailed}, + }, + want: false, + }, + { + name: "should-bootstrap", + ctx: ctx, + err: errFailed, + want: true, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + d := &dialer{ + pastConnect: map[netip.Addr]time.Time{}, + } + if tt.noFallback { + d.dnsCache = cacheNoFallback + } else { + d.dnsCache = cacheWithFallback + } + dc := &dialCall{d: d} + for _, st := range tt.steps { + dc.noteDialResult(st.ip, st.err) + } + got := d.shouldTryBootstrap(tt.ctx, tt.err, dc) + if got != tt.want { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} diff --git a/net/dnscache/messagecache_test.go b/net/dnscache/messagecache_test.go index 18af324597a43..41fc334483f78 100644 --- a/net/dnscache/messagecache_test.go +++ b/net/dnscache/messagecache_test.go @@ -1,291 +1,291 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dnscache - -import ( - "bytes" - "context" - "errors" - "fmt" - "net" - "runtime" - "testing" - "time" - - "golang.org/x/net/dns/dnsmessage" - "tailscale.com/tstest" -) - -func TestMessageCache(t *testing.T) { - clock := tstest.NewClock(tstest.ClockOpts{ - Start: time.Date(1987, 11, 1, 0, 0, 0, 0, time.UTC), - }) - mc := &MessageCache{Clock: clock.Now} - mc.SetMaxCacheSize(2) - clock.Advance(time.Second) - - var out bytes.Buffer - if err := mc.ReplyFromCache(&out, makeQ(1, "foo.com.")); err != ErrCacheMiss { - t.Fatalf("unexpected error: %v", err) - } - - if err := mc.AddCacheEntry( - makeQ(2, "foo.com."), - makeRes(2, "FOO.COM.", ttlOpt(10), - &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}, - &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}})); err != nil { - t.Fatal(err) - } - - // Expect cache hit, with 10 seconds remaining. - out.Reset() - if err := mc.ReplyFromCache(&out, makeQ(3, "foo.com.")); err != nil { - t.Fatalf("expected cache hit; got: %v", err) - } - if p := mustParseResponse(t, out.Bytes()); p.TxID != 3 { - t.Errorf("TxID = %v; want %v", p.TxID, 3) - } else if p.TTL != 10 { - t.Errorf("TTL = %v; want 10", p.TTL) - } - - // One second elapses, expect a cache hit, with 9 seconds - // remaining. - clock.Advance(time.Second) - out.Reset() - if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.")); err != nil { - t.Fatalf("expected cache hit; got: %v", err) - } - if p := mustParseResponse(t, out.Bytes()); p.TxID != 4 { - t.Errorf("TxID = %v; want %v", p.TxID, 4) - } else if p.TTL != 9 { - t.Errorf("TTL = %v; want 9", p.TTL) - } - - // Expect cache miss on MX record. - if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.TypeMX)); err != ErrCacheMiss { - t.Fatalf("expected cache miss on MX; got: %v", err) - } - // Expect cache miss on CHAOS class. - if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.ClassCHAOS)); err != ErrCacheMiss { - t.Fatalf("expected cache miss on CHAOS; got: %v", err) - } - - // Ten seconds elapses; expect a cache miss. - clock.Advance(10 * time.Second) - if err := mc.ReplyFromCache(&out, makeQ(5, "foo.com.")); err != ErrCacheMiss { - t.Fatalf("expected cache miss, got: %v", err) - } -} - -type parsedMeta struct { - TxID uint16 - TTL uint32 -} - -func mustParseResponse(t testing.TB, r []byte) (ret parsedMeta) { - t.Helper() - var p dnsmessage.Parser - h, err := p.Start(r) - if err != nil { - t.Fatal(err) - } - ret.TxID = h.ID - qq, err := p.AllQuestions() - if err != nil { - t.Fatalf("AllQuestions: %v", err) - } - if len(qq) != 1 { - t.Fatalf("num questions = %v; want 1", len(qq)) - } - aa, err := p.AllAnswers() - if err != nil { - t.Fatalf("AllAnswers: %v", err) - } - for _, r := range aa { - if ret.TTL == 0 { - ret.TTL = r.Header.TTL - } - if ret.TTL != r.Header.TTL { - t.Fatal("mixed TTLs") - } - } - return ret -} - -type responseOpt bool - -type ttlOpt uint32 - -func makeQ(txID uint16, name string, opt ...any) []byte { - opt = append(opt, responseOpt(false)) - return makeDNSPkt(txID, name, opt...) -} - -func makeRes(txID uint16, name string, opt ...any) []byte { - opt = append(opt, responseOpt(true)) - return makeDNSPkt(txID, name, opt...) -} - -func makeDNSPkt(txID uint16, name string, opt ...any) []byte { - typ := dnsmessage.TypeA - class := dnsmessage.ClassINET - var response bool - var answers []dnsmessage.ResourceBody - var ttl uint32 = 1 // one second by default - for _, o := range opt { - switch o := o.(type) { - case dnsmessage.Type: - typ = o - case dnsmessage.Class: - class = o - case responseOpt: - response = bool(o) - case dnsmessage.ResourceBody: - answers = append(answers, o) - case ttlOpt: - ttl = uint32(o) - default: - panic(fmt.Sprintf("unknown opt type %T", o)) - } - } - qname := dnsmessage.MustNewName(name) - msg := dnsmessage.Message{ - Header: dnsmessage.Header{ID: txID, Response: response}, - Questions: []dnsmessage.Question{ - { - Name: qname, - Type: typ, - Class: class, - }, - }, - } - for _, rb := range answers { - msg.Answers = append(msg.Answers, dnsmessage.Resource{ - Header: dnsmessage.ResourceHeader{ - Name: qname, - Type: typ, - Class: class, - TTL: ttl, - }, - Body: rb, - }) - } - buf, err := msg.Pack() - if err != nil { - panic(err) - } - return buf -} - -func TestASCIILowerName(t *testing.T) { - n := asciiLowerName(dnsmessage.MustNewName("Foo.COM.")) - if got, want := n.String(), "foo.com."; got != want { - t.Errorf("got = %q; want %q", got, want) - } -} - -func TestGetDNSQueryCacheKey(t *testing.T) { - tests := []struct { - name string - pkt []byte - want msgQ - txID uint16 - anyTX bool - }{ - { - name: "empty", - }, - { - name: "a", - pkt: makeQ(123, "foo.com."), - want: msgQ{"foo.com.", dnsmessage.TypeA}, - txID: 123, - }, - { - name: "aaaa", - pkt: makeQ(6, "foo.com.", dnsmessage.TypeAAAA), - want: msgQ{"foo.com.", dnsmessage.TypeAAAA}, - txID: 6, - }, - { - name: "normalize_case", - pkt: makeQ(123, "FoO.CoM."), - want: msgQ{"foo.com.", dnsmessage.TypeA}, - txID: 123, - }, - { - name: "ignore_response", - pkt: makeRes(123, "foo.com."), - }, - { - name: "ignore_question_with_answers", - pkt: makeQ(2, "foo.com.", &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}), - }, - { - name: "whatever_go_generates", // in case Go's net package grows functionality we don't handle - pkt: getGoNetPacketDNSQuery("from-go.foo."), - want: msgQ{"from-go.foo.", dnsmessage.TypeA}, - anyTX: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, gotTX, ok := getDNSQueryCacheKey(tt.pkt) - if !ok { - if tt.txID == 0 && got == (msgQ{}) { - return - } - t.Fatal("failed") - } - if got != tt.want { - t.Errorf("got %+v, want %+v", got, tt.want) - } - if gotTX != tt.txID && !tt.anyTX { - t.Errorf("got tx %v, want %v", gotTX, tt.txID) - } - }) - } -} - -func getGoNetPacketDNSQuery(name string) []byte { - if runtime.GOOS == "windows" { - // On Windows, Go's net.Resolver doesn't use the DNS client. - // See https://github.com/golang/go/issues/33097 which - // was approved but not yet implemented. - // For now just pretend it's implemented to make this test - // pass on Windows with complicated the caller. - return makeQ(123, name) - } - res := make(chan []byte, 1) - r := &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - return goResolverConn(res), nil - }, - } - r.LookupIP(context.Background(), "ip4", name) - return <-res -} - -type goResolverConn chan<- []byte - -func (goResolverConn) Close() error { return nil } -func (goResolverConn) LocalAddr() net.Addr { return todoAddr{} } -func (goResolverConn) RemoteAddr() net.Addr { return todoAddr{} } -func (goResolverConn) SetDeadline(t time.Time) error { return nil } -func (goResolverConn) SetReadDeadline(t time.Time) error { return nil } -func (goResolverConn) SetWriteDeadline(t time.Time) error { return nil } -func (goResolverConn) Read([]byte) (int, error) { return 0, errors.New("boom") } -func (c goResolverConn) Write(p []byte) (int, error) { - select { - case c <- p[2:]: // skip 2 byte length for TCP mode DNS query - default: - } - return 0, errors.New("boom") -} - -type todoAddr struct{} - -func (todoAddr) Network() string { return "unused" } -func (todoAddr) String() string { return "unused-todoAddr" } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dnscache + +import ( + "bytes" + "context" + "errors" + "fmt" + "net" + "runtime" + "testing" + "time" + + "golang.org/x/net/dns/dnsmessage" + "tailscale.com/tstest" +) + +func TestMessageCache(t *testing.T) { + clock := tstest.NewClock(tstest.ClockOpts{ + Start: time.Date(1987, 11, 1, 0, 0, 0, 0, time.UTC), + }) + mc := &MessageCache{Clock: clock.Now} + mc.SetMaxCacheSize(2) + clock.Advance(time.Second) + + var out bytes.Buffer + if err := mc.ReplyFromCache(&out, makeQ(1, "foo.com.")); err != ErrCacheMiss { + t.Fatalf("unexpected error: %v", err) + } + + if err := mc.AddCacheEntry( + makeQ(2, "foo.com."), + makeRes(2, "FOO.COM.", ttlOpt(10), + &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}, + &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}})); err != nil { + t.Fatal(err) + } + + // Expect cache hit, with 10 seconds remaining. + out.Reset() + if err := mc.ReplyFromCache(&out, makeQ(3, "foo.com.")); err != nil { + t.Fatalf("expected cache hit; got: %v", err) + } + if p := mustParseResponse(t, out.Bytes()); p.TxID != 3 { + t.Errorf("TxID = %v; want %v", p.TxID, 3) + } else if p.TTL != 10 { + t.Errorf("TTL = %v; want 10", p.TTL) + } + + // One second elapses, expect a cache hit, with 9 seconds + // remaining. + clock.Advance(time.Second) + out.Reset() + if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.")); err != nil { + t.Fatalf("expected cache hit; got: %v", err) + } + if p := mustParseResponse(t, out.Bytes()); p.TxID != 4 { + t.Errorf("TxID = %v; want %v", p.TxID, 4) + } else if p.TTL != 9 { + t.Errorf("TTL = %v; want 9", p.TTL) + } + + // Expect cache miss on MX record. + if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.TypeMX)); err != ErrCacheMiss { + t.Fatalf("expected cache miss on MX; got: %v", err) + } + // Expect cache miss on CHAOS class. + if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.ClassCHAOS)); err != ErrCacheMiss { + t.Fatalf("expected cache miss on CHAOS; got: %v", err) + } + + // Ten seconds elapses; expect a cache miss. + clock.Advance(10 * time.Second) + if err := mc.ReplyFromCache(&out, makeQ(5, "foo.com.")); err != ErrCacheMiss { + t.Fatalf("expected cache miss, got: %v", err) + } +} + +type parsedMeta struct { + TxID uint16 + TTL uint32 +} + +func mustParseResponse(t testing.TB, r []byte) (ret parsedMeta) { + t.Helper() + var p dnsmessage.Parser + h, err := p.Start(r) + if err != nil { + t.Fatal(err) + } + ret.TxID = h.ID + qq, err := p.AllQuestions() + if err != nil { + t.Fatalf("AllQuestions: %v", err) + } + if len(qq) != 1 { + t.Fatalf("num questions = %v; want 1", len(qq)) + } + aa, err := p.AllAnswers() + if err != nil { + t.Fatalf("AllAnswers: %v", err) + } + for _, r := range aa { + if ret.TTL == 0 { + ret.TTL = r.Header.TTL + } + if ret.TTL != r.Header.TTL { + t.Fatal("mixed TTLs") + } + } + return ret +} + +type responseOpt bool + +type ttlOpt uint32 + +func makeQ(txID uint16, name string, opt ...any) []byte { + opt = append(opt, responseOpt(false)) + return makeDNSPkt(txID, name, opt...) +} + +func makeRes(txID uint16, name string, opt ...any) []byte { + opt = append(opt, responseOpt(true)) + return makeDNSPkt(txID, name, opt...) +} + +func makeDNSPkt(txID uint16, name string, opt ...any) []byte { + typ := dnsmessage.TypeA + class := dnsmessage.ClassINET + var response bool + var answers []dnsmessage.ResourceBody + var ttl uint32 = 1 // one second by default + for _, o := range opt { + switch o := o.(type) { + case dnsmessage.Type: + typ = o + case dnsmessage.Class: + class = o + case responseOpt: + response = bool(o) + case dnsmessage.ResourceBody: + answers = append(answers, o) + case ttlOpt: + ttl = uint32(o) + default: + panic(fmt.Sprintf("unknown opt type %T", o)) + } + } + qname := dnsmessage.MustNewName(name) + msg := dnsmessage.Message{ + Header: dnsmessage.Header{ID: txID, Response: response}, + Questions: []dnsmessage.Question{ + { + Name: qname, + Type: typ, + Class: class, + }, + }, + } + for _, rb := range answers { + msg.Answers = append(msg.Answers, dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: qname, + Type: typ, + Class: class, + TTL: ttl, + }, + Body: rb, + }) + } + buf, err := msg.Pack() + if err != nil { + panic(err) + } + return buf +} + +func TestASCIILowerName(t *testing.T) { + n := asciiLowerName(dnsmessage.MustNewName("Foo.COM.")) + if got, want := n.String(), "foo.com."; got != want { + t.Errorf("got = %q; want %q", got, want) + } +} + +func TestGetDNSQueryCacheKey(t *testing.T) { + tests := []struct { + name string + pkt []byte + want msgQ + txID uint16 + anyTX bool + }{ + { + name: "empty", + }, + { + name: "a", + pkt: makeQ(123, "foo.com."), + want: msgQ{"foo.com.", dnsmessage.TypeA}, + txID: 123, + }, + { + name: "aaaa", + pkt: makeQ(6, "foo.com.", dnsmessage.TypeAAAA), + want: msgQ{"foo.com.", dnsmessage.TypeAAAA}, + txID: 6, + }, + { + name: "normalize_case", + pkt: makeQ(123, "FoO.CoM."), + want: msgQ{"foo.com.", dnsmessage.TypeA}, + txID: 123, + }, + { + name: "ignore_response", + pkt: makeRes(123, "foo.com."), + }, + { + name: "ignore_question_with_answers", + pkt: makeQ(2, "foo.com.", &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}), + }, + { + name: "whatever_go_generates", // in case Go's net package grows functionality we don't handle + pkt: getGoNetPacketDNSQuery("from-go.foo."), + want: msgQ{"from-go.foo.", dnsmessage.TypeA}, + anyTX: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, gotTX, ok := getDNSQueryCacheKey(tt.pkt) + if !ok { + if tt.txID == 0 && got == (msgQ{}) { + return + } + t.Fatal("failed") + } + if got != tt.want { + t.Errorf("got %+v, want %+v", got, tt.want) + } + if gotTX != tt.txID && !tt.anyTX { + t.Errorf("got tx %v, want %v", gotTX, tt.txID) + } + }) + } +} + +func getGoNetPacketDNSQuery(name string) []byte { + if runtime.GOOS == "windows" { + // On Windows, Go's net.Resolver doesn't use the DNS client. + // See https://github.com/golang/go/issues/33097 which + // was approved but not yet implemented. + // For now just pretend it's implemented to make this test + // pass on Windows with complicated the caller. + return makeQ(123, name) + } + res := make(chan []byte, 1) + r := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return goResolverConn(res), nil + }, + } + r.LookupIP(context.Background(), "ip4", name) + return <-res +} + +type goResolverConn chan<- []byte + +func (goResolverConn) Close() error { return nil } +func (goResolverConn) LocalAddr() net.Addr { return todoAddr{} } +func (goResolverConn) RemoteAddr() net.Addr { return todoAddr{} } +func (goResolverConn) SetDeadline(t time.Time) error { return nil } +func (goResolverConn) SetReadDeadline(t time.Time) error { return nil } +func (goResolverConn) SetWriteDeadline(t time.Time) error { return nil } +func (goResolverConn) Read([]byte) (int, error) { return 0, errors.New("boom") } +func (c goResolverConn) Write(p []byte) (int, error) { + select { + case c <- p[2:]: // skip 2 byte length for TCP mode DNS query + default: + } + return 0, errors.New("boom") +} + +type todoAddr struct{} + +func (todoAddr) Network() string { return "unused" } +func (todoAddr) String() string { return "unused-todoAddr" } diff --git a/net/dnsfallback/update-dns-fallbacks.go b/net/dnsfallback/update-dns-fallbacks.go index ebbfc2ad17409..384e77e104cdc 100644 --- a/net/dnsfallback/update-dns-fallbacks.go +++ b/net/dnsfallback/update-dns-fallbacks.go @@ -1,45 +1,45 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -package main - -import ( - "encoding/json" - "fmt" - "log" - "net/http" - "os" - - "tailscale.com/tailcfg" -) - -func main() { - res, err := http.Get("https://login.tailscale.com/derpmap/default") - if err != nil { - log.Fatal(err) - } - if res.StatusCode != 200 { - res.Write(os.Stderr) - os.Exit(1) - } - dm := new(tailcfg.DERPMap) - if err := json.NewDecoder(res.Body).Decode(dm); err != nil { - log.Fatal(err) - } - for rid, r := range dm.Regions { - // Names misleading to check into git, as this is a - // static snapshot and doesn't reflect the live DERP - // map. - r.RegionCode = fmt.Sprintf("r%d", rid) - r.RegionName = r.RegionCode - } - out, err := json.MarshalIndent(dm, "", "\t") - if err != nil { - log.Fatal(err) - } - if err := os.WriteFile("dns-fallback-servers.json", out, 0644); err != nil { - log.Fatal(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +package main + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "os" + + "tailscale.com/tailcfg" +) + +func main() { + res, err := http.Get("https://login.tailscale.com/derpmap/default") + if err != nil { + log.Fatal(err) + } + if res.StatusCode != 200 { + res.Write(os.Stderr) + os.Exit(1) + } + dm := new(tailcfg.DERPMap) + if err := json.NewDecoder(res.Body).Decode(dm); err != nil { + log.Fatal(err) + } + for rid, r := range dm.Regions { + // Names misleading to check into git, as this is a + // static snapshot and doesn't reflect the live DERP + // map. + r.RegionCode = fmt.Sprintf("r%d", rid) + r.RegionName = r.RegionCode + } + out, err := json.MarshalIndent(dm, "", "\t") + if err != nil { + log.Fatal(err) + } + if err := os.WriteFile("dns-fallback-servers.json", out, 0644); err != nil { + log.Fatal(err) + } +} diff --git a/net/memnet/conn.go b/net/memnet/conn.go index f599612d93553..a9e1fd39901a0 100644 --- a/net/memnet/conn.go +++ b/net/memnet/conn.go @@ -1,114 +1,114 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "net" - "net/netip" - "time" -) - -// NetworkName is the network name returned by [net.Addr.Network] -// for [net.Conn.LocalAddr] and [net.Conn.RemoteAddr] from the [Conn] type. -const NetworkName = "mem" - -// Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked. -type Conn interface { - net.Conn - - // SetReadBlock blocks or unblocks the Read method of this Conn. - // It reports an error if the existing value matches the new value, - // or if the Conn has been Closed. - SetReadBlock(bool) error - - // SetWriteBlock blocks or unblocks the Write method of this Conn. - // It reports an error if the existing value matches the new value, - // or if the Conn has been Closed. - SetWriteBlock(bool) error -} - -// NewConn creates a pair of Conns that are wired together by pipes. -func NewConn(name string, maxBuf int) (Conn, Conn) { - r := NewPipe(name+"|0", maxBuf) - w := NewPipe(name+"|1", maxBuf) - - return &connHalf{r: r, w: w}, &connHalf{r: w, w: r} -} - -// NewTCPConn creates a pair of Conns that are wired together by pipes. -func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) { - r := NewPipe(src.String(), maxBuf) - w := NewPipe(dst.String(), maxBuf) - - lAddr := net.TCPAddrFromAddrPort(src) - rAddr := net.TCPAddrFromAddrPort(dst) - - return &connHalf{r: r, w: w, remote: rAddr, local: lAddr}, &connHalf{r: w, w: r, remote: lAddr, local: rAddr} -} - -type connAddr string - -func (a connAddr) Network() string { return NetworkName } -func (a connAddr) String() string { return string(a) } - -type connHalf struct { - local, remote net.Addr - r, w *Pipe -} - -func (c *connHalf) LocalAddr() net.Addr { - if c.local != nil { - return c.local - } - return connAddr(c.r.name) -} - -func (c *connHalf) RemoteAddr() net.Addr { - if c.remote != nil { - return c.remote - } - return connAddr(c.w.name) -} - -func (c *connHalf) Read(b []byte) (n int, err error) { - return c.r.Read(b) -} -func (c *connHalf) Write(b []byte) (n int, err error) { - return c.w.Write(b) -} - -func (c *connHalf) Close() error { - if err := c.w.Close(); err != nil { - return err - } - return c.r.Close() -} - -func (c *connHalf) SetDeadline(t time.Time) error { - err1 := c.SetReadDeadline(t) - err2 := c.SetWriteDeadline(t) - if err1 != nil { - return err1 - } - return err2 -} -func (c *connHalf) SetReadDeadline(t time.Time) error { - return c.r.SetReadDeadline(t) -} -func (c *connHalf) SetWriteDeadline(t time.Time) error { - return c.w.SetWriteDeadline(t) -} - -func (c *connHalf) SetReadBlock(b bool) error { - if b { - return c.r.Block() - } - return c.r.Unblock() -} -func (c *connHalf) SetWriteBlock(b bool) error { - if b { - return c.w.Block() - } - return c.w.Unblock() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "net" + "net/netip" + "time" +) + +// NetworkName is the network name returned by [net.Addr.Network] +// for [net.Conn.LocalAddr] and [net.Conn.RemoteAddr] from the [Conn] type. +const NetworkName = "mem" + +// Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked. +type Conn interface { + net.Conn + + // SetReadBlock blocks or unblocks the Read method of this Conn. + // It reports an error if the existing value matches the new value, + // or if the Conn has been Closed. + SetReadBlock(bool) error + + // SetWriteBlock blocks or unblocks the Write method of this Conn. + // It reports an error if the existing value matches the new value, + // or if the Conn has been Closed. + SetWriteBlock(bool) error +} + +// NewConn creates a pair of Conns that are wired together by pipes. +func NewConn(name string, maxBuf int) (Conn, Conn) { + r := NewPipe(name+"|0", maxBuf) + w := NewPipe(name+"|1", maxBuf) + + return &connHalf{r: r, w: w}, &connHalf{r: w, w: r} +} + +// NewTCPConn creates a pair of Conns that are wired together by pipes. +func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) { + r := NewPipe(src.String(), maxBuf) + w := NewPipe(dst.String(), maxBuf) + + lAddr := net.TCPAddrFromAddrPort(src) + rAddr := net.TCPAddrFromAddrPort(dst) + + return &connHalf{r: r, w: w, remote: rAddr, local: lAddr}, &connHalf{r: w, w: r, remote: lAddr, local: rAddr} +} + +type connAddr string + +func (a connAddr) Network() string { return NetworkName } +func (a connAddr) String() string { return string(a) } + +type connHalf struct { + local, remote net.Addr + r, w *Pipe +} + +func (c *connHalf) LocalAddr() net.Addr { + if c.local != nil { + return c.local + } + return connAddr(c.r.name) +} + +func (c *connHalf) RemoteAddr() net.Addr { + if c.remote != nil { + return c.remote + } + return connAddr(c.w.name) +} + +func (c *connHalf) Read(b []byte) (n int, err error) { + return c.r.Read(b) +} +func (c *connHalf) Write(b []byte) (n int, err error) { + return c.w.Write(b) +} + +func (c *connHalf) Close() error { + if err := c.w.Close(); err != nil { + return err + } + return c.r.Close() +} + +func (c *connHalf) SetDeadline(t time.Time) error { + err1 := c.SetReadDeadline(t) + err2 := c.SetWriteDeadline(t) + if err1 != nil { + return err1 + } + return err2 +} +func (c *connHalf) SetReadDeadline(t time.Time) error { + return c.r.SetReadDeadline(t) +} +func (c *connHalf) SetWriteDeadline(t time.Time) error { + return c.w.SetWriteDeadline(t) +} + +func (c *connHalf) SetReadBlock(b bool) error { + if b { + return c.r.Block() + } + return c.r.Unblock() +} +func (c *connHalf) SetWriteBlock(b bool) error { + if b { + return c.w.Block() + } + return c.w.Unblock() +} diff --git a/net/memnet/conn_test.go b/net/memnet/conn_test.go index 3eec80bc6a583..743ce5248cb9d 100644 --- a/net/memnet/conn_test.go +++ b/net/memnet/conn_test.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "net" - "testing" - - "golang.org/x/net/nettest" -) - -func TestConn(t *testing.T) { - nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) { - c1, c2 = NewConn("test", bufferSize) - return c1, c2, func() { - c1.Close() - c2.Close() - }, nil - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "net" + "testing" + + "golang.org/x/net/nettest" +) + +func TestConn(t *testing.T) { + nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) { + c1, c2 = NewConn("test", bufferSize) + return c1, c2, func() { + c1.Close() + c2.Close() + }, nil + }) +} diff --git a/net/memnet/listener.go b/net/memnet/listener.go index d1364d7903d15..d84a2e443cbff 100644 --- a/net/memnet/listener.go +++ b/net/memnet/listener.go @@ -1,100 +1,100 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "context" - "net" - "strings" - "sync" -) - -const ( - bufferSize = 256 * 1024 -) - -// Listener is a net.Listener using NewConn to create pairs of network -// connections connected in memory using a buffered pipe. It also provides a -// Dial method to establish new connections. -type Listener struct { - addr connAddr - ch chan Conn - closeOnce sync.Once - closed chan struct{} - - // NewConn, if non-nil, is called to create a new pair of connections - // when dialing. If nil, NewConn is used. - NewConn func(network, addr string, maxBuf int) (Conn, Conn) -} - -// Listen returns a new Listener for the provided address. -func Listen(addr string) *Listener { - return &Listener{ - addr: connAddr(addr), - ch: make(chan Conn), - closed: make(chan struct{}), - } -} - -// Addr implements net.Listener.Addr. -func (l *Listener) Addr() net.Addr { - return l.addr -} - -// Close closes the pipe listener. -func (l *Listener) Close() error { - l.closeOnce.Do(func() { - close(l.closed) - }) - return nil -} - -// Accept blocks until a new connection is available or the listener is closed. -func (l *Listener) Accept() (net.Conn, error) { - select { - case c := <-l.ch: - return c, nil - case <-l.closed: - return nil, net.ErrClosed - } -} - -// Dial connects to the listener using the provided context. -// The provided Context must be non-nil. If the context expires before the -// connection is complete, an error is returned. Once successfully connected -// any expiration of the context will not affect the connection. -func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) { - if !strings.HasSuffix(network, "tcp") { - return nil, net.UnknownNetworkError(network) - } - if connAddr(addr) != l.addr { - return nil, &net.AddrError{ - Err: "invalid address", - Addr: addr, - } - } - - newConn := l.NewConn - if newConn == nil { - newConn = func(network, addr string, maxBuf int) (Conn, Conn) { - return NewConn(addr, maxBuf) - } - } - c, s := newConn(network, addr, bufferSize) - defer func() { - if err != nil { - c.Close() - s.Close() - } - }() - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-l.closed: - return nil, net.ErrClosed - case l.ch <- s: - return c, nil - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "context" + "net" + "strings" + "sync" +) + +const ( + bufferSize = 256 * 1024 +) + +// Listener is a net.Listener using NewConn to create pairs of network +// connections connected in memory using a buffered pipe. It also provides a +// Dial method to establish new connections. +type Listener struct { + addr connAddr + ch chan Conn + closeOnce sync.Once + closed chan struct{} + + // NewConn, if non-nil, is called to create a new pair of connections + // when dialing. If nil, NewConn is used. + NewConn func(network, addr string, maxBuf int) (Conn, Conn) +} + +// Listen returns a new Listener for the provided address. +func Listen(addr string) *Listener { + return &Listener{ + addr: connAddr(addr), + ch: make(chan Conn), + closed: make(chan struct{}), + } +} + +// Addr implements net.Listener.Addr. +func (l *Listener) Addr() net.Addr { + return l.addr +} + +// Close closes the pipe listener. +func (l *Listener) Close() error { + l.closeOnce.Do(func() { + close(l.closed) + }) + return nil +} + +// Accept blocks until a new connection is available or the listener is closed. +func (l *Listener) Accept() (net.Conn, error) { + select { + case c := <-l.ch: + return c, nil + case <-l.closed: + return nil, net.ErrClosed + } +} + +// Dial connects to the listener using the provided context. +// The provided Context must be non-nil. If the context expires before the +// connection is complete, an error is returned. Once successfully connected +// any expiration of the context will not affect the connection. +func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) { + if !strings.HasSuffix(network, "tcp") { + return nil, net.UnknownNetworkError(network) + } + if connAddr(addr) != l.addr { + return nil, &net.AddrError{ + Err: "invalid address", + Addr: addr, + } + } + + newConn := l.NewConn + if newConn == nil { + newConn = func(network, addr string, maxBuf int) (Conn, Conn) { + return NewConn(addr, maxBuf) + } + } + c, s := newConn(network, addr, bufferSize) + defer func() { + if err != nil { + c.Close() + s.Close() + } + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-l.closed: + return nil, net.ErrClosed + case l.ch <- s: + return c, nil + } +} diff --git a/net/memnet/listener_test.go b/net/memnet/listener_test.go index 989d5e9e4bb2b..73b67841ad08c 100644 --- a/net/memnet/listener_test.go +++ b/net/memnet/listener_test.go @@ -1,33 +1,33 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "context" - "testing" -) - -func TestListener(t *testing.T) { - l := Listen("srv.local") - defer l.Close() - go func() { - c, err := l.Accept() - if err != nil { - t.Error(err) - return - } - defer c.Close() - }() - - if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil { - c.Close() - t.Fatalf("dial to invalid address succeeded") - } - c, err := l.Dial(context.Background(), "tcp", "srv.local") - if err != nil { - t.Fatalf("dial failed: %v", err) - return - } - c.Close() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "context" + "testing" +) + +func TestListener(t *testing.T) { + l := Listen("srv.local") + defer l.Close() + go func() { + c, err := l.Accept() + if err != nil { + t.Error(err) + return + } + defer c.Close() + }() + + if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil { + c.Close() + t.Fatalf("dial to invalid address succeeded") + } + c, err := l.Dial(context.Background(), "tcp", "srv.local") + if err != nil { + t.Fatalf("dial failed: %v", err) + return + } + c.Close() +} diff --git a/net/memnet/memnet.go b/net/memnet/memnet.go index 2fc13b4b2436f..c8799bc17035e 100644 --- a/net/memnet/memnet.go +++ b/net/memnet/memnet.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package memnet implements an in-memory network implementation. -// It is useful for dialing and listening on in-memory addresses -// in tests and other situations where you don't want to use the -// network. -package memnet +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package memnet implements an in-memory network implementation. +// It is useful for dialing and listening on in-memory addresses +// in tests and other situations where you don't want to use the +// network. +package memnet diff --git a/net/memnet/pipe.go b/net/memnet/pipe.go index 51bee109024d0..47163508353a6 100644 --- a/net/memnet/pipe.go +++ b/net/memnet/pipe.go @@ -1,244 +1,244 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "bytes" - "context" - "fmt" - "io" - "log" - "net" - "os" - "sync" - "time" -) - -const debugPipe = false - -// Pipe implements an in-memory FIFO with timeouts. -type Pipe struct { - name string - maxBuf int - mu sync.Mutex - cnd *sync.Cond - - blocked bool - closed bool - buf bytes.Buffer - readTimeout time.Time - writeTimeout time.Time - cancelReadTimer func() - cancelWriteTimer func() -} - -// NewPipe creates a Pipe with a buffer size fixed at maxBuf. -func NewPipe(name string, maxBuf int) *Pipe { - p := &Pipe{ - name: name, - maxBuf: maxBuf, - } - p.cnd = sync.NewCond(&p.mu) - return p -} - -// readOrBlock attempts to read from the buffer, if the buffer is empty and -// the connection hasn't been closed it will block until there is a change. -func (p *Pipe) readOrBlock(b []byte) (int, error) { - p.mu.Lock() - defer p.mu.Unlock() - if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) { - return 0, os.ErrDeadlineExceeded - } - if p.blocked { - p.cnd.Wait() - return 0, nil - } - - n, err := p.buf.Read(b) - // err will either be nil or io.EOF. - if err == io.EOF { - if p.closed { - return n, err - } - // Wait for something to change. - p.cnd.Wait() - } - return n, nil -} - -// Read implements io.Reader. -// Once the buffer is drained (i.e. after Close), subsequent calls will -// return io.EOF. -func (p *Pipe) Read(b []byte) (n int, err error) { - if debugPipe { - orig := b - defer func() { - log.Printf("Pipe(%q).Read(%q) n=%d, err=%v", p.name, string(orig[:n]), n, err) - }() - } - for n == 0 { - n2, err := p.readOrBlock(b) - if err != nil { - return n2, err - } - n += n2 - } - p.cnd.Signal() - return n, nil -} - -// writeOrBlock attempts to write to the buffer, if the buffer is full it will -// block until there is a change. -func (p *Pipe) writeOrBlock(b []byte) (int, error) { - p.mu.Lock() - defer p.mu.Unlock() - if p.closed { - return 0, net.ErrClosed - } - if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) { - return 0, os.ErrDeadlineExceeded - } - if p.blocked { - p.cnd.Wait() - return 0, nil - } - - // Optimistically we want to write the entire slice. - n := len(b) - if limit := p.maxBuf - p.buf.Len(); limit < n { - // However, we don't have enough capacity to write everything. - n = limit - } - if n == 0 { - // Wait for something to change. - p.cnd.Wait() - return 0, nil - } - - p.buf.Write(b[:n]) - p.cnd.Signal() - return n, nil -} - -// Write implements io.Writer. -func (p *Pipe) Write(b []byte) (n int, err error) { - if debugPipe { - orig := b - defer func() { - log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err) - }() - } - for len(b) > 0 { - n2, err := p.writeOrBlock(b) - if err != nil { - return n + n2, err - } - n += n2 - b = b[n2:] - } - return n, nil -} - -// Close closes the pipe. -func (p *Pipe) Close() error { - p.mu.Lock() - defer p.mu.Unlock() - p.closed = true - p.blocked = false - if p.cancelWriteTimer != nil { - p.cancelWriteTimer() - p.cancelWriteTimer = nil - } - if p.cancelReadTimer != nil { - p.cancelReadTimer() - p.cancelReadTimer = nil - } - p.cnd.Broadcast() - - return nil -} - -func (p *Pipe) deadlineTimer(t time.Time) func() { - if t.IsZero() { - return nil - } - if t.Before(time.Now()) { - p.cnd.Broadcast() - return nil - } - ctx, cancel := context.WithDeadline(context.Background(), t) - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - p.cnd.Broadcast() - } - }() - return cancel -} - -// SetReadDeadline sets the deadline for future Read calls. -func (p *Pipe) SetReadDeadline(t time.Time) error { - p.mu.Lock() - defer p.mu.Unlock() - p.readTimeout = t - // If we already have a deadline, cancel it and create a new one. - if p.cancelReadTimer != nil { - p.cancelReadTimer() - p.cancelReadTimer = nil - } - p.cancelReadTimer = p.deadlineTimer(t) - return nil -} - -// SetWriteDeadline sets the deadline for future Write calls. -func (p *Pipe) SetWriteDeadline(t time.Time) error { - p.mu.Lock() - defer p.mu.Unlock() - p.writeTimeout = t - // If we already have a deadline, cancel it and create a new one. - if p.cancelWriteTimer != nil { - p.cancelWriteTimer() - p.cancelWriteTimer = nil - } - p.cancelWriteTimer = p.deadlineTimer(t) - return nil -} - -// Block will cause all calls to Read and Write to block until they either -// timeout, are unblocked or the pipe is closed. -func (p *Pipe) Block() error { - p.mu.Lock() - defer p.mu.Unlock() - closed := p.closed - blocked := p.blocked - p.blocked = true - - if closed { - return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) - } - if blocked { - return fmt.Errorf("memnet.Pipe(%q).Block: already blocked", p.name) - } - p.cnd.Broadcast() - return nil -} - -// Unblock will cause all blocked Read/Write calls to continue execution. -func (p *Pipe) Unblock() error { - p.mu.Lock() - defer p.mu.Unlock() - closed := p.closed - blocked := p.blocked - p.blocked = false - - if closed { - return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) - } - if !blocked { - return fmt.Errorf("memnet.Pipe(%q).Block: already unblocked", p.name) - } - p.cnd.Broadcast() - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "bytes" + "context" + "fmt" + "io" + "log" + "net" + "os" + "sync" + "time" +) + +const debugPipe = false + +// Pipe implements an in-memory FIFO with timeouts. +type Pipe struct { + name string + maxBuf int + mu sync.Mutex + cnd *sync.Cond + + blocked bool + closed bool + buf bytes.Buffer + readTimeout time.Time + writeTimeout time.Time + cancelReadTimer func() + cancelWriteTimer func() +} + +// NewPipe creates a Pipe with a buffer size fixed at maxBuf. +func NewPipe(name string, maxBuf int) *Pipe { + p := &Pipe{ + name: name, + maxBuf: maxBuf, + } + p.cnd = sync.NewCond(&p.mu) + return p +} + +// readOrBlock attempts to read from the buffer, if the buffer is empty and +// the connection hasn't been closed it will block until there is a change. +func (p *Pipe) readOrBlock(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) { + return 0, os.ErrDeadlineExceeded + } + if p.blocked { + p.cnd.Wait() + return 0, nil + } + + n, err := p.buf.Read(b) + // err will either be nil or io.EOF. + if err == io.EOF { + if p.closed { + return n, err + } + // Wait for something to change. + p.cnd.Wait() + } + return n, nil +} + +// Read implements io.Reader. +// Once the buffer is drained (i.e. after Close), subsequent calls will +// return io.EOF. +func (p *Pipe) Read(b []byte) (n int, err error) { + if debugPipe { + orig := b + defer func() { + log.Printf("Pipe(%q).Read(%q) n=%d, err=%v", p.name, string(orig[:n]), n, err) + }() + } + for n == 0 { + n2, err := p.readOrBlock(b) + if err != nil { + return n2, err + } + n += n2 + } + p.cnd.Signal() + return n, nil +} + +// writeOrBlock attempts to write to the buffer, if the buffer is full it will +// block until there is a change. +func (p *Pipe) writeOrBlock(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.closed { + return 0, net.ErrClosed + } + if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) { + return 0, os.ErrDeadlineExceeded + } + if p.blocked { + p.cnd.Wait() + return 0, nil + } + + // Optimistically we want to write the entire slice. + n := len(b) + if limit := p.maxBuf - p.buf.Len(); limit < n { + // However, we don't have enough capacity to write everything. + n = limit + } + if n == 0 { + // Wait for something to change. + p.cnd.Wait() + return 0, nil + } + + p.buf.Write(b[:n]) + p.cnd.Signal() + return n, nil +} + +// Write implements io.Writer. +func (p *Pipe) Write(b []byte) (n int, err error) { + if debugPipe { + orig := b + defer func() { + log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err) + }() + } + for len(b) > 0 { + n2, err := p.writeOrBlock(b) + if err != nil { + return n + n2, err + } + n += n2 + b = b[n2:] + } + return n, nil +} + +// Close closes the pipe. +func (p *Pipe) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + p.closed = true + p.blocked = false + if p.cancelWriteTimer != nil { + p.cancelWriteTimer() + p.cancelWriteTimer = nil + } + if p.cancelReadTimer != nil { + p.cancelReadTimer() + p.cancelReadTimer = nil + } + p.cnd.Broadcast() + + return nil +} + +func (p *Pipe) deadlineTimer(t time.Time) func() { + if t.IsZero() { + return nil + } + if t.Before(time.Now()) { + p.cnd.Broadcast() + return nil + } + ctx, cancel := context.WithDeadline(context.Background(), t) + go func() { + <-ctx.Done() + if ctx.Err() == context.DeadlineExceeded { + p.cnd.Broadcast() + } + }() + return cancel +} + +// SetReadDeadline sets the deadline for future Read calls. +func (p *Pipe) SetReadDeadline(t time.Time) error { + p.mu.Lock() + defer p.mu.Unlock() + p.readTimeout = t + // If we already have a deadline, cancel it and create a new one. + if p.cancelReadTimer != nil { + p.cancelReadTimer() + p.cancelReadTimer = nil + } + p.cancelReadTimer = p.deadlineTimer(t) + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls. +func (p *Pipe) SetWriteDeadline(t time.Time) error { + p.mu.Lock() + defer p.mu.Unlock() + p.writeTimeout = t + // If we already have a deadline, cancel it and create a new one. + if p.cancelWriteTimer != nil { + p.cancelWriteTimer() + p.cancelWriteTimer = nil + } + p.cancelWriteTimer = p.deadlineTimer(t) + return nil +} + +// Block will cause all calls to Read and Write to block until they either +// timeout, are unblocked or the pipe is closed. +func (p *Pipe) Block() error { + p.mu.Lock() + defer p.mu.Unlock() + closed := p.closed + blocked := p.blocked + p.blocked = true + + if closed { + return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) + } + if blocked { + return fmt.Errorf("memnet.Pipe(%q).Block: already blocked", p.name) + } + p.cnd.Broadcast() + return nil +} + +// Unblock will cause all blocked Read/Write calls to continue execution. +func (p *Pipe) Unblock() error { + p.mu.Lock() + defer p.mu.Unlock() + closed := p.closed + blocked := p.blocked + p.blocked = false + + if closed { + return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) + } + if !blocked { + return fmt.Errorf("memnet.Pipe(%q).Block: already unblocked", p.name) + } + p.cnd.Broadcast() + return nil +} diff --git a/net/memnet/pipe_test.go b/net/memnet/pipe_test.go index b3775cf7f9130..a86d65388e27d 100644 --- a/net/memnet/pipe_test.go +++ b/net/memnet/pipe_test.go @@ -1,117 +1,117 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "errors" - "fmt" - "os" - "testing" - "time" -) - -func TestPipeHello(t *testing.T) { - p := NewPipe("p1", 1<<16) - msg := "Hello, World!" - if n, err := p.Write([]byte(msg)); err != nil { - t.Fatal(err) - } else if n != len(msg) { - t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg)) - } - b := make([]byte, len(msg)) - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != len(b) { - t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b)) - } - if got := string(b); got != msg { - t.Errorf("p.Read: %q, want %q", got, msg) - } -} - -func TestPipeTimeout(t *testing.T) { - t.Run("write", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.SetWriteDeadline(time.Now().Add(-1 * time.Second)) - n, err := p.Write([]byte{'h'}) - if !errors.Is(err, os.ErrDeadlineExceeded) { - t.Errorf("missing write timeout got err: %v", err) - } - if n != 0 { - t.Errorf("n=%d on timeout", n) - } - }) - t.Run("read", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.Write([]byte{'h'}) - - p.SetReadDeadline(time.Now().Add(-1 * time.Second)) - b := make([]byte, 1) - n, err := p.Read(b) - if !errors.Is(err, os.ErrDeadlineExceeded) { - t.Errorf("missing read timeout got err: %v", err) - } - if n != 0 { - t.Errorf("n=%d on timeout", n) - } - }) - t.Run("block-write", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) - if err := p.Block(); err != nil { - t.Fatal(err) - } - if _, err := p.Write([]byte{'h'}); !errors.Is(err, os.ErrDeadlineExceeded) { - t.Fatalf("want write timeout got: %v", err) - } - }) - t.Run("block-read", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.Write([]byte{'h', 'i'}) - p.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) - b := make([]byte, 1) - if err := p.Block(); err != nil { - t.Fatal(err) - } - if _, err := p.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) { - t.Fatalf("want read timeout got: %v", err) - } - }) -} - -func TestLimit(t *testing.T) { - p := NewPipe("p1", 1) - errCh := make(chan error) - go func() { - n, err := p.Write([]byte{'a', 'b', 'c'}) - if err != nil { - errCh <- err - } else if n != 3 { - errCh <- fmt.Errorf("p.Write n=%d, want 3", n) - } else { - errCh <- nil - } - }() - b := make([]byte, 3) - - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Errorf("Read(%q): n=%d want 1", string(b), n) - } - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Errorf("Read(%q): n=%d want 1", string(b), n) - } - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Errorf("Read(%q): n=%d want 1", string(b), n) - } - - if err := <-errCh; err != nil { - t.Error(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "errors" + "fmt" + "os" + "testing" + "time" +) + +func TestPipeHello(t *testing.T) { + p := NewPipe("p1", 1<<16) + msg := "Hello, World!" + if n, err := p.Write([]byte(msg)); err != nil { + t.Fatal(err) + } else if n != len(msg) { + t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg)) + } + b := make([]byte, len(msg)) + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != len(b) { + t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b)) + } + if got := string(b); got != msg { + t.Errorf("p.Read: %q, want %q", got, msg) + } +} + +func TestPipeTimeout(t *testing.T) { + t.Run("write", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.SetWriteDeadline(time.Now().Add(-1 * time.Second)) + n, err := p.Write([]byte{'h'}) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("missing write timeout got err: %v", err) + } + if n != 0 { + t.Errorf("n=%d on timeout", n) + } + }) + t.Run("read", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.Write([]byte{'h'}) + + p.SetReadDeadline(time.Now().Add(-1 * time.Second)) + b := make([]byte, 1) + n, err := p.Read(b) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("missing read timeout got err: %v", err) + } + if n != 0 { + t.Errorf("n=%d on timeout", n) + } + }) + t.Run("block-write", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) + if err := p.Block(); err != nil { + t.Fatal(err) + } + if _, err := p.Write([]byte{'h'}); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("want write timeout got: %v", err) + } + }) + t.Run("block-read", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.Write([]byte{'h', 'i'}) + p.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + b := make([]byte, 1) + if err := p.Block(); err != nil { + t.Fatal(err) + } + if _, err := p.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("want read timeout got: %v", err) + } + }) +} + +func TestLimit(t *testing.T) { + p := NewPipe("p1", 1) + errCh := make(chan error) + go func() { + n, err := p.Write([]byte{'a', 'b', 'c'}) + if err != nil { + errCh <- err + } else if n != 3 { + errCh <- fmt.Errorf("p.Write n=%d, want 3", n) + } else { + errCh <- nil + } + }() + b := make([]byte, 3) + + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + + if err := <-errCh; err != nil { + t.Error(err) + } +} diff --git a/net/netaddr/netaddr.go b/net/netaddr/netaddr.go index 6f85a52b7c550..1ab6c053a523e 100644 --- a/net/netaddr/netaddr.go +++ b/net/netaddr/netaddr.go @@ -1,49 +1,49 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netaddr is a transitional package while we finish migrating from inet.af/netaddr -// to Go 1.18's net/netip. -// -// TODO(bradfitz): delete this package eventually. Tracking bug is -// https://github.com/tailscale/tailscale/issues/5162 -package netaddr - -import ( - "net" - "net/netip" -) - -// IPv4 returns the IP of the IPv4 address a.b.c.d. -func IPv4(a, b, c, d uint8) netip.Addr { - return netip.AddrFrom4([4]byte{a, b, c, d}) -} - -// Unmap returns the provided AddrPort with its Addr IP component Unmap'ed. -// -// See https://github.com/golang/go/issues/53607#issuecomment-1203466984 -func Unmap(ap netip.AddrPort) netip.AddrPort { - return netip.AddrPortFrom(ap.Addr().Unmap(), ap.Port()) -} - -// FromStdIPNet returns an IPPrefix from the standard library's IPNet type. -// If std is invalid, ok is false. -func FromStdIPNet(std *net.IPNet) (prefix netip.Prefix, ok bool) { - ip, ok := netip.AddrFromSlice(std.IP) - if !ok { - return netip.Prefix{}, false - } - ip = ip.Unmap() - - if l := len(std.Mask); l != net.IPv4len && l != net.IPv6len { - // Invalid mask. - return netip.Prefix{}, false - } - - ones, bits := std.Mask.Size() - if ones == 0 && bits == 0 { - // IPPrefix does not support non-contiguous masks. - return netip.Prefix{}, false - } - - return netip.PrefixFrom(ip, ones), true -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netaddr is a transitional package while we finish migrating from inet.af/netaddr +// to Go 1.18's net/netip. +// +// TODO(bradfitz): delete this package eventually. Tracking bug is +// https://github.com/tailscale/tailscale/issues/5162 +package netaddr + +import ( + "net" + "net/netip" +) + +// IPv4 returns the IP of the IPv4 address a.b.c.d. +func IPv4(a, b, c, d uint8) netip.Addr { + return netip.AddrFrom4([4]byte{a, b, c, d}) +} + +// Unmap returns the provided AddrPort with its Addr IP component Unmap'ed. +// +// See https://github.com/golang/go/issues/53607#issuecomment-1203466984 +func Unmap(ap netip.AddrPort) netip.AddrPort { + return netip.AddrPortFrom(ap.Addr().Unmap(), ap.Port()) +} + +// FromStdIPNet returns an IPPrefix from the standard library's IPNet type. +// If std is invalid, ok is false. +func FromStdIPNet(std *net.IPNet) (prefix netip.Prefix, ok bool) { + ip, ok := netip.AddrFromSlice(std.IP) + if !ok { + return netip.Prefix{}, false + } + ip = ip.Unmap() + + if l := len(std.Mask); l != net.IPv4len && l != net.IPv6len { + // Invalid mask. + return netip.Prefix{}, false + } + + ones, bits := std.Mask.Size() + if ones == 0 && bits == 0 { + // IPPrefix does not support non-contiguous masks. + return netip.Prefix{}, false + } + + return netip.PrefixFrom(ip, ones), true +} diff --git a/net/neterror/neterror.go b/net/neterror/neterror.go index f570b89302a1b..e2387440d33d5 100644 --- a/net/neterror/neterror.go +++ b/net/neterror/neterror.go @@ -1,82 +1,82 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package neterror classifies network errors. -package neterror - -import ( - "errors" - "fmt" - "runtime" - "syscall" -) - -var errEPERM error = syscall.EPERM // box it into interface just once - -// TreatAsLostUDP reports whether err is an error from a UDP send -// operation that should be treated as a UDP packet that just got -// lost. -// -// Notably, on Linux this reports true for EPERM errors (from outbound -// firewall blocks) which aren't really send errors; they're just -// sends that are never going to make it because the local OS blocked -// it. -func TreatAsLostUDP(err error) bool { - if err == nil { - return false - } - switch runtime.GOOS { - case "linux": - // Linux, while not documented in the man page, - // returns EPERM when there's an OUTPUT rule with -j - // DROP or -j REJECT. We use this very specific - // Linux+EPERM check rather than something super broad - // like net.Error.Temporary which could be anything. - // - // For now we only do this on Linux, as such outgoing - // firewall violations mapping to syscall errors - // hasn't yet been observed on other OSes. - return errors.Is(err, errEPERM) - } - return false -} - -var packetWasTruncated func(error) bool // non-nil on Windows at least - -// PacketWasTruncated reports whether err indicates truncation but the RecvFrom -// that generated err was otherwise successful. On Windows, Go's UDP RecvFrom -// calls WSARecvFrom which returns the WSAEMSGSIZE error code when the received -// datagram is larger than the provided buffer. When that happens, both a valid -// size and an error are returned (as per the partial fix for golang/go#14074). -// If the WSAEMSGSIZE error is returned, then we ignore the error to get -// semantics similar to the POSIX operating systems. One caveat is that it -// appears that the source address is not returned when WSAEMSGSIZE occurs, but -// we do not currently look at the source address. -func PacketWasTruncated(err error) bool { - if packetWasTruncated == nil { - return false - } - return packetWasTruncated(err) -} - -var shouldDisableUDPGSO func(error) bool // non-nil on Linux - -func ShouldDisableUDPGSO(err error) bool { - if shouldDisableUDPGSO == nil { - return false - } - return shouldDisableUDPGSO(err) -} - -type ErrUDPGSODisabled struct { - OnLaddr string - RetryErr error -} - -func (e ErrUDPGSODisabled) Error() string { - return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.OnLaddr) -} - -func (e ErrUDPGSODisabled) Unwrap() error { - return e.RetryErr -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package neterror classifies network errors. +package neterror + +import ( + "errors" + "fmt" + "runtime" + "syscall" +) + +var errEPERM error = syscall.EPERM // box it into interface just once + +// TreatAsLostUDP reports whether err is an error from a UDP send +// operation that should be treated as a UDP packet that just got +// lost. +// +// Notably, on Linux this reports true for EPERM errors (from outbound +// firewall blocks) which aren't really send errors; they're just +// sends that are never going to make it because the local OS blocked +// it. +func TreatAsLostUDP(err error) bool { + if err == nil { + return false + } + switch runtime.GOOS { + case "linux": + // Linux, while not documented in the man page, + // returns EPERM when there's an OUTPUT rule with -j + // DROP or -j REJECT. We use this very specific + // Linux+EPERM check rather than something super broad + // like net.Error.Temporary which could be anything. + // + // For now we only do this on Linux, as such outgoing + // firewall violations mapping to syscall errors + // hasn't yet been observed on other OSes. + return errors.Is(err, errEPERM) + } + return false +} + +var packetWasTruncated func(error) bool // non-nil on Windows at least + +// PacketWasTruncated reports whether err indicates truncation but the RecvFrom +// that generated err was otherwise successful. On Windows, Go's UDP RecvFrom +// calls WSARecvFrom which returns the WSAEMSGSIZE error code when the received +// datagram is larger than the provided buffer. When that happens, both a valid +// size and an error are returned (as per the partial fix for golang/go#14074). +// If the WSAEMSGSIZE error is returned, then we ignore the error to get +// semantics similar to the POSIX operating systems. One caveat is that it +// appears that the source address is not returned when WSAEMSGSIZE occurs, but +// we do not currently look at the source address. +func PacketWasTruncated(err error) bool { + if packetWasTruncated == nil { + return false + } + return packetWasTruncated(err) +} + +var shouldDisableUDPGSO func(error) bool // non-nil on Linux + +func ShouldDisableUDPGSO(err error) bool { + if shouldDisableUDPGSO == nil { + return false + } + return shouldDisableUDPGSO(err) +} + +type ErrUDPGSODisabled struct { + OnLaddr string + RetryErr error +} + +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.OnLaddr) +} + +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} diff --git a/net/neterror/neterror_linux.go b/net/neterror/neterror_linux.go index 3f402dd30d236..857367fe8ebb5 100644 --- a/net/neterror/neterror_linux.go +++ b/net/neterror/neterror_linux.go @@ -1,26 +1,26 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package neterror - -import ( - "errors" - "os" - - "golang.org/x/sys/unix" -) - -func init() { - shouldDisableUDPGSO = func(err error) bool { - var serr *os.SyscallError - if errors.As(err, &serr) { - // EIO is returned by udp_send_skb() if the device driver does not - // have tx checksumming enabled, which is a hard requirement of - // UDP_SEGMENT. See: - // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 - // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 - return serr.Err == unix.EIO - } - return false - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package neterror + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func init() { + shouldDisableUDPGSO = func(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not + // have tx checksumming enabled, which is a hard requirement of + // UDP_SEGMENT. See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false + } +} diff --git a/net/neterror/neterror_linux_test.go b/net/neterror/neterror_linux_test.go index 1d600d6b6e073..5b99060741351 100644 --- a/net/neterror/neterror_linux_test.go +++ b/net/neterror/neterror_linux_test.go @@ -1,54 +1,54 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package neterror - -import ( - "errors" - "net" - "os" - "syscall" - "testing" -) - -func TestTreatAsLostUDP(t *testing.T) { - tests := []struct { - name string - err error - want bool - }{ - {"nil", nil, false}, - {"non-nil", errors.New("foo"), false}, - {"eperm", syscall.EPERM, true}, - { - name: "operror", - err: &net.OpError{ - Op: "write", - Err: &os.SyscallError{ - Syscall: "sendto", - Err: syscall.EPERM, - }, - }, - want: true, - }, - { - name: "host_unreach", - err: &net.OpError{ - Op: "write", - Err: &os.SyscallError{ - Syscall: "sendto", - Err: syscall.EHOSTUNREACH, - }, - }, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := TreatAsLostUDP(tt.err); got != tt.want { - t.Errorf("got = %v; want %v", got, tt.want) - } - }) - } - -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package neterror + +import ( + "errors" + "net" + "os" + "syscall" + "testing" +) + +func TestTreatAsLostUDP(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"non-nil", errors.New("foo"), false}, + {"eperm", syscall.EPERM, true}, + { + name: "operror", + err: &net.OpError{ + Op: "write", + Err: &os.SyscallError{ + Syscall: "sendto", + Err: syscall.EPERM, + }, + }, + want: true, + }, + { + name: "host_unreach", + err: &net.OpError{ + Op: "write", + Err: &os.SyscallError{ + Syscall: "sendto", + Err: syscall.EHOSTUNREACH, + }, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := TreatAsLostUDP(tt.err); got != tt.want { + t.Errorf("got = %v; want %v", got, tt.want) + } + }) + } + +} diff --git a/net/neterror/neterror_windows.go b/net/neterror/neterror_windows.go index c293ec4a96295..bf112f5ed7ab7 100644 --- a/net/neterror/neterror_windows.go +++ b/net/neterror/neterror_windows.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package neterror - -import ( - "errors" - - "golang.org/x/sys/windows" -) - -func init() { - packetWasTruncated = func(err error) bool { - return errors.Is(err, windows.WSAEMSGSIZE) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package neterror + +import ( + "errors" + + "golang.org/x/sys/windows" +) + +func init() { + packetWasTruncated = func(err error) bool { + return errors.Is(err, windows.WSAEMSGSIZE) + } +} diff --git a/net/netkernelconf/netkernelconf.go b/net/netkernelconf/netkernelconf.go index 23ec9c5b69f19..3ea502b377fdf 100644 --- a/net/netkernelconf/netkernelconf.go +++ b/net/netkernelconf/netkernelconf.go @@ -1,5 +1,5 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netkernelconf contains code for checking kernel netdev config. -package netkernelconf +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netkernelconf contains code for checking kernel netdev config. +package netkernelconf diff --git a/net/netknob/netknob.go b/net/netknob/netknob.go index 0b271fc95b720..53171f4243f8d 100644 --- a/net/netknob/netknob.go +++ b/net/netknob/netknob.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netknob has Tailscale network knobs. -package netknob - -import ( - "runtime" - "time" -) - -// PlatformTCPKeepAlive returns the default net.Dialer.KeepAlive -// value for the current runtime.GOOS. -func PlatformTCPKeepAlive() time.Duration { - switch runtime.GOOS { - case "ios", "android": - // Disable TCP keep-alives on mobile platforms. - // See https://github.com/golang/go/issues/48622. - // - // TODO(bradfitz): in 1.17.x, try disabling TCP - // keep-alives on for all platforms. - return -1 - } - - // Otherwise, default to 30 seconds, which is mostly what we - // used to do. In some places we used the zero value, which Go - // defaults to 15 seconds. But 30 seconds is fine. - return 30 * time.Second -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netknob has Tailscale network knobs. +package netknob + +import ( + "runtime" + "time" +) + +// PlatformTCPKeepAlive returns the default net.Dialer.KeepAlive +// value for the current runtime.GOOS. +func PlatformTCPKeepAlive() time.Duration { + switch runtime.GOOS { + case "ios", "android": + // Disable TCP keep-alives on mobile platforms. + // See https://github.com/golang/go/issues/48622. + // + // TODO(bradfitz): in 1.17.x, try disabling TCP + // keep-alives on for all platforms. + return -1 + } + + // Otherwise, default to 30 seconds, which is mostly what we + // used to do. In some places we used the zero value, which Go + // defaults to 15 seconds. But 30 seconds is fine. + return 30 * time.Second +} diff --git a/net/netmon/netmon_darwin_test.go b/net/netmon/netmon_darwin_test.go index 77a212683e035..84c67cf6fa3e2 100644 --- a/net/netmon/netmon_darwin_test.go +++ b/net/netmon/netmon_darwin_test.go @@ -1,27 +1,27 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netmon - -import ( - "encoding/hex" - "strings" - "testing" - - "golang.org/x/net/route" -) - -func TestIssue1416RIB(t *testing.T) { - const ribHex = `32 00 05 10 30 00 00 00 00 00 00 00 04 00 00 00 14 12 04 00 06 03 06 00 65 6e 30 ac 87 a3 19 7f 82 00 00 00 0e 12 00 00 00 00 06 00 91 e0 f0 01 00 00` - rtmMsg, err := hex.DecodeString(strings.ReplaceAll(ribHex, " ", "")) - if err != nil { - t.Fatal(err) - } - msgs, err := route.ParseRIB(route.RIBTypeRoute, rtmMsg) - if err != nil { - t.Logf("ParseRIB: %v", err) - t.Skip("skipping on known failure; see https://github.com/tailscale/tailscale/issues/1416") - t.Fatal(err) - } - t.Logf("Got: %#v", msgs) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netmon + +import ( + "encoding/hex" + "strings" + "testing" + + "golang.org/x/net/route" +) + +func TestIssue1416RIB(t *testing.T) { + const ribHex = `32 00 05 10 30 00 00 00 00 00 00 00 04 00 00 00 14 12 04 00 06 03 06 00 65 6e 30 ac 87 a3 19 7f 82 00 00 00 0e 12 00 00 00 00 06 00 91 e0 f0 01 00 00` + rtmMsg, err := hex.DecodeString(strings.ReplaceAll(ribHex, " ", "")) + if err != nil { + t.Fatal(err) + } + msgs, err := route.ParseRIB(route.RIBTypeRoute, rtmMsg) + if err != nil { + t.Logf("ParseRIB: %v", err) + t.Skip("skipping on known failure; see https://github.com/tailscale/tailscale/issues/1416") + t.Fatal(err) + } + t.Logf("Got: %#v", msgs) +} diff --git a/net/netmon/netmon_freebsd.go b/net/netmon/netmon_freebsd.go index 724f964c98747..30480a1d3387e 100644 --- a/net/netmon/netmon_freebsd.go +++ b/net/netmon/netmon_freebsd.go @@ -1,56 +1,56 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netmon - -import ( - "bufio" - "fmt" - "net" - "strings" - - "tailscale.com/types/logger" -) - -// unspecifiedMessage is a minimal message implementation that should not -// be ignored. In general, OS-specific implementations should use better -// types and avoid this if they can. -type unspecifiedMessage struct{} - -func (unspecifiedMessage) ignore() bool { return false } - -// devdConn implements osMon using devd(8). -type devdConn struct { - conn net.Conn -} - -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { - conn, err := net.Dial("unixpacket", "/var/run/devd.seqpacket.pipe") - if err != nil { - logf("devd dial error: %v, falling back to polling method", err) - return newPollingMon(logf, m) - } - return &devdConn{conn}, nil -} - -func (c *devdConn) IsInterestingInterface(iface string) bool { return true } - -func (c *devdConn) Close() error { - return c.conn.Close() -} - -func (c *devdConn) Receive() (message, error) { - for { - msg, err := bufio.NewReader(c.conn).ReadString('\n') - if err != nil { - return nil, fmt.Errorf("reading devd socket: %v", err) - } - // Only return messages related to the network subsystem. - if !strings.Contains(msg, "system=IFNET") { - continue - } - // TODO: this is where the devd-specific message would - // get converted into a "standard" event message and returned. - return unspecifiedMessage{}, nil - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netmon + +import ( + "bufio" + "fmt" + "net" + "strings" + + "tailscale.com/types/logger" +) + +// unspecifiedMessage is a minimal message implementation that should not +// be ignored. In general, OS-specific implementations should use better +// types and avoid this if they can. +type unspecifiedMessage struct{} + +func (unspecifiedMessage) ignore() bool { return false } + +// devdConn implements osMon using devd(8). +type devdConn struct { + conn net.Conn +} + +func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { + conn, err := net.Dial("unixpacket", "/var/run/devd.seqpacket.pipe") + if err != nil { + logf("devd dial error: %v, falling back to polling method", err) + return newPollingMon(logf, m) + } + return &devdConn{conn}, nil +} + +func (c *devdConn) IsInterestingInterface(iface string) bool { return true } + +func (c *devdConn) Close() error { + return c.conn.Close() +} + +func (c *devdConn) Receive() (message, error) { + for { + msg, err := bufio.NewReader(c.conn).ReadString('\n') + if err != nil { + return nil, fmt.Errorf("reading devd socket: %v", err) + } + // Only return messages related to the network subsystem. + if !strings.Contains(msg, "system=IFNET") { + continue + } + // TODO: this is where the devd-specific message would + // get converted into a "standard" event message and returned. + return unspecifiedMessage{}, nil + } +} diff --git a/net/netmon/netmon_linux.go b/net/netmon/netmon_linux.go index 888afa92d7612..dd23dd34263c5 100644 --- a/net/netmon/netmon_linux.go +++ b/net/netmon/netmon_linux.go @@ -1,290 +1,290 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !android - -package netmon - -import ( - "net" - "net/netip" - "time" - - "github.com/jsimonetti/rtnetlink" - "github.com/mdlayher/netlink" - "golang.org/x/sys/unix" - "tailscale.com/envknob" - "tailscale.com/net/tsaddr" - "tailscale.com/types/logger" -) - -var debugNetlinkMessages = envknob.RegisterBool("TS_DEBUG_NETLINK") - -// unspecifiedMessage is a minimal message implementation that should not -// be ignored. In general, OS-specific implementations should use better -// types and avoid this if they can. -type unspecifiedMessage struct{} - -func (unspecifiedMessage) ignore() bool { return false } - -// nlConn wraps a *netlink.Conn and returns a monitor.Message -// instead of a netlink.Message. Currently, messages are discarded, -// but down the line, when messages trigger different logic depending -// on the type of event, this provides the capability of handling -// each architecture-specific message in a generic fashion. -type nlConn struct { - logf logger.Logf - conn *netlink.Conn - buffered []netlink.Message - - // addrCache maps interface indices to a set of addresses, and is - // used to suppress duplicate RTM_NEWADDR messages. It is populated - // by RTM_NEWADDR messages and de-populated by RTM_DELADDR. See - // issue #4282. - addrCache map[uint32]map[netip.Addr]bool -} - -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { - conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ - // Routes get us most of the events of interest, but we need - // address as well to cover things like DHCP deciding to give - // us a new address upon renewal - routing wouldn't change, - // but all reachability would. - Groups: unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR | - unix.RTMGRP_IPV4_ROUTE | unix.RTMGRP_IPV6_ROUTE | - unix.RTMGRP_IPV4_RULE, // no IPV6_RULE in x/sys/unix - }) - if err != nil { - // Google Cloud Run does not implement NETLINK_ROUTE RTMGRP support - logf("monitor_linux: AF_NETLINK RTMGRP failed, falling back to polling") - return newPollingMon(logf, m) - } - return &nlConn{logf: logf, conn: conn, addrCache: make(map[uint32]map[netip.Addr]bool)}, nil -} - -func (c *nlConn) IsInterestingInterface(iface string) bool { return true } - -func (c *nlConn) Close() error { return c.conn.Close() } - -func (c *nlConn) Receive() (message, error) { - if len(c.buffered) == 0 { - var err error - c.buffered, err = c.conn.Receive() - if err != nil { - return nil, err - } - if len(c.buffered) == 0 { - // Unexpected. Not seen in wild, but sleep defensively. - time.Sleep(time.Second) - return ignoreMessage{}, nil - } - } - msg := c.buffered[0] - c.buffered = c.buffered[1:] - - // See https://github.com/torvalds/linux/blob/master/include/uapi/linux/rtnetlink.h - // And https://man7.org/linux/man-pages/man7/rtnetlink.7.html - switch msg.Header.Type { - case unix.RTM_NEWADDR, unix.RTM_DELADDR: - var rmsg rtnetlink.AddressMessage - if err := rmsg.UnmarshalBinary(msg.Data); err != nil { - c.logf("failed to parse type %v: %v", msg.Header.Type, err) - return unspecifiedMessage{}, nil - } - - nip := netaddrIP(rmsg.Attributes.Address) - - if debugNetlinkMessages() { - typ := "RTM_NEWADDR" - if msg.Header.Type == unix.RTM_DELADDR { - typ = "RTM_DELADDR" - } - - // label attributes are seemingly only populated for IPv4 addresses in the wild. - label := rmsg.Attributes.Label - if label == "" { - itf, err := net.InterfaceByIndex(int(rmsg.Index)) - if err == nil { - label = itf.Name - } - } - - c.logf("%s: %s(%d) %s / %s", typ, label, rmsg.Index, rmsg.Attributes.Address, rmsg.Attributes.Local) - } - - addrs := c.addrCache[rmsg.Index] - - // Ignore duplicate RTM_NEWADDR messages using c.addrCache to - // detect them. See nlConn.addrcache and issue #4282. - if msg.Header.Type == unix.RTM_NEWADDR { - if addrs == nil { - addrs = make(map[netip.Addr]bool) - c.addrCache[rmsg.Index] = addrs - } - - if addrs[nip] { - if debugNetlinkMessages() { - c.logf("ignored duplicate RTM_NEWADDR for %s", nip) - } - return ignoreMessage{}, nil - } - - addrs[nip] = true - } else { // msg.Header.Type == unix.RTM_DELADDR - if addrs != nil { - delete(addrs, nip) - } - - if len(addrs) == 0 { - delete(c.addrCache, rmsg.Index) - } - } - - nam := &newAddrMessage{ - IfIndex: rmsg.Index, - Addr: nip, - Delete: msg.Header.Type == unix.RTM_DELADDR, - } - if debugNetlinkMessages() { - c.logf("%+v", nam) - } - return nam, nil - case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: - typeStr := "RTM_NEWROUTE" - if msg.Header.Type == unix.RTM_DELROUTE { - typeStr = "RTM_DELROUTE" - } - var rmsg rtnetlink.RouteMessage - if err := rmsg.UnmarshalBinary(msg.Data); err != nil { - c.logf("%s: failed to parse: %v", typeStr, err) - return unspecifiedMessage{}, nil - } - src := netaddrIPPrefix(rmsg.Attributes.Src, rmsg.SrcLength) - dst := netaddrIPPrefix(rmsg.Attributes.Dst, rmsg.DstLength) - gw := netaddrIP(rmsg.Attributes.Gateway) - - if msg.Header.Type == unix.RTM_NEWROUTE && - (rmsg.Attributes.Table == 255 || rmsg.Attributes.Table == 254) && - (dst.Addr().IsMulticast() || dst.Addr().IsLinkLocalUnicast()) { - - if debugNetlinkMessages() { - c.logf("%s ignored", typeStr) - } - - // Normal Linux route changes on new interface coming up; don't log or react. - return ignoreMessage{}, nil - } - - if rmsg.Table == tsTable && dst.IsSingleIP() { - // Don't log. Spammy and normal to see a bunch of these on start-up, - // which we make ourselves. - } else if tsaddr.IsTailscaleIP(dst.Addr()) { - // Verbose only. - c.logf("%s: [v1] src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr, - condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw), - rmsg.Attributes.OutIface, rmsg.Attributes.Table) - } else { - c.logf("%s: src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr, - condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw), - rmsg.Attributes.OutIface, rmsg.Attributes.Table) - } - if msg.Header.Type == unix.RTM_DELROUTE { - // Just logging it for now. - // (Debugging https://github.com/tailscale/tailscale/issues/643) - return unspecifiedMessage{}, nil - } - - nrm := &newRouteMessage{ - Table: rmsg.Table, - Src: src, - Dst: dst, - Gateway: gw, - } - if debugNetlinkMessages() { - c.logf("%+v", nrm) - } - return nrm, nil - case unix.RTM_NEWRULE: - // Probably ourselves adding it. - return ignoreMessage{}, nil - case unix.RTM_DELRULE: - // For https://github.com/tailscale/tailscale/issues/1591 where - // systemd-networkd deletes our rules. - var rmsg rtnetlink.RouteMessage - err := rmsg.UnmarshalBinary(msg.Data) - if err != nil { - c.logf("ip rule deleted; failed to parse netlink message: %v", err) - } else { - c.logf("ip rule deleted: %+v", rmsg) - // On `ip -4 rule del pref 5210 table main`, logs: - // monitor: ip rule deleted: {Family:2 DstLength:0 SrcLength:0 Tos:0 Table:254 Protocol:0 Scope:0 Type:1 Flags:0 Attributes:{Dst: Src: Gateway: OutIface:0 Priority:5210 Table:254 Mark:4294967295 Expires: Metrics: Multipath:[]}} - } - rdm := ipRuleDeletedMessage{ - table: rmsg.Table, - priority: rmsg.Attributes.Priority, - } - if debugNetlinkMessages() { - c.logf("%+v", rdm) - } - return rdm, nil - case unix.RTM_NEWLINK, unix.RTM_DELLINK: - // This is an unhandled message, but don't print an error. - // See https://github.com/tailscale/tailscale/issues/6806 - return unspecifiedMessage{}, nil - default: - c.logf("unhandled netlink msg type %+v, %q", msg.Header, msg.Data) - return unspecifiedMessage{}, nil - } -} - -func netaddrIP(std net.IP) netip.Addr { - ip, _ := netip.AddrFromSlice(std) - return ip.Unmap() -} - -func netaddrIPPrefix(std net.IP, bits uint8) netip.Prefix { - ip, _ := netip.AddrFromSlice(std) - return netip.PrefixFrom(ip.Unmap(), int(bits)) -} - -func condNetAddrPrefix(ipp netip.Prefix) string { - if !ipp.Addr().IsValid() { - return "" - } - return ipp.String() -} - -func condNetAddrIP(ip netip.Addr) string { - if !ip.IsValid() { - return "" - } - return ip.String() -} - -// newRouteMessage is a message for a new route being added. -type newRouteMessage struct { - Src, Dst netip.Prefix - Gateway netip.Addr - Table uint8 -} - -const tsTable = 52 - -func (m *newRouteMessage) ignore() bool { - return m.Table == tsTable || tsaddr.IsTailscaleIP(m.Dst.Addr()) -} - -// newAddrMessage is a message for a new address being added. -type newAddrMessage struct { - Delete bool - Addr netip.Addr - IfIndex uint32 // interface index -} - -func (m *newAddrMessage) ignore() bool { - return tsaddr.IsTailscaleIP(m.Addr) -} - -type ignoreMessage struct{} - -func (ignoreMessage) ignore() bool { return true } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !android + +package netmon + +import ( + "net" + "net/netip" + "time" + + "github.com/jsimonetti/rtnetlink" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" + "tailscale.com/envknob" + "tailscale.com/net/tsaddr" + "tailscale.com/types/logger" +) + +var debugNetlinkMessages = envknob.RegisterBool("TS_DEBUG_NETLINK") + +// unspecifiedMessage is a minimal message implementation that should not +// be ignored. In general, OS-specific implementations should use better +// types and avoid this if they can. +type unspecifiedMessage struct{} + +func (unspecifiedMessage) ignore() bool { return false } + +// nlConn wraps a *netlink.Conn and returns a monitor.Message +// instead of a netlink.Message. Currently, messages are discarded, +// but down the line, when messages trigger different logic depending +// on the type of event, this provides the capability of handling +// each architecture-specific message in a generic fashion. +type nlConn struct { + logf logger.Logf + conn *netlink.Conn + buffered []netlink.Message + + // addrCache maps interface indices to a set of addresses, and is + // used to suppress duplicate RTM_NEWADDR messages. It is populated + // by RTM_NEWADDR messages and de-populated by RTM_DELADDR. See + // issue #4282. + addrCache map[uint32]map[netip.Addr]bool +} + +func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { + conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ + // Routes get us most of the events of interest, but we need + // address as well to cover things like DHCP deciding to give + // us a new address upon renewal - routing wouldn't change, + // but all reachability would. + Groups: unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR | + unix.RTMGRP_IPV4_ROUTE | unix.RTMGRP_IPV6_ROUTE | + unix.RTMGRP_IPV4_RULE, // no IPV6_RULE in x/sys/unix + }) + if err != nil { + // Google Cloud Run does not implement NETLINK_ROUTE RTMGRP support + logf("monitor_linux: AF_NETLINK RTMGRP failed, falling back to polling") + return newPollingMon(logf, m) + } + return &nlConn{logf: logf, conn: conn, addrCache: make(map[uint32]map[netip.Addr]bool)}, nil +} + +func (c *nlConn) IsInterestingInterface(iface string) bool { return true } + +func (c *nlConn) Close() error { return c.conn.Close() } + +func (c *nlConn) Receive() (message, error) { + if len(c.buffered) == 0 { + var err error + c.buffered, err = c.conn.Receive() + if err != nil { + return nil, err + } + if len(c.buffered) == 0 { + // Unexpected. Not seen in wild, but sleep defensively. + time.Sleep(time.Second) + return ignoreMessage{}, nil + } + } + msg := c.buffered[0] + c.buffered = c.buffered[1:] + + // See https://github.com/torvalds/linux/blob/master/include/uapi/linux/rtnetlink.h + // And https://man7.org/linux/man-pages/man7/rtnetlink.7.html + switch msg.Header.Type { + case unix.RTM_NEWADDR, unix.RTM_DELADDR: + var rmsg rtnetlink.AddressMessage + if err := rmsg.UnmarshalBinary(msg.Data); err != nil { + c.logf("failed to parse type %v: %v", msg.Header.Type, err) + return unspecifiedMessage{}, nil + } + + nip := netaddrIP(rmsg.Attributes.Address) + + if debugNetlinkMessages() { + typ := "RTM_NEWADDR" + if msg.Header.Type == unix.RTM_DELADDR { + typ = "RTM_DELADDR" + } + + // label attributes are seemingly only populated for IPv4 addresses in the wild. + label := rmsg.Attributes.Label + if label == "" { + itf, err := net.InterfaceByIndex(int(rmsg.Index)) + if err == nil { + label = itf.Name + } + } + + c.logf("%s: %s(%d) %s / %s", typ, label, rmsg.Index, rmsg.Attributes.Address, rmsg.Attributes.Local) + } + + addrs := c.addrCache[rmsg.Index] + + // Ignore duplicate RTM_NEWADDR messages using c.addrCache to + // detect them. See nlConn.addrcache and issue #4282. + if msg.Header.Type == unix.RTM_NEWADDR { + if addrs == nil { + addrs = make(map[netip.Addr]bool) + c.addrCache[rmsg.Index] = addrs + } + + if addrs[nip] { + if debugNetlinkMessages() { + c.logf("ignored duplicate RTM_NEWADDR for %s", nip) + } + return ignoreMessage{}, nil + } + + addrs[nip] = true + } else { // msg.Header.Type == unix.RTM_DELADDR + if addrs != nil { + delete(addrs, nip) + } + + if len(addrs) == 0 { + delete(c.addrCache, rmsg.Index) + } + } + + nam := &newAddrMessage{ + IfIndex: rmsg.Index, + Addr: nip, + Delete: msg.Header.Type == unix.RTM_DELADDR, + } + if debugNetlinkMessages() { + c.logf("%+v", nam) + } + return nam, nil + case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: + typeStr := "RTM_NEWROUTE" + if msg.Header.Type == unix.RTM_DELROUTE { + typeStr = "RTM_DELROUTE" + } + var rmsg rtnetlink.RouteMessage + if err := rmsg.UnmarshalBinary(msg.Data); err != nil { + c.logf("%s: failed to parse: %v", typeStr, err) + return unspecifiedMessage{}, nil + } + src := netaddrIPPrefix(rmsg.Attributes.Src, rmsg.SrcLength) + dst := netaddrIPPrefix(rmsg.Attributes.Dst, rmsg.DstLength) + gw := netaddrIP(rmsg.Attributes.Gateway) + + if msg.Header.Type == unix.RTM_NEWROUTE && + (rmsg.Attributes.Table == 255 || rmsg.Attributes.Table == 254) && + (dst.Addr().IsMulticast() || dst.Addr().IsLinkLocalUnicast()) { + + if debugNetlinkMessages() { + c.logf("%s ignored", typeStr) + } + + // Normal Linux route changes on new interface coming up; don't log or react. + return ignoreMessage{}, nil + } + + if rmsg.Table == tsTable && dst.IsSingleIP() { + // Don't log. Spammy and normal to see a bunch of these on start-up, + // which we make ourselves. + } else if tsaddr.IsTailscaleIP(dst.Addr()) { + // Verbose only. + c.logf("%s: [v1] src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr, + condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw), + rmsg.Attributes.OutIface, rmsg.Attributes.Table) + } else { + c.logf("%s: src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr, + condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw), + rmsg.Attributes.OutIface, rmsg.Attributes.Table) + } + if msg.Header.Type == unix.RTM_DELROUTE { + // Just logging it for now. + // (Debugging https://github.com/tailscale/tailscale/issues/643) + return unspecifiedMessage{}, nil + } + + nrm := &newRouteMessage{ + Table: rmsg.Table, + Src: src, + Dst: dst, + Gateway: gw, + } + if debugNetlinkMessages() { + c.logf("%+v", nrm) + } + return nrm, nil + case unix.RTM_NEWRULE: + // Probably ourselves adding it. + return ignoreMessage{}, nil + case unix.RTM_DELRULE: + // For https://github.com/tailscale/tailscale/issues/1591 where + // systemd-networkd deletes our rules. + var rmsg rtnetlink.RouteMessage + err := rmsg.UnmarshalBinary(msg.Data) + if err != nil { + c.logf("ip rule deleted; failed to parse netlink message: %v", err) + } else { + c.logf("ip rule deleted: %+v", rmsg) + // On `ip -4 rule del pref 5210 table main`, logs: + // monitor: ip rule deleted: {Family:2 DstLength:0 SrcLength:0 Tos:0 Table:254 Protocol:0 Scope:0 Type:1 Flags:0 Attributes:{Dst: Src: Gateway: OutIface:0 Priority:5210 Table:254 Mark:4294967295 Expires: Metrics: Multipath:[]}} + } + rdm := ipRuleDeletedMessage{ + table: rmsg.Table, + priority: rmsg.Attributes.Priority, + } + if debugNetlinkMessages() { + c.logf("%+v", rdm) + } + return rdm, nil + case unix.RTM_NEWLINK, unix.RTM_DELLINK: + // This is an unhandled message, but don't print an error. + // See https://github.com/tailscale/tailscale/issues/6806 + return unspecifiedMessage{}, nil + default: + c.logf("unhandled netlink msg type %+v, %q", msg.Header, msg.Data) + return unspecifiedMessage{}, nil + } +} + +func netaddrIP(std net.IP) netip.Addr { + ip, _ := netip.AddrFromSlice(std) + return ip.Unmap() +} + +func netaddrIPPrefix(std net.IP, bits uint8) netip.Prefix { + ip, _ := netip.AddrFromSlice(std) + return netip.PrefixFrom(ip.Unmap(), int(bits)) +} + +func condNetAddrPrefix(ipp netip.Prefix) string { + if !ipp.Addr().IsValid() { + return "" + } + return ipp.String() +} + +func condNetAddrIP(ip netip.Addr) string { + if !ip.IsValid() { + return "" + } + return ip.String() +} + +// newRouteMessage is a message for a new route being added. +type newRouteMessage struct { + Src, Dst netip.Prefix + Gateway netip.Addr + Table uint8 +} + +const tsTable = 52 + +func (m *newRouteMessage) ignore() bool { + return m.Table == tsTable || tsaddr.IsTailscaleIP(m.Dst.Addr()) +} + +// newAddrMessage is a message for a new address being added. +type newAddrMessage struct { + Delete bool + Addr netip.Addr + IfIndex uint32 // interface index +} + +func (m *newAddrMessage) ignore() bool { + return tsaddr.IsTailscaleIP(m.Addr) +} + +type ignoreMessage struct{} + +func (ignoreMessage) ignore() bool { return true } diff --git a/net/netmon/netmon_polling.go b/net/netmon/netmon_polling.go index 1ce4a51deadc4..3d6f94731077a 100644 --- a/net/netmon/netmon_polling.go +++ b/net/netmon/netmon_polling.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build (!linux && !freebsd && !windows && !darwin) || android - -package netmon - -import ( - "tailscale.com/types/logger" -) - -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { - return newPollingMon(logf, m) -} - -// unspecifiedMessage is a minimal message implementation that should not -// be ignored. In general, OS-specific implementations should use better -// types and avoid this if they can. -type unspecifiedMessage struct{} - -func (unspecifiedMessage) ignore() bool { return false } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (!linux && !freebsd && !windows && !darwin) || android + +package netmon + +import ( + "tailscale.com/types/logger" +) + +func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { + return newPollingMon(logf, m) +} + +// unspecifiedMessage is a minimal message implementation that should not +// be ignored. In general, OS-specific implementations should use better +// types and avoid this if they can. +type unspecifiedMessage struct{} + +func (unspecifiedMessage) ignore() bool { return false } diff --git a/net/netmon/polling.go b/net/netmon/polling.go index bb7210b94ed62..ce1618ed6c987 100644 --- a/net/netmon/polling.go +++ b/net/netmon/polling.go @@ -1,86 +1,86 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !darwin - -package netmon - -import ( - "bytes" - "errors" - "os" - "runtime" - "sync" - "time" - - "tailscale.com/types/logger" -) - -func newPollingMon(logf logger.Logf, m *Monitor) (osMon, error) { - return &pollingMon{ - logf: logf, - m: m, - stop: make(chan struct{}), - }, nil -} - -// pollingMon is a bad but portable implementation of the link monitor -// that works by polling the interface state every 10 seconds, in lieu -// of anything to subscribe to. -type pollingMon struct { - logf logger.Logf - m *Monitor - - closeOnce sync.Once - stop chan struct{} -} - -func (pm *pollingMon) IsInterestingInterface(iface string) bool { - return true -} - -func (pm *pollingMon) Close() error { - pm.closeOnce.Do(func() { - close(pm.stop) - }) - return nil -} - -func (pm *pollingMon) isCloudRun() bool { - // https: //cloud.google.com/run/docs/reference/container-contract#env-vars - if os.Getenv("K_REVISION") == "" || os.Getenv("K_CONFIGURATION") == "" || - os.Getenv("K_SERVICE") == "" || os.Getenv("PORT") == "" { - return false - } - vers, err := os.ReadFile("/proc/version") - if err != nil { - pm.logf("Failed to read /proc/version: %v", err) - return false - } - return string(bytes.TrimSpace(vers)) == "Linux version 4.4.0 #1 SMP Sun Jan 10 15:06:54 PST 2016" -} - -func (pm *pollingMon) Receive() (message, error) { - d := 10 * time.Second - if runtime.GOOS == "android" { - // We'll have Android notify the link monitor to wake up earlier, - // so this can go very slowly there, to save battery. - // https://github.com/tailscale/tailscale/issues/1427 - d = 10 * time.Minute - } else if pm.isCloudRun() { - // Cloud Run routes never change at runtime. the containers are killed within - // 15 minutes by default, set the interval long enough to be effectively infinite. - pm.logf("monitor polling: Cloud Run detected, reduce polling interval to 24h") - d = 24 * time.Hour - } - timer := time.NewTimer(d) - defer timer.Stop() - for { - select { - case <-timer.C: - return unspecifiedMessage{}, nil - case <-pm.stop: - return nil, errors.New("stopped") - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !darwin + +package netmon + +import ( + "bytes" + "errors" + "os" + "runtime" + "sync" + "time" + + "tailscale.com/types/logger" +) + +func newPollingMon(logf logger.Logf, m *Monitor) (osMon, error) { + return &pollingMon{ + logf: logf, + m: m, + stop: make(chan struct{}), + }, nil +} + +// pollingMon is a bad but portable implementation of the link monitor +// that works by polling the interface state every 10 seconds, in lieu +// of anything to subscribe to. +type pollingMon struct { + logf logger.Logf + m *Monitor + + closeOnce sync.Once + stop chan struct{} +} + +func (pm *pollingMon) IsInterestingInterface(iface string) bool { + return true +} + +func (pm *pollingMon) Close() error { + pm.closeOnce.Do(func() { + close(pm.stop) + }) + return nil +} + +func (pm *pollingMon) isCloudRun() bool { + // https: //cloud.google.com/run/docs/reference/container-contract#env-vars + if os.Getenv("K_REVISION") == "" || os.Getenv("K_CONFIGURATION") == "" || + os.Getenv("K_SERVICE") == "" || os.Getenv("PORT") == "" { + return false + } + vers, err := os.ReadFile("/proc/version") + if err != nil { + pm.logf("Failed to read /proc/version: %v", err) + return false + } + return string(bytes.TrimSpace(vers)) == "Linux version 4.4.0 #1 SMP Sun Jan 10 15:06:54 PST 2016" +} + +func (pm *pollingMon) Receive() (message, error) { + d := 10 * time.Second + if runtime.GOOS == "android" { + // We'll have Android notify the link monitor to wake up earlier, + // so this can go very slowly there, to save battery. + // https://github.com/tailscale/tailscale/issues/1427 + d = 10 * time.Minute + } else if pm.isCloudRun() { + // Cloud Run routes never change at runtime. the containers are killed within + // 15 minutes by default, set the interval long enough to be effectively infinite. + pm.logf("monitor polling: Cloud Run detected, reduce polling interval to 24h") + d = 24 * time.Hour + } + timer := time.NewTimer(d) + defer timer.Stop() + for { + select { + case <-timer.C: + return unspecifiedMessage{}, nil + case <-pm.stop: + return nil, errors.New("stopped") + } + } +} diff --git a/net/netns/netns_android.go b/net/netns/netns_android.go index 79084ff11f521..162e5c79a62fa 100644 --- a/net/netns/netns_android.go +++ b/net/netns/netns_android.go @@ -1,75 +1,75 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build android - -package netns - -import ( - "fmt" - "sync" - "syscall" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -var ( - androidProtectFuncMu sync.Mutex - androidProtectFunc func(fd int) error -) - -// UseSocketMark reports whether SO_MARK is in use. Android does not use SO_MARK. -func UseSocketMark() bool { - return false -} - -// SetAndroidProtectFunc register a func that Android provides that JNI calls into -// https://developer.android.com/reference/android/net/VpnService#protect(int) -// which is documented as: -// -// "Protect a socket from VPN connections. After protecting, data sent -// through this socket will go directly to the underlying network, so -// its traffic will not be forwarded through the VPN. This method is -// useful if some connections need to be kept outside of VPN. For -// example, a VPN tunnel should protect itself if its destination is -// covered by VPN routes. Otherwise its outgoing packets will be sent -// back to the VPN interface and cause an infinite loop. This method -// will fail if the application is not prepared or is revoked." -// -// A nil func disables the use the hook. -// -// This indirection is necessary because this is the supported, stable -// interface to use on Android, and doing the sockopts to set the -// fwmark return errors on Android. The actual implementation of -// VpnService.protect ends up doing an IPC to another process on -// Android, asking for the fwmark to be set. -func SetAndroidProtectFunc(f func(fd int) error) { - androidProtectFuncMu.Lock() - defer androidProtectFuncMu.Unlock() - androidProtectFunc = f -} - -func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { - return controlC -} - -// controlC marks c as necessary to dial in a separate network namespace. -// -// It's intentionally the same signature as net.Dialer.Control -// and net.ListenConfig.Control. -func controlC(network, address string, c syscall.RawConn) error { - var sockErr error - err := c.Control(func(fd uintptr) { - androidProtectFuncMu.Lock() - f := androidProtectFunc - androidProtectFuncMu.Unlock() - if f != nil { - sockErr = f(int(fd)) - } - }) - if err != nil { - return fmt.Errorf("RawConn.Control on %T: %w", c, err) - } - return sockErr -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build android + +package netns + +import ( + "fmt" + "sync" + "syscall" + + "tailscale.com/net/netmon" + "tailscale.com/types/logger" +) + +var ( + androidProtectFuncMu sync.Mutex + androidProtectFunc func(fd int) error +) + +// UseSocketMark reports whether SO_MARK is in use. Android does not use SO_MARK. +func UseSocketMark() bool { + return false +} + +// SetAndroidProtectFunc register a func that Android provides that JNI calls into +// https://developer.android.com/reference/android/net/VpnService#protect(int) +// which is documented as: +// +// "Protect a socket from VPN connections. After protecting, data sent +// through this socket will go directly to the underlying network, so +// its traffic will not be forwarded through the VPN. This method is +// useful if some connections need to be kept outside of VPN. For +// example, a VPN tunnel should protect itself if its destination is +// covered by VPN routes. Otherwise its outgoing packets will be sent +// back to the VPN interface and cause an infinite loop. This method +// will fail if the application is not prepared or is revoked." +// +// A nil func disables the use the hook. +// +// This indirection is necessary because this is the supported, stable +// interface to use on Android, and doing the sockopts to set the +// fwmark return errors on Android. The actual implementation of +// VpnService.protect ends up doing an IPC to another process on +// Android, asking for the fwmark to be set. +func SetAndroidProtectFunc(f func(fd int) error) { + androidProtectFuncMu.Lock() + defer androidProtectFuncMu.Unlock() + androidProtectFunc = f +} + +func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { + return controlC +} + +// controlC marks c as necessary to dial in a separate network namespace. +// +// It's intentionally the same signature as net.Dialer.Control +// and net.ListenConfig.Control. +func controlC(network, address string, c syscall.RawConn) error { + var sockErr error + err := c.Control(func(fd uintptr) { + androidProtectFuncMu.Lock() + f := androidProtectFunc + androidProtectFuncMu.Unlock() + if f != nil { + sockErr = f(int(fd)) + } + }) + if err != nil { + return fmt.Errorf("RawConn.Control on %T: %w", c, err) + } + return sockErr +} diff --git a/net/netns/netns_default.go b/net/netns/netns_default.go index 02db19e7566fa..94f24d8fa4e19 100644 --- a/net/netns/netns_default.go +++ b/net/netns/netns_default.go @@ -1,22 +1,22 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !windows && !darwin - -package netns - -import ( - "syscall" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { - return controlC -} - -// controlC does nothing to c. -func controlC(network, address string, c syscall.RawConn) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !windows && !darwin + +package netns + +import ( + "syscall" + + "tailscale.com/net/netmon" + "tailscale.com/types/logger" +) + +func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { + return controlC +} + +// controlC does nothing to c. +func controlC(network, address string, c syscall.RawConn) error { + return nil +} diff --git a/net/netns/netns_linux_test.go b/net/netns/netns_linux_test.go index cc221bcb1712c..a5000f37f0a44 100644 --- a/net/netns/netns_linux_test.go +++ b/net/netns/netns_linux_test.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netns - -import ( - "testing" -) - -func TestSocketMarkWorks(t *testing.T) { - _ = socketMarkWorks() - // we cannot actually assert whether the test runner has SO_MARK available - // or not, as we don't know. We're just checking that it doesn't panic. -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netns + +import ( + "testing" +) + +func TestSocketMarkWorks(t *testing.T) { + _ = socketMarkWorks() + // we cannot actually assert whether the test runner has SO_MARK available + // or not, as we don't know. We're just checking that it doesn't panic. +} diff --git a/net/netns/netns_test.go b/net/netns/netns_test.go index 1c6d699ac88aa..82f919b946d4a 100644 --- a/net/netns/netns_test.go +++ b/net/netns/netns_test.go @@ -1,78 +1,78 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netns contains the common code for using the Go net package -// in a logical "network namespace" to avoid routing loops where -// Tailscale-created packets would otherwise loop back through -// Tailscale routes. -// -// Despite the name netns, the exact mechanism used differs by -// operating system, and perhaps even by version of the OS. -// -// The netns package also handles connecting via SOCKS proxies when -// configured by the environment. -package netns - -import ( - "flag" - "testing" -) - -var extNetwork = flag.Bool("use-external-network", false, "use the external network in tests") - -func TestDial(t *testing.T) { - if !*extNetwork { - t.Skip("skipping test without --use-external-network") - } - d := NewDialer(t.Logf, nil) - c, err := d.Dial("tcp", "google.com:80") - if err != nil { - t.Fatal(err) - } - defer c.Close() - t.Logf("got addr %v", c.RemoteAddr()) - - c, err = d.Dial("tcp4", "google.com:80") - if err != nil { - t.Fatal(err) - } - defer c.Close() - t.Logf("got addr %v", c.RemoteAddr()) -} - -func TestIsLocalhost(t *testing.T) { - tests := []struct { - name string - host string - want bool - }{ - {"IPv4 loopback", "127.0.0.1", true}, - {"IPv4 !loopback", "192.168.0.1", false}, - {"IPv4 loopback with port", "127.0.0.1:1", true}, - {"IPv4 !loopback with port", "192.168.0.1:1", false}, - {"IPv4 unspecified", "0.0.0.0", false}, - {"IPv4 unspecified with port", "0.0.0.0:1", false}, - {"IPv6 loopback", "::1", true}, - {"IPv6 !loopback", "2001:4860:4860::8888", false}, - {"IPv6 loopback with port", "[::1]:1", true}, - {"IPv6 !loopback with port", "[2001:4860:4860::8888]:1", false}, - {"IPv6 unspecified", "::", false}, - {"IPv6 unspecified with port", "[::]:1", false}, - {"empty", "", false}, - {"hostname", "example.com", false}, - {"localhost", "localhost", true}, - {"localhost6", "localhost6", true}, - {"localhost with port", "localhost:1", true}, - {"localhost6 with port", "localhost6:1", true}, - {"ip6-localhost", "ip6-localhost", true}, - {"ip6-localhost with port", "ip6-localhost:1", true}, - {"ip6-loopback", "ip6-loopback", true}, - {"ip6-loopback with port", "ip6-loopback:1", true}, - } - - for _, test := range tests { - if got := isLocalhost(test.host); got != test.want { - t.Errorf("isLocalhost(%q) = %v, want %v", test.name, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netns contains the common code for using the Go net package +// in a logical "network namespace" to avoid routing loops where +// Tailscale-created packets would otherwise loop back through +// Tailscale routes. +// +// Despite the name netns, the exact mechanism used differs by +// operating system, and perhaps even by version of the OS. +// +// The netns package also handles connecting via SOCKS proxies when +// configured by the environment. +package netns + +import ( + "flag" + "testing" +) + +var extNetwork = flag.Bool("use-external-network", false, "use the external network in tests") + +func TestDial(t *testing.T) { + if !*extNetwork { + t.Skip("skipping test without --use-external-network") + } + d := NewDialer(t.Logf, nil) + c, err := d.Dial("tcp", "google.com:80") + if err != nil { + t.Fatal(err) + } + defer c.Close() + t.Logf("got addr %v", c.RemoteAddr()) + + c, err = d.Dial("tcp4", "google.com:80") + if err != nil { + t.Fatal(err) + } + defer c.Close() + t.Logf("got addr %v", c.RemoteAddr()) +} + +func TestIsLocalhost(t *testing.T) { + tests := []struct { + name string + host string + want bool + }{ + {"IPv4 loopback", "127.0.0.1", true}, + {"IPv4 !loopback", "192.168.0.1", false}, + {"IPv4 loopback with port", "127.0.0.1:1", true}, + {"IPv4 !loopback with port", "192.168.0.1:1", false}, + {"IPv4 unspecified", "0.0.0.0", false}, + {"IPv4 unspecified with port", "0.0.0.0:1", false}, + {"IPv6 loopback", "::1", true}, + {"IPv6 !loopback", "2001:4860:4860::8888", false}, + {"IPv6 loopback with port", "[::1]:1", true}, + {"IPv6 !loopback with port", "[2001:4860:4860::8888]:1", false}, + {"IPv6 unspecified", "::", false}, + {"IPv6 unspecified with port", "[::]:1", false}, + {"empty", "", false}, + {"hostname", "example.com", false}, + {"localhost", "localhost", true}, + {"localhost6", "localhost6", true}, + {"localhost with port", "localhost:1", true}, + {"localhost6 with port", "localhost6:1", true}, + {"ip6-localhost", "ip6-localhost", true}, + {"ip6-localhost with port", "ip6-localhost:1", true}, + {"ip6-loopback", "ip6-loopback", true}, + {"ip6-loopback with port", "ip6-loopback:1", true}, + } + + for _, test := range tests { + if got := isLocalhost(test.host); got != test.want { + t.Errorf("isLocalhost(%q) = %v, want %v", test.name, got, test.want) + } + } +} diff --git a/net/netns/socks.go b/net/netns/socks.go index a3d10d3ae80c5..eea69d8651eda 100644 --- a/net/netns/socks.go +++ b/net/netns/socks.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ios && !js - -package netns - -import "golang.org/x/net/proxy" - -func init() { - wrapDialer = wrapSocks -} - -func wrapSocks(d Dialer) Dialer { - if cd, ok := proxy.FromEnvironmentUsing(d).(Dialer); ok { - return cd - } - return d -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !js + +package netns + +import "golang.org/x/net/proxy" + +func init() { + wrapDialer = wrapSocks +} + +func wrapSocks(d Dialer) Dialer { + if cd, ok := proxy.FromEnvironmentUsing(d).(Dialer); ok { + return cd + } + return d +} diff --git a/net/netstat/netstat.go b/net/netstat/netstat.go index 53121dc52e202..53c7d7757eac6 100644 --- a/net/netstat/netstat.go +++ b/net/netstat/netstat.go @@ -1,35 +1,35 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netstat returns the local machine's network connection table. -package netstat - -import ( - "errors" - "net/netip" - "runtime" -) - -var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) - -type Entry struct { - Local, Remote netip.AddrPort - Pid int - State string // TODO: type? - OSMetadata OSMetadata -} - -// Table contains local machine's TCP connection entries. -// -// Currently only TCP (IPv4 and IPv6) are included. -type Table struct { - Entries []Entry -} - -// Get returns the connection table. -// -// It returns ErrNotImplemented if the table is not available for the -// current operating system. -func Get() (*Table, error) { - return get() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netstat returns the local machine's network connection table. +package netstat + +import ( + "errors" + "net/netip" + "runtime" +) + +var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) + +type Entry struct { + Local, Remote netip.AddrPort + Pid int + State string // TODO: type? + OSMetadata OSMetadata +} + +// Table contains local machine's TCP connection entries. +// +// Currently only TCP (IPv4 and IPv6) are included. +type Table struct { + Entries []Entry +} + +// Get returns the connection table. +// +// It returns ErrNotImplemented if the table is not available for the +// current operating system. +func Get() (*Table, error) { + return get() +} diff --git a/net/netstat/netstat_noimpl.go b/net/netstat/netstat_noimpl.go index 608b1a617bc5d..e455c8ce931de 100644 --- a/net/netstat/netstat_noimpl.go +++ b/net/netstat/netstat_noimpl.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package netstat - -// OSMetadata includes any additional OS-specific information that may be -// obtained during the retrieval of a given Entry. -type OSMetadata struct{} - -func get() (*Table, error) { - return nil, ErrNotImplemented -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package netstat + +// OSMetadata includes any additional OS-specific information that may be +// obtained during the retrieval of a given Entry. +type OSMetadata struct{} + +func get() (*Table, error) { + return nil, ErrNotImplemented +} diff --git a/net/netstat/netstat_test.go b/net/netstat/netstat_test.go index 74f4fcec02616..38827df5ef65a 100644 --- a/net/netstat/netstat_test.go +++ b/net/netstat/netstat_test.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netstat - -import ( - "testing" -) - -func TestGet(t *testing.T) { - nt, err := Get() - if err == ErrNotImplemented { - t.Skip("TODO: not implemented") - } - if err != nil { - t.Fatal(err) - } - for _, e := range nt.Entries { - t.Logf("Entry: %+v", e) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netstat + +import ( + "testing" +) + +func TestGet(t *testing.T) { + nt, err := Get() + if err == ErrNotImplemented { + t.Skip("TODO: not implemented") + } + if err != nil { + t.Fatal(err) + } + for _, e := range nt.Entries { + t.Logf("Entry: %+v", e) + } +} diff --git a/net/packet/doc.go b/net/packet/doc.go index f3cb93db87e03..ce6c0c30716c6 100644 --- a/net/packet/doc.go +++ b/net/packet/doc.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package packet contains packet parsing and marshaling utilities. -// -// Parsed provides allocation-free minimal packet header decoding, for -// use in packet filtering. The other types in the package are for -// constructing and marshaling packets into []bytes. -// -// To support allocation-free parsing, this package defines IPv4 and -// IPv6 address types. You should prefer to use netaddr's types, -// except where you absolutely need allocation-free IP handling -// (i.e. in the tunnel datapath) and are willing to implement all -// codepaths and data structures twice, once per IP family. -package packet +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package packet contains packet parsing and marshaling utilities. +// +// Parsed provides allocation-free minimal packet header decoding, for +// use in packet filtering. The other types in the package are for +// constructing and marshaling packets into []bytes. +// +// To support allocation-free parsing, this package defines IPv4 and +// IPv6 address types. You should prefer to use netaddr's types, +// except where you absolutely need allocation-free IP handling +// (i.e. in the tunnel datapath) and are willing to implement all +// codepaths and data structures twice, once per IP family. +package packet diff --git a/net/packet/header.go b/net/packet/header.go index 0b1947c0abc36..dbe84429adbd8 100644 --- a/net/packet/header.go +++ b/net/packet/header.go @@ -1,66 +1,66 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "errors" - "math" -) - -const tcpHeaderLength = 20 -const sctpHeaderLength = 12 - -// maxPacketLength is the largest length that all headers support. -// IPv4 headers using uint16 for this forces an upper bound of 64KB. -const maxPacketLength = math.MaxUint16 - -var ( - // errSmallBuffer is returned when Marshal receives a buffer - // too small to contain the header to marshal. - errSmallBuffer = errors.New("buffer too small") - // errLargePacket is returned when Marshal receives a payload - // larger than the maximum representable size in header - // fields. - errLargePacket = errors.New("packet too large") -) - -// Header is a packet header capable of marshaling itself into a byte -// buffer. -type Header interface { - // Len returns the length of the marshaled packet. - Len() int - // Marshal serializes the header into buf, which must be at - // least Len() bytes long. Implementations of Marshal assume - // that bytes after the first Len() are payload bytes for the - // purpose of computing length and checksum fields. Marshal - // implementations must not allocate memory. - Marshal(buf []byte) error -} - -// HeaderChecksummer is implemented by Header implementations that -// need to do a checksum over their payloads. -type HeaderChecksummer interface { - Header - - // WriteCheck writes the correct checksum into buf, which should - // be be the already-marshalled header and payload. - WriteChecksum(buf []byte) -} - -// Generate generates a new packet with the given Header and -// payload. This function allocates memory, see Header.Marshal for an -// allocation-free option. -func Generate(h Header, payload []byte) []byte { - hlen := h.Len() - buf := make([]byte, hlen+len(payload)) - - copy(buf[hlen:], payload) - h.Marshal(buf) - - if hc, ok := h.(HeaderChecksummer); ok { - hc.WriteChecksum(buf) - } - - return buf -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "errors" + "math" +) + +const tcpHeaderLength = 20 +const sctpHeaderLength = 12 + +// maxPacketLength is the largest length that all headers support. +// IPv4 headers using uint16 for this forces an upper bound of 64KB. +const maxPacketLength = math.MaxUint16 + +var ( + // errSmallBuffer is returned when Marshal receives a buffer + // too small to contain the header to marshal. + errSmallBuffer = errors.New("buffer too small") + // errLargePacket is returned when Marshal receives a payload + // larger than the maximum representable size in header + // fields. + errLargePacket = errors.New("packet too large") +) + +// Header is a packet header capable of marshaling itself into a byte +// buffer. +type Header interface { + // Len returns the length of the marshaled packet. + Len() int + // Marshal serializes the header into buf, which must be at + // least Len() bytes long. Implementations of Marshal assume + // that bytes after the first Len() are payload bytes for the + // purpose of computing length and checksum fields. Marshal + // implementations must not allocate memory. + Marshal(buf []byte) error +} + +// HeaderChecksummer is implemented by Header implementations that +// need to do a checksum over their payloads. +type HeaderChecksummer interface { + Header + + // WriteCheck writes the correct checksum into buf, which should + // be be the already-marshalled header and payload. + WriteChecksum(buf []byte) +} + +// Generate generates a new packet with the given Header and +// payload. This function allocates memory, see Header.Marshal for an +// allocation-free option. +func Generate(h Header, payload []byte) []byte { + hlen := h.Len() + buf := make([]byte, hlen+len(payload)) + + copy(buf[hlen:], payload) + h.Marshal(buf) + + if hc, ok := h.(HeaderChecksummer); ok { + hc.WriteChecksum(buf) + } + + return buf +} diff --git a/net/packet/icmp.go b/net/packet/icmp.go index 7b86edd815384..89a7aaa32bec4 100644 --- a/net/packet/icmp.go +++ b/net/packet/icmp.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - crand "crypto/rand" - - "encoding/binary" -) - -// ICMPEchoPayload generates a new random ID/Sequence pair, and returns a uint32 -// derived from them, along with the id, sequence and given payload in a buffer. -// It returns an error if the random source could not be read. -func ICMPEchoPayload(payload []byte) (idSeq uint32, buf []byte) { - buf = make([]byte, len(payload)+4) - - // make a completely random id/sequence combo, which is very unlikely to - // collide with a running ping sequence on the host system. Errors are - // ignored, that would result in collisions, but errors reading from the - // random device are rare, and will cause this process universe to soon end. - crand.Read(buf[:4]) - - idSeq = binary.LittleEndian.Uint32(buf) - copy(buf[4:], payload) - - return -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + crand "crypto/rand" + + "encoding/binary" +) + +// ICMPEchoPayload generates a new random ID/Sequence pair, and returns a uint32 +// derived from them, along with the id, sequence and given payload in a buffer. +// It returns an error if the random source could not be read. +func ICMPEchoPayload(payload []byte) (idSeq uint32, buf []byte) { + buf = make([]byte, len(payload)+4) + + // make a completely random id/sequence combo, which is very unlikely to + // collide with a running ping sequence on the host system. Errors are + // ignored, that would result in collisions, but errors reading from the + // random device are rare, and will cause this process universe to soon end. + crand.Read(buf[:4]) + + idSeq = binary.LittleEndian.Uint32(buf) + copy(buf[4:], payload) + + return +} diff --git a/net/packet/icmp6_test.go b/net/packet/icmp6_test.go index c2fab353a582d..f34883ca41e7e 100644 --- a/net/packet/icmp6_test.go +++ b/net/packet/icmp6_test.go @@ -1,79 +1,79 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "net/netip" - "testing" - - "tailscale.com/types/ipproto" -) - -func TestICMPv6PingResponse(t *testing.T) { - pingHdr := ICMP6Header{ - IP6Header: IP6Header{ - Src: netip.MustParseAddr("1::1"), - Dst: netip.MustParseAddr("2::2"), - IPProto: ipproto.ICMPv6, - }, - Type: ICMP6EchoRequest, - Code: ICMP6NoCode, - } - - // echoReqLen is 2 bytes identifier + 2 bytes seq number. - // https://datatracker.ietf.org/doc/html/rfc4443#section-4.1 - // Packet.IsEchoRequest verifies that these 4 bytes are present. - const echoReqLen = 4 - buf := make([]byte, pingHdr.Len()+echoReqLen) - if err := pingHdr.Marshal(buf); err != nil { - t.Fatal(err) - } - - var p Parsed - p.Decode(buf) - if !p.IsEchoRequest() { - t.Fatalf("not an echo request, got: %+v", p) - } - - pingHdr.ToResponse() - buf = make([]byte, pingHdr.Len()+echoReqLen) - if err := pingHdr.Marshal(buf); err != nil { - t.Fatal(err) - } - - p.Decode(buf) - if p.IsEchoRequest() { - t.Fatalf("unexpectedly still an echo request: %+v", p) - } - if !p.IsEchoResponse() { - t.Fatalf("not an echo response: %+v", p) - } -} - -func TestICMPv6Checksum(t *testing.T) { - const req = "\x60\x0f\x07\x00\x00\x10\x3a\x40\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" + - "\x48\x43\xcd\x96\x62\x7b\x65\x28\x26\x07\xf8\xb0\x40\x0a\x08\x07" + - "\x00\x00\x00\x00\x00\x00\x20\x0e\x80\x00\x4a\x9a\x2e\xea\x00\x02" + - "\x61\xb1\x9e\xad\x00\x06\x45\xaa" - // The packet that we'd originally generated incorrectly, but with the checksum - // bytes fixed per WireShark's correct calculation: - const wantRes = "\x60\x00\xf8\xff\x00\x10\x3a\x40\x26\x07\xf8\xb0\x40\x0a\x08\x07" + - "\x00\x00\x00\x00\x00\x00\x20\x0e\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" + - "\x48\x43\xcd\x96\x62\x7b\x65\x28\x81\x00\x49\x9a\x2e\xea\x00\x02" + - "\x61\xb1\x9e\xad\x00\x06\x45\xaa" - - var p Parsed - p.Decode([]byte(req)) - if !p.IsEchoRequest() { - t.Fatalf("not an echo request, got: %+v", p) - } - - h := p.ICMP6Header() - h.ToResponse() - pong := Generate(&h, p.Payload()) - - if string(pong) != wantRes { - t.Errorf("wrong packet\n\n got: %x\nwant: %x", pong, wantRes) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "net/netip" + "testing" + + "tailscale.com/types/ipproto" +) + +func TestICMPv6PingResponse(t *testing.T) { + pingHdr := ICMP6Header{ + IP6Header: IP6Header{ + Src: netip.MustParseAddr("1::1"), + Dst: netip.MustParseAddr("2::2"), + IPProto: ipproto.ICMPv6, + }, + Type: ICMP6EchoRequest, + Code: ICMP6NoCode, + } + + // echoReqLen is 2 bytes identifier + 2 bytes seq number. + // https://datatracker.ietf.org/doc/html/rfc4443#section-4.1 + // Packet.IsEchoRequest verifies that these 4 bytes are present. + const echoReqLen = 4 + buf := make([]byte, pingHdr.Len()+echoReqLen) + if err := pingHdr.Marshal(buf); err != nil { + t.Fatal(err) + } + + var p Parsed + p.Decode(buf) + if !p.IsEchoRequest() { + t.Fatalf("not an echo request, got: %+v", p) + } + + pingHdr.ToResponse() + buf = make([]byte, pingHdr.Len()+echoReqLen) + if err := pingHdr.Marshal(buf); err != nil { + t.Fatal(err) + } + + p.Decode(buf) + if p.IsEchoRequest() { + t.Fatalf("unexpectedly still an echo request: %+v", p) + } + if !p.IsEchoResponse() { + t.Fatalf("not an echo response: %+v", p) + } +} + +func TestICMPv6Checksum(t *testing.T) { + const req = "\x60\x0f\x07\x00\x00\x10\x3a\x40\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" + + "\x48\x43\xcd\x96\x62\x7b\x65\x28\x26\x07\xf8\xb0\x40\x0a\x08\x07" + + "\x00\x00\x00\x00\x00\x00\x20\x0e\x80\x00\x4a\x9a\x2e\xea\x00\x02" + + "\x61\xb1\x9e\xad\x00\x06\x45\xaa" + // The packet that we'd originally generated incorrectly, but with the checksum + // bytes fixed per WireShark's correct calculation: + const wantRes = "\x60\x00\xf8\xff\x00\x10\x3a\x40\x26\x07\xf8\xb0\x40\x0a\x08\x07" + + "\x00\x00\x00\x00\x00\x00\x20\x0e\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" + + "\x48\x43\xcd\x96\x62\x7b\x65\x28\x81\x00\x49\x9a\x2e\xea\x00\x02" + + "\x61\xb1\x9e\xad\x00\x06\x45\xaa" + + var p Parsed + p.Decode([]byte(req)) + if !p.IsEchoRequest() { + t.Fatalf("not an echo request, got: %+v", p) + } + + h := p.ICMP6Header() + h.ToResponse() + pong := Generate(&h, p.Payload()) + + if string(pong) != wantRes { + t.Errorf("wrong packet\n\n got: %x\nwant: %x", pong, wantRes) + } +} diff --git a/net/packet/ip4.go b/net/packet/ip4.go index 596bc766d9a17..967a8dba7f57b 100644 --- a/net/packet/ip4.go +++ b/net/packet/ip4.go @@ -1,116 +1,116 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - "errors" - "net/netip" - - "tailscale.com/types/ipproto" -) - -// ip4HeaderLength is the length of an IPv4 header with no IP options. -const ip4HeaderLength = 20 - -// IP4Header represents an IPv4 packet header. -type IP4Header struct { - IPProto ipproto.Proto - IPID uint16 - Src netip.Addr - Dst netip.Addr -} - -// Len implements Header. -func (h IP4Header) Len() int { - return ip4HeaderLength -} - -var errWrongFamily = errors.New("wrong address family for src/dst IP") - -// Marshal implements Header. -func (h IP4Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - if !h.Src.Is4() || !h.Dst.Is4() { - return errWrongFamily - } - - buf[0] = 0x40 | (byte(h.Len() >> 2)) // IPv4 + IHL - buf[1] = 0x00 // DSCP + ECN - binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf))) // Total length - binary.BigEndian.PutUint16(buf[4:6], h.IPID) // ID - binary.BigEndian.PutUint16(buf[6:8], 0) // Flags + fragment offset - buf[8] = 64 // TTL - buf[9] = uint8(h.IPProto) // Inner protocol - // Blank checksum. This is necessary even though we overwrite - // it later, because the checksum computation runs over these - // bytes and expects them to be zero. - binary.BigEndian.PutUint16(buf[10:12], 0) - src := h.Src.As4() - dst := h.Dst.As4() - copy(buf[12:16], src[:]) - copy(buf[16:20], dst[:]) - - binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum - - return nil -} - -// ToResponse implements Header. -func (h *IP4Header) ToResponse() { - h.Src, h.Dst = h.Dst, h.Src - // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. - h.IPID = ^h.IPID -} - -// ip4Checksum computes an IPv4 checksum, as specified in -// https://tools.ietf.org/html/rfc1071 -func ip4Checksum(b []byte) uint16 { - var ac uint32 - i := 0 - n := len(b) - for n >= 2 { - ac += uint32(binary.BigEndian.Uint16(b[i : i+2])) - n -= 2 - i += 2 - } - if n == 1 { - ac += uint32(b[i]) << 8 - } - for (ac >> 16) > 0 { - ac = (ac >> 16) + (ac & 0xffff) - } - return uint16(^ac) -} - -// ip4PseudoHeaderOffset is the number of bytes by which the IPv4 UDP -// pseudo-header is smaller than the real IPv4 header. -const ip4PseudoHeaderOffset = 8 - -// marshalPseudo serializes h into buf in the "pseudo-header" form -// required when calculating UDP checksums. The pseudo-header starts -// at buf[ip4PseudoHeaderOffset] so as to abut the following UDP -// header, while leaving enough space in buf for a full IPv4 header. -func (h IP4Header) marshalPseudo(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - - length := len(buf) - h.Len() - src, dst := h.Src.As4(), h.Dst.As4() - copy(buf[8:12], src[:]) - copy(buf[12:16], dst[:]) - buf[16] = 0x0 - buf[17] = uint8(h.IPProto) - binary.BigEndian.PutUint16(buf[18:20], uint16(length)) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "encoding/binary" + "errors" + "net/netip" + + "tailscale.com/types/ipproto" +) + +// ip4HeaderLength is the length of an IPv4 header with no IP options. +const ip4HeaderLength = 20 + +// IP4Header represents an IPv4 packet header. +type IP4Header struct { + IPProto ipproto.Proto + IPID uint16 + Src netip.Addr + Dst netip.Addr +} + +// Len implements Header. +func (h IP4Header) Len() int { + return ip4HeaderLength +} + +var errWrongFamily = errors.New("wrong address family for src/dst IP") + +// Marshal implements Header. +func (h IP4Header) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + if !h.Src.Is4() || !h.Dst.Is4() { + return errWrongFamily + } + + buf[0] = 0x40 | (byte(h.Len() >> 2)) // IPv4 + IHL + buf[1] = 0x00 // DSCP + ECN + binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf))) // Total length + binary.BigEndian.PutUint16(buf[4:6], h.IPID) // ID + binary.BigEndian.PutUint16(buf[6:8], 0) // Flags + fragment offset + buf[8] = 64 // TTL + buf[9] = uint8(h.IPProto) // Inner protocol + // Blank checksum. This is necessary even though we overwrite + // it later, because the checksum computation runs over these + // bytes and expects them to be zero. + binary.BigEndian.PutUint16(buf[10:12], 0) + src := h.Src.As4() + dst := h.Dst.As4() + copy(buf[12:16], src[:]) + copy(buf[16:20], dst[:]) + + binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum + + return nil +} + +// ToResponse implements Header. +func (h *IP4Header) ToResponse() { + h.Src, h.Dst = h.Dst, h.Src + // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. + h.IPID = ^h.IPID +} + +// ip4Checksum computes an IPv4 checksum, as specified in +// https://tools.ietf.org/html/rfc1071 +func ip4Checksum(b []byte) uint16 { + var ac uint32 + i := 0 + n := len(b) + for n >= 2 { + ac += uint32(binary.BigEndian.Uint16(b[i : i+2])) + n -= 2 + i += 2 + } + if n == 1 { + ac += uint32(b[i]) << 8 + } + for (ac >> 16) > 0 { + ac = (ac >> 16) + (ac & 0xffff) + } + return uint16(^ac) +} + +// ip4PseudoHeaderOffset is the number of bytes by which the IPv4 UDP +// pseudo-header is smaller than the real IPv4 header. +const ip4PseudoHeaderOffset = 8 + +// marshalPseudo serializes h into buf in the "pseudo-header" form +// required when calculating UDP checksums. The pseudo-header starts +// at buf[ip4PseudoHeaderOffset] so as to abut the following UDP +// header, while leaving enough space in buf for a full IPv4 header. +func (h IP4Header) marshalPseudo(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + + length := len(buf) - h.Len() + src, dst := h.Src.As4(), h.Dst.As4() + copy(buf[8:12], src[:]) + copy(buf[12:16], dst[:]) + buf[16] = 0x0 + buf[17] = uint8(h.IPProto) + binary.BigEndian.PutUint16(buf[18:20], uint16(length)) + return nil +} diff --git a/net/packet/ip6.go b/net/packet/ip6.go index cebc46c534c04..d26b9a1619b31 100644 --- a/net/packet/ip6.go +++ b/net/packet/ip6.go @@ -1,76 +1,76 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - "net/netip" - - "tailscale.com/types/ipproto" -) - -// ip6HeaderLength is the length of an IPv6 header with no IP options. -const ip6HeaderLength = 40 - -// IP6Header represents an IPv6 packet header. -type IP6Header struct { - IPProto ipproto.Proto - IPID uint32 // only lower 20 bits used - Src netip.Addr - Dst netip.Addr -} - -// Len implements Header. -func (h IP6Header) Len() int { - return ip6HeaderLength -} - -// Marshal implements Header. -func (h IP6Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - - binary.BigEndian.PutUint32(buf[:4], h.IPID&0x000FFFFF) - buf[0] = 0x60 - binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length - buf[6] = uint8(h.IPProto) // Inner protocol - buf[7] = 64 // TTL - src, dst := h.Src.As16(), h.Dst.As16() - copy(buf[8:24], src[:]) - copy(buf[24:40], dst[:]) - - return nil -} - -// ToResponse implements Header. -func (h *IP6Header) ToResponse() { - h.Src, h.Dst = h.Dst, h.Src - // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. - h.IPID = (^h.IPID) & 0x000FFFFF -} - -// marshalPseudo serializes h into buf in the "pseudo-header" form -// required when calculating UDP checksums. -func (h IP6Header) marshalPseudo(buf []byte, proto ipproto.Proto) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - - src, dst := h.Src.As16(), h.Dst.As16() - copy(buf[:16], src[:]) - copy(buf[16:32], dst[:]) - binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len())) - buf[36] = 0 - buf[37] = 0 - buf[38] = 0 - buf[39] = byte(proto) // NextProto - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "encoding/binary" + "net/netip" + + "tailscale.com/types/ipproto" +) + +// ip6HeaderLength is the length of an IPv6 header with no IP options. +const ip6HeaderLength = 40 + +// IP6Header represents an IPv6 packet header. +type IP6Header struct { + IPProto ipproto.Proto + IPID uint32 // only lower 20 bits used + Src netip.Addr + Dst netip.Addr +} + +// Len implements Header. +func (h IP6Header) Len() int { + return ip6HeaderLength +} + +// Marshal implements Header. +func (h IP6Header) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + + binary.BigEndian.PutUint32(buf[:4], h.IPID&0x000FFFFF) + buf[0] = 0x60 + binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length + buf[6] = uint8(h.IPProto) // Inner protocol + buf[7] = 64 // TTL + src, dst := h.Src.As16(), h.Dst.As16() + copy(buf[8:24], src[:]) + copy(buf[24:40], dst[:]) + + return nil +} + +// ToResponse implements Header. +func (h *IP6Header) ToResponse() { + h.Src, h.Dst = h.Dst, h.Src + // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. + h.IPID = (^h.IPID) & 0x000FFFFF +} + +// marshalPseudo serializes h into buf in the "pseudo-header" form +// required when calculating UDP checksums. +func (h IP6Header) marshalPseudo(buf []byte, proto ipproto.Proto) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + + src, dst := h.Src.As16(), h.Dst.As16() + copy(buf[:16], src[:]) + copy(buf[16:32], dst[:]) + binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len())) + buf[36] = 0 + buf[37] = 0 + buf[38] = 0 + buf[39] = byte(proto) // NextProto + return nil +} diff --git a/net/packet/tsmp_test.go b/net/packet/tsmp_test.go index 4ec24e1ea0a4c..e261e6a4199b3 100644 --- a/net/packet/tsmp_test.go +++ b/net/packet/tsmp_test.go @@ -1,73 +1,73 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "net/netip" - "testing" -) - -func TestTailscaleRejectedHeader(t *testing.T) { - tests := []struct { - h TailscaleRejectedHeader - wantStr string - }{ - { - h: TailscaleRejectedHeader{ - IPSrc: netip.MustParseAddr("5.5.5.5"), - IPDst: netip.MustParseAddr("1.2.3.4"), - Src: netip.MustParseAddrPort("1.2.3.4:567"), - Dst: netip.MustParseAddrPort("5.5.5.5:443"), - Proto: TCP, - Reason: RejectedDueToACLs, - }, - wantStr: "TSMP-reject-flow{TCP 1.2.3.4:567 > 5.5.5.5:443}: acl", - }, - { - h: TailscaleRejectedHeader{ - IPSrc: netip.MustParseAddr("2::2"), - IPDst: netip.MustParseAddr("1::1"), - Src: netip.MustParseAddrPort("[1::1]:567"), - Dst: netip.MustParseAddrPort("[2::2]:443"), - Proto: UDP, - Reason: RejectedDueToShieldsUp, - }, - wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: shields", - }, - { - h: TailscaleRejectedHeader{ - IPSrc: netip.MustParseAddr("2::2"), - IPDst: netip.MustParseAddr("1::1"), - Src: netip.MustParseAddrPort("[1::1]:567"), - Dst: netip.MustParseAddrPort("[2::2]:443"), - Proto: UDP, - Reason: RejectedDueToIPForwarding, - MaybeBroken: true, - }, - wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: host-ip-forwarding-unavailable", - }, - } - for i, tt := range tests { - gotStr := tt.h.String() - if gotStr != tt.wantStr { - t.Errorf("%v. String = %q; want %q", i, gotStr, tt.wantStr) - continue - } - pkt := make([]byte, tt.h.Len()) - tt.h.Marshal(pkt) - - var p Parsed - p.Decode(pkt) - t.Logf("Parsed: %+v", p) - t.Logf("Parsed: %s", p.String()) - back, ok := p.AsTailscaleRejectedHeader() - if !ok { - t.Errorf("%v. %q (%02x) didn't parse back", i, gotStr, pkt) - continue - } - if back != tt.h { - t.Errorf("%v. %q parsed back as %q", i, tt.h, back) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "net/netip" + "testing" +) + +func TestTailscaleRejectedHeader(t *testing.T) { + tests := []struct { + h TailscaleRejectedHeader + wantStr string + }{ + { + h: TailscaleRejectedHeader{ + IPSrc: netip.MustParseAddr("5.5.5.5"), + IPDst: netip.MustParseAddr("1.2.3.4"), + Src: netip.MustParseAddrPort("1.2.3.4:567"), + Dst: netip.MustParseAddrPort("5.5.5.5:443"), + Proto: TCP, + Reason: RejectedDueToACLs, + }, + wantStr: "TSMP-reject-flow{TCP 1.2.3.4:567 > 5.5.5.5:443}: acl", + }, + { + h: TailscaleRejectedHeader{ + IPSrc: netip.MustParseAddr("2::2"), + IPDst: netip.MustParseAddr("1::1"), + Src: netip.MustParseAddrPort("[1::1]:567"), + Dst: netip.MustParseAddrPort("[2::2]:443"), + Proto: UDP, + Reason: RejectedDueToShieldsUp, + }, + wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: shields", + }, + { + h: TailscaleRejectedHeader{ + IPSrc: netip.MustParseAddr("2::2"), + IPDst: netip.MustParseAddr("1::1"), + Src: netip.MustParseAddrPort("[1::1]:567"), + Dst: netip.MustParseAddrPort("[2::2]:443"), + Proto: UDP, + Reason: RejectedDueToIPForwarding, + MaybeBroken: true, + }, + wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: host-ip-forwarding-unavailable", + }, + } + for i, tt := range tests { + gotStr := tt.h.String() + if gotStr != tt.wantStr { + t.Errorf("%v. String = %q; want %q", i, gotStr, tt.wantStr) + continue + } + pkt := make([]byte, tt.h.Len()) + tt.h.Marshal(pkt) + + var p Parsed + p.Decode(pkt) + t.Logf("Parsed: %+v", p) + t.Logf("Parsed: %s", p.String()) + back, ok := p.AsTailscaleRejectedHeader() + if !ok { + t.Errorf("%v. %q (%02x) didn't parse back", i, gotStr, pkt) + continue + } + if back != tt.h { + t.Errorf("%v. %q parsed back as %q", i, tt.h, back) + } + } +} diff --git a/net/packet/udp4.go b/net/packet/udp4.go index c8761baef2d36..0d5bca73e8c89 100644 --- a/net/packet/udp4.go +++ b/net/packet/udp4.go @@ -1,58 +1,58 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - - "tailscale.com/types/ipproto" -) - -// udpHeaderLength is the size of the UDP packet header, not including -// the outer IP header. -const udpHeaderLength = 8 - -// UDP4Header is an IPv4+UDP header. -type UDP4Header struct { - IP4Header - SrcPort uint16 - DstPort uint16 -} - -// Len implements Header. -func (h UDP4Header) Len() int { - return h.IP4Header.Len() + udpHeaderLength -} - -// Marshal implements Header. -func (h UDP4Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - // The caller does not need to set this. - h.IPProto = ipproto.UDP - - length := len(buf) - h.IP4Header.Len() - binary.BigEndian.PutUint16(buf[20:22], h.SrcPort) - binary.BigEndian.PutUint16(buf[22:24], h.DstPort) - binary.BigEndian.PutUint16(buf[24:26], uint16(length)) - binary.BigEndian.PutUint16(buf[26:28], 0) // blank checksum - - // UDP checksum with IP pseudo header. - h.IP4Header.marshalPseudo(buf) - binary.BigEndian.PutUint16(buf[26:28], ip4Checksum(buf[ip4PseudoHeaderOffset:])) - - h.IP4Header.Marshal(buf) - - return nil -} - -// ToResponse implements Header. -func (h *UDP4Header) ToResponse() { - h.SrcPort, h.DstPort = h.DstPort, h.SrcPort - h.IP4Header.ToResponse() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "encoding/binary" + + "tailscale.com/types/ipproto" +) + +// udpHeaderLength is the size of the UDP packet header, not including +// the outer IP header. +const udpHeaderLength = 8 + +// UDP4Header is an IPv4+UDP header. +type UDP4Header struct { + IP4Header + SrcPort uint16 + DstPort uint16 +} + +// Len implements Header. +func (h UDP4Header) Len() int { + return h.IP4Header.Len() + udpHeaderLength +} + +// Marshal implements Header. +func (h UDP4Header) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + // The caller does not need to set this. + h.IPProto = ipproto.UDP + + length := len(buf) - h.IP4Header.Len() + binary.BigEndian.PutUint16(buf[20:22], h.SrcPort) + binary.BigEndian.PutUint16(buf[22:24], h.DstPort) + binary.BigEndian.PutUint16(buf[24:26], uint16(length)) + binary.BigEndian.PutUint16(buf[26:28], 0) // blank checksum + + // UDP checksum with IP pseudo header. + h.IP4Header.marshalPseudo(buf) + binary.BigEndian.PutUint16(buf[26:28], ip4Checksum(buf[ip4PseudoHeaderOffset:])) + + h.IP4Header.Marshal(buf) + + return nil +} + +// ToResponse implements Header. +func (h *UDP4Header) ToResponse() { + h.SrcPort, h.DstPort = h.DstPort, h.SrcPort + h.IP4Header.ToResponse() +} diff --git a/net/packet/udp6.go b/net/packet/udp6.go index c8634b5080aea..10fdcb99e525c 100644 --- a/net/packet/udp6.go +++ b/net/packet/udp6.go @@ -1,54 +1,54 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - - "tailscale.com/types/ipproto" -) - -// UDP6Header is an IPv6+UDP header. -type UDP6Header struct { - IP6Header - SrcPort uint16 - DstPort uint16 -} - -// Len implements Header. -func (h UDP6Header) Len() int { - return h.IP6Header.Len() + udpHeaderLength -} - -// Marshal implements Header. -func (h UDP6Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - // The caller does not need to set this. - h.IPProto = ipproto.UDP - - length := len(buf) - h.IP6Header.Len() - binary.BigEndian.PutUint16(buf[40:42], h.SrcPort) - binary.BigEndian.PutUint16(buf[42:44], h.DstPort) - binary.BigEndian.PutUint16(buf[44:46], uint16(length)) - binary.BigEndian.PutUint16(buf[46:48], 0) // blank checksum - - // UDP checksum with IP pseudo header. - h.IP6Header.marshalPseudo(buf, ipproto.UDP) - binary.BigEndian.PutUint16(buf[46:48], ip4Checksum(buf[:])) - - h.IP6Header.Marshal(buf) - - return nil -} - -// ToResponse implements Header. -func (h *UDP6Header) ToResponse() { - h.SrcPort, h.DstPort = h.DstPort, h.SrcPort - h.IP6Header.ToResponse() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "encoding/binary" + + "tailscale.com/types/ipproto" +) + +// UDP6Header is an IPv6+UDP header. +type UDP6Header struct { + IP6Header + SrcPort uint16 + DstPort uint16 +} + +// Len implements Header. +func (h UDP6Header) Len() int { + return h.IP6Header.Len() + udpHeaderLength +} + +// Marshal implements Header. +func (h UDP6Header) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + // The caller does not need to set this. + h.IPProto = ipproto.UDP + + length := len(buf) - h.IP6Header.Len() + binary.BigEndian.PutUint16(buf[40:42], h.SrcPort) + binary.BigEndian.PutUint16(buf[42:44], h.DstPort) + binary.BigEndian.PutUint16(buf[44:46], uint16(length)) + binary.BigEndian.PutUint16(buf[46:48], 0) // blank checksum + + // UDP checksum with IP pseudo header. + h.IP6Header.marshalPseudo(buf, ipproto.UDP) + binary.BigEndian.PutUint16(buf[46:48], ip4Checksum(buf[:])) + + h.IP6Header.Marshal(buf) + + return nil +} + +// ToResponse implements Header. +func (h *UDP6Header) ToResponse() { + h.SrcPort, h.DstPort = h.DstPort, h.SrcPort + h.IP6Header.ToResponse() +} diff --git a/net/ping/ping.go b/net/ping/ping.go index f2093292a7a2c..01f3dcf2c4976 100644 --- a/net/ping/ping.go +++ b/net/ping/ping.go @@ -1,343 +1,343 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package ping allows sending ICMP echo requests to a host in order to -// determine network latency. -package ping - -import ( - "bytes" - "context" - "crypto/rand" - "encoding/binary" - "fmt" - "io" - "log" - "net" - "net/netip" - "sync" - "sync/atomic" - "time" - - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "tailscale.com/types/logger" - "tailscale.com/util/mak" - "tailscale.com/util/multierr" -) - -const ( - v4Type = "ip4:icmp" - v6Type = "ip6:icmp" -) - -type response struct { - t time.Time - err error -} - -type outstanding struct { - ch chan response - data []byte -} - -// PacketListener defines the interface required to listen to packages -// on an address. -type ListenPacketer interface { - ListenPacket(ctx context.Context, typ string, addr string) (net.PacketConn, error) -} - -// Pinger represents a set of ICMP echo requests to be sent at a single time. -// -// A new instance should be created for each concurrent set of ping requests; -// this type should not be reused. -type Pinger struct { - lp ListenPacketer - - // closed guards against send incrementing the waitgroup concurrently with close. - closed atomic.Bool - Logf logger.Logf - Verbose bool - timeNow func() time.Time - id uint16 // uint16 per RFC 792 - wg sync.WaitGroup - - // Following fields protected by mu - mu sync.Mutex - // conns is a map of "type" to net.PacketConn, type is either - // "ip4:icmp" or "ip6:icmp" - conns map[string]net.PacketConn - seq uint16 // uint16 per RFC 792 - pings map[uint16]outstanding -} - -// New creates a new Pinger. The Context provided will be used to create -// network listeners, and to set an absolute deadline (if any) on the net.Conn -func New(ctx context.Context, logf logger.Logf, lp ListenPacketer) *Pinger { - var id [2]byte - if _, err := io.ReadFull(rand.Reader, id[:]); err != nil { - panic("net/ping: New:" + err.Error()) - } - - return &Pinger{ - lp: lp, - Logf: logf, - timeNow: time.Now, - id: binary.LittleEndian.Uint16(id[:]), - pings: make(map[uint16]outstanding), - } -} - -func (p *Pinger) mkconn(ctx context.Context, typ, addr string) (net.PacketConn, error) { - if p.closed.Load() { - return nil, net.ErrClosed - } - - c, err := p.lp.ListenPacket(ctx, typ, addr) - if err != nil { - return nil, err - } - - // Start by setting the deadline from the context; note that this - // applies to all future I/O, so we only need to do it once. - deadline, ok := ctx.Deadline() - if ok { - if err := c.SetReadDeadline(deadline); err != nil { - return nil, err - } - } - - p.wg.Add(1) - go p.run(ctx, c, typ) - - return c, err -} - -// getConn creates or returns a conn matching typ which is ip4:icmp -// or ip6:icmp. -func (p *Pinger) getConn(ctx context.Context, typ string) (net.PacketConn, error) { - p.mu.Lock() - defer p.mu.Unlock() - if c, ok := p.conns[typ]; ok { - return c, nil - } - - var addr = "0.0.0.0" - if typ == v6Type { - addr = "::" - } - c, err := p.mkconn(ctx, typ, addr) - if err != nil { - return nil, err - } - mak.Set(&p.conns, typ, c) - return c, nil -} - -func (p *Pinger) logf(format string, a ...any) { - if p.Logf != nil { - p.Logf(format, a...) - } else { - log.Printf(format, a...) - } -} - -func (p *Pinger) vlogf(format string, a ...any) { - if p.Verbose { - p.logf(format, a...) - } -} - -func (p *Pinger) Close() error { - p.closed.Store(true) - - p.mu.Lock() - conns := p.conns - p.conns = nil - p.mu.Unlock() - - var errors []error - for _, c := range conns { - if err := c.Close(); err != nil { - errors = append(errors, err) - } - } - - p.wg.Wait() - p.cleanupOutstanding() - - return multierr.New(errors...) -} - -func (p *Pinger) run(ctx context.Context, conn net.PacketConn, typ string) { - defer p.wg.Done() - defer func() { - conn.Close() - p.mu.Lock() - delete(p.conns, typ) - p.mu.Unlock() - }() - buf := make([]byte, 1500) - -loop: - for { - select { - case <-ctx.Done(): - break loop - default: - } - - n, _, err := conn.ReadFrom(buf) - if err != nil { - // Ignore temporary errors; everything else is fatal - if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() { - break - } - continue - } - - p.handleResponse(buf[:n], p.timeNow(), typ) - } -} - -func (p *Pinger) cleanupOutstanding() { - // Complete outstanding requests - p.mu.Lock() - defer p.mu.Unlock() - for _, o := range p.pings { - o.ch <- response{err: net.ErrClosed} - } -} - -func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) { - // We need to handle responding to both IPv4 - // and IPv6. - var icmpType icmp.Type - switch typ { - case v4Type: - icmpType = ipv4.ICMPTypeEchoReply - case v6Type: - icmpType = ipv6.ICMPTypeEchoReply - default: - p.vlogf("handleResponse: unknown icmp.Type") - return - } - - m, err := icmp.ParseMessage(icmpType.Protocol(), buf) - if err != nil { - p.vlogf("handleResponse: invalid packet: %v", err) - return - } - - if m.Type != icmpType { - p.vlogf("handleResponse: wanted m.Type=%d; got %d", icmpType, m.Type) - return - } - - resp, ok := m.Body.(*icmp.Echo) - if !ok || resp == nil { - p.vlogf("handleResponse: wanted body=*icmp.Echo; got %v", m.Body) - return - } - - // We assume we sent this if the ID in the response is ours. - if uint16(resp.ID) != p.id { - p.vlogf("handleResponse: wanted ID=%d; got %d", p.id, resp.ID) - return - } - - // Search for existing running echo request - var o outstanding - p.mu.Lock() - if o, ok = p.pings[uint16(resp.Seq)]; ok { - // Ensure that the data matches before we delete from our map, - // so a future correct packet will be handled correctly. - if bytes.Equal(resp.Data, o.data) { - delete(p.pings, uint16(resp.Seq)) - } else { - p.vlogf("handleResponse: got response for Seq %d with mismatched data", resp.Seq) - ok = false - } - } else { - p.vlogf("handleResponse: got response for unknown Seq %d", resp.Seq) - } - p.mu.Unlock() - - if ok { - o.ch <- response{t: now} - } -} - -// Send sends an ICMP Echo Request packet to the destination, waits for a -// response, and returns the duration between when the request was sent and -// when the reply was received. -// -// If provided, "data" is sent with the packet and is compared upon receiving a -// reply. -func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Duration, error) { - // Use sequential sequence numbers on the assumption that we will not - // wrap around when using a single Pinger instance - p.mu.Lock() - p.seq++ - seq := p.seq - p.mu.Unlock() - - // Check whether the address is IPv4 or IPv6 to - // determine the icmp.Type and conn to use. - var conn net.PacketConn - var icmpType icmp.Type = ipv4.ICMPTypeEcho - ap, err := netip.ParseAddr(dest.String()) - if err != nil { - return 0, err - } - if ap.Is6() { - icmpType = ipv6.ICMPTypeEchoRequest - conn, err = p.getConn(ctx, v6Type) - } else { - conn, err = p.getConn(ctx, v4Type) - } - if err != nil { - return 0, err - } - - m := icmp.Message{ - Type: icmpType, - Code: 0, - Body: &icmp.Echo{ - ID: int(p.id), - Seq: int(seq), - Data: data, - }, - } - b, err := m.Marshal(nil) - if err != nil { - return 0, err - } - - // Register our response before sending since we could otherwise race a - // quick reply. - ch := make(chan response, 1) - p.mu.Lock() - p.pings[seq] = outstanding{ch: ch, data: data} - p.mu.Unlock() - - start := p.timeNow() - n, err := conn.WriteTo(b, dest) - if err != nil { - return 0, err - } else if n != len(b) { - return 0, fmt.Errorf("conn.WriteTo: got %v; want %v", n, len(b)) - } - - select { - case resp := <-ch: - if resp.err != nil { - return 0, resp.err - } - return resp.t.Sub(start), nil - - case <-ctx.Done(): - return 0, ctx.Err() - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ping allows sending ICMP echo requests to a host in order to +// determine network latency. +package ping + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "log" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "tailscale.com/types/logger" + "tailscale.com/util/mak" + "tailscale.com/util/multierr" +) + +const ( + v4Type = "ip4:icmp" + v6Type = "ip6:icmp" +) + +type response struct { + t time.Time + err error +} + +type outstanding struct { + ch chan response + data []byte +} + +// PacketListener defines the interface required to listen to packages +// on an address. +type ListenPacketer interface { + ListenPacket(ctx context.Context, typ string, addr string) (net.PacketConn, error) +} + +// Pinger represents a set of ICMP echo requests to be sent at a single time. +// +// A new instance should be created for each concurrent set of ping requests; +// this type should not be reused. +type Pinger struct { + lp ListenPacketer + + // closed guards against send incrementing the waitgroup concurrently with close. + closed atomic.Bool + Logf logger.Logf + Verbose bool + timeNow func() time.Time + id uint16 // uint16 per RFC 792 + wg sync.WaitGroup + + // Following fields protected by mu + mu sync.Mutex + // conns is a map of "type" to net.PacketConn, type is either + // "ip4:icmp" or "ip6:icmp" + conns map[string]net.PacketConn + seq uint16 // uint16 per RFC 792 + pings map[uint16]outstanding +} + +// New creates a new Pinger. The Context provided will be used to create +// network listeners, and to set an absolute deadline (if any) on the net.Conn +func New(ctx context.Context, logf logger.Logf, lp ListenPacketer) *Pinger { + var id [2]byte + if _, err := io.ReadFull(rand.Reader, id[:]); err != nil { + panic("net/ping: New:" + err.Error()) + } + + return &Pinger{ + lp: lp, + Logf: logf, + timeNow: time.Now, + id: binary.LittleEndian.Uint16(id[:]), + pings: make(map[uint16]outstanding), + } +} + +func (p *Pinger) mkconn(ctx context.Context, typ, addr string) (net.PacketConn, error) { + if p.closed.Load() { + return nil, net.ErrClosed + } + + c, err := p.lp.ListenPacket(ctx, typ, addr) + if err != nil { + return nil, err + } + + // Start by setting the deadline from the context; note that this + // applies to all future I/O, so we only need to do it once. + deadline, ok := ctx.Deadline() + if ok { + if err := c.SetReadDeadline(deadline); err != nil { + return nil, err + } + } + + p.wg.Add(1) + go p.run(ctx, c, typ) + + return c, err +} + +// getConn creates or returns a conn matching typ which is ip4:icmp +// or ip6:icmp. +func (p *Pinger) getConn(ctx context.Context, typ string) (net.PacketConn, error) { + p.mu.Lock() + defer p.mu.Unlock() + if c, ok := p.conns[typ]; ok { + return c, nil + } + + var addr = "0.0.0.0" + if typ == v6Type { + addr = "::" + } + c, err := p.mkconn(ctx, typ, addr) + if err != nil { + return nil, err + } + mak.Set(&p.conns, typ, c) + return c, nil +} + +func (p *Pinger) logf(format string, a ...any) { + if p.Logf != nil { + p.Logf(format, a...) + } else { + log.Printf(format, a...) + } +} + +func (p *Pinger) vlogf(format string, a ...any) { + if p.Verbose { + p.logf(format, a...) + } +} + +func (p *Pinger) Close() error { + p.closed.Store(true) + + p.mu.Lock() + conns := p.conns + p.conns = nil + p.mu.Unlock() + + var errors []error + for _, c := range conns { + if err := c.Close(); err != nil { + errors = append(errors, err) + } + } + + p.wg.Wait() + p.cleanupOutstanding() + + return multierr.New(errors...) +} + +func (p *Pinger) run(ctx context.Context, conn net.PacketConn, typ string) { + defer p.wg.Done() + defer func() { + conn.Close() + p.mu.Lock() + delete(p.conns, typ) + p.mu.Unlock() + }() + buf := make([]byte, 1500) + +loop: + for { + select { + case <-ctx.Done(): + break loop + default: + } + + n, _, err := conn.ReadFrom(buf) + if err != nil { + // Ignore temporary errors; everything else is fatal + if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() { + break + } + continue + } + + p.handleResponse(buf[:n], p.timeNow(), typ) + } +} + +func (p *Pinger) cleanupOutstanding() { + // Complete outstanding requests + p.mu.Lock() + defer p.mu.Unlock() + for _, o := range p.pings { + o.ch <- response{err: net.ErrClosed} + } +} + +func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) { + // We need to handle responding to both IPv4 + // and IPv6. + var icmpType icmp.Type + switch typ { + case v4Type: + icmpType = ipv4.ICMPTypeEchoReply + case v6Type: + icmpType = ipv6.ICMPTypeEchoReply + default: + p.vlogf("handleResponse: unknown icmp.Type") + return + } + + m, err := icmp.ParseMessage(icmpType.Protocol(), buf) + if err != nil { + p.vlogf("handleResponse: invalid packet: %v", err) + return + } + + if m.Type != icmpType { + p.vlogf("handleResponse: wanted m.Type=%d; got %d", icmpType, m.Type) + return + } + + resp, ok := m.Body.(*icmp.Echo) + if !ok || resp == nil { + p.vlogf("handleResponse: wanted body=*icmp.Echo; got %v", m.Body) + return + } + + // We assume we sent this if the ID in the response is ours. + if uint16(resp.ID) != p.id { + p.vlogf("handleResponse: wanted ID=%d; got %d", p.id, resp.ID) + return + } + + // Search for existing running echo request + var o outstanding + p.mu.Lock() + if o, ok = p.pings[uint16(resp.Seq)]; ok { + // Ensure that the data matches before we delete from our map, + // so a future correct packet will be handled correctly. + if bytes.Equal(resp.Data, o.data) { + delete(p.pings, uint16(resp.Seq)) + } else { + p.vlogf("handleResponse: got response for Seq %d with mismatched data", resp.Seq) + ok = false + } + } else { + p.vlogf("handleResponse: got response for unknown Seq %d", resp.Seq) + } + p.mu.Unlock() + + if ok { + o.ch <- response{t: now} + } +} + +// Send sends an ICMP Echo Request packet to the destination, waits for a +// response, and returns the duration between when the request was sent and +// when the reply was received. +// +// If provided, "data" is sent with the packet and is compared upon receiving a +// reply. +func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Duration, error) { + // Use sequential sequence numbers on the assumption that we will not + // wrap around when using a single Pinger instance + p.mu.Lock() + p.seq++ + seq := p.seq + p.mu.Unlock() + + // Check whether the address is IPv4 or IPv6 to + // determine the icmp.Type and conn to use. + var conn net.PacketConn + var icmpType icmp.Type = ipv4.ICMPTypeEcho + ap, err := netip.ParseAddr(dest.String()) + if err != nil { + return 0, err + } + if ap.Is6() { + icmpType = ipv6.ICMPTypeEchoRequest + conn, err = p.getConn(ctx, v6Type) + } else { + conn, err = p.getConn(ctx, v4Type) + } + if err != nil { + return 0, err + } + + m := icmp.Message{ + Type: icmpType, + Code: 0, + Body: &icmp.Echo{ + ID: int(p.id), + Seq: int(seq), + Data: data, + }, + } + b, err := m.Marshal(nil) + if err != nil { + return 0, err + } + + // Register our response before sending since we could otherwise race a + // quick reply. + ch := make(chan response, 1) + p.mu.Lock() + p.pings[seq] = outstanding{ch: ch, data: data} + p.mu.Unlock() + + start := p.timeNow() + n, err := conn.WriteTo(b, dest) + if err != nil { + return 0, err + } else if n != len(b) { + return 0, fmt.Errorf("conn.WriteTo: got %v; want %v", n, len(b)) + } + + select { + case resp := <-ch: + if resp.err != nil { + return 0, resp.err + } + return resp.t.Sub(start), nil + + case <-ctx.Done(): + return 0, ctx.Err() + } +} diff --git a/net/ping/ping_test.go b/net/ping/ping_test.go index 5232f6ada85e0..bbedbcad80e44 100644 --- a/net/ping/ping_test.go +++ b/net/ping/ping_test.go @@ -1,350 +1,350 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ping - -import ( - "context" - "errors" - "fmt" - "net" - "testing" - "time" - - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "tailscale.com/tstest" - "tailscale.com/util/mak" -) - -var ( - localhost = &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} -) - -func TestPinger(t *testing.T) { - clock := &tstest.Clock{} - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - p, closeP := mockPinger(t, clock) - defer closeP() - - bodyData := []byte("data goes here") - - // Start a ping in the background - r := make(chan time.Duration, 1) - go func() { - dur, err := p.Send(ctx, localhost, bodyData) - if err != nil { - t.Errorf("p.Send: %v", err) - r <- 0 - } else { - r <- dur - } - }() - - p.waitOutstanding(t, ctx, 1) - - // Fake a response from ourself - fakeResponse := mustMarshal(t, &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: ipv4.ICMPTypeEchoReply.Protocol(), - Body: &icmp.Echo{ - ID: 1234, - Seq: 1, - Data: bodyData, - }, - }) - - const fakeDuration = 100 * time.Millisecond - p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v4Type) - - select { - case dur := <-r: - want := fakeDuration - if dur != want { - t.Errorf("wanted ping response time = %d; got %d", want, dur) - } - case <-ctx.Done(): - t.Fatal("did not get response by timeout") - } -} - -func TestV6Pinger(t *testing.T) { - if c, err := net.ListenPacket("udp6", "::1"); err != nil { - // skip test if we can't use IPv6. - t.Skipf("IPv6 not supported: %s", err) - } else { - c.Close() - } - - clock := &tstest.Clock{} - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - p, closeP := mockPinger(t, clock) - defer closeP() - - bodyData := []byte("data goes here") - - // Start a ping in the background - r := make(chan time.Duration, 1) - go func() { - dur, err := p.Send(ctx, &net.IPAddr{IP: net.ParseIP("::")}, bodyData) - if err != nil { - t.Errorf("p.Send: %v", err) - r <- 0 - } else { - r <- dur - } - }() - - p.waitOutstanding(t, ctx, 1) - - // Fake a response from ourself - fakeResponse := mustMarshal(t, &icmp.Message{ - Type: ipv6.ICMPTypeEchoReply, - Code: ipv6.ICMPTypeEchoReply.Protocol(), - Body: &icmp.Echo{ - ID: 1234, - Seq: 1, - Data: bodyData, - }, - }) - - const fakeDuration = 100 * time.Millisecond - p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v6Type) - - select { - case dur := <-r: - want := fakeDuration - if dur != want { - t.Errorf("wanted ping response time = %d; got %d", want, dur) - } - case <-ctx.Done(): - t.Fatal("did not get response by timeout") - } -} - -func TestPingerTimeout(t *testing.T) { - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - clock := &tstest.Clock{} - p, closeP := mockPinger(t, clock) - defer closeP() - - // Send a ping in the background - r := make(chan error, 1) - go func() { - _, err := p.Send(ctx, localhost, []byte("data goes here")) - r <- err - }() - - // Wait until we're blocking - p.waitOutstanding(t, ctx, 1) - - // Close everything down - p.cleanupOutstanding() - - // Should have got an error from the ping - err := <-r - if !errors.Is(err, net.ErrClosed) { - t.Errorf("wanted errors.Is(err, net.ErrClosed); got=%v", err) - } -} - -func TestPingerMismatch(t *testing.T) { - clock := &tstest.Clock{} - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 1*time.Second) // intentionally short - defer cancel() - - p, closeP := mockPinger(t, clock) - defer closeP() - - bodyData := []byte("data goes here") - - // Start a ping in the background - r := make(chan time.Duration, 1) - go func() { - dur, err := p.Send(ctx, localhost, bodyData) - if err != nil && !errors.Is(err, context.DeadlineExceeded) { - t.Errorf("p.Send: %v", err) - r <- 0 - } else { - r <- dur - } - }() - - p.waitOutstanding(t, ctx, 1) - - // "Receive" a bunch of intentionally malformed packets that should not - // result in the Send call above returning - badPackets := []struct { - name string - pkt *icmp.Message - }{ - { - name: "wrong type", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeDestinationUnreachable, - Code: 0, - Body: &icmp.DstUnreach{}, - }, - }, - { - name: "wrong id", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ - ID: 9999, - Seq: 1, - Data: bodyData, - }, - }, - }, - { - name: "wrong seq", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ - ID: 1234, - Seq: 5, - Data: bodyData, - }, - }, - }, - { - name: "bad body", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ - ID: 1234, - Seq: 1, - - // Intentionally missing first byte - Data: bodyData[1:], - }, - }, - }, - } - - const fakeDuration = 100 * time.Millisecond - tm := clock.Now().Add(fakeDuration) - - for _, tt := range badPackets { - fakeResponse := mustMarshal(t, tt.pkt) - p.handleResponse(fakeResponse, tm, v4Type) - } - - // Also "receive" a packet that does not unmarshal as an ICMP packet - p.handleResponse([]byte("foo"), tm, v4Type) - - select { - case <-r: - t.Fatal("wanted timeout") - case <-ctx.Done(): - t.Logf("test correctly timed out") - } -} - -// udpingPacketConn will convert potentially ICMP destination addrs to UDP -// destination addrs in WriteTo so that a test that is intending to send ICMP -// traffic will instead send UDP traffic, without the higher level Pinger being -// aware of this difference. -type udpingPacketConn struct { - net.PacketConn - // destPort will be configured by the test to be the peer expected to respond to a ping. - destPort uint16 -} - -func (u *udpingPacketConn) WriteTo(body []byte, dest net.Addr) (int, error) { - switch d := dest.(type) { - case *net.IPAddr: - udpAddr := &net.UDPAddr{ - IP: d.IP, - Port: int(u.destPort), - Zone: d.Zone, - } - return u.PacketConn.WriteTo(body, udpAddr) - } - return 0, fmt.Errorf("unimplemented udpingPacketConn for %T", dest) -} - -func mockPinger(t *testing.T, clock *tstest.Clock) (*Pinger, func()) { - p := New(context.Background(), t.Logf, nil) - p.timeNow = clock.Now - p.Verbose = true - p.id = 1234 - - // In tests, we use UDP so that we can test without being root; this - // doesn't matter because we mock out the ICMP reply below to be a real - // ICMP echo reply packet. - conn4, err := net.ListenPacket("udp4", "127.0.0.1:0") - if err != nil { - t.Fatalf("net.ListenPacket: %v", err) - } - - conn6, err := net.ListenPacket("udp6", "[::]:0") - if err != nil { - t.Fatalf("net.ListenPacket: %v", err) - } - - conn4 = &udpingPacketConn{ - destPort: 12345, - PacketConn: conn4, - } - conn6 = &udpingPacketConn{ - PacketConn: conn6, - destPort: 12345, - } - - mak.Set(&p.conns, v4Type, conn4) - mak.Set(&p.conns, v6Type, conn6) - done := func() { - if err := p.Close(); err != nil { - t.Errorf("error on close: %v", err) - } - } - return p, done -} - -func mustMarshal(t *testing.T, m *icmp.Message) []byte { - t.Helper() - - b, err := m.Marshal(nil) - if err != nil { - t.Fatal(err) - } - return b -} - -func (p *Pinger) waitOutstanding(t *testing.T, ctx context.Context, count int) { - // This is a bit janky, but... we busy-loop to wait for the Send call - // to write to our map so we know that a response will be handled. - var haveMapEntry bool - for !haveMapEntry { - time.Sleep(10 * time.Millisecond) - select { - case <-ctx.Done(): - t.Error("no entry in ping map before timeout") - return - default: - } - - p.mu.Lock() - haveMapEntry = len(p.pings) == count - p.mu.Unlock() - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ping + +import ( + "context" + "errors" + "fmt" + "net" + "testing" + "time" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "tailscale.com/tstest" + "tailscale.com/util/mak" +) + +var ( + localhost = &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} +) + +func TestPinger(t *testing.T) { + clock := &tstest.Clock{} + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + p, closeP := mockPinger(t, clock) + defer closeP() + + bodyData := []byte("data goes here") + + // Start a ping in the background + r := make(chan time.Duration, 1) + go func() { + dur, err := p.Send(ctx, localhost, bodyData) + if err != nil { + t.Errorf("p.Send: %v", err) + r <- 0 + } else { + r <- dur + } + }() + + p.waitOutstanding(t, ctx, 1) + + // Fake a response from ourself + fakeResponse := mustMarshal(t, &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: ipv4.ICMPTypeEchoReply.Protocol(), + Body: &icmp.Echo{ + ID: 1234, + Seq: 1, + Data: bodyData, + }, + }) + + const fakeDuration = 100 * time.Millisecond + p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v4Type) + + select { + case dur := <-r: + want := fakeDuration + if dur != want { + t.Errorf("wanted ping response time = %d; got %d", want, dur) + } + case <-ctx.Done(): + t.Fatal("did not get response by timeout") + } +} + +func TestV6Pinger(t *testing.T) { + if c, err := net.ListenPacket("udp6", "::1"); err != nil { + // skip test if we can't use IPv6. + t.Skipf("IPv6 not supported: %s", err) + } else { + c.Close() + } + + clock := &tstest.Clock{} + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + p, closeP := mockPinger(t, clock) + defer closeP() + + bodyData := []byte("data goes here") + + // Start a ping in the background + r := make(chan time.Duration, 1) + go func() { + dur, err := p.Send(ctx, &net.IPAddr{IP: net.ParseIP("::")}, bodyData) + if err != nil { + t.Errorf("p.Send: %v", err) + r <- 0 + } else { + r <- dur + } + }() + + p.waitOutstanding(t, ctx, 1) + + // Fake a response from ourself + fakeResponse := mustMarshal(t, &icmp.Message{ + Type: ipv6.ICMPTypeEchoReply, + Code: ipv6.ICMPTypeEchoReply.Protocol(), + Body: &icmp.Echo{ + ID: 1234, + Seq: 1, + Data: bodyData, + }, + }) + + const fakeDuration = 100 * time.Millisecond + p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v6Type) + + select { + case dur := <-r: + want := fakeDuration + if dur != want { + t.Errorf("wanted ping response time = %d; got %d", want, dur) + } + case <-ctx.Done(): + t.Fatal("did not get response by timeout") + } +} + +func TestPingerTimeout(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + clock := &tstest.Clock{} + p, closeP := mockPinger(t, clock) + defer closeP() + + // Send a ping in the background + r := make(chan error, 1) + go func() { + _, err := p.Send(ctx, localhost, []byte("data goes here")) + r <- err + }() + + // Wait until we're blocking + p.waitOutstanding(t, ctx, 1) + + // Close everything down + p.cleanupOutstanding() + + // Should have got an error from the ping + err := <-r + if !errors.Is(err, net.ErrClosed) { + t.Errorf("wanted errors.Is(err, net.ErrClosed); got=%v", err) + } +} + +func TestPingerMismatch(t *testing.T) { + clock := &tstest.Clock{} + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) // intentionally short + defer cancel() + + p, closeP := mockPinger(t, clock) + defer closeP() + + bodyData := []byte("data goes here") + + // Start a ping in the background + r := make(chan time.Duration, 1) + go func() { + dur, err := p.Send(ctx, localhost, bodyData) + if err != nil && !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("p.Send: %v", err) + r <- 0 + } else { + r <- dur + } + }() + + p.waitOutstanding(t, ctx, 1) + + // "Receive" a bunch of intentionally malformed packets that should not + // result in the Send call above returning + badPackets := []struct { + name string + pkt *icmp.Message + }{ + { + name: "wrong type", + pkt: &icmp.Message{ + Type: ipv4.ICMPTypeDestinationUnreachable, + Code: 0, + Body: &icmp.DstUnreach{}, + }, + }, + { + name: "wrong id", + pkt: &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + Body: &icmp.Echo{ + ID: 9999, + Seq: 1, + Data: bodyData, + }, + }, + }, + { + name: "wrong seq", + pkt: &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + Body: &icmp.Echo{ + ID: 1234, + Seq: 5, + Data: bodyData, + }, + }, + }, + { + name: "bad body", + pkt: &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + Body: &icmp.Echo{ + ID: 1234, + Seq: 1, + + // Intentionally missing first byte + Data: bodyData[1:], + }, + }, + }, + } + + const fakeDuration = 100 * time.Millisecond + tm := clock.Now().Add(fakeDuration) + + for _, tt := range badPackets { + fakeResponse := mustMarshal(t, tt.pkt) + p.handleResponse(fakeResponse, tm, v4Type) + } + + // Also "receive" a packet that does not unmarshal as an ICMP packet + p.handleResponse([]byte("foo"), tm, v4Type) + + select { + case <-r: + t.Fatal("wanted timeout") + case <-ctx.Done(): + t.Logf("test correctly timed out") + } +} + +// udpingPacketConn will convert potentially ICMP destination addrs to UDP +// destination addrs in WriteTo so that a test that is intending to send ICMP +// traffic will instead send UDP traffic, without the higher level Pinger being +// aware of this difference. +type udpingPacketConn struct { + net.PacketConn + // destPort will be configured by the test to be the peer expected to respond to a ping. + destPort uint16 +} + +func (u *udpingPacketConn) WriteTo(body []byte, dest net.Addr) (int, error) { + switch d := dest.(type) { + case *net.IPAddr: + udpAddr := &net.UDPAddr{ + IP: d.IP, + Port: int(u.destPort), + Zone: d.Zone, + } + return u.PacketConn.WriteTo(body, udpAddr) + } + return 0, fmt.Errorf("unimplemented udpingPacketConn for %T", dest) +} + +func mockPinger(t *testing.T, clock *tstest.Clock) (*Pinger, func()) { + p := New(context.Background(), t.Logf, nil) + p.timeNow = clock.Now + p.Verbose = true + p.id = 1234 + + // In tests, we use UDP so that we can test without being root; this + // doesn't matter because we mock out the ICMP reply below to be a real + // ICMP echo reply packet. + conn4, err := net.ListenPacket("udp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket: %v", err) + } + + conn6, err := net.ListenPacket("udp6", "[::]:0") + if err != nil { + t.Fatalf("net.ListenPacket: %v", err) + } + + conn4 = &udpingPacketConn{ + destPort: 12345, + PacketConn: conn4, + } + conn6 = &udpingPacketConn{ + PacketConn: conn6, + destPort: 12345, + } + + mak.Set(&p.conns, v4Type, conn4) + mak.Set(&p.conns, v6Type, conn6) + done := func() { + if err := p.Close(); err != nil { + t.Errorf("error on close: %v", err) + } + } + return p, done +} + +func mustMarshal(t *testing.T, m *icmp.Message) []byte { + t.Helper() + + b, err := m.Marshal(nil) + if err != nil { + t.Fatal(err) + } + return b +} + +func (p *Pinger) waitOutstanding(t *testing.T, ctx context.Context, count int) { + // This is a bit janky, but... we busy-loop to wait for the Send call + // to write to our map so we know that a response will be handled. + var haveMapEntry bool + for !haveMapEntry { + time.Sleep(10 * time.Millisecond) + select { + case <-ctx.Done(): + t.Error("no entry in ping map before timeout") + return + default: + } + + p.mu.Lock() + haveMapEntry = len(p.pings) == count + p.mu.Unlock() + } +} diff --git a/net/portmapper/pcp_test.go b/net/portmapper/pcp_test.go index 3dece72367423..8f8eef3ef8399 100644 --- a/net/portmapper/pcp_test.go +++ b/net/portmapper/pcp_test.go @@ -1,62 +1,62 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package portmapper - -import ( - "encoding/binary" - "net/netip" - "testing" - - "tailscale.com/net/netaddr" -) - -var examplePCPMapResponse = []byte{2, 129, 0, 0, 0, 0, 28, 32, 0, 2, 155, 237, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 129, 112, 9, 24, 241, 208, 251, 45, 157, 76, 10, 188, 17, 0, 0, 0, 4, 210, 4, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 135, 180, 175, 246} - -func TestParsePCPMapResponse(t *testing.T) { - mapping, err := parsePCPMapResponse(examplePCPMapResponse) - if err != nil { - t.Fatalf("failed to parse PCP Map Response: %v", err) - } - if mapping == nil { - t.Fatalf("got nil mapping when expected non-nil") - } - expectedAddr := netip.MustParseAddrPort("135.180.175.246:1234") - if mapping.external != expectedAddr { - t.Errorf("mismatched external address, got: %v, want: %v", mapping.external, expectedAddr) - } -} - -const ( - serverResponseBit = 1 << 7 - fakeLifetimeSec = 1<<31 - 1 -) - -func buildPCPDiscoResponse(req []byte) []byte { - out := make([]byte, 24) - out[0] = pcpVersion - out[1] = req[1] | serverResponseBit - out[3] = 0 - // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. - return out -} - -func buildPCPMapResponse(req []byte) []byte { - out := make([]byte, 24+36) - out[0] = pcpVersion - out[1] = req[1] | serverResponseBit - out[3] = 0 - binary.BigEndian.PutUint32(out[4:8], 1<<30) - // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. - mapResp := out[24:] - mapReq := req[24:] - // copy nonce, protocol and internal port - copy(mapResp[:13], mapReq[:13]) - copy(mapResp[16:18], mapReq[16:18]) - // assign external port - binary.BigEndian.PutUint16(mapResp[18:20], 4242) - assignedIP := netaddr.IPv4(127, 0, 0, 1) - assignedIP16 := assignedIP.As16() - copy(mapResp[20:36], assignedIP16[:]) - return out -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package portmapper + +import ( + "encoding/binary" + "net/netip" + "testing" + + "tailscale.com/net/netaddr" +) + +var examplePCPMapResponse = []byte{2, 129, 0, 0, 0, 0, 28, 32, 0, 2, 155, 237, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 129, 112, 9, 24, 241, 208, 251, 45, 157, 76, 10, 188, 17, 0, 0, 0, 4, 210, 4, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 135, 180, 175, 246} + +func TestParsePCPMapResponse(t *testing.T) { + mapping, err := parsePCPMapResponse(examplePCPMapResponse) + if err != nil { + t.Fatalf("failed to parse PCP Map Response: %v", err) + } + if mapping == nil { + t.Fatalf("got nil mapping when expected non-nil") + } + expectedAddr := netip.MustParseAddrPort("135.180.175.246:1234") + if mapping.external != expectedAddr { + t.Errorf("mismatched external address, got: %v, want: %v", mapping.external, expectedAddr) + } +} + +const ( + serverResponseBit = 1 << 7 + fakeLifetimeSec = 1<<31 - 1 +) + +func buildPCPDiscoResponse(req []byte) []byte { + out := make([]byte, 24) + out[0] = pcpVersion + out[1] = req[1] | serverResponseBit + out[3] = 0 + // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. + return out +} + +func buildPCPMapResponse(req []byte) []byte { + out := make([]byte, 24+36) + out[0] = pcpVersion + out[1] = req[1] | serverResponseBit + out[3] = 0 + binary.BigEndian.PutUint32(out[4:8], 1<<30) + // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. + mapResp := out[24:] + mapReq := req[24:] + // copy nonce, protocol and internal port + copy(mapResp[:13], mapReq[:13]) + copy(mapResp[16:18], mapReq[16:18]) + // assign external port + binary.BigEndian.PutUint16(mapResp[18:20], 4242) + assignedIP := netaddr.IPv4(127, 0, 0, 1) + assignedIP16 := assignedIP.As16() + copy(mapResp[20:36], assignedIP16[:]) + return out +} diff --git a/net/proxymux/mux.go b/net/proxymux/mux.go index 12c3107de8339..ff5aaff3b975f 100644 --- a/net/proxymux/mux.go +++ b/net/proxymux/mux.go @@ -1,144 +1,144 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package proxymux splits a net.Listener in two, routing SOCKS5 -// connections to one and HTTP requests to the other. -// -// It allows for hosting both a SOCKS5 proxy and an HTTP proxy on the -// same listener. -package proxymux - -import ( - "io" - "net" - "sync" - "time" -) - -// SplitSOCKSAndHTTP accepts connections on ln and passes connections -// through to either socksListener or httpListener, depending the -// first byte sent by the client. -func SplitSOCKSAndHTTP(ln net.Listener) (socksListener, httpListener net.Listener) { - sl := &listener{ - addr: ln.Addr(), - c: make(chan net.Conn), - closed: make(chan struct{}), - } - hl := &listener{ - addr: ln.Addr(), - c: make(chan net.Conn), - closed: make(chan struct{}), - } - - go splitSOCKSAndHTTPListener(ln, sl, hl) - - return sl, hl -} - -func splitSOCKSAndHTTPListener(ln net.Listener, sl, hl *listener) { - for { - conn, err := ln.Accept() - if err != nil { - sl.Close() - hl.Close() - return - } - go routeConn(conn, sl, hl) - } -} - -func routeConn(c net.Conn, socksListener, httpListener *listener) { - if err := c.SetReadDeadline(time.Now().Add(15 * time.Second)); err != nil { - c.Close() - return - } - - var b [1]byte - if _, err := io.ReadFull(c, b[:]); err != nil { - c.Close() - return - } - - if err := c.SetReadDeadline(time.Time{}); err != nil { - c.Close() - return - } - - conn := &connWithOneByte{ - Conn: c, - b: b[0], - } - - // First byte of a SOCKS5 session is a version byte set to 5. - var ln *listener - if b[0] == 5 { - ln = socksListener - } else { - ln = httpListener - } - select { - case ln.c <- conn: - case <-ln.closed: - c.Close() - } -} - -type listener struct { - addr net.Addr - c chan net.Conn - mu sync.Mutex // serializes close() on closed. It's okay to receive on closed without locking. - closed chan struct{} -} - -func (ln *listener) Accept() (net.Conn, error) { - // Once closed, reliably stay closed, don't race with attempts at - // further connections. - select { - case <-ln.closed: - return nil, net.ErrClosed - default: - } - select { - case ret := <-ln.c: - return ret, nil - case <-ln.closed: - return nil, net.ErrClosed - } -} - -func (ln *listener) Close() error { - ln.mu.Lock() - defer ln.mu.Unlock() - select { - case <-ln.closed: - // Already closed - default: - close(ln.closed) - } - return nil -} - -func (ln *listener) Addr() net.Addr { - return ln.addr -} - -// connWithOneByte is a net.Conn that returns b for the first read -// request, then forwards everything else to Conn. -type connWithOneByte struct { - net.Conn - - b byte - bRead bool -} - -func (c *connWithOneByte) Read(bs []byte) (int, error) { - if c.bRead { - return c.Conn.Read(bs) - } - if len(bs) == 0 { - return 0, nil - } - c.bRead = true - bs[0] = c.b - return 1, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package proxymux splits a net.Listener in two, routing SOCKS5 +// connections to one and HTTP requests to the other. +// +// It allows for hosting both a SOCKS5 proxy and an HTTP proxy on the +// same listener. +package proxymux + +import ( + "io" + "net" + "sync" + "time" +) + +// SplitSOCKSAndHTTP accepts connections on ln and passes connections +// through to either socksListener or httpListener, depending the +// first byte sent by the client. +func SplitSOCKSAndHTTP(ln net.Listener) (socksListener, httpListener net.Listener) { + sl := &listener{ + addr: ln.Addr(), + c: make(chan net.Conn), + closed: make(chan struct{}), + } + hl := &listener{ + addr: ln.Addr(), + c: make(chan net.Conn), + closed: make(chan struct{}), + } + + go splitSOCKSAndHTTPListener(ln, sl, hl) + + return sl, hl +} + +func splitSOCKSAndHTTPListener(ln net.Listener, sl, hl *listener) { + for { + conn, err := ln.Accept() + if err != nil { + sl.Close() + hl.Close() + return + } + go routeConn(conn, sl, hl) + } +} + +func routeConn(c net.Conn, socksListener, httpListener *listener) { + if err := c.SetReadDeadline(time.Now().Add(15 * time.Second)); err != nil { + c.Close() + return + } + + var b [1]byte + if _, err := io.ReadFull(c, b[:]); err != nil { + c.Close() + return + } + + if err := c.SetReadDeadline(time.Time{}); err != nil { + c.Close() + return + } + + conn := &connWithOneByte{ + Conn: c, + b: b[0], + } + + // First byte of a SOCKS5 session is a version byte set to 5. + var ln *listener + if b[0] == 5 { + ln = socksListener + } else { + ln = httpListener + } + select { + case ln.c <- conn: + case <-ln.closed: + c.Close() + } +} + +type listener struct { + addr net.Addr + c chan net.Conn + mu sync.Mutex // serializes close() on closed. It's okay to receive on closed without locking. + closed chan struct{} +} + +func (ln *listener) Accept() (net.Conn, error) { + // Once closed, reliably stay closed, don't race with attempts at + // further connections. + select { + case <-ln.closed: + return nil, net.ErrClosed + default: + } + select { + case ret := <-ln.c: + return ret, nil + case <-ln.closed: + return nil, net.ErrClosed + } +} + +func (ln *listener) Close() error { + ln.mu.Lock() + defer ln.mu.Unlock() + select { + case <-ln.closed: + // Already closed + default: + close(ln.closed) + } + return nil +} + +func (ln *listener) Addr() net.Addr { + return ln.addr +} + +// connWithOneByte is a net.Conn that returns b for the first read +// request, then forwards everything else to Conn. +type connWithOneByte struct { + net.Conn + + b byte + bRead bool +} + +func (c *connWithOneByte) Read(bs []byte) (int, error) { + if c.bRead { + return c.Conn.Read(bs) + } + if len(bs) == 0 { + return 0, nil + } + c.bRead = true + bs[0] = c.b + return 1, nil +} diff --git a/net/routetable/routetable_darwin.go b/net/routetable/routetable_darwin.go index 7de80a66229e9..7f525ae32807a 100644 --- a/net/routetable/routetable_darwin.go +++ b/net/routetable/routetable_darwin.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin - -package routetable - -import "golang.org/x/sys/unix" - -const ( - ribType = unix.NET_RT_DUMP2 - parseType = unix.NET_RT_IFLIST2 - rmExpectedType = unix.RTM_GET2 - - // Skip routes that were cloned from a parent - skipFlags = unix.RTF_WASCLONED -) - -var flags = map[int]string{ - unix.RTF_BLACKHOLE: "blackhole", - unix.RTF_BROADCAST: "broadcast", - unix.RTF_GATEWAY: "gateway", - unix.RTF_GLOBAL: "global", - unix.RTF_HOST: "host", - unix.RTF_IFSCOPE: "ifscope", - unix.RTF_LOCAL: "local", - unix.RTF_MULTICAST: "multicast", - unix.RTF_REJECT: "reject", - unix.RTF_ROUTER: "router", - unix.RTF_STATIC: "static", - unix.RTF_UP: "up", - // More obscure flags, just to have full coverage. - unix.RTF_LLINFO: "{RTF_LLINFO}", - unix.RTF_PRCLONING: "{RTF_PRCLONING}", - unix.RTF_CLONING: "{RTF_CLONING}", -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin + +package routetable + +import "golang.org/x/sys/unix" + +const ( + ribType = unix.NET_RT_DUMP2 + parseType = unix.NET_RT_IFLIST2 + rmExpectedType = unix.RTM_GET2 + + // Skip routes that were cloned from a parent + skipFlags = unix.RTF_WASCLONED +) + +var flags = map[int]string{ + unix.RTF_BLACKHOLE: "blackhole", + unix.RTF_BROADCAST: "broadcast", + unix.RTF_GATEWAY: "gateway", + unix.RTF_GLOBAL: "global", + unix.RTF_HOST: "host", + unix.RTF_IFSCOPE: "ifscope", + unix.RTF_LOCAL: "local", + unix.RTF_MULTICAST: "multicast", + unix.RTF_REJECT: "reject", + unix.RTF_ROUTER: "router", + unix.RTF_STATIC: "static", + unix.RTF_UP: "up", + // More obscure flags, just to have full coverage. + unix.RTF_LLINFO: "{RTF_LLINFO}", + unix.RTF_PRCLONING: "{RTF_PRCLONING}", + unix.RTF_CLONING: "{RTF_CLONING}", +} diff --git a/net/routetable/routetable_freebsd.go b/net/routetable/routetable_freebsd.go index aa4e03c41236a..8e57a330246ed 100644 --- a/net/routetable/routetable_freebsd.go +++ b/net/routetable/routetable_freebsd.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build freebsd - -package routetable - -import "golang.org/x/sys/unix" - -const ( - ribType = unix.NET_RT_DUMP - parseType = unix.NET_RT_IFLIST - rmExpectedType = unix.RTM_GET - - // Nothing to skip - skipFlags = 0 -) - -var flags = map[int]string{ - unix.RTF_BLACKHOLE: "blackhole", - unix.RTF_BROADCAST: "broadcast", - unix.RTF_GATEWAY: "gateway", - unix.RTF_HOST: "host", - unix.RTF_MULTICAST: "multicast", - unix.RTF_REJECT: "reject", - unix.RTF_STATIC: "static", - unix.RTF_UP: "up", -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build freebsd + +package routetable + +import "golang.org/x/sys/unix" + +const ( + ribType = unix.NET_RT_DUMP + parseType = unix.NET_RT_IFLIST + rmExpectedType = unix.RTM_GET + + // Nothing to skip + skipFlags = 0 +) + +var flags = map[int]string{ + unix.RTF_BLACKHOLE: "blackhole", + unix.RTF_BROADCAST: "broadcast", + unix.RTF_GATEWAY: "gateway", + unix.RTF_HOST: "host", + unix.RTF_MULTICAST: "multicast", + unix.RTF_REJECT: "reject", + unix.RTF_STATIC: "static", + unix.RTF_UP: "up", +} diff --git a/net/routetable/routetable_other.go b/net/routetable/routetable_other.go index 521fe1911aaa5..35c83e374564f 100644 --- a/net/routetable/routetable_other.go +++ b/net/routetable/routetable_other.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !darwin && !freebsd - -package routetable - -import ( - "errors" - "runtime" -) - -var errUnsupported = errors.New("cannot get route table on platform " + runtime.GOOS) - -func Get(max int) ([]RouteEntry, error) { - return nil, errUnsupported -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !darwin && !freebsd + +package routetable + +import ( + "errors" + "runtime" +) + +var errUnsupported = errors.New("cannot get route table on platform " + runtime.GOOS) + +func Get(max int) ([]RouteEntry, error) { + return nil, errUnsupported +} diff --git a/net/sockstats/sockstats.go b/net/sockstats/sockstats.go index fb524a5c53684..715c1ee06e9a9 100644 --- a/net/sockstats/sockstats.go +++ b/net/sockstats/sockstats.go @@ -1,121 +1,121 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package sockstats collects statistics about network sockets used by -// the Tailscale client. The context where sockets are used must be -// instrumented with the WithSockStats() function. -// -// Only available on POSIX platforms when built with Tailscale's fork of Go. -package sockstats - -import ( - "context" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -// SockStats contains statistics for sockets instrumented with the -// WithSockStats() function -type SockStats struct { - Stats map[Label]SockStat - CurrentInterfaceCellular bool -} - -// SockStat contains the sent and received bytes for a socket instrumented with -// the WithSockStats() function. -type SockStat struct { - TxBytes uint64 - RxBytes uint64 -} - -// Label is an identifier for a socket that stats are collected for. A finite -// set of values that may be used to label a socket to encourage grouping and -// to make storage more efficient. -type Label uint8 - -//go:generate go run golang.org/x/tools/cmd/stringer -type Label -trimprefix Label - -// Labels are named after the package and function/struct that uses the socket. -// Values may be persisted and thus existing entries should not be re-numbered. -const ( - LabelControlClientAuto Label = 0 // control/controlclient/auto.go - LabelControlClientDialer Label = 1 // control/controlhttp/client.go - LabelDERPHTTPClient Label = 2 // derp/derphttp/derphttp_client.go - LabelLogtailLogger Label = 3 // logtail/logtail.go - LabelDNSForwarderDoH Label = 4 // net/dns/resolver/forwarder.go - LabelDNSForwarderUDP Label = 5 // net/dns/resolver/forwarder.go - LabelNetcheckClient Label = 6 // net/netcheck/netcheck.go - LabelPortmapperClient Label = 7 // net/portmapper/portmapper.go - LabelMagicsockConnUDP4 Label = 8 // wgengine/magicsock/magicsock.go - LabelMagicsockConnUDP6 Label = 9 // wgengine/magicsock/magicsock.go - LabelNetlogLogger Label = 10 // wgengine/netlog/logger.go - LabelSockstatlogLogger Label = 11 // log/sockstatlog/logger.go - LabelDNSForwarderTCP Label = 12 // net/dns/resolver/forwarder.go -) - -// WithSockStats instruments a context so that sockets created with it will -// have their statistics collected. -func WithSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context { - return withSockStats(ctx, label, logf) -} - -// Get returns the current socket statistics. -func Get() *SockStats { - return get() -} - -// InterfaceSockStats contains statistics for sockets instrumented with the -// WithSockStats() function, broken down by interface. The statistics may be a -// subset of the total if interfaces were added after the instrumented socket -// was created. -type InterfaceSockStats struct { - Stats map[Label]InterfaceSockStat - Interfaces []string -} - -// InterfaceSockStat contains the per-interface sent and received bytes for a -// socket instrumented with the WithSockStats() function. -type InterfaceSockStat struct { - TxBytesByInterface map[string]uint64 - RxBytesByInterface map[string]uint64 -} - -// GetWithInterfaces is a variant of Get that returns the current socket -// statistics broken down by interface. It is slightly more expensive than Get. -func GetInterfaces() *InterfaceSockStats { - return getInterfaces() -} - -// ValidationSockStats contains external validation numbers for sockets -// instrumented with WithSockStats. It may be a subset of the all sockets, -// depending on what externa measurement mechanisms the platform supports. -type ValidationSockStats struct { - Stats map[Label]ValidationSockStat -} - -// ValidationSockStat contains the validation bytes for a socket instrumented -// with WithSockStats. -type ValidationSockStat struct { - TxBytes uint64 - RxBytes uint64 -} - -// GetValidation is a variant of Get that returns external validation numbers -// for stats. It is more expensive than Get and should be used in debug -// interfaces only. -func GetValidation() *ValidationSockStats { - return getValidation() -} - -// SetNetMon configures the sockstats package to monitor the active -// interface, so that per-interface stats can be collected. -func SetNetMon(netMon *netmon.Monitor) { - setNetMon(netMon) -} - -// DebugInfo returns a string containing debug information about the tracked -// statistics. -func DebugInfo() string { - return debugInfo() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package sockstats collects statistics about network sockets used by +// the Tailscale client. The context where sockets are used must be +// instrumented with the WithSockStats() function. +// +// Only available on POSIX platforms when built with Tailscale's fork of Go. +package sockstats + +import ( + "context" + + "tailscale.com/net/netmon" + "tailscale.com/types/logger" +) + +// SockStats contains statistics for sockets instrumented with the +// WithSockStats() function +type SockStats struct { + Stats map[Label]SockStat + CurrentInterfaceCellular bool +} + +// SockStat contains the sent and received bytes for a socket instrumented with +// the WithSockStats() function. +type SockStat struct { + TxBytes uint64 + RxBytes uint64 +} + +// Label is an identifier for a socket that stats are collected for. A finite +// set of values that may be used to label a socket to encourage grouping and +// to make storage more efficient. +type Label uint8 + +//go:generate go run golang.org/x/tools/cmd/stringer -type Label -trimprefix Label + +// Labels are named after the package and function/struct that uses the socket. +// Values may be persisted and thus existing entries should not be re-numbered. +const ( + LabelControlClientAuto Label = 0 // control/controlclient/auto.go + LabelControlClientDialer Label = 1 // control/controlhttp/client.go + LabelDERPHTTPClient Label = 2 // derp/derphttp/derphttp_client.go + LabelLogtailLogger Label = 3 // logtail/logtail.go + LabelDNSForwarderDoH Label = 4 // net/dns/resolver/forwarder.go + LabelDNSForwarderUDP Label = 5 // net/dns/resolver/forwarder.go + LabelNetcheckClient Label = 6 // net/netcheck/netcheck.go + LabelPortmapperClient Label = 7 // net/portmapper/portmapper.go + LabelMagicsockConnUDP4 Label = 8 // wgengine/magicsock/magicsock.go + LabelMagicsockConnUDP6 Label = 9 // wgengine/magicsock/magicsock.go + LabelNetlogLogger Label = 10 // wgengine/netlog/logger.go + LabelSockstatlogLogger Label = 11 // log/sockstatlog/logger.go + LabelDNSForwarderTCP Label = 12 // net/dns/resolver/forwarder.go +) + +// WithSockStats instruments a context so that sockets created with it will +// have their statistics collected. +func WithSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context { + return withSockStats(ctx, label, logf) +} + +// Get returns the current socket statistics. +func Get() *SockStats { + return get() +} + +// InterfaceSockStats contains statistics for sockets instrumented with the +// WithSockStats() function, broken down by interface. The statistics may be a +// subset of the total if interfaces were added after the instrumented socket +// was created. +type InterfaceSockStats struct { + Stats map[Label]InterfaceSockStat + Interfaces []string +} + +// InterfaceSockStat contains the per-interface sent and received bytes for a +// socket instrumented with the WithSockStats() function. +type InterfaceSockStat struct { + TxBytesByInterface map[string]uint64 + RxBytesByInterface map[string]uint64 +} + +// GetWithInterfaces is a variant of Get that returns the current socket +// statistics broken down by interface. It is slightly more expensive than Get. +func GetInterfaces() *InterfaceSockStats { + return getInterfaces() +} + +// ValidationSockStats contains external validation numbers for sockets +// instrumented with WithSockStats. It may be a subset of the all sockets, +// depending on what externa measurement mechanisms the platform supports. +type ValidationSockStats struct { + Stats map[Label]ValidationSockStat +} + +// ValidationSockStat contains the validation bytes for a socket instrumented +// with WithSockStats. +type ValidationSockStat struct { + TxBytes uint64 + RxBytes uint64 +} + +// GetValidation is a variant of Get that returns external validation numbers +// for stats. It is more expensive than Get and should be used in debug +// interfaces only. +func GetValidation() *ValidationSockStats { + return getValidation() +} + +// SetNetMon configures the sockstats package to monitor the active +// interface, so that per-interface stats can be collected. +func SetNetMon(netMon *netmon.Monitor) { + setNetMon(netMon) +} + +// DebugInfo returns a string containing debug information about the tracked +// statistics. +func DebugInfo() string { + return debugInfo() +} diff --git a/net/sockstats/sockstats_noop.go b/net/sockstats/sockstats_noop.go index 797fdc42bde18..96723111ade7a 100644 --- a/net/sockstats/sockstats_noop.go +++ b/net/sockstats/sockstats_noop.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !tailscale_go || !(darwin || ios || android || ts_enable_sockstats) - -package sockstats - -import ( - "context" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -const IsAvailable = false - -func withSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context { - return ctx -} - -func get() *SockStats { - return nil -} - -func getInterfaces() *InterfaceSockStats { - return nil -} - -func getValidation() *ValidationSockStats { - return nil -} - -func setNetMon(netMon *netmon.Monitor) { -} - -func debugInfo() string { - return "" -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !tailscale_go || !(darwin || ios || android || ts_enable_sockstats) + +package sockstats + +import ( + "context" + + "tailscale.com/net/netmon" + "tailscale.com/types/logger" +) + +const IsAvailable = false + +func withSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context { + return ctx +} + +func get() *SockStats { + return nil +} + +func getInterfaces() *InterfaceSockStats { + return nil +} + +func getValidation() *ValidationSockStats { + return nil +} + +func setNetMon(netMon *netmon.Monitor) { +} + +func debugInfo() string { + return "" +} diff --git a/net/sockstats/sockstats_tsgo_darwin.go b/net/sockstats/sockstats_tsgo_darwin.go index 4b03ed6162965..321d32e04e5f0 100644 --- a/net/sockstats/sockstats_tsgo_darwin.go +++ b/net/sockstats/sockstats_tsgo_darwin.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build tailscale_go && (darwin || ios) - -package sockstats - -import ( - "syscall" - - "golang.org/x/sys/unix" -) - -func init() { - tcpConnStats = darwinTcpConnStats -} - -func darwinTcpConnStats(c syscall.RawConn) (tx, rx uint64) { - c.Control(func(fd uintptr) { - if rawInfo, err := unix.GetsockoptTCPConnectionInfo( - int(fd), - unix.IPPROTO_TCP, - unix.TCP_CONNECTION_INFO, - ); err == nil { - tx = uint64(rawInfo.Txbytes) - rx = uint64(rawInfo.Rxbytes) - } - }) - return -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build tailscale_go && (darwin || ios) + +package sockstats + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + tcpConnStats = darwinTcpConnStats +} + +func darwinTcpConnStats(c syscall.RawConn) (tx, rx uint64) { + c.Control(func(fd uintptr) { + if rawInfo, err := unix.GetsockoptTCPConnectionInfo( + int(fd), + unix.IPPROTO_TCP, + unix.TCP_CONNECTION_INFO, + ); err == nil { + tx = uint64(rawInfo.Txbytes) + rx = uint64(rawInfo.Rxbytes) + } + }) + return +} diff --git a/net/speedtest/speedtest.go b/net/speedtest/speedtest.go index 89639c12d5fc2..7ab0881cc22f9 100644 --- a/net/speedtest/speedtest.go +++ b/net/speedtest/speedtest.go @@ -1,87 +1,87 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package speedtest contains both server and client code for -// running speedtests between tailscale nodes. -package speedtest - -import ( - "time" -) - -const ( - blockSize = 2 * 1024 * 1024 // size of the block of data to send - MinDuration = 5 * time.Second // minimum duration for a test - DefaultDuration = MinDuration // default duration for a test - MaxDuration = 30 * time.Second // maximum duration for a test - version = 2 // value used when comparing client and server versions - increment = time.Second // increment to display results for, in seconds - minInterval = 10 * time.Millisecond // minimum interval length for a result to be included - DefaultPort = 20333 -) - -// config is the initial message sent to the server, that contains information on how to -// conduct the test. -type config struct { - Version int `json:"version"` - TestDuration time.Duration `json:"time"` - Direction Direction `json:"direction"` -} - -// configResponse is the response to the testConfig message. If the server has an -// error with the config, the Error variable will hold that error value. -type configResponse struct { - Error string `json:"error,omitempty"` -} - -// This represents the Result of a speedtest within a specific interval -type Result struct { - Bytes int // number of bytes sent/received during the interval - IntervalStart time.Time // start of the interval - IntervalEnd time.Time // end of the interval - Total bool // if true, this result struct represents the entire test, rather than a segment of the test -} - -func (r Result) MBitsPerSecond() float64 { - return r.MegaBits() / r.IntervalEnd.Sub(r.IntervalStart).Seconds() -} - -func (r Result) MegaBytes() float64 { - return float64(r.Bytes) / 1000000.0 -} - -func (r Result) MegaBits() float64 { - return r.MegaBytes() * 8.0 -} - -func (r Result) Interval() time.Duration { - return r.IntervalEnd.Sub(r.IntervalStart) -} - -type Direction int - -const ( - Download Direction = iota - Upload -) - -func (d Direction) String() string { - switch d { - case Upload: - return "upload" - case Download: - return "download" - default: - return "" - } -} - -func (d *Direction) Reverse() { - switch *d { - case Upload: - *d = Download - case Download: - *d = Upload - default: - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package speedtest contains both server and client code for +// running speedtests between tailscale nodes. +package speedtest + +import ( + "time" +) + +const ( + blockSize = 2 * 1024 * 1024 // size of the block of data to send + MinDuration = 5 * time.Second // minimum duration for a test + DefaultDuration = MinDuration // default duration for a test + MaxDuration = 30 * time.Second // maximum duration for a test + version = 2 // value used when comparing client and server versions + increment = time.Second // increment to display results for, in seconds + minInterval = 10 * time.Millisecond // minimum interval length for a result to be included + DefaultPort = 20333 +) + +// config is the initial message sent to the server, that contains information on how to +// conduct the test. +type config struct { + Version int `json:"version"` + TestDuration time.Duration `json:"time"` + Direction Direction `json:"direction"` +} + +// configResponse is the response to the testConfig message. If the server has an +// error with the config, the Error variable will hold that error value. +type configResponse struct { + Error string `json:"error,omitempty"` +} + +// This represents the Result of a speedtest within a specific interval +type Result struct { + Bytes int // number of bytes sent/received during the interval + IntervalStart time.Time // start of the interval + IntervalEnd time.Time // end of the interval + Total bool // if true, this result struct represents the entire test, rather than a segment of the test +} + +func (r Result) MBitsPerSecond() float64 { + return r.MegaBits() / r.IntervalEnd.Sub(r.IntervalStart).Seconds() +} + +func (r Result) MegaBytes() float64 { + return float64(r.Bytes) / 1000000.0 +} + +func (r Result) MegaBits() float64 { + return r.MegaBytes() * 8.0 +} + +func (r Result) Interval() time.Duration { + return r.IntervalEnd.Sub(r.IntervalStart) +} + +type Direction int + +const ( + Download Direction = iota + Upload +) + +func (d Direction) String() string { + switch d { + case Upload: + return "upload" + case Download: + return "download" + default: + return "" + } +} + +func (d *Direction) Reverse() { + switch *d { + case Upload: + *d = Download + case Download: + *d = Upload + default: + } +} diff --git a/net/speedtest/speedtest_client.go b/net/speedtest/speedtest_client.go index cc34c468c22c0..299a12a8dfaec 100644 --- a/net/speedtest/speedtest_client.go +++ b/net/speedtest/speedtest_client.go @@ -1,41 +1,41 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package speedtest - -import ( - "encoding/json" - "errors" - "net" - "time" -) - -// RunClient dials the given address and starts a speedtest. -// It returns any errors that come up in the tests. -// If there are no errors in the test, it returns a slice of results. -func RunClient(direction Direction, duration time.Duration, host string) ([]Result, error) { - conn, err := net.Dial("tcp", host) - if err != nil { - return nil, err - } - - conf := config{TestDuration: duration, Version: version, Direction: direction} - - defer conn.Close() - encoder := json.NewEncoder(conn) - - if err = encoder.Encode(conf); err != nil { - return nil, err - } - - var response configResponse - decoder := json.NewDecoder(conn) - if err = decoder.Decode(&response); err != nil { - return nil, err - } - if response.Error != "" { - return nil, errors.New(response.Error) - } - - return doTest(conn, conf) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package speedtest + +import ( + "encoding/json" + "errors" + "net" + "time" +) + +// RunClient dials the given address and starts a speedtest. +// It returns any errors that come up in the tests. +// If there are no errors in the test, it returns a slice of results. +func RunClient(direction Direction, duration time.Duration, host string) ([]Result, error) { + conn, err := net.Dial("tcp", host) + if err != nil { + return nil, err + } + + conf := config{TestDuration: duration, Version: version, Direction: direction} + + defer conn.Close() + encoder := json.NewEncoder(conn) + + if err = encoder.Encode(conf); err != nil { + return nil, err + } + + var response configResponse + decoder := json.NewDecoder(conn) + if err = decoder.Decode(&response); err != nil { + return nil, err + } + if response.Error != "" { + return nil, errors.New(response.Error) + } + + return doTest(conn, conf) +} diff --git a/net/speedtest/speedtest_server.go b/net/speedtest/speedtest_server.go index d2673464e3132..9dd78b195fff4 100644 --- a/net/speedtest/speedtest_server.go +++ b/net/speedtest/speedtest_server.go @@ -1,146 +1,146 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package speedtest - -import ( - "crypto/rand" - "encoding/json" - "errors" - "fmt" - "io" - "net" - "time" -) - -// Serve starts up the server on a given host and port pair. It starts to listen for -// connections and handles each one in a goroutine. Because it runs in an infinite loop, -// this function only returns if any of the speedtests return with errors, or if the -// listener is closed. -func Serve(l net.Listener) error { - for { - conn, err := l.Accept() - if errors.Is(err, net.ErrClosed) { - return nil - } - if err != nil { - return err - } - err = handleConnection(conn) - if err != nil { - return err - } - } -} - -// handleConnection handles the initial exchange between the server and the client. -// It reads the testconfig message into a config struct. If any errors occur with -// the testconfig (specifically, if there is a version mismatch), it will return those -// errors to the client with a configResponse. After the exchange, it will start -// the speed test. -func handleConnection(conn net.Conn) error { - defer conn.Close() - var conf config - - decoder := json.NewDecoder(conn) - err := decoder.Decode(&conf) - encoder := json.NewEncoder(conn) - - // Both return and encode errors that occurred before the test started. - if err != nil { - encoder.Encode(configResponse{Error: err.Error()}) - return err - } - - // The server should always be doing the opposite of what the client is doing. - conf.Direction.Reverse() - - if conf.Version != version { - err = fmt.Errorf("version mismatch! Server is version %d, client is version %d", version, conf.Version) - encoder.Encode(configResponse{Error: err.Error()}) - return err - } - - // Start the test - encoder.Encode(configResponse{}) - _, err = doTest(conn, conf) - return err -} - -// TODO include code to detect whether the code is direct vs DERP - -// doTest contains the code to run both the upload and download speedtest. -// the direction value in the config parameter determines which test to run. -func doTest(conn net.Conn, conf config) ([]Result, error) { - bufferData := make([]byte, blockSize) - - intervalBytes := 0 - totalBytes := 0 - - var currentTime time.Time - var results []Result - - if conf.Direction == Download { - conn.SetReadDeadline(time.Now().Add(conf.TestDuration).Add(5 * time.Second)) - } else { - _, err := rand.Read(bufferData) - if err != nil { - return nil, err - } - - } - - startTime := time.Now() - lastCalculated := startTime - -SpeedTestLoop: - for { - var n int - var err error - - if conf.Direction == Download { - n, err = io.ReadFull(conn, bufferData) - switch err { - case io.EOF, io.ErrUnexpectedEOF: - break SpeedTestLoop - case nil: - // successful read - default: - return nil, fmt.Errorf("unexpected error has occurred: %w", err) - } - } else { - n, err = conn.Write(bufferData) - if err != nil { - // If the write failed, there is most likely something wrong with the connection. - return nil, fmt.Errorf("upload failed: %w", err) - } - } - intervalBytes += n - - currentTime = time.Now() - // checks if the current time is more or equal to the lastCalculated time plus the increment - if currentTime.Sub(lastCalculated) >= increment { - results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false}) - lastCalculated = currentTime - totalBytes += intervalBytes - intervalBytes = 0 - } - - if conf.Direction == Upload && currentTime.Sub(startTime) > conf.TestDuration { - break SpeedTestLoop - } - } - - // get last segment - if currentTime.Sub(lastCalculated) > minInterval { - results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false}) - } - - // get total - totalBytes += intervalBytes - if currentTime.Sub(startTime) > minInterval { - results = append(results, Result{Bytes: totalBytes, IntervalStart: startTime, IntervalEnd: currentTime, Total: true}) - } - - return results, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package speedtest + +import ( + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "time" +) + +// Serve starts up the server on a given host and port pair. It starts to listen for +// connections and handles each one in a goroutine. Because it runs in an infinite loop, +// this function only returns if any of the speedtests return with errors, or if the +// listener is closed. +func Serve(l net.Listener) error { + for { + conn, err := l.Accept() + if errors.Is(err, net.ErrClosed) { + return nil + } + if err != nil { + return err + } + err = handleConnection(conn) + if err != nil { + return err + } + } +} + +// handleConnection handles the initial exchange between the server and the client. +// It reads the testconfig message into a config struct. If any errors occur with +// the testconfig (specifically, if there is a version mismatch), it will return those +// errors to the client with a configResponse. After the exchange, it will start +// the speed test. +func handleConnection(conn net.Conn) error { + defer conn.Close() + var conf config + + decoder := json.NewDecoder(conn) + err := decoder.Decode(&conf) + encoder := json.NewEncoder(conn) + + // Both return and encode errors that occurred before the test started. + if err != nil { + encoder.Encode(configResponse{Error: err.Error()}) + return err + } + + // The server should always be doing the opposite of what the client is doing. + conf.Direction.Reverse() + + if conf.Version != version { + err = fmt.Errorf("version mismatch! Server is version %d, client is version %d", version, conf.Version) + encoder.Encode(configResponse{Error: err.Error()}) + return err + } + + // Start the test + encoder.Encode(configResponse{}) + _, err = doTest(conn, conf) + return err +} + +// TODO include code to detect whether the code is direct vs DERP + +// doTest contains the code to run both the upload and download speedtest. +// the direction value in the config parameter determines which test to run. +func doTest(conn net.Conn, conf config) ([]Result, error) { + bufferData := make([]byte, blockSize) + + intervalBytes := 0 + totalBytes := 0 + + var currentTime time.Time + var results []Result + + if conf.Direction == Download { + conn.SetReadDeadline(time.Now().Add(conf.TestDuration).Add(5 * time.Second)) + } else { + _, err := rand.Read(bufferData) + if err != nil { + return nil, err + } + + } + + startTime := time.Now() + lastCalculated := startTime + +SpeedTestLoop: + for { + var n int + var err error + + if conf.Direction == Download { + n, err = io.ReadFull(conn, bufferData) + switch err { + case io.EOF, io.ErrUnexpectedEOF: + break SpeedTestLoop + case nil: + // successful read + default: + return nil, fmt.Errorf("unexpected error has occurred: %w", err) + } + } else { + n, err = conn.Write(bufferData) + if err != nil { + // If the write failed, there is most likely something wrong with the connection. + return nil, fmt.Errorf("upload failed: %w", err) + } + } + intervalBytes += n + + currentTime = time.Now() + // checks if the current time is more or equal to the lastCalculated time plus the increment + if currentTime.Sub(lastCalculated) >= increment { + results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false}) + lastCalculated = currentTime + totalBytes += intervalBytes + intervalBytes = 0 + } + + if conf.Direction == Upload && currentTime.Sub(startTime) > conf.TestDuration { + break SpeedTestLoop + } + } + + // get last segment + if currentTime.Sub(lastCalculated) > minInterval { + results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false}) + } + + // get total + totalBytes += intervalBytes + if currentTime.Sub(startTime) > minInterval { + results = append(results, Result{Bytes: totalBytes, IntervalStart: startTime, IntervalEnd: currentTime, Total: true}) + } + + return results, nil +} diff --git a/net/speedtest/speedtest_test.go b/net/speedtest/speedtest_test.go index a413e9efafcd4..55dcbeea1abdf 100644 --- a/net/speedtest/speedtest_test.go +++ b/net/speedtest/speedtest_test.go @@ -1,83 +1,83 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package speedtest - -import ( - "net" - "testing" - "time" -) - -func TestDownload(t *testing.T) { - // start a listener and find the port where the server will be listening. - l, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { l.Close() }) - - serverIP := l.Addr().String() - t.Log("server IP found:", serverIP) - - type state struct { - err error - } - displayResult := func(t *testing.T, r Result, start time.Time) { - t.Helper() - t.Logf("{ Megabytes: %.2f, Start: %.1f, End: %.1f, Total: %t }", r.MegaBytes(), r.IntervalStart.Sub(start).Seconds(), r.IntervalEnd.Sub(start).Seconds(), r.Total) - } - stateChan := make(chan state, 1) - - go func() { - err := Serve(l) - stateChan <- state{err: err} - }() - - // ensure that the test returns an appropriate number of Result structs - expectedLen := int(DefaultDuration.Seconds()) + 1 - - t.Run("download test", func(t *testing.T) { - // conduct a download test - results, err := RunClient(Download, DefaultDuration, serverIP) - - if err != nil { - t.Fatal("download test failed:", err) - } - - if len(results) < expectedLen { - t.Fatalf("download results: expected length: %d, actual length: %d", expectedLen, len(results)) - } - - start := results[0].IntervalStart - for _, result := range results { - displayResult(t, result, start) - } - }) - - t.Run("upload test", func(t *testing.T) { - // conduct an upload test - results, err := RunClient(Upload, DefaultDuration, serverIP) - - if err != nil { - t.Fatal("upload test failed:", err) - } - - if len(results) < expectedLen { - t.Fatalf("upload results: expected length: %d, actual length: %d", expectedLen, len(results)) - } - - start := results[0].IntervalStart - for _, result := range results { - displayResult(t, result, start) - } - }) - - // causes the server goroutine to finish - l.Close() - - testState := <-stateChan - if testState.err != nil { - t.Error("server error:", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package speedtest + +import ( + "net" + "testing" + "time" +) + +func TestDownload(t *testing.T) { + // start a listener and find the port where the server will be listening. + l, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { l.Close() }) + + serverIP := l.Addr().String() + t.Log("server IP found:", serverIP) + + type state struct { + err error + } + displayResult := func(t *testing.T, r Result, start time.Time) { + t.Helper() + t.Logf("{ Megabytes: %.2f, Start: %.1f, End: %.1f, Total: %t }", r.MegaBytes(), r.IntervalStart.Sub(start).Seconds(), r.IntervalEnd.Sub(start).Seconds(), r.Total) + } + stateChan := make(chan state, 1) + + go func() { + err := Serve(l) + stateChan <- state{err: err} + }() + + // ensure that the test returns an appropriate number of Result structs + expectedLen := int(DefaultDuration.Seconds()) + 1 + + t.Run("download test", func(t *testing.T) { + // conduct a download test + results, err := RunClient(Download, DefaultDuration, serverIP) + + if err != nil { + t.Fatal("download test failed:", err) + } + + if len(results) < expectedLen { + t.Fatalf("download results: expected length: %d, actual length: %d", expectedLen, len(results)) + } + + start := results[0].IntervalStart + for _, result := range results { + displayResult(t, result, start) + } + }) + + t.Run("upload test", func(t *testing.T) { + // conduct an upload test + results, err := RunClient(Upload, DefaultDuration, serverIP) + + if err != nil { + t.Fatal("upload test failed:", err) + } + + if len(results) < expectedLen { + t.Fatalf("upload results: expected length: %d, actual length: %d", expectedLen, len(results)) + } + + start := results[0].IntervalStart + for _, result := range results { + displayResult(t, result, start) + } + }) + + // causes the server goroutine to finish + l.Close() + + testState := <-stateChan + if testState.err != nil { + t.Error("server error:", err) + } +} diff --git a/net/stun/stun.go b/net/stun/stun.go index 81cf9b6080d26..eeac23cbbd45d 100644 --- a/net/stun/stun.go +++ b/net/stun/stun.go @@ -1,312 +1,312 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package STUN generates STUN request packets and parses response packets. -package stun - -import ( - "bytes" - crand "crypto/rand" - "encoding/binary" - "errors" - "hash/crc32" - "net" - "net/netip" -) - -const ( - attrNumSoftware = 0x8022 - attrNumFingerprint = 0x8028 - attrMappedAddress = 0x0001 - attrXorMappedAddress = 0x0020 - // This alternative attribute type is not - // mentioned in the RFC, but the shift into - // the "comprehension-optional" range seems - // like an easy mistake for a server to make. - // And servers appear to send it. - attrXorMappedAddressAlt = 0x8020 - - software = "tailnode" // notably: 8 bytes long, so no padding - bindingRequest = "\x00\x01" - magicCookie = "\x21\x12\xa4\x42" - lenFingerprint = 8 // 2+byte header + 2-byte length + 4-byte crc32 - headerLen = 20 -) - -// TxID is a transaction ID. -type TxID [12]byte - -// NewTxID returns a new random TxID. -func NewTxID() TxID { - var tx TxID - if _, err := crand.Read(tx[:]); err != nil { - panic(err) - } - return tx -} - -// Request generates a binding request STUN packet. -// The transaction ID, tID, should be a random sequence of bytes. -func Request(tID TxID) []byte { - // STUN header, RFC5389 Section 6. - const lenAttrSoftware = 4 + len(software) - b := make([]byte, 0, headerLen+lenAttrSoftware+lenFingerprint) - b = append(b, bindingRequest...) - b = appendU16(b, uint16(lenAttrSoftware+lenFingerprint)) // number of bytes following header - b = append(b, magicCookie...) - b = append(b, tID[:]...) - - // Attribute SOFTWARE, RFC5389 Section 15.5. - b = appendU16(b, attrNumSoftware) - b = appendU16(b, uint16(len(software))) - b = append(b, software...) - - // Attribute FINGERPRINT, RFC5389 Section 15.5. - fp := fingerPrint(b) - b = appendU16(b, attrNumFingerprint) - b = appendU16(b, 4) - b = appendU32(b, fp) - - return b -} - -func fingerPrint(b []byte) uint32 { return crc32.ChecksumIEEE(b) ^ 0x5354554e } - -func appendU16(b []byte, v uint16) []byte { - return append(b, byte(v>>8), byte(v)) -} - -func appendU32(b []byte, v uint32) []byte { - return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) -} - -// ParseBindingRequest parses a STUN binding request. -// -// It returns an error unless it advertises that it came from -// Tailscale. -func ParseBindingRequest(b []byte) (TxID, error) { - if !Is(b) { - return TxID{}, ErrNotSTUN - } - if string(b[:len(bindingRequest)]) != bindingRequest { - return TxID{}, ErrNotBindingRequest - } - var txID TxID - copy(txID[:], b[8:8+len(txID)]) - var softwareOK bool - var lastAttr uint16 - var gotFP uint32 - if err := foreachAttr(b[headerLen:], func(attrType uint16, a []byte) error { - lastAttr = attrType - if attrType == attrNumSoftware && string(a) == software { - softwareOK = true - } - if attrType == attrNumFingerprint && len(a) == 4 { - gotFP = binary.BigEndian.Uint32(a) - } - return nil - }); err != nil { - return TxID{}, err - } - if !softwareOK { - return TxID{}, ErrWrongSoftware - } - if lastAttr != attrNumFingerprint { - return TxID{}, ErrNoFingerprint - } - wantFP := fingerPrint(b[:len(b)-lenFingerprint]) - if gotFP != wantFP { - return TxID{}, ErrWrongFingerprint - } - return txID, nil -} - -var ( - ErrNotSTUN = errors.New("response is not a STUN packet") - ErrNotSuccessResponse = errors.New("STUN packet is not a response") - ErrMalformedAttrs = errors.New("STUN response has malformed attributes") - ErrNotBindingRequest = errors.New("STUN request not a binding request") - ErrWrongSoftware = errors.New("STUN request came from non-Tailscale software") - ErrNoFingerprint = errors.New("STUN request didn't end in fingerprint") - ErrWrongFingerprint = errors.New("STUN request had bogus fingerprint") -) - -func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error { - for len(b) > 0 { - if len(b) < 4 { - return ErrMalformedAttrs - } - attrType := binary.BigEndian.Uint16(b[:2]) - attrLen := int(binary.BigEndian.Uint16(b[2:4])) - attrLenWithPad := (attrLen + 3) &^ 3 - b = b[4:] - if attrLenWithPad > len(b) { - return ErrMalformedAttrs - } - if err := fn(attrType, b[:attrLen]); err != nil { - return err - } - b = b[attrLenWithPad:] - } - return nil -} - -// Response generates a binding response. -func Response(txID TxID, addrPort netip.AddrPort) []byte { - addr := addrPort.Addr() - - var fam byte - if addr.Is4() { - fam = 1 - } else if addr.Is6() { - fam = 2 - } else { - return nil - } - attrsLen := 8 + addr.BitLen()/8 - b := make([]byte, 0, headerLen+attrsLen) - - // Header - b = append(b, 0x01, 0x01) // success - b = appendU16(b, uint16(attrsLen)) - b = append(b, magicCookie...) - b = append(b, txID[:]...) - - // Attributes (well, one) - b = appendU16(b, attrXorMappedAddress) - b = appendU16(b, uint16(4+addr.BitLen()/8)) - b = append(b, - 0, // unused byte - fam) - b = appendU16(b, addrPort.Port()^0x2112) // first half of magicCookie - ipa := addr.As16() - for i, o := range ipa[16-addr.BitLen()/8:] { - if i < 4 { - b = append(b, o^magicCookie[i]) - } else { - b = append(b, o^txID[i-len(magicCookie)]) - } - } - return b -} - -// ParseResponse parses a successful binding response STUN packet. -// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute. -func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) { - if !Is(b) { - return tID, netip.AddrPort{}, ErrNotSTUN - } - copy(tID[:], b[8:8+len(tID)]) - if b[0] != 0x01 || b[1] != 0x01 { - return tID, netip.AddrPort{}, ErrNotSuccessResponse - } - attrsLen := int(binary.BigEndian.Uint16(b[2:4])) - b = b[headerLen:] // remove STUN header - if attrsLen > len(b) { - return tID, netip.AddrPort{}, ErrMalformedAttrs - } else if len(b) > attrsLen { - b = b[:attrsLen] // trim trailing packet bytes - } - - var fallbackAddr netip.AddrPort - - // Read through the attributes. - // The the addr+port reported by XOR-MAPPED-ADDRESS - // as the canonical value. If the attribute is not - // present but the STUN server responds with - // MAPPED-ADDRESS we fall back to it. - if err := foreachAttr(b, func(attrType uint16, attr []byte) error { - switch attrType { - case attrXorMappedAddress, attrXorMappedAddressAlt: - ipSlice, port, err := xorMappedAddress(tID, attr) - if err != nil { - return err - } - if ip, ok := netip.AddrFromSlice(ipSlice); ok { - addr = netip.AddrPortFrom(ip.Unmap(), port) - } - case attrMappedAddress: - ipSlice, port, err := mappedAddress(attr) - if err != nil { - return ErrMalformedAttrs - } - if ip, ok := netip.AddrFromSlice(ipSlice); ok { - fallbackAddr = netip.AddrPortFrom(ip.Unmap(), port) - } - } - return nil - - }); err != nil { - return TxID{}, netip.AddrPort{}, err - } - - if addr.IsValid() { - return tID, addr, nil - } - if fallbackAddr.IsValid() { - return tID, fallbackAddr, nil - } - return tID, netip.AddrPort{}, ErrMalformedAttrs -} - -func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error) { - // XOR-MAPPED-ADDRESS attribute, RFC5389 Section 15.2 - if len(b) < 4 { - return nil, 0, ErrMalformedAttrs - } - xorPort := binary.BigEndian.Uint16(b[2:4]) - addrField := b[4:] - port = xorPort ^ 0x2112 // first half of magicCookie - - addrLen := familyAddrLen(b[1]) - if addrLen == 0 { - return nil, 0, ErrMalformedAttrs - } - if len(addrField) < addrLen { - return nil, 0, ErrMalformedAttrs - } - xorAddr := addrField[:addrLen] - addr = make([]byte, addrLen) - for i := range xorAddr { - if i < len(magicCookie) { - addr[i] = xorAddr[i] ^ magicCookie[i] - } else { - addr[i] = xorAddr[i] ^ tID[i-len(magicCookie)] - } - } - return addr, port, nil -} - -func familyAddrLen(fam byte) int { - switch fam { - case 0x01: // IPv4 - return net.IPv4len - case 0x02: // IPv6 - return net.IPv6len - default: - return 0 - } -} - -func mappedAddress(b []byte) (addr []byte, port uint16, err error) { - if len(b) < 4 { - return nil, 0, ErrMalformedAttrs - } - port = uint16(b[2])<<8 | uint16(b[3]) - addrField := b[4:] - addrLen := familyAddrLen(b[1]) - if addrLen == 0 { - return nil, 0, ErrMalformedAttrs - } - if len(addrField) < addrLen { - return nil, 0, ErrMalformedAttrs - } - return bytes.Clone(addrField[:addrLen]), port, nil -} - -// Is reports whether b is a STUN message. -func Is(b []byte) bool { - return len(b) >= headerLen && - b[0]&0b11000000 == 0 && // top two bits must be zero - string(b[4:8]) == magicCookie -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package STUN generates STUN request packets and parses response packets. +package stun + +import ( + "bytes" + crand "crypto/rand" + "encoding/binary" + "errors" + "hash/crc32" + "net" + "net/netip" +) + +const ( + attrNumSoftware = 0x8022 + attrNumFingerprint = 0x8028 + attrMappedAddress = 0x0001 + attrXorMappedAddress = 0x0020 + // This alternative attribute type is not + // mentioned in the RFC, but the shift into + // the "comprehension-optional" range seems + // like an easy mistake for a server to make. + // And servers appear to send it. + attrXorMappedAddressAlt = 0x8020 + + software = "tailnode" // notably: 8 bytes long, so no padding + bindingRequest = "\x00\x01" + magicCookie = "\x21\x12\xa4\x42" + lenFingerprint = 8 // 2+byte header + 2-byte length + 4-byte crc32 + headerLen = 20 +) + +// TxID is a transaction ID. +type TxID [12]byte + +// NewTxID returns a new random TxID. +func NewTxID() TxID { + var tx TxID + if _, err := crand.Read(tx[:]); err != nil { + panic(err) + } + return tx +} + +// Request generates a binding request STUN packet. +// The transaction ID, tID, should be a random sequence of bytes. +func Request(tID TxID) []byte { + // STUN header, RFC5389 Section 6. + const lenAttrSoftware = 4 + len(software) + b := make([]byte, 0, headerLen+lenAttrSoftware+lenFingerprint) + b = append(b, bindingRequest...) + b = appendU16(b, uint16(lenAttrSoftware+lenFingerprint)) // number of bytes following header + b = append(b, magicCookie...) + b = append(b, tID[:]...) + + // Attribute SOFTWARE, RFC5389 Section 15.5. + b = appendU16(b, attrNumSoftware) + b = appendU16(b, uint16(len(software))) + b = append(b, software...) + + // Attribute FINGERPRINT, RFC5389 Section 15.5. + fp := fingerPrint(b) + b = appendU16(b, attrNumFingerprint) + b = appendU16(b, 4) + b = appendU32(b, fp) + + return b +} + +func fingerPrint(b []byte) uint32 { return crc32.ChecksumIEEE(b) ^ 0x5354554e } + +func appendU16(b []byte, v uint16) []byte { + return append(b, byte(v>>8), byte(v)) +} + +func appendU32(b []byte, v uint32) []byte { + return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +// ParseBindingRequest parses a STUN binding request. +// +// It returns an error unless it advertises that it came from +// Tailscale. +func ParseBindingRequest(b []byte) (TxID, error) { + if !Is(b) { + return TxID{}, ErrNotSTUN + } + if string(b[:len(bindingRequest)]) != bindingRequest { + return TxID{}, ErrNotBindingRequest + } + var txID TxID + copy(txID[:], b[8:8+len(txID)]) + var softwareOK bool + var lastAttr uint16 + var gotFP uint32 + if err := foreachAttr(b[headerLen:], func(attrType uint16, a []byte) error { + lastAttr = attrType + if attrType == attrNumSoftware && string(a) == software { + softwareOK = true + } + if attrType == attrNumFingerprint && len(a) == 4 { + gotFP = binary.BigEndian.Uint32(a) + } + return nil + }); err != nil { + return TxID{}, err + } + if !softwareOK { + return TxID{}, ErrWrongSoftware + } + if lastAttr != attrNumFingerprint { + return TxID{}, ErrNoFingerprint + } + wantFP := fingerPrint(b[:len(b)-lenFingerprint]) + if gotFP != wantFP { + return TxID{}, ErrWrongFingerprint + } + return txID, nil +} + +var ( + ErrNotSTUN = errors.New("response is not a STUN packet") + ErrNotSuccessResponse = errors.New("STUN packet is not a response") + ErrMalformedAttrs = errors.New("STUN response has malformed attributes") + ErrNotBindingRequest = errors.New("STUN request not a binding request") + ErrWrongSoftware = errors.New("STUN request came from non-Tailscale software") + ErrNoFingerprint = errors.New("STUN request didn't end in fingerprint") + ErrWrongFingerprint = errors.New("STUN request had bogus fingerprint") +) + +func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error { + for len(b) > 0 { + if len(b) < 4 { + return ErrMalformedAttrs + } + attrType := binary.BigEndian.Uint16(b[:2]) + attrLen := int(binary.BigEndian.Uint16(b[2:4])) + attrLenWithPad := (attrLen + 3) &^ 3 + b = b[4:] + if attrLenWithPad > len(b) { + return ErrMalformedAttrs + } + if err := fn(attrType, b[:attrLen]); err != nil { + return err + } + b = b[attrLenWithPad:] + } + return nil +} + +// Response generates a binding response. +func Response(txID TxID, addrPort netip.AddrPort) []byte { + addr := addrPort.Addr() + + var fam byte + if addr.Is4() { + fam = 1 + } else if addr.Is6() { + fam = 2 + } else { + return nil + } + attrsLen := 8 + addr.BitLen()/8 + b := make([]byte, 0, headerLen+attrsLen) + + // Header + b = append(b, 0x01, 0x01) // success + b = appendU16(b, uint16(attrsLen)) + b = append(b, magicCookie...) + b = append(b, txID[:]...) + + // Attributes (well, one) + b = appendU16(b, attrXorMappedAddress) + b = appendU16(b, uint16(4+addr.BitLen()/8)) + b = append(b, + 0, // unused byte + fam) + b = appendU16(b, addrPort.Port()^0x2112) // first half of magicCookie + ipa := addr.As16() + for i, o := range ipa[16-addr.BitLen()/8:] { + if i < 4 { + b = append(b, o^magicCookie[i]) + } else { + b = append(b, o^txID[i-len(magicCookie)]) + } + } + return b +} + +// ParseResponse parses a successful binding response STUN packet. +// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute. +func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) { + if !Is(b) { + return tID, netip.AddrPort{}, ErrNotSTUN + } + copy(tID[:], b[8:8+len(tID)]) + if b[0] != 0x01 || b[1] != 0x01 { + return tID, netip.AddrPort{}, ErrNotSuccessResponse + } + attrsLen := int(binary.BigEndian.Uint16(b[2:4])) + b = b[headerLen:] // remove STUN header + if attrsLen > len(b) { + return tID, netip.AddrPort{}, ErrMalformedAttrs + } else if len(b) > attrsLen { + b = b[:attrsLen] // trim trailing packet bytes + } + + var fallbackAddr netip.AddrPort + + // Read through the attributes. + // The the addr+port reported by XOR-MAPPED-ADDRESS + // as the canonical value. If the attribute is not + // present but the STUN server responds with + // MAPPED-ADDRESS we fall back to it. + if err := foreachAttr(b, func(attrType uint16, attr []byte) error { + switch attrType { + case attrXorMappedAddress, attrXorMappedAddressAlt: + ipSlice, port, err := xorMappedAddress(tID, attr) + if err != nil { + return err + } + if ip, ok := netip.AddrFromSlice(ipSlice); ok { + addr = netip.AddrPortFrom(ip.Unmap(), port) + } + case attrMappedAddress: + ipSlice, port, err := mappedAddress(attr) + if err != nil { + return ErrMalformedAttrs + } + if ip, ok := netip.AddrFromSlice(ipSlice); ok { + fallbackAddr = netip.AddrPortFrom(ip.Unmap(), port) + } + } + return nil + + }); err != nil { + return TxID{}, netip.AddrPort{}, err + } + + if addr.IsValid() { + return tID, addr, nil + } + if fallbackAddr.IsValid() { + return tID, fallbackAddr, nil + } + return tID, netip.AddrPort{}, ErrMalformedAttrs +} + +func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error) { + // XOR-MAPPED-ADDRESS attribute, RFC5389 Section 15.2 + if len(b) < 4 { + return nil, 0, ErrMalformedAttrs + } + xorPort := binary.BigEndian.Uint16(b[2:4]) + addrField := b[4:] + port = xorPort ^ 0x2112 // first half of magicCookie + + addrLen := familyAddrLen(b[1]) + if addrLen == 0 { + return nil, 0, ErrMalformedAttrs + } + if len(addrField) < addrLen { + return nil, 0, ErrMalformedAttrs + } + xorAddr := addrField[:addrLen] + addr = make([]byte, addrLen) + for i := range xorAddr { + if i < len(magicCookie) { + addr[i] = xorAddr[i] ^ magicCookie[i] + } else { + addr[i] = xorAddr[i] ^ tID[i-len(magicCookie)] + } + } + return addr, port, nil +} + +func familyAddrLen(fam byte) int { + switch fam { + case 0x01: // IPv4 + return net.IPv4len + case 0x02: // IPv6 + return net.IPv6len + default: + return 0 + } +} + +func mappedAddress(b []byte) (addr []byte, port uint16, err error) { + if len(b) < 4 { + return nil, 0, ErrMalformedAttrs + } + port = uint16(b[2])<<8 | uint16(b[3]) + addrField := b[4:] + addrLen := familyAddrLen(b[1]) + if addrLen == 0 { + return nil, 0, ErrMalformedAttrs + } + if len(addrField) < addrLen { + return nil, 0, ErrMalformedAttrs + } + return bytes.Clone(addrField[:addrLen]), port, nil +} + +// Is reports whether b is a STUN message. +func Is(b []byte) bool { + return len(b) >= headerLen && + b[0]&0b11000000 == 0 && // top two bits must be zero + string(b[4:8]) == magicCookie +} diff --git a/net/stun/stun_fuzzer.go b/net/stun/stun_fuzzer.go index 9ddb418950b39..6f0c9e3b0beae 100644 --- a/net/stun/stun_fuzzer.go +++ b/net/stun/stun_fuzzer.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -//go:build gofuzz - -package stun - -func FuzzStunParser(data []byte) int { - _, _, _ = ParseResponse(data) - - _, _ = ParseBindingRequest(data) - return 1 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +//go:build gofuzz + +package stun + +func FuzzStunParser(data []byte) int { + _, _, _ = ParseResponse(data) + + _, _ = ParseBindingRequest(data) + return 1 +} diff --git a/net/tcpinfo/tcpinfo.go b/net/tcpinfo/tcpinfo.go index adc40ca372cf5..a757add9f8f46 100644 --- a/net/tcpinfo/tcpinfo.go +++ b/net/tcpinfo/tcpinfo.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package tcpinfo provides platform-agnostic accessors to information about a -// TCP connection (e.g. RTT, MSS, etc.). -package tcpinfo - -import ( - "errors" - "net" - "time" -) - -var ( - ErrNotTCP = errors.New("tcpinfo: not a TCP conn") - ErrUnimplemented = errors.New("tcpinfo: unimplemented") -) - -// RTT returns the RTT for the given net.Conn. -// -// If the net.Conn is not a *net.TCPConn and cannot be unwrapped into one, then -// ErrNotTCP will be returned. If retrieving the RTT is not supported on the -// current platform, ErrUnimplemented will be returned. -func RTT(conn net.Conn) (time.Duration, error) { - tcpConn, err := unwrap(conn) - if err != nil { - return 0, err - } - - return rttImpl(tcpConn) -} - -// netConner is implemented by crypto/tls.Conn to unwrap into an underlying -// net.Conn. -type netConner interface { - NetConn() net.Conn -} - -// unwrap attempts to unwrap a net.Conn into an underlying *net.TCPConn -func unwrap(nc net.Conn) (*net.TCPConn, error) { - for { - switch v := nc.(type) { - case *net.TCPConn: - return v, nil - case netConner: - nc = v.NetConn() - default: - return nil, ErrNotTCP - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tcpinfo provides platform-agnostic accessors to information about a +// TCP connection (e.g. RTT, MSS, etc.). +package tcpinfo + +import ( + "errors" + "net" + "time" +) + +var ( + ErrNotTCP = errors.New("tcpinfo: not a TCP conn") + ErrUnimplemented = errors.New("tcpinfo: unimplemented") +) + +// RTT returns the RTT for the given net.Conn. +// +// If the net.Conn is not a *net.TCPConn and cannot be unwrapped into one, then +// ErrNotTCP will be returned. If retrieving the RTT is not supported on the +// current platform, ErrUnimplemented will be returned. +func RTT(conn net.Conn) (time.Duration, error) { + tcpConn, err := unwrap(conn) + if err != nil { + return 0, err + } + + return rttImpl(tcpConn) +} + +// netConner is implemented by crypto/tls.Conn to unwrap into an underlying +// net.Conn. +type netConner interface { + NetConn() net.Conn +} + +// unwrap attempts to unwrap a net.Conn into an underlying *net.TCPConn +func unwrap(nc net.Conn) (*net.TCPConn, error) { + for { + switch v := nc.(type) { + case *net.TCPConn: + return v, nil + case netConner: + nc = v.NetConn() + default: + return nil, ErrNotTCP + } + } +} diff --git a/net/tcpinfo/tcpinfo_darwin.go b/net/tcpinfo/tcpinfo_darwin.go index bc4ac08b38b04..53fa22fbf5bed 100644 --- a/net/tcpinfo/tcpinfo_darwin.go +++ b/net/tcpinfo/tcpinfo_darwin.go @@ -1,33 +1,33 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tcpinfo - -import ( - "net" - "time" - - "golang.org/x/sys/unix" -) - -func rttImpl(conn *net.TCPConn) (time.Duration, error) { - rawConn, err := conn.SyscallConn() - if err != nil { - return 0, err - } - - var ( - tcpInfo *unix.TCPConnectionInfo - sysErr error - ) - err = rawConn.Control(func(fd uintptr) { - tcpInfo, sysErr = unix.GetsockoptTCPConnectionInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) - }) - if err != nil { - return 0, err - } else if sysErr != nil { - return 0, sysErr - } - - return time.Duration(tcpInfo.Rttcur) * time.Millisecond, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tcpinfo + +import ( + "net" + "time" + + "golang.org/x/sys/unix" +) + +func rttImpl(conn *net.TCPConn) (time.Duration, error) { + rawConn, err := conn.SyscallConn() + if err != nil { + return 0, err + } + + var ( + tcpInfo *unix.TCPConnectionInfo + sysErr error + ) + err = rawConn.Control(func(fd uintptr) { + tcpInfo, sysErr = unix.GetsockoptTCPConnectionInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) + }) + if err != nil { + return 0, err + } else if sysErr != nil { + return 0, sysErr + } + + return time.Duration(tcpInfo.Rttcur) * time.Millisecond, nil +} diff --git a/net/tcpinfo/tcpinfo_linux.go b/net/tcpinfo/tcpinfo_linux.go index 5d86055bb8499..885d462c95e35 100644 --- a/net/tcpinfo/tcpinfo_linux.go +++ b/net/tcpinfo/tcpinfo_linux.go @@ -1,33 +1,33 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tcpinfo - -import ( - "net" - "time" - - "golang.org/x/sys/unix" -) - -func rttImpl(conn *net.TCPConn) (time.Duration, error) { - rawConn, err := conn.SyscallConn() - if err != nil { - return 0, err - } - - var ( - tcpInfo *unix.TCPInfo - sysErr error - ) - err = rawConn.Control(func(fd uintptr) { - tcpInfo, sysErr = unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO) - }) - if err != nil { - return 0, err - } else if sysErr != nil { - return 0, sysErr - } - - return time.Duration(tcpInfo.Rtt) * time.Microsecond, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tcpinfo + +import ( + "net" + "time" + + "golang.org/x/sys/unix" +) + +func rttImpl(conn *net.TCPConn) (time.Duration, error) { + rawConn, err := conn.SyscallConn() + if err != nil { + return 0, err + } + + var ( + tcpInfo *unix.TCPInfo + sysErr error + ) + err = rawConn.Control(func(fd uintptr) { + tcpInfo, sysErr = unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO) + }) + if err != nil { + return 0, err + } else if sysErr != nil { + return 0, sysErr + } + + return time.Duration(tcpInfo.Rtt) * time.Microsecond, nil +} diff --git a/net/tcpinfo/tcpinfo_other.go b/net/tcpinfo/tcpinfo_other.go index f219cda1bd4a0..be45523aeb00d 100644 --- a/net/tcpinfo/tcpinfo_other.go +++ b/net/tcpinfo/tcpinfo_other.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !darwin - -package tcpinfo - -import ( - "net" - "time" -) - -func rttImpl(conn *net.TCPConn) (time.Duration, error) { - return 0, ErrUnimplemented -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !darwin + +package tcpinfo + +import ( + "net" + "time" +) + +func rttImpl(conn *net.TCPConn) (time.Duration, error) { + return 0, ErrUnimplemented +} diff --git a/net/tlsdial/deps_test.go b/net/tlsdial/deps_test.go index 750cb300ae5eb..7a93899c2f126 100644 --- a/net/tlsdial/deps_test.go +++ b/net/tlsdial/deps_test.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build for_go_mod_tidy_only - -package tlsdial - -import _ "filippo.io/mkcert" +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build for_go_mod_tidy_only + +package tlsdial + +import _ "filippo.io/mkcert" diff --git a/net/tsdial/dnsmap_test.go b/net/tsdial/dnsmap_test.go index f846b853e1432..43461a135e1c5 100644 --- a/net/tsdial/dnsmap_test.go +++ b/net/tsdial/dnsmap_test.go @@ -1,125 +1,125 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsdial - -import ( - "net/netip" - "reflect" - "testing" - - "tailscale.com/tailcfg" - "tailscale.com/types/netmap" -) - -func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { - nv := make([]tailcfg.NodeView, len(v)) - for i, n := range v { - nv[i] = n.View() - } - return nv -} - -func TestDNSMapFromNetworkMap(t *testing.T) { - pfx := netip.MustParsePrefix - ip := netip.MustParseAddr - tests := []struct { - name string - nm *netmap.NetworkMap - want dnsMap - }{ - { - name: "self", - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100.102.103.104/32"), - pfx("100::123/128"), - }, - }).View(), - }, - want: dnsMap{ - "foo": ip("100.102.103.104"), - "foo.tailnet": ip("100.102.103.104"), - }, - }, - { - name: "self_and_peers", - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100.102.103.104/32"), - pfx("100::123/128"), - }, - }).View(), - Peers: []tailcfg.NodeView{ - (&tailcfg.Node{ - Name: "a.tailnet", - Addresses: []netip.Prefix{ - pfx("100.0.0.201/32"), - pfx("100::201/128"), - }, - }).View(), - (&tailcfg.Node{ - Name: "b.tailnet", - Addresses: []netip.Prefix{ - pfx("100::202/128"), - }, - }).View(), - }, - }, - want: dnsMap{ - "foo": ip("100.102.103.104"), - "foo.tailnet": ip("100.102.103.104"), - "a": ip("100.0.0.201"), - "a.tailnet": ip("100.0.0.201"), - "b": ip("100::202"), - "b.tailnet": ip("100::202"), - }, - }, - { - name: "self_has_v6_only", - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100::123/128"), - }, - }).View(), - Peers: nodeViews([]*tailcfg.Node{ - { - Name: "a.tailnet", - Addresses: []netip.Prefix{ - pfx("100.0.0.201/32"), - pfx("100::201/128"), - }, - }, - { - Name: "b.tailnet", - Addresses: []netip.Prefix{ - pfx("100::202/128"), - }, - }, - }), - }, - want: dnsMap{ - "foo": ip("100::123"), - "foo.tailnet": ip("100::123"), - "a": ip("100::201"), - "a.tailnet": ip("100::201"), - "b": ip("100::202"), - "b.tailnet": ip("100::202"), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := dnsMapFromNetworkMap(tt.nm) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("mismatch:\n got %v\nwant %v\n", got, tt.want) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsdial + +import ( + "net/netip" + "reflect" + "testing" + + "tailscale.com/tailcfg" + "tailscale.com/types/netmap" +) + +func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { + nv := make([]tailcfg.NodeView, len(v)) + for i, n := range v { + nv[i] = n.View() + } + return nv +} + +func TestDNSMapFromNetworkMap(t *testing.T) { + pfx := netip.MustParsePrefix + ip := netip.MustParseAddr + tests := []struct { + name string + nm *netmap.NetworkMap + want dnsMap + }{ + { + name: "self", + nm: &netmap.NetworkMap{ + Name: "foo.tailnet", + SelfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{ + pfx("100.102.103.104/32"), + pfx("100::123/128"), + }, + }).View(), + }, + want: dnsMap{ + "foo": ip("100.102.103.104"), + "foo.tailnet": ip("100.102.103.104"), + }, + }, + { + name: "self_and_peers", + nm: &netmap.NetworkMap{ + Name: "foo.tailnet", + SelfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{ + pfx("100.102.103.104/32"), + pfx("100::123/128"), + }, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + Name: "a.tailnet", + Addresses: []netip.Prefix{ + pfx("100.0.0.201/32"), + pfx("100::201/128"), + }, + }).View(), + (&tailcfg.Node{ + Name: "b.tailnet", + Addresses: []netip.Prefix{ + pfx("100::202/128"), + }, + }).View(), + }, + }, + want: dnsMap{ + "foo": ip("100.102.103.104"), + "foo.tailnet": ip("100.102.103.104"), + "a": ip("100.0.0.201"), + "a.tailnet": ip("100.0.0.201"), + "b": ip("100::202"), + "b.tailnet": ip("100::202"), + }, + }, + { + name: "self_has_v6_only", + nm: &netmap.NetworkMap{ + Name: "foo.tailnet", + SelfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{ + pfx("100::123/128"), + }, + }).View(), + Peers: nodeViews([]*tailcfg.Node{ + { + Name: "a.tailnet", + Addresses: []netip.Prefix{ + pfx("100.0.0.201/32"), + pfx("100::201/128"), + }, + }, + { + Name: "b.tailnet", + Addresses: []netip.Prefix{ + pfx("100::202/128"), + }, + }, + }), + }, + want: dnsMap{ + "foo": ip("100::123"), + "foo.tailnet": ip("100::123"), + "a": ip("100::201"), + "a.tailnet": ip("100::201"), + "b": ip("100::202"), + "b.tailnet": ip("100::202"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := dnsMapFromNetworkMap(tt.nm) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("mismatch:\n got %v\nwant %v\n", got, tt.want) + } + }) + } +} diff --git a/net/tsdial/dohclient.go b/net/tsdial/dohclient.go index 64c127fd3270a..d830398cdfb9c 100644 --- a/net/tsdial/dohclient.go +++ b/net/tsdial/dohclient.go @@ -1,100 +1,100 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsdial - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "net" - "net/http" - "time" - - "tailscale.com/net/dnscache" -) - -// dohConn is a net.PacketConn suitable for returning from -// net.Dialer.Dial to send DNS queries over PeerAPI to exit nodes' -// ExitDNS DoH proxy service. -type dohConn struct { - ctx context.Context - baseURL string - hc *http.Client // if nil, default is used - dnsCache *dnscache.MessageCache - - rbuf bytes.Buffer -} - -var ( - _ net.Conn = (*dohConn)(nil) - _ net.PacketConn = (*dohConn)(nil) // be a PacketConn to change net.Resolver semantics -) - -func (*dohConn) Close() error { return nil } -func (*dohConn) LocalAddr() net.Addr { return todoAddr{} } -func (*dohConn) RemoteAddr() net.Addr { return todoAddr{} } -func (*dohConn) SetDeadline(t time.Time) error { return nil } -func (*dohConn) SetReadDeadline(t time.Time) error { return nil } -func (*dohConn) SetWriteDeadline(t time.Time) error { return nil } - -func (c *dohConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - return c.Write(p) -} - -func (c *dohConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - n, err = c.Read(p) - return n, todoAddr{}, err -} - -func (c *dohConn) Read(p []byte) (n int, err error) { - return c.rbuf.Read(p) -} - -func (c *dohConn) Write(packet []byte) (n int, err error) { - if c.dnsCache != nil { - err := c.dnsCache.ReplyFromCache(&c.rbuf, packet) - if err == nil { - // Cache hit. - // TODO(bradfitz): add clientmetric - return len(packet), nil - } - c.rbuf.Reset() - } - req, err := http.NewRequestWithContext(c.ctx, "POST", c.baseURL, bytes.NewReader(packet)) - if err != nil { - return 0, err - } - const dohType = "application/dns-message" - req.Header.Set("Content-Type", dohType) - hc := c.hc - if hc == nil { - hc = http.DefaultClient - } - hres, err := hc.Do(req) - if err != nil { - return 0, err - } - defer hres.Body.Close() - if hres.StatusCode != 200 { - return 0, errors.New(hres.Status) - } - if ct := hres.Header.Get("Content-Type"); ct != dohType { - return 0, fmt.Errorf("unexpected response Content-Type %q", ct) - } - _, err = io.Copy(&c.rbuf, hres.Body) - if err != nil { - return 0, err - } - if c.dnsCache != nil { - c.dnsCache.AddCacheEntry(packet, c.rbuf.Bytes()) - } - return len(packet), nil -} - -type todoAddr struct{} - -func (todoAddr) Network() string { return "unused" } -func (todoAddr) String() string { return "unused-todoAddr" } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsdial + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "time" + + "tailscale.com/net/dnscache" +) + +// dohConn is a net.PacketConn suitable for returning from +// net.Dialer.Dial to send DNS queries over PeerAPI to exit nodes' +// ExitDNS DoH proxy service. +type dohConn struct { + ctx context.Context + baseURL string + hc *http.Client // if nil, default is used + dnsCache *dnscache.MessageCache + + rbuf bytes.Buffer +} + +var ( + _ net.Conn = (*dohConn)(nil) + _ net.PacketConn = (*dohConn)(nil) // be a PacketConn to change net.Resolver semantics +) + +func (*dohConn) Close() error { return nil } +func (*dohConn) LocalAddr() net.Addr { return todoAddr{} } +func (*dohConn) RemoteAddr() net.Addr { return todoAddr{} } +func (*dohConn) SetDeadline(t time.Time) error { return nil } +func (*dohConn) SetReadDeadline(t time.Time) error { return nil } +func (*dohConn) SetWriteDeadline(t time.Time) error { return nil } + +func (c *dohConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return c.Write(p) +} + +func (c *dohConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(p) + return n, todoAddr{}, err +} + +func (c *dohConn) Read(p []byte) (n int, err error) { + return c.rbuf.Read(p) +} + +func (c *dohConn) Write(packet []byte) (n int, err error) { + if c.dnsCache != nil { + err := c.dnsCache.ReplyFromCache(&c.rbuf, packet) + if err == nil { + // Cache hit. + // TODO(bradfitz): add clientmetric + return len(packet), nil + } + c.rbuf.Reset() + } + req, err := http.NewRequestWithContext(c.ctx, "POST", c.baseURL, bytes.NewReader(packet)) + if err != nil { + return 0, err + } + const dohType = "application/dns-message" + req.Header.Set("Content-Type", dohType) + hc := c.hc + if hc == nil { + hc = http.DefaultClient + } + hres, err := hc.Do(req) + if err != nil { + return 0, err + } + defer hres.Body.Close() + if hres.StatusCode != 200 { + return 0, errors.New(hres.Status) + } + if ct := hres.Header.Get("Content-Type"); ct != dohType { + return 0, fmt.Errorf("unexpected response Content-Type %q", ct) + } + _, err = io.Copy(&c.rbuf, hres.Body) + if err != nil { + return 0, err + } + if c.dnsCache != nil { + c.dnsCache.AddCacheEntry(packet, c.rbuf.Bytes()) + } + return len(packet), nil +} + +type todoAddr struct{} + +func (todoAddr) Network() string { return "unused" } +func (todoAddr) String() string { return "unused-todoAddr" } diff --git a/net/tsdial/dohclient_test.go b/net/tsdial/dohclient_test.go index 41a66f8f71edd..23255769f4847 100644 --- a/net/tsdial/dohclient_test.go +++ b/net/tsdial/dohclient_test.go @@ -1,31 +1,31 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsdial - -import ( - "context" - "flag" - "net" - "testing" - "time" -) - -var dohBase = flag.String("doh-base", "", "DoH base URL for manual DoH tests; e.g. \"http://100.68.82.120:47830/dns-query\"") - -func TestDoHResolve(t *testing.T) { - if *dohBase == "" { - t.Skip("skipping manual test without --doh-base= set") - } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - var r net.Resolver - r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { - return &dohConn{ctx: ctx, baseURL: *dohBase}, nil - } - addrs, err := r.LookupIP(ctx, "ip4", "google.com.") - if err != nil { - t.Fatal(err) - } - t.Logf("Got: %q", addrs) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsdial + +import ( + "context" + "flag" + "net" + "testing" + "time" +) + +var dohBase = flag.String("doh-base", "", "DoH base URL for manual DoH tests; e.g. \"http://100.68.82.120:47830/dns-query\"") + +func TestDoHResolve(t *testing.T) { + if *dohBase == "" { + t.Skip("skipping manual test without --doh-base= set") + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + var r net.Resolver + r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { + return &dohConn{ctx: ctx, baseURL: *dohBase}, nil + } + addrs, err := r.LookupIP(ctx, "ip4", "google.com.") + if err != nil { + t.Fatal(err) + } + t.Logf("Got: %q", addrs) +} diff --git a/net/tshttpproxy/mksyscall.go b/net/tshttpproxy/mksyscall.go index 467dc49170092..f8fdae89b55f0 100644 --- a/net/tshttpproxy/mksyscall.go +++ b/net/tshttpproxy/mksyscall.go @@ -1,11 +1,11 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tshttpproxy - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go - -//sys globalFree(hglobal winHGlobal) (err error) [failretval==0] = kernel32.GlobalFree -//sys winHTTPCloseHandle(whi winHTTPInternet) (err error) [failretval==0] = winhttp.WinHttpCloseHandle -//sys winHTTPGetProxyForURL(whi winHTTPInternet, url *uint16, options *winHTTPAutoProxyOptions, proxyInfo *winHTTPProxyInfo) (err error) [failretval==0] = winhttp.WinHttpGetProxyForUrl -//sys winHTTPOpen(agent *uint16, accessType uint32, proxy *uint16, proxyBypass *uint16, flags uint32) (whi winHTTPInternet, err error) [failretval==0] = winhttp.WinHttpOpen +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tshttpproxy + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go + +//sys globalFree(hglobal winHGlobal) (err error) [failretval==0] = kernel32.GlobalFree +//sys winHTTPCloseHandle(whi winHTTPInternet) (err error) [failretval==0] = winhttp.WinHttpCloseHandle +//sys winHTTPGetProxyForURL(whi winHTTPInternet, url *uint16, options *winHTTPAutoProxyOptions, proxyInfo *winHTTPProxyInfo) (err error) [failretval==0] = winhttp.WinHttpGetProxyForUrl +//sys winHTTPOpen(agent *uint16, accessType uint32, proxy *uint16, proxyBypass *uint16, flags uint32) (whi winHTTPInternet, err error) [failretval==0] = winhttp.WinHttpOpen diff --git a/net/tshttpproxy/tshttpproxy_linux.go b/net/tshttpproxy/tshttpproxy_linux.go index 09019893ade8c..b241c256d4798 100644 --- a/net/tshttpproxy/tshttpproxy_linux.go +++ b/net/tshttpproxy/tshttpproxy_linux.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package tshttpproxy - -import ( - "net/http" - "net/url" - - "tailscale.com/version/distro" -) - -func init() { - sysProxyFromEnv = linuxSysProxyFromEnv -} - -func linuxSysProxyFromEnv(req *http.Request) (*url.URL, error) { - if distro.Get() == distro.Synology { - return synologyProxyFromConfigCached(req) - } - return nil, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package tshttpproxy + +import ( + "net/http" + "net/url" + + "tailscale.com/version/distro" +) + +func init() { + sysProxyFromEnv = linuxSysProxyFromEnv +} + +func linuxSysProxyFromEnv(req *http.Request) (*url.URL, error) { + if distro.Get() == distro.Synology { + return synologyProxyFromConfigCached(req) + } + return nil, nil +} diff --git a/net/tshttpproxy/tshttpproxy_synology_test.go b/net/tshttpproxy/tshttpproxy_synology_test.go index e11c9d05996ed..3061740f3beff 100644 --- a/net/tshttpproxy/tshttpproxy_synology_test.go +++ b/net/tshttpproxy/tshttpproxy_synology_test.go @@ -1,376 +1,376 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package tshttpproxy - -import ( - "errors" - "fmt" - "io" - "net/http" - "net/url" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "tailscale.com/tstest" -) - -func TestSynologyProxyFromConfigCached(t *testing.T) { - req, err := http.NewRequest("GET", "http://example.org/", nil) - if err != nil { - t.Fatal(err) - } - - tstest.Replace(t, &synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf")) - - t.Run("no config file", func(t *testing.T) { - if _, err := os.Stat(synologyProxyConfigPath); err == nil { - t.Fatalf("%s must not exist for this test", synologyProxyConfigPath) - } - - cache.updated = time.Time{} - cache.httpProxy = nil - cache.httpsProxy = nil - - if val, err := synologyProxyFromConfigCached(req); val != nil || err != nil { - t.Fatalf("got %s, %v; want nil, nil", val, err) - } - - if got, want := cache.updated, time.Unix(0, 0); got != want { - t.Fatalf("got %s, want %s", got, want) - } - if cache.httpProxy != nil { - t.Fatalf("got %s, want nil", cache.httpProxy) - } - if cache.httpsProxy != nil { - t.Fatalf("got %s, want nil", cache.httpsProxy) - } - }) - - t.Run("config file updated", func(t *testing.T) { - cache.updated = time.Now() - cache.httpProxy = nil - cache.httpsProxy = nil - - if err := os.WriteFile(synologyProxyConfigPath, []byte(` -proxy_enabled=yes -http_host=10.0.0.55 -http_port=80 -https_host=10.0.0.66 -https_port=443 - `), 0600); err != nil { - t.Fatal(err) - } - - val, err := synologyProxyFromConfigCached(req) - if err != nil { - t.Fatal(err) - } - - if cache.httpProxy == nil { - t.Fatal("http proxy was not cached") - } - if cache.httpsProxy == nil { - t.Fatal("https proxy was not cached") - } - - if want := urlMustParse("http://10.0.0.55:80"); val.String() != want.String() { - t.Fatalf("got %s; want %s", val, want) - } - }) - - t.Run("config file removed", func(t *testing.T) { - cache.updated = time.Now() - cache.httpProxy = urlMustParse("http://127.0.0.1/") - cache.httpsProxy = urlMustParse("http://127.0.0.1/") - - if err := os.Remove(synologyProxyConfigPath); err != nil && !os.IsNotExist(err) { - t.Fatal(err) - } - - val, err := synologyProxyFromConfigCached(req) - if err != nil { - t.Fatal(err) - } - if val != nil { - t.Fatalf("got %s; want nil", val) - } - if cache.httpProxy != nil { - t.Fatalf("got %s, want nil", cache.httpProxy) - } - if cache.httpsProxy != nil { - t.Fatalf("got %s, want nil", cache.httpsProxy) - } - }) - - t.Run("picks proxy from request scheme", func(t *testing.T) { - cache.updated = time.Now() - cache.httpProxy = nil - cache.httpsProxy = nil - - if err := os.WriteFile(synologyProxyConfigPath, []byte(` -proxy_enabled=yes -http_host=10.0.0.55 -http_port=80 -https_host=10.0.0.66 -https_port=443 - `), 0600); err != nil { - t.Fatal(err) - } - - httpReq, err := http.NewRequest("GET", "http://example.com", nil) - if err != nil { - t.Fatal(err) - } - val, err := synologyProxyFromConfigCached(httpReq) - if err != nil { - t.Fatal(err) - } - if val == nil { - t.Fatalf("got nil, want an http URL") - } - if got, want := val.String(), "http://10.0.0.55:80"; got != want { - t.Fatalf("got %q, want %q", got, want) - } - - httpsReq, err := http.NewRequest("GET", "https://example.com", nil) - if err != nil { - t.Fatal(err) - } - val, err = synologyProxyFromConfigCached(httpsReq) - if err != nil { - t.Fatal(err) - } - if val == nil { - t.Fatalf("got nil, want an http URL") - } - if got, want := val.String(), "http://10.0.0.66:443"; got != want { - t.Fatalf("got %q, want %q", got, want) - } - }) -} - -func TestSynologyProxiesFromConfig(t *testing.T) { - var ( - openReader io.ReadCloser - openErr error - ) - tstest.Replace(t, &openSynologyProxyConf, func() (io.ReadCloser, error) { - return openReader, openErr - }) - - t.Run("with config", func(t *testing.T) { - mc := &mustCloser{Reader: strings.NewReader(` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=yes -https_host=10.0.0.66 -https_port=8443 -http_host=10.0.0.55 -http_port=80 - `)} - defer mc.check(t) - openReader = mc - - httpProxy, httpsProxy, err := synologyProxiesFromConfig() - - if got, want := err, openErr; got != want { - t.Fatalf("got %s, want %s", got, want) - } - - if got, want := httpsProxy, urlMustParse("http://foo:bar@10.0.0.66:8443"); got.String() != want.String() { - t.Fatalf("got %s, want %s", got, want) - } - - if got, want := err, openErr; got != want { - t.Fatalf("got %s, want %s", got, want) - } - - if got, want := httpProxy, urlMustParse("http://foo:bar@10.0.0.55:80"); got.String() != want.String() { - t.Fatalf("got %s, want %s", got, want) - } - - }) - - t.Run("nonexistent config", func(t *testing.T) { - openReader = nil - openErr = os.ErrNotExist - - httpProxy, httpsProxy, err := synologyProxiesFromConfig() - if err != nil { - t.Fatalf("expected no error, got %s", err) - } - if httpProxy != nil { - t.Fatalf("expected no url, got %s", httpProxy) - } - if httpsProxy != nil { - t.Fatalf("expected no url, got %s", httpsProxy) - } - }) - - t.Run("error opening config", func(t *testing.T) { - openReader = nil - openErr = errors.New("example error") - - httpProxy, httpsProxy, err := synologyProxiesFromConfig() - if err != openErr { - t.Fatalf("expected %s, got %s", openErr, err) - } - if httpProxy != nil { - t.Fatalf("expected no url, got %s", httpProxy) - } - if httpsProxy != nil { - t.Fatalf("expected no url, got %s", httpsProxy) - } - }) - -} - -func TestParseSynologyConfig(t *testing.T) { - cases := map[string]struct { - input string - httpProxy *url.URL - httpsProxy *url.URL - err error - }{ - "populated": { - input: ` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=yes -https_host=10.0.0.66 -https_port=8443 -http_host=10.0.0.55 -http_port=80 -`, - httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"), - httpsProxy: urlMustParse("http://foo:bar@10.0.0.66:8443"), - err: nil, - }, - "no-auth": { - input: ` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=no -https_host=10.0.0.66 -https_port=8443 -http_host=10.0.0.55 -http_port=80 -`, - httpProxy: urlMustParse("http://10.0.0.55:80"), - httpsProxy: urlMustParse("http://10.0.0.66:8443"), - err: nil, - }, - "http-only": { - input: ` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=yes -https_host= -https_port=8443 -http_host=10.0.0.55 -http_port=80 -`, - httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"), - httpsProxy: nil, - err: nil, - }, - "empty": { - input: ` -proxy_user= -proxy_pwd= -proxy_enabled= -adv_enabled= -bypass_enabled= -auth_enabled= -https_host= -https_port= -http_host= -http_port= -`, - httpProxy: nil, - httpsProxy: nil, - err: nil, - }, - } - - for name, example := range cases { - t.Run(name, func(t *testing.T) { - httpProxy, httpsProxy, err := parseSynologyConfig(strings.NewReader(example.input)) - if err != example.err { - t.Fatal(err) - } - if example.err != nil { - return - } - - if example.httpProxy == nil && httpProxy != nil { - t.Fatalf("got %s, want nil", httpProxy) - } - - if example.httpProxy != nil { - if httpProxy == nil { - t.Fatalf("got nil, want %s", example.httpProxy) - } - - if got, want := example.httpProxy.String(), httpProxy.String(); got != want { - t.Fatalf("got %s, want %s", got, want) - } - } - - if example.httpsProxy == nil && httpsProxy != nil { - t.Fatalf("got %s, want nil", httpProxy) - } - - if example.httpsProxy != nil { - if httpsProxy == nil { - t.Fatalf("got nil, want %s", example.httpsProxy) - } - - if got, want := example.httpsProxy.String(), httpsProxy.String(); got != want { - t.Fatalf("got %s, want %s", got, want) - } - } - }) - } -} -func urlMustParse(u string) *url.URL { - r, err := url.Parse(u) - if err != nil { - panic(fmt.Sprintf("urlMustParse: %s", err)) - } - return r -} - -type mustCloser struct { - io.Reader - closed bool -} - -func (m *mustCloser) Close() error { - m.closed = true - return nil -} - -func (m *mustCloser) check(t *testing.T) { - if !m.closed { - t.Errorf("mustCloser wrapping %#v was not closed at time of check", m.Reader) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package tshttpproxy + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "tailscale.com/tstest" +) + +func TestSynologyProxyFromConfigCached(t *testing.T) { + req, err := http.NewRequest("GET", "http://example.org/", nil) + if err != nil { + t.Fatal(err) + } + + tstest.Replace(t, &synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf")) + + t.Run("no config file", func(t *testing.T) { + if _, err := os.Stat(synologyProxyConfigPath); err == nil { + t.Fatalf("%s must not exist for this test", synologyProxyConfigPath) + } + + cache.updated = time.Time{} + cache.httpProxy = nil + cache.httpsProxy = nil + + if val, err := synologyProxyFromConfigCached(req); val != nil || err != nil { + t.Fatalf("got %s, %v; want nil, nil", val, err) + } + + if got, want := cache.updated, time.Unix(0, 0); got != want { + t.Fatalf("got %s, want %s", got, want) + } + if cache.httpProxy != nil { + t.Fatalf("got %s, want nil", cache.httpProxy) + } + if cache.httpsProxy != nil { + t.Fatalf("got %s, want nil", cache.httpsProxy) + } + }) + + t.Run("config file updated", func(t *testing.T) { + cache.updated = time.Now() + cache.httpProxy = nil + cache.httpsProxy = nil + + if err := os.WriteFile(synologyProxyConfigPath, []byte(` +proxy_enabled=yes +http_host=10.0.0.55 +http_port=80 +https_host=10.0.0.66 +https_port=443 + `), 0600); err != nil { + t.Fatal(err) + } + + val, err := synologyProxyFromConfigCached(req) + if err != nil { + t.Fatal(err) + } + + if cache.httpProxy == nil { + t.Fatal("http proxy was not cached") + } + if cache.httpsProxy == nil { + t.Fatal("https proxy was not cached") + } + + if want := urlMustParse("http://10.0.0.55:80"); val.String() != want.String() { + t.Fatalf("got %s; want %s", val, want) + } + }) + + t.Run("config file removed", func(t *testing.T) { + cache.updated = time.Now() + cache.httpProxy = urlMustParse("http://127.0.0.1/") + cache.httpsProxy = urlMustParse("http://127.0.0.1/") + + if err := os.Remove(synologyProxyConfigPath); err != nil && !os.IsNotExist(err) { + t.Fatal(err) + } + + val, err := synologyProxyFromConfigCached(req) + if err != nil { + t.Fatal(err) + } + if val != nil { + t.Fatalf("got %s; want nil", val) + } + if cache.httpProxy != nil { + t.Fatalf("got %s, want nil", cache.httpProxy) + } + if cache.httpsProxy != nil { + t.Fatalf("got %s, want nil", cache.httpsProxy) + } + }) + + t.Run("picks proxy from request scheme", func(t *testing.T) { + cache.updated = time.Now() + cache.httpProxy = nil + cache.httpsProxy = nil + + if err := os.WriteFile(synologyProxyConfigPath, []byte(` +proxy_enabled=yes +http_host=10.0.0.55 +http_port=80 +https_host=10.0.0.66 +https_port=443 + `), 0600); err != nil { + t.Fatal(err) + } + + httpReq, err := http.NewRequest("GET", "http://example.com", nil) + if err != nil { + t.Fatal(err) + } + val, err := synologyProxyFromConfigCached(httpReq) + if err != nil { + t.Fatal(err) + } + if val == nil { + t.Fatalf("got nil, want an http URL") + } + if got, want := val.String(), "http://10.0.0.55:80"; got != want { + t.Fatalf("got %q, want %q", got, want) + } + + httpsReq, err := http.NewRequest("GET", "https://example.com", nil) + if err != nil { + t.Fatal(err) + } + val, err = synologyProxyFromConfigCached(httpsReq) + if err != nil { + t.Fatal(err) + } + if val == nil { + t.Fatalf("got nil, want an http URL") + } + if got, want := val.String(), "http://10.0.0.66:443"; got != want { + t.Fatalf("got %q, want %q", got, want) + } + }) +} + +func TestSynologyProxiesFromConfig(t *testing.T) { + var ( + openReader io.ReadCloser + openErr error + ) + tstest.Replace(t, &openSynologyProxyConf, func() (io.ReadCloser, error) { + return openReader, openErr + }) + + t.Run("with config", func(t *testing.T) { + mc := &mustCloser{Reader: strings.NewReader(` +proxy_user=foo +proxy_pwd=bar +proxy_enabled=yes +adv_enabled=yes +bypass_enabled=yes +auth_enabled=yes +https_host=10.0.0.66 +https_port=8443 +http_host=10.0.0.55 +http_port=80 + `)} + defer mc.check(t) + openReader = mc + + httpProxy, httpsProxy, err := synologyProxiesFromConfig() + + if got, want := err, openErr; got != want { + t.Fatalf("got %s, want %s", got, want) + } + + if got, want := httpsProxy, urlMustParse("http://foo:bar@10.0.0.66:8443"); got.String() != want.String() { + t.Fatalf("got %s, want %s", got, want) + } + + if got, want := err, openErr; got != want { + t.Fatalf("got %s, want %s", got, want) + } + + if got, want := httpProxy, urlMustParse("http://foo:bar@10.0.0.55:80"); got.String() != want.String() { + t.Fatalf("got %s, want %s", got, want) + } + + }) + + t.Run("nonexistent config", func(t *testing.T) { + openReader = nil + openErr = os.ErrNotExist + + httpProxy, httpsProxy, err := synologyProxiesFromConfig() + if err != nil { + t.Fatalf("expected no error, got %s", err) + } + if httpProxy != nil { + t.Fatalf("expected no url, got %s", httpProxy) + } + if httpsProxy != nil { + t.Fatalf("expected no url, got %s", httpsProxy) + } + }) + + t.Run("error opening config", func(t *testing.T) { + openReader = nil + openErr = errors.New("example error") + + httpProxy, httpsProxy, err := synologyProxiesFromConfig() + if err != openErr { + t.Fatalf("expected %s, got %s", openErr, err) + } + if httpProxy != nil { + t.Fatalf("expected no url, got %s", httpProxy) + } + if httpsProxy != nil { + t.Fatalf("expected no url, got %s", httpsProxy) + } + }) + +} + +func TestParseSynologyConfig(t *testing.T) { + cases := map[string]struct { + input string + httpProxy *url.URL + httpsProxy *url.URL + err error + }{ + "populated": { + input: ` +proxy_user=foo +proxy_pwd=bar +proxy_enabled=yes +adv_enabled=yes +bypass_enabled=yes +auth_enabled=yes +https_host=10.0.0.66 +https_port=8443 +http_host=10.0.0.55 +http_port=80 +`, + httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"), + httpsProxy: urlMustParse("http://foo:bar@10.0.0.66:8443"), + err: nil, + }, + "no-auth": { + input: ` +proxy_user=foo +proxy_pwd=bar +proxy_enabled=yes +adv_enabled=yes +bypass_enabled=yes +auth_enabled=no +https_host=10.0.0.66 +https_port=8443 +http_host=10.0.0.55 +http_port=80 +`, + httpProxy: urlMustParse("http://10.0.0.55:80"), + httpsProxy: urlMustParse("http://10.0.0.66:8443"), + err: nil, + }, + "http-only": { + input: ` +proxy_user=foo +proxy_pwd=bar +proxy_enabled=yes +adv_enabled=yes +bypass_enabled=yes +auth_enabled=yes +https_host= +https_port=8443 +http_host=10.0.0.55 +http_port=80 +`, + httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"), + httpsProxy: nil, + err: nil, + }, + "empty": { + input: ` +proxy_user= +proxy_pwd= +proxy_enabled= +adv_enabled= +bypass_enabled= +auth_enabled= +https_host= +https_port= +http_host= +http_port= +`, + httpProxy: nil, + httpsProxy: nil, + err: nil, + }, + } + + for name, example := range cases { + t.Run(name, func(t *testing.T) { + httpProxy, httpsProxy, err := parseSynologyConfig(strings.NewReader(example.input)) + if err != example.err { + t.Fatal(err) + } + if example.err != nil { + return + } + + if example.httpProxy == nil && httpProxy != nil { + t.Fatalf("got %s, want nil", httpProxy) + } + + if example.httpProxy != nil { + if httpProxy == nil { + t.Fatalf("got nil, want %s", example.httpProxy) + } + + if got, want := example.httpProxy.String(), httpProxy.String(); got != want { + t.Fatalf("got %s, want %s", got, want) + } + } + + if example.httpsProxy == nil && httpsProxy != nil { + t.Fatalf("got %s, want nil", httpProxy) + } + + if example.httpsProxy != nil { + if httpsProxy == nil { + t.Fatalf("got nil, want %s", example.httpsProxy) + } + + if got, want := example.httpsProxy.String(), httpsProxy.String(); got != want { + t.Fatalf("got %s, want %s", got, want) + } + } + }) + } +} +func urlMustParse(u string) *url.URL { + r, err := url.Parse(u) + if err != nil { + panic(fmt.Sprintf("urlMustParse: %s", err)) + } + return r +} + +type mustCloser struct { + io.Reader + closed bool +} + +func (m *mustCloser) Close() error { + m.closed = true + return nil +} + +func (m *mustCloser) check(t *testing.T) { + if !m.closed { + t.Errorf("mustCloser wrapping %#v was not closed at time of check", m.Reader) + } +} diff --git a/net/tshttpproxy/tshttpproxy_windows.go b/net/tshttpproxy/tshttpproxy_windows.go index cb6b24c8355e8..06a1f5ae445d0 100644 --- a/net/tshttpproxy/tshttpproxy_windows.go +++ b/net/tshttpproxy/tshttpproxy_windows.go @@ -1,276 +1,276 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tshttpproxy - -import ( - "context" - "encoding/base64" - "fmt" - "log" - "net/http" - "net/url" - "runtime" - "strings" - "sync" - "syscall" - "time" - "unsafe" - - "github.com/alexbrainman/sspi/negotiate" - "golang.org/x/sys/windows" - "tailscale.com/hostinfo" - "tailscale.com/syncs" - "tailscale.com/types/logger" - "tailscale.com/util/clientmetric" - "tailscale.com/util/cmpver" -) - -func init() { - sysProxyFromEnv = proxyFromWinHTTPOrCache - sysAuthHeader = sysAuthHeaderWindows -} - -var cachedProxy struct { - sync.Mutex - val *url.URL -} - -// proxyErrorf is a rate-limited logger specifically for errors asking -// WinHTTP for the proxy information. We don't want to log about -// errors often, otherwise the log message itself will generate a new -// HTTP request which ultimately will call back into us to log again, -// forever. So for errors, we only log a bit. -var proxyErrorf = logger.RateLimitedFn(log.Printf, 10*time.Minute, 2 /* burst*/, 10 /* maxCache */) - -var ( - metricSuccess = clientmetric.NewCounter("winhttp_proxy_success") - metricErrDetectionFailed = clientmetric.NewCounter("winhttp_proxy_err_detection_failed") - metricErrInvalidParameters = clientmetric.NewCounter("winhttp_proxy_err_invalid_param") - metricErrDownloadScript = clientmetric.NewCounter("winhttp_proxy_err_download_script") - metricErrTimeout = clientmetric.NewCounter("winhttp_proxy_err_timeout") - metricErrOther = clientmetric.NewCounter("winhttp_proxy_err_other") -) - -func proxyFromWinHTTPOrCache(req *http.Request) (*url.URL, error) { - if req.URL == nil { - return nil, nil - } - urlStr := req.URL.String() - - ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second) - defer cancel() - - type result struct { - proxy *url.URL - err error - } - resc := make(chan result, 1) - go func() { - proxy, err := proxyFromWinHTTP(ctx, urlStr) - resc <- result{proxy, err} - }() - - select { - case res := <-resc: - err := res.err - if err == nil { - metricSuccess.Add(1) - cachedProxy.Lock() - defer cachedProxy.Unlock() - if was, now := fmt.Sprint(cachedProxy.val), fmt.Sprint(res.proxy); was != now { - log.Printf("tshttpproxy: winhttp: updating cached proxy setting from %v to %v", was, now) - } - cachedProxy.val = res.proxy - return res.proxy, nil - } - - // See https://docs.microsoft.com/en-us/windows/win32/winhttp/error-messages - const ( - ERROR_WINHTTP_AUTODETECTION_FAILED = 12180 - ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT = 12167 - ) - if err == syscall.Errno(ERROR_WINHTTP_AUTODETECTION_FAILED) { - metricErrDetectionFailed.Add(1) - setNoProxyUntil(10 * time.Second) - return nil, nil - } - if err == windows.ERROR_INVALID_PARAMETER { - metricErrInvalidParameters.Add(1) - // Seen on Windows 8.1. (https://github.com/tailscale/tailscale/issues/879) - // TODO(bradfitz): figure this out. - setNoProxyUntil(time.Hour) - proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): ERROR_INVALID_PARAMETER [unexpected]", urlStr) - return nil, nil - } - proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): %v/%#v", urlStr, err, err) - if err == syscall.Errno(ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT) { - metricErrDownloadScript.Add(1) - setNoProxyUntil(10 * time.Second) - return nil, nil - } - metricErrOther.Add(1) - return nil, err - case <-ctx.Done(): - metricErrTimeout.Add(1) - cachedProxy.Lock() - defer cachedProxy.Unlock() - proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): timeout; using cached proxy %v", urlStr, cachedProxy.val) - return cachedProxy.val, nil - } -} - -func proxyFromWinHTTP(ctx context.Context, urlStr string) (proxy *url.URL, err error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - whi, err := httpOpen() - if err != nil { - proxyErrorf("winhttp: Open: %v", err) - return nil, err - } - defer whi.Close() - - t0 := time.Now() - v, err := whi.GetProxyForURL(urlStr) - td := time.Since(t0).Round(time.Millisecond) - if err := ctx.Err(); err != nil { - log.Printf("tshttpproxy: winhttp: context canceled, ignoring GetProxyForURL(%q) after %v", urlStr, td) - return nil, err - } - if err != nil { - return nil, err - } - if v == "" { - return nil, nil - } - // Discard all but first proxy value for now. - if i := strings.Index(v, ";"); i != -1 { - v = v[:i] - } - if !strings.HasPrefix(v, "https://") { - v = "http://" + v - } - return url.Parse(v) -} - -var userAgent = windows.StringToUTF16Ptr("Tailscale") - -const ( - winHTTP_ACCESS_TYPE_DEFAULT_PROXY = 0 - winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY = 4 - winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG = 0x00000100 - winHTTP_AUTOPROXY_AUTO_DETECT = 1 - winHTTP_AUTO_DETECT_TYPE_DHCP = 0x00000001 - winHTTP_AUTO_DETECT_TYPE_DNS_A = 0x00000002 -) - -// Windows 8.1 is actually Windows 6.3 under the hood. Yay, marketing! -const win8dot1Ver = "6.3" - -// accessType is the flag we must pass to WinHttpOpen for proxy resolution -// depending on whether or not we're running Windows < 8.1 -var accessType syncs.AtomicValue[uint32] - -func getAccessFlag() uint32 { - if flag, ok := accessType.LoadOk(); ok { - return flag - } - var flag uint32 - if cmpver.Compare(hostinfo.GetOSVersion(), win8dot1Ver) < 0 { - flag = winHTTP_ACCESS_TYPE_DEFAULT_PROXY - } else { - flag = winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY - } - accessType.Store(flag) - return flag -} - -func httpOpen() (winHTTPInternet, error) { - return winHTTPOpen( - userAgent, - getAccessFlag(), - nil, /* WINHTTP_NO_PROXY_NAME */ - nil, /* WINHTTP_NO_PROXY_BYPASS */ - 0, - ) -} - -type winHTTPInternet windows.Handle - -func (hi winHTTPInternet) Close() error { - return winHTTPCloseHandle(hi) -} - -// WINHTTP_AUTOPROXY_OPTIONS -// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_autoproxy_options -type winHTTPAutoProxyOptions struct { - DwFlags uint32 - DwAutoDetectFlags uint32 - AutoConfigUrl *uint16 - _ uintptr - _ uint32 - FAutoLogonIfChallenged int32 // BOOL -} - -// WINHTTP_PROXY_INFO -// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_proxy_info -type winHTTPProxyInfo struct { - AccessType uint32 - Proxy *uint16 - ProxyBypass *uint16 -} - -type winHGlobal windows.Handle - -func globalFreeUTF16Ptr(p *uint16) error { - return globalFree((winHGlobal)(unsafe.Pointer(p))) -} - -func (pi *winHTTPProxyInfo) free() { - if pi.Proxy != nil { - globalFreeUTF16Ptr(pi.Proxy) - pi.Proxy = nil - } - if pi.ProxyBypass != nil { - globalFreeUTF16Ptr(pi.ProxyBypass) - pi.ProxyBypass = nil - } -} - -var proxyForURLOpts = &winHTTPAutoProxyOptions{ - DwFlags: winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG | winHTTP_AUTOPROXY_AUTO_DETECT, - DwAutoDetectFlags: winHTTP_AUTO_DETECT_TYPE_DHCP, // | winHTTP_AUTO_DETECT_TYPE_DNS_A, -} - -func (hi winHTTPInternet) GetProxyForURL(urlStr string) (string, error) { - var out winHTTPProxyInfo - err := winHTTPGetProxyForURL( - hi, - windows.StringToUTF16Ptr(urlStr), - proxyForURLOpts, - &out, - ) - if err != nil { - return "", err - } - defer out.free() - return windows.UTF16PtrToString(out.Proxy), nil -} - -func sysAuthHeaderWindows(u *url.URL) (string, error) { - spn := "HTTP/" + u.Hostname() - creds, err := negotiate.AcquireCurrentUserCredentials() - if err != nil { - return "", fmt.Errorf("negotiate.AcquireCurrentUserCredentials: %w", err) - } - defer creds.Release() - - secCtx, token, err := negotiate.NewClientContext(creds, spn) - if err != nil { - return "", fmt.Errorf("negotiate.NewClientContext: %w", err) - } - defer secCtx.Release() - - return "Negotiate " + base64.StdEncoding.EncodeToString(token), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tshttpproxy + +import ( + "context" + "encoding/base64" + "fmt" + "log" + "net/http" + "net/url" + "runtime" + "strings" + "sync" + "syscall" + "time" + "unsafe" + + "github.com/alexbrainman/sspi/negotiate" + "golang.org/x/sys/windows" + "tailscale.com/hostinfo" + "tailscale.com/syncs" + "tailscale.com/types/logger" + "tailscale.com/util/clientmetric" + "tailscale.com/util/cmpver" +) + +func init() { + sysProxyFromEnv = proxyFromWinHTTPOrCache + sysAuthHeader = sysAuthHeaderWindows +} + +var cachedProxy struct { + sync.Mutex + val *url.URL +} + +// proxyErrorf is a rate-limited logger specifically for errors asking +// WinHTTP for the proxy information. We don't want to log about +// errors often, otherwise the log message itself will generate a new +// HTTP request which ultimately will call back into us to log again, +// forever. So for errors, we only log a bit. +var proxyErrorf = logger.RateLimitedFn(log.Printf, 10*time.Minute, 2 /* burst*/, 10 /* maxCache */) + +var ( + metricSuccess = clientmetric.NewCounter("winhttp_proxy_success") + metricErrDetectionFailed = clientmetric.NewCounter("winhttp_proxy_err_detection_failed") + metricErrInvalidParameters = clientmetric.NewCounter("winhttp_proxy_err_invalid_param") + metricErrDownloadScript = clientmetric.NewCounter("winhttp_proxy_err_download_script") + metricErrTimeout = clientmetric.NewCounter("winhttp_proxy_err_timeout") + metricErrOther = clientmetric.NewCounter("winhttp_proxy_err_other") +) + +func proxyFromWinHTTPOrCache(req *http.Request) (*url.URL, error) { + if req.URL == nil { + return nil, nil + } + urlStr := req.URL.String() + + ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second) + defer cancel() + + type result struct { + proxy *url.URL + err error + } + resc := make(chan result, 1) + go func() { + proxy, err := proxyFromWinHTTP(ctx, urlStr) + resc <- result{proxy, err} + }() + + select { + case res := <-resc: + err := res.err + if err == nil { + metricSuccess.Add(1) + cachedProxy.Lock() + defer cachedProxy.Unlock() + if was, now := fmt.Sprint(cachedProxy.val), fmt.Sprint(res.proxy); was != now { + log.Printf("tshttpproxy: winhttp: updating cached proxy setting from %v to %v", was, now) + } + cachedProxy.val = res.proxy + return res.proxy, nil + } + + // See https://docs.microsoft.com/en-us/windows/win32/winhttp/error-messages + const ( + ERROR_WINHTTP_AUTODETECTION_FAILED = 12180 + ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT = 12167 + ) + if err == syscall.Errno(ERROR_WINHTTP_AUTODETECTION_FAILED) { + metricErrDetectionFailed.Add(1) + setNoProxyUntil(10 * time.Second) + return nil, nil + } + if err == windows.ERROR_INVALID_PARAMETER { + metricErrInvalidParameters.Add(1) + // Seen on Windows 8.1. (https://github.com/tailscale/tailscale/issues/879) + // TODO(bradfitz): figure this out. + setNoProxyUntil(time.Hour) + proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): ERROR_INVALID_PARAMETER [unexpected]", urlStr) + return nil, nil + } + proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): %v/%#v", urlStr, err, err) + if err == syscall.Errno(ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT) { + metricErrDownloadScript.Add(1) + setNoProxyUntil(10 * time.Second) + return nil, nil + } + metricErrOther.Add(1) + return nil, err + case <-ctx.Done(): + metricErrTimeout.Add(1) + cachedProxy.Lock() + defer cachedProxy.Unlock() + proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): timeout; using cached proxy %v", urlStr, cachedProxy.val) + return cachedProxy.val, nil + } +} + +func proxyFromWinHTTP(ctx context.Context, urlStr string) (proxy *url.URL, err error) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + whi, err := httpOpen() + if err != nil { + proxyErrorf("winhttp: Open: %v", err) + return nil, err + } + defer whi.Close() + + t0 := time.Now() + v, err := whi.GetProxyForURL(urlStr) + td := time.Since(t0).Round(time.Millisecond) + if err := ctx.Err(); err != nil { + log.Printf("tshttpproxy: winhttp: context canceled, ignoring GetProxyForURL(%q) after %v", urlStr, td) + return nil, err + } + if err != nil { + return nil, err + } + if v == "" { + return nil, nil + } + // Discard all but first proxy value for now. + if i := strings.Index(v, ";"); i != -1 { + v = v[:i] + } + if !strings.HasPrefix(v, "https://") { + v = "http://" + v + } + return url.Parse(v) +} + +var userAgent = windows.StringToUTF16Ptr("Tailscale") + +const ( + winHTTP_ACCESS_TYPE_DEFAULT_PROXY = 0 + winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY = 4 + winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG = 0x00000100 + winHTTP_AUTOPROXY_AUTO_DETECT = 1 + winHTTP_AUTO_DETECT_TYPE_DHCP = 0x00000001 + winHTTP_AUTO_DETECT_TYPE_DNS_A = 0x00000002 +) + +// Windows 8.1 is actually Windows 6.3 under the hood. Yay, marketing! +const win8dot1Ver = "6.3" + +// accessType is the flag we must pass to WinHttpOpen for proxy resolution +// depending on whether or not we're running Windows < 8.1 +var accessType syncs.AtomicValue[uint32] + +func getAccessFlag() uint32 { + if flag, ok := accessType.LoadOk(); ok { + return flag + } + var flag uint32 + if cmpver.Compare(hostinfo.GetOSVersion(), win8dot1Ver) < 0 { + flag = winHTTP_ACCESS_TYPE_DEFAULT_PROXY + } else { + flag = winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY + } + accessType.Store(flag) + return flag +} + +func httpOpen() (winHTTPInternet, error) { + return winHTTPOpen( + userAgent, + getAccessFlag(), + nil, /* WINHTTP_NO_PROXY_NAME */ + nil, /* WINHTTP_NO_PROXY_BYPASS */ + 0, + ) +} + +type winHTTPInternet windows.Handle + +func (hi winHTTPInternet) Close() error { + return winHTTPCloseHandle(hi) +} + +// WINHTTP_AUTOPROXY_OPTIONS +// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_autoproxy_options +type winHTTPAutoProxyOptions struct { + DwFlags uint32 + DwAutoDetectFlags uint32 + AutoConfigUrl *uint16 + _ uintptr + _ uint32 + FAutoLogonIfChallenged int32 // BOOL +} + +// WINHTTP_PROXY_INFO +// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_proxy_info +type winHTTPProxyInfo struct { + AccessType uint32 + Proxy *uint16 + ProxyBypass *uint16 +} + +type winHGlobal windows.Handle + +func globalFreeUTF16Ptr(p *uint16) error { + return globalFree((winHGlobal)(unsafe.Pointer(p))) +} + +func (pi *winHTTPProxyInfo) free() { + if pi.Proxy != nil { + globalFreeUTF16Ptr(pi.Proxy) + pi.Proxy = nil + } + if pi.ProxyBypass != nil { + globalFreeUTF16Ptr(pi.ProxyBypass) + pi.ProxyBypass = nil + } +} + +var proxyForURLOpts = &winHTTPAutoProxyOptions{ + DwFlags: winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG | winHTTP_AUTOPROXY_AUTO_DETECT, + DwAutoDetectFlags: winHTTP_AUTO_DETECT_TYPE_DHCP, // | winHTTP_AUTO_DETECT_TYPE_DNS_A, +} + +func (hi winHTTPInternet) GetProxyForURL(urlStr string) (string, error) { + var out winHTTPProxyInfo + err := winHTTPGetProxyForURL( + hi, + windows.StringToUTF16Ptr(urlStr), + proxyForURLOpts, + &out, + ) + if err != nil { + return "", err + } + defer out.free() + return windows.UTF16PtrToString(out.Proxy), nil +} + +func sysAuthHeaderWindows(u *url.URL) (string, error) { + spn := "HTTP/" + u.Hostname() + creds, err := negotiate.AcquireCurrentUserCredentials() + if err != nil { + return "", fmt.Errorf("negotiate.AcquireCurrentUserCredentials: %w", err) + } + defer creds.Release() + + secCtx, token, err := negotiate.NewClientContext(creds, spn) + if err != nil { + return "", fmt.Errorf("negotiate.NewClientContext: %w", err) + } + defer secCtx.Release() + + return "Negotiate " + base64.StdEncoding.EncodeToString(token), nil +} diff --git a/net/tstun/fake.go b/net/tstun/fake.go index a002952a3eef5..3d86bb3df4ca9 100644 --- a/net/tstun/fake.go +++ b/net/tstun/fake.go @@ -1,58 +1,58 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "io" - "os" - - "github.com/tailscale/wireguard-go/tun" -) - -type fakeTUN struct { - evchan chan tun.Event - closechan chan struct{} -} - -// NewFake returns a tun.Device that does nothing. -func NewFake() tun.Device { - return &fakeTUN{ - evchan: make(chan tun.Event), - closechan: make(chan struct{}), - } -} - -func (t *fakeTUN) File() *os.File { - panic("fakeTUN.File() called, which makes no sense") -} - -func (t *fakeTUN) Close() error { - close(t.closechan) - close(t.evchan) - return nil -} - -func (t *fakeTUN) Read(out [][]byte, sizes []int, offset int) (int, error) { - <-t.closechan - return 0, io.EOF -} - -func (t *fakeTUN) Write(b [][]byte, n int) (int, error) { - select { - case <-t.closechan: - return 0, ErrClosed - default: - } - return 1, nil -} - -// FakeTUNName is the name of the fake TUN device. -const FakeTUNName = "FakeTUN" - -func (t *fakeTUN) Flush() error { return nil } -func (t *fakeTUN) MTU() (int, error) { return 1500, nil } -func (t *fakeTUN) Name() (string, error) { return FakeTUNName, nil } -func (t *fakeTUN) Events() <-chan tun.Event { return t.evchan } -func (t *fakeTUN) BatchSize() int { return 1 } -func (t *fakeTUN) IsFakeTun() bool { return true } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "io" + "os" + + "github.com/tailscale/wireguard-go/tun" +) + +type fakeTUN struct { + evchan chan tun.Event + closechan chan struct{} +} + +// NewFake returns a tun.Device that does nothing. +func NewFake() tun.Device { + return &fakeTUN{ + evchan: make(chan tun.Event), + closechan: make(chan struct{}), + } +} + +func (t *fakeTUN) File() *os.File { + panic("fakeTUN.File() called, which makes no sense") +} + +func (t *fakeTUN) Close() error { + close(t.closechan) + close(t.evchan) + return nil +} + +func (t *fakeTUN) Read(out [][]byte, sizes []int, offset int) (int, error) { + <-t.closechan + return 0, io.EOF +} + +func (t *fakeTUN) Write(b [][]byte, n int) (int, error) { + select { + case <-t.closechan: + return 0, ErrClosed + default: + } + return 1, nil +} + +// FakeTUNName is the name of the fake TUN device. +const FakeTUNName = "FakeTUN" + +func (t *fakeTUN) Flush() error { return nil } +func (t *fakeTUN) MTU() (int, error) { return 1500, nil } +func (t *fakeTUN) Name() (string, error) { return FakeTUNName, nil } +func (t *fakeTUN) Events() <-chan tun.Event { return t.evchan } +func (t *fakeTUN) BatchSize() int { return 1 } +func (t *fakeTUN) IsFakeTun() bool { return true } diff --git a/net/tstun/ifstatus_noop.go b/net/tstun/ifstatus_noop.go index 4d453b72c83bc..8cf569f982010 100644 --- a/net/tstun/ifstatus_noop.go +++ b/net/tstun/ifstatus_noop.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package tstun - -import ( - "time" - - "github.com/tailscale/wireguard-go/tun" - "tailscale.com/types/logger" -) - -// Dummy implementation that does nothing. -func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package tstun + +import ( + "time" + + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/types/logger" +) + +// Dummy implementation that does nothing. +func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { + return nil +} diff --git a/net/tstun/ifstatus_windows.go b/net/tstun/ifstatus_windows.go index 6c6377bb40fb6..fd9fc2112524c 100644 --- a/net/tstun/ifstatus_windows.go +++ b/net/tstun/ifstatus_windows.go @@ -1,109 +1,109 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "fmt" - "sync" - "time" - - "github.com/tailscale/wireguard-go/tun" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "tailscale.com/types/logger" -) - -// ifaceWatcher waits for an interface to be up. -type ifaceWatcher struct { - logf logger.Logf - luid winipcfg.LUID - - mu sync.Mutex // guards following - done bool - sig chan bool -} - -// callback is the callback we register with Windows to call when IP interface changes. -func (iw *ifaceWatcher) callback(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { - // Probably should check only when MibParameterNotification, but just in case included MibAddInstance also. - if notificationType == winipcfg.MibParameterNotification || notificationType == winipcfg.MibAddInstance { - // Out of paranoia, start a goroutine to finish our work, to return to Windows out of this callback. - go iw.isUp() - } -} - -func (iw *ifaceWatcher) isUp() bool { - iw.mu.Lock() - defer iw.mu.Unlock() - - if iw.done { - // We already know that it's up - return true - } - - if iw.getOperStatus() != winipcfg.IfOperStatusUp { - return false - } - - iw.done = true - iw.sig <- true - return true -} - -func (iw *ifaceWatcher) getOperStatus() winipcfg.IfOperStatus { - ifc, err := iw.luid.Interface() - if err != nil { - iw.logf("iw.luid.Interface error: %v", err) - return 0 - } - return ifc.OperStatus -} - -func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { - iw := &ifaceWatcher{ - luid: winipcfg.LUID(iface.(*tun.NativeTun).LUID()), - logf: logger.WithPrefix(logf, "waitInterfaceUp: "), - } - - // Just in case check the status first - if iw.getOperStatus() == winipcfg.IfOperStatusUp { - iw.logf("TUN interface already up; no need to wait") - return nil - } - - iw.sig = make(chan bool, 1) - cb, err := winipcfg.RegisterInterfaceChangeCallback(iw.callback) - if err != nil { - iw.logf("RegisterInterfaceChangeCallback error: %v", err) - return err - } - defer cb.Unregister() - - t0 := time.Now() - expires := t0.Add(timeout) - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - for { - iw.logf("waiting for TUN interface to come up...") - - select { - case <-iw.sig: - iw.logf("TUN interface is up after %v", time.Since(t0)) - return nil - case <-ticker.C: - } - - if iw.isUp() { - // Very unlikely to happen - either NotifyIpInterfaceChange doesn't work - // or it came up in the same moment as tick. Indicate this in the log message. - iw.logf("TUN interface is up after %v (on poll, without notification)", time.Since(t0)) - return nil - } - - if expires.Before(time.Now()) { - iw.logf("timeout waiting %v for TUN interface to come up", timeout) - return fmt.Errorf("timeout waiting for TUN interface to come up") - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "fmt" + "sync" + "time" + + "github.com/tailscale/wireguard-go/tun" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "tailscale.com/types/logger" +) + +// ifaceWatcher waits for an interface to be up. +type ifaceWatcher struct { + logf logger.Logf + luid winipcfg.LUID + + mu sync.Mutex // guards following + done bool + sig chan bool +} + +// callback is the callback we register with Windows to call when IP interface changes. +func (iw *ifaceWatcher) callback(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { + // Probably should check only when MibParameterNotification, but just in case included MibAddInstance also. + if notificationType == winipcfg.MibParameterNotification || notificationType == winipcfg.MibAddInstance { + // Out of paranoia, start a goroutine to finish our work, to return to Windows out of this callback. + go iw.isUp() + } +} + +func (iw *ifaceWatcher) isUp() bool { + iw.mu.Lock() + defer iw.mu.Unlock() + + if iw.done { + // We already know that it's up + return true + } + + if iw.getOperStatus() != winipcfg.IfOperStatusUp { + return false + } + + iw.done = true + iw.sig <- true + return true +} + +func (iw *ifaceWatcher) getOperStatus() winipcfg.IfOperStatus { + ifc, err := iw.luid.Interface() + if err != nil { + iw.logf("iw.luid.Interface error: %v", err) + return 0 + } + return ifc.OperStatus +} + +func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { + iw := &ifaceWatcher{ + luid: winipcfg.LUID(iface.(*tun.NativeTun).LUID()), + logf: logger.WithPrefix(logf, "waitInterfaceUp: "), + } + + // Just in case check the status first + if iw.getOperStatus() == winipcfg.IfOperStatusUp { + iw.logf("TUN interface already up; no need to wait") + return nil + } + + iw.sig = make(chan bool, 1) + cb, err := winipcfg.RegisterInterfaceChangeCallback(iw.callback) + if err != nil { + iw.logf("RegisterInterfaceChangeCallback error: %v", err) + return err + } + defer cb.Unregister() + + t0 := time.Now() + expires := t0.Add(timeout) + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + iw.logf("waiting for TUN interface to come up...") + + select { + case <-iw.sig: + iw.logf("TUN interface is up after %v", time.Since(t0)) + return nil + case <-ticker.C: + } + + if iw.isUp() { + // Very unlikely to happen - either NotifyIpInterfaceChange doesn't work + // or it came up in the same moment as tick. Indicate this in the log message. + iw.logf("TUN interface is up after %v (on poll, without notification)", time.Since(t0)) + return nil + } + + if expires.Before(time.Now()) { + iw.logf("timeout waiting %v for TUN interface to come up", timeout) + return fmt.Errorf("timeout waiting for TUN interface to come up") + } + } +} diff --git a/net/tstun/linkattrs_linux.go b/net/tstun/linkattrs_linux.go index 7f546110995ee..681e79269f75f 100644 --- a/net/tstun/linkattrs_linux.go +++ b/net/tstun/linkattrs_linux.go @@ -1,63 +1,63 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "github.com/mdlayher/genetlink" - "github.com/mdlayher/netlink" - "github.com/tailscale/wireguard-go/tun" - "golang.org/x/sys/unix" -) - -// setLinkSpeed sets the advertised link speed of the TUN interface. -func setLinkSpeed(iface tun.Device, mbps int) error { - name, err := iface.Name() - if err != nil { - return err - } - - conn, err := genetlink.Dial(&netlink.Config{Strict: true}) - if err != nil { - return err - } - - defer conn.Close() - - f, err := conn.GetFamily(unix.ETHTOOL_GENL_NAME) - if err != nil { - return err - } - - ae := netlink.NewAttributeEncoder() - ae.Nested(unix.ETHTOOL_A_LINKMODES_HEADER, func(nae *netlink.AttributeEncoder) error { - nae.String(unix.ETHTOOL_A_HEADER_DEV_NAME, name) - return nil - }) - ae.Uint32(unix.ETHTOOL_A_LINKMODES_SPEED, uint32(mbps)) - - b, err := ae.Encode() - if err != nil { - return err - } - - _, err = conn.Execute( - genetlink.Message{ - Header: genetlink.Header{ - Command: unix.ETHTOOL_MSG_LINKMODES_SET, - Version: unix.ETHTOOL_GENL_VERSION, - }, - Data: b, - }, - f.ID, - netlink.Request|netlink.Acknowledge, - ) - return err -} - -// setLinkAttrs sets up link attributes that can be queried by external tools. -// Its failure is non-fatal to interface bringup. -func setLinkAttrs(iface tun.Device) error { - // By default the link speed is 10Mbps, which is easily exceeded and causes monitoring tools to complain (#3933). - return setLinkSpeed(iface, unix.SPEED_UNKNOWN) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "github.com/mdlayher/genetlink" + "github.com/mdlayher/netlink" + "github.com/tailscale/wireguard-go/tun" + "golang.org/x/sys/unix" +) + +// setLinkSpeed sets the advertised link speed of the TUN interface. +func setLinkSpeed(iface tun.Device, mbps int) error { + name, err := iface.Name() + if err != nil { + return err + } + + conn, err := genetlink.Dial(&netlink.Config{Strict: true}) + if err != nil { + return err + } + + defer conn.Close() + + f, err := conn.GetFamily(unix.ETHTOOL_GENL_NAME) + if err != nil { + return err + } + + ae := netlink.NewAttributeEncoder() + ae.Nested(unix.ETHTOOL_A_LINKMODES_HEADER, func(nae *netlink.AttributeEncoder) error { + nae.String(unix.ETHTOOL_A_HEADER_DEV_NAME, name) + return nil + }) + ae.Uint32(unix.ETHTOOL_A_LINKMODES_SPEED, uint32(mbps)) + + b, err := ae.Encode() + if err != nil { + return err + } + + _, err = conn.Execute( + genetlink.Message{ + Header: genetlink.Header{ + Command: unix.ETHTOOL_MSG_LINKMODES_SET, + Version: unix.ETHTOOL_GENL_VERSION, + }, + Data: b, + }, + f.ID, + netlink.Request|netlink.Acknowledge, + ) + return err +} + +// setLinkAttrs sets up link attributes that can be queried by external tools. +// Its failure is non-fatal to interface bringup. +func setLinkAttrs(iface tun.Device) error { + // By default the link speed is 10Mbps, which is easily exceeded and causes monitoring tools to complain (#3933). + return setLinkSpeed(iface, unix.SPEED_UNKNOWN) +} diff --git a/net/tstun/linkattrs_notlinux.go b/net/tstun/linkattrs_notlinux.go index 45dd000b3d7d4..7a7b40fc2652b 100644 --- a/net/tstun/linkattrs_notlinux.go +++ b/net/tstun/linkattrs_notlinux.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package tstun - -import "github.com/tailscale/wireguard-go/tun" - -func setLinkAttrs(iface tun.Device) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package tstun + +import "github.com/tailscale/wireguard-go/tun" + +func setLinkAttrs(iface tun.Device) error { + return nil +} diff --git a/net/tstun/mtu.go b/net/tstun/mtu.go index b72a19bdebe6e..004529c205f9e 100644 --- a/net/tstun/mtu.go +++ b/net/tstun/mtu.go @@ -1,161 +1,161 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "tailscale.com/envknob" -) - -// The MTU (Maximum Transmission Unit) of a network interface is the largest -// packet that can be sent or received through that interface, including all -// headers above the link layer (e.g. IP headers, UDP headers, Wireguard -// headers, etc.). We have to think about several different values of MTU: -// -// Wire MTU: The MTU of an interface underneath the tailscale TUN, e.g. an -// Ethernet network card will default to a 1500 byte MTU. The user may change -// this MTU at any time. -// -// TUN MTU: The current MTU of the tailscale TUN. This MTU is adjusted downward -// to make room for the wireguard/tailscale headers. For example, if the -// underlying network interface's MTU is 1500 bytes, the maximum size of a -// packet entering the tailscale TUN is 1420 bytes. The user may change this MTU -// at any time via the OS's tools (ifconfig, ip, etc.). -// -// User configured initial MTU: The MTU the tailscale TUN should be created -// with, set by the user via TS_DEBUG_MTU. It should be adjusted down from the -// underlying interface MTU by 80 bytes to make room for the wireguard -// headers. This envknob is mostly for debugging. This value is used once at TUN -// creation and ignored thereafter. -// -// User configured current MTU: The MTU set via the OS's tools (ifconfig, ip, -// etc.). This MTU can change at any time. Setting the MTU this way goes through -// the MTU() method of tailscale's TUN wrapper. -// -// Maximum probed MTU: This is the largest MTU size that we send probe packets -// for. -// -// Safe MTU: If the tailscale TUN MTU is set to this value, almost all packets -// will get to their destination. Tailscale defaults to this MTU in the absence -// of path MTU probe information or user MTU configuration. We may occasionally -// find a path that needs a smaller MTU but it is very rare. -// -// Peer MTU: This is the path MTU to a peer's current best endpoint. It defaults -// to the Safe MTU unless we have path MTU probe results that tell us otherwise. -// -// Initial MTU: This is the MTU tailscaled creates the TUN with. In order of -// priority, it is: -// -// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of 65536 -// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg -// overhead -// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU -// -// Current MTU: This the MTU of the tailscale TUN at any given moment -// after TUN creation. In order of priority, it is: -// -// 1. The MTU set by the user via the OS, if it has ever been set -// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg -// overhead -// 4. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU - -// TUNMTU is the MTU for the tailscale TUN. -type TUNMTU uint32 - -// WireMTU is the MTU for the underlying network devices. -type WireMTU uint32 - -const ( - // maxTUNMTU is the largest MTU we will consider for the Tailscale - // TUN. This is inherited from wireguard-go and can be surprisingly - // small; on Windows it is currently 2048 - 32 bytes and iOS it is 1700 - // - 32 bytes. - // TODO(val,raggi): On Windows this seems to derive from RIO driver - // constraints in Wireguard but we don't use RIO so could probably make - // this bigger. - maxTUNMTU TUNMTU = TUNMTU(MaxPacketSize) - // safeTUNMTU is the default "safe" MTU for the Tailscale TUN that we - // use in the absence of other information such as path MTU probes. - safeTUNMTU TUNMTU = 1280 -) - -// WireMTUsToProbe is a list of the on-the-wire MTUs we want to probe. Each time -// magicsock discovery begins, it will send a set of pings, one of each size -// listed below. -var WireMTUsToProbe = []WireMTU{ - WireMTU(safeTUNMTU), // Tailscale over Tailscale :) - TUNToWireMTU(safeTUNMTU), // Smallest MTU allowed for IPv6, current default - 1400, // Most common MTU minus a few bytes for tunnels - 1500, // Most common MTU - 8000, // Should fit inside all jumbo frame sizes - 9000, // Most jumbo frames are this size or larger -} - -// wgHeaderLen is the length of all the headers Wireguard adds to a packet -// in the worst case (IPv6). This constant is for use when we can't or -// shouldn't use information about the IP version of a specific packet -// (e.g., calculating the MTU for the Tailscale interface. -// -// A Wireguard header includes: -// -// - 20-byte IPv4 header or 40-byte IPv6 header -// - 8-byte UDP header -// - 4-byte type -// - 4-byte key index -// - 8-byte nonce -// - 16-byte authentication tag -const wgHeaderLen = 40 + 8 + 4 + 4 + 8 + 16 - -// TUNToWireMTU takes the MTU that the Tailscale TUN presents to the user and -// returns the on-the-wire MTU necessary to transmit the largest packet that -// will fit through the TUN, given that we have to add wireguard headers. -func TUNToWireMTU(t TUNMTU) WireMTU { - return WireMTU(t + wgHeaderLen) -} - -// WireToTUNMTU takes the MTU of an underlying network device and returns the -// largest possible MTU for a Tailscale TUN operating on top of that device, -// given that we have to add wireguard headers. -func WireToTUNMTU(w WireMTU) TUNMTU { - if w < wgHeaderLen { - return 0 - } - return TUNMTU(w - wgHeaderLen) -} - -// DefaultTUNMTU returns the MTU we use to set the Tailscale TUN -// MTU. It is also the path MTU that we default to if we have no -// information about the path to a peer. -// -// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of MaxTUNMTU -// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg overhead -// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU -func DefaultTUNMTU() TUNMTU { - if m, ok := envknob.LookupUintSized("TS_DEBUG_MTU", 10, 32); ok { - return min(TUNMTU(m), maxTUNMTU) - } - - debugPMTUD, _ := envknob.LookupBool("TS_DEBUG_ENABLE_PMTUD") - if debugPMTUD { - // TODO: While we are just probing MTU but not generating PTB, - // this has to continue to return the safe MTU. When we add the - // code to generate PTB, this will be: - // - // return WireToTUNMTU(maxProbedWireMTU) - return safeTUNMTU - } - - return safeTUNMTU -} - -// SafeWireMTU returns the wire MTU that is safe to use if we have no -// information about the path MTU to this peer. -func SafeWireMTU() WireMTU { - return TUNToWireMTU(safeTUNMTU) -} - -// DefaultWireMTU returns the default TUN MTU, adjusted for wireguard -// overhead. -func DefaultWireMTU() WireMTU { - return TUNToWireMTU(DefaultTUNMTU()) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "tailscale.com/envknob" +) + +// The MTU (Maximum Transmission Unit) of a network interface is the largest +// packet that can be sent or received through that interface, including all +// headers above the link layer (e.g. IP headers, UDP headers, Wireguard +// headers, etc.). We have to think about several different values of MTU: +// +// Wire MTU: The MTU of an interface underneath the tailscale TUN, e.g. an +// Ethernet network card will default to a 1500 byte MTU. The user may change +// this MTU at any time. +// +// TUN MTU: The current MTU of the tailscale TUN. This MTU is adjusted downward +// to make room for the wireguard/tailscale headers. For example, if the +// underlying network interface's MTU is 1500 bytes, the maximum size of a +// packet entering the tailscale TUN is 1420 bytes. The user may change this MTU +// at any time via the OS's tools (ifconfig, ip, etc.). +// +// User configured initial MTU: The MTU the tailscale TUN should be created +// with, set by the user via TS_DEBUG_MTU. It should be adjusted down from the +// underlying interface MTU by 80 bytes to make room for the wireguard +// headers. This envknob is mostly for debugging. This value is used once at TUN +// creation and ignored thereafter. +// +// User configured current MTU: The MTU set via the OS's tools (ifconfig, ip, +// etc.). This MTU can change at any time. Setting the MTU this way goes through +// the MTU() method of tailscale's TUN wrapper. +// +// Maximum probed MTU: This is the largest MTU size that we send probe packets +// for. +// +// Safe MTU: If the tailscale TUN MTU is set to this value, almost all packets +// will get to their destination. Tailscale defaults to this MTU in the absence +// of path MTU probe information or user MTU configuration. We may occasionally +// find a path that needs a smaller MTU but it is very rare. +// +// Peer MTU: This is the path MTU to a peer's current best endpoint. It defaults +// to the Safe MTU unless we have path MTU probe results that tell us otherwise. +// +// Initial MTU: This is the MTU tailscaled creates the TUN with. In order of +// priority, it is: +// +// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of 65536 +// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg +// overhead +// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU +// +// Current MTU: This the MTU of the tailscale TUN at any given moment +// after TUN creation. In order of priority, it is: +// +// 1. The MTU set by the user via the OS, if it has ever been set +// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg +// overhead +// 4. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU + +// TUNMTU is the MTU for the tailscale TUN. +type TUNMTU uint32 + +// WireMTU is the MTU for the underlying network devices. +type WireMTU uint32 + +const ( + // maxTUNMTU is the largest MTU we will consider for the Tailscale + // TUN. This is inherited from wireguard-go and can be surprisingly + // small; on Windows it is currently 2048 - 32 bytes and iOS it is 1700 + // - 32 bytes. + // TODO(val,raggi): On Windows this seems to derive from RIO driver + // constraints in Wireguard but we don't use RIO so could probably make + // this bigger. + maxTUNMTU TUNMTU = TUNMTU(MaxPacketSize) + // safeTUNMTU is the default "safe" MTU for the Tailscale TUN that we + // use in the absence of other information such as path MTU probes. + safeTUNMTU TUNMTU = 1280 +) + +// WireMTUsToProbe is a list of the on-the-wire MTUs we want to probe. Each time +// magicsock discovery begins, it will send a set of pings, one of each size +// listed below. +var WireMTUsToProbe = []WireMTU{ + WireMTU(safeTUNMTU), // Tailscale over Tailscale :) + TUNToWireMTU(safeTUNMTU), // Smallest MTU allowed for IPv6, current default + 1400, // Most common MTU minus a few bytes for tunnels + 1500, // Most common MTU + 8000, // Should fit inside all jumbo frame sizes + 9000, // Most jumbo frames are this size or larger +} + +// wgHeaderLen is the length of all the headers Wireguard adds to a packet +// in the worst case (IPv6). This constant is for use when we can't or +// shouldn't use information about the IP version of a specific packet +// (e.g., calculating the MTU for the Tailscale interface. +// +// A Wireguard header includes: +// +// - 20-byte IPv4 header or 40-byte IPv6 header +// - 8-byte UDP header +// - 4-byte type +// - 4-byte key index +// - 8-byte nonce +// - 16-byte authentication tag +const wgHeaderLen = 40 + 8 + 4 + 4 + 8 + 16 + +// TUNToWireMTU takes the MTU that the Tailscale TUN presents to the user and +// returns the on-the-wire MTU necessary to transmit the largest packet that +// will fit through the TUN, given that we have to add wireguard headers. +func TUNToWireMTU(t TUNMTU) WireMTU { + return WireMTU(t + wgHeaderLen) +} + +// WireToTUNMTU takes the MTU of an underlying network device and returns the +// largest possible MTU for a Tailscale TUN operating on top of that device, +// given that we have to add wireguard headers. +func WireToTUNMTU(w WireMTU) TUNMTU { + if w < wgHeaderLen { + return 0 + } + return TUNMTU(w - wgHeaderLen) +} + +// DefaultTUNMTU returns the MTU we use to set the Tailscale TUN +// MTU. It is also the path MTU that we default to if we have no +// information about the path to a peer. +// +// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of MaxTUNMTU +// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg overhead +// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU +func DefaultTUNMTU() TUNMTU { + if m, ok := envknob.LookupUintSized("TS_DEBUG_MTU", 10, 32); ok { + return min(TUNMTU(m), maxTUNMTU) + } + + debugPMTUD, _ := envknob.LookupBool("TS_DEBUG_ENABLE_PMTUD") + if debugPMTUD { + // TODO: While we are just probing MTU but not generating PTB, + // this has to continue to return the safe MTU. When we add the + // code to generate PTB, this will be: + // + // return WireToTUNMTU(maxProbedWireMTU) + return safeTUNMTU + } + + return safeTUNMTU +} + +// SafeWireMTU returns the wire MTU that is safe to use if we have no +// information about the path MTU to this peer. +func SafeWireMTU() WireMTU { + return TUNToWireMTU(safeTUNMTU) +} + +// DefaultWireMTU returns the default TUN MTU, adjusted for wireguard +// overhead. +func DefaultWireMTU() WireMTU { + return TUNToWireMTU(DefaultTUNMTU()) +} diff --git a/net/tstun/mtu_test.go b/net/tstun/mtu_test.go index fc5274ae1037c..8d165bfd341a9 100644 --- a/net/tstun/mtu_test.go +++ b/net/tstun/mtu_test.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package tstun - -import ( - "os" - "strconv" - "testing" -) - -// Test the default MTU in the presence of various envknobs. -func TestDefaultTunMTU(t *testing.T) { - // Save and restore the envknobs we will be changing. - - // TS_DEBUG_MTU sets the MTU to a specific value. - defer os.Setenv("TS_DEBUG_MTU", os.Getenv("TS_DEBUG_MTU")) - os.Setenv("TS_DEBUG_MTU", "") - - // TS_DEBUG_ENABLE_PMTUD enables path MTU discovery. - defer os.Setenv("TS_DEBUG_ENABLE_PMTUD", os.Getenv("TS_DEBUG_ENABLE_PMTUD")) - os.Setenv("TS_DEBUG_ENABLE_PMTUD", "") - - // With no MTU envknobs set, we should get the conservative MTU. - if DefaultTUNMTU() != safeTUNMTU { - t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU) - } - - // If set, TS_DEBUG_MTU should set the MTU. - mtu := maxTUNMTU - 1 - os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) - if DefaultTUNMTU() != mtu { - t.Errorf("default TUN MTU = %d, want %d, TS_DEBUG_MTU ignored", DefaultTUNMTU(), mtu) - } - - // MTU should be clamped to maxTunMTU. - mtu = maxTUNMTU + 1 - os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) - if DefaultTUNMTU() != maxTUNMTU { - t.Errorf("default TUN MTU = %d, want %d, clamping failed", DefaultTUNMTU(), maxTUNMTU) - } - - // If PMTUD is enabled, the MTU should default to the safe MTU, but only - // if the user hasn't requested a specific MTU. - // - // TODO: When PMTUD is generating PTB responses, this will become the - // largest MTU we probe. - os.Setenv("TS_DEBUG_MTU", "") - os.Setenv("TS_DEBUG_ENABLE_PMTUD", "true") - if DefaultTUNMTU() != safeTUNMTU { - t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU) - } - // TS_DEBUG_MTU should take precedence over TS_DEBUG_ENABLE_PMTUD. - mtu = WireToTUNMTU(MaxPacketSize - 1) - os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) - if DefaultTUNMTU() != mtu { - t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), mtu) - } -} - -// Test the conversion of wire MTU to/from Tailscale TUN MTU corner cases. -func TestMTUConversion(t *testing.T) { - tests := []struct { - w WireMTU - t TUNMTU - }{ - {w: 0, t: 0}, - {w: wgHeaderLen - 1, t: 0}, - {w: wgHeaderLen, t: 0}, - {w: wgHeaderLen + 1, t: 1}, - {w: 1360, t: 1280}, - {w: 1500, t: 1420}, - {w: 9000, t: 8920}, - } - - for _, tt := range tests { - m := WireToTUNMTU(tt.w) - if m != tt.t { - t.Errorf("conversion of wire MTU %v to TUN MTU = %v, want %v", tt.w, m, tt.t) - } - } - - tests2 := []struct { - t TUNMTU - w WireMTU - }{ - {t: 0, w: wgHeaderLen}, - {t: 1, w: wgHeaderLen + 1}, - {t: 1280, w: 1360}, - {t: 1420, w: 1500}, - {t: 8920, w: 9000}, - } - - for _, tt := range tests2 { - m := TUNToWireMTU(tt.t) - if m != tt.w { - t.Errorf("conversion of TUN MTU %v to wire MTU = %v, want %v", tt.t, m, tt.w) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package tstun + +import ( + "os" + "strconv" + "testing" +) + +// Test the default MTU in the presence of various envknobs. +func TestDefaultTunMTU(t *testing.T) { + // Save and restore the envknobs we will be changing. + + // TS_DEBUG_MTU sets the MTU to a specific value. + defer os.Setenv("TS_DEBUG_MTU", os.Getenv("TS_DEBUG_MTU")) + os.Setenv("TS_DEBUG_MTU", "") + + // TS_DEBUG_ENABLE_PMTUD enables path MTU discovery. + defer os.Setenv("TS_DEBUG_ENABLE_PMTUD", os.Getenv("TS_DEBUG_ENABLE_PMTUD")) + os.Setenv("TS_DEBUG_ENABLE_PMTUD", "") + + // With no MTU envknobs set, we should get the conservative MTU. + if DefaultTUNMTU() != safeTUNMTU { + t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU) + } + + // If set, TS_DEBUG_MTU should set the MTU. + mtu := maxTUNMTU - 1 + os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) + if DefaultTUNMTU() != mtu { + t.Errorf("default TUN MTU = %d, want %d, TS_DEBUG_MTU ignored", DefaultTUNMTU(), mtu) + } + + // MTU should be clamped to maxTunMTU. + mtu = maxTUNMTU + 1 + os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) + if DefaultTUNMTU() != maxTUNMTU { + t.Errorf("default TUN MTU = %d, want %d, clamping failed", DefaultTUNMTU(), maxTUNMTU) + } + + // If PMTUD is enabled, the MTU should default to the safe MTU, but only + // if the user hasn't requested a specific MTU. + // + // TODO: When PMTUD is generating PTB responses, this will become the + // largest MTU we probe. + os.Setenv("TS_DEBUG_MTU", "") + os.Setenv("TS_DEBUG_ENABLE_PMTUD", "true") + if DefaultTUNMTU() != safeTUNMTU { + t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU) + } + // TS_DEBUG_MTU should take precedence over TS_DEBUG_ENABLE_PMTUD. + mtu = WireToTUNMTU(MaxPacketSize - 1) + os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) + if DefaultTUNMTU() != mtu { + t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), mtu) + } +} + +// Test the conversion of wire MTU to/from Tailscale TUN MTU corner cases. +func TestMTUConversion(t *testing.T) { + tests := []struct { + w WireMTU + t TUNMTU + }{ + {w: 0, t: 0}, + {w: wgHeaderLen - 1, t: 0}, + {w: wgHeaderLen, t: 0}, + {w: wgHeaderLen + 1, t: 1}, + {w: 1360, t: 1280}, + {w: 1500, t: 1420}, + {w: 9000, t: 8920}, + } + + for _, tt := range tests { + m := WireToTUNMTU(tt.w) + if m != tt.t { + t.Errorf("conversion of wire MTU %v to TUN MTU = %v, want %v", tt.w, m, tt.t) + } + } + + tests2 := []struct { + t TUNMTU + w WireMTU + }{ + {t: 0, w: wgHeaderLen}, + {t: 1, w: wgHeaderLen + 1}, + {t: 1280, w: 1360}, + {t: 1420, w: 1500}, + {t: 8920, w: 9000}, + } + + for _, tt := range tests2 { + m := TUNToWireMTU(tt.t) + if m != tt.w { + t.Errorf("conversion of TUN MTU %v to wire MTU = %v, want %v", tt.t, m, tt.w) + } + } +} diff --git a/net/tstun/tun_linux.go b/net/tstun/tun_linux.go index e08f12bc14129..9600ceb77328f 100644 --- a/net/tstun/tun_linux.go +++ b/net/tstun/tun_linux.go @@ -1,103 +1,103 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "bytes" - "errors" - "os" - "os/exec" - "strings" - "syscall" - - "tailscale.com/types/logger" - "tailscale.com/version/distro" -) - -func init() { - tunDiagnoseFailure = diagnoseLinuxTUNFailure -} - -func diagnoseLinuxTUNFailure(tunName string, logf logger.Logf, createErr error) { - if errors.Is(createErr, syscall.EBUSY) { - logf("TUN device %s is busy; another process probably still has it open (from old version of Tailscale that had a bug)", tunName) - logf("To fix, kill the process that has it open. Find with:\n\n$ sudo lsof -n /dev/net/tun\n\n") - logf("... and then kill those PID(s)") - return - } - - var un syscall.Utsname - err := syscall.Uname(&un) - if err != nil { - logf("no TUN, and failed to look up kernel version: %v", err) - return - } - kernel := utsReleaseField(&un) - logf("Linux kernel version: %s", kernel) - - modprobeOut, err := exec.Command("/sbin/modprobe", "tun").CombinedOutput() - if err == nil { - logf("'modprobe tun' successful") - // Either tun is currently loaded, or it's statically - // compiled into the kernel (which modprobe checks - // with /lib/modules/$(uname -r)/modules.builtin) - // - // So if there's a problem at this point, it's - // probably because /dev/net/tun doesn't exist. - const dev = "/dev/net/tun" - if fi, err := os.Stat(dev); err != nil { - logf("tun module loaded in kernel, but %s does not exist", dev) - } else { - logf("%s: %v", dev, fi.Mode()) - } - - // We failed to find why it failed. Just let our - // caller report the error it got from wireguard-go. - return - } - logf("is CONFIG_TUN enabled in your kernel? `modprobe tun` failed with: %s", modprobeOut) - - switch distro.Get() { - case distro.Debian: - dpkgOut, err := exec.Command("dpkg", "-S", "kernel/drivers/net/tun.ko").CombinedOutput() - if len(bytes.TrimSpace(dpkgOut)) == 0 || err != nil { - logf("tun module not loaded nor found on disk") - return - } - if !bytes.Contains(dpkgOut, []byte(kernel)) { - logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", dpkgOut) - } - case distro.Arch: - findOut, err := exec.Command("find", "/lib/modules/", "-path", "*/net/tun.ko*").CombinedOutput() - if len(bytes.TrimSpace(findOut)) == 0 || err != nil { - logf("tun module not loaded nor found on disk") - return - } - if !bytes.Contains(findOut, []byte(kernel)) { - logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", findOut) - } - case distro.OpenWrt: - out, err := exec.Command("opkg", "list-installed").CombinedOutput() - if err != nil { - logf("error querying OpenWrt installed packages: %s", out) - return - } - for _, pkg := range []string{"kmod-tun", "ca-bundle"} { - if !bytes.Contains(out, []byte(pkg+" - ")) { - logf("Missing required package %s; run: opkg install %s", pkg, pkg) - } - } - } -} - -func utsReleaseField(u *syscall.Utsname) string { - var sb strings.Builder - for _, v := range u.Release { - if v == 0 { - break - } - sb.WriteByte(byte(v)) - } - return strings.TrimSpace(sb.String()) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "bytes" + "errors" + "os" + "os/exec" + "strings" + "syscall" + + "tailscale.com/types/logger" + "tailscale.com/version/distro" +) + +func init() { + tunDiagnoseFailure = diagnoseLinuxTUNFailure +} + +func diagnoseLinuxTUNFailure(tunName string, logf logger.Logf, createErr error) { + if errors.Is(createErr, syscall.EBUSY) { + logf("TUN device %s is busy; another process probably still has it open (from old version of Tailscale that had a bug)", tunName) + logf("To fix, kill the process that has it open. Find with:\n\n$ sudo lsof -n /dev/net/tun\n\n") + logf("... and then kill those PID(s)") + return + } + + var un syscall.Utsname + err := syscall.Uname(&un) + if err != nil { + logf("no TUN, and failed to look up kernel version: %v", err) + return + } + kernel := utsReleaseField(&un) + logf("Linux kernel version: %s", kernel) + + modprobeOut, err := exec.Command("/sbin/modprobe", "tun").CombinedOutput() + if err == nil { + logf("'modprobe tun' successful") + // Either tun is currently loaded, or it's statically + // compiled into the kernel (which modprobe checks + // with /lib/modules/$(uname -r)/modules.builtin) + // + // So if there's a problem at this point, it's + // probably because /dev/net/tun doesn't exist. + const dev = "/dev/net/tun" + if fi, err := os.Stat(dev); err != nil { + logf("tun module loaded in kernel, but %s does not exist", dev) + } else { + logf("%s: %v", dev, fi.Mode()) + } + + // We failed to find why it failed. Just let our + // caller report the error it got from wireguard-go. + return + } + logf("is CONFIG_TUN enabled in your kernel? `modprobe tun` failed with: %s", modprobeOut) + + switch distro.Get() { + case distro.Debian: + dpkgOut, err := exec.Command("dpkg", "-S", "kernel/drivers/net/tun.ko").CombinedOutput() + if len(bytes.TrimSpace(dpkgOut)) == 0 || err != nil { + logf("tun module not loaded nor found on disk") + return + } + if !bytes.Contains(dpkgOut, []byte(kernel)) { + logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", dpkgOut) + } + case distro.Arch: + findOut, err := exec.Command("find", "/lib/modules/", "-path", "*/net/tun.ko*").CombinedOutput() + if len(bytes.TrimSpace(findOut)) == 0 || err != nil { + logf("tun module not loaded nor found on disk") + return + } + if !bytes.Contains(findOut, []byte(kernel)) { + logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", findOut) + } + case distro.OpenWrt: + out, err := exec.Command("opkg", "list-installed").CombinedOutput() + if err != nil { + logf("error querying OpenWrt installed packages: %s", out) + return + } + for _, pkg := range []string{"kmod-tun", "ca-bundle"} { + if !bytes.Contains(out, []byte(pkg+" - ")) { + logf("Missing required package %s; run: opkg install %s", pkg, pkg) + } + } + } +} + +func utsReleaseField(u *syscall.Utsname) string { + var sb strings.Builder + for _, v := range u.Release { + if v == 0 { + break + } + sb.WriteByte(byte(v)) + } + return strings.TrimSpace(sb.String()) +} diff --git a/net/tstun/tun_macos.go b/net/tstun/tun_macos.go index f71494f0b91b6..3506f05b1e4c9 100644 --- a/net/tstun/tun_macos.go +++ b/net/tstun/tun_macos.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin && !ios - -package tstun - -import ( - "os" - - "tailscale.com/types/logger" -) - -func init() { - tunDiagnoseFailure = diagnoseDarwinTUNFailure -} - -func diagnoseDarwinTUNFailure(tunName string, logf logger.Logf, err error) { - if os.Getuid() != 0 { - logf("failed to create TUN device as non-root user; use 'sudo tailscaled', or run under launchd with 'sudo tailscaled install-system-daemon'") - } - if tunName != "utun" { - logf("failed to create TUN device %q; try using tun device \"utun\" instead for automatic selection", tunName) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin && !ios + +package tstun + +import ( + "os" + + "tailscale.com/types/logger" +) + +func init() { + tunDiagnoseFailure = diagnoseDarwinTUNFailure +} + +func diagnoseDarwinTUNFailure(tunName string, logf logger.Logf, err error) { + if os.Getuid() != 0 { + logf("failed to create TUN device as non-root user; use 'sudo tailscaled', or run under launchd with 'sudo tailscaled install-system-daemon'") + } + if tunName != "utun" { + logf("failed to create TUN device %q; try using tun device \"utun\" instead for automatic selection", tunName) + } +} diff --git a/net/tstun/tun_notwindows.go b/net/tstun/tun_notwindows.go index 60f1c62bacaab..087fcd4eec784 100644 --- a/net/tstun/tun_notwindows.go +++ b/net/tstun/tun_notwindows.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package tstun - -import "github.com/tailscale/wireguard-go/tun" - -func interfaceName(dev tun.Device) (string, error) { - return dev.Name() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package tstun + +import "github.com/tailscale/wireguard-go/tun" + +func interfaceName(dev tun.Device) (string, error) { + return dev.Name() +} diff --git a/packages/deb/deb.go b/packages/deb/deb.go index 1be7f96526d1e..30e3f2b4d360c 100644 --- a/packages/deb/deb.go +++ b/packages/deb/deb.go @@ -1,182 +1,182 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package deb extracts metadata from Debian packages. -package deb - -import ( - "archive/tar" - "bufio" - "bytes" - "compress/gzip" - "crypto/md5" - "crypto/sha1" - "crypto/sha256" - "errors" - "fmt" - "io" - "os" - "path/filepath" - "strconv" - "strings" -) - -// Info is the Debian package metadata needed to integrate the package -// into a repository. -type Info struct { - // Version is the version of the package, as reported by dpkg. - Version string - // Arch is the Debian CPU architecture the package is for. - Arch string - // Control is the entire contents of the package's control file, - // with leading and trailing whitespace removed. - Control []byte - // MD5 is the MD5 hash of the package file. - MD5 []byte - // SHA1 is the SHA1 hash of the package file. - SHA1 []byte - // SHA256 is the SHA256 hash of the package file. - SHA256 []byte -} - -// ReadFile returns Debian package metadata from the .deb file at path. -func ReadFile(path string) (*Info, error) { - f, err := os.Open(path) - if err != nil { - return nil, err - } - return Read(f) -} - -// Read returns Debian package metadata from the .deb file in r. -func Read(r io.Reader) (*Info, error) { - b := bufio.NewReader(r) - - m5, s1, s256 := md5.New(), sha1.New(), sha256.New() - summers := io.MultiWriter(m5, s1, s256) - r = io.TeeReader(b, summers) - - t, err := findControlTar(r) - if err != nil { - return nil, fmt.Errorf("searching for control.tar.gz: %w", err) - } - - control, err := findControlFile(t) - if err != nil { - return nil, fmt.Errorf("searching for control file in control.tar.gz: %w", err) - } - - arch, version, err := findArchAndVersion(control) - if err != nil { - return nil, fmt.Errorf("extracting version and architecture from control file: %w", err) - } - - // Exhaust the remainder of r, so that the summers see the entire file. - if _, err := io.Copy(io.Discard, r); err != nil { - return nil, fmt.Errorf("hashing file: %w", err) - } - - return &Info{ - Version: version, - Arch: arch, - Control: control, - MD5: m5.Sum(nil), - SHA1: s1.Sum(nil), - SHA256: s256.Sum(nil), - }, nil -} - -// findControlTar reads r as an `ar` archive, finds a tarball named -// `control.tar.gz` within, and returns a reader for that file. -func findControlTar(r io.Reader) (tarReader io.Reader, err error) { - var magic [8]byte - if _, err := io.ReadFull(r, magic[:]); err != nil { - return nil, fmt.Errorf("reading ar magic: %w", err) - } - if string(magic[:]) != "!\n" { - return nil, fmt.Errorf("not an ar file (bad magic %q)", magic) - } - - for { - var hdr [60]byte - if _, err := io.ReadFull(r, hdr[:]); err != nil { - return nil, fmt.Errorf("reading file header: %w", err) - } - filename := strings.TrimSpace(string(hdr[:16])) - size, err := strconv.ParseInt(strings.TrimSpace(string(hdr[48:58])), 10, 64) - if err != nil { - return nil, fmt.Errorf("reading size of file %q: %w", filename, err) - } - if filename == "control.tar.gz" { - return io.LimitReader(r, size), nil - } - - // files in ar are padded out to 2 bytes. - if size%2 == 1 { - size++ - } - if _, err := io.CopyN(io.Discard, r, size); err != nil { - return nil, fmt.Errorf("seeking past file %q: %w", filename, err) - } - } -} - -// findControlFile reads r as a tar.gz archive, finds a file named -// `control` within, and returns its contents. -func findControlFile(r io.Reader) (control []byte, err error) { - gz, err := gzip.NewReader(r) - if err != nil { - return nil, fmt.Errorf("decompressing control.tar.gz: %w", err) - } - defer gz.Close() - - tr := tar.NewReader(gz) - for { - hdr, err := tr.Next() - if err != nil { - if errors.Is(err, io.EOF) { - return nil, errors.New("EOF while looking for control file in control.tar.gz") - } - return nil, fmt.Errorf("reading tar header: %w", err) - } - - if filepath.Clean(hdr.Name) != "control" { - continue - } - - // Found control file - break - } - - bs, err := io.ReadAll(tr) - if err != nil { - return nil, fmt.Errorf("reading control file: %w", err) - } - - return bytes.TrimSpace(bs), nil -} - -var ( - archKey = []byte("Architecture:") - versionKey = []byte("Version:") -) - -// findArchAndVersion extracts the architecture and version strings -// from the given control file. -func findArchAndVersion(control []byte) (arch string, version string, err error) { - b := bytes.NewBuffer(control) - for { - l, err := b.ReadBytes('\n') - if err != nil { - return "", "", err - } - if bytes.HasPrefix(l, archKey) { - arch = string(bytes.TrimSpace(l[len(archKey):])) - } else if bytes.HasPrefix(l, versionKey) { - version = string(bytes.TrimSpace(l[len(versionKey):])) - } - if arch != "" && version != "" { - return arch, version, nil - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package deb extracts metadata from Debian packages. +package deb + +import ( + "archive/tar" + "bufio" + "bytes" + "compress/gzip" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strconv" + "strings" +) + +// Info is the Debian package metadata needed to integrate the package +// into a repository. +type Info struct { + // Version is the version of the package, as reported by dpkg. + Version string + // Arch is the Debian CPU architecture the package is for. + Arch string + // Control is the entire contents of the package's control file, + // with leading and trailing whitespace removed. + Control []byte + // MD5 is the MD5 hash of the package file. + MD5 []byte + // SHA1 is the SHA1 hash of the package file. + SHA1 []byte + // SHA256 is the SHA256 hash of the package file. + SHA256 []byte +} + +// ReadFile returns Debian package metadata from the .deb file at path. +func ReadFile(path string) (*Info, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + return Read(f) +} + +// Read returns Debian package metadata from the .deb file in r. +func Read(r io.Reader) (*Info, error) { + b := bufio.NewReader(r) + + m5, s1, s256 := md5.New(), sha1.New(), sha256.New() + summers := io.MultiWriter(m5, s1, s256) + r = io.TeeReader(b, summers) + + t, err := findControlTar(r) + if err != nil { + return nil, fmt.Errorf("searching for control.tar.gz: %w", err) + } + + control, err := findControlFile(t) + if err != nil { + return nil, fmt.Errorf("searching for control file in control.tar.gz: %w", err) + } + + arch, version, err := findArchAndVersion(control) + if err != nil { + return nil, fmt.Errorf("extracting version and architecture from control file: %w", err) + } + + // Exhaust the remainder of r, so that the summers see the entire file. + if _, err := io.Copy(io.Discard, r); err != nil { + return nil, fmt.Errorf("hashing file: %w", err) + } + + return &Info{ + Version: version, + Arch: arch, + Control: control, + MD5: m5.Sum(nil), + SHA1: s1.Sum(nil), + SHA256: s256.Sum(nil), + }, nil +} + +// findControlTar reads r as an `ar` archive, finds a tarball named +// `control.tar.gz` within, and returns a reader for that file. +func findControlTar(r io.Reader) (tarReader io.Reader, err error) { + var magic [8]byte + if _, err := io.ReadFull(r, magic[:]); err != nil { + return nil, fmt.Errorf("reading ar magic: %w", err) + } + if string(magic[:]) != "!\n" { + return nil, fmt.Errorf("not an ar file (bad magic %q)", magic) + } + + for { + var hdr [60]byte + if _, err := io.ReadFull(r, hdr[:]); err != nil { + return nil, fmt.Errorf("reading file header: %w", err) + } + filename := strings.TrimSpace(string(hdr[:16])) + size, err := strconv.ParseInt(strings.TrimSpace(string(hdr[48:58])), 10, 64) + if err != nil { + return nil, fmt.Errorf("reading size of file %q: %w", filename, err) + } + if filename == "control.tar.gz" { + return io.LimitReader(r, size), nil + } + + // files in ar are padded out to 2 bytes. + if size%2 == 1 { + size++ + } + if _, err := io.CopyN(io.Discard, r, size); err != nil { + return nil, fmt.Errorf("seeking past file %q: %w", filename, err) + } + } +} + +// findControlFile reads r as a tar.gz archive, finds a file named +// `control` within, and returns its contents. +func findControlFile(r io.Reader) (control []byte, err error) { + gz, err := gzip.NewReader(r) + if err != nil { + return nil, fmt.Errorf("decompressing control.tar.gz: %w", err) + } + defer gz.Close() + + tr := tar.NewReader(gz) + for { + hdr, err := tr.Next() + if err != nil { + if errors.Is(err, io.EOF) { + return nil, errors.New("EOF while looking for control file in control.tar.gz") + } + return nil, fmt.Errorf("reading tar header: %w", err) + } + + if filepath.Clean(hdr.Name) != "control" { + continue + } + + // Found control file + break + } + + bs, err := io.ReadAll(tr) + if err != nil { + return nil, fmt.Errorf("reading control file: %w", err) + } + + return bytes.TrimSpace(bs), nil +} + +var ( + archKey = []byte("Architecture:") + versionKey = []byte("Version:") +) + +// findArchAndVersion extracts the architecture and version strings +// from the given control file. +func findArchAndVersion(control []byte) (arch string, version string, err error) { + b := bytes.NewBuffer(control) + for { + l, err := b.ReadBytes('\n') + if err != nil { + return "", "", err + } + if bytes.HasPrefix(l, archKey) { + arch = string(bytes.TrimSpace(l[len(archKey):])) + } else if bytes.HasPrefix(l, versionKey) { + version = string(bytes.TrimSpace(l[len(versionKey):])) + } + if arch != "" && version != "" { + return arch, version, nil + } + } +} diff --git a/packages/deb/deb_test.go b/packages/deb/deb_test.go index 0ff43da21d151..1a25f67ad4875 100644 --- a/packages/deb/deb_test.go +++ b/packages/deb/deb_test.go @@ -1,205 +1,205 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package deb - -import ( - "bytes" - "crypto/md5" - "crypto/sha1" - "crypto/sha256" - "encoding/hex" - "fmt" - "hash" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/goreleaser/nfpm/v2" - _ "github.com/goreleaser/nfpm/v2/deb" -) - -func TestDebInfo(t *testing.T) { - tests := []struct { - name string - in []byte - want *Info - wantErr bool - }{ - { - name: "simple", - in: mkTestDeb("1.2.3", "amd64"), - want: &Info{ - Version: "1.2.3", - Arch: "amd64", - Control: mkControl( - "Package", "tailscale", - "Version", "1.2.3", - "Section", "net", - "Priority", "extra", - "Architecture", "amd64", - "Maintainer", "Tail Scalar", - "Installed-Size", "0", - "Description", "test package"), - }, - }, - { - name: "arm64", - in: mkTestDeb("1.2.3", "arm64"), - want: &Info{ - Version: "1.2.3", - Arch: "arm64", - Control: mkControl( - "Package", "tailscale", - "Version", "1.2.3", - "Section", "net", - "Priority", "extra", - "Architecture", "arm64", - "Maintainer", "Tail Scalar", - "Installed-Size", "0", - "Description", "test package"), - }, - }, - { - name: "unstable", - in: mkTestDeb("1.7.25", "amd64"), - want: &Info{ - Version: "1.7.25", - Arch: "amd64", - Control: mkControl( - "Package", "tailscale", - "Version", "1.7.25", - "Section", "net", - "Priority", "extra", - "Architecture", "amd64", - "Maintainer", "Tail Scalar", - "Installed-Size", "0", - "Description", "test package"), - }, - }, - - // These truncation tests assume the structure of a .deb - // package, which is as follows: - // magic: 8 bytes - // file header: 60 bytes, before each file blob - // - // The first file in a .deb ar is "debian-binary", which is 4 - // bytes long and consists of "2.0\n". - // The second file is control.tar.gz, which is what we care - // about introspecting for metadata. - // The final file is data.tar.gz, which we don't care about. - // - // The first file in control.tar.gz is the "control" file we - // want to read for metadata. - { - name: "truncated_ar_magic", - in: mkTestDeb("1.7.25", "amd64")[:4], - wantErr: true, - }, - { - name: "truncated_ar_header", - in: mkTestDeb("1.7.25", "amd64")[:30], - wantErr: true, - }, - { - name: "missing_control_tgz", - // Truncate right after the "debian-binary" file, which - // makes the file a valid 1-file archive that's missing - // control.tar.gz. - in: mkTestDeb("1.7.25", "amd64")[:72], - wantErr: true, - }, - { - name: "truncated_tgz", - in: mkTestDeb("1.7.25", "amd64")[:172], - wantErr: true, - }, - } - - for _, test := range tests { - // mkTestDeb returns non-deterministic output due to - // timestamps embedded in the package file, so compute the - // wanted hashes on the fly here. - if test.want != nil { - test.want.MD5 = mkHash(test.in, md5.New) - test.want.SHA1 = mkHash(test.in, sha1.New) - test.want.SHA256 = mkHash(test.in, sha256.New) - } - - t.Run(test.name, func(t *testing.T) { - b := bytes.NewBuffer(test.in) - got, err := Read(b) - if err != nil { - if test.wantErr { - t.Logf("got expected error: %v", err) - return - } - t.Fatalf("reading deb info: %v", err) - } - if diff := diff(got, test.want); diff != "" { - t.Fatalf("parsed info diff (-got+want):\n%s", diff) - } - }) - } -} - -func diff(got, want any) string { - matchField := func(name string) func(p cmp.Path) bool { - return func(p cmp.Path) bool { - if len(p) != 3 { - return false - } - return p[2].String() == "."+name - } - } - toLines := cmp.Transformer("lines", func(b []byte) []string { return strings.Split(string(b), "\n") }) - toHex := cmp.Transformer("hex", func(b []byte) string { return hex.EncodeToString(b) }) - return cmp.Diff(got, want, - cmp.FilterPath(matchField("Control"), toLines), - cmp.FilterPath(matchField("MD5"), toHex), - cmp.FilterPath(matchField("SHA1"), toHex), - cmp.FilterPath(matchField("SHA256"), toHex)) -} - -func mkTestDeb(version, arch string) []byte { - info := nfpm.WithDefaults(&nfpm.Info{ - Name: "tailscale", - Description: "test package", - Arch: arch, - Platform: "linux", - Version: version, - Section: "net", - Priority: "extra", - Maintainer: "Tail Scalar", - }) - - pkg, err := nfpm.Get("deb") - if err != nil { - panic(fmt.Sprintf("getting deb packager: %v", err)) - } - - var b bytes.Buffer - if err := pkg.Package(info, &b); err != nil { - panic(fmt.Sprintf("creating deb package: %v", err)) - } - - return b.Bytes() -} - -func mkControl(fs ...string) []byte { - if len(fs)%2 != 0 { - panic("odd number of control file fields") - } - var b bytes.Buffer - for i := 0; i < len(fs); i = i + 2 { - k, v := fs[i], fs[i+1] - fmt.Fprintf(&b, "%s: %s\n", k, v) - } - return bytes.TrimSpace(b.Bytes()) -} - -func mkHash(b []byte, hasher func() hash.Hash) []byte { - h := hasher() - h.Write(b) - return h.Sum(nil) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package deb + +import ( + "bytes" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "encoding/hex" + "fmt" + "hash" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/goreleaser/nfpm/v2" + _ "github.com/goreleaser/nfpm/v2/deb" +) + +func TestDebInfo(t *testing.T) { + tests := []struct { + name string + in []byte + want *Info + wantErr bool + }{ + { + name: "simple", + in: mkTestDeb("1.2.3", "amd64"), + want: &Info{ + Version: "1.2.3", + Arch: "amd64", + Control: mkControl( + "Package", "tailscale", + "Version", "1.2.3", + "Section", "net", + "Priority", "extra", + "Architecture", "amd64", + "Maintainer", "Tail Scalar", + "Installed-Size", "0", + "Description", "test package"), + }, + }, + { + name: "arm64", + in: mkTestDeb("1.2.3", "arm64"), + want: &Info{ + Version: "1.2.3", + Arch: "arm64", + Control: mkControl( + "Package", "tailscale", + "Version", "1.2.3", + "Section", "net", + "Priority", "extra", + "Architecture", "arm64", + "Maintainer", "Tail Scalar", + "Installed-Size", "0", + "Description", "test package"), + }, + }, + { + name: "unstable", + in: mkTestDeb("1.7.25", "amd64"), + want: &Info{ + Version: "1.7.25", + Arch: "amd64", + Control: mkControl( + "Package", "tailscale", + "Version", "1.7.25", + "Section", "net", + "Priority", "extra", + "Architecture", "amd64", + "Maintainer", "Tail Scalar", + "Installed-Size", "0", + "Description", "test package"), + }, + }, + + // These truncation tests assume the structure of a .deb + // package, which is as follows: + // magic: 8 bytes + // file header: 60 bytes, before each file blob + // + // The first file in a .deb ar is "debian-binary", which is 4 + // bytes long and consists of "2.0\n". + // The second file is control.tar.gz, which is what we care + // about introspecting for metadata. + // The final file is data.tar.gz, which we don't care about. + // + // The first file in control.tar.gz is the "control" file we + // want to read for metadata. + { + name: "truncated_ar_magic", + in: mkTestDeb("1.7.25", "amd64")[:4], + wantErr: true, + }, + { + name: "truncated_ar_header", + in: mkTestDeb("1.7.25", "amd64")[:30], + wantErr: true, + }, + { + name: "missing_control_tgz", + // Truncate right after the "debian-binary" file, which + // makes the file a valid 1-file archive that's missing + // control.tar.gz. + in: mkTestDeb("1.7.25", "amd64")[:72], + wantErr: true, + }, + { + name: "truncated_tgz", + in: mkTestDeb("1.7.25", "amd64")[:172], + wantErr: true, + }, + } + + for _, test := range tests { + // mkTestDeb returns non-deterministic output due to + // timestamps embedded in the package file, so compute the + // wanted hashes on the fly here. + if test.want != nil { + test.want.MD5 = mkHash(test.in, md5.New) + test.want.SHA1 = mkHash(test.in, sha1.New) + test.want.SHA256 = mkHash(test.in, sha256.New) + } + + t.Run(test.name, func(t *testing.T) { + b := bytes.NewBuffer(test.in) + got, err := Read(b) + if err != nil { + if test.wantErr { + t.Logf("got expected error: %v", err) + return + } + t.Fatalf("reading deb info: %v", err) + } + if diff := diff(got, test.want); diff != "" { + t.Fatalf("parsed info diff (-got+want):\n%s", diff) + } + }) + } +} + +func diff(got, want any) string { + matchField := func(name string) func(p cmp.Path) bool { + return func(p cmp.Path) bool { + if len(p) != 3 { + return false + } + return p[2].String() == "."+name + } + } + toLines := cmp.Transformer("lines", func(b []byte) []string { return strings.Split(string(b), "\n") }) + toHex := cmp.Transformer("hex", func(b []byte) string { return hex.EncodeToString(b) }) + return cmp.Diff(got, want, + cmp.FilterPath(matchField("Control"), toLines), + cmp.FilterPath(matchField("MD5"), toHex), + cmp.FilterPath(matchField("SHA1"), toHex), + cmp.FilterPath(matchField("SHA256"), toHex)) +} + +func mkTestDeb(version, arch string) []byte { + info := nfpm.WithDefaults(&nfpm.Info{ + Name: "tailscale", + Description: "test package", + Arch: arch, + Platform: "linux", + Version: version, + Section: "net", + Priority: "extra", + Maintainer: "Tail Scalar", + }) + + pkg, err := nfpm.Get("deb") + if err != nil { + panic(fmt.Sprintf("getting deb packager: %v", err)) + } + + var b bytes.Buffer + if err := pkg.Package(info, &b); err != nil { + panic(fmt.Sprintf("creating deb package: %v", err)) + } + + return b.Bytes() +} + +func mkControl(fs ...string) []byte { + if len(fs)%2 != 0 { + panic("odd number of control file fields") + } + var b bytes.Buffer + for i := 0; i < len(fs); i = i + 2 { + k, v := fs[i], fs[i+1] + fmt.Fprintf(&b, "%s: %s\n", k, v) + } + return bytes.TrimSpace(b.Bytes()) +} + +func mkHash(b []byte, hasher func() hash.Hash) []byte { + h := hasher() + h.Write(b) + return h.Sum(nil) +} diff --git a/paths/migrate.go b/paths/migrate.go index 11d90a6272a65..3a23ecca34fdc 100644 --- a/paths/migrate.go +++ b/paths/migrate.go @@ -1,58 +1,58 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package paths - -import ( - "os" - "path/filepath" - - "tailscale.com/types/logger" -) - -// TryConfigFileMigration carefully copies the contents of oldFile to -// newFile, returning the path which should be used to read the config. -// - if newFile already exists, don't modify it just return its path -// - if neither oldFile nor newFile exist, return newFile for a fresh -// default config to be written to. -// - if oldFile exists but copying to newFile fails, return oldFile so -// there will at least be some config to work with. -func TryConfigFileMigration(logf logger.Logf, oldFile, newFile string) string { - _, err := os.Stat(newFile) - if err == nil { - // Common case for a system which has already been migrated. - return newFile - } - if !os.IsNotExist(err) { - logf("TryConfigFileMigration failed; new file: %v", err) - return newFile - } - - contents, err := os.ReadFile(oldFile) - if err != nil { - // Common case for a new user. - return newFile - } - - if err = MkStateDir(filepath.Dir(newFile)); err != nil { - logf("TryConfigFileMigration failed; MkStateDir: %v", err) - return oldFile - } - - err = os.WriteFile(newFile, contents, 0600) - if err != nil { - removeErr := os.Remove(newFile) - if removeErr != nil { - logf("TryConfigFileMigration failed; write newFile no cleanup: %v, remove err: %v", - err, removeErr) - return oldFile - } - logf("TryConfigFileMigration failed; write newFile: %v", err) - return oldFile - } - - logf("TryConfigFileMigration: successfully migrated: from %v to %v", - oldFile, newFile) - - return newFile -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package paths + +import ( + "os" + "path/filepath" + + "tailscale.com/types/logger" +) + +// TryConfigFileMigration carefully copies the contents of oldFile to +// newFile, returning the path which should be used to read the config. +// - if newFile already exists, don't modify it just return its path +// - if neither oldFile nor newFile exist, return newFile for a fresh +// default config to be written to. +// - if oldFile exists but copying to newFile fails, return oldFile so +// there will at least be some config to work with. +func TryConfigFileMigration(logf logger.Logf, oldFile, newFile string) string { + _, err := os.Stat(newFile) + if err == nil { + // Common case for a system which has already been migrated. + return newFile + } + if !os.IsNotExist(err) { + logf("TryConfigFileMigration failed; new file: %v", err) + return newFile + } + + contents, err := os.ReadFile(oldFile) + if err != nil { + // Common case for a new user. + return newFile + } + + if err = MkStateDir(filepath.Dir(newFile)); err != nil { + logf("TryConfigFileMigration failed; MkStateDir: %v", err) + return oldFile + } + + err = os.WriteFile(newFile, contents, 0600) + if err != nil { + removeErr := os.Remove(newFile) + if removeErr != nil { + logf("TryConfigFileMigration failed; write newFile no cleanup: %v, remove err: %v", + err, removeErr) + return oldFile + } + logf("TryConfigFileMigration failed; write newFile: %v", err) + return oldFile + } + + logf("TryConfigFileMigration: successfully migrated: from %v to %v", + oldFile, newFile) + + return newFile +} diff --git a/paths/paths.go b/paths/paths.go index 8cee4cabfd2a9..28c3be02a9c86 100644 --- a/paths/paths.go +++ b/paths/paths.go @@ -1,92 +1,92 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package paths returns platform and user-specific default paths to -// Tailscale files and directories. -package paths - -import ( - "os" - "path/filepath" - "runtime" - - "tailscale.com/syncs" - "tailscale.com/version/distro" -) - -// AppSharedDir is a string set by the iOS or Android app on start -// containing a directory we can read/write in. -var AppSharedDir syncs.AtomicValue[string] - -// DefaultTailscaledSocket returns the path to the tailscaled Unix socket -// or the empty string if there's no reasonable default. -func DefaultTailscaledSocket() string { - if runtime.GOOS == "windows" { - return `\\.\pipe\ProtectedPrefix\Administrators\Tailscale\tailscaled` - } - if runtime.GOOS == "darwin" { - return "/var/run/tailscaled.socket" - } - if runtime.GOOS == "plan9" { - return "/srv/tailscaled.sock" - } - switch distro.Get() { - case distro.Synology: - if distro.DSMVersion() == 6 { - return "/var/packages/Tailscale/etc/tailscaled.sock" - } - // DSM 7 (and higher? or failure to detect.) - return "/var/packages/Tailscale/var/tailscaled.sock" - case distro.Gokrazy: - return "/perm/tailscaled/tailscaled.sock" - case distro.QNAP: - return "/tmp/tailscale/tailscaled.sock" - } - if fi, err := os.Stat("/var/run"); err == nil && fi.IsDir() { - return "/var/run/tailscale/tailscaled.sock" - } - return "tailscaled.sock" -} - -// Overridden in init by OS-specific files. -var ( - stateFileFunc func() string - - // ensureStateDirPerms applies a restrictive ACL/chmod - // to the provided directory. - ensureStateDirPerms = func(string) error { return nil } -) - -// DefaultTailscaledStateFile returns the default path to the -// tailscaled state file, or the empty string if there's no reasonable -// default value. -func DefaultTailscaledStateFile() string { - if f := stateFileFunc; f != nil { - return f() - } - if runtime.GOOS == "windows" { - return filepath.Join(os.Getenv("ProgramData"), "Tailscale", "server-state.conf") - } - return "" -} - -// MkStateDir ensures that dirPath, the daemon's configuration directory -// containing machine keys etc, both exists and has the correct permissions. -// We want it to only be accessible to the user the daemon is running under. -func MkStateDir(dirPath string) error { - if err := os.MkdirAll(dirPath, 0700); err != nil { - return err - } - return ensureStateDirPerms(dirPath) -} - -// LegacyStateFilePath returns the legacy path to the state file when -// it was stored under the current user's %LocalAppData%. -// -// It is only called on Windows. -func LegacyStateFilePath() string { - if runtime.GOOS == "windows" { - return filepath.Join(os.Getenv("LocalAppData"), "Tailscale", "server-state.conf") - } - return "" -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package paths returns platform and user-specific default paths to +// Tailscale files and directories. +package paths + +import ( + "os" + "path/filepath" + "runtime" + + "tailscale.com/syncs" + "tailscale.com/version/distro" +) + +// AppSharedDir is a string set by the iOS or Android app on start +// containing a directory we can read/write in. +var AppSharedDir syncs.AtomicValue[string] + +// DefaultTailscaledSocket returns the path to the tailscaled Unix socket +// or the empty string if there's no reasonable default. +func DefaultTailscaledSocket() string { + if runtime.GOOS == "windows" { + return `\\.\pipe\ProtectedPrefix\Administrators\Tailscale\tailscaled` + } + if runtime.GOOS == "darwin" { + return "/var/run/tailscaled.socket" + } + if runtime.GOOS == "plan9" { + return "/srv/tailscaled.sock" + } + switch distro.Get() { + case distro.Synology: + if distro.DSMVersion() == 6 { + return "/var/packages/Tailscale/etc/tailscaled.sock" + } + // DSM 7 (and higher? or failure to detect.) + return "/var/packages/Tailscale/var/tailscaled.sock" + case distro.Gokrazy: + return "/perm/tailscaled/tailscaled.sock" + case distro.QNAP: + return "/tmp/tailscale/tailscaled.sock" + } + if fi, err := os.Stat("/var/run"); err == nil && fi.IsDir() { + return "/var/run/tailscale/tailscaled.sock" + } + return "tailscaled.sock" +} + +// Overridden in init by OS-specific files. +var ( + stateFileFunc func() string + + // ensureStateDirPerms applies a restrictive ACL/chmod + // to the provided directory. + ensureStateDirPerms = func(string) error { return nil } +) + +// DefaultTailscaledStateFile returns the default path to the +// tailscaled state file, or the empty string if there's no reasonable +// default value. +func DefaultTailscaledStateFile() string { + if f := stateFileFunc; f != nil { + return f() + } + if runtime.GOOS == "windows" { + return filepath.Join(os.Getenv("ProgramData"), "Tailscale", "server-state.conf") + } + return "" +} + +// MkStateDir ensures that dirPath, the daemon's configuration directory +// containing machine keys etc, both exists and has the correct permissions. +// We want it to only be accessible to the user the daemon is running under. +func MkStateDir(dirPath string) error { + if err := os.MkdirAll(dirPath, 0700); err != nil { + return err + } + return ensureStateDirPerms(dirPath) +} + +// LegacyStateFilePath returns the legacy path to the state file when +// it was stored under the current user's %LocalAppData%. +// +// It is only called on Windows. +func LegacyStateFilePath() string { + if runtime.GOOS == "windows" { + return filepath.Join(os.Getenv("LocalAppData"), "Tailscale", "server-state.conf") + } + return "" +} diff --git a/paths/paths_windows.go b/paths/paths_windows.go index 2249810494b14..4705400655212 100644 --- a/paths/paths_windows.go +++ b/paths/paths_windows.go @@ -1,100 +1,100 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package paths - -import ( - "os" - "path/filepath" - "strings" - - "golang.org/x/sys/windows" - "tailscale.com/util/winutil" -) - -func init() { - ensureStateDirPerms = ensureStateDirPermsWindows -} - -// ensureStateDirPermsWindows applies a restrictive ACL to the directory specified by dirPath. -// It sets the following security attributes on the directory: -// Owner: The user for the current process; -// Primary Group: The primary group for the current process; -// DACL: Full control to the current user and to the Administrators group. -// -// (We include Administrators so that admin users may still access logs; -// granting access exclusively to LocalSystem would require admins to use -// special tools to access the Log directory) -// -// Inheritance: The directory does not inherit the ACL from its parent. -// -// However, any directories and/or files created within this -// directory *do* inherit the ACL that we are setting. -func ensureStateDirPermsWindows(dirPath string) error { - fi, err := os.Stat(dirPath) - if err != nil { - return err - } - if !fi.IsDir() { - return os.ErrInvalid - } - if strings.ToLower(filepath.Base(dirPath)) != "tailscale" { - return nil - } - - // We need the info for our current user as SIDs - sids, err := winutil.GetCurrentUserSIDs() - if err != nil { - return err - } - - // We also need the SID for the Administrators group so that admins may - // easily access logs. - adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid) - if err != nil { - return err - } - - // Munge the SIDs into the format required by EXPLICIT_ACCESS. - userTrustee := windows.TRUSTEE{nil, windows.NO_MULTIPLE_TRUSTEE, - windows.TRUSTEE_IS_SID, windows.TRUSTEE_IS_USER, - windows.TrusteeValueFromSID(sids.User)} - - adminTrustee := windows.TRUSTEE{nil, windows.NO_MULTIPLE_TRUSTEE, - windows.TRUSTEE_IS_SID, windows.TRUSTEE_IS_WELL_KNOWN_GROUP, - windows.TrusteeValueFromSID(adminGroupSid)} - - // We declare our access rights via this array of EXPLICIT_ACCESS structures. - // We set full access to our user and to Administrators. - // We configure the DACL such that any files or directories created within - // dirPath will also inherit this DACL. - explicitAccess := []windows.EXPLICIT_ACCESS{ - { - windows.GENERIC_ALL, - windows.SET_ACCESS, - windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, - userTrustee, - }, - { - windows.GENERIC_ALL, - windows.SET_ACCESS, - windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, - adminTrustee, - }, - } - - dacl, err := windows.ACLFromEntries(explicitAccess, nil) - if err != nil { - return err - } - - // We now reset the file's owner, primary group, and DACL. - // We also must pass PROTECTED_DACL_SECURITY_INFORMATION so that our new ACL - // does not inherit any ACL entries from the parent directory. - const flags = windows.OWNER_SECURITY_INFORMATION | - windows.GROUP_SECURITY_INFORMATION | - windows.DACL_SECURITY_INFORMATION | - windows.PROTECTED_DACL_SECURITY_INFORMATION - return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, flags, - sids.User, sids.PrimaryGroup, dacl, nil) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package paths + +import ( + "os" + "path/filepath" + "strings" + + "golang.org/x/sys/windows" + "tailscale.com/util/winutil" +) + +func init() { + ensureStateDirPerms = ensureStateDirPermsWindows +} + +// ensureStateDirPermsWindows applies a restrictive ACL to the directory specified by dirPath. +// It sets the following security attributes on the directory: +// Owner: The user for the current process; +// Primary Group: The primary group for the current process; +// DACL: Full control to the current user and to the Administrators group. +// +// (We include Administrators so that admin users may still access logs; +// granting access exclusively to LocalSystem would require admins to use +// special tools to access the Log directory) +// +// Inheritance: The directory does not inherit the ACL from its parent. +// +// However, any directories and/or files created within this +// directory *do* inherit the ACL that we are setting. +func ensureStateDirPermsWindows(dirPath string) error { + fi, err := os.Stat(dirPath) + if err != nil { + return err + } + if !fi.IsDir() { + return os.ErrInvalid + } + if strings.ToLower(filepath.Base(dirPath)) != "tailscale" { + return nil + } + + // We need the info for our current user as SIDs + sids, err := winutil.GetCurrentUserSIDs() + if err != nil { + return err + } + + // We also need the SID for the Administrators group so that admins may + // easily access logs. + adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid) + if err != nil { + return err + } + + // Munge the SIDs into the format required by EXPLICIT_ACCESS. + userTrustee := windows.TRUSTEE{nil, windows.NO_MULTIPLE_TRUSTEE, + windows.TRUSTEE_IS_SID, windows.TRUSTEE_IS_USER, + windows.TrusteeValueFromSID(sids.User)} + + adminTrustee := windows.TRUSTEE{nil, windows.NO_MULTIPLE_TRUSTEE, + windows.TRUSTEE_IS_SID, windows.TRUSTEE_IS_WELL_KNOWN_GROUP, + windows.TrusteeValueFromSID(adminGroupSid)} + + // We declare our access rights via this array of EXPLICIT_ACCESS structures. + // We set full access to our user and to Administrators. + // We configure the DACL such that any files or directories created within + // dirPath will also inherit this DACL. + explicitAccess := []windows.EXPLICIT_ACCESS{ + { + windows.GENERIC_ALL, + windows.SET_ACCESS, + windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, + userTrustee, + }, + { + windows.GENERIC_ALL, + windows.SET_ACCESS, + windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, + adminTrustee, + }, + } + + dacl, err := windows.ACLFromEntries(explicitAccess, nil) + if err != nil { + return err + } + + // We now reset the file's owner, primary group, and DACL. + // We also must pass PROTECTED_DACL_SECURITY_INFORMATION so that our new ACL + // does not inherit any ACL entries from the parent directory. + const flags = windows.OWNER_SECURITY_INFORMATION | + windows.GROUP_SECURITY_INFORMATION | + windows.DACL_SECURITY_INFORMATION | + windows.PROTECTED_DACL_SECURITY_INFORMATION + return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, flags, + sids.User, sids.PrimaryGroup, dacl, nil) +} diff --git a/portlist/clean.go b/portlist/clean.go index cad1562c3e1d8..7e137de948e99 100644 --- a/portlist/clean.go +++ b/portlist/clean.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package portlist - -import ( - "path/filepath" - "strings" -) - -// argvSubject takes a command and its flags, and returns the -// short/pretty name for the process. This is usually the basename of -// the binary being executed, but can sometimes vary (e.g. so that we -// don't report all Java programs as "java"). -func argvSubject(argv ...string) string { - if len(argv) == 0 { - return "" - } - ret := filepath.Base(argv[0]) - - // Handle special cases. - switch { - case ret == "mono" && len(argv) >= 2: - // .Net programs execute as `mono actualProgram.exe`. - ret = filepath.Base(argv[1]) - } - - // Handle space separated argv - ret, _, _ = strings.Cut(ret, " ") - - // Remove common noise. - ret = strings.TrimSpace(ret) - ret = strings.TrimSuffix(ret, ".exe") - - return ret -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package portlist + +import ( + "path/filepath" + "strings" +) + +// argvSubject takes a command and its flags, and returns the +// short/pretty name for the process. This is usually the basename of +// the binary being executed, but can sometimes vary (e.g. so that we +// don't report all Java programs as "java"). +func argvSubject(argv ...string) string { + if len(argv) == 0 { + return "" + } + ret := filepath.Base(argv[0]) + + // Handle special cases. + switch { + case ret == "mono" && len(argv) >= 2: + // .Net programs execute as `mono actualProgram.exe`. + ret = filepath.Base(argv[1]) + } + + // Handle space separated argv + ret, _, _ = strings.Cut(ret, " ") + + // Remove common noise. + ret = strings.TrimSpace(ret) + ret = strings.TrimSuffix(ret, ".exe") + + return ret +} diff --git a/portlist/clean_test.go b/portlist/clean_test.go index cca18ab8eb2c6..5a1e34405eed0 100644 --- a/portlist/clean_test.go +++ b/portlist/clean_test.go @@ -1,57 +1,57 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package portlist - -import "testing" - -func TestArgvSubject(t *testing.T) { - tests := []struct { - in []string - want string - }{ - { - in: nil, - want: "", - }, - { - in: []string{"/usr/bin/sshd"}, - want: "sshd", - }, - { - in: []string{"/bin/mono"}, - want: "mono", - }, - { - in: []string{"/nix/store/x2cw2xjw98zdysf56bdlfzsr7cyxv0jf-mono-5.20.1.27/bin/mono", "/bin/exampleProgram.exe"}, - want: "exampleProgram", - }, - { - in: []string{"/bin/mono", "/sbin/exampleProgram.bin"}, - want: "exampleProgram.bin", - }, - { - in: []string{"/usr/bin/sshd_config [listener] 1 of 10-100 startups"}, - want: "sshd_config", - }, - { - in: []string{"/usr/bin/sshd [listener] 0 of 10-100 startups"}, - want: "sshd", - }, - { - in: []string{"/opt/aws/bin/eic_run_authorized_keys %u %f -o AuthorizedKeysCommandUser ec2-instance-connect [listener] 0 of 10-100 startups"}, - want: "eic_run_authorized_keys", - }, - { - in: []string{"/usr/bin/nginx worker"}, - want: "nginx", - }, - } - - for _, test := range tests { - got := argvSubject(test.in...) - if got != test.want { - t.Errorf("argvSubject(%v) = %q, want %q", test.in, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package portlist + +import "testing" + +func TestArgvSubject(t *testing.T) { + tests := []struct { + in []string + want string + }{ + { + in: nil, + want: "", + }, + { + in: []string{"/usr/bin/sshd"}, + want: "sshd", + }, + { + in: []string{"/bin/mono"}, + want: "mono", + }, + { + in: []string{"/nix/store/x2cw2xjw98zdysf56bdlfzsr7cyxv0jf-mono-5.20.1.27/bin/mono", "/bin/exampleProgram.exe"}, + want: "exampleProgram", + }, + { + in: []string{"/bin/mono", "/sbin/exampleProgram.bin"}, + want: "exampleProgram.bin", + }, + { + in: []string{"/usr/bin/sshd_config [listener] 1 of 10-100 startups"}, + want: "sshd_config", + }, + { + in: []string{"/usr/bin/sshd [listener] 0 of 10-100 startups"}, + want: "sshd", + }, + { + in: []string{"/opt/aws/bin/eic_run_authorized_keys %u %f -o AuthorizedKeysCommandUser ec2-instance-connect [listener] 0 of 10-100 startups"}, + want: "eic_run_authorized_keys", + }, + { + in: []string{"/usr/bin/nginx worker"}, + want: "nginx", + }, + } + + for _, test := range tests { + got := argvSubject(test.in...) + if got != test.want { + t.Errorf("argvSubject(%v) = %q, want %q", test.in, got, test.want) + } + } +} diff --git a/portlist/netstat_test.go b/portlist/netstat_test.go index d04b657f623f4..023b75b794426 100644 --- a/portlist/netstat_test.go +++ b/portlist/netstat_test.go @@ -1,92 +1,92 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin && !ios - -package portlist - -import ( - "bufio" - "encoding/json" - "fmt" - "strings" - "testing" - - "go4.org/mem" -) - -func TestParsePort(t *testing.T) { - type InOut struct { - in string - expect int - } - tests := []InOut{ - {"1.2.3.4:5678", 5678}, - {"0.0.0.0.999", 999}, - {"1.2.3.4:*", 0}, - {"5.5.5.5:0", 0}, - {"[1::2]:5", 5}, - {"[1::2].5", 5}, - {"gibberish", -1}, - } - - for _, io := range tests { - got := parsePort(mem.S(io.in)) - if got != io.expect { - t.Fatalf("input:%#v expect:%v got:%v\n", io.in, io.expect, got) - } - } -} - -const netstatOutput = ` -// macOS -tcp4 0 0 *.23 *.* LISTEN -tcp6 0 0 *.24 *.* LISTEN -tcp4 0 0 *.8185 *.* LISTEN -tcp4 0 0 127.0.0.1.8186 *.* LISTEN -tcp6 0 0 ::1.8187 *.* LISTEN -tcp4 0 0 127.1.2.3.8188 *.* LISTEN - -udp6 0 0 *.106 *.* -udp4 0 0 *.104 *.* -udp46 0 0 *.146 *.* -` - -func TestParsePortsNetstat(t *testing.T) { - for _, loopBack := range [...]bool{false, true} { - t.Run(fmt.Sprintf("loopback_%v", loopBack), func(t *testing.T) { - want := List{ - {"tcp", 23, "", 0}, - {"tcp", 24, "", 0}, - {"udp", 104, "", 0}, - {"udp", 106, "", 0}, - {"udp", 146, "", 0}, - {"tcp", 8185, "", 0}, // but not 8186, 8187, 8188 on localhost, when loopback is false - } - if loopBack { - want = append(want, - Port{"tcp", 8186, "", 0}, - Port{"tcp", 8187, "", 0}, - Port{"tcp", 8188, "", 0}, - ) - } - pl, err := appendParsePortsNetstat(nil, bufio.NewReader(strings.NewReader(netstatOutput)), loopBack) - if err != nil { - t.Fatal(err) - } - pl = sortAndDedup(pl) - jgot, _ := json.MarshalIndent(pl, "", "\t") - jwant, _ := json.MarshalIndent(want, "", "\t") - if len(pl) != len(want) { - t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant) - } - for i := range pl { - if pl[i] != want[i] { - t.Errorf("row#%d\n got: %+v\n\nwant: %+v\n", - i, pl[i], want[i]) - t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant) - } - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin && !ios + +package portlist + +import ( + "bufio" + "encoding/json" + "fmt" + "strings" + "testing" + + "go4.org/mem" +) + +func TestParsePort(t *testing.T) { + type InOut struct { + in string + expect int + } + tests := []InOut{ + {"1.2.3.4:5678", 5678}, + {"0.0.0.0.999", 999}, + {"1.2.3.4:*", 0}, + {"5.5.5.5:0", 0}, + {"[1::2]:5", 5}, + {"[1::2].5", 5}, + {"gibberish", -1}, + } + + for _, io := range tests { + got := parsePort(mem.S(io.in)) + if got != io.expect { + t.Fatalf("input:%#v expect:%v got:%v\n", io.in, io.expect, got) + } + } +} + +const netstatOutput = ` +// macOS +tcp4 0 0 *.23 *.* LISTEN +tcp6 0 0 *.24 *.* LISTEN +tcp4 0 0 *.8185 *.* LISTEN +tcp4 0 0 127.0.0.1.8186 *.* LISTEN +tcp6 0 0 ::1.8187 *.* LISTEN +tcp4 0 0 127.1.2.3.8188 *.* LISTEN + +udp6 0 0 *.106 *.* +udp4 0 0 *.104 *.* +udp46 0 0 *.146 *.* +` + +func TestParsePortsNetstat(t *testing.T) { + for _, loopBack := range [...]bool{false, true} { + t.Run(fmt.Sprintf("loopback_%v", loopBack), func(t *testing.T) { + want := List{ + {"tcp", 23, "", 0}, + {"tcp", 24, "", 0}, + {"udp", 104, "", 0}, + {"udp", 106, "", 0}, + {"udp", 146, "", 0}, + {"tcp", 8185, "", 0}, // but not 8186, 8187, 8188 on localhost, when loopback is false + } + if loopBack { + want = append(want, + Port{"tcp", 8186, "", 0}, + Port{"tcp", 8187, "", 0}, + Port{"tcp", 8188, "", 0}, + ) + } + pl, err := appendParsePortsNetstat(nil, bufio.NewReader(strings.NewReader(netstatOutput)), loopBack) + if err != nil { + t.Fatal(err) + } + pl = sortAndDedup(pl) + jgot, _ := json.MarshalIndent(pl, "", "\t") + jwant, _ := json.MarshalIndent(want, "", "\t") + if len(pl) != len(want) { + t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant) + } + for i := range pl { + if pl[i] != want[i] { + t.Errorf("row#%d\n got: %+v\n\nwant: %+v\n", + i, pl[i], want[i]) + t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant) + } + } + }) + } +} diff --git a/portlist/poller.go b/portlist/poller.go index 226f3b9958e8d..423bad3be33ba 100644 --- a/portlist/poller.go +++ b/portlist/poller.go @@ -1,122 +1,122 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// This file contains the code related to the Poller type and its methods. -// The hot loop to keep efficient is Poller.Run. - -package portlist - -import ( - "errors" - "fmt" - "runtime" - "slices" - "sync" - "time" - - "tailscale.com/envknob" -) - -var ( - newOSImpl func(includeLocalhost bool) osImpl // if non-nil, constructs a new osImpl. - pollInterval = 5 * time.Second // default; changed by some OS-specific init funcs - debugDisablePortlist = envknob.RegisterBool("TS_DEBUG_DISABLE_PORTLIST") -) - -// PollInterval is the recommended OS-specific interval -// to wait between *Poller.Poll method calls. -func PollInterval() time.Duration { - return pollInterval -} - -// Poller scans the systems for listening ports periodically and sends -// the results to C. -type Poller struct { - // IncludeLocalhost controls whether services bound to localhost are included. - // - // This field should only be changed before calling Run. - IncludeLocalhost bool - - // os, if non-nil, is an OS-specific implementation of the portlist getting - // code. When non-nil, it's responsible for getting the complete list of - // cached ports complete with the process name. That is, when set, - // addProcesses is not used. - // A nil values means we don't have code for getting the list on the current - // operating system. - os osImpl - initOnce sync.Once // guards init of os - initErr error - - // scatch is memory for Poller.getList to reuse between calls. - scratch []Port - - prev List // most recent data, not aliasing scratch -} - -// osImpl is the OS-specific implementation of getting the open listening ports. -type osImpl interface { - Close() error - - // AppendListeningPorts appends to base (which must have length 0 but - // optional capacity) the list of listening ports. The Port struct should be - // populated as completely as possible. Another pass will not add anything - // to it. - // - // The appended ports should be in a sorted (or at least stable) order so - // the caller can cheaply detect when there are no changes. - AppendListeningPorts(base []Port) ([]Port, error) -} - -func (p *Poller) setPrev(pl List) { - // Make a copy, as the pass in pl slice aliases pl.scratch and we don't want - // that to except to the caller. - p.prev = slices.Clone(pl) -} - -// init initializes the Poller by ensuring it has an underlying -// OS implementation and is not turned off by envknob. -func (p *Poller) init() { - switch { - case debugDisablePortlist(): - p.initErr = errors.New("portlist disabled by envknob") - case newOSImpl == nil: - p.initErr = errors.New("portlist poller not implemented on " + runtime.GOOS) - default: - p.os = newOSImpl(p.IncludeLocalhost) - } -} - -// Close closes the Poller. -func (p *Poller) Close() error { - if p.initErr != nil { - return p.initErr - } - if p.os == nil { - return nil - } - return p.os.Close() -} - -// Poll returns the list of listening ports, if changed from -// a previous call as indicated by the changed result. -func (p *Poller) Poll() (ports []Port, changed bool, err error) { - p.initOnce.Do(p.init) - if p.initErr != nil { - return nil, false, fmt.Errorf("error initializing poller: %w", p.initErr) - } - pl, err := p.getList() - if err != nil { - return nil, false, err - } - if pl.equal(p.prev) { - return nil, false, nil - } - p.setPrev(pl) - return p.prev, true, nil -} - -func (p *Poller) getList() (List, error) { - var err error - p.scratch, err = p.os.AppendListeningPorts(p.scratch[:0]) - return p.scratch, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This file contains the code related to the Poller type and its methods. +// The hot loop to keep efficient is Poller.Run. + +package portlist + +import ( + "errors" + "fmt" + "runtime" + "slices" + "sync" + "time" + + "tailscale.com/envknob" +) + +var ( + newOSImpl func(includeLocalhost bool) osImpl // if non-nil, constructs a new osImpl. + pollInterval = 5 * time.Second // default; changed by some OS-specific init funcs + debugDisablePortlist = envknob.RegisterBool("TS_DEBUG_DISABLE_PORTLIST") +) + +// PollInterval is the recommended OS-specific interval +// to wait between *Poller.Poll method calls. +func PollInterval() time.Duration { + return pollInterval +} + +// Poller scans the systems for listening ports periodically and sends +// the results to C. +type Poller struct { + // IncludeLocalhost controls whether services bound to localhost are included. + // + // This field should only be changed before calling Run. + IncludeLocalhost bool + + // os, if non-nil, is an OS-specific implementation of the portlist getting + // code. When non-nil, it's responsible for getting the complete list of + // cached ports complete with the process name. That is, when set, + // addProcesses is not used. + // A nil values means we don't have code for getting the list on the current + // operating system. + os osImpl + initOnce sync.Once // guards init of os + initErr error + + // scatch is memory for Poller.getList to reuse between calls. + scratch []Port + + prev List // most recent data, not aliasing scratch +} + +// osImpl is the OS-specific implementation of getting the open listening ports. +type osImpl interface { + Close() error + + // AppendListeningPorts appends to base (which must have length 0 but + // optional capacity) the list of listening ports. The Port struct should be + // populated as completely as possible. Another pass will not add anything + // to it. + // + // The appended ports should be in a sorted (or at least stable) order so + // the caller can cheaply detect when there are no changes. + AppendListeningPorts(base []Port) ([]Port, error) +} + +func (p *Poller) setPrev(pl List) { + // Make a copy, as the pass in pl slice aliases pl.scratch and we don't want + // that to except to the caller. + p.prev = slices.Clone(pl) +} + +// init initializes the Poller by ensuring it has an underlying +// OS implementation and is not turned off by envknob. +func (p *Poller) init() { + switch { + case debugDisablePortlist(): + p.initErr = errors.New("portlist disabled by envknob") + case newOSImpl == nil: + p.initErr = errors.New("portlist poller not implemented on " + runtime.GOOS) + default: + p.os = newOSImpl(p.IncludeLocalhost) + } +} + +// Close closes the Poller. +func (p *Poller) Close() error { + if p.initErr != nil { + return p.initErr + } + if p.os == nil { + return nil + } + return p.os.Close() +} + +// Poll returns the list of listening ports, if changed from +// a previous call as indicated by the changed result. +func (p *Poller) Poll() (ports []Port, changed bool, err error) { + p.initOnce.Do(p.init) + if p.initErr != nil { + return nil, false, fmt.Errorf("error initializing poller: %w", p.initErr) + } + pl, err := p.getList() + if err != nil { + return nil, false, err + } + if pl.equal(p.prev) { + return nil, false, nil + } + p.setPrev(pl) + return p.prev, true, nil +} + +func (p *Poller) getList() (List, error) { + var err error + p.scratch, err = p.os.AppendListeningPorts(p.scratch[:0]) + return p.scratch, err +} diff --git a/portlist/portlist.go b/portlist/portlist.go index 6d24cedcc5038..9f7af40d08dc1 100644 --- a/portlist/portlist.go +++ b/portlist/portlist.go @@ -1,80 +1,80 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// This file is just the types. The bulk of the code is in poller.go. - -// The portlist package contains code that checks what ports are open and -// listening on the current machine. -package portlist - -import ( - "fmt" - "sort" - "strings" -) - -// Port is a listening port on the machine. -type Port struct { - Proto string // "tcp" or "udp" - Port uint16 // port number - Process string // optional process name, if found (requires suitable permissions) - Pid int // process ID, if known (requires suitable permissions) -} - -// List is a list of Ports. -type List []Port - -func (a *Port) lessThan(b *Port) bool { - if a.Port != b.Port { - return a.Port < b.Port - } - if a.Proto != b.Proto { - return a.Proto < b.Proto - } - return a.Process < b.Process -} - -func (a *Port) equal(b *Port) bool { - return a.Port == b.Port && - a.Proto == b.Proto && - a.Process == b.Process -} - -func (a List) equal(b List) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if !a[i].equal(&b[i]) { - return false - } - } - return true -} - -func (pl List) String() string { - var sb strings.Builder - for _, v := range pl { - fmt.Fprintf(&sb, "%-3s %5d %#v\n", - v.Proto, v.Port, v.Process) - } - return strings.TrimRight(sb.String(), "\n") -} - -// sortAndDedup sorts ps in place (by Port.lessThan) and then returns -// a subset of it with duplicate (Proto, Port) removed. -func sortAndDedup(ps List) List { - sort.Slice(ps, func(i, j int) bool { - return (&ps[i]).lessThan(&ps[j]) - }) - out := ps[:0] - var last Port - for _, p := range ps { - if last.Proto == p.Proto && last.Port == p.Port { - continue - } - out = append(out, p) - last = p - } - return out -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This file is just the types. The bulk of the code is in poller.go. + +// The portlist package contains code that checks what ports are open and +// listening on the current machine. +package portlist + +import ( + "fmt" + "sort" + "strings" +) + +// Port is a listening port on the machine. +type Port struct { + Proto string // "tcp" or "udp" + Port uint16 // port number + Process string // optional process name, if found (requires suitable permissions) + Pid int // process ID, if known (requires suitable permissions) +} + +// List is a list of Ports. +type List []Port + +func (a *Port) lessThan(b *Port) bool { + if a.Port != b.Port { + return a.Port < b.Port + } + if a.Proto != b.Proto { + return a.Proto < b.Proto + } + return a.Process < b.Process +} + +func (a *Port) equal(b *Port) bool { + return a.Port == b.Port && + a.Proto == b.Proto && + a.Process == b.Process +} + +func (a List) equal(b List) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !a[i].equal(&b[i]) { + return false + } + } + return true +} + +func (pl List) String() string { + var sb strings.Builder + for _, v := range pl { + fmt.Fprintf(&sb, "%-3s %5d %#v\n", + v.Proto, v.Port, v.Process) + } + return strings.TrimRight(sb.String(), "\n") +} + +// sortAndDedup sorts ps in place (by Port.lessThan) and then returns +// a subset of it with duplicate (Proto, Port) removed. +func sortAndDedup(ps List) List { + sort.Slice(ps, func(i, j int) bool { + return (&ps[i]).lessThan(&ps[j]) + }) + out := ps[:0] + var last Port + for _, p := range ps { + if last.Proto == p.Proto && last.Port == p.Port { + continue + } + out = append(out, p) + last = p + } + return out +} diff --git a/portlist/portlist_macos.go b/portlist/portlist_macos.go index 2f4fee351f1cf..e67b2c9b8c064 100644 --- a/portlist/portlist_macos.go +++ b/portlist/portlist_macos.go @@ -1,230 +1,230 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin && !ios - -package portlist - -import ( - "bufio" - "bytes" - "fmt" - "log" - "os/exec" - "strings" - "sync/atomic" - "time" - - "go4.org/mem" -) - -func init() { - newOSImpl = newMacOSImpl - - // We have to run netstat, which is a bit expensive, so don't do it too often. - pollInterval = 5 * time.Second -} - -type macOSImpl struct { - known map[protoPort]*portMeta // inode string => metadata - netstatPath string // lazily populated - - br *bufio.Reader // reused - portsBuf []Port - includeLocalhost bool -} - -type protoPort struct { - proto string - port uint16 -} - -type portMeta struct { - port Port - keep bool -} - -func newMacOSImpl(includeLocalhost bool) osImpl { - return &macOSImpl{ - known: map[protoPort]*portMeta{}, - br: bufio.NewReader(bytes.NewReader(nil)), - includeLocalhost: includeLocalhost, - } -} - -func (*macOSImpl) Close() error { return nil } - -func (im *macOSImpl) AppendListeningPorts(base []Port) ([]Port, error) { - var err error - im.portsBuf, err = im.appendListeningPortsNetstat(im.portsBuf[:0]) - if err != nil { - return nil, err - } - - for _, pm := range im.known { - pm.keep = false - } - - var needProcs bool - for _, p := range im.portsBuf { - fp := protoPort{ - proto: p.Proto, - port: p.Port, - } - if pm, ok := im.known[fp]; ok { - pm.keep = true - } else { - needProcs = true - im.known[fp] = &portMeta{ - port: p, - keep: true, - } - } - } - - ret := base - for k, m := range im.known { - if !m.keep { - delete(im.known, k) - } - } - - if needProcs { - im.addProcesses() // best effort - } - - for _, m := range im.known { - ret = append(ret, m.port) - } - return sortAndDedup(ret), nil -} - -func (im *macOSImpl) appendListeningPortsNetstat(base []Port) ([]Port, error) { - if im.netstatPath == "" { - var err error - im.netstatPath, err = exec.LookPath("netstat") - if err != nil { - return nil, fmt.Errorf("netstat: lookup: %v", err) - } - } - - cmd := exec.Command(im.netstatPath, "-na") - outPipe, err := cmd.StdoutPipe() - if err != nil { - return nil, err - } - im.br.Reset(outPipe) - - if err := cmd.Start(); err != nil { - return nil, err - } - defer cmd.Process.Wait() - defer cmd.Process.Kill() - - return appendParsePortsNetstat(base, im.br, im.includeLocalhost) -} - -var lsofFailed atomic.Bool - -// In theory, lsof could replace the function of both listPorts() and -// addProcesses(), since it provides a superset of the netstat output. -// However, "netstat -na" runs ~100x faster than lsof on my machine, so -// we should do it only if the list of open ports has actually changed. -// -// This fails in a macOS sandbox (i.e. in the Mac App Store or System -// Extension GUI build), but does at least work in the -// tailscaled-on-macos mode. -func (im *macOSImpl) addProcesses() error { - if lsofFailed.Load() { - // This previously failed in the macOS sandbox, so don't try again. - return nil - } - exe, err := exec.LookPath("lsof") - if err != nil { - return fmt.Errorf("lsof: lookup: %v", err) - } - lsofCmd := exec.Command(exe, "-F", "-n", "-P", "-O", "-S2", "-T", "-i4", "-i6") - outPipe, err := lsofCmd.StdoutPipe() - if err != nil { - return err - } - err = lsofCmd.Start() - if err != nil { - var stderr []byte - if xe, ok := err.(*exec.ExitError); ok { - stderr = xe.Stderr - } - // fails when run in a macOS sandbox, so make this non-fatal. - if lsofFailed.CompareAndSwap(false, true) { - log.Printf("portlist: can't run lsof in Mac sandbox; omitting process names from service list. Error details: %v, %s", err, bytes.TrimSpace(stderr)) - } - return nil - } - defer func() { - ps, err := lsofCmd.Process.Wait() - if err != nil || ps.ExitCode() != 0 { - log.Printf("portlist: can't run lsof in Mac sandbox; omitting process names from service list. Error: %v, exit code %d", err, ps.ExitCode()) - lsofFailed.Store(true) - } - }() - defer lsofCmd.Process.Kill() - - im.br.Reset(outPipe) - - var cmd, proto string - var pid int - for { - line, err := im.br.ReadBytes('\n') - if err != nil { - break - } - if len(line) < 1 { - continue - } - field, val := line[0], bytes.TrimSpace(line[1:]) - switch field { - case 'p': - // starting a new process - cmd = "" - proto = "" - pid = 0 - if p, err := mem.ParseInt(mem.B(val), 10, 0); err == nil { - pid = int(p) - } - case 'c': - cmd = string(val) // TODO(bradfitz): avoid garbage; cache process names between runs? - case 'P': - proto = lsofProtoLower(val) - case 'n': - if mem.Contains(mem.B(val), mem.S("->")) { - continue - } - // a listening port - port := parsePort(mem.B(val)) - if port <= 0 { - continue - } - pp := protoPort{proto, uint16(port)} - m := im.known[pp] - switch { - case m != nil: - m.port.Process = cmd - m.port.Pid = pid - default: - // ignore: processes and ports come and go - } - } - } - - return nil -} - -func lsofProtoLower(p []byte) string { - if string(p) == "TCP" { - return "tcp" - } - if string(p) == "UDP" { - return "udp" - } - return strings.ToLower(string(p)) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin && !ios + +package portlist + +import ( + "bufio" + "bytes" + "fmt" + "log" + "os/exec" + "strings" + "sync/atomic" + "time" + + "go4.org/mem" +) + +func init() { + newOSImpl = newMacOSImpl + + // We have to run netstat, which is a bit expensive, so don't do it too often. + pollInterval = 5 * time.Second +} + +type macOSImpl struct { + known map[protoPort]*portMeta // inode string => metadata + netstatPath string // lazily populated + + br *bufio.Reader // reused + portsBuf []Port + includeLocalhost bool +} + +type protoPort struct { + proto string + port uint16 +} + +type portMeta struct { + port Port + keep bool +} + +func newMacOSImpl(includeLocalhost bool) osImpl { + return &macOSImpl{ + known: map[protoPort]*portMeta{}, + br: bufio.NewReader(bytes.NewReader(nil)), + includeLocalhost: includeLocalhost, + } +} + +func (*macOSImpl) Close() error { return nil } + +func (im *macOSImpl) AppendListeningPorts(base []Port) ([]Port, error) { + var err error + im.portsBuf, err = im.appendListeningPortsNetstat(im.portsBuf[:0]) + if err != nil { + return nil, err + } + + for _, pm := range im.known { + pm.keep = false + } + + var needProcs bool + for _, p := range im.portsBuf { + fp := protoPort{ + proto: p.Proto, + port: p.Port, + } + if pm, ok := im.known[fp]; ok { + pm.keep = true + } else { + needProcs = true + im.known[fp] = &portMeta{ + port: p, + keep: true, + } + } + } + + ret := base + for k, m := range im.known { + if !m.keep { + delete(im.known, k) + } + } + + if needProcs { + im.addProcesses() // best effort + } + + for _, m := range im.known { + ret = append(ret, m.port) + } + return sortAndDedup(ret), nil +} + +func (im *macOSImpl) appendListeningPortsNetstat(base []Port) ([]Port, error) { + if im.netstatPath == "" { + var err error + im.netstatPath, err = exec.LookPath("netstat") + if err != nil { + return nil, fmt.Errorf("netstat: lookup: %v", err) + } + } + + cmd := exec.Command(im.netstatPath, "-na") + outPipe, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + im.br.Reset(outPipe) + + if err := cmd.Start(); err != nil { + return nil, err + } + defer cmd.Process.Wait() + defer cmd.Process.Kill() + + return appendParsePortsNetstat(base, im.br, im.includeLocalhost) +} + +var lsofFailed atomic.Bool + +// In theory, lsof could replace the function of both listPorts() and +// addProcesses(), since it provides a superset of the netstat output. +// However, "netstat -na" runs ~100x faster than lsof on my machine, so +// we should do it only if the list of open ports has actually changed. +// +// This fails in a macOS sandbox (i.e. in the Mac App Store or System +// Extension GUI build), but does at least work in the +// tailscaled-on-macos mode. +func (im *macOSImpl) addProcesses() error { + if lsofFailed.Load() { + // This previously failed in the macOS sandbox, so don't try again. + return nil + } + exe, err := exec.LookPath("lsof") + if err != nil { + return fmt.Errorf("lsof: lookup: %v", err) + } + lsofCmd := exec.Command(exe, "-F", "-n", "-P", "-O", "-S2", "-T", "-i4", "-i6") + outPipe, err := lsofCmd.StdoutPipe() + if err != nil { + return err + } + err = lsofCmd.Start() + if err != nil { + var stderr []byte + if xe, ok := err.(*exec.ExitError); ok { + stderr = xe.Stderr + } + // fails when run in a macOS sandbox, so make this non-fatal. + if lsofFailed.CompareAndSwap(false, true) { + log.Printf("portlist: can't run lsof in Mac sandbox; omitting process names from service list. Error details: %v, %s", err, bytes.TrimSpace(stderr)) + } + return nil + } + defer func() { + ps, err := lsofCmd.Process.Wait() + if err != nil || ps.ExitCode() != 0 { + log.Printf("portlist: can't run lsof in Mac sandbox; omitting process names from service list. Error: %v, exit code %d", err, ps.ExitCode()) + lsofFailed.Store(true) + } + }() + defer lsofCmd.Process.Kill() + + im.br.Reset(outPipe) + + var cmd, proto string + var pid int + for { + line, err := im.br.ReadBytes('\n') + if err != nil { + break + } + if len(line) < 1 { + continue + } + field, val := line[0], bytes.TrimSpace(line[1:]) + switch field { + case 'p': + // starting a new process + cmd = "" + proto = "" + pid = 0 + if p, err := mem.ParseInt(mem.B(val), 10, 0); err == nil { + pid = int(p) + } + case 'c': + cmd = string(val) // TODO(bradfitz): avoid garbage; cache process names between runs? + case 'P': + proto = lsofProtoLower(val) + case 'n': + if mem.Contains(mem.B(val), mem.S("->")) { + continue + } + // a listening port + port := parsePort(mem.B(val)) + if port <= 0 { + continue + } + pp := protoPort{proto, uint16(port)} + m := im.known[pp] + switch { + case m != nil: + m.port.Process = cmd + m.port.Pid = pid + default: + // ignore: processes and ports come and go + } + } + } + + return nil +} + +func lsofProtoLower(p []byte) string { + if string(p) == "TCP" { + return "tcp" + } + if string(p) == "UDP" { + return "udp" + } + return strings.ToLower(string(p)) +} diff --git a/portlist/portlist_windows.go b/portlist/portlist_windows.go index c164dbad75485..f449973599247 100644 --- a/portlist/portlist_windows.go +++ b/portlist/portlist_windows.go @@ -1,103 +1,103 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package portlist - -import ( - "time" - - "tailscale.com/net/netstat" -) - -func init() { - newOSImpl = newWindowsImpl - // The portlist poller used to fork on Windows, which is insanely expensive, - // so historically we only did this every 5 seconds on Windows. Maybe we - // could reduce it down to 1 seconds like Linux, but nobody's benchmarked as - // of 2022-11-04. - pollInterval = 5 * time.Second -} - -type famPort struct { - proto string - port uint16 - pid uint32 -} - -type windowsImpl struct { - known map[famPort]*portMeta // inode string => metadata - includeLocalhost bool -} - -type portMeta struct { - port Port - keep bool -} - -func newWindowsImpl(includeLocalhost bool) osImpl { - return &windowsImpl{ - known: map[famPort]*portMeta{}, - includeLocalhost: includeLocalhost, - } -} - -func (*windowsImpl) Close() error { return nil } - -func (im *windowsImpl) AppendListeningPorts(base []Port) ([]Port, error) { - // TODO(bradfitz): netstat.Get makes a bunch of garbage. Add an Append-style - // API to that package instead/additionally. - tab, err := netstat.Get() - if err != nil { - return nil, err - } - - for _, pm := range im.known { - pm.keep = false - } - - ret := base - for _, e := range tab.Entries { - if e.State != "LISTEN" { - continue - } - if !im.includeLocalhost && !e.Local.Addr().IsUnspecified() { - continue - } - fp := famPort{ - proto: "tcp", // TODO(bradfitz): UDP too; add to netstat - port: e.Local.Port(), - pid: uint32(e.Pid), - } - pm, ok := im.known[fp] - if ok { - pm.keep = true - continue - } - var process string - if e.OSMetadata != nil { - if module, err := e.OSMetadata.GetModule(); err == nil { - process = module - } - } - pm = &portMeta{ - keep: true, - port: Port{ - Proto: "tcp", - Port: e.Local.Port(), - Process: process, - Pid: e.Pid, - }, - } - im.known[fp] = pm - } - - for k, m := range im.known { - if !m.keep { - delete(im.known, k) - continue - } - ret = append(ret, m.port) - } - - return sortAndDedup(ret), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package portlist + +import ( + "time" + + "tailscale.com/net/netstat" +) + +func init() { + newOSImpl = newWindowsImpl + // The portlist poller used to fork on Windows, which is insanely expensive, + // so historically we only did this every 5 seconds on Windows. Maybe we + // could reduce it down to 1 seconds like Linux, but nobody's benchmarked as + // of 2022-11-04. + pollInterval = 5 * time.Second +} + +type famPort struct { + proto string + port uint16 + pid uint32 +} + +type windowsImpl struct { + known map[famPort]*portMeta // inode string => metadata + includeLocalhost bool +} + +type portMeta struct { + port Port + keep bool +} + +func newWindowsImpl(includeLocalhost bool) osImpl { + return &windowsImpl{ + known: map[famPort]*portMeta{}, + includeLocalhost: includeLocalhost, + } +} + +func (*windowsImpl) Close() error { return nil } + +func (im *windowsImpl) AppendListeningPorts(base []Port) ([]Port, error) { + // TODO(bradfitz): netstat.Get makes a bunch of garbage. Add an Append-style + // API to that package instead/additionally. + tab, err := netstat.Get() + if err != nil { + return nil, err + } + + for _, pm := range im.known { + pm.keep = false + } + + ret := base + for _, e := range tab.Entries { + if e.State != "LISTEN" { + continue + } + if !im.includeLocalhost && !e.Local.Addr().IsUnspecified() { + continue + } + fp := famPort{ + proto: "tcp", // TODO(bradfitz): UDP too; add to netstat + port: e.Local.Port(), + pid: uint32(e.Pid), + } + pm, ok := im.known[fp] + if ok { + pm.keep = true + continue + } + var process string + if e.OSMetadata != nil { + if module, err := e.OSMetadata.GetModule(); err == nil { + process = module + } + } + pm = &portMeta{ + keep: true, + port: Port{ + Proto: "tcp", + Port: e.Local.Port(), + Process: process, + Pid: e.Pid, + }, + } + im.known[fp] = pm + } + + for k, m := range im.known { + if !m.keep { + delete(im.known, k) + continue + } + ret = append(ret, m.port) + } + + return sortAndDedup(ret), nil +} diff --git a/posture/serialnumber_macos.go b/posture/serialnumber_macos.go index ce0b996837889..48355d31393ee 100644 --- a/posture/serialnumber_macos.go +++ b/posture/serialnumber_macos.go @@ -1,74 +1,74 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build cgo && darwin && !ios - -package posture - -// #cgo LDFLAGS: -framework CoreFoundation -framework IOKit -// #include -// #include -// -// #if __MAC_OS_X_VERSION_MIN_REQUIRED < 120000 -// #define kIOMainPortDefault kIOMasterPortDefault -// #endif -// -// const char * -// getSerialNumber() -// { -// CFMutableDictionaryRef matching = IOServiceMatching("IOPlatformExpertDevice"); -// if (!matching) { -// return "err: failed to create dictionary to match IOServices"; -// } -// -// io_service_t service = IOServiceGetMatchingService(kIOMainPortDefault, matching); -// if (!service) { -// return "err: failed to look up registered IOService objects that match a matching dictionary"; -// } -// -// CFStringRef serialNumberRef = IORegistryEntryCreateCFProperty( -// service, -// CFSTR("IOPlatformSerialNumber"), -// kCFAllocatorDefault, -// 0 -// ); -// if (!serialNumberRef) { -// return "err: failed to look up serial number in IORegistry"; -// } -// -// CFIndex length = CFStringGetLength(serialNumberRef); -// CFIndex max_size = CFStringGetMaximumSizeForEncoding(length, kCFStringEncodingUTF8) + 1; -// char *serialNumberBuf = (char *)malloc(max_size); -// -// bool result = CFStringGetCString(serialNumberRef, serialNumberBuf, max_size, kCFStringEncodingUTF8); -// -// CFRelease(serialNumberRef); -// IOObjectRelease(service); -// -// if (!result) { -// free(serialNumberBuf); -// -// return "err: failed to convert serial number reference to string"; -// } -// -// return serialNumberBuf; -// } -import "C" -import ( - "fmt" - "strings" - - "tailscale.com/types/logger" -) - -// GetSerialNumber returns the platform serial sumber as reported by IOKit. -func GetSerialNumbers(_ logger.Logf) ([]string, error) { - csn := C.getSerialNumber() - serialNumber := C.GoString(csn) - - if err, ok := strings.CutPrefix(serialNumber, "err: "); ok { - return nil, fmt.Errorf("failed to get serial number from IOKit: %s", err) - } - - return []string{serialNumber}, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build cgo && darwin && !ios + +package posture + +// #cgo LDFLAGS: -framework CoreFoundation -framework IOKit +// #include +// #include +// +// #if __MAC_OS_X_VERSION_MIN_REQUIRED < 120000 +// #define kIOMainPortDefault kIOMasterPortDefault +// #endif +// +// const char * +// getSerialNumber() +// { +// CFMutableDictionaryRef matching = IOServiceMatching("IOPlatformExpertDevice"); +// if (!matching) { +// return "err: failed to create dictionary to match IOServices"; +// } +// +// io_service_t service = IOServiceGetMatchingService(kIOMainPortDefault, matching); +// if (!service) { +// return "err: failed to look up registered IOService objects that match a matching dictionary"; +// } +// +// CFStringRef serialNumberRef = IORegistryEntryCreateCFProperty( +// service, +// CFSTR("IOPlatformSerialNumber"), +// kCFAllocatorDefault, +// 0 +// ); +// if (!serialNumberRef) { +// return "err: failed to look up serial number in IORegistry"; +// } +// +// CFIndex length = CFStringGetLength(serialNumberRef); +// CFIndex max_size = CFStringGetMaximumSizeForEncoding(length, kCFStringEncodingUTF8) + 1; +// char *serialNumberBuf = (char *)malloc(max_size); +// +// bool result = CFStringGetCString(serialNumberRef, serialNumberBuf, max_size, kCFStringEncodingUTF8); +// +// CFRelease(serialNumberRef); +// IOObjectRelease(service); +// +// if (!result) { +// free(serialNumberBuf); +// +// return "err: failed to convert serial number reference to string"; +// } +// +// return serialNumberBuf; +// } +import "C" +import ( + "fmt" + "strings" + + "tailscale.com/types/logger" +) + +// GetSerialNumber returns the platform serial sumber as reported by IOKit. +func GetSerialNumbers(_ logger.Logf) ([]string, error) { + csn := C.getSerialNumber() + serialNumber := C.GoString(csn) + + if err, ok := strings.CutPrefix(serialNumber, "err: "); ok { + return nil, fmt.Errorf("failed to get serial number from IOKit: %s", err) + } + + return []string{serialNumber}, nil +} diff --git a/posture/serialnumber_notmacos_test.go b/posture/serialnumber_notmacos_test.go index 8106c34b36541..f2a15e0373caf 100644 --- a/posture/serialnumber_notmacos_test.go +++ b/posture/serialnumber_notmacos_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Build on Windows, Linux and *BSD - -//go:build windows || (linux && !android) || freebsd || openbsd || dragonfly || netbsd - -package posture - -import ( - "fmt" - "testing" - - "tailscale.com/types/logger" -) - -func TestGetSerialNumberNotMac(t *testing.T) { - // This test is intentionally skipped as it will - // require root on Linux to get access to the serials. - // The test case is intended for local testing. - // Comment out skip for local testing. - t.Skip() - - sns, err := GetSerialNumbers(logger.Discard) - if err != nil { - t.Fatalf("failed to get serial number: %s", err) - } - - if len(sns) == 0 { - t.Fatalf("expected at least one serial number, got %v", sns) - } - - if len(sns[0]) <= 0 { - t.Errorf("expected a serial number with more than zero characters, got %s", sns[0]) - } - - fmt.Printf("serials: %v\n", sns) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Build on Windows, Linux and *BSD + +//go:build windows || (linux && !android) || freebsd || openbsd || dragonfly || netbsd + +package posture + +import ( + "fmt" + "testing" + + "tailscale.com/types/logger" +) + +func TestGetSerialNumberNotMac(t *testing.T) { + // This test is intentionally skipped as it will + // require root on Linux to get access to the serials. + // The test case is intended for local testing. + // Comment out skip for local testing. + t.Skip() + + sns, err := GetSerialNumbers(logger.Discard) + if err != nil { + t.Fatalf("failed to get serial number: %s", err) + } + + if len(sns) == 0 { + t.Fatalf("expected at least one serial number, got %v", sns) + } + + if len(sns[0]) <= 0 { + t.Errorf("expected a serial number with more than zero characters, got %s", sns[0]) + } + + fmt.Printf("serials: %v\n", sns) +} diff --git a/posture/serialnumber_test.go b/posture/serialnumber_test.go index 1ab8193367bc2..fac4392fab7d3 100644 --- a/posture/serialnumber_test.go +++ b/posture/serialnumber_test.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package posture - -import ( - "testing" - - "tailscale.com/types/logger" -) - -func TestGetSerialNumber(t *testing.T) { - // ensure GetSerialNumbers is implemented - // or covered by a stub on a given platform. - _, _ = GetSerialNumbers(logger.Discard) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package posture + +import ( + "testing" + + "tailscale.com/types/logger" +) + +func TestGetSerialNumber(t *testing.T) { + // ensure GetSerialNumbers is implemented + // or covered by a stub on a given platform. + _, _ = GetSerialNumbers(logger.Discard) +} diff --git a/pull-toolchain.sh b/pull-toolchain.sh index 87350ff53e39a..f5a19e7d75de1 100755 --- a/pull-toolchain.sh +++ b/pull-toolchain.sh @@ -1,16 +1,16 @@ -#!/bin/sh -# Retrieve the latest Go toolchain. -# -set -eu -cd "$(dirname "$0")" - -read -r go_branch go.toolchain.rev -fi - -if [ -n "$(git diff-index --name-only HEAD -- go.toolchain.rev)" ]; then - echo "pull-toolchain.sh: changes imported. Use git commit to make them permanent." >&2 -fi +#!/bin/sh +# Retrieve the latest Go toolchain. +# +set -eu +cd "$(dirname "$0")" + +read -r go_branch go.toolchain.rev +fi + +if [ -n "$(git diff-index --name-only HEAD -- go.toolchain.rev)" ]; then + echo "pull-toolchain.sh: changes imported. Use git commit to make them permanent." >&2 +fi diff --git a/release/deb/debian.postrm.sh b/release/deb/debian.postrm.sh index 93d90b0ea2707..f4dd4ed9cdc15 100755 --- a/release/deb/debian.postrm.sh +++ b/release/deb/debian.postrm.sh @@ -1,17 +1,17 @@ -#!/bin/sh -set -e -if [ -d /run/systemd/system ] ; then - systemctl --system daemon-reload >/dev/null || true -fi - -if [ -x "/usr/bin/deb-systemd-helper" ]; then - if [ "$1" = "remove" ]; then - deb-systemd-helper mask 'tailscaled.service' >/dev/null || true - fi - - if [ "$1" = "purge" ]; then - deb-systemd-helper purge 'tailscaled.service' >/dev/null || true - deb-systemd-helper unmask 'tailscaled.service' >/dev/null || true - rm -rf /var/lib/tailscale - fi -fi +#!/bin/sh +set -e +if [ -d /run/systemd/system ] ; then + systemctl --system daemon-reload >/dev/null || true +fi + +if [ -x "/usr/bin/deb-systemd-helper" ]; then + if [ "$1" = "remove" ]; then + deb-systemd-helper mask 'tailscaled.service' >/dev/null || true + fi + + if [ "$1" = "purge" ]; then + deb-systemd-helper purge 'tailscaled.service' >/dev/null || true + deb-systemd-helper unmask 'tailscaled.service' >/dev/null || true + rm -rf /var/lib/tailscale + fi +fi diff --git a/release/deb/debian.prerm.sh b/release/deb/debian.prerm.sh index a712a08c8181f..9be58ede4d963 100755 --- a/release/deb/debian.prerm.sh +++ b/release/deb/debian.prerm.sh @@ -1,7 +1,7 @@ -#!/bin/sh -set -e -if [ "$1" = "remove" ]; then - if [ -d /run/systemd/system ]; then - deb-systemd-invoke stop 'tailscaled.service' >/dev/null || true - fi -fi +#!/bin/sh +set -e +if [ "$1" = "remove" ]; then + if [ -d /run/systemd/system ]; then + deb-systemd-invoke stop 'tailscaled.service' >/dev/null || true + fi +fi diff --git a/release/dist/memoize.go b/release/dist/memoize.go index f148cd2b79c86..0927ac0a81540 100644 --- a/release/dist/memoize.go +++ b/release/dist/memoize.go @@ -1,86 +1,86 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dist - -import ( - "sync" - - "tailscale.com/util/deephash" -) - -// MemoizedFn is a function that memoize.Do can call. -type MemoizedFn[T any] func() (T, error) - -// Memoize runs MemoizedFns and remembers their results. -type Memoize[O any] struct { - mu sync.Mutex - cond *sync.Cond - outs map[deephash.Sum]O - errs map[deephash.Sum]error - inflight map[deephash.Sum]bool -} - -// Do runs fn and returns its result. -// fn is only run once per unique key. Subsequent Do calls with the same key -// return the memoized result of the first call, even if fn is a different -// function. -func (m *Memoize[O]) Do(key any, fn MemoizedFn[O]) (ret O, err error) { - m.mu.Lock() - defer m.mu.Unlock() - if m.cond == nil { - m.cond = sync.NewCond(&m.mu) - m.outs = map[deephash.Sum]O{} - m.errs = map[deephash.Sum]error{} - m.inflight = map[deephash.Sum]bool{} - } - - k := deephash.Hash(&key) - - for m.inflight[k] { - m.cond.Wait() - } - if err := m.errs[k]; err != nil { - var ret O - return ret, err - } - if ret, ok := m.outs[k]; ok { - return ret, nil - } - - m.inflight[k] = true - m.mu.Unlock() - defer func() { - m.mu.Lock() - delete(m.inflight, k) - if err != nil { - m.errs[k] = err - } else { - m.outs[k] = ret - } - m.cond.Broadcast() - }() - - ret, err = fn() - if err != nil { - var ret O - return ret, err - } - return ret, nil -} - -// once is like memoize, but for functions that don't return non-error values. -type once struct { - m Memoize[any] -} - -// Do runs fn. -// fn is only run once per unique key. Subsequent Do calls with the same key -// return the memoized result of the first call, even if fn is a different -// function. -func (o *once) Do(key any, fn func() error) error { - _, err := o.m.Do(key, func() (any, error) { - return nil, fn() - }) - return err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dist + +import ( + "sync" + + "tailscale.com/util/deephash" +) + +// MemoizedFn is a function that memoize.Do can call. +type MemoizedFn[T any] func() (T, error) + +// Memoize runs MemoizedFns and remembers their results. +type Memoize[O any] struct { + mu sync.Mutex + cond *sync.Cond + outs map[deephash.Sum]O + errs map[deephash.Sum]error + inflight map[deephash.Sum]bool +} + +// Do runs fn and returns its result. +// fn is only run once per unique key. Subsequent Do calls with the same key +// return the memoized result of the first call, even if fn is a different +// function. +func (m *Memoize[O]) Do(key any, fn MemoizedFn[O]) (ret O, err error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.cond == nil { + m.cond = sync.NewCond(&m.mu) + m.outs = map[deephash.Sum]O{} + m.errs = map[deephash.Sum]error{} + m.inflight = map[deephash.Sum]bool{} + } + + k := deephash.Hash(&key) + + for m.inflight[k] { + m.cond.Wait() + } + if err := m.errs[k]; err != nil { + var ret O + return ret, err + } + if ret, ok := m.outs[k]; ok { + return ret, nil + } + + m.inflight[k] = true + m.mu.Unlock() + defer func() { + m.mu.Lock() + delete(m.inflight, k) + if err != nil { + m.errs[k] = err + } else { + m.outs[k] = ret + } + m.cond.Broadcast() + }() + + ret, err = fn() + if err != nil { + var ret O + return ret, err + } + return ret, nil +} + +// once is like memoize, but for functions that don't return non-error values. +type once struct { + m Memoize[any] +} + +// Do runs fn. +// fn is only run once per unique key. Subsequent Do calls with the same key +// return the memoized result of the first call, even if fn is a different +// function. +func (o *once) Do(key any, fn func() error) error { + _, err := o.m.Do(key, func() (any, error) { + return nil, fn() + }) + return err +} diff --git a/release/dist/synology/files/Tailscale.sc b/release/dist/synology/files/Tailscale.sc index f3bb1f0bdbe5d..707ac6bb079b1 100644 --- a/release/dist/synology/files/Tailscale.sc +++ b/release/dist/synology/files/Tailscale.sc @@ -1,6 +1,6 @@ -[Tailscale] -title="Tailscale" -desc="Tailscale VPN" -port_forward="no" -src.ports="41641/udp" +[Tailscale] +title="Tailscale" +desc="Tailscale VPN" +port_forward="no" +src.ports="41641/udp" dst.ports="41641/udp" \ No newline at end of file diff --git a/release/dist/synology/files/config b/release/dist/synology/files/config index 1cf1a6cfaee47..4dbc48dfb9434 100644 --- a/release/dist/synology/files/config +++ b/release/dist/synology/files/config @@ -1,11 +1,11 @@ -{ - ".url": { - "SYNO.SDS.Tailscale": { - "type": "url", - "title": "Tailscale", - "icon": "PACKAGE_ICON_256.PNG", - "url": "webman/3rdparty/Tailscale/index.cgi/", - "urlTarget": "_syno_tailscale" - } - } -} +{ + ".url": { + "SYNO.SDS.Tailscale": { + "type": "url", + "title": "Tailscale", + "icon": "PACKAGE_ICON_256.PNG", + "url": "webman/3rdparty/Tailscale/index.cgi/", + "urlTarget": "_syno_tailscale" + } + } +} diff --git a/release/dist/synology/files/index.cgi b/release/dist/synology/files/index.cgi index 996160d1dca4e..2c1990cfd138a 100755 --- a/release/dist/synology/files/index.cgi +++ b/release/dist/synology/files/index.cgi @@ -1,2 +1,2 @@ -#! /bin/sh -exec /var/packages/Tailscale/target/bin/tailscale web -cgi -prefix="/webman/3rdparty/Tailscale/index.cgi/" +#! /bin/sh +exec /var/packages/Tailscale/target/bin/tailscale web -cgi -prefix="/webman/3rdparty/Tailscale/index.cgi/" diff --git a/release/dist/synology/files/logrotate-dsm6 b/release/dist/synology/files/logrotate-dsm6 index a52a6ba24c59e..2df64283afc30 100644 --- a/release/dist/synology/files/logrotate-dsm6 +++ b/release/dist/synology/files/logrotate-dsm6 @@ -1,8 +1,8 @@ -/var/packages/Tailscale/etc/tailscaled.stdout.log { - size 10M - rotate 3 - missingok - copytruncate - compress - notifempty -} +/var/packages/Tailscale/etc/tailscaled.stdout.log { + size 10M + rotate 3 + missingok + copytruncate + compress + notifempty +} diff --git a/release/dist/synology/files/logrotate-dsm7 b/release/dist/synology/files/logrotate-dsm7 index 3fe6775102b72..7020dc925c2ca 100644 --- a/release/dist/synology/files/logrotate-dsm7 +++ b/release/dist/synology/files/logrotate-dsm7 @@ -1,8 +1,8 @@ -/var/packages/Tailscale/var/tailscaled.stdout.log { - size 10M - rotate 3 - missingok - copytruncate - compress - notifempty -} +/var/packages/Tailscale/var/tailscaled.stdout.log { + size 10M + rotate 3 + missingok + copytruncate + compress + notifempty +} diff --git a/release/dist/synology/files/privilege-dsm6 b/release/dist/synology/files/privilege-dsm6 index c638528d199bc..4b6fe093a1f23 100644 --- a/release/dist/synology/files/privilege-dsm6 +++ b/release/dist/synology/files/privilege-dsm6 @@ -1,7 +1,7 @@ -{ - "defaults":{ - "run-as": "root" - }, - "username": "tailscale", - "groupname": "tailscale" -} +{ + "defaults":{ + "run-as": "root" + }, + "username": "tailscale", + "groupname": "tailscale" +} diff --git a/release/dist/synology/files/privilege-dsm7 b/release/dist/synology/files/privilege-dsm7 index 4eca66cff5dd0..93a9c4f7d7bb5 100644 --- a/release/dist/synology/files/privilege-dsm7 +++ b/release/dist/synology/files/privilege-dsm7 @@ -1,7 +1,7 @@ -{ - "defaults":{ - "run-as": "package" - }, - "username": "tailscale", - "groupname": "tailscale" -} +{ + "defaults":{ + "run-as": "package" + }, + "username": "tailscale", + "groupname": "tailscale" +} diff --git a/release/dist/synology/files/privilege-dsm7.for-package-center b/release/dist/synology/files/privilege-dsm7.for-package-center index b2f93cee1a3c6..db14683460909 100644 --- a/release/dist/synology/files/privilege-dsm7.for-package-center +++ b/release/dist/synology/files/privilege-dsm7.for-package-center @@ -1,13 +1,13 @@ -{ - "defaults":{ - "run-as": "package" - }, - "username": "tailscale", - "groupname": "tailscale", - "tool": [{ - "relpath": "bin/tailscaled", - "user": "package", - "group": "package", - "capabilities": "cap_net_admin,cap_chown,cap_net_raw" - }] -} +{ + "defaults":{ + "run-as": "package" + }, + "username": "tailscale", + "groupname": "tailscale", + "tool": [{ + "relpath": "bin/tailscaled", + "user": "package", + "group": "package", + "capabilities": "cap_net_admin,cap_chown,cap_net_raw" + }] +} diff --git a/release/dist/synology/files/resource b/release/dist/synology/files/resource index 706c97671ed47..0da0002ef2fb2 100644 --- a/release/dist/synology/files/resource +++ b/release/dist/synology/files/resource @@ -1,11 +1,11 @@ -{ - "port-config": { - "protocol-file": "conf/Tailscale.sc" - }, - "usr-local-linker": { - "bin": ["bin/tailscale"] - }, - "syslog-config": { - "logrotate-relpath": "conf/logrotate.conf" - } +{ + "port-config": { + "protocol-file": "conf/Tailscale.sc" + }, + "usr-local-linker": { + "bin": ["bin/tailscale"] + }, + "syslog-config": { + "logrotate-relpath": "conf/logrotate.conf" + } } \ No newline at end of file diff --git a/release/dist/synology/files/scripts/postupgrade b/release/dist/synology/files/scripts/postupgrade index 2a7fba5b6f483..92b94c40c5f2b 100644 --- a/release/dist/synology/files/scripts/postupgrade +++ b/release/dist/synology/files/scripts/postupgrade @@ -1,3 +1,3 @@ -#!/bin/sh - +#!/bin/sh + exit 0 \ No newline at end of file diff --git a/release/dist/synology/files/scripts/preupgrade b/release/dist/synology/files/scripts/preupgrade index 2a7fba5b6f483..92b94c40c5f2b 100644 --- a/release/dist/synology/files/scripts/preupgrade +++ b/release/dist/synology/files/scripts/preupgrade @@ -1,3 +1,3 @@ -#!/bin/sh - +#!/bin/sh + exit 0 \ No newline at end of file diff --git a/release/dist/synology/files/scripts/start-stop-status b/release/dist/synology/files/scripts/start-stop-status index 311f9293bd62a..e6ece04e3383e 100755 --- a/release/dist/synology/files/scripts/start-stop-status +++ b/release/dist/synology/files/scripts/start-stop-status @@ -1,129 +1,129 @@ -#!/bin/bash - -SERVICE_NAME="tailscale" - -if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "6" ]; then - PKGVAR="/var/packages/Tailscale/etc" -else - PKGVAR="${SYNOPKG_PKGVAR}" -fi - -PID_FILE="${PKGVAR}/tailscaled.pid" -LOG_FILE="${PKGVAR}/tailscaled.stdout.log" -STATE_FILE="${PKGVAR}/tailscaled.state" -SOCKET_FILE="${PKGVAR}/tailscaled.sock" -PORT="41641" - -SERVICE_COMMAND="${SYNOPKG_PKGDEST}/bin/tailscaled \ ---state=${STATE_FILE} \ ---socket=${SOCKET_FILE} \ ---port=$PORT" - -if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "7" -a ! -e "/dev/net/tun" ]; then - # TODO(maisem/crawshaw): Disable the tun device in DSM7 for now. - SERVICE_COMMAND="${SERVICE_COMMAND} --tun=userspace-networking" -fi - -if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "6" ]; then - chown -R tailscale:tailscale "${PKGVAR}/" -fi - -start_daemon() { - local ts=$(date --iso-8601=second) - echo "${ts} Starting ${SERVICE_NAME} with: ${SERVICE_COMMAND}" >${LOG_FILE} - STATE_DIRECTORY=${PKGVAR} ${SERVICE_COMMAND} 2>&1 | sed -u '1,200p;201s,.*,[further tailscaled logs suppressed],p;d' >>${LOG_FILE} & - # We pipe tailscaled's output to sed, so "$!" retrieves the PID of sed not tailscaled. - # Use jobs -p to retrieve the PID of the most recent process group leader. - jobs -p >"${PID_FILE}" -} - -stop_daemon() { - if [ -r "${PID_FILE}" ]; then - local PID=$(cat "${PID_FILE}") - local ts=$(date --iso-8601=second) - echo "${ts} Stopping ${SERVICE_NAME} service PID=${PID}" >>${LOG_FILE} - kill -TERM $PID >>${LOG_FILE} 2>&1 - wait_for_status 1 || kill -KILL $PID >>${LOG_FILE} 2>&1 - rm -f "${PID_FILE}" >/dev/null - fi -} - -daemon_status() { - if [ -r "${PID_FILE}" ]; then - local PID=$(cat "${PID_FILE}") - if ps -o pid -p ${PID} > /dev/null; then - return - fi - rm -f "${PID_FILE}" >/dev/null - fi - return 1 -} - -wait_for_status() { - # 20 tries - # sleeps for 1 second after each try - local counter=20 - while [ ${counter} -gt 0 ]; do - daemon_status - [ $? -eq $1 ] && return - counter=$((counter - 1)) - sleep 1 - done - return 1 -} - -ensure_tun_created() { - if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "7" ]; then - # TODO(maisem/crawshaw): Disable the tun device in DSM7 for now. - return - fi - # Create the necessary file structure for /dev/net/tun - if ([ ! -c /dev/net/tun ]); then - if ([ ! -d /dev/net ]); then - mkdir -m 755 /dev/net - fi - mknod /dev/net/tun c 10 200 - chmod 0755 /dev/net/tun - fi - - # Load the tun module if not already loaded - if (!(lsmod | grep -q "^tun\s")); then - insmod /lib/modules/tun.ko - fi -} - -case $1 in -start) - if daemon_status; then - exit 0 - else - ensure_tun_created - start_daemon - exit $? - fi - ;; -stop) - if daemon_status; then - stop_daemon - exit $? - else - exit 0 - fi - ;; -status) - if daemon_status; then - echo "${SERVICE_NAME} is running" - exit 0 - else - echo "${SERVICE_NAME} is not running" - exit 3 - fi - ;; -log) - exit 0 - ;; -*) - echo "command $1 is not implemented" - exit 0 - ;; -esac +#!/bin/bash + +SERVICE_NAME="tailscale" + +if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "6" ]; then + PKGVAR="/var/packages/Tailscale/etc" +else + PKGVAR="${SYNOPKG_PKGVAR}" +fi + +PID_FILE="${PKGVAR}/tailscaled.pid" +LOG_FILE="${PKGVAR}/tailscaled.stdout.log" +STATE_FILE="${PKGVAR}/tailscaled.state" +SOCKET_FILE="${PKGVAR}/tailscaled.sock" +PORT="41641" + +SERVICE_COMMAND="${SYNOPKG_PKGDEST}/bin/tailscaled \ +--state=${STATE_FILE} \ +--socket=${SOCKET_FILE} \ +--port=$PORT" + +if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "7" -a ! -e "/dev/net/tun" ]; then + # TODO(maisem/crawshaw): Disable the tun device in DSM7 for now. + SERVICE_COMMAND="${SERVICE_COMMAND} --tun=userspace-networking" +fi + +if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "6" ]; then + chown -R tailscale:tailscale "${PKGVAR}/" +fi + +start_daemon() { + local ts=$(date --iso-8601=second) + echo "${ts} Starting ${SERVICE_NAME} with: ${SERVICE_COMMAND}" >${LOG_FILE} + STATE_DIRECTORY=${PKGVAR} ${SERVICE_COMMAND} 2>&1 | sed -u '1,200p;201s,.*,[further tailscaled logs suppressed],p;d' >>${LOG_FILE} & + # We pipe tailscaled's output to sed, so "$!" retrieves the PID of sed not tailscaled. + # Use jobs -p to retrieve the PID of the most recent process group leader. + jobs -p >"${PID_FILE}" +} + +stop_daemon() { + if [ -r "${PID_FILE}" ]; then + local PID=$(cat "${PID_FILE}") + local ts=$(date --iso-8601=second) + echo "${ts} Stopping ${SERVICE_NAME} service PID=${PID}" >>${LOG_FILE} + kill -TERM $PID >>${LOG_FILE} 2>&1 + wait_for_status 1 || kill -KILL $PID >>${LOG_FILE} 2>&1 + rm -f "${PID_FILE}" >/dev/null + fi +} + +daemon_status() { + if [ -r "${PID_FILE}" ]; then + local PID=$(cat "${PID_FILE}") + if ps -o pid -p ${PID} > /dev/null; then + return + fi + rm -f "${PID_FILE}" >/dev/null + fi + return 1 +} + +wait_for_status() { + # 20 tries + # sleeps for 1 second after each try + local counter=20 + while [ ${counter} -gt 0 ]; do + daemon_status + [ $? -eq $1 ] && return + counter=$((counter - 1)) + sleep 1 + done + return 1 +} + +ensure_tun_created() { + if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "7" ]; then + # TODO(maisem/crawshaw): Disable the tun device in DSM7 for now. + return + fi + # Create the necessary file structure for /dev/net/tun + if ([ ! -c /dev/net/tun ]); then + if ([ ! -d /dev/net ]); then + mkdir -m 755 /dev/net + fi + mknod /dev/net/tun c 10 200 + chmod 0755 /dev/net/tun + fi + + # Load the tun module if not already loaded + if (!(lsmod | grep -q "^tun\s")); then + insmod /lib/modules/tun.ko + fi +} + +case $1 in +start) + if daemon_status; then + exit 0 + else + ensure_tun_created + start_daemon + exit $? + fi + ;; +stop) + if daemon_status; then + stop_daemon + exit $? + else + exit 0 + fi + ;; +status) + if daemon_status; then + echo "${SERVICE_NAME} is running" + exit 0 + else + echo "${SERVICE_NAME} is not running" + exit 3 + fi + ;; +log) + exit 0 + ;; +*) + echo "command $1 is not implemented" + exit 0 + ;; +esac diff --git a/release/dist/unixpkgs/pkgs.go b/release/dist/unixpkgs/pkgs.go index 60a038eb49d21..bad6ce572e675 100644 --- a/release/dist/unixpkgs/pkgs.go +++ b/release/dist/unixpkgs/pkgs.go @@ -1,472 +1,472 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package unixpkgs contains dist Targets for building unix Tailscale packages. -package unixpkgs - -import ( - "archive/tar" - "compress/gzip" - "errors" - "fmt" - "io" - "log" - "os" - "path/filepath" - "strings" - - "github.com/goreleaser/nfpm/v2" - "github.com/goreleaser/nfpm/v2/files" - "tailscale.com/release/dist" -) - -type tgzTarget struct { - filenameArch string // arch to use in filename instead of deriving from goEnv["GOARCH"] - goEnv map[string]string - signer dist.Signer -} - -func (t *tgzTarget) arch() string { - if t.filenameArch != "" { - return t.filenameArch - } - return t.goEnv["GOARCH"] -} - -func (t *tgzTarget) os() string { - return t.goEnv["GOOS"] -} - -func (t *tgzTarget) String() string { - return fmt.Sprintf("%s/%s/tgz", t.os(), t.arch()) -} - -func (t *tgzTarget) Build(b *dist.Build) ([]string, error) { - var filename string - if t.goEnv["GOOS"] == "linux" { - // Linux used to be the only tgz architecture, so we didn't put the OS - // name in the filename. - filename = fmt.Sprintf("tailscale_%s_%s.tgz", b.Version.Short, t.arch()) - } else { - filename = fmt.Sprintf("tailscale_%s_%s_%s.tgz", b.Version.Short, t.os(), t.arch()) - } - if err := b.BuildWebClientAssets(); err != nil { - return nil, err - } - ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) - if err != nil { - return nil, err - } - tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) - if err != nil { - return nil, err - } - - log.Printf("Building %s", filename) - - out := filepath.Join(b.Out, filename) - f, err := os.Create(out) - if err != nil { - return nil, err - } - defer f.Close() - gw := gzip.NewWriter(f) - defer gw.Close() - tw := tar.NewWriter(gw) - defer tw.Close() - - addFile := func(src, dst string, mode int64) error { - f, err := os.Open(src) - if err != nil { - return err - } - defer f.Close() - fi, err := f.Stat() - if err != nil { - return err - } - hdr := &tar.Header{ - Name: dst, - Size: fi.Size(), - Mode: mode, - ModTime: b.Time, - Uid: 0, - Gid: 0, - Uname: "root", - Gname: "root", - } - if err := tw.WriteHeader(hdr); err != nil { - return err - } - if _, err = io.Copy(tw, f); err != nil { - return err - } - return nil - } - addDir := func(name string) error { - hdr := &tar.Header{ - Name: name + "/", - Mode: 0755, - ModTime: b.Time, - Uid: 0, - Gid: 0, - Uname: "root", - Gname: "root", - } - return tw.WriteHeader(hdr) - } - dir := strings.TrimSuffix(filename, ".tgz") - if err := addDir(dir); err != nil { - return nil, err - } - if err := addFile(tsd, filepath.Join(dir, "tailscaled"), 0755); err != nil { - return nil, err - } - if err := addFile(ts, filepath.Join(dir, "tailscale"), 0755); err != nil { - return nil, err - } - if t.os() == "linux" { - dir = filepath.Join(dir, "systemd") - if err := addDir(dir); err != nil { - return nil, err - } - tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") - if err != nil { - return nil, err - } - if err := addFile(filepath.Join(tailscaledDir, "tailscaled.service"), filepath.Join(dir, "tailscaled.service"), 0644); err != nil { - return nil, err - } - if err := addFile(filepath.Join(tailscaledDir, "tailscaled.defaults"), filepath.Join(dir, "tailscaled.defaults"), 0644); err != nil { - return nil, err - } - } - if err := tw.Close(); err != nil { - return nil, err - } - if err := gw.Close(); err != nil { - return nil, err - } - if err := f.Close(); err != nil { - return nil, err - } - - files := []string{filename} - - if t.signer != nil { - outSig := out + ".sig" - if err := t.signer.SignFile(out, outSig); err != nil { - return nil, err - } - files = append(files, filepath.Base(outSig)) - } - - return files, nil -} - -type debTarget struct { - goEnv map[string]string -} - -func (t *debTarget) os() string { - return t.goEnv["GOOS"] -} - -func (t *debTarget) arch() string { - return t.goEnv["GOARCH"] -} - -func (t *debTarget) String() string { - return fmt.Sprintf("linux/%s/deb", t.goEnv["GOARCH"]) -} - -func (t *debTarget) Build(b *dist.Build) ([]string, error) { - if t.os() != "linux" { - return nil, errors.New("deb only supported on linux") - } - - if err := b.BuildWebClientAssets(); err != nil { - return nil, err - } - ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) - if err != nil { - return nil, err - } - tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) - if err != nil { - return nil, err - } - - tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") - if err != nil { - return nil, err - } - repoDir, err := b.GoPkg("tailscale.com") - if err != nil { - return nil, err - } - - arch := debArch(t.arch()) - contents, err := files.PrepareForPackager(files.Contents{ - &files.Content{ - Type: files.TypeFile, - Source: ts, - Destination: "/usr/bin/tailscale", - }, - &files.Content{ - Type: files.TypeFile, - Source: tsd, - Destination: "/usr/sbin/tailscaled", - }, - &files.Content{ - Type: files.TypeFile, - Source: filepath.Join(tailscaledDir, "tailscaled.service"), - Destination: "/lib/systemd/system/tailscaled.service", - }, - &files.Content{ - Type: files.TypeConfigNoReplace, - Source: filepath.Join(tailscaledDir, "tailscaled.defaults"), - Destination: "/etc/default/tailscaled", - }, - }, 0, "deb", false) - if err != nil { - return nil, err - } - info := nfpm.WithDefaults(&nfpm.Info{ - Name: "tailscale", - Arch: arch, - Platform: "linux", - Version: b.Version.Short, - Maintainer: "Tailscale Inc ", - Description: "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", - Homepage: "https://www.tailscale.com", - License: "MIT", - Section: "net", - Priority: "extra", - Overridables: nfpm.Overridables{ - Contents: contents, - Scripts: nfpm.Scripts{ - PostInstall: filepath.Join(repoDir, "release/deb/debian.postinst.sh"), - PreRemove: filepath.Join(repoDir, "release/deb/debian.prerm.sh"), - PostRemove: filepath.Join(repoDir, "release/deb/debian.postrm.sh"), - }, - Depends: []string{ - // iptables is almost always required but not strictly needed. - // Even if you can technically run Tailscale without it (by - // manually configuring nftables or userspace mode), we still - // mark this as "Depends" because our previous experiment in - // https://github.com/tailscale/tailscale/issues/9236 of making - // it only Recommends caused too many problems. Until our - // nftables table is more mature, we'd rather err on the side of - // wasting a little disk by including iptables for people who - // might not need it rather than handle reports of it being - // missing. - "iptables", - }, - Recommends: []string{ - "tailscale-archive-keyring (>= 1.35.181)", - // The "ip" command isn't needed since 2021-11-01 in - // 408b0923a61972ed but kept as an option as of - // 2021-11-18 in d24ed3f68e35e802d531371. See - // https://github.com/tailscale/tailscale/issues/391. - // We keep it recommended because it's usually - // installed anyway and it's useful for debugging. But - // we can live without it, so it's not Depends. - "iproute2", - }, - Replaces: []string{"tailscale-relay"}, - Conflicts: []string{"tailscale-relay"}, - }, - }) - pkg, err := nfpm.Get("deb") - if err != nil { - return nil, err - } - - filename := fmt.Sprintf("tailscale_%s_%s.deb", b.Version.Short, arch) - log.Printf("Building %s", filename) - f, err := os.Create(filepath.Join(b.Out, filename)) - if err != nil { - return nil, err - } - defer f.Close() - if err := pkg.Package(info, f); err != nil { - return nil, err - } - if err := f.Close(); err != nil { - return nil, err - } - - return []string{filename}, nil -} - -type rpmTarget struct { - goEnv map[string]string - signer dist.Signer -} - -func (t *rpmTarget) os() string { - return t.goEnv["GOOS"] -} - -func (t *rpmTarget) arch() string { - return t.goEnv["GOARCH"] -} - -func (t *rpmTarget) String() string { - return fmt.Sprintf("linux/%s/rpm", t.arch()) -} - -func (t *rpmTarget) Build(b *dist.Build) ([]string, error) { - if t.os() != "linux" { - return nil, errors.New("rpm only supported on linux") - } - - if err := b.BuildWebClientAssets(); err != nil { - return nil, err - } - ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) - if err != nil { - return nil, err - } - tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) - if err != nil { - return nil, err - } - - tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") - if err != nil { - return nil, err - } - repoDir, err := b.GoPkg("tailscale.com") - if err != nil { - return nil, err - } - - arch := rpmArch(t.arch()) - contents, err := files.PrepareForPackager(files.Contents{ - &files.Content{ - Type: files.TypeFile, - Source: ts, - Destination: "/usr/bin/tailscale", - }, - &files.Content{ - Type: files.TypeFile, - Source: tsd, - Destination: "/usr/sbin/tailscaled", - }, - &files.Content{ - Type: files.TypeFile, - Source: filepath.Join(tailscaledDir, "tailscaled.service"), - Destination: "/lib/systemd/system/tailscaled.service", - }, - &files.Content{ - Type: files.TypeConfigNoReplace, - Source: filepath.Join(tailscaledDir, "tailscaled.defaults"), - Destination: "/etc/default/tailscaled", - }, - // SELinux policy on e.g. CentOS 8 forbids writing to /var/cache. - // Creating an empty directory at install time resolves this issue. - &files.Content{ - Type: files.TypeDir, - Destination: "/var/cache/tailscale", - }, - }, 0, "rpm", false) - if err != nil { - return nil, err - } - info := nfpm.WithDefaults(&nfpm.Info{ - Name: "tailscale", - Arch: arch, - Platform: "linux", - Version: b.Version.Short, - Maintainer: "Tailscale Inc ", - Description: "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", - Homepage: "https://www.tailscale.com", - License: "MIT", - Overridables: nfpm.Overridables{ - Contents: contents, - Scripts: nfpm.Scripts{ - PostInstall: filepath.Join(repoDir, "release/rpm/rpm.postinst.sh"), - PreRemove: filepath.Join(repoDir, "release/rpm/rpm.prerm.sh"), - PostRemove: filepath.Join(repoDir, "release/rpm/rpm.postrm.sh"), - }, - Depends: []string{"iptables", "iproute"}, - Replaces: []string{"tailscale-relay"}, - Conflicts: []string{"tailscale-relay"}, - RPM: nfpm.RPM{ - Group: "Network", - Signature: nfpm.RPMSignature{ - PackageSignature: nfpm.PackageSignature{ - SignFn: t.signer, - }, - }, - }, - }, - }) - pkg, err := nfpm.Get("rpm") - if err != nil { - return nil, err - } - - filename := fmt.Sprintf("tailscale_%s_%s.rpm", b.Version.Short, arch) - log.Printf("Building %s", filename) - - f, err := os.Create(filepath.Join(b.Out, filename)) - if err != nil { - return nil, err - } - defer f.Close() - if err := pkg.Package(info, f); err != nil { - return nil, err - } - if err := f.Close(); err != nil { - return nil, err - } - - return []string{filename}, nil -} - -// debArch returns the debian arch name for the given Go arch name. -// nfpm also does this translation internally, but we need to do it outside nfpm -// because we also need the filename to be correct. -func debArch(arch string) string { - switch arch { - case "386": - return "i386" - case "arm": - // TODO: this is supposed to be "armel" for GOARM=5, and "armhf" for - // GOARM=6 and 7. But we have some tech debt to pay off here before we - // can ship more than 1 ARM deb, so for now match redo's behavior of - // shipping armv5 binaries in an armv7 trenchcoat. - return "armhf" - case "mipsle": - return "mipsel" - case "mips64le": - return "mips64el" - default: - return arch - } -} - -// rpmArch returns the RPM arch name for the given Go arch name. -// nfpm also does this translation internally, but we need to do it outside nfpm -// because we also need the filename to be correct. -func rpmArch(arch string) string { - switch arch { - case "amd64": - return "x86_64" - case "386": - return "i386" - case "arm": - return "armv7hl" - case "arm64": - return "aarch64" - case "mipsle": - return "mipsel" - case "mips64le": - return "mips64el" - default: - return arch - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package unixpkgs contains dist Targets for building unix Tailscale packages. +package unixpkgs + +import ( + "archive/tar" + "compress/gzip" + "errors" + "fmt" + "io" + "log" + "os" + "path/filepath" + "strings" + + "github.com/goreleaser/nfpm/v2" + "github.com/goreleaser/nfpm/v2/files" + "tailscale.com/release/dist" +) + +type tgzTarget struct { + filenameArch string // arch to use in filename instead of deriving from goEnv["GOARCH"] + goEnv map[string]string + signer dist.Signer +} + +func (t *tgzTarget) arch() string { + if t.filenameArch != "" { + return t.filenameArch + } + return t.goEnv["GOARCH"] +} + +func (t *tgzTarget) os() string { + return t.goEnv["GOOS"] +} + +func (t *tgzTarget) String() string { + return fmt.Sprintf("%s/%s/tgz", t.os(), t.arch()) +} + +func (t *tgzTarget) Build(b *dist.Build) ([]string, error) { + var filename string + if t.goEnv["GOOS"] == "linux" { + // Linux used to be the only tgz architecture, so we didn't put the OS + // name in the filename. + filename = fmt.Sprintf("tailscale_%s_%s.tgz", b.Version.Short, t.arch()) + } else { + filename = fmt.Sprintf("tailscale_%s_%s_%s.tgz", b.Version.Short, t.os(), t.arch()) + } + if err := b.BuildWebClientAssets(); err != nil { + return nil, err + } + ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) + if err != nil { + return nil, err + } + tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) + if err != nil { + return nil, err + } + + log.Printf("Building %s", filename) + + out := filepath.Join(b.Out, filename) + f, err := os.Create(out) + if err != nil { + return nil, err + } + defer f.Close() + gw := gzip.NewWriter(f) + defer gw.Close() + tw := tar.NewWriter(gw) + defer tw.Close() + + addFile := func(src, dst string, mode int64) error { + f, err := os.Open(src) + if err != nil { + return err + } + defer f.Close() + fi, err := f.Stat() + if err != nil { + return err + } + hdr := &tar.Header{ + Name: dst, + Size: fi.Size(), + Mode: mode, + ModTime: b.Time, + Uid: 0, + Gid: 0, + Uname: "root", + Gname: "root", + } + if err := tw.WriteHeader(hdr); err != nil { + return err + } + if _, err = io.Copy(tw, f); err != nil { + return err + } + return nil + } + addDir := func(name string) error { + hdr := &tar.Header{ + Name: name + "/", + Mode: 0755, + ModTime: b.Time, + Uid: 0, + Gid: 0, + Uname: "root", + Gname: "root", + } + return tw.WriteHeader(hdr) + } + dir := strings.TrimSuffix(filename, ".tgz") + if err := addDir(dir); err != nil { + return nil, err + } + if err := addFile(tsd, filepath.Join(dir, "tailscaled"), 0755); err != nil { + return nil, err + } + if err := addFile(ts, filepath.Join(dir, "tailscale"), 0755); err != nil { + return nil, err + } + if t.os() == "linux" { + dir = filepath.Join(dir, "systemd") + if err := addDir(dir); err != nil { + return nil, err + } + tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") + if err != nil { + return nil, err + } + if err := addFile(filepath.Join(tailscaledDir, "tailscaled.service"), filepath.Join(dir, "tailscaled.service"), 0644); err != nil { + return nil, err + } + if err := addFile(filepath.Join(tailscaledDir, "tailscaled.defaults"), filepath.Join(dir, "tailscaled.defaults"), 0644); err != nil { + return nil, err + } + } + if err := tw.Close(); err != nil { + return nil, err + } + if err := gw.Close(); err != nil { + return nil, err + } + if err := f.Close(); err != nil { + return nil, err + } + + files := []string{filename} + + if t.signer != nil { + outSig := out + ".sig" + if err := t.signer.SignFile(out, outSig); err != nil { + return nil, err + } + files = append(files, filepath.Base(outSig)) + } + + return files, nil +} + +type debTarget struct { + goEnv map[string]string +} + +func (t *debTarget) os() string { + return t.goEnv["GOOS"] +} + +func (t *debTarget) arch() string { + return t.goEnv["GOARCH"] +} + +func (t *debTarget) String() string { + return fmt.Sprintf("linux/%s/deb", t.goEnv["GOARCH"]) +} + +func (t *debTarget) Build(b *dist.Build) ([]string, error) { + if t.os() != "linux" { + return nil, errors.New("deb only supported on linux") + } + + if err := b.BuildWebClientAssets(); err != nil { + return nil, err + } + ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) + if err != nil { + return nil, err + } + tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) + if err != nil { + return nil, err + } + + tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") + if err != nil { + return nil, err + } + repoDir, err := b.GoPkg("tailscale.com") + if err != nil { + return nil, err + } + + arch := debArch(t.arch()) + contents, err := files.PrepareForPackager(files.Contents{ + &files.Content{ + Type: files.TypeFile, + Source: ts, + Destination: "/usr/bin/tailscale", + }, + &files.Content{ + Type: files.TypeFile, + Source: tsd, + Destination: "/usr/sbin/tailscaled", + }, + &files.Content{ + Type: files.TypeFile, + Source: filepath.Join(tailscaledDir, "tailscaled.service"), + Destination: "/lib/systemd/system/tailscaled.service", + }, + &files.Content{ + Type: files.TypeConfigNoReplace, + Source: filepath.Join(tailscaledDir, "tailscaled.defaults"), + Destination: "/etc/default/tailscaled", + }, + }, 0, "deb", false) + if err != nil { + return nil, err + } + info := nfpm.WithDefaults(&nfpm.Info{ + Name: "tailscale", + Arch: arch, + Platform: "linux", + Version: b.Version.Short, + Maintainer: "Tailscale Inc ", + Description: "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", + Homepage: "https://www.tailscale.com", + License: "MIT", + Section: "net", + Priority: "extra", + Overridables: nfpm.Overridables{ + Contents: contents, + Scripts: nfpm.Scripts{ + PostInstall: filepath.Join(repoDir, "release/deb/debian.postinst.sh"), + PreRemove: filepath.Join(repoDir, "release/deb/debian.prerm.sh"), + PostRemove: filepath.Join(repoDir, "release/deb/debian.postrm.sh"), + }, + Depends: []string{ + // iptables is almost always required but not strictly needed. + // Even if you can technically run Tailscale without it (by + // manually configuring nftables or userspace mode), we still + // mark this as "Depends" because our previous experiment in + // https://github.com/tailscale/tailscale/issues/9236 of making + // it only Recommends caused too many problems. Until our + // nftables table is more mature, we'd rather err on the side of + // wasting a little disk by including iptables for people who + // might not need it rather than handle reports of it being + // missing. + "iptables", + }, + Recommends: []string{ + "tailscale-archive-keyring (>= 1.35.181)", + // The "ip" command isn't needed since 2021-11-01 in + // 408b0923a61972ed but kept as an option as of + // 2021-11-18 in d24ed3f68e35e802d531371. See + // https://github.com/tailscale/tailscale/issues/391. + // We keep it recommended because it's usually + // installed anyway and it's useful for debugging. But + // we can live without it, so it's not Depends. + "iproute2", + }, + Replaces: []string{"tailscale-relay"}, + Conflicts: []string{"tailscale-relay"}, + }, + }) + pkg, err := nfpm.Get("deb") + if err != nil { + return nil, err + } + + filename := fmt.Sprintf("tailscale_%s_%s.deb", b.Version.Short, arch) + log.Printf("Building %s", filename) + f, err := os.Create(filepath.Join(b.Out, filename)) + if err != nil { + return nil, err + } + defer f.Close() + if err := pkg.Package(info, f); err != nil { + return nil, err + } + if err := f.Close(); err != nil { + return nil, err + } + + return []string{filename}, nil +} + +type rpmTarget struct { + goEnv map[string]string + signer dist.Signer +} + +func (t *rpmTarget) os() string { + return t.goEnv["GOOS"] +} + +func (t *rpmTarget) arch() string { + return t.goEnv["GOARCH"] +} + +func (t *rpmTarget) String() string { + return fmt.Sprintf("linux/%s/rpm", t.arch()) +} + +func (t *rpmTarget) Build(b *dist.Build) ([]string, error) { + if t.os() != "linux" { + return nil, errors.New("rpm only supported on linux") + } + + if err := b.BuildWebClientAssets(); err != nil { + return nil, err + } + ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) + if err != nil { + return nil, err + } + tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) + if err != nil { + return nil, err + } + + tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") + if err != nil { + return nil, err + } + repoDir, err := b.GoPkg("tailscale.com") + if err != nil { + return nil, err + } + + arch := rpmArch(t.arch()) + contents, err := files.PrepareForPackager(files.Contents{ + &files.Content{ + Type: files.TypeFile, + Source: ts, + Destination: "/usr/bin/tailscale", + }, + &files.Content{ + Type: files.TypeFile, + Source: tsd, + Destination: "/usr/sbin/tailscaled", + }, + &files.Content{ + Type: files.TypeFile, + Source: filepath.Join(tailscaledDir, "tailscaled.service"), + Destination: "/lib/systemd/system/tailscaled.service", + }, + &files.Content{ + Type: files.TypeConfigNoReplace, + Source: filepath.Join(tailscaledDir, "tailscaled.defaults"), + Destination: "/etc/default/tailscaled", + }, + // SELinux policy on e.g. CentOS 8 forbids writing to /var/cache. + // Creating an empty directory at install time resolves this issue. + &files.Content{ + Type: files.TypeDir, + Destination: "/var/cache/tailscale", + }, + }, 0, "rpm", false) + if err != nil { + return nil, err + } + info := nfpm.WithDefaults(&nfpm.Info{ + Name: "tailscale", + Arch: arch, + Platform: "linux", + Version: b.Version.Short, + Maintainer: "Tailscale Inc ", + Description: "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", + Homepage: "https://www.tailscale.com", + License: "MIT", + Overridables: nfpm.Overridables{ + Contents: contents, + Scripts: nfpm.Scripts{ + PostInstall: filepath.Join(repoDir, "release/rpm/rpm.postinst.sh"), + PreRemove: filepath.Join(repoDir, "release/rpm/rpm.prerm.sh"), + PostRemove: filepath.Join(repoDir, "release/rpm/rpm.postrm.sh"), + }, + Depends: []string{"iptables", "iproute"}, + Replaces: []string{"tailscale-relay"}, + Conflicts: []string{"tailscale-relay"}, + RPM: nfpm.RPM{ + Group: "Network", + Signature: nfpm.RPMSignature{ + PackageSignature: nfpm.PackageSignature{ + SignFn: t.signer, + }, + }, + }, + }, + }) + pkg, err := nfpm.Get("rpm") + if err != nil { + return nil, err + } + + filename := fmt.Sprintf("tailscale_%s_%s.rpm", b.Version.Short, arch) + log.Printf("Building %s", filename) + + f, err := os.Create(filepath.Join(b.Out, filename)) + if err != nil { + return nil, err + } + defer f.Close() + if err := pkg.Package(info, f); err != nil { + return nil, err + } + if err := f.Close(); err != nil { + return nil, err + } + + return []string{filename}, nil +} + +// debArch returns the debian arch name for the given Go arch name. +// nfpm also does this translation internally, but we need to do it outside nfpm +// because we also need the filename to be correct. +func debArch(arch string) string { + switch arch { + case "386": + return "i386" + case "arm": + // TODO: this is supposed to be "armel" for GOARM=5, and "armhf" for + // GOARM=6 and 7. But we have some tech debt to pay off here before we + // can ship more than 1 ARM deb, so for now match redo's behavior of + // shipping armv5 binaries in an armv7 trenchcoat. + return "armhf" + case "mipsle": + return "mipsel" + case "mips64le": + return "mips64el" + default: + return arch + } +} + +// rpmArch returns the RPM arch name for the given Go arch name. +// nfpm also does this translation internally, but we need to do it outside nfpm +// because we also need the filename to be correct. +func rpmArch(arch string) string { + switch arch { + case "amd64": + return "x86_64" + case "386": + return "i386" + case "arm": + return "armv7hl" + case "arm64": + return "aarch64" + case "mipsle": + return "mipsel" + case "mips64le": + return "mips64el" + default: + return arch + } +} diff --git a/release/dist/unixpkgs/targets.go b/release/dist/unixpkgs/targets.go index f87c56d317d9f..42bab6d3b2685 100644 --- a/release/dist/unixpkgs/targets.go +++ b/release/dist/unixpkgs/targets.go @@ -1,127 +1,127 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package unixpkgs - -import ( - "fmt" - "sort" - "strings" - - "tailscale.com/release/dist" - - _ "github.com/goreleaser/nfpm/v2/deb" - _ "github.com/goreleaser/nfpm/v2/rpm" -) - -type Signers struct { - Tarball dist.Signer - RPM dist.Signer -} - -func Targets(signers Signers) []dist.Target { - var ret []dist.Target - for goosgoarch := range tarballs { - goos, goarch := splitGoosGoarch(goosgoarch) - ret = append(ret, &tgzTarget{ - goEnv: map[string]string{ - "GOOS": goos, - "GOARCH": goarch, - }, - signer: signers.Tarball, - }) - } - for goosgoarch := range debs { - goos, goarch := splitGoosGoarch(goosgoarch) - ret = append(ret, &debTarget{ - goEnv: map[string]string{ - "GOOS": goos, - "GOARCH": goarch, - }, - }) - } - for goosgoarch := range rpms { - goos, goarch := splitGoosGoarch(goosgoarch) - ret = append(ret, &rpmTarget{ - goEnv: map[string]string{ - "GOOS": goos, - "GOARCH": goarch, - }, - signer: signers.RPM, - }) - } - - // Special case: AMD Geode is 386 with softfloat. Tarballs only since it's - // an ancient architecture. - ret = append(ret, &tgzTarget{ - filenameArch: "geode", - goEnv: map[string]string{ - "GOOS": "linux", - "GOARCH": "386", - "GO386": "softfloat", - }, - signer: signers.Tarball, - }) - - sort.Slice(ret, func(i, j int) bool { - return ret[i].String() < ret[j].String() - }) - - return ret -} - -var ( - tarballs = map[string]bool{ - "linux/386": true, - "linux/amd64": true, - "linux/arm": true, - "linux/arm64": true, - "linux/mips64": true, - "linux/mips64le": true, - "linux/mips": true, - "linux/mipsle": true, - "linux/riscv64": true, - // TODO: more tarballs we could distribute, but don't currently. Leaving - // out for initial parity with redo. - // "darwin/amd64": true, - // "darwin/arm64": true, - // "freebsd/amd64": true, - // "openbsd/amd64": true, - } - - debs = map[string]bool{ - "linux/386": true, - "linux/amd64": true, - "linux/arm": true, - "linux/arm64": true, - "linux/riscv64": true, - "linux/mipsle": true, - "linux/mips64le": true, - "linux/mips": true, - // Debian does not support big endian mips64. Leave that out until we know - // we need it. - // "linux/mips64": true, - } - - rpms = map[string]bool{ - "linux/386": true, - "linux/amd64": true, - "linux/arm": true, - "linux/arm64": true, - "linux/riscv64": true, - "linux/mipsle": true, - "linux/mips64le": true, - // Fedora only supports little endian mipses. Maybe some other distribution - // supports big-endian? Leave them out for now. - // "linux/mips": true, - // "linux/mips64": true, - } -) - -func splitGoosGoarch(s string) (string, string) { - goos, goarch, ok := strings.Cut(s, "/") - if !ok { - panic(fmt.Sprintf("invalid target %q", s)) - } - return goos, goarch -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package unixpkgs + +import ( + "fmt" + "sort" + "strings" + + "tailscale.com/release/dist" + + _ "github.com/goreleaser/nfpm/v2/deb" + _ "github.com/goreleaser/nfpm/v2/rpm" +) + +type Signers struct { + Tarball dist.Signer + RPM dist.Signer +} + +func Targets(signers Signers) []dist.Target { + var ret []dist.Target + for goosgoarch := range tarballs { + goos, goarch := splitGoosGoarch(goosgoarch) + ret = append(ret, &tgzTarget{ + goEnv: map[string]string{ + "GOOS": goos, + "GOARCH": goarch, + }, + signer: signers.Tarball, + }) + } + for goosgoarch := range debs { + goos, goarch := splitGoosGoarch(goosgoarch) + ret = append(ret, &debTarget{ + goEnv: map[string]string{ + "GOOS": goos, + "GOARCH": goarch, + }, + }) + } + for goosgoarch := range rpms { + goos, goarch := splitGoosGoarch(goosgoarch) + ret = append(ret, &rpmTarget{ + goEnv: map[string]string{ + "GOOS": goos, + "GOARCH": goarch, + }, + signer: signers.RPM, + }) + } + + // Special case: AMD Geode is 386 with softfloat. Tarballs only since it's + // an ancient architecture. + ret = append(ret, &tgzTarget{ + filenameArch: "geode", + goEnv: map[string]string{ + "GOOS": "linux", + "GOARCH": "386", + "GO386": "softfloat", + }, + signer: signers.Tarball, + }) + + sort.Slice(ret, func(i, j int) bool { + return ret[i].String() < ret[j].String() + }) + + return ret +} + +var ( + tarballs = map[string]bool{ + "linux/386": true, + "linux/amd64": true, + "linux/arm": true, + "linux/arm64": true, + "linux/mips64": true, + "linux/mips64le": true, + "linux/mips": true, + "linux/mipsle": true, + "linux/riscv64": true, + // TODO: more tarballs we could distribute, but don't currently. Leaving + // out for initial parity with redo. + // "darwin/amd64": true, + // "darwin/arm64": true, + // "freebsd/amd64": true, + // "openbsd/amd64": true, + } + + debs = map[string]bool{ + "linux/386": true, + "linux/amd64": true, + "linux/arm": true, + "linux/arm64": true, + "linux/riscv64": true, + "linux/mipsle": true, + "linux/mips64le": true, + "linux/mips": true, + // Debian does not support big endian mips64. Leave that out until we know + // we need it. + // "linux/mips64": true, + } + + rpms = map[string]bool{ + "linux/386": true, + "linux/amd64": true, + "linux/arm": true, + "linux/arm64": true, + "linux/riscv64": true, + "linux/mipsle": true, + "linux/mips64le": true, + // Fedora only supports little endian mipses. Maybe some other distribution + // supports big-endian? Leave them out for now. + // "linux/mips": true, + // "linux/mips64": true, + } +) + +func splitGoosGoarch(s string) (string, string) { + goos, goarch, ok := strings.Cut(s, "/") + if !ok { + panic(fmt.Sprintf("invalid target %q", s)) + } + return goos, goarch +} diff --git a/release/release.go b/release/release.go index 638635b6d23e9..a8d0e6b62e8d7 100644 --- a/release/release.go +++ b/release/release.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package release provides functionality for building client releases. -package release - -import "embed" - -// This contains all files in the release directory, -// notably the files needed for deb, rpm, and similar packages. -// Because we assign this to the blank identifier, it does not actually embed the files. -// However, this does cause `go mod vendor` to include the files when vendoring the package. -// -//go:embed * -var _ embed.FS +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package release provides functionality for building client releases. +package release + +import "embed" + +// This contains all files in the release directory, +// notably the files needed for deb, rpm, and similar packages. +// Because we assign this to the blank identifier, it does not actually embed the files. +// However, this does cause `go mod vendor` to include the files when vendoring the package. +// +//go:embed * +var _ embed.FS diff --git a/release/rpm/rpm.postinst.sh b/release/rpm/rpm.postinst.sh index f9c1fddfdfc73..3d264c5f60b18 100755 --- a/release/rpm/rpm.postinst.sh +++ b/release/rpm/rpm.postinst.sh @@ -1,41 +1,41 @@ -# $1 == 1 for initial installation. -# $1 == 2 for upgrades. - -if [ $1 -eq 1 ] ; then - # Normally, the tailscale-relay package would request shutdown of - # its service before uninstallation. Unfortunately, the - # tailscale-relay package we distributed doesn't have those - # scriptlets. We definitely want relaynode to be stopped when - # installing tailscaled though, so we blindly try to turn off - # relaynode here. - # - # However, we also want this package installation to look like an - # upgrade from relaynode! Therefore, if relaynode is currently - # enabled, we want to also enable tailscaled. If relaynode is - # currently running, we also want to start tailscaled. - # - # If there doesn't seem to be an active or enabled relaynode on - # the system, we follow the RPM convention for package installs, - # which is to not enable or start the service. - relaynode_enabled=0 - relaynode_running=0 - if systemctl is-enabled tailscale-relay.service >/dev/null 2>&1; then - relaynode_enabled=1 - fi - if systemctl is-active tailscale-relay.service >/dev/null 2>&1; then - relaynode_running=1 - fi - - systemctl --no-reload disable tailscale-relay.service >/dev/null 2>&1 || : - systemctl stop tailscale-relay.service >/dev/null 2>&1 || : - - if [ $relaynode_enabled -eq 1 ]; then - systemctl enable tailscaled.service >/dev/null 2>&1 || : - else - systemctl preset tailscaled.service >/dev/null 2>&1 || : - fi - - if [ $relaynode_running -eq 1 ]; then - systemctl start tailscaled.service >/dev/null 2>&1 || : - fi -fi +# $1 == 1 for initial installation. +# $1 == 2 for upgrades. + +if [ $1 -eq 1 ] ; then + # Normally, the tailscale-relay package would request shutdown of + # its service before uninstallation. Unfortunately, the + # tailscale-relay package we distributed doesn't have those + # scriptlets. We definitely want relaynode to be stopped when + # installing tailscaled though, so we blindly try to turn off + # relaynode here. + # + # However, we also want this package installation to look like an + # upgrade from relaynode! Therefore, if relaynode is currently + # enabled, we want to also enable tailscaled. If relaynode is + # currently running, we also want to start tailscaled. + # + # If there doesn't seem to be an active or enabled relaynode on + # the system, we follow the RPM convention for package installs, + # which is to not enable or start the service. + relaynode_enabled=0 + relaynode_running=0 + if systemctl is-enabled tailscale-relay.service >/dev/null 2>&1; then + relaynode_enabled=1 + fi + if systemctl is-active tailscale-relay.service >/dev/null 2>&1; then + relaynode_running=1 + fi + + systemctl --no-reload disable tailscale-relay.service >/dev/null 2>&1 || : + systemctl stop tailscale-relay.service >/dev/null 2>&1 || : + + if [ $relaynode_enabled -eq 1 ]; then + systemctl enable tailscaled.service >/dev/null 2>&1 || : + else + systemctl preset tailscaled.service >/dev/null 2>&1 || : + fi + + if [ $relaynode_running -eq 1 ]; then + systemctl start tailscaled.service >/dev/null 2>&1 || : + fi +fi diff --git a/release/rpm/rpm.postrm.sh b/release/rpm/rpm.postrm.sh index e19a7305cac23..d74f7e9deac77 100755 --- a/release/rpm/rpm.postrm.sh +++ b/release/rpm/rpm.postrm.sh @@ -1,8 +1,8 @@ -# $1 == 0 for uninstallation. -# $1 == 1 for removing old package during upgrade. - -systemctl daemon-reload >/dev/null 2>&1 || : -if [ $1 -ge 1 ] ; then - # Package upgrade, not uninstall - systemctl try-restart tailscaled.service >/dev/null 2>&1 || : -fi +# $1 == 0 for uninstallation. +# $1 == 1 for removing old package during upgrade. + +systemctl daemon-reload >/dev/null 2>&1 || : +if [ $1 -ge 1 ] ; then + # Package upgrade, not uninstall + systemctl try-restart tailscaled.service >/dev/null 2>&1 || : +fi diff --git a/release/rpm/rpm.prerm.sh b/release/rpm/rpm.prerm.sh index eeabf3b584721..682c01bd574d8 100755 --- a/release/rpm/rpm.prerm.sh +++ b/release/rpm/rpm.prerm.sh @@ -1,8 +1,8 @@ -# $1 == 0 for uninstallation. -# $1 == 1 for removing old package during upgrade. - -if [ $1 -eq 0 ] ; then - # Package removal, not upgrade - systemctl --no-reload disable tailscaled.service > /dev/null 2>&1 || : - systemctl stop tailscaled.service > /dev/null 2>&1 || : -fi +# $1 == 0 for uninstallation. +# $1 == 1 for removing old package during upgrade. + +if [ $1 -eq 0 ] ; then + # Package removal, not upgrade + systemctl --no-reload disable tailscaled.service > /dev/null 2>&1 || : + systemctl stop tailscaled.service > /dev/null 2>&1 || : +fi diff --git a/safesocket/safesocket_test.go b/safesocket/safesocket_test.go index 85b317bd6e70f..3f36a1cf6ca1f 100644 --- a/safesocket/safesocket_test.go +++ b/safesocket/safesocket_test.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package safesocket - -import "testing" - -func TestLocalTCPPortAndToken(t *testing.T) { - // Just test that it compiles for now (is available on all platforms). - port, token, err := LocalTCPPortAndToken() - t.Logf("got %v, %s, %v", port, token, err) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package safesocket + +import "testing" + +func TestLocalTCPPortAndToken(t *testing.T) { + // Just test that it compiles for now (is available on all platforms). + port, token, err := LocalTCPPortAndToken() + t.Logf("got %v, %s, %v", port, token, err) +} diff --git a/smallzstd/testdata b/smallzstd/testdata index 498b014fd8d36..76640fdc57df0 100644 --- a/smallzstd/testdata +++ b/smallzstd/testdata @@ -1,14 +1,14 @@ -{"logtail":{"client_time":"2020-07-01T14:49:40.196597018-07:00","server_time":"2020-07-01T21:49:40.198371511Z"},"text":"9.8M/25.6M magicsock: starting endpoint update (periodic)\n"} -{"logtail":{"client_time":"2020-07-01T14:49:40.345925455-07:00","server_time":"2020-07-01T21:49:40.347904717Z"},"text":"9.9M/25.6M netcheck: udp=true v6=false mapvarydest=false hair=false v4a=202.188.7.1:41641 derp=2 derpdist=1v4:7ms,2v4:3ms,4v4:18ms\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347155742-07:00","server_time":"2020-07-01T21:49:43.34828658Z"},"text":"9.9M/25.6M control: map response long-poll timed out!\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347539333-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.9M/25.6M control: PollNetMap: context canceled\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347767812-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M control: sendStatus: mapRoutine1: state:authenticated\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347817165-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M blockEngineUpdates(false)\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347989028-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M wgcfg: [SViTM] skipping subnet route\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.349997554-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M Received error: PollNetMap: context canceled\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.350072606-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M control: mapRoutine: backoff: 30136 msec\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.998364646-07:00","server_time":"2020-07-01T21:49:47.999333754Z"},"text":"9.5M/25.6M [W1NbE] - [UcppE] Send handshake init [127.3.3.40:1, 6.1.1.6:37388*, 10.3.2.6:41641]\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.99881914-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: adding connection to derp-1 for [W1NbE]\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.998904932-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: 2 active derp conns: derp-1=cr0s,wr0s derp-2=cr16h0m0s,wr14h38m0s\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.999045606-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M derphttp.Client.Recv: connecting to derp-1 (nyc)\n"} -{"logtail":{"client_time":"2020-07-01T14:49:48.091104119-07:00","server_time":"2020-07-01T21:49:48.09280535Z"},"text":"9.6M/25.6M magicsock: rx [W1NbE] from 6.1.1.6:37388 (1/3), set as new priority\n"} +{"logtail":{"client_time":"2020-07-01T14:49:40.196597018-07:00","server_time":"2020-07-01T21:49:40.198371511Z"},"text":"9.8M/25.6M magicsock: starting endpoint update (periodic)\n"} +{"logtail":{"client_time":"2020-07-01T14:49:40.345925455-07:00","server_time":"2020-07-01T21:49:40.347904717Z"},"text":"9.9M/25.6M netcheck: udp=true v6=false mapvarydest=false hair=false v4a=202.188.7.1:41641 derp=2 derpdist=1v4:7ms,2v4:3ms,4v4:18ms\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.347155742-07:00","server_time":"2020-07-01T21:49:43.34828658Z"},"text":"9.9M/25.6M control: map response long-poll timed out!\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.347539333-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.9M/25.6M control: PollNetMap: context canceled\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.347767812-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M control: sendStatus: mapRoutine1: state:authenticated\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.347817165-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M blockEngineUpdates(false)\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.347989028-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M wgcfg: [SViTM] skipping subnet route\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.349997554-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M Received error: PollNetMap: context canceled\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.350072606-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M control: mapRoutine: backoff: 30136 msec\n"} +{"logtail":{"client_time":"2020-07-01T14:49:47.998364646-07:00","server_time":"2020-07-01T21:49:47.999333754Z"},"text":"9.5M/25.6M [W1NbE] - [UcppE] Send handshake init [127.3.3.40:1, 6.1.1.6:37388*, 10.3.2.6:41641]\n"} +{"logtail":{"client_time":"2020-07-01T14:49:47.99881914-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: adding connection to derp-1 for [W1NbE]\n"} +{"logtail":{"client_time":"2020-07-01T14:49:47.998904932-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: 2 active derp conns: derp-1=cr0s,wr0s derp-2=cr16h0m0s,wr14h38m0s\n"} +{"logtail":{"client_time":"2020-07-01T14:49:47.999045606-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M derphttp.Client.Recv: connecting to derp-1 (nyc)\n"} +{"logtail":{"client_time":"2020-07-01T14:49:48.091104119-07:00","server_time":"2020-07-01T21:49:48.09280535Z"},"text":"9.6M/25.6M magicsock: rx [W1NbE] from 6.1.1.6:37388 (1/3), set as new priority\n"} diff --git a/smallzstd/zstd.go b/smallzstd/zstd.go index d91afeb67e254..1d80854224359 100644 --- a/smallzstd/zstd.go +++ b/smallzstd/zstd.go @@ -1,78 +1,78 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package smallzstd produces zstd encoders and decoders optimized for -// low memory usage, at the expense of compression efficiency. -// -// This package is optimized primarily for the memory cost of -// compressing and decompressing data. We reduce this cost in two -// major ways: disable parallelism within the library (i.e. don't use -// multiple CPU cores to decompress), and drop the compression window -// down from the defaults of 4-16MiB, to 8kiB. -// -// Decompressors cost 2x the window size in RAM to run, so by using an -// 8kiB window, we can run ~1000 more decompressors per unit of memory -// than with the defaults. -// -// Depending on context, the benefit is either being able to run more -// decoders (e.g. in our logs processing system), or having a lower -// memory footprint when using compression in network protocols -// (e.g. in tailscaled, which should have a minimal RAM cost). -package smallzstd - -import ( - "io" - - "github.com/klauspost/compress/zstd" -) - -// WindowSize is the window size used for zstd compression. Decoder -// memory usage scales linearly with WindowSize. -const WindowSize = 8 << 10 // 8kiB - -// NewDecoder returns a zstd.Decoder configured for low memory usage, -// at the expense of decompression performance. -func NewDecoder(r io.Reader, options ...zstd.DOption) (*zstd.Decoder, error) { - defaults := []zstd.DOption{ - // Default is GOMAXPROCS, which costs many KiB in stacks. - zstd.WithDecoderConcurrency(1), - // Default is to allocate more upfront for performance. We - // prefer lower memory use and a bit of GC load. - zstd.WithDecoderLowmem(true), - // You might expect to see zstd.WithDecoderMaxMemory - // here. However, it's not terribly safe to use if you're - // doing stateless decoding, because it sets the maximum - // amount of memory the decompressed data can occupy, rather - // than the window size of the zstd stream. This means a very - // compressible piece of data might violate the max memory - // limit here, even if the window size (and thus total memory - // required to decompress the data) is small. - // - // As a result, we don't set a decoder limit here, and rely on - // the encoder below producing "cheap" streams. Callers are - // welcome to set their own max memory setting, if - // contextually there is a clearly correct value (e.g. it's - // known from the upper layer protocol that the decoded data - // can never be more than 1MiB). - } - - return zstd.NewReader(r, append(defaults, options...)...) -} - -// NewEncoder returns a zstd.Encoder configured for low memory usage, -// both during compression and at decompression time, at the expense -// of performance and compression efficiency. -func NewEncoder(w io.Writer, options ...zstd.EOption) (*zstd.Encoder, error) { - defaults := []zstd.EOption{ - // Default is GOMAXPROCS, which costs many KiB in stacks. - zstd.WithEncoderConcurrency(1), - // Default is several MiB, which bloats both encoders and - // their corresponding decoders. - zstd.WithWindowSize(WindowSize), - // Encode zero-length inputs in a way that the `zstd` utility - // can read, because interoperability is handy. - zstd.WithZeroFrames(true), - } - - return zstd.NewWriter(w, append(defaults, options...)...) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package smallzstd produces zstd encoders and decoders optimized for +// low memory usage, at the expense of compression efficiency. +// +// This package is optimized primarily for the memory cost of +// compressing and decompressing data. We reduce this cost in two +// major ways: disable parallelism within the library (i.e. don't use +// multiple CPU cores to decompress), and drop the compression window +// down from the defaults of 4-16MiB, to 8kiB. +// +// Decompressors cost 2x the window size in RAM to run, so by using an +// 8kiB window, we can run ~1000 more decompressors per unit of memory +// than with the defaults. +// +// Depending on context, the benefit is either being able to run more +// decoders (e.g. in our logs processing system), or having a lower +// memory footprint when using compression in network protocols +// (e.g. in tailscaled, which should have a minimal RAM cost). +package smallzstd + +import ( + "io" + + "github.com/klauspost/compress/zstd" +) + +// WindowSize is the window size used for zstd compression. Decoder +// memory usage scales linearly with WindowSize. +const WindowSize = 8 << 10 // 8kiB + +// NewDecoder returns a zstd.Decoder configured for low memory usage, +// at the expense of decompression performance. +func NewDecoder(r io.Reader, options ...zstd.DOption) (*zstd.Decoder, error) { + defaults := []zstd.DOption{ + // Default is GOMAXPROCS, which costs many KiB in stacks. + zstd.WithDecoderConcurrency(1), + // Default is to allocate more upfront for performance. We + // prefer lower memory use and a bit of GC load. + zstd.WithDecoderLowmem(true), + // You might expect to see zstd.WithDecoderMaxMemory + // here. However, it's not terribly safe to use if you're + // doing stateless decoding, because it sets the maximum + // amount of memory the decompressed data can occupy, rather + // than the window size of the zstd stream. This means a very + // compressible piece of data might violate the max memory + // limit here, even if the window size (and thus total memory + // required to decompress the data) is small. + // + // As a result, we don't set a decoder limit here, and rely on + // the encoder below producing "cheap" streams. Callers are + // welcome to set their own max memory setting, if + // contextually there is a clearly correct value (e.g. it's + // known from the upper layer protocol that the decoded data + // can never be more than 1MiB). + } + + return zstd.NewReader(r, append(defaults, options...)...) +} + +// NewEncoder returns a zstd.Encoder configured for low memory usage, +// both during compression and at decompression time, at the expense +// of performance and compression efficiency. +func NewEncoder(w io.Writer, options ...zstd.EOption) (*zstd.Encoder, error) { + defaults := []zstd.EOption{ + // Default is GOMAXPROCS, which costs many KiB in stacks. + zstd.WithEncoderConcurrency(1), + // Default is several MiB, which bloats both encoders and + // their corresponding decoders. + zstd.WithWindowSize(WindowSize), + // Encode zero-length inputs in a way that the `zstd` utility + // can read, because interoperability is handy. + zstd.WithZeroFrames(true), + } + + return zstd.NewWriter(w, append(defaults, options...)...) +} diff --git a/syncs/locked.go b/syncs/locked.go index abde5bca62415..d2048665dee3d 100644 --- a/syncs/locked.go +++ b/syncs/locked.go @@ -1,32 +1,32 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syncs - -import ( - "sync" -) - -// AssertLocked panics if m is not locked. -func AssertLocked(m *sync.Mutex) { - if m.TryLock() { - m.Unlock() - panic("mutex is not locked") - } -} - -// AssertRLocked panics if rw is not locked for reading or writing. -func AssertRLocked(rw *sync.RWMutex) { - if rw.TryLock() { - rw.Unlock() - panic("mutex is not locked") - } -} - -// AssertWLocked panics if rw is not locked for writing. -func AssertWLocked(rw *sync.RWMutex) { - if rw.TryRLock() { - rw.RUnlock() - panic("mutex is not rlocked") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "sync" +) + +// AssertLocked panics if m is not locked. +func AssertLocked(m *sync.Mutex) { + if m.TryLock() { + m.Unlock() + panic("mutex is not locked") + } +} + +// AssertRLocked panics if rw is not locked for reading or writing. +func AssertRLocked(rw *sync.RWMutex) { + if rw.TryLock() { + rw.Unlock() + panic("mutex is not locked") + } +} + +// AssertWLocked panics if rw is not locked for writing. +func AssertWLocked(rw *sync.RWMutex) { + if rw.TryRLock() { + rw.RUnlock() + panic("mutex is not rlocked") + } +} diff --git a/syncs/locked_test.go b/syncs/locked_test.go index 44877be50be1a..90b36e8321d82 100644 --- a/syncs/locked_test.go +++ b/syncs/locked_test.go @@ -1,120 +1,120 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.13 && !go1.19 - -package syncs - -import ( - "sync" - "testing" - "time" -) - -func wantPanic(t *testing.T, fn func()) { - t.Helper() - defer func() { - recover() - }() - fn() - t.Fatal("failed to panic") -} - -func TestAssertLocked(t *testing.T) { - m := new(sync.Mutex) - wantPanic(t, func() { AssertLocked(m) }) - m.Lock() - AssertLocked(m) - m.Unlock() - wantPanic(t, func() { AssertLocked(m) }) - // Test correct handling of mutex with waiter. - m.Lock() - AssertLocked(m) - go func() { - m.Lock() - m.Unlock() - }() - // Give the goroutine above a few moments to get started. - // The test will pass whether or not we win the race, - // but we want to run sometimes, to get the test coverage. - time.Sleep(10 * time.Millisecond) - AssertLocked(m) -} - -func TestAssertWLocked(t *testing.T) { - m := new(sync.RWMutex) - wantPanic(t, func() { AssertWLocked(m) }) - m.Lock() - AssertWLocked(m) - m.Unlock() - wantPanic(t, func() { AssertWLocked(m) }) - // Test correct handling of mutex with waiter. - m.Lock() - AssertWLocked(m) - go func() { - m.Lock() - m.Unlock() - }() - // Give the goroutine above a few moments to get started. - // The test will pass whether or not we win the race, - // but we want to run sometimes, to get the test coverage. - time.Sleep(10 * time.Millisecond) - AssertWLocked(m) -} - -func TestAssertRLocked(t *testing.T) { - m := new(sync.RWMutex) - wantPanic(t, func() { AssertRLocked(m) }) - - m.Lock() - AssertRLocked(m) - m.Unlock() - - m.RLock() - AssertRLocked(m) - m.RUnlock() - - wantPanic(t, func() { AssertRLocked(m) }) - - // Test correct handling of mutex with waiter. - m.RLock() - AssertRLocked(m) - go func() { - m.RLock() - m.RUnlock() - }() - // Give the goroutine above a few moments to get started. - // The test will pass whether or not we win the race, - // but we want to run sometimes, to get the test coverage. - time.Sleep(10 * time.Millisecond) - AssertRLocked(m) - m.RUnlock() - - // Test correct handling of rlock with write waiter. - m.RLock() - AssertRLocked(m) - go func() { - m.Lock() - m.Unlock() - }() - // Give the goroutine above a few moments to get started. - // The test will pass whether or not we win the race, - // but we want to run sometimes, to get the test coverage. - time.Sleep(10 * time.Millisecond) - AssertRLocked(m) - m.RUnlock() - - // Test correct handling of rlock with other rlocks. - // This is a bit racy, but losing the race hurts nothing, - // and winning the race means correct test coverage. - m.RLock() - AssertRLocked(m) - go func() { - m.RLock() - time.Sleep(10 * time.Millisecond) - m.RUnlock() - }() - time.Sleep(5 * time.Millisecond) - AssertRLocked(m) - m.RUnlock() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.13 && !go1.19 + +package syncs + +import ( + "sync" + "testing" + "time" +) + +func wantPanic(t *testing.T, fn func()) { + t.Helper() + defer func() { + recover() + }() + fn() + t.Fatal("failed to panic") +} + +func TestAssertLocked(t *testing.T) { + m := new(sync.Mutex) + wantPanic(t, func() { AssertLocked(m) }) + m.Lock() + AssertLocked(m) + m.Unlock() + wantPanic(t, func() { AssertLocked(m) }) + // Test correct handling of mutex with waiter. + m.Lock() + AssertLocked(m) + go func() { + m.Lock() + m.Unlock() + }() + // Give the goroutine above a few moments to get started. + // The test will pass whether or not we win the race, + // but we want to run sometimes, to get the test coverage. + time.Sleep(10 * time.Millisecond) + AssertLocked(m) +} + +func TestAssertWLocked(t *testing.T) { + m := new(sync.RWMutex) + wantPanic(t, func() { AssertWLocked(m) }) + m.Lock() + AssertWLocked(m) + m.Unlock() + wantPanic(t, func() { AssertWLocked(m) }) + // Test correct handling of mutex with waiter. + m.Lock() + AssertWLocked(m) + go func() { + m.Lock() + m.Unlock() + }() + // Give the goroutine above a few moments to get started. + // The test will pass whether or not we win the race, + // but we want to run sometimes, to get the test coverage. + time.Sleep(10 * time.Millisecond) + AssertWLocked(m) +} + +func TestAssertRLocked(t *testing.T) { + m := new(sync.RWMutex) + wantPanic(t, func() { AssertRLocked(m) }) + + m.Lock() + AssertRLocked(m) + m.Unlock() + + m.RLock() + AssertRLocked(m) + m.RUnlock() + + wantPanic(t, func() { AssertRLocked(m) }) + + // Test correct handling of mutex with waiter. + m.RLock() + AssertRLocked(m) + go func() { + m.RLock() + m.RUnlock() + }() + // Give the goroutine above a few moments to get started. + // The test will pass whether or not we win the race, + // but we want to run sometimes, to get the test coverage. + time.Sleep(10 * time.Millisecond) + AssertRLocked(m) + m.RUnlock() + + // Test correct handling of rlock with write waiter. + m.RLock() + AssertRLocked(m) + go func() { + m.Lock() + m.Unlock() + }() + // Give the goroutine above a few moments to get started. + // The test will pass whether or not we win the race, + // but we want to run sometimes, to get the test coverage. + time.Sleep(10 * time.Millisecond) + AssertRLocked(m) + m.RUnlock() + + // Test correct handling of rlock with other rlocks. + // This is a bit racy, but losing the race hurts nothing, + // and winning the race means correct test coverage. + m.RLock() + AssertRLocked(m) + go func() { + m.RLock() + time.Sleep(10 * time.Millisecond) + m.RUnlock() + }() + time.Sleep(5 * time.Millisecond) + AssertRLocked(m) + m.RUnlock() +} diff --git a/syncs/shardedmap.go b/syncs/shardedmap.go index 906de3ade2d5c..12edf5bfce475 100644 --- a/syncs/shardedmap.go +++ b/syncs/shardedmap.go @@ -1,138 +1,138 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syncs - -import ( - "sync" - - "golang.org/x/sys/cpu" -) - -// ShardedMap is a synchronized map[K]V, internally sharded by a user-defined -// K-sharding function. -// -// The zero value is not safe for use; use NewShardedMap. -type ShardedMap[K comparable, V any] struct { - shardFunc func(K) int - shards []mapShard[K, V] -} - -type mapShard[K comparable, V any] struct { - mu sync.Mutex - m map[K]V - _ cpu.CacheLinePad // avoid false sharing of neighboring shards' mutexes -} - -// NewShardedMap returns a new ShardedMap with the given number of shards and -// sharding function. -// -// The shard func must return a integer in the range [0, shards) purely -// deterministically based on the provided K. -func NewShardedMap[K comparable, V any](shards int, shard func(K) int) *ShardedMap[K, V] { - m := &ShardedMap[K, V]{ - shardFunc: shard, - shards: make([]mapShard[K, V], shards), - } - for i := range m.shards { - m.shards[i].m = make(map[K]V) - } - return m -} - -func (m *ShardedMap[K, V]) shard(key K) *mapShard[K, V] { - return &m.shards[m.shardFunc(key)] -} - -// GetOk returns m[key] and whether it was present. -func (m *ShardedMap[K, V]) GetOk(key K) (value V, ok bool) { - shard := m.shard(key) - shard.mu.Lock() - defer shard.mu.Unlock() - value, ok = shard.m[key] - return -} - -// Get returns m[key] or the zero value of V if key is not present. -func (m *ShardedMap[K, V]) Get(key K) (value V) { - value, _ = m.GetOk(key) - return -} - -// Mutate atomically mutates m[k] by calling mutator. -// -// The mutator function is called with the old value (or its zero value) and -// whether it existed in the map and it returns the new value and whether it -// should be set in the map (true) or deleted from the map (false). -// -// It returns the change in size of the map as a result of the mutation, one of -// -1 (delete), 0 (change), or 1 (addition). -func (m *ShardedMap[K, V]) Mutate(key K, mutator func(oldValue V, oldValueExisted bool) (newValue V, keep bool)) (sizeDelta int) { - shard := m.shard(key) - shard.mu.Lock() - defer shard.mu.Unlock() - oldV, oldOK := shard.m[key] - newV, newOK := mutator(oldV, oldOK) - if newOK { - shard.m[key] = newV - if oldOK { - return 0 - } - return 1 - } - delete(shard.m, key) - if oldOK { - return -1 - } - return 0 -} - -// Set sets m[key] = value. -// -// present in m). -func (m *ShardedMap[K, V]) Set(key K, value V) (grew bool) { - shard := m.shard(key) - shard.mu.Lock() - defer shard.mu.Unlock() - s0 := len(shard.m) - shard.m[key] = value - return len(shard.m) > s0 -} - -// Delete removes key from m. -// -// It reports whether the map size shrunk (that is, whether key was present in -// the map). -func (m *ShardedMap[K, V]) Delete(key K) (shrunk bool) { - shard := m.shard(key) - shard.mu.Lock() - defer shard.mu.Unlock() - s0 := len(shard.m) - delete(shard.m, key) - return len(shard.m) < s0 -} - -// Contains reports whether m contains key. -func (m *ShardedMap[K, V]) Contains(key K) bool { - shard := m.shard(key) - shard.mu.Lock() - defer shard.mu.Unlock() - _, ok := shard.m[key] - return ok -} - -// Len returns the number of elements in m. -// -// It does so by locking shards one at a time, so it's not particularly cheap, -// nor does it give a consistent snapshot of the map. It's mostly intended for -// metrics or testing. -func (m *ShardedMap[K, V]) Len() int { - n := 0 - for i := range m.shards { - shard := &m.shards[i] - shard.mu.Lock() - n += len(shard.m) - shard.mu.Unlock() - } - return n -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "sync" + + "golang.org/x/sys/cpu" +) + +// ShardedMap is a synchronized map[K]V, internally sharded by a user-defined +// K-sharding function. +// +// The zero value is not safe for use; use NewShardedMap. +type ShardedMap[K comparable, V any] struct { + shardFunc func(K) int + shards []mapShard[K, V] +} + +type mapShard[K comparable, V any] struct { + mu sync.Mutex + m map[K]V + _ cpu.CacheLinePad // avoid false sharing of neighboring shards' mutexes +} + +// NewShardedMap returns a new ShardedMap with the given number of shards and +// sharding function. +// +// The shard func must return a integer in the range [0, shards) purely +// deterministically based on the provided K. +func NewShardedMap[K comparable, V any](shards int, shard func(K) int) *ShardedMap[K, V] { + m := &ShardedMap[K, V]{ + shardFunc: shard, + shards: make([]mapShard[K, V], shards), + } + for i := range m.shards { + m.shards[i].m = make(map[K]V) + } + return m +} + +func (m *ShardedMap[K, V]) shard(key K) *mapShard[K, V] { + return &m.shards[m.shardFunc(key)] +} + +// GetOk returns m[key] and whether it was present. +func (m *ShardedMap[K, V]) GetOk(key K) (value V, ok bool) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + value, ok = shard.m[key] + return +} + +// Get returns m[key] or the zero value of V if key is not present. +func (m *ShardedMap[K, V]) Get(key K) (value V) { + value, _ = m.GetOk(key) + return +} + +// Mutate atomically mutates m[k] by calling mutator. +// +// The mutator function is called with the old value (or its zero value) and +// whether it existed in the map and it returns the new value and whether it +// should be set in the map (true) or deleted from the map (false). +// +// It returns the change in size of the map as a result of the mutation, one of +// -1 (delete), 0 (change), or 1 (addition). +func (m *ShardedMap[K, V]) Mutate(key K, mutator func(oldValue V, oldValueExisted bool) (newValue V, keep bool)) (sizeDelta int) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + oldV, oldOK := shard.m[key] + newV, newOK := mutator(oldV, oldOK) + if newOK { + shard.m[key] = newV + if oldOK { + return 0 + } + return 1 + } + delete(shard.m, key) + if oldOK { + return -1 + } + return 0 +} + +// Set sets m[key] = value. +// +// present in m). +func (m *ShardedMap[K, V]) Set(key K, value V) (grew bool) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + s0 := len(shard.m) + shard.m[key] = value + return len(shard.m) > s0 +} + +// Delete removes key from m. +// +// It reports whether the map size shrunk (that is, whether key was present in +// the map). +func (m *ShardedMap[K, V]) Delete(key K) (shrunk bool) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + s0 := len(shard.m) + delete(shard.m, key) + return len(shard.m) < s0 +} + +// Contains reports whether m contains key. +func (m *ShardedMap[K, V]) Contains(key K) bool { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + _, ok := shard.m[key] + return ok +} + +// Len returns the number of elements in m. +// +// It does so by locking shards one at a time, so it's not particularly cheap, +// nor does it give a consistent snapshot of the map. It's mostly intended for +// metrics or testing. +func (m *ShardedMap[K, V]) Len() int { + n := 0 + for i := range m.shards { + shard := &m.shards[i] + shard.mu.Lock() + n += len(shard.m) + shard.mu.Unlock() + } + return n +} diff --git a/syncs/shardedmap_test.go b/syncs/shardedmap_test.go index 170201c0a2b13..993ffdff875c2 100644 --- a/syncs/shardedmap_test.go +++ b/syncs/shardedmap_test.go @@ -1,81 +1,81 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syncs - -import "testing" - -func TestShardedMap(t *testing.T) { - m := NewShardedMap[int, string](16, func(i int) int { return i % 16 }) - - if m.Contains(1) { - t.Errorf("got contains; want !contains") - } - if !m.Set(1, "one") { - t.Errorf("got !set; want set") - } - if m.Set(1, "one") { - t.Errorf("got set; want !set") - } - if !m.Contains(1) { - t.Errorf("got !contains; want contains") - } - if g, w := m.Get(1), "one"; g != w { - t.Errorf("got %q; want %q", g, w) - } - if _, ok := m.GetOk(1); !ok { - t.Errorf("got ok; want !ok") - } - if _, ok := m.GetOk(2); ok { - t.Errorf("got ok; want !ok") - } - if g, w := m.Len(), 1; g != w { - t.Errorf("got Len %v; want %v", g, w) - } - if m.Delete(2) { - t.Errorf("got deleted; want !deleted") - } - if !m.Delete(1) { - t.Errorf("got !deleted; want deleted") - } - if g, w := m.Len(), 0; g != w { - t.Errorf("got Len %v; want %v", g, w) - } - - // Mutation adding an entry. - if v := m.Mutate(1, func(was string, ok bool) (string, bool) { - if ok { - t.Fatal("was okay") - } - return "ONE", true - }); v != 1 { - t.Errorf("Mutate = %v; want 1", v) - } - if g, w := m.Get(1), "ONE"; g != w { - t.Errorf("got %q; want %q", g, w) - } - // Mutation changing an entry. - if v := m.Mutate(1, func(was string, ok bool) (string, bool) { - if !ok { - t.Fatal("wasn't okay") - } - return was + "-" + was, true - }); v != 0 { - t.Errorf("Mutate = %v; want 0", v) - } - if g, w := m.Get(1), "ONE-ONE"; g != w { - t.Errorf("got %q; want %q", g, w) - } - // Mutation removing an entry. - if v := m.Mutate(1, func(was string, ok bool) (string, bool) { - if !ok { - t.Fatal("wasn't okay") - } - return "", false - }); v != -1 { - t.Errorf("Mutate = %v; want -1", v) - } - if g, w := m.Get(1), ""; g != w { - t.Errorf("got %q; want %q", g, w) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import "testing" + +func TestShardedMap(t *testing.T) { + m := NewShardedMap[int, string](16, func(i int) int { return i % 16 }) + + if m.Contains(1) { + t.Errorf("got contains; want !contains") + } + if !m.Set(1, "one") { + t.Errorf("got !set; want set") + } + if m.Set(1, "one") { + t.Errorf("got set; want !set") + } + if !m.Contains(1) { + t.Errorf("got !contains; want contains") + } + if g, w := m.Get(1), "one"; g != w { + t.Errorf("got %q; want %q", g, w) + } + if _, ok := m.GetOk(1); !ok { + t.Errorf("got ok; want !ok") + } + if _, ok := m.GetOk(2); ok { + t.Errorf("got ok; want !ok") + } + if g, w := m.Len(), 1; g != w { + t.Errorf("got Len %v; want %v", g, w) + } + if m.Delete(2) { + t.Errorf("got deleted; want !deleted") + } + if !m.Delete(1) { + t.Errorf("got !deleted; want deleted") + } + if g, w := m.Len(), 0; g != w { + t.Errorf("got Len %v; want %v", g, w) + } + + // Mutation adding an entry. + if v := m.Mutate(1, func(was string, ok bool) (string, bool) { + if ok { + t.Fatal("was okay") + } + return "ONE", true + }); v != 1 { + t.Errorf("Mutate = %v; want 1", v) + } + if g, w := m.Get(1), "ONE"; g != w { + t.Errorf("got %q; want %q", g, w) + } + // Mutation changing an entry. + if v := m.Mutate(1, func(was string, ok bool) (string, bool) { + if !ok { + t.Fatal("wasn't okay") + } + return was + "-" + was, true + }); v != 0 { + t.Errorf("Mutate = %v; want 0", v) + } + if g, w := m.Get(1), "ONE-ONE"; g != w { + t.Errorf("got %q; want %q", g, w) + } + // Mutation removing an entry. + if v := m.Mutate(1, func(was string, ok bool) (string, bool) { + if !ok { + t.Fatal("wasn't okay") + } + return "", false + }); v != -1 { + t.Errorf("Mutate = %v; want -1", v) + } + if g, w := m.Get(1), ""; g != w { + t.Errorf("got %q; want %q", g, w) + } +} diff --git a/tailcfg/proto_port_range.go b/tailcfg/proto_port_range.go index 0bb7e388eaaa8..f65c58804d44d 100644 --- a/tailcfg/proto_port_range.go +++ b/tailcfg/proto_port_range.go @@ -1,187 +1,187 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailcfg - -import ( - "errors" - "fmt" - "strconv" - "strings" - - "tailscale.com/types/ipproto" - "tailscale.com/util/vizerror" -) - -var ( - errEmptyProtocol = errors.New("empty protocol") - errEmptyString = errors.New("empty string") -) - -// ProtoPortRange is used to encode "proto:port" format. -// The following formats are supported: -// -// "*" allows all TCP, UDP and ICMP traffic on all ports. -// "" allows all TCP, UDP and ICMP traffic on the specified ports. -// "proto:*" allows traffic of the specified proto on all ports. -// "proto:" allows traffic of the specified proto on the specified port. -// -// Ports are either a single port number or a range of ports (e.g. "80-90"). -// String named protocols support names that ipproto.Proto accepts. -type ProtoPortRange struct { - // Proto is the IP protocol number. - // If Proto is 0, it means TCP+UDP+ICMP(4+6). - Proto int - Ports PortRange -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. See -// ProtoPortRange for the format. -func (ppr *ProtoPortRange) UnmarshalText(text []byte) error { - ppr2, err := parseProtoPortRange(string(text)) - if err != nil { - return err - } - *ppr = *ppr2 - return nil -} - -// MarshalText implements the encoding.TextMarshaler interface. See -// ProtoPortRange for the format. -func (ppr *ProtoPortRange) MarshalText() ([]byte, error) { - if ppr.Proto == 0 && ppr.Ports == (PortRange{}) { - return []byte{}, nil - } - return []byte(ppr.String()), nil -} - -// String implements the stringer interface. See ProtoPortRange for the -// format. -func (ppr ProtoPortRange) String() string { - if ppr.Proto == 0 { - if ppr.Ports == PortRangeAny { - return "*" - } - } - var buf strings.Builder - if ppr.Proto != 0 { - // Proto.MarshalText is infallible. - text, _ := ipproto.Proto(ppr.Proto).MarshalText() - buf.Write(text) - buf.Write([]byte(":")) - } - pr := ppr.Ports - if pr.First == pr.Last { - fmt.Fprintf(&buf, "%d", pr.First) - } else if pr == PortRangeAny { - buf.WriteByte('*') - } else { - fmt.Fprintf(&buf, "%d-%d", pr.First, pr.Last) - } - return buf.String() -} - -// ParseProtoPortRanges parses a slice of IP port range fields. -func ParseProtoPortRanges(ips []string) ([]ProtoPortRange, error) { - var out []ProtoPortRange - for _, p := range ips { - ppr, err := parseProtoPortRange(p) - if err != nil { - return nil, err - } - out = append(out, *ppr) - } - return out, nil -} - -func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) { - if ipProtoPort == "" { - return nil, errEmptyString - } - if ipProtoPort == "*" { - return &ProtoPortRange{Ports: PortRangeAny}, nil - } - if !strings.Contains(ipProtoPort, ":") { - ipProtoPort = "*:" + ipProtoPort - } - protoStr, portRange, err := parseHostPortRange(ipProtoPort) - if err != nil { - return nil, err - } - if protoStr == "" { - return nil, errEmptyProtocol - } - - ppr := &ProtoPortRange{ - Ports: portRange, - } - if protoStr == "*" { - return ppr, nil - } - var ipProto ipproto.Proto - if err := ipProto.UnmarshalText([]byte(protoStr)); err != nil { - return nil, err - } - ppr.Proto = int(ipProto) - return ppr, nil -} - -// parseHostPortRange parses hostport as HOST:PORTS where HOST is -// returned unchanged and PORTS is is either "*" or PORTLOW-PORTHIGH ranges. -func parseHostPortRange(hostport string) (host string, ports PortRange, err error) { - hostport = strings.ToLower(hostport) - colon := strings.LastIndexByte(hostport, ':') - if colon < 0 { - return "", ports, vizerror.New("hostport must contain a colon (\":\")") - } - host = hostport[:colon] - portlist := hostport[colon+1:] - - if strings.Contains(host, ",") { - return "", ports, vizerror.New("host cannot contain a comma (\",\")") - } - - if portlist == "*" { - // Special case: permit hostname:* as a port wildcard. - return host, PortRangeAny, nil - } - - if len(portlist) == 0 { - return "", ports, vizerror.Errorf("invalid port list: %#v", portlist) - } - - if strings.Count(portlist, "-") > 1 { - return "", ports, vizerror.Errorf("port range %#v: too many dashes(-)", portlist) - } - - firstStr, lastStr, isRange := strings.Cut(portlist, "-") - - var first, last uint64 - first, err = strconv.ParseUint(firstStr, 10, 16) - if err != nil { - return "", ports, vizerror.Errorf("port range %#v: invalid first integer", portlist) - } - - if isRange { - last, err = strconv.ParseUint(lastStr, 10, 16) - if err != nil { - return "", ports, vizerror.Errorf("port range %#v: invalid last integer", portlist) - } - } else { - last = first - } - - if first == 0 { - return "", ports, vizerror.Errorf("port range %#v: first port must be >0, or use '*' for wildcard", portlist) - } - - if first > last { - return "", ports, vizerror.Errorf("port range %#v: first port must be >= last port", portlist) - } - - return host, newPortRange(uint16(first), uint16(last)), nil -} - -func newPortRange(first, last uint16) PortRange { - return PortRange{First: first, Last: last} -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailcfg + +import ( + "errors" + "fmt" + "strconv" + "strings" + + "tailscale.com/types/ipproto" + "tailscale.com/util/vizerror" +) + +var ( + errEmptyProtocol = errors.New("empty protocol") + errEmptyString = errors.New("empty string") +) + +// ProtoPortRange is used to encode "proto:port" format. +// The following formats are supported: +// +// "*" allows all TCP, UDP and ICMP traffic on all ports. +// "" allows all TCP, UDP and ICMP traffic on the specified ports. +// "proto:*" allows traffic of the specified proto on all ports. +// "proto:" allows traffic of the specified proto on the specified port. +// +// Ports are either a single port number or a range of ports (e.g. "80-90"). +// String named protocols support names that ipproto.Proto accepts. +type ProtoPortRange struct { + // Proto is the IP protocol number. + // If Proto is 0, it means TCP+UDP+ICMP(4+6). + Proto int + Ports PortRange +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. See +// ProtoPortRange for the format. +func (ppr *ProtoPortRange) UnmarshalText(text []byte) error { + ppr2, err := parseProtoPortRange(string(text)) + if err != nil { + return err + } + *ppr = *ppr2 + return nil +} + +// MarshalText implements the encoding.TextMarshaler interface. See +// ProtoPortRange for the format. +func (ppr *ProtoPortRange) MarshalText() ([]byte, error) { + if ppr.Proto == 0 && ppr.Ports == (PortRange{}) { + return []byte{}, nil + } + return []byte(ppr.String()), nil +} + +// String implements the stringer interface. See ProtoPortRange for the +// format. +func (ppr ProtoPortRange) String() string { + if ppr.Proto == 0 { + if ppr.Ports == PortRangeAny { + return "*" + } + } + var buf strings.Builder + if ppr.Proto != 0 { + // Proto.MarshalText is infallible. + text, _ := ipproto.Proto(ppr.Proto).MarshalText() + buf.Write(text) + buf.Write([]byte(":")) + } + pr := ppr.Ports + if pr.First == pr.Last { + fmt.Fprintf(&buf, "%d", pr.First) + } else if pr == PortRangeAny { + buf.WriteByte('*') + } else { + fmt.Fprintf(&buf, "%d-%d", pr.First, pr.Last) + } + return buf.String() +} + +// ParseProtoPortRanges parses a slice of IP port range fields. +func ParseProtoPortRanges(ips []string) ([]ProtoPortRange, error) { + var out []ProtoPortRange + for _, p := range ips { + ppr, err := parseProtoPortRange(p) + if err != nil { + return nil, err + } + out = append(out, *ppr) + } + return out, nil +} + +func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) { + if ipProtoPort == "" { + return nil, errEmptyString + } + if ipProtoPort == "*" { + return &ProtoPortRange{Ports: PortRangeAny}, nil + } + if !strings.Contains(ipProtoPort, ":") { + ipProtoPort = "*:" + ipProtoPort + } + protoStr, portRange, err := parseHostPortRange(ipProtoPort) + if err != nil { + return nil, err + } + if protoStr == "" { + return nil, errEmptyProtocol + } + + ppr := &ProtoPortRange{ + Ports: portRange, + } + if protoStr == "*" { + return ppr, nil + } + var ipProto ipproto.Proto + if err := ipProto.UnmarshalText([]byte(protoStr)); err != nil { + return nil, err + } + ppr.Proto = int(ipProto) + return ppr, nil +} + +// parseHostPortRange parses hostport as HOST:PORTS where HOST is +// returned unchanged and PORTS is is either "*" or PORTLOW-PORTHIGH ranges. +func parseHostPortRange(hostport string) (host string, ports PortRange, err error) { + hostport = strings.ToLower(hostport) + colon := strings.LastIndexByte(hostport, ':') + if colon < 0 { + return "", ports, vizerror.New("hostport must contain a colon (\":\")") + } + host = hostport[:colon] + portlist := hostport[colon+1:] + + if strings.Contains(host, ",") { + return "", ports, vizerror.New("host cannot contain a comma (\",\")") + } + + if portlist == "*" { + // Special case: permit hostname:* as a port wildcard. + return host, PortRangeAny, nil + } + + if len(portlist) == 0 { + return "", ports, vizerror.Errorf("invalid port list: %#v", portlist) + } + + if strings.Count(portlist, "-") > 1 { + return "", ports, vizerror.Errorf("port range %#v: too many dashes(-)", portlist) + } + + firstStr, lastStr, isRange := strings.Cut(portlist, "-") + + var first, last uint64 + first, err = strconv.ParseUint(firstStr, 10, 16) + if err != nil { + return "", ports, vizerror.Errorf("port range %#v: invalid first integer", portlist) + } + + if isRange { + last, err = strconv.ParseUint(lastStr, 10, 16) + if err != nil { + return "", ports, vizerror.Errorf("port range %#v: invalid last integer", portlist) + } + } else { + last = first + } + + if first == 0 { + return "", ports, vizerror.Errorf("port range %#v: first port must be >0, or use '*' for wildcard", portlist) + } + + if first > last { + return "", ports, vizerror.Errorf("port range %#v: first port must be >= last port", portlist) + } + + return host, newPortRange(uint16(first), uint16(last)), nil +} + +func newPortRange(first, last uint16) PortRange { + return PortRange{First: first, Last: last} +} diff --git a/tailcfg/proto_port_range_test.go b/tailcfg/proto_port_range_test.go index 31b282641e975..59ccc9be4a1a8 100644 --- a/tailcfg/proto_port_range_test.go +++ b/tailcfg/proto_port_range_test.go @@ -1,131 +1,131 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailcfg - -import ( - "encoding" - "testing" - - "tailscale.com/types/ipproto" - "tailscale.com/util/vizerror" -) - -var _ encoding.TextUnmarshaler = (*ProtoPortRange)(nil) - -func TestProtoPortRangeParsing(t *testing.T) { - pr := func(s, e uint16) PortRange { - return PortRange{First: s, Last: e} - } - tests := []struct { - in string - out ProtoPortRange - err error - }{ - {in: "tcp:80", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: pr(80, 80)}}, - {in: "80", out: ProtoPortRange{Ports: pr(80, 80)}}, - {in: "*", out: ProtoPortRange{Ports: PortRangeAny}}, - {in: "*:*", out: ProtoPortRange{Ports: PortRangeAny}}, - {in: "tcp:*", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: PortRangeAny}}, - { - in: "tcp:", - err: vizerror.Errorf("invalid port list: %#v", ""), - }, - { - in: ":80", - err: errEmptyProtocol, - }, - { - in: "", - err: errEmptyString, - }, - } - - for _, tc := range tests { - t.Run(tc.in, func(t *testing.T) { - var ppr ProtoPortRange - err := ppr.UnmarshalText([]byte(tc.in)) - if tc.err != err { - if err == nil || tc.err.Error() != err.Error() { - t.Fatalf("want err=%v, got %v", tc.err, err) - } - } - if ppr != tc.out { - t.Fatalf("got %v; want %v", ppr, tc.out) - } - }) - } -} - -func TestProtoPortRangeString(t *testing.T) { - tests := []struct { - input ProtoPortRange - want string - }{ - {ProtoPortRange{}, "0"}, - - // Zero protocol. - {ProtoPortRange{Ports: PortRangeAny}, "*"}, - {ProtoPortRange{Ports: PortRange{23, 23}}, "23"}, - {ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"}, - - // Non-zero unnamed protocol. - {ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"}, - {ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"}, - - // Non-zero named protocol. - {ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"}, - {ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"}, - {ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"}, - {ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"}, - {ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"}, - {ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"}, - {ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"}, - {ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"}, - } - for _, tc := range tests { - if got := tc.input.String(); got != tc.want { - t.Errorf("String for %v: got %q, want %q", tc.input, got, tc.want) - } - } -} - -func TestProtoPortRangeRoundTrip(t *testing.T) { - tests := []struct { - input ProtoPortRange - text string - }{ - {ProtoPortRange{Ports: PortRangeAny}, "*"}, - {ProtoPortRange{Ports: PortRange{23, 23}}, "23"}, - {ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"}, - {ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"}, - {ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"}, - {ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"}, - {ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"}, - {ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"}, - {ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"}, - {ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"}, - {ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"}, - {ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"}, - {ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"}, - } - - for _, tc := range tests { - out, err := tc.input.MarshalText() - if err != nil { - t.Errorf("MarshalText for %v: %v", tc.input, err) - continue - } - if got := string(out); got != tc.text { - t.Errorf("MarshalText for %#v: got %q, want %q", tc.input, got, tc.text) - } - var ppr ProtoPortRange - if err := ppr.UnmarshalText(out); err != nil { - t.Errorf("UnmarshalText for %q: err=%v", tc.text, err) - continue - } - if ppr != tc.input { - t.Errorf("round trip error for %q: got %v, want %#v", tc.text, ppr, tc.input) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailcfg + +import ( + "encoding" + "testing" + + "tailscale.com/types/ipproto" + "tailscale.com/util/vizerror" +) + +var _ encoding.TextUnmarshaler = (*ProtoPortRange)(nil) + +func TestProtoPortRangeParsing(t *testing.T) { + pr := func(s, e uint16) PortRange { + return PortRange{First: s, Last: e} + } + tests := []struct { + in string + out ProtoPortRange + err error + }{ + {in: "tcp:80", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: pr(80, 80)}}, + {in: "80", out: ProtoPortRange{Ports: pr(80, 80)}}, + {in: "*", out: ProtoPortRange{Ports: PortRangeAny}}, + {in: "*:*", out: ProtoPortRange{Ports: PortRangeAny}}, + {in: "tcp:*", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: PortRangeAny}}, + { + in: "tcp:", + err: vizerror.Errorf("invalid port list: %#v", ""), + }, + { + in: ":80", + err: errEmptyProtocol, + }, + { + in: "", + err: errEmptyString, + }, + } + + for _, tc := range tests { + t.Run(tc.in, func(t *testing.T) { + var ppr ProtoPortRange + err := ppr.UnmarshalText([]byte(tc.in)) + if tc.err != err { + if err == nil || tc.err.Error() != err.Error() { + t.Fatalf("want err=%v, got %v", tc.err, err) + } + } + if ppr != tc.out { + t.Fatalf("got %v; want %v", ppr, tc.out) + } + }) + } +} + +func TestProtoPortRangeString(t *testing.T) { + tests := []struct { + input ProtoPortRange + want string + }{ + {ProtoPortRange{}, "0"}, + + // Zero protocol. + {ProtoPortRange{Ports: PortRangeAny}, "*"}, + {ProtoPortRange{Ports: PortRange{23, 23}}, "23"}, + {ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"}, + + // Non-zero unnamed protocol. + {ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"}, + {ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"}, + + // Non-zero named protocol. + {ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"}, + {ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"}, + {ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"}, + {ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"}, + {ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"}, + {ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"}, + {ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"}, + {ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"}, + } + for _, tc := range tests { + if got := tc.input.String(); got != tc.want { + t.Errorf("String for %v: got %q, want %q", tc.input, got, tc.want) + } + } +} + +func TestProtoPortRangeRoundTrip(t *testing.T) { + tests := []struct { + input ProtoPortRange + text string + }{ + {ProtoPortRange{Ports: PortRangeAny}, "*"}, + {ProtoPortRange{Ports: PortRange{23, 23}}, "23"}, + {ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"}, + {ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"}, + {ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"}, + {ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"}, + {ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"}, + {ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"}, + {ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"}, + {ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"}, + {ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"}, + {ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"}, + {ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"}, + } + + for _, tc := range tests { + out, err := tc.input.MarshalText() + if err != nil { + t.Errorf("MarshalText for %v: %v", tc.input, err) + continue + } + if got := string(out); got != tc.text { + t.Errorf("MarshalText for %#v: got %q, want %q", tc.input, got, tc.text) + } + var ppr ProtoPortRange + if err := ppr.UnmarshalText(out); err != nil { + t.Errorf("UnmarshalText for %q: err=%v", tc.text, err) + continue + } + if ppr != tc.input { + t.Errorf("round trip error for %q: got %v, want %#v", tc.text, ppr, tc.input) + } + } +} diff --git a/tailcfg/tka.go b/tailcfg/tka.go index ca7e6be76ba1e..97fdcc0db687a 100644 --- a/tailcfg/tka.go +++ b/tailcfg/tka.go @@ -1,264 +1,264 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailcfg - -import ( - "tailscale.com/types/key" - "tailscale.com/types/tkatype" -) - -// TKAInitBeginRequest submits a genesis AUM to seed the creation of the -// tailnet's key authority. -type TKAInitBeginRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // GenesisAUM is the initial (genesis) AUM that the node generated - // to bootstrap tailnet key authority state. - GenesisAUM tkatype.MarshaledAUM -} - -// TKASignInfo describes information about an existing node that needs -// to be signed into a node-key signature. -type TKASignInfo struct { - // NodeID is the ID of the node which needs a signature. It must - // correspond to NodePublic. - NodeID NodeID - // NodePublic is the node (Wireguard) public key which is being - // signed. - NodePublic key.NodePublic - - // RotationPubkey specifies the public key which may sign - // a NodeKeySignature (NKS), which rotates the node key. - // - // This is necessary so the node can rotate its node-key without - // talking to a node which holds a trusted network-lock key. - // It does this by nesting the original NKS in a 'rotation' NKS, - // which it then signs with the key corresponding to RotationPubkey. - // - // This field expects a raw ed25519 public key. - RotationPubkey []byte -} - -// TKAInitBeginResponse is the JSON response from a /tka/init/begin RPC. -// This structure describes node information which must be signed to -// complete initialization of the tailnets' key authority. -type TKAInitBeginResponse struct { - // NeedSignatures specify information about the nodes in your tailnet - // which need initial signatures to function once the tailnet key - // authority is in use. The generated signatures should then be - // submitted in a /tka/init/finish RPC. - NeedSignatures []TKASignInfo -} - -// TKAInitFinishRequest is the JSON request of a /tka/init/finish RPC. -// This RPC finalizes initialization of the tailnet key authority -// by submitting node-key signatures for all existing nodes. -type TKAInitFinishRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // Signatures are serialized tka.NodeKeySignatures for all nodes - // in the tailnet. - Signatures map[NodeID]tkatype.MarshaledSignature - - // SupportDisablement is a disablement secret for Tailscale support. - // This is only generated if --gen-disablement-for-support is specified - // in an invocation to 'tailscale lock init'. - SupportDisablement []byte `json:",omitempty"` -} - -// TKAInitFinishResponse is the JSON response from a /tka/init/finish RPC. -// This schema describes the successful enablement of the tailnet's -// key authority. -type TKAInitFinishResponse struct { - // Nothing. (yet?) -} - -// TKAInfo encodes the control plane's view of tailnet key authority (TKA) -// state. This information is transmitted as part of the MapResponse. -type TKAInfo struct { - // Head describes the hash of the latest AUM applied to the authority. - // Head is encoded as tka.AUMHash.MarshalText. - // - // If the Head state differs to that known locally, the node should perform - // synchronization via a separate RPC. - Head string `json:",omitempty"` - - // Disabled indicates the control plane believes TKA should be disabled, - // and the node should reach out to fetch a disablement - // secret. If the disablement secret verifies, then the node should then - // disable TKA locally. - // This field exists to disambiguate a nil TKAInfo in a delta mapresponse - // from a nil TKAInfo indicating TKA should be disabled. - Disabled bool `json:",omitempty"` -} - -// TKABootstrapRequest is sent by a node to get information necessary for -// enabling or disabling the tailnet key authority. -type TKABootstrapRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // Head represents the node's head AUMHash (tka.Authority.Head), if - // network lock is enabled. - Head string -} - -// TKABootstrapResponse encodes values necessary to enable or disable -// the tailnet key authority (TKA). -type TKABootstrapResponse struct { - // GenesisAUM returns the initial AUM necessary to initialize TKA. - GenesisAUM tkatype.MarshaledAUM `json:",omitempty"` - - // DisablementSecret encodes a secret necessary to disable TKA. - DisablementSecret []byte `json:",omitempty"` -} - -// TKASyncOfferRequest encodes a request to synchronize tailnet key authority -// state (TKA). Values of type tka.AUMHash are encoded as strings in their -// MarshalText form. -type TKASyncOfferRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // Head represents the node's head AUMHash (tka.Authority.Head). This - // corresponds to tka.SyncOffer.Head. - Head string - // Ancestors represents a selection of ancestor AUMHash values ascending - // from the current head. This corresponds to tka.SyncOffer.Ancestors. - Ancestors []string -} - -// TKASyncOfferResponse encodes a response in synchronizing a node's -// tailnet key authority state. Values of type tka.AUMHash are encoded as -// strings in their MarshalText form. -type TKASyncOfferResponse struct { - // Head represents the control plane's head AUMHash (tka.Authority.Head). - // This corresponds to tka.SyncOffer.Head. - Head string - // Ancestors represents a selection of ancestor AUMHash values ascending - // from the control plane's head. This corresponds to - // tka.SyncOffer.Ancestors. - Ancestors []string - // MissingAUMs encodes AUMs that the control plane believes the node - // is missing. - MissingAUMs []tkatype.MarshaledAUM -} - -// TKASyncSendRequest encodes AUMs that a node believes the control plane -// is missing, and notifies control of its local TKA state (specifically -// the head hash). -type TKASyncSendRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // Head represents the node's head AUMHash (tka.Authority.Head) after - // applying any AUMs from the sync-offer response. - // It is encoded as tka.AUMHash.MarshalText. - Head string - - // MissingAUMs encodes AUMs that the node believes the control plane - // is missing. - MissingAUMs []tkatype.MarshaledAUM - - // Interactive is true if additional error checking should be performed as - // the request is on behalf of an interactive operation (e.g., an - // administrator publishing new changes) as opposed to an automatic - // synchronization that may be reporting lost data. - Interactive bool -} - -// TKASyncSendResponse encodes the control plane's response to a node -// submitting AUMs during AUM synchronization. -type TKASyncSendResponse struct { - // Head represents the control plane's head AUMHash (tka.Authority.Head), - // after applying the missing AUMs. - Head string -} - -// TKADisableRequest disables network-lock across the tailnet using the -// provided disablement secret. -// -// This is the request schema for a /tka/disable noise RPC. -type TKADisableRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // Head represents the node's head AUMHash (tka.Authority.Head). - // It is encoded as tka.AUMHash.MarshalText. - Head string - - // DisablementSecret encodes the secret necessary to disable TKA. - DisablementSecret []byte -} - -// TKADisableResponse is the JSON response from a /tka/disable RPC. -// This schema describes the successful disablement of the tailnet's -// key authority. -type TKADisableResponse struct { - // Nothing. (yet?) -} - -// TKASubmitSignatureRequest transmits a node-key signature to the control plane. -// -// This is the request schema for a /tka/sign noise RPC. -type TKASubmitSignatureRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. The node-key which - // is being signed is embedded in Signature. - NodeKey key.NodePublic - - // Signature encodes the node-key signature being submitted. - Signature tkatype.MarshaledSignature -} - -// TKASubmitSignatureResponse is the JSON response from a /tka/sign RPC. -type TKASubmitSignatureResponse struct { - // Nothing. (yet?) -} - -// TKASignaturesUsingKeyRequest asks the control plane for -// all signatures which are signed by the provided keyID. -// -// This is the request schema for a /tka/affected-sigs RPC. -type TKASignaturesUsingKeyRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // KeyID is the key we are querying using. - KeyID tkatype.KeyID -} - -// TKASignaturesUsingKeyResponse is the JSON response to -// a /tka/affected-sigs RPC. -// -// It enumerates all signatures which are signed by the -// queried keyID. -type TKASignaturesUsingKeyResponse struct { - Signatures []tkatype.MarshaledSignature -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailcfg + +import ( + "tailscale.com/types/key" + "tailscale.com/types/tkatype" +) + +// TKAInitBeginRequest submits a genesis AUM to seed the creation of the +// tailnet's key authority. +type TKAInitBeginRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // GenesisAUM is the initial (genesis) AUM that the node generated + // to bootstrap tailnet key authority state. + GenesisAUM tkatype.MarshaledAUM +} + +// TKASignInfo describes information about an existing node that needs +// to be signed into a node-key signature. +type TKASignInfo struct { + // NodeID is the ID of the node which needs a signature. It must + // correspond to NodePublic. + NodeID NodeID + // NodePublic is the node (Wireguard) public key which is being + // signed. + NodePublic key.NodePublic + + // RotationPubkey specifies the public key which may sign + // a NodeKeySignature (NKS), which rotates the node key. + // + // This is necessary so the node can rotate its node-key without + // talking to a node which holds a trusted network-lock key. + // It does this by nesting the original NKS in a 'rotation' NKS, + // which it then signs with the key corresponding to RotationPubkey. + // + // This field expects a raw ed25519 public key. + RotationPubkey []byte +} + +// TKAInitBeginResponse is the JSON response from a /tka/init/begin RPC. +// This structure describes node information which must be signed to +// complete initialization of the tailnets' key authority. +type TKAInitBeginResponse struct { + // NeedSignatures specify information about the nodes in your tailnet + // which need initial signatures to function once the tailnet key + // authority is in use. The generated signatures should then be + // submitted in a /tka/init/finish RPC. + NeedSignatures []TKASignInfo +} + +// TKAInitFinishRequest is the JSON request of a /tka/init/finish RPC. +// This RPC finalizes initialization of the tailnet key authority +// by submitting node-key signatures for all existing nodes. +type TKAInitFinishRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // Signatures are serialized tka.NodeKeySignatures for all nodes + // in the tailnet. + Signatures map[NodeID]tkatype.MarshaledSignature + + // SupportDisablement is a disablement secret for Tailscale support. + // This is only generated if --gen-disablement-for-support is specified + // in an invocation to 'tailscale lock init'. + SupportDisablement []byte `json:",omitempty"` +} + +// TKAInitFinishResponse is the JSON response from a /tka/init/finish RPC. +// This schema describes the successful enablement of the tailnet's +// key authority. +type TKAInitFinishResponse struct { + // Nothing. (yet?) +} + +// TKAInfo encodes the control plane's view of tailnet key authority (TKA) +// state. This information is transmitted as part of the MapResponse. +type TKAInfo struct { + // Head describes the hash of the latest AUM applied to the authority. + // Head is encoded as tka.AUMHash.MarshalText. + // + // If the Head state differs to that known locally, the node should perform + // synchronization via a separate RPC. + Head string `json:",omitempty"` + + // Disabled indicates the control plane believes TKA should be disabled, + // and the node should reach out to fetch a disablement + // secret. If the disablement secret verifies, then the node should then + // disable TKA locally. + // This field exists to disambiguate a nil TKAInfo in a delta mapresponse + // from a nil TKAInfo indicating TKA should be disabled. + Disabled bool `json:",omitempty"` +} + +// TKABootstrapRequest is sent by a node to get information necessary for +// enabling or disabling the tailnet key authority. +type TKABootstrapRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // Head represents the node's head AUMHash (tka.Authority.Head), if + // network lock is enabled. + Head string +} + +// TKABootstrapResponse encodes values necessary to enable or disable +// the tailnet key authority (TKA). +type TKABootstrapResponse struct { + // GenesisAUM returns the initial AUM necessary to initialize TKA. + GenesisAUM tkatype.MarshaledAUM `json:",omitempty"` + + // DisablementSecret encodes a secret necessary to disable TKA. + DisablementSecret []byte `json:",omitempty"` +} + +// TKASyncOfferRequest encodes a request to synchronize tailnet key authority +// state (TKA). Values of type tka.AUMHash are encoded as strings in their +// MarshalText form. +type TKASyncOfferRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // Head represents the node's head AUMHash (tka.Authority.Head). This + // corresponds to tka.SyncOffer.Head. + Head string + // Ancestors represents a selection of ancestor AUMHash values ascending + // from the current head. This corresponds to tka.SyncOffer.Ancestors. + Ancestors []string +} + +// TKASyncOfferResponse encodes a response in synchronizing a node's +// tailnet key authority state. Values of type tka.AUMHash are encoded as +// strings in their MarshalText form. +type TKASyncOfferResponse struct { + // Head represents the control plane's head AUMHash (tka.Authority.Head). + // This corresponds to tka.SyncOffer.Head. + Head string + // Ancestors represents a selection of ancestor AUMHash values ascending + // from the control plane's head. This corresponds to + // tka.SyncOffer.Ancestors. + Ancestors []string + // MissingAUMs encodes AUMs that the control plane believes the node + // is missing. + MissingAUMs []tkatype.MarshaledAUM +} + +// TKASyncSendRequest encodes AUMs that a node believes the control plane +// is missing, and notifies control of its local TKA state (specifically +// the head hash). +type TKASyncSendRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // Head represents the node's head AUMHash (tka.Authority.Head) after + // applying any AUMs from the sync-offer response. + // It is encoded as tka.AUMHash.MarshalText. + Head string + + // MissingAUMs encodes AUMs that the node believes the control plane + // is missing. + MissingAUMs []tkatype.MarshaledAUM + + // Interactive is true if additional error checking should be performed as + // the request is on behalf of an interactive operation (e.g., an + // administrator publishing new changes) as opposed to an automatic + // synchronization that may be reporting lost data. + Interactive bool +} + +// TKASyncSendResponse encodes the control plane's response to a node +// submitting AUMs during AUM synchronization. +type TKASyncSendResponse struct { + // Head represents the control plane's head AUMHash (tka.Authority.Head), + // after applying the missing AUMs. + Head string +} + +// TKADisableRequest disables network-lock across the tailnet using the +// provided disablement secret. +// +// This is the request schema for a /tka/disable noise RPC. +type TKADisableRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // Head represents the node's head AUMHash (tka.Authority.Head). + // It is encoded as tka.AUMHash.MarshalText. + Head string + + // DisablementSecret encodes the secret necessary to disable TKA. + DisablementSecret []byte +} + +// TKADisableResponse is the JSON response from a /tka/disable RPC. +// This schema describes the successful disablement of the tailnet's +// key authority. +type TKADisableResponse struct { + // Nothing. (yet?) +} + +// TKASubmitSignatureRequest transmits a node-key signature to the control plane. +// +// This is the request schema for a /tka/sign noise RPC. +type TKASubmitSignatureRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. The node-key which + // is being signed is embedded in Signature. + NodeKey key.NodePublic + + // Signature encodes the node-key signature being submitted. + Signature tkatype.MarshaledSignature +} + +// TKASubmitSignatureResponse is the JSON response from a /tka/sign RPC. +type TKASubmitSignatureResponse struct { + // Nothing. (yet?) +} + +// TKASignaturesUsingKeyRequest asks the control plane for +// all signatures which are signed by the provided keyID. +// +// This is the request schema for a /tka/affected-sigs RPC. +type TKASignaturesUsingKeyRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // KeyID is the key we are querying using. + KeyID tkatype.KeyID +} + +// TKASignaturesUsingKeyResponse is the JSON response to +// a /tka/affected-sigs RPC. +// +// It enumerates all signatures which are signed by the +// queried keyID. +type TKASignaturesUsingKeyResponse struct { + Signatures []tkatype.MarshaledSignature +} diff --git a/taildrop/delete.go b/taildrop/delete.go index 7279a7687b2ec..aaef34df1a7e4 100644 --- a/taildrop/delete.go +++ b/taildrop/delete.go @@ -1,205 +1,205 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package taildrop - -import ( - "container/list" - "context" - "io/fs" - "os" - "path/filepath" - "strings" - "sync" - "time" - - "tailscale.com/ipn" - "tailscale.com/syncs" - "tailscale.com/tstime" - "tailscale.com/types/logger" -) - -// deleteDelay is the amount of time to wait before we delete a file. -// A shorter value ensures timely deletion of deleted and partial files, while -// a longer value provides more opportunity for partial files to be resumed. -const deleteDelay = time.Hour - -// fileDeleter manages asynchronous deletion of files after deleteDelay. -type fileDeleter struct { - logf logger.Logf - clock tstime.DefaultClock - dir string - event func(string) // called for certain events; for testing only - - mu sync.Mutex - queue list.List - byName map[string]*list.Element - - emptySignal chan struct{} // signal that the queue is empty - group syncs.WaitGroup - shutdownCtx context.Context - shutdown context.CancelFunc -} - -// deleteFile is a specific file to delete after deleteDelay. -type deleteFile struct { - name string - inserted time.Time -} - -func (d *fileDeleter) Init(m *Manager, eventHook func(string)) { - d.logf = m.opts.Logf - d.clock = m.opts.Clock - d.dir = m.opts.Dir - d.event = eventHook - - d.byName = make(map[string]*list.Element) - d.emptySignal = make(chan struct{}) - d.shutdownCtx, d.shutdown = context.WithCancel(context.Background()) - - // From a cold-start, load the list of partial and deleted files. - // - // Only run this if we have ever received at least one file - // to avoid ever touching the taildrop directory on systems (e.g., MacOS) - // that pop up a security dialog window upon first access. - if m.opts.State == nil { - return - } - if b, _ := m.opts.State.ReadState(ipn.TaildropReceivedKey); len(b) == 0 { - return - } - d.group.Go(func() { - d.event("start full-scan") - defer d.event("end full-scan") - rangeDir(d.dir, func(de fs.DirEntry) bool { - switch { - case d.shutdownCtx.Err() != nil: - return false // terminate early - case !de.Type().IsRegular(): - return true - case strings.HasSuffix(de.Name(), partialSuffix): - // Only enqueue the file for deletion if there is no active put. - nameID := strings.TrimSuffix(de.Name(), partialSuffix) - if i := strings.LastIndexByte(nameID, '.'); i > 0 { - key := incomingFileKey{ClientID(nameID[i+len("."):]), nameID[:i]} - m.incomingFiles.LoadFunc(key, func(_ *incomingFile, loaded bool) { - if !loaded { - d.Insert(de.Name()) - } - }) - } else { - d.Insert(de.Name()) - } - case strings.HasSuffix(de.Name(), deletedSuffix): - // Best-effort immediate deletion of deleted files. - name := strings.TrimSuffix(de.Name(), deletedSuffix) - if os.Remove(filepath.Join(d.dir, name)) == nil { - if os.Remove(filepath.Join(d.dir, de.Name())) == nil { - break - } - } - // Otherwise, enqueue the file for later deletion. - d.Insert(de.Name()) - } - return true - }) - }) -} - -// Insert enqueues baseName for eventual deletion. -func (d *fileDeleter) Insert(baseName string) { - d.mu.Lock() - defer d.mu.Unlock() - if d.shutdownCtx.Err() != nil { - return - } - if _, ok := d.byName[baseName]; ok { - return // already queued for deletion - } - d.byName[baseName] = d.queue.PushBack(&deleteFile{baseName, d.clock.Now()}) - if d.queue.Len() == 1 && d.shutdownCtx.Err() == nil { - d.group.Go(func() { d.waitAndDelete(deleteDelay) }) - } -} - -// waitAndDelete is an asynchronous deletion goroutine. -// At most one waitAndDelete routine is ever running at a time. -// It is not started unless there is at least one file in the queue. -func (d *fileDeleter) waitAndDelete(wait time.Duration) { - tc, ch := d.clock.NewTimer(wait) - defer tc.Stop() // cleanup the timer resource if we stop early - d.event("start waitAndDelete") - defer d.event("end waitAndDelete") - select { - case <-d.shutdownCtx.Done(): - case <-d.emptySignal: - case now := <-ch: - d.mu.Lock() - defer d.mu.Unlock() - - // Iterate over all files to delete, and delete anything old enough. - var next *list.Element - var failed []*list.Element - for elem := d.queue.Front(); elem != nil; elem = next { - next = elem.Next() - file := elem.Value.(*deleteFile) - if now.Sub(file.inserted) < deleteDelay { - break // everything after this is recently inserted - } - - // Delete the expired file. - if name, ok := strings.CutSuffix(file.name, deletedSuffix); ok { - if err := os.Remove(filepath.Join(d.dir, name)); err != nil && !os.IsNotExist(err) { - d.logf("could not delete: %v", redactError(err)) - failed = append(failed, elem) - continue - } - } - if err := os.Remove(filepath.Join(d.dir, file.name)); err != nil && !os.IsNotExist(err) { - d.logf("could not delete: %v", redactError(err)) - failed = append(failed, elem) - continue - } - d.queue.Remove(elem) - delete(d.byName, file.name) - d.event("deleted " + file.name) - } - for _, elem := range failed { - elem.Value.(*deleteFile).inserted = now // retry after deleteDelay - d.queue.MoveToBack(elem) - } - - // If there are still some files to delete, retry again later. - if d.queue.Len() > 0 && d.shutdownCtx.Err() == nil { - file := d.queue.Front().Value.(*deleteFile) - retryAfter := deleteDelay - now.Sub(file.inserted) - d.group.Go(func() { d.waitAndDelete(retryAfter) }) - } - } -} - -// Remove dequeues baseName from eventual deletion. -func (d *fileDeleter) Remove(baseName string) { - d.mu.Lock() - defer d.mu.Unlock() - if elem := d.byName[baseName]; elem != nil { - d.queue.Remove(elem) - delete(d.byName, baseName) - // Signal to terminate any waitAndDelete goroutines. - if d.queue.Len() == 0 { - select { - case <-d.shutdownCtx.Done(): - case d.emptySignal <- struct{}{}: - } - } - } -} - -// Shutdown shuts down the deleter. -// It blocks until all goroutines are stopped. -func (d *fileDeleter) Shutdown() { - d.mu.Lock() // acquire lock to ensure no new goroutines start after shutdown - d.shutdown() - d.mu.Unlock() - d.group.Wait() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "container/list" + "context" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "tailscale.com/ipn" + "tailscale.com/syncs" + "tailscale.com/tstime" + "tailscale.com/types/logger" +) + +// deleteDelay is the amount of time to wait before we delete a file. +// A shorter value ensures timely deletion of deleted and partial files, while +// a longer value provides more opportunity for partial files to be resumed. +const deleteDelay = time.Hour + +// fileDeleter manages asynchronous deletion of files after deleteDelay. +type fileDeleter struct { + logf logger.Logf + clock tstime.DefaultClock + dir string + event func(string) // called for certain events; for testing only + + mu sync.Mutex + queue list.List + byName map[string]*list.Element + + emptySignal chan struct{} // signal that the queue is empty + group syncs.WaitGroup + shutdownCtx context.Context + shutdown context.CancelFunc +} + +// deleteFile is a specific file to delete after deleteDelay. +type deleteFile struct { + name string + inserted time.Time +} + +func (d *fileDeleter) Init(m *Manager, eventHook func(string)) { + d.logf = m.opts.Logf + d.clock = m.opts.Clock + d.dir = m.opts.Dir + d.event = eventHook + + d.byName = make(map[string]*list.Element) + d.emptySignal = make(chan struct{}) + d.shutdownCtx, d.shutdown = context.WithCancel(context.Background()) + + // From a cold-start, load the list of partial and deleted files. + // + // Only run this if we have ever received at least one file + // to avoid ever touching the taildrop directory on systems (e.g., MacOS) + // that pop up a security dialog window upon first access. + if m.opts.State == nil { + return + } + if b, _ := m.opts.State.ReadState(ipn.TaildropReceivedKey); len(b) == 0 { + return + } + d.group.Go(func() { + d.event("start full-scan") + defer d.event("end full-scan") + rangeDir(d.dir, func(de fs.DirEntry) bool { + switch { + case d.shutdownCtx.Err() != nil: + return false // terminate early + case !de.Type().IsRegular(): + return true + case strings.HasSuffix(de.Name(), partialSuffix): + // Only enqueue the file for deletion if there is no active put. + nameID := strings.TrimSuffix(de.Name(), partialSuffix) + if i := strings.LastIndexByte(nameID, '.'); i > 0 { + key := incomingFileKey{ClientID(nameID[i+len("."):]), nameID[:i]} + m.incomingFiles.LoadFunc(key, func(_ *incomingFile, loaded bool) { + if !loaded { + d.Insert(de.Name()) + } + }) + } else { + d.Insert(de.Name()) + } + case strings.HasSuffix(de.Name(), deletedSuffix): + // Best-effort immediate deletion of deleted files. + name := strings.TrimSuffix(de.Name(), deletedSuffix) + if os.Remove(filepath.Join(d.dir, name)) == nil { + if os.Remove(filepath.Join(d.dir, de.Name())) == nil { + break + } + } + // Otherwise, enqueue the file for later deletion. + d.Insert(de.Name()) + } + return true + }) + }) +} + +// Insert enqueues baseName for eventual deletion. +func (d *fileDeleter) Insert(baseName string) { + d.mu.Lock() + defer d.mu.Unlock() + if d.shutdownCtx.Err() != nil { + return + } + if _, ok := d.byName[baseName]; ok { + return // already queued for deletion + } + d.byName[baseName] = d.queue.PushBack(&deleteFile{baseName, d.clock.Now()}) + if d.queue.Len() == 1 && d.shutdownCtx.Err() == nil { + d.group.Go(func() { d.waitAndDelete(deleteDelay) }) + } +} + +// waitAndDelete is an asynchronous deletion goroutine. +// At most one waitAndDelete routine is ever running at a time. +// It is not started unless there is at least one file in the queue. +func (d *fileDeleter) waitAndDelete(wait time.Duration) { + tc, ch := d.clock.NewTimer(wait) + defer tc.Stop() // cleanup the timer resource if we stop early + d.event("start waitAndDelete") + defer d.event("end waitAndDelete") + select { + case <-d.shutdownCtx.Done(): + case <-d.emptySignal: + case now := <-ch: + d.mu.Lock() + defer d.mu.Unlock() + + // Iterate over all files to delete, and delete anything old enough. + var next *list.Element + var failed []*list.Element + for elem := d.queue.Front(); elem != nil; elem = next { + next = elem.Next() + file := elem.Value.(*deleteFile) + if now.Sub(file.inserted) < deleteDelay { + break // everything after this is recently inserted + } + + // Delete the expired file. + if name, ok := strings.CutSuffix(file.name, deletedSuffix); ok { + if err := os.Remove(filepath.Join(d.dir, name)); err != nil && !os.IsNotExist(err) { + d.logf("could not delete: %v", redactError(err)) + failed = append(failed, elem) + continue + } + } + if err := os.Remove(filepath.Join(d.dir, file.name)); err != nil && !os.IsNotExist(err) { + d.logf("could not delete: %v", redactError(err)) + failed = append(failed, elem) + continue + } + d.queue.Remove(elem) + delete(d.byName, file.name) + d.event("deleted " + file.name) + } + for _, elem := range failed { + elem.Value.(*deleteFile).inserted = now // retry after deleteDelay + d.queue.MoveToBack(elem) + } + + // If there are still some files to delete, retry again later. + if d.queue.Len() > 0 && d.shutdownCtx.Err() == nil { + file := d.queue.Front().Value.(*deleteFile) + retryAfter := deleteDelay - now.Sub(file.inserted) + d.group.Go(func() { d.waitAndDelete(retryAfter) }) + } + } +} + +// Remove dequeues baseName from eventual deletion. +func (d *fileDeleter) Remove(baseName string) { + d.mu.Lock() + defer d.mu.Unlock() + if elem := d.byName[baseName]; elem != nil { + d.queue.Remove(elem) + delete(d.byName, baseName) + // Signal to terminate any waitAndDelete goroutines. + if d.queue.Len() == 0 { + select { + case <-d.shutdownCtx.Done(): + case d.emptySignal <- struct{}{}: + } + } + } +} + +// Shutdown shuts down the deleter. +// It blocks until all goroutines are stopped. +func (d *fileDeleter) Shutdown() { + d.mu.Lock() // acquire lock to ensure no new goroutines start after shutdown + d.shutdown() + d.mu.Unlock() + d.group.Wait() +} diff --git a/taildrop/delete_test.go b/taildrop/delete_test.go index b40fa35bfb0e3..5fa4b9c374fdf 100644 --- a/taildrop/delete_test.go +++ b/taildrop/delete_test.go @@ -1,152 +1,152 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package taildrop - -import ( - "os" - "path/filepath" - "slices" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "tailscale.com/ipn" - "tailscale.com/ipn/store/mem" - "tailscale.com/tstest" - "tailscale.com/tstime" - "tailscale.com/util/must" -) - -func TestDeleter(t *testing.T) { - dir := t.TempDir() - must.Do(touchFile(filepath.Join(dir, "foo.partial"))) - must.Do(touchFile(filepath.Join(dir, "bar.partial"))) - must.Do(touchFile(filepath.Join(dir, "fizz"))) - must.Do(touchFile(filepath.Join(dir, "fizz.deleted"))) - must.Do(touchFile(filepath.Join(dir, "buzz.deleted"))) // lacks a matching "buzz" file - - checkDirectory := func(want ...string) { - t.Helper() - var got []string - for _, de := range must.Get(os.ReadDir(dir)) { - got = append(got, de.Name()) - } - slices.Sort(got) - slices.Sort(want) - if diff := cmp.Diff(got, want); diff != "" { - t.Fatalf("directory mismatch (-got +want):\n%s", diff) - } - } - - clock := tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}) - advance := func(d time.Duration) { - t.Helper() - t.Logf("advance: %v", d) - clock.Advance(d) - } - - eventsChan := make(chan string, 1000) - checkEvents := func(want ...string) { - t.Helper() - tm := time.NewTimer(10 * time.Second) - defer tm.Stop() - var got []string - for range want { - select { - case event := <-eventsChan: - t.Logf("event: %s", event) - got = append(got, event) - case <-tm.C: - t.Fatalf("timed out waiting for event: got %v, want %v", got, want) - } - } - slices.Sort(got) - slices.Sort(want) - if diff := cmp.Diff(got, want); diff != "" { - t.Fatalf("events mismatch (-got +want):\n%s", diff) - } - } - eventHook := func(event string) { eventsChan <- event } - - var m Manager - var fd fileDeleter - m.opts.Logf = t.Logf - m.opts.Clock = tstime.DefaultClock{Clock: clock} - m.opts.Dir = dir - m.opts.State = must.Get(mem.New(nil, "")) - must.Do(m.opts.State.WriteState(ipn.TaildropReceivedKey, []byte{1})) - fd.Init(&m, eventHook) - defer fd.Shutdown() - insert := func(name string) { - t.Helper() - t.Logf("insert: %v", name) - fd.Insert(name) - } - remove := func(name string) { - t.Helper() - t.Logf("remove: %v", name) - fd.Remove(name) - } - - checkEvents("start full-scan") - checkEvents("end full-scan", "start waitAndDelete") - checkDirectory("foo.partial", "bar.partial", "buzz.deleted") - - advance(deleteDelay / 2) - checkDirectory("foo.partial", "bar.partial", "buzz.deleted") - advance(deleteDelay / 2) - checkEvents("deleted foo.partial", "deleted bar.partial", "deleted buzz.deleted") - checkEvents("end waitAndDelete") - checkDirectory() - - must.Do(touchFile(filepath.Join(dir, "one.partial"))) - insert("one.partial") - checkEvents("start waitAndDelete") - advance(deleteDelay / 4) - must.Do(touchFile(filepath.Join(dir, "two.partial"))) - insert("two.partial") - advance(deleteDelay / 4) - must.Do(touchFile(filepath.Join(dir, "three.partial"))) - insert("three.partial") - advance(deleteDelay / 4) - must.Do(touchFile(filepath.Join(dir, "four.partial"))) - insert("four.partial") - - advance(deleteDelay / 4) - checkEvents("deleted one.partial") - checkDirectory("two.partial", "three.partial", "four.partial") - checkEvents("end waitAndDelete", "start waitAndDelete") - - advance(deleteDelay / 4) - checkEvents("deleted two.partial") - checkDirectory("three.partial", "four.partial") - checkEvents("end waitAndDelete", "start waitAndDelete") - - advance(deleteDelay / 4) - checkEvents("deleted three.partial") - checkDirectory("four.partial") - checkEvents("end waitAndDelete", "start waitAndDelete") - - advance(deleteDelay / 4) - checkEvents("deleted four.partial") - checkDirectory() - checkEvents("end waitAndDelete") - - insert("wuzz.partial") - checkEvents("start waitAndDelete") - remove("wuzz.partial") - checkEvents("end waitAndDelete") -} - -// Test that the asynchronous full scan of the taildrop directory does not occur -// on a cold start if taildrop has never received any files. -func TestDeleterInitWithoutTaildrop(t *testing.T) { - var m Manager - var fd fileDeleter - m.opts.Logf = t.Logf - m.opts.Dir = t.TempDir() - m.opts.State = must.Get(mem.New(nil, "")) - fd.Init(&m, func(event string) { t.Errorf("unexpected event: %v", event) }) - fd.Shutdown() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "os" + "path/filepath" + "slices" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "tailscale.com/ipn" + "tailscale.com/ipn/store/mem" + "tailscale.com/tstest" + "tailscale.com/tstime" + "tailscale.com/util/must" +) + +func TestDeleter(t *testing.T) { + dir := t.TempDir() + must.Do(touchFile(filepath.Join(dir, "foo.partial"))) + must.Do(touchFile(filepath.Join(dir, "bar.partial"))) + must.Do(touchFile(filepath.Join(dir, "fizz"))) + must.Do(touchFile(filepath.Join(dir, "fizz.deleted"))) + must.Do(touchFile(filepath.Join(dir, "buzz.deleted"))) // lacks a matching "buzz" file + + checkDirectory := func(want ...string) { + t.Helper() + var got []string + for _, de := range must.Get(os.ReadDir(dir)) { + got = append(got, de.Name()) + } + slices.Sort(got) + slices.Sort(want) + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("directory mismatch (-got +want):\n%s", diff) + } + } + + clock := tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}) + advance := func(d time.Duration) { + t.Helper() + t.Logf("advance: %v", d) + clock.Advance(d) + } + + eventsChan := make(chan string, 1000) + checkEvents := func(want ...string) { + t.Helper() + tm := time.NewTimer(10 * time.Second) + defer tm.Stop() + var got []string + for range want { + select { + case event := <-eventsChan: + t.Logf("event: %s", event) + got = append(got, event) + case <-tm.C: + t.Fatalf("timed out waiting for event: got %v, want %v", got, want) + } + } + slices.Sort(got) + slices.Sort(want) + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("events mismatch (-got +want):\n%s", diff) + } + } + eventHook := func(event string) { eventsChan <- event } + + var m Manager + var fd fileDeleter + m.opts.Logf = t.Logf + m.opts.Clock = tstime.DefaultClock{Clock: clock} + m.opts.Dir = dir + m.opts.State = must.Get(mem.New(nil, "")) + must.Do(m.opts.State.WriteState(ipn.TaildropReceivedKey, []byte{1})) + fd.Init(&m, eventHook) + defer fd.Shutdown() + insert := func(name string) { + t.Helper() + t.Logf("insert: %v", name) + fd.Insert(name) + } + remove := func(name string) { + t.Helper() + t.Logf("remove: %v", name) + fd.Remove(name) + } + + checkEvents("start full-scan") + checkEvents("end full-scan", "start waitAndDelete") + checkDirectory("foo.partial", "bar.partial", "buzz.deleted") + + advance(deleteDelay / 2) + checkDirectory("foo.partial", "bar.partial", "buzz.deleted") + advance(deleteDelay / 2) + checkEvents("deleted foo.partial", "deleted bar.partial", "deleted buzz.deleted") + checkEvents("end waitAndDelete") + checkDirectory() + + must.Do(touchFile(filepath.Join(dir, "one.partial"))) + insert("one.partial") + checkEvents("start waitAndDelete") + advance(deleteDelay / 4) + must.Do(touchFile(filepath.Join(dir, "two.partial"))) + insert("two.partial") + advance(deleteDelay / 4) + must.Do(touchFile(filepath.Join(dir, "three.partial"))) + insert("three.partial") + advance(deleteDelay / 4) + must.Do(touchFile(filepath.Join(dir, "four.partial"))) + insert("four.partial") + + advance(deleteDelay / 4) + checkEvents("deleted one.partial") + checkDirectory("two.partial", "three.partial", "four.partial") + checkEvents("end waitAndDelete", "start waitAndDelete") + + advance(deleteDelay / 4) + checkEvents("deleted two.partial") + checkDirectory("three.partial", "four.partial") + checkEvents("end waitAndDelete", "start waitAndDelete") + + advance(deleteDelay / 4) + checkEvents("deleted three.partial") + checkDirectory("four.partial") + checkEvents("end waitAndDelete", "start waitAndDelete") + + advance(deleteDelay / 4) + checkEvents("deleted four.partial") + checkDirectory() + checkEvents("end waitAndDelete") + + insert("wuzz.partial") + checkEvents("start waitAndDelete") + remove("wuzz.partial") + checkEvents("end waitAndDelete") +} + +// Test that the asynchronous full scan of the taildrop directory does not occur +// on a cold start if taildrop has never received any files. +func TestDeleterInitWithoutTaildrop(t *testing.T) { + var m Manager + var fd fileDeleter + m.opts.Logf = t.Logf + m.opts.Dir = t.TempDir() + m.opts.State = must.Get(mem.New(nil, "")) + fd.Init(&m, func(event string) { t.Errorf("unexpected event: %v", event) }) + fd.Shutdown() +} diff --git a/taildrop/resume_test.go b/taildrop/resume_test.go index 8758ddd29d48c..d366340eb6efa 100644 --- a/taildrop/resume_test.go +++ b/taildrop/resume_test.go @@ -1,74 +1,74 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package taildrop - -import ( - "bytes" - "io" - "math/rand" - "os" - "testing" - "testing/iotest" - - "tailscale.com/util/must" -) - -func TestResume(t *testing.T) { - oldBlockSize := blockSize - defer func() { blockSize = oldBlockSize }() - blockSize = 256 - - m := ManagerOptions{Logf: t.Logf, Dir: t.TempDir()}.New() - defer m.Shutdown() - - rn := rand.New(rand.NewSource(0)) - want := make([]byte, 12345) - must.Get(io.ReadFull(rn, want)) - - t.Run("resume-noexist", func(t *testing.T) { - r := io.Reader(bytes.NewReader(want)) - - next, close, err := m.HashPartialFile("", "foo") - must.Do(err) - defer close() - offset, r, err := ResumeReader(r, next) - must.Do(err) - must.Do(close()) // Windows wants the file handle to be closed to rename it. - - must.Get(m.PutFile("", "foo", r, offset, -1)) - got := must.Get(os.ReadFile(must.Get(joinDir(m.opts.Dir, "foo")))) - if !bytes.Equal(got, want) { - t.Errorf("content mismatches") - } - }) - - t.Run("resume-retry", func(t *testing.T) { - rn := rand.New(rand.NewSource(0)) - for i := 0; true; i++ { - r := io.Reader(bytes.NewReader(want)) - - next, close, err := m.HashPartialFile("", "bar") - must.Do(err) - defer close() - offset, r, err := ResumeReader(r, next) - must.Do(err) - must.Do(close()) // Windows wants the file handle to be closed to rename it. - - numWant := rn.Int63n(min(int64(len(want))-offset, 1000) + 1) - if offset < int64(len(want)) { - r = io.MultiReader(io.LimitReader(r, numWant), iotest.ErrReader(io.ErrClosedPipe)) - } - if _, err := m.PutFile("", "bar", r, offset, -1); err == nil { - break - } - if i > 1000 { - t.Fatalf("too many iterations to complete the test") - } - } - got := must.Get(os.ReadFile(must.Get(joinDir(m.opts.Dir, "bar")))) - if !bytes.Equal(got, want) { - t.Errorf("content mismatches") - } - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "bytes" + "io" + "math/rand" + "os" + "testing" + "testing/iotest" + + "tailscale.com/util/must" +) + +func TestResume(t *testing.T) { + oldBlockSize := blockSize + defer func() { blockSize = oldBlockSize }() + blockSize = 256 + + m := ManagerOptions{Logf: t.Logf, Dir: t.TempDir()}.New() + defer m.Shutdown() + + rn := rand.New(rand.NewSource(0)) + want := make([]byte, 12345) + must.Get(io.ReadFull(rn, want)) + + t.Run("resume-noexist", func(t *testing.T) { + r := io.Reader(bytes.NewReader(want)) + + next, close, err := m.HashPartialFile("", "foo") + must.Do(err) + defer close() + offset, r, err := ResumeReader(r, next) + must.Do(err) + must.Do(close()) // Windows wants the file handle to be closed to rename it. + + must.Get(m.PutFile("", "foo", r, offset, -1)) + got := must.Get(os.ReadFile(must.Get(joinDir(m.opts.Dir, "foo")))) + if !bytes.Equal(got, want) { + t.Errorf("content mismatches") + } + }) + + t.Run("resume-retry", func(t *testing.T) { + rn := rand.New(rand.NewSource(0)) + for i := 0; true; i++ { + r := io.Reader(bytes.NewReader(want)) + + next, close, err := m.HashPartialFile("", "bar") + must.Do(err) + defer close() + offset, r, err := ResumeReader(r, next) + must.Do(err) + must.Do(close()) // Windows wants the file handle to be closed to rename it. + + numWant := rn.Int63n(min(int64(len(want))-offset, 1000) + 1) + if offset < int64(len(want)) { + r = io.MultiReader(io.LimitReader(r, numWant), iotest.ErrReader(io.ErrClosedPipe)) + } + if _, err := m.PutFile("", "bar", r, offset, -1); err == nil { + break + } + if i > 1000 { + t.Fatalf("too many iterations to complete the test") + } + } + got := must.Get(os.ReadFile(must.Get(joinDir(m.opts.Dir, "bar")))) + if !bytes.Equal(got, want) { + t.Errorf("content mismatches") + } + }) +} diff --git a/taildrop/retrieve.go b/taildrop/retrieve.go index 527f8caed2bf5..3e37b492adc0a 100644 --- a/taildrop/retrieve.go +++ b/taildrop/retrieve.go @@ -1,178 +1,178 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package taildrop - -import ( - "context" - "errors" - "io" - "io/fs" - "os" - "path/filepath" - "runtime" - "sort" - "time" - - "tailscale.com/client/tailscale/apitype" - "tailscale.com/logtail/backoff" -) - -// HasFilesWaiting reports whether any files are buffered in [Handler.Dir]. -// This always returns false when [Handler.DirectFileMode] is false. -func (m *Manager) HasFilesWaiting() (has bool) { - if m == nil || m.opts.Dir == "" || m.opts.DirectFileMode { - return false - } - - // Optimization: this is usually empty, so avoid opening - // the directory and checking. We can't cache the actual - // has-files-or-not values as the macOS/iOS client might - // in the future use+delete the files directly. So only - // keep this negative cache. - totalReceived := m.totalReceived.Load() - if totalReceived == m.emptySince.Load() { - return false - } - - // Check whether there is at least one one waiting file. - err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { - name := de.Name() - if isPartialOrDeleted(name) || !de.Type().IsRegular() { - return true - } - _, err := os.Stat(filepath.Join(m.opts.Dir, name+deletedSuffix)) - if os.IsNotExist(err) { - has = true - return false - } - return true - }) - - // If there are no more waiting files, record totalReceived as emptySince - // so that we can short-circuit the expensive directory traversal - // if no files have been received after the start of this call. - if err == nil && !has { - m.emptySince.Store(totalReceived) - } - return has -} - -// WaitingFiles returns the list of files that have been sent by a -// peer that are waiting in [Handler.Dir]. -// This always returns nil when [Handler.DirectFileMode] is false. -func (m *Manager) WaitingFiles() (ret []apitype.WaitingFile, err error) { - if m == nil || m.opts.Dir == "" { - return nil, ErrNoTaildrop - } - if m.opts.DirectFileMode { - return nil, nil - } - if err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { - name := de.Name() - if isPartialOrDeleted(name) || !de.Type().IsRegular() { - return true - } - _, err := os.Stat(filepath.Join(m.opts.Dir, name+deletedSuffix)) - if os.IsNotExist(err) { - fi, err := de.Info() - if err != nil { - return true - } - ret = append(ret, apitype.WaitingFile{ - Name: filepath.Base(name), - Size: fi.Size(), - }) - } - return true - }); err != nil { - return nil, redactError(err) - } - sort.Slice(ret, func(i, j int) bool { return ret[i].Name < ret[j].Name }) - return ret, nil -} - -// DeleteFile deletes a file of the given baseName from [Handler.Dir]. -// This method is only allowed when [Handler.DirectFileMode] is false. -func (m *Manager) DeleteFile(baseName string) error { - if m == nil || m.opts.Dir == "" { - return ErrNoTaildrop - } - if m.opts.DirectFileMode { - return errors.New("deletes not allowed in direct mode") - } - path, err := joinDir(m.opts.Dir, baseName) - if err != nil { - return err - } - var bo *backoff.Backoff - logf := m.opts.Logf - t0 := m.opts.Clock.Now() - for { - err := os.Remove(path) - if err != nil && !os.IsNotExist(err) { - err = redactError(err) - // Put a retry loop around deletes on Windows. - // - // Windows file descriptor closes are effectively asynchronous, - // as a bunch of hooks run on/after close, - // and we can't necessarily delete the file for a while after close, - // as we need to wait for everybody to be done with it. - // On Windows, unlike Unix, a file can't be deleted if it's open anywhere. - // So try a few times but ultimately just leave a "foo.jpg.deleted" - // marker file to note that it's deleted and we clean it up later. - if runtime.GOOS == "windows" { - if bo == nil { - bo = backoff.NewBackoff("delete-retry", logf, 1*time.Second) - } - if m.opts.Clock.Since(t0) < 5*time.Second { - bo.BackOff(context.Background(), err) - continue - } - if err := touchFile(path + deletedSuffix); err != nil { - logf("peerapi: failed to leave deleted marker: %v", err) - } - m.deleter.Insert(baseName + deletedSuffix) - } - logf("peerapi: failed to DeleteFile: %v", err) - return err - } - return nil - } -} - -func touchFile(path string) error { - f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0666) - if err != nil { - return redactError(err) - } - return f.Close() -} - -// OpenFile opens a file of the given baseName from [Handler.Dir]. -// This method is only allowed when [Handler.DirectFileMode] is false. -func (m *Manager) OpenFile(baseName string) (rc io.ReadCloser, size int64, err error) { - if m == nil || m.opts.Dir == "" { - return nil, 0, ErrNoTaildrop - } - if m.opts.DirectFileMode { - return nil, 0, errors.New("opens not allowed in direct mode") - } - path, err := joinDir(m.opts.Dir, baseName) - if err != nil { - return nil, 0, err - } - if _, err := os.Stat(path + deletedSuffix); err == nil { - return nil, 0, redactError(&fs.PathError{Op: "open", Path: path, Err: fs.ErrNotExist}) - } - f, err := os.Open(path) - if err != nil { - return nil, 0, redactError(err) - } - fi, err := f.Stat() - if err != nil { - f.Close() - return nil, 0, redactError(err) - } - return f, fi.Size(), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "context" + "errors" + "io" + "io/fs" + "os" + "path/filepath" + "runtime" + "sort" + "time" + + "tailscale.com/client/tailscale/apitype" + "tailscale.com/logtail/backoff" +) + +// HasFilesWaiting reports whether any files are buffered in [Handler.Dir]. +// This always returns false when [Handler.DirectFileMode] is false. +func (m *Manager) HasFilesWaiting() (has bool) { + if m == nil || m.opts.Dir == "" || m.opts.DirectFileMode { + return false + } + + // Optimization: this is usually empty, so avoid opening + // the directory and checking. We can't cache the actual + // has-files-or-not values as the macOS/iOS client might + // in the future use+delete the files directly. So only + // keep this negative cache. + totalReceived := m.totalReceived.Load() + if totalReceived == m.emptySince.Load() { + return false + } + + // Check whether there is at least one one waiting file. + err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { + name := de.Name() + if isPartialOrDeleted(name) || !de.Type().IsRegular() { + return true + } + _, err := os.Stat(filepath.Join(m.opts.Dir, name+deletedSuffix)) + if os.IsNotExist(err) { + has = true + return false + } + return true + }) + + // If there are no more waiting files, record totalReceived as emptySince + // so that we can short-circuit the expensive directory traversal + // if no files have been received after the start of this call. + if err == nil && !has { + m.emptySince.Store(totalReceived) + } + return has +} + +// WaitingFiles returns the list of files that have been sent by a +// peer that are waiting in [Handler.Dir]. +// This always returns nil when [Handler.DirectFileMode] is false. +func (m *Manager) WaitingFiles() (ret []apitype.WaitingFile, err error) { + if m == nil || m.opts.Dir == "" { + return nil, ErrNoTaildrop + } + if m.opts.DirectFileMode { + return nil, nil + } + if err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { + name := de.Name() + if isPartialOrDeleted(name) || !de.Type().IsRegular() { + return true + } + _, err := os.Stat(filepath.Join(m.opts.Dir, name+deletedSuffix)) + if os.IsNotExist(err) { + fi, err := de.Info() + if err != nil { + return true + } + ret = append(ret, apitype.WaitingFile{ + Name: filepath.Base(name), + Size: fi.Size(), + }) + } + return true + }); err != nil { + return nil, redactError(err) + } + sort.Slice(ret, func(i, j int) bool { return ret[i].Name < ret[j].Name }) + return ret, nil +} + +// DeleteFile deletes a file of the given baseName from [Handler.Dir]. +// This method is only allowed when [Handler.DirectFileMode] is false. +func (m *Manager) DeleteFile(baseName string) error { + if m == nil || m.opts.Dir == "" { + return ErrNoTaildrop + } + if m.opts.DirectFileMode { + return errors.New("deletes not allowed in direct mode") + } + path, err := joinDir(m.opts.Dir, baseName) + if err != nil { + return err + } + var bo *backoff.Backoff + logf := m.opts.Logf + t0 := m.opts.Clock.Now() + for { + err := os.Remove(path) + if err != nil && !os.IsNotExist(err) { + err = redactError(err) + // Put a retry loop around deletes on Windows. + // + // Windows file descriptor closes are effectively asynchronous, + // as a bunch of hooks run on/after close, + // and we can't necessarily delete the file for a while after close, + // as we need to wait for everybody to be done with it. + // On Windows, unlike Unix, a file can't be deleted if it's open anywhere. + // So try a few times but ultimately just leave a "foo.jpg.deleted" + // marker file to note that it's deleted and we clean it up later. + if runtime.GOOS == "windows" { + if bo == nil { + bo = backoff.NewBackoff("delete-retry", logf, 1*time.Second) + } + if m.opts.Clock.Since(t0) < 5*time.Second { + bo.BackOff(context.Background(), err) + continue + } + if err := touchFile(path + deletedSuffix); err != nil { + logf("peerapi: failed to leave deleted marker: %v", err) + } + m.deleter.Insert(baseName + deletedSuffix) + } + logf("peerapi: failed to DeleteFile: %v", err) + return err + } + return nil + } +} + +func touchFile(path string) error { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0666) + if err != nil { + return redactError(err) + } + return f.Close() +} + +// OpenFile opens a file of the given baseName from [Handler.Dir]. +// This method is only allowed when [Handler.DirectFileMode] is false. +func (m *Manager) OpenFile(baseName string) (rc io.ReadCloser, size int64, err error) { + if m == nil || m.opts.Dir == "" { + return nil, 0, ErrNoTaildrop + } + if m.opts.DirectFileMode { + return nil, 0, errors.New("opens not allowed in direct mode") + } + path, err := joinDir(m.opts.Dir, baseName) + if err != nil { + return nil, 0, err + } + if _, err := os.Stat(path + deletedSuffix); err == nil { + return nil, 0, redactError(&fs.PathError{Op: "open", Path: path, Err: fs.ErrNotExist}) + } + f, err := os.Open(path) + if err != nil { + return nil, 0, redactError(err) + } + fi, err := f.Stat() + if err != nil { + f.Close() + return nil, 0, redactError(err) + } + return f, fi.Size(), nil +} diff --git a/tempfork/gliderlabs/ssh/LICENSE b/tempfork/gliderlabs/ssh/LICENSE index 80b2b2baa7d2f..4a03f02a28185 100644 --- a/tempfork/gliderlabs/ssh/LICENSE +++ b/tempfork/gliderlabs/ssh/LICENSE @@ -1,27 +1,27 @@ -Copyright (c) 2016 Glider Labs. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Glider Labs nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +Copyright (c) 2016 Glider Labs. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Glider Labs nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/tempfork/gliderlabs/ssh/README.md b/tempfork/gliderlabs/ssh/README.md index ecef6b7c47895..79b5b89fa8a94 100644 --- a/tempfork/gliderlabs/ssh/README.md +++ b/tempfork/gliderlabs/ssh/README.md @@ -1,96 +1,96 @@ -# gliderlabs/ssh - -[![GoDoc](https://godoc.org/tailscale.com/tempfork/gliderlabs/ssh?status.svg)](https://godoc.org/github.com/gliderlabs/ssh) -[![CircleCI](https://img.shields.io/circleci/project/github/gliderlabs/ssh.svg)](https://circleci.com/gh/gliderlabs/ssh) -[![Go Report Card](https://goreportcard.com/badge/tailscale.com/tempfork/gliderlabs/ssh)](https://goreportcard.com/report/github.com/gliderlabs/ssh) -[![OpenCollective](https://opencollective.com/ssh/sponsors/badge.svg)](#sponsors) -[![Slack](http://slack.gliderlabs.com/badge.svg)](http://slack.gliderlabs.com) -[![Email Updates](https://img.shields.io/badge/updates-subscribe-yellow.svg)](https://app.convertkit.com/landing_pages/243312) - -> The Glider Labs SSH server package is dope. —[@bradfitz](https://twitter.com/bradfitz), Go team member - -This Go package wraps the [crypto/ssh -package](https://godoc.org/golang.org/x/crypto/ssh) with a higher-level API for -building SSH servers. The goal of the API was to make it as simple as using -[net/http](https://golang.org/pkg/net/http/), so the API is very similar: - -```go - package main - - import ( - "tailscale.com/tempfork/gliderlabs/ssh" - "io" - "log" - ) - - func main() { - ssh.Handle(func(s ssh.Session) { - io.WriteString(s, "Hello world\n") - }) - - log.Fatal(ssh.ListenAndServe(":2222", nil)) - } - -``` -This package was built by [@progrium](https://twitter.com/progrium) after working on nearly a dozen projects at Glider Labs using SSH and collaborating with [@shazow](https://twitter.com/shazow) (known for [ssh-chat](https://github.com/shazow/ssh-chat)). - -## Examples - -A bunch of great examples are in the `_examples` directory. - -## Usage - -[See GoDoc reference.](https://godoc.org/tailscale.com/tempfork/gliderlabs/ssh) - -## Contributing - -Pull requests are welcome! However, since this project is very much about API -design, please submit API changes as issues to discuss before submitting PRs. - -Also, you can [join our Slack](http://slack.gliderlabs.com) to discuss as well. - -## Roadmap - -* Non-session channel handlers -* Cleanup callback API -* 1.0 release -* High-level client? - -## Sponsors - -Become a sponsor and get your logo on our README on Github with a link to your site. [[Become a sponsor](https://opencollective.com/ssh#sponsor)] - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -## License - -[BSD](LICENSE) +# gliderlabs/ssh + +[![GoDoc](https://godoc.org/tailscale.com/tempfork/gliderlabs/ssh?status.svg)](https://godoc.org/github.com/gliderlabs/ssh) +[![CircleCI](https://img.shields.io/circleci/project/github/gliderlabs/ssh.svg)](https://circleci.com/gh/gliderlabs/ssh) +[![Go Report Card](https://goreportcard.com/badge/tailscale.com/tempfork/gliderlabs/ssh)](https://goreportcard.com/report/github.com/gliderlabs/ssh) +[![OpenCollective](https://opencollective.com/ssh/sponsors/badge.svg)](#sponsors) +[![Slack](http://slack.gliderlabs.com/badge.svg)](http://slack.gliderlabs.com) +[![Email Updates](https://img.shields.io/badge/updates-subscribe-yellow.svg)](https://app.convertkit.com/landing_pages/243312) + +> The Glider Labs SSH server package is dope. —[@bradfitz](https://twitter.com/bradfitz), Go team member + +This Go package wraps the [crypto/ssh +package](https://godoc.org/golang.org/x/crypto/ssh) with a higher-level API for +building SSH servers. The goal of the API was to make it as simple as using +[net/http](https://golang.org/pkg/net/http/), so the API is very similar: + +```go + package main + + import ( + "tailscale.com/tempfork/gliderlabs/ssh" + "io" + "log" + ) + + func main() { + ssh.Handle(func(s ssh.Session) { + io.WriteString(s, "Hello world\n") + }) + + log.Fatal(ssh.ListenAndServe(":2222", nil)) + } + +``` +This package was built by [@progrium](https://twitter.com/progrium) after working on nearly a dozen projects at Glider Labs using SSH and collaborating with [@shazow](https://twitter.com/shazow) (known for [ssh-chat](https://github.com/shazow/ssh-chat)). + +## Examples + +A bunch of great examples are in the `_examples` directory. + +## Usage + +[See GoDoc reference.](https://godoc.org/tailscale.com/tempfork/gliderlabs/ssh) + +## Contributing + +Pull requests are welcome! However, since this project is very much about API +design, please submit API changes as issues to discuss before submitting PRs. + +Also, you can [join our Slack](http://slack.gliderlabs.com) to discuss as well. + +## Roadmap + +* Non-session channel handlers +* Cleanup callback API +* 1.0 release +* High-level client? + +## Sponsors + +Become a sponsor and get your logo on our README on Github with a link to your site. [[Become a sponsor](https://opencollective.com/ssh#sponsor)] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +## License + +[BSD](LICENSE) diff --git a/tempfork/gliderlabs/ssh/agent.go b/tempfork/gliderlabs/ssh/agent.go index 3da665292a447..86a5bce7f8ebc 100644 --- a/tempfork/gliderlabs/ssh/agent.go +++ b/tempfork/gliderlabs/ssh/agent.go @@ -1,83 +1,83 @@ -package ssh - -import ( - "io" - "net" - "os" - "path" - "sync" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -const ( - agentRequestType = "auth-agent-req@openssh.com" - agentChannelType = "auth-agent@openssh.com" - - agentTempDir = "auth-agent" - agentListenFile = "listener.sock" -) - -// contextKeyAgentRequest is an internal context key for storing if the -// client requested agent forwarding -var contextKeyAgentRequest = &contextKey{"auth-agent-req"} - -// SetAgentRequested sets up the session context so that AgentRequested -// returns true. -func SetAgentRequested(ctx Context) { - ctx.SetValue(contextKeyAgentRequest, true) -} - -// AgentRequested returns true if the client requested agent forwarding. -func AgentRequested(sess Session) bool { - return sess.Context().Value(contextKeyAgentRequest) == true -} - -// NewAgentListener sets up a temporary Unix socket that can be communicated -// to the session environment and used for forwarding connections. -func NewAgentListener() (net.Listener, error) { - dir, err := os.MkdirTemp("", agentTempDir) - if err != nil { - return nil, err - } - l, err := net.Listen("unix", path.Join(dir, agentListenFile)) - if err != nil { - return nil, err - } - return l, nil -} - -// ForwardAgentConnections takes connections from a listener to proxy into the -// session on the OpenSSH channel for agent connections. It blocks and services -// connections until the listener stop accepting. -func ForwardAgentConnections(l net.Listener, s Session) { - sshConn := s.Context().Value(ContextKeyConn).(gossh.Conn) - for { - conn, err := l.Accept() - if err != nil { - return - } - go func(conn net.Conn) { - defer conn.Close() - channel, reqs, err := sshConn.OpenChannel(agentChannelType, nil) - if err != nil { - return - } - defer channel.Close() - go gossh.DiscardRequests(reqs) - var wg sync.WaitGroup - wg.Add(2) - go func() { - io.Copy(conn, channel) - conn.(*net.UnixConn).CloseWrite() - wg.Done() - }() - go func() { - io.Copy(channel, conn) - channel.CloseWrite() - wg.Done() - }() - wg.Wait() - }(conn) - } -} +package ssh + +import ( + "io" + "net" + "os" + "path" + "sync" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +const ( + agentRequestType = "auth-agent-req@openssh.com" + agentChannelType = "auth-agent@openssh.com" + + agentTempDir = "auth-agent" + agentListenFile = "listener.sock" +) + +// contextKeyAgentRequest is an internal context key for storing if the +// client requested agent forwarding +var contextKeyAgentRequest = &contextKey{"auth-agent-req"} + +// SetAgentRequested sets up the session context so that AgentRequested +// returns true. +func SetAgentRequested(ctx Context) { + ctx.SetValue(contextKeyAgentRequest, true) +} + +// AgentRequested returns true if the client requested agent forwarding. +func AgentRequested(sess Session) bool { + return sess.Context().Value(contextKeyAgentRequest) == true +} + +// NewAgentListener sets up a temporary Unix socket that can be communicated +// to the session environment and used for forwarding connections. +func NewAgentListener() (net.Listener, error) { + dir, err := os.MkdirTemp("", agentTempDir) + if err != nil { + return nil, err + } + l, err := net.Listen("unix", path.Join(dir, agentListenFile)) + if err != nil { + return nil, err + } + return l, nil +} + +// ForwardAgentConnections takes connections from a listener to proxy into the +// session on the OpenSSH channel for agent connections. It blocks and services +// connections until the listener stop accepting. +func ForwardAgentConnections(l net.Listener, s Session) { + sshConn := s.Context().Value(ContextKeyConn).(gossh.Conn) + for { + conn, err := l.Accept() + if err != nil { + return + } + go func(conn net.Conn) { + defer conn.Close() + channel, reqs, err := sshConn.OpenChannel(agentChannelType, nil) + if err != nil { + return + } + defer channel.Close() + go gossh.DiscardRequests(reqs) + var wg sync.WaitGroup + wg.Add(2) + go func() { + io.Copy(conn, channel) + conn.(*net.UnixConn).CloseWrite() + wg.Done() + }() + go func() { + io.Copy(channel, conn) + channel.CloseWrite() + wg.Done() + }() + wg.Wait() + }(conn) + } +} diff --git a/tempfork/gliderlabs/ssh/conn.go b/tempfork/gliderlabs/ssh/conn.go index ec277bf27676f..ebef8845baccb 100644 --- a/tempfork/gliderlabs/ssh/conn.go +++ b/tempfork/gliderlabs/ssh/conn.go @@ -1,55 +1,55 @@ -package ssh - -import ( - "context" - "net" - "time" -) - -type serverConn struct { - net.Conn - - idleTimeout time.Duration - maxDeadline time.Time - closeCanceler context.CancelFunc -} - -func (c *serverConn) Write(p []byte) (n int, err error) { - c.updateDeadline() - n, err = c.Conn.Write(p) - if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { - c.closeCanceler() - } - return -} - -func (c *serverConn) Read(b []byte) (n int, err error) { - c.updateDeadline() - n, err = c.Conn.Read(b) - if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { - c.closeCanceler() - } - return -} - -func (c *serverConn) Close() (err error) { - err = c.Conn.Close() - if c.closeCanceler != nil { - c.closeCanceler() - } - return -} - -func (c *serverConn) updateDeadline() { - switch { - case c.idleTimeout > 0: - idleDeadline := time.Now().Add(c.idleTimeout) - if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() { - c.Conn.SetDeadline(idleDeadline) - return - } - fallthrough - default: - c.Conn.SetDeadline(c.maxDeadline) - } -} +package ssh + +import ( + "context" + "net" + "time" +) + +type serverConn struct { + net.Conn + + idleTimeout time.Duration + maxDeadline time.Time + closeCanceler context.CancelFunc +} + +func (c *serverConn) Write(p []byte) (n int, err error) { + c.updateDeadline() + n, err = c.Conn.Write(p) + if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) Read(b []byte) (n int, err error) { + c.updateDeadline() + n, err = c.Conn.Read(b) + if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) Close() (err error) { + err = c.Conn.Close() + if c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) updateDeadline() { + switch { + case c.idleTimeout > 0: + idleDeadline := time.Now().Add(c.idleTimeout) + if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() { + c.Conn.SetDeadline(idleDeadline) + return + } + fallthrough + default: + c.Conn.SetDeadline(c.maxDeadline) + } +} diff --git a/tempfork/gliderlabs/ssh/context.go b/tempfork/gliderlabs/ssh/context.go index 6f7245574060d..d43de6f09c8a5 100644 --- a/tempfork/gliderlabs/ssh/context.go +++ b/tempfork/gliderlabs/ssh/context.go @@ -1,164 +1,164 @@ -package ssh - -import ( - "context" - "encoding/hex" - "net" - "sync" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -// contextKey is a value for use with context.WithValue. It's used as -// a pointer so it fits in an interface{} without allocation. -type contextKey struct { - name string -} - -var ( - // ContextKeyUser is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeyUser = &contextKey{"user"} - - // ContextKeySessionID is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeySessionID = &contextKey{"session-id"} - - // ContextKeyPermissions is a context key for use with Contexts in this package. - // The associated value will be of type *Permissions. - ContextKeyPermissions = &contextKey{"permissions"} - - // ContextKeyClientVersion is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeyClientVersion = &contextKey{"client-version"} - - // ContextKeyServerVersion is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeyServerVersion = &contextKey{"server-version"} - - // ContextKeyLocalAddr is a context key for use with Contexts in this package. - // The associated value will be of type net.Addr. - ContextKeyLocalAddr = &contextKey{"local-addr"} - - // ContextKeyRemoteAddr is a context key for use with Contexts in this package. - // The associated value will be of type net.Addr. - ContextKeyRemoteAddr = &contextKey{"remote-addr"} - - // ContextKeyServer is a context key for use with Contexts in this package. - // The associated value will be of type *Server. - ContextKeyServer = &contextKey{"ssh-server"} - - // ContextKeyConn is a context key for use with Contexts in this package. - // The associated value will be of type gossh.ServerConn. - ContextKeyConn = &contextKey{"ssh-conn"} - - // ContextKeyPublicKey is a context key for use with Contexts in this package. - // The associated value will be of type PublicKey. - ContextKeyPublicKey = &contextKey{"public-key"} - - ContextKeySendAuthBanner = &contextKey{"send-auth-banner"} -) - -// Context is a package specific context interface. It exposes connection -// metadata and allows new values to be easily written to it. It's used in -// authentication handlers and callbacks, and its underlying context.Context is -// exposed on Session in the session Handler. A connection-scoped lock is also -// embedded in the context to make it easier to limit operations per-connection. -type Context interface { - context.Context - sync.Locker - - // User returns the username used when establishing the SSH connection. - User() string - - // SessionID returns the session hash. - SessionID() string - - // ClientVersion returns the version reported by the client. - ClientVersion() string - - // ServerVersion returns the version reported by the server. - ServerVersion() string - - // RemoteAddr returns the remote address for this connection. - RemoteAddr() net.Addr - - // LocalAddr returns the local address for this connection. - LocalAddr() net.Addr - - // Permissions returns the Permissions object used for this connection. - Permissions() *Permissions - - // SetValue allows you to easily write new values into the underlying context. - SetValue(key, value interface{}) - - SendAuthBanner(banner string) error -} - -type sshContext struct { - context.Context - *sync.Mutex -} - -func newContext(srv *Server) (*sshContext, context.CancelFunc) { - innerCtx, cancel := context.WithCancel(context.Background()) - ctx := &sshContext{innerCtx, &sync.Mutex{}} - ctx.SetValue(ContextKeyServer, srv) - perms := &Permissions{&gossh.Permissions{}} - ctx.SetValue(ContextKeyPermissions, perms) - return ctx, cancel -} - -// this is separate from newContext because we will get ConnMetadata -// at different points so it needs to be applied separately -func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { - if ctx.Value(ContextKeySessionID) != nil { - return - } - ctx.SetValue(ContextKeySessionID, hex.EncodeToString(conn.SessionID())) - ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion())) - ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion())) - ctx.SetValue(ContextKeyUser, conn.User()) - ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr()) - ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr()) - ctx.SetValue(ContextKeySendAuthBanner, conn.SendAuthBanner) -} - -func (ctx *sshContext) SetValue(key, value interface{}) { - ctx.Context = context.WithValue(ctx.Context, key, value) -} - -func (ctx *sshContext) User() string { - return ctx.Value(ContextKeyUser).(string) -} - -func (ctx *sshContext) SessionID() string { - return ctx.Value(ContextKeySessionID).(string) -} - -func (ctx *sshContext) ClientVersion() string { - return ctx.Value(ContextKeyClientVersion).(string) -} - -func (ctx *sshContext) ServerVersion() string { - return ctx.Value(ContextKeyServerVersion).(string) -} - -func (ctx *sshContext) RemoteAddr() net.Addr { - if addr, ok := ctx.Value(ContextKeyRemoteAddr).(net.Addr); ok { - return addr - } - return nil -} - -func (ctx *sshContext) LocalAddr() net.Addr { - return ctx.Value(ContextKeyLocalAddr).(net.Addr) -} - -func (ctx *sshContext) Permissions() *Permissions { - return ctx.Value(ContextKeyPermissions).(*Permissions) -} - -func (ctx *sshContext) SendAuthBanner(msg string) error { - return ctx.Value(ContextKeySendAuthBanner).(func(string) error)(msg) -} +package ssh + +import ( + "context" + "encoding/hex" + "net" + "sync" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +var ( + // ContextKeyUser is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeyUser = &contextKey{"user"} + + // ContextKeySessionID is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeySessionID = &contextKey{"session-id"} + + // ContextKeyPermissions is a context key for use with Contexts in this package. + // The associated value will be of type *Permissions. + ContextKeyPermissions = &contextKey{"permissions"} + + // ContextKeyClientVersion is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeyClientVersion = &contextKey{"client-version"} + + // ContextKeyServerVersion is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeyServerVersion = &contextKey{"server-version"} + + // ContextKeyLocalAddr is a context key for use with Contexts in this package. + // The associated value will be of type net.Addr. + ContextKeyLocalAddr = &contextKey{"local-addr"} + + // ContextKeyRemoteAddr is a context key for use with Contexts in this package. + // The associated value will be of type net.Addr. + ContextKeyRemoteAddr = &contextKey{"remote-addr"} + + // ContextKeyServer is a context key for use with Contexts in this package. + // The associated value will be of type *Server. + ContextKeyServer = &contextKey{"ssh-server"} + + // ContextKeyConn is a context key for use with Contexts in this package. + // The associated value will be of type gossh.ServerConn. + ContextKeyConn = &contextKey{"ssh-conn"} + + // ContextKeyPublicKey is a context key for use with Contexts in this package. + // The associated value will be of type PublicKey. + ContextKeyPublicKey = &contextKey{"public-key"} + + ContextKeySendAuthBanner = &contextKey{"send-auth-banner"} +) + +// Context is a package specific context interface. It exposes connection +// metadata and allows new values to be easily written to it. It's used in +// authentication handlers and callbacks, and its underlying context.Context is +// exposed on Session in the session Handler. A connection-scoped lock is also +// embedded in the context to make it easier to limit operations per-connection. +type Context interface { + context.Context + sync.Locker + + // User returns the username used when establishing the SSH connection. + User() string + + // SessionID returns the session hash. + SessionID() string + + // ClientVersion returns the version reported by the client. + ClientVersion() string + + // ServerVersion returns the version reported by the server. + ServerVersion() string + + // RemoteAddr returns the remote address for this connection. + RemoteAddr() net.Addr + + // LocalAddr returns the local address for this connection. + LocalAddr() net.Addr + + // Permissions returns the Permissions object used for this connection. + Permissions() *Permissions + + // SetValue allows you to easily write new values into the underlying context. + SetValue(key, value interface{}) + + SendAuthBanner(banner string) error +} + +type sshContext struct { + context.Context + *sync.Mutex +} + +func newContext(srv *Server) (*sshContext, context.CancelFunc) { + innerCtx, cancel := context.WithCancel(context.Background()) + ctx := &sshContext{innerCtx, &sync.Mutex{}} + ctx.SetValue(ContextKeyServer, srv) + perms := &Permissions{&gossh.Permissions{}} + ctx.SetValue(ContextKeyPermissions, perms) + return ctx, cancel +} + +// this is separate from newContext because we will get ConnMetadata +// at different points so it needs to be applied separately +func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { + if ctx.Value(ContextKeySessionID) != nil { + return + } + ctx.SetValue(ContextKeySessionID, hex.EncodeToString(conn.SessionID())) + ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion())) + ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion())) + ctx.SetValue(ContextKeyUser, conn.User()) + ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr()) + ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr()) + ctx.SetValue(ContextKeySendAuthBanner, conn.SendAuthBanner) +} + +func (ctx *sshContext) SetValue(key, value interface{}) { + ctx.Context = context.WithValue(ctx.Context, key, value) +} + +func (ctx *sshContext) User() string { + return ctx.Value(ContextKeyUser).(string) +} + +func (ctx *sshContext) SessionID() string { + return ctx.Value(ContextKeySessionID).(string) +} + +func (ctx *sshContext) ClientVersion() string { + return ctx.Value(ContextKeyClientVersion).(string) +} + +func (ctx *sshContext) ServerVersion() string { + return ctx.Value(ContextKeyServerVersion).(string) +} + +func (ctx *sshContext) RemoteAddr() net.Addr { + if addr, ok := ctx.Value(ContextKeyRemoteAddr).(net.Addr); ok { + return addr + } + return nil +} + +func (ctx *sshContext) LocalAddr() net.Addr { + return ctx.Value(ContextKeyLocalAddr).(net.Addr) +} + +func (ctx *sshContext) Permissions() *Permissions { + return ctx.Value(ContextKeyPermissions).(*Permissions) +} + +func (ctx *sshContext) SendAuthBanner(msg string) error { + return ctx.Value(ContextKeySendAuthBanner).(func(string) error)(msg) +} diff --git a/tempfork/gliderlabs/ssh/context_test.go b/tempfork/gliderlabs/ssh/context_test.go index 8f71c395841c9..dcbd326b77809 100644 --- a/tempfork/gliderlabs/ssh/context_test.go +++ b/tempfork/gliderlabs/ssh/context_test.go @@ -1,49 +1,49 @@ -//go:build glidertests - -package ssh - -import "testing" - -func TestSetPermissions(t *testing.T) { - t.Parallel() - permsExt := map[string]string{ - "foo": "bar", - } - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - if _, ok := s.Permissions().Extensions["foo"]; !ok { - t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt) - } - }, - }, nil, PasswordAuth(func(ctx Context, password string) bool { - ctx.Permissions().Extensions = permsExt - return true - })) - defer cleanup() - if err := session.Run(""); err != nil { - t.Fatal(err) - } -} - -func TestSetValue(t *testing.T) { - t.Parallel() - value := map[string]string{ - "foo": "bar", - } - key := "testValue" - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - v := s.Context().Value(key).(map[string]string) - if v["foo"] != value["foo"] { - t.Fatalf("got %#v; want %#v", v, value) - } - }, - }, nil, PasswordAuth(func(ctx Context, password string) bool { - ctx.SetValue(key, value) - return true - })) - defer cleanup() - if err := session.Run(""); err != nil { - t.Fatal(err) - } -} +//go:build glidertests + +package ssh + +import "testing" + +func TestSetPermissions(t *testing.T) { + t.Parallel() + permsExt := map[string]string{ + "foo": "bar", + } + session, _, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + if _, ok := s.Permissions().Extensions["foo"]; !ok { + t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt) + } + }, + }, nil, PasswordAuth(func(ctx Context, password string) bool { + ctx.Permissions().Extensions = permsExt + return true + })) + defer cleanup() + if err := session.Run(""); err != nil { + t.Fatal(err) + } +} + +func TestSetValue(t *testing.T) { + t.Parallel() + value := map[string]string{ + "foo": "bar", + } + key := "testValue" + session, _, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + v := s.Context().Value(key).(map[string]string) + if v["foo"] != value["foo"] { + t.Fatalf("got %#v; want %#v", v, value) + } + }, + }, nil, PasswordAuth(func(ctx Context, password string) bool { + ctx.SetValue(key, value) + return true + })) + defer cleanup() + if err := session.Run(""); err != nil { + t.Fatal(err) + } +} diff --git a/tempfork/gliderlabs/ssh/doc.go b/tempfork/gliderlabs/ssh/doc.go index 46c47d650a06c..d139191768d55 100644 --- a/tempfork/gliderlabs/ssh/doc.go +++ b/tempfork/gliderlabs/ssh/doc.go @@ -1,45 +1,45 @@ -/* -Package ssh wraps the crypto/ssh package with a higher-level API for building -SSH servers. The goal of the API was to make it as simple as using net/http, so -the API is very similar. - -You should be able to build any SSH server using only this package, which wraps -relevant types and some functions from crypto/ssh. However, you still need to -use crypto/ssh for building SSH clients. - -ListenAndServe starts an SSH server with a given address, handler, and options. The -handler is usually nil, which means to use DefaultHandler. Handle sets DefaultHandler: - - ssh.Handle(func(s ssh.Session) { - io.WriteString(s, "Hello world\n") - }) - - log.Fatal(ssh.ListenAndServe(":2222", nil)) - -If you don't specify a host key, it will generate one every time. This is convenient -except you'll have to deal with clients being confused that the host key is different. -It's a better idea to generate or point to an existing key on your system: - - log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/Users/progrium/.ssh/id_rsa"))) - -Although all options have functional option helpers, another way to control the -server's behavior is by creating a custom Server: - - s := &ssh.Server{ - Addr: ":2222", - Handler: sessionHandler, - PublicKeyHandler: authHandler, - } - s.AddHostKey(hostKeySigner) - - log.Fatal(s.ListenAndServe()) - -This package automatically handles basic SSH requests like setting environment -variables, requesting PTY, and changing window size. These requests are -processed, responded to, and any relevant state is updated. This state is then -exposed to you via the Session interface. - -The one big feature missing from the Session abstraction is signals. This was -started, but not completed. Pull Requests welcome! -*/ -package ssh +/* +Package ssh wraps the crypto/ssh package with a higher-level API for building +SSH servers. The goal of the API was to make it as simple as using net/http, so +the API is very similar. + +You should be able to build any SSH server using only this package, which wraps +relevant types and some functions from crypto/ssh. However, you still need to +use crypto/ssh for building SSH clients. + +ListenAndServe starts an SSH server with a given address, handler, and options. The +handler is usually nil, which means to use DefaultHandler. Handle sets DefaultHandler: + + ssh.Handle(func(s ssh.Session) { + io.WriteString(s, "Hello world\n") + }) + + log.Fatal(ssh.ListenAndServe(":2222", nil)) + +If you don't specify a host key, it will generate one every time. This is convenient +except you'll have to deal with clients being confused that the host key is different. +It's a better idea to generate or point to an existing key on your system: + + log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/Users/progrium/.ssh/id_rsa"))) + +Although all options have functional option helpers, another way to control the +server's behavior is by creating a custom Server: + + s := &ssh.Server{ + Addr: ":2222", + Handler: sessionHandler, + PublicKeyHandler: authHandler, + } + s.AddHostKey(hostKeySigner) + + log.Fatal(s.ListenAndServe()) + +This package automatically handles basic SSH requests like setting environment +variables, requesting PTY, and changing window size. These requests are +processed, responded to, and any relevant state is updated. This state is then +exposed to you via the Session interface. + +The one big feature missing from the Session abstraction is signals. This was +started, but not completed. Pull Requests welcome! +*/ +package ssh diff --git a/tempfork/gliderlabs/ssh/example_test.go b/tempfork/gliderlabs/ssh/example_test.go index 61ffebbc045dc..c174bc4ae190e 100644 --- a/tempfork/gliderlabs/ssh/example_test.go +++ b/tempfork/gliderlabs/ssh/example_test.go @@ -1,50 +1,50 @@ -package ssh_test - -import ( - "errors" - "io" - "os" - - "tailscale.com/tempfork/gliderlabs/ssh" -) - -func ExampleListenAndServe() { - ssh.ListenAndServe(":2222", func(s ssh.Session) { - io.WriteString(s, "Hello world\n") - }) -} - -func ExamplePasswordAuth() { - ssh.ListenAndServe(":2222", nil, - ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool { - return pass == "secret" - }), - ) -} - -func ExampleNoPty() { - ssh.ListenAndServe(":2222", nil, ssh.NoPty()) -} - -func ExamplePublicKeyAuth() { - ssh.ListenAndServe(":2222", nil, - ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) error { - data, err := os.ReadFile("/path/to/allowed/key.pub") - if err != nil { - return err - } - allowed, _, _, _, err := ssh.ParseAuthorizedKey(data) - if err != nil { - return err - } - if !ssh.KeysEqual(key, allowed) { - return errors.New("some error") - } - return nil - }), - ) -} - -func ExampleHostKeyFile() { - ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/path/to/host/key")) -} +package ssh_test + +import ( + "errors" + "io" + "os" + + "tailscale.com/tempfork/gliderlabs/ssh" +) + +func ExampleListenAndServe() { + ssh.ListenAndServe(":2222", func(s ssh.Session) { + io.WriteString(s, "Hello world\n") + }) +} + +func ExamplePasswordAuth() { + ssh.ListenAndServe(":2222", nil, + ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool { + return pass == "secret" + }), + ) +} + +func ExampleNoPty() { + ssh.ListenAndServe(":2222", nil, ssh.NoPty()) +} + +func ExamplePublicKeyAuth() { + ssh.ListenAndServe(":2222", nil, + ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) error { + data, err := os.ReadFile("/path/to/allowed/key.pub") + if err != nil { + return err + } + allowed, _, _, _, err := ssh.ParseAuthorizedKey(data) + if err != nil { + return err + } + if !ssh.KeysEqual(key, allowed) { + return errors.New("some error") + } + return nil + }), + ) +} + +func ExampleHostKeyFile() { + ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/path/to/host/key")) +} diff --git a/tempfork/gliderlabs/ssh/options.go b/tempfork/gliderlabs/ssh/options.go index bb24909bebd2a..aa87a4f39db9e 100644 --- a/tempfork/gliderlabs/ssh/options.go +++ b/tempfork/gliderlabs/ssh/options.go @@ -1,84 +1,84 @@ -package ssh - -import ( - "os" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -// PasswordAuth returns a functional option that sets PasswordHandler on the server. -func PasswordAuth(fn PasswordHandler) Option { - return func(srv *Server) error { - srv.PasswordHandler = fn - return nil - } -} - -// PublicKeyAuth returns a functional option that sets PublicKeyHandler on the server. -func PublicKeyAuth(fn PublicKeyHandler) Option { - return func(srv *Server) error { - srv.PublicKeyHandler = fn - return nil - } -} - -// HostKeyFile returns a functional option that adds HostSigners to the server -// from a PEM file at filepath. -func HostKeyFile(filepath string) Option { - return func(srv *Server) error { - pemBytes, err := os.ReadFile(filepath) - if err != nil { - return err - } - - signer, err := gossh.ParsePrivateKey(pemBytes) - if err != nil { - return err - } - - srv.AddHostKey(signer) - - return nil - } -} - -func KeyboardInteractiveAuth(fn KeyboardInteractiveHandler) Option { - return func(srv *Server) error { - srv.KeyboardInteractiveHandler = fn - return nil - } -} - -// HostKeyPEM returns a functional option that adds HostSigners to the server -// from a PEM file as bytes. -func HostKeyPEM(bytes []byte) Option { - return func(srv *Server) error { - signer, err := gossh.ParsePrivateKey(bytes) - if err != nil { - return err - } - - srv.AddHostKey(signer) - - return nil - } -} - -// NoPty returns a functional option that sets PtyCallback to return false, -// denying PTY requests. -func NoPty() Option { - return func(srv *Server) error { - srv.PtyCallback = func(ctx Context, pty Pty) bool { - return false - } - return nil - } -} - -// WrapConn returns a functional option that sets ConnCallback on the server. -func WrapConn(fn ConnCallback) Option { - return func(srv *Server) error { - srv.ConnCallback = fn - return nil - } -} +package ssh + +import ( + "os" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +// PasswordAuth returns a functional option that sets PasswordHandler on the server. +func PasswordAuth(fn PasswordHandler) Option { + return func(srv *Server) error { + srv.PasswordHandler = fn + return nil + } +} + +// PublicKeyAuth returns a functional option that sets PublicKeyHandler on the server. +func PublicKeyAuth(fn PublicKeyHandler) Option { + return func(srv *Server) error { + srv.PublicKeyHandler = fn + return nil + } +} + +// HostKeyFile returns a functional option that adds HostSigners to the server +// from a PEM file at filepath. +func HostKeyFile(filepath string) Option { + return func(srv *Server) error { + pemBytes, err := os.ReadFile(filepath) + if err != nil { + return err + } + + signer, err := gossh.ParsePrivateKey(pemBytes) + if err != nil { + return err + } + + srv.AddHostKey(signer) + + return nil + } +} + +func KeyboardInteractiveAuth(fn KeyboardInteractiveHandler) Option { + return func(srv *Server) error { + srv.KeyboardInteractiveHandler = fn + return nil + } +} + +// HostKeyPEM returns a functional option that adds HostSigners to the server +// from a PEM file as bytes. +func HostKeyPEM(bytes []byte) Option { + return func(srv *Server) error { + signer, err := gossh.ParsePrivateKey(bytes) + if err != nil { + return err + } + + srv.AddHostKey(signer) + + return nil + } +} + +// NoPty returns a functional option that sets PtyCallback to return false, +// denying PTY requests. +func NoPty() Option { + return func(srv *Server) error { + srv.PtyCallback = func(ctx Context, pty Pty) bool { + return false + } + return nil + } +} + +// WrapConn returns a functional option that sets ConnCallback on the server. +func WrapConn(fn ConnCallback) Option { + return func(srv *Server) error { + srv.ConnCallback = fn + return nil + } +} diff --git a/tempfork/gliderlabs/ssh/options_test.go b/tempfork/gliderlabs/ssh/options_test.go index 3aa2f1cf5e31b..7cf6f376c6a88 100644 --- a/tempfork/gliderlabs/ssh/options_test.go +++ b/tempfork/gliderlabs/ssh/options_test.go @@ -1,111 +1,111 @@ -//go:build glidertests - -package ssh - -import ( - "net" - "strings" - "sync/atomic" - "testing" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, *gossh.Client, func()) { - for _, option := range options { - if err := srv.SetOption(option); err != nil { - t.Fatal(err) - } - } - return newTestSession(t, srv, cfg) -} - -func TestPasswordAuth(t *testing.T) { - t.Parallel() - testUser := "testuser" - testPass := "testpass" - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - // noop - }, - }, &gossh.ClientConfig{ - User: testUser, - Auth: []gossh.AuthMethod{ - gossh.Password(testPass), - }, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - }, PasswordAuth(func(ctx Context, password string) bool { - if ctx.User() != testUser { - t.Fatalf("user = %#v; want %#v", ctx.User(), testUser) - } - if password != testPass { - t.Fatalf("user = %#v; want %#v", password, testPass) - } - return true - })) - defer cleanup() - if err := session.Run(""); err != nil { - t.Fatal(err) - } -} - -func TestPasswordAuthBadPass(t *testing.T) { - t.Parallel() - l := newLocalListener() - srv := &Server{Handler: func(s Session) {}} - srv.SetOption(PasswordAuth(func(ctx Context, password string) bool { - return false - })) - go srv.serveOnce(l) - _, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{ - User: "testuser", - Auth: []gossh.AuthMethod{ - gossh.Password("testpass"), - }, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - }) - if err != nil { - if !strings.Contains(err.Error(), "unable to authenticate") { - t.Fatal(err) - } - } -} - -type wrappedConn struct { - net.Conn - written int32 -} - -func (c *wrappedConn) Write(p []byte) (n int, err error) { - n, err = c.Conn.Write(p) - atomic.AddInt32(&(c.written), int32(n)) - return -} - -func TestConnWrapping(t *testing.T) { - t.Parallel() - var wrapped *wrappedConn - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - // nothing - }, - }, &gossh.ClientConfig{ - User: "testuser", - Auth: []gossh.AuthMethod{ - gossh.Password("testpass"), - }, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - }, PasswordAuth(func(ctx Context, password string) bool { - return true - }), WrapConn(func(ctx Context, conn net.Conn) net.Conn { - wrapped = &wrappedConn{conn, 0} - return wrapped - })) - defer cleanup() - if err := session.Shell(); err != nil { - t.Fatal(err) - } - if atomic.LoadInt32(&(wrapped.written)) == 0 { - t.Fatal("wrapped conn not written to") - } -} +//go:build glidertests + +package ssh + +import ( + "net" + "strings" + "sync/atomic" + "testing" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, *gossh.Client, func()) { + for _, option := range options { + if err := srv.SetOption(option); err != nil { + t.Fatal(err) + } + } + return newTestSession(t, srv, cfg) +} + +func TestPasswordAuth(t *testing.T) { + t.Parallel() + testUser := "testuser" + testPass := "testpass" + session, _, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + // noop + }, + }, &gossh.ClientConfig{ + User: testUser, + Auth: []gossh.AuthMethod{ + gossh.Password(testPass), + }, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }, PasswordAuth(func(ctx Context, password string) bool { + if ctx.User() != testUser { + t.Fatalf("user = %#v; want %#v", ctx.User(), testUser) + } + if password != testPass { + t.Fatalf("user = %#v; want %#v", password, testPass) + } + return true + })) + defer cleanup() + if err := session.Run(""); err != nil { + t.Fatal(err) + } +} + +func TestPasswordAuthBadPass(t *testing.T) { + t.Parallel() + l := newLocalListener() + srv := &Server{Handler: func(s Session) {}} + srv.SetOption(PasswordAuth(func(ctx Context, password string) bool { + return false + })) + go srv.serveOnce(l) + _, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{ + User: "testuser", + Auth: []gossh.AuthMethod{ + gossh.Password("testpass"), + }, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }) + if err != nil { + if !strings.Contains(err.Error(), "unable to authenticate") { + t.Fatal(err) + } + } +} + +type wrappedConn struct { + net.Conn + written int32 +} + +func (c *wrappedConn) Write(p []byte) (n int, err error) { + n, err = c.Conn.Write(p) + atomic.AddInt32(&(c.written), int32(n)) + return +} + +func TestConnWrapping(t *testing.T) { + t.Parallel() + var wrapped *wrappedConn + session, _, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + // nothing + }, + }, &gossh.ClientConfig{ + User: "testuser", + Auth: []gossh.AuthMethod{ + gossh.Password("testpass"), + }, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }, PasswordAuth(func(ctx Context, password string) bool { + return true + }), WrapConn(func(ctx Context, conn net.Conn) net.Conn { + wrapped = &wrappedConn{conn, 0} + return wrapped + })) + defer cleanup() + if err := session.Shell(); err != nil { + t.Fatal(err) + } + if atomic.LoadInt32(&(wrapped.written)) == 0 { + t.Fatal("wrapped conn not written to") + } +} diff --git a/tempfork/gliderlabs/ssh/server.go b/tempfork/gliderlabs/ssh/server.go index 32f633e87b58e..1086a72caf0e5 100644 --- a/tempfork/gliderlabs/ssh/server.go +++ b/tempfork/gliderlabs/ssh/server.go @@ -1,459 +1,459 @@ -package ssh - -import ( - "context" - "errors" - "fmt" - "net" - "sync" - "time" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -// ErrServerClosed is returned by the Server's Serve, ListenAndServe, -// and ListenAndServeTLS methods after a call to Shutdown or Close. -var ErrServerClosed = errors.New("ssh: Server closed") - -type SubsystemHandler func(s Session) - -var DefaultSubsystemHandlers = map[string]SubsystemHandler{} - -type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) - -var DefaultRequestHandlers = map[string]RequestHandler{} - -type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) - -var DefaultChannelHandlers = map[string]ChannelHandler{ - "session": DefaultSessionHandler, -} - -// Server defines parameters for running an SSH server. The zero value for -// Server is a valid configuration. When both PasswordHandler and -// PublicKeyHandler are nil, no client authentication is performed. -type Server struct { - Addr string // TCP address to listen on, ":22" if empty - Handler Handler // handler to invoke, ssh.DefaultHandler if nil - HostSigners []Signer // private keys for the host key, must have at least one - Version string // server version to be sent before the initial handshake - - KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler - PasswordHandler PasswordHandler // password authentication handler - PublicKeyHandler PublicKeyHandler // public key authentication handler - NoClientAuthHandler NoClientAuthHandler // no client authentication handler - PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil - ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling - LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil - ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil - ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options - SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions - - ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures - - IdleTimeout time.Duration // connection timeout when no activity, none if empty - MaxTimeout time.Duration // absolute connection timeout, none if empty - - // ChannelHandlers allow overriding the built-in session handlers or provide - // extensions to the protocol, such as tcpip forwarding. By default only the - // "session" handler is enabled. - ChannelHandlers map[string]ChannelHandler - - // RequestHandlers allow overriding the server-level request handlers or - // provide extensions to the protocol, such as tcpip forwarding. By default - // no handlers are enabled. - RequestHandlers map[string]RequestHandler - - // SubsystemHandlers are handlers which are similar to the usual SSH command - // handlers, but handle named subsystems. - SubsystemHandlers map[string]SubsystemHandler - - listenerWg sync.WaitGroup - mu sync.RWMutex - listeners map[net.Listener]struct{} - conns map[*gossh.ServerConn]struct{} - connWg sync.WaitGroup - doneChan chan struct{} -} - -func (srv *Server) ensureHostSigner() error { - srv.mu.Lock() - defer srv.mu.Unlock() - - if len(srv.HostSigners) == 0 { - signer, err := generateSigner() - if err != nil { - return err - } - srv.HostSigners = append(srv.HostSigners, signer) - } - return nil -} - -func (srv *Server) ensureHandlers() { - srv.mu.Lock() - defer srv.mu.Unlock() - - if srv.RequestHandlers == nil { - srv.RequestHandlers = map[string]RequestHandler{} - for k, v := range DefaultRequestHandlers { - srv.RequestHandlers[k] = v - } - } - if srv.ChannelHandlers == nil { - srv.ChannelHandlers = map[string]ChannelHandler{} - for k, v := range DefaultChannelHandlers { - srv.ChannelHandlers[k] = v - } - } - if srv.SubsystemHandlers == nil { - srv.SubsystemHandlers = map[string]SubsystemHandler{} - for k, v := range DefaultSubsystemHandlers { - srv.SubsystemHandlers[k] = v - } - } -} - -func (srv *Server) config(ctx Context) *gossh.ServerConfig { - srv.mu.RLock() - defer srv.mu.RUnlock() - - var config *gossh.ServerConfig - if srv.ServerConfigCallback == nil { - config = &gossh.ServerConfig{} - } else { - config = srv.ServerConfigCallback(ctx) - } - for _, signer := range srv.HostSigners { - config.AddHostKey(signer) - } - if srv.PasswordHandler == nil && srv.PublicKeyHandler == nil && srv.KeyboardInteractiveHandler == nil { - config.NoClientAuth = true - } - if srv.Version != "" { - config.ServerVersion = "SSH-2.0-" + srv.Version - } - if srv.PasswordHandler != nil { - config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if ok := srv.PasswordHandler(ctx, string(password)); !ok { - return ctx.Permissions().Permissions, fmt.Errorf("permission denied") - } - return ctx.Permissions().Permissions, nil - } - } - if srv.PublicKeyHandler != nil { - config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if err := srv.PublicKeyHandler(ctx, key); err != nil { - return ctx.Permissions().Permissions, err - } - ctx.SetValue(ContextKeyPublicKey, key) - return ctx.Permissions().Permissions, nil - } - } - if srv.KeyboardInteractiveHandler != nil { - config.KeyboardInteractiveCallback = func(conn gossh.ConnMetadata, challenger gossh.KeyboardInteractiveChallenge) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if ok := srv.KeyboardInteractiveHandler(ctx, challenger); !ok { - return ctx.Permissions().Permissions, fmt.Errorf("permission denied") - } - return ctx.Permissions().Permissions, nil - } - } - if srv.NoClientAuthHandler != nil { - config.NoClientAuthCallback = func(conn gossh.ConnMetadata) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if err := srv.NoClientAuthHandler(ctx); err != nil { - return ctx.Permissions().Permissions, err - } - return ctx.Permissions().Permissions, nil - } - } - return config -} - -// Handle sets the Handler for the server. -func (srv *Server) Handle(fn Handler) { - srv.mu.Lock() - defer srv.mu.Unlock() - - srv.Handler = fn -} - -// Close immediately closes all active listeners and all active -// connections. -// -// Close returns any error returned from closing the Server's -// underlying Listener(s). -func (srv *Server) Close() error { - srv.mu.Lock() - defer srv.mu.Unlock() - - srv.closeDoneChanLocked() - err := srv.closeListenersLocked() - for c := range srv.conns { - c.Close() - delete(srv.conns, c) - } - return err -} - -// Shutdown gracefully shuts down the server without interrupting any -// active connections. Shutdown works by first closing all open -// listeners, and then waiting indefinitely for connections to close. -// If the provided context expires before the shutdown is complete, -// then the context's error is returned. -func (srv *Server) Shutdown(ctx context.Context) error { - srv.mu.Lock() - lnerr := srv.closeListenersLocked() - srv.closeDoneChanLocked() - srv.mu.Unlock() - - finished := make(chan struct{}, 1) - go func() { - srv.listenerWg.Wait() - srv.connWg.Wait() - finished <- struct{}{} - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case <-finished: - return lnerr - } -} - -// Serve accepts incoming connections on the Listener l, creating a new -// connection goroutine for each. The connection goroutines read requests and then -// calls srv.Handler to handle sessions. -// -// Serve always returns a non-nil error. -func (srv *Server) Serve(l net.Listener) error { - srv.ensureHandlers() - defer l.Close() - if err := srv.ensureHostSigner(); err != nil { - return err - } - if srv.Handler == nil { - srv.Handler = DefaultHandler - } - var tempDelay time.Duration - - srv.trackListener(l, true) - defer srv.trackListener(l, false) - for { - conn, e := l.Accept() - if e != nil { - select { - case <-srv.getDoneChan(): - return ErrServerClosed - default: - } - if ne, ok := e.(net.Error); ok && ne.Temporary() { - if tempDelay == 0 { - tempDelay = 5 * time.Millisecond - } else { - tempDelay *= 2 - } - if max := 1 * time.Second; tempDelay > max { - tempDelay = max - } - time.Sleep(tempDelay) - continue - } - return e - } - go srv.HandleConn(conn) - } -} - -func (srv *Server) HandleConn(newConn net.Conn) { - ctx, cancel := newContext(srv) - if srv.ConnCallback != nil { - cbConn := srv.ConnCallback(ctx, newConn) - if cbConn == nil { - newConn.Close() - return - } - newConn = cbConn - } - conn := &serverConn{ - Conn: newConn, - idleTimeout: srv.IdleTimeout, - closeCanceler: cancel, - } - if srv.MaxTimeout > 0 { - conn.maxDeadline = time.Now().Add(srv.MaxTimeout) - } - defer conn.Close() - sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx)) - if err != nil { - if srv.ConnectionFailedCallback != nil { - srv.ConnectionFailedCallback(conn, err) - } - return - } - - srv.trackConn(sshConn, true) - defer srv.trackConn(sshConn, false) - - ctx.SetValue(ContextKeyConn, sshConn) - applyConnMetadata(ctx, sshConn) - //go gossh.DiscardRequests(reqs) - go srv.handleRequests(ctx, reqs) - for ch := range chans { - handler := srv.ChannelHandlers[ch.ChannelType()] - if handler == nil { - handler = srv.ChannelHandlers["default"] - } - if handler == nil { - ch.Reject(gossh.UnknownChannelType, "unsupported channel type") - continue - } - go handler(srv, sshConn, ch, ctx) - } -} - -func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { - for req := range in { - handler := srv.RequestHandlers[req.Type] - if handler == nil { - handler = srv.RequestHandlers["default"] - } - if handler == nil { - req.Reply(false, nil) - continue - } - /*reqCtx, cancel := context.WithCancel(ctx) - defer cancel() */ - ret, payload := handler(ctx, srv, req) - req.Reply(ret, payload) - } -} - -// ListenAndServe listens on the TCP network address srv.Addr and then calls -// Serve to handle incoming connections. If srv.Addr is blank, ":22" is used. -// ListenAndServe always returns a non-nil error. -func (srv *Server) ListenAndServe() error { - addr := srv.Addr - if addr == "" { - addr = ":22" - } - ln, err := net.Listen("tcp", addr) - if err != nil { - return err - } - return srv.Serve(ln) -} - -// AddHostKey adds a private key as a host key. If an existing host key exists -// with the same algorithm, it is overwritten. Each server config must have at -// least one host key. -func (srv *Server) AddHostKey(key Signer) { - srv.mu.Lock() - defer srv.mu.Unlock() - - // these are later added via AddHostKey on ServerConfig, which performs the - // check for one of every algorithm. - - // This check is based on the AddHostKey method from the x/crypto/ssh - // library. This allows us to only keep one active key for each type on a - // server at once. So, if you're dynamically updating keys at runtime, this - // list will not keep growing. - for i, k := range srv.HostSigners { - if k.PublicKey().Type() == key.PublicKey().Type() { - srv.HostSigners[i] = key - return - } - } - - srv.HostSigners = append(srv.HostSigners, key) -} - -// SetOption runs a functional option against the server. -func (srv *Server) SetOption(option Option) error { - // NOTE: there is a potential race here for any option that doesn't call an - // internal method. We can't actually lock here because if something calls - // (as an example) AddHostKey, it will deadlock. - - //srv.mu.Lock() - //defer srv.mu.Unlock() - - return option(srv) -} - -func (srv *Server) getDoneChan() <-chan struct{} { - srv.mu.Lock() - defer srv.mu.Unlock() - - return srv.getDoneChanLocked() -} - -func (srv *Server) getDoneChanLocked() chan struct{} { - if srv.doneChan == nil { - srv.doneChan = make(chan struct{}) - } - return srv.doneChan -} - -func (srv *Server) closeDoneChanLocked() { - ch := srv.getDoneChanLocked() - select { - case <-ch: - // Already closed. Don't close again. - default: - // Safe to close here. We're the only closer, guarded - // by srv.mu. - close(ch) - } -} - -func (srv *Server) closeListenersLocked() error { - var err error - for ln := range srv.listeners { - if cerr := ln.Close(); cerr != nil && err == nil { - err = cerr - } - delete(srv.listeners, ln) - } - return err -} - -func (srv *Server) trackListener(ln net.Listener, add bool) { - srv.mu.Lock() - defer srv.mu.Unlock() - - if srv.listeners == nil { - srv.listeners = make(map[net.Listener]struct{}) - } - if add { - // If the *Server is being reused after a previous - // Close or Shutdown, reset its doneChan: - if len(srv.listeners) == 0 && len(srv.conns) == 0 { - srv.doneChan = nil - } - srv.listeners[ln] = struct{}{} - srv.listenerWg.Add(1) - } else { - delete(srv.listeners, ln) - srv.listenerWg.Done() - } -} - -func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { - srv.mu.Lock() - defer srv.mu.Unlock() - - if srv.conns == nil { - srv.conns = make(map[*gossh.ServerConn]struct{}) - } - if add { - srv.conns[c] = struct{}{} - srv.connWg.Add(1) - } else { - delete(srv.conns, c) - srv.connWg.Done() - } -} +package ssh + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +// ErrServerClosed is returned by the Server's Serve, ListenAndServe, +// and ListenAndServeTLS methods after a call to Shutdown or Close. +var ErrServerClosed = errors.New("ssh: Server closed") + +type SubsystemHandler func(s Session) + +var DefaultSubsystemHandlers = map[string]SubsystemHandler{} + +type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) + +var DefaultRequestHandlers = map[string]RequestHandler{} + +type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) + +var DefaultChannelHandlers = map[string]ChannelHandler{ + "session": DefaultSessionHandler, +} + +// Server defines parameters for running an SSH server. The zero value for +// Server is a valid configuration. When both PasswordHandler and +// PublicKeyHandler are nil, no client authentication is performed. +type Server struct { + Addr string // TCP address to listen on, ":22" if empty + Handler Handler // handler to invoke, ssh.DefaultHandler if nil + HostSigners []Signer // private keys for the host key, must have at least one + Version string // server version to be sent before the initial handshake + + KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler + PasswordHandler PasswordHandler // password authentication handler + PublicKeyHandler PublicKeyHandler // public key authentication handler + NoClientAuthHandler NoClientAuthHandler // no client authentication handler + PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil + ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling + LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil + ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil + ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options + SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions + + ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures + + IdleTimeout time.Duration // connection timeout when no activity, none if empty + MaxTimeout time.Duration // absolute connection timeout, none if empty + + // ChannelHandlers allow overriding the built-in session handlers or provide + // extensions to the protocol, such as tcpip forwarding. By default only the + // "session" handler is enabled. + ChannelHandlers map[string]ChannelHandler + + // RequestHandlers allow overriding the server-level request handlers or + // provide extensions to the protocol, such as tcpip forwarding. By default + // no handlers are enabled. + RequestHandlers map[string]RequestHandler + + // SubsystemHandlers are handlers which are similar to the usual SSH command + // handlers, but handle named subsystems. + SubsystemHandlers map[string]SubsystemHandler + + listenerWg sync.WaitGroup + mu sync.RWMutex + listeners map[net.Listener]struct{} + conns map[*gossh.ServerConn]struct{} + connWg sync.WaitGroup + doneChan chan struct{} +} + +func (srv *Server) ensureHostSigner() error { + srv.mu.Lock() + defer srv.mu.Unlock() + + if len(srv.HostSigners) == 0 { + signer, err := generateSigner() + if err != nil { + return err + } + srv.HostSigners = append(srv.HostSigners, signer) + } + return nil +} + +func (srv *Server) ensureHandlers() { + srv.mu.Lock() + defer srv.mu.Unlock() + + if srv.RequestHandlers == nil { + srv.RequestHandlers = map[string]RequestHandler{} + for k, v := range DefaultRequestHandlers { + srv.RequestHandlers[k] = v + } + } + if srv.ChannelHandlers == nil { + srv.ChannelHandlers = map[string]ChannelHandler{} + for k, v := range DefaultChannelHandlers { + srv.ChannelHandlers[k] = v + } + } + if srv.SubsystemHandlers == nil { + srv.SubsystemHandlers = map[string]SubsystemHandler{} + for k, v := range DefaultSubsystemHandlers { + srv.SubsystemHandlers[k] = v + } + } +} + +func (srv *Server) config(ctx Context) *gossh.ServerConfig { + srv.mu.RLock() + defer srv.mu.RUnlock() + + var config *gossh.ServerConfig + if srv.ServerConfigCallback == nil { + config = &gossh.ServerConfig{} + } else { + config = srv.ServerConfigCallback(ctx) + } + for _, signer := range srv.HostSigners { + config.AddHostKey(signer) + } + if srv.PasswordHandler == nil && srv.PublicKeyHandler == nil && srv.KeyboardInteractiveHandler == nil { + config.NoClientAuth = true + } + if srv.Version != "" { + config.ServerVersion = "SSH-2.0-" + srv.Version + } + if srv.PasswordHandler != nil { + config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { + applyConnMetadata(ctx, conn) + if ok := srv.PasswordHandler(ctx, string(password)); !ok { + return ctx.Permissions().Permissions, fmt.Errorf("permission denied") + } + return ctx.Permissions().Permissions, nil + } + } + if srv.PublicKeyHandler != nil { + config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { + applyConnMetadata(ctx, conn) + if err := srv.PublicKeyHandler(ctx, key); err != nil { + return ctx.Permissions().Permissions, err + } + ctx.SetValue(ContextKeyPublicKey, key) + return ctx.Permissions().Permissions, nil + } + } + if srv.KeyboardInteractiveHandler != nil { + config.KeyboardInteractiveCallback = func(conn gossh.ConnMetadata, challenger gossh.KeyboardInteractiveChallenge) (*gossh.Permissions, error) { + applyConnMetadata(ctx, conn) + if ok := srv.KeyboardInteractiveHandler(ctx, challenger); !ok { + return ctx.Permissions().Permissions, fmt.Errorf("permission denied") + } + return ctx.Permissions().Permissions, nil + } + } + if srv.NoClientAuthHandler != nil { + config.NoClientAuthCallback = func(conn gossh.ConnMetadata) (*gossh.Permissions, error) { + applyConnMetadata(ctx, conn) + if err := srv.NoClientAuthHandler(ctx); err != nil { + return ctx.Permissions().Permissions, err + } + return ctx.Permissions().Permissions, nil + } + } + return config +} + +// Handle sets the Handler for the server. +func (srv *Server) Handle(fn Handler) { + srv.mu.Lock() + defer srv.mu.Unlock() + + srv.Handler = fn +} + +// Close immediately closes all active listeners and all active +// connections. +// +// Close returns any error returned from closing the Server's +// underlying Listener(s). +func (srv *Server) Close() error { + srv.mu.Lock() + defer srv.mu.Unlock() + + srv.closeDoneChanLocked() + err := srv.closeListenersLocked() + for c := range srv.conns { + c.Close() + delete(srv.conns, c) + } + return err +} + +// Shutdown gracefully shuts down the server without interrupting any +// active connections. Shutdown works by first closing all open +// listeners, and then waiting indefinitely for connections to close. +// If the provided context expires before the shutdown is complete, +// then the context's error is returned. +func (srv *Server) Shutdown(ctx context.Context) error { + srv.mu.Lock() + lnerr := srv.closeListenersLocked() + srv.closeDoneChanLocked() + srv.mu.Unlock() + + finished := make(chan struct{}, 1) + go func() { + srv.listenerWg.Wait() + srv.connWg.Wait() + finished <- struct{}{} + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-finished: + return lnerr + } +} + +// Serve accepts incoming connections on the Listener l, creating a new +// connection goroutine for each. The connection goroutines read requests and then +// calls srv.Handler to handle sessions. +// +// Serve always returns a non-nil error. +func (srv *Server) Serve(l net.Listener) error { + srv.ensureHandlers() + defer l.Close() + if err := srv.ensureHostSigner(); err != nil { + return err + } + if srv.Handler == nil { + srv.Handler = DefaultHandler + } + var tempDelay time.Duration + + srv.trackListener(l, true) + defer srv.trackListener(l, false) + for { + conn, e := l.Accept() + if e != nil { + select { + case <-srv.getDoneChan(): + return ErrServerClosed + default: + } + if ne, ok := e.(net.Error); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + time.Sleep(tempDelay) + continue + } + return e + } + go srv.HandleConn(conn) + } +} + +func (srv *Server) HandleConn(newConn net.Conn) { + ctx, cancel := newContext(srv) + if srv.ConnCallback != nil { + cbConn := srv.ConnCallback(ctx, newConn) + if cbConn == nil { + newConn.Close() + return + } + newConn = cbConn + } + conn := &serverConn{ + Conn: newConn, + idleTimeout: srv.IdleTimeout, + closeCanceler: cancel, + } + if srv.MaxTimeout > 0 { + conn.maxDeadline = time.Now().Add(srv.MaxTimeout) + } + defer conn.Close() + sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx)) + if err != nil { + if srv.ConnectionFailedCallback != nil { + srv.ConnectionFailedCallback(conn, err) + } + return + } + + srv.trackConn(sshConn, true) + defer srv.trackConn(sshConn, false) + + ctx.SetValue(ContextKeyConn, sshConn) + applyConnMetadata(ctx, sshConn) + //go gossh.DiscardRequests(reqs) + go srv.handleRequests(ctx, reqs) + for ch := range chans { + handler := srv.ChannelHandlers[ch.ChannelType()] + if handler == nil { + handler = srv.ChannelHandlers["default"] + } + if handler == nil { + ch.Reject(gossh.UnknownChannelType, "unsupported channel type") + continue + } + go handler(srv, sshConn, ch, ctx) + } +} + +func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { + for req := range in { + handler := srv.RequestHandlers[req.Type] + if handler == nil { + handler = srv.RequestHandlers["default"] + } + if handler == nil { + req.Reply(false, nil) + continue + } + /*reqCtx, cancel := context.WithCancel(ctx) + defer cancel() */ + ret, payload := handler(ctx, srv, req) + req.Reply(ret, payload) + } +} + +// ListenAndServe listens on the TCP network address srv.Addr and then calls +// Serve to handle incoming connections. If srv.Addr is blank, ":22" is used. +// ListenAndServe always returns a non-nil error. +func (srv *Server) ListenAndServe() error { + addr := srv.Addr + if addr == "" { + addr = ":22" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + return srv.Serve(ln) +} + +// AddHostKey adds a private key as a host key. If an existing host key exists +// with the same algorithm, it is overwritten. Each server config must have at +// least one host key. +func (srv *Server) AddHostKey(key Signer) { + srv.mu.Lock() + defer srv.mu.Unlock() + + // these are later added via AddHostKey on ServerConfig, which performs the + // check for one of every algorithm. + + // This check is based on the AddHostKey method from the x/crypto/ssh + // library. This allows us to only keep one active key for each type on a + // server at once. So, if you're dynamically updating keys at runtime, this + // list will not keep growing. + for i, k := range srv.HostSigners { + if k.PublicKey().Type() == key.PublicKey().Type() { + srv.HostSigners[i] = key + return + } + } + + srv.HostSigners = append(srv.HostSigners, key) +} + +// SetOption runs a functional option against the server. +func (srv *Server) SetOption(option Option) error { + // NOTE: there is a potential race here for any option that doesn't call an + // internal method. We can't actually lock here because if something calls + // (as an example) AddHostKey, it will deadlock. + + //srv.mu.Lock() + //defer srv.mu.Unlock() + + return option(srv) +} + +func (srv *Server) getDoneChan() <-chan struct{} { + srv.mu.Lock() + defer srv.mu.Unlock() + + return srv.getDoneChanLocked() +} + +func (srv *Server) getDoneChanLocked() chan struct{} { + if srv.doneChan == nil { + srv.doneChan = make(chan struct{}) + } + return srv.doneChan +} + +func (srv *Server) closeDoneChanLocked() { + ch := srv.getDoneChanLocked() + select { + case <-ch: + // Already closed. Don't close again. + default: + // Safe to close here. We're the only closer, guarded + // by srv.mu. + close(ch) + } +} + +func (srv *Server) closeListenersLocked() error { + var err error + for ln := range srv.listeners { + if cerr := ln.Close(); cerr != nil && err == nil { + err = cerr + } + delete(srv.listeners, ln) + } + return err +} + +func (srv *Server) trackListener(ln net.Listener, add bool) { + srv.mu.Lock() + defer srv.mu.Unlock() + + if srv.listeners == nil { + srv.listeners = make(map[net.Listener]struct{}) + } + if add { + // If the *Server is being reused after a previous + // Close or Shutdown, reset its doneChan: + if len(srv.listeners) == 0 && len(srv.conns) == 0 { + srv.doneChan = nil + } + srv.listeners[ln] = struct{}{} + srv.listenerWg.Add(1) + } else { + delete(srv.listeners, ln) + srv.listenerWg.Done() + } +} + +func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { + srv.mu.Lock() + defer srv.mu.Unlock() + + if srv.conns == nil { + srv.conns = make(map[*gossh.ServerConn]struct{}) + } + if add { + srv.conns[c] = struct{}{} + srv.connWg.Add(1) + } else { + delete(srv.conns, c) + srv.connWg.Done() + } +} diff --git a/tempfork/gliderlabs/ssh/server_test.go b/tempfork/gliderlabs/ssh/server_test.go index 1a63bb4b2f3d5..177c071170c4e 100644 --- a/tempfork/gliderlabs/ssh/server_test.go +++ b/tempfork/gliderlabs/ssh/server_test.go @@ -1,128 +1,128 @@ -//go:build glidertests - -package ssh - -import ( - "bytes" - "context" - "io" - "testing" - "time" -) - -func TestAddHostKey(t *testing.T) { - s := Server{} - signer, err := generateSigner() - if err != nil { - t.Fatal(err) - } - s.AddHostKey(signer) - if len(s.HostSigners) != 1 { - t.Fatal("Key was not properly added") - } - signer, err = generateSigner() - if err != nil { - t.Fatal(err) - } - s.AddHostKey(signer) - if len(s.HostSigners) != 1 { - t.Fatal("Key was not properly replaced") - } -} - -func TestServerShutdown(t *testing.T) { - l := newLocalListener() - testBytes := []byte("Hello world\n") - s := &Server{ - Handler: func(s Session) { - s.Write(testBytes) - time.Sleep(50 * time.Millisecond) - }, - } - go func() { - err := s.Serve(l) - if err != nil && err != ErrServerClosed { - t.Fatal(err) - } - }() - sessDone := make(chan struct{}) - sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) - go func() { - defer cleanup() - defer close(sessDone) - var stdout bytes.Buffer - sess.Stdout = &stdout - if err := sess.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testBytes) { - t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes()) - } - }() - - srvDone := make(chan struct{}) - go func() { - defer close(srvDone) - err := s.Shutdown(context.Background()) - if err != nil { - t.Fatal(err) - } - }() - - timeout := time.After(2 * time.Second) - select { - case <-timeout: - t.Fatal("timeout") - return - case <-srvDone: - // TODO: add timeout for sessDone - <-sessDone - return - } -} - -func TestServerClose(t *testing.T) { - l := newLocalListener() - s := &Server{ - Handler: func(s Session) { - time.Sleep(5 * time.Second) - }, - } - go func() { - err := s.Serve(l) - if err != nil && err != ErrServerClosed { - t.Fatal(err) - } - }() - - clientDoneChan := make(chan struct{}) - closeDoneChan := make(chan struct{}) - - sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) - go func() { - defer cleanup() - defer close(clientDoneChan) - <-closeDoneChan - if err := sess.Run(""); err != nil && err != io.EOF { - t.Fatal(err) - } - }() - - go func() { - err := s.Close() - if err != nil { - t.Fatal(err) - } - close(closeDoneChan) - }() - - timeout := time.After(100 * time.Millisecond) - select { - case <-timeout: - t.Error("timeout") - return - case <-s.getDoneChan(): - <-clientDoneChan - return - } -} +//go:build glidertests + +package ssh + +import ( + "bytes" + "context" + "io" + "testing" + "time" +) + +func TestAddHostKey(t *testing.T) { + s := Server{} + signer, err := generateSigner() + if err != nil { + t.Fatal(err) + } + s.AddHostKey(signer) + if len(s.HostSigners) != 1 { + t.Fatal("Key was not properly added") + } + signer, err = generateSigner() + if err != nil { + t.Fatal(err) + } + s.AddHostKey(signer) + if len(s.HostSigners) != 1 { + t.Fatal("Key was not properly replaced") + } +} + +func TestServerShutdown(t *testing.T) { + l := newLocalListener() + testBytes := []byte("Hello world\n") + s := &Server{ + Handler: func(s Session) { + s.Write(testBytes) + time.Sleep(50 * time.Millisecond) + }, + } + go func() { + err := s.Serve(l) + if err != nil && err != ErrServerClosed { + t.Fatal(err) + } + }() + sessDone := make(chan struct{}) + sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) + go func() { + defer cleanup() + defer close(sessDone) + var stdout bytes.Buffer + sess.Stdout = &stdout + if err := sess.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testBytes) { + t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes()) + } + }() + + srvDone := make(chan struct{}) + go func() { + defer close(srvDone) + err := s.Shutdown(context.Background()) + if err != nil { + t.Fatal(err) + } + }() + + timeout := time.After(2 * time.Second) + select { + case <-timeout: + t.Fatal("timeout") + return + case <-srvDone: + // TODO: add timeout for sessDone + <-sessDone + return + } +} + +func TestServerClose(t *testing.T) { + l := newLocalListener() + s := &Server{ + Handler: func(s Session) { + time.Sleep(5 * time.Second) + }, + } + go func() { + err := s.Serve(l) + if err != nil && err != ErrServerClosed { + t.Fatal(err) + } + }() + + clientDoneChan := make(chan struct{}) + closeDoneChan := make(chan struct{}) + + sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) + go func() { + defer cleanup() + defer close(clientDoneChan) + <-closeDoneChan + if err := sess.Run(""); err != nil && err != io.EOF { + t.Fatal(err) + } + }() + + go func() { + err := s.Close() + if err != nil { + t.Fatal(err) + } + close(closeDoneChan) + }() + + timeout := time.After(100 * time.Millisecond) + select { + case <-timeout: + t.Error("timeout") + return + case <-s.getDoneChan(): + <-clientDoneChan + return + } +} diff --git a/tempfork/gliderlabs/ssh/session.go b/tempfork/gliderlabs/ssh/session.go index 2f43de739d6d0..0a4a21e534401 100644 --- a/tempfork/gliderlabs/ssh/session.go +++ b/tempfork/gliderlabs/ssh/session.go @@ -1,386 +1,386 @@ -package ssh - -import ( - "bytes" - "context" - "errors" - "fmt" - "net" - "sync" - - "github.com/anmitsu/go-shlex" - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -// Session provides access to information about an SSH session and methods -// to read and write to the SSH channel with an embedded Channel interface from -// crypto/ssh. -// -// When Command() returns an empty slice, the user requested a shell. Otherwise -// the user is performing an exec with those command arguments. -// -// TODO: Signals -type Session interface { - gossh.Channel - - // User returns the username used when establishing the SSH connection. - User() string - - // RemoteAddr returns the net.Addr of the client side of the connection. - RemoteAddr() net.Addr - - // LocalAddr returns the net.Addr of the server side of the connection. - LocalAddr() net.Addr - - // Environ returns a copy of strings representing the environment set by the - // user for this session, in the form "key=value". - Environ() []string - - // Exit sends an exit status and then closes the session. - Exit(code int) error - - // Command returns a shell parsed slice of arguments that were provided by the - // user. Shell parsing splits the command string according to POSIX shell rules, - // which considers quoting not just whitespace. - Command() []string - - // RawCommand returns the exact command that was provided by the user. - RawCommand() string - - // Subsystem returns the subsystem requested by the user. - Subsystem() string - - // PublicKey returns the PublicKey used to authenticate. If a public key was not - // used it will return nil. - PublicKey() PublicKey - - // Context returns the connection's context. The returned context is always - // non-nil and holds the same data as the Context passed into auth - // handlers and callbacks. - // - // The context is canceled when the client's connection closes or I/O - // operation fails. - Context() context.Context - - // Permissions returns a copy of the Permissions object that was available for - // setup in the auth handlers via the Context. - Permissions() Permissions - - // Pty returns PTY information, a channel of window size changes, and a boolean - // of whether or not a PTY was accepted for this session. - Pty() (Pty, <-chan Window, bool) - - // Signals registers a channel to receive signals sent from the client. The - // channel must handle signal sends or it will block the SSH request loop. - // Registering nil will unregister the channel from signal sends. During the - // time no channel is registered signals are buffered up to a reasonable amount. - // If there are buffered signals when a channel is registered, they will be - // sent in order on the channel immediately after registering. - Signals(c chan<- Signal) - - // Break regisers a channel to receive notifications of break requests sent - // from the client. The channel must handle break requests, or it will block - // the request handling loop. Registering nil will unregister the channel. - // During the time that no channel is registered, breaks are ignored. - Break(c chan<- bool) - - // DisablePTYEmulation disables the session's default minimal PTY emulation. - // If you're setting the pty's termios settings from the Pty request, use - // this method to avoid corruption. - // Currently (2022-03-12) the only emulation implemented is NL-to-CRNL translation (`\n`=>`\r\n`). - // A call of DisablePTYEmulation must precede any call to Write. - DisablePTYEmulation() -} - -// maxSigBufSize is how many signals will be buffered -// when there is no signal channel specified -const maxSigBufSize = 128 - -func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { - ch, reqs, err := newChan.Accept() - if err != nil { - // TODO: trigger event callback - return - } - sess := &session{ - Channel: ch, - conn: conn, - handler: srv.Handler, - ptyCb: srv.PtyCallback, - sessReqCb: srv.SessionRequestCallback, - subsystemHandlers: srv.SubsystemHandlers, - ctx: ctx, - } - sess.handleRequests(reqs) -} - -type session struct { - sync.Mutex - gossh.Channel - conn *gossh.ServerConn - handler Handler - subsystemHandlers map[string]SubsystemHandler - handled bool - exited bool - pty *Pty - winch chan Window - env []string - ptyCb PtyCallback - sessReqCb SessionRequestCallback - rawCmd string - subsystem string - ctx Context - sigCh chan<- Signal - sigBuf []Signal - breakCh chan<- bool - disablePtyEmulation bool -} - -func (sess *session) DisablePTYEmulation() { - sess.disablePtyEmulation = true -} - -func (sess *session) Write(p []byte) (n int, err error) { - if sess.pty != nil && !sess.disablePtyEmulation { - m := len(p) - // normalize \n to \r\n when pty is accepted. - // this is a hardcoded shortcut since we don't support terminal modes. - p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1) - p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1) - n, err = sess.Channel.Write(p) - if n > m { - n = m - } - return - } - return sess.Channel.Write(p) -} - -func (sess *session) PublicKey() PublicKey { - sessionkey := sess.ctx.Value(ContextKeyPublicKey) - if sessionkey == nil { - return nil - } - return sessionkey.(PublicKey) -} - -func (sess *session) Permissions() Permissions { - // use context permissions because its properly - // wrapped and easier to dereference - perms := sess.ctx.Value(ContextKeyPermissions).(*Permissions) - return *perms -} - -func (sess *session) Context() context.Context { - return sess.ctx -} - -func (sess *session) Exit(code int) error { - sess.Lock() - defer sess.Unlock() - if sess.exited { - return errors.New("Session.Exit called multiple times") - } - sess.exited = true - - status := struct{ Status uint32 }{uint32(code)} - _, err := sess.SendRequest("exit-status", false, gossh.Marshal(&status)) - if err != nil { - return err - } - return sess.Close() -} - -func (sess *session) User() string { - return sess.conn.User() -} - -func (sess *session) RemoteAddr() net.Addr { - return sess.conn.RemoteAddr() -} - -func (sess *session) LocalAddr() net.Addr { - return sess.conn.LocalAddr() -} - -func (sess *session) Environ() []string { - return append([]string(nil), sess.env...) -} - -func (sess *session) RawCommand() string { - return sess.rawCmd -} - -func (sess *session) Command() []string { - cmd, _ := shlex.Split(sess.rawCmd, true) - return append([]string(nil), cmd...) -} - -func (sess *session) Subsystem() string { - return sess.subsystem -} - -func (sess *session) Pty() (Pty, <-chan Window, bool) { - if sess.pty != nil { - return *sess.pty, sess.winch, true - } - return Pty{}, sess.winch, false -} - -func (sess *session) Signals(c chan<- Signal) { - sess.Lock() - defer sess.Unlock() - sess.sigCh = c - if len(sess.sigBuf) > 0 { - go func() { - for _, sig := range sess.sigBuf { - sess.sigCh <- sig - } - }() - } -} - -func (sess *session) Break(c chan<- bool) { - sess.Lock() - defer sess.Unlock() - sess.breakCh = c -} - -func (sess *session) handleRequests(reqs <-chan *gossh.Request) { - for req := range reqs { - switch req.Type { - case "shell", "exec": - if sess.handled { - req.Reply(false, nil) - continue - } - - var payload = struct{ Value string }{} - gossh.Unmarshal(req.Payload, &payload) - sess.rawCmd = payload.Value - - // If there's a session policy callback, we need to confirm before - // accepting the session. - if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { - sess.rawCmd = "" - req.Reply(false, nil) - continue - } - - sess.handled = true - req.Reply(true, nil) - - go func() { - sess.handler(sess) - sess.Exit(0) - }() - case "subsystem": - if sess.handled { - req.Reply(false, nil) - continue - } - - var payload = struct{ Value string }{} - gossh.Unmarshal(req.Payload, &payload) - sess.subsystem = payload.Value - - // If there's a session policy callback, we need to confirm before - // accepting the session. - if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { - sess.rawCmd = "" - req.Reply(false, nil) - continue - } - - handler := sess.subsystemHandlers[payload.Value] - if handler == nil { - handler = sess.subsystemHandlers["default"] - } - if handler == nil { - req.Reply(false, nil) - continue - } - - sess.handled = true - req.Reply(true, nil) - - go func() { - handler(sess) - sess.Exit(0) - }() - case "env": - if sess.handled { - req.Reply(false, nil) - continue - } - var kv struct{ Key, Value string } - gossh.Unmarshal(req.Payload, &kv) - sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value)) - req.Reply(true, nil) - case "signal": - var payload struct{ Signal string } - gossh.Unmarshal(req.Payload, &payload) - sess.Lock() - if sess.sigCh != nil { - sess.sigCh <- Signal(payload.Signal) - } else { - if len(sess.sigBuf) < maxSigBufSize { - sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal)) - } - } - sess.Unlock() - case "pty-req": - if sess.handled || sess.pty != nil { - req.Reply(false, nil) - continue - } - ptyReq, ok := parsePtyRequest(req.Payload) - if !ok { - req.Reply(false, nil) - continue - } - if sess.ptyCb != nil { - ok := sess.ptyCb(sess.ctx, ptyReq) - if !ok { - req.Reply(false, nil) - continue - } - } - sess.pty = &ptyReq - sess.winch = make(chan Window, 1) - sess.winch <- ptyReq.Window - defer func() { - // when reqs is closed - close(sess.winch) - }() - req.Reply(ok, nil) - case "window-change": - if sess.pty == nil { - req.Reply(false, nil) - continue - } - win, _, ok := parseWindow(req.Payload) - if ok { - sess.pty.Window = win - sess.winch <- win - } - req.Reply(ok, nil) - case agentRequestType: - // TODO: option/callback to allow agent forwarding - SetAgentRequested(sess.ctx) - req.Reply(true, nil) - case "break": - ok := false - sess.Lock() - if sess.breakCh != nil { - sess.breakCh <- true - ok = true - } - req.Reply(ok, nil) - sess.Unlock() - default: - // TODO: debug log - req.Reply(false, nil) - } - } -} +package ssh + +import ( + "bytes" + "context" + "errors" + "fmt" + "net" + "sync" + + "github.com/anmitsu/go-shlex" + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +// Session provides access to information about an SSH session and methods +// to read and write to the SSH channel with an embedded Channel interface from +// crypto/ssh. +// +// When Command() returns an empty slice, the user requested a shell. Otherwise +// the user is performing an exec with those command arguments. +// +// TODO: Signals +type Session interface { + gossh.Channel + + // User returns the username used when establishing the SSH connection. + User() string + + // RemoteAddr returns the net.Addr of the client side of the connection. + RemoteAddr() net.Addr + + // LocalAddr returns the net.Addr of the server side of the connection. + LocalAddr() net.Addr + + // Environ returns a copy of strings representing the environment set by the + // user for this session, in the form "key=value". + Environ() []string + + // Exit sends an exit status and then closes the session. + Exit(code int) error + + // Command returns a shell parsed slice of arguments that were provided by the + // user. Shell parsing splits the command string according to POSIX shell rules, + // which considers quoting not just whitespace. + Command() []string + + // RawCommand returns the exact command that was provided by the user. + RawCommand() string + + // Subsystem returns the subsystem requested by the user. + Subsystem() string + + // PublicKey returns the PublicKey used to authenticate. If a public key was not + // used it will return nil. + PublicKey() PublicKey + + // Context returns the connection's context. The returned context is always + // non-nil and holds the same data as the Context passed into auth + // handlers and callbacks. + // + // The context is canceled when the client's connection closes or I/O + // operation fails. + Context() context.Context + + // Permissions returns a copy of the Permissions object that was available for + // setup in the auth handlers via the Context. + Permissions() Permissions + + // Pty returns PTY information, a channel of window size changes, and a boolean + // of whether or not a PTY was accepted for this session. + Pty() (Pty, <-chan Window, bool) + + // Signals registers a channel to receive signals sent from the client. The + // channel must handle signal sends or it will block the SSH request loop. + // Registering nil will unregister the channel from signal sends. During the + // time no channel is registered signals are buffered up to a reasonable amount. + // If there are buffered signals when a channel is registered, they will be + // sent in order on the channel immediately after registering. + Signals(c chan<- Signal) + + // Break regisers a channel to receive notifications of break requests sent + // from the client. The channel must handle break requests, or it will block + // the request handling loop. Registering nil will unregister the channel. + // During the time that no channel is registered, breaks are ignored. + Break(c chan<- bool) + + // DisablePTYEmulation disables the session's default minimal PTY emulation. + // If you're setting the pty's termios settings from the Pty request, use + // this method to avoid corruption. + // Currently (2022-03-12) the only emulation implemented is NL-to-CRNL translation (`\n`=>`\r\n`). + // A call of DisablePTYEmulation must precede any call to Write. + DisablePTYEmulation() +} + +// maxSigBufSize is how many signals will be buffered +// when there is no signal channel specified +const maxSigBufSize = 128 + +func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { + ch, reqs, err := newChan.Accept() + if err != nil { + // TODO: trigger event callback + return + } + sess := &session{ + Channel: ch, + conn: conn, + handler: srv.Handler, + ptyCb: srv.PtyCallback, + sessReqCb: srv.SessionRequestCallback, + subsystemHandlers: srv.SubsystemHandlers, + ctx: ctx, + } + sess.handleRequests(reqs) +} + +type session struct { + sync.Mutex + gossh.Channel + conn *gossh.ServerConn + handler Handler + subsystemHandlers map[string]SubsystemHandler + handled bool + exited bool + pty *Pty + winch chan Window + env []string + ptyCb PtyCallback + sessReqCb SessionRequestCallback + rawCmd string + subsystem string + ctx Context + sigCh chan<- Signal + sigBuf []Signal + breakCh chan<- bool + disablePtyEmulation bool +} + +func (sess *session) DisablePTYEmulation() { + sess.disablePtyEmulation = true +} + +func (sess *session) Write(p []byte) (n int, err error) { + if sess.pty != nil && !sess.disablePtyEmulation { + m := len(p) + // normalize \n to \r\n when pty is accepted. + // this is a hardcoded shortcut since we don't support terminal modes. + p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1) + p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1) + n, err = sess.Channel.Write(p) + if n > m { + n = m + } + return + } + return sess.Channel.Write(p) +} + +func (sess *session) PublicKey() PublicKey { + sessionkey := sess.ctx.Value(ContextKeyPublicKey) + if sessionkey == nil { + return nil + } + return sessionkey.(PublicKey) +} + +func (sess *session) Permissions() Permissions { + // use context permissions because its properly + // wrapped and easier to dereference + perms := sess.ctx.Value(ContextKeyPermissions).(*Permissions) + return *perms +} + +func (sess *session) Context() context.Context { + return sess.ctx +} + +func (sess *session) Exit(code int) error { + sess.Lock() + defer sess.Unlock() + if sess.exited { + return errors.New("Session.Exit called multiple times") + } + sess.exited = true + + status := struct{ Status uint32 }{uint32(code)} + _, err := sess.SendRequest("exit-status", false, gossh.Marshal(&status)) + if err != nil { + return err + } + return sess.Close() +} + +func (sess *session) User() string { + return sess.conn.User() +} + +func (sess *session) RemoteAddr() net.Addr { + return sess.conn.RemoteAddr() +} + +func (sess *session) LocalAddr() net.Addr { + return sess.conn.LocalAddr() +} + +func (sess *session) Environ() []string { + return append([]string(nil), sess.env...) +} + +func (sess *session) RawCommand() string { + return sess.rawCmd +} + +func (sess *session) Command() []string { + cmd, _ := shlex.Split(sess.rawCmd, true) + return append([]string(nil), cmd...) +} + +func (sess *session) Subsystem() string { + return sess.subsystem +} + +func (sess *session) Pty() (Pty, <-chan Window, bool) { + if sess.pty != nil { + return *sess.pty, sess.winch, true + } + return Pty{}, sess.winch, false +} + +func (sess *session) Signals(c chan<- Signal) { + sess.Lock() + defer sess.Unlock() + sess.sigCh = c + if len(sess.sigBuf) > 0 { + go func() { + for _, sig := range sess.sigBuf { + sess.sigCh <- sig + } + }() + } +} + +func (sess *session) Break(c chan<- bool) { + sess.Lock() + defer sess.Unlock() + sess.breakCh = c +} + +func (sess *session) handleRequests(reqs <-chan *gossh.Request) { + for req := range reqs { + switch req.Type { + case "shell", "exec": + if sess.handled { + req.Reply(false, nil) + continue + } + + var payload = struct{ Value string }{} + gossh.Unmarshal(req.Payload, &payload) + sess.rawCmd = payload.Value + + // If there's a session policy callback, we need to confirm before + // accepting the session. + if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { + sess.rawCmd = "" + req.Reply(false, nil) + continue + } + + sess.handled = true + req.Reply(true, nil) + + go func() { + sess.handler(sess) + sess.Exit(0) + }() + case "subsystem": + if sess.handled { + req.Reply(false, nil) + continue + } + + var payload = struct{ Value string }{} + gossh.Unmarshal(req.Payload, &payload) + sess.subsystem = payload.Value + + // If there's a session policy callback, we need to confirm before + // accepting the session. + if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { + sess.rawCmd = "" + req.Reply(false, nil) + continue + } + + handler := sess.subsystemHandlers[payload.Value] + if handler == nil { + handler = sess.subsystemHandlers["default"] + } + if handler == nil { + req.Reply(false, nil) + continue + } + + sess.handled = true + req.Reply(true, nil) + + go func() { + handler(sess) + sess.Exit(0) + }() + case "env": + if sess.handled { + req.Reply(false, nil) + continue + } + var kv struct{ Key, Value string } + gossh.Unmarshal(req.Payload, &kv) + sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value)) + req.Reply(true, nil) + case "signal": + var payload struct{ Signal string } + gossh.Unmarshal(req.Payload, &payload) + sess.Lock() + if sess.sigCh != nil { + sess.sigCh <- Signal(payload.Signal) + } else { + if len(sess.sigBuf) < maxSigBufSize { + sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal)) + } + } + sess.Unlock() + case "pty-req": + if sess.handled || sess.pty != nil { + req.Reply(false, nil) + continue + } + ptyReq, ok := parsePtyRequest(req.Payload) + if !ok { + req.Reply(false, nil) + continue + } + if sess.ptyCb != nil { + ok := sess.ptyCb(sess.ctx, ptyReq) + if !ok { + req.Reply(false, nil) + continue + } + } + sess.pty = &ptyReq + sess.winch = make(chan Window, 1) + sess.winch <- ptyReq.Window + defer func() { + // when reqs is closed + close(sess.winch) + }() + req.Reply(ok, nil) + case "window-change": + if sess.pty == nil { + req.Reply(false, nil) + continue + } + win, _, ok := parseWindow(req.Payload) + if ok { + sess.pty.Window = win + sess.winch <- win + } + req.Reply(ok, nil) + case agentRequestType: + // TODO: option/callback to allow agent forwarding + SetAgentRequested(sess.ctx) + req.Reply(true, nil) + case "break": + ok := false + sess.Lock() + if sess.breakCh != nil { + sess.breakCh <- true + ok = true + } + req.Reply(ok, nil) + sess.Unlock() + default: + // TODO: debug log + req.Reply(false, nil) + } + } +} diff --git a/tempfork/gliderlabs/ssh/session_test.go b/tempfork/gliderlabs/ssh/session_test.go index fddd67f6d41cc..a60be5ec12d4e 100644 --- a/tempfork/gliderlabs/ssh/session_test.go +++ b/tempfork/gliderlabs/ssh/session_test.go @@ -1,440 +1,440 @@ -//go:build glidertests - -package ssh - -import ( - "bytes" - "fmt" - "io" - "net" - "testing" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -func (srv *Server) serveOnce(l net.Listener) error { - srv.ensureHandlers() - if err := srv.ensureHostSigner(); err != nil { - return err - } - conn, e := l.Accept() - if e != nil { - return e - } - srv.ChannelHandlers = map[string]ChannelHandler{ - "session": DefaultSessionHandler, - "direct-tcpip": DirectTCPIPHandler, - } - srv.HandleConn(conn) - return nil -} - -func newLocalListener() net.Listener { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { - panic(fmt.Sprintf("failed to listen on a port: %v", err)) - } - } - return l -} - -func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { - if config == nil { - config = &gossh.ClientConfig{ - User: "testuser", - Auth: []gossh.AuthMethod{ - gossh.Password("testpass"), - }, - } - } - if config.HostKeyCallback == nil { - config.HostKeyCallback = gossh.InsecureIgnoreHostKey() - } - client, err := gossh.Dial("tcp", addr, config) - if err != nil { - t.Fatal(err) - } - session, err := client.NewSession() - if err != nil { - t.Fatal(err) - } - return session, client, func() { - session.Close() - client.Close() - } -} - -func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { - l := newLocalListener() - go srv.serveOnce(l) - return newClientSession(t, l.Addr().String(), cfg) -} - -func TestStdout(t *testing.T) { - t.Parallel() - testBytes := []byte("Hello world\n") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Write(testBytes) - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testBytes) { - t.Fatalf("stdout = %#v; want %#v", stdout.Bytes(), testBytes) - } -} - -func TestStderr(t *testing.T) { - t.Parallel() - testBytes := []byte("Hello world\n") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Stderr().Write(testBytes) - }, - }, nil) - defer cleanup() - var stderr bytes.Buffer - session.Stderr = &stderr - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stderr.Bytes(), testBytes) { - t.Fatalf("stderr = %#v; want %#v", stderr.Bytes(), testBytes) - } -} - -func TestStdin(t *testing.T) { - t.Parallel() - testBytes := []byte("Hello world\n") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - io.Copy(s, s) // stdin back into stdout - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - session.Stdin = bytes.NewBuffer(testBytes) - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testBytes) { - t.Fatalf("stdout = %#v; want %#v given stdin = %#v", stdout.Bytes(), testBytes, testBytes) - } -} - -func TestUser(t *testing.T) { - t.Parallel() - testUser := []byte("progrium") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - io.WriteString(s, s.User()) - }, - }, &gossh.ClientConfig{ - User: string(testUser), - }) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testUser) { - t.Fatalf("stdout = %#v; want %#v given user = %#v", stdout.Bytes(), testUser, string(testUser)) - } -} - -func TestDefaultExitStatusZero(t *testing.T) { - t.Parallel() - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - // noop - }, - }, nil) - defer cleanup() - err := session.Run("") - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -func TestExplicitExitStatusZero(t *testing.T) { - t.Parallel() - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Exit(0) - }, - }, nil) - defer cleanup() - err := session.Run("") - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -func TestExitStatusNonZero(t *testing.T) { - t.Parallel() - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Exit(1) - }, - }, nil) - defer cleanup() - err := session.Run("") - e, ok := err.(*gossh.ExitError) - if !ok { - t.Fatalf("expected ExitError but got %T", err) - } - if e.ExitStatus() != 1 { - t.Fatalf("exit-status = %#v; want %#v", e.ExitStatus(), 1) - } -} - -func TestPty(t *testing.T) { - t.Parallel() - term := "xterm" - winWidth := 40 - winHeight := 80 - done := make(chan bool) - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - ptyReq, _, isPty := s.Pty() - if !isPty { - t.Fatalf("expected pty but none requested") - } - if ptyReq.Term != term { - t.Fatalf("expected term %#v but got %#v", term, ptyReq.Term) - } - if ptyReq.Window.Width != winWidth { - t.Fatalf("expected window width %#v but got %#v", winWidth, ptyReq.Window.Width) - } - if ptyReq.Window.Height != winHeight { - t.Fatalf("expected window height %#v but got %#v", winHeight, ptyReq.Window.Height) - } - close(done) - }, - }, nil) - defer cleanup() - if err := session.RequestPty(term, winHeight, winWidth, gossh.TerminalModes{}); err != nil { - t.Fatalf("expected nil but got %v", err) - } - if err := session.Shell(); err != nil { - t.Fatalf("expected nil but got %v", err) - } - <-done -} - -func TestPtyResize(t *testing.T) { - t.Parallel() - winch0 := Window{Width: 40, Height: 80} - winch1 := Window{Width: 80, Height: 160} - winch2 := Window{Width: 20, Height: 40} - winches := make(chan Window) - done := make(chan bool) - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - ptyReq, winCh, isPty := s.Pty() - if !isPty { - t.Fatalf("expected pty but none requested") - } - if ptyReq.Window != winch0 { - t.Fatalf("expected window %#v but got %#v", winch0, ptyReq.Window) - } - for win := range winCh { - winches <- win - } - close(done) - }, - }, nil) - defer cleanup() - // winch0 - if err := session.RequestPty("xterm", winch0.Height, winch0.Width, gossh.TerminalModes{}); err != nil { - t.Fatalf("expected nil but got %v", err) - } - if err := session.Shell(); err != nil { - t.Fatalf("expected nil but got %v", err) - } - gotWinch := <-winches - if gotWinch != winch0 { - t.Fatalf("expected window %#v but got %#v", winch0, gotWinch) - } - // winch1 - winchMsg := struct{ w, h uint32 }{uint32(winch1.Width), uint32(winch1.Height)} - ok, err := session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) - if err == nil && !ok { - t.Fatalf("unexpected error or bad reply on send request") - } - gotWinch = <-winches - if gotWinch != winch1 { - t.Fatalf("expected window %#v but got %#v", winch1, gotWinch) - } - // winch2 - winchMsg = struct{ w, h uint32 }{uint32(winch2.Width), uint32(winch2.Height)} - ok, err = session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) - if err == nil && !ok { - t.Fatalf("unexpected error or bad reply on send request") - } - gotWinch = <-winches - if gotWinch != winch2 { - t.Fatalf("expected window %#v but got %#v", winch2, gotWinch) - } - session.Close() - <-done -} - -func TestSignals(t *testing.T) { - t.Parallel() - - // errChan lets us get errors back from the session - errChan := make(chan error, 5) - - // doneChan lets us specify that we should exit. - doneChan := make(chan interface{}) - - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - // We need to use a buffered channel here, otherwise it's possible for the - // second call to Signal to get discarded. - signals := make(chan Signal, 2) - s.Signals(signals) - - select { - case sig := <-signals: - if sig != SIGINT { - errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig) - return - } - case <-doneChan: - errChan <- fmt.Errorf("Unexpected done") - return - } - - select { - case sig := <-signals: - if sig != SIGKILL { - errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig) - return - } - case <-doneChan: - errChan <- fmt.Errorf("Unexpected done") - return - } - }, - }, nil) - defer cleanup() - - go func() { - session.Signal(gossh.SIGINT) - session.Signal(gossh.SIGKILL) - }() - - go func() { - errChan <- session.Run("") - }() - - err := <-errChan - close(doneChan) - - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -func TestBreakWithChanRegistered(t *testing.T) { - t.Parallel() - - // errChan lets us get errors back from the session - errChan := make(chan error, 5) - - // doneChan lets us specify that we should exit. - doneChan := make(chan interface{}) - - breakChan := make(chan bool) - - readyToReceiveBreak := make(chan bool) - - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Break(breakChan) // register a break channel with the session - readyToReceiveBreak <- true - - select { - case <-breakChan: - io.WriteString(s, "break") - case <-doneChan: - errChan <- fmt.Errorf("Unexpected done") - return - } - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - go func() { - errChan <- session.Run("") - }() - - <-readyToReceiveBreak - ok, err := session.SendRequest("break", true, nil) - if err != nil { - t.Fatalf("expected nil but got %v", err) - } - if ok != true { - t.Fatalf("expected true but got %v", ok) - } - - err = <-errChan - close(doneChan) - - if err != nil { - t.Fatalf("expected nil but got %v", err) - } - if !bytes.Equal(stdout.Bytes(), []byte("break")) { - t.Fatalf("stdout = %#v, expected 'break'", stdout.Bytes()) - } -} - -func TestBreakWithoutChanRegistered(t *testing.T) { - t.Parallel() - - // errChan lets us get errors back from the session - errChan := make(chan error, 5) - - // doneChan lets us specify that we should exit. - doneChan := make(chan interface{}) - - waitUntilAfterBreakSent := make(chan bool) - - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - <-waitUntilAfterBreakSent - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - go func() { - errChan <- session.Run("") - }() - - ok, err := session.SendRequest("break", true, nil) - if err != nil { - t.Fatalf("expected nil but got %v", err) - } - if ok != false { - t.Fatalf("expected false but got %v", ok) - } - waitUntilAfterBreakSent <- true - - err = <-errChan - close(doneChan) - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} +//go:build glidertests + +package ssh + +import ( + "bytes" + "fmt" + "io" + "net" + "testing" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +func (srv *Server) serveOnce(l net.Listener) error { + srv.ensureHandlers() + if err := srv.ensureHostSigner(); err != nil { + return err + } + conn, e := l.Accept() + if e != nil { + return e + } + srv.ChannelHandlers = map[string]ChannelHandler{ + "session": DefaultSessionHandler, + "direct-tcpip": DirectTCPIPHandler, + } + srv.HandleConn(conn) + return nil +} + +func newLocalListener() net.Listener { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(fmt.Sprintf("failed to listen on a port: %v", err)) + } + } + return l +} + +func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { + if config == nil { + config = &gossh.ClientConfig{ + User: "testuser", + Auth: []gossh.AuthMethod{ + gossh.Password("testpass"), + }, + } + } + if config.HostKeyCallback == nil { + config.HostKeyCallback = gossh.InsecureIgnoreHostKey() + } + client, err := gossh.Dial("tcp", addr, config) + if err != nil { + t.Fatal(err) + } + session, err := client.NewSession() + if err != nil { + t.Fatal(err) + } + return session, client, func() { + session.Close() + client.Close() + } +} + +func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { + l := newLocalListener() + go srv.serveOnce(l) + return newClientSession(t, l.Addr().String(), cfg) +} + +func TestStdout(t *testing.T) { + t.Parallel() + testBytes := []byte("Hello world\n") + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Write(testBytes) + }, + }, nil) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testBytes) { + t.Fatalf("stdout = %#v; want %#v", stdout.Bytes(), testBytes) + } +} + +func TestStderr(t *testing.T) { + t.Parallel() + testBytes := []byte("Hello world\n") + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Stderr().Write(testBytes) + }, + }, nil) + defer cleanup() + var stderr bytes.Buffer + session.Stderr = &stderr + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stderr.Bytes(), testBytes) { + t.Fatalf("stderr = %#v; want %#v", stderr.Bytes(), testBytes) + } +} + +func TestStdin(t *testing.T) { + t.Parallel() + testBytes := []byte("Hello world\n") + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + io.Copy(s, s) // stdin back into stdout + }, + }, nil) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + session.Stdin = bytes.NewBuffer(testBytes) + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testBytes) { + t.Fatalf("stdout = %#v; want %#v given stdin = %#v", stdout.Bytes(), testBytes, testBytes) + } +} + +func TestUser(t *testing.T) { + t.Parallel() + testUser := []byte("progrium") + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + io.WriteString(s, s.User()) + }, + }, &gossh.ClientConfig{ + User: string(testUser), + }) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testUser) { + t.Fatalf("stdout = %#v; want %#v given user = %#v", stdout.Bytes(), testUser, string(testUser)) + } +} + +func TestDefaultExitStatusZero(t *testing.T) { + t.Parallel() + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + // noop + }, + }, nil) + defer cleanup() + err := session.Run("") + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +func TestExplicitExitStatusZero(t *testing.T) { + t.Parallel() + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Exit(0) + }, + }, nil) + defer cleanup() + err := session.Run("") + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +func TestExitStatusNonZero(t *testing.T) { + t.Parallel() + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Exit(1) + }, + }, nil) + defer cleanup() + err := session.Run("") + e, ok := err.(*gossh.ExitError) + if !ok { + t.Fatalf("expected ExitError but got %T", err) + } + if e.ExitStatus() != 1 { + t.Fatalf("exit-status = %#v; want %#v", e.ExitStatus(), 1) + } +} + +func TestPty(t *testing.T) { + t.Parallel() + term := "xterm" + winWidth := 40 + winHeight := 80 + done := make(chan bool) + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + ptyReq, _, isPty := s.Pty() + if !isPty { + t.Fatalf("expected pty but none requested") + } + if ptyReq.Term != term { + t.Fatalf("expected term %#v but got %#v", term, ptyReq.Term) + } + if ptyReq.Window.Width != winWidth { + t.Fatalf("expected window width %#v but got %#v", winWidth, ptyReq.Window.Width) + } + if ptyReq.Window.Height != winHeight { + t.Fatalf("expected window height %#v but got %#v", winHeight, ptyReq.Window.Height) + } + close(done) + }, + }, nil) + defer cleanup() + if err := session.RequestPty(term, winHeight, winWidth, gossh.TerminalModes{}); err != nil { + t.Fatalf("expected nil but got %v", err) + } + if err := session.Shell(); err != nil { + t.Fatalf("expected nil but got %v", err) + } + <-done +} + +func TestPtyResize(t *testing.T) { + t.Parallel() + winch0 := Window{Width: 40, Height: 80} + winch1 := Window{Width: 80, Height: 160} + winch2 := Window{Width: 20, Height: 40} + winches := make(chan Window) + done := make(chan bool) + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + ptyReq, winCh, isPty := s.Pty() + if !isPty { + t.Fatalf("expected pty but none requested") + } + if ptyReq.Window != winch0 { + t.Fatalf("expected window %#v but got %#v", winch0, ptyReq.Window) + } + for win := range winCh { + winches <- win + } + close(done) + }, + }, nil) + defer cleanup() + // winch0 + if err := session.RequestPty("xterm", winch0.Height, winch0.Width, gossh.TerminalModes{}); err != nil { + t.Fatalf("expected nil but got %v", err) + } + if err := session.Shell(); err != nil { + t.Fatalf("expected nil but got %v", err) + } + gotWinch := <-winches + if gotWinch != winch0 { + t.Fatalf("expected window %#v but got %#v", winch0, gotWinch) + } + // winch1 + winchMsg := struct{ w, h uint32 }{uint32(winch1.Width), uint32(winch1.Height)} + ok, err := session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) + if err == nil && !ok { + t.Fatalf("unexpected error or bad reply on send request") + } + gotWinch = <-winches + if gotWinch != winch1 { + t.Fatalf("expected window %#v but got %#v", winch1, gotWinch) + } + // winch2 + winchMsg = struct{ w, h uint32 }{uint32(winch2.Width), uint32(winch2.Height)} + ok, err = session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) + if err == nil && !ok { + t.Fatalf("unexpected error or bad reply on send request") + } + gotWinch = <-winches + if gotWinch != winch2 { + t.Fatalf("expected window %#v but got %#v", winch2, gotWinch) + } + session.Close() + <-done +} + +func TestSignals(t *testing.T) { + t.Parallel() + + // errChan lets us get errors back from the session + errChan := make(chan error, 5) + + // doneChan lets us specify that we should exit. + doneChan := make(chan interface{}) + + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + // We need to use a buffered channel here, otherwise it's possible for the + // second call to Signal to get discarded. + signals := make(chan Signal, 2) + s.Signals(signals) + + select { + case sig := <-signals: + if sig != SIGINT { + errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig) + return + } + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } + + select { + case sig := <-signals: + if sig != SIGKILL { + errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig) + return + } + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } + }, + }, nil) + defer cleanup() + + go func() { + session.Signal(gossh.SIGINT) + session.Signal(gossh.SIGKILL) + }() + + go func() { + errChan <- session.Run("") + }() + + err := <-errChan + close(doneChan) + + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +func TestBreakWithChanRegistered(t *testing.T) { + t.Parallel() + + // errChan lets us get errors back from the session + errChan := make(chan error, 5) + + // doneChan lets us specify that we should exit. + doneChan := make(chan interface{}) + + breakChan := make(chan bool) + + readyToReceiveBreak := make(chan bool) + + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Break(breakChan) // register a break channel with the session + readyToReceiveBreak <- true + + select { + case <-breakChan: + io.WriteString(s, "break") + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } + }, + }, nil) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + go func() { + errChan <- session.Run("") + }() + + <-readyToReceiveBreak + ok, err := session.SendRequest("break", true, nil) + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + if ok != true { + t.Fatalf("expected true but got %v", ok) + } + + err = <-errChan + close(doneChan) + + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + if !bytes.Equal(stdout.Bytes(), []byte("break")) { + t.Fatalf("stdout = %#v, expected 'break'", stdout.Bytes()) + } +} + +func TestBreakWithoutChanRegistered(t *testing.T) { + t.Parallel() + + // errChan lets us get errors back from the session + errChan := make(chan error, 5) + + // doneChan lets us specify that we should exit. + doneChan := make(chan interface{}) + + waitUntilAfterBreakSent := make(chan bool) + + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + <-waitUntilAfterBreakSent + }, + }, nil) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + go func() { + errChan <- session.Run("") + }() + + ok, err := session.SendRequest("break", true, nil) + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + if ok != false { + t.Fatalf("expected false but got %v", ok) + } + waitUntilAfterBreakSent <- true + + err = <-errChan + close(doneChan) + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} diff --git a/tempfork/gliderlabs/ssh/ssh.go b/tempfork/gliderlabs/ssh/ssh.go index 4216ea97ab932..644cb257d9afa 100644 --- a/tempfork/gliderlabs/ssh/ssh.go +++ b/tempfork/gliderlabs/ssh/ssh.go @@ -1,156 +1,156 @@ -package ssh - -import ( - "crypto/subtle" - "net" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -type Signal string - -// POSIX signals as listed in RFC 4254 Section 6.10. -const ( - SIGABRT Signal = "ABRT" - SIGALRM Signal = "ALRM" - SIGFPE Signal = "FPE" - SIGHUP Signal = "HUP" - SIGILL Signal = "ILL" - SIGINT Signal = "INT" - SIGKILL Signal = "KILL" - SIGPIPE Signal = "PIPE" - SIGQUIT Signal = "QUIT" - SIGSEGV Signal = "SEGV" - SIGTERM Signal = "TERM" - SIGUSR1 Signal = "USR1" - SIGUSR2 Signal = "USR2" -) - -// DefaultHandler is the default Handler used by Serve. -var DefaultHandler Handler - -// Option is a functional option handler for Server. -type Option func(*Server) error - -// Handler is a callback for handling established SSH sessions. -type Handler func(Session) - -// PublicKeyHandler is a callback for performing public key authentication. -type PublicKeyHandler func(ctx Context, key PublicKey) error - -type NoClientAuthHandler func(ctx Context) error - -type BannerHandler func(ctx Context) string - -// PasswordHandler is a callback for performing password authentication. -type PasswordHandler func(ctx Context, password string) bool - -// KeyboardInteractiveHandler is a callback for performing keyboard-interactive authentication. -type KeyboardInteractiveHandler func(ctx Context, challenger gossh.KeyboardInteractiveChallenge) bool - -// PtyCallback is a hook for allowing PTY sessions. -type PtyCallback func(ctx Context, pty Pty) bool - -// SessionRequestCallback is a callback for allowing or denying SSH sessions. -type SessionRequestCallback func(sess Session, requestType string) bool - -// ConnCallback is a hook for new connections before handling. -// It allows wrapping for timeouts and limiting by returning -// the net.Conn that will be used as the underlying connection. -type ConnCallback func(ctx Context, conn net.Conn) net.Conn - -// LocalPortForwardingCallback is a hook for allowing port forwarding -type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool - -// ReversePortForwardingCallback is a hook for allowing reverse port forwarding -type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool - -// ServerConfigCallback is a hook for creating custom default server configs -type ServerConfigCallback func(ctx Context) *gossh.ServerConfig - -// ConnectionFailedCallback is a hook for reporting failed connections -// Please note: the net.Conn is likely to be closed at this point -type ConnectionFailedCallback func(conn net.Conn, err error) - -// Window represents the size of a PTY window. -// -// See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 -// -// Zero dimension parameters MUST be ignored. The character/row dimensions -// override the pixel dimensions (when nonzero). Pixel dimensions refer -// to the drawable area of the window. -type Window struct { - // Width is the number of columns. - // It overrides WidthPixels. - Width int - // Height is the number of rows. - // It overrides HeightPixels. - Height int - - // WidthPixels is the drawable width of the window, in pixels. - WidthPixels int - // HeightPixels is the drawable height of the window, in pixels. - HeightPixels int -} - -// Pty represents a PTY request and configuration. -type Pty struct { - // Term is the TERM environment variable value. - Term string - - // Window is the Window sent as part of the pty-req. - Window Window - - // Modes represent a mapping of Terminal Mode opcode to value as it was - // requested by the client as part of the pty-req. These are outlined as - // part of https://datatracker.ietf.org/doc/html/rfc4254#section-8. - // - // The opcodes are defined as constants in github.com/tailscale/golang-x-crypto/ssh (VINTR,VQUIT,etc.). - // Boolean opcodes have values 0 or 1. - Modes gossh.TerminalModes -} - -// Serve accepts incoming SSH connections on the listener l, creating a new -// connection goroutine for each. The connection goroutines read requests and -// then calls handler to handle sessions. Handler is typically nil, in which -// case the DefaultHandler is used. -func Serve(l net.Listener, handler Handler, options ...Option) error { - srv := &Server{Handler: handler} - for _, option := range options { - if err := srv.SetOption(option); err != nil { - return err - } - } - return srv.Serve(l) -} - -// ListenAndServe listens on the TCP network address addr and then calls Serve -// with handler to handle sessions on incoming connections. Handler is typically -// nil, in which case the DefaultHandler is used. -func ListenAndServe(addr string, handler Handler, options ...Option) error { - srv := &Server{Addr: addr, Handler: handler} - for _, option := range options { - if err := srv.SetOption(option); err != nil { - return err - } - } - return srv.ListenAndServe() -} - -// Handle registers the handler as the DefaultHandler. -func Handle(handler Handler) { - DefaultHandler = handler -} - -// KeysEqual is constant time compare of the keys to avoid timing attacks. -func KeysEqual(ak, bk PublicKey) bool { - - //avoid panic if one of the keys is nil, return false instead - if ak == nil || bk == nil { - return false - } - - a := ak.Marshal() - b := bk.Marshal() - return (len(a) == len(b) && subtle.ConstantTimeCompare(a, b) == 1) -} +package ssh + +import ( + "crypto/subtle" + "net" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +type Signal string + +// POSIX signals as listed in RFC 4254 Section 6.10. +const ( + SIGABRT Signal = "ABRT" + SIGALRM Signal = "ALRM" + SIGFPE Signal = "FPE" + SIGHUP Signal = "HUP" + SIGILL Signal = "ILL" + SIGINT Signal = "INT" + SIGKILL Signal = "KILL" + SIGPIPE Signal = "PIPE" + SIGQUIT Signal = "QUIT" + SIGSEGV Signal = "SEGV" + SIGTERM Signal = "TERM" + SIGUSR1 Signal = "USR1" + SIGUSR2 Signal = "USR2" +) + +// DefaultHandler is the default Handler used by Serve. +var DefaultHandler Handler + +// Option is a functional option handler for Server. +type Option func(*Server) error + +// Handler is a callback for handling established SSH sessions. +type Handler func(Session) + +// PublicKeyHandler is a callback for performing public key authentication. +type PublicKeyHandler func(ctx Context, key PublicKey) error + +type NoClientAuthHandler func(ctx Context) error + +type BannerHandler func(ctx Context) string + +// PasswordHandler is a callback for performing password authentication. +type PasswordHandler func(ctx Context, password string) bool + +// KeyboardInteractiveHandler is a callback for performing keyboard-interactive authentication. +type KeyboardInteractiveHandler func(ctx Context, challenger gossh.KeyboardInteractiveChallenge) bool + +// PtyCallback is a hook for allowing PTY sessions. +type PtyCallback func(ctx Context, pty Pty) bool + +// SessionRequestCallback is a callback for allowing or denying SSH sessions. +type SessionRequestCallback func(sess Session, requestType string) bool + +// ConnCallback is a hook for new connections before handling. +// It allows wrapping for timeouts and limiting by returning +// the net.Conn that will be used as the underlying connection. +type ConnCallback func(ctx Context, conn net.Conn) net.Conn + +// LocalPortForwardingCallback is a hook for allowing port forwarding +type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool + +// ReversePortForwardingCallback is a hook for allowing reverse port forwarding +type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool + +// ServerConfigCallback is a hook for creating custom default server configs +type ServerConfigCallback func(ctx Context) *gossh.ServerConfig + +// ConnectionFailedCallback is a hook for reporting failed connections +// Please note: the net.Conn is likely to be closed at this point +type ConnectionFailedCallback func(conn net.Conn, err error) + +// Window represents the size of a PTY window. +// +// See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 +// +// Zero dimension parameters MUST be ignored. The character/row dimensions +// override the pixel dimensions (when nonzero). Pixel dimensions refer +// to the drawable area of the window. +type Window struct { + // Width is the number of columns. + // It overrides WidthPixels. + Width int + // Height is the number of rows. + // It overrides HeightPixels. + Height int + + // WidthPixels is the drawable width of the window, in pixels. + WidthPixels int + // HeightPixels is the drawable height of the window, in pixels. + HeightPixels int +} + +// Pty represents a PTY request and configuration. +type Pty struct { + // Term is the TERM environment variable value. + Term string + + // Window is the Window sent as part of the pty-req. + Window Window + + // Modes represent a mapping of Terminal Mode opcode to value as it was + // requested by the client as part of the pty-req. These are outlined as + // part of https://datatracker.ietf.org/doc/html/rfc4254#section-8. + // + // The opcodes are defined as constants in github.com/tailscale/golang-x-crypto/ssh (VINTR,VQUIT,etc.). + // Boolean opcodes have values 0 or 1. + Modes gossh.TerminalModes +} + +// Serve accepts incoming SSH connections on the listener l, creating a new +// connection goroutine for each. The connection goroutines read requests and +// then calls handler to handle sessions. Handler is typically nil, in which +// case the DefaultHandler is used. +func Serve(l net.Listener, handler Handler, options ...Option) error { + srv := &Server{Handler: handler} + for _, option := range options { + if err := srv.SetOption(option); err != nil { + return err + } + } + return srv.Serve(l) +} + +// ListenAndServe listens on the TCP network address addr and then calls Serve +// with handler to handle sessions on incoming connections. Handler is typically +// nil, in which case the DefaultHandler is used. +func ListenAndServe(addr string, handler Handler, options ...Option) error { + srv := &Server{Addr: addr, Handler: handler} + for _, option := range options { + if err := srv.SetOption(option); err != nil { + return err + } + } + return srv.ListenAndServe() +} + +// Handle registers the handler as the DefaultHandler. +func Handle(handler Handler) { + DefaultHandler = handler +} + +// KeysEqual is constant time compare of the keys to avoid timing attacks. +func KeysEqual(ak, bk PublicKey) bool { + + //avoid panic if one of the keys is nil, return false instead + if ak == nil || bk == nil { + return false + } + + a := ak.Marshal() + b := bk.Marshal() + return (len(a) == len(b) && subtle.ConstantTimeCompare(a, b) == 1) +} diff --git a/tempfork/gliderlabs/ssh/ssh_test.go b/tempfork/gliderlabs/ssh/ssh_test.go index 8772c03adea53..aa301b0489f21 100644 --- a/tempfork/gliderlabs/ssh/ssh_test.go +++ b/tempfork/gliderlabs/ssh/ssh_test.go @@ -1,17 +1,17 @@ -package ssh - -import ( - "testing" -) - -func TestKeysEqual(t *testing.T) { - defer func() { - if r := recover(); r != nil { - t.Errorf("The code did panic") - } - }() - - if KeysEqual(nil, nil) { - t.Error("two nil keys should not return true") - } -} +package ssh + +import ( + "testing" +) + +func TestKeysEqual(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("The code did panic") + } + }() + + if KeysEqual(nil, nil) { + t.Error("two nil keys should not return true") + } +} diff --git a/tempfork/gliderlabs/ssh/tcpip.go b/tempfork/gliderlabs/ssh/tcpip.go index d30bb15ac284b..056a0c7343daf 100644 --- a/tempfork/gliderlabs/ssh/tcpip.go +++ b/tempfork/gliderlabs/ssh/tcpip.go @@ -1,193 +1,193 @@ -package ssh - -import ( - "io" - "log" - "net" - "strconv" - "sync" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -const ( - forwardedTCPChannelType = "forwarded-tcpip" -) - -// direct-tcpip data struct as specified in RFC4254, Section 7.2 -type localForwardChannelData struct { - DestAddr string - DestPort uint32 - - OriginAddr string - OriginPort uint32 -} - -// DirectTCPIPHandler can be enabled by adding it to the server's -// ChannelHandlers under direct-tcpip. -func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { - d := localForwardChannelData{} - if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { - newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) - return - } - - if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestAddr, d.DestPort) { - newChan.Reject(gossh.Prohibited, "port forwarding is disabled") - return - } - - dest := net.JoinHostPort(d.DestAddr, strconv.FormatInt(int64(d.DestPort), 10)) - - var dialer net.Dialer - dconn, err := dialer.DialContext(ctx, "tcp", dest) - if err != nil { - newChan.Reject(gossh.ConnectionFailed, err.Error()) - return - } - - ch, reqs, err := newChan.Accept() - if err != nil { - dconn.Close() - return - } - go gossh.DiscardRequests(reqs) - - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(ch, dconn) - }() - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(dconn, ch) - }() -} - -type remoteForwardRequest struct { - BindAddr string - BindPort uint32 -} - -type remoteForwardSuccess struct { - BindPort uint32 -} - -type remoteForwardCancelRequest struct { - BindAddr string - BindPort uint32 -} - -type remoteForwardChannelData struct { - DestAddr string - DestPort uint32 - OriginAddr string - OriginPort uint32 -} - -// ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and -// adding the HandleSSHRequest callback to the server's RequestHandlers under -// tcpip-forward and cancel-tcpip-forward. -type ForwardedTCPHandler struct { - forwards map[string]net.Listener - sync.Mutex -} - -func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { - h.Lock() - if h.forwards == nil { - h.forwards = make(map[string]net.Listener) - } - h.Unlock() - conn := ctx.Value(ContextKeyConn).(*gossh.ServerConn) - switch req.Type { - case "tcpip-forward": - var reqPayload remoteForwardRequest - if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { - // TODO: log parse failure - return false, []byte{} - } - if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) { - return false, []byte("port forwarding is disabled") - } - addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) - ln, err := net.Listen("tcp", addr) - if err != nil { - // TODO: log listen failure - return false, []byte{} - } - _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) - destPort, _ := strconv.Atoi(destPortStr) - h.Lock() - h.forwards[addr] = ln - h.Unlock() - go func() { - <-ctx.Done() - h.Lock() - ln, ok := h.forwards[addr] - h.Unlock() - if ok { - ln.Close() - } - }() - go func() { - for { - c, err := ln.Accept() - if err != nil { - // TODO: log accept failure - break - } - originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String()) - originPort, _ := strconv.Atoi(orignPortStr) - payload := gossh.Marshal(&remoteForwardChannelData{ - DestAddr: reqPayload.BindAddr, - DestPort: uint32(destPort), - OriginAddr: originAddr, - OriginPort: uint32(originPort), - }) - go func() { - ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload) - if err != nil { - // TODO: log failure to open channel - log.Println(err) - c.Close() - return - } - go gossh.DiscardRequests(reqs) - go func() { - defer ch.Close() - defer c.Close() - io.Copy(ch, c) - }() - go func() { - defer ch.Close() - defer c.Close() - io.Copy(c, ch) - }() - }() - } - h.Lock() - delete(h.forwards, addr) - h.Unlock() - }() - return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)}) - - case "cancel-tcpip-forward": - var reqPayload remoteForwardCancelRequest - if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { - // TODO: log parse failure - return false, []byte{} - } - addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) - h.Lock() - ln, ok := h.forwards[addr] - h.Unlock() - if ok { - ln.Close() - } - return true, nil - default: - return false, nil - } -} +package ssh + +import ( + "io" + "log" + "net" + "strconv" + "sync" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +const ( + forwardedTCPChannelType = "forwarded-tcpip" +) + +// direct-tcpip data struct as specified in RFC4254, Section 7.2 +type localForwardChannelData struct { + DestAddr string + DestPort uint32 + + OriginAddr string + OriginPort uint32 +} + +// DirectTCPIPHandler can be enabled by adding it to the server's +// ChannelHandlers under direct-tcpip. +func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { + d := localForwardChannelData{} + if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { + newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) + return + } + + if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestAddr, d.DestPort) { + newChan.Reject(gossh.Prohibited, "port forwarding is disabled") + return + } + + dest := net.JoinHostPort(d.DestAddr, strconv.FormatInt(int64(d.DestPort), 10)) + + var dialer net.Dialer + dconn, err := dialer.DialContext(ctx, "tcp", dest) + if err != nil { + newChan.Reject(gossh.ConnectionFailed, err.Error()) + return + } + + ch, reqs, err := newChan.Accept() + if err != nil { + dconn.Close() + return + } + go gossh.DiscardRequests(reqs) + + go func() { + defer ch.Close() + defer dconn.Close() + io.Copy(ch, dconn) + }() + go func() { + defer ch.Close() + defer dconn.Close() + io.Copy(dconn, ch) + }() +} + +type remoteForwardRequest struct { + BindAddr string + BindPort uint32 +} + +type remoteForwardSuccess struct { + BindPort uint32 +} + +type remoteForwardCancelRequest struct { + BindAddr string + BindPort uint32 +} + +type remoteForwardChannelData struct { + DestAddr string + DestPort uint32 + OriginAddr string + OriginPort uint32 +} + +// ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and +// adding the HandleSSHRequest callback to the server's RequestHandlers under +// tcpip-forward and cancel-tcpip-forward. +type ForwardedTCPHandler struct { + forwards map[string]net.Listener + sync.Mutex +} + +func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { + h.Lock() + if h.forwards == nil { + h.forwards = make(map[string]net.Listener) + } + h.Unlock() + conn := ctx.Value(ContextKeyConn).(*gossh.ServerConn) + switch req.Type { + case "tcpip-forward": + var reqPayload remoteForwardRequest + if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { + // TODO: log parse failure + return false, []byte{} + } + if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) { + return false, []byte("port forwarding is disabled") + } + addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) + ln, err := net.Listen("tcp", addr) + if err != nil { + // TODO: log listen failure + return false, []byte{} + } + _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) + destPort, _ := strconv.Atoi(destPortStr) + h.Lock() + h.forwards[addr] = ln + h.Unlock() + go func() { + <-ctx.Done() + h.Lock() + ln, ok := h.forwards[addr] + h.Unlock() + if ok { + ln.Close() + } + }() + go func() { + for { + c, err := ln.Accept() + if err != nil { + // TODO: log accept failure + break + } + originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String()) + originPort, _ := strconv.Atoi(orignPortStr) + payload := gossh.Marshal(&remoteForwardChannelData{ + DestAddr: reqPayload.BindAddr, + DestPort: uint32(destPort), + OriginAddr: originAddr, + OriginPort: uint32(originPort), + }) + go func() { + ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload) + if err != nil { + // TODO: log failure to open channel + log.Println(err) + c.Close() + return + } + go gossh.DiscardRequests(reqs) + go func() { + defer ch.Close() + defer c.Close() + io.Copy(ch, c) + }() + go func() { + defer ch.Close() + defer c.Close() + io.Copy(c, ch) + }() + }() + } + h.Lock() + delete(h.forwards, addr) + h.Unlock() + }() + return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)}) + + case "cancel-tcpip-forward": + var reqPayload remoteForwardCancelRequest + if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { + // TODO: log parse failure + return false, []byte{} + } + addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) + h.Lock() + ln, ok := h.forwards[addr] + h.Unlock() + if ok { + ln.Close() + } + return true, nil + default: + return false, nil + } +} diff --git a/tempfork/gliderlabs/ssh/tcpip_test.go b/tempfork/gliderlabs/ssh/tcpip_test.go index e1d74d566c7bf..118b5d53ac4a1 100644 --- a/tempfork/gliderlabs/ssh/tcpip_test.go +++ b/tempfork/gliderlabs/ssh/tcpip_test.go @@ -1,85 +1,85 @@ -//go:build glidertests - -package ssh - -import ( - "bytes" - "io" - "net" - "strconv" - "strings" - "testing" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -var sampleServerResponse = []byte("Hello world") - -func sampleSocketServer() net.Listener { - l := newLocalListener() - - go func() { - conn, err := l.Accept() - if err != nil { - return - } - conn.Write(sampleServerResponse) - conn.Close() - }() - - return l -} - -func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { - l := sampleSocketServer() - - _, client, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) {}, - LocalPortForwardingCallback: func(ctx Context, destinationHost string, destinationPort uint32) bool { - addr := net.JoinHostPort(destinationHost, strconv.FormatInt(int64(destinationPort), 10)) - if addr != l.Addr().String() { - panic("unexpected destinationHost: " + addr) - } - return forwardingEnabled - }, - }, nil) - - return l, client, func() { - cleanup() - l.Close() - } -} - -func TestLocalPortForwardingWorks(t *testing.T) { - t.Parallel() - - l, client, cleanup := newTestSessionWithForwarding(t, true) - defer cleanup() - - conn, err := client.Dial("tcp", l.Addr().String()) - if err != nil { - t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) - } - result, err := io.ReadAll(conn) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(result, sampleServerResponse) { - t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) - } -} - -func TestLocalPortForwardingRespectsCallback(t *testing.T) { - t.Parallel() - - l, client, cleanup := newTestSessionWithForwarding(t, false) - defer cleanup() - - _, err := client.Dial("tcp", l.Addr().String()) - if err == nil { - t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String()) - } - if !strings.Contains(err.Error(), "port forwarding is disabled") { - t.Fatalf("Expected permission error but got %#v", err) - } -} +//go:build glidertests + +package ssh + +import ( + "bytes" + "io" + "net" + "strconv" + "strings" + "testing" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +var sampleServerResponse = []byte("Hello world") + +func sampleSocketServer() net.Listener { + l := newLocalListener() + + go func() { + conn, err := l.Accept() + if err != nil { + return + } + conn.Write(sampleServerResponse) + conn.Close() + }() + + return l +} + +func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { + l := sampleSocketServer() + + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + LocalPortForwardingCallback: func(ctx Context, destinationHost string, destinationPort uint32) bool { + addr := net.JoinHostPort(destinationHost, strconv.FormatInt(int64(destinationPort), 10)) + if addr != l.Addr().String() { + panic("unexpected destinationHost: " + addr) + } + return forwardingEnabled + }, + }, nil) + + return l, client, func() { + cleanup() + l.Close() + } +} + +func TestLocalPortForwardingWorks(t *testing.T) { + t.Parallel() + + l, client, cleanup := newTestSessionWithForwarding(t, true) + defer cleanup() + + conn, err := client.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) + } + result, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, sampleServerResponse) { + t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) + } +} + +func TestLocalPortForwardingRespectsCallback(t *testing.T) { + t.Parallel() + + l, client, cleanup := newTestSessionWithForwarding(t, false) + defer cleanup() + + _, err := client.Dial("tcp", l.Addr().String()) + if err == nil { + t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String()) + } + if !strings.Contains(err.Error(), "port forwarding is disabled") { + t.Fatalf("Expected permission error but got %#v", err) + } +} diff --git a/tempfork/gliderlabs/ssh/util.go b/tempfork/gliderlabs/ssh/util.go index 7a6a1824109bf..e3b5716a3ab55 100644 --- a/tempfork/gliderlabs/ssh/util.go +++ b/tempfork/gliderlabs/ssh/util.go @@ -1,157 +1,157 @@ -package ssh - -import ( - "crypto/rand" - "crypto/rsa" - "encoding/binary" - - "github.com/tailscale/golang-x-crypto/ssh" -) - -func generateSigner() (ssh.Signer, error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, err - } - return ssh.NewSignerFromKey(key) -} - -func parsePtyRequest(payload []byte) (pty Pty, ok bool) { - // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 - // 6.2. Requesting a Pseudo-Terminal - // A pseudo-terminal can be allocated for the session by sending the - // following message. - // byte SSH_MSG_CHANNEL_REQUEST - // uint32 recipient channel - // string "pty-req" - // boolean want_reply - // string TERM environment variable value (e.g., vt100) - // uint32 terminal width, characters (e.g., 80) - // uint32 terminal height, rows (e.g., 24) - // uint32 terminal width, pixels (e.g., 640) - // uint32 terminal height, pixels (e.g., 480) - // string encoded terminal modes - - // The payload starts from the TERM variable. - term, rem, ok := parseString(payload) - if !ok { - return - } - win, rem, ok := parseWindow(rem) - if !ok { - return - } - modes, ok := parseTerminalModes(rem) - if !ok { - return - } - pty = Pty{ - Term: term, - Window: win, - Modes: modes, - } - return -} - -func parseTerminalModes(in []byte) (modes ssh.TerminalModes, ok bool) { - // See https://datatracker.ietf.org/doc/html/rfc4254#section-8 - // 8. Encoding of Terminal Modes - // - // All 'encoded terminal modes' (as passed in a pty request) are encoded - // into a byte stream. It is intended that the coding be portable - // across different environments. The stream consists of opcode- - // argument pairs wherein the opcode is a byte value. Opcodes 1 to 159 - // have a single uint32 argument. Opcodes 160 to 255 are not yet - // defined, and cause parsing to stop (they should only be used after - // any other data). The stream is terminated by opcode TTY_OP_END - // (0x00). - // - // The client SHOULD put any modes it knows about in the stream, and the - // server MAY ignore any modes it does not know about. This allows some - // degree of machine-independence, at least between systems that use a - // POSIX-like tty interface. The protocol can support other systems as - // well, but the client may need to fill reasonable values for a number - // of parameters so the server pty gets set to a reasonable mode (the - // server leaves all unspecified mode bits in their default values, and - // only some combinations make sense). - _, rem, ok := parseUint32(in) - if !ok { - return - } - const ttyOpEnd = 0 - for len(rem) > 0 { - if modes == nil { - modes = make(ssh.TerminalModes) - } - code := uint8(rem[0]) - rem = rem[1:] - if code == ttyOpEnd || code > 160 { - break - } - var val uint32 - val, rem, ok = parseUint32(rem) - if !ok { - return - } - modes[code] = val - } - ok = true - return -} - -func parseWindow(s []byte) (win Window, rem []byte, ok bool) { - // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.7 - // 6.7. Window Dimension Change Message - // When the window (terminal) size changes on the client side, it MAY - // send a message to the other side to inform it of the new dimensions. - - // byte SSH_MSG_CHANNEL_REQUEST - // uint32 recipient channel - // string "window-change" - // boolean FALSE - // uint32 terminal width, columns - // uint32 terminal height, rows - // uint32 terminal width, pixels - // uint32 terminal height, pixels - wCols, rem, ok := parseUint32(s) - if !ok { - return - } - hRows, rem, ok := parseUint32(rem) - if !ok { - return - } - wPixels, rem, ok := parseUint32(rem) - if !ok { - return - } - hPixels, rem, ok := parseUint32(rem) - if !ok { - return - } - win = Window{ - Width: int(wCols), - Height: int(hRows), - WidthPixels: int(wPixels), - HeightPixels: int(hPixels), - } - return -} - -func parseString(in []byte) (out string, rem []byte, ok bool) { - length, rem, ok := parseUint32(in) - if uint32(len(rem)) < length || !ok { - ok = false - return - } - out, rem = string(rem[:length]), rem[length:] - ok = true - return -} - -func parseUint32(in []byte) (uint32, []byte, bool) { - if len(in) < 4 { - return 0, nil, false - } - return binary.BigEndian.Uint32(in), in[4:], true -} +package ssh + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/binary" + + "github.com/tailscale/golang-x-crypto/ssh" +) + +func generateSigner() (ssh.Signer, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + return ssh.NewSignerFromKey(key) +} + +func parsePtyRequest(payload []byte) (pty Pty, ok bool) { + // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 + // 6.2. Requesting a Pseudo-Terminal + // A pseudo-terminal can be allocated for the session by sending the + // following message. + // byte SSH_MSG_CHANNEL_REQUEST + // uint32 recipient channel + // string "pty-req" + // boolean want_reply + // string TERM environment variable value (e.g., vt100) + // uint32 terminal width, characters (e.g., 80) + // uint32 terminal height, rows (e.g., 24) + // uint32 terminal width, pixels (e.g., 640) + // uint32 terminal height, pixels (e.g., 480) + // string encoded terminal modes + + // The payload starts from the TERM variable. + term, rem, ok := parseString(payload) + if !ok { + return + } + win, rem, ok := parseWindow(rem) + if !ok { + return + } + modes, ok := parseTerminalModes(rem) + if !ok { + return + } + pty = Pty{ + Term: term, + Window: win, + Modes: modes, + } + return +} + +func parseTerminalModes(in []byte) (modes ssh.TerminalModes, ok bool) { + // See https://datatracker.ietf.org/doc/html/rfc4254#section-8 + // 8. Encoding of Terminal Modes + // + // All 'encoded terminal modes' (as passed in a pty request) are encoded + // into a byte stream. It is intended that the coding be portable + // across different environments. The stream consists of opcode- + // argument pairs wherein the opcode is a byte value. Opcodes 1 to 159 + // have a single uint32 argument. Opcodes 160 to 255 are not yet + // defined, and cause parsing to stop (they should only be used after + // any other data). The stream is terminated by opcode TTY_OP_END + // (0x00). + // + // The client SHOULD put any modes it knows about in the stream, and the + // server MAY ignore any modes it does not know about. This allows some + // degree of machine-independence, at least between systems that use a + // POSIX-like tty interface. The protocol can support other systems as + // well, but the client may need to fill reasonable values for a number + // of parameters so the server pty gets set to a reasonable mode (the + // server leaves all unspecified mode bits in their default values, and + // only some combinations make sense). + _, rem, ok := parseUint32(in) + if !ok { + return + } + const ttyOpEnd = 0 + for len(rem) > 0 { + if modes == nil { + modes = make(ssh.TerminalModes) + } + code := uint8(rem[0]) + rem = rem[1:] + if code == ttyOpEnd || code > 160 { + break + } + var val uint32 + val, rem, ok = parseUint32(rem) + if !ok { + return + } + modes[code] = val + } + ok = true + return +} + +func parseWindow(s []byte) (win Window, rem []byte, ok bool) { + // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.7 + // 6.7. Window Dimension Change Message + // When the window (terminal) size changes on the client side, it MAY + // send a message to the other side to inform it of the new dimensions. + + // byte SSH_MSG_CHANNEL_REQUEST + // uint32 recipient channel + // string "window-change" + // boolean FALSE + // uint32 terminal width, columns + // uint32 terminal height, rows + // uint32 terminal width, pixels + // uint32 terminal height, pixels + wCols, rem, ok := parseUint32(s) + if !ok { + return + } + hRows, rem, ok := parseUint32(rem) + if !ok { + return + } + wPixels, rem, ok := parseUint32(rem) + if !ok { + return + } + hPixels, rem, ok := parseUint32(rem) + if !ok { + return + } + win = Window{ + Width: int(wCols), + Height: int(hRows), + WidthPixels: int(wPixels), + HeightPixels: int(hPixels), + } + return +} + +func parseString(in []byte) (out string, rem []byte, ok bool) { + length, rem, ok := parseUint32(in) + if uint32(len(rem)) < length || !ok { + ok = false + return + } + out, rem = string(rem[:length]), rem[length:] + ok = true + return +} + +func parseUint32(in []byte) (uint32, []byte, bool) { + if len(in) < 4 { + return 0, nil, false + } + return binary.BigEndian.Uint32(in), in[4:], true +} diff --git a/tempfork/gliderlabs/ssh/wrap.go b/tempfork/gliderlabs/ssh/wrap.go index f44f5d9bff299..17867d7518dd1 100644 --- a/tempfork/gliderlabs/ssh/wrap.go +++ b/tempfork/gliderlabs/ssh/wrap.go @@ -1,33 +1,33 @@ -package ssh - -import gossh "github.com/tailscale/golang-x-crypto/ssh" - -// PublicKey is an abstraction of different types of public keys. -type PublicKey interface { - gossh.PublicKey -} - -// The Permissions type holds fine-grained permissions that are specific to a -// user or a specific authentication method for a user. Permissions, except for -// "source-address", must be enforced in the server application layer, after -// successful authentication. -type Permissions struct { - *gossh.Permissions -} - -// A Signer can create signatures that verify against a public key. -type Signer interface { - gossh.Signer -} - -// ParseAuthorizedKey parses a public key from an authorized_keys file used in -// OpenSSH according to the sshd(8) manual page. -func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { - return gossh.ParseAuthorizedKey(in) -} - -// ParsePublicKey parses an SSH public key formatted for use in -// the SSH wire protocol according to RFC 4253, section 6.6. -func ParsePublicKey(in []byte) (out PublicKey, err error) { - return gossh.ParsePublicKey(in) -} +package ssh + +import gossh "github.com/tailscale/golang-x-crypto/ssh" + +// PublicKey is an abstraction of different types of public keys. +type PublicKey interface { + gossh.PublicKey +} + +// The Permissions type holds fine-grained permissions that are specific to a +// user or a specific authentication method for a user. Permissions, except for +// "source-address", must be enforced in the server application layer, after +// successful authentication. +type Permissions struct { + *gossh.Permissions +} + +// A Signer can create signatures that verify against a public key. +type Signer interface { + gossh.Signer +} + +// ParseAuthorizedKey parses a public key from an authorized_keys file used in +// OpenSSH according to the sshd(8) manual page. +func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { + return gossh.ParseAuthorizedKey(in) +} + +// ParsePublicKey parses an SSH public key formatted for use in +// the SSH wire protocol according to RFC 4253, section 6.6. +func ParsePublicKey(in []byte) (out PublicKey, err error) { + return gossh.ParsePublicKey(in) +} diff --git a/tempfork/heap/heap.go b/tempfork/heap/heap.go index 080b80ca5f7f0..3dfab492ad0b8 100644 --- a/tempfork/heap/heap.go +++ b/tempfork/heap/heap.go @@ -1,121 +1,121 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package heap provides heap operations for any type that implements -// heap.Interface. A heap is a tree with the property that each node is the -// minimum-valued node in its subtree. -// -// The minimum element in the tree is the root, at index 0. -// -// A heap is a common way to implement a priority queue. To build a priority -// queue, implement the Heap interface with the (negative) priority as the -// ordering for the Less method, so Push adds items while Pop removes the -// highest-priority item from the queue. The Examples include such an -// implementation; the file example_pq_test.go has the complete source. -// -// This package is a copy of the Go standard library's -// container/heap, but using generics. -package heap - -import "sort" - -// The Interface type describes the requirements -// for a type using the routines in this package. -// Any type that implements it may be used as a -// min-heap with the following invariants (established after -// Init has been called or if the data is empty or sorted): -// -// !h.Less(j, i) for 0 <= i < h.Len() and 2*i+1 <= j <= 2*i+2 and j < h.Len() -// -// Note that Push and Pop in this interface are for package heap's -// implementation to call. To add and remove things from the heap, -// use heap.Push and heap.Pop. -type Interface[V any] interface { - sort.Interface - Push(x V) // add x as element Len() - Pop() V // remove and return element Len() - 1. -} - -// Init establishes the heap invariants required by the other routines in this package. -// Init is idempotent with respect to the heap invariants -// and may be called whenever the heap invariants may have been invalidated. -// The complexity is O(n) where n = h.Len(). -func Init[V any](h Interface[V]) { - // heapify - n := h.Len() - for i := n/2 - 1; i >= 0; i-- { - down(h, i, n) - } -} - -// Push pushes the element x onto the heap. -// The complexity is O(log n) where n = h.Len(). -func Push[V any](h Interface[V], x V) { - h.Push(x) - up(h, h.Len()-1) -} - -// Pop removes and returns the minimum element (according to Less) from the heap. -// The complexity is O(log n) where n = h.Len(). -// Pop is equivalent to Remove(h, 0). -func Pop[V any](h Interface[V]) V { - n := h.Len() - 1 - h.Swap(0, n) - down(h, 0, n) - return h.Pop() -} - -// Remove removes and returns the element at index i from the heap. -// The complexity is O(log n) where n = h.Len(). -func Remove[V any](h Interface[V], i int) V { - n := h.Len() - 1 - if n != i { - h.Swap(i, n) - if !down(h, i, n) { - up(h, i) - } - } - return h.Pop() -} - -// Fix re-establishes the heap ordering after the element at index i has changed its value. -// Changing the value of the element at index i and then calling Fix is equivalent to, -// but less expensive than, calling Remove(h, i) followed by a Push of the new value. -// The complexity is O(log n) where n = h.Len(). -func Fix[V any](h Interface[V], i int) { - if !down(h, i, h.Len()) { - up(h, i) - } -} - -func up[V any](h Interface[V], j int) { - for { - i := (j - 1) / 2 // parent - if i == j || !h.Less(j, i) { - break - } - h.Swap(i, j) - j = i - } -} - -func down[V any](h Interface[V], i0, n int) bool { - i := i0 - for { - j1 := 2*i + 1 - if j1 >= n || j1 < 0 { // j1 < 0 after int overflow - break - } - j := j1 // left child - if j2 := j1 + 1; j2 < n && h.Less(j2, j1) { - j = j2 // = 2*i + 2 // right child - } - if !h.Less(j, i) { - break - } - h.Swap(i, j) - i = j - } - return i > i0 -} +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package heap provides heap operations for any type that implements +// heap.Interface. A heap is a tree with the property that each node is the +// minimum-valued node in its subtree. +// +// The minimum element in the tree is the root, at index 0. +// +// A heap is a common way to implement a priority queue. To build a priority +// queue, implement the Heap interface with the (negative) priority as the +// ordering for the Less method, so Push adds items while Pop removes the +// highest-priority item from the queue. The Examples include such an +// implementation; the file example_pq_test.go has the complete source. +// +// This package is a copy of the Go standard library's +// container/heap, but using generics. +package heap + +import "sort" + +// The Interface type describes the requirements +// for a type using the routines in this package. +// Any type that implements it may be used as a +// min-heap with the following invariants (established after +// Init has been called or if the data is empty or sorted): +// +// !h.Less(j, i) for 0 <= i < h.Len() and 2*i+1 <= j <= 2*i+2 and j < h.Len() +// +// Note that Push and Pop in this interface are for package heap's +// implementation to call. To add and remove things from the heap, +// use heap.Push and heap.Pop. +type Interface[V any] interface { + sort.Interface + Push(x V) // add x as element Len() + Pop() V // remove and return element Len() - 1. +} + +// Init establishes the heap invariants required by the other routines in this package. +// Init is idempotent with respect to the heap invariants +// and may be called whenever the heap invariants may have been invalidated. +// The complexity is O(n) where n = h.Len(). +func Init[V any](h Interface[V]) { + // heapify + n := h.Len() + for i := n/2 - 1; i >= 0; i-- { + down(h, i, n) + } +} + +// Push pushes the element x onto the heap. +// The complexity is O(log n) where n = h.Len(). +func Push[V any](h Interface[V], x V) { + h.Push(x) + up(h, h.Len()-1) +} + +// Pop removes and returns the minimum element (according to Less) from the heap. +// The complexity is O(log n) where n = h.Len(). +// Pop is equivalent to Remove(h, 0). +func Pop[V any](h Interface[V]) V { + n := h.Len() - 1 + h.Swap(0, n) + down(h, 0, n) + return h.Pop() +} + +// Remove removes and returns the element at index i from the heap. +// The complexity is O(log n) where n = h.Len(). +func Remove[V any](h Interface[V], i int) V { + n := h.Len() - 1 + if n != i { + h.Swap(i, n) + if !down(h, i, n) { + up(h, i) + } + } + return h.Pop() +} + +// Fix re-establishes the heap ordering after the element at index i has changed its value. +// Changing the value of the element at index i and then calling Fix is equivalent to, +// but less expensive than, calling Remove(h, i) followed by a Push of the new value. +// The complexity is O(log n) where n = h.Len(). +func Fix[V any](h Interface[V], i int) { + if !down(h, i, h.Len()) { + up(h, i) + } +} + +func up[V any](h Interface[V], j int) { + for { + i := (j - 1) / 2 // parent + if i == j || !h.Less(j, i) { + break + } + h.Swap(i, j) + j = i + } +} + +func down[V any](h Interface[V], i0, n int) bool { + i := i0 + for { + j1 := 2*i + 1 + if j1 >= n || j1 < 0 { // j1 < 0 after int overflow + break + } + j := j1 // left child + if j2 := j1 + 1; j2 < n && h.Less(j2, j1) { + j = j2 // = 2*i + 2 // right child + } + if !h.Less(j, i) { + break + } + h.Swap(i, j) + i = j + } + return i > i0 +} diff --git a/tka/aum_test.go b/tka/aum_test.go index 84b5674776319..4297efabff13f 100644 --- a/tka/aum_test.go +++ b/tka/aum_test.go @@ -1,253 +1,253 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "testing" - - "github.com/google/go-cmp/cmp" - "golang.org/x/crypto/blake2s" - "tailscale.com/types/tkatype" -) - -func TestSerialization(t *testing.T) { - uint2 := uint(2) - var fakeAUMHash AUMHash - - tcs := []struct { - Name string - AUM AUM - Expect []byte - }{ - { - "AddKey", - AUM{MessageKind: AUMAddKey, Key: &Key{}}, - []byte{ - 0xa3, // major type 5 (map), 3 items - 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) - 0x01, // |- major type 0 (int), value 1 (first value, AUMAddKey) - 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) - 0xf6, // |- major type 7 (val), value null (second value, nil) - 0x03, // |- major type 0 (int), value 3 (third key, Key) - 0xa3, // |- major type 5 (map), 3 items (type Key) - 0x01, // |- major type 0 (int), value 1 (first key, Kind) - 0x00, // |- major type 0 (int), value 0 (first value) - 0x02, // |- major type 0 (int), value 2 (second key, Votes) - 0x00, // |- major type 0 (int), value 0 (first value) - 0x03, // |- major type 0 (int), value 3 (third key, Public) - 0xf6, // |- major type 7 (val), value null (third value, nil) - }, - }, - { - "RemoveKey", - AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}}, - []byte{ - 0xa3, // major type 5 (map), 3 items - 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) - 0x02, // |- major type 0 (int), value 2 (first value, AUMRemoveKey) - 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) - 0xf6, // |- major type 7 (val), value null (second value, nil) - 0x04, // |- major type 0 (int), value 4 (third key, KeyID) - 0x42, // |- major type 2 (byte string), 2 items - 0x01, // |- major type 0 (int), value 1 (byte 1) - 0x02, // |- major type 0 (int), value 2 (byte 2) - }, - }, - { - "UpdateKey", - AUM{MessageKind: AUMUpdateKey, Votes: &uint2, KeyID: []byte{1, 2}, Meta: map[string]string{"a": "b"}}, - []byte{ - 0xa5, // major type 5 (map), 5 items - 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) - 0x04, // |- major type 0 (int), value 4 (first value, AUMUpdateKey) - 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) - 0xf6, // |- major type 7 (val), value null (second value, nil) - 0x04, // |- major type 0 (int), value 4 (third key, KeyID) - 0x42, // |- major type 2 (byte string), 2 items - 0x01, // |- major type 0 (int), value 1 (byte 1) - 0x02, // |- major type 0 (int), value 2 (byte 2) - 0x06, // |- major type 0 (int), value 6 (fourth key, Votes) - 0x02, // |- major type 0 (int), value 2 (forth value, 2) - 0x07, // |- major type 0 (int), value 7 (fifth key, Meta) - 0xa1, // |- major type 5 (map), 1 item (map[string]string type) - 0x61, // |- major type 3 (text string), value 1 (first key, one byte long) - 0x61, // |- byte 'a' - 0x61, // |- major type 3 (text string), value 1 (first value, one byte long) - 0x62, // |- byte 'b' - }, - }, - { - "Checkpoint", - AUM{MessageKind: AUMCheckpoint, PrevAUMHash: []byte{1, 2}, State: &State{ - LastAUMHash: &fakeAUMHash, - Keys: []Key{ - {Kind: Key25519, Public: []byte{5, 6}}, - }, - }}, - append( - append([]byte{ - 0xa3, // major type 5 (map), 3 items - 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) - 0x05, // |- major type 0 (int), value 5 (first value, AUMCheckpoint) - 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) - 0x42, // |- major type 2 (byte string), 2 items (second value) - 0x01, // |- major type 0 (int), value 1 (byte 1) - 0x02, // |- major type 0 (int), value 2 (byte 2) - 0x05, // |- major type 0 (int), value 5 (third key, State) - 0xa3, // |- major type 5 (map), 3 items (third value, State type) - 0x01, // |- major type 0 (int), value 1 (first key, LastAUMHash) - 0x58, 0x20, // |- major type 2 (byte string), 32 items (first value) - }, - bytes.Repeat([]byte{0}, 32)...), - []byte{ - 0x02, // |- major type 0 (int), value 2 (second key, DisablementSecrets) - 0xf6, // |- major type 7 (val), value null (second value, nil) - 0x03, // |- major type 0 (int), value 3 (third key, Keys) - 0x81, // |- major type 4 (array), value 1 (one item in array) - 0xa3, // |- major type 5 (map), 3 items (Key type) - 0x01, // |- major type 0 (int), value 1 (first key, Kind) - 0x01, // |- major type 0 (int), value 1 (first value, Key25519) - 0x02, // |- major type 0 (int), value 2 (second key, Votes) - 0x00, // |- major type 0 (int), value 0 (second value, 0) - 0x03, // |- major type 0 (int), value 3 (third key, Public) - 0x42, // |- major type 2 (byte string), 2 items (third value) - 0x05, // |- major type 0 (int), value 5 (byte 5) - 0x06, // |- major type 0 (int), value 6 (byte 6) - }...), - }, - { - "Signature", - AUM{MessageKind: AUMAddKey, Signatures: []tkatype.Signature{{KeyID: []byte{1}}}}, - []byte{ - 0xa3, // major type 5 (map), 3 items - 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) - 0x01, // |- major type 0 (int), value 1 (first value, AUMAddKey) - 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) - 0xf6, // |- major type 7 (val), value null (second value, nil) - 0x17, // |- major type 0 (int), value 22 (third key, Signatures) - 0x81, // |- major type 4 (array), value 1 (one item in array) - 0xa2, // |- major type 5 (map), 2 items (Signature type) - 0x01, // |- major type 0 (int), value 1 (first key, KeyID) - 0x41, // |- major type 2 (byte string), 1 item - 0x01, // |- major type 0 (int), value 1 (byte 1) - 0x02, // |- major type 0 (int), value 2 (second key, Signature) - 0xf6, // |- major type 7 (val), value null (second value, nil) - }, - }, - } - - for _, tc := range tcs { - t.Run(tc.Name, func(t *testing.T) { - data := []byte(tc.AUM.Serialize()) - if diff := cmp.Diff(tc.Expect, data); diff != "" { - t.Errorf("serialization differs (-want, +got):\n%s", diff) - } - - var decodedAUM AUM - if err := decodedAUM.Unserialize(data); err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - if diff := cmp.Diff(tc.AUM, decodedAUM); diff != "" { - t.Errorf("unmarshalled version differs (-want, +got):\n%s", diff) - } - }) - } -} - -func TestAUMWeight(t *testing.T) { - var fakeKeyID [blake2s.Size]byte - testingRand(t, 1).Read(fakeKeyID[:]) - - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - pub, _ = testingKey25519(t, 2) - key2 := Key{Kind: Key25519, Public: pub, Votes: 2} - - tcs := []struct { - Name string - AUM AUM - State State - Want uint - }{ - { - "Empty", - AUM{}, - State{}, - 0, - }, - { - "Key unknown", - AUM{ - Signatures: []tkatype.Signature{{KeyID: fakeKeyID[:]}}, - }, - State{}, - 0, - }, - { - "Unary key", - AUM{ - Signatures: []tkatype.Signature{{KeyID: key.MustID()}}, - }, - State{ - Keys: []Key{key}, - }, - 2, - }, - { - "Multiple keys", - AUM{ - Signatures: []tkatype.Signature{{KeyID: key.MustID()}, {KeyID: key2.MustID()}}, - }, - State{ - Keys: []Key{key, key2}, - }, - 4, - }, - { - "Double use", - AUM{ - Signatures: []tkatype.Signature{{KeyID: key.MustID()}, {KeyID: key.MustID()}}, - }, - State{ - Keys: []Key{key}, - }, - 2, - }, - } - - for _, tc := range tcs { - t.Run(tc.Name, func(t *testing.T) { - got := tc.AUM.Weight(tc.State) - if got != tc.Want { - t.Errorf("Weight() = %d, want %d", got, tc.Want) - } - }) - } -} - -func TestAUMHashes(t *testing.T) { - // .Hash(): a hash over everything. - // .SigHash(): a hash over everything except the signatures. - // The signatures are over a hash of the AUM, so - // using SigHash() breaks this circularity. - - aum := AUM{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519}} - sigHash1 := aum.SigHash() - aumHash1 := aum.Hash() - - aum.Signatures = []tkatype.Signature{{KeyID: []byte{1, 2, 3, 4}}} - sigHash2 := aum.SigHash() - aumHash2 := aum.Hash() - if len(aum.Signatures) != 1 { - t.Error("signature was removed by one of the hash functions") - } - - if !bytes.Equal(sigHash1[:], sigHash1[:]) { - t.Errorf("signature hash dependent on signatures!\n\t1 = %x\n\t2 = %x", sigHash1, sigHash2) - } - if bytes.Equal(aumHash1[:], aumHash2[:]) { - t.Error("aum hash didnt change") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "testing" + + "github.com/google/go-cmp/cmp" + "golang.org/x/crypto/blake2s" + "tailscale.com/types/tkatype" +) + +func TestSerialization(t *testing.T) { + uint2 := uint(2) + var fakeAUMHash AUMHash + + tcs := []struct { + Name string + AUM AUM + Expect []byte + }{ + { + "AddKey", + AUM{MessageKind: AUMAddKey, Key: &Key{}}, + []byte{ + 0xa3, // major type 5 (map), 3 items + 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) + 0x01, // |- major type 0 (int), value 1 (first value, AUMAddKey) + 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) + 0xf6, // |- major type 7 (val), value null (second value, nil) + 0x03, // |- major type 0 (int), value 3 (third key, Key) + 0xa3, // |- major type 5 (map), 3 items (type Key) + 0x01, // |- major type 0 (int), value 1 (first key, Kind) + 0x00, // |- major type 0 (int), value 0 (first value) + 0x02, // |- major type 0 (int), value 2 (second key, Votes) + 0x00, // |- major type 0 (int), value 0 (first value) + 0x03, // |- major type 0 (int), value 3 (third key, Public) + 0xf6, // |- major type 7 (val), value null (third value, nil) + }, + }, + { + "RemoveKey", + AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}}, + []byte{ + 0xa3, // major type 5 (map), 3 items + 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) + 0x02, // |- major type 0 (int), value 2 (first value, AUMRemoveKey) + 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) + 0xf6, // |- major type 7 (val), value null (second value, nil) + 0x04, // |- major type 0 (int), value 4 (third key, KeyID) + 0x42, // |- major type 2 (byte string), 2 items + 0x01, // |- major type 0 (int), value 1 (byte 1) + 0x02, // |- major type 0 (int), value 2 (byte 2) + }, + }, + { + "UpdateKey", + AUM{MessageKind: AUMUpdateKey, Votes: &uint2, KeyID: []byte{1, 2}, Meta: map[string]string{"a": "b"}}, + []byte{ + 0xa5, // major type 5 (map), 5 items + 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) + 0x04, // |- major type 0 (int), value 4 (first value, AUMUpdateKey) + 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) + 0xf6, // |- major type 7 (val), value null (second value, nil) + 0x04, // |- major type 0 (int), value 4 (third key, KeyID) + 0x42, // |- major type 2 (byte string), 2 items + 0x01, // |- major type 0 (int), value 1 (byte 1) + 0x02, // |- major type 0 (int), value 2 (byte 2) + 0x06, // |- major type 0 (int), value 6 (fourth key, Votes) + 0x02, // |- major type 0 (int), value 2 (forth value, 2) + 0x07, // |- major type 0 (int), value 7 (fifth key, Meta) + 0xa1, // |- major type 5 (map), 1 item (map[string]string type) + 0x61, // |- major type 3 (text string), value 1 (first key, one byte long) + 0x61, // |- byte 'a' + 0x61, // |- major type 3 (text string), value 1 (first value, one byte long) + 0x62, // |- byte 'b' + }, + }, + { + "Checkpoint", + AUM{MessageKind: AUMCheckpoint, PrevAUMHash: []byte{1, 2}, State: &State{ + LastAUMHash: &fakeAUMHash, + Keys: []Key{ + {Kind: Key25519, Public: []byte{5, 6}}, + }, + }}, + append( + append([]byte{ + 0xa3, // major type 5 (map), 3 items + 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) + 0x05, // |- major type 0 (int), value 5 (first value, AUMCheckpoint) + 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) + 0x42, // |- major type 2 (byte string), 2 items (second value) + 0x01, // |- major type 0 (int), value 1 (byte 1) + 0x02, // |- major type 0 (int), value 2 (byte 2) + 0x05, // |- major type 0 (int), value 5 (third key, State) + 0xa3, // |- major type 5 (map), 3 items (third value, State type) + 0x01, // |- major type 0 (int), value 1 (first key, LastAUMHash) + 0x58, 0x20, // |- major type 2 (byte string), 32 items (first value) + }, + bytes.Repeat([]byte{0}, 32)...), + []byte{ + 0x02, // |- major type 0 (int), value 2 (second key, DisablementSecrets) + 0xf6, // |- major type 7 (val), value null (second value, nil) + 0x03, // |- major type 0 (int), value 3 (third key, Keys) + 0x81, // |- major type 4 (array), value 1 (one item in array) + 0xa3, // |- major type 5 (map), 3 items (Key type) + 0x01, // |- major type 0 (int), value 1 (first key, Kind) + 0x01, // |- major type 0 (int), value 1 (first value, Key25519) + 0x02, // |- major type 0 (int), value 2 (second key, Votes) + 0x00, // |- major type 0 (int), value 0 (second value, 0) + 0x03, // |- major type 0 (int), value 3 (third key, Public) + 0x42, // |- major type 2 (byte string), 2 items (third value) + 0x05, // |- major type 0 (int), value 5 (byte 5) + 0x06, // |- major type 0 (int), value 6 (byte 6) + }...), + }, + { + "Signature", + AUM{MessageKind: AUMAddKey, Signatures: []tkatype.Signature{{KeyID: []byte{1}}}}, + []byte{ + 0xa3, // major type 5 (map), 3 items + 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) + 0x01, // |- major type 0 (int), value 1 (first value, AUMAddKey) + 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) + 0xf6, // |- major type 7 (val), value null (second value, nil) + 0x17, // |- major type 0 (int), value 22 (third key, Signatures) + 0x81, // |- major type 4 (array), value 1 (one item in array) + 0xa2, // |- major type 5 (map), 2 items (Signature type) + 0x01, // |- major type 0 (int), value 1 (first key, KeyID) + 0x41, // |- major type 2 (byte string), 1 item + 0x01, // |- major type 0 (int), value 1 (byte 1) + 0x02, // |- major type 0 (int), value 2 (second key, Signature) + 0xf6, // |- major type 7 (val), value null (second value, nil) + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.Name, func(t *testing.T) { + data := []byte(tc.AUM.Serialize()) + if diff := cmp.Diff(tc.Expect, data); diff != "" { + t.Errorf("serialization differs (-want, +got):\n%s", diff) + } + + var decodedAUM AUM + if err := decodedAUM.Unserialize(data); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if diff := cmp.Diff(tc.AUM, decodedAUM); diff != "" { + t.Errorf("unmarshalled version differs (-want, +got):\n%s", diff) + } + }) + } +} + +func TestAUMWeight(t *testing.T) { + var fakeKeyID [blake2s.Size]byte + testingRand(t, 1).Read(fakeKeyID[:]) + + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + pub, _ = testingKey25519(t, 2) + key2 := Key{Kind: Key25519, Public: pub, Votes: 2} + + tcs := []struct { + Name string + AUM AUM + State State + Want uint + }{ + { + "Empty", + AUM{}, + State{}, + 0, + }, + { + "Key unknown", + AUM{ + Signatures: []tkatype.Signature{{KeyID: fakeKeyID[:]}}, + }, + State{}, + 0, + }, + { + "Unary key", + AUM{ + Signatures: []tkatype.Signature{{KeyID: key.MustID()}}, + }, + State{ + Keys: []Key{key}, + }, + 2, + }, + { + "Multiple keys", + AUM{ + Signatures: []tkatype.Signature{{KeyID: key.MustID()}, {KeyID: key2.MustID()}}, + }, + State{ + Keys: []Key{key, key2}, + }, + 4, + }, + { + "Double use", + AUM{ + Signatures: []tkatype.Signature{{KeyID: key.MustID()}, {KeyID: key.MustID()}}, + }, + State{ + Keys: []Key{key}, + }, + 2, + }, + } + + for _, tc := range tcs { + t.Run(tc.Name, func(t *testing.T) { + got := tc.AUM.Weight(tc.State) + if got != tc.Want { + t.Errorf("Weight() = %d, want %d", got, tc.Want) + } + }) + } +} + +func TestAUMHashes(t *testing.T) { + // .Hash(): a hash over everything. + // .SigHash(): a hash over everything except the signatures. + // The signatures are over a hash of the AUM, so + // using SigHash() breaks this circularity. + + aum := AUM{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519}} + sigHash1 := aum.SigHash() + aumHash1 := aum.Hash() + + aum.Signatures = []tkatype.Signature{{KeyID: []byte{1, 2, 3, 4}}} + sigHash2 := aum.SigHash() + aumHash2 := aum.Hash() + if len(aum.Signatures) != 1 { + t.Error("signature was removed by one of the hash functions") + } + + if !bytes.Equal(sigHash1[:], sigHash1[:]) { + t.Errorf("signature hash dependent on signatures!\n\t1 = %x\n\t2 = %x", sigHash1, sigHash2) + } + if bytes.Equal(aumHash1[:], aumHash2[:]) { + t.Error("aum hash didnt change") + } +} diff --git a/tka/builder.go b/tka/builder.go index 19cd340f03823..c14ba2330ae0d 100644 --- a/tka/builder.go +++ b/tka/builder.go @@ -1,180 +1,180 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "fmt" - "os" - - "tailscale.com/types/tkatype" -) - -// Types implementing Signer can sign update messages. -type Signer interface { - // SignAUM returns signatures for the AUM encoded by the given AUMSigHash. - SignAUM(tkatype.AUMSigHash) ([]tkatype.Signature, error) -} - -// UpdateBuilder implements a builder for changes to the tailnet -// key authority. -// -// Finalize must be called to compute the update messages, which -// must then be applied to all Authority objects using Inform(). -type UpdateBuilder struct { - a *Authority - signer Signer - - state State - parent AUMHash - - out []AUM -} - -func (b *UpdateBuilder) mkUpdate(update AUM) error { - prevHash := make([]byte, len(b.parent)) - copy(prevHash, b.parent[:]) - update.PrevAUMHash = prevHash - - if b.signer != nil { - sigs, err := b.signer.SignAUM(update.SigHash()) - if err != nil { - return fmt.Errorf("signing failed: %v", err) - } - update.Signatures = append(update.Signatures, sigs...) - } - if err := update.StaticValidate(); err != nil { - return fmt.Errorf("generated update was invalid: %v", err) - } - state, err := b.state.applyVerifiedAUM(update) - if err != nil { - return fmt.Errorf("update cannot be applied: %v", err) - } - - b.state = state - b.parent = update.Hash() - b.out = append(b.out, update) - return nil -} - -// AddKey adds a new key to the authority. -func (b *UpdateBuilder) AddKey(key Key) error { - keyID, err := key.ID() - if err != nil { - return err - } - - if _, err := b.state.GetKey(keyID); err == nil { - return fmt.Errorf("cannot add key %v: already exists", key) - } - return b.mkUpdate(AUM{MessageKind: AUMAddKey, Key: &key}) -} - -// RemoveKey removes a key from the authority. -func (b *UpdateBuilder) RemoveKey(keyID tkatype.KeyID) error { - if _, err := b.state.GetKey(keyID); err != nil { - return fmt.Errorf("failed reading key %x: %v", keyID, err) - } - return b.mkUpdate(AUM{MessageKind: AUMRemoveKey, KeyID: keyID}) -} - -// SetKeyVote updates the number of votes of an existing key. -func (b *UpdateBuilder) SetKeyVote(keyID tkatype.KeyID, votes uint) error { - if _, err := b.state.GetKey(keyID); err != nil { - return fmt.Errorf("failed reading key %x: %v", keyID, err) - } - return b.mkUpdate(AUM{MessageKind: AUMUpdateKey, Votes: &votes, KeyID: keyID}) -} - -// SetKeyMeta updates key-value metadata stored against an existing key. -// -// TODO(tom): Provide an API to update specific values rather than the whole -// map. -func (b *UpdateBuilder) SetKeyMeta(keyID tkatype.KeyID, meta map[string]string) error { - if _, err := b.state.GetKey(keyID); err != nil { - return fmt.Errorf("failed reading key %x: %v", keyID, err) - } - return b.mkUpdate(AUM{MessageKind: AUMUpdateKey, Meta: meta, KeyID: keyID}) -} - -func (b *UpdateBuilder) generateCheckpoint() error { - // Compute the checkpoint state. - state := b.a.state - for i, update := range b.out { - var err error - if state, err = state.applyVerifiedAUM(update); err != nil { - return fmt.Errorf("applying update %d: %v", i, err) - } - } - - // Checkpoints cant specify a parent AUM. - state.LastAUMHash = nil - return b.mkUpdate(AUM{MessageKind: AUMCheckpoint, State: &state}) -} - -// checkpointEvery sets how often a checkpoint AUM should be generated. -const checkpointEvery = 50 - -// Finalize returns the set of update message to actuate the update. -func (b *UpdateBuilder) Finalize(storage Chonk) ([]AUM, error) { - var ( - needCheckpoint bool = true - cursor AUMHash = b.a.Head() - ) - for i := len(b.out); i < checkpointEvery; i++ { - aum, err := storage.AUM(cursor) - if err != nil { - if err == os.ErrNotExist { - // The available chain is shorter than the interval to checkpoint at. - needCheckpoint = false - break - } - return nil, fmt.Errorf("reading AUM: %v", err) - } - - if aum.MessageKind == AUMCheckpoint { - needCheckpoint = false - break - } - - parent, hasParent := aum.Parent() - if !hasParent { - // We've hit the genesis update, so the chain is shorter than the interval to checkpoint at. - needCheckpoint = false - break - } - cursor = parent - } - - if needCheckpoint { - if err := b.generateCheckpoint(); err != nil { - return nil, fmt.Errorf("generating checkpoint: %v", err) - } - } - - // Check no AUMs were applied in the meantime - if len(b.out) > 0 { - if parent, _ := b.out[0].Parent(); parent != b.a.Head() { - return nil, fmt.Errorf("updates no longer apply to head: based on %x but head is %x", parent, b.a.Head()) - } - } - return b.out, nil -} - -// NewUpdater returns a builder you can use to make changes to -// the tailnet key authority. -// -// The provided signer function, if non-nil, is called with each update -// to compute and apply signatures. -// -// Updates are specified by calling methods on the returned UpdatedBuilder. -// Call Finalize() when you are done to obtain the specific update messages -// which actuate the changes. -func (a *Authority) NewUpdater(signer Signer) *UpdateBuilder { - return &UpdateBuilder{ - a: a, - signer: signer, - parent: a.Head(), - state: a.state, - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "fmt" + "os" + + "tailscale.com/types/tkatype" +) + +// Types implementing Signer can sign update messages. +type Signer interface { + // SignAUM returns signatures for the AUM encoded by the given AUMSigHash. + SignAUM(tkatype.AUMSigHash) ([]tkatype.Signature, error) +} + +// UpdateBuilder implements a builder for changes to the tailnet +// key authority. +// +// Finalize must be called to compute the update messages, which +// must then be applied to all Authority objects using Inform(). +type UpdateBuilder struct { + a *Authority + signer Signer + + state State + parent AUMHash + + out []AUM +} + +func (b *UpdateBuilder) mkUpdate(update AUM) error { + prevHash := make([]byte, len(b.parent)) + copy(prevHash, b.parent[:]) + update.PrevAUMHash = prevHash + + if b.signer != nil { + sigs, err := b.signer.SignAUM(update.SigHash()) + if err != nil { + return fmt.Errorf("signing failed: %v", err) + } + update.Signatures = append(update.Signatures, sigs...) + } + if err := update.StaticValidate(); err != nil { + return fmt.Errorf("generated update was invalid: %v", err) + } + state, err := b.state.applyVerifiedAUM(update) + if err != nil { + return fmt.Errorf("update cannot be applied: %v", err) + } + + b.state = state + b.parent = update.Hash() + b.out = append(b.out, update) + return nil +} + +// AddKey adds a new key to the authority. +func (b *UpdateBuilder) AddKey(key Key) error { + keyID, err := key.ID() + if err != nil { + return err + } + + if _, err := b.state.GetKey(keyID); err == nil { + return fmt.Errorf("cannot add key %v: already exists", key) + } + return b.mkUpdate(AUM{MessageKind: AUMAddKey, Key: &key}) +} + +// RemoveKey removes a key from the authority. +func (b *UpdateBuilder) RemoveKey(keyID tkatype.KeyID) error { + if _, err := b.state.GetKey(keyID); err != nil { + return fmt.Errorf("failed reading key %x: %v", keyID, err) + } + return b.mkUpdate(AUM{MessageKind: AUMRemoveKey, KeyID: keyID}) +} + +// SetKeyVote updates the number of votes of an existing key. +func (b *UpdateBuilder) SetKeyVote(keyID tkatype.KeyID, votes uint) error { + if _, err := b.state.GetKey(keyID); err != nil { + return fmt.Errorf("failed reading key %x: %v", keyID, err) + } + return b.mkUpdate(AUM{MessageKind: AUMUpdateKey, Votes: &votes, KeyID: keyID}) +} + +// SetKeyMeta updates key-value metadata stored against an existing key. +// +// TODO(tom): Provide an API to update specific values rather than the whole +// map. +func (b *UpdateBuilder) SetKeyMeta(keyID tkatype.KeyID, meta map[string]string) error { + if _, err := b.state.GetKey(keyID); err != nil { + return fmt.Errorf("failed reading key %x: %v", keyID, err) + } + return b.mkUpdate(AUM{MessageKind: AUMUpdateKey, Meta: meta, KeyID: keyID}) +} + +func (b *UpdateBuilder) generateCheckpoint() error { + // Compute the checkpoint state. + state := b.a.state + for i, update := range b.out { + var err error + if state, err = state.applyVerifiedAUM(update); err != nil { + return fmt.Errorf("applying update %d: %v", i, err) + } + } + + // Checkpoints cant specify a parent AUM. + state.LastAUMHash = nil + return b.mkUpdate(AUM{MessageKind: AUMCheckpoint, State: &state}) +} + +// checkpointEvery sets how often a checkpoint AUM should be generated. +const checkpointEvery = 50 + +// Finalize returns the set of update message to actuate the update. +func (b *UpdateBuilder) Finalize(storage Chonk) ([]AUM, error) { + var ( + needCheckpoint bool = true + cursor AUMHash = b.a.Head() + ) + for i := len(b.out); i < checkpointEvery; i++ { + aum, err := storage.AUM(cursor) + if err != nil { + if err == os.ErrNotExist { + // The available chain is shorter than the interval to checkpoint at. + needCheckpoint = false + break + } + return nil, fmt.Errorf("reading AUM: %v", err) + } + + if aum.MessageKind == AUMCheckpoint { + needCheckpoint = false + break + } + + parent, hasParent := aum.Parent() + if !hasParent { + // We've hit the genesis update, so the chain is shorter than the interval to checkpoint at. + needCheckpoint = false + break + } + cursor = parent + } + + if needCheckpoint { + if err := b.generateCheckpoint(); err != nil { + return nil, fmt.Errorf("generating checkpoint: %v", err) + } + } + + // Check no AUMs were applied in the meantime + if len(b.out) > 0 { + if parent, _ := b.out[0].Parent(); parent != b.a.Head() { + return nil, fmt.Errorf("updates no longer apply to head: based on %x but head is %x", parent, b.a.Head()) + } + } + return b.out, nil +} + +// NewUpdater returns a builder you can use to make changes to +// the tailnet key authority. +// +// The provided signer function, if non-nil, is called with each update +// to compute and apply signatures. +// +// Updates are specified by calling methods on the returned UpdatedBuilder. +// Call Finalize() when you are done to obtain the specific update messages +// which actuate the changes. +func (a *Authority) NewUpdater(signer Signer) *UpdateBuilder { + return &UpdateBuilder{ + a: a, + signer: signer, + parent: a.Head(), + state: a.state, + } +} diff --git a/tka/builder_test.go b/tka/builder_test.go index 758fb170c0b5e..666af9ad07daf 100644 --- a/tka/builder_test.go +++ b/tka/builder_test.go @@ -1,270 +1,270 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "crypto/ed25519" - "testing" - - "github.com/google/go-cmp/cmp" - "tailscale.com/types/tkatype" -) - -type signer25519 ed25519.PrivateKey - -func (s signer25519) SignAUM(sigHash tkatype.AUMSigHash) ([]tkatype.Signature, error) { - priv := ed25519.PrivateKey(s) - key := Key{Kind: Key25519, Public: priv.Public().(ed25519.PublicKey)} - - return []tkatype.Signature{{ - KeyID: key.MustID(), - Signature: ed25519.Sign(priv, sigHash[:]), - }}, nil -} - -func TestAuthorityBuilderAddKey(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - pub2, _ := testingKey25519(t, 2) - key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - - b := a.NewUpdater(signer25519(priv)) - if err := b.AddKey(key2); err != nil { - t.Fatalf("AddKey(%v) failed: %v", key2, err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - - // See if the update is valid by applying it to the authority - // + checking if the new key is there. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - if _, err := a.state.GetKey(key2.MustID()); err != nil { - t.Errorf("could not read new key: %v", err) - } -} - -func TestAuthorityBuilderRemoveKey(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - pub2, _ := testingKey25519(t, 2) - key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key, key2}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - b := a.NewUpdater(signer25519(priv)) - if err := b.RemoveKey(key2.MustID()); err != nil { - t.Fatalf("RemoveKey(%v) failed: %v", key2, err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - - // See if the update is valid by applying it to the authority - // + checking if the key has been removed. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - if _, err := a.state.GetKey(key2.MustID()); err != ErrNoSuchKey { - t.Errorf("GetKey(key2).err = %v, want %v", err, ErrNoSuchKey) - } -} - -func TestAuthorityBuilderSetKeyVote(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - b := a.NewUpdater(signer25519(priv)) - if err := b.SetKeyVote(key.MustID(), 5); err != nil { - t.Fatalf("SetKeyVote(%v) failed: %v", key.MustID(), err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - - // See if the update is valid by applying it to the authority - // + checking if the update is there. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - k, err := a.state.GetKey(key.MustID()) - if err != nil { - t.Fatal(err) - } - if got, want := k.Votes, uint(5); got != want { - t.Errorf("key.Votes = %d, want %d", got, want) - } -} - -func TestAuthorityBuilderSetKeyMeta(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2, Meta: map[string]string{"a": "b"}} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - b := a.NewUpdater(signer25519(priv)) - if err := b.SetKeyMeta(key.MustID(), map[string]string{"b": "c"}); err != nil { - t.Fatalf("SetKeyMeta(%v) failed: %v", key, err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - - // See if the update is valid by applying it to the authority - // + checking if the update is there. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - k, err := a.state.GetKey(key.MustID()) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(map[string]string{"b": "c"}, k.Meta); diff != "" { - t.Errorf("updated meta differs (-want, +got):\n%s", diff) - } -} - -func TestAuthorityBuilderMultiple(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - pub2, _ := testingKey25519(t, 2) - key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - - b := a.NewUpdater(signer25519(priv)) - if err := b.AddKey(key2); err != nil { - t.Fatalf("AddKey(%v) failed: %v", key2, err) - } - if err := b.SetKeyVote(key2.MustID(), 42); err != nil { - t.Fatalf("SetKeyVote(%v) failed: %v", key2, err) - } - if err := b.RemoveKey(key.MustID()); err != nil { - t.Fatalf("RemoveKey(%v) failed: %v", key, err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - - // See if the update is valid by applying it to the authority - // + checking if the update is there. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - k, err := a.state.GetKey(key2.MustID()) - if err != nil { - t.Fatal(err) - } - if got, want := k.Votes, uint(42); got != want { - t.Errorf("key.Votes = %d, want %d", got, want) - } - if _, err := a.state.GetKey(key.MustID()); err != ErrNoSuchKey { - t.Errorf("GetKey(key).err = %v, want %v", err, ErrNoSuchKey) - } -} - -func TestAuthorityBuilderCheckpointsAfterXUpdates(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - for i := 0; i <= checkpointEvery; i++ { - pub2, _ := testingKey25519(t, int64(i+2)) - key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - - b := a.NewUpdater(signer25519(priv)) - if err := b.AddKey(key2); err != nil { - t.Fatalf("AddKey(%v) failed: %v", key2, err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - // See if the update is valid by applying it to the authority - // + checking if the new key is there. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - if _, err := a.state.GetKey(key2.MustID()); err != nil { - t.Fatal(err) - } - - wantKind := AUMAddKey - if i == checkpointEvery-1 { // Genesis + 49 updates == 50 (the value of checkpointEvery) - wantKind = AUMCheckpoint - } - lastAUM, err := storage.AUM(a.Head()) - if err != nil { - t.Fatal(err) - } - if lastAUM.MessageKind != wantKind { - t.Errorf("[%d] HeadAUM.MessageKind = %v, want %v", i, lastAUM.MessageKind, wantKind) - } - } - - // Try starting an authority just based on storage. - a2, err := Open(storage) - if err != nil { - t.Fatalf("Failed to open from stored AUMs: %v", err) - } - if a.Head() != a2.Head() { - t.Errorf("stored and computed HEAD differ: got %v, want %v", a2.Head(), a.Head()) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "crypto/ed25519" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/types/tkatype" +) + +type signer25519 ed25519.PrivateKey + +func (s signer25519) SignAUM(sigHash tkatype.AUMSigHash) ([]tkatype.Signature, error) { + priv := ed25519.PrivateKey(s) + key := Key{Kind: Key25519, Public: priv.Public().(ed25519.PublicKey)} + + return []tkatype.Signature{{ + KeyID: key.MustID(), + Signature: ed25519.Sign(priv, sigHash[:]), + }}, nil +} + +func TestAuthorityBuilderAddKey(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + pub2, _ := testingKey25519(t, 2) + key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + + b := a.NewUpdater(signer25519(priv)) + if err := b.AddKey(key2); err != nil { + t.Fatalf("AddKey(%v) failed: %v", key2, err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + // See if the update is valid by applying it to the authority + // + checking if the new key is there. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + if _, err := a.state.GetKey(key2.MustID()); err != nil { + t.Errorf("could not read new key: %v", err) + } +} + +func TestAuthorityBuilderRemoveKey(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + pub2, _ := testingKey25519(t, 2) + key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key, key2}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + b := a.NewUpdater(signer25519(priv)) + if err := b.RemoveKey(key2.MustID()); err != nil { + t.Fatalf("RemoveKey(%v) failed: %v", key2, err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + // See if the update is valid by applying it to the authority + // + checking if the key has been removed. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + if _, err := a.state.GetKey(key2.MustID()); err != ErrNoSuchKey { + t.Errorf("GetKey(key2).err = %v, want %v", err, ErrNoSuchKey) + } +} + +func TestAuthorityBuilderSetKeyVote(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + b := a.NewUpdater(signer25519(priv)) + if err := b.SetKeyVote(key.MustID(), 5); err != nil { + t.Fatalf("SetKeyVote(%v) failed: %v", key.MustID(), err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + // See if the update is valid by applying it to the authority + // + checking if the update is there. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + k, err := a.state.GetKey(key.MustID()) + if err != nil { + t.Fatal(err) + } + if got, want := k.Votes, uint(5); got != want { + t.Errorf("key.Votes = %d, want %d", got, want) + } +} + +func TestAuthorityBuilderSetKeyMeta(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2, Meta: map[string]string{"a": "b"}} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + b := a.NewUpdater(signer25519(priv)) + if err := b.SetKeyMeta(key.MustID(), map[string]string{"b": "c"}); err != nil { + t.Fatalf("SetKeyMeta(%v) failed: %v", key, err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + // See if the update is valid by applying it to the authority + // + checking if the update is there. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + k, err := a.state.GetKey(key.MustID()) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(map[string]string{"b": "c"}, k.Meta); diff != "" { + t.Errorf("updated meta differs (-want, +got):\n%s", diff) + } +} + +func TestAuthorityBuilderMultiple(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + pub2, _ := testingKey25519(t, 2) + key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + + b := a.NewUpdater(signer25519(priv)) + if err := b.AddKey(key2); err != nil { + t.Fatalf("AddKey(%v) failed: %v", key2, err) + } + if err := b.SetKeyVote(key2.MustID(), 42); err != nil { + t.Fatalf("SetKeyVote(%v) failed: %v", key2, err) + } + if err := b.RemoveKey(key.MustID()); err != nil { + t.Fatalf("RemoveKey(%v) failed: %v", key, err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + // See if the update is valid by applying it to the authority + // + checking if the update is there. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + k, err := a.state.GetKey(key2.MustID()) + if err != nil { + t.Fatal(err) + } + if got, want := k.Votes, uint(42); got != want { + t.Errorf("key.Votes = %d, want %d", got, want) + } + if _, err := a.state.GetKey(key.MustID()); err != ErrNoSuchKey { + t.Errorf("GetKey(key).err = %v, want %v", err, ErrNoSuchKey) + } +} + +func TestAuthorityBuilderCheckpointsAfterXUpdates(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + for i := 0; i <= checkpointEvery; i++ { + pub2, _ := testingKey25519(t, int64(i+2)) + key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + + b := a.NewUpdater(signer25519(priv)) + if err := b.AddKey(key2); err != nil { + t.Fatalf("AddKey(%v) failed: %v", key2, err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + // See if the update is valid by applying it to the authority + // + checking if the new key is there. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + if _, err := a.state.GetKey(key2.MustID()); err != nil { + t.Fatal(err) + } + + wantKind := AUMAddKey + if i == checkpointEvery-1 { // Genesis + 49 updates == 50 (the value of checkpointEvery) + wantKind = AUMCheckpoint + } + lastAUM, err := storage.AUM(a.Head()) + if err != nil { + t.Fatal(err) + } + if lastAUM.MessageKind != wantKind { + t.Errorf("[%d] HeadAUM.MessageKind = %v, want %v", i, lastAUM.MessageKind, wantKind) + } + } + + // Try starting an authority just based on storage. + a2, err := Open(storage) + if err != nil { + t.Fatalf("Failed to open from stored AUMs: %v", err) + } + if a.Head() != a2.Head() { + t.Errorf("stored and computed HEAD differ: got %v, want %v", a2.Head(), a.Head()) + } +} diff --git a/tka/deeplink.go b/tka/deeplink.go index 97bcd664b23ec..5cf24fc5c2c82 100644 --- a/tka/deeplink.go +++ b/tka/deeplink.go @@ -1,221 +1,221 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "crypto/hmac" - "crypto/sha256" - "encoding/binary" - "encoding/hex" - "fmt" - "net/url" - "strings" -) - -const ( - DeeplinkTailscaleURLScheme = "tailscale" - DeeplinkCommandSign = "sign-device" -) - -// generateHMAC computes a SHA-256 HMAC for the concatenation of components, -// using the Authority stateID as secret. -func (a *Authority) generateHMAC(params NewDeeplinkParams) []byte { - stateID, _ := a.StateIDs() - - key := make([]byte, 8) - binary.LittleEndian.PutUint64(key, stateID) - mac := hmac.New(sha256.New, key) - mac.Write([]byte(params.NodeKey)) - mac.Write([]byte(params.TLPub)) - mac.Write([]byte(params.DeviceName)) - mac.Write([]byte(params.OSName)) - mac.Write([]byte(params.LoginName)) - return mac.Sum(nil) -} - -type NewDeeplinkParams struct { - NodeKey string - TLPub string - DeviceName string - OSName string - LoginName string -} - -// NewDeeplink creates a signed deeplink using the authority's stateID as a -// secret. This deeplink can then be validated by ValidateDeeplink. -func (a *Authority) NewDeeplink(params NewDeeplinkParams) (string, error) { - if params.NodeKey == "" || !strings.HasPrefix(params.NodeKey, "nodekey:") { - return "", fmt.Errorf("invalid node key %q", params.NodeKey) - } - if params.TLPub == "" || !strings.HasPrefix(params.TLPub, "tlpub:") { - return "", fmt.Errorf("invalid tlpub %q", params.TLPub) - } - if params.DeviceName == "" { - return "", fmt.Errorf("invalid device name %q", params.DeviceName) - } - if params.OSName == "" { - return "", fmt.Errorf("invalid os name %q", params.OSName) - } - if params.LoginName == "" { - return "", fmt.Errorf("invalid login name %q", params.LoginName) - } - - u := url.URL{ - Scheme: DeeplinkTailscaleURLScheme, - Host: DeeplinkCommandSign, - Path: "/v1/", - } - v := url.Values{} - v.Set("nk", params.NodeKey) - v.Set("tp", params.TLPub) - v.Set("dn", params.DeviceName) - v.Set("os", params.OSName) - v.Set("em", params.LoginName) - - hmac := a.generateHMAC(params) - v.Set("hm", hex.EncodeToString(hmac)) - - u.RawQuery = v.Encode() - return u.String(), nil -} - -type DeeplinkValidationResult struct { - IsValid bool - Error string - Version uint8 - NodeKey string - TLPub string - DeviceName string - OSName string - EmailAddress string -} - -// ValidateDeeplink validates a device signing deeplink using the authority's stateID. -// The input urlString follows this structure: -// -// tailscale://sign-device/v1/?nk=xxx&tp=xxx&dn=xxx&os=xxx&em=xxx&hm=xxx -// -// where: -// - "nk" is the nodekey of the node being signed -// - "tp" is the tailnet lock public key -// - "dn" is the name of the node -// - "os" is the operating system of the node -// - "em" is the email address associated with the node -// - "hm" is a SHA-256 HMAC computed over the concatenation of the above fields, encoded as a hex string -func (a *Authority) ValidateDeeplink(urlString string) DeeplinkValidationResult { - parsedUrl, err := url.Parse(urlString) - if err != nil { - return DeeplinkValidationResult{ - IsValid: false, - Error: err.Error(), - } - } - - if parsedUrl.Scheme != DeeplinkTailscaleURLScheme { - return DeeplinkValidationResult{ - IsValid: false, - Error: fmt.Sprintf("unhandled scheme %s, expected %s", parsedUrl.Scheme, DeeplinkTailscaleURLScheme), - } - } - - if parsedUrl.Host != DeeplinkCommandSign { - return DeeplinkValidationResult{ - IsValid: false, - Error: fmt.Sprintf("unhandled host %s, expected %s", parsedUrl.Host, DeeplinkCommandSign), - } - } - - path := parsedUrl.EscapedPath() - pathComponents := strings.Split(path, "/") - if len(pathComponents) != 3 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "invalid path components number found", - } - } - - if pathComponents[1] != "v1" { - return DeeplinkValidationResult{ - IsValid: false, - Error: fmt.Sprintf("expected v1 deeplink version, found something else: %s", pathComponents[1]), - } - } - - nodeKey := parsedUrl.Query().Get("nk") - if len(nodeKey) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing nk (NodeKey) query parameter", - } - } - - tlPub := parsedUrl.Query().Get("tp") - if len(tlPub) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing tp (TLPub) query parameter", - } - } - - deviceName := parsedUrl.Query().Get("dn") - if len(deviceName) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing dn (DeviceName) query parameter", - } - } - - osName := parsedUrl.Query().Get("os") - if len(deviceName) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing os (OSName) query parameter", - } - } - - emailAddress := parsedUrl.Query().Get("em") - if len(emailAddress) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing em (EmailAddress) query parameter", - } - } - - hmacString := parsedUrl.Query().Get("hm") - if len(hmacString) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing hm (HMAC) query parameter", - } - } - - computedHMAC := a.generateHMAC(NewDeeplinkParams{ - NodeKey: nodeKey, - TLPub: tlPub, - DeviceName: deviceName, - OSName: osName, - LoginName: emailAddress, - }) - - hmacHexBytes, err := hex.DecodeString(hmacString) - if err != nil { - return DeeplinkValidationResult{IsValid: false, Error: "could not hex-decode hmac"} - } - - if !hmac.Equal(computedHMAC, hmacHexBytes) { - return DeeplinkValidationResult{ - IsValid: false, - Error: "hmac authentication failed", - } - } - - return DeeplinkValidationResult{ - IsValid: true, - NodeKey: nodeKey, - TLPub: tlPub, - DeviceName: deviceName, - OSName: osName, - EmailAddress: emailAddress, - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "net/url" + "strings" +) + +const ( + DeeplinkTailscaleURLScheme = "tailscale" + DeeplinkCommandSign = "sign-device" +) + +// generateHMAC computes a SHA-256 HMAC for the concatenation of components, +// using the Authority stateID as secret. +func (a *Authority) generateHMAC(params NewDeeplinkParams) []byte { + stateID, _ := a.StateIDs() + + key := make([]byte, 8) + binary.LittleEndian.PutUint64(key, stateID) + mac := hmac.New(sha256.New, key) + mac.Write([]byte(params.NodeKey)) + mac.Write([]byte(params.TLPub)) + mac.Write([]byte(params.DeviceName)) + mac.Write([]byte(params.OSName)) + mac.Write([]byte(params.LoginName)) + return mac.Sum(nil) +} + +type NewDeeplinkParams struct { + NodeKey string + TLPub string + DeviceName string + OSName string + LoginName string +} + +// NewDeeplink creates a signed deeplink using the authority's stateID as a +// secret. This deeplink can then be validated by ValidateDeeplink. +func (a *Authority) NewDeeplink(params NewDeeplinkParams) (string, error) { + if params.NodeKey == "" || !strings.HasPrefix(params.NodeKey, "nodekey:") { + return "", fmt.Errorf("invalid node key %q", params.NodeKey) + } + if params.TLPub == "" || !strings.HasPrefix(params.TLPub, "tlpub:") { + return "", fmt.Errorf("invalid tlpub %q", params.TLPub) + } + if params.DeviceName == "" { + return "", fmt.Errorf("invalid device name %q", params.DeviceName) + } + if params.OSName == "" { + return "", fmt.Errorf("invalid os name %q", params.OSName) + } + if params.LoginName == "" { + return "", fmt.Errorf("invalid login name %q", params.LoginName) + } + + u := url.URL{ + Scheme: DeeplinkTailscaleURLScheme, + Host: DeeplinkCommandSign, + Path: "/v1/", + } + v := url.Values{} + v.Set("nk", params.NodeKey) + v.Set("tp", params.TLPub) + v.Set("dn", params.DeviceName) + v.Set("os", params.OSName) + v.Set("em", params.LoginName) + + hmac := a.generateHMAC(params) + v.Set("hm", hex.EncodeToString(hmac)) + + u.RawQuery = v.Encode() + return u.String(), nil +} + +type DeeplinkValidationResult struct { + IsValid bool + Error string + Version uint8 + NodeKey string + TLPub string + DeviceName string + OSName string + EmailAddress string +} + +// ValidateDeeplink validates a device signing deeplink using the authority's stateID. +// The input urlString follows this structure: +// +// tailscale://sign-device/v1/?nk=xxx&tp=xxx&dn=xxx&os=xxx&em=xxx&hm=xxx +// +// where: +// - "nk" is the nodekey of the node being signed +// - "tp" is the tailnet lock public key +// - "dn" is the name of the node +// - "os" is the operating system of the node +// - "em" is the email address associated with the node +// - "hm" is a SHA-256 HMAC computed over the concatenation of the above fields, encoded as a hex string +func (a *Authority) ValidateDeeplink(urlString string) DeeplinkValidationResult { + parsedUrl, err := url.Parse(urlString) + if err != nil { + return DeeplinkValidationResult{ + IsValid: false, + Error: err.Error(), + } + } + + if parsedUrl.Scheme != DeeplinkTailscaleURLScheme { + return DeeplinkValidationResult{ + IsValid: false, + Error: fmt.Sprintf("unhandled scheme %s, expected %s", parsedUrl.Scheme, DeeplinkTailscaleURLScheme), + } + } + + if parsedUrl.Host != DeeplinkCommandSign { + return DeeplinkValidationResult{ + IsValid: false, + Error: fmt.Sprintf("unhandled host %s, expected %s", parsedUrl.Host, DeeplinkCommandSign), + } + } + + path := parsedUrl.EscapedPath() + pathComponents := strings.Split(path, "/") + if len(pathComponents) != 3 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "invalid path components number found", + } + } + + if pathComponents[1] != "v1" { + return DeeplinkValidationResult{ + IsValid: false, + Error: fmt.Sprintf("expected v1 deeplink version, found something else: %s", pathComponents[1]), + } + } + + nodeKey := parsedUrl.Query().Get("nk") + if len(nodeKey) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing nk (NodeKey) query parameter", + } + } + + tlPub := parsedUrl.Query().Get("tp") + if len(tlPub) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing tp (TLPub) query parameter", + } + } + + deviceName := parsedUrl.Query().Get("dn") + if len(deviceName) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing dn (DeviceName) query parameter", + } + } + + osName := parsedUrl.Query().Get("os") + if len(deviceName) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing os (OSName) query parameter", + } + } + + emailAddress := parsedUrl.Query().Get("em") + if len(emailAddress) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing em (EmailAddress) query parameter", + } + } + + hmacString := parsedUrl.Query().Get("hm") + if len(hmacString) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing hm (HMAC) query parameter", + } + } + + computedHMAC := a.generateHMAC(NewDeeplinkParams{ + NodeKey: nodeKey, + TLPub: tlPub, + DeviceName: deviceName, + OSName: osName, + LoginName: emailAddress, + }) + + hmacHexBytes, err := hex.DecodeString(hmacString) + if err != nil { + return DeeplinkValidationResult{IsValid: false, Error: "could not hex-decode hmac"} + } + + if !hmac.Equal(computedHMAC, hmacHexBytes) { + return DeeplinkValidationResult{ + IsValid: false, + Error: "hmac authentication failed", + } + } + + return DeeplinkValidationResult{ + IsValid: true, + NodeKey: nodeKey, + TLPub: tlPub, + DeviceName: deviceName, + OSName: osName, + EmailAddress: emailAddress, + } +} diff --git a/tka/deeplink_test.go b/tka/deeplink_test.go index 397cc6917f289..03523202fed8b 100644 --- a/tka/deeplink_test.go +++ b/tka/deeplink_test.go @@ -1,52 +1,52 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "testing" -) - -func TestGenerateDeeplink(t *testing.T) { - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - c := newTestchain(t, ` - G1 -> L1 - - G1.template = genesis - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - ) - a, _ := Open(c.Chonk()) - - nodeKey := "nodekey:1234567890" - tlPub := "tlpub:1234567890" - deviceName := "Example Device" - osName := "iOS" - loginName := "insecure@example.com" - - deeplink, err := a.NewDeeplink(NewDeeplinkParams{ - NodeKey: nodeKey, - TLPub: tlPub, - DeviceName: deviceName, - OSName: osName, - LoginName: loginName, - }) - if err != nil { - t.Errorf("deeplink generation failed: %v", err) - } - - res := a.ValidateDeeplink(deeplink) - if !res.IsValid { - t.Errorf("deeplink validation failed: %s", res.Error) - } - if res.NodeKey != nodeKey { - t.Errorf("node key mismatch: %s != %s", res.NodeKey, nodeKey) - } - if res.TLPub != tlPub { - t.Errorf("tlpub mismatch: %s != %s", res.TLPub, tlPub) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "testing" +) + +func TestGenerateDeeplink(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + c := newTestchain(t, ` + G1 -> L1 + + G1.template = genesis + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + ) + a, _ := Open(c.Chonk()) + + nodeKey := "nodekey:1234567890" + tlPub := "tlpub:1234567890" + deviceName := "Example Device" + osName := "iOS" + loginName := "insecure@example.com" + + deeplink, err := a.NewDeeplink(NewDeeplinkParams{ + NodeKey: nodeKey, + TLPub: tlPub, + DeviceName: deviceName, + OSName: osName, + LoginName: loginName, + }) + if err != nil { + t.Errorf("deeplink generation failed: %v", err) + } + + res := a.ValidateDeeplink(deeplink) + if !res.IsValid { + t.Errorf("deeplink validation failed: %s", res.Error) + } + if res.NodeKey != nodeKey { + t.Errorf("node key mismatch: %s != %s", res.NodeKey, nodeKey) + } + if res.TLPub != tlPub { + t.Errorf("tlpub mismatch: %s != %s", res.TLPub, tlPub) + } +} diff --git a/tka/key.go b/tka/key.go index 47218438d88ea..07736795d8e58 100644 --- a/tka/key.go +++ b/tka/key.go @@ -1,159 +1,159 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "crypto/ed25519" - "errors" - "fmt" - - "github.com/hdevalence/ed25519consensus" - "tailscale.com/types/tkatype" -) - -// KeyKind describes the different varieties of a Key. -type KeyKind uint8 - -// Valid KeyKind values. -const ( - KeyInvalid KeyKind = iota - Key25519 -) - -func (k KeyKind) String() string { - switch k { - case KeyInvalid: - return "invalid" - case Key25519: - return "25519" - default: - return fmt.Sprintf("Key?<%d>", int(k)) - } -} - -// Key describes the public components of a key known to network-lock. -type Key struct { - Kind KeyKind `cbor:"1,keyasint"` - - // Votes describes the weight applied to signatures using this key. - // Weighting is used to deterministically resolve branches in the AUM - // chain (i.e. forks, where two AUMs exist with the same parent). - Votes uint `cbor:"2,keyasint"` - - // Public encodes the public key of the key. For 25519 keys, - // this is simply the point on the curve representing the public - // key. - Public []byte `cbor:"3,keyasint"` - - // Meta describes arbitrary metadata about the key. This could be - // used to store the name of the key, for instance. - Meta map[string]string `cbor:"12,keyasint,omitempty"` -} - -// Clone makes an independent copy of Key. -// -// NOTE: There is a difference between a nil slice and an empty slice for encoding purposes, -// so an implementation of Clone() must take care to preserve this. -func (k Key) Clone() Key { - out := k - - if k.Public != nil { - out.Public = make([]byte, len(k.Public)) - copy(out.Public, k.Public) - } - - if k.Meta != nil { - out.Meta = make(map[string]string, len(k.Meta)) - for k, v := range k.Meta { - out.Meta[k] = v - } - } - - return out -} - -// MustID returns the KeyID of the key, panicking if an error is -// encountered. This must only be used for tests. -func (k Key) MustID() tkatype.KeyID { - id, err := k.ID() - if err != nil { - panic(err) - } - return id -} - -// ID returns the KeyID of the key. -func (k Key) ID() (tkatype.KeyID, error) { - switch k.Kind { - // Because 25519 public keys are so short, we just use the 32-byte - // public as their 'key ID'. - case Key25519: - return tkatype.KeyID(k.Public), nil - default: - return nil, fmt.Errorf("unknown key kind: %v", k.Kind) - } -} - -// Ed25519 returns the ed25519 public key encoded by Key. An error is -// returned for keys which do not represent ed25519 public keys. -func (k Key) Ed25519() (ed25519.PublicKey, error) { - switch k.Kind { - case Key25519: - return ed25519.PublicKey(k.Public), nil - default: - return nil, fmt.Errorf("key is of type %v, not ed25519", k.Kind) - } -} - -const maxMetaBytes = 512 - -func (k Key) StaticValidate() error { - if k.Votes > 4096 { - return fmt.Errorf("excessive key weight: %d > 4096", k.Votes) - } - if k.Votes == 0 { - return errors.New("key votes must be non-zero") - } - - // We have an arbitrary upper limit on the amount - // of metadata that can be associated with a key, so - // people don't start using it as a key-value store and - // causing pathological cases due to the number + size of - // AUMs. - var metaBytes uint - for k, v := range k.Meta { - metaBytes += uint(len(k) + len(v)) - } - if metaBytes > maxMetaBytes { - return fmt.Errorf("key metadata too big (%d > %d)", metaBytes, maxMetaBytes) - } - - switch k.Kind { - case Key25519: - default: - return fmt.Errorf("unrecognized key kind: %v", k.Kind) - } - return nil -} - -// Verify returns a nil error if the signature is valid over the -// provided AUM BLAKE2s digest, using the given key. -func signatureVerify(s *tkatype.Signature, aumDigest tkatype.AUMSigHash, key Key) error { - // NOTE(tom): Even if we can compute the public from the KeyID, - // its possible for the KeyID to be attacker-controlled - // so we should use the public contained in the state machine. - switch key.Kind { - case Key25519: - if len(key.Public) != ed25519.PublicKeySize { - return fmt.Errorf("ed25519 key has wrong length: %d", len(key.Public)) - } - if ed25519consensus.Verify(ed25519.PublicKey(key.Public), aumDigest[:], s.Signature) { - return nil - } - return errors.New("invalid signature") - - default: - return fmt.Errorf("unhandled key type: %v", key.Kind) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "crypto/ed25519" + "errors" + "fmt" + + "github.com/hdevalence/ed25519consensus" + "tailscale.com/types/tkatype" +) + +// KeyKind describes the different varieties of a Key. +type KeyKind uint8 + +// Valid KeyKind values. +const ( + KeyInvalid KeyKind = iota + Key25519 +) + +func (k KeyKind) String() string { + switch k { + case KeyInvalid: + return "invalid" + case Key25519: + return "25519" + default: + return fmt.Sprintf("Key?<%d>", int(k)) + } +} + +// Key describes the public components of a key known to network-lock. +type Key struct { + Kind KeyKind `cbor:"1,keyasint"` + + // Votes describes the weight applied to signatures using this key. + // Weighting is used to deterministically resolve branches in the AUM + // chain (i.e. forks, where two AUMs exist with the same parent). + Votes uint `cbor:"2,keyasint"` + + // Public encodes the public key of the key. For 25519 keys, + // this is simply the point on the curve representing the public + // key. + Public []byte `cbor:"3,keyasint"` + + // Meta describes arbitrary metadata about the key. This could be + // used to store the name of the key, for instance. + Meta map[string]string `cbor:"12,keyasint,omitempty"` +} + +// Clone makes an independent copy of Key. +// +// NOTE: There is a difference between a nil slice and an empty slice for encoding purposes, +// so an implementation of Clone() must take care to preserve this. +func (k Key) Clone() Key { + out := k + + if k.Public != nil { + out.Public = make([]byte, len(k.Public)) + copy(out.Public, k.Public) + } + + if k.Meta != nil { + out.Meta = make(map[string]string, len(k.Meta)) + for k, v := range k.Meta { + out.Meta[k] = v + } + } + + return out +} + +// MustID returns the KeyID of the key, panicking if an error is +// encountered. This must only be used for tests. +func (k Key) MustID() tkatype.KeyID { + id, err := k.ID() + if err != nil { + panic(err) + } + return id +} + +// ID returns the KeyID of the key. +func (k Key) ID() (tkatype.KeyID, error) { + switch k.Kind { + // Because 25519 public keys are so short, we just use the 32-byte + // public as their 'key ID'. + case Key25519: + return tkatype.KeyID(k.Public), nil + default: + return nil, fmt.Errorf("unknown key kind: %v", k.Kind) + } +} + +// Ed25519 returns the ed25519 public key encoded by Key. An error is +// returned for keys which do not represent ed25519 public keys. +func (k Key) Ed25519() (ed25519.PublicKey, error) { + switch k.Kind { + case Key25519: + return ed25519.PublicKey(k.Public), nil + default: + return nil, fmt.Errorf("key is of type %v, not ed25519", k.Kind) + } +} + +const maxMetaBytes = 512 + +func (k Key) StaticValidate() error { + if k.Votes > 4096 { + return fmt.Errorf("excessive key weight: %d > 4096", k.Votes) + } + if k.Votes == 0 { + return errors.New("key votes must be non-zero") + } + + // We have an arbitrary upper limit on the amount + // of metadata that can be associated with a key, so + // people don't start using it as a key-value store and + // causing pathological cases due to the number + size of + // AUMs. + var metaBytes uint + for k, v := range k.Meta { + metaBytes += uint(len(k) + len(v)) + } + if metaBytes > maxMetaBytes { + return fmt.Errorf("key metadata too big (%d > %d)", metaBytes, maxMetaBytes) + } + + switch k.Kind { + case Key25519: + default: + return fmt.Errorf("unrecognized key kind: %v", k.Kind) + } + return nil +} + +// Verify returns a nil error if the signature is valid over the +// provided AUM BLAKE2s digest, using the given key. +func signatureVerify(s *tkatype.Signature, aumDigest tkatype.AUMSigHash, key Key) error { + // NOTE(tom): Even if we can compute the public from the KeyID, + // its possible for the KeyID to be attacker-controlled + // so we should use the public contained in the state machine. + switch key.Kind { + case Key25519: + if len(key.Public) != ed25519.PublicKeySize { + return fmt.Errorf("ed25519 key has wrong length: %d", len(key.Public)) + } + if ed25519consensus.Verify(ed25519.PublicKey(key.Public), aumDigest[:], s.Signature) { + return nil + } + return errors.New("invalid signature") + + default: + return fmt.Errorf("unhandled key type: %v", key.Kind) + } +} diff --git a/tka/key_test.go b/tka/key_test.go index aaddb2f404f10..e912f89c4f7eb 100644 --- a/tka/key_test.go +++ b/tka/key_test.go @@ -1,97 +1,97 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "crypto/ed25519" - "encoding/binary" - "math/rand" - "testing" - - "tailscale.com/types/key" - "tailscale.com/types/tkatype" -) - -// returns a random source based on the test name + extraSeed. -func testingRand(t *testing.T, extraSeed int64) *rand.Rand { - var seed int64 - if err := binary.Read(bytes.NewBuffer([]byte(t.Name())), binary.LittleEndian, &seed); err != nil { - panic(err) - } - return rand.New(rand.NewSource(seed + extraSeed)) -} - -// generates a 25519 private key based on the seed + test name. -func testingKey25519(t *testing.T, seed int64) (ed25519.PublicKey, ed25519.PrivateKey) { - pub, priv, err := ed25519.GenerateKey(testingRand(t, seed)) - if err != nil { - panic(err) - } - return pub, priv -} - -func TestVerify25519(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{ - Kind: Key25519, - Public: pub, - } - - aum := AUM{ - MessageKind: AUMRemoveKey, - KeyID: []byte{1, 2, 3, 4}, - // Signatures is set to crap so we are sure its ignored in the sigHash computation. - Signatures: []tkatype.Signature{{KeyID: []byte{45, 42}}}, - } - sigHash := aum.SigHash() - aum.Signatures = []tkatype.Signature{ - { - KeyID: key.MustID(), - Signature: ed25519.Sign(priv, sigHash[:]), - }, - } - - if err := signatureVerify(&aum.Signatures[0], aum.SigHash(), key); err != nil { - t.Errorf("signature verification failed: %v", err) - } - - // Make sure it fails with a different public key. - pub2, _ := testingKey25519(t, 2) - key2 := Key{Kind: Key25519, Public: pub2} - if err := signatureVerify(&aum.Signatures[0], aum.SigHash(), key2); err == nil { - t.Error("signature verification with different key did not fail") - } -} - -func TestNLPrivate(t *testing.T) { - p := key.NewNLPrivate() - pub := p.Public() - - // Test that key.NLPrivate implements Signer by making a new - // authority. - k := Key{Kind: Key25519, Public: pub.Verifier(), Votes: 1} - _, aum, err := Create(&Mem{}, State{ - Keys: []Key{k}, - DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, - }, p) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - // Make sure the generated genesis AUM was signed. - if got, want := len(aum.Signatures), 1; got != want { - t.Fatalf("len(signatures) = %d, want %d", got, want) - } - sigHash := aum.SigHash() - if ok := ed25519.Verify(pub.Verifier(), sigHash[:], aum.Signatures[0].Signature); !ok { - t.Error("signature did not verify") - } - - // We manually compute the keyID, so make sure its consistent with - // tka.Key.ID(). - if !bytes.Equal(k.MustID(), p.KeyID()) { - t.Errorf("private.KeyID() & tka KeyID differ: %x != %x", k.MustID(), p.KeyID()) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "crypto/ed25519" + "encoding/binary" + "math/rand" + "testing" + + "tailscale.com/types/key" + "tailscale.com/types/tkatype" +) + +// returns a random source based on the test name + extraSeed. +func testingRand(t *testing.T, extraSeed int64) *rand.Rand { + var seed int64 + if err := binary.Read(bytes.NewBuffer([]byte(t.Name())), binary.LittleEndian, &seed); err != nil { + panic(err) + } + return rand.New(rand.NewSource(seed + extraSeed)) +} + +// generates a 25519 private key based on the seed + test name. +func testingKey25519(t *testing.T, seed int64) (ed25519.PublicKey, ed25519.PrivateKey) { + pub, priv, err := ed25519.GenerateKey(testingRand(t, seed)) + if err != nil { + panic(err) + } + return pub, priv +} + +func TestVerify25519(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{ + Kind: Key25519, + Public: pub, + } + + aum := AUM{ + MessageKind: AUMRemoveKey, + KeyID: []byte{1, 2, 3, 4}, + // Signatures is set to crap so we are sure its ignored in the sigHash computation. + Signatures: []tkatype.Signature{{KeyID: []byte{45, 42}}}, + } + sigHash := aum.SigHash() + aum.Signatures = []tkatype.Signature{ + { + KeyID: key.MustID(), + Signature: ed25519.Sign(priv, sigHash[:]), + }, + } + + if err := signatureVerify(&aum.Signatures[0], aum.SigHash(), key); err != nil { + t.Errorf("signature verification failed: %v", err) + } + + // Make sure it fails with a different public key. + pub2, _ := testingKey25519(t, 2) + key2 := Key{Kind: Key25519, Public: pub2} + if err := signatureVerify(&aum.Signatures[0], aum.SigHash(), key2); err == nil { + t.Error("signature verification with different key did not fail") + } +} + +func TestNLPrivate(t *testing.T) { + p := key.NewNLPrivate() + pub := p.Public() + + // Test that key.NLPrivate implements Signer by making a new + // authority. + k := Key{Kind: Key25519, Public: pub.Verifier(), Votes: 1} + _, aum, err := Create(&Mem{}, State{ + Keys: []Key{k}, + DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, + }, p) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + // Make sure the generated genesis AUM was signed. + if got, want := len(aum.Signatures), 1; got != want { + t.Fatalf("len(signatures) = %d, want %d", got, want) + } + sigHash := aum.SigHash() + if ok := ed25519.Verify(pub.Verifier(), sigHash[:], aum.Signatures[0].Signature); !ok { + t.Error("signature did not verify") + } + + // We manually compute the keyID, so make sure its consistent with + // tka.Key.ID(). + if !bytes.Equal(k.MustID(), p.KeyID()) { + t.Errorf("private.KeyID() & tka KeyID differ: %x != %x", k.MustID(), p.KeyID()) + } +} diff --git a/tka/state.go b/tka/state.go index e99b731ccb2ad..0a459bd9a1b24 100644 --- a/tka/state.go +++ b/tka/state.go @@ -1,315 +1,315 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "errors" - "fmt" - - "golang.org/x/crypto/argon2" - "tailscale.com/types/tkatype" -) - -// ErrNoSuchKey is returned if the key referenced by a KeyID does not exist. -var ErrNoSuchKey = errors.New("key not found") - -// State describes Tailnet Key Authority state at an instant in time. -// -// State is mutated by applying Authority Update Messages (AUMs), resulting -// in a new State. -type State struct { - // LastAUMHash is the blake2s digest of the last-applied AUM. - // Because AUMs are strictly ordered and form a hash chain, we - // check the previous AUM hash in an update we are applying - // is the same as the LastAUMHash. - LastAUMHash *AUMHash `cbor:"1,keyasint"` - - // DisablementSecrets are KDF-derived values which can be used - // to turn off the TKA in the event of a consensus-breaking bug. - DisablementSecrets [][]byte `cbor:"2,keyasint"` - - // Keys are the public keys of either: - // - // 1. The signing nodes currently trusted by the TKA. - // 2. Ephemeral keys that were used to generate pre-signed auth keys. - Keys []Key `cbor:"3,keyasint"` - - // StateID's are nonce's, generated on enablement and fixed for - // the lifetime of the Tailnet Key Authority. We generate 16-bytes - // worth of keyspace here just in case we come up with a cool future - // use for this. - StateID1 uint64 `cbor:"4,keyasint,omitempty"` - StateID2 uint64 `cbor:"5,keyasint,omitempty"` -} - -// GetKey returns the trusted key with the specified KeyID. -func (s State) GetKey(key tkatype.KeyID) (Key, error) { - for _, k := range s.Keys { - keyID, err := k.ID() - if err != nil { - return Key{}, err - } - - if bytes.Equal(keyID, key) { - return k, nil - } - } - - return Key{}, ErrNoSuchKey -} - -// Clone makes an independent copy of State. -// -// NOTE: There is a difference between a nil slice and an empty -// slice for encoding purposes, so an implementation of Clone() -// must take care to preserve this. -func (s State) Clone() State { - out := State{ - StateID1: s.StateID1, - StateID2: s.StateID2, - } - - if s.LastAUMHash != nil { - dupe := *s.LastAUMHash - out.LastAUMHash = &dupe - } - - if s.DisablementSecrets != nil { - out.DisablementSecrets = make([][]byte, len(s.DisablementSecrets)) - for i := range s.DisablementSecrets { - out.DisablementSecrets[i] = make([]byte, len(s.DisablementSecrets[i])) - copy(out.DisablementSecrets[i], s.DisablementSecrets[i]) - } - } - - if s.Keys != nil { - out.Keys = make([]Key, len(s.Keys)) - for i := range s.Keys { - out.Keys[i] = s.Keys[i].Clone() - } - } - - return out -} - -// cloneForUpdate is like Clone, except LastAUMHash is set based -// on the hash of the given update. -func (s State) cloneForUpdate(update *AUM) State { - out := s.Clone() - aumHash := update.Hash() - out.LastAUMHash = &aumHash - return out -} - -const disablementLength = 32 - -var disablementSalt = []byte("tailscale network-lock disablement salt") - -// DisablementKDF computes a public value which can be stored in a -// key authority, but cannot be reversed to find the input secret. -// -// When the output of this function is stored in tka state (i.e. in -// tka.State.DisablementSecrets) a call to Authority.ValidDisablement() -// with the input of this function as the argument will return true. -func DisablementKDF(secret []byte) []byte { - // time = 4 (3 recommended, booped to 4 to compensate for less memory) - // memory = 16 (32 recommended) - // threads = 4 - // keyLen = 32 (256 bits) - return argon2.Key(secret, disablementSalt, 4, 16*1024, 4, disablementLength) -} - -// checkDisablement returns true for a valid disablement secret. -func (s State) checkDisablement(secret []byte) bool { - derived := DisablementKDF(secret) - for _, candidate := range s.DisablementSecrets { - if bytes.Equal(derived, candidate) { - return true - } - } - return false -} - -// parentMatches returns true if an AUM can chain to (be applied) -// to the current state. -// -// Specifically, the rules are: -// - The last AUM hash must match (transitively, this implies that this -// update follows the last update message applied to the state machine) -// - Or, the state machine knows no parent (its brand new). -func (s State) parentMatches(update AUM) bool { - if s.LastAUMHash == nil { - return true - } - return bytes.Equal(s.LastAUMHash[:], update.PrevAUMHash) -} - -// applyVerifiedAUM computes a new state based on the update provided. -// -// The provided update MUST be verified: That is, the AUM must be well-formed -// (as defined by StaticValidate()), and signatures over the AUM must have -// been verified. -func (s State) applyVerifiedAUM(update AUM) (State, error) { - // Validate that the update message has the right parent. - if !s.parentMatches(update) { - return State{}, errors.New("parent AUMHash mismatch") - } - - switch update.MessageKind { - case AUMNoOp: - out := s.cloneForUpdate(&update) - return out, nil - - case AUMCheckpoint: - if update.State == nil { - return State{}, errors.New("missing checkpoint state") - } - id1Match, id2Match := update.State.StateID1 == s.StateID1, update.State.StateID2 == s.StateID2 - if !id1Match || !id2Match { - return State{}, errors.New("checkpointed state has an incorrect stateID") - } - return update.State.cloneForUpdate(&update), nil - - case AUMAddKey: - if update.Key == nil { - return State{}, errors.New("no key to add provided") - } - keyID, err := update.Key.ID() - if err != nil { - return State{}, err - } - if _, err := s.GetKey(keyID); err == nil { - return State{}, errors.New("key already exists") - } - out := s.cloneForUpdate(&update) - out.Keys = append(out.Keys, *update.Key) - return out, nil - - case AUMUpdateKey: - k, err := s.GetKey(update.KeyID) - if err != nil { - return State{}, err - } - if update.Votes != nil { - k.Votes = *update.Votes - } - if update.Meta != nil { - k.Meta = update.Meta - } - if err := k.StaticValidate(); err != nil { - return State{}, fmt.Errorf("updated key fails validation: %v", err) - } - out := s.cloneForUpdate(&update) - for i := range out.Keys { - keyID, err := out.Keys[i].ID() - if err != nil { - return State{}, err - } - if bytes.Equal(keyID, update.KeyID) { - out.Keys[i] = k - } - } - return out, nil - - case AUMRemoveKey: - idx := -1 - for i := range s.Keys { - keyID, err := s.Keys[i].ID() - if err != nil { - return State{}, err - } - if bytes.Equal(update.KeyID, keyID) { - idx = i - break - } - } - if idx < 0 { - return State{}, ErrNoSuchKey - } - out := s.cloneForUpdate(&update) - out.Keys = append(out.Keys[:idx], out.Keys[idx+1:]...) - return out, nil - - default: - // An AUM with an unknown message kind was received! That means - // that a future version of tailscaled added some feature we don't - // understand. - // - // The future-compatibility contract for AUM message types is that - // they must only add new features, not change the semantics of existing - // mechanisms or features. As such, old clients can safely ignore them. - out := s.cloneForUpdate(&update) - return out, nil - } -} - -// Upper bound on checkpoint elements, chosen arbitrarily. Intended to -// cap out insanely large AUMs. -const ( - maxDisablementSecrets = 32 - maxKeys = 512 -) - -// staticValidateCheckpoint validates that the state is well-formed for -// inclusion in a checkpoint AUM. -func (s *State) staticValidateCheckpoint() error { - if s.LastAUMHash != nil { - return errors.New("cannot specify a parent AUM") - } - if len(s.DisablementSecrets) == 0 { - return errors.New("at least one disablement secret required") - } - if numDS := len(s.DisablementSecrets); numDS > maxDisablementSecrets { - return fmt.Errorf("too many disablement secrets (%d, max %d)", numDS, maxDisablementSecrets) - } - for i, ds := range s.DisablementSecrets { - if len(ds) != disablementLength { - return fmt.Errorf("disablement[%d]: invalid length (got %d, want %d)", i, len(ds), disablementLength) - } - for j, ds2 := range s.DisablementSecrets { - if i == j { - continue - } - if bytes.Equal(ds, ds2) { - return fmt.Errorf("disablement[%d]: duplicates disablement[%d]", i, j) - } - } - } - - if len(s.Keys) == 0 { - return errors.New("at least one key is required") - } - if numKeys := len(s.Keys); numKeys > maxKeys { - return fmt.Errorf("too many keys (%d, max %d)", numKeys, maxKeys) - } - for i, k := range s.Keys { - if err := k.StaticValidate(); err != nil { - return fmt.Errorf("key[%d]: %v", i, err) - } - } - // NOTE: The max number of keys is constrained (512), so - // O(n^2) is fine. - for i, k := range s.Keys { - for j, k2 := range s.Keys { - if i == j { - continue - } - - id1, err := k.ID() - if err != nil { - return fmt.Errorf("key[%d]: %w", i, err) - } - id2, err := k2.ID() - if err != nil { - return fmt.Errorf("key[%d]: %w", j, err) - } - - if bytes.Equal(id1, id2) { - return fmt.Errorf("key[%d]: duplicates key[%d]", i, j) - } - } - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "errors" + "fmt" + + "golang.org/x/crypto/argon2" + "tailscale.com/types/tkatype" +) + +// ErrNoSuchKey is returned if the key referenced by a KeyID does not exist. +var ErrNoSuchKey = errors.New("key not found") + +// State describes Tailnet Key Authority state at an instant in time. +// +// State is mutated by applying Authority Update Messages (AUMs), resulting +// in a new State. +type State struct { + // LastAUMHash is the blake2s digest of the last-applied AUM. + // Because AUMs are strictly ordered and form a hash chain, we + // check the previous AUM hash in an update we are applying + // is the same as the LastAUMHash. + LastAUMHash *AUMHash `cbor:"1,keyasint"` + + // DisablementSecrets are KDF-derived values which can be used + // to turn off the TKA in the event of a consensus-breaking bug. + DisablementSecrets [][]byte `cbor:"2,keyasint"` + + // Keys are the public keys of either: + // + // 1. The signing nodes currently trusted by the TKA. + // 2. Ephemeral keys that were used to generate pre-signed auth keys. + Keys []Key `cbor:"3,keyasint"` + + // StateID's are nonce's, generated on enablement and fixed for + // the lifetime of the Tailnet Key Authority. We generate 16-bytes + // worth of keyspace here just in case we come up with a cool future + // use for this. + StateID1 uint64 `cbor:"4,keyasint,omitempty"` + StateID2 uint64 `cbor:"5,keyasint,omitempty"` +} + +// GetKey returns the trusted key with the specified KeyID. +func (s State) GetKey(key tkatype.KeyID) (Key, error) { + for _, k := range s.Keys { + keyID, err := k.ID() + if err != nil { + return Key{}, err + } + + if bytes.Equal(keyID, key) { + return k, nil + } + } + + return Key{}, ErrNoSuchKey +} + +// Clone makes an independent copy of State. +// +// NOTE: There is a difference between a nil slice and an empty +// slice for encoding purposes, so an implementation of Clone() +// must take care to preserve this. +func (s State) Clone() State { + out := State{ + StateID1: s.StateID1, + StateID2: s.StateID2, + } + + if s.LastAUMHash != nil { + dupe := *s.LastAUMHash + out.LastAUMHash = &dupe + } + + if s.DisablementSecrets != nil { + out.DisablementSecrets = make([][]byte, len(s.DisablementSecrets)) + for i := range s.DisablementSecrets { + out.DisablementSecrets[i] = make([]byte, len(s.DisablementSecrets[i])) + copy(out.DisablementSecrets[i], s.DisablementSecrets[i]) + } + } + + if s.Keys != nil { + out.Keys = make([]Key, len(s.Keys)) + for i := range s.Keys { + out.Keys[i] = s.Keys[i].Clone() + } + } + + return out +} + +// cloneForUpdate is like Clone, except LastAUMHash is set based +// on the hash of the given update. +func (s State) cloneForUpdate(update *AUM) State { + out := s.Clone() + aumHash := update.Hash() + out.LastAUMHash = &aumHash + return out +} + +const disablementLength = 32 + +var disablementSalt = []byte("tailscale network-lock disablement salt") + +// DisablementKDF computes a public value which can be stored in a +// key authority, but cannot be reversed to find the input secret. +// +// When the output of this function is stored in tka state (i.e. in +// tka.State.DisablementSecrets) a call to Authority.ValidDisablement() +// with the input of this function as the argument will return true. +func DisablementKDF(secret []byte) []byte { + // time = 4 (3 recommended, booped to 4 to compensate for less memory) + // memory = 16 (32 recommended) + // threads = 4 + // keyLen = 32 (256 bits) + return argon2.Key(secret, disablementSalt, 4, 16*1024, 4, disablementLength) +} + +// checkDisablement returns true for a valid disablement secret. +func (s State) checkDisablement(secret []byte) bool { + derived := DisablementKDF(secret) + for _, candidate := range s.DisablementSecrets { + if bytes.Equal(derived, candidate) { + return true + } + } + return false +} + +// parentMatches returns true if an AUM can chain to (be applied) +// to the current state. +// +// Specifically, the rules are: +// - The last AUM hash must match (transitively, this implies that this +// update follows the last update message applied to the state machine) +// - Or, the state machine knows no parent (its brand new). +func (s State) parentMatches(update AUM) bool { + if s.LastAUMHash == nil { + return true + } + return bytes.Equal(s.LastAUMHash[:], update.PrevAUMHash) +} + +// applyVerifiedAUM computes a new state based on the update provided. +// +// The provided update MUST be verified: That is, the AUM must be well-formed +// (as defined by StaticValidate()), and signatures over the AUM must have +// been verified. +func (s State) applyVerifiedAUM(update AUM) (State, error) { + // Validate that the update message has the right parent. + if !s.parentMatches(update) { + return State{}, errors.New("parent AUMHash mismatch") + } + + switch update.MessageKind { + case AUMNoOp: + out := s.cloneForUpdate(&update) + return out, nil + + case AUMCheckpoint: + if update.State == nil { + return State{}, errors.New("missing checkpoint state") + } + id1Match, id2Match := update.State.StateID1 == s.StateID1, update.State.StateID2 == s.StateID2 + if !id1Match || !id2Match { + return State{}, errors.New("checkpointed state has an incorrect stateID") + } + return update.State.cloneForUpdate(&update), nil + + case AUMAddKey: + if update.Key == nil { + return State{}, errors.New("no key to add provided") + } + keyID, err := update.Key.ID() + if err != nil { + return State{}, err + } + if _, err := s.GetKey(keyID); err == nil { + return State{}, errors.New("key already exists") + } + out := s.cloneForUpdate(&update) + out.Keys = append(out.Keys, *update.Key) + return out, nil + + case AUMUpdateKey: + k, err := s.GetKey(update.KeyID) + if err != nil { + return State{}, err + } + if update.Votes != nil { + k.Votes = *update.Votes + } + if update.Meta != nil { + k.Meta = update.Meta + } + if err := k.StaticValidate(); err != nil { + return State{}, fmt.Errorf("updated key fails validation: %v", err) + } + out := s.cloneForUpdate(&update) + for i := range out.Keys { + keyID, err := out.Keys[i].ID() + if err != nil { + return State{}, err + } + if bytes.Equal(keyID, update.KeyID) { + out.Keys[i] = k + } + } + return out, nil + + case AUMRemoveKey: + idx := -1 + for i := range s.Keys { + keyID, err := s.Keys[i].ID() + if err != nil { + return State{}, err + } + if bytes.Equal(update.KeyID, keyID) { + idx = i + break + } + } + if idx < 0 { + return State{}, ErrNoSuchKey + } + out := s.cloneForUpdate(&update) + out.Keys = append(out.Keys[:idx], out.Keys[idx+1:]...) + return out, nil + + default: + // An AUM with an unknown message kind was received! That means + // that a future version of tailscaled added some feature we don't + // understand. + // + // The future-compatibility contract for AUM message types is that + // they must only add new features, not change the semantics of existing + // mechanisms or features. As such, old clients can safely ignore them. + out := s.cloneForUpdate(&update) + return out, nil + } +} + +// Upper bound on checkpoint elements, chosen arbitrarily. Intended to +// cap out insanely large AUMs. +const ( + maxDisablementSecrets = 32 + maxKeys = 512 +) + +// staticValidateCheckpoint validates that the state is well-formed for +// inclusion in a checkpoint AUM. +func (s *State) staticValidateCheckpoint() error { + if s.LastAUMHash != nil { + return errors.New("cannot specify a parent AUM") + } + if len(s.DisablementSecrets) == 0 { + return errors.New("at least one disablement secret required") + } + if numDS := len(s.DisablementSecrets); numDS > maxDisablementSecrets { + return fmt.Errorf("too many disablement secrets (%d, max %d)", numDS, maxDisablementSecrets) + } + for i, ds := range s.DisablementSecrets { + if len(ds) != disablementLength { + return fmt.Errorf("disablement[%d]: invalid length (got %d, want %d)", i, len(ds), disablementLength) + } + for j, ds2 := range s.DisablementSecrets { + if i == j { + continue + } + if bytes.Equal(ds, ds2) { + return fmt.Errorf("disablement[%d]: duplicates disablement[%d]", i, j) + } + } + } + + if len(s.Keys) == 0 { + return errors.New("at least one key is required") + } + if numKeys := len(s.Keys); numKeys > maxKeys { + return fmt.Errorf("too many keys (%d, max %d)", numKeys, maxKeys) + } + for i, k := range s.Keys { + if err := k.StaticValidate(); err != nil { + return fmt.Errorf("key[%d]: %v", i, err) + } + } + // NOTE: The max number of keys is constrained (512), so + // O(n^2) is fine. + for i, k := range s.Keys { + for j, k2 := range s.Keys { + if i == j { + continue + } + + id1, err := k.ID() + if err != nil { + return fmt.Errorf("key[%d]: %w", i, err) + } + id2, err := k2.ID() + if err != nil { + return fmt.Errorf("key[%d]: %w", j, err) + } + + if bytes.Equal(id1, id2) { + return fmt.Errorf("key[%d]: duplicates key[%d]", i, j) + } + } + } + return nil +} diff --git a/tka/state_test.go b/tka/state_test.go index b8337dd8a6cb8..060bd9350dd06 100644 --- a/tka/state_test.go +++ b/tka/state_test.go @@ -1,260 +1,260 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "encoding/hex" - "errors" - "testing" - - "github.com/fxamacker/cbor/v2" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" -) - -func fromHex(in string) []byte { - out, err := hex.DecodeString(in) - if err != nil { - panic(err) - } - return out -} - -func hashFromHex(in string) *AUMHash { - var out AUMHash - copy(out[:], fromHex(in)) - return &out -} - -func TestCloneState(t *testing.T) { - tcs := []struct { - Name string - State State - }{ - { - "Empty", - State{}, - }, - { - "Key", - State{ - Keys: []Key{{Kind: Key25519, Votes: 2, Public: []byte{5, 6, 7, 8}, Meta: map[string]string{"a": "b"}}}, - }, - }, - { - "StateID", - State{ - StateID1: 42, - StateID2: 22, - }, - }, - { - "DisablementSecrets", - State{ - DisablementSecrets: [][]byte{ - {1, 2, 3, 4}, - {5, 6, 7, 8}, - }, - }, - }, - } - - for _, tc := range tcs { - t.Run(tc.Name, func(t *testing.T) { - if diff := cmp.Diff(tc.State, tc.State.Clone()); diff != "" { - t.Errorf("output state differs (-want, +got):\n%s", diff) - } - - // Make sure the cloned State is the same even after - // an encode + decode into + from CBOR. - t.Run("cbor", func(t *testing.T) { - out := bytes.NewBuffer(nil) - encoder, err := cbor.CTAP2EncOptions().EncMode() - if err != nil { - t.Fatal(err) - } - if err := encoder.NewEncoder(out).Encode(tc.State.Clone()); err != nil { - t.Fatal(err) - } - - var decodedState State - if err := cbor.Unmarshal(out.Bytes(), &decodedState); err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - if diff := cmp.Diff(tc.State, decodedState); diff != "" { - t.Errorf("decoded state differs (-want, +got):\n%s", diff) - } - }) - }) - } -} - -func TestApplyUpdatesChain(t *testing.T) { - intOne := uint(1) - tcs := []struct { - Name string - Updates []AUM - Start State - End State - }{ - { - "AddKey", - []AUM{{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, - State{}, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), - }, - }, - { - "RemoveKey", - []AUM{{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03")}}, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), - }, - State{ - LastAUMHash: hashFromHex("15d65756abfafbb592279503f40759898590c9c59056be1e2e9f02684c15ba4b"), - }, - }, - { - "UpdateKey", - []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1, 2, 3, 4}, Votes: &intOne, Meta: map[string]string{"a": "b"}, PrevAUMHash: fromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03")}}, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), - }, - State{ - LastAUMHash: hashFromHex("d55458a9c3ed6997439ba5a18b9b62d2c6e5e0c1bb4c61409e92a1281a3b458d"), - Keys: []Key{{Kind: Key25519, Votes: 1, Meta: map[string]string{"a": "b"}, Public: []byte{1, 2, 3, 4}}}, - }, - }, - { - "ChainedKeyUpdates", - []AUM{ - {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, - {MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("f09bda3bb7cf6756ea9adc25770aede4b3ca8142949d6ef5ca0add29af912fd4")}, - }, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - }, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, - LastAUMHash: hashFromHex("218165fe5f757304b9deaff4ac742890364f5f509e533c74e80e0ce35e44ee1d"), - }, - }, - { - "Checkpoint", - []AUM{ - {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, - {MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - }, PrevAUMHash: fromHex("f09bda3bb7cf6756ea9adc25770aede4b3ca8142949d6ef5ca0add29af912fd4")}, - }, - State{DisablementSecrets: [][]byte{{1, 2, 3, 4}}}, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - LastAUMHash: hashFromHex("57343671da5eea3cfb502954e976e8028bffd3540b50a043b2a65a8d8d8217d0"), - }, - }, - } - - for _, tc := range tcs { - t.Run(tc.Name, func(t *testing.T) { - state := tc.Start - for i := range tc.Updates { - var err error - // t.Logf("update[%d] start-state = %+v", i, state) - state, err = state.applyVerifiedAUM(tc.Updates[i]) - if err != nil { - t.Fatalf("Apply message[%d] failed: %v", i, err) - } - // t.Logf("update[%d] end-state = %+v", i, state) - - updateHash := tc.Updates[i].Hash() - if got, want := *state.LastAUMHash, updateHash[:]; !bytes.Equal(got[:], want) { - t.Errorf("expected state.LastAUMHash = %x (update %d), got %x", want, i, got) - } - } - - if diff := cmp.Diff(tc.End, state, cmpopts.EquateEmpty()); diff != "" { - t.Errorf("output state differs (+got, -want):\n%s", diff) - } - }) - } -} - -func TestApplyUpdateErrors(t *testing.T) { - tooLargeVotes := uint(99999) - tcs := []struct { - Name string - Updates []AUM - Start State - Error error - }{ - { - "AddKey exists", - []AUM{{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, - State{Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, - errors.New("key already exists"), - }, - { - "RemoveKey notfound", - []AUM{{MessageKind: AUMRemoveKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, - State{}, - ErrNoSuchKey, - }, - { - "UpdateKey notfound", - []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}}}, - State{}, - ErrNoSuchKey, - }, - { - "UpdateKey now fails validation", - []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}, Votes: &tooLargeVotes}}, - State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}}, - errors.New("updated key fails validation: excessive key weight: 99999 > 4096"), - }, - { - "Bad lastAUMHash", - []AUM{ - {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, - {MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("1234")}, - }, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - }, - errors.New("parent AUMHash mismatch"), - }, - { - "Bad StateID", - []AUM{{MessageKind: AUMCheckpoint, State: &State{StateID1: 1}}}, - State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}, StateID1: 42}, - errors.New("checkpointed state has an incorrect stateID"), - }, - } - - for _, tc := range tcs { - t.Run(tc.Name, func(t *testing.T) { - state := tc.Start - for i := range tc.Updates { - var err error - // t.Logf("update[%d] start-state = %+v", i, state) - state, err = state.applyVerifiedAUM(tc.Updates[i]) - if err != nil { - if err.Error() != tc.Error.Error() { - t.Errorf("state[%d].Err = %v, want %v", i, err, tc.Error) - } else { - return - } - } - // t.Logf("update[%d] end-state = %+v", i, state) - } - - t.Errorf("did not error, expected %v", tc.Error) - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "encoding/hex" + "errors" + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func fromHex(in string) []byte { + out, err := hex.DecodeString(in) + if err != nil { + panic(err) + } + return out +} + +func hashFromHex(in string) *AUMHash { + var out AUMHash + copy(out[:], fromHex(in)) + return &out +} + +func TestCloneState(t *testing.T) { + tcs := []struct { + Name string + State State + }{ + { + "Empty", + State{}, + }, + { + "Key", + State{ + Keys: []Key{{Kind: Key25519, Votes: 2, Public: []byte{5, 6, 7, 8}, Meta: map[string]string{"a": "b"}}}, + }, + }, + { + "StateID", + State{ + StateID1: 42, + StateID2: 22, + }, + }, + { + "DisablementSecrets", + State{ + DisablementSecrets: [][]byte{ + {1, 2, 3, 4}, + {5, 6, 7, 8}, + }, + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.Name, func(t *testing.T) { + if diff := cmp.Diff(tc.State, tc.State.Clone()); diff != "" { + t.Errorf("output state differs (-want, +got):\n%s", diff) + } + + // Make sure the cloned State is the same even after + // an encode + decode into + from CBOR. + t.Run("cbor", func(t *testing.T) { + out := bytes.NewBuffer(nil) + encoder, err := cbor.CTAP2EncOptions().EncMode() + if err != nil { + t.Fatal(err) + } + if err := encoder.NewEncoder(out).Encode(tc.State.Clone()); err != nil { + t.Fatal(err) + } + + var decodedState State + if err := cbor.Unmarshal(out.Bytes(), &decodedState); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if diff := cmp.Diff(tc.State, decodedState); diff != "" { + t.Errorf("decoded state differs (-want, +got):\n%s", diff) + } + }) + }) + } +} + +func TestApplyUpdatesChain(t *testing.T) { + intOne := uint(1) + tcs := []struct { + Name string + Updates []AUM + Start State + End State + }{ + { + "AddKey", + []AUM{{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, + State{}, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), + }, + }, + { + "RemoveKey", + []AUM{{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03")}}, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), + }, + State{ + LastAUMHash: hashFromHex("15d65756abfafbb592279503f40759898590c9c59056be1e2e9f02684c15ba4b"), + }, + }, + { + "UpdateKey", + []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1, 2, 3, 4}, Votes: &intOne, Meta: map[string]string{"a": "b"}, PrevAUMHash: fromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03")}}, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), + }, + State{ + LastAUMHash: hashFromHex("d55458a9c3ed6997439ba5a18b9b62d2c6e5e0c1bb4c61409e92a1281a3b458d"), + Keys: []Key{{Kind: Key25519, Votes: 1, Meta: map[string]string{"a": "b"}, Public: []byte{1, 2, 3, 4}}}, + }, + }, + { + "ChainedKeyUpdates", + []AUM{ + {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, + {MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("f09bda3bb7cf6756ea9adc25770aede4b3ca8142949d6ef5ca0add29af912fd4")}, + }, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + }, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, + LastAUMHash: hashFromHex("218165fe5f757304b9deaff4ac742890364f5f509e533c74e80e0ce35e44ee1d"), + }, + }, + { + "Checkpoint", + []AUM{ + {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, + {MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + }, PrevAUMHash: fromHex("f09bda3bb7cf6756ea9adc25770aede4b3ca8142949d6ef5ca0add29af912fd4")}, + }, + State{DisablementSecrets: [][]byte{{1, 2, 3, 4}}}, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + LastAUMHash: hashFromHex("57343671da5eea3cfb502954e976e8028bffd3540b50a043b2a65a8d8d8217d0"), + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.Name, func(t *testing.T) { + state := tc.Start + for i := range tc.Updates { + var err error + // t.Logf("update[%d] start-state = %+v", i, state) + state, err = state.applyVerifiedAUM(tc.Updates[i]) + if err != nil { + t.Fatalf("Apply message[%d] failed: %v", i, err) + } + // t.Logf("update[%d] end-state = %+v", i, state) + + updateHash := tc.Updates[i].Hash() + if got, want := *state.LastAUMHash, updateHash[:]; !bytes.Equal(got[:], want) { + t.Errorf("expected state.LastAUMHash = %x (update %d), got %x", want, i, got) + } + } + + if diff := cmp.Diff(tc.End, state, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("output state differs (+got, -want):\n%s", diff) + } + }) + } +} + +func TestApplyUpdateErrors(t *testing.T) { + tooLargeVotes := uint(99999) + tcs := []struct { + Name string + Updates []AUM + Start State + Error error + }{ + { + "AddKey exists", + []AUM{{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, + State{Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, + errors.New("key already exists"), + }, + { + "RemoveKey notfound", + []AUM{{MessageKind: AUMRemoveKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, + State{}, + ErrNoSuchKey, + }, + { + "UpdateKey notfound", + []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}}}, + State{}, + ErrNoSuchKey, + }, + { + "UpdateKey now fails validation", + []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}, Votes: &tooLargeVotes}}, + State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}}, + errors.New("updated key fails validation: excessive key weight: 99999 > 4096"), + }, + { + "Bad lastAUMHash", + []AUM{ + {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, + {MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("1234")}, + }, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + }, + errors.New("parent AUMHash mismatch"), + }, + { + "Bad StateID", + []AUM{{MessageKind: AUMCheckpoint, State: &State{StateID1: 1}}}, + State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}, StateID1: 42}, + errors.New("checkpointed state has an incorrect stateID"), + }, + } + + for _, tc := range tcs { + t.Run(tc.Name, func(t *testing.T) { + state := tc.Start + for i := range tc.Updates { + var err error + // t.Logf("update[%d] start-state = %+v", i, state) + state, err = state.applyVerifiedAUM(tc.Updates[i]) + if err != nil { + if err.Error() != tc.Error.Error() { + t.Errorf("state[%d].Err = %v, want %v", i, err, tc.Error) + } else { + return + } + } + // t.Logf("update[%d] end-state = %+v", i, state) + } + + t.Errorf("did not error, expected %v", tc.Error) + }) + } +} diff --git a/tka/sync_test.go b/tka/sync_test.go index d214020c41af4..7250eacf7d143 100644 --- a/tka/sync_test.go +++ b/tka/sync_test.go @@ -1,377 +1,377 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "strconv" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestSyncOffer(t *testing.T) { - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 - A10 -> A11 -> A12 -> A13 -> A14 -> A15 -> A16 -> A17 -> A18 - A18 -> A19 -> A20 -> A21 -> A22 -> A23 -> A24 -> A25 - `) - storage := c.Chonk() - a, err := Open(storage) - if err != nil { - t.Fatal(err) - } - got, err := a.SyncOffer(storage) - if err != nil { - t.Fatal(err) - } - - // A SyncOffer includes a selection of AUMs going backwards in the tree, - // progressively skipping more and more each iteration. - want := SyncOffer{ - Head: c.AUMHashes["A25"], - Ancestors: []AUMHash{ - c.AUMHashes["A"+strconv.Itoa(25-ancestorsSkipStart)], - c.AUMHashes["A"+strconv.Itoa(25-ancestorsSkipStart< A2 - // Node 2 has: A1 -> A2 -> A3 -> A4 - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 - `) - a1H, a2H := c.AUMHashes["A1"], c.AUMHashes["A2"] - - chonk1 := c.ChonkWith("A1", "A2") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } - - chonk2 := c.Chonk() // All AUMs - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } - - // Node 1 only knows about the first two nodes, so the head of n2 is - // alien to it. - t.Run("n1", func(t *testing.T) { - got, err := computeSyncIntersection(chonk1, offer1, offer2) - if err != nil { - t.Fatalf("computeSyncIntersection() failed: %v", err) - } - want := &intersection{ - tailIntersection: &a1H, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { - t.Errorf("intersection diff (-want, +got):\n%s", diff) - } - }) - - // Node 2 knows about the full chain, so it can see that the head of n1 - // intersects with a subset of its chain (a Head Intersection). - t.Run("n2", func(t *testing.T) { - got, err := computeSyncIntersection(chonk2, offer2, offer1) - if err != nil { - t.Fatalf("computeSyncIntersection() failed: %v", err) - } - want := &intersection{ - headIntersection: &a2H, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { - t.Errorf("intersection diff (-want, +got):\n%s", diff) - } - }) -} - -func TestComputeSyncIntersection_ForkSmallDiff(t *testing.T) { - // The number of nodes in the chain is longer than ancestorSkipStart, - // so that during sync both nodes are able to find a common ancestor - // which was later than A1. - - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 - | -> F1 - // Make F1 different to A9. - // hashSeed is chosen such that the hash is higher than A9. - F1.hashSeed = 7 - `) - // Node 1 has: A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> F1 - // Node 2 has: A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 - f1H, a9H := c.AUMHashes["F1"], c.AUMHashes["A9"] - - if bytes.Compare(f1H[:], a9H[:]) < 0 { - t.Fatal("failed assert: h(a9) > h(f1H)\nTweak hashSeed till this passes") - } - - chonk1 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "F1") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(SyncOffer{ - Head: c.AUMHashes["F1"], - Ancestors: []AUMHash{ - c.AUMHashes["A"+strconv.Itoa(9-ancestorsSkipStart)], - c.AUMHashes["A1"], - }, - }, offer1); diff != "" { - t.Errorf("offer1 diff (-want, +got):\n%s", diff) - } - - chonk2 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9", "A10") - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(SyncOffer{ - Head: c.AUMHashes["A10"], - Ancestors: []AUMHash{ - c.AUMHashes["A"+strconv.Itoa(10-ancestorsSkipStart)], - c.AUMHashes["A1"], - }, - }, offer2); diff != "" { - t.Errorf("offer2 diff (-want, +got):\n%s", diff) - } - - // Node 1 only knows about the first eight nodes, so the head of n2 is - // alien to it. - t.Run("n1", func(t *testing.T) { - // n2 has 10 nodes, so the first common ancestor should be 10-ancestorsSkipStart - wantIntersection := c.AUMHashes["A"+strconv.Itoa(10-ancestorsSkipStart)] - - got, err := computeSyncIntersection(chonk1, offer1, offer2) - if err != nil { - t.Fatalf("computeSyncIntersection() failed: %v", err) - } - want := &intersection{ - tailIntersection: &wantIntersection, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { - t.Errorf("intersection diff (-want, +got):\n%s", diff) - } - }) - - // Node 2 knows about the full chain but doesn't recognize the head. - t.Run("n2", func(t *testing.T) { - // n1 has 9 nodes, so the first common ancestor should be 9-ancestorsSkipStart - wantIntersection := c.AUMHashes["A"+strconv.Itoa(9-ancestorsSkipStart)] - - got, err := computeSyncIntersection(chonk2, offer2, offer1) - if err != nil { - t.Fatalf("computeSyncIntersection() failed: %v", err) - } - want := &intersection{ - tailIntersection: &wantIntersection, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { - t.Errorf("intersection diff (-want, +got):\n%s", diff) - } - }) -} - -func TestMissingAUMs_FastForward(t *testing.T) { - // Node 1 has: A1 -> A2 - // Node 2 has: A1 -> A2 -> A3 -> A4 - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 - A1.hashSeed = 1 - A2.hashSeed = 2 - A3.hashSeed = 3 - A4.hashSeed = 4 - `) - - chonk1 := c.ChonkWith("A1", "A2") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } - - chonk2 := c.Chonk() // All AUMs - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } - - // Node 1 only knows about the first two nodes, so the head of n2 is - // alien to it. As such, it should send history from the newest ancestor, - // A1 (if the chain was longer there would be one in the middle). - t.Run("n1", func(t *testing.T) { - got, err := n1.MissingAUMs(chonk1, offer2) - if err != nil { - t.Fatalf("MissingAUMs() failed: %v", err) - } - - // Both sides have A1, so the only AUM that n2 might not have is - // A2. - want := []AUM{c.AUMs["A2"]} - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) - } - }) - - // Node 2 knows about the full chain, so it can see that the head of n1 - // intersects with a subset of its chain (a Head Intersection). - t.Run("n2", func(t *testing.T) { - got, err := n2.MissingAUMs(chonk2, offer1) - if err != nil { - t.Fatalf("MissingAUMs() failed: %v", err) - } - - want := []AUM{ - c.AUMs["A3"], - c.AUMs["A4"], - } - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) - } - }) -} - -func TestMissingAUMs_Fork(t *testing.T) { - // Node 1 has: A1 -> A2 -> A3 -> F1 - // Node 2 has: A1 -> A2 -> A3 -> A4 - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 - | -> F1 - A1.hashSeed = 1 - A2.hashSeed = 2 - A3.hashSeed = 3 - A4.hashSeed = 4 - `) - - chonk1 := c.ChonkWith("A1", "A2", "A3", "F1") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } - - chonk2 := c.ChonkWith("A1", "A2", "A3", "A4") - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } - - t.Run("n1", func(t *testing.T) { - got, err := n1.MissingAUMs(chonk1, offer2) - if err != nil { - t.Fatalf("MissingAUMs() failed: %v", err) - } - - // Both sides have A1, so n1 will send everything it knows from - // there to head. - want := []AUM{ - c.AUMs["A2"], - c.AUMs["A3"], - c.AUMs["F1"], - } - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) - } - }) - - t.Run("n2", func(t *testing.T) { - got, err := n2.MissingAUMs(chonk2, offer1) - if err != nil { - t.Fatalf("MissingAUMs() failed: %v", err) - } - - // Both sides have A1, so n2 will send everything it knows from - // there to head. - want := []AUM{ - c.AUMs["A2"], - c.AUMs["A3"], - c.AUMs["A4"], - } - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) - } - }) -} - -func TestSyncSimpleE2E(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> L1 -> L2 -> L3 - G1.template = genesis - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optKey("key", key, priv), - optSignAllUsing("key")) - - nodeStorage := &Mem{} - node, err := Bootstrap(nodeStorage, c.AUMs["G1"]) - if err != nil { - t.Fatalf("node Bootstrap() failed: %v", err) - } - controlStorage := c.Chonk() - control, err := Open(controlStorage) - if err != nil { - t.Fatalf("control Open() failed: %v", err) - } - - // Control knows the full chain, node only knows the genesis. Lets see - // if they can sync. - nodeOffer, err := node.SyncOffer(nodeStorage) - if err != nil { - t.Fatal(err) - } - controlAUMs, err := control.MissingAUMs(controlStorage, nodeOffer) - if err != nil { - t.Fatalf("control.MissingAUMs(%v) failed: %v", nodeOffer, err) - } - if err := node.Inform(nodeStorage, controlAUMs); err != nil { - t.Fatalf("node.Inform(%v) failed: %v", controlAUMs, err) - } - - if cHash, nHash := control.Head(), node.Head(); cHash != nHash { - t.Errorf("node & control are not synced: c=%x, n=%x", cHash, nHash) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "strconv" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestSyncOffer(t *testing.T) { + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 + A10 -> A11 -> A12 -> A13 -> A14 -> A15 -> A16 -> A17 -> A18 + A18 -> A19 -> A20 -> A21 -> A22 -> A23 -> A24 -> A25 + `) + storage := c.Chonk() + a, err := Open(storage) + if err != nil { + t.Fatal(err) + } + got, err := a.SyncOffer(storage) + if err != nil { + t.Fatal(err) + } + + // A SyncOffer includes a selection of AUMs going backwards in the tree, + // progressively skipping more and more each iteration. + want := SyncOffer{ + Head: c.AUMHashes["A25"], + Ancestors: []AUMHash{ + c.AUMHashes["A"+strconv.Itoa(25-ancestorsSkipStart)], + c.AUMHashes["A"+strconv.Itoa(25-ancestorsSkipStart< A2 + // Node 2 has: A1 -> A2 -> A3 -> A4 + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 + `) + a1H, a2H := c.AUMHashes["A1"], c.AUMHashes["A2"] + + chonk1 := c.ChonkWith("A1", "A2") + n1, err := Open(chonk1) + if err != nil { + t.Fatal(err) + } + offer1, err := n1.SyncOffer(chonk1) + if err != nil { + t.Fatal(err) + } + + chonk2 := c.Chonk() // All AUMs + n2, err := Open(chonk2) + if err != nil { + t.Fatal(err) + } + offer2, err := n2.SyncOffer(chonk2) + if err != nil { + t.Fatal(err) + } + + // Node 1 only knows about the first two nodes, so the head of n2 is + // alien to it. + t.Run("n1", func(t *testing.T) { + got, err := computeSyncIntersection(chonk1, offer1, offer2) + if err != nil { + t.Fatalf("computeSyncIntersection() failed: %v", err) + } + want := &intersection{ + tailIntersection: &a1H, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { + t.Errorf("intersection diff (-want, +got):\n%s", diff) + } + }) + + // Node 2 knows about the full chain, so it can see that the head of n1 + // intersects with a subset of its chain (a Head Intersection). + t.Run("n2", func(t *testing.T) { + got, err := computeSyncIntersection(chonk2, offer2, offer1) + if err != nil { + t.Fatalf("computeSyncIntersection() failed: %v", err) + } + want := &intersection{ + headIntersection: &a2H, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { + t.Errorf("intersection diff (-want, +got):\n%s", diff) + } + }) +} + +func TestComputeSyncIntersection_ForkSmallDiff(t *testing.T) { + // The number of nodes in the chain is longer than ancestorSkipStart, + // so that during sync both nodes are able to find a common ancestor + // which was later than A1. + + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 + | -> F1 + // Make F1 different to A9. + // hashSeed is chosen such that the hash is higher than A9. + F1.hashSeed = 7 + `) + // Node 1 has: A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> F1 + // Node 2 has: A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 + f1H, a9H := c.AUMHashes["F1"], c.AUMHashes["A9"] + + if bytes.Compare(f1H[:], a9H[:]) < 0 { + t.Fatal("failed assert: h(a9) > h(f1H)\nTweak hashSeed till this passes") + } + + chonk1 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "F1") + n1, err := Open(chonk1) + if err != nil { + t.Fatal(err) + } + offer1, err := n1.SyncOffer(chonk1) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(SyncOffer{ + Head: c.AUMHashes["F1"], + Ancestors: []AUMHash{ + c.AUMHashes["A"+strconv.Itoa(9-ancestorsSkipStart)], + c.AUMHashes["A1"], + }, + }, offer1); diff != "" { + t.Errorf("offer1 diff (-want, +got):\n%s", diff) + } + + chonk2 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9", "A10") + n2, err := Open(chonk2) + if err != nil { + t.Fatal(err) + } + offer2, err := n2.SyncOffer(chonk2) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(SyncOffer{ + Head: c.AUMHashes["A10"], + Ancestors: []AUMHash{ + c.AUMHashes["A"+strconv.Itoa(10-ancestorsSkipStart)], + c.AUMHashes["A1"], + }, + }, offer2); diff != "" { + t.Errorf("offer2 diff (-want, +got):\n%s", diff) + } + + // Node 1 only knows about the first eight nodes, so the head of n2 is + // alien to it. + t.Run("n1", func(t *testing.T) { + // n2 has 10 nodes, so the first common ancestor should be 10-ancestorsSkipStart + wantIntersection := c.AUMHashes["A"+strconv.Itoa(10-ancestorsSkipStart)] + + got, err := computeSyncIntersection(chonk1, offer1, offer2) + if err != nil { + t.Fatalf("computeSyncIntersection() failed: %v", err) + } + want := &intersection{ + tailIntersection: &wantIntersection, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { + t.Errorf("intersection diff (-want, +got):\n%s", diff) + } + }) + + // Node 2 knows about the full chain but doesn't recognize the head. + t.Run("n2", func(t *testing.T) { + // n1 has 9 nodes, so the first common ancestor should be 9-ancestorsSkipStart + wantIntersection := c.AUMHashes["A"+strconv.Itoa(9-ancestorsSkipStart)] + + got, err := computeSyncIntersection(chonk2, offer2, offer1) + if err != nil { + t.Fatalf("computeSyncIntersection() failed: %v", err) + } + want := &intersection{ + tailIntersection: &wantIntersection, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { + t.Errorf("intersection diff (-want, +got):\n%s", diff) + } + }) +} + +func TestMissingAUMs_FastForward(t *testing.T) { + // Node 1 has: A1 -> A2 + // Node 2 has: A1 -> A2 -> A3 -> A4 + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 + A1.hashSeed = 1 + A2.hashSeed = 2 + A3.hashSeed = 3 + A4.hashSeed = 4 + `) + + chonk1 := c.ChonkWith("A1", "A2") + n1, err := Open(chonk1) + if err != nil { + t.Fatal(err) + } + offer1, err := n1.SyncOffer(chonk1) + if err != nil { + t.Fatal(err) + } + + chonk2 := c.Chonk() // All AUMs + n2, err := Open(chonk2) + if err != nil { + t.Fatal(err) + } + offer2, err := n2.SyncOffer(chonk2) + if err != nil { + t.Fatal(err) + } + + // Node 1 only knows about the first two nodes, so the head of n2 is + // alien to it. As such, it should send history from the newest ancestor, + // A1 (if the chain was longer there would be one in the middle). + t.Run("n1", func(t *testing.T) { + got, err := n1.MissingAUMs(chonk1, offer2) + if err != nil { + t.Fatalf("MissingAUMs() failed: %v", err) + } + + // Both sides have A1, so the only AUM that n2 might not have is + // A2. + want := []AUM{c.AUMs["A2"]} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) + } + }) + + // Node 2 knows about the full chain, so it can see that the head of n1 + // intersects with a subset of its chain (a Head Intersection). + t.Run("n2", func(t *testing.T) { + got, err := n2.MissingAUMs(chonk2, offer1) + if err != nil { + t.Fatalf("MissingAUMs() failed: %v", err) + } + + want := []AUM{ + c.AUMs["A3"], + c.AUMs["A4"], + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) + } + }) +} + +func TestMissingAUMs_Fork(t *testing.T) { + // Node 1 has: A1 -> A2 -> A3 -> F1 + // Node 2 has: A1 -> A2 -> A3 -> A4 + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 + | -> F1 + A1.hashSeed = 1 + A2.hashSeed = 2 + A3.hashSeed = 3 + A4.hashSeed = 4 + `) + + chonk1 := c.ChonkWith("A1", "A2", "A3", "F1") + n1, err := Open(chonk1) + if err != nil { + t.Fatal(err) + } + offer1, err := n1.SyncOffer(chonk1) + if err != nil { + t.Fatal(err) + } + + chonk2 := c.ChonkWith("A1", "A2", "A3", "A4") + n2, err := Open(chonk2) + if err != nil { + t.Fatal(err) + } + offer2, err := n2.SyncOffer(chonk2) + if err != nil { + t.Fatal(err) + } + + t.Run("n1", func(t *testing.T) { + got, err := n1.MissingAUMs(chonk1, offer2) + if err != nil { + t.Fatalf("MissingAUMs() failed: %v", err) + } + + // Both sides have A1, so n1 will send everything it knows from + // there to head. + want := []AUM{ + c.AUMs["A2"], + c.AUMs["A3"], + c.AUMs["F1"], + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) + } + }) + + t.Run("n2", func(t *testing.T) { + got, err := n2.MissingAUMs(chonk2, offer1) + if err != nil { + t.Fatalf("MissingAUMs() failed: %v", err) + } + + // Both sides have A1, so n2 will send everything it knows from + // there to head. + want := []AUM{ + c.AUMs["A2"], + c.AUMs["A3"], + c.AUMs["A4"], + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) + } + }) +} + +func TestSyncSimpleE2E(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 -> L2 -> L3 + G1.template = genesis + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optKey("key", key, priv), + optSignAllUsing("key")) + + nodeStorage := &Mem{} + node, err := Bootstrap(nodeStorage, c.AUMs["G1"]) + if err != nil { + t.Fatalf("node Bootstrap() failed: %v", err) + } + controlStorage := c.Chonk() + control, err := Open(controlStorage) + if err != nil { + t.Fatalf("control Open() failed: %v", err) + } + + // Control knows the full chain, node only knows the genesis. Lets see + // if they can sync. + nodeOffer, err := node.SyncOffer(nodeStorage) + if err != nil { + t.Fatal(err) + } + controlAUMs, err := control.MissingAUMs(controlStorage, nodeOffer) + if err != nil { + t.Fatalf("control.MissingAUMs(%v) failed: %v", nodeOffer, err) + } + if err := node.Inform(nodeStorage, controlAUMs); err != nil { + t.Fatalf("node.Inform(%v) failed: %v", controlAUMs, err) + } + + if cHash, nHash := control.Head(), node.Head(); cHash != nHash { + t.Errorf("node & control are not synced: c=%x, n=%x", cHash, nHash) + } +} diff --git a/tka/tailchonk_test.go b/tka/tailchonk_test.go index 13d989f0c3c63..86d5642a3bd10 100644 --- a/tka/tailchonk_test.go +++ b/tka/tailchonk_test.go @@ -1,693 +1,693 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "fmt" - "os" - "path/filepath" - "sync" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "golang.org/x/crypto/blake2s" -) - -// randHash derives a fake blake2s hash from the test name -// and the given seed. -func randHash(t *testing.T, seed int64) [blake2s.Size]byte { - var out [blake2s.Size]byte - testingRand(t, seed).Read(out[:]) - return out -} - -func TestImplementsChonk(t *testing.T) { - impls := []Chonk{&Mem{}, &FS{}} - t.Logf("chonks: %v", impls) -} - -func TestTailchonk_ChildAUMs(t *testing.T) { - for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { - t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { - parentHash := randHash(t, 1) - data := []AUM{ - { - MessageKind: AUMRemoveKey, - KeyID: []byte{1, 2}, - PrevAUMHash: parentHash[:], - }, - { - MessageKind: AUMRemoveKey, - KeyID: []byte{3, 4}, - PrevAUMHash: parentHash[:], - }, - } - - if err := chonk.CommitVerifiedAUMs(data); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) - } - stored, err := chonk.ChildAUMs(parentHash) - if err != nil { - t.Fatalf("ChildAUMs failed: %v", err) - } - if diff := cmp.Diff(data, stored); diff != "" { - t.Errorf("stored AUM differs (-want, +got):\n%s", diff) - } - }) - } -} - -func TestTailchonk_AUMMissing(t *testing.T) { - for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { - t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { - var notExists AUMHash - notExists[:][0] = 42 - if _, err := chonk.AUM(notExists); err != os.ErrNotExist { - t.Errorf("chonk.AUM(notExists).err = %v, want %v", err, os.ErrNotExist) - } - }) - } -} - -func TestTailchonkMem_Orphans(t *testing.T) { - chonk := Mem{} - - parentHash := randHash(t, 1) - orphan := AUM{MessageKind: AUMNoOp} - aums := []AUM{ - orphan, - // A parent is specified, so we shouldnt see it in GetOrphans() - { - MessageKind: AUMRemoveKey, - KeyID: []byte{3, 4}, - PrevAUMHash: parentHash[:], - }, - } - if err := chonk.CommitVerifiedAUMs(aums); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) - } - - stored, err := chonk.Orphans() - if err != nil { - t.Fatalf("Orphans failed: %v", err) - } - if diff := cmp.Diff([]AUM{orphan}, stored); diff != "" { - t.Errorf("stored AUM differs (-want, +got):\n%s", diff) - } -} - -func TestTailchonk_ReadChainFromHead(t *testing.T) { - for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { - - t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { - genesis := AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}} - gHash := genesis.Hash() - intermediate := AUM{PrevAUMHash: gHash[:]} - iHash := intermediate.Hash() - leaf := AUM{PrevAUMHash: iHash[:]} - - commitSet := []AUM{ - genesis, - intermediate, - leaf, - } - if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) - } - // t.Logf("genesis hash = %X", genesis.Hash()) - // t.Logf("intermediate hash = %X", intermediate.Hash()) - // t.Logf("leaf hash = %X", leaf.Hash()) - - // Read the chain from the leaf backwards. - gotLeafs, err := chonk.Heads() - if err != nil { - t.Fatalf("Heads failed: %v", err) - } - if diff := cmp.Diff([]AUM{leaf}, gotLeafs); diff != "" { - t.Fatalf("leaf AUM differs (-want, +got):\n%s", diff) - } - - parent, _ := gotLeafs[0].Parent() - gotIntermediate, err := chonk.AUM(parent) - if err != nil { - t.Fatalf("AUM() failed: %v", err) - } - if diff := cmp.Diff(intermediate, gotIntermediate); diff != "" { - t.Errorf("intermediate AUM differs (-want, +got):\n%s", diff) - } - - parent, _ = gotIntermediate.Parent() - gotGenesis, err := chonk.AUM(parent) - if err != nil { - t.Fatalf("AUM() failed: %v", err) - } - if diff := cmp.Diff(genesis, gotGenesis); diff != "" { - t.Errorf("genesis AUM differs (-want, +got):\n%s", diff) - } - }) - } -} - -func TestTailchonkFS_Commit(t *testing.T) { - chonk := &FS{base: t.TempDir()} - parentHash := randHash(t, 1) - aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} - - if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { - t.Fatal(err) - } - - dir, base := chonk.aumDir(aum.Hash()) - if got, want := dir, filepath.Join(chonk.base, "PD"); got != want { - t.Errorf("aum dir=%s, want %s", got, want) - } - if want := "PD57DVP6GKC76OOZMXFFZUSOEFQXOLAVT7N2ZM5KB3HDIMCANF4A"; base != want { - t.Errorf("aum base=%s, want %s", base, want) - } - if _, err := os.Stat(filepath.Join(dir, base)); err != nil { - t.Errorf("stat of AUM file failed: %v", err) - } - if _, err := os.Stat(filepath.Join(chonk.base, "M7", "M7LL2NDB4NKCZIUPVS6RDM2GUOIMW6EEAFVBWMVCPUANQJPHT3SQ")); err != nil { - t.Errorf("stat of AUM parent failed: %v", err) - } - - info, err := chonk.get(aum.Hash()) - if err != nil { - t.Fatal(err) - } - if info.PurgedUnix > 0 { - t.Errorf("recently-created AUM PurgedUnix = %d, want 0", info.PurgedUnix) - } -} - -func TestTailchonkFS_CommitTime(t *testing.T) { - chonk := &FS{base: t.TempDir()} - parentHash := randHash(t, 1) - aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} - - if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { - t.Fatal(err) - } - ct, err := chonk.CommitTime(aum.Hash()) - if err != nil { - t.Fatalf("CommitTime() failed: %v", err) - } - if ct.Before(time.Now().Add(-time.Minute)) || ct.After(time.Now().Add(time.Minute)) { - t.Errorf("commit time was wrong: %v more than a minute off from now (%v)", ct, time.Now()) - } -} - -func TestTailchonkFS_PurgeAUMs(t *testing.T) { - chonk := &FS{base: t.TempDir()} - parentHash := randHash(t, 1) - aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} - - if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { - t.Fatal(err) - } - if err := chonk.PurgeAUMs([]AUMHash{aum.Hash()}); err != nil { - t.Fatal(err) - } - - if _, err := chonk.AUM(aum.Hash()); err != os.ErrNotExist { - t.Errorf("AUM() on purged AUM returned err = %v, want ErrNotExist", err) - } - - info, err := chonk.get(aum.Hash()) - if err != nil { - t.Fatal(err) - } - if info.PurgedUnix == 0 { - t.Errorf("recently-created AUM PurgedUnix = %d, want non-zero", info.PurgedUnix) - } -} - -func TestTailchonkFS_AllAUMs(t *testing.T) { - chonk := &FS{base: t.TempDir()} - genesis := AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}} - gHash := genesis.Hash() - intermediate := AUM{PrevAUMHash: gHash[:]} - iHash := intermediate.Hash() - leaf := AUM{PrevAUMHash: iHash[:]} - - commitSet := []AUM{ - genesis, - intermediate, - leaf, - } - if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) - } - - hashes, err := chonk.AllAUMs() - if err != nil { - t.Fatal(err) - } - hashesLess := func(a, b AUMHash) bool { - return bytes.Compare(a[:], b[:]) < 0 - } - if diff := cmp.Diff([]AUMHash{genesis.Hash(), intermediate.Hash(), leaf.Hash()}, hashes, cmpopts.SortSlices(hashesLess)); diff != "" { - t.Fatalf("AllAUMs() output differs (-want, +got):\n%s", diff) - } -} - -func TestMarkActiveChain(t *testing.T) { - type aumTemplate struct { - AUM AUM - } - - tcs := []struct { - name string - minChain int - chain []aumTemplate - expectLastActiveIdx int // expected lastActiveAncestor, corresponds to an index on chain. - }{ - { - name: "genesis", - minChain: 2, - chain: []aumTemplate{ - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - }, - expectLastActiveIdx: 0, - }, - { - name: "simple truncate", - minChain: 2, - chain: []aumTemplate{ - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - }, - expectLastActiveIdx: 1, - }, - { - name: "long truncate", - minChain: 5, - chain: []aumTemplate{ - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - }, - expectLastActiveIdx: 2, - }, - { - name: "truncate finding checkpoint", - minChain: 2, - chain: []aumTemplate{ - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMAddKey, Key: &Key{}}}, // Should keep searching upwards for a checkpoint - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - }, - expectLastActiveIdx: 1, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - verdict := make(map[AUMHash]retainState, len(tc.chain)) - - // Build the state of the tailchonk for tests. - storage := &Mem{} - var prev AUMHash - for i := range tc.chain { - if !prev.IsZero() { - tc.chain[i].AUM.PrevAUMHash = make([]byte, len(prev[:])) - copy(tc.chain[i].AUM.PrevAUMHash, prev[:]) - } - if err := storage.CommitVerifiedAUMs([]AUM{tc.chain[i].AUM}); err != nil { - t.Fatal(err) - } - - h := tc.chain[i].AUM.Hash() - prev = h - verdict[h] = 0 - } - - got, err := markActiveChain(storage, verdict, tc.minChain, prev) - if err != nil { - t.Logf("state = %+v", verdict) - t.Fatalf("markActiveChain() failed: %v", err) - } - want := tc.chain[tc.expectLastActiveIdx].AUM.Hash() - if got != want { - t.Logf("state = %+v", verdict) - t.Errorf("lastActiveAncestor = %v, want %v", got, want) - } - - // Make sure the verdict array was marked correctly. - for i := range tc.chain { - h := tc.chain[i].AUM.Hash() - if i >= tc.expectLastActiveIdx { - if (verdict[h] & retainStateActive) == 0 { - t.Errorf("verdict[%v] = %v, want %v set", h, verdict[h], retainStateActive) - } - } else { - if (verdict[h] & retainStateCandidate) == 0 { - t.Errorf("verdict[%v] = %v, want %v set", h, verdict[h], retainStateCandidate) - } - } - } - }) - } -} - -func TestMarkDescendantAUMs(t *testing.T) { - c := newTestchain(t, ` - genesis -> B -> C -> C2 - | -> D - | -> E -> F -> G -> H - | -> E2 - - // tweak seeds so hashes arent identical - C.hashSeed = 1 - D.hashSeed = 2 - E.hashSeed = 3 - E2.hashSeed = 4 - `) - - verdict := make(map[AUMHash]retainState, len(c.AUMs)) - for _, a := range c.AUMs { - verdict[a.Hash()] = 0 - } - - // Mark E & C. - verdict[c.AUMHashes["C"]] = retainStateActive - verdict[c.AUMHashes["E"]] = retainStateActive - - if err := markDescendantAUMs(c.Chonk(), verdict); err != nil { - t.Errorf("markDescendantAUMs() failed: %v", err) - } - - // Make sure the descendants got marked. - hs := c.AUMHashes - for _, h := range []AUMHash{hs["C2"], hs["F"], hs["G"], hs["H"], hs["E2"]} { - if (verdict[h] & retainStateLeaf) == 0 { - t.Errorf("%v was not marked as a descendant", h) - } - } - for _, h := range []AUMHash{hs["genesis"], hs["B"], hs["D"]} { - if (verdict[h] & retainStateLeaf) != 0 { - t.Errorf("%v was marked as a descendant and shouldnt be", h) - } - } -} - -func TestMarkAncestorIntersectionAUMs(t *testing.T) { - fakeState := &State{ - Keys: []Key{{Kind: Key25519, Votes: 1}}, - DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, - } - - tcs := []struct { - name string - chain *testChain - verdicts map[string]retainState - initialAncestor string - wantAncestor string - wantRetained []string - wantDeleted []string - }{ - { - name: "genesis", - chain: newTestchain(t, ` - A - A.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "A", - wantAncestor: "A", - verdicts: map[string]retainState{ - "A": retainStateActive, - }, - wantRetained: []string{"A"}, - }, - { - name: "no adjustment", - chain: newTestchain(t, ` - DEAD -> A -> B -> C - A.template = checkpoint - B.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "A", - wantAncestor: "A", - verdicts: map[string]retainState{ - "A": retainStateActive, - "B": retainStateActive, - "C": retainStateActive, - "DEAD": retainStateCandidate, - }, - wantRetained: []string{"A", "B", "C"}, - wantDeleted: []string{"DEAD"}, - }, - { - name: "fork", - chain: newTestchain(t, ` - A -> B -> C -> D - | -> FORK - A.template = checkpoint - C.template = checkpoint - D.template = checkpoint - FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "D", - wantAncestor: "C", - verdicts: map[string]retainState{ - "A": retainStateCandidate, - "B": retainStateCandidate, - "C": retainStateCandidate, - "D": retainStateActive, - "FORK": retainStateYoung, - }, - wantRetained: []string{"C", "D", "FORK"}, - wantDeleted: []string{"A", "B"}, - }, - { - name: "fork finding earlier checkpoint", - chain: newTestchain(t, ` - A -> B -> C -> D -> E -> F - | -> FORK - A.template = checkpoint - B.template = checkpoint - E.template = checkpoint - FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "E", - wantAncestor: "B", - verdicts: map[string]retainState{ - "A": retainStateCandidate, - "B": retainStateCandidate, - "C": retainStateCandidate, - "D": retainStateCandidate, - "E": retainStateActive, - "F": retainStateActive, - "FORK": retainStateYoung, - }, - wantRetained: []string{"B", "C", "D", "E", "F", "FORK"}, - wantDeleted: []string{"A"}, - }, - { - name: "fork multi", - chain: newTestchain(t, ` - A -> B -> C -> D -> E - | -> DEADFORK - C -> FORK - A.template = checkpoint - C.template = checkpoint - D.template = checkpoint - E.template = checkpoint - FORK.hashSeed = 2 - DEADFORK.hashSeed = 3`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "D", - wantAncestor: "C", - verdicts: map[string]retainState{ - "A": retainStateCandidate, - "B": retainStateCandidate, - "C": retainStateCandidate, - "D": retainStateActive, - "E": retainStateActive, - "FORK": retainStateYoung, - "DEADFORK": 0, - }, - wantRetained: []string{"C", "D", "E", "FORK"}, - wantDeleted: []string{"A", "B", "DEADFORK"}, - }, - { - name: "fork multi 2", - chain: newTestchain(t, ` - A -> B -> C -> D -> E -> F -> G - - F -> F1 - D -> F2 - B -> F3 - - A.template = checkpoint - B.template = checkpoint - D.template = checkpoint - F.template = checkpoint - F1.hashSeed = 2 - F2.hashSeed = 3 - F3.hashSeed = 4`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "F", - wantAncestor: "B", - verdicts: map[string]retainState{ - "A": retainStateCandidate, - "B": retainStateCandidate, - "C": retainStateCandidate, - "D": retainStateCandidate, - "E": retainStateCandidate, - "F": retainStateActive, - "G": retainStateActive, - "F1": retainStateYoung, - "F2": retainStateYoung, - "F3": retainStateYoung, - }, - wantRetained: []string{"B", "C", "D", "E", "F", "G", "F1", "F2", "F3"}, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - verdict := make(map[AUMHash]retainState, len(tc.verdicts)) - for name, v := range tc.verdicts { - verdict[tc.chain.AUMHashes[name]] = v - } - - got, err := markAncestorIntersectionAUMs(tc.chain.Chonk(), verdict, tc.chain.AUMHashes[tc.initialAncestor]) - if err != nil { - t.Logf("state = %+v", verdict) - t.Fatalf("markAncestorIntersectionAUMs() failed: %v", err) - } - if want := tc.chain.AUMHashes[tc.wantAncestor]; got != want { - t.Logf("state = %+v", verdict) - t.Errorf("lastActiveAncestor = %v, want %v", got, want) - } - - for _, name := range tc.wantRetained { - h := tc.chain.AUMHashes[name] - if v := verdict[h]; v&retainAUMMask == 0 { - t.Errorf("AUM %q was not retained: verdict = %v", name, v) - } - } - for _, name := range tc.wantDeleted { - h := tc.chain.AUMHashes[name] - if v := verdict[h]; v&retainAUMMask != 0 { - t.Errorf("AUM %q was retained: verdict = %v", name, v) - } - } - - if t.Failed() { - for name, hash := range tc.chain.AUMHashes { - t.Logf("AUM[%q] = %v", name, hash) - } - } - }) - } -} - -type compactingChonkFake struct { - Mem - - aumAge map[AUMHash]time.Time - t *testing.T - wantDelete []AUMHash -} - -func (c *compactingChonkFake) AllAUMs() ([]AUMHash, error) { - out := make([]AUMHash, 0, len(c.Mem.aums)) - for h := range c.Mem.aums { - out = append(out, h) - } - return out, nil -} - -func (c *compactingChonkFake) CommitTime(hash AUMHash) (time.Time, error) { - return c.aumAge[hash], nil -} - -func (c *compactingChonkFake) PurgeAUMs(hashes []AUMHash) error { - lessHashes := func(a, b AUMHash) bool { - return bytes.Compare(a[:], b[:]) < 0 - } - if diff := cmp.Diff(c.wantDelete, hashes, cmpopts.SortSlices(lessHashes)); diff != "" { - c.t.Errorf("deletion set differs (-want, +got):\n%s", diff) - } - return nil -} - -// Avoid go vet complaining about copying a lock value -func cloneMem(src, dst *Mem) { - dst.l = sync.RWMutex{} - dst.aums = src.aums - dst.parentIndex = src.parentIndex - dst.lastActiveAncestor = src.lastActiveAncestor -} - -func TestCompact(t *testing.T) { - fakeState := &State{ - Keys: []Key{{Kind: Key25519, Votes: 1}}, - DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, - } - - // A & B are deleted because the new lastActiveAncestor advances beyond them. - // OLD is deleted because it does not match retention criteria, and - // though it is a descendant of the new lastActiveAncestor (C), it is not a - // descendant of a retained AUM. - // G, & H are retained as recent (MinChain=2) ancestors of HEAD. - // E & F are retained because they are between retained AUMs (G+) and - // their newest checkpoint ancestor. - // D is retained because it is the newest checkpoint ancestor from - // MinChain-retained AUMs. - // G2 is retained because it is a descendant of a retained AUM (G). - // F1 is retained because it is new enough by wall-clock time. - // F2 is retained because it is a descendant of a retained AUM (F1). - // C2 is retained because it is between an ancestor checkpoint and - // a retained AUM (F1). - // C is retained because it is the new lastActiveAncestor. It is the - // new lastActiveAncestor because it is the newest common checkpoint - // of all retained AUMs. - c := newTestchain(t, ` - A -> B -> C -> C2 -> D -> E -> F -> G -> H - | -> F1 -> F2 | -> G2 - | -> OLD - - // make {A,B,C,D} compaction candidates - A.template = checkpoint - B.template = checkpoint - C.template = checkpoint - D.template = checkpoint - - // tweak seeds of forks so hashes arent identical - F1.hashSeed = 1 - OLD.hashSeed = 2 - G2.hashSeed = 3 - `, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})) - - storage := &compactingChonkFake{ - aumAge: map[AUMHash]time.Time{(c.AUMHashes["F1"]): time.Now()}, - t: t, - wantDelete: []AUMHash{c.AUMHashes["A"], c.AUMHashes["B"], c.AUMHashes["OLD"]}, - } - - cloneMem(c.Chonk().(*Mem), &storage.Mem) - - lastActiveAncestor, err := Compact(storage, c.AUMHashes["H"], CompactionOptions{MinChain: 2, MinAge: time.Hour}) - if err != nil { - t.Errorf("Compact() failed: %v", err) - } - if lastActiveAncestor != c.AUMHashes["C"] { - t.Errorf("last active ancestor = %v, want %v", lastActiveAncestor, c.AUMHashes["C"]) - } - - if t.Failed() { - for name, hash := range c.AUMHashes { - t.Logf("AUM[%q] = %v", name, hash) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "golang.org/x/crypto/blake2s" +) + +// randHash derives a fake blake2s hash from the test name +// and the given seed. +func randHash(t *testing.T, seed int64) [blake2s.Size]byte { + var out [blake2s.Size]byte + testingRand(t, seed).Read(out[:]) + return out +} + +func TestImplementsChonk(t *testing.T) { + impls := []Chonk{&Mem{}, &FS{}} + t.Logf("chonks: %v", impls) +} + +func TestTailchonk_ChildAUMs(t *testing.T) { + for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { + t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { + parentHash := randHash(t, 1) + data := []AUM{ + { + MessageKind: AUMRemoveKey, + KeyID: []byte{1, 2}, + PrevAUMHash: parentHash[:], + }, + { + MessageKind: AUMRemoveKey, + KeyID: []byte{3, 4}, + PrevAUMHash: parentHash[:], + }, + } + + if err := chonk.CommitVerifiedAUMs(data); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + stored, err := chonk.ChildAUMs(parentHash) + if err != nil { + t.Fatalf("ChildAUMs failed: %v", err) + } + if diff := cmp.Diff(data, stored); diff != "" { + t.Errorf("stored AUM differs (-want, +got):\n%s", diff) + } + }) + } +} + +func TestTailchonk_AUMMissing(t *testing.T) { + for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { + t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { + var notExists AUMHash + notExists[:][0] = 42 + if _, err := chonk.AUM(notExists); err != os.ErrNotExist { + t.Errorf("chonk.AUM(notExists).err = %v, want %v", err, os.ErrNotExist) + } + }) + } +} + +func TestTailchonkMem_Orphans(t *testing.T) { + chonk := Mem{} + + parentHash := randHash(t, 1) + orphan := AUM{MessageKind: AUMNoOp} + aums := []AUM{ + orphan, + // A parent is specified, so we shouldnt see it in GetOrphans() + { + MessageKind: AUMRemoveKey, + KeyID: []byte{3, 4}, + PrevAUMHash: parentHash[:], + }, + } + if err := chonk.CommitVerifiedAUMs(aums); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + + stored, err := chonk.Orphans() + if err != nil { + t.Fatalf("Orphans failed: %v", err) + } + if diff := cmp.Diff([]AUM{orphan}, stored); diff != "" { + t.Errorf("stored AUM differs (-want, +got):\n%s", diff) + } +} + +func TestTailchonk_ReadChainFromHead(t *testing.T) { + for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { + + t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { + genesis := AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}} + gHash := genesis.Hash() + intermediate := AUM{PrevAUMHash: gHash[:]} + iHash := intermediate.Hash() + leaf := AUM{PrevAUMHash: iHash[:]} + + commitSet := []AUM{ + genesis, + intermediate, + leaf, + } + if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + // t.Logf("genesis hash = %X", genesis.Hash()) + // t.Logf("intermediate hash = %X", intermediate.Hash()) + // t.Logf("leaf hash = %X", leaf.Hash()) + + // Read the chain from the leaf backwards. + gotLeafs, err := chonk.Heads() + if err != nil { + t.Fatalf("Heads failed: %v", err) + } + if diff := cmp.Diff([]AUM{leaf}, gotLeafs); diff != "" { + t.Fatalf("leaf AUM differs (-want, +got):\n%s", diff) + } + + parent, _ := gotLeafs[0].Parent() + gotIntermediate, err := chonk.AUM(parent) + if err != nil { + t.Fatalf("AUM() failed: %v", err) + } + if diff := cmp.Diff(intermediate, gotIntermediate); diff != "" { + t.Errorf("intermediate AUM differs (-want, +got):\n%s", diff) + } + + parent, _ = gotIntermediate.Parent() + gotGenesis, err := chonk.AUM(parent) + if err != nil { + t.Fatalf("AUM() failed: %v", err) + } + if diff := cmp.Diff(genesis, gotGenesis); diff != "" { + t.Errorf("genesis AUM differs (-want, +got):\n%s", diff) + } + }) + } +} + +func TestTailchonkFS_Commit(t *testing.T) { + chonk := &FS{base: t.TempDir()} + parentHash := randHash(t, 1) + aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} + + if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { + t.Fatal(err) + } + + dir, base := chonk.aumDir(aum.Hash()) + if got, want := dir, filepath.Join(chonk.base, "PD"); got != want { + t.Errorf("aum dir=%s, want %s", got, want) + } + if want := "PD57DVP6GKC76OOZMXFFZUSOEFQXOLAVT7N2ZM5KB3HDIMCANF4A"; base != want { + t.Errorf("aum base=%s, want %s", base, want) + } + if _, err := os.Stat(filepath.Join(dir, base)); err != nil { + t.Errorf("stat of AUM file failed: %v", err) + } + if _, err := os.Stat(filepath.Join(chonk.base, "M7", "M7LL2NDB4NKCZIUPVS6RDM2GUOIMW6EEAFVBWMVCPUANQJPHT3SQ")); err != nil { + t.Errorf("stat of AUM parent failed: %v", err) + } + + info, err := chonk.get(aum.Hash()) + if err != nil { + t.Fatal(err) + } + if info.PurgedUnix > 0 { + t.Errorf("recently-created AUM PurgedUnix = %d, want 0", info.PurgedUnix) + } +} + +func TestTailchonkFS_CommitTime(t *testing.T) { + chonk := &FS{base: t.TempDir()} + parentHash := randHash(t, 1) + aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} + + if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { + t.Fatal(err) + } + ct, err := chonk.CommitTime(aum.Hash()) + if err != nil { + t.Fatalf("CommitTime() failed: %v", err) + } + if ct.Before(time.Now().Add(-time.Minute)) || ct.After(time.Now().Add(time.Minute)) { + t.Errorf("commit time was wrong: %v more than a minute off from now (%v)", ct, time.Now()) + } +} + +func TestTailchonkFS_PurgeAUMs(t *testing.T) { + chonk := &FS{base: t.TempDir()} + parentHash := randHash(t, 1) + aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} + + if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { + t.Fatal(err) + } + if err := chonk.PurgeAUMs([]AUMHash{aum.Hash()}); err != nil { + t.Fatal(err) + } + + if _, err := chonk.AUM(aum.Hash()); err != os.ErrNotExist { + t.Errorf("AUM() on purged AUM returned err = %v, want ErrNotExist", err) + } + + info, err := chonk.get(aum.Hash()) + if err != nil { + t.Fatal(err) + } + if info.PurgedUnix == 0 { + t.Errorf("recently-created AUM PurgedUnix = %d, want non-zero", info.PurgedUnix) + } +} + +func TestTailchonkFS_AllAUMs(t *testing.T) { + chonk := &FS{base: t.TempDir()} + genesis := AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}} + gHash := genesis.Hash() + intermediate := AUM{PrevAUMHash: gHash[:]} + iHash := intermediate.Hash() + leaf := AUM{PrevAUMHash: iHash[:]} + + commitSet := []AUM{ + genesis, + intermediate, + leaf, + } + if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + + hashes, err := chonk.AllAUMs() + if err != nil { + t.Fatal(err) + } + hashesLess := func(a, b AUMHash) bool { + return bytes.Compare(a[:], b[:]) < 0 + } + if diff := cmp.Diff([]AUMHash{genesis.Hash(), intermediate.Hash(), leaf.Hash()}, hashes, cmpopts.SortSlices(hashesLess)); diff != "" { + t.Fatalf("AllAUMs() output differs (-want, +got):\n%s", diff) + } +} + +func TestMarkActiveChain(t *testing.T) { + type aumTemplate struct { + AUM AUM + } + + tcs := []struct { + name string + minChain int + chain []aumTemplate + expectLastActiveIdx int // expected lastActiveAncestor, corresponds to an index on chain. + }{ + { + name: "genesis", + minChain: 2, + chain: []aumTemplate{ + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + }, + expectLastActiveIdx: 0, + }, + { + name: "simple truncate", + minChain: 2, + chain: []aumTemplate{ + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + }, + expectLastActiveIdx: 1, + }, + { + name: "long truncate", + minChain: 5, + chain: []aumTemplate{ + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + }, + expectLastActiveIdx: 2, + }, + { + name: "truncate finding checkpoint", + minChain: 2, + chain: []aumTemplate{ + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMAddKey, Key: &Key{}}}, // Should keep searching upwards for a checkpoint + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + }, + expectLastActiveIdx: 1, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + verdict := make(map[AUMHash]retainState, len(tc.chain)) + + // Build the state of the tailchonk for tests. + storage := &Mem{} + var prev AUMHash + for i := range tc.chain { + if !prev.IsZero() { + tc.chain[i].AUM.PrevAUMHash = make([]byte, len(prev[:])) + copy(tc.chain[i].AUM.PrevAUMHash, prev[:]) + } + if err := storage.CommitVerifiedAUMs([]AUM{tc.chain[i].AUM}); err != nil { + t.Fatal(err) + } + + h := tc.chain[i].AUM.Hash() + prev = h + verdict[h] = 0 + } + + got, err := markActiveChain(storage, verdict, tc.minChain, prev) + if err != nil { + t.Logf("state = %+v", verdict) + t.Fatalf("markActiveChain() failed: %v", err) + } + want := tc.chain[tc.expectLastActiveIdx].AUM.Hash() + if got != want { + t.Logf("state = %+v", verdict) + t.Errorf("lastActiveAncestor = %v, want %v", got, want) + } + + // Make sure the verdict array was marked correctly. + for i := range tc.chain { + h := tc.chain[i].AUM.Hash() + if i >= tc.expectLastActiveIdx { + if (verdict[h] & retainStateActive) == 0 { + t.Errorf("verdict[%v] = %v, want %v set", h, verdict[h], retainStateActive) + } + } else { + if (verdict[h] & retainStateCandidate) == 0 { + t.Errorf("verdict[%v] = %v, want %v set", h, verdict[h], retainStateCandidate) + } + } + } + }) + } +} + +func TestMarkDescendantAUMs(t *testing.T) { + c := newTestchain(t, ` + genesis -> B -> C -> C2 + | -> D + | -> E -> F -> G -> H + | -> E2 + + // tweak seeds so hashes arent identical + C.hashSeed = 1 + D.hashSeed = 2 + E.hashSeed = 3 + E2.hashSeed = 4 + `) + + verdict := make(map[AUMHash]retainState, len(c.AUMs)) + for _, a := range c.AUMs { + verdict[a.Hash()] = 0 + } + + // Mark E & C. + verdict[c.AUMHashes["C"]] = retainStateActive + verdict[c.AUMHashes["E"]] = retainStateActive + + if err := markDescendantAUMs(c.Chonk(), verdict); err != nil { + t.Errorf("markDescendantAUMs() failed: %v", err) + } + + // Make sure the descendants got marked. + hs := c.AUMHashes + for _, h := range []AUMHash{hs["C2"], hs["F"], hs["G"], hs["H"], hs["E2"]} { + if (verdict[h] & retainStateLeaf) == 0 { + t.Errorf("%v was not marked as a descendant", h) + } + } + for _, h := range []AUMHash{hs["genesis"], hs["B"], hs["D"]} { + if (verdict[h] & retainStateLeaf) != 0 { + t.Errorf("%v was marked as a descendant and shouldnt be", h) + } + } +} + +func TestMarkAncestorIntersectionAUMs(t *testing.T) { + fakeState := &State{ + Keys: []Key{{Kind: Key25519, Votes: 1}}, + DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, + } + + tcs := []struct { + name string + chain *testChain + verdicts map[string]retainState + initialAncestor string + wantAncestor string + wantRetained []string + wantDeleted []string + }{ + { + name: "genesis", + chain: newTestchain(t, ` + A + A.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "A", + wantAncestor: "A", + verdicts: map[string]retainState{ + "A": retainStateActive, + }, + wantRetained: []string{"A"}, + }, + { + name: "no adjustment", + chain: newTestchain(t, ` + DEAD -> A -> B -> C + A.template = checkpoint + B.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "A", + wantAncestor: "A", + verdicts: map[string]retainState{ + "A": retainStateActive, + "B": retainStateActive, + "C": retainStateActive, + "DEAD": retainStateCandidate, + }, + wantRetained: []string{"A", "B", "C"}, + wantDeleted: []string{"DEAD"}, + }, + { + name: "fork", + chain: newTestchain(t, ` + A -> B -> C -> D + | -> FORK + A.template = checkpoint + C.template = checkpoint + D.template = checkpoint + FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "D", + wantAncestor: "C", + verdicts: map[string]retainState{ + "A": retainStateCandidate, + "B": retainStateCandidate, + "C": retainStateCandidate, + "D": retainStateActive, + "FORK": retainStateYoung, + }, + wantRetained: []string{"C", "D", "FORK"}, + wantDeleted: []string{"A", "B"}, + }, + { + name: "fork finding earlier checkpoint", + chain: newTestchain(t, ` + A -> B -> C -> D -> E -> F + | -> FORK + A.template = checkpoint + B.template = checkpoint + E.template = checkpoint + FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "E", + wantAncestor: "B", + verdicts: map[string]retainState{ + "A": retainStateCandidate, + "B": retainStateCandidate, + "C": retainStateCandidate, + "D": retainStateCandidate, + "E": retainStateActive, + "F": retainStateActive, + "FORK": retainStateYoung, + }, + wantRetained: []string{"B", "C", "D", "E", "F", "FORK"}, + wantDeleted: []string{"A"}, + }, + { + name: "fork multi", + chain: newTestchain(t, ` + A -> B -> C -> D -> E + | -> DEADFORK + C -> FORK + A.template = checkpoint + C.template = checkpoint + D.template = checkpoint + E.template = checkpoint + FORK.hashSeed = 2 + DEADFORK.hashSeed = 3`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "D", + wantAncestor: "C", + verdicts: map[string]retainState{ + "A": retainStateCandidate, + "B": retainStateCandidate, + "C": retainStateCandidate, + "D": retainStateActive, + "E": retainStateActive, + "FORK": retainStateYoung, + "DEADFORK": 0, + }, + wantRetained: []string{"C", "D", "E", "FORK"}, + wantDeleted: []string{"A", "B", "DEADFORK"}, + }, + { + name: "fork multi 2", + chain: newTestchain(t, ` + A -> B -> C -> D -> E -> F -> G + + F -> F1 + D -> F2 + B -> F3 + + A.template = checkpoint + B.template = checkpoint + D.template = checkpoint + F.template = checkpoint + F1.hashSeed = 2 + F2.hashSeed = 3 + F3.hashSeed = 4`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "F", + wantAncestor: "B", + verdicts: map[string]retainState{ + "A": retainStateCandidate, + "B": retainStateCandidate, + "C": retainStateCandidate, + "D": retainStateCandidate, + "E": retainStateCandidate, + "F": retainStateActive, + "G": retainStateActive, + "F1": retainStateYoung, + "F2": retainStateYoung, + "F3": retainStateYoung, + }, + wantRetained: []string{"B", "C", "D", "E", "F", "G", "F1", "F2", "F3"}, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + verdict := make(map[AUMHash]retainState, len(tc.verdicts)) + for name, v := range tc.verdicts { + verdict[tc.chain.AUMHashes[name]] = v + } + + got, err := markAncestorIntersectionAUMs(tc.chain.Chonk(), verdict, tc.chain.AUMHashes[tc.initialAncestor]) + if err != nil { + t.Logf("state = %+v", verdict) + t.Fatalf("markAncestorIntersectionAUMs() failed: %v", err) + } + if want := tc.chain.AUMHashes[tc.wantAncestor]; got != want { + t.Logf("state = %+v", verdict) + t.Errorf("lastActiveAncestor = %v, want %v", got, want) + } + + for _, name := range tc.wantRetained { + h := tc.chain.AUMHashes[name] + if v := verdict[h]; v&retainAUMMask == 0 { + t.Errorf("AUM %q was not retained: verdict = %v", name, v) + } + } + for _, name := range tc.wantDeleted { + h := tc.chain.AUMHashes[name] + if v := verdict[h]; v&retainAUMMask != 0 { + t.Errorf("AUM %q was retained: verdict = %v", name, v) + } + } + + if t.Failed() { + for name, hash := range tc.chain.AUMHashes { + t.Logf("AUM[%q] = %v", name, hash) + } + } + }) + } +} + +type compactingChonkFake struct { + Mem + + aumAge map[AUMHash]time.Time + t *testing.T + wantDelete []AUMHash +} + +func (c *compactingChonkFake) AllAUMs() ([]AUMHash, error) { + out := make([]AUMHash, 0, len(c.Mem.aums)) + for h := range c.Mem.aums { + out = append(out, h) + } + return out, nil +} + +func (c *compactingChonkFake) CommitTime(hash AUMHash) (time.Time, error) { + return c.aumAge[hash], nil +} + +func (c *compactingChonkFake) PurgeAUMs(hashes []AUMHash) error { + lessHashes := func(a, b AUMHash) bool { + return bytes.Compare(a[:], b[:]) < 0 + } + if diff := cmp.Diff(c.wantDelete, hashes, cmpopts.SortSlices(lessHashes)); diff != "" { + c.t.Errorf("deletion set differs (-want, +got):\n%s", diff) + } + return nil +} + +// Avoid go vet complaining about copying a lock value +func cloneMem(src, dst *Mem) { + dst.l = sync.RWMutex{} + dst.aums = src.aums + dst.parentIndex = src.parentIndex + dst.lastActiveAncestor = src.lastActiveAncestor +} + +func TestCompact(t *testing.T) { + fakeState := &State{ + Keys: []Key{{Kind: Key25519, Votes: 1}}, + DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, + } + + // A & B are deleted because the new lastActiveAncestor advances beyond them. + // OLD is deleted because it does not match retention criteria, and + // though it is a descendant of the new lastActiveAncestor (C), it is not a + // descendant of a retained AUM. + // G, & H are retained as recent (MinChain=2) ancestors of HEAD. + // E & F are retained because they are between retained AUMs (G+) and + // their newest checkpoint ancestor. + // D is retained because it is the newest checkpoint ancestor from + // MinChain-retained AUMs. + // G2 is retained because it is a descendant of a retained AUM (G). + // F1 is retained because it is new enough by wall-clock time. + // F2 is retained because it is a descendant of a retained AUM (F1). + // C2 is retained because it is between an ancestor checkpoint and + // a retained AUM (F1). + // C is retained because it is the new lastActiveAncestor. It is the + // new lastActiveAncestor because it is the newest common checkpoint + // of all retained AUMs. + c := newTestchain(t, ` + A -> B -> C -> C2 -> D -> E -> F -> G -> H + | -> F1 -> F2 | -> G2 + | -> OLD + + // make {A,B,C,D} compaction candidates + A.template = checkpoint + B.template = checkpoint + C.template = checkpoint + D.template = checkpoint + + // tweak seeds of forks so hashes arent identical + F1.hashSeed = 1 + OLD.hashSeed = 2 + G2.hashSeed = 3 + `, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})) + + storage := &compactingChonkFake{ + aumAge: map[AUMHash]time.Time{(c.AUMHashes["F1"]): time.Now()}, + t: t, + wantDelete: []AUMHash{c.AUMHashes["A"], c.AUMHashes["B"], c.AUMHashes["OLD"]}, + } + + cloneMem(c.Chonk().(*Mem), &storage.Mem) + + lastActiveAncestor, err := Compact(storage, c.AUMHashes["H"], CompactionOptions{MinChain: 2, MinAge: time.Hour}) + if err != nil { + t.Errorf("Compact() failed: %v", err) + } + if lastActiveAncestor != c.AUMHashes["C"] { + t.Errorf("last active ancestor = %v, want %v", lastActiveAncestor, c.AUMHashes["C"]) + } + + if t.Failed() { + for name, hash := range c.AUMHashes { + t.Logf("AUM[%q] = %v", name, hash) + } + } +} diff --git a/tka/tka_test.go b/tka/tka_test.go index 3438a4016f0f6..9e3c4e79d05bd 100644 --- a/tka/tka_test.go +++ b/tka/tka_test.go @@ -1,654 +1,654 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "testing" - - "github.com/google/go-cmp/cmp" - "tailscale.com/types/key" - "tailscale.com/types/tkatype" -) - -func TestComputeChainCandidates(t *testing.T) { - c := newTestchain(t, ` - G1 -> I1 -> I2 -> I3 -> L2 - | -> L1 | -> L3 - - G2 -> L4 - - // We tweak these AUMs so they are different hashes. - G2.hashSeed = 2 - L1.hashSeed = 2 - L3.hashSeed = 2 - L4.hashSeed = 3 - `) - // Should result in 4 chains: - // G1->L1, G1->L2, G1->L3, G2->L4 - - i1H := c.AUMHashes["I1"] - got, err := computeChainCandidates(c.Chonk(), &i1H, 50) - if err != nil { - t.Fatalf("computeChainCandidates() failed: %v", err) - } - - want := []chain{ - {Oldest: c.AUMs["G2"], Head: c.AUMs["L4"]}, - {Oldest: c.AUMs["G1"], Head: c.AUMs["L3"], chainsThroughActive: true}, - {Oldest: c.AUMs["G1"], Head: c.AUMs["L1"], chainsThroughActive: true}, - {Oldest: c.AUMs["G1"], Head: c.AUMs["L2"], chainsThroughActive: true}, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(chain{})); diff != "" { - t.Errorf("chains differ (-want, +got):\n%s", diff) - } -} - -func TestForkResolutionHash(t *testing.T) { - c := newTestchain(t, ` - G1 -> L1 - | -> L2 - - // tweak hashes so L1 & L2 are not identical - L1.hashSeed = 2 - L2.hashSeed = 3 - `) - - got, err := computeActiveChain(c.Chonk(), nil, 50) - if err != nil { - t.Fatalf("computeActiveChain() failed: %v", err) - } - - // The fork with the lowest AUM hash should have been chosen. - l1H := c.AUMHashes["L1"] - l2H := c.AUMHashes["L2"] - want := l1H - if bytes.Compare(l2H[:], l1H[:]) < 0 { - want = l2H - } - - if got := got.Head.Hash(); got != want { - t.Errorf("head was %x, want %x", got, want) - } -} - -func TestForkResolutionSigWeight(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> L1 - | -> L2 - - G1.template = addKey - L1.hashSeed = 11 - L2.signedWith = key - `, - optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}), - optKey("key", key, priv)) - - l1H := c.AUMHashes["L1"] - l2H := c.AUMHashes["L2"] - if bytes.Compare(l2H[:], l1H[:]) < 0 { - t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes") - } - - got, err := computeActiveChain(c.Chonk(), nil, 50) - if err != nil { - t.Fatalf("computeActiveChain() failed: %v", err) - } - - // Based on the hash, l1H should be chosen. - // But based on the signature weight (which has higher - // precedence), it should be l2H - want := l2H - if got := got.Head.Hash(); got != want { - t.Errorf("head was %x, want %x", got, want) - } -} - -func TestForkResolutionMessageType(t *testing.T) { - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> L1 - | -> L2 - | -> L3 - - G1.template = addKey - L1.hashSeed = 11 - L2.template = removeKey - L3.hashSeed = 18 - `, - optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}), - optTemplate("removeKey", AUM{MessageKind: AUMRemoveKey, KeyID: key.MustID()})) - - l1H := c.AUMHashes["L1"] - l2H := c.AUMHashes["L2"] - l3H := c.AUMHashes["L3"] - if bytes.Compare(l2H[:], l1H[:]) < 0 { - t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes") - } - if bytes.Compare(l2H[:], l3H[:]) < 0 { - t.Fatal("failed assert: h(l3) > h(l2)\nTweak hashSeed till this passes") - } - - got, err := computeActiveChain(c.Chonk(), nil, 50) - if err != nil { - t.Fatalf("computeActiveChain() failed: %v", err) - } - - // Based on the hash, L1 or L3 should be chosen. - // But based on the preference for AUMRemoveKey messages, - // it should be L2. - want := l2H - if got := got.Head.Hash(); got != want { - t.Errorf("head was %x, want %x", got, want) - } -} - -func TestComputeStateAt(t *testing.T) { - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> I1 -> I2 - I1.template = addKey - `, - optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key})) - - // G1 is before the key, so there shouldn't be a key there. - state, err := computeStateAt(c.Chonk(), 500, c.AUMHashes["G1"]) - if err != nil { - t.Fatalf("computeStateAt(G1) failed: %v", err) - } - if _, err := state.GetKey(key.MustID()); err != ErrNoSuchKey { - t.Errorf("expected key to be missing: err = %v", err) - } - if *state.LastAUMHash != c.AUMHashes["G1"] { - t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, c.AUMHashes["G1"]) - } - - // I1 & I2 are after the key, so the computed state should contain - // the key. - for _, wantHash := range []AUMHash{c.AUMHashes["I1"], c.AUMHashes["I2"]} { - state, err = computeStateAt(c.Chonk(), 500, wantHash) - if err != nil { - t.Fatalf("computeStateAt(%X) failed: %v", wantHash, err) - } - if *state.LastAUMHash != wantHash { - t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, wantHash) - } - if _, err := state.GetKey(key.MustID()); err != nil { - t.Errorf("expected key to be present at state: err = %v", err) - } - } -} - -// fakeAUM generates an AUM structure based on the template. -// If parent is provided, PrevAUMHash is set to that value. -// -// If template is an AUM, the returned AUM is based on that. -// If template is an int, a NOOP AUM is returned, and the -// provided int can be used to tweak the resulting hash (needed -// for tests you want one AUM to be 'lower' than another, so that -// that chain is taken based on fork resolution rules). -func fakeAUM(t *testing.T, template any, parent *AUMHash) (AUM, AUMHash) { - if seed, ok := template.(int); ok { - a := AUM{MessageKind: AUMNoOp, KeyID: []byte{byte(seed)}} - if parent != nil { - a.PrevAUMHash = (*parent)[:] - } - h := a.Hash() - return a, h - } - - if a, ok := template.(AUM); ok { - if parent != nil { - a.PrevAUMHash = (*parent)[:] - } - h := a.Hash() - return a, h - } - - panic("template must be an int or an AUM") -} - -func TestOpenAuthority(t *testing.T) { - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - // /- L1 - // G1 - I1 - I2 - I3 -L2 - // \-L3 - // G2 - L4 - // - // We set the previous-known ancestor to G1, so the - // ancestor to start from should be G1. - g1, g1H := fakeAUM(t, AUM{MessageKind: AUMAddKey, Key: &key}, nil) - i1, i1H := fakeAUM(t, 2, &g1H) // AUM{MessageKind: AUMAddKey, Key: &key2} - l1, l1H := fakeAUM(t, 13, &i1H) - - i2, i2H := fakeAUM(t, 2, &i1H) - i3, i3H := fakeAUM(t, 5, &i2H) - l2, l2H := fakeAUM(t, AUM{MessageKind: AUMNoOp, KeyID: []byte{7}, Signatures: []tkatype.Signature{{KeyID: key.MustID()}}}, &i3H) - l3, l3H := fakeAUM(t, 4, &i3H) - - g2, g2H := fakeAUM(t, 8, nil) - l4, _ := fakeAUM(t, 9, &g2H) - - // We make sure that I2 has a lower hash than L1, so - // it should take that path rather than L1. - if bytes.Compare(l1H[:], i2H[:]) < 0 { - t.Fatal("failed assert: h(i2) > h(l1)\nTweak parameters to fakeAUM till this passes") - } - // We make sure L2 has a signature with key, so it should - // take that path over L3. We assert that the L3 hash - // is less than L2 so the test will fail if the signature - // preference logic is broken. - if bytes.Compare(l2H[:], l3H[:]) < 0 { - t.Fatal("failed assert: h(l3) > h(l2)\nTweak parameters to fakeAUM till this passes") - } - - // Construct the state of durable storage. - chonk := &Mem{} - err := chonk.CommitVerifiedAUMs([]AUM{g1, i1, l1, i2, i3, l2, l3, g2, l4}) - if err != nil { - t.Fatal(err) - } - chonk.SetLastActiveAncestor(i1H) - - a, err := Open(chonk) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - // Should include the key added in G1 - if _, err := a.state.GetKey(key.MustID()); err != nil { - t.Errorf("missing G1 key: %v", err) - } - // The head of the chain should be L2. - if a.Head() != l2H { - t.Errorf("head was %x, want %x", a.state.LastAUMHash, l2H) - } -} - -func TestOpenAuthority_EmptyErrors(t *testing.T) { - _, err := Open(&Mem{}) - if err == nil { - t.Error("Expected an error initializing an empty authority, got nil") - } -} - -func TestAuthorityHead(t *testing.T) { - c := newTestchain(t, ` - G1 -> L1 - | -> L2 - - L1.hashSeed = 2 - `) - - a, _ := Open(c.Chonk()) - if got, want := a.head.Hash(), a.Head(); got != want { - t.Errorf("Hash() returned %x, want %x", got, want) - } -} - -func TestAuthorityValidDisablement(t *testing.T) { - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - c := newTestchain(t, ` - G1 -> L1 - - G1.template = genesis - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - ) - - a, _ := Open(c.Chonk()) - if valid := a.ValidDisablement([]byte{1, 2, 3}); !valid { - t.Error("ValidDisablement() returned false, want true") - } -} - -func TestCreateBootstrapAuthority(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - a1, genesisAUM, err := Create(&Mem{}, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - a2, err := Bootstrap(&Mem{}, genesisAUM) - if err != nil { - t.Fatalf("Bootstrap() failed: %v", err) - } - - if a1.Head() != a2.Head() { - t.Fatal("created and bootstrapped authority differ") - } - - // Both authorities should trust the key laid down in the genesis state. - if !a1.KeyTrusted(key.MustID()) { - t.Error("a1 did not trust genesis key") - } - if !a2.KeyTrusted(key.MustID()) { - t.Error("a2 did not trust genesis key") - } -} - -func TestAuthorityInformNonLinear(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> L1 - | -> L2 -> L3 - | -> L4 -> L5 - - G1.template = genesis - L1.hashSeed = 3 - L2.hashSeed = 2 - L4.hashSeed = 2 - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optKey("key", key, priv), - optSignAllUsing("key")) - - storage := &Mem{} - a, err := Bootstrap(storage, c.AUMs["G1"]) - if err != nil { - t.Fatalf("Bootstrap() failed: %v", err) - } - - // L2 does not chain from L1, disabling the isHeadChain optimization - // and forcing Inform() to take the slow path. - informAUMs := []AUM{c.AUMs["L1"], c.AUMs["L2"], c.AUMs["L3"], c.AUMs["L4"], c.AUMs["L5"]} - - if err := a.Inform(storage, informAUMs); err != nil { - t.Fatalf("Inform() failed: %v", err) - } - for i, update := range informAUMs { - stored, err := storage.AUM(update.Hash()) - if err != nil { - t.Errorf("reading stored update %d: %v", i, err) - continue - } - if diff := cmp.Diff(update, stored); diff != "" { - t.Errorf("update %d differs (-want, +got):\n%s", i, diff) - } - } - - if a.Head() != c.AUMHashes["L3"] { - t.Fatal("authority did not converge to correct AUM") - } -} - -func TestAuthorityInformLinear(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> L1 -> L2 -> L3 - - G1.template = genesis - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optKey("key", key, priv), - optSignAllUsing("key")) - - storage := &Mem{} - a, err := Bootstrap(storage, c.AUMs["G1"]) - if err != nil { - t.Fatalf("Bootstrap() failed: %v", err) - } - - informAUMs := []AUM{c.AUMs["L1"], c.AUMs["L2"], c.AUMs["L3"]} - - if err := a.Inform(storage, informAUMs); err != nil { - t.Fatalf("Inform() failed: %v", err) - } - for i, update := range informAUMs { - stored, err := storage.AUM(update.Hash()) - if err != nil { - t.Errorf("reading stored update %d: %v", i, err) - continue - } - if diff := cmp.Diff(update, stored); diff != "" { - t.Errorf("update %d differs (-want, +got):\n%s", i, diff) - } - } - - if a.Head() != c.AUMHashes["L3"] { - t.Fatal("authority did not converge to correct AUM") - } -} - -func TestInteropWithNLKey(t *testing.T) { - priv1 := key.NewNLPrivate() - pub1 := priv1.Public() - pub2 := key.NewNLPrivate().Public() - pub3 := key.NewNLPrivate().Public() - - a, _, err := Create(&Mem{}, State{ - Keys: []Key{ - { - Kind: Key25519, - Votes: 1, - Public: pub1.KeyID(), - }, - { - Kind: Key25519, - Votes: 1, - Public: pub2.KeyID(), - }, - }, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, priv1) - if err != nil { - t.Errorf("tka.Create: %v", err) - return - } - - if !a.KeyTrusted(pub1.KeyID()) { - t.Error("pub1 want trusted, got untrusted") - } - if !a.KeyTrusted(pub2.KeyID()) { - t.Error("pub2 want trusted, got untrusted") - } - if a.KeyTrusted(pub3.KeyID()) { - t.Error("pub3 want untrusted, got trusted") - } -} - -func TestAuthorityCompact(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G -> A -> B -> C -> D -> E - - G.template = genesis - C.template = checkpoint2 - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optTemplate("checkpoint2", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optKey("key", key, priv), - optSignAllUsing("key")) - - storage := &FS{base: t.TempDir()} - a, err := Bootstrap(storage, c.AUMs["G"]) - if err != nil { - t.Fatalf("Bootstrap() failed: %v", err) - } - a.Inform(storage, []AUM{c.AUMs["A"], c.AUMs["B"], c.AUMs["C"], c.AUMs["D"], c.AUMs["E"]}) - - // Should compact down to C -> D -> E - if err := a.Compact(storage, CompactionOptions{MinChain: 2, MinAge: 1}); err != nil { - t.Fatal(err) - } - if a.oldestAncestor.Hash() != c.AUMHashes["C"] { - t.Errorf("ancestor = %v, want %v", a.oldestAncestor.Hash(), c.AUMHashes["C"]) - } - - // Make sure the stored authority is still openable and resolves to the same state. - stored, err := Open(storage) - if err != nil { - t.Fatalf("Failed to open stored authority: %v", err) - } - if stored.Head() != a.Head() { - t.Errorf("Stored authority head differs: head = %v, want %v", stored.Head(), a.Head()) - } - t.Logf("original ancestor = %v", c.AUMHashes["G"]) - if anc, _ := storage.LastActiveAncestor(); *anc != c.AUMHashes["C"] { - t.Errorf("ancestor = %v, want %v", anc, c.AUMHashes["C"]) - } -} - -func TestFindParentForRewrite(t *testing.T) { - pub, _ := testingKey25519(t, 1) - k1 := Key{Kind: Key25519, Public: pub, Votes: 1} - - pub2, _ := testingKey25519(t, 2) - k2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - k2ID, _ := k2.ID() - pub3, _ := testingKey25519(t, 3) - k3 := Key{Kind: Key25519, Public: pub3, Votes: 1} - - c := newTestchain(t, ` - A -> B -> C -> D -> E - A.template = genesis - B.template = add2 - C.template = add3 - D.template = remove2 - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{k1}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), - optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3}), - optTemplate("remove2", AUM{MessageKind: AUMRemoveKey, KeyID: k2ID})) - - a, err := Open(c.Chonk()) - if err != nil { - t.Fatal(err) - } - - // k1 was trusted at genesis, so there's no better rewrite parent - // than the genesis. - k1ID, _ := k1.ID() - k1P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k1ID}, k1ID) - if err != nil { - t.Fatalf("FindParentForRewrite(k1) failed: %v", err) - } - if k1P != a.oldestAncestor.Hash() { - t.Errorf("FindParentForRewrite(k1) = %v, want %v", k1P, a.oldestAncestor.Hash()) - } - - // k3 was trusted at C, so B would be an ideal rewrite point. - k3ID, _ := k3.ID() - k3P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k3ID}, k1ID) - if err != nil { - t.Fatalf("FindParentForRewrite(k3) failed: %v", err) - } - if k3P != c.AUMHashes["B"] { - t.Errorf("FindParentForRewrite(k3) = %v, want %v", k3P, c.AUMHashes["B"]) - } - - // k2 was added but then removed, so HEAD is an appropriate rewrite point. - k2P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k2ID}, k1ID) - if err != nil { - t.Fatalf("FindParentForRewrite(k2) failed: %v", err) - } - if k3P != c.AUMHashes["B"] { - t.Errorf("FindParentForRewrite(k2) = %v, want %v", k2P, a.Head()) - } - - // There's no appropriate point where both k2 and k3 are simultaneously not trusted, - // so the best rewrite point is the genesis AUM. - doubleP, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k2ID, k3ID}, k1ID) - if err != nil { - t.Fatalf("FindParentForRewrite({k2, k3}) failed: %v", err) - } - if doubleP != a.oldestAncestor.Hash() { - t.Errorf("FindParentForRewrite({k2, k3}) = %v, want %v", doubleP, a.oldestAncestor.Hash()) - } -} - -func TestMakeRetroactiveRevocation(t *testing.T) { - pub, _ := testingKey25519(t, 1) - k1 := Key{Kind: Key25519, Public: pub, Votes: 1} - - pub2, _ := testingKey25519(t, 2) - k2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - pub3, _ := testingKey25519(t, 3) - k3 := Key{Kind: Key25519, Public: pub3, Votes: 1} - - c := newTestchain(t, ` - A -> B -> C -> D - A.template = genesis - C.template = add2 - D.template = add3 - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{k1}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), - optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3})) - - a, err := Open(c.Chonk()) - if err != nil { - t.Fatal(err) - } - - // k2 was added by C, so a forking revocation should: - // - have B as a parent - // - trust the remaining keys at the time, k1 & k3. - k1ID, _ := k1.ID() - k2ID, _ := k2.ID() - k3ID, _ := k3.ID() - forkingAUM, err := a.MakeRetroactiveRevocation(c.Chonk(), []tkatype.KeyID{k2ID}, k1ID, AUMHash{}) - if err != nil { - t.Fatalf("MakeRetroactiveRevocation(k2) failed: %v", err) - } - if bHash := c.AUMHashes["B"]; !bytes.Equal(forkingAUM.PrevAUMHash, bHash[:]) { - t.Errorf("forking AUM has parent %v, want %v", forkingAUM.PrevAUMHash, bHash[:]) - } - if _, err := forkingAUM.State.GetKey(k1ID); err != nil { - t.Error("Forked state did not trust k1") - } - if _, err := forkingAUM.State.GetKey(k3ID); err != nil { - t.Error("Forked state did not trust k3") - } - if _, err := forkingAUM.State.GetKey(k2ID); err == nil { - t.Error("Forked state trusted removed-key k2") - } - - // Test that removing all trusted keys results in an error. - _, err = a.MakeRetroactiveRevocation(c.Chonk(), []tkatype.KeyID{k1ID, k2ID, k3ID}, k1ID, AUMHash{}) - if wantErr := "cannot revoke all trusted keys"; err == nil || err.Error() != wantErr { - t.Fatalf("MakeRetroactiveRevocation({k1, k2, k3}) returned %v, expected %q", err, wantErr) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/types/key" + "tailscale.com/types/tkatype" +) + +func TestComputeChainCandidates(t *testing.T) { + c := newTestchain(t, ` + G1 -> I1 -> I2 -> I3 -> L2 + | -> L1 | -> L3 + + G2 -> L4 + + // We tweak these AUMs so they are different hashes. + G2.hashSeed = 2 + L1.hashSeed = 2 + L3.hashSeed = 2 + L4.hashSeed = 3 + `) + // Should result in 4 chains: + // G1->L1, G1->L2, G1->L3, G2->L4 + + i1H := c.AUMHashes["I1"] + got, err := computeChainCandidates(c.Chonk(), &i1H, 50) + if err != nil { + t.Fatalf("computeChainCandidates() failed: %v", err) + } + + want := []chain{ + {Oldest: c.AUMs["G2"], Head: c.AUMs["L4"]}, + {Oldest: c.AUMs["G1"], Head: c.AUMs["L3"], chainsThroughActive: true}, + {Oldest: c.AUMs["G1"], Head: c.AUMs["L1"], chainsThroughActive: true}, + {Oldest: c.AUMs["G1"], Head: c.AUMs["L2"], chainsThroughActive: true}, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(chain{})); diff != "" { + t.Errorf("chains differ (-want, +got):\n%s", diff) + } +} + +func TestForkResolutionHash(t *testing.T) { + c := newTestchain(t, ` + G1 -> L1 + | -> L2 + + // tweak hashes so L1 & L2 are not identical + L1.hashSeed = 2 + L2.hashSeed = 3 + `) + + got, err := computeActiveChain(c.Chonk(), nil, 50) + if err != nil { + t.Fatalf("computeActiveChain() failed: %v", err) + } + + // The fork with the lowest AUM hash should have been chosen. + l1H := c.AUMHashes["L1"] + l2H := c.AUMHashes["L2"] + want := l1H + if bytes.Compare(l2H[:], l1H[:]) < 0 { + want = l2H + } + + if got := got.Head.Hash(); got != want { + t.Errorf("head was %x, want %x", got, want) + } +} + +func TestForkResolutionSigWeight(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 + | -> L2 + + G1.template = addKey + L1.hashSeed = 11 + L2.signedWith = key + `, + optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}), + optKey("key", key, priv)) + + l1H := c.AUMHashes["L1"] + l2H := c.AUMHashes["L2"] + if bytes.Compare(l2H[:], l1H[:]) < 0 { + t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes") + } + + got, err := computeActiveChain(c.Chonk(), nil, 50) + if err != nil { + t.Fatalf("computeActiveChain() failed: %v", err) + } + + // Based on the hash, l1H should be chosen. + // But based on the signature weight (which has higher + // precedence), it should be l2H + want := l2H + if got := got.Head.Hash(); got != want { + t.Errorf("head was %x, want %x", got, want) + } +} + +func TestForkResolutionMessageType(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 + | -> L2 + | -> L3 + + G1.template = addKey + L1.hashSeed = 11 + L2.template = removeKey + L3.hashSeed = 18 + `, + optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}), + optTemplate("removeKey", AUM{MessageKind: AUMRemoveKey, KeyID: key.MustID()})) + + l1H := c.AUMHashes["L1"] + l2H := c.AUMHashes["L2"] + l3H := c.AUMHashes["L3"] + if bytes.Compare(l2H[:], l1H[:]) < 0 { + t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes") + } + if bytes.Compare(l2H[:], l3H[:]) < 0 { + t.Fatal("failed assert: h(l3) > h(l2)\nTweak hashSeed till this passes") + } + + got, err := computeActiveChain(c.Chonk(), nil, 50) + if err != nil { + t.Fatalf("computeActiveChain() failed: %v", err) + } + + // Based on the hash, L1 or L3 should be chosen. + // But based on the preference for AUMRemoveKey messages, + // it should be L2. + want := l2H + if got := got.Head.Hash(); got != want { + t.Errorf("head was %x, want %x", got, want) + } +} + +func TestComputeStateAt(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> I1 -> I2 + I1.template = addKey + `, + optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key})) + + // G1 is before the key, so there shouldn't be a key there. + state, err := computeStateAt(c.Chonk(), 500, c.AUMHashes["G1"]) + if err != nil { + t.Fatalf("computeStateAt(G1) failed: %v", err) + } + if _, err := state.GetKey(key.MustID()); err != ErrNoSuchKey { + t.Errorf("expected key to be missing: err = %v", err) + } + if *state.LastAUMHash != c.AUMHashes["G1"] { + t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, c.AUMHashes["G1"]) + } + + // I1 & I2 are after the key, so the computed state should contain + // the key. + for _, wantHash := range []AUMHash{c.AUMHashes["I1"], c.AUMHashes["I2"]} { + state, err = computeStateAt(c.Chonk(), 500, wantHash) + if err != nil { + t.Fatalf("computeStateAt(%X) failed: %v", wantHash, err) + } + if *state.LastAUMHash != wantHash { + t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, wantHash) + } + if _, err := state.GetKey(key.MustID()); err != nil { + t.Errorf("expected key to be present at state: err = %v", err) + } + } +} + +// fakeAUM generates an AUM structure based on the template. +// If parent is provided, PrevAUMHash is set to that value. +// +// If template is an AUM, the returned AUM is based on that. +// If template is an int, a NOOP AUM is returned, and the +// provided int can be used to tweak the resulting hash (needed +// for tests you want one AUM to be 'lower' than another, so that +// that chain is taken based on fork resolution rules). +func fakeAUM(t *testing.T, template any, parent *AUMHash) (AUM, AUMHash) { + if seed, ok := template.(int); ok { + a := AUM{MessageKind: AUMNoOp, KeyID: []byte{byte(seed)}} + if parent != nil { + a.PrevAUMHash = (*parent)[:] + } + h := a.Hash() + return a, h + } + + if a, ok := template.(AUM); ok { + if parent != nil { + a.PrevAUMHash = (*parent)[:] + } + h := a.Hash() + return a, h + } + + panic("template must be an int or an AUM") +} + +func TestOpenAuthority(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + // /- L1 + // G1 - I1 - I2 - I3 -L2 + // \-L3 + // G2 - L4 + // + // We set the previous-known ancestor to G1, so the + // ancestor to start from should be G1. + g1, g1H := fakeAUM(t, AUM{MessageKind: AUMAddKey, Key: &key}, nil) + i1, i1H := fakeAUM(t, 2, &g1H) // AUM{MessageKind: AUMAddKey, Key: &key2} + l1, l1H := fakeAUM(t, 13, &i1H) + + i2, i2H := fakeAUM(t, 2, &i1H) + i3, i3H := fakeAUM(t, 5, &i2H) + l2, l2H := fakeAUM(t, AUM{MessageKind: AUMNoOp, KeyID: []byte{7}, Signatures: []tkatype.Signature{{KeyID: key.MustID()}}}, &i3H) + l3, l3H := fakeAUM(t, 4, &i3H) + + g2, g2H := fakeAUM(t, 8, nil) + l4, _ := fakeAUM(t, 9, &g2H) + + // We make sure that I2 has a lower hash than L1, so + // it should take that path rather than L1. + if bytes.Compare(l1H[:], i2H[:]) < 0 { + t.Fatal("failed assert: h(i2) > h(l1)\nTweak parameters to fakeAUM till this passes") + } + // We make sure L2 has a signature with key, so it should + // take that path over L3. We assert that the L3 hash + // is less than L2 so the test will fail if the signature + // preference logic is broken. + if bytes.Compare(l2H[:], l3H[:]) < 0 { + t.Fatal("failed assert: h(l3) > h(l2)\nTweak parameters to fakeAUM till this passes") + } + + // Construct the state of durable storage. + chonk := &Mem{} + err := chonk.CommitVerifiedAUMs([]AUM{g1, i1, l1, i2, i3, l2, l3, g2, l4}) + if err != nil { + t.Fatal(err) + } + chonk.SetLastActiveAncestor(i1H) + + a, err := Open(chonk) + if err != nil { + t.Fatalf("New() failed: %v", err) + } + // Should include the key added in G1 + if _, err := a.state.GetKey(key.MustID()); err != nil { + t.Errorf("missing G1 key: %v", err) + } + // The head of the chain should be L2. + if a.Head() != l2H { + t.Errorf("head was %x, want %x", a.state.LastAUMHash, l2H) + } +} + +func TestOpenAuthority_EmptyErrors(t *testing.T) { + _, err := Open(&Mem{}) + if err == nil { + t.Error("Expected an error initializing an empty authority, got nil") + } +} + +func TestAuthorityHead(t *testing.T) { + c := newTestchain(t, ` + G1 -> L1 + | -> L2 + + L1.hashSeed = 2 + `) + + a, _ := Open(c.Chonk()) + if got, want := a.head.Hash(), a.Head(); got != want { + t.Errorf("Hash() returned %x, want %x", got, want) + } +} + +func TestAuthorityValidDisablement(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + c := newTestchain(t, ` + G1 -> L1 + + G1.template = genesis + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + ) + + a, _ := Open(c.Chonk()) + if valid := a.ValidDisablement([]byte{1, 2, 3}); !valid { + t.Error("ValidDisablement() returned false, want true") + } +} + +func TestCreateBootstrapAuthority(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + a1, genesisAUM, err := Create(&Mem{}, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + a2, err := Bootstrap(&Mem{}, genesisAUM) + if err != nil { + t.Fatalf("Bootstrap() failed: %v", err) + } + + if a1.Head() != a2.Head() { + t.Fatal("created and bootstrapped authority differ") + } + + // Both authorities should trust the key laid down in the genesis state. + if !a1.KeyTrusted(key.MustID()) { + t.Error("a1 did not trust genesis key") + } + if !a2.KeyTrusted(key.MustID()) { + t.Error("a2 did not trust genesis key") + } +} + +func TestAuthorityInformNonLinear(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 + | -> L2 -> L3 + | -> L4 -> L5 + + G1.template = genesis + L1.hashSeed = 3 + L2.hashSeed = 2 + L4.hashSeed = 2 + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optKey("key", key, priv), + optSignAllUsing("key")) + + storage := &Mem{} + a, err := Bootstrap(storage, c.AUMs["G1"]) + if err != nil { + t.Fatalf("Bootstrap() failed: %v", err) + } + + // L2 does not chain from L1, disabling the isHeadChain optimization + // and forcing Inform() to take the slow path. + informAUMs := []AUM{c.AUMs["L1"], c.AUMs["L2"], c.AUMs["L3"], c.AUMs["L4"], c.AUMs["L5"]} + + if err := a.Inform(storage, informAUMs); err != nil { + t.Fatalf("Inform() failed: %v", err) + } + for i, update := range informAUMs { + stored, err := storage.AUM(update.Hash()) + if err != nil { + t.Errorf("reading stored update %d: %v", i, err) + continue + } + if diff := cmp.Diff(update, stored); diff != "" { + t.Errorf("update %d differs (-want, +got):\n%s", i, diff) + } + } + + if a.Head() != c.AUMHashes["L3"] { + t.Fatal("authority did not converge to correct AUM") + } +} + +func TestAuthorityInformLinear(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 -> L2 -> L3 + + G1.template = genesis + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optKey("key", key, priv), + optSignAllUsing("key")) + + storage := &Mem{} + a, err := Bootstrap(storage, c.AUMs["G1"]) + if err != nil { + t.Fatalf("Bootstrap() failed: %v", err) + } + + informAUMs := []AUM{c.AUMs["L1"], c.AUMs["L2"], c.AUMs["L3"]} + + if err := a.Inform(storage, informAUMs); err != nil { + t.Fatalf("Inform() failed: %v", err) + } + for i, update := range informAUMs { + stored, err := storage.AUM(update.Hash()) + if err != nil { + t.Errorf("reading stored update %d: %v", i, err) + continue + } + if diff := cmp.Diff(update, stored); diff != "" { + t.Errorf("update %d differs (-want, +got):\n%s", i, diff) + } + } + + if a.Head() != c.AUMHashes["L3"] { + t.Fatal("authority did not converge to correct AUM") + } +} + +func TestInteropWithNLKey(t *testing.T) { + priv1 := key.NewNLPrivate() + pub1 := priv1.Public() + pub2 := key.NewNLPrivate().Public() + pub3 := key.NewNLPrivate().Public() + + a, _, err := Create(&Mem{}, State{ + Keys: []Key{ + { + Kind: Key25519, + Votes: 1, + Public: pub1.KeyID(), + }, + { + Kind: Key25519, + Votes: 1, + Public: pub2.KeyID(), + }, + }, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, priv1) + if err != nil { + t.Errorf("tka.Create: %v", err) + return + } + + if !a.KeyTrusted(pub1.KeyID()) { + t.Error("pub1 want trusted, got untrusted") + } + if !a.KeyTrusted(pub2.KeyID()) { + t.Error("pub2 want trusted, got untrusted") + } + if a.KeyTrusted(pub3.KeyID()) { + t.Error("pub3 want untrusted, got trusted") + } +} + +func TestAuthorityCompact(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G -> A -> B -> C -> D -> E + + G.template = genesis + C.template = checkpoint2 + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optTemplate("checkpoint2", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optKey("key", key, priv), + optSignAllUsing("key")) + + storage := &FS{base: t.TempDir()} + a, err := Bootstrap(storage, c.AUMs["G"]) + if err != nil { + t.Fatalf("Bootstrap() failed: %v", err) + } + a.Inform(storage, []AUM{c.AUMs["A"], c.AUMs["B"], c.AUMs["C"], c.AUMs["D"], c.AUMs["E"]}) + + // Should compact down to C -> D -> E + if err := a.Compact(storage, CompactionOptions{MinChain: 2, MinAge: 1}); err != nil { + t.Fatal(err) + } + if a.oldestAncestor.Hash() != c.AUMHashes["C"] { + t.Errorf("ancestor = %v, want %v", a.oldestAncestor.Hash(), c.AUMHashes["C"]) + } + + // Make sure the stored authority is still openable and resolves to the same state. + stored, err := Open(storage) + if err != nil { + t.Fatalf("Failed to open stored authority: %v", err) + } + if stored.Head() != a.Head() { + t.Errorf("Stored authority head differs: head = %v, want %v", stored.Head(), a.Head()) + } + t.Logf("original ancestor = %v", c.AUMHashes["G"]) + if anc, _ := storage.LastActiveAncestor(); *anc != c.AUMHashes["C"] { + t.Errorf("ancestor = %v, want %v", anc, c.AUMHashes["C"]) + } +} + +func TestFindParentForRewrite(t *testing.T) { + pub, _ := testingKey25519(t, 1) + k1 := Key{Kind: Key25519, Public: pub, Votes: 1} + + pub2, _ := testingKey25519(t, 2) + k2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + k2ID, _ := k2.ID() + pub3, _ := testingKey25519(t, 3) + k3 := Key{Kind: Key25519, Public: pub3, Votes: 1} + + c := newTestchain(t, ` + A -> B -> C -> D -> E + A.template = genesis + B.template = add2 + C.template = add3 + D.template = remove2 + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{k1}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), + optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3}), + optTemplate("remove2", AUM{MessageKind: AUMRemoveKey, KeyID: k2ID})) + + a, err := Open(c.Chonk()) + if err != nil { + t.Fatal(err) + } + + // k1 was trusted at genesis, so there's no better rewrite parent + // than the genesis. + k1ID, _ := k1.ID() + k1P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k1ID}, k1ID) + if err != nil { + t.Fatalf("FindParentForRewrite(k1) failed: %v", err) + } + if k1P != a.oldestAncestor.Hash() { + t.Errorf("FindParentForRewrite(k1) = %v, want %v", k1P, a.oldestAncestor.Hash()) + } + + // k3 was trusted at C, so B would be an ideal rewrite point. + k3ID, _ := k3.ID() + k3P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k3ID}, k1ID) + if err != nil { + t.Fatalf("FindParentForRewrite(k3) failed: %v", err) + } + if k3P != c.AUMHashes["B"] { + t.Errorf("FindParentForRewrite(k3) = %v, want %v", k3P, c.AUMHashes["B"]) + } + + // k2 was added but then removed, so HEAD is an appropriate rewrite point. + k2P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k2ID}, k1ID) + if err != nil { + t.Fatalf("FindParentForRewrite(k2) failed: %v", err) + } + if k3P != c.AUMHashes["B"] { + t.Errorf("FindParentForRewrite(k2) = %v, want %v", k2P, a.Head()) + } + + // There's no appropriate point where both k2 and k3 are simultaneously not trusted, + // so the best rewrite point is the genesis AUM. + doubleP, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k2ID, k3ID}, k1ID) + if err != nil { + t.Fatalf("FindParentForRewrite({k2, k3}) failed: %v", err) + } + if doubleP != a.oldestAncestor.Hash() { + t.Errorf("FindParentForRewrite({k2, k3}) = %v, want %v", doubleP, a.oldestAncestor.Hash()) + } +} + +func TestMakeRetroactiveRevocation(t *testing.T) { + pub, _ := testingKey25519(t, 1) + k1 := Key{Kind: Key25519, Public: pub, Votes: 1} + + pub2, _ := testingKey25519(t, 2) + k2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + pub3, _ := testingKey25519(t, 3) + k3 := Key{Kind: Key25519, Public: pub3, Votes: 1} + + c := newTestchain(t, ` + A -> B -> C -> D + A.template = genesis + C.template = add2 + D.template = add3 + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{k1}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), + optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3})) + + a, err := Open(c.Chonk()) + if err != nil { + t.Fatal(err) + } + + // k2 was added by C, so a forking revocation should: + // - have B as a parent + // - trust the remaining keys at the time, k1 & k3. + k1ID, _ := k1.ID() + k2ID, _ := k2.ID() + k3ID, _ := k3.ID() + forkingAUM, err := a.MakeRetroactiveRevocation(c.Chonk(), []tkatype.KeyID{k2ID}, k1ID, AUMHash{}) + if err != nil { + t.Fatalf("MakeRetroactiveRevocation(k2) failed: %v", err) + } + if bHash := c.AUMHashes["B"]; !bytes.Equal(forkingAUM.PrevAUMHash, bHash[:]) { + t.Errorf("forking AUM has parent %v, want %v", forkingAUM.PrevAUMHash, bHash[:]) + } + if _, err := forkingAUM.State.GetKey(k1ID); err != nil { + t.Error("Forked state did not trust k1") + } + if _, err := forkingAUM.State.GetKey(k3ID); err != nil { + t.Error("Forked state did not trust k3") + } + if _, err := forkingAUM.State.GetKey(k2ID); err == nil { + t.Error("Forked state trusted removed-key k2") + } + + // Test that removing all trusted keys results in an error. + _, err = a.MakeRetroactiveRevocation(c.Chonk(), []tkatype.KeyID{k1ID, k2ID, k3ID}, k1ID, AUMHash{}) + if wantErr := "cannot revoke all trusted keys"; err == nil || err.Error() != wantErr { + t.Fatalf("MakeRetroactiveRevocation({k1, k2, k3}) returned %v, expected %q", err, wantErr) + } +} diff --git a/tool/binaryen.rev b/tool/binaryen.rev index e0d03ab88bb4a..58c9bdf9d017f 100644 --- a/tool/binaryen.rev +++ b/tool/binaryen.rev @@ -1 +1 @@ -111 +111 diff --git a/tool/go b/tool/go index 3c99f3e2fceeb..1c53683d52f95 100755 --- a/tool/go +++ b/tool/go @@ -1,7 +1,7 @@ -#!/bin/sh -# -# This script acts like the "go" command, but uses Tailscale's -# currently-desired version from https://github.com/tailscale/go, -# downloading it first if necessary. - -exec "$(dirname "$0")/../tool/gocross/gocross-wrapper.sh" "$@" +#!/bin/sh +# +# This script acts like the "go" command, but uses Tailscale's +# currently-desired version from https://github.com/tailscale/go, +# downloading it first if necessary. + +exec "$(dirname "$0")/../tool/gocross/gocross-wrapper.sh" "$@" diff --git a/tool/gocross/env.go b/tool/gocross/env.go index 249476dc1b5a3..9d8a4f1b390b4 100644 --- a/tool/gocross/env.go +++ b/tool/gocross/env.go @@ -1,131 +1,131 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "fmt" - "os" - "sort" - "strings" -) - -// Environment starts from an initial set of environment variables, and tracks -// mutations to the environment. It can then apply those mutations to the -// environment, or produce debugging output that illustrates the changes it -// would make. -type Environment struct { - init map[string]string - set map[string]string - unset map[string]bool - - setenv func(string, string) error - unsetenv func(string) error -} - -// NewEnvironment returns an Environment initialized from os.Environ. -func NewEnvironment() *Environment { - init := map[string]string{} - for _, env := range os.Environ() { - fs := strings.SplitN(env, "=", 2) - if len(fs) != 2 { - panic("bad environ provided") - } - init[fs[0]] = fs[1] - } - - return newEnvironmentForTest(init, os.Setenv, os.Unsetenv) -} - -func newEnvironmentForTest(init map[string]string, setenv func(string, string) error, unsetenv func(string) error) *Environment { - return &Environment{ - init: init, - set: map[string]string{}, - unset: map[string]bool{}, - setenv: setenv, - unsetenv: unsetenv, - } -} - -// Set sets the environment variable k to v. -func (e *Environment) Set(k, v string) { - e.set[k] = v - delete(e.unset, k) -} - -// Unset removes the environment variable k. -func (e *Environment) Unset(k string) { - delete(e.set, k) - e.unset[k] = true -} - -// IsSet reports whether the environment variable k is set. -func (e *Environment) IsSet(k string) bool { - if e.unset[k] { - return false - } - if _, ok := e.init[k]; ok { - return true - } - if _, ok := e.set[k]; ok { - return true - } - return false -} - -// Get returns the value of the environment variable k, or defaultVal if it is -// not set. -func (e *Environment) Get(k, defaultVal string) string { - if e.unset[k] { - return defaultVal - } - if v, ok := e.set[k]; ok { - return v - } - if v, ok := e.init[k]; ok { - return v - } - return defaultVal -} - -// Apply applies all pending mutations to the environment. -func (e *Environment) Apply() error { - for k, v := range e.set { - if err := e.setenv(k, v); err != nil { - return fmt.Errorf("setting %q: %v", k, err) - } - e.init[k] = v - delete(e.set, k) - } - for k := range e.unset { - if err := e.unsetenv(k); err != nil { - return fmt.Errorf("unsetting %q: %v", k, err) - } - delete(e.init, k) - delete(e.unset, k) - } - return nil -} - -// Diff returns a string describing the pending mutations to the environment. -func (e *Environment) Diff() string { - lines := make([]string, 0, len(e.set)+len(e.unset)) - for k, v := range e.set { - old, ok := e.init[k] - if ok { - lines = append(lines, fmt.Sprintf("%s=%s (was %s)", k, v, old)) - } else { - lines = append(lines, fmt.Sprintf("%s=%s (was )", k, v)) - } - } - for k := range e.unset { - old, ok := e.init[k] - if ok { - lines = append(lines, fmt.Sprintf("%s= (was %s)", k, old)) - } else { - lines = append(lines, fmt.Sprintf("%s= (was )", k)) - } - } - sort.Strings(lines) - return strings.Join(lines, "\n") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "fmt" + "os" + "sort" + "strings" +) + +// Environment starts from an initial set of environment variables, and tracks +// mutations to the environment. It can then apply those mutations to the +// environment, or produce debugging output that illustrates the changes it +// would make. +type Environment struct { + init map[string]string + set map[string]string + unset map[string]bool + + setenv func(string, string) error + unsetenv func(string) error +} + +// NewEnvironment returns an Environment initialized from os.Environ. +func NewEnvironment() *Environment { + init := map[string]string{} + for _, env := range os.Environ() { + fs := strings.SplitN(env, "=", 2) + if len(fs) != 2 { + panic("bad environ provided") + } + init[fs[0]] = fs[1] + } + + return newEnvironmentForTest(init, os.Setenv, os.Unsetenv) +} + +func newEnvironmentForTest(init map[string]string, setenv func(string, string) error, unsetenv func(string) error) *Environment { + return &Environment{ + init: init, + set: map[string]string{}, + unset: map[string]bool{}, + setenv: setenv, + unsetenv: unsetenv, + } +} + +// Set sets the environment variable k to v. +func (e *Environment) Set(k, v string) { + e.set[k] = v + delete(e.unset, k) +} + +// Unset removes the environment variable k. +func (e *Environment) Unset(k string) { + delete(e.set, k) + e.unset[k] = true +} + +// IsSet reports whether the environment variable k is set. +func (e *Environment) IsSet(k string) bool { + if e.unset[k] { + return false + } + if _, ok := e.init[k]; ok { + return true + } + if _, ok := e.set[k]; ok { + return true + } + return false +} + +// Get returns the value of the environment variable k, or defaultVal if it is +// not set. +func (e *Environment) Get(k, defaultVal string) string { + if e.unset[k] { + return defaultVal + } + if v, ok := e.set[k]; ok { + return v + } + if v, ok := e.init[k]; ok { + return v + } + return defaultVal +} + +// Apply applies all pending mutations to the environment. +func (e *Environment) Apply() error { + for k, v := range e.set { + if err := e.setenv(k, v); err != nil { + return fmt.Errorf("setting %q: %v", k, err) + } + e.init[k] = v + delete(e.set, k) + } + for k := range e.unset { + if err := e.unsetenv(k); err != nil { + return fmt.Errorf("unsetting %q: %v", k, err) + } + delete(e.init, k) + delete(e.unset, k) + } + return nil +} + +// Diff returns a string describing the pending mutations to the environment. +func (e *Environment) Diff() string { + lines := make([]string, 0, len(e.set)+len(e.unset)) + for k, v := range e.set { + old, ok := e.init[k] + if ok { + lines = append(lines, fmt.Sprintf("%s=%s (was %s)", k, v, old)) + } else { + lines = append(lines, fmt.Sprintf("%s=%s (was )", k, v)) + } + } + for k := range e.unset { + old, ok := e.init[k] + if ok { + lines = append(lines, fmt.Sprintf("%s= (was %s)", k, old)) + } else { + lines = append(lines, fmt.Sprintf("%s= (was )", k)) + } + } + sort.Strings(lines) + return strings.Join(lines, "\n") +} diff --git a/tool/gocross/env_test.go b/tool/gocross/env_test.go index 9a797530d72cd..001487bb8e1a6 100644 --- a/tool/gocross/env_test.go +++ b/tool/gocross/env_test.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestEnv(t *testing.T) { - - var ( - init = map[string]string{ - "FOO": "bar", - } - - wasSet = map[string]string{} - wasUnset = map[string]bool{} - - setenv = func(k, v string) error { - wasSet[k] = v - return nil - } - unsetenv = func(k string) error { - wasUnset[k] = true - return nil - } - ) - - env := newEnvironmentForTest(init, setenv, unsetenv) - - if got, want := env.Get("FOO", ""), "bar"; got != want { - t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) - } - if got, want := env.IsSet("FOO"), true; got != want { - t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) - } - - if got, want := env.Get("BAR", "defaultVal"), "defaultVal"; got != want { - t.Errorf(`env.Get("BAR") = %q, want %q`, got, want) - } - if got, want := env.IsSet("BAR"), false; got != want { - t.Errorf(`env.IsSet("BAR") = %v, want %v`, got, want) - } - - env.Set("BAR", "quux") - if got, want := env.Get("BAR", ""), "quux"; got != want { - t.Errorf(`env.Get("BAR") = %q, want %q`, got, want) - } - if got, want := env.IsSet("BAR"), true; got != want { - t.Errorf(`env.IsSet("BAR") = %v, want %v`, got, want) - } - diff := "BAR=quux (was )" - if got := env.Diff(); got != diff { - t.Errorf("env.Diff() = %q, want %q", got, diff) - } - - env.Set("FOO", "foo2") - if got, want := env.Get("FOO", ""), "foo2"; got != want { - t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) - } - if got, want := env.IsSet("FOO"), true; got != want { - t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) - } - diff = `BAR=quux (was ) -FOO=foo2 (was bar)` - if got := env.Diff(); got != diff { - t.Errorf("env.Diff() = %q, want %q", got, diff) - } - - env.Unset("FOO") - if got, want := env.Get("FOO", "default"), "default"; got != want { - t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) - } - if got, want := env.IsSet("FOO"), false; got != want { - t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) - } - diff = `BAR=quux (was ) -FOO= (was bar)` - if got := env.Diff(); got != diff { - t.Errorf("env.Diff() = %q, want %q", got, diff) - } - - if err := env.Apply(); err != nil { - t.Fatalf("env.Apply() failed: %v", err) - } - - wantSet := map[string]string{"BAR": "quux"} - wantUnset := map[string]bool{"FOO": true} - - if diff := cmp.Diff(wasSet, wantSet); diff != "" { - t.Errorf("env.Apply didn't set as expected (-got+want):\n%s", diff) - } - if diff := cmp.Diff(wasUnset, wantUnset); diff != "" { - t.Errorf("env.Apply didn't unset as expected (-got+want):\n%s", diff) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestEnv(t *testing.T) { + + var ( + init = map[string]string{ + "FOO": "bar", + } + + wasSet = map[string]string{} + wasUnset = map[string]bool{} + + setenv = func(k, v string) error { + wasSet[k] = v + return nil + } + unsetenv = func(k string) error { + wasUnset[k] = true + return nil + } + ) + + env := newEnvironmentForTest(init, setenv, unsetenv) + + if got, want := env.Get("FOO", ""), "bar"; got != want { + t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) + } + if got, want := env.IsSet("FOO"), true; got != want { + t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) + } + + if got, want := env.Get("BAR", "defaultVal"), "defaultVal"; got != want { + t.Errorf(`env.Get("BAR") = %q, want %q`, got, want) + } + if got, want := env.IsSet("BAR"), false; got != want { + t.Errorf(`env.IsSet("BAR") = %v, want %v`, got, want) + } + + env.Set("BAR", "quux") + if got, want := env.Get("BAR", ""), "quux"; got != want { + t.Errorf(`env.Get("BAR") = %q, want %q`, got, want) + } + if got, want := env.IsSet("BAR"), true; got != want { + t.Errorf(`env.IsSet("BAR") = %v, want %v`, got, want) + } + diff := "BAR=quux (was )" + if got := env.Diff(); got != diff { + t.Errorf("env.Diff() = %q, want %q", got, diff) + } + + env.Set("FOO", "foo2") + if got, want := env.Get("FOO", ""), "foo2"; got != want { + t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) + } + if got, want := env.IsSet("FOO"), true; got != want { + t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) + } + diff = `BAR=quux (was ) +FOO=foo2 (was bar)` + if got := env.Diff(); got != diff { + t.Errorf("env.Diff() = %q, want %q", got, diff) + } + + env.Unset("FOO") + if got, want := env.Get("FOO", "default"), "default"; got != want { + t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) + } + if got, want := env.IsSet("FOO"), false; got != want { + t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) + } + diff = `BAR=quux (was ) +FOO= (was bar)` + if got := env.Diff(); got != diff { + t.Errorf("env.Diff() = %q, want %q", got, diff) + } + + if err := env.Apply(); err != nil { + t.Fatalf("env.Apply() failed: %v", err) + } + + wantSet := map[string]string{"BAR": "quux"} + wantUnset := map[string]bool{"FOO": true} + + if diff := cmp.Diff(wasSet, wantSet); diff != "" { + t.Errorf("env.Apply didn't set as expected (-got+want):\n%s", diff) + } + if diff := cmp.Diff(wasUnset, wantUnset); diff != "" { + t.Errorf("env.Apply didn't unset as expected (-got+want):\n%s", diff) + } +} diff --git a/tool/gocross/exec_other.go b/tool/gocross/exec_other.go index ec9663df7c7d9..8d4df0db334dd 100644 --- a/tool/gocross/exec_other.go +++ b/tool/gocross/exec_other.go @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !unix - -package main - -import ( - "os" - "os/exec" -) - -func doExec(cmd string, args []string, env []string) error { - c := exec.Command(cmd, args...) - c.Env = env - c.Stdin = os.Stdin - c.Stdout = os.Stdout - c.Stderr = os.Stderr - return c.Run() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !unix + +package main + +import ( + "os" + "os/exec" +) + +func doExec(cmd string, args []string, env []string) error { + c := exec.Command(cmd, args...) + c.Env = env + c.Stdin = os.Stdin + c.Stdout = os.Stdout + c.Stderr = os.Stderr + return c.Run() +} diff --git a/tool/gocross/exec_unix.go b/tool/gocross/exec_unix.go index eeffd5f939aab..79cbf764ad2f6 100644 --- a/tool/gocross/exec_unix.go +++ b/tool/gocross/exec_unix.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build unix - -package main - -import "golang.org/x/sys/unix" - -func doExec(cmd string, args []string, env []string) error { - return unix.Exec(cmd, args, env) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build unix + +package main + +import "golang.org/x/sys/unix" + +func doExec(cmd string, args []string, env []string) error { + return unix.Exec(cmd, args, env) +} diff --git a/tool/helm b/tool/helm index 8cbc2f2065ead..3f9a9dfd5ba21 100755 --- a/tool/helm +++ b/tool/helm @@ -1,69 +1,69 @@ -#!/usr/bin/env bash - -# installs $(cat ./helm.rev) version of helm as $HOME/.cache/tailscale-helm - -set -euo pipefail - -if [[ "${CI:-}" == "true" ]]; then - set -x -fi - -( - if [[ "${CI:-}" == "true" ]]; then - set -x - fi - - repo_root="${BASH_SOURCE%/*}/../" - cd "$repo_root" - - cachedir="$HOME/.cache/tailscale-helm" - tarball="${cachedir}.tar.gz" - - read -r want_rev < "$(dirname "$0")/helm.rev" - - got_rev="" - if [[ -x "${cachedir}/helm" ]]; then - got_rev=$("${cachedir}/helm" version --short) - got_rev="${got_rev#v}" # trim the leading 'v' - got_rev="${got_rev%+*}" # trim the trailing '+" followed by a commit SHA' - - - fi - - if [[ "$want_rev" != "$got_rev" ]]; then - rm -rf "$cachedir" "$tarball" - if [[ -n "${IN_NIX_SHELL:-}" ]]; then - nix_helm="$(which -a helm | grep /nix/store | head -1)" - nix_helm="${nix_helm%/helm}" - nix_helm_rev="${nix_helm##*-}" - if [[ "$nix_helm_rev" != "$want_rev" ]]; then - echo "Wrong helm version in Nix, got $nix_helm_rev want $want_rev" >&2 - exit 1 - fi - ln -sf "$nix_helm" "$cachedir" - else - # works for linux and darwin - # https://github.com/helm/helm/releases - OS=$(uname -s | tr A-Z a-z) - ARCH=$(uname -m) - if [ "$ARCH" = "x86_64" ]; then - ARCH="amd64" - fi - if [ "$ARCH" = "aarch64" ]; then - ARCH="arm64" - fi - mkdir -p "$cachedir" - # When running on GitHub in CI, the below curl sometimes fails with - # INTERNAL_ERROR after finishing the download. The most common cause - # of INTERNAL_ERROR is glitches in intermediate hosts handling of - # HTTP/2 forwarding, so forcing HTTP 1.1 often fixes the issue. See - # https://github.com/tailscale/tailscale/issues/8988 - curl -f -L --http1.1 -o "$tarball" -sSL "https://get.helm.sh/helm-v${want_rev}-${OS}-${ARCH}.tar.gz" - (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") - rm -f "$tarball" - fi - fi -) - -export PATH="$HOME/.cache/tailscale-helm:$PATH" -exec "$HOME/.cache/tailscale-helm/helm" "$@" +#!/usr/bin/env bash + +# installs $(cat ./helm.rev) version of helm as $HOME/.cache/tailscale-helm + +set -euo pipefail + +if [[ "${CI:-}" == "true" ]]; then + set -x +fi + +( + if [[ "${CI:-}" == "true" ]]; then + set -x + fi + + repo_root="${BASH_SOURCE%/*}/../" + cd "$repo_root" + + cachedir="$HOME/.cache/tailscale-helm" + tarball="${cachedir}.tar.gz" + + read -r want_rev < "$(dirname "$0")/helm.rev" + + got_rev="" + if [[ -x "${cachedir}/helm" ]]; then + got_rev=$("${cachedir}/helm" version --short) + got_rev="${got_rev#v}" # trim the leading 'v' + got_rev="${got_rev%+*}" # trim the trailing '+" followed by a commit SHA' + + + fi + + if [[ "$want_rev" != "$got_rev" ]]; then + rm -rf "$cachedir" "$tarball" + if [[ -n "${IN_NIX_SHELL:-}" ]]; then + nix_helm="$(which -a helm | grep /nix/store | head -1)" + nix_helm="${nix_helm%/helm}" + nix_helm_rev="${nix_helm##*-}" + if [[ "$nix_helm_rev" != "$want_rev" ]]; then + echo "Wrong helm version in Nix, got $nix_helm_rev want $want_rev" >&2 + exit 1 + fi + ln -sf "$nix_helm" "$cachedir" + else + # works for linux and darwin + # https://github.com/helm/helm/releases + OS=$(uname -s | tr A-Z a-z) + ARCH=$(uname -m) + if [ "$ARCH" = "x86_64" ]; then + ARCH="amd64" + fi + if [ "$ARCH" = "aarch64" ]; then + ARCH="arm64" + fi + mkdir -p "$cachedir" + # When running on GitHub in CI, the below curl sometimes fails with + # INTERNAL_ERROR after finishing the download. The most common cause + # of INTERNAL_ERROR is glitches in intermediate hosts handling of + # HTTP/2 forwarding, so forcing HTTP 1.1 often fixes the issue. See + # https://github.com/tailscale/tailscale/issues/8988 + curl -f -L --http1.1 -o "$tarball" -sSL "https://get.helm.sh/helm-v${want_rev}-${OS}-${ARCH}.tar.gz" + (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") + rm -f "$tarball" + fi + fi +) + +export PATH="$HOME/.cache/tailscale-helm:$PATH" +exec "$HOME/.cache/tailscale-helm/helm" "$@" diff --git a/tool/helm.rev b/tool/helm.rev index 0d0e48dd05bb9..c10780c628ad5 100644 --- a/tool/helm.rev +++ b/tool/helm.rev @@ -1 +1 @@ -3.13.1 +3.13.1 diff --git a/tool/node b/tool/node index 7e96826f34988..310140ae5bfa0 100755 --- a/tool/node +++ b/tool/node @@ -1,65 +1,65 @@ -#!/usr/bin/env bash -# Run a command with our local node install, rather than any globally installed -# instance. - -set -euo pipefail - -if [[ "${CI:-}" == "true" ]]; then - set -x -fi - -( - if [[ "${CI:-}" == "true" ]]; then - set -x - fi - - repo_root="${BASH_SOURCE%/*}/../" - cd "$repo_root" - - cachedir="$HOME/.cache/tailscale-node" - tarball="${cachedir}.tar.gz" - - read -r want_rev < "$(dirname "$0")/node.rev" - - got_rev="" - if [[ -x "${cachedir}/bin/node" ]]; then - got_rev=$("${cachedir}/bin/node" --version) - got_rev="${got_rev#v}" # trim the leading 'v' - fi - - if [[ "$want_rev" != "$got_rev" ]]; then - rm -rf "$cachedir" "$tarball" - if [[ -n "${IN_NIX_SHELL:-}" ]]; then - nix_node="$(which -a node | grep /nix/store | head -1)" - nix_node="${nix_node%/bin/node}" - nix_node_rev="${nix_node##*-}" - if [[ "$nix_node_rev" != "$want_rev" ]]; then - echo "Wrong node version in Nix, got $nix_node_rev want $want_rev" >&2 - exit 1 - fi - ln -sf "$nix_node" "$cachedir" - else - # works for "linux" and "darwin" - OS=$(uname -s | tr A-Z a-z) - ARCH=$(uname -m) - if [ "$ARCH" = "x86_64" ]; then - ARCH="x64" - fi - if [ "$ARCH" = "aarch64" ]; then - ARCH="arm64" - fi - mkdir -p "$cachedir" - # When running on GitHub in CI, the below curl sometimes fails with - # INTERNAL_ERROR after finishing the download. The most common cause - # of INTERNAL_ERROR is glitches in intermediate hosts handling of - # HTTP/2 forwarding, so forcing HTTP 1.1 often fixes the issue. See - # https://github.com/tailscale/tailscale/issues/8988 - curl -f -L --http1.1 -o "$tarball" "https://nodejs.org/dist/v${want_rev}/node-v${want_rev}-${OS}-${ARCH}.tar.gz" - (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") - rm -f "$tarball" - fi - fi -) - -export PATH="$HOME/.cache/tailscale-node/bin:$PATH" -exec "$HOME/.cache/tailscale-node/bin/node" "$@" +#!/usr/bin/env bash +# Run a command with our local node install, rather than any globally installed +# instance. + +set -euo pipefail + +if [[ "${CI:-}" == "true" ]]; then + set -x +fi + +( + if [[ "${CI:-}" == "true" ]]; then + set -x + fi + + repo_root="${BASH_SOURCE%/*}/../" + cd "$repo_root" + + cachedir="$HOME/.cache/tailscale-node" + tarball="${cachedir}.tar.gz" + + read -r want_rev < "$(dirname "$0")/node.rev" + + got_rev="" + if [[ -x "${cachedir}/bin/node" ]]; then + got_rev=$("${cachedir}/bin/node" --version) + got_rev="${got_rev#v}" # trim the leading 'v' + fi + + if [[ "$want_rev" != "$got_rev" ]]; then + rm -rf "$cachedir" "$tarball" + if [[ -n "${IN_NIX_SHELL:-}" ]]; then + nix_node="$(which -a node | grep /nix/store | head -1)" + nix_node="${nix_node%/bin/node}" + nix_node_rev="${nix_node##*-}" + if [[ "$nix_node_rev" != "$want_rev" ]]; then + echo "Wrong node version in Nix, got $nix_node_rev want $want_rev" >&2 + exit 1 + fi + ln -sf "$nix_node" "$cachedir" + else + # works for "linux" and "darwin" + OS=$(uname -s | tr A-Z a-z) + ARCH=$(uname -m) + if [ "$ARCH" = "x86_64" ]; then + ARCH="x64" + fi + if [ "$ARCH" = "aarch64" ]; then + ARCH="arm64" + fi + mkdir -p "$cachedir" + # When running on GitHub in CI, the below curl sometimes fails with + # INTERNAL_ERROR after finishing the download. The most common cause + # of INTERNAL_ERROR is glitches in intermediate hosts handling of + # HTTP/2 forwarding, so forcing HTTP 1.1 often fixes the issue. See + # https://github.com/tailscale/tailscale/issues/8988 + curl -f -L --http1.1 -o "$tarball" "https://nodejs.org/dist/v${want_rev}/node-v${want_rev}-${OS}-${ARCH}.tar.gz" + (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") + rm -f "$tarball" + fi + fi +) + +export PATH="$HOME/.cache/tailscale-node/bin:$PATH" +exec "$HOME/.cache/tailscale-node/bin/node" "$@" diff --git a/tool/wasm-opt b/tool/wasm-opt index 88d332f0b2ca4..08f3e5bfbb841 100755 --- a/tool/wasm-opt +++ b/tool/wasm-opt @@ -1,74 +1,74 @@ -#!/bin/sh -# -# This script acts like the "wasm-opt" command from the Binaryen toolchain, but -# uses Tailscale's currently-desired version, downloading it first if necessary. - -set -eu - -BINARYEN_DIR="$HOME/.cache/tailscale-binaryen" -read -r BINARYEN_REV < "$(dirname "$0")/binaryen.rev" -# This works for Linux and Darwin, which is sufficient -# (we do not build for other targets). -OS=$(uname -s | tr A-Z a-z) -if [ "$OS" = "darwin" ]; then - # Binaryen uses the name "macos". - OS="macos" -fi -ARCH="$(uname -m)" -if [ "$ARCH" = "aarch64" ]; then - # Binaryen uses the name "arm64". - ARCH="arm64" -fi - -install_binaryen() { - BINARYEN_URL="https://github.com/WebAssembly/binaryen/releases/download/version_${BINARYEN_REV}/binaryen-version_${BINARYEN_REV}-${ARCH}-${OS}.tar.gz" - install_tool "wasm-opt" $BINARYEN_REV $BINARYEN_DIR $BINARYEN_URL -} - -install_tool() { - TOOL=$1 - REV=$2 - TOOLCHAIN=$3 - URL=$4 - - archive="$TOOLCHAIN-$REV.tar.gz" - mark="$TOOLCHAIN.extracted" - extracted= - [ ! -e "$mark" ] || read -r extracted junk <$mark - - if [ "$extracted" = "$REV" ] && [ -e "$TOOLCHAIN/bin/$TOOL" ]; then - # Already extracted, continue silently - return 0 - fi - echo "" - - rm -f "$archive.new" "$TOOLCHAIN.extracted" - if [ ! -e "$archive" ]; then - log "Need to download $TOOL '$REV' from $URL." - curl -f -L -o "$archive.new" $URL - rm -f "$archive" - mv "$archive.new" "$archive" - fi - - log "Extracting $TOOL '$REV' into '$TOOLCHAIN'." >&2 - rm -rf "$TOOLCHAIN" - mkdir -p "$TOOLCHAIN" - (cd "$TOOLCHAIN" && tar --strip-components=1 -xf "$archive") - echo "$REV" >$mark -} - -log() { - echo "$@" >&2 -} - -if [ "${BINARYEN_DIR}" = "SKIP" ] || - [ "${OS}" != "macos" -a "${OS}" != "linux" ] || - [ "${ARCH}" != "x86_64" -a "${ARCH}" != "arm64" ]; then - log "Unsupported OS (${OS}) and architecture (${ARCH}) combination." - log "Using existing wasm-opt (`which wasm-opt`)." - exec wasm-opt "$@" -fi - -install_binaryen - -"$BINARYEN_DIR/bin/wasm-opt" "$@" +#!/bin/sh +# +# This script acts like the "wasm-opt" command from the Binaryen toolchain, but +# uses Tailscale's currently-desired version, downloading it first if necessary. + +set -eu + +BINARYEN_DIR="$HOME/.cache/tailscale-binaryen" +read -r BINARYEN_REV < "$(dirname "$0")/binaryen.rev" +# This works for Linux and Darwin, which is sufficient +# (we do not build for other targets). +OS=$(uname -s | tr A-Z a-z) +if [ "$OS" = "darwin" ]; then + # Binaryen uses the name "macos". + OS="macos" +fi +ARCH="$(uname -m)" +if [ "$ARCH" = "aarch64" ]; then + # Binaryen uses the name "arm64". + ARCH="arm64" +fi + +install_binaryen() { + BINARYEN_URL="https://github.com/WebAssembly/binaryen/releases/download/version_${BINARYEN_REV}/binaryen-version_${BINARYEN_REV}-${ARCH}-${OS}.tar.gz" + install_tool "wasm-opt" $BINARYEN_REV $BINARYEN_DIR $BINARYEN_URL +} + +install_tool() { + TOOL=$1 + REV=$2 + TOOLCHAIN=$3 + URL=$4 + + archive="$TOOLCHAIN-$REV.tar.gz" + mark="$TOOLCHAIN.extracted" + extracted= + [ ! -e "$mark" ] || read -r extracted junk <$mark + + if [ "$extracted" = "$REV" ] && [ -e "$TOOLCHAIN/bin/$TOOL" ]; then + # Already extracted, continue silently + return 0 + fi + echo "" + + rm -f "$archive.new" "$TOOLCHAIN.extracted" + if [ ! -e "$archive" ]; then + log "Need to download $TOOL '$REV' from $URL." + curl -f -L -o "$archive.new" $URL + rm -f "$archive" + mv "$archive.new" "$archive" + fi + + log "Extracting $TOOL '$REV' into '$TOOLCHAIN'." >&2 + rm -rf "$TOOLCHAIN" + mkdir -p "$TOOLCHAIN" + (cd "$TOOLCHAIN" && tar --strip-components=1 -xf "$archive") + echo "$REV" >$mark +} + +log() { + echo "$@" >&2 +} + +if [ "${BINARYEN_DIR}" = "SKIP" ] || + [ "${OS}" != "macos" -a "${OS}" != "linux" ] || + [ "${ARCH}" != "x86_64" -a "${ARCH}" != "arm64" ]; then + log "Unsupported OS (${OS}) and architecture (${ARCH}) combination." + log "Using existing wasm-opt (`which wasm-opt`)." + exec wasm-opt "$@" +fi + +install_binaryen + +"$BINARYEN_DIR/bin/wasm-opt" "$@" diff --git a/tool/yarn b/tool/yarn index 6bb01d2f223de..6357beda61cb9 100755 --- a/tool/yarn +++ b/tool/yarn @@ -1,43 +1,43 @@ -#!/usr/bin/env bash -# Run a command with our local yarn install, rather than any globally installed -# instance. - -set -euo pipefail - -if [[ "${CI:-}" == "true" ]]; then - set -x -fi - -( - if [[ "${CI:-}" == "true" ]]; then - set -x - fi - - repo_root="${BASH_SOURCE%/*}/../" - cd "$repo_root" - - ./tool/node --version >/dev/null # Ensure node is unpacked and ready - - cachedir="$HOME/.cache/tailscale-yarn" - tarball="${cachedir}.tar.gz" - - read -r want_rev < "./tool/yarn.rev" - - got_rev="" - if [[ -x "${cachedir}/bin/yarn" ]]; then - got_rev=$(PATH="$HOME/.cache/tailscale-node/bin:$PATH" "${cachedir}/bin/yarn" --version) - fi - - if [[ "$want_rev" != "$got_rev" ]]; then - rm -rf "$cachedir" "$tarball" - mkdir -p "$cachedir" - curl -f -L -o "$tarball" "https://github.com/yarnpkg/yarn/releases/download/v${want_rev}/yarn-v${want_rev}.tar.gz" - (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") - rm -f "$tarball" - fi -) - -# Deliberately not using cachedir here, to keep the environment -# completely pristine for execution of yarn. -export PATH="$HOME/.cache/tailscale-node/bin:$HOME/.cache/tailscale-yarn/bin:$PATH" -exec "$HOME/.cache/tailscale-yarn/bin/yarn" "$@" +#!/usr/bin/env bash +# Run a command with our local yarn install, rather than any globally installed +# instance. + +set -euo pipefail + +if [[ "${CI:-}" == "true" ]]; then + set -x +fi + +( + if [[ "${CI:-}" == "true" ]]; then + set -x + fi + + repo_root="${BASH_SOURCE%/*}/../" + cd "$repo_root" + + ./tool/node --version >/dev/null # Ensure node is unpacked and ready + + cachedir="$HOME/.cache/tailscale-yarn" + tarball="${cachedir}.tar.gz" + + read -r want_rev < "./tool/yarn.rev" + + got_rev="" + if [[ -x "${cachedir}/bin/yarn" ]]; then + got_rev=$(PATH="$HOME/.cache/tailscale-node/bin:$PATH" "${cachedir}/bin/yarn" --version) + fi + + if [[ "$want_rev" != "$got_rev" ]]; then + rm -rf "$cachedir" "$tarball" + mkdir -p "$cachedir" + curl -f -L -o "$tarball" "https://github.com/yarnpkg/yarn/releases/download/v${want_rev}/yarn-v${want_rev}.tar.gz" + (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") + rm -f "$tarball" + fi +) + +# Deliberately not using cachedir here, to keep the environment +# completely pristine for execution of yarn. +export PATH="$HOME/.cache/tailscale-node/bin:$HOME/.cache/tailscale-yarn/bin:$PATH" +exec "$HOME/.cache/tailscale-yarn/bin/yarn" "$@" diff --git a/tool/yarn.rev b/tool/yarn.rev index 736c4acbded70..de5856e86ba27 100644 --- a/tool/yarn.rev +++ b/tool/yarn.rev @@ -1 +1 @@ -1.22.19 +1.22.19 diff --git a/tsnet/example/tshello/tshello.go b/tsnet/example/tshello/tshello.go index 2110c4d9699d8..0cadcdd837d99 100644 --- a/tsnet/example/tshello/tshello.go +++ b/tsnet/example/tshello/tshello.go @@ -1,60 +1,60 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The tshello server demonstrates how to use Tailscale as a library. -package main - -import ( - "crypto/tls" - "flag" - "fmt" - "html" - "log" - "net/http" - "strings" - - "tailscale.com/tsnet" -) - -var ( - addr = flag.String("addr", ":80", "address to listen on") -) - -func main() { - flag.Parse() - s := new(tsnet.Server) - defer s.Close() - ln, err := s.Listen("tcp", *addr) - if err != nil { - log.Fatal(err) - } - defer ln.Close() - - lc, err := s.LocalClient() - if err != nil { - log.Fatal(err) - } - - if *addr == ":443" { - ln = tls.NewListener(ln, &tls.Config{ - GetCertificate: lc.GetCertificate, - }) - } - log.Fatal(http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - who, err := lc.WhoIs(r.Context(), r.RemoteAddr) - if err != nil { - http.Error(w, err.Error(), 500) - return - } - fmt.Fprintf(w, "

Hello, world!

\n") - fmt.Fprintf(w, "

You are %s from %s (%s)

", - html.EscapeString(who.UserProfile.LoginName), - html.EscapeString(firstLabel(who.Node.ComputedName)), - r.RemoteAddr) - }))) -} - -func firstLabel(s string) string { - s, _, _ = strings.Cut(s, ".") - return s -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The tshello server demonstrates how to use Tailscale as a library. +package main + +import ( + "crypto/tls" + "flag" + "fmt" + "html" + "log" + "net/http" + "strings" + + "tailscale.com/tsnet" +) + +var ( + addr = flag.String("addr", ":80", "address to listen on") +) + +func main() { + flag.Parse() + s := new(tsnet.Server) + defer s.Close() + ln, err := s.Listen("tcp", *addr) + if err != nil { + log.Fatal(err) + } + defer ln.Close() + + lc, err := s.LocalClient() + if err != nil { + log.Fatal(err) + } + + if *addr == ":443" { + ln = tls.NewListener(ln, &tls.Config{ + GetCertificate: lc.GetCertificate, + }) + } + log.Fatal(http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + who, err := lc.WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + fmt.Fprintf(w, "

Hello, world!

\n") + fmt.Fprintf(w, "

You are %s from %s (%s)

", + html.EscapeString(who.UserProfile.LoginName), + html.EscapeString(firstLabel(who.Node.ComputedName)), + r.RemoteAddr) + }))) +} + +func firstLabel(s string) string { + s, _, _ = strings.Cut(s, ".") + return s +} diff --git a/tsnet/example/tsnet-http-client/tsnet-http-client.go b/tsnet/example/tsnet-http-client/tsnet-http-client.go index cda52eef75ac1..9666fe9992745 100644 --- a/tsnet/example/tsnet-http-client/tsnet-http-client.go +++ b/tsnet/example/tsnet-http-client/tsnet-http-client.go @@ -1,44 +1,44 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The tshello server demonstrates how to use Tailscale as a library. -package main - -import ( - "flag" - "fmt" - "log" - "os" - "path/filepath" - - "tailscale.com/tsnet" -) - -func main() { - flag.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: %s \n", filepath.Base(os.Args[0])) - os.Exit(2) - } - flag.Parse() - - if flag.NArg() != 1 { - flag.Usage() - } - tailnetURL := flag.Arg(0) - - s := new(tsnet.Server) - defer s.Close() - - if err := s.Start(); err != nil { - log.Fatal(err) - } - - cli := s.HTTPClient() - - resp, err := cli.Get(tailnetURL) - if err != nil { - log.Fatal(err) - } - - resp.Write(os.Stdout) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The tshello server demonstrates how to use Tailscale as a library. +package main + +import ( + "flag" + "fmt" + "log" + "os" + "path/filepath" + + "tailscale.com/tsnet" +) + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s \n", filepath.Base(os.Args[0])) + os.Exit(2) + } + flag.Parse() + + if flag.NArg() != 1 { + flag.Usage() + } + tailnetURL := flag.Arg(0) + + s := new(tsnet.Server) + defer s.Close() + + if err := s.Start(); err != nil { + log.Fatal(err) + } + + cli := s.HTTPClient() + + resp, err := cli.Get(tailnetURL) + if err != nil { + log.Fatal(err) + } + + resp.Write(os.Stdout) +} diff --git a/tsnet/example/web-client/web-client.go b/tsnet/example/web-client/web-client.go index dee7fedfab2ba..541efbaedf3d3 100644 --- a/tsnet/example/web-client/web-client.go +++ b/tsnet/example/web-client/web-client.go @@ -1,46 +1,46 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The web-client command demonstrates serving the Tailscale web client over tsnet. -package main - -import ( - "flag" - "log" - "net/http" - - "tailscale.com/client/web" - "tailscale.com/tsnet" -) - -var ( - addr = flag.String("addr", "localhost:8060", "address of Tailscale web client") -) - -func main() { - flag.Parse() - - s := &tsnet.Server{RunWebClient: true} - defer s.Close() - - lc, err := s.LocalClient() - if err != nil { - log.Fatal(err) - } - - // Serve the Tailscale web client. - ws, err := web.NewServer(web.ServerOpts{ - Mode: web.LoginServerMode, - LocalClient: lc, - }) - if err != nil { - log.Fatal(err) - } - defer ws.Shutdown() - log.Printf("Serving Tailscale web client on http://%s", *addr) - if err := http.ListenAndServe(*addr, ws); err != nil { - if err != http.ErrServerClosed { - log.Fatal(err) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The web-client command demonstrates serving the Tailscale web client over tsnet. +package main + +import ( + "flag" + "log" + "net/http" + + "tailscale.com/client/web" + "tailscale.com/tsnet" +) + +var ( + addr = flag.String("addr", "localhost:8060", "address of Tailscale web client") +) + +func main() { + flag.Parse() + + s := &tsnet.Server{RunWebClient: true} + defer s.Close() + + lc, err := s.LocalClient() + if err != nil { + log.Fatal(err) + } + + // Serve the Tailscale web client. + ws, err := web.NewServer(web.ServerOpts{ + Mode: web.LoginServerMode, + LocalClient: lc, + }) + if err != nil { + log.Fatal(err) + } + defer ws.Shutdown() + log.Printf("Serving Tailscale web client on http://%s", *addr) + if err := http.ListenAndServe(*addr, ws); err != nil { + if err != http.ErrServerClosed { + log.Fatal(err) + } + } +} diff --git a/tsnet/example_tshello_test.go b/tsnet/example_tshello_test.go index 4dec482339e2c..d534bcfd1f1d4 100644 --- a/tsnet/example_tshello_test.go +++ b/tsnet/example_tshello_test.go @@ -1,72 +1,72 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsnet_test - -import ( - "flag" - "fmt" - "html" - "log" - "net/http" - "strings" - - "tailscale.com/tsnet" -) - -func firstLabel(s string) string { - s, _, _ = strings.Cut(s, ".") - return s -} - -// Example_tshello is a full example on using tsnet. When you run this program it will print -// an authentication link. Open it in your favorite web browser and add it to your tailnet -// like any other machine. Open another terminal window and try to ping it: -// -// $ ping tshello -c 2 -// PING tshello (100.105.183.159) 56(84) bytes of data. -// 64 bytes from tshello.your-tailnet.ts.net (100.105.183.159): icmp_seq=1 ttl=64 time=25.0 ms -// 64 bytes from tshello.your-tailnet.ts.net (100.105.183.159): icmp_seq=2 ttl=64 time=1.12 ms -// -// Then connect to it using curl: -// -// $ curl http://tshello -//

Hello, world!

-//

You are Xe from pneuma (100.78.40.86:49214)

-// -// From here you can do anything you want with the Go standard library HTTP stack, or anything -// that is compatible with it (Gin/Gonic, Gorilla/mux, etc.). -func Example_tshello() { - var ( - addr = flag.String("addr", ":80", "address to listen on") - hostname = flag.String("hostname", "tshello", "hostname to use on the tailnet") - ) - - flag.Parse() - s := new(tsnet.Server) - s.Hostname = *hostname - defer s.Close() - ln, err := s.Listen("tcp", *addr) - if err != nil { - log.Fatal(err) - } - defer ln.Close() - - lc, err := s.LocalClient() - if err != nil { - log.Fatal(err) - } - - log.Fatal(http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - who, err := lc.WhoIs(r.Context(), r.RemoteAddr) - if err != nil { - http.Error(w, err.Error(), 500) - return - } - fmt.Fprintf(w, "

Hello, tailnet!

\n") - fmt.Fprintf(w, "

You are %s from %s (%s)

", - html.EscapeString(who.UserProfile.LoginName), - html.EscapeString(firstLabel(who.Node.ComputedName)), - r.RemoteAddr) - }))) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsnet_test + +import ( + "flag" + "fmt" + "html" + "log" + "net/http" + "strings" + + "tailscale.com/tsnet" +) + +func firstLabel(s string) string { + s, _, _ = strings.Cut(s, ".") + return s +} + +// Example_tshello is a full example on using tsnet. When you run this program it will print +// an authentication link. Open it in your favorite web browser and add it to your tailnet +// like any other machine. Open another terminal window and try to ping it: +// +// $ ping tshello -c 2 +// PING tshello (100.105.183.159) 56(84) bytes of data. +// 64 bytes from tshello.your-tailnet.ts.net (100.105.183.159): icmp_seq=1 ttl=64 time=25.0 ms +// 64 bytes from tshello.your-tailnet.ts.net (100.105.183.159): icmp_seq=2 ttl=64 time=1.12 ms +// +// Then connect to it using curl: +// +// $ curl http://tshello +//

Hello, world!

+//

You are Xe from pneuma (100.78.40.86:49214)

+// +// From here you can do anything you want with the Go standard library HTTP stack, or anything +// that is compatible with it (Gin/Gonic, Gorilla/mux, etc.). +func Example_tshello() { + var ( + addr = flag.String("addr", ":80", "address to listen on") + hostname = flag.String("hostname", "tshello", "hostname to use on the tailnet") + ) + + flag.Parse() + s := new(tsnet.Server) + s.Hostname = *hostname + defer s.Close() + ln, err := s.Listen("tcp", *addr) + if err != nil { + log.Fatal(err) + } + defer ln.Close() + + lc, err := s.LocalClient() + if err != nil { + log.Fatal(err) + } + + log.Fatal(http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + who, err := lc.WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + fmt.Fprintf(w, "

Hello, tailnet!

\n") + fmt.Fprintf(w, "

You are %s from %s (%s)

", + html.EscapeString(who.UserProfile.LoginName), + html.EscapeString(firstLabel(who.Node.ComputedName)), + r.RemoteAddr) + }))) +} diff --git a/tstest/allocs.go b/tstest/allocs.go index a6d9c79f69ff7..f15a00508d87f 100644 --- a/tstest/allocs.go +++ b/tstest/allocs.go @@ -1,50 +1,50 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstest - -import ( - "fmt" - "runtime" - "testing" - "time" -) - -// MinAllocsPerRun asserts that f can run with no more than target allocations. -// It runs f up to 1000 times or 5s, whichever happens first. -// If f has executed more than target allocations on every run, it returns a non-nil error. -// -// MinAllocsPerRun sets GOMAXPROCS to 1 during its measurement and restores -// it before returning. -func MinAllocsPerRun(t *testing.T, target uint64, f func()) error { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) - - var memstats runtime.MemStats - var min, max, sum uint64 - start := time.Now() - var iters int - for { - runtime.ReadMemStats(&memstats) - startMallocs := memstats.Mallocs - f() - runtime.ReadMemStats(&memstats) - mallocs := memstats.Mallocs - startMallocs - // TODO: if mallocs < target, return an error? See discussion in #3204. - if mallocs <= target { - return nil - } - if min == 0 || mallocs < min { - min = mallocs - } - if mallocs > max { - max = mallocs - } - sum += mallocs - iters++ - if iters == 1000 || time.Since(start) > 5*time.Second { - break - } - } - - return fmt.Errorf("min allocs = %d, max allocs = %d, avg allocs/run = %f, want run with <= %d allocs", min, max, float64(sum)/float64(iters), target) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import ( + "fmt" + "runtime" + "testing" + "time" +) + +// MinAllocsPerRun asserts that f can run with no more than target allocations. +// It runs f up to 1000 times or 5s, whichever happens first. +// If f has executed more than target allocations on every run, it returns a non-nil error. +// +// MinAllocsPerRun sets GOMAXPROCS to 1 during its measurement and restores +// it before returning. +func MinAllocsPerRun(t *testing.T, target uint64, f func()) error { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + + var memstats runtime.MemStats + var min, max, sum uint64 + start := time.Now() + var iters int + for { + runtime.ReadMemStats(&memstats) + startMallocs := memstats.Mallocs + f() + runtime.ReadMemStats(&memstats) + mallocs := memstats.Mallocs - startMallocs + // TODO: if mallocs < target, return an error? See discussion in #3204. + if mallocs <= target { + return nil + } + if min == 0 || mallocs < min { + min = mallocs + } + if mallocs > max { + max = mallocs + } + sum += mallocs + iters++ + if iters == 1000 || time.Since(start) > 5*time.Second { + break + } + } + + return fmt.Errorf("min allocs = %d, max allocs = %d, avg allocs/run = %f, want run with <= %d allocs", min, max, float64(sum)/float64(iters), target) +} diff --git a/tstest/archtest/qemu_test.go b/tstest/archtest/qemu_test.go index cea3b4b8e9b53..8b59ae5d9fee1 100644 --- a/tstest/archtest/qemu_test.go +++ b/tstest/archtest/qemu_test.go @@ -1,73 +1,73 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux && amd64 && !race - -package archtest - -import ( - "bytes" - "fmt" - "os" - "os/exec" - "strings" - "testing" - - "tailscale.com/util/cibuild" -) - -func TestInQemu(t *testing.T) { - t.Parallel() - type Arch struct { - Goarch string // GOARCH value - Qarch string // qemu name - } - arches := []Arch{ - {"arm", "arm"}, - {"arm64", "aarch64"}, - {"mips", "mips"}, - {"mipsle", "mipsel"}, - {"mips64", "mips64"}, - {"mips64le", "mips64el"}, - {"386", "386"}, - } - inCI := cibuild.On() - for _, arch := range arches { - arch := arch - t.Run(arch.Goarch, func(t *testing.T) { - t.Parallel() - qemuUser := "qemu-" + arch.Qarch - execVia := qemuUser - if arch.Goarch == "386" { - execVia = "" // amd64 can run it fine - } else { - look, err := exec.LookPath(qemuUser) - if err != nil { - if inCI { - t.Fatalf("in CI and qemu not available: %v", err) - } - t.Skipf("%s not found; skipping test. error was: %v", qemuUser, err) - } - t.Logf("using %v", look) - } - cmd := exec.Command("go", - "test", - "--exec="+execVia, - "-v", - "tailscale.com/tstest/archtest", - ) - cmd.Env = append(os.Environ(), "GOARCH="+arch.Goarch) - out, err := cmd.CombinedOutput() - if err != nil { - if strings.Contains(string(out), "fatal error: sigaction failed") && !inCI { - t.Skip("skipping; qemu too old. use 5.x.") - } - t.Errorf("failed: %s", out) - } - sub := fmt.Sprintf("I am linux/%s", arch.Goarch) - if !bytes.Contains(out, []byte(sub)) { - t.Errorf("output didn't contain %q: %s", sub, out) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && amd64 && !race + +package archtest + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "strings" + "testing" + + "tailscale.com/util/cibuild" +) + +func TestInQemu(t *testing.T) { + t.Parallel() + type Arch struct { + Goarch string // GOARCH value + Qarch string // qemu name + } + arches := []Arch{ + {"arm", "arm"}, + {"arm64", "aarch64"}, + {"mips", "mips"}, + {"mipsle", "mipsel"}, + {"mips64", "mips64"}, + {"mips64le", "mips64el"}, + {"386", "386"}, + } + inCI := cibuild.On() + for _, arch := range arches { + arch := arch + t.Run(arch.Goarch, func(t *testing.T) { + t.Parallel() + qemuUser := "qemu-" + arch.Qarch + execVia := qemuUser + if arch.Goarch == "386" { + execVia = "" // amd64 can run it fine + } else { + look, err := exec.LookPath(qemuUser) + if err != nil { + if inCI { + t.Fatalf("in CI and qemu not available: %v", err) + } + t.Skipf("%s not found; skipping test. error was: %v", qemuUser, err) + } + t.Logf("using %v", look) + } + cmd := exec.Command("go", + "test", + "--exec="+execVia, + "-v", + "tailscale.com/tstest/archtest", + ) + cmd.Env = append(os.Environ(), "GOARCH="+arch.Goarch) + out, err := cmd.CombinedOutput() + if err != nil { + if strings.Contains(string(out), "fatal error: sigaction failed") && !inCI { + t.Skip("skipping; qemu too old. use 5.x.") + } + t.Errorf("failed: %s", out) + } + sub := fmt.Sprintf("I am linux/%s", arch.Goarch) + if !bytes.Contains(out, []byte(sub)) { + t.Errorf("output didn't contain %q: %s", sub, out) + } + }) + } +} diff --git a/tstest/clock.go b/tstest/clock.go index 48684957ec421..ee7523430ff54 100644 --- a/tstest/clock.go +++ b/tstest/clock.go @@ -1,694 +1,694 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstest - -import ( - "container/heap" - "sync" - "time" - - "tailscale.com/tstime" - "tailscale.com/util/mak" -) - -// ClockOpts is used to configure the initial settings for a Clock. Once the -// settings are configured as desired, call NewClock to get the resulting Clock. -type ClockOpts struct { - // Start is the starting time for the Clock. When FollowRealTime is false, - // Start is also the value that will be returned by the first call - // to Clock.Now. - Start time.Time - // Step is the amount of time the Clock will advance whenever Clock.Now is - // called. If set to zero, the Clock will only advance when Clock.Advance is - // called and/or if FollowRealTime is true. - // - // FollowRealTime and Step cannot be enabled at the same time. - Step time.Duration - - // TimerChannelSize configures the maximum buffered ticks that are - // permitted in the channel of any Timer and Ticker created by this Clock. - // The special value 0 means to use the default of 1. The buffer may need to - // be increased if time is advanced by more than a single tick and proper - // functioning of the test requires that the ticks are not lost. - TimerChannelSize int - - // FollowRealTime makes the simulated time increment along with real time. - // It is a compromise between determinism and the difficulty of explicitly - // managing the simulated time via Step or Clock.Advance. When - // FollowRealTime is set, calls to Now() and PeekNow() will add the - // elapsed real-world time to the simulated time. - // - // FollowRealTime and Step cannot be enabled at the same time. - FollowRealTime bool -} - -// NewClock creates a Clock with the specified settings. To create a -// Clock with only the default settings, new(Clock) is equivalent, except that -// the start time will not be computed until one of the receivers is called. -func NewClock(co ClockOpts) *Clock { - if co.FollowRealTime && co.Step != 0 { - panic("only one of FollowRealTime and Step are allowed in NewClock") - } - - return newClockInternal(co, nil) -} - -// newClockInternal creates a Clock with the specified settings and allows -// specifying a non-standard realTimeClock. -func newClockInternal(co ClockOpts, rtClock tstime.Clock) *Clock { - if !co.FollowRealTime && rtClock != nil { - panic("rtClock can only be set with FollowRealTime enabled") - } - - if co.FollowRealTime && rtClock == nil { - rtClock = new(tstime.StdClock) - } - - c := &Clock{ - start: co.Start, - realTimeClock: rtClock, - step: co.Step, - timerChannelSize: co.TimerChannelSize, - } - c.init() // init now to capture the current time when co.Start.IsZero() - return c -} - -// Clock is a testing clock that advances every time its Now method is -// called, beginning at its start time. If no start time is specified using -// ClockBuilder, an arbitrary start time will be selected when the Clock is -// created and can be retrieved by calling Clock.Start(). -type Clock struct { - // start is the first value returned by Now. It must not be modified after - // init is called. - start time.Time - - // realTimeClock, if not nil, indicates that the Clock shall move forward - // according to realTimeClock + the accumulated calls to Advance. This can - // make writing tests easier that require some control over the clock but do - // not need exact control over the clock. While step can also be used for - // this purpose, it is harder to control how quickly time moves using step. - realTimeClock tstime.Clock - - initOnce sync.Once - mu sync.Mutex - - // step is how much to advance with each Now call. - step time.Duration - // present is the last value returned by Now (and will be returned again by - // PeekNow). - present time.Time - // realTime is the time from realTimeClock corresponding to the current - // value of present. - realTime time.Time - // skipStep indicates that the next call to Now should not add step to - // present. This occurs after initialization and after Advance. - skipStep bool - // timerChannelSize is the buffer size to use for channels created by - // NewTimer and NewTicker. - timerChannelSize int - - events eventManager -} - -func (c *Clock) init() { - c.initOnce.Do(func() { - if c.realTimeClock != nil { - c.realTime = c.realTimeClock.Now() - } - if c.start.IsZero() { - if c.realTime.IsZero() { - c.start = time.Now() - } else { - c.start = c.realTime - } - } - if c.timerChannelSize == 0 { - c.timerChannelSize = 1 - } - c.present = c.start - c.skipStep = true - c.events.AdvanceTo(c.present) - }) -} - -// Now returns the virtual clock's current time, and advances it -// according to its step configuration. -func (c *Clock) Now() time.Time { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - - step := c.step - if c.skipStep { - step = 0 - c.skipStep = false - } - c.advanceLocked(rt, step) - - return c.present -} - -func (c *Clock) maybeGetRealTime() time.Time { - if c.realTimeClock == nil { - return time.Time{} - } - return c.realTimeClock.Now() -} - -func (c *Clock) advanceLocked(now time.Time, add time.Duration) { - if !now.IsZero() { - add += now.Sub(c.realTime) - c.realTime = now - } - if add == 0 { - return - } - c.present = c.present.Add(add) - c.events.AdvanceTo(c.present) -} - -// PeekNow returns the last time reported by Now. If Now has never been called, -// PeekNow returns the same value as GetStart. -func (c *Clock) PeekNow() time.Time { - c.init() - c.mu.Lock() - defer c.mu.Unlock() - return c.present -} - -// Advance moves simulated time forward or backwards by a relative amount. Any -// Timer or Ticker that is waiting will fire at the requested point in simulated -// time. Advance returns the new simulated time. If this Clock follows real time -// then the next call to Now will equal the return value of Advance + the -// elapsed time since calling Advance. Otherwise, the next call to Now will -// equal the return value of Advance, regardless of the current step. -func (c *Clock) Advance(d time.Duration) time.Time { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - c.skipStep = true - - c.advanceLocked(rt, d) - return c.present -} - -// AdvanceTo moves simulated time to a new absolute value. Any Timer or Ticker -// that is waiting will fire at the requested point in simulated time. If this -// Clock follows real time then the next call to Now will equal t + the elapsed -// time since calling Advance. Otherwise, the next call to Now will equal t, -// regardless of the configured step. -func (c *Clock) AdvanceTo(t time.Time) { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - c.skipStep = true - c.realTime = rt - c.present = t - c.events.AdvanceTo(c.present) -} - -// GetStart returns the initial simulated time when this Clock was created. -func (c *Clock) GetStart() time.Time { - c.init() - c.mu.Lock() - defer c.mu.Unlock() - return c.start -} - -// GetStep returns the amount that simulated time advances on every call to Now. -func (c *Clock) GetStep() time.Duration { - c.init() - c.mu.Lock() - defer c.mu.Unlock() - return c.step -} - -// SetStep updates the amount that simulated time advances on every call to Now. -func (c *Clock) SetStep(d time.Duration) { - c.init() - c.mu.Lock() - defer c.mu.Unlock() - c.step = d -} - -// SetTimerChannelSize changes the channel size for any Timer or Ticker created -// in the future. It does not affect those that were already created. -func (c *Clock) SetTimerChannelSize(n int) { - c.init() - c.mu.Lock() - defer c.mu.Unlock() - c.timerChannelSize = n -} - -// NewTicker returns a Ticker that uses this Clock for accessing the current -// time. -func (c *Clock) NewTicker(d time.Duration) (tstime.TickerController, <-chan time.Time) { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - - c.advanceLocked(rt, 0) - t := &Ticker{ - nextTrigger: c.present.Add(d), - period: d, - em: &c.events, - } - t.init(c.timerChannelSize) - return t, t.C -} - -// NewTimer returns a Timer that uses this Clock for accessing the current -// time. -func (c *Clock) NewTimer(d time.Duration) (tstime.TimerController, <-chan time.Time) { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - - c.advanceLocked(rt, 0) - t := &Timer{ - nextTrigger: c.present.Add(d), - em: &c.events, - } - t.init(c.timerChannelSize, nil) - return t, t.C -} - -// AfterFunc returns a Timer that calls f when it fires, using this Clock for -// accessing the current time. -func (c *Clock) AfterFunc(d time.Duration, f func()) tstime.TimerController { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - - c.advanceLocked(rt, 0) - t := &Timer{ - nextTrigger: c.present.Add(d), - em: &c.events, - } - t.init(c.timerChannelSize, f) - return t -} - -// Since subtracts specified duration from Now(). -func (c *Clock) Since(t time.Time) time.Duration { - return c.Now().Sub(t) -} - -// eventHandler offers a common interface for Timer and Ticker events to avoid -// code duplication in eventManager. -type eventHandler interface { - // Fire signals the event. The provided time is written to the event's - // channel as the current time. The return value is the next time this event - // should fire, otherwise if it is zero then the event will be removed from - // the eventManager. - Fire(time.Time) time.Time -} - -// event tracks details about an upcoming Timer or Ticker firing. -type event struct { - position int // The current index in the heap, needed for heap.Fix and heap.Remove. - when time.Time // A cache of the next time the event triggers to avoid locking issues if we were to get it from eh. - eh eventHandler -} - -// eventManager tracks pending events created by Timer and Ticker. eventManager -// implements heap.Interface for efficient lookups of the next event. -type eventManager struct { - // clock is a real time clock for scheduling events with. When clock is nil, - // events only fire when AdvanceTo is called by the simulated clock that - // this eventManager belongs to. When clock is not nil, events may fire when - // timer triggers. - clock tstime.Clock - - mu sync.Mutex - now time.Time - heap []*event - reverseLookup map[eventHandler]*event - - // timer is an AfterFunc that triggers at heap[0].when.Sub(now) relative to - // the time represented by clock. In other words, if clock is real world - // time, then if an event is scheduled 1 second into the future in the - // simulated time, then the event will trigger after 1 second of actual test - // execution time (unless the test advances simulated time, in which case - // the timer is updated accordingly). This makes tests easier to write in - // situations where the simulated time only needs to be partially - // controlled, and the test writer wishes for simulated time to pass with an - // offset but still synchronized with the real world. - // - // In the future, this could be extended to allow simulated time to run at a - // multiple of real world time. - timer tstime.TimerController -} - -func (em *eventManager) handleTimer() { - rt := em.clock.Now() - em.AdvanceTo(rt) -} - -// Push implements heap.Interface.Push and must only be called by heap funcs -// with em.mu already held. -func (em *eventManager) Push(x any) { - e, ok := x.(*event) - if !ok { - panic("incorrect event type") - } - if e == nil { - panic("nil event") - } - - mak.Set(&em.reverseLookup, e.eh, e) - e.position = len(em.heap) - em.heap = append(em.heap, e) -} - -// Pop implements heap.Interface.Pop and must only be called by heap funcs with -// em.mu already held. -func (em *eventManager) Pop() any { - e := em.heap[len(em.heap)-1] - em.heap = em.heap[:len(em.heap)-1] - delete(em.reverseLookup, e.eh) - return e -} - -// Len implements sort.Interface.Len and must only be called by heap funcs with -// em.mu already held. -func (em *eventManager) Len() int { - return len(em.heap) -} - -// Less implements sort.Interface.Less and must only be called by heap funcs -// with em.mu already held. -func (em *eventManager) Less(i, j int) bool { - return em.heap[i].when.Before(em.heap[j].when) -} - -// Swap implements sort.Interface.Swap and must only be called by heap funcs -// with em.mu already held. -func (em *eventManager) Swap(i, j int) { - em.heap[i], em.heap[j] = em.heap[j], em.heap[i] - em.heap[i].position = i - em.heap[j].position = j -} - -// Reschedule adds/updates/deletes an event in the heap, whichever -// operation is applicable (use a zero time to delete). -func (em *eventManager) Reschedule(eh eventHandler, t time.Time) { - em.mu.Lock() - defer em.mu.Unlock() - defer em.updateTimerLocked() - - e, ok := em.reverseLookup[eh] - if !ok { - if t.IsZero() { - // eh is not scheduled and also not active, so do nothing. - return - } - // eh is not scheduled but is active, so add it. - heap.Push(em, &event{ - when: t, - eh: eh, - }) - em.processEventsLocked(em.now) // This is always safe and required when !t.After(em.now). - return - } - - if t.IsZero() { - // e is scheduled but not active, so remove it. - heap.Remove(em, e.position) - return - } - - // e is scheduled and active, so update it. - e.when = t - heap.Fix(em, e.position) - em.processEventsLocked(em.now) // This is always safe and required when !t.After(em.now). -} - -// AdvanceTo updates the current time to tm and fires all events scheduled -// before or equal to tm. When an event fires, it may request rescheduling and -// the rescheduled events will be combined with the other existing events that -// are waiting, and will be run in the unified ordering. A poorly behaved event -// may theoretically prevent this from ever completing, but both Timer and -// Ticker require positive steps into the future. -func (em *eventManager) AdvanceTo(tm time.Time) { - em.mu.Lock() - defer em.mu.Unlock() - defer em.updateTimerLocked() - - em.processEventsLocked(tm) - em.now = tm -} - -// Now returns the cached current time. It is intended for use by a Timer or -// Ticker that needs to convert a relative time to an absolute time. -func (em *eventManager) Now() time.Time { - em.mu.Lock() - defer em.mu.Unlock() - return em.now -} - -func (em *eventManager) processEventsLocked(tm time.Time) { - for len(em.heap) > 0 && !em.heap[0].when.After(tm) { - // Ideally some jitter would be added here but it's difficult to do so - // in a deterministic fashion. - em.now = em.heap[0].when - - if nextFire := em.heap[0].eh.Fire(em.now); !nextFire.IsZero() { - em.heap[0].when = nextFire - heap.Fix(em, 0) - } else { - heap.Pop(em) - } - } -} - -func (em *eventManager) updateTimerLocked() { - if em.clock == nil { - return - } - if len(em.heap) == 0 { - if em.timer != nil { - em.timer.Stop() - } - return - } - - timeToEvent := em.heap[0].when.Sub(em.now) - if em.timer == nil { - em.timer = em.clock.AfterFunc(timeToEvent, em.handleTimer) - return - } - em.timer.Reset(timeToEvent) -} - -// Ticker is a time.Ticker lookalike for use in tests that need to control when -// events fire. Ticker could be made standalone in future but for now is -// expected to be paired with a Clock and created by Clock.NewTicker. -type Ticker struct { - C <-chan time.Time // The channel on which ticks are delivered. - - // em is the eventManager to be notified when nextTrigger changes. - // eventManager has its own mutex, and the pointer is immutable, therefore - // em can be accessed without holding mu. - em *eventManager - - c chan<- time.Time // The writer side of C. - - mu sync.Mutex - - // nextTrigger is the time of the ticker's next scheduled activation. When - // Fire activates the ticker, nextTrigger is the timestamp written to the - // channel. - nextTrigger time.Time - - // period is the duration that is added to nextTrigger when the ticker - // fires. - period time.Duration -} - -func (t *Ticker) init(channelSize int) { - if channelSize <= 0 { - panic("ticker channel size must be non-negative") - } - c := make(chan time.Time, channelSize) - t.c = c - t.C = c - t.em.Reschedule(t, t.nextTrigger) -} - -// Fire triggers the ticker. curTime is the timestamp to write to the channel. -// The next trigger time for the ticker is updated to the last computed trigger -// time + the ticker period (set at creation or using Reset). The next trigger -// time is computed this way to match standard time.Ticker behavior, which -// prevents accumulation of long term drift caused by delays in event execution. -func (t *Ticker) Fire(curTime time.Time) time.Time { - t.mu.Lock() - defer t.mu.Unlock() - - if t.nextTrigger.IsZero() { - return time.Time{} - } - select { - case t.c <- curTime: - default: - } - t.nextTrigger = t.nextTrigger.Add(t.period) - - return t.nextTrigger -} - -// Reset adjusts the Ticker's period to d and reschedules the next fire time to -// the current simulated time + d. -func (t *Ticker) Reset(d time.Duration) { - if d <= 0 { - // The standard time.Ticker requires a positive period. - panic("non-positive period for Ticker.Reset") - } - - now := t.em.Now() - - t.mu.Lock() - t.resetLocked(now.Add(d), d) - t.mu.Unlock() - - t.em.Reschedule(t, t.nextTrigger) -} - -// ResetAbsolute adjusts the Ticker's period to d and reschedules the next fire -// time to nextTrigger. -func (t *Ticker) ResetAbsolute(nextTrigger time.Time, d time.Duration) { - if nextTrigger.IsZero() { - panic("zero nextTrigger time for ResetAbsolute") - } - if d <= 0 { - panic("non-positive period for ResetAbsolute") - } - - t.mu.Lock() - t.resetLocked(nextTrigger, d) - t.mu.Unlock() - - t.em.Reschedule(t, t.nextTrigger) -} - -func (t *Ticker) resetLocked(nextTrigger time.Time, d time.Duration) { - t.nextTrigger = nextTrigger - t.period = d -} - -// Stop deactivates the Ticker. -func (t *Ticker) Stop() { - t.mu.Lock() - t.nextTrigger = time.Time{} - t.mu.Unlock() - - t.em.Reschedule(t, t.nextTrigger) -} - -// Timer is a time.Timer lookalike for use in tests that need to control when -// events fire. Timer could be made standalone in future but for now must be -// paired with a Clock and created by Clock.NewTimer. -type Timer struct { - C <-chan time.Time // The channel on which ticks are delivered. - - // em is the eventManager to be notified when nextTrigger changes. - // eventManager has its own mutex, and the pointer is immutable, therefore - // em can be accessed without holding mu. - em *eventManager - - f func(time.Time) // The function to call when the timer expires. - - mu sync.Mutex - - // nextTrigger is the time of the ticker's next scheduled activation. When - // Fire activates the ticker, nextTrigger is the timestamp written to the - // channel. - nextTrigger time.Time -} - -func (t *Timer) init(channelSize int, afterFunc func()) { - if channelSize <= 0 { - panic("ticker channel size must be non-negative") - } - c := make(chan time.Time, channelSize) - t.C = c - if afterFunc == nil { - t.f = func(curTime time.Time) { - select { - case c <- curTime: - default: - } - } - } else { - t.f = func(_ time.Time) { afterFunc() } - } - t.em.Reschedule(t, t.nextTrigger) -} - -// Fire triggers the ticker. curTime is the timestamp to write to the channel. -// The next trigger time for the ticker is updated to the last computed trigger -// time + the ticker period (set at creation or using Reset). The next trigger -// time is computed this way to match standard time.Ticker behavior, which -// prevents accumulation of long term drift caused by delays in event execution. -func (t *Timer) Fire(curTime time.Time) time.Time { - t.mu.Lock() - defer t.mu.Unlock() - - if t.nextTrigger.IsZero() { - return time.Time{} - } - t.nextTrigger = time.Time{} - t.f(curTime) - return time.Time{} -} - -// Reset reschedules the next fire time to the current simulated time + d. -// Reset reports whether the timer was still active before the reset. -func (t *Timer) Reset(d time.Duration) bool { - if d <= 0 { - // The standard time.Timer requires a positive delay. - panic("non-positive delay for Timer.Reset") - } - - return t.reset(t.em.Now().Add(d)) -} - -// ResetAbsolute reschedules the next fire time to nextTrigger. -// ResetAbsolute reports whether the timer was still active before the reset. -func (t *Timer) ResetAbsolute(nextTrigger time.Time) bool { - if nextTrigger.IsZero() { - panic("zero nextTrigger time for ResetAbsolute") - } - - return t.reset(nextTrigger) -} - -// Stop deactivates the Timer. Stop reports whether the timer was active before -// stopping. -func (t *Timer) Stop() bool { - return t.reset(time.Time{}) -} - -func (t *Timer) reset(nextTrigger time.Time) bool { - t.mu.Lock() - wasActive := !t.nextTrigger.IsZero() - t.nextTrigger = nextTrigger - t.mu.Unlock() - - t.em.Reschedule(t, t.nextTrigger) - return wasActive -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import ( + "container/heap" + "sync" + "time" + + "tailscale.com/tstime" + "tailscale.com/util/mak" +) + +// ClockOpts is used to configure the initial settings for a Clock. Once the +// settings are configured as desired, call NewClock to get the resulting Clock. +type ClockOpts struct { + // Start is the starting time for the Clock. When FollowRealTime is false, + // Start is also the value that will be returned by the first call + // to Clock.Now. + Start time.Time + // Step is the amount of time the Clock will advance whenever Clock.Now is + // called. If set to zero, the Clock will only advance when Clock.Advance is + // called and/or if FollowRealTime is true. + // + // FollowRealTime and Step cannot be enabled at the same time. + Step time.Duration + + // TimerChannelSize configures the maximum buffered ticks that are + // permitted in the channel of any Timer and Ticker created by this Clock. + // The special value 0 means to use the default of 1. The buffer may need to + // be increased if time is advanced by more than a single tick and proper + // functioning of the test requires that the ticks are not lost. + TimerChannelSize int + + // FollowRealTime makes the simulated time increment along with real time. + // It is a compromise between determinism and the difficulty of explicitly + // managing the simulated time via Step or Clock.Advance. When + // FollowRealTime is set, calls to Now() and PeekNow() will add the + // elapsed real-world time to the simulated time. + // + // FollowRealTime and Step cannot be enabled at the same time. + FollowRealTime bool +} + +// NewClock creates a Clock with the specified settings. To create a +// Clock with only the default settings, new(Clock) is equivalent, except that +// the start time will not be computed until one of the receivers is called. +func NewClock(co ClockOpts) *Clock { + if co.FollowRealTime && co.Step != 0 { + panic("only one of FollowRealTime and Step are allowed in NewClock") + } + + return newClockInternal(co, nil) +} + +// newClockInternal creates a Clock with the specified settings and allows +// specifying a non-standard realTimeClock. +func newClockInternal(co ClockOpts, rtClock tstime.Clock) *Clock { + if !co.FollowRealTime && rtClock != nil { + panic("rtClock can only be set with FollowRealTime enabled") + } + + if co.FollowRealTime && rtClock == nil { + rtClock = new(tstime.StdClock) + } + + c := &Clock{ + start: co.Start, + realTimeClock: rtClock, + step: co.Step, + timerChannelSize: co.TimerChannelSize, + } + c.init() // init now to capture the current time when co.Start.IsZero() + return c +} + +// Clock is a testing clock that advances every time its Now method is +// called, beginning at its start time. If no start time is specified using +// ClockBuilder, an arbitrary start time will be selected when the Clock is +// created and can be retrieved by calling Clock.Start(). +type Clock struct { + // start is the first value returned by Now. It must not be modified after + // init is called. + start time.Time + + // realTimeClock, if not nil, indicates that the Clock shall move forward + // according to realTimeClock + the accumulated calls to Advance. This can + // make writing tests easier that require some control over the clock but do + // not need exact control over the clock. While step can also be used for + // this purpose, it is harder to control how quickly time moves using step. + realTimeClock tstime.Clock + + initOnce sync.Once + mu sync.Mutex + + // step is how much to advance with each Now call. + step time.Duration + // present is the last value returned by Now (and will be returned again by + // PeekNow). + present time.Time + // realTime is the time from realTimeClock corresponding to the current + // value of present. + realTime time.Time + // skipStep indicates that the next call to Now should not add step to + // present. This occurs after initialization and after Advance. + skipStep bool + // timerChannelSize is the buffer size to use for channels created by + // NewTimer and NewTicker. + timerChannelSize int + + events eventManager +} + +func (c *Clock) init() { + c.initOnce.Do(func() { + if c.realTimeClock != nil { + c.realTime = c.realTimeClock.Now() + } + if c.start.IsZero() { + if c.realTime.IsZero() { + c.start = time.Now() + } else { + c.start = c.realTime + } + } + if c.timerChannelSize == 0 { + c.timerChannelSize = 1 + } + c.present = c.start + c.skipStep = true + c.events.AdvanceTo(c.present) + }) +} + +// Now returns the virtual clock's current time, and advances it +// according to its step configuration. +func (c *Clock) Now() time.Time { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + + step := c.step + if c.skipStep { + step = 0 + c.skipStep = false + } + c.advanceLocked(rt, step) + + return c.present +} + +func (c *Clock) maybeGetRealTime() time.Time { + if c.realTimeClock == nil { + return time.Time{} + } + return c.realTimeClock.Now() +} + +func (c *Clock) advanceLocked(now time.Time, add time.Duration) { + if !now.IsZero() { + add += now.Sub(c.realTime) + c.realTime = now + } + if add == 0 { + return + } + c.present = c.present.Add(add) + c.events.AdvanceTo(c.present) +} + +// PeekNow returns the last time reported by Now. If Now has never been called, +// PeekNow returns the same value as GetStart. +func (c *Clock) PeekNow() time.Time { + c.init() + c.mu.Lock() + defer c.mu.Unlock() + return c.present +} + +// Advance moves simulated time forward or backwards by a relative amount. Any +// Timer or Ticker that is waiting will fire at the requested point in simulated +// time. Advance returns the new simulated time. If this Clock follows real time +// then the next call to Now will equal the return value of Advance + the +// elapsed time since calling Advance. Otherwise, the next call to Now will +// equal the return value of Advance, regardless of the current step. +func (c *Clock) Advance(d time.Duration) time.Time { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + c.skipStep = true + + c.advanceLocked(rt, d) + return c.present +} + +// AdvanceTo moves simulated time to a new absolute value. Any Timer or Ticker +// that is waiting will fire at the requested point in simulated time. If this +// Clock follows real time then the next call to Now will equal t + the elapsed +// time since calling Advance. Otherwise, the next call to Now will equal t, +// regardless of the configured step. +func (c *Clock) AdvanceTo(t time.Time) { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + c.skipStep = true + c.realTime = rt + c.present = t + c.events.AdvanceTo(c.present) +} + +// GetStart returns the initial simulated time when this Clock was created. +func (c *Clock) GetStart() time.Time { + c.init() + c.mu.Lock() + defer c.mu.Unlock() + return c.start +} + +// GetStep returns the amount that simulated time advances on every call to Now. +func (c *Clock) GetStep() time.Duration { + c.init() + c.mu.Lock() + defer c.mu.Unlock() + return c.step +} + +// SetStep updates the amount that simulated time advances on every call to Now. +func (c *Clock) SetStep(d time.Duration) { + c.init() + c.mu.Lock() + defer c.mu.Unlock() + c.step = d +} + +// SetTimerChannelSize changes the channel size for any Timer or Ticker created +// in the future. It does not affect those that were already created. +func (c *Clock) SetTimerChannelSize(n int) { + c.init() + c.mu.Lock() + defer c.mu.Unlock() + c.timerChannelSize = n +} + +// NewTicker returns a Ticker that uses this Clock for accessing the current +// time. +func (c *Clock) NewTicker(d time.Duration) (tstime.TickerController, <-chan time.Time) { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + + c.advanceLocked(rt, 0) + t := &Ticker{ + nextTrigger: c.present.Add(d), + period: d, + em: &c.events, + } + t.init(c.timerChannelSize) + return t, t.C +} + +// NewTimer returns a Timer that uses this Clock for accessing the current +// time. +func (c *Clock) NewTimer(d time.Duration) (tstime.TimerController, <-chan time.Time) { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + + c.advanceLocked(rt, 0) + t := &Timer{ + nextTrigger: c.present.Add(d), + em: &c.events, + } + t.init(c.timerChannelSize, nil) + return t, t.C +} + +// AfterFunc returns a Timer that calls f when it fires, using this Clock for +// accessing the current time. +func (c *Clock) AfterFunc(d time.Duration, f func()) tstime.TimerController { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + + c.advanceLocked(rt, 0) + t := &Timer{ + nextTrigger: c.present.Add(d), + em: &c.events, + } + t.init(c.timerChannelSize, f) + return t +} + +// Since subtracts specified duration from Now(). +func (c *Clock) Since(t time.Time) time.Duration { + return c.Now().Sub(t) +} + +// eventHandler offers a common interface for Timer and Ticker events to avoid +// code duplication in eventManager. +type eventHandler interface { + // Fire signals the event. The provided time is written to the event's + // channel as the current time. The return value is the next time this event + // should fire, otherwise if it is zero then the event will be removed from + // the eventManager. + Fire(time.Time) time.Time +} + +// event tracks details about an upcoming Timer or Ticker firing. +type event struct { + position int // The current index in the heap, needed for heap.Fix and heap.Remove. + when time.Time // A cache of the next time the event triggers to avoid locking issues if we were to get it from eh. + eh eventHandler +} + +// eventManager tracks pending events created by Timer and Ticker. eventManager +// implements heap.Interface for efficient lookups of the next event. +type eventManager struct { + // clock is a real time clock for scheduling events with. When clock is nil, + // events only fire when AdvanceTo is called by the simulated clock that + // this eventManager belongs to. When clock is not nil, events may fire when + // timer triggers. + clock tstime.Clock + + mu sync.Mutex + now time.Time + heap []*event + reverseLookup map[eventHandler]*event + + // timer is an AfterFunc that triggers at heap[0].when.Sub(now) relative to + // the time represented by clock. In other words, if clock is real world + // time, then if an event is scheduled 1 second into the future in the + // simulated time, then the event will trigger after 1 second of actual test + // execution time (unless the test advances simulated time, in which case + // the timer is updated accordingly). This makes tests easier to write in + // situations where the simulated time only needs to be partially + // controlled, and the test writer wishes for simulated time to pass with an + // offset but still synchronized with the real world. + // + // In the future, this could be extended to allow simulated time to run at a + // multiple of real world time. + timer tstime.TimerController +} + +func (em *eventManager) handleTimer() { + rt := em.clock.Now() + em.AdvanceTo(rt) +} + +// Push implements heap.Interface.Push and must only be called by heap funcs +// with em.mu already held. +func (em *eventManager) Push(x any) { + e, ok := x.(*event) + if !ok { + panic("incorrect event type") + } + if e == nil { + panic("nil event") + } + + mak.Set(&em.reverseLookup, e.eh, e) + e.position = len(em.heap) + em.heap = append(em.heap, e) +} + +// Pop implements heap.Interface.Pop and must only be called by heap funcs with +// em.mu already held. +func (em *eventManager) Pop() any { + e := em.heap[len(em.heap)-1] + em.heap = em.heap[:len(em.heap)-1] + delete(em.reverseLookup, e.eh) + return e +} + +// Len implements sort.Interface.Len and must only be called by heap funcs with +// em.mu already held. +func (em *eventManager) Len() int { + return len(em.heap) +} + +// Less implements sort.Interface.Less and must only be called by heap funcs +// with em.mu already held. +func (em *eventManager) Less(i, j int) bool { + return em.heap[i].when.Before(em.heap[j].when) +} + +// Swap implements sort.Interface.Swap and must only be called by heap funcs +// with em.mu already held. +func (em *eventManager) Swap(i, j int) { + em.heap[i], em.heap[j] = em.heap[j], em.heap[i] + em.heap[i].position = i + em.heap[j].position = j +} + +// Reschedule adds/updates/deletes an event in the heap, whichever +// operation is applicable (use a zero time to delete). +func (em *eventManager) Reschedule(eh eventHandler, t time.Time) { + em.mu.Lock() + defer em.mu.Unlock() + defer em.updateTimerLocked() + + e, ok := em.reverseLookup[eh] + if !ok { + if t.IsZero() { + // eh is not scheduled and also not active, so do nothing. + return + } + // eh is not scheduled but is active, so add it. + heap.Push(em, &event{ + when: t, + eh: eh, + }) + em.processEventsLocked(em.now) // This is always safe and required when !t.After(em.now). + return + } + + if t.IsZero() { + // e is scheduled but not active, so remove it. + heap.Remove(em, e.position) + return + } + + // e is scheduled and active, so update it. + e.when = t + heap.Fix(em, e.position) + em.processEventsLocked(em.now) // This is always safe and required when !t.After(em.now). +} + +// AdvanceTo updates the current time to tm and fires all events scheduled +// before or equal to tm. When an event fires, it may request rescheduling and +// the rescheduled events will be combined with the other existing events that +// are waiting, and will be run in the unified ordering. A poorly behaved event +// may theoretically prevent this from ever completing, but both Timer and +// Ticker require positive steps into the future. +func (em *eventManager) AdvanceTo(tm time.Time) { + em.mu.Lock() + defer em.mu.Unlock() + defer em.updateTimerLocked() + + em.processEventsLocked(tm) + em.now = tm +} + +// Now returns the cached current time. It is intended for use by a Timer or +// Ticker that needs to convert a relative time to an absolute time. +func (em *eventManager) Now() time.Time { + em.mu.Lock() + defer em.mu.Unlock() + return em.now +} + +func (em *eventManager) processEventsLocked(tm time.Time) { + for len(em.heap) > 0 && !em.heap[0].when.After(tm) { + // Ideally some jitter would be added here but it's difficult to do so + // in a deterministic fashion. + em.now = em.heap[0].when + + if nextFire := em.heap[0].eh.Fire(em.now); !nextFire.IsZero() { + em.heap[0].when = nextFire + heap.Fix(em, 0) + } else { + heap.Pop(em) + } + } +} + +func (em *eventManager) updateTimerLocked() { + if em.clock == nil { + return + } + if len(em.heap) == 0 { + if em.timer != nil { + em.timer.Stop() + } + return + } + + timeToEvent := em.heap[0].when.Sub(em.now) + if em.timer == nil { + em.timer = em.clock.AfterFunc(timeToEvent, em.handleTimer) + return + } + em.timer.Reset(timeToEvent) +} + +// Ticker is a time.Ticker lookalike for use in tests that need to control when +// events fire. Ticker could be made standalone in future but for now is +// expected to be paired with a Clock and created by Clock.NewTicker. +type Ticker struct { + C <-chan time.Time // The channel on which ticks are delivered. + + // em is the eventManager to be notified when nextTrigger changes. + // eventManager has its own mutex, and the pointer is immutable, therefore + // em can be accessed without holding mu. + em *eventManager + + c chan<- time.Time // The writer side of C. + + mu sync.Mutex + + // nextTrigger is the time of the ticker's next scheduled activation. When + // Fire activates the ticker, nextTrigger is the timestamp written to the + // channel. + nextTrigger time.Time + + // period is the duration that is added to nextTrigger when the ticker + // fires. + period time.Duration +} + +func (t *Ticker) init(channelSize int) { + if channelSize <= 0 { + panic("ticker channel size must be non-negative") + } + c := make(chan time.Time, channelSize) + t.c = c + t.C = c + t.em.Reschedule(t, t.nextTrigger) +} + +// Fire triggers the ticker. curTime is the timestamp to write to the channel. +// The next trigger time for the ticker is updated to the last computed trigger +// time + the ticker period (set at creation or using Reset). The next trigger +// time is computed this way to match standard time.Ticker behavior, which +// prevents accumulation of long term drift caused by delays in event execution. +func (t *Ticker) Fire(curTime time.Time) time.Time { + t.mu.Lock() + defer t.mu.Unlock() + + if t.nextTrigger.IsZero() { + return time.Time{} + } + select { + case t.c <- curTime: + default: + } + t.nextTrigger = t.nextTrigger.Add(t.period) + + return t.nextTrigger +} + +// Reset adjusts the Ticker's period to d and reschedules the next fire time to +// the current simulated time + d. +func (t *Ticker) Reset(d time.Duration) { + if d <= 0 { + // The standard time.Ticker requires a positive period. + panic("non-positive period for Ticker.Reset") + } + + now := t.em.Now() + + t.mu.Lock() + t.resetLocked(now.Add(d), d) + t.mu.Unlock() + + t.em.Reschedule(t, t.nextTrigger) +} + +// ResetAbsolute adjusts the Ticker's period to d and reschedules the next fire +// time to nextTrigger. +func (t *Ticker) ResetAbsolute(nextTrigger time.Time, d time.Duration) { + if nextTrigger.IsZero() { + panic("zero nextTrigger time for ResetAbsolute") + } + if d <= 0 { + panic("non-positive period for ResetAbsolute") + } + + t.mu.Lock() + t.resetLocked(nextTrigger, d) + t.mu.Unlock() + + t.em.Reschedule(t, t.nextTrigger) +} + +func (t *Ticker) resetLocked(nextTrigger time.Time, d time.Duration) { + t.nextTrigger = nextTrigger + t.period = d +} + +// Stop deactivates the Ticker. +func (t *Ticker) Stop() { + t.mu.Lock() + t.nextTrigger = time.Time{} + t.mu.Unlock() + + t.em.Reschedule(t, t.nextTrigger) +} + +// Timer is a time.Timer lookalike for use in tests that need to control when +// events fire. Timer could be made standalone in future but for now must be +// paired with a Clock and created by Clock.NewTimer. +type Timer struct { + C <-chan time.Time // The channel on which ticks are delivered. + + // em is the eventManager to be notified when nextTrigger changes. + // eventManager has its own mutex, and the pointer is immutable, therefore + // em can be accessed without holding mu. + em *eventManager + + f func(time.Time) // The function to call when the timer expires. + + mu sync.Mutex + + // nextTrigger is the time of the ticker's next scheduled activation. When + // Fire activates the ticker, nextTrigger is the timestamp written to the + // channel. + nextTrigger time.Time +} + +func (t *Timer) init(channelSize int, afterFunc func()) { + if channelSize <= 0 { + panic("ticker channel size must be non-negative") + } + c := make(chan time.Time, channelSize) + t.C = c + if afterFunc == nil { + t.f = func(curTime time.Time) { + select { + case c <- curTime: + default: + } + } + } else { + t.f = func(_ time.Time) { afterFunc() } + } + t.em.Reschedule(t, t.nextTrigger) +} + +// Fire triggers the ticker. curTime is the timestamp to write to the channel. +// The next trigger time for the ticker is updated to the last computed trigger +// time + the ticker period (set at creation or using Reset). The next trigger +// time is computed this way to match standard time.Ticker behavior, which +// prevents accumulation of long term drift caused by delays in event execution. +func (t *Timer) Fire(curTime time.Time) time.Time { + t.mu.Lock() + defer t.mu.Unlock() + + if t.nextTrigger.IsZero() { + return time.Time{} + } + t.nextTrigger = time.Time{} + t.f(curTime) + return time.Time{} +} + +// Reset reschedules the next fire time to the current simulated time + d. +// Reset reports whether the timer was still active before the reset. +func (t *Timer) Reset(d time.Duration) bool { + if d <= 0 { + // The standard time.Timer requires a positive delay. + panic("non-positive delay for Timer.Reset") + } + + return t.reset(t.em.Now().Add(d)) +} + +// ResetAbsolute reschedules the next fire time to nextTrigger. +// ResetAbsolute reports whether the timer was still active before the reset. +func (t *Timer) ResetAbsolute(nextTrigger time.Time) bool { + if nextTrigger.IsZero() { + panic("zero nextTrigger time for ResetAbsolute") + } + + return t.reset(nextTrigger) +} + +// Stop deactivates the Timer. Stop reports whether the timer was active before +// stopping. +func (t *Timer) Stop() bool { + return t.reset(time.Time{}) +} + +func (t *Timer) reset(nextTrigger time.Time) bool { + t.mu.Lock() + wasActive := !t.nextTrigger.IsZero() + t.nextTrigger = nextTrigger + t.mu.Unlock() + + t.em.Reschedule(t, t.nextTrigger) + return wasActive +} diff --git a/tstest/deptest/deptest_test.go b/tstest/deptest/deptest_test.go index 3b7b2dde91dec..ebafa56849efb 100644 --- a/tstest/deptest/deptest_test.go +++ b/tstest/deptest/deptest_test.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package deptest - -import "testing" - -func TestImports(t *testing.T) { - ImportAliasCheck(t, "../../") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package deptest + +import "testing" + +func TestImports(t *testing.T) { + ImportAliasCheck(t, "../../") +} diff --git a/tstest/integration/gen_deps.go b/tstest/integration/gen_deps.go index ab5cc0448b54d..23bb95ee56a9f 100644 --- a/tstest/integration/gen_deps.go +++ b/tstest/integration/gen_deps.go @@ -1,65 +1,65 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -package main - -import ( - "bytes" - "encoding/json" - "fmt" - "log" - "os" - "os/exec" - "strings" -) - -func main() { - for _, goos := range []string{"windows", "linux", "darwin", "freebsd", "openbsd"} { - generate(goos) - } -} - -func generate(goos string) { - var x struct { - Imports []string - } - cmd := exec.Command("go", "list", "-json", "tailscale.com/cmd/tailscaled") - cmd.Env = append(os.Environ(), "GOOS="+goos, "GOARCH=amd64") - j, err := cmd.Output() - if err != nil { - log.Fatalf("GOOS=%s GOARCH=amd64 %s: %v", goos, cmd, err) - } - if err := json.Unmarshal(j, &x); err != nil { - log.Fatal(err) - } - var out bytes.Buffer - out.WriteString(`// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Code generated by gen_deps.go; DO NOT EDIT. - -package integration - -import ( - // And depend on a bunch of tailscaled innards, for Go's test caching. - // Otherwise cmd/go never sees that we depend on these packages' - // transitive deps when we run "go install tailscaled" in a child - // process and can cache a prior success when a dependency changes. -`) - for _, dep := range x.Imports { - if !strings.Contains(dep, ".") { - // Omit standard library deps. - continue - } - fmt.Fprintf(&out, "\t_ %q\n", dep) - } - fmt.Fprintf(&out, ")\n") - - filename := fmt.Sprintf("tailscaled_deps_test_%s.go", goos) - err = os.WriteFile(filename, out.Bytes(), 0644) - if err != nil { - log.Fatal(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "log" + "os" + "os/exec" + "strings" +) + +func main() { + for _, goos := range []string{"windows", "linux", "darwin", "freebsd", "openbsd"} { + generate(goos) + } +} + +func generate(goos string) { + var x struct { + Imports []string + } + cmd := exec.Command("go", "list", "-json", "tailscale.com/cmd/tailscaled") + cmd.Env = append(os.Environ(), "GOOS="+goos, "GOARCH=amd64") + j, err := cmd.Output() + if err != nil { + log.Fatalf("GOOS=%s GOARCH=amd64 %s: %v", goos, cmd, err) + } + if err := json.Unmarshal(j, &x); err != nil { + log.Fatal(err) + } + var out bytes.Buffer + out.WriteString(`// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen_deps.go; DO NOT EDIT. + +package integration + +import ( + // And depend on a bunch of tailscaled innards, for Go's test caching. + // Otherwise cmd/go never sees that we depend on these packages' + // transitive deps when we run "go install tailscaled" in a child + // process and can cache a prior success when a dependency changes. +`) + for _, dep := range x.Imports { + if !strings.Contains(dep, ".") { + // Omit standard library deps. + continue + } + fmt.Fprintf(&out, "\t_ %q\n", dep) + } + fmt.Fprintf(&out, ")\n") + + filename := fmt.Sprintf("tailscaled_deps_test_%s.go", goos) + err = os.WriteFile(filename, out.Bytes(), 0644) + if err != nil { + log.Fatal(err) + } +} diff --git a/tstest/integration/vms/README.md b/tstest/integration/vms/README.md index 766d8e5741e6d..519c3d000fb63 100644 --- a/tstest/integration/vms/README.md +++ b/tstest/integration/vms/README.md @@ -1,95 +1,95 @@ -# End-to-End VM-based Integration Testing - -This test spins up a bunch of common linux distributions and then tries to get -them to connect to a -[`testcontrol`](https://pkg.go.dev/tailscale.com/tstest/integration/testcontrol) -server. - -## Running - -This test currently only runs on Linux. - -This test depends on the following command line tools: - -- [qemu](https://www.qemu.org/) -- [cdrkit](https://en.wikipedia.org/wiki/Cdrkit) -- [openssh](https://www.openssh.com/) - -This test also requires the following: - -- about 10 GB of temporary storage -- about 10 GB of cached VM images -- at least 4 GB of ram for virtual machines -- hardware virtualization support - ([KVM](https://www.linux-kvm.org/page/Main_Page)) enabled in the BIOS -- the `kvm` module to be loaded (`modprobe kvm`) -- the user running these tests must have access to `/dev/kvm` (being in the - `kvm` group should suffice) - -The `--no-s3` flag is needed to disable downloads from S3, which require -credentials. However keep in mind that some distributions do not use stable URLs -for each individual image artifact, so there may be spurious test failures as a -result. - -If you are using [Nix](https://nixos.org), you can run all of the tests with the -correct command line tools using this command: - -```console -$ nix-shell -p nixos-generators -p openssh -p go -p qemu -p cdrkit --run "go test . --run-vm-tests --v --timeout 30m --no-s3" -``` - -Keep the timeout high for the first run, especially if you are not downloading -VM images from S3. The mirrors we pull images from have download rate limits and -will take a while to download. - -Because of the hardware requirements of this test, this test will not run -without the `--run-vm-tests` flag set. - -## Other Fun Flags - -This test's behavior is customized with command line flags. - -### Don't Download Images From S3 - -If you pass the `-no-s3` flag to `go test`, the S3 step will be skipped in favor -of downloading the images directly from upstream sources, which may cause the -test to fail in odd places. - -### Distribution Picking - -This test runs on a large number of distributions. By default it tries to run -everything, which may or may not be ideal for you. If you only want to test a -subset of distributions, you can use the `--distro-regex` flag to match a subset -of distributions using a [regular expression](https://golang.org/pkg/regexp/) -such as like this: - -```console -$ go test -run-vm-tests -distro-regex centos -``` - -This would run all tests on all versions of CentOS. - -```console -$ go test -run-vm-tests -distro-regex '(debian|ubuntu)' -``` - -This would run all tests on all versions of Debian and Ubuntu. - -### Ram Limiting - -This test uses a lot of memory. In order to avoid making machines run out of -memory running this test, a semaphore is used to limit how many megabytes of ram -are being used at once. By default this semaphore is set to 4096 MB of ram -(about 4 gigabytes). You can customize this with the `--ram-limit` flag: - -```console -$ go test --run-vm-tests --ram-limit 2048 -$ go test --run-vm-tests --ram-limit 65536 -``` - -The first example will set the limit to 2048 MB of ram (about 2 gigabytes). The -second example will set the limit to 65536 MB of ram (about 65 gigabytes). -Please be careful with this flag, improper usage of it is known to cause the -Linux out-of-memory killer to engage. Try to keep it within 50-75% of your -machine's available ram (there is some overhead involved with the -virtualization) to be on the safe side. +# End-to-End VM-based Integration Testing + +This test spins up a bunch of common linux distributions and then tries to get +them to connect to a +[`testcontrol`](https://pkg.go.dev/tailscale.com/tstest/integration/testcontrol) +server. + +## Running + +This test currently only runs on Linux. + +This test depends on the following command line tools: + +- [qemu](https://www.qemu.org/) +- [cdrkit](https://en.wikipedia.org/wiki/Cdrkit) +- [openssh](https://www.openssh.com/) + +This test also requires the following: + +- about 10 GB of temporary storage +- about 10 GB of cached VM images +- at least 4 GB of ram for virtual machines +- hardware virtualization support + ([KVM](https://www.linux-kvm.org/page/Main_Page)) enabled in the BIOS +- the `kvm` module to be loaded (`modprobe kvm`) +- the user running these tests must have access to `/dev/kvm` (being in the + `kvm` group should suffice) + +The `--no-s3` flag is needed to disable downloads from S3, which require +credentials. However keep in mind that some distributions do not use stable URLs +for each individual image artifact, so there may be spurious test failures as a +result. + +If you are using [Nix](https://nixos.org), you can run all of the tests with the +correct command line tools using this command: + +```console +$ nix-shell -p nixos-generators -p openssh -p go -p qemu -p cdrkit --run "go test . --run-vm-tests --v --timeout 30m --no-s3" +``` + +Keep the timeout high for the first run, especially if you are not downloading +VM images from S3. The mirrors we pull images from have download rate limits and +will take a while to download. + +Because of the hardware requirements of this test, this test will not run +without the `--run-vm-tests` flag set. + +## Other Fun Flags + +This test's behavior is customized with command line flags. + +### Don't Download Images From S3 + +If you pass the `-no-s3` flag to `go test`, the S3 step will be skipped in favor +of downloading the images directly from upstream sources, which may cause the +test to fail in odd places. + +### Distribution Picking + +This test runs on a large number of distributions. By default it tries to run +everything, which may or may not be ideal for you. If you only want to test a +subset of distributions, you can use the `--distro-regex` flag to match a subset +of distributions using a [regular expression](https://golang.org/pkg/regexp/) +such as like this: + +```console +$ go test -run-vm-tests -distro-regex centos +``` + +This would run all tests on all versions of CentOS. + +```console +$ go test -run-vm-tests -distro-regex '(debian|ubuntu)' +``` + +This would run all tests on all versions of Debian and Ubuntu. + +### Ram Limiting + +This test uses a lot of memory. In order to avoid making machines run out of +memory running this test, a semaphore is used to limit how many megabytes of ram +are being used at once. By default this semaphore is set to 4096 MB of ram +(about 4 gigabytes). You can customize this with the `--ram-limit` flag: + +```console +$ go test --run-vm-tests --ram-limit 2048 +$ go test --run-vm-tests --ram-limit 65536 +``` + +The first example will set the limit to 2048 MB of ram (about 2 gigabytes). The +second example will set the limit to 65536 MB of ram (about 65 gigabytes). +Please be careful with this flag, improper usage of it is known to cause the +Linux out-of-memory killer to engage. Try to keep it within 50-75% of your +machine's available ram (there is some overhead involved with the +virtualization) to be on the safe side. diff --git a/tstest/integration/vms/distros.hujson b/tstest/integration/vms/distros.hujson index 5634d6d678562..049091ed50e6e 100644 --- a/tstest/integration/vms/distros.hujson +++ b/tstest/integration/vms/distros.hujson @@ -1,39 +1,39 @@ -// NOTE(Xe): If you run into issues getting the autoconfig to work, run -// this test with the flag `--distro-regex=alpine-edge`. Connect with a VNC -// client with a command like this: -// -// $ vncviewer :0 -// -// On NixOS you can get away with something like this: -// -// $ env NIXPKGS_ALLOW_UNFREE=1 nix-shell -p tigervnc --run 'vncviewer :0' -// -// Login as root with the password root. Then look in -// /var/log/cloud-init-output.log for what you messed up. -[ - { - "Name": "ubuntu-18-04", - "URL": "https://cloud-images.ubuntu.com/releases/bionic/release-20210817/ubuntu-18.04-server-cloudimg-amd64.img", - "SHA256Sum": "1ee1039f0b91c8367351413b5b5f56026aaf302fd5f66f17f8215132d6e946d2", - "MemoryMegs": 512, - "PackageManager": "apt", - "InitSystem": "systemd" - }, - { - "Name": "ubuntu-20-04", - "URL": "https://cloud-images.ubuntu.com/releases/focal/release-20210819/ubuntu-20.04-server-cloudimg-amd64.img", - "SHA256Sum": "99e25e6e344e3a50a081235e825937238a3d51b099969e107ef66f0d3a1f955e", - "MemoryMegs": 512, - "PackageManager": "apt", - "InitSystem": "systemd" - }, - { - "Name": "nixos-21-11", - "URL": "channel:nixos-21.11", - "SHA256Sum": "lolfakesha", - "MemoryMegs": 512, - "PackageManager": "nix", - "InitSystem": "systemd", - "HostGenerated": true - }, -] +// NOTE(Xe): If you run into issues getting the autoconfig to work, run +// this test with the flag `--distro-regex=alpine-edge`. Connect with a VNC +// client with a command like this: +// +// $ vncviewer :0 +// +// On NixOS you can get away with something like this: +// +// $ env NIXPKGS_ALLOW_UNFREE=1 nix-shell -p tigervnc --run 'vncviewer :0' +// +// Login as root with the password root. Then look in +// /var/log/cloud-init-output.log for what you messed up. +[ + { + "Name": "ubuntu-18-04", + "URL": "https://cloud-images.ubuntu.com/releases/bionic/release-20210817/ubuntu-18.04-server-cloudimg-amd64.img", + "SHA256Sum": "1ee1039f0b91c8367351413b5b5f56026aaf302fd5f66f17f8215132d6e946d2", + "MemoryMegs": 512, + "PackageManager": "apt", + "InitSystem": "systemd" + }, + { + "Name": "ubuntu-20-04", + "URL": "https://cloud-images.ubuntu.com/releases/focal/release-20210819/ubuntu-20.04-server-cloudimg-amd64.img", + "SHA256Sum": "99e25e6e344e3a50a081235e825937238a3d51b099969e107ef66f0d3a1f955e", + "MemoryMegs": 512, + "PackageManager": "apt", + "InitSystem": "systemd" + }, + { + "Name": "nixos-21-11", + "URL": "channel:nixos-21.11", + "SHA256Sum": "lolfakesha", + "MemoryMegs": 512, + "PackageManager": "nix", + "InitSystem": "systemd", + "HostGenerated": true + }, +] diff --git a/tstest/integration/vms/distros_test.go b/tstest/integration/vms/distros_test.go index db3bae793b367..462aa2a6bc825 100644 --- a/tstest/integration/vms/distros_test.go +++ b/tstest/integration/vms/distros_test.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package vms - -import ( - "testing" -) - -func TestDistrosGotLoaded(t *testing.T) { - if len(Distros) == 0 { - t.Fatal("no distros were loaded") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package vms + +import ( + "testing" +) + +func TestDistrosGotLoaded(t *testing.T) { + if len(Distros) == 0 { + t.Fatal("no distros were loaded") + } +} diff --git a/tstest/integration/vms/dns_tester.go b/tstest/integration/vms/dns_tester.go index be7d7ee6d69c8..50b39bb5f1fa1 100644 --- a/tstest/integration/vms/dns_tester.go +++ b/tstest/integration/vms/dns_tester.go @@ -1,54 +1,54 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -// Command dns_tester exists in order to perform tests of our DNS -// configuration stack. This was written because the state of DNS -// in our target environments is so diverse that we need a little tool -// to do this test for us. -package main - -import ( - "context" - "encoding/json" - "flag" - "net" - "os" - "time" -) - -func main() { - flag.Parse() - target := flag.Arg(0) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - errCount := 0 - wait := 25 * time.Millisecond - for range make([]struct{}, 5) { - err := lookup(ctx, target) - if err != nil { - errCount++ - time.Sleep(wait) - wait = wait * 2 - continue - } - - break - } -} - -func lookup(ctx context.Context, target string) error { - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - - hosts, err := net.LookupHost(target) - if err != nil { - return err - } - - json.NewEncoder(os.Stdout).Encode(hosts) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +// Command dns_tester exists in order to perform tests of our DNS +// configuration stack. This was written because the state of DNS +// in our target environments is so diverse that we need a little tool +// to do this test for us. +package main + +import ( + "context" + "encoding/json" + "flag" + "net" + "os" + "time" +) + +func main() { + flag.Parse() + target := flag.Arg(0) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCount := 0 + wait := 25 * time.Millisecond + for range make([]struct{}, 5) { + err := lookup(ctx, target) + if err != nil { + errCount++ + time.Sleep(wait) + wait = wait * 2 + continue + } + + break + } +} + +func lookup(ctx context.Context, target string) error { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + hosts, err := net.LookupHost(target) + if err != nil { + return err + } + + json.NewEncoder(os.Stdout).Encode(hosts) + return nil +} diff --git a/tstest/integration/vms/doc.go b/tstest/integration/vms/doc.go index 3008493ea1a33..6093b53ac8ed5 100644 --- a/tstest/integration/vms/doc.go +++ b/tstest/integration/vms/doc.go @@ -1,6 +1,6 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package vms does VM-based integration/functional tests by using -// qemu and a bank of pre-made VM images. -package vms +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package vms does VM-based integration/functional tests by using +// qemu and a bank of pre-made VM images. +package vms diff --git a/tstest/integration/vms/harness_test.go b/tstest/integration/vms/harness_test.go index 620276ac26491..1e080414d72e7 100644 --- a/tstest/integration/vms/harness_test.go +++ b/tstest/integration/vms/harness_test.go @@ -1,242 +1,242 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !plan9 - -package vms - -import ( - "bytes" - "context" - "fmt" - "log" - "net" - "net/http" - "net/netip" - "os" - "os/exec" - "path" - "path/filepath" - "strconv" - "sync" - "testing" - "time" - - "golang.org/x/crypto/ssh" - "golang.org/x/net/proxy" - "tailscale.com/tailcfg" - "tailscale.com/tstest/integration" - "tailscale.com/tstest/integration/testcontrol" - "tailscale.com/types/dnstype" -) - -type Harness struct { - testerDialer proxy.Dialer - testerDir string - binaryDir string - cli string - daemon string - pubKey string - signer ssh.Signer - cs *testcontrol.Server - loginServerURL string - testerV4 netip.Addr - ipMu *sync.Mutex - ipMap map[string]ipMapping -} - -func newHarness(t *testing.T) *Harness { - dir := t.TempDir() - bindHost := deriveBindhost(t) - ln, err := net.Listen("tcp", net.JoinHostPort(bindHost, "0")) - if err != nil { - t.Fatalf("can't make TCP listener: %v", err) - } - t.Cleanup(func() { - ln.Close() - }) - t.Logf("host:port: %s", ln.Addr()) - - cs := &testcontrol.Server{ - DNSConfig: &tailcfg.DNSConfig{ - // TODO: this is wrong. - // It is also only one of many configurations. - // Figure out how to scale it up. - Resolvers: []*dnstype.Resolver{{Addr: "100.100.100.100"}, {Addr: "8.8.8.8"}}, - Domains: []string{"record"}, - Proxied: true, - ExtraRecords: []tailcfg.DNSRecord{{Name: "extratest.record", Type: "A", Value: "1.2.3.4"}}, - }, - } - - derpMap := integration.RunDERPAndSTUN(t, t.Logf, bindHost) - cs.DERPMap = derpMap - - var ( - ipMu sync.Mutex - ipMap = map[string]ipMapping{} - ) - - mux := http.NewServeMux() - mux.Handle("/", cs) - - lc := &integration.LogCatcher{} - if *verboseLogcatcher { - lc.UseLogf(t.Logf) - t.Cleanup(func() { - lc.UseLogf(nil) // do not log after test is complete - }) - } - mux.Handle("/c/", lc) - - // This handler will let the virtual machines tell the host information about that VM. - // This is used to maintain a list of port->IP address mappings that are known to be - // working. This allows later steps to connect over SSH. This returns no response to - // clients because no response is needed. - mux.HandleFunc("/myip/", func(w http.ResponseWriter, r *http.Request) { - ipMu.Lock() - defer ipMu.Unlock() - - name := path.Base(r.URL.Path) - host, _, _ := net.SplitHostPort(r.RemoteAddr) - port, err := strconv.Atoi(name) - if err != nil { - log.Panicf("bad port: %v", port) - } - distro := r.UserAgent() - ipMap[distro] = ipMapping{distro, port, host} - t.Logf("%s: %v", name, host) - }) - - hs := &http.Server{Handler: mux} - go hs.Serve(ln) - - cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", "machinekey", "-N", "") - cmd.Dir = dir - if out, err := cmd.CombinedOutput(); err != nil { - t.Fatalf("ssh-keygen: %v, %s", err, out) - } - pubkey, err := os.ReadFile(filepath.Join(dir, "machinekey.pub")) - if err != nil { - t.Fatalf("can't read ssh key: %v", err) - } - - privateKey, err := os.ReadFile(filepath.Join(dir, "machinekey")) - if err != nil { - t.Fatalf("can't read ssh private key: %v", err) - } - - signer, err := ssh.ParsePrivateKey(privateKey) - if err != nil { - t.Fatalf("can't parse private key: %v", err) - } - - loginServer := fmt.Sprintf("http://%s", ln.Addr()) - t.Logf("loginServer: %s", loginServer) - - h := &Harness{ - pubKey: string(pubkey), - binaryDir: integration.BinaryDir(t), - cli: integration.TailscaleBinary(t), - daemon: integration.TailscaledBinary(t), - signer: signer, - loginServerURL: loginServer, - cs: cs, - ipMu: &ipMu, - ipMap: ipMap, - } - - h.makeTestNode(t, loginServer) - - return h -} - -func (h *Harness) Tailscale(t *testing.T, args ...string) []byte { - t.Helper() - - args = append([]string{"--socket=" + filepath.Join(h.testerDir, "sock")}, args...) - - cmd := exec.Command(h.cli, args...) - out, err := cmd.CombinedOutput() - if err != nil { - t.Fatal(err) - } - - return out -} - -// makeTestNode creates a userspace tailscaled running in netstack mode that -// enables us to make connections to and from the tailscale network being -// tested. This mutates the Harness to allow tests to dial into the tailscale -// network as well as control the tester's tailscaled. -func (h *Harness) makeTestNode(t *testing.T, controlURL string) { - dir := t.TempDir() - h.testerDir = dir - - port, err := getProbablyFreePortNumber() - if err != nil { - t.Fatalf("can't get free port: %v", err) - } - - cmd := exec.Command( - h.daemon, - "--tun=userspace-networking", - "--state="+filepath.Join(dir, "state.json"), - "--socket="+filepath.Join(dir, "sock"), - fmt.Sprintf("--socks5-server=localhost:%d", port), - ) - - cmd.Env = append( - os.Environ(), - "NOTIFY_SOCKET="+filepath.Join(dir, "notify_socket"), - "TS_LOG_TARGET="+h.loginServerURL, - ) - - err = cmd.Start() - if err != nil { - t.Fatalf("can't start tailscaled: %v", err) - } - - t.Cleanup(func() { - cmd.Process.Kill() - }) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - ticker := time.NewTicker(100 * time.Millisecond) - -outer: - for { - select { - case <-ctx.Done(): - t.Fatal("timed out waiting for tailscaled to come up") - return - case <-ticker.C: - conn, err := net.Dial("unix", filepath.Join(dir, "sock")) - if err != nil { - continue - } - - conn.Close() - break outer - } - } - - run(t, dir, h.cli, - "--socket="+filepath.Join(dir, "sock"), - "up", - "--login-server="+controlURL, - "--hostname=tester", - ) - - dialer, err := proxy.SOCKS5("tcp", net.JoinHostPort("127.0.0.1", fmt.Sprint(port)), nil, &net.Dialer{}) - if err != nil { - t.Fatalf("can't make netstack proxy dialer: %v", err) - } - h.testerDialer = dialer - h.testerV4 = bytes2Netaddr(h.Tailscale(t, "ip", "-4")) -} - -func bytes2Netaddr(inp []byte) netip.Addr { - return netip.MustParseAddr(string(bytes.TrimSpace(inp))) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !plan9 + +package vms + +import ( + "bytes" + "context" + "fmt" + "log" + "net" + "net/http" + "net/netip" + "os" + "os/exec" + "path" + "path/filepath" + "strconv" + "sync" + "testing" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/net/proxy" + "tailscale.com/tailcfg" + "tailscale.com/tstest/integration" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/dnstype" +) + +type Harness struct { + testerDialer proxy.Dialer + testerDir string + binaryDir string + cli string + daemon string + pubKey string + signer ssh.Signer + cs *testcontrol.Server + loginServerURL string + testerV4 netip.Addr + ipMu *sync.Mutex + ipMap map[string]ipMapping +} + +func newHarness(t *testing.T) *Harness { + dir := t.TempDir() + bindHost := deriveBindhost(t) + ln, err := net.Listen("tcp", net.JoinHostPort(bindHost, "0")) + if err != nil { + t.Fatalf("can't make TCP listener: %v", err) + } + t.Cleanup(func() { + ln.Close() + }) + t.Logf("host:port: %s", ln.Addr()) + + cs := &testcontrol.Server{ + DNSConfig: &tailcfg.DNSConfig{ + // TODO: this is wrong. + // It is also only one of many configurations. + // Figure out how to scale it up. + Resolvers: []*dnstype.Resolver{{Addr: "100.100.100.100"}, {Addr: "8.8.8.8"}}, + Domains: []string{"record"}, + Proxied: true, + ExtraRecords: []tailcfg.DNSRecord{{Name: "extratest.record", Type: "A", Value: "1.2.3.4"}}, + }, + } + + derpMap := integration.RunDERPAndSTUN(t, t.Logf, bindHost) + cs.DERPMap = derpMap + + var ( + ipMu sync.Mutex + ipMap = map[string]ipMapping{} + ) + + mux := http.NewServeMux() + mux.Handle("/", cs) + + lc := &integration.LogCatcher{} + if *verboseLogcatcher { + lc.UseLogf(t.Logf) + t.Cleanup(func() { + lc.UseLogf(nil) // do not log after test is complete + }) + } + mux.Handle("/c/", lc) + + // This handler will let the virtual machines tell the host information about that VM. + // This is used to maintain a list of port->IP address mappings that are known to be + // working. This allows later steps to connect over SSH. This returns no response to + // clients because no response is needed. + mux.HandleFunc("/myip/", func(w http.ResponseWriter, r *http.Request) { + ipMu.Lock() + defer ipMu.Unlock() + + name := path.Base(r.URL.Path) + host, _, _ := net.SplitHostPort(r.RemoteAddr) + port, err := strconv.Atoi(name) + if err != nil { + log.Panicf("bad port: %v", port) + } + distro := r.UserAgent() + ipMap[distro] = ipMapping{distro, port, host} + t.Logf("%s: %v", name, host) + }) + + hs := &http.Server{Handler: mux} + go hs.Serve(ln) + + cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", "machinekey", "-N", "") + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("ssh-keygen: %v, %s", err, out) + } + pubkey, err := os.ReadFile(filepath.Join(dir, "machinekey.pub")) + if err != nil { + t.Fatalf("can't read ssh key: %v", err) + } + + privateKey, err := os.ReadFile(filepath.Join(dir, "machinekey")) + if err != nil { + t.Fatalf("can't read ssh private key: %v", err) + } + + signer, err := ssh.ParsePrivateKey(privateKey) + if err != nil { + t.Fatalf("can't parse private key: %v", err) + } + + loginServer := fmt.Sprintf("http://%s", ln.Addr()) + t.Logf("loginServer: %s", loginServer) + + h := &Harness{ + pubKey: string(pubkey), + binaryDir: integration.BinaryDir(t), + cli: integration.TailscaleBinary(t), + daemon: integration.TailscaledBinary(t), + signer: signer, + loginServerURL: loginServer, + cs: cs, + ipMu: &ipMu, + ipMap: ipMap, + } + + h.makeTestNode(t, loginServer) + + return h +} + +func (h *Harness) Tailscale(t *testing.T, args ...string) []byte { + t.Helper() + + args = append([]string{"--socket=" + filepath.Join(h.testerDir, "sock")}, args...) + + cmd := exec.Command(h.cli, args...) + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatal(err) + } + + return out +} + +// makeTestNode creates a userspace tailscaled running in netstack mode that +// enables us to make connections to and from the tailscale network being +// tested. This mutates the Harness to allow tests to dial into the tailscale +// network as well as control the tester's tailscaled. +func (h *Harness) makeTestNode(t *testing.T, controlURL string) { + dir := t.TempDir() + h.testerDir = dir + + port, err := getProbablyFreePortNumber() + if err != nil { + t.Fatalf("can't get free port: %v", err) + } + + cmd := exec.Command( + h.daemon, + "--tun=userspace-networking", + "--state="+filepath.Join(dir, "state.json"), + "--socket="+filepath.Join(dir, "sock"), + fmt.Sprintf("--socks5-server=localhost:%d", port), + ) + + cmd.Env = append( + os.Environ(), + "NOTIFY_SOCKET="+filepath.Join(dir, "notify_socket"), + "TS_LOG_TARGET="+h.loginServerURL, + ) + + err = cmd.Start() + if err != nil { + t.Fatalf("can't start tailscaled: %v", err) + } + + t.Cleanup(func() { + cmd.Process.Kill() + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ticker := time.NewTicker(100 * time.Millisecond) + +outer: + for { + select { + case <-ctx.Done(): + t.Fatal("timed out waiting for tailscaled to come up") + return + case <-ticker.C: + conn, err := net.Dial("unix", filepath.Join(dir, "sock")) + if err != nil { + continue + } + + conn.Close() + break outer + } + } + + run(t, dir, h.cli, + "--socket="+filepath.Join(dir, "sock"), + "up", + "--login-server="+controlURL, + "--hostname=tester", + ) + + dialer, err := proxy.SOCKS5("tcp", net.JoinHostPort("127.0.0.1", fmt.Sprint(port)), nil, &net.Dialer{}) + if err != nil { + t.Fatalf("can't make netstack proxy dialer: %v", err) + } + h.testerDialer = dialer + h.testerV4 = bytes2Netaddr(h.Tailscale(t, "ip", "-4")) +} + +func bytes2Netaddr(inp []byte) netip.Addr { + return netip.MustParseAddr(string(bytes.TrimSpace(inp))) +} diff --git a/tstest/integration/vms/nixos_test.go b/tstest/integration/vms/nixos_test.go index 06a14e4f6cc21..c2998ff3c087c 100644 --- a/tstest/integration/vms/nixos_test.go +++ b/tstest/integration/vms/nixos_test.go @@ -1,231 +1,231 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !plan9 - -package vms - -import ( - "flag" - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - "testing" - "text/template" - - "tailscale.com/types/logger" -) - -var ( - verboseNixOutput = flag.Bool("verbose-nix-output", false, "if set, use verbose nix output (lots of noise)") -) - -/* - NOTE(Xe): Okay, so, at a high level testing NixOS is a lot different than - other distros due to NixOS' determinism. Normally NixOS wants packages to - be defined in either an overlay, a custom packageOverrides or even - yolo-inline as a part of the system configuration. This is going to have - us take a different approach compared to other distributions. The overall - plan here is as following: - - 1. make the binaries as normal - 2. template in their paths as raw strings to the nixos system module - 3. run `nixos-generators -f qcow -o $CACHE_DIR/tailscale/nixos/version -c generated-config.nix` - 4. pass that to the steps that make the virtual machine - - It doesn't really make sense for us to use a premade virtual machine image - for this as that will make it harder to deterministically create the image. -*/ - -const nixosConfigTemplate = ` -# NOTE(Xe): This template is going to be heavily commented. - -# All NixOS modules are functions. Here is the function prelude for this NixOS -# module that defines the system. It is a function that takes in an attribute -# set (effectively a map[string]nix.Value) and destructures it to some variables: -{ - # other NixOS settings as defined in other modules - config, - - # nixpkgs, which is basically the standard library of NixOS - pkgs, - - # the path to some system-scoped NixOS modules that aren't imported by default - modulesPath, - - # the rest of the arguments don't matter - ... -}: - -# Nix's syntax was inspired by Haskell and other functional languages, so the -# let .. in pattern is used to create scoped variables: -let - # Define the package (derivation) for Tailscale based on the binaries we - # just built for this test: - testTailscale = pkgs.stdenv.mkDerivation { - # The name of the package. This usually includes a version however it - # doesn't matter here. - name = "tailscale-test"; - - # The path on disk to the "source code" of the package, in this case it is - # the path to the binaries that are built. This needs to be the raw - # unquoted slash-separated path, not a string containing the path because Nix - # has a special path type. - src = {{.BinPath}}; - - # We only need to worry about the install phase because we've already - # built the binaries. - phases = "installPhase"; - - # We need to wrap tailscaled such that it has iptables in its $PATH. - nativeBuildInputs = [ pkgs.makeWrapper ]; - - # The install instructions for this package ('' ''defines a multi-line string). - # The with statement lets us bring in values into scope as if they were - # defined in the current scope. - installPhase = with pkgs; '' - # This is bash. - - # Make the output folders for the package (systemd unit and binary folders). - mkdir -p $out/bin - - # Install tailscale{,d} - cp $src/tailscale $out/bin/tailscale - cp $src/tailscaled $out/bin/tailscaled - - # Wrap tailscaled with the ip and iptables commands. - wrapProgram $out/bin/tailscaled --prefix PATH : ${ - lib.makeBinPath [ iproute iptables ] - } - - # Install systemd unit. - cp $src/systemd/tailscaled.service . - sed -i -e "s#/usr/sbin#$out/bin#" -e "/^EnvironmentFile/d" ./tailscaled.service - install -D -m0444 -t $out/lib/systemd/system ./tailscaled.service - ''; - }; -in { - # This is a QEMU VM. This module has a lot of common qemu VM settings so you - # don't have to set them manually. - imports = [ (modulesPath + "/profiles/qemu-guest.nix") ]; - - # We need virtio support to boot. - boot.initrd.availableKernelModules = - [ "ata_piix" "uhci_hcd" "virtio_pci" "sr_mod" "virtio_blk" ]; - boot.initrd.kernelModules = [ ]; - boot.kernelModules = [ ]; - boot.extraModulePackages = [ ]; - - # Curl is needed for one of the steps in cloud-final - systemd.services.cloud-final.path = with pkgs; [ curl ]; - - # Curl is needed for one of the integration tests - environment.systemPackages = with pkgs; [ curl nix bash squid openssl daemonize ]; - - # yolo, this vm can sudo freely. - security.sudo.wheelNeedsPassword = false; - - # Enable cloud-init so we can set VM hostnames and the like the same as other - # distros. This will also take care of SSH keys. It's pretty handy. - services.cloud-init = { - enable = true; - ext4.enable = true; - }; - - # We want sshd running. - services.openssh.enable = true; - - # Tailscale settings: - services.tailscale = { - # We want Tailscale to start at boot. - enable = true; - - # Use the Tailscale package we just assembled. - package = testTailscale; - }; - - # Override TS_LOG_TARGET to our private logcatcher. - systemd.services.tailscaled.environment."TS_LOG_TARGET" = "{{.LogTarget}}"; -}` - -func (h *Harness) copyUnit(t *testing.T) { - t.Helper() - - data, err := os.ReadFile("../../../cmd/tailscaled/tailscaled.service") - if err != nil { - t.Fatal(err) - } - os.MkdirAll(filepath.Join(h.binaryDir, "systemd"), 0755) - err = os.WriteFile(filepath.Join(h.binaryDir, "systemd", "tailscaled.service"), data, 0666) - if err != nil { - t.Fatal(err) - } -} - -func (h *Harness) makeNixOSImage(t *testing.T, d Distro, cdir string) string { - if d.Name == "nixos-unstable" { - t.Skip("https://github.com/NixOS/nixpkgs/issues/131098") - } - - h.copyUnit(t) - dir := t.TempDir() - fname := filepath.Join(dir, d.Name+".nix") - fout, err := os.Create(fname) - if err != nil { - t.Fatal(err) - } - - tmpl := template.Must(template.New("base.nix").Parse(nixosConfigTemplate)) - err = tmpl.Execute(fout, struct { - BinPath string - LogTarget string - }{ - BinPath: h.binaryDir, - LogTarget: h.loginServerURL, - }) - if err != nil { - t.Fatal(err) - } - - err = fout.Close() - if err != nil { - t.Fatal(err) - } - - outpath := filepath.Join(cdir, "nixos") - os.MkdirAll(outpath, 0755) - - t.Cleanup(func() { - os.RemoveAll(filepath.Join(outpath, d.Name)) // makes the disk image a candidate for GC - }) - - cmd := exec.Command("nixos-generate", "-f", "qcow", "-o", filepath.Join(outpath, d.Name), "-c", fname) - if *verboseNixOutput { - cmd.Stdout = logger.FuncWriter(t.Logf) - cmd.Stderr = logger.FuncWriter(t.Logf) - } else { - fname := fmt.Sprintf("nix-build-%s-%s", os.Getenv("GITHUB_RUN_NUMBER"), strings.Replace(t.Name(), "/", "-", -1)) - t.Logf("writing nix logs to %s", fname) - fout, err := os.Create(fname) - if err != nil { - t.Fatalf("can't make log file for nix build: %v", err) - } - cmd.Stdout = fout - cmd.Stderr = fout - defer fout.Close() - } - cmd.Env = append(os.Environ(), "NIX_PATH=nixpkgs="+d.URL) - cmd.Dir = outpath - t.Logf("running %s %#v", "nixos-generate", cmd.Args) - if err := cmd.Run(); err != nil { - t.Fatalf("error while making NixOS image for %s: %v", d.Name, err) - } - - if !*verboseNixOutput { - t.Log("done") - } - - return filepath.Join(outpath, d.Name, "nixos.qcow2") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !plan9 + +package vms + +import ( + "flag" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "text/template" + + "tailscale.com/types/logger" +) + +var ( + verboseNixOutput = flag.Bool("verbose-nix-output", false, "if set, use verbose nix output (lots of noise)") +) + +/* + NOTE(Xe): Okay, so, at a high level testing NixOS is a lot different than + other distros due to NixOS' determinism. Normally NixOS wants packages to + be defined in either an overlay, a custom packageOverrides or even + yolo-inline as a part of the system configuration. This is going to have + us take a different approach compared to other distributions. The overall + plan here is as following: + + 1. make the binaries as normal + 2. template in their paths as raw strings to the nixos system module + 3. run `nixos-generators -f qcow -o $CACHE_DIR/tailscale/nixos/version -c generated-config.nix` + 4. pass that to the steps that make the virtual machine + + It doesn't really make sense for us to use a premade virtual machine image + for this as that will make it harder to deterministically create the image. +*/ + +const nixosConfigTemplate = ` +# NOTE(Xe): This template is going to be heavily commented. + +# All NixOS modules are functions. Here is the function prelude for this NixOS +# module that defines the system. It is a function that takes in an attribute +# set (effectively a map[string]nix.Value) and destructures it to some variables: +{ + # other NixOS settings as defined in other modules + config, + + # nixpkgs, which is basically the standard library of NixOS + pkgs, + + # the path to some system-scoped NixOS modules that aren't imported by default + modulesPath, + + # the rest of the arguments don't matter + ... +}: + +# Nix's syntax was inspired by Haskell and other functional languages, so the +# let .. in pattern is used to create scoped variables: +let + # Define the package (derivation) for Tailscale based on the binaries we + # just built for this test: + testTailscale = pkgs.stdenv.mkDerivation { + # The name of the package. This usually includes a version however it + # doesn't matter here. + name = "tailscale-test"; + + # The path on disk to the "source code" of the package, in this case it is + # the path to the binaries that are built. This needs to be the raw + # unquoted slash-separated path, not a string containing the path because Nix + # has a special path type. + src = {{.BinPath}}; + + # We only need to worry about the install phase because we've already + # built the binaries. + phases = "installPhase"; + + # We need to wrap tailscaled such that it has iptables in its $PATH. + nativeBuildInputs = [ pkgs.makeWrapper ]; + + # The install instructions for this package ('' ''defines a multi-line string). + # The with statement lets us bring in values into scope as if they were + # defined in the current scope. + installPhase = with pkgs; '' + # This is bash. + + # Make the output folders for the package (systemd unit and binary folders). + mkdir -p $out/bin + + # Install tailscale{,d} + cp $src/tailscale $out/bin/tailscale + cp $src/tailscaled $out/bin/tailscaled + + # Wrap tailscaled with the ip and iptables commands. + wrapProgram $out/bin/tailscaled --prefix PATH : ${ + lib.makeBinPath [ iproute iptables ] + } + + # Install systemd unit. + cp $src/systemd/tailscaled.service . + sed -i -e "s#/usr/sbin#$out/bin#" -e "/^EnvironmentFile/d" ./tailscaled.service + install -D -m0444 -t $out/lib/systemd/system ./tailscaled.service + ''; + }; +in { + # This is a QEMU VM. This module has a lot of common qemu VM settings so you + # don't have to set them manually. + imports = [ (modulesPath + "/profiles/qemu-guest.nix") ]; + + # We need virtio support to boot. + boot.initrd.availableKernelModules = + [ "ata_piix" "uhci_hcd" "virtio_pci" "sr_mod" "virtio_blk" ]; + boot.initrd.kernelModules = [ ]; + boot.kernelModules = [ ]; + boot.extraModulePackages = [ ]; + + # Curl is needed for one of the steps in cloud-final + systemd.services.cloud-final.path = with pkgs; [ curl ]; + + # Curl is needed for one of the integration tests + environment.systemPackages = with pkgs; [ curl nix bash squid openssl daemonize ]; + + # yolo, this vm can sudo freely. + security.sudo.wheelNeedsPassword = false; + + # Enable cloud-init so we can set VM hostnames and the like the same as other + # distros. This will also take care of SSH keys. It's pretty handy. + services.cloud-init = { + enable = true; + ext4.enable = true; + }; + + # We want sshd running. + services.openssh.enable = true; + + # Tailscale settings: + services.tailscale = { + # We want Tailscale to start at boot. + enable = true; + + # Use the Tailscale package we just assembled. + package = testTailscale; + }; + + # Override TS_LOG_TARGET to our private logcatcher. + systemd.services.tailscaled.environment."TS_LOG_TARGET" = "{{.LogTarget}}"; +}` + +func (h *Harness) copyUnit(t *testing.T) { + t.Helper() + + data, err := os.ReadFile("../../../cmd/tailscaled/tailscaled.service") + if err != nil { + t.Fatal(err) + } + os.MkdirAll(filepath.Join(h.binaryDir, "systemd"), 0755) + err = os.WriteFile(filepath.Join(h.binaryDir, "systemd", "tailscaled.service"), data, 0666) + if err != nil { + t.Fatal(err) + } +} + +func (h *Harness) makeNixOSImage(t *testing.T, d Distro, cdir string) string { + if d.Name == "nixos-unstable" { + t.Skip("https://github.com/NixOS/nixpkgs/issues/131098") + } + + h.copyUnit(t) + dir := t.TempDir() + fname := filepath.Join(dir, d.Name+".nix") + fout, err := os.Create(fname) + if err != nil { + t.Fatal(err) + } + + tmpl := template.Must(template.New("base.nix").Parse(nixosConfigTemplate)) + err = tmpl.Execute(fout, struct { + BinPath string + LogTarget string + }{ + BinPath: h.binaryDir, + LogTarget: h.loginServerURL, + }) + if err != nil { + t.Fatal(err) + } + + err = fout.Close() + if err != nil { + t.Fatal(err) + } + + outpath := filepath.Join(cdir, "nixos") + os.MkdirAll(outpath, 0755) + + t.Cleanup(func() { + os.RemoveAll(filepath.Join(outpath, d.Name)) // makes the disk image a candidate for GC + }) + + cmd := exec.Command("nixos-generate", "-f", "qcow", "-o", filepath.Join(outpath, d.Name), "-c", fname) + if *verboseNixOutput { + cmd.Stdout = logger.FuncWriter(t.Logf) + cmd.Stderr = logger.FuncWriter(t.Logf) + } else { + fname := fmt.Sprintf("nix-build-%s-%s", os.Getenv("GITHUB_RUN_NUMBER"), strings.Replace(t.Name(), "/", "-", -1)) + t.Logf("writing nix logs to %s", fname) + fout, err := os.Create(fname) + if err != nil { + t.Fatalf("can't make log file for nix build: %v", err) + } + cmd.Stdout = fout + cmd.Stderr = fout + defer fout.Close() + } + cmd.Env = append(os.Environ(), "NIX_PATH=nixpkgs="+d.URL) + cmd.Dir = outpath + t.Logf("running %s %#v", "nixos-generate", cmd.Args) + if err := cmd.Run(); err != nil { + t.Fatalf("error while making NixOS image for %s: %v", d.Name, err) + } + + if !*verboseNixOutput { + t.Log("done") + } + + return filepath.Join(outpath, d.Name, "nixos.qcow2") +} diff --git a/tstest/integration/vms/regex_flag.go b/tstest/integration/vms/regex_flag.go index 195f7c7718b7c..02e399ecdfaad 100644 --- a/tstest/integration/vms/regex_flag.go +++ b/tstest/integration/vms/regex_flag.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package vms - -import "regexp" - -type regexValue struct { - r *regexp.Regexp -} - -func (r *regexValue) String() string { - if r.r == nil { - return "" - } - - return r.r.String() -} - -func (r *regexValue) Set(val string) error { - if rex, err := regexp.Compile(val); err != nil { - return err - } else { - r.r = rex - return nil - } -} - -func (r regexValue) Unwrap() *regexp.Regexp { return r.r } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package vms + +import "regexp" + +type regexValue struct { + r *regexp.Regexp +} + +func (r *regexValue) String() string { + if r.r == nil { + return "" + } + + return r.r.String() +} + +func (r *regexValue) Set(val string) error { + if rex, err := regexp.Compile(val); err != nil { + return err + } else { + r.r = rex + return nil + } +} + +func (r regexValue) Unwrap() *regexp.Regexp { return r.r } diff --git a/tstest/integration/vms/regex_flag_test.go b/tstest/integration/vms/regex_flag_test.go index 790894080a7d5..0f4e5f8f7bdec 100644 --- a/tstest/integration/vms/regex_flag_test.go +++ b/tstest/integration/vms/regex_flag_test.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package vms - -import ( - "flag" - "testing" -) - -func TestRegexFlag(t *testing.T) { - var v regexValue - fs := flag.NewFlagSet(t.Name(), flag.PanicOnError) - fs.Var(&v, "regex", "regex to parse") - - const want = `.*` - fs.Parse([]string{"-regex", want}) - if v.Unwrap().String() != want { - t.Fatalf("got wrong regex: %q, wanted: %q", v.Unwrap().String(), want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package vms + +import ( + "flag" + "testing" +) + +func TestRegexFlag(t *testing.T) { + var v regexValue + fs := flag.NewFlagSet(t.Name(), flag.PanicOnError) + fs.Var(&v, "regex", "regex to parse") + + const want = `.*` + fs.Parse([]string{"-regex", want}) + if v.Unwrap().String() != want { + t.Fatalf("got wrong regex: %q, wanted: %q", v.Unwrap().String(), want) + } +} diff --git a/tstest/integration/vms/runner.nix b/tstest/integration/vms/runner.nix index 8d4c0a25dc5f6..ac569cf658cb1 100644 --- a/tstest/integration/vms/runner.nix +++ b/tstest/integration/vms/runner.nix @@ -1,89 +1,89 @@ -# This is a NixOS module to allow a machine to act as an integration test -# runner. This is used for the end-to-end VM test suite. - -{ lib, config, pkgs, ... }: - -{ - # The GitHub Actions self-hosted runner service. - services.github-runner = { - enable = true; - url = "https://github.com/tailscale/tailscale"; - replace = true; - extraLabels = [ "vm_integration_test" ]; - - # Justifications for the packages: - extraPackages = with pkgs; [ - # The test suite is written in Go. - go - - # This contains genisoimage, which is needed to create cloud-init - # seeds. - cdrkit - - # This package is the virtual machine hypervisor we use in tests. - qemu - - # This package contains tools like `ssh-keygen`. - openssh - - # The C compiler so cgo builds work. - gcc - - # The package manager Nix, just in case. - nix - - # Used to generate a NixOS image for testing. - nixos-generators - - # Used to extract things. - gnutar - - # Used to decompress things. - lzma - ]; - - # Customize this to include your GitHub username so we can track - # who is running which node. - name = "YOUR-GITHUB-USERNAME-tstest-integration-vms"; - - # Replace this with the path to the GitHub Actions runner token on - # your disk. - tokenFile = "/run/decrypted/ts-oss-ghaction-token"; - }; - - # A user account so there is a home directory and so they have kvm - # access. Please don't change this account name. - users.users.ghrunner = { - createHome = true; - isSystemUser = true; - extraGroups = [ "kvm" ]; - }; - - # The default github-runner service sets a lot of isolation features - # that attempt to limit the damage that malicious code can use. - # Unfortunately we rely on some "dangerous" features to do these tests, - # so this shim will peel some of them away. - systemd.services.github-runner = { - serviceConfig = { - # We need access to /dev to poke /dev/kvm. - PrivateDevices = lib.mkForce false; - - # /dev/kvm is how qemu creates a virtual machine with KVM. - DeviceAllow = lib.mkForce [ "/dev/kvm" ]; - - # Ensure the service has KVM permissions with the `kvm` group. - ExtraGroups = [ "kvm" ]; - - # The service runs as a dynamic user by default. This makes it hard - # to persistently store things in /var/lib/ghrunner. This line - # disables the dynamic user feature. - DynamicUser = lib.mkForce false; - - # Run this service as our ghrunner user. - User = "ghrunner"; - - # We need access to /var/lib/ghrunner to store VM images. - ProtectSystem = lib.mkForce null; - }; - }; -} +# This is a NixOS module to allow a machine to act as an integration test +# runner. This is used for the end-to-end VM test suite. + +{ lib, config, pkgs, ... }: + +{ + # The GitHub Actions self-hosted runner service. + services.github-runner = { + enable = true; + url = "https://github.com/tailscale/tailscale"; + replace = true; + extraLabels = [ "vm_integration_test" ]; + + # Justifications for the packages: + extraPackages = with pkgs; [ + # The test suite is written in Go. + go + + # This contains genisoimage, which is needed to create cloud-init + # seeds. + cdrkit + + # This package is the virtual machine hypervisor we use in tests. + qemu + + # This package contains tools like `ssh-keygen`. + openssh + + # The C compiler so cgo builds work. + gcc + + # The package manager Nix, just in case. + nix + + # Used to generate a NixOS image for testing. + nixos-generators + + # Used to extract things. + gnutar + + # Used to decompress things. + lzma + ]; + + # Customize this to include your GitHub username so we can track + # who is running which node. + name = "YOUR-GITHUB-USERNAME-tstest-integration-vms"; + + # Replace this with the path to the GitHub Actions runner token on + # your disk. + tokenFile = "/run/decrypted/ts-oss-ghaction-token"; + }; + + # A user account so there is a home directory and so they have kvm + # access. Please don't change this account name. + users.users.ghrunner = { + createHome = true; + isSystemUser = true; + extraGroups = [ "kvm" ]; + }; + + # The default github-runner service sets a lot of isolation features + # that attempt to limit the damage that malicious code can use. + # Unfortunately we rely on some "dangerous" features to do these tests, + # so this shim will peel some of them away. + systemd.services.github-runner = { + serviceConfig = { + # We need access to /dev to poke /dev/kvm. + PrivateDevices = lib.mkForce false; + + # /dev/kvm is how qemu creates a virtual machine with KVM. + DeviceAllow = lib.mkForce [ "/dev/kvm" ]; + + # Ensure the service has KVM permissions with the `kvm` group. + ExtraGroups = [ "kvm" ]; + + # The service runs as a dynamic user by default. This makes it hard + # to persistently store things in /var/lib/ghrunner. This line + # disables the dynamic user feature. + DynamicUser = lib.mkForce false; + + # Run this service as our ghrunner user. + User = "ghrunner"; + + # We need access to /var/lib/ghrunner to store VM images. + ProtectSystem = lib.mkForce null; + }; + }; +} diff --git a/tstest/integration/vms/squid.conf b/tstest/integration/vms/squid.conf index e43c5cd1f41d4..29d32bd6d8606 100644 --- a/tstest/integration/vms/squid.conf +++ b/tstest/integration/vms/squid.conf @@ -1,39 +1,39 @@ -pid_filename /run/squid.pid -cache_dir ufs /tmp/squid/cache 500 16 256 -maximum_object_size 4096 KB -coredump_dir /tmp/squid/core -visible_hostname localhost -cache_access_log /tmp/squid/access.log -cache_log /tmp/squid/cache.log - -# Access Control lists -acl localhost src 127.0.0.1 ::1 -acl manager proto cache_object -acl SSL_ports port 443 -acl Safe_ports port 80 # http -acl Safe_ports port 21 # ftp -acl Safe_ports port 443 # https -acl Safe_ports port 70 # gopher -acl Safe_ports port 210 # wais -acl Safe_ports port 1025-65535 # unregistered ports -acl Safe_ports port 280 # http-mgmt -acl Safe_ports port 488 # gss-http -acl Safe_ports port 591 # filemaker -acl Safe_ports port 777 # multiling http -acl CONNECT method CONNECT - -http_access allow localhost -http_access deny all -forwarded_for on - -# sslcrtd_program /nix/store/nqlqk1f6qlxdirlrl1aijgb6vbzxs0gs-squid-4.17/libexec/security_file_certgen -s /tmp/squid/ssl_db -M 4MB -sslcrtd_children 5 - -http_port 127.0.0.1:3128 \ - ssl-bump \ - generate-host-certificates=on \ - dynamic_cert_mem_cache_size=4MB \ - cert=/tmp/squid/myca-mitm.pem - -ssl_bump stare all # mimic the Client Hello, drop unsupported extensions +pid_filename /run/squid.pid +cache_dir ufs /tmp/squid/cache 500 16 256 +maximum_object_size 4096 KB +coredump_dir /tmp/squid/core +visible_hostname localhost +cache_access_log /tmp/squid/access.log +cache_log /tmp/squid/cache.log + +# Access Control lists +acl localhost src 127.0.0.1 ::1 +acl manager proto cache_object +acl SSL_ports port 443 +acl Safe_ports port 80 # http +acl Safe_ports port 21 # ftp +acl Safe_ports port 443 # https +acl Safe_ports port 70 # gopher +acl Safe_ports port 210 # wais +acl Safe_ports port 1025-65535 # unregistered ports +acl Safe_ports port 280 # http-mgmt +acl Safe_ports port 488 # gss-http +acl Safe_ports port 591 # filemaker +acl Safe_ports port 777 # multiling http +acl CONNECT method CONNECT + +http_access allow localhost +http_access deny all +forwarded_for on + +# sslcrtd_program /nix/store/nqlqk1f6qlxdirlrl1aijgb6vbzxs0gs-squid-4.17/libexec/security_file_certgen -s /tmp/squid/ssl_db -M 4MB +sslcrtd_children 5 + +http_port 127.0.0.1:3128 \ + ssl-bump \ + generate-host-certificates=on \ + dynamic_cert_mem_cache_size=4MB \ + cert=/tmp/squid/myca-mitm.pem + +ssl_bump stare all # mimic the Client Hello, drop unsupported extensions ssl_bump bump all # terminate and establish new TLS connection \ No newline at end of file diff --git a/tstest/integration/vms/top_level_test.go b/tstest/integration/vms/top_level_test.go index 1b9c10e29297a..c107fd89cc886 100644 --- a/tstest/integration/vms/top_level_test.go +++ b/tstest/integration/vms/top_level_test.go @@ -1,124 +1,124 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !plan9 - -package vms - -import ( - "context" - "testing" - "time" - - "github.com/pkg/sftp" - expect "github.com/tailscale/goexpect" -) - -func TestRunUbuntu1804(t *testing.T) { - testOneDistribution(t, 0, Distros[0]) -} - -func TestRunUbuntu2004(t *testing.T) { - testOneDistribution(t, 1, Distros[1]) -} - -func TestRunNixos2111(t *testing.T) { - t.Parallel() - testOneDistribution(t, 2, Distros[2]) -} - -// TestMITMProxy is a smoke test for derphttp through a MITM proxy. -// Encountering such proxies is unfortunately commonplace in more -// traditional enterprise networks. -// -// We invoke tailscale netcheck because the networking check is done -// by tailscale rather than tailscaled, making it easier to configure -// the proxy. -// -// To provide the actual MITM server, we use squid. -func TestMITMProxy(t *testing.T) { - t.Parallel() - setupTests(t) - distro := Distros[2] // nixos-21.11 - - if distroRex.Unwrap().MatchString(distro.Name) { - t.Logf("%s matches %s", distro.Name, distroRex.Unwrap()) - } else { - t.Skip("regex not matched") - } - - ctx, done := context.WithCancel(context.Background()) - t.Cleanup(done) - - h := newHarness(t) - - err := ramsem.sem.Acquire(ctx, int64(distro.MemoryMegs)) - if err != nil { - t.Fatalf("can't acquire ram semaphore: %v", err) - } - t.Cleanup(func() { ramsem.sem.Release(int64(distro.MemoryMegs)) }) - - vm := h.mkVM(t, 2, distro, h.pubKey, h.loginServerURL, t.TempDir()) - vm.waitStartup(t) - - ipm := h.waitForIPMap(t, vm, distro) - _, cli := h.setupSSHShell(t, distro, ipm) - - sftpCli, err := sftp.NewClient(cli) - if err != nil { - t.Fatalf("can't connect over sftp to copy binaries: %v", err) - } - defer sftpCli.Close() - - // Initialize a squid installation. - // - // A few things of note here: - // - The first thing we do is append the nsslcrtd_program stanza to the config. - // This must be an absolute path and is based on the nix path of the squid derivation, - // so we compute and write it out here. - // - Squid expects a pre-initialized directory layout, so we create that in /tmp/squid then - // invoke squid with -z to have it fill in the rest. - // - Doing a meddler-in-the-middle attack requires using some fake keys, so we create - // them using openssl and then use the security_file_certgen tool to setup squids' ssl_db. - // - There were some perms issues, so i yeeted 0777. Its only a test anyway - copyFile(t, sftpCli, "squid.conf", "/tmp/squid.conf") - runTestCommands(t, 30*time.Second, cli, []expect.Batcher{ - &expect.BSnd{S: "echo -e \"\\nsslcrtd_program $(nix eval --raw nixpkgs.squid)/libexec/security_file_certgen -s /tmp/squid/ssl_db -M 4MB\\n\" >> /tmp/squid.conf\n"}, - &expect.BSnd{S: "mkdir -p /tmp/squid/{cache,core}\n"}, - &expect.BSnd{S: "openssl req -batch -new -newkey rsa:4096 -sha256 -days 3650 -nodes -x509 -keyout /tmp/squid/myca-mitm.pem -out /tmp/squid/myca-mitm.pem\n"}, - &expect.BExp{R: `writing new private key to '/tmp/squid/myca-mitm.pem'`}, - &expect.BSnd{S: "$(nix eval --raw nixpkgs.squid)/libexec/security_file_certgen -c -s /tmp/squid/ssl_db -M 4MB\n"}, - &expect.BExp{R: `Done`}, - &expect.BSnd{S: "sudo chmod -R 0777 /tmp/squid\n"}, - &expect.BSnd{S: "squid --foreground -YCs -z -f /tmp/squid.conf\n"}, - &expect.BSnd{S: "echo Success.\n"}, - &expect.BExp{R: `Success.`}, - }) - - // Start the squid server. - runTestCommands(t, 10*time.Second, cli, []expect.Batcher{ - &expect.BSnd{S: "daemonize -v -c /tmp/squid $(nix eval --raw nixpkgs.squid)/bin/squid --foreground -YCs -f /tmp/squid.conf\n"}, // start daemon - // NOTE(tom): Writing to /dev/tcp/* is bash magic, not a file. This - // eldritchian incantation lets us wait till squid is up. - &expect.BSnd{S: "while ! timeout 5 bash -c 'echo > /dev/tcp/localhost/3128'; do sleep 1; done\n"}, - &expect.BSnd{S: "echo Success.\n"}, - &expect.BExp{R: `Success.`}, - }) - - // Uncomment to help debugging this test if it fails. - // - // runTestCommands(t, 30 * time.Second, cli, []expect.Batcher{ - // &expect.BSnd{S: "sudo ifconfig\n"}, - // &expect.BSnd{S: "sudo ip link\n"}, - // &expect.BSnd{S: "sudo ip route\n"}, - // &expect.BSnd{S: "ps -aux\n"}, - // &expect.BSnd{S: "netstat -a\n"}, - // &expect.BSnd{S: "cat /tmp/squid/access.log && cat /tmp/squid/cache.log && cat /tmp/squid.conf && echo Success.\n"}, - // &expect.BExp{R: `Success.`}, - // }) - - runTestCommands(t, 30*time.Second, cli, []expect.Batcher{ - &expect.BSnd{S: "SSL_CERT_FILE=/tmp/squid/myca-mitm.pem HTTPS_PROXY=http://127.0.0.1:3128 tailscale netcheck\n"}, - &expect.BExp{R: `IPv4: yes`}, - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !plan9 + +package vms + +import ( + "context" + "testing" + "time" + + "github.com/pkg/sftp" + expect "github.com/tailscale/goexpect" +) + +func TestRunUbuntu1804(t *testing.T) { + testOneDistribution(t, 0, Distros[0]) +} + +func TestRunUbuntu2004(t *testing.T) { + testOneDistribution(t, 1, Distros[1]) +} + +func TestRunNixos2111(t *testing.T) { + t.Parallel() + testOneDistribution(t, 2, Distros[2]) +} + +// TestMITMProxy is a smoke test for derphttp through a MITM proxy. +// Encountering such proxies is unfortunately commonplace in more +// traditional enterprise networks. +// +// We invoke tailscale netcheck because the networking check is done +// by tailscale rather than tailscaled, making it easier to configure +// the proxy. +// +// To provide the actual MITM server, we use squid. +func TestMITMProxy(t *testing.T) { + t.Parallel() + setupTests(t) + distro := Distros[2] // nixos-21.11 + + if distroRex.Unwrap().MatchString(distro.Name) { + t.Logf("%s matches %s", distro.Name, distroRex.Unwrap()) + } else { + t.Skip("regex not matched") + } + + ctx, done := context.WithCancel(context.Background()) + t.Cleanup(done) + + h := newHarness(t) + + err := ramsem.sem.Acquire(ctx, int64(distro.MemoryMegs)) + if err != nil { + t.Fatalf("can't acquire ram semaphore: %v", err) + } + t.Cleanup(func() { ramsem.sem.Release(int64(distro.MemoryMegs)) }) + + vm := h.mkVM(t, 2, distro, h.pubKey, h.loginServerURL, t.TempDir()) + vm.waitStartup(t) + + ipm := h.waitForIPMap(t, vm, distro) + _, cli := h.setupSSHShell(t, distro, ipm) + + sftpCli, err := sftp.NewClient(cli) + if err != nil { + t.Fatalf("can't connect over sftp to copy binaries: %v", err) + } + defer sftpCli.Close() + + // Initialize a squid installation. + // + // A few things of note here: + // - The first thing we do is append the nsslcrtd_program stanza to the config. + // This must be an absolute path and is based on the nix path of the squid derivation, + // so we compute and write it out here. + // - Squid expects a pre-initialized directory layout, so we create that in /tmp/squid then + // invoke squid with -z to have it fill in the rest. + // - Doing a meddler-in-the-middle attack requires using some fake keys, so we create + // them using openssl and then use the security_file_certgen tool to setup squids' ssl_db. + // - There were some perms issues, so i yeeted 0777. Its only a test anyway + copyFile(t, sftpCli, "squid.conf", "/tmp/squid.conf") + runTestCommands(t, 30*time.Second, cli, []expect.Batcher{ + &expect.BSnd{S: "echo -e \"\\nsslcrtd_program $(nix eval --raw nixpkgs.squid)/libexec/security_file_certgen -s /tmp/squid/ssl_db -M 4MB\\n\" >> /tmp/squid.conf\n"}, + &expect.BSnd{S: "mkdir -p /tmp/squid/{cache,core}\n"}, + &expect.BSnd{S: "openssl req -batch -new -newkey rsa:4096 -sha256 -days 3650 -nodes -x509 -keyout /tmp/squid/myca-mitm.pem -out /tmp/squid/myca-mitm.pem\n"}, + &expect.BExp{R: `writing new private key to '/tmp/squid/myca-mitm.pem'`}, + &expect.BSnd{S: "$(nix eval --raw nixpkgs.squid)/libexec/security_file_certgen -c -s /tmp/squid/ssl_db -M 4MB\n"}, + &expect.BExp{R: `Done`}, + &expect.BSnd{S: "sudo chmod -R 0777 /tmp/squid\n"}, + &expect.BSnd{S: "squid --foreground -YCs -z -f /tmp/squid.conf\n"}, + &expect.BSnd{S: "echo Success.\n"}, + &expect.BExp{R: `Success.`}, + }) + + // Start the squid server. + runTestCommands(t, 10*time.Second, cli, []expect.Batcher{ + &expect.BSnd{S: "daemonize -v -c /tmp/squid $(nix eval --raw nixpkgs.squid)/bin/squid --foreground -YCs -f /tmp/squid.conf\n"}, // start daemon + // NOTE(tom): Writing to /dev/tcp/* is bash magic, not a file. This + // eldritchian incantation lets us wait till squid is up. + &expect.BSnd{S: "while ! timeout 5 bash -c 'echo > /dev/tcp/localhost/3128'; do sleep 1; done\n"}, + &expect.BSnd{S: "echo Success.\n"}, + &expect.BExp{R: `Success.`}, + }) + + // Uncomment to help debugging this test if it fails. + // + // runTestCommands(t, 30 * time.Second, cli, []expect.Batcher{ + // &expect.BSnd{S: "sudo ifconfig\n"}, + // &expect.BSnd{S: "sudo ip link\n"}, + // &expect.BSnd{S: "sudo ip route\n"}, + // &expect.BSnd{S: "ps -aux\n"}, + // &expect.BSnd{S: "netstat -a\n"}, + // &expect.BSnd{S: "cat /tmp/squid/access.log && cat /tmp/squid/cache.log && cat /tmp/squid.conf && echo Success.\n"}, + // &expect.BExp{R: `Success.`}, + // }) + + runTestCommands(t, 30*time.Second, cli, []expect.Batcher{ + &expect.BSnd{S: "SSL_CERT_FILE=/tmp/squid/myca-mitm.pem HTTPS_PROXY=http://127.0.0.1:3128 tailscale netcheck\n"}, + &expect.BExp{R: `IPv4: yes`}, + }) +} diff --git a/tstest/integration/vms/udp_tester.go b/tstest/integration/vms/udp_tester.go index 14c8c6ed0c7a5..be44aa9636103 100644 --- a/tstest/integration/vms/udp_tester.go +++ b/tstest/integration/vms/udp_tester.go @@ -1,77 +1,77 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -// Command udp_tester exists because all of these distros being tested don't -// have a consistent tool for doing UDP traffic. This is a very hacked up tool -// that does that UDP traffic so these tests can be done. -package main - -import ( - "flag" - "io" - "log" - "net" - "os" -) - -var ( - client = flag.String("client", "", "host:port to connect to for sending UDP") - server = flag.String("server", "", "host:port to bind to for receiving UDP") -) - -func main() { - flag.Parse() - - if *client == "" && *server == "" { - log.Fatal("specify -client or -server") - } - - if *client != "" { - conn, err := net.Dial("udp", *client) - if err != nil { - log.Fatalf("can't dial %s: %v", *client, err) - } - log.Printf("dialed to %s", conn.RemoteAddr()) - defer conn.Close() - - buf := make([]byte, 2048) - n, err := os.Stdin.Read(buf) - if err != nil && err != io.EOF { - log.Fatalf("can't read from stdin: %v", err) - } - - nn, err := conn.Write(buf[:n]) - if err != nil { - log.Fatalf("can't write to %s: %v", conn.RemoteAddr(), err) - } - - if n == nn { - return - } - - log.Fatalf("wanted to write %d bytes, wrote %d bytes", n, nn) - } - - if *server != "" { - addr, err := net.ResolveUDPAddr("udp", *server) - if err != nil { - log.Fatalf("can't resolve %s: %v", *server, err) - } - ln, err := net.ListenUDP("udp", addr) - if err != nil { - log.Fatalf("can't listen %s: %v", *server, err) - } - defer ln.Close() - - buf := make([]byte, 2048) - - n, _, err := ln.ReadFromUDP(buf) - if err != nil { - log.Fatal(err) - } - - os.Stdout.Write(buf[:n]) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +// Command udp_tester exists because all of these distros being tested don't +// have a consistent tool for doing UDP traffic. This is a very hacked up tool +// that does that UDP traffic so these tests can be done. +package main + +import ( + "flag" + "io" + "log" + "net" + "os" +) + +var ( + client = flag.String("client", "", "host:port to connect to for sending UDP") + server = flag.String("server", "", "host:port to bind to for receiving UDP") +) + +func main() { + flag.Parse() + + if *client == "" && *server == "" { + log.Fatal("specify -client or -server") + } + + if *client != "" { + conn, err := net.Dial("udp", *client) + if err != nil { + log.Fatalf("can't dial %s: %v", *client, err) + } + log.Printf("dialed to %s", conn.RemoteAddr()) + defer conn.Close() + + buf := make([]byte, 2048) + n, err := os.Stdin.Read(buf) + if err != nil && err != io.EOF { + log.Fatalf("can't read from stdin: %v", err) + } + + nn, err := conn.Write(buf[:n]) + if err != nil { + log.Fatalf("can't write to %s: %v", conn.RemoteAddr(), err) + } + + if n == nn { + return + } + + log.Fatalf("wanted to write %d bytes, wrote %d bytes", n, nn) + } + + if *server != "" { + addr, err := net.ResolveUDPAddr("udp", *server) + if err != nil { + log.Fatalf("can't resolve %s: %v", *server, err) + } + ln, err := net.ListenUDP("udp", addr) + if err != nil { + log.Fatalf("can't listen %s: %v", *server, err) + } + defer ln.Close() + + buf := make([]byte, 2048) + + n, _, err := ln.ReadFromUDP(buf) + if err != nil { + log.Fatal(err) + } + + os.Stdout.Write(buf[:n]) + } +} diff --git a/tstest/log_test.go b/tstest/log_test.go index a8cb62cf5ccf2..51a5743c2c7f2 100644 --- a/tstest/log_test.go +++ b/tstest/log_test.go @@ -1,47 +1,47 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstest - -import ( - "reflect" - "testing" -) - -func TestLogLineTracker(t *testing.T) { - const ( - l1 = "line 1: %s" - l2 = "line 2: %s" - l3 = "line 3: %s" - ) - - lt := NewLogLineTracker(t.Logf, []string{l1, l2}) - - if got, want := lt.Check(), []string{l1, l2}; !reflect.DeepEqual(got, want) { - t.Errorf("Check = %q; want %q", got, want) - } - - lt.Logf(l3, "hi") - - if got, want := lt.Check(), []string{l1, l2}; !reflect.DeepEqual(got, want) { - t.Errorf("Check = %q; want %q", got, want) - } - - lt.Logf(l1, "hi") - - if got, want := lt.Check(), []string{l2}; !reflect.DeepEqual(got, want) { - t.Errorf("Check = %q; want %q", got, want) - } - - lt.Logf(l1, "bye") - - if got, want := lt.Check(), []string{l2}; !reflect.DeepEqual(got, want) { - t.Errorf("Check = %q; want %q", got, want) - } - - lt.Logf(l2, "hi") - - if got, want := lt.Check(), []string(nil); !reflect.DeepEqual(got, want) { - t.Errorf("Check = %q; want %q", got, want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import ( + "reflect" + "testing" +) + +func TestLogLineTracker(t *testing.T) { + const ( + l1 = "line 1: %s" + l2 = "line 2: %s" + l3 = "line 3: %s" + ) + + lt := NewLogLineTracker(t.Logf, []string{l1, l2}) + + if got, want := lt.Check(), []string{l1, l2}; !reflect.DeepEqual(got, want) { + t.Errorf("Check = %q; want %q", got, want) + } + + lt.Logf(l3, "hi") + + if got, want := lt.Check(), []string{l1, l2}; !reflect.DeepEqual(got, want) { + t.Errorf("Check = %q; want %q", got, want) + } + + lt.Logf(l1, "hi") + + if got, want := lt.Check(), []string{l2}; !reflect.DeepEqual(got, want) { + t.Errorf("Check = %q; want %q", got, want) + } + + lt.Logf(l1, "bye") + + if got, want := lt.Check(), []string{l2}; !reflect.DeepEqual(got, want) { + t.Errorf("Check = %q; want %q", got, want) + } + + lt.Logf(l2, "hi") + + if got, want := lt.Check(), []string(nil); !reflect.DeepEqual(got, want) { + t.Errorf("Check = %q; want %q", got, want) + } +} diff --git a/tstest/natlab/firewall.go b/tstest/natlab/firewall.go index 851f1c56dcf8d..c427d6692a29c 100644 --- a/tstest/natlab/firewall.go +++ b/tstest/natlab/firewall.go @@ -1,156 +1,156 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package natlab - -import ( - "fmt" - "net/netip" - "sync" - "time" - - "tailscale.com/util/mak" -) - -// FirewallType is the type of filtering a stateful firewall -// does. Values express different modes defined by RFC 4787. -type FirewallType int - -const ( - // AddressAndPortDependentFirewall specifies a destination - // address-and-port dependent firewall. Outbound traffic to an - // ip:port authorizes traffic from that ip:port exactly, and - // nothing else. - AddressAndPortDependentFirewall FirewallType = iota - // AddressDependentFirewall specifies a destination address - // dependent firewall. Once outbound traffic has been seen to an - // IP address, that IP address can talk back from any port. - AddressDependentFirewall - // EndpointIndependentFirewall specifies a destination endpoint - // independent firewall. Once outbound traffic has been seen from - // a source, anyone can talk back to that source. - EndpointIndependentFirewall -) - -// fwKey is the lookup key for a firewall session. While it contains a -// 4-tuple ({src,dst} {ip,port}), some FirewallTypes will zero out -// some fields, so in practice the key is either a 2-tuple (src only), -// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port). -type fwKey struct { - src netip.AddrPort - dst netip.AddrPort -} - -// key returns an fwKey for the given src and dst, trimmed according -// to the FirewallType. fwKeys are always constructed from the -// "outbound" point of view (i.e. src is the "trusted" side of the -// world), it's the caller's responsibility to swap src and dst in the -// call to key when processing packets inbound from the "untrusted" -// world. -func (s FirewallType) key(src, dst netip.AddrPort) fwKey { - k := fwKey{src: src} - switch s { - case EndpointIndependentFirewall: - case AddressDependentFirewall: - k.dst = netip.AddrPortFrom(dst.Addr(), k.dst.Port()) - case AddressAndPortDependentFirewall: - k.dst = dst - default: - panic(fmt.Sprintf("unknown firewall selectivity %v", s)) - } - return k -} - -// DefaultSessionTimeout is the default timeout for a firewall -// session. -const DefaultSessionTimeout = 30 * time.Second - -// Firewall is a simple stateful firewall that allows all outbound -// traffic and filters inbound traffic based on recently seen outbound -// traffic. Its HandlePacket method should be attached to a Machine to -// give it a stateful firewall. -type Firewall struct { - // SessionTimeout is the lifetime of idle sessions in the firewall - // state. Packets transiting from the TrustedInterface reset the - // session lifetime to SessionTimeout. If zero, - // DefaultSessionTimeout is used. - SessionTimeout time.Duration - // Type specifies how precisely return traffic must match - // previously seen outbound traffic to be allowed. Defaults to - // AddressAndPortDependentFirewall. - Type FirewallType - // TrustedInterface is an optional interface that is considered - // trusted in addition to PacketConns local to the Machine. All - // other interfaces can only respond to traffic from - // TrustedInterface or the local host. - TrustedInterface *Interface - // TimeNow is a function returning the current time. If nil, - // time.Now is used. - TimeNow func() time.Time - - // TODO: refresh directionality: outbound-only, both - - mu sync.Mutex - seen map[fwKey]time.Time // session -> deadline -} - -func (f *Firewall) timeNow() time.Time { - if f.TimeNow != nil { - return f.TimeNow() - } - return time.Now() -} - -// Reset drops all firewall state, forgetting all flows. -func (f *Firewall) Reset() { - f.mu.Lock() - defer f.mu.Unlock() - f.seen = nil -} - -func (f *Firewall) HandleOut(p *Packet, oif *Interface) *Packet { - f.mu.Lock() - defer f.mu.Unlock() - - k := f.Type.key(p.Src, p.Dst) - mak.Set(&f.seen, k, f.timeNow().Add(f.sessionTimeoutLocked())) - p.Trace("firewall out ok") - return p -} - -func (f *Firewall) HandleIn(p *Packet, iif *Interface) *Packet { - f.mu.Lock() - defer f.mu.Unlock() - - // reverse src and dst because the session table is from the POV - // of outbound packets. - k := f.Type.key(p.Dst, p.Src) - now := f.timeNow() - if now.After(f.seen[k]) { - p.Trace("firewall drop") - return nil - } - p.Trace("firewall in ok") - return p -} - -func (f *Firewall) HandleForward(p *Packet, iif *Interface, oif *Interface) *Packet { - if iif == f.TrustedInterface { - // Treat just like a locally originated packet - return f.HandleOut(p, oif) - } - if oif != f.TrustedInterface { - // Not a possible return packet from our trusted interface, drop. - p.Trace("firewall drop, unexpected oif") - return nil - } - // Otherwise, a session must exist, same as HandleIn. - return f.HandleIn(p, iif) -} - -func (f *Firewall) sessionTimeoutLocked() time.Duration { - if f.SessionTimeout == 0 { - return DefaultSessionTimeout - } - return f.SessionTimeout -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package natlab + +import ( + "fmt" + "net/netip" + "sync" + "time" + + "tailscale.com/util/mak" +) + +// FirewallType is the type of filtering a stateful firewall +// does. Values express different modes defined by RFC 4787. +type FirewallType int + +const ( + // AddressAndPortDependentFirewall specifies a destination + // address-and-port dependent firewall. Outbound traffic to an + // ip:port authorizes traffic from that ip:port exactly, and + // nothing else. + AddressAndPortDependentFirewall FirewallType = iota + // AddressDependentFirewall specifies a destination address + // dependent firewall. Once outbound traffic has been seen to an + // IP address, that IP address can talk back from any port. + AddressDependentFirewall + // EndpointIndependentFirewall specifies a destination endpoint + // independent firewall. Once outbound traffic has been seen from + // a source, anyone can talk back to that source. + EndpointIndependentFirewall +) + +// fwKey is the lookup key for a firewall session. While it contains a +// 4-tuple ({src,dst} {ip,port}), some FirewallTypes will zero out +// some fields, so in practice the key is either a 2-tuple (src only), +// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port). +type fwKey struct { + src netip.AddrPort + dst netip.AddrPort +} + +// key returns an fwKey for the given src and dst, trimmed according +// to the FirewallType. fwKeys are always constructed from the +// "outbound" point of view (i.e. src is the "trusted" side of the +// world), it's the caller's responsibility to swap src and dst in the +// call to key when processing packets inbound from the "untrusted" +// world. +func (s FirewallType) key(src, dst netip.AddrPort) fwKey { + k := fwKey{src: src} + switch s { + case EndpointIndependentFirewall: + case AddressDependentFirewall: + k.dst = netip.AddrPortFrom(dst.Addr(), k.dst.Port()) + case AddressAndPortDependentFirewall: + k.dst = dst + default: + panic(fmt.Sprintf("unknown firewall selectivity %v", s)) + } + return k +} + +// DefaultSessionTimeout is the default timeout for a firewall +// session. +const DefaultSessionTimeout = 30 * time.Second + +// Firewall is a simple stateful firewall that allows all outbound +// traffic and filters inbound traffic based on recently seen outbound +// traffic. Its HandlePacket method should be attached to a Machine to +// give it a stateful firewall. +type Firewall struct { + // SessionTimeout is the lifetime of idle sessions in the firewall + // state. Packets transiting from the TrustedInterface reset the + // session lifetime to SessionTimeout. If zero, + // DefaultSessionTimeout is used. + SessionTimeout time.Duration + // Type specifies how precisely return traffic must match + // previously seen outbound traffic to be allowed. Defaults to + // AddressAndPortDependentFirewall. + Type FirewallType + // TrustedInterface is an optional interface that is considered + // trusted in addition to PacketConns local to the Machine. All + // other interfaces can only respond to traffic from + // TrustedInterface or the local host. + TrustedInterface *Interface + // TimeNow is a function returning the current time. If nil, + // time.Now is used. + TimeNow func() time.Time + + // TODO: refresh directionality: outbound-only, both + + mu sync.Mutex + seen map[fwKey]time.Time // session -> deadline +} + +func (f *Firewall) timeNow() time.Time { + if f.TimeNow != nil { + return f.TimeNow() + } + return time.Now() +} + +// Reset drops all firewall state, forgetting all flows. +func (f *Firewall) Reset() { + f.mu.Lock() + defer f.mu.Unlock() + f.seen = nil +} + +func (f *Firewall) HandleOut(p *Packet, oif *Interface) *Packet { + f.mu.Lock() + defer f.mu.Unlock() + + k := f.Type.key(p.Src, p.Dst) + mak.Set(&f.seen, k, f.timeNow().Add(f.sessionTimeoutLocked())) + p.Trace("firewall out ok") + return p +} + +func (f *Firewall) HandleIn(p *Packet, iif *Interface) *Packet { + f.mu.Lock() + defer f.mu.Unlock() + + // reverse src and dst because the session table is from the POV + // of outbound packets. + k := f.Type.key(p.Dst, p.Src) + now := f.timeNow() + if now.After(f.seen[k]) { + p.Trace("firewall drop") + return nil + } + p.Trace("firewall in ok") + return p +} + +func (f *Firewall) HandleForward(p *Packet, iif *Interface, oif *Interface) *Packet { + if iif == f.TrustedInterface { + // Treat just like a locally originated packet + return f.HandleOut(p, oif) + } + if oif != f.TrustedInterface { + // Not a possible return packet from our trusted interface, drop. + p.Trace("firewall drop, unexpected oif") + return nil + } + // Otherwise, a session must exist, same as HandleIn. + return f.HandleIn(p, iif) +} + +func (f *Firewall) sessionTimeoutLocked() time.Duration { + if f.SessionTimeout == 0 { + return DefaultSessionTimeout + } + return f.SessionTimeout +} diff --git a/tstest/natlab/nat.go b/tstest/natlab/nat.go index 36b1322cdb62c..d756c5bf11833 100644 --- a/tstest/natlab/nat.go +++ b/tstest/natlab/nat.go @@ -1,252 +1,252 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package natlab - -import ( - "context" - "fmt" - "net" - "net/netip" - "sync" - "time" -) - -// mapping is the state of an allocated NAT session. -type mapping struct { - lanSrc netip.AddrPort - lanDst netip.AddrPort - wanSrc netip.AddrPort - deadline time.Time - - // pc is a PacketConn that reserves an outbound port on the NAT's - // WAN interface. We do this because ListenPacket already has - // random port selection logic built in. Additionally this means - // that concurrent use of ListenPacket for connections originating - // from the NAT box won't conflict with NAT mappings, since both - // use PacketConn to reserve ports on the machine. - pc net.PacketConn -} - -// NATType is the mapping behavior of a NAT device. Values express -// different modes defined by RFC 4787. -type NATType int - -const ( - // EndpointIndependentNAT specifies a destination endpoint - // independent NAT. All traffic from a source ip:port gets mapped - // to a single WAN ip:port. - EndpointIndependentNAT NATType = iota - // AddressDependentNAT specifies a destination address dependent - // NAT. Every distinct destination IP gets its own WAN ip:port - // allocation. - AddressDependentNAT - // AddressAndPortDependentNAT specifies a destination - // address-and-port dependent NAT. Every distinct destination - // ip:port gets its own WAN ip:port allocation. - AddressAndPortDependentNAT -) - -// natKey is the lookup key for a NAT session. While it contains a -// 4-tuple ({src,dst} {ip,port}), some NATTypes will zero out some -// fields, so in practice the key is either a 2-tuple (src only), -// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port). -type natKey struct { - src, dst netip.AddrPort -} - -func (t NATType) key(src, dst netip.AddrPort) natKey { - k := natKey{src: src} - switch t { - case EndpointIndependentNAT: - case AddressDependentNAT: - k.dst = netip.AddrPortFrom(dst.Addr(), k.dst.Port()) - case AddressAndPortDependentNAT: - k.dst = dst - default: - panic(fmt.Sprintf("unknown NAT type %v", t)) - } - return k -} - -// DefaultMappingTimeout is the default timeout for a NAT mapping. -const DefaultMappingTimeout = 30 * time.Second - -// SNAT44 implements an IPv4-to-IPv4 source NAT (SNAT) translator, with -// optional builtin firewall. -type SNAT44 struct { - // Machine is the machine to which this NAT is attached. Altered - // packets are injected back into this Machine for processing. - Machine *Machine - // ExternalInterface is the "WAN" interface of Machine. Packets - // from other sources get NATed onto this interface. - ExternalInterface *Interface - // Type specifies the mapping allocation behavior for this NAT. - Type NATType - // MappingTimeout is the lifetime of individual NAT sessions. Once - // a session expires, the mapped port effectively "closes" to new - // traffic. If MappingTimeout is 0, DefaultMappingTimeout is used. - MappingTimeout time.Duration - // Firewall is an optional packet handler that will be invoked as - // a firewall during NAT translation. The firewall always sees - // packets in their "LAN form", i.e. before translation in the - // outbound direction and after translation in the inbound - // direction. - Firewall PacketHandler - // TimeNow is a function that returns the current time. If - // nil, time.Now is used. - TimeNow func() time.Time - - mu sync.Mutex - byLAN map[natKey]*mapping // lookup by outbound packet tuple - byWAN map[netip.AddrPort]*mapping // lookup by wan ip:port only -} - -func (n *SNAT44) timeNow() time.Time { - if n.TimeNow != nil { - return n.TimeNow() - } - return time.Now() -} - -func (n *SNAT44) mappingTimeout() time.Duration { - if n.MappingTimeout == 0 { - return DefaultMappingTimeout - } - return n.MappingTimeout -} - -func (n *SNAT44) initLocked() { - if n.byLAN == nil { - n.byLAN = map[natKey]*mapping{} - n.byWAN = map[netip.AddrPort]*mapping{} - } - if n.ExternalInterface.Machine() != n.Machine { - panic(fmt.Sprintf("NAT given interface %s that is not part of given machine %s", n.ExternalInterface, n.Machine.Name)) - } -} - -func (n *SNAT44) HandleOut(p *Packet, oif *Interface) *Packet { - // NATs don't affect locally originated packets. - if n.Firewall != nil { - return n.Firewall.HandleOut(p, oif) - } - return p -} - -func (n *SNAT44) HandleIn(p *Packet, iif *Interface) *Packet { - if iif != n.ExternalInterface { - // NAT can't apply, defer to firewall. - if n.Firewall != nil { - return n.Firewall.HandleIn(p, iif) - } - return p - } - - n.mu.Lock() - defer n.mu.Unlock() - n.initLocked() - - now := n.timeNow() - mapping := n.byWAN[p.Dst] - if mapping == nil || now.After(mapping.deadline) { - // NAT didn't hit, defer to firewall or allow in for local - // socket handling. - if n.Firewall != nil { - return n.Firewall.HandleIn(p, iif) - } - return p - } - - p.Dst = mapping.lanSrc - p.Trace("dnat to %v", p.Dst) - // Don't process firewall here. We mutated the packet such that - // it's no longer destined locally, so we'll get reinvoked as - // HandleForward and need to process the altered packet there. - return p -} - -func (n *SNAT44) HandleForward(p *Packet, iif, oif *Interface) *Packet { - switch { - case oif == n.ExternalInterface: - if p.Src.Addr() == oif.V4() { - // Packet already NATed and is just retraversing Forward, - // don't touch it again. - return p - } - - if n.Firewall != nil { - p2 := n.Firewall.HandleForward(p, iif, oif) - if p2 == nil { - // firewall dropped, done - return nil - } - if !p.Equivalent(p2) { - // firewall mutated packet? Weird, but okay. - return p2 - } - } - - n.mu.Lock() - defer n.mu.Unlock() - n.initLocked() - - k := n.Type.key(p.Src, p.Dst) - now := n.timeNow() - m := n.byLAN[k] - if m == nil || now.After(m.deadline) { - pc, wanAddr := n.allocateMappedPort() - m = &mapping{ - lanSrc: p.Src, - lanDst: p.Dst, - wanSrc: wanAddr, - pc: pc, - } - n.byLAN[k] = m - n.byWAN[wanAddr] = m - } - m.deadline = now.Add(n.mappingTimeout()) - p.Src = m.wanSrc - p.Trace("snat from %v", p.Src) - return p - case iif == n.ExternalInterface: - // Packet was already un-NAT-ed, we just need to either - // firewall it or let it through. - if n.Firewall != nil { - return n.Firewall.HandleForward(p, iif, oif) - } - return p - default: - // No NAT applies, invoke firewall or drop. - if n.Firewall != nil { - return n.Firewall.HandleForward(p, iif, oif) - } - return nil - } -} - -func (n *SNAT44) allocateMappedPort() (net.PacketConn, netip.AddrPort) { - // Clean up old entries before trying to allocate, to free up any - // expired ports. - n.gc() - - ip := n.ExternalInterface.V4() - pc, err := n.Machine.ListenPacket(context.Background(), "udp", net.JoinHostPort(ip.String(), "0")) - if err != nil { - panic(fmt.Sprintf("ran out of NAT ports: %v", err)) - } - addr := netip.AddrPortFrom(ip, uint16(pc.LocalAddr().(*net.UDPAddr).Port)) - return pc, addr -} - -func (n *SNAT44) gc() { - now := n.timeNow() - for _, m := range n.byLAN { - if !now.After(m.deadline) { - continue - } - m.pc.Close() - delete(n.byLAN, n.Type.key(m.lanSrc, m.lanDst)) - delete(n.byWAN, m.wanSrc) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package natlab + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + "time" +) + +// mapping is the state of an allocated NAT session. +type mapping struct { + lanSrc netip.AddrPort + lanDst netip.AddrPort + wanSrc netip.AddrPort + deadline time.Time + + // pc is a PacketConn that reserves an outbound port on the NAT's + // WAN interface. We do this because ListenPacket already has + // random port selection logic built in. Additionally this means + // that concurrent use of ListenPacket for connections originating + // from the NAT box won't conflict with NAT mappings, since both + // use PacketConn to reserve ports on the machine. + pc net.PacketConn +} + +// NATType is the mapping behavior of a NAT device. Values express +// different modes defined by RFC 4787. +type NATType int + +const ( + // EndpointIndependentNAT specifies a destination endpoint + // independent NAT. All traffic from a source ip:port gets mapped + // to a single WAN ip:port. + EndpointIndependentNAT NATType = iota + // AddressDependentNAT specifies a destination address dependent + // NAT. Every distinct destination IP gets its own WAN ip:port + // allocation. + AddressDependentNAT + // AddressAndPortDependentNAT specifies a destination + // address-and-port dependent NAT. Every distinct destination + // ip:port gets its own WAN ip:port allocation. + AddressAndPortDependentNAT +) + +// natKey is the lookup key for a NAT session. While it contains a +// 4-tuple ({src,dst} {ip,port}), some NATTypes will zero out some +// fields, so in practice the key is either a 2-tuple (src only), +// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port). +type natKey struct { + src, dst netip.AddrPort +} + +func (t NATType) key(src, dst netip.AddrPort) natKey { + k := natKey{src: src} + switch t { + case EndpointIndependentNAT: + case AddressDependentNAT: + k.dst = netip.AddrPortFrom(dst.Addr(), k.dst.Port()) + case AddressAndPortDependentNAT: + k.dst = dst + default: + panic(fmt.Sprintf("unknown NAT type %v", t)) + } + return k +} + +// DefaultMappingTimeout is the default timeout for a NAT mapping. +const DefaultMappingTimeout = 30 * time.Second + +// SNAT44 implements an IPv4-to-IPv4 source NAT (SNAT) translator, with +// optional builtin firewall. +type SNAT44 struct { + // Machine is the machine to which this NAT is attached. Altered + // packets are injected back into this Machine for processing. + Machine *Machine + // ExternalInterface is the "WAN" interface of Machine. Packets + // from other sources get NATed onto this interface. + ExternalInterface *Interface + // Type specifies the mapping allocation behavior for this NAT. + Type NATType + // MappingTimeout is the lifetime of individual NAT sessions. Once + // a session expires, the mapped port effectively "closes" to new + // traffic. If MappingTimeout is 0, DefaultMappingTimeout is used. + MappingTimeout time.Duration + // Firewall is an optional packet handler that will be invoked as + // a firewall during NAT translation. The firewall always sees + // packets in their "LAN form", i.e. before translation in the + // outbound direction and after translation in the inbound + // direction. + Firewall PacketHandler + // TimeNow is a function that returns the current time. If + // nil, time.Now is used. + TimeNow func() time.Time + + mu sync.Mutex + byLAN map[natKey]*mapping // lookup by outbound packet tuple + byWAN map[netip.AddrPort]*mapping // lookup by wan ip:port only +} + +func (n *SNAT44) timeNow() time.Time { + if n.TimeNow != nil { + return n.TimeNow() + } + return time.Now() +} + +func (n *SNAT44) mappingTimeout() time.Duration { + if n.MappingTimeout == 0 { + return DefaultMappingTimeout + } + return n.MappingTimeout +} + +func (n *SNAT44) initLocked() { + if n.byLAN == nil { + n.byLAN = map[natKey]*mapping{} + n.byWAN = map[netip.AddrPort]*mapping{} + } + if n.ExternalInterface.Machine() != n.Machine { + panic(fmt.Sprintf("NAT given interface %s that is not part of given machine %s", n.ExternalInterface, n.Machine.Name)) + } +} + +func (n *SNAT44) HandleOut(p *Packet, oif *Interface) *Packet { + // NATs don't affect locally originated packets. + if n.Firewall != nil { + return n.Firewall.HandleOut(p, oif) + } + return p +} + +func (n *SNAT44) HandleIn(p *Packet, iif *Interface) *Packet { + if iif != n.ExternalInterface { + // NAT can't apply, defer to firewall. + if n.Firewall != nil { + return n.Firewall.HandleIn(p, iif) + } + return p + } + + n.mu.Lock() + defer n.mu.Unlock() + n.initLocked() + + now := n.timeNow() + mapping := n.byWAN[p.Dst] + if mapping == nil || now.After(mapping.deadline) { + // NAT didn't hit, defer to firewall or allow in for local + // socket handling. + if n.Firewall != nil { + return n.Firewall.HandleIn(p, iif) + } + return p + } + + p.Dst = mapping.lanSrc + p.Trace("dnat to %v", p.Dst) + // Don't process firewall here. We mutated the packet such that + // it's no longer destined locally, so we'll get reinvoked as + // HandleForward and need to process the altered packet there. + return p +} + +func (n *SNAT44) HandleForward(p *Packet, iif, oif *Interface) *Packet { + switch { + case oif == n.ExternalInterface: + if p.Src.Addr() == oif.V4() { + // Packet already NATed and is just retraversing Forward, + // don't touch it again. + return p + } + + if n.Firewall != nil { + p2 := n.Firewall.HandleForward(p, iif, oif) + if p2 == nil { + // firewall dropped, done + return nil + } + if !p.Equivalent(p2) { + // firewall mutated packet? Weird, but okay. + return p2 + } + } + + n.mu.Lock() + defer n.mu.Unlock() + n.initLocked() + + k := n.Type.key(p.Src, p.Dst) + now := n.timeNow() + m := n.byLAN[k] + if m == nil || now.After(m.deadline) { + pc, wanAddr := n.allocateMappedPort() + m = &mapping{ + lanSrc: p.Src, + lanDst: p.Dst, + wanSrc: wanAddr, + pc: pc, + } + n.byLAN[k] = m + n.byWAN[wanAddr] = m + } + m.deadline = now.Add(n.mappingTimeout()) + p.Src = m.wanSrc + p.Trace("snat from %v", p.Src) + return p + case iif == n.ExternalInterface: + // Packet was already un-NAT-ed, we just need to either + // firewall it or let it through. + if n.Firewall != nil { + return n.Firewall.HandleForward(p, iif, oif) + } + return p + default: + // No NAT applies, invoke firewall or drop. + if n.Firewall != nil { + return n.Firewall.HandleForward(p, iif, oif) + } + return nil + } +} + +func (n *SNAT44) allocateMappedPort() (net.PacketConn, netip.AddrPort) { + // Clean up old entries before trying to allocate, to free up any + // expired ports. + n.gc() + + ip := n.ExternalInterface.V4() + pc, err := n.Machine.ListenPacket(context.Background(), "udp", net.JoinHostPort(ip.String(), "0")) + if err != nil { + panic(fmt.Sprintf("ran out of NAT ports: %v", err)) + } + addr := netip.AddrPortFrom(ip, uint16(pc.LocalAddr().(*net.UDPAddr).Port)) + return pc, addr +} + +func (n *SNAT44) gc() { + now := n.timeNow() + for _, m := range n.byLAN { + if !now.After(m.deadline) { + continue + } + m.pc.Close() + delete(n.byLAN, n.Type.key(m.lanSrc, m.lanDst)) + delete(n.byWAN, m.wanSrc) + } +} diff --git a/tstest/tstest.go b/tstest/tstest.go index 118aa382749ae..2d0d1351e293a 100644 --- a/tstest/tstest.go +++ b/tstest/tstest.go @@ -1,95 +1,95 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package tstest provides utilities for use in unit tests. -package tstest - -import ( - "context" - "os" - "strconv" - "strings" - "sync/atomic" - "testing" - "time" - - "tailscale.com/envknob" - "tailscale.com/logtail/backoff" - "tailscale.com/types/logger" - "tailscale.com/util/cibuild" -) - -// Replace replaces the value of target with val. -// The old value is restored when the test ends. -func Replace[T any](t testing.TB, target *T, val T) { - t.Helper() - if target == nil { - t.Fatalf("Replace: nil pointer") - panic("unreachable") // pacify staticcheck - } - old := *target - t.Cleanup(func() { - *target = old - }) - - *target = val - return -} - -// WaitFor retries try for up to maxWait. -// It returns nil once try returns nil the first time. -// If maxWait passes without success, it returns try's last error. -func WaitFor(maxWait time.Duration, try func() error) error { - bo := backoff.NewBackoff("wait-for", logger.Discard, maxWait/4) - deadline := time.Now().Add(maxWait) - var err error - for time.Now().Before(deadline) { - err = try() - if err == nil { - break - } - bo.BackOff(context.Background(), err) - } - return err -} - -var testNum atomic.Int32 - -// Shard skips t if it's not running if the TS_TEST_SHARD test shard is set to -// "n/m" and this test execution number in the process mod m is not equal to n-1. -// That is, to run with 4 shards, set TS_TEST_SHARD=1/4, ..., TS_TEST_SHARD=4/4 -// for the four jobs. -func Shard(t testing.TB) { - e := os.Getenv("TS_TEST_SHARD") - a, b, ok := strings.Cut(e, "/") - if !ok { - return - } - wantShard, _ := strconv.ParseInt(a, 10, 32) - shards, _ := strconv.ParseInt(b, 10, 32) - if wantShard == 0 || shards == 0 { - return - } - - shard := ((testNum.Add(1) - 1) % int32(shards)) + 1 - if shard != int32(wantShard) { - t.Skipf("skipping shard %d/%d (process has TS_TEST_SHARD=%q)", shard, shards, e) - } -} - -// SkipOnUnshardedCI skips t if we're in CI and the TS_TEST_SHARD -// environment variable isn't set. -func SkipOnUnshardedCI(t testing.TB) { - if cibuild.On() && os.Getenv("TS_TEST_SHARD") == "" { - t.Skip("skipping on CI without TS_TEST_SHARD") - } -} - -var serializeParallel = envknob.RegisterBool("TS_SERIAL_TESTS") - -// Parallel calls t.Parallel, unless TS_SERIAL_TESTS is set true. -func Parallel(t *testing.T) { - if !serializeParallel() { - t.Parallel() - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tstest provides utilities for use in unit tests. +package tstest + +import ( + "context" + "os" + "strconv" + "strings" + "sync/atomic" + "testing" + "time" + + "tailscale.com/envknob" + "tailscale.com/logtail/backoff" + "tailscale.com/types/logger" + "tailscale.com/util/cibuild" +) + +// Replace replaces the value of target with val. +// The old value is restored when the test ends. +func Replace[T any](t testing.TB, target *T, val T) { + t.Helper() + if target == nil { + t.Fatalf("Replace: nil pointer") + panic("unreachable") // pacify staticcheck + } + old := *target + t.Cleanup(func() { + *target = old + }) + + *target = val + return +} + +// WaitFor retries try for up to maxWait. +// It returns nil once try returns nil the first time. +// If maxWait passes without success, it returns try's last error. +func WaitFor(maxWait time.Duration, try func() error) error { + bo := backoff.NewBackoff("wait-for", logger.Discard, maxWait/4) + deadline := time.Now().Add(maxWait) + var err error + for time.Now().Before(deadline) { + err = try() + if err == nil { + break + } + bo.BackOff(context.Background(), err) + } + return err +} + +var testNum atomic.Int32 + +// Shard skips t if it's not running if the TS_TEST_SHARD test shard is set to +// "n/m" and this test execution number in the process mod m is not equal to n-1. +// That is, to run with 4 shards, set TS_TEST_SHARD=1/4, ..., TS_TEST_SHARD=4/4 +// for the four jobs. +func Shard(t testing.TB) { + e := os.Getenv("TS_TEST_SHARD") + a, b, ok := strings.Cut(e, "/") + if !ok { + return + } + wantShard, _ := strconv.ParseInt(a, 10, 32) + shards, _ := strconv.ParseInt(b, 10, 32) + if wantShard == 0 || shards == 0 { + return + } + + shard := ((testNum.Add(1) - 1) % int32(shards)) + 1 + if shard != int32(wantShard) { + t.Skipf("skipping shard %d/%d (process has TS_TEST_SHARD=%q)", shard, shards, e) + } +} + +// SkipOnUnshardedCI skips t if we're in CI and the TS_TEST_SHARD +// environment variable isn't set. +func SkipOnUnshardedCI(t testing.TB) { + if cibuild.On() && os.Getenv("TS_TEST_SHARD") == "" { + t.Skip("skipping on CI without TS_TEST_SHARD") + } +} + +var serializeParallel = envknob.RegisterBool("TS_SERIAL_TESTS") + +// Parallel calls t.Parallel, unless TS_SERIAL_TESTS is set true. +func Parallel(t *testing.T) { + if !serializeParallel() { + t.Parallel() + } +} diff --git a/tstest/tstest_test.go b/tstest/tstest_test.go index 20a9f7bf1faa2..e988d5d5624b6 100644 --- a/tstest/tstest_test.go +++ b/tstest/tstest_test.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstest - -import "testing" - -func TestReplace(t *testing.T) { - before := "before" - done := false - t.Run("replace", func(t *testing.T) { - Replace(t, &before, "after") - if before != "after" { - t.Errorf("before = %q; want %q", before, "after") - } - done = true - }) - if !done { - t.Fatal("subtest didn't run") - } - if before != "before" { - t.Errorf("before = %q; want %q", before, "before") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import "testing" + +func TestReplace(t *testing.T) { + before := "before" + done := false + t.Run("replace", func(t *testing.T) { + Replace(t, &before, "after") + if before != "after" { + t.Errorf("before = %q; want %q", before, "after") + } + done = true + }) + if !done { + t.Fatal("subtest didn't run") + } + if before != "before" { + t.Errorf("before = %q; want %q", before, "before") + } +} diff --git a/tstime/mono/mono.go b/tstime/mono/mono.go index 94dca7d79b6bb..260e02b0fb0f3 100644 --- a/tstime/mono/mono.go +++ b/tstime/mono/mono.go @@ -1,127 +1,127 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package mono provides fast monotonic time. -// On most platforms, mono.Now is about 2x faster than time.Now. -// However, time.Now is really fast, and nicer to use. -// -// For almost all purposes, you should use time.Now. -// -// Package mono exists because we get the current time multiple -// times per network packet, at which point it makes a -// measurable difference. -package mono - -import ( - "fmt" - "sync/atomic" - "time" -) - -// Time is the number of nanoseconds elapsed since an unspecified reference start time. -type Time int64 - -// Now returns the current monotonic time. -func Now() Time { - // On a newly started machine, the monotonic clock might be very near zero. - // Thus mono.Time(0).Before(mono.Now.Add(-time.Minute)) might yield true. - // The corresponding package time expression never does, if the wall clock is correct. - // Preserve this correspondence by increasing the "base" monotonic clock by a fair amount. - const baseOffset int64 = 1 << 55 // approximately 10,000 hours in nanoseconds - return Time(int64(time.Since(baseWall)) + baseOffset) -} - -// Since returns the time elapsed since t. -func Since(t Time) time.Duration { - return time.Duration(Now() - t) -} - -// Sub returns t-n, the duration from n to t. -func (t Time) Sub(n Time) time.Duration { - return time.Duration(t - n) -} - -// Add returns t+d. -func (t Time) Add(d time.Duration) Time { - return t + Time(d) -} - -// After reports t > n, whether t is after n. -func (t Time) After(n Time) bool { - return t > n -} - -// Before reports t < n, whether t is before n. -func (t Time) Before(n Time) bool { - return t < n -} - -// IsZero reports whether t == 0. -func (t Time) IsZero() bool { - return t == 0 -} - -// StoreAtomic does an atomic store *t = new. -func (t *Time) StoreAtomic(new Time) { - atomic.StoreInt64((*int64)(t), int64(new)) -} - -// LoadAtomic does an atomic load *t. -func (t *Time) LoadAtomic() Time { - return Time(atomic.LoadInt64((*int64)(t))) -} - -// baseWall and baseMono are a pair of almost-identical times used to correlate a Time with a wall time. -var ( - baseWall time.Time - baseMono Time -) - -func init() { - baseWall = time.Now() - baseMono = Now() -} - -// String prints t, including an estimated equivalent wall clock. -// This is best-effort only, for rough debugging purposes only. -// Since t is a monotonic time, it can vary from the actual wall clock by arbitrary amounts. -// Even in the best of circumstances, it may vary by a few milliseconds. -func (t Time) String() string { - return fmt.Sprintf("mono.Time(ns=%d, estimated wall=%v)", int64(t), baseWall.Add(t.Sub(baseMono)).Truncate(0)) -} - -// WallTime returns an approximate wall time that corresponded to t. -func (t Time) WallTime() time.Time { - if !t.IsZero() { - return baseWall.Add(t.Sub(baseMono)).Truncate(0) - } - return time.Time{} -} - -// MarshalJSON formats t for JSON as if it were a time.Time. -// We format Time this way for backwards-compatibility. -// Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged -// across different invocations of the Go process. This is best-effort only. -// Since t is a monotonic time, it can vary from the actual wall clock by arbitrary amounts. -// Even in the best of circumstances, it may vary by a few milliseconds. -func (t Time) MarshalJSON() ([]byte, error) { - tt := t.WallTime() - return tt.MarshalJSON() -} - -// UnmarshalJSON sets t according to data. -// Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged -// across different invocations of the Go process. This is best-effort only. -func (t *Time) UnmarshalJSON(data []byte) error { - var tt time.Time - err := tt.UnmarshalJSON(data) - if err != nil { - return err - } - if tt.IsZero() { - *t = 0 - return nil - } - *t = baseMono.Add(tt.Sub(baseWall)) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package mono provides fast monotonic time. +// On most platforms, mono.Now is about 2x faster than time.Now. +// However, time.Now is really fast, and nicer to use. +// +// For almost all purposes, you should use time.Now. +// +// Package mono exists because we get the current time multiple +// times per network packet, at which point it makes a +// measurable difference. +package mono + +import ( + "fmt" + "sync/atomic" + "time" +) + +// Time is the number of nanoseconds elapsed since an unspecified reference start time. +type Time int64 + +// Now returns the current monotonic time. +func Now() Time { + // On a newly started machine, the monotonic clock might be very near zero. + // Thus mono.Time(0).Before(mono.Now.Add(-time.Minute)) might yield true. + // The corresponding package time expression never does, if the wall clock is correct. + // Preserve this correspondence by increasing the "base" monotonic clock by a fair amount. + const baseOffset int64 = 1 << 55 // approximately 10,000 hours in nanoseconds + return Time(int64(time.Since(baseWall)) + baseOffset) +} + +// Since returns the time elapsed since t. +func Since(t Time) time.Duration { + return time.Duration(Now() - t) +} + +// Sub returns t-n, the duration from n to t. +func (t Time) Sub(n Time) time.Duration { + return time.Duration(t - n) +} + +// Add returns t+d. +func (t Time) Add(d time.Duration) Time { + return t + Time(d) +} + +// After reports t > n, whether t is after n. +func (t Time) After(n Time) bool { + return t > n +} + +// Before reports t < n, whether t is before n. +func (t Time) Before(n Time) bool { + return t < n +} + +// IsZero reports whether t == 0. +func (t Time) IsZero() bool { + return t == 0 +} + +// StoreAtomic does an atomic store *t = new. +func (t *Time) StoreAtomic(new Time) { + atomic.StoreInt64((*int64)(t), int64(new)) +} + +// LoadAtomic does an atomic load *t. +func (t *Time) LoadAtomic() Time { + return Time(atomic.LoadInt64((*int64)(t))) +} + +// baseWall and baseMono are a pair of almost-identical times used to correlate a Time with a wall time. +var ( + baseWall time.Time + baseMono Time +) + +func init() { + baseWall = time.Now() + baseMono = Now() +} + +// String prints t, including an estimated equivalent wall clock. +// This is best-effort only, for rough debugging purposes only. +// Since t is a monotonic time, it can vary from the actual wall clock by arbitrary amounts. +// Even in the best of circumstances, it may vary by a few milliseconds. +func (t Time) String() string { + return fmt.Sprintf("mono.Time(ns=%d, estimated wall=%v)", int64(t), baseWall.Add(t.Sub(baseMono)).Truncate(0)) +} + +// WallTime returns an approximate wall time that corresponded to t. +func (t Time) WallTime() time.Time { + if !t.IsZero() { + return baseWall.Add(t.Sub(baseMono)).Truncate(0) + } + return time.Time{} +} + +// MarshalJSON formats t for JSON as if it were a time.Time. +// We format Time this way for backwards-compatibility. +// Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged +// across different invocations of the Go process. This is best-effort only. +// Since t is a monotonic time, it can vary from the actual wall clock by arbitrary amounts. +// Even in the best of circumstances, it may vary by a few milliseconds. +func (t Time) MarshalJSON() ([]byte, error) { + tt := t.WallTime() + return tt.MarshalJSON() +} + +// UnmarshalJSON sets t according to data. +// Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged +// across different invocations of the Go process. This is best-effort only. +func (t *Time) UnmarshalJSON(data []byte) error { + var tt time.Time + err := tt.UnmarshalJSON(data) + if err != nil { + return err + } + if tt.IsZero() { + *t = 0 + return nil + } + *t = baseMono.Add(tt.Sub(baseWall)) + return nil +} diff --git a/tstime/rate/rate.go b/tstime/rate/rate.go index 19dc26e6ae8a7..f0473862a2890 100644 --- a/tstime/rate/rate.go +++ b/tstime/rate/rate.go @@ -1,90 +1,90 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// This is a modified, simplified version of code from golang.org/x/time/rate. - -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package rate provides a rate limiter. -package rate - -import ( - "sync" - "time" - - "tailscale.com/tstime/mono" -) - -// Limit defines the maximum frequency of some events. -// Limit is represented as number of events per second. -// A zero Limit is invalid. -type Limit float64 - -// Every converts a minimum time interval between events to a Limit. -func Every(interval time.Duration) Limit { - if interval <= 0 { - panic("invalid interval") - } - return 1 / Limit(interval.Seconds()) -} - -// A Limiter controls how frequently events are allowed to happen. -// It implements a [token bucket] of a particular size b, -// initially full and refilled at rate r tokens per second. -// Informally, in any large enough time interval, -// the Limiter limits the rate to r tokens per second, -// with a maximum burst size of b events. -// Use NewLimiter to create non-zero Limiters. -// -// [token bucket]: https://en.wikipedia.org/wiki/Token_bucket -type Limiter struct { - limit Limit - burst float64 - mu sync.Mutex // protects following fields - tokens float64 // number of tokens currently in bucket - last mono.Time // the last time the limiter's tokens field was updated -} - -// NewLimiter returns a new Limiter that allows events up to rate r and permits -// bursts of at most b tokens. -func NewLimiter(r Limit, b int) *Limiter { - if b < 1 { - panic("bad burst, must be at least 1") - } - return &Limiter{limit: r, burst: float64(b)} -} - -// Allow reports whether an event may happen now. -func (lim *Limiter) Allow() bool { - return lim.allow(mono.Now()) -} - -func (lim *Limiter) allow(now mono.Time) bool { - lim.mu.Lock() - defer lim.mu.Unlock() - - // If time has moved backwards, look around awkwardly and pretend nothing happened. - if now.Before(lim.last) { - lim.last = now - } - - // Calculate the new number of tokens available due to the passage of time. - elapsed := now.Sub(lim.last) - tokens := lim.tokens + float64(lim.limit)*elapsed.Seconds() - if tokens > lim.burst { - tokens = lim.burst - } - - // Consume a token. - tokens-- - - // Update state. - ok := tokens >= 0 - if ok { - lim.last = now - lim.tokens = tokens - } - return ok -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This is a modified, simplified version of code from golang.org/x/time/rate. + +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package rate provides a rate limiter. +package rate + +import ( + "sync" + "time" + + "tailscale.com/tstime/mono" +) + +// Limit defines the maximum frequency of some events. +// Limit is represented as number of events per second. +// A zero Limit is invalid. +type Limit float64 + +// Every converts a minimum time interval between events to a Limit. +func Every(interval time.Duration) Limit { + if interval <= 0 { + panic("invalid interval") + } + return 1 / Limit(interval.Seconds()) +} + +// A Limiter controls how frequently events are allowed to happen. +// It implements a [token bucket] of a particular size b, +// initially full and refilled at rate r tokens per second. +// Informally, in any large enough time interval, +// the Limiter limits the rate to r tokens per second, +// with a maximum burst size of b events. +// Use NewLimiter to create non-zero Limiters. +// +// [token bucket]: https://en.wikipedia.org/wiki/Token_bucket +type Limiter struct { + limit Limit + burst float64 + mu sync.Mutex // protects following fields + tokens float64 // number of tokens currently in bucket + last mono.Time // the last time the limiter's tokens field was updated +} + +// NewLimiter returns a new Limiter that allows events up to rate r and permits +// bursts of at most b tokens. +func NewLimiter(r Limit, b int) *Limiter { + if b < 1 { + panic("bad burst, must be at least 1") + } + return &Limiter{limit: r, burst: float64(b)} +} + +// Allow reports whether an event may happen now. +func (lim *Limiter) Allow() bool { + return lim.allow(mono.Now()) +} + +func (lim *Limiter) allow(now mono.Time) bool { + lim.mu.Lock() + defer lim.mu.Unlock() + + // If time has moved backwards, look around awkwardly and pretend nothing happened. + if now.Before(lim.last) { + lim.last = now + } + + // Calculate the new number of tokens available due to the passage of time. + elapsed := now.Sub(lim.last) + tokens := lim.tokens + float64(lim.limit)*elapsed.Seconds() + if tokens > lim.burst { + tokens = lim.burst + } + + // Consume a token. + tokens-- + + // Update state. + ok := tokens >= 0 + if ok { + lim.last = now + lim.tokens = tokens + } + return ok +} diff --git a/tstime/tstime.go b/tstime/tstime.go index 22616bca7a47a..1c006355f8726 100644 --- a/tstime/tstime.go +++ b/tstime/tstime.go @@ -1,185 +1,185 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package tstime defines Tailscale-specific time utilities. -package tstime - -import ( - "context" - "strconv" - "strings" - "time" -) - -// Parse3339 is a wrapper around time.Parse(time.RFC3339, s). -func Parse3339(s string) (time.Time, error) { - return time.Parse(time.RFC3339, s) -} - -// Parse3339B is Parse3339 but for byte slices. -func Parse3339B(b []byte) (time.Time, error) { - var t time.Time - if err := t.UnmarshalText(b); err != nil { - return Parse3339(string(b)) // reproduce same error message - } - return t, nil -} - -// ParseDuration is more expressive than [time.ParseDuration], -// also accepting 'd' (days) and 'w' (weeks) literals. -func ParseDuration(s string) (time.Duration, error) { - for { - end := strings.IndexAny(s, "dw") - if end < 0 { - break - } - start := end - (len(s[:end]) - len(strings.TrimRight(s[:end], "0123456789"))) - n, err := strconv.Atoi(s[start:end]) - if err != nil { - return 0, err - } - hours := 24 - if s[end] == 'w' { - hours *= 7 - } - s = s[:start] + s[end+1:] + strconv.Itoa(n*hours) + "h" - } - return time.ParseDuration(s) -} - -// Sleep is like [time.Sleep] but returns early upon context cancelation. -// It reports whether the full sleep duration was achieved. -func Sleep(ctx context.Context, d time.Duration) bool { - timer := time.NewTimer(d) - defer timer.Stop() - select { - case <-ctx.Done(): - return false - case <-timer.C: - return true - } -} - -// DefaultClock is a wrapper around a Clock. -// It uses StdClock by default if Clock is nil. -type DefaultClock struct{ Clock } - -// TODO: We should make the methods of DefaultClock inlineable -// so that we can optimize for the common case where c.Clock == nil. - -func (c DefaultClock) Now() time.Time { - if c.Clock == nil { - return time.Now() - } - return c.Clock.Now() -} -func (c DefaultClock) NewTimer(d time.Duration) (TimerController, <-chan time.Time) { - if c.Clock == nil { - t := time.NewTimer(d) - return t, t.C - } - return c.Clock.NewTimer(d) -} -func (c DefaultClock) NewTicker(d time.Duration) (TickerController, <-chan time.Time) { - if c.Clock == nil { - t := time.NewTicker(d) - return t, t.C - } - return c.Clock.NewTicker(d) -} -func (c DefaultClock) AfterFunc(d time.Duration, f func()) TimerController { - if c.Clock == nil { - return time.AfterFunc(d, f) - } - return c.Clock.AfterFunc(d, f) -} -func (c DefaultClock) Since(t time.Time) time.Duration { - if c.Clock == nil { - return time.Since(t) - } - return c.Clock.Since(t) -} - -// Clock offers a subset of the functionality from the std/time package. -// Normally, applications will use the StdClock implementation that calls the -// appropriate std/time exported funcs. The advantage of using Clock is that -// tests can substitute a different implementation, allowing the test to control -// time precisely, something required for certain types of tests to be possible -// at all, speeds up execution by not needing to sleep, and can dramatically -// reduce the risk of flakes due to tests executing too slowly or quickly. -type Clock interface { - // Now returns the current time, as in time.Now. - Now() time.Time - // NewTimer returns a timer whose notion of the current time is controlled - // by this Clock. It follows the semantics of time.NewTimer as closely as - // possible but is adapted to return an interface, so the channel needs to - // be returned as well. - NewTimer(d time.Duration) (TimerController, <-chan time.Time) - // NewTicker returns a ticker whose notion of the current time is controlled - // by this Clock. It follows the semantics of time.NewTicker as closely as - // possible but is adapted to return an interface, so the channel needs to - // be returned as well. - NewTicker(d time.Duration) (TickerController, <-chan time.Time) - // AfterFunc returns a ticker whose notion of the current time is controlled - // by this Clock. When the ticker expires, it will call the provided func. - // It follows the semantics of time.AfterFunc. - AfterFunc(d time.Duration, f func()) TimerController - // Since returns the time elapsed since t. - // It follows the semantics of time.Since. - Since(t time.Time) time.Duration -} - -// TickerController offers the receivers of a time.Ticker to ensure -// compatibility with standard timers, but allows for the option of substituting -// a standard timer with something else for testing purposes. -type TickerController interface { - // Reset follows the same semantics as with time.Ticker.Reset. - Reset(d time.Duration) - // Stop follows the same semantics as with time.Ticker.Stop. - Stop() -} - -// TimerController offers the receivers of a time.Timer to ensure -// compatibility with standard timers, but allows for the option of substituting -// a standard timer with something else for testing purposes. -type TimerController interface { - // Reset follows the same semantics as with time.Timer.Reset. - Reset(d time.Duration) bool - // Stop follows the same semantics as with time.Timer.Stop. - Stop() bool -} - -// StdClock is a simple implementation of Clock using the relevant funcs in the -// std/time package. -type StdClock struct{} - -// Now calls time.Now. -func (StdClock) Now() time.Time { - return time.Now() -} - -// NewTimer calls time.NewTimer. As an interface does not allow for struct -// members and other packages cannot add receivers to another package, the -// channel is also returned because it would be otherwise inaccessible. -func (StdClock) NewTimer(d time.Duration) (TimerController, <-chan time.Time) { - t := time.NewTimer(d) - return t, t.C -} - -// NewTicker calls time.NewTicker. As an interface does not allow for struct -// members and other packages cannot add receivers to another package, the -// channel is also returned because it would be otherwise inaccessible. -func (StdClock) NewTicker(d time.Duration) (TickerController, <-chan time.Time) { - t := time.NewTicker(d) - return t, t.C -} - -// AfterFunc calls time.AfterFunc. -func (StdClock) AfterFunc(d time.Duration, f func()) TimerController { - return time.AfterFunc(d, f) -} - -// Since calls time.Since. -func (StdClock) Since(t time.Time) time.Duration { - return time.Since(t) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tstime defines Tailscale-specific time utilities. +package tstime + +import ( + "context" + "strconv" + "strings" + "time" +) + +// Parse3339 is a wrapper around time.Parse(time.RFC3339, s). +func Parse3339(s string) (time.Time, error) { + return time.Parse(time.RFC3339, s) +} + +// Parse3339B is Parse3339 but for byte slices. +func Parse3339B(b []byte) (time.Time, error) { + var t time.Time + if err := t.UnmarshalText(b); err != nil { + return Parse3339(string(b)) // reproduce same error message + } + return t, nil +} + +// ParseDuration is more expressive than [time.ParseDuration], +// also accepting 'd' (days) and 'w' (weeks) literals. +func ParseDuration(s string) (time.Duration, error) { + for { + end := strings.IndexAny(s, "dw") + if end < 0 { + break + } + start := end - (len(s[:end]) - len(strings.TrimRight(s[:end], "0123456789"))) + n, err := strconv.Atoi(s[start:end]) + if err != nil { + return 0, err + } + hours := 24 + if s[end] == 'w' { + hours *= 7 + } + s = s[:start] + s[end+1:] + strconv.Itoa(n*hours) + "h" + } + return time.ParseDuration(s) +} + +// Sleep is like [time.Sleep] but returns early upon context cancelation. +// It reports whether the full sleep duration was achieved. +func Sleep(ctx context.Context, d time.Duration) bool { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return false + case <-timer.C: + return true + } +} + +// DefaultClock is a wrapper around a Clock. +// It uses StdClock by default if Clock is nil. +type DefaultClock struct{ Clock } + +// TODO: We should make the methods of DefaultClock inlineable +// so that we can optimize for the common case where c.Clock == nil. + +func (c DefaultClock) Now() time.Time { + if c.Clock == nil { + return time.Now() + } + return c.Clock.Now() +} +func (c DefaultClock) NewTimer(d time.Duration) (TimerController, <-chan time.Time) { + if c.Clock == nil { + t := time.NewTimer(d) + return t, t.C + } + return c.Clock.NewTimer(d) +} +func (c DefaultClock) NewTicker(d time.Duration) (TickerController, <-chan time.Time) { + if c.Clock == nil { + t := time.NewTicker(d) + return t, t.C + } + return c.Clock.NewTicker(d) +} +func (c DefaultClock) AfterFunc(d time.Duration, f func()) TimerController { + if c.Clock == nil { + return time.AfterFunc(d, f) + } + return c.Clock.AfterFunc(d, f) +} +func (c DefaultClock) Since(t time.Time) time.Duration { + if c.Clock == nil { + return time.Since(t) + } + return c.Clock.Since(t) +} + +// Clock offers a subset of the functionality from the std/time package. +// Normally, applications will use the StdClock implementation that calls the +// appropriate std/time exported funcs. The advantage of using Clock is that +// tests can substitute a different implementation, allowing the test to control +// time precisely, something required for certain types of tests to be possible +// at all, speeds up execution by not needing to sleep, and can dramatically +// reduce the risk of flakes due to tests executing too slowly or quickly. +type Clock interface { + // Now returns the current time, as in time.Now. + Now() time.Time + // NewTimer returns a timer whose notion of the current time is controlled + // by this Clock. It follows the semantics of time.NewTimer as closely as + // possible but is adapted to return an interface, so the channel needs to + // be returned as well. + NewTimer(d time.Duration) (TimerController, <-chan time.Time) + // NewTicker returns a ticker whose notion of the current time is controlled + // by this Clock. It follows the semantics of time.NewTicker as closely as + // possible but is adapted to return an interface, so the channel needs to + // be returned as well. + NewTicker(d time.Duration) (TickerController, <-chan time.Time) + // AfterFunc returns a ticker whose notion of the current time is controlled + // by this Clock. When the ticker expires, it will call the provided func. + // It follows the semantics of time.AfterFunc. + AfterFunc(d time.Duration, f func()) TimerController + // Since returns the time elapsed since t. + // It follows the semantics of time.Since. + Since(t time.Time) time.Duration +} + +// TickerController offers the receivers of a time.Ticker to ensure +// compatibility with standard timers, but allows for the option of substituting +// a standard timer with something else for testing purposes. +type TickerController interface { + // Reset follows the same semantics as with time.Ticker.Reset. + Reset(d time.Duration) + // Stop follows the same semantics as with time.Ticker.Stop. + Stop() +} + +// TimerController offers the receivers of a time.Timer to ensure +// compatibility with standard timers, but allows for the option of substituting +// a standard timer with something else for testing purposes. +type TimerController interface { + // Reset follows the same semantics as with time.Timer.Reset. + Reset(d time.Duration) bool + // Stop follows the same semantics as with time.Timer.Stop. + Stop() bool +} + +// StdClock is a simple implementation of Clock using the relevant funcs in the +// std/time package. +type StdClock struct{} + +// Now calls time.Now. +func (StdClock) Now() time.Time { + return time.Now() +} + +// NewTimer calls time.NewTimer. As an interface does not allow for struct +// members and other packages cannot add receivers to another package, the +// channel is also returned because it would be otherwise inaccessible. +func (StdClock) NewTimer(d time.Duration) (TimerController, <-chan time.Time) { + t := time.NewTimer(d) + return t, t.C +} + +// NewTicker calls time.NewTicker. As an interface does not allow for struct +// members and other packages cannot add receivers to another package, the +// channel is also returned because it would be otherwise inaccessible. +func (StdClock) NewTicker(d time.Duration) (TickerController, <-chan time.Time) { + t := time.NewTicker(d) + return t, t.C +} + +// AfterFunc calls time.AfterFunc. +func (StdClock) AfterFunc(d time.Duration, f func()) TimerController { + return time.AfterFunc(d, f) +} + +// Since calls time.Since. +func (StdClock) Since(t time.Time) time.Duration { + return time.Since(t) +} diff --git a/tstime/tstime_test.go b/tstime/tstime_test.go index 1169408b69b29..3ffeaf0fff1b8 100644 --- a/tstime/tstime_test.go +++ b/tstime/tstime_test.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstime - -import ( - "testing" - "time" -) - -func TestParseDuration(t *testing.T) { - tests := []struct { - in string - want time.Duration - }{ - {"1h", time.Hour}, - {"1d", 24 * time.Hour}, - {"365d", 365 * 24 * time.Hour}, - {"12345d", 12345 * 24 * time.Hour}, - {"67890d", 67890 * 24 * time.Hour}, - {"100d", 100 * 24 * time.Hour}, - {"1d1d", 48 * time.Hour}, - {"1h1d", 25 * time.Hour}, - {"1d1h", 25 * time.Hour}, - {"1w", 7 * 24 * time.Hour}, - {"1w1d1h", 8*24*time.Hour + time.Hour}, - {"1w1d1h", 8*24*time.Hour + time.Hour}, - {"1y", 0}, - {"", 0}, - } - for _, tt := range tests { - if got, _ := ParseDuration(tt.in); got != tt.want { - t.Errorf("ParseDuration(%q) = %d; want %d", tt.in, got, tt.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstime + +import ( + "testing" + "time" +) + +func TestParseDuration(t *testing.T) { + tests := []struct { + in string + want time.Duration + }{ + {"1h", time.Hour}, + {"1d", 24 * time.Hour}, + {"365d", 365 * 24 * time.Hour}, + {"12345d", 12345 * 24 * time.Hour}, + {"67890d", 67890 * 24 * time.Hour}, + {"100d", 100 * 24 * time.Hour}, + {"1d1d", 48 * time.Hour}, + {"1h1d", 25 * time.Hour}, + {"1d1h", 25 * time.Hour}, + {"1w", 7 * 24 * time.Hour}, + {"1w1d1h", 8*24*time.Hour + time.Hour}, + {"1w1d1h", 8*24*time.Hour + time.Hour}, + {"1y", 0}, + {"", 0}, + } + for _, tt := range tests { + if got, _ := ParseDuration(tt.in); got != tt.want { + t.Errorf("ParseDuration(%q) = %d; want %d", tt.in, got, tt.want) + } + } +} diff --git a/tsweb/debug_test.go b/tsweb/debug_test.go index 504ec06ba20ab..2a68ab6fb27b9 100644 --- a/tsweb/debug_test.go +++ b/tsweb/debug_test.go @@ -1,208 +1,208 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsweb - -import ( - "fmt" - "io" - "net/http" - "net/http/httptest" - "runtime" - "strings" - "testing" -) - -func TestDebugger(t *testing.T) { - mux := http.NewServeMux() - - dbg1 := Debugger(mux) - if dbg1 == nil { - t.Fatal("didn't get a debugger from mux") - } - - dbg2 := Debugger(mux) - if dbg2 != dbg1 { - t.Fatal("Debugger returned different debuggers for the same mux") - } - - t.Run("cpu_pprof", func(t *testing.T) { - if testing.Short() { - t.Skip("skipping second long test") - } - switch runtime.GOOS { - case "linux", "darwin": - default: - t.Skipf("skipping test on %v", runtime.GOOS) - } - req := httptest.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil) - req.RemoteAddr = "100.101.102.103:1234" - rec := httptest.NewRecorder() - mux.ServeHTTP(rec, req) - res := rec.Result() - if res.StatusCode != 200 { - t.Errorf("unexpected %v", res.Status) - } - }) -} - -func get(m http.Handler, path, srcIP string) (int, string) { - req := httptest.NewRequest("GET", path, nil) - req.RemoteAddr = srcIP + ":1234" - rec := httptest.NewRecorder() - m.ServeHTTP(rec, req) - return rec.Result().StatusCode, rec.Body.String() -} - -const ( - tsIP = "100.100.100.100" - pubIP = "8.8.8.8" -) - -func TestDebuggerKV(t *testing.T) { - mux := http.NewServeMux() - dbg := Debugger(mux) - dbg.KV("Donuts", 42) - dbg.KV("Secret code", "hunter2") - val := "red" - dbg.KVFunc("Condition", func() any { return val }) - - code, _ := get(mux, "/debug/", pubIP) - if code != 403 { - t.Fatalf("debug access wasn't denied, got %v", code) - } - - code, body := get(mux, "/debug/", tsIP) - if code != 200 { - t.Fatalf("debug access failed, got %v", code) - } - for _, want := range []string{"Donuts", "42", "Secret code", "hunter2", "Condition", "red"} { - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } - } - - val = "green" - code, body = get(mux, "/debug/", tsIP) - if code != 200 { - t.Fatalf("debug access failed, got %v", code) - } - for _, want := range []string{"Condition", "green"} { - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } - } -} - -func TestDebuggerURL(t *testing.T) { - mux := http.NewServeMux() - dbg := Debugger(mux) - dbg.URL("https://www.tailscale.com", "Homepage") - - code, body := get(mux, "/debug/", tsIP) - if code != 200 { - t.Fatalf("debug access failed, got %v", code) - } - for _, want := range []string{"https://www.tailscale.com", "Homepage"} { - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } - } -} - -func TestDebuggerSection(t *testing.T) { - mux := http.NewServeMux() - dbg := Debugger(mux) - dbg.Section(func(w io.Writer, r *http.Request) { - fmt.Fprintf(w, "Test output %v", r.RemoteAddr) - }) - - code, body := get(mux, "/debug/", tsIP) - if code != 200 { - t.Fatalf("debug access failed, got %v", code) - } - want := `Test output 100.100.100.100:1234` - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } -} - -func TestDebuggerHandle(t *testing.T) { - mux := http.NewServeMux() - dbg := Debugger(mux) - dbg.Handle("check", "Consistency check", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Test output %v", r.RemoteAddr) - })) - - code, body := get(mux, "/debug/", tsIP) - if code != 200 { - t.Fatalf("debug access failed, got %v", code) - } - for _, want := range []string{"/debug/check", "Consistency check"} { - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } - } - - code, _ = get(mux, "/debug/check", pubIP) - if code != 403 { - t.Fatal("/debug/check should be protected, but isn't") - } - - code, body = get(mux, "/debug/check", tsIP) - if code != 200 { - t.Fatal("/debug/check denied debug access") - } - want := "Test output " + tsIP - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } -} - -func ExampleDebugHandler_Handle() { - mux := http.NewServeMux() - dbg := Debugger(mux) - // Registers /debug/flushcache with the given handler, and adds a - // link to /debug/ with the description "Flush caches". - dbg.Handle("flushcache", "Flush caches", http.HandlerFunc(http.NotFound)) -} - -func ExampleDebugHandler_KV() { - mux := http.NewServeMux() - dbg := Debugger(mux) - // Adds two list items to /debug/, showing that the condition is - // red and there are 42 donuts. - dbg.KV("Condition", "red") - dbg.KV("Donuts", 42) -} - -func ExampleDebugHandler_KVFunc() { - mux := http.NewServeMux() - dbg := Debugger(mux) - // Adds an count of page renders to /debug/. Note this example - // isn't concurrency-safe. - views := 0 - dbg.KVFunc("Debug pageviews", func() any { - views = views + 1 - return views - }) - dbg.KV("Donuts", 42) -} - -func ExampleDebugHandler_URL() { - mux := http.NewServeMux() - dbg := Debugger(mux) - // Links to the Tailscale website from /debug/. - dbg.URL("https://www.tailscale.com", "Homepage") -} - -func ExampleDebugHandler_Section() { - mux := http.NewServeMux() - dbg := Debugger(mux) - // Adds a section to /debug/ that dumps the HTTP request of the - // visitor. - dbg.Section(func(w io.Writer, r *http.Request) { - io.WriteString(w, "

Dump of your HTTP request

") - fmt.Fprintf(w, "%#v", r) - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsweb + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "runtime" + "strings" + "testing" +) + +func TestDebugger(t *testing.T) { + mux := http.NewServeMux() + + dbg1 := Debugger(mux) + if dbg1 == nil { + t.Fatal("didn't get a debugger from mux") + } + + dbg2 := Debugger(mux) + if dbg2 != dbg1 { + t.Fatal("Debugger returned different debuggers for the same mux") + } + + t.Run("cpu_pprof", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping second long test") + } + switch runtime.GOOS { + case "linux", "darwin": + default: + t.Skipf("skipping test on %v", runtime.GOOS) + } + req := httptest.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil) + req.RemoteAddr = "100.101.102.103:1234" + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + res := rec.Result() + if res.StatusCode != 200 { + t.Errorf("unexpected %v", res.Status) + } + }) +} + +func get(m http.Handler, path, srcIP string) (int, string) { + req := httptest.NewRequest("GET", path, nil) + req.RemoteAddr = srcIP + ":1234" + rec := httptest.NewRecorder() + m.ServeHTTP(rec, req) + return rec.Result().StatusCode, rec.Body.String() +} + +const ( + tsIP = "100.100.100.100" + pubIP = "8.8.8.8" +) + +func TestDebuggerKV(t *testing.T) { + mux := http.NewServeMux() + dbg := Debugger(mux) + dbg.KV("Donuts", 42) + dbg.KV("Secret code", "hunter2") + val := "red" + dbg.KVFunc("Condition", func() any { return val }) + + code, _ := get(mux, "/debug/", pubIP) + if code != 403 { + t.Fatalf("debug access wasn't denied, got %v", code) + } + + code, body := get(mux, "/debug/", tsIP) + if code != 200 { + t.Fatalf("debug access failed, got %v", code) + } + for _, want := range []string{"Donuts", "42", "Secret code", "hunter2", "Condition", "red"} { + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } + } + + val = "green" + code, body = get(mux, "/debug/", tsIP) + if code != 200 { + t.Fatalf("debug access failed, got %v", code) + } + for _, want := range []string{"Condition", "green"} { + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } + } +} + +func TestDebuggerURL(t *testing.T) { + mux := http.NewServeMux() + dbg := Debugger(mux) + dbg.URL("https://www.tailscale.com", "Homepage") + + code, body := get(mux, "/debug/", tsIP) + if code != 200 { + t.Fatalf("debug access failed, got %v", code) + } + for _, want := range []string{"https://www.tailscale.com", "Homepage"} { + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } + } +} + +func TestDebuggerSection(t *testing.T) { + mux := http.NewServeMux() + dbg := Debugger(mux) + dbg.Section(func(w io.Writer, r *http.Request) { + fmt.Fprintf(w, "Test output %v", r.RemoteAddr) + }) + + code, body := get(mux, "/debug/", tsIP) + if code != 200 { + t.Fatalf("debug access failed, got %v", code) + } + want := `Test output 100.100.100.100:1234` + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } +} + +func TestDebuggerHandle(t *testing.T) { + mux := http.NewServeMux() + dbg := Debugger(mux) + dbg.Handle("check", "Consistency check", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Test output %v", r.RemoteAddr) + })) + + code, body := get(mux, "/debug/", tsIP) + if code != 200 { + t.Fatalf("debug access failed, got %v", code) + } + for _, want := range []string{"/debug/check", "Consistency check"} { + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } + } + + code, _ = get(mux, "/debug/check", pubIP) + if code != 403 { + t.Fatal("/debug/check should be protected, but isn't") + } + + code, body = get(mux, "/debug/check", tsIP) + if code != 200 { + t.Fatal("/debug/check denied debug access") + } + want := "Test output " + tsIP + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } +} + +func ExampleDebugHandler_Handle() { + mux := http.NewServeMux() + dbg := Debugger(mux) + // Registers /debug/flushcache with the given handler, and adds a + // link to /debug/ with the description "Flush caches". + dbg.Handle("flushcache", "Flush caches", http.HandlerFunc(http.NotFound)) +} + +func ExampleDebugHandler_KV() { + mux := http.NewServeMux() + dbg := Debugger(mux) + // Adds two list items to /debug/, showing that the condition is + // red and there are 42 donuts. + dbg.KV("Condition", "red") + dbg.KV("Donuts", 42) +} + +func ExampleDebugHandler_KVFunc() { + mux := http.NewServeMux() + dbg := Debugger(mux) + // Adds an count of page renders to /debug/. Note this example + // isn't concurrency-safe. + views := 0 + dbg.KVFunc("Debug pageviews", func() any { + views = views + 1 + return views + }) + dbg.KV("Donuts", 42) +} + +func ExampleDebugHandler_URL() { + mux := http.NewServeMux() + dbg := Debugger(mux) + // Links to the Tailscale website from /debug/. + dbg.URL("https://www.tailscale.com", "Homepage") +} + +func ExampleDebugHandler_Section() { + mux := http.NewServeMux() + dbg := Debugger(mux) + // Adds a section to /debug/ that dumps the HTTP request of the + // visitor. + dbg.Section(func(w io.Writer, r *http.Request) { + io.WriteString(w, "

Dump of your HTTP request

") + fmt.Fprintf(w, "%#v", r) + }) +} diff --git a/tsweb/promvarz/promvarz_test.go b/tsweb/promvarz/promvarz_test.go index 7f9b3396ed3c9..a3f4e66f11a42 100644 --- a/tsweb/promvarz/promvarz_test.go +++ b/tsweb/promvarz/promvarz_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package promvarz - -import ( - "expvar" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/prometheus/client_golang/prometheus/testutil" -) - -var ( - testVar1 = expvar.NewInt("gauge_promvarz_test_expvar") - testVar2 = promauto.NewGauge(prometheus.GaugeOpts{Name: "promvarz_test_native"}) -) - -func TestHandler(t *testing.T) { - testVar1.Set(42) - testVar2.Set(4242) - - svr := httptest.NewServer(http.HandlerFunc(Handler)) - defer svr.Close() - - want := ` - # TYPE promvarz_test_expvar gauge - promvarz_test_expvar 42 - # TYPE promvarz_test_native gauge - promvarz_test_native 4242 - ` - if err := testutil.ScrapeAndCompare(svr.URL, strings.NewReader(want), "promvarz_test_expvar", "promvarz_test_native"); err != nil { - t.Error(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package promvarz + +import ( + "expvar" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prometheus/client_golang/prometheus/testutil" +) + +var ( + testVar1 = expvar.NewInt("gauge_promvarz_test_expvar") + testVar2 = promauto.NewGauge(prometheus.GaugeOpts{Name: "promvarz_test_native"}) +) + +func TestHandler(t *testing.T) { + testVar1.Set(42) + testVar2.Set(4242) + + svr := httptest.NewServer(http.HandlerFunc(Handler)) + defer svr.Close() + + want := ` + # TYPE promvarz_test_expvar gauge + promvarz_test_expvar 42 + # TYPE promvarz_test_native gauge + promvarz_test_native 4242 + ` + if err := testutil.ScrapeAndCompare(svr.URL, strings.NewReader(want), "promvarz_test_expvar", "promvarz_test_native"); err != nil { + t.Error(err) + } +} diff --git a/types/appctype/appconnector_test.go b/types/appctype/appconnector_test.go index 8aef135b4a876..390d1776a3280 100644 --- a/types/appctype/appconnector_test.go +++ b/types/appctype/appconnector_test.go @@ -1,78 +1,78 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package appctype - -import ( - "encoding/json" - "net/netip" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "tailscale.com/tailcfg" - "tailscale.com/util/must" -) - -var golden = `{ - "dnat": { - "opaqueid1": { - "addrs": ["100.64.0.1", "fd7a:115c:a1e0::1"], - "to": ["example.org"], - "ip": ["*"] - } - }, - "sniProxy": { - "opaqueid2": { - "addrs": ["::"], - "ip": ["tcp:443"], - "allowedDomains": ["*"] - } - }, - "advertiseRoutes": true -}` - -func TestGolden(t *testing.T) { - wantDNAT := map[ConfigID]DNATConfig{"opaqueid1": { - Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, - To: []string{"example.org"}, - IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, - }} - - wantSNI := map[ConfigID]SNIProxyConfig{"opaqueid2": { - Addrs: []netip.Addr{netip.MustParseAddr("::")}, - IP: []tailcfg.ProtoPortRange{{Proto: 6, Ports: tailcfg.PortRange{First: 443, Last: 443}}}, - AllowedDomains: []string{"*"}, - }} - - var config AppConnectorConfig - if err := json.NewDecoder(strings.NewReader(golden)).Decode(&config); err != nil { - t.Fatalf("failed to decode golden config: %v", err) - } - - if !config.AdvertiseRoutes { - t.Fatalf("expected AdvertiseRoutes to be true, got false") - } - - assertEqual(t, "DNAT", config.DNAT, wantDNAT) - assertEqual(t, "SNI", config.SNIProxy, wantSNI) -} - -func TestRoundTrip(t *testing.T) { - var config AppConnectorConfig - must.Do(json.NewDecoder(strings.NewReader(golden)).Decode(&config)) - b := must.Get(json.Marshal(config)) - var config2 AppConnectorConfig - must.Do(json.Unmarshal(b, &config2)) - assertEqual(t, "DNAT", config.DNAT, config2.DNAT) -} - -func assertEqual(t *testing.T, name string, a, b any) { - var addrComparer = cmp.Comparer(func(a, b netip.Addr) bool { - return a.Compare(b) == 0 - }) - t.Helper() - if diff := cmp.Diff(a, b, addrComparer); diff != "" { - t.Fatalf("mismatch (-want +got):\n%s", diff) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package appctype + +import ( + "encoding/json" + "net/netip" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/tailcfg" + "tailscale.com/util/must" +) + +var golden = `{ + "dnat": { + "opaqueid1": { + "addrs": ["100.64.0.1", "fd7a:115c:a1e0::1"], + "to": ["example.org"], + "ip": ["*"] + } + }, + "sniProxy": { + "opaqueid2": { + "addrs": ["::"], + "ip": ["tcp:443"], + "allowedDomains": ["*"] + } + }, + "advertiseRoutes": true +}` + +func TestGolden(t *testing.T) { + wantDNAT := map[ConfigID]DNATConfig{"opaqueid1": { + Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, + To: []string{"example.org"}, + IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, + }} + + wantSNI := map[ConfigID]SNIProxyConfig{"opaqueid2": { + Addrs: []netip.Addr{netip.MustParseAddr("::")}, + IP: []tailcfg.ProtoPortRange{{Proto: 6, Ports: tailcfg.PortRange{First: 443, Last: 443}}}, + AllowedDomains: []string{"*"}, + }} + + var config AppConnectorConfig + if err := json.NewDecoder(strings.NewReader(golden)).Decode(&config); err != nil { + t.Fatalf("failed to decode golden config: %v", err) + } + + if !config.AdvertiseRoutes { + t.Fatalf("expected AdvertiseRoutes to be true, got false") + } + + assertEqual(t, "DNAT", config.DNAT, wantDNAT) + assertEqual(t, "SNI", config.SNIProxy, wantSNI) +} + +func TestRoundTrip(t *testing.T) { + var config AppConnectorConfig + must.Do(json.NewDecoder(strings.NewReader(golden)).Decode(&config)) + b := must.Get(json.Marshal(config)) + var config2 AppConnectorConfig + must.Do(json.Unmarshal(b, &config2)) + assertEqual(t, "DNAT", config.DNAT, config2.DNAT) +} + +func assertEqual(t *testing.T, name string, a, b any) { + var addrComparer = cmp.Comparer(func(a, b netip.Addr) bool { + return a.Compare(b) == 0 + }) + t.Helper() + if diff := cmp.Diff(a, b, addrComparer); diff != "" { + t.Fatalf("mismatch (-want +got):\n%s", diff) + } +} diff --git a/types/dnstype/dnstype.go b/types/dnstype/dnstype.go index 6cc91c999e8d4..b7f5b9d02fe47 100644 --- a/types/dnstype/dnstype.go +++ b/types/dnstype/dnstype.go @@ -1,68 +1,68 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package dnstype defines types for working with DNS. -package dnstype - -//go:generate go run tailscale.com/cmd/viewer --type=Resolver --clonefunc=true - -import ( - "net/netip" - "slices" -) - -// Resolver is the configuration for one DNS resolver. -type Resolver struct { - // Addr is the address of the DNS resolver, one of: - // - A plain IP address for a "classic" UDP+TCP DNS resolver. - // This is the common format as sent by the control plane. - // - An IP:port, for tests. - // - "https://resolver.com/path" for DNS over HTTPS; currently - // as of 2022-09-08 only used for certain well-known resolvers - // (see the publicdns package) for which the IP addresses to dial DoH are - // known ahead of time, so bootstrap DNS resolution is not required. - // - "http://node-address:port/path" for DNS over HTTP over WireGuard. This - // is implemented in the PeerAPI for exit nodes and app connectors. - // - [TODO] "tls://resolver.com" for DNS over TCP+TLS - Addr string `json:",omitempty"` - - // BootstrapResolution is an optional suggested resolution for the - // DoT/DoH resolver, if the resolver URL does not reference an IP - // address directly. - // BootstrapResolution may be empty, in which case clients should - // look up the DoT/DoH server using their local "classic" DNS - // resolver. - // - // As of 2022-09-08, BootstrapResolution is not yet used. - BootstrapResolution []netip.Addr `json:",omitempty"` -} - -// IPPort returns r.Addr as an IP address and port if either -// r.Addr is an IP address (the common case) or if r.Addr -// is an IP:port (as done in tests). -func (r *Resolver) IPPort() (ipp netip.AddrPort, ok bool) { - if r.Addr == "" || r.Addr[0] == 'h' || r.Addr[0] == 't' { - // Fast path to avoid ParseIP error allocation for obviously not IP - // cases. - return - } - if ip, err := netip.ParseAddr(r.Addr); err == nil { - return netip.AddrPortFrom(ip, 53), true - } - if ipp, err := netip.ParseAddrPort(r.Addr); err == nil { - return ipp, true - } - return -} - -// Equal reports whether r and other are equal. -func (r *Resolver) Equal(other *Resolver) bool { - if r == nil || other == nil { - return r == other - } - if r == other { - return true - } - - return r.Addr == other.Addr && slices.Equal(r.BootstrapResolution, other.BootstrapResolution) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package dnstype defines types for working with DNS. +package dnstype + +//go:generate go run tailscale.com/cmd/viewer --type=Resolver --clonefunc=true + +import ( + "net/netip" + "slices" +) + +// Resolver is the configuration for one DNS resolver. +type Resolver struct { + // Addr is the address of the DNS resolver, one of: + // - A plain IP address for a "classic" UDP+TCP DNS resolver. + // This is the common format as sent by the control plane. + // - An IP:port, for tests. + // - "https://resolver.com/path" for DNS over HTTPS; currently + // as of 2022-09-08 only used for certain well-known resolvers + // (see the publicdns package) for which the IP addresses to dial DoH are + // known ahead of time, so bootstrap DNS resolution is not required. + // - "http://node-address:port/path" for DNS over HTTP over WireGuard. This + // is implemented in the PeerAPI for exit nodes and app connectors. + // - [TODO] "tls://resolver.com" for DNS over TCP+TLS + Addr string `json:",omitempty"` + + // BootstrapResolution is an optional suggested resolution for the + // DoT/DoH resolver, if the resolver URL does not reference an IP + // address directly. + // BootstrapResolution may be empty, in which case clients should + // look up the DoT/DoH server using their local "classic" DNS + // resolver. + // + // As of 2022-09-08, BootstrapResolution is not yet used. + BootstrapResolution []netip.Addr `json:",omitempty"` +} + +// IPPort returns r.Addr as an IP address and port if either +// r.Addr is an IP address (the common case) or if r.Addr +// is an IP:port (as done in tests). +func (r *Resolver) IPPort() (ipp netip.AddrPort, ok bool) { + if r.Addr == "" || r.Addr[0] == 'h' || r.Addr[0] == 't' { + // Fast path to avoid ParseIP error allocation for obviously not IP + // cases. + return + } + if ip, err := netip.ParseAddr(r.Addr); err == nil { + return netip.AddrPortFrom(ip, 53), true + } + if ipp, err := netip.ParseAddrPort(r.Addr); err == nil { + return ipp, true + } + return +} + +// Equal reports whether r and other are equal. +func (r *Resolver) Equal(other *Resolver) bool { + if r == nil || other == nil { + return r == other + } + if r == other { + return true + } + + return r.Addr == other.Addr && slices.Equal(r.BootstrapResolution, other.BootstrapResolution) +} diff --git a/types/empty/message.go b/types/empty/message.go index 5ada7f40202af..dc8eb4cc2dc37 100644 --- a/types/empty/message.go +++ b/types/empty/message.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package empty defines an empty struct type. -package empty - -// Message is an empty message. Its purpose is to be used as pointer -// type where nil and non-nil distinguish whether it's set. This is -// used instead of a bool when we want to marshal it as a JSON empty -// object (or null) for the future ability to add other fields, at -// which point callers would define a new struct and not use -// empty.Message. -type Message struct{} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package empty defines an empty struct type. +package empty + +// Message is an empty message. Its purpose is to be used as pointer +// type where nil and non-nil distinguish whether it's set. This is +// used instead of a bool when we want to marshal it as a JSON empty +// object (or null) for the future ability to add other fields, at +// which point callers would define a new struct and not use +// empty.Message. +type Message struct{} diff --git a/types/flagtype/flagtype.go b/types/flagtype/flagtype.go index c76b16353a280..be160dee82a21 100644 --- a/types/flagtype/flagtype.go +++ b/types/flagtype/flagtype.go @@ -1,45 +1,45 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package flagtype defines flag.Value types. -package flagtype - -import ( - "errors" - "flag" - "fmt" - "math" - "strconv" - "strings" -) - -type portValue struct{ n *uint16 } - -func PortValue(dst *uint16, defaultPort uint16) flag.Value { - *dst = defaultPort - return portValue{dst} -} - -func (p portValue) String() string { - if p.n == nil { - return "" - } - return fmt.Sprint(*p.n) -} -func (p portValue) Set(v string) error { - if v == "" { - return errors.New("can't be the empty string") - } - if strings.Contains(v, ":") { - return errors.New("expecting just a port number, without a colon") - } - n, err := strconv.ParseUint(v, 10, 64) // use 64 instead of 16 to return nicer error message - if err != nil { - return fmt.Errorf("not a valid number") - } - if n > math.MaxUint16 { - return errors.New("out of range for port number") - } - *p.n = uint16(n) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package flagtype defines flag.Value types. +package flagtype + +import ( + "errors" + "flag" + "fmt" + "math" + "strconv" + "strings" +) + +type portValue struct{ n *uint16 } + +func PortValue(dst *uint16, defaultPort uint16) flag.Value { + *dst = defaultPort + return portValue{dst} +} + +func (p portValue) String() string { + if p.n == nil { + return "" + } + return fmt.Sprint(*p.n) +} +func (p portValue) Set(v string) error { + if v == "" { + return errors.New("can't be the empty string") + } + if strings.Contains(v, ":") { + return errors.New("expecting just a port number, without a colon") + } + n, err := strconv.ParseUint(v, 10, 64) // use 64 instead of 16 to return nicer error message + if err != nil { + return fmt.Errorf("not a valid number") + } + if n > math.MaxUint16 { + return errors.New("out of range for port number") + } + *p.n = uint16(n) + return nil +} diff --git a/types/ipproto/ipproto.go b/types/ipproto/ipproto.go index 97fc4f3dd89e8..b5333eb56ace0 100644 --- a/types/ipproto/ipproto.go +++ b/types/ipproto/ipproto.go @@ -1,199 +1,199 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package ipproto contains IP Protocol constants. -package ipproto - -import ( - "fmt" - "strconv" - - "tailscale.com/util/nocasemaps" - "tailscale.com/util/vizerror" -) - -// Version describes the IP address version. -type Version uint8 - -// Valid Version values. -const ( - Version4 = 4 - Version6 = 6 -) - -func (p Version) String() string { - switch p { - case Version4: - return "IPv4" - case Version6: - return "IPv6" - default: - return fmt.Sprintf("Version-%d", int(p)) - } -} - -// Proto is an IP subprotocol as defined by the IANA protocol -// numbers list -// (https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml), -// or the special values Unknown or Fragment. -type Proto uint8 - -const ( - // Unknown represents an unknown or unsupported protocol; it's - // deliberately the zero value. Strictly speaking the zero - // value is IPv6 hop-by-hop extensions, but we don't support - // those, so this is still technically correct. - Unknown Proto = 0x00 - - // Values from the IANA registry. - ICMPv4 Proto = 0x01 - IGMP Proto = 0x02 - ICMPv6 Proto = 0x3a - TCP Proto = 0x06 - UDP Proto = 0x11 - DCCP Proto = 0x21 - GRE Proto = 0x2f - SCTP Proto = 0x84 - - // TSMP is the Tailscale Message Protocol (our ICMP-ish - // thing), an IP protocol used only between Tailscale nodes - // (still encrypted by WireGuard) that communicates why things - // failed, etc. - // - // Proto number 99 is reserved for "any private encryption - // scheme". We never accept these from the host OS stack nor - // send them to the host network stack. It's only used between - // nodes. - TSMP Proto = 99 - - // Fragment represents any non-first IP fragment, for which we - // don't have the sub-protocol header (and therefore can't - // figure out what the sub-protocol is). - // - // 0xFF is reserved in the IANA registry, so we steal it for - // internal use. - Fragment Proto = 0xFF -) - -// Deprecated: use MarshalText instead. -func (p Proto) String() string { - switch p { - case Unknown: - return "Unknown" - case Fragment: - return "Frag" - case ICMPv4: - return "ICMPv4" - case IGMP: - return "IGMP" - case ICMPv6: - return "ICMPv6" - case UDP: - return "UDP" - case TCP: - return "TCP" - case SCTP: - return "SCTP" - case TSMP: - return "TSMP" - case GRE: - return "GRE" - case DCCP: - return "DCCP" - default: - return fmt.Sprintf("IPProto-%d", int(p)) - } -} - -// Prefer names from -// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml -// unless otherwise noted. -var ( - // preferredNames is the set of protocol names that re produced by - // MarshalText, and are the preferred representation. - preferredNames = map[Proto]string{ - 51: "ah", - DCCP: "dccp", - 8: "egp", - 50: "esp", - 47: "gre", - ICMPv4: "icmp", - IGMP: "igmp", - 9: "igp", - 4: "ipv4", - ICMPv6: "ipv6-icmp", - SCTP: "sctp", - TCP: "tcp", - UDP: "udp", - } - - // acceptedNames is the set of protocol names that are accepted by - // UnmarshalText. - acceptedNames = map[string]Proto{ - "ah": 51, - "dccp": DCCP, - "egp": 8, - "esp": 50, - "gre": 47, - "icmp": ICMPv4, - "icmpv4": ICMPv4, - "icmpv6": ICMPv6, - "igmp": IGMP, - "igp": 9, - "ip-in-ip": 4, // IANA says "ipv4"; Wikipedia/popular use says "ip-in-ip" - "ipv4": 4, - "ipv6-icmp": ICMPv6, - "sctp": SCTP, - "tcp": TCP, - "tsmp": TSMP, - "udp": UDP, - } -) - -// UnmarshalText implements encoding.TextUnmarshaler. If the input is empty, p -// is set to 0. If an error occurs, p is unchanged. -func (p *Proto) UnmarshalText(b []byte) error { - if len(b) == 0 { - *p = 0 - return nil - } - - if u, err := strconv.ParseUint(string(b), 10, 8); err == nil { - *p = Proto(u) - return nil - } - - if newP, ok := nocasemaps.GetOk(acceptedNames, string(b)); ok { - *p = newP - return nil - } - - return vizerror.Errorf("proto name %q not known; use protocol number 0-255", b) -} - -// MarshalText implements encoding.TextMarshaler. -func (p Proto) MarshalText() ([]byte, error) { - if s, ok := preferredNames[p]; ok { - return []byte(s), nil - } - return []byte(strconv.Itoa(int(p))), nil -} - -// MarshalJSON implements json.Marshaler. -func (p Proto) MarshalJSON() ([]byte, error) { - return []byte(strconv.Itoa(int(p))), nil -} - -// UnmarshalJSON implements json.Unmarshaler. If the input is empty, p is set to -// 0. If an error occurs, p is unchanged. The input must be a JSON number or an -// accepted string name. -func (p *Proto) UnmarshalJSON(b []byte) error { - if len(b) == 0 { - *p = 0 - return nil - } - if b[0] == '"' { - b = b[1 : len(b)-1] - } - return p.UnmarshalText(b) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ipproto contains IP Protocol constants. +package ipproto + +import ( + "fmt" + "strconv" + + "tailscale.com/util/nocasemaps" + "tailscale.com/util/vizerror" +) + +// Version describes the IP address version. +type Version uint8 + +// Valid Version values. +const ( + Version4 = 4 + Version6 = 6 +) + +func (p Version) String() string { + switch p { + case Version4: + return "IPv4" + case Version6: + return "IPv6" + default: + return fmt.Sprintf("Version-%d", int(p)) + } +} + +// Proto is an IP subprotocol as defined by the IANA protocol +// numbers list +// (https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml), +// or the special values Unknown or Fragment. +type Proto uint8 + +const ( + // Unknown represents an unknown or unsupported protocol; it's + // deliberately the zero value. Strictly speaking the zero + // value is IPv6 hop-by-hop extensions, but we don't support + // those, so this is still technically correct. + Unknown Proto = 0x00 + + // Values from the IANA registry. + ICMPv4 Proto = 0x01 + IGMP Proto = 0x02 + ICMPv6 Proto = 0x3a + TCP Proto = 0x06 + UDP Proto = 0x11 + DCCP Proto = 0x21 + GRE Proto = 0x2f + SCTP Proto = 0x84 + + // TSMP is the Tailscale Message Protocol (our ICMP-ish + // thing), an IP protocol used only between Tailscale nodes + // (still encrypted by WireGuard) that communicates why things + // failed, etc. + // + // Proto number 99 is reserved for "any private encryption + // scheme". We never accept these from the host OS stack nor + // send them to the host network stack. It's only used between + // nodes. + TSMP Proto = 99 + + // Fragment represents any non-first IP fragment, for which we + // don't have the sub-protocol header (and therefore can't + // figure out what the sub-protocol is). + // + // 0xFF is reserved in the IANA registry, so we steal it for + // internal use. + Fragment Proto = 0xFF +) + +// Deprecated: use MarshalText instead. +func (p Proto) String() string { + switch p { + case Unknown: + return "Unknown" + case Fragment: + return "Frag" + case ICMPv4: + return "ICMPv4" + case IGMP: + return "IGMP" + case ICMPv6: + return "ICMPv6" + case UDP: + return "UDP" + case TCP: + return "TCP" + case SCTP: + return "SCTP" + case TSMP: + return "TSMP" + case GRE: + return "GRE" + case DCCP: + return "DCCP" + default: + return fmt.Sprintf("IPProto-%d", int(p)) + } +} + +// Prefer names from +// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml +// unless otherwise noted. +var ( + // preferredNames is the set of protocol names that re produced by + // MarshalText, and are the preferred representation. + preferredNames = map[Proto]string{ + 51: "ah", + DCCP: "dccp", + 8: "egp", + 50: "esp", + 47: "gre", + ICMPv4: "icmp", + IGMP: "igmp", + 9: "igp", + 4: "ipv4", + ICMPv6: "ipv6-icmp", + SCTP: "sctp", + TCP: "tcp", + UDP: "udp", + } + + // acceptedNames is the set of protocol names that are accepted by + // UnmarshalText. + acceptedNames = map[string]Proto{ + "ah": 51, + "dccp": DCCP, + "egp": 8, + "esp": 50, + "gre": 47, + "icmp": ICMPv4, + "icmpv4": ICMPv4, + "icmpv6": ICMPv6, + "igmp": IGMP, + "igp": 9, + "ip-in-ip": 4, // IANA says "ipv4"; Wikipedia/popular use says "ip-in-ip" + "ipv4": 4, + "ipv6-icmp": ICMPv6, + "sctp": SCTP, + "tcp": TCP, + "tsmp": TSMP, + "udp": UDP, + } +) + +// UnmarshalText implements encoding.TextUnmarshaler. If the input is empty, p +// is set to 0. If an error occurs, p is unchanged. +func (p *Proto) UnmarshalText(b []byte) error { + if len(b) == 0 { + *p = 0 + return nil + } + + if u, err := strconv.ParseUint(string(b), 10, 8); err == nil { + *p = Proto(u) + return nil + } + + if newP, ok := nocasemaps.GetOk(acceptedNames, string(b)); ok { + *p = newP + return nil + } + + return vizerror.Errorf("proto name %q not known; use protocol number 0-255", b) +} + +// MarshalText implements encoding.TextMarshaler. +func (p Proto) MarshalText() ([]byte, error) { + if s, ok := preferredNames[p]; ok { + return []byte(s), nil + } + return []byte(strconv.Itoa(int(p))), nil +} + +// MarshalJSON implements json.Marshaler. +func (p Proto) MarshalJSON() ([]byte, error) { + return []byte(strconv.Itoa(int(p))), nil +} + +// UnmarshalJSON implements json.Unmarshaler. If the input is empty, p is set to +// 0. If an error occurs, p is unchanged. The input must be a JSON number or an +// accepted string name. +func (p *Proto) UnmarshalJSON(b []byte) error { + if len(b) == 0 { + *p = 0 + return nil + } + if b[0] == '"' { + b = b[1 : len(b)-1] + } + return p.UnmarshalText(b) +} diff --git a/types/key/chal.go b/types/key/chal.go index da15dd1f8a01d..742ac5479e4a1 100644 --- a/types/key/chal.go +++ b/types/key/chal.go @@ -1,91 +1,91 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "errors" - - "go4.org/mem" - "tailscale.com/types/structs" -) - -const ( - // chalPublicHexPrefix is the prefix used to identify a - // hex-encoded challenge public key. - // - // This prefix is used in the control protocol, so cannot be - // changed. - chalPublicHexPrefix = "chalpub:" -) - -// ChallengePrivate is a challenge key, used to test whether clients control a -// key they want to prove ownership of. -// -// A ChallengePrivate is ephemeral and not serialized to the disk or network. -type ChallengePrivate struct { - _ structs.Incomparable // because == isn't constant-time - k [32]byte -} - -// NewChallenge creates and returns a new node private key. -func NewChallenge() ChallengePrivate { - return ChallengePrivate(NewNode()) -} - -// Public returns the ChallengePublic for k. -// Panics if ChallengePublic is zero. -func (k ChallengePrivate) Public() ChallengePublic { - pub := NodePrivate(k).Public() - return ChallengePublic(pub) -} - -// MarshalText implements encoding.TextMarshaler, but by returning an error. -// It shouldn't need to be marshalled anywhere. -func (k ChallengePrivate) MarshalText() ([]byte, error) { - return nil, errors.New("refusing to marshal") -} - -// SealToChallenge is like SealTo, but for a ChallengePublic. -func (k NodePrivate) SealToChallenge(p ChallengePublic, cleartext []byte) (ciphertext []byte) { - return k.SealTo(NodePublic(p), cleartext) -} - -// OpenFrom opens the NaCl box ciphertext, which must be a value -// created by NodePrivate.SealToChallenge, and returns the inner cleartext if -// ciphertext is a valid box from p to k. -func (k ChallengePrivate) OpenFrom(p NodePublic, ciphertext []byte) (cleartext []byte, ok bool) { - return NodePrivate(k).OpenFrom(p, ciphertext) -} - -// ChallengePublic is the public portion of a ChallengePrivate. -type ChallengePublic struct { - k [32]byte -} - -// String returns the output of MarshalText as a string. -func (k ChallengePublic) String() string { - bs, err := k.MarshalText() - if err != nil { - panic(err) - } - return string(bs) -} - -// AppendText implements encoding.TextAppender. -func (k ChallengePublic) AppendText(b []byte) ([]byte, error) { - return appendHexKey(b, chalPublicHexPrefix, k.k[:]), nil -} - -// MarshalText implements encoding.TextMarshaler. -func (k ChallengePublic) MarshalText() ([]byte, error) { - return k.AppendText(nil) -} - -// UnmarshalText implements encoding.TextUnmarshaler. -func (k *ChallengePublic) UnmarshalText(b []byte) error { - return parseHex(k.k[:], mem.B(b), mem.S(chalPublicHexPrefix)) -} - -// IsZero reports whether k is the zero value. -func (k ChallengePublic) IsZero() bool { return k == ChallengePublic{} } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "errors" + + "go4.org/mem" + "tailscale.com/types/structs" +) + +const ( + // chalPublicHexPrefix is the prefix used to identify a + // hex-encoded challenge public key. + // + // This prefix is used in the control protocol, so cannot be + // changed. + chalPublicHexPrefix = "chalpub:" +) + +// ChallengePrivate is a challenge key, used to test whether clients control a +// key they want to prove ownership of. +// +// A ChallengePrivate is ephemeral and not serialized to the disk or network. +type ChallengePrivate struct { + _ structs.Incomparable // because == isn't constant-time + k [32]byte +} + +// NewChallenge creates and returns a new node private key. +func NewChallenge() ChallengePrivate { + return ChallengePrivate(NewNode()) +} + +// Public returns the ChallengePublic for k. +// Panics if ChallengePublic is zero. +func (k ChallengePrivate) Public() ChallengePublic { + pub := NodePrivate(k).Public() + return ChallengePublic(pub) +} + +// MarshalText implements encoding.TextMarshaler, but by returning an error. +// It shouldn't need to be marshalled anywhere. +func (k ChallengePrivate) MarshalText() ([]byte, error) { + return nil, errors.New("refusing to marshal") +} + +// SealToChallenge is like SealTo, but for a ChallengePublic. +func (k NodePrivate) SealToChallenge(p ChallengePublic, cleartext []byte) (ciphertext []byte) { + return k.SealTo(NodePublic(p), cleartext) +} + +// OpenFrom opens the NaCl box ciphertext, which must be a value +// created by NodePrivate.SealToChallenge, and returns the inner cleartext if +// ciphertext is a valid box from p to k. +func (k ChallengePrivate) OpenFrom(p NodePublic, ciphertext []byte) (cleartext []byte, ok bool) { + return NodePrivate(k).OpenFrom(p, ciphertext) +} + +// ChallengePublic is the public portion of a ChallengePrivate. +type ChallengePublic struct { + k [32]byte +} + +// String returns the output of MarshalText as a string. +func (k ChallengePublic) String() string { + bs, err := k.MarshalText() + if err != nil { + panic(err) + } + return string(bs) +} + +// AppendText implements encoding.TextAppender. +func (k ChallengePublic) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, chalPublicHexPrefix, k.k[:]), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (k ChallengePublic) MarshalText() ([]byte, error) { + return k.AppendText(nil) +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (k *ChallengePublic) UnmarshalText(b []byte) error { + return parseHex(k.k[:], mem.B(b), mem.S(chalPublicHexPrefix)) +} + +// IsZero reports whether k is the zero value. +func (k ChallengePublic) IsZero() bool { return k == ChallengePublic{} } diff --git a/types/key/control.go b/types/key/control.go index a84359771bcab..96021249ba047 100644 --- a/types/key/control.go +++ b/types/key/control.go @@ -1,68 +1,68 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import "encoding/json" - -// ControlPrivate is a Tailscale control plane private key. -// -// It is functionally equivalent to a MachinePrivate, but serializes -// to JSON as a byte array rather than a typed string, because our -// control plane database stores the key that way. -// -// Deprecated: this type should only be used in Tailscale's control -// plane, where existing database serializations require this -// less-good serialization format to persist. Other control plane -// implementations can use MachinePrivate with no downsides. -type ControlPrivate struct { - mkey MachinePrivate // unexported so we can limit the API surface to only exactly what we need -} - -// NewControl generates and returns a new control plane private key. -func NewControl() ControlPrivate { - return ControlPrivate{NewMachine()} -} - -// IsZero reports whether k is the zero value. -func (k ControlPrivate) IsZero() bool { - return k.mkey.IsZero() -} - -// Public returns the MachinePublic for k. -// Panics if ControlPrivate is zero. -func (k ControlPrivate) Public() MachinePublic { - return k.mkey.Public() -} - -// MarshalJSON implements json.Marshaler. -func (k ControlPrivate) MarshalJSON() ([]byte, error) { - return json.Marshal(k.mkey.k) -} - -// UnmarshalJSON implements json.Unmarshaler. -func (k *ControlPrivate) UnmarshalJSON(bs []byte) error { - return json.Unmarshal(bs, &k.mkey.k) -} - -// SealTo wraps cleartext into a NaCl box (see -// golang.org/x/crypto/nacl) to p, authenticated from k, using a -// random nonce. -// -// The returned ciphertext is a 24-byte nonce concatenated with the -// box value. -func (k ControlPrivate) SealTo(p MachinePublic, cleartext []byte) (ciphertext []byte) { - return k.mkey.SealTo(p, cleartext) -} - -// SharedKey returns the precomputed Nacl box shared key between k and p. -func (k ControlPrivate) SharedKey(p MachinePublic) MachinePrecomputedSharedKey { - return k.mkey.SharedKey(p) -} - -// OpenFrom opens the NaCl box ciphertext, which must be a value -// created by SealTo, and returns the inner cleartext if ciphertext is -// a valid box from p to k. -func (k ControlPrivate) OpenFrom(p MachinePublic, ciphertext []byte) (cleartext []byte, ok bool) { - return k.mkey.OpenFrom(p, ciphertext) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import "encoding/json" + +// ControlPrivate is a Tailscale control plane private key. +// +// It is functionally equivalent to a MachinePrivate, but serializes +// to JSON as a byte array rather than a typed string, because our +// control plane database stores the key that way. +// +// Deprecated: this type should only be used in Tailscale's control +// plane, where existing database serializations require this +// less-good serialization format to persist. Other control plane +// implementations can use MachinePrivate with no downsides. +type ControlPrivate struct { + mkey MachinePrivate // unexported so we can limit the API surface to only exactly what we need +} + +// NewControl generates and returns a new control plane private key. +func NewControl() ControlPrivate { + return ControlPrivate{NewMachine()} +} + +// IsZero reports whether k is the zero value. +func (k ControlPrivate) IsZero() bool { + return k.mkey.IsZero() +} + +// Public returns the MachinePublic for k. +// Panics if ControlPrivate is zero. +func (k ControlPrivate) Public() MachinePublic { + return k.mkey.Public() +} + +// MarshalJSON implements json.Marshaler. +func (k ControlPrivate) MarshalJSON() ([]byte, error) { + return json.Marshal(k.mkey.k) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (k *ControlPrivate) UnmarshalJSON(bs []byte) error { + return json.Unmarshal(bs, &k.mkey.k) +} + +// SealTo wraps cleartext into a NaCl box (see +// golang.org/x/crypto/nacl) to p, authenticated from k, using a +// random nonce. +// +// The returned ciphertext is a 24-byte nonce concatenated with the +// box value. +func (k ControlPrivate) SealTo(p MachinePublic, cleartext []byte) (ciphertext []byte) { + return k.mkey.SealTo(p, cleartext) +} + +// SharedKey returns the precomputed Nacl box shared key between k and p. +func (k ControlPrivate) SharedKey(p MachinePublic) MachinePrecomputedSharedKey { + return k.mkey.SharedKey(p) +} + +// OpenFrom opens the NaCl box ciphertext, which must be a value +// created by SealTo, and returns the inner cleartext if ciphertext is +// a valid box from p to k. +func (k ControlPrivate) OpenFrom(p MachinePublic, ciphertext []byte) (cleartext []byte, ok bool) { + return k.mkey.OpenFrom(p, ciphertext) +} diff --git a/types/key/control_test.go b/types/key/control_test.go index 06e0f36d50bcf..a98a586f3ba5a 100644 --- a/types/key/control_test.go +++ b/types/key/control_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "encoding/json" - "testing" -) - -func TestControlKey(t *testing.T) { - serialized := `{"PrivateKey":[36,132,249,6,73,141,249,49,9,96,49,60,240,217,253,57,3,69,248,64,178,62,121,73,121,88,115,218,130,145,68,254]}` - want := ControlPrivate{ - MachinePrivate{ - k: [32]byte{36, 132, 249, 6, 73, 141, 249, 49, 9, 96, 49, 60, 240, 217, 253, 57, 3, 69, 248, 64, 178, 62, 121, 73, 121, 88, 115, 218, 130, 145, 68, 254}, - }, - } - - var got struct { - PrivateKey ControlPrivate - } - if err := json.Unmarshal([]byte(serialized), &got); err != nil { - t.Fatalf("decoding serialized ControlPrivate: %v", err) - } - - if !got.PrivateKey.mkey.Equal(want.mkey) { - t.Fatalf("Serialized ControlPrivate didn't deserialize as expected, got %v want %v", got.PrivateKey, want) - } - - bs, err := json.Marshal(got) - if err != nil { - t.Fatalf("json reserialization of ControlPrivate failed: %v", err) - } - - if got, want := string(bs), serialized; got != want { - t.Fatalf("ControlPrivate didn't round-trip, got %q want %q", got, want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "encoding/json" + "testing" +) + +func TestControlKey(t *testing.T) { + serialized := `{"PrivateKey":[36,132,249,6,73,141,249,49,9,96,49,60,240,217,253,57,3,69,248,64,178,62,121,73,121,88,115,218,130,145,68,254]}` + want := ControlPrivate{ + MachinePrivate{ + k: [32]byte{36, 132, 249, 6, 73, 141, 249, 49, 9, 96, 49, 60, 240, 217, 253, 57, 3, 69, 248, 64, 178, 62, 121, 73, 121, 88, 115, 218, 130, 145, 68, 254}, + }, + } + + var got struct { + PrivateKey ControlPrivate + } + if err := json.Unmarshal([]byte(serialized), &got); err != nil { + t.Fatalf("decoding serialized ControlPrivate: %v", err) + } + + if !got.PrivateKey.mkey.Equal(want.mkey) { + t.Fatalf("Serialized ControlPrivate didn't deserialize as expected, got %v want %v", got.PrivateKey, want) + } + + bs, err := json.Marshal(got) + if err != nil { + t.Fatalf("json reserialization of ControlPrivate failed: %v", err) + } + + if got, want := string(bs), serialized; got != want { + t.Fatalf("ControlPrivate didn't round-trip, got %q want %q", got, want) + } +} diff --git a/types/key/disco_test.go b/types/key/disco_test.go index c9d60c82874f8..c62c13cbf8970 100644 --- a/types/key/disco_test.go +++ b/types/key/disco_test.go @@ -1,83 +1,83 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "bytes" - "encoding/json" - "testing" -) - -func TestDiscoKey(t *testing.T) { - k := NewDisco() - if k.IsZero() { - t.Fatal("DiscoPrivate should not be zero") - } - - p := k.Public() - if p.IsZero() { - t.Fatal("DiscoPublic should not be zero") - } - - bs, err := p.MarshalText() - if err != nil { - t.Fatal(err) - } - if !bytes.HasPrefix(bs, []byte("discokey:")) { - t.Fatalf("serialization of public discokey %s has wrong prefix", p) - } - - z := DiscoPublic{} - if !z.IsZero() { - t.Fatal("IsZero(DiscoPublic{}) is false") - } - if s := z.ShortString(); s != "" { - t.Fatalf("DiscoPublic{}.ShortString() is %q, want \"\"", s) - } -} - -func TestDiscoSerialization(t *testing.T) { - serialized := `{ - "Pub":"discokey:50d20b455ecf12bc453f83c2cfdb2a24925d06cf2598dcaa54e91af82ce9f765" - }` - - pub := DiscoPublic{ - k: [32]uint8{ - 0x50, 0xd2, 0xb, 0x45, 0x5e, 0xcf, 0x12, 0xbc, 0x45, 0x3f, 0x83, - 0xc2, 0xcf, 0xdb, 0x2a, 0x24, 0x92, 0x5d, 0x6, 0xcf, 0x25, 0x98, - 0xdc, 0xaa, 0x54, 0xe9, 0x1a, 0xf8, 0x2c, 0xe9, 0xf7, 0x65, - }, - } - - type key struct { - Pub DiscoPublic - } - - var a key - if err := json.Unmarshal([]byte(serialized), &a); err != nil { - t.Fatal(err) - } - if a.Pub != pub { - t.Errorf("wrong deserialization of public key, got %#v want %#v", a.Pub, pub) - } - - bs, err := json.MarshalIndent(a, "", " ") - if err != nil { - t.Fatal(err) - } - - var b bytes.Buffer - json.Indent(&b, []byte(serialized), "", " ") - if got, want := string(bs), b.String(); got != want { - t.Error("json serialization doesn't roundtrip") - } -} - -func TestDiscoShared(t *testing.T) { - k1, k2 := NewDisco(), NewDisco() - s1, s2 := k1.Shared(k2.Public()), k2.Shared(k1.Public()) - if !s1.Equal(s2) { - t.Error("k1.Shared(k2) != k2.Shared(k1)") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "bytes" + "encoding/json" + "testing" +) + +func TestDiscoKey(t *testing.T) { + k := NewDisco() + if k.IsZero() { + t.Fatal("DiscoPrivate should not be zero") + } + + p := k.Public() + if p.IsZero() { + t.Fatal("DiscoPublic should not be zero") + } + + bs, err := p.MarshalText() + if err != nil { + t.Fatal(err) + } + if !bytes.HasPrefix(bs, []byte("discokey:")) { + t.Fatalf("serialization of public discokey %s has wrong prefix", p) + } + + z := DiscoPublic{} + if !z.IsZero() { + t.Fatal("IsZero(DiscoPublic{}) is false") + } + if s := z.ShortString(); s != "" { + t.Fatalf("DiscoPublic{}.ShortString() is %q, want \"\"", s) + } +} + +func TestDiscoSerialization(t *testing.T) { + serialized := `{ + "Pub":"discokey:50d20b455ecf12bc453f83c2cfdb2a24925d06cf2598dcaa54e91af82ce9f765" + }` + + pub := DiscoPublic{ + k: [32]uint8{ + 0x50, 0xd2, 0xb, 0x45, 0x5e, 0xcf, 0x12, 0xbc, 0x45, 0x3f, 0x83, + 0xc2, 0xcf, 0xdb, 0x2a, 0x24, 0x92, 0x5d, 0x6, 0xcf, 0x25, 0x98, + 0xdc, 0xaa, 0x54, 0xe9, 0x1a, 0xf8, 0x2c, 0xe9, 0xf7, 0x65, + }, + } + + type key struct { + Pub DiscoPublic + } + + var a key + if err := json.Unmarshal([]byte(serialized), &a); err != nil { + t.Fatal(err) + } + if a.Pub != pub { + t.Errorf("wrong deserialization of public key, got %#v want %#v", a.Pub, pub) + } + + bs, err := json.MarshalIndent(a, "", " ") + if err != nil { + t.Fatal(err) + } + + var b bytes.Buffer + json.Indent(&b, []byte(serialized), "", " ") + if got, want := string(bs), b.String(); got != want { + t.Error("json serialization doesn't roundtrip") + } +} + +func TestDiscoShared(t *testing.T) { + k1, k2 := NewDisco(), NewDisco() + s1, s2 := k1.Shared(k2.Public()), k2.Shared(k1.Public()) + if !s1.Equal(s2) { + t.Error("k1.Shared(k2) != k2.Shared(k1)") + } +} diff --git a/types/key/machine.go b/types/key/machine.go index 0dc02574c510d..a05f3cc1f5735 100644 --- a/types/key/machine.go +++ b/types/key/machine.go @@ -1,264 +1,264 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "bytes" - "crypto/subtle" - "encoding/hex" - - "go4.org/mem" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/nacl/box" - "tailscale.com/types/structs" -) - -const ( - // machinePrivateHexPrefix is the prefix used to identify a - // hex-encoded machine private key. - // - // This prefix name is a little unfortunate, in that it comes from - // WireGuard's own key types. Unfortunately we're stuck with it for - // machine keys, because we serialize them to disk with this prefix. - machinePrivateHexPrefix = "privkey:" - - // machinePublicHexPrefix is the prefix used to identify a - // hex-encoded machine public key. - // - // This prefix is used in the control protocol, so cannot be - // changed. - machinePublicHexPrefix = "mkey:" -) - -// MachinePrivate is a machine key, used for communication with the -// Tailscale coordination server. -type MachinePrivate struct { - _ structs.Incomparable // == isn't constant-time - k [32]byte -} - -// NewMachine creates and returns a new machine private key. -func NewMachine() MachinePrivate { - var ret MachinePrivate - rand(ret.k[:]) - clamp25519Private(ret.k[:]) - return ret -} - -// IsZero reports whether k is the zero value. -func (k MachinePrivate) IsZero() bool { - return k.Equal(MachinePrivate{}) -} - -// Equal reports whether k and other are the same key. -func (k MachinePrivate) Equal(other MachinePrivate) bool { - return subtle.ConstantTimeCompare(k.k[:], other.k[:]) == 1 -} - -// Public returns the MachinePublic for k. -// Panics if MachinePrivate is zero. -func (k MachinePrivate) Public() MachinePublic { - if k.IsZero() { - panic("can't take the public key of a zero MachinePrivate") - } - var ret MachinePublic - curve25519.ScalarBaseMult(&ret.k, &k.k) - return ret -} - -// AppendText implements encoding.TextAppender. -func (k MachinePrivate) AppendText(b []byte) ([]byte, error) { - return appendHexKey(b, machinePrivateHexPrefix, k.k[:]), nil -} - -// MarshalText implements encoding.TextMarshaler. -func (k MachinePrivate) MarshalText() ([]byte, error) { - return k.AppendText(nil) -} - -// MarshalText implements encoding.TextUnmarshaler. -func (k *MachinePrivate) UnmarshalText(b []byte) error { - return parseHex(k.k[:], mem.B(b), mem.S(machinePrivateHexPrefix)) -} - -// UntypedBytes returns k, encoded as an untyped 64-character hex -// string. -// -// Deprecated: this function is risky to use, because it produces -// serialized values that do not identify themselves as a -// MachinePrivate, allowing other code to potentially parse it back in -// as the wrong key type. For new uses that don't require this -// specific raw byte serialization, please use -// MarshalText/UnmarshalText. -func (k MachinePrivate) UntypedBytes() []byte { - return bytes.Clone(k.k[:]) -} - -// SealTo wraps cleartext into a NaCl box (see -// golang.org/x/crypto/nacl) to p, authenticated from k, using a -// random nonce. -// -// The returned ciphertext is a 24-byte nonce concatenated with the -// box value. -func (k MachinePrivate) SealTo(p MachinePublic, cleartext []byte) (ciphertext []byte) { - if k.IsZero() || p.IsZero() { - panic("can't seal with zero keys") - } - var nonce [24]byte - rand(nonce[:]) - return box.Seal(nonce[:], cleartext, &nonce, &p.k, &k.k) -} - -// SharedKey returns the precomputed Nacl box shared key between k and p. -func (k MachinePrivate) SharedKey(p MachinePublic) MachinePrecomputedSharedKey { - var shared MachinePrecomputedSharedKey - box.Precompute(&shared.k, &p.k, &k.k) - return shared -} - -// MachinePrecomputedSharedKey is a precomputed shared NaCl box shared key. -type MachinePrecomputedSharedKey struct { - k [32]byte -} - -// Seal wraps cleartext into a NaCl box (see -// golang.org/x/crypto/nacl) using the shared key k as generated -// by MachinePrivate.SharedKey. -// -// The returned ciphertext is a 24-byte nonce concatenated with the -// box value. -func (k MachinePrecomputedSharedKey) Seal(cleartext []byte) (ciphertext []byte) { - if k == (MachinePrecomputedSharedKey{}) { - panic("can't seal with zero keys") - } - var nonce [24]byte - rand(nonce[:]) - return box.SealAfterPrecomputation(nonce[:], cleartext, &nonce, &k.k) -} - -// Open opens the NaCl box ciphertext, which must be a value created by -// MachinePrecomputedSharedKey.Seal or MachinePrivate.SealTo, and returns the -// inner cleartext if ciphertext is a valid box for the shared key k. -func (k MachinePrecomputedSharedKey) Open(ciphertext []byte) (cleartext []byte, ok bool) { - if k == (MachinePrecomputedSharedKey{}) { - panic("can't open with zero keys") - } - if len(ciphertext) < 24 { - return nil, false - } - var nonce [24]byte - copy(nonce[:], ciphertext) - return box.OpenAfterPrecomputation(nil, ciphertext[len(nonce):], &nonce, &k.k) -} - -// OpenFrom opens the NaCl box ciphertext, which must be a value -// created by SealTo, and returns the inner cleartext if ciphertext is -// a valid box from p to k. -func (k MachinePrivate) OpenFrom(p MachinePublic, ciphertext []byte) (cleartext []byte, ok bool) { - if k.IsZero() || p.IsZero() { - panic("can't open with zero keys") - } - if len(ciphertext) < 24 { - return nil, false - } - var nonce [24]byte - copy(nonce[:], ciphertext) - return box.Open(nil, ciphertext[len(nonce):], &nonce, &p.k, &k.k) -} - -// MachinePublic is the public portion of a a MachinePrivate. -type MachinePublic struct { - k [32]byte -} - -// MachinePublicFromRaw32 parses a 32-byte raw value as a MachinePublic. -// -// This should be used only when deserializing a MachinePublic from a -// binary protocol. -func MachinePublicFromRaw32(raw mem.RO) MachinePublic { - if raw.Len() != 32 { - panic("input has wrong size") - } - var ret MachinePublic - raw.Copy(ret.k[:]) - return ret -} - -// ParseMachinePublicUntyped parses an untyped 64-character hex value -// as a MachinePublic. -// -// Deprecated: this function is risky to use, because it cannot verify -// that the hex string was intended to be a MachinePublic. This can -// lead to accidentally decoding one type of key as another. For new -// uses that don't require backwards compatibility with the untyped -// string format, please use MarshalText/UnmarshalText. -func ParseMachinePublicUntyped(raw mem.RO) (MachinePublic, error) { - var ret MachinePublic - if err := parseHex(ret.k[:], raw, mem.B(nil)); err != nil { - return MachinePublic{}, err - } - return ret, nil -} - -// IsZero reports whether k is the zero value. -func (k MachinePublic) IsZero() bool { - return k == MachinePublic{} -} - -// ShortString returns the Tailscale conventional debug representation -// of a public key: the first five base64 digits of the key, in square -// brackets. -func (k MachinePublic) ShortString() string { - return debug32(k.k) -} - -// UntypedHexString returns k, encoded as an untyped 64-character hex -// string. -// -// Deprecated: this function is risky to use, because it produces -// serialized values that do not identify themselves as a -// MachinePublic, allowing other code to potentially parse it back in -// as the wrong key type. For new uses that don't require backwards -// compatibility with the untyped string format, please use -// MarshalText/UnmarshalText. -func (k MachinePublic) UntypedHexString() string { - return hex.EncodeToString(k.k[:]) -} - -// UntypedBytes returns k, encoded as an untyped 64-character hex -// string. -// -// Deprecated: this function is risky to use, because it produces -// serialized values that do not identify themselves as a -// MachinePublic, allowing other code to potentially parse it back in -// as the wrong key type. For new uses that don't require this -// specific raw byte serialization, please use -// MarshalText/UnmarshalText. -func (k MachinePublic) UntypedBytes() []byte { - return bytes.Clone(k.k[:]) -} - -// String returns the output of MarshalText as a string. -func (k MachinePublic) String() string { - bs, err := k.MarshalText() - if err != nil { - panic(err) - } - return string(bs) -} - -// AppendText implements encoding.TextAppender. -func (k MachinePublic) AppendText(b []byte) ([]byte, error) { - return appendHexKey(b, machinePublicHexPrefix, k.k[:]), nil -} - -// MarshalText implements encoding.TextMarshaler. -func (k MachinePublic) MarshalText() ([]byte, error) { - return k.AppendText(nil) -} - -// MarshalText implements encoding.TextUnmarshaler. -func (k *MachinePublic) UnmarshalText(b []byte) error { - return parseHex(k.k[:], mem.B(b), mem.S(machinePublicHexPrefix)) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "bytes" + "crypto/subtle" + "encoding/hex" + + "go4.org/mem" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/nacl/box" + "tailscale.com/types/structs" +) + +const ( + // machinePrivateHexPrefix is the prefix used to identify a + // hex-encoded machine private key. + // + // This prefix name is a little unfortunate, in that it comes from + // WireGuard's own key types. Unfortunately we're stuck with it for + // machine keys, because we serialize them to disk with this prefix. + machinePrivateHexPrefix = "privkey:" + + // machinePublicHexPrefix is the prefix used to identify a + // hex-encoded machine public key. + // + // This prefix is used in the control protocol, so cannot be + // changed. + machinePublicHexPrefix = "mkey:" +) + +// MachinePrivate is a machine key, used for communication with the +// Tailscale coordination server. +type MachinePrivate struct { + _ structs.Incomparable // == isn't constant-time + k [32]byte +} + +// NewMachine creates and returns a new machine private key. +func NewMachine() MachinePrivate { + var ret MachinePrivate + rand(ret.k[:]) + clamp25519Private(ret.k[:]) + return ret +} + +// IsZero reports whether k is the zero value. +func (k MachinePrivate) IsZero() bool { + return k.Equal(MachinePrivate{}) +} + +// Equal reports whether k and other are the same key. +func (k MachinePrivate) Equal(other MachinePrivate) bool { + return subtle.ConstantTimeCompare(k.k[:], other.k[:]) == 1 +} + +// Public returns the MachinePublic for k. +// Panics if MachinePrivate is zero. +func (k MachinePrivate) Public() MachinePublic { + if k.IsZero() { + panic("can't take the public key of a zero MachinePrivate") + } + var ret MachinePublic + curve25519.ScalarBaseMult(&ret.k, &k.k) + return ret +} + +// AppendText implements encoding.TextAppender. +func (k MachinePrivate) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, machinePrivateHexPrefix, k.k[:]), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (k MachinePrivate) MarshalText() ([]byte, error) { + return k.AppendText(nil) +} + +// MarshalText implements encoding.TextUnmarshaler. +func (k *MachinePrivate) UnmarshalText(b []byte) error { + return parseHex(k.k[:], mem.B(b), mem.S(machinePrivateHexPrefix)) +} + +// UntypedBytes returns k, encoded as an untyped 64-character hex +// string. +// +// Deprecated: this function is risky to use, because it produces +// serialized values that do not identify themselves as a +// MachinePrivate, allowing other code to potentially parse it back in +// as the wrong key type. For new uses that don't require this +// specific raw byte serialization, please use +// MarshalText/UnmarshalText. +func (k MachinePrivate) UntypedBytes() []byte { + return bytes.Clone(k.k[:]) +} + +// SealTo wraps cleartext into a NaCl box (see +// golang.org/x/crypto/nacl) to p, authenticated from k, using a +// random nonce. +// +// The returned ciphertext is a 24-byte nonce concatenated with the +// box value. +func (k MachinePrivate) SealTo(p MachinePublic, cleartext []byte) (ciphertext []byte) { + if k.IsZero() || p.IsZero() { + panic("can't seal with zero keys") + } + var nonce [24]byte + rand(nonce[:]) + return box.Seal(nonce[:], cleartext, &nonce, &p.k, &k.k) +} + +// SharedKey returns the precomputed Nacl box shared key between k and p. +func (k MachinePrivate) SharedKey(p MachinePublic) MachinePrecomputedSharedKey { + var shared MachinePrecomputedSharedKey + box.Precompute(&shared.k, &p.k, &k.k) + return shared +} + +// MachinePrecomputedSharedKey is a precomputed shared NaCl box shared key. +type MachinePrecomputedSharedKey struct { + k [32]byte +} + +// Seal wraps cleartext into a NaCl box (see +// golang.org/x/crypto/nacl) using the shared key k as generated +// by MachinePrivate.SharedKey. +// +// The returned ciphertext is a 24-byte nonce concatenated with the +// box value. +func (k MachinePrecomputedSharedKey) Seal(cleartext []byte) (ciphertext []byte) { + if k == (MachinePrecomputedSharedKey{}) { + panic("can't seal with zero keys") + } + var nonce [24]byte + rand(nonce[:]) + return box.SealAfterPrecomputation(nonce[:], cleartext, &nonce, &k.k) +} + +// Open opens the NaCl box ciphertext, which must be a value created by +// MachinePrecomputedSharedKey.Seal or MachinePrivate.SealTo, and returns the +// inner cleartext if ciphertext is a valid box for the shared key k. +func (k MachinePrecomputedSharedKey) Open(ciphertext []byte) (cleartext []byte, ok bool) { + if k == (MachinePrecomputedSharedKey{}) { + panic("can't open with zero keys") + } + if len(ciphertext) < 24 { + return nil, false + } + var nonce [24]byte + copy(nonce[:], ciphertext) + return box.OpenAfterPrecomputation(nil, ciphertext[len(nonce):], &nonce, &k.k) +} + +// OpenFrom opens the NaCl box ciphertext, which must be a value +// created by SealTo, and returns the inner cleartext if ciphertext is +// a valid box from p to k. +func (k MachinePrivate) OpenFrom(p MachinePublic, ciphertext []byte) (cleartext []byte, ok bool) { + if k.IsZero() || p.IsZero() { + panic("can't open with zero keys") + } + if len(ciphertext) < 24 { + return nil, false + } + var nonce [24]byte + copy(nonce[:], ciphertext) + return box.Open(nil, ciphertext[len(nonce):], &nonce, &p.k, &k.k) +} + +// MachinePublic is the public portion of a a MachinePrivate. +type MachinePublic struct { + k [32]byte +} + +// MachinePublicFromRaw32 parses a 32-byte raw value as a MachinePublic. +// +// This should be used only when deserializing a MachinePublic from a +// binary protocol. +func MachinePublicFromRaw32(raw mem.RO) MachinePublic { + if raw.Len() != 32 { + panic("input has wrong size") + } + var ret MachinePublic + raw.Copy(ret.k[:]) + return ret +} + +// ParseMachinePublicUntyped parses an untyped 64-character hex value +// as a MachinePublic. +// +// Deprecated: this function is risky to use, because it cannot verify +// that the hex string was intended to be a MachinePublic. This can +// lead to accidentally decoding one type of key as another. For new +// uses that don't require backwards compatibility with the untyped +// string format, please use MarshalText/UnmarshalText. +func ParseMachinePublicUntyped(raw mem.RO) (MachinePublic, error) { + var ret MachinePublic + if err := parseHex(ret.k[:], raw, mem.B(nil)); err != nil { + return MachinePublic{}, err + } + return ret, nil +} + +// IsZero reports whether k is the zero value. +func (k MachinePublic) IsZero() bool { + return k == MachinePublic{} +} + +// ShortString returns the Tailscale conventional debug representation +// of a public key: the first five base64 digits of the key, in square +// brackets. +func (k MachinePublic) ShortString() string { + return debug32(k.k) +} + +// UntypedHexString returns k, encoded as an untyped 64-character hex +// string. +// +// Deprecated: this function is risky to use, because it produces +// serialized values that do not identify themselves as a +// MachinePublic, allowing other code to potentially parse it back in +// as the wrong key type. For new uses that don't require backwards +// compatibility with the untyped string format, please use +// MarshalText/UnmarshalText. +func (k MachinePublic) UntypedHexString() string { + return hex.EncodeToString(k.k[:]) +} + +// UntypedBytes returns k, encoded as an untyped 64-character hex +// string. +// +// Deprecated: this function is risky to use, because it produces +// serialized values that do not identify themselves as a +// MachinePublic, allowing other code to potentially parse it back in +// as the wrong key type. For new uses that don't require this +// specific raw byte serialization, please use +// MarshalText/UnmarshalText. +func (k MachinePublic) UntypedBytes() []byte { + return bytes.Clone(k.k[:]) +} + +// String returns the output of MarshalText as a string. +func (k MachinePublic) String() string { + bs, err := k.MarshalText() + if err != nil { + panic(err) + } + return string(bs) +} + +// AppendText implements encoding.TextAppender. +func (k MachinePublic) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, machinePublicHexPrefix, k.k[:]), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (k MachinePublic) MarshalText() ([]byte, error) { + return k.AppendText(nil) +} + +// MarshalText implements encoding.TextUnmarshaler. +func (k *MachinePublic) UnmarshalText(b []byte) error { + return parseHex(k.k[:], mem.B(b), mem.S(machinePublicHexPrefix)) +} diff --git a/types/key/machine_test.go b/types/key/machine_test.go index f797ff087f090..157df9e4356b1 100644 --- a/types/key/machine_test.go +++ b/types/key/machine_test.go @@ -1,119 +1,119 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "bytes" - "encoding/json" - "strings" - "testing" -) - -func TestMachineKey(t *testing.T) { - k := NewMachine() - if k.IsZero() { - t.Fatal("MachinePrivate should not be zero") - } - - p := k.Public() - if p.IsZero() { - t.Fatal("MachinePublic should not be zero") - } - - bs, err := p.MarshalText() - if err != nil { - t.Fatal(err) - } - if full, got := string(bs), ":"+p.UntypedHexString(); !strings.HasSuffix(full, got) { - t.Fatalf("MachinePublic.UntypedHexString is not a suffix of the typed serialization, got %q want suffix of %q", got, full) - } - - z := MachinePublic{} - if !z.IsZero() { - t.Fatal("IsZero(MachinePublic{}) is false") - } - if s := z.ShortString(); s != "" { - t.Fatalf("MachinePublic{}.ShortString() is %q, want \"\"", s) - } -} - -func TestMachineSerialization(t *testing.T) { - serialized := `{ - "Priv": "privkey:40ab1b58e9076c7a4d9d07291f5edf9d1aa017eb949624ba683317f48a640369", - "Pub":"mkey:50d20b455ecf12bc453f83c2cfdb2a24925d06cf2598dcaa54e91af82ce9f765" - }` - - // Carefully check that the expected serialized data decodes and - // reencodes to the expected keys. These types are serialized to - // disk all over the place and need to be stable. - priv := MachinePrivate{ - k: [32]uint8{ - 0x40, 0xab, 0x1b, 0x58, 0xe9, 0x7, 0x6c, 0x7a, 0x4d, 0x9d, 0x7, - 0x29, 0x1f, 0x5e, 0xdf, 0x9d, 0x1a, 0xa0, 0x17, 0xeb, 0x94, - 0x96, 0x24, 0xba, 0x68, 0x33, 0x17, 0xf4, 0x8a, 0x64, 0x3, 0x69, - }, - } - pub := MachinePublic{ - k: [32]uint8{ - 0x50, 0xd2, 0xb, 0x45, 0x5e, 0xcf, 0x12, 0xbc, 0x45, 0x3f, 0x83, - 0xc2, 0xcf, 0xdb, 0x2a, 0x24, 0x92, 0x5d, 0x6, 0xcf, 0x25, 0x98, - 0xdc, 0xaa, 0x54, 0xe9, 0x1a, 0xf8, 0x2c, 0xe9, 0xf7, 0x65, - }, - } - - type keypair struct { - Priv MachinePrivate - Pub MachinePublic - } - - var a keypair - if err := json.Unmarshal([]byte(serialized), &a); err != nil { - t.Fatal(err) - } - if !a.Priv.Equal(priv) { - t.Errorf("wrong deserialization of private key, got %#v want %#v", a.Priv, priv) - } - if a.Pub != pub { - t.Errorf("wrong deserialization of public key, got %#v want %#v", a.Pub, pub) - } - - bs, err := json.MarshalIndent(a, "", " ") - if err != nil { - t.Fatal(err) - } - - var b bytes.Buffer - json.Indent(&b, []byte(serialized), "", " ") - if got, want := string(bs), b.String(); got != want { - t.Error("json serialization doesn't roundtrip") - } -} - -func TestSealViaSharedKey(t *testing.T) { - // encrypt a message from a to b - a := NewMachine() - b := NewMachine() - apub, bpub := a.Public(), b.Public() - - shared := a.SharedKey(bpub) - - const clear = "the eagle flies at midnight" - enc := shared.Seal([]byte(clear)) - - back, ok := b.OpenFrom(apub, enc) - if !ok { - t.Fatal("failed to decrypt") - } - if string(back) != clear { - t.Errorf("OpenFrom got %q; want cleartext %q", back, clear) - } - - backShared, ok := shared.Open(enc) - if !ok { - t.Fatal("failed to decrypt from shared key") - } - if string(backShared) != clear { - t.Errorf("Open got %q; want cleartext %q", back, clear) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "bytes" + "encoding/json" + "strings" + "testing" +) + +func TestMachineKey(t *testing.T) { + k := NewMachine() + if k.IsZero() { + t.Fatal("MachinePrivate should not be zero") + } + + p := k.Public() + if p.IsZero() { + t.Fatal("MachinePublic should not be zero") + } + + bs, err := p.MarshalText() + if err != nil { + t.Fatal(err) + } + if full, got := string(bs), ":"+p.UntypedHexString(); !strings.HasSuffix(full, got) { + t.Fatalf("MachinePublic.UntypedHexString is not a suffix of the typed serialization, got %q want suffix of %q", got, full) + } + + z := MachinePublic{} + if !z.IsZero() { + t.Fatal("IsZero(MachinePublic{}) is false") + } + if s := z.ShortString(); s != "" { + t.Fatalf("MachinePublic{}.ShortString() is %q, want \"\"", s) + } +} + +func TestMachineSerialization(t *testing.T) { + serialized := `{ + "Priv": "privkey:40ab1b58e9076c7a4d9d07291f5edf9d1aa017eb949624ba683317f48a640369", + "Pub":"mkey:50d20b455ecf12bc453f83c2cfdb2a24925d06cf2598dcaa54e91af82ce9f765" + }` + + // Carefully check that the expected serialized data decodes and + // reencodes to the expected keys. These types are serialized to + // disk all over the place and need to be stable. + priv := MachinePrivate{ + k: [32]uint8{ + 0x40, 0xab, 0x1b, 0x58, 0xe9, 0x7, 0x6c, 0x7a, 0x4d, 0x9d, 0x7, + 0x29, 0x1f, 0x5e, 0xdf, 0x9d, 0x1a, 0xa0, 0x17, 0xeb, 0x94, + 0x96, 0x24, 0xba, 0x68, 0x33, 0x17, 0xf4, 0x8a, 0x64, 0x3, 0x69, + }, + } + pub := MachinePublic{ + k: [32]uint8{ + 0x50, 0xd2, 0xb, 0x45, 0x5e, 0xcf, 0x12, 0xbc, 0x45, 0x3f, 0x83, + 0xc2, 0xcf, 0xdb, 0x2a, 0x24, 0x92, 0x5d, 0x6, 0xcf, 0x25, 0x98, + 0xdc, 0xaa, 0x54, 0xe9, 0x1a, 0xf8, 0x2c, 0xe9, 0xf7, 0x65, + }, + } + + type keypair struct { + Priv MachinePrivate + Pub MachinePublic + } + + var a keypair + if err := json.Unmarshal([]byte(serialized), &a); err != nil { + t.Fatal(err) + } + if !a.Priv.Equal(priv) { + t.Errorf("wrong deserialization of private key, got %#v want %#v", a.Priv, priv) + } + if a.Pub != pub { + t.Errorf("wrong deserialization of public key, got %#v want %#v", a.Pub, pub) + } + + bs, err := json.MarshalIndent(a, "", " ") + if err != nil { + t.Fatal(err) + } + + var b bytes.Buffer + json.Indent(&b, []byte(serialized), "", " ") + if got, want := string(bs), b.String(); got != want { + t.Error("json serialization doesn't roundtrip") + } +} + +func TestSealViaSharedKey(t *testing.T) { + // encrypt a message from a to b + a := NewMachine() + b := NewMachine() + apub, bpub := a.Public(), b.Public() + + shared := a.SharedKey(bpub) + + const clear = "the eagle flies at midnight" + enc := shared.Seal([]byte(clear)) + + back, ok := b.OpenFrom(apub, enc) + if !ok { + t.Fatal("failed to decrypt") + } + if string(back) != clear { + t.Errorf("OpenFrom got %q; want cleartext %q", back, clear) + } + + backShared, ok := shared.Open(enc) + if !ok { + t.Fatal("failed to decrypt from shared key") + } + if string(backShared) != clear { + t.Errorf("Open got %q; want cleartext %q", back, clear) + } +} diff --git a/types/key/nl_test.go b/types/key/nl_test.go index 2e10d04acc58b..75b7765a19ea1 100644 --- a/types/key/nl_test.go +++ b/types/key/nl_test.go @@ -1,48 +1,48 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "bytes" - "testing" -) - -func TestNLPrivate(t *testing.T) { - p := NewNLPrivate() - - encoded, err := p.MarshalText() - if err != nil { - t.Fatal(err) - } - var decoded NLPrivate - if err := decoded.UnmarshalText(encoded); err != nil { - t.Fatal(err) - } - if !bytes.Equal(decoded.k[:], p.k[:]) { - t.Error("decoded and generated NLPrivate bytes differ") - } - - // Test NLPublic - pub := p.Public() - encoded, err = pub.MarshalText() - if err != nil { - t.Fatal(err) - } - var decodedPub NLPublic - if err := decodedPub.UnmarshalText(encoded); err != nil { - t.Fatal(err) - } - if !bytes.Equal(decodedPub.k[:], pub.k[:]) { - t.Error("decoded and generated NLPublic bytes differ") - } - - // Test decoding with CLI prefix: 'nlpub:' => 'tlpub:' - decodedPub = NLPublic{} - if err := decodedPub.UnmarshalText([]byte(pub.CLIString())); err != nil { - t.Fatal(err) - } - if !bytes.Equal(decodedPub.k[:], pub.k[:]) { - t.Error("decoded and generated NLPublic bytes differ (CLI prefix)") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "bytes" + "testing" +) + +func TestNLPrivate(t *testing.T) { + p := NewNLPrivate() + + encoded, err := p.MarshalText() + if err != nil { + t.Fatal(err) + } + var decoded NLPrivate + if err := decoded.UnmarshalText(encoded); err != nil { + t.Fatal(err) + } + if !bytes.Equal(decoded.k[:], p.k[:]) { + t.Error("decoded and generated NLPrivate bytes differ") + } + + // Test NLPublic + pub := p.Public() + encoded, err = pub.MarshalText() + if err != nil { + t.Fatal(err) + } + var decodedPub NLPublic + if err := decodedPub.UnmarshalText(encoded); err != nil { + t.Fatal(err) + } + if !bytes.Equal(decodedPub.k[:], pub.k[:]) { + t.Error("decoded and generated NLPublic bytes differ") + } + + // Test decoding with CLI prefix: 'nlpub:' => 'tlpub:' + decodedPub = NLPublic{} + if err := decodedPub.UnmarshalText([]byte(pub.CLIString())); err != nil { + t.Fatal(err) + } + if !bytes.Equal(decodedPub.k[:], pub.k[:]) { + t.Error("decoded and generated NLPublic bytes differ (CLI prefix)") + } +} diff --git a/types/lazy/unsync.go b/types/lazy/unsync.go index ca46f9c7bbad3..0f89ce4f6935a 100644 --- a/types/lazy/unsync.go +++ b/types/lazy/unsync.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package lazy - -// GValue is a lazily computed value. -// -// Use either Get or GetErr, depending on whether your fill function returns an -// error. -// -// Recursive use of a GValue from its own fill function will panic. -// -// GValue is not safe for concurrent use. (Mnemonic: G is for one Goroutine, -// which isn't strictly true if you provide your own synchronization between -// goroutines, but in practice most of our callers have been using it within -// a single goroutine.) -type GValue[T any] struct { - done bool - calling bool - V T - err error -} - -// Set attempts to set z's value to val, and reports whether it succeeded. -// Set only succeeds if none of Get/GetErr/Set have been called before. -func (z *GValue[T]) Set(v T) bool { - if z.done { - return false - } - if z.calling { - panic("Set while Get fill is running") - } - z.V = v - z.done = true - return true -} - -// MustSet sets z's value to val, or panics if z already has a value. -func (z *GValue[T]) MustSet(val T) { - if !z.Set(val) { - panic("Set after already filled") - } -} - -// Get returns z's value, calling fill to compute it if necessary. -// f is called at most once. -func (z *GValue[T]) Get(fill func() T) T { - if !z.done { - if z.calling { - panic("recursive lazy fill") - } - z.calling = true - z.V = fill() - z.done = true - z.calling = false - } - return z.V -} - -// GetErr returns z's value, calling fill to compute it if necessary. -// f is called at most once, and z remembers both of fill's outputs. -func (z *GValue[T]) GetErr(fill func() (T, error)) (T, error) { - if !z.done { - if z.calling { - panic("recursive lazy fill") - } - z.calling = true - z.V, z.err = fill() - z.done = true - z.calling = false - } - return z.V, z.err -} - -// GFunc wraps a function to make it lazy. -// -// The returned function calls fill the first time it's called, and returns -// fill's result on every subsequent call. -// -// The returned function is not safe for concurrent use. -func GFunc[T any](fill func() T) func() T { - var v GValue[T] - return func() T { - return v.Get(fill) - } -} - -// SyncFuncErr wraps a function to make it lazy. -// -// The returned function calls fill the first time it's called, and returns -// fill's results on every subsequent call. -// -// The returned function is not safe for concurrent use. -func GFuncErr[T any](fill func() (T, error)) func() (T, error) { - var v GValue[T] - return func() (T, error) { - return v.GetErr(fill) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lazy + +// GValue is a lazily computed value. +// +// Use either Get or GetErr, depending on whether your fill function returns an +// error. +// +// Recursive use of a GValue from its own fill function will panic. +// +// GValue is not safe for concurrent use. (Mnemonic: G is for one Goroutine, +// which isn't strictly true if you provide your own synchronization between +// goroutines, but in practice most of our callers have been using it within +// a single goroutine.) +type GValue[T any] struct { + done bool + calling bool + V T + err error +} + +// Set attempts to set z's value to val, and reports whether it succeeded. +// Set only succeeds if none of Get/GetErr/Set have been called before. +func (z *GValue[T]) Set(v T) bool { + if z.done { + return false + } + if z.calling { + panic("Set while Get fill is running") + } + z.V = v + z.done = true + return true +} + +// MustSet sets z's value to val, or panics if z already has a value. +func (z *GValue[T]) MustSet(val T) { + if !z.Set(val) { + panic("Set after already filled") + } +} + +// Get returns z's value, calling fill to compute it if necessary. +// f is called at most once. +func (z *GValue[T]) Get(fill func() T) T { + if !z.done { + if z.calling { + panic("recursive lazy fill") + } + z.calling = true + z.V = fill() + z.done = true + z.calling = false + } + return z.V +} + +// GetErr returns z's value, calling fill to compute it if necessary. +// f is called at most once, and z remembers both of fill's outputs. +func (z *GValue[T]) GetErr(fill func() (T, error)) (T, error) { + if !z.done { + if z.calling { + panic("recursive lazy fill") + } + z.calling = true + z.V, z.err = fill() + z.done = true + z.calling = false + } + return z.V, z.err +} + +// GFunc wraps a function to make it lazy. +// +// The returned function calls fill the first time it's called, and returns +// fill's result on every subsequent call. +// +// The returned function is not safe for concurrent use. +func GFunc[T any](fill func() T) func() T { + var v GValue[T] + return func() T { + return v.Get(fill) + } +} + +// SyncFuncErr wraps a function to make it lazy. +// +// The returned function calls fill the first time it's called, and returns +// fill's results on every subsequent call. +// +// The returned function is not safe for concurrent use. +func GFuncErr[T any](fill func() (T, error)) func() (T, error) { + var v GValue[T] + return func() (T, error) { + return v.GetErr(fill) + } +} diff --git a/types/lazy/unsync_test.go b/types/lazy/unsync_test.go index d8b870dbeb8a8..f0d2494d12b6e 100644 --- a/types/lazy/unsync_test.go +++ b/types/lazy/unsync_test.go @@ -1,140 +1,140 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package lazy - -import ( - "errors" - "testing" -) - -func fortyTwo() int { return 42 } - -func TestGValue(t *testing.T) { - var lt GValue[int] - n := int(testing.AllocsPerRun(1000, func() { - got := lt.Get(fortyTwo) - if got != 42 { - t.Fatalf("got %v; want 42", got) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} - -func TestGValueErr(t *testing.T) { - var lt GValue[int] - n := int(testing.AllocsPerRun(1000, func() { - got, err := lt.GetErr(func() (int, error) { - return 42, nil - }) - if got != 42 || err != nil { - t.Fatalf("got %v, %v; want 42, nil", got, err) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } - - var lterr GValue[int] - wantErr := errors.New("test error") - n = int(testing.AllocsPerRun(1000, func() { - got, err := lterr.GetErr(func() (int, error) { - return 0, wantErr - }) - if got != 0 || err != wantErr { - t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} - -func TestGValueSet(t *testing.T) { - var lt GValue[int] - if !lt.Set(42) { - t.Fatalf("Set failed") - } - if lt.Set(43) { - t.Fatalf("Set succeeded after first Set") - } - n := int(testing.AllocsPerRun(1000, func() { - got := lt.Get(fortyTwo) - if got != 42 { - t.Fatalf("got %v; want 42", got) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} - -func TestGValueMustSet(t *testing.T) { - var lt GValue[int] - lt.MustSet(42) - defer func() { - if e := recover(); e == nil { - t.Errorf("unexpected success; want panic") - } - }() - lt.MustSet(43) -} - -func TestGValueRecursivePanic(t *testing.T) { - defer func() { - if e := recover(); e != nil { - t.Logf("got panic, as expected") - } else { - t.Errorf("unexpected success; want panic") - } - }() - v := GValue[int]{} - v.Get(func() int { - return v.Get(func() int { return 42 }) - }) -} - -func TestGFunc(t *testing.T) { - f := GFunc(fortyTwo) - - n := int(testing.AllocsPerRun(1000, func() { - got := f() - if got != 42 { - t.Fatalf("got %v; want 42", got) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} - -func TestGFuncErr(t *testing.T) { - f := GFuncErr(func() (int, error) { - return 42, nil - }) - n := int(testing.AllocsPerRun(1000, func() { - got, err := f() - if got != 42 || err != nil { - t.Fatalf("got %v, %v; want 42, nil", got, err) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } - - wantErr := errors.New("test error") - f = GFuncErr(func() (int, error) { - return 0, wantErr - }) - n = int(testing.AllocsPerRun(1000, func() { - got, err := f() - if got != 0 || err != wantErr { - t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lazy + +import ( + "errors" + "testing" +) + +func fortyTwo() int { return 42 } + +func TestGValue(t *testing.T) { + var lt GValue[int] + n := int(testing.AllocsPerRun(1000, func() { + got := lt.Get(fortyTwo) + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGValueErr(t *testing.T) { + var lt GValue[int] + n := int(testing.AllocsPerRun(1000, func() { + got, err := lt.GetErr(func() (int, error) { + return 42, nil + }) + if got != 42 || err != nil { + t.Fatalf("got %v, %v; want 42, nil", got, err) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } + + var lterr GValue[int] + wantErr := errors.New("test error") + n = int(testing.AllocsPerRun(1000, func() { + got, err := lterr.GetErr(func() (int, error) { + return 0, wantErr + }) + if got != 0 || err != wantErr { + t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGValueSet(t *testing.T) { + var lt GValue[int] + if !lt.Set(42) { + t.Fatalf("Set failed") + } + if lt.Set(43) { + t.Fatalf("Set succeeded after first Set") + } + n := int(testing.AllocsPerRun(1000, func() { + got := lt.Get(fortyTwo) + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGValueMustSet(t *testing.T) { + var lt GValue[int] + lt.MustSet(42) + defer func() { + if e := recover(); e == nil { + t.Errorf("unexpected success; want panic") + } + }() + lt.MustSet(43) +} + +func TestGValueRecursivePanic(t *testing.T) { + defer func() { + if e := recover(); e != nil { + t.Logf("got panic, as expected") + } else { + t.Errorf("unexpected success; want panic") + } + }() + v := GValue[int]{} + v.Get(func() int { + return v.Get(func() int { return 42 }) + }) +} + +func TestGFunc(t *testing.T) { + f := GFunc(fortyTwo) + + n := int(testing.AllocsPerRun(1000, func() { + got := f() + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGFuncErr(t *testing.T) { + f := GFuncErr(func() (int, error) { + return 42, nil + }) + n := int(testing.AllocsPerRun(1000, func() { + got, err := f() + if got != 42 || err != nil { + t.Fatalf("got %v, %v; want 42, nil", got, err) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } + + wantErr := errors.New("test error") + f = GFuncErr(func() (int, error) { + return 0, wantErr + }) + n = int(testing.AllocsPerRun(1000, func() { + got, err := f() + if got != 0 || err != wantErr { + t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} diff --git a/types/logger/rusage.go b/types/logger/rusage.go index ebe0e972d7749..3943636d6e255 100644 --- a/types/logger/rusage.go +++ b/types/logger/rusage.go @@ -1,23 +1,23 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package logger - -import ( - "fmt" - "runtime" -) - -// RusagePrefixLog returns a Logf func wrapping the provided logf func that adds -// a prefixed log message to each line with the current binary memory usage -// and max RSS. -func RusagePrefixLog(logf Logf) Logf { - return func(f string, argv ...any) { - var m runtime.MemStats - runtime.ReadMemStats(&m) - goMem := float64(m.HeapInuse+m.StackInuse) / (1 << 20) - maxRSS := rusageMaxRSS() - pf := fmt.Sprintf("%.1fM/%.1fM %s", goMem, maxRSS, f) - logf(pf, argv...) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package logger + +import ( + "fmt" + "runtime" +) + +// RusagePrefixLog returns a Logf func wrapping the provided logf func that adds +// a prefixed log message to each line with the current binary memory usage +// and max RSS. +func RusagePrefixLog(logf Logf) Logf { + return func(f string, argv ...any) { + var m runtime.MemStats + runtime.ReadMemStats(&m) + goMem := float64(m.HeapInuse+m.StackInuse) / (1 << 20) + maxRSS := rusageMaxRSS() + pf := fmt.Sprintf("%.1fM/%.1fM %s", goMem, maxRSS, f) + logf(pf, argv...) + } +} diff --git a/types/logger/rusage_stub.go b/types/logger/rusage_stub.go index a228b086557fb..f646f1e1eee7f 100644 --- a/types/logger/rusage_stub.go +++ b/types/logger/rusage_stub.go @@ -1,11 +1,11 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build windows || wasm || plan9 || tamago - -package logger - -func rusageMaxRSS() float64 { - // TODO(apenwarr): Substitute Windows equivalent of Getrusage() here. - return 0 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows || wasm || plan9 || tamago + +package logger + +func rusageMaxRSS() float64 { + // TODO(apenwarr): Substitute Windows equivalent of Getrusage() here. + return 0 +} diff --git a/types/logger/rusage_syscall.go b/types/logger/rusage_syscall.go index 19488aef1e800..2871b66c6bb24 100644 --- a/types/logger/rusage_syscall.go +++ b/types/logger/rusage_syscall.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !wasm && !plan9 && !tamago - -package logger - -import ( - "runtime" - - "golang.org/x/sys/unix" -) - -func rusageMaxRSS() float64 { - var ru unix.Rusage - err := unix.Getrusage(unix.RUSAGE_SELF, &ru) - if err != nil { - return 0 - } - - rss := float64(ru.Maxrss) - if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { - rss /= 1 << 20 // ru_maxrss is bytes on darwin - } else { - // ru_maxrss is kilobytes elsewhere (linux, openbsd, etc) - rss /= 1 << 10 - } - return rss -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !wasm && !plan9 && !tamago + +package logger + +import ( + "runtime" + + "golang.org/x/sys/unix" +) + +func rusageMaxRSS() float64 { + var ru unix.Rusage + err := unix.Getrusage(unix.RUSAGE_SELF, &ru) + if err != nil { + return 0 + } + + rss := float64(ru.Maxrss) + if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { + rss /= 1 << 20 // ru_maxrss is bytes on darwin + } else { + // ru_maxrss is kilobytes elsewhere (linux, openbsd, etc) + rss /= 1 << 10 + } + return rss +} diff --git a/types/logger/tokenbucket.go b/types/logger/tokenbucket.go index 2407e01a7abc4..83d4059c2af00 100644 --- a/types/logger/tokenbucket.go +++ b/types/logger/tokenbucket.go @@ -1,63 +1,63 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package logger - -import ( - "time" -) - -// tokenBucket is a simple token bucket style rate limiter. - -// It's similar in function to golang.org/x/time/rate.Limiter, which we -// can't use because: -// - It doesn't give access to the number of accumulated tokens, which we -// need for implementing hysteresis; -// - It doesn't let us provide our own time function, which we need for -// implementing proper unit tests. -// -// rate.Limiter is also much more complex than necessary, but that wouldn't -// be enough to disqualify it on its own. -// -// Unlike rate.Limiter, this token bucket does not attempt to -// do any locking of its own. Don't try to access it reentrantly. -// That's fine inside this types/logger package because we already have -// locking at a higher level. -type tokenBucket struct { - remaining int - max int - tick time.Duration - t time.Time -} - -func newTokenBucket(tick time.Duration, max int, now time.Time) *tokenBucket { - return &tokenBucket{max, max, tick, now} -} - -func (tb *tokenBucket) Get() bool { - if tb.remaining > 0 { - tb.remaining-- - return true - } - return false -} - -func (tb *tokenBucket) Refund(n int) { - b := tb.remaining + n - if b > tb.max { - tb.remaining = tb.max - } else { - tb.remaining = b - } -} - -func (tb *tokenBucket) AdvanceTo(t time.Time) { - diff := t.Sub(tb.t) - - // only use up whole ticks. The remainder will be used up - // next time. - ticks := int(diff / tb.tick) - tb.t = tb.t.Add(time.Duration(ticks) * tb.tick) - - tb.Refund(ticks) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package logger + +import ( + "time" +) + +// tokenBucket is a simple token bucket style rate limiter. + +// It's similar in function to golang.org/x/time/rate.Limiter, which we +// can't use because: +// - It doesn't give access to the number of accumulated tokens, which we +// need for implementing hysteresis; +// - It doesn't let us provide our own time function, which we need for +// implementing proper unit tests. +// +// rate.Limiter is also much more complex than necessary, but that wouldn't +// be enough to disqualify it on its own. +// +// Unlike rate.Limiter, this token bucket does not attempt to +// do any locking of its own. Don't try to access it reentrantly. +// That's fine inside this types/logger package because we already have +// locking at a higher level. +type tokenBucket struct { + remaining int + max int + tick time.Duration + t time.Time +} + +func newTokenBucket(tick time.Duration, max int, now time.Time) *tokenBucket { + return &tokenBucket{max, max, tick, now} +} + +func (tb *tokenBucket) Get() bool { + if tb.remaining > 0 { + tb.remaining-- + return true + } + return false +} + +func (tb *tokenBucket) Refund(n int) { + b := tb.remaining + n + if b > tb.max { + tb.remaining = tb.max + } else { + tb.remaining = b + } +} + +func (tb *tokenBucket) AdvanceTo(t time.Time) { + diff := t.Sub(tb.t) + + // only use up whole ticks. The remainder will be used up + // next time. + ticks := int(diff / tb.tick) + tb.t = tb.t.Add(time.Duration(ticks) * tb.tick) + + tb.Refund(ticks) +} diff --git a/types/netlogtype/netlogtype.go b/types/netlogtype/netlogtype.go index 56002628e94e0..f2fa2bda92366 100644 --- a/types/netlogtype/netlogtype.go +++ b/types/netlogtype/netlogtype.go @@ -1,100 +1,100 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netlogtype defines types for network logging. -package netlogtype - -import ( - "net/netip" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/types/ipproto" -) - -// TODO(joetsai): Remove "omitempty" if "omitzero" is ever supported in both -// the v1 and v2 "json" packages. - -// Message is the log message that captures network traffic. -type Message struct { - NodeID tailcfg.StableNodeID `json:"nodeId" cbor:"0,keyasint"` // e.g., "n123456CNTRL" - - Start time.Time `json:"start" cbor:"12,keyasint"` // inclusive - End time.Time `json:"end" cbor:"13,keyasint"` // inclusive - - VirtualTraffic []ConnectionCounts `json:"virtualTraffic,omitempty" cbor:"14,keyasint,omitempty"` - SubnetTraffic []ConnectionCounts `json:"subnetTraffic,omitempty" cbor:"15,keyasint,omitempty"` - ExitTraffic []ConnectionCounts `json:"exitTraffic,omitempty" cbor:"16,keyasint,omitempty"` - PhysicalTraffic []ConnectionCounts `json:"physicalTraffic,omitempty" cbor:"17,keyasint,omitempty"` -} - -const ( - messageJSON = `{"nodeId":"n0123456789abcdefCNTRL",` + maxJSONTimeRange + `,` + minJSONTraffic + `}` - maxJSONTimeRange = `"start":` + maxJSONRFC3339 + `,"end":` + maxJSONRFC3339 - maxJSONRFC3339 = `"0001-01-01T00:00:00.000000000Z"` - minJSONTraffic = `"virtualTraffic":{},"subnetTraffic":{},"exitTraffic":{},"physicalTraffic":{}` - - // MaxMessageJSONSize is the overhead size of Message when it is - // serialized as JSON assuming that each traffic map is populated. - MaxMessageJSONSize = len(messageJSON) - - maxJSONConnCounts = `{` + maxJSONConn + `,` + maxJSONCounts + `}` - maxJSONConn = `"proto":` + maxJSONProto + `,"src":` + maxJSONAddrPort + `,"dst":` + maxJSONAddrPort - maxJSONProto = `255` - maxJSONAddrPort = `"[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535"` - maxJSONCounts = `"txPkts":` + maxJSONCount + `,"txBytes":` + maxJSONCount + `,"rxPkts":` + maxJSONCount + `,"rxBytes":` + maxJSONCount - maxJSONCount = `18446744073709551615` - - // MaxConnectionCountsJSONSize is the maximum size of a ConnectionCounts - // when it is serialized as JSON, assuming no superfluous whitespace. - // It does not include the trailing comma that often appears when - // this object is nested within an array. - // It assumes that netip.Addr never has IPv6 zones. - MaxConnectionCountsJSONSize = len(maxJSONConnCounts) - - maxCBORConnCounts = "\xbf" + maxCBORConn + maxCBORCounts + "\xff" - maxCBORConn = "\x00" + maxCBORProto + "\x01" + maxCBORAddrPort + "\x02" + maxCBORAddrPort - maxCBORProto = "\x18\xff" - maxCBORAddrPort = "\x52\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" - maxCBORCounts = "\x0c" + maxCBORCount + "\x0d" + maxCBORCount + "\x0e" + maxCBORCount + "\x0f" + maxCBORCount - maxCBORCount = "\x1b\xff\xff\xff\xff\xff\xff\xff\xff" - - // MaxConnectionCountsCBORSize is the maximum size of a ConnectionCounts - // when it is serialized as CBOR. - // It assumes that netip.Addr never has IPv6 zones. - MaxConnectionCountsCBORSize = len(maxCBORConnCounts) -) - -// ConnectionCounts is a flattened struct of both a connection and counts. -type ConnectionCounts struct { - Connection - Counts -} - -// Connection is a 5-tuple of proto, source and destination IP and port. -type Connection struct { - Proto ipproto.Proto `json:"proto,omitzero,omitempty" cbor:"0,keyasint,omitempty"` - Src netip.AddrPort `json:"src,omitzero,omitempty" cbor:"1,keyasint,omitempty"` - Dst netip.AddrPort `json:"dst,omitzero,omitempty" cbor:"2,keyasint,omitempty"` -} - -func (c Connection) IsZero() bool { return c == Connection{} } - -// Counts are statistics about a particular connection. -type Counts struct { - TxPackets uint64 `json:"txPkts,omitzero,omitempty" cbor:"12,keyasint,omitempty"` - TxBytes uint64 `json:"txBytes,omitzero,omitempty" cbor:"13,keyasint,omitempty"` - RxPackets uint64 `json:"rxPkts,omitzero,omitempty" cbor:"14,keyasint,omitempty"` - RxBytes uint64 `json:"rxBytes,omitzero,omitempty" cbor:"15,keyasint,omitempty"` -} - -func (c Counts) IsZero() bool { return c == Counts{} } - -// Add adds the counts from both c1 and c2. -func (c1 Counts) Add(c2 Counts) Counts { - c1.TxPackets += c2.TxPackets - c1.TxBytes += c2.TxBytes - c1.RxPackets += c2.RxPackets - c1.RxBytes += c2.RxBytes - return c1 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netlogtype defines types for network logging. +package netlogtype + +import ( + "net/netip" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/types/ipproto" +) + +// TODO(joetsai): Remove "omitempty" if "omitzero" is ever supported in both +// the v1 and v2 "json" packages. + +// Message is the log message that captures network traffic. +type Message struct { + NodeID tailcfg.StableNodeID `json:"nodeId" cbor:"0,keyasint"` // e.g., "n123456CNTRL" + + Start time.Time `json:"start" cbor:"12,keyasint"` // inclusive + End time.Time `json:"end" cbor:"13,keyasint"` // inclusive + + VirtualTraffic []ConnectionCounts `json:"virtualTraffic,omitempty" cbor:"14,keyasint,omitempty"` + SubnetTraffic []ConnectionCounts `json:"subnetTraffic,omitempty" cbor:"15,keyasint,omitempty"` + ExitTraffic []ConnectionCounts `json:"exitTraffic,omitempty" cbor:"16,keyasint,omitempty"` + PhysicalTraffic []ConnectionCounts `json:"physicalTraffic,omitempty" cbor:"17,keyasint,omitempty"` +} + +const ( + messageJSON = `{"nodeId":"n0123456789abcdefCNTRL",` + maxJSONTimeRange + `,` + minJSONTraffic + `}` + maxJSONTimeRange = `"start":` + maxJSONRFC3339 + `,"end":` + maxJSONRFC3339 + maxJSONRFC3339 = `"0001-01-01T00:00:00.000000000Z"` + minJSONTraffic = `"virtualTraffic":{},"subnetTraffic":{},"exitTraffic":{},"physicalTraffic":{}` + + // MaxMessageJSONSize is the overhead size of Message when it is + // serialized as JSON assuming that each traffic map is populated. + MaxMessageJSONSize = len(messageJSON) + + maxJSONConnCounts = `{` + maxJSONConn + `,` + maxJSONCounts + `}` + maxJSONConn = `"proto":` + maxJSONProto + `,"src":` + maxJSONAddrPort + `,"dst":` + maxJSONAddrPort + maxJSONProto = `255` + maxJSONAddrPort = `"[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535"` + maxJSONCounts = `"txPkts":` + maxJSONCount + `,"txBytes":` + maxJSONCount + `,"rxPkts":` + maxJSONCount + `,"rxBytes":` + maxJSONCount + maxJSONCount = `18446744073709551615` + + // MaxConnectionCountsJSONSize is the maximum size of a ConnectionCounts + // when it is serialized as JSON, assuming no superfluous whitespace. + // It does not include the trailing comma that often appears when + // this object is nested within an array. + // It assumes that netip.Addr never has IPv6 zones. + MaxConnectionCountsJSONSize = len(maxJSONConnCounts) + + maxCBORConnCounts = "\xbf" + maxCBORConn + maxCBORCounts + "\xff" + maxCBORConn = "\x00" + maxCBORProto + "\x01" + maxCBORAddrPort + "\x02" + maxCBORAddrPort + maxCBORProto = "\x18\xff" + maxCBORAddrPort = "\x52\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + maxCBORCounts = "\x0c" + maxCBORCount + "\x0d" + maxCBORCount + "\x0e" + maxCBORCount + "\x0f" + maxCBORCount + maxCBORCount = "\x1b\xff\xff\xff\xff\xff\xff\xff\xff" + + // MaxConnectionCountsCBORSize is the maximum size of a ConnectionCounts + // when it is serialized as CBOR. + // It assumes that netip.Addr never has IPv6 zones. + MaxConnectionCountsCBORSize = len(maxCBORConnCounts) +) + +// ConnectionCounts is a flattened struct of both a connection and counts. +type ConnectionCounts struct { + Connection + Counts +} + +// Connection is a 5-tuple of proto, source and destination IP and port. +type Connection struct { + Proto ipproto.Proto `json:"proto,omitzero,omitempty" cbor:"0,keyasint,omitempty"` + Src netip.AddrPort `json:"src,omitzero,omitempty" cbor:"1,keyasint,omitempty"` + Dst netip.AddrPort `json:"dst,omitzero,omitempty" cbor:"2,keyasint,omitempty"` +} + +func (c Connection) IsZero() bool { return c == Connection{} } + +// Counts are statistics about a particular connection. +type Counts struct { + TxPackets uint64 `json:"txPkts,omitzero,omitempty" cbor:"12,keyasint,omitempty"` + TxBytes uint64 `json:"txBytes,omitzero,omitempty" cbor:"13,keyasint,omitempty"` + RxPackets uint64 `json:"rxPkts,omitzero,omitempty" cbor:"14,keyasint,omitempty"` + RxBytes uint64 `json:"rxBytes,omitzero,omitempty" cbor:"15,keyasint,omitempty"` +} + +func (c Counts) IsZero() bool { return c == Counts{} } + +// Add adds the counts from both c1 and c2. +func (c1 Counts) Add(c2 Counts) Counts { + c1.TxPackets += c2.TxPackets + c1.TxBytes += c2.TxBytes + c1.RxPackets += c2.RxPackets + c1.RxBytes += c2.RxBytes + return c1 +} diff --git a/types/netlogtype/netlogtype_test.go b/types/netlogtype/netlogtype_test.go index 1fa604b317de4..7f29090c5f757 100644 --- a/types/netlogtype/netlogtype_test.go +++ b/types/netlogtype/netlogtype_test.go @@ -1,39 +1,39 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netlogtype - -import ( - "encoding/json" - "math" - "net/netip" - "testing" - - "github.com/fxamacker/cbor/v2" - "github.com/google/go-cmp/cmp" - "tailscale.com/util/must" -) - -func TestMaxSize(t *testing.T) { - maxAddr := netip.AddrFrom16([16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}) - maxAddrPort := netip.AddrPortFrom(maxAddr, math.MaxUint16) - cc := ConnectionCounts{ - // NOTE: These composite literals are deliberately unkeyed so that - // added fields result in a build failure here. - // Newly added fields should result in an update to both - // MaxConnectionCountsJSONSize and MaxConnectionCountsCBORSize. - Connection{math.MaxUint8, maxAddrPort, maxAddrPort}, - Counts{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64}, - } - - outJSON := must.Get(json.Marshal(cc)) - if string(outJSON) != maxJSONConnCounts { - t.Errorf("JSON mismatch (-got +want):\n%s", cmp.Diff(string(outJSON), maxJSONConnCounts)) - } - - outCBOR := must.Get(cbor.Marshal(cc)) - maxCBORConnCountsAlt := "\xa7" + maxCBORConnCounts[1:len(maxCBORConnCounts)-1] // may use a definite encoding of map - if string(outCBOR) != maxCBORConnCounts && string(outCBOR) != maxCBORConnCountsAlt { - t.Errorf("CBOR mismatch (-got +want):\n%s", cmp.Diff(string(outCBOR), maxCBORConnCounts)) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netlogtype + +import ( + "encoding/json" + "math" + "net/netip" + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/google/go-cmp/cmp" + "tailscale.com/util/must" +) + +func TestMaxSize(t *testing.T) { + maxAddr := netip.AddrFrom16([16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}) + maxAddrPort := netip.AddrPortFrom(maxAddr, math.MaxUint16) + cc := ConnectionCounts{ + // NOTE: These composite literals are deliberately unkeyed so that + // added fields result in a build failure here. + // Newly added fields should result in an update to both + // MaxConnectionCountsJSONSize and MaxConnectionCountsCBORSize. + Connection{math.MaxUint8, maxAddrPort, maxAddrPort}, + Counts{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64}, + } + + outJSON := must.Get(json.Marshal(cc)) + if string(outJSON) != maxJSONConnCounts { + t.Errorf("JSON mismatch (-got +want):\n%s", cmp.Diff(string(outJSON), maxJSONConnCounts)) + } + + outCBOR := must.Get(cbor.Marshal(cc)) + maxCBORConnCountsAlt := "\xa7" + maxCBORConnCounts[1:len(maxCBORConnCounts)-1] // may use a definite encoding of map + if string(outCBOR) != maxCBORConnCounts && string(outCBOR) != maxCBORConnCountsAlt { + t.Errorf("CBOR mismatch (-got +want):\n%s", cmp.Diff(string(outCBOR), maxCBORConnCounts)) + } +} diff --git a/types/netmap/netmap_test.go b/types/netmap/netmap_test.go index 910b6bc21fc8d..e7e2d19575c44 100644 --- a/types/netmap/netmap_test.go +++ b/types/netmap/netmap_test.go @@ -1,318 +1,318 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netmap - -import ( - "encoding/hex" - "net/netip" - "testing" - - "go4.org/mem" - "tailscale.com/net/netaddr" - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -func testNodeKey(b byte) (ret key.NodePublic) { - var bs [key.NodePublicRawLen]byte - for i := range bs { - bs[i] = b - } - return key.NodePublicFromRaw32(mem.B(bs[:])) -} - -func testDiscoKey(hexPrefix string) (ret key.DiscoPublic) { - b, err := hex.DecodeString(hexPrefix) - if err != nil { - panic(err) - } - // this function is used with short hexes, so zero-extend the raw - // value. - var bs [32]byte - copy(bs[:], b) - return key.DiscoPublicFromRaw32(mem.B(bs[:])) -} - -func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { - nv := make([]tailcfg.NodeView, len(v)) - for i, n := range v { - nv[i] = n.View() - } - return nv -} - -func eps(s ...string) []netip.AddrPort { - var eps []netip.AddrPort - for _, ep := range s { - eps = append(eps, netip.MustParseAddrPort(ep)) - } - return eps -} - -func TestNetworkMapConcise(t *testing.T) { - for _, tt := range []struct { - name string - nm *NetworkMap - want string - }{ - { - name: "basic", - nm: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - { - Key: testNodeKey(3), - DERP: "127.3.3.40:4", - Endpoints: eps("10.2.0.100:12", "10.1.0.100:12345"), - }, - }), - }, - want: "netmap: self: [AQEBA] auth=machine-unknown u=? []\n [AgICA] D2 : 192.168.0.100:12 192.168.0.100:12354\n [AwMDA] D4 : 10.2.0.100:12 10.1.0.100:12345\n", - }, - } { - t.Run(tt.name, func(t *testing.T) { - var got string - n := int(testing.AllocsPerRun(1000, func() { - got = tt.nm.Concise() - })) - t.Logf("Allocs = %d", n) - if got != tt.want { - t.Errorf("Wrong output\n Got: %q\nWant: %q\n## Got (unescaped):\n%s\n## Want (unescaped):\n%s\n", got, tt.want, got, tt.want) - } - }) - } -} - -func TestConciseDiffFrom(t *testing.T) { - for _, tt := range []struct { - name string - a, b *NetworkMap - want string - }{ - { - name: "no_change", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - want: "", - }, - { - name: "header_change", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(2), - Peers: nodeViews([]*tailcfg.Node{ - { - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - want: "-netmap: self: [AQEBA] auth=machine-unknown u=? []\n+netmap: self: [AgICA] auth=machine-unknown u=? []\n", - }, - { - name: "peer_add", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 1, - Key: testNodeKey(1), - DERP: "127.3.3.40:1", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - { - ID: 3, - Key: testNodeKey(3), - DERP: "127.3.3.40:3", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - want: "+ [AQEBA] D1 : 192.168.0.100:12 192.168.0.100:12354\n+ [AwMDA] D3 : 192.168.0.100:12 192.168.0.100:12354\n", - }, - { - name: "peer_remove", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 1, - Key: testNodeKey(1), - DERP: "127.3.3.40:1", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - { - ID: 3, - Key: testNodeKey(3), - DERP: "127.3.3.40:3", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - want: "- [AQEBA] D1 : 192.168.0.100:12 192.168.0.100:12354\n- [AwMDA] D3 : 192.168.0.100:12 192.168.0.100:12354\n", - }, - { - name: "peer_port_change", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "1.1.1.1:1"), - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "1.1.1.1:2"), - }, - }), - }, - want: "- [AgICA] D2 : 192.168.0.100:12 1.1.1.1:1 \n+ [AgICA] D2 : 192.168.0.100:12 1.1.1.1:2 \n", - }, - { - name: "disco_key_only_change", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), - DiscoKey: testDiscoKey("f00f00f00f"), - AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), - DiscoKey: testDiscoKey("ba4ba4ba4b"), - AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, - }, - }), - }, - want: "- [AgICA] d:f00f00f00f000000 D2 100.102.103.104 : 192.168.0.100:41641 1.1.1.1:41641\n+ [AgICA] d:ba4ba4ba4b000000 D2 100.102.103.104 : 192.168.0.100:41641 1.1.1.1:41641\n", - }, - } { - t.Run(tt.name, func(t *testing.T) { - var got string - n := int(testing.AllocsPerRun(50, func() { - got = tt.b.ConciseDiffFrom(tt.a) - })) - t.Logf("Allocs = %d", n) - if got != tt.want { - t.Errorf("Wrong output\n Got: %q\nWant: %q\n## Got (unescaped):\n%s\n## Want (unescaped):\n%s\n", got, tt.want, got, tt.want) - } - }) - } -} - -func TestPeerIndexByNodeID(t *testing.T) { - var nilPtr *NetworkMap - if nilPtr.PeerIndexByNodeID(123) != -1 { - t.Errorf("nil PeerIndexByNodeID should return -1") - } - var nm NetworkMap - const min = 2 - const max = 10000 - const hole = max / 2 - for nid := tailcfg.NodeID(2); nid <= max; nid++ { - if nid == hole { - continue - } - nm.Peers = append(nm.Peers, (&tailcfg.Node{ID: nid}).View()) - } - for want, nv := range nm.Peers { - got := nm.PeerIndexByNodeID(nv.ID()) - if got != want { - t.Errorf("PeerIndexByNodeID(%v) = %v; want %v", nv.ID(), got, want) - } - } - for _, miss := range []tailcfg.NodeID{min - 1, hole, max + 1} { - if got := nm.PeerIndexByNodeID(miss); got != -1 { - t.Errorf("PeerIndexByNodeID(%v) = %v; want -1", miss, got) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netmap + +import ( + "encoding/hex" + "net/netip" + "testing" + + "go4.org/mem" + "tailscale.com/net/netaddr" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +func testNodeKey(b byte) (ret key.NodePublic) { + var bs [key.NodePublicRawLen]byte + for i := range bs { + bs[i] = b + } + return key.NodePublicFromRaw32(mem.B(bs[:])) +} + +func testDiscoKey(hexPrefix string) (ret key.DiscoPublic) { + b, err := hex.DecodeString(hexPrefix) + if err != nil { + panic(err) + } + // this function is used with short hexes, so zero-extend the raw + // value. + var bs [32]byte + copy(bs[:], b) + return key.DiscoPublicFromRaw32(mem.B(bs[:])) +} + +func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { + nv := make([]tailcfg.NodeView, len(v)) + for i, n := range v { + nv[i] = n.View() + } + return nv +} + +func eps(s ...string) []netip.AddrPort { + var eps []netip.AddrPort + for _, ep := range s { + eps = append(eps, netip.MustParseAddrPort(ep)) + } + return eps +} + +func TestNetworkMapConcise(t *testing.T) { + for _, tt := range []struct { + name string + nm *NetworkMap + want string + }{ + { + name: "basic", + nm: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + { + Key: testNodeKey(3), + DERP: "127.3.3.40:4", + Endpoints: eps("10.2.0.100:12", "10.1.0.100:12345"), + }, + }), + }, + want: "netmap: self: [AQEBA] auth=machine-unknown u=? []\n [AgICA] D2 : 192.168.0.100:12 192.168.0.100:12354\n [AwMDA] D4 : 10.2.0.100:12 10.1.0.100:12345\n", + }, + } { + t.Run(tt.name, func(t *testing.T) { + var got string + n := int(testing.AllocsPerRun(1000, func() { + got = tt.nm.Concise() + })) + t.Logf("Allocs = %d", n) + if got != tt.want { + t.Errorf("Wrong output\n Got: %q\nWant: %q\n## Got (unescaped):\n%s\n## Want (unescaped):\n%s\n", got, tt.want, got, tt.want) + } + }) + } +} + +func TestConciseDiffFrom(t *testing.T) { + for _, tt := range []struct { + name string + a, b *NetworkMap + want string + }{ + { + name: "no_change", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + want: "", + }, + { + name: "header_change", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(2), + Peers: nodeViews([]*tailcfg.Node{ + { + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + want: "-netmap: self: [AQEBA] auth=machine-unknown u=? []\n+netmap: self: [AgICA] auth=machine-unknown u=? []\n", + }, + { + name: "peer_add", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 1, + Key: testNodeKey(1), + DERP: "127.3.3.40:1", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + { + ID: 3, + Key: testNodeKey(3), + DERP: "127.3.3.40:3", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + want: "+ [AQEBA] D1 : 192.168.0.100:12 192.168.0.100:12354\n+ [AwMDA] D3 : 192.168.0.100:12 192.168.0.100:12354\n", + }, + { + name: "peer_remove", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 1, + Key: testNodeKey(1), + DERP: "127.3.3.40:1", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + { + ID: 3, + Key: testNodeKey(3), + DERP: "127.3.3.40:3", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + want: "- [AQEBA] D1 : 192.168.0.100:12 192.168.0.100:12354\n- [AwMDA] D3 : 192.168.0.100:12 192.168.0.100:12354\n", + }, + { + name: "peer_port_change", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "1.1.1.1:1"), + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "1.1.1.1:2"), + }, + }), + }, + want: "- [AgICA] D2 : 192.168.0.100:12 1.1.1.1:1 \n+ [AgICA] D2 : 192.168.0.100:12 1.1.1.1:2 \n", + }, + { + name: "disco_key_only_change", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), + DiscoKey: testDiscoKey("f00f00f00f"), + AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), + DiscoKey: testDiscoKey("ba4ba4ba4b"), + AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, + }, + }), + }, + want: "- [AgICA] d:f00f00f00f000000 D2 100.102.103.104 : 192.168.0.100:41641 1.1.1.1:41641\n+ [AgICA] d:ba4ba4ba4b000000 D2 100.102.103.104 : 192.168.0.100:41641 1.1.1.1:41641\n", + }, + } { + t.Run(tt.name, func(t *testing.T) { + var got string + n := int(testing.AllocsPerRun(50, func() { + got = tt.b.ConciseDiffFrom(tt.a) + })) + t.Logf("Allocs = %d", n) + if got != tt.want { + t.Errorf("Wrong output\n Got: %q\nWant: %q\n## Got (unescaped):\n%s\n## Want (unescaped):\n%s\n", got, tt.want, got, tt.want) + } + }) + } +} + +func TestPeerIndexByNodeID(t *testing.T) { + var nilPtr *NetworkMap + if nilPtr.PeerIndexByNodeID(123) != -1 { + t.Errorf("nil PeerIndexByNodeID should return -1") + } + var nm NetworkMap + const min = 2 + const max = 10000 + const hole = max / 2 + for nid := tailcfg.NodeID(2); nid <= max; nid++ { + if nid == hole { + continue + } + nm.Peers = append(nm.Peers, (&tailcfg.Node{ID: nid}).View()) + } + for want, nv := range nm.Peers { + got := nm.PeerIndexByNodeID(nv.ID()) + if got != want { + t.Errorf("PeerIndexByNodeID(%v) = %v; want %v", nv.ID(), got, want) + } + } + for _, miss := range []tailcfg.NodeID{min - 1, hole, max + 1} { + if got := nm.PeerIndexByNodeID(miss); got != -1 { + t.Errorf("PeerIndexByNodeID(%v) = %v; want -1", miss, got) + } + } +} diff --git a/types/nettype/nettype.go b/types/nettype/nettype.go index 8930c36d845b6..5d3d303c38a0d 100644 --- a/types/nettype/nettype.go +++ b/types/nettype/nettype.go @@ -1,65 +1,65 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package nettype defines an interface that doesn't exist in the Go net package. -package nettype - -import ( - "context" - "io" - "net" - "net/netip" - "time" -) - -// PacketListener defines the ListenPacket method as implemented -// by net.ListenConfig, net.ListenPacket, and tstest/natlab. -type PacketListener interface { - ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) -} - -type PacketListenerWithNetIP interface { - ListenPacket(ctx context.Context, network, address string) (PacketConn, error) -} - -// Std implements PacketListener using the Go net package's ListenPacket func. -type Std struct{} - -func (Std) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { - var conf net.ListenConfig - return conf.ListenPacket(ctx, network, address) -} - -// PacketConn is like a net.PacketConn but uses the newer netip.AddrPort -// write/read methods. -type PacketConn interface { - WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) - ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error) - io.Closer - LocalAddr() net.Addr - SetDeadline(time.Time) error - SetReadDeadline(time.Time) error - SetWriteDeadline(time.Time) error -} - -func MakePacketListenerWithNetIP(ln PacketListener) PacketListenerWithNetIP { - return packetListenerAdapter{ln} -} - -type packetListenerAdapter struct { - PacketListener -} - -func (a packetListenerAdapter) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) { - pc, err := a.PacketListener.ListenPacket(ctx, network, address) - if err != nil { - return nil, err - } - return pc.(PacketConn), nil -} - -// ConnPacketConn is the interface that's a superset of net.Conn and net.PacketConn. -type ConnPacketConn interface { - net.Conn - net.PacketConn -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package nettype defines an interface that doesn't exist in the Go net package. +package nettype + +import ( + "context" + "io" + "net" + "net/netip" + "time" +) + +// PacketListener defines the ListenPacket method as implemented +// by net.ListenConfig, net.ListenPacket, and tstest/natlab. +type PacketListener interface { + ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) +} + +type PacketListenerWithNetIP interface { + ListenPacket(ctx context.Context, network, address string) (PacketConn, error) +} + +// Std implements PacketListener using the Go net package's ListenPacket func. +type Std struct{} + +func (Std) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + var conf net.ListenConfig + return conf.ListenPacket(ctx, network, address) +} + +// PacketConn is like a net.PacketConn but uses the newer netip.AddrPort +// write/read methods. +type PacketConn interface { + WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) + ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error) + io.Closer + LocalAddr() net.Addr + SetDeadline(time.Time) error + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error +} + +func MakePacketListenerWithNetIP(ln PacketListener) PacketListenerWithNetIP { + return packetListenerAdapter{ln} +} + +type packetListenerAdapter struct { + PacketListener +} + +func (a packetListenerAdapter) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) { + pc, err := a.PacketListener.ListenPacket(ctx, network, address) + if err != nil { + return nil, err + } + return pc.(PacketConn), nil +} + +// ConnPacketConn is the interface that's a superset of net.Conn and net.PacketConn. +type ConnPacketConn interface { + net.Conn + net.PacketConn +} diff --git a/types/preftype/netfiltermode.go b/types/preftype/netfiltermode.go index 5756e50968fa5..273e173444365 100644 --- a/types/preftype/netfiltermode.go +++ b/types/preftype/netfiltermode.go @@ -1,46 +1,46 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package preftype is a leaf package containing types for various -// preferences. -package preftype - -import "fmt" - -// NetfilterMode is the firewall management mode to use when -// programming the Linux network stack. -type NetfilterMode int - -// These numbers are persisted to disk in JSON files and thus can't be -// renumbered or repurposed. -const ( - NetfilterOff NetfilterMode = 0 // remove all tailscale netfilter state - NetfilterNoDivert NetfilterMode = 1 // manage tailscale chains, but don't call them - NetfilterOn NetfilterMode = 2 // manage tailscale chains and call them from main chains -) - -func ParseNetfilterMode(s string) (NetfilterMode, error) { - switch s { - case "off": - return NetfilterOff, nil - case "nodivert": - return NetfilterNoDivert, nil - case "on": - return NetfilterOn, nil - default: - return NetfilterOff, fmt.Errorf("unknown netfilter mode %q", s) - } -} - -func (m NetfilterMode) String() string { - switch m { - case NetfilterOff: - return "off" - case NetfilterNoDivert: - return "nodivert" - case NetfilterOn: - return "on" - default: - return "???" - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package preftype is a leaf package containing types for various +// preferences. +package preftype + +import "fmt" + +// NetfilterMode is the firewall management mode to use when +// programming the Linux network stack. +type NetfilterMode int + +// These numbers are persisted to disk in JSON files and thus can't be +// renumbered or repurposed. +const ( + NetfilterOff NetfilterMode = 0 // remove all tailscale netfilter state + NetfilterNoDivert NetfilterMode = 1 // manage tailscale chains, but don't call them + NetfilterOn NetfilterMode = 2 // manage tailscale chains and call them from main chains +) + +func ParseNetfilterMode(s string) (NetfilterMode, error) { + switch s { + case "off": + return NetfilterOff, nil + case "nodivert": + return NetfilterNoDivert, nil + case "on": + return NetfilterOn, nil + default: + return NetfilterOff, fmt.Errorf("unknown netfilter mode %q", s) + } +} + +func (m NetfilterMode) String() string { + switch m { + case NetfilterOff: + return "off" + case NetfilterNoDivert: + return "nodivert" + case NetfilterOn: + return "on" + default: + return "???" + } +} diff --git a/types/ptr/ptr.go b/types/ptr/ptr.go index beb955bf00b61..beb17bee8ee0e 100644 --- a/types/ptr/ptr.go +++ b/types/ptr/ptr.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package ptr contains the ptr.To function. -package ptr - -// To returns a pointer to a shallow copy of v. -func To[T any](v T) *T { - return &v -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ptr contains the ptr.To function. +package ptr + +// To returns a pointer to a shallow copy of v. +func To[T any](v T) *T { + return &v +} diff --git a/types/structs/structs.go b/types/structs/structs.go index bac6b29917318..47c359f0caa0f 100644 --- a/types/structs/structs.go +++ b/types/structs/structs.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package structs contains the Incomparable type. -package structs - -// Incomparable is a zero-width incomparable type. If added as the -// first field in a struct, it marks that struct as not comparable -// (can't do == or be a map key) and usually doesn't add any width to -// the struct (unless the struct has only small fields). -// -// Be making a struct incomparable, you can prevent misuse (prevent -// people from using ==), but also you can shrink generated binaries, -// as the compiler can omit equality funcs from the binary. -type Incomparable [0]func() +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package structs contains the Incomparable type. +package structs + +// Incomparable is a zero-width incomparable type. If added as the +// first field in a struct, it marks that struct as not comparable +// (can't do == or be a map key) and usually doesn't add any width to +// the struct (unless the struct has only small fields). +// +// Be making a struct incomparable, you can prevent misuse (prevent +// people from using ==), but also you can shrink generated binaries, +// as the compiler can omit equality funcs from the binary. +type Incomparable [0]func() diff --git a/types/tkatype/tkatype.go b/types/tkatype/tkatype.go index aca6f144303d0..6ad51f6a90240 100644 --- a/types/tkatype/tkatype.go +++ b/types/tkatype/tkatype.go @@ -1,40 +1,40 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package tkatype defines types for working with the tka package. -// -// Do not add extra dependencies to this package unless they are tiny, -// because this package encodes wire types that should be lightweight to use. -package tkatype - -// KeyID references a verification key stored in the key authority. A keyID -// uniquely identifies a key. KeyIDs are all 32 bytes. -// -// For 25519 keys: We just use the 32-byte public key. -// -// Even though this is a 32-byte value, we use a byte slice because -// CBOR-encoded byte slices have a different prefix to CBOR-encoded arrays. -// Encoding as a byte slice allows us to change the size in the future if we -// ever need to. -type KeyID []byte - -// MarshaledSignature represents a marshaled tka.NodeKeySignature. -type MarshaledSignature []byte - -// MarshaledAUM represents a marshaled tka.AUM. -type MarshaledAUM []byte - -// AUMSigHash represents the BLAKE2s digest of an Authority Update -// Message (AUM), sans any signatures. -type AUMSigHash [32]byte - -// NKSSigHash represents the BLAKE2s digest of a Node-Key Signature (NKS), -// sans the Signature field if present. -type NKSSigHash [32]byte - -// Signature describes a signature over an AUM, which can be verified -// using the key referenced by KeyID. -type Signature struct { - KeyID KeyID `cbor:"1,keyasint"` - Signature []byte `cbor:"2,keyasint"` -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tkatype defines types for working with the tka package. +// +// Do not add extra dependencies to this package unless they are tiny, +// because this package encodes wire types that should be lightweight to use. +package tkatype + +// KeyID references a verification key stored in the key authority. A keyID +// uniquely identifies a key. KeyIDs are all 32 bytes. +// +// For 25519 keys: We just use the 32-byte public key. +// +// Even though this is a 32-byte value, we use a byte slice because +// CBOR-encoded byte slices have a different prefix to CBOR-encoded arrays. +// Encoding as a byte slice allows us to change the size in the future if we +// ever need to. +type KeyID []byte + +// MarshaledSignature represents a marshaled tka.NodeKeySignature. +type MarshaledSignature []byte + +// MarshaledAUM represents a marshaled tka.AUM. +type MarshaledAUM []byte + +// AUMSigHash represents the BLAKE2s digest of an Authority Update +// Message (AUM), sans any signatures. +type AUMSigHash [32]byte + +// NKSSigHash represents the BLAKE2s digest of a Node-Key Signature (NKS), +// sans the Signature field if present. +type NKSSigHash [32]byte + +// Signature describes a signature over an AUM, which can be verified +// using the key referenced by KeyID. +type Signature struct { + KeyID KeyID `cbor:"1,keyasint"` + Signature []byte `cbor:"2,keyasint"` +} diff --git a/types/tkatype/tkatype_test.go b/types/tkatype/tkatype_test.go index bff90807240e1..c81891b9ce103 100644 --- a/types/tkatype/tkatype_test.go +++ b/types/tkatype/tkatype_test.go @@ -1,43 +1,43 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tkatype - -import ( - "encoding/json" - "testing" - - "golang.org/x/crypto/blake2s" -) - -func TestSigHashSize(t *testing.T) { - var sigHash AUMSigHash - if len(sigHash) != blake2s.Size { - t.Errorf("AUMSigHash is wrong size: got %d, want %d", len(sigHash), blake2s.Size) - } - - var nksHash NKSSigHash - if len(nksHash) != blake2s.Size { - t.Errorf("NKSSigHash is wrong size: got %d, want %d", len(nksHash), blake2s.Size) - } -} - -func TestMarshaledSignatureJSON(t *testing.T) { - sig := MarshaledSignature("abcdef") - j, err := json.Marshal(sig) - if err != nil { - t.Fatal(err) - } - const encoded = `"YWJjZGVm"` - if string(j) != encoded { - t.Errorf("got JSON %q; want %q", j, encoded) - } - - var back MarshaledSignature - if err := json.Unmarshal([]byte(encoded), &back); err != nil { - t.Fatal(err) - } - if string(back) != string(sig) { - t.Errorf("decoded JSON back to %q; want %q", back, sig) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tkatype + +import ( + "encoding/json" + "testing" + + "golang.org/x/crypto/blake2s" +) + +func TestSigHashSize(t *testing.T) { + var sigHash AUMSigHash + if len(sigHash) != blake2s.Size { + t.Errorf("AUMSigHash is wrong size: got %d, want %d", len(sigHash), blake2s.Size) + } + + var nksHash NKSSigHash + if len(nksHash) != blake2s.Size { + t.Errorf("NKSSigHash is wrong size: got %d, want %d", len(nksHash), blake2s.Size) + } +} + +func TestMarshaledSignatureJSON(t *testing.T) { + sig := MarshaledSignature("abcdef") + j, err := json.Marshal(sig) + if err != nil { + t.Fatal(err) + } + const encoded = `"YWJjZGVm"` + if string(j) != encoded { + t.Errorf("got JSON %q; want %q", j, encoded) + } + + var back MarshaledSignature + if err := json.Unmarshal([]byte(encoded), &back); err != nil { + t.Fatal(err) + } + if string(back) != string(sig) { + t.Errorf("decoded JSON back to %q; want %q", back, sig) + } +} diff --git a/util/cibuild/cibuild.go b/util/cibuild/cibuild.go index c3dee61548b42..c1e337f9a142a 100644 --- a/util/cibuild/cibuild.go +++ b/util/cibuild/cibuild.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package cibuild reports runtime CI information. -package cibuild - -import "os" - -// On reports whether the current binary is executing on a CI system. -func On() bool { - // CI env variable is set by GitHub. - // https://docs.github.com/en/actions/learn-github-actions/environment-variables#default-environment-variables - return os.Getenv("GITHUB_ACTIONS") != "" || os.Getenv("CI") == "true" -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package cibuild reports runtime CI information. +package cibuild + +import "os" + +// On reports whether the current binary is executing on a CI system. +func On() bool { + // CI env variable is set by GitHub. + // https://docs.github.com/en/actions/learn-github-actions/environment-variables#default-environment-variables + return os.Getenv("GITHUB_ACTIONS") != "" || os.Getenv("CI") == "true" +} diff --git a/util/cstruct/cstruct.go b/util/cstruct/cstruct.go index e32c90830e6a7..464dc5dc3cadf 100644 --- a/util/cstruct/cstruct.go +++ b/util/cstruct/cstruct.go @@ -1,178 +1,178 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package cstruct provides a helper for decoding binary data that is in the -// form of a padded C structure. -package cstruct - -import ( - "errors" - "io" - - "github.com/josharian/native" -) - -// Size of a pointer-typed value, in bits -const pointerSize = 32 << (^uintptr(0) >> 63) - -// We assume that non-64-bit platforms are 32-bit; we don't expect Go to run on -// a 16- or 8-bit architecture any time soon. -const is64Bit = pointerSize == 64 - -// Decoder reads and decodes padded fields from a slice of bytes. All fields -// are decoded with native endianness. -// -// Methods of a Decoder do not return errors, but rather store any error within -// the Decoder. The first error can be obtained via the Err method; after the -// first error, methods will return the zero value for their type. -type Decoder struct { - b []byte - off int - err error - dbuf [8]byte // for decoding -} - -// NewDecoder creates a Decoder from a byte slice. -func NewDecoder(b []byte) *Decoder { - return &Decoder{b: b} -} - -var errUnsupportedSize = errors.New("unsupported size") - -func padBytes(offset, size int) int { - if offset == 0 || size == 1 { - return 0 - } - remainder := offset % size - return size - remainder -} - -func (d *Decoder) getField(b []byte) error { - size := len(b) - - // We only support fields that are multiples of 2 (or 1-sized) - if size != 1 && size&1 == 1 { - return errUnsupportedSize - } - - // Fields are aligned to their size - padBytes := padBytes(d.off, size) - if d.off+size+padBytes > len(d.b) { - return io.EOF - } - d.off += padBytes - - copy(b, d.b[d.off:d.off+size]) - d.off += size - return nil -} - -// Err returns the first error that was encountered by this Decoder. -func (d *Decoder) Err() error { - return d.err -} - -// Offset returns the current read offset for data in the buffer. -func (d *Decoder) Offset() int { - return d.off -} - -// Byte returns a single byte from the buffer. -func (d *Decoder) Byte() byte { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:1]); err != nil { - d.err = err - return 0 - } - return d.dbuf[0] -} - -// Byte returns a number of bytes from the buffer based on the size of the -// input slice. No padding is applied. -// -// If an error is encountered or this Decoder has previously encountered an -// error, no changes are made to the provided buffer. -func (d *Decoder) Bytes(b []byte) { - if d.err != nil { - return - } - - // No padding for byte slices - size := len(b) - if d.off+size >= len(d.b) { - d.err = io.EOF - return - } - copy(b, d.b[d.off:d.off+size]) - d.off += size -} - -// Uint16 returns a uint16 decoded from the buffer. -func (d *Decoder) Uint16() uint16 { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:2]); err != nil { - d.err = err - return 0 - } - return native.Endian.Uint16(d.dbuf[0:2]) -} - -// Uint32 returns a uint32 decoded from the buffer. -func (d *Decoder) Uint32() uint32 { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:4]); err != nil { - d.err = err - return 0 - } - return native.Endian.Uint32(d.dbuf[0:4]) -} - -// Uint64 returns a uint64 decoded from the buffer. -func (d *Decoder) Uint64() uint64 { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:8]); err != nil { - d.err = err - return 0 - } - return native.Endian.Uint64(d.dbuf[0:8]) -} - -// Uintptr returns a uintptr decoded from the buffer. -func (d *Decoder) Uintptr() uintptr { - if d.err != nil { - return 0 - } - - if is64Bit { - return uintptr(d.Uint64()) - } else { - return uintptr(d.Uint32()) - } -} - -// Int16 returns a int16 decoded from the buffer. -func (d *Decoder) Int16() int16 { - return int16(d.Uint16()) -} - -// Int32 returns a int32 decoded from the buffer. -func (d *Decoder) Int32() int32 { - return int32(d.Uint32()) -} - -// Int64 returns a int64 decoded from the buffer. -func (d *Decoder) Int64() int64 { - return int64(d.Uint64()) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package cstruct provides a helper for decoding binary data that is in the +// form of a padded C structure. +package cstruct + +import ( + "errors" + "io" + + "github.com/josharian/native" +) + +// Size of a pointer-typed value, in bits +const pointerSize = 32 << (^uintptr(0) >> 63) + +// We assume that non-64-bit platforms are 32-bit; we don't expect Go to run on +// a 16- or 8-bit architecture any time soon. +const is64Bit = pointerSize == 64 + +// Decoder reads and decodes padded fields from a slice of bytes. All fields +// are decoded with native endianness. +// +// Methods of a Decoder do not return errors, but rather store any error within +// the Decoder. The first error can be obtained via the Err method; after the +// first error, methods will return the zero value for their type. +type Decoder struct { + b []byte + off int + err error + dbuf [8]byte // for decoding +} + +// NewDecoder creates a Decoder from a byte slice. +func NewDecoder(b []byte) *Decoder { + return &Decoder{b: b} +} + +var errUnsupportedSize = errors.New("unsupported size") + +func padBytes(offset, size int) int { + if offset == 0 || size == 1 { + return 0 + } + remainder := offset % size + return size - remainder +} + +func (d *Decoder) getField(b []byte) error { + size := len(b) + + // We only support fields that are multiples of 2 (or 1-sized) + if size != 1 && size&1 == 1 { + return errUnsupportedSize + } + + // Fields are aligned to their size + padBytes := padBytes(d.off, size) + if d.off+size+padBytes > len(d.b) { + return io.EOF + } + d.off += padBytes + + copy(b, d.b[d.off:d.off+size]) + d.off += size + return nil +} + +// Err returns the first error that was encountered by this Decoder. +func (d *Decoder) Err() error { + return d.err +} + +// Offset returns the current read offset for data in the buffer. +func (d *Decoder) Offset() int { + return d.off +} + +// Byte returns a single byte from the buffer. +func (d *Decoder) Byte() byte { + if d.err != nil { + return 0 + } + + if err := d.getField(d.dbuf[0:1]); err != nil { + d.err = err + return 0 + } + return d.dbuf[0] +} + +// Byte returns a number of bytes from the buffer based on the size of the +// input slice. No padding is applied. +// +// If an error is encountered or this Decoder has previously encountered an +// error, no changes are made to the provided buffer. +func (d *Decoder) Bytes(b []byte) { + if d.err != nil { + return + } + + // No padding for byte slices + size := len(b) + if d.off+size >= len(d.b) { + d.err = io.EOF + return + } + copy(b, d.b[d.off:d.off+size]) + d.off += size +} + +// Uint16 returns a uint16 decoded from the buffer. +func (d *Decoder) Uint16() uint16 { + if d.err != nil { + return 0 + } + + if err := d.getField(d.dbuf[0:2]); err != nil { + d.err = err + return 0 + } + return native.Endian.Uint16(d.dbuf[0:2]) +} + +// Uint32 returns a uint32 decoded from the buffer. +func (d *Decoder) Uint32() uint32 { + if d.err != nil { + return 0 + } + + if err := d.getField(d.dbuf[0:4]); err != nil { + d.err = err + return 0 + } + return native.Endian.Uint32(d.dbuf[0:4]) +} + +// Uint64 returns a uint64 decoded from the buffer. +func (d *Decoder) Uint64() uint64 { + if d.err != nil { + return 0 + } + + if err := d.getField(d.dbuf[0:8]); err != nil { + d.err = err + return 0 + } + return native.Endian.Uint64(d.dbuf[0:8]) +} + +// Uintptr returns a uintptr decoded from the buffer. +func (d *Decoder) Uintptr() uintptr { + if d.err != nil { + return 0 + } + + if is64Bit { + return uintptr(d.Uint64()) + } else { + return uintptr(d.Uint32()) + } +} + +// Int16 returns a int16 decoded from the buffer. +func (d *Decoder) Int16() int16 { + return int16(d.Uint16()) +} + +// Int32 returns a int32 decoded from the buffer. +func (d *Decoder) Int32() int32 { + return int32(d.Uint32()) +} + +// Int64 returns a int64 decoded from the buffer. +func (d *Decoder) Int64() int64 { + return int64(d.Uint64()) +} diff --git a/util/cstruct/cstruct_example_test.go b/util/cstruct/cstruct_example_test.go index a36cbf9f0caa3..17032267b9dc6 100644 --- a/util/cstruct/cstruct_example_test.go +++ b/util/cstruct/cstruct_example_test.go @@ -1,73 +1,73 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Only built on 64-bit platforms to avoid complexity - -//go:build amd64 || arm64 || mips64le || ppc64le || riscv64 - -package cstruct - -import "fmt" - -// This test provides a semi-realistic example of how you can -// use this package to decode a C structure. -func ExampleDecoder() { - // Our example C structure: - // struct mystruct { - // char *p; - // char c; - // /* implicit: char _pad[3]; */ - // int x; - // }; - // - // The Go structure definition: - type myStruct struct { - Ptr uintptr - Ch byte - Intval uint32 - } - - // Our "in-memory" version of the above structure - buf := []byte{ - 1, 2, 3, 4, 0, 0, 0, 0, // ptr - 5, // ch - 99, 99, 99, // padding - 78, 6, 0, 0, // x - } - d := NewDecoder(buf) - - // Decode the structure; if one of these function returns an error, - // then subsequent decoder functions will return the zero value. - var x myStruct - x.Ptr = d.Uintptr() - x.Ch = d.Byte() - x.Intval = d.Uint32() - - // Note that per the Go language spec: - // [...] when evaluating the operands of an expression, assignment, - // or return statement, all function calls, method calls, and - // (channel) communication operations are evaluated in lexical - // left-to-right order - // - // Since each field is assigned via a function call, one could use the - // following snippet to decode the struct. - // x := myStruct{ - // Ptr: d.Uintptr(), - // Ch: d.Byte(), - // Intval: d.Uint32(), - // } - // - // However, this means that reordering the fields in the initialization - // statement–normally a semantically identical operation–would change - // the way the structure is parsed. Thus we do it as above with - // explicit ordering. - - // After finishing with the decoder, check errors - if err := d.Err(); err != nil { - panic(err) - } - - // Print the decoder offset and structure - fmt.Printf("off=%d struct=%#v\n", d.Offset(), x) - // Output: off=16 struct=cstruct.myStruct{Ptr:0x4030201, Ch:0x5, Intval:0x64e} -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Only built on 64-bit platforms to avoid complexity + +//go:build amd64 || arm64 || mips64le || ppc64le || riscv64 + +package cstruct + +import "fmt" + +// This test provides a semi-realistic example of how you can +// use this package to decode a C structure. +func ExampleDecoder() { + // Our example C structure: + // struct mystruct { + // char *p; + // char c; + // /* implicit: char _pad[3]; */ + // int x; + // }; + // + // The Go structure definition: + type myStruct struct { + Ptr uintptr + Ch byte + Intval uint32 + } + + // Our "in-memory" version of the above structure + buf := []byte{ + 1, 2, 3, 4, 0, 0, 0, 0, // ptr + 5, // ch + 99, 99, 99, // padding + 78, 6, 0, 0, // x + } + d := NewDecoder(buf) + + // Decode the structure; if one of these function returns an error, + // then subsequent decoder functions will return the zero value. + var x myStruct + x.Ptr = d.Uintptr() + x.Ch = d.Byte() + x.Intval = d.Uint32() + + // Note that per the Go language spec: + // [...] when evaluating the operands of an expression, assignment, + // or return statement, all function calls, method calls, and + // (channel) communication operations are evaluated in lexical + // left-to-right order + // + // Since each field is assigned via a function call, one could use the + // following snippet to decode the struct. + // x := myStruct{ + // Ptr: d.Uintptr(), + // Ch: d.Byte(), + // Intval: d.Uint32(), + // } + // + // However, this means that reordering the fields in the initialization + // statement–normally a semantically identical operation–would change + // the way the structure is parsed. Thus we do it as above with + // explicit ordering. + + // After finishing with the decoder, check errors + if err := d.Err(); err != nil { + panic(err) + } + + // Print the decoder offset and structure + fmt.Printf("off=%d struct=%#v\n", d.Offset(), x) + // Output: off=16 struct=cstruct.myStruct{Ptr:0x4030201, Ch:0x5, Intval:0x64e} +} diff --git a/util/deephash/debug.go b/util/deephash/debug.go index ff417e5835178..50b3d5605f327 100644 --- a/util/deephash/debug.go +++ b/util/deephash/debug.go @@ -1,37 +1,37 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build deephash_debug - -package deephash - -import "fmt" - -func (h *hasher) HashBytes(b []byte) { - fmt.Printf("B(%q)+", b) - h.Block512.HashBytes(b) -} -func (h *hasher) HashString(s string) { - fmt.Printf("S(%q)+", s) - h.Block512.HashString(s) -} -func (h *hasher) HashUint8(n uint8) { - fmt.Printf("U8(%d)+", n) - h.Block512.HashUint8(n) -} -func (h *hasher) HashUint16(n uint16) { - fmt.Printf("U16(%d)+", n) - h.Block512.HashUint16(n) -} -func (h *hasher) HashUint32(n uint32) { - fmt.Printf("U32(%d)+", n) - h.Block512.HashUint32(n) -} -func (h *hasher) HashUint64(n uint64) { - fmt.Printf("U64(%d)+", n) - h.Block512.HashUint64(n) -} -func (h *hasher) Sum(b []byte) []byte { - fmt.Println("FIN") - return h.Block512.Sum(b) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build deephash_debug + +package deephash + +import "fmt" + +func (h *hasher) HashBytes(b []byte) { + fmt.Printf("B(%q)+", b) + h.Block512.HashBytes(b) +} +func (h *hasher) HashString(s string) { + fmt.Printf("S(%q)+", s) + h.Block512.HashString(s) +} +func (h *hasher) HashUint8(n uint8) { + fmt.Printf("U8(%d)+", n) + h.Block512.HashUint8(n) +} +func (h *hasher) HashUint16(n uint16) { + fmt.Printf("U16(%d)+", n) + h.Block512.HashUint16(n) +} +func (h *hasher) HashUint32(n uint32) { + fmt.Printf("U32(%d)+", n) + h.Block512.HashUint32(n) +} +func (h *hasher) HashUint64(n uint64) { + fmt.Printf("U64(%d)+", n) + h.Block512.HashUint64(n) +} +func (h *hasher) Sum(b []byte) []byte { + fmt.Println("FIN") + return h.Block512.Sum(b) +} diff --git a/util/deephash/pointer.go b/util/deephash/pointer.go index 71b11d7ff1d75..aafae47a23673 100644 --- a/util/deephash/pointer.go +++ b/util/deephash/pointer.go @@ -1,114 +1,114 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package deephash - -import ( - "net/netip" - "reflect" - "time" - "unsafe" -) - -// unsafePointer is an untyped pointer. -// It is the caller's responsibility to call operations on the correct type. -// -// This pointer only ever points to a small set of kinds or types: -// time.Time, netip.Addr, string, array, slice, struct, map, pointer, interface, -// or a pointer to memory that is directly hashable. -// -// Arrays are represented as pointers to the first element. -// Structs are represented as pointers to the first field. -// Slices are represented as pointers to a slice header. -// Pointers are represented as pointers to a pointer. -// -// We do not support direct operations on maps and interfaces, and instead -// rely on pointer.asValue to convert the pointer back to a reflect.Value. -// Conversion of an unsafe.Pointer to reflect.Value guarantees that the -// read-only flag in the reflect.Value is unpopulated, avoiding panics that may -// otherwise have occurred since the value was obtained from an unexported field. -type unsafePointer struct{ p unsafe.Pointer } - -func unsafePointerOf(v reflect.Value) unsafePointer { - return unsafePointer{v.UnsafePointer()} -} -func (p unsafePointer) isNil() bool { - return p.p == nil -} - -// pointerElem dereferences a pointer. -// p must point to a pointer. -func (p unsafePointer) pointerElem() unsafePointer { - return unsafePointer{*(*unsafe.Pointer)(p.p)} -} - -// sliceLen returns the slice length. -// p must point to a slice. -func (p unsafePointer) sliceLen() int { - return (*reflect.SliceHeader)(p.p).Len -} - -// sliceArray returns a pointer to the underlying slice array. -// p must point to a slice. -func (p unsafePointer) sliceArray() unsafePointer { - return unsafePointer{unsafe.Pointer((*reflect.SliceHeader)(p.p).Data)} -} - -// arrayIndex returns a pointer to an element in the array. -// p must point to an array. -func (p unsafePointer) arrayIndex(index int, size uintptr) unsafePointer { - return unsafePointer{unsafe.Add(p.p, uintptr(index)*size)} -} - -// structField returns a pointer to a field in a struct. -// p must pointer to a struct. -func (p unsafePointer) structField(index int, offset, size uintptr) unsafePointer { - return unsafePointer{unsafe.Add(p.p, offset)} -} - -// asString casts p as a *string. -func (p unsafePointer) asString() *string { - return (*string)(p.p) -} - -// asTime casts p as a *time.Time. -func (p unsafePointer) asTime() *time.Time { - return (*time.Time)(p.p) -} - -// asAddr casts p as a *netip.Addr. -func (p unsafePointer) asAddr() *netip.Addr { - return (*netip.Addr)(p.p) -} - -// asValue casts p as a reflect.Value containing a pointer to value of t. -func (p unsafePointer) asValue(typ reflect.Type) reflect.Value { - return reflect.NewAt(typ, p.p) -} - -// asMemory returns the memory pointer at by p for a specified size. -func (p unsafePointer) asMemory(size uintptr) []byte { - return unsafe.Slice((*byte)(p.p), size) -} - -// visitStack is a stack of pointers visited. -// Pointers are pushed onto the stack when visited, and popped when leaving. -// The integer value is the depth at which the pointer was visited. -// The length of this stack should be zero after every hashing operation. -type visitStack map[unsafe.Pointer]int - -func (v visitStack) seen(p unsafe.Pointer) (int, bool) { - idx, ok := v[p] - return idx, ok -} - -func (v *visitStack) push(p unsafe.Pointer) { - if *v == nil { - *v = make(map[unsafe.Pointer]int) - } - (*v)[p] = len(*v) -} - -func (v visitStack) pop(p unsafe.Pointer) { - delete(v, p) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package deephash + +import ( + "net/netip" + "reflect" + "time" + "unsafe" +) + +// unsafePointer is an untyped pointer. +// It is the caller's responsibility to call operations on the correct type. +// +// This pointer only ever points to a small set of kinds or types: +// time.Time, netip.Addr, string, array, slice, struct, map, pointer, interface, +// or a pointer to memory that is directly hashable. +// +// Arrays are represented as pointers to the first element. +// Structs are represented as pointers to the first field. +// Slices are represented as pointers to a slice header. +// Pointers are represented as pointers to a pointer. +// +// We do not support direct operations on maps and interfaces, and instead +// rely on pointer.asValue to convert the pointer back to a reflect.Value. +// Conversion of an unsafe.Pointer to reflect.Value guarantees that the +// read-only flag in the reflect.Value is unpopulated, avoiding panics that may +// otherwise have occurred since the value was obtained from an unexported field. +type unsafePointer struct{ p unsafe.Pointer } + +func unsafePointerOf(v reflect.Value) unsafePointer { + return unsafePointer{v.UnsafePointer()} +} +func (p unsafePointer) isNil() bool { + return p.p == nil +} + +// pointerElem dereferences a pointer. +// p must point to a pointer. +func (p unsafePointer) pointerElem() unsafePointer { + return unsafePointer{*(*unsafe.Pointer)(p.p)} +} + +// sliceLen returns the slice length. +// p must point to a slice. +func (p unsafePointer) sliceLen() int { + return (*reflect.SliceHeader)(p.p).Len +} + +// sliceArray returns a pointer to the underlying slice array. +// p must point to a slice. +func (p unsafePointer) sliceArray() unsafePointer { + return unsafePointer{unsafe.Pointer((*reflect.SliceHeader)(p.p).Data)} +} + +// arrayIndex returns a pointer to an element in the array. +// p must point to an array. +func (p unsafePointer) arrayIndex(index int, size uintptr) unsafePointer { + return unsafePointer{unsafe.Add(p.p, uintptr(index)*size)} +} + +// structField returns a pointer to a field in a struct. +// p must pointer to a struct. +func (p unsafePointer) structField(index int, offset, size uintptr) unsafePointer { + return unsafePointer{unsafe.Add(p.p, offset)} +} + +// asString casts p as a *string. +func (p unsafePointer) asString() *string { + return (*string)(p.p) +} + +// asTime casts p as a *time.Time. +func (p unsafePointer) asTime() *time.Time { + return (*time.Time)(p.p) +} + +// asAddr casts p as a *netip.Addr. +func (p unsafePointer) asAddr() *netip.Addr { + return (*netip.Addr)(p.p) +} + +// asValue casts p as a reflect.Value containing a pointer to value of t. +func (p unsafePointer) asValue(typ reflect.Type) reflect.Value { + return reflect.NewAt(typ, p.p) +} + +// asMemory returns the memory pointer at by p for a specified size. +func (p unsafePointer) asMemory(size uintptr) []byte { + return unsafe.Slice((*byte)(p.p), size) +} + +// visitStack is a stack of pointers visited. +// Pointers are pushed onto the stack when visited, and popped when leaving. +// The integer value is the depth at which the pointer was visited. +// The length of this stack should be zero after every hashing operation. +type visitStack map[unsafe.Pointer]int + +func (v visitStack) seen(p unsafe.Pointer) (int, bool) { + idx, ok := v[p] + return idx, ok +} + +func (v *visitStack) push(p unsafe.Pointer) { + if *v == nil { + *v = make(map[unsafe.Pointer]int) + } + (*v)[p] = len(*v) +} + +func (v visitStack) pop(p unsafe.Pointer) { + delete(v, p) +} diff --git a/util/deephash/pointer_norace.go b/util/deephash/pointer_norace.go index 4993720002460..f98a70f6a18e5 100644 --- a/util/deephash/pointer_norace.go +++ b/util/deephash/pointer_norace.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !race - -package deephash - -import "reflect" - -type pointer = unsafePointer - -// pointerOf returns a pointer from v, which must be a reflect.Pointer. -func pointerOf(v reflect.Value) pointer { return unsafePointerOf(v) } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !race + +package deephash + +import "reflect" + +type pointer = unsafePointer + +// pointerOf returns a pointer from v, which must be a reflect.Pointer. +func pointerOf(v reflect.Value) pointer { return unsafePointerOf(v) } diff --git a/util/deephash/pointer_race.go b/util/deephash/pointer_race.go index 93a358b6df358..c638c7d39f393 100644 --- a/util/deephash/pointer_race.go +++ b/util/deephash/pointer_race.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build race - -package deephash - -import ( - "fmt" - "net/netip" - "reflect" - "time" -) - -// pointer is a typed pointer that performs safety checks for every operation. -type pointer struct { - unsafePointer - t reflect.Type // type of pointed-at value; may be nil - n uintptr // size of valid memory after p -} - -// pointerOf returns a pointer from v, which must be a reflect.Pointer. -func pointerOf(v reflect.Value) pointer { - assert(v.Kind() == reflect.Pointer, "got %v, want pointer", v.Kind()) - te := v.Type().Elem() - return pointer{unsafePointerOf(v), te, te.Size()} -} - -func (p pointer) pointerElem() pointer { - assert(p.t.Kind() == reflect.Pointer, "got %v, want pointer", p.t.Kind()) - te := p.t.Elem() - return pointer{p.unsafePointer.pointerElem(), te, te.Size()} -} - -func (p pointer) sliceLen() int { - assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind()) - return p.unsafePointer.sliceLen() -} - -func (p pointer) sliceArray() pointer { - assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind()) - n := p.sliceLen() - assert(n >= 0, "got negative slice length %d", n) - ta := reflect.ArrayOf(n, p.t.Elem()) - return pointer{p.unsafePointer.sliceArray(), ta, ta.Size()} -} - -func (p pointer) arrayIndex(index int, size uintptr) pointer { - assert(p.t.Kind() == reflect.Array, "got %v, want array", p.t.Kind()) - assert(0 <= index && index < p.t.Len(), "got array of size %d, want to access element %d", p.t.Len(), index) - assert(p.t.Elem().Size() == size, "got element size of %d, want %d", p.t.Elem().Size(), size) - te := p.t.Elem() - return pointer{p.unsafePointer.arrayIndex(index, size), te, te.Size()} -} - -func (p pointer) structField(index int, offset, size uintptr) pointer { - assert(p.t.Kind() == reflect.Struct, "got %v, want struct", p.t.Kind()) - assert(p.n >= offset, "got size of %d, want excessive start offset of %d", p.n, offset) - assert(p.n >= offset+size, "got size of %d, want excessive end offset of %d", p.n, offset+size) - if index < 0 { - return pointer{p.unsafePointer.structField(index, offset, size), nil, size} - } - sf := p.t.Field(index) - t := sf.Type - assert(sf.Offset == offset, "got offset of %d, want offset %d", sf.Offset, offset) - assert(t.Size() == size, "got size of %d, want size %d", t.Size(), size) - return pointer{p.unsafePointer.structField(index, offset, size), t, t.Size()} -} - -func (p pointer) asString() *string { - assert(p.t.Kind() == reflect.String, "got %v, want string", p.t) - return p.unsafePointer.asString() -} - -func (p pointer) asTime() *time.Time { - assert(p.t == timeTimeType, "got %v, want %v", p.t, timeTimeType) - return p.unsafePointer.asTime() -} - -func (p pointer) asAddr() *netip.Addr { - assert(p.t == netipAddrType, "got %v, want %v", p.t, netipAddrType) - return p.unsafePointer.asAddr() -} - -func (p pointer) asValue(typ reflect.Type) reflect.Value { - assert(p.t == typ, "got %v, want %v", p.t, typ) - return p.unsafePointer.asValue(typ) -} - -func (p pointer) asMemory(size uintptr) []byte { - assert(p.n >= size, "got size of %d, want excessive size of %d", p.n, size) - return p.unsafePointer.asMemory(size) -} - -func assert(b bool, f string, a ...any) { - if !b { - panic(fmt.Sprintf(f, a...)) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build race + +package deephash + +import ( + "fmt" + "net/netip" + "reflect" + "time" +) + +// pointer is a typed pointer that performs safety checks for every operation. +type pointer struct { + unsafePointer + t reflect.Type // type of pointed-at value; may be nil + n uintptr // size of valid memory after p +} + +// pointerOf returns a pointer from v, which must be a reflect.Pointer. +func pointerOf(v reflect.Value) pointer { + assert(v.Kind() == reflect.Pointer, "got %v, want pointer", v.Kind()) + te := v.Type().Elem() + return pointer{unsafePointerOf(v), te, te.Size()} +} + +func (p pointer) pointerElem() pointer { + assert(p.t.Kind() == reflect.Pointer, "got %v, want pointer", p.t.Kind()) + te := p.t.Elem() + return pointer{p.unsafePointer.pointerElem(), te, te.Size()} +} + +func (p pointer) sliceLen() int { + assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind()) + return p.unsafePointer.sliceLen() +} + +func (p pointer) sliceArray() pointer { + assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind()) + n := p.sliceLen() + assert(n >= 0, "got negative slice length %d", n) + ta := reflect.ArrayOf(n, p.t.Elem()) + return pointer{p.unsafePointer.sliceArray(), ta, ta.Size()} +} + +func (p pointer) arrayIndex(index int, size uintptr) pointer { + assert(p.t.Kind() == reflect.Array, "got %v, want array", p.t.Kind()) + assert(0 <= index && index < p.t.Len(), "got array of size %d, want to access element %d", p.t.Len(), index) + assert(p.t.Elem().Size() == size, "got element size of %d, want %d", p.t.Elem().Size(), size) + te := p.t.Elem() + return pointer{p.unsafePointer.arrayIndex(index, size), te, te.Size()} +} + +func (p pointer) structField(index int, offset, size uintptr) pointer { + assert(p.t.Kind() == reflect.Struct, "got %v, want struct", p.t.Kind()) + assert(p.n >= offset, "got size of %d, want excessive start offset of %d", p.n, offset) + assert(p.n >= offset+size, "got size of %d, want excessive end offset of %d", p.n, offset+size) + if index < 0 { + return pointer{p.unsafePointer.structField(index, offset, size), nil, size} + } + sf := p.t.Field(index) + t := sf.Type + assert(sf.Offset == offset, "got offset of %d, want offset %d", sf.Offset, offset) + assert(t.Size() == size, "got size of %d, want size %d", t.Size(), size) + return pointer{p.unsafePointer.structField(index, offset, size), t, t.Size()} +} + +func (p pointer) asString() *string { + assert(p.t.Kind() == reflect.String, "got %v, want string", p.t) + return p.unsafePointer.asString() +} + +func (p pointer) asTime() *time.Time { + assert(p.t == timeTimeType, "got %v, want %v", p.t, timeTimeType) + return p.unsafePointer.asTime() +} + +func (p pointer) asAddr() *netip.Addr { + assert(p.t == netipAddrType, "got %v, want %v", p.t, netipAddrType) + return p.unsafePointer.asAddr() +} + +func (p pointer) asValue(typ reflect.Type) reflect.Value { + assert(p.t == typ, "got %v, want %v", p.t, typ) + return p.unsafePointer.asValue(typ) +} + +func (p pointer) asMemory(size uintptr) []byte { + assert(p.n >= size, "got size of %d, want excessive size of %d", p.n, size) + return p.unsafePointer.asMemory(size) +} + +func assert(b bool, f string, a ...any) { + if !b { + panic(fmt.Sprintf(f, a...)) + } +} diff --git a/util/deephash/testtype/testtype.go b/util/deephash/testtype/testtype.go index 2df38da8777ff..3c90053d6dfd5 100644 --- a/util/deephash/testtype/testtype.go +++ b/util/deephash/testtype/testtype.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package testtype contains types for testing deephash. -package testtype - -import "time" - -type UnexportedAddressableTime struct { - t time.Time -} - -func NewUnexportedAddressableTime(t time.Time) *UnexportedAddressableTime { - return &UnexportedAddressableTime{t: t} -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package testtype contains types for testing deephash. +package testtype + +import "time" + +type UnexportedAddressableTime struct { + t time.Time +} + +func NewUnexportedAddressableTime(t time.Time) *UnexportedAddressableTime { + return &UnexportedAddressableTime{t: t} +} diff --git a/util/dirwalk/dirwalk.go b/util/dirwalk/dirwalk.go index a05ee3553ad90..811766892896a 100644 --- a/util/dirwalk/dirwalk.go +++ b/util/dirwalk/dirwalk.go @@ -1,53 +1,53 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package dirwalk contains code to walk a directory. -package dirwalk - -import ( - "io" - "io/fs" - "os" - - "go4.org/mem" -) - -var osWalkShallow func(name mem.RO, fn WalkFunc) error - -// WalkFunc is the callback type used with WalkShallow. -// -// The name and de are only valid for the duration of func's call -// and should not be retained. -type WalkFunc func(name mem.RO, de fs.DirEntry) error - -// WalkShallow reads the entries in the named directory and calls fn for each. -// It does not recurse into subdirectories. -// -// If fn returns an error, iteration stops and WalkShallow returns that value. -// -// On Linux, WalkShallow does not allocate, so long as certain methods on the -// WalkFunc's DirEntry are not called which necessarily allocate. -func WalkShallow(dirName mem.RO, fn WalkFunc) error { - if f := osWalkShallow; f != nil { - return f(dirName, fn) - } - of, err := os.Open(dirName.StringCopy()) - if err != nil { - return err - } - defer of.Close() - for { - fis, err := of.ReadDir(100) - for _, de := range fis { - if err := fn(mem.S(de.Name()), de); err != nil { - return err - } - } - if err != nil { - if err == io.EOF { - return nil - } - return err - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package dirwalk contains code to walk a directory. +package dirwalk + +import ( + "io" + "io/fs" + "os" + + "go4.org/mem" +) + +var osWalkShallow func(name mem.RO, fn WalkFunc) error + +// WalkFunc is the callback type used with WalkShallow. +// +// The name and de are only valid for the duration of func's call +// and should not be retained. +type WalkFunc func(name mem.RO, de fs.DirEntry) error + +// WalkShallow reads the entries in the named directory and calls fn for each. +// It does not recurse into subdirectories. +// +// If fn returns an error, iteration stops and WalkShallow returns that value. +// +// On Linux, WalkShallow does not allocate, so long as certain methods on the +// WalkFunc's DirEntry are not called which necessarily allocate. +func WalkShallow(dirName mem.RO, fn WalkFunc) error { + if f := osWalkShallow; f != nil { + return f(dirName, fn) + } + of, err := os.Open(dirName.StringCopy()) + if err != nil { + return err + } + defer of.Close() + for { + fis, err := of.ReadDir(100) + for _, de := range fis { + if err := fn(mem.S(de.Name()), de); err != nil { + return err + } + } + if err != nil { + if err == io.EOF { + return nil + } + return err + } + } +} diff --git a/util/dirwalk/dirwalk_linux.go b/util/dirwalk/dirwalk_linux.go index 7147831452d38..256467ebd8ac5 100644 --- a/util/dirwalk/dirwalk_linux.go +++ b/util/dirwalk/dirwalk_linux.go @@ -1,167 +1,167 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dirwalk - -import ( - "fmt" - "io/fs" - "os" - "path/filepath" - "sync" - "syscall" - "unsafe" - - "go4.org/mem" - "golang.org/x/sys/unix" -) - -func init() { - osWalkShallow = linuxWalkShallow -} - -var dirEntPool = &sync.Pool{New: func() any { return new(linuxDirEnt) }} - -func linuxWalkShallow(dirName mem.RO, fn WalkFunc) error { - const blockSize = 8 << 10 - buf := make([]byte, blockSize) // stack-allocated; doesn't escape - - nameb := mem.Append(buf[:0], dirName) - nameb = append(nameb, 0) - - fd, err := sysOpen(nameb) - if err != nil { - return err - } - defer syscall.Close(fd) - - bufp := 0 // starting read position in buf - nbuf := 0 // end valid data in buf - - de := dirEntPool.Get().(*linuxDirEnt) - defer de.cleanAndPutInPool() - de.root = dirName - - for { - if bufp >= nbuf { - bufp = 0 - nbuf, err = readDirent(fd, buf) - if err != nil { - return err - } - if nbuf <= 0 { - return nil - } - } - consumed, name := parseDirEnt(&de.d, buf[bufp:nbuf]) - bufp += consumed - if len(name) == 0 || string(name) == "." || string(name) == ".." { - continue - } - de.name = mem.B(name) - if err := fn(de.name, de); err != nil { - return err - } - } -} - -type linuxDirEnt struct { - root mem.RO - d syscall.Dirent - name mem.RO -} - -func (de *linuxDirEnt) cleanAndPutInPool() { - de.root = mem.RO{} - de.name = mem.RO{} - dirEntPool.Put(de) -} - -func (de *linuxDirEnt) Name() string { return de.name.StringCopy() } -func (de *linuxDirEnt) Info() (fs.FileInfo, error) { - return os.Lstat(filepath.Join(de.root.StringCopy(), de.name.StringCopy())) -} -func (de *linuxDirEnt) IsDir() bool { - return de.d.Type == syscall.DT_DIR -} -func (de *linuxDirEnt) Type() fs.FileMode { - switch de.d.Type { - case syscall.DT_BLK: - return fs.ModeDevice // shrug - case syscall.DT_CHR: - return fs.ModeCharDevice - case syscall.DT_DIR: - return fs.ModeDir - case syscall.DT_FIFO: - return fs.ModeNamedPipe - case syscall.DT_LNK: - return fs.ModeSymlink - case syscall.DT_REG: - return 0 - case syscall.DT_SOCK: - return fs.ModeSocket - default: - return fs.ModeIrregular // shrug - } -} - -func direntNamlen(dirent *syscall.Dirent) int { - const fixedHdr = uint16(unsafe.Offsetof(syscall.Dirent{}.Name)) - limit := dirent.Reclen - fixedHdr - const dirNameLen = 256 // sizeof syscall.Dirent.Name - if limit > dirNameLen { - limit = dirNameLen - } - for i := uint16(0); i < limit; i++ { - if dirent.Name[i] == 0 { - return int(i) - } - } - panic("failed to find terminating 0 byte in dirent") -} - -func parseDirEnt(dirent *syscall.Dirent, buf []byte) (consumed int, name []byte) { - // golang.org/issue/37269 - copy(unsafe.Slice((*byte)(unsafe.Pointer(dirent)), unsafe.Sizeof(syscall.Dirent{})), buf) - if v := unsafe.Offsetof(dirent.Reclen) + unsafe.Sizeof(dirent.Reclen); uintptr(len(buf)) < v { - panic(fmt.Sprintf("buf size of %d smaller than dirent header size %d", len(buf), v)) - } - if len(buf) < int(dirent.Reclen) { - panic(fmt.Sprintf("buf size %d < record length %d", len(buf), dirent.Reclen)) - } - consumed = int(dirent.Reclen) - if dirent.Ino == 0 { // File absent in directory. - return - } - name = unsafe.Slice((*byte)(unsafe.Pointer(&dirent.Name[0])), direntNamlen(dirent)) - return -} - -func sysOpen(name []byte) (fd int, err error) { - if len(name) == 0 || name[len(name)-1] != 0 { - return 0, syscall.EINVAL - } - var dirfd int = unix.AT_FDCWD - for { - r0, _, e1 := syscall.Syscall(unix.SYS_OPENAT, uintptr(dirfd), - uintptr(unsafe.Pointer(&name[0])), 0) - if e1 == 0 { - return int(r0), nil - } - if e1 == syscall.EINTR { - // Since https://golang.org/doc/go1.14#runtime we - // need to loop on EINTR on more places. - continue - } - return 0, syscall.Errno(e1) - } -} - -func readDirent(fd int, buf []byte) (n int, err error) { - for { - nbuf, err := syscall.ReadDirent(fd, buf) - if err != syscall.EINTR { - return nbuf, err - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dirwalk + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "sync" + "syscall" + "unsafe" + + "go4.org/mem" + "golang.org/x/sys/unix" +) + +func init() { + osWalkShallow = linuxWalkShallow +} + +var dirEntPool = &sync.Pool{New: func() any { return new(linuxDirEnt) }} + +func linuxWalkShallow(dirName mem.RO, fn WalkFunc) error { + const blockSize = 8 << 10 + buf := make([]byte, blockSize) // stack-allocated; doesn't escape + + nameb := mem.Append(buf[:0], dirName) + nameb = append(nameb, 0) + + fd, err := sysOpen(nameb) + if err != nil { + return err + } + defer syscall.Close(fd) + + bufp := 0 // starting read position in buf + nbuf := 0 // end valid data in buf + + de := dirEntPool.Get().(*linuxDirEnt) + defer de.cleanAndPutInPool() + de.root = dirName + + for { + if bufp >= nbuf { + bufp = 0 + nbuf, err = readDirent(fd, buf) + if err != nil { + return err + } + if nbuf <= 0 { + return nil + } + } + consumed, name := parseDirEnt(&de.d, buf[bufp:nbuf]) + bufp += consumed + if len(name) == 0 || string(name) == "." || string(name) == ".." { + continue + } + de.name = mem.B(name) + if err := fn(de.name, de); err != nil { + return err + } + } +} + +type linuxDirEnt struct { + root mem.RO + d syscall.Dirent + name mem.RO +} + +func (de *linuxDirEnt) cleanAndPutInPool() { + de.root = mem.RO{} + de.name = mem.RO{} + dirEntPool.Put(de) +} + +func (de *linuxDirEnt) Name() string { return de.name.StringCopy() } +func (de *linuxDirEnt) Info() (fs.FileInfo, error) { + return os.Lstat(filepath.Join(de.root.StringCopy(), de.name.StringCopy())) +} +func (de *linuxDirEnt) IsDir() bool { + return de.d.Type == syscall.DT_DIR +} +func (de *linuxDirEnt) Type() fs.FileMode { + switch de.d.Type { + case syscall.DT_BLK: + return fs.ModeDevice // shrug + case syscall.DT_CHR: + return fs.ModeCharDevice + case syscall.DT_DIR: + return fs.ModeDir + case syscall.DT_FIFO: + return fs.ModeNamedPipe + case syscall.DT_LNK: + return fs.ModeSymlink + case syscall.DT_REG: + return 0 + case syscall.DT_SOCK: + return fs.ModeSocket + default: + return fs.ModeIrregular // shrug + } +} + +func direntNamlen(dirent *syscall.Dirent) int { + const fixedHdr = uint16(unsafe.Offsetof(syscall.Dirent{}.Name)) + limit := dirent.Reclen - fixedHdr + const dirNameLen = 256 // sizeof syscall.Dirent.Name + if limit > dirNameLen { + limit = dirNameLen + } + for i := uint16(0); i < limit; i++ { + if dirent.Name[i] == 0 { + return int(i) + } + } + panic("failed to find terminating 0 byte in dirent") +} + +func parseDirEnt(dirent *syscall.Dirent, buf []byte) (consumed int, name []byte) { + // golang.org/issue/37269 + copy(unsafe.Slice((*byte)(unsafe.Pointer(dirent)), unsafe.Sizeof(syscall.Dirent{})), buf) + if v := unsafe.Offsetof(dirent.Reclen) + unsafe.Sizeof(dirent.Reclen); uintptr(len(buf)) < v { + panic(fmt.Sprintf("buf size of %d smaller than dirent header size %d", len(buf), v)) + } + if len(buf) < int(dirent.Reclen) { + panic(fmt.Sprintf("buf size %d < record length %d", len(buf), dirent.Reclen)) + } + consumed = int(dirent.Reclen) + if dirent.Ino == 0 { // File absent in directory. + return + } + name = unsafe.Slice((*byte)(unsafe.Pointer(&dirent.Name[0])), direntNamlen(dirent)) + return +} + +func sysOpen(name []byte) (fd int, err error) { + if len(name) == 0 || name[len(name)-1] != 0 { + return 0, syscall.EINVAL + } + var dirfd int = unix.AT_FDCWD + for { + r0, _, e1 := syscall.Syscall(unix.SYS_OPENAT, uintptr(dirfd), + uintptr(unsafe.Pointer(&name[0])), 0) + if e1 == 0 { + return int(r0), nil + } + if e1 == syscall.EINTR { + // Since https://golang.org/doc/go1.14#runtime we + // need to loop on EINTR on more places. + continue + } + return 0, syscall.Errno(e1) + } +} + +func readDirent(fd int, buf []byte) (n int, err error) { + for { + nbuf, err := syscall.ReadDirent(fd, buf) + if err != syscall.EINTR { + return nbuf, err + } + } +} diff --git a/util/dirwalk/dirwalk_test.go b/util/dirwalk/dirwalk_test.go index e2e41f634947e..15ebc13dd404d 100644 --- a/util/dirwalk/dirwalk_test.go +++ b/util/dirwalk/dirwalk_test.go @@ -1,91 +1,91 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dirwalk - -import ( - "fmt" - "os" - "path/filepath" - "reflect" - "runtime" - "sort" - "testing" - - "go4.org/mem" - "tailscale.com/tstest" -) - -func TestWalkShallowOSSpecific(t *testing.T) { - if osWalkShallow == nil { - t.Skip("no OS-specific implementation") - } - testWalkShallow(t, false) -} - -func TestWalkShallowPortable(t *testing.T) { - testWalkShallow(t, true) -} - -func testWalkShallow(t *testing.T, portable bool) { - if portable { - tstest.Replace(t, &osWalkShallow, nil) - } - d := t.TempDir() - - t.Run("basics", func(t *testing.T) { - if err := os.WriteFile(filepath.Join(d, "foo"), []byte("1"), 0600); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(d, "bar"), []byte("22"), 0400); err != nil { - t.Fatal(err) - } - if err := os.Mkdir(filepath.Join(d, "baz"), 0777); err != nil { - t.Fatal(err) - } - - var got []string - if err := WalkShallow(mem.S(d), func(name mem.RO, de os.DirEntry) error { - var size int64 - if fi, err := de.Info(); err != nil { - t.Errorf("Info stat error on %q: %v", de.Name(), err) - } else if !fi.IsDir() { - size = fi.Size() - } - got = append(got, fmt.Sprintf("%q %q dir=%v type=%d size=%v", name.StringCopy(), de.Name(), de.IsDir(), de.Type(), size)) - return nil - }); err != nil { - t.Fatal(err) - } - sort.Strings(got) - want := []string{ - `"bar" "bar" dir=false type=0 size=2`, - `"baz" "baz" dir=true type=2147483648 size=0`, - `"foo" "foo" dir=false type=0 size=1`, - } - if !reflect.DeepEqual(got, want) { - t.Errorf("mismatch:\n got %#q\nwant %#q", got, want) - } - }) - - t.Run("err_not_exist", func(t *testing.T) { - err := WalkShallow(mem.S(filepath.Join(d, "not_exist")), func(name mem.RO, de os.DirEntry) error { - return nil - }) - if !os.IsNotExist(err) { - t.Errorf("unexpected error: %v", err) - } - }) - - t.Run("allocs", func(t *testing.T) { - allocs := int(testing.AllocsPerRun(1000, func() { - if err := WalkShallow(mem.S(d), func(name mem.RO, de os.DirEntry) error { return nil }); err != nil { - t.Fatal(err) - } - })) - t.Logf("allocs = %v", allocs) - if !portable && runtime.GOOS == "linux" && allocs != 0 { - t.Errorf("unexpected allocs: got %v, want 0", allocs) - } - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dirwalk + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "runtime" + "sort" + "testing" + + "go4.org/mem" + "tailscale.com/tstest" +) + +func TestWalkShallowOSSpecific(t *testing.T) { + if osWalkShallow == nil { + t.Skip("no OS-specific implementation") + } + testWalkShallow(t, false) +} + +func TestWalkShallowPortable(t *testing.T) { + testWalkShallow(t, true) +} + +func testWalkShallow(t *testing.T, portable bool) { + if portable { + tstest.Replace(t, &osWalkShallow, nil) + } + d := t.TempDir() + + t.Run("basics", func(t *testing.T) { + if err := os.WriteFile(filepath.Join(d, "foo"), []byte("1"), 0600); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(d, "bar"), []byte("22"), 0400); err != nil { + t.Fatal(err) + } + if err := os.Mkdir(filepath.Join(d, "baz"), 0777); err != nil { + t.Fatal(err) + } + + var got []string + if err := WalkShallow(mem.S(d), func(name mem.RO, de os.DirEntry) error { + var size int64 + if fi, err := de.Info(); err != nil { + t.Errorf("Info stat error on %q: %v", de.Name(), err) + } else if !fi.IsDir() { + size = fi.Size() + } + got = append(got, fmt.Sprintf("%q %q dir=%v type=%d size=%v", name.StringCopy(), de.Name(), de.IsDir(), de.Type(), size)) + return nil + }); err != nil { + t.Fatal(err) + } + sort.Strings(got) + want := []string{ + `"bar" "bar" dir=false type=0 size=2`, + `"baz" "baz" dir=true type=2147483648 size=0`, + `"foo" "foo" dir=false type=0 size=1`, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch:\n got %#q\nwant %#q", got, want) + } + }) + + t.Run("err_not_exist", func(t *testing.T) { + err := WalkShallow(mem.S(filepath.Join(d, "not_exist")), func(name mem.RO, de os.DirEntry) error { + return nil + }) + if !os.IsNotExist(err) { + t.Errorf("unexpected error: %v", err) + } + }) + + t.Run("allocs", func(t *testing.T) { + allocs := int(testing.AllocsPerRun(1000, func() { + if err := WalkShallow(mem.S(d), func(name mem.RO, de os.DirEntry) error { return nil }); err != nil { + t.Fatal(err) + } + })) + t.Logf("allocs = %v", allocs) + if !portable && runtime.GOOS == "linux" && allocs != 0 { + t.Errorf("unexpected allocs: got %v, want 0", allocs) + } + }) +} diff --git a/util/goroutines/goroutines.go b/util/goroutines/goroutines.go index 24c61b37cd399..9758b07586613 100644 --- a/util/goroutines/goroutines.go +++ b/util/goroutines/goroutines.go @@ -1,93 +1,93 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The goroutines package contains utilities for getting active goroutines. -package goroutines - -import ( - "bytes" - "fmt" - "runtime" - "strconv" -) - -// ScrubbedGoroutineDump returns either the current goroutine's stack or all -// goroutines' stacks, but with the actual values of arguments scrubbed out, -// lest it contain some private key material. -func ScrubbedGoroutineDump(all bool) []byte { - var buf []byte - // Grab stacks multiple times into increasingly larger buffer sizes - // to minimize the risk that we blow past our iOS memory limit. - for size := 1 << 10; size <= 1<<20; size += 1 << 10 { - buf = make([]byte, size) - buf = buf[:runtime.Stack(buf, all)] - if len(buf) < size { - // It fit. - break - } - } - return scrubHex(buf) -} - -func scrubHex(buf []byte) []byte { - saw := map[string][]byte{} // "0x123" => "v1%3" (unique value 1 and its value mod 8) - - foreachHexAddress(buf, func(in []byte) { - if string(in) == "0x0" { - return - } - if v, ok := saw[string(in)]; ok { - for i := range in { - in[i] = '_' - } - copy(in, v) - return - } - inStr := string(in) - u64, err := strconv.ParseUint(string(in[2:]), 16, 64) - for i := range in { - in[i] = '_' - } - if err != nil { - in[0] = '?' - return - } - v := []byte(fmt.Sprintf("v%d%%%d", len(saw)+1, u64%8)) - saw[inStr] = v - copy(in, v) - }) - return buf -} - -var ohx = []byte("0x") - -// foreachHexAddress calls f with each subslice of b that matches -// regexp `0x[0-9a-f]*`. -func foreachHexAddress(b []byte, f func([]byte)) { - for len(b) > 0 { - i := bytes.Index(b, ohx) - if i == -1 { - return - } - b = b[i:] - hx := hexPrefix(b) - f(hx) - b = b[len(hx):] - } -} - -func hexPrefix(b []byte) []byte { - for i, c := range b { - if i < 2 { - continue - } - if !isHexByte(c) { - return b[:i] - } - } - return b -} - -func isHexByte(b byte) bool { - return '0' <= b && b <= '9' || 'a' <= b && b <= 'f' || 'A' <= b && b <= 'F' -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The goroutines package contains utilities for getting active goroutines. +package goroutines + +import ( + "bytes" + "fmt" + "runtime" + "strconv" +) + +// ScrubbedGoroutineDump returns either the current goroutine's stack or all +// goroutines' stacks, but with the actual values of arguments scrubbed out, +// lest it contain some private key material. +func ScrubbedGoroutineDump(all bool) []byte { + var buf []byte + // Grab stacks multiple times into increasingly larger buffer sizes + // to minimize the risk that we blow past our iOS memory limit. + for size := 1 << 10; size <= 1<<20; size += 1 << 10 { + buf = make([]byte, size) + buf = buf[:runtime.Stack(buf, all)] + if len(buf) < size { + // It fit. + break + } + } + return scrubHex(buf) +} + +func scrubHex(buf []byte) []byte { + saw := map[string][]byte{} // "0x123" => "v1%3" (unique value 1 and its value mod 8) + + foreachHexAddress(buf, func(in []byte) { + if string(in) == "0x0" { + return + } + if v, ok := saw[string(in)]; ok { + for i := range in { + in[i] = '_' + } + copy(in, v) + return + } + inStr := string(in) + u64, err := strconv.ParseUint(string(in[2:]), 16, 64) + for i := range in { + in[i] = '_' + } + if err != nil { + in[0] = '?' + return + } + v := []byte(fmt.Sprintf("v%d%%%d", len(saw)+1, u64%8)) + saw[inStr] = v + copy(in, v) + }) + return buf +} + +var ohx = []byte("0x") + +// foreachHexAddress calls f with each subslice of b that matches +// regexp `0x[0-9a-f]*`. +func foreachHexAddress(b []byte, f func([]byte)) { + for len(b) > 0 { + i := bytes.Index(b, ohx) + if i == -1 { + return + } + b = b[i:] + hx := hexPrefix(b) + f(hx) + b = b[len(hx):] + } +} + +func hexPrefix(b []byte) []byte { + for i, c := range b { + if i < 2 { + continue + } + if !isHexByte(c) { + return b[:i] + } + } + return b +} + +func isHexByte(b byte) bool { + return '0' <= b && b <= '9' || 'a' <= b && b <= 'f' || 'A' <= b && b <= 'F' +} diff --git a/util/goroutines/goroutines_test.go b/util/goroutines/goroutines_test.go index df6560fe5e20b..ae17c399ca274 100644 --- a/util/goroutines/goroutines_test.go +++ b/util/goroutines/goroutines_test.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package goroutines - -import "testing" - -func TestScrubbedGoroutineDump(t *testing.T) { - t.Logf("Got:\n%s\n", ScrubbedGoroutineDump(true)) -} - -func TestScrubHex(t *testing.T) { - tests := []struct { - in, want string - }{ - {"foo", "foo"}, - {"", ""}, - {"0x", "?_"}, - {"0x001 and same 0x001", "v1%1_ and same v1%1_"}, - {"0x008 and same 0x008", "v1%0_ and same v1%0_"}, - {"0x001 and diff 0x002", "v1%1_ and diff v2%2_"}, - } - for _, tt := range tests { - got := scrubHex([]byte(tt.in)) - if string(got) != tt.want { - t.Errorf("for input:\n%s\n\ngot:\n%s\n\nwant:\n%s\n", tt.in, got, tt.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package goroutines + +import "testing" + +func TestScrubbedGoroutineDump(t *testing.T) { + t.Logf("Got:\n%s\n", ScrubbedGoroutineDump(true)) +} + +func TestScrubHex(t *testing.T) { + tests := []struct { + in, want string + }{ + {"foo", "foo"}, + {"", ""}, + {"0x", "?_"}, + {"0x001 and same 0x001", "v1%1_ and same v1%1_"}, + {"0x008 and same 0x008", "v1%0_ and same v1%0_"}, + {"0x001 and diff 0x002", "v1%1_ and diff v2%2_"}, + } + for _, tt := range tests { + got := scrubHex([]byte(tt.in)) + if string(got) != tt.want { + t.Errorf("for input:\n%s\n\ngot:\n%s\n\nwant:\n%s\n", tt.in, got, tt.want) + } + } +} diff --git a/util/groupmember/groupmember.go b/util/groupmember/groupmember.go index 38431a7ff8791..d604168169022 100644 --- a/util/groupmember/groupmember.go +++ b/util/groupmember/groupmember.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package groupmember verifies group membership of the provided user on the -// local system. -package groupmember - -import ( - "os/user" - "slices" -) - -// IsMemberOfGroup reports whether the provided user is a member of -// the provided system group. -func IsMemberOfGroup(group, userName string) (bool, error) { - u, err := user.Lookup(userName) - if err != nil { - return false, err - } - g, err := user.LookupGroup(group) - if err != nil { - return false, err - } - ugids, err := u.GroupIds() - if err != nil { - return false, err - } - return slices.Contains(ugids, g.Gid), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package groupmember verifies group membership of the provided user on the +// local system. +package groupmember + +import ( + "os/user" + "slices" +) + +// IsMemberOfGroup reports whether the provided user is a member of +// the provided system group. +func IsMemberOfGroup(group, userName string) (bool, error) { + u, err := user.Lookup(userName) + if err != nil { + return false, err + } + g, err := user.LookupGroup(group) + if err != nil { + return false, err + } + ugids, err := u.GroupIds() + if err != nil { + return false, err + } + return slices.Contains(ugids, g.Gid), nil +} diff --git a/util/hashx/block512.go b/util/hashx/block512.go index dd69ccd35637c..e637c0c030653 100644 --- a/util/hashx/block512.go +++ b/util/hashx/block512.go @@ -1,197 +1,197 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package hashx provides a concrete implementation of [hash.Hash] -// that operates on a particular block size. -package hashx - -import ( - "encoding/binary" - "fmt" - "hash" - "unsafe" -) - -var _ hash.Hash = (*Block512)(nil) - -// Block512 wraps a [hash.Hash] for functions that operate on 512-bit block sizes. -// It has efficient methods for hashing fixed-width integers. -// -// A hashing algorithm that operates on 512-bit block sizes should be used. -// The hash still operates correctly even with misaligned block sizes, -// but operates less efficiently. -// -// Example algorithms with 512-bit block sizes include: -// - MD4 (https://golang.org/x/crypto/md4) -// - MD5 (https://golang.org/pkg/crypto/md5) -// - BLAKE2s (https://golang.org/x/crypto/blake2s) -// - BLAKE3 -// - RIPEMD (https://golang.org/x/crypto/ripemd160) -// - SHA-0 -// - SHA-1 (https://golang.org/pkg/crypto/sha1) -// - SHA-2 (https://golang.org/pkg/crypto/sha256) -// - Whirlpool -// -// See https://en.wikipedia.org/wiki/Comparison_of_cryptographic_hash_functions#Parameters -// for a list of hash functions and their block sizes. -// -// Block512 assumes that [hash.Hash.Write] never fails and -// never allows the provided buffer to escape. -type Block512 struct { - hash.Hash - - x [512 / 8]byte - nx int -} - -// New512 constructs a new Block512 that wraps h. -// -// It reports an error if the block sizes do not match. -// Misaligned block sizes perform poorly, but execute correctly. -// The error may be ignored if performance is not a concern. -func New512(h hash.Hash) (*Block512, error) { - b := &Block512{Hash: h} - if len(b.x)%h.BlockSize() != 0 { - return b, fmt.Errorf("hashx.Block512: inefficient use of hash.Hash with %d-bit block size", 8*h.BlockSize()) - } - return b, nil -} - -// Write hashes the contents of b. -func (h *Block512) Write(b []byte) (int, error) { - h.HashBytes(b) - return len(b), nil -} - -// Sum appends the current hash to b and returns the resulting slice. -// -// It flushes any partially completed blocks to the underlying [hash.Hash], -// which may cause future operations to be misaligned and less efficient -// until [Block512.Reset] is called. -func (h *Block512) Sum(b []byte) []byte { - if h.nx > 0 { - h.Hash.Write(h.x[:h.nx]) - h.nx = 0 - } - - // Unfortunately hash.Hash.Sum always causes the input to escape since - // escape analysis cannot prove anything past an interface method call. - // Assuming h already escapes, we call Sum with h.x first, - // and then copy the result to b. - sum := h.Hash.Sum(h.x[:0]) - return append(b, sum...) -} - -// Reset resets Block512 to its initial state. -// It recursively resets the underlying [hash.Hash]. -func (h *Block512) Reset() { - h.Hash.Reset() - h.nx = 0 -} - -// HashUint8 hashes n as a 1-byte integer. -func (h *Block512) HashUint8(n uint8) { - // NOTE: This method is carefully written to be inlineable. - if h.nx <= len(h.x)-1 { - h.x[h.nx] = n - h.nx += 1 - } else { - h.hashUint8Slow(n) // mark "noinline" to keep this within inline budget - } -} - -//go:noinline -func (h *Block512) hashUint8Slow(n uint8) { h.hashUint(uint64(n), 1) } - -// HashUint16 hashes n as a 2-byte little-endian integer. -func (h *Block512) HashUint16(n uint16) { - // NOTE: This method is carefully written to be inlineable. - if h.nx <= len(h.x)-2 { - binary.LittleEndian.PutUint16(h.x[h.nx:], n) - h.nx += 2 - } else { - h.hashUint16Slow(n) // mark "noinline" to keep this within inline budget - } -} - -//go:noinline -func (h *Block512) hashUint16Slow(n uint16) { h.hashUint(uint64(n), 2) } - -// HashUint32 hashes n as a 4-byte little-endian integer. -func (h *Block512) HashUint32(n uint32) { - // NOTE: This method is carefully written to be inlineable. - if h.nx <= len(h.x)-4 { - binary.LittleEndian.PutUint32(h.x[h.nx:], n) - h.nx += 4 - } else { - h.hashUint32Slow(n) // mark "noinline" to keep this within inline budget - } -} - -//go:noinline -func (h *Block512) hashUint32Slow(n uint32) { h.hashUint(uint64(n), 4) } - -// HashUint64 hashes n as a 8-byte little-endian integer. -func (h *Block512) HashUint64(n uint64) { - // NOTE: This method is carefully written to be inlineable. - if h.nx <= len(h.x)-8 { - binary.LittleEndian.PutUint64(h.x[h.nx:], n) - h.nx += 8 - } else { - h.hashUint64Slow(n) // mark "noinline" to keep this within inline budget - } -} - -//go:noinline -func (h *Block512) hashUint64Slow(n uint64) { h.hashUint(uint64(n), 8) } - -func (h *Block512) hashUint(n uint64, i int) { - for ; i > 0; i-- { - if h.nx == len(h.x) { - h.Hash.Write(h.x[:]) - h.nx = 0 - } - h.x[h.nx] = byte(n) - h.nx += 1 - n >>= 8 - } -} - -// HashBytes hashes the contents of b. -// It does not explicitly hash the length separately. -func (h *Block512) HashBytes(b []byte) { - // Nearly identical to sha256.digest.Write. - if h.nx > 0 { - n := copy(h.x[h.nx:], b) - h.nx += n - if h.nx == len(h.x) { - h.Hash.Write(h.x[:]) - h.nx = 0 - } - b = b[n:] - } - if len(b) >= len(h.x) { - n := len(b) &^ (len(h.x) - 1) // n is a multiple of len(h.x) - h.Hash.Write(b[:n]) - b = b[n:] - } - if len(b) > 0 { - h.nx = copy(h.x[:], b) - } -} - -// HashString hashes the contents of s. -// It does not explicitly hash the length separately. -func (h *Block512) HashString(s string) { - // TODO: Avoid unsafe when standard hashers implement io.StringWriter. - // See https://go.dev/issue/38776. - type stringHeader struct { - p unsafe.Pointer - n int - } - p := (*stringHeader)(unsafe.Pointer(&s)) - b := unsafe.Slice((*byte)(p.p), p.n) - h.HashBytes(b) -} - -// TODO: Add Hash.MarshalBinary and Hash.UnmarshalBinary? +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package hashx provides a concrete implementation of [hash.Hash] +// that operates on a particular block size. +package hashx + +import ( + "encoding/binary" + "fmt" + "hash" + "unsafe" +) + +var _ hash.Hash = (*Block512)(nil) + +// Block512 wraps a [hash.Hash] for functions that operate on 512-bit block sizes. +// It has efficient methods for hashing fixed-width integers. +// +// A hashing algorithm that operates on 512-bit block sizes should be used. +// The hash still operates correctly even with misaligned block sizes, +// but operates less efficiently. +// +// Example algorithms with 512-bit block sizes include: +// - MD4 (https://golang.org/x/crypto/md4) +// - MD5 (https://golang.org/pkg/crypto/md5) +// - BLAKE2s (https://golang.org/x/crypto/blake2s) +// - BLAKE3 +// - RIPEMD (https://golang.org/x/crypto/ripemd160) +// - SHA-0 +// - SHA-1 (https://golang.org/pkg/crypto/sha1) +// - SHA-2 (https://golang.org/pkg/crypto/sha256) +// - Whirlpool +// +// See https://en.wikipedia.org/wiki/Comparison_of_cryptographic_hash_functions#Parameters +// for a list of hash functions and their block sizes. +// +// Block512 assumes that [hash.Hash.Write] never fails and +// never allows the provided buffer to escape. +type Block512 struct { + hash.Hash + + x [512 / 8]byte + nx int +} + +// New512 constructs a new Block512 that wraps h. +// +// It reports an error if the block sizes do not match. +// Misaligned block sizes perform poorly, but execute correctly. +// The error may be ignored if performance is not a concern. +func New512(h hash.Hash) (*Block512, error) { + b := &Block512{Hash: h} + if len(b.x)%h.BlockSize() != 0 { + return b, fmt.Errorf("hashx.Block512: inefficient use of hash.Hash with %d-bit block size", 8*h.BlockSize()) + } + return b, nil +} + +// Write hashes the contents of b. +func (h *Block512) Write(b []byte) (int, error) { + h.HashBytes(b) + return len(b), nil +} + +// Sum appends the current hash to b and returns the resulting slice. +// +// It flushes any partially completed blocks to the underlying [hash.Hash], +// which may cause future operations to be misaligned and less efficient +// until [Block512.Reset] is called. +func (h *Block512) Sum(b []byte) []byte { + if h.nx > 0 { + h.Hash.Write(h.x[:h.nx]) + h.nx = 0 + } + + // Unfortunately hash.Hash.Sum always causes the input to escape since + // escape analysis cannot prove anything past an interface method call. + // Assuming h already escapes, we call Sum with h.x first, + // and then copy the result to b. + sum := h.Hash.Sum(h.x[:0]) + return append(b, sum...) +} + +// Reset resets Block512 to its initial state. +// It recursively resets the underlying [hash.Hash]. +func (h *Block512) Reset() { + h.Hash.Reset() + h.nx = 0 +} + +// HashUint8 hashes n as a 1-byte integer. +func (h *Block512) HashUint8(n uint8) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-1 { + h.x[h.nx] = n + h.nx += 1 + } else { + h.hashUint8Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Block512) hashUint8Slow(n uint8) { h.hashUint(uint64(n), 1) } + +// HashUint16 hashes n as a 2-byte little-endian integer. +func (h *Block512) HashUint16(n uint16) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-2 { + binary.LittleEndian.PutUint16(h.x[h.nx:], n) + h.nx += 2 + } else { + h.hashUint16Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Block512) hashUint16Slow(n uint16) { h.hashUint(uint64(n), 2) } + +// HashUint32 hashes n as a 4-byte little-endian integer. +func (h *Block512) HashUint32(n uint32) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-4 { + binary.LittleEndian.PutUint32(h.x[h.nx:], n) + h.nx += 4 + } else { + h.hashUint32Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Block512) hashUint32Slow(n uint32) { h.hashUint(uint64(n), 4) } + +// HashUint64 hashes n as a 8-byte little-endian integer. +func (h *Block512) HashUint64(n uint64) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-8 { + binary.LittleEndian.PutUint64(h.x[h.nx:], n) + h.nx += 8 + } else { + h.hashUint64Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Block512) hashUint64Slow(n uint64) { h.hashUint(uint64(n), 8) } + +func (h *Block512) hashUint(n uint64, i int) { + for ; i > 0; i-- { + if h.nx == len(h.x) { + h.Hash.Write(h.x[:]) + h.nx = 0 + } + h.x[h.nx] = byte(n) + h.nx += 1 + n >>= 8 + } +} + +// HashBytes hashes the contents of b. +// It does not explicitly hash the length separately. +func (h *Block512) HashBytes(b []byte) { + // Nearly identical to sha256.digest.Write. + if h.nx > 0 { + n := copy(h.x[h.nx:], b) + h.nx += n + if h.nx == len(h.x) { + h.Hash.Write(h.x[:]) + h.nx = 0 + } + b = b[n:] + } + if len(b) >= len(h.x) { + n := len(b) &^ (len(h.x) - 1) // n is a multiple of len(h.x) + h.Hash.Write(b[:n]) + b = b[n:] + } + if len(b) > 0 { + h.nx = copy(h.x[:], b) + } +} + +// HashString hashes the contents of s. +// It does not explicitly hash the length separately. +func (h *Block512) HashString(s string) { + // TODO: Avoid unsafe when standard hashers implement io.StringWriter. + // See https://go.dev/issue/38776. + type stringHeader struct { + p unsafe.Pointer + n int + } + p := (*stringHeader)(unsafe.Pointer(&s)) + b := unsafe.Slice((*byte)(p.p), p.n) + h.HashBytes(b) +} + +// TODO: Add Hash.MarshalBinary and Hash.UnmarshalBinary? diff --git a/util/httphdr/httphdr.go b/util/httphdr/httphdr.go index b78b165c65701..852e28b8fae03 100644 --- a/util/httphdr/httphdr.go +++ b/util/httphdr/httphdr.go @@ -1,197 +1,197 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package httphdr implements functionality for parsing and formatting -// standard HTTP headers. -package httphdr - -import ( - "bytes" - "strconv" - "strings" -) - -// Range is a range of bytes within some content. -type Range struct { - // Start is the starting offset. - // It is zero if Length is negative; it must not be negative. - Start int64 - // Length is the length of the content. - // It is zero if the length extends to the end of the content. - // It is negative if the length is relative to the end (e.g., last 5 bytes). - Length int64 -} - -// ows is optional whitespace. -const ows = " \t" // per RFC 7230, section 3.2.3 - -// ParseRange parses a "Range" header per RFC 7233, section 3. -// It only handles "Range" headers where the units is "bytes". -// The "Range" header is usually only specified in GET requests. -func ParseRange(hdr string) (ranges []Range, ok bool) { - // Grammar per RFC 7233, appendix D: - // Range = byte-ranges-specifier | other-ranges-specifier - // byte-ranges-specifier = bytes-unit "=" byte-range-set - // bytes-unit = "bytes" - // byte-range-set = - // *("," OWS) - // (byte-range-spec | suffix-byte-range-spec) - // *(OWS "," [OWS ( byte-range-spec | suffix-byte-range-spec )]) - // byte-range-spec = first-byte-pos "-" [last-byte-pos] - // suffix-byte-range-spec = "-" suffix-length - // We do not support other-ranges-specifier. - // All other identifiers are 1*DIGIT. - hdr = strings.Trim(hdr, ows) // per RFC 7230, section 3.2 - units, elems, hasUnits := strings.Cut(hdr, "=") - elems = strings.TrimLeft(elems, ","+ows) - for _, elem := range strings.Split(elems, ",") { - elem = strings.Trim(elem, ows) // per RFC 7230, section 7 - switch { - case strings.HasPrefix(elem, "-"): // i.e., "-" suffix-length - n, ok := parseNumber(strings.TrimPrefix(elem, "-")) - if !ok { - return ranges, false - } - ranges = append(ranges, Range{0, -n}) - case strings.HasSuffix(elem, "-"): // i.e., first-byte-pos "-" - n, ok := parseNumber(strings.TrimSuffix(elem, "-")) - if !ok { - return ranges, false - } - ranges = append(ranges, Range{n, 0}) - default: // i.e., first-byte-pos "-" last-byte-pos - prefix, suffix, hasDash := strings.Cut(elem, "-") - n, ok2 := parseNumber(prefix) - m, ok3 := parseNumber(suffix) - if !hasDash || !ok2 || !ok3 || m < n { - return ranges, false - } - ranges = append(ranges, Range{n, m - n + 1}) - } - } - return ranges, units == "bytes" && hasUnits && len(ranges) > 0 // must see at least one element per RFC 7233, section 2.1 -} - -// FormatRange formats a "Range" header per RFC 7233, section 3. -// It only handles "Range" headers where the units is "bytes". -// The "Range" header is usually only specified in GET requests. -func FormatRange(ranges []Range) (hdr string, ok bool) { - b := []byte("bytes=") - for _, r := range ranges { - switch { - case r.Length > 0: // i.e., first-byte-pos "-" last-byte-pos - if r.Start < 0 { - return string(b), false - } - b = strconv.AppendUint(b, uint64(r.Start), 10) - b = append(b, '-') - b = strconv.AppendUint(b, uint64(r.Start+r.Length-1), 10) - b = append(b, ',') - case r.Length == 0: // i.e., first-byte-pos "-" - if r.Start < 0 { - return string(b), false - } - b = strconv.AppendUint(b, uint64(r.Start), 10) - b = append(b, '-') - b = append(b, ',') - case r.Length < 0: // i.e., "-" suffix-length - if r.Start != 0 { - return string(b), false - } - b = append(b, '-') - b = strconv.AppendUint(b, uint64(-r.Length), 10) - b = append(b, ',') - default: - return string(b), false - } - } - return string(bytes.TrimRight(b, ",")), len(ranges) > 0 -} - -// ParseContentRange parses a "Content-Range" header per RFC 7233, section 4.2. -// It only handles "Content-Range" headers where the units is "bytes". -// The "Content-Range" header is usually only specified in HTTP responses. -// -// If only the completeLength is specified, then start and length are both zero. -// -// Otherwise, the parses the start and length and the optional completeLength, -// which is -1 if unspecified. The start is non-negative and the length is positive. -func ParseContentRange(hdr string) (start, length, completeLength int64, ok bool) { - // Grammar per RFC 7233, appendix D: - // Content-Range = byte-content-range | other-content-range - // byte-content-range = bytes-unit SP (byte-range-resp | unsatisfied-range) - // bytes-unit = "bytes" - // byte-range-resp = byte-range "/" (complete-length | "*") - // unsatisfied-range = "*/" complete-length - // byte-range = first-byte-pos "-" last-byte-pos - // We do not support other-content-range. - // All other identifiers are 1*DIGIT. - hdr = strings.Trim(hdr, ows) // per RFC 7230, section 3.2 - suffix, hasUnits := strings.CutPrefix(hdr, "bytes ") - suffix, unsatisfied := strings.CutPrefix(suffix, "*/") - if unsatisfied { // i.e., unsatisfied-range - n, ok := parseNumber(suffix) - if !ok { - return start, length, completeLength, false - } - completeLength = n - } else { // i.e., byte-range "/" (complete-length | "*") - prefix, suffix, hasDash := strings.Cut(suffix, "-") - middle, suffix, hasSlash := strings.Cut(suffix, "/") - n, ok0 := parseNumber(prefix) - m, ok1 := parseNumber(middle) - o, ok2 := parseNumber(suffix) - if suffix == "*" { - o, ok2 = -1, true - } - if !hasDash || !hasSlash || !ok0 || !ok1 || !ok2 || m < n || (o >= 0 && o <= m) { - return start, length, completeLength, false - } - start = n - length = m - n + 1 - completeLength = o - } - return start, length, completeLength, hasUnits -} - -// FormatContentRange parses a "Content-Range" header per RFC 7233, section 4.2. -// It only handles "Content-Range" headers where the units is "bytes". -// The "Content-Range" header is usually only specified in HTTP responses. -// -// If start and length are non-positive, then it encodes just the completeLength, -// which must be a non-negative value. -// -// Otherwise, it encodes the start and length as a byte-range, -// and optionally emits the complete length if it is non-negative. -// The length must be positive (as RFC 7233 uses inclusive end offsets). -func FormatContentRange(start, length, completeLength int64) (hdr string, ok bool) { - b := []byte("bytes ") - switch { - case start <= 0 && length <= 0 && completeLength >= 0: // i.e., unsatisfied-range - b = append(b, "*/"...) - b = strconv.AppendUint(b, uint64(completeLength), 10) - ok = true - case start >= 0 && length > 0: // i.e., byte-range "/" (complete-length | "*") - b = strconv.AppendUint(b, uint64(start), 10) - b = append(b, '-') - b = strconv.AppendUint(b, uint64(start+length-1), 10) - b = append(b, '/') - if completeLength >= 0 { - b = strconv.AppendUint(b, uint64(completeLength), 10) - ok = completeLength >= start+length && start+length > 0 - } else { - b = append(b, '*') - ok = true - } - } - return string(b), ok -} - -// parseNumber parses s as an unsigned decimal integer. -// It parses according to the 1*DIGIT grammar, which allows leading zeros. -func parseNumber(s string) (int64, bool) { - suffix := strings.TrimLeft(s, "0123456789") - prefix := s[:len(s)-len(suffix)] - n, err := strconv.ParseInt(prefix, 10, 64) - return n, suffix == "" && err == nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package httphdr implements functionality for parsing and formatting +// standard HTTP headers. +package httphdr + +import ( + "bytes" + "strconv" + "strings" +) + +// Range is a range of bytes within some content. +type Range struct { + // Start is the starting offset. + // It is zero if Length is negative; it must not be negative. + Start int64 + // Length is the length of the content. + // It is zero if the length extends to the end of the content. + // It is negative if the length is relative to the end (e.g., last 5 bytes). + Length int64 +} + +// ows is optional whitespace. +const ows = " \t" // per RFC 7230, section 3.2.3 + +// ParseRange parses a "Range" header per RFC 7233, section 3. +// It only handles "Range" headers where the units is "bytes". +// The "Range" header is usually only specified in GET requests. +func ParseRange(hdr string) (ranges []Range, ok bool) { + // Grammar per RFC 7233, appendix D: + // Range = byte-ranges-specifier | other-ranges-specifier + // byte-ranges-specifier = bytes-unit "=" byte-range-set + // bytes-unit = "bytes" + // byte-range-set = + // *("," OWS) + // (byte-range-spec | suffix-byte-range-spec) + // *(OWS "," [OWS ( byte-range-spec | suffix-byte-range-spec )]) + // byte-range-spec = first-byte-pos "-" [last-byte-pos] + // suffix-byte-range-spec = "-" suffix-length + // We do not support other-ranges-specifier. + // All other identifiers are 1*DIGIT. + hdr = strings.Trim(hdr, ows) // per RFC 7230, section 3.2 + units, elems, hasUnits := strings.Cut(hdr, "=") + elems = strings.TrimLeft(elems, ","+ows) + for _, elem := range strings.Split(elems, ",") { + elem = strings.Trim(elem, ows) // per RFC 7230, section 7 + switch { + case strings.HasPrefix(elem, "-"): // i.e., "-" suffix-length + n, ok := parseNumber(strings.TrimPrefix(elem, "-")) + if !ok { + return ranges, false + } + ranges = append(ranges, Range{0, -n}) + case strings.HasSuffix(elem, "-"): // i.e., first-byte-pos "-" + n, ok := parseNumber(strings.TrimSuffix(elem, "-")) + if !ok { + return ranges, false + } + ranges = append(ranges, Range{n, 0}) + default: // i.e., first-byte-pos "-" last-byte-pos + prefix, suffix, hasDash := strings.Cut(elem, "-") + n, ok2 := parseNumber(prefix) + m, ok3 := parseNumber(suffix) + if !hasDash || !ok2 || !ok3 || m < n { + return ranges, false + } + ranges = append(ranges, Range{n, m - n + 1}) + } + } + return ranges, units == "bytes" && hasUnits && len(ranges) > 0 // must see at least one element per RFC 7233, section 2.1 +} + +// FormatRange formats a "Range" header per RFC 7233, section 3. +// It only handles "Range" headers where the units is "bytes". +// The "Range" header is usually only specified in GET requests. +func FormatRange(ranges []Range) (hdr string, ok bool) { + b := []byte("bytes=") + for _, r := range ranges { + switch { + case r.Length > 0: // i.e., first-byte-pos "-" last-byte-pos + if r.Start < 0 { + return string(b), false + } + b = strconv.AppendUint(b, uint64(r.Start), 10) + b = append(b, '-') + b = strconv.AppendUint(b, uint64(r.Start+r.Length-1), 10) + b = append(b, ',') + case r.Length == 0: // i.e., first-byte-pos "-" + if r.Start < 0 { + return string(b), false + } + b = strconv.AppendUint(b, uint64(r.Start), 10) + b = append(b, '-') + b = append(b, ',') + case r.Length < 0: // i.e., "-" suffix-length + if r.Start != 0 { + return string(b), false + } + b = append(b, '-') + b = strconv.AppendUint(b, uint64(-r.Length), 10) + b = append(b, ',') + default: + return string(b), false + } + } + return string(bytes.TrimRight(b, ",")), len(ranges) > 0 +} + +// ParseContentRange parses a "Content-Range" header per RFC 7233, section 4.2. +// It only handles "Content-Range" headers where the units is "bytes". +// The "Content-Range" header is usually only specified in HTTP responses. +// +// If only the completeLength is specified, then start and length are both zero. +// +// Otherwise, the parses the start and length and the optional completeLength, +// which is -1 if unspecified. The start is non-negative and the length is positive. +func ParseContentRange(hdr string) (start, length, completeLength int64, ok bool) { + // Grammar per RFC 7233, appendix D: + // Content-Range = byte-content-range | other-content-range + // byte-content-range = bytes-unit SP (byte-range-resp | unsatisfied-range) + // bytes-unit = "bytes" + // byte-range-resp = byte-range "/" (complete-length | "*") + // unsatisfied-range = "*/" complete-length + // byte-range = first-byte-pos "-" last-byte-pos + // We do not support other-content-range. + // All other identifiers are 1*DIGIT. + hdr = strings.Trim(hdr, ows) // per RFC 7230, section 3.2 + suffix, hasUnits := strings.CutPrefix(hdr, "bytes ") + suffix, unsatisfied := strings.CutPrefix(suffix, "*/") + if unsatisfied { // i.e., unsatisfied-range + n, ok := parseNumber(suffix) + if !ok { + return start, length, completeLength, false + } + completeLength = n + } else { // i.e., byte-range "/" (complete-length | "*") + prefix, suffix, hasDash := strings.Cut(suffix, "-") + middle, suffix, hasSlash := strings.Cut(suffix, "/") + n, ok0 := parseNumber(prefix) + m, ok1 := parseNumber(middle) + o, ok2 := parseNumber(suffix) + if suffix == "*" { + o, ok2 = -1, true + } + if !hasDash || !hasSlash || !ok0 || !ok1 || !ok2 || m < n || (o >= 0 && o <= m) { + return start, length, completeLength, false + } + start = n + length = m - n + 1 + completeLength = o + } + return start, length, completeLength, hasUnits +} + +// FormatContentRange parses a "Content-Range" header per RFC 7233, section 4.2. +// It only handles "Content-Range" headers where the units is "bytes". +// The "Content-Range" header is usually only specified in HTTP responses. +// +// If start and length are non-positive, then it encodes just the completeLength, +// which must be a non-negative value. +// +// Otherwise, it encodes the start and length as a byte-range, +// and optionally emits the complete length if it is non-negative. +// The length must be positive (as RFC 7233 uses inclusive end offsets). +func FormatContentRange(start, length, completeLength int64) (hdr string, ok bool) { + b := []byte("bytes ") + switch { + case start <= 0 && length <= 0 && completeLength >= 0: // i.e., unsatisfied-range + b = append(b, "*/"...) + b = strconv.AppendUint(b, uint64(completeLength), 10) + ok = true + case start >= 0 && length > 0: // i.e., byte-range "/" (complete-length | "*") + b = strconv.AppendUint(b, uint64(start), 10) + b = append(b, '-') + b = strconv.AppendUint(b, uint64(start+length-1), 10) + b = append(b, '/') + if completeLength >= 0 { + b = strconv.AppendUint(b, uint64(completeLength), 10) + ok = completeLength >= start+length && start+length > 0 + } else { + b = append(b, '*') + ok = true + } + } + return string(b), ok +} + +// parseNumber parses s as an unsigned decimal integer. +// It parses according to the 1*DIGIT grammar, which allows leading zeros. +func parseNumber(s string) (int64, bool) { + suffix := strings.TrimLeft(s, "0123456789") + prefix := s[:len(s)-len(suffix)] + n, err := strconv.ParseInt(prefix, 10, 64) + return n, suffix == "" && err == nil +} diff --git a/util/httphdr/httphdr_test.go b/util/httphdr/httphdr_test.go index 77ec0c3247d3e..81feeaca080d8 100644 --- a/util/httphdr/httphdr_test.go +++ b/util/httphdr/httphdr_test.go @@ -1,96 +1,96 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package httphdr - -import ( - "testing" - - "github.com/google/go-cmp/cmp" -) - -func valOk[T any](v T, ok bool) (out struct { - V T - Ok bool -}) { - out.V = v - out.Ok = ok - return out -} - -func TestRange(t *testing.T) { - tests := []struct { - in string - want []Range - wantOk bool - roundtrip bool - }{ - {"", nil, false, false}, - {"1-3", nil, false, false}, - {"units=1-3", []Range{{1, 3}}, false, false}, - {"bytes=1-3", []Range{{1, 3}}, true, true}, - {"bytes=#-3", nil, false, false}, - {"bytes=#-", nil, false, false}, - {"bytes=13", nil, false, false}, - {"bytes=1-#", nil, false, false}, - {"bytes=-#", nil, false, false}, - {"bytes= , , , ,\t , \t 1-3", []Range{{1, 3}}, true, false}, - {"bytes=1-1", []Range{{1, 1}}, true, true}, - {"bytes=01-01", []Range{{1, 1}}, true, false}, - {"bytes=1-0", nil, false, false}, - {"bytes=0-5,2-3", []Range{{0, 6}, {2, 2}}, true, true}, - {"bytes=2-3,0-5", []Range{{2, 2}, {0, 6}}, true, true}, - {"bytes=0-5,2-,-5", []Range{{0, 6}, {2, 0}, {0, -5}}, true, true}, - } - - for _, tt := range tests { - got, gotOk := ParseRange(tt.in) - if d := cmp.Diff(valOk(got, gotOk), valOk(tt.want, tt.wantOk)); d != "" { - t.Errorf("ParseRange(%q) mismatch (-got +want):\n%s", tt.in, d) - } - if tt.roundtrip { - got, gotOk := FormatRange(tt.want) - if d := cmp.Diff(valOk(got, gotOk), valOk(tt.in, tt.wantOk)); d != "" { - t.Errorf("FormatRange(%v) mismatch (-got +want):\n%s", tt.want, d) - } - } - } -} - -type contentRange struct{ Start, Length, CompleteLength int64 } - -func TestContentRange(t *testing.T) { - tests := []struct { - in string - want contentRange - wantOk bool - roundtrip bool - }{ - {"", contentRange{}, false, false}, - {"bytes 5-6/*", contentRange{5, 2, -1}, true, true}, - {"units 5-6/*", contentRange{}, false, false}, - {"bytes 5-6/*", contentRange{}, false, false}, - {"bytes 5-5/*", contentRange{5, 1, -1}, true, true}, - {"bytes 5-4/*", contentRange{}, false, false}, - {"bytes 5-5/6", contentRange{5, 1, 6}, true, true}, - {"bytes 05-005/0006", contentRange{5, 1, 6}, true, false}, - {"bytes 5-5/5", contentRange{}, false, false}, - {"bytes #-5/6", contentRange{}, false, false}, - {"bytes 5-#/6", contentRange{}, false, false}, - {"bytes 5-5/#", contentRange{}, false, false}, - } - - for _, tt := range tests { - start, length, completeLength, gotOk := ParseContentRange(tt.in) - got := contentRange{start, length, completeLength} - if d := cmp.Diff(valOk(got, gotOk), valOk(tt.want, tt.wantOk)); d != "" { - t.Errorf("ParseContentRange mismatch (-got +want):\n%s", d) - } - if tt.roundtrip { - got, gotOk := FormatContentRange(tt.want.Start, tt.want.Length, tt.want.CompleteLength) - if d := cmp.Diff(valOk(got, gotOk), valOk(tt.in, tt.wantOk)); d != "" { - t.Errorf("FormatContentRange mismatch (-got +want):\n%s", d) - } - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package httphdr + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func valOk[T any](v T, ok bool) (out struct { + V T + Ok bool +}) { + out.V = v + out.Ok = ok + return out +} + +func TestRange(t *testing.T) { + tests := []struct { + in string + want []Range + wantOk bool + roundtrip bool + }{ + {"", nil, false, false}, + {"1-3", nil, false, false}, + {"units=1-3", []Range{{1, 3}}, false, false}, + {"bytes=1-3", []Range{{1, 3}}, true, true}, + {"bytes=#-3", nil, false, false}, + {"bytes=#-", nil, false, false}, + {"bytes=13", nil, false, false}, + {"bytes=1-#", nil, false, false}, + {"bytes=-#", nil, false, false}, + {"bytes= , , , ,\t , \t 1-3", []Range{{1, 3}}, true, false}, + {"bytes=1-1", []Range{{1, 1}}, true, true}, + {"bytes=01-01", []Range{{1, 1}}, true, false}, + {"bytes=1-0", nil, false, false}, + {"bytes=0-5,2-3", []Range{{0, 6}, {2, 2}}, true, true}, + {"bytes=2-3,0-5", []Range{{2, 2}, {0, 6}}, true, true}, + {"bytes=0-5,2-,-5", []Range{{0, 6}, {2, 0}, {0, -5}}, true, true}, + } + + for _, tt := range tests { + got, gotOk := ParseRange(tt.in) + if d := cmp.Diff(valOk(got, gotOk), valOk(tt.want, tt.wantOk)); d != "" { + t.Errorf("ParseRange(%q) mismatch (-got +want):\n%s", tt.in, d) + } + if tt.roundtrip { + got, gotOk := FormatRange(tt.want) + if d := cmp.Diff(valOk(got, gotOk), valOk(tt.in, tt.wantOk)); d != "" { + t.Errorf("FormatRange(%v) mismatch (-got +want):\n%s", tt.want, d) + } + } + } +} + +type contentRange struct{ Start, Length, CompleteLength int64 } + +func TestContentRange(t *testing.T) { + tests := []struct { + in string + want contentRange + wantOk bool + roundtrip bool + }{ + {"", contentRange{}, false, false}, + {"bytes 5-6/*", contentRange{5, 2, -1}, true, true}, + {"units 5-6/*", contentRange{}, false, false}, + {"bytes 5-6/*", contentRange{}, false, false}, + {"bytes 5-5/*", contentRange{5, 1, -1}, true, true}, + {"bytes 5-4/*", contentRange{}, false, false}, + {"bytes 5-5/6", contentRange{5, 1, 6}, true, true}, + {"bytes 05-005/0006", contentRange{5, 1, 6}, true, false}, + {"bytes 5-5/5", contentRange{}, false, false}, + {"bytes #-5/6", contentRange{}, false, false}, + {"bytes 5-#/6", contentRange{}, false, false}, + {"bytes 5-5/#", contentRange{}, false, false}, + } + + for _, tt := range tests { + start, length, completeLength, gotOk := ParseContentRange(tt.in) + got := contentRange{start, length, completeLength} + if d := cmp.Diff(valOk(got, gotOk), valOk(tt.want, tt.wantOk)); d != "" { + t.Errorf("ParseContentRange mismatch (-got +want):\n%s", d) + } + if tt.roundtrip { + got, gotOk := FormatContentRange(tt.want.Start, tt.want.Length, tt.want.CompleteLength) + if d := cmp.Diff(valOk(got, gotOk), valOk(tt.in, tt.wantOk)); d != "" { + t.Errorf("FormatContentRange mismatch (-got +want):\n%s", d) + } + } + } +} diff --git a/util/httpm/httpm.go b/util/httpm/httpm.go index 05292f0fa1fa2..a9a691b8a69e2 100644 --- a/util/httpm/httpm.go +++ b/util/httpm/httpm.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package httpm has shorter names for HTTP method constants. -// -// Some background: originally Go didn't have http.MethodGet, http.MethodPost -// and life was good and people just wrote readable "GET" and "POST". But then -// in a moment of weakness Brad and others maintaining net/http caved and let -// the http.MethodFoo constants be added and code's been less readable since. -// Now the substance of the method name is hidden away at the end after -// "http.Method" and they all blend together and it's hard to read code using -// them. -// -// This package is a compromise. It provides constants, but shorter and closer -// to how it used to look. It does violate Go style -// (https://github.com/golang/go/wiki/CodeReviewComments#mixed-caps) that says -// constants shouldn't be SCREAM_CASE. But this isn't INT_MAX; it's GET and -// POST, which are already defined as all caps. -// -// It would be tempting to make these constants be typed but then they wouldn't -// be assignable to things in net/http that just want string. Oh well. -package httpm - -const ( - GET = "GET" - HEAD = "HEAD" - POST = "POST" - PUT = "PUT" - PATCH = "PATCH" - DELETE = "DELETE" - CONNECT = "CONNECT" - OPTIONS = "OPTIONS" - TRACE = "TRACE" - SPACEJUMP = "SPACEJUMP" // https://www.w3.org/Protocols/HTTP/Methods/SpaceJump.html - BREW = "BREW" // https://datatracker.ietf.org/doc/html/rfc2324#section-2.1.1 -) +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package httpm has shorter names for HTTP method constants. +// +// Some background: originally Go didn't have http.MethodGet, http.MethodPost +// and life was good and people just wrote readable "GET" and "POST". But then +// in a moment of weakness Brad and others maintaining net/http caved and let +// the http.MethodFoo constants be added and code's been less readable since. +// Now the substance of the method name is hidden away at the end after +// "http.Method" and they all blend together and it's hard to read code using +// them. +// +// This package is a compromise. It provides constants, but shorter and closer +// to how it used to look. It does violate Go style +// (https://github.com/golang/go/wiki/CodeReviewComments#mixed-caps) that says +// constants shouldn't be SCREAM_CASE. But this isn't INT_MAX; it's GET and +// POST, which are already defined as all caps. +// +// It would be tempting to make these constants be typed but then they wouldn't +// be assignable to things in net/http that just want string. Oh well. +package httpm + +const ( + GET = "GET" + HEAD = "HEAD" + POST = "POST" + PUT = "PUT" + PATCH = "PATCH" + DELETE = "DELETE" + CONNECT = "CONNECT" + OPTIONS = "OPTIONS" + TRACE = "TRACE" + SPACEJUMP = "SPACEJUMP" // https://www.w3.org/Protocols/HTTP/Methods/SpaceJump.html + BREW = "BREW" // https://datatracker.ietf.org/doc/html/rfc2324#section-2.1.1 +) diff --git a/util/httpm/httpm_test.go b/util/httpm/httpm_test.go index cbe327d956083..0c71edc2f3c42 100644 --- a/util/httpm/httpm_test.go +++ b/util/httpm/httpm_test.go @@ -1,37 +1,37 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package httpm - -import ( - "os" - "os/exec" - "path/filepath" - "strings" - "testing" -) - -func TestUsedConsistently(t *testing.T) { - dir, err := os.Getwd() - if err != nil { - t.Fatal(err) - } - rootDir := filepath.Join(dir, "../..") - - // If we don't have a .git directory, we're not in a git checkout (e.g. - // a downstream package); skip this test. - if _, err := os.Stat(filepath.Join(rootDir, ".git")); err != nil { - t.Skipf("skipping test since .git doesn't exist: %v", err) - } - - cmd := exec.Command("git", "grep", "-l", "-F", "http.Method") - cmd.Dir = rootDir - matches, _ := cmd.Output() - for _, fn := range strings.Split(strings.TrimSpace(string(matches)), "\n") { - switch fn { - case "util/httpm/httpm.go", "util/httpm/httpm_test.go": - continue - } - t.Errorf("http.MethodFoo constant used in %s; use httpm.FOO instead", fn) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package httpm + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +func TestUsedConsistently(t *testing.T) { + dir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + rootDir := filepath.Join(dir, "../..") + + // If we don't have a .git directory, we're not in a git checkout (e.g. + // a downstream package); skip this test. + if _, err := os.Stat(filepath.Join(rootDir, ".git")); err != nil { + t.Skipf("skipping test since .git doesn't exist: %v", err) + } + + cmd := exec.Command("git", "grep", "-l", "-F", "http.Method") + cmd.Dir = rootDir + matches, _ := cmd.Output() + for _, fn := range strings.Split(strings.TrimSpace(string(matches)), "\n") { + switch fn { + case "util/httpm/httpm.go", "util/httpm/httpm_test.go": + continue + } + t.Errorf("http.MethodFoo constant used in %s; use httpm.FOO instead", fn) + } +} diff --git a/util/jsonutil/types.go b/util/jsonutil/types.go index 2ee53f44a1037..057473249f258 100644 --- a/util/jsonutil/types.go +++ b/util/jsonutil/types.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package jsonutil - -// Bytes is a byte slice in a json-encoded struct. -// encoding/json assumes that []byte fields are hex-encoded. -// Bytes are not hex-encoded; they are treated the same as strings. -// This can avoid unnecessary allocations due to a round trip through strings. -type Bytes []byte - -func (b *Bytes) UnmarshalText(text []byte) error { - // Copy the contexts of text. - *b = append(*b, text...) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package jsonutil + +// Bytes is a byte slice in a json-encoded struct. +// encoding/json assumes that []byte fields are hex-encoded. +// Bytes are not hex-encoded; they are treated the same as strings. +// This can avoid unnecessary allocations due to a round trip through strings. +type Bytes []byte + +func (b *Bytes) UnmarshalText(text []byte) error { + // Copy the contexts of text. + *b = append(*b, text...) + return nil +} diff --git a/util/jsonutil/unmarshal.go b/util/jsonutil/unmarshal.go index 13aea0c87ff30..b1eb4ea873e67 100644 --- a/util/jsonutil/unmarshal.go +++ b/util/jsonutil/unmarshal.go @@ -1,89 +1,89 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package jsonutil provides utilities to improve JSON performance. -// It includes an Unmarshal wrapper that amortizes allocated garbage over subsequent runs -// and a Bytes type to reduce allocations when unmarshalling a non-hex-encoded string into a []byte. -package jsonutil - -import ( - "bytes" - "encoding/json" - "sync" -) - -// decoder is a re-usable json decoder. -type decoder struct { - dec *json.Decoder - r *bytes.Reader -} - -var readerPool = sync.Pool{ - New: func() any { - return bytes.NewReader(nil) - }, -} - -var decoderPool = sync.Pool{ - New: func() any { - var d decoder - d.r = readerPool.Get().(*bytes.Reader) - d.dec = json.NewDecoder(d.r) - return &d - }, -} - -// Unmarshal is similar to encoding/json.Unmarshal. -// There are three major differences: -// -// On error, encoding/json.Unmarshal zeros v. -// This Unmarshal may leave partial data in v. -// Always check the error before using v! -// (Future improvements may remove this bug.) -// -// The errors they return don't always match perfectly. -// If you do error matching more precise than err != nil, -// don't use this Unmarshal. -// -// This Unmarshal allocates considerably less memory. -func Unmarshal(b []byte, v any) error { - d := decoderPool.Get().(*decoder) - d.r.Reset(b) - off := d.dec.InputOffset() - err := d.dec.Decode(v) - d.r.Reset(nil) // don't keep a reference to b - // In case of error, report the offset in this byte slice, - // instead of in the totality of all bytes this decoder has processed. - // It is not possible to make all errors match json.Unmarshal exactly, - // but we can at least try. - switch jsonerr := err.(type) { - case *json.SyntaxError: - jsonerr.Offset -= off - case *json.UnmarshalTypeError: - jsonerr.Offset -= off - case nil: - // json.Unmarshal fails if there's any extra junk in the input. - // json.Decoder does not; see https://github.com/golang/go/issues/36225. - // We need to check for anything left over in the buffer. - if d.dec.More() { - // TODO: Provide a better error message. - // Unfortunately, we can't set the msg field. - // The offset doesn't perfectly match json: - // Ours is at the end of the valid data, - // and theirs is at the beginning of the extra data after whitespace. - // Close enough, though. - err = &json.SyntaxError{Offset: d.dec.InputOffset() - off} - - // TODO: zero v. This is hard; see encoding/json.indirect. - } - } - if err == nil { - decoderPool.Put(d) - } else { - // There might be junk left in the decoder's buffer. - // There's no way to flush it, no Reset method. - // Abandoned the decoder but reuse the reader. - readerPool.Put(d.r) - } - return err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package jsonutil provides utilities to improve JSON performance. +// It includes an Unmarshal wrapper that amortizes allocated garbage over subsequent runs +// and a Bytes type to reduce allocations when unmarshalling a non-hex-encoded string into a []byte. +package jsonutil + +import ( + "bytes" + "encoding/json" + "sync" +) + +// decoder is a re-usable json decoder. +type decoder struct { + dec *json.Decoder + r *bytes.Reader +} + +var readerPool = sync.Pool{ + New: func() any { + return bytes.NewReader(nil) + }, +} + +var decoderPool = sync.Pool{ + New: func() any { + var d decoder + d.r = readerPool.Get().(*bytes.Reader) + d.dec = json.NewDecoder(d.r) + return &d + }, +} + +// Unmarshal is similar to encoding/json.Unmarshal. +// There are three major differences: +// +// On error, encoding/json.Unmarshal zeros v. +// This Unmarshal may leave partial data in v. +// Always check the error before using v! +// (Future improvements may remove this bug.) +// +// The errors they return don't always match perfectly. +// If you do error matching more precise than err != nil, +// don't use this Unmarshal. +// +// This Unmarshal allocates considerably less memory. +func Unmarshal(b []byte, v any) error { + d := decoderPool.Get().(*decoder) + d.r.Reset(b) + off := d.dec.InputOffset() + err := d.dec.Decode(v) + d.r.Reset(nil) // don't keep a reference to b + // In case of error, report the offset in this byte slice, + // instead of in the totality of all bytes this decoder has processed. + // It is not possible to make all errors match json.Unmarshal exactly, + // but we can at least try. + switch jsonerr := err.(type) { + case *json.SyntaxError: + jsonerr.Offset -= off + case *json.UnmarshalTypeError: + jsonerr.Offset -= off + case nil: + // json.Unmarshal fails if there's any extra junk in the input. + // json.Decoder does not; see https://github.com/golang/go/issues/36225. + // We need to check for anything left over in the buffer. + if d.dec.More() { + // TODO: Provide a better error message. + // Unfortunately, we can't set the msg field. + // The offset doesn't perfectly match json: + // Ours is at the end of the valid data, + // and theirs is at the beginning of the extra data after whitespace. + // Close enough, though. + err = &json.SyntaxError{Offset: d.dec.InputOffset() - off} + + // TODO: zero v. This is hard; see encoding/json.indirect. + } + } + if err == nil { + decoderPool.Put(d) + } else { + // There might be junk left in the decoder's buffer. + // There's no way to flush it, no Reset method. + // Abandoned the decoder but reuse the reader. + readerPool.Put(d.r) + } + return err +} diff --git a/util/lineread/lineread.go b/util/lineread/lineread.go index 2a7486e0a4fec..6b01d2b69ffd7 100644 --- a/util/lineread/lineread.go +++ b/util/lineread/lineread.go @@ -1,37 +1,37 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package lineread reads lines from files. It's not fancy, but it got repetitive. -package lineread - -import ( - "bufio" - "io" - "os" -) - -// File opens name and calls fn for each line. It returns an error if the Open failed -// or once fn returns an error. -func File(name string, fn func(line []byte) error) error { - f, err := os.Open(name) - if err != nil { - return err - } - defer f.Close() - return Reader(f, fn) -} - -// Reader calls fn for each line. -// If fn returns an error, Reader stops reading and returns that error. -// Reader may also return errors encountered reading and parsing from r. -// To stop reading early, use a sentinel "stop" error value and ignore -// it when returned from Reader. -func Reader(r io.Reader, fn func(line []byte) error) error { - bs := bufio.NewScanner(r) - for bs.Scan() { - if err := fn(bs.Bytes()); err != nil { - return err - } - } - return bs.Err() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package lineread reads lines from files. It's not fancy, but it got repetitive. +package lineread + +import ( + "bufio" + "io" + "os" +) + +// File opens name and calls fn for each line. It returns an error if the Open failed +// or once fn returns an error. +func File(name string, fn func(line []byte) error) error { + f, err := os.Open(name) + if err != nil { + return err + } + defer f.Close() + return Reader(f, fn) +} + +// Reader calls fn for each line. +// If fn returns an error, Reader stops reading and returns that error. +// Reader may also return errors encountered reading and parsing from r. +// To stop reading early, use a sentinel "stop" error value and ignore +// it when returned from Reader. +func Reader(r io.Reader, fn func(line []byte) error) error { + bs := bufio.NewScanner(r) + for bs.Scan() { + if err := fn(bs.Bytes()); err != nil { + return err + } + } + return bs.Err() +} diff --git a/util/linuxfw/linuxfwtest/linuxfwtest.go b/util/linuxfw/linuxfwtest/linuxfwtest.go index 04f179199fb6b..ee2cbd1b227f4 100644 --- a/util/linuxfw/linuxfwtest/linuxfwtest.go +++ b/util/linuxfw/linuxfwtest/linuxfwtest.go @@ -1,31 +1,31 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build cgo && linux - -// Package linuxfwtest contains tests for the linuxfw package. Go does not -// support cgo in tests, and we don't want the main package to have a cgo -// dependency, so we put all the tests here and call them from the main package -// in tests intead. -package linuxfwtest - -import ( - "testing" - "unsafe" -) - -/* -#include // socket() -*/ -import "C" - -type SizeInfo struct { - SizeofSocklen uintptr -} - -func TestSizes(t *testing.T, si *SizeInfo) { - want := unsafe.Sizeof(C.socklen_t(0)) - if want != si.SizeofSocklen { - t.Errorf("sockLen has wrong size; want=%d got=%d", want, si.SizeofSocklen) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build cgo && linux + +// Package linuxfwtest contains tests for the linuxfw package. Go does not +// support cgo in tests, and we don't want the main package to have a cgo +// dependency, so we put all the tests here and call them from the main package +// in tests intead. +package linuxfwtest + +import ( + "testing" + "unsafe" +) + +/* +#include // socket() +*/ +import "C" + +type SizeInfo struct { + SizeofSocklen uintptr +} + +func TestSizes(t *testing.T, si *SizeInfo) { + want := unsafe.Sizeof(C.socklen_t(0)) + if want != si.SizeofSocklen { + t.Errorf("sockLen has wrong size; want=%d got=%d", want, si.SizeofSocklen) + } +} diff --git a/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go b/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go index d5e297da7b965..6e95699001d4b 100644 --- a/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go +++ b/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !cgo || !linux - -package linuxfwtest - -import ( - "testing" -) - -type SizeInfo struct { - SizeofSocklen uintptr -} - -func TestSizes(t *testing.T, si *SizeInfo) { - t.Skip("not supported without cgo") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !cgo || !linux + +package linuxfwtest + +import ( + "testing" +) + +type SizeInfo struct { + SizeofSocklen uintptr +} + +func TestSizes(t *testing.T, si *SizeInfo) { + t.Skip("not supported without cgo") +} diff --git a/util/linuxfw/nftables_types.go b/util/linuxfw/nftables_types.go index a8c5a0730dbd3..b6e24d2a67b5b 100644 --- a/util/linuxfw/nftables_types.go +++ b/util/linuxfw/nftables_types.go @@ -1,95 +1,95 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// TODO(#8502): add support for more architectures -//go:build linux && (arm64 || amd64) - -package linuxfw - -import ( - "github.com/google/nftables/expr" - "github.com/google/nftables/xt" -) - -var metaKeyNames = map[expr.MetaKey]string{ - expr.MetaKeyLEN: "LEN", - expr.MetaKeyPROTOCOL: "PROTOCOL", - expr.MetaKeyPRIORITY: "PRIORITY", - expr.MetaKeyMARK: "MARK", - expr.MetaKeyIIF: "IIF", - expr.MetaKeyOIF: "OIF", - expr.MetaKeyIIFNAME: "IIFNAME", - expr.MetaKeyOIFNAME: "OIFNAME", - expr.MetaKeyIIFTYPE: "IIFTYPE", - expr.MetaKeyOIFTYPE: "OIFTYPE", - expr.MetaKeySKUID: "SKUID", - expr.MetaKeySKGID: "SKGID", - expr.MetaKeyNFTRACE: "NFTRACE", - expr.MetaKeyRTCLASSID: "RTCLASSID", - expr.MetaKeySECMARK: "SECMARK", - expr.MetaKeyNFPROTO: "NFPROTO", - expr.MetaKeyL4PROTO: "L4PROTO", - expr.MetaKeyBRIIIFNAME: "BRIIIFNAME", - expr.MetaKeyBRIOIFNAME: "BRIOIFNAME", - expr.MetaKeyPKTTYPE: "PKTTYPE", - expr.MetaKeyCPU: "CPU", - expr.MetaKeyIIFGROUP: "IIFGROUP", - expr.MetaKeyOIFGROUP: "OIFGROUP", - expr.MetaKeyCGROUP: "CGROUP", - expr.MetaKeyPRANDOM: "PRANDOM", -} - -var cmpOpNames = map[expr.CmpOp]string{ - expr.CmpOpEq: "EQ", - expr.CmpOpNeq: "NEQ", - expr.CmpOpLt: "LT", - expr.CmpOpLte: "LTE", - expr.CmpOpGt: "GT", - expr.CmpOpGte: "GTE", -} - -var verdictNames = map[expr.VerdictKind]string{ - expr.VerdictReturn: "RETURN", - expr.VerdictGoto: "GOTO", - expr.VerdictJump: "JUMP", - expr.VerdictBreak: "BREAK", - expr.VerdictContinue: "CONTINUE", - expr.VerdictDrop: "DROP", - expr.VerdictAccept: "ACCEPT", - expr.VerdictStolen: "STOLEN", - expr.VerdictQueue: "QUEUE", - expr.VerdictRepeat: "REPEAT", - expr.VerdictStop: "STOP", -} - -var payloadOperationTypeNames = map[expr.PayloadOperationType]string{ - expr.PayloadLoad: "LOAD", - expr.PayloadWrite: "WRITE", -} - -var payloadBaseNames = map[expr.PayloadBase]string{ - expr.PayloadBaseLLHeader: "ll-header", - expr.PayloadBaseNetworkHeader: "network-header", - expr.PayloadBaseTransportHeader: "transport-header", -} - -var packetTypeNames = map[int]string{ - 0 /* PACKET_HOST */ : "unicast", - 1 /* PACKET_BROADCAST */ : "broadcast", - 2 /* PACKET_MULTICAST */ : "multicast", -} - -var addrTypeFlagNames = map[xt.AddrTypeFlags]string{ - xt.AddrTypeUnspec: "unspec", - xt.AddrTypeUnicast: "unicast", - xt.AddrTypeLocal: "local", - xt.AddrTypeBroadcast: "broadcast", - xt.AddrTypeAnycast: "anycast", - xt.AddrTypeMulticast: "multicast", - xt.AddrTypeBlackhole: "blackhole", - xt.AddrTypeUnreachable: "unreachable", - xt.AddrTypeProhibit: "prohibit", - xt.AddrTypeThrow: "throw", - xt.AddrTypeNat: "nat", - xt.AddrTypeXresolve: "xresolve", -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// TODO(#8502): add support for more architectures +//go:build linux && (arm64 || amd64) + +package linuxfw + +import ( + "github.com/google/nftables/expr" + "github.com/google/nftables/xt" +) + +var metaKeyNames = map[expr.MetaKey]string{ + expr.MetaKeyLEN: "LEN", + expr.MetaKeyPROTOCOL: "PROTOCOL", + expr.MetaKeyPRIORITY: "PRIORITY", + expr.MetaKeyMARK: "MARK", + expr.MetaKeyIIF: "IIF", + expr.MetaKeyOIF: "OIF", + expr.MetaKeyIIFNAME: "IIFNAME", + expr.MetaKeyOIFNAME: "OIFNAME", + expr.MetaKeyIIFTYPE: "IIFTYPE", + expr.MetaKeyOIFTYPE: "OIFTYPE", + expr.MetaKeySKUID: "SKUID", + expr.MetaKeySKGID: "SKGID", + expr.MetaKeyNFTRACE: "NFTRACE", + expr.MetaKeyRTCLASSID: "RTCLASSID", + expr.MetaKeySECMARK: "SECMARK", + expr.MetaKeyNFPROTO: "NFPROTO", + expr.MetaKeyL4PROTO: "L4PROTO", + expr.MetaKeyBRIIIFNAME: "BRIIIFNAME", + expr.MetaKeyBRIOIFNAME: "BRIOIFNAME", + expr.MetaKeyPKTTYPE: "PKTTYPE", + expr.MetaKeyCPU: "CPU", + expr.MetaKeyIIFGROUP: "IIFGROUP", + expr.MetaKeyOIFGROUP: "OIFGROUP", + expr.MetaKeyCGROUP: "CGROUP", + expr.MetaKeyPRANDOM: "PRANDOM", +} + +var cmpOpNames = map[expr.CmpOp]string{ + expr.CmpOpEq: "EQ", + expr.CmpOpNeq: "NEQ", + expr.CmpOpLt: "LT", + expr.CmpOpLte: "LTE", + expr.CmpOpGt: "GT", + expr.CmpOpGte: "GTE", +} + +var verdictNames = map[expr.VerdictKind]string{ + expr.VerdictReturn: "RETURN", + expr.VerdictGoto: "GOTO", + expr.VerdictJump: "JUMP", + expr.VerdictBreak: "BREAK", + expr.VerdictContinue: "CONTINUE", + expr.VerdictDrop: "DROP", + expr.VerdictAccept: "ACCEPT", + expr.VerdictStolen: "STOLEN", + expr.VerdictQueue: "QUEUE", + expr.VerdictRepeat: "REPEAT", + expr.VerdictStop: "STOP", +} + +var payloadOperationTypeNames = map[expr.PayloadOperationType]string{ + expr.PayloadLoad: "LOAD", + expr.PayloadWrite: "WRITE", +} + +var payloadBaseNames = map[expr.PayloadBase]string{ + expr.PayloadBaseLLHeader: "ll-header", + expr.PayloadBaseNetworkHeader: "network-header", + expr.PayloadBaseTransportHeader: "transport-header", +} + +var packetTypeNames = map[int]string{ + 0 /* PACKET_HOST */ : "unicast", + 1 /* PACKET_BROADCAST */ : "broadcast", + 2 /* PACKET_MULTICAST */ : "multicast", +} + +var addrTypeFlagNames = map[xt.AddrTypeFlags]string{ + xt.AddrTypeUnspec: "unspec", + xt.AddrTypeUnicast: "unicast", + xt.AddrTypeLocal: "local", + xt.AddrTypeBroadcast: "broadcast", + xt.AddrTypeAnycast: "anycast", + xt.AddrTypeMulticast: "multicast", + xt.AddrTypeBlackhole: "blackhole", + xt.AddrTypeUnreachable: "unreachable", + xt.AddrTypeProhibit: "prohibit", + xt.AddrTypeThrow: "throw", + xt.AddrTypeNat: "nat", + xt.AddrTypeXresolve: "xresolve", +} diff --git a/util/mak/mak.go b/util/mak/mak.go index b0d64daa422d4..b421fb0ed5a55 100644 --- a/util/mak/mak.go +++ b/util/mak/mak.go @@ -1,70 +1,70 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package mak helps make maps. It contains generic helpers to make/assign -// things, notably to maps, but also slices. -package mak - -import ( - "fmt" - "reflect" -) - -// Set populates an entry in a map, making the map if necessary. -// -// That is, it assigns (*m)[k] = v, making *m if it was nil. -func Set[K comparable, V any, T ~map[K]V](m *T, k K, v V) { - if *m == nil { - *m = make(map[K]V) - } - (*m)[k] = v -} - -// NonNil takes a pointer to a Go data structure -// (currently only a slice or a map) and makes sure it's non-nil for -// JSON serialization. (In particular, JavaScript clients usually want -// the field to be defined after they decode the JSON.) -// -// Deprecated: use NonNilSliceForJSON or NonNilMapForJSON instead. -func NonNil(ptr any) { - if ptr == nil { - panic("nil interface") - } - rv := reflect.ValueOf(ptr) - if rv.Kind() != reflect.Ptr { - panic(fmt.Sprintf("kind %v, not Ptr", rv.Kind())) - } - if rv.Pointer() == 0 { - panic("nil pointer") - } - rv = rv.Elem() - if rv.Pointer() != 0 { - return - } - switch rv.Type().Kind() { - case reflect.Slice: - rv.Set(reflect.MakeSlice(rv.Type(), 0, 0)) - case reflect.Map: - rv.Set(reflect.MakeMap(rv.Type())) - } -} - -// NonNilSliceForJSON makes sure that *slicePtr is non-nil so it will -// won't be omitted from JSON serialization and possibly confuse JavaScript -// clients expecting it to be present. -func NonNilSliceForJSON[T any, S ~[]T](slicePtr *S) { - if *slicePtr != nil { - return - } - *slicePtr = make([]T, 0) -} - -// NonNilMapForJSON makes sure that *slicePtr is non-nil so it will -// won't be omitted from JSON serialization and possibly confuse JavaScript -// clients expecting it to be present. -func NonNilMapForJSON[K comparable, V any, M ~map[K]V](mapPtr *M) { - if *mapPtr != nil { - return - } - *mapPtr = make(M) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package mak helps make maps. It contains generic helpers to make/assign +// things, notably to maps, but also slices. +package mak + +import ( + "fmt" + "reflect" +) + +// Set populates an entry in a map, making the map if necessary. +// +// That is, it assigns (*m)[k] = v, making *m if it was nil. +func Set[K comparable, V any, T ~map[K]V](m *T, k K, v V) { + if *m == nil { + *m = make(map[K]V) + } + (*m)[k] = v +} + +// NonNil takes a pointer to a Go data structure +// (currently only a slice or a map) and makes sure it's non-nil for +// JSON serialization. (In particular, JavaScript clients usually want +// the field to be defined after they decode the JSON.) +// +// Deprecated: use NonNilSliceForJSON or NonNilMapForJSON instead. +func NonNil(ptr any) { + if ptr == nil { + panic("nil interface") + } + rv := reflect.ValueOf(ptr) + if rv.Kind() != reflect.Ptr { + panic(fmt.Sprintf("kind %v, not Ptr", rv.Kind())) + } + if rv.Pointer() == 0 { + panic("nil pointer") + } + rv = rv.Elem() + if rv.Pointer() != 0 { + return + } + switch rv.Type().Kind() { + case reflect.Slice: + rv.Set(reflect.MakeSlice(rv.Type(), 0, 0)) + case reflect.Map: + rv.Set(reflect.MakeMap(rv.Type())) + } +} + +// NonNilSliceForJSON makes sure that *slicePtr is non-nil so it will +// won't be omitted from JSON serialization and possibly confuse JavaScript +// clients expecting it to be present. +func NonNilSliceForJSON[T any, S ~[]T](slicePtr *S) { + if *slicePtr != nil { + return + } + *slicePtr = make([]T, 0) +} + +// NonNilMapForJSON makes sure that *slicePtr is non-nil so it will +// won't be omitted from JSON serialization and possibly confuse JavaScript +// clients expecting it to be present. +func NonNilMapForJSON[K comparable, V any, M ~map[K]V](mapPtr *M) { + if *mapPtr != nil { + return + } + *mapPtr = make(M) +} diff --git a/util/mak/mak_test.go b/util/mak/mak_test.go index dc1d7e93d7b19..4de499a9d5040 100644 --- a/util/mak/mak_test.go +++ b/util/mak/mak_test.go @@ -1,88 +1,88 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package mak contains code to help make things. -package mak - -import ( - "reflect" - "testing" -) - -type M map[string]int - -func TestSet(t *testing.T) { - t.Run("unnamed", func(t *testing.T) { - var m map[string]int - Set(&m, "foo", 42) - Set(&m, "bar", 1) - Set(&m, "bar", 2) - want := map[string]int{ - "foo": 42, - "bar": 2, - } - if got := m; !reflect.DeepEqual(got, want) { - t.Errorf("got %v; want %v", got, want) - } - }) - t.Run("named", func(t *testing.T) { - var m M - Set(&m, "foo", 1) - Set(&m, "bar", 1) - Set(&m, "bar", 2) - want := M{ - "foo": 1, - "bar": 2, - } - if got := m; !reflect.DeepEqual(got, want) { - t.Errorf("got %v; want %v", got, want) - } - }) -} - -func TestNonNil(t *testing.T) { - var s []string - NonNil(&s) - if len(s) != 0 { - t.Errorf("slice len = %d; want 0", len(s)) - } - if s == nil { - t.Error("slice still nil") - } - - s = append(s, "foo") - NonNil(&s) - if len(s) != 1 { - t.Errorf("len = %d; want 1", len(s)) - } - if s[0] != "foo" { - t.Errorf("value = %q; want foo", s) - } - - var m map[string]string - NonNil(&m) - if len(m) != 0 { - t.Errorf("map len = %d; want 0", len(s)) - } - if m == nil { - t.Error("map still nil") - } -} - -func TestNonNilMapForJSON(t *testing.T) { - type M map[string]int - var m M - NonNilMapForJSON(&m) - if m == nil { - t.Fatal("still nil") - } -} - -func TestNonNilSliceForJSON(t *testing.T) { - type S []int - var s S - NonNilSliceForJSON(&s) - if s == nil { - t.Fatal("still nil") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package mak contains code to help make things. +package mak + +import ( + "reflect" + "testing" +) + +type M map[string]int + +func TestSet(t *testing.T) { + t.Run("unnamed", func(t *testing.T) { + var m map[string]int + Set(&m, "foo", 42) + Set(&m, "bar", 1) + Set(&m, "bar", 2) + want := map[string]int{ + "foo": 42, + "bar": 2, + } + if got := m; !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } + }) + t.Run("named", func(t *testing.T) { + var m M + Set(&m, "foo", 1) + Set(&m, "bar", 1) + Set(&m, "bar", 2) + want := M{ + "foo": 1, + "bar": 2, + } + if got := m; !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } + }) +} + +func TestNonNil(t *testing.T) { + var s []string + NonNil(&s) + if len(s) != 0 { + t.Errorf("slice len = %d; want 0", len(s)) + } + if s == nil { + t.Error("slice still nil") + } + + s = append(s, "foo") + NonNil(&s) + if len(s) != 1 { + t.Errorf("len = %d; want 1", len(s)) + } + if s[0] != "foo" { + t.Errorf("value = %q; want foo", s) + } + + var m map[string]string + NonNil(&m) + if len(m) != 0 { + t.Errorf("map len = %d; want 0", len(s)) + } + if m == nil { + t.Error("map still nil") + } +} + +func TestNonNilMapForJSON(t *testing.T) { + type M map[string]int + var m M + NonNilMapForJSON(&m) + if m == nil { + t.Fatal("still nil") + } +} + +func TestNonNilSliceForJSON(t *testing.T) { + type S []int + var s S + NonNilSliceForJSON(&s) + if s == nil { + t.Fatal("still nil") + } +} diff --git a/util/multierr/multierr.go b/util/multierr/multierr.go index 5ec36f644b73c..93ca068f56532 100644 --- a/util/multierr/multierr.go +++ b/util/multierr/multierr.go @@ -1,136 +1,136 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package multierr provides a simple multiple-error type. -// It was inspired by github.com/go-multierror/multierror. -package multierr - -import ( - "errors" - "slices" - "strings" -) - -// An Error represents multiple errors. -type Error struct { - errs []error -} - -// Error implements the error interface. -func (e Error) Error() string { - s := new(strings.Builder) - s.WriteString("multiple errors:") - for _, err := range e.errs { - s.WriteString("\n\t") - s.WriteString(err.Error()) - } - return s.String() -} - -// Errors returns a slice containing all errors in e. -func (e Error) Errors() []error { - return slices.Clone(e.errs) -} - -// Unwrap returns the underlying errors as-is. -func (e Error) Unwrap() []error { - // Do not clone since Unwrap requires callers to not mutate the slice. - // See the documentation in the Go "errors" package. - return e.errs -} - -// New returns an error composed from errs. -// Some errors in errs get special treatment: -// - nil errors are discarded -// - errors of type Error are expanded into the top level -// -// If the resulting slice has length 0, New returns nil. -// If the resulting slice has length 1, New returns that error. -// If the resulting slice has length > 1, New returns that slice as an Error. -func New(errs ...error) error { - // First count the number of errors to avoid allocating. - var n int - var errFirst error - for _, e := range errs { - switch e := e.(type) { - case nil: - continue - case Error: - n += len(e.errs) - if errFirst == nil && len(e.errs) > 0 { - errFirst = e.errs[0] - } - default: - n++ - if errFirst == nil { - errFirst = e - } - } - } - if n <= 1 { - return errFirst // nil if n == 0 - } - - // More than one error, allocate slice and construct the multi-error. - dst := make([]error, 0, n) - for _, e := range errs { - switch e := e.(type) { - case nil: - continue - case Error: - dst = append(dst, e.errs...) - default: - dst = append(dst, e) - } - } - return Error{errs: dst} -} - -// Is reports whether any error in e matches target. -func (e Error) Is(target error) bool { - for _, err := range e.errs { - if errors.Is(err, target) { - return true - } - } - return false -} - -// As finds the first error in e that matches target, and if any is found, -// sets target to that error value and returns true. Otherwise, it returns false. -func (e Error) As(target any) bool { - for _, err := range e.errs { - if ok := errors.As(err, target); ok { - return true - } - } - return false -} - -// Range performs a pre-order, depth-first iteration of the error tree -// by successively unwrapping all error values. -// For each iteration it calls fn with the current error value and -// stops iteration if it ever reports false. -func Range(err error, fn func(error) bool) bool { - if err == nil { - return true - } - if !fn(err) { - return false - } - switch err := err.(type) { - case interface{ Unwrap() error }: - if err := err.Unwrap(); err != nil { - if !Range(err, fn) { - return false - } - } - case interface{ Unwrap() []error }: - for _, err := range err.Unwrap() { - if !Range(err, fn) { - return false - } - } - } - return true -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package multierr provides a simple multiple-error type. +// It was inspired by github.com/go-multierror/multierror. +package multierr + +import ( + "errors" + "slices" + "strings" +) + +// An Error represents multiple errors. +type Error struct { + errs []error +} + +// Error implements the error interface. +func (e Error) Error() string { + s := new(strings.Builder) + s.WriteString("multiple errors:") + for _, err := range e.errs { + s.WriteString("\n\t") + s.WriteString(err.Error()) + } + return s.String() +} + +// Errors returns a slice containing all errors in e. +func (e Error) Errors() []error { + return slices.Clone(e.errs) +} + +// Unwrap returns the underlying errors as-is. +func (e Error) Unwrap() []error { + // Do not clone since Unwrap requires callers to not mutate the slice. + // See the documentation in the Go "errors" package. + return e.errs +} + +// New returns an error composed from errs. +// Some errors in errs get special treatment: +// - nil errors are discarded +// - errors of type Error are expanded into the top level +// +// If the resulting slice has length 0, New returns nil. +// If the resulting slice has length 1, New returns that error. +// If the resulting slice has length > 1, New returns that slice as an Error. +func New(errs ...error) error { + // First count the number of errors to avoid allocating. + var n int + var errFirst error + for _, e := range errs { + switch e := e.(type) { + case nil: + continue + case Error: + n += len(e.errs) + if errFirst == nil && len(e.errs) > 0 { + errFirst = e.errs[0] + } + default: + n++ + if errFirst == nil { + errFirst = e + } + } + } + if n <= 1 { + return errFirst // nil if n == 0 + } + + // More than one error, allocate slice and construct the multi-error. + dst := make([]error, 0, n) + for _, e := range errs { + switch e := e.(type) { + case nil: + continue + case Error: + dst = append(dst, e.errs...) + default: + dst = append(dst, e) + } + } + return Error{errs: dst} +} + +// Is reports whether any error in e matches target. +func (e Error) Is(target error) bool { + for _, err := range e.errs { + if errors.Is(err, target) { + return true + } + } + return false +} + +// As finds the first error in e that matches target, and if any is found, +// sets target to that error value and returns true. Otherwise, it returns false. +func (e Error) As(target any) bool { + for _, err := range e.errs { + if ok := errors.As(err, target); ok { + return true + } + } + return false +} + +// Range performs a pre-order, depth-first iteration of the error tree +// by successively unwrapping all error values. +// For each iteration it calls fn with the current error value and +// stops iteration if it ever reports false. +func Range(err error, fn func(error) bool) bool { + if err == nil { + return true + } + if !fn(err) { + return false + } + switch err := err.(type) { + case interface{ Unwrap() error }: + if err := err.Unwrap(); err != nil { + if !Range(err, fn) { + return false + } + } + case interface{ Unwrap() []error }: + for _, err := range err.Unwrap() { + if !Range(err, fn) { + return false + } + } + } + return true +} diff --git a/util/must/must.go b/util/must/must.go index 056986fcac915..21965daa9b038 100644 --- a/util/must/must.go +++ b/util/must/must.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package must assists in calling functions that must succeed. -// -// Example usage: -// -// var target = must.Get(url.Parse(...)) -// must.Do(close()) -package must - -// Do panics if err is non-nil. -func Do(err error) { - if err != nil { - panic(err) - } -} - -// Get returns v as is. It panics if err is non-nil. -func Get[T any](v T, err error) T { - if err != nil { - panic(err) - } - return v -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package must assists in calling functions that must succeed. +// +// Example usage: +// +// var target = must.Get(url.Parse(...)) +// must.Do(close()) +package must + +// Do panics if err is non-nil. +func Do(err error) { + if err != nil { + panic(err) + } +} + +// Get returns v as is. It panics if err is non-nil. +func Get[T any](v T, err error) T { + if err != nil { + panic(err) + } + return v +} diff --git a/util/osdiag/mksyscall.go b/util/osdiag/mksyscall.go index f20be7f92da7f..bcbe113b051cd 100644 --- a/util/osdiag/mksyscall.go +++ b/util/osdiag/mksyscall.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package osdiag - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go -//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go - -//sys globalMemoryStatusEx(memStatus *_MEMORYSTATUSEX) (err error) [int32(failretval)==0] = kernel32.GlobalMemoryStatusEx -//sys regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLen *uint32, reserved *uint32, valueType *uint32, pData *byte, cbData *uint32) (ret error) [failretval!=0] = advapi32.RegEnumValueW -//sys wscEnumProtocols(iProtocols *int32, protocolBuffer *wsaProtocolInfo, bufLen *uint32, errno *int32) (ret int32) = ws2_32.WSCEnumProtocols -//sys wscGetProviderInfo(providerId *windows.GUID, infoType _WSC_PROVIDER_INFO_TYPE, info unsafe.Pointer, infoSize *uintptr, flags uint32, errno *int32) (ret int32) = ws2_32.WSCGetProviderInfo -//sys wscGetProviderPath(providerId *windows.GUID, providerDllPath *uint16, providerDllPathLen *int32, errno *int32) (ret int32) = ws2_32.WSCGetProviderPath +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package osdiag + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go +//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go + +//sys globalMemoryStatusEx(memStatus *_MEMORYSTATUSEX) (err error) [int32(failretval)==0] = kernel32.GlobalMemoryStatusEx +//sys regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLen *uint32, reserved *uint32, valueType *uint32, pData *byte, cbData *uint32) (ret error) [failretval!=0] = advapi32.RegEnumValueW +//sys wscEnumProtocols(iProtocols *int32, protocolBuffer *wsaProtocolInfo, bufLen *uint32, errno *int32) (ret int32) = ws2_32.WSCEnumProtocols +//sys wscGetProviderInfo(providerId *windows.GUID, infoType _WSC_PROVIDER_INFO_TYPE, info unsafe.Pointer, infoSize *uintptr, flags uint32, errno *int32) (ret int32) = ws2_32.WSCGetProviderInfo +//sys wscGetProviderPath(providerId *windows.GUID, providerDllPath *uint16, providerDllPathLen *int32, errno *int32) (ret int32) = ws2_32.WSCGetProviderPath diff --git a/util/osdiag/osdiag_windows_test.go b/util/osdiag/osdiag_windows_test.go index 776852a345f2b..b29b602ccb73c 100644 --- a/util/osdiag/osdiag_windows_test.go +++ b/util/osdiag/osdiag_windows_test.go @@ -1,128 +1,128 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package osdiag - -import ( - "errors" - "fmt" - "maps" - "strings" - "testing" - - "golang.org/x/sys/windows/registry" -) - -func makeLongBinaryValue() []byte { - buf := make([]byte, maxBinaryValueLen*2) - for i, _ := range buf { - buf[i] = byte(i % 0xFF) - } - return buf -} - -var testData = map[string]any{ - "": "I am the default", - "StringEmpty": "", - "StringShort": "Hello", - "StringLong": strings.Repeat("7", initialValueBufLen+1), - "MultiStringEmpty": []string{}, - "MultiStringSingle": []string{"Foo"}, - "MultiStringSingleEmpty": []string{""}, - "MultiString": []string{"Foo", "Bar", "Baz"}, - "MultiStringWithEmptyBeginning": []string{"", "Foo", "Bar"}, - "MultiStringWithEmptyMiddle": []string{"Foo", "", "Bar"}, - "MultiStringWithEmptyEnd": []string{"Foo", "Bar", ""}, - "DWord": uint32(0x12345678), - "QWord": uint64(0x123456789abcdef0), - "BinaryEmpty": []byte{}, - "BinaryShort": []byte{0x01, 0x02, 0x03, 0x04}, - "BinaryLong": makeLongBinaryValue(), -} - -const ( - keyNameTest = `SOFTWARE\Tailscale Test` - subKeyNameTest = "SubKey" -) - -func setValues(t *testing.T, k registry.Key) { - for vk, v := range testData { - var err error - switch tv := v.(type) { - case string: - err = k.SetStringValue(vk, tv) - case []string: - err = k.SetStringsValue(vk, tv) - case uint32: - err = k.SetDWordValue(vk, tv) - case uint64: - err = k.SetQWordValue(vk, tv) - case []byte: - err = k.SetBinaryValue(vk, tv) - default: - t.Fatalf("Unknown type") - } - - if err != nil { - t.Fatalf("Error setting %q: %v", vk, err) - } - } -} - -func TestRegistrySupportInfo(t *testing.T) { - // Make sure the key doesn't exist yet - k, err := registry.OpenKey(registry.CURRENT_USER, keyNameTest, registry.READ) - switch { - case err == nil: - k.Close() - t.Fatalf("Test key already exists") - case !errors.Is(err, registry.ErrNotExist): - t.Fatal(err) - } - - func() { - k, _, err := registry.CreateKey(registry.CURRENT_USER, keyNameTest, registry.WRITE) - if err != nil { - t.Fatalf("Error creating test key: %v", err) - } - defer k.Close() - - setValues(t, k) - - sk, _, err := registry.CreateKey(k, subKeyNameTest, registry.WRITE) - if err != nil { - t.Fatalf("Error creating test subkey: %v", err) - } - defer sk.Close() - - setValues(t, sk) - }() - - t.Cleanup(func() { - registry.DeleteKey(registry.CURRENT_USER, keyNameTest+"\\"+subKeyNameTest) - registry.DeleteKey(registry.CURRENT_USER, keyNameTest) - }) - - wantValuesData := maps.Clone(testData) - wantValuesData["BinaryLong"] = (wantValuesData["BinaryLong"].([]byte))[:maxBinaryValueLen] - - wantKeyData := make(map[string]any) - maps.Copy(wantKeyData, wantValuesData) - wantSubKeyData := make(map[string]any) - maps.Copy(wantSubKeyData, wantValuesData) - wantKeyData[subKeyNameTest] = wantSubKeyData - - wantData := map[string]any{ - "HKCU\\" + keyNameTest: wantKeyData, - } - - gotData, err := getRegistrySupportInfo(registry.CURRENT_USER, []string{keyNameTest}) - if err != nil { - t.Errorf("getRegistrySupportInfo error: %v", err) - } - - want, got := fmt.Sprintf("%#v", wantData), fmt.Sprintf("%#v", gotData) - if want != got { - t.Errorf("Compare error: want\n%s,\ngot %s", want, got) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package osdiag + +import ( + "errors" + "fmt" + "maps" + "strings" + "testing" + + "golang.org/x/sys/windows/registry" +) + +func makeLongBinaryValue() []byte { + buf := make([]byte, maxBinaryValueLen*2) + for i, _ := range buf { + buf[i] = byte(i % 0xFF) + } + return buf +} + +var testData = map[string]any{ + "": "I am the default", + "StringEmpty": "", + "StringShort": "Hello", + "StringLong": strings.Repeat("7", initialValueBufLen+1), + "MultiStringEmpty": []string{}, + "MultiStringSingle": []string{"Foo"}, + "MultiStringSingleEmpty": []string{""}, + "MultiString": []string{"Foo", "Bar", "Baz"}, + "MultiStringWithEmptyBeginning": []string{"", "Foo", "Bar"}, + "MultiStringWithEmptyMiddle": []string{"Foo", "", "Bar"}, + "MultiStringWithEmptyEnd": []string{"Foo", "Bar", ""}, + "DWord": uint32(0x12345678), + "QWord": uint64(0x123456789abcdef0), + "BinaryEmpty": []byte{}, + "BinaryShort": []byte{0x01, 0x02, 0x03, 0x04}, + "BinaryLong": makeLongBinaryValue(), +} + +const ( + keyNameTest = `SOFTWARE\Tailscale Test` + subKeyNameTest = "SubKey" +) + +func setValues(t *testing.T, k registry.Key) { + for vk, v := range testData { + var err error + switch tv := v.(type) { + case string: + err = k.SetStringValue(vk, tv) + case []string: + err = k.SetStringsValue(vk, tv) + case uint32: + err = k.SetDWordValue(vk, tv) + case uint64: + err = k.SetQWordValue(vk, tv) + case []byte: + err = k.SetBinaryValue(vk, tv) + default: + t.Fatalf("Unknown type") + } + + if err != nil { + t.Fatalf("Error setting %q: %v", vk, err) + } + } +} + +func TestRegistrySupportInfo(t *testing.T) { + // Make sure the key doesn't exist yet + k, err := registry.OpenKey(registry.CURRENT_USER, keyNameTest, registry.READ) + switch { + case err == nil: + k.Close() + t.Fatalf("Test key already exists") + case !errors.Is(err, registry.ErrNotExist): + t.Fatal(err) + } + + func() { + k, _, err := registry.CreateKey(registry.CURRENT_USER, keyNameTest, registry.WRITE) + if err != nil { + t.Fatalf("Error creating test key: %v", err) + } + defer k.Close() + + setValues(t, k) + + sk, _, err := registry.CreateKey(k, subKeyNameTest, registry.WRITE) + if err != nil { + t.Fatalf("Error creating test subkey: %v", err) + } + defer sk.Close() + + setValues(t, sk) + }() + + t.Cleanup(func() { + registry.DeleteKey(registry.CURRENT_USER, keyNameTest+"\\"+subKeyNameTest) + registry.DeleteKey(registry.CURRENT_USER, keyNameTest) + }) + + wantValuesData := maps.Clone(testData) + wantValuesData["BinaryLong"] = (wantValuesData["BinaryLong"].([]byte))[:maxBinaryValueLen] + + wantKeyData := make(map[string]any) + maps.Copy(wantKeyData, wantValuesData) + wantSubKeyData := make(map[string]any) + maps.Copy(wantSubKeyData, wantValuesData) + wantKeyData[subKeyNameTest] = wantSubKeyData + + wantData := map[string]any{ + "HKCU\\" + keyNameTest: wantKeyData, + } + + gotData, err := getRegistrySupportInfo(registry.CURRENT_USER, []string{keyNameTest}) + if err != nil { + t.Errorf("getRegistrySupportInfo error: %v", err) + } + + want, got := fmt.Sprintf("%#v", wantData), fmt.Sprintf("%#v", gotData) + if want != got { + t.Errorf("Compare error: want\n%s,\ngot %s", want, got) + } +} diff --git a/util/osshare/filesharingstatus_noop.go b/util/osshare/filesharingstatus_noop.go index 6be4131a991d6..7f2b131904ea9 100644 --- a/util/osshare/filesharingstatus_noop.go +++ b/util/osshare/filesharingstatus_noop.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package osshare - -import ( - "tailscale.com/types/logger" -) - -func SetFileSharingEnabled(enabled bool, logf logger.Logf) {} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package osshare + +import ( + "tailscale.com/types/logger" +) + +func SetFileSharingEnabled(enabled bool, logf logger.Logf) {} diff --git a/util/pidowner/pidowner.go b/util/pidowner/pidowner.go index 62ea85d780b07..56bb640b785dd 100644 --- a/util/pidowner/pidowner.go +++ b/util/pidowner/pidowner.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package pidowner handles lookups from process ID to its owning user. -package pidowner - -import ( - "errors" - "runtime" -) - -var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) - -var ErrProcessNotFound = errors.New("process not found") - -// OwnerOfPID returns the user ID that owns the given process ID. -// -// The returned user ID is suitable to passing to os/user.LookupId. -// -// The returned error will be ErrNotImplemented for operating systems where -// this isn't supported. -func OwnerOfPID(pid int) (userID string, err error) { - return ownerOfPID(pid) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package pidowner handles lookups from process ID to its owning user. +package pidowner + +import ( + "errors" + "runtime" +) + +var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) + +var ErrProcessNotFound = errors.New("process not found") + +// OwnerOfPID returns the user ID that owns the given process ID. +// +// The returned user ID is suitable to passing to os/user.LookupId. +// +// The returned error will be ErrNotImplemented for operating systems where +// this isn't supported. +func OwnerOfPID(pid int) (userID string, err error) { + return ownerOfPID(pid) +} diff --git a/util/pidowner/pidowner_noimpl.go b/util/pidowner/pidowner_noimpl.go index a631e3f249896..50add492fda76 100644 --- a/util/pidowner/pidowner_noimpl.go +++ b/util/pidowner/pidowner_noimpl.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !linux - -package pidowner - -func ownerOfPID(pid int) (userID string, err error) { return "", ErrNotImplemented } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !linux + +package pidowner + +func ownerOfPID(pid int) (userID string, err error) { return "", ErrNotImplemented } diff --git a/util/pidowner/pidowner_windows.go b/util/pidowner/pidowner_windows.go index c7b2512a497ed..dbf13ac8135f1 100644 --- a/util/pidowner/pidowner_windows.go +++ b/util/pidowner/pidowner_windows.go @@ -1,35 +1,35 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package pidowner - -import ( - "fmt" - "syscall" - - "golang.org/x/sys/windows" -) - -func ownerOfPID(pid int) (userID string, err error) { - procHnd, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, uint32(pid)) - if err == syscall.Errno(0x57) { // invalid parameter, for PIDs that don't exist - return "", ErrProcessNotFound - } - if err != nil { - return "", fmt.Errorf("OpenProcess: %T %#v", err, err) - } - defer windows.CloseHandle(procHnd) - - var tok windows.Token - if err := windows.OpenProcessToken(procHnd, windows.TOKEN_QUERY, &tok); err != nil { - return "", fmt.Errorf("OpenProcessToken: %w", err) - } - - tokUser, err := tok.GetTokenUser() - if err != nil { - return "", fmt.Errorf("GetTokenUser: %w", err) - } - - sid := tokUser.User.Sid - return sid.String(), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package pidowner + +import ( + "fmt" + "syscall" + + "golang.org/x/sys/windows" +) + +func ownerOfPID(pid int) (userID string, err error) { + procHnd, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, uint32(pid)) + if err == syscall.Errno(0x57) { // invalid parameter, for PIDs that don't exist + return "", ErrProcessNotFound + } + if err != nil { + return "", fmt.Errorf("OpenProcess: %T %#v", err, err) + } + defer windows.CloseHandle(procHnd) + + var tok windows.Token + if err := windows.OpenProcessToken(procHnd, windows.TOKEN_QUERY, &tok); err != nil { + return "", fmt.Errorf("OpenProcessToken: %w", err) + } + + tokUser, err := tok.GetTokenUser() + if err != nil { + return "", fmt.Errorf("GetTokenUser: %w", err) + } + + sid := tokUser.User.Sid + return sid.String(), nil +} diff --git a/util/precompress/precompress.go b/util/precompress/precompress.go index e9bebb333e2af..6d1a26efdd767 100644 --- a/util/precompress/precompress.go +++ b/util/precompress/precompress.go @@ -1,129 +1,129 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package precompress provides build- and serving-time support for -// precompressed static resources, to avoid the cost of repeatedly compressing -// unchanging resources. -package precompress - -import ( - "bytes" - "compress/gzip" - "io" - "io/fs" - "net/http" - "os" - "path" - "path/filepath" - - "github.com/andybalholm/brotli" - "golang.org/x/sync/errgroup" - "tailscale.com/tsweb" -) - -// PrecompressDir compresses static assets in dirPath using Gzip and Brotli, so -// that they can be later served with OpenPrecompressedFile. -func PrecompressDir(dirPath string, options Options) error { - var eg errgroup.Group - err := fs.WalkDir(os.DirFS(dirPath), ".", func(p string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - return nil - } - if !compressibleExtensions[filepath.Ext(p)] { - return nil - } - p = path.Join(dirPath, p) - if options.ProgressFn != nil { - options.ProgressFn(p) - } - - eg.Go(func() error { - return Precompress(p, options) - }) - return nil - }) - if err != nil { - return err - } - return eg.Wait() -} - -type Options struct { - // FastCompression controls whether compression should be optimized for - // speed rather than size. - FastCompression bool - // ProgressFn, if non-nil, is invoked when a file in the directory is about - // to be compressed. - ProgressFn func(path string) -} - -// OpenPrecompressedFile opens a file from fs, preferring compressed versions -// generated by PrecompressDir if possible. -func OpenPrecompressedFile(w http.ResponseWriter, r *http.Request, path string, fs fs.FS) (fs.File, error) { - if tsweb.AcceptsEncoding(r, "br") { - if f, err := fs.Open(path + ".br"); err == nil { - w.Header().Set("Content-Encoding", "br") - return f, nil - } - } - if tsweb.AcceptsEncoding(r, "gzip") { - if f, err := fs.Open(path + ".gz"); err == nil { - w.Header().Set("Content-Encoding", "gzip") - return f, nil - } - } - - return fs.Open(path) -} - -var compressibleExtensions = map[string]bool{ - ".js": true, - ".css": true, -} - -func Precompress(path string, options Options) error { - contents, err := os.ReadFile(path) - if err != nil { - return err - } - fi, err := os.Lstat(path) - if err != nil { - return err - } - - gzipLevel := gzip.BestCompression - if options.FastCompression { - gzipLevel = gzip.BestSpeed - } - err = writeCompressed(contents, func(w io.Writer) (io.WriteCloser, error) { - return gzip.NewWriterLevel(w, gzipLevel) - }, path+".gz", fi.Mode()) - if err != nil { - return err - } - brotliLevel := brotli.BestCompression - if options.FastCompression { - brotliLevel = brotli.BestSpeed - } - return writeCompressed(contents, func(w io.Writer) (io.WriteCloser, error) { - return brotli.NewWriterLevel(w, brotliLevel), nil - }, path+".br", fi.Mode()) -} - -func writeCompressed(contents []byte, compressedWriterCreator func(io.Writer) (io.WriteCloser, error), outputPath string, outputMode fs.FileMode) error { - var buf bytes.Buffer - compressedWriter, err := compressedWriterCreator(&buf) - if err != nil { - return err - } - if _, err := compressedWriter.Write(contents); err != nil { - return err - } - if err := compressedWriter.Close(); err != nil { - return err - } - return os.WriteFile(outputPath, buf.Bytes(), outputMode) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package precompress provides build- and serving-time support for +// precompressed static resources, to avoid the cost of repeatedly compressing +// unchanging resources. +package precompress + +import ( + "bytes" + "compress/gzip" + "io" + "io/fs" + "net/http" + "os" + "path" + "path/filepath" + + "github.com/andybalholm/brotli" + "golang.org/x/sync/errgroup" + "tailscale.com/tsweb" +) + +// PrecompressDir compresses static assets in dirPath using Gzip and Brotli, so +// that they can be later served with OpenPrecompressedFile. +func PrecompressDir(dirPath string, options Options) error { + var eg errgroup.Group + err := fs.WalkDir(os.DirFS(dirPath), ".", func(p string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + if !compressibleExtensions[filepath.Ext(p)] { + return nil + } + p = path.Join(dirPath, p) + if options.ProgressFn != nil { + options.ProgressFn(p) + } + + eg.Go(func() error { + return Precompress(p, options) + }) + return nil + }) + if err != nil { + return err + } + return eg.Wait() +} + +type Options struct { + // FastCompression controls whether compression should be optimized for + // speed rather than size. + FastCompression bool + // ProgressFn, if non-nil, is invoked when a file in the directory is about + // to be compressed. + ProgressFn func(path string) +} + +// OpenPrecompressedFile opens a file from fs, preferring compressed versions +// generated by PrecompressDir if possible. +func OpenPrecompressedFile(w http.ResponseWriter, r *http.Request, path string, fs fs.FS) (fs.File, error) { + if tsweb.AcceptsEncoding(r, "br") { + if f, err := fs.Open(path + ".br"); err == nil { + w.Header().Set("Content-Encoding", "br") + return f, nil + } + } + if tsweb.AcceptsEncoding(r, "gzip") { + if f, err := fs.Open(path + ".gz"); err == nil { + w.Header().Set("Content-Encoding", "gzip") + return f, nil + } + } + + return fs.Open(path) +} + +var compressibleExtensions = map[string]bool{ + ".js": true, + ".css": true, +} + +func Precompress(path string, options Options) error { + contents, err := os.ReadFile(path) + if err != nil { + return err + } + fi, err := os.Lstat(path) + if err != nil { + return err + } + + gzipLevel := gzip.BestCompression + if options.FastCompression { + gzipLevel = gzip.BestSpeed + } + err = writeCompressed(contents, func(w io.Writer) (io.WriteCloser, error) { + return gzip.NewWriterLevel(w, gzipLevel) + }, path+".gz", fi.Mode()) + if err != nil { + return err + } + brotliLevel := brotli.BestCompression + if options.FastCompression { + brotliLevel = brotli.BestSpeed + } + return writeCompressed(contents, func(w io.Writer) (io.WriteCloser, error) { + return brotli.NewWriterLevel(w, brotliLevel), nil + }, path+".br", fi.Mode()) +} + +func writeCompressed(contents []byte, compressedWriterCreator func(io.Writer) (io.WriteCloser, error), outputPath string, outputMode fs.FileMode) error { + var buf bytes.Buffer + compressedWriter, err := compressedWriterCreator(&buf) + if err != nil { + return err + } + if _, err := compressedWriter.Write(contents); err != nil { + return err + } + if err := compressedWriter.Close(); err != nil { + return err + } + return os.WriteFile(outputPath, buf.Bytes(), outputMode) +} diff --git a/util/quarantine/quarantine.go b/util/quarantine/quarantine.go index 488465ba055bb..7ad65a81d69ee 100644 --- a/util/quarantine/quarantine.go +++ b/util/quarantine/quarantine.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package quarantine sets platform specific "quarantine" attributes on files -// that are received from other hosts. -package quarantine - -import "os" - -// SetOnFile sets the platform-specific quarantine attribute (if any) on the -// provided file. -func SetOnFile(f *os.File) error { - return setQuarantineAttr(f) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package quarantine sets platform specific "quarantine" attributes on files +// that are received from other hosts. +package quarantine + +import "os" + +// SetOnFile sets the platform-specific quarantine attribute (if any) on the +// provided file. +func SetOnFile(f *os.File) error { + return setQuarantineAttr(f) +} diff --git a/util/quarantine/quarantine_darwin.go b/util/quarantine/quarantine_darwin.go index b7757f3346809..35405d9cc7a87 100644 --- a/util/quarantine/quarantine_darwin.go +++ b/util/quarantine/quarantine_darwin.go @@ -1,56 +1,56 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package quarantine - -import ( - "fmt" - "os" - "strings" - "time" - - "github.com/google/uuid" - "golang.org/x/sys/unix" -) - -func setQuarantineAttr(f *os.File) error { - sc, err := f.SyscallConn() - if err != nil { - return err - } - - now := time.Now() - - // We uppercase the UUID to match what other applications on macOS do - id := strings.ToUpper(uuid.New().String()) - - // kLSQuarantineTypeOtherDownload; this matches what AirDrop sets when - // receiving a file. - quarantineType := "0001" - - // This format is under-documented, but the following links contain a - // reasonably comprehensive overview: - // https://eclecticlight.co/2020/10/29/quarantine-and-the-quarantine-flag/ - // https://nixhacker.com/security-protection-in-macos-1/ - // https://ilostmynotes.blogspot.com/2012/06/gatekeeper-xprotect-and-quarantine.html - attrData := fmt.Sprintf("%s;%x;%s;%s", - quarantineType, // quarantine value - now.Unix(), // time in hex - "Tailscale", // application - id, // UUID - ) - - var innerErr error - err = sc.Control(func(fd uintptr) { - innerErr = unix.Fsetxattr( - int(fd), - "com.apple.quarantine", // attr - []byte(attrData), - 0, - ) - }) - if err != nil { - return err - } - return innerErr -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package quarantine + +import ( + "fmt" + "os" + "strings" + "time" + + "github.com/google/uuid" + "golang.org/x/sys/unix" +) + +func setQuarantineAttr(f *os.File) error { + sc, err := f.SyscallConn() + if err != nil { + return err + } + + now := time.Now() + + // We uppercase the UUID to match what other applications on macOS do + id := strings.ToUpper(uuid.New().String()) + + // kLSQuarantineTypeOtherDownload; this matches what AirDrop sets when + // receiving a file. + quarantineType := "0001" + + // This format is under-documented, but the following links contain a + // reasonably comprehensive overview: + // https://eclecticlight.co/2020/10/29/quarantine-and-the-quarantine-flag/ + // https://nixhacker.com/security-protection-in-macos-1/ + // https://ilostmynotes.blogspot.com/2012/06/gatekeeper-xprotect-and-quarantine.html + attrData := fmt.Sprintf("%s;%x;%s;%s", + quarantineType, // quarantine value + now.Unix(), // time in hex + "Tailscale", // application + id, // UUID + ) + + var innerErr error + err = sc.Control(func(fd uintptr) { + innerErr = unix.Fsetxattr( + int(fd), + "com.apple.quarantine", // attr + []byte(attrData), + 0, + ) + }) + if err != nil { + return err + } + return innerErr +} diff --git a/util/quarantine/quarantine_default.go b/util/quarantine/quarantine_default.go index 65a14ed26fa97..65954a4d25415 100644 --- a/util/quarantine/quarantine_default.go +++ b/util/quarantine/quarantine_default.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !darwin && !windows - -package quarantine - -import ( - "os" -) - -func setQuarantineAttr(f *os.File) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !darwin && !windows + +package quarantine + +import ( + "os" +) + +func setQuarantineAttr(f *os.File) error { + return nil +} diff --git a/util/quarantine/quarantine_windows.go b/util/quarantine/quarantine_windows.go index 3052c2c6dfab5..6fdf4e699b75b 100644 --- a/util/quarantine/quarantine_windows.go +++ b/util/quarantine/quarantine_windows.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package quarantine - -import ( - "os" - "strings" -) - -func setQuarantineAttr(f *os.File) error { - // Documentation on this can be found here: - // https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-fscc/6e3f7352-d11c-4d76-8c39-2516a9df36e8 - // - // Additional information can be found at: - // https://www.digital-detective.net/forensic-analysis-of-zone-identifier-stream/ - // https://bugzilla.mozilla.org/show_bug.cgi?id=1433179 - content := strings.Join([]string{ - "[ZoneTransfer]", - - // "URLZONE_INTERNET" - // https://docs.microsoft.com/en-us/previous-versions/windows/internet-explorer/ie-developer/platform-apis/ms537175(v=vs.85) - "ZoneId=3", - - // TODO(andrew): should/could we add ReferrerUrl or HostUrl? - }, "\r\n") - - return os.WriteFile(f.Name()+":Zone.Identifier", []byte(content), 0) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package quarantine + +import ( + "os" + "strings" +) + +func setQuarantineAttr(f *os.File) error { + // Documentation on this can be found here: + // https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-fscc/6e3f7352-d11c-4d76-8c39-2516a9df36e8 + // + // Additional information can be found at: + // https://www.digital-detective.net/forensic-analysis-of-zone-identifier-stream/ + // https://bugzilla.mozilla.org/show_bug.cgi?id=1433179 + content := strings.Join([]string{ + "[ZoneTransfer]", + + // "URLZONE_INTERNET" + // https://docs.microsoft.com/en-us/previous-versions/windows/internet-explorer/ie-developer/platform-apis/ms537175(v=vs.85) + "ZoneId=3", + + // TODO(andrew): should/could we add ReferrerUrl or HostUrl? + }, "\r\n") + + return os.WriteFile(f.Name()+":Zone.Identifier", []byte(content), 0) +} diff --git a/util/race/race_test.go b/util/race/race_test.go index 17ea764591503..d3838271226ac 100644 --- a/util/race/race_test.go +++ b/util/race/race_test.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package race - -import ( - "context" - "errors" - "testing" - "time" - - "tailscale.com/tstest" -) - -func TestRaceSuccess1(t *testing.T) { - tstest.ResourceCheck(t) - - const want = "success" - rh := New[string]( - 10*time.Second, - func(context.Context) (string, error) { - return want, nil - }, func(context.Context) (string, error) { - t.Fatal("should not be called") - return "", nil - }) - res, err := rh.Start(context.Background()) - if err != nil { - t.Fatal(err) - } - if res != want { - t.Errorf("got res=%q, want %q", res, want) - } -} - -func TestRaceRetry(t *testing.T) { - tstest.ResourceCheck(t) - - const want = "fallback" - rh := New[string]( - 10*time.Second, - func(context.Context) (string, error) { - return "", errors.New("some error") - }, func(context.Context) (string, error) { - return want, nil - }) - res, err := rh.Start(context.Background()) - if err != nil { - t.Fatal(err) - } - if res != want { - t.Errorf("got res=%q, want %q", res, want) - } -} - -func TestRaceTimeout(t *testing.T) { - tstest.ResourceCheck(t) - - const want = "fallback" - rh := New[string]( - 100*time.Millisecond, - func(ctx context.Context) (string, error) { - // Block forever - <-ctx.Done() - return "", ctx.Err() - }, func(context.Context) (string, error) { - return want, nil - }) - res, err := rh.Start(context.Background()) - if err != nil { - t.Fatal(err) - } - if res != want { - t.Errorf("got res=%q, want %q", res, want) - } -} - -func TestRaceError(t *testing.T) { - tstest.ResourceCheck(t) - - err1 := errors.New("error 1") - err2 := errors.New("error 2") - - rh := New[string]( - 100*time.Millisecond, - func(ctx context.Context) (string, error) { - return "", err1 - }, func(context.Context) (string, error) { - return "", err2 - }) - - _, err := rh.Start(context.Background()) - if !errors.Is(err, err1) { - t.Errorf("wanted err to contain err1; got %v", err) - } - if !errors.Is(err, err2) { - t.Errorf("wanted err to contain err2; got %v", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package race + +import ( + "context" + "errors" + "testing" + "time" + + "tailscale.com/tstest" +) + +func TestRaceSuccess1(t *testing.T) { + tstest.ResourceCheck(t) + + const want = "success" + rh := New[string]( + 10*time.Second, + func(context.Context) (string, error) { + return want, nil + }, func(context.Context) (string, error) { + t.Fatal("should not be called") + return "", nil + }) + res, err := rh.Start(context.Background()) + if err != nil { + t.Fatal(err) + } + if res != want { + t.Errorf("got res=%q, want %q", res, want) + } +} + +func TestRaceRetry(t *testing.T) { + tstest.ResourceCheck(t) + + const want = "fallback" + rh := New[string]( + 10*time.Second, + func(context.Context) (string, error) { + return "", errors.New("some error") + }, func(context.Context) (string, error) { + return want, nil + }) + res, err := rh.Start(context.Background()) + if err != nil { + t.Fatal(err) + } + if res != want { + t.Errorf("got res=%q, want %q", res, want) + } +} + +func TestRaceTimeout(t *testing.T) { + tstest.ResourceCheck(t) + + const want = "fallback" + rh := New[string]( + 100*time.Millisecond, + func(ctx context.Context) (string, error) { + // Block forever + <-ctx.Done() + return "", ctx.Err() + }, func(context.Context) (string, error) { + return want, nil + }) + res, err := rh.Start(context.Background()) + if err != nil { + t.Fatal(err) + } + if res != want { + t.Errorf("got res=%q, want %q", res, want) + } +} + +func TestRaceError(t *testing.T) { + tstest.ResourceCheck(t) + + err1 := errors.New("error 1") + err2 := errors.New("error 2") + + rh := New[string]( + 100*time.Millisecond, + func(ctx context.Context) (string, error) { + return "", err1 + }, func(context.Context) (string, error) { + return "", err2 + }) + + _, err := rh.Start(context.Background()) + if !errors.Is(err, err1) { + t.Errorf("wanted err to contain err1; got %v", err) + } + if !errors.Is(err, err2) { + t.Errorf("wanted err to contain err2; got %v", err) + } +} diff --git a/util/racebuild/off.go b/util/racebuild/off.go index a0dba0f32c052..8f4fe998fb4bb 100644 --- a/util/racebuild/off.go +++ b/util/racebuild/off.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !race - -package racebuild - -const On = false +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !race + +package racebuild + +const On = false diff --git a/util/racebuild/on.go b/util/racebuild/on.go index c60bca2e6f8df..69ae2bcae4239 100644 --- a/util/racebuild/on.go +++ b/util/racebuild/on.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build race - -package racebuild - -const On = true +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build race + +package racebuild + +const On = true diff --git a/util/racebuild/racebuild.go b/util/racebuild/racebuild.go index c1a43eb96a376..d061276cb8a0a 100644 --- a/util/racebuild/racebuild.go +++ b/util/racebuild/racebuild.go @@ -1,6 +1,6 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package racebuild exports a constant about whether the current binary -// was built with the race detector. -package racebuild +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package racebuild exports a constant about whether the current binary +// was built with the race detector. +package racebuild diff --git a/util/rands/rands.go b/util/rands/rands.go index dcd75c5f37158..d83e1e55898dc 100644 --- a/util/rands/rands.go +++ b/util/rands/rands.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package rands contains utility functions for randomness. -package rands - -import ( - crand "crypto/rand" - "encoding/hex" -) - -// HexString returns a string of n cryptographically random lowercase -// hex characters. -// -// That is, HexString(3) returns something like "0fc", containing 12 -// bits of randomness. -func HexString(n int) string { - nb := n / 2 - if n%2 == 1 { - nb++ - } - b := make([]byte, nb) - crand.Read(b) - return hex.EncodeToString(b)[:n] -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package rands contains utility functions for randomness. +package rands + +import ( + crand "crypto/rand" + "encoding/hex" +) + +// HexString returns a string of n cryptographically random lowercase +// hex characters. +// +// That is, HexString(3) returns something like "0fc", containing 12 +// bits of randomness. +func HexString(n int) string { + nb := n / 2 + if n%2 == 1 { + nb++ + } + b := make([]byte, nb) + crand.Read(b) + return hex.EncodeToString(b)[:n] +} diff --git a/util/rands/rands_test.go b/util/rands/rands_test.go index ec339f94bace7..5813f2bb46763 100644 --- a/util/rands/rands_test.go +++ b/util/rands/rands_test.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package rands - -import "testing" - -func TestHexString(t *testing.T) { - for i := 0; i <= 8; i++ { - s := HexString(i) - if len(s) != i { - t.Errorf("HexString(%v) = %q; want len %v, not %v", i, s, i, len(s)) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rands + +import "testing" + +func TestHexString(t *testing.T) { + for i := 0; i <= 8; i++ { + s := HexString(i) + if len(s) != i { + t.Errorf("HexString(%v) = %q; want len %v, not %v", i, s, i, len(s)) + } + } +} diff --git a/util/set/handle.go b/util/set/handle.go index 61b4eb93d8b4d..471ceeba2d523 100644 --- a/util/set/handle.go +++ b/util/set/handle.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package set - -// HandleSet is a set of T. -// -// It is not safe for concurrent use. -type HandleSet[T any] map[Handle]T - -// Handle is an opaque comparable value that's used as the map key in a -// HandleSet. The only way to get one is to call HandleSet.Add. -type Handle struct { - v *byte -} - -// Add adds the element (map value) e to the set. -// -// It returns the handle (map key) with which e can be removed, using a map -// delete. -func (s *HandleSet[T]) Add(e T) Handle { - h := Handle{new(byte)} - if *s == nil { - *s = make(HandleSet[T]) - } - (*s)[h] = e - return h -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +// HandleSet is a set of T. +// +// It is not safe for concurrent use. +type HandleSet[T any] map[Handle]T + +// Handle is an opaque comparable value that's used as the map key in a +// HandleSet. The only way to get one is to call HandleSet.Add. +type Handle struct { + v *byte +} + +// Add adds the element (map value) e to the set. +// +// It returns the handle (map key) with which e can be removed, using a map +// delete. +func (s *HandleSet[T]) Add(e T) Handle { + h := Handle{new(byte)} + if *s == nil { + *s = make(HandleSet[T]) + } + (*s)[h] = e + return h +} diff --git a/util/set/slice_test.go b/util/set/slice_test.go index ca57e52e8cbc3..9134c296292d3 100644 --- a/util/set/slice_test.go +++ b/util/set/slice_test.go @@ -1,56 +1,56 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package set - -import ( - "testing" - - qt "github.com/frankban/quicktest" -) - -func TestSliceSet(t *testing.T) { - c := qt.New(t) - - var ss Slice[int] - c.Check(len(ss.slice), qt.Equals, 0) - ss.Add(1) - c.Check(len(ss.slice), qt.Equals, 1) - c.Check(len(ss.set), qt.Equals, 0) - c.Check(ss.Contains(1), qt.Equals, true) - c.Check(ss.Contains(2), qt.Equals, false) - - ss.Add(1) - c.Check(len(ss.slice), qt.Equals, 1) - c.Check(len(ss.set), qt.Equals, 0) - - ss.Add(2) - ss.Add(3) - ss.Add(4) - ss.Add(5) - ss.Add(6) - ss.Add(7) - ss.Add(8) - c.Check(len(ss.slice), qt.Equals, 8) - c.Check(len(ss.set), qt.Equals, 0) - - ss.Add(9) - c.Check(len(ss.slice), qt.Equals, 9) - c.Check(len(ss.set), qt.Equals, 9) - - ss.Remove(4) - c.Check(len(ss.slice), qt.Equals, 8) - c.Check(len(ss.set), qt.Equals, 8) - c.Assert(ss.Contains(4), qt.IsFalse) - - // Ensure that the order of insertion is maintained - c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9}) - ss.Add(4) - c.Check(len(ss.slice), qt.Equals, 9) - c.Check(len(ss.set), qt.Equals, 9) - c.Assert(ss.Contains(4), qt.IsTrue) - c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9, 4}) - - ss.Add(1, 234, 556) - c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9, 4, 234, 556}) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +import ( + "testing" + + qt "github.com/frankban/quicktest" +) + +func TestSliceSet(t *testing.T) { + c := qt.New(t) + + var ss Slice[int] + c.Check(len(ss.slice), qt.Equals, 0) + ss.Add(1) + c.Check(len(ss.slice), qt.Equals, 1) + c.Check(len(ss.set), qt.Equals, 0) + c.Check(ss.Contains(1), qt.Equals, true) + c.Check(ss.Contains(2), qt.Equals, false) + + ss.Add(1) + c.Check(len(ss.slice), qt.Equals, 1) + c.Check(len(ss.set), qt.Equals, 0) + + ss.Add(2) + ss.Add(3) + ss.Add(4) + ss.Add(5) + ss.Add(6) + ss.Add(7) + ss.Add(8) + c.Check(len(ss.slice), qt.Equals, 8) + c.Check(len(ss.set), qt.Equals, 0) + + ss.Add(9) + c.Check(len(ss.slice), qt.Equals, 9) + c.Check(len(ss.set), qt.Equals, 9) + + ss.Remove(4) + c.Check(len(ss.slice), qt.Equals, 8) + c.Check(len(ss.set), qt.Equals, 8) + c.Assert(ss.Contains(4), qt.IsFalse) + + // Ensure that the order of insertion is maintained + c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9}) + ss.Add(4) + c.Check(len(ss.slice), qt.Equals, 9) + c.Check(len(ss.set), qt.Equals, 9) + c.Assert(ss.Contains(4), qt.IsTrue) + c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9, 4}) + + ss.Add(1, 234, 556) + c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9, 4, 234, 556}) +} diff --git a/util/sysresources/memory.go b/util/sysresources/memory.go index 8bf784e13d831..7363155cdb2ae 100644 --- a/util/sysresources/memory.go +++ b/util/sysresources/memory.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package sysresources - -// TotalMemory returns the total accessible system memory, in bytes. If the -// value cannot be determined, then 0 will be returned. -func TotalMemory() uint64 { - return totalMemoryImpl() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package sysresources + +// TotalMemory returns the total accessible system memory, in bytes. If the +// value cannot be determined, then 0 will be returned. +func TotalMemory() uint64 { + return totalMemoryImpl() +} diff --git a/util/sysresources/memory_bsd.go b/util/sysresources/memory_bsd.go index 39d3a18a972f1..26850dce652ff 100644 --- a/util/sysresources/memory_bsd.go +++ b/util/sysresources/memory_bsd.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build freebsd || openbsd || dragonfly || netbsd - -package sysresources - -import "golang.org/x/sys/unix" - -func totalMemoryImpl() uint64 { - val, err := unix.SysctlUint64("hw.physmem") - if err != nil { - return 0 - } - return val -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build freebsd || openbsd || dragonfly || netbsd + +package sysresources + +import "golang.org/x/sys/unix" + +func totalMemoryImpl() uint64 { + val, err := unix.SysctlUint64("hw.physmem") + if err != nil { + return 0 + } + return val +} diff --git a/util/sysresources/memory_darwin.go b/util/sysresources/memory_darwin.go index 2f74b6cecd7f3..e07bac0cd7f9b 100644 --- a/util/sysresources/memory_darwin.go +++ b/util/sysresources/memory_darwin.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin - -package sysresources - -import "golang.org/x/sys/unix" - -func totalMemoryImpl() uint64 { - val, err := unix.SysctlUint64("hw.memsize") - if err != nil { - return 0 - } - return val -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin + +package sysresources + +import "golang.org/x/sys/unix" + +func totalMemoryImpl() uint64 { + val, err := unix.SysctlUint64("hw.memsize") + if err != nil { + return 0 + } + return val +} diff --git a/util/sysresources/memory_linux.go b/util/sysresources/memory_linux.go index f3c51469fcc6c..0239b0e80d62a 100644 --- a/util/sysresources/memory_linux.go +++ b/util/sysresources/memory_linux.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package sysresources - -import "golang.org/x/sys/unix" - -func totalMemoryImpl() uint64 { - var info unix.Sysinfo_t - - if err := unix.Sysinfo(&info); err != nil { - return 0 - } - - // uint64 casts are required since these might be uint32s - return uint64(info.Totalram) * uint64(info.Unit) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package sysresources + +import "golang.org/x/sys/unix" + +func totalMemoryImpl() uint64 { + var info unix.Sysinfo_t + + if err := unix.Sysinfo(&info); err != nil { + return 0 + } + + // uint64 casts are required since these might be uint32s + return uint64(info.Totalram) * uint64(info.Unit) +} diff --git a/util/sysresources/memory_unsupported.go b/util/sysresources/memory_unsupported.go index f80ef4e6ebfe8..0fde256e0543d 100644 --- a/util/sysresources/memory_unsupported.go +++ b/util/sysresources/memory_unsupported.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !(linux || darwin || freebsd || openbsd || dragonfly || netbsd) - -package sysresources - -func totalMemoryImpl() uint64 { return 0 } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !(linux || darwin || freebsd || openbsd || dragonfly || netbsd) + +package sysresources + +func totalMemoryImpl() uint64 { return 0 } diff --git a/util/sysresources/sysresources.go b/util/sysresources/sysresources.go index 1cce164a74730..32d972ab15513 100644 --- a/util/sysresources/sysresources.go +++ b/util/sysresources/sysresources.go @@ -1,6 +1,6 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package sysresources provides OS-independent methods of determining the -// resources available to the current system. -package sysresources +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package sysresources provides OS-independent methods of determining the +// resources available to the current system. +package sysresources diff --git a/util/sysresources/sysresources_test.go b/util/sysresources/sysresources_test.go index af96620421bae..331ad913bfba1 100644 --- a/util/sysresources/sysresources_test.go +++ b/util/sysresources/sysresources_test.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package sysresources - -import ( - "runtime" - "testing" -) - -func TestTotalMemory(t *testing.T) { - switch runtime.GOOS { - case "linux": - case "freebsd", "openbsd", "dragonfly", "netbsd": - case "darwin": - default: - t.Skipf("not supported on runtime.GOOS=%q yet", runtime.GOOS) - } - - mem := TotalMemory() - if mem == 0 { - t.Fatal("wanted TotalMemory > 0") - } - t.Logf("total memory: %v bytes", mem) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package sysresources + +import ( + "runtime" + "testing" +) + +func TestTotalMemory(t *testing.T) { + switch runtime.GOOS { + case "linux": + case "freebsd", "openbsd", "dragonfly", "netbsd": + case "darwin": + default: + t.Skipf("not supported on runtime.GOOS=%q yet", runtime.GOOS) + } + + mem := TotalMemory() + if mem == 0 { + t.Fatal("wanted TotalMemory > 0") + } + t.Logf("total memory: %v bytes", mem) +} diff --git a/util/systemd/doc.go b/util/systemd/doc.go index 296f74e9d4cd6..0c28e182354ec 100644 --- a/util/systemd/doc.go +++ b/util/systemd/doc.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -/* -Package systemd contains a minimal wrapper around systemd-notify to enable -applications to signal readiness and status to systemd. - -This package will only have effect on Linux systems running Tailscale in a -systemd unit with the Type=notify flag set. On other operating systems (or -when running in a Linux distro without being run from inside systemd) this -package will become a no-op. -*/ -package systemd +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +/* +Package systemd contains a minimal wrapper around systemd-notify to enable +applications to signal readiness and status to systemd. + +This package will only have effect on Linux systems running Tailscale in a +systemd unit with the Type=notify flag set. On other operating systems (or +when running in a Linux distro without being run from inside systemd) this +package will become a no-op. +*/ +package systemd diff --git a/util/systemd/systemd_linux.go b/util/systemd/systemd_linux.go index 34d6daff39e3b..909cfcb20ac6e 100644 --- a/util/systemd/systemd_linux.go +++ b/util/systemd/systemd_linux.go @@ -1,77 +1,77 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package systemd - -import ( - "errors" - "log" - "os" - "sync" - - "github.com/mdlayher/sdnotify" -) - -var getNotifyOnce struct { - sync.Once - v *sdnotify.Notifier -} - -type logOnce struct { - sync.Once -} - -func (l *logOnce) logf(format string, args ...any) { - l.Once.Do(func() { - log.Printf(format, args...) - }) -} - -var ( - readyOnce = &logOnce{} - statusOnce = &logOnce{} -) - -func notifier() *sdnotify.Notifier { - getNotifyOnce.Do(func() { - var err error - getNotifyOnce.v, err = sdnotify.New() - // Not exist means probably not running under systemd, so don't log. - if err != nil && !errors.Is(err, os.ErrNotExist) { - log.Printf("systemd: systemd-notifier error: %v", err) - } - }) - return getNotifyOnce.v -} - -// Ready signals readiness to systemd. This will unblock service dependents from starting. -func Ready() { - err := notifier().Notify(sdnotify.Ready) - if err != nil { - readyOnce.logf("systemd: error notifying: %v", err) - } -} - -// Status sends a single line status update to systemd so that information shows up -// in systemctl output. For example: -// -// $ systemctl status tailscale -// ● tailscale.service - Tailscale client daemon -// Loaded: loaded (/nix/store/qc312qcy907wz80fqrgbbm8a9djafmlg-unit-tailscale.service/tailscale.service; enabled; vendor preset: enabled) -// Active: active (running) since Tue 2020-11-24 17:54:07 EST; 13h ago -// Main PID: 26741 (.tailscaled-wra) -// Status: "Connected; user@host.domain.tld; 100.101.102.103" -// IP: 0B in, 0B out -// Tasks: 22 (limit: 4915) -// Memory: 30.9M -// CPU: 2min 38.469s -// CGroup: /system.slice/tailscale.service -// └─26741 /nix/store/sv6cj4mw2jajm9xkbwj07k29dj30lh0n-tailscale-date.20200727/bin/tailscaled --port 41641 -func Status(format string, args ...any) { - err := notifier().Notify(sdnotify.Statusf(format, args...)) - if err != nil { - statusOnce.logf("systemd: error notifying: %v", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package systemd + +import ( + "errors" + "log" + "os" + "sync" + + "github.com/mdlayher/sdnotify" +) + +var getNotifyOnce struct { + sync.Once + v *sdnotify.Notifier +} + +type logOnce struct { + sync.Once +} + +func (l *logOnce) logf(format string, args ...any) { + l.Once.Do(func() { + log.Printf(format, args...) + }) +} + +var ( + readyOnce = &logOnce{} + statusOnce = &logOnce{} +) + +func notifier() *sdnotify.Notifier { + getNotifyOnce.Do(func() { + var err error + getNotifyOnce.v, err = sdnotify.New() + // Not exist means probably not running under systemd, so don't log. + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Printf("systemd: systemd-notifier error: %v", err) + } + }) + return getNotifyOnce.v +} + +// Ready signals readiness to systemd. This will unblock service dependents from starting. +func Ready() { + err := notifier().Notify(sdnotify.Ready) + if err != nil { + readyOnce.logf("systemd: error notifying: %v", err) + } +} + +// Status sends a single line status update to systemd so that information shows up +// in systemctl output. For example: +// +// $ systemctl status tailscale +// ● tailscale.service - Tailscale client daemon +// Loaded: loaded (/nix/store/qc312qcy907wz80fqrgbbm8a9djafmlg-unit-tailscale.service/tailscale.service; enabled; vendor preset: enabled) +// Active: active (running) since Tue 2020-11-24 17:54:07 EST; 13h ago +// Main PID: 26741 (.tailscaled-wra) +// Status: "Connected; user@host.domain.tld; 100.101.102.103" +// IP: 0B in, 0B out +// Tasks: 22 (limit: 4915) +// Memory: 30.9M +// CPU: 2min 38.469s +// CGroup: /system.slice/tailscale.service +// └─26741 /nix/store/sv6cj4mw2jajm9xkbwj07k29dj30lh0n-tailscale-date.20200727/bin/tailscaled --port 41641 +func Status(format string, args ...any) { + err := notifier().Notify(sdnotify.Statusf(format, args...)) + if err != nil { + statusOnce.logf("systemd: error notifying: %v", err) + } +} diff --git a/util/systemd/systemd_nonlinux.go b/util/systemd/systemd_nonlinux.go index d8b20665fb7ba..36214020ce566 100644 --- a/util/systemd/systemd_nonlinux.go +++ b/util/systemd/systemd_nonlinux.go @@ -1,9 +1,9 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package systemd - -func Ready() {} -func Status(string, ...any) {} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package systemd + +func Ready() {} +func Status(string, ...any) {} diff --git a/util/testenv/testenv.go b/util/testenv/testenv.go index 02c688803a943..12ada9003052b 100644 --- a/util/testenv/testenv.go +++ b/util/testenv/testenv.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package testenv provides utility functions for tests. It does not depend on -// the `testing` package to allow usage in non-test code. -package testenv - -import ( - "flag" - - "tailscale.com/types/lazy" -) - -var lazyInTest lazy.SyncValue[bool] - -// InTest reports whether the current binary is a test binary. -func InTest() bool { - return lazyInTest.Get(func() bool { - return flag.Lookup("test.v") != nil - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package testenv provides utility functions for tests. It does not depend on +// the `testing` package to allow usage in non-test code. +package testenv + +import ( + "flag" + + "tailscale.com/types/lazy" +) + +var lazyInTest lazy.SyncValue[bool] + +// InTest reports whether the current binary is a test binary. +func InTest() bool { + return lazyInTest.Get(func() bool { + return flag.Lookup("test.v") != nil + }) +} diff --git a/util/truncate/truncate_test.go b/util/truncate/truncate_test.go index 6ead55a6ae76e..c0d9e6e14df99 100644 --- a/util/truncate/truncate_test.go +++ b/util/truncate/truncate_test.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package truncate_test - -import ( - "testing" - - "tailscale.com/util/truncate" -) - -func TestString(t *testing.T) { - tests := []struct { - input string - size int - want string - }{ - {"", 1000, ""}, // n > length - {"abc", 4, "abc"}, // n > length - {"abc", 3, "abc"}, // n == length - {"abcdefg", 4, "abcd"}, // n < length, safe - {"abcdefg", 0, ""}, // n < length, safe - {"abc\U0001fc2d", 3, "abc"}, // n < length, at boundary - {"abc\U0001fc2d", 4, "abc"}, // n < length, mid-rune - {"abc\U0001fc2d", 5, "abc"}, // n < length, mid-rune - {"abc\U0001fc2d", 6, "abc"}, // n < length, mid-rune - {"abc\U0001fc2defg", 7, "abc"}, // n < length, cut multibyte - } - - for _, tc := range tests { - got := truncate.String(tc.input, tc.size) - if got != tc.want { - t.Errorf("truncate(%q, %d): got %q, want %q", tc.input, tc.size, got, tc.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package truncate_test + +import ( + "testing" + + "tailscale.com/util/truncate" +) + +func TestString(t *testing.T) { + tests := []struct { + input string + size int + want string + }{ + {"", 1000, ""}, // n > length + {"abc", 4, "abc"}, // n > length + {"abc", 3, "abc"}, // n == length + {"abcdefg", 4, "abcd"}, // n < length, safe + {"abcdefg", 0, ""}, // n < length, safe + {"abc\U0001fc2d", 3, "abc"}, // n < length, at boundary + {"abc\U0001fc2d", 4, "abc"}, // n < length, mid-rune + {"abc\U0001fc2d", 5, "abc"}, // n < length, mid-rune + {"abc\U0001fc2d", 6, "abc"}, // n < length, mid-rune + {"abc\U0001fc2defg", 7, "abc"}, // n < length, cut multibyte + } + + for _, tc := range tests { + got := truncate.String(tc.input, tc.size) + if got != tc.want { + t.Errorf("truncate(%q, %d): got %q, want %q", tc.input, tc.size, got, tc.want) + } + } +} diff --git a/util/uniq/slice.go b/util/uniq/slice.go index fb46cc491f5d7..4ab933a9d82d1 100644 --- a/util/uniq/slice.go +++ b/util/uniq/slice.go @@ -1,62 +1,62 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package uniq provides removal of adjacent duplicate elements in slices. -// It is similar to the unix command uniq. -package uniq - -// ModifySlice removes adjacent duplicate elements from the given slice. It -// adjusts the length of the slice appropriately and zeros the tail. -// -// ModifySlice does O(len(*slice)) operations. -func ModifySlice[E comparable](slice *[]E) { - // Remove duplicates - dst := 0 - for i := 1; i < len(*slice); i++ { - if (*slice)[i] == (*slice)[dst] { - continue - } - dst++ - (*slice)[dst] = (*slice)[i] - } - - // Zero out the elements we removed at the end of the slice - end := dst + 1 - var zero E - for i := end; i < len(*slice); i++ { - (*slice)[i] = zero - } - - // Truncate the slice - if end < len(*slice) { - *slice = (*slice)[:end] - } -} - -// ModifySliceFunc is the same as ModifySlice except that it allows using a -// custom comparison function. -// -// eq should report whether the two provided elements are equal. -func ModifySliceFunc[E any](slice *[]E, eq func(i, j E) bool) { - // Remove duplicates - dst := 0 - for i := 1; i < len(*slice); i++ { - if eq((*slice)[dst], (*slice)[i]) { - continue - } - dst++ - (*slice)[dst] = (*slice)[i] - } - - // Zero out the elements we removed at the end of the slice - end := dst + 1 - var zero E - for i := end; i < len(*slice); i++ { - (*slice)[i] = zero - } - - // Truncate the slice - if end < len(*slice) { - *slice = (*slice)[:end] - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package uniq provides removal of adjacent duplicate elements in slices. +// It is similar to the unix command uniq. +package uniq + +// ModifySlice removes adjacent duplicate elements from the given slice. It +// adjusts the length of the slice appropriately and zeros the tail. +// +// ModifySlice does O(len(*slice)) operations. +func ModifySlice[E comparable](slice *[]E) { + // Remove duplicates + dst := 0 + for i := 1; i < len(*slice); i++ { + if (*slice)[i] == (*slice)[dst] { + continue + } + dst++ + (*slice)[dst] = (*slice)[i] + } + + // Zero out the elements we removed at the end of the slice + end := dst + 1 + var zero E + for i := end; i < len(*slice); i++ { + (*slice)[i] = zero + } + + // Truncate the slice + if end < len(*slice) { + *slice = (*slice)[:end] + } +} + +// ModifySliceFunc is the same as ModifySlice except that it allows using a +// custom comparison function. +// +// eq should report whether the two provided elements are equal. +func ModifySliceFunc[E any](slice *[]E, eq func(i, j E) bool) { + // Remove duplicates + dst := 0 + for i := 1; i < len(*slice); i++ { + if eq((*slice)[dst], (*slice)[i]) { + continue + } + dst++ + (*slice)[dst] = (*slice)[i] + } + + // Zero out the elements we removed at the end of the slice + end := dst + 1 + var zero E + for i := end; i < len(*slice); i++ { + (*slice)[i] = zero + } + + // Truncate the slice + if end < len(*slice) { + *slice = (*slice)[:end] + } +} diff --git a/util/winutil/authenticode/mksyscall.go b/util/winutil/authenticode/mksyscall.go index 7c6b33973de8e..8b7cabe6e4d7f 100644 --- a/util/winutil/authenticode/mksyscall.go +++ b/util/winutil/authenticode/mksyscall.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package authenticode - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go -//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go - -//sys cryptCATAdminAcquireContext2(hCatAdmin *_HCATADMIN, pgSubsystem *windows.GUID, hashAlgorithm *uint16, strongHashPolicy *windows.CertStrongSignPara, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminAcquireContext2 -//sys cryptCATAdminCalcHashFromFileHandle2(hCatAdmin _HCATADMIN, file windows.Handle, pcbHash *uint32, pbHash *byte, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminCalcHashFromFileHandle2 -//sys cryptCATAdminCatalogInfoFromContext(hCatInfo _HCATINFO, catInfo *_CATALOG_INFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATCatalogInfoFromContext -//sys cryptCATAdminEnumCatalogFromHash(hCatAdmin _HCATADMIN, pbHash *byte, cbHash uint32, flags uint32, prevCatInfo *_HCATINFO) (ret _HCATINFO, err error) [ret==0] = wintrust.CryptCATAdminEnumCatalogFromHash -//sys cryptCATAdminReleaseCatalogContext(hCatAdmin _HCATADMIN, hCatInfo _HCATINFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseCatalogContext -//sys cryptCATAdminReleaseContext(hCatAdmin _HCATADMIN, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseContext -//sys cryptMsgClose(cryptMsg windows.Handle) (err error) [int32(failretval)==0] = crypt32.CryptMsgClose -//sys cryptMsgGetParam(cryptMsg windows.Handle, paramType uint32, index uint32, data unsafe.Pointer, dataLen *uint32) (err error) [int32(failretval)==0] = crypt32.CryptMsgGetParam -//sys cryptVerifyMessageSignature(pVerifyPara *_CRYPT_VERIFY_MESSAGE_PARA, signerIndex uint32, pbSignedBlob *byte, cbSignedBlob uint32, pbDecoded *byte, pdbDecoded *uint32, ppSignerCert **windows.CertContext) (err error) [int32(failretval)==0] = crypt32.CryptVerifyMessageSignature -//sys msiGetFileSignatureInformation(signedObjectPath *uint16, flags uint32, certCtx **windows.CertContext, pbHashData *byte, cbHashData *uint32) (ret wingoes.HRESULT) = msi.MsiGetFileSignatureInformationW +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package authenticode + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go +//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go + +//sys cryptCATAdminAcquireContext2(hCatAdmin *_HCATADMIN, pgSubsystem *windows.GUID, hashAlgorithm *uint16, strongHashPolicy *windows.CertStrongSignPara, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminAcquireContext2 +//sys cryptCATAdminCalcHashFromFileHandle2(hCatAdmin _HCATADMIN, file windows.Handle, pcbHash *uint32, pbHash *byte, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminCalcHashFromFileHandle2 +//sys cryptCATAdminCatalogInfoFromContext(hCatInfo _HCATINFO, catInfo *_CATALOG_INFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATCatalogInfoFromContext +//sys cryptCATAdminEnumCatalogFromHash(hCatAdmin _HCATADMIN, pbHash *byte, cbHash uint32, flags uint32, prevCatInfo *_HCATINFO) (ret _HCATINFO, err error) [ret==0] = wintrust.CryptCATAdminEnumCatalogFromHash +//sys cryptCATAdminReleaseCatalogContext(hCatAdmin _HCATADMIN, hCatInfo _HCATINFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseCatalogContext +//sys cryptCATAdminReleaseContext(hCatAdmin _HCATADMIN, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseContext +//sys cryptMsgClose(cryptMsg windows.Handle) (err error) [int32(failretval)==0] = crypt32.CryptMsgClose +//sys cryptMsgGetParam(cryptMsg windows.Handle, paramType uint32, index uint32, data unsafe.Pointer, dataLen *uint32) (err error) [int32(failretval)==0] = crypt32.CryptMsgGetParam +//sys cryptVerifyMessageSignature(pVerifyPara *_CRYPT_VERIFY_MESSAGE_PARA, signerIndex uint32, pbSignedBlob *byte, cbSignedBlob uint32, pbDecoded *byte, pdbDecoded *uint32, ppSignerCert **windows.CertContext) (err error) [int32(failretval)==0] = crypt32.CryptVerifyMessageSignature +//sys msiGetFileSignatureInformation(signedObjectPath *uint16, flags uint32, certCtx **windows.CertContext, pbHashData *byte, cbHashData *uint32) (ret wingoes.HRESULT) = msi.MsiGetFileSignatureInformationW diff --git a/util/winutil/policy/policy_windows.go b/util/winutil/policy/policy_windows.go index 4674696fa101d..89142951f8bd5 100644 --- a/util/winutil/policy/policy_windows.go +++ b/util/winutil/policy/policy_windows.go @@ -1,155 +1,155 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package policy contains higher-level abstractions for accessing Windows enterprise policies. -package policy - -import ( - "time" - - "tailscale.com/util/winutil" -) - -// PreferenceOptionPolicy is a policy that governs whether a boolean variable -// is forcibly assigned an administrator-defined value, or allowed to receive -// a user-defined value. -type PreferenceOptionPolicy int - -const ( - showChoiceByPolicy PreferenceOptionPolicy = iota - neverByPolicy - alwaysByPolicy -) - -// Show returns if the UI option that controls the choice administered by this -// policy should be shown. Currently this is true if and only if the policy is -// showChoiceByPolicy. -func (p PreferenceOptionPolicy) Show() bool { - return p == showChoiceByPolicy -} - -// ShouldEnable checks if the choice administered by this policy should be -// enabled. If the administrator has chosen a setting, the administrator's -// setting is returned, otherwise userChoice is returned. -func (p PreferenceOptionPolicy) ShouldEnable(userChoice bool) bool { - switch p { - case neverByPolicy: - return false - case alwaysByPolicy: - return true - default: - return userChoice - } -} - -// GetPreferenceOptionPolicy loads a policy from the registry that can be -// managed by an enterprise policy management system and allows administrative -// overrides of users' choices in a way that we do not want tailcontrol to have -// the authority to set. It describes user-decides/always/never options, where -// "always" and "never" remove the user's ability to make a selection. If not -// present or set to a different value, "user-decides" is the default. -func GetPreferenceOptionPolicy(name string) PreferenceOptionPolicy { - opt, err := winutil.GetPolicyString(name) - if opt == "" || err != nil { - return showChoiceByPolicy - } - switch opt { - case "always": - return alwaysByPolicy - case "never": - return neverByPolicy - default: - return showChoiceByPolicy - } -} - -// VisibilityPolicy is a policy that controls whether or not a particular -// component of a user interface is to be shown. -type VisibilityPolicy byte - -const ( - visibleByPolicy VisibilityPolicy = 'v' - hiddenByPolicy VisibilityPolicy = 'h' -) - -// Show reports whether the UI option administered by this policy should be shown. -// Currently this is true if and only if the policy is visibleByPolicy. -func (p VisibilityPolicy) Show() bool { - return p == visibleByPolicy -} - -// GetVisibilityPolicy loads a policy from the registry that can be managed -// by an enterprise policy management system and describes show/hide decisions -// for UI elements. The registry value should be a string set to "show" (return -// true) or "hide" (return true). If not present or set to a different value, -// "show" (return false) is the default. -func GetVisibilityPolicy(name string) VisibilityPolicy { - opt, err := winutil.GetPolicyString(name) - if opt == "" || err != nil { - return visibleByPolicy - } - switch opt { - case "hide": - return hiddenByPolicy - default: - return visibleByPolicy - } -} - -// GetDurationPolicy loads a policy from the registry that can be managed -// by an enterprise policy management system and describes a duration for some -// action. The registry value should be a string that time.ParseDuration -// understands. If the registry value is "" or can not be processed, -// defaultValue is returned instead. -func GetDurationPolicy(name string, defaultValue time.Duration) time.Duration { - opt, err := winutil.GetPolicyString(name) - if opt == "" || err != nil { - return defaultValue - } - v, err := time.ParseDuration(opt) - if err != nil || v < 0 { - return defaultValue - } - return v -} - -// SelectControlURL returns the ControlURL to use based on a value in -// the registry (LoginURL) and the one on disk (in the GUI's -// prefs.conf). If both are empty, it returns a default value. (It -// always return a non-empty value) -// -// See https://github.com/tailscale/tailscale/issues/2798 for some background. -func SelectControlURL(reg, disk string) string { - const def = "https://controlplane.tailscale.com" - - // Prior to Dec 2020's commit 739b02e6, the installer - // wrote a LoginURL value of https://login.tailscale.com to the registry. - const oldRegDef = "https://login.tailscale.com" - - // If they have an explicit value in the registry, use it, - // unless it's an old default value from an old installer. - // Then we have to see which is better. - if reg != "" { - if reg != oldRegDef { - // Something explicit in the registry that we didn't - // set ourselves by the installer. - return reg - } - if disk == "" { - // Something in the registry is better than nothing on disk. - return reg - } - if disk != def && disk != oldRegDef { - // The value in the registry is the old - // default (login.tailscale.com) but the value - // on disk is neither our old nor new default - // value, so it must be some custom thing that - // the user cares about. Prefer the disk value. - return disk - } - } - if disk != "" { - return disk - } - return def -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package policy contains higher-level abstractions for accessing Windows enterprise policies. +package policy + +import ( + "time" + + "tailscale.com/util/winutil" +) + +// PreferenceOptionPolicy is a policy that governs whether a boolean variable +// is forcibly assigned an administrator-defined value, or allowed to receive +// a user-defined value. +type PreferenceOptionPolicy int + +const ( + showChoiceByPolicy PreferenceOptionPolicy = iota + neverByPolicy + alwaysByPolicy +) + +// Show returns if the UI option that controls the choice administered by this +// policy should be shown. Currently this is true if and only if the policy is +// showChoiceByPolicy. +func (p PreferenceOptionPolicy) Show() bool { + return p == showChoiceByPolicy +} + +// ShouldEnable checks if the choice administered by this policy should be +// enabled. If the administrator has chosen a setting, the administrator's +// setting is returned, otherwise userChoice is returned. +func (p PreferenceOptionPolicy) ShouldEnable(userChoice bool) bool { + switch p { + case neverByPolicy: + return false + case alwaysByPolicy: + return true + default: + return userChoice + } +} + +// GetPreferenceOptionPolicy loads a policy from the registry that can be +// managed by an enterprise policy management system and allows administrative +// overrides of users' choices in a way that we do not want tailcontrol to have +// the authority to set. It describes user-decides/always/never options, where +// "always" and "never" remove the user's ability to make a selection. If not +// present or set to a different value, "user-decides" is the default. +func GetPreferenceOptionPolicy(name string) PreferenceOptionPolicy { + opt, err := winutil.GetPolicyString(name) + if opt == "" || err != nil { + return showChoiceByPolicy + } + switch opt { + case "always": + return alwaysByPolicy + case "never": + return neverByPolicy + default: + return showChoiceByPolicy + } +} + +// VisibilityPolicy is a policy that controls whether or not a particular +// component of a user interface is to be shown. +type VisibilityPolicy byte + +const ( + visibleByPolicy VisibilityPolicy = 'v' + hiddenByPolicy VisibilityPolicy = 'h' +) + +// Show reports whether the UI option administered by this policy should be shown. +// Currently this is true if and only if the policy is visibleByPolicy. +func (p VisibilityPolicy) Show() bool { + return p == visibleByPolicy +} + +// GetVisibilityPolicy loads a policy from the registry that can be managed +// by an enterprise policy management system and describes show/hide decisions +// for UI elements. The registry value should be a string set to "show" (return +// true) or "hide" (return true). If not present or set to a different value, +// "show" (return false) is the default. +func GetVisibilityPolicy(name string) VisibilityPolicy { + opt, err := winutil.GetPolicyString(name) + if opt == "" || err != nil { + return visibleByPolicy + } + switch opt { + case "hide": + return hiddenByPolicy + default: + return visibleByPolicy + } +} + +// GetDurationPolicy loads a policy from the registry that can be managed +// by an enterprise policy management system and describes a duration for some +// action. The registry value should be a string that time.ParseDuration +// understands. If the registry value is "" or can not be processed, +// defaultValue is returned instead. +func GetDurationPolicy(name string, defaultValue time.Duration) time.Duration { + opt, err := winutil.GetPolicyString(name) + if opt == "" || err != nil { + return defaultValue + } + v, err := time.ParseDuration(opt) + if err != nil || v < 0 { + return defaultValue + } + return v +} + +// SelectControlURL returns the ControlURL to use based on a value in +// the registry (LoginURL) and the one on disk (in the GUI's +// prefs.conf). If both are empty, it returns a default value. (It +// always return a non-empty value) +// +// See https://github.com/tailscale/tailscale/issues/2798 for some background. +func SelectControlURL(reg, disk string) string { + const def = "https://controlplane.tailscale.com" + + // Prior to Dec 2020's commit 739b02e6, the installer + // wrote a LoginURL value of https://login.tailscale.com to the registry. + const oldRegDef = "https://login.tailscale.com" + + // If they have an explicit value in the registry, use it, + // unless it's an old default value from an old installer. + // Then we have to see which is better. + if reg != "" { + if reg != oldRegDef { + // Something explicit in the registry that we didn't + // set ourselves by the installer. + return reg + } + if disk == "" { + // Something in the registry is better than nothing on disk. + return reg + } + if disk != def && disk != oldRegDef { + // The value in the registry is the old + // default (login.tailscale.com) but the value + // on disk is neither our old nor new default + // value, so it must be some custom thing that + // the user cares about. Prefer the disk value. + return disk + } + } + if disk != "" { + return disk + } + return def +} diff --git a/util/winutil/policy/policy_windows_test.go b/util/winutil/policy/policy_windows_test.go index ebfd185deaaf2..cf2390c568cce 100644 --- a/util/winutil/policy/policy_windows_test.go +++ b/util/winutil/policy/policy_windows_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package policy - -import "testing" - -func TestSelectControlURL(t *testing.T) { - tests := []struct { - reg, disk, want string - }{ - // Modern default case. - {"", "", "https://controlplane.tailscale.com"}, - - // For a user who installed prior to Dec 2020, with - // stuff in their registry. - {"https://login.tailscale.com", "", "https://login.tailscale.com"}, - - // Ignore pre-Dec'20 LoginURL from installer if prefs - // prefs overridden manually to an on-prem control - // server. - {"https://login.tailscale.com", "http://on-prem", "http://on-prem"}, - - // Something unknown explicitly set in the registry always wins. - {"http://explicit-reg", "", "http://explicit-reg"}, - {"http://explicit-reg", "http://on-prem", "http://explicit-reg"}, - {"http://explicit-reg", "https://login.tailscale.com", "http://explicit-reg"}, - {"http://explicit-reg", "https://controlplane.tailscale.com", "http://explicit-reg"}, - - // If nothing in the registry, disk wins. - {"", "http://on-prem", "http://on-prem"}, - } - for _, tt := range tests { - if got := SelectControlURL(tt.reg, tt.disk); got != tt.want { - t.Errorf("(reg %q, disk %q) = %q; want %q", tt.reg, tt.disk, got, tt.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package policy + +import "testing" + +func TestSelectControlURL(t *testing.T) { + tests := []struct { + reg, disk, want string + }{ + // Modern default case. + {"", "", "https://controlplane.tailscale.com"}, + + // For a user who installed prior to Dec 2020, with + // stuff in their registry. + {"https://login.tailscale.com", "", "https://login.tailscale.com"}, + + // Ignore pre-Dec'20 LoginURL from installer if prefs + // prefs overridden manually to an on-prem control + // server. + {"https://login.tailscale.com", "http://on-prem", "http://on-prem"}, + + // Something unknown explicitly set in the registry always wins. + {"http://explicit-reg", "", "http://explicit-reg"}, + {"http://explicit-reg", "http://on-prem", "http://explicit-reg"}, + {"http://explicit-reg", "https://login.tailscale.com", "http://explicit-reg"}, + {"http://explicit-reg", "https://controlplane.tailscale.com", "http://explicit-reg"}, + + // If nothing in the registry, disk wins. + {"", "http://on-prem", "http://on-prem"}, + } + for _, tt := range tests { + if got := SelectControlURL(tt.reg, tt.disk); got != tt.want { + t.Errorf("(reg %q, disk %q) = %q; want %q", tt.reg, tt.disk, got, tt.want) + } + } +} diff --git a/version/.gitignore b/version/.gitignore index 8878450fa4364..58d19bfc27c97 100644 --- a/version/.gitignore +++ b/version/.gitignore @@ -1,10 +1,10 @@ -describe.txt -long.txt -short.txt -gitcommit.txt -extragitcommit.txt -version-info.sh -version.h -version.xcconfig -ver.go -version +describe.txt +long.txt +short.txt +gitcommit.txt +extragitcommit.txt +version-info.sh +version.h +version.xcconfig +ver.go +version diff --git a/version/cmdname.go b/version/cmdname.go index 9f85ef96d427f..51e065438e3a5 100644 --- a/version/cmdname.go +++ b/version/cmdname.go @@ -1,139 +1,139 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ios - -package version - -import ( - "bytes" - "encoding/hex" - "errors" - "io" - "os" - "path" - "path/filepath" - "strings" -) - -// CmdName returns either the base name of the current binary -// using os.Executable. If os.Executable fails (it shouldn't), then -// "cmd" is returned. -func CmdName() string { - e, err := os.Executable() - if err != nil { - return "cmd" - } - return cmdName(e) -} - -func cmdName(exe string) string { - // fallbackName, the lowercase basename of the executable, is what we return if - // we can't find the Go module metadata embedded in the file. - fallbackName := filepath.Base(strings.TrimSuffix(strings.ToLower(exe), ".exe")) - - var ret string - info, err := findModuleInfo(exe) - if err != nil { - return fallbackName - } - // v is like: - // "path\ttailscale.com/cmd/tailscale\nmod\ttailscale.com\t(devel)\t\ndep\tgithub.com/apenwarr/fixconsole\tv0.0.0-20191012055117-5a9f6489cc29\th1:muXWUcay7DDy1/hEQWrYlBy+g0EuwT70sBHg65SeUc4=\ndep\tgithub.... - for _, line := range strings.Split(info, "\n") { - if goPkg, ok := strings.CutPrefix(line, "path\t"); ok { // like "tailscale.com/cmd/tailscale" - ret = path.Base(goPkg) // goPkg is always forward slashes; use path, not filepath - break - } - } - if strings.HasPrefix(ret, "wg") && fallbackName == "tailscale-ipn" { - // The tailscale-ipn.exe binary for internal build system packaging reasons - // has a path of "tailscale.io/win/wg64", "tailscale.io/win/wg32", etc. - // Ignore that name and use "tailscale-ipn" instead. - return fallbackName - } - if ret == "" { - return fallbackName - } - return ret -} - -// findModuleInfo returns the Go module info from the executable file. -func findModuleInfo(file string) (s string, err error) { - f, err := os.Open(file) - if err != nil { - return "", err - } - defer f.Close() - // Scan through f until we find infoStart. - buf := make([]byte, 65536) - start, err := findOffset(f, buf, infoStart) - if err != nil { - return "", err - } - start += int64(len(infoStart)) - // Seek to the end of infoStart and scan for infoEnd. - _, err = f.Seek(start, io.SeekStart) - if err != nil { - return "", err - } - end, err := findOffset(f, buf, infoEnd) - if err != nil { - return "", err - } - length := end - start - // As of Aug 2021, tailscaled's mod info was about 2k. - if length > int64(len(buf)) { - return "", errors.New("mod info too large") - } - // We have located modinfo. Read it into buf. - buf = buf[:length] - _, err = f.Seek(start, io.SeekStart) - if err != nil { - return "", err - } - _, err = io.ReadFull(f, buf) - if err != nil { - return "", err - } - return string(buf), nil -} - -// findOffset finds the absolute offset of needle in f, -// starting at f's current read position, -// using temporary buffer buf. -func findOffset(f *os.File, buf, needle []byte) (int64, error) { - for { - // Fill buf and look within it. - n, err := f.Read(buf) - if err != nil { - return -1, err - } - i := bytes.Index(buf[:n], needle) - if i < 0 { - // Not found. Rewind a little bit in case we happened to end halfway through needle. - rewind, err := f.Seek(int64(-len(needle)), io.SeekCurrent) - if err != nil { - return -1, err - } - // If we're at EOF and rewound exactly len(needle) bytes, return io.EOF. - _, err = f.ReadAt(buf[:1], rewind+int64(len(needle))) - if err == io.EOF { - return -1, err - } - continue - } - // Found! Figure out exactly where. - cur, err := f.Seek(0, io.SeekCurrent) - if err != nil { - return -1, err - } - return cur - int64(n) + int64(i), nil - } -} - -// These constants are taken from rsc.io/goversion. - -var ( - infoStart, _ = hex.DecodeString("3077af0c9274080241e1c107e6d618e6") - infoEnd, _ = hex.DecodeString("f932433186182072008242104116d8f2") -) +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios + +package version + +import ( + "bytes" + "encoding/hex" + "errors" + "io" + "os" + "path" + "path/filepath" + "strings" +) + +// CmdName returns either the base name of the current binary +// using os.Executable. If os.Executable fails (it shouldn't), then +// "cmd" is returned. +func CmdName() string { + e, err := os.Executable() + if err != nil { + return "cmd" + } + return cmdName(e) +} + +func cmdName(exe string) string { + // fallbackName, the lowercase basename of the executable, is what we return if + // we can't find the Go module metadata embedded in the file. + fallbackName := filepath.Base(strings.TrimSuffix(strings.ToLower(exe), ".exe")) + + var ret string + info, err := findModuleInfo(exe) + if err != nil { + return fallbackName + } + // v is like: + // "path\ttailscale.com/cmd/tailscale\nmod\ttailscale.com\t(devel)\t\ndep\tgithub.com/apenwarr/fixconsole\tv0.0.0-20191012055117-5a9f6489cc29\th1:muXWUcay7DDy1/hEQWrYlBy+g0EuwT70sBHg65SeUc4=\ndep\tgithub.... + for _, line := range strings.Split(info, "\n") { + if goPkg, ok := strings.CutPrefix(line, "path\t"); ok { // like "tailscale.com/cmd/tailscale" + ret = path.Base(goPkg) // goPkg is always forward slashes; use path, not filepath + break + } + } + if strings.HasPrefix(ret, "wg") && fallbackName == "tailscale-ipn" { + // The tailscale-ipn.exe binary for internal build system packaging reasons + // has a path of "tailscale.io/win/wg64", "tailscale.io/win/wg32", etc. + // Ignore that name and use "tailscale-ipn" instead. + return fallbackName + } + if ret == "" { + return fallbackName + } + return ret +} + +// findModuleInfo returns the Go module info from the executable file. +func findModuleInfo(file string) (s string, err error) { + f, err := os.Open(file) + if err != nil { + return "", err + } + defer f.Close() + // Scan through f until we find infoStart. + buf := make([]byte, 65536) + start, err := findOffset(f, buf, infoStart) + if err != nil { + return "", err + } + start += int64(len(infoStart)) + // Seek to the end of infoStart and scan for infoEnd. + _, err = f.Seek(start, io.SeekStart) + if err != nil { + return "", err + } + end, err := findOffset(f, buf, infoEnd) + if err != nil { + return "", err + } + length := end - start + // As of Aug 2021, tailscaled's mod info was about 2k. + if length > int64(len(buf)) { + return "", errors.New("mod info too large") + } + // We have located modinfo. Read it into buf. + buf = buf[:length] + _, err = f.Seek(start, io.SeekStart) + if err != nil { + return "", err + } + _, err = io.ReadFull(f, buf) + if err != nil { + return "", err + } + return string(buf), nil +} + +// findOffset finds the absolute offset of needle in f, +// starting at f's current read position, +// using temporary buffer buf. +func findOffset(f *os.File, buf, needle []byte) (int64, error) { + for { + // Fill buf and look within it. + n, err := f.Read(buf) + if err != nil { + return -1, err + } + i := bytes.Index(buf[:n], needle) + if i < 0 { + // Not found. Rewind a little bit in case we happened to end halfway through needle. + rewind, err := f.Seek(int64(-len(needle)), io.SeekCurrent) + if err != nil { + return -1, err + } + // If we're at EOF and rewound exactly len(needle) bytes, return io.EOF. + _, err = f.ReadAt(buf[:1], rewind+int64(len(needle))) + if err == io.EOF { + return -1, err + } + continue + } + // Found! Figure out exactly where. + cur, err := f.Seek(0, io.SeekCurrent) + if err != nil { + return -1, err + } + return cur - int64(n) + int64(i), nil + } +} + +// These constants are taken from rsc.io/goversion. + +var ( + infoStart, _ = hex.DecodeString("3077af0c9274080241e1c107e6d618e6") + infoEnd, _ = hex.DecodeString("f932433186182072008242104116d8f2") +) diff --git a/version/cmdname_ios.go b/version/cmdname_ios.go index 5e338944c6916..6bfed38b64226 100644 --- a/version/cmdname_ios.go +++ b/version/cmdname_ios.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ios - -package version - -import ( - "os" -) - -func CmdName() string { - e, err := os.Executable() - if err != nil { - return "cmd" - } - return e -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ios + +package version + +import ( + "os" +) + +func CmdName() string { + e, err := os.Executable() + if err != nil { + return "cmd" + } + return e +} diff --git a/version/cmp_test.go b/version/cmp_test.go index 59153f0dd15d0..e244d5e16fe22 100644 --- a/version/cmp_test.go +++ b/version/cmp_test.go @@ -1,82 +1,82 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package version_test - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - "tailscale.com/tstest" - "tailscale.com/version" -) - -func TestParse(t *testing.T) { - parse := version.ExportParse - type parsed = version.ExportParsed - - tests := []struct { - version string - parsed parsed - want bool - }{ - {"1", parsed{Major: 1}, true}, - {"1.2", parsed{Major: 1, Minor: 2}, true}, - {"1.2.3", parsed{Major: 1, Minor: 2, Patch: 3}, true}, - {"1.2.3-4", parsed{Major: 1, Minor: 2, Patch: 3, ExtraCommits: 4}, true}, - {"1.2-4", parsed{Major: 1, Minor: 2, ExtraCommits: 4}, true}, - {"1.2.3-4-extra", parsed{Major: 1, Minor: 2, Patch: 3, ExtraCommits: 4}, true}, - {"1.2.3-4a-test", parsed{Major: 1, Minor: 2, Patch: 3}, true}, - {"1.2-extra", parsed{Major: 1, Minor: 2}, true}, - {"1.2.3-extra", parsed{Major: 1, Minor: 2, Patch: 3}, true}, - {"date.20200612", parsed{Datestamp: 20200612}, true}, - {"borkbork", parsed{}, false}, - {"1a.2.3", parsed{}, false}, - {"", parsed{}, false}, - } - - for _, test := range tests { - gotParsed, got := parse(test.version) - if got != test.want { - t.Errorf("version(%q) = %v, want %v", test.version, got, test.want) - } - if diff := cmp.Diff(gotParsed, test.parsed); diff != "" { - t.Errorf("parse(%q) diff (-got+want):\n%s", test.version, diff) - } - err := tstest.MinAllocsPerRun(t, 0, func() { - gotParsed, got = parse(test.version) - }) - if err != nil { - t.Errorf("parse(%q): %v", test.version, err) - } - } -} - -func TestAtLeast(t *testing.T) { - tests := []struct { - v, m string - want bool - }{ - {"1", "1", true}, - {"1.2", "1", true}, - {"1.2.3", "1", true}, - {"1.2.3-4", "1", true}, - {"0.98-0", "0.98", true}, - {"0.97.1-216", "0.98", false}, - {"0.94", "0.98", false}, - {"0.98", "0.98", true}, - {"0.98.0-0", "0.98", true}, - {"1.2.3-4", "1.2.4-4", false}, - {"1.2.3-4", "1.2.3-4", true}, - {"date.20200612", "date.20200612", true}, - {"date.20200701", "date.20200612", true}, - {"date.20200501", "date.20200612", false}, - } - - for _, test := range tests { - got := version.AtLeast(test.v, test.m) - if got != test.want { - t.Errorf("AtLeast(%q, %q) = %v, want %v", test.v, test.m, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/tstest" + "tailscale.com/version" +) + +func TestParse(t *testing.T) { + parse := version.ExportParse + type parsed = version.ExportParsed + + tests := []struct { + version string + parsed parsed + want bool + }{ + {"1", parsed{Major: 1}, true}, + {"1.2", parsed{Major: 1, Minor: 2}, true}, + {"1.2.3", parsed{Major: 1, Minor: 2, Patch: 3}, true}, + {"1.2.3-4", parsed{Major: 1, Minor: 2, Patch: 3, ExtraCommits: 4}, true}, + {"1.2-4", parsed{Major: 1, Minor: 2, ExtraCommits: 4}, true}, + {"1.2.3-4-extra", parsed{Major: 1, Minor: 2, Patch: 3, ExtraCommits: 4}, true}, + {"1.2.3-4a-test", parsed{Major: 1, Minor: 2, Patch: 3}, true}, + {"1.2-extra", parsed{Major: 1, Minor: 2}, true}, + {"1.2.3-extra", parsed{Major: 1, Minor: 2, Patch: 3}, true}, + {"date.20200612", parsed{Datestamp: 20200612}, true}, + {"borkbork", parsed{}, false}, + {"1a.2.3", parsed{}, false}, + {"", parsed{}, false}, + } + + for _, test := range tests { + gotParsed, got := parse(test.version) + if got != test.want { + t.Errorf("version(%q) = %v, want %v", test.version, got, test.want) + } + if diff := cmp.Diff(gotParsed, test.parsed); diff != "" { + t.Errorf("parse(%q) diff (-got+want):\n%s", test.version, diff) + } + err := tstest.MinAllocsPerRun(t, 0, func() { + gotParsed, got = parse(test.version) + }) + if err != nil { + t.Errorf("parse(%q): %v", test.version, err) + } + } +} + +func TestAtLeast(t *testing.T) { + tests := []struct { + v, m string + want bool + }{ + {"1", "1", true}, + {"1.2", "1", true}, + {"1.2.3", "1", true}, + {"1.2.3-4", "1", true}, + {"0.98-0", "0.98", true}, + {"0.97.1-216", "0.98", false}, + {"0.94", "0.98", false}, + {"0.98", "0.98", true}, + {"0.98.0-0", "0.98", true}, + {"1.2.3-4", "1.2.4-4", false}, + {"1.2.3-4", "1.2.3-4", true}, + {"date.20200612", "date.20200612", true}, + {"date.20200701", "date.20200612", true}, + {"date.20200501", "date.20200612", false}, + } + + for _, test := range tests { + got := version.AtLeast(test.v, test.m) + if got != test.want { + t.Errorf("AtLeast(%q, %q) = %v, want %v", test.v, test.m, got, test.want) + } + } +} diff --git a/version/export_test.go b/version/export_test.go index fabba13e8ba55..8e8ce5ecb2129 100644 --- a/version/export_test.go +++ b/version/export_test.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package version - -var ( - ExportParse = parse - ExportFindModuleInfo = findModuleInfo - ExportCmdName = cmdName -) - -type ( - ExportParsed = parsed -) +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version + +var ( + ExportParse = parse + ExportFindModuleInfo = findModuleInfo + ExportCmdName = cmdName +) + +type ( + ExportParsed = parsed +) diff --git a/version/print.go b/version/print.go index e3bfc38efa16c..7d8554279f255 100644 --- a/version/print.go +++ b/version/print.go @@ -1,33 +1,33 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package version - -import ( - "fmt" - "runtime" - "strings" - - "tailscale.com/types/lazy" -) - -var stringLazy = lazy.SyncFunc(func() string { - var ret strings.Builder - ret.WriteString(Short()) - ret.WriteByte('\n') - if IsUnstableBuild() { - fmt.Fprintf(&ret, " track: unstable (dev); frequent updates and bugs are likely\n") - } - if gitCommit() != "" { - fmt.Fprintf(&ret, " tailscale commit: %s%s\n", gitCommit(), dirtyString()) - } - if extraGitCommitStamp != "" { - fmt.Fprintf(&ret, " other commit: %s\n", extraGitCommitStamp) - } - fmt.Fprintf(&ret, " go version: %s\n", runtime.Version()) - return strings.TrimSpace(ret.String()) -}) - -func String() string { - return stringLazy() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version + +import ( + "fmt" + "runtime" + "strings" + + "tailscale.com/types/lazy" +) + +var stringLazy = lazy.SyncFunc(func() string { + var ret strings.Builder + ret.WriteString(Short()) + ret.WriteByte('\n') + if IsUnstableBuild() { + fmt.Fprintf(&ret, " track: unstable (dev); frequent updates and bugs are likely\n") + } + if gitCommit() != "" { + fmt.Fprintf(&ret, " tailscale commit: %s%s\n", gitCommit(), dirtyString()) + } + if extraGitCommitStamp != "" { + fmt.Fprintf(&ret, " other commit: %s\n", extraGitCommitStamp) + } + fmt.Fprintf(&ret, " go version: %s\n", runtime.Version()) + return strings.TrimSpace(ret.String()) +}) + +func String() string { + return stringLazy() +} diff --git a/version/race.go b/version/race.go index bc3ca8db6b6dd..e1dc76591ebf4 100644 --- a/version/race.go +++ b/version/race.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build race - -package version - -// IsRace reports whether the current binary was built with the Go -// race detector enabled. -func IsRace() bool { return true } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build race + +package version + +// IsRace reports whether the current binary was built with the Go +// race detector enabled. +func IsRace() bool { return true } diff --git a/version/race_off.go b/version/race_off.go index d55288d9cc962..6db901974bb77 100644 --- a/version/race_off.go +++ b/version/race_off.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !race - -package version - -// IsRace reports whether the current binary was built with the Go -// race detector enabled. -func IsRace() bool { return false } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !race + +package version + +// IsRace reports whether the current binary was built with the Go +// race detector enabled. +func IsRace() bool { return false } diff --git a/version/version_test.go b/version/version_test.go index 4d676f9f5ea1f..a515650586cc4 100644 --- a/version/version_test.go +++ b/version/version_test.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package version_test - -import ( - "bytes" - "os" - "testing" - - ts "tailscale.com" - "tailscale.com/version" -) - -func TestAlpineTag(t *testing.T) { - if tag := readAlpineTag(t, "../Dockerfile.base"); tag == "" { - t.Fatal(`"FROM alpine:" not found in Dockerfile.base`) - } else if tag != ts.AlpineDockerTag { - t.Errorf("alpine version mismatch: Dockerfile.base has %q; ALPINE.txt has %q", tag, ts.AlpineDockerTag) - } - if tag := readAlpineTag(t, "../Dockerfile"); tag == "" { - t.Fatal(`"FROM alpine:" not found in Dockerfile`) - } else if tag != ts.AlpineDockerTag { - t.Errorf("alpine version mismatch: Dockerfile has %q; ALPINE.txt has %q", tag, ts.AlpineDockerTag) - } -} - -func readAlpineTag(t *testing.T, file string) string { - f, err := os.ReadFile(file) - if err != nil { - t.Fatal(err) - } - for _, line := range bytes.Split(f, []byte{'\n'}) { - line = bytes.TrimSpace(line) - _, suf, ok := bytes.Cut(line, []byte("FROM alpine:")) - if !ok { - continue - } - return string(suf) - } - return "" -} - -func TestShortAllocs(t *testing.T) { - allocs := int(testing.AllocsPerRun(10000, func() { - _ = version.Short() - })) - if allocs > 0 { - t.Errorf("allocs = %v; want 0", allocs) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version_test + +import ( + "bytes" + "os" + "testing" + + ts "tailscale.com" + "tailscale.com/version" +) + +func TestAlpineTag(t *testing.T) { + if tag := readAlpineTag(t, "../Dockerfile.base"); tag == "" { + t.Fatal(`"FROM alpine:" not found in Dockerfile.base`) + } else if tag != ts.AlpineDockerTag { + t.Errorf("alpine version mismatch: Dockerfile.base has %q; ALPINE.txt has %q", tag, ts.AlpineDockerTag) + } + if tag := readAlpineTag(t, "../Dockerfile"); tag == "" { + t.Fatal(`"FROM alpine:" not found in Dockerfile`) + } else if tag != ts.AlpineDockerTag { + t.Errorf("alpine version mismatch: Dockerfile has %q; ALPINE.txt has %q", tag, ts.AlpineDockerTag) + } +} + +func readAlpineTag(t *testing.T, file string) string { + f, err := os.ReadFile(file) + if err != nil { + t.Fatal(err) + } + for _, line := range bytes.Split(f, []byte{'\n'}) { + line = bytes.TrimSpace(line) + _, suf, ok := bytes.Cut(line, []byte("FROM alpine:")) + if !ok { + continue + } + return string(suf) + } + return "" +} + +func TestShortAllocs(t *testing.T) { + allocs := int(testing.AllocsPerRun(10000, func() { + _ = version.Short() + })) + if allocs > 0 { + t.Errorf("allocs = %v; want 0", allocs) + } +} diff --git a/wgengine/bench/bench.go b/wgengine/bench/bench.go index b94930ee50c11..8695f18d15899 100644 --- a/wgengine/bench/bench.go +++ b/wgengine/bench/bench.go @@ -1,409 +1,409 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Create two wgengine instances and pass data through them, measuring -// throughput, latency, and packet loss. -package main - -import ( - "bufio" - "io" - "log" - "net" - "net/http" - "net/http/pprof" - "net/netip" - "os" - "strconv" - "sync" - "time" - - "tailscale.com/types/logger" -) - -const PayloadSize = 1000 -const ICMPMinSize = 24 - -var Addr1 = netip.MustParsePrefix("100.64.1.1/32") -var Addr2 = netip.MustParsePrefix("100.64.1.2/32") - -func main() { - var logf logger.Logf = log.Printf - log.SetFlags(0) - - debugMux := newDebugMux() - go runDebugServer(debugMux, "0.0.0.0:8999") - - mode, err := strconv.Atoi(os.Args[1]) - if err != nil { - log.Fatalf("%q: %v", os.Args[1], err) - } - - traf := NewTrafficGen(nil) - - // Sample test results below are using GOMAXPROCS=2 (for some - // tests, including wireguard-go, higher GOMAXPROCS goes slower) - // on apenwarr's old Linux box: - // Intel(R) Core(TM) i7-4785T CPU @ 2.20GHz - // My 2019 Mac Mini is about 20% faster on most tests. - - switch mode { - // tx=8786325 rx=8786326 (0 = 0.00% loss) (70768.7 Mbits/sec) - case 1: - setupTrivialNoAllocTest(logf, traf) - - // tx=6476293 rx=6476293 (0 = 0.00% loss) (52249.7 Mbits/sec) - case 2: - setupTrivialTest(logf, traf) - - // tx=1957974 rx=1958379 (0 = 0.00% loss) (15939.8 Mbits/sec) - case 11: - setupBlockingChannelTest(logf, traf) - - // tx=728621 rx=701825 (26620 = 3.65% loss) (5525.2 Mbits/sec) - // (much faster on macOS??) - case 12: - setupNonblockingChannelTest(logf, traf) - - // tx=1024260 rx=941098 (83334 = 8.14% loss) (7516.6 Mbits/sec) - // (much faster on macOS??) - case 13: - setupDoubleChannelTest(logf, traf) - - // tx=265468 rx=263189 (2279 = 0.86% loss) (2162.0 Mbits/sec) - case 21: - setupUDPTest(logf, traf) - - // tx=1493580 rx=1493580 (0 = 0.00% loss) (12210.4 Mbits/sec) - case 31: - setupBatchTCPTest(logf, traf) - - // tx=134236 rx=133166 (1070 = 0.80% loss) (1088.9 Mbits/sec) - case 101: - setupWGTest(nil, logf, traf, Addr1, Addr2) - - default: - log.Fatalf("provide a valid test number (0..n)") - } - - logf("initialized ok.") - traf.Start(Addr1.Addr(), Addr2.Addr(), PayloadSize+ICMPMinSize, 0) - - var cur, prev Snapshot - var pps int64 - i := 0 - for { - i += 1 - time.Sleep(10 * time.Millisecond) - - if (i % 100) == 0 { - prev = cur - cur = traf.Snap() - d := cur.Sub(prev) - - if prev.WhenNsec == 0 { - logf("tx=%-6d rx=%-6d", d.TxPackets, d.RxPackets) - } else { - logf("%v @%7d pkt/s", d, pps) - } - } - - pps = traf.Adjust() - } -} - -func newDebugMux() *http.ServeMux { - mux := http.NewServeMux() - mux.HandleFunc("/debug/pprof/", pprof.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprof.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprof.Trace) - return mux -} - -func runDebugServer(mux *http.ServeMux, addr string) { - srv := &http.Server{ - Addr: addr, - Handler: mux, - } - if err := srv.ListenAndServe(); err != nil { - log.Fatal(err) - } -} - -// The absolute minimal test of the traffic generator: have it fill -// a packet buffer, then absorb it again. Zero packet loss. -func setupTrivialNoAllocTest(logf logger.Logf, traf *TrafficGen) { - go func() { - b := make([]byte, 1600) - for { - n := traf.Generate(b, 16) - if n == 0 { - break - } - traf.GotPacket(b[0:n+16], 16) - } - }() -} - -// Almost the same, but this time allocate a fresh buffer each time -// through the loop. Still zero packet loss. Runs about 2/3 as fast for me. -func setupTrivialTest(logf logger.Logf, traf *TrafficGen) { - go func() { - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - break - } - traf.GotPacket(b[0:n+16], 16) - } - }() -} - -// Pass packets through a blocking channel between sender and receiver. -// Still zero packet loss since the sender stops when the channel is full. -// Max speed depends on channel length (I'm not sure why). -func setupBlockingChannelTest(logf logger.Logf, traf *TrafficGen) { - ch := make(chan []byte, 1000) - - go func() { - // transmitter - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - close(ch) - break - } - ch <- b[0 : n+16] - } - }() - - go func() { - // receiver - for b := range ch { - traf.GotPacket(b, 16) - } - }() -} - -// Same as setupBlockingChannelTest, but now we drop packets whenever the -// channel is full. Max speed is about the same as the above test, but -// now with nonzero packet loss. -func setupNonblockingChannelTest(logf logger.Logf, traf *TrafficGen) { - ch := make(chan []byte, 1000) - - go func() { - // transmitter - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - close(ch) - break - } - select { - case ch <- b[0 : n+16]: - default: - } - } - }() - - go func() { - // receiver - for b := range ch { - traf.GotPacket(b, 16) - } - }() -} - -// Same as above, but at an intermediate blocking channel and goroutine -// to make things a little more like wireguard-go. Roughly 20% slower than -// the single-channel version. -func setupDoubleChannelTest(logf logger.Logf, traf *TrafficGen) { - ch := make(chan []byte, 1000) - ch2 := make(chan []byte, 1000) - - go func() { - // transmitter - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - close(ch) - break - } - select { - case ch <- b[0 : n+16]: - default: - } - } - }() - - go func() { - // intermediary - for b := range ch { - ch2 <- b - } - close(ch2) - }() - - go func() { - // receiver - for b := range ch2 { - traf.GotPacket(b, 16) - } - }() -} - -// Instead of a channel, pass packets through a UDP socket. -func setupUDPTest(logf logger.Logf, traf *TrafficGen) { - la, err := net.ResolveUDPAddr("udp", ":0") - if err != nil { - log.Fatalf("resolve: %v", err) - } - - s1, err := net.ListenUDP("udp", la) - if err != nil { - log.Fatalf("listen1: %v", err) - } - s2, err := net.ListenUDP("udp", la) - if err != nil { - log.Fatalf("listen2: %v", err) - } - - a2 := s2.LocalAddr() - - // On macOS (but not Linux), you can't transmit to 0.0.0.0:port, - // which is what returns from .LocalAddr() above. We have to - // force it to localhost instead. - a2.(*net.UDPAddr).IP = net.ParseIP("127.0.0.1") - - s1.SetWriteBuffer(1024 * 1024) - s2.SetReadBuffer(1024 * 1024) - - go func() { - // transmitter - b := make([]byte, 1600) - for { - n := traf.Generate(b, 16) - if n == 0 { - break - } - s1.WriteTo(b[16:n+16], a2) - } - }() - - go func() { - // receiver - b := make([]byte, 1600) - for traf.Running() { - // Use ReadFrom instead of Read, to be more like - // how wireguard-go does it, even though we're not - // going to actually look at the address. - n, _, err := s2.ReadFrom(b) - if err != nil { - log.Fatalf("s2.Read: %v", err) - } - traf.GotPacket(b[:n], 0) - } - }() -} - -// Instead of a channel, pass packets through a TCP socket. -// TCP is a single stream, so we can amortize one syscall across -// multiple packets. 10x amortization seems to make it go ~10x faster, -// as expected, getting us close to the speed of the channel tests above. -// There's also zero packet loss. -func setupBatchTCPTest(logf logger.Logf, traf *TrafficGen) { - sl, err := net.Listen("tcp", ":0") - if err != nil { - log.Fatalf("listen: %v", err) - } - - var slCloseOnce sync.Once - slClose := func() { - slCloseOnce.Do(func() { - sl.Close() - }) - } - - s1, err := net.Dial("tcp", sl.Addr().String()) - if err != nil { - log.Fatalf("dial: %v", err) - } - - s2, err := sl.Accept() - if err != nil { - log.Fatalf("accept: %v", err) - } - - s1.(*net.TCPConn).SetWriteBuffer(1024 * 1024) - s2.(*net.TCPConn).SetReadBuffer(1024 * 1024) - - ch := make(chan int) - - go func() { - // transmitter - defer slClose() - defer s1.Close() - - bs1 := bufio.NewWriterSize(s1, 1024*1024) - - b := make([]byte, 1600) - i := 0 - for { - i += 1 - n := traf.Generate(b, 16) - if n == 0 { - break - } - if i == 1 { - ch <- n - } - bs1.Write(b[16 : n+16]) - - // TODO: this is a pretty half-baked batching - // function, which we'd never want to employ in - // a real-life program. - // - // In real life, we'd probably want to flush - // immediately when there are no more packets to - // generate, and queue up only if we fall behind. - // - // In our case however, we just want to see the - // technical benefits of batching 10 syscalls - // into 1, so a fixed ratio makes more sense. - if (i % 10) == 0 { - bs1.Flush() - } - } - }() - - go func() { - // receiver - defer slClose() - defer s2.Close() - - bs2 := bufio.NewReaderSize(s2, 1024*1024) - - // Find out the packet size (we happen to know they're - // all the same size) - packetSize := <-ch - - b := make([]byte, packetSize) - for traf.Running() { - // TODO: can't use ReadFrom() here, which is - // unfair compared to UDP. (ReadFrom for UDP - // apparently allocates memory per packet, which - // this test does not.) - n, err := io.ReadFull(bs2, b) - if err != nil { - log.Fatalf("s2.Read: %v", err) - } - traf.GotPacket(b[:n], 0) - } - }() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Create two wgengine instances and pass data through them, measuring +// throughput, latency, and packet loss. +package main + +import ( + "bufio" + "io" + "log" + "net" + "net/http" + "net/http/pprof" + "net/netip" + "os" + "strconv" + "sync" + "time" + + "tailscale.com/types/logger" +) + +const PayloadSize = 1000 +const ICMPMinSize = 24 + +var Addr1 = netip.MustParsePrefix("100.64.1.1/32") +var Addr2 = netip.MustParsePrefix("100.64.1.2/32") + +func main() { + var logf logger.Logf = log.Printf + log.SetFlags(0) + + debugMux := newDebugMux() + go runDebugServer(debugMux, "0.0.0.0:8999") + + mode, err := strconv.Atoi(os.Args[1]) + if err != nil { + log.Fatalf("%q: %v", os.Args[1], err) + } + + traf := NewTrafficGen(nil) + + // Sample test results below are using GOMAXPROCS=2 (for some + // tests, including wireguard-go, higher GOMAXPROCS goes slower) + // on apenwarr's old Linux box: + // Intel(R) Core(TM) i7-4785T CPU @ 2.20GHz + // My 2019 Mac Mini is about 20% faster on most tests. + + switch mode { + // tx=8786325 rx=8786326 (0 = 0.00% loss) (70768.7 Mbits/sec) + case 1: + setupTrivialNoAllocTest(logf, traf) + + // tx=6476293 rx=6476293 (0 = 0.00% loss) (52249.7 Mbits/sec) + case 2: + setupTrivialTest(logf, traf) + + // tx=1957974 rx=1958379 (0 = 0.00% loss) (15939.8 Mbits/sec) + case 11: + setupBlockingChannelTest(logf, traf) + + // tx=728621 rx=701825 (26620 = 3.65% loss) (5525.2 Mbits/sec) + // (much faster on macOS??) + case 12: + setupNonblockingChannelTest(logf, traf) + + // tx=1024260 rx=941098 (83334 = 8.14% loss) (7516.6 Mbits/sec) + // (much faster on macOS??) + case 13: + setupDoubleChannelTest(logf, traf) + + // tx=265468 rx=263189 (2279 = 0.86% loss) (2162.0 Mbits/sec) + case 21: + setupUDPTest(logf, traf) + + // tx=1493580 rx=1493580 (0 = 0.00% loss) (12210.4 Mbits/sec) + case 31: + setupBatchTCPTest(logf, traf) + + // tx=134236 rx=133166 (1070 = 0.80% loss) (1088.9 Mbits/sec) + case 101: + setupWGTest(nil, logf, traf, Addr1, Addr2) + + default: + log.Fatalf("provide a valid test number (0..n)") + } + + logf("initialized ok.") + traf.Start(Addr1.Addr(), Addr2.Addr(), PayloadSize+ICMPMinSize, 0) + + var cur, prev Snapshot + var pps int64 + i := 0 + for { + i += 1 + time.Sleep(10 * time.Millisecond) + + if (i % 100) == 0 { + prev = cur + cur = traf.Snap() + d := cur.Sub(prev) + + if prev.WhenNsec == 0 { + logf("tx=%-6d rx=%-6d", d.TxPackets, d.RxPackets) + } else { + logf("%v @%7d pkt/s", d, pps) + } + } + + pps = traf.Adjust() + } +} + +func newDebugMux() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + return mux +} + +func runDebugServer(mux *http.ServeMux, addr string) { + srv := &http.Server{ + Addr: addr, + Handler: mux, + } + if err := srv.ListenAndServe(); err != nil { + log.Fatal(err) + } +} + +// The absolute minimal test of the traffic generator: have it fill +// a packet buffer, then absorb it again. Zero packet loss. +func setupTrivialNoAllocTest(logf logger.Logf, traf *TrafficGen) { + go func() { + b := make([]byte, 1600) + for { + n := traf.Generate(b, 16) + if n == 0 { + break + } + traf.GotPacket(b[0:n+16], 16) + } + }() +} + +// Almost the same, but this time allocate a fresh buffer each time +// through the loop. Still zero packet loss. Runs about 2/3 as fast for me. +func setupTrivialTest(logf logger.Logf, traf *TrafficGen) { + go func() { + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + break + } + traf.GotPacket(b[0:n+16], 16) + } + }() +} + +// Pass packets through a blocking channel between sender and receiver. +// Still zero packet loss since the sender stops when the channel is full. +// Max speed depends on channel length (I'm not sure why). +func setupBlockingChannelTest(logf logger.Logf, traf *TrafficGen) { + ch := make(chan []byte, 1000) + + go func() { + // transmitter + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + close(ch) + break + } + ch <- b[0 : n+16] + } + }() + + go func() { + // receiver + for b := range ch { + traf.GotPacket(b, 16) + } + }() +} + +// Same as setupBlockingChannelTest, but now we drop packets whenever the +// channel is full. Max speed is about the same as the above test, but +// now with nonzero packet loss. +func setupNonblockingChannelTest(logf logger.Logf, traf *TrafficGen) { + ch := make(chan []byte, 1000) + + go func() { + // transmitter + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + close(ch) + break + } + select { + case ch <- b[0 : n+16]: + default: + } + } + }() + + go func() { + // receiver + for b := range ch { + traf.GotPacket(b, 16) + } + }() +} + +// Same as above, but at an intermediate blocking channel and goroutine +// to make things a little more like wireguard-go. Roughly 20% slower than +// the single-channel version. +func setupDoubleChannelTest(logf logger.Logf, traf *TrafficGen) { + ch := make(chan []byte, 1000) + ch2 := make(chan []byte, 1000) + + go func() { + // transmitter + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + close(ch) + break + } + select { + case ch <- b[0 : n+16]: + default: + } + } + }() + + go func() { + // intermediary + for b := range ch { + ch2 <- b + } + close(ch2) + }() + + go func() { + // receiver + for b := range ch2 { + traf.GotPacket(b, 16) + } + }() +} + +// Instead of a channel, pass packets through a UDP socket. +func setupUDPTest(logf logger.Logf, traf *TrafficGen) { + la, err := net.ResolveUDPAddr("udp", ":0") + if err != nil { + log.Fatalf("resolve: %v", err) + } + + s1, err := net.ListenUDP("udp", la) + if err != nil { + log.Fatalf("listen1: %v", err) + } + s2, err := net.ListenUDP("udp", la) + if err != nil { + log.Fatalf("listen2: %v", err) + } + + a2 := s2.LocalAddr() + + // On macOS (but not Linux), you can't transmit to 0.0.0.0:port, + // which is what returns from .LocalAddr() above. We have to + // force it to localhost instead. + a2.(*net.UDPAddr).IP = net.ParseIP("127.0.0.1") + + s1.SetWriteBuffer(1024 * 1024) + s2.SetReadBuffer(1024 * 1024) + + go func() { + // transmitter + b := make([]byte, 1600) + for { + n := traf.Generate(b, 16) + if n == 0 { + break + } + s1.WriteTo(b[16:n+16], a2) + } + }() + + go func() { + // receiver + b := make([]byte, 1600) + for traf.Running() { + // Use ReadFrom instead of Read, to be more like + // how wireguard-go does it, even though we're not + // going to actually look at the address. + n, _, err := s2.ReadFrom(b) + if err != nil { + log.Fatalf("s2.Read: %v", err) + } + traf.GotPacket(b[:n], 0) + } + }() +} + +// Instead of a channel, pass packets through a TCP socket. +// TCP is a single stream, so we can amortize one syscall across +// multiple packets. 10x amortization seems to make it go ~10x faster, +// as expected, getting us close to the speed of the channel tests above. +// There's also zero packet loss. +func setupBatchTCPTest(logf logger.Logf, traf *TrafficGen) { + sl, err := net.Listen("tcp", ":0") + if err != nil { + log.Fatalf("listen: %v", err) + } + + var slCloseOnce sync.Once + slClose := func() { + slCloseOnce.Do(func() { + sl.Close() + }) + } + + s1, err := net.Dial("tcp", sl.Addr().String()) + if err != nil { + log.Fatalf("dial: %v", err) + } + + s2, err := sl.Accept() + if err != nil { + log.Fatalf("accept: %v", err) + } + + s1.(*net.TCPConn).SetWriteBuffer(1024 * 1024) + s2.(*net.TCPConn).SetReadBuffer(1024 * 1024) + + ch := make(chan int) + + go func() { + // transmitter + defer slClose() + defer s1.Close() + + bs1 := bufio.NewWriterSize(s1, 1024*1024) + + b := make([]byte, 1600) + i := 0 + for { + i += 1 + n := traf.Generate(b, 16) + if n == 0 { + break + } + if i == 1 { + ch <- n + } + bs1.Write(b[16 : n+16]) + + // TODO: this is a pretty half-baked batching + // function, which we'd never want to employ in + // a real-life program. + // + // In real life, we'd probably want to flush + // immediately when there are no more packets to + // generate, and queue up only if we fall behind. + // + // In our case however, we just want to see the + // technical benefits of batching 10 syscalls + // into 1, so a fixed ratio makes more sense. + if (i % 10) == 0 { + bs1.Flush() + } + } + }() + + go func() { + // receiver + defer slClose() + defer s2.Close() + + bs2 := bufio.NewReaderSize(s2, 1024*1024) + + // Find out the packet size (we happen to know they're + // all the same size) + packetSize := <-ch + + b := make([]byte, packetSize) + for traf.Running() { + // TODO: can't use ReadFrom() here, which is + // unfair compared to UDP. (ReadFrom for UDP + // apparently allocates memory per packet, which + // this test does not.) + n, err := io.ReadFull(bs2, b) + if err != nil { + log.Fatalf("s2.Read: %v", err) + } + traf.GotPacket(b[:n], 0) + } + }() +} diff --git a/wgengine/bench/bench_test.go b/wgengine/bench/bench_test.go index 42571d0557115..4fae86c0580ba 100644 --- a/wgengine/bench/bench_test.go +++ b/wgengine/bench/bench_test.go @@ -1,108 +1,108 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Create two wgengine instances and pass data through them, measuring -// throughput, latency, and packet loss. -package main - -import ( - "fmt" - "testing" - "time" - - "tailscale.com/types/logger" -) - -func BenchmarkTrivialNoAlloc(b *testing.B) { - run(b, setupTrivialNoAllocTest) -} -func BenchmarkTrivial(b *testing.B) { - run(b, setupTrivialTest) -} - -func BenchmarkBlockingChannel(b *testing.B) { - run(b, setupBlockingChannelTest) -} - -func BenchmarkNonblockingChannel(b *testing.B) { - run(b, setupNonblockingChannelTest) -} - -func BenchmarkDoubleChannel(b *testing.B) { - run(b, setupDoubleChannelTest) -} - -func BenchmarkUDP(b *testing.B) { - run(b, setupUDPTest) -} - -func BenchmarkBatchTCP(b *testing.B) { - run(b, setupBatchTCPTest) -} - -func BenchmarkWireGuardTest(b *testing.B) { - b.Skip("https://github.com/tailscale/tailscale/issues/2716") - run(b, func(logf logger.Logf, traf *TrafficGen) { - setupWGTest(b, logf, traf, Addr1, Addr2) - }) -} - -type SetupFunc func(logger.Logf, *TrafficGen) - -func run(b *testing.B, setup SetupFunc) { - sizes := []int{ - ICMPMinSize + 8, - ICMPMinSize + 100, - ICMPMinSize + 1000, - } - - for _, size := range sizes { - b.Run(fmt.Sprintf("%d", size), func(b *testing.B) { - runOnce(b, setup, size) - }) - } -} - -func runOnce(b *testing.B, setup SetupFunc, payload int) { - b.StopTimer() - b.ReportAllocs() - - var logf logger.Logf = b.Logf - if !testing.Verbose() { - logf = logger.Discard - } - - traf := NewTrafficGen(b.StartTimer) - setup(logf, traf) - - logf("initialized. (n=%v)", b.N) - b.SetBytes(int64(payload)) - - traf.Start(Addr1.Addr(), Addr2.Addr(), payload, int64(b.N)) - - var cur, prev Snapshot - var pps int64 - i := 0 - for traf.Running() { - i += 1 - time.Sleep(10 * time.Millisecond) - - if (i % 100) == 0 { - prev = cur - cur = traf.Snap() - d := cur.Sub(prev) - - if prev.WhenNsec != 0 { - logf("%v @%7d pkt/sec", d, pps) - } - } - - pps = traf.Adjust() - } - - cur = traf.Snap() - d := cur.Sub(prev) - loss := float64(d.LostPackets) / float64(d.RxPackets) - - b.ReportMetric(loss*100, "%lost") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Create two wgengine instances and pass data through them, measuring +// throughput, latency, and packet loss. +package main + +import ( + "fmt" + "testing" + "time" + + "tailscale.com/types/logger" +) + +func BenchmarkTrivialNoAlloc(b *testing.B) { + run(b, setupTrivialNoAllocTest) +} +func BenchmarkTrivial(b *testing.B) { + run(b, setupTrivialTest) +} + +func BenchmarkBlockingChannel(b *testing.B) { + run(b, setupBlockingChannelTest) +} + +func BenchmarkNonblockingChannel(b *testing.B) { + run(b, setupNonblockingChannelTest) +} + +func BenchmarkDoubleChannel(b *testing.B) { + run(b, setupDoubleChannelTest) +} + +func BenchmarkUDP(b *testing.B) { + run(b, setupUDPTest) +} + +func BenchmarkBatchTCP(b *testing.B) { + run(b, setupBatchTCPTest) +} + +func BenchmarkWireGuardTest(b *testing.B) { + b.Skip("https://github.com/tailscale/tailscale/issues/2716") + run(b, func(logf logger.Logf, traf *TrafficGen) { + setupWGTest(b, logf, traf, Addr1, Addr2) + }) +} + +type SetupFunc func(logger.Logf, *TrafficGen) + +func run(b *testing.B, setup SetupFunc) { + sizes := []int{ + ICMPMinSize + 8, + ICMPMinSize + 100, + ICMPMinSize + 1000, + } + + for _, size := range sizes { + b.Run(fmt.Sprintf("%d", size), func(b *testing.B) { + runOnce(b, setup, size) + }) + } +} + +func runOnce(b *testing.B, setup SetupFunc, payload int) { + b.StopTimer() + b.ReportAllocs() + + var logf logger.Logf = b.Logf + if !testing.Verbose() { + logf = logger.Discard + } + + traf := NewTrafficGen(b.StartTimer) + setup(logf, traf) + + logf("initialized. (n=%v)", b.N) + b.SetBytes(int64(payload)) + + traf.Start(Addr1.Addr(), Addr2.Addr(), payload, int64(b.N)) + + var cur, prev Snapshot + var pps int64 + i := 0 + for traf.Running() { + i += 1 + time.Sleep(10 * time.Millisecond) + + if (i % 100) == 0 { + prev = cur + cur = traf.Snap() + d := cur.Sub(prev) + + if prev.WhenNsec != 0 { + logf("%v @%7d pkt/sec", d, pps) + } + } + + pps = traf.Adjust() + } + + cur = traf.Snap() + d := cur.Sub(prev) + loss := float64(d.LostPackets) / float64(d.RxPackets) + + b.ReportMetric(loss*100, "%lost") +} diff --git a/wgengine/bench/trafficgen.go b/wgengine/bench/trafficgen.go index 9de3c2e6bbc4b..ce79c616f86ed 100644 --- a/wgengine/bench/trafficgen.go +++ b/wgengine/bench/trafficgen.go @@ -1,259 +1,259 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "encoding/binary" - "fmt" - "log" - "net/netip" - "sync" - "time" - - "tailscale.com/net/packet" - "tailscale.com/types/ipproto" -) - -type Snapshot struct { - WhenNsec int64 // current time - timeAcc int64 // accumulated time (+NSecPerTx per transmit) - - LastSeqTx int64 // last sequence number sent - LastSeqRx int64 // last sequence number received - TotalLost int64 // packets out-of-order or lost so far - TotalOOO int64 // packets out-of-order so far - TotalBytesRx int64 // total bytes received so far -} - -type Delta struct { - DurationNsec int64 - TxPackets int64 - RxPackets int64 - LostPackets int64 - OOOPackets int64 - Bytes int64 -} - -func (b Snapshot) Sub(a Snapshot) Delta { - return Delta{ - DurationNsec: b.WhenNsec - a.WhenNsec, - TxPackets: b.LastSeqTx - a.LastSeqTx, - RxPackets: (b.LastSeqRx - a.LastSeqRx) - - (b.TotalLost - a.TotalLost) + - (b.TotalOOO - a.TotalOOO), - LostPackets: b.TotalLost - a.TotalLost, - OOOPackets: b.TotalOOO - a.TotalOOO, - Bytes: b.TotalBytesRx - a.TotalBytesRx, - } -} - -func (d Delta) String() string { - return fmt.Sprintf("tx=%-6d rx=%-4d (%6d = %.1f%% loss) (%d OOO) (%4.1f Mbit/s)", - d.TxPackets, d.RxPackets, d.LostPackets, - float64(d.LostPackets)*100/float64(d.TxPackets), - d.OOOPackets, - float64(d.Bytes)*8*1e9/float64(d.DurationNsec)/1e6) -} - -type TrafficGen struct { - mu sync.Mutex - cur, prev Snapshot // snapshots used for rate control - buf []byte // pre-generated packet buffer - done bool // true if the test has completed - - onFirstPacket func() // function to call on first received packet - - // maxPackets is the max packets to receive (not send) before - // ending the test. If it's zero, the test runs forever. - maxPackets int64 - - // nsPerPacket is the target average nanoseconds between packets. - // It's initially zero, which means transmit as fast as the - // caller wants to go. - nsPerPacket int64 - - // ppsHistory is the observed packets-per-second from recent - // samples. - ppsHistory [5]int64 -} - -// NewTrafficGen creates a new, initially locked, TrafficGen. -// Until Start() is called, Generate() will block forever. -func NewTrafficGen(onFirstPacket func()) *TrafficGen { - t := TrafficGen{ - onFirstPacket: onFirstPacket, - } - - // initially locked, until first Start() - t.mu.Lock() - - return &t -} - -// Start starts the traffic generator. It assumes mu is already locked, -// and unlocks it. -func (t *TrafficGen) Start(src, dst netip.Addr, bytesPerPacket int, maxPackets int64) { - h12 := packet.ICMP4Header{ - IP4Header: packet.IP4Header{ - IPProto: ipproto.ICMPv4, - IPID: 0, - Src: src, - Dst: dst, - }, - Type: packet.ICMP4EchoRequest, - Code: packet.ICMP4NoCode, - } - - // ensure there's room for ICMP header plus sequence number - if bytesPerPacket < ICMPMinSize+8 { - log.Fatalf("bytesPerPacket must be > 24+8") - } - - t.maxPackets = maxPackets - - payload := make([]byte, bytesPerPacket-ICMPMinSize) - t.buf = packet.Generate(h12, payload) - - t.mu.Unlock() -} - -func (t *TrafficGen) Snap() Snapshot { - t.mu.Lock() - defer t.mu.Unlock() - - t.cur.WhenNsec = time.Now().UnixNano() - return t.cur -} - -func (t *TrafficGen) Running() bool { - t.mu.Lock() - defer t.mu.Unlock() - - return !t.done -} - -// Generate produces the next packet in the sequence. It sleeps if -// it's too soon for the next packet to be sent. -// -// The generated packet is placed into buf at offset ofs, for compatibility -// with the wireguard-go conventions. -// -// The return value is the number of bytes generated in the packet, or 0 -// if the test has finished running. -func (t *TrafficGen) Generate(b []byte, ofs int) int { - t.mu.Lock() - - now := time.Now().UnixNano() - if t.nsPerPacket == 0 || t.cur.timeAcc == 0 { - t.cur.timeAcc = now - 1 - } - if t.cur.timeAcc >= now { - // too soon - t.mu.Unlock() - time.Sleep(time.Duration(t.cur.timeAcc-now) * time.Nanosecond) - t.mu.Lock() - - now = t.cur.timeAcc - } - if t.done { - t.mu.Unlock() - return 0 - } - - t.cur.timeAcc += t.nsPerPacket - t.cur.LastSeqTx += 1 - t.cur.WhenNsec = now - seq := t.cur.LastSeqTx - - t.mu.Unlock() - - copy(b[ofs:], t.buf) - binary.BigEndian.PutUint64( - b[ofs+ICMPMinSize:ofs+ICMPMinSize+8], - uint64(seq)) - - return len(t.buf) -} - -// GotPacket processes a packet that came back on the receive side. -func (t *TrafficGen) GotPacket(b []byte, ofs int) { - t.mu.Lock() - defer t.mu.Unlock() - - s := &t.cur - seq := int64(binary.BigEndian.Uint64( - b[ofs+ICMPMinSize : ofs+ICMPMinSize+8])) - if seq > s.LastSeqRx { - if s.LastSeqRx > 0 { - // only count lost packets after the very first - // successful one. - s.TotalLost += seq - s.LastSeqRx - 1 - } - s.LastSeqRx = seq - } else { - s.TotalOOO += 1 - } - - // +1 packet since we only start counting after the first one - if t.maxPackets > 0 && s.LastSeqRx >= t.maxPackets+1 { - t.done = true - } - s.TotalBytesRx += int64(len(b) - ofs) - - f := t.onFirstPacket - t.onFirstPacket = nil - if f != nil { - f() - } -} - -// Adjust tunes the transmit rate based on the received packets. -// The goal is to converge on the fastest transmit rate that still has -// minimal packet loss. Returns the new target rate in packets/sec. -// -// We need to play this guessing game in order to balance out tx and rx -// rates when there's a lossy network between them. Otherwise we can end -// up using 99% of the CPU to blast out transmitted packets and leaving only -// 1% to receive them, leading to a misleading throughput calculation. -// -// Call this function multiple times per second. -func (t *TrafficGen) Adjust() (pps int64) { - t.mu.Lock() - defer t.mu.Unlock() - - d := t.cur.Sub(t.prev) - - // don't adjust rate until the first full period *after* receiving - // the first packet. This skips any handshake time in the underlying - // transport. - if t.prev.LastSeqRx == 0 || d.DurationNsec == 0 { - t.prev = t.cur - return 0 // no estimate yet, continue at max speed - } - - pps = int64(d.RxPackets) * 1e9 / int64(d.DurationNsec) - - // We use a rate selection algorithm based loosely on TCP BBR. - // Basically, we set the transmit rate to be a bit higher than - // the best observed transmit rate in the last several time - // periods. This guarantees some packet loss, but should converge - // quickly on a rate near the sustainable maximum. - bestPPS := pps - for _, p := range t.ppsHistory { - if p > bestPPS { - bestPPS = p - } - } - if pps > 0 && t.prev.WhenNsec > 0 { - copy(t.ppsHistory[1:], t.ppsHistory[0:len(t.ppsHistory)-1]) - t.ppsHistory[0] = pps - } - if bestPPS > 0 { - pps = bestPPS * 103 / 100 - t.nsPerPacket = int64(1e9 / pps) - } - t.prev = t.cur - - return pps -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "encoding/binary" + "fmt" + "log" + "net/netip" + "sync" + "time" + + "tailscale.com/net/packet" + "tailscale.com/types/ipproto" +) + +type Snapshot struct { + WhenNsec int64 // current time + timeAcc int64 // accumulated time (+NSecPerTx per transmit) + + LastSeqTx int64 // last sequence number sent + LastSeqRx int64 // last sequence number received + TotalLost int64 // packets out-of-order or lost so far + TotalOOO int64 // packets out-of-order so far + TotalBytesRx int64 // total bytes received so far +} + +type Delta struct { + DurationNsec int64 + TxPackets int64 + RxPackets int64 + LostPackets int64 + OOOPackets int64 + Bytes int64 +} + +func (b Snapshot) Sub(a Snapshot) Delta { + return Delta{ + DurationNsec: b.WhenNsec - a.WhenNsec, + TxPackets: b.LastSeqTx - a.LastSeqTx, + RxPackets: (b.LastSeqRx - a.LastSeqRx) - + (b.TotalLost - a.TotalLost) + + (b.TotalOOO - a.TotalOOO), + LostPackets: b.TotalLost - a.TotalLost, + OOOPackets: b.TotalOOO - a.TotalOOO, + Bytes: b.TotalBytesRx - a.TotalBytesRx, + } +} + +func (d Delta) String() string { + return fmt.Sprintf("tx=%-6d rx=%-4d (%6d = %.1f%% loss) (%d OOO) (%4.1f Mbit/s)", + d.TxPackets, d.RxPackets, d.LostPackets, + float64(d.LostPackets)*100/float64(d.TxPackets), + d.OOOPackets, + float64(d.Bytes)*8*1e9/float64(d.DurationNsec)/1e6) +} + +type TrafficGen struct { + mu sync.Mutex + cur, prev Snapshot // snapshots used for rate control + buf []byte // pre-generated packet buffer + done bool // true if the test has completed + + onFirstPacket func() // function to call on first received packet + + // maxPackets is the max packets to receive (not send) before + // ending the test. If it's zero, the test runs forever. + maxPackets int64 + + // nsPerPacket is the target average nanoseconds between packets. + // It's initially zero, which means transmit as fast as the + // caller wants to go. + nsPerPacket int64 + + // ppsHistory is the observed packets-per-second from recent + // samples. + ppsHistory [5]int64 +} + +// NewTrafficGen creates a new, initially locked, TrafficGen. +// Until Start() is called, Generate() will block forever. +func NewTrafficGen(onFirstPacket func()) *TrafficGen { + t := TrafficGen{ + onFirstPacket: onFirstPacket, + } + + // initially locked, until first Start() + t.mu.Lock() + + return &t +} + +// Start starts the traffic generator. It assumes mu is already locked, +// and unlocks it. +func (t *TrafficGen) Start(src, dst netip.Addr, bytesPerPacket int, maxPackets int64) { + h12 := packet.ICMP4Header{ + IP4Header: packet.IP4Header{ + IPProto: ipproto.ICMPv4, + IPID: 0, + Src: src, + Dst: dst, + }, + Type: packet.ICMP4EchoRequest, + Code: packet.ICMP4NoCode, + } + + // ensure there's room for ICMP header plus sequence number + if bytesPerPacket < ICMPMinSize+8 { + log.Fatalf("bytesPerPacket must be > 24+8") + } + + t.maxPackets = maxPackets + + payload := make([]byte, bytesPerPacket-ICMPMinSize) + t.buf = packet.Generate(h12, payload) + + t.mu.Unlock() +} + +func (t *TrafficGen) Snap() Snapshot { + t.mu.Lock() + defer t.mu.Unlock() + + t.cur.WhenNsec = time.Now().UnixNano() + return t.cur +} + +func (t *TrafficGen) Running() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return !t.done +} + +// Generate produces the next packet in the sequence. It sleeps if +// it's too soon for the next packet to be sent. +// +// The generated packet is placed into buf at offset ofs, for compatibility +// with the wireguard-go conventions. +// +// The return value is the number of bytes generated in the packet, or 0 +// if the test has finished running. +func (t *TrafficGen) Generate(b []byte, ofs int) int { + t.mu.Lock() + + now := time.Now().UnixNano() + if t.nsPerPacket == 0 || t.cur.timeAcc == 0 { + t.cur.timeAcc = now - 1 + } + if t.cur.timeAcc >= now { + // too soon + t.mu.Unlock() + time.Sleep(time.Duration(t.cur.timeAcc-now) * time.Nanosecond) + t.mu.Lock() + + now = t.cur.timeAcc + } + if t.done { + t.mu.Unlock() + return 0 + } + + t.cur.timeAcc += t.nsPerPacket + t.cur.LastSeqTx += 1 + t.cur.WhenNsec = now + seq := t.cur.LastSeqTx + + t.mu.Unlock() + + copy(b[ofs:], t.buf) + binary.BigEndian.PutUint64( + b[ofs+ICMPMinSize:ofs+ICMPMinSize+8], + uint64(seq)) + + return len(t.buf) +} + +// GotPacket processes a packet that came back on the receive side. +func (t *TrafficGen) GotPacket(b []byte, ofs int) { + t.mu.Lock() + defer t.mu.Unlock() + + s := &t.cur + seq := int64(binary.BigEndian.Uint64( + b[ofs+ICMPMinSize : ofs+ICMPMinSize+8])) + if seq > s.LastSeqRx { + if s.LastSeqRx > 0 { + // only count lost packets after the very first + // successful one. + s.TotalLost += seq - s.LastSeqRx - 1 + } + s.LastSeqRx = seq + } else { + s.TotalOOO += 1 + } + + // +1 packet since we only start counting after the first one + if t.maxPackets > 0 && s.LastSeqRx >= t.maxPackets+1 { + t.done = true + } + s.TotalBytesRx += int64(len(b) - ofs) + + f := t.onFirstPacket + t.onFirstPacket = nil + if f != nil { + f() + } +} + +// Adjust tunes the transmit rate based on the received packets. +// The goal is to converge on the fastest transmit rate that still has +// minimal packet loss. Returns the new target rate in packets/sec. +// +// We need to play this guessing game in order to balance out tx and rx +// rates when there's a lossy network between them. Otherwise we can end +// up using 99% of the CPU to blast out transmitted packets and leaving only +// 1% to receive them, leading to a misleading throughput calculation. +// +// Call this function multiple times per second. +func (t *TrafficGen) Adjust() (pps int64) { + t.mu.Lock() + defer t.mu.Unlock() + + d := t.cur.Sub(t.prev) + + // don't adjust rate until the first full period *after* receiving + // the first packet. This skips any handshake time in the underlying + // transport. + if t.prev.LastSeqRx == 0 || d.DurationNsec == 0 { + t.prev = t.cur + return 0 // no estimate yet, continue at max speed + } + + pps = int64(d.RxPackets) * 1e9 / int64(d.DurationNsec) + + // We use a rate selection algorithm based loosely on TCP BBR. + // Basically, we set the transmit rate to be a bit higher than + // the best observed transmit rate in the last several time + // periods. This guarantees some packet loss, but should converge + // quickly on a rate near the sustainable maximum. + bestPPS := pps + for _, p := range t.ppsHistory { + if p > bestPPS { + bestPPS = p + } + } + if pps > 0 && t.prev.WhenNsec > 0 { + copy(t.ppsHistory[1:], t.ppsHistory[0:len(t.ppsHistory)-1]) + t.ppsHistory[0] = pps + } + if bestPPS > 0 { + pps = bestPPS * 103 / 100 + t.nsPerPacket = int64(1e9 / pps) + } + t.prev = t.cur + + return pps +} diff --git a/wgengine/capture/capture.go b/wgengine/capture/capture.go index 01f79ea9f5485..6ea5a9549b4f1 100644 --- a/wgengine/capture/capture.go +++ b/wgengine/capture/capture.go @@ -1,238 +1,238 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package capture formats packet logging into a debug pcap stream. -package capture - -import ( - "bytes" - "context" - "encoding/binary" - "io" - "net/http" - "sync" - "time" - - _ "embed" - - "tailscale.com/net/packet" - "tailscale.com/util/set" -) - -//go:embed ts-dissector.lua -var DissectorLua string - -// Callback describes a function which is called to -// record packets when debugging packet-capture. -// Such callbacks must not take ownership of the -// provided data slice: it may only copy out of it -// within the lifetime of the function. -type Callback func(Path, time.Time, []byte, packet.CaptureMeta) - -var bufferPool = sync.Pool{ - New: func() any { - return new(bytes.Buffer) - }, -} - -const flushPeriod = 100 * time.Millisecond - -func writePcapHeader(w io.Writer) { - binary.Write(w, binary.LittleEndian, uint32(0xA1B2C3D4)) // pcap magic number - binary.Write(w, binary.LittleEndian, uint16(2)) // version major - binary.Write(w, binary.LittleEndian, uint16(4)) // version minor - binary.Write(w, binary.LittleEndian, uint32(0)) // this zone - binary.Write(w, binary.LittleEndian, uint32(0)) // zone significant figures - binary.Write(w, binary.LittleEndian, uint32(65535)) // max packet len - binary.Write(w, binary.LittleEndian, uint32(147)) // link-layer ID - USER0 -} - -func writePktHeader(w *bytes.Buffer, when time.Time, length int) { - s := when.Unix() - us := when.UnixMicro() - (s * 1000000) - - binary.Write(w, binary.LittleEndian, uint32(s)) // timestamp in seconds - binary.Write(w, binary.LittleEndian, uint32(us)) // timestamp microseconds - binary.Write(w, binary.LittleEndian, uint32(length)) // length present - binary.Write(w, binary.LittleEndian, uint32(length)) // total length -} - -// Path describes where in the data path the packet was captured. -type Path uint8 - -// Valid Path values. -const ( - // FromLocal indicates the packet was logged as it traversed the FromLocal path: - // i.e.: A packet from the local system into the TUN. - FromLocal Path = 0 - // FromPeer indicates the packet was logged upon reception from a remote peer. - FromPeer Path = 1 - // SynthesizedToLocal indicates the packet was generated from within tailscaled, - // and is being routed to the local machine's network stack. - SynthesizedToLocal Path = 2 - // SynthesizedToPeer indicates the packet was generated from within tailscaled, - // and is being routed to a remote Wireguard peer. - SynthesizedToPeer Path = 3 - - // PathDisco indicates the packet is information about a disco frame. - PathDisco Path = 254 -) - -// New creates a new capture sink. -func New() *Sink { - ctx, c := context.WithCancel(context.Background()) - return &Sink{ - ctx: ctx, - ctxCancel: c, - } -} - -// Type Sink handles callbacks with packets to be logged, -// formatting them into a pcap stream which is mirrored to -// all registered outputs. -type Sink struct { - ctx context.Context - ctxCancel context.CancelFunc - - mu sync.Mutex - outputs set.HandleSet[io.Writer] - flushTimer *time.Timer // or nil if none running -} - -// RegisterOutput connects an output to this sink, which -// will be written to with a pcap stream as packets are logged. -// A function is returned which unregisters the output when -// called. -// -// If w implements io.Closer, it will be closed upon error -// or when the sink is closed. If w implements http.Flusher, -// it will be flushed periodically. -func (s *Sink) RegisterOutput(w io.Writer) (unregister func()) { - select { - case <-s.ctx.Done(): - return func() {} - default: - } - - writePcapHeader(w) - s.mu.Lock() - hnd := s.outputs.Add(w) - s.mu.Unlock() - - return func() { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.outputs, hnd) - } -} - -// NumOutputs returns the number of outputs registered with the sink. -func (s *Sink) NumOutputs() int { - s.mu.Lock() - defer s.mu.Unlock() - return len(s.outputs) -} - -// Close shuts down the sink. Future calls to LogPacket -// are ignored, and any registered output that implements -// io.Closer is closed. -func (s *Sink) Close() error { - s.ctxCancel() - s.mu.Lock() - defer s.mu.Unlock() - if s.flushTimer != nil { - s.flushTimer.Stop() - s.flushTimer = nil - } - - for _, o := range s.outputs { - if o, ok := o.(io.Closer); ok { - o.Close() - } - } - s.outputs = nil - return nil -} - -// WaitCh returns a channel which blocks until -// the sink is closed. -func (s *Sink) WaitCh() <-chan struct{} { - return s.ctx.Done() -} - -func customDataLen(meta packet.CaptureMeta) int { - length := 4 - if meta.DidSNAT { - length += meta.OriginalSrc.Addr().BitLen() / 8 - } - if meta.DidDNAT { - length += meta.OriginalDst.Addr().BitLen() / 8 - } - return length -} - -// LogPacket is called to insert a packet into the capture. -// -// This function does not take ownership of the provided data slice. -func (s *Sink) LogPacket(path Path, when time.Time, data []byte, meta packet.CaptureMeta) { - select { - case <-s.ctx.Done(): - return - default: - } - - extraLen := customDataLen(meta) - b := bufferPool.Get().(*bytes.Buffer) - b.Reset() - b.Grow(16 + extraLen + len(data)) // 16b pcap header + len(metadata) + len(payload) - defer bufferPool.Put(b) - - writePktHeader(b, when, len(data)+extraLen) - - // Custom tailscale debugging data - binary.Write(b, binary.LittleEndian, uint16(path)) - if meta.DidSNAT { - binary.Write(b, binary.LittleEndian, uint8(meta.OriginalSrc.Addr().BitLen()/8)) - b.Write(meta.OriginalSrc.Addr().AsSlice()) - } else { - binary.Write(b, binary.LittleEndian, uint8(0)) // SNAT addr len == 0 - } - if meta.DidDNAT { - binary.Write(b, binary.LittleEndian, uint8(meta.OriginalDst.Addr().BitLen()/8)) - b.Write(meta.OriginalDst.Addr().AsSlice()) - } else { - binary.Write(b, binary.LittleEndian, uint8(0)) // DNAT addr len == 0 - } - - b.Write(data) - - s.mu.Lock() - defer s.mu.Unlock() - - var hadError []set.Handle - for hnd, o := range s.outputs { - if _, err := o.Write(b.Bytes()); err != nil { - hadError = append(hadError, hnd) - continue - } - } - for _, hnd := range hadError { - if o, ok := s.outputs[hnd].(io.Closer); ok { - o.Close() - } - delete(s.outputs, hnd) - } - - if s.flushTimer == nil { - s.flushTimer = time.AfterFunc(flushPeriod, func() { - s.mu.Lock() - defer s.mu.Unlock() - for _, o := range s.outputs { - if f, ok := o.(http.Flusher); ok { - f.Flush() - } - } - s.flushTimer = nil - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package capture formats packet logging into a debug pcap stream. +package capture + +import ( + "bytes" + "context" + "encoding/binary" + "io" + "net/http" + "sync" + "time" + + _ "embed" + + "tailscale.com/net/packet" + "tailscale.com/util/set" +) + +//go:embed ts-dissector.lua +var DissectorLua string + +// Callback describes a function which is called to +// record packets when debugging packet-capture. +// Such callbacks must not take ownership of the +// provided data slice: it may only copy out of it +// within the lifetime of the function. +type Callback func(Path, time.Time, []byte, packet.CaptureMeta) + +var bufferPool = sync.Pool{ + New: func() any { + return new(bytes.Buffer) + }, +} + +const flushPeriod = 100 * time.Millisecond + +func writePcapHeader(w io.Writer) { + binary.Write(w, binary.LittleEndian, uint32(0xA1B2C3D4)) // pcap magic number + binary.Write(w, binary.LittleEndian, uint16(2)) // version major + binary.Write(w, binary.LittleEndian, uint16(4)) // version minor + binary.Write(w, binary.LittleEndian, uint32(0)) // this zone + binary.Write(w, binary.LittleEndian, uint32(0)) // zone significant figures + binary.Write(w, binary.LittleEndian, uint32(65535)) // max packet len + binary.Write(w, binary.LittleEndian, uint32(147)) // link-layer ID - USER0 +} + +func writePktHeader(w *bytes.Buffer, when time.Time, length int) { + s := when.Unix() + us := when.UnixMicro() - (s * 1000000) + + binary.Write(w, binary.LittleEndian, uint32(s)) // timestamp in seconds + binary.Write(w, binary.LittleEndian, uint32(us)) // timestamp microseconds + binary.Write(w, binary.LittleEndian, uint32(length)) // length present + binary.Write(w, binary.LittleEndian, uint32(length)) // total length +} + +// Path describes where in the data path the packet was captured. +type Path uint8 + +// Valid Path values. +const ( + // FromLocal indicates the packet was logged as it traversed the FromLocal path: + // i.e.: A packet from the local system into the TUN. + FromLocal Path = 0 + // FromPeer indicates the packet was logged upon reception from a remote peer. + FromPeer Path = 1 + // SynthesizedToLocal indicates the packet was generated from within tailscaled, + // and is being routed to the local machine's network stack. + SynthesizedToLocal Path = 2 + // SynthesizedToPeer indicates the packet was generated from within tailscaled, + // and is being routed to a remote Wireguard peer. + SynthesizedToPeer Path = 3 + + // PathDisco indicates the packet is information about a disco frame. + PathDisco Path = 254 +) + +// New creates a new capture sink. +func New() *Sink { + ctx, c := context.WithCancel(context.Background()) + return &Sink{ + ctx: ctx, + ctxCancel: c, + } +} + +// Type Sink handles callbacks with packets to be logged, +// formatting them into a pcap stream which is mirrored to +// all registered outputs. +type Sink struct { + ctx context.Context + ctxCancel context.CancelFunc + + mu sync.Mutex + outputs set.HandleSet[io.Writer] + flushTimer *time.Timer // or nil if none running +} + +// RegisterOutput connects an output to this sink, which +// will be written to with a pcap stream as packets are logged. +// A function is returned which unregisters the output when +// called. +// +// If w implements io.Closer, it will be closed upon error +// or when the sink is closed. If w implements http.Flusher, +// it will be flushed periodically. +func (s *Sink) RegisterOutput(w io.Writer) (unregister func()) { + select { + case <-s.ctx.Done(): + return func() {} + default: + } + + writePcapHeader(w) + s.mu.Lock() + hnd := s.outputs.Add(w) + s.mu.Unlock() + + return func() { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.outputs, hnd) + } +} + +// NumOutputs returns the number of outputs registered with the sink. +func (s *Sink) NumOutputs() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.outputs) +} + +// Close shuts down the sink. Future calls to LogPacket +// are ignored, and any registered output that implements +// io.Closer is closed. +func (s *Sink) Close() error { + s.ctxCancel() + s.mu.Lock() + defer s.mu.Unlock() + if s.flushTimer != nil { + s.flushTimer.Stop() + s.flushTimer = nil + } + + for _, o := range s.outputs { + if o, ok := o.(io.Closer); ok { + o.Close() + } + } + s.outputs = nil + return nil +} + +// WaitCh returns a channel which blocks until +// the sink is closed. +func (s *Sink) WaitCh() <-chan struct{} { + return s.ctx.Done() +} + +func customDataLen(meta packet.CaptureMeta) int { + length := 4 + if meta.DidSNAT { + length += meta.OriginalSrc.Addr().BitLen() / 8 + } + if meta.DidDNAT { + length += meta.OriginalDst.Addr().BitLen() / 8 + } + return length +} + +// LogPacket is called to insert a packet into the capture. +// +// This function does not take ownership of the provided data slice. +func (s *Sink) LogPacket(path Path, when time.Time, data []byte, meta packet.CaptureMeta) { + select { + case <-s.ctx.Done(): + return + default: + } + + extraLen := customDataLen(meta) + b := bufferPool.Get().(*bytes.Buffer) + b.Reset() + b.Grow(16 + extraLen + len(data)) // 16b pcap header + len(metadata) + len(payload) + defer bufferPool.Put(b) + + writePktHeader(b, when, len(data)+extraLen) + + // Custom tailscale debugging data + binary.Write(b, binary.LittleEndian, uint16(path)) + if meta.DidSNAT { + binary.Write(b, binary.LittleEndian, uint8(meta.OriginalSrc.Addr().BitLen()/8)) + b.Write(meta.OriginalSrc.Addr().AsSlice()) + } else { + binary.Write(b, binary.LittleEndian, uint8(0)) // SNAT addr len == 0 + } + if meta.DidDNAT { + binary.Write(b, binary.LittleEndian, uint8(meta.OriginalDst.Addr().BitLen()/8)) + b.Write(meta.OriginalDst.Addr().AsSlice()) + } else { + binary.Write(b, binary.LittleEndian, uint8(0)) // DNAT addr len == 0 + } + + b.Write(data) + + s.mu.Lock() + defer s.mu.Unlock() + + var hadError []set.Handle + for hnd, o := range s.outputs { + if _, err := o.Write(b.Bytes()); err != nil { + hadError = append(hadError, hnd) + continue + } + } + for _, hnd := range hadError { + if o, ok := s.outputs[hnd].(io.Closer); ok { + o.Close() + } + delete(s.outputs, hnd) + } + + if s.flushTimer == nil { + s.flushTimer = time.AfterFunc(flushPeriod, func() { + s.mu.Lock() + defer s.mu.Unlock() + for _, o := range s.outputs { + if f, ok := o.(http.Flusher); ok { + f.Flush() + } + } + s.flushTimer = nil + }) + } +} diff --git a/wgengine/magicsock/blockforever_conn.go b/wgengine/magicsock/blockforever_conn.go index 58359acdd51f2..f2e85dcd57002 100644 --- a/wgengine/magicsock/blockforever_conn.go +++ b/wgengine/magicsock/blockforever_conn.go @@ -1,55 +1,55 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package magicsock - -import ( - "errors" - "net" - "net/netip" - "sync" - "syscall" - "time" -) - -// blockForeverConn is a net.PacketConn whose reads block until it is closed. -type blockForeverConn struct { - mu sync.Mutex - cond *sync.Cond - closed bool -} - -func (c *blockForeverConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { - c.mu.Lock() - for !c.closed { - c.cond.Wait() - } - c.mu.Unlock() - return 0, netip.AddrPort{}, net.ErrClosed -} - -func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (int, error) { - // Silently drop writes. - return len(p), nil -} - -func (c *blockForeverConn) LocalAddr() net.Addr { - // Return a *net.UDPAddr because lots of code assumes that it will. - return new(net.UDPAddr) -} - -func (c *blockForeverConn) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return net.ErrClosed - } - c.closed = true - c.cond.Broadcast() - return nil -} - -func (c *blockForeverConn) SetDeadline(t time.Time) error { return errors.New("unimplemented") } -func (c *blockForeverConn) SetReadDeadline(t time.Time) error { return errors.New("unimplemented") } -func (c *blockForeverConn) SetWriteDeadline(t time.Time) error { return errors.New("unimplemented") } -func (c *blockForeverConn) SyscallConn() (syscall.RawConn, error) { return nil, errUnsupportedConnType } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "errors" + "net" + "net/netip" + "sync" + "syscall" + "time" +) + +// blockForeverConn is a net.PacketConn whose reads block until it is closed. +type blockForeverConn struct { + mu sync.Mutex + cond *sync.Cond + closed bool +} + +func (c *blockForeverConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { + c.mu.Lock() + for !c.closed { + c.cond.Wait() + } + c.mu.Unlock() + return 0, netip.AddrPort{}, net.ErrClosed +} + +func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (int, error) { + // Silently drop writes. + return len(p), nil +} + +func (c *blockForeverConn) LocalAddr() net.Addr { + // Return a *net.UDPAddr because lots of code assumes that it will. + return new(net.UDPAddr) +} + +func (c *blockForeverConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return net.ErrClosed + } + c.closed = true + c.cond.Broadcast() + return nil +} + +func (c *blockForeverConn) SetDeadline(t time.Time) error { return errors.New("unimplemented") } +func (c *blockForeverConn) SetReadDeadline(t time.Time) error { return errors.New("unimplemented") } +func (c *blockForeverConn) SetWriteDeadline(t time.Time) error { return errors.New("unimplemented") } +func (c *blockForeverConn) SyscallConn() (syscall.RawConn, error) { return nil, errUnsupportedConnType } diff --git a/wgengine/magicsock/endpoint_default.go b/wgengine/magicsock/endpoint_default.go index 9ffeef5f8a7bf..1ed6e5e0e2399 100644 --- a/wgengine/magicsock/endpoint_default.go +++ b/wgengine/magicsock/endpoint_default.go @@ -1,22 +1,22 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !js && !wasm && !plan9 - -package magicsock - -import ( - "errors" - "syscall" -) - -// errHOSTUNREACH wraps unix.EHOSTUNREACH in an interface type to pass to -// errors.Is while avoiding an allocation per call. -var errHOSTUNREACH error = syscall.EHOSTUNREACH - -// isBadEndpointErr checks if err is one which is known to report that an -// endpoint can no longer be sent to. It is not exhaustive, and for unknown -// errors always reports false. -func isBadEndpointErr(err error) bool { - return errors.Is(err, errHOSTUNREACH) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !wasm && !plan9 + +package magicsock + +import ( + "errors" + "syscall" +) + +// errHOSTUNREACH wraps unix.EHOSTUNREACH in an interface type to pass to +// errors.Is while avoiding an allocation per call. +var errHOSTUNREACH error = syscall.EHOSTUNREACH + +// isBadEndpointErr checks if err is one which is known to report that an +// endpoint can no longer be sent to. It is not exhaustive, and for unknown +// errors always reports false. +func isBadEndpointErr(err error) bool { + return errors.Is(err, errHOSTUNREACH) +} diff --git a/wgengine/magicsock/endpoint_stub.go b/wgengine/magicsock/endpoint_stub.go index 9a5c9d937560c..a209c352bfe5e 100644 --- a/wgengine/magicsock/endpoint_stub.go +++ b/wgengine/magicsock/endpoint_stub.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build wasm || plan9 - -package magicsock - -// isBadEndpointErr checks if err is one which is known to report that an -// endpoint can no longer be sent to. It is not exhaustive, but covers known -// cases. -func isBadEndpointErr(err error) bool { - return false -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build wasm || plan9 + +package magicsock + +// isBadEndpointErr checks if err is one which is known to report that an +// endpoint can no longer be sent to. It is not exhaustive, but covers known +// cases. +func isBadEndpointErr(err error) bool { + return false +} diff --git a/wgengine/magicsock/endpoint_tracker.go b/wgengine/magicsock/endpoint_tracker.go index e2ac926b43060..5caddd1a06960 100644 --- a/wgengine/magicsock/endpoint_tracker.go +++ b/wgengine/magicsock/endpoint_tracker.go @@ -1,248 +1,248 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package magicsock - -import ( - "net/netip" - "slices" - "sync" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/tempfork/heap" - "tailscale.com/util/mak" - "tailscale.com/util/set" -) - -const ( - // endpointTrackerLifetime is how long we continue advertising an - // endpoint after we last see it. This is intentionally chosen to be - // slightly longer than a full netcheck period. - endpointTrackerLifetime = 5*time.Minute + 10*time.Second - - // endpointTrackerMaxPerAddr is how many cached addresses we track for - // a given netip.Addr. This allows e.g. restricting the number of STUN - // endpoints we cache (which usually have the same netip.Addr but - // different ports). - // - // The value of 6 is chosen because we can advertise up to 3 endpoints - // based on the STUN IP: - // 1. The STUN endpoint itself (EndpointSTUN) - // 2. The STUN IP with the local Tailscale port (EndpointSTUN4LocalPort) - // 3. The STUN IP with a portmapped port (EndpointPortmapped) - // - // Storing 6 endpoints in the cache means we can store up to 2 previous - // sets of endpoints. - endpointTrackerMaxPerAddr = 6 -) - -// endpointTrackerEntry is an entry in an endpointHeap that stores the state of -// a given cached endpoint. -type endpointTrackerEntry struct { - // endpoint is the cached endpoint. - endpoint tailcfg.Endpoint - // until is the time until which this endpoint is being cached. - until time.Time - // index is the index within the containing endpointHeap. - index int -} - -// endpointHeap is an ordered heap of endpointTrackerEntry structs, ordered in -// ascending order by the 'until' expiry time (i.e. oldest first). -type endpointHeap []*endpointTrackerEntry - -var _ heap.Interface[*endpointTrackerEntry] = (*endpointHeap)(nil) - -// Len implements heap.Interface. -func (eh endpointHeap) Len() int { return len(eh) } - -// Less implements heap.Interface. -func (eh endpointHeap) Less(i, j int) bool { - // We want to store items so that the lowest item in the heap is the - // oldest, so that heap.Pop()-ing from the endpointHeap will remove the - // oldest entry. - return eh[i].until.Before(eh[j].until) -} - -// Swap implements heap.Interface. -func (eh endpointHeap) Swap(i, j int) { - eh[i], eh[j] = eh[j], eh[i] - eh[i].index = i - eh[j].index = j -} - -// Push implements heap.Interface. -func (eh *endpointHeap) Push(item *endpointTrackerEntry) { - n := len(*eh) - item.index = n - *eh = append(*eh, item) -} - -// Pop implements heap.Interface. -func (eh *endpointHeap) Pop() *endpointTrackerEntry { - old := *eh - n := len(old) - item := old[n-1] - old[n-1] = nil // avoid memory leak - item.index = -1 // for safety - *eh = old[0 : n-1] - return item -} - -// Min returns a pointer to the minimum element in the heap, without removing -// it. Since this is a min-heap ordered by the 'until' field, this returns the -// chronologically "earliest" element in the heap. -// -// Len() must be non-zero. -func (eh endpointHeap) Min() *endpointTrackerEntry { - return eh[0] -} - -// endpointTracker caches endpoints that are advertised to peers. This allows -// peers to still reach this node if there's a temporary endpoint flap; rather -// than withdrawing an endpoint and then re-advertising it the next time we run -// a netcheck, we keep advertising the endpoint until it's not present for a -// defined timeout. -// -// See tailscale/tailscale#7877 for more information. -type endpointTracker struct { - mu sync.Mutex - endpoints map[netip.Addr]*endpointHeap -} - -// update takes as input the current sent of discovered endpoints and the -// current time, and returns the set of endpoints plus any previous-cached and -// non-expired endpoints that should be advertised to peers. -func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) { - var inputEps set.Slice[netip.AddrPort] - for _, ep := range eps { - inputEps.Add(ep.Addr) - } - - et.mu.Lock() - defer et.mu.Unlock() - - // Extend endpoints that already exist in the cache. We do this before - // we remove expired endpoints, below, so we don't remove something - // that would otherwise have survived by extending. - until := now.Add(endpointTrackerLifetime) - for _, ep := range eps { - et.extendLocked(ep, until) - } - - // Now that we've extended existing endpoints, remove everything that - // has expired. - et.removeExpiredLocked(now) - - // Add entries from the input set of endpoints into the cache; we do - // this after removing expired ones so that we can store as many as - // possible, with space freed by the entries removed after expiry. - for _, ep := range eps { - et.addLocked(now, ep, until) - } - - // Finally, add entries to the return array that aren't already there. - epsPlusCached = eps - for _, heap := range et.endpoints { - for _, ep := range *heap { - // If the endpoint was in the input list, or has expired, skip it. - if inputEps.Contains(ep.endpoint.Addr) { - continue - } else if now.After(ep.until) { - // Defense-in-depth; should never happen since - // we removed expired entries above, but ignore - // it anyway. - continue - } - - // We haven't seen this endpoint; add to the return array - epsPlusCached = append(epsPlusCached, ep.endpoint) - } - } - - return epsPlusCached -} - -// extendLocked will update the expiry time of the provided endpoint in the -// cache, if it is present. If it is not present, nothing will be done. -// -// et.mu must be held. -func (et *endpointTracker) extendLocked(ep tailcfg.Endpoint, until time.Time) { - key := ep.Addr.Addr() - epHeap, found := et.endpoints[key] - if !found { - return - } - - // Find the entry for this exact address; this loop is quick since we - // bound the number of items in the heap. - // - // TODO(andrew): this means we iterate over the entire heap once per - // endpoint; even if the heap is small, if we have a lot of input - // endpoints this can be expensive? - for i, entry := range *epHeap { - if entry.endpoint == ep { - entry.until = until - heap.Fix(epHeap, i) - return - } - } -} - -// addLocked will store the provided endpoint(s) in the cache for a fixed -// period of time, ensuring that the size of the endpoint cache remains below -// the maximum. -// -// et.mu must be held. -func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) { - key := ep.Addr.Addr() - - // Create or get the heap for this endpoint's addr - epHeap := et.endpoints[key] - if epHeap == nil { - epHeap = new(endpointHeap) - mak.Set(&et.endpoints, key, epHeap) - } - - // Find the entry for this exact address; this loop is quick - // since we bound the number of items in the heap. - found := slices.ContainsFunc(*epHeap, func(v *endpointTrackerEntry) bool { - return v.endpoint == ep - }) - if !found { - // Add address to heap; either the endpoint is new, or the heap - // was newly-created and thus empty. - heap.Push(epHeap, &endpointTrackerEntry{endpoint: ep, until: until}) - } - - // Now that we've added everything, pop from our heap until we're below - // the limit. This is a min-heap, so popping removes the lowest (and - // thus oldest) endpoint. - for epHeap.Len() > endpointTrackerMaxPerAddr { - heap.Pop(epHeap) - } -} - -// removeExpired will remove all expired entries from the cache. -// -// et.mu must be held. -func (et *endpointTracker) removeExpiredLocked(now time.Time) { - for k, epHeap := range et.endpoints { - // The minimum element is oldest/earliest endpoint; repeatedly - // pop from the heap while it's in the past. - for epHeap.Len() > 0 { - minElem := epHeap.Min() - if now.After(minElem.until) { - heap.Pop(epHeap) - } else { - break - } - } - - if epHeap.Len() == 0 { - // Free up space in the map by removing the empty heap. - delete(et.endpoints, k) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "net/netip" + "slices" + "sync" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/tempfork/heap" + "tailscale.com/util/mak" + "tailscale.com/util/set" +) + +const ( + // endpointTrackerLifetime is how long we continue advertising an + // endpoint after we last see it. This is intentionally chosen to be + // slightly longer than a full netcheck period. + endpointTrackerLifetime = 5*time.Minute + 10*time.Second + + // endpointTrackerMaxPerAddr is how many cached addresses we track for + // a given netip.Addr. This allows e.g. restricting the number of STUN + // endpoints we cache (which usually have the same netip.Addr but + // different ports). + // + // The value of 6 is chosen because we can advertise up to 3 endpoints + // based on the STUN IP: + // 1. The STUN endpoint itself (EndpointSTUN) + // 2. The STUN IP with the local Tailscale port (EndpointSTUN4LocalPort) + // 3. The STUN IP with a portmapped port (EndpointPortmapped) + // + // Storing 6 endpoints in the cache means we can store up to 2 previous + // sets of endpoints. + endpointTrackerMaxPerAddr = 6 +) + +// endpointTrackerEntry is an entry in an endpointHeap that stores the state of +// a given cached endpoint. +type endpointTrackerEntry struct { + // endpoint is the cached endpoint. + endpoint tailcfg.Endpoint + // until is the time until which this endpoint is being cached. + until time.Time + // index is the index within the containing endpointHeap. + index int +} + +// endpointHeap is an ordered heap of endpointTrackerEntry structs, ordered in +// ascending order by the 'until' expiry time (i.e. oldest first). +type endpointHeap []*endpointTrackerEntry + +var _ heap.Interface[*endpointTrackerEntry] = (*endpointHeap)(nil) + +// Len implements heap.Interface. +func (eh endpointHeap) Len() int { return len(eh) } + +// Less implements heap.Interface. +func (eh endpointHeap) Less(i, j int) bool { + // We want to store items so that the lowest item in the heap is the + // oldest, so that heap.Pop()-ing from the endpointHeap will remove the + // oldest entry. + return eh[i].until.Before(eh[j].until) +} + +// Swap implements heap.Interface. +func (eh endpointHeap) Swap(i, j int) { + eh[i], eh[j] = eh[j], eh[i] + eh[i].index = i + eh[j].index = j +} + +// Push implements heap.Interface. +func (eh *endpointHeap) Push(item *endpointTrackerEntry) { + n := len(*eh) + item.index = n + *eh = append(*eh, item) +} + +// Pop implements heap.Interface. +func (eh *endpointHeap) Pop() *endpointTrackerEntry { + old := *eh + n := len(old) + item := old[n-1] + old[n-1] = nil // avoid memory leak + item.index = -1 // for safety + *eh = old[0 : n-1] + return item +} + +// Min returns a pointer to the minimum element in the heap, without removing +// it. Since this is a min-heap ordered by the 'until' field, this returns the +// chronologically "earliest" element in the heap. +// +// Len() must be non-zero. +func (eh endpointHeap) Min() *endpointTrackerEntry { + return eh[0] +} + +// endpointTracker caches endpoints that are advertised to peers. This allows +// peers to still reach this node if there's a temporary endpoint flap; rather +// than withdrawing an endpoint and then re-advertising it the next time we run +// a netcheck, we keep advertising the endpoint until it's not present for a +// defined timeout. +// +// See tailscale/tailscale#7877 for more information. +type endpointTracker struct { + mu sync.Mutex + endpoints map[netip.Addr]*endpointHeap +} + +// update takes as input the current sent of discovered endpoints and the +// current time, and returns the set of endpoints plus any previous-cached and +// non-expired endpoints that should be advertised to peers. +func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) { + var inputEps set.Slice[netip.AddrPort] + for _, ep := range eps { + inputEps.Add(ep.Addr) + } + + et.mu.Lock() + defer et.mu.Unlock() + + // Extend endpoints that already exist in the cache. We do this before + // we remove expired endpoints, below, so we don't remove something + // that would otherwise have survived by extending. + until := now.Add(endpointTrackerLifetime) + for _, ep := range eps { + et.extendLocked(ep, until) + } + + // Now that we've extended existing endpoints, remove everything that + // has expired. + et.removeExpiredLocked(now) + + // Add entries from the input set of endpoints into the cache; we do + // this after removing expired ones so that we can store as many as + // possible, with space freed by the entries removed after expiry. + for _, ep := range eps { + et.addLocked(now, ep, until) + } + + // Finally, add entries to the return array that aren't already there. + epsPlusCached = eps + for _, heap := range et.endpoints { + for _, ep := range *heap { + // If the endpoint was in the input list, or has expired, skip it. + if inputEps.Contains(ep.endpoint.Addr) { + continue + } else if now.After(ep.until) { + // Defense-in-depth; should never happen since + // we removed expired entries above, but ignore + // it anyway. + continue + } + + // We haven't seen this endpoint; add to the return array + epsPlusCached = append(epsPlusCached, ep.endpoint) + } + } + + return epsPlusCached +} + +// extendLocked will update the expiry time of the provided endpoint in the +// cache, if it is present. If it is not present, nothing will be done. +// +// et.mu must be held. +func (et *endpointTracker) extendLocked(ep tailcfg.Endpoint, until time.Time) { + key := ep.Addr.Addr() + epHeap, found := et.endpoints[key] + if !found { + return + } + + // Find the entry for this exact address; this loop is quick since we + // bound the number of items in the heap. + // + // TODO(andrew): this means we iterate over the entire heap once per + // endpoint; even if the heap is small, if we have a lot of input + // endpoints this can be expensive? + for i, entry := range *epHeap { + if entry.endpoint == ep { + entry.until = until + heap.Fix(epHeap, i) + return + } + } +} + +// addLocked will store the provided endpoint(s) in the cache for a fixed +// period of time, ensuring that the size of the endpoint cache remains below +// the maximum. +// +// et.mu must be held. +func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) { + key := ep.Addr.Addr() + + // Create or get the heap for this endpoint's addr + epHeap := et.endpoints[key] + if epHeap == nil { + epHeap = new(endpointHeap) + mak.Set(&et.endpoints, key, epHeap) + } + + // Find the entry for this exact address; this loop is quick + // since we bound the number of items in the heap. + found := slices.ContainsFunc(*epHeap, func(v *endpointTrackerEntry) bool { + return v.endpoint == ep + }) + if !found { + // Add address to heap; either the endpoint is new, or the heap + // was newly-created and thus empty. + heap.Push(epHeap, &endpointTrackerEntry{endpoint: ep, until: until}) + } + + // Now that we've added everything, pop from our heap until we're below + // the limit. This is a min-heap, so popping removes the lowest (and + // thus oldest) endpoint. + for epHeap.Len() > endpointTrackerMaxPerAddr { + heap.Pop(epHeap) + } +} + +// removeExpired will remove all expired entries from the cache. +// +// et.mu must be held. +func (et *endpointTracker) removeExpiredLocked(now time.Time) { + for k, epHeap := range et.endpoints { + // The minimum element is oldest/earliest endpoint; repeatedly + // pop from the heap while it's in the past. + for epHeap.Len() > 0 { + minElem := epHeap.Min() + if now.After(minElem.until) { + heap.Pop(epHeap) + } else { + break + } + } + + if epHeap.Len() == 0 { + // Free up space in the map by removing the empty heap. + delete(et.endpoints, k) + } + } +} diff --git a/wgengine/magicsock/magicsock_unix_test.go b/wgengine/magicsock/magicsock_unix_test.go index 9ad8cab93330b..b0700a8ebe870 100644 --- a/wgengine/magicsock/magicsock_unix_test.go +++ b/wgengine/magicsock/magicsock_unix_test.go @@ -1,60 +1,60 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build unix - -package magicsock - -import ( - "net" - "syscall" - "testing" - - "tailscale.com/types/nettype" -) - -func TestTrySetSocketBuffer(t *testing.T) { - c, err := net.ListenPacket("udp", ":0") - if err != nil { - t.Fatal(err) - } - defer c.Close() - - rc, err := c.(*net.UDPConn).SyscallConn() - if err != nil { - t.Fatal(err) - } - - getBufs := func() (int, int) { - var rcv, snd int - rc.Control(func(fd uintptr) { - rcv, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF) - if err != nil { - t.Errorf("getsockopt(SO_RCVBUF): %v", err) - } - snd, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUF) - if err != nil { - t.Errorf("getsockopt(SO_SNDBUF): %v", err) - } - }) - return rcv, snd - } - - curRcv, curSnd := getBufs() - - trySetSocketBuffer(c.(nettype.PacketConn), t.Logf) - - newRcv, newSnd := getBufs() - - if curRcv > newRcv { - t.Errorf("SO_RCVBUF decreased: %v -> %v", curRcv, newRcv) - } - if curSnd > newSnd { - t.Errorf("SO_SNDBUF decreased: %v -> %v", curSnd, newSnd) - } - - // On many systems we may not increase the value, particularly running as a - // regular user, so log the information for manual verification. - t.Logf("SO_RCVBUF: %v -> %v", curRcv, newRcv) - t.Logf("SO_SNDBUF: %v -> %v", curRcv, newRcv) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build unix + +package magicsock + +import ( + "net" + "syscall" + "testing" + + "tailscale.com/types/nettype" +) + +func TestTrySetSocketBuffer(t *testing.T) { + c, err := net.ListenPacket("udp", ":0") + if err != nil { + t.Fatal(err) + } + defer c.Close() + + rc, err := c.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + getBufs := func() (int, int) { + var rcv, snd int + rc.Control(func(fd uintptr) { + rcv, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF) + if err != nil { + t.Errorf("getsockopt(SO_RCVBUF): %v", err) + } + snd, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUF) + if err != nil { + t.Errorf("getsockopt(SO_SNDBUF): %v", err) + } + }) + return rcv, snd + } + + curRcv, curSnd := getBufs() + + trySetSocketBuffer(c.(nettype.PacketConn), t.Logf) + + newRcv, newSnd := getBufs() + + if curRcv > newRcv { + t.Errorf("SO_RCVBUF decreased: %v -> %v", curRcv, newRcv) + } + if curSnd > newSnd { + t.Errorf("SO_SNDBUF decreased: %v -> %v", curSnd, newSnd) + } + + // On many systems we may not increase the value, particularly running as a + // regular user, so log the information for manual verification. + t.Logf("SO_RCVBUF: %v -> %v", curRcv, newRcv) + t.Logf("SO_SNDBUF: %v -> %v", curRcv, newRcv) +} diff --git a/wgengine/magicsock/peermtu_darwin.go b/wgengine/magicsock/peermtu_darwin.go index b2a1ed217b2b8..a0a1aacb55f5f 100644 --- a/wgengine/magicsock/peermtu_darwin.go +++ b/wgengine/magicsock/peermtu_darwin.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin && !ios - -package magicsock - -import ( - "syscall" - - "golang.org/x/sys/unix" -) - -func getDontFragOpt(network string) int { - if network == "udp4" { - return unix.IP_DONTFRAG - } - return unix.IPV6_DONTFRAG -} - -func (c *Conn) setDontFragment(network string, enable bool) error { - optArg := 1 - if enable == false { - optArg = 0 - } - var err error - rcErr := c.connControl(network, func(fd uintptr) { - err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) - }) - - if rcErr != nil { - return rcErr - } - return err -} - -func (c *Conn) getDontFragment(network string) (bool, error) { - var v int - var err error - rcErr := c.connControl(network, func(fd uintptr) { - v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) - }) - - if rcErr != nil { - return false, rcErr - } - if v == 1 { - return true, err - } - return false, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin && !ios + +package magicsock + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func getDontFragOpt(network string) int { + if network == "udp4" { + return unix.IP_DONTFRAG + } + return unix.IPV6_DONTFRAG +} + +func (c *Conn) setDontFragment(network string, enable bool) error { + optArg := 1 + if enable == false { + optArg = 0 + } + var err error + rcErr := c.connControl(network, func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) + }) + + if rcErr != nil { + return rcErr + } + return err +} + +func (c *Conn) getDontFragment(network string) (bool, error) { + var v int + var err error + rcErr := c.connControl(network, func(fd uintptr) { + v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) + }) + + if rcErr != nil { + return false, rcErr + } + if v == 1 { + return true, err + } + return false, err +} diff --git a/wgengine/magicsock/peermtu_linux.go b/wgengine/magicsock/peermtu_linux.go index d32ead0991953..b76f30f081042 100644 --- a/wgengine/magicsock/peermtu_linux.go +++ b/wgengine/magicsock/peermtu_linux.go @@ -1,49 +1,49 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux && !android - -package magicsock - -import ( - "syscall" -) - -func getDontFragOpt(network string) int { - if network == "udp4" { - return syscall.IP_MTU_DISCOVER - } - return syscall.IPV6_MTU_DISCOVER -} - -func (c *Conn) setDontFragment(network string, enable bool) error { - optArg := syscall.IP_PMTUDISC_DO - if enable == false { - optArg = syscall.IP_PMTUDISC_DONT - } - var err error - rcErr := c.connControl(network, func(fd uintptr) { - err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) - }) - - if rcErr != nil { - return rcErr - } - return err -} - -func (c *Conn) getDontFragment(network string) (bool, error) { - var v int - var err error - rcErr := c.connControl(network, func(fd uintptr) { - v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) - }) - - if rcErr != nil { - return false, rcErr - } - if v == syscall.IP_PMTUDISC_DO { - return true, err - } - return false, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !android + +package magicsock + +import ( + "syscall" +) + +func getDontFragOpt(network string) int { + if network == "udp4" { + return syscall.IP_MTU_DISCOVER + } + return syscall.IPV6_MTU_DISCOVER +} + +func (c *Conn) setDontFragment(network string, enable bool) error { + optArg := syscall.IP_PMTUDISC_DO + if enable == false { + optArg = syscall.IP_PMTUDISC_DONT + } + var err error + rcErr := c.connControl(network, func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) + }) + + if rcErr != nil { + return rcErr + } + return err +} + +func (c *Conn) getDontFragment(network string) (bool, error) { + var v int + var err error + rcErr := c.connControl(network, func(fd uintptr) { + v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) + }) + + if rcErr != nil { + return false, rcErr + } + if v == syscall.IP_PMTUDISC_DO { + return true, err + } + return false, err +} diff --git a/wgengine/magicsock/peermtu_unix.go b/wgengine/magicsock/peermtu_unix.go index 59e808ee75e34..eec3d744f3ded 100644 --- a/wgengine/magicsock/peermtu_unix.go +++ b/wgengine/magicsock/peermtu_unix.go @@ -1,42 +1,42 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build (darwin && !ios) || (linux && !android) - -package magicsock - -import ( - "syscall" -) - -// getIPProto returns the value of the get/setsockopt proto argument necessary -// to set an IP sockopt that corresponds with the string network, which must be -// "udp4" or "udp6". -func getIPProto(network string) int { - if network == "udp4" { - return syscall.IPPROTO_IP - } - return syscall.IPPROTO_IPV6 -} - -// connControl allows the caller to run a system call on the socket underlying -// Conn specified by the string network, which must be "udp4" or "udp6". If the -// pconn type implements the syscall method, this function returns the value of -// of the system call fn called with the fd of the socket as its arg (or the -// error from rc.Control() if that fails). Otherwise it returns the error -// errUnsupportedConnType. -func (c *Conn) connControl(network string, fn func(fd uintptr)) error { - pconn := c.pconn4.pconn - if network == "udp6" { - pconn = c.pconn6.pconn - } - sc, ok := pconn.(syscall.Conn) - if !ok { - return errUnsupportedConnType - } - rc, err := sc.SyscallConn() - if err != nil { - return err - } - return rc.Control(fn) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (darwin && !ios) || (linux && !android) + +package magicsock + +import ( + "syscall" +) + +// getIPProto returns the value of the get/setsockopt proto argument necessary +// to set an IP sockopt that corresponds with the string network, which must be +// "udp4" or "udp6". +func getIPProto(network string) int { + if network == "udp4" { + return syscall.IPPROTO_IP + } + return syscall.IPPROTO_IPV6 +} + +// connControl allows the caller to run a system call on the socket underlying +// Conn specified by the string network, which must be "udp4" or "udp6". If the +// pconn type implements the syscall method, this function returns the value of +// of the system call fn called with the fd of the socket as its arg (or the +// error from rc.Control() if that fails). Otherwise it returns the error +// errUnsupportedConnType. +func (c *Conn) connControl(network string, fn func(fd uintptr)) error { + pconn := c.pconn4.pconn + if network == "udp6" { + pconn = c.pconn6.pconn + } + sc, ok := pconn.(syscall.Conn) + if !ok { + return errUnsupportedConnType + } + rc, err := sc.SyscallConn() + if err != nil { + return err + } + return rc.Control(fn) +} diff --git a/wgengine/mem_ios.go b/wgengine/mem_ios.go index 975dfca611fbb..cc266ea3aadc8 100644 --- a/wgengine/mem_ios.go +++ b/wgengine/mem_ios.go @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgengine - -import ( - "github.com/tailscale/wireguard-go/device" -) - -// iOS has a very restrictive memory limit on network extensions. -// Reduce the maximum amount of memory that wireguard-go can allocate -// to avoid getting killed. - -func init() { - device.QueueStagedSize = 64 - device.QueueOutboundSize = 64 - device.QueueInboundSize = 64 - device.QueueHandshakeSize = 64 - device.PreallocatedBuffersPerPool = 64 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package wgengine + +import ( + "github.com/tailscale/wireguard-go/device" +) + +// iOS has a very restrictive memory limit on network extensions. +// Reduce the maximum amount of memory that wireguard-go can allocate +// to avoid getting killed. + +func init() { + device.QueueStagedSize = 64 + device.QueueOutboundSize = 64 + device.QueueInboundSize = 64 + device.QueueHandshakeSize = 64 + device.PreallocatedBuffersPerPool = 64 +} diff --git a/wgengine/netstack/netstack_linux.go b/wgengine/netstack/netstack_linux.go index 9e27b7819dc4d..a0bfb44567da7 100644 --- a/wgengine/netstack/netstack_linux.go +++ b/wgengine/netstack/netstack_linux.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netstack - -import ( - "os/exec" - "syscall" - - "golang.org/x/sys/unix" -) - -func init() { - setAmbientCapsRaw = func(cmd *exec.Cmd) { - cmd.SysProcAttr = &syscall.SysProcAttr{ - AmbientCaps: []uintptr{unix.CAP_NET_RAW}, - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netstack + +import ( + "os/exec" + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + setAmbientCapsRaw = func(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + AmbientCaps: []uintptr{unix.CAP_NET_RAW}, + } + } +} diff --git a/wgengine/router/runner.go b/wgengine/router/runner.go index 7ba633344f601..8fa068e335e66 100644 --- a/wgengine/router/runner.go +++ b/wgengine/router/runner.go @@ -1,120 +1,120 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package router - -import ( - "errors" - "fmt" - "os" - "os/exec" - "strconv" - "strings" - "syscall" - - "golang.org/x/sys/unix" -) - -// commandRunner abstracts helpers to run OS commands. It exists -// purely to swap out osCommandRunner (below) with a fake runner in -// tests. -type commandRunner interface { - run(...string) error - output(...string) ([]byte, error) -} - -type osCommandRunner struct { - // ambientCapNetAdmin determines whether commands are executed with - // CAP_NET_ADMIN. - // CAP_NET_ADMIN is required when running as non-root and executing cmds - // like `ip rule`. Even if our process has the capability, we need to - // explicitly grant it to the new process. - // We specifically need this for Synology DSM7 where tailscaled no longer - // runs as root. - ambientCapNetAdmin bool -} - -// errCode extracts and returns the process exit code from err, or -// zero if err is nil. -func errCode(err error) int { - if err == nil { - return 0 - } - var e *exec.ExitError - if ok := errors.As(err, &e); ok { - return e.ExitCode() - } - s := err.Error() - if strings.HasPrefix(s, "exitcode:") { - code, err := strconv.Atoi(s[9:]) - if err == nil { - return code - } - } - return -42 -} - -func (o osCommandRunner) run(args ...string) error { - _, err := o.output(args...) - return err -} - -func (o osCommandRunner) output(args ...string) ([]byte, error) { - if len(args) == 0 { - return nil, errors.New("cmd: no argv[0]") - } - - cmd := exec.Command(args[0], args[1:]...) - cmd.Env = append(os.Environ(), "LC_ALL=C") - if o.ambientCapNetAdmin { - cmd.SysProcAttr = &syscall.SysProcAttr{ - AmbientCaps: []uintptr{unix.CAP_NET_ADMIN}, - } - } - out, err := cmd.CombinedOutput() - if err != nil { - return nil, fmt.Errorf("running %q failed: %w\n%s", strings.Join(args, " "), err, out) - } - - return out, nil -} - -type runGroup struct { - OkCode []int // error codes that are acceptable, other than 0, if any - Runner commandRunner // the runner that actually runs our commands - ErrAcc error // first error encountered, if any -} - -func newRunGroup(okCode []int, runner commandRunner) *runGroup { - return &runGroup{ - OkCode: okCode, - Runner: runner, - } -} - -func (rg *runGroup) okCode(err error) bool { - got := errCode(err) - for _, want := range rg.OkCode { - if got == want { - return true - } - } - return false -} - -func (rg *runGroup) Output(args ...string) []byte { - b, err := rg.Runner.output(args...) - if rg.ErrAcc == nil && err != nil && !rg.okCode(err) { - rg.ErrAcc = err - } - return b -} - -func (rg *runGroup) Run(args ...string) { - err := rg.Runner.run(args...) - if rg.ErrAcc == nil && err != nil && !rg.okCode(err) { - rg.ErrAcc = err - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package router + +import ( + "errors" + "fmt" + "os" + "os/exec" + "strconv" + "strings" + "syscall" + + "golang.org/x/sys/unix" +) + +// commandRunner abstracts helpers to run OS commands. It exists +// purely to swap out osCommandRunner (below) with a fake runner in +// tests. +type commandRunner interface { + run(...string) error + output(...string) ([]byte, error) +} + +type osCommandRunner struct { + // ambientCapNetAdmin determines whether commands are executed with + // CAP_NET_ADMIN. + // CAP_NET_ADMIN is required when running as non-root and executing cmds + // like `ip rule`. Even if our process has the capability, we need to + // explicitly grant it to the new process. + // We specifically need this for Synology DSM7 where tailscaled no longer + // runs as root. + ambientCapNetAdmin bool +} + +// errCode extracts and returns the process exit code from err, or +// zero if err is nil. +func errCode(err error) int { + if err == nil { + return 0 + } + var e *exec.ExitError + if ok := errors.As(err, &e); ok { + return e.ExitCode() + } + s := err.Error() + if strings.HasPrefix(s, "exitcode:") { + code, err := strconv.Atoi(s[9:]) + if err == nil { + return code + } + } + return -42 +} + +func (o osCommandRunner) run(args ...string) error { + _, err := o.output(args...) + return err +} + +func (o osCommandRunner) output(args ...string) ([]byte, error) { + if len(args) == 0 { + return nil, errors.New("cmd: no argv[0]") + } + + cmd := exec.Command(args[0], args[1:]...) + cmd.Env = append(os.Environ(), "LC_ALL=C") + if o.ambientCapNetAdmin { + cmd.SysProcAttr = &syscall.SysProcAttr{ + AmbientCaps: []uintptr{unix.CAP_NET_ADMIN}, + } + } + out, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("running %q failed: %w\n%s", strings.Join(args, " "), err, out) + } + + return out, nil +} + +type runGroup struct { + OkCode []int // error codes that are acceptable, other than 0, if any + Runner commandRunner // the runner that actually runs our commands + ErrAcc error // first error encountered, if any +} + +func newRunGroup(okCode []int, runner commandRunner) *runGroup { + return &runGroup{ + OkCode: okCode, + Runner: runner, + } +} + +func (rg *runGroup) okCode(err error) bool { + got := errCode(err) + for _, want := range rg.OkCode { + if got == want { + return true + } + } + return false +} + +func (rg *runGroup) Output(args ...string) []byte { + b, err := rg.Runner.output(args...) + if rg.ErrAcc == nil && err != nil && !rg.okCode(err) { + rg.ErrAcc = err + } + return b +} + +func (rg *runGroup) Run(args ...string) { + err := rg.Runner.run(args...) + if rg.ErrAcc == nil && err != nil && !rg.okCode(err) { + rg.ErrAcc = err + } +} diff --git a/wgengine/watchdog_js.go b/wgengine/watchdog_js.go index 9dcb29c4ee556..872ce36d5fd5d 100644 --- a/wgengine/watchdog_js.go +++ b/wgengine/watchdog_js.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build js - -package wgengine - -import "tailscale.com/net/dns/resolver" - -type watchdogEngine struct { - Engine - wrap Engine -} - -func (e *watchdogEngine) GetResolver() (r *resolver.Resolver, ok bool) { - return nil, false -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build js + +package wgengine + +import "tailscale.com/net/dns/resolver" + +type watchdogEngine struct { + Engine + wrap Engine +} + +func (e *watchdogEngine) GetResolver() (r *resolver.Resolver, ok bool) { + return nil, false +} diff --git a/wgengine/wgcfg/device.go b/wgengine/wgcfg/device.go index 9b83998cb4232..80fa159e38972 100644 --- a/wgengine/wgcfg/device.go +++ b/wgengine/wgcfg/device.go @@ -1,68 +1,68 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "io" - "sort" - - "github.com/tailscale/wireguard-go/conn" - "github.com/tailscale/wireguard-go/device" - "github.com/tailscale/wireguard-go/tun" - "tailscale.com/types/logger" - "tailscale.com/util/multierr" -) - -// NewDevice returns a wireguard-go Device configured for Tailscale use. -func NewDevice(tunDev tun.Device, bind conn.Bind, logger *device.Logger) *device.Device { - ret := device.NewDevice(tunDev, bind, logger) - ret.DisableSomeRoamingForBrokenMobileSemantics() - return ret -} - -func DeviceConfig(d *device.Device) (*Config, error) { - r, w := io.Pipe() - errc := make(chan error, 1) - go func() { - errc <- d.IpcGetOperation(w) - w.Close() - }() - cfg, fromErr := FromUAPI(r) - r.Close() - getErr := <-errc - err := multierr.New(getErr, fromErr) - if err != nil { - return nil, err - } - sort.Slice(cfg.Peers, func(i, j int) bool { - return cfg.Peers[i].PublicKey.Less(cfg.Peers[j].PublicKey) - }) - return cfg, nil -} - -// ReconfigDevice replaces the existing device configuration with cfg. -func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) { - defer func() { - if err != nil { - logf("wgcfg.Reconfig failed: %v", err) - } - }() - - prev, err := DeviceConfig(d) - if err != nil { - return err - } - - r, w := io.Pipe() - errc := make(chan error, 1) - go func() { - errc <- d.IpcSetOperation(r) - r.Close() - }() - - toErr := cfg.ToUAPI(logf, w, prev) - w.Close() - setErr := <-errc - return multierr.New(setErr, toErr) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package wgcfg + +import ( + "io" + "sort" + + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/types/logger" + "tailscale.com/util/multierr" +) + +// NewDevice returns a wireguard-go Device configured for Tailscale use. +func NewDevice(tunDev tun.Device, bind conn.Bind, logger *device.Logger) *device.Device { + ret := device.NewDevice(tunDev, bind, logger) + ret.DisableSomeRoamingForBrokenMobileSemantics() + return ret +} + +func DeviceConfig(d *device.Device) (*Config, error) { + r, w := io.Pipe() + errc := make(chan error, 1) + go func() { + errc <- d.IpcGetOperation(w) + w.Close() + }() + cfg, fromErr := FromUAPI(r) + r.Close() + getErr := <-errc + err := multierr.New(getErr, fromErr) + if err != nil { + return nil, err + } + sort.Slice(cfg.Peers, func(i, j int) bool { + return cfg.Peers[i].PublicKey.Less(cfg.Peers[j].PublicKey) + }) + return cfg, nil +} + +// ReconfigDevice replaces the existing device configuration with cfg. +func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) { + defer func() { + if err != nil { + logf("wgcfg.Reconfig failed: %v", err) + } + }() + + prev, err := DeviceConfig(d) + if err != nil { + return err + } + + r, w := io.Pipe() + errc := make(chan error, 1) + go func() { + errc <- d.IpcSetOperation(r) + r.Close() + }() + + toErr := cfg.ToUAPI(logf, w, prev) + w.Close() + setErr := <-errc + return multierr.New(setErr, toErr) +} diff --git a/wgengine/wgcfg/device_test.go b/wgengine/wgcfg/device_test.go index c54ad16d9e8b2..d54282e4bdf04 100644 --- a/wgengine/wgcfg/device_test.go +++ b/wgengine/wgcfg/device_test.go @@ -1,261 +1,261 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "bufio" - "bytes" - "io" - "net/netip" - "os" - "sort" - "strings" - "sync" - "testing" - - "github.com/tailscale/wireguard-go/conn" - "github.com/tailscale/wireguard-go/device" - "github.com/tailscale/wireguard-go/tun" - "go4.org/mem" - "tailscale.com/types/key" -) - -func TestDeviceConfig(t *testing.T) { - newK := func() (key.NodePublic, key.NodePrivate) { - t.Helper() - k := key.NewNode() - return k.Public(), k - } - k1, pk1 := newK() - ip1 := netip.MustParsePrefix("10.0.0.1/32") - - k2, pk2 := newK() - ip2 := netip.MustParsePrefix("10.0.0.2/32") - - k3, _ := newK() - ip3 := netip.MustParsePrefix("10.0.0.3/32") - - cfg1 := &Config{ - PrivateKey: pk1, - Peers: []Peer{{ - PublicKey: k2, - AllowedIPs: []netip.Prefix{ip2}, - }}, - } - - cfg2 := &Config{ - PrivateKey: pk2, - Peers: []Peer{{ - PublicKey: k1, - AllowedIPs: []netip.Prefix{ip1}, - PersistentKeepalive: 5, - }}, - } - - device1 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device1")) - device2 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device2")) - defer device1.Close() - defer device2.Close() - - cmp := func(t *testing.T, d *device.Device, want *Config) { - t.Helper() - got, err := DeviceConfig(d) - if err != nil { - t.Fatal(err) - } - prev := new(Config) - gotbuf := new(strings.Builder) - err = got.ToUAPI(t.Logf, gotbuf, prev) - gotStr := gotbuf.String() - if err != nil { - t.Errorf("got.ToUAPI(): error: %v", err) - return - } - wantbuf := new(strings.Builder) - err = want.ToUAPI(t.Logf, wantbuf, prev) - wantStr := wantbuf.String() - if err != nil { - t.Errorf("want.ToUAPI(): error: %v", err) - return - } - if gotStr != wantStr { - buf := new(bytes.Buffer) - w := bufio.NewWriter(buf) - if err := d.IpcGetOperation(w); err != nil { - t.Errorf("on error, could not IpcGetOperation: %v", err) - } - w.Flush() - t.Errorf("config mismatch:\n---- got:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String()) - } - } - - t.Run("device1 config", func(t *testing.T) { - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device2 config", func(t *testing.T) { - if err := ReconfigDevice(device2, cfg2, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device2, cfg2) - }) - - // This is only to test that Config and Reconfig are properly synchronized. - t.Run("device2 config/reconfig", func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(2) - - go func() { - ReconfigDevice(device2, cfg2, t.Logf) - wg.Done() - }() - - go func() { - DeviceConfig(device2) - wg.Done() - }() - - wg.Wait() - }) - - t.Run("device1 modify peer", func(t *testing.T) { - cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 1, 31: 0})) - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device1 replace endpoint", func(t *testing.T) { - cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 2, 31: 0})) - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device1 add new peer", func(t *testing.T) { - cfg1.Peers = append(cfg1.Peers, Peer{ - PublicKey: k3, - AllowedIPs: []netip.Prefix{ip3}, - }) - sort.Slice(cfg1.Peers, func(i, j int) bool { - return cfg1.Peers[i].PublicKey.Less(cfg1.Peers[j].PublicKey) - }) - - origCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - - newCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - peer0 := func(cfg *Config) Peer { - p, ok := cfg.PeerWithKey(k2) - if !ok { - t.Helper() - t.Fatal("failed to look up peer 2") - } - return p - } - peersEqual := func(p, q Peer) bool { - return p.PublicKey == q.PublicKey && p.DiscoKey == q.DiscoKey && p.PersistentKeepalive == q.PersistentKeepalive && cidrsEqual(p.AllowedIPs, q.AllowedIPs) - } - if !peersEqual(peer0(origCfg), peer0(newCfg)) { - t.Error("reconfig modified old peer") - } - }) - - t.Run("device1 remove peer", func(t *testing.T) { - removeKey := cfg1.Peers[len(cfg1.Peers)-1].PublicKey - cfg1.Peers = cfg1.Peers[:len(cfg1.Peers)-1] - - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - - newCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - _, ok := newCfg.PeerWithKey(removeKey) - if ok { - t.Error("reconfig failed to remove peer") - } - }) -} - -// TODO: replace with a loopback tunnel -type nilTun struct { - events chan tun.Event - closed chan struct{} -} - -func newNilTun() tun.Device { - return &nilTun{ - events: make(chan tun.Event), - closed: make(chan struct{}), - } -} - -func (t *nilTun) File() *os.File { return nil } -func (t *nilTun) Flush() error { return nil } -func (t *nilTun) MTU() (int, error) { return 1420, nil } -func (t *nilTun) Name() (string, error) { return "niltun", nil } -func (t *nilTun) Events() <-chan tun.Event { return t.events } - -func (t *nilTun) Read(data [][]byte, sizes []int, offset int) (int, error) { - <-t.closed - return 0, io.EOF -} - -func (t *nilTun) Write(data [][]byte, offset int) (int, error) { - <-t.closed - return 0, io.EOF -} - -func (t *nilTun) Close() error { - close(t.events) - close(t.closed) - return nil -} - -func (t *nilTun) BatchSize() int { return 1 } - -// A noopBind is a conn.Bind that does no actual binding work. -type noopBind struct{} - -func (noopBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { - return nil, 1, nil -} -func (noopBind) Close() error { return nil } -func (noopBind) SetMark(mark uint32) error { return nil } -func (noopBind) Send(b [][]byte, ep conn.Endpoint) error { return nil } -func (noopBind) ParseEndpoint(s string) (conn.Endpoint, error) { - return dummyEndpoint(s), nil -} -func (noopBind) BatchSize() int { return 1 } - -// A dummyEndpoint is a string holding the endpoint destination. -type dummyEndpoint string - -func (e dummyEndpoint) ClearSrc() {} -func (e dummyEndpoint) SrcToString() string { return "" } -func (e dummyEndpoint) DstToString() string { return string(e) } -func (e dummyEndpoint) DstToBytes() []byte { return nil } -func (e dummyEndpoint) DstIP() netip.Addr { return netip.Addr{} } -func (dummyEndpoint) SrcIP() netip.Addr { return netip.Addr{} } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package wgcfg + +import ( + "bufio" + "bytes" + "io" + "net/netip" + "os" + "sort" + "strings" + "sync" + "testing" + + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "go4.org/mem" + "tailscale.com/types/key" +) + +func TestDeviceConfig(t *testing.T) { + newK := func() (key.NodePublic, key.NodePrivate) { + t.Helper() + k := key.NewNode() + return k.Public(), k + } + k1, pk1 := newK() + ip1 := netip.MustParsePrefix("10.0.0.1/32") + + k2, pk2 := newK() + ip2 := netip.MustParsePrefix("10.0.0.2/32") + + k3, _ := newK() + ip3 := netip.MustParsePrefix("10.0.0.3/32") + + cfg1 := &Config{ + PrivateKey: pk1, + Peers: []Peer{{ + PublicKey: k2, + AllowedIPs: []netip.Prefix{ip2}, + }}, + } + + cfg2 := &Config{ + PrivateKey: pk2, + Peers: []Peer{{ + PublicKey: k1, + AllowedIPs: []netip.Prefix{ip1}, + PersistentKeepalive: 5, + }}, + } + + device1 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device1")) + device2 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device2")) + defer device1.Close() + defer device2.Close() + + cmp := func(t *testing.T, d *device.Device, want *Config) { + t.Helper() + got, err := DeviceConfig(d) + if err != nil { + t.Fatal(err) + } + prev := new(Config) + gotbuf := new(strings.Builder) + err = got.ToUAPI(t.Logf, gotbuf, prev) + gotStr := gotbuf.String() + if err != nil { + t.Errorf("got.ToUAPI(): error: %v", err) + return + } + wantbuf := new(strings.Builder) + err = want.ToUAPI(t.Logf, wantbuf, prev) + wantStr := wantbuf.String() + if err != nil { + t.Errorf("want.ToUAPI(): error: %v", err) + return + } + if gotStr != wantStr { + buf := new(bytes.Buffer) + w := bufio.NewWriter(buf) + if err := d.IpcGetOperation(w); err != nil { + t.Errorf("on error, could not IpcGetOperation: %v", err) + } + w.Flush() + t.Errorf("config mismatch:\n---- got:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String()) + } + } + + t.Run("device1 config", func(t *testing.T) { + if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device1, cfg1) + }) + + t.Run("device2 config", func(t *testing.T) { + if err := ReconfigDevice(device2, cfg2, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device2, cfg2) + }) + + // This is only to test that Config and Reconfig are properly synchronized. + t.Run("device2 config/reconfig", func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(2) + + go func() { + ReconfigDevice(device2, cfg2, t.Logf) + wg.Done() + }() + + go func() { + DeviceConfig(device2) + wg.Done() + }() + + wg.Wait() + }) + + t.Run("device1 modify peer", func(t *testing.T) { + cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 1, 31: 0})) + if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device1, cfg1) + }) + + t.Run("device1 replace endpoint", func(t *testing.T) { + cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 2, 31: 0})) + if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device1, cfg1) + }) + + t.Run("device1 add new peer", func(t *testing.T) { + cfg1.Peers = append(cfg1.Peers, Peer{ + PublicKey: k3, + AllowedIPs: []netip.Prefix{ip3}, + }) + sort.Slice(cfg1.Peers, func(i, j int) bool { + return cfg1.Peers[i].PublicKey.Less(cfg1.Peers[j].PublicKey) + }) + + origCfg, err := DeviceConfig(device1) + if err != nil { + t.Fatal(err) + } + + if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device1, cfg1) + + newCfg, err := DeviceConfig(device1) + if err != nil { + t.Fatal(err) + } + + peer0 := func(cfg *Config) Peer { + p, ok := cfg.PeerWithKey(k2) + if !ok { + t.Helper() + t.Fatal("failed to look up peer 2") + } + return p + } + peersEqual := func(p, q Peer) bool { + return p.PublicKey == q.PublicKey && p.DiscoKey == q.DiscoKey && p.PersistentKeepalive == q.PersistentKeepalive && cidrsEqual(p.AllowedIPs, q.AllowedIPs) + } + if !peersEqual(peer0(origCfg), peer0(newCfg)) { + t.Error("reconfig modified old peer") + } + }) + + t.Run("device1 remove peer", func(t *testing.T) { + removeKey := cfg1.Peers[len(cfg1.Peers)-1].PublicKey + cfg1.Peers = cfg1.Peers[:len(cfg1.Peers)-1] + + if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device1, cfg1) + + newCfg, err := DeviceConfig(device1) + if err != nil { + t.Fatal(err) + } + + _, ok := newCfg.PeerWithKey(removeKey) + if ok { + t.Error("reconfig failed to remove peer") + } + }) +} + +// TODO: replace with a loopback tunnel +type nilTun struct { + events chan tun.Event + closed chan struct{} +} + +func newNilTun() tun.Device { + return &nilTun{ + events: make(chan tun.Event), + closed: make(chan struct{}), + } +} + +func (t *nilTun) File() *os.File { return nil } +func (t *nilTun) Flush() error { return nil } +func (t *nilTun) MTU() (int, error) { return 1420, nil } +func (t *nilTun) Name() (string, error) { return "niltun", nil } +func (t *nilTun) Events() <-chan tun.Event { return t.events } + +func (t *nilTun) Read(data [][]byte, sizes []int, offset int) (int, error) { + <-t.closed + return 0, io.EOF +} + +func (t *nilTun) Write(data [][]byte, offset int) (int, error) { + <-t.closed + return 0, io.EOF +} + +func (t *nilTun) Close() error { + close(t.events) + close(t.closed) + return nil +} + +func (t *nilTun) BatchSize() int { return 1 } + +// A noopBind is a conn.Bind that does no actual binding work. +type noopBind struct{} + +func (noopBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + return nil, 1, nil +} +func (noopBind) Close() error { return nil } +func (noopBind) SetMark(mark uint32) error { return nil } +func (noopBind) Send(b [][]byte, ep conn.Endpoint) error { return nil } +func (noopBind) ParseEndpoint(s string) (conn.Endpoint, error) { + return dummyEndpoint(s), nil +} +func (noopBind) BatchSize() int { return 1 } + +// A dummyEndpoint is a string holding the endpoint destination. +type dummyEndpoint string + +func (e dummyEndpoint) ClearSrc() {} +func (e dummyEndpoint) SrcToString() string { return "" } +func (e dummyEndpoint) DstToString() string { return string(e) } +func (e dummyEndpoint) DstToBytes() []byte { return nil } +func (e dummyEndpoint) DstIP() netip.Addr { return netip.Addr{} } +func (dummyEndpoint) SrcIP() netip.Addr { return netip.Addr{} } diff --git a/wgengine/wgcfg/parser.go b/wgengine/wgcfg/parser.go index 553aaecbb7171..ec3d008f7de97 100644 --- a/wgengine/wgcfg/parser.go +++ b/wgengine/wgcfg/parser.go @@ -1,186 +1,186 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "bufio" - "fmt" - "io" - "net" - "net/netip" - "strconv" - "strings" - - "go4.org/mem" - "tailscale.com/types/key" -) - -type ParseError struct { - why string - offender string -} - -func (e *ParseError) Error() string { - return fmt.Sprintf("%s: %q", e.why, e.offender) -} - -func parseEndpoint(s string) (host string, port uint16, err error) { - i := strings.LastIndexByte(s, ':') - if i < 0 { - return "", 0, &ParseError{"Missing port from endpoint", s} - } - host, portStr := s[:i], s[i+1:] - if len(host) < 1 { - return "", 0, &ParseError{"Invalid endpoint host", host} - } - uport, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return "", 0, err - } - hostColon := strings.IndexByte(host, ':') - if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { - err := &ParseError{"Brackets must contain an IPv6 address", host} - if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { - maybeV6 := net.ParseIP(host[1 : len(host)-1]) - if maybeV6 == nil || len(maybeV6) != net.IPv6len { - return "", 0, err - } - } else { - return "", 0, err - } - host = host[1 : len(host)-1] - } - return host, uint16(uport), nil -} - -// memROCut separates a mem.RO at the separator if it exists, otherwise -// it returns two empty ROs and reports that it was not found. -func memROCut(s mem.RO, sep byte) (before, after mem.RO, found bool) { - if i := mem.IndexByte(s, sep); i >= 0 { - return s.SliceTo(i), s.SliceFrom(i + 1), true - } - found = false - return -} - -// FromUAPI generates a Config from r. -// r should be generated by calling device.IpcGetOperation; -// it is not compatible with other uapi streams. -func FromUAPI(r io.Reader) (*Config, error) { - cfg := new(Config) - var peer *Peer // current peer being operated on - deviceConfig := true - - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := mem.B(scanner.Bytes()) - if line.Len() == 0 { - continue - } - key, value, ok := memROCut(line, '=') - if !ok { - return nil, fmt.Errorf("failed to cut line %q on =", line.StringCopy()) - } - valueBytes := scanner.Bytes()[key.Len()+1:] - - if key.EqualString("public_key") { - if deviceConfig { - deviceConfig = false - } - // Load/create the peer we are now configuring. - var err error - peer, err = cfg.handlePublicKeyLine(valueBytes) - if err != nil { - return nil, err - } - continue - } - - var err error - if deviceConfig { - err = cfg.handleDeviceLine(key, value, valueBytes) - } else { - err = cfg.handlePeerLine(peer, key, value, valueBytes) - } - if err != nil { - return nil, err - } - } - - if err := scanner.Err(); err != nil { - return nil, err - } - - return cfg, nil -} - -func (cfg *Config) handleDeviceLine(k, value mem.RO, valueBytes []byte) error { - switch { - case k.EqualString("private_key"): - // wireguard-go guarantees not to send zero value; private keys are already clamped. - var err error - cfg.PrivateKey, err = key.ParseNodePrivateUntyped(value) - if err != nil { - return err - } - case k.EqualString("listen_port") || k.EqualString("fwmark"): - // ignore - default: - return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) - } - return nil -} - -func (cfg *Config) handlePublicKeyLine(valueBytes []byte) (*Peer, error) { - p := Peer{} - var err error - p.PublicKey, err = key.ParseNodePublicUntyped(mem.B(valueBytes)) - if err != nil { - return nil, err - } - cfg.Peers = append(cfg.Peers, p) - return &cfg.Peers[len(cfg.Peers)-1], nil -} - -func (cfg *Config) handlePeerLine(peer *Peer, k, value mem.RO, valueBytes []byte) error { - switch { - case k.EqualString("endpoint"): - nk, err := key.ParseNodePublicUntyped(value) - if err != nil { - return fmt.Errorf("invalid endpoint %q for peer %q, expected a hex public key", value.StringCopy(), peer.PublicKey.ShortString()) - } - // nk ought to equal peer.PublicKey. - // Under some rare circumstances, it might not. See corp issue #3016. - // Even if that happens, don't stop early, so that we can recover from it. - // Instead, note the value of nk so we can fix as needed. - peer.WGEndpoint = nk - case k.EqualString("persistent_keepalive_interval"): - n, err := mem.ParseUint(value, 10, 16) - if err != nil { - return err - } - peer.PersistentKeepalive = uint16(n) - case k.EqualString("allowed_ip"): - ipp := netip.Prefix{} - err := ipp.UnmarshalText(valueBytes) - if err != nil { - return err - } - peer.AllowedIPs = append(peer.AllowedIPs, ipp) - case k.EqualString("protocol_version"): - if !value.EqualString("1") { - return fmt.Errorf("invalid protocol version: %q", value.StringCopy()) - } - case k.EqualString("replace_allowed_ips") || - k.EqualString("preshared_key") || - k.EqualString("last_handshake_time_sec") || - k.EqualString("last_handshake_time_nsec") || - k.EqualString("tx_bytes") || - k.EqualString("rx_bytes"): - // ignore - default: - return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package wgcfg + +import ( + "bufio" + "fmt" + "io" + "net" + "net/netip" + "strconv" + "strings" + + "go4.org/mem" + "tailscale.com/types/key" +) + +type ParseError struct { + why string + offender string +} + +func (e *ParseError) Error() string { + return fmt.Sprintf("%s: %q", e.why, e.offender) +} + +func parseEndpoint(s string) (host string, port uint16, err error) { + i := strings.LastIndexByte(s, ':') + if i < 0 { + return "", 0, &ParseError{"Missing port from endpoint", s} + } + host, portStr := s[:i], s[i+1:] + if len(host) < 1 { + return "", 0, &ParseError{"Invalid endpoint host", host} + } + uport, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return "", 0, err + } + hostColon := strings.IndexByte(host, ':') + if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { + err := &ParseError{"Brackets must contain an IPv6 address", host} + if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { + maybeV6 := net.ParseIP(host[1 : len(host)-1]) + if maybeV6 == nil || len(maybeV6) != net.IPv6len { + return "", 0, err + } + } else { + return "", 0, err + } + host = host[1 : len(host)-1] + } + return host, uint16(uport), nil +} + +// memROCut separates a mem.RO at the separator if it exists, otherwise +// it returns two empty ROs and reports that it was not found. +func memROCut(s mem.RO, sep byte) (before, after mem.RO, found bool) { + if i := mem.IndexByte(s, sep); i >= 0 { + return s.SliceTo(i), s.SliceFrom(i + 1), true + } + found = false + return +} + +// FromUAPI generates a Config from r. +// r should be generated by calling device.IpcGetOperation; +// it is not compatible with other uapi streams. +func FromUAPI(r io.Reader) (*Config, error) { + cfg := new(Config) + var peer *Peer // current peer being operated on + deviceConfig := true + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := mem.B(scanner.Bytes()) + if line.Len() == 0 { + continue + } + key, value, ok := memROCut(line, '=') + if !ok { + return nil, fmt.Errorf("failed to cut line %q on =", line.StringCopy()) + } + valueBytes := scanner.Bytes()[key.Len()+1:] + + if key.EqualString("public_key") { + if deviceConfig { + deviceConfig = false + } + // Load/create the peer we are now configuring. + var err error + peer, err = cfg.handlePublicKeyLine(valueBytes) + if err != nil { + return nil, err + } + continue + } + + var err error + if deviceConfig { + err = cfg.handleDeviceLine(key, value, valueBytes) + } else { + err = cfg.handlePeerLine(peer, key, value, valueBytes) + } + if err != nil { + return nil, err + } + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return cfg, nil +} + +func (cfg *Config) handleDeviceLine(k, value mem.RO, valueBytes []byte) error { + switch { + case k.EqualString("private_key"): + // wireguard-go guarantees not to send zero value; private keys are already clamped. + var err error + cfg.PrivateKey, err = key.ParseNodePrivateUntyped(value) + if err != nil { + return err + } + case k.EqualString("listen_port") || k.EqualString("fwmark"): + // ignore + default: + return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) + } + return nil +} + +func (cfg *Config) handlePublicKeyLine(valueBytes []byte) (*Peer, error) { + p := Peer{} + var err error + p.PublicKey, err = key.ParseNodePublicUntyped(mem.B(valueBytes)) + if err != nil { + return nil, err + } + cfg.Peers = append(cfg.Peers, p) + return &cfg.Peers[len(cfg.Peers)-1], nil +} + +func (cfg *Config) handlePeerLine(peer *Peer, k, value mem.RO, valueBytes []byte) error { + switch { + case k.EqualString("endpoint"): + nk, err := key.ParseNodePublicUntyped(value) + if err != nil { + return fmt.Errorf("invalid endpoint %q for peer %q, expected a hex public key", value.StringCopy(), peer.PublicKey.ShortString()) + } + // nk ought to equal peer.PublicKey. + // Under some rare circumstances, it might not. See corp issue #3016. + // Even if that happens, don't stop early, so that we can recover from it. + // Instead, note the value of nk so we can fix as needed. + peer.WGEndpoint = nk + case k.EqualString("persistent_keepalive_interval"): + n, err := mem.ParseUint(value, 10, 16) + if err != nil { + return err + } + peer.PersistentKeepalive = uint16(n) + case k.EqualString("allowed_ip"): + ipp := netip.Prefix{} + err := ipp.UnmarshalText(valueBytes) + if err != nil { + return err + } + peer.AllowedIPs = append(peer.AllowedIPs, ipp) + case k.EqualString("protocol_version"): + if !value.EqualString("1") { + return fmt.Errorf("invalid protocol version: %q", value.StringCopy()) + } + case k.EqualString("replace_allowed_ips") || + k.EqualString("preshared_key") || + k.EqualString("last_handshake_time_sec") || + k.EqualString("last_handshake_time_nsec") || + k.EqualString("tx_bytes") || + k.EqualString("rx_bytes"): + // ignore + default: + return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) + } + return nil +} diff --git a/wgengine/winnet/winnet_windows.go b/wgengine/winnet/winnet_windows.go index 01e38517d2d64..283ce5ad17b68 100644 --- a/wgengine/winnet/winnet_windows.go +++ b/wgengine/winnet/winnet_windows.go @@ -1,26 +1,26 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package winnet - -import ( - "fmt" - "syscall" - "unsafe" - - "github.com/go-ole/go-ole" -) - -func (v *INetworkConnection) GetAdapterId() (string, error) { - buf := ole.GUID{} - hr, _, _ := syscall.Syscall( - v.VTable().GetAdapterId, - 2, - uintptr(unsafe.Pointer(v)), - uintptr(unsafe.Pointer(&buf)), - 0) - if hr != 0 { - return "", fmt.Errorf("GetAdapterId failed: %08x", hr) - } - return buf.String(), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package winnet + +import ( + "fmt" + "syscall" + "unsafe" + + "github.com/go-ole/go-ole" +) + +func (v *INetworkConnection) GetAdapterId() (string, error) { + buf := ole.GUID{} + hr, _, _ := syscall.Syscall( + v.VTable().GetAdapterId, + 2, + uintptr(unsafe.Pointer(v)), + uintptr(unsafe.Pointer(&buf)), + 0) + if hr != 0 { + return "", fmt.Errorf("GetAdapterId failed: %08x", hr) + } + return buf.String(), nil +} diff --git a/words/words.go b/words/words.go index 18efb75d77506..b373ffef6541f 100644 --- a/words/words.go +++ b/words/words.go @@ -1,58 +1,58 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package words contains accessors for some nice words. -package words - -import ( - "bytes" - _ "embed" - "strings" - "sync" -) - -//go:embed tails.txt -var tailsTxt []byte - -//go:embed scales.txt -var scalesTxt []byte - -var ( - once sync.Once - tails, scales []string -) - -// Tails returns words about tails. -func Tails() []string { - once.Do(initWords) - return tails -} - -// Scales returns words about scales. -func Scales() []string { - once.Do(initWords) - return scales -} - -func initWords() { - tails = parseWords(tailsTxt) - scales = parseWords(scalesTxt) -} - -func parseWords(txt []byte) []string { - n := bytes.Count(txt, []byte{'\n'}) - ret := make([]string, 0, n) - for len(txt) > 0 { - word := txt - i := bytes.IndexByte(txt, '\n') - if i != -1 { - word, txt = word[:i], txt[i+1:] - } else { - txt = nil - } - if word := strings.TrimSpace(string(word)); word != "" && word[0] != '#' { - ret = append(ret, word) - } - } - return ret -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package words contains accessors for some nice words. +package words + +import ( + "bytes" + _ "embed" + "strings" + "sync" +) + +//go:embed tails.txt +var tailsTxt []byte + +//go:embed scales.txt +var scalesTxt []byte + +var ( + once sync.Once + tails, scales []string +) + +// Tails returns words about tails. +func Tails() []string { + once.Do(initWords) + return tails +} + +// Scales returns words about scales. +func Scales() []string { + once.Do(initWords) + return scales +} + +func initWords() { + tails = parseWords(tailsTxt) + scales = parseWords(scalesTxt) +} + +func parseWords(txt []byte) []string { + n := bytes.Count(txt, []byte{'\n'}) + ret := make([]string, 0, n) + for len(txt) > 0 { + word := txt + i := bytes.IndexByte(txt, '\n') + if i != -1 { + word, txt = word[:i], txt[i+1:] + } else { + txt = nil + } + if word := strings.TrimSpace(string(word)); word != "" && word[0] != '#' { + ret = append(ret, word) + } + } + return ret +} diff --git a/words/words_test.go b/words/words_test.go index e96c234d7b84b..a9691792a5c00 100644 --- a/words/words_test.go +++ b/words/words_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package words - -import ( - "strings" - "testing" -) - -func TestWords(t *testing.T) { - test := func(t *testing.T, words []string) { - t.Helper() - if len(words) == 0 { - t.Error("no words") - } - seen := map[string]bool{} - for _, w := range words { - if seen[w] { - t.Errorf("dup word %q", w) - } - seen[w] = true - if w == "" || strings.IndexFunc(w, nonASCIILower) != -1 { - t.Errorf("malformed word %q", w) - } - } - } - t.Run("tails", func(t *testing.T) { test(t, Tails()) }) - t.Run("scales", func(t *testing.T) { test(t, Scales()) }) - t.Logf("%v tails * %v scales = %v beautiful combinations", len(Tails()), len(Scales()), len(Tails())*len(Scales())) -} - -func nonASCIILower(r rune) bool { - if 'a' <= r && r <= 'z' { - return false - } - return true -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package words + +import ( + "strings" + "testing" +) + +func TestWords(t *testing.T) { + test := func(t *testing.T, words []string) { + t.Helper() + if len(words) == 0 { + t.Error("no words") + } + seen := map[string]bool{} + for _, w := range words { + if seen[w] { + t.Errorf("dup word %q", w) + } + seen[w] = true + if w == "" || strings.IndexFunc(w, nonASCIILower) != -1 { + t.Errorf("malformed word %q", w) + } + } + } + t.Run("tails", func(t *testing.T) { test(t, Tails()) }) + t.Run("scales", func(t *testing.T) { test(t, Scales()) }) + t.Logf("%v tails * %v scales = %v beautiful combinations", len(Tails()), len(Scales()), len(Tails())*len(Scales())) +} + +func nonASCIILower(r rune) bool { + if 'a' <= r && r <= 'z' { + return false + } + return true +} From 1aef3e83b8ce88c186d16314fd14b68825137b87 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 5 Dec 2024 15:45:48 -0800 Subject: [PATCH 173/179] health: fix TestHealthMetric to pass on release branch Fixes #14302 Change-Id: I9fd893a97711c72b713fe5535f2ccb93fadf7452 Signed-off-by: Brad Fitzpatrick (cherry picked from commit dc6728729e903e83d7bc91de51dc38e115d79624) --- health/health_test.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/health/health_test.go b/health/health_test.go index 69e586066cdd6..ebdddc988edc7 100644 --- a/health/health_test.go +++ b/health/health_test.go @@ -14,6 +14,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/opt" "tailscale.com/util/usermetric" + "tailscale.com/version" ) func TestAppendWarnableDebugFlags(t *testing.T) { @@ -352,6 +353,11 @@ func TestShowUpdateWarnable(t *testing.T) { } func TestHealthMetric(t *testing.T) { + unstableBuildWarning := 0 + if version.IsUnstableBuild() { + unstableBuildWarning = 1 + } + tests := []struct { desc string check bool @@ -361,20 +367,20 @@ func TestHealthMetric(t *testing.T) { }{ // When running in dev, and not initialising the client, there will be two warnings // by default: - // - is-using-unstable-version + // - is-using-unstable-version (except on the release branch) // - wantrunning-false { desc: "base-warnings", check: true, cv: nil, - wantMetricCount: 2, + wantMetricCount: unstableBuildWarning + 1, }, // with: update-available { desc: "update-warning", check: true, cv: &tailcfg.ClientVersion{RunningLatest: false, LatestVersion: "1.2.3"}, - wantMetricCount: 3, + wantMetricCount: unstableBuildWarning + 2, }, } for _, tt := range tests { From c80eb698d5057b04d826b5ae2004d4c464ae28f6 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 5 Dec 2024 15:26:30 -0800 Subject: [PATCH 174/179] VERSION.txt: this is v1.78.1 Change-Id: I3588027fee8460b27c357d3a656f769fda151ccc Signed-off-by: Brad Fitzpatrick --- VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION.txt b/VERSION.txt index 79e15fd49370a..d9741f66cacd7 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -1.77.0 +1.78.1 From 3e3d5d8c6861330d44e3b0ac5648f6b2a392ee2a Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Tue, 10 Dec 2024 18:19:25 +0000 Subject: [PATCH 175/179] hostinfo: fix testing in container (#14330) (#14337) Previously this unit test failed if it was run in a container. Update the assert to focus on exactly the condition we are trying to assert: the package type should only be 'container' if we use the build tag. Updates #14317 Signed-off-by: Tom Proctor (cherry picked from commit 06c5e83c204b29496e67a8184d9ed7791c05b23c) --- hostinfo/hostinfo_linux_test.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/hostinfo/hostinfo_linux_test.go b/hostinfo/hostinfo_linux_test.go index c8bd2abbeb230..0286fadf329ab 100644 --- a/hostinfo/hostinfo_linux_test.go +++ b/hostinfo/hostinfo_linux_test.go @@ -35,8 +35,12 @@ remotes/origin/QTSFW_5.0.0` } } -func TestInContainer(t *testing.T) { - if got := inContainer(); !got.EqualBool(false) { - t.Errorf("inContainer = %v; want false due to absence of ts_package_container build tag", got) +func TestPackageTypeNotContainer(t *testing.T) { + var got string + if packageType != nil { + got = packageType() + } + if got == "container" { + t.Fatal("packageType = container; should only happen if build tag ts_package_container is set") } } From 6e0f168db07abe3ed7ca5a206b65087415708153 Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Wed, 11 Dec 2024 15:30:43 +0000 Subject: [PATCH 176/179] cmd/containerboot: fix nil pointer exception (cherry-pick of #14357, #14358) (#14359) * cmd/containerboot: guard kubeClient against nil dereference (#14357) A method on kc was called unconditionally, even if was not initialized, leading to a nil pointer dereference when TS_SERVE_CONFIG was set outside Kubernetes. Add a guard symmetric with other uses of the kubeClient. Signed-off-by: Bjorn Neergaard (cherry picked from commit 8b1d01161bbca8a26c2a50208444087c9fa2b3f1) * cmd/containerboot: don't attempt to write kube Secret in non-kube environments (#14358) Signed-off-by: Irbe Krumina (cherry picked from commit 0cc071f15409071f2649c3e142eceaf7cabff560) * cmd/containerboot: don't attempt to patch a Secret field without permissions (#14365) Signed-off-by: Irbe Krumina (cherry picked from commit 6e552f66a0289f6309477fb024019b62a251da16) Updates tailscale/tailscale#14354 --- cmd/containerboot/kube.go | 1 + cmd/containerboot/main.go | 6 ++++-- cmd/containerboot/serve.go | 6 ++++-- cmd/containerboot/settings.go | 1 + 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/cmd/containerboot/kube.go b/cmd/containerboot/kube.go index 643eef385ee0c..4d00687ee4566 100644 --- a/cmd/containerboot/kube.go +++ b/cmd/containerboot/kube.go @@ -24,6 +24,7 @@ import ( type kubeClient struct { kubeclient.Client stateSecret string + canPatch bool // whether the client has permissions to patch Kubernetes Secrets } func newKubeClient(root string, stateSecret string) (*kubeClient, error) { diff --git a/cmd/containerboot/main.go b/cmd/containerboot/main.go index ad1c0db201aa5..7411ea9496cfd 100644 --- a/cmd/containerboot/main.go +++ b/cmd/containerboot/main.go @@ -331,8 +331,10 @@ authLoop: if err := client.SetServeConfig(ctx, new(ipn.ServeConfig)); err != nil { log.Fatalf("failed to unset serve config: %v", err) } - if err := kc.storeHTTPSEndpoint(ctx, ""); err != nil { - log.Fatalf("failed to update HTTPS endpoint in tailscale state: %v", err) + if hasKubeStateStore(cfg) { + if err := kc.storeHTTPSEndpoint(ctx, ""); err != nil { + log.Fatalf("failed to update HTTPS endpoint in tailscale state: %v", err) + } } } diff --git a/cmd/containerboot/serve.go b/cmd/containerboot/serve.go index 29ee7347f0c14..14c7f00d7450f 100644 --- a/cmd/containerboot/serve.go +++ b/cmd/containerboot/serve.go @@ -72,8 +72,10 @@ func watchServeConfigChanges(ctx context.Context, path string, cdChanged <-chan if err := updateServeConfig(ctx, sc, certDomain, lc); err != nil { log.Fatalf("serve proxy: error updating serve config: %v", err) } - if err := kc.storeHTTPSEndpoint(ctx, certDomain); err != nil { - log.Fatalf("serve proxy: error storing HTTPS endpoint: %v", err) + if kc != nil && kc.canPatch { + if err := kc.storeHTTPSEndpoint(ctx, certDomain); err != nil { + log.Fatalf("serve proxy: error storing HTTPS endpoint: %v", err) + } } prevServeConfig = sc } diff --git a/cmd/containerboot/settings.go b/cmd/containerboot/settings.go index 4fae58584cec7..cc8641909dafe 100644 --- a/cmd/containerboot/settings.go +++ b/cmd/containerboot/settings.go @@ -214,6 +214,7 @@ func (cfg *settings) setupKube(ctx context.Context, kc *kubeClient) error { return fmt.Errorf("some Kubernetes permissions are missing, please check your RBAC configuration: %v", err) } cfg.KubernetesCanPatch = canPatch + kc.canPatch = canPatch s, err := kc.GetSecret(ctx, cfg.KubeSecret) if err != nil { From 3037dc793c6e738fee3cf36d8da24a9f54b1790d Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Wed, 11 Dec 2024 12:06:06 -0600 Subject: [PATCH 177/179] VERSION.txt: this is v1.78.2 Signed-off-by: Nick Khyl --- VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION.txt b/VERSION.txt index d9741f66cacd7..3bc3034360c70 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -1.78.1 +1.78.2 From 1b41fdeddb6598b35ba15cc6b07740e0cc0e8411 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Wed, 11 Dec 2024 14:26:51 -0600 Subject: [PATCH 178/179] VERSION.txt: this is v1.78.3 Signed-off-by: Nick Khyl --- VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION.txt b/VERSION.txt index 3bc3034360c70..2ea5ecd85abfc 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -1.78.2 +1.78.3 From 38bbe0193997373cfd1676757bff4ec6546000b3 Mon Sep 17 00:00:00 2001 From: ChandonPierre Date: Fri, 13 Dec 2024 21:57:27 -0500 Subject: [PATCH 179/179] fix(ci): update makefile target changed in 44c8892 --- .github/workflows/publish-image.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish-image.yaml b/.github/workflows/publish-image.yaml index d34954cf8f57e..363d5c661a297 100644 --- a/.github/workflows/publish-image.yaml +++ b/.github/workflows/publish-image.yaml @@ -57,7 +57,7 @@ jobs: - name: Publish k8s-operator shell: bash run: | - REPOS="ghcr.io/${{ github.repository }}/k8s-operator" TARGET="operator" ./build_docker.sh + REPOS="ghcr.io/${{ github.repository }}/k8s-operator" TARGET="k8s-operator" ./build_docker.sh - name: Publish k8s-nameserver shell: bash