Skip to content
11 changes: 8 additions & 3 deletions cmd/lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cmd
import (
"context"
"fmt"
"net/http"

"github.com/spf13/cobra"

Expand All @@ -15,9 +14,15 @@ func lock(group, id, url *string) *cobra.Command {
Use: "recursive-lock",
Short: "Try to reserve (lock) a slot for rebooting",
RunE: func(cmd *cobra.Command, args []string) error {
httpClient := http.DefaultClient
if err := checkID(id); err != nil {
return fmt.Errorf("checking ID: %w", err)
}

c, err := client.New(*url, *group, *id, httpClient)
c, err := client.New(&client.Config{
ID: *id,
Group: *group,
URL: *url,
})
if err != nil {
return fmt.Errorf("building the client: %w", err)
}
Expand Down
32 changes: 32 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
package cmd

import (
"fmt"
"io/ioutil"

"github.com/spf13/cobra"
)

Expand All @@ -20,3 +23,32 @@ func Command() *cobra.Command {

return cli
}

// machineID is a helper to return unique ID
// of the machine.
func machineID() (string, error) {
id, err := ioutil.ReadFile("/etc/machine-id")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making this path configurable would make the more testable, but I guess it would have to be done via e.g. --id-from-file flag or something. Also, we don't have any tests in place for this, so would be more effort.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get your point, however we should not forget that machineID() is a helper to provide a default value for the --id flag. In case one wants to use a different ID he can just use --id my-id or even --id $(cat /tmp/my-id-file).

If we really want to test this helper, we can still rely on fs abstraction with afero.FS but it becomes a bit overkill IMHO. :)

if err != nil {
return "", fmt.Errorf("reading machine ID from file: %w", err)
}

return string(id), nil
}

// checkID asserts that the ID is not nil, if it's the case
// it uses `machineID` to generate a default one.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not generate anything.

Suggested change
// it uses `machineID` to generate a default one.
// it uses `machineID` to set a default one.

func checkID(id *string) error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this function should have check in the name and only return error. Right now it strangely looks like a validation function, which is not really the case, as it may actually mutate the given id.

Maybe we can make it return (string, err) and rename to e.g. getID?

// the ID is set and it's not empty.
if id != nil && *id != "" {
return nil
}

i, err := machineID()
if err != nil {
return fmt.Errorf("getting default machine ID: %w", err)
}

*id = i

return nil
}
11 changes: 8 additions & 3 deletions cmd/unlock.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cmd
import (
"context"
"fmt"
"net/http"

"github.com/spf13/cobra"

Expand All @@ -15,9 +14,15 @@ func unlock(group, id, url *string) *cobra.Command {
Use: "unlock-if-held",
Short: "Try to release (unlock) a slot that it was previously holding",
RunE: func(cmd *cobra.Command, args []string) error {
httpClient := http.DefaultClient
if err := checkID(id); err != nil {
return fmt.Errorf("checking ID: %w", err)
}

c, err := client.New(*url, *group, *id, httpClient)
c, err := client.New(&client.Config{
ID: *id,
Group: *group,
URL: *url,
})
if err != nil {
return fmt.Errorf("building the client: %w", err)
}
Expand Down
50 changes: 50 additions & 0 deletions pkg/client/authentication.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package client

import (
"context"
"fmt"
"net/http"
)

type basicAuthRoundTripper struct {
username string
password string
rt http.RoundTripper
}

// RoundTrip is required to implement RoundTripper interface.
// We check if an authorization token is already set, if not we set it.
// We return the initial RoundTripper to chain it in the whole process.
func (b *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if len(req.Header.Get("Authorization")) != 0 {
resp, err := b.rt.RoundTrip(req)
if err != nil {
return nil, fmt.Errorf("inner round trip error: %w", err)
}

return resp, nil
}

req = req.Clone(context.TODO())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
req = req.Clone(context.TODO())
req = req.Clone(req.Context())

Otherwise using round tripper swallows the context control on the request:

package client_test

import (
  "context"
  "errors"
  "net/http"
  "testing"
  "time"

  "github.com/flatcar-linux/fleetlock/pkg/client"
)

func Test_Cancelling_context_for_request_performed_with_http_client_with_basic_auth_round_tripper_cancels_the_request(t *testing.T) {
  httpClient := http.Client{
    Transport: client.NewBasicAuthRoundTripper("foo", "bar", nil),
  }

  requestTimeout := time.Second

  ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
  t.Cleanup(cancel)

  req, err := http.NewRequestWithContext(ctx, "GET", "http://10.255.255.1", nil)
  if err != nil {
    t.Fatal(err)
  }

  errCh := make(chan error, 1)
  go func() {
    _, err = httpClient.Do(req)
    errCh <- err
  }()

  testDeadline := time.NewTimer(2 * requestTimeout)
  select {
  case <-testDeadline.C:
    t.Fatalf("Expected request to return before the deadline")
  case err := <-errCh:
    if err != nil && !errors.Is(err, context.DeadlineExceeded) {
      t.Fatal(err)
    }
  }
}

req.SetBasicAuth(b.username, b.password)

resp, err := b.rt.RoundTrip(req)
if err != nil {
return nil, fmt.Errorf("inner round trip error: %w", err)
}

return resp, nil
}

// NewBasicAuthRoundTripper returns a basicAuthRoundTripper with username and password.
func NewBasicAuthRoundTripper(username, password string, rt http.RoundTripper) http.RoundTripper {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can let rt to be nil and use net/http.Transport{} as default?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like round tripper should have it's own set of tests, separate from the client, to make the effort required for testing it smaller.

if rt == nil {
rt = &http.Transport{}
}

return &basicAuthRoundTripper{
username: username,
password: password,
rt: rt,
}
}
34 changes: 26 additions & 8 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,35 @@ type Client struct {
}

// New builds a FleetLock client.
func New(baseServerURL, group, id string, c HTTPClient) (*Client, error) {
if _, err := url.ParseRequestURI(baseServerURL); err != nil {
func New(cfg *Config) (*Client, error) {
fleetlock := &Client{
baseServerURL: cfg.URL,
http: cfg.HTTP,
group: cfg.Group,
id: cfg.ID,
}

if fleetlock.id == "" {
return nil, fmt.Errorf("ID is required")
}

if fleetlock.baseServerURL == "" {
return nil, fmt.Errorf("URL is required")
}

if _, err := url.ParseRequestURI(fleetlock.baseServerURL); err != nil {
return nil, fmt.Errorf("parsing URL: %w", err)
}

return &Client{
baseServerURL: baseServerURL,
http: c,
group: group,
id: id,
}, nil
if fleetlock.group == "" {
fleetlock.group = "default"
}

if fleetlock.http == nil {
fleetlock.http = http.DefaultClient
}

return fleetlock, nil
}

// RecursiveLock tries to reserve (lock) a slot for rebooting.
Expand Down
Loading