From c5d79b7c377225fb7876cdfed5959835fc793cf5 Mon Sep 17 00:00:00 2001 From: Sam Herrmann Date: Wed, 18 Jan 2023 16:09:49 -0500 Subject: [PATCH] Add ability to omit params member from request The JSON-RPC 2.0 specification allows the params member of a request to be omitted [1]. Before this commit, this library did not allow the params member to be omitted. When the params argument of the Conn.Call or Conn.Notify method was set to nil, then Request.Params was set to the JSON encoding of nil which is null. This commit adds a ConnOption named OmitNilParams. If OmitNilParams is applied on Conn and Conn.Call or Conn.Notify are invoked with their params argument set to nil, then the params member in the JSON encoding of Request is omitted. If the OmitNilParams option is not applied on Conn then the previous behavior is maintained. In other words, the changes in this commit are backwards compatible. References [1]: https://www.jsonrpc.org/specification#request_object --- conn_opt.go | 8 ++++ conn_opt_test.go | 113 +++++++++++++++++++++++++++++++++++++++++++++++ jsonrpc2.go | 23 +++++++--- 3 files changed, 137 insertions(+), 7 deletions(-) diff --git a/conn_opt.go b/conn_opt.go index 423cf80..b4db9ba 100644 --- a/conn_opt.go +++ b/conn_opt.go @@ -109,3 +109,11 @@ func SetLogger(logger Logger) ConnOpt { c.logger = logger } } + +// OmitNilParams instructs Conn.Call and Conn.Notify to omit the params member +// from the JSON-RPC request if invoked with their params argument set to nil. +func OmitNilParams() ConnOpt { + return func(c *Conn) { + c.omitNilParams = true + } +} diff --git a/conn_opt_test.go b/conn_opt_test.go index df53a1a..bc6db37 100644 --- a/conn_opt_test.go +++ b/conn_opt_test.go @@ -3,9 +3,12 @@ package jsonrpc2_test import ( "bufio" "context" + "encoding/json" + "fmt" "io" "log" "net" + "sync" "testing" "github.com/sourcegraph/jsonrpc2" @@ -51,3 +54,113 @@ func TestSetLogger(t *testing.T) { t.Fatalf("got %q, want %q", got, want) } } + +func TestOmitNilParams(t *testing.T) { + rawJSONMessage := func(v string) *json.RawMessage { + b := []byte(v) + return (*json.RawMessage)(&b) + } + + type testCase struct { + connOpt jsonrpc2.ConnOpt + sendParams interface{} + wantParams *json.RawMessage + } + + testCases := []testCase{ + { + sendParams: nil, + wantParams: rawJSONMessage("null"), + }, + { + sendParams: rawJSONMessage("null"), + wantParams: rawJSONMessage("null"), + }, + { + connOpt: jsonrpc2.OmitNilParams(), + sendParams: nil, + wantParams: nil, + }, + { + connOpt: jsonrpc2.OmitNilParams(), + sendParams: rawJSONMessage("null"), + wantParams: rawJSONMessage("null"), + }, + } + + assert := func(got *json.RawMessage, want *json.RawMessage) error { + // Assert pointers. + if got == nil || want == nil { + if got != want { + return fmt.Errorf("got %v, want %v", got, want) + } + return nil + } + { + // If pointers are not nil, then assert values. + got := string(*got) + want := string(*want) + if got != want { + return fmt.Errorf("got %q, want %q", got, want) + } + } + return nil + } + + newClientServer := func(handler jsonrpc2.Handler, connOpt jsonrpc2.ConnOpt) (client *jsonrpc2.Conn, server *jsonrpc2.Conn) { + ctx := context.Background() + connA, connB := net.Pipe() + client = jsonrpc2.NewConn( + ctx, + jsonrpc2.NewPlainObjectStream(connA), + noopHandler{}, + connOpt, + ) + server = jsonrpc2.NewConn( + ctx, + jsonrpc2.NewPlainObjectStream(connB), + handler, + connOpt, + ) + return client, server + } + + for i, tc := range testCases { + + t.Run(fmt.Sprintf("test case %v", i), func(t *testing.T) { + t.Run("call", func(t *testing.T) { + handler := jsonrpc2.HandlerWithError(func(ctx context.Context, c *jsonrpc2.Conn, r *jsonrpc2.Request) (result interface{}, err error) { + return nil, assert(r.Params, tc.wantParams) + }) + + client, server := newClientServer(handler, tc.connOpt) + defer client.Close() + defer server.Close() + + if err := client.Call(context.Background(), "f", tc.sendParams, nil); err != nil { + t.Fatal(err) + } + }) + t.Run("notify", func(t *testing.T) { + wg := &sync.WaitGroup{} + handler := handlerFunc(func(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + err := assert(req.Params, tc.wantParams) + if err != nil { + t.Error(err) + } + wg.Done() + }) + + client, server := newClientServer(handler, tc.connOpt) + defer client.Close() + defer server.Close() + + wg.Add(1) + if err := client.Notify(context.Background(), "f", tc.sendParams); err != nil { + t.Fatal(err) + } + wg.Wait() + }) + }) + } +} diff --git a/jsonrpc2.go b/jsonrpc2.go index 0bfcc71..ae2416d 100644 --- a/jsonrpc2.go +++ b/jsonrpc2.go @@ -159,8 +159,8 @@ func (r *Request) UnmarshalJSON(data []byte) error { return nil } -// SetParams sets r.Params to the JSON representation of v. If JSON -// marshaling fails, it returns an error. +// SetParams sets r.Params to the JSON representation of v. If JSON marshaling +// fails, it returns an error. Beware that the JSON encoding of nil is null. func (r *Request) SetParams(v interface{}) error { b, err := json.Marshal(v) if err != nil { @@ -367,7 +367,8 @@ type Conn struct { disconnect chan struct{} - logger Logger + logger Logger + omitNilParams bool // Set by ConnOpt funcs. onRecv []func(*Request, *Response) @@ -511,8 +512,12 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface // otherwise use Call. func (c *Conn) DispatchCall(ctx context.Context, method string, params interface{}, opts ...CallOption) (Waiter, error) { req := &Request{Method: method} - if err := req.SetParams(params); err != nil { - return Waiter{}, err + if c.omitNilParams && params == nil { + req.Params = nil + } else { + if err := req.SetParams(params); err != nil { + return Waiter{}, err + } } for _, opt := range opts { if opt == nil { @@ -569,8 +574,12 @@ var jsonNull = json.RawMessage("null") // notifications do not have responses). func (c *Conn) Notify(ctx context.Context, method string, params interface{}, opts ...CallOption) error { req := &Request{Method: method, Notif: true} - if err := req.SetParams(params); err != nil { - return err + if c.omitNilParams && params == nil { + req.Params = nil + } else { + if err := req.SetParams(params); err != nil { + return err + } } for _, opt := range opts { if opt == nil {