diff --git a/pkg/client/client.go b/pkg/client/client.go index 7568648..b276759 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -10,13 +10,6 @@ import ( "net/url" ) -// HTTPClient interface holds the required Post method -// to send FleetLock requests. -type HTTPClient interface { - // Do send a `body` payload to the URL. - Do(*http.Request) (*http.Response, error) -} - // Payload is the content to send // to the FleetLock server. type Payload struct { @@ -40,7 +33,7 @@ type Client struct { URL string group string id string - http HTTPClient + http *http.Client } func (c *Client) generateRequest(endpoint string) (*http.Request, error) { @@ -129,7 +122,7 @@ func (c *Client) UnlockIfHeld() error { } // New builds a Fleet-Lock client. -func New(URL, group, ID string, c HTTPClient) (*Client, error) { +func New(URL, group, ID string, c *http.Client) (*Client, error) { if _, err := url.ParseRequestURI(URL); err != nil { return nil, fmt.Errorf("parsing URL: %w", err) } diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index ae48200..48787e1 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -11,15 +11,15 @@ import ( "github.com/flatcar-linux/fleetlock/pkg/client" ) -type httpClient struct { +type mockRoundTripper struct { resp *http.Response r *http.Request doErr error } -func (h *httpClient) Do(req *http.Request) (*http.Response, error) { - h.r = req - return h.resp, h.doErr +func (r *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + r.r = req + return r.resp, r.doErr } func TestBadURL(t *testing.T) { @@ -61,11 +61,12 @@ func TestRecursiveLock(t *testing.T) { expErr: errors.New("unexpected status code: 100"), }, { - expErr: errors.New("doing the request: connection refused"), + expErr: errors.New("doing the request: Post \"http://1.2.3.4/v1/pre-reboot\": connection refused"), doErr: errors.New("connection refused"), }, } { - h := &httpClient{ + h := http.DefaultClient + tr := &mockRoundTripper{ resp: &http.Response{ StatusCode: test.statusCode, Body: test.body, @@ -73,6 +74,8 @@ func TestRecursiveLock(t *testing.T) { doErr: test.doErr, } + h.Transport = tr + c, _ := client.New("http://1.2.3.4", "default", "1234", h) err := c.RecursiveLock() @@ -80,8 +83,8 @@ func TestRecursiveLock(t *testing.T) { t.Fatalf("should have %v for err, got: %v", test.expErr, err) } - if h.r.URL.String() != expURL { - t.Fatalf("should have %s for URL, got: %s", expURL, h.r.URL.String()) + if tr.r.URL.String() != expURL { + t.Fatalf("should have %s for URL, got: %s", expURL, tr.r.URL.String()) } } } @@ -114,11 +117,12 @@ func TestUnlockIfHeld(t *testing.T) { expErr: errors.New("unexpected status code: 100"), }, { - expErr: errors.New("doing the request: connection refused"), + expErr: errors.New("doing the request: Post \"http://1.2.3.4/v1/steady-state\": connection refused"), doErr: errors.New("connection refused"), }, } { - h := &httpClient{ + h := http.DefaultClient + tr := &mockRoundTripper{ resp: &http.Response{ StatusCode: test.statusCode, Body: test.body, @@ -126,6 +130,8 @@ func TestUnlockIfHeld(t *testing.T) { doErr: test.doErr, } + h.Transport = tr + c, _ := client.New("http://1.2.3.4", "default", "1234", h) err := c.UnlockIfHeld() @@ -133,8 +139,8 @@ func TestUnlockIfHeld(t *testing.T) { t.Fatalf("should have %v for err, got: %v", test.expErr, err) } - if h.r.URL.String() != expURL { - t.Fatalf("should have %s for URL, got: %s", expURL, h.r.URL.String()) + if tr.r.URL.String() != expURL { + t.Fatalf("should have %s for URL, got: %s", expURL, tr.r.URL.String()) } } }