From 4fb7cd90793ee6ab445f466b900e6bffb9b63d78 Mon Sep 17 00:00:00 2001 From: Keegan Carruthers-Smith Date: Tue, 31 Jan 2017 11:08:53 +0200 Subject: [PATCH] Add CallOpt SetID SetID allows a caller to control the ID of the request. Previously it was impossible to set the ID of a call. --- call_opt.go | 10 +++++++ jsonrpc2.go | 14 +++++++--- jsonrpc2_test.go | 70 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 3 deletions(-) diff --git a/call_opt.go b/call_opt.go index f2585ec..b554bac 100644 --- a/call_opt.go +++ b/call_opt.go @@ -18,3 +18,13 @@ func Meta(meta interface{}) CallOption { return r.SetMeta(meta) }) } + +// PickID returns a call option which sets the ID on a request. Care must be +// taken to ensure there are no conflicts with any previously picked ID, nor +// with the default sequence ID. +func PickID(id ID) CallOption { + return callOptionFunc(func(r *Request) error { + r.ID = id + return nil + }) +} diff --git a/jsonrpc2.go b/jsonrpc2.go index bc4bf38..29ae130 100644 --- a/jsonrpc2.go +++ b/jsonrpc2.go @@ -332,6 +332,10 @@ func (c *Conn) send(ctx context.Context, m *anyMessage, wait bool) (cc *call, er c.sending.Lock() defer c.sending.Unlock() + // m.request.ID could be changed, so we store a copy to correctly + // clean up pending + var id ID + c.mu.Lock() if c.shutdown || c.closing { c.mu.Unlock() @@ -342,8 +346,12 @@ func (c *Conn) send(ctx context.Context, m *anyMessage, wait bool) (cc *call, er // responses. if m.request != nil && wait { cc = &call{request: m.request, seq: c.seq, done: make(chan error, 1)} - c.pending[ID{Num: c.seq}] = cc // use next seq as call ID - m.request.ID.Num = c.seq + if !m.request.ID.IsString && m.request.ID.Num == 0 { + // unset, use next seq as call ID + m.request.ID.Num = c.seq + } + id = m.request.ID + c.pending[id] = cc c.seq++ } c.mu.Unlock() @@ -364,7 +372,7 @@ func (c *Conn) send(ctx context.Context, m *anyMessage, wait bool) (cc *call, er if err != nil { if cc != nil { c.mu.Lock() - delete(c.pending, ID{Num: cc.seq}) + delete(c.pending, id) c.mu.Unlock() } } diff --git a/jsonrpc2_test.go b/jsonrpc2_test.go index 5aa826c..6bcf2c2 100644 --- a/jsonrpc2_test.go +++ b/jsonrpc2_test.go @@ -208,6 +208,76 @@ func testClientServer(ctx context.Context, t *testing.T, stream jsonrpc2.ObjectS hb.mu.Unlock() } +func inMemoryPeerConns() (io.ReadWriteCloser, io.ReadWriteCloser) { + sr, cw := io.Pipe() + cr, sw := io.Pipe() + return &pipeReadWriteCloser{sr, sw}, &pipeReadWriteCloser{cr, cw} +} + +type pipeReadWriteCloser struct { + *io.PipeReader + *io.PipeWriter +} + +func (c *pipeReadWriteCloser) Close() error { + err1 := c.PipeReader.Close() + err2 := c.PipeWriter.Close() + if err1 != nil { + return err1 + } + return err2 +} + +type handlerFunc func(context.Context, *jsonrpc2.Conn, *jsonrpc2.Request) + +func (h handlerFunc) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + h(ctx, conn, req) +} + +func TestPickID(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + a, b := inMemoryPeerConns() + defer a.Close() + defer b.Close() + + handler := handlerFunc(func(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + if err := conn.Reply(ctx, req.ID, fmt.Sprintf("hello, #%s: %s", req.ID, *req.Params)); err != nil { + t.Error(err) + } + }) + connA := jsonrpc2.NewConn(ctx, jsonrpc2.NewBufferedStream(a, jsonrpc2.VSCodeObjectCodec{}), handler) + connB := jsonrpc2.NewConn(ctx, jsonrpc2.NewBufferedStream(b, jsonrpc2.VSCodeObjectCodec{}), noopHandler{}) + defer connA.Close() + defer connB.Close() + + const n = 100 + for i := 0; i < n; i++ { + var opts []jsonrpc2.CallOption + id := jsonrpc2.ID{Num: uint64(i)} + + // This is the actual test, every 3rd request we specify the + // ID and ensure we get a response with the correct ID echoed + // back + if i%3 == 0 { + id = jsonrpc2.ID{ + Str: fmt.Sprintf("helloworld-%d", i/3), + IsString: true, + } + opts = append(opts, jsonrpc2.PickID(id)) + } + + var got string + if err := connB.Call(ctx, "f", []int32{1, 2, 3}, &got, opts...); err != nil { + t.Fatal(err) + } + if want := fmt.Sprintf("hello, #%s: [1,2,3]", id); got != want { + t.Errorf("got result %q, want %q", got, want) + } + } +} + type noopHandler struct{} func (noopHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {}