Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 59 additions & 44 deletions p_kit/psl/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ def __init__(self):
def register_instance(self, instance: Any):
self.instances.append(instance)

def register_submodule(self, submodule: Any):
"""Flatten a sub-module's instances into this context."""
for instance in submodule._context.instances:
if instance not in self.instances:
self.instances.append(instance)

def reset_indices(self):
"""Reset all global indices before synthesis."""
for instance in self.instances:
Expand All @@ -23,20 +29,24 @@ def assign_global_indices(self) -> int:
port_count = 0
for instance in self.instances:
for port in self.get_circuit_ports(instance):
if (
hasattr(port, "_connection_strategy")
and getattr(port, "_connection_strategy")
and port.global_index is None
):
# Let the strategy handle index assignment
port_count = port._connection_strategy.assign_global_index(
port, port._connected_port, port_count, self.port_to_global
)
strategy = getattr(port, "_connection_strategy", None)
if strategy:
if port.global_index is None:
port_count = strategy.assign_global_index(
port, port._connected_port, port_count, self.port_to_global
)
elif (
isinstance(strategy, NoCopyConnection)
and port._connected_port.global_index is None
):
# Port already indexed by a prior NoCopy; propagate index down the chain
target = port._connected_port
target.global_index = port.global_index
self.port_to_global[target] = port.global_index
elif port.global_index is None:
# Unconnected ports get their own index
port.global_index = port_count
self.port_to_global[port] = port_count
port_count += 1
port_count += port.width
return port_count

def get_circuit_ports(self, instance) -> List[Port]:
Expand All @@ -63,7 +73,6 @@ def synthesize(self, format: str = "sparse") -> Union[
self.reset_indices()
total_ports = self.assign_global_indices()

# Initialize data structures
J_global = {} if format == "sparse" else np.zeros((total_ports, total_ports))
h_global = {} if format == "sparse" else np.zeros((total_ports, 1))

Expand All @@ -75,45 +84,48 @@ def synthesize(self, format: str = "sparse") -> Union[
def _process_instance_matrices(self, J_global, h_global, format: str):
"""Process internal J and h matrices of instances."""
for instance in self.instances:
# Skip module instances that have no J/h (they are port containers only)
if not hasattr(instance, "J") or not hasattr(instance, "h"):
continue
circuit_ports = self.get_circuit_ports(instance)
for port1 in circuit_ports:
gi = port1.global_index
if gi is not None:
# Add bias
if instance.h is not None and port1.index < instance.h.shape[0]:
if port1.global_index is None:
continue
for bit in range(port1.width):
gi = port1.global_index + bit
local_idx = port1.index + bit
if instance.h is not None and local_idx < instance.h.shape[0]:
val = float(instance.h.flat[local_idx])
if format == "sparse":
if gi in h_global.keys():
h_global[gi] += float(instance.h[port1.index])
else:
h_global[gi] = float(instance.h[port1.index])
h_global[gi] = h_global.get(gi, 0.0) + val
else:
h_global[gi, 0] += instance.h[port1.index]

# Add couplings
h_global[gi, 0] += val
self._add_instance_couplings(
instance, port1, gi, circuit_ports, J_global, format
instance, local_idx, gi, circuit_ports, J_global, format
)

def _add_instance_couplings(
self, instance, port1, gi, circuit_ports, J_global, format
self, instance, local_idx, gi, circuit_ports, J_global, format
):
"""Add coupling terms from instance J matrix."""
"""Add coupling terms from instance J matrix for one source bit."""
for port2 in circuit_ports:
gj = port2.global_index
if (
gj is not None
and gi != gj
and instance.J is not None
and port1.index < instance.J.shape[0]
and port2.index < instance.J.shape[1]
):

weight = instance.J[port1.index, port2.index]
if weight != 0: # Only store non-zero weights
if port2.global_index is None:
continue
for bit2 in range(port2.width):
gj = port2.global_index + bit2
local_idx2 = port2.index + bit2
if gi == gj:
continue
if (
instance.J is None
or local_idx >= instance.J.shape[0]
or local_idx2 >= instance.J.shape[1]
):
continue
weight = instance.J[local_idx, local_idx2]
if weight != 0:
if format == "sparse":
if gi not in J_global:
J_global[gi] = {}
J_global[gi][gj] = float(weight)
J_global.setdefault(gi, {})[gj] = float(weight)
else:
J_global[gi, gj] = weight

Expand All @@ -128,7 +140,10 @@ def _process_connections(self, J_global, format: str):
):
other_port = port._connected_port
if other_port.global_index is not None:
# Let the strategy handle matrix updates
port._connection_strategy.synthesize_connection(
port.global_index, other_port.global_index, J_global
)
for bit in range(port.width):
port._connection_strategy.synthesize_connection(
port.global_index + bit,
other_port.global_index + bit,
J_global,
format,
)
47 changes: 37 additions & 10 deletions p_kit/psl/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ def module(cls: Type[T]) -> Type[T]:
and synthesizing their combined matrices. The synthesize method supports both
sparse and dense matrix formats.

Modules may declare Port attributes at class level to expose an external
interface. These ports can be connected to internal circuit ports and
participate in synthesis via NoCopyConnection (sharing global indices).

Modules may also contain other @module instances as attributes; their
registered instances are flattened into the parent context automatically
(recursive synthesis).

Args:
cls (Type[T]): Class to be decorated

Expand All @@ -24,34 +32,53 @@ def module(cls: Type[T]) -> Type[T]:
Example:
>>> @module
>>> class MyCircuit:
>>> out = Port("out")
>>> def __init__(self):
>>> self.gate1 = ANDGate()
>>> self.gate2 = ANDGate()
>>> self.out.connect(self.gate1.output, NoCopyConnection)
>>>
>>> # Create instance and synthesize matrices
>>> circuit = MyCircuit()
>>>
>>> # Get sparse representation (default)
>>> J_sparse, h_sparse = circuit.synthesize()
>>>
>>> # Get dense matrix representation
>>> J_dense, h_dense = circuit.synthesize(format='dense')
"""
context = ModuleContext()
# Detect Port attributes declared on the class (module interface ports)
port_attrs = {
name: attr for name, attr in cls.__dict__.items() if isinstance(attr, Port)
}

original_new = cls.__new__
original_init = cls.__init__

def __new__(cls, *args, **kwargs):
instance = original_new(cls)
instance._context = context
def __new__(cls_ref, *args, **kwargs):
instance = original_new(cls_ref)
instance._context = ModuleContext() # per-instance context
return instance

def __init__(self, *args, **kwargs):
# Initialize module-level interface ports before original __init__ runs
# so that user code in __init__ can connect them to internal gates
if port_attrs:
idx = 0
for name, port in port_attrs.items():
new_port = Port(name=port.name, width=port.width)
new_port.circuit = self
new_port.index = idx
idx += port.width
setattr(self, name, new_port)
# Register the module itself so its interface ports participate
# in global index assignment before internal gate ports
self._context.register_instance(self)

original_init(self, *args, **kwargs)

for attr_name, attr_value in vars(self).items():
if attr_value is self:
continue
if hasattr(attr_value, "n_pbits"):
self._context.register_instance(attr_value)
elif hasattr(attr_value, "_context"):
# Sub-module: flatten its instances into this context
self._context.register_submodule(attr_value)

def synthesize(self, format: str = "sparse") -> Union[
Tuple[Dict[int, Dict[int, float]], Dict[int, float]],
Expand Down
10 changes: 7 additions & 3 deletions p_kit/psl/p_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,16 @@ def __init__(self, n_pbits: int, ports: Dict[str, Any] = None):

def _initialize_ports(self, port_attrs: Dict[str, Any]) -> None:
"""Initialize ports from attributes dictionary."""
# Create port index mapping
port_indices = {name: idx for idx, name in enumerate(port_attrs.keys())}
# Build cumulative index mapping respecting port widths
port_indices = {}
idx = 0
for name, port in port_attrs.items():
port_indices[name] = idx
idx += port.width

# Set up each port
for name, port in port_attrs.items():
new_port = Port(name=port.name)
new_port = Port(name=port.name, width=port.width)
new_port.circuit = self
new_port.index = port_indices[name]
# Check ports name doesn't conflict with reserved attributes.
Expand Down
18 changes: 15 additions & 3 deletions p_kit/psl/port.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def assign_global_index(
port_count: int,
port_to_global: Dict["Port", int],
) -> int:
# Each port gets unique index
# Each port gets unique index; allocate width consecutive slots
source_port.global_index = port_count
port_to_global[source_port] = port_count
return port_count + 1
return port_count + source_port.width

@abstractmethod
def synthesize_connection_sparse(
Expand Down Expand Up @@ -58,7 +58,7 @@ def assign_global_index(self, source_port, target_port, port_count, port_to_glob
target_port.global_index = port_count
port_to_global[source_port] = port_count
port_to_global[target_port] = port_count
return port_count + 1
return port_count + source_port.width

def synthesize_connection_sparse(
self, source_idx: int, target_idx: int, J_global: Dict[int, Dict[int, float]]
Expand Down Expand Up @@ -152,8 +152,16 @@ class Port:
circuit: Any = None
index: int = None
global_index: int = None
width: int = 1
_connections: Dict = field(default_factory=dict)

@property
def global_indices(self) -> List[int]:
"""Returns all global indices covered by this port (one per bit)."""
if self.global_index is None:
return []
return list(range(self.global_index, self.global_index + self.width))

def __hash__(self):
"""
Generates a hash based on the combination of circuit reference and port name.
Expand Down Expand Up @@ -189,6 +197,10 @@ def connect(self, other_port: "Port", strategy: ConnectionStrategy):
raise ValueError("Both ports must be bound to circuits")
if self.index is None or other_port.index is None:
raise ValueError("Both ports must have assigned indices")
if self.width != other_port.width:
raise ValueError(
f"Cannot connect ports of different widths: {self.width} vs {other_port.width}"
)

if isinstance(strategy, type):
self._connection_strategy = strategy()
Expand Down
Loading
Loading