diff --git a/README.md b/README.md index a5a045a..3a61a1b 100644 --- a/README.md +++ b/README.md @@ -123,8 +123,8 @@ All options can be set via environment variables or CLI flags. CLI flags take pr |---------------------|----------|---------|-------------| | `MIGRATION_RESOURCES` | `--resources`, `-r` | `all` | Comma-separated list of resources to migrate. Options: `all`, `ai_secrets`, `roles`, `groups`, `datasets`, `project_tags`, `span_iframes`, `functions`, `prompts`, `project_scores`, `experiments`, `logs`, `views` | | `MIGRATION_PROJECTS` | `--projects`, `-p` | *(all projects)* | Comma-separated list of project names to migrate | -| `MIGRATION_CREATED_AFTER` | `--created-after` | *(none)* | Only migrate data created on or after this date (**inclusive**: `>=`). Format: `YYYY-MM-DD` or ISO-8601 | -| `MIGRATION_CREATED_BEFORE` | `--created-before` | *(none)* | Only migrate data created before this date (**exclusive**: `<`). Format: `YYYY-MM-DD` or ISO-8601 | +| `MIGRATION_CREATED_AFTER` | `--created-after` | *(none)* | Only applies to resources that support created-time filtering. Currently this affects project logs event streaming and experiment listing. Migrates items with `created >=` this value (**inclusive**). Format: `YYYY-MM-DD` or ISO-8601 | +| `MIGRATION_CREATED_BEFORE` | `--created-before` | *(none)* | Only applies to resources that support created-time filtering. Currently this affects project logs event streaming and experiment listing. Migrates items with `created <` this value (**exclusive**). Format: `YYYY-MM-DD` or ISO-8601 | #### Logging @@ -146,7 +146,7 @@ All options can be set via environment variables or CLI flags. CLI flags take pr | `MIGRATION_BATCH_SIZE` | — | `100` | Number of resources to process per batch | | `MIGRATION_RETRY_ATTEMPTS` | — | `3` | Number of retry attempts for failed operations (0 = no retries) | | `MIGRATION_RETRY_DELAY` | — | `1.0` | Initial retry delay in seconds (exponential backoff) | -| `MIGRATION_MAX_CONCURRENT` | — | `10` | Maximum concurrent **projects** (bounded project-level parallelism) | +| `MIGRATION_MAX_CONCURRENT` | — | `1` | Maximum concurrent **projects** (bounded project-level parallelism) | | `MIGRATION_CHECKPOINT_INTERVAL` | — | `50` | Write checkpoint every N successful operations | #### Parallelization Tuning @@ -160,6 +160,7 @@ All options can be set via environment variables or CLI flags. CLI flags take pr #### Streaming Migration (Logs, Experiments, Datasets) These settings control BTQL-based streaming for high-volume resources. +`MIGRATION_CREATED_AFTER` and `MIGRATION_CREATED_BEFORE` are not universal streaming filters: they currently affect project logs event migration and experiment selection, but not dataset event migration. | Environment Variable | CLI Flag | Default | Description | |---------------------|----------|---------|-------------| @@ -403,45 +404,48 @@ On resume: skips 1-30 (done), resumes experiment 31 from saved `_pagination_key` ## Parallelization -The migration tool uses **three levels of parallelization** that work together to reduce migration time. All levels are enabled by default and are controlled by the three env vars described in [Parallelization Tuning](#parallelization-tuning). +The migration tool currently uses **two active levels of concurrency** plus pipelined event streaming. The env vars in [Parallelization Tuning](#parallelization-tuning) still matter, but the within-project resource-type DAG concurrency described below has not been implemented yet. ### How It Works ``` ┌─────────────────────────────────────────────────────────────┐ -│ Level 1: Inter-Project (MIGRATION_MAX_CONCURRENT=10) │ -│ Projects migrate concurrently (up to 10 at a time) │ +│ Level 1: Inter-Project (MIGRATION_MAX_CONCURRENT=1) │ +│ Projects migrate concurrently (up to 1 at a time by default) │ │ │ │ ┌─────────────────────────────────────────────────────────┐ │ -│ │ Level 2: Inter-Resource-Type (DAG scheduler) │ │ -│ │ Within each project, independent resource types │ │ -│ │ (e.g. Datasets + Tags + Iframes) run concurrently. │ │ -│ │ Dependent types wait for prerequisites to finish. │ │ +│ │ Level 2: Intra-Resource-Type │ │ +│ │ Within a single resource type, some migrators can │ │ +│ │ process multiple items concurrently │ │ +│ │ (MIGRATION_MAX_CONCURRENT_RESOURCES). │ │ │ │ │ │ -│ │ ┌─────────────────────────────────────────────────────┐ │ │ -│ │ │ Level 3: Intra-Resource-Type │ │ │ -│ │ │ Within each type, individual items migrate │ │ │ -│ │ │ concurrently (MIGRATION_MAX_CONCURRENT_RESOURCES) │ │ │ -│ │ │ │ │ │ -│ │ │ For streaming resources (logs/datasets/exps): │ │ │ -│ │ │ - Multiple event streams run in parallel │ │ │ -│ │ │ - Each stream prefetches the next page while │ │ │ -│ │ │ inserting the current (STREAMING_PIPELINE) │ │ │ -│ │ └─────────────────────────────────────────────────────┘ │ │ +│ │ Resource types themselves still run sequentially │ │ +│ │ within each project. │ │ +│ │ │ │ +│ │ For streaming resources (logs/datasets/exps): │ │ +│ │ - Each stream can prefetch the next page while │ │ +│ │ inserting the current one │ │ +│ │ - Dataset/experiment streams are grouped for fetch │ │ +│ │ efficiency, not scheduled as independent parallel │ │ +│ │ DAG tasks within a project │ │ │ └─────────────────────────────────────────────────────────┘ │ └─────────────────────────────────────────────────────────────┘ ``` ### Level Details -**Inter-Project Concurrency** (`MIGRATION_MAX_CONCURRENT`, default 10) -Multiple projects migrate at the same time. Each project runs its own DAG-scheduled resource pipeline independently. This was the only parallelism the tool had before the current changes. - -**Inter-Resource-Type Concurrency** (automatic, DAG-based) -Within each project, resource types are organized in a dependency graph. Types whose dependencies are satisfied start immediately instead of waiting in a fixed sequence. For example, Datasets, Project Tags, and Span Iframes all start at the same time because they have no mutual dependencies. +**Inter-Project Concurrency** (`MIGRATION_MAX_CONCURRENT`, default 1) +Multiple projects migrate at the same time. Each project runs independently, bounded by this semaphore. **Intra-Resource-Type Concurrency** (`MIGRATION_MAX_CONCURRENT_RESOURCES`, default 5) -Within a single resource type, multiple individual items migrate concurrently. For example, if a project has 50 functions, up to 5 are created at once instead of one-by-one. For streaming resources (datasets, experiments), this also controls how many event streams run in parallel -- e.g. events for 5 datasets copy simultaneously. +Within a single resource type, migrators that use the base batch executor can process multiple individual items concurrently. For example, if a project has 50 functions, up to 5 may be created at once instead of one-by-one. + +Resource types themselves are still processed sequentially within a project in the current implementation. + +For streaming resources: +- `logs` is a single stream per project. +- `datasets` and `experiments` group multiple ids into one BTQL fetch stream for efficiency. +- `MIGRATION_MAX_CONCURRENT_RESOURCES` does not currently create multiple independent resource-type DAG tasks within a single project. **Pipelined Event Streaming** (`MIGRATION_STREAMING_PIPELINE`, default true) For each individual event stream (logs, dataset records, experiment events), the next BTQL page is prefetched from the source while the current page's batches are being inserted into the destination. This overlaps source reads with destination writes, reducing idle time for large migrations. @@ -458,7 +462,7 @@ State mutations (ID mappings, checkpoint files) are protected by `asyncio.Lock` |----------|---------------------| | **Small migration** (<5 projects, <100 resources) | Defaults work well. No tuning needed. | | **Many small projects** (50+ projects, small data) | Increase `MIGRATION_MAX_CONCURRENT=20` for more project-level parallelism. | -| **Few projects with many resources** (e.g. 500 experiments in one project) | Increase `MIGRATION_MAX_CONCURRENT_RESOURCES=10` for more intra-type parallelism. | +| **Few projects with many resources** (e.g. 500 experiments in one project) | `MIGRATION_MAX_CONCURRENT_RESOURCES` helps only for migrators that support per-item fanout. Streaming resources within one project still run mostly as a single grouped stream. | | **Large event streams** (TB-scale logs) | Defaults are good. Pipeline is on by default. Consider increasing `MIGRATION_MAX_CONCURRENT_REQUESTS=40` if the API can handle it. | | **Rate-limited API** (frequent 429s) | *Decrease* `MIGRATION_MAX_CONCURRENT_RESOURCES=2` and `MIGRATION_MAX_CONCURRENT_REQUESTS=10`. The tool handles 429s with backoff, but fewer concurrent requests reduces throttling. | | **Debugging or sequential run** | Set `MIGRATION_MAX_CONCURRENT_RESOURCES=1` and `MIGRATION_STREAMING_PIPELINE=false` for deterministic, sequential execution. | diff --git a/braintrust_migrate/cli.py b/braintrust_migrate/cli.py index aeb054b..fbf1c24 100644 --- a/braintrust_migrate/cli.py +++ b/braintrust_migrate/cli.py @@ -559,72 +559,27 @@ async def _run_migration(config: Config, *, resume_run_dir: Path | None) -> None ) try: - # Import here to avoid circular dependencies - from braintrust_migrate.client import create_client_pair - progress.update( migration_task, description="🔗 Connecting to organizations...", ) - async with create_client_pair( - config.source, - config.destination, - config.migration, - ) as (source_client, dest_client): - progress.update( - migration_task, - description="🔍 Discovering projects...", - ) - - # Discover projects first to set up progress tracking - projects = await orchestrator._discover_projects( - source_client, dest_client - ) - num_projects = len(projects) - - # Set up project-based progress tracking - # We'll show organization migration as "prep", then project 1/N, 2/N, etc. - total_projects = max( - num_projects, 1 - ) # At least 1 to avoid division by zero + progress.update(migration_task, description="🚀 Starting migration...") - progress.update( - migration_task, - total=total_projects, - completed=0, - description=f"📋 Found {num_projects} projects to migrate", - ) - - # Show project list for user awareness - if num_projects > 0: - project_names = [ - p["name"] for p in projects[:PROJECTS_PREVIEW_LIMIT] - ] - if num_projects > PROJECTS_PREVIEW_LIMIT: - project_names.append( - f"... and {num_projects - PROJECTS_PREVIEW_LIMIT} more" - ) - console.print(f"[blue]Projects:[/blue] {', '.join(project_names)}") - - progress.update(migration_task, description="🚀 Starting migration...") - - # Run migration with project-based progress tracking - results = await _run_migration_with_progress( - orchestrator, - progress, - migration_task, - num_projects, - checkpoint_dir=checkpoint_dir, - is_resuming=is_resuming, - ) + results = await _run_migration_with_progress( + orchestrator, + progress, + migration_task, + checkpoint_dir=checkpoint_dir, + ) - # Final update - progress.update( - migration_task, - completed=total_projects, - description="✅ Migration completed", - ) + total_projects = max(results["summary"]["total_projects"], 1) + progress.update( + migration_task, + completed=total_projects, + total=total_projects, + description="✅ Migration completed", + ) # Display results _display_results(results) @@ -661,10 +616,8 @@ async def _run_migration_with_progress( orchestrator, progress, migration_task, - total_projects, *, checkpoint_dir: Path, - is_resuming: bool, ): """Run migration with detailed progress updates. @@ -672,328 +625,204 @@ async def _run_migration_with_progress( orchestrator: MigrationOrchestrator instance progress: Rich progress instance migration_task: Progress task ID - total_projects: Total number of projects to migrate Returns: Migration results """ - from datetime import datetime - - import structlog - - from braintrust_migrate.client import create_client_pair - - # Track progress during migration - start_time = datetime.now() - logger = structlog.get_logger(__name__) - projects_completed = 0 # Track streaming throughput from progress hooks (per-project/resource). stream_totals: dict[tuple[str, str], dict[str, float]] = {} # Use the resolved run checkpoint dir (either new timestamp dir or resume dir) orchestrator.config.ensure_checkpoint_dir() checkpoint_dir.mkdir(parents=True, exist_ok=True) + def _projects_discovered(projects: list[dict[str, Any]]) -> None: + num_projects = len(projects) + progress.update( + migration_task, + total=max(num_projects, 1), + completed=0, + description=f"📋 Found {num_projects} projects to migrate", + ) + if num_projects > 0: + project_names = [p["name"] for p in projects[:PROJECTS_PREVIEW_LIMIT]] + if num_projects > PROJECTS_PREVIEW_LIMIT: + project_names.append( + f"... and {num_projects - PROJECTS_PREVIEW_LIMIT} more" + ) + console.print(f"[blue]Projects:[/blue] {', '.join(project_names)}") - total_results = { - "start_time": start_time.isoformat(), - "checkpoint_dir": str(checkpoint_dir), - "organization_resources": {}, - "projects": {}, - "summary": { - "total_projects": 0, - "total_resources": 0, - "migrated_resources": 0, - "skipped_resources": 0, - "failed_resources": 0, - "errors": [], - }, - } - - async with create_client_pair( - orchestrator.config.source, - orchestrator.config.destination, - orchestrator.config.migration, - ) as (source_client, dest_client): - # Discover projects - projects = await orchestrator._discover_projects(source_client, dest_client) - total_results["summary"]["total_projects"] = len(projects) - - # Create global ID mapping registry - global_id_mappings = {} - for project in projects: - global_id_mappings[project["source_id"]] = project["dest_id"] - - # STEP 1: Migrate organization-scoped resources first (doesn't count toward project progress) + def _organization_start() -> None: progress.update( migration_task, description="🏢 Migrating organization resources...", - completed=projects_completed, + completed=0, ) - try: - org_results = await orchestrator._migrate_organization_resources( - source_client, - dest_client, - checkpoint_dir, - global_id_mappings, - ) - total_results["organization_resources"] = org_results - - # Don't increment project counter for org resources, just update description - progress.update( - migration_task, - description="✅ Organization resources migrated", - completed=projects_completed, - ) - - # Aggregate organization results - summary = total_results["summary"] - summary["total_resources"] += org_results.get("total_resources", 0) - summary["migrated_resources"] += org_results.get("migrated_resources", 0) - summary["skipped_resources"] += org_results.get("skipped_resources", 0) - summary["failed_resources"] += org_results.get("failed_resources", 0) - summary["errors"].extend(org_results.get("errors", [])) - - except Exception as e: - progress.update( - migration_task, - description="❌ Organization migration failed", - completed=projects_completed, - ) - logger.error("Organization resource migration failed", error=str(e)) - total_results["summary"]["errors"].append( - {"type": "org_error", "error": str(e)} - ) - - # STEP 2: Migrate project-scoped resources (1 per project) - for i, project in enumerate(projects): - project_name = project["name"] - - progress.update( - migration_task, - description=f"📁 Migrating project {i + 1} of {total_projects}: {project_name}", - completed=projects_completed, - ) + def _organization_complete(_org_results: dict[str, Any]) -> None: + progress.update( + migration_task, + description="✅ Organization resources migrated", + completed=0, + ) - # Print project header for visibility - console.print( - f"\n[bold blue]📁 {project_name}[/bold blue] ({i + 1}/{total_projects})" - ) + def _project_start(project: dict[str, Any], index: int, total: int) -> None: + progress.update( + migration_task, + description=f"📁 Running project {index} of {total}: {project['name']}", + ) + console.print(f"\n[bold blue]📁 {project['name']}[/bold blue] ({index}/{total})") + + def _progress_factory_factory(project: dict[str, Any]): + project_name = project["name"] + stream_task_ids: dict[str, TaskID] = {} + + def _stream_progress_factory( + resource_name: str, + *, + _project_name: str = project_name, + _progress: Progress = progress, + _stream_task_ids: dict[str, TaskID] = stream_task_ids, + ): + label_map = { + "logs": "🧾 Logs", + "experiments": "🧪 Experiment events", + "datasets": "📚 Dataset events", + } + label = label_map.get(resource_name, resource_name) - try: - # Create per-project streaming progress tasks lazily. - stream_task_ids: dict[str, TaskID] = {} - - def _stream_progress_factory( - resource_name: str, - *, - _project_name: str = project_name, - _progress: Progress = progress, - _stream_task_ids: dict[str, TaskID] = stream_task_ids, - ): - label_map = { - "logs": "🧾 Logs", - "experiments": "🧪 Experiment events", - "datasets": "📚 Dataset events", - } - label = label_map.get(resource_name, resource_name) - - if resource_name not in _stream_task_ids: - _stream_task_ids[resource_name] = _progress.add_task( - f"{label} ({_project_name}): starting…", - total=None, - ) - task_id = _stream_task_ids[resource_name] - - def hook(update: dict[str, Any], *, _label: str = label) -> None: - # Common fields - phase = update.get("phase") - fetched = update.get("fetched_total") - inserted = update.get("inserted_total") - inserted_bytes = update.get("inserted_bytes_total") - _ = update.get("skipped_seen_total") - _ = update.get("skipped_deleted_total") - page_num = update.get("page_num") - inserted_last = update.get("inserted_last") - inserted_bytes_last = update.get("inserted_bytes_last") - insert_seconds = update.get("insert_seconds") - pending_buffered_rows = update.get("pending_buffered_rows") - pending_buffered_bytes = update.get("pending_buffered_bytes") - - gb_part = "" - if isinstance(inserted_bytes, int): - gb = inserted_bytes / 1_000_000_000 - gb_part = f" gb={gb:.3f}" - stream_totals[ - (_project_name, str(update.get("resource"))) - ] = { - "inserted_rows": float(inserted or 0), - "inserted_gb": gb, - } - - batch_rate_part = "" - if ( - isinstance(inserted_last, int) - and isinstance(inserted_bytes_last, int) - and isinstance(insert_seconds, int | float) - and insert_seconds > 0 - ): - rps = inserted_last / float(insert_seconds) - gbps = (inserted_bytes_last / 1_000_000_000) / float( - insert_seconds - ) - batch_rate_part = f" rps={rps:.0f} gbps={gbps:.3f}" - - # Per-resource context - page_part = f" page={page_num}" if page_num is not None else "" - pending_part = "" - if ( - isinstance(pending_buffered_rows, int) - and pending_buffered_rows > 0 - ): - pending_part = f" buffered={pending_buffered_rows}" - if isinstance(pending_buffered_bytes, int): - pending_gb = pending_buffered_bytes / 1_000_000_000 - pending_part += f" pending_gb={pending_gb:.3f}" - - if update.get("resource") == "experiment_events": - desc = ( - f"{_label} ({_project_name}):{page_part}" - f" fetched={fetched} inserted={inserted}" - f"{gb_part}{pending_part}{batch_rate_part}" - ) - elif update.get("resource") == "dataset_events": - desc = ( - f"{_label} ({_project_name}):{page_part}" - f" fetched={fetched} inserted={inserted}" - f"{gb_part}{pending_part}{batch_rate_part}" - ) - else: - # logs - desc = ( - f"{_label} ({_project_name}):{page_part}" - f" fetched={fetched} inserted={inserted}" - f"{gb_part}{pending_part}{batch_rate_part}" - ) - - if phase == "done": - _progress.update( - task_id, description=desc, total=1, completed=1 - ) - else: - _progress.update(task_id, description=desc) - - return hook - - # Create callback for real-time resource feedback - def _resource_callback( - resource_name: str, - results: dict[str, Any], - *, - _project_name: str = project_name, - ) -> None: - """Print real-time feedback when each resource type completes.""" - total = results.get("total", 0) - migrated = results.get("migrated", 0) - skipped = results.get("skipped", 0) - failed = results.get("failed", 0) - - # Emoji labels for resource types - label_map = { - "datasets": "📚 datasets", - "experiments": "🧪 experiments", - "logs": "🧾 logs", - "prompts": "💬 prompts", - "functions": "⚙️ functions", - "project_tags": "🏷️ project_tags", - "project_scores": "📊 project_scores", - "views": "👁️ views", - "span_iframes": "🖼️ span_iframes", - } - label = label_map.get(resource_name, f" {resource_name}") - - # Build status parts - if total == 0: - status = "[dim]none found[/dim]" - else: - parts = [] - if migrated > 0: - parts.append(f"[green]{migrated} migrated[/green]") - if skipped > 0: - parts.append(f"[yellow]{skipped} skipped[/yellow]") - if failed > 0: - parts.append(f"[red]{failed} failed[/red]") - status = ", ".join(parts) if parts else "[dim]0[/dim]" - - console.print(f" {label}: {status}") - - project_results = await orchestrator._migrate_project( - project, - source_client, - dest_client, - checkpoint_dir, - global_id_mappings, - progress_factory=_stream_progress_factory, - resource_callback=_resource_callback, + if resource_name not in _stream_task_ids: + _stream_task_ids[resource_name] = _progress.add_task( + f"{label} ({_project_name}): starting…", + total=None, ) + task_id = _stream_task_ids[resource_name] + + def hook(update: dict[str, Any], *, _label: str = label) -> None: + phase = update.get("phase") + fetched = update.get("fetched_total") + inserted = update.get("inserted_total") + inserted_bytes = update.get("inserted_bytes_total") + page_num = update.get("page_num") + inserted_last = update.get("inserted_last") + inserted_bytes_last = update.get("inserted_bytes_last") + insert_seconds = update.get("insert_seconds") + pending_buffered_rows = update.get("pending_buffered_rows") + pending_buffered_bytes = update.get("pending_buffered_bytes") + + gb_part = "" + if isinstance(inserted_bytes, int): + gb = inserted_bytes / 1_000_000_000 + gb_part = f" gb={gb:.3f}" + stream_totals[(_project_name, str(update.get("resource")))] = { + "inserted_rows": float(inserted or 0), + "inserted_gb": gb, + } - total_results["projects"][project_name] = project_results - - # Aggregate project results - summary = total_results["summary"] - summary["total_resources"] += project_results.get("total_resources", 0) - summary["migrated_resources"] += project_results.get( - "migrated_resources", 0 - ) - summary["skipped_resources"] += project_results.get( - "skipped_resources", 0 - ) - summary["failed_resources"] += project_results.get( - "failed_resources", 0 + batch_rate_part = "" + if ( + isinstance(inserted_last, int) + and isinstance(inserted_bytes_last, int) + and isinstance(insert_seconds, int | float) + and insert_seconds > 0 + ): + rps = inserted_last / float(insert_seconds) + gbps = (inserted_bytes_last / 1_000_000_000) / float( + insert_seconds + ) + batch_rate_part = f" rps={rps:.0f} gbps={gbps:.3f}" + + page_part = f" page={page_num}" if page_num is not None else "" + pending_part = "" + if isinstance(pending_buffered_rows, int) and pending_buffered_rows > 0: + pending_part = f" buffered={pending_buffered_rows}" + if isinstance(pending_buffered_bytes, int): + pending_gb = pending_buffered_bytes / 1_000_000_000 + pending_part += f" pending_gb={pending_gb:.3f}" + + desc = ( + f"{_label} ({_project_name}):{page_part}" + f" fetched={fetched} inserted={inserted}" + f"{gb_part}{pending_part}{batch_rate_part}" ) - summary["errors"].extend(project_results.get("errors", [])) - - # Complete this project - projects_completed += 1 - progress.update( - migration_task, - description=f"✅ Completed project {projects_completed} of {total_projects}: {project_name}", - completed=projects_completed, - ) + if phase == "done": + _progress.update(task_id, description=desc, total=1, completed=1) + else: + _progress.update(task_id, description=desc) + + return hook + + return _stream_progress_factory + + def _resource_callback_factory(project: dict[str, Any]): + project_name = project["name"] + + def _resource_callback( + resource_name: str, + results: dict[str, Any], + *, + _project_name: str = project_name, + ) -> None: + total = results.get("total", 0) + migrated = results.get("migrated", 0) + skipped = results.get("skipped", 0) + failed = results.get("failed", 0) + + label_map = { + "datasets": "📚 datasets", + "experiments": "🧪 experiments", + "logs": "🧾 logs", + "prompts": "💬 prompts", + "functions": "⚙️ functions", + "project_tags": "🏷️ project_tags", + "project_scores": "📊 project_scores", + "views": "👁️ views", + "span_iframes": "🖼️ span_iframes", + } + label = label_map.get(resource_name, f" {resource_name}") - except Exception as e: - # Still increment project counter even on failure - projects_completed += 1 - logger.error( - "Project migration failed", project=project_name, error=str(e) - ) - total_results["summary"]["errors"].append( - { - "type": "project_error", - "project": project_name, - "error": str(e), - } - ) - progress.update( - migration_task, - description=f"❌ Failed project {projects_completed} of {total_projects}: {project_name}", - completed=projects_completed, - ) + if total == 0: + status = "[dim]none found[/dim]" + else: + parts = [] + if migrated > 0: + parts.append(f"[green]{migrated} migrated[/green]") + if skipped > 0: + parts.append(f"[yellow]{skipped} skipped[/yellow]") + if failed > 0: + parts.append(f"[red]{failed} failed[/red]") + status = ", ".join(parts) if parts else "[dim]0[/dim]" + + console.print(f" {label}: {status}") + + return _resource_callback + + def _project_complete( + project: dict[str, Any], + result: dict[str, Any] | Exception, + completed: int, + total: int, + ) -> None: + status = "✅ Completed" if not isinstance(result, Exception) else "❌ Failed" + progress.update( + migration_task, + description=f"{status} project {completed} of {total}: {project['name']}", + completed=completed, + ) - # Finalize results - end_time = datetime.now() - duration = (end_time - start_time).total_seconds() - - total_results.update( - { - "end_time": end_time.isoformat(), - "duration_seconds": duration, - "success": total_results["summary"]["failed_resources"] == 0 - and len(total_results["summary"]["errors"]) == 0, - } + total_results = await orchestrator.migrate_all( + checkpoint_dir=checkpoint_dir, + on_projects_discovered=_projects_discovered, + on_organization_start=_organization_start, + on_organization_complete=_organization_complete, + on_project_start=_project_start, + on_project_complete=_project_complete, + progress_factory_factory=_progress_factory_factory, + resource_callback_factory=_resource_callback_factory, ) + duration = float(total_results.get("duration_seconds", 0.0)) + # Add a lightweight throughput summary based on streaming progress hooks. inserted_rows_total = 0.0 inserted_gb_total = 0.0 @@ -1007,10 +836,6 @@ def _resource_callback( "gb_per_sec": (inserted_gb_total / duration) if duration > 0 else None, } - # Generate detailed migration report - report_path = orchestrator._generate_migration_report(total_results, checkpoint_dir) - total_results["report_path"] = str(report_path) - return total_results diff --git a/braintrust_migrate/config.py b/braintrust_migrate/config.py index 42570ff..ada99fc 100644 --- a/braintrust_migrate/config.py +++ b/braintrust_migrate/config.py @@ -95,7 +95,7 @@ class MigrationConfig(BaseModel): description="Initial delay between retries in seconds", ) max_concurrent: int = Field( - default=10, ge=1, le=50, description="Maximum number of concurrent operations" + default=1, ge=1, le=50, description="Maximum number of concurrent operations" ) max_concurrent_resources: int = Field( default=5, @@ -340,7 +340,7 @@ def from_env(cls) -> "Config": batch_size = int(os.getenv("MIGRATION_BATCH_SIZE", "100")) retry_attempts = int(os.getenv("MIGRATION_RETRY_ATTEMPTS", "3")) retry_delay = float(os.getenv("MIGRATION_RETRY_DELAY", "1.0")) - max_concurrent = int(os.getenv("MIGRATION_MAX_CONCURRENT", "10")) + max_concurrent = int(os.getenv("MIGRATION_MAX_CONCURRENT", "1")) max_concurrent_resources = int( os.getenv("MIGRATION_MAX_CONCURRENT_RESOURCES", "5") ) diff --git a/braintrust_migrate/orchestration.py b/braintrust_migrate/orchestration.py index aae97e7..57b2aa4 100644 --- a/braintrust_migrate/orchestration.py +++ b/braintrust_migrate/orchestration.py @@ -33,6 +33,14 @@ MAX_SUMMARY_ITEMS_PER_REASON: int = 10 T = TypeVar("T") +ProjectHook = Callable[[dict[str, Any], int, int], None] +ProjectResultHook = Callable[[dict[str, Any], dict[str, Any] | Exception, int, int], None] +ProjectsDiscoveredHook = Callable[[list[dict[str, Any]]], None] +OrganizationResultHook = Callable[[dict[str, Any]], None] +ProgressFactory = Callable[[str], Callable[[dict[str, Any]], None]] +ProgressFactoryFactory = Callable[[dict[str, Any]], ProgressFactory | None] +ResourceCallback = Callable[[str, dict[str, Any]], None] +ResourceCallbackFactory = Callable[[dict[str, Any]], ResourceCallback | None] # Project-scoped resource dependency graph used by DAG scheduler unit tests. # This graph is intentionally kept as module-level compatibility surface. @@ -187,7 +195,18 @@ def __init__(self, config: Config) -> None: self.config = config self._logger = logger.bind(orchestrator=True) - async def migrate_all(self) -> dict[str, Any]: + async def migrate_all( + self, + *, + checkpoint_dir: Path | None = None, + on_projects_discovered: ProjectsDiscoveredHook | None = None, + on_organization_start: Callable[[], None] | None = None, + on_organization_complete: OrganizationResultHook | None = None, + on_project_start: ProjectHook | None = None, + on_project_complete: ProjectResultHook | None = None, + progress_factory_factory: ProgressFactoryFactory | None = None, + resource_callback_factory: ResourceCallbackFactory | None = None, + ) -> dict[str, Any]: """Migrate all resources from source to destination organization. Returns: @@ -196,9 +215,10 @@ async def migrate_all(self) -> dict[str, Any]: start_time = datetime.now() self._logger.info("Starting complete migration") - # Create timestamped checkpoint directory - timestamp = start_time.strftime("%Y%m%d_%H%M%S") - checkpoint_dir = self.config.ensure_checkpoint_dir() / timestamp + # Create timestamped checkpoint directory unless one was provided. + if checkpoint_dir is None: + timestamp = start_time.strftime("%Y%m%d_%H%M%S") + checkpoint_dir = self.config.ensure_checkpoint_dir() / timestamp total_results = { "start_time": start_time.isoformat(), @@ -224,6 +244,8 @@ async def migrate_all(self) -> dict[str, Any]: # Discover projects projects = await self._discover_projects(source_client, dest_client) total_results["summary"]["total_projects"] = len(projects) + if on_projects_discovered is not None: + on_projects_discovered(projects) self._logger.info(f"Discovered {len(projects)} projects to migrate") @@ -242,12 +264,16 @@ async def migrate_all(self) -> dict[str, Any]: # STEP 1: Migrate organization-scoped resources once self._logger.info("Migrating organization-scoped resources") + if on_organization_start is not None: + on_organization_start() org_results = await self._migrate_organization_resources( source_client, dest_client, checkpoint_dir, global_id_mappings, ) + if on_organization_complete is not None: + on_organization_complete(org_results) total_results["organization_resources"] = org_results @@ -270,15 +296,51 @@ async def migrate_all(self) -> dict[str, Any]: total_projects=len(projects), ) + completed_projects = 0 + + async def _run_project_with_hooks( + project: dict[str, Any], index: int + ) -> dict[str, Any]: + nonlocal completed_projects + if on_project_start is not None: + on_project_start(project, index, len(projects)) + + try: + project_results = await self._migrate_project( + project, + source_client, + dest_client, + checkpoint_dir, + global_id_mappings, + progress_factory=( + progress_factory_factory(project) + if progress_factory_factory is not None + else None + ), + resource_callback=( + resource_callback_factory(project) + if resource_callback_factory is not None + else None + ), + ) + except Exception as e: + completed_projects += 1 + if on_project_complete is not None: + on_project_complete( + project, e, completed_projects, len(projects) + ) + raise + + completed_projects += 1 + if on_project_complete is not None: + on_project_complete( + project, project_results, completed_projects, len(projects) + ) + return project_results + project_coros = [ - self._migrate_project( - project, - source_client, - dest_client, - checkpoint_dir, - global_id_mappings, # shared mappings across all projects - ) - for project in projects + _run_project_with_hooks(project, index) + for index, project in enumerate(projects, start=1) ] project_results_list = await _gather_with_concurrency( project_coros, max_concurrent=max_concurrent diff --git a/tests/unit/test_cli_progress_path.py b/tests/unit/test_cli_progress_path.py new file mode 100644 index 0000000..50893ac --- /dev/null +++ b/tests/unit/test_cli_progress_path.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest +from rich.progress import Progress + +from braintrust_migrate.cli import _run_migration_with_progress +from braintrust_migrate.config import Config, MigrationConfig + + +class _StubOrchestrator: + def __init__(self, config: Config) -> None: + self.config = config + self.migrate_all_calls: list[dict] = [] + + async def migrate_all(self, **kwargs): + self.migrate_all_calls.append(kwargs) + + projects = [ + {"source_id": "src-1", "dest_id": "dst-1", "name": "Project A"}, + {"source_id": "src-2", "dest_id": "dst-2", "name": "Project B"}, + ] + kwargs["on_projects_discovered"](projects) + kwargs["on_organization_start"]() + kwargs["on_organization_complete"]({}) + kwargs["on_project_start"](projects[0], 1, 2) + progress_factory = kwargs["progress_factory_factory"](projects[0]) + resource_callback = kwargs["resource_callback_factory"](projects[0]) + progress_hook = progress_factory("logs") + progress_hook( + { + "phase": "done", + "resource": "logs", + "fetched_total": 10, + "inserted_total": 8, + "inserted_bytes_total": 2_000_000_000, + } + ) + resource_callback( + "logs", + {"total": 1, "migrated": 1, "skipped": 0, "failed": 0, "errors": []}, + ) + kwargs["on_project_complete"]( + projects[0], + { + "project_id": "dst-1", + "project_name": "Project A", + "resources": {}, + "total_resources": 1, + "migrated_resources": 1, + "skipped_resources": 0, + "failed_resources": 0, + "errors": [], + }, + 1, + 2, + ) + return { + "summary": { + "total_projects": 2, + "total_resources": 1, + "migrated_resources": 1, + "skipped_resources": 0, + "failed_resources": 0, + "errors": [], + }, + "projects": { + "Project A": { + "project_id": "dst-1", + "project_name": "Project A", + "resources": {}, + "total_resources": 1, + "migrated_resources": 1, + "skipped_resources": 0, + "failed_resources": 0, + "errors": [], + } + }, + "duration_seconds": 2.0, + "success": True, + } + + +@pytest.mark.asyncio +async def test_run_migration_with_progress_delegates_to_orchestrator( + tmp_path: Path, +) -> None: + config = Config( + source={"api_key": "src", "url": "https://api.braintrust.dev"}, + destination={"api_key": "dst", "url": "https://api.braintrust.dev"}, + migration=MigrationConfig(), + state_dir=tmp_path, + resources=["all"], + ) + orchestrator = _StubOrchestrator(config) + + with Progress() as progress: + migration_task = progress.add_task("migrate", total=1) + results = await _run_migration_with_progress( + orchestrator, + progress, + migration_task, + checkpoint_dir=tmp_path / "run", + ) + + assert len(orchestrator.migrate_all_calls) == 1 + call = orchestrator.migrate_all_calls[0] + assert call["checkpoint_dir"] == tmp_path / "run" + assert callable(call["on_projects_discovered"]) + assert callable(call["on_project_start"]) + assert callable(call["on_project_complete"]) + assert callable(call["progress_factory_factory"]) + assert callable(call["resource_callback_factory"]) + assert results["throughput"]["inserted_rows_total"] == 8 + assert results["throughput"]["inserted_gb_total"] == 2.0 + assert results["throughput"]["rows_per_sec"] == 4.0 diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index f077845..b28bea5 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -9,7 +9,7 @@ # Test constants DEFAULT_BATCH_SIZE = 100 DEFAULT_RETRY_ATTEMPTS = 3 -DEFAULT_MAX_CONCURRENT = 10 +DEFAULT_MAX_CONCURRENT = 1 DEFAULT_CHECKPOINT_INTERVAL = 50 TEST_BATCH_SIZE = 50 TEST_RETRY_ATTEMPTS = 5 diff --git a/tests/unit/test_orchestrator_migrate_all.py b/tests/unit/test_orchestrator_migrate_all.py new file mode 100644 index 0000000..2c3e3a8 --- /dev/null +++ b/tests/unit/test_orchestrator_migrate_all.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from pathlib import Path +from unittest.mock import AsyncMock, Mock + +import pytest + +from braintrust_migrate.config import Config, MigrationConfig +from braintrust_migrate.orchestration import MigrationOrchestrator + + +def _make_config(tmp_path: Path, *, max_concurrent: int = 2) -> Config: + return Config( + source={"api_key": "src", "url": "https://api.braintrust.dev"}, + destination={"api_key": "dst", "url": "https://api.braintrust.dev"}, + migration=MigrationConfig(max_concurrent=max_concurrent), + state_dir=tmp_path, + resources=["all"], + ) + + +@pytest.mark.asyncio +async def test_migrate_all_runs_projects_concurrently_and_emits_hooks( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + orchestrator = MigrationOrchestrator(_make_config(tmp_path, max_concurrent=2)) + checkpoint_dir = tmp_path / "run" + report_path = checkpoint_dir / "migration_report.json" + projects = [ + {"source_id": "src-1", "dest_id": "dst-1", "name": "Project 1"}, + {"source_id": "src-2", "dest_id": "dst-2", "name": "Project 2"}, + {"source_id": "src-3", "dest_id": "dst-3", "name": "Project 3"}, + ] + + source_client = Mock() + dest_client = Mock() + + @asynccontextmanager + async def mock_create_client_pair(_source_cfg, _dest_cfg, _migration_cfg): + yield source_client, dest_client + + monkeypatch.setattr( + "braintrust_migrate.orchestration.create_client_pair", mock_create_client_pair + ) + monkeypatch.setattr(orchestrator, "_discover_projects", AsyncMock(return_value=projects)) + monkeypatch.setattr( + orchestrator, + "_migrate_organization_resources", + AsyncMock( + return_value={ + "resources": {}, + "total_resources": 0, + "migrated_resources": 0, + "skipped_resources": 0, + "failed_resources": 0, + "errors": [], + } + ), + ) + monkeypatch.setattr( + orchestrator, + "_migrate_post_project_global_resources", + AsyncMock( + return_value={ + "resources": {}, + "total_resources": 0, + "migrated_resources": 0, + "skipped_resources": 0, + "failed_resources": 0, + "errors": [], + } + ), + ) + monkeypatch.setattr(orchestrator, "_generate_migration_report", lambda *_: report_path) + + lock = asyncio.Lock() + in_flight = 0 + max_seen = 0 + started: list[tuple[str, int, int]] = [] + completed: list[tuple[str, int, int]] = [] + discovered: list[str] = [] + + async def fake_migrate_project( + project: dict[str, str], + _source_client: Mock, + _dest_client: Mock, + _checkpoint_dir: Path, + _global_id_mappings: dict[str, str], + progress_factory=None, + resource_callback=None, + ) -> dict[str, object]: + nonlocal in_flight, max_seen + assert callable(progress_factory) + assert callable(resource_callback) + async with lock: + in_flight += 1 + max_seen = max(max_seen, in_flight) + await asyncio.sleep(0.02) + resource_callback( + "datasets", + {"total": 1, "migrated": 1, "skipped": 0, "failed": 0, "errors": []}, + ) + async with lock: + in_flight -= 1 + return { + "project_id": project["dest_id"], + "project_name": project["name"], + "resources": {}, + "total_resources": 1, + "migrated_resources": 1, + "skipped_resources": 0, + "failed_resources": 0, + "errors": [], + } + + monkeypatch.setattr(orchestrator, "_migrate_project", fake_migrate_project) + + results = await orchestrator.migrate_all( + checkpoint_dir=checkpoint_dir, + on_projects_discovered=lambda ps: discovered.extend([p["name"] for p in ps]), + on_project_start=lambda project, index, total: started.append( + (project["name"], index, total) + ), + on_project_complete=lambda project, _result, done, total: completed.append( + (project["name"], done, total) + ), + progress_factory_factory=lambda _project: lambda _resource_name: lambda _update: None, + resource_callback_factory=lambda _project: lambda _resource_name, _results: None, + ) + + assert discovered == [p["name"] for p in projects] + assert len(started) == 3 + assert len(completed) == 3 + assert max_seen == 2 + assert results["summary"]["total_projects"] == 3 + assert results["summary"]["migrated_resources"] == 3 + assert results["report_path"] == str(report_path) diff --git a/tests/unit/test_parallelization_config.py b/tests/unit/test_parallelization_config.py index 0e87164..640fbd2 100644 --- a/tests/unit/test_parallelization_config.py +++ b/tests/unit/test_parallelization_config.py @@ -8,6 +8,10 @@ class TestParallelizationConfigDefaults: """Verify defaults for new parallelization config fields.""" + def test_max_concurrent_default(self): + config = MigrationConfig() + assert config.max_concurrent == 1 + def test_max_concurrent_resources_default(self): config = MigrationConfig() assert config.max_concurrent_resources == 5 @@ -65,6 +69,14 @@ def test_streaming_pipeline_bool(self): class TestParallelizationConfigFromEnv: """Verify env var parsing for parallelization fields.""" + def test_max_concurrent_from_env(self, monkeypatch): + monkeypatch.setenv("BT_SOURCE_API_KEY", "src") + monkeypatch.setenv("BT_DEST_API_KEY", "dst") + monkeypatch.setenv("MIGRATION_MAX_CONCURRENT", "4") + + config = Config.from_env() + assert config.migration.max_concurrent == 4 + def test_max_concurrent_resources_from_env(self, monkeypatch): monkeypatch.setenv("BT_SOURCE_API_KEY", "src") monkeypatch.setenv("BT_DEST_API_KEY", "dst") @@ -101,11 +113,13 @@ def test_defaults_when_env_not_set(self, monkeypatch): monkeypatch.setenv("BT_SOURCE_API_KEY", "src") monkeypatch.setenv("BT_DEST_API_KEY", "dst") # Ensure the parallelization env vars are NOT set + monkeypatch.delenv("MIGRATION_MAX_CONCURRENT", raising=False) monkeypatch.delenv("MIGRATION_MAX_CONCURRENT_RESOURCES", raising=False) monkeypatch.delenv("MIGRATION_STREAMING_PIPELINE", raising=False) monkeypatch.delenv("MIGRATION_MAX_CONCURRENT_REQUESTS", raising=False) config = Config.from_env() + assert config.migration.max_concurrent == 1 assert config.migration.max_concurrent_resources == 5 assert config.migration.streaming_pipeline is True assert config.migration.max_concurrent_requests == 20