From 189511bb23e9604433fcde8280e5979336905a25 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Sun, 30 Mar 2025 22:01:44 +0200 Subject: [PATCH] Add ExecCommandFunc option, which allows for overwriting exec.Command There are some reasons, for example to control environment variables being passed to the process, that the calling code might want more control over how the process is created. The new option function allows overwriting the `exec.Command` in order to do so. --- iptables/iptables.go | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/iptables/iptables.go b/iptables/iptables.go index b058995..28b48f6 100644 --- a/iptables/iptables.go +++ b/iptables/iptables.go @@ -82,6 +82,7 @@ type IPTables struct { v3 int mode string // the underlying iptables operating mode, e.g. nf_tables timeout int // time to wait for the iptables lock, default waits forever + execCommandFunc func(string, ...string) *exec.Cmd } // Stat represents a structured statistic entry. @@ -118,6 +119,15 @@ func Path(path string) option { } } +// ExecCommandFunc allows for overriding the [exec.Command] used when spawning +// iptables sub-processes. Stdout and Stderr should be nil on the returned +// [exec.Cmd]. +func ExecCommandFunc(fn func(name string, arg ...string) *exec.Cmd) option { + return func(ipt *IPTables) { + ipt.execCommandFunc = fn + } +} + // New creates a new IPTables configured with the options passed as parameters. // Supported parameters are: // @@ -133,9 +143,10 @@ func Path(path string) option { func New(opts ...option) (*IPTables, error) { ipt := &IPTables{ - proto: ProtocolIPv4, - timeout: 0, - path: "", + proto: ProtocolIPv4, + timeout: 0, + path: "", + execCommandFunc: exec.Command, } for _, opt := range opts { @@ -155,7 +166,7 @@ func New(opts ...option) (*IPTables, error) { } ipt.path = path - vstring, err := getIptablesVersionString(path) + vstring, err := getIptablesVersionString(ipt.execCommandFunc, path) if err != nil { return nil, fmt.Errorf("could not get iptables version: %v", err) } @@ -563,7 +574,7 @@ func (ipt *IPTables) run(args ...string) error { // runWithOutput runs an iptables command with the given arguments, // writing any stdout output to the given writer func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error { - args = append([]string{ipt.path}, args...) + args = append([]string(nil), args...) // copy input args if ipt.hasWait { args = append(args, "--wait") if ipt.timeout != 0 && ipt.waitSupportSecond { @@ -585,17 +596,14 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error { } var stderr bytes.Buffer - cmd := exec.Cmd{ - Path: ipt.path, - Args: args, - Stdout: stdout, - Stderr: &stderr, - } + cmd := ipt.execCommandFunc(ipt.path, args...) + cmd.Stdout = stdout + cmd.Stderr = &stderr if err := cmd.Run(); err != nil { switch e := err.(type) { case *exec.ExitError: - return &Error{*e, cmd, stderr.String(), nil} + return &Error{*e, *cmd, stderr.String(), nil} default: return err } @@ -651,8 +659,11 @@ func extractIptablesVersion(str string) (int, int, int, string, error) { } // Runs "iptables --version" to get the version string -func getIptablesVersionString(path string) (string, error) { - cmd := exec.Command(path, "--version") +func getIptablesVersionString( + execCommandFunc func(string, ...string) *exec.Cmd, + path string, +) (string, error) { + cmd := execCommandFunc(path, "--version") var out bytes.Buffer cmd.Stdout = &out err := cmd.Run()