diff --git a/jsonrpc2.go b/jsonrpc2.go index 7b921f3..bc4bf38 100644 --- a/jsonrpc2.go +++ b/jsonrpc2.go @@ -43,10 +43,7 @@ type Request struct { // MarshalJSON implements json.Marshaler and adds the "jsonrpc":"2.0" // property. -func (r *Request) MarshalJSON() ([]byte, error) { - if r == nil { - return nil, errors.New("can't marshal nil *jsonrpc2.Request") - } +func (r Request) MarshalJSON() ([]byte, error) { r2 := struct { Method string `json:"method"` Params *json.RawMessage `json:"params,omitempty"` @@ -73,11 +70,22 @@ func (r *Request) UnmarshalJSON(data []byte) error { Meta *json.RawMessage `json:"meta,omitempty"` ID *ID `json:"id"` } + + // Detect if the "params" field is JSON "null" or just not present + // by seeing if the field gets overwritten to nil. + r2.Params = &json.RawMessage{} + if err := json.Unmarshal(data, &r2); err != nil { return err } r.Method = r2.Method - r.Params = r2.Params + if r2.Params == nil { + r.Params = &jsonNull + } else if len(*r2.Params) == 0 { + r.Params = nil + } else { + r.Params = r2.Params + } r.Meta = r2.Meta if r2.ID == nil { r.ID = ID{} @@ -115,7 +123,7 @@ func (r *Request) SetMeta(v interface{}) error { // http://www.jsonrpc.org/specification#response_object. type Response struct { ID ID `json:"id"` - Result *json.RawMessage `json:"result,omitempty"` + Result *json.RawMessage `json:"result"` Error *Error `json:"error,omitempty"` // SPEC NOTE: The spec says "If there was an error in detecting @@ -128,14 +136,12 @@ type Response struct { // MarshalJSON implements json.Marshaler and adds the "jsonrpc":"2.0" // property. -func (r *Response) MarshalJSON() ([]byte, error) { - if r == nil { - return nil, errors.New("can't marshal nil *jsonrpc2.Response") - } +func (r Response) MarshalJSON() ([]byte, error) { if (r.Result == nil || len(*r.Result) == 0) && r.Error == nil { return nil, errors.New("can't marshal *jsonrpc2.Response (must have result or error)") } - b, err := json.Marshal(*r) + type tmpType Response // avoid infinite MarshalJSON recursion + b, err := json.Marshal(tmpType(r)) if err != nil { return nil, err } @@ -143,6 +149,25 @@ func (r *Response) MarshalJSON() ([]byte, error) { return b, nil } +// UnmarshalJSON implements json.Unmarshaler. +func (r *Response) UnmarshalJSON(data []byte) error { + type tmpType Response + + // Detect if the "result" field is JSON "null" or just not present + // by seeing if the field gets overwritten to nil. + *r = Response{Result: &json.RawMessage{}} + + if err := json.Unmarshal(data, (*tmpType)(r)); err != nil { + return err + } + if r.Result == nil { // JSON "null" + r.Result = &jsonNull + } else if len(*r.Result) == 0 { + r.Result = nil + } + return nil +} + // SetResult sets r.Result to the JSON representation of v. If JSON // marshaling fails, it returns an error. func (r *Response) SetResult(v interface{}) error { @@ -531,7 +556,7 @@ type anyMessage struct { response *Response } -func (m *anyMessage) MarshalJSON() ([]byte, error) { +func (m anyMessage) MarshalJSON() ([]byte, error) { var v interface{} switch { case m.request != nil && m.response == nil: diff --git a/object_test.go b/object_test.go index 906d99b..3572422 100644 --- a/object_test.go +++ b/object_test.go @@ -1,6 +1,7 @@ package jsonrpc2 import ( + "bytes" "encoding/json" "reflect" "testing" @@ -35,32 +36,93 @@ func TestAnyMessage(t *testing.T) { } } -func TestMessageCodec(t *testing.T) { +func TestRequest_MarshalUnmarshalJSON(t *testing.T) { + null := json.RawMessage("null") obj := json.RawMessage(`{"foo":"bar"}`) tests := []struct { - v, vempty interface{} + data []byte + want Request }{ { - v: &Request{ID: ID{Num: 123}}, - vempty: &Request{ID: ID{Num: 123}}, + data: []byte(`{"method":"m","params":{"foo":"bar"},"id":123,"jsonrpc":"2.0"}`), + want: Request{ID: ID{Num: 123}, Method: "m", Params: &obj}, }, { - v: &Response{ID: ID{Num: 123}, Result: &obj}, - vempty: &Response{ID: ID{Num: 123}, Result: &obj}, + data: []byte(`{"method":"m","params":null,"id":123,"jsonrpc":"2.0"}`), + want: Request{ID: ID{Num: 123}, Method: "m", Params: &null}, + }, + { + data: []byte(`{"method":"m","id":123,"jsonrpc":"2.0"}`), + want: Request{ID: ID{Num: 123}, Method: "m", Params: nil}, }, } for _, test := range tests { - b, err := json.Marshal(test.v) + var got Request + if err := json.Unmarshal(test.data, &got); err != nil { + t.Error(err) + continue + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("%q: got %+v, want %+v", test.data, got, test.want) + continue + } + data, err := json.Marshal(got) if err != nil { - t.Fatal(err) + t.Error(err) + continue } - - if err := json.Unmarshal(b, test.vempty); err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(test.vempty, test.v) { - t.Errorf("got %+v, want %+v", test.vempty, test.v) + if !bytes.Equal(data, test.data) { + t.Errorf("got JSON %q, want %q", data, test.data) + } + } +} + +func TestResponse_MarshalUnmarshalJSON(t *testing.T) { + null := json.RawMessage("null") + obj := json.RawMessage(`{"foo":"bar"}`) + tests := []struct { + data []byte + want Response + error bool + }{ + { + data: []byte(`{"id":123,"result":{"foo":"bar"},"jsonrpc":"2.0"}`), + want: Response{ID: ID{Num: 123}, Result: &obj}, + }, + { + data: []byte(`{"id":123,"result":null,"jsonrpc":"2.0"}`), + want: Response{ID: ID{Num: 123}, Result: &null}, + }, + { + data: []byte(`{"id":123,"jsonrpc":"2.0"}`), + want: Response{ID: ID{Num: 123}, Result: nil}, + error: true, // either result or error field must be set + }, + } + for _, test := range tests { + var got Response + if err := json.Unmarshal(test.data, &got); err != nil { + t.Error(err) + continue + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("%q: got %+v, want %+v", test.data, got, test.want) + continue + } + data, err := json.Marshal(got) + if err != nil { + if test.error { + continue + } + t.Error(err) + continue + } + if test.error { + t.Errorf("%q: expected error", test.data) + continue + } + if !bytes.Equal(data, test.data) { + t.Errorf("got JSON %q, want %q", data, test.data) } } }