diff --git a/ip/resolver/basic_http/service.go b/ip/resolver/basic_http/service.go index af03405..a3a3d07 100644 --- a/ip/resolver/basic_http/service.go +++ b/ip/resolver/basic_http/service.go @@ -1,6 +1,7 @@ package basic_http import ( + "context" "fmt" "io" "net" @@ -18,8 +19,8 @@ func (s Service) Name() string { return s.ServiceName } -func (s Service) Lookup() (net.IP, error) { - req, err := http.NewRequest("GET", s.Url, nil) +func (s Service) Lookup(ctx context.Context) (net.IP, error) { + req, err := http.NewRequestWithContext(ctx, "GET", s.Url, nil) if err != nil { return nil, err } diff --git a/ip/resolver/basic_http/service_test.go b/ip/resolver/basic_http/service_test.go index 5611e8f..d191f5d 100644 --- a/ip/resolver/basic_http/service_test.go +++ b/ip/resolver/basic_http/service_test.go @@ -1,6 +1,7 @@ package basic_http import ( + "context" "net" "net/http" "net/http/httptest" @@ -24,7 +25,7 @@ func TestService_Lookup(t *testing.T) { s := Service{Url: server.URL} - ip, err := s.Lookup() + ip, err := s.Lookup(context.Background()) assert.NoError(t, err) assert.Equal(t, net.IPv4(255, 240, 85, 2), ip) @@ -46,7 +47,7 @@ func TestService_Lookup_WithHeaders(t *testing.T) { }, } - ip, err := s.Lookup() + ip, err := s.Lookup(context.Background()) assert.NoError(t, err) assert.Equal(t, net.IPv4(125, 74, 233, 13), ip) @@ -62,7 +63,7 @@ func TestService_Lookup_HTTPError(t *testing.T) { Url: server.URL, } - ip, err := s.Lookup() + ip, err := s.Lookup(context.Background()) assert.EqualError(t, err, "HTTP Response: 404 Not Found") assert.Nil(t, ip) } @@ -78,7 +79,7 @@ func TestService_Lookup_ParseError(t *testing.T) { Url: server.URL, } - ip, err := s.Lookup() + ip, err := s.Lookup(context.Background()) assert.EqualError(t, err, "Failed to parse ip: random_string") assert.Nil(t, ip) } diff --git a/ip/resolver/jsonip/service.go b/ip/resolver/jsonip/service.go index 960ec01..6af1e16 100644 --- a/ip/resolver/jsonip/service.go +++ b/ip/resolver/jsonip/service.go @@ -1,6 +1,7 @@ package jsonip import ( + "context" "encoding/json" "net" "net/http" @@ -20,8 +21,13 @@ func (s Service) Name() string { return "jsonip" } -func (s Service) Lookup() (net.IP, error) { - resp, err := http.DefaultClient.Get(s.url) +func (s Service) Lookup(ctx context.Context) (net.IP, error) { + req, err := http.NewRequestWithContext(ctx, "GET", s.url, nil) + if err != nil { + return nil, err + } + + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } diff --git a/ip/resolver/jsonip/service_test.go b/ip/resolver/jsonip/service_test.go index 1087b3e..ed7e419 100644 --- a/ip/resolver/jsonip/service_test.go +++ b/ip/resolver/jsonip/service_test.go @@ -1,6 +1,7 @@ package jsonip import ( + "context" "net" "net/http" "net/http/httptest" @@ -24,8 +25,21 @@ func TestService_Lookup(t *testing.T) { s := Service{url: server.URL} - ip, err := s.Lookup() + ip, err := s.Lookup(context.Background()) assert.NoError(t, err) assert.Equal(t, net.IPv4(211, 46, 32, 214), ip) } + +func TestService_Lookup_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + })) + defer server.Close() + + s := Service{url: server.URL} + + ip, err := s.Lookup(context.Background()) + assert.EqualError(t, err, "HTTP Response: 404 Not Found") + assert.Nil(t, ip) +} diff --git a/ip/resolver/mock/service.go b/ip/resolver/mock/service.go index b27f0c6..fe72248 100644 --- a/ip/resolver/mock/service.go +++ b/ip/resolver/mock/service.go @@ -1,6 +1,9 @@ package mock -import "net" +import ( + "context" + "net" +) type Service struct { IP net.IP @@ -11,6 +14,6 @@ func (s Service) Name() string { return "mock" } -func (s Service) Lookup() (net.IP, error) { +func (s Service) Lookup(ctx context.Context) (net.IP, error) { return s.IP, s.Error } diff --git a/ip/resolver/service.go b/ip/resolver/service.go index 60951d4..4852363 100644 --- a/ip/resolver/service.go +++ b/ip/resolver/service.go @@ -1,7 +1,9 @@ package lookup import ( + "context" "net" + "time" "dnsupdater/ip" ) @@ -9,7 +11,7 @@ import ( type Service interface { Name() string - Lookup() (net.IP, error) + Lookup(ctx context.Context) (net.IP, error) } const WAN_IFACE = "wan" @@ -17,7 +19,9 @@ const WAN_IFACE = "wan" func LookupWrapper(service Service) ip.NetInterfaceIPResolver { return func(iface_name string) (net.IP, error) { if iface_name == WAN_IFACE { - return service.Lookup() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + return service.Lookup(ctx) } return ip.GetInterfaceIP(iface_name) }