diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..2390d8c --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" + groups: + github-actions: + patterns: + - "*" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..29413b6 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,34 @@ +name: CI +on: + pull_request: {} + push: + branches: + - master + +permissions: + contents: read + +jobs: + test: + strategy: + fail-fast: false + matrix: + go: + - 1.16 + name: Go ${{ matrix.go }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: ${{ matrix.go }} + id: go + - name: Get dependencies + run: go get -t -v ./... + - name: Install staticcheck + run: go install honnef.co/go/tools/cmd/staticcheck@v0.2.2 + - name: Lint + run: staticcheck -checks=all ./... + - name: Test + run: go test -v -race ./... diff --git a/.github/workflows/lsif.yml b/.github/workflows/lsif.yml deleted file mode 100644 index 83d4bfd..0000000 --- a/.github/workflows/lsif.yml +++ /dev/null @@ -1,13 +0,0 @@ -name: LSIF -on: - - push -jobs: - lsif-go: - runs-on: ubuntu-latest - container: sourcegraph/lsif-go - steps: - - uses: actions/checkout@v1 - - name: Generate LSIF data - run: lsif-go - - name: Upload LSIF data - run: src lsif upload -github-token=${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/scip.yml b/.github/workflows/scip.yml new file mode 100644 index 0000000..c0dc33b --- /dev/null +++ b/.github/workflows/scip.yml @@ -0,0 +1,20 @@ +name: SCIP +'on': + - push +permissions: + contents: read +jobs: + scip-go: + runs-on: ubuntu-latest + container: sourcegraph/scip-go + steps: + - uses: actions/checkout@v6 + - name: Get src-cli + run: curl -L https://sourcegraph.com/.api/src-cli/src_linux_amd64 -o /usr/local/bin/src; + chmod +x /usr/local/bin/src + - name: Set directory to safe for git + run: git config --global --add safe.directory $GITHUB_WORKSPACE + - name: Generate SCIP data + run: scip-go + - name: Upload SCIP data + run: src code-intel upload -github-token=${{ secrets.GITHUB_TOKEN }} diff --git a/README.md b/README.md index d2406ab..df4f081 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,8 @@ Package jsonrpc2 provides a [Go](https://golang.org) implementation of [JSON-RPC 2.0](http://www.jsonrpc.org/specification). -This package is **experimental** until further notice. - -[**Open the code in Sourcegraph**](https://sourcegraph.com/github.com/sourcegraph/jsonrpc2) +* [Documentation](https://pkg.go.dev/github.com/sourcegraph/jsonrpc2) +* [Open the code in Sourcegraph](https://sourcegraph.com/github.com/sourcegraph/jsonrpc2) ## Known issues diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..d889a1f --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,12 @@ +# Security Policy + +## Supported Versions + +Security updates are applied only to the latest release. + +## Reporting a Vulnerability + +If you have discovered a security vulnerability in this project, please report it privately. **Do not disclose it as a public issue.** This gives us time to work with you to evaluate and fix the issue before public exposure, reducing the chance that the exploit will be used before a patch is released. + +Please disclose it privately via email to security@sourcegraph.com. We will work with you to understand and resolve the issue promptly. + diff --git a/call_opt.go b/call_opt.go index 73fe9c2..851baa5 100644 --- a/call_opt.go +++ b/call_opt.go @@ -19,6 +19,15 @@ func Meta(meta interface{}) CallOption { }) } +// ExtraField returns a call option which attaches the given name/value pair to +// the JSON-RPC 2.0 request. This can be used to add arbitrary extensions to +// JSON RPC 2.0. +func ExtraField(name string, value interface{}) CallOption { + return callOptionFunc(func(r *Request) error { + return r.SetExtraField(name, value) + }) +} + // PickID returns a call option which sets the ID on a request. Care must be // taken to ensure there are no conflicts with any previously picked ID, nor // with the default sequence ID. diff --git a/call_opt_test.go b/call_opt_test.go index 82b05ca..b64f661 100644 --- a/call_opt_test.go +++ b/call_opt_test.go @@ -90,3 +90,53 @@ func TestStringID(t *testing.T) { t.Fatal(err) } } + +func TestExtraField(t *testing.T) { + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + a, b := inMemoryPeerConns() + defer a.Close() + defer b.Close() + + handler := handlerFunc(func(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + replyWithError := func(msg string) { + respErr := &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidRequest, Message: msg} + if err := conn.ReplyWithError(ctx, req.ID, respErr); err != nil { + t.Error(err) + } + } + var sessionID string + for _, field := range req.ExtraFields { + if field.Name != "sessionId" { + continue + } + var ok bool + sessionID, ok = field.Value.(string) + if !ok { + t.Errorf("\"sessionId\" is not a string: %v", field.Value) + } + } + if sessionID == "" { + replyWithError("sessionId must be set") + return + } + if sessionID != "session" { + replyWithError("sessionId has the wrong value") + return + } + if err := conn.Reply(ctx, req.ID, "ok"); err != nil { + t.Error(err) + } + }) + connA := jsonrpc2.NewConn(ctx, jsonrpc2.NewBufferedStream(a, jsonrpc2.VSCodeObjectCodec{}), handler) + connB := jsonrpc2.NewConn(ctx, jsonrpc2.NewBufferedStream(b, jsonrpc2.VSCodeObjectCodec{}), noopHandler{}) + defer connA.Close() + defer connB.Close() + + var res string + if err := connB.Call(ctx, "f", nil, &res, jsonrpc2.ExtraField("sessionId", "session")); err != nil { + t.Fatal(err) + } +} diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..35c6279 --- /dev/null +++ b/conn.go @@ -0,0 +1,479 @@ +package jsonrpc2 + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "log" + "os" + "strconv" + "sync" +) + +// Conn is a JSON-RPC client/server connection. The JSON-RPC protocol +// is symmetric, so a Conn runs on both ends of a client-server +// connection. +type Conn struct { + stream ObjectStream + + h Handler + + mu sync.Mutex + closed bool + seq uint64 + pending map[ID]*call + + sending sync.Mutex + + cancelCtx context.CancelFunc + disconnect chan struct{} + + logger Logger + + // Set by ConnOpt funcs. + onRecv []func(*Request, *Response) + onSend []func(*Request, *Response) +} + +var _ JSONRPC2 = (*Conn)(nil) + +// NewConn creates a new JSON-RPC client/server connection using the +// given ReadWriteCloser (typically a TCP connection or stdio). The +// JSON-RPC protocol is symmetric, so a Conn runs on both ends of a +// client-server connection. +// +// NewConn consumes stream, so you should call Close on the returned +// Conn not on the given stream or its underlying connection. +// +// Conn is closed when the given context's Done channel is closed. +func NewConn(ctx context.Context, stream ObjectStream, h Handler, opts ...ConnOpt) *Conn { + + ctx, cancel := context.WithCancel(ctx) + + c := &Conn{ + stream: stream, + h: h, + pending: map[ID]*call{}, + cancelCtx: cancel, + disconnect: make(chan struct{}), + logger: log.New(os.Stderr, "", log.LstdFlags), + } + for _, opt := range opts { + if opt == nil { + continue + } + opt(c) + } + go c.readMessages(ctx) + + go func() { + <-ctx.Done() + c.close(nil) + }() + + return c +} + +// Close closes the JSON-RPC connection. The connection may not be +// used after it has been closed. +func (c *Conn) Close() error { + return c.close(nil) +} + +// Call initiates a JSON-RPC call using the specified method and params, and +// waits for the response. If the response is successful, its result is stored +// in result (a pointer to a value that can be JSON-unmarshaled into); +// otherwise, a non-nil error is returned. See DispatchCall for more details. +func (c *Conn) Call(ctx context.Context, method string, params, result interface{}, opts ...CallOption) error { + call, err := c.DispatchCall(ctx, method, params, opts...) + if err != nil { + return err + } + return call.Wait(ctx, result) +} + +// DisconnectNotify returns a channel that is closed when the +// underlying connection is disconnected. +func (c *Conn) DisconnectNotify() <-chan struct{} { + return c.disconnect +} + +// DispatchCall dispatches a JSON-RPC call using the specified method and +// params, and returns a call proxy or an error. Call Wait() on the returned +// proxy to receive the response. Only use this function if you need to do work +// after dispatching the request, otherwise use Call. +// +// The params member is omitted from the JSON-RPC request if the given params is +// nil. Use json.RawMessage("null") to send a JSON-RPC request with its params +// member set to null. +func (c *Conn) DispatchCall(ctx context.Context, method string, params interface{}, opts ...CallOption) (Waiter, error) { + req := &Request{Method: method} + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.apply(req); err != nil { + return Waiter{}, err + } + } + if params != nil { + if err := req.SetParams(params); err != nil { + return Waiter{}, err + } + } + call, err := c.send(ctx, &anyMessage{request: req}, true) + if err != nil { + return Waiter{}, err + } + return Waiter{call: call}, nil +} + +// Notify is like Call, but it returns when the notification request is sent +// (without waiting for a response, because JSON-RPC notifications do not have +// responses). +// +// The params member is omitted from the JSON-RPC request if the given params is +// nil. Use json.RawMessage("null") to send a JSON-RPC request with its params +// member set to null. +func (c *Conn) Notify(ctx context.Context, method string, params interface{}, opts ...CallOption) error { + req := &Request{Method: method, Notif: true} + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.apply(req); err != nil { + return err + } + } + if params != nil { + if err := req.SetParams(params); err != nil { + return err + } + } + _, err := c.send(ctx, &anyMessage{request: req}, false) + return err +} + +// Reply sends a successful response with a result. +func (c *Conn) Reply(ctx context.Context, id ID, result interface{}) error { + resp := &Response{ID: id} + if err := resp.SetResult(result); err != nil { + return err + } + _, err := c.send(ctx, &anyMessage{response: resp}, false) + return err +} + +// ReplyWithError sends a response with an error. +func (c *Conn) ReplyWithError(ctx context.Context, id ID, respErr *Error) error { + _, err := c.send(ctx, &anyMessage{response: &Response{ID: id, Error: respErr}}, false) + return err +} + +// SendResponse sends resp to the peer. It is lower level than (*Conn).Reply. +func (c *Conn) SendResponse(ctx context.Context, resp *Response) error { + _, err := c.send(ctx, &anyMessage{response: resp}, false) + return err +} + +func (c *Conn) close(cause error) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return ErrClosed + } + + for _, call := range c.pending { + close(call.done) + } + + if cause != nil && cause != io.EOF && cause != io.ErrUnexpectedEOF { + c.logger.Printf("jsonrpc2: protocol error: %v\n", cause) + } + + close(c.disconnect) + c.cancelCtx() + c.closed = true + return c.stream.Close() +} + +func (c *Conn) readMessages(ctx context.Context) { + for { + var m anyMessage + err := c.stream.ReadObject(&m) + if err != nil { + c.close(err) + return + } + + switch { + // TODO: handle the case where both request and response are nil. + + case m.request != nil: + for _, onRecv := range c.onRecv { + onRecv(m.request, nil) + } + c.h.Handle(ctx, c, m.request) + + case m.response != nil: + resp := m.response + id := resp.ID + c.mu.Lock() + call := c.pending[id] + delete(c.pending, id) + c.mu.Unlock() + + var req *Request + if call != nil { + call.response = resp + req = call.request + } + + for _, onRecv := range c.onRecv { + onRecv(req, resp) + } + + if call == nil { + c.logger.Printf("jsonrpc2: ignoring response #%s with no corresponding request\n", id) + continue + } + + var err error + if resp.Error != nil { + err = resp.Error + } + + call.done <- err + close(call.done) + } + } +} + +func (c *Conn) send(_ context.Context, m *anyMessage, wait bool) (cc *call, err error) { + c.sending.Lock() + defer c.sending.Unlock() + + // double check the error isn't due to being closed while sending. + defer func() { + if err != nil { + c.mu.Lock() + if c.closed { + err = ErrClosed + } + c.mu.Unlock() + } + }() + + // m.request.ID could be changed, so we store a copy to correctly + // clean up pending + var id ID + + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, ErrClosed + } + + // Assign a default id if not set + if m.request != nil && wait { + cc = &call{request: m.request, seq: c.seq, done: make(chan error, 1)} + + isIDUnset := len(m.request.ID.Str) == 0 && m.request.ID.Num == 0 + if isIDUnset { + if m.request.ID.IsString { + m.request.ID.Str = strconv.FormatUint(c.seq, 10) + } else { + m.request.ID.Num = c.seq + } + } + c.seq++ + } + c.mu.Unlock() + + if len(c.onSend) > 0 { + var ( + req *Request + resp *Response + ) + switch { + case m.request != nil: + req = m.request + case m.response != nil: + resp = m.response + } + for _, onSend := range c.onSend { + onSend(req, resp) + } + } + + // Store requests so we can later associate them with incoming + // responses. + if m.request != nil && wait { + c.mu.Lock() + id = m.request.ID + c.pending[id] = cc + c.mu.Unlock() + } + + // From here on, if we fail to send this, then we need to remove + // this from the pending map so we don't block on it or pile up + // pending entries for unsent messages. + defer func() { + if err != nil { + if cc != nil { + c.mu.Lock() + delete(c.pending, id) + c.mu.Unlock() + } + } + }() + + if err := c.stream.WriteObject(m); err != nil { + return nil, err + } + return cc, nil +} + +// Waiter proxies an ongoing JSON-RPC call. +type Waiter struct { + *call +} + +// Wait for the result of an ongoing JSON-RPC call. If the response +// is successful, its result is stored in result (a pointer to a +// value that can be JSON-unmarshaled into); otherwise, a non-nil +// error is returned. +func (w Waiter) Wait(ctx context.Context, result interface{}) error { + select { + case <-ctx.Done(): + return ctx.Err() + + case err, ok := <-w.call.done: + if !ok { + return ErrClosed + } + if err != nil || result == nil { + return err + } + if w.call.response.Result == nil { + w.call.response.Result = &jsonNull + } + return json.Unmarshal(*w.call.response.Result, result) + } +} + +// call represents a JSON-RPC call over its entire lifecycle. +type call struct { + request *Request + response *Response + seq uint64 // the seq of the request + done chan error +} + +// anyMessage represents either a JSON Request or Response. +type anyMessage struct { + request *Request + response *Response +} + +func (m anyMessage) MarshalJSON() ([]byte, error) { + var v interface{} + switch { + case m.request != nil && m.response == nil: + v = m.request + case m.request == nil && m.response != nil: + v = m.response + } + if v != nil { + return json.Marshal(v) + } + return nil, errors.New("jsonrpc2: message must have exactly one of the request or response fields set") +} + +func (m *anyMessage) UnmarshalJSON(data []byte) error { + // The presence of these fields distinguishes between the 2 + // message types. + type msg struct { + ID interface{} `json:"id"` + Method *string `json:"method"` + Result anyValueWithExplicitNull `json:"result"` + Error interface{} `json:"error"` + } + + var isRequest, isResponse bool + checkType := func(m *msg) error { + mIsRequest := m.Method != nil + mIsResponse := m.Result.null || m.Result.value != nil || m.Error != nil + if (!mIsRequest && !mIsResponse) || (mIsRequest && mIsResponse) { + return errors.New("jsonrpc2: unable to determine message type (request or response)") + } + if (mIsRequest && isResponse) || (mIsResponse && isRequest) { + return errors.New("jsonrpc2: batch message type mismatch (must be all requests or all responses)") + } + isRequest = mIsRequest + isResponse = mIsResponse + return nil + } + + if isArray := len(data) > 0 && data[0] == '['; isArray { + var msgs []msg + if err := json.Unmarshal(data, &msgs); err != nil { + return err + } + if len(msgs) == 0 { + return errors.New("jsonrpc2: invalid empty batch") + } + for i := range msgs { + if err := checkType(&msgs[i]); err != nil { + return err + } + } + } else { + var m msg + if err := json.Unmarshal(data, &m); err != nil { + return err + } + if err := checkType(&m); err != nil { + return err + } + } + + var v interface{} + switch { + case isRequest && !isResponse: + v = &m.request + case !isRequest && isResponse: + v = &m.response + } + if err := json.Unmarshal(data, v); err != nil { + return err + } + if !isRequest && isResponse && m.response.Error == nil && m.response.Result == nil { + m.response.Result = &jsonNull + } + return nil +} + +// anyValueWithExplicitNull is used to distinguish {} from +// {"result":null} by anyMessage's JSON unmarshaler. +type anyValueWithExplicitNull struct { + null bool // JSON "null" + value interface{} +} + +func (v anyValueWithExplicitNull) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) +} + +func (v *anyValueWithExplicitNull) UnmarshalJSON(data []byte) error { + data = bytes.TrimSpace(data) + if string(data) == "null" { + *v = anyValueWithExplicitNull{null: true} + return nil + } + *v = anyValueWithExplicitNull{} + return json.Unmarshal(data, &v.value) +} diff --git a/conn_opt.go b/conn_opt.go index 3779d7f..8a29f80 100644 --- a/conn_opt.go +++ b/conn_opt.go @@ -43,6 +43,20 @@ func LogMessages(logger Logger) ConnOpt { OnRecv(func(req *Request, resp *Response) { switch { + case resp != nil: + method := "(no matching request)" + if req != nil { + method = req.Method + } + switch { + case resp.Result != nil: + result, _ := json.Marshal(resp.Result) + logger.Printf("jsonrpc2: --> result #%s: %s: %s\n", resp.ID, method, result) + case resp.Error != nil: + err, _ := json.Marshal(resp.Error) + logger.Printf("jsonrpc2: --> error #%s: %s: %s\n", resp.ID, method, err) + } + case req != nil: mu.Lock() reqMethods[req.ID] = req.Method @@ -54,34 +68,10 @@ func LogMessages(logger Logger) ConnOpt { } else { logger.Printf("jsonrpc2: --> request #%s: %s: %s\n", req.ID, req.Method, params) } - - case resp != nil: - var method string - if req != nil { - method = req.Method - } else { - method = "(no matching request)" - } - switch { - case resp.Result != nil: - result, _ := json.Marshal(resp.Result) - logger.Printf("jsonrpc2: --> result #%s: %s: %s\n", resp.ID, method, result) - case resp.Error != nil: - err, _ := json.Marshal(resp.Error) - logger.Printf("jsonrpc2: --> error #%s: %s: %s\n", resp.ID, method, err) - } } })(c) OnSend(func(req *Request, resp *Response) { switch { - case req != nil: - params, _ := json.Marshal(req.Params) - if req.Notif { - logger.Printf("jsonrpc2: <-- notif: %s: %s\n", req.Method, params) - } else { - logger.Printf("jsonrpc2: <-- request #%s: %s: %s\n", req.ID, req.Method, params) - } - case resp != nil: mu.Lock() method := reqMethods[resp.ID] @@ -98,7 +88,22 @@ func LogMessages(logger Logger) ConnOpt { err, _ := json.Marshal(resp.Error) logger.Printf("jsonrpc2: <-- error #%s: %s: %s\n", resp.ID, method, err) } + + case req != nil: + params, _ := json.Marshal(req.Params) + if req.Notif { + logger.Printf("jsonrpc2: <-- notif: %s: %s\n", req.Method, params) + } else { + logger.Printf("jsonrpc2: <-- request #%s: %s: %s\n", req.ID, req.Method, params) + } } })(c) } } + +// SetLogger sets the logger for the connection. +func SetLogger(logger Logger) ConnOpt { + return func(c *Conn) { + c.logger = logger + } +} diff --git a/conn_opt_test.go b/conn_opt_test.go new file mode 100644 index 0000000..97f59e4 --- /dev/null +++ b/conn_opt_test.go @@ -0,0 +1,130 @@ +package jsonrpc2_test + +import ( + "bufio" + "context" + "io" + "log" + "net" + "testing" + + "github.com/sourcegraph/jsonrpc2" +) + +func TestSetLogger(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rd, wr := io.Pipe() + defer rd.Close() + defer wr.Close() + + buf := bufio.NewReader(rd) + logger := log.New(wr, "", log.Lmsgprefix) + + a, b := net.Pipe() + connA := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewBufferedStream(a, jsonrpc2.VSCodeObjectCodec{}), + noopHandler{}, + jsonrpc2.SetLogger(logger), + ) + connB := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewBufferedStream(b, jsonrpc2.VSCodeObjectCodec{}), + noopHandler{}, + ) + defer connA.Close() + defer connB.Close() + + // Write a response with no corresponding request. + if err := connB.Reply(ctx, jsonrpc2.ID{Num: 0}, nil); err != nil { + t.Fatal(err) + } + + want := "jsonrpc2: ignoring response #0 with no corresponding request\n" + got, err := buf.ReadString('\n') + if err != nil { + t.Fatal(err) + } + if got != want { + t.Fatalf("got %q, want %q", got, want) + } +} + +type dummyHandler struct { + t *testing.T +} + +func (h *dummyHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + if !req.Notif { + err := conn.Reply(ctx, req.ID, nil) + if err != nil { + h.t.Error(err) + return + } + } +} + +func TestLogMessages(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rd, wr := io.Pipe() + defer rd.Close() + defer wr.Close() + + buf := bufio.NewReader(rd) + logger := log.New(wr, "", log.Lmsgprefix) + + a, b := net.Pipe() + connA := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewBufferedStream(a, jsonrpc2.VSCodeObjectCodec{}), + &dummyHandler{t}, + jsonrpc2.LogMessages(logger), + ) + connB := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewBufferedStream(b, jsonrpc2.VSCodeObjectCodec{}), + &dummyHandler{t}, + ) + defer connA.Close() + defer connB.Close() + + go func() { + if err := connA.Call(ctx, "method1", nil, nil); err != nil { + t.Error(err) + return + } + if err := connB.Call(ctx, "method2", nil, nil); err != nil { + t.Error(err) + return + } + if err := connA.Notify(ctx, "notification1", nil); err != nil { + t.Error(err) + return + } + if err := connB.Notify(ctx, "notification2", nil); err != nil { + t.Error(err) + return + } + }() + + for i, want := range []string{ + "jsonrpc2: <-- request #0: method1: null\n", + "jsonrpc2: --> result #0: method1: null\n", + "jsonrpc2: --> request #0: method2: null\n", + "jsonrpc2: <-- result #0: method2: null\n", + "jsonrpc2: <-- notif: notification1: null\n", + "jsonrpc2: --> notif: notification2: null\n", + } { + got, err := buf.ReadString('\n') + if err != nil { + t.Fatal(err) + } + if got != want { + t.Errorf("message %v: got %q, want %q", i, got, want) + } + } +} diff --git a/conn_test.go b/conn_test.go new file mode 100644 index 0000000..5d2a7e4 --- /dev/null +++ b/conn_test.go @@ -0,0 +1,304 @@ +package jsonrpc2_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "net" + "sync" + "testing" + "time" + + "github.com/sourcegraph/jsonrpc2" +) + +func TestConn(t *testing.T) { + + t.Run("closes when context is done", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + connA, connB := Pipe(ctx, noopHandler{}, noopHandler{}) + defer connA.Close() + defer connB.Close() + + cancel() + <-connA.DisconnectNotify() + + got := connA.Close() + want := jsonrpc2.ErrClosed + if got != want { + t.Fatalf("got %v, want %v", got, want) + } + }) + + t.Run("cancels context when closed", func(t *testing.T) { + ctxCanceled := make(chan struct{}) + + handler := handlerFunc(func(ctx context.Context, c *jsonrpc2.Conn, r *jsonrpc2.Request) { + // Block until the context is canceled. + <-ctx.Done() + close(ctxCanceled) + }) + + connA, connB := Pipe(context.Background(), noopHandler{}, jsonrpc2.AsyncHandler(handler)) + defer connA.Close() + defer connB.Close() + + // Send a notification from connA to connB to trigger connB's handler + // function. + if err := connA.Notify(context.Background(), "foo", nil, nil); err != nil { + t.Fatal(err) + } + + // Disconnect connA from connB. + if err := connA.Close(); err != nil { + t.Fatal(err) + } + + select { + case <-ctxCanceled: + // Test passed, the handler's context was canceled. + case <-time.After(time.Second): + t.Fatal("context not canceled") + } + }) +} + +var paramsTests = []struct { + sendParams interface{} + wantParams *json.RawMessage +}{ + { + sendParams: nil, + wantParams: nil, + }, + { + sendParams: jsonNull, + wantParams: &jsonNull, + }, + { + sendParams: false, + wantParams: rawJSONMessage("false"), + }, + { + sendParams: 0, + wantParams: rawJSONMessage("0"), + }, + { + sendParams: "", + wantParams: rawJSONMessage(`""`), + }, + { + sendParams: rawJSONMessage(`{"foo":"bar"}`), + wantParams: rawJSONMessage(`{"foo":"bar"}`), + }, +} + +func TestConn_DispatchCall(t *testing.T) { + for _, test := range paramsTests { + t.Run(fmt.Sprintf("%s", test.sendParams), func(t *testing.T) { + testParams(t, test.wantParams, func(c *jsonrpc2.Conn) error { + _, err := c.DispatchCall(context.Background(), "f", test.sendParams) + return err + }) + }) + } +} + +func TestConn_Notify(t *testing.T) { + for _, test := range paramsTests { + t.Run(fmt.Sprintf("%s", test.sendParams), func(t *testing.T) { + testParams(t, test.wantParams, func(c *jsonrpc2.Conn) error { + return c.Notify(context.Background(), "f", test.sendParams) + }) + }) + } +} + +func TestConn_DisconnectNotify(t *testing.T) { + + t.Run("EOF", func(t *testing.T) { + connA, connB := net.Pipe() + c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil) + // By closing connA, connB receives io.EOF + if err := connA.Close(); err != nil { + t.Error(err) + } + assertDisconnect(t, c, connB) + }) + + t.Run("Close", func(t *testing.T) { + _, connB := net.Pipe() + c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil) + if err := c.Close(); err != nil { + t.Error(err) + } + assertDisconnect(t, c, connB) + }) + + t.Run("Close async", func(t *testing.T) { + done := make(chan struct{}) + _, connB := net.Pipe() + c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil) + go func() { + if err := c.Close(); err != nil && err != jsonrpc2.ErrClosed { + t.Error(err) + } + close(done) + }() + assertDisconnect(t, c, connB) + <-done + }) + + t.Run("protocol error", func(t *testing.T) { + connA, connB := net.Pipe() + c := jsonrpc2.NewConn( + context.Background(), + jsonrpc2.NewPlainObjectStream(connB), + noopHandler{}, + // Suppress log message. This connection receives an invalid JSON + // message that causes an error to be written to the logger. We + // don't want this expected error to appear in os.Stderr though when + // running tests in verbose mode or when other tests fail. + jsonrpc2.SetLogger(log.New(io.Discard, "", 0)), + ) + connA.Write([]byte("invalid json")) + assertDisconnect(t, c, connB) + }) +} + +func TestConn_Close(t *testing.T) { + cases := []struct { + name string + run func(*testing.T, context.Context, *jsonrpc2.Conn) + }{{ + name: "during Call", + run: func(t *testing.T, ctx context.Context, conn *jsonrpc2.Conn) { + ready := make(chan struct{}) + done := make(chan struct{}) + go func() { + close(ready) + err := conn.Call(ctx, "m", nil, nil) + if err != jsonrpc2.ErrClosed { + t.Errorf("got error %v, want %v", err, jsonrpc2.ErrClosed) + } + close(done) + }() + // Wait for the request to be sent before we close the connection. + <-ready + if err := conn.Close(); err != nil && err != jsonrpc2.ErrClosed { + t.Error(err) + } + <-done + }, + }, { + name: "during Wait", + run: func(t *testing.T, ctx context.Context, conn *jsonrpc2.Conn) { + call, err := conn.DispatchCall(ctx, "m", nil, nil) + if err != nil { + t.Fatal(err) + } + if err := conn.Close(); err != nil { + t.Fatal(err) + } + if err := call.Wait(ctx, nil); err != jsonrpc2.ErrClosed { + t.Fatal(err) + } + }, + }, { + name: "during Dispatch", + run: func(t *testing.T, ctx context.Context, conn *jsonrpc2.Conn) { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + if _, err := conn.DispatchCall(ctx, "m", nil, nil); err != jsonrpc2.ErrClosed { + t.Fatal(err) + } + }, + }} + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + connA, connB := net.Pipe() + nodeA := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewPlainObjectStream(connA), noopHandler{}, + ) + defer nodeA.Close() + nodeB := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewPlainObjectStream(connB), + noopHandler{}, + ) + defer nodeB.Close() + + tc.run(t, ctx, nodeB) + + assertDisconnect(t, nodeB, connB) + }) + } +} + +func testParams(t *testing.T, want *json.RawMessage, fn func(c *jsonrpc2.Conn) error) { + wg := &sync.WaitGroup{} + handler := handlerFunc(func(ctx context.Context, conn *jsonrpc2.Conn, r *jsonrpc2.Request) { + assertRawJSONMessage(t, r.Params, want) + wg.Done() + }) + + connA, connB := Pipe(context.Background(), noopHandler{}, handler) + defer connA.Close() + defer connB.Close() + + wg.Add(1) + if err := fn(connA); err != nil { + t.Error(err) + } + wg.Wait() +} + +func assertDisconnect(t *testing.T, c *jsonrpc2.Conn, conn io.Writer) { + select { + case <-c.DisconnectNotify(): + case <-time.After(200 * time.Millisecond): + t.Error("no disconnect notification") + return + } + // Assert that conn is closed by trying to write to it. + _, got := conn.Write(nil) + want := io.ErrClosedPipe + if got != want { + t.Errorf("got %s, want %s", got, want) + } +} + +func assertRawJSONMessage(t *testing.T, got *json.RawMessage, want *json.RawMessage) { + // Assert pointers. + if got == nil || want == nil { + if got != want { + t.Errorf("pointer: got %s, want %s", got, want) + } + return + } + { + // If pointers are not nil, then assert values. + got := string(*got) + want := string(*want) + if got != want { + t.Errorf("value: got %q, want %q", got, want) + } + } +} + +// Pipe returns two jsonrpc2.Conn, connected via a synchronous, in-memory, full +// duplex network connection. +func Pipe(ctx context.Context, handlerA, handlerB jsonrpc2.Handler) (connA *jsonrpc2.Conn, connB *jsonrpc2.Conn) { + a, b := net.Pipe() + connA = jsonrpc2.NewConn(ctx, jsonrpc2.NewPlainObjectStream(a), handlerA) + connB = jsonrpc2.NewConn(ctx, jsonrpc2.NewPlainObjectStream(b), handlerB) + return connA, connB +} diff --git a/example_params_test.go b/example_params_test.go new file mode 100644 index 0000000..9b2b75a --- /dev/null +++ b/example_params_test.go @@ -0,0 +1,78 @@ +package jsonrpc2_test + +import ( + "context" + "encoding/json" + "fmt" + "net" + "os" + + "github.com/sourcegraph/jsonrpc2" +) + +// Send a JSON-RPC notification with its params member omitted. +func ExampleConn_Notify_paramsOmitted() { + ctx := context.Background() + + connA, connB := net.Pipe() + defer connA.Close() + defer connB.Close() + + rpcConn := jsonrpc2.NewConn(ctx, jsonrpc2.NewPlainObjectStream(connA), nil) + + // Send the JSON-RPC notification. + go func() { + // Set params to nil. + if err := rpcConn.Notify(ctx, "foo", nil); err != nil { + fmt.Fprintln(os.Stderr, "notify:", err) + } + }() + + // Read the raw JSON-RPC notification on connB. + // + // Reading the raw JSON-RPC request is for the purpose of this example only. + // Use a jsonrpc2.Handler to read parsed requests. + buf := make([]byte, 64) + n, err := connB.Read(buf) + if err != nil { + fmt.Fprintln(os.Stderr, "read:", err) + } + + fmt.Printf("%s\n", buf[:n]) + + // Output: {"jsonrpc":"2.0","method":"foo"} +} + +// Send a JSON-RPC notification with its params member set to null. +func ExampleConn_Notify_nullParams() { + ctx := context.Background() + + connA, connB := net.Pipe() + defer connA.Close() + defer connB.Close() + + rpcConn := jsonrpc2.NewConn(ctx, jsonrpc2.NewPlainObjectStream(connA), nil) + + // Send the JSON-RPC notification. + go func() { + // Set params to the JSON null value. + params := json.RawMessage("null") + if err := rpcConn.Notify(ctx, "foo", params); err != nil { + fmt.Fprintln(os.Stderr, "notify:", err) + } + }() + + // Read the raw JSON-RPC notification on connB. + // + // Reading the raw JSON-RPC request is for the purpose of this example only. + // Use a jsonrpc2.Handler to read parsed requests. + buf := make([]byte, 64) + n, err := connB.Read(buf) + if err != nil { + fmt.Fprintln(os.Stderr, "read:", err) + } + + fmt.Printf("%s\n", buf[:n]) + + // Output: {"jsonrpc":"2.0","method":"foo","params":null} +} diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..00f4a57 --- /dev/null +++ b/example_test.go @@ -0,0 +1,64 @@ +package jsonrpc2_test + +import ( + "context" + "fmt" + "log" + "net" + "os" + + "github.com/sourcegraph/jsonrpc2" +) + +func Example() { + ctx := context.Background() + + // Create an in-memory network connection. This connection is used below to + // transport the JSON-RPC messages. However, any io.ReadWriteCloser may be + // used to send/receive JSON-RPC messages. + connA, connB := net.Pipe() + + // The following JSON-RPC connection is both a client and a server. It can + // send requests as well as receive requests. The incoming requests are + // handled by myHandler. + jsonrpcConnA := jsonrpc2.NewConn(ctx, jsonrpc2.NewPlainObjectStream(connA), &myHandler{}) + defer jsonrpcConnA.Close() + + // The following JSON-RPC connection has no handler, meaning that it is + // configured to only be a client. It can send requests and receive the + // responses to those requests, but it will ignore any incoming requests. + jsonrpcConnB := jsonrpc2.NewConn(ctx, jsonrpc2.NewPlainObjectStream(connB), nil) + defer jsonrpcConnB.Close() + + // Send a request from jsonrpcConnB to jsonrpcConnA. The result of a + // successful call is stored in the result variable. + var result string + if err := jsonrpcConnB.Call(ctx, "sayHello", nil, &result); err != nil { + fmt.Fprintln(os.Stderr, err) + return + } + + fmt.Println(result) + + // Output: hello world +} + +// myHandler is the jsonrpc2.Handler used by jsonrpcConnA. +type myHandler struct{} + +// Handle implements the jsonrpc2.Handler interface. +func (h *myHandler) Handle(ctx context.Context, c *jsonrpc2.Conn, r *jsonrpc2.Request) { + switch r.Method { + case "sayHello": + if err := c.Reply(ctx, r.ID, "hello world"); err != nil { + log.Println(err) + return + } + default: + err := &jsonrpc2.Error{Code: jsonrpc2.CodeMethodNotFound, Message: "Method not found"} + if err := c.ReplyWithError(ctx, r.ID, err); err != nil { + log.Println(err) + return + } + } +} diff --git a/handler_with_error.go b/handler_with_error.go index 2bd5c1d..d727237 100644 --- a/handler_with_error.go +++ b/handler_with_error.go @@ -30,20 +30,16 @@ func (h *HandlerWithErrorConfigurer) Handle(ctx context.Context, conn *Conn, req if err == nil { err = resp.SetResult(result) } - if err != nil { - if e, ok := err.(*Error); ok { - resp.Error = e - } else { - resp.Error = &Error{Message: err.Error()} - } + + if e, ok := err.(*Error); ok { + resp.Error = e + } else if err != nil { + resp.Error = &Error{Message: err.Error()} } - if !req.Notif { - if err := conn.SendResponse(ctx, resp); err != nil { - if err != ErrClosed || !h.suppressErrClosed { - conn.logger.Printf("jsonrpc2 handler: sending response %s: %v\n", resp.ID, err) - } - } + err = conn.SendResponse(ctx, resp) + if err != nil && (err != ErrClosed || !h.suppressErrClosed) { + conn.logger.Printf("jsonrpc2 handler: sending response %s: %v\n", resp.ID, err) } } diff --git a/internal_test.go b/internal_test.go new file mode 100644 index 0000000..990fb24 --- /dev/null +++ b/internal_test.go @@ -0,0 +1,35 @@ +package jsonrpc2 + +import ( + "encoding/json" + "testing" +) + +func TestAnyMessage(t *testing.T) { + tests := map[string]struct { + request, response, invalid bool + }{ + // Single messages + `{}`: {invalid: true}, + `{"foo":"bar"}`: {invalid: true}, + `{"method":"m"}`: {request: true}, + `{"result":123}`: {response: true}, + `{"result":null}`: {response: true}, + `{"error":{"code":456,"message":"m"}}`: {response: true}, + } + for s, want := range tests { + var m anyMessage + if err := json.Unmarshal([]byte(s), &m); err != nil { + if !want.invalid { + t.Errorf("%s: error: %v", s, err) + } + continue + } + if (m.request != nil) != want.request { + t.Errorf("%s: got request %v, want %v", s, m.request != nil, want.request) + } + if (m.response != nil) != want.response { + t.Errorf("%s: got response %v, want %v", s, m.response != nil, want.response) + } + } +} diff --git a/jsonrpc2.go b/jsonrpc2.go index 005b65c..7d3e132 100644 --- a/jsonrpc2.go +++ b/jsonrpc2.go @@ -3,16 +3,11 @@ package jsonrpc2 import ( - "bytes" "context" "encoding/json" "errors" "fmt" - "io" - "log" - "os" "strconv" - "sync" ) // JSONRPC2 describes an interface for issuing requests that speak the @@ -30,176 +25,14 @@ type JSONRPC2 interface { Close() error } -// Request represents a JSON-RPC request or -// notification. See -// http://www.jsonrpc.org/specification#request_object and -// http://www.jsonrpc.org/specification#notification. -type Request struct { - Method string `json:"method"` - Params *json.RawMessage `json:"params,omitempty"` - ID ID `json:"id"` - Notif bool `json:"-"` - - // Meta optionally provides metadata to include in the request. - // - // NOTE: It is not part of spec. However, it is useful for propogating - // tracing context, etc. - Meta *json.RawMessage `json:"meta,omitempty"` -} - -// MarshalJSON implements json.Marshaler and adds the "jsonrpc":"2.0" -// property. -func (r Request) MarshalJSON() ([]byte, error) { - r2 := struct { - Method string `json:"method"` - Params *json.RawMessage `json:"params,omitempty"` - ID *ID `json:"id,omitempty"` - Meta *json.RawMessage `json:"meta,omitempty"` - JSONRPC string `json:"jsonrpc"` - }{ - Method: r.Method, - Params: r.Params, - Meta: r.Meta, - JSONRPC: "2.0", - } - if !r.Notif { - r2.ID = &r.ID - } - return json.Marshal(r2) -} - -// UnmarshalJSON implements json.Unmarshaler. -func (r *Request) UnmarshalJSON(data []byte) error { - var r2 struct { - Method string `json:"method"` - Params *json.RawMessage `json:"params,omitempty"` - Meta *json.RawMessage `json:"meta,omitempty"` - ID *ID `json:"id"` - } - - // Detect if the "params" field is JSON "null" or just not present - // by seeing if the field gets overwritten to nil. - r2.Params = &json.RawMessage{} - - if err := json.Unmarshal(data, &r2); err != nil { - return err - } - r.Method = r2.Method - switch { - case r2.Params == nil: - r.Params = &jsonNull - case len(*r2.Params) == 0: - r.Params = nil - default: - r.Params = r2.Params - } - r.Meta = r2.Meta - if r2.ID == nil { - r.ID = ID{} - r.Notif = true - } else { - r.ID = *r2.ID - r.Notif = false - } - return nil -} - -// SetParams sets r.Params to the JSON representation of v. If JSON -// marshaling fails, it returns an error. -func (r *Request) SetParams(v interface{}) error { - b, err := json.Marshal(v) - if err != nil { - return err - } - r.Params = (*json.RawMessage)(&b) - return nil -} - -// SetMeta sets r.Meta to the JSON representation of v. If JSON -// marshaling fails, it returns an error. -func (r *Request) SetMeta(v interface{}) error { - b, err := json.Marshal(v) - if err != nil { - return err - } - r.Meta = (*json.RawMessage)(&b) - return nil -} - -// Response represents a JSON-RPC response. See -// http://www.jsonrpc.org/specification#response_object. -type Response struct { - ID ID `json:"id"` - Result *json.RawMessage `json:"result,omitempty"` - Error *Error `json:"error,omitempty"` - - // Meta optionally provides metadata to include in the response. - // - // NOTE: It is not part of spec. However, it is useful for propogating - // tracing context, etc. - Meta *json.RawMessage `json:"meta,omitempty"` - - // SPEC NOTE: The spec says "If there was an error in detecting - // the id in the Request object (e.g. Parse error/Invalid - // Request), it MUST be Null." If we made the ID field nullable, - // then we'd have to make it a pointer type. For simplicity, we're - // ignoring the case where there was an error in detecting the ID - // in the Request object. -} - -// MarshalJSON implements json.Marshaler and adds the "jsonrpc":"2.0" -// property. -func (r Response) MarshalJSON() ([]byte, error) { - if (r.Result == nil || len(*r.Result) == 0) && r.Error == nil { - return nil, errors.New("can't marshal *jsonrpc2.Response (must have result or error)") - } - type tmpType Response // avoid infinite MarshalJSON recursion - b, err := json.Marshal(tmpType(r)) - if err != nil { - return nil, err - } - b = append(b[:len(b)-1], []byte(`,"jsonrpc":"2.0"}`)...) - return b, nil -} - -// UnmarshalJSON implements json.Unmarshaler. -func (r *Response) UnmarshalJSON(data []byte) error { - type tmpType Response - - // Detect if the "result" field is JSON "null" or just not present - // by seeing if the field gets overwritten to nil. - *r = Response{Result: &json.RawMessage{}} - - if err := json.Unmarshal(data, (*tmpType)(r)); err != nil { - return err - } - if r.Result == nil { // JSON "null" - r.Result = &jsonNull - } else if len(*r.Result) == 0 { - r.Result = nil - } - return nil -} - -// SetResult sets r.Result to the JSON representation of v. If JSON -// marshaling fails, it returns an error. -func (r *Response) SetResult(v interface{}) error { - b, err := json.Marshal(v) - if err != nil { - return err - } - r.Result = (*json.RawMessage)(&b) - return nil -} - // Error represents a JSON-RPC response error. type Error struct { Code int64 `json:"code"` Message string `json:"message"` - Data *json.RawMessage `json:"data"` + Data *json.RawMessage `json:"data,omitempty"` } -// SetError sets e.Error to the JSON representation of v. If JSON +// SetError sets e.Data to the JSON encoding of v. If JSON // marshaling fails, it panics. func (e *Error) SetError(v interface{}) { b, err := json.Marshal(v) @@ -226,10 +59,10 @@ const ( // Handler handles JSON-RPC requests and notifications. type Handler interface { - // Handle is called to handle a request. No other requests are handled - // until it returns. If you do not require strict ordering behavior - // of received RPCs, it is suggested to wrap your handler in - // AsyncHandler. + // Handle is called to handle a request. No other requests are handled until + // it returns. If you do not require strict ordering behavior of received + // RPCs, it is suggested to wrap your handler in AsyncHandler. The context + // is automatically canceled when the connection closes. Handle(context.Context, *Conn, *Request) } @@ -279,455 +112,8 @@ func (id *ID) UnmarshalJSON(data []byte) error { return nil } -// Conn is a JSON-RPC client/server connection. The JSON-RPC protocol -// is symmetric, so a Conn runs on both ends of a client-server -// connection. -type Conn struct { - stream ObjectStream - - h Handler - - mu sync.Mutex - shutdown bool - closing bool - seq uint64 - pending map[ID]*call - - sending sync.Mutex - - disconnect chan struct{} - - logger Logger - - // Set by ConnOpt funcs. - onRecv []func(*Request, *Response) - onSend []func(*Request, *Response) -} - -var _ JSONRPC2 = (*Conn)(nil) - // ErrClosed indicates that the JSON-RPC connection is closed (or in // the process of closing). var ErrClosed = errors.New("jsonrpc2: connection is closed") -// NewConn creates a new JSON-RPC client/server connection using the -// given ReadWriteCloser (typically a TCP connection or stdio). The -// JSON-RPC protocol is symmetric, so a Conn runs on both ends of a -// client-server connection. -// -// NewClient consumes conn, so you should call Close on the returned -// client not on the given conn. -func NewConn(ctx context.Context, stream ObjectStream, h Handler, opts ...ConnOpt) *Conn { - c := &Conn{ - stream: stream, - h: h, - pending: map[ID]*call{}, - disconnect: make(chan struct{}), - logger: log.New(os.Stderr, "", log.LstdFlags), - } - for _, opt := range opts { - if opt == nil { - continue - } - opt(c) - } - go c.readMessages(ctx) - return c -} - -// Close closes the JSON-RPC connection. The connection may not be -// used after it has been closed. -func (c *Conn) Close() error { - c.mu.Lock() - if c.shutdown || c.closing { - c.mu.Unlock() - return ErrClosed - } - c.closing = true - c.mu.Unlock() - return c.stream.Close() -} - -func (c *Conn) send(_ context.Context, m *anyMessage, wait bool) (cc *call, err error) { - c.sending.Lock() - defer c.sending.Unlock() - - // m.request.ID could be changed, so we store a copy to correctly - // clean up pending - var id ID - - c.mu.Lock() - if c.shutdown || c.closing { - c.mu.Unlock() - return nil, ErrClosed - } - - // Store requests so we can later associate them with incoming - // responses. - if m.request != nil && wait { - cc = &call{request: m.request, seq: c.seq, done: make(chan error, 1)} - - isIDUnset := len(m.request.ID.Str) == 0 && m.request.ID.Num == 0 - if isIDUnset { - if m.request.ID.IsString { - m.request.ID.Str = strconv.FormatUint(c.seq, 10) - } else { - m.request.ID.Num = c.seq - } - } - id = m.request.ID - c.pending[id] = cc - c.seq++ - } - c.mu.Unlock() - - if len(c.onSend) > 0 { - var ( - req *Request - resp *Response - ) - switch { - case m.request != nil: - req = m.request - case m.response != nil: - resp = m.response - } - for _, onSend := range c.onSend { - onSend(req, resp) - } - } - - // From here on, if we fail to send this, then we need to remove - // this from the pending map so we don't block on it or pile up - // pending entries for unsent messages. - defer func() { - if err != nil { - if cc != nil { - c.mu.Lock() - delete(c.pending, id) - c.mu.Unlock() - } - } - }() - - if err := c.stream.WriteObject(m); err != nil { - return nil, err - } - return cc, nil -} - -// Call initiates a JSON-RPC call using the specified method and -// params, and waits for the response. If the response is successful, -// its result is stored in result (a pointer to a value that can be -// JSON-unmarshaled into); otherwise, a non-nil error is returned. -func (c *Conn) Call(ctx context.Context, method string, params, result interface{}, opts ...CallOption) error { - call, err := c.DispatchCall(ctx, method, params, opts...) - if err != nil { - return err - } - return call.Wait(ctx, result) -} - -// DispatchCall dispatches a JSON-RPC call using the specified method -// and params, and returns a call proxy or an error. Call Wait() -// on the returned proxy to receive the response. Only use this -// function if you need to do work after dispatching the request, -// otherwise use Call. -func (c *Conn) DispatchCall(ctx context.Context, method string, params interface{}, opts ...CallOption) (Waiter, error) { - req := &Request{Method: method} - if err := req.SetParams(params); err != nil { - return Waiter{}, err - } - for _, opt := range opts { - if opt == nil { - continue - } - if err := opt.apply(req); err != nil { - return Waiter{}, err - } - } - call, err := c.send(ctx, &anyMessage{request: req}, true) - if err != nil { - return Waiter{}, err - } - return Waiter{call: call}, nil -} - -// Waiter proxies an ongoing JSON-RPC call. -type Waiter struct { - *call -} - -// Wait for the result of an ongoing JSON-RPC call. If the response -// is successful, its result is stored in result (a pointer to a -// value that can be JSON-unmarshaled into); otherwise, a non-nil -// error is returned. -func (w Waiter) Wait(ctx context.Context, result interface{}) error { - select { - case err, ok := <-w.call.done: - if !ok { - err = ErrClosed - } - if err != nil { - return err - } - if result != nil { - if w.call.response.Result == nil { - w.call.response.Result = &jsonNull - } - if err := json.Unmarshal(*w.call.response.Result, result); err != nil { - return err - } - } - return nil - - case <-ctx.Done(): - return ctx.Err() - } -} - var jsonNull = json.RawMessage("null") - -// Notify is like Call, but it returns when the notification request -// is sent (without waiting for a response, because JSON-RPC -// notifications do not have responses). -func (c *Conn) Notify(ctx context.Context, method string, params interface{}, opts ...CallOption) error { - req := &Request{Method: method, Notif: true} - if err := req.SetParams(params); err != nil { - return err - } - for _, opt := range opts { - if opt == nil { - continue - } - if err := opt.apply(req); err != nil { - return err - } - } - _, err := c.send(ctx, &anyMessage{request: req}, false) - return err -} - -// Reply sends a successful response with a result. -func (c *Conn) Reply(ctx context.Context, id ID, result interface{}) error { - resp := &Response{ID: id} - if err := resp.SetResult(result); err != nil { - return err - } - _, err := c.send(ctx, &anyMessage{response: resp}, false) - return err -} - -// ReplyWithError sends a response with an error. -func (c *Conn) ReplyWithError(ctx context.Context, id ID, respErr *Error) error { - _, err := c.send(ctx, &anyMessage{response: &Response{ID: id, Error: respErr}}, false) - return err -} - -// SendResponse sends resp to the peer. It is lower level than (*Conn).Reply. -func (c *Conn) SendResponse(ctx context.Context, resp *Response) error { - _, err := c.send(ctx, &anyMessage{response: resp}, false) - return err -} - -// DisconnectNotify returns a channel that is closed when the -// underlying connection is disconnected. -func (c *Conn) DisconnectNotify() <-chan struct{} { - return c.disconnect -} - -func (c *Conn) readMessages(ctx context.Context) { - var err error - for err == nil { - var m anyMessage - err = c.stream.ReadObject(&m) - if err != nil { - break - } - - switch { - case m.request != nil: - for _, onRecv := range c.onRecv { - onRecv(m.request, nil) - } - c.h.Handle(ctx, c, m.request) - - case m.response != nil: - resp := m.response - if resp != nil { - id := resp.ID - c.mu.Lock() - call := c.pending[id] - delete(c.pending, id) - c.mu.Unlock() - - if call != nil { - call.response = resp - } - - if len(c.onRecv) > 0 { - var req *Request - if call != nil { - req = call.request - } - for _, onRecv := range c.onRecv { - onRecv(req, resp) - } - } - - switch { - case call == nil: - c.logger.Printf("jsonrpc2: ignoring response #%s with no corresponding request\n", id) - - case resp.Error != nil: - call.done <- resp.Error - close(call.done) - - default: - call.done <- nil - close(call.done) - } - } - } - } - - c.sending.Lock() - c.mu.Lock() - c.shutdown = true - closing := c.closing - if err == io.EOF { - if closing { - err = ErrClosed - } else { - err = io.ErrUnexpectedEOF - } - } - for _, call := range c.pending { - call.done <- err - close(call.done) - } - c.mu.Unlock() - c.sending.Unlock() - if err != io.ErrUnexpectedEOF && !closing { - c.logger.Printf("jsonrpc2: protocol error: %v\n", err) - } - close(c.disconnect) -} - -// call represents a JSON-RPC call over its entire lifecycle. -type call struct { - request *Request - response *Response - seq uint64 // the seq of the request - done chan error -} - -// anyMessage represents either a JSON Request or Response. -type anyMessage struct { - request *Request - response *Response -} - -func (m anyMessage) MarshalJSON() ([]byte, error) { - var v interface{} - switch { - case m.request != nil && m.response == nil: - v = m.request - case m.request == nil && m.response != nil: - v = m.response - } - if v != nil { - return json.Marshal(v) - } - return nil, errors.New("jsonrpc2: message must have exactly one of the request or response fields set") -} - -func (m *anyMessage) UnmarshalJSON(data []byte) error { - // The presence of these fields distinguishes between the 2 - // message types. - type msg struct { - ID interface{} `json:"id"` - Method *string `json:"method"` - Result anyValueWithExplicitNull `json:"result"` - Error interface{} `json:"error"` - } - - var isRequest, isResponse bool - checkType := func(m *msg) error { - mIsRequest := m.Method != nil - mIsResponse := m.Result.null || m.Result.value != nil || m.Error != nil - if (!mIsRequest && !mIsResponse) || (mIsRequest && mIsResponse) { - return errors.New("jsonrpc2: unable to determine message type (request or response)") - } - if (mIsRequest && isResponse) || (mIsResponse && isRequest) { - return errors.New("jsonrpc2: batch message type mismatch (must be all requests or all responses)") - } - isRequest = mIsRequest - isResponse = mIsResponse - return nil - } - - if isArray := len(data) > 0 && data[0] == '['; isArray { - var msgs []msg - if err := json.Unmarshal(data, &msgs); err != nil { - return err - } - if len(msgs) == 0 { - return errors.New("jsonrpc2: invalid empty batch") - } - for i := range msgs { - if err := checkType(&msg{ - ID: msgs[i].ID, - Method: msgs[i].Method, - Result: msgs[i].Result, - Error: msgs[i].Error, - }); err != nil { - return err - } - } - } else { - var m msg - if err := json.Unmarshal(data, &m); err != nil { - return err - } - if err := checkType(&m); err != nil { - return err - } - } - - var v interface{} - switch { - case isRequest && !isResponse: - v = &m.request - case !isRequest && isResponse: - v = &m.response - } - if err := json.Unmarshal(data, v); err != nil { - return err - } - if !isRequest && isResponse && m.response.Error == nil && m.response.Result == nil { - m.response.Result = &jsonNull - } - return nil -} - -// anyValueWithExplicitNull is used to distinguish {} from -// {"result":null} by anyMessage's JSON unmarshaler. -type anyValueWithExplicitNull struct { - null bool // JSON "null" - value interface{} -} - -func (v anyValueWithExplicitNull) MarshalJSON() ([]byte, error) { - return json.Marshal(v.value) -} - -func (v *anyValueWithExplicitNull) UnmarshalJSON(data []byte) error { - data = bytes.TrimSpace(data) - if string(data) == "null" { - *v = anyValueWithExplicitNull{null: true} - return nil - } - *v = anyValueWithExplicitNull{} - return json.Unmarshal(data, &v.value) -} diff --git a/jsonrpc2_test.go b/jsonrpc2_test.go index c319a1b..8d7968f 100644 --- a/jsonrpc2_test.go +++ b/jsonrpc2_test.go @@ -1,7 +1,6 @@ package jsonrpc2_test import ( - "bytes" "context" "encoding/json" "fmt" @@ -19,62 +18,61 @@ import ( websocketjsonrpc2 "github.com/sourcegraph/jsonrpc2/websocket" ) -func TestRequest_MarshalJSON_jsonrpc(t *testing.T) { - b, err := json.Marshal(&jsonrpc2.Request{}) - if err != nil { - t.Fatal(err) +func TestError_MarshalJSON(t *testing.T) { + tests := []struct { + name string + setError func(err *jsonrpc2.Error) + want string + }{ + { + name: "Data == nil", + want: `{"code":-32603,"message":"Internal error"}`, + }, + { + name: "Error.SetError(nil)", + setError: func(err *jsonrpc2.Error) { + err.SetError(nil) + }, + want: `{"code":-32603,"message":"Internal error","data":null}`, + }, + { + name: "Error.SetError(0)", + setError: func(err *jsonrpc2.Error) { + err.SetError(0) + }, + want: `{"code":-32603,"message":"Internal error","data":0}`, + }, + { + name: `Error.SetError("")`, + setError: func(err *jsonrpc2.Error) { + err.SetError("") + }, + want: `{"code":-32603,"message":"Internal error","data":""}`, + }, + { + name: `Error.SetError(false)`, + setError: func(err *jsonrpc2.Error) { + err.SetError(false) + }, + want: `{"code":-32603,"message":"Internal error","data":false}`, + }, } - if want := `{"method":"","id":0,"jsonrpc":"2.0"}`; string(b) != want { - t.Errorf("got %q, want %q", b, want) - } -} -func TestResponse_MarshalJSON_jsonrpc(t *testing.T) { - null := json.RawMessage("null") - b, err := json.Marshal(&jsonrpc2.Response{Result: &null}) - if err != nil { - t.Fatal(err) - } - if want := `{"id":0,"result":null,"jsonrpc":"2.0"}`; string(b) != want { - t.Errorf("got %q, want %q", b, want) - } -} - -func TestResponseMarshalJSON_Notif(t *testing.T) { - tests := map[*jsonrpc2.Request]bool{ - {ID: jsonrpc2.ID{Num: 0}}: true, - {ID: jsonrpc2.ID{Num: 1}}: true, - {ID: jsonrpc2.ID{Str: "", IsString: true}}: true, - {ID: jsonrpc2.ID{Str: "a", IsString: true}}: true, - {Notif: true}: false, - } - for r, wantIDKey := range tests { - b, err := json.Marshal(r) + for _, test := range tests { + e := &jsonrpc2.Error{ + Code: jsonrpc2.CodeInternalError, + Message: "Internal error", + } + if test.setError != nil { + test.setError(e) + } + b, err := json.Marshal(e) if err != nil { - t.Fatal(err) + t.Error(err) } - hasIDKey := bytes.Contains(b, []byte(`"id"`)) - if hasIDKey != wantIDKey { - t.Errorf("got %s, want contain id key: %v", b, wantIDKey) - } - } -} - -func TestResponseUnmarshalJSON_Notif(t *testing.T) { - tests := map[string]bool{ - `{"method":"f","id":0}`: false, - `{"method":"f","id":1}`: false, - `{"method":"f","id":"a"}`: false, - `{"method":"f","id":""}`: false, - `{"method":"f"}`: true, - } - for s, want := range tests { - var r jsonrpc2.Request - if err := json.Unmarshal([]byte(s), &r); err != nil { - t.Fatal(err) - } - if r.Notif != want { - t.Errorf("%s: got %v, want %v", s, r.Notif, want) + got := string(b) + if got != test.want { + t.Fatalf("%s: got %q, want %q", test.name, got, test.want) } } } @@ -112,44 +110,67 @@ func (h *testHandlerB) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jso h.t.Errorf("testHandlerB got unexpected request %+v", req) } +type streamMaker func(conn io.ReadWriteCloser) jsonrpc2.ObjectStream + +func testClientServerForCodec(t *testing.T, streamMaker streamMaker) { + ctx := context.Background() + done := make(chan struct{}) + + lis, err := net.Listen("tcp", "127.0.0.1:0") // any available address + if err != nil { + t.Fatal("Listen:", err) + } + defer func() { + if lis == nil { + return // already closed + } + if err = lis.Close(); err != nil { + if !strings.HasSuffix(err.Error(), "use of closed network connection") { + t.Fatal(err) + } + } + }() + + ha := testHandlerA{t: t} + go func() { + if err = serve(ctx, lis, &ha, streamMaker); err != nil { + if !strings.HasSuffix(err.Error(), "use of closed network connection") { + t.Error(err) + } + } + close(done) + }() + + conn, err := net.Dial("tcp", lis.Addr().String()) + if err != nil { + t.Fatal("Dial:", err) + } + testClientServer(ctx, t, streamMaker(conn)) + + lis.Close() + <-done // ensure Serve's error return (if any) is caught by this test +} + func TestClientServer(t *testing.T) { - t.Run("tcp", func(t *testing.T) { - ctx := context.Background() - done := make(chan struct{}) - - lis, err := net.Listen("tcp", "127.0.0.1:0") // any available address - if err != nil { - t.Fatal("Listen:", err) - } - defer func() { - if lis == nil { - return // already closed - } - if err = lis.Close(); err != nil { - if !strings.HasSuffix(err.Error(), "use of closed network connection") { - t.Fatal(err) - } - } - }() - - ha := testHandlerA{t: t} - go func() { - if err = serve(ctx, lis, &ha); err != nil { - if !strings.HasSuffix(err.Error(), "use of closed network connection") { - t.Error(err) - } - } - close(done) - }() - - conn, err := net.Dial("tcp", lis.Addr().String()) - if err != nil { - t.Fatal("Dial:", err) - } - testClientServer(ctx, t, jsonrpc2.NewBufferedStream(conn, jsonrpc2.VarintObjectCodec{})) - - lis.Close() - <-done // ensure Serve's error return (if any) is caught by this test + t.Run("tcp-varint-object-codec", func(t *testing.T) { + testClientServerForCodec(t, func(conn io.ReadWriteCloser) jsonrpc2.ObjectStream { + return jsonrpc2.NewBufferedStream(conn, jsonrpc2.VarintObjectCodec{}) + }) + }) + t.Run("tcp-vscode-object-codec", func(t *testing.T) { + testClientServerForCodec(t, func(conn io.ReadWriteCloser) jsonrpc2.ObjectStream { + return jsonrpc2.NewBufferedStream(conn, jsonrpc2.VSCodeObjectCodec{}) + }) + }) + t.Run("tcp-plain-object-codec", func(t *testing.T) { + testClientServerForCodec(t, func(conn io.ReadWriteCloser) jsonrpc2.ObjectStream { + return jsonrpc2.NewBufferedStream(conn, jsonrpc2.PlainObjectCodec{}) + }) + }) + t.Run("tcp-plain-object-stream", func(t *testing.T) { + testClientServerForCodec(t, func(conn io.ReadWriteCloser) jsonrpc2.ObjectStream { + return jsonrpc2.NewPlainObjectStream(conn) + }) }) t.Run("websocket", func(t *testing.T) { ctx := context.Background() @@ -291,88 +312,19 @@ type noopHandler struct{} func (noopHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {} -type readWriteCloser struct { - read, write func(p []byte) (n int, err error) -} - -func (x readWriteCloser) Read(p []byte) (n int, err error) { - return x.read(p) -} - -func (x readWriteCloser) Write(p []byte) (n int, err error) { - return x.write(p) -} - -func (readWriteCloser) Close() error { return nil } - -func eof(p []byte) (n int, err error) { - return 0, io.EOF -} - -func TestConn_DisconnectNotify_EOF(t *testing.T) { - c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewBufferedStream(&readWriteCloser{eof, eof}, jsonrpc2.VarintObjectCodec{}), nil) - select { - case <-c.DisconnectNotify(): - case <-time.After(200 * time.Millisecond): - t.Fatal("no disconnect notification") - } -} - -func TestConn_DisconnectNotify_Close(t *testing.T) { - c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewBufferedStream(&readWriteCloser{eof, eof}, jsonrpc2.VarintObjectCodec{}), nil) - if err := c.Close(); err != nil { - t.Error(err) - } - select { - case <-c.DisconnectNotify(): - case <-time.After(200 * time.Millisecond): - t.Fatal("no disconnect notification") - } -} - -func TestConn_DisconnectNotify_Close_async(t *testing.T) { - done := make(chan struct{}) - c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewBufferedStream(&readWriteCloser{eof, eof}, jsonrpc2.VarintObjectCodec{}), nil) - go func() { - if err := c.Close(); err != nil && err != jsonrpc2.ErrClosed { - t.Error(err) - } - close(done) - }() - select { - case <-c.DisconnectNotify(): - case <-time.After(200 * time.Millisecond): - t.Fatal("no disconnect notification") - } - <-done -} - -func TestConn_Close_waitingForResponse(t *testing.T) { - c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewBufferedStream(&readWriteCloser{eof, eof}, jsonrpc2.VarintObjectCodec{}), noopHandler{}) - done := make(chan struct{}) - go func() { - if err := c.Call(context.Background(), "m", nil, nil); err != jsonrpc2.ErrClosed { - t.Errorf("got error %v, want %v", err, jsonrpc2.ErrClosed) - } - close(done) - }() - if err := c.Close(); err != nil && err != jsonrpc2.ErrClosed { - t.Error(err) - } - select { - case <-c.DisconnectNotify(): - case <-time.After(200 * time.Millisecond): - t.Fatal("no disconnect notification") - } - <-done -} - -func serve(ctx context.Context, lis net.Listener, h jsonrpc2.Handler, opts ...jsonrpc2.ConnOpt) error { +func serve(ctx context.Context, lis net.Listener, h jsonrpc2.Handler, streamMaker streamMaker, opts ...jsonrpc2.ConnOpt) error { for { conn, err := lis.Accept() if err != nil { return err } - jsonrpc2.NewConn(ctx, jsonrpc2.NewBufferedStream(conn, jsonrpc2.VarintObjectCodec{}), h, opts...) + jsonrpc2.NewConn(ctx, streamMaker(conn), h, opts...) } } + +func rawJSONMessage(v string) *json.RawMessage { + b := []byte(v) + return (*json.RawMessage)(&b) +} + +var jsonNull = json.RawMessage("null") diff --git a/object_test.go b/object_test.go deleted file mode 100644 index 2430e3b..0000000 --- a/object_test.go +++ /dev/null @@ -1,128 +0,0 @@ -package jsonrpc2 - -import ( - "bytes" - "encoding/json" - "reflect" - "testing" -) - -func TestAnyMessage(t *testing.T) { - tests := map[string]struct { - request, response, invalid bool - }{ - // Single messages - `{}`: {invalid: true}, - `{"foo":"bar"}`: {invalid: true}, - `{"method":"m"}`: {request: true}, - `{"result":123}`: {response: true}, - `{"result":null}`: {response: true}, - `{"error":{"code":456,"message":"m"}}`: {response: true}, - } - for s, want := range tests { - var m anyMessage - if err := json.Unmarshal([]byte(s), &m); err != nil { - if !want.invalid { - t.Errorf("%s: error: %v", s, err) - } - continue - } - if (m.request != nil) != want.request { - t.Errorf("%s: got request %v, want %v", s, m.request != nil, want.request) - } - if (m.response != nil) != want.response { - t.Errorf("%s: got response %v, want %v", s, m.response != nil, want.response) - } - } -} - -func TestRequest_MarshalUnmarshalJSON(t *testing.T) { - null := json.RawMessage("null") - obj := json.RawMessage(`{"foo":"bar"}`) - tests := []struct { - data []byte - want Request - }{ - { - data: []byte(`{"method":"m","params":{"foo":"bar"},"id":123,"jsonrpc":"2.0"}`), - want: Request{ID: ID{Num: 123}, Method: "m", Params: &obj}, - }, - { - data: []byte(`{"method":"m","params":null,"id":123,"jsonrpc":"2.0"}`), - want: Request{ID: ID{Num: 123}, Method: "m", Params: &null}, - }, - { - data: []byte(`{"method":"m","id":123,"jsonrpc":"2.0"}`), - want: Request{ID: ID{Num: 123}, Method: "m", Params: nil}, - }, - } - for _, test := range tests { - var got Request - if err := json.Unmarshal(test.data, &got); err != nil { - t.Error(err) - continue - } - if !reflect.DeepEqual(got, test.want) { - t.Errorf("%q: got %+v, want %+v", test.data, got, test.want) - continue - } - data, err := json.Marshal(got) - if err != nil { - t.Error(err) - continue - } - if !bytes.Equal(data, test.data) { - t.Errorf("got JSON %q, want %q", data, test.data) - } - } -} - -func TestResponse_MarshalUnmarshalJSON(t *testing.T) { - null := json.RawMessage("null") - obj := json.RawMessage(`{"foo":"bar"}`) - tests := []struct { - data []byte - want Response - error bool - }{ - { - data: []byte(`{"id":123,"result":{"foo":"bar"},"jsonrpc":"2.0"}`), - want: Response{ID: ID{Num: 123}, Result: &obj}, - }, - { - data: []byte(`{"id":123,"result":null,"jsonrpc":"2.0"}`), - want: Response{ID: ID{Num: 123}, Result: &null}, - }, - { - data: []byte(`{"id":123,"jsonrpc":"2.0"}`), - want: Response{ID: ID{Num: 123}, Result: nil}, - error: true, // either result or error field must be set - }, - } - for _, test := range tests { - var got Response - if err := json.Unmarshal(test.data, &got); err != nil { - t.Error(err) - continue - } - if !reflect.DeepEqual(got, test.want) { - t.Errorf("%q: got %+v, want %+v", test.data, got, test.want) - continue - } - data, err := json.Marshal(got) - if err != nil { - if test.error { - continue - } - t.Error(err) - continue - } - if test.error { - t.Errorf("%q: expected error", test.data) - continue - } - if !bytes.Equal(data, test.data) { - t.Errorf("got JSON %q, want %q", data, test.data) - } - } -} diff --git a/request.go b/request.go new file mode 100644 index 0000000..b9cdde0 --- /dev/null +++ b/request.go @@ -0,0 +1,178 @@ +package jsonrpc2 + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" +) + +// Request represents a JSON-RPC request or +// notification. See +// http://www.jsonrpc.org/specification#request_object and +// http://www.jsonrpc.org/specification#notification. +type Request struct { + Method string `json:"method"` + Params *json.RawMessage `json:"params,omitempty"` + ID ID `json:"id"` + Notif bool `json:"-"` + + // Meta optionally provides metadata to include in the request. + // + // NOTE: It is not part of spec. However, it is useful for propagating + // tracing context, etc. + Meta *json.RawMessage `json:"meta,omitempty"` + + // ExtraFields optionally adds fields to the root of the JSON-RPC request. + // + // NOTE: It is not part of the spec, but there are other protocols based on + // JSON-RPC 2 that require it. + ExtraFields []RequestField `json:"-"` +} + +// MarshalJSON implements json.Marshaler and adds the "jsonrpc":"2.0" +// property. +func (r Request) MarshalJSON() ([]byte, error) { + r2 := map[string]interface{}{ + "jsonrpc": "2.0", + "method": r.Method, + } + for _, field := range r.ExtraFields { + r2[field.Name] = field.Value + } + if !r.Notif { + r2["id"] = &r.ID + } + if r.Params != nil { + r2["params"] = r.Params + } + if r.Meta != nil { + r2["meta"] = r.Meta + } + return json.Marshal(r2) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (r *Request) UnmarshalJSON(data []byte) error { + r2 := make(map[string]interface{}) + pop := func(key string) interface{} { + defer delete(r2, key) + return r2[key] + } + + // Detect if the "params" or "meta" fields are JSON "null" or just not + // present by seeing if the field gets overwritten to nil. + emptyParams := &json.RawMessage{} + r2["params"] = emptyParams + emptyMeta := &json.RawMessage{} + r2["meta"] = emptyMeta + + decoder := json.NewDecoder(bytes.NewReader(data)) + decoder.UseNumber() + if err := decoder.Decode(&r2); err != nil { + return err + } + + var ok bool + r.Method, ok = pop("method").(string) + if !ok { + return errors.New("missing method field") + } + switch params := pop("params"); params { + case nil: + r.Params = &jsonNull + case emptyParams: + r.Params = nil + default: + b, err := json.Marshal(params) + if err != nil { + return fmt.Errorf("failed to marshal params: %w", err) + } + r.Params = (*json.RawMessage)(&b) + } + switch meta := pop("meta"); meta { + case nil: + r.Meta = &jsonNull + case emptyMeta: + r.Meta = nil + default: + b, err := json.Marshal(meta) + if err != nil { + return fmt.Errorf("failed to marshal Meta: %w", err) + } + r.Meta = (*json.RawMessage)(&b) + } + switch rawID := pop("id").(type) { + case nil: + r.ID = ID{} + r.Notif = true + case string: + r.ID = ID{Str: rawID, IsString: true} + r.Notif = false + case json.Number: + id, err := rawID.Int64() + if err != nil { + return fmt.Errorf("failed to unmarshal ID: %w", err) + } + r.ID = ID{Num: uint64(id)} + r.Notif = false + default: + return fmt.Errorf("unexpected ID type: %T", rawID) + } + + // The jsonrpc field should not be added to ExtraFields. + delete(r2, "jsonrpc") + + // Clear the extra fields before populating them again. + r.ExtraFields = nil + for name, value := range r2 { + r.ExtraFields = append(r.ExtraFields, RequestField{ + Name: name, + Value: value, + }) + } + return nil +} + +// SetParams sets r.Params to the JSON encoding of v. If JSON +// marshaling fails, it returns an error. +func (r *Request) SetParams(v interface{}) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + r.Params = (*json.RawMessage)(&b) + return nil +} + +// SetMeta sets r.Meta to the JSON encoding of v. If JSON +// marshaling fails, it returns an error. +func (r *Request) SetMeta(v interface{}) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + r.Meta = (*json.RawMessage)(&b) + return nil +} + +// SetExtraField adds an entry to r.ExtraFields, so that it is added to the +// JSON encoding of the request, as a way to add arbitrary extensions to +// JSON RPC 2.0. If JSON marshaling fails, it returns an error. +func (r *Request) SetExtraField(name string, v interface{}) error { + switch name { + case "id", "jsonrpc", "meta", "method", "params": + return fmt.Errorf("invalid extra field %q", name) + } + r.ExtraFields = append(r.ExtraFields, RequestField{ + Name: name, + Value: v, + }) + return nil +} + +// RequestField is a top-level field that can be added to the JSON-RPC request. +type RequestField struct { + Name string + Value interface{} +} diff --git a/request_test.go b/request_test.go new file mode 100644 index 0000000..0e7a7f4 --- /dev/null +++ b/request_test.go @@ -0,0 +1,64 @@ +package jsonrpc2_test + +import ( + "bytes" + "encoding/json" + "reflect" + "testing" + + "github.com/sourcegraph/jsonrpc2" +) + +func TestRequest_MarshalJSON_jsonrpc(t *testing.T) { + b, err := json.Marshal(&jsonrpc2.Request{}) + if err != nil { + t.Fatal(err) + } + if want := `{"id":0,"jsonrpc":"2.0","method":""}`; string(b) != want { + t.Errorf("got %q, want %q", b, want) + } +} + +func TestRequest_MarshalUnmarshalJSON(t *testing.T) { + obj := json.RawMessage(`{"foo":"bar"}`) + tests := []struct { + data []byte + want jsonrpc2.Request + }{ + { + data: []byte(`{"id":123,"jsonrpc":"2.0","method":"m","params":{"foo":"bar"}}`), + want: jsonrpc2.Request{ID: jsonrpc2.ID{Num: 123}, Method: "m", Params: &obj}, + }, + { + data: []byte(`{"id":123,"jsonrpc":"2.0","method":"m","params":null}`), + want: jsonrpc2.Request{ID: jsonrpc2.ID{Num: 123}, Method: "m", Params: &jsonNull}, + }, + { + data: []byte(`{"id":123,"jsonrpc":"2.0","method":"m"}`), + want: jsonrpc2.Request{ID: jsonrpc2.ID{Num: 123}, Method: "m", Params: nil}, + }, + { + data: []byte(`{"id":123,"jsonrpc":"2.0","method":"m","sessionId":"session"}`), + want: jsonrpc2.Request{ID: jsonrpc2.ID{Num: 123}, Method: "m", Params: nil, ExtraFields: []jsonrpc2.RequestField{{Name: "sessionId", Value: "session"}}}, + }, + } + for _, test := range tests { + var got jsonrpc2.Request + if err := json.Unmarshal(test.data, &got); err != nil { + t.Error(err) + continue + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("%q: got %+v, want %+v", test.data, got, test.want) + continue + } + data, err := json.Marshal(got) + if err != nil { + t.Error(err) + continue + } + if !bytes.Equal(data, test.data) { + t.Errorf("got JSON %q, want %q", data, test.data) + } + } +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..c9a0bfe --- /dev/null +++ b/response.go @@ -0,0 +1,72 @@ +package jsonrpc2 + +import ( + "encoding/json" + "errors" +) + +// Response represents a JSON-RPC response. See +// http://www.jsonrpc.org/specification#response_object. +type Response struct { + ID ID `json:"id"` + Result *json.RawMessage `json:"result,omitempty"` + Error *Error `json:"error,omitempty"` + + // Meta optionally provides metadata to include in the response. + // + // NOTE: It is not part of spec. However, it is useful for propagating + // tracing context, etc. + Meta *json.RawMessage `json:"meta,omitempty"` + + // SPEC NOTE: The spec says "If there was an error in detecting + // the id in the Request object (e.g. Parse error/Invalid + // Request), it MUST be Null." If we made the ID field nullable, + // then we'd have to make it a pointer type. For simplicity, we're + // ignoring the case where there was an error in detecting the ID + // in the Request object. +} + +// MarshalJSON implements json.Marshaler and adds the "jsonrpc":"2.0" +// property. +func (r Response) MarshalJSON() ([]byte, error) { + if (r.Result == nil || len(*r.Result) == 0) && r.Error == nil { + return nil, errors.New("can't marshal *jsonrpc2.Response (must have result or error)") + } + type tmpType Response // avoid infinite MarshalJSON recursion + b, err := json.Marshal(tmpType(r)) + if err != nil { + return nil, err + } + b = append(b[:len(b)-1], []byte(`,"jsonrpc":"2.0"}`)...) + return b, nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (r *Response) UnmarshalJSON(data []byte) error { + type tmpType Response + + // Detect if the "result" field is JSON "null" or just not present + // by seeing if the field gets overwritten to nil. + *r = Response{Result: &json.RawMessage{}} + + if err := json.Unmarshal(data, (*tmpType)(r)); err != nil { + return err + } + if r.Result == nil { // JSON "null" + r.Result = &jsonNull + } else if len(*r.Result) == 0 { + r.Result = nil + } + return nil +} + +// SetResult sets r.Result to the JSON representation of v. If JSON +// marshaling fails, it returns an error. +func (r *Response) SetResult(v interface{}) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + r.Result = (*json.RawMessage)(&b) + return nil +} diff --git a/response_test.go b/response_test.go new file mode 100644 index 0000000..4819c0e --- /dev/null +++ b/response_test.go @@ -0,0 +1,108 @@ +package jsonrpc2_test + +import ( + "bytes" + "encoding/json" + "reflect" + "testing" + + "github.com/sourcegraph/jsonrpc2" +) + +func TestResponse_MarshalJSON_jsonrpc(t *testing.T) { + b, err := json.Marshal(&jsonrpc2.Response{Result: &jsonNull}) + if err != nil { + t.Fatal(err) + } + if want := `{"id":0,"result":null,"jsonrpc":"2.0"}`; string(b) != want { + t.Errorf("got %q, want %q", b, want) + } +} + +func TestResponseMarshalJSON_Notif(t *testing.T) { + tests := map[*jsonrpc2.Request]bool{ + {ID: jsonrpc2.ID{Num: 0}}: true, + {ID: jsonrpc2.ID{Num: 1}}: true, + {ID: jsonrpc2.ID{Str: "", IsString: true}}: true, + {ID: jsonrpc2.ID{Str: "a", IsString: true}}: true, + {Notif: true}: false, + } + for r, wantIDKey := range tests { + b, err := json.Marshal(r) + if err != nil { + t.Fatal(err) + } + hasIDKey := bytes.Contains(b, []byte(`"id"`)) + if hasIDKey != wantIDKey { + t.Errorf("got %s, want contain id key: %v", b, wantIDKey) + } + } +} + +func TestResponseUnmarshalJSON_Notif(t *testing.T) { + tests := map[string]bool{ + `{"method":"f","id":0}`: false, + `{"method":"f","id":1}`: false, + `{"method":"f","id":"a"}`: false, + `{"method":"f","id":""}`: false, + `{"method":"f"}`: true, + } + for s, want := range tests { + var r jsonrpc2.Request + if err := json.Unmarshal([]byte(s), &r); err != nil { + t.Fatal(err) + } + if r.Notif != want { + t.Errorf("%s: got %v, want %v", s, r.Notif, want) + } + } +} + +func TestResponse_MarshalUnmarshalJSON(t *testing.T) { + obj := json.RawMessage(`{"foo":"bar"}`) + tests := []struct { + data []byte + want jsonrpc2.Response + error bool + }{ + { + data: []byte(`{"id":123,"result":{"foo":"bar"},"jsonrpc":"2.0"}`), + want: jsonrpc2.Response{ID: jsonrpc2.ID{Num: 123}, Result: &obj}, + }, + { + data: []byte(`{"id":123,"result":null,"jsonrpc":"2.0"}`), + want: jsonrpc2.Response{ID: jsonrpc2.ID{Num: 123}, Result: &jsonNull}, + }, + { + data: []byte(`{"id":123,"jsonrpc":"2.0"}`), + want: jsonrpc2.Response{ID: jsonrpc2.ID{Num: 123}, Result: nil}, + error: true, // either result or error field must be set + }, + } + for _, test := range tests { + var got jsonrpc2.Response + if err := json.Unmarshal(test.data, &got); err != nil { + t.Error(err) + continue + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("%q: got %+v, want %+v", test.data, got, test.want) + continue + } + data, err := json.Marshal(got) + if err != nil { + if test.error { + continue + } + t.Error(err) + continue + } + if test.error { + t.Errorf("%q: expected error", test.data) + continue + } + if !bytes.Equal(data, test.data) { + t.Errorf("got JSON %q, want %q", data, test.data) + } + } +} diff --git a/stream.go b/stream.go index e7a9025..ff24d0f 100644 --- a/stream.go +++ b/stream.go @@ -40,6 +40,12 @@ type bufferedObjectStream struct { // objectStream is used to produce the bytes to write to the stream // for the JSON-RPC 2.0 objects. func NewBufferedStream(conn io.ReadWriteCloser, codec ObjectCodec) ObjectStream { + switch v := codec.(type) { + case PlainObjectCodec: + v.decoder = json.NewDecoder(conn) + v.encoder = json.NewEncoder(conn) + codec = v + } return &bufferedObjectStream{ conn: conn, w: bufio.NewWriter(conn), @@ -68,7 +74,7 @@ func (t *bufferedObjectStream) Close() error { return t.conn.Close() } -// An ObjectCodec specifies how to encoed and decode a JSON-RPC 2.0 +// An ObjectCodec specifies how to encode and decode a JSON-RPC 2.0 // object in a stream. type ObjectCodec interface { // WriteObject writes a JSON-RPC 2.0 object to the stream. @@ -164,14 +170,57 @@ func (VSCodeObjectCodec) ReadObject(stream *bufio.Reader, v interface{}) error { } // PlainObjectCodec reads/writes plain JSON-RPC 2.0 objects without a header. -type PlainObjectCodec struct{} +// +// Deprecated: use NewPlainObjectStream +type PlainObjectCodec struct { + decoder *json.Decoder + encoder *json.Encoder +} // WriteObject implements ObjectCodec. -func (PlainObjectCodec) WriteObject(stream io.Writer, v interface{}) error { +func (c PlainObjectCodec) WriteObject(stream io.Writer, v interface{}) error { + if c.encoder != nil { + return c.encoder.Encode(v) + } return json.NewEncoder(stream).Encode(v) } // ReadObject implements ObjectCodec. -func (PlainObjectCodec) ReadObject(stream *bufio.Reader, v interface{}) error { +func (c PlainObjectCodec) ReadObject(stream *bufio.Reader, v interface{}) error { + if c.decoder != nil { + return c.decoder.Decode(v) + } return json.NewDecoder(stream).Decode(v) } + +// plainObjectStream reads/writes plain JSON-RPC 2.0 objects without a header. +type plainObjectStream struct { + conn io.Closer + decoder *json.Decoder + encoder *json.Encoder +} + +// NewPlainObjectStream creates a buffered stream from a network +// connection (or other similar interface). The underlying +// objectStream produces plain JSON-RPC 2.0 objects without a header. +func NewPlainObjectStream(conn io.ReadWriteCloser) ObjectStream { + return &plainObjectStream{ + conn: conn, + encoder: json.NewEncoder(conn), + decoder: json.NewDecoder(conn), + } +} + +func (os *plainObjectStream) ReadObject(v interface{}) error { + return os.decoder.Decode(v) +} + +// WriteObject serializes a value to JSON and writes it to a stream. +// Not thread-safe, a user must synchronize writes in a multithreaded environment. +func (os *plainObjectStream) WriteObject(v interface{}) error { + return os.encoder.Encode(v) +} + +func (os *plainObjectStream) Close() error { + return os.conn.Close() +}