diff --git a/api/client.go b/api/client.go index de21dc5..023499f 100644 --- a/api/client.go +++ b/api/client.go @@ -1,6 +1,7 @@ package api import ( + "fmt" "sync" "github.com/eosswedenorg/thalos/api/message" @@ -63,13 +64,16 @@ func (c *Client) hbHandler(payload []byte) { c.OnHeartbeat(hb) } -func (c *Client) Subscribe(channel Channel) { +func (c *Client) Subscribe(channel Channel) error { var handler handler - if HeartbeatChannel.Is(channel) { + switch t := channel.Type(); t { + case HeartbeatChannel.Type(): handler = c.hbHandler - } else { + case ActionChannel{}.Channel().Type(): handler = c.actHandler + default: + return fmt.Errorf("invalid channel type. %s", t) } // Start a worker for this channel. @@ -78,6 +82,8 @@ func (c *Client) Subscribe(channel Channel) { defer c.wg.Done() c.worker(channel, handler) }() + + return nil } func (c *Client) Run() { diff --git a/api/client_test.go b/api/client_test.go new file mode 100644 index 0000000..2a1061b --- /dev/null +++ b/api/client_test.go @@ -0,0 +1,26 @@ +package api + +import ( + "testing" +) + +func TestClient_Subscribe(t *testing.T) { + tests := []struct { + name string + channel Channel + wantErr bool + }{ + {"Channel", Channel{}, true}, + {"ActionChannel", ActionChannel{}.Channel(), false}, + {"HeartbeatChannel", HeartbeatChannel, false}, + {"TransactionChannel", TransactionChannel, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := Client{} + if err := c.Subscribe(tt.channel); (err != nil) != tt.wantErr { + t.Errorf("Client.Subscribe() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}