diff --git a/jsonrpc2.go b/jsonrpc2.go index 4885b05..32bc98c 100644 --- a/jsonrpc2.go +++ b/jsonrpc2.go @@ -366,11 +366,10 @@ type Conn struct { h Handler - mu sync.Mutex - shutdown bool - closing bool - seq uint64 - pending map[ID]*call + mu sync.Mutex + closed bool + seq uint64 + pending map[ID]*call sending sync.Mutex @@ -417,13 +416,29 @@ func NewConn(ctx context.Context, stream ObjectStream, h Handler, opts ...ConnOp // Close closes the JSON-RPC connection. The connection may not be // used after it has been closed. func (c *Conn) Close() error { + return c.close(nil) +} + +func (c *Conn) close(cause error) error { + c.sending.Lock() c.mu.Lock() - if c.shutdown || c.closing { - c.mu.Unlock() + defer c.sending.Unlock() + defer c.mu.Unlock() + + if c.closed { return ErrClosed } - c.closing = true - c.mu.Unlock() + + for _, call := range c.pending { + close(call.done) + } + + if cause != nil && cause != io.EOF && cause != io.ErrUnexpectedEOF { + c.logger.Printf("jsonrpc2: protocol error: %v\n", cause) + } + + close(c.disconnect) + c.closed = true return c.stream.Close() } @@ -436,7 +451,7 @@ func (c *Conn) send(_ context.Context, m *anyMessage, wait bool) (cc *call, err var id ID c.mu.Lock() - if c.shutdown || c.closing { + if c.closed { c.mu.Unlock() return nil, ErrClosed } @@ -675,28 +690,7 @@ func (c *Conn) readMessages(ctx context.Context) { } } } - - c.sending.Lock() - c.mu.Lock() - c.shutdown = true - closing := c.closing - if err == io.EOF { - if closing { - err = ErrClosed - } else { - err = io.ErrUnexpectedEOF - } - } - for _, call := range c.pending { - call.done <- err - close(call.done) - } - c.mu.Unlock() - c.sending.Unlock() - if err != io.ErrUnexpectedEOF && !closing { - c.logger.Printf("jsonrpc2: protocol error: %v\n", err) - } - close(c.disconnect) + c.close(err) } // call represents a JSON-RPC call over its entire lifecycle. diff --git a/jsonrpc2_test.go b/jsonrpc2_test.go index ca600db..b68f3c1 100644 --- a/jsonrpc2_test.go +++ b/jsonrpc2_test.go @@ -314,80 +314,82 @@ type noopHandler struct{} func (noopHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {} -type readWriteCloser struct { - read, write func(p []byte) (n int, err error) -} +func TestConn_DisconnectNotify(t *testing.T) { -func (x readWriteCloser) Read(p []byte) (n int, err error) { - return x.read(p) -} - -func (x readWriteCloser) Write(p []byte) (n int, err error) { - return x.write(p) -} - -func (readWriteCloser) Close() error { return nil } - -func eof(p []byte) (n int, err error) { - return 0, io.EOF -} - -func TestConn_DisconnectNotify_EOF(t *testing.T) { - c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewBufferedStream(&readWriteCloser{eof, eof}, jsonrpc2.VarintObjectCodec{}), nil) - select { - case <-c.DisconnectNotify(): - case <-time.After(200 * time.Millisecond): - t.Fatal("no disconnect notification") - } -} - -func TestConn_DisconnectNotify_Close(t *testing.T) { - c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewBufferedStream(&readWriteCloser{eof, eof}, jsonrpc2.VarintObjectCodec{}), nil) - if err := c.Close(); err != nil { - t.Error(err) - } - select { - case <-c.DisconnectNotify(): - case <-time.After(200 * time.Millisecond): - t.Fatal("no disconnect notification") - } -} - -func TestConn_DisconnectNotify_Close_async(t *testing.T) { - done := make(chan struct{}) - c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewBufferedStream(&readWriteCloser{eof, eof}, jsonrpc2.VarintObjectCodec{}), nil) - go func() { - if err := c.Close(); err != nil && err != jsonrpc2.ErrClosed { + t.Run("EOF", func(t *testing.T) { + connA, connB := net.Pipe() + c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil) + // By closing connA, connB receives io.EOF + if err := connA.Close(); err != nil { t.Error(err) } - close(done) - }() - select { - case <-c.DisconnectNotify(): - case <-time.After(200 * time.Millisecond): - t.Fatal("no disconnect notification") - } - <-done + assertDisconnect(t, c, connB) + }) + + t.Run("Close", func(t *testing.T) { + _, connB := net.Pipe() + c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil) + if err := c.Close(); err != nil { + t.Error(err) + } + assertDisconnect(t, c, connB) + }) + + t.Run("Close async", func(t *testing.T) { + done := make(chan struct{}) + _, connB := net.Pipe() + c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil) + go func() { + if err := c.Close(); err != nil && err != jsonrpc2.ErrClosed { + t.Error(err) + } + close(done) + }() + assertDisconnect(t, c, connB) + <-done + }) + + t.Run("protocol error", func(t *testing.T) { + connA, connB := net.Pipe() + c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil) + connA.Write([]byte("invalid json")) + assertDisconnect(t, c, connB) + }) } -func TestConn_Close_waitingForResponse(t *testing.T) { - c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewBufferedStream(&readWriteCloser{eof, eof}, jsonrpc2.VarintObjectCodec{}), noopHandler{}) - done := make(chan struct{}) - go func() { - if err := c.Call(context.Background(), "m", nil, nil); err != jsonrpc2.ErrClosed { - t.Errorf("got error %v, want %v", err, jsonrpc2.ErrClosed) +func TestConn_Close(t *testing.T) { + t.Run("waiting for response", func(t *testing.T) { + connA, connB := net.Pipe() + nodeA := jsonrpc2.NewConn( + context.Background(), + jsonrpc2.NewPlainObjectStream(connA), noopHandler{}, + ) + defer nodeA.Close() + nodeB := jsonrpc2.NewConn( + context.Background(), + jsonrpc2.NewPlainObjectStream(connB), + noopHandler{}, + ) + defer nodeB.Close() + + ready := make(chan struct{}) + done := make(chan struct{}) + go func() { + close(ready) + err := nodeB.Call(context.Background(), "m", nil, nil) + if err != jsonrpc2.ErrClosed { + t.Errorf("got error %v, want %v", err, jsonrpc2.ErrClosed) + } + close(done) + }() + // Wait for the request to be sent before we close the connection. + <-ready + if err := nodeB.Close(); err != nil && err != jsonrpc2.ErrClosed { + t.Error(err) } - close(done) - }() - if err := c.Close(); err != nil && err != jsonrpc2.ErrClosed { - t.Error(err) - } - select { - case <-c.DisconnectNotify(): - case <-time.After(200 * time.Millisecond): - t.Fatal("no disconnect notification") - } - <-done + assertDisconnect(t, nodeB, connB) + <-done + }) } func serve(ctx context.Context, lis net.Listener, h jsonrpc2.Handler, streamMaker streamMaker, opts ...jsonrpc2.ConnOpt) error { @@ -399,3 +401,17 @@ func serve(ctx context.Context, lis net.Listener, h jsonrpc2.Handler, streamMake jsonrpc2.NewConn(ctx, streamMaker(conn), h, opts...) } } + +func assertDisconnect(t *testing.T, c *jsonrpc2.Conn, conn io.Writer) { + select { + case <-c.DisconnectNotify(): + case <-time.After(200 * time.Millisecond): + t.Fatal("no disconnect notification") + } + // Assert that conn is closed by trying to write to it. + _, got := conn.Write(nil) + want := io.ErrClosedPipe + if got != want { + t.Fatalf("got %q, want %q", got, want) + } +}