1
0
Fork 0

line ending fix

This commit is contained in:
Henrik Hautakoski 2025-10-12 23:52:30 +02:00
parent a0e4de3d19
commit 0c347312bd
26 changed files with 1053 additions and 1046 deletions

2
.gitignore vendored
View file

@ -1,2 +1,2 @@
config.yml config.yml
build/ build/

View file

@ -1,22 +1,22 @@
image: golang:1.19 image: golang:1.19
stages: stages:
- test - test
- build - build
unit-test: unit-test:
stage: test stage: test
script: script:
- go test -v ./... - go test -v ./...
compile: compile:
stage: build stage: build
script: script:
- mkdir -p build - mkdir -p build
- GOOS=linux GOARCH=amd64 go build -o build/dnsupdater-linux-amd64 cmd/dnsupdater/main.go - GOOS=linux GOARCH=amd64 go build -o build/dnsupdater-linux-amd64 cmd/dnsupdater/main.go
- GOOS=linux GOARCH=mips GOMIPS=softfloat go build -o build/dnsupdater-linux-mips cmd/dnsupdater/main.go - GOOS=linux GOARCH=mips GOMIPS=softfloat go build -o build/dnsupdater-linux-mips cmd/dnsupdater/main.go
- GOOS=linux GOARCH=mipsle GOMIPS=softfloat go build -o build/dnsupdater-linux-mipsle cmd/dnsupdater/main.go - GOOS=linux GOARCH=mipsle GOMIPS=softfloat go build -o build/dnsupdater-linux-mipsle cmd/dnsupdater/main.go
artifacts: artifacts:
paths: paths:
- build - build

View file

@ -1,11 +1,11 @@
GO=go GO=go
GOLDFLAGS=-v -s -w GOLDFLAGS=-v -s -w
GOBUILDFLAGS=-v -p $(shell nproc) -ldflags="$(GOLDFLAGS)" GOBUILDFLAGS=-v -p $(shell nproc) -ldflags="$(GOLDFLAGS)"
.PHONY: build test .PHONY: build test
build : build :
$(GO) build $(GOBUILDFLAGS) -o build/dnsupdater cmd/dnsupdater/main.go $(GO) build $(GOBUILDFLAGS) -o build/dnsupdater cmd/dnsupdater/main.go
test : test :
$(GO) test -v ./... $(GO) test -v ./...

View file

@ -1,61 +1,61 @@
package app package app
import ( import (
"context" "context"
"fmt" "fmt"
"net" "net"
"time" "time"
"dnsupdater/provider/manager" "dnsupdater/provider/manager"
"dnsupdater/ip" "dnsupdater/ip"
"dnsupdater/ip/resolver" "dnsupdater/ip/resolver"
) )
// Constant name for the virtual WAN interface // Constant name for the virtual WAN interface
const WAN_IFACE = "wan" const WAN_IFACE = "wan"
type App struct { type App struct {
cache *ip.Cache cache *ip.Cache
cacheDefaultCallback ip.CacheDefaultCallback cacheDefaultCallback ip.CacheDefaultCallback
// Updater manager // Updater manager
ProviderManager *manager.Manager ProviderManager *manager.Manager
} }
func makeCacheCallback(service resolver.Service) ip.CacheDefaultCallback { func makeCacheCallback(service resolver.Service) ip.CacheDefaultCallback {
return func(name string) (net.IP, error) { return func(name string) (net.IP, error) {
if name == WAN_IFACE { if name == WAN_IFACE {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel() defer cancel()
return service.Lookup(ctx) return service.Lookup(ctx)
} }
return ip.GetInterfaceIP(name) return ip.GetInterfaceIP(name)
} }
} }
func NewApp(config *Config) (*App, error) { func NewApp(config *Config) (*App, error) {
providerMgr := manager.New() providerMgr := manager.New()
// providerMgr.Register("digitalocean", digitalocean.New(config.Services.DigitalOcean.Token)) // providerMgr.Register("digitalocean", digitalocean.New(config.Services.DigitalOcean.Token))
err := providerMgr.RegisterFromConfig(config.Providers) err := providerMgr.RegisterFromConfig(config.Providers)
if err != nil { if err != nil {
return nil, err return nil, err
} }
service := resolver.Get(config.Services.IPLookup) service := resolver.Get(config.Services.IPLookup)
if service == nil { if service == nil {
return nil, fmt.Errorf("Failed to load lookup service: %s", config.Services.IPLookup) return nil, fmt.Errorf("Failed to load lookup service: %s", config.Services.IPLookup)
} }
return &App{ return &App{
ProviderManager: providerMgr, ProviderManager: providerMgr,
cache: ip.NewCache(), cache: ip.NewCache(),
cacheDefaultCallback: makeCacheCallback(service), cacheDefaultCallback: makeCacheCallback(service),
}, nil }, nil
} }
func (a App) GetIP(iface_name string) (net.IP, error) { func (a App) GetIP(iface_name string) (net.IP, error) {
return a.cache.GetWithDefault(iface_name, a.cacheDefaultCallback) return a.cache.GetWithDefault(iface_name, a.cacheDefaultCallback)
} }

View file

@ -1,47 +1,47 @@
package app package app
import ( import (
"os" "os"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
type ( type (
DomainRecords map[string]string DomainRecords map[string]string
Domain map[string]DomainRecords Domain map[string]DomainRecords
) )
type DigitalOceanService struct { type DigitalOceanService struct {
Token string `yaml:"token"` Token string `yaml:"token"`
Domains map[string]DomainRecords `yaml:"domains"` Domains map[string]DomainRecords `yaml:"domains"`
} }
type Providers struct { type Providers struct {
Token string `yaml:"token"` Token string `yaml:"token"`
Domains map[string]DomainRecords `yaml:"domains"` Domains map[string]DomainRecords `yaml:"domains"`
} }
type Services struct { type Services struct {
IPLookup string `yaml:"IPLookup"` IPLookup string `yaml:"IPLookup"`
// DigitalOcean DigitalOceanService `yaml:"digitalocean"` // DigitalOcean DigitalOceanService `yaml:"digitalocean"`
} }
type Config struct { type Config struct {
Services Services `yaml:"services"` Services Services `yaml:"services"`
Providers map[string]map[string]interface{} Providers map[string]map[string]interface{}
Updates map[string]Domain Updates map[string]Domain
} }
func LoadConfig(filename string) (*Config, error) { func LoadConfig(filename string) (*Config, error) {
cfg := Config{ cfg := Config{
Services: Services{ Services: Services{
IPLookup: "ipecho", IPLookup: "ipecho",
}, },
} }
data, err := os.ReadFile(filename) data, err := os.ReadFile(filename)
if err == nil { if err == nil {
err = yaml.Unmarshal(data, &cfg) err = yaml.Unmarshal(data, &cfg)
} }
return &cfg, err return &cfg, err
} }

View file

@ -1,68 +1,68 @@
package main package main
import ( import (
"flag" "flag"
"os" "os"
"time" "time"
"dnsupdater/app" "dnsupdater/app"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func main() { func main() {
configFile := flag.String("config", "./config.yml", "configuration file") configFile := flag.String("config", "./config.yml", "configuration file")
flag.Parse() flag.Parse()
log.Logger = log.Output(zerolog.ConsoleWriter{ log.Logger = log.Output(zerolog.ConsoleWriter{
Out: os.Stderr, Out: os.Stderr,
TimeFormat: time.RFC3339, TimeFormat: time.RFC3339,
}) })
config, err := app.LoadConfig(*configFile) config, err := app.LoadConfig(*configFile)
if err != nil { if err != nil {
log.Fatal().Err(err).Str("file", *configFile).Msg("Failed to load config") log.Fatal().Err(err).Str("file", *configFile).Msg("Failed to load config")
} }
app, err := app.NewApp(config) app, err := app.NewApp(config)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to initialize application") log.Fatal().Err(err).Msg("Failed to initialize application")
} }
for service_name, domains := range config.Updates { for service_name, domains := range config.Updates {
log.Info().Str("service", service_name).Msg("Begin update for service") log.Info().Str("service", service_name).Msg("Begin update for service")
// Get service // Get service
service := app.ProviderManager.Get(service_name) service := app.ProviderManager.Get(service_name)
for domain, records := range domains { for domain, records := range domains {
for name, data := range records { for name, data := range records {
logger := log.With(). logger := log.With().
Str("service", service_name). Str("service", service_name).
Str("domain", domain). Str("domain", domain).
Str("record", name). Str("record", name).
Str("interface", data). Str("interface", data).
Logger() Logger()
ip, err := app.GetIP(data) ip, err := app.GetIP(data)
if err != nil { if err != nil {
logger.Error().Err(err).Msg("Failed to fetch ip") logger.Error().Err(err).Msg("Failed to fetch ip")
continue continue
} }
logger = logger.With().IPAddr("ip", ip).Logger() logger = logger.With().IPAddr("ip", ip).Logger()
err = service.Update(domain, name, ip) err = service.Update(domain, name, ip)
if err != nil { if err != nil {
logger.Error().Err(err).Msg("Failed to update record") logger.Error().Err(err).Msg("Failed to update record")
} else { } else {
logger.Info().Msg("Record updated") logger.Info().Msg("Record updated")
} }
} }
} }
} }
} }

View file

@ -1,18 +1,18 @@
services: services:
IPLookup: ipecho IPLookup: ipecho
providers: providers:
digitalocean: digitalocean:
token: xxxx token: xxxx
updates: updates:
digitalocean: digitalocean:
domain1.com: domain1.com:
www: wan www: wan
box: 10.140.14.2 box: 10.140.14.2
domain2.com: domain2.com:
www: wan www: wan
mail: wan mail: wan
static: 84.24.254.21 static: 84.24.254.21

7
go.mod
View file

@ -12,13 +12,14 @@ require (
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/net v0.9.0 // indirect golang.org/x/net v0.9.0 // indirect
golang.org/x/oauth2 v0.7.0 // indirect golang.org/x/oauth2 v0.7.0 // indirect
golang.org/x/sys v0.7.0 // indirect golang.org/x/sys v0.20.0 // indirect
golang.org/x/time v0.3.0 // indirect golang.org/x/time v0.3.0 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.30.0 // indirect

16
go.sum
View file

@ -11,13 +11,17 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@ -40,8 +44,10 @@ golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=

View file

@ -1,27 +1,27 @@
package http package http
import ( import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
) )
// Perform a HTTP Get request. // Perform a HTTP Get request.
func Get(ctx context.Context, url string, headers http.Header) (*http.Response, error) { func Get(ctx context.Context, url string, headers http.Header) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil) req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header = headers req.Header = headers
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return nil, fmt.Errorf("HTTP Response: %s", resp.Status) return nil, fmt.Errorf("HTTP Response: %s", resp.Status)
} }
return resp, nil return resp, nil
} }

View file

@ -1,43 +1,43 @@
package ip package ip
import ( import (
"errors" "errors"
"net" "net"
) )
type CacheDefaultCallback func(name string) (net.IP, error) type CacheDefaultCallback func(name string) (net.IP, error)
type Cache struct { type Cache struct {
items map[string]net.IP items map[string]net.IP
} }
func NewCache() *Cache { func NewCache() *Cache {
return &Cache{ return &Cache{
items: make(map[string]net.IP), items: make(map[string]net.IP),
} }
} }
func (c Cache) Get(name string) (net.IP, error) { func (c Cache) Get(name string) (net.IP, error) {
// Return cached entry. // Return cached entry.
if cached, ok := c.items[name]; ok { if cached, ok := c.items[name]; ok {
return cached, nil return cached, nil
} }
return nil, errors.New("key did not exist") return nil, errors.New("key did not exist")
} }
func (c Cache) GetWithDefault(name string, callback CacheDefaultCallback) (net.IP, error) { func (c Cache) GetWithDefault(name string, callback CacheDefaultCallback) (net.IP, error) {
// Return cached entry. // Return cached entry.
if cached, ok := c.items[name]; ok { if cached, ok := c.items[name]; ok {
return cached, nil return cached, nil
} }
ip, err := callback(name) ip, err := callback(name)
if err == nil { if err == nil {
c.Set(name, ip) c.Set(name, ip)
} }
return ip, err return ip, err
} }
func (c *Cache) Set(name string, ip net.IP) { func (c *Cache) Set(name string, ip net.IP) {
c.items[name] = ip c.items[name] = ip
} }

View file

@ -1,76 +1,76 @@
package ip package ip
import ( import (
"errors" "errors"
"net" "net"
"reflect" "reflect"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func defaultCallback(t *testing.T, expected_name string, ip net.IP, err error) CacheDefaultCallback { func defaultCallback(t *testing.T, expected_name string, ip net.IP, err error) CacheDefaultCallback {
return func(name string) (net.IP, error) { return func(name string) (net.IP, error) {
assert.Equal(t, expected_name, name) assert.Equal(t, expected_name, name)
return ip, err return ip, err
} }
} }
func dontCallDefaultCallback(t *testing.T) CacheDefaultCallback { func dontCallDefaultCallback(t *testing.T) CacheDefaultCallback {
return func(name string) (net.IP, error) { return func(name string) (net.IP, error) {
t.Error("Should not have been called") t.Error("Should not have been called")
return nil, nil return nil, nil
} }
} }
func TestCache_Get(t *testing.T) { func TestCache_Get(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
c *Cache c *Cache
iface string iface string
want net.IP want net.IP
wantErr bool wantErr bool
}{ }{
{"Exists in cache", &Cache{items: map[string]net.IP{"eth0": net.IPv4(10, 4, 0, 1)}}, "eth0", net.IPv4(10, 4, 0, 1), false}, {"Exists in cache", &Cache{items: map[string]net.IP{"eth0": net.IPv4(10, 4, 0, 1)}}, "eth0", net.IPv4(10, 4, 0, 1), false},
{"Did not exist in cache", &Cache{items: map[string]net.IP{}}, "eth0", nil, true}, {"Did not exist in cache", &Cache{items: map[string]net.IP{}}, "eth0", nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.c.Get(tt.iface) got, err := tt.c.Get(tt.iface)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Cache.Get() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Cache.Get() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(got, tt.want) { if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Cache.Get() = %v, want %v", got, tt.want) t.Errorf("Cache.Get() = %v, want %v", got, tt.want)
} }
}) })
} }
} }
func TestCache_GetWithDefault(t *testing.T) { func TestCache_GetWithDefault(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
c *Cache c *Cache
def CacheDefaultCallback def CacheDefaultCallback
iface string iface string
want net.IP want net.IP
wantErr bool wantErr bool
}{ }{
{"Exists in cache", &Cache{items: map[string]net.IP{"eth0": net.IPv4(10, 4, 0, 1)}}, dontCallDefaultCallback(t), "eth0", net.IPv4(10, 4, 0, 1), false}, {"Exists in cache", &Cache{items: map[string]net.IP{"eth0": net.IPv4(10, 4, 0, 1)}}, dontCallDefaultCallback(t), "eth0", net.IPv4(10, 4, 0, 1), false},
{"Did not exists in cache", NewCache(), defaultCallback(t, "eth1", net.IPv4(192, 172, 44, 25), nil), "eth1", net.IPv4(192, 172, 44, 25), false}, {"Did not exists in cache", NewCache(), defaultCallback(t, "eth1", net.IPv4(192, 172, 44, 25), nil), "eth1", net.IPv4(192, 172, 44, 25), false},
{"Callback returns error", NewCache(), defaultCallback(t, "eth1", nil, errors.New("some error")), "eth1", nil, true}, {"Callback returns error", NewCache(), defaultCallback(t, "eth1", nil, errors.New("some error")), "eth1", nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.c.GetWithDefault(tt.iface, tt.def) got, err := tt.c.GetWithDefault(tt.iface, tt.def)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Cache.Get() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Cache.Get() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(got, tt.want) { if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Cache.Get() = %v, want %v", got, tt.want) t.Errorf("Cache.Get() = %v, want %v", got, tt.want)
} }
}) })
} }
} }

View file

@ -1,46 +1,46 @@
package ip package ip
import ( import (
"errors" "errors"
"net" "net"
) )
func GetInterfaceIP(iface_name string) (net.IP, error) { func GetInterfaceIP(iface_name string) (net.IP, error) {
ip := net.IP{} ip := net.IP{}
iface, err := net.InterfaceByName(iface_name) iface, err := net.InterfaceByName(iface_name)
if err != nil { if err != nil {
return ip, err return ip, err
} }
addrs, err := iface.Addrs() addrs, err := iface.Addrs()
if err != nil { if err != nil {
return ip, err return ip, err
} }
return GetPublicIp(addrs) return GetPublicIp(addrs)
} }
func GetPublicIp(list []net.Addr) (net.IP, error) { func GetPublicIp(list []net.Addr) (net.IP, error) {
for _, addr := range list { for _, addr := range list {
ip, err := AddrToIP(addr) ip, err := AddrToIP(addr)
if err == nil && !ip.IsPrivate() { if err == nil && !ip.IsPrivate() {
return ip, nil return ip, nil
} }
} }
return nil, errors.New("no public ip found on interface") return nil, errors.New("no public ip found on interface")
} }
func AddrToIP(addr net.Addr) (net.IP, error) { func AddrToIP(addr net.Addr) (net.IP, error) {
switch v := addr.(type) { switch v := addr.(type) {
case *net.IPNet: case *net.IPNet:
return v.IP, nil return v.IP, nil
case *net.IPAddr: case *net.IPAddr:
return v.IP, nil return v.IP, nil
case *net.UDPAddr: case *net.UDPAddr:
return v.IP, nil return v.IP, nil
case *net.TCPAddr: case *net.TCPAddr:
return v.IP, nil return v.IP, nil
} }
return nil, errors.New("could not find ip") return nil, errors.New("could not find ip")
} }

View file

@ -1,69 +1,69 @@
package ip package ip
import ( import (
"net" "net"
"reflect" "reflect"
"testing" "testing"
) )
func TestGetPublicIp(t *testing.T) { func TestGetPublicIp(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
list []string list []string
want string want string
wantErr bool wantErr bool
}{ }{
{"empty", []string{}, "", true}, {"empty", []string{}, "", true},
{"find", []string{"99.140.96.132"}, "99.140.96.132", false}, {"find", []string{"99.140.96.132"}, "99.140.96.132", false},
{"findfirst", []string{"23.114.115.197", "251.78.128.148"}, "23.114.115.197", false}, {"findfirst", []string{"23.114.115.197", "251.78.128.148"}, "23.114.115.197", false},
{"dontfindprivate", []string{"192.168.0.22", "88.12.32.44"}, "88.12.32.44", false}, {"dontfindprivate", []string{"192.168.0.22", "88.12.32.44"}, "88.12.32.44", false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
list := []net.Addr{} list := []net.Addr{}
for _, item := range tt.list { for _, item := range tt.list {
list = append(list, &net.IPAddr{IP: net.ParseIP(item)}) list = append(list, &net.IPAddr{IP: net.ParseIP(item)})
} }
want := net.ParseIP(tt.want) want := net.ParseIP(tt.want)
got, err := GetPublicIp(list) got, err := GetPublicIp(list)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("GetPublicIp() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GetPublicIp() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
t.Errorf("GetPublicIp() = %v, want %v", got, want) t.Errorf("GetPublicIp() = %v, want %v", got, want)
} }
}) })
} }
} }
func TestAddrToIP(t *testing.T) { func TestAddrToIP(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
addr net.Addr addr net.Addr
want net.IP want net.IP
wantErr bool wantErr bool
}{ }{
{"IPNet", &net.IPNet{IP: net.IPv4(177, 171, 44, 1)}, net.IPv4(177, 171, 44, 1), false}, {"IPNet", &net.IPNet{IP: net.IPv4(177, 171, 44, 1)}, net.IPv4(177, 171, 44, 1), false},
{"IPAddr", &net.IPAddr{IP: net.IPv4(240, 23, 119, 171)}, net.IPv4(240, 23, 119, 171), false}, {"IPAddr", &net.IPAddr{IP: net.IPv4(240, 23, 119, 171)}, net.IPv4(240, 23, 119, 171), false},
{"TCPAddr", &net.TCPAddr{IP: net.IPv4(139, 231, 35, 221)}, net.IPv4(139, 231, 35, 221), false}, {"TCPAddr", &net.TCPAddr{IP: net.IPv4(139, 231, 35, 221)}, net.IPv4(139, 231, 35, 221), false},
{"UDPAddr", &net.UDPAddr{IP: net.IPv4(167, 147, 140, 119)}, net.IPv4(167, 147, 140, 119), false}, {"UDPAddr", &net.UDPAddr{IP: net.IPv4(167, 147, 140, 119)}, net.IPv4(167, 147, 140, 119), false},
{"UnixAddr", &net.UnixAddr{}, nil, true}, {"UnixAddr", &net.UnixAddr{}, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := AddrToIP(tt.addr) got, err := AddrToIP(tt.addr)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("AddrToIP() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("AddrToIP() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(got, tt.want) { if !reflect.DeepEqual(got, tt.want) {
t.Errorf("AddrToIP() = %v, want %v", got, tt.want) t.Errorf("AddrToIP() = %v, want %v", got, tt.want)
} }
}) })
} }
} }

View file

@ -1,15 +1,15 @@
package internal package internal
import "net" import "net"
func ParseIP(s string) (net.IP, error) { func ParseIP(s string) (net.IP, error) {
var err error = nil var err error = nil
ip := net.ParseIP(s) ip := net.ParseIP(s)
if ip == nil { if ip == nil {
err = &net.ParseError{ err = &net.ParseError{
Type: "IP address", Type: "IP address",
Text: s, Text: s,
} }
} }
return ip, err return ip, err
} }

View file

@ -1,35 +1,35 @@
package internal package internal
import ( import (
"net" "net"
"reflect" "reflect"
"testing" "testing"
) )
func TestParseIP(t *testing.T) { func TestParseIP(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
input string input string
want net.IP want net.IP
wantErr bool wantErr bool
}{ }{
{"localhost", "127.0.0.1", net.IPv4(127, 0, 0, 1), false}, {"localhost", "127.0.0.1", net.IPv4(127, 0, 0, 1), false},
{"Private#1", "10.4.0.11", net.IPv4(10, 4, 0, 11), false}, {"Private#1", "10.4.0.11", net.IPv4(10, 4, 0, 11), false},
{"Private#2", "192.168.1.12", net.IPv4(192, 168, 1, 12), false}, {"Private#2", "192.168.1.12", net.IPv4(192, 168, 1, 12), false},
{"Public#1", "82.249.10.254", net.IPv4(82, 249, 10, 254), false}, {"Public#1", "82.249.10.254", net.IPv4(82, 249, 10, 254), false},
{"Public#2", "57.167.50.222", net.IPv4(57, 167, 50, 222), false}, {"Public#2", "57.167.50.222", net.IPv4(57, 167, 50, 222), false},
{"Invalid", "xx", nil, true}, {"Invalid", "xx", nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := ParseIP(tt.input) got, err := ParseIP(tt.input)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("ParseIP() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ParseIP() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(got, tt.want) { if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ParseIP() = %v, want %v", got, tt.want) t.Errorf("ParseIP() = %v, want %v", got, tt.want)
} }
}) })
} }
} }

View file

@ -1,46 +1,46 @@
package http package http
import ( import (
"context" "context"
"io" "io"
"net" "net"
"net/http" "net/http"
"strings" "strings"
httputils "dnsupdater/http" httputils "dnsupdater/http"
"dnsupdater/ip/internal" "dnsupdater/ip/internal"
) )
type Decoder func(io.Reader) ([]byte, error) type Decoder func(io.Reader) ([]byte, error)
type Service struct { type Service struct {
ServiceName string ServiceName string
Url string Url string
Headers http.Header Headers http.Header
Decoder Decoder Decoder Decoder
} }
func (s Service) Name() string { func (s Service) Name() string {
return s.ServiceName return s.ServiceName
} }
func (s Service) Lookup(ctx context.Context) (net.IP, error) { func (s Service) Lookup(ctx context.Context) (net.IP, error) {
resp, err := httputils.Get(ctx, s.Url, s.Headers) resp, err := httputils.Get(ctx, s.Url, s.Headers)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if s.Decoder == nil { if s.Decoder == nil {
s.Decoder = io.ReadAll s.Decoder = io.ReadAll
} }
body, err := s.Decoder(resp.Body) body, err := s.Decoder(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Trim spaces and stuff. // Trim spaces and stuff.
ip_str := strings.TrimSpace(string(body)) ip_str := strings.TrimSpace(string(body))
return internal.ParseIP(ip_str) return internal.ParseIP(ip_str)
} }

View file

@ -1,85 +1,85 @@
package http package http
import ( import (
"context" "context"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestService_Name(t *testing.T) { func TestService_Name(t *testing.T) {
s := Service{ServiceName: "my_service"} s := Service{ServiceName: "my_service"}
assert.Equal(t, "my_service", s.Name()) assert.Equal(t, "my_service", s.Name())
} }
func TestService_Lookup(t *testing.T) { func TestService_Lookup(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("255.240.85.2")) _, err := w.Write([]byte("255.240.85.2"))
assert.NoError(t, err) assert.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
s := Service{Url: server.URL} s := Service{Url: server.URL}
ip, err := s.Lookup(context.Background()) ip, err := s.Lookup(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, net.IPv4(255, 240, 85, 2), ip) assert.Equal(t, net.IPv4(255, 240, 85, 2), ip)
} }
func TestService_Lookup_WithHeaders(t *testing.T) { func TestService_Lookup_WithHeaders(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "application/json", r.Header.Get("Content-Type")) assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
_, err := w.Write([]byte("125.74.233.13")) _, err := w.Write([]byte("125.74.233.13"))
assert.NoError(t, err) assert.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
s := Service{ s := Service{
Url: server.URL, Url: server.URL,
Headers: http.Header{ Headers: http.Header{
"Content-Type": []string{"application/json"}, "Content-Type": []string{"application/json"},
}, },
} }
ip, err := s.Lookup(context.Background()) ip, err := s.Lookup(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, net.IPv4(125, 74, 233, 13), ip) assert.Equal(t, net.IPv4(125, 74, 233, 13), ip)
} }
func TestService_Lookup_HTTPError(t *testing.T) { func TestService_Lookup_HTTPError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404) w.WriteHeader(404)
})) }))
defer server.Close() defer server.Close()
s := Service{ s := Service{
Url: server.URL, Url: server.URL,
} }
ip, err := s.Lookup(context.Background()) ip, err := s.Lookup(context.Background())
assert.EqualError(t, err, "HTTP Response: 404 Not Found") assert.EqualError(t, err, "HTTP Response: 404 Not Found")
assert.Nil(t, ip) assert.Nil(t, ip)
} }
func TestService_Lookup_ParseError(t *testing.T) { func TestService_Lookup_ParseError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("random_string")) _, err := w.Write([]byte("random_string"))
assert.NoError(t, err) assert.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
s := Service{ s := Service{
Url: server.URL, Url: server.URL,
} }
ip, err := s.Lookup(context.Background()) ip, err := s.Lookup(context.Background())
assert.EqualError(t, err, "invalid IP address: random_string") assert.EqualError(t, err, "invalid IP address: random_string")
assert.Nil(t, ip) assert.Nil(t, ip)
} }

View file

@ -1,20 +1,20 @@
package resolver package resolver
import ( import (
"encoding/json" "encoding/json"
"io" "io"
) )
func JsonipDecoder(r io.Reader) ([]byte, error) { func JsonipDecoder(r io.Reader) ([]byte, error) {
var v struct { var v struct {
Ip string `json:"ip"` Ip string `json:"ip"`
Location string `json:"geo-ip"` Location string `json:"geo-ip"`
Help string `json:"API Help"` Help string `json:"API Help"`
} }
var val []byte var val []byte
err := json.NewDecoder(r).Decode(&v) err := json.NewDecoder(r).Decode(&v)
if err == nil { if err == nil {
val = []byte(v.Ip) val = []byte(v.Ip)
} }
return val, err return val, err
} }

View file

@ -1,53 +1,53 @@
package resolver package resolver
import ( import (
"net/http" "net/http"
httpres "dnsupdater/ip/resolver/http" httpres "dnsupdater/ip/resolver/http"
) )
var services []Service var services []Service
func Provide(service Service) { func Provide(service Service) {
services = append(services, service) services = append(services, service)
} }
func Get(name string) Service { func Get(name string) Service {
for _, service := range services { for _, service := range services {
if service.Name() == name { if service.Name() == name {
return service return service
} }
} }
return nil return nil
} }
func init() { func init() {
Provide(&httpres.Service{ Provide(&httpres.Service{
ServiceName: "jsonip", ServiceName: "jsonip",
Url: "https://jsonip.com", Url: "https://jsonip.com",
Decoder: JsonipDecoder, Decoder: JsonipDecoder,
}) })
Provide(&httpres.Service{ Provide(&httpres.Service{
ServiceName: "ifconfig.me", ServiceName: "ifconfig.me",
Url: "https://ifconfig.me/ip", Url: "https://ifconfig.me/ip",
}) })
Provide(&httpres.Service{ Provide(&httpres.Service{
ServiceName: "ip.me", ServiceName: "ip.me",
Url: "https://ip.me", Url: "https://ip.me",
Headers: http.Header{ Headers: http.Header{
"User-Agent": []string{"curl"}, "User-Agent": []string{"curl"},
}, },
}) })
Provide(&httpres.Service{ Provide(&httpres.Service{
ServiceName: "ipecho", ServiceName: "ipecho",
Url: "http://ipecho.net/plain", Url: "http://ipecho.net/plain",
}) })
Provide(&httpres.Service{ Provide(&httpres.Service{
ServiceName: "icanhazip", ServiceName: "icanhazip",
Url: "https://icanhazip.com", Url: "https://icanhazip.com",
}) })
} }

View file

@ -1,19 +1,19 @@
package mock package mock
import ( import (
"context" "context"
"net" "net"
) )
type Service struct { type Service struct {
IP net.IP IP net.IP
Error error Error error
} }
func (s Service) Name() string { func (s Service) Name() string {
return "mock" return "mock"
} }
func (s Service) Lookup(ctx context.Context) (net.IP, error) { func (s Service) Lookup(ctx context.Context) (net.IP, error) {
return s.IP, s.Error return s.IP, s.Error
} }

View file

@ -1,15 +1,15 @@
package resolver package resolver
import ( import (
"context" "context"
"net" "net"
) )
// Interface that IP Lookup Services must implement. // Interface that IP Lookup Services must implement.
type Service interface { type Service interface {
// Get the name of the serivce // Get the name of the serivce
Name() string Name() string
// Lookup the public ip. // Lookup the public ip.
Lookup(ctx context.Context) (net.IP, error) Lookup(ctx context.Context) (net.IP, error)
} }

View file

@ -1,105 +1,105 @@
package digitalocean package digitalocean
import ( import (
"context" "context"
"errors" "errors"
"testing" "testing"
"github.com/digitalocean/godo" "github.com/digitalocean/godo"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
type mock struct { type mock struct {
t *testing.T t *testing.T
records_by_type map[string][]godo.DomainRecord records_by_type map[string][]godo.DomainRecord
edit_record_request *godo.DomainRecordEditRequest edit_record_request *godo.DomainRecordEditRequest
edit_record_error error edit_record_error error
} }
func (m mock) List(context.Context, *godo.ListOptions) ([]godo.Domain, *godo.Response, error) { func (m mock) List(context.Context, *godo.ListOptions) ([]godo.Domain, *godo.Response, error) {
m.t.Error("List called when it should not have been") m.t.Error("List called when it should not have been")
return nil, nil, nil return nil, nil, nil
} }
func (m mock) Get(context.Context, string) (*godo.Domain, *godo.Response, error) { func (m mock) Get(context.Context, string) (*godo.Domain, *godo.Response, error) {
m.t.Error("Get called when it should not have been") m.t.Error("Get called when it should not have been")
return nil, nil, nil return nil, nil, nil
} }
func (m mock) Create(context.Context, *godo.DomainCreateRequest) (*godo.Domain, *godo.Response, error) { func (m mock) Create(context.Context, *godo.DomainCreateRequest) (*godo.Domain, *godo.Response, error) {
m.t.Error("Create called when it should not have been") m.t.Error("Create called when it should not have been")
return nil, nil, nil return nil, nil, nil
} }
func (m mock) Delete(context.Context, string) (*godo.Response, error) { func (m mock) Delete(context.Context, string) (*godo.Response, error) {
m.t.Error("Delete called when it should not have been") m.t.Error("Delete called when it should not have been")
return nil, nil return nil, nil
} }
func (m mock) Records(context.Context, string, *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) { func (m mock) Records(context.Context, string, *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) {
m.t.Error("Records called when it should not have been") m.t.Error("Records called when it should not have been")
return nil, nil, nil return nil, nil, nil
} }
func (m mock) RecordsByType(_ context.Context, name string, t string, opt *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) { func (m mock) RecordsByType(_ context.Context, name string, t string, opt *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) {
var err error var err error
// Only care about "A" records // Only care about "A" records
assert.Equal(m.t, "A", t) assert.Equal(m.t, "A", t)
r, ok := m.records_by_type[name] r, ok := m.records_by_type[name]
if !ok { if !ok {
err = errors.New("Record not found") err = errors.New("Record not found")
} }
return r, nil, err return r, nil, err
} }
func (m mock) RecordsByName(context.Context, string, string, *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) { func (m mock) RecordsByName(context.Context, string, string, *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) {
m.t.Error("RecordsByName called when it should not have been") m.t.Error("RecordsByName called when it should not have been")
return nil, nil, nil return nil, nil, nil
} }
func (m mock) RecordsByTypeAndName(context.Context, string, string, string, *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) { func (m mock) RecordsByTypeAndName(context.Context, string, string, string, *godo.ListOptions) ([]godo.DomainRecord, *godo.Response, error) {
m.t.Error("RecordsByTypeAndName called when it should not have been") m.t.Error("RecordsByTypeAndName called when it should not have been")
return nil, nil, nil return nil, nil, nil
} }
func (m mock) Record(context.Context, string, int) (*godo.DomainRecord, *godo.Response, error) { func (m mock) Record(context.Context, string, int) (*godo.DomainRecord, *godo.Response, error) {
m.t.Error("Record called when it should not have been") m.t.Error("Record called when it should not have been")
return nil, nil, nil return nil, nil, nil
} }
func (m mock) DeleteRecord(context.Context, string, int) (*godo.Response, error) { func (m mock) DeleteRecord(context.Context, string, int) (*godo.Response, error) {
m.t.Error("DeleteRecord called when it should not have been") m.t.Error("DeleteRecord called when it should not have been")
return nil, nil return nil, nil
} }
func (m mock) EditRecord(_ context.Context, domain string, id int, req *godo.DomainRecordEditRequest) (*godo.DomainRecord, *godo.Response, error) { func (m mock) EditRecord(_ context.Context, domain string, id int, req *godo.DomainRecordEditRequest) (*godo.DomainRecord, *godo.Response, error) {
if m.edit_record_request == nil { if m.edit_record_request == nil {
m.t.Error("EditRecord called with empty request") m.t.Error("EditRecord called with empty request")
} }
assert.Equal(m.t, m.edit_record_request, req) assert.Equal(m.t, m.edit_record_request, req)
record := godo.DomainRecord{ record := godo.DomainRecord{
ID: id, ID: id,
Type: req.Type, Type: req.Type,
Name: req.Name, Name: req.Name,
Data: req.Data, Data: req.Data,
Priority: req.Priority, Priority: req.Priority,
Port: req.Port, Port: req.Port,
TTL: req.TTL, TTL: req.TTL,
Weight: req.Weight, Weight: req.Weight,
Flags: req.Flags, Flags: req.Flags,
Tag: req.Tag, Tag: req.Tag,
} }
return &record, nil, m.edit_record_error return &record, nil, m.edit_record_error
} }
func (m mock) CreateRecord(context.Context, string, *godo.DomainRecordEditRequest) (*godo.DomainRecord, *godo.Response, error) { func (m mock) CreateRecord(context.Context, string, *godo.DomainRecordEditRequest) (*godo.DomainRecord, *godo.Response, error) {
m.t.Error("CreateRecord called when it should not have been") m.t.Error("CreateRecord called when it should not have been")
return nil, nil, nil return nil, nil, nil
} }

View file

@ -1,98 +1,98 @@
package digitalocean package digitalocean
import ( import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"dnsupdater/provider" "dnsupdater/provider"
"github.com/digitalocean/godo" "github.com/digitalocean/godo"
) )
type Provider struct { type Provider struct {
service godo.DomainsService service godo.DomainsService
cache map[string][]godo.DomainRecord cache map[string][]godo.DomainRecord
} }
func New(token string) Provider { func New(token string) Provider {
return Provider{ return Provider{
service: godo.NewFromToken(token).Domains, service: godo.NewFromToken(token).Domains,
cache: make(map[string][]godo.DomainRecord), cache: make(map[string][]godo.DomainRecord),
} }
} }
func Factory(args map[string]interface{}) (provider.Provider, error) { func Factory(args map[string]interface{}) (provider.Provider, error) {
t, ok := args["token"] t, ok := args["token"]
if !ok { if !ok {
return nil, errors.New("did not find token") return nil, errors.New("did not find token")
} }
token, ok := t.(string) token, ok := t.(string)
if !ok { if !ok {
return nil, errors.New("token must be a string") return nil, errors.New("token must be a string")
} }
return New(token), nil return New(token), nil
} }
func (d *Provider) fetch(domain string) ([]godo.DomainRecord, error) { func (d *Provider) fetch(domain string) ([]godo.DomainRecord, error) {
domains, ok := d.cache[domain] domains, ok := d.cache[domain]
if !ok { if !ok {
var err error var err error
options := &godo.ListOptions{ options := &godo.ListOptions{
PerPage: 50, PerPage: 50,
} }
domains, _, err = d.service.RecordsByType(context.Background(), domain, "A", options) domains, _, err = d.service.RecordsByType(context.Background(), domain, "A", options)
if err != nil { if err != nil {
return nil, err return nil, err
} }
d.cache[domain] = domains d.cache[domain] = domains
} }
return domains, nil return domains, nil
} }
func (d *Provider) find(domain string, record string) (*godo.DomainRecord, error) { func (d *Provider) find(domain string, record string) (*godo.DomainRecord, error) {
records, err := d.fetch(domain) records, err := d.fetch(domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, r := range records { for _, r := range records {
if r.Name == record { if r.Name == record {
return &r, nil return &r, nil
} }
} }
return nil, fmt.Errorf("could not find record %s", record) return nil, fmt.Errorf("could not find record %s", record)
} }
func (d Provider) Update(domain string, record string, ip net.IP) error { func (d Provider) Update(domain string, record string, ip net.IP) error {
r, err := d.find(domain, record) r, err := d.find(domain, record)
if err != nil { if err != nil {
return err return err
} }
if r.Data != ip.String() { if r.Data != ip.String() {
// Update // Update
req := godo.DomainRecordEditRequest{ req := godo.DomainRecordEditRequest{
// Type: r.Type, // Type: r.Type,
// Name: r.Name, // Name: r.Name,
Data: ip.String(), Data: ip.String(),
// Priority: r.Priority, // Priority: r.Priority,
// Port: r.Port, // Port: r.Port,
// TTL: r.TTL, // TTL: r.TTL,
// Weight: r.Weight, // Weight: r.Weight,
// Flags: r.Flags, // Flags: r.Flags,
// Tag: r.Tag, // Tag: r.Tag,
} }
_, _, err := d.service.EditRecord(context.Background(), domain, r.ID, &req) _, _, err := d.service.EditRecord(context.Background(), domain, r.ID, &req)
if err != nil { if err != nil {
return err return err
} }
} }
return nil return nil
} }

View file

@ -1,11 +1,11 @@
package provider package provider
import ( import (
"net" "net"
) )
type Provider interface { type Provider interface {
Update(domain string, record string, ip net.IP) error Update(domain string, record string, ip net.IP) error
} }
type ProviderFactory func(map[string]interface{}) (Provider, error) type ProviderFactory func(map[string]interface{}) (Provider, error)

View file

@ -1,48 +1,48 @@
package manager package manager
import ( import (
"fmt" "fmt"
"dnsupdater/provider" "dnsupdater/provider"
"dnsupdater/provider/digitalocean" "dnsupdater/provider/digitalocean"
) )
var factories = map[string]provider.ProviderFactory{ var factories = map[string]provider.ProviderFactory{
"digitalocean": digitalocean.Factory, "digitalocean": digitalocean.Factory,
} }
type Manager struct { type Manager struct {
services map[string]provider.Provider services map[string]provider.Provider
} }
func New() *Manager { func New() *Manager {
return &Manager{ return &Manager{
services: make(map[string]provider.Provider), services: make(map[string]provider.Provider),
} }
} }
func (m Manager) Get(name string) provider.Provider { func (m Manager) Get(name string) provider.Provider {
if service, ok := m.services[name]; ok { if service, ok := m.services[name]; ok {
return service return service
} }
return nil return nil
} }
func (m Manager) RegisterFromConfig(providers map[string]map[string]interface{}) error { func (m Manager) RegisterFromConfig(providers map[string]map[string]interface{}) error {
for name, args := range providers { for name, args := range providers {
if factory, ok := factories[name]; ok { if factory, ok := factories[name]; ok {
provider, err := factory(args) provider, err := factory(args)
if err != nil { if err != nil {
return fmt.Errorf("could not create provider '%s': %v", name, err) return fmt.Errorf("could not create provider '%s': %v", name, err)
} }
m.Register(name, provider) m.Register(name, provider)
} }
} }
return nil return nil
} }
func (m Manager) Register(name string, provider provider.Provider) { func (m Manager) Register(name string, provider provider.Provider) {
m.services[name] = provider m.services[name] = provider
} }