Skip to content

Commit 45b3317

Browse files
authored
Yearly update
Updated some test files Improved some type hints Increase the min version of some dependencies
1 parent d9f749d commit 45b3317

13 files changed

Lines changed: 67 additions & 57 deletions

File tree

pyhdtoolkit/utils/decorators.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,19 @@
1313
import inspect
1414
import traceback
1515
import warnings
16-
from typing import TYPE_CHECKING
16+
from typing import TYPE_CHECKING, ParamSpec, TypeVar
1717

1818
if TYPE_CHECKING:
1919
from collections.abc import Callable
2020

2121

22+
P = ParamSpec("P") # for params
23+
R = TypeVar("R") # for returns
24+
2225
# ----- Utility deprecation decorator ----- #
2326

2427

25-
def deprecated(message: str = "") -> Callable:
28+
def deprecated(message: str = "") -> Callable[[Callable[P, R]], Callable[P, R]]:
2629
"""
2730
Decorator to mark a function as deprecated. It will result in an
2831
informative `DeprecationWarning` being issued with the provided
@@ -49,22 +52,23 @@ def old_function():
4952
return "I am old!"
5053
"""
5154

52-
def decorator_wrapper(func):
55+
def decorator_wrapper(func: Callable[P, R]) -> Callable[P, R]:
56+
last_call_sources: set[str] = set()
57+
5358
@functools.wraps(func)
54-
def function_wrapper(*args, **kwargs):
59+
def function_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
5560
current_call_source = "|".join(traceback.format_stack(inspect.currentframe()))
56-
if current_call_source not in function_wrapper.last_call_source:
61+
62+
if current_call_source not in last_call_sources:
5763
warnings.warn(
58-
f"Function {func.__name__} is now deprecated and will be removed in a future release! {message}",
64+
f"Function {func.__name__} is now deprecated and will be removed in a future release! {message}", # ty:ignore[unresolved-attribute]
5965
category=DeprecationWarning,
6066
stacklevel=2,
6167
)
62-
function_wrapper.last_call_source.add(current_call_source)
68+
last_call_sources.add(current_call_source)
6369

6470
return func(*args, **kwargs)
6571

66-
function_wrapper.last_call_source = set()
67-
6872
return function_wrapper
6973

7074
return decorator_wrapper
@@ -73,7 +77,9 @@ def function_wrapper(*args, **kwargs):
7377
# ----- Utility JIT Compilation decorator ----- #
7478

7579

76-
def maybe_jit(func: Callable, **kwargs) -> Callable:
80+
# We type hint to specify we return a function with the same
81+
# signature as the input function.
82+
def maybe_jit(func: Callable[P, R], **kwargs) -> Callable[P, R]:
7783
"""
7884
.. versionadded:: 1.7.0
7985

pyproject.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ dependencies = [
5858
"numpy >= 2.0",
5959
"pandas >= 2.0",
6060
"matplotlib >=3.7",
61-
"scipy >= 1.6",
61+
"scipy >= 1.10",
6262
"tfs-pandas >= 3.8",
6363
"loguru < 1.0",
6464
"cpymad >= 1.16",
@@ -71,22 +71,22 @@ dependencies = [
7171
[project.optional-dependencies]
7272
test = [
7373
"pytest >= 8.0",
74-
"pytest-cov >= 5.0",
74+
"pytest-cov >= 6.0",
7575
"pytest-xdist >= 3.0",
7676
"numba >= 0.60.0",
7777
"flaky >= 3.5",
78-
"pytest-randomly >= 3.3",
78+
"pytest-randomly >= 3.10",
7979
"coverage[toml] >= 7.0",
8080
"pytest-mpl >= 0.14",
8181
]
8282
dev = [
83-
"ruff >= 0.5",
83+
"ruff >= 0.12",
8484
]
8585
docs = [
8686
"joblib >= 1.0",
87-
"Sphinx >= 7.0",
88-
"sphinx-rtd-theme >= 2.0",
89-
"sphinx-issues >= 4.0",
87+
"Sphinx >= 8.0",
88+
"sphinx-rtd-theme >= 3.0",
89+
"sphinx-issues >= 5.0",
9090
"sphinx_copybutton < 1.0",
9191
"sphinxcontrib-bibtex >= 2.4",
9292
"sphinx-design >= 0.6",
-88 Bytes
Binary file not shown.

tests/test_cpymadtools/test_lhc.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def test_rigidity_knob_fails_on_invalid_ir(_non_matched_lhc_madx, caplog):
487487
def test_rigidity_knob_fails_on_invalid_side(caplog, _non_matched_lhc_madx):
488488
madx = _non_matched_lhc_madx
489489

490-
with pytest.raises(ValueError, match="Invalid value for parameter 'side'."):
490+
with pytest.raises(ValueError, match=r"Invalid value for parameter 'side'."):
491491
apply_lhc_rigidity_waist_shift_knob(madx, 1, 1, "invalid")
492492

493493
for record in caplog.records:
@@ -712,10 +712,10 @@ def test_get_bpms_coupling_rdts(_non_matched_lhc_madx, _reference_twiss_rdts):
712712

713713
twiss_with_rdts = get_lhc_bpms_twiss_and_rdts(madx)
714714
# We separate the complex components to compare to the reference
715-
twiss_with_rdts["F1001R"] = twiss_with_rdts.F1001.apply(np.real)
716-
twiss_with_rdts["F1001I"] = twiss_with_rdts.F1001.apply(np.imag)
717-
twiss_with_rdts["F1010R"] = twiss_with_rdts.F1010.apply(np.real)
718-
twiss_with_rdts["F1010I"] = twiss_with_rdts.F1010.apply(np.imag)
715+
twiss_with_rdts["F1001R"] = twiss_with_rdts.F1001.apply(np.real) # ty:ignore[unresolved-attribute]
716+
twiss_with_rdts["F1001I"] = twiss_with_rdts.F1001.apply(np.imag) # ty:ignore[unresolved-attribute]
717+
twiss_with_rdts["F1010R"] = twiss_with_rdts.F1010.apply(np.real) # ty:ignore[unresolved-attribute]
718+
twiss_with_rdts["F1010I"] = twiss_with_rdts.F1010.apply(np.imag) # ty:ignore[unresolved-attribute]
719719
twiss_with_rdts = twiss_with_rdts.drop(columns=["F1001", "F1010"]).set_index("NAME")
720720
# Only care to compare the coupling RDTs columns
721721
twiss_with_rdts = twiss_with_rdts.loc[:, ["F1001R", "F1001I", "F1010R", "F1010I"]]
@@ -727,8 +727,8 @@ def test_get_bpms_coupling_rdts(_non_matched_lhc_madx, _reference_twiss_rdts):
727727
def test_k_modulation(_non_matched_lhc_madx, _reference_kmodulation):
728728
madx = _non_matched_lhc_madx
729729
results = do_kmodulation(madx)
730-
assert all(var == 0 for var in results.ERRTUNEX)
731-
assert all(var == 0 for var in results.ERRTUNEY)
730+
assert np.all(results.ERRTUNEX.to_numpy() == 0) # ty:ignore[unresolved-attribute]
731+
assert np.all(results.ERRTUNEY.to_numpy() == 0) # ty:ignore[unresolved-attribute]
732732

733733
reference = tfs.read(_reference_kmodulation)
734734
assert_frame_equal(results.convert_dtypes(), reference.convert_dtypes()) # avoid dtype comparison error on 0 cols
@@ -841,7 +841,7 @@ def test_lhc_run3_setup_context_manager_raises_on_wrong_b4_conditions():
841841
@pytest.mark.skipif(not (TESTS_DIR.parent / "acc-models-lhc").is_dir(), reason="acc-models-lhc not found")
842842
def test_lhc_run3_setup_context_manager_raises_on_wrong_run_value():
843843
with pytest.raises( # noqa: SIM117
844-
NotImplementedError, match="This setup is only possible for Run 2 and Run 3 configurations."
844+
NotImplementedError, match=r"This setup is only possible for Run 2 and Run 3 configurations."
845845
): # using b4 with beam1 setup crashes
846846
with LHCSetup(run=1, opticsfile="R2022a_A30cmC30cmA10mL200cm.madx") as madx: # noqa: F841
847847
pass

tests/test_plotting/test_aperture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,5 @@ def test_plot_physical_apertures_ir5_collision_vertical(_collision_aperture_tole
8787
def test_plot_physical_apertures_raises_on_wrong_plane():
8888
madx = Madx(stdout=False)
8989

90-
with pytest.raises(ValueError, match="Invalid 'plane' argument."):
90+
with pytest.raises(ValueError, match=r"Invalid 'plane' argument."):
9191
plot_physical_apertures(madx, plane="invalid")

tests/test_plotting/test_envelope.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
def test_plot_enveloppe_raises_on_wrong_plane():
1919
madx = Madx(stdout=False)
2020

21-
with pytest.raises(ValueError, match="Invalid 'plane' argument."):
21+
with pytest.raises(ValueError, match=r"Invalid 'plane' argument."):
2222
plot_beam_envelope(madx, "lhcb1", plane="invalid")
2323

2424

tests/test_plotting/test_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,5 @@ def test_confidence_ellipse_fails_on_mismatched_dimensions():
4040

4141

4242
def test_default_sbs_coupling_label_raises_on_wrong_component():
43-
with pytest.raises(ValueError, match="Invalid component for coupling RDT."):
43+
with pytest.raises(ValueError, match=r"Invalid component for coupling RDT."):
4444
_determine_default_sbs_coupling_ylabel(rdt="f1001", component="NONEXISTANT")

tests/test_plotting/test_phasespace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_plot_courant_snyder_phase_space_wrong_plane_input():
8787
match_cas3(madx)
8888
x_coords_stable, px_coords_stable = np.array([]), np.array([]) # no need for tracking
8989

90-
with pytest.raises(ValueError, match="Invalid 'plane' argument."):
90+
with pytest.raises(ValueError, match=r"Invalid 'plane' argument."):
9191
plot_courant_snyder_phase_space(madx, x_coords_stable, px_coords_stable, plane="invalid_plane")
9292

9393

@@ -97,7 +97,7 @@ def test_plot_courant_snyder_phase_space_colored_wrong_plane_input():
9797
madx.input(BASE_LATTICE)
9898
match_cas3(madx)
9999
x_coords_stable, px_coords_stable = np.array([]), np.array([]) # no need for tracking
100-
with pytest.raises(ValueError, match="Invalid 'plane' argument."):
100+
with pytest.raises(ValueError, match=r"Invalid 'plane' argument."):
101101
plot_courant_snyder_phase_space_colored(madx, x_coords_stable, px_coords_stable, plane="invalid_plane")
102102

103103

tests/test_plotting/test_plotting_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_coupling_ylabel(f1001, f1010, abs_, real, imag):
5454

5555
@pytest.mark.parametrize("rdt", ["invalid", "F1111", "nope"])
5656
def test_coupling_ylabel_raises_on_invalid_rdt(rdt):
57-
with pytest.raises(ValueError, match="Invalid RDT for coupling plot."):
57+
with pytest.raises(ValueError, match=r"Invalid RDT for coupling plot."):
5858
_determine_default_sbs_coupling_ylabel(rdt, "abs")
5959

6060

@@ -69,7 +69,7 @@ def test_phase_ylabel(plane):
6969

7070
@pytest.mark.parametrize("plane", ["a", "Fb1", "nope", "not a plane"])
7171
def test_phase_ylabel_raises_on_invalid_plane(plane):
72-
with pytest.raises(ValueError, match="Invalid plane for phase plot."):
72+
with pytest.raises(ValueError, match=r"Invalid plane for phase plot."):
7373
_determine_default_sbs_phase_ylabel(plane)
7474

7575

tests/test_plotting/test_sbs_phase.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_plot_both_beams(sbs_phasex, sbs_phasey, sbs_model_b2):
4242

4343
@pytest.mark.parametrize("wrongplane", ["not", "accepted", "incorrect", ""])
4444
def test_plot_phase_segment_raises_on_wrong_plane(wrongplane, sbs_phasex, sbs_model_b2):
45-
with pytest.raises(ValueError, match="Invalid 'plane' argument."):
45+
with pytest.raises(ValueError, match=r"Invalid 'plane' argument."):
4646
plot_phase_segment(segment_df=sbs_phasex, model_df=sbs_model_b2, plane=wrongplane)
4747

4848

0 commit comments

Comments
 (0)