Skip to content
127 changes: 90 additions & 37 deletions src/py/mat3ra/api_client/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Any, Optional, Tuple
import os
import re
from typing import Any, List, Optional, Tuple

import requests
from pydantic import BaseModel, ConfigDict

from .constants import ACCESS_TOKEN_ENV_VAR, _build_base_url
from .endpoints.bank_materials import BankMaterialEndpoints
from .endpoints.bank_workflows import BankWorkflowEndpoints
from .endpoints.jobs import JobEndpoints
Expand All @@ -26,8 +30,15 @@ class APIClient(BaseModel):
def model_post_init(self, __context: Any) -> None:
self.my_account = Account(client=self)
self.account = self.my_account
self._my_organization: Optional[Account] = None
self._init_endpoints(self.timeout_seconds)

@property
def my_organization(self) -> Optional[Account]:
if self._my_organization is None:
self._my_organization = self.get_default_organization()
return self._my_organization

@classmethod
def env(cls) -> APIEnv:
return APIEnv.from_env()
Expand All @@ -37,39 +48,17 @@ def auth_env(cls) -> AuthEnv:
return AuthEnv.from_env()

def _init_endpoints(self, timeout_seconds: int) -> None:
kwargs = {"timeout": timeout_seconds, "auth": self.auth}
account_id = self.auth.account_id or ""
auth_token = self.auth.auth_token or ""
self._init_core_endpoints(kwargs, account_id, auth_token)
self._init_bank_endpoints(kwargs, account_id, auth_token)

def _init_core_endpoints(self, kwargs: dict, account_id: str, auth_token: str) -> None:
self.materials = MaterialEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)
self.workflows = WorkflowEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)
self.jobs = JobEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)
self.projects = ProjectEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)
self.properties = PropertiesEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)
self.metaproperties = MetaPropertiesEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)
base_args = (self.host, self.port, self.auth.account_id or "", self.auth.auth_token or "")
base_kwargs = {"version": self.version, "secure": self.secure, "timeout": timeout_seconds, "auth": self.auth}

def _init_bank_endpoints(self, kwargs: dict, account_id: str, auth_token: str) -> None:
self.bank_materials = BankMaterialEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)
self.bank_workflows = BankWorkflowEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)
self.materials = MaterialEndpoints(*base_args, **base_kwargs)
self.workflows = WorkflowEndpoints(*base_args, **base_kwargs)
self.jobs = JobEndpoints(*base_args, **base_kwargs)
self.projects = ProjectEndpoints(*base_args, **base_kwargs)
self.properties = PropertiesEndpoints(*base_args, **base_kwargs)
self.metaproperties = MetaPropertiesEndpoints(*base_args, **base_kwargs)
self.bank_materials = BankMaterialEndpoints(*base_args, **base_kwargs)
self.bank_workflows = BankWorkflowEndpoints(*base_args, **base_kwargs)

@staticmethod
def _resolve_config(
Expand Down Expand Up @@ -122,9 +111,73 @@ def authenticate(
auth_token: Optional[str] = None,
timeout_seconds: int = 60,
) -> "APIClient":
host_value, port_value, version_value, secure_value = cls._resolve_config(host, port, version, secure,
cls.env())
host_value, port_value, version_value, secure_value = cls._resolve_config(
host, port, version, secure, cls.env()
)
auth = cls._auth_from_env(access_token=access_token, account_id=account_id, auth_token=auth_token)
cls._validate_auth(auth)
return cls(host=host_value, port=port_value, version=version_value, secure=secure_value, auth=auth,
timeout_seconds=timeout_seconds)
return cls(
host=host_value,
port=port_value,
version=version_value,
secure=secure_value,
auth=auth,
timeout_seconds=timeout_seconds,
)

def _fetch_data(self) -> dict:
access_token = self.auth.access_token or os.environ.get(ACCESS_TOKEN_ENV_VAR)
if not access_token:
raise ValueError("Access token is required to fetch user data")

url = _build_base_url(self.host, self.port, self.secure, "/api/v1/users/me")
response = requests.get(url, headers={"Authorization": f"Bearer {access_token}"}, timeout=30)
response.raise_for_status()
return response.json()["data"]

def _fetch_user_accounts(self) -> List[dict]:
return self._fetch_data().get("accounts", [])

def list_accounts(self) -> List[dict]:
accounts = self._fetch_user_accounts()
return [
{
"_id": account["entity"]["_id"],
"name": account["entity"].get("name", ""),
"type": account["entity"].get("type", "personal"),
"isDefault": account.get("isDefault", False),
}
for account in accounts
]

def get_account(self, name: Optional[str] = None, index: Optional[int] = None) -> Account:
"""Get account by name (partial regex match) or index from the list of user accounts."""
if name is None and index is None:
raise ValueError("Either 'name' or 'index' must be provided")

accounts = self._fetch_user_accounts()

if index is not None:
return Account(client=self, entity_cache=accounts[index]["entity"])

pattern = re.compile(name, re.IGNORECASE)
matches = [account for account in accounts if pattern.search(account["entity"].get("name", ""))]

if not matches:
raise ValueError(f"No account found matching '{name}'")
if len(matches) > 1:
names = [acc["entity"].get("name", "") for acc in matches]
raise ValueError(f"Multiple accounts match '{name}': {names}")

return Account(client=self, entity_cache=matches[0]["entity"])

def get_default_organization(self) -> Optional[Account]:
accounts = self._fetch_user_accounts()
organizations = [account for account in accounts if
account["entity"].get("type") in ("organization", "enterprise")]

if not organizations:
return None

default_org = next((org for org in organizations if org.get("isDefault")), organizations[0])
return Account(client=self, entity_cache=default_org["entity"])
3 changes: 2 additions & 1 deletion src/py/mat3ra/api_client/endpoints/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def terminate(self, id_):
"""
self.request("POST", "/".join((self.name, id_, "submit")), headers=self.headers)

def build_config(self, material_ids, workflow_id, project_id, owner_id, name, compute=None, is_multi_material=False):
def build_config(self, material_ids, workflow_id, project_id, owner_id, name, compute=None,
is_multi_material=False):
"""
Returns a job config based on the given parameters.

Expand Down
54 changes: 35 additions & 19 deletions src/py/mat3ra/api_client/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from typing import Any, Optional

import requests
from pydantic import BaseModel, ConfigDict, Field

from .constants import ACCESS_TOKEN_ENV_VAR, ACCOUNT_ID_ENV_VAR, AUTH_TOKEN_ENV_VAR, _build_base_url
from .constants import ACCESS_TOKEN_ENV_VAR, ACCOUNT_ID_ENV_VAR, AUTH_TOKEN_ENV_VAR


class AuthContext(BaseModel):
Expand Down Expand Up @@ -38,30 +37,47 @@ class Account(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)

client: Any = Field(exclude=True, repr=False)
id_cache: Optional[str] = None
entity_cache: Optional[dict] = None

@property
def id(self) -> str:
if self.id_cache:
return self.id_cache
self.id_cache = self._resolve_account_id()
return self.id_cache
if not self.entity_cache:
self._get_entity()
return self.entity_cache["_id"]

def _resolve_account_id(self) -> str:
@property
def name(self) -> str:
if not self.entity_cache:
self._get_entity()
return self.entity_cache.get("name", "")

def _get_entity(self) -> None:
account_id, accounts = self._get_account_id_and_accounts()
self.entity_cache = self._find_account_entity(account_id, accounts)

def _get_account_id_and_accounts(self) -> tuple[str, Optional[list]]:
account_id = self.client.auth.account_id or os.environ.get(ACCOUNT_ID_ENV_VAR)

if account_id:
self.client.auth.account_id = account_id
return account_id

access_token = self.client.auth.access_token or os.environ.get(ACCESS_TOKEN_ENV_VAR)
if not access_token:
return account_id, None

if not (self.client.auth.access_token or os.environ.get(ACCESS_TOKEN_ENV_VAR)):
raise ValueError("ACCOUNT_ID is not set and no OIDC access token is available.")

url = _build_base_url(self.client.host, self.client.port, self.client.secure, "/api/v1/users/me")
response = requests.get(url, headers={"Authorization": f"Bearer {access_token}"}, timeout=30)
response.raise_for_status()
account_id = response.json()["data"]["user"]["entity"]["defaultAccountId"]

data = self.client._fetch_data()
account_id = data["user"]["entity"]["defaultAccountId"]
os.environ[ACCOUNT_ID_ENV_VAR] = account_id
self.client.auth.account_id = account_id
return account_id
return account_id, data.get("accounts", [])

def _find_account_entity(self, account_id: str, accounts: Optional[list]) -> dict:
if accounts is None and (self.client.auth.access_token or os.environ.get(ACCESS_TOKEN_ENV_VAR)):
accounts = self.client._fetch_user_accounts()

if accounts:
for account in accounts:
if account["entity"]["_id"] == account_id:
return account["entity"]

return {"_id": account_id}

77 changes: 74 additions & 3 deletions tests/py/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,28 @@
ME_ACCOUNT_ID = "my-account-id"
USERS_ME_RESPONSE = {"data": {"user": {"entity": {"defaultAccountId": ME_ACCOUNT_ID}}}}

ACCOUNTS_RESPONSE = {
"data": {
"user": {
"entity": {"defaultAccountId": ME_ACCOUNT_ID}
},
"accounts": [
{
"entity": {"_id": "user-acc-1", "name": "John Doe", "type": "personal"},
"isDefault": True,
},
{
"entity": {"_id": "org-acc-1", "name": "Acme Corp", "type": "enterprise"},
"isDefault": True,
},
{
"entity": {"_id": "org-acc-2", "name": "Beta Industries", "type": "organization"},
"isDefault": False,
},
],
}
}


class APIClientUnitTest(EndpointBaseUnitTest):
def _base_env(self):
Expand All @@ -27,9 +49,9 @@ def _base_env(self):
"API_SECURE": API_SECURE_FALSE,
}

def _mock_users_me(self, mock_get):
def _mock_users_me(self, mock_get, response=None):
mock_resp = mock.Mock()
mock_resp.json.return_value = USERS_ME_RESPONSE
mock_resp.json.return_value = response or USERS_ME_RESPONSE
mock_resp.raise_for_status.return_value = None
mock_get.return_value = mock_resp

Expand Down Expand Up @@ -61,8 +83,16 @@ def test_my_account_id_uses_existing_account_id(self, mock_get):
@mock.patch("requests.get")
def test_my_account_id_fetches_and_caches(self, mock_get):
env = self._base_env() | {"OIDC_ACCESS_TOKEN": OIDC_ACCESS_TOKEN}
response_with_account = {
"data": {
"user": {"entity": {"defaultAccountId": ME_ACCOUNT_ID}},
"accounts": [
{"entity": {"_id": ME_ACCOUNT_ID, "name": "Test User", "type": "personal"}, "isDefault": True}
],
}
}
with mock.patch.dict("os.environ", env, clear=True):
self._mock_users_me(mock_get)
self._mock_users_me(mock_get, response_with_account)
client = APIClient.authenticate()
self.assertEqual(client.my_account.id, ME_ACCOUNT_ID)
self.assertEqual(client.my_account.id, ME_ACCOUNT_ID)
Expand All @@ -71,3 +101,44 @@ def test_my_account_id_fetches_and_caches(self, mock_get):
self.assertEqual(mock_get.call_args[1]["headers"]["Authorization"], f"Bearer {OIDC_ACCESS_TOKEN}")
self.assertEqual(mock_get.call_args[1]["timeout"], 30)
self.assertEqual(os.environ.get("ACCOUNT_ID"), ME_ACCOUNT_ID)

@mock.patch("requests.get")
def test_list_accounts(self, mock_get):
env = self._base_env() | {"OIDC_ACCESS_TOKEN": OIDC_ACCESS_TOKEN}
with mock.patch.dict("os.environ", env, clear=True):
self._mock_users_me(mock_get, ACCOUNTS_RESPONSE)
client = APIClient.authenticate()
accounts = client.list_accounts()

self.assertEqual(len(accounts), 3)
self.assertEqual(accounts[0]["_id"], "user-acc-1")
self.assertEqual(accounts[0]["name"], "John Doe")
self.assertEqual(accounts[0]["type"], "personal")
self.assertTrue(accounts[0]["isDefault"])
self.assertEqual(accounts[1]["_id"], "org-acc-1")
self.assertEqual(accounts[1]["name"], "Acme Corp")
self.assertEqual(accounts[1]["type"], "enterprise")

@mock.patch("requests.get")
def test_get_account(self, mock_get):
env = self._base_env() | {"OIDC_ACCESS_TOKEN": OIDC_ACCESS_TOKEN}
with mock.patch.dict("os.environ", env, clear=True):
self._mock_users_me(mock_get, ACCOUNTS_RESPONSE)
client = APIClient.authenticate()

account = client.get_account(index=1)
self.assertEqual(account.id, "org-acc-1")
self.assertEqual(account.name, "Acme Corp")

self.assertEqual(client.get_account(name="Acme").id, "org-acc-1")
self.assertEqual(client.get_account(name="Beta.*").id, "org-acc-2")

@mock.patch("requests.get")
def test_my_organization(self, mock_get):
env = self._base_env() | {"OIDC_ACCESS_TOKEN": OIDC_ACCESS_TOKEN}
with mock.patch.dict("os.environ", env, clear=True):
self._mock_users_me(mock_get, ACCOUNTS_RESPONSE)
client = APIClient.authenticate()
org = client.my_organization
self.assertEqual(org.id, "org-acc-1")
self.assertEqual(org.name, "Acme Corp")
Loading