From 6416f80f8f28c5210789f98c04f725e624cc7317 Mon Sep 17 00:00:00 2001 From: Quinn Slack Date: Sun, 6 Nov 2016 07:12:05 -0800 Subject: [PATCH] support string request IDs (in addition to numeric request IDs) --- conn_opt.go | 12 ++++---- jsonrpc2.go | 78 ++++++++++++++++++++++++++++++++++++++---------- jsonrpc2_test.go | 28 +++++++++-------- 3 files changed, 84 insertions(+), 34 deletions(-) diff --git a/conn_opt.go b/conn_opt.go index 3f7dea9..a2f184f 100644 --- a/conn_opt.go +++ b/conn_opt.go @@ -30,7 +30,7 @@ func LogMessages(log *log.Logger) ConnOpt { // request method in OnSend for responses. var ( mu sync.Mutex - reqMethods = map[uint64]string{} + reqMethods = map[ID]string{} ) OnRecv(func(req *Request, resp *Response) { @@ -44,7 +44,7 @@ func LogMessages(log *log.Logger) ConnOpt { if req.Notif { log.Printf("--> notif: %s: %s", req.Method, params) } else { - log.Printf("--> request #%d: %s: %s", req.ID, req.Method, params) + log.Printf("--> request #%s: %s: %s", req.ID, req.Method, params) } case resp != nil: @@ -57,7 +57,7 @@ func LogMessages(log *log.Logger) ConnOpt { switch { case resp.Result != nil: result, _ := json.Marshal(resp.Result) - log.Printf("--> result #%d: %s: %s", resp.ID, method, result) + log.Printf("--> result #%s: %s: %s", resp.ID, method, result) case resp.Error != nil: err, _ := json.Marshal(resp.Error) log.Printf("--> error #%d: %s: %s", resp.ID, method, err) @@ -71,7 +71,7 @@ func LogMessages(log *log.Logger) ConnOpt { if req.Notif { log.Printf("<-- notif: %s: %s", req.Method, params) } else { - log.Printf("<-- request #%d: %s: %s", req.ID, req.Method, params) + log.Printf("<-- request #%s: %s: %s", req.ID, req.Method, params) } case resp != nil: @@ -85,10 +85,10 @@ func LogMessages(log *log.Logger) ConnOpt { if resp.Result != nil { result, _ := json.Marshal(resp.Result) - log.Printf("<-- result #%d: %s: %s", resp.ID, method, result) + log.Printf("<-- result #%s: %s: %s", resp.ID, method, result) } else { err, _ := json.Marshal(resp.Error) - log.Printf("<-- error #%d: %s: %s", resp.ID, method, err) + log.Printf("<-- error #%s: %s: %s", resp.ID, method, err) } } })(c) diff --git a/jsonrpc2.go b/jsonrpc2.go index 7f3a37e..76e8d0e 100644 --- a/jsonrpc2.go +++ b/jsonrpc2.go @@ -38,7 +38,7 @@ type JSONRPC2 interface { type Request struct { Method string `json:"method"` Params *json.RawMessage `json:"params,omitempty"` - ID uint64 `json:"id"` + ID ID `json:"id"` Meta *json.RawMessage `json:"meta,omitempty"` Notif bool `json:"-"` } @@ -52,7 +52,7 @@ func (r *Request) MarshalJSON() ([]byte, error) { r2 := struct { Method string `json:"method"` Params *json.RawMessage `json:"params,omitempty"` - ID *uint64 `json:"id,omitempty"` + ID *ID `json:"id,omitempty"` Meta *json.RawMessage `json:"meta,omitempty"` JSONRPC string `json:"jsonrpc"` }{ @@ -73,7 +73,7 @@ func (r *Request) UnmarshalJSON(data []byte) error { Method string `json:"method"` Params *json.RawMessage `json:"params,omitempty"` Meta *json.RawMessage `json:"meta,omitempty"` - ID *uint64 `json:"id"` + ID *ID `json:"id"` } if err := json.Unmarshal(data, &r2); err != nil { return err @@ -82,7 +82,7 @@ func (r *Request) UnmarshalJSON(data []byte) error { r.Params = r2.Params r.Meta = r2.Meta if r2.ID == nil { - r.ID = 0 + r.ID = ID{} r.Notif = true } else { r.ID = *r2.ID @@ -116,7 +116,7 @@ func (r *Request) SetMeta(v interface{}) error { // Response represents a JSON-RPC response. See // http://www.jsonrpc.org/specification#response_object. type Response struct { - ID uint64 `json:"id"` + ID ID `json:"id"` Result *json.RawMessage `json:"result,omitempty"` Error *Error `json:"error,omitempty"` @@ -193,6 +193,52 @@ type Handler interface { Handle(context.Context, *Conn, *Request) } +// ID represents a JSON-RPC 2.0 request ID, which may be either a +// string or number (or null, which is unsupported). +type ID struct { + // At most one of Num or Str may be nonzero. If both are zero + // valued, then IsNum specifies which field's value is to be used + // as the ID. + Num uint64 + Str string + + // IsString controls whether the Num or Str field's value should be + // used as the ID, when both are zero valued. It must always be + // set to true if the request ID is a string. + IsString bool +} + +func (id ID) String() string { + if id.IsString { + return strconv.Quote(id.Str) + } + return strconv.FormatUint(id.Num, 10) +} + +// MarshalJSON implements json.Marshaler. +func (id ID) MarshalJSON() ([]byte, error) { + if id.IsString { + return json.Marshal(id.Str) + } + return json.Marshal(id.Num) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (id *ID) UnmarshalJSON(data []byte) error { + // Support both uint64 and string IDs. + var v uint64 + if err := json.Unmarshal(data, &v); err == nil { + *id = ID{Num: v} + return nil + } + var v2 string + if err := json.Unmarshal(data, &v2); err != nil { + return err + } + *id = ID{Str: v2, IsString: true} + return nil +} + // Conn is a JSON-RPC client/server connection. The JSON-RPC protocol // is symmetric, so a Conn runs on both ends of a client-server // connection. @@ -206,7 +252,7 @@ type Conn struct { shutdown bool closing bool seq uint64 - pending map[uint64]*call + pending map[ID]*call sending sync.Mutex @@ -235,7 +281,7 @@ func NewConn(ctx context.Context, conn io.ReadWriteCloser, h Handler, opt ...Con conn: conn, w: bufio.NewWriter(conn), h: h, - pending: map[uint64]*call{}, + pending: map[ID]*call{}, disconnect: make(chan struct{}), } for _, opt := range opt { @@ -273,8 +319,8 @@ func (c *Conn) send(ctx context.Context, m *anyMessage, wait bool) (*call, error var cc *call if m.request != nil && wait { cc = &call{request: m.request, seq: c.seq, done: make(chan error)} - c.pending[c.seq] = cc // use next seq as call ID - m.request.ID = c.seq + c.pending[ID{Num: c.seq}] = cc // use next seq as call ID + m.request.ID.Num = c.seq c.seq++ } c.mu.Unlock() @@ -293,7 +339,7 @@ func (c *Conn) send(ctx context.Context, m *anyMessage, wait bool) (*call, error c.w.Flush() if cc != nil { c.mu.Lock() - delete(c.pending, cc.seq) + delete(c.pending, ID{Num: cc.seq}) c.mu.Unlock() } return nil, err @@ -358,7 +404,7 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}, op } // Reply sends a successful response with a result. -func (c *Conn) Reply(ctx context.Context, id uint64, result interface{}) error { +func (c *Conn) Reply(ctx context.Context, id ID, result interface{}) error { resp := &Response{ID: id} if err := resp.SetResult(result); err != nil { return err @@ -368,7 +414,7 @@ func (c *Conn) Reply(ctx context.Context, id uint64, result interface{}) error { } // ReplyWithError sends a response with an error. -func (c *Conn) ReplyWithError(ctx context.Context, id uint64, respErr *Error) error { +func (c *Conn) ReplyWithError(ctx context.Context, id ID, respErr *Error) error { _, err := c.send(ctx, &anyMessage{response: &Response{ID: id, Error: respErr}}, false) return err } @@ -409,10 +455,10 @@ func (c *Conn) readMessages(ctx context.Context, r *bufio.Reader) { case m.response != nil: resp := m.response if resp != nil { - seq := resp.ID + id := resp.ID c.mu.Lock() - call := c.pending[seq] - delete(c.pending, seq) + call := c.pending[id] + delete(c.pending, id) c.mu.Unlock() if call != nil { @@ -430,7 +476,7 @@ func (c *Conn) readMessages(ctx context.Context, r *bufio.Reader) { switch { case call == nil: - log.Printf("jsonrpc2: ignoring response %d with no corresponding request", seq) + log.Printf("jsonrpc2: ignoring response #%s with no corresponding request", id) case resp.Error != nil: call.done <- resp.Error diff --git a/jsonrpc2_test.go b/jsonrpc2_test.go index ae770c5..167c2e1 100644 --- a/jsonrpc2_test.go +++ b/jsonrpc2_test.go @@ -37,9 +37,11 @@ func TestResponse_MarshalJSON_jsonrpc(t *testing.T) { func TestResponseMarshalJSON_Notif(t *testing.T) { tests := map[*Request]bool{ - &Request{ID: 0}: true, - &Request{ID: 1}: true, - &Request{Notif: true}: false, + &Request{ID: ID{Num: 0}}: true, + &Request{ID: ID{Num: 1}}: true, + &Request{ID: ID{Str: "", IsString: true}}: true, + &Request{ID: ID{Str: "a", IsString: true}}: true, + &Request{Notif: true}: false, } for r, wantIDKey := range tests { b, err := json.Marshal(r) @@ -55,9 +57,11 @@ func TestResponseMarshalJSON_Notif(t *testing.T) { func TestResponseUnmarshalJSON_Notif(t *testing.T) { tests := map[string]bool{ - `{"method":"f","id":0}`: false, - `{"method":"f","id":1}`: false, - `{"method":"f"}`: true, + `{"method":"f","id":0}`: false, + `{"method":"f","id":1}`: false, + `{"method":"f","id":"a"}`: false, + `{"method":"f","id":""}`: false, + `{"method":"f"}`: true, } for s, want := range tests { var r Request @@ -77,11 +81,11 @@ func (h *testHandlerA) Handle(ctx context.Context, conn *Conn, req *Request) { if req.Notif { return // notification } - if err := conn.Reply(ctx, req.ID, fmt.Sprintf("hello, #%d: %s", req.ID, *req.Params)); err != nil { + if err := conn.Reply(ctx, req.ID, fmt.Sprintf("hello, #%s: %s", req.ID, *req.Params)); err != nil { h.t.Error(err) } - if err := conn.Notify(ctx, "m", fmt.Sprintf("notif for #%d", req.ID)); err != nil { + if err := conn.Notify(ctx, "m", fmt.Sprintf("notif for #%s", req.ID)); err != nil { h.t.Error(err) } } @@ -273,12 +277,12 @@ func TestMessageCodec(t *testing.T) { v, vempty interface{} }{ { - v: &Request{ID: 123}, - vempty: &Request{ID: 123}, + v: &Request{ID: ID{Num: 123}}, + vempty: &Request{ID: ID{Num: 123}}, }, { - v: &Response{ID: 123}, - vempty: &Response{ID: 123}, + v: &Response{ID: ID{Num: 123}}, + vempty: &Response{ID: ID{Num: 123}}, }, } for _, test := range tests {