From 4f27307c702b8af56c825e0f6cc27ef8946d0087 Mon Sep 17 00:00:00 2001 From: Henrik Hautakoski Date: Mon, 15 Jul 2024 23:02:29 +0200 Subject: [PATCH] internal/types/blacklist.go: add isWhitelist field --- internal/config/builder.go | 83 +++++++++++++++++++++++--------- internal/config/builder_test.go | 44 ++++++++++------- internal/types/blacklist.go | 40 ++++++++++----- internal/types/blacklist_test.go | 37 +++++++++++--- 4 files changed, 146 insertions(+), 58 deletions(-) diff --git a/internal/config/builder.go b/internal/config/builder.go index 026dd33..d273858 100644 --- a/internal/config/builder.go +++ b/internal/config/builder.go @@ -116,30 +116,9 @@ func (b *Builder) Build() (*Config, error) { mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToSliceHookFunc(","), func(f reflect.Type, t reflect.Type, in interface{}) (interface{}, error) { - if t == reflect.TypeOf(types.Blacklist{}) && f.Kind() == reflect.Slice { - if v, ok := in.([]string); ok { - list := types.Blacklist{} - for _, i := range v { - var action string - parts := strings.SplitN(i, ":", 2) - - if len(parts) < 2 { - action = "*" - } else { - action = parts[1] - } - - list.Add(parts[0], action) - } - - if len(list) < 1 { - list = nil - } - return list, nil - } - return nil, fmt.Errorf("Must be a string slice") + if t == reflect.TypeOf(types.Blacklist{}) { + return decodeIntoBlacklist(in) } - return in, nil }, ) @@ -151,3 +130,61 @@ func (b *Builder) Build() (*Config, error) { return &conf, nil } + +// Decode a generic structure into types.Blacklist +func decodeIntoBlacklist(in any) (*types.Blacklist, error) { + switch v := in.(type) { + // Standard map structure. + case map[string]any: + return blacklistParseMap(v) + + // slice of "contract:action" pairs. Usually from CLI + case []string: + return blacklistParseSlice(v) + + // Sometimes we have a slice of interfaces. + // Need to convert it to a slice of strings. + case []any: + sv := make([]string, len(v)) + for i, j := range v { + sv[i] = j.(string) + } + return blacklistParseSlice(sv) + } + + return nil, fmt.Errorf("Must be a string slice") +} + +// Blacklist map parser +func blacklistParseMap(in map[string]any) (*types.Blacklist, error) { + list := &types.Blacklist{} + for k, v := range in { + switch v := v.(type) { + case []any: + for _, v := range v { + list.Add(k, v.(string)) + } + case any: + list.Add(k, v.(string)) + } + } + return list, nil +} + +// Blacklist slice parser +func blacklistParseSlice(in []string) (*types.Blacklist, error) { + list := &types.Blacklist{} + for _, i := range in { + var action string + parts := strings.SplitN(i, ":", 2) + + if len(parts) < 2 { + action = "*" + } else { + action = parts[1] + } + + list.Add(parts[0], action) + } + return list, nil +} diff --git a/internal/config/builder_test.go b/internal/config/builder_test.go index d226a30..bec8178 100644 --- a/internal/config/builder_test.go +++ b/internal/config/builder_test.go @@ -28,10 +28,10 @@ func TestBuilder(t *testing.T) { EndBlockNum: 23872222, IrreversibleOnly: true, MaxMessagesInFlight: 1337, - Blacklist: types.Blacklist{ + Blacklist: *types.NewBlacklist(map[string][]string{ "eosio": {"noop"}, "contract": {"skip1", "skip2"}, - }, + }), }, Telegram: TelegramConfig{ Id: "110201543:AAHdqTcvCH1vGWJxfSeofSAs0K5PALDsaw", @@ -64,7 +64,7 @@ ship: start_block_num: 23671836 end_block_num: 23872222 blacklist: - eosio: ["noop"] + eosio: noop contract: - skip1 - skip2 @@ -207,10 +207,10 @@ func TestBuilder_Flags(t *testing.T) { MaxMessagesInFlight: 98, IrreversibleOnly: true, Chain: "wax", - Blacklist: types.Blacklist{ + Blacklist: *types.NewBlacklist(map[string][]string{ "contract": {"action1", "action2"}, "contract2": {"action1"}, - }, + }), }, Telegram: TelegramConfig{ Id: "72983126312982618", @@ -229,20 +229,28 @@ func TestBuilder_Flags(t *testing.T) { require.Equal(t, &expected, cfg) } -func TestBuilder_BlacklistFlag(t *testing.T) { - flags := GetFlags() - - require.NoError(t, flags.Set("blacklist", "contract,contract:action2")) - - conf, err := NewBuilder(). - SetSource(bytes.NewReader([]byte(``))). - SetFlags(flags). - Build() - - expected := types.Blacklist{ - "contract": {"*", "action2"}, +func TestBuilder_BlacklistSlice(t *testing.T) { + expected := Config{ + Ship: ShipConfig{ + Blacklist: *types.NewBlacklist(map[string][]string{ + "contract": {"action"}, + "contract2": {"action2"}, + "contract3": {"*"}, + }), + }, } + builder := NewBuilder() + builder.SetSource(bytes.NewBuffer([]byte(` +ship: + blacklist: + - "contract:action" + - "contract2:action2" + - contract3 +`))) + + cfg, err := builder.Build() + require.NoError(t, err) - require.Equal(t, expected, conf.Ship.Blacklist) + require.Equal(t, &expected, cfg) } diff --git a/internal/types/blacklist.go b/internal/types/blacklist.go index 12ee19d..24419c8 100644 --- a/internal/types/blacklist.go +++ b/internal/types/blacklist.go @@ -1,27 +1,45 @@ package types -type Blacklist map[string][]string - -func (bl Blacklist) Empty() bool { - return len(bl) < 1 +type Blacklist struct { + table map[string][]string + isWhitelist bool } -func (bl Blacklist) Add(contract string, action string) { - if len(bl[contract]) < 1 { - bl[contract] = []string{} +func NewBlacklist(entries map[string][]string) *Blacklist { + return &Blacklist{ + table: entries, } - bl[contract] = append(bl[contract], action) +} + +func (bl *Blacklist) SetWhitelist(value bool) *Blacklist { + bl.isWhitelist = value + return bl +} + +func (bl Blacklist) Empty() bool { + return len(bl.table) < 1 +} + +func (bl *Blacklist) Add(contract string, action string) { + if bl.table == nil { + bl.table = map[string][]string{} + } + + if len(bl.table[contract]) < 1 { + bl.table[contract] = []string{} + } + bl.table[contract] = append(bl.table[contract], action) } func (bl Blacklist) IsAllowed(contract string, action string) bool { - if v, ok := bl[contract]; ok { + if v, ok := bl.table[contract]; ok { for _, act := range v { if act == action || act == "*" { - return false + return bl.isWhitelist == true } } } - return true + return bl.isWhitelist == false } func (bl Blacklist) IsDenied(contract string, action string) bool { diff --git a/internal/types/blacklist_test.go b/internal/types/blacklist_test.go index 1cbcbfe..631d329 100644 --- a/internal/types/blacklist_test.go +++ b/internal/types/blacklist_test.go @@ -7,7 +7,9 @@ import ( ) func TestBlacklist_Empty(t *testing.T) { - bl := Blacklist{} + bl := Blacklist{ + table: map[string][]string{}, + } require.True(t, bl.Empty()) @@ -17,14 +19,18 @@ func TestBlacklist_Empty(t *testing.T) { } func TestBlacklist_Add(t *testing.T) { - bl := Blacklist{} + bl := Blacklist{ + table: map[string][]string{}, + } bl.Add("contract", "action1") bl.Add("contract", "action2") bl.Add("contract2", "action1") expected := Blacklist{ - "contract": {"action1", "action2"}, - "contract2": {"action1"}, + table: map[string][]string{ + "contract": {"action1", "action2"}, + "contract2": {"action1"}, + }, } require.Equal(t, expected, bl) @@ -32,7 +38,9 @@ func TestBlacklist_Add(t *testing.T) { func TestBlacklist_IsAllowed(t *testing.T) { bl := Blacklist{ - "mycontract": {"myaction", "noop"}, + table: map[string][]string{ + "mycontract": {"myaction", "noop"}, + }, } require.False(t, bl.IsAllowed("mycontract", "myaction")) @@ -43,7 +51,9 @@ func TestBlacklist_IsAllowed(t *testing.T) { func TestBlacklist_IsAllowedWildcard(t *testing.T) { bl := Blacklist{ - "mycontract": {"*"}, + table: map[string][]string{ + "mycontract": {"*"}, + }, } require.False(t, bl.IsAllowed("mycontract", "myaction")) @@ -51,3 +61,18 @@ func TestBlacklist_IsAllowedWildcard(t *testing.T) { require.False(t, bl.IsAllowed("mycontract", "xxx")) require.True(t, bl.IsAllowed("xxx", "yyy")) } + +func TestBlacklist_Whitelist(t *testing.T) { + bl := Blacklist{ + table: map[string][]string{ + "mycontract": {"myaction", "noop"}, + }, + } + + bl.SetWhitelist(true) + + require.True(t, bl.IsAllowed("mycontract", "myaction")) + require.True(t, bl.IsAllowed("mycontract", "noop")) + require.False(t, bl.IsAllowed("mycontract", "xxx")) + require.False(t, bl.IsAllowed("xxx", "yyy")) +}