diff --git a/api/channel.go b/api/channel.go index 8cee878..a843f54 100644 --- a/api/channel.go +++ b/api/channel.go @@ -19,6 +19,13 @@ func (c Channel) String() string { return c.Format("/") } +func (c Channel) Type() string { + if len(c) > 0 { + return c[0] + } + return "unknown" +} + // Check if two channels are equal func (c Channel) Is(other Channel) bool { if len(c) != len(other) { diff --git a/api/channel_test.go b/api/channel_test.go index b80dd5e..58bc597 100644 --- a/api/channel_test.go +++ b/api/channel_test.go @@ -90,6 +90,26 @@ func TestChannel_String(t *testing.T) { } } +func TestChannel_Type(t *testing.T) { + tests := []struct { + name string + c Channel + want string + }{ + {"Empty", Channel{}, "unknown"}, + {"Heartbeat", HeartbeatChannel, "heartbeat"}, + {"Transaction", TransactionChannel, "transactions"}, + {"Actions", ActionChannel{}.Channel(), "actions"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.c.Type(); got != tt.want { + t.Errorf("Channel.Type() = %v, want %v", got, tt.want) + } + }) + } +} + func TestAction_Channel(t *testing.T) { tests := []struct { name string