diff --git a/bajor/batch/predictions.py b/bajor/batch/predictions.py index e7100b0..1007a34 100644 --- a/bajor/batch/predictions.py +++ b/bajor/batch/predictions.py @@ -1,7 +1,11 @@ # training job specific functions import logging, os, sys -from bajor.batch.checkpoint_strategies import get_checkpoint_target +from bajor.batch.runtime_config import ( + resolve_container_image_name, + resolve_checkpoint_target, + resolve_prediction_script_path, +) if os.getenv('DEBUG'): import pdb @@ -11,7 +15,7 @@ from bajor.batch.client import azure_batch_client import bajor.batch.jobs as batch_jobs from bajor.log_config import log -from bajor.models.job import Options +from bajor.models.job import JobOptions # Zoobot Azure Batch predictions pool ID predictions_pool_id = os.getenv('POOL_ID', 'predictions_0') @@ -28,21 +32,21 @@ def get_non_active_batch_job_list(): return batch_jobs.get_non_active_batch_job_list(predictions_pool_id) # schedule a training job -def schedule_job(job_id:str, manifest_url:str, options:Options=Options()): - checkpoint_target = get_checkpoint_target(options.workflow_name) - +def schedule_job(job_id:str, manifest_url:str, options:JobOptions=JobOptions()): submitted_job_id = create_batch_job( - job_id=job_id, manifest_url=manifest_url, pool_id=predictions_pool_id, checkpoint_target=checkpoint_target) + job_id=job_id, manifest_url=manifest_url, pool_id=predictions_pool_id, options=options) job_task_submission_status = create_job_tasks( - job_id=job_id, run_opts=options.run_opts) + job_id=job_id, options=options) # return the submitted job_id and task submission status dict return batch_jobs.job_submission_response(submitted_job_id, job_task_submission_status) -def create_batch_job(job_id, manifest_url, pool_id, checkpoint_target='ZOOBOT_CHECKPOINT_TARGET'): +def create_batch_job(job_id, manifest_url, pool_id, options: JobOptions=JobOptions()): log.debug('server_job, create_batch_job, using manifest at url: {}'.format(manifest_url)) + checkpoint_target = resolve_checkpoint_target(options) + log.debug(f'BatchJobManager, create_job, job_id: {job_id}') job = batchmodels.JobAddParameter( id=job_id, @@ -73,7 +77,7 @@ def create_batch_job(job_id, manifest_url, pool_id, checkpoint_target='ZOOBOT_CH # set the zoobot saved model checkpoint file path batchmodels.EnvironmentSetting( name='ZOOBOT_CHECKPOINT_TARGET', - value=os.getenv(checkpoint_target, 'zoobot.ckpt')), + value=checkpoint_target), # setup error reporting service batchmodels.EnvironmentSetting( name='HONEYBADGER_API_KEY', @@ -149,7 +153,7 @@ def job_logs_path(job_id, task_id, suffix): return f'{job_dir(job_id)}/task_logs/job_{job_id}_task_{task_id}_{suffix}.txt' -def create_job_tasks(job_id, task_id=1, run_opts=''): +def create_job_tasks(job_id, task_id=1, options: JobOptions=JobOptions()): # for persisting stdout and stderr log files in container storage container_sas_url = batch_jobs.storage_container_sas_url( os.getenv('PREDICTIONS_STORAGE_CONTAINER', 'predictions')) @@ -186,10 +190,10 @@ def create_job_tasks(job_id, task_id=1, run_opts=''): tasks = [] # ZOOBOT command for catalogue predictions! # see jobPreparation task for code setup - prediction_code_path = os.getenv('ZOOBOT_PREDICTION_CMD', 'predict_catalog_with_model.py') + prediction_code_path = resolve_prediction_script_path(options) setup_hugging_face_cache_env_var = f'HF_HOME={huggingface_dir}' # TODO: perhaps we can add the output file extension as a job env param that can be modified by job runtime params - escaped_opts = run_opts.replace('"','\\"') + escaped_opts = options.run_opts.replace('"','\\"') prediction_cmd = f'$AZ_BATCH_NODE_SHARED_DIR/{prediction_code_path} {escaped_opts} --checkpoint-path $AZ_BATCH_NODE_MOUNTS_DIR/$MODELS_CONTAINER_MOUNT_DIR/$ZOOBOT_CHECKPOINT_TARGET --catalog-url $MANIFEST_URL --save-path $AZ_BATCH_NODE_MOUNTS_DIR/$PREDICTIONS_CONTAINER_MOUNT_DIR/$PREDICTIONS_JOB_RESULTS_DIR/predictions.json' # redirect the stdout to stderr for logging command = f'/bin/bash -c \"set -ex; {setup_hugging_face_cache_env_var}; python {prediction_cmd}\"' @@ -200,7 +204,7 @@ def create_job_tasks(job_id, task_id=1, run_opts=''): id=str(task_id), command_line=command, container_settings=batchmodels.TaskContainerSettings( - image_name=os.getenv('CONTAINER_IMAGE_NAME'), + image_name=resolve_container_image_name(options), working_directory='taskWorkingDirectory', container_run_options='--ipc=host' ), @@ -239,4 +243,3 @@ def create_job_tasks(job_id, task_id=1, run_opts=''): format = '[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', stream = sys.stdout ) - diff --git a/bajor/batch/runtime_config.py b/bajor/batch/runtime_config.py new file mode 100644 index 0000000..ce49631 --- /dev/null +++ b/bajor/batch/runtime_config.py @@ -0,0 +1,53 @@ +import os +from urllib.parse import urlparse + +from bajor.batch.checkpoint_strategies import get_checkpoint_target +from bajor.models.job import JobOptions + + +MODELS_CONTAINER_MOUNT_DIR_ENV = 'MODELS_CONTAINER_MOUNT_DIR' +MODELS_CONTAINER_NAME = 'models' +def _models_mount_root() -> str: + return f'$AZ_BATCH_NODE_MOUNTS_DIR/${MODELS_CONTAINER_MOUNT_DIR_ENV}' + + +def _relative_checkpoint_path(checkpoint_ref: str) -> str: + if checkpoint_ref.startswith('http://') or checkpoint_ref.startswith('https://'): + checkpoint_path = urlparse(checkpoint_ref).path.lstrip('/') + if checkpoint_path.startswith(f'{MODELS_CONTAINER_NAME}/'): + return checkpoint_path[len(f'{MODELS_CONTAINER_NAME}/'):] + return checkpoint_path + + return checkpoint_ref.lstrip('/') + + +def resolve_checkpoint_target(options: JobOptions) -> str: + if options.pretrained_checkpoint_url: + return _relative_checkpoint_path(options.pretrained_checkpoint_url) + + checkpoint_target = get_checkpoint_target(options.workflow_name) + return os.getenv(checkpoint_target, 'zoobot.ckpt') + + +def resolve_pretrained_checkpoint_path(options: JobOptions) -> str: + if options.pretrained_checkpoint_url: + return f'{_models_mount_root()}/{_relative_checkpoint_path(options.pretrained_checkpoint_url)}' + + checkpoint_file = os.getenv('ZOOBOT_FINETUNE_CHECKPOINT_FILE', 'zoobot_pretrained_model.ckpt') + return f'{_models_mount_root()}/{checkpoint_file}' + + +def resolve_training_script_path(options: JobOptions) -> str: + return options.training_script_path or os.getenv('ZOOBOT_FINETUNE_TRAIN_CMD', 'train_model_finetune_on_catalog.py') + + +def resolve_prediction_script_path(options: JobOptions) -> str: + return options.prediction_script_path or os.getenv('ZOOBOT_PREDICTION_CMD', 'predict_catalog_with_model.py') + + +def resolve_promote_script_path(options: JobOptions) -> str: + return options.promote_script_path or os.getenv('ZOOBOT_PROMOTE_CMD', 'promote_best_checkpoint_to_model.sh') + + +def resolve_container_image_name(options: JobOptions) -> str: + return options.container_image_name or os.getenv('CONTAINER_IMAGE_NAME') diff --git a/bajor/batch/train_finetuning.py b/bajor/batch/train_finetuning.py index 4a696d4..0128185 100644 --- a/bajor/batch/train_finetuning.py +++ b/bajor/batch/train_finetuning.py @@ -1,7 +1,13 @@ # training job specific functions import logging, os, sys -from bajor.batch.checkpoint_strategies import get_checkpoint_target +from bajor.batch.runtime_config import ( + resolve_container_image_name, + resolve_checkpoint_target, + resolve_pretrained_checkpoint_path, + resolve_promote_script_path, + resolve_training_script_path, +) if os.getenv('DEBUG'): import pdb @@ -11,7 +17,7 @@ from bajor.batch.client import azure_batch_client import bajor.batch.jobs as batch_jobs from bajor.log_config import log -from bajor.models.job import Options +from bajor.models.job import JobOptions # Zoobot Azure Batch training pool ID training_pool_id = os.getenv('POOL_ID', 'training_1') @@ -29,21 +35,21 @@ def get_non_active_batch_job_list(): return batch_jobs.get_non_active_batch_job_list(training_pool_id) # schedule a training job -def schedule_job(job_id: str, manifest_path:str, options: Options=Options()): - checkpoint_target = get_checkpoint_target(options.workflow_name) - +def schedule_job(job_id: str, manifest_path:str, options: JobOptions=JobOptions()): submitted_job_id = create_batch_job( - job_id=job_id, manifest_container_path=manifest_path, pool_id=training_pool_id, checkpoint_target=checkpoint_target) + job_id=job_id, manifest_container_path=manifest_path, pool_id=training_pool_id, options=options) job_task_submission_status = create_job_tasks( - job_id=job_id, run_opts=options.run_opts) + job_id=job_id, options=options) # return the submitted job_id and task submission status dict return batch_jobs.job_submission_response(submitted_job_id, job_task_submission_status) -def create_batch_job(job_id, manifest_container_path, pool_id, checkpoint_target='ZOOBOT_CHECKPOINT_TARGET'): +def create_batch_job(job_id, manifest_container_path, pool_id, options: JobOptions=JobOptions()): log.debug('server_job, create_batch_job, using manifest from path: {}'.format( manifest_container_path)) + checkpoint_target = resolve_checkpoint_target(options) + log.debug(f'BatchJobManager, create_job, job_id: {job_id}') job = batchmodels.JobAddParameter( id=job_id, @@ -84,7 +90,7 @@ def create_batch_job(job_id, manifest_container_path, pool_id, checkpoint_target # set the zoobot saved model checkpoint file path batchmodels.EnvironmentSetting( name='ZOOBOT_CHECKPOINT_TARGET', - value=os.getenv(checkpoint_target, 'zoobot.ckpt')), + value=checkpoint_target), # setup error reporting service batchmodels.EnvironmentSetting( name='HONEYBADGER_API_KEY', @@ -176,7 +182,7 @@ def training_job_logs_path(job_id, task_id, suffix): return f'{training_job_dir(job_id)}/task_logs/job_{job_id}_task_{task_id}_{suffix}.txt' -def create_job_tasks(job_id, task_id=1, run_opts=''): +def create_job_tasks(job_id, task_id=1, options: JobOptions=JobOptions()): # for persisting stdout and stderr log files in container storage container_sas_url = batch_jobs.storage_container_sas_url( os.getenv('TRAINING_STORAGE_CONTAINER', 'training')) @@ -220,14 +226,13 @@ def create_job_tasks(job_id, task_id=1, run_opts=''): # train_cmd file path is copied from blob storage into this runtime container # so this location is relative to the container paths and can be modified at runtime # see jobPreparation task for code setup - train_code_path = os.getenv('ZOOBOT_FINETUNE_TRAIN_CMD', 'train_model_finetune_on_catalog.py') - # checkpoint file is the base model for finetuning (transfer learning) - checkpoint_file = os.getenv('ZOOBOT_FINETUNE_CHECKPOINT_FILE', 'zoobot_pretrained_model.ckpt') + train_code_path = resolve_training_script_path(options) + checkpoint_path = resolve_pretrained_checkpoint_path(options) # setup the training cmd - escaped_opts = run_opts.replace('"','\\"') - train_cmd = f'$AZ_BATCH_NODE_SHARED_DIR/{train_code_path} {escaped_opts} --checkpoint $AZ_BATCH_NODE_MOUNTS_DIR/$MODELS_CONTAINER_MOUNT_DIR/{checkpoint_file} --catalog $AZ_BATCH_NODE_MOUNTS_DIR/$TRAINING_CONTAINER_MOUNT_DIR/$MANIFEST_PATH --save-dir $AZ_BATCH_NODE_MOUNTS_DIR/$TRAINING_CONTAINER_MOUNT_DIR/$TRAINING_JOB_RESULTS_DIR/' + escaped_opts = options.run_opts.replace('"','\\"') + train_cmd = f'$AZ_BATCH_NODE_SHARED_DIR/{train_code_path} {escaped_opts} --checkpoint {checkpoint_path} --catalog $AZ_BATCH_NODE_MOUNTS_DIR/$TRAINING_CONTAINER_MOUNT_DIR/$MANIFEST_PATH --save-dir $AZ_BATCH_NODE_MOUNTS_DIR/$TRAINING_CONTAINER_MOUNT_DIR/$TRAINING_JOB_RESULTS_DIR/' # and a way to promote the resulting model artifact for use in prediction systems - promote_model_code_path = os.getenv('ZOOBOT_PROMOTE_CMD', 'promote_best_checkpoint_to_model.sh') + promote_model_code_path = resolve_promote_script_path(options) # redirect the stdout to stderr for logging promote_checkpoint_cmd = f'$AZ_BATCH_NODE_SHARED_DIR/{promote_model_code_path} $AZ_BATCH_NODE_MOUNTS_DIR/$TRAINING_CONTAINER_MOUNT_DIR/$TRAINING_JOB_RESULTS_DIR 2>&1' # ensure pytorch has the correct kernel cach path (this enables CUDA JIT - https://pytorch.org/docs/stable/notes/cuda.html#just-in-time-compilation) @@ -247,7 +252,7 @@ def create_job_tasks(job_id, task_id=1, run_opts=''): id=str(task_id), command_line=command, container_settings=batchmodels.TaskContainerSettings( - image_name=os.getenv('CONTAINER_IMAGE_NAME'), + image_name=resolve_container_image_name(options), working_directory='taskWorkingDirectory', container_run_options='--ipc=host' ), @@ -286,4 +291,3 @@ def create_job_tasks(job_id, task_id=1, run_opts=''): format = '[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', stream = sys.stdout ) - diff --git a/bajor/models/job.py b/bajor/models/job.py index 96cb85d..10743b6 100644 --- a/bajor/models/job.py +++ b/bajor/models/job.py @@ -1,16 +1,21 @@ from pydantic import BaseModel, HttpUrl -from typing import Optional, Dict +from typing import Optional -class Options(BaseModel): +class JobOptions(BaseModel): run_opts: str = "" workflow_name: str = 'cosmic_dawn' + container_image_name: Optional[str] = None + training_script_path: Optional[str] = None + prediction_script_path: Optional[str] = None + promote_script_path: Optional[str] = None + pretrained_checkpoint_url: Optional[str] = None class TrainingJob(BaseModel): manifest_path: str id: Optional[str] = None status: Optional[str] = None - opts: Options = Options() + opts: JobOptions = JobOptions() # remove the leading / from the manifest url # as it's added via the blob storage paths in schedule_job @@ -22,4 +27,4 @@ class PredictionJob(BaseModel): manifest_url: HttpUrl id: Optional[str] = None status: Optional[str] = None - opts: Options = Options() + opts: JobOptions = JobOptions() diff --git a/tests/batch/test_runtime_config.py b/tests/batch/test_runtime_config.py new file mode 100644 index 0000000..2c9f5ec --- /dev/null +++ b/tests/batch/test_runtime_config.py @@ -0,0 +1,31 @@ +from bajor.batch.runtime_config import ( + resolve_container_image_name, + resolve_checkpoint_target, + resolve_pretrained_checkpoint_path, +) +from bajor.models.job import JobOptions + + +def test_explicit_checkpoint_path_overrides_legacy_checkpoint_target(): + options = JobOptions( + workflow_name='euclid', + pretrained_checkpoint_url='jwst/custom.ckpt' + ) + + assert resolve_checkpoint_target(options) == 'jwst/custom.ckpt' + assert resolve_pretrained_checkpoint_path(options) == '$AZ_BATCH_NODE_MOUNTS_DIR/$MODELS_CONTAINER_MOUNT_DIR/jwst/custom.ckpt' + + +def test_blob_url_checkpoint_ref_is_normalized_to_relative_models_path(): + options = JobOptions( + pretrained_checkpoint_url='https://kadeactivelearning.blob.core.windows.net/models/staging-euclid-zoobot.ckpt' + ) + + assert resolve_checkpoint_target(options) == 'staging-euclid-zoobot.ckpt' + assert resolve_pretrained_checkpoint_path(options) == '$AZ_BATCH_NODE_MOUNTS_DIR/$MODELS_CONTAINER_MOUNT_DIR/staging-euclid-zoobot.ckpt' + + +def test_explicit_container_image_name_overrides_env(): + options = JobOptions(container_image_name='zoobot.azurecr.io/pytorch:custom-jwst') + + assert resolve_container_image_name(options) == 'zoobot.azurecr.io/pytorch:custom-jwst' diff --git a/tests/batch/test_training.py b/tests/batch/test_training.py index 64f1f15..7bcaa5d 100644 --- a/tests/batch/test_training.py +++ b/tests/batch/test_training.py @@ -1,8 +1,10 @@ import bajor.batch.train_from_scratch as train_from_scratch import bajor.batch.train_finetuning as train_finetuning +import bajor.batch.predictions as predictions from bajor.batch.jobs import active_jobs_running import uuid, os from unittest import mock +from bajor.models.job import JobOptions fake_job_id = str(uuid.uuid4()) test_pool = 'pool' @@ -40,9 +42,9 @@ def test_schedule_job(mock_create_job_tasks, mock_create_batch_job): def test_no_active_jobs(mock_create_job_tasks, mock_create_batch_job): train_finetuning.schedule_job(fake_job_id, 'fake-manifest.csv') mock_create_batch_job.assert_called_once_with( - job_id=fake_job_id, manifest_container_path='fake-manifest.csv', pool_id='training_1', checkpoint_target= 'ZOOBOT_CHECKPOINT_TARGET') + job_id=fake_job_id, manifest_container_path='fake-manifest.csv', pool_id='training_1', options=JobOptions()) mock_create_job_tasks.assert_called_once_with( - job_id=fake_job_id, run_opts='') + job_id=fake_job_id, options=JobOptions()) @mock.patch('bajor.batch.train_finetuning.create_batch_job') @@ -58,3 +60,22 @@ def test_schedule_job(mock_create_job_tasks, mock_create_batch_job): result_dict = train_finetuning.schedule_job( submitted_job_id, 'fake-manifest-path.csv') assert(result_dict) == expected_result_dict + + +@mock.patch('bajor.batch.predictions.create_batch_job') +@mock.patch('bajor.batch.predictions.create_job_tasks') +def test_prediction_schedule_job_uses_options(mock_create_job_tasks, mock_create_batch_job): + options = JobOptions( + prediction_script_path='predict_catalog_with_model.py', + pretrained_checkpoint_url='custom.ckpt' + ) + + predictions.schedule_job(fake_job_id, 'https://manifest-host/predictions.json', options) + + mock_create_batch_job.assert_called_once_with( + job_id=fake_job_id, + manifest_url='https://manifest-host/predictions.json', + pool_id='predictions_0', + options=options + ) + mock_create_job_tasks.assert_called_once_with(job_id=fake_job_id, options=options) diff --git a/tests/test_api.py b/tests/test_api.py index c4d8073..031f165 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -77,4 +77,16 @@ def test_batch_scheduling_code_is_called(mocked_client): assert response.status_code == 201 assert response.json() == { - 'manifest_path': 'test_manifest_file_path.csv', 'id': submitted_job_id, 'opts': {'run_opts': run_opts, 'workflow_name': 'cosmic_dawn'}, 'status': {"status": "started", "message": "Job submitted successfully"}} + 'manifest_path': 'test_manifest_file_path.csv', + 'id': submitted_job_id, + 'opts': { + 'run_opts': run_opts, + 'workflow_name': 'cosmic_dawn', + 'container_image_name': None, + 'training_script_path': None, + 'prediction_script_path': None, + 'promote_script_path': None, + 'pretrained_checkpoint_url': None + }, + 'status': {"status": "started", "message": "Job submitted successfully"} + }