Trainable uses three specialized AI agents — EDA, Prep, and Train — that share a common execution architecture but differ in their system prompts, goals, and validation rules. Each agent is a Claude instance running via the Claude Agent SDK, equipped with a single MCP tool (execute_code) that runs Python in an isolated Modal sandbox.
All three agents share the same runtime pipeline:
┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐
│ Frontend │ SSE │ FastAPI │ SDK │ Claude │ MCP │ Modal │
│ (Next.js) │◄────│ Backend │◄───►│ Agent │────►│ Sandbox │
│ │ │ │ │ │ │ (Python) │
└──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘
- Trigger: User clicks "Start EDA/Prep/Train" or sends a follow-up message.
- Backend (
routers/sessions.py) validates prerequisites and launchesrun_agent()as a background async task. - Agent setup (
services/agent.py):- Loads the previous stage's report as context (prep reads EDA report, train reads prep report + metadata).
- Builds a stage-specific system prompt with session/experiment IDs, user instructions, and previous context injected.
- Creates a per-call MCP server with a bound
execute_codetool handler (concurrency-safe — each run gets its own handler instance). - Initializes the Claude Agent SDK with
max_turns=30andbypassPermissionsmode.
- Agentic loop: Claude generates Python code → calls
execute_code→ receives stdout/stderr → decides what to do next. This repeats up to 30 turns. - Post-stage hooks: After the agent finishes, the backend automatically runs validation, S3 sync, and metadata extraction.
- State update: Session state transitions to
{stage}_done.
This is the only tool available to all three agents. It:
- Accepts a
codestring parameter (Python source code). - Auto-saves the code as a numbered
.pyscript to the stage'sscripts/directory on the Modal Volume. - Executes the code in a Modal sandbox — an isolated container with Python 3.11 and pre-installed ML libraries.
- The sandbox mounts the shared Modal Volume at
/data, giving access to:/data/datasets/{experiment_id}/— raw uploaded files/data/sessions/{session_id}/{stage}/— stage workspace for outputs
- A
trainableSDK module is injected into every execution, providinglog()andconfigure_dashboard()for live metrics streaming. - Returns stdout/stderr and exit code to the agent.
- Has a 10-minute timeout per execution.
Every agent action is published to the frontend via Server-Sent Events (SSE):
| Event | When | Data |
|---|---|---|
state_change |
Stage starts/ends | {state: "eda_running"} |
agent_message |
Agent produces text | {text: "..."} |
tool_start |
Code execution begins | {tool: "execute_code", input: {code: "..."}} |
tool_end |
Code execution finishes | {tool: "execute_code", output: "..."} |
code_output |
Stdout chunk streamed | {stream: "stdout", text: "..."} |
file_created |
New file detected | {path, name, stage} |
report_ready |
Report.md content | {content: "...", stage} |
files_ready |
All stage files listed | {files: [...], stage} |
metric |
Training metric logged | {step, metrics, run} |
chart_config |
Dashboard layout defined | {charts: [...]} |
validation_result |
Post-stage validation | {passed, warnings, errors} |
s3_sync_complete |
Artifacts uploaded to S3 | {files_synced, s3_prefix} |
agent_error |
Agent crashed | {error: "..."} |
agent_aborted |
User cancelled | {reason, stage} |
Purpose: Understand the dataset — its shape, quality, distributions, and relationships — before any transformations.
Entry state: created → transitions to eda_running
Exit state: eda_done
- Lists dataset files at
/data/datasets/{experiment_id}/ - Loads and inspects the data (shape, dtypes,
.head(),.describe()) - Performs statistical profiling:
- Missing values (count + percentage per column)
- Duplicate rows
- Numeric columns: mean, std, min, max, skewness, outlier count (IQR method)
- Categorical columns: cardinality, top values, rare categories (<1% frequency)
- Identifies the likely target column and problem type (classification vs regression)
- For classification targets: plots class distribution, reports balance ratio
- For regression targets: plots distribution, reports skewness
- Computes feature-target correlations (Pearson for numeric, chi-squared for categorical)
- Checks for data leakage signals (perfect predictors, ID-like columns, date leakage)
- Checks for multicollinearity (correlation heatmap, flags |r| > 0.9 pairs)
- Creates visualizations (saved as PNGs to
figures/) - Writes a comprehensive
report.md
| Path | Description |
|---|---|
/data/sessions/{id}/eda/report.md |
Markdown report with findings and recommendations |
/data/sessions/{id}/eda/figures/*.png |
Charts: distributions, correlations, heatmaps |
/data/sessions/{id}/eda/data/ |
Summary CSVs, profiling outputs |
/data/sessions/{id}/eda/scripts/ |
Auto-saved Python scripts (step_01_*.py, etc.) |
pandas, numpy, matplotlib, seaborn, scikit-learn, duckdb (for large datasets >1M rows), statsmodels
- Always check shape, dtypes, missing, duplicates
- Report per-column statistics for both numeric and categorical features
- Identify and flag data leakage risks
- Recommend target column and problem type in the report
- Use DuckDB for aggregation queries on large datasets
Purpose: Clean, transform, and split the data into train/val/test sets ready for model training.
Entry state: eda_done → transitions to prep_running
Exit state: prep_done
- Reads the EDA report from the previous stage for context
- Loads the raw dataset
- Identifies target column and problem type
- Splits into train/val/test FIRST (70/15/15, stratified for classification, random_state=42)
- Handles missing values (fit imputer on train only, transform all splits)
- Encodes categoricals:
- One-hot for low cardinality (<10 unique values)
- Target/ordinal encoding for high cardinality
- Always fit on train set only
- Engineers features if beneficial (interactions, polynomial, binning)
- Removes duplicates from train only (never from val/test)
- Scales/normalizes numeric features (fit on train only)
- Validates: no nulls in output, consistent dtypes, same columns across splits
- Saves processed data, metadata, and fitted pipeline
| Path | Description |
|---|---|
/data/sessions/{id}/prep/data/train.parquet |
Training set |
/data/sessions/{id}/prep/data/val.parquet |
Validation set |
/data/sessions/{id}/prep/data/test.parquet |
Test set |
/data/sessions/{id}/prep/data/metadata.json |
Target column, features, splits, transforms |
/data/sessions/{id}/prep/data/prep_pipeline.pkl |
Fitted sklearn Pipeline/ColumnTransformer |
/data/sessions/{id}/prep/report.md |
Decisions and statistics report |
/data/sessions/{id}/prep/figures/ |
Distribution/transform visualizations |
/data/sessions/{id}/prep/scripts/ |
Auto-saved Python scripts |
{
"target_column": "...",
"problem_type": "classification|regression",
"features": ["..."],
"categorical_features": ["..."],
"numeric_features": ["..."],
"n_classes": 3,
"class_distribution": {"A": 100, "B": 50, "C": 30},
"splits": {
"train": {"rows": 700},
"val": {"rows": 150},
"test": {"rows": 150}
},
"transforms": {},
"random_seed": 42,
"original_shape": [1000, 15],
"duplicates_removed": 5,
"outliers_removed": 0
}pandas, numpy, scikit-learn, pyarrow, duckdb, imbalanced-learn, category_encoders, pandera, statsmodels
This agent has the strictest rules of all three:
- Split BEFORE transforms: train/val/test split happens before any learned transformations
- Fit on train only: All transformers (scalers, encoders, imputers) are fitted on the training set, then applied to val/test
- Never use target statistics from the full dataset
- Save the fitted pipeline for reproducibility downstream
- Verify schema consistency across all three splits after transforms
After the prep agent completes:
- Validator (
services/validator.py): Checks that train/val/test parquet files exist, feature columns are consistent, metadata.json is valid - S3 sync (
services/s3_sync.py): Uploads all prep artifacts to S3 - Metadata extractor (
services/metadata_extractor.py): Reads the processed data and stores column metadata in theProcessedDatasetMetaDB table
Purpose: Train, tune, and evaluate ML models on the prepared data, producing a final model with explainability analysis.
Entry state: prep_done → transitions to train_running
Exit state: train_done
- Reads the prep report AND
metadata.jsonfor structured context (target column, problem type, class distribution) - Loads prepared data (train.parquet, val.parquet, test.parquet)
- Trains at least 2 different models (e.g., LogisticRegression + RandomForest, or XGBoost + LightGBM)
- Handles class imbalance if needed:
- Tries
class_weight='balanced'first - Then SMOTE (on train set only via imblearn)
- Compares balanced vs imbalanced training
- Tries
- Tunes the most promising model with Optuna (30-50 trials) or sklearn cross-validation
- Evaluates all models on validation set for comparison
- Runs test set evaluation exactly once on the final selected model
- Computes SHAP feature importance for the best model
- Generates confusion matrix (classification) or residual plot (regression)
- Saves model, metadata, and report
The train agent is the only agent that uses the trainable SDK for real-time metrics streaming:
from trainable import log, configure_dashboard
# Step 1: Define chart layout (once, before training)
configure_dashboard([
{"title": "Loss", "metrics": ["train_loss", "val_loss"], "type": "line"},
{"title": "Accuracy", "metrics": ["val_accuracy", "val_f1"], "type": "line"},
])
# Step 2: Log metrics every iteration
log(step=epoch, metrics={"train_loss": 0.5, "val_loss": 0.6}, run="xgboost")How it works under the hood:
log()andconfigure_dashboard()print JSON to stdout- The sandbox streams stdout chunks in real-time
services/metrics.pyparses JSON lines, persists to theMetricDB table, and publishes SSE events- The frontend's
MetricsTabcomponent renders live Recharts line charts
| Path | Description |
|---|---|
/data/sessions/{id}/train/models/model.pkl |
Best model (joblib serialized) |
/data/sessions/{id}/train/data/metadata.json |
Model info, test metrics, feature importance |
/data/sessions/{id}/train/report.md |
Full training report |
/data/sessions/{id}/train/figures/ |
SHAP plots, confusion matrix, learning curves |
/data/sessions/{id}/train/scripts/ |
Auto-saved Python scripts |
{
"best_model": "XGBoost",
"best_model_params": {"max_depth": 6, "learning_rate": 0.1},
"models_evaluated": ["LogisticRegression", "RandomForest", "XGBoost"],
"test_metrics": {"accuracy": 0.92, "f1": 0.89, "roc_auc": 0.95},
"feature_importance": {"feature_1": 0.25, "feature_2": 0.18},
"class_imbalance_strategy": "class_weight=balanced",
"tuning_method": "optuna",
"tuning_trials": 50,
"random_seed": 42
}scikit-learn, xgboost, lightgbm, optuna, imbalanced-learn, shap, matplotlib, pandas, numpy, statsmodels, torch (optional), tensorflow (optional)
- Never train on val/test data
- Never use test metrics to choose between models or tune hyperparameters
- Test set evaluation happens exactly once on the final selected model
- Cross-validation happens only on the training set
- SMOTE is applied only to the training set, never val/test
After the train agent completes:
- Validator (
services/validator.py): Checks that model.pkl exists, metrics were logged, metadata.json is valid - S3 sync (
services/s3_sync.py): Uploads all train artifacts to S3
created
│
├─► eda_running ──► eda_done
│ │
│ ├─► prep_running ──► prep_done
│ │ │
│ │ ├─► train_running ──► train_done
│ │ │
│ │ └─► failed / cancelled
│ │
│ └─► failed / cancelled
│
└─► failed / cancelled
Each stage requires the previous stage to be complete. The backend validates this before launching an agent.
- One agent runs per session at a time (tracked in
_running_tasksdict). - If a user sends a follow-up message while an agent is running, the current agent is silently aborted and a new one launches with conversation history.
- Abort is implemented via
asyncio.Task.cancel()with a 5-second grace period. - Each
run_agent()call creates its own MCP server and tool handler — no shared mutable state between concurrent sessions.
When a user sends a message with run_agent: true:
- Any running agent for that session is silently cancelled (no abort SSE events).
- Conversation history is loaded from the DB.
- A new agent launches with the user's message as the prompt and prior messages appended to the system prompt.
- The agent continues working in the context of the current stage.
All three agents have access to the same sandbox environment:
| Category | Libraries |
|---|---|
| Data | pandas, numpy, pyarrow, openpyxl, duckdb |
| Visualization | matplotlib, seaborn |
| ML (classical) | scikit-learn, xgboost, lightgbm |
| ML (deep learning) | torch, torchvision, torchaudio, tensorflow-cpu |
| Tuning | optuna |
| Imbalance | imbalanced-learn |
| Encoding | category_encoders |
| Validation | pandera |
| Explainability | shap |
| Statistics | statsmodels |
- Python version: 3.11
- Base image: Debian Slim
- Timeout: 10 minutes per execution (600 seconds)
- GPU: Optional, passed via
gpuparameter (e.g.,"T4","A10G") - Volume: Modal Volume
trainable-datamounted at/data - Stdout: Unbuffered (
python -u) for real-time streaming