1
0
Fork 0
mirror of https://github.com/sourcegraph/jsonrpc2.git synced 2026-07-04 16:23:41 +02:00

allow multiple OnRecv and OnSend

This commit is contained in:
Keegan Carruthers-Smith 2018-05-01 19:02:17 +01:00
parent 2ed59d3304
commit a3d86c792f
2 changed files with 20 additions and 12 deletions

View file

@ -13,13 +13,13 @@ type ConnOpt func(*Conn)
// OnRecv causes all requests received on conn to invoke f(req, nil) // OnRecv causes all requests received on conn to invoke f(req, nil)
// and all responses to invoke f(req, resp), // and all responses to invoke f(req, resp),
func OnRecv(f func(*Request, *Response)) ConnOpt { func OnRecv(f func(*Request, *Response)) ConnOpt {
return func(c *Conn) { c.onRecv = f } return func(c *Conn) { c.onRecv = append(c.onRecv, f) }
} }
// OnSend causes all requests sent on conn to invoke f(req, nil) and // OnSend causes all requests sent on conn to invoke f(req, nil) and
// all responses to invoke f(nil, resp), // all responses to invoke f(nil, resp),
func OnSend(f func(*Request, *Response)) ConnOpt { func OnSend(f func(*Request, *Response)) ConnOpt {
return func(c *Conn) { c.onSend = f } return func(c *Conn) { c.onSend = append(c.onSend, f) }
} }
// LogMessages causes all messages sent and received on conn to be // LogMessages causes all messages sent and received on conn to be

View file

@ -298,8 +298,8 @@ type Conn struct {
disconnect chan struct{} disconnect chan struct{}
// Set by ConnOpt funcs. // Set by ConnOpt funcs.
onRecv func(*Request, *Response) onRecv []func(*Request, *Response)
onSend func(*Request, *Response) onSend []func(*Request, *Response)
} }
var _ JSONRPC2 = (*Conn)(nil) var _ JSONRPC2 = (*Conn)(nil)
@ -370,12 +370,19 @@ func (c *Conn) send(ctx context.Context, m *anyMessage, wait bool) (cc *call, er
} }
c.mu.Unlock() c.mu.Unlock()
if c.onSend != nil { if len(c.onSend) > 0 {
var (
req *Request
resp *Response
)
switch { switch {
case m.request != nil: case m.request != nil:
c.onSend(m.request, nil) req = m.request
case m.response != nil: case m.response != nil:
c.onSend(nil, m.response) resp = m.response
}
for _, onSend := range c.onSend {
onSend(req, resp)
} }
} }
@ -498,8 +505,8 @@ func (c *Conn) readMessages(ctx context.Context) {
switch { switch {
case m.request != nil: case m.request != nil:
if c.onRecv != nil { for _, onRecv := range c.onRecv {
c.onRecv(m.request, nil) onRecv(m.request, nil)
} }
c.h.Handle(ctx, c, m.request) c.h.Handle(ctx, c, m.request)
@ -516,13 +523,14 @@ func (c *Conn) readMessages(ctx context.Context) {
call.response = resp call.response = resp
} }
if c.onRecv != nil { if len(c.onRecv) > 0 {
var req *Request var req *Request
if call != nil { if call != nil {
req = call.request req = call.request
} }
c.onRecv(req, resp) for _, onRecv := range c.onRecv {
onRecv(req, resp)
}
} }
switch { switch {