Skip to content

Commit c725195

Browse files
Copilotintel352
andcommitted
Add OAuth2 client_credentials support to step.http_call
Co-authored-by: intel352 <77607+intel352@users.noreply.github.com>
1 parent 460960b commit c725195

3 files changed

Lines changed: 699 additions & 37 deletions

File tree

cmd/wfctl/type_registry.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ func KnownStepTypes() map[string]StepTypeInfo {
548548
"step.http_call": {
549549
Type: "step.http_call",
550550
Plugin: "pipelinesteps",
551-
ConfigKeys: []string{"url", "method", "headers", "body", "timeout"},
551+
ConfigKeys: []string{"url", "method", "headers", "body", "timeout", "auth"},
552552
},
553553
"step.request_parse": {
554554
Type: "step.request_parse",

module/pipeline_step_http_call.go

Lines changed: 262 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,74 @@ import (
77
"fmt"
88
"io"
99
"net/http"
10+
"net/url"
11+
"strings"
12+
"sync"
1013
"time"
1114

1215
"github.com/CrisisTextLine/modular"
1316
)
1417

18+
// oauthConfig holds OAuth2 client_credentials configuration.
19+
type oauthConfig struct {
20+
tokenURL string
21+
clientID string
22+
clientSecret string
23+
scopes []string
24+
}
25+
26+
// tokenCache holds a cached OAuth2 access token and its expiry.
27+
type tokenCache struct {
28+
mu sync.Mutex
29+
accessToken string
30+
expiry time.Time
31+
}
32+
33+
// get returns the cached token if it is still valid, or empty string.
34+
func (c *tokenCache) get() string {
35+
c.mu.Lock()
36+
defer c.mu.Unlock()
37+
if c.accessToken != "" && time.Now().Before(c.expiry) {
38+
return c.accessToken
39+
}
40+
return ""
41+
}
42+
43+
// set stores a token with the given TTL.
44+
func (c *tokenCache) set(token string, ttl time.Duration) {
45+
c.mu.Lock()
46+
defer c.mu.Unlock()
47+
c.accessToken = token
48+
c.expiry = time.Now().Add(ttl)
49+
}
50+
51+
// invalidate clears the cached token.
52+
func (c *tokenCache) invalidate() {
53+
c.mu.Lock()
54+
defer c.mu.Unlock()
55+
c.accessToken = ""
56+
c.expiry = time.Time{}
57+
}
58+
1559
// HTTPCallStep makes an HTTP request as a pipeline step.
1660
type HTTPCallStep struct {
17-
name string
18-
url string
19-
method string
20-
headers map[string]string
21-
body map[string]any
22-
timeout time.Duration
23-
tmpl *TemplateEngine
61+
name string
62+
url string
63+
method string
64+
headers map[string]string
65+
body map[string]any
66+
timeout time.Duration
67+
tmpl *TemplateEngine
68+
auth *oauthConfig
69+
tokenCache *tokenCache
70+
httpClient *http.Client
2471
}
2572

2673
// NewHTTPCallStepFactory returns a StepFactory that creates HTTPCallStep instances.
2774
func NewHTTPCallStepFactory() StepFactory {
2875
return func(name string, config map[string]any, _ modular.Application) (PipelineStep, error) {
29-
url, _ := config["url"].(string)
30-
if url == "" {
76+
rawURL, _ := config["url"].(string)
77+
if rawURL == "" {
3178
return nil, fmt.Errorf("http_call step %q: 'url' is required", name)
3279
}
3380

@@ -37,11 +84,12 @@ func NewHTTPCallStepFactory() StepFactory {
3784
}
3885

3986
step := &HTTPCallStep{
40-
name: name,
41-
url: url,
42-
method: method,
43-
timeout: 30 * time.Second,
44-
tmpl: NewTemplateEngine(),
87+
name: name,
88+
url: rawURL,
89+
method: method,
90+
timeout: 30 * time.Second,
91+
tmpl: NewTemplateEngine(),
92+
httpClient: http.DefaultClient,
4593
}
4694

4795
if headers, ok := config["headers"].(map[string]any); ok {
@@ -63,25 +111,121 @@ func NewHTTPCallStepFactory() StepFactory {
63111
}
64112
}
65113

114+
if authCfg, ok := config["auth"].(map[string]any); ok {
115+
authType, _ := authCfg["type"].(string)
116+
if authType == "oauth2_client_credentials" {
117+
tokenURL, _ := authCfg["token_url"].(string)
118+
if tokenURL == "" {
119+
return nil, fmt.Errorf("http_call step %q: auth.token_url is required for oauth2_client_credentials", name)
120+
}
121+
clientID, _ := authCfg["client_id"].(string)
122+
if clientID == "" {
123+
return nil, fmt.Errorf("http_call step %q: auth.client_id is required for oauth2_client_credentials", name)
124+
}
125+
clientSecret, _ := authCfg["client_secret"].(string)
126+
if clientSecret == "" {
127+
return nil, fmt.Errorf("http_call step %q: auth.client_secret is required for oauth2_client_credentials", name)
128+
}
129+
130+
var scopes []string
131+
if raw, ok := authCfg["scopes"]; ok {
132+
switch v := raw.(type) {
133+
case []string:
134+
scopes = v
135+
case []any:
136+
for _, s := range v {
137+
if str, ok := s.(string); ok {
138+
scopes = append(scopes, str)
139+
}
140+
}
141+
}
142+
}
143+
144+
step.auth = &oauthConfig{
145+
tokenURL: tokenURL,
146+
clientID: clientID,
147+
clientSecret: clientSecret,
148+
scopes: scopes,
149+
}
150+
step.tokenCache = &tokenCache{}
151+
}
152+
}
153+
66154
return step, nil
67155
}
68156
}
69157

70158
// Name returns the step name.
71159
func (s *HTTPCallStep) Name() string { return s.name }
72160

73-
// Execute performs the HTTP request and returns the response.
74-
func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) {
75-
ctx, cancel := context.WithTimeout(ctx, s.timeout)
76-
defer cancel()
161+
// fetchToken obtains a new OAuth2 access token using client_credentials grant.
162+
func (s *HTTPCallStep) fetchToken(ctx context.Context) (string, error) {
163+
params := url.Values{
164+
"grant_type": {"client_credentials"},
165+
"client_id": {s.auth.clientID},
166+
"client_secret": {s.auth.clientSecret},
167+
}
168+
if len(s.auth.scopes) > 0 {
169+
params.Set("scope", strings.Join(s.auth.scopes, " "))
170+
}
77171

78-
// Resolve URL template
79-
resolvedURL, err := s.tmpl.Resolve(s.url, pc)
172+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.auth.tokenURL,
173+
strings.NewReader(params.Encode()))
80174
if err != nil {
81-
return nil, fmt.Errorf("http_call step %q: failed to resolve url: %w", s.name, err)
175+
return "", fmt.Errorf("http_call step %q: failed to create token request: %w", s.name, err)
176+
}
177+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
178+
179+
resp, err := s.httpClient.Do(req)
180+
if err != nil {
181+
return "", fmt.Errorf("http_call step %q: token request failed: %w", s.name, err)
182+
}
183+
defer resp.Body.Close()
184+
185+
body, err := io.ReadAll(resp.Body)
186+
if err != nil {
187+
return "", fmt.Errorf("http_call step %q: failed to read token response: %w", s.name, err)
82188
}
83189

84-
var bodyReader io.Reader
190+
if resp.StatusCode != http.StatusOK {
191+
return "", fmt.Errorf("http_call step %q: token endpoint returned HTTP %d: %s", s.name, resp.StatusCode, string(body))
192+
}
193+
194+
var tokenResp struct {
195+
AccessToken string `json:"access_token"` //nolint:gosec // G117: parsing OAuth2 token response, not a secret exposure
196+
ExpiresIn float64 `json:"expires_in"`
197+
TokenType string `json:"token_type"`
198+
}
199+
if err := json.Unmarshal(body, &tokenResp); err != nil {
200+
return "", fmt.Errorf("http_call step %q: failed to parse token response: %w", s.name, err)
201+
}
202+
if tokenResp.AccessToken == "" {
203+
return "", fmt.Errorf("http_call step %q: token response missing access_token", s.name)
204+
}
205+
206+
ttl := time.Duration(tokenResp.ExpiresIn) * time.Second
207+
if ttl <= 0 {
208+
ttl = 3600 * time.Second
209+
}
210+
// Subtract a small buffer to avoid using a token that is about to expire
211+
if ttl > 10*time.Second {
212+
ttl -= 10 * time.Second
213+
}
214+
s.tokenCache.set(tokenResp.AccessToken, ttl)
215+
216+
return tokenResp.AccessToken, nil
217+
}
218+
219+
// getToken returns a valid OAuth2 token, fetching one if the cache is empty or expired.
220+
func (s *HTTPCallStep) getToken(ctx context.Context) (string, error) {
221+
if token := s.tokenCache.get(); token != "" {
222+
return token, nil
223+
}
224+
return s.fetchToken(ctx)
225+
}
226+
227+
// buildBodyReader constructs the request body reader from the step configuration.
228+
func (s *HTTPCallStep) buildBodyReader(pc *PipelineContext) (io.Reader, error) {
85229
if s.body != nil {
86230
resolvedBody, resolveErr := s.tmpl.ResolveMap(s.body, pc)
87231
if resolveErr != nil {
@@ -91,15 +235,20 @@ func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepR
91235
if marshalErr != nil {
92236
return nil, fmt.Errorf("http_call step %q: failed to marshal body: %w", s.name, marshalErr)
93237
}
94-
bodyReader = bytes.NewReader(data)
95-
} else if s.method != "GET" && s.method != "HEAD" {
238+
return bytes.NewReader(data), nil
239+
}
240+
if s.method != "GET" && s.method != "HEAD" {
96241
data, marshalErr := json.Marshal(pc.Current)
97242
if marshalErr != nil {
98243
return nil, fmt.Errorf("http_call step %q: failed to marshal current data: %w", s.name, marshalErr)
99244
}
100-
bodyReader = bytes.NewReader(data)
245+
return bytes.NewReader(data), nil
101246
}
247+
return nil, nil
248+
}
102249

250+
// buildRequest constructs the HTTP request with resolved headers and optional bearer token.
251+
func (s *HTTPCallStep) buildRequest(ctx context.Context, resolvedURL string, bodyReader io.Reader, pc *PipelineContext, bearerToken string) (*http.Request, error) {
103252
req, err := http.NewRequestWithContext(ctx, s.method, resolvedURL, bodyReader)
104253
if err != nil {
105254
return nil, fmt.Errorf("http_call step %q: failed to create request: %w", s.name, err)
@@ -116,19 +265,15 @@ func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepR
116265
req.Header.Set(k, resolved)
117266
}
118267
}
119-
120-
resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: SSRF via taint analysis
121-
if err != nil {
122-
return nil, fmt.Errorf("http_call step %q: request failed: %w", s.name, err)
268+
if bearerToken != "" {
269+
req.Header.Set("Authorization", "Bearer "+bearerToken)
123270
}
124-
defer resp.Body.Close()
125271

126-
respBody, err := io.ReadAll(resp.Body)
127-
if err != nil {
128-
return nil, fmt.Errorf("http_call step %q: failed to read response: %w", s.name, err)
129-
}
272+
return req, nil
273+
}
130274

131-
// Build response headers map
275+
// parseResponse converts an HTTP response into a StepResult output map.
276+
func parseHTTPResponse(resp *http.Response, respBody []byte) map[string]any {
132277
respHeaders := make(map[string]any, len(resp.Header))
133278
for k, v := range resp.Header {
134279
if len(v) == 1 {
@@ -148,14 +293,95 @@ func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepR
148293
"headers": respHeaders,
149294
}
150295

151-
// Try to parse response as JSON
152296
var jsonResp any
153297
if json.Unmarshal(respBody, &jsonResp) == nil {
154298
output["body"] = jsonResp
155299
} else {
156300
output["body"] = string(respBody)
157301
}
158302

303+
return output
304+
}
305+
306+
// Execute performs the HTTP request and returns the response.
307+
func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) {
308+
ctx, cancel := context.WithTimeout(ctx, s.timeout)
309+
defer cancel()
310+
311+
// Resolve URL template
312+
resolvedURL, err := s.tmpl.Resolve(s.url, pc)
313+
if err != nil {
314+
return nil, fmt.Errorf("http_call step %q: failed to resolve url: %w", s.name, err)
315+
}
316+
317+
bodyReader, err := s.buildBodyReader(pc)
318+
if err != nil {
319+
return nil, err
320+
}
321+
322+
// Obtain OAuth2 bearer token if auth is configured
323+
var bearerToken string
324+
if s.auth != nil {
325+
bearerToken, err = s.getToken(ctx)
326+
if err != nil {
327+
return nil, err
328+
}
329+
}
330+
331+
req, err := s.buildRequest(ctx, resolvedURL, bodyReader, pc, bearerToken)
332+
if err != nil {
333+
return nil, err
334+
}
335+
336+
resp, err := s.httpClient.Do(req) //nolint:gosec // G107: URL is user-configured
337+
if err != nil {
338+
return nil, fmt.Errorf("http_call step %q: request failed: %w", s.name, err)
339+
}
340+
defer resp.Body.Close()
341+
342+
respBody, err := io.ReadAll(resp.Body)
343+
if err != nil {
344+
return nil, fmt.Errorf("http_call step %q: failed to read response: %w", s.name, err)
345+
}
346+
347+
// On 401, invalidate token cache and retry once with a fresh token
348+
if resp.StatusCode == http.StatusUnauthorized && s.auth != nil {
349+
s.tokenCache.invalidate()
350+
351+
newToken, tokenErr := s.fetchToken(ctx)
352+
if tokenErr != nil {
353+
return nil, tokenErr
354+
}
355+
356+
retryBody, buildErr := s.buildBodyReader(pc)
357+
if buildErr != nil {
358+
return nil, buildErr
359+
}
360+
retryReq, buildErr := s.buildRequest(ctx, resolvedURL, retryBody, pc, newToken)
361+
if buildErr != nil {
362+
return nil, buildErr
363+
}
364+
365+
retryResp, doErr := s.httpClient.Do(retryReq) //nolint:gosec // G107: URL is user-configured
366+
if doErr != nil {
367+
return nil, fmt.Errorf("http_call step %q: retry request failed: %w", s.name, doErr)
368+
}
369+
defer retryResp.Body.Close()
370+
371+
respBody, err = io.ReadAll(retryResp.Body)
372+
if err != nil {
373+
return nil, fmt.Errorf("http_call step %q: failed to read retry response: %w", s.name, err)
374+
}
375+
376+
output := parseHTTPResponse(retryResp, respBody)
377+
if retryResp.StatusCode >= 400 {
378+
return nil, fmt.Errorf("http_call step %q: HTTP %d: %s", s.name, retryResp.StatusCode, string(respBody))
379+
}
380+
return &StepResult{Output: output}, nil
381+
}
382+
383+
output := parseHTTPResponse(resp, respBody)
384+
159385
if resp.StatusCode >= 400 {
160386
return nil, fmt.Errorf("http_call step %q: HTTP %d: %s", s.name, resp.StatusCode, string(respBody))
161387
}

0 commit comments

Comments
 (0)