Skip to content

Commit d5a0390

Browse files
committed
fix: race in close grpc transport
1 parent 257fead commit d5a0390

File tree

1 file changed

+37
-23
lines changed

1 file changed

+37
-23
lines changed

transport/gun/gun.go

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"sync"
1919
"time"
2020

21-
"github.com/metacubex/mihomo/common/atomic"
2221
"github.com/metacubex/mihomo/common/buf"
2322
"github.com/metacubex/mihomo/common/pool"
2423
"github.com/metacubex/mihomo/component/ech"
@@ -42,16 +41,19 @@ type DialFn = func(ctx context.Context, network, addr string) (net.Conn, error)
4241

4342
type Conn struct {
4443
initFn func() (io.ReadCloser, netAddr, error)
45-
writer io.Writer
44+
writer io.Writer // writer must not nil
4645
closer io.Closer
4746
netAddr
4847

49-
reader io.ReadCloser
50-
once sync.Once
51-
closed atomic.Bool
52-
err error
53-
remain int
54-
br *bufio.Reader
48+
initOnce sync.Once
49+
initErr error
50+
reader io.ReadCloser
51+
br *bufio.Reader
52+
remain int
53+
54+
closeMutex sync.Mutex
55+
closed bool
56+
5557
// deadlines
5658
deadline *time.Timer
5759
}
@@ -65,25 +67,29 @@ type Config struct {
6567
func (g *Conn) initReader() {
6668
reader, addr, err := g.initFn()
6769
if err != nil {
68-
g.err = err
70+
g.initErr = err
6971
if closer, ok := g.writer.(io.Closer); ok {
7072
closer.Close()
7173
}
7274
return
7375
}
7476
g.netAddr = addr
7577

76-
if !g.closed.Load() {
77-
g.reader = reader
78-
g.br = bufio.NewReader(reader)
79-
} else {
80-
reader.Close()
78+
g.closeMutex.Lock()
79+
defer g.closeMutex.Unlock()
80+
if g.closed { // if g.Close() be called between g.initFn(), direct close the initFn returned reader
81+
_ = reader.Close()
82+
g.initErr = net.ErrClosed
83+
return
8184
}
85+
86+
g.reader = reader
87+
g.br = bufio.NewReader(reader)
8288
}
8389

8490
func (g *Conn) Init() error {
85-
g.once.Do(g.initReader)
86-
return g.err
91+
g.initOnce.Do(g.initReader)
92+
return g.initErr
8793
}
8894

8995
func (g *Conn) Read(b []byte) (n int, err error) {
@@ -100,8 +106,6 @@ func (g *Conn) Read(b []byte) (n int, err error) {
100106
n, err = io.ReadFull(g.br, b[:size])
101107
g.remain -= n
102108
return
103-
} else if g.reader == nil {
104-
return 0, net.ErrClosed
105109
}
106110

107111
// 0x00 grpclength(uint32) 0x0A uleb128 payload
@@ -147,8 +151,8 @@ func (g *Conn) Write(b []byte) (n int, err error) {
147151
buf.Write(b)
148152

149153
_, err = g.writer.Write(buf.Bytes())
150-
if err == io.ErrClosedPipe && g.err != nil {
151-
err = g.err
154+
if err == io.ErrClosedPipe && g.initErr != nil {
155+
err = g.initErr
152156
}
153157

154158
if flusher, ok := g.writer.(http.Flusher); ok {
@@ -170,8 +174,8 @@ func (g *Conn) WriteBuffer(buffer *buf.Buffer) error {
170174
binary.PutUvarint(header[6:], uint64(dataLen))
171175
_, err := g.writer.Write(buffer.Bytes())
172176

173-
if err == io.ErrClosedPipe && g.err != nil {
174-
err = g.err
177+
if err == io.ErrClosedPipe && g.initErr != nil {
178+
err = g.initErr
175179
}
176180

177181
if flusher, ok := g.writer.(http.Flusher); ok {
@@ -186,7 +190,17 @@ func (g *Conn) FrontHeadroom() int {
186190
}
187191

188192
func (g *Conn) Close() error {
189-
g.closed.Store(true)
193+
g.initOnce.Do(func() { // if initReader not called, it should not be run anymore
194+
g.initErr = net.ErrClosed
195+
})
196+
197+
g.closeMutex.Lock()
198+
defer g.closeMutex.Unlock()
199+
if g.closed {
200+
return nil
201+
}
202+
g.closed = true
203+
190204
var errorArr []error
191205

192206
if reader := g.reader; reader != nil {

0 commit comments

Comments
 (0)