Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ipfx/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def get_test_epoch(i,hz):
return None

if len(di_idx) == 1:
raise Exception("Cannot detect and end to the test pulse")
print("Cannot detect and end to the test pulse")
return None

start_pulse_idx = di_idx[0] + 1 # shift by one to compensate for diff()
end_pulse_idx = di_idx[1]
Expand Down
14 changes: 14 additions & 0 deletions ipfx/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,20 @@ def process(self, t, v, i):
# Spike list and thresholds have been refined - now find other features
upstrokes = spkd.find_upstroke_indexes(v, t, thresholds, peaks, self.filter, dvdt)
troughs = spkd.find_trough_indexes(v, t, thresholds, peaks, clipped, self.end)
not_nan = np.logical_not(np.logical_or(np.isnan(troughs), np.isnan(peaks)))
not_nan_idx = np.argwhere(not_nan)
valid_peak_tr_pair = not_nan_idx[peaks[not_nan_idx] < troughs[not_nan_idx]]

troughs = troughs[valid_peak_tr_pair]
peaks = peaks[valid_peak_tr_pair]
upstrokes = upstrokes[valid_peak_tr_pair]
clipped = clipped[valid_peak_tr_pair]
thresholds = thresholds[valid_peak_tr_pair]

if not thresholds.size:
# Save time if no spikes detected
return DataFrame()

downstrokes = spkd.find_downstroke_indexes(v, t, peaks, troughs, clipped, dvdt=dvdt)
trough_details, clipped = spkf.analyze_trough_details(v, t, thresholds, peaks, clipped, self.end,
dvdt=dvdt)
Expand Down
14 changes: 7 additions & 7 deletions ipfx/plot_qc_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def plot_single_ap_values(data_set, sweep_numbers, lims_features, sweep_features
else:
rheo_sn = cell_features["long_squares"]["rheobase_sweep"]["sweep_number"]
rheo_spike = get_spikes(sweep_features, rheo_sn)[0]
voltages = [ rheo_spike[f] for f in voltage_features]
times = [ rheo_spike[f] for f in time_features]
voltages = [rheo_spike[f] for f in voltage_features if not np.isnan(rheo_spike[f])]
times = [rheo_spike[f] for f in time_features if not np.isnan(rheo_spike[f])]

plt.figure(figs[0].number)
plt.scatter(range(len(voltages)), voltages, color='gray')
Expand Down Expand Up @@ -123,13 +123,13 @@ def plot_single_ap_values(data_set, sweep_numbers, lims_features, sweep_features
if nspikes:
if type_name != "long_square" and nspikes:

voltages = [spikes[0][f] for f in voltage_features]
times = [spikes[0][f] for f in time_features]
voltages = [spikes[0][f] for f in voltage_features if not np.isnan(spikes[0][f])]
times = [spikes[0][f] for f in time_features if not np.isnan(spikes[0][f])]
else:
rheo_sn = cell_features["long_squares"]["rheobase_sweep"]["sweep_number"]
rheo_spike = get_spikes(sweep_features, rheo_sn)[0]
voltages = [rheo_spike[f] for f in voltage_features]
times = [rheo_spike[f] for f in time_features]
voltages = [rheo_spike[f] for f in voltage_features if not np.isnan(rheo_spike[f])]
times = [rheo_spike[f] for f in time_features if not np.isnan(rheo_spike[f])]

plt.scatter(times, voltages, color='red', zorder=20)

Expand All @@ -144,7 +144,7 @@ def plot_single_ap_values(data_set, sweep_numbers, lims_features, sweep_features

if type_name == "ramp":
if nspikes:
plt.xlim(spikes[0]["threshold_t"] - 0.002, spikes[0]["fast_trough_t"] + 0.01)
plt.xlim(times[0] - 0.002, times[-2] + 0.01)
elif type_name == "short_square":
plt.xlim(stim_start_shifted - 0.002, stim_start_shifted + stim_dur + 0.01)
elif type_name == "long_square":
Expand Down
4 changes: 4 additions & 0 deletions ipfx/spike_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ def check_thresholds_and_peaks(v, t, spike_indexes, peak_indexes, upstroke_index

# Find the peak in a window twice the size of our allowed window
spike = spike_indexes[i]
if t[spike] + 2 * max_interval > t[-1]:
drop_spikes.append(i)
continue

t_0 = tsu.find_time_index(t, t[spike] + 2 * max_interval)
new_peak = np.argmax(v[spike:t_0]) + spike

Expand Down
4 changes: 3 additions & 1 deletion ipfx/spike_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def find_widths(v, t, spike_indexes, peak_indexes, trough_indexes, clipped=None)
# Some spikes in burst may have deep trough but short height, so can't use same
# definition for width

width_levels[width_levels < v[spike_indexes]] = thresh_to_peak_levels[width_levels < v[spike_indexes]]
used_indexes = np.argwhere(use_indexes)
change_indexes = used_indexes[width_levels[use_indexes] < v[spike_indexes[use_indexes]]]
width_levels[change_indexes] = thresh_to_peak_levels[change_indexes]

width_starts = np.zeros_like(trough_indexes) * np.nan
width_starts[use_indexes] = np.array([pk - np.flatnonzero(v[pk:spk:-1] <= wl)[0] if
Expand Down
9 changes: 8 additions & 1 deletion ipfx/stim_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@ def get_stim_characteristics(i, t, test_pulse=True):

di = np.diff(i)
di_idx = np.flatnonzero(di) # != 0
start_idx_idx = 2 if test_pulse else 0 # skip the first up/down (test pulse) if present

if test_pulse:
if len(di_idx) % 2 == 1: # skip the truncated test pulse
start_idx_idx = 1
else:
start_idx_idx = 2
else:
start_idx_idx = 0

if len(di_idx[start_idx_idx:]) == 0: # if no stimulus is found
return None, None, 0.0, None, None
Expand Down
9 changes: 5 additions & 4 deletions ipfx/subthresh_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def voltage_deflection(t, v, i, start, end, deflect_type=None):
if deflect_type is None:
if i is not None:
halfway_index = tsu.find_time_index(t, (end - start) / 2. + start)
if i[halfway_index] >= 0:
if (i[halfway_index] - i[start_index - 1]) >= 0:
deflect_type = "max"
else:
deflect_type = "min"
Expand Down Expand Up @@ -156,9 +156,10 @@ def input_resistance(t_set, i_set, v_set, start, end, baseline_interval=0.1):
v_vals = []
i_vals = []
for t, i, v, in zip(t_set, i_set, v_set):
v_peak, min_index = voltage_deflection(t, v, i, start, end, 'min')
v_vals.append(v_peak)
i_vals.append(i[min_index])
v_stim_avg = tsu.average_voltage(v, t, start, end)
i_stim_avg = tsu.average_voltage(i, t, start, end)
v_vals.append(v_stim_avg)
i_vals.append(i_stim_avg)

v = np.array(v_vals)
i = np.array(i_vals)
Expand Down
25 changes: 14 additions & 11 deletions ipfx/time_series_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,25 @@ def calculate_dvdt(v, t, filter=None):
dvdt : numpy array of time-derivative of voltage (V/s = mV/ms)
"""

if has_fixed_dt(t) and filter:
delta_t = t[1] - t[0]
sample_freq = 1. / delta_t
filt_coeff = (filter * 1e3) / (sample_freq / 2.) # filter kHz -> Hz, then get fraction of Nyquist frequency
if filt_coeff < 0 or filt_coeff >= 1:
raise ValueError("bessel coeff ({:f}) is outside of valid range [0,1); cannot filter sampling frequency {:.1f} kHz with cutoff frequency {:.1f} kHz.".format(filt_coeff, sample_freq / 1e3, filter))
b, a = signal.bessel(4, filt_coeff, "low")
v_filt = signal.filtfilt(b, a, v, axis=0)
dv = np.diff(v_filt)
else:
try:
if has_fixed_dt(t) and filter:
delta_t = t[1] - t[0]
sample_freq = 1. / delta_t
filt_coeff = (filter * 1e3) / (sample_freq / 2.) # filter kHz -> Hz, then get fraction of Nyquist frequency
if filt_coeff < 0 or filt_coeff >= 1:
raise ValueError("bessel coeff ({:f}) is outside of valid range [0,1); cannot filter sampling frequency {:.1f} kHz with cutoff frequency {:.1f} kHz.".format(filt_coeff, sample_freq / 1e3, filter))
b, a = signal.bessel(4, filt_coeff, "low")
v_filt = signal.filtfilt(b, a, v, axis=0)
dv = np.diff(v_filt)
else:
dv = np.diff(v)
except ValueError:
dv = np.diff(v)

dt = np.diff(t)
dvdt = 1e-3 * dv / dt # in V/s = mV/ms

# some data sources, such as neuron, occasionally report
# some data sources, such as neuron, occasionally report
# duplicate timestamps, so we require that dt is not 0
return dvdt[np.fabs(dt) > sys.float_info.epsilon]

Expand Down
6 changes: 6 additions & 0 deletions tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def test_get_experiment_epoch(i, sampling_rate, expt_epoch):
None
),

# truncated test pulse
(
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
4,
None
),
]
)
def test_get_test_epoch(i, sampling_rate, test_epoch):
Expand Down
13 changes: 13 additions & 0 deletions tests/test_spike_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,19 @@ def test_check_spikes_and_peaks():
assert np.allclose(new_peaks, peaks[1:])


def test_check_spikes_and_peaks_close_to_end():
t = np.arange(0, 2000) * 5e-6
v = np.zeros_like(t)
spikes = np.array([500])
peaks = np.array([1510])
upstrokes = np.array([700])
dvdt = np.ones_like(t)

new_spikes, new_peaks, new_upstrokes, clipped = spkd.check_thresholds_and_peaks(v, t, spikes, peaks, upstrokes, dvdt=dvdt)
assert np.allclose(new_spikes, spikes[:-1])
assert np.allclose(new_peaks, peaks[1:])


def test_thresholds(spike_test_pair):
data = spike_test_pair
t = data[:, 0]
Expand Down
15 changes: 14 additions & 1 deletion tests/test_spike_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_width_calculation(spike_test_pair):

expected_widths = np.array([0.000545, 0.000585])
assert np.allclose(
spkf.find_widths(v, t, spikes, peaks, troughs),
spkf.find_widths(v, t, spikes, peaks, troughs),
expected_widths
)

Expand All @@ -64,3 +64,16 @@ def test_width_calculation_too_many_troughs(spike_test_pair):

with pytest.raises(er.FeatureError):
spkf.find_widths(v, t, spikes, peaks, troughs)


def test_width_calculation_missing_troughs(spike_test_pair):
data = spike_test_pair
t = data[:, 0]
v = data[:, 1]
spikes = np.array([725, 3382])
peaks = np.array([812, 3478])
troughs = np.array([1089, np.nan])

found_width = spkf.find_widths(v, t, spikes, peaks, troughs)
assert np.isclose(found_width[0], 0.000545)
assert np.isnan(found_width[1])
5 changes: 5 additions & 0 deletions tests/test_stim_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
[0, 1, 1, 0],
False,
(1, 2, 1, 1, 2)
),
(
[1, 1, 0, 0, 0, 2, 2, 2, 2, 0, 0],
True,
(5, 4, 2, 5, 8)
)
]

Expand Down
34 changes: 34 additions & 0 deletions tests/test_subthresh_features.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ipfx.subthresh_features as subf
import numpy as np
import pytest


def test_input_resistance():
Expand Down Expand Up @@ -76,3 +77,36 @@ def test_time_constant_noise_acceptance():

tau = subf.time_constant(t, v, i, start=start, end=end)
assert np.isclose(actual_tau, tau, rtol=1e-3)


test_params = [
(
[-5, -5, -5, -2, 0, 1, 1, 1, -2 -5, -5],
[0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0],
3,
8,
(1, 5)
),
(
[-5, -5, -5, -7, -10, -12, -12, -12, -7 - 5, -5],
[0, 0, 0, -1, -1, -1, -1, -1, -1, 0, 0],
3,
8,
(-12, 5)
),
(
[-70, -70, -70, -50, -20, -30, -30, -30, -50, -70, -70],
[-70, -70, -70, -20, -20, -20, -20, -20, -20, -70, -70],
3,
8,
(-20, 4)
),
]


@pytest.mark.parametrize('v, i, start, end, deflection_result', test_params)
def test_voltage_deflection(v, i, start, end, deflection_result):
t = np.arange(0, 10)
deflection_v, deflection_idx = subf.voltage_deflection(t, v, i, start, end)
assert deflection_v == deflection_result[0]
assert deflection_idx == deflection_result[1]
8 changes: 8 additions & 0 deletions tests/test_time_series_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def test_dvdt_no_filter():
assert np.allclose(tsu.calculate_dvdt(v, t), np.diff(v) / np.diff(t))


def test_dvdt_with_filter_small_samplingrate():
t = np.array([0, 1, 2, 3]) / 10000.0
v = np.array([1, 1, 1, 1])

assert np.allclose(tsu.calculate_dvdt(v, t, 10), np.diff(v) / np.diff(t))


def test_fixed_dt():
t = [0, 1, 2, 3]
assert tsu.has_fixed_dt(t)
Expand All @@ -27,6 +34,7 @@ def test_fixed_dt():
t[0] -= 3.
assert not tsu.has_fixed_dt(t)


def test_flatnotnan():
a = [1, 10, 12, 17, 13, 4, 8, np.nan, np.nan]

Expand Down