diff --git a/jsonrpc2.go b/jsonrpc2.go index 4885b05..b855390 100644 --- a/jsonrpc2.go +++ b/jsonrpc2.go @@ -366,11 +366,10 @@ type Conn struct { h Handler - mu sync.Mutex - shutdown bool - closing bool - seq uint64 - pending map[ID]*call + mu sync.Mutex + closed bool + seq uint64 + pending map[ID]*call sending sync.Mutex @@ -417,14 +416,35 @@ func NewConn(ctx context.Context, stream ObjectStream, h Handler, opts ...ConnOp // Close closes the JSON-RPC connection. The connection may not be // used after it has been closed. func (c *Conn) Close() error { + return c.close(nil) +} + +func (c *Conn) close(cause error) error { + c.sending.Lock() c.mu.Lock() - if c.shutdown || c.closing { - c.mu.Unlock() + defer c.sending.Unlock() + defer c.mu.Unlock() + + if c.closed { return ErrClosed } - c.closing = true - c.mu.Unlock() - return c.stream.Close() + + for _, call := range c.pending { + call.done <- cause + close(call.done) + } + + if cause != nil && cause != io.EOF && cause != io.ErrUnexpectedEOF { + c.logger.Printf("jsonrpc2: protocol error: %v\n", cause) + } + + if err := c.stream.Close(); err != nil { + return err + } + + close(c.disconnect) + c.closed = true + return nil } func (c *Conn) send(_ context.Context, m *anyMessage, wait bool) (cc *call, err error) { @@ -436,7 +456,7 @@ func (c *Conn) send(_ context.Context, m *anyMessage, wait bool) (cc *call, err var id ID c.mu.Lock() - if c.shutdown || c.closing { + if c.closed { c.mu.Unlock() return nil, ErrClosed } @@ -675,28 +695,7 @@ func (c *Conn) readMessages(ctx context.Context) { } } } - - c.sending.Lock() - c.mu.Lock() - c.shutdown = true - closing := c.closing - if err == io.EOF { - if closing { - err = ErrClosed - } else { - err = io.ErrUnexpectedEOF - } - } - for _, call := range c.pending { - call.done <- err - close(call.done) - } - c.mu.Unlock() - c.sending.Unlock() - if err != io.ErrUnexpectedEOF && !closing { - c.logger.Printf("jsonrpc2: protocol error: %v\n", err) - } - close(c.disconnect) + c.close(err) } // call represents a JSON-RPC call over its entire lifecycle. diff --git a/jsonrpc2_test.go b/jsonrpc2_test.go index ca600db..a8dec6e 100644 --- a/jsonrpc2_test.go +++ b/jsonrpc2_test.go @@ -390,6 +390,23 @@ func TestConn_Close_waitingForResponse(t *testing.T) { <-done } +func TestConn_DisconnectNotify_protocol_error(t *testing.T) { + connA, connB := net.Pipe() + c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewBufferedStream(connB, jsonrpc2.VarintObjectCodec{}), nil) + connA.Write([]byte("invalid json")) + select { + case <-c.DisconnectNotify(): + case <-time.After(200 * time.Millisecond): + t.Fatal("no disconnect notification") + } + // Assert that the underlying connection is closed by trying to write to it. + _, got := connB.Write(nil) + want := io.ErrClosedPipe + if got != want { + t.Fatalf("got %q, want %q", got, want) + } +} + func serve(ctx context.Context, lis net.Listener, h jsonrpc2.Handler, streamMaker streamMaker, opts ...jsonrpc2.ConnOpt) error { for { conn, err := lis.Accept()