diff --git a/src/batdetect2/utils/visualize.py b/src/batdetect2/utils/visualize.py index b7f889a6..963ea6f4 100644 --- a/src/batdetect2/utils/visualize.py +++ b/src/batdetect2/utils/visualize.py @@ -52,9 +52,9 @@ def __init__( self.spec_slices = spec_slices self.call_info = call_info # _, self.labels = np.unique([cc['class'] for cc in call_info], return_inverse=True) - self.labels = np.zeros(len(call_info), dtype=np.int) + self.labels = np.zeros(len(call_info), dtype=int) self.annotated = np.zeros( - self.labels.shape[0], dtype=np.int + self.labels.shape[0], dtype=int ) # can populate this with 1's where we have labels self.labels_cols = [ colors[self.labels[ii]] for ii in range(len(self.labels)) diff --git a/tests/test_utils/test_visualize.py b/tests/test_utils/test_visualize.py new file mode 100644 index 00000000..de695495 --- /dev/null +++ b/tests/test_utils/test_visualize.py @@ -0,0 +1,22 @@ +import numpy as np + +from batdetect2.utils.visualize import InteractivePlotter + + +def test_interactive_plotter_init_builds_integer_label_arrays(): + feats_ds = np.zeros((2, 2)) + spec_slices = [np.zeros((4, 6)), np.zeros((4, 8))] + call_info = [{"class": "a"}, {"class": "b"}] + + plotter = InteractivePlotter( + feats_ds=feats_ds, + feats=feats_ds, + spec_slices=spec_slices, + call_info=call_info, + freq_lims=[0, 1], + allow_training=False, + ) + + assert plotter.labels.shape == (2,) + assert np.issubdtype(plotter.labels.dtype, np.integer) + assert np.issubdtype(plotter.annotated.dtype, np.integer)