diff --git a/conn_opt.go b/conn_opt.go index 423cf80..b4db9ba 100644 --- a/conn_opt.go +++ b/conn_opt.go @@ -109,3 +109,11 @@ func SetLogger(logger Logger) ConnOpt { c.logger = logger } } + +// OmitNilParams instructs Conn.Call and Conn.Notify to omit the params member +// from the JSON-RPC request if invoked with their params argument set to nil. +func OmitNilParams() ConnOpt { + return func(c *Conn) { + c.omitNilParams = true + } +} diff --git a/conn_opt_test.go b/conn_opt_test.go index df53a1a..bc6db37 100644 --- a/conn_opt_test.go +++ b/conn_opt_test.go @@ -3,9 +3,12 @@ package jsonrpc2_test import ( "bufio" "context" + "encoding/json" + "fmt" "io" "log" "net" + "sync" "testing" "github.com/sourcegraph/jsonrpc2" @@ -51,3 +54,113 @@ func TestSetLogger(t *testing.T) { t.Fatalf("got %q, want %q", got, want) } } + +func TestOmitNilParams(t *testing.T) { + rawJSONMessage := func(v string) *json.RawMessage { + b := []byte(v) + return (*json.RawMessage)(&b) + } + + type testCase struct { + connOpt jsonrpc2.ConnOpt + sendParams interface{} + wantParams *json.RawMessage + } + + testCases := []testCase{ + { + sendParams: nil, + wantParams: rawJSONMessage("null"), + }, + { + sendParams: rawJSONMessage("null"), + wantParams: rawJSONMessage("null"), + }, + { + connOpt: jsonrpc2.OmitNilParams(), + sendParams: nil, + wantParams: nil, + }, + { + connOpt: jsonrpc2.OmitNilParams(), + sendParams: rawJSONMessage("null"), + wantParams: rawJSONMessage("null"), + }, + } + + assert := func(got *json.RawMessage, want *json.RawMessage) error { + // Assert pointers. + if got == nil || want == nil { + if got != want { + return fmt.Errorf("got %v, want %v", got, want) + } + return nil + } + { + // If pointers are not nil, then assert values. + got := string(*got) + want := string(*want) + if got != want { + return fmt.Errorf("got %q, want %q", got, want) + } + } + return nil + } + + newClientServer := func(handler jsonrpc2.Handler, connOpt jsonrpc2.ConnOpt) (client *jsonrpc2.Conn, server *jsonrpc2.Conn) { + ctx := context.Background() + connA, connB := net.Pipe() + client = jsonrpc2.NewConn( + ctx, + jsonrpc2.NewPlainObjectStream(connA), + noopHandler{}, + connOpt, + ) + server = jsonrpc2.NewConn( + ctx, + jsonrpc2.NewPlainObjectStream(connB), + handler, + connOpt, + ) + return client, server + } + + for i, tc := range testCases { + + t.Run(fmt.Sprintf("test case %v", i), func(t *testing.T) { + t.Run("call", func(t *testing.T) { + handler := jsonrpc2.HandlerWithError(func(ctx context.Context, c *jsonrpc2.Conn, r *jsonrpc2.Request) (result interface{}, err error) { + return nil, assert(r.Params, tc.wantParams) + }) + + client, server := newClientServer(handler, tc.connOpt) + defer client.Close() + defer server.Close() + + if err := client.Call(context.Background(), "f", tc.sendParams, nil); err != nil { + t.Fatal(err) + } + }) + t.Run("notify", func(t *testing.T) { + wg := &sync.WaitGroup{} + handler := handlerFunc(func(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + err := assert(req.Params, tc.wantParams) + if err != nil { + t.Error(err) + } + wg.Done() + }) + + client, server := newClientServer(handler, tc.connOpt) + defer client.Close() + defer server.Close() + + wg.Add(1) + if err := client.Notify(context.Background(), "f", tc.sendParams); err != nil { + t.Fatal(err) + } + wg.Wait() + }) + }) + } +} diff --git a/jsonrpc2.go b/jsonrpc2.go index 0bfcc71..ae2416d 100644 --- a/jsonrpc2.go +++ b/jsonrpc2.go @@ -159,8 +159,8 @@ func (r *Request) UnmarshalJSON(data []byte) error { return nil } -// SetParams sets r.Params to the JSON representation of v. If JSON -// marshaling fails, it returns an error. +// SetParams sets r.Params to the JSON representation of v. If JSON marshaling +// fails, it returns an error. Beware that the JSON encoding of nil is null. func (r *Request) SetParams(v interface{}) error { b, err := json.Marshal(v) if err != nil { @@ -367,7 +367,8 @@ type Conn struct { disconnect chan struct{} - logger Logger + logger Logger + omitNilParams bool // Set by ConnOpt funcs. onRecv []func(*Request, *Response) @@ -511,8 +512,12 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface // otherwise use Call. func (c *Conn) DispatchCall(ctx context.Context, method string, params interface{}, opts ...CallOption) (Waiter, error) { req := &Request{Method: method} - if err := req.SetParams(params); err != nil { - return Waiter{}, err + if c.omitNilParams && params == nil { + req.Params = nil + } else { + if err := req.SetParams(params); err != nil { + return Waiter{}, err + } } for _, opt := range opts { if opt == nil { @@ -569,8 +574,12 @@ var jsonNull = json.RawMessage("null") // notifications do not have responses). func (c *Conn) Notify(ctx context.Context, method string, params interface{}, opts ...CallOption) error { req := &Request{Method: method, Notif: true} - if err := req.SetParams(params); err != nil { - return err + if c.omitNilParams && params == nil { + req.Params = nil + } else { + if err := req.SetParams(params); err != nil { + return err + } } for _, opt := range opts { if opt == nil {