Skip to content

Commit 6fc5c54

Browse files
tode-rlcursoragent
andcommitted
Fix pyright issues in Runloop rollout and related tests.
Use the correct blueprint build_context type and add narrowing assertions in new test coverage. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 1b21d81 commit 6fc5c54

4 files changed

Lines changed: 24 additions & 12 deletions

File tree

examples/runloop_remote_rollout/create_blueprint.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pathlib import Path
77

88
from runloop_api_client import RunloopSDK
9-
from runloop_api_client.types.blueprint_build_parameters import BuildContext
109

1110

1211
DEFAULT_DOCKERFILE = """\
@@ -60,7 +59,7 @@ def main() -> None:
6059
blueprint = runloop.blueprint.create(
6160
name=args.name,
6261
dockerfile=DEFAULT_DOCKERFILE,
63-
build_context=BuildContext(type="object", object_id=build_context.id),
62+
build_context={"type": "object", "object_id": build_context.id},
6463
)
6564

6665
print(f"export RUNLOOP_BLUEPRINT_ID={blueprint.id}")

tests/adapters/test_fireworks_tracing_logprobs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def test_attaches_completion_logprobs_and_message_logprobs(self):
7171

7272
assistant = row.messages[-1]
7373
assert assistant.role == "assistant"
74+
assert assistant.logprobs is not None
7475
content = assistant.logprobs["content"]
7576
assert len(content) == len(extra["completion_logprobs"])
7677
assert content[0]["token_id"] == 10
@@ -83,10 +84,13 @@ def test_omits_token_id_keys_when_any_missing(self):
8384
assert row is not None
8485

8586
extra = row.execution_metadata.extra
87+
assert extra is not None
8688
assert "completion_logprobs" in extra
8789
assert "completion_token_ids" not in extra
8890

89-
content = row.messages[-1].logprobs["content"]
91+
logprobs = row.messages[-1].logprobs
92+
assert logprobs is not None
93+
content = logprobs["content"]
9094
assert len(content) == 2
9195
assert all("token_id" not in entry for entry in content)
9296
assert content[0]["logprob"] == pytest.approx(-0.1)

tests/adapters/test_r3_deserializer.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,9 @@ def test_all_mode_uint8(self):
143143
assert metadata["replayed_token_count"] == total_tokens
144144

145145
for i in range(total_tokens):
146-
assert matrices[i] is not None
147-
decoded = base64.b64decode(matrices[i])
146+
matrix = matrices[i]
147+
assert matrix is not None
148+
decoded = base64.b64decode(matrix)
148149
assert decoded == matrices_raw[i]
149150

150151
def test_suffix_mode(self):
@@ -180,8 +181,9 @@ def test_suffix_mode(self):
180181
# Positions from start_token to start_token+replayed should have data
181182
for i in range(replayed):
182183
pos = start_token + i
183-
assert matrices[pos] is not None
184-
decoded = base64.b64decode(matrices[pos])
184+
matrix = matrices[pos]
185+
assert matrix is not None
186+
decoded = base64.b64decode(matrix)
185187
assert decoded == matrices_raw[i]
186188

187189
def test_bitmap_mode(self):
@@ -220,9 +222,10 @@ def test_bitmap_mode(self):
220222

221223
for i in range(total_tokens):
222224
if i in replayed_positions:
223-
assert matrices[i] is not None
225+
matrix = matrices[i]
226+
assert matrix is not None
224227
idx = replayed_positions.index(i)
225-
decoded = base64.b64decode(matrices[i])
228+
decoded = base64.b64decode(matrix)
226229
assert decoded == matrices_raw[idx]
227230
else:
228231
assert matrices[i] is None
@@ -249,7 +252,9 @@ def test_uint16_dtype(self):
249252
assert metadata["routing_dtype"] == "uint16"
250253
assert len(matrices) == total_tokens
251254
for i in range(total_tokens):
252-
decoded = base64.b64decode(matrices[i])
255+
matrix = matrices[i]
256+
assert matrix is not None
257+
decoded = base64.b64decode(matrix)
253258
assert decoded == matrices_raw[i]
254259

255260
def test_zero_replayed_tokens(self):
@@ -302,6 +307,7 @@ def test_high_compression_ratio_payload(self):
302307
assert len(matrices) == total_tokens
303308
assert metadata["replayed_token_count"] == total_tokens
304309
for m in matrices:
310+
assert m is not None
305311
assert base64.b64decode(m) == b"\x00" * matrix_elem_size
306312

307313

tests/pytest/test_runloop_rollout_processor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
from types import SimpleNamespace
33
import urllib.error
4+
from email.message import Message
45

56
import pytest
67

@@ -273,7 +274,9 @@ def read(self, size):
273274
def _urlopen(request, timeout):
274275
calls.append((request.full_url, timeout))
275276
if len(calls) == 1:
276-
raise urllib.error.HTTPError(request.full_url, 503, "Service Unavailable", hdrs=None, fp=None)
277+
raise urllib.error.HTTPError(
278+
request.full_url, 503, "Service Unavailable", hdrs=Message(), fp=None
279+
)
277280
return ReadyResponse()
278281

279282
monkeypatch.setattr(runloop_rollout_processor_module.urllib.request, "urlopen", _urlopen)
@@ -296,7 +299,7 @@ def test_startup_wait_accepts_non_5xx_http_errors(monkeypatch):
296299

297300
def _urlopen(request, timeout):
298301
calls.append((request.full_url, timeout))
299-
raise urllib.error.HTTPError(request.full_url, 404, "Not Found", hdrs=None, fp=None)
302+
raise urllib.error.HTTPError(request.full_url, 404, "Not Found", hdrs=Message(), fp=None)
300303

301304
monkeypatch.setattr(runloop_rollout_processor_module.urllib.request, "urlopen", _urlopen)
302305

0 commit comments

Comments
 (0)