From 5d80b29f441bbb48e87bca31a6c9ecf3f92449a3 Mon Sep 17 00:00:00 2001 From: Keegan Carruthers-Smith Date: Wed, 7 Jun 2023 08:40:20 +0200 Subject: [PATCH] conn: do not lock sending when closing (#70) The sending mutex may be blocked due to the underlying conn blocking. If that happens then we can't call close since close also attempts to hold the sending mutex. Sending mutex is only used for serializing sends and doesn't protect the fields close reads/writes. I believe we introduced locking it so we would return ErrClosed. This commit instead introduces a check in send which rechecks if we have since closed while attempting to send. Test Plan: expanded the test coverage --- conn.go | 13 ++++++- conn_test.go | 101 +++++++++++++++++++++++++++++++++++---------------- 2 files changed, 81 insertions(+), 33 deletions(-) diff --git a/conn.go b/conn.go index 9de994a..8ce19d1 100644 --- a/conn.go +++ b/conn.go @@ -166,9 +166,7 @@ func (c *Conn) SendResponse(ctx context.Context, resp *Response) error { } func (c *Conn) close(cause error) error { - c.sending.Lock() c.mu.Lock() - defer c.sending.Unlock() defer c.mu.Unlock() if c.closed { @@ -249,6 +247,17 @@ func (c *Conn) send(_ context.Context, m *anyMessage, wait bool) (cc *call, err c.sending.Lock() defer c.sending.Unlock() + // double check the error isn't due to being closed while sending. + defer func() { + if err != nil { + c.mu.Lock() + if c.closed { + err = ErrClosed + } + c.mu.Unlock() + } + }() + // m.request.ID could be changed, so we store a copy to correctly // clean up pending var id ID diff --git a/conn_test.go b/conn_test.go index 56e0350..aa32fd9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -118,38 +118,77 @@ func TestConn_DisconnectNotify(t *testing.T) { } func TestConn_Close(t *testing.T) { - t.Run("waiting for response", func(t *testing.T) { - connA, connB := net.Pipe() - nodeA := jsonrpc2.NewConn( - context.Background(), - jsonrpc2.NewPlainObjectStream(connA), noopHandler{}, - ) - defer nodeA.Close() - nodeB := jsonrpc2.NewConn( - context.Background(), - jsonrpc2.NewPlainObjectStream(connB), - noopHandler{}, - ) - defer nodeB.Close() - - ready := make(chan struct{}) - done := make(chan struct{}) - go func() { - close(ready) - err := nodeB.Call(context.Background(), "m", nil, nil) - if err != jsonrpc2.ErrClosed { - t.Errorf("got error %v, want %v", err, jsonrpc2.ErrClosed) + cases := []struct { + name string + run func(*testing.T, context.Context, *jsonrpc2.Conn) + }{{ + name: "during Call", + run: func(t *testing.T, ctx context.Context, conn *jsonrpc2.Conn) { + ready := make(chan struct{}) + done := make(chan struct{}) + go func() { + close(ready) + err := conn.Call(ctx, "m", nil, nil) + if err != jsonrpc2.ErrClosed { + t.Errorf("got error %v, want %v", err, jsonrpc2.ErrClosed) + } + close(done) + }() + // Wait for the request to be sent before we close the connection. + <-ready + if err := conn.Close(); err != nil && err != jsonrpc2.ErrClosed { + t.Error(err) } - close(done) - }() - // Wait for the request to be sent before we close the connection. - <-ready - if err := nodeB.Close(); err != nil && err != jsonrpc2.ErrClosed { - t.Error(err) - } - assertDisconnect(t, nodeB, connB) - <-done - }) + <-done + }, + }, { + name: "during Wait", + run: func(t *testing.T, ctx context.Context, conn *jsonrpc2.Conn) { + call, err := conn.DispatchCall(ctx, "m", nil, nil) + if err != nil { + t.Fatal(err) + } + if err := conn.Close(); err != nil { + t.Fatal(err) + } + if err := call.Wait(ctx, nil); err != jsonrpc2.ErrClosed { + t.Fatal(err) + } + }, + }, { + name: "during Dispatch", + run: func(t *testing.T, ctx context.Context, conn *jsonrpc2.Conn) { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + if _, err := conn.DispatchCall(ctx, "m", nil, nil); err != jsonrpc2.ErrClosed { + t.Fatal(err) + } + }, + }} + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + connA, connB := net.Pipe() + nodeA := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewPlainObjectStream(connA), noopHandler{}, + ) + defer nodeA.Close() + nodeB := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewPlainObjectStream(connB), + noopHandler{}, + ) + defer nodeB.Close() + + tc.run(t, ctx, nodeB) + + assertDisconnect(t, nodeB, connB) + }) + } } func testParams(t *testing.T, want *json.RawMessage, fn func(c *jsonrpc2.Conn) error) {