diff --git a/src/py/mat3ra/api_client/client.py b/src/py/mat3ra/api_client/client.py index b9c58c0..00914aa 100644 --- a/src/py/mat3ra/api_client/client.py +++ b/src/py/mat3ra/api_client/client.py @@ -1,7 +1,11 @@ -from typing import Any, Optional, Tuple +import os +import re +from typing import Any, List, Optional, Tuple +import requests from pydantic import BaseModel, ConfigDict +from .constants import ACCESS_TOKEN_ENV_VAR, _build_base_url from .endpoints.bank_materials import BankMaterialEndpoints from .endpoints.bank_workflows import BankWorkflowEndpoints from .endpoints.jobs import JobEndpoints @@ -26,8 +30,15 @@ class APIClient(BaseModel): def model_post_init(self, __context: Any) -> None: self.my_account = Account(client=self) self.account = self.my_account + self._my_organization: Optional[Account] = None self._init_endpoints(self.timeout_seconds) + @property + def my_organization(self) -> Optional[Account]: + if self._my_organization is None: + self._my_organization = self.get_default_organization() + return self._my_organization + @classmethod def env(cls) -> APIEnv: return APIEnv.from_env() @@ -37,39 +48,17 @@ def auth_env(cls) -> AuthEnv: return AuthEnv.from_env() def _init_endpoints(self, timeout_seconds: int) -> None: - kwargs = {"timeout": timeout_seconds, "auth": self.auth} - account_id = self.auth.account_id or "" - auth_token = self.auth.auth_token or "" - self._init_core_endpoints(kwargs, account_id, auth_token) - self._init_bank_endpoints(kwargs, account_id, auth_token) - - def _init_core_endpoints(self, kwargs: dict, account_id: str, auth_token: str) -> None: - self.materials = MaterialEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) - self.workflows = WorkflowEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) - self.jobs = JobEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) - self.projects = ProjectEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) - self.properties = PropertiesEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) - self.metaproperties = MetaPropertiesEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) + base_args = (self.host, self.port, self.auth.account_id or "", self.auth.auth_token or "") + base_kwargs = {"version": self.version, "secure": self.secure, "timeout": timeout_seconds, "auth": self.auth} - def _init_bank_endpoints(self, kwargs: dict, account_id: str, auth_token: str) -> None: - self.bank_materials = BankMaterialEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) - self.bank_workflows = BankWorkflowEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) + self.materials = MaterialEndpoints(*base_args, **base_kwargs) + self.workflows = WorkflowEndpoints(*base_args, **base_kwargs) + self.jobs = JobEndpoints(*base_args, **base_kwargs) + self.projects = ProjectEndpoints(*base_args, **base_kwargs) + self.properties = PropertiesEndpoints(*base_args, **base_kwargs) + self.metaproperties = MetaPropertiesEndpoints(*base_args, **base_kwargs) + self.bank_materials = BankMaterialEndpoints(*base_args, **base_kwargs) + self.bank_workflows = BankWorkflowEndpoints(*base_args, **base_kwargs) @staticmethod def _resolve_config( @@ -122,9 +111,73 @@ def authenticate( auth_token: Optional[str] = None, timeout_seconds: int = 60, ) -> "APIClient": - host_value, port_value, version_value, secure_value = cls._resolve_config(host, port, version, secure, - cls.env()) + host_value, port_value, version_value, secure_value = cls._resolve_config( + host, port, version, secure, cls.env() + ) auth = cls._auth_from_env(access_token=access_token, account_id=account_id, auth_token=auth_token) cls._validate_auth(auth) - return cls(host=host_value, port=port_value, version=version_value, secure=secure_value, auth=auth, - timeout_seconds=timeout_seconds) + return cls( + host=host_value, + port=port_value, + version=version_value, + secure=secure_value, + auth=auth, + timeout_seconds=timeout_seconds, + ) + + def _fetch_data(self) -> dict: + access_token = self.auth.access_token or os.environ.get(ACCESS_TOKEN_ENV_VAR) + if not access_token: + raise ValueError("Access token is required to fetch user data") + + url = _build_base_url(self.host, self.port, self.secure, "/api/v1/users/me") + response = requests.get(url, headers={"Authorization": f"Bearer {access_token}"}, timeout=30) + response.raise_for_status() + return response.json()["data"] + + def _fetch_user_accounts(self) -> List[dict]: + return self._fetch_data().get("accounts", []) + + def list_accounts(self) -> List[dict]: + accounts = self._fetch_user_accounts() + return [ + { + "_id": account["entity"]["_id"], + "name": account["entity"].get("name", ""), + "type": account["entity"].get("type", "personal"), + "isDefault": account.get("isDefault", False), + } + for account in accounts + ] + + def get_account(self, name: Optional[str] = None, index: Optional[int] = None) -> Account: + """Get account by name (partial regex match) or index from the list of user accounts.""" + if name is None and index is None: + raise ValueError("Either 'name' or 'index' must be provided") + + accounts = self._fetch_user_accounts() + + if index is not None: + return Account(client=self, entity_cache=accounts[index]["entity"]) + + pattern = re.compile(name, re.IGNORECASE) + matches = [account for account in accounts if pattern.search(account["entity"].get("name", ""))] + + if not matches: + raise ValueError(f"No account found matching '{name}'") + if len(matches) > 1: + names = [acc["entity"].get("name", "") for acc in matches] + raise ValueError(f"Multiple accounts match '{name}': {names}") + + return Account(client=self, entity_cache=matches[0]["entity"]) + + def get_default_organization(self) -> Optional[Account]: + accounts = self._fetch_user_accounts() + organizations = [account for account in accounts if + account["entity"].get("type") in ("organization", "enterprise")] + + if not organizations: + return None + + default_org = next((org for org in organizations if org.get("isDefault")), organizations[0]) + return Account(client=self, entity_cache=default_org["entity"]) diff --git a/src/py/mat3ra/api_client/endpoints/jobs.py b/src/py/mat3ra/api_client/endpoints/jobs.py index e9e84dd..0001695 100644 --- a/src/py/mat3ra/api_client/endpoints/jobs.py +++ b/src/py/mat3ra/api_client/endpoints/jobs.py @@ -54,7 +54,8 @@ def terminate(self, id_): """ self.request("POST", "/".join((self.name, id_, "submit")), headers=self.headers) - def build_config(self, material_ids, workflow_id, project_id, owner_id, name, compute=None, is_multi_material=False): + def build_config(self, material_ids, workflow_id, project_id, owner_id, name, compute=None, + is_multi_material=False): """ Returns a job config based on the given parameters. diff --git a/src/py/mat3ra/api_client/models.py b/src/py/mat3ra/api_client/models.py index 290c6c9..e834ed8 100644 --- a/src/py/mat3ra/api_client/models.py +++ b/src/py/mat3ra/api_client/models.py @@ -1,10 +1,9 @@ import os from typing import Any, Optional -import requests from pydantic import BaseModel, ConfigDict, Field -from .constants import ACCESS_TOKEN_ENV_VAR, ACCOUNT_ID_ENV_VAR, AUTH_TOKEN_ENV_VAR, _build_base_url +from .constants import ACCESS_TOKEN_ENV_VAR, ACCOUNT_ID_ENV_VAR, AUTH_TOKEN_ENV_VAR class AuthContext(BaseModel): @@ -38,30 +37,47 @@ class Account(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) client: Any = Field(exclude=True, repr=False) - id_cache: Optional[str] = None + entity_cache: Optional[dict] = None @property def id(self) -> str: - if self.id_cache: - return self.id_cache - self.id_cache = self._resolve_account_id() - return self.id_cache + if not self.entity_cache: + self._get_entity() + return self.entity_cache["_id"] - def _resolve_account_id(self) -> str: + @property + def name(self) -> str: + if not self.entity_cache: + self._get_entity() + return self.entity_cache.get("name", "") + + def _get_entity(self) -> None: + account_id, accounts = self._get_account_id_and_accounts() + self.entity_cache = self._find_account_entity(account_id, accounts) + + def _get_account_id_and_accounts(self) -> tuple[str, Optional[list]]: account_id = self.client.auth.account_id or os.environ.get(ACCOUNT_ID_ENV_VAR) + if account_id: - self.client.auth.account_id = account_id - return account_id - - access_token = self.client.auth.access_token or os.environ.get(ACCESS_TOKEN_ENV_VAR) - if not access_token: + return account_id, None + + if not (self.client.auth.access_token or os.environ.get(ACCESS_TOKEN_ENV_VAR)): raise ValueError("ACCOUNT_ID is not set and no OIDC access token is available.") - - url = _build_base_url(self.client.host, self.client.port, self.client.secure, "/api/v1/users/me") - response = requests.get(url, headers={"Authorization": f"Bearer {access_token}"}, timeout=30) - response.raise_for_status() - account_id = response.json()["data"]["user"]["entity"]["defaultAccountId"] + + data = self.client._fetch_data() + account_id = data["user"]["entity"]["defaultAccountId"] os.environ[ACCOUNT_ID_ENV_VAR] = account_id self.client.auth.account_id = account_id - return account_id + return account_id, data.get("accounts", []) + + def _find_account_entity(self, account_id: str, accounts: Optional[list]) -> dict: + if accounts is None and (self.client.auth.access_token or os.environ.get(ACCESS_TOKEN_ENV_VAR)): + accounts = self.client._fetch_user_accounts() + + if accounts: + for account in accounts: + if account["entity"]["_id"] == account_id: + return account["entity"] + + return {"_id": account_id} diff --git a/tests/py/unit/test_client.py b/tests/py/unit/test_client.py index 59306ad..b19aa41 100644 --- a/tests/py/unit/test_client.py +++ b/tests/py/unit/test_client.py @@ -17,6 +17,28 @@ ME_ACCOUNT_ID = "my-account-id" USERS_ME_RESPONSE = {"data": {"user": {"entity": {"defaultAccountId": ME_ACCOUNT_ID}}}} +ACCOUNTS_RESPONSE = { + "data": { + "user": { + "entity": {"defaultAccountId": ME_ACCOUNT_ID} + }, + "accounts": [ + { + "entity": {"_id": "user-acc-1", "name": "John Doe", "type": "personal"}, + "isDefault": True, + }, + { + "entity": {"_id": "org-acc-1", "name": "Acme Corp", "type": "enterprise"}, + "isDefault": True, + }, + { + "entity": {"_id": "org-acc-2", "name": "Beta Industries", "type": "organization"}, + "isDefault": False, + }, + ], + } +} + class APIClientUnitTest(EndpointBaseUnitTest): def _base_env(self): @@ -27,9 +49,9 @@ def _base_env(self): "API_SECURE": API_SECURE_FALSE, } - def _mock_users_me(self, mock_get): + def _mock_users_me(self, mock_get, response=None): mock_resp = mock.Mock() - mock_resp.json.return_value = USERS_ME_RESPONSE + mock_resp.json.return_value = response or USERS_ME_RESPONSE mock_resp.raise_for_status.return_value = None mock_get.return_value = mock_resp @@ -61,8 +83,16 @@ def test_my_account_id_uses_existing_account_id(self, mock_get): @mock.patch("requests.get") def test_my_account_id_fetches_and_caches(self, mock_get): env = self._base_env() | {"OIDC_ACCESS_TOKEN": OIDC_ACCESS_TOKEN} + response_with_account = { + "data": { + "user": {"entity": {"defaultAccountId": ME_ACCOUNT_ID}}, + "accounts": [ + {"entity": {"_id": ME_ACCOUNT_ID, "name": "Test User", "type": "personal"}, "isDefault": True} + ], + } + } with mock.patch.dict("os.environ", env, clear=True): - self._mock_users_me(mock_get) + self._mock_users_me(mock_get, response_with_account) client = APIClient.authenticate() self.assertEqual(client.my_account.id, ME_ACCOUNT_ID) self.assertEqual(client.my_account.id, ME_ACCOUNT_ID) @@ -71,3 +101,44 @@ def test_my_account_id_fetches_and_caches(self, mock_get): self.assertEqual(mock_get.call_args[1]["headers"]["Authorization"], f"Bearer {OIDC_ACCESS_TOKEN}") self.assertEqual(mock_get.call_args[1]["timeout"], 30) self.assertEqual(os.environ.get("ACCOUNT_ID"), ME_ACCOUNT_ID) + + @mock.patch("requests.get") + def test_list_accounts(self, mock_get): + env = self._base_env() | {"OIDC_ACCESS_TOKEN": OIDC_ACCESS_TOKEN} + with mock.patch.dict("os.environ", env, clear=True): + self._mock_users_me(mock_get, ACCOUNTS_RESPONSE) + client = APIClient.authenticate() + accounts = client.list_accounts() + + self.assertEqual(len(accounts), 3) + self.assertEqual(accounts[0]["_id"], "user-acc-1") + self.assertEqual(accounts[0]["name"], "John Doe") + self.assertEqual(accounts[0]["type"], "personal") + self.assertTrue(accounts[0]["isDefault"]) + self.assertEqual(accounts[1]["_id"], "org-acc-1") + self.assertEqual(accounts[1]["name"], "Acme Corp") + self.assertEqual(accounts[1]["type"], "enterprise") + + @mock.patch("requests.get") + def test_get_account(self, mock_get): + env = self._base_env() | {"OIDC_ACCESS_TOKEN": OIDC_ACCESS_TOKEN} + with mock.patch.dict("os.environ", env, clear=True): + self._mock_users_me(mock_get, ACCOUNTS_RESPONSE) + client = APIClient.authenticate() + + account = client.get_account(index=1) + self.assertEqual(account.id, "org-acc-1") + self.assertEqual(account.name, "Acme Corp") + + self.assertEqual(client.get_account(name="Acme").id, "org-acc-1") + self.assertEqual(client.get_account(name="Beta.*").id, "org-acc-2") + + @mock.patch("requests.get") + def test_my_organization(self, mock_get): + env = self._base_env() | {"OIDC_ACCESS_TOKEN": OIDC_ACCESS_TOKEN} + with mock.patch.dict("os.environ", env, clear=True): + self._mock_users_me(mock_get, ACCOUNTS_RESPONSE) + client = APIClient.authenticate() + org = client.my_organization + self.assertEqual(org.id, "org-acc-1") + self.assertEqual(org.name, "Acme Corp")