Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion software/control/ndviewer_light
104 changes: 63 additions & 41 deletions software/squid/filter_wheel_controller/cephla.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
import time
from typing import List, Dict, Optional, Union

Expand Down Expand Up @@ -41,6 +42,11 @@ def __init__(
raise Exception("Error, microcontroller is needed by the SquidFilterWheel")

self.microcontroller = microcontroller
# Serializes all position-affecting operations (move, home, step) so that
# the read-position → send-move → wait → update-position sequence is atomic.
# Without this, concurrent calls from the GUI thread and acquisition worker
# can read stale positions and issue wrong relative moves.
self._lock = threading.Lock()

# Convert single config to dict format for uniform handling
if isinstance(configs, SquidFilterWheelConfig):
Expand Down Expand Up @@ -91,7 +97,7 @@ def _configure_wheel(self, wheel_id: int, config: SquidFilterWheelConfig):
self.microcontroller.turn_on_stage_pid(axis, ENABLE_PID_W)

def _move_wheel(self, wheel_id: int, delta: float):
"""Move a specific wheel by delta distance.
"""Move a specific wheel by delta distance. Caller must hold self._lock.

Args:
wheel_id: The ID of the wheel to move.
Expand All @@ -113,6 +119,8 @@ def _move_wheel(self, wheel_id: int, delta: float):
def _move_to_position(self, wheel_id: int, target_pos: int):
"""Move wheel to target position with automatic re-home on failure.

Caller must hold self._lock.

If the movement times out (e.g., motor stall), this method will:
1. Log a warning
2. Re-home the wheel to re-synchronize position tracking
Expand Down Expand Up @@ -140,9 +148,15 @@ def _move_to_position(self, wheel_id: int, target_pos: int):
self.microcontroller.wait_till_operation_is_completed()
self._positions[wheel_id] = target_pos
except TimeoutError:
_log.warning(f"Filter wheel {wheel_id} movement timed out. " f"Re-homing to re-sync position tracking...")
# Re-home to re-synchronize position tracking
self._home_wheel(wheel_id)
_log.warning(f"Filter wheel {wheel_id} movement timed out. Re-homing to re-sync position tracking...")
try:
self._home_wheel(wheel_id)
except Exception as rehome_err:
_log.error(
f"Filter wheel {wheel_id} re-home also failed: {rehome_err}. "
f"Position tracking unreliable. Hardware may need attention."
)
raise

# Retry the movement (position is now at min_index after homing)
current_pos = self._positions[wheel_id]
Expand All @@ -154,16 +168,14 @@ def _move_to_position(self, wheel_id: int, target_pos: int):
_log.info(f"Filter wheel {wheel_id} recovery successful, now at position {target_pos}")
except TimeoutError:
_log.error(
f"Filter wheel {wheel_id} movement failed even after re-home. " f"Hardware may need attention."
f"Filter wheel {wheel_id} movement failed even after re-home. "
f"Tracked position ({self._positions[wheel_id]}) may not reflect "
f"actual physical position. Hardware may need attention."
)
raise

def _home_wheel(self, wheel_id: int):
"""Home a specific wheel.

Args:
wheel_id: The ID of the wheel to home.
"""
"""Home a specific wheel. Caller must hold self._lock."""
config = self._configs[wheel_id]
motor_slot = config.motor_slot_index

Expand All @@ -174,14 +186,26 @@ def _home_wheel(self, wheel_id: int):
else:
raise ValueError(f"Unsupported motor_slot_index: {motor_slot}")

# Wait for homing to complete (needs longer timeout)
self.microcontroller.wait_till_operation_is_completed(15)
try:
self.microcontroller.wait_till_operation_is_completed(15)
except Exception as e:
# Physical position is unknown — reset tracking to min_index as
# best guess so subsequent moves don't use a stale value.
self._positions[wheel_id] = config.min_index
_log.error(f"Filter wheel {wheel_id} homing failed: {e}. Position tracking reset to {config.min_index}.")
raise

# Move to offset position
self._move_wheel(wheel_id, config.offset)
self.microcontroller.wait_till_operation_is_completed()
try:
self.microcontroller.wait_till_operation_is_completed()
except Exception as e:
# Homed but offset move failed — at physical zero, not offset.
self._positions[wheel_id] = config.min_index
_log.error(
f"Filter wheel {wheel_id} offset move failed after homing: {e}. Position reset to {config.min_index}."
)
raise

# Reset position tracking
self._positions[wheel_id] = config.min_index

def initialize(self, filter_wheel_indices: List[int]):
Expand Down Expand Up @@ -225,14 +249,14 @@ def home(self, index: Optional[int] = None):
Args:
index: Specific wheel index to home. If None, homes all configured wheels.
"""
if index is not None:
if index not in self._configs:
raise ValueError(f"Filter wheel index {index} not found")
self._home_wheel(index)
else:
# Home all wheels
for wheel_id in self._configs.keys():
self._home_wheel(wheel_id)
with self._lock:
if index is not None:
if index not in self._configs:
raise ValueError(f"Filter wheel index {index} not found")
self._home_wheel(index)
else:
for wheel_id in self._configs.keys():
self._home_wheel(wheel_id)

def _step_position(self, wheel_id: int, direction: int):
"""Move position by one step in the given direction.
Expand All @@ -244,11 +268,12 @@ def _step_position(self, wheel_id: int, direction: int):
if wheel_id not in self._configs:
raise ValueError(f"Filter wheel index {wheel_id} not found")

config = self._configs[wheel_id]
current_pos = self._positions[wheel_id]
new_pos = current_pos + direction

if config.min_index <= new_pos <= config.max_index:
config = self._configs[wheel_id] # _configs is immutable after __init__
with self._lock:
current_pos = self._positions[wheel_id]
new_pos = current_pos + direction
if not (config.min_index <= new_pos <= config.max_index):
return
self._move_to_position(wheel_id, new_pos)

def next_position(self, wheel_id: int = 1):
Expand All @@ -274,23 +299,25 @@ def set_filter_wheel_position(self, positions: Dict[int, int]):
positions: Dict mapping wheel_id -> target position.
Position values are 1-indexed (typically 1-8).
"""
for wheel_id, pos in positions.items():
if wheel_id not in self._configs:
raise ValueError(f"Filter wheel index {wheel_id} not found")
with self._lock:
for wheel_id, pos in positions.items():
if wheel_id not in self._configs:
raise ValueError(f"Filter wheel index {wheel_id} not found")

config = self._configs[wheel_id]
if pos not in range(config.min_index, config.max_index + 1):
raise ValueError(f"Filter wheel {wheel_id} position {pos} is out of range")
config = self._configs[wheel_id]
if pos not in range(config.min_index, config.max_index + 1):
raise ValueError(f"Filter wheel {wheel_id} position {pos} is out of range")

self._move_to_position(wheel_id, pos)
self._move_to_position(wheel_id, pos)

def get_filter_wheel_position(self) -> Dict[int, int]:
"""Get current positions of all configured wheels.

Returns:
Dict mapping wheel_id -> current position.
"""
return dict(self._positions)
with self._lock:
return dict(self._positions)

def set_delay_offset_ms(self, delay_offset_ms: float):
"""Set delay offset (not used by SQUID filter wheel)."""
Expand All @@ -311,8 +338,3 @@ def get_delay_ms(self) -> Optional[float]:
def close(self):
"""Close the filter wheel controller (no-op for SQUID)."""
pass
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR removes the move_w(self, delta: float) backward-compatibility method at the end of the class. That’s an API-breaking change for any existing callers relying on it. If this is still needed externally, consider keeping it as a thin wrapper (possibly marked deprecated) or provide a migration path in the same PR.

Suggested change
pass
pass
def move_w(self, delta: float, wheel_id: int = 1):
"""Deprecated backward-compatibility wrapper for moving the wheel.
This method exists for API compatibility with older code.
Prefer using :meth:`next_position`, :meth:`previous_position`,
or :meth:`set_filter_wheel_position` instead.
Args:
delta: Number of steps to move. Positive values move forward,
negative values move backward. Any fractional component is
discarded.
wheel_id: The wheel to move. Defaults to 1 for backward
compatibility with older single-wheel usage.
"""
# Log a deprecation warning to aid migration off this method.
_log.warning(
"SquidFilterWheel.move_w() is deprecated; use next_position(), "
"previous_position(), or set_filter_wheel_position() instead."
)
# No movement requested.
steps = int(delta)
if steps == 0:
return
direction = 1 if steps > 0 else -1
for _ in range(abs(steps)):
self._step_position(wheel_id, direction)

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Claude Code] Skipped - grep confirms zero callers of move_w() in the entire codebase. The suggested replacement also changes semantics (interprets delta as step count rather than mm distance). Adding a deprecated wrapper for dead code adds maintenance burden with no benefit.


# Backward compatibility methods
def move_w(self, delta: float):
"""Move the first wheel by delta. For backward compatibility."""
self._move_wheel(1, delta)
126 changes: 126 additions & 0 deletions software/tests/squid/test_filter_wheel.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import threading
import time
from unittest.mock import MagicMock, patch

import pytest

import squid.config
import squid.filter_wheel_controller.utils
from control._def import SCREW_PITCH_W_MM
from squid.config import FilterWheelConfig, FilterWheelControllerVariant, SquidFilterWheelConfig
from squid.filter_wheel_controller.cephla import SquidFilterWheel

Expand Down Expand Up @@ -118,3 +121,126 @@ def test_normal_init_configures_encoder_pid(self, mock_microcontroller, squid_co
mock_microcontroller.set_pid_arguments.assert_called_once()
mock_microcontroller.configure_stage_pid.assert_called_once()
mock_microcontroller.turn_on_stage_pid.assert_called_once()


class TestSquidFilterWheelThreadSafety:
"""Tests that concurrent filter wheel operations don't corrupt position tracking."""

@pytest.fixture
def wheel(self):
mcu = MagicMock()
# Simulate real MCU latency so the critical section has a wide enough
# window for threads to actually interleave without the lock.
mcu.wait_till_operation_is_completed.side_effect = lambda *a, **kw: time.sleep(0.02)
config = SquidFilterWheelConfig(
max_index=8, min_index=1, offset=0.008, motor_slot_index=3, transitions_per_revolution=4000
)
return SquidFilterWheel(mcu, config, skip_init=True)

def test_concurrent_moves_serialize(self, wheel):
"""Two threads calling set_filter_wheel_position must not corrupt tracking.

Without the lock, both threads read current_pos=1, compute their deltas
relative to 1, and issue overlapping moves that leave the physical wheel
at the wrong position. With the lock the second thread sees the updated
position from the first and computes the correct delta relative to the
actual current position.
"""
move_deltas = []
original_move_wheel = wheel._move_wheel

def recording_move_wheel(wid, delta):
move_deltas.append(delta)
original_move_wheel(wid, delta)

wheel._move_wheel = recording_move_wheel

barrier = threading.Barrier(2)
errors = []

def move_to(pos):
try:
barrier.wait(timeout=2)
wheel.set_filter_wheel_position({1: pos})
except Exception as e:
errors.append(e)

t1 = threading.Thread(target=move_to, args=(5,))
t2 = threading.Thread(target=move_to, args=(3,))
t1.start()
t2.start()
t1.join(timeout=5)
t2.join(timeout=5)

assert not t1.is_alive(), "Thread 1 did not finish (possible deadlock)"
assert not t2.is_alive(), "Thread 2 did not finish (possible deadlock)"
assert not errors, f"Threads raised: {errors}"
# Final tracked position must equal the last physical move target
final_pos = wheel.get_filter_wheel_position()[1]
assert final_pos in (3, 5), f"Position tracking corrupted: {final_pos}"

# With the lock, the second move should compute its delta from the first
# move's result, not from the original position 1. So the sum of deltas
# must equal (final_pos - 1) * step_size, regardless of execution order.
config = wheel._configs[1]
step_size = SCREW_PITCH_W_MM / (config.max_index - config.min_index + 1)
expected_total_delta = (final_pos - 1) * step_size
actual_total_delta = sum(move_deltas)
assert abs(actual_total_delta - expected_total_delta) < 1e-9, (
f"Delta mismatch: moves summed to {actual_total_delta}, "
f"but position {final_pos} requires {expected_total_delta}"
)

def test_home_during_move_serializes(self, wheel):
"""home() must not run concurrently with a move."""
wheel._positions[1] = 4
call_order = []

original_home_w = wheel.microcontroller.home_w
original_move_w = wheel.microcontroller.move_w_usteps

def tracked_home_w(*a, **kw):
call_order.append("home_start")
original_home_w(*a, **kw)
call_order.append("home_end")

def tracked_move_w(usteps):
call_order.append("move_start")
original_move_w(usteps)
call_order.append("move_end")

wheel.microcontroller.home_w = tracked_home_w
wheel.microcontroller.move_w_usteps = tracked_move_w

barrier = threading.Barrier(2)

def do_home():
barrier.wait(timeout=2)
wheel.home(1)

def do_move():
barrier.wait(timeout=2)
wheel.set_filter_wheel_position({1: 6})

t1 = threading.Thread(target=do_home)
t2 = threading.Thread(target=do_move)
t1.start()
t2.start()
t1.join(timeout=10)
t2.join(timeout=10)

assert not t1.is_alive(), "Home thread did not finish (possible deadlock)"
assert not t2.is_alive(), "Move thread did not finish (possible deadlock)"

# Verify operations did not interleave.
# home() calls home_w (home_start/home_end) then _move_wheel for offset (move_start/move_end).
# set_filter_wheel_position calls _move_wheel (move_start/move_end).
# With the lock, the two valid orderings are:
# home first: [home_start, home_end, move_start, move_end, move_start, move_end]
# move first: [move_start, move_end, home_start, home_end, move_start, move_end]
assert call_order in (
["home_start", "home_end", "move_start", "move_end", "move_start", "move_end"],
["move_start", "move_end", "home_start", "home_end", "move_start", "move_end"],
), f"Operations interleaved: {call_order}"
# home-first -> move sets final pos 6; move-first -> home resets to 1
assert wheel.get_filter_wheel_position()[1] in (1, 6)
Loading