diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index 5a1af3a..d9f90a4 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -13,19 +13,16 @@ jobs: with: fetch-depth: 0 - - name: Set up Python 3.7.x + - name: Set up Python 3.9.x uses: actions/setup-python@v4 with: - python-version: "3.7.16" + python-version: "3.9.20" - name: Install dependencies run: pip install .[dev] - - name: Lint - run: make lint - - - name: Format - run: make format + - name: Checks + run: make check - name: Test run: make test diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 65783a6..328ad1c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -7,12 +7,10 @@ on: # A workflow run is made up of one or more jobs that can run sequentially or in parallel jobs: - release-latest: - permissions: - id-token: write # to verify the deployment originates from an appropriate source - contents: write # To allow pushing tags/etc. + id-token: write # to verify the deployment originates from an appropriate source + contents: write # To allow pushing tags/etc. # Specify runner + deployment step runs-on: ubuntu-22.04 @@ -59,7 +57,7 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: "3.7.16" + python-version: "3.9.20" - name: Install dependencies run: pip install .[dev] - name: Build docs @@ -70,14 +68,13 @@ jobs: path: docs/_build/html release-docs: - # Add a dependency to the build job needs: build-docs # Grant GITHUB_TOKEN the permissions required to make a Pages deployment permissions: - pages: write # to deploy to Pages - id-token: write # to verify the deployment originates from an appropriate source + pages: write # to deploy to Pages + id-token: write # to verify the deployment originates from an appropriate source # Deploy to the github-pages environment environment: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..228f496 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.7.2 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 # Use the ref you want to point at + hooks: + - id: trailing-whitespace + - id: check-yaml + - id: pretty-format-json + args: + - "--autofix" + - "--indent=4" + - id: requirements-txt-fixer diff --git a/Makefile b/Makefile index 4e64dbf..330c719 100644 --- a/Makefile +++ b/Makefile @@ -5,16 +5,9 @@ FILES := $(shell git diff --name-only --diff-filter=AM $$(git merge-base origin/ test: pytest -vv -.PHONY: lint -lint: - ruff check . - -.PHONY: format -format: - ruff format . - -.PHONY: tidy -tidy: format lint +.PHONY: check +check: + pre-commit run -a # Removes the directory that contains bytecode cache files # that are automatically generated by python. diff --git a/pyproject.toml b/pyproject.toml index e7dec71..1fd61a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,12 +38,13 @@ dependencies = [ dynamic = ["description", "version"] name = "syncsparkpy" readme = "README.md" -requires-python = ">=3.7" +requires-python = ">=3.9" [project.optional-dependencies] dev = [ "Sphinx==4.3.0", "deepdiff==6.3.0", + "pre-commit==4.0.1", "pytest-asyncio==0.21.0", "pytest-env==0.8.1", "pytest==7.2.0", @@ -73,7 +74,7 @@ pythonpath = ["."] [tool.ruff] exclude = ["artifacts/*"] line-length = 100 -target-version = "py37" +target-version = "py39" [tool.ruff.lint] ignore = ["E501"] preview = true @@ -91,12 +92,12 @@ extend-immutable-calls = [ "fastapi.Security", ] [tool.ruff.lint.mccabe] -max-complexity = 10 +max-complexity = 20 [tool.pyright] pythonPlatform = "All" -pythonVersion = "3.7" +pythonVersion = "3.9" reportUnnecessaryTypeIgnoreComment = "error" typeCheckingMode = "standard" useLibraryCodeForTypes = false diff --git a/sync/__init__.py b/sync/__init__.py index 171b60f..8e2d6c4 100644 --- a/sync/__init__.py +++ b/sync/__init__.py @@ -1,5 +1,5 @@ """Library for leveraging the power of Sync""" -__version__ = "1.11.7" +__version__ = "2.0.0" TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" diff --git a/sync/_databricks.py b/sync/_databricks.py index 19c93f4..c14d49d 100644 --- a/sync/_databricks.py +++ b/sync/_databricks.py @@ -8,10 +8,11 @@ import time import zipfile from collections import defaultdict +from collections.abc import Collection from datetime import datetime, timezone from pathlib import Path from time import sleep -from typing import Collection, Dict, List, Optional, Set, Tuple, Union +from typing import Optional, Union from urllib.parse import urlparse import boto3 as boto @@ -70,9 +71,9 @@ def get_cluster(cluster_id: str) -> Response[dict]: def create_submission_with_cluster_info( run_id: str, project_id: str, - cluster: Dict, - cluster_info: Dict, - cluster_activity_events: Dict, + cluster: dict, + cluster_info: dict, + cluster_activity_events: dict, plan_type: DatabricksPlanType, compute_type: DatabricksComputeType, skip_eventlog: bool = False, @@ -185,7 +186,7 @@ def create_submission_for_run( def _create_submission( cluster_id: str, - tasks: List[dict], + tasks: list[dict], plan_type: str, compute_type: str, project_id: str, @@ -215,12 +216,12 @@ def _create_submission( def _get_run_information( cluster_id: str, - tasks: List[dict], + tasks: list[dict], plan_type: str, compute_type: str, allow_failed_tasks: bool = False, allow_incomplete_cluster_report: bool = False, -) -> Response[Tuple[DatabricksClusterReport, bytes]]: +) -> Response[tuple[DatabricksClusterReport, bytes]]: if not allow_failed_tasks and any( task["state"].get("result_state") != "SUCCESS" for task in tasks ): @@ -243,7 +244,7 @@ def _get_run_information( return cluster_report_response -def _get_event_log_from_cluster(cluster: Dict, tasks: List[Dict]) -> Response[bytes]: +def _get_event_log_from_cluster(cluster: dict, tasks: list[dict]) -> Response[bytes]: spark_context_id = _get_run_spark_context_id(tasks) end_time = max(task["end_time"] for task in tasks) eventlog_response = _get_eventlog(cluster, spark_context_id.result, end_time) @@ -257,7 +258,7 @@ def _get_event_log_from_cluster(cluster: Dict, tasks: List[Dict]) -> Response[by def _maybe_get_event_log_from_cluster( - cluster: Dict, tasks: List[Dict], dbfs_eventlog_file_size: Union[int, None] + cluster: dict, tasks: list[dict], dbfs_eventlog_file_size: Union[int, None] ) -> Response[bytes]: spark_context_id = _get_run_spark_context_id(tasks) end_time = max(task["end_time"] for task in tasks) @@ -330,7 +331,7 @@ def get_cluster_report( def _get_cluster_report( cluster_id: str, - cluster_tasks: List[dict], + cluster_tasks: list[dict], plan_type: str, compute_type: str, allow_incomplete: bool, @@ -342,7 +343,7 @@ def _create_cluster_report( cluster: dict, cluster_info: dict, cluster_activity_events: dict, - tasks: List[dict], + tasks: list[dict], plan_type: DatabricksPlanType, compute_type: DatabricksComputeType, ) -> DatabricksClusterReport: @@ -366,7 +367,7 @@ def handle_successful_job_run( project_id: Union[str, None] = None, allow_incomplete_cluster_report: bool = False, exclude_tasks: Union[Collection[str], None] = None, -) -> Response[Dict[str, str]]: +) -> Response[dict[str, str]]: """Create's Sync project submissions for each eligible cluster in the run (see :py:func:`~record_run`) If project ID is provided only submit run data for the cluster tagged with it, or the only cluster if there is such. @@ -455,7 +456,7 @@ def record_run( project_id: Union[str, None] = None, allow_incomplete_cluster_report: bool = False, exclude_tasks: Union[Collection[str], None] = None, -) -> Response[Dict[str, str]]: +) -> Response[dict[str, str]]: """Create's Sync project submissions for each eligible cluster in the run. If project ID is provided only submit run data for the cluster tagged with it, or the only cluster if there is such. @@ -508,11 +509,11 @@ def record_run( def _record_project_clusters( - project_cluster_tasks: Dict[str, Tuple[str, List[dict]]], + project_cluster_tasks: dict[str, tuple[str, list[dict]]], plan_type: str, compute_type: str, allow_incomplete_cluster_report: bool, -) -> Dict[str, str]: +) -> dict[str, str]: """Creates project submissions/predictions and returns a map of project IDs to the new submissions/predictions IDs""" result_ids = {} for cluster_project_id, (cluster_id, tasks) in project_cluster_tasks.items(): @@ -792,7 +793,7 @@ def get_project_cluster_settings( return Response(result=cluster_template) -def run_job_object(job: dict) -> Response[Tuple[str, str]]: +def run_job_object(job: dict) -> Response[tuple[str, str]]: """Create a Databricks one-off run based on the job configuration. :param job: Databricks job object @@ -1134,7 +1135,7 @@ def _wait_for_cluster_termination( def _cluster_log_destination( cluster: dict, -) -> Union[Tuple[str, str, str, str], Tuple[None, None, None, None]]: +) -> Union[tuple[str, str, str, str], tuple[None, None, None, None]]: cluster_log_conf = cluster.get("cluster_log_conf", {}) s3_log_url = cluster_log_conf.get("s3", {}).get("destination") dbfs_log_url = cluster_log_conf.get("dbfs", {}).get("destination") @@ -1156,7 +1157,7 @@ def _cluster_log_destination( return None, None, None, None -def _get_job_cluster(tasks: List[dict], job_clusters: list) -> Response[dict]: +def _get_job_cluster(tasks: list[dict], job_clusters: list) -> Response[dict]: if len(tasks) == 1: return _get_task_cluster(tasks[0], job_clusters) @@ -1173,7 +1174,7 @@ def _get_job_cluster(tasks: List[dict], job_clusters: list) -> Response[dict]: def _get_project_job_clusters( job: dict, exclude_tasks: Union[Collection[str], None] = None, -) -> Dict[str, Tuple[Tuple[str], dict]]: +) -> dict[str, tuple[tuple[str], dict]]: """Returns a mapping of project IDs to cluster paths and clusters. Cluster paths are tuples that can be used to locate clusters in a job object, e.g. @@ -1215,7 +1216,7 @@ def _get_project_cluster_tasks( project_id: Optional[str] = None, cluster_path: Optional[str] = None, exclude_tasks: Union[Collection[str], None] = None, -) -> Dict[str, Tuple[str, List[dict]]]: +) -> dict[str, tuple[str, list[dict]]]: """Returns a mapping of project IDs to cluster-ID-tasks pairs""" project_cluster_tasks = _get_cluster_tasks(run, exclude_tasks) @@ -1265,7 +1266,7 @@ def _get_project_cluster_tasks( def _get_cluster_tasks( run: dict, exclude_tasks: Union[Collection[str], None] = None, -) -> Dict[str, Dict[str, Tuple[str, List[dict]]]]: +) -> dict[str, dict[str, tuple[str, list[dict]]]]: """Returns a mapping of project IDs to cluster paths to cluster IDs and tasks""" job_clusters = {c["job_cluster_key"]: c["new_cluster"] for c in run.get("job_clusters", [])} @@ -1311,7 +1312,7 @@ def _get_cluster_tasks( return result_cluster_project_tasks -def _get_run_spark_context_id(tasks: List[dict]) -> Response[str]: +def _get_run_spark_context_id(tasks: list[dict]) -> Response[str]: context_ids = { task["cluster_instance"]["spark_context_id"] for task in tasks if "cluster_instance" in task } @@ -1341,7 +1342,7 @@ def _get_task_cluster(task: dict, clusters: list) -> Response[dict]: return Response(result=cluster) -def _s3_contents_have_all_rollover_logs(contents: List[dict], run_end_time_seconds: float): +def _s3_contents_have_all_rollover_logs(contents: list[dict], run_end_time_seconds: float): final_rollover_log = contents and next( ( content @@ -1380,7 +1381,7 @@ def _dbfs_directory_has_all_rollover_logs(contents: dict, run_end_time_millis: f ) -def _dbfs_any_file_has_zero_size(dbfs_contents: Dict) -> bool: +def _dbfs_any_file_has_zero_size(dbfs_contents: dict) -> bool: any_zeros = any(file["file_size"] == 0 for file in dbfs_contents["files"]) if any_zeros: logger.info("One or more dbfs event log files has a file size of zero") @@ -1388,8 +1389,8 @@ def _dbfs_any_file_has_zero_size(dbfs_contents: Dict) -> bool: def _check_total_file_size_changed( - last_total_file_size: int, dbfs_contents: Dict -) -> Tuple[bool, int]: + last_total_file_size: int, dbfs_contents: dict +) -> tuple[bool, int]: new_total_file_size = sum([file.get("file_size", 0) for file in dbfs_contents.get("files", {})]) if new_total_file_size == last_total_file_size: return False, new_total_file_size @@ -1687,9 +1688,9 @@ def get_all_cluster_events(cluster_id: str): def _update_monitored_timelines( - running_instance_ids: Set[str], - active_timelines_by_id: Dict[str, dict], -) -> Tuple[Dict[str, dict], List[dict]]: + running_instance_ids: set[str], + active_timelines_by_id: dict[str, dict], +) -> tuple[dict[str, dict], list[dict]]: """ Shared monitoring method for both Azure and Databricks to reduce complexity. Compares the current running instances (keyed by id) to the running diff --git a/sync/api/projects.py b/sync/api/projects.py index 6209749..f6879c1 100644 --- a/sync/api/projects.py +++ b/sync/api/projects.py @@ -4,7 +4,7 @@ import json import logging from time import sleep -from typing import List, Optional, Union +from typing import Optional, Union from urllib.parse import urlparse import httpx @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) -def get_products() -> Response[List[str]]: +def get_products() -> Response[list[str]]: """Get supported platforms :return: list of platform names :rtype: Response[list[str]] @@ -221,7 +221,7 @@ def get_project_by_app_id(app_id: str) -> Response[dict]: return Response(error=ProjectError(message=f"No project found for '{app_id}'")) -def get_projects(app_id: Optional[str] = None) -> Response[List[dict]]: +def get_projects(app_id: Optional[str] = None) -> Response[list[dict]]: """Returns all projects authorized by the API key :param app_id: app ID to filter by, defaults to None @@ -552,7 +552,3 @@ def get_updated_cluster_definition( return Response(result=recommendation_cluster) else: return rec_response - - -# Old typo, to keep the API consistent we also define the one with wrong name: -get_updated_cluster_defintion = get_updated_cluster_definition diff --git a/sync/api/workspace.py b/sync/api/workspace.py index 2d6a3e4..f2c9cf0 100644 --- a/sync/api/workspace.py +++ b/sync/api/workspace.py @@ -1,5 +1,3 @@ -from typing import List - from sync.clients.sync import get_default_client from sync.models import CreateWorkspaceConfig, Response, UpdateWorkspaceConfig @@ -36,7 +34,7 @@ def get_workspace_config(workspace_id: str) -> Response[dict]: return Response(**get_default_client().get_workspace_config(workspace_id)) -def get_workspace_configs() -> Response[List[dict]]: +def get_workspace_configs() -> Response[list[dict]]: return Response(**get_default_client().get_workspace_configs()) diff --git a/sync/awsdatabricks.py b/sync/awsdatabricks.py index 229ea78..aa07884 100644 --- a/sync/awsdatabricks.py +++ b/sync/awsdatabricks.py @@ -1,8 +1,9 @@ import json import logging +from collections.abc import Generator from pathlib import Path from time import sleep -from typing import Dict, Generator, List, Optional, Tuple +from typing import Optional from urllib.parse import urlparse import boto3 as boto @@ -187,7 +188,7 @@ def get_access_report(log_url: Optional[str] = None) -> AccessReport: def _get_cluster_report( cluster_id: str, - cluster_tasks: List[dict], + cluster_tasks: list[dict], plan_type: str, compute_type: str, allow_incomplete: bool, @@ -242,7 +243,7 @@ def _create_cluster_report( cluster: dict, cluster_info: dict, cluster_activity_events: dict, - tasks: List[dict], + tasks: list[dict], plan_type: DatabricksPlanType, compute_type: DatabricksComputeType, ) -> AWSDatabricksClusterReport: @@ -268,7 +269,7 @@ def _create_cluster_report( setattr(sync._databricks, "__claim", __name__) -def _load_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict]]: +def _load_aws_cluster_info(cluster: dict) -> tuple[Response[dict], Response[dict]]: cluster_info = None cluster_id = None cluster_log_dest = _cluster_log_destination(cluster) @@ -311,7 +312,7 @@ def _load_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict return cluster_info, cluster_id -def _get_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict], Response[dict]]: +def _get_aws_cluster_info(cluster: dict) -> tuple[Response[dict], Response[dict], Response[dict]]: aws_region_name = DB_CONFIG.aws_region_name cluster_info, cluster_id = _load_aws_cluster_info(cluster) @@ -351,9 +352,9 @@ def _get_aws_cluster_info_from_s3(bucket: str, file_key: str, cluster_id): def save_cluster_report( cluster_id: str, - instance_timelines: List[dict], - cluster_log_destination: Optional[Tuple[str, ...]] = None, - cluster_report_destination_override: Optional[Dict[str, str]] = None, + instance_timelines: list[dict], + cluster_log_destination: Optional[tuple[str, ...]] = None, + cluster_report_destination_override: Optional[dict[str, str]] = None, write_function=None, ) -> bool: cluster = get_default_client().get_cluster(cluster_id) @@ -570,7 +571,7 @@ def write_file(body: bytes): return write_file -def _get_ec2_instances(cluster_id: str, ec2_client: "botocore.client.ec2") -> List[dict]: +def _get_ec2_instances(cluster_id: str, ec2_client: "botocore.client.ec2") -> list[dict]: filters = [ {"Name": "tag:Vendor", "Values": ["Databricks"]}, {"Name": "tag:ClusterId", "Values": [cluster_id]}, @@ -594,8 +595,8 @@ def _get_ec2_instances(cluster_id: str, ec2_client: "botocore.client.ec2") -> Li def _get_ebs_volumes_for_instances( - instances: List[dict], ec2_client: "botocore.client.ec2" -) -> List[dict]: + instances: list[dict], ec2_client: "botocore.client.ec2" +) -> list[dict]: """Get all ebs volumes associated with a list of instance reservations""" def get_chunk(instance_ids: list[str], chunk_size: int) -> Generator[list[str], None, None]: diff --git a/sync/azuredatabricks.py b/sync/azuredatabricks.py index 029ddff..676f07a 100644 --- a/sync/azuredatabricks.py +++ b/sync/azuredatabricks.py @@ -4,7 +4,7 @@ import sys from pathlib import Path from time import sleep -from typing import Dict, List, Optional, Type, TypeVar, Union +from typing import Optional, TypeVar, Union from urllib.parse import urlparse from azure.common.credentials import get_cli_profile @@ -195,7 +195,7 @@ def get_access_report(log_url: Optional[str] = None) -> AccessReport: def _get_cluster_report( cluster_id: str, - cluster_tasks: List[dict], + cluster_tasks: list[dict], plan_type: str, compute_type: str, allow_incomplete: bool, @@ -233,7 +233,7 @@ def _create_cluster_report( cluster: dict, cluster_info: dict, cluster_activity_events: dict, - tasks: List[dict], + tasks: list[dict], plan_type: DatabricksPlanType, compute_type: DatabricksComputeType, ) -> AzureDatabricksClusterReport: @@ -492,7 +492,7 @@ def set_azure_client_credentials( _azure_credential = azure_credential -def _get_azure_client(azure_client_class: Type[AzureClient]) -> AzureClient: +def _get_azure_client(azure_client_class: type[AzureClient]) -> AzureClient: global _azure_subscription_id if not _azure_subscription_id: _azure_subscription_id = _get_azure_subscription_id() @@ -516,7 +516,7 @@ def _get_azure_subscription_id(): def _get_running_vms_by_id( compute: AzureClient, resource_group_name: Optional[str], cluster_id: str -) -> Dict[str, dict]: +) -> dict[str, dict]: if resource_group_name: vms = compute.virtual_machines.list(resource_group_name=resource_group_name) else: diff --git a/sync/cli/_databricks.py b/sync/cli/_databricks.py index fef4165..060d18a 100644 --- a/sync/cli/_databricks.py +++ b/sync/cli/_databricks.py @@ -1,5 +1,5 @@ import json -from typing import Optional, Tuple +from typing import Optional import click @@ -56,7 +56,7 @@ def create_submission( compute: DatabricksComputeType, project: dict, allow_incomplete: bool = False, - exclude_task: Optional[Tuple[str, ...]] = None, + exclude_task: Optional[tuple[str, ...]] = None, ): """Create a submission for a job run""" if platform is Platform.AWS_DATABRICKS: @@ -159,7 +159,7 @@ def get_cluster_report( compute: DatabricksComputeType, project: Optional[dict] = None, allow_incomplete: bool = False, - exclude_task: Optional[Tuple[str, ...]] = None, + exclude_task: Optional[tuple[str, ...]] = None, ): """Get a cluster report""" if platform is Platform.AWS_DATABRICKS: diff --git a/sync/clients/__init__.py b/sync/clients/__init__.py index c1a7b96..335b1fb 100644 --- a/sync/clients/__init__.py +++ b/sync/clients/__init__.py @@ -1,5 +1,6 @@ import json -from typing import ClassVar, Set, Tuple, Union +import sys +from typing import ClassVar, Union import httpx from tenacity import ( @@ -13,11 +14,14 @@ from sync import __version__ from sync.utils.json import DateTimeEncoderNaiveUTCDropMicroseconds -USER_AGENT = f"Sync Library/{__version__} (syncsparkpy)" +# inclue python version +USER_AGENT = ( + f"Sync Library/{__version__} (syncsparkpy) Python/{'.'.join(map(str, sys.version_info[:3]))}" +) DATABRICKS_USER_AGENT = "sync-gradient" -def encode_json(obj: dict) -> Tuple[dict, str]: +def encode_json(obj: dict) -> tuple[dict, str]: # "%Y-%m-%dT%H:%M:%SZ" json_obj = json.dumps(obj, cls=DateTimeEncoderNaiveUTCDropMicroseconds) @@ -33,7 +37,7 @@ class RetryableHTTPClient: Smaller wrapper around httpx.Client/AsyncClient to contain retrying logic that httpx does not offer natively """ - _DEFAULT_RETRYABLE_STATUS_CODES: ClassVar[Set[httpx.codes]] = { + _DEFAULT_RETRYABLE_STATUS_CODES: ClassVar[set[httpx.codes]] = { httpx.codes.REQUEST_TIMEOUT, httpx.codes.TOO_EARLY, httpx.codes.TOO_MANY_REQUESTS, diff --git a/sync/clients/cache.py b/sync/clients/cache.py index 3dc644d..90d812e 100644 --- a/sync/clients/cache.py +++ b/sync/clients/cache.py @@ -2,7 +2,7 @@ import logging from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Callable, Optional, Tuple, Type, Union +from typing import Callable, Optional, Union from platformdirs import user_cache_dir @@ -45,7 +45,7 @@ def set_cached_token(self, access_token: str, expires_at_utc: datetime) -> None: def _set_cached_token(self) -> None: raise NotImplementedError - def _get_cached_token(self) -> Optional[Tuple[str, datetime]]: + def _get_cached_token(self) -> Optional[tuple[str, datetime]]: raise NotImplementedError @@ -55,7 +55,7 @@ def __init__(self): super().__init__() - def _get_cached_token(self) -> Optional[Tuple[str, datetime]]: + def _get_cached_token(self) -> Optional[tuple[str, datetime]]: # Cache is optional, we can fail to read it and not worry if self._cache_file.exists(): try: @@ -87,7 +87,7 @@ def _set_cached_token(self) -> None: # Putting this here instead of config.py because circular imports and typing. -ACCESS_TOKEN_CACHE_CLS_TYPE = Union[Type[CachedToken], Callable[[], CachedToken]] +ACCESS_TOKEN_CACHE_CLS_TYPE = Union[type[CachedToken], Callable[[], CachedToken]] _access_token_cache_cls: ACCESS_TOKEN_CACHE_CLS_TYPE = ( FileCachedToken # Default to local file caching. ) diff --git a/sync/clients/databricks.py b/sync/clients/databricks.py index c04d8f3..a592594 100644 --- a/sync/clients/databricks.py +++ b/sync/clients/databricks.py @@ -1,6 +1,7 @@ import logging import os -from typing import Generator, Union +from collections.abc import Generator +from typing import Union import httpx from packaging.version import Version diff --git a/sync/clients/sync.py b/sync/clients/sync.py index e496c3c..76b600d 100644 --- a/sync/clients/sync.py +++ b/sync/clients/sync.py @@ -1,7 +1,8 @@ import asyncio import logging import threading -from typing import AsyncGenerator, Generator, Optional +from collections.abc import AsyncGenerator, Generator +from typing import Optional import dateutil.parser import httpx diff --git a/sync/config.py b/sync/config.py index a652295..0247227 100644 --- a/sync/config.py +++ b/sync/config.py @@ -4,7 +4,7 @@ import json from pathlib import Path -from typing import Any, Callable, Dict +from typing import Any, Callable from urllib.parse import urlparse import boto3 as boto @@ -15,8 +15,8 @@ DATABRICKS_CONFIG_FILE = "databrickscfg" -def json_config_settings_source(path: str) -> Callable[[BaseSettings], Dict[str, Any]]: - def source(settings: BaseSettings) -> Dict[str, Any]: +def json_config_settings_source(path: str) -> Callable[[BaseSettings], dict[str, Any]]: + def source(settings: BaseSettings) -> dict[str, Any]: config_path = _get_config_dir().joinpath(path) if config_path.exists(): with open(config_path) as fobj: diff --git a/sync/models.py b/sync/models.py index e3f2019..79bcd89 100644 --- a/sync/models.py +++ b/sync/models.py @@ -6,7 +6,7 @@ import json from dataclasses import dataclass from enum import Enum, unique -from typing import Callable, Dict, Generic, List, Optional, TypeVar, Union +from typing import Callable, Generic, Optional, TypeVar, Union from botocore.exceptions import ClientError from pydantic import BaseModel, Field, root_validator, validator @@ -55,7 +55,7 @@ class AccessReportLine: message: Union[str, None] -class AccessReport(List[AccessReportLine]): +class AccessReport(list[AccessReportLine]): def __str__(self): return "\n".join(f"{line.name}\n {line.status}: {line.message}" for line in self) @@ -125,13 +125,13 @@ class DatabricksClusterReport(BaseModel): compute_type: DatabricksComputeType cluster: dict cluster_events: dict - tasks: List[dict] - instances: Union[List[dict], None] - instance_timelines: Union[List[dict], None] + tasks: list[dict] + instances: Union[list[dict], None] + instance_timelines: Union[list[dict], None] class AWSDatabricksClusterReport(DatabricksClusterReport): - volumes: Union[List[dict], None] + volumes: Union[list[dict], None] class AzureDatabricksClusterReport(DatabricksClusterReport): @@ -188,28 +188,28 @@ class DBFSClusterLogConfiguration(BaseModel): class AWSProjectConfiguration(BaseModel): node_type_id: str driver_node_type: str - custom_tags: Dict + custom_tags: dict cluster_log_conf: Union[S3ClusterLogConfiguration, DBFSClusterLogConfiguration] cluster_name: str num_workers: int spark_version: str runtime_engine: str - autoscale: Dict - spark_conf: Dict - aws_attributes: Dict - spark_env_vars: Dict + autoscale: dict + spark_conf: dict + aws_attributes: dict + spark_env_vars: dict class AzureProjectConfiguration(BaseModel): node_type_id: str driver_node_type: str cluster_log_conf: DBFSClusterLogConfiguration - custom_tags: Dict + custom_tags: dict num_workers: int - spark_conf: Dict + spark_conf: dict spark_version: str runtime_engine: str - azure_attributes: Dict + azure_attributes: dict class AwsRegionEnum(str, Enum): diff --git a/sync/utils/json.py b/sync/utils/json.py index 37e11df..8d03d0e 100644 --- a/sync/utils/json.py +++ b/sync/utils/json.py @@ -1,6 +1,6 @@ import datetime from json import JSONEncoder -from typing import Any, Dict, TypeVar +from typing import Any, TypeVar class DefaultDateTimeEncoder(JSONEncoder): @@ -41,8 +41,8 @@ def default(self, obj): def deep_update( - mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any] -) -> Dict[KeyType, Any]: + mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any] +) -> dict[KeyType, Any]: updated_mapping = mapping.copy() for updating_mapping in updating_mappings: for k, v in updating_mapping.items(): diff --git a/tests/test_files/aws_cluster.json b/tests/test_files/aws_cluster.json index 1b0012f..50da7d5 100644 --- a/tests/test_files/aws_cluster.json +++ b/tests/test_files/aws_cluster.json @@ -1,52 +1,52 @@ { - "cluster_id": "1234-567890-reef123", - "spark_context_id": 4020997813441462000, - "cluster_name": "my-cluster", - "spark_version": "13.3.x-scala2.12", - "aws_attributes": { - "zone_id": "us-west-2c", - "first_on_demand": 1, - "availability": "SPOT_WITH_FALLBACK", - "spot_bid_price_percent": 100, - "ebs_volume_count": 0 - }, - "node_type_id": "i3.xlarge", - "driver_node_type_id": "i3.xlarge", - "autotermination_minutes": 120, - "enable_elastic_disk": false, - "disk_spec": { - "disk_count": 0 - }, - "cluster_source": "UI", - "enable_local_disk_encryption": false, - "instance_source": { - "node_type_id": "i3.xlarge" - }, - "driver_instance_source": { - "node_type_id": "i3.xlarge" - }, - "state": "TERMINATED", - "state_message": "Inactive cluster terminated (inactive for 120 minutes).", - "start_time": 1618263108824, - "terminated_time": 1619746525713, - "last_state_loss_time": 1619739324740, - "num_workers": 30, - "default_tags": { - "Vendor": "Databricks", - "Creator": "someone@example.com", - "ClusterName": "my-cluster", - "ClusterId": "1234-567890-reef123" - }, - "creator_user_name": "someone@example.com", - "termination_reason": { - "code": "INACTIVITY", - "parameters": { - "inactivity_duration_min": "120" + "autotermination_minutes": 120, + "aws_attributes": { + "availability": "SPOT_WITH_FALLBACK", + "ebs_volume_count": 0, + "first_on_demand": 1, + "spot_bid_price_percent": 100, + "zone_id": "us-west-2c" }, - "type": "SUCCESS" - }, - "init_scripts_safe_mode": false, - "spec": { - "spark_version": "13.3.x-scala2.12" - } -} \ No newline at end of file + "cluster_id": "1234-567890-reef123", + "cluster_name": "my-cluster", + "cluster_source": "UI", + "creator_user_name": "someone@example.com", + "default_tags": { + "ClusterId": "1234-567890-reef123", + "ClusterName": "my-cluster", + "Creator": "someone@example.com", + "Vendor": "Databricks" + }, + "disk_spec": { + "disk_count": 0 + }, + "driver_instance_source": { + "node_type_id": "i3.xlarge" + }, + "driver_node_type_id": "i3.xlarge", + "enable_elastic_disk": false, + "enable_local_disk_encryption": false, + "init_scripts_safe_mode": false, + "instance_source": { + "node_type_id": "i3.xlarge" + }, + "last_state_loss_time": 1619739324740, + "node_type_id": "i3.xlarge", + "num_workers": 30, + "spark_context_id": 4020997813441462000, + "spark_version": "13.3.x-scala2.12", + "spec": { + "spark_version": "13.3.x-scala2.12" + }, + "start_time": 1618263108824, + "state": "TERMINATED", + "state_message": "Inactive cluster terminated (inactive for 120 minutes).", + "terminated_time": 1619746525713, + "termination_reason": { + "code": "INACTIVITY", + "parameters": { + "inactivity_duration_min": "120" + }, + "type": "SUCCESS" + } +} diff --git a/tests/test_files/aws_recommendation.json b/tests/test_files/aws_recommendation.json index 6087fe2..7fe0fd6 100644 --- a/tests/test_files/aws_recommendation.json +++ b/tests/test_files/aws_recommendation.json @@ -1,55 +1,55 @@ { - "result": [ - { - "created_at": "2024-02-29T23:11:58.559Z", - "updated_at": "2024-02-29T23:11:58.559Z", - "id": "3fa85f64-5717-4562-b3fc-2c963f66afa6", - "state": "string", - "error": "string", - "recommendation": { - "metrics": { - "spark_duration_minutes": 0, - "spark_cost_requested_usd": 0, - "spark_cost_lower_usd": 0, - "spark_cost_midpoint_usd": 0, - "spark_cost_upper_usd": 0 - }, - "configuration": { - "node_type_id": "i6.xlarge", - "driver_node_type_id": "i6.xlarge", - "custom_tags": { - "sync:project-id": "b9bd7136-7699-4603-9040-c6dc4c914e43", - "sync:run-id": "e96401da-f64d-4ed0-8ded-db1317f40248", - "sync:recommendation-id": "e029a220-c6a5-49fd-b7ed-7ea046366741", - "sync:tenant-id": "352176a7-b605-4cc2-b3b2-ee591715b6b4" - }, - "num_workers": 20, - "spark_conf": { - "spark.databricks.isv.product": "sync-gradient" - }, - "spark_version": "13.3.x-scala2.12", - "runtime_engine": "PHOTON", - "aws_attributes": { - "first_on_demand": 1, - "availability": "SPOT_WITH_FALLBACK", - "spot_bid_price_percent": 100 - } - }, - "prediction_params": { - "sla_minutes": 0, - "force_ondemand_workers": false, - "force_ondemand_basis": false, - "fix_worker_family": true, - "fix_driver_type": true, - "fix_scaling_type": false + "result": [ + { + "context": { + "current_learning_iteration": 0, + "latest_submission_id": "3fa85f64-5717-4562-b3fc-2c963f66afa6", + "phase": "LEARNING", + "total_learning_iterations": 0 + }, + "created_at": "2024-02-29T23:11:58.559Z", + "error": "string", + "id": "3fa85f64-5717-4562-b3fc-2c963f66afa6", + "recommendation": { + "configuration": { + "aws_attributes": { + "availability": "SPOT_WITH_FALLBACK", + "first_on_demand": 1, + "spot_bid_price_percent": 100 + }, + "custom_tags": { + "sync:project-id": "b9bd7136-7699-4603-9040-c6dc4c914e43", + "sync:recommendation-id": "e029a220-c6a5-49fd-b7ed-7ea046366741", + "sync:run-id": "e96401da-f64d-4ed0-8ded-db1317f40248", + "sync:tenant-id": "352176a7-b605-4cc2-b3b2-ee591715b6b4" + }, + "driver_node_type_id": "i6.xlarge", + "node_type_id": "i6.xlarge", + "num_workers": 20, + "runtime_engine": "PHOTON", + "spark_conf": { + "spark.databricks.isv.product": "sync-gradient" + }, + "spark_version": "13.3.x-scala2.12" + }, + "metrics": { + "spark_cost_lower_usd": 0, + "spark_cost_midpoint_usd": 0, + "spark_cost_requested_usd": 0, + "spark_cost_upper_usd": 0, + "spark_duration_minutes": 0 + }, + "prediction_params": { + "fix_driver_type": true, + "fix_scaling_type": false, + "fix_worker_family": true, + "force_ondemand_basis": false, + "force_ondemand_workers": false, + "sla_minutes": 0 + } + }, + "state": "string", + "updated_at": "2024-02-29T23:11:58.559Z" } - }, - "context": { - "latest_submission_id": "3fa85f64-5717-4562-b3fc-2c963f66afa6", - "phase": "LEARNING", - "current_learning_iteration": 0, - "total_learning_iterations": 0 - } - } - ] -} \ No newline at end of file + ] +} diff --git a/tests/test_files/azure_cluster.json b/tests/test_files/azure_cluster.json index 3ef7977..de9d7db 100644 --- a/tests/test_files/azure_cluster.json +++ b/tests/test_files/azure_cluster.json @@ -1,50 +1,50 @@ { - "cluster_id": "1114-202840-mu1ql9xp", - "spark_context_id": 8637481617925571639, - "cluster_name": "my-cluster", - "spark_version": "13.3.x-scala2.12", - "azure_attributes": { - "first_on_demand": 6, - "availability": "SPOT_WITH_FALLBACK_AZURE", - "spot_bid_max_price": 100.0 - }, - "node_type_id": "Standard_DS5_v2", - "driver_node_type_id": "Standard_DS5_v2", - "autotermination_minutes": 120, - "enable_elastic_disk": false, - "disk_spec": { - "disk_count": 0 - }, - "cluster_source": "UI", - "enable_local_disk_encryption": false, - "instance_source": { - "node_type_id": "Standard_DS5_v2" - }, - "driver_instance_source": { - "node_type_id": "Standard_DS5_v2" - }, - "state": "TERMINATED", - "state_message": "Inactive cluster terminated (inactive for 120 minutes).", - "start_time": 1618263108824, - "terminated_time": 1619746525713, - "last_state_loss_time": 1619739324740, - "num_workers": 30, - "default_tags": { - "Vendor": "Databricks", - "Creator": "someone@example.com", - "ClusterName": "my-cluster", - "ClusterId": "1234-567890-reef123" - }, - "creator_user_name": "someone@example.com", - "termination_reason": { - "code": "INACTIVITY", - "parameters": { - "inactivity_duration_min": "120" + "autotermination_minutes": 120, + "azure_attributes": { + "availability": "SPOT_WITH_FALLBACK_AZURE", + "first_on_demand": 6, + "spot_bid_max_price": 100.0 }, - "type": "SUCCESS" - }, - "init_scripts_safe_mode": false, - "spec": { - "spark_version": "13.3.x-scala2.12" - } -} \ No newline at end of file + "cluster_id": "1114-202840-mu1ql9xp", + "cluster_name": "my-cluster", + "cluster_source": "UI", + "creator_user_name": "someone@example.com", + "default_tags": { + "ClusterId": "1234-567890-reef123", + "ClusterName": "my-cluster", + "Creator": "someone@example.com", + "Vendor": "Databricks" + }, + "disk_spec": { + "disk_count": 0 + }, + "driver_instance_source": { + "node_type_id": "Standard_DS5_v2" + }, + "driver_node_type_id": "Standard_DS5_v2", + "enable_elastic_disk": false, + "enable_local_disk_encryption": false, + "init_scripts_safe_mode": false, + "instance_source": { + "node_type_id": "Standard_DS5_v2" + }, + "last_state_loss_time": 1619739324740, + "node_type_id": "Standard_DS5_v2", + "num_workers": 30, + "spark_context_id": 8637481617925571639, + "spark_version": "13.3.x-scala2.12", + "spec": { + "spark_version": "13.3.x-scala2.12" + }, + "start_time": 1618263108824, + "state": "TERMINATED", + "state_message": "Inactive cluster terminated (inactive for 120 minutes).", + "terminated_time": 1619746525713, + "termination_reason": { + "code": "INACTIVITY", + "parameters": { + "inactivity_duration_min": "120" + }, + "type": "SUCCESS" + } +} diff --git a/tests/test_files/azure_recommendation.json b/tests/test_files/azure_recommendation.json index 8218219..32d39fd 100644 --- a/tests/test_files/azure_recommendation.json +++ b/tests/test_files/azure_recommendation.json @@ -1,47 +1,47 @@ { - "result": [ - { - "created_at": "2024-02-14T21:26:37Z", - "updated_at": "2024-02-14T21:29:38Z", - "id": "6024acdd-fd13-4bf1-82f5-44f1ab7008f2", - "state": "SUCCESS", - "recommendation": { - "configuration": { - "node_type_id": "Standard_D4s_v3", - "driver_node_type_id": "Standard_D4s_v3", - "custom_tags": { - "sync:project-id": "769c3443-afd7-45ff-a72a-27bf4296b80e", - "sync:run-id": "d3f8db6c-df4b-430a-a511-a1e9c95d1ad0", - "sync:recommendation-id": "6024acdd-fd13-4bf1-82f5-44f1ab7008f2", - "sync:tenant-id": "290d381e-8eb4-4d6a-80d4-453d82897ecc" - }, - "num_workers": 5, - "spark_conf": { - "spark.databricks.isv.product": "sync-gradient" - }, - "spark_version": "13.3.x-scala2.12", - "runtime_engine": "STANDARD", - "azure_attributes": { - "availability": "SPOT_WITH_FALLBACK_AZURE", - "first_on_demand": 7, - "spot_bid_max_price": 100.0 - } - }, - "prediction_params": { - "sla_minutes": 60, - "force_ondemand_workers": false, - "force_ondemand_basis": false, - "fix_worker_family": true, - "fix_driver_type": true, - "fix_scaling_type": true + "result": [ + { + "context": { + "current_learning_iteration": 2, + "latest_submission_id": "8626c4f9-e57f-4563-9981-f21b936828b0", + "phase": "LEARNING", + "total_learning_iterations": 3 + }, + "created_at": "2024-02-14T21:26:37Z", + "id": "6024acdd-fd13-4bf1-82f5-44f1ab7008f2", + "recommendation": { + "configuration": { + "azure_attributes": { + "availability": "SPOT_WITH_FALLBACK_AZURE", + "first_on_demand": 7, + "spot_bid_max_price": 100.0 + }, + "custom_tags": { + "sync:project-id": "769c3443-afd7-45ff-a72a-27bf4296b80e", + "sync:recommendation-id": "6024acdd-fd13-4bf1-82f5-44f1ab7008f2", + "sync:run-id": "d3f8db6c-df4b-430a-a511-a1e9c95d1ad0", + "sync:tenant-id": "290d381e-8eb4-4d6a-80d4-453d82897ecc" + }, + "driver_node_type_id": "Standard_D4s_v3", + "node_type_id": "Standard_D4s_v3", + "num_workers": 5, + "runtime_engine": "STANDARD", + "spark_conf": { + "spark.databricks.isv.product": "sync-gradient" + }, + "spark_version": "13.3.x-scala2.12" + }, + "prediction_params": { + "fix_driver_type": true, + "fix_scaling_type": true, + "fix_worker_family": true, + "force_ondemand_basis": false, + "force_ondemand_workers": false, + "sla_minutes": 60 + } + }, + "state": "SUCCESS", + "updated_at": "2024-02-14T21:29:38Z" } - }, - "context": { - "latest_submission_id": "8626c4f9-e57f-4563-9981-f21b936828b0", - "phase": "LEARNING", - "current_learning_iteration": 2, - "total_learning_iterations": 3 - } - } - ] -} \ No newline at end of file + ] +}