Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion braintrust_migrate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def __init__(
self._http_client: httpx.AsyncClient | None = None
self._logger = logger.bind(org=org_name, url=str(org_config.url))
self._org_id: str | None = None
self._request_semaphore = asyncio.Semaphore(
int(self.migration_config.max_concurrent_requests)
)

async def __aenter__(self) -> "BraintrustClient":
"""Async context manager entry."""
Expand All @@ -91,9 +94,13 @@ async def connect(self) -> None:
self._logger.info("Connecting to Braintrust API")

# Create HTTP client for requests
max_concurrent_requests = int(self.migration_config.max_concurrent_requests)
self._http_client = httpx.AsyncClient(
timeout=httpx.Timeout(30.0),
limits=httpx.Limits(max_connections=20, max_keepalive_connections=5),
limits=httpx.Limits(
max_connections=max_concurrent_requests,
max_keepalive_connections=max(5, max_concurrent_requests),
),
)

# Perform health check
Expand Down Expand Up @@ -208,6 +215,28 @@ async def raw_request(
) -> Any:
"""Perform a raw HTTP request against the Braintrust API.

This method is concurrency-limited by a per-client request semaphore.
"""
async with self._request_semaphore:
return await self._raw_request_inner(
method,
path,
params=params,
json=json,
timeout=timeout,
)

async def _raw_request_inner(
self,
method: str,
path: str,
*,
params: dict[str, Any] | None = None,
json: Any | None = None,
timeout: float | None = None,
) -> Any:
"""Perform a raw HTTP request against the Braintrust API.

This is useful when we need tight control over request/response behavior
(e.g. cursor-pagination for large logs) or want to avoid additional SDK
dependencies.
Expand Down
32 changes: 32 additions & 0 deletions braintrust_migrate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,22 @@ class MigrationConfig(BaseModel):
max_concurrent: int = Field(
default=10, ge=1, le=50, description="Maximum number of concurrent operations"
)
max_concurrent_resources: int = Field(
default=5,
ge=1,
le=50,
description="Maximum number of resources migrated concurrently within a batch",
)
streaming_pipeline: bool = Field(
default=True,
description="Enable pipelined page prefetch during streaming event migrations",
)
max_concurrent_requests: int = Field(
default=20,
ge=1,
le=200,
description="Maximum number of concurrent in-flight HTTP requests per API client",
)
checkpoint_interval: int = Field(
default=50, ge=1, description="Write checkpoint every N successful operations"
)
Expand Down Expand Up @@ -324,6 +340,19 @@ def from_env(cls) -> "Config":
retry_attempts = int(os.getenv("MIGRATION_RETRY_ATTEMPTS", "3"))
retry_delay = float(os.getenv("MIGRATION_RETRY_DELAY", "1.0"))
max_concurrent = int(os.getenv("MIGRATION_MAX_CONCURRENT", "10"))
max_concurrent_resources = int(
os.getenv("MIGRATION_MAX_CONCURRENT_RESOURCES", "5")
)
streaming_pipeline = os.getenv("MIGRATION_STREAMING_PIPELINE", "true").lower() in {
"1",
"true",
"yes",
"y",
"on",
}
max_concurrent_requests = int(
os.getenv("MIGRATION_MAX_CONCURRENT_REQUESTS", "20")
)
checkpoint_interval = int(os.getenv("MIGRATION_CHECKPOINT_INTERVAL", "50"))

insert_max_request_bytes = int(
Expand Down Expand Up @@ -466,6 +495,9 @@ def _get_bool(specific_key: str, unified_key: str, default: str) -> bool:
retry_attempts=retry_attempts,
retry_delay=retry_delay,
max_concurrent=max_concurrent,
max_concurrent_resources=max_concurrent_resources,
streaming_pipeline=streaming_pipeline,
max_concurrent_requests=max_concurrent_requests,
checkpoint_interval=checkpoint_interval,
insert_max_request_bytes=insert_max_request_bytes,
insert_request_headroom_ratio=insert_request_headroom_ratio,
Expand Down
21 changes: 18 additions & 3 deletions braintrust_migrate/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,12 @@ async def _migrate_project(
)

# Perform migration
resource_results = await migrator.migrate_all(source_project_id)
resource_results = await migrator.migrate_all(
source_project_id,
max_concurrent=int(
self.config.migration.max_concurrent_resources
),
)

# Collect ID mappings from this migrator and add to global registry
new_mappings = {
Expand Down Expand Up @@ -1208,7 +1213,12 @@ async def _migrate_organization_resources(
)

# Perform migration (project_id=None for org-scoped)
resource_results = await migrator.migrate_all(None)
resource_results = await migrator.migrate_all(
None,
max_concurrent=int(
self.config.migration.max_concurrent_resources
),
)

# Collect ID mappings from this migrator and add to global registry
new_mappings = {
Expand Down Expand Up @@ -1312,7 +1322,12 @@ async def _migrate_post_project_global_resources(
shared_dependency_cache,
)

resource_results = await migrator.migrate_all(None)
resource_results = await migrator.migrate_all(
None,
max_concurrent=int(
self.config.migration.max_concurrent_resources
),
)
global_id_mappings.update(migrator.state.id_mapping)

post_results["resources"][resource_name] = resource_results
Expand Down
105 changes: 64 additions & 41 deletions braintrust_migrate/resources/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Abstract base class for Braintrust resource migrators."""

import asyncio
import inspect
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
Expand Down Expand Up @@ -863,18 +865,19 @@ async def resolve_dependencies(

return resolved

async def migrate_batch(self, resources: list[T]) -> list[MigrationResult]:
async def migrate_batch(
self, resources: list[T], max_concurrent: int | None = None
) -> list[MigrationResult]:
"""Migrate a batch of resources.

Args:
resources: List of resources to migrate.
max_concurrent: Optional concurrency cap for this batch.

Returns:
List of migration results.
"""
results = []

for resource in resources:
async def _migrate_single(resource: T) -> MigrationResult:
source_id = None # Initialize source_id before try block
try:
source_id = self.get_resource_id(resource)
Expand All @@ -890,21 +893,18 @@ async def migrate_batch(self, resources: list[T]) -> list[MigrationResult]:
name=resource_name,
)

results.append(
MigrationResult(
success=True,
source_id=source_id,
dest_id=dest_id,
skipped=True,
metadata={
"name": resource_name,
"skip_reason": "already_migrated",
}
if resource_name
else {"skip_reason": "already_migrated"},
)
return MigrationResult(
success=True,
source_id=source_id,
dest_id=dest_id,
skipped=True,
metadata={
"name": resource_name,
"skip_reason": "already_migrated",
}
if resource_name
else {"skip_reason": "already_migrated"},
)
continue

# Check dependencies
dependencies = await self.get_dependencies(resource)
Expand All @@ -916,14 +916,11 @@ async def migrate_batch(self, resources: list[T]) -> list[MigrationResult]:
except ValueError as e:
# This should rarely happen now with strict=False, but keep as fallback
self.record_failure(source_id, str(e))
results.append(
MigrationResult(
success=False,
source_id=source_id,
error=str(e),
)
return MigrationResult(
success=False,
source_id=source_id,
error=str(e),
)
continue

# Perform migration
dest_id = await self.migrate_resource(resource)
Expand All @@ -938,35 +935,53 @@ async def migrate_batch(self, resources: list[T]) -> list[MigrationResult]:
)

self.record_success(source_id, dest_id, resource)
results.append(
MigrationResult(
success=True,
source_id=source_id,
dest_id=dest_id,
metadata={"name": resource_name} if resource_name else {},
)
return MigrationResult(
success=True,
source_id=source_id,
dest_id=dest_id,
metadata={"name": resource_name} if resource_name else {},
)

except Exception as e:
error_msg = f"Failed to migrate resource: {e}"
# Use source_id if available, otherwise use a fallback
resource_id = source_id if source_id else f"unknown_{id(resource)}"
self.record_failure(resource_id, error_msg)
results.append(
MigrationResult(
success=False,
source_id=resource_id,
error=error_msg,
)
return MigrationResult(
success=False,
source_id=resource_id,
error=error_msg,
)

return results
if not resources:
return []

effective_max = (
int(max_concurrent) if max_concurrent is not None else len(resources)
)
if effective_max <= 1:
results: list[MigrationResult] = []
for resource in resources:
results.append(await _migrate_single(resource))
return results

semaphore = asyncio.Semaphore(effective_max)

async def _run_with_limit(resource: T) -> MigrationResult:
async with semaphore:
return await _migrate_single(resource)

tasks = [asyncio.create_task(_run_with_limit(r)) for r in resources]
return await asyncio.gather(*tasks)

async def migrate_all(self, project_id: str | None = None) -> dict[str, Any]:
async def migrate_all(
self, project_id: str | None = None, max_concurrent: int | None = None
) -> dict[str, Any]:
"""Migrate all resources of this type.

Args:
project_id: Optional project ID to filter resources.
max_concurrent: Optional concurrency cap to pass to each batch.

Returns:
Summary of migration results.
Expand Down Expand Up @@ -1007,7 +1022,15 @@ async def migrate_all(self, project_id: str | None = None) -> dict[str, Any]:
batch_size=len(batch),
)

batch_results = await self.migrate_batch(batch)
supports_max_concurrent = (
"max_concurrent" in inspect.signature(self.migrate_batch).parameters
)
if max_concurrent is None or not supports_max_concurrent:
batch_results = await self.migrate_batch(batch)
else:
batch_results = await self.migrate_batch(
batch, max_concurrent=max_concurrent
)

# Aggregate results with details
for result in batch_results:
Expand Down
5 changes: 4 additions & 1 deletion braintrust_migrate/resources/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,10 @@ def _dump_oversize_event_summary(
cursor=cursor,
)

async def migrate_all(self, project_id: str | None = None) -> dict[str, Any]:
async def migrate_all(
self, project_id: str | None = None, max_concurrent: int | None = None
) -> dict[str, Any]:
_ = max_concurrent
if not project_id:
return {
"resource_type": self.resource_name,
Expand Down
2 changes: 2 additions & 0 deletions braintrust_migrate/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ async def stream_btql_sorted_events(
incr_skipped_deleted: Callable[[int], None] | None,
incr_skipped_seen: Callable[[int], None] | None,
incr_attachments_copied: Callable[[int], None] | None,
pipeline: bool = False,
# Optional progress hooks
hooks: StreamHooks | None = None,
) -> None:
Expand All @@ -214,6 +215,7 @@ async def stream_btql_sorted_events(
- Advancing pagination key only after successful inserts
"""

_ = pipeline
page_num = 0
while True:
page_num += 1
Expand Down
Loading