Skip to content

Commit b873e76

Browse files
committed
Merge branch 'main' of https://github.com/gorilla/websocket
2 parents bc9f200 + 5e00238 commit b873e76

10 files changed

+159
-64
lines changed

client.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,15 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
305305
})
306306
}
307307

308+
// Close the network connection when returning an error. The variable
309+
// netConn is set to nil before the success return at the end of the
310+
// function.
308311
defer func() {
309312
if netConn != nil {
310-
netConn.Close()
313+
// It's safe to ignore the error from Close() because this code is
314+
// only executed when returning a more important error to the
315+
// application.
316+
_ = netConn.Close()
311317
}
312318
}()
313319

@@ -398,8 +404,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
398404
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
399405
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
400406

401-
netConn.SetDeadline(time.Time{})
402-
netConn = nil // to avoid close in defer.
407+
if err := netConn.SetDeadline(time.Time{}); err != nil {
408+
return nil, resp, err
409+
}
410+
411+
// Success! Set netConn to nil to stop the deferred function above from
412+
// closing the network connection.
413+
netConn = nil
414+
403415
return conn, resp, nil
404416
}
405417

client_server_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ func TestNoUpgrade(t *testing.T) {
522522
}
523523
resp.Body.Close()
524524
if u := resp.Header.Get("Upgrade"); u != "websocket" {
525-
t.Errorf("Uprade response header is %q, want %q", u, "websocket")
525+
t.Errorf("Upgrade response header is %q, want %q", u, "websocket")
526526
}
527527
if resp.StatusCode != http.StatusUpgradeRequired {
528528
t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusUpgradeRequired)
@@ -578,7 +578,7 @@ func TestRespOnBadHandshake(t *testing.T) {
578578

579579
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
580580
w.WriteHeader(expectedStatus)
581-
io.WriteString(w, expectedBody)
581+
_, _ = io.WriteString(w, expectedBody)
582582
}))
583583
defer s.Close()
584584

@@ -828,7 +828,7 @@ func TestSocksProxyDial(t *testing.T) {
828828
}
829829
defer c1.Close()
830830

831-
c1.SetDeadline(time.Now().Add(30 * time.Second))
831+
_ = c1.SetDeadline(time.Now().Add(30 * time.Second))
832832

833833
buf := make([]byte, 32)
834834
if _, err := io.ReadFull(c1, buf[:3]); err != nil {
@@ -867,10 +867,10 @@ func TestSocksProxyDial(t *testing.T) {
867867
defer c2.Close()
868868
done := make(chan struct{})
869869
go func() {
870-
io.Copy(c1, c2)
870+
_, _ = io.Copy(c1, c2)
871871
close(done)
872872
}()
873-
io.Copy(c2, c1)
873+
_, _ = io.Copy(c2, c1)
874874
<-done
875875
}()
876876

compression.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
3434
"\x01\x00\x00\xff\xff"
3535

3636
fr, _ := flateReaderPool.Get().(io.ReadCloser)
37-
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
37+
mr := io.MultiReader(r, strings.NewReader(tail))
38+
if err := fr.(flate.Resetter).Reset(mr, nil); err != nil {
39+
// Reset never fails, but handle error in case that changes.
40+
fr = flate.NewReader(mr)
41+
}
3842
return &flateReadWrapper{fr}
3943
}
4044

compression_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func TestTruncWriter(t *testing.T) {
2222
if m > n {
2323
m = n
2424
}
25-
w.Write(p[:m])
25+
_, _ = w.Write(p[:m])
2626
p = p[m:]
2727
}
2828
if b.String() != data[:len(data)-len(w.p)] {
@@ -46,7 +46,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
4646
messages := textMessages(100)
4747
b.ResetTimer()
4848
for i := 0; i < b.N; i++ {
49-
c.WriteMessage(TextMessage, messages[i%len(messages)])
49+
_ = c.WriteMessage(TextMessage, messages[i%len(messages)])
5050
}
5151
b.ReportAllocs()
5252
}
@@ -59,7 +59,7 @@ func BenchmarkWriteWithCompression(b *testing.B) {
5959
c.newCompressionWriter = compressNoContextTakeover
6060
b.ResetTimer()
6161
for i := 0; i < b.N; i++ {
62-
c.WriteMessage(TextMessage, messages[i%len(messages)])
62+
_ = c.WriteMessage(TextMessage, messages[i%len(messages)])
6363
}
6464
b.ReportAllocs()
6565
}

conn.go

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,9 @@ func (c *Conn) read(n int) ([]byte, error) {
391391
if err == io.EOF {
392392
err = errUnexpectedEOF
393393
}
394-
c.br.Discard(len(p))
394+
// Discard is guaranteed to succeed because the number of bytes to discard
395+
// is less than or equal to the number of bytes buffered.
396+
_, _ = c.br.Discard(len(p))
395397
return p, err
396398
}
397399

@@ -412,7 +414,9 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
412414
return ErrNilNetConn
413415
}
414416

415-
c.conn.SetWriteDeadline(deadline)
417+
if err := c.conn.SetWriteDeadline(deadline); err != nil {
418+
return c.writeFatal(err)
419+
}
416420
if len(buf1) == 0 {
417421
_, err = c.conn.Write(buf0)
418422
} else {
@@ -422,7 +426,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
422426
return c.writeFatal(err)
423427
}
424428
if frameType == CloseMessage {
425-
c.writeFatal(ErrCloseSent)
429+
_ = c.writeFatal(ErrCloseSent)
426430
}
427431
return nil
428432
}
@@ -464,21 +468,27 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
464468
maskBytes(key, 0, buf[6:])
465469
}
466470

467-
d := 1000 * time.Hour
468-
if !deadline.IsZero() {
469-
d = deadline.Sub(time.Now())
471+
if deadline.IsZero() {
472+
// No timeout for zero time.
473+
<-c.mu
474+
} else {
475+
d := time.Until(deadline)
470476
if d < 0 {
471477
return errWriteTimeout
472478
}
479+
select {
480+
case <-c.mu:
481+
default:
482+
timer := time.NewTimer(d)
483+
select {
484+
case <-c.mu:
485+
timer.Stop()
486+
case <-timer.C:
487+
return errWriteTimeout
488+
}
489+
}
473490
}
474491

475-
timer := time.NewTimer(d)
476-
select {
477-
case <-c.mu:
478-
timer.Stop()
479-
case <-timer.C:
480-
return errWriteTimeout
481-
}
482492
defer func() { c.mu <- struct{}{} }()
483493

484494
c.writeErrMu.Lock()
@@ -491,13 +501,14 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
491501
return ErrNilNetConn
492502
}
493503

494-
c.conn.SetWriteDeadline(deadline)
495-
_, err = c.conn.Write(buf)
496-
if err != nil {
504+
if err := c.conn.SetWriteDeadline(deadline); err != nil {
505+
return c.writeFatal(err)
506+
}
507+
if _, err = c.conn.Write(buf); err != nil {
497508
return c.writeFatal(err)
498509
}
499510
if messageType == CloseMessage {
500-
c.writeFatal(ErrCloseSent)
511+
_ = c.writeFatal(ErrCloseSent)
501512
}
502513
return err
503514
}
@@ -670,7 +681,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
670681
}
671682

672683
if final {
673-
w.endMessage(errWriteClosed)
684+
_ = w.endMessage(errWriteClosed)
674685
return nil
675686
}
676687

@@ -865,7 +876,7 @@ func (c *Conn) advanceFrame() (int, error) {
865876
rsv2 := p[0]&rsv2Bit != 0
866877
rsv3 := p[0]&rsv3Bit != 0
867878
mask := p[1]&maskBit != 0
868-
c.setReadRemaining(int64(p[1] & 0x7f))
879+
_ = c.setReadRemaining(int64(p[1] & 0x7f)) // will not fail because argument is >= 0
869880

870881
c.readDecompress = false
871882
if rsv1 {
@@ -970,7 +981,8 @@ func (c *Conn) advanceFrame() (int, error) {
970981
}
971982

972983
if c.readLimit > 0 && c.readLength > c.readLimit {
973-
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
984+
// Make a best effort to send a close message describing the problem.
985+
_ = c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
974986
return noFrame, ErrReadLimit
975987
}
976988

@@ -982,7 +994,7 @@ func (c *Conn) advanceFrame() (int, error) {
982994
var payload []byte
983995
if c.readRemaining > 0 {
984996
payload, err = c.read(int(c.readRemaining))
985-
c.setReadRemaining(0)
997+
_ = c.setReadRemaining(0) // will not fail because argument is >= 0
986998
if err != nil {
987999
return noFrame, err
9881000
}
@@ -1032,7 +1044,8 @@ func (c *Conn) handleProtocolError(message string) error {
10321044
if len(data) > maxControlFramePayloadSize {
10331045
data = data[:maxControlFramePayloadSize]
10341046
}
1035-
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
1047+
// Make a best effor to send a close message describing the problem.
1048+
_ = c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
10361049
return errors.New("websocket: " + message)
10371050
}
10381051

@@ -1111,7 +1124,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
11111124
}
11121125
rem := c.readRemaining
11131126
rem -= int64(n)
1114-
c.setReadRemaining(rem)
1127+
_ = c.setReadRemaining(rem) // rem is guaranteed to be >= 0
11151128
if c.readRemaining > 0 && c.readErr == io.EOF {
11161129
c.readErr = errUnexpectedEOF
11171130
}
@@ -1211,7 +1224,8 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
12111224
if h == nil {
12121225
h = func(code int, text string) error {
12131226
message := FormatCloseMessage(code, "")
1214-
c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
1227+
// Make a best effor to send the close message.
1228+
_ = c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
12151229
return nil
12161230
}
12171231
}
@@ -1239,7 +1253,7 @@ func (c *Conn) SetPingHandler(h func(appData string) error) {
12391253
}
12401254
if h == nil {
12411255
h = func(message string) error {
1242-
// Make a best effort to send the pong mesage.
1256+
// Make a best effort to send the pong message.
12431257
_ = c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
12441258
return nil
12451259
}

conn_broadcast_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ func (b *broadcastBench) makeConns(numConns int) {
6969
select {
7070
case msg := <-c.msgCh:
7171
if msg.prepared != nil {
72-
c.conn.WritePreparedMessage(msg.prepared)
72+
_ = c.conn.WritePreparedMessage(msg.prepared)
7373
} else {
74-
c.conn.WriteMessage(TextMessage, msg.payload)
74+
_ = c.conn.WriteMessage(TextMessage, msg.payload)
7575
}
7676
val := atomic.AddInt32(&b.count, 1)
7777
if val%int32(numConns) == 0 {

0 commit comments

Comments
 (0)