-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbrainwave_processor.py
More file actions
193 lines (155 loc) · 6.65 KB
/
brainwave_processor.py
File metadata and controls
193 lines (155 loc) · 6.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import numpy as np
from scipy import signal
from scipy.signal import butter, filtfilt, iirnotch
import warnings
warnings.filterwarnings('ignore')
class BrainwaveProcessor:
"""
Processes raw EEG/brainwave data including filtering, artifact removal, and normalization.
"""
def __init__(self):
self.sampling_rate = 256 # Default sampling rate in Hz
def process_eeg(self, data, sampling_rate=256, apply_bandpass=True,
low_freq=1.0, high_freq=50.0, apply_notch=True,
notch_freq=50, remove_artifacts=True, artifact_threshold=100,
normalize=True):
"""
Complete EEG processing pipeline.
Args:
data: Raw EEG data (samples x channels)
sampling_rate: Sampling frequency in Hz
apply_bandpass: Whether to apply bandpass filter
low_freq: Low frequency cutoff for bandpass filter
high_freq: High frequency cutoff for bandpass filter
apply_notch: Whether to apply notch filter
notch_freq: Notch filter frequency (50 or 60 Hz)
remove_artifacts: Whether to remove artifacts
artifact_threshold: Threshold for artifact detection in μV
normalize: Whether to normalize the data
Returns:
Processed EEG data
"""
processed_data = data.copy()
self.sampling_rate = sampling_rate
# Apply bandpass filter
if apply_bandpass and low_freq and high_freq:
processed_data = self.bandpass_filter(
processed_data, low_freq, high_freq, sampling_rate
)
# Apply notch filter
if apply_notch and notch_freq:
processed_data = self.notch_filter(
processed_data, notch_freq, sampling_rate
)
# Remove artifacts
if remove_artifacts and artifact_threshold:
processed_data = self.remove_artifacts(
processed_data, artifact_threshold
)
# Normalize data
if normalize:
processed_data = self.normalize_data(processed_data)
return processed_data
def bandpass_filter(self, data, low_freq, high_freq, sampling_rate):
"""
Apply bandpass filter to EEG data.
"""
nyquist = sampling_rate / 2
low = low_freq / nyquist
high = high_freq / nyquist
# Design Butterworth filter
b, a = butter(4, [low, high], btype='band')
# Apply filter to each channel
filtered_data = np.zeros_like(data)
for ch in range(data.shape[1]):
filtered_data[:, ch] = filtfilt(b, a, data[:, ch])
return filtered_data
def notch_filter(self, data, notch_freq, sampling_rate, quality_factor=30):
"""
Apply notch filter to remove line noise (50/60 Hz).
"""
nyquist = sampling_rate / 2
freq_ratio = notch_freq / nyquist
# Design notch filter
b, a = iirnotch(freq_ratio, quality_factor)
# Apply filter to each channel
filtered_data = np.zeros_like(data)
for ch in range(data.shape[1]):
filtered_data[:, ch] = filtfilt(b, a, data[:, ch])
return filtered_data
def remove_artifacts(self, data, threshold):
"""
Remove artifacts based on amplitude threshold.
"""
cleaned_data = data.copy()
# Find samples exceeding threshold
artifact_mask = np.abs(data) > threshold
# Replace artifacts with interpolated values
for ch in range(data.shape[1]):
channel_data = cleaned_data[:, ch]
artifacts = artifact_mask[:, ch]
if np.any(artifacts):
# Simple linear interpolation for artifact removal
artifact_indices = np.where(artifacts)[0]
clean_indices = np.where(~artifacts)[0]
if len(clean_indices) > 1:
# Interpolate artifacts
interp_values = np.interp(
artifact_indices, clean_indices, channel_data[clean_indices]
)
cleaned_data[artifact_indices, ch] = interp_values
return cleaned_data
def normalize_data(self, data):
"""
Normalize EEG data to zero mean and unit variance per channel.
"""
normalized_data = np.zeros_like(data)
for ch in range(data.shape[1]):
channel_data = data[:, ch]
mean_val = np.mean(channel_data)
std_val = np.std(channel_data)
if std_val > 0:
normalized_data[:, ch] = (channel_data - mean_val) / std_val
else:
normalized_data[:, ch] = channel_data - mean_val
return normalized_data
def extract_frequency_features(self, data, sampling_rate=256):
"""
Extract frequency domain features from EEG data.
"""
features = {}
# Define frequency bands
bands = {
'delta': (0.5, 4),
'theta': (4, 8),
'alpha': (8, 13),
'beta': (13, 30),
'gamma': (30, 100)
}
for ch in range(data.shape[1]):
channel_data = data[:, ch]
# Compute power spectral density
freqs, psd = signal.welch(
channel_data, sampling_rate, nperseg=min(256, len(channel_data)//4)
)
# Extract power in each frequency band
for band_name, (low, high) in bands.items():
band_mask = (freqs >= low) & (freqs <= high)
band_power = np.trapz(psd[band_mask], freqs[band_mask])
if f'{band_name}_power' not in features:
features[f'{band_name}_power'] = []
features[f'{band_name}_power'].append(band_power)
return features
def compute_connectivity(self, data):
"""
Compute simple connectivity measures between channels.
"""
num_channels = data.shape[1]
connectivity_matrix = np.zeros((num_channels, num_channels))
# Compute correlation-based connectivity
for i in range(num_channels):
for j in range(num_channels):
if i != j:
correlation = np.corrcoef(data[:, i], data[:, j])[0, 1]
connectivity_matrix[i, j] = abs(correlation)
return connectivity_matrix