diff --git a/call_opt_test.go b/call_opt_test.go index 37dc556..b64f661 100644 --- a/call_opt_test.go +++ b/call_opt_test.go @@ -2,7 +2,6 @@ package jsonrpc2_test import ( "context" - "encoding/json" "fmt" "testing" @@ -108,18 +107,22 @@ func TestExtraField(t *testing.T) { t.Error(err) } } - var sessionId *json.RawMessage + var sessionID string for _, field := range req.ExtraFields { if field.Name != "sessionId" { continue } - sessionId = field.Value + var ok bool + sessionID, ok = field.Value.(string) + if !ok { + t.Errorf("\"sessionId\" is not a string: %v", field.Value) + } } - if sessionId == nil { + if sessionID == "" { replyWithError("sessionId must be set") return } - if string(*sessionId) != `"session"` { + if sessionID != "session" { replyWithError("sessionId has the wrong value") return } diff --git a/jsonrpc2.go b/jsonrpc2.go index 0fdb675..7e3bace 100644 --- a/jsonrpc2.go +++ b/jsonrpc2.go @@ -33,7 +33,7 @@ type JSONRPC2 interface { // RequestField is a top-level field that can be added to the JSON-RPC request. type RequestField struct { Name string - Value *json.RawMessage + Value interface{} } // Request represents a JSON-RPC request or @@ -83,46 +83,64 @@ func (r Request) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. func (r *Request) UnmarshalJSON(data []byte) error { - var r2 struct { - Method string `json:"method"` - Params *json.RawMessage `json:"params,omitempty"` - Meta *json.RawMessage `json:"meta,omitempty"` - ID *ID `json:"id"` - } - // This is used to get the extra fields, which are not type-safe. - r3 := make(map[string]*json.RawMessage) + r2 := make(map[string]interface{}) // Detect if the "params" field is JSON "null" or just not present // by seeing if the field gets overwritten to nil. - r2.Params = &json.RawMessage{} + emptyParams := &json.RawMessage{} + r2["params"] = emptyParams - if err := json.Unmarshal(data, &r2); err != nil { + decoder := json.NewDecoder(bytes.NewReader(data)) + decoder.UseNumber() + if err := decoder.Decode(&r2); err != nil { return err } - if err := json.Unmarshal(data, &r3); err != nil { - return err + var ok bool + r.Method, ok = r2["method"].(string) + if !ok { + return errors.New("missing method field") } - r.Method = r2.Method switch { - case r2.Params == nil: + case r2["params"] == nil: r.Params = &jsonNull - case len(*r2.Params) == 0: + case r2["params"] == emptyParams: r.Params = nil default: - r.Params = r2.Params + b, err := json.Marshal(r2["params"]) + if err != nil { + return fmt.Errorf("failed to marshal params: %w", err) + } + r.Params = (*json.RawMessage)(&b) } - r.Meta = r2.Meta - if r2.ID == nil { + meta, ok := r2["meta"] + if ok { + b, err := json.Marshal(meta) + if err != nil { + return fmt.Errorf("failed to marshal Meta: %w", err) + } + r.Meta = (*json.RawMessage)(&b) + } + switch rawID := r2["id"].(type) { + case nil: r.ID = ID{} r.Notif = true - } else { - r.ID = *r2.ID + case string: + r.ID = ID{Str: rawID, IsString: true} r.Notif = false + case json.Number: + id, err := rawID.Int64() + if err != nil { + return fmt.Errorf("failed to unmarshal ID: %w", err) + } + r.ID = ID{Num: uint64(id)} + r.Notif = false + default: + return fmt.Errorf("unexpected ID type: %T", rawID) } // Clear the extra fields before populating them again. r.ExtraFields = nil - for name, value := range r3 { + for name, value := range r2 { switch name { case "id", "jsonrpc", "meta", "method", "params": continue @@ -161,13 +179,9 @@ func (r *Request) SetMeta(v interface{}) error { // JSON representation of the request, as a way to add arbitrary extensions to // JSON RPC 2.0. If JSON marshaling fails, it returns an error. func (r *Request) SetExtraField(name string, v interface{}) error { - b, err := json.Marshal(v) - if err != nil { - return err - } r.ExtraFields = append(r.ExtraFields, RequestField{ Name: name, - Value: (*json.RawMessage)(&b), + Value: v, }) return nil } diff --git a/object_test.go b/object_test.go index 686f4c9..cfa5b00 100644 --- a/object_test.go +++ b/object_test.go @@ -39,7 +39,6 @@ func TestAnyMessage(t *testing.T) { func TestRequest_MarshalUnmarshalJSON(t *testing.T) { null := json.RawMessage("null") obj := json.RawMessage(`{"foo":"bar"}`) - requestFieldValue := json.RawMessage(`"session"`) tests := []struct { data []byte want Request @@ -58,7 +57,7 @@ func TestRequest_MarshalUnmarshalJSON(t *testing.T) { }, { data: []byte(`{"id":123,"jsonrpc":"2.0","method":"m","sessionId":"session"}`), - want: Request{ID: ID{Num: 123}, Method: "m", Params: nil, ExtraFields: []RequestField{{Name: "sessionId", Value: &requestFieldValue}}}, + want: Request{ID: ID{Num: 123}, Method: "m", Params: nil, ExtraFields: []RequestField{{Name: "sessionId", Value: "session"}}}, }, } for _, test := range tests {