diff --git a/software/control/ndviewer_light b/software/control/ndviewer_light index fbff5e405..46575caea 160000 --- a/software/control/ndviewer_light +++ b/software/control/ndviewer_light @@ -1 +1 @@ -Subproject commit fbff5e405654b6e5f59eaf74cbbab85a4780432f +Subproject commit 46575caead5f51352444b5017904e811c27b1d6c diff --git a/software/squid/filter_wheel_controller/cephla.py b/software/squid/filter_wheel_controller/cephla.py index 480a7ed63..3696c5c8f 100644 --- a/software/squid/filter_wheel_controller/cephla.py +++ b/software/squid/filter_wheel_controller/cephla.py @@ -1,3 +1,4 @@ +import threading import time from typing import List, Dict, Optional, Union @@ -41,6 +42,11 @@ def __init__( raise Exception("Error, microcontroller is needed by the SquidFilterWheel") self.microcontroller = microcontroller + # Serializes all position-affecting operations (move, home, step) so that + # the read-position → send-move → wait → update-position sequence is atomic. + # Without this, concurrent calls from the GUI thread and acquisition worker + # can read stale positions and issue wrong relative moves. + self._lock = threading.Lock() # Convert single config to dict format for uniform handling if isinstance(configs, SquidFilterWheelConfig): @@ -91,7 +97,7 @@ def _configure_wheel(self, wheel_id: int, config: SquidFilterWheelConfig): self.microcontroller.turn_on_stage_pid(axis, ENABLE_PID_W) def _move_wheel(self, wheel_id: int, delta: float): - """Move a specific wheel by delta distance. + """Move a specific wheel by delta distance. Caller must hold self._lock. Args: wheel_id: The ID of the wheel to move. @@ -113,6 +119,8 @@ def _move_wheel(self, wheel_id: int, delta: float): def _move_to_position(self, wheel_id: int, target_pos: int): """Move wheel to target position with automatic re-home on failure. + Caller must hold self._lock. + If the movement times out (e.g., motor stall), this method will: 1. Log a warning 2. Re-home the wheel to re-synchronize position tracking @@ -140,9 +148,15 @@ def _move_to_position(self, wheel_id: int, target_pos: int): self.microcontroller.wait_till_operation_is_completed() self._positions[wheel_id] = target_pos except TimeoutError: - _log.warning(f"Filter wheel {wheel_id} movement timed out. " f"Re-homing to re-sync position tracking...") - # Re-home to re-synchronize position tracking - self._home_wheel(wheel_id) + _log.warning(f"Filter wheel {wheel_id} movement timed out. Re-homing to re-sync position tracking...") + try: + self._home_wheel(wheel_id) + except Exception as rehome_err: + _log.error( + f"Filter wheel {wheel_id} re-home also failed: {rehome_err}. " + f"Position tracking unreliable. Hardware may need attention." + ) + raise # Retry the movement (position is now at min_index after homing) current_pos = self._positions[wheel_id] @@ -154,16 +168,14 @@ def _move_to_position(self, wheel_id: int, target_pos: int): _log.info(f"Filter wheel {wheel_id} recovery successful, now at position {target_pos}") except TimeoutError: _log.error( - f"Filter wheel {wheel_id} movement failed even after re-home. " f"Hardware may need attention." + f"Filter wheel {wheel_id} movement failed even after re-home. " + f"Tracked position ({self._positions[wheel_id]}) may not reflect " + f"actual physical position. Hardware may need attention." ) raise def _home_wheel(self, wheel_id: int): - """Home a specific wheel. - - Args: - wheel_id: The ID of the wheel to home. - """ + """Home a specific wheel. Caller must hold self._lock.""" config = self._configs[wheel_id] motor_slot = config.motor_slot_index @@ -174,14 +186,26 @@ def _home_wheel(self, wheel_id: int): else: raise ValueError(f"Unsupported motor_slot_index: {motor_slot}") - # Wait for homing to complete (needs longer timeout) - self.microcontroller.wait_till_operation_is_completed(15) + try: + self.microcontroller.wait_till_operation_is_completed(15) + except Exception as e: + # Physical position is unknown — reset tracking to min_index as + # best guess so subsequent moves don't use a stale value. + self._positions[wheel_id] = config.min_index + _log.error(f"Filter wheel {wheel_id} homing failed: {e}. Position tracking reset to {config.min_index}.") + raise - # Move to offset position self._move_wheel(wheel_id, config.offset) - self.microcontroller.wait_till_operation_is_completed() + try: + self.microcontroller.wait_till_operation_is_completed() + except Exception as e: + # Homed but offset move failed — at physical zero, not offset. + self._positions[wheel_id] = config.min_index + _log.error( + f"Filter wheel {wheel_id} offset move failed after homing: {e}. Position reset to {config.min_index}." + ) + raise - # Reset position tracking self._positions[wheel_id] = config.min_index def initialize(self, filter_wheel_indices: List[int]): @@ -225,14 +249,14 @@ def home(self, index: Optional[int] = None): Args: index: Specific wheel index to home. If None, homes all configured wheels. """ - if index is not None: - if index not in self._configs: - raise ValueError(f"Filter wheel index {index} not found") - self._home_wheel(index) - else: - # Home all wheels - for wheel_id in self._configs.keys(): - self._home_wheel(wheel_id) + with self._lock: + if index is not None: + if index not in self._configs: + raise ValueError(f"Filter wheel index {index} not found") + self._home_wheel(index) + else: + for wheel_id in self._configs.keys(): + self._home_wheel(wheel_id) def _step_position(self, wheel_id: int, direction: int): """Move position by one step in the given direction. @@ -244,11 +268,12 @@ def _step_position(self, wheel_id: int, direction: int): if wheel_id not in self._configs: raise ValueError(f"Filter wheel index {wheel_id} not found") - config = self._configs[wheel_id] - current_pos = self._positions[wheel_id] - new_pos = current_pos + direction - - if config.min_index <= new_pos <= config.max_index: + config = self._configs[wheel_id] # _configs is immutable after __init__ + with self._lock: + current_pos = self._positions[wheel_id] + new_pos = current_pos + direction + if not (config.min_index <= new_pos <= config.max_index): + return self._move_to_position(wheel_id, new_pos) def next_position(self, wheel_id: int = 1): @@ -274,15 +299,16 @@ def set_filter_wheel_position(self, positions: Dict[int, int]): positions: Dict mapping wheel_id -> target position. Position values are 1-indexed (typically 1-8). """ - for wheel_id, pos in positions.items(): - if wheel_id not in self._configs: - raise ValueError(f"Filter wheel index {wheel_id} not found") + with self._lock: + for wheel_id, pos in positions.items(): + if wheel_id not in self._configs: + raise ValueError(f"Filter wheel index {wheel_id} not found") - config = self._configs[wheel_id] - if pos not in range(config.min_index, config.max_index + 1): - raise ValueError(f"Filter wheel {wheel_id} position {pos} is out of range") + config = self._configs[wheel_id] + if pos not in range(config.min_index, config.max_index + 1): + raise ValueError(f"Filter wheel {wheel_id} position {pos} is out of range") - self._move_to_position(wheel_id, pos) + self._move_to_position(wheel_id, pos) def get_filter_wheel_position(self) -> Dict[int, int]: """Get current positions of all configured wheels. @@ -290,7 +316,8 @@ def get_filter_wheel_position(self) -> Dict[int, int]: Returns: Dict mapping wheel_id -> current position. """ - return dict(self._positions) + with self._lock: + return dict(self._positions) def set_delay_offset_ms(self, delay_offset_ms: float): """Set delay offset (not used by SQUID filter wheel).""" @@ -311,8 +338,3 @@ def get_delay_ms(self) -> Optional[float]: def close(self): """Close the filter wheel controller (no-op for SQUID).""" pass - - # Backward compatibility methods - def move_w(self, delta: float): - """Move the first wheel by delta. For backward compatibility.""" - self._move_wheel(1, delta) diff --git a/software/tests/squid/test_filter_wheel.py b/software/tests/squid/test_filter_wheel.py index 617961f53..36b256f4d 100644 --- a/software/tests/squid/test_filter_wheel.py +++ b/software/tests/squid/test_filter_wheel.py @@ -1,9 +1,12 @@ +import threading +import time from unittest.mock import MagicMock, patch import pytest import squid.config import squid.filter_wheel_controller.utils +from control._def import SCREW_PITCH_W_MM from squid.config import FilterWheelConfig, FilterWheelControllerVariant, SquidFilterWheelConfig from squid.filter_wheel_controller.cephla import SquidFilterWheel @@ -118,3 +121,126 @@ def test_normal_init_configures_encoder_pid(self, mock_microcontroller, squid_co mock_microcontroller.set_pid_arguments.assert_called_once() mock_microcontroller.configure_stage_pid.assert_called_once() mock_microcontroller.turn_on_stage_pid.assert_called_once() + + +class TestSquidFilterWheelThreadSafety: + """Tests that concurrent filter wheel operations don't corrupt position tracking.""" + + @pytest.fixture + def wheel(self): + mcu = MagicMock() + # Simulate real MCU latency so the critical section has a wide enough + # window for threads to actually interleave without the lock. + mcu.wait_till_operation_is_completed.side_effect = lambda *a, **kw: time.sleep(0.02) + config = SquidFilterWheelConfig( + max_index=8, min_index=1, offset=0.008, motor_slot_index=3, transitions_per_revolution=4000 + ) + return SquidFilterWheel(mcu, config, skip_init=True) + + def test_concurrent_moves_serialize(self, wheel): + """Two threads calling set_filter_wheel_position must not corrupt tracking. + + Without the lock, both threads read current_pos=1, compute their deltas + relative to 1, and issue overlapping moves that leave the physical wheel + at the wrong position. With the lock the second thread sees the updated + position from the first and computes the correct delta relative to the + actual current position. + """ + move_deltas = [] + original_move_wheel = wheel._move_wheel + + def recording_move_wheel(wid, delta): + move_deltas.append(delta) + original_move_wheel(wid, delta) + + wheel._move_wheel = recording_move_wheel + + barrier = threading.Barrier(2) + errors = [] + + def move_to(pos): + try: + barrier.wait(timeout=2) + wheel.set_filter_wheel_position({1: pos}) + except Exception as e: + errors.append(e) + + t1 = threading.Thread(target=move_to, args=(5,)) + t2 = threading.Thread(target=move_to, args=(3,)) + t1.start() + t2.start() + t1.join(timeout=5) + t2.join(timeout=5) + + assert not t1.is_alive(), "Thread 1 did not finish (possible deadlock)" + assert not t2.is_alive(), "Thread 2 did not finish (possible deadlock)" + assert not errors, f"Threads raised: {errors}" + # Final tracked position must equal the last physical move target + final_pos = wheel.get_filter_wheel_position()[1] + assert final_pos in (3, 5), f"Position tracking corrupted: {final_pos}" + + # With the lock, the second move should compute its delta from the first + # move's result, not from the original position 1. So the sum of deltas + # must equal (final_pos - 1) * step_size, regardless of execution order. + config = wheel._configs[1] + step_size = SCREW_PITCH_W_MM / (config.max_index - config.min_index + 1) + expected_total_delta = (final_pos - 1) * step_size + actual_total_delta = sum(move_deltas) + assert abs(actual_total_delta - expected_total_delta) < 1e-9, ( + f"Delta mismatch: moves summed to {actual_total_delta}, " + f"but position {final_pos} requires {expected_total_delta}" + ) + + def test_home_during_move_serializes(self, wheel): + """home() must not run concurrently with a move.""" + wheel._positions[1] = 4 + call_order = [] + + original_home_w = wheel.microcontroller.home_w + original_move_w = wheel.microcontroller.move_w_usteps + + def tracked_home_w(*a, **kw): + call_order.append("home_start") + original_home_w(*a, **kw) + call_order.append("home_end") + + def tracked_move_w(usteps): + call_order.append("move_start") + original_move_w(usteps) + call_order.append("move_end") + + wheel.microcontroller.home_w = tracked_home_w + wheel.microcontroller.move_w_usteps = tracked_move_w + + barrier = threading.Barrier(2) + + def do_home(): + barrier.wait(timeout=2) + wheel.home(1) + + def do_move(): + barrier.wait(timeout=2) + wheel.set_filter_wheel_position({1: 6}) + + t1 = threading.Thread(target=do_home) + t2 = threading.Thread(target=do_move) + t1.start() + t2.start() + t1.join(timeout=10) + t2.join(timeout=10) + + assert not t1.is_alive(), "Home thread did not finish (possible deadlock)" + assert not t2.is_alive(), "Move thread did not finish (possible deadlock)" + + # Verify operations did not interleave. + # home() calls home_w (home_start/home_end) then _move_wheel for offset (move_start/move_end). + # set_filter_wheel_position calls _move_wheel (move_start/move_end). + # With the lock, the two valid orderings are: + # home first: [home_start, home_end, move_start, move_end, move_start, move_end] + # move first: [move_start, move_end, home_start, home_end, move_start, move_end] + assert call_order in ( + ["home_start", "home_end", "move_start", "move_end", "move_start", "move_end"], + ["move_start", "move_end", "home_start", "home_end", "move_start", "move_end"], + ), f"Operations interleaved: {call_order}" + # home-first -> move sets final pos 6; move-first -> home resets to 1 + assert wheel.get_filter_wheel_position()[1] in (1, 6)