diff --git a/jsonrpc2.go b/jsonrpc2.go index 3c3fa76..9ec1297 100644 --- a/jsonrpc2.go +++ b/jsonrpc2.go @@ -3,6 +3,7 @@ package jsonrpc2 import ( + "bytes" "context" "encoding/json" "errors" @@ -540,15 +541,16 @@ func (m *anyMessage) UnmarshalJSON(data []byte) error { // The presence of these fields distinguishes between the 2 // message types. type msg struct { - Method *string `json:"method"` - Result interface{} `json:"result"` - Error interface{} `json:"error"` + ID interface{} `json:"id"` + Method *string `json:"method"` + Result anyValueWithExplicitNull `json:"result"` + Error interface{} `json:"error"` } var isRequest, isResponse bool checkType := func(m *msg) error { mIsRequest := m.Method != nil - mIsResponse := m.Result != nil || m.Error != nil + mIsResponse := m.Result.null || m.Result.value != nil || m.Error != nil if (!mIsRequest && !mIsResponse) || (mIsRequest && mIsResponse) { return errors.New("jsonrpc2: unable to determine message type (request or response)") } @@ -593,6 +595,27 @@ func (m *anyMessage) UnmarshalJSON(data []byte) error { return json.Unmarshal(data, v) } +// anyValueWithExplicitNull is used to distinguish {} from +// {"result":null} by anyMessage's JSON unmarshaler. +type anyValueWithExplicitNull struct { + null bool // JSON "null" + value interface{} +} + +func (v anyValueWithExplicitNull) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) +} + +func (v *anyValueWithExplicitNull) UnmarshalJSON(data []byte) error { + data = bytes.TrimSpace(data) + if string(data) == "null" { + *v = anyValueWithExplicitNull{null: true} + return nil + } + *v = anyValueWithExplicitNull{} + return json.Unmarshal(data, &v.value) +} + var ( errInvalidRequestJSON = errors.New("jsonrpc2: request must be either a JSON object or JSON array") errInvalidResponseJSON = errors.New("jsonrpc2: response must be either a JSON object or JSON array") diff --git a/jsonrpc2_test.go b/jsonrpc2_test.go index ab36d8d..5aa826c 100644 --- a/jsonrpc2_test.go +++ b/jsonrpc2_test.go @@ -24,18 +24,19 @@ func TestRequest_MarshalJSON_jsonrpc(t *testing.T) { if err != nil { t.Fatal(err) } - if want := `"jsonrpc":"2.0"`; !strings.Contains(string(b), want) { - t.Errorf("got %s, want it to include the string %s", b, want) + if want := `{"method":"","id":0,"jsonrpc":"2.0"}`; string(b) != want { + t.Errorf("got %q, want %q", b, want) } } func TestResponse_MarshalJSON_jsonrpc(t *testing.T) { - b, err := json.Marshal(&jsonrpc2.Response{}) + null := json.RawMessage("null") + b, err := json.Marshal(&jsonrpc2.Response{Result: &null}) if err != nil { t.Fatal(err) } - if want := `"jsonrpc":"2.0"`; !strings.Contains(string(b), want) { - t.Errorf("got %s, want it to include the string %s", b, want) + if want := `{"id":0,"result":null,"jsonrpc":"2.0"}`; string(b) != want { + t.Errorf("got %q, want %q", b, want) } } diff --git a/object_test.go b/object_test.go index 593ceb2..8a90233 100644 --- a/object_test.go +++ b/object_test.go @@ -15,6 +15,7 @@ func TestAnyMessage(t *testing.T) { `{"foo":"bar"}`: {}, `{"method":"m"}`: {request: true}, `{"result":123}`: {response: true}, + `{"result":null}`: {response: true}, `{"error":{"code":456,"message":"m"}}`: {response: true}, } for s, want := range tests {