diff --git a/ocpi/core/push.py b/ocpi/core/push.py index 5a8bfd1..fc3d81e 100644 --- a/ocpi/core/push.py +++ b/ocpi/core/push.py @@ -16,10 +16,34 @@ from ocpi.modules.versions.v_2_2_1.enums import InterfaceRole +_VERSION_PREFERENCE = ["2.3.0", "2.2.1", "2.1.1"] + + +def _pick_version_details_url( + versions_list: list[dict], requested: VersionNumber +) -> str | None: + """Pick the best version details URL from an OCPI /versions response. + + Tries the requested version first, then falls back to the highest + mutually supported version. + """ + by_version = {v["version"]: v["url"] for v in versions_list if "version" in v} + + if requested.value in by_version: + return by_version[requested.value] + + for v in _VERSION_PREFERENCE: + if v in by_version: + return by_version[v] + + return None + + def client_url(module_id: ModuleID, object_id: str, base_url: str) -> str: if module_id == ModuleID.cdrs: return base_url - return f"{base_url}{settings.COUNTRY_CODE}/{settings.PARTY_ID}/{object_id}" + base = base_url.rstrip("/") + return f"{base}/{settings.COUNTRY_CODE}/{settings.PARTY_ID}/{object_id}" def client_method(module_id: ModuleID) -> str: @@ -62,7 +86,7 @@ async def send_push_request( base_url = "" for endpoint in endpoints: if ( - version.value.startswith("2.2") + not (version.value.startswith("2.0") or version.value.startswith("2.1")) and endpoint["identifier"] == module_id and endpoint["role"] == InterfaceRole.receiver ) or (version.value.startswith("2.1") and endpoint["identifier"] == module_id): @@ -92,7 +116,7 @@ async def push_object( # get client endpoints if version.value.startswith("2.1") or version.value.startswith("2.0"): token = receiver.auth_token - else: + else: # 2.2.x and 2.3.x use base64-encoded tokens token = encode_string_base64(receiver.auth_token) client_auth_token = f"Token {token}" @@ -106,7 +130,33 @@ async def push_object( headers={"authorization": client_auth_token}, ) logger.info(f"Response status_code - `{response.status_code}`") - endpoints = response.json()["data"]["endpoints"] + response_data = response.json()["data"] + + # If response is a versions list, negotiate version and + # fetch the details URL for the best mutual version. + if isinstance(response_data, list): + details_url = _pick_version_details_url( + response_data, version + ) + if not details_url: + raise ValueError( + f"No mutual OCPI version found. " + f"Requested {version.value}, receiver supports: " + f"{[v.get('version') for v in response_data]}" + ) + logger.info( + f"Resolved version details URL: {details_url}" + ) + response = await client.get( + details_url, + headers={"authorization": client_auth_token}, + ) + logger.info( + f"Version details response: {response.status_code}" + ) + response_data = response.json()["data"] + + endpoints = response_data["endpoints"] logger.debug(f"Endpoints response data - `{endpoints}`") # get object data diff --git a/tests/test_push.py b/tests/test_push.py index f9ae0c4..378d48c 100644 --- a/tests/test_push.py +++ b/tests/test_push.py @@ -12,6 +12,7 @@ ) from tests.test_modules.utils import ( ENCODED_AUTH_TOKEN, + ENCODED_AUTH_TOKEN_V_2_3_0, ClientAuthenticator, ) @@ -488,3 +489,87 @@ def test_push_token_module_uses_emsp_role(): crud.get.assert_awaited_once() call_kwargs = crud.get.call_args assert call_kwargs[0][1] == enums.RoleEnum.emsp + + +def test_push_v_2_3_0_uses_receiver_role_and_base64_token(): + """Push for v2.3.0 matches endpoints by RECEIVER role and base64-encodes the token.""" + crud = AsyncMock() + adapter = MagicMock() + crud.get.return_value = LOCATIONS[0] + adapter.location_adapter.return_value.model_dump.return_value = LOCATIONS[0] + + app = get_application( + version_numbers=[VersionNumber.v_2_3_0], + roles=[enums.RoleEnum.cpo], + crud=crud, + adapter=adapter, + authenticator=ClientAuthenticator, + modules=[], + http_push=True, + ) + + client = TestClient(app) + push_data = schemas.Push( + module_id=enums.ModuleID.locations, + object_id="loc-1", + receivers=[ + schemas.Receiver( + endpoints_url="http://example.com/versions", auth_token="token" + ), + ], + ).model_dump() + + with patch("ocpi.core.push.httpx.AsyncClient") as mock_client: + mock_endpoints_response = MagicMock() + mock_endpoints_response.status_code = 200 + mock_endpoints_response.json.return_value = { + "data": { + "endpoints": [ + # SENDER endpoint — should be ignored + { + "identifier": enums.ModuleID.locations, + "role": "SENDER", + "url": "http://example.com/sender/locations/", + }, + # RECEIVER endpoint — should be picked up + { + "identifier": enums.ModuleID.locations, + "role": "RECEIVER", + "url": "http://example.com/locations/", + }, + ] + } + } + + mock_push_response = MagicMock() + mock_push_response.status_code = 200 + mock_push_response.json.return_value = {"status_code": 1000} + + mock_client.return_value.__aenter__.return_value.get = AsyncMock( + return_value=mock_endpoints_response + ) + mock_client.return_value.__aenter__.return_value.send = AsyncMock( + return_value=mock_push_response + ) + build_request = MagicMock() + mock_client.return_value.__aenter__.return_value.build_request = build_request + + response = client.post( + "/push/2.3.0", + json=push_data, + headers={"Authorization": f"Token {ENCODED_AUTH_TOKEN_V_2_3_0}"}, + ) + + assert response.status_code == 200 + + # Token must be base64-encoded for 2.3.0 + call_args = build_request.call_args + auth_header = call_args[1]["headers"]["Authorization"] + assert auth_header.startswith("Token ") + # The raw "token" string encoded in base64 is "dG9rZW4=" + assert auth_header == "Token dG9rZW4=" + + # Must use the RECEIVER url, not the SENDER url + url_arg = call_args[0][1] + assert "sender" not in url_arg + assert "locations" in url_arg