Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions bajor/batch/predictions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# training job specific functions
import logging, os, sys

from bajor.batch.checkpoint_strategies import get_checkpoint_target
from bajor.batch.runtime_config import (
resolve_container_image_name,
resolve_checkpoint_target,
resolve_prediction_script_path,
)

if os.getenv('DEBUG'):
import pdb
Expand All @@ -11,7 +15,7 @@
from bajor.batch.client import azure_batch_client
import bajor.batch.jobs as batch_jobs
from bajor.log_config import log
from bajor.models.job import Options
from bajor.models.job import JobOptions

# Zoobot Azure Batch predictions pool ID
predictions_pool_id = os.getenv('POOL_ID', 'predictions_0')
Expand All @@ -28,21 +32,21 @@ def get_non_active_batch_job_list():
return batch_jobs.get_non_active_batch_job_list(predictions_pool_id)

# schedule a training job
def schedule_job(job_id:str, manifest_url:str, options:Options=Options()):
checkpoint_target = get_checkpoint_target(options.workflow_name)

def schedule_job(job_id:str, manifest_url:str, options:JobOptions=JobOptions()):
submitted_job_id = create_batch_job(
job_id=job_id, manifest_url=manifest_url, pool_id=predictions_pool_id, checkpoint_target=checkpoint_target)
job_id=job_id, manifest_url=manifest_url, pool_id=predictions_pool_id, options=options)
job_task_submission_status = create_job_tasks(
job_id=job_id, run_opts=options.run_opts)
job_id=job_id, options=options)

# return the submitted job_id and task submission status dict
return batch_jobs.job_submission_response(submitted_job_id, job_task_submission_status)


def create_batch_job(job_id, manifest_url, pool_id, checkpoint_target='ZOOBOT_CHECKPOINT_TARGET'):
def create_batch_job(job_id, manifest_url, pool_id, options: JobOptions=JobOptions()):
log.debug('server_job, create_batch_job, using manifest at url: {}'.format(manifest_url))

checkpoint_target = resolve_checkpoint_target(options)

log.debug(f'BatchJobManager, create_job, job_id: {job_id}')
job = batchmodels.JobAddParameter(
id=job_id,
Expand Down Expand Up @@ -73,7 +77,7 @@ def create_batch_job(job_id, manifest_url, pool_id, checkpoint_target='ZOOBOT_CH
# set the zoobot saved model checkpoint file path
batchmodels.EnvironmentSetting(
name='ZOOBOT_CHECKPOINT_TARGET',
value=os.getenv(checkpoint_target, 'zoobot.ckpt')),
value=checkpoint_target),
# setup error reporting service
batchmodels.EnvironmentSetting(
name='HONEYBADGER_API_KEY',
Expand Down Expand Up @@ -149,7 +153,7 @@ def job_logs_path(job_id, task_id, suffix):
return f'{job_dir(job_id)}/task_logs/job_{job_id}_task_{task_id}_{suffix}.txt'


def create_job_tasks(job_id, task_id=1, run_opts=''):
def create_job_tasks(job_id, task_id=1, options: JobOptions=JobOptions()):
# for persisting stdout and stderr log files in container storage
container_sas_url = batch_jobs.storage_container_sas_url(
os.getenv('PREDICTIONS_STORAGE_CONTAINER', 'predictions'))
Expand Down Expand Up @@ -186,10 +190,10 @@ def create_job_tasks(job_id, task_id=1, run_opts=''):
tasks = []
# ZOOBOT command for catalogue predictions!
# see jobPreparation task for code setup
prediction_code_path = os.getenv('ZOOBOT_PREDICTION_CMD', 'predict_catalog_with_model.py')
prediction_code_path = resolve_prediction_script_path(options)
setup_hugging_face_cache_env_var = f'HF_HOME={huggingface_dir}'
# TODO: perhaps we can add the output file extension as a job env param that can be modified by job runtime params
escaped_opts = run_opts.replace('"','\\"')
escaped_opts = options.run_opts.replace('"','\\"')
prediction_cmd = f'$AZ_BATCH_NODE_SHARED_DIR/{prediction_code_path} {escaped_opts} --checkpoint-path $AZ_BATCH_NODE_MOUNTS_DIR/$MODELS_CONTAINER_MOUNT_DIR/$ZOOBOT_CHECKPOINT_TARGET --catalog-url $MANIFEST_URL --save-path $AZ_BATCH_NODE_MOUNTS_DIR/$PREDICTIONS_CONTAINER_MOUNT_DIR/$PREDICTIONS_JOB_RESULTS_DIR/predictions.json'
# redirect the stdout to stderr for logging
command = f'/bin/bash -c \"set -ex; {setup_hugging_face_cache_env_var}; python {prediction_cmd}\"'
Expand All @@ -200,7 +204,7 @@ def create_job_tasks(job_id, task_id=1, run_opts=''):
id=str(task_id),
command_line=command,
container_settings=batchmodels.TaskContainerSettings(
image_name=os.getenv('CONTAINER_IMAGE_NAME'),
image_name=resolve_container_image_name(options),
working_directory='taskWorkingDirectory',
container_run_options='--ipc=host'
),
Expand Down Expand Up @@ -239,4 +243,3 @@ def create_job_tasks(job_id, task_id=1, run_opts=''):
format = '[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s',
stream = sys.stdout
)

53 changes: 53 additions & 0 deletions bajor/batch/runtime_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
from urllib.parse import urlparse

from bajor.batch.checkpoint_strategies import get_checkpoint_target
from bajor.models.job import JobOptions


MODELS_CONTAINER_MOUNT_DIR_ENV = 'MODELS_CONTAINER_MOUNT_DIR'
MODELS_CONTAINER_NAME = 'models'
def _models_mount_root() -> str:
return f'$AZ_BATCH_NODE_MOUNTS_DIR/${MODELS_CONTAINER_MOUNT_DIR_ENV}'


def _relative_checkpoint_path(checkpoint_ref: str) -> str:
if checkpoint_ref.startswith('http://') or checkpoint_ref.startswith('https://'):
checkpoint_path = urlparse(checkpoint_ref).path.lstrip('/')
if checkpoint_path.startswith(f'{MODELS_CONTAINER_NAME}/'):
return checkpoint_path[len(f'{MODELS_CONTAINER_NAME}/'):]
return checkpoint_path

return checkpoint_ref.lstrip('/')


def resolve_checkpoint_target(options: JobOptions) -> str:
if options.pretrained_checkpoint_url:
return _relative_checkpoint_path(options.pretrained_checkpoint_url)

checkpoint_target = get_checkpoint_target(options.workflow_name)
return os.getenv(checkpoint_target, 'zoobot.ckpt')


def resolve_pretrained_checkpoint_path(options: JobOptions) -> str:
if options.pretrained_checkpoint_url:
return f'{_models_mount_root()}/{_relative_checkpoint_path(options.pretrained_checkpoint_url)}'

checkpoint_file = os.getenv('ZOOBOT_FINETUNE_CHECKPOINT_FILE', 'zoobot_pretrained_model.ckpt')
return f'{_models_mount_root()}/{checkpoint_file}'


def resolve_training_script_path(options: JobOptions) -> str:
return options.training_script_path or os.getenv('ZOOBOT_FINETUNE_TRAIN_CMD', 'train_model_finetune_on_catalog.py')


def resolve_prediction_script_path(options: JobOptions) -> str:
return options.prediction_script_path or os.getenv('ZOOBOT_PREDICTION_CMD', 'predict_catalog_with_model.py')


def resolve_promote_script_path(options: JobOptions) -> str:
return options.promote_script_path or os.getenv('ZOOBOT_PROMOTE_CMD', 'promote_best_checkpoint_to_model.sh')


def resolve_container_image_name(options: JobOptions) -> str:
return options.container_image_name or os.getenv('CONTAINER_IMAGE_NAME')
40 changes: 22 additions & 18 deletions bajor/batch/train_finetuning.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
# training job specific functions
import logging, os, sys

from bajor.batch.checkpoint_strategies import get_checkpoint_target
from bajor.batch.runtime_config import (
resolve_container_image_name,
resolve_checkpoint_target,
resolve_pretrained_checkpoint_path,
resolve_promote_script_path,
resolve_training_script_path,
)

if os.getenv('DEBUG'):
import pdb
Expand All @@ -11,7 +17,7 @@
from bajor.batch.client import azure_batch_client
import bajor.batch.jobs as batch_jobs
from bajor.log_config import log
from bajor.models.job import Options
from bajor.models.job import JobOptions

# Zoobot Azure Batch training pool ID
training_pool_id = os.getenv('POOL_ID', 'training_1')
Expand All @@ -29,21 +35,21 @@ def get_non_active_batch_job_list():
return batch_jobs.get_non_active_batch_job_list(training_pool_id)

# schedule a training job
def schedule_job(job_id: str, manifest_path:str, options: Options=Options()):
checkpoint_target = get_checkpoint_target(options.workflow_name)

def schedule_job(job_id: str, manifest_path:str, options: JobOptions=JobOptions()):
submitted_job_id = create_batch_job(
job_id=job_id, manifest_container_path=manifest_path, pool_id=training_pool_id, checkpoint_target=checkpoint_target)
job_id=job_id, manifest_container_path=manifest_path, pool_id=training_pool_id, options=options)
job_task_submission_status = create_job_tasks(
job_id=job_id, run_opts=options.run_opts)
job_id=job_id, options=options)

# return the submitted job_id and task submission status dict
return batch_jobs.job_submission_response(submitted_job_id, job_task_submission_status)

def create_batch_job(job_id, manifest_container_path, pool_id, checkpoint_target='ZOOBOT_CHECKPOINT_TARGET'):
def create_batch_job(job_id, manifest_container_path, pool_id, options: JobOptions=JobOptions()):
log.debug('server_job, create_batch_job, using manifest from path: {}'.format(
manifest_container_path))

checkpoint_target = resolve_checkpoint_target(options)

log.debug(f'BatchJobManager, create_job, job_id: {job_id}')
job = batchmodels.JobAddParameter(
id=job_id,
Expand Down Expand Up @@ -84,7 +90,7 @@ def create_batch_job(job_id, manifest_container_path, pool_id, checkpoint_target
# set the zoobot saved model checkpoint file path
batchmodels.EnvironmentSetting(
name='ZOOBOT_CHECKPOINT_TARGET',
value=os.getenv(checkpoint_target, 'zoobot.ckpt')),
value=checkpoint_target),
# setup error reporting service
batchmodels.EnvironmentSetting(
name='HONEYBADGER_API_KEY',
Expand Down Expand Up @@ -176,7 +182,7 @@ def training_job_logs_path(job_id, task_id, suffix):
return f'{training_job_dir(job_id)}/task_logs/job_{job_id}_task_{task_id}_{suffix}.txt'


def create_job_tasks(job_id, task_id=1, run_opts=''):
def create_job_tasks(job_id, task_id=1, options: JobOptions=JobOptions()):
# for persisting stdout and stderr log files in container storage
container_sas_url = batch_jobs.storage_container_sas_url(
os.getenv('TRAINING_STORAGE_CONTAINER', 'training'))
Expand Down Expand Up @@ -220,14 +226,13 @@ def create_job_tasks(job_id, task_id=1, run_opts=''):
# train_cmd file path is copied from blob storage into this runtime container
# so this location is relative to the container paths and can be modified at runtime
# see jobPreparation task for code setup
train_code_path = os.getenv('ZOOBOT_FINETUNE_TRAIN_CMD', 'train_model_finetune_on_catalog.py')
# checkpoint file is the base model for finetuning (transfer learning)
checkpoint_file = os.getenv('ZOOBOT_FINETUNE_CHECKPOINT_FILE', 'zoobot_pretrained_model.ckpt')
train_code_path = resolve_training_script_path(options)
checkpoint_path = resolve_pretrained_checkpoint_path(options)
# setup the training cmd
escaped_opts = run_opts.replace('"','\\"')
train_cmd = f'$AZ_BATCH_NODE_SHARED_DIR/{train_code_path} {escaped_opts} --checkpoint $AZ_BATCH_NODE_MOUNTS_DIR/$MODELS_CONTAINER_MOUNT_DIR/{checkpoint_file} --catalog $AZ_BATCH_NODE_MOUNTS_DIR/$TRAINING_CONTAINER_MOUNT_DIR/$MANIFEST_PATH --save-dir $AZ_BATCH_NODE_MOUNTS_DIR/$TRAINING_CONTAINER_MOUNT_DIR/$TRAINING_JOB_RESULTS_DIR/'
escaped_opts = options.run_opts.replace('"','\\"')
train_cmd = f'$AZ_BATCH_NODE_SHARED_DIR/{train_code_path} {escaped_opts} --checkpoint {checkpoint_path} --catalog $AZ_BATCH_NODE_MOUNTS_DIR/$TRAINING_CONTAINER_MOUNT_DIR/$MANIFEST_PATH --save-dir $AZ_BATCH_NODE_MOUNTS_DIR/$TRAINING_CONTAINER_MOUNT_DIR/$TRAINING_JOB_RESULTS_DIR/'
# and a way to promote the resulting model artifact for use in prediction systems
promote_model_code_path = os.getenv('ZOOBOT_PROMOTE_CMD', 'promote_best_checkpoint_to_model.sh')
promote_model_code_path = resolve_promote_script_path(options)
# redirect the stdout to stderr for logging
promote_checkpoint_cmd = f'$AZ_BATCH_NODE_SHARED_DIR/{promote_model_code_path} $AZ_BATCH_NODE_MOUNTS_DIR/$TRAINING_CONTAINER_MOUNT_DIR/$TRAINING_JOB_RESULTS_DIR 2>&1'
# ensure pytorch has the correct kernel cach path (this enables CUDA JIT - https://pytorch.org/docs/stable/notes/cuda.html#just-in-time-compilation)
Expand All @@ -247,7 +252,7 @@ def create_job_tasks(job_id, task_id=1, run_opts=''):
id=str(task_id),
command_line=command,
container_settings=batchmodels.TaskContainerSettings(
image_name=os.getenv('CONTAINER_IMAGE_NAME'),
image_name=resolve_container_image_name(options),
working_directory='taskWorkingDirectory',
container_run_options='--ipc=host'
),
Expand Down Expand Up @@ -286,4 +291,3 @@ def create_job_tasks(job_id, task_id=1, run_opts=''):
format = '[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s',
stream = sys.stdout
)

13 changes: 9 additions & 4 deletions bajor/models/job.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from pydantic import BaseModel, HttpUrl
from typing import Optional, Dict
from typing import Optional

class Options(BaseModel):
class JobOptions(BaseModel):
run_opts: str = ""
workflow_name: str = 'cosmic_dawn'
container_image_name: Optional[str] = None
training_script_path: Optional[str] = None
prediction_script_path: Optional[str] = None
promote_script_path: Optional[str] = None
pretrained_checkpoint_url: Optional[str] = None


class TrainingJob(BaseModel):
manifest_path: str
id: Optional[str] = None
status: Optional[str] = None
opts: Options = Options()
opts: JobOptions = JobOptions()

# remove the leading / from the manifest url
# as it's added via the blob storage paths in schedule_job
Expand All @@ -22,4 +27,4 @@ class PredictionJob(BaseModel):
manifest_url: HttpUrl
id: Optional[str] = None
status: Optional[str] = None
opts: Options = Options()
opts: JobOptions = JobOptions()
31 changes: 31 additions & 0 deletions tests/batch/test_runtime_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from bajor.batch.runtime_config import (
resolve_container_image_name,
resolve_checkpoint_target,
resolve_pretrained_checkpoint_path,
)
from bajor.models.job import JobOptions


def test_explicit_checkpoint_path_overrides_legacy_checkpoint_target():
options = JobOptions(
workflow_name='euclid',
pretrained_checkpoint_url='jwst/custom.ckpt'
)

assert resolve_checkpoint_target(options) == 'jwst/custom.ckpt'
assert resolve_pretrained_checkpoint_path(options) == '$AZ_BATCH_NODE_MOUNTS_DIR/$MODELS_CONTAINER_MOUNT_DIR/jwst/custom.ckpt'


def test_blob_url_checkpoint_ref_is_normalized_to_relative_models_path():
options = JobOptions(
pretrained_checkpoint_url='https://kadeactivelearning.blob.core.windows.net/models/staging-euclid-zoobot.ckpt'
)

assert resolve_checkpoint_target(options) == 'staging-euclid-zoobot.ckpt'
assert resolve_pretrained_checkpoint_path(options) == '$AZ_BATCH_NODE_MOUNTS_DIR/$MODELS_CONTAINER_MOUNT_DIR/staging-euclid-zoobot.ckpt'


def test_explicit_container_image_name_overrides_env():
options = JobOptions(container_image_name='zoobot.azurecr.io/pytorch:custom-jwst')

assert resolve_container_image_name(options) == 'zoobot.azurecr.io/pytorch:custom-jwst'
25 changes: 23 additions & 2 deletions tests/batch/test_training.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import bajor.batch.train_from_scratch as train_from_scratch
import bajor.batch.train_finetuning as train_finetuning
import bajor.batch.predictions as predictions
from bajor.batch.jobs import active_jobs_running
import uuid, os
from unittest import mock
from bajor.models.job import JobOptions

fake_job_id = str(uuid.uuid4())
test_pool = 'pool'
Expand Down Expand Up @@ -40,9 +42,9 @@ def test_schedule_job(mock_create_job_tasks, mock_create_batch_job):
def test_no_active_jobs(mock_create_job_tasks, mock_create_batch_job):
train_finetuning.schedule_job(fake_job_id, 'fake-manifest.csv')
mock_create_batch_job.assert_called_once_with(
job_id=fake_job_id, manifest_container_path='fake-manifest.csv', pool_id='training_1', checkpoint_target= 'ZOOBOT_CHECKPOINT_TARGET')
job_id=fake_job_id, manifest_container_path='fake-manifest.csv', pool_id='training_1', options=JobOptions())
mock_create_job_tasks.assert_called_once_with(
job_id=fake_job_id, run_opts='')
job_id=fake_job_id, options=JobOptions())


@mock.patch('bajor.batch.train_finetuning.create_batch_job')
Expand All @@ -58,3 +60,22 @@ def test_schedule_job(mock_create_job_tasks, mock_create_batch_job):
result_dict = train_finetuning.schedule_job(
submitted_job_id, 'fake-manifest-path.csv')
assert(result_dict) == expected_result_dict


@mock.patch('bajor.batch.predictions.create_batch_job')
@mock.patch('bajor.batch.predictions.create_job_tasks')
def test_prediction_schedule_job_uses_options(mock_create_job_tasks, mock_create_batch_job):
options = JobOptions(
prediction_script_path='predict_catalog_with_model.py',
pretrained_checkpoint_url='custom.ckpt'
)

predictions.schedule_job(fake_job_id, 'https://manifest-host/predictions.json', options)

mock_create_batch_job.assert_called_once_with(
job_id=fake_job_id,
manifest_url='https://manifest-host/predictions.json',
pool_id='predictions_0',
options=options
)
mock_create_job_tasks.assert_called_once_with(job_id=fake_job_id, options=options)
14 changes: 13 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,16 @@ def test_batch_scheduling_code_is_called(mocked_client):

assert response.status_code == 201
assert response.json() == {
'manifest_path': 'test_manifest_file_path.csv', 'id': submitted_job_id, 'opts': {'run_opts': run_opts, 'workflow_name': 'cosmic_dawn'}, 'status': {"status": "started", "message": "Job submitted successfully"}}
'manifest_path': 'test_manifest_file_path.csv',
'id': submitted_job_id,
'opts': {
'run_opts': run_opts,
'workflow_name': 'cosmic_dawn',
'container_image_name': None,
'training_script_path': None,
'prediction_script_path': None,
'promote_script_path': None,
'pretrained_checkpoint_url': None
},
'status': {"status": "started", "message": "Job submitted successfully"}
}
Loading