-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset_loader.py
More file actions
268 lines (215 loc) · 9.73 KB
/
dataset_loader.py
File metadata and controls
268 lines (215 loc) · 9.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import requests
import os
import zipfile
import mne
import numpy as np
import pandas as pd
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
class DryadSpeechLoader:
"""
Downloads and loads the Dryad Speech EEG dataset for training the brainwave AI model.
"""
def __init__(self, data_dir="./dryad_data"):
self.data_dir = Path(data_dir)
self.data_dir.mkdir(exist_ok=True)
# Multiple potential URLs for the dataset
self.dataset_urls = [
"https://datadryad.org/stash/downloads/file_stream/485896",
"https://dryad-assetstore-merritt-west.s3-us-west-2.amazonaws.com/ark%3A%2F13030%2Fm5w67n8q%7C1%7Cproducer%2FTilburg_et_al_2020_processed.zip",
"https://datadryad.org/api/v2/datasets/doi%3A10.5061%2Fdryad.070jc/download"
]
def download_dataset(self):
"""Download the Dryad Speech dataset."""
zip_path = self.data_dir / "dryad_speech.zip"
if zip_path.exists():
print("Dataset already downloaded.")
return zip_path
print("Downloading Dryad Speech dataset...")
for i, url in enumerate(self.dataset_urls):
try:
print(f"Attempting download from source {i+1}/{len(self.dataset_urls)}...")
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
with open(zip_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Dataset downloaded successfully to {zip_path}")
return zip_path
except Exception as e:
print(f"Source {i+1} failed: {e}")
continue
print("All download sources failed. Dataset not available.")
return None
def extract_dataset(self, zip_path):
"""Extract the downloaded dataset."""
extract_path = self.data_dir / "extracted"
if extract_path.exists():
print("Dataset already extracted.")
return extract_path
print("Extracting dataset...")
try:
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_path)
print(f"Dataset extracted to {extract_path}")
return extract_path
except Exception as e:
print(f"Error extracting dataset: {e}")
return None
def load_eeg_files(self, extract_path):
"""Load EEG files from the extracted dataset."""
eeg_files = list(extract_path.rglob("*.bdf")) + list(extract_path.rglob("*.edf"))
if not eeg_files:
# Try alternative file extensions
eeg_files = list(extract_path.rglob("*.fif")) + list(extract_path.rglob("*.set"))
print(f"Found {len(eeg_files)} EEG files")
return eeg_files
def load_experiment_data(self, experiment_type="audiobook"):
"""
Load data for a specific experiment type.
Args:
experiment_type: Type of experiment to load
- "audiobook": Audio-book listening experiment
- "reversed": Reversed speech experiment
- "n400": Sentence comprehension experiment
- "cocktail": Cocktail party experiment
- "multisensory": Audio-visual experiment
"""
# Download and extract dataset
zip_path = self.download_dataset()
if not zip_path:
return None, None
extract_path = self.extract_dataset(zip_path)
if not extract_path:
return None, None
# Find relevant files for the experiment
eeg_files = self.load_eeg_files(extract_path)
if not eeg_files:
print("No EEG files found. Checking directory structure...")
for item in extract_path.rglob("*"):
if item.is_file():
print(f"Found file: {item}")
return None, None
# Load first available file as example
try:
print(f"Loading EEG file: {eeg_files[0]}")
raw = mne.io.read_raw(eeg_files[0], preload=True, verbose=False)
# Get data and info
data_array = raw.get_data() # Shape: (channels, samples)
sfreq = raw.info['sfreq']
ch_names = raw.ch_names
print(f"Loaded EEG data: {data_array.shape[0]} channels, {data_array.shape[1]} samples")
print(f"Sampling rate: {sfreq} Hz")
print(f"Duration: {data_array.shape[1] / sfreq:.2f} seconds")
return data_array.T, { # Transpose to (samples, channels)
'sampling_rate': sfreq,
'channel_names': ch_names,
'experiment_type': experiment_type,
'file_path': str(eeg_files[0])
}
except Exception as e:
print(f"Error loading EEG file: {e}")
# Try alternative loading method
try:
import pyedflib
f = pyedflib.EdfReader(str(eeg_files[0]))
n_channels = f.signals_in_file
signal_labels = f.getSignalLabels()
sfreq = f.getSampleFrequency(0)
# Read all signals
signals = []
for i in range(n_channels):
signals.append(f.readSignal(i))
f.close()
data = np.array(signals).T # Shape: (samples, channels)
print(f"Loaded EEG data (pyedflib): {data.shape[1]} channels, {data.shape[0]} samples")
print(f"Sampling rate: {sfreq} Hz")
return data, {
'sampling_rate': sfreq,
'channel_names': signal_labels,
'experiment_type': experiment_type,
'file_path': str(eeg_files[0])
}
except Exception as e2:
print(f"Error with alternative loading method: {e2}")
return None, None
def preprocess_for_speech(self, data, info, target_words=None):
"""
Preprocess EEG data specifically for speech-related analysis.
Args:
data: EEG data (samples, channels)
info: Data info dictionary
target_words: Optional list of target words/stimuli
"""
# Basic preprocessing for speech EEG
sampling_rate = info['sampling_rate']
# Segment data into epochs (e.g., 1-second windows)
epoch_length = int(sampling_rate) # 1 second epochs
n_epochs = data.shape[0] // epoch_length
epochs = []
labels = []
for i in range(n_epochs):
start_idx = i * epoch_length
end_idx = start_idx + epoch_length
epoch_data = data[start_idx:end_idx, :]
epochs.append(epoch_data)
# If target words are available, create labels
if target_words and i < len(target_words):
labels.append(target_words[i])
else:
labels.append(f"epoch_{i}")
return np.array(epochs), labels, {
'epoch_length': epoch_length,
'sampling_rate': sampling_rate,
'n_epochs': n_epochs
}
def get_speech_tokens(self, epochs, labels, tokenizer):
"""
Convert speech EEG epochs to tokens for training.
Args:
epochs: EEG epochs (n_epochs, samples, channels)
labels: Labels for each epoch
tokenizer: Fitted tokenizer
"""
all_tokens = []
token_labels = []
for epoch, label in zip(epochs, labels):
# Tokenize each epoch
tokens = tokenizer.tokenize(epoch)
all_tokens.extend(tokens)
token_labels.extend([label] * len(tokens))
return all_tokens, token_labels
def integrate_real_data():
"""
Function to integrate real EEG data into the existing brainwave AI system.
"""
loader = DryadSpeechLoader()
# Try to download the dataset first
zip_path = loader.download_dataset()
if zip_path is None:
print("Real EEG dataset download failed.")
print("The Dryad Speech dataset requires manual download from:")
print("https://datadryad.org/stash/dataset/doi:10.5061/dryad.070jc")
print("Please download and extract to 'dryad_data' folder to use authentic data.")
return None, {"status": "download_required", "manual_url": "https://datadryad.org/stash/dataset/doi:10.5061/dryad.070jc"}
# Extract the dataset
extract_path = loader.extract_dataset(zip_path)
if extract_path is None:
print("Dataset extraction failed.")
return None, {"status": "extraction_failed"}
# Load EEG files
eeg_files = loader.load_eeg_files(extract_path)
if not eeg_files:
print("No EEG files found in dataset.")
return None, {"status": "no_files_found"}
# Load audiobook experiment (good for speech comprehension)
data, info = loader.load_experiment_data("audiobook")
if data is not None:
print("Successfully loaded real EEG data!")
print(f"Data shape: {data.shape}")
print(f"Experiment info: {info}")
return data, info
else:
print("Dataset found but could not load experiment data.")
return None, {"status": "processing_failed"}