Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 54 additions & 4 deletions ocpi/core/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,34 @@
from ocpi.modules.versions.v_2_2_1.enums import InterfaceRole


_VERSION_PREFERENCE = ["2.3.0", "2.2.1", "2.1.1"]


def _pick_version_details_url(
versions_list: list[dict], requested: VersionNumber
) -> str | None:
"""Pick the best version details URL from an OCPI /versions response.

Tries the requested version first, then falls back to the highest
mutually supported version.
"""
by_version = {v["version"]: v["url"] for v in versions_list if "version" in v}

if requested.value in by_version:
return by_version[requested.value]

for v in _VERSION_PREFERENCE:
if v in by_version:
return by_version[v]

return None


def client_url(module_id: ModuleID, object_id: str, base_url: str) -> str:
if module_id == ModuleID.cdrs:
return base_url
return f"{base_url}{settings.COUNTRY_CODE}/{settings.PARTY_ID}/{object_id}"
base = base_url.rstrip("/")
return f"{base}/{settings.COUNTRY_CODE}/{settings.PARTY_ID}/{object_id}"


def client_method(module_id: ModuleID) -> str:
Expand Down Expand Up @@ -62,7 +86,7 @@ async def send_push_request(
base_url = ""
for endpoint in endpoints:
if (
version.value.startswith("2.2")
not (version.value.startswith("2.0") or version.value.startswith("2.1"))
and endpoint["identifier"] == module_id
and endpoint["role"] == InterfaceRole.receiver
) or (version.value.startswith("2.1") and endpoint["identifier"] == module_id):
Expand Down Expand Up @@ -92,7 +116,7 @@ async def push_object(
# get client endpoints
if version.value.startswith("2.1") or version.value.startswith("2.0"):
token = receiver.auth_token
else:
else: # 2.2.x and 2.3.x use base64-encoded tokens
token = encode_string_base64(receiver.auth_token)

client_auth_token = f"Token {token}"
Expand All @@ -106,7 +130,33 @@ async def push_object(
headers={"authorization": client_auth_token},
)
logger.info(f"Response status_code - `{response.status_code}`")
endpoints = response.json()["data"]["endpoints"]
response_data = response.json()["data"]

# If response is a versions list, negotiate version and
# fetch the details URL for the best mutual version.
if isinstance(response_data, list):
details_url = _pick_version_details_url(
response_data, version
)
if not details_url:
raise ValueError(
f"No mutual OCPI version found. "
f"Requested {version.value}, receiver supports: "
f"{[v.get('version') for v in response_data]}"
)
logger.info(
f"Resolved version details URL: {details_url}"
)
response = await client.get(
details_url,
headers={"authorization": client_auth_token},
)
logger.info(
f"Version details response: {response.status_code}"
)
response_data = response.json()["data"]

endpoints = response_data["endpoints"]
logger.debug(f"Endpoints response data - `{endpoints}`")

# get object data
Expand Down
85 changes: 85 additions & 0 deletions tests/test_push.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from tests.test_modules.utils import (
ENCODED_AUTH_TOKEN,
ENCODED_AUTH_TOKEN_V_2_3_0,
ClientAuthenticator,
)

Expand Down Expand Up @@ -488,3 +489,87 @@ def test_push_token_module_uses_emsp_role():
crud.get.assert_awaited_once()
call_kwargs = crud.get.call_args
assert call_kwargs[0][1] == enums.RoleEnum.emsp


def test_push_v_2_3_0_uses_receiver_role_and_base64_token():
"""Push for v2.3.0 matches endpoints by RECEIVER role and base64-encodes the token."""
crud = AsyncMock()
adapter = MagicMock()
crud.get.return_value = LOCATIONS[0]
adapter.location_adapter.return_value.model_dump.return_value = LOCATIONS[0]

app = get_application(
version_numbers=[VersionNumber.v_2_3_0],
roles=[enums.RoleEnum.cpo],
crud=crud,
adapter=adapter,
authenticator=ClientAuthenticator,
modules=[],
http_push=True,
)

client = TestClient(app)
push_data = schemas.Push(
module_id=enums.ModuleID.locations,
object_id="loc-1",
receivers=[
schemas.Receiver(
endpoints_url="http://example.com/versions", auth_token="token"
),
],
).model_dump()

with patch("ocpi.core.push.httpx.AsyncClient") as mock_client:
mock_endpoints_response = MagicMock()
mock_endpoints_response.status_code = 200
mock_endpoints_response.json.return_value = {
"data": {
"endpoints": [
# SENDER endpoint — should be ignored
{
"identifier": enums.ModuleID.locations,
"role": "SENDER",
"url": "http://example.com/sender/locations/",
},
# RECEIVER endpoint — should be picked up
{
"identifier": enums.ModuleID.locations,
"role": "RECEIVER",
"url": "http://example.com/locations/",
},
]
}
}

mock_push_response = MagicMock()
mock_push_response.status_code = 200
mock_push_response.json.return_value = {"status_code": 1000}

mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_endpoints_response
)
mock_client.return_value.__aenter__.return_value.send = AsyncMock(
return_value=mock_push_response
)
build_request = MagicMock()
mock_client.return_value.__aenter__.return_value.build_request = build_request

response = client.post(
"/push/2.3.0",
json=push_data,
headers={"Authorization": f"Token {ENCODED_AUTH_TOKEN_V_2_3_0}"},
)

assert response.status_code == 200

# Token must be base64-encoded for 2.3.0
call_args = build_request.call_args
auth_header = call_args[1]["headers"]["Authorization"]
assert auth_header.startswith("Token ")
# The raw "token" string encoded in base64 is "dG9rZW4="
assert auth_header == "Token dG9rZW4="

# Must use the RECEIVER url, not the SENDER url
url_arg = call_args[0][1]
assert "sender" not in url_arg
assert "locations" in url_arg