diff --git a/conn.go b/conn.go index 9de994a..8ce19d1 100644 --- a/conn.go +++ b/conn.go @@ -166,9 +166,7 @@ func (c *Conn) SendResponse(ctx context.Context, resp *Response) error { } func (c *Conn) close(cause error) error { - c.sending.Lock() c.mu.Lock() - defer c.sending.Unlock() defer c.mu.Unlock() if c.closed { @@ -249,6 +247,17 @@ func (c *Conn) send(_ context.Context, m *anyMessage, wait bool) (cc *call, err c.sending.Lock() defer c.sending.Unlock() + // double check the error isn't due to being closed while sending. + defer func() { + if err != nil { + c.mu.Lock() + if c.closed { + err = ErrClosed + } + c.mu.Unlock() + } + }() + // m.request.ID could be changed, so we store a copy to correctly // clean up pending var id ID diff --git a/conn_test.go b/conn_test.go index 56e0350..aa32fd9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -118,38 +118,77 @@ func TestConn_DisconnectNotify(t *testing.T) { } func TestConn_Close(t *testing.T) { - t.Run("waiting for response", func(t *testing.T) { - connA, connB := net.Pipe() - nodeA := jsonrpc2.NewConn( - context.Background(), - jsonrpc2.NewPlainObjectStream(connA), noopHandler{}, - ) - defer nodeA.Close() - nodeB := jsonrpc2.NewConn( - context.Background(), - jsonrpc2.NewPlainObjectStream(connB), - noopHandler{}, - ) - defer nodeB.Close() - - ready := make(chan struct{}) - done := make(chan struct{}) - go func() { - close(ready) - err := nodeB.Call(context.Background(), "m", nil, nil) - if err != jsonrpc2.ErrClosed { - t.Errorf("got error %v, want %v", err, jsonrpc2.ErrClosed) + cases := []struct { + name string + run func(*testing.T, context.Context, *jsonrpc2.Conn) + }{{ + name: "during Call", + run: func(t *testing.T, ctx context.Context, conn *jsonrpc2.Conn) { + ready := make(chan struct{}) + done := make(chan struct{}) + go func() { + close(ready) + err := conn.Call(ctx, "m", nil, nil) + if err != jsonrpc2.ErrClosed { + t.Errorf("got error %v, want %v", err, jsonrpc2.ErrClosed) + } + close(done) + }() + // Wait for the request to be sent before we close the connection. + <-ready + if err := conn.Close(); err != nil && err != jsonrpc2.ErrClosed { + t.Error(err) } - close(done) - }() - // Wait for the request to be sent before we close the connection. - <-ready - if err := nodeB.Close(); err != nil && err != jsonrpc2.ErrClosed { - t.Error(err) - } - assertDisconnect(t, nodeB, connB) - <-done - }) + <-done + }, + }, { + name: "during Wait", + run: func(t *testing.T, ctx context.Context, conn *jsonrpc2.Conn) { + call, err := conn.DispatchCall(ctx, "m", nil, nil) + if err != nil { + t.Fatal(err) + } + if err := conn.Close(); err != nil { + t.Fatal(err) + } + if err := call.Wait(ctx, nil); err != jsonrpc2.ErrClosed { + t.Fatal(err) + } + }, + }, { + name: "during Dispatch", + run: func(t *testing.T, ctx context.Context, conn *jsonrpc2.Conn) { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + if _, err := conn.DispatchCall(ctx, "m", nil, nil); err != jsonrpc2.ErrClosed { + t.Fatal(err) + } + }, + }} + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + connA, connB := net.Pipe() + nodeA := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewPlainObjectStream(connA), noopHandler{}, + ) + defer nodeA.Close() + nodeB := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewPlainObjectStream(connB), + noopHandler{}, + ) + defer nodeB.Close() + + tc.run(t, ctx, nodeB) + + assertDisconnect(t, nodeB, connB) + }) + } } func testParams(t *testing.T, want *json.RawMessage, fn func(c *jsonrpc2.Conn) error) {