7
7
"bytes"
8
8
"context"
9
9
"encoding/binary"
10
- "errors"
11
10
"flag"
12
11
"fmt"
13
12
"io"
@@ -450,7 +449,7 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte)
450
449
return
451
450
}
452
451
453
- func runTestQuery (tb testing.TB , port uint16 , request []byte , modify func (* forwarder )) ([]byte , error ) {
452
+ func runTestQuery (tb testing.TB , request []byte , modify func (* forwarder ), ports ... uint16 ) ([]byte , error ) {
454
453
netMon , err := netmon .New (tb .Logf )
455
454
if err != nil {
456
455
tb .Fatal (err )
@@ -464,8 +463,9 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
464
463
modify (fwd )
465
464
}
466
465
467
- rr := resolverAndDelay {
468
- name : & dnstype.Resolver {Addr : fmt .Sprintf ("127.0.0.1:%d" , port )},
466
+ resolvers := make ([]resolverAndDelay , len (ports ))
467
+ for i , port := range ports {
468
+ resolvers [i ].name = & dnstype.Resolver {Addr : fmt .Sprintf ("127.0.0.1:%d" , port )}
469
469
}
470
470
471
471
rpkt := packet {
@@ -477,7 +477,7 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
477
477
rchan := make (chan packet , 1 )
478
478
ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
479
479
tb .Cleanup (cancel )
480
- err = fwd .forwardWithDestChan (ctx , rpkt , rchan , rr )
480
+ err = fwd .forwardWithDestChan (ctx , rpkt , rchan , resolvers ... )
481
481
select {
482
482
case res := <- rchan :
483
483
return res .bs , err
@@ -486,8 +486,62 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
486
486
}
487
487
}
488
488
489
- func mustRunTestQuery (tb testing.TB , port uint16 , request []byte , modify func (* forwarder )) []byte {
490
- resp , err := runTestQuery (tb , port , request , modify )
489
+ // makeTestRequest returns a new TypeA request for the given domain.
490
+ func makeTestRequest (tb testing.TB , domain string ) []byte {
491
+ tb .Helper ()
492
+ name := dns .MustNewName (domain )
493
+ builder := dns .NewBuilder (nil , dns.Header {})
494
+ builder .StartQuestions ()
495
+ builder .Question (dns.Question {
496
+ Name : name ,
497
+ Type : dns .TypeA ,
498
+ Class : dns .ClassINET ,
499
+ })
500
+ request , err := builder .Finish ()
501
+ if err != nil {
502
+ tb .Fatal (err )
503
+ }
504
+ return request
505
+ }
506
+
507
+ // makeTestResponse returns a new Type A response for the given domain,
508
+ // with the specified status code and zero or more addresses.
509
+ func makeTestResponse (tb testing.TB , domain string , code dns.RCode , addrs ... netip.Addr ) []byte {
510
+ tb .Helper ()
511
+ name := dns .MustNewName (domain )
512
+ builder := dns .NewBuilder (nil , dns.Header {
513
+ Response : true ,
514
+ Authoritative : true ,
515
+ RCode : code ,
516
+ })
517
+ builder .StartQuestions ()
518
+ q := dns.Question {
519
+ Name : name ,
520
+ Type : dns .TypeA ,
521
+ Class : dns .ClassINET ,
522
+ }
523
+ builder .Question (q )
524
+ if len (addrs ) > 0 {
525
+ builder .StartAnswers ()
526
+ for _ , addr := range addrs {
527
+ builder .AResource (dns.ResourceHeader {
528
+ Name : q .Name ,
529
+ Class : q .Class ,
530
+ TTL : 120 ,
531
+ }, dns.AResource {
532
+ A : addr .As4 (),
533
+ })
534
+ }
535
+ }
536
+ response , err := builder .Finish ()
537
+ if err != nil {
538
+ tb .Fatal (err )
539
+ }
540
+ return response
541
+ }
542
+
543
+ func mustRunTestQuery (tb testing.TB , request []byte , modify func (* forwarder ), ports ... uint16 ) []byte {
544
+ resp , err := runTestQuery (tb , request , modify , ports ... )
491
545
if err != nil {
492
546
tb .Fatalf ("error making request: %v" , err )
493
547
}
@@ -516,7 +570,7 @@ func TestForwarderTCPFallback(t *testing.T) {
516
570
}
517
571
})
518
572
519
- resp := mustRunTestQuery (t , port , request , nil )
573
+ resp := mustRunTestQuery (t , request , nil , port )
520
574
if ! bytes .Equal (resp , largeResponse ) {
521
575
t .Errorf ("invalid response\n got: %+v\n want: %+v" , resp , largeResponse )
522
576
}
@@ -554,7 +608,7 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) {
554
608
}
555
609
})
556
610
557
- resp := mustRunTestQuery (t , port , request , nil )
611
+ resp := mustRunTestQuery (t , request , nil , port )
558
612
if ! bytes .Equal (resp , largeResponse ) {
559
613
t .Errorf ("invalid response\n got: %+v\n want: %+v" , resp , largeResponse )
560
614
}
@@ -585,11 +639,11 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) {
585
639
}
586
640
})
587
641
588
- resp := mustRunTestQuery (t , port , request , func (fwd * forwarder ) {
642
+ resp := mustRunTestQuery (t , request , func (fwd * forwarder ) {
589
643
// Disable retries for this test.
590
644
fwd .controlKnobs = & controlknobs.Knobs {}
591
645
fwd .controlKnobs .DisableDNSForwarderTCPRetries .Store (true )
592
- })
646
+ }, port )
593
647
594
648
wantResp := append ([]byte (nil ), largeResponse [:maxResponseBytes ]... )
595
649
@@ -613,41 +667,10 @@ func TestForwarderTCPFallbackError(t *testing.T) {
613
667
const domain = "error-response.tailscale.com."
614
668
615
669
// Our response is a SERVFAIL
616
- response := func () []byte {
617
- name := dns .MustNewName (domain )
618
-
619
- builder := dns .NewBuilder (nil , dns.Header {
620
- Response : true ,
621
- RCode : dns .RCodeServerFailure ,
622
- })
623
- builder .StartQuestions ()
624
- builder .Question (dns.Question {
625
- Name : name ,
626
- Type : dns .TypeA ,
627
- Class : dns .ClassINET ,
628
- })
629
- response , err := builder .Finish ()
630
- if err != nil {
631
- t .Fatal (err )
632
- }
633
- return response
634
- }()
670
+ response := makeTestResponse (t , domain , dns .RCodeServerFailure )
635
671
636
672
// Our request is a single A query for the domain in the answer, above.
637
- request := func () []byte {
638
- builder := dns .NewBuilder (nil , dns.Header {})
639
- builder .StartQuestions ()
640
- builder .Question (dns.Question {
641
- Name : dns .MustNewName (domain ),
642
- Type : dns .TypeA ,
643
- Class : dns .ClassINET ,
644
- })
645
- request , err := builder .Finish ()
646
- if err != nil {
647
- t .Fatal (err )
648
- }
649
- return request
650
- }()
673
+ request := makeTestRequest (t , domain )
651
674
652
675
var sawRequest atomic.Bool
653
676
port := runDNSServer (t , nil , response , func (isTCP bool , gotRequest []byte ) {
@@ -657,14 +680,141 @@ func TestForwarderTCPFallbackError(t *testing.T) {
657
680
}
658
681
})
659
682
660
- _ , err := runTestQuery (t , port , request , nil )
683
+ resp , err := runTestQuery (t , request , nil , port )
661
684
if ! sawRequest .Load () {
662
685
t .Error ("did not see DNS request" )
663
686
}
664
- if err == nil {
665
- t .Error ("wanted error, got nil" )
666
- } else if ! errors .Is (err , errServerFailure ) {
667
- t .Errorf ("wanted errServerFailure, got: %v" , err )
687
+ if err != nil {
688
+ t .Fatalf ("wanted nil, got %v" , err )
689
+ }
690
+ var parser dns.Parser
691
+ respHeader , err := parser .Start (resp )
692
+ if err != nil {
693
+ t .Fatalf ("parser.Start() failed: %v" , err )
694
+ }
695
+ if got , want := respHeader .RCode , dns .RCodeServerFailure ; got != want {
696
+ t .Errorf ("wanted %v, got %v" , want , got )
697
+ }
698
+ }
699
+
700
+ // Test to ensure that if we have more than one resolver, and at least one of them
701
+ // returns a successful response, we propagate it.
702
+ func TestForwarderWithManyResolvers (t * testing.T ) {
703
+ enableDebug (t )
704
+
705
+ const domain = "example.com."
706
+ request := makeTestRequest (t , domain )
707
+
708
+ tests := []struct {
709
+ name string
710
+ responses [][]byte // upstream responses
711
+ wantResponses [][]byte // we should receive one of these from the forwarder
712
+ }{
713
+ {
714
+ name : "Success" ,
715
+ responses : [][]byte { // All upstream servers returned successful, but different, response.
716
+ makeTestResponse (t , domain , dns .RCodeSuccess , netip .MustParseAddr ("127.0.0.1" )),
717
+ makeTestResponse (t , domain , dns .RCodeSuccess , netip .MustParseAddr ("127.0.0.2" )),
718
+ makeTestResponse (t , domain , dns .RCodeSuccess , netip .MustParseAddr ("127.0.0.3" )),
719
+ },
720
+ wantResponses : [][]byte { // We may forward whichever response is received first.
721
+ makeTestResponse (t , domain , dns .RCodeSuccess , netip .MustParseAddr ("127.0.0.1" )),
722
+ makeTestResponse (t , domain , dns .RCodeSuccess , netip .MustParseAddr ("127.0.0.2" )),
723
+ makeTestResponse (t , domain , dns .RCodeSuccess , netip .MustParseAddr ("127.0.0.3" )),
724
+ },
725
+ },
726
+ {
727
+ name : "ServFail" ,
728
+ responses : [][]byte { // All upstream servers returned a SERVFAIL.
729
+ makeTestResponse (t , domain , dns .RCodeServerFailure ),
730
+ makeTestResponse (t , domain , dns .RCodeServerFailure ),
731
+ makeTestResponse (t , domain , dns .RCodeServerFailure ),
732
+ },
733
+ wantResponses : [][]byte {
734
+ makeTestResponse (t , domain , dns .RCodeServerFailure ),
735
+ },
736
+ },
737
+ {
738
+ name : "ServFail+Success" ,
739
+ responses : [][]byte { // All upstream servers fail except for one.
740
+ makeTestResponse (t , domain , dns .RCodeServerFailure ),
741
+ makeTestResponse (t , domain , dns .RCodeServerFailure ),
742
+ makeTestResponse (t , domain , dns .RCodeSuccess , netip .MustParseAddr ("127.0.0.1" )),
743
+ makeTestResponse (t , domain , dns .RCodeServerFailure ),
744
+ },
745
+ wantResponses : [][]byte { // We should forward the successful response.
746
+ makeTestResponse (t , domain , dns .RCodeSuccess , netip .MustParseAddr ("127.0.0.1" )),
747
+ },
748
+ },
749
+ {
750
+ name : "NXDomain" ,
751
+ responses : [][]byte { // All upstream servers returned NXDOMAIN.
752
+ makeTestResponse (t , domain , dns .RCodeNameError ),
753
+ makeTestResponse (t , domain , dns .RCodeNameError ),
754
+ makeTestResponse (t , domain , dns .RCodeNameError ),
755
+ },
756
+ wantResponses : [][]byte {
757
+ makeTestResponse (t , domain , dns .RCodeNameError ),
758
+ },
759
+ },
760
+ {
761
+ name : "NXDomain+Success" ,
762
+ responses : [][]byte { // All upstream servers returned NXDOMAIN except for one.
763
+ makeTestResponse (t , domain , dns .RCodeNameError ),
764
+ makeTestResponse (t , domain , dns .RCodeNameError ),
765
+ makeTestResponse (t , domain , dns .RCodeSuccess , netip .MustParseAddr ("127.0.0.1" )),
766
+ },
767
+ wantResponses : [][]byte { // However, only SERVFAIL are considered to be errors. Therefore, we may forward any response.
768
+ makeTestResponse (t , domain , dns .RCodeNameError ),
769
+ makeTestResponse (t , domain , dns .RCodeSuccess , netip .MustParseAddr ("127.0.0.1" )),
770
+ },
771
+ },
772
+ {
773
+ name : "Refused" ,
774
+ responses : [][]byte { // All upstream servers return different failures.
775
+ makeTestResponse (t , domain , dns .RCodeRefused ),
776
+ makeTestResponse (t , domain , dns .RCodeRefused ),
777
+ makeTestResponse (t , domain , dns .RCodeRefused ),
778
+ makeTestResponse (t , domain , dns .RCodeRefused ),
779
+ makeTestResponse (t , domain , dns .RCodeRefused ),
780
+ makeTestResponse (t , domain , dns .RCodeSuccess , netip .MustParseAddr ("127.0.0.1" )),
781
+ },
782
+ wantResponses : [][]byte { // Refused is not considered to be an error and can be forwarded.
783
+ makeTestResponse (t , domain , dns .RCodeRefused ),
784
+ makeTestResponse (t , domain , dns .RCodeSuccess , netip .MustParseAddr ("127.0.0.1" )),
785
+ },
786
+ },
787
+ {
788
+ name : "MixFail" ,
789
+ responses : [][]byte { // All upstream servers return different failures.
790
+ makeTestResponse (t , domain , dns .RCodeServerFailure ),
791
+ makeTestResponse (t , domain , dns .RCodeNameError ),
792
+ makeTestResponse (t , domain , dns .RCodeRefused ),
793
+ },
794
+ wantResponses : [][]byte { // Both NXDomain and Refused can be forwarded.
795
+ makeTestResponse (t , domain , dns .RCodeNameError ),
796
+ makeTestResponse (t , domain , dns .RCodeRefused ),
797
+ },
798
+ },
799
+ }
800
+
801
+ for _ , tt := range tests {
802
+ t .Run (tt .name , func (t * testing.T ) {
803
+ ports := make ([]uint16 , len (tt .responses ))
804
+ for i := range tt .responses {
805
+ ports [i ] = runDNSServer (t , nil , tt .responses [i ], func (isTCP bool , gotRequest []byte ) {})
806
+ }
807
+ gotResponse , err := runTestQuery (t , request , nil , ports ... )
808
+ if err != nil {
809
+ t .Fatalf ("wanted nil, got %v" , err )
810
+ }
811
+ responseOk := slices .ContainsFunc (tt .wantResponses , func (wantResponse []byte ) bool {
812
+ return slices .Equal (gotResponse , wantResponse )
813
+ })
814
+ if ! responseOk {
815
+ t .Errorf ("invalid response\n got: %+v\n want: %+v" , gotResponse , tt .wantResponses [0 ])
816
+ }
817
+ })
668
818
}
669
819
}
670
820
@@ -713,7 +863,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) {
713
863
port := runDNSServer (t , nil , response , func (isTCP bool , gotRequest []byte ) {
714
864
})
715
865
716
- res , err := runTestQuery (t , port , request , nil )
866
+ res , err := runTestQuery (t , request , nil , port )
717
867
if err != nil {
718
868
t .Fatal (err )
719
869
}
0 commit comments