Skip to content

Commit e80db52

Browse files
authored
[mypyc] Make compilation order with multiple files consistent (#21419)
The `mypyc` plugin sets all modules inside a group as dependent on each other so they are part of a single SCC. When we compile the SCC, we compile its modules according to the order in `scc.mod_ids`. `mod_ids` is a `set` so the order is non-deterministic. This might result in random failures when multiple files are compiled together and we have a bug in mypyc caused by relying on data that is dependent on the compilation order. For example if we are compiling two files `a.py` and `b.py` each with a single class and an attribute of `ClassIR` representing the class `A` that is not set during the preparation phase is used when compiling the class `B`, the result will depend on whether `a.py` was compiled before `b.py`. To fix, make the order deterministic based on module names by sorting them first. Also sort the source files and groups passed to `mypycify` so that invoking `mypyc a.py b.py` and `mypyc b.py a.py` produces identical results.
1 parent e7846d9 commit e80db52

4 files changed

Lines changed: 85 additions & 5 deletions

File tree

mypyc/build.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,15 +485,19 @@ def construct_groups(
485485
else:
486486
groups = [(sources, None)]
487487

488-
# Generate missing names
488+
# Generate missing names.
489+
# Sort the modules to make the compilation results consistent regardless of
490+
# the source file order passed to mypycify.
489491
for i, (group, name) in enumerate(groups):
492+
group = sorted(group, key=lambda source: source.module)
490493
if use_shared_lib and not name:
491494
if group_name_override is not None:
492495
name = group_name_override
493496
else:
494497
name = group_name([source.module for source in group])
495498
groups[i] = (group, name)
496499

500+
groups = sorted(groups, key=lambda g: (g[1] or "", [s.module for s in g[0]]))
497501
return groups
498502

499503

mypyc/codegen/emitmodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def compile_modules_to_ir(
305305

306306
# Process the graph by SCC in topological order, like we do in mypy.build
307307
for scc in sorted_components(result.graph):
308-
scc_states = [result.graph[id] for id in scc.mod_ids]
308+
scc_states = [result.graph[id] for id in sorted(scc.mod_ids)]
309309
trees = [st.tree for st in scc_states if st.id in mapper.group_map and st.tree]
310310

311311
if not trees:
@@ -1441,7 +1441,7 @@ def _toposort_visit(name: str) -> None:
14411441
if decl.mark:
14421442
return
14431443

1444-
for child in decl.declaration.dependencies:
1444+
for child in sorted(decl.declaration.dependencies):
14451445
_toposort_visit(child)
14461446

14471447
result.append(decl.declaration)

mypyc/irbuild/builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,11 +1076,11 @@ def get_sequence_type_from_type(self, target_type: Type) -> RType:
10761076
items = target_type.items
10771077
assert items, "This function does not support empty tuples"
10781078
# Tuple might have elements of different types.
1079-
rtypes = set(map(self.mapper.type_to_rtype, items))
1079+
rtypes = list(dict.fromkeys(self.mapper.type_to_rtype(item) for item in items))
10801080
if len(rtypes) == 1:
10811081
return rtypes.pop()
10821082
else:
1083-
return RUnion.make_simplified_union(list(rtypes))
1083+
return RUnion.make_simplified_union(rtypes)
10841084
assert False, target_type
10851085

10861086
def get_dict_base_type(self, expr: Expression) -> list[Instance]:

mypyc/test/test_emitmodule.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from __future__ import annotations
2+
3+
import tempfile
4+
import unittest
5+
from pathlib import Path
6+
7+
import pytest
8+
9+
from mypy import build
10+
from mypy.options import Options
11+
from mypyc.build import construct_groups
12+
from mypyc.codegen import emitmodule
13+
from mypyc.errors import Errors
14+
from mypyc.irbuild.mapper import Mapper
15+
from mypyc.options import CompilerOptions
16+
17+
18+
class FakeSCC:
19+
def __init__(self, mod_ids: list[str]) -> None:
20+
self.mod_ids = mod_ids
21+
22+
23+
class TestEmitModule(unittest.TestCase):
24+
def test_compile_modules_to_ir_orders_scc_members_deterministically(self) -> None:
25+
with tempfile.TemporaryDirectory() as tmp_dir, pytest.MonkeyPatch.context() as monkeypatch:
26+
tmp_path = Path(tmp_dir)
27+
a_py = tmp_path / "a.py"
28+
b_py = tmp_path / "b.py"
29+
a_py.write_text("import b\n\nclass A: pass\nclass C(A): pass\n", encoding="utf-8")
30+
b_py.write_text(
31+
"import a\n\nclass B(a.A): pass\nclass D(a.A): pass\n", encoding="utf-8"
32+
)
33+
34+
sources = [
35+
build.BuildSource(str(a_py), "a", None),
36+
build.BuildSource(str(b_py), "b", None),
37+
]
38+
options = Options()
39+
options.preserve_asts = True
40+
options.mypy_path = [str(tmp_path)]
41+
options.cache_dir = str(tmp_path / ".mypy_cache")
42+
for source in sources:
43+
options.per_module_options.setdefault(source.module, {})["mypyc"] = True
44+
45+
compiler_options = CompilerOptions(strict_traceback_checks=True)
46+
groups = construct_groups(
47+
sources, False, use_shared_lib=True, group_name_override=None
48+
)
49+
result = emitmodule.parse_and_typecheck(sources, options, compiler_options, groups)
50+
try:
51+
group_map = {
52+
source.module: lib_name for group, lib_name in groups for source in group
53+
}
54+
children_by_order = []
55+
for order in (["a", "b"], ["b", "a"]):
56+
monkeypatch.setattr(
57+
emitmodule,
58+
"sorted_components",
59+
lambda graph, order=order: [FakeSCC(order)],
60+
)
61+
mapper = Mapper(group_map)
62+
errors = Errors(options)
63+
modules = emitmodule.compile_modules_to_ir(
64+
result, mapper, compiler_options, errors
65+
)
66+
assert errors.num_errors == 0, errors.new_messages()
67+
classes = {
68+
cl.fullname: cl for module in modules.values() for cl in module.classes
69+
}
70+
children = classes["a.A"].children
71+
assert children is not None
72+
children_by_order.append([child.fullname for child in children])
73+
74+
assert children_by_order[1] == children_by_order[0]
75+
finally:
76+
result.manager.metastore.close()

0 commit comments

Comments
 (0)