diff --git a/src/phaseshift/bell_interferometer.py b/src/phaseshift/bell_interferometer.py index b25de47..1387538 100644 --- a/src/phaseshift/bell_interferometer.py +++ b/src/phaseshift/bell_interferometer.py @@ -31,6 +31,7 @@ It was introduced in [this PR by Jake Bulmer and Yuan Yao](https://github.com/XanaduAI/strawberryfields/pull/584#issue-894649549). """ +from io import StringIO from collections import defaultdict from dataclasses import dataclass, field @@ -444,3 +445,46 @@ def circuit_reconstruction(decomp: BellDecomp) -> NDArray[np.complex128]: U = phase_shifter(dim, mode, phi) @ U return U + + +def circuit_printer(decomp, width=30, rounding=3): + """Print the circuit corresponding to the decomposition using qpic syntax. + + Given a `BellDecomp` instance, this function prints the circuits in qpic format. + The `BellDecomp` instance can be obtained from the `bell_decomposition` function. + + Args: + decomp (BellDecomp): The Bell decomposition containing the parameters of the circuit. + + Returns: + ### NEED TO WRITE IT + + Raises: + TypeError: If the input is not a BellDecomp instance. + """ + output = StringIO() + dim = decomp.dim + # Apply the phase shifters at the input + for mode, phi in decomp.phi_input.items(): + print(f"{mode} G ${round(phi/np.pi, rounding)}$ width={width}", file=output) + + # Iterate through the layers of the circuit + for layer in range(dim): + # Apply the sMZIs in the layer + for mode in range(layer % 2, dim - 1, 2): + delta = round((decomp.delta[mode, layer] / np.pi) % 1, rounding) + sigma = round((decomp.sigma[mode, layer] / np.pi) % 1, rounding) + print( + f"{mode} {mode+1} G $\\genfrac{{}}{{}}{{0pt}}{{}}{{{delta}}}{{{sigma}}}$ width={width}", + file=output, + ) + if (dim - 1, layer) in decomp.phi_edge: + phi_bottom = round((decomp.phi_edge[dim - 1, layer] / np.pi) % 1, rounding) + print(f"{dim-1} G ${phi_bottom}$ width={width}", file=output) + # Apply the phase shifters at the output + for mode, phi in decomp.phi_output.items(): + phi = round((phi / np.pi) % 1, rounding) + print(f"{mode} G ${phi}$ width={width}", file=output) + captured_text = output.getvalue() + output.close() + return captured_text