From 6ce8eb0749f8708d90727bb6d7398555e380402a Mon Sep 17 00:00:00 2001 From: Keegan Carruthers-Smith Date: Mon, 5 Jun 2023 15:59:48 +0200 Subject: [PATCH] conn: do not lock sending when closing The sending mutex may be blocked due to the underlying conn blocking. If that happens then we can't call close since close also attempts to hold the sending mutex. Sending mutex is only used for serializing sends and doesn't protect the fields close reads/writes. I believe we introduced locking it so we would return ErrClosed. This commit instead introduces a check in send which rechecks if we have since closed while attempting to send. Test Plan: expanded the test coverage --- conn.go | 13 ++++++- conn_test.go | 101 +++++++++++++++++++++++++++++++++++---------------- 2 files changed, 81 insertions(+), 33 deletions(-) diff --git a/conn.go b/conn.go index 9de994a..8ce19d1 100644 --- a/conn.go +++ b/conn.go @@ -166,9 +166,7 @@ func (c *Conn) SendResponse(ctx context.Context, resp *Response) error { } func (c *Conn) close(cause error) error { - c.sending.Lock() c.mu.Lock() - defer c.sending.Unlock() defer c.mu.Unlock() if c.closed { @@ -249,6 +247,17 @@ func (c *Conn) send(_ context.Context, m *anyMessage, wait bool) (cc *call, err c.sending.Lock() defer c.sending.Unlock() + // double check the error isn't due to being closed while sending. + defer func() { + if err != nil { + c.mu.Lock() + if c.closed { + err = ErrClosed + } + c.mu.Unlock() + } + }() + // m.request.ID could be changed, so we store a copy to correctly // clean up pending var id ID diff --git a/conn_test.go b/conn_test.go index 56e0350..aa32fd9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -118,38 +118,77 @@ func TestConn_DisconnectNotify(t *testing.T) { } func TestConn_Close(t *testing.T) { - t.Run("waiting for response", func(t *testing.T) { - connA, connB := net.Pipe() - nodeA := jsonrpc2.NewConn( - context.Background(), - jsonrpc2.NewPlainObjectStream(connA), noopHandler{}, - ) - defer nodeA.Close() - nodeB := jsonrpc2.NewConn( - context.Background(), - jsonrpc2.NewPlainObjectStream(connB), - noopHandler{}, - ) - defer nodeB.Close() - - ready := make(chan struct{}) - done := make(chan struct{}) - go func() { - close(ready) - err := nodeB.Call(context.Background(), "m", nil, nil) - if err != jsonrpc2.ErrClosed { - t.Errorf("got error %v, want %v", err, jsonrpc2.ErrClosed) + cases := []struct { + name string + run func(*testing.T, context.Context, *jsonrpc2.Conn) + }{{ + name: "during Call", + run: func(t *testing.T, ctx context.Context, conn *jsonrpc2.Conn) { + ready := make(chan struct{}) + done := make(chan struct{}) + go func() { + close(ready) + err := conn.Call(ctx, "m", nil, nil) + if err != jsonrpc2.ErrClosed { + t.Errorf("got error %v, want %v", err, jsonrpc2.ErrClosed) + } + close(done) + }() + // Wait for the request to be sent before we close the connection. + <-ready + if err := conn.Close(); err != nil && err != jsonrpc2.ErrClosed { + t.Error(err) } - close(done) - }() - // Wait for the request to be sent before we close the connection. - <-ready - if err := nodeB.Close(); err != nil && err != jsonrpc2.ErrClosed { - t.Error(err) - } - assertDisconnect(t, nodeB, connB) - <-done - }) + <-done + }, + }, { + name: "during Wait", + run: func(t *testing.T, ctx context.Context, conn *jsonrpc2.Conn) { + call, err := conn.DispatchCall(ctx, "m", nil, nil) + if err != nil { + t.Fatal(err) + } + if err := conn.Close(); err != nil { + t.Fatal(err) + } + if err := call.Wait(ctx, nil); err != jsonrpc2.ErrClosed { + t.Fatal(err) + } + }, + }, { + name: "during Dispatch", + run: func(t *testing.T, ctx context.Context, conn *jsonrpc2.Conn) { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + if _, err := conn.DispatchCall(ctx, "m", nil, nil); err != jsonrpc2.ErrClosed { + t.Fatal(err) + } + }, + }} + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + connA, connB := net.Pipe() + nodeA := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewPlainObjectStream(connA), noopHandler{}, + ) + defer nodeA.Close() + nodeB := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewPlainObjectStream(connB), + noopHandler{}, + ) + defer nodeB.Close() + + tc.run(t, ctx, nodeB) + + assertDisconnect(t, nodeB, connB) + }) + } } func testParams(t *testing.T, want *json.RawMessage, fn func(c *jsonrpc2.Conn) error) {