From 5f51212fd6150dcb9af2e2a2088deed6808377b9 Mon Sep 17 00:00:00 2001 From: Petar Maymounkov Date: Fri, 23 May 2025 21:05:55 -0700 Subject: [PATCH] fix race condition in driver closure --- rpc/driver.go | 48 +++++++++++++++++++++++++++--------------------- rpc/rpc_test.go | 2 +- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/rpc/driver.go b/rpc/driver.go index ac0d7f4..03b94f3 100644 --- a/rpc/driver.go +++ b/rpc/driver.go @@ -121,21 +121,27 @@ func ReceiveChan(ctx context.Context, t Transport) <-chan *Message { // Create a channel for sending messages through a Transport. Creates // a thread that won't exit until the returned channel is closed. // Does not close the underlying Transport. -func SendChan(t Transport, onErr func(uint32, error)) chan<- *Message { - ret := make(chan *Message, 1) - go func(c <-chan *Message) { +func SendChan(t Transport, onErr func(uint32, error)) (chan<- *Message, chan<- struct{}) { + ch := make(chan *Message, 1) + chClose := make(chan struct{}) + go func(ch <-chan *Message, cancel <-chan struct{}) { for { - if m, ok := <-c; !ok { + select { + case <-cancel: return - } else { - xid := m.Xid() - if err := t.Send(m); err != nil && onErr != nil { - onErr(xid, err) + case m, ok := <-ch: + if !ok { + return + } else { + xid := m.Xid() + if err := t.Send(m); err != nil && onErr != nil { + onErr(xid, err) + } } } } - }(ret) - return ret + }(ch, chClose) + return ch, chClose } // RPC driver implements all Transport-agnostic logic for handling @@ -184,13 +190,14 @@ type Driver struct { // If non-nil, all panics arising from service method implementations are passed to PanicHandle. PanicHandler PanicHandler - srv RpcSrv - ctx context.Context - cancel context.CancelFunc - out chan<- *Message - in <-chan *Message - cs CallSet - started int32 + srv RpcSrv + ctx context.Context + cancel context.CancelFunc + out chan<- *Message + outClose chan<- struct{} + in <-chan *Message + cs CallSet + started int32 } // PanicHandler defines a handler for panics arising from service method implementations. @@ -213,7 +220,7 @@ func (r *Driver) logXdr(t xdr.XdrType, f string, args ...interface{}) { var out bytes.Buffer fmt.Fprintf(&out, f, args...) out.WriteByte('\n') - t.XdrMarshal(xdr.XdrPrint{&out}, "") + t.XdrMarshal(xdr.XdrPrint{Out: &out}, "") r.Log.Write(out.Bytes()) } @@ -235,13 +242,13 @@ func NewDriver(ctx context.Context, t Transport) *Driver { cancel: cancel, in: ReceiveChan(ctx, t), } - ret.out = SendChan(t, func(xid uint32, _ error) { + ret.out, ret.outClose = SendChan(t, func(xid uint32, _ error) { ret.cs.Cancel(xid, SEND_ERR) }) go func() { <-ctx.Done() t.Close() - close(ret.out) + close(ret.outClose) }() return &ret @@ -260,7 +267,6 @@ func (r *Driver) Close() { } func (r *Driver) safeSend(ctx context.Context, m *Message) (ok bool) { - defer func() { recover() }() select { case r.out <- m: return true diff --git a/rpc/rpc_test.go b/rpc/rpc_test.go index b3f8aeb..8e1ee58 100644 --- a/rpc/rpc_test.go +++ b/rpc/rpc_test.go @@ -69,7 +69,7 @@ func TestChannels(t *testing.T) { tx1, tx2 := rpc.NewStreamTransport(cs[0]), rpc.NewStreamTransport(cs[1]) r := rpc.ReceiveChan(ctx, tx1) defer tx1.Close() - s := rpc.SendChan(tx2, nil) + s, _ := rpc.SendChan(tx2, nil) go func() { defer close(s) defer tx2.Close()