From 0ecf849df1ec79b4e9531179e82c73037a8487c0 Mon Sep 17 00:00:00 2001 From: Gregoire CATTAN Date: Sat, 23 May 2026 22:01:23 +0200 Subject: [PATCH] Recursive P-Circuit Synthesis --- p_kit/psl/context.py | 103 ++++++++++++++----------- p_kit/psl/decorators.py | 47 +++++++++--- p_kit/psl/p_circuit.py | 10 ++- p_kit/psl/port.py | 18 ++++- tests/test_psl.py | 163 +++++++++++++++++++++++++++++++++++++++- 5 files changed, 280 insertions(+), 61 deletions(-) diff --git a/p_kit/psl/context.py b/p_kit/psl/context.py index 6ca3b27..cc13f82 100644 --- a/p_kit/psl/context.py +++ b/p_kit/psl/context.py @@ -11,6 +11,12 @@ def __init__(self): def register_instance(self, instance: Any): self.instances.append(instance) + def register_submodule(self, submodule: Any): + """Flatten a sub-module's instances into this context.""" + for instance in submodule._context.instances: + if instance not in self.instances: + self.instances.append(instance) + def reset_indices(self): """Reset all global indices before synthesis.""" for instance in self.instances: @@ -23,20 +29,24 @@ def assign_global_indices(self) -> int: port_count = 0 for instance in self.instances: for port in self.get_circuit_ports(instance): - if ( - hasattr(port, "_connection_strategy") - and getattr(port, "_connection_strategy") - and port.global_index is None - ): - # Let the strategy handle index assignment - port_count = port._connection_strategy.assign_global_index( - port, port._connected_port, port_count, self.port_to_global - ) + strategy = getattr(port, "_connection_strategy", None) + if strategy: + if port.global_index is None: + port_count = strategy.assign_global_index( + port, port._connected_port, port_count, self.port_to_global + ) + elif ( + isinstance(strategy, NoCopyConnection) + and port._connected_port.global_index is None + ): + # Port already indexed by a prior NoCopy; propagate index down the chain + target = port._connected_port + target.global_index = port.global_index + self.port_to_global[target] = port.global_index elif port.global_index is None: - # Unconnected ports get their own index port.global_index = port_count self.port_to_global[port] = port_count - port_count += 1 + port_count += port.width return port_count def get_circuit_ports(self, instance) -> List[Port]: @@ -63,7 +73,6 @@ def synthesize(self, format: str = "sparse") -> Union[ self.reset_indices() total_ports = self.assign_global_indices() - # Initialize data structures J_global = {} if format == "sparse" else np.zeros((total_ports, total_ports)) h_global = {} if format == "sparse" else np.zeros((total_ports, 1)) @@ -75,45 +84,48 @@ def synthesize(self, format: str = "sparse") -> Union[ def _process_instance_matrices(self, J_global, h_global, format: str): """Process internal J and h matrices of instances.""" for instance in self.instances: + # Skip module instances that have no J/h (they are port containers only) + if not hasattr(instance, "J") or not hasattr(instance, "h"): + continue circuit_ports = self.get_circuit_ports(instance) for port1 in circuit_ports: - gi = port1.global_index - if gi is not None: - # Add bias - if instance.h is not None and port1.index < instance.h.shape[0]: + if port1.global_index is None: + continue + for bit in range(port1.width): + gi = port1.global_index + bit + local_idx = port1.index + bit + if instance.h is not None and local_idx < instance.h.shape[0]: + val = float(instance.h.flat[local_idx]) if format == "sparse": - if gi in h_global.keys(): - h_global[gi] += float(instance.h[port1.index]) - else: - h_global[gi] = float(instance.h[port1.index]) + h_global[gi] = h_global.get(gi, 0.0) + val else: - h_global[gi, 0] += instance.h[port1.index] - - # Add couplings + h_global[gi, 0] += val self._add_instance_couplings( - instance, port1, gi, circuit_ports, J_global, format + instance, local_idx, gi, circuit_ports, J_global, format ) def _add_instance_couplings( - self, instance, port1, gi, circuit_ports, J_global, format + self, instance, local_idx, gi, circuit_ports, J_global, format ): - """Add coupling terms from instance J matrix.""" + """Add coupling terms from instance J matrix for one source bit.""" for port2 in circuit_ports: - gj = port2.global_index - if ( - gj is not None - and gi != gj - and instance.J is not None - and port1.index < instance.J.shape[0] - and port2.index < instance.J.shape[1] - ): - - weight = instance.J[port1.index, port2.index] - if weight != 0: # Only store non-zero weights + if port2.global_index is None: + continue + for bit2 in range(port2.width): + gj = port2.global_index + bit2 + local_idx2 = port2.index + bit2 + if gi == gj: + continue + if ( + instance.J is None + or local_idx >= instance.J.shape[0] + or local_idx2 >= instance.J.shape[1] + ): + continue + weight = instance.J[local_idx, local_idx2] + if weight != 0: if format == "sparse": - if gi not in J_global: - J_global[gi] = {} - J_global[gi][gj] = float(weight) + J_global.setdefault(gi, {})[gj] = float(weight) else: J_global[gi, gj] = weight @@ -128,7 +140,10 @@ def _process_connections(self, J_global, format: str): ): other_port = port._connected_port if other_port.global_index is not None: - # Let the strategy handle matrix updates - port._connection_strategy.synthesize_connection( - port.global_index, other_port.global_index, J_global - ) + for bit in range(port.width): + port._connection_strategy.synthesize_connection( + port.global_index + bit, + other_port.global_index + bit, + J_global, + format, + ) diff --git a/p_kit/psl/decorators.py b/p_kit/psl/decorators.py index 3ddea91..fed8970 100644 --- a/p_kit/psl/decorators.py +++ b/p_kit/psl/decorators.py @@ -15,6 +15,14 @@ def module(cls: Type[T]) -> Type[T]: and synthesizing their combined matrices. The synthesize method supports both sparse and dense matrix formats. + Modules may declare Port attributes at class level to expose an external + interface. These ports can be connected to internal circuit ports and + participate in synthesis via NoCopyConnection (sharing global indices). + + Modules may also contain other @module instances as attributes; their + registered instances are flattened into the parent context automatically + (recursive synthesis). + Args: cls (Type[T]): Class to be decorated @@ -24,34 +32,53 @@ def module(cls: Type[T]) -> Type[T]: Example: >>> @module >>> class MyCircuit: + >>> out = Port("out") >>> def __init__(self): >>> self.gate1 = ANDGate() - >>> self.gate2 = ANDGate() + >>> self.out.connect(self.gate1.output, NoCopyConnection) >>> - >>> # Create instance and synthesize matrices >>> circuit = MyCircuit() - >>> - >>> # Get sparse representation (default) >>> J_sparse, h_sparse = circuit.synthesize() - >>> - >>> # Get dense matrix representation >>> J_dense, h_dense = circuit.synthesize(format='dense') """ - context = ModuleContext() + # Detect Port attributes declared on the class (module interface ports) + port_attrs = { + name: attr for name, attr in cls.__dict__.items() if isinstance(attr, Port) + } original_new = cls.__new__ original_init = cls.__init__ - def __new__(cls, *args, **kwargs): - instance = original_new(cls) - instance._context = context + def __new__(cls_ref, *args, **kwargs): + instance = original_new(cls_ref) + instance._context = ModuleContext() # per-instance context return instance def __init__(self, *args, **kwargs): + # Initialize module-level interface ports before original __init__ runs + # so that user code in __init__ can connect them to internal gates + if port_attrs: + idx = 0 + for name, port in port_attrs.items(): + new_port = Port(name=port.name, width=port.width) + new_port.circuit = self + new_port.index = idx + idx += port.width + setattr(self, name, new_port) + # Register the module itself so its interface ports participate + # in global index assignment before internal gate ports + self._context.register_instance(self) + original_init(self, *args, **kwargs) + for attr_name, attr_value in vars(self).items(): + if attr_value is self: + continue if hasattr(attr_value, "n_pbits"): self._context.register_instance(attr_value) + elif hasattr(attr_value, "_context"): + # Sub-module: flatten its instances into this context + self._context.register_submodule(attr_value) def synthesize(self, format: str = "sparse") -> Union[ Tuple[Dict[int, Dict[int, float]], Dict[int, float]], diff --git a/p_kit/psl/p_circuit.py b/p_kit/psl/p_circuit.py index 4b50783..1ad2c04 100644 --- a/p_kit/psl/p_circuit.py +++ b/p_kit/psl/p_circuit.py @@ -36,12 +36,16 @@ def __init__(self, n_pbits: int, ports: Dict[str, Any] = None): def _initialize_ports(self, port_attrs: Dict[str, Any]) -> None: """Initialize ports from attributes dictionary.""" - # Create port index mapping - port_indices = {name: idx for idx, name in enumerate(port_attrs.keys())} + # Build cumulative index mapping respecting port widths + port_indices = {} + idx = 0 + for name, port in port_attrs.items(): + port_indices[name] = idx + idx += port.width # Set up each port for name, port in port_attrs.items(): - new_port = Port(name=port.name) + new_port = Port(name=port.name, width=port.width) new_port.circuit = self new_port.index = port_indices[name] # Check ports name doesn't conflict with reserved attributes. diff --git a/p_kit/psl/port.py b/p_kit/psl/port.py index 0653bbd..5cb0130 100644 --- a/p_kit/psl/port.py +++ b/p_kit/psl/port.py @@ -17,10 +17,10 @@ def assign_global_index( port_count: int, port_to_global: Dict["Port", int], ) -> int: - # Each port gets unique index + # Each port gets unique index; allocate width consecutive slots source_port.global_index = port_count port_to_global[source_port] = port_count - return port_count + 1 + return port_count + source_port.width @abstractmethod def synthesize_connection_sparse( @@ -58,7 +58,7 @@ def assign_global_index(self, source_port, target_port, port_count, port_to_glob target_port.global_index = port_count port_to_global[source_port] = port_count port_to_global[target_port] = port_count - return port_count + 1 + return port_count + source_port.width def synthesize_connection_sparse( self, source_idx: int, target_idx: int, J_global: Dict[int, Dict[int, float]] @@ -152,8 +152,16 @@ class Port: circuit: Any = None index: int = None global_index: int = None + width: int = 1 _connections: Dict = field(default_factory=dict) + @property + def global_indices(self) -> List[int]: + """Returns all global indices covered by this port (one per bit).""" + if self.global_index is None: + return [] + return list(range(self.global_index, self.global_index + self.width)) + def __hash__(self): """ Generates a hash based on the combination of circuit reference and port name. @@ -189,6 +197,10 @@ def connect(self, other_port: "Port", strategy: ConnectionStrategy): raise ValueError("Both ports must be bound to circuits") if self.index is None or other_port.index is None: raise ValueError("Both ports must have assigned indices") + if self.width != other_port.width: + raise ValueError( + f"Cannot connect ports of different widths: {self.width} vs {other_port.width}" + ) if isinstance(strategy, type): self._connection_strategy = strategy() diff --git a/tests/test_psl.py b/tests/test_psl.py index e4334d8..3f346b9 100644 --- a/tests/test_psl.py +++ b/tests/test_psl.py @@ -1,7 +1,7 @@ import pytest import numpy as np from p_kit import psl -from p_kit.psl.gates import ANDGate +from p_kit.psl.gates import ANDGate, ORGate @pytest.fixture @@ -217,3 +217,164 @@ def test_invalid_connection_raises_error(): port1 = psl.Port("test1") port2 = psl.Port("test2") port1.connect(port2, psl.NoCopyConnection) + + +# ── Feature: Ports with width ───────────────────────────────────────────────── + +@psl.pcircuit(n_pbits=5) +class WideBusGate: + """A gate with a 4-bit input bus and a 1-bit output.""" + data_in = psl.Port("data_in", width=4) + flag = psl.Port("flag", width=1) + + J = np.zeros((5, 5)) + h = np.zeros((5, 1)) + + +def test_port_width_field(): + gate = WideBusGate() + assert gate.data_in.width == 4 + assert gate.flag.width == 1 + + +def test_port_global_indices_after_synthesis(): + @psl.module + class WideCircuit: + def __init__(self): + self.gate = WideBusGate() + + wc = WideCircuit() + wc.synthesize(format="dense") + gate = wc.gate + assert len(gate.data_in.global_indices) == 4 + gi = gate.data_in.global_indices + assert gi == list(range(gi[0], gi[0] + 4)) + + +def test_wide_port_connection_width_mismatch_raises(): + gate1 = ANDGate() + gate2 = ANDGate() + # output (width=1) vs input1 (width=1) — OK, but let's try width mismatch via Port directly + with pytest.raises(ValueError, match="different widths"): + @psl.module + class Mismatch: + def __init__(self): + self.g1 = WideBusGate() + self.g2 = ANDGate() + # data_in (width=4) vs input1 (width=1) + self.g1.data_in.connect(self.g2.input1, psl.NoCopyConnection) + Mismatch() + + +def test_wide_port_index_offset(): + """data_in starts at local index 0, flag at local index 4.""" + gate = WideBusGate() + assert gate.data_in.index == 0 + assert gate.flag.index == 4 + + +# ── Feature: Modules with Ports ─────────────────────────────────────────────── + +@psl.module +class AndWrapper: + """Module that exposes the AND gate's output as a named interface port.""" + result = psl.Port("result") + + def __init__(self): + self.gate = ANDGate() + self.result.connect(self.gate.output, psl.NoCopyConnection) + + +def test_module_port_shares_global_index(): + aw = AndWrapper() + aw.synthesize(format="dense") + assert aw.result.global_index == aw.gate.output.global_index + + +def test_module_port_in_synthesis_shape(): + aw = AndWrapper() + J, h = aw.synthesize(format="dense") + # AND gate has 3 p-bits; result shares output's index → still 3 unique indices + assert J.shape == (3, 3) + assert h.shape == (3, 1) + + +def test_module_with_port_connected_to_another_gate(): + # Connect via internal gate port; the NoCopy propagation carries the shared + # index through the chain: result(0) == gate.output(0) == gate2.input1(0) + @psl.module + class Chain: + def __init__(self): + self.wrapper = AndWrapper() + self.gate2 = ANDGate() + self.wrapper.gate.output.connect(self.gate2.input1, psl.NoCopyConnection) + + chain = Chain() + J, h = chain.synthesize(format="dense") + # 3 (AND in wrapper) + 3 (gate2) - 1 (shared output/input1) = 5 + assert J.shape == (5, 5) + # All three ports collapse to the same global index + assert chain.wrapper.result.global_index == chain.wrapper.gate.output.global_index + assert chain.wrapper.gate.output.global_index == chain.gate2.input1.global_index + + +# ── Feature: Recursive Synthesis ────────────────────────────────────────────── + +@psl.module +class TwoAnds: + """Two AND gates chained; output of first feeds input of second.""" + def __init__(self): + self.and1 = ANDGate() + self.and2 = ANDGate() + self.and1.output.connect(self.and2.input1, psl.NoCopyConnection) + + +def test_recursive_synthesis_flat(): + """Sub-module instances are flattened into the parent context.""" + @psl.module + class FourAnds: + def __init__(self): + self.pair1 = TwoAnds() + self.pair2 = TwoAnds() + + fa = FourAnds() + J, h = fa.synthesize(format="dense") + # pair1: 3+3-1 = 5 unique pbits; pair2: 5 unique pbits; no cross-connection → 10 total + assert J.shape == (10, 10) + assert h.shape == (10, 1) + + +def test_recursive_synthesis_preserves_couplings(): + """Internal couplings of sub-modules are present in the global J matrix.""" + @psl.module + class Nested: + def __init__(self): + self.sub = TwoAnds() + + n = Nested() + J_nested, _ = n.synthesize(format="dense") + + standalone = TwoAnds() + J_standalone, _ = standalone.synthesize(format="dense") + + assert J_nested.shape == J_standalone.shape + assert np.allclose(J_nested, J_standalone) + + +def test_recursive_synthesis_sparse_dense_equivalence(): + @psl.module + class Nested: + def __init__(self): + self.sub = TwoAnds() + + n = Nested() + J_sparse, h_sparse = n.synthesize(format="sparse") + J_dense, h_dense = n.synthesize(format="dense") + + size = J_dense.shape[0] + J_from_sparse = np.zeros((size, size)) + for i, row in J_sparse.items(): + for j, w in row.items(): + J_from_sparse[i, j] = w + + assert np.allclose(J_dense, J_from_sparse)