diff --git a/.gitignore b/.gitignore index a2acf60..82305e2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ -config.yml -build/ \ No newline at end of file +config.yml +build/ diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 793653a..e1c6caa 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,22 +1,22 @@ - -image: golang:1.23 - -stages: - - test - - build - -unit-test: - stage: test - script: - - go test -v ./... - -compile: - stage: build - script: - - mkdir -p build - - GOOS=linux GOARCH=amd64 go build -o build/dnsupdater-linux-amd64 cmd/dnsupdater/main.go - - GOOS=linux GOARCH=mips GOMIPS=softfloat go build -o build/dnsupdater-linux-mips cmd/dnsupdater/main.go - - GOOS=linux GOARCH=mipsle GOMIPS=softfloat go build -o build/dnsupdater-linux-mipsle cmd/dnsupdater/main.go - artifacts: - paths: - - build + +image: golang:1.23 + +stages: + - test + - build + +unit-test: + stage: test + script: + - go test -v ./... + +compile: + stage: build + script: + - mkdir -p build + - GOOS=linux GOARCH=amd64 go build -o build/dnsupdater-linux-amd64 cmd/dnsupdater/main.go + - GOOS=linux GOARCH=mips GOMIPS=softfloat go build -o build/dnsupdater-linux-mips cmd/dnsupdater/main.go + - GOOS=linux GOARCH=mipsle GOMIPS=softfloat go build -o build/dnsupdater-linux-mipsle cmd/dnsupdater/main.go + artifacts: + paths: + - build diff --git a/Makefile b/Makefile index bbd67a9..4014040 100644 --- a/Makefile +++ b/Makefile @@ -1,12 +1,12 @@ -GO=go -VERSION=$(shell git describe --always --tags --dirty --match="v*") -GOLDFLAGS=-v -s -w -X main.version="$(VERSION)" -GOBUILDFLAGS=-v -p $(shell nproc) -ldflags="$(GOLDFLAGS)" - -.PHONY: build test - -build : - $(GO) build $(GOBUILDFLAGS) -o build/dnsupdater cmd/dnsupdater/main.go - -test : - $(GO) test -v ./... +GO=go +VERSION=$(shell git describe --always --tags --dirty --match="v*") +GOLDFLAGS=-v -s -w -X main.version="$(VERSION)" +GOBUILDFLAGS=-v -p $(shell nproc) -ldflags="$(GOLDFLAGS)" + +.PHONY: build test + +build : + $(GO) build $(GOBUILDFLAGS) -o build/dnsupdater cmd/dnsupdater/main.go + +test : + $(GO) test -v ./... diff --git a/app/app.go b/app/app.go index 41a9cb5..fef9a75 100644 --- a/app/app.go +++ b/app/app.go @@ -1,59 +1,59 @@ -package app - -import ( - "context" - "fmt" - "net" - "time" - - dnsservice "dnsupdater/dns/service" - "dnsupdater/ip" - "dnsupdater/ip/resolver" -) - -// WAN_IFACE Name for the virtual WAN interface -const WAN_IFACE = "wan" - -type App struct { - cache *ip.Cache - - cacheDefaultCallback ip.CacheDefaultCallback - - // DNS service manager - DnsServiceMgr *dnsservice.Manager -} - -func makeCacheCallback(service resolver.Service) ip.CacheDefaultCallback { - return func(name string) (net.IP, error) { - if name == WAN_IFACE { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - return service.Lookup(ctx) - } - return ip.GetInterfaceIP(name) - } -} - -func NewApp(config *Config) (*App, error) { - dnsServiceMgr := dnsservice.NewManager() - err := dnsServiceMgr.RegisterFromConfig(config.Providers) - if err != nil { - return nil, err - } - - ipService := resolver.Get(config.Services.IPLookup) - - if ipService == nil { - return nil, fmt.Errorf("failed to load lookup service: %s", config.Services.IPLookup) - } - - return &App{ - DnsServiceMgr: dnsServiceMgr, - cache: ip.NewCache(), - cacheDefaultCallback: makeCacheCallback(ipService), - }, nil -} - -func (a App) GetIP(iface_name string) (net.IP, error) { - return a.cache.GetWithDefault(iface_name, a.cacheDefaultCallback) -} +package app + +import ( + "context" + "fmt" + "net" + "time" + + dnsservice "dnsupdater/dns/service" + "dnsupdater/ip" + "dnsupdater/ip/resolver" +) + +// WAN_IFACE Name for the virtual WAN interface +const WAN_IFACE = "wan" + +type App struct { + cache *ip.Cache + + cacheDefaultCallback ip.CacheDefaultCallback + + // DNS service manager + DnsServiceMgr *dnsservice.Manager +} + +func makeCacheCallback(service resolver.Service) ip.CacheDefaultCallback { + return func(name string) (net.IP, error) { + if name == WAN_IFACE { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + return service.Lookup(ctx) + } + return ip.GetInterfaceIP(name) + } +} + +func NewApp(config *Config) (*App, error) { + dnsServiceMgr := dnsservice.NewManager() + err := dnsServiceMgr.RegisterFromConfig(config.Providers) + if err != nil { + return nil, err + } + + ipService := resolver.Get(config.Services.IPLookup) + + if ipService == nil { + return nil, fmt.Errorf("failed to load lookup service: %s", config.Services.IPLookup) + } + + return &App{ + DnsServiceMgr: dnsServiceMgr, + cache: ip.NewCache(), + cacheDefaultCallback: makeCacheCallback(ipService), + }, nil +} + +func (a App) GetIP(iface_name string) (net.IP, error) { + return a.cache.GetWithDefault(iface_name, a.cacheDefaultCallback) +} diff --git a/app/config.go b/app/config.go index 3d058de..0dd4e53 100644 --- a/app/config.go +++ b/app/config.go @@ -1,46 +1,46 @@ -package app - -import ( - "os" - - "gopkg.in/yaml.v3" -) - -type ( - DomainRecords map[string]string - Domain map[string]DomainRecords -) - -type DigitalOceanService struct { - Token string `yaml:"token"` - Domains map[string]DomainRecords `yaml:"domains"` -} - -type Providers struct { - Token string `yaml:"token"` - Domains map[string]DomainRecords `yaml:"domains"` -} - -type Services struct { - IPLookup string `yaml:"IPLookup"` -} - -type Config struct { - Services Services `yaml:"services"` - Providers map[string]map[string]any - Updates map[string]Domain -} - -func LoadConfig(filename string) (*Config, error) { - cfg := Config{ - Services: Services{ - IPLookup: "ipecho", - }, - } - - data, err := os.ReadFile(filename) - if err == nil { - err = yaml.Unmarshal(data, &cfg) - } - return &cfg, err -} +package app + +import ( + "os" + + "gopkg.in/yaml.v3" +) + +type ( + DomainRecords map[string]string + Domain map[string]DomainRecords +) + +type DigitalOceanService struct { + Token string `yaml:"token"` + Domains map[string]DomainRecords `yaml:"domains"` +} + +type Providers struct { + Token string `yaml:"token"` + Domains map[string]DomainRecords `yaml:"domains"` +} + +type Services struct { + IPLookup string `yaml:"IPLookup"` +} + +type Config struct { + Services Services `yaml:"services"` + Providers map[string]map[string]any + Updates map[string]Domain +} + +func LoadConfig(filename string) (*Config, error) { + cfg := Config{ + Services: Services{ + IPLookup: "ipecho", + }, + } + + data, err := os.ReadFile(filename) + if err == nil { + err = yaml.Unmarshal(data, &cfg) + } + return &cfg, err +} diff --git a/cmd/dnsupdater/main.go b/cmd/dnsupdater/main.go index 008a915..71fe33a 100644 --- a/cmd/dnsupdater/main.go +++ b/cmd/dnsupdater/main.go @@ -1,84 +1,84 @@ -package main - -import ( - "flag" - "fmt" - "os" - "time" - - App "dnsupdater/app" - - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" -) - -var version string = "(unknown)" - -func main() { - configFile := flag.String("config", "./config.yml", "configuration file") - versionFlag := flag.Bool("v", false, "Prints the version") - - flag.Parse() - - if *versionFlag { - fmt.Println(version) - os.Exit(0) - } - - log.Logger = log.Output(zerolog.ConsoleWriter{ - Out: os.Stderr, - TimeFormat: time.RFC3339, - }) - - config, err := App.LoadConfig(*configFile) - if err != nil { - log.Fatal().Err(err).Str("file", *configFile).Msg("Failed to load config") - } - - app, err := App.NewApp(config) - if err != nil { - log.Fatal().Err(err).Msg("Failed to initialize application") - } - - for service_name, domains := range config.Updates { - - // Get DNS Service - dnsService := app.DnsServiceMgr.Get(service_name) - - if dnsService == nil { - log.Warn().Str("service", service_name).Msg("Invalid DNS service") - continue - } - - log.Info().Str("service", service_name).Msg("Begin update for service") - - updater := App.NewUpdater(dnsService) - - for domain, records := range domains { - for name, data := range records { - - logger := log.With(). - Str("service", service_name). - Str("domain", domain). - Str("record", name). - Str("interface", data). - Logger() - - ip, err := app.GetIP(data) - if err != nil { - logger.Error().Err(err).Msg("Failed to fetch ip") - continue - } - - logger = logger.With().IPAddr("ip", ip).Logger() - - err = updater.Update(domain, name, ip) - if err != nil { - logger.Error().Err(err).Msg("Failed to update record") - } else { - logger.Info().Msg("Record updated") - } - } - } - } -} +package main + +import ( + "flag" + "fmt" + "os" + "time" + + App "dnsupdater/app" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +var version string = "(unknown)" + +func main() { + configFile := flag.String("config", "./config.yml", "configuration file") + versionFlag := flag.Bool("v", false, "Prints the version") + + flag.Parse() + + if *versionFlag { + fmt.Println(version) + os.Exit(0) + } + + log.Logger = log.Output(zerolog.ConsoleWriter{ + Out: os.Stderr, + TimeFormat: time.RFC3339, + }) + + config, err := App.LoadConfig(*configFile) + if err != nil { + log.Fatal().Err(err).Str("file", *configFile).Msg("Failed to load config") + } + + app, err := App.NewApp(config) + if err != nil { + log.Fatal().Err(err).Msg("Failed to initialize application") + } + + for service_name, domains := range config.Updates { + + // Get DNS Service + dnsService := app.DnsServiceMgr.Get(service_name) + + if dnsService == nil { + log.Warn().Str("service", service_name).Msg("Invalid DNS service") + continue + } + + log.Info().Str("service", service_name).Msg("Begin update for service") + + updater := App.NewUpdater(dnsService) + + for domain, records := range domains { + for name, data := range records { + + logger := log.With(). + Str("service", service_name). + Str("domain", domain). + Str("record", name). + Str("interface", data). + Logger() + + ip, err := app.GetIP(data) + if err != nil { + logger.Error().Err(err).Msg("Failed to fetch ip") + continue + } + + logger = logger.With().IPAddr("ip", ip).Logger() + + err = updater.Update(domain, name, ip) + if err != nil { + logger.Error().Err(err).Msg("Failed to update record") + } else { + logger.Info().Msg("Record updated") + } + } + } + } +} diff --git a/config.example.yml b/config.example.yml index 7471468..c4220de 100644 --- a/config.example.yml +++ b/config.example.yml @@ -1,26 +1,26 @@ - -services: - IPLookup: ipecho - -providers: - digitalocean: - token: xxxx - vultr: - token: xxxx - -updates: - digitalocean: - domain1.com: - www: wan - box: 10.140.14.2 - domain2.com: - www: wan - mail: wan - static: 84.24.254.21 - vultr: - example1.com: - www: wan - example2.com: - www: wan - ftp: 88.212.99.90 - + +services: + IPLookup: ipecho + +providers: + digitalocean: + token: xxxx + vultr: + token: xxxx + +updates: + digitalocean: + domain1.com: + www: wan + box: 10.140.14.2 + domain2.com: + www: wan + mail: wan + static: 84.24.254.21 + vultr: + example1.com: + www: wan + example2.com: + www: wan + ftp: 88.212.99.90 + diff --git a/dns/service/digitalocean/mock_test.go b/dns/service/digitalocean/mock_test.go index 2f64a24..b9ecb86 100644 --- a/dns/service/digitalocean/mock_test.go +++ b/dns/service/digitalocean/mock_test.go @@ -1,109 +1,109 @@ -package digitalocean - -import ( - "context" - "errors" - "testing" - - "github.com/digitalocean/godo" - "github.com/stretchr/testify/assert" -) - -type mock struct { - t *testing.T - - records_by_type map[string][]godo.DomainRecord - - edit_record_request *godo.DomainRecordEditRequest - edit_record_error error -} - -func (m mock) List(context.Context, *godo.ListOptions) ([]godo.Domain, *godo.Response, error) { - m.t.Error("List called when it should not have been") - return nil, nil, nil -} - -func (m mock) Get(context.Context, string) (*godo.Domain, *godo.Response, error) { - m.t.Error("Get called when it should not have been") - return nil, nil, nil -} - -func (m mock) Create(context.Context, *godo.DomainCreateRequest) (*godo.Domain, *godo.Response, error) { - m.t.Error("Create called when it should not have been") - return nil, nil, nil -} - -func (m mock) Delete(context.Context, string) (*godo.Response, error) { - m.t.Error("Delete called when it should not have been") - return nil, nil -} - -func (m mock) Records(context.Context, string, *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) { - m.t.Error("Records called when it should not have been") - return nil, nil, nil -} - -func (m mock) RecordsByType(_ context.Context, name string, t string, opt *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) { - var err error - - // Only care about "A" records - assert.Equal(m.t, "A", t) - - r, ok := m.records_by_type[name] - if !ok { - err = errors.New("Record not found") - } - return r, nil, err -} - -func (m mock) RecordsByName(context.Context, string, string, *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) { - m.t.Error("RecordsByName called when it should not have been") - return nil, nil, nil -} - -func (m mock) RecordsByTypeAndName(context.Context, string, string, string, *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) { - m.t.Error("RecordsByTypeAndName called when it should not have been") - return nil, nil, nil -} - -func (m mock) Record(context.Context, string, int) (*godo.DomainRecord, *godo.Response, error) { - m.t.Error("Record called when it should not have been") - return nil, nil, nil -} - -func (m mock) DeleteRecord(context.Context, string, int) (*godo.Response, error) { - m.t.Error("DeleteRecord called when it should not have been") - return nil, nil -} - -func (m mock) EditRecord(_ context.Context, domain string, id int, req *godo.DomainRecordEditRequest) (*godo.DomainRecord, *godo.Response, error) { - if m.edit_record_request == nil { - m.t.Error("EditRecord called with empty request") - } - - if m.edit_record_error != nil { - return nil, nil, m.edit_record_error - } - - assert.Equal(m.t, m.edit_record_request, req) - - record := godo.DomainRecord{ - ID: id, - Type: req.Type, - Name: req.Name, - Data: req.Data, - Priority: req.Priority, - Port: req.Port, - TTL: req.TTL, - Weight: req.Weight, - Flags: req.Flags, - Tag: req.Tag, - } - - return &record, nil, nil -} - -func (m mock) CreateRecord(context.Context, string, *godo.DomainRecordEditRequest) (*godo.DomainRecord, *godo.Response, error) { - m.t.Error("CreateRecord called when it should not have been") - return nil, nil, nil -} +package digitalocean + +import ( + "context" + "errors" + "testing" + + "github.com/digitalocean/godo" + "github.com/stretchr/testify/assert" +) + +type mock struct { + t *testing.T + + records_by_type map[string][]godo.DomainRecord + + edit_record_request *godo.DomainRecordEditRequest + edit_record_error error +} + +func (m mock) List(context.Context, *godo.ListOptions) ([]godo.Domain, *godo.Response, error) { + m.t.Error("List called when it should not have been") + return nil, nil, nil +} + +func (m mock) Get(context.Context, string) (*godo.Domain, *godo.Response, error) { + m.t.Error("Get called when it should not have been") + return nil, nil, nil +} + +func (m mock) Create(context.Context, *godo.DomainCreateRequest) (*godo.Domain, *godo.Response, error) { + m.t.Error("Create called when it should not have been") + return nil, nil, nil +} + +func (m mock) Delete(context.Context, string) (*godo.Response, error) { + m.t.Error("Delete called when it should not have been") + return nil, nil +} + +func (m mock) Records(context.Context, string, *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) { + m.t.Error("Records called when it should not have been") + return nil, nil, nil +} + +func (m mock) RecordsByType(_ context.Context, name string, t string, opt *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) { + var err error + + // Only care about "A" records + assert.Equal(m.t, "A", t) + + r, ok := m.records_by_type[name] + if !ok { + err = errors.New("Record not found") + } + return r, nil, err +} + +func (m mock) RecordsByName(context.Context, string, string, *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) { + m.t.Error("RecordsByName called when it should not have been") + return nil, nil, nil +} + +func (m mock) RecordsByTypeAndName(context.Context, string, string, string, *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) { + m.t.Error("RecordsByTypeAndName called when it should not have been") + return nil, nil, nil +} + +func (m mock) Record(context.Context, string, int) (*godo.DomainRecord, *godo.Response, error) { + m.t.Error("Record called when it should not have been") + return nil, nil, nil +} + +func (m mock) DeleteRecord(context.Context, string, int) (*godo.Response, error) { + m.t.Error("DeleteRecord called when it should not have been") + return nil, nil +} + +func (m mock) EditRecord(_ context.Context, domain string, id int, req *godo.DomainRecordEditRequest) (*godo.DomainRecord, *godo.Response, error) { + if m.edit_record_request == nil { + m.t.Error("EditRecord called with empty request") + } + + if m.edit_record_error != nil { + return nil, nil, m.edit_record_error + } + + assert.Equal(m.t, m.edit_record_request, req) + + record := godo.DomainRecord{ + ID: id, + Type: req.Type, + Name: req.Name, + Data: req.Data, + Priority: req.Priority, + Port: req.Port, + TTL: req.TTL, + Weight: req.Weight, + Flags: req.Flags, + Tag: req.Tag, + } + + return &record, nil, nil +} + +func (m mock) CreateRecord(context.Context, string, *godo.DomainRecordEditRequest) (*godo.DomainRecord, *godo.Response, error) { + m.t.Error("CreateRecord called when it should not have been") + return nil, nil, nil +} diff --git a/dns/service/digitalocean/service.go b/dns/service/digitalocean/service.go index e0414be..4ea1efd 100644 --- a/dns/service/digitalocean/service.go +++ b/dns/service/digitalocean/service.go @@ -1,67 +1,67 @@ -package digitalocean - -import ( - "context" - "errors" - "net" - "strconv" - - "dnsupdater/dns" - - "github.com/digitalocean/godo" -) - -type Service struct { - api godo.DomainsService -} - -func New(token string) Service { - return Service{ - api: godo.NewFromToken(token).Domains, - } -} - -func Factory(args map[string]any) (any, error) { - t, ok := args["token"] - if !ok { - return nil, errors.New("did not find token") - } - - token, ok := t.(string) - if !ok { - return nil, errors.New("token must be a string") - } - - return New(token), nil -} - -func (d Service) List(domain_name string) (dns.RecordList, error) { - fetchedRecords, _, err := d.api.RecordsByType(context.Background(), domain_name, "A", &godo.ListOptions{ - PerPage: 50, - }) - if err != nil { - return nil, err - } - - records := dns.RecordList{} - for _, rec := range fetchedRecords { - records.Add(dns.Record{ - Id: strconv.Itoa(rec.ID), - Name: rec.Name, - Ip: net.ParseIP(rec.Data), - }) - } - return records, nil -} - -func (d Service) Update(domain, recordID, ip string) error { - id, err := strconv.Atoi(recordID) - if err != nil { - return err - } - - _, _, err = d.api.EditRecord(context.Background(), domain, id, &godo.DomainRecordEditRequest{ - Data: ip, - }) - return err -} +package digitalocean + +import ( + "context" + "errors" + "net" + "strconv" + + "dnsupdater/dns" + + "github.com/digitalocean/godo" +) + +type Service struct { + api godo.DomainsService +} + +func New(token string) Service { + return Service{ + api: godo.NewFromToken(token).Domains, + } +} + +func Factory(args map[string]any) (any, error) { + t, ok := args["token"] + if !ok { + return nil, errors.New("did not find token") + } + + token, ok := t.(string) + if !ok { + return nil, errors.New("token must be a string") + } + + return New(token), nil +} + +func (d Service) List(domain_name string) (dns.RecordList, error) { + fetchedRecords, _, err := d.api.RecordsByType(context.Background(), domain_name, "A", &godo.ListOptions{ + PerPage: 50, + }) + if err != nil { + return nil, err + } + + records := dns.RecordList{} + for _, rec := range fetchedRecords { + records.Add(dns.Record{ + Id: strconv.Itoa(rec.ID), + Name: rec.Name, + Ip: net.ParseIP(rec.Data), + }) + } + return records, nil +} + +func (d Service) Update(domain, recordID, ip string) error { + id, err := strconv.Atoi(recordID) + if err != nil { + return err + } + + _, _, err = d.api.EditRecord(context.Background(), domain, id, &godo.DomainRecordEditRequest{ + Data: ip, + }) + return err +} diff --git a/dns/service/manager.go b/dns/service/manager.go index acbf8f4..48fd49d 100644 --- a/dns/service/manager.go +++ b/dns/service/manager.go @@ -1,49 +1,49 @@ -package service - -import ( - "fmt" - - "dnsupdater/dns/service/digitalocean" - "dnsupdater/dns/service/vultr" -) - -var factories = map[string]Factory{ - "digitalocean": digitalocean.Factory, - "vultr": vultr.Factory, -} - -type Manager struct { - services map[string]Service -} - -func NewManager() *Manager { - return &Manager{ - services: make(map[string]Service), - } -} - -func (m Manager) Get(name string) Service { - if service, ok := m.services[name]; ok { - return service - } - return nil -} - -func (m Manager) RegisterFromConfig(providers map[string]map[string]any) error { - for name, args := range providers { - if factory, ok := factories[name]; ok { - - provider, err := factory(args) - if err != nil { - return fmt.Errorf("could not create provider '%s': %v", name, err) - } - - m.Register(name, provider.(Service)) - } - } - return nil -} - -func (m Manager) Register(name string, provider Service) { - m.services[name] = provider -} +package service + +import ( + "fmt" + + "dnsupdater/dns/service/digitalocean" + "dnsupdater/dns/service/vultr" +) + +var factories = map[string]Factory{ + "digitalocean": digitalocean.Factory, + "vultr": vultr.Factory, +} + +type Manager struct { + services map[string]Service +} + +func NewManager() *Manager { + return &Manager{ + services: make(map[string]Service), + } +} + +func (m Manager) Get(name string) Service { + if service, ok := m.services[name]; ok { + return service + } + return nil +} + +func (m Manager) RegisterFromConfig(providers map[string]map[string]any) error { + for name, args := range providers { + if factory, ok := factories[name]; ok { + + provider, err := factory(args) + if err != nil { + return fmt.Errorf("could not create provider '%s': %v", name, err) + } + + m.Register(name, provider.(Service)) + } + } + return nil +} + +func (m Manager) Register(name string, provider Service) { + m.services[name] = provider +} diff --git a/dns/service/service.go b/dns/service/service.go index e76a8e5..405a96d 100644 --- a/dns/service/service.go +++ b/dns/service/service.go @@ -1,10 +1,10 @@ -package service - -import "dnsupdater/dns" - -type Service interface { - List(domain string) (dns.RecordList, error) - Update(domain, recordID, ip string) error -} - -type Factory func(map[string]any) (any, error) +package service + +import "dnsupdater/dns" + +type Service interface { + List(domain string) (dns.RecordList, error) + Update(domain, recordID, ip string) error +} + +type Factory func(map[string]any) (any, error) diff --git a/http/get.go b/http/get.go index 3da6058..ef16922 100644 --- a/http/get.go +++ b/http/get.go @@ -1,27 +1,27 @@ -package http - -import ( - "context" - "fmt" - "net/http" -) - -// Perform a HTTP Get request. -func Get(ctx context.Context, url string, headers http.Header) (*http.Response, error) { - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - if err != nil { - return nil, err - } - - req.Header = headers - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - - if resp.StatusCode >= 400 { - return nil, fmt.Errorf("HTTP Response: %s", resp.Status) - } - return resp, nil -} +package http + +import ( + "context" + "fmt" + "net/http" +) + +// Perform a HTTP Get request. +func Get(ctx context.Context, url string, headers http.Header) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + + req.Header = headers + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("HTTP Response: %s", resp.Status) + } + return resp, nil +} diff --git a/ip/cache.go b/ip/cache.go index fc5b5f6..16e1ea8 100644 --- a/ip/cache.go +++ b/ip/cache.go @@ -1,43 +1,43 @@ -package ip - -import ( - "errors" - "net" -) - -type CacheDefaultCallback func(name string) (net.IP, error) - -type Cache struct { - items map[string]net.IP -} - -func NewCache() *Cache { - return &Cache{ - items: make(map[string]net.IP), - } -} - -func (c Cache) Get(name string) (net.IP, error) { - // Return cached entry. - if cached, ok := c.items[name]; ok { - return cached, nil - } - return nil, errors.New("key did not exist") -} - -func (c Cache) GetWithDefault(name string, callback CacheDefaultCallback) (net.IP, error) { - // Return cached entry. - if cached, ok := c.items[name]; ok { - return cached, nil - } - - ip, err := callback(name) - if err == nil { - c.Set(name, ip) - } - return ip, err -} - -func (c *Cache) Set(name string, ip net.IP) { - c.items[name] = ip -} +package ip + +import ( + "errors" + "net" +) + +type CacheDefaultCallback func(name string) (net.IP, error) + +type Cache struct { + items map[string]net.IP +} + +func NewCache() *Cache { + return &Cache{ + items: make(map[string]net.IP), + } +} + +func (c Cache) Get(name string) (net.IP, error) { + // Return cached entry. + if cached, ok := c.items[name]; ok { + return cached, nil + } + return nil, errors.New("key did not exist") +} + +func (c Cache) GetWithDefault(name string, callback CacheDefaultCallback) (net.IP, error) { + // Return cached entry. + if cached, ok := c.items[name]; ok { + return cached, nil + } + + ip, err := callback(name) + if err == nil { + c.Set(name, ip) + } + return ip, err +} + +func (c *Cache) Set(name string, ip net.IP) { + c.items[name] = ip +} diff --git a/ip/cache_test.go b/ip/cache_test.go index a9bd157..4bececa 100644 --- a/ip/cache_test.go +++ b/ip/cache_test.go @@ -1,76 +1,76 @@ -package ip - -import ( - "errors" - "net" - "reflect" - "testing" - - "github.com/stretchr/testify/assert" -) - -func defaultCallback(t *testing.T, expected_name string, ip net.IP, err error) CacheDefaultCallback { - return func(name string) (net.IP, error) { - assert.Equal(t, expected_name, name) - return ip, err - } -} - -func dontCallDefaultCallback(t *testing.T) CacheDefaultCallback { - return func(name string) (net.IP, error) { - t.Error("Should not have been called") - return nil, nil - } -} - -func TestCache_Get(t *testing.T) { - tests := []struct { - name string - c *Cache - iface string - want net.IP - wantErr bool - }{ - {"Exists in cache", &Cache{items: map[string]net.IP{"eth0": net.IPv4(10, 4, 0, 1)}}, "eth0", net.IPv4(10, 4, 0, 1), false}, - {"Did not exist in cache", &Cache{items: map[string]net.IP{}}, "eth0", nil, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.c.Get(tt.iface) - if (err != nil) != tt.wantErr { - t.Errorf("Cache.Get() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Cache.Get() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestCache_GetWithDefault(t *testing.T) { - tests := []struct { - name string - c *Cache - def CacheDefaultCallback - iface string - want net.IP - wantErr bool - }{ - {"Exists in cache", &Cache{items: map[string]net.IP{"eth0": net.IPv4(10, 4, 0, 1)}}, dontCallDefaultCallback(t), "eth0", net.IPv4(10, 4, 0, 1), false}, - {"Did not exists in cache", NewCache(), defaultCallback(t, "eth1", net.IPv4(192, 172, 44, 25), nil), "eth1", net.IPv4(192, 172, 44, 25), false}, - {"Callback returns error", NewCache(), defaultCallback(t, "eth1", nil, errors.New("some error")), "eth1", nil, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.c.GetWithDefault(tt.iface, tt.def) - if (err != nil) != tt.wantErr { - t.Errorf("Cache.Get() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Cache.Get() = %v, want %v", got, tt.want) - } - }) - } -} +package ip + +import ( + "errors" + "net" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func defaultCallback(t *testing.T, expected_name string, ip net.IP, err error) CacheDefaultCallback { + return func(name string) (net.IP, error) { + assert.Equal(t, expected_name, name) + return ip, err + } +} + +func dontCallDefaultCallback(t *testing.T) CacheDefaultCallback { + return func(name string) (net.IP, error) { + t.Error("Should not have been called") + return nil, nil + } +} + +func TestCache_Get(t *testing.T) { + tests := []struct { + name string + c *Cache + iface string + want net.IP + wantErr bool + }{ + {"Exists in cache", &Cache{items: map[string]net.IP{"eth0": net.IPv4(10, 4, 0, 1)}}, "eth0", net.IPv4(10, 4, 0, 1), false}, + {"Did not exist in cache", &Cache{items: map[string]net.IP{}}, "eth0", nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.c.Get(tt.iface) + if (err != nil) != tt.wantErr { + t.Errorf("Cache.Get() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Cache.Get() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCache_GetWithDefault(t *testing.T) { + tests := []struct { + name string + c *Cache + def CacheDefaultCallback + iface string + want net.IP + wantErr bool + }{ + {"Exists in cache", &Cache{items: map[string]net.IP{"eth0": net.IPv4(10, 4, 0, 1)}}, dontCallDefaultCallback(t), "eth0", net.IPv4(10, 4, 0, 1), false}, + {"Did not exists in cache", NewCache(), defaultCallback(t, "eth1", net.IPv4(192, 172, 44, 25), nil), "eth1", net.IPv4(192, 172, 44, 25), false}, + {"Callback returns error", NewCache(), defaultCallback(t, "eth1", nil, errors.New("some error")), "eth1", nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.c.GetWithDefault(tt.iface, tt.def) + if (err != nil) != tt.wantErr { + t.Errorf("Cache.Get() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Cache.Get() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/ip/interface.go b/ip/interface.go index d12332a..dd32b17 100644 --- a/ip/interface.go +++ b/ip/interface.go @@ -1,46 +1,46 @@ -package ip - -import ( - "errors" - "net" -) - -func GetInterfaceIP(iface_name string) (net.IP, error) { - ip := net.IP{} - iface, err := net.InterfaceByName(iface_name) - if err != nil { - return ip, err - } - - addrs, err := iface.Addrs() - if err != nil { - return ip, err - } - - return GetPublicIp(addrs) -} - -func GetPublicIp(list []net.Addr) (net.IP, error) { - for _, addr := range list { - ip, err := AddrToIP(addr) - if err == nil && !ip.IsPrivate() { - return ip, nil - } - } - - return nil, errors.New("no public ip found on interface") -} - -func AddrToIP(addr net.Addr) (net.IP, error) { - switch v := addr.(type) { - case *net.IPNet: - return v.IP, nil - case *net.IPAddr: - return v.IP, nil - case *net.UDPAddr: - return v.IP, nil - case *net.TCPAddr: - return v.IP, nil - } - return nil, errors.New("could not find ip") -} +package ip + +import ( + "errors" + "net" +) + +func GetInterfaceIP(iface_name string) (net.IP, error) { + ip := net.IP{} + iface, err := net.InterfaceByName(iface_name) + if err != nil { + return ip, err + } + + addrs, err := iface.Addrs() + if err != nil { + return ip, err + } + + return GetPublicIp(addrs) +} + +func GetPublicIp(list []net.Addr) (net.IP, error) { + for _, addr := range list { + ip, err := AddrToIP(addr) + if err == nil && !ip.IsPrivate() { + return ip, nil + } + } + + return nil, errors.New("no public ip found on interface") +} + +func AddrToIP(addr net.Addr) (net.IP, error) { + switch v := addr.(type) { + case *net.IPNet: + return v.IP, nil + case *net.IPAddr: + return v.IP, nil + case *net.UDPAddr: + return v.IP, nil + case *net.TCPAddr: + return v.IP, nil + } + return nil, errors.New("could not find ip") +} diff --git a/ip/interface_test.go b/ip/interface_test.go index 771656b..436eb96 100644 --- a/ip/interface_test.go +++ b/ip/interface_test.go @@ -1,69 +1,69 @@ -package ip - -import ( - "net" - "reflect" - "testing" -) - -func TestGetPublicIp(t *testing.T) { - tests := []struct { - name string - list []string - want string - wantErr bool - }{ - {"empty", []string{}, "", true}, - {"find", []string{"99.140.96.132"}, "99.140.96.132", false}, - {"findfirst", []string{"23.114.115.197", "251.78.128.148"}, "23.114.115.197", false}, - {"dontfindprivate", []string{"192.168.0.22", "88.12.32.44"}, "88.12.32.44", false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - list := []net.Addr{} - for _, item := range tt.list { - list = append(list, &net.IPAddr{IP: net.ParseIP(item)}) - } - - want := net.ParseIP(tt.want) - - got, err := GetPublicIp(list) - if (err != nil) != tt.wantErr { - t.Errorf("GetPublicIp() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, want) { - t.Errorf("GetPublicIp() = %v, want %v", got, want) - } - }) - } -} - -func TestAddrToIP(t *testing.T) { - tests := []struct { - name string - addr net.Addr - want net.IP - wantErr bool - }{ - {"IPNet", &net.IPNet{IP: net.IPv4(177, 171, 44, 1)}, net.IPv4(177, 171, 44, 1), false}, - {"IPAddr", &net.IPAddr{IP: net.IPv4(240, 23, 119, 171)}, net.IPv4(240, 23, 119, 171), false}, - {"TCPAddr", &net.TCPAddr{IP: net.IPv4(139, 231, 35, 221)}, net.IPv4(139, 231, 35, 221), false}, - {"UDPAddr", &net.UDPAddr{IP: net.IPv4(167, 147, 140, 119)}, net.IPv4(167, 147, 140, 119), false}, - {"UnixAddr", &net.UnixAddr{}, nil, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := AddrToIP(tt.addr) - - if (err != nil) != tt.wantErr { - t.Errorf("AddrToIP() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("AddrToIP() = %v, want %v", got, tt.want) - } - }) - } -} +package ip + +import ( + "net" + "reflect" + "testing" +) + +func TestGetPublicIp(t *testing.T) { + tests := []struct { + name string + list []string + want string + wantErr bool + }{ + {"empty", []string{}, "", true}, + {"find", []string{"99.140.96.132"}, "99.140.96.132", false}, + {"findfirst", []string{"23.114.115.197", "251.78.128.148"}, "23.114.115.197", false}, + {"dontfindprivate", []string{"192.168.0.22", "88.12.32.44"}, "88.12.32.44", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + list := []net.Addr{} + for _, item := range tt.list { + list = append(list, &net.IPAddr{IP: net.ParseIP(item)}) + } + + want := net.ParseIP(tt.want) + + got, err := GetPublicIp(list) + if (err != nil) != tt.wantErr { + t.Errorf("GetPublicIp() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, want) { + t.Errorf("GetPublicIp() = %v, want %v", got, want) + } + }) + } +} + +func TestAddrToIP(t *testing.T) { + tests := []struct { + name string + addr net.Addr + want net.IP + wantErr bool + }{ + {"IPNet", &net.IPNet{IP: net.IPv4(177, 171, 44, 1)}, net.IPv4(177, 171, 44, 1), false}, + {"IPAddr", &net.IPAddr{IP: net.IPv4(240, 23, 119, 171)}, net.IPv4(240, 23, 119, 171), false}, + {"TCPAddr", &net.TCPAddr{IP: net.IPv4(139, 231, 35, 221)}, net.IPv4(139, 231, 35, 221), false}, + {"UDPAddr", &net.UDPAddr{IP: net.IPv4(167, 147, 140, 119)}, net.IPv4(167, 147, 140, 119), false}, + {"UnixAddr", &net.UnixAddr{}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := AddrToIP(tt.addr) + + if (err != nil) != tt.wantErr { + t.Errorf("AddrToIP() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("AddrToIP() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/ip/internal/ip.go b/ip/internal/ip.go index 1d7f0c5..45ab9be 100644 --- a/ip/internal/ip.go +++ b/ip/internal/ip.go @@ -1,15 +1,15 @@ -package internal - -import "net" - -func ParseIP(s string) (net.IP, error) { - var err error = nil - ip := net.ParseIP(s) - if ip == nil { - err = &net.ParseError{ - Type: "IP address", - Text: s, - } - } - return ip, err -} +package internal + +import "net" + +func ParseIP(s string) (net.IP, error) { + var err error = nil + ip := net.ParseIP(s) + if ip == nil { + err = &net.ParseError{ + Type: "IP address", + Text: s, + } + } + return ip, err +} diff --git a/ip/internal/ip_test.go b/ip/internal/ip_test.go index a65cb72..29d0c2e 100644 --- a/ip/internal/ip_test.go +++ b/ip/internal/ip_test.go @@ -1,35 +1,35 @@ -package internal - -import ( - "net" - "reflect" - "testing" -) - -func TestParseIP(t *testing.T) { - tests := []struct { - name string - input string - want net.IP - wantErr bool - }{ - {"localhost", "127.0.0.1", net.IPv4(127, 0, 0, 1), false}, - {"Private#1", "10.4.0.11", net.IPv4(10, 4, 0, 11), false}, - {"Private#2", "192.168.1.12", net.IPv4(192, 168, 1, 12), false}, - {"Public#1", "82.249.10.254", net.IPv4(82, 249, 10, 254), false}, - {"Public#2", "57.167.50.222", net.IPv4(57, 167, 50, 222), false}, - {"Invalid", "xx", nil, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := ParseIP(tt.input) - if (err != nil) != tt.wantErr { - t.Errorf("ParseIP() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("ParseIP() = %v, want %v", got, tt.want) - } - }) - } -} +package internal + +import ( + "net" + "reflect" + "testing" +) + +func TestParseIP(t *testing.T) { + tests := []struct { + name string + input string + want net.IP + wantErr bool + }{ + {"localhost", "127.0.0.1", net.IPv4(127, 0, 0, 1), false}, + {"Private#1", "10.4.0.11", net.IPv4(10, 4, 0, 11), false}, + {"Private#2", "192.168.1.12", net.IPv4(192, 168, 1, 12), false}, + {"Public#1", "82.249.10.254", net.IPv4(82, 249, 10, 254), false}, + {"Public#2", "57.167.50.222", net.IPv4(57, 167, 50, 222), false}, + {"Invalid", "xx", nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseIP(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ParseIP() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseIP() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/ip/resolver/decoder/jsonip.go b/ip/resolver/decoder/jsonip.go index 5b6f053..0513416 100644 --- a/ip/resolver/decoder/jsonip.go +++ b/ip/resolver/decoder/jsonip.go @@ -1,16 +1,16 @@ -package decoder - -import ( - "encoding/json" - "io" -) - -func Jsonip(r io.Reader) (string, error) { - var v struct { - Ip string `json:"ip"` - Location string `json:"geo-ip"` - Help string `json:"API Help"` - } - err := json.NewDecoder(r).Decode(&v) - return v.Ip, err -} +package decoder + +import ( + "encoding/json" + "io" +) + +func Jsonip(r io.Reader) (string, error) { + var v struct { + Ip string `json:"ip"` + Location string `json:"geo-ip"` + Help string `json:"API Help"` + } + err := json.NewDecoder(r).Decode(&v) + return v.Ip, err +} diff --git a/ip/resolver/http/service.go b/ip/resolver/http/service.go index cdb52bb..12b817a 100644 --- a/ip/resolver/http/service.go +++ b/ip/resolver/http/service.go @@ -1,42 +1,42 @@ -package http - -import ( - "context" - "io" - "net" - "net/http" - - httputils "dnsupdater/http" - "dnsupdater/ip/internal" - "dnsupdater/ip/resolver/decoder" -) - -type Decoder func(io.Reader) (string, error) - -type Service struct { - ServiceName string - Url string - Headers http.Header - Decoder Decoder -} - -func (s Service) Name() string { - return s.ServiceName -} - -func (s Service) Lookup(ctx context.Context) (net.IP, error) { - resp, err := httputils.Get(ctx, s.Url, s.Headers) - if err != nil { - return nil, err - } - - if s.Decoder == nil { - s.Decoder = decoder.Text - } - - body, err := s.Decoder(resp.Body) - if err != nil && err != io.EOF { - return nil, err - } - return internal.ParseIP(string(body)) -} +package http + +import ( + "context" + "io" + "net" + "net/http" + + httputils "dnsupdater/http" + "dnsupdater/ip/internal" + "dnsupdater/ip/resolver/decoder" +) + +type Decoder func(io.Reader) (string, error) + +type Service struct { + ServiceName string + Url string + Headers http.Header + Decoder Decoder +} + +func (s Service) Name() string { + return s.ServiceName +} + +func (s Service) Lookup(ctx context.Context) (net.IP, error) { + resp, err := httputils.Get(ctx, s.Url, s.Headers) + if err != nil { + return nil, err + } + + if s.Decoder == nil { + s.Decoder = decoder.Text + } + + body, err := s.Decoder(resp.Body) + if err != nil && err != io.EOF { + return nil, err + } + return internal.ParseIP(string(body)) +} diff --git a/ip/resolver/http/service_test.go b/ip/resolver/http/service_test.go index d09ce61..3fd98c7 100644 --- a/ip/resolver/http/service_test.go +++ b/ip/resolver/http/service_test.go @@ -1,85 +1,85 @@ -package http - -import ( - "context" - "net" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestService_Name(t *testing.T) { - s := Service{ServiceName: "my_service"} - - assert.Equal(t, "my_service", s.Name()) -} - -func TestService_Lookup(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("255.240.85.2")) - assert.NoError(t, err) - })) - defer server.Close() - - s := Service{Url: server.URL} - - ip, err := s.Lookup(context.Background()) - assert.NoError(t, err) - - assert.Equal(t, net.IPv4(255, 240, 85, 2), ip) -} - -func TestService_Lookup_WithHeaders(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "application/json", r.Header.Get("Content-Type")) - - _, err := w.Write([]byte("125.74.233.13")) - assert.NoError(t, err) - })) - defer server.Close() - - s := Service{ - Url: server.URL, - Headers: http.Header{ - "Content-Type": []string{"application/json"}, - }, - } - - ip, err := s.Lookup(context.Background()) - assert.NoError(t, err) - - assert.Equal(t, net.IPv4(125, 74, 233, 13), ip) -} - -func TestService_Lookup_HTTPError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - })) - defer server.Close() - - s := Service{ - Url: server.URL, - } - - ip, err := s.Lookup(context.Background()) - assert.EqualError(t, err, "HTTP Response: 404 Not Found") - assert.Nil(t, ip) -} - -func TestService_Lookup_ParseError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("random_string")) - assert.NoError(t, err) - })) - defer server.Close() - - s := Service{ - Url: server.URL, - } - - ip, err := s.Lookup(context.Background()) - assert.EqualError(t, err, "invalid IP address: random_string") - assert.Nil(t, ip) -} +package http + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestService_Name(t *testing.T) { + s := Service{ServiceName: "my_service"} + + assert.Equal(t, "my_service", s.Name()) +} + +func TestService_Lookup(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("255.240.85.2")) + assert.NoError(t, err) + })) + defer server.Close() + + s := Service{Url: server.URL} + + ip, err := s.Lookup(context.Background()) + assert.NoError(t, err) + + assert.Equal(t, net.IPv4(255, 240, 85, 2), ip) +} + +func TestService_Lookup_WithHeaders(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + + _, err := w.Write([]byte("125.74.233.13")) + assert.NoError(t, err) + })) + defer server.Close() + + s := Service{ + Url: server.URL, + Headers: http.Header{ + "Content-Type": []string{"application/json"}, + }, + } + + ip, err := s.Lookup(context.Background()) + assert.NoError(t, err) + + assert.Equal(t, net.IPv4(125, 74, 233, 13), ip) +} + +func TestService_Lookup_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + })) + defer server.Close() + + s := Service{ + Url: server.URL, + } + + ip, err := s.Lookup(context.Background()) + assert.EqualError(t, err, "HTTP Response: 404 Not Found") + assert.Nil(t, ip) +} + +func TestService_Lookup_ParseError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("random_string")) + assert.NoError(t, err) + })) + defer server.Close() + + s := Service{ + Url: server.URL, + } + + ip, err := s.Lookup(context.Background()) + assert.EqualError(t, err, "invalid IP address: random_string") + assert.Nil(t, ip) +} diff --git a/ip/resolver/manager.go b/ip/resolver/manager.go index 6ab2217..b4fcbd4 100644 --- a/ip/resolver/manager.go +++ b/ip/resolver/manager.go @@ -1,70 +1,70 @@ -package resolver - -import ( - "net/http" - - "dnsupdater/ip/resolver/decoder" - httpres "dnsupdater/ip/resolver/http" -) - -var services []Service - -func Provide(service Service) { - services = append(services, service) -} - -func Get(name string) Service { - for _, service := range services { - if service.Name() == name { - return service - } - } - return nil -} - -func init() { - Provide(&httpres.Service{ - ServiceName: "jsonip", - Url: "https://jsonip.com", - Decoder: decoder.Jsonip, - }) - - Provide(&httpres.Service{ - ServiceName: "ifconfig.me", - Url: "https://ifconfig.me/ip", - }) - - Provide(&httpres.Service{ - ServiceName: "ip.me", - Url: "https://ip.me", - Headers: http.Header{ - "User-Agent": []string{"curl"}, - }, - }) - - Provide(&httpres.Service{ - ServiceName: "ipecho", - Url: "http://ipecho.net/plain", - }) - - Provide(&httpres.Service{ - ServiceName: "icanhazip", - Url: "https://icanhazip.com", - }) - - Provide(&httpres.Service{ - ServiceName: "ipify", - Url: "https://api.ipify.org", - }) - - Provide(&httpres.Service{ - ServiceName: "myip", - Url: "https://api.myip.com", - Decoder: decoder.MyIP, - }) - - Provide(&httpres.Service{ - ServiceName: "my-ip", - Url: "https://api.my-ip.io/v2/ip.txt", - }) -} +package resolver + +import ( + "net/http" + + "dnsupdater/ip/resolver/decoder" + httpres "dnsupdater/ip/resolver/http" +) + +var services []Service + +func Provide(service Service) { + services = append(services, service) +} + +func Get(name string) Service { + for _, service := range services { + if service.Name() == name { + return service + } + } + return nil +} + +func init() { + Provide(&httpres.Service{ + ServiceName: "jsonip", + Url: "https://jsonip.com", + Decoder: decoder.Jsonip, + }) + + Provide(&httpres.Service{ + ServiceName: "ifconfig.me", + Url: "https://ifconfig.me/ip", + }) + + Provide(&httpres.Service{ + ServiceName: "ip.me", + Url: "https://ip.me", + Headers: http.Header{ + "User-Agent": []string{"curl"}, + }, + }) + + Provide(&httpres.Service{ + ServiceName: "ipecho", + Url: "http://ipecho.net/plain", + }) + + Provide(&httpres.Service{ + ServiceName: "icanhazip", + Url: "https://icanhazip.com", + }) + + Provide(&httpres.Service{ + ServiceName: "ipify", + Url: "https://api.ipify.org", + }) + + Provide(&httpres.Service{ + ServiceName: "myip", + Url: "https://api.myip.com", + Decoder: decoder.MyIP, + }) + + Provide(&httpres.Service{ + ServiceName: "my-ip", + Url: "https://api.my-ip.io/v2/ip.txt", + }) +} diff --git a/ip/resolver/mock/service.go b/ip/resolver/mock/service.go index 464336e..fe72248 100644 --- a/ip/resolver/mock/service.go +++ b/ip/resolver/mock/service.go @@ -1,19 +1,19 @@ -package mock - -import ( - "context" - "net" -) - -type Service struct { - IP net.IP - Error error -} - -func (s Service) Name() string { - return "mock" -} - -func (s Service) Lookup(ctx context.Context) (net.IP, error) { - return s.IP, s.Error -} +package mock + +import ( + "context" + "net" +) + +type Service struct { + IP net.IP + Error error +} + +func (s Service) Name() string { + return "mock" +} + +func (s Service) Lookup(ctx context.Context) (net.IP, error) { + return s.IP, s.Error +} diff --git a/ip/resolver/resolver.go b/ip/resolver/resolver.go index c01d167..094418f 100644 --- a/ip/resolver/resolver.go +++ b/ip/resolver/resolver.go @@ -1,15 +1,15 @@ -package resolver - -import ( - "context" - "net" -) - -// Interface that IP Lookup Services must implement. -type Service interface { - // Get the name of the serivce - Name() string - - // Lookup the public ip. - Lookup(ctx context.Context) (net.IP, error) -} +package resolver + +import ( + "context" + "net" +) + +// Interface that IP Lookup Services must implement. +type Service interface { + // Get the name of the serivce + Name() string + + // Lookup the public ip. + Lookup(ctx context.Context) (net.IP, error) +}