From 4188fa4438caf544b4a46e1dac4dc5706f3f60ca Mon Sep 17 00:00:00 2001 From: Ggicci Date: Sun, 23 Jan 2022 16:36:45 +0800 Subject: [PATCH] Adjust Handler interface and support middleware --- .gitignore | 21 +++++++++++++++++++ async.go | 6 ++---- call_opt_test.go | 8 ++++---- codec_test.go | 6 +++--- handler.go | 38 ++++++++++++++++++++++++++++++++++ handler_with_error.go | 14 +++++-------- jsonrpc2.go | 33 ++++++++++++++++++------------ jsonrpc2_test.go | 47 +++++++++++++++++++++++++++++++------------ 8 files changed, 127 insertions(+), 46 deletions(-) create mode 100644 .gitignore create mode 100644 handler.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3b735ec --- /dev/null +++ b/.gitignore @@ -0,0 +1,21 @@ +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work diff --git a/async.go b/async.go index bc8a370..acac81b 100644 --- a/async.go +++ b/async.go @@ -1,7 +1,5 @@ package jsonrpc2 -import "context" - // AsyncHandler wraps a Handler such that each request is handled in its own // goroutine. It is a convenience wrapper. func AsyncHandler(h Handler) Handler { @@ -12,6 +10,6 @@ type asyncHandler struct { Handler } -func (h asyncHandler) Handle(ctx context.Context, conn *Conn, req *Request) { - go h.Handler.Handle(ctx, conn, req) +func (h asyncHandler) Handle(conn *Conn, req *Request) { + go h.Handler.Handle(conn, req) } diff --git a/call_opt_test.go b/call_opt_test.go index b64f661..d08c92d 100644 --- a/call_opt_test.go +++ b/call_opt_test.go @@ -16,8 +16,8 @@ func TestPickID(t *testing.T) { defer a.Close() defer b.Close() - handler := handlerFunc(func(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { - if err := conn.Reply(ctx, req.ID, fmt.Sprintf("hello, #%s: %s", req.ID, *req.Params)); err != nil { + handler := handlerFunc(func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + if err := conn.Reply(req.Context(), req.ID, fmt.Sprintf("hello, #%s: %s", req.ID, *req.Params)); err != nil { t.Error(err) } }) @@ -61,7 +61,7 @@ func TestStringID(t *testing.T) { defer a.Close() defer b.Close() - handler := handlerFunc(func(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + handler := handlerFunc(func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) { replyWithError := func(msg string) { respErr := &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidRequest, Message: msg} if err := conn.ReplyWithError(ctx, req.ID, respErr); err != nil { @@ -100,7 +100,7 @@ func TestExtraField(t *testing.T) { defer a.Close() defer b.Close() - handler := handlerFunc(func(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + handler := handlerFunc(func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) { replyWithError := func(msg string) { respErr := &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidRequest, Message: msg} if err := conn.ReplyWithError(ctx, req.ID, respErr); err != nil { diff --git a/codec_test.go b/codec_test.go index 4dc7555..442cc83 100644 --- a/codec_test.go +++ b/codec_test.go @@ -56,13 +56,13 @@ func TestPlainObjectCodec(t *testing.T) { // echoHandler unmarshals the request's params object and echos the object // back as the response's result. - var echoHandler handlerFunc = func(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + var echoHandler handlerFunc = func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) { msg := &Message{} if err := json.Unmarshal(*req.Params, msg); err != nil { - conn.ReplyWithError(ctx, req.ID, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidRequest, Message: err.Error()}) + conn.ReplyWithError(req.Context(), req.ID, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidRequest, Message: err.Error()}) return } - conn.Reply(ctx, req.ID, msg) + conn.Reply(req.Context(), req.ID, msg) } connB := jsonrpc2.NewConn( context.Background(), diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..4c31f35 --- /dev/null +++ b/handler.go @@ -0,0 +1,38 @@ +package jsonrpc2 + +// Handler handles JSON-RPC requests and notifications. +type Handler interface { + // Handle is called to handle a request. No other requests are handled + // until it returns. If you do not require strict ordering behavior + // of received RPCs, it is suggested to wrap your handler in + // AsyncHandler. + Handle(*Conn, *Request) +} + +type HandlerFunc func(*Conn, *Request) + +func (f HandlerFunc) Handle(conn *Conn, req *Request) { + f(conn, req) +} + +type Middleware func(Handler) Handler + +type chain struct { + ms []Middleware +} + +func Chain(middleware ...Middleware) chain { + return chain{ms: append([]Middleware(nil), middleware...)} +} + +func (c chain) Then(h Handler) Handler { + if h == nil { + panic("nil handler") + } + + for i := range c.ms { + h = c.ms[len(c.ms)-1-i](h) + } + + return h +} diff --git a/handler_with_error.go b/handler_with_error.go index 2bd5c1d..6406424 100644 --- a/handler_with_error.go +++ b/handler_with_error.go @@ -1,24 +1,20 @@ package jsonrpc2 -import ( - "context" -) - // HandlerWithError implements Handler by calling the func for each // request and handling returned errors and results. -func HandlerWithError(handleFunc func(context.Context, *Conn, *Request) (result interface{}, err error)) *HandlerWithErrorConfigurer { +func HandlerWithError(handleFunc func(*Conn, *Request) (result interface{}, err error)) *HandlerWithErrorConfigurer { return &HandlerWithErrorConfigurer{handleFunc: handleFunc} } // HandlerWithErrorConfigurer is a handler created by HandlerWithError. type HandlerWithErrorConfigurer struct { - handleFunc func(context.Context, *Conn, *Request) (result interface{}, err error) + handleFunc func(*Conn, *Request) (result interface{}, err error) suppressErrClosed bool } // Handle implements Handler. -func (h *HandlerWithErrorConfigurer) Handle(ctx context.Context, conn *Conn, req *Request) { - result, err := h.handleFunc(ctx, conn, req) +func (h *HandlerWithErrorConfigurer) Handle(conn *Conn, req *Request) { + result, err := h.handleFunc(conn, req) if req.Notif { if err != nil { conn.logger.Printf("jsonrpc2 handler: notification %q handling error: %v\n", req.Method, err) @@ -39,7 +35,7 @@ func (h *HandlerWithErrorConfigurer) Handle(ctx context.Context, conn *Conn, req } if !req.Notif { - if err := conn.SendResponse(ctx, resp); err != nil { + if err := conn.SendResponse(req.Context(), resp); err != nil { if err != ErrClosed || !h.suppressErrClosed { conn.logger.Printf("jsonrpc2 handler: sending response %s: %v\n", resp.ID, err) } diff --git a/jsonrpc2.go b/jsonrpc2.go index 7815c84..3f71fb8 100644 --- a/jsonrpc2.go +++ b/jsonrpc2.go @@ -41,6 +41,8 @@ type RequestField struct { // http://www.jsonrpc.org/specification#request_object and // http://www.jsonrpc.org/specification#notification. type Request struct { + ctx context.Context + Method string `json:"method"` Params *json.RawMessage `json:"params,omitempty"` ID ID `json:"id"` @@ -59,6 +61,23 @@ type Request struct { ExtraFields []RequestField `json:"-"` } +func (r *Request) Context() context.Context { + if r.ctx != nil { + return r.ctx + } + return context.Background() +} + +func (r *Request) WithContext(ctx context.Context) *Request { + if ctx == nil { + panic("nil context") + } + r2 := new(Request) + *r2 = *r + r2.ctx = ctx + return r2 +} + // MarshalJSON implements json.Marshaler and adds the "jsonrpc":"2.0" // property. func (r Request) MarshalJSON() ([]byte, error) { @@ -294,15 +313,6 @@ const ( CodeInternalError = -32603 ) -// Handler handles JSON-RPC requests and notifications. -type Handler interface { - // Handle is called to handle a request. No other requests are handled - // until it returns. If you do not require strict ordering behavior - // of received RPCs, it is suggested to wrap your handler in - // AsyncHandler. - Handle(context.Context, *Conn, *Request) -} - // ID represents a JSON-RPC 2.0 request ID, which may be either a // string or number (or null, which is unsupported). type ID struct { @@ -384,9 +394,6 @@ var ErrClosed = errors.New("jsonrpc2: connection is closed") // given ReadWriteCloser (typically a TCP connection or stdio). The // JSON-RPC protocol is symmetric, so a Conn runs on both ends of a // client-server connection. -// -// NewClient consumes conn, so you should call Close on the returned -// client not on the given conn. func NewConn(ctx context.Context, stream ObjectStream, h Handler, opts ...ConnOpt) *Conn { c := &Conn{ stream: stream, @@ -620,7 +627,7 @@ func (c *Conn) readMessages(ctx context.Context) { for _, onRecv := range c.onRecv { onRecv(m.request, nil) } - c.h.Handle(ctx, c, m.request) + c.h.Handle(c, m.request.WithContext(ctx)) case m.response != nil: resp := m.response diff --git a/jsonrpc2_test.go b/jsonrpc2_test.go index f9dd950..2d86a17 100644 --- a/jsonrpc2_test.go +++ b/jsonrpc2_test.go @@ -79,18 +79,31 @@ func TestResponseUnmarshalJSON_Notif(t *testing.T) { } } +func noInternalMethods(next jsonrpc2.Handler) jsonrpc2.Handler { + return jsonrpc2.HandlerFunc(func(conn *jsonrpc2.Conn, r *jsonrpc2.Request) { + if strings.HasPrefix(r.Method, "internal_") { + conn.ReplyWithError(r.Context(), r.ID, &jsonrpc2.Error{ + Code: jsonrpc2.CodeMethodNotFound, + Message: fmt.Sprintf("method %q not found", r.Method), + }) + return + } + next.Handle(conn, r) + }) +} + // testHandlerA is the "server" handler. type testHandlerA struct{ t *testing.T } -func (h *testHandlerA) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { +func (h *testHandlerA) Handle(conn *jsonrpc2.Conn, req *jsonrpc2.Request) { if req.Notif { return // notification } - if err := conn.Reply(ctx, req.ID, fmt.Sprintf("hello, #%s: %s", req.ID, *req.Params)); err != nil { + if err := conn.Reply(req.Context(), req.ID, fmt.Sprintf("hello, #%s: %s", req.ID, *req.Params)); err != nil { h.t.Error(err) } - if err := conn.Notify(ctx, "m", fmt.Sprintf("notif for #%s", req.ID)); err != nil { + if err := conn.Notify(req.Context(), "m", fmt.Sprintf("notif for #%s", req.ID)); err != nil { h.t.Error(err) } } @@ -102,7 +115,7 @@ type testHandlerB struct { got []string } -func (h *testHandlerB) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { +func (h *testHandlerB) Handle(conn *jsonrpc2.Conn, req *jsonrpc2.Request) { if req.Notif { h.mu.Lock() defer h.mu.Unlock() @@ -132,9 +145,9 @@ func TestClientServer(t *testing.T) { } }() - ha := testHandlerA{t: t} + ha := jsonrpc2.Chain(noInternalMethods).Then(&testHandlerA{t: t}) go func() { - if err = serve(ctx, lis, &ha); err != nil { + if err = serve(ctx, lis, ha); err != nil { if !strings.HasSuffix(err.Error(), "use of closed network connection") { t.Error(err) } @@ -155,7 +168,7 @@ func TestClientServer(t *testing.T) { ctx := context.Background() done := make(chan struct{}) - ha := testHandlerA{t: t} + ha := jsonrpc2.Chain(noInternalMethods).Then(&testHandlerA{t: t}) upgrader := websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024} s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := upgrader.Upgrade(w, r, nil) @@ -163,7 +176,7 @@ func TestClientServer(t *testing.T) { t.Fatal(err) } defer c.Close() - jc := jsonrpc2.NewConn(r.Context(), websocketjsonrpc2.NewObjectStream(c), &ha) + jc := jsonrpc2.NewConn(r.Context(), websocketjsonrpc2.NewObjectStream(c), ha) <-jc.DisconnectNotify() close(done) })) @@ -215,6 +228,14 @@ func testClientServer(ctx context.Context, t *testing.T, stream jsonrpc2.ObjectS t.Fatalf("out of order response. got %q, want %q", s, want) } } + + // The "internal_*" methods should not be exposed to the client. + err := cc.Call(ctx, "internal_listAdmins", nil, nil) + if err == nil { + t.Error("unexpected successful call to internal_listAdmins") + } else if rpcError, ok := err.(*jsonrpc2.Error); !ok || rpcError.Code != jsonrpc2.CodeMethodNotFound { + t.Errorf("got error %v, want CodeMethodNotFound", err) + } } func inMemoryPeerConns() (io.ReadWriteCloser, io.ReadWriteCloser) { @@ -237,10 +258,10 @@ func (c *pipeReadWriteCloser) Close() error { return err2 } -type handlerFunc func(context.Context, *jsonrpc2.Conn, *jsonrpc2.Request) +type handlerFunc func(*jsonrpc2.Conn, *jsonrpc2.Request) -func (h handlerFunc) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { - h(ctx, conn, req) +func (h handlerFunc) Handle(conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + h(conn, req) } func TestHandlerBlocking(t *testing.T) { @@ -257,7 +278,7 @@ func TestHandlerBlocking(t *testing.T) { wg sync.WaitGroup params []int ) - handler := handlerFunc(func(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + handler := handlerFunc(func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) { var i int _ = json.Unmarshal(*req.Params, &i) // don't need to synchronize access to ids since we should be blocking @@ -289,7 +310,7 @@ func TestHandlerBlocking(t *testing.T) { type noopHandler struct{} -func (noopHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {} +func (noopHandler) Handle(conn *jsonrpc2.Conn, req *jsonrpc2.Request) {} type readWriteCloser struct { read, write func(p []byte) (n int, err error)