diff --git a/handler_with_error.go b/handler_with_error.go index 11cd945..81c5546 100644 --- a/handler_with_error.go +++ b/handler_with_error.go @@ -3,7 +3,6 @@ package jsonrpc2 import ( "context" "log" - "reflect" ) // HandlerWithError implements Handler by calling the func for each @@ -22,9 +21,6 @@ func (h HandlerWithError) Handle(ctx context.Context, conn *Conn, req *Request) resp := &Response{ID: req.ID} if err == nil { - if isNilValue(result) { - result = struct{}{} - } err = resp.SetResult(result) } if err != nil { @@ -41,17 +37,3 @@ func (h HandlerWithError) Handle(ctx context.Context, conn *Conn, req *Request) } } } - -// isNilValue tests if an interface is empty, because an empty interface does -// not encode any information, we can't encode it in JSON so that the proxy -// knows it's a response, not a request. -func isNilValue(resp interface{}) bool { - if resp == nil { - return true - } - kind := reflect.TypeOf(resp).Kind() - value := reflect.ValueOf(resp) - nilPtr := kind == reflect.Ptr && value.IsNil() - nilSlice := kind == reflect.Slice && value.IsNil() - return nilPtr || nilSlice -} diff --git a/jsonrpc2.go b/jsonrpc2.go index 3c3fa76..7b921f3 100644 --- a/jsonrpc2.go +++ b/jsonrpc2.go @@ -3,6 +3,7 @@ package jsonrpc2 import ( + "bytes" "context" "encoding/json" "errors" @@ -131,6 +132,9 @@ func (r *Response) MarshalJSON() ([]byte, error) { if r == nil { return nil, errors.New("can't marshal nil *jsonrpc2.Response") } + if (r.Result == nil || len(*r.Result) == 0) && r.Error == nil { + return nil, errors.New("can't marshal *jsonrpc2.Response (must have result or error)") + } b, err := json.Marshal(*r) if err != nil { return nil, err @@ -373,7 +377,10 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface if err != nil { return err } - if result != nil && call.response.Result != nil { + if result != nil { + if call.response.Result == nil { + call.response.Result = &jsonNull + } // TODO(sqs): error handling if err := json.Unmarshal(*call.response.Result, result); err != nil { return err @@ -386,6 +393,8 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface } } +var jsonNull = json.RawMessage("null") + // Notify is like Call, but it returns when the notification request // is sent (without waiting for a response, because JSON-RPC // notifications do not have responses). @@ -540,15 +549,16 @@ func (m *anyMessage) UnmarshalJSON(data []byte) error { // The presence of these fields distinguishes between the 2 // message types. type msg struct { - Method *string `json:"method"` - Result interface{} `json:"result"` - Error interface{} `json:"error"` + ID interface{} `json:"id"` + Method *string `json:"method"` + Result anyValueWithExplicitNull `json:"result"` + Error interface{} `json:"error"` } var isRequest, isResponse bool checkType := func(m *msg) error { mIsRequest := m.Method != nil - mIsResponse := m.Result != nil || m.Error != nil + mIsResponse := m.Result.null || m.Result.value != nil || m.Error != nil if (!mIsRequest && !mIsResponse) || (mIsRequest && mIsResponse) { return errors.New("jsonrpc2: unable to determine message type (request or response)") } @@ -590,7 +600,34 @@ func (m *anyMessage) UnmarshalJSON(data []byte) error { case !isRequest && isResponse: v = &m.response } - return json.Unmarshal(data, v) + if err := json.Unmarshal(data, v); err != nil { + return err + } + if !isRequest && isResponse && m.response.Error == nil && m.response.Result == nil { + m.response.Result = &jsonNull + } + return nil +} + +// anyValueWithExplicitNull is used to distinguish {} from +// {"result":null} by anyMessage's JSON unmarshaler. +type anyValueWithExplicitNull struct { + null bool // JSON "null" + value interface{} +} + +func (v anyValueWithExplicitNull) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) +} + +func (v *anyValueWithExplicitNull) UnmarshalJSON(data []byte) error { + data = bytes.TrimSpace(data) + if string(data) == "null" { + *v = anyValueWithExplicitNull{null: true} + return nil + } + *v = anyValueWithExplicitNull{} + return json.Unmarshal(data, &v.value) } var ( diff --git a/jsonrpc2_test.go b/jsonrpc2_test.go index ab36d8d..5aa826c 100644 --- a/jsonrpc2_test.go +++ b/jsonrpc2_test.go @@ -24,18 +24,19 @@ func TestRequest_MarshalJSON_jsonrpc(t *testing.T) { if err != nil { t.Fatal(err) } - if want := `"jsonrpc":"2.0"`; !strings.Contains(string(b), want) { - t.Errorf("got %s, want it to include the string %s", b, want) + if want := `{"method":"","id":0,"jsonrpc":"2.0"}`; string(b) != want { + t.Errorf("got %q, want %q", b, want) } } func TestResponse_MarshalJSON_jsonrpc(t *testing.T) { - b, err := json.Marshal(&jsonrpc2.Response{}) + null := json.RawMessage("null") + b, err := json.Marshal(&jsonrpc2.Response{Result: &null}) if err != nil { t.Fatal(err) } - if want := `"jsonrpc":"2.0"`; !strings.Contains(string(b), want) { - t.Errorf("got %s, want it to include the string %s", b, want) + if want := `{"id":0,"result":null,"jsonrpc":"2.0"}`; string(b) != want { + t.Errorf("got %q, want %q", b, want) } } diff --git a/object_test.go b/object_test.go index 593ceb2..906d99b 100644 --- a/object_test.go +++ b/object_test.go @@ -8,18 +8,24 @@ import ( func TestAnyMessage(t *testing.T) { tests := map[string]struct { - request, response bool + request, response, invalid bool }{ // Single messages - `{}`: {}, - `{"foo":"bar"}`: {}, + `{}`: {invalid: true}, + `{"foo":"bar"}`: {invalid: true}, `{"method":"m"}`: {request: true}, `{"result":123}`: {response: true}, + `{"result":null}`: {response: true}, `{"error":{"code":456,"message":"m"}}`: {response: true}, } for s, want := range tests { var m anyMessage - json.Unmarshal([]byte(s), &m) + if err := json.Unmarshal([]byte(s), &m); err != nil { + if !want.invalid { + t.Errorf("%s: error: %s", s, err) + } + continue + } if (m.request != nil) != want.request { t.Errorf("%s: got request %v, want %v", s, m.request != nil, want.request) } @@ -30,6 +36,7 @@ func TestAnyMessage(t *testing.T) { } func TestMessageCodec(t *testing.T) { + obj := json.RawMessage(`{"foo":"bar"}`) tests := []struct { v, vempty interface{} }{ @@ -38,8 +45,8 @@ func TestMessageCodec(t *testing.T) { vempty: &Request{ID: ID{Num: 123}}, }, { - v: &Response{ID: ID{Num: 123}}, - vempty: &Response{ID: ID{Num: 123}}, + v: &Response{ID: ID{Num: 123}, Result: &obj}, + vempty: &Response{ID: ID{Num: 123}, Result: &obj}, }, } for _, test := range tests {