From d64b26bcc1f4f37c76f3cd490ee0321820c1b31d Mon Sep 17 00:00:00 2001 From: tooyosi Date: Tue, 7 Apr 2026 18:16:48 +0100 Subject: [PATCH 1/2] Add runtime config for model-specific jobs --- bajor/batch/predictions.py | 26 ++++++++------- bajor/batch/runtime_config.py | 53 ++++++++++++++++++++++++++++++ bajor/batch/train_finetuning.py | 35 +++++++++++--------- bajor/models/job.py | 7 +++- tests/batch/test_runtime_config.py | 31 +++++++++++++++++ tests/batch/test_training.py | 25 ++++++++++++-- tests/test_api.py | 14 +++++++- 7 files changed, 161 insertions(+), 30 deletions(-) create mode 100644 bajor/batch/runtime_config.py create mode 100644 tests/batch/test_runtime_config.py diff --git a/bajor/batch/predictions.py b/bajor/batch/predictions.py index e7100b0..45ff456 100644 --- a/bajor/batch/predictions.py +++ b/bajor/batch/predictions.py @@ -1,7 +1,11 @@ # training job specific functions import logging, os, sys -from bajor.batch.checkpoint_strategies import get_checkpoint_target +from bajor.batch.runtime_config import ( + resolve_container_image_name, + resolve_checkpoint_target, + resolve_prediction_script_path, +) if os.getenv('DEBUG'): import pdb @@ -29,20 +33,20 @@ def get_non_active_batch_job_list(): # schedule a training job def schedule_job(job_id:str, manifest_url:str, options:Options=Options()): - checkpoint_target = get_checkpoint_target(options.workflow_name) - submitted_job_id = create_batch_job( - job_id=job_id, manifest_url=manifest_url, pool_id=predictions_pool_id, checkpoint_target=checkpoint_target) + job_id=job_id, manifest_url=manifest_url, pool_id=predictions_pool_id, options=options) job_task_submission_status = create_job_tasks( - job_id=job_id, run_opts=options.run_opts) + job_id=job_id, options=options) # return the submitted job_id and task submission status dict return batch_jobs.job_submission_response(submitted_job_id, job_task_submission_status) -def create_batch_job(job_id, manifest_url, pool_id, checkpoint_target='ZOOBOT_CHECKPOINT_TARGET'): +def create_batch_job(job_id, manifest_url, pool_id, options: Options=Options()): log.debug('server_job, create_batch_job, using manifest at url: {}'.format(manifest_url)) + checkpoint_target = resolve_checkpoint_target(options) + log.debug(f'BatchJobManager, create_job, job_id: {job_id}') job = batchmodels.JobAddParameter( id=job_id, @@ -73,7 +77,7 @@ def create_batch_job(job_id, manifest_url, pool_id, checkpoint_target='ZOOBOT_CH # set the zoobot saved model checkpoint file path batchmodels.EnvironmentSetting( name='ZOOBOT_CHECKPOINT_TARGET', - value=os.getenv(checkpoint_target, 'zoobot.ckpt')), + value=checkpoint_target), # setup error reporting service batchmodels.EnvironmentSetting( name='HONEYBADGER_API_KEY', @@ -149,7 +153,7 @@ def job_logs_path(job_id, task_id, suffix): return f'{job_dir(job_id)}/task_logs/job_{job_id}_task_{task_id}_{suffix}.txt' -def create_job_tasks(job_id, task_id=1, run_opts=''): +def create_job_tasks(job_id, task_id=1, options: Options=Options()): # for persisting stdout and stderr log files in container storage container_sas_url = batch_jobs.storage_container_sas_url( os.getenv('PREDICTIONS_STORAGE_CONTAINER', 'predictions')) @@ -186,10 +190,10 @@ def create_job_tasks(job_id, task_id=1, run_opts=''): tasks = [] # ZOOBOT command for catalogue predictions! # see jobPreparation task for code setup - prediction_code_path = os.getenv('ZOOBOT_PREDICTION_CMD', 'predict_catalog_with_model.py') + prediction_code_path = resolve_prediction_script_path(options) setup_hugging_face_cache_env_var = f'HF_HOME={huggingface_dir}' # TODO: perhaps we can add the output file extension as a job env param that can be modified by job runtime params - escaped_opts = run_opts.replace('"','\\"') + escaped_opts = options.run_opts.replace('"','\\"') prediction_cmd = f'$AZ_BATCH_NODE_SHARED_DIR/{prediction_code_path} {escaped_opts} --checkpoint-path $AZ_BATCH_NODE_MOUNTS_DIR/$MODELS_CONTAINER_MOUNT_DIR/$ZOOBOT_CHECKPOINT_TARGET --catalog-url $MANIFEST_URL --save-path $AZ_BATCH_NODE_MOUNTS_DIR/$PREDICTIONS_CONTAINER_MOUNT_DIR/$PREDICTIONS_JOB_RESULTS_DIR/predictions.json' # redirect the stdout to stderr for logging command = f'/bin/bash -c \"set -ex; {setup_hugging_face_cache_env_var}; python {prediction_cmd}\"' @@ -200,7 +204,7 @@ def create_job_tasks(job_id, task_id=1, run_opts=''): id=str(task_id), command_line=command, container_settings=batchmodels.TaskContainerSettings( - image_name=os.getenv('CONTAINER_IMAGE_NAME'), + image_name=resolve_container_image_name(options), working_directory='taskWorkingDirectory', container_run_options='--ipc=host' ), diff --git a/bajor/batch/runtime_config.py b/bajor/batch/runtime_config.py new file mode 100644 index 0000000..5a9209c --- /dev/null +++ b/bajor/batch/runtime_config.py @@ -0,0 +1,53 @@ +import os +from urllib.parse import urlparse + +from bajor.batch.checkpoint_strategies import get_checkpoint_target +from bajor.models.job import Options + + +MODELS_CONTAINER_MOUNT_DIR_ENV = 'MODELS_CONTAINER_MOUNT_DIR' +MODELS_CONTAINER_NAME = 'models' +def _models_mount_root() -> str: + return f'$AZ_BATCH_NODE_MOUNTS_DIR/${MODELS_CONTAINER_MOUNT_DIR_ENV}' + + +def _relative_checkpoint_path(checkpoint_ref: str) -> str: + if checkpoint_ref.startswith('http://') or checkpoint_ref.startswith('https://'): + checkpoint_path = urlparse(checkpoint_ref).path.lstrip('/') + if checkpoint_path.startswith(f'{MODELS_CONTAINER_NAME}/'): + return checkpoint_path[len(f'{MODELS_CONTAINER_NAME}/'):] + return checkpoint_path + + return checkpoint_ref.lstrip('/') + + +def resolve_checkpoint_target(options: Options) -> str: + if options.pretrained_checkpoint_url: + return _relative_checkpoint_path(options.pretrained_checkpoint_url) + + checkpoint_target = get_checkpoint_target(options.workflow_name) + return os.getenv(checkpoint_target, 'zoobot.ckpt') + + +def resolve_pretrained_checkpoint_path(options: Options) -> str: + if options.pretrained_checkpoint_url: + return f'{_models_mount_root()}/{_relative_checkpoint_path(options.pretrained_checkpoint_url)}' + + checkpoint_file = os.getenv('ZOOBOT_FINETUNE_CHECKPOINT_FILE', 'zoobot_pretrained_model.ckpt') + return f'{_models_mount_root()}/{checkpoint_file}' + + +def resolve_training_script_path(options: Options) -> str: + return options.training_script_path or os.getenv('ZOOBOT_FINETUNE_TRAIN_CMD', 'train_model_finetune_on_catalog.py') + + +def resolve_prediction_script_path(options: Options) -> str: + return options.prediction_script_path or os.getenv('ZOOBOT_PREDICTION_CMD', 'predict_catalog_with_model.py') + + +def resolve_promote_script_path(options: Options) -> str: + return options.promote_script_path or os.getenv('ZOOBOT_PROMOTE_CMD', 'promote_best_checkpoint_to_model.sh') + + +def resolve_container_image_name(options: Options) -> str: + return options.container_image_name or os.getenv('CONTAINER_IMAGE_NAME') diff --git a/bajor/batch/train_finetuning.py b/bajor/batch/train_finetuning.py index 4a696d4..c3269a2 100644 --- a/bajor/batch/train_finetuning.py +++ b/bajor/batch/train_finetuning.py @@ -1,7 +1,13 @@ # training job specific functions import logging, os, sys -from bajor.batch.checkpoint_strategies import get_checkpoint_target +from bajor.batch.runtime_config import ( + resolve_container_image_name, + resolve_checkpoint_target, + resolve_pretrained_checkpoint_path, + resolve_promote_script_path, + resolve_training_script_path, +) if os.getenv('DEBUG'): import pdb @@ -30,20 +36,20 @@ def get_non_active_batch_job_list(): # schedule a training job def schedule_job(job_id: str, manifest_path:str, options: Options=Options()): - checkpoint_target = get_checkpoint_target(options.workflow_name) - submitted_job_id = create_batch_job( - job_id=job_id, manifest_container_path=manifest_path, pool_id=training_pool_id, checkpoint_target=checkpoint_target) + job_id=job_id, manifest_container_path=manifest_path, pool_id=training_pool_id, options=options) job_task_submission_status = create_job_tasks( - job_id=job_id, run_opts=options.run_opts) + job_id=job_id, options=options) # return the submitted job_id and task submission status dict return batch_jobs.job_submission_response(submitted_job_id, job_task_submission_status) -def create_batch_job(job_id, manifest_container_path, pool_id, checkpoint_target='ZOOBOT_CHECKPOINT_TARGET'): +def create_batch_job(job_id, manifest_container_path, pool_id, options: Options=Options()): log.debug('server_job, create_batch_job, using manifest from path: {}'.format( manifest_container_path)) + checkpoint_target = resolve_checkpoint_target(options) + log.debug(f'BatchJobManager, create_job, job_id: {job_id}') job = batchmodels.JobAddParameter( id=job_id, @@ -84,7 +90,7 @@ def create_batch_job(job_id, manifest_container_path, pool_id, checkpoint_target # set the zoobot saved model checkpoint file path batchmodels.EnvironmentSetting( name='ZOOBOT_CHECKPOINT_TARGET', - value=os.getenv(checkpoint_target, 'zoobot.ckpt')), + value=checkpoint_target), # setup error reporting service batchmodels.EnvironmentSetting( name='HONEYBADGER_API_KEY', @@ -176,7 +182,7 @@ def training_job_logs_path(job_id, task_id, suffix): return f'{training_job_dir(job_id)}/task_logs/job_{job_id}_task_{task_id}_{suffix}.txt' -def create_job_tasks(job_id, task_id=1, run_opts=''): +def create_job_tasks(job_id, task_id=1, options: Options=Options()): # for persisting stdout and stderr log files in container storage container_sas_url = batch_jobs.storage_container_sas_url( os.getenv('TRAINING_STORAGE_CONTAINER', 'training')) @@ -220,14 +226,13 @@ def create_job_tasks(job_id, task_id=1, run_opts=''): # train_cmd file path is copied from blob storage into this runtime container # so this location is relative to the container paths and can be modified at runtime # see jobPreparation task for code setup - train_code_path = os.getenv('ZOOBOT_FINETUNE_TRAIN_CMD', 'train_model_finetune_on_catalog.py') - # checkpoint file is the base model for finetuning (transfer learning) - checkpoint_file = os.getenv('ZOOBOT_FINETUNE_CHECKPOINT_FILE', 'zoobot_pretrained_model.ckpt') + train_code_path = resolve_training_script_path(options) + checkpoint_path = resolve_pretrained_checkpoint_path(options) # setup the training cmd - escaped_opts = run_opts.replace('"','\\"') - train_cmd = f'$AZ_BATCH_NODE_SHARED_DIR/{train_code_path} {escaped_opts} --checkpoint $AZ_BATCH_NODE_MOUNTS_DIR/$MODELS_CONTAINER_MOUNT_DIR/{checkpoint_file} --catalog $AZ_BATCH_NODE_MOUNTS_DIR/$TRAINING_CONTAINER_MOUNT_DIR/$MANIFEST_PATH --save-dir $AZ_BATCH_NODE_MOUNTS_DIR/$TRAINING_CONTAINER_MOUNT_DIR/$TRAINING_JOB_RESULTS_DIR/' + escaped_opts = options.run_opts.replace('"','\\"') + train_cmd = f'$AZ_BATCH_NODE_SHARED_DIR/{train_code_path} {escaped_opts} --checkpoint {checkpoint_path} --catalog $AZ_BATCH_NODE_MOUNTS_DIR/$TRAINING_CONTAINER_MOUNT_DIR/$MANIFEST_PATH --save-dir $AZ_BATCH_NODE_MOUNTS_DIR/$TRAINING_CONTAINER_MOUNT_DIR/$TRAINING_JOB_RESULTS_DIR/' # and a way to promote the resulting model artifact for use in prediction systems - promote_model_code_path = os.getenv('ZOOBOT_PROMOTE_CMD', 'promote_best_checkpoint_to_model.sh') + promote_model_code_path = resolve_promote_script_path(options) # redirect the stdout to stderr for logging promote_checkpoint_cmd = f'$AZ_BATCH_NODE_SHARED_DIR/{promote_model_code_path} $AZ_BATCH_NODE_MOUNTS_DIR/$TRAINING_CONTAINER_MOUNT_DIR/$TRAINING_JOB_RESULTS_DIR 2>&1' # ensure pytorch has the correct kernel cach path (this enables CUDA JIT - https://pytorch.org/docs/stable/notes/cuda.html#just-in-time-compilation) @@ -247,7 +252,7 @@ def create_job_tasks(job_id, task_id=1, run_opts=''): id=str(task_id), command_line=command, container_settings=batchmodels.TaskContainerSettings( - image_name=os.getenv('CONTAINER_IMAGE_NAME'), + image_name=resolve_container_image_name(options), working_directory='taskWorkingDirectory', container_run_options='--ipc=host' ), diff --git a/bajor/models/job.py b/bajor/models/job.py index 96cb85d..9d1abd0 100644 --- a/bajor/models/job.py +++ b/bajor/models/job.py @@ -1,9 +1,14 @@ from pydantic import BaseModel, HttpUrl -from typing import Optional, Dict +from typing import Optional class Options(BaseModel): run_opts: str = "" workflow_name: str = 'cosmic_dawn' + container_image_name: Optional[str] = None + training_script_path: Optional[str] = None + prediction_script_path: Optional[str] = None + promote_script_path: Optional[str] = None + pretrained_checkpoint_url: Optional[str] = None class TrainingJob(BaseModel): diff --git a/tests/batch/test_runtime_config.py b/tests/batch/test_runtime_config.py new file mode 100644 index 0000000..c897a1e --- /dev/null +++ b/tests/batch/test_runtime_config.py @@ -0,0 +1,31 @@ +from bajor.batch.runtime_config import ( + resolve_container_image_name, + resolve_checkpoint_target, + resolve_pretrained_checkpoint_path, +) +from bajor.models.job import Options + + +def test_explicit_checkpoint_path_overrides_legacy_checkpoint_target(): + options = Options( + workflow_name='euclid', + pretrained_checkpoint_url='jwst/custom.ckpt' + ) + + assert resolve_checkpoint_target(options) == 'jwst/custom.ckpt' + assert resolve_pretrained_checkpoint_path(options) == '$AZ_BATCH_NODE_MOUNTS_DIR/$MODELS_CONTAINER_MOUNT_DIR/jwst/custom.ckpt' + + +def test_blob_url_checkpoint_ref_is_normalized_to_relative_models_path(): + options = Options( + pretrained_checkpoint_url='https://kadeactivelearning.blob.core.windows.net/models/staging-euclid-zoobot.ckpt' + ) + + assert resolve_checkpoint_target(options) == 'staging-euclid-zoobot.ckpt' + assert resolve_pretrained_checkpoint_path(options) == '$AZ_BATCH_NODE_MOUNTS_DIR/$MODELS_CONTAINER_MOUNT_DIR/staging-euclid-zoobot.ckpt' + + +def test_explicit_container_image_name_overrides_env(): + options = Options(container_image_name='zoobot.azurecr.io/pytorch:custom-jwst') + + assert resolve_container_image_name(options) == 'zoobot.azurecr.io/pytorch:custom-jwst' diff --git a/tests/batch/test_training.py b/tests/batch/test_training.py index 64f1f15..3cd39b6 100644 --- a/tests/batch/test_training.py +++ b/tests/batch/test_training.py @@ -1,8 +1,10 @@ import bajor.batch.train_from_scratch as train_from_scratch import bajor.batch.train_finetuning as train_finetuning +import bajor.batch.predictions as predictions from bajor.batch.jobs import active_jobs_running import uuid, os from unittest import mock +from bajor.models.job import Options fake_job_id = str(uuid.uuid4()) test_pool = 'pool' @@ -40,9 +42,9 @@ def test_schedule_job(mock_create_job_tasks, mock_create_batch_job): def test_no_active_jobs(mock_create_job_tasks, mock_create_batch_job): train_finetuning.schedule_job(fake_job_id, 'fake-manifest.csv') mock_create_batch_job.assert_called_once_with( - job_id=fake_job_id, manifest_container_path='fake-manifest.csv', pool_id='training_1', checkpoint_target= 'ZOOBOT_CHECKPOINT_TARGET') + job_id=fake_job_id, manifest_container_path='fake-manifest.csv', pool_id='training_1', options=Options()) mock_create_job_tasks.assert_called_once_with( - job_id=fake_job_id, run_opts='') + job_id=fake_job_id, options=Options()) @mock.patch('bajor.batch.train_finetuning.create_batch_job') @@ -58,3 +60,22 @@ def test_schedule_job(mock_create_job_tasks, mock_create_batch_job): result_dict = train_finetuning.schedule_job( submitted_job_id, 'fake-manifest-path.csv') assert(result_dict) == expected_result_dict + + +@mock.patch('bajor.batch.predictions.create_batch_job') +@mock.patch('bajor.batch.predictions.create_job_tasks') +def test_prediction_schedule_job_uses_options(mock_create_job_tasks, mock_create_batch_job): + options = Options( + prediction_script_path='predict_catalog_with_model.py', + pretrained_checkpoint_url='custom.ckpt' + ) + + predictions.schedule_job(fake_job_id, 'https://manifest-host/predictions.json', options) + + mock_create_batch_job.assert_called_once_with( + job_id=fake_job_id, + manifest_url='https://manifest-host/predictions.json', + pool_id='predictions_0', + options=options + ) + mock_create_job_tasks.assert_called_once_with(job_id=fake_job_id, options=options) diff --git a/tests/test_api.py b/tests/test_api.py index c4d8073..031f165 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -77,4 +77,16 @@ def test_batch_scheduling_code_is_called(mocked_client): assert response.status_code == 201 assert response.json() == { - 'manifest_path': 'test_manifest_file_path.csv', 'id': submitted_job_id, 'opts': {'run_opts': run_opts, 'workflow_name': 'cosmic_dawn'}, 'status': {"status": "started", "message": "Job submitted successfully"}} + 'manifest_path': 'test_manifest_file_path.csv', + 'id': submitted_job_id, + 'opts': { + 'run_opts': run_opts, + 'workflow_name': 'cosmic_dawn', + 'container_image_name': None, + 'training_script_path': None, + 'prediction_script_path': None, + 'promote_script_path': None, + 'pretrained_checkpoint_url': None + }, + 'status': {"status": "started", "message": "Job submitted successfully"} + } From 213f5b7ad89745533419608262e7705a3b5f8866 Mon Sep 17 00:00:00 2001 From: tooyosi Date: Wed, 15 Apr 2026 18:27:59 +0100 Subject: [PATCH 2/2] rename Options to JobOptions --- bajor/batch/predictions.py | 9 ++++----- bajor/batch/runtime_config.py | 14 +++++++------- bajor/batch/train_finetuning.py | 9 ++++----- bajor/models/job.py | 6 +++--- tests/batch/test_runtime_config.py | 8 ++++---- tests/batch/test_training.py | 8 ++++---- 6 files changed, 26 insertions(+), 28 deletions(-) diff --git a/bajor/batch/predictions.py b/bajor/batch/predictions.py index 45ff456..1007a34 100644 --- a/bajor/batch/predictions.py +++ b/bajor/batch/predictions.py @@ -15,7 +15,7 @@ from bajor.batch.client import azure_batch_client import bajor.batch.jobs as batch_jobs from bajor.log_config import log -from bajor.models.job import Options +from bajor.models.job import JobOptions # Zoobot Azure Batch predictions pool ID predictions_pool_id = os.getenv('POOL_ID', 'predictions_0') @@ -32,7 +32,7 @@ def get_non_active_batch_job_list(): return batch_jobs.get_non_active_batch_job_list(predictions_pool_id) # schedule a training job -def schedule_job(job_id:str, manifest_url:str, options:Options=Options()): +def schedule_job(job_id:str, manifest_url:str, options:JobOptions=JobOptions()): submitted_job_id = create_batch_job( job_id=job_id, manifest_url=manifest_url, pool_id=predictions_pool_id, options=options) job_task_submission_status = create_job_tasks( @@ -42,7 +42,7 @@ def schedule_job(job_id:str, manifest_url:str, options:Options=Options()): return batch_jobs.job_submission_response(submitted_job_id, job_task_submission_status) -def create_batch_job(job_id, manifest_url, pool_id, options: Options=Options()): +def create_batch_job(job_id, manifest_url, pool_id, options: JobOptions=JobOptions()): log.debug('server_job, create_batch_job, using manifest at url: {}'.format(manifest_url)) checkpoint_target = resolve_checkpoint_target(options) @@ -153,7 +153,7 @@ def job_logs_path(job_id, task_id, suffix): return f'{job_dir(job_id)}/task_logs/job_{job_id}_task_{task_id}_{suffix}.txt' -def create_job_tasks(job_id, task_id=1, options: Options=Options()): +def create_job_tasks(job_id, task_id=1, options: JobOptions=JobOptions()): # for persisting stdout and stderr log files in container storage container_sas_url = batch_jobs.storage_container_sas_url( os.getenv('PREDICTIONS_STORAGE_CONTAINER', 'predictions')) @@ -243,4 +243,3 @@ def create_job_tasks(job_id, task_id=1, options: Options=Options()): format = '[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', stream = sys.stdout ) - diff --git a/bajor/batch/runtime_config.py b/bajor/batch/runtime_config.py index 5a9209c..ce49631 100644 --- a/bajor/batch/runtime_config.py +++ b/bajor/batch/runtime_config.py @@ -2,7 +2,7 @@ from urllib.parse import urlparse from bajor.batch.checkpoint_strategies import get_checkpoint_target -from bajor.models.job import Options +from bajor.models.job import JobOptions MODELS_CONTAINER_MOUNT_DIR_ENV = 'MODELS_CONTAINER_MOUNT_DIR' @@ -21,7 +21,7 @@ def _relative_checkpoint_path(checkpoint_ref: str) -> str: return checkpoint_ref.lstrip('/') -def resolve_checkpoint_target(options: Options) -> str: +def resolve_checkpoint_target(options: JobOptions) -> str: if options.pretrained_checkpoint_url: return _relative_checkpoint_path(options.pretrained_checkpoint_url) @@ -29,7 +29,7 @@ def resolve_checkpoint_target(options: Options) -> str: return os.getenv(checkpoint_target, 'zoobot.ckpt') -def resolve_pretrained_checkpoint_path(options: Options) -> str: +def resolve_pretrained_checkpoint_path(options: JobOptions) -> str: if options.pretrained_checkpoint_url: return f'{_models_mount_root()}/{_relative_checkpoint_path(options.pretrained_checkpoint_url)}' @@ -37,17 +37,17 @@ def resolve_pretrained_checkpoint_path(options: Options) -> str: return f'{_models_mount_root()}/{checkpoint_file}' -def resolve_training_script_path(options: Options) -> str: +def resolve_training_script_path(options: JobOptions) -> str: return options.training_script_path or os.getenv('ZOOBOT_FINETUNE_TRAIN_CMD', 'train_model_finetune_on_catalog.py') -def resolve_prediction_script_path(options: Options) -> str: +def resolve_prediction_script_path(options: JobOptions) -> str: return options.prediction_script_path or os.getenv('ZOOBOT_PREDICTION_CMD', 'predict_catalog_with_model.py') -def resolve_promote_script_path(options: Options) -> str: +def resolve_promote_script_path(options: JobOptions) -> str: return options.promote_script_path or os.getenv('ZOOBOT_PROMOTE_CMD', 'promote_best_checkpoint_to_model.sh') -def resolve_container_image_name(options: Options) -> str: +def resolve_container_image_name(options: JobOptions) -> str: return options.container_image_name or os.getenv('CONTAINER_IMAGE_NAME') diff --git a/bajor/batch/train_finetuning.py b/bajor/batch/train_finetuning.py index c3269a2..0128185 100644 --- a/bajor/batch/train_finetuning.py +++ b/bajor/batch/train_finetuning.py @@ -17,7 +17,7 @@ from bajor.batch.client import azure_batch_client import bajor.batch.jobs as batch_jobs from bajor.log_config import log -from bajor.models.job import Options +from bajor.models.job import JobOptions # Zoobot Azure Batch training pool ID training_pool_id = os.getenv('POOL_ID', 'training_1') @@ -35,7 +35,7 @@ def get_non_active_batch_job_list(): return batch_jobs.get_non_active_batch_job_list(training_pool_id) # schedule a training job -def schedule_job(job_id: str, manifest_path:str, options: Options=Options()): +def schedule_job(job_id: str, manifest_path:str, options: JobOptions=JobOptions()): submitted_job_id = create_batch_job( job_id=job_id, manifest_container_path=manifest_path, pool_id=training_pool_id, options=options) job_task_submission_status = create_job_tasks( @@ -44,7 +44,7 @@ def schedule_job(job_id: str, manifest_path:str, options: Options=Options()): # return the submitted job_id and task submission status dict return batch_jobs.job_submission_response(submitted_job_id, job_task_submission_status) -def create_batch_job(job_id, manifest_container_path, pool_id, options: Options=Options()): +def create_batch_job(job_id, manifest_container_path, pool_id, options: JobOptions=JobOptions()): log.debug('server_job, create_batch_job, using manifest from path: {}'.format( manifest_container_path)) @@ -182,7 +182,7 @@ def training_job_logs_path(job_id, task_id, suffix): return f'{training_job_dir(job_id)}/task_logs/job_{job_id}_task_{task_id}_{suffix}.txt' -def create_job_tasks(job_id, task_id=1, options: Options=Options()): +def create_job_tasks(job_id, task_id=1, options: JobOptions=JobOptions()): # for persisting stdout and stderr log files in container storage container_sas_url = batch_jobs.storage_container_sas_url( os.getenv('TRAINING_STORAGE_CONTAINER', 'training')) @@ -291,4 +291,3 @@ def create_job_tasks(job_id, task_id=1, options: Options=Options()): format = '[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', stream = sys.stdout ) - diff --git a/bajor/models/job.py b/bajor/models/job.py index 9d1abd0..10743b6 100644 --- a/bajor/models/job.py +++ b/bajor/models/job.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, HttpUrl from typing import Optional -class Options(BaseModel): +class JobOptions(BaseModel): run_opts: str = "" workflow_name: str = 'cosmic_dawn' container_image_name: Optional[str] = None @@ -15,7 +15,7 @@ class TrainingJob(BaseModel): manifest_path: str id: Optional[str] = None status: Optional[str] = None - opts: Options = Options() + opts: JobOptions = JobOptions() # remove the leading / from the manifest url # as it's added via the blob storage paths in schedule_job @@ -27,4 +27,4 @@ class PredictionJob(BaseModel): manifest_url: HttpUrl id: Optional[str] = None status: Optional[str] = None - opts: Options = Options() + opts: JobOptions = JobOptions() diff --git a/tests/batch/test_runtime_config.py b/tests/batch/test_runtime_config.py index c897a1e..2c9f5ec 100644 --- a/tests/batch/test_runtime_config.py +++ b/tests/batch/test_runtime_config.py @@ -3,11 +3,11 @@ resolve_checkpoint_target, resolve_pretrained_checkpoint_path, ) -from bajor.models.job import Options +from bajor.models.job import JobOptions def test_explicit_checkpoint_path_overrides_legacy_checkpoint_target(): - options = Options( + options = JobOptions( workflow_name='euclid', pretrained_checkpoint_url='jwst/custom.ckpt' ) @@ -17,7 +17,7 @@ def test_explicit_checkpoint_path_overrides_legacy_checkpoint_target(): def test_blob_url_checkpoint_ref_is_normalized_to_relative_models_path(): - options = Options( + options = JobOptions( pretrained_checkpoint_url='https://kadeactivelearning.blob.core.windows.net/models/staging-euclid-zoobot.ckpt' ) @@ -26,6 +26,6 @@ def test_blob_url_checkpoint_ref_is_normalized_to_relative_models_path(): def test_explicit_container_image_name_overrides_env(): - options = Options(container_image_name='zoobot.azurecr.io/pytorch:custom-jwst') + options = JobOptions(container_image_name='zoobot.azurecr.io/pytorch:custom-jwst') assert resolve_container_image_name(options) == 'zoobot.azurecr.io/pytorch:custom-jwst' diff --git a/tests/batch/test_training.py b/tests/batch/test_training.py index 3cd39b6..7bcaa5d 100644 --- a/tests/batch/test_training.py +++ b/tests/batch/test_training.py @@ -4,7 +4,7 @@ from bajor.batch.jobs import active_jobs_running import uuid, os from unittest import mock -from bajor.models.job import Options +from bajor.models.job import JobOptions fake_job_id = str(uuid.uuid4()) test_pool = 'pool' @@ -42,9 +42,9 @@ def test_schedule_job(mock_create_job_tasks, mock_create_batch_job): def test_no_active_jobs(mock_create_job_tasks, mock_create_batch_job): train_finetuning.schedule_job(fake_job_id, 'fake-manifest.csv') mock_create_batch_job.assert_called_once_with( - job_id=fake_job_id, manifest_container_path='fake-manifest.csv', pool_id='training_1', options=Options()) + job_id=fake_job_id, manifest_container_path='fake-manifest.csv', pool_id='training_1', options=JobOptions()) mock_create_job_tasks.assert_called_once_with( - job_id=fake_job_id, options=Options()) + job_id=fake_job_id, options=JobOptions()) @mock.patch('bajor.batch.train_finetuning.create_batch_job') @@ -65,7 +65,7 @@ def test_schedule_job(mock_create_job_tasks, mock_create_batch_job): @mock.patch('bajor.batch.predictions.create_batch_job') @mock.patch('bajor.batch.predictions.create_job_tasks') def test_prediction_schedule_job_uses_options(mock_create_job_tasks, mock_create_batch_job): - options = Options( + options = JobOptions( prediction_script_path='predict_catalog_with_model.py', pretrained_checkpoint_url='custom.ckpt' )