diff --git a/conn_opt.go b/conn_opt.go index 423cf80..a83ccc3 100644 --- a/conn_opt.go +++ b/conn_opt.go @@ -43,18 +43,6 @@ func LogMessages(logger Logger) ConnOpt { OnRecv(func(req *Request, resp *Response) { switch { - case req != nil: - mu.Lock() - reqMethods[req.ID] = req.Method - mu.Unlock() - - params, _ := json.Marshal(req.Params) - if req.Notif { - logger.Printf("jsonrpc2: --> notif: %s: %s\n", req.Method, params) - } else { - logger.Printf("jsonrpc2: --> request #%s: %s: %s\n", req.ID, req.Method, params) - } - case resp != nil: var method string if req != nil { @@ -70,18 +58,22 @@ func LogMessages(logger Logger) ConnOpt { err, _ := json.Marshal(resp.Error) logger.Printf("jsonrpc2: --> error #%s: %s: %s\n", resp.ID, method, err) } + + case req != nil: + mu.Lock() + reqMethods[req.ID] = req.Method + mu.Unlock() + + params, _ := json.Marshal(req.Params) + if req.Notif { + logger.Printf("jsonrpc2: --> notif: %s: %s\n", req.Method, params) + } else { + logger.Printf("jsonrpc2: --> request #%s: %s: %s\n", req.ID, req.Method, params) + } } })(c) OnSend(func(req *Request, resp *Response) { switch { - case req != nil: - params, _ := json.Marshal(req.Params) - if req.Notif { - logger.Printf("jsonrpc2: <-- notif: %s: %s\n", req.Method, params) - } else { - logger.Printf("jsonrpc2: <-- request #%s: %s: %s\n", req.ID, req.Method, params) - } - case resp != nil: mu.Lock() method := reqMethods[resp.ID] @@ -98,6 +90,14 @@ func LogMessages(logger Logger) ConnOpt { err, _ := json.Marshal(resp.Error) logger.Printf("jsonrpc2: <-- error #%s: %s: %s\n", resp.ID, method, err) } + + case req != nil: + params, _ := json.Marshal(req.Params) + if req.Notif { + logger.Printf("jsonrpc2: <-- notif: %s: %s\n", req.Method, params) + } else { + logger.Printf("jsonrpc2: <-- request #%s: %s: %s\n", req.ID, req.Method, params) + } } })(c) } diff --git a/conn_opt_test.go b/conn_opt_test.go index df53a1a..97f59e4 100644 --- a/conn_opt_test.go +++ b/conn_opt_test.go @@ -51,3 +51,80 @@ func TestSetLogger(t *testing.T) { t.Fatalf("got %q, want %q", got, want) } } + +type dummyHandler struct { + t *testing.T +} + +func (h *dummyHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + if !req.Notif { + err := conn.Reply(ctx, req.ID, nil) + if err != nil { + h.t.Error(err) + return + } + } +} + +func TestLogMessages(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rd, wr := io.Pipe() + defer rd.Close() + defer wr.Close() + + buf := bufio.NewReader(rd) + logger := log.New(wr, "", log.Lmsgprefix) + + a, b := net.Pipe() + connA := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewBufferedStream(a, jsonrpc2.VSCodeObjectCodec{}), + &dummyHandler{t}, + jsonrpc2.LogMessages(logger), + ) + connB := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewBufferedStream(b, jsonrpc2.VSCodeObjectCodec{}), + &dummyHandler{t}, + ) + defer connA.Close() + defer connB.Close() + + go func() { + if err := connA.Call(ctx, "method1", nil, nil); err != nil { + t.Error(err) + return + } + if err := connB.Call(ctx, "method2", nil, nil); err != nil { + t.Error(err) + return + } + if err := connA.Notify(ctx, "notification1", nil); err != nil { + t.Error(err) + return + } + if err := connB.Notify(ctx, "notification2", nil); err != nil { + t.Error(err) + return + } + }() + + for i, want := range []string{ + "jsonrpc2: <-- request #0: method1: null\n", + "jsonrpc2: --> result #0: method1: null\n", + "jsonrpc2: --> request #0: method2: null\n", + "jsonrpc2: <-- result #0: method2: null\n", + "jsonrpc2: <-- notif: notification1: null\n", + "jsonrpc2: --> notif: notification2: null\n", + } { + got, err := buf.ReadString('\n') + if err != nil { + t.Fatal(err) + } + if got != want { + t.Errorf("message %v: got %q, want %q", i, got, want) + } + } +}