From a6c98a3209c5508d08e65109cfa3150cb63ca9e9 Mon Sep 17 00:00:00 2001 From: Henrik Hautakoski Date: Thu, 7 Dec 2023 20:40:49 +0100 Subject: [PATCH] ip/cache.go: skip storing a NetInterfaceIPResolver in struct, Define a CacheDefaultCallback type and have it passed to the new GetWithDefault function instead. --- app/app.go | 28 ++++++++++++++++++++++------ ip/cache.go | 21 +++++++++++++++------ ip/cache_test.go | 41 +++++++++++++++++++++++++++++++++++++---- 3 files changed, 74 insertions(+), 16 deletions(-) diff --git a/app/app.go b/app/app.go index 9e07e87..b17634e 100644 --- a/app/app.go +++ b/app/app.go @@ -1,8 +1,10 @@ package app import ( + "context" "fmt" "net" + "time" "dnsupdater/provider/manager" @@ -11,12 +13,25 @@ import ( ) type App struct { - iplookup ip.NetInterfaceIPResolver + cache *ip.Cache + + cacheDefaultCallback ip.CacheDefaultCallback // Updater manager ProviderManager *manager.Manager } +func makeCacheCallback(service resolver.Service) ip.CacheDefaultCallback { + return func(name string) (net.IP, error) { + if name == resolver.WAN_IFACE { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + return service.Lookup(ctx) + } + return ip.GetInterfaceIP(name) + } +} + func NewApp(config *Config) (*App, error) { providerMgr := manager.New() // providerMgr.Register("digitalocean", digitalocean.New(config.Services.DigitalOcean.Token)) @@ -25,18 +40,19 @@ func NewApp(config *Config) (*App, error) { return nil, err } - l := resolver.Get(config.Services.IPLookup) + service := resolver.Get(config.Services.IPLookup) - if l == nil { + if service == nil { return nil, fmt.Errorf("Failed to load lookup service: %s", config.Services.IPLookup) } return &App{ - ProviderManager: providerMgr, - iplookup: ip.NewCache(ip.LookupWrapper(l)).Get, + ProviderManager: providerMgr, + cache: ip.NewCache(), + cacheDefaultCallback: makeCacheCallback(service), }, nil } func (a App) GetIP(iface_name string) (net.IP, error) { - return a.iplookup(iface_name) + return a.cache.GetWithDefault(iface_name, a.cacheDefaultCallback) } diff --git a/ip/cache.go b/ip/cache.go index 4fa130b..16e1ea8 100644 --- a/ip/cache.go +++ b/ip/cache.go @@ -1,18 +1,19 @@ package ip import ( + "errors" "net" ) +type CacheDefaultCallback func(name string) (net.IP, error) + type Cache struct { - resolver NetInterfaceIPResolver - items map[string]net.IP + items map[string]net.IP } -func NewCache(resolver NetInterfaceIPResolver) *Cache { +func NewCache() *Cache { return &Cache{ - resolver: resolver, - items: make(map[string]net.IP), + items: make(map[string]net.IP), } } @@ -21,8 +22,16 @@ func (c Cache) Get(name string) (net.IP, error) { if cached, ok := c.items[name]; ok { return cached, nil } + return nil, errors.New("key did not exist") +} - ip, err := c.resolver(name) +func (c Cache) GetWithDefault(name string, callback CacheDefaultCallback) (net.IP, error) { + // Return cached entry. + if cached, ok := c.items[name]; ok { + return cached, nil + } + + ip, err := callback(name) if err == nil { c.Set(name, ip) } diff --git a/ip/cache_test.go b/ip/cache_test.go index a66e949..4bececa 100644 --- a/ip/cache_test.go +++ b/ip/cache_test.go @@ -9,13 +9,20 @@ import ( "github.com/stretchr/testify/assert" ) -func mockResolver(t *testing.T, expected_name string, ip net.IP, err error) NetInterfaceIPResolver { +func defaultCallback(t *testing.T, expected_name string, ip net.IP, err error) CacheDefaultCallback { return func(name string) (net.IP, error) { assert.Equal(t, expected_name, name) return ip, err } } +func dontCallDefaultCallback(t *testing.T) CacheDefaultCallback { + return func(name string) (net.IP, error) { + t.Error("Should not have been called") + return nil, nil + } +} + func TestCache_Get(t *testing.T) { tests := []struct { name string @@ -24,9 +31,8 @@ func TestCache_Get(t *testing.T) { want net.IP wantErr bool }{ - {"FromCache", &Cache{resolver: nil, items: map[string]net.IP{"eth0": net.IPv4(10, 4, 0, 1)}}, "eth0", net.IPv4(10, 4, 0, 1), false}, - {"FromResolver", NewCache(mockResolver(t, "eth1", net.IPv4(192, 172, 44, 25), nil)), "eth1", net.IPv4(192, 172, 44, 25), false}, - {"NoInterface", NewCache(mockResolver(t, "eth2", nil, errors.New("Invalid interface"))), "eth2", nil, true}, + {"Exists in cache", &Cache{items: map[string]net.IP{"eth0": net.IPv4(10, 4, 0, 1)}}, "eth0", net.IPv4(10, 4, 0, 1), false}, + {"Did not exist in cache", &Cache{items: map[string]net.IP{}}, "eth0", nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -41,3 +47,30 @@ func TestCache_Get(t *testing.T) { }) } } + +func TestCache_GetWithDefault(t *testing.T) { + tests := []struct { + name string + c *Cache + def CacheDefaultCallback + iface string + want net.IP + wantErr bool + }{ + {"Exists in cache", &Cache{items: map[string]net.IP{"eth0": net.IPv4(10, 4, 0, 1)}}, dontCallDefaultCallback(t), "eth0", net.IPv4(10, 4, 0, 1), false}, + {"Did not exists in cache", NewCache(), defaultCallback(t, "eth1", net.IPv4(192, 172, 44, 25), nil), "eth1", net.IPv4(192, 172, 44, 25), false}, + {"Callback returns error", NewCache(), defaultCallback(t, "eth1", nil, errors.New("some error")), "eth1", nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.c.GetWithDefault(tt.iface, tt.def) + if (err != nil) != tt.wantErr { + t.Errorf("Cache.Get() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Cache.Get() = %v, want %v", got, tt.want) + } + }) + } +}