From 8bdcb5982392b51691ed383ea37e5c67572d2965 Mon Sep 17 00:00:00 2001 From: Kyle Ellrott Date: Mon, 9 Mar 2026 22:54:06 -0700 Subject: [PATCH 1/4] Adding methods to capture inter-step communication --- cmd/run/main.go | 8 +- cmd/run/run.go | 26 +++++- playbook/execute.go | 220 ++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 238 insertions(+), 16 deletions(-) diff --git a/cmd/run/main.go b/cmd/run/main.go index 8b3f016..775297b 100644 --- a/cmd/run/main.go +++ b/cmd/run/main.go @@ -14,6 +14,8 @@ var outDir string = "" var paramsFile string = "" var verbose bool = false var cmdParams map[string]string +var debugOutputDir string = "" +var debugSampleLimit int = 10 // Cmd is the declaration of the command line var Cmd = &cobra.Command{ @@ -46,11 +48,11 @@ var Cmd = &cobra.Command{ } pb := playbook.Playbook{} playbook.ParseBytes(yaml, "./playbook.yaml", &pb) - if err := Execute(pb, "./", "./", outDir, params); err != nil { + if err := Execute(pb, "./", "./", outDir, params, debugOutputDir, debugSampleLimit); err != nil { return err } } else { - if err := ExecuteFile(playFile, "./", outDir, params); err != nil { + if err := ExecuteFile(playFile, "./", outDir, params, debugOutputDir, debugSampleLimit); err != nil { return err } } @@ -65,4 +67,6 @@ func init() { flags.BoolVarP(&verbose, "verbose", "v", verbose, "Verbose logging") flags.StringToStringVarP(&cmdParams, "param", "p", cmdParams, "Parameter variable") flags.StringVarP(¶msFile, "params-file", "f", paramsFile, "Parameter file") + flags.StringVarP(&debugOutputDir, "debug-output-dir", "d", "", "Directory for debug capture files (default: ./debug-capture)") + flags.IntVarP(&debugSampleLimit, "debug-sample-limit", "l", 10, "Max records to capture per step (0 = unlimited)") } diff --git a/cmd/run/run.go b/cmd/run/run.go index 42651b2..e3e5e19 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -9,7 +9,7 @@ import ( "github.com/bmeg/sifter/task" ) -func ExecuteFile(playFile string, workDir string, outDir string, inputs map[string]string) error { +func ExecuteFile(playFile string, workDir string, outDir string, inputs map[string]string, debugDir string, debugLimit int) error { logger.Info("Starting", "playFile", playFile) pb := playbook.Playbook{} if err := playbook.ParseFile(playFile, &pb); err != nil { @@ -19,10 +19,10 @@ func ExecuteFile(playFile string, workDir string, outDir string, inputs map[stri a, _ := filepath.Abs(playFile) baseDir := filepath.Dir(a) logger.Debug("parsed file", "baseDir", baseDir, "playbook", pb) - return Execute(pb, baseDir, workDir, outDir, inputs) + return Execute(pb, baseDir, workDir, outDir, inputs, debugDir, debugLimit) } -func Execute(pb playbook.Playbook, baseDir string, workDir string, outDir string, params map[string]string) error { +func Execute(pb playbook.Playbook, baseDir string, workDir string, outDir string, params map[string]string, debugDir string, debugLimit int) error { if outDir == "" { outDir = pb.GetDefaultOutDir() @@ -32,6 +32,24 @@ func Execute(pb playbook.Playbook, baseDir string, workDir string, outDir string os.MkdirAll(outDir, 0777) } + // Setup debug capture directory if enabled + // Enable if: user explicitly set dir, OR user changed limit from default + enableDebug := debugDir != "" || (debugLimit != 10) + if enableDebug { + if debugDir == "" { + debugDir = filepath.Join(workDir, "debug-capture") + } else if !filepath.IsAbs(debugDir) { + debugDir = filepath.Join(workDir, debugDir) + } + if _, err := os.Stat(debugDir); os.IsNotExist(err) { + if err := os.MkdirAll(debugDir, 0777); err != nil { + logger.Error("Failed to create debug directory", "error", err) + return err + } + } + logger.Info("Debug capture enabled", "dir", debugDir, "limit", debugLimit) + } + nInputs, err := pb.PrepConfig(params, workDir) if err != nil { return err @@ -39,6 +57,6 @@ func Execute(pb playbook.Playbook, baseDir string, workDir string, outDir string logger.Debug("Running", "outDir", outDir) t := task.NewTask(pb.Name, baseDir, workDir, outDir, nInputs) - err = pb.Execute(t) + err = pb.Execute(t, debugDir, debugLimit) return err } diff --git a/playbook/execute.go b/playbook/execute.go index ed126ab..42efb5f 100644 --- a/playbook/execute.go +++ b/playbook/execute.go @@ -1,9 +1,14 @@ package playbook import ( + "encoding/json" "fmt" + "os" "path/filepath" "strings" + "sync" + "sync/atomic" + "time" "github.com/bmeg/flame" "github.com/bmeg/sifter/logger" @@ -119,7 +124,52 @@ type joinStruct struct { proc transform.JoinProcessor } -func (pb *Playbook) Execute(task task.RuntimeTask) error { +// stepCaptureState tracks debug capture state for a single step +type stepCaptureState struct { + pipelineName string + stepIndex int + stepType string + count uint64 + limit int + file *os.File + mu sync.Mutex +} + +// captureRecord writes a debug record to the capture file +func (s *stepCaptureState) captureRecord(record map[string]any) { + if s.limit > 0 { + currentCount := atomic.LoadUint64(&s.count) + if currentCount >= uint64(s.limit) { + return + } + } + + recordNum := atomic.AddUint64(&s.count, 1) + + envelope := map[string]any{ + "pipeline": s.pipelineName, + "step_index": s.stepIndex, + "step_type": s.stepType, + "record_num": recordNum, + "timestamp": time.Now().UTC().Format(time.RFC3339), + "data": record, + } + + s.mu.Lock() + defer s.mu.Unlock() + + if s.file != nil { + data, err := json.Marshal(envelope) + if err == nil { + s.file.Write(data) + s.file.Write([]byte("\n")) + } else { + logger.Error("Failed to marshal debug record", "error", err) + } + } +} + +func (pb *Playbook) Execute(task task.RuntimeTask, debugDir string, debugLimit int) error { logger.Debug("Running playbook") logger.Debug("Inputs", "config", task.GetConfig()) @@ -148,6 +198,43 @@ func (pb *Playbook) Execute(task task.RuntimeTask) error { procs := []transform.Processor{} joins := []joinStruct{} + captureFiles := []*os.File{} // Track all open capture files for cleanup + + // Helper function to sanitize filename components + sanitizeFilename := func(s string) string { + s = strings.ReplaceAll(s, "*", "") + s = strings.ReplaceAll(s, "/", "_") + s = strings.ReplaceAll(s, "\\", "_") + return s + } + + // Helper function to create capture state for a step + createCaptureState := func(pipelineName string, stepIndex int, stepType string) *stepCaptureState { + if debugDir == "" { + return nil + } + + filename := fmt.Sprintf("%s.step%d.%s.ndjson", pipelineName, stepIndex, sanitizeFilename(stepType)) + filepath := filepath.Join(debugDir, filename) + + file, err := os.Create(filepath) + if err != nil { + logger.Error("Failed to create debug capture file", "path", filepath, "error", err) + return nil + } + + captureFiles = append(captureFiles, file) + logger.Debug("Created debug capture file", "path", filepath) + + return &stepCaptureState{ + pipelineName: pipelineName, + stepIndex: stepIndex, + stepType: stepType, + count: 0, + limit: debugLimit, + file: file, + } + } for k, v := range pb.Pipelines { var lastStep flame.Emitter[map[string]any] @@ -163,7 +250,22 @@ func (pb *Playbook) Execute(task task.RuntimeTask) error { if mProcess, ok := b.(transform.NodeProcessor); ok { logger.Debug("PipelineSetup", "name", k, "step", i, "processor", fmt.Sprintf("%T", mProcess)) - c := flame.AddFlatMapper(wf, mProcess.Process) + + // Create capture state for this step + captureState := createCaptureState(k, i, fmt.Sprintf("%T", mProcess)) + + // Wrap the process function if capture is enabled + var processFunc func(map[string]any) []map[string]any + if captureState != nil { + processFunc = func(record map[string]any) []map[string]any { + captureState.captureRecord(record) + return mProcess.Process(record) + } + } else { + processFunc = mProcess.Process + } + + c := flame.AddFlatMapper(wf, processFunc) if lastStep != nil { c.Connect(lastStep) } @@ -178,12 +280,27 @@ func (pb *Playbook) Execute(task task.RuntimeTask) error { } } else if mProcess, ok := b.(transform.MapProcessor); ok { logger.Debug("Pipeline Pool", "name", k, "step", i, "processor", b) + + // Create capture state for this step + captureState := createCaptureState(k, i, fmt.Sprintf("%T", mProcess)) + + // Wrap the process function if capture is enabled + var processFunc func(map[string]any) map[string]any + if captureState != nil { + processFunc = func(record map[string]any) map[string]any { + captureState.captureRecord(record) + return mProcess.Process(record) + } + } else { + processFunc = mProcess.Process + } + var c flame.Node[map[string]any, map[string]any] if mProcess.PoolReady() { logger.Debug("Starting pool worker") - c = flame.AddMapperPool(wf, mProcess.Process, 4) // TODO: config pool count + c = flame.AddMapperPool(wf, processFunc, 4) // TODO: config pool count } else { - c = flame.AddMapper(wf, mProcess.Process) + c = flame.AddMapper(wf, processFunc) } if lastStep != nil { c.Connect(lastStep) @@ -199,12 +316,27 @@ func (pb *Playbook) Execute(task task.RuntimeTask) error { } } else if mProcess, ok := b.(transform.FlatMapProcessor); ok { logger.Debug("Pipeline flatmap", "name", k, "step", i, "processor", b) + + // Create capture state for this step + captureState := createCaptureState(k, i, fmt.Sprintf("%T", mProcess)) + + // Wrap the process function if capture is enabled + var processFunc func(map[string]any) []map[string]any + if captureState != nil { + processFunc = func(record map[string]any) []map[string]any { + captureState.captureRecord(record) + return mProcess.Process(record) + } + } else { + processFunc = mProcess.Process + } + var c flame.Node[map[string]any, map[string]any] if mProcess.PoolReady() { // log.Printf("Starting pool worker") - c = flame.AddFlatMapperPool(wf, mProcess.Process, 4) // TODO: config pool count + c = flame.AddFlatMapperPool(wf, processFunc, 4) // TODO: config pool count } else { - c = flame.AddFlatMapper(wf, mProcess.Process) + c = flame.AddFlatMapper(wf, processFunc) } if lastStep != nil { c.Connect(lastStep) @@ -220,6 +352,8 @@ func (pb *Playbook) Execute(task task.RuntimeTask) error { } } else if mProcess, ok := b.(transform.StreamProcessor); ok { logger.Info("Pipeline stream %s step %d: %T", k, i, b) + // Note: StreamProcessor uses channels, not suitable for simple record capture + // Would need to wrap the entire channel processing, which is complex c := flame.AddStreamer(wf, mProcess.Process) if c != nil { if lastStep != nil { @@ -235,6 +369,8 @@ func (pb *Playbook) Execute(task task.RuntimeTask) error { } } else if jProcess, ok := b.(transform.JoinProcessor); ok { logger.Debug("Pipeline Join Step") + // Note: JoinProcessor uses channels, not suitable for simple record capture + // Would need to wrap the entire channel processing, which is complex c := flame.AddJoin(wf, jProcess.Process) if c != nil { if lastStep != nil { @@ -253,9 +389,38 @@ func (pb *Playbook) Execute(task task.RuntimeTask) error { } } else if rProcess, ok := b.(transform.ReduceProcessor); ok { logger.Debug("Pipeline reduce %s step %d: %T", k, i, b) + + // Create capture state for this step (pre-reduce) + captureStateInput := createCaptureState(k, i, fmt.Sprintf("%T-input", rProcess)) + captureStateOutput := createCaptureState(k, i, fmt.Sprintf("%T-output", rProcess)) + wrap := reduceWrapper{rProcess} - k := flame.AddMapper(wf, wrap.addKeyValue) - r := flame.AddReduceKey(wf, rProcess.Reduce, rProcess.GetInit()) + + // Wrap addKeyValue if capturing input + var addKeyValueFunc func(map[string]any) flame.KeyValue[string, map[string]any] + if captureStateInput != nil { + addKeyValueFunc = func(x map[string]any) flame.KeyValue[string, map[string]any] { + captureStateInput.captureRecord(x) + return wrap.addKeyValue(x) + } + } else { + addKeyValueFunc = wrap.addKeyValue + } + + // Wrap reduce function if capturing output + var reduceFunc func(string, map[string]any, map[string]any) map[string]any + if captureStateOutput != nil { + reduceFunc = func(key string, acc map[string]any, val map[string]any) map[string]any { + result := rProcess.Reduce(key, acc, val) + captureStateOutput.captureRecord(result) + return result + } + } else { + reduceFunc = rProcess.Reduce + } + + k := flame.AddMapper(wf, addKeyValueFunc) + r := flame.AddReduceKey(wf, reduceFunc, rProcess.GetInit()) c := flame.AddFlatMapper(wf, wrap.removeKeyValue) if lastStep != nil { k.Connect(lastStep) @@ -269,9 +434,37 @@ func (pb *Playbook) Execute(task task.RuntimeTask) error { } else if rProcess, ok := b.(transform.AccumulateProcessor); ok { logger.Debug("Pipeline accumulate %s step %d: %T", k, i, b) + // Create capture state for this step + captureStateInput := createCaptureState(k, i, fmt.Sprintf("%T-input", rProcess)) + captureStateOutput := createCaptureState(k, i, fmt.Sprintf("%T-output", rProcess)) + wrap := accumulateWrapper{rProcess} - k := flame.AddMapper(wf, wrap.addKeyValue) - r := flame.AddAccumulate(wf, rProcess.Accumulate) + + // Wrap addKeyValue if capturing input + var addKeyValueFunc func(map[string]any) flame.KeyValue[string, map[string]any] + if captureStateInput != nil { + addKeyValueFunc = func(x map[string]any) flame.KeyValue[string, map[string]any] { + captureStateInput.captureRecord(x) + return wrap.addKeyValue(x) + } + } else { + addKeyValueFunc = wrap.addKeyValue + } + + // Wrap accumulate function if capturing output + var accumulateFunc func(string, []map[string]any) map[string]any + if captureStateOutput != nil { + accumulateFunc = func(key string, vals []map[string]any) map[string]any { + result := rProcess.Accumulate(key, vals) + captureStateOutput.captureRecord(result) + return result + } + } else { + accumulateFunc = rProcess.Accumulate + } + + k := flame.AddMapper(wf, addKeyValueFunc) + r := flame.AddAccumulate(wf, accumulateFunc) c := flame.AddFlatMapper(wf, wrap.removeKeyValue) if lastStep != nil { k.Connect(lastStep) @@ -376,6 +569,13 @@ func (pb *Playbook) Execute(task task.RuntimeTask) error { outputs[k].Close() } + // Close all debug capture files + for _, f := range captureFiles { + if f != nil { + f.Close() + } + } + task.Close() return nil } From 40fac5f10b3b07d5c7ca6575a102ecb1a5580b38 Mon Sep 17 00:00:00 2001 From: Kyle Ellrott Date: Tue, 10 Mar 2026 11:49:57 -0700 Subject: [PATCH 2/4] Updating capture output --- README.md | 14 ++++++-- cmd/run/main.go | 13 +++---- extractors/interface.go | 14 ++++++++ playbook/execute.go | 75 ++++++++++++++++++++++++++--------------- 4 files changed, 80 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 1fae78f..32733be 100644 --- a/README.md +++ b/README.md @@ -32,15 +32,23 @@ class: sifter name: census_2010 params: - census: ../data/census_2010_byzip.json - date: "2010-01-01" - schema: ../covid19_datadictionary/gdcdictionary/schemas/ + census: + type: file + default: ../data/census_2010_byzip.json + date: + default: "2010-01-01" + schema: + type: path + default: ../covid19_datadictionary/gdcdictionary/schemas/ inputs: censusData: jsonLoad: input: "{{params.census}}" +outputs: + + pipelines: transform: - from: censusData diff --git a/cmd/run/main.go b/cmd/run/main.go index 775297b..cbab677 100644 --- a/cmd/run/main.go +++ b/cmd/run/main.go @@ -14,8 +14,8 @@ var outDir string = "" var paramsFile string = "" var verbose bool = false var cmdParams map[string]string -var debugOutputDir string = "" -var debugSampleLimit int = 10 +var captureDir string = "" +var captureLimit int = 10 // Cmd is the declaration of the command line var Cmd = &cobra.Command{ @@ -48,11 +48,11 @@ var Cmd = &cobra.Command{ } pb := playbook.Playbook{} playbook.ParseBytes(yaml, "./playbook.yaml", &pb) - if err := Execute(pb, "./", "./", outDir, params, debugOutputDir, debugSampleLimit); err != nil { + if err := Execute(pb, "./", "./", outDir, params, captureDir, captureLimit); err != nil { return err } } else { - if err := ExecuteFile(playFile, "./", outDir, params, debugOutputDir, debugSampleLimit); err != nil { + if err := ExecuteFile(playFile, "./", outDir, params, captureDir, captureLimit); err != nil { return err } } @@ -67,6 +67,7 @@ func init() { flags.BoolVarP(&verbose, "verbose", "v", verbose, "Verbose logging") flags.StringToStringVarP(&cmdParams, "param", "p", cmdParams, "Parameter variable") flags.StringVarP(¶msFile, "params-file", "f", paramsFile, "Parameter file") - flags.StringVarP(&debugOutputDir, "debug-output-dir", "d", "", "Directory for debug capture files (default: ./debug-capture)") - flags.IntVarP(&debugSampleLimit, "debug-sample-limit", "l", 10, "Max records to capture per step (0 = unlimited)") + flags.StringVarP(&captureDir, "capture-dir", "d", "", "Directory for capture files (default: None)") + flags.IntVarP(&captureLimit, "capture-limit", "l", 10, "Max records to capture per step (0 = unlimited)") + flags.StringVarP(&outDir, "output", "o", outDir, "Output directory for playbook results (default: current directory or value specified in playbook)") } diff --git a/extractors/interface.go b/extractors/interface.go index 9655dc8..bc99ae0 100644 --- a/extractors/interface.go +++ b/extractors/interface.go @@ -44,6 +44,20 @@ func (ex *Extractor) Start(t task.RuntimeTask) (chan map[string]interface{}, err return nil, fmt.Errorf(("Extractor not defined")) } +func (ex *Extractor) GetType() reflect.Type { + v := reflect.ValueOf(ex).Elem() + for i := 0; i < v.NumField(); i++ { + f := v.Field(i) + x := f.Interface() + if _, ok := x.(Source); ok { + if !f.IsNil() { + return f.Type() + } + } + } + return nil +} + func (ex *Extractor) GetRequiredParams() []config.ParamRequest { out := []config.ParamRequest{} v := reflect.ValueOf(ex).Elem() diff --git a/playbook/execute.go b/playbook/execute.go index 42efb5f..03a4956 100644 --- a/playbook/execute.go +++ b/playbook/execute.go @@ -169,7 +169,7 @@ func (s *stepCaptureState) captureRecord(record map[string]any) { } } -func (pb *Playbook) Execute(task task.RuntimeTask, debugDir string, debugLimit int) error { +func (pb *Playbook) Execute(task task.RuntimeTask, captureDir string, captureLimit int) error { logger.Debug("Running playbook") logger.Debug("Inputs", "config", task.GetConfig()) @@ -180,22 +180,6 @@ func (pb *Playbook) Execute(task task.RuntimeTask, debugDir string, debugLimit i } task.SetName(pb.Name) - outNodes := map[string]flame.Emitter[map[string]any]{} - inNodes := map[string]flame.Receiver[map[string]any]{} - outputs := map[string]OutputProcessor{} - - for n, v := range pb.Inputs { - logger.Debug("Setting Up", "name", n) - s, err := v.Start(task) - if err == nil { - c := flame.AddSourceChan(wf, s) - outNodes[n] = c - } else { - logger.Error("Source error", "error", err) - return err - } - } - procs := []transform.Processor{} joins := []joinStruct{} captureFiles := []*os.File{} // Track all open capture files for cleanup @@ -205,17 +189,19 @@ func (pb *Playbook) Execute(task task.RuntimeTask, debugDir string, debugLimit i s = strings.ReplaceAll(s, "*", "") s = strings.ReplaceAll(s, "/", "_") s = strings.ReplaceAll(s, "\\", "_") + s = strings.ReplaceAll(s, "transform.", "") // Remove package prefix for readability + s = strings.ReplaceAll(s, "extractors.", "") // Remove package prefix for readability return s } // Helper function to create capture state for a step createCaptureState := func(pipelineName string, stepIndex int, stepType string) *stepCaptureState { - if debugDir == "" { + if captureDir == "" { return nil } - filename := fmt.Sprintf("%s.step%d.%s.ndjson", pipelineName, stepIndex, sanitizeFilename(stepType)) - filepath := filepath.Join(debugDir, filename) + filename := fmt.Sprintf("%s.%d.%s.ndjson", pipelineName, stepIndex, sanitizeFilename(stepType)) + filepath := filepath.Join(captureDir, filename) file, err := os.Create(filepath) if err != nil { @@ -231,15 +217,43 @@ func (pb *Playbook) Execute(task task.RuntimeTask, debugDir string, debugLimit i stepIndex: stepIndex, stepType: stepType, count: 0, - limit: debugLimit, + limit: captureLimit, file: file, } } + outNodes := map[string]flame.Emitter[map[string]any]{} + inNodes := map[string]flame.Receiver[map[string]any]{} + outputs := map[string]OutputProcessor{} + + for n, v := range pb.Inputs { + logger.Debug("Setting Up", "name", n) + s, err := v.Start(task) + if err == nil { + sourceNode := flame.AddSourceChan(wf, s) + + captureState := createCaptureState(n, 0, v.GetType().String()) + if captureState != nil { + captureMapper := flame.AddMapper(wf, func(record map[string]any) map[string]any { + captureState.captureRecord(record) + return record + }) + captureMapper.Connect(sourceNode) + outNodes[n] = captureMapper + } else { + outNodes[n] = sourceNode + } + } else { + logger.Error("Source error", "error", err) + return err + } + } + for k, v := range pb.Pipelines { var lastStep flame.Emitter[map[string]any] var firstStep flame.Receiver[map[string]any] for i, s := range v { + b, err := s.Init(task) if err != nil { logger.Error("Pipeline error", "name", k, "error", err) @@ -258,8 +272,11 @@ func (pb *Playbook) Execute(task task.RuntimeTask, debugDir string, debugLimit i var processFunc func(map[string]any) []map[string]any if captureState != nil { processFunc = func(record map[string]any) []map[string]any { - captureState.captureRecord(record) - return mProcess.Process(record) + out := mProcess.Process(record) + for _, r := range out { + captureState.captureRecord(r) + } + return out } } else { processFunc = mProcess.Process @@ -288,8 +305,9 @@ func (pb *Playbook) Execute(task task.RuntimeTask, debugDir string, debugLimit i var processFunc func(map[string]any) map[string]any if captureState != nil { processFunc = func(record map[string]any) map[string]any { - captureState.captureRecord(record) - return mProcess.Process(record) + out := mProcess.Process(record) + captureState.captureRecord(out) + return out } } else { processFunc = mProcess.Process @@ -324,8 +342,11 @@ func (pb *Playbook) Execute(task task.RuntimeTask, debugDir string, debugLimit i var processFunc func(map[string]any) []map[string]any if captureState != nil { processFunc = func(record map[string]any) []map[string]any { - captureState.captureRecord(record) - return mProcess.Process(record) + out := mProcess.Process(record) + for _, r := range out { + captureState.captureRecord(r) + } + return out } } else { processFunc = mProcess.Process From 7d7b33664789dfa9e359fc6a0ecc13e3200cce5d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 19:07:41 +0000 Subject: [PATCH 3/4] Initial plan From a9a8d84434852074c9c27f8127810230f117e33e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 19:15:21 +0000 Subject: [PATCH 4/4] Address code review comments: sanitize filenames, fix racy limit, handle write errors, defer cleanup, improve Stat handling, add Execute wrapper, add capture test Co-authored-by: kellrott <113868+kellrott@users.noreply.github.com> --- cmd/run/run.go | 17 +++++++-- playbook/execute.go | 77 +++++++++++++++++++++++++++++---------- test/command_line_test.go | 69 +++++++++++++++++++++++++++++++++++ 3 files changed, 140 insertions(+), 23 deletions(-) diff --git a/cmd/run/run.go b/cmd/run/run.go index e3e5e19..0c491e0 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -1,6 +1,7 @@ package run import ( + "fmt" "os" "path/filepath" @@ -41,11 +42,19 @@ func Execute(pb playbook.Playbook, baseDir string, workDir string, outDir string } else if !filepath.IsAbs(debugDir) { debugDir = filepath.Join(workDir, debugDir) } - if _, err := os.Stat(debugDir); os.IsNotExist(err) { - if err := os.MkdirAll(debugDir, 0777); err != nil { - logger.Error("Failed to create debug directory", "error", err) + if info, err := os.Stat(debugDir); err != nil { + if os.IsNotExist(err) { + if mkErr := os.MkdirAll(debugDir, 0777); mkErr != nil { + logger.Error("Failed to create debug directory", "error", mkErr) + return mkErr + } + } else { + logger.Error("Failed to access debug directory", "path", debugDir, "error", err) return err } + } else if !info.IsDir() { + logger.Error("Debug path exists but is not a directory", "path", debugDir) + return fmt.Errorf("debug path %s exists but is not a directory", debugDir) } logger.Info("Debug capture enabled", "dir", debugDir, "limit", debugLimit) } @@ -57,6 +66,6 @@ func Execute(pb playbook.Playbook, baseDir string, workDir string, outDir string logger.Debug("Running", "outDir", outDir) t := task.NewTask(pb.Name, baseDir, workDir, outDir, nInputs) - err = pb.Execute(t, debugDir, debugLimit) + err = pb.ExecuteWithCapture(t, debugDir, debugLimit) return err } diff --git a/playbook/execute.go b/playbook/execute.go index 03a4956..518084b 100644 --- a/playbook/execute.go +++ b/playbook/execute.go @@ -137,14 +137,22 @@ type stepCaptureState struct { // captureRecord writes a debug record to the capture file func (s *stepCaptureState) captureRecord(record map[string]any) { - if s.limit > 0 { + var recordNum uint64 + + for { currentCount := atomic.LoadUint64(&s.count) - if currentCount >= uint64(s.limit) { + + // Enforce limit strictly under concurrency + if s.limit > 0 && currentCount >= uint64(s.limit) { return } - } - recordNum := atomic.AddUint64(&s.count, 1) + next := currentCount + 1 + if atomic.CompareAndSwapUint64(&s.count, currentCount, next) { + recordNum = next + break + } + } envelope := map[string]any{ "pipeline": s.pipelineName, @@ -161,15 +169,28 @@ func (s *stepCaptureState) captureRecord(record map[string]any) { if s.file != nil { data, err := json.Marshal(envelope) if err == nil { - s.file.Write(data) - s.file.Write([]byte("\n")) + if _, writeErr := s.file.Write(data); writeErr != nil { + logger.Error("Failed to write debug record data", "error", writeErr) + return + } + if _, writeErr := s.file.Write([]byte("\n")); writeErr != nil { + logger.Error("Failed to write debug record newline", "error", writeErr) + return + } } else { logger.Error("Failed to marshal debug record", "error", err) } } } -func (pb *Playbook) Execute(task task.RuntimeTask, captureDir string, captureLimit int) error { +// Execute runs the playbook without debug capture. +// This maintains the original public API signature for backward compatibility. +func (pb *Playbook) Execute(task task.RuntimeTask) error { + return pb.ExecuteWithCapture(task, "", 0) +} + +// ExecuteWithCapture runs the playbook with optional debug capture configuration. +func (pb *Playbook) ExecuteWithCapture(task task.RuntimeTask, captureDir string, captureLimit int) error { logger.Debug("Running playbook") logger.Debug("Inputs", "config", task.GetConfig()) @@ -183,6 +204,13 @@ func (pb *Playbook) Execute(task task.RuntimeTask, captureDir string, captureLim procs := []transform.Processor{} joins := []joinStruct{} captureFiles := []*os.File{} // Track all open capture files for cleanup + defer func() { + for _, f := range captureFiles { + if f != nil { + _ = f.Close() + } + } + }() // Helper function to sanitize filename components sanitizeFilename := func(s string) string { @@ -194,23 +222,41 @@ func (pb *Playbook) Execute(task task.RuntimeTask, captureDir string, captureLim return s } + // Helper function to sanitize pipeline names used in filenames + sanitizePipelineName := func(s string) string { + // Use only the last path element to avoid directory traversal + s = filepath.Base(s) + + // Treat empty, current-dir, parent-dir, or bare separator as invalid and use a default + if s == "" || s == "." || s == ".." || s == string(os.PathSeparator) { + s = "pipeline" + } + + // Replace any remaining path separators with underscores + s = strings.ReplaceAll(s, string(os.PathSeparator), "_") + s = strings.ReplaceAll(s, "/", "_") + s = strings.ReplaceAll(s, "\\", "_") + + return s + } + // Helper function to create capture state for a step createCaptureState := func(pipelineName string, stepIndex int, stepType string) *stepCaptureState { if captureDir == "" { return nil } - filename := fmt.Sprintf("%s.%d.%s.ndjson", pipelineName, stepIndex, sanitizeFilename(stepType)) - filepath := filepath.Join(captureDir, filename) + filename := fmt.Sprintf("%s.%d.%s.ndjson", sanitizePipelineName(pipelineName), stepIndex, sanitizeFilename(stepType)) + filePath := filepath.Join(captureDir, filename) - file, err := os.Create(filepath) + file, err := os.Create(filePath) if err != nil { - logger.Error("Failed to create debug capture file", "path", filepath, "error", err) + logger.Error("Failed to create debug capture file", "path", filePath, "error", err) return nil } captureFiles = append(captureFiles, file) - logger.Debug("Created debug capture file", "path", filepath) + logger.Debug("Created debug capture file", "path", filePath) return &stepCaptureState{ pipelineName: pipelineName, @@ -590,13 +636,6 @@ func (pb *Playbook) Execute(task task.RuntimeTask, captureDir string, captureLim outputs[k].Close() } - // Close all debug capture files - for _, f := range captureFiles { - if f != nil { - f.Close() - } - } - task.Close() return nil } diff --git a/test/command_line_test.go b/test/command_line_test.go index 119244c..5dcc0f2 100644 --- a/test/command_line_test.go +++ b/test/command_line_test.go @@ -1,6 +1,7 @@ package test import ( + "bufio" "bytes" "compress/gzip" "fmt" @@ -101,3 +102,71 @@ func TestCommandLines(t *testing.T) { } } + +// TestCaptureMode verifies that --capture-dir and --capture-limit flags create +// NDJSON capture files and respect the record limit. +func TestCaptureMode(t *testing.T) { + captureDir, err := os.MkdirTemp("", "sifter-capture-*") + if err != nil { + t.Fatalf("Failed to create temp capture dir: %s", err) + } + defer os.RemoveAll(captureDir) + + playbook := "examples/gene-table/gene-table.yaml" + limit := 3 + + cmd := exec.Command("../sifter", "run", + "--capture-dir", captureDir, + "--capture-limit", fmt.Sprintf("%d", limit), + playbook, + ) + t.Logf("Running: %s with capture-dir=%s capture-limit=%d", playbook, captureDir, limit) + if err := cmd.Run(); err != nil { + t.Fatalf("Failed running %s: %s", playbook, err) + } + + // Check that at least one .ndjson file was created + entries, err := os.ReadDir(captureDir) + if err != nil { + t.Fatalf("Failed to read capture dir: %s", err) + } + if len(entries) == 0 { + t.Errorf("Expected capture NDJSON files in %s, but directory is empty", captureDir) + return + } + + // Verify each capture file has at most `limit` records + for _, entry := range entries { + if !strings.HasSuffix(entry.Name(), ".ndjson") { + continue + } + filePath := filepath.Join(captureDir, entry.Name()) + f, err := os.Open(filePath) + if err != nil { + t.Errorf("Failed to open capture file %s: %s", filePath, err) + continue + } + + lineCount := 0 + scanner := bufio.NewScanner(f) + for scanner.Scan() { + if scanner.Text() != "" { + lineCount++ + } + } + scanErr := scanner.Err() + f.Close() + + if scanErr != nil { + t.Errorf("Error reading capture file %s: %s", filePath, scanErr) + } + if lineCount > limit { + t.Errorf("Capture file %s has %d records, expected at most %d", entry.Name(), lineCount, limit) + } + t.Logf("Capture file %s has %d records (limit=%d)", entry.Name(), lineCount, limit) + } + + // Clean up the playbook output + outputDir := filepath.Join(filepath.Dir(playbook), "output") + os.RemoveAll(outputDir) +}