diff --git a/api/client.go b/api/client.go index a426d3d..146b4aa 100644 --- a/api/client.go +++ b/api/client.go @@ -2,15 +2,15 @@ package api import ( "fmt" - "reflect" "sync" + "time" "github.com/eosswedenorg/thalos/api/message" ) type handler func([]byte) -// Client reads and decodes messages from a reader and provides callback functions. +// Client reads and decodes messages from a reader and posts thems to a go channel type Client struct { reader Reader decoder message.Decoder @@ -18,18 +18,26 @@ type Client struct { // waitgroup for worker threads. wg sync.WaitGroup - OnError func(error) - OnRollback func(message.RollbackMessage) - OnTransaction func(message.TransactionTrace) - OnAction func(message.ActionTrace) - OnHeartbeat func(message.HeartBeat) - OnTableDelta func(message.TableDelta) + // Channel for messages and errors + channel chan any } func NewClient(reader Reader, decoder message.Decoder) *Client { return &Client{ reader: reader, decoder: decoder, + channel: make(chan any), + } +} + +func (c *Client) Channel() <-chan any { + return c.channel +} + +func (c *Client) post(msg any) { + select { + case <-time.After(time.Second): + case c.channel <- msg: } } @@ -37,9 +45,7 @@ func (c *Client) worker(channel Channel, h handler) { for { payload, err := c.reader.Read(channel) if err != nil { - if c.OnError != nil { - c.OnError(err) - } + c.post(err) return } @@ -47,13 +53,11 @@ func (c *Client) worker(channel Channel, h handler) { } } -// Helper method to decode a message and call OnError on error. +// Helper method to decode a message and post and error on the channel if it fails. // Returns true if successfull. false otherwise func (c *Client) decode(payload []byte, msg any) bool { if err := c.decoder(payload, msg); err != nil { - if c.OnError != nil { - c.OnError(err) - } + c.post(err) return false } return true @@ -63,7 +67,7 @@ func (c *Client) decode(payload []byte, msg any) bool { func (c *Client) rollbackHandler(payload []byte) { var rb message.RollbackMessage if ok := c.decode(payload, &rb); ok { - c.OnRollback(rb) + c.post(rb) } } @@ -71,7 +75,7 @@ func (c *Client) rollbackHandler(payload []byte) { func (c *Client) transactionHandler(payload []byte) { var trans message.TransactionTrace if ok := c.decode(payload, &trans); ok { - c.OnTransaction(trans) + c.post(trans) } } @@ -79,7 +83,7 @@ func (c *Client) transactionHandler(payload []byte) { func (c *Client) actHandler(payload []byte) { var act message.ActionTrace if ok := c.decode(payload, &act); ok { - c.OnAction(act) + c.post(act) } } @@ -87,7 +91,7 @@ func (c *Client) actHandler(payload []byte) { func (c *Client) tableDeltaHandler(payload []byte) { td := message.TableDelta{} if ok := c.decode(payload, &td); ok { - c.OnTableDelta(td) + c.post(td) } } @@ -95,37 +99,33 @@ func (c *Client) tableDeltaHandler(payload []byte) { func (c *Client) hbHandler(payload []byte) { var hb message.HeartBeat if ok := c.decode(payload, &hb); ok { - c.OnHeartbeat(hb) + c.post(hb) } } func (c *Client) Subscribe(channel Channel) error { - handlers := map[string]struct { - handler handler - callback any - }{ - RollbackChannel.Type(): {c.rollbackHandler, c.OnRollback}, - TransactionChannel.Type(): {c.transactionHandler, c.OnTransaction}, - HeartbeatChannel.Type(): {c.hbHandler, c.OnHeartbeat}, - ActionChannel{}.Channel().Type(): {c.actHandler, c.OnAction}, - TableDeltaChannel{}.Channel().Type(): {c.tableDeltaHandler, c.OnTableDelta}, - } + var handler handler - h, ok := handlers[channel.Type()] - - if !ok { + switch channel.Type() { + case RollbackChannel.Type(): + handler = c.rollbackHandler + case TransactionChannel.Type(): + handler = c.transactionHandler + case HeartbeatChannel.Type(): + handler = c.hbHandler + case ActionChannel{}.Channel().Type(): + handler = c.actHandler + case TableDeltaChannel{}.Channel().Type(): + handler = c.tableDeltaHandler + default: return fmt.Errorf("invalid channel type. %s", channel.Type()) } - if h.callback == nil || reflect.ValueOf(h.callback).IsNil() { - return fmt.Errorf("please set an handler before calling Subscribe") - } - // Start a worker for this channel. c.wg.Add(1) go func() { defer c.wg.Done() - c.worker(channel, h.handler) + c.worker(channel, handler) }() return nil @@ -137,5 +137,9 @@ func (c *Client) Run() { } func (c *Client) Close() error { - return c.reader.Close() + err := c.reader.Close() + // Wait for all goroutines before closing channel. + c.wg.Wait() + close(c.channel) + return err } diff --git a/api/client_test.go b/api/client_test.go index 6fa9d8b..784eee6 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -33,12 +33,6 @@ func mockDecoder([]byte, any) error { return nil } -func mockHbHandler(message.HeartBeat) { -} - -func mockActionHandler(message.ActionTrace) { -} - func TestClient_Subscribe(t *testing.T) { tests := []struct { name string @@ -48,13 +42,12 @@ func TestClient_Subscribe(t *testing.T) { {"Channel", Channel{}, true}, {"ActionChannel", ActionChannel{}.Channel(), false}, {"HeartbeatChannel", HeartbeatChannel, false}, - {"TransactionChannel", TransactionChannel, true}, + {"TransactionChannel", TransactionChannel, false}, + {"InvalidChannel", Channel{"random_type"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := NewClient(&mockReader{}, mockDecoder) - c.OnHeartbeat = mockHbHandler - c.OnAction = mockActionHandler if err := c.Subscribe(tt.channel); (err != nil) != tt.wantErr { t.Errorf("Client.Subscribe() error = %v, wantErr %v", err, tt.wantErr) } @@ -62,18 +55,7 @@ func TestClient_Subscribe(t *testing.T) { } } -func TestClient_SubscribeWithNilHandler(t *testing.T) { - client := NewClient(nil, nil) - client.OnAction = mockActionHandler - client.OnHeartbeat = mockHbHandler - - err := client.Subscribe(TableDeltaChannel{Name: "name"}.Channel()) - - assert.Error(t, err) -} - func TestClient_ReadRollback(t *testing.T) { - called := false expected := message.RollbackMessage{ OldBlockNum: 1000, NewBlockNum: 50, @@ -86,15 +68,10 @@ func TestClient_ReadRollback(t *testing.T) { assert.NoError(t, err) client := NewClient(mockReader{bytes.NewReader(payload)}, codec.Decoder) - client.OnRollback = func(rb message.RollbackMessage) { - assert.Equal(t, rb, expected) - called = true - } err = client.Subscribe(RollbackChannel) assert.NoError(t, err) - client.Run() - - assert.True(t, called, "Rollback callback not called when it should have been") + actual := <-client.Channel() + assert.Equal(t, expected, actual) } diff --git a/cmd/tools/bench.go b/cmd/tools/bench.go index ae23a55..31149f9 100644 --- a/cmd/tools/bench.go +++ b/cmd/tools/bench.go @@ -76,46 +76,47 @@ var benchCmd = &cli.Command{ client := api.NewClient(sub, codec.Decoder) - client.OnAction = func(act message.ActionTrace) { - counter++ - } - // Subscribe to all actions if err = client.Subscribe(api.ActionChannel{}.Channel()); err != nil { return err } go func() { - t := time.Now() - sig := make(chan os.Signal, 1) - signal.Notify(sig, os.Interrupt) - - for { - select { - case <-sig: - fmt.Println("Got interrupt") - client.Close() - return - case now := <-time.After(interval): - elapsed := now.Sub(t) - t = now - - log.WithFields(log.Fields{ - "num_messages": counter, - "elapsed": elapsed, - "msg_per_sec": float64(counter) / elapsed.Seconds(), - "msg_per_ms": float64(counter) / float64(elapsed.Milliseconds()), - "msg_per_min": float64(counter) / elapsed.Minutes(), - }).Info("Benchmark results") - - counter = 0 + for t := range client.Channel() { + switch err := t.(type) { + case message.ActionTrace: + counter++ + case error: + log.WithError(err).Error("Error when reading stream") } } }() - // Read stuff. - client.Run() + t := time.Now() + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt) - return nil + // Read stuff. + for { + select { + case <-sig: + fmt.Println("Got interrupt") + client.Close() + return nil + case now := <-time.After(interval): + elapsed := now.Sub(t) + t = now + + log.WithFields(log.Fields{ + "num_messages": counter, + "elapsed": elapsed, + "msg_per_sec": float64(counter) / elapsed.Seconds(), + "msg_per_ms": float64(counter) / float64(elapsed.Milliseconds()), + "msg_per_min": float64(counter) / elapsed.Minutes(), + }).Info("Benchmark results") + + counter = 0 + } + } }, } diff --git a/cmd/tools/validate.go b/cmd/tools/validate.go index 061e35d..e696da8 100644 --- a/cmd/tools/validate.go +++ b/cmd/tools/validate.go @@ -18,37 +18,6 @@ import ( log "github.com/sirupsen/logrus" ) -type Tester struct { - block_num uint32 - timeout time.Duration - timer *time.Ticker -} - -func NewTester(timeout time.Duration) *Tester { - return &Tester{ - block_num: 0, - timeout: timeout, - timer: time.NewTicker(timeout), - } -} - -func (t *Tester) OnAction(act message.ActionTrace) { - if t.block_num > 0 { - var diff int32 = int32(act.BlockNum - t.block_num) - if diff < 0 || diff > 1 { - log.WithFields(log.Fields{ - "current_block": t.block_num, - "block": act.BlockNum, - "diff": diff, - }).Warn("Invalid") - } - } - - t.block_num = act.BlockNum - - t.timer.Reset(t.timeout) -} - var validateCmd = &cli.Command{ Name: "validate", Usage: "Validate a thalos server by following action traces and makes sure that blocks arrive in order.", @@ -59,7 +28,6 @@ var validateCmd = &cli.Command{ chainIdFlag, }, Action: func(ctx *cli.Context) error { - tester := NewTester(time.Second * 5) status_duration := time.Second * 10 log.WithFields(log.Fields{ @@ -94,37 +62,53 @@ var validateCmd = &cli.Command{ } client := api.NewClient(sub, codec.Decoder) - client.OnAction = tester.OnAction // Subscribe to all actions if err = client.Subscribe(api.ActionChannel{}.Channel()); err != nil { return err } - go func() { - sig := make(chan os.Signal, 1) - signal.Notify(sig, os.Interrupt) + block_num := uint32(0) + timeout := time.Second * 5 + timer := time.NewTicker(timeout) - for { - select { - case <-sig: - fmt.Println("Got interrupt") - client.Close() - return - case <-tester.timer.C: - log.WithField("duration", tester.timeout). - Warn("Did not get any messages during the defined duration") - case <-time.After(status_duration): - log.WithFields(log.Fields{ - "current_block": tester.block_num, - }).Info("Status") + go func() { + for t := range client.Channel() { + switch msg := t.(type) { + case message.ActionTrace: + if block_num > 0 { + var diff int32 = int32(msg.BlockNum - block_num) + if diff < 0 || diff > 1 { + log.WithFields(log.Fields{ + "current_block": block_num, + "block": msg.BlockNum, + "diff": diff, + }).Warn("Invalid") + } + } + block_num = msg.BlockNum + timer.Reset(timeout) } } }() - // Read stuff. - client.Run() + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt) - return nil + for { + select { + case <-sig: + fmt.Println("Got interrupt") + client.Close() + return nil + case <-timer.C: + log.WithField("duration", timeout). + Warn("Did not get any messages during the defined duration") + case <-time.After(status_duration): + log.WithFields(log.Fields{ + "current_block": block_num, + }).Info("Status") + } + } }, }