diff --git a/.flake8 b/.flake8 index 7f73516..b57c737 100644 --- a/.flake8 +++ b/.flake8 @@ -25,6 +25,7 @@ ignore = E731, E713, E714, + E722, E741, F403, F405, diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 70511d2..259b3ac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,7 +37,7 @@ repos: rev: v2.2.6 hooks: - id: codespell - args: [--toml, pyproject.toml] + args: [--toml, pyproject.toml, -L, whis] additional_dependencies: [tomli] - repo: https://github.com/jumanjihouse/pre-commit-hook-yamlfmt diff --git a/pydfc/__init__.py b/pydfc/__init__.py index b62481f..206f254 100644 --- a/pydfc/__init__.py +++ b/pydfc/__init__.py @@ -31,4 +31,6 @@ "dfc_methods", "dfc_utils", "comparison", + "task_utils", + "simul_utils", ] diff --git a/pydfc/data_loader.py b/pydfc/data_loader.py index 59307ac..66ec2e8 100644 --- a/pydfc/data_loader.py +++ b/pydfc/data_loader.py @@ -527,25 +527,18 @@ def multi_nifti2timeseries( def load_TS( data_root, file_name, - SESSIONs, subj_id2load=None, task=None, + session=None, run=None, ): """ load a TIME_SERIES object from a .npy file - if SESSIONs is a list, it will load all the sessions, - if it is a string, it will load that session if subj_id2load is None, it will load all the subjects file_name: name of the file to load - format example: {subj_id}_{task}_{run}_time-series.npy + format example: {subj_id}_{session}_{task}_{run}_time-series.npy (keep the {} for the variables) """ - # check if SESSIONs is a list or a string - flag = False - if type(SESSIONs) is str: - SESSIONs = [SESSIONs] - flag = True if subj_id2load is None: SUBJECTS = find_subj_list(data_root) @@ -553,37 +546,50 @@ def load_TS( assert "sub-" in subj_id2load, "subj_id2load must start with 'sub-'" SUBJECTS = [subj_id2load] - TS = {} - for session in SESSIONs: - TS[session] = None - for subj in SUBJECTS: - subj_fldr = subj - # make the file_name - TS_file = deepcopy(file_name) - if "{subj_id}" in file_name: - TS_file = TS_file.replace("{subj_id}", subj) - if "{task}" in file_name: - assert task is not None, "task must be provided" - TS_file = TS_file.replace("{task}", task) - if "{run}" in file_name: - assert run is not None, "run must be provided" - TS_file = TS_file.replace("{run}", run) - - try: + TS = None + for subj in SUBJECTS: + subj_fldr = subj + # make the file_name + TS_file = deepcopy(file_name) + if "{subj_id}" in file_name: + TS_file = TS_file.replace("{subj_id}", subj) + if "{task}" in file_name: + assert task is not None, "task must be provided" + TS_file = TS_file.replace("{task}", task) + if "{session}" in file_name: + assert session is not None, "session must be provided" + TS_file = TS_file.replace("{session}", session) + if "{run}" in file_name: + assert run is not None, "run must be provided" + TS_file = TS_file.replace("{run}", run) + + try: + if session is None: time_series = np.load( f"{data_root}/{subj_fldr}/{TS_file}", allow_pickle="True" ).item() - except FileNotFoundError: - print(f"File {TS_file} not found for {subj}") - continue - - if TS[session] is None: - TS[session] = time_series else: - TS[session].concat_ts(time_series) + time_series = np.load( + f"{data_root}/{subj_fldr}/{session}/{TS_file}", + allow_pickle="True", + ).item() + except FileNotFoundError: + print(f"File {TS_file} not found for {subj}") + continue + + if TS is None: + TS = time_series + else: + try: + TS.concat_ts(time_series) + except AssertionError as e: + # print the error message + print(f"Error in concatenating time series for {subj}: {e}") + # raise error with a message and stop the program + raise Exception( + f"Fs of subj {subj} TS is {time_series.Fs} while the group Fs is {TS.Fs}" + ) - if flag: - return TS[SESSIONs[0]] return TS diff --git a/pydfc/ml_utils.py b/pydfc/ml_utils.py new file mode 100644 index 0000000..f183ae8 --- /dev/null +++ b/pydfc/ml_utils.py @@ -0,0 +1,2235 @@ +# -*- coding: utf-8 -*- +""" +Functions to facilitate applying ML algorithms to dFC. + +Created on Aug 8 2024 +@author: Mohammad Torabi +""" +import os +import warnings + +import numpy as np +from scipy.spatial import procrustes +from sklearn.base import BaseEstimator, TransformerMixin, clone +from sklearn.cluster import KMeans +from sklearn.cross_decomposition import PLSRegression +from sklearn.decomposition import PCA +from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier +from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.manifold import SpectralEmbedding +from sklearn.metrics import ( + accuracy_score, + average_precision_score, + balanced_accuracy_score, + confusion_matrix, + f1_score, + mean_squared_error, + precision_score, + r2_score, + recall_score, + silhouette_score, +) +from sklearn.model_selection import ( + GridSearchCV, + GroupKFold, + StratifiedGroupKFold, + StratifiedKFold, +) +from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors, kneighbors_graph +from sklearn.pipeline import Pipeline, make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVC, SVR +from sklearn.utils import shuffle + +from .dfc_utils import dFC_mat2vec, dFC_vec2mat, rank_norm +from .task_utils import ( + calc_relative_task_on, + calc_rest_duration, + calc_task_duration, + calc_transition_freq, + extract_task_presence, +) + +################################# Feature Loading Functions #################################### + + +def find_available_subjects(dFC_root, task, run=None, session=None, dFC_id=None): + """ + Find the subjects that have dFC results for the given task and dFC_id (method). + + If run is specified, the dFC results for that run will be used. + Otherwise, the subjects that have dFC results at least for one run will returned. + + If session is specified, the dFC results for that session will be used. + Otherwise, it is considered that the dataset does not have session information. + Note that not specifying session will cause error if the dataset has session information. + """ + SUBJECTS = list() + ALL_SUBJ_FOLDERS = os.listdir(f"{dFC_root}/") + ALL_SUBJ_FOLDERS = [folder for folder in ALL_SUBJ_FOLDERS if "sub-" in folder] + ALL_SUBJ_FOLDERS.sort() + for subj_folder in ALL_SUBJ_FOLDERS: + if session is None: + ALL_DFC_FILES = os.listdir(f"{dFC_root}/{subj_folder}/") + else: + if not os.path.exists(f"{dFC_root}/{subj_folder}/{session}/"): + continue + ALL_DFC_FILES = os.listdir(f"{dFC_root}/{subj_folder}/{session}/") + ALL_DFC_FILES = [ + dFC_file for dFC_file in ALL_DFC_FILES if f"_{task}_" in dFC_file + ] + if dFC_id is not None: + ALL_DFC_FILES = [ + dFC_file for dFC_file in ALL_DFC_FILES if f"_{dFC_id}.npy" in dFC_file + ] + if run is not None: + ALL_DFC_FILES = [ + dFC_file for dFC_file in ALL_DFC_FILES if f"_{run}_" in dFC_file + ] + if session is not None: + ALL_DFC_FILES = [ + dFC_file for dFC_file in ALL_DFC_FILES if f"_{session}_" in dFC_file + ] + ALL_DFC_FILES.sort() + if len(ALL_DFC_FILES) > 0: + SUBJECTS.append(subj_folder) + return SUBJECTS + + +def load_dFC(dFC_root, subj, task, dFC_id, run=None, session=None): + """ + Load the dFC results for a given subject, task, dFC_id, run and session. + """ + if session is None: + if run is None: + dFC = np.load( + f"{dFC_root}/{subj}/dFC_{task}_{dFC_id}.npy", allow_pickle="TRUE" + ).item() + else: + dFC = np.load( + f"{dFC_root}/{subj}/dFC_{task}_{run}_{dFC_id}.npy", allow_pickle="TRUE" + ).item() + else: + if run is None: + dFC = np.load( + f"{dFC_root}/{subj}/{session}/dFC_{session}_{task}_{dFC_id}.npy", + allow_pickle="TRUE", + ).item() + else: + dFC = np.load( + f"{dFC_root}/{subj}/{session}/dFC_{session}_{task}_{run}_{dFC_id}.npy", + allow_pickle="TRUE", + ).item() + + return dFC + + +def load_task_data(roi_root, subj, task, run=None, session=None): + """ + Load the task data for a given subject, task and run. + """ + if session is None: + if run is None: + task_data = np.load( + f"{roi_root}/{subj}/{subj}_{task}_task-data.npy", allow_pickle="TRUE" + ).item() + else: + task_data = np.load( + f"{roi_root}/{subj}/{subj}_{task}_{run}_task-data.npy", + allow_pickle="TRUE", + ).item() + else: + if run is None: + task_data = np.load( + f"{roi_root}/{subj}/{session}/{subj}_{session}_{task}_task-data.npy", + allow_pickle="TRUE", + ).item() + else: + task_data = np.load( + f"{roi_root}/{subj}/{session}/{subj}_{session}_{task}_{run}_task-data.npy", + allow_pickle="TRUE", + ).item() + + return task_data + + +################################# Feature Extraction Functions #################################### + + +def extract_task_features(TASKS, RUNS, session, roi_root, dFC_root, no_hrf=False): + """ + Extract task features from the event data. + + if no_hrf is True, the task presence will be binarized without convolving with HRF. + Therefore the task features will be extracted based on the event labels and + without the effect of HRF. + """ + task_features = { + "task": list(), + "run": list(), + "relative_task_on": list(), + "avg_task_duration": list(), + "var_task_duration": list(), + "avg_rest_duration": list(), + "var_rest_duration": list(), + "num_of_transitions": list(), + "relative_transition_freq": list(), + } + for task_id, task in enumerate(TASKS): + + if task == "task-restingstate": + continue + + for run in RUNS[task]: + + SUBJECTS = find_available_subjects( + dFC_root=dFC_root, task=task, run=run, session=session + ) + + for subj in SUBJECTS: + # event data + task_data = load_task_data( + roi_root=roi_root, subj=subj, task=task, run=run, session=session + ) + Fs_task = task_data["Fs_task"] + TR_task = 1 / Fs_task + + task_presence, indices = extract_task_presence( + event_labels=task_data["event_labels"], + TR_task=TR_task, + TR_mri=task_data["TR_mri"], + binary=True, + binarizing_method="GMM", + no_hrf=no_hrf, + ) + task_presence = task_presence[indices] + + relative_task_on = calc_relative_task_on(task_presence) + # task duration + avg_task_duration, var_task_duration = calc_task_duration( + task_presence, task_data["TR_mri"] + ) + # rest duration + avg_rest_duration, var_rest_duration = calc_rest_duration( + task_presence, task_data["TR_mri"] + ) + # freq of transitions + num_of_transitions, relative_transition_freq = calc_transition_freq( + task_presence + ) + + task_features["task"].append(task) + task_features["run"].append(run) + task_features["relative_task_on"].append(relative_task_on) + task_features["avg_task_duration"].append(avg_task_duration) + task_features["var_task_duration"].append(var_task_duration) + task_features["avg_rest_duration"].append(avg_rest_duration) + task_features["var_rest_duration"].append(var_rest_duration) + task_features["num_of_transitions"].append(num_of_transitions) + task_features["relative_transition_freq"].append(relative_transition_freq) + + return task_features + + +def dFC_feature_extraction_subj_lvl( + dFC, + task_data, + dynamic_pred="no", + normalize_dFC=False, + FCS_proba_for_SB=True, +): + """ + Extract features and target for task presence classification + for a single subject. + dynamic_pred: "no", "past", "past_and_future" + + FCS_proba_for_SB: if True, use FCS_proba as features for state-based dFC. + If False, use dFC_vecs (dFC matrix as features). + """ + # dFC features + # for state-based dFC, we use the FCS_proba as features + # for state-free dFC, we use the dFC matrix as features + if dFC.measure.is_state_based and FCS_proba_for_SB: + # state-based dFC + dFC_vecs = dFC.FCS_proba # shape: (n_time, n_states) + TR_array = dFC.TR_array + + assert dFC_vecs.shape[0] == len( + TR_array + ), "dFC_vecs and TR_array have different number of samples." + assert ( + dFC_vecs.shape[1] == dFC.measure.params["n_states"] + ), "dFC_vecs and n_states are not consistent." + else: + dFC_mat = dFC.get_dFC_mat() + TR_array = dFC.TR_array + if normalize_dFC: + dFC_mat = rank_norm(dFC_mat, global_norm=False) + dFC_vecs = dFC_mat2vec(dFC_mat) + + # event data + task_presence, indices = extract_task_presence( + event_labels=task_data["event_labels"], + TR_task=1 / task_data["Fs_task"], + TR_mri=task_data["TR_mri"], + TR_array=TR_array, + binary=True, + binarizing_method="GMM", + ) + + # features = dFC_vecs + # target = task_presence.ravel() + + # use absolute task presence + features = dFC_vecs[indices, :] + target = task_presence.ravel()[indices] + + assert ( + features.shape[0] == target.shape[0] + ), "Features and target have different number of samples." + + if dynamic_pred == "past": + # concat current TR and two TR before of features to predict the current TR of target + # ignore the edge case of the first two TRs + features = np.concatenate( + (features, np.roll(features, 1, axis=0), np.roll(features, 2, axis=0)), axis=1 + ) + features = features[2:, :] + target = target[2:] + elif dynamic_pred == "past_and_future": + # concat current TR and two TR before and after of features to predict the current TR of target + # ignore the edge case of the first and last two TRs + features = np.concatenate( + ( + features, + np.roll(features, 1, axis=0), + np.roll(features, 2, axis=0), + np.roll(features, -1, axis=0), + np.roll(features, -2, axis=0), + ), + axis=1, + ) + features = features[2:-2, :] + target = target[2:-2] + + features = features.astype(np.float32, copy=False) + target = target.astype(np.int8, copy=False) # labels smaller & faster + return features, target + + +def dFC_feature_extraction( + task, + train_subjects, + test_subjects, + dFC_id, + roi_root, + dFC_root, + run=None, + session=None, + dynamic_pred="no", + normalize_dFC=False, + FCS_proba_for_SB=True, +): + """ + Extract features and target for task presence classification + for all subjects. + if run is specified, dFC results for that run will be used. + + if FCS_proba_for_SB is True, use FCS_proba as features for state-based dFC. + If False, use dFC_vecs (dFC matrix as features). + """ + dFC_measure_name = None + measure_is_state_based = None + X_train = None + y_train = None + subj_label_train = list() + for subj in train_subjects: + + dFC = load_dFC( + dFC_root=dFC_root, + subj=subj, + task=task, + dFC_id=dFC_id, + run=run, + session=session, + ) + task_data = load_task_data( + roi_root=roi_root, subj=subj, task=task, run=run, session=session + ) + + X_subj, y_subj = dFC_feature_extraction_subj_lvl( + dFC=dFC, + task_data=task_data, + dynamic_pred=dynamic_pred, + normalize_dFC=normalize_dFC, + FCS_proba_for_SB=FCS_proba_for_SB, + ) + + # to make computations faster + X_subj = X_subj.astype(np.float32, copy=False) + y_subj = y_subj.astype(np.int8, copy=False) + + subj_label_train.extend([subj for i in range(X_subj.shape[0])]) + if X_train is None and y_train is None: + X_train = X_subj + y_train = y_subj + else: + X_train = np.concatenate((X_train, X_subj), axis=0) + y_train = np.concatenate((y_train, y_subj), axis=0) + + dFC_measure_name_new = dFC.measure.measure_name + measure_is_state_based_new = dFC.measure.is_state_based + if dFC_measure_name is None: + dFC_measure_name = dFC_measure_name_new + measure_is_state_based = measure_is_state_based_new + else: + assert ( + dFC_measure_name == dFC_measure_name_new + ), "dFC measure is not consistent." + assert ( + measure_is_state_based == measure_is_state_based_new + ), "dFC measure is not consistent." + + X_test = None + y_test = None + subj_label_test = list() + for subj in test_subjects: + dFC = load_dFC( + dFC_root=dFC_root, + subj=subj, + task=task, + dFC_id=dFC_id, + run=run, + session=session, + ) + task_data = load_task_data( + roi_root=roi_root, subj=subj, task=task, run=run, session=session + ) + + X_subj, y_subj = dFC_feature_extraction_subj_lvl( + dFC=dFC, + task_data=task_data, + dynamic_pred=dynamic_pred, + normalize_dFC=normalize_dFC, + FCS_proba_for_SB=FCS_proba_for_SB, + ) + + # to make computations faster + X_subj = X_subj.astype(np.float32, copy=False) + y_subj = y_subj.astype(np.int8, copy=False) + + subj_label_test.extend([subj for i in range(X_subj.shape[0])]) + if X_test is None and y_test is None: + X_test = X_subj.astype(np.float32, copy=False) + y_test = y_subj.astype(np.int8, copy=False) + else: + X_test = np.concatenate((X_test, X_subj), axis=0) + y_test = np.concatenate((y_test, y_subj), axis=0) + + dFC_measure_name_new = dFC.measure.measure_name + measure_is_state_based_new = dFC.measure.is_state_based + if dFC_measure_name is None: + dFC_measure_name = dFC_measure_name_new + measure_is_state_based = measure_is_state_based_new + else: + assert ( + dFC_measure_name == dFC_measure_name_new + ), "dFC measure is not consistent." + assert ( + measure_is_state_based == measure_is_state_based_new + ), "dFC measure is not consistent." + + # print(X_train.shape, X_test.shape, y_train.shape, y_test.shape) + subj_label_train = np.array(subj_label_train) + subj_label_test = np.array(subj_label_test) + + return ( + X_train, + X_test, + y_train, + y_test, + subj_label_train, + subj_label_test, + dFC_measure_name, + measure_is_state_based, + ) + + +################################# Feature Embedding Functions #################################### + + +def precheck_for_procruste(X_best, X_subj): + """ + Check if the two matrices have the same number of rows. if not, make them the same. + """ + # for the procrustes transformation, the number of samples should be the same + if X_subj.shape[0] > X_best.shape[0]: + # add zero rows to the embedding of the best subject + X_best_new = np.concatenate( + ( + X_best, + np.zeros( + ( + X_subj.shape[0] - X_best.shape[0], + X_best.shape[1], + ) + ), + ), + axis=0, + ) + elif X_subj.shape[0] < X_best.shape[0]: + # remove extra rows from the embedding of the best subject + X_best_new = X_best[: X_subj.shape[0], :] + else: + X_best_new = X_best + + X_best_new = X_best_new.copy() + + return X_best_new + + +def generalized_procrustes(X_embed_dict, max_iter=1000, tol=1e-6): + """ + Generalized Procrustes Analysis + + X_embed_dict: dict + dict of scans and their embeddings + + returns the mean X across scans to be used as the reference for procrustes transformation + """ + # initial step + # not all scans have the same number of samples + # find the max number of samples among all scans + max_samples = 0 + for scan in X_embed_dict: + if X_embed_dict[scan].shape[0] > max_samples: + max_samples = X_embed_dict[scan].shape[0] + + # find the mean embedding of all scan to use as the reference for procrustes transformation + X_list = [] + for scan in X_embed_dict: + X_scan_embed = X_embed_dict[scan] + # add zero rows to the embedding of the scan with less samples + if X_scan_embed.shape[0] < max_samples: + X_scan_embed_new = np.concatenate( + ( + X_scan_embed, + np.zeros( + ( + max_samples - X_scan_embed.shape[0], + X_scan_embed.shape[1], + ) + ), + ), + axis=0, + ) + else: + X_scan_embed_new = X_scan_embed + X_list.append(X_scan_embed_new) + + # now iteratively find the mean X for transform + for _ in range(10): + + try: + # initialize Procrustes distance + current_distance = 0 + + num_X = len(X_list) + + # initialize a mean X by randomly selecting + # one of the Xs using np.random.choice + mean_X = X_list[np.random.choice(num_X)] + + # create array for new Xs, add + new_Xs = np.zeros(np.array(X_list).shape) + + counter = 0 + flag = False + while True: + counter += 1 + if counter > max_iter: + # if the algorithm does not converge, break the cycle + # to avoid infinite loop + flag = True + break + + # add the mean X as first element of array + new_Xs[0] = mean_X + + # superimpose all shapes to current mean + for i in range(1, num_X): + _, new_X, _ = procrustes(mean_X, X_list[i]) + new_Xs[i] = new_X + + # calculate new mean + new_mean = np.mean(new_Xs, axis=0) + + _, _, new_distance = procrustes(new_mean, mean_X) + + # if the distance did not change, break the cycle + if np.abs(new_distance - current_distance) < tol: + break + + # align the new_mean to old mean + _, new_mean, _ = procrustes(mean_X, new_mean) + + # update mean and distance + mean_X = new_mean + current_distance = new_distance + + if not flag: + # if the algorithm converged, return the mean X + return mean_X + except: + continue + + raise RuntimeError("Generalized Procrustes Analysis did not converge.") + + +def twonn(X, discard_ratio=0.1, n_neighbors=30, eps=1e-12, metric="euclidean"): + """ + TWO-NN intrinsic dimension estimator. + + Parameters + ---------- + X : (n_samples, n_features) + discard_ratio : float in [0,1) + Fraction of largest mu values to discard (tail trimming). + n_neighbors : int + Number of neighbors to query (must be >= 3 ideally, and <= n_samples-1). + eps : float + Numerical tolerance for filtering mu values. + metric : str + Distance metric for NearestNeighbors. + + Returns + ------- + d : float + Estimated intrinsic dimension. + """ + X = np.asarray(X) + n = X.shape[0] + if n < 5: + raise ValueError("TWO-NN needs more samples (n >= 5 is a practical minimum).") + + k = int(min(max(n_neighbors, 3), n - 1)) # at least 3, at most n-1 + + nn = NearestNeighbors(n_neighbors=k, metric=metric) + nn.fit(X) + distances, _ = nn.kneighbors(X, return_distance=True) + + mu = np.full(n, np.nan, dtype=float) + + for i in range(n): + # distances[i, 0] is typically 0 (self). Find first two *positive* distances + pos = distances[i][distances[i] > eps] + if pos.size >= 2: + r1, r2 = pos[0], pos[1] + mu[i] = r2 / r1 + + mu = mu[np.isfinite(mu)] + mu = mu[mu > 1.0 + eps] # avoid log(1)=0 edge cases + + if mu.size < 5: + raise ValueError( + "Too few valid mu values after filtering; check duplicates / ties / eps." + ) + + # discard upper tail (largest mu) + mu.sort() + keep = int(np.floor((1.0 - discard_ratio) * mu.size)) + keep = max(5, keep) # don't keep too few + mu = mu[:keep] + + N = mu.size + # plotting positions; i/(N+1) is common and avoids CDF=1 exactly + F = np.arange(1, N + 1) / (N + 1.0) + + x = np.log(mu).reshape(-1, 1) + y = (-np.log(1.0 - F)).reshape(-1, 1) + + lr = LinearRegression(fit_intercept=False) + lr.fit(x, y) + return float(lr.coef_[0, 0]) + + +def SI_ID( + X, + y, + search_range=range(2, 50, 5), + n_neighbors_LE=125, + LE_embedding_method="embed+procrustes", + measure_is_state_based=False, +): + """ + Find the intrinsic dimension of the data based on the silhouette score. + """ + + SI_score = {} + for n_components in search_range: + if n_components > X.shape[1]: + # if the number of components is larger than the number of features, break + break + try: + X_train_embed, _ = embed_dFC_features( + train_subjects=["subj"], + test_subjects=[], + X_train=X, + X_test=None, + y_train=y, + y_test=None, + subj_label_train=np.array(["subj"] * len(y)), + subj_label_test=None, + embedding="LE", + n_components=n_components, + n_neighbors_LE=n_neighbors_LE, + LE_embedding_method=LE_embedding_method, + measure_is_state_based=measure_is_state_based, + ) + except Exception as e: + warnings.warn( + f"Error in SI_ID embedding with n_components={n_components}: {e}. Skipping this n_components." + ) + continue + + SI_score[n_components] = silhouette_score(X_train_embed, y) + + # find the intrinsic dimension based on the silhouette score + intrinsic_dim = max(SI_score, key=SI_score.get) + + return intrinsic_dim + + +import numpy as np +from sklearn.neighbors import NearestNeighbors + + +def localpca_intrinsic_dim( + X, + k=20, + method="explained_var", # "explained_var" or "eigengap" + var_threshold=0.9, # used for explained_var + max_dim=None, # cap returned dim (optional) + center=True, + metric="euclidean", + random_state=0, + agg="median", # "median", "mean", "trimmed_mean" + trim=0.1, # used if agg="trimmed_mean" + eps=1e-12, +): + """ + Local PCA intrinsic dimension estimation. + + Parameters + ---------- + X : (n_samples, n_features) + k : int + Neighborhood size (kNN). Must be < n_samples. + method : str + "explained_var": choose smallest d achieving cumulative variance >= var_threshold + "eigengap": choose d maximizing eigenvalue ratio lambda_d / lambda_{d+1} + var_threshold : float + Threshold for explained_var method. + max_dim : int or None + Max dimension to consider/return; defaults to min(n_features, k-1). + center : bool + Whether to mean-center each neighborhood before PCA. + metric : str + Metric for kNN graph. + agg : str + Aggregation across points: "median", "mean", "trimmed_mean" + trim : float + Trimming fraction for trimmed_mean. + eps : float + Numerical stability. + + Returns + ------- + d_global : float + Aggregated intrinsic dimension estimate. + d_local : (n_samples,) int + Local dimension estimates. + """ + X = np.asarray(X, dtype=float) + n, D = X.shape + if n < 5: + raise ValueError("Need more samples for localPCA ID.") + if k >= n: + raise ValueError(f"k must be < n_samples (got k={k}, n={n}).") + + # Choose max_dim limit + max_possible = min(D, k - 1) # local covariance rank limited by k-1 if centered + if max_dim is None: + max_dim = max_possible + else: + max_dim = int(min(max_dim, max_possible)) + max_dim = max(1, max_dim) + + # kNN indices (exclude self by requesting k+1 and dropping first) + nn = NearestNeighbors(n_neighbors=k + 1, metric=metric) + nn.fit(X) + _, idx = nn.kneighbors(X, return_distance=True) + nbrs = idx[:, 1:] # (n, k) + + d_local = np.zeros(n, dtype=int) + + for i in range(n): + Xi = X[nbrs[i]] # (k, D) + if center: + Xi = Xi - Xi.mean(axis=0, keepdims=True) + + # PCA via SVD of neighborhood matrix + # Xi = U S Vt ; singular values S relate to eigenvalues of covariance + # covariance eigenvalues proportional to (S^2) / (k-1) + # we can work directly with S^2 + try: + # full_matrices=False keeps it fast + _, S, _ = np.linalg.svd(Xi, full_matrices=False) + except np.linalg.LinAlgError: + d_local[i] = 1 + continue + + lam = S**2 # proportional to variance along PCs + if lam.size == 0: + d_local[i] = 1 + continue + + lam = lam[: max_dim + 1] # for eigengap need d and d+1 + lam = np.maximum(lam, eps) + + if method == "explained_var": + lam_use = lam[:max_dim] + cum = np.cumsum(lam_use) + total = cum[-1] + if total <= eps: + d_local[i] = 1 + else: + frac = cum / total + d_local[i] = int(np.searchsorted(frac, var_threshold) + 1) + + elif method == "eigengap": + # need ratios up to max_dim-1: lam[d-1]/lam[d] + lam_use = lam[: max_dim + 1] # ensures lam[d] exists + if lam_use.size < 2: + d_local[i] = 1 + else: + ratios = lam_use[:-1] / lam_use[1:] + # pick d that maximizes ratio, d in [1..max_dim] + d_local[i] = int(np.argmax(ratios) + 1) + else: + raise ValueError(f"Unknown method: {method}") + + # aggregate + if agg == "median": + d_global = float(np.median(d_local)) + elif agg == "mean": + d_global = float(np.mean(d_local)) + elif agg == "trimmed_mean": + d_sorted = np.sort(d_local) + m = len(d_sorted) + lo = int(np.floor(trim * m)) + hi = int(np.ceil((1 - trim) * m)) + hi = max(hi, lo + 1) + d_global = float(np.mean(d_sorted[lo:hi])) + else: + raise ValueError(f"Unknown agg: {agg}") + + return d_global, d_local + + +def find_intrinsic_dim( + X, + y, + subj_label, + subjects, + method="SI", + n_neighbors_LE=125, + search_range_SI=range(2, 50, 5), + LE_embedding_method="embed+procrustes", + measure_is_state_based=False, +): + """ + Find the number of components to use for embedding the data using LE. + Find the average intrinsic dimension across all subjects. + + method: "SI" or "twonn" or "localpca" + + Returns: + intrinsic_dim: number of components to use for embedding + """ + if method == "SI": + intrinsic_dim_all = list() + for subject in subjects: + X_subj = X[subj_label == subject, :] + y_subj = y[subj_label == subject] + try: + # some subjects may not have enough samples to estimate the intrinsic dimension + subj_estim_ID = SI_ID( + X_subj, + y_subj, + search_range=search_range_SI, + n_neighbors_LE=n_neighbors_LE, + LE_embedding_method=LE_embedding_method, + measure_is_state_based=measure_is_state_based, + ) + intrinsic_dim_all.append(subj_estim_ID) + except Exception as e: + warnings.warn( + f"Error in SI_ID for subject {subject}: {e}. Skipping this subject." + ) + continue + intrinsic_dim = int(np.mean(intrinsic_dim_all)) + elif method == "twonn": + intrinsic_dim_all = list() + for subject in subjects: + X_subj = X[subj_label == subject, :] + intrinsic_dim_all.append( + twonn(X_subj, discard_ratio=0.1, metric="correlation") + ) + intrinsic_dim = int(np.median(intrinsic_dim_all)) + elif method == "localpca": + intrinsic_dim_all = list() + for subject in subjects: + X_subj = X[subj_label == subject, :] + intrinsic_dim_diff_k = list() + # seatryrch 0.2 * X_subj.shape[0] and 0.3 * X_subj.shape[0] for k + for k in range( + max(5, int(0.1 * X_subj.shape[0])), # not letting go below 5 + int(0.3 * X_subj.shape[0]), + 5, + ): + if k == 1: + warnings.warn( + f"Warning: k=1 is not valid for localpca_intrinsic_dim. Skipping k=1 for subject {subject}." + ) + continue + try: + d_global, _ = localpca_intrinsic_dim( + X_subj, + k=k, + method="explained_var", + var_threshold=0.9, + center=True, + metric="correlation", + agg="median", + ) + if np.isfinite(d_global) and d_global >= 1: + intrinsic_dim_diff_k.append(d_global) + except Exception as e: + warnings.warn( + f"Error in localpca_intrinsic_dim for subject {subject} with k={k}: {e}." + ) + continue + if len(intrinsic_dim_diff_k) == 0: + warnings.warn( + f"No valid intrinsic dimensions found for subject {subject}." + ) + continue + intrinsic_dim_all.append(int(np.mean(intrinsic_dim_diff_k))) + if len(intrinsic_dim_all) == 0: + raise ValueError("No valid intrinsic dimensions found for any subject.") + intrinsic_dim = int(np.median(intrinsic_dim_all)) + return intrinsic_dim + + +def LE_transform(X, n_components, n_neighbors, distance_metric="euclidean"): + """ + Apply Laplacian Eigenmaps (LE) to transform data into a lower dimensional space. + + if n_neighbors >= n_samples, n_neighbors will be changed to the lower limit n_neighbors + """ + n_neighbors_upper = int(X.shape[0] / 8) + + if n_neighbors > n_neighbors_upper: + n_neighbors_to_be_used = n_neighbors_upper + # raise a warning + warnings.warn( + f"n_neighbors is larger than the limit. n_neighbors is set to {n_neighbors_to_be_used}." + ) + else: + n_neighbors_to_be_used = n_neighbors + + affinity_matrix = kneighbors_graph( + X, + n_neighbors=n_neighbors_to_be_used, + mode="connectivity", + include_self=False, + metric=distance_metric, + ) + + # Symmetrize + affinity_matrix = affinity_matrix.maximum(affinity_matrix.T) + + LE = SpectralEmbedding( + n_components=n_components, + affinity="precomputed", + n_neighbors=n_neighbors_to_be_used, + eigen_solver="lobpcg", + ) + X_embed = LE.fit_transform(X=affinity_matrix) + return X_embed + + +def LE_transform_dFC(X, n_components, n_neighbors, distance_metric="euclidean"): + """ + Transform dFC features into a lower dimensional space using Laplacian Eigenmaps (LE). + This function takes care of the case where the dFC samples are not unique, + specifically for state-based dFC features. + """ + unique_samples = np.unique(X, axis=0) + # if there are repeated samples, we need to apply LE on the unique samples + if unique_samples.shape[0] < X.shape[0] // 2: + n_neighbors_LE = int(3 / 5 * unique_samples.shape[0]) + unique_samples_embedded = LE_transform( + X=unique_samples, + n_components=n_components, + n_neighbors=n_neighbors_LE, + distance_metric=distance_metric, + ) + + # for each entry in X, put the corresponding entry in unique_samples_embedded + # in the corresponding position in X_embedded + X_embedded = np.zeros((X.shape[0], unique_samples_embedded.shape[1])) + for i, sample in enumerate(unique_samples): + idx = np.where((X == sample).all(axis=1))[0] + if len(idx) > 0: + X_embedded[idx] = unique_samples_embedded[i] + else: + # if all samples are unique, we can apply LE directly on the data + X_embedded = LE_transform( + X=X, + n_components=n_components, + n_neighbors=n_neighbors, + distance_metric=distance_metric, + ) + + return X_embedded + + +def LE_embed_procustes( + X_train, + X_test, + y_train, + y_test, + subj_label_train, + subj_label_test, + train_subjects, + test_subjects, + n_components=30, + n_neighbors_LE=125, + procruste_method="best_SI", +): + procrustes_limit = int(np.sqrt(2 * X_train.shape[0])) + if n_components > procrustes_limit: + warnings.warn( + f"n_components ({n_components}) is larger than the limit for procrustes method ({procrustes_limit}). Setting n_components to {procrustes_limit}." + ) + n_components = procrustes_limit - 1 + if procruste_method == "best_SI": + # first embed the dFC features of each subject into a lower dimensional space using LE separately + embed_dict = {} + for subject in train_subjects: + # assert the samples of the same subject are contiguous + assert np.all( + np.diff(np.where(subj_label_train == subject)[0]) == 1 + ), f"Indices of {subject} are not consecutive" + X_subj = X_train[subj_label_train == subject, :] + y_subj = y_train[subj_label_train == subject] + X_subj_embed = LE_transform_dFC( + X=X_subj, + n_components=n_components, + n_neighbors=n_neighbors_LE, + distance_metric="correlation", + ) + SI = silhouette_score(X_subj_embed, y_subj) + embed_dict[subject] = {"X_subj_embed": X_subj_embed, "SI": SI} + + # find the best transformation based on the SI score + best_SI = -1 + best_subject = None + for subject in embed_dict: + if embed_dict[subject]["SI"] > best_SI: + best_SI = embed_dict[subject]["SI"] + best_subject = subject + + # apply procrustes transformation to align the embeddings of different subjects + # use the embeddings of the subject with the highest SI score as the reference + X_train_embed = None + for subject in train_subjects: + X_subj_embed = embed_dict[subject]["X_subj_embed"] + # procrustes transformation + if subject == best_subject: + X_subj_embed_transformed = X_subj_embed + else: + # for the procrustes transformation, the number of samples should be the same + X_best_subj_embed = precheck_for_procruste( + embed_dict[best_subject]["X_subj_embed"], X_subj_embed + ) + _, X_subj_embed_transformed, _ = procrustes( + X_best_subj_embed, X_subj_embed + ) + if X_train_embed is None: + X_train_embed = X_subj_embed_transformed + else: + X_train_embed = np.concatenate( + (X_train_embed, X_subj_embed_transformed), axis=0 + ) + + # apply the same transformation to the test set + X_test_embed = None + for subject in test_subjects: + # assert the samples of the same subject are contiguous + assert np.all( + np.diff(np.where(subj_label_test == subject)[0]) == 1 + ), f"Indices of {subject} are not consecutive" + X_subj = X_test[subj_label_test == subject, :] + X_subj_embed = LE_transform_dFC( + X=X_subj, + n_components=n_components, + n_neighbors=n_neighbors_LE, + distance_metric="correlation", + ) + # procrustes transformation + # for the procrustes transformation, the number of samples should be the same + X_best_subj_embed = precheck_for_procruste( + embed_dict[best_subject]["X_subj_embed"], X_subj_embed + ) + _, X_subj_embed_transformed, _ = procrustes(X_best_subj_embed, X_subj_embed) + if X_test_embed is None: + X_test_embed = X_subj_embed_transformed + else: + X_test_embed = np.concatenate( + (X_test_embed, X_subj_embed_transformed), axis=0 + ) + + elif procruste_method == "generalized": + # in this method we use generalized procrustes analysis to align the embeddings of different subjects + # first embed the dFC features of each subject into a lower dimensional space using LE separately + embed_dict = {} + for subject in train_subjects: + # assert the samples of the same subject are contiguous + assert np.all( + np.diff(np.where(subj_label_train == subject)[0]) == 1 + ), f"Indices of {subject} are not consecutive" + X_subj = X_train[subj_label_train == subject, :] + X_subj_embed = LE_transform_dFC( + X=X_subj, + n_components=n_components, + n_neighbors=n_neighbors_LE, + distance_metric="correlation", + ) + embed_dict[subject] = X_subj_embed + + mean_X_train = generalized_procrustes(embed_dict) + + X_train_embed = None + for subject in train_subjects: + X_subj_embed = embed_dict[subject] + mean_X_train_new_size = precheck_for_procruste(mean_X_train, X_subj_embed) + _, X_subj_embed_transformed, _ = procrustes( + mean_X_train_new_size, X_subj_embed + ) + if X_train_embed is None: + X_train_embed = X_subj_embed_transformed + else: + X_train_embed = np.concatenate( + (X_train_embed, X_subj_embed_transformed), axis=0 + ) + + X_test_embed = None + for subject in test_subjects: + X_subj = X_test[subj_label_test == subject, :] + X_subj_embed = LE_transform_dFC( + X=X_subj, + n_components=n_components, + n_neighbors=n_neighbors_LE, + distance_metric="correlation", + ) + mean_X_train_new_size = precheck_for_procruste(mean_X_train, X_subj_embed) + _, X_subj_embed_transformed, _ = procrustes( + mean_X_train_new_size, X_subj_embed + ) + if X_test_embed is None: + X_test_embed = X_subj_embed_transformed + else: + X_test_embed = np.concatenate( + (X_test_embed, X_subj_embed_transformed), axis=0 + ) + + return X_train_embed, X_test_embed + + +def rows_look_redundant(X, sample=100): + n = X.shape[0] + if n > sample: + idx = np.random.choice(n, sample, replace=False) + Xs = X[idx] + else: + Xs = X + # Hash rows quickly + h = np.apply_along_axis(lambda r: hash(r.tobytes()), 1, Xs) + # If more than, say, 50% duplicates -> likely state-based + return (len(h) - len(set(h))) / len(h) > 0.5 + + +class PLSEmbedder(BaseEstimator, TransformerMixin): + """ + Supervised dimensionality reduction using PLSRegression. + Returns X scores (latent components) for downstream models. + + Notes: + - Works for binary y (0/1) and also continuous y (regression-style PLS). + - For classification, y should typically be 0/1 or {-1,1}. + """ + + def __init__(self, n_components=10, scale=False): + self.n_components = n_components + self.scale = scale + + def fit(self, X, y): + X = np.asarray(X) + y = np.asarray(y).ravel().reshape(-1, 1) + + if X.shape[0] != y.shape[0]: + raise ValueError(f"X has {X.shape[0]} rows but y has {y.shape[0]}.") + + # optional internal scaling (usually OFF if pipeline already scales) + if self.scale: + self.scaler_ = StandardScaler(with_mean=True, with_std=True) + Xs = self.scaler_.fit_transform(X) + else: + self.scaler_ = None + Xs = X + + # safety: cap n_components for this fold + nmax = min(Xs.shape[0] - 1, Xs.shape[1]) + ncomp = int(self.n_components) + if ncomp > nmax: + raise ValueError( + f"n_components={ncomp} is too large for fold with " + f"n_samples={Xs.shape[0]}, n_features={Xs.shape[1]} (max {nmax})." + ) + + self.model_ = PLSRegression(n_components=ncomp, scale=False) + self.model_.fit(Xs, y) + return self + + def transform(self, X): + if not hasattr(self, "model_"): + raise RuntimeError("PLSEmbedder is not fitted yet.") + + X = np.asarray(X) + Xs = self.scaler_.transform(X) if self.scaler_ is not None else X + + # Out-of-sample scores + Z = Xs @ self.model_.x_rotations_ + return Z.astype(np.float32, copy=False) + + +def subject_center(X, subj_labels, mode="zscore"): + Xc = np.zeros_like(X) + for subj in np.unique(subj_labels): + idx = subj_labels == subj + if mode == "demean": + Xc[idx] = X[idx] - X[idx].mean(axis=0, keepdims=True) + elif mode == "zscore": + mu = X[idx].mean(axis=0, keepdims=True) + sd = X[idx].std(axis=0, keepdims=True) + 1e-6 + Xc[idx] = (X[idx] - mu) / sd + return Xc + + +def select_num_components_binary_groupcv( + X, + y, + groups, + embedding_method="PLS", + n_list=(2, 5, 10, 15, 20), + cv=3, + random_state=0, +): + """ + Select number of PLS/PCA components using subject-aware CV. + + Parameters + ---------- + X : array (n_samples, n_features) + y : array (n_samples,) binary labels + groups : array (n_samples,) subject IDs + embedding_method : "PLS" or "PCA" + n_list : iterable of candidate n_components + cv : number of folds + random_state : int + + Returns + ------- + best_n : int + Selected number of PLS/PCA components + best_score : float + Mean CV balanced accuracy + """ + + X = np.asarray(X) + y = np.asarray(y).ravel() + groups = np.asarray(groups) + + cv_splitter = StratifiedGroupKFold( + n_splits=cv, shuffle=True, random_state=random_state + ) + + best_n, best_score = None, -np.inf + + for n in n_list: + fold_scores = [] + + for tr, va in cv_splitter.split(X, y, groups): + # ---- embedding (trained ONLY on train fold subjects) + if embedding_method == "PCA": + emb = PCA(n_components=n, svd_solver="full", whiten=False) + Ztr = emb.fit_transform(X[tr]) + Zva = emb.transform(X[va]) + elif embedding_method == "PLS": + emb = PLSEmbedder(n_components=n, scale=True) + Ztr = emb.fit_transform(X[tr], y[tr]) + Zva = emb.transform(X[va]) + + # ---- classifier in latent space + clf = make_pipeline( + StandardScaler(), + SVC(kernel="rbf", C=1.0, gamma="scale"), + ) + clf.fit(Ztr, y[tr]) + pred = clf.predict(Zva) + + fold_scores.append(balanced_accuracy_score(y[va], pred)) + + mean_score = float(np.mean(fold_scores)) + + if mean_score > best_score: + best_score = mean_score + best_n = n + + return best_n, best_score + + +def select_num_components_continuous_groupcv( + X, + y, + groups, + embedding_method="PLS", + n_list=(2, 5, 10, 15, 20), + cv=3, + score="r2", # "r2" or "neg_mse" +): + """ + Select number of PLS/PCA components using subject-aware CV for a CONTINUOUS target. + + Parameters + ---------- + X : array (n_samples, n_features) + y : array (n_samples,) continuous target + groups : array (n_samples,) subject IDs + embedding_method : "PLS" or "PCA" + n_list : iterable of candidate n_components + cv : number of folds + score : "r2" or "neg_mse" + + Returns + ------- + best_n : int + Selected number of PLS/PCA components + best_score : float + Mean CV score (higher is better) + - R² if score="r2" + - negative MSE if score="neg_mse" + """ + + X = np.asarray(X) + y = np.asarray(y).ravel() # regression target must be 1D for SVR + groups = np.asarray(groups) + + if score not in ("r2", "neg_mse"): + raise ValueError("score must be 'r2' or 'neg_mse'.") + + cv_splitter = GroupKFold(n_splits=cv) + + best_n, best_score = None, -np.inf + + for n in n_list: + fold_scores = [] + + for tr, va in cv_splitter.split(X, y, groups): + # ---- embedding (trained ONLY on train fold subjects) + if embedding_method == "PCA": + emb = PCA(n_components=n, svd_solver="full", whiten=False) + Ztr = emb.fit_transform(X[tr]) + Zva = emb.transform(X[va]) + elif embedding_method == "PLS": + emb = PLSEmbedder(n_components=n, scale=True) + # PLSRegression expects y 2D + Ztr = emb.fit_transform(X[tr], y[tr].reshape(-1, 1)) + Zva = emb.transform(X[va]) + # ---- regressor in latent space + reg = make_pipeline( + StandardScaler(), + SVR(kernel="rbf", C=1.0, gamma="scale"), + ) + reg.fit(Ztr, y[tr]) + pred = reg.predict(Zva) + + if score == "r2": + fold_scores.append(r2_score(y[va], pred)) + else: + fold_scores.append(-mean_squared_error(y[va], pred)) + + mean_score = float(np.mean(fold_scores)) + if mean_score > best_score: + best_score, best_n = mean_score, n + + return best_n, best_score + + +def embed_dFC_features( + train_subjects, + test_subjects, + X_train, + X_test, + y_train, + y_test, + subj_label_train, + subj_label_test, + embedding="PCA", + n_components="auto", + n_neighbors_LE=125, + LE_embedding_method="embed+procrustes", + measure_is_state_based=False, + y_continuous=False, +): + """ + Embed the dFC features into a lower dimensional space using PCA, or PLS. For PLS, it assumes that the samples of the same subject are contiguous. + + for LE, first the LE is applied on each subj separately and then the procrustes transformation is applied to align the embeddings of different subjects. + All the subjects are transformed into the space of the subject with the highest silhouette score. + + LE_embedding_method: "concat+embed" or "embed+procrustes" + if the dFC features are not unique (state-based), "embed+procrustes" will not work. So this function will switch to "concat+embed" method. + """ + # make a copy of the data + X_train = X_train.copy() + if X_test is not None: + X_test = X_test.copy() + + # preprocess the data by standardizing it + if embedding in ("PCA", "PLS"): + # center the data by subject before PLS to remove subject effects + X_train_c = subject_center(X_train, subj_label_train, mode="zscore") + if X_test is not None: + X_test_c = subject_center(X_test, subj_label_test, mode="zscore") + else: + X_test_c = None + scaler = StandardScaler(with_mean=True, with_std=True) + X_train_preproc = scaler.fit_transform(X_train_c) + if X_test is not None: + X_test_preproc = scaler.transform(X_test_c) + else: + X_test_preproc = None + + if n_components == "auto": + if y_continuous: + best_n, _ = select_num_components_continuous_groupcv( + X=X_train_preproc, + y=y_train, + groups=subj_label_train, + embedding_method=embedding, + n_list=[ + 2, + 3, + 4, + 5, + 10, + 15, + 20, + 25, + 30, + 40, + 50, + ], # you can adjust this range based on your data + cv=5, # more stable + score="r2", + ) + else: + best_n, _ = select_num_components_binary_groupcv( + X=X_train_preproc, + y=y_train, + groups=subj_label_train, + embedding_method=embedding, + n_list=[ + 2, + 3, + 4, + 5, + 10, + 15, + 20, + 25, + 30, + 40, + 50, + ], # you can adjust this range based on your data + cv=5, # more stable + ) + n_components = best_n + + if embedding == "PCA": + pca = PCA(n_components=n_components, svd_solver="full", whiten=False) + pca.fit(X_train_preproc) + X_train_embed = pca.transform(X_train_preproc) + if X_test is not None: + X_test_embed = pca.transform(X_test_preproc) + else: + X_test_embed = None + elif embedding == "PLS": + pls = PLSEmbedder(n_components=n_components, scale=True) + # fit on train set + X_train_embed = pls.fit_transform(X_train_preproc, y_train) + # only transform test set + if X_test is not None: + X_test_embed = pls.transform(X_test_preproc) + else: + X_test_embed = None + elif embedding == "LE": + # if the dFC features are not unique (state-based), set the LE_embedding_method to "concat+embed" + if measure_is_state_based: + if LE_embedding_method == "embed+procrustes": + warnings.warn( + "The dFC features are not unique (state-based). Switching to 'concat+embed' method." + ) + LE_embedding_method = "concat+embed" + # if n_components is not specified, find the intrinsic dimension of the data using training set and based on the silhouette score + if n_components == "auto": + if LE_embedding_method == "embed+procrustes": + # find the list of time lengths across subjects + n_time_across_subj = [ + np.sum(subj_label_train == subj) for subj in train_subjects + ] + # find the minimum time length across subjects + min_time_length = min(n_time_across_subj) + # set the search range based on the minimum time length + procrustes_limit = int(np.sqrt(2 * min_time_length)) + if procrustes_limit < 50 and procrustes_limit > 10: + search_range_SI = range(2, procrustes_limit, 2) + elif procrustes_limit <= 10: + search_range_SI = range(2, procrustes_limit) + else: + search_range_SI = range(2, 50, 5) + else: + if X_train.shape[0] < 7: + search_range_SI = range(2, X_train.shape[1] + 1) + elif X_train.shape[1] < 24: + search_range_SI = range(2, X_train.shape[1] + 1, 2) + else: + search_range_SI = range(2, 50, 5) + n_components = find_intrinsic_dim( + X=X_train, + y=y_train, + subj_label=subj_label_train, + subjects=train_subjects, + method="localpca", + n_neighbors_LE=n_neighbors_LE, + search_range_SI=search_range_SI, + LE_embedding_method=LE_embedding_method, + measure_is_state_based=measure_is_state_based, + ) + + if LE_embedding_method == "embed+procrustes": + X_train_embed, X_test_embed = LE_embed_procustes( + X_train=X_train, + X_test=X_test, + y_train=y_train, + y_test=y_test, + subj_label_train=subj_label_train, + subj_label_test=subj_label_test, + train_subjects=train_subjects, + test_subjects=test_subjects, + n_components=n_components, + n_neighbors_LE=n_neighbors_LE, + procruste_method="generalized", + ) + elif LE_embedding_method == "concat+embed": + # since SpectralEmbedding does not have transform method, we need to fit the LE on the whole data + # but note that this method is used mostly for state-based dFC features, and in this case the + # samples are the same across subjects, so we can concatenate the training and test sets + # and then apply LE on the concatenated data + if X_test is not None: + X_concat = np.concatenate((X_train, X_test), axis=0) + else: + X_concat = X_train + X_concat_embed = LE_transform_dFC( + X=X_concat, + n_components=n_components, + n_neighbors=n_neighbors_LE, + distance_metric="correlation", + ) + X_train_embed = X_concat_embed[: X_train.shape[0], :] + if X_test is not None: + X_test_embed = X_concat_embed[X_train.shape[0] :, :] + else: + X_test_embed = None + else: + raise ValueError(f"Unknown embedding method: {embedding}") + + # to make computation faster, we can return the embeddings as float32 + X_train_embed = X_train_embed.astype(np.float32, copy=False) + if X_test_embed is not None: + X_test_embed = X_test_embed.astype(np.float32, copy=False) + return X_train_embed, X_test_embed + + +################################# Classification Framework Functions #################################### + + +def get_classification_results( + X_train, + X_test, + y_train, + y_test, + classifier_model=None, +): + """ + Get classification results for a given classifier. + This function fits the classifier, predicts the labels for train and test sets, + and calculates the balanced accuracy score, recall, precision, and f1 for both sets. + + cloning ensures that the classifier is not fitted and the original classifier remains unchanged. + """ + classifier_model = clone(classifier_model) + classifier_model.fit(X_train, y_train) + y_train_pred = classifier_model.predict(X_train) + y_test_pred = classifier_model.predict(X_test) + + RESULT = { + "model": classifier_model, + "train": { + "balanced accuracy": balanced_accuracy_score(y_train, y_train_pred), + "recall": recall_score(y_train, y_train_pred), + "precision": precision_score(y_train, y_train_pred), + "f1": f1_score(y_train, y_train_pred), + }, + "test": { + "balanced accuracy": balanced_accuracy_score(y_test, y_test_pred), + "recall": recall_score(y_test, y_test_pred), + "precision": precision_score(y_test, y_test_pred), + "f1": f1_score(y_test, y_test_pred), + }, + } + return RESULT + + +def logistic_regression_classify( + X_train, + y_train, + X_test, + y_test, + subj_label_train=None, + embedding_method=None, +): + + if embedding_method == "PCA": + emb = PCA(whiten=False, svd_solver="full", random_state=0) + elif embedding_method == "PLS": + emb = PLSEmbedder(scale=False) # IMPORTANT: avoid double scaling + elif embedding_method is None: + emb = None + else: + raise ValueError("embedding_method must be 'PCA' or 'PLS'.") + + if emb is not None: + # Grid (keep small!) + param_grid = { + "emb__n_components": [5, 10, 20, 30, 50, 100], + "lr__C": [0.001, 0.01, 0.1, 1, 10, 100], + } + else: + param_grid = {"lr__C": [0.001, 0.01, 0.1, 1, 10, 100]} + + steps = [("scaler", StandardScaler())] + + if emb is not None: + steps.append(("emb", emb)) + + steps.append( + ("lr", LogisticRegression(penalty="l1", solver="saga", max_iter=2000, tol=1e-3)) + ) + + pipe = Pipeline(steps) + + # CV splitter + if subj_label_train is None: + Xs, ys = shuffle(X_train, y_train, random_state=0) + cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=0) + fit_kwargs = {} + else: + Xs, ys, gs = shuffle(X_train, y_train, subj_label_train, random_state=0) + cv = StratifiedGroupKFold(n_splits=3, shuffle=True, random_state=0) + fit_kwargs = {"groups": gs} + + # GridSearch on training subjects only + gscv = GridSearchCV(pipe, param_grid, cv=cv, n_jobs=1, scoring="balanced_accuracy") + gscv.fit(Xs, ys, **fit_kwargs) + + # Evaluate with best estimator (already refit on full training set by default) + model = gscv.best_estimator_ + + RESULT = get_classification_results( + X_train=X_train, + X_test=X_test, + y_train=y_train, + y_test=y_test, + classifier_model=model, + ) + RESULT["best_params"] = gscv.best_params_ + return RESULT + + +def SVM_classify( + X_train, + y_train, + X_test, + y_test, + subj_label_train=None, + embedding_method=None, +): + if embedding_method == "PCA": + emb = PCA(whiten=False, svd_solver="full", random_state=0) + elif embedding_method == "PLS": + emb = PLSEmbedder(scale=False) # IMPORTANT: avoid double scaling + elif embedding_method is None: + emb = None + else: + raise ValueError("embedding_method must be 'PCA' or 'PLS'.") + + if emb is not None: + # Grid (keep small!) + param_grid = { + "emb__n_components": [5, 10, 20, 30, 50, 100], + "svc__C": [0.1, 1, 10], + "svc__gamma": ["scale", 0.01, 0.1], + } + else: + param_grid = { + "svc__C": [0.1, 1, 10], + "svc__gamma": ["scale", 0.01, 0.1], + } + + steps = [("scaler", StandardScaler())] + + if emb is not None: + steps.append(("emb", emb)) + + steps.append(("svc", SVC(kernel="rbf"))) + + pipe = Pipeline(steps) + + # CV splitter + if subj_label_train is None: + Xs, ys = shuffle(X_train, y_train, random_state=0) + cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=0) + fit_kwargs = {} + else: + Xs, ys, gs = shuffle(X_train, y_train, subj_label_train, random_state=0) + cv = StratifiedGroupKFold(n_splits=3, shuffle=True, random_state=0) + fit_kwargs = {"groups": gs} + + # GridSearch on training subjects only + gscv = GridSearchCV(pipe, param_grid, cv=cv, n_jobs=1, scoring="balanced_accuracy") + gscv.fit(Xs, ys, **fit_kwargs) + + # Evaluate with best estimator (already refit on full training set by default) + model = gscv.best_estimator_ + + RESULT = get_classification_results( + X_train=X_train, + X_test=X_test, + y_train=y_train, + y_test=y_test, + classifier_model=model, + ) + RESULT["best_params"] = gscv.best_params_ + return RESULT + + +def group_permutation(y, groups, permute_groups=True): + """ + Permute the labels while keeping the group structure intact. + This is useful for permutation tests where we want to keep the group structure. + Also permute the order of groups if permute_groups is True. + If permute_groups is False, the labels within each group are permuted but the order of groups is not changed. + This function assumes that all samples in a group have the same label. + """ + # make sure groups is a numpy array + groups = np.array(groups, copy=True) + y = np.copy(y) + + unique_groups = np.unique(groups) + + # Step 1: Create a mapping from groups to labels + group_to_label = {group: y[groups == group] for group in unique_groups} + + # Step 2: Permute each group labels + group_to_permuted_label = {} + for group in unique_groups: + group_to_permuted_label[group] = np.random.permutation(group_to_label[group]) + + # Step 3: Reconstruct permuted y based on groups + # also shuffle the order of groups if permute_groups is True + if permute_groups: + unique_groups_permuted = np.random.permutation(unique_groups) + else: + unique_groups_permuted = unique_groups + y_permuted = list() + for group in unique_groups_permuted: + # For each group, append the permuted label to y_permuted + y_permuted.extend(group_to_permuted_label[group]) + # Convert to numpy array + y_permuted = np.array(y_permuted) + + assert ( + y_permuted.shape == y.shape + ), f"Permuted labels shape {y_permuted.shape} does not match original labels shape {y.shape}" + + return y_permuted + + +def get_permutation_scores( + X_train, + y_train, + X_test, + y_test, + classifier_model, + groups_train=None, + n_permutations=100, +): + """ + Get permutation scores for a given classifier and data. + cloning ensures that the classifier is not previously fitted. + """ + # first get the true balanced accuracy scores from original data + classifier_original = clone(classifier_model) + classifier_original.fit(X_train, y_train) + y_train_pred = classifier_original.predict(X_train) + y_test_pred = classifier_original.predict(X_test) + + # next calculate the balanced accuracy scores for permuted data + permutation_train_scores = [] + permutation_test_scores = [] + for _ in range(n_permutations): + if groups_train is not None: + # permute the labels while keeping the group structure intact + y_train_permuted = group_permutation(y_train, groups_train) + else: + y_train_permuted = np.random.permutation(y_train) + model_permuted = clone(classifier_model) + model_permuted.fit(X_train, y_train_permuted) + + y_train_permuted_pred = model_permuted.predict(X_train) + y_test_permuted_pred = model_permuted.predict(X_test) + permutation_train_scores.append( + balanced_accuracy_score(y_train_permuted, y_train_permuted_pred) + ) + permutation_test_scores.append( + balanced_accuracy_score(y_test, y_test_permuted_pred) + ) + p_value_train = ( + np.sum( + np.array(permutation_train_scores) + >= balanced_accuracy_score(y_train, y_train_pred) + ) + + 1 + ) / (len(permutation_train_scores) + 1) + p_value_test = ( + np.sum( + np.array(permutation_test_scores) + >= balanced_accuracy_score(y_test, y_test_pred) + ) + + 1 + ) / (len(permutation_test_scores) + 1) + + return permutation_train_scores, permutation_test_scores, p_value_train, p_value_test + + +def softmax(x, tau=1.0, axis=1): + z = (x - np.max(x, axis=axis, keepdims=True)) / float(tau) + np.exp(z, out=z) + z_sum = np.sum(z, axis=axis, keepdims=True) + z /= z_sum + return z + + +def clip_and_renorm(P, eps=1e-6, axis=1): + P = np.asarray(P, float) + P = np.clip(P, eps, None) + P /= P.sum(axis=axis, keepdims=True) + return P + + +# ---- log-ratio transforms ---- +def clr_transform(P, eps=1e-6): + """Centered log-ratio: log(p) - mean(log(p)) row-wise.""" + P = clip_and_renorm(P, eps=eps) + L = np.log(P) + return L - L.mean(axis=1, keepdims=True) # each row sums to 0 + + +def ilr_transform(P, eps=1e-6): + """Pivot ILR using an orthonormal basis; returns (n, K-1).""" + P = clip_and_renorm(P, eps=eps) + L = np.log(P) + clr = L - L.mean(axis=1, keepdims=True) + K = P.shape[1] + V = np.zeros((K, K - 1)) + # Pivot coordinates basis (orthonormal in Aitchison geometry) + for j in range(1, K): + V[:j, j - 1] = 1 / j + V[j, j - 1] = -1 + V[:, j - 1] *= np.sqrt(j / (j + 1)) + return clr @ V # (n, K-1) + + +def process_SB_features(X, measure_name): + """ + Process state-based features for a given measure. + + The process involves applying a softmax function followed by an ILR transform. + This is to ensure that the features are properly normalized and transformed for subsequent analysis. + + State-based feature vectors are compositional (non-negative and sum-to-one). We therefore analyze + them in the Aitchison geometry and apply the isometric log-ratio (ILR) transformation (K−1 coordinates). + The output has K−1 dimensions. + """ + tau = 1.0 # temperature; 0.5–2.0 is typical + + X_transformed = None + if measure_name in ["CAP", "Clustering"]: + X_transformed = softmax(-X, tau=tau) + # 2) ILR transform + X_transformed = ilr_transform(X_transformed) + elif measure_name in ["ContinuousHMM", "DiscreteHMM", "Windowless"]: + X_transformed = ilr_transform(X) + return X_transformed + + +def get_classification_scores( + target, + pred, +): + """ + Get classification scores for a given target and predicted labels. + Returns a dictionary with these metrics: + - accuracy + - balanced accuracy + - recall + - precision + - f1 score + - fp, fn, tp, tn + - average precision + """ + tn, fp, fn, tp = confusion_matrix(target, pred).ravel() + scores = { + "accuracy": accuracy_score(target, pred), + "balanced accuracy": balanced_accuracy_score(target, pred), + "recall": recall_score(target, pred), + "precision": precision_score(target, pred), + "f1": f1_score(target, pred), + "fp": fp, + "fn": fn, + "tp": tp, + "tn": tn, + "average precision": average_precision_score(target, pred), + } + return scores + + +def task_presence_classification( + task, + dFC_id, + roi_root, + dFC_root, + run=None, + session=None, + dynamic_pred="no", + normalize_dFC=False, + train_test_ratio=0.8, +): + """ + perform task presence classification using logistic regression, SVM, KNN, Random Forest, Gradient Boosting + for a given task and dFC method and run. + """ + if run is None: + print(f"=============== {task} ===============") + else: + print(f"=============== {task} {run} ===============") + + if task == "task-restingstate": + return + + SUBJECTS = find_available_subjects( + dFC_root=dFC_root, task=task, run=run, session=session, dFC_id=dFC_id + ) + + # randomly select train_test_ratio of the subjects for training + # and rest for testing using numpy.random.choice + train_subjects = np.random.choice( + SUBJECTS, int(train_test_ratio * len(SUBJECTS)), replace=False + ) + test_subjects = np.setdiff1d(SUBJECTS, train_subjects) + print( + f"Number of train subjects: {len(train_subjects)} and test subjects: {len(test_subjects)}" + ) + + ( + X_train, + X_test, + y_train, + y_test, + subj_label_train, + subj_label_test, + measure_name, + measure_is_state_based, + ) = dFC_feature_extraction( + task=task, + train_subjects=train_subjects, + test_subjects=test_subjects, + dFC_id=dFC_id, + roi_root=roi_root, + dFC_root=dFC_root, + run=run, + session=session, + dynamic_pred=dynamic_pred, + normalize_dFC=normalize_dFC, + FCS_proba_for_SB=True, # for state-based dFC features, we use FCS_proba + ) + + if measure_is_state_based: + X_train = process_SB_features(X=X_train, measure_name=measure_name) + X_test = process_SB_features(X=X_test, measure_name=measure_name) + + # center the data by subject before embedding to remove subject effects + # separately for train and test sets to avoid data leakage + # for both state-based and state-free methods + X_train = subject_center(X_train, subj_label_train, mode="demean") + X_test = subject_center(X_test, subj_label_test, mode="demean") + + ML_scores = { + "group_lvl": { + "task": list(), + "run": list(), + "dFC method": list(), + "embedding": list(), + "group": list(), + "SI": list(), + }, + "subj_lvl": { + "subj_id": list(), + "group": list(), + "SI": list(), + "task": list(), + "run": list(), + "dFC method": list(), + "embedding": list(), + }, + } + + EMBEDDINGS = ["PCA", "PLS"] + check_count = len(EMBEDDINGS) + num_excluded_subjects = 0 + for embedding in EMBEDDINGS: + if measure_is_state_based: + embedding_to_use = None + else: + embedding_to_use = embedding + + # check if both classes are present in train and test sets + if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2: + print( + f"Only one class present in train or test sets for {embedding}. Skipping..." + ) + check_count -= 1 + continue + + # task presence classification + + print("task presence classification ...") + + # logistic regression + log_reg_RESULT = logistic_regression_classify( + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + subj_label_train=subj_label_train, + embedding_method=embedding_to_use, + ) + + # SVM + SVM_RESULT = SVM_classify( + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + subj_label_train=subj_label_train, + embedding_method=embedding_to_use, + ) + + ML_models = {"Logistic regression": log_reg_RESULT, "SVM": SVM_RESULT} + + # Silhouette score + # SI does not need to be separated for train and test sets + # we will use the same SI for both train and test sets + # using all samples from train and test sets + # use the embedding and scaler trained in SVM_RESULT["model"] + # so the results are comparable to the classification scores + scaler = SVM_RESULT["model"].named_steps["scaler"] + embedding_model = SVM_RESULT["model"].named_steps.get("emb", None) + if embedding_model is not None: + X_train_embedded = embedding_model.transform(scaler.transform(X_train)) + X_test_embedded = embedding_model.transform(scaler.transform(X_test)) + else: + X_train_embedded = scaler.transform(X_train) + X_test_embedded = scaler.transform(X_test) + + X_combined = np.concatenate((X_train_embedded, X_test_embedded), axis=0) + y_combined = np.concatenate((y_train, y_test), axis=0) + + SI = { + "train": silhouette_score(X_combined, y_combined), + "test": silhouette_score(X_combined, y_combined), + } + + # # permutation tests + # permutation_scores = { + # "train": {}, + # "test": {}, + # } + # for model_name in ML_models: + # ( + # permutation_train_scores, + # permutation_test_scores, + # p_value_train, + # p_value_test, + # ) = get_permutation_scores( + # X_train=X_train_embedded, + # y_train=y_train, + # X_test=X_test_embedded, + # y_test=y_test, + # classifier_model=ML_models[model_name]["model"], + # groups_train=subj_label_train, + # n_permutations=100, + # ) + # permutation_scores["train"][ + # f"{model_name} permutation p_value" + # ] = p_value_train + # permutation_scores["train"][f"{model_name} permutation score mean"] = np.mean( + # permutation_train_scores + # ) + # permutation_scores["train"][f"{model_name} permutation score std"] = np.std( + # permutation_train_scores + # ) + # permutation_scores["test"][f"{model_name} permutation p_value"] = p_value_test + # permutation_scores["test"][f"{model_name} permutation score mean"] = np.mean( + # permutation_test_scores + # ) + # permutation_scores["test"][f"{model_name} permutation score std"] = np.std( + # permutation_test_scores + # ) + + # group level scores + for group in ["train", "test"]: + + ML_scores["group_lvl"]["group"].append(group) + ML_scores["group_lvl"]["embedding"].append(embedding) + ML_scores["group_lvl"]["task"].append(task) + ML_scores["group_lvl"]["run"].append(run) + ML_scores["group_lvl"]["dFC method"].append(measure_name) + # SI + ML_scores["group_lvl"]["SI"].append(SI[group]) + + for model_name in ML_models: + # accuracy score + for metric in ML_models[model_name][group]: + if not f"{model_name} {metric}" in ML_scores["group_lvl"]: + ML_scores["group_lvl"][f"{model_name} {metric}"] = list() + ML_scores["group_lvl"][f"{model_name} {metric}"].append( + ML_models[model_name][group][metric] + ) + + # # permutation test results + # for key in permutation_scores[group]: + # if not key in ML_scores["group_lvl"]: + # ML_scores["group_lvl"][key] = list() + # ML_scores["group_lvl"][key].append(permutation_scores[group][key]) + + # subject level scores + for subj in SUBJECTS: + if subj in train_subjects: + subj_group = "train" + features = X_train[subj_label_train == subj, :] + target = y_train[subj_label_train == subj] + elif subj in test_subjects: + subj_group = "test" + features = X_test[subj_label_test == subj, :] + target = y_test[subj_label_test == subj] + # check if only one class is present, skip the subject + if len(np.unique(target)) < 2: + num_excluded_subjects += 1 + continue + ML_scores["subj_lvl"]["group"].append(subj_group) + ML_scores["subj_lvl"]["subj_id"].append(subj) + + # Silhouette score + ML_scores["subj_lvl"]["SI"].append(silhouette_score(features, target)) + # measure pred score using different metrics on each subj + for model_name in ML_models: + model = ML_models[model_name]["model"] + pred = model.predict(features) + scores = get_classification_scores(target=target, pred=pred) + + for metric in scores: + if not f"{model_name} {metric}" in ML_scores["subj_lvl"]: + ML_scores["subj_lvl"][f"{model_name} {metric}"] = list() + ML_scores["subj_lvl"][f"{model_name} {metric}"].append(scores[metric]) + + ML_scores["subj_lvl"]["task"].append(task) + ML_scores["subj_lvl"]["run"].append(run) + ML_scores["subj_lvl"]["dFC method"].append(measure_name) + ML_scores["subj_lvl"]["embedding"].append(embedding) + + # sanity check of the ML_scores + L = None + for key in ML_scores["group_lvl"]: + if L is None: + L = len(ML_scores["group_lvl"][key]) + else: + assert ( + len(ML_scores["group_lvl"][key]) == L + ), f"Length of {key} is not equal to others." + + # L is supposed to be equal to 3 embeddings (PCA, PLS, and LE) * 2 groups (train and test) + assert ( + L == check_count * 2 + ), f"Length of group_lvl is not equal to {check_count * 2}, but {L}." + + L = None + for key in ML_scores["subj_lvl"]: + if L is None: + L = len(ML_scores["subj_lvl"][key]) + else: + assert ( + len(ML_scores["subj_lvl"][key]) == L + ), f"Length of {key} is not equal to others." + + # L is supposed to be equal to number of subjects * 3 embeddings (PCA, PLS, and LE) + assert ( + L == len(SUBJECTS) * check_count - num_excluded_subjects + ), f"Length of subj_lvl is not equal to {len(SUBJECTS) * check_count - num_excluded_subjects}, but {L}." + + return ML_scores diff --git a/pydfc/report_util.py b/pydfc/report_util.py new file mode 100644 index 0000000..82f8d82 --- /dev/null +++ b/pydfc/report_util.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +""" +Functions to facilitate reporting. + +Created on Feb 5 2025 +@author: Mohammad Torabi +""" + +import os + +import matplotlib.pyplot as plt +import seaborn as sns + +################################# Parameters #################################### + +fig_dpi = 120 +fig_bbox_inches = "tight" +fig_pad = 0.1 +show_title = True +save_fig_format = "png" # pdf, png, + +########## Plotting Classification Results Functions ########## + + +def plot_classification_metrics( + dataframe, ML_algorithm, pred_metric, title, suffix, output_dir +): + """ + This function plots these metrics: + - accuracy + - balanced accuracy + - precision + - recall + - f1 score (f1) + - true positive (tp) + - true negative (tn) + - false positive (fp) + - false negative (fn) + - average precision + """ + + plt.figure(figsize=(10, 5)) + + g = sns.pointplot( + data=dataframe, + x="dFC method", + y=f"{ML_algorithm} {pred_metric}", + hue="group", + hue_order=["train", "test"], + errorbar="sd", + linestyle="none", + dodge=True, + capsize=0.1, + ) + plt.xlabel(g.get_xlabel(), fontweight="bold") + plt.ylabel(g.get_ylabel(), fontweight="bold") + plt.xticks(fontweight="bold") + plt.yticks(fontweight="bold") + if pred_metric == "balanced accuracy": + # add a horizontal line at 0.5 corresponding to chance level + g.axhline(0.5, color="r", linestyle="--") + if not pred_metric in ["fp", "fn", "tp", "tn"]: + # set the y-axis upper limit to 1, but not set the lower limit + g.set(ylim=(None, 1)) + if show_title: + g.set_title(title, fontdict={"fontsize": 10, "fontweight": "bold"}) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + pred_metric_no_space = pred_metric.replace(" ", "_") + plt.savefig( + f"{output_dir}/classification_{pred_metric_no_space}_{suffix}.{save_fig_format}", + dpi=fig_dpi, + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + format=save_fig_format, + ) + + plt.close() + + +def plot_clustering_metrics(dataframe, metric, title, suffix, output_dir): + """ + This function plots these metrics: + - SI + """ + + plt.figure(figsize=(10, 5)) + + g = sns.pointplot( + data=dataframe, + x="dFC method", + y=f"{metric}", + hue="group", + hue_order=["train", "test"], + errorbar="sd", + linestyle="none", + dodge=True, + capsize=0.1, + ) + plt.xlabel(g.get_xlabel(), fontweight="bold") + plt.ylabel(g.get_ylabel(), fontweight="bold") + plt.xticks(fontweight="bold") + plt.yticks(fontweight="bold") + + # set the y-axis upper limit to 1, but not set the lower limit + g.set(ylim=(None, 1)) + + if show_title: + g.set_title(title, fontdict={"fontsize": 10, "fontweight": "bold"}) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + metric_no_space = metric.replace(" ", "_") + plt.savefig( + f"{output_dir}/clustering_{metric_no_space}_{suffix}.{save_fig_format}", + dpi=fig_dpi, + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + format=save_fig_format, + ) + + plt.close() diff --git a/pydfc/simul_utils.py b/pydfc/simul_utils.py new file mode 100644 index 0000000..bf6d9a4 --- /dev/null +++ b/pydfc/simul_utils.py @@ -0,0 +1,447 @@ +# -*- coding: utf-8 -*- +""" +Functions to facilitate dFC simulation. + +Created on April 25 2024 +@author: Mohammad Torabi +""" + +import numpy as np +from scipy import signal +from tvb.simulator.lab import * + +from pydfc import TIME_SERIES, task_utils + +################################# Simulation Functions #################################### + + +class CustomStimuli(patterns.StimuliRegion): + def __init__( + self, stimulus_timing, region_weighting, connectivity, amplitude=1.0, **kwargs + ): + """ + Parameters: + - stimulus_timing: array of 0s and 1s (or amplitudes) over time + - target_nodes: list or array of node indices where to apply the stimulus + - amplitude: default amplitude (multiplied by stimulus_timing value) + """ + super().__init__(**kwargs) + self.stimulus_timing = np.array(stimulus_timing) + self.amplitude = amplitude + self.current_idx = 0 + self.weight = region_weighting + self.connectivity = connectivity # Required by TVB, even if not used + # Required by TVB, even if not used + self.temporal = equations.PulseTrain() + self.spatial = equations.DiscreteEquation() + + def __call__(self, temporal_indices, spatial_indices=None): + + # if temporal_indices is not a single integer, raise an error + if not isinstance(temporal_indices, (int, np.integer)): + raise ValueError( + "CustomStimuli expects a single integer for temporal_indices." + ) + # time is milliseconds + n_nodes = self.weight.shape[0] + stim = np.zeros(n_nodes) + + # Determine which index in the stimulus array corresponds to current time + self.current_idx = temporal_indices + + if self.current_idx < len(self.stimulus_timing): + stim_value = self.stimulus_timing[self.current_idx] * self.amplitude + else: + stim_value = 0 # stimulus ends when array is exhausted + + stim = np.multiply(self.weight, stim_value) + self.stimulus = stim + return self.stimulus + + def set_state(self, state): + self.state = state + + def configure_time(self, t): + pass + + +def create_random_stimulus_weights(stimulated_regions_list, n_regions=76): + """ + Create random stimulus weights for the stimulated regions. + """ + rand_weighting = [ + np.random.normal(loc=2.0 ** (-1 * (2 + i)), scale=0.1 * (2.0**-2)) + for i in range(len(stimulated_regions_list)) + ] + + # configure stimulus spatial pattern + weighting = np.zeros((n_regions,)) + weighting[stimulated_regions_list] = rand_weighting + + return weighting + + +def simulate_task_BOLD( + stimulus_timing, + sim_length, + BOLD_period, + TAVG_period, + num_stimulated_regions=5, + global_conn_coupling_coef=0.0126, + D=0.001, + conn_speed=1.0, + dt=0.5, + drop_initial_time=False, +): + """ + Simulate BOLD signal for a task. + + Parameters + ---------- + stimulus_timing : array-like, optional + The stimulus timing array, which should contain 0s and 1s (or amplitudes) over time. + sim_length : float + The length of the simulation in seconds. + BOLD_period : float + The BOLD period in milliseconds. + TAVG_period : float + The TAVG period in milliseconds. + num_stimulated_regions : int, optional + The number of stimulated regions. The default is 5. + if num_stimulated_regions is 5, the stimulated regions are: + [0, 7, 13, 33, 42] + if num_stimulated_regions is 16, the stimulated regions are: + regions = list(range(0, 76, 5)) + if num_stimulated_regions is 26, the stimulated regions are: + regions = list(range(0, 76, 3)) + else, the stimulated regions are randomly selected. + """ + # randomize some parameters for each subjects + global_conn_coupling = np.random.normal(loc=global_conn_coupling_coef, scale=0.0075) + conn_speed_rand = np.random.normal(loc=conn_speed, scale=0.1 * conn_speed) + ################################# Initialize Simulation #################################### + conn = connectivity.Connectivity.from_file() + conn.speed = np.array([conn_speed_rand]) + conn.configure() + # randomize the structural connectivity + # Additive Gaussian noise (e.g. 10% of weight magnitude) + noise_level = 0.1 # 10% + conn.weights += np.random.normal( + loc=0, + scale=noise_level * np.std(conn.weights[conn.weights > 0]), + size=conn.weights.shape, + ) + # Remove negative weights if any + conn.weights = np.clip(conn.weights, 0, None) + # reconfigure the connectivity + conn.configure() + + # configure stimulus spatial pattern + if num_stimulated_regions == 5: + stimulated_regions_list = [0, 7, 13, 33, 42] + elif num_stimulated_regions == 16: + stimulated_regions_list = list(range(0, 76, 5)) + elif num_stimulated_regions == 26: + stimulated_regions_list = list(range(0, 76, 3)) + else: + stimulated_regions_list = np.random.choice( + np.arange(76), num_stimulated_regions, replace=False + ) + stimulated_regions_list = list(stimulated_regions_list) + weighting = create_random_stimulus_weights( + stimulated_regions_list=stimulated_regions_list, n_regions=76 + ) + + # check if stimulus_timing is only containing 0s and 1s + if not np.all(np.isin(stimulus_timing, [0, 1])): + raise ValueError("stimulus_timing should only contain 0s and 1s.") + + stimulus = CustomStimuli( + stimulus_timing=stimulus_timing, + region_weighting=weighting, + connectivity=conn, + amplitude=1.0, + ) + + ################################# Run Simulation #################################### + + # set the global coupling strength + # you can switch between deterministic (without noise) and stochastic integration (with noise) + sim = simulator.Simulator( + model=models.Generic2dOscillator(a=np.array([0.5])), + connectivity=conn, + coupling=coupling.Linear(a=np.array([global_conn_coupling])), + # integrator=integrators.HeunDeterministic(dt=dt), + integrator=integrators.HeunStochastic( + dt=dt, noise=noise.Additive(nsig=np.array([D])) + ), + monitors=( + monitors.TemporalAverage(period=TAVG_period), + monitors.Bold(period=BOLD_period, hrf_kernel=equations.MixtureOfGammas()), + monitors.ProgressLogger(period=10e3), + ), + stimulus=stimulus, + simulation_length=sim_length, + ).configure() + + (tavg_time, tavg_data), (bold_time, bold_data), _ = sim.run() + + if drop_initial_time: + # truncate the first 10 seconds of the simulation + # to avoid transient effects + truncate_time = 10e3 # in m sec + bold_truncate_idx = int(truncate_time / BOLD_period) + bold_time = bold_time[bold_truncate_idx:] + bold_data = bold_data[bold_truncate_idx:] + tavg_truncate_idx = int(truncate_time / TAVG_period) + tavg_time = tavg_time[tavg_truncate_idx:] + tavg_data = tavg_data[tavg_truncate_idx:] + + centres_locs = conn.centres + region_labels = list(conn.region_labels) + TR_mri = BOLD_period * 1e-3 # in seconds + + bold_data = bold_data[:, 0, :, 0] + # change time_series.shape to (roi, time) + bold_data = bold_data.T + + TAVG_data = tavg_data[:, 0, :, 0] + # change time_series.shape to (roi, time) + TAVG_data = TAVG_data.T + + return ( + bold_data, + bold_time, + region_labels, + centres_locs, + TR_mri, + TAVG_data, + tavg_time, + TAVG_period, + ) + + +def create_simul_task_info( + TR_mri, + task, + onset, + task_duration, + task_block_duration, + sim_length, + oversampling=50, +): + """ + Create a dictionary containing the task data for simulation. + + Parameters + ---------- + TR_mri : float + The repetition time of the MRI in seconds. + task : str + The task name. + onset : float + The onset time of the task. + task_duration : float + The duration of the task. + task_block_duration : float + The duration of the task block. + sim_length : float + The length of the simulation. + in milliseconds + oversampling : int, optional + The oversampling factor. The default is 50. + generate more samples per TR than the func data to have a + better event_labels time resolution + """ + ####################### EXTRACT TASK LABELS ####################### + events = [] + + # using onset, task_duration, task_block_duration to create the events + events.append(["onset", "duration", "trial_type"]) + t = onset + while t < (sim_length * 1e-3): + events.append([t, task_duration, "task"]) + t += task_block_duration + events = np.array(events) + + # find the number of time points in the MRI data + # sim_length is in milliseconds + num_time_mri = int((sim_length * 1e-3) / TR_mri) + + event_labels, Fs_task, event_types = task_utils.events_time_to_labels( + events=events, + TR_mri=TR_mri, + num_time_mri=num_time_mri, + oversampling=oversampling, + return_0_1=False, + ) + # fill task labels with 0 (rest) and 1 (task's index, here only 1 task is used) + task_labels = np.multiply(event_labels != 0, 1) + ################################# SAVE ################################# + # save the ROI time series and task data + task_data = { + "task": task, + "task_labels": task_labels, + "event_labels": event_labels, + "event_types": event_types, + "events": events, + "Fs_task": Fs_task, + "TR_mri": TR_mri, + "num_time_mri": num_time_mri, + } + + return task_data + + +def event_labels_to_stimulus_timing( + event_labels, + Fs_task, + dt, +): + """ + Convert event labels to stimulus timing. + Parameters + ---------- + event_labels : array-like + The event labels, which should contain 0s (rest) and event ids over time. + Fs_task : float + The sampling frequency of the task data in Hz. + dt : float + The simulation time step in milliseconds. + """ + # make sure the timings are only 0s and 1s + stimulus_timing = np.multiply(event_labels != 0, 1) + + # make sure task_data sampling frequency is equal to simulation time step + L_old = len(stimulus_timing) + L_new = int((L_old * 1e3) / (Fs_task * dt)) + stimulus_timing = signal.resample(stimulus_timing, L_new) + # binarize the stimulus timing + # because of the resampling, the values might not be exactly 0 or 1 + stimulus_timing = np.where(stimulus_timing > 0.5, 1, 0) + + return stimulus_timing + + +def simulate_task_BOLD_TS( + subj_id, + task_data, + TAVG_period, + num_stimulated_regions=5, + global_conn_coupling_coef=0.0126, + D=0.001, + conn_speed=1.0, + dt=0.5, + drop_initial_time=False, +): + """ + Simulate BOLD signal for a task and return a TIME_SERIES object. + """ + task = task_data["task"] + BOLD_period = task_data["TR_mri"] * 1e3 # convert to milliseconds + sim_length = task_data["num_time_mri"] * BOLD_period # in milliseconds + stimulus_timing = event_labels_to_stimulus_timing( + event_labels=task_data["event_labels"], + Fs_task=task_data["Fs_task"], + dt=dt, + ) + + bold_data, bold_time, region_labels, centres_locs, TR_mri, _, _, _ = ( + simulate_task_BOLD( + stimulus_timing=stimulus_timing, + sim_length=sim_length, + BOLD_period=BOLD_period, + TAVG_period=TAVG_period, + num_stimulated_regions=num_stimulated_regions, + global_conn_coupling_coef=global_conn_coupling_coef, + D=D, + conn_speed=conn_speed, + dt=dt, + drop_initial_time=drop_initial_time, + ) + ) + time_series = TIME_SERIES( + data=bold_data, + subj_id=subj_id, + Fs=1 / TR_mri, + locs=centres_locs, + node_labels=region_labels, + TS_name=f"BOLD_{subj_id}_{task}", + session_name=task, + ) + + return time_series + + +def simulate_task_data(subj_id, task_info): + """ + Simulate task-based BOLD signal for a subject. + + Parameters + ---------- + subj_id : str + The subject ID. + task_info : dict + A dictionary containing the task information below: + - task_name: str + The name of the task. + - task_data: str + Path to a dictionary containing the task parameters + if task_data is not provided, onset_time, task_duration, task_block_duration, + sim_length, will be used to create the task data. + - onset_time: float + The onset time of the task in seconds. + - task_duration: float + The duration of the task in seconds. + - task_block_duration: float + The duration of the task block in seconds. + - sim_length: float + The length of the simulation in milliseconds. + - BOLD_period: float + The BOLD period in milliseconds. + - TAVG_period: float + The TAVG period in milliseconds. + - num_stimulated_regions: int + The number of stimulated regions. + - global_conn_coupling_coef: float + The global connectivity coupling coefficient. + - D: float + The noise parameter. + - conn_speed: float + The connectivity speed. + - dt: float + The simulation time step in milliseconds. + """ + if "task_data" in task_info: + # task_info["task_data"] is a path to a dictionary with {subj_id} as a placeholder + if "{subj_id}" in task_info["task_data"]: + task_data_path = task_info["task_data"].replace("{subj_id}", subj_id) + else: + task_data_path = task_info["task_data"] + task_data = np.load(task_data_path, allow_pickle="TRUE").item() + else: + task_data = create_simul_task_info( + TR_mri=task_info["BOLD_period"] * 1e-3, # convert to seconds + task=task_info["task_name"], + onset=task_info["onset_time"], + task_duration=task_info["task_duration"], + task_block_duration=task_info["task_block_duration"], + sim_length=task_info["sim_length"], + ) + + time_series = simulate_task_BOLD_TS( + subj_id=subj_id, + task_data=task_data, + TAVG_period=task_info["TAVG_period"], + num_stimulated_regions=task_info["num_stimulated_regions"], + global_conn_coupling_coef=task_info["global_conn_coupling_coef"], + D=task_info["D"], + conn_speed=task_info["conn_speed"], + dt=task_info["dt"], + ) + + # make sure task_data["num_time_mri"] is equal to the number of time points in the time series + if task_data["num_time_mri"] != time_series.n_time: + task_data["num_time_mri"] = time_series.n_time + + return time_series, task_data diff --git a/pydfc/task_utils.py b/pydfc/task_utils.py index daaf95e..249a489 100644 --- a/pydfc/task_utils.py +++ b/pydfc/task_utils.py @@ -6,10 +6,14 @@ @author: Mohammad Torabi """ +import warnings + import matplotlib.pyplot as plt import numpy as np from nilearn import glm from scipy import signal +from sklearn.mixture import GaussianMixture +from statsmodels.tsa.stattools import acf from .dfc_utils import TR_intersection, rank_norm, visualize_conn_mat @@ -17,23 +21,52 @@ def events_time_to_labels( - events, TR_mri, num_time_mri, event_types=[], oversampling=50, return_0_1=False + events, + TR_mri, + num_time_mri, + event_types=None, + oversampling=50, + trial_type_label="trial_type", + rest_labels=["rest", "Rest"], + return_0_1=False, ): """ - event_types is a list of event types to be considered. If None, 0 and 1s will be returned. + event_types is a list of event types to be considered. If None, it will found based on events. Assigns the longest event in each TR to that TR (in the interval from last TR to current TR). It assumes that the first time point is TR0 which corresponds to [0 sec, TR sec] interval. oversampling: number of samples per TR_mri to improve the time resolution of tasks + + if trial_type_label is None, we use event type "unknown" as the trial type """ + + # find which column is the "onset" in the first row + onset_idx = np.where(events[0, :] == "onset")[0][0] + duration_idx = np.where(events[0, :] == "duration")[0][0] + if trial_type_label is not None: + trial_type_idx = np.where(events[0, :] == trial_type_label)[0][0] + assert ( - events[0, 0] == "onset" - ), "The first column of the events file should be the onset!" - assert ( - events[0, 1] == "duration" - ), "The second column of the events file should be the duration!" + events[0, onset_idx] == "onset" + ), "Something went wrong with the events file! The onset column was not found!" assert ( - events[0, 2] == "trial_type" - ), "The third column of the events file should be the trial type!" + events[0, duration_idx] == "duration" + ), "Something went wrong with the events file! The duration column was not found!" + if trial_type_label is not None: + assert ( + events[0, trial_type_idx] == trial_type_label + ), "Something went wrong with the events file! The trial_type column was not found!" + + if event_types is None: + if trial_type_label is None: + event_types = ["unknown"] + else: + event_types = list(np.unique(events[1:, trial_type_idx])) + # remove all the rest labels + for rest_label in rest_labels: + if rest_label in event_types: + event_types.remove(rest_label) + # add the rest label to the beginning for consistency + event_types = ["rest"] + event_types Fs = float(1 / TR_mri) * oversampling num_time_task = int(num_time_mri * oversampling) @@ -43,27 +76,35 @@ def events_time_to_labels( if i == 0: continue - if events[i, 2] in event_types: - start_time = float(events[i, 0]) - end_time = float(events[i, 0]) + float(events[i, 1]) + if trial_type_label is None: + trial_type = "unknown" + else: + trial_type = events[i, trial_type_idx] + + if trial_type in event_types: + # the only rest label that is left in event types is "rest" but we don't want to consider it + if trial_type == "rest": + continue + start_time = float(events[i, onset_idx]) + end_time = float(events[i, onset_idx]) + float(events[i, duration_idx]) start_timepoint = int(np.rint(start_time * Fs)) end_timepoint = int(np.rint(end_time * Fs)) - event_labels[start_timepoint:end_timepoint] = event_types.index(events[i, 2]) + event_labels[start_timepoint:end_timepoint] = event_types.index(trial_type) if return_0_1: event_labels = np.multiply(event_labels != 0, 1) - return event_labels, Fs + return event_labels, Fs, event_types ################################# Visualization Functions #################################### -def plot_task_dFC(task_labels, dFC_lst, event_types, Fs_mri, TR_step=12): +def plot_task_dFC(task_presence, dFC_lst, Fs_mri, TR_step=12): """ - task_labels: numpy array of shape (num_time_task, num_event_types) containing the event or task labels - this function assumes that the task data has the same Fs as the dFC data, i.e. MRI data - and that the time points of the task data are aligned with the time points of the dFC data + task_presence: numpy array containing the task presence in the time points of the dFC data + this function assumes that the task presence has the same Fs as the dFC data, i.e. MRI data + and that the time points of the task presence are aligned with the time points of the dFC data """ conn_mat_size = 20 scale_task_plot = 20 @@ -73,12 +114,8 @@ def plot_task_dFC(task_labels, dFC_lst, event_types, Fs_mri, TR_step=12): ax = plt.gca() - time = np.arange(0, task_labels.shape[0]) / Fs_mri - for i in range(0, task_labels.shape[1]): - ax.plot( - time, task_labels[:, i] * scale_task_plot, label=event_types[i], linewidth=4 - ) - plt.legend() + time = np.arange(0, task_presence.shape[0]) / Fs_mri + ax.plot(time, task_presence * scale_task_plot, linewidth=4) plt.xlabel("Time (s)") comman_TRs = TR_intersection(dFC_lst) @@ -112,24 +149,47 @@ def plot_task_dFC(task_labels, dFC_lst, event_types, Fs_mri, TR_step=12): plt.show() -################################# PCA Functions #################################### +################################# Stat Functions #################################### -# def BOLD +def cohen_d_bold(X, y): + """ + Compute Cohen's d per ROI between task and rest. + + Parameters + ---------- + X : ndarray, shape (n_timepoints, n_ROIs) + BOLD signals. + y : ndarray, shape (n_timepoints,) + Task labels: 0 = rest, 1 = task. + + Returns + ------- + d_values : ndarray, shape (n_ROIs,) + Cohen's d per ROI. + """ + task_idx = y == 1 + rest_idx = y == 0 -################################# Prediction Functions #################################### + X_task = X[task_idx, :] + X_rest = X[rest_idx, :] -from sklearn.linear_model import LinearRegression + mean_task = X_task.mean(axis=0) + mean_rest = X_rest.mean(axis=0) + std_task = X_task.std(axis=0, ddof=1) + std_rest = X_rest.std(axis=0, ddof=1) -def linear_reg(X, y): - """ - X = (n_samples, n_features) - y = (n_samples, n_targets) - """ - reg = LinearRegression().fit(X, y) - print(reg.score(X, y)) - return reg.predict(X) + n_task = X_task.shape[0] + n_rest = X_rest.shape[0] + + pooled_std = np.sqrt( + ((n_task - 1) * std_task**2 + (n_rest - 1) * std_rest**2) / (n_task + n_rest - 2) + ) + + d_values = (mean_task - mean_rest) / pooled_std + + return d_values ################################# Validation Functions #################################### @@ -157,7 +217,7 @@ def event_conv_hrf(event_signal, TR_mri, TR_task): return events_hrf -def event_labels_conv_hrf(event_labels, TR_mri, TR_task): +def event_labels_conv_hrf(event_labels, TR_mri, TR_task, no_hrf=False): """ event_labels: event labels including 0 and event ids at the time each event happens TR_mri: TR of MRI @@ -167,6 +227,10 @@ def event_labels_conv_hrf(event_labels, TR_mri, TR_task): return: event labels convolved with HRF for each event type the convolved event labels have the same length as the event_labels event type i convolved with HRF is in events_hrf[:, i-1] + + events_hrf[:, 0] is the resting state + + if no_hrf is True, the event labels are not convolved with HRF """ event_labels = np.array(event_labels) @@ -181,7 +245,10 @@ def event_labels_conv_hrf(event_labels, TR_mri, TR_task): event_signal = np.zeros(L) event_signal[event_labels[:, 0] == event_id] = 1.0 - events_hrf[:, i] = event_conv_hrf(event_signal, TR_mri, TR_task) + if no_hrf: + events_hrf[:, i] = event_signal + else: + events_hrf[:, i] = event_conv_hrf(event_signal, TR_mri, TR_task) # the time points that are not in any event are considered as resting state events_hrf[np.sum(events_hrf[:, 1:], axis=1) == 0.0, 0] = 1.0 @@ -201,8 +268,9 @@ def downsample_events_hrf(events_hrf, TR_mri, TR_task, method="uniform"): the shape of events_hrf is (num_time_task, num_event_types) or (num_time_task,) the shape of the downsampled events_hrf is (num_time_mri, num_event_types) """ + flag = False if len(events_hrf.shape) == 1: - flag = 1 + flag = True events_hrf = np.expand_dims(events_hrf, axis=1) events_hrf_ds = [] for i in range(events_hrf.shape[1]): @@ -224,23 +292,160 @@ def downsample_events_hrf(events_hrf, TR_mri, TR_task, method="uniform"): return events_hrf_ds -def extract_task_presence(event_labels, TR_task, TR_array, TR_mri, binary=True): +def shifted_binarizing( + event_labels_all_task_hrf, + task_presence_ratio=0.5, + step=0.001, +): + # find threshold such that the after binarization of event_labels_all_task_hrf, + # the ratio of 1 to 0 is equal to task_presence_ratio + for threshold in np.arange(0, np.max(event_labels_all_task_hrf), step): + # binarize the event_labels_all_task_hrf + event_labels_all_task_hrf_binarized = np.where( + event_labels_all_task_hrf > threshold, 1, 0 + ) + # find the ratio of 1 to 0 in event_labels_all_task_hrf_binarized + new_ratio = np.mean(event_labels_all_task_hrf_binarized) + if new_ratio <= task_presence_ratio: + break + return threshold + + +def GMM_binarizing( + event_labels_all_task_hrf, + threshold=None, + downsample=True, + TR_mri=None, + TR_task=None, + TR_array=None, +): + """_summary_ + + Parameters + ---------- + event_labels_all_task_hrf : _type_ + _description_ + threshold : float, optional + _description_, by default 0.01 + downsample : bool, optional + _description_, by default True + TR_mri : _type_, optional + _description_, by default None + TR_task : _type_, optional + _description_, by default None + TR_array : _type_, optional + _description_, by default None + + Returns + ------- + task_presence : array + _description_ + indices : array + _description_ + ----------- + in order to get the task presence, use task_presence[indices] + """ + if threshold is None: + thresholds_list = [0.01, 0.1, 0.2, 0.3, 0.4] + else: + thresholds_list = [threshold] + event_labels_all_task_hrf = event_labels_all_task_hrf.copy() + event_labels_all_task_hrf_reshaped = event_labels_all_task_hrf.reshape(-1, 1) + # normal the signal to [0, 1] + event_labels_all_task_hrf_reshaped = ( + event_labels_all_task_hrf_reshaped - np.min(event_labels_all_task_hrf_reshaped) + ) / ( + np.max(event_labels_all_task_hrf_reshaped) + - np.min(event_labels_all_task_hrf_reshaped) + ) + # Fit GMM + gmm = GaussianMixture( + n_components=2, means_init=np.array([[0.0], [1.0]]), n_init=5 + ).fit(event_labels_all_task_hrf_reshaped) + means = gmm.means_.flatten() + # if the lower mean is larger than 0.25 or the higher mean is smaller than 0.75, we need to use 3 components + # first find the lower and higher mean + lower_mean = np.min(means) + higher_mean = np.max(means) + if lower_mean > 0.25 or higher_mean < 0.75: + # Fit GMM with 3 components + gmm = GaussianMixture( + n_components=3, means_init=np.array([[0.0], [0.5], [1.0]]), n_init=5 + ).fit(event_labels_all_task_hrf_reshaped) + # downsample to MRI TR + if downsample: + event_labels_all_task_hrf_reshaped = downsample_events_hrf( + event_labels_all_task_hrf_reshaped, TR_mri, TR_task + ) + # some dFC measures (window-based) have a different TR than the task data + if TR_array is not None: + event_labels_all_task_hrf_reshaped = event_labels_all_task_hrf_reshaped[TR_array] + # now predict on vs. off for the downsampled time points + probs = gmm.predict_proba(event_labels_all_task_hrf_reshaped) + # Identify which component corresponds to "on" (higher mean) + # Each component has a mean, and in this case: + # The "off" state should have a lower mean (closer to baseline). + # The "on" state should have a higher mean (HRF-convolved signal is elevated during task). + means = gmm.means_.flatten() + on_component = np.argmax(means) + off_component = np.argmin(means) + # if len(means) == 3: + # # set the mid component to the one that is in between the on and off components + # mid_component = np.argsort(means)[1] + # Get probability of being in the "on" and "off" state + p_on = probs[:, on_component] + p_off = probs[:, off_component] + # if len(means) == 3: + # p_mid = probs[:, mid_component] + # Create a binarized signal with transition points discarded + for threshold_ in thresholds_list: + # try different thresholds + # lower thresholds may result in only one class being present + indices = np.where((p_off >= (1 - threshold_)) | (p_on >= (1 - threshold_)))[0] + task_presence = np.where(p_on >= (1 - threshold_), 1, 0) + + # check that both classes are non-empty + unique_labels = np.unique(task_presence[indices]) + if len(unique_labels) == 2: + break + + if threshold_ == 0.4: + warnings.warn( + f"Even with threshold={threshold_}, only one class present in confident samples." + ) + + return task_presence, indices + + +def extract_task_presence( + event_labels, + TR_task, + TR_mri, + TR_array=None, + binary=True, + binarizing_method="GMM", + no_hrf=False, +): """ event_labels: event labels including 0 and event ids at the time each event happens TR_task: TR of task - TR_array: the time points of the dFC data + TR_array: the time points of the dFC data, optional TR_mri: TR of MRI This function extracts the task presence from the event labels and returns it in the same time points as the dFC data It also downsamples the task presence to the time points of the dFC data if binary is True, the task presence is binarized using the mean of the task presence + binarizing_method: 'median' or 'mean' or 'shift' or 'GMM' + if binarizing_method is 'shift', the task presence is binarized such that the ratio of 1 to 0 is equal to the task presence ratio + + if no_hrf is True, the task presence is not convolved with HRF """ # event_labels_all_task is all conditions together, rest vs. task times event_labels_all_task = np.multiply(event_labels != 0, 1) event_labels_all_task_hrf = event_labels_conv_hrf( - event_labels=event_labels_all_task, TR_mri=TR_mri, TR_task=TR_task + event_labels=event_labels_all_task, TR_mri=TR_mri, TR_task=TR_task, no_hrf=no_hrf ) # keep the task signal of events_hrf_0_1_ds @@ -253,15 +458,405 @@ def extract_task_presence(event_labels, TR_task, TR_array, TR_mri, binary=True): event_labels_all_task_hrf = event_labels_all_task_hrf[:, 1] if binary: - task_presence = np.where( - event_labels_all_task_hrf > np.mean(event_labels_all_task_hrf), 1, 0 - ) + if binarizing_method == "median": + threshold = np.median(event_labels_all_task_hrf) + task_presence = np.where(event_labels_all_task_hrf > threshold, 1, 0) + task_presence = downsample_events_hrf(task_presence, TR_mri, TR_task) + # some dFC measures (window-based) have a different TR than the task data + if TR_array is not None: + task_presence = task_presence[TR_array] + indices = np.arange(task_presence.shape[0]) + elif binarizing_method == "mean": + threshold = np.mean(event_labels_all_task_hrf) + task_presence = np.where(event_labels_all_task_hrf > threshold, 1, 0) + task_presence = downsample_events_hrf(task_presence, TR_mri, TR_task) + # some dFC measures (window-based) have a different TR than the task data + if TR_array is not None: + task_presence = task_presence[TR_array] + indices = np.arange(task_presence.shape[0]) + elif binarizing_method == "shift": + task_presence_ratio = np.mean(event_labels_all_task) + threshold = shifted_binarizing( + event_labels_all_task_hrf=event_labels_all_task_hrf, + task_presence_ratio=task_presence_ratio, + ) + task_presence = np.where(event_labels_all_task_hrf > threshold, 1, 0) + task_presence = downsample_events_hrf(task_presence, TR_mri, TR_task) + # some dFC measures (window-based) have a different TR than the task data + if TR_array is not None: + task_presence = task_presence[TR_array] + indices = np.arange(task_presence.shape[0]) + elif binarizing_method == "GMM": + task_presence, indices = GMM_binarizing( + event_labels_all_task_hrf=event_labels_all_task_hrf, + threshold=None, + downsample=True, + TR_mri=TR_mri, + TR_task=TR_task, + TR_array=TR_array, + ) + else: + raise ValueError( + "binarizing_method should be 'median', 'mean', 'shift', or 'GMM'" + ) else: task_presence = event_labels_all_task_hrf + task_presence = downsample_events_hrf(task_presence, TR_mri, TR_task) + # some dFC measures (window-based) have a different TR than the task data + if TR_array is not None: + task_presence = task_presence[TR_array] + indices = np.arange(task_presence.shape[0]) - task_presence = downsample_events_hrf(task_presence, TR_mri, TR_task) + return task_presence, indices - # some dFC measures (window-based) have a different TR than the task data - task_presence = task_presence[TR_array] - return task_presence +################################# Task Design Features #################################### + + +def calc_relative_task_on(task_presence): + """ + task_presence: 0, 1 array + return: relative_task_on + """ + return np.sum(task_presence) / len(task_presence) + + +def calc_task_duration(task_presence, TR_mri): + """ + task_presence: 0, 1 array + return: list of task_durations + """ + task_durations = list() + start = None + for i in range(1, len(task_presence)): + if task_presence[i] == 1 and task_presence[i - 1] == 0: + start = i + if ( + (task_presence[i] == 0) + and (task_presence[i - 1] == 1) + and (start is not None) + ): + end = i + task_durations.append((end - start) * TR_mri) + start = None + task_durations = np.array(task_durations) + return task_durations + + +def calc_rest_duration(task_presence, TR_mri): + """ + task_presence: 0, 1 array + return: list of rest_durations + """ + rest_durations = list() + if task_presence[0] == 0: + start = 0 + for i in range(1, len(task_presence)): + if task_presence[i] == 0 and task_presence[i - 1] == 1: + start = i + if task_presence[i] == 1 and task_presence[i - 1] == 0: + end = i + rest_durations.append((end - start) * TR_mri) + start = None + if task_presence[-1] == 0: + end = len(task_presence) + if not start is None: + rest_durations.append((end - start) * TR_mri) + rest_durations = np.array(rest_durations) + return rest_durations + + +def calc_transition_freq(task_presence): + """ + task_presence: 0, 1 array + return: num_of_transitions, relative_transition_freq + """ + transitions = np.abs(np.diff(task_presence)) + num_of_transitions = np.sum(transitions) + relative_transition_freq = num_of_transitions / len(task_presence) + return num_of_transitions, relative_transition_freq + + +def noise_model(f, alpha=1.0): + # 1/f^alpha normalized to unit median (cheap default) + spec = 1.0 / np.maximum(f, 1e-6) ** alpha + med = np.median(spec[f > 0]) + return spec / med + + +def compute_periodicity_index( + event_labels, + TR_task, + fmin=0.0, + fmax=None, + no_hrf=False, +): + """ + Compute a noise-free periodicity index for a task timing time course. + + Parameters + ---------- + event_labels : array, shape (T,) + Event labels time course. + TR_task : float + Repetition time (seconds). + fmin, fmax : float + Frequency band (Hz) to consider. If fmax is None, Nyquist is used. + no_hrf : bool + If True, do not convolve with HRF. + + Returns + ------- + results : dict + { + 'periodicity_index': float in [0, 1], higher = more periodic, + 'spectral_entropy': float in [0, 1], lower = more periodic, + 'peak_freq': float, frequency of dominant peak (Hz), + 'peak_dominance': float in [0, 1], peak power / total power + } + """ + if no_hrf: + task_tc = np.multiply(event_labels != 0, 1) + else: + task_tc, _ = extract_task_presence( + event_labels=event_labels, + TR_task=TR_task, + TR_mri=TR_task, + TR_array=None, + binary=False, + binarizing_method="GMM", + no_hrf=False, + ) + task_tc = np.asarray(task_tc) + T = len(task_tc) + + # Detrend and mean-center + x = task_tc - np.mean(task_tc) + + # FFT + freqs = np.fft.rfftfreq(T, d=TR_task) + fft_vals = np.fft.rfft(x) + power = np.abs(fft_vals) ** 2 + + # Restrict frequency range + if fmax is None: + fmax = 0.5 / TR_task # Nyquist + mask = (freqs >= fmin) & (freqs <= fmax) + freqs = freqs[mask] + power = power[mask] + + # Avoid division by zero + if np.all(power == 0): + return { + "periodicity_index": 0.0, + "spectral_entropy": 1.0, + "peak_freq": 0.0, + "peak_dominance": 0.0, + } + + # Normalize spectrum to probability distribution + p = power / power.sum() + + # Spectral entropy (normalized to [0,1]) + eps = 1e-12 + H = -(p * np.log(p + eps)).sum() / np.log(len(p)) # in [0,1], higher = more "flat" + + # Dominant peak and its dominance + peak_idx = np.argmax(power) + peak_freq = freqs[peak_idx] + peak_power = power[peak_idx] + peak_dominance = peak_power / power.sum() # 0–1 + + # Define periodicity index: high when entropy is low and peak is dominant + periodicity_index = (1.0 - H) * peak_dominance + + return { + "periodicity_index": float(periodicity_index), + "spectral_entropy": float(H), + "peak_freq": float(peak_freq), + "peak_dominance": float(peak_dominance), + } + + +def compute_optimality_index( + event_labels, TR_task, TR_mri, fmin=0.0, fmax=None, alpha=1.0 +): + """ + Compute a Worsley-style optimality index (OI) and normalized OI. + + Uses HRF spectrum as the weighting term (no explicit noise model). + OI_norm compares the observed design to an ideal sinusoid with + matched in-band power placed at the optimal frequency. + + Returns: + -------- + { + "OI": float, + "OI_ideal": float, + "OI_norm": float, + "peak_freq": float + } + """ + + # ------------------------- + # 1. Preprocess task timing + # ------------------------- + task_tc = np.multiply(event_labels != 0, 1).astype(float).flatten() + T = len(task_tc) + + # ------------------------- + # 2. HRF Model + # ------------------------- + # same length as our task + time_length_HRF = T * TR_task + oversampling = TR_mri / TR_task + + hrf_tc = glm.first_level.spm_hrf( + tr=TR_mri, oversampling=oversampling, time_length=time_length_HRF, onset=0.0 + ) + hrf_tc = np.asarray(hrf_tc) + + # Pad or truncate HRF to length T + if len(hrf_tc) < T: + hrf_tc = np.pad(hrf_tc, (0, T - len(hrf_tc)), mode="constant") + else: + hrf_tc = hrf_tc[:T] + + # ------------------------- + # 3. Frequency grid + noise PSD + # ------------------------- + freqs = np.fft.rfftfreq(T, d=TR_task) + + # ------------------------- + # 4. FFT-based spectra + # ------------------------- + design_spectrum = np.abs(np.fft.rfft(task_tc)) ** 2 + hrf_spectrum = np.abs(np.fft.rfft(hrf_tc)) ** 2 + + # ------------------------- + # 5. Frequency mask + # ------------------------- + if fmax is None: + fmax = 0.5 / TR_task + + mask = (freqs >= float(fmin)) & (freqs <= fmax) + freqs_m = freqs[mask] + design_spectrum_m = design_spectrum[mask] + hrf_spectrum_m = hrf_spectrum[mask] + + eps = 1e-12 + snr_weight = hrf_spectrum_m + + # ------------------------- + # 6. ORIGINAL (TASK) OI + # ------------------------- + OI = np.sum(design_spectrum_m * snr_weight) + + # ------------------------- + # 7. IDEAL OI UPPER BOUND + # ------------------------- + + if freqs_m.size == 0: + # no nonzero frequencies in the band + return { + "OI": float(OI), + "OI_ideal": 0.0, + "OI_norm": 0.0, + "peak_freq": 0.0, + } + + # Report dominant task frequency for interpretability. + peak_idx = np.argmax(design_spectrum_m) + peak_freq = freqs_m[peak_idx] + + # Theoretical in-band upper bound for weighted spectral sum: + # sum(d_i * w_i) <= max(w_i) * sum(d_i), with d_i >= 0. + design_band_power = float(np.sum(design_spectrum_m)) + max_weight = float(np.max(snr_weight)) if snr_weight.size > 0 else 0.0 + OI_ideal = design_band_power * max_weight + + # ------------------------- + # 8. Normalized OI + # ------------------------- + if OI_ideal < eps: + OI_norm = 0.0 + else: + OI_norm = OI / OI_ideal + + return { + "OI": float(OI), + "OI_ideal": float(OI_ideal), + "OI_norm": float(OI_norm), + "peak_freq": float(peak_freq), + } + + +from scipy.ndimage import uniform_filter1d +from scipy.signal import find_peaks + + +def periodicity_autocorr(event_labels, TR_task, max_lag=None): + """ + Measure how periodic a 0/1 event label time course is using autocorrelation. + + Parameters + ---------- + event_labels : array-like + array of 0/1 labels (e.g., rest=0, task=1). + TR_task : float + Repetition time (seconds). + max_lag : int or None + Maximum lag to compute autocorrelation. If None, uses len(x)//2. + + Returns + ------- + periodicity : float + Strength of the strongest non-zero autocorrelation peak (in [−1, 1]). + best_lag : int + Lag (in samples) at which this peak occurs. + r : np.ndarray + Autocorrelation values from lag 0..max_lag. + """ + x, _ = extract_task_presence( + event_labels=event_labels, + TR_task=TR_task, + TR_mri=TR_task, + TR_array=None, + binary=False, + binarizing_method="GMM", + no_hrf=False, + ) + + # Optional: center to remove bias from unbalanced 0/1 ratio + x = x - x.mean() + + if max_lag is None: + max_lag = len(x) // 2 + + # r[0] = 1 by definition + r = acf(x, nlags=max_lag, fft=False) + + # Find true peaks (periodic peaks) --- + peaks, _ = find_peaks(r) + + if len(peaks) == 0: + return {"periodicity": 0.0, "best_lag": None, "r": r} + + # skip lag 0 + peaks = peaks[peaks > 0] + + if len(peaks) == 0: + return {"periodicity": 0.0, "best_lag": None, "r": r} + + best_lag = peaks[np.argmax(r[peaks])] + + # # Ignore lag 0, find strongest positive correlation + # r_nonzero = r[1:] + # best_lag = np.argmax(r_nonzero) + 1 + periodicity = r[best_lag] + + return { + "periodicity": periodicity, + "best_lag": best_lag, + "r": r, + } diff --git a/simul_dFC/README.rst b/simul_dFC/README.rst new file mode 100644 index 0000000..86eaf18 --- /dev/null +++ b/simul_dFC/README.rst @@ -0,0 +1,27 @@ +============================================ +PydFC: simul_dFC Module Documentation +============================================ + +The ``simul_dFC`` module generates **synthetic task-based fMRI data** for benchmarking dFC methods under controlled conditions. + +It uses `The Virtual Brain (TVB) `_ simulator to produce BOLD signals driven by a known task design, allowing ground-truth evaluation of dFC methods. + +Two task paradigms are supported: + +* **Real task-derived** (``tasks_info_ds003465.json``) — task timing extracted from an OpenNeuro dataset (ds003465) to drive the simulation. +* **Synthetic pulse-train** (``tasks_info_pulseTrain.json``) — parametric block designs with configurable onset, duration, and frequency. + +Running +------- + +Set ``VENV_PATH`` and ``PYDFC_CODE_DIR`` in the cluster configuration block at the top of the job script, then submit:: + + # SLURM + sbatch --array=1-N run_scripts_slurm/run_simulator.sh + + # SGE + qsub -t 1-N run_scripts_sge/run_simulator.sh + +The script expects a ``subj_list.txt`` (one subject ID per line), a ``dataset_info.json``, and a ``tasks_info.json`` in the same directory as the run script. + +Simulated outputs are consumed directly by the ``task_dFC`` pipeline starting at ``FCS_estimate.py``. diff --git a/simul_dFC/run_scripts_sge/run_simulator.sh b/simul_dFC/run_scripts_sge/run_simulator.sh new file mode 100644 index 0000000..f25fa25 --- /dev/null +++ b/simul_dFC/run_scripts_sge/run_simulator.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# +#$ -N simul_dfc_job +#$ -o logs/simul_out.txt +#$ -e logs/simul_err.txt +#$ -l h_rt=24:00:00 +#$ -l h_vmem=8g +#$ -t 1-200 +#$ -q YOUR_QUEUE + +# ---- Cluster configuration (set these for your system) ---- +VENV_PATH="/path/to/your/venv/bin/activate" +PYDFC_CODE_DIR="/path/to/pydfc" +# For conda environments, replace the two lines above with: +# CONDA_SH="/path/to/conda/etc/profile.d/conda.sh" +# CONDA_ENV="pydfc" +# ----------------------------------------------------------- + +SUBJECT_LIST="./subj_list.txt" +DATASET_INFO="./dataset_info.json" +TASKS_INFO="./tasks_info.json" + +SUBJECT_ID=`sed -n "${SGE_TASK_ID}p" $SUBJECT_LIST` +echo "Subject ID: $SUBJECT_ID" + +# Activate virtual environment +source "$VENV_PATH" +# For conda: source "$CONDA_SH" && conda activate "$CONDA_ENV" + +# Run Python script +python "$PYDFC_CODE_DIR/simul_dFC/task_data_simulator.py" \ +--dataset_info $DATASET_INFO \ +--tasks_info $TASKS_INFO \ +--participant_id $SUBJECT_ID + +# Deactivate environment +deactivate diff --git a/simul_dFC/run_scripts_slurm/run_simulator.sh b/simul_dFC/run_scripts_slurm/run_simulator.sh new file mode 100644 index 0000000..9dd5c5a --- /dev/null +++ b/simul_dFC/run_scripts_slurm/run_simulator.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# +#SBATCH --job-name=simul_dfc_job # Optional: Name of your job +#SBATCH --output=logs/simul_out.txt # Standard output log +#SBATCH --error=logs/simul_err.txt # Standard error log +#SBATCH --account=YOUR_ACCOUNT # Account +#SBATCH --time=24:00:00 # Walltime for each task (24 hours) +#SBATCH --mem=8G # Memory request per node +#SBATCH --array=1-200 # Task array specification + +SUBJECT_LIST="./subj_list.txt" +DATASET_INFO="./dataset_info.json" +TASKS_INFO="./tasks_info.json" + +SUBJECT_ID=`sed -n "${SLURM_ARRAY_TASK_ID}p" $SUBJECT_LIST` +echo "Subject ID: $SUBJECT_ID" + +# ---- Cluster configuration (set these for your system) ---- +VENV_PATH="/path/to/your/venv/bin/activate" +PYDFC_CODE_DIR="/path/to/pydfc" +# ----------------------------------------------------------- + +# Activate virtual environment +source "$VENV_PATH" + +# Run Python script +python "$PYDFC_CODE_DIR/simul_dFC/task_data_simulator.py" \ +--dataset_info $DATASET_INFO \ +--tasks_info $TASKS_INFO \ +--participant_id $SUBJECT_ID + +# Deactivate environment +deactivate diff --git a/simul_dFC/run_scripts_slurm/tasks_info_ds003465.json b/simul_dFC/run_scripts_slurm/tasks_info_ds003465.json new file mode 100644 index 0000000..81d4f88 --- /dev/null +++ b/simul_dFC/run_scripts_slurm/tasks_info_ds003465.json @@ -0,0 +1,42 @@ +{ + "task-Axcpt": { + "task_name": "task-Axcpt", + "task_data": "/path/to/your/data/ds003465/derivatives/ROI_timeseries/{subj_id}/ses-wave1bas/{subj_id}_ses-wave1bas_task-Axcpt_run-1_task-data.npy", + "TAVG_period": 1.0, + "num_stimulated_regions": 5, + "global_conn_coupling_coef": 0.0126, + "D": 0.1, + "conn_speed": 1.0, + "dt": 0.5 + }, + "task-Cuedts": { + "task_name": "task-Cuedts", + "task_data": "/path/to/your/data/ds003465/derivatives/ROI_timeseries/{subj_id}/ses-wave1bas/{subj_id}_ses-wave1bas_task-Cuedts_run-1_task-data.npy", + "TAVG_period": 1.0, + "num_stimulated_regions": 5, + "global_conn_coupling_coef": 0.0126, + "D": 0.1, + "conn_speed": 1.0, + "dt": 0.5 + }, + "task-Stern": { + "task_name": "task-Stern", + "task_data": "/path/to/your/data/ds003465/derivatives/ROI_timeseries/{subj_id}/ses-wave1bas/{subj_id}_ses-wave1bas_task-Stern_run-1_task-data.npy", + "TAVG_period": 1.0, + "num_stimulated_regions": 5, + "global_conn_coupling_coef": 0.0126, + "D": 0.1, + "conn_speed": 1.0, + "dt": 0.5 + }, + "task-Stroop": { + "task_name": "task-Stroop", + "task_data": "/path/to/your/data/ds003465/derivatives/ROI_timeseries/{subj_id}/ses-wave1bas/{subj_id}_ses-wave1bas_task-Stroop_run-1_task-data.npy", + "TAVG_period": 1.0, + "num_stimulated_regions": 5, + "global_conn_coupling_coef": 0.0126, + "D": 0.1, + "conn_speed": 1.0, + "dt": 0.5 + } +} diff --git a/simul_dFC/run_scripts_slurm/tasks_info_pulseTrain.json b/simul_dFC/run_scripts_slurm/tasks_info_pulseTrain.json new file mode 100644 index 0000000..9a153ec --- /dev/null +++ b/simul_dFC/run_scripts_slurm/tasks_info_pulseTrain.json @@ -0,0 +1,72 @@ +{ + "task-lowFreqLongRest": { + "task_name": "task-lowFreqLongRest", + "onset_time": 20.0, + "task_duration": 8.0, + "task_block_duration": 20.0, + "sim_length": 250e3, + "BOLD_period": 500, + "TAVG_period": 1.0, + "num_stimulated_regions": 5, + "global_conn_coupling_coef": 0.0126, + "D": 0.1, + "conn_speed": 1.0, + "dt": 0.5 + }, + "task-lowFreqShortRest": { + "task_name": "task-lowFreqShortRest", + "onset_time": 20.0, + "task_duration": 12.0, + "task_block_duration": 20.0, + "sim_length": 250e3, + "BOLD_period": 500, + "TAVG_period": 1.0, + "num_stimulated_regions": 5, + "global_conn_coupling_coef": 0.0126, + "D": 0.1, + "conn_speed": 1.0, + "dt": 0.5 + }, + "task-lowFreqShortTask": { + "task_name": "task-lowFreqShortTask", + "onset_time": 20.0, + "task_duration": 1.0, + "task_block_duration": 20.0, + "sim_length": 250e3, + "BOLD_period": 500, + "TAVG_period": 1.0, + "num_stimulated_regions": 5, + "global_conn_coupling_coef": 0.0126, + "D": 0.1, + "conn_speed": 1.0, + "dt": 0.5 + }, + "task-highFreqLongRest": { + "task_name": "task-highFreqLongRest", + "onset_time": 20.0, + "task_duration": 1.0, + "task_block_duration": 5.0, + "sim_length": 250e3, + "BOLD_period": 500, + "TAVG_period": 1.0, + "num_stimulated_regions": 5, + "global_conn_coupling_coef": 0.0126, + "D": 0.1, + "conn_speed": 1.0, + "dt": 0.5 + }, + "task-highFreqShortRest": { + "task_name": "task-highFreqShortRest", + "onset_time": 20.0, + "task_duration": 4.0, + "task_block_duration": 5.0, + "sim_length": 250e3, + "BOLD_period": 500, + "TAVG_period": 1.0, + "num_stimulated_regions": 5, + "global_conn_coupling_coef": 0.0126, + "D": 0.1, + "conn_speed": 1.0, + "dt": 0.5 + } +} diff --git a/simul_dFC/task_data_simulator.py b/simul_dFC/task_data_simulator.py new file mode 100644 index 0000000..df80ecc --- /dev/null +++ b/simul_dFC/task_data_simulator.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed March 20 2024 + +@author: mte +""" +import argparse +import json +import os +import traceback +import warnings + +import numpy as np +from tvb.simulator.lab import * + +from pydfc import simul_utils + +warnings.simplefilter("ignore") + +os.environ["MKL_NUM_THREADS"] = "16" +os.environ["NUMEXPR_NUM_THREADS"] = "16" +os.environ["OMP_NUM_THREADS"] = "16" +################################# Parameters #################################### + +# argparse +HELPTEXT = """ +Script to simulate task-based data. +""" +parser = argparse.ArgumentParser(description=HELPTEXT) + +parser.add_argument("--dataset_info", type=str, help="path to dataset info file") +parser.add_argument("--tasks_info", type=str, help="path to tasks info file") +parser.add_argument("--participant_id", type=str, help="participant id") + +args = parser.parse_args() + +dataset_info_file = args.dataset_info +tasks_info_file = args.tasks_info +participant_id = args.participant_id + +# Read dataset info +with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + +if "{dataset}" in dataset_info["main_root"]: + main_root = dataset_info["main_root"].replace("{dataset}", dataset_info["dataset"]) +else: + main_root = dataset_info["main_root"] + +if "{main_root}" in dataset_info["roi_root"]: + output_root = dataset_info["roi_root"].replace("{main_root}", main_root) +else: + output_root = dataset_info["roi_root"] + +# Read tasks info +with open(tasks_info_file, "r") as f: + all_tasks_info = json.load(f) + +print(f"subject-level simulation started running ... for subject: {participant_id} ...") + +for task in all_tasks_info: + + # the task_data file might not exist for some subjects, so we use a try-except block + try: + time_series, task_data = simul_utils.simulate_task_data( + participant_id, all_tasks_info[task] + ) + except Exception as e: + print(f"Error simulating task {task} for participant {participant_id}: {e}") + # print traceback + traceback.print_exc() + continue + + # save the time series and task data + output_file_prefix = f"{participant_id}_{task}" + if not os.path.exists(f"{output_root}/{participant_id}/"): + os.makedirs(f"{output_root}/{participant_id}/") + np.save( + f"{output_root}/{participant_id}/{output_file_prefix}_time-series.npy", + time_series, + ) + np.save( + f"{output_root}/{participant_id}/{output_file_prefix}_task-data.npy", task_data + ) + +print("****************** DONE ******************") +#################################################################################### diff --git a/task_dFC/FCS_estimate.py b/task_dFC/FCS_estimate.py index de4d738..17a45eb 100644 --- a/task_dFC/FCS_estimate.py +++ b/task_dFC/FCS_estimate.py @@ -1,149 +1,228 @@ +import argparse +import json import os import time +import traceback import warnings import numpy as np -from pydfc import MultiAnalysis, data_loader +from pydfc import data_loader, multi_analysis_utils warnings.simplefilter("ignore") -os.environ["MKL_NUM_THREADS"] = "16" -os.environ["NUMEXPR_NUM_THREADS"] = "16" -os.environ["OMP_NUM_THREADS"] = "16" - -################################# Parameters ################################# -# data paths -# main_root = '../../DATA/ds002785/' # for local -main_root = "../../../DATA/task-based/openneuro/ds002785" # for server -roi_root = f"{main_root}/derivatives/ROI_timeseries" -output_root = f"{main_root}/derivatives/fitted_MEASURES" - -# for consistency we use 0 for resting state -TASKS = [ - "task-restingstate", - "task-anticipation", - "task-emomatching", - "task-faces", - "task-gstroop", - "task-workingmemory", -] - -job_id = int(os.getenv("SGE_TASK_ID")) -TASK_id = job_id - 1 # SGE_TASK_ID starts from 1 not 0 -if TASK_id >= len(TASKS): - print("TASK_id out of TASKS") - exit() -task = TASKS[TASK_id] - -###### MEASUREMENT PARAMETERS ###### - -# W is in sec - -params_methods = { - # Sliding Parameters - "W": 44, - "n_overlap": 1.0, - "sw_method": "pear_corr", - "tapered_window": True, - # TIME_FREQ - "TF_method": "WTC", - # CLUSTERING AND DHMM - "clstr_base_measure": "SlidingWindow", - # HMM - "hmm_iter": 20, - "dhmm_obs_state_ratio": 16 / 24, - # State Parameters - "n_states": 12, - "n_subj_clstrs": 20, - # Parallelization Parameters - "n_jobs": 2, - "verbose": 0, - "backend": "loky", - # SESSION - "session": task, - # Hyper Parameters - "normalization": True, - "num_subj": None, # None or 216? - "num_time_point": None, # None or set? -} - -###### HYPER PARAMETERS ALTERNATIVE ###### - -MEASURES_name_lst = [ - "SlidingWindow", - "Time-Freq", - "CAP", - "ContinuousHMM", - "Windowless", - "Clustering", - "DiscreteHMM", -] - -alter_hparams = { - # 'session': ['Rest1_RL', 'Rest2_LR', 'Rest2_RL'], - # 'n_overlap': [0, 0.25, 0.75, 1], - # 'n_states': [6, 16], - # # 'normalization': [], - # 'num_subj': [50, 100, 200], - # 'num_select_nodes': [30, 50, 333], - # 'num_time_point': [800, 1000], - # 'Fs_ratio': [0.50, 0.75, 1.5], - # 'noise_ratio': [1.00, 2.00, 3.00], - # 'num_realization': [] -} - -###### MultiAnalysis PARAMETERS ###### - -params_multi_analysis = { - # Parallelization Parameters - "n_jobs": None, - "verbose": 0, - "backend": "loky", -} - -################################# LOAD DATA ################################# - -BOLD = data_loader.load_TS( - data_root=roi_root, file_name="time_series.npy", SESSIONs=task, subj_id2load=None -) - -################################# Visualize BOLD ################################# - -# for session in BOLD: -# BOLD.visualize(start_time=0, end_time=2000, nodes_lst=list(range(10)), -# save_image=False, output_root=None) - -################################ Measures of dFC ################################# - -MA = MultiAnalysis( - analysis_name=f"task-based-dFC-ds002785-{task}", **params_multi_analysis -) - -MEASURES_lst = MA.measures_initializer(MEASURES_name_lst, params_methods, alter_hparams) - -tic = time.time() -print("Measurement Started ...") - -################################# estimate FCS ################################# - -for MEASURE_id, measure in enumerate(MEASURES_lst): - - print("MEASURE: " + measure.measure_name) - print("FCS estimation started...") - - if measure.is_state_based: - measure.estimate_FCS(time_series=BOLD) - - # dFC_analyzer.estimate_group_FCS(time_series_dict=BOLD) - print("FCS estimation done.") - - # Save - if not os.path.exists(f"{output_root}/{task}"): - os.makedirs(f"{output_root}/{task}") - np.save(f"{output_root}/{task}/MEASURE_{str(MEASURE_id)}.npy", measure) - -print(f"Measurement required {time.time() - tic:0.3f} seconds.") -np.save(f"{output_root}/{task}/multi_analysis.npy", MA) - +######################################################################################## + + +def run_FCS_estimate( + params_methods, + MEASURES_name_lst, + alter_hparams, + params_multi_analysis, + task, + roi_root, + output_root, + session=None, + run=None, +): + if session is None: + output_dir = f"{output_root}" + else: + output_dir = f"{output_root}/{session}" + + if run is None: + print(f"TASK: {task} started ...") + if session is None: + BOLD_file_name = "{subj_id}_{task}_time-series.npy" + file_suffix = f"{task}" + else: + BOLD_file_name = "{subj_id}_{session}_{task}_time-series.npy" + file_suffix = f"{session}_{task}" + else: + print(f"TASK: {task}, RUN: {run} started ...") + if session is None: + BOLD_file_name = "{subj_id}_{task}_{run}_time-series.npy" + file_suffix = f"{task}_{run}" + else: + BOLD_file_name = "{subj_id}_{session}_{task}_{run}_time-series.npy" + file_suffix = f"{session}_{task}_{run}" + ################################# LOAD DATA ################################# + BOLD = data_loader.load_TS( + data_root=roi_root, + file_name=BOLD_file_name, + subj_id2load=None, + task=task, + session=session, + run=run, + ) + + if BOLD is None: + print(f"No BOLD data found for task: {task}, session: {session}, run: {run}.") + return + ################################ Measures of dFC ################################# + + MEASURES_lst, hyper_param_info = multi_analysis_utils.measures_initializer( + MEASURES_name_lst, params_methods, alter_hparams + ) + + # in this script we process only one measure + # if alter_hparams is not empty, we need to change the naming of the output files + # to differentiate between the measures + if len(MEASURES_lst) == 1: + only_one_measure = True + n_jobs = None + else: + only_one_measure = False + n_jobs = params_multi_analysis["n_jobs"] + + if not only_one_measure: + # we assume only one hyperparameter is altered + # alter_hparams is a dictionary with one key + # ow change the naming of the output files + assert len(alter_hparams) == 1, ( + "alter_hparams should have only one key, " + "but got more than one. This script is designed to process only one hyperparameter." + ) + hyper_param_name = [key for key in alter_hparams.keys()][0] + + tic = time.time() + print("Measurement Started ...") + + ################################# estimate FCS ################################# + + MEASURES_fit_lst = multi_analysis_utils.estimate_group_FCS( + time_series=BOLD, + MEASURES_lst=MEASURES_lst, + n_jobs=n_jobs, + verbose=params_multi_analysis["verbose"], + backend=params_multi_analysis["backend"], + ) + + if only_one_measure: + assert ( + len(MEASURES_fit_lst) == 1 + ), "Only one measure should be processed, but got more than one." + + # Save the fitted measures + for measure in MEASURES_fit_lst: + try: + if not os.path.exists(f"{output_dir}"): + os.makedirs(f"{output_dir}") + except OSError as err: + print(err) + if only_one_measure: + measure_name = measure.measure_name + else: + measure_name = f"{measure.measure_name}-{hyper_param_name}-{measure.params[hyper_param_name]}" + np.save(f"{output_dir}/MEASURE_{file_suffix}_{measure_name}.npy", measure) + + print(f"Measurement required {time.time() - tic:0.3f} seconds.") + + +######################################################################################## + +if __name__ == "__main__": + # argparse + HELPTEXT = """ + Script to fit dFC methods for a given task. + """ + + parser = argparse.ArgumentParser(description=HELPTEXT) + + parser.add_argument("--dataset_info", type=str, help="path to dataset info file") + parser.add_argument("--methods_config", type=str, help="methods config file") + + args = parser.parse_args() + + dataset_info_file = args.dataset_info + methods_config_file = args.methods_config + + # Read dataset info + with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + + # Read methods config + with open(methods_config_file, "r") as f: + methods_config = json.load(f) + + TASKS = dataset_info["TASKS"] + + if "SESSIONS" in dataset_info: + SESSIONS = dataset_info["SESSIONS"] + else: + SESSIONS = None + if SESSIONS is None: + SESSIONS = [None] + + if "RUNS" in dataset_info: + RUNS = dataset_info["RUNS"] + else: + RUNS = None + if RUNS is None: + RUNS = {task: [None] for task in TASKS} + + if "{dataset}" in dataset_info["main_root"]: + main_root = dataset_info["main_root"].replace( + "{dataset}", dataset_info["dataset"] + ) + else: + main_root = dataset_info["main_root"] + + if "{main_root}" in dataset_info["roi_root"]: + roi_root = dataset_info["roi_root"].replace("{main_root}", main_root) + else: + roi_root = dataset_info["roi_root"] + + if "{main_root}" in dataset_info["fitted_measures_root"]: + fitted_measures_root = dataset_info["fitted_measures_root"].replace( + "{main_root}", main_root + ) + else: + fitted_measures_root = dataset_info["fitted_measures_root"] + + # methods params + params_methods = methods_config["params_methods"] + MEASURES_name_lst = methods_config["MEASURES_name_lst"] + alter_hparams = methods_config["alter_hparams"] + params_multi_analysis = methods_config["params_multi_analysis"] + + # pick one method + job_id = os.getenv("SGE_TASK_ID") # for SGE + if job_id is None: + job_id = os.getenv("SLURM_ARRAY_TASK_ID") # for SLURM + job_id = int(job_id) + MEASURE_id = job_id - 1 # job_id starts from 1 not 0 + if MEASURE_id >= len(MEASURES_name_lst): + print("MEASURE_id out of MEASURES_name_lst range") + exit() + picked_measure_list = [MEASURES_name_lst[MEASURE_id]] # pick one method but as a list + + print( + f"FCS estimation CODE started running ... for measure: {picked_measure_list[0]} ..." + ) + + for session in SESSIONS: + for task in TASKS: + for run in RUNS[task]: + try: + run_FCS_estimate( + params_methods=params_methods, + MEASURES_name_lst=picked_measure_list, + alter_hparams=alter_hparams, + params_multi_analysis=params_multi_analysis, + task=task, + roi_root=roi_root, + output_root=fitted_measures_root, + session=session, + run=run, + ) + except Exception as e: + print( + f"Error in run_FCS_estimate for task: {task}, session: {session}, run: {run}, measure: {picked_measure_list[0]}, error: {e}" + ) + traceback.print_exc() + + print( + f"FCS estimation CODE finished running ... for measure: {picked_measure_list[0]} ..." + ) ################################################################################# diff --git a/task_dFC/ML.py b/task_dFC/ML.py new file mode 100644 index 0000000..271fa5d --- /dev/null +++ b/task_dFC/ML.py @@ -0,0 +1,252 @@ +import argparse +import json +import os +import traceback + +import numpy as np +from joblib import Parallel, delayed + +from pydfc.ml_utils import extract_task_features, task_presence_classification + +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["NUMEXPR_NUM_THREADS"] = "1" +os.environ["OMP_NUM_THREADS"] = "1" + +####################################################################################### + + +def run_task_features_extraction( + TASKS, + RUNS, + SESSIONS, + roi_root, + dFC_root, + output_root, +): + for session in SESSIONS: + + # Extract task features without HRF effect + task_features = extract_task_features( + TASKS=TASKS, + RUNS=RUNS, + session=session, + roi_root=roi_root, + dFC_root=dFC_root, + no_hrf=True, + ) + + # Extract task features with HRF effect + task_features_hrf = extract_task_features( + TASKS=TASKS, + RUNS=RUNS, + session=session, + roi_root=roi_root, + dFC_root=dFC_root, + no_hrf=False, + ) + + if session is None: + folder = f"{output_root}/task_features" + else: + folder = f"{output_root}/task_features/{session}" + try: + if not os.path.exists(folder): + os.makedirs(folder) + except OSError as err: + print(err) + try: + if not os.path.exists(f"{folder}/task_features.npy"): + np.save(f"{folder}/task_features.npy", task_features) + if not os.path.exists(f"{folder}/task_features_hrf.npy"): + np.save(f"{folder}/task_features_hrf.npy", task_features_hrf) + except OSError as err: + print(err) + + +def classify_single_run( + task, run, session, dFC_id, roi_root, dFC_root, dynamic_pred, normalize_dFC +): + try: + ML_scores_new = task_presence_classification( + task=task, + dFC_id=dFC_id, + roi_root=roi_root, + dFC_root=dFC_root, + run=run, + session=session, + dynamic_pred=dynamic_pred, + normalize_dFC=normalize_dFC, + ) + return task, run, ML_scores_new + except Exception as e: + print(f"Error in task presence classification for {session} {task} {run}: {e}") + traceback.print_exc() + return task, run, None + + +def run_classification( + dFC_id, + TASKS, + RUNS, + SESSIONS, + roi_root, + dFC_root, + output_root, + dynamic_pred="no", + normalize_dFC=False, + n_jobs=-1, # Number of parallel jobs; -1 = all available cores +): + for session in SESSIONS: + if session is not None: + print(f"=================== {session} ===================") + + ML_scores = { + "group_lvl": {}, + "subj_lvl": {}, + } + + # Parallel execution + results = Parallel(n_jobs=n_jobs, verbose=0, backend="loky")( + delayed(classify_single_run)( + task, + run, + session, + dFC_id, + roi_root, + dFC_root, + dynamic_pred, + normalize_dFC, + ) + for task in TASKS + for run in RUNS[task] + ) + + # Aggregate results + for task, run, result in results: + if result is None: + continue + for key in result["group_lvl"]: + if key not in ML_scores["group_lvl"]: + ML_scores["group_lvl"][key] = [] + ML_scores["group_lvl"][key].extend(result["group_lvl"][key]) + for key in result["subj_lvl"]: + if key not in ML_scores["subj_lvl"]: + ML_scores["subj_lvl"][key] = [] + ML_scores["subj_lvl"][key].extend(result["subj_lvl"][key]) + + # Save output + folder = ( + f"{output_root}/classification" + if session is None + else f"{output_root}/classification/{session}" + ) + try: + os.makedirs(folder, exist_ok=True) + except OSError as err: + print(err) + + np.save(f"{folder}/ML_scores_classify_{dFC_id}.npy", ML_scores) + + +####################################################################################### + +if __name__ == "__main__": + # argparse + HELPTEXT = """ + Script to apply Machine Learning on dFC results to predict task presence. + """ + + parser = argparse.ArgumentParser(description=HELPTEXT) + + parser.add_argument("--dataset_info", type=str, help="path to dataset info file") + + args = parser.parse_args() + + dataset_info_file = args.dataset_info + + # Read dataset info + with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + + print("Task presence prediction started ...") + + TASKS = dataset_info["TASKS"] + if "RUNS" in dataset_info: + RUNS = dataset_info["RUNS"] + else: + RUNS = None + if RUNS is None: + RUNS = {task: [None] for task in TASKS} + + if "SESSIONS" in dataset_info: + SESSIONS = dataset_info["SESSIONS"] + else: + SESSIONS = None + if SESSIONS is None: + SESSIONS = [None] + + if "{dataset}" in dataset_info["main_root"]: + main_root = dataset_info["main_root"].replace( + "{dataset}", dataset_info["dataset"] + ) + else: + main_root = dataset_info["main_root"] + + if "{main_root}" in dataset_info["roi_root"]: + roi_root = dataset_info["roi_root"].replace("{main_root}", main_root) + else: + roi_root = dataset_info["roi_root"] + + if "{main_root}" in dataset_info["dFC_root"]: + dFC_root = dataset_info["dFC_root"].replace("{main_root}", main_root) + else: + dFC_root = dataset_info["dFC_root"] + + if "{main_root}" in dataset_info["ML_root"]: + ML_root = dataset_info["ML_root"].replace("{main_root}", main_root) + else: + ML_root = dataset_info["ML_root"] + + # # The task feature extraction will be executed multiple times in parallel redundantly + # try: + # run_task_features_extraction( + # TASKS=TASKS, + # RUNS=RUNS, + # SESSIONS=SESSIONS, + # roi_root=roi_root, + # dFC_root=dFC_root, + # output_root=ML_root, + # ) + # except Exception as e: + # print(f"Error in task features extraction: {e}") + # traceback.print_exc() + # print("Task features extraction finished.") + + job_id = os.getenv("SGE_TASK_ID") # for SGE + if job_id is None: + job_id = os.getenv("SLURM_ARRAY_TASK_ID") # for SLURM + job_id = int(job_id) + dFC_id = job_id - 1 # TASK_ID starts from 1 not 0 + + print(f"Task presence classification started for dFC ID {dFC_id}...") + try: + run_classification( + dFC_id=dFC_id, + TASKS=TASKS, + RUNS=RUNS, + SESSIONS=SESSIONS, + roi_root=roi_root, + dFC_root=dFC_root, + output_root=ML_root, + dynamic_pred="no", + normalize_dFC=False, + n_jobs=8, + ) + except Exception as e: + print(f"Error in classification for dFC ID {dFC_id}: {e}") + traceback.print_exc() + print(f"Task presence classification finished for dFC ID {dFC_id}.") + + print(f"Task presence prediction finished for dFC ID {dFC_id}.") + +####################################################################################### diff --git a/task_dFC/README.rst b/task_dFC/README.rst new file mode 100644 index 0000000..0de41ea --- /dev/null +++ b/task_dFC/README.rst @@ -0,0 +1,133 @@ +.. raw:: html + + GitHub Repository + +======================================================= +PydFC: task_dFC Module Documentation +======================================================= + +The ``task_dFC`` module provides a scalable, open-source Python solution for the **large-scale benchmarking and application of dynamic functional connectivity (dFC) methods**. + +Its core purpose is to apply end-to-end analytical workflows to fMRI data to assess the efficacy of various dFC methodologies in **predicting ongoing cognitive states** — specifically, distinguishing between moments of task engagement versus rest at the single repetition time (TR) resolution. + +Methods Implemented +------------------- + +The module supports a diverse selection of seven well-established dFC methodologies implemented within the PydFC toolbox: + +* **State-free Methods:** Designed to capture continuous fluctuations in connectivity. + + * Sliding Window (SW) + * Time-Frequency (TF) + +* **State-based Methods:** Designed to identify recurring, discrete connectivity patterns or states. + + * Co-Activation Patterns (CAP) + * Clustering (SWC) + * Continuous Hidden Markov Models (CHMM) + * Discrete Hidden Markov Models (DHMM) + * Windowless (WL) + +Prerequisites +------------- + +* Preprocessed fMRI data in BIDS format with ``events.tsv`` files (e.g., via fMRIPrep). +* PydFC installed (see the root ``README.rst``). +* For fMRIPrep preprocessing: ``nipoppy`` installed and configured. + +Configuration Files +------------------- + +Before running the pipeline, fill in the following JSON configuration files located in ``run_scripts_slurm/`` or ``run_scripts_sge/``: + +* ``dataset_info.json`` — dataset name, root paths, sessions, tasks, runs, and BOLD suffix. +* ``methods_config.json`` — dFC method parameters, method list, and parallelism settings. +* ``multi_dataset_info.json`` — paths and dataset lists for cross-dataset analysis. +* ``global_config.json`` — nipoppy configuration for fMRIPrep (containers, FreeSurfer license, TemplateFlow). + +Analysis Pipeline: Script-Based Workflow +----------------------------------------- + +The ``task_dFC`` workflow starts assuming that fMRI data (in BIDS format with ``events.tsv``) has undergone standard preprocessing (via fMRIPrep). The subsequent analysis is executed sequentially through the following scripts: + +1. ``nifti_to_roi_signal.py`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +**Function:** Runs denoising and extracts regional BOLD time series from preprocessed NIfTI data. + +**Details:** Voxel-wise BOLD signals are parcellated, typically using an atlas such as the Schaefer 100-region atlas, yielding regional time series that serve as the input for dFC assessment. + +2. ``FCS_estimate.py`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +**Function:** Estimates Functional Connectivity States (FCS). + +**Details:** This script fits the dFC model required by **state-based methodologies** (CAP, HMM, Clustering) that rely on identifying **group-level recurring patterns**. Must be run before ``dFC_assessment.py``. + +3. ``dFC_assessment.py`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +**Function:** Computes individual-level dFC patterns. + +**Details:** Applies the seven implemented dFC methodologies (SW, TF, CAP, etc.) to the BOLD signals of each run and subject to obtain the corresponding high-dimensional dFC patterns. + +4. ``ML.py`` +~~~~~~~~~~~~~~~~~~~~ +**Function:** Implements the core machine learning pipeline, including cognitive state labeling, feature extraction, supervised classification, and separability analysis. + +**A. Task Presence Labeling** + +* Initial stimulus timings from ``events.tsv`` are convolved with a canonical **Hemodynamic Response Function (HRF)** to account for hemodynamic delay. +* The HRF-convolved signal is binarized using a **Gaussian Mixture Model (GMM)** to assign time points as "rest" or "task-present". This process critically identifies and removes ambiguous **"gray zone" time points** corresponding to transitions, improving classifier performance. + +**B. Feature Extraction and Reduction** + +* **State-free Methods (SW, TF):** DFC matrices are vectorized (e.g., 4950 connections for Schaefer 100-region atlas). **Laplacian Eigenmaps (LE)** dimensionality reduction is applied to make the high-dimensional discriminative information accessible to classifiers. +* **State-based Methods (CAP, HMM, etc.):** Features are derived from state probabilities, distances from states, or state weights. These resulting compositional features (shape ``(time, states)``) are transformed using an **isometric log-ratio (ILR) transformation**. + +**C. Prediction and Evaluation** + +* A **Support Vector Machine (SVM) with an RBF kernel** is trained to predict the cognitive state (rest vs. task) at the single-TR level. +* **Balanced Accuracy** is used as the primary metric, ensuring chance performance is 50%. +* **Cognitive State Separability** is quantified using the **Silhouette Index (SI)** to evaluate whether task and rest samples are intrinsically distinguishable in the feature space without supervision. + +5. ``generate_report.py`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +**Function:** Summarizes classification efficacy and separability results for individual datasets and paradigms. + +**Details:** Generates figures, tables, and reports (e.g., heatmaps and boxplots) documenting Balanced Accuracy and SI scores across methods and paradigms. + +6. ``multi_dataset_analysis/`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +**Function:** Aggregates and compares results across multiple datasets and paradigms. + +**Details:** Facilitates **large-scale benchmarking** by calculating aggregate performance statistics across datasets and task paradigms. Run via ``run_scripts_slurm/run_across_dataset_analysis.sh``. + +Available scripts: + +* ``performance_predict.py`` — prediction accuracy across datasets and methods +* ``performance_factor.py`` — factors driving performance differences +* ``ml_results.py`` — ML pipeline result summaries +* ``dfc_visualization.py`` — dFC pattern visualizations +* ``embedding_visualization.py`` — low-dimensional embedding visualizations +* ``sample_matrix_visualization.py`` — sample dFC matrix plots +* ``task_presence_binarization.py`` — task label binarization diagnostics +* ``task_timing_stats.py`` — task timing statistics +* ``cohensd.py`` — effect size analysis + +Running the Pipeline +-------------------- + +Cluster job scripts are provided in two formats: + +* ``run_scripts_slurm/`` — for SLURM-based clusters (e.g., Compute Canada / Alliance) +* ``run_scripts_sge/`` — for SGE-based clusters + +Each script contains a **"Cluster configuration"** block at the top where you set your virtual environment path and pydfc code directory before submitting. Array jobs (``run_dFC.sh``, ``run_nifti_to_roi.sh``, ``run_fmriprep.sh``) expect a ``subj_list.txt`` file listing one subject ID per line. + +Typical submission order:: + + sbatch run_fmriprep.sh # (optional) if preprocessing not done + sbatch run_nifti_to_roi.sh + sbatch run_FCS.sh + sbatch --array=1-N run_dFC.sh + sbatch run_ML.sh + sbatch run_report.sh # (optional) + sbatch run_across_dataset_analysis.sh diff --git a/task_dFC/dFC_assessment.py b/task_dFC/dFC_assessment.py index a381f95..1269e37 100644 --- a/task_dFC/dFC_assessment.py +++ b/task_dFC/dFC_assessment.py @@ -1,10 +1,12 @@ +import argparse +import json import os import time import warnings import numpy as np -from pydfc import MultiAnalysis, data_loader +from pydfc import data_loader, multi_analysis_utils warnings.simplefilter("ignore") @@ -12,77 +14,102 @@ os.environ["NUMEXPR_NUM_THREADS"] = "16" os.environ["OMP_NUM_THREADS"] = "16" -################################# Parameters ################################# - -# Data parameters -# main_root = '../../DATA/ds002785/' # for local -main_root = "../../../DATA/task-based/openneuro/ds002785/" # for server - -# subjects used for dFC assessment do not need to be the same as those used for FCS_estimate -# you can set the new roi root and data load parameters here: -roi_root = f"{main_root}/derivatives/ROI_timeseries" -fitted_measures_root = f"{main_root}/derivatives/fitted_MEASURES" -output_root = f"{main_root}/derivatives/dFC_assessed" - -# for consistency we use 0 for resting state -TASKS = [ - "task-restingstate", - "task-anticipation", - "task-emomatching", - "task-faces", - "task-gstroop", - "task-workingmemory", -] - -# find all subjects across all tasks -SUBJECTS = data_loader.find_subj_list(data_root=roi_root, sessions=TASKS) - -# job_id selects the subject -job_id = int(os.getenv("SGE_TASK_ID")) -if job_id > len(SUBJECTS): - print("job_id > len(SUBJECTS)") - exit() -subj_id = SUBJECTS[job_id - 1] # SGE_TASK_ID starts from 1 not 0 - -for task in TASKS: - - MA = np.load( - f"{fitted_measures_root}/{task}/multi_analysis.npy", allow_pickle="TRUE" - ).item() - - # check if the subject has this task - SUBJECTS_with_this_task = data_loader.find_subj_list( - data_root=roi_root, sessions=[task] - ) - if not subj_id in SUBJECTS_with_this_task: - print(f"subject {subj_id} not in the list of subjects with task {task}") - continue - +################################# Functions ################################# + + +def run_dFC_assess( + subj_id, + task, + roi_root, + fitted_measures_root, + output_root, + params_multi_analysis, + session=None, + run=None, +): + if session is None: + output_dir = f"{output_root}/{subj_id}" + fitted_measures_dir = f"{fitted_measures_root}" + else: + output_dir = f"{output_root}/{subj_id}/{session}" + fitted_measures_dir = f"{fitted_measures_root}/{session}" + + if run is None: + if session is None: + print(f"Subject-level dFC assessment started for TASK: {task} ...") + input_root = f"{roi_root}/{subj_id}" + BOLD_file_name = "{subj_id}_{task}_time-series.npy" + file_suffix = f"{task}" + else: + print( + f"Subject-level dFC assessment started for Session {session}, TASK: {task} ..." + ) + input_root = f"{roi_root}/{subj_id}/{session}" + BOLD_file_name = "{subj_id}_{session}_{task}_time-series.npy" + file_suffix = f"{session}_{task}" + else: + if session is None: + print( + f"Subject-level dFC assessment started for TASK: {task}, RUN: {run} ..." + ) + input_root = f"{roi_root}/{subj_id}" + BOLD_file_name = "{subj_id}_{task}_{run}_time-series.npy" + file_suffix = f"{task}_{run}" + else: + print( + f"Subject-level dFC assessment started for Session {session}, TASK: {task}, RUN: {run} ..." + ) + input_root = f"{roi_root}/{subj_id}/{session}" + BOLD_file_name = "{subj_id}_{session}_{task}_{run}_time-series.npy" + file_suffix = f"{session}_{task}_{run}" + + # check if the subject has this task in roi_root + if not os.path.exists(input_root): + print(f"{input_root} not found in {roi_root}") + return + + ALL_ROI_FILES = os.listdir(f"{input_root}/") + ALL_ROI_FILES = [ + roi_file + for roi_file in ALL_ROI_FILES + if ("_time-series.npy" in roi_file) and (f"_{task}_" in roi_file) + ] + if session is not None: + ALL_ROI_FILES = [ + roi_file for roi_file in ALL_ROI_FILES if (f"_{session}_" in roi_file) + ] + if run is not None: + ALL_ROI_FILES = [ + roi_file for roi_file in ALL_ROI_FILES if (f"_{run}_" in roi_file) + ] + ALL_ROI_FILES.sort() + + # if there are no files for this task, return + if not len(ALL_ROI_FILES) >= 1: + print(f"No time series files found for {subj_id} {file_suffix}") + return ################################# LOAD FIT MEASURES ################################# - ALL_RECORDS = os.listdir(f"{fitted_measures_root}/{task}/") - ALL_RECORDS = [i for i in ALL_RECORDS if "MEASURE" in i] + ALL_RECORDS = os.listdir(f"{fitted_measures_dir}/") + ALL_RECORDS = [ + i for i in ALL_RECORDS if ("MEASURE" in i) and (f"_{file_suffix}_" in i) + ] ALL_RECORDS.sort() MEASURES_fit_lst = list() for s in ALL_RECORDS: - fit_measure = np.load( - f"{fitted_measures_root}/{task}/{s}", allow_pickle="TRUE" - ).item() + fit_measure = np.load(f"{fitted_measures_dir}/{s}", allow_pickle="TRUE").item() MEASURES_fit_lst.append(fit_measure) - MA.set_MEASURES_fit_lst(MEASURES_fit_lst) - print("fitted MEASURES loaded ...") + print("fitted MEASURES are loaded ...") ################################# LOAD DATA ################################# - print( - f"subject-level dFC assessment CODE started running ... for task {task} of subject {subj_id} ..." - ) - BOLD = data_loader.load_TS( data_root=roi_root, - file_name="time_series.npy", - SESSIONs=[task], + file_name=BOLD_file_name, subj_id2load=subj_id, + task=task, + session=session, + run=run, ) ################################# dFC ASSESSMENT ################################# @@ -91,18 +118,129 @@ print("Measurement Started ...") print("dFC estimation started...") - dFC_dict = MA.subj_lvl_dFC_assess(time_series_dict=BOLD) + dFC_dict = multi_analysis_utils.subj_lvl_dFC_assess( + time_series=BOLD, + MEASURES_fit_lst=MEASURES_fit_lst, + n_jobs=params_multi_analysis["n_jobs"], + verbose=params_multi_analysis["verbose"], + backend=params_multi_analysis["backend"], + ) print("dFC estimation done.") print(f"Measurement required {time.time() - tic:0.3f} seconds.") ################################# SAVE DATA ################################# - folder = f"{output_root}/{task}/{subj_id}" + folder = f"{output_dir}/" if not os.path.exists(folder): os.makedirs(folder) for dFC_id, dFC in enumerate(dFC_dict["dFC_lst"]): - np.save(f"{folder}/dFC_{str(dFC_id)}.npy", dFC) + + # Optional: cast each dFC to float32 to save space + dFC.FCSs_ = { + key: value.astype(np.float32, copy=False) for key, value in dFC.FCSs_.items() + } + + np.save(f"{folder}dFC_{file_suffix}_{dFC_id}.npy", dFC) + + +####################################################################################### + +if __name__ == "__main__": + # argparse + HELPTEXT = """ + Script to assess dFC for a given participant. + """ + + parser = argparse.ArgumentParser(description=HELPTEXT) + + parser.add_argument("--dataset_info", type=str, help="path to dataset info file") + parser.add_argument("--methods_config", type=str, help="methods config file") + parser.add_argument("--participant_id", type=str, help="participant id") + + args = parser.parse_args() + + dataset_info_file = args.dataset_info + methods_config_file = args.methods_config + participant_id = args.participant_id + + # Read dataset info + with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + + # Read methods config + with open(methods_config_file, "r") as f: + methods_config = json.load(f) + + print( + f"subject-level dFC assessment CODE started running ... for subject: {participant_id} ..." + ) + + TASKS = dataset_info["TASKS"] + + if "{dataset}" in dataset_info["main_root"]: + main_root = dataset_info["main_root"].replace( + "{dataset}", dataset_info["dataset"] + ) + else: + main_root = dataset_info["main_root"] + + if "{main_root}" in dataset_info["roi_root"]: + roi_root = dataset_info["roi_root"].replace("{main_root}", main_root) + else: + roi_root = dataset_info["roi_root"] + + if "{main_root}" in dataset_info["fitted_measures_root"]: + fitted_measures_root = dataset_info["fitted_measures_root"].replace( + "{main_root}", main_root + ) + else: + fitted_measures_root = dataset_info["fitted_measures_root"] + + if "{main_root}" in dataset_info["dFC_root"]: + output_root = dataset_info["dFC_root"].replace("{main_root}", main_root) + else: + output_root = dataset_info["dFC_root"] + + if "SESSIONS" in dataset_info: + SESSIONS = dataset_info["SESSIONS"] + else: + SESSIONS = None + if SESSIONS is None: + SESSIONS = [None] + + if "RUNS" in dataset_info: + RUNS = dataset_info["RUNS"] + else: + RUNS = None + if RUNS is None: + RUNS = {task: [None] for task in TASKS} + + params_multi_analysis = methods_config["params_multi_analysis"] + + for session in SESSIONS: + for task in TASKS: + for run in RUNS[task]: + try: + run_dFC_assess( + subj_id=participant_id, + task=task, + roi_root=roi_root, + fitted_measures_root=fitted_measures_root, + output_root=output_root, + params_multi_analysis=params_multi_analysis, + session=session, + run=run, + ) + except Exception as e: + print( + f"Error in dFC assessment for subject {participant_id}, task {task}, session {session}, run {run}: {e}" + ) + continue + + print( + f"subject-level dFC assessment CODE finished running for subject: {participant_id}" + ) ####################################################################################### diff --git a/task_dFC/generate_report.py b/task_dFC/generate_report.py new file mode 100644 index 0000000..3d33afc --- /dev/null +++ b/task_dFC/generate_report.py @@ -0,0 +1,1550 @@ +import argparse +import json +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from nilearn import image, masking, plotting +from nilearn.glm.first_level import FirstLevelModel +from sklearn.cluster import KMeans +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +from pydfc import DFC, data_loader, task_utils +from pydfc.dfc_utils import ( + TR_intersection, + dFC_mat2vec, + dFC_vec2mat, + rank_norm, + visualize_conn_mat_dict, +) +from pydfc.report_util import plot_classification_metrics, plot_clustering_metrics + +################################# Parameters #################################### + +fig_dpi = 120 +fig_bbox_inches = "tight" +fig_pad = 0.1 +show_title = True +save_fig_format = "png" # pdf, png, + +####################################################################################### + + +def load_dFC(dFC_root, subj, task, dFC_id, run=None, session=None): + """ + Load the dFC results for a given subject, task, dFC_id, run and session. + """ + if session is None: + if run is None: + dFC = np.load( + f"{dFC_root}/{subj}/dFC_{task}_{dFC_id}.npy", allow_pickle="TRUE" + ).item() + else: + dFC = np.load( + f"{dFC_root}/{subj}/dFC_{task}_{run}_{dFC_id}.npy", allow_pickle="TRUE" + ).item() + else: + if run is None: + dFC = np.load( + f"{dFC_root}/{subj}/{session}/dFC_{session}_{task}_{dFC_id}.npy", + allow_pickle="TRUE", + ).item() + else: + dFC = np.load( + f"{dFC_root}/{subj}/{session}/dFC_{session}_{task}_{run}_{dFC_id}.npy", + allow_pickle="TRUE", + ).item() + + return dFC + + +def load_task_data(roi_root, subj, task, run=None, session=None): + """ + Load the task data for a given subject, task and run. + """ + if session is None: + if run is None: + task_data = np.load( + f"{roi_root}/{subj}/{subj}_{task}_task-data.npy", allow_pickle="TRUE" + ).item() + else: + task_data = np.load( + f"{roi_root}/{subj}/{subj}_{task}_{run}_task-data.npy", + allow_pickle="TRUE", + ).item() + else: + if run is None: + task_data = np.load( + f"{roi_root}/{subj}/{session}/{subj}_{session}_{task}_task-data.npy", + allow_pickle="TRUE", + ).item() + else: + task_data = np.load( + f"{roi_root}/{subj}/{session}/{subj}_{session}_{task}_{run}_task-data.npy", + allow_pickle="TRUE", + ).item() + + return task_data + + +def get_func_data(fmriprep_root, subj, task, bold_suffix, run=None, session=None): + if session is None: + ALL_TASK_FILES = os.listdir(f"{fmriprep_root}/{subj}/func/") + else: + ALL_TASK_FILES = os.listdir(f"{fmriprep_root}/{subj}/{session}/func/") + + ALL_TASK_FILES = [ + file_i + for file_i in ALL_TASK_FILES + if (bold_suffix in file_i) and (f"_{task}_" in file_i) + ] + + if not len(ALL_TASK_FILES) >= 1: + return None + + if run is None: + task_file = ALL_TASK_FILES[0] + else: + task_file = [file_i for file_i in ALL_TASK_FILES if f"_{run}_" in file_i][0] + if session is None: + func_file = f"{fmriprep_root}/{subj}/func/{task_file}" + else: + func_file = f"{fmriprep_root}/{subj}/{session}/func/{task_file}" + + return func_file + + +# def plot_anatomical( +# fmriprep_root, +# subj, +# anat_suffix, +# session=None, +# ): +# anat_suffix = '_space-MNI152NLin2009cAsym_desc-preproc_T1w.nii.gz' +# anat_file = f"{fmriprep_root}/{subj}/anat/{subj}{anat_suffix}" +# display = plotting.plot_anat(anat_file, title="plot_anat") + + +# def plot_functional( +# fmriprep_root, +# subj, +# bold_suffix, +# task, +# session=None, +# run=None, +# ): +# if session is None: +# if run is None: +# task_file = f"{subj}_{task}{bold_suffix}" +# else: +# task_file = f"{subj}_{task}_{run}{bold_suffix}" +# func_file = f"{fmriprep_root}/{subj}/func/{task_file}" +# else: +# if run is None: +# task_file = f"{subj}_{session}_{task}{bold_suffix}" +# else: +# task_file = f"{subj}_{session}_{task}_{run}{bold_suffix}" +# func_file = f"{fmriprep_root}/{subj}/{session}/func/{task_file}" + +# # Compute voxel-wise mean functional image across time dimension. Now we have +# # functional image in 3D assigned in mean_func_img +# mean_func_img = image.mean_img(func_file) +# display = plotting.plot_anat(mean_func_img, title="plot_func") + + +def get_events_df(events, trial_type_label="trial_type", rest_labels=["rest", "Rest"]): + # find which column is the "onset" in the first row + onset_idx = np.where(events[0, :] == "onset")[0][0] + duration_idx = np.where(events[0, :] == "duration")[0][0] + if trial_type_label is not None: + trial_type_idx = np.where(events[0, :] == trial_type_label)[0][0] + + # assign the time between active onsets to 'rest' + events_new = [] + prev_onset = 0.0 + for i in range(1, events.shape[0]): + + if trial_type_label is not None: + if events[i, trial_type_idx] in rest_labels: + continue + + current_onset = float(events[i, onset_idx]) + current_duration = float(events[i, duration_idx]) + rest_duration = current_onset - prev_onset + if rest_duration > 0.0: + events_new.append([prev_onset, rest_duration, "rest"]) + events_new.append([current_onset, current_duration, "active"]) + prev_onset = current_onset + current_duration + + events_new = np.array(events_new) + + # convert to pandas dataframe + events_df = pd.DataFrame(events_new, columns=["onset", "duration", "trial_type"]) + + return events_df + + +def plot_glm( + fmriprep_root, + roi_root, + subj, + task, + bold_suffix, + trial_type_label, + rest_labels, + output_root, + run=None, + session=None, +): + + func_file = get_func_data( + fmriprep_root=fmriprep_root, + subj=subj, + task=task, + bold_suffix=bold_suffix, + run=run, + session=session, + ) + task_data = load_task_data(roi_root, subj, task, run, session) + TR_mri = task_data["TR_mri"] + + events_df = get_events_df( + events=task_data["events"], + trial_type_label=trial_type_label, + rest_labels=rest_labels, + ) + + # Make an average + mean_img = image.mean_img(func_file) + mask = masking.compute_epi_mask(mean_img) + + # Clean and smooth data + fmri_img = image.clean_img(func_file, standardize=False) + fmri_img = image.smooth_img(fmri_img, 5.0) + + fmri_glm = FirstLevelModel( + t_r=TR_mri, + drift_model="cosine", + signal_scaling=False, + mask_img=mask, + minimize_memory=False, + ) + + fmri_glm = fmri_glm.fit(fmri_img, events_df) + + z_map = fmri_glm.compute_contrast("active - rest") + + plotting.plot_stat_map(z_map, bg_img=mean_img, threshold=3.1) + + # save the figure + output_dir = f"{output_root}/subject_results/{subj}/GLM" + if session is not None: + output_dir = f"{output_dir}/{session}" + output_dir = f"{output_dir}/{task}" + if run is not None: + output_dir = f"{output_dir}/{run}" + output_dir = f"{output_dir}/" + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig( + f"{output_dir}/glm.{save_fig_format}", + dpi=fig_dpi, + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + format=save_fig_format, + ) + + plt.close() + + +def plot_roi_signals( + roi_root, + subj, + task, + start_time, + end_time, + output_root, + nodes_list=range(0, 10), + session=None, + run=None, +): + if session is None: + if run is None: + file_name = "{subj_id}_{task}_time-series.npy" + else: + file_name = "{subj_id}_{task}_{run}_time-series.npy" + else: + if run is None: + file_name = "{subj_id}_{session}_{task}_time-series.npy" + else: + file_name = "{subj_id}_{session}_{task}_{run}_time-series.npy" + + task_data = load_task_data(roi_root, subj, task, run, session) + TR_mri = task_data["TR_mri"] + + BOLD = data_loader.load_TS( + data_root=roi_root, + file_name=file_name, + subj_id2load=subj, + task=task, + run=run, + session=session, + ) + + time = np.arange(0, BOLD.data.shape[1]) * TR_mri + start_TR = int(start_time / TR_mri) + end_TR = int(end_time / TR_mri) + # keep the figure width proportional to the number of time points + fig_width = int(2.5 * (end_time - start_time) / 2) + fig_width = min(fig_width, 500) + plt.figure(figsize=(fig_width, 5)) + for i in nodes_list: + plt.plot(time[start_TR:end_TR], BOLD.data[i, start_TR:end_TR], linewidth=4) + # put vertical lines at the start of each TR + for TR in range(start_TR, end_TR): + plt.axvline(x=TR * TR_mri, color="r", linestyle="--") + # show TR labels on the red lines with a small font and at the top + for TR in range(start_TR, end_TR): + plt.text(TR * TR_mri, 1.2, f"TR {TR}", fontsize=8, color="black", ha="center") + if show_title: + plt.title("ROI signals") + plt.xlabel("Time (s)") + + # save the figure + output_dir = f"{output_root}/subject_results/{subj}/ROI_signals" + if session is not None: + output_dir = f"{output_dir}/{session}" + output_dir = f"{output_dir}/{task}" + if run is not None: + output_dir = f"{output_dir}/{run}" + output_dir = f"{output_dir}/" + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig( + f"{output_dir}/ROI_signals.{save_fig_format}", + dpi=fig_dpi, + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + format=save_fig_format, + ) + + plt.close() + + +def plot_event_labels( + roi_root, + subj, + task, + start_time, + end_time, + output_root, + run=None, + session=None, +): + task_data = load_task_data(roi_root, subj, task, run, session) + Fs_task = task_data["Fs_task"] + TR_task = 1 / Fs_task + # TR_mri = task_data["TR_mri"] + + time = np.arange(0, task_data["event_labels"].shape[0]) / Fs_task + start_timepoint = int(start_time / TR_task) + end_timepoint = int(end_time / TR_task) + # keep the figure width proportional to the number of time points + fig_width = int(2.5 * (end_time - start_time) / 2) + fig_width = min(fig_width, 500) + plt.figure(figsize=(fig_width, 5)) + plt.plot( + time[start_timepoint:end_timepoint], + task_data["event_labels"][start_timepoint:end_timepoint], + linewidth=4, + ) + plt.title("Event labels") + plt.xlabel("Time (s)") + + # save the figure + output_dir = f"{output_root}/subject_results/{subj}/event_labels" + if session is not None: + output_dir = f"{output_dir}/{session}" + output_dir = f"{output_dir}/{task}" + if run is not None: + output_dir = f"{output_dir}/{run}" + output_dir = f"{output_dir}/" + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig( + f"{output_dir}/event_labels.{save_fig_format}", + dpi=fig_dpi, + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + format=save_fig_format, + ) + + plt.close() + + +def plot_task_presence( + roi_root, + subj, + task, + start_time, + end_time, + output_root, + run=None, + session=None, +): + task_data = load_task_data(roi_root, subj, task, run, session) + Fs_task = task_data["Fs_task"] + TR_task = 1 / Fs_task + TR_mri = task_data["TR_mri"] + Fs_mri = 1 / TR_mri + + task_presence_non_binarized, _ = task_utils.extract_task_presence( + event_labels=task_data["event_labels"], + TR_task=TR_task, + TR_mri=task_data["TR_mri"], + binary=False, + ) + + task_presence, indices = task_utils.extract_task_presence( + event_labels=task_data["event_labels"], + TR_task=TR_task, + TR_mri=task_data["TR_mri"], + binary=True, + binarizing_method="GMM", + ) + + time = np.arange(0, task_presence.shape[0]) / Fs_mri + start_TR = int(start_time / TR_mri) + end_TR = int(end_time / TR_mri) + # keep the figure width proportional to the number of time points in data + fig_width = int(2.5 * (end_time - start_time) / 2) + fig_width = min(fig_width, 500) + plt.figure(figsize=(fig_width, 5)) + plt.plot( + time[start_TR:end_TR], task_presence_non_binarized[start_TR:end_TR], linewidth=4 + ) + plt.plot(time[start_TR:end_TR], task_presence[start_TR:end_TR], linewidth=4) + + # put vertical lines at the start of each TR + for TR in range(start_TR, end_TR): + if TR in indices: + plt.axvline(x=TR * TR_mri, color="g", linestyle="--") + else: + plt.axvline(x=TR * TR_mri, color="r", linestyle="--") + # show TR labels on the red lines with a small font and at the top + for TR in range(start_TR, end_TR): + plt.text(TR * TR_mri, 1.2, f"TR {TR}", fontsize=8, color="black", ha="center") + plt.title("Task presence") + plt.xlabel("Time (s)") + + # save the figure + output_dir = f"{output_root}/subject_results/{subj}/task_presence" + if session is not None: + output_dir = f"{output_dir}/{session}" + output_dir = f"{output_dir}/{task}" + if run is not None: + output_dir = f"{output_dir}/{run}" + output_dir = f"{output_dir}/" + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig( + f"{output_dir}/task_presence.{save_fig_format}", + dpi=fig_dpi, + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + format=save_fig_format, + ) + + plt.close() + + +# def plot_FCS(): +# visualize_FCS( +# measure, +# normalize=True, +# fix_lim=False, +# save_image=save_image, +# output_root=output_root + "FCS/", +# ) + + +def plot_dFC_matrices( + dFC_root, + subj, + task, + start_time, + end_time, + output_root, + run=None, + session=None, +): + """ + plot dFC matrices for a given subject, task, run, session, start_time and end_time + parameters: + ---------- + dFC_root: str, path to dFC results + subj: str, subject id + task: str, task name + start_time: float, start time in seconds + end_time: float, end time in seconds + """ + task_data = load_task_data(roi_root, subj, task, run, session) + TR_mri = task_data["TR_mri"] + + dFC_lst = list() + for dFC_id in range(0, 20): # change this to the number of dFCs you have + try: + dFC = load_dFC(dFC_root, subj, task, dFC_id, run, session) + dFC_lst.append(dFC) + except Exception: + pass + + TRs = TR_intersection(dFC_lst) + start_TR = int(start_time / TR_mri) + end_TR = int(end_time / TR_mri) + start_TR_idx = np.where(np.array(TRs) >= start_TR)[0][0] + end_TR_idx = np.where(np.array(TRs) <= end_TR)[0][-1] + # if the TR_mri is low which will cause the figure to be too wide, + # we will only plot a resampled version of the dFC matrices, e.g. to make it the same as TR_mri=2s + if TR_mri < 2: + TR_step = int(2 / TR_mri) + chosen_TRs = TRs[start_TR_idx:end_TR_idx:TR_step] + # raise warning if the TR_mri is low + print( + f"TR_mri is low ({TR_mri}s), the dFC matrices will be resampled to make the figure width reasonable" + ) + else: + chosen_TRs = TRs[start_TR_idx:end_TR_idx] + + output_dir = f"{output_root}/subject_results/{subj}/dFC_matrices" + if session is not None: + output_dir = f"{output_dir}/{session}" + output_dir = f"{output_dir}/{task}" + if run is not None: + output_dir = f"{output_dir}/{run}" + output_dir = f"{output_dir}/" + + for dFC in dFC_lst: + dFC.visualize_dFC( + TRs=chosen_TRs, + normalize=False, + rank_norm=True, + fix_lim=False, + save_image=True, + output_root=output_dir, + ) + + +def plot_ML_results( + ML_root, + output_root, + task, + run=None, + session=None, + ML_algorithms=["KNN"], + embedding="PCA", +): + """ + Plot the ML classification results plus SI score for a given task, run and session. + parameters: + ---------- + ML_root: str, path to ML results + output_root: str, path to save the figures + task: str, task name + run: int, run number + session: str, session name + ML_algorithms: list of str, list of ML algorithm name (default: KNN, other options: Logistic regression, SVM, Gradient Boosting, RF) + embedding: str, embedding method (default: PCA, other options: LE) + """ + # the ML_scores files are saved as ML_scores_classify_{dFC_id}.npy + # find all the ML_scores files in the directory + if session is None: + input_dir = f"{ML_root}/classification" + else: + input_dir = f"{ML_root}/classification/{session}" + ALL_ML_SCORES = os.listdir(input_dir) + ALL_ML_SCORES = [ + score_file for score_file in ALL_ML_SCORES if "ML_scores_classify" in score_file + ] + ALL_ML_SCORES.sort() + ML_scores = None + for score_file in ALL_ML_SCORES: + ML_scores_new = np.load(f"{input_dir}/{score_file}", allow_pickle="TRUE").item() + ML_scores_new = ML_scores_new["subj_lvl"] + if ML_scores is None: + ML_scores = ML_scores_new + else: + for key in ML_scores_new.keys(): + ML_scores[key].extend(ML_scores_new[key]) + + sns.set_context("paper", font_scale=1.0, rc={"lines.linewidth": 1.0}) + + sns.set_style("darkgrid") + + dataframe = pd.DataFrame(ML_scores) + if run is not None: + dataframe = dataframe[dataframe["run"] == run] + + dataframe = dataframe[dataframe["task"] == task] + dataframe = dataframe[dataframe["embedding"] == embedding] + + # save the figure + if session is None: + output_dir = f"{output_root}/group_results/classification" + else: + output_dir = f"{output_root}/group_results/classification/{session}" + + metrics = [ + # "accuracy", + "balanced accuracy", + "precision", + "recall", + # "f1", + # "tp", + # "tn", + # "fp", + # "fn", + # "average precision", + ] + + for ML_algorithm in ML_algorithms: + if ML_algorithm == "Logistic regression": + ML_algorithm_name = "LogReg" + elif ML_algorithm == "SVM": + ML_algorithm_name = "SVM" + elif ML_algorithm == "KNN": + ML_algorithm_name = "KNN" + elif ML_algorithm == "Random Forest": + ML_algorithm_name = "RF" + elif ML_algorithm == "Gradient Boosting": + ML_algorithm_name = "GBT" + + if run is None: + suffix = f"{ML_algorithm_name}_{task}_{embedding}" + else: + suffix = f"{ML_algorithm_name}_{task}_{run}_{embedding}" + + for metric in metrics: + plot_classification_metrics( + dataframe=dataframe, + ML_algorithm=ML_algorithm, + pred_metric=metric, + title=task, + suffix=suffix, + output_dir=output_dir, + ) + + # Clustering SI score + + # save the figure + if run is None: + suffix = f"{task}_{embedding}" + else: + suffix = f"{task}_{run}_{embedding}" + + if session is None: + output_dir = f"{output_root}/group_results/clustering" + else: + output_dir = f"{output_root}/group_results/clustering/{session}" + + plot_clustering_metrics( + dataframe=dataframe, + metric="SI", + title=task, + suffix=suffix, + output_dir=output_dir, + ) + + +def plot_visual_clstr_centroids( + ML_root, + output_root, + session=None, +): + """ """ + # the centroids files are saved as centroids_{session}_{task}_{run}_{measure_name}.npy + # find all the centroids files in the directory + if session is None: + input_dir = f"{ML_root}/centroids" + else: + input_dir = f"{ML_root}/centroids/{session}" + + output_dir = f"{output_root}/group_results/visual_clustering_centroids" + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + ALL_CENTROID_RESULTS = os.listdir(input_dir) + ALL_CENTROID_RESULTS = [ + result_file for result_file in ALL_CENTROID_RESULTS if "centroids_" in result_file + ] + ALL_CENTROID_RESULTS.sort() + + for result_file in ALL_CENTROID_RESULTS: + centroids_results = np.load( + f"{input_dir}/{result_file}", allow_pickle="TRUE" + ).item() + centroids_mat = centroids_results["centroids_mat"] + co_occurrence_matrix = centroids_results["co_occurrence_matrix"] + cluster_label_percentage = centroids_results["cluster_label_percentage"] + task_label_percentage = centroids_results["task_label_percentage"] + + # result_file is centroids_{session}_{task}_{run}_{measure_name}.npy + # suffix is whatever comes after the centroids and before .npy + suffix = result_file.split("centroids_")[1].split(".npy")[0] + + centroids_dict = {} + for i, centroid_mat in enumerate(centroids_mat): + centroids_dict[f"Cluster {i + 1}"] = centroid_mat + + visualize_conn_mat_dict( + data=centroids_dict, + title=f"visual-centroids_{suffix}", + cmap="seismic", + normalize=True, + disp_diag=False, + save_image=True, + output_root=f"{output_dir}/", + center_0=True, + # node_networks=None, + ) + + # plot co-occurrence matrix and cluster label percentage and task label percentage + # as a seaborn heatmap with numbers in the cells + # as separate figures + + # plot co-occurrence matrix + plt.figure(figsize=(20, 10)) + sns.heatmap( + co_occurrence_matrix, + annot=True, + fmt=".0f", + cmap="Reds", + cbar_kws={"label": "Co-occurrence"}, + yticklabels=["rest", "task"], + xticklabels=[str(i + 1) for i in range(co_occurrence_matrix.shape[1])], + ) + plt.title("Co-occurrence matrix") + plt.xlabel("Cluster") + plt.ylabel("Task") + plt.savefig( + f"{output_dir}/co-occurrence-matrix_{suffix}.{save_fig_format}", + dpi=fig_dpi, + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + format=save_fig_format, + ) + plt.close() + + # plot cluster label percentage + plt.figure(figsize=(20, 10)) + sns.heatmap( + cluster_label_percentage, + annot=True, + fmt=".2f", + cmap="Reds", + cbar_kws={"label": "Percentage"}, + yticklabels=["rest", "task"], + xticklabels=[str(i + 1) for i in range(co_occurrence_matrix.shape[1])], + ) + plt.title("Cluster label percentage") + plt.xlabel("Cluster") + plt.ylabel("Task") + plt.savefig( + f"{output_dir}/cluster-label-percentage_{suffix}.{save_fig_format}", + dpi=fig_dpi, + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + format=save_fig_format, + ) + plt.close() + + # plot task label percentage + plt.figure(figsize=(20, 10)) + sns.heatmap( + task_label_percentage, + annot=True, + fmt=".2f", + cmap="Reds", + cbar_kws={"label": "Percentage"}, + yticklabels=["rest", "task"], + xticklabels=[str(i + 1) for i in range(co_occurrence_matrix.shape[1])], + ) + plt.title("Task label percentage") + plt.xlabel("Cluster") + plt.ylabel("Task") + plt.savefig( + f"{output_dir}/task-label-percentage_{suffix}.{save_fig_format}", + dpi=fig_dpi, + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + format=save_fig_format, + ) + plt.close() + + +def plot_task_presence_features( + ML_root, + output_root, + session=None, + run=None, +): + """ + Plot the task presence features for a given session and run. + Features for both with and without HRF are plotted. + for comparability of tasks, pass the same run number for all tasks + parameters: + ---------- + ML_root: str, path to ML results + output_root: str, path to save the figures + session: str, session name + run: int, run number + """ + if session is None: + task_features = np.load( + f"{ML_root}/task_features/task_features.npy", allow_pickle="TRUE" + ).item() + task_features_hrf = np.load( + f"{ML_root}/task_features/task_features_hrf.npy", allow_pickle="TRUE" + ).item() + else: + task_features = np.load( + f"{ML_root}/task_features/{session}/task_features.npy", allow_pickle="TRUE" + ).item() + task_features_hrf = np.load( + f"{ML_root}/task_features/{session}/task_features_hrf.npy", + allow_pickle="TRUE", + ).item() + + sns.set_context("paper", font_scale=1.0, rc={"lines.linewidth": 1.0}) + + sns.set_style("darkgrid") + + task_features_df = pd.DataFrame(task_features) + task_features_hrf_df = pd.DataFrame(task_features_hrf) + if run is not None: + task_features_df = task_features_df[task_features_df["run"] == run] + task_features_hrf_df = task_features_hrf_df[task_features_hrf_df["run"] == run] + + # FEATURES are columns in the dataframe except for 'task' and 'run' + FEATURES = list(task_features_df.columns) + FEATURES.remove("task") + FEATURES.remove("run") + + if session is None: + output_dir = f"{output_root}/group_results/task_presence_features" + else: + output_dir = f"{output_root}/group_results/task_presence_features/{session}" + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + for i, feature in enumerate(FEATURES): + plt.figure(figsize=(10, 5)) + g = sns.pointplot( + data=task_features_df, + x="task", + y=feature, + errorbar="sd", + linestyle="none", + dodge=True, + capsize=0.1, + ) + plt.xlabel(g.get_xlabel(), fontweight="bold") + plt.ylabel(g.get_ylabel(), fontweight="bold") + plt.xticks(fontweight="bold") + plt.yticks(fontweight="bold") + + # save the figure + plt.savefig( + f"{output_dir}/task_presence_features_{feature}.{save_fig_format}", + dpi=fig_dpi, + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + format=save_fig_format, + ) + plt.close() + + plt.figure(figsize=(10, 5)) + g = sns.pointplot( + data=task_features_hrf_df, + x="task", + y=feature, + errorbar="sd", + linestyle="none", + dodge=True, + capsize=0.1, + ) + plt.xlabel(g.get_xlabel(), fontweight="bold") + plt.ylabel(g.get_ylabel(), fontweight="bold") + plt.xticks(fontweight="bold") + plt.yticks(fontweight="bold") + + # save the figure + plt.savefig( + f"{output_dir}/task_presence_features_hrf_{feature}.{save_fig_format}", + dpi=fig_dpi, + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + format=save_fig_format, + ) + plt.close() + + +def create_html_report_subj_results( + subj, + SESSIONS, + TASKS, + RUNS, + reports_root, +): + """ + This function creates an html report for the subject results + using the generated figures. + """ + + # create html report + subj_dir = f"{reports_root}/subject_results/{subj}" + file = open(f"{subj_dir}/report.html", "w") + file.write("\n") + file.write("\n") + file.write(f"Subject {subj} Results\n") + file.write("\n") + file.write("\n") + file.write(f"

Subject {subj} Results

\n") + for session in SESSIONS: + if session is not None: + file.write(f"

{session}

\n") + for task in TASKS: + file.write(f"

{task}

\n") + for run in RUNS[task]: + if run is not None: + file.write(f"

{run}

\n") + if session is not None: + session_task_run_dir = f"{session}/{task}" + else: + session_task_run_dir = f"{task}" + if run is not None: + session_task_run_dir = f"{session_task_run_dir}/{run}" + + img_height = 100 + + # display GLM + glm_img = f"{subj_dir}/GLM/{session_task_run_dir}/glm.png" + if os.path.exists(glm_img): + img = plt.imread(glm_img) + height, width, _ = img.shape + # change the width so that height equals img_height + width = int(width * img_height / height) + # replace the path to the image with a relative path + glm_img = glm_img.replace(subj_dir, ".") + file.write( + f"GLM\n" + ) + file.write("
\n") + + # display ROI signals + ROI_signals_img = ( + f"{subj_dir}/ROI_signals/{session_task_run_dir}/ROI_signals.png" + ) + if os.path.exists(ROI_signals_img): + img = plt.imread(ROI_signals_img) + height, width, _ = img.shape + # change the width so that height equals img_height + width = int(width * img_height / height) + # replace the path to the image with a relative path + ROI_signals_img = ROI_signals_img.replace(subj_dir, ".") + file.write( + f"ROI signals\n" + ) + file.write("
\n") + + # display event labels + event_labels_img = ( + f"{subj_dir}/event_labels/{session_task_run_dir}/event_labels.png" + ) + if os.path.exists(event_labels_img): + img = plt.imread(event_labels_img) + height, width, _ = img.shape + # change the width so that height equals img_height + width = int(width * img_height / height) + # replace the path to the image with a relative path + event_labels_img = event_labels_img.replace(subj_dir, ".") + file.write( + f"Event labels\n" + ) + file.write("
\n") + + # display task presence + task_presence_img = ( + f"{subj_dir}/task_presence/{session_task_run_dir}/task_presence.png" + ) + if os.path.exists(task_presence_img): + img = plt.imread(task_presence_img) + height, width, _ = img.shape + # change the width so that height equals img_height + width = int(width * img_height / height) + # replace the path to the image with a relative path + task_presence_img = task_presence_img.replace(subj_dir, ".") + file.write( + f"Task presence\n" + ) + file.write("
\n") + + # display dFC matrices + img_height = 45 + # for dFC matrices find all png files in the directory + dFC_matrices_dir = f"{subj_dir}/dFC_matrices/{session_task_run_dir}" + if os.path.exists(dFC_matrices_dir): + for file_name in os.listdir(dFC_matrices_dir): + if file_name.endswith(".png"): + file.write(f"

{file_name[:file_name.find('_dFC')]}

\n") + dFC_matrices_img = f"{dFC_matrices_dir}/{file_name}" + # get the original size of the image + img = plt.imread(dFC_matrices_img) + height, width, _ = img.shape + # change the width so that height equals img_height + width = int(width * img_height / height) + # replace the path to the image with a relative path + dFC_matrices_img = dFC_matrices_img.replace(subj_dir, ".") + file.write( + f"{file_name}\n" + ) + file.write("
\n") + + file.write("\n") + file.write("\n") + file.close() + + +def create_html_report_group_results( + SESSIONS, + TASKS, + RUNS, + reports_root, +): + """ + This function creates an html report for the group results + using the generated figures. + """ + # create html report + group_dir = f"{reports_root}/group_results" + file = open(f"{group_dir}/report.html", "w") + file.write("\n") + file.write("\n") + file.write("Group Results\n") + file.write("\n") + file.write("\n") + file.write("

Group Results

\n") + + # task presence features + img_height = 300 + file.write("

Task Presence Features

\n") + for session in SESSIONS: + if session is not None: + file.write(f"

{session}

\n") + # display task presence features + if session is not None: + task_presence_features_dir = f"{group_dir}/task_presence_features/{session}" + else: + task_presence_features_dir = f"{group_dir}/task_presence_features" + + for condition in ["with_HRF", "without_HRF"]: + file.write(f"

{condition}

\n") + # find all png files in the directory + for file_name in os.listdir(task_presence_features_dir): + if file_name.endswith(".png"): + if (condition == "with_HRF" and "hrf" not in file_name) or ( + condition == "without_HRF" and "hrf" in file_name + ): + continue + task_presence_features_img = ( + f"{task_presence_features_dir}/{file_name}" + ) + # get the original size of the image + img = plt.imread(task_presence_features_img) + height, width, _ = img.shape + # change the width so that height equals img_height + width = int(width * img_height / height) + # replace the path to the image with a relative path + task_presence_features_img = task_presence_features_img.replace( + group_dir, "." + ) + file.write( + f"Task presence features\n" + ) + + file.write("
\n") + + file.write("
\n") + + # classification results + metrics = [ + # "accuracy", + "balanced accuracy", + "precision", + "recall", + # "f1", + # "tp", + # "tn", + # "fp", + # "fn", + # "average precision", + ] + classification_models = {"LogReg": "Logistic Regression", "SVM": "SVM"} + img_height = 300 + file.write("

Classification Results

\n") + for session in SESSIONS: + if session is not None: + file.write(f"

{session}

\n") + for task in TASKS: + file.write(f"

{task}

\n") + for run in RUNS[task]: + if run is not None: + file.write(f"

{run}

\n") + if session is not None: + classification_dir = f"{group_dir}/classification/{session}" + else: + classification_dir = f"{group_dir}/classification" + + for model in classification_models: + file.write(f"

{classification_models[model]}

\n") + for embedding in ["PCA", "LE"]: + file.write(f"

{embedding}

\n") + for metric in metrics: + metric_no_space = metric.replace(" ", "_") + if run is None: + classification_img = f"{classification_dir}/classification_{metric_no_space}_{model}_{task}_{embedding}.png" + else: + classification_img = f"{classification_dir}/classification_{metric_no_space}_{model}_{task}_{run}_{embedding}.png" + if os.path.exists(classification_img): + img = plt.imread(classification_img) + height, width, _ = img.shape + # change the width so that height equals img_height + width = int(width * img_height / height) + # replace the path to the image with a relative path + classification_img = classification_img.replace( + group_dir, "." + ) + file.write( + f"Classification results\n" + ) + + file.write("
\n") + + # clustering results + img_height = 300 + file.write("

Clustering Results

\n") + for session in SESSIONS: + if session is not None: + file.write(f"

{session}

\n") + for task in TASKS: + file.write(f"

{task}

\n") + for run in RUNS[task]: + if run is not None: + file.write(f"

{run}

\n") + if session is not None: + clustering_dir = f"{group_dir}/clustering/{session}" + else: + clustering_dir = f"{group_dir}/clustering" + + for embedding in ["PCA", "LE"]: + file.write(f"

{embedding}

\n") + # display clustering ARI results + if run is None: + clustering_img = ( + f"{clustering_dir}/clustering_SI_{task}_{embedding}.png" + ) + else: + clustering_img = ( + f"{clustering_dir}/clustering_SI_{task}_{run}_{embedding}.png" + ) + if os.path.exists(clustering_img): + img = plt.imread(clustering_img) + height, width, _ = img.shape + # change the width so that height equals img_height + width = int(width * img_height / height) + # replace the path to the image with a relative path + clustering_img = clustering_img.replace(group_dir, ".") + file.write( + f"Clustering results\n" + ) + + file.write("
\n") + + # display visual clustering centroids + img_height = 300 + file.write("

Visual Clustering Centroids

\n") + # find all png files in the directory + visual_clustering_centroids_dir = f"{group_dir}/visual_clustering_centroids" + for session in SESSIONS: + if session is not None: + file.write(f"

{session}

\n") + for task in TASKS: + file.write(f"

{task}

\n") + for run in RUNS[task]: + if run is not None: + file.write(f"

{run}

\n") + + # visual-centroids_{session}_{task}_{run}_{measure_name}.png + all_centroids_img_files = os.listdir(visual_clustering_centroids_dir) + all_centroids_img_files = [ + centroids_img_file + for centroids_img_file in all_centroids_img_files + if "visual-centroids" in centroids_img_file + and f"_{task}" in centroids_img_file + ] + if session is not None: + all_centroids_img_files = [ + centroids_img_file + for centroids_img_file in all_centroids_img_files + if f"_{session}" in centroids_img_file + ] + if run is not None: + all_centroids_img_files = [ + centroids_img_file + for centroids_img_file in all_centroids_img_files + if f"_{run}" in centroids_img_file + ] + all_centroids_img_files.sort() + + for centroids_img_file in all_centroids_img_files: + # iterate over centroids images of different measures + centroid_img = ( + f"{visual_clustering_centroids_dir}/{centroids_img_file}" + ) + measure_name = centroids_img_file.split("_")[-1].split(".")[0] + file.write(f"

{measure_name}

\n") + # get the original size of the image + if os.path.exists(centroid_img): + img = plt.imread(centroid_img) + height, width, _ = img.shape + # change the width so that height equals img_height + width = int(width * img_height / height) + # replace the path to the image with a relative path + centroid_img = centroid_img.replace(group_dir, ".") + file.write( + f"Visual clustering centroids\n" + ) + + # visual-centroids_{suffix}.png + suffix = centroids_img_file[ + centroids_img_file.find("visual-centroids_") + 17 : -4 + ] + + # display co-occurrence matrix + co_occurrence_matrix_img = f"{visual_clustering_centroids_dir}/co-occurrence-matrix_{suffix}.png" + if os.path.exists(co_occurrence_matrix_img): + img = plt.imread(co_occurrence_matrix_img) + height, width, _ = img.shape + # change the width so that height equals img_height + width = int(width * img_height / height) + # replace the path to the image with a relative path + co_occurrence_matrix_img = co_occurrence_matrix_img.replace( + group_dir, "." + ) + file.write( + f"Co-occurrence matrix\n" + ) + + # display cluster label percentage + cluster_label_percentage_img = f"{visual_clustering_centroids_dir}/cluster-label-percentage_{suffix}.png" + if os.path.exists(cluster_label_percentage_img): + img = plt.imread(cluster_label_percentage_img) + height, width, _ = img.shape + # change the width so that height equals img_height + width = int(width * img_height / height) + # replace the path to the image with a relative path + cluster_label_percentage_img = ( + cluster_label_percentage_img.replace(group_dir, ".") + ) + file.write( + f"Cluster label percentage\n" + ) + + # display task label percentage + task_label_percentage_img = f"{visual_clustering_centroids_dir}/task-label-percentage_{suffix}.png" + if os.path.exists(task_label_percentage_img): + img = plt.imread(task_label_percentage_img) + height, width, _ = img.shape + # change the width so that height equals img_height + width = int(width * img_height / height) + # replace the path to the image with a relative path + task_label_percentage_img = task_label_percentage_img.replace( + group_dir, "." + ) + file.write( + f"Task label percentage\n" + ) + + file.write("
\n") + + file.write("\n") + file.write("\n") + file.close() + + +####################################################################################### +if __name__ == "__main__": + # argparse + HELPTEXT = """ + Script to generate a report of subject results. + """ + + parser = argparse.ArgumentParser(description=HELPTEXT) + + parser.add_argument("--dataset_info", type=str, help="path to dataset info file") + parser.add_argument("--subj_list", type=str, help="path to subject list file") + + args = parser.parse_args() + + dataset_info_file = args.dataset_info + subj_list_file = args.subj_list + + # Read dataset info + with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + + # Read subject list file, a txt file with one subject id per line + with open(subj_list_file, "r") as f: + SUBJECTS = f.read().splitlines() + + TASKS = dataset_info["TASKS"] + if "RUNS" in dataset_info: + RUNS = dataset_info["RUNS"] + else: + RUNS = None + if RUNS is None: + RUNS = {task: [None] for task in TASKS} + + if "SESSIONS" in dataset_info: + SESSIONS = dataset_info["SESSIONS"] + else: + SESSIONS = None + if SESSIONS is None: + SESSIONS = [None] + + if "{dataset}" in dataset_info["main_root"]: + main_root = dataset_info["main_root"].replace( + "{dataset}", dataset_info["dataset"] + ) + else: + main_root = dataset_info["main_root"] + + if "{main_root}" in dataset_info["fmriprep_root"]: + fmriprep_root = dataset_info["fmriprep_root"].replace("{main_root}", main_root) + elif "{dataset}" in dataset_info["fmriprep_root"]: + fmriprep_root = dataset_info["fmriprep_root"].replace( + "{dataset}", dataset_info["dataset"] + ) + else: + fmriprep_root = dataset_info["fmriprep_root"] + + if "{main_root}" in dataset_info["roi_root"]: + roi_root = dataset_info["roi_root"].replace("{main_root}", main_root) + else: + roi_root = dataset_info["roi_root"] + + if "{main_root}" in dataset_info["dFC_root"]: + dFC_root = dataset_info["dFC_root"].replace("{main_root}", main_root) + else: + dFC_root = dataset_info["dFC_root"] + + if "{main_root}" in dataset_info["ML_root"]: + ML_root = dataset_info["ML_root"].replace("{main_root}", main_root) + else: + ML_root = dataset_info["ML_root"] + + if "{main_root}" in dataset_info["reports_root"]: + reports_root = dataset_info["reports_root"].replace("{main_root}", main_root) + else: + reports_root = dataset_info["reports_root"] + + print("Generating report...") + + # Generate report only 3 subjects + SUBJECTS = SUBJECTS[:3] + + start_time = 0 + end_time = 200 + + for subj in SUBJECTS: + for session in SESSIONS: + for task in TASKS: + for run in RUNS[task]: + + try: + plot_dFC_matrices( + dFC_root=dFC_root, + subj=subj, + task=task, + start_time=start_time, + end_time=end_time, + output_root=reports_root, + run=run, + session=session, + ) + except Exception as e: + print(f"Error in plotting dFC matrices: {e}") + + # try: + # plot_glm( + # fmriprep_root=fmriprep_root, + # roi_root=roi_root, + # subj=subj, + # task=task, + # bold_suffix=dataset_info["bold_suffix"], + # trial_type_label=dataset_info["trial_type_label"], + # rest_labels=dataset_info["rest_labels"], + # output_root=reports_root, + # run=run, + # session=session, + # ) + # except Exception as e: + # print(f"Error in plotting GLM: {e}") + + try: + plot_roi_signals( + roi_root=roi_root, + subj=subj, + task=task, + start_time=start_time, + end_time=end_time, + nodes_list=range(0, 10), + output_root=reports_root, + run=run, + session=session, + ) + except Exception as e: + print(f"Error in plotting ROI signals: {e}") + + try: + plot_event_labels( + roi_root=roi_root, + subj=subj, + task=task, + start_time=start_time, + end_time=end_time, + output_root=reports_root, + run=run, + session=session, + ) + except Exception as e: + print(f"Error in plotting event labels: {e}") + + try: + plot_task_presence( + roi_root=roi_root, + subj=subj, + task=task, + start_time=start_time, + end_time=end_time, + output_root=reports_root, + run=run, + session=session, + ) + except Exception as e: + print(f"Error in plotting task presence: {e}") + + # try: + # plot_dFC_clustering( + # dFC_root=dFC_root, + # subj=subj, + # task=task, + # start_time=start_time, + # end_time=end_time, + # output_root=reports_root, + # run=run, + # session=session, + # normalize_dFC=True, + # ) + # except Exception as e: + # print(f"Error in plotting dFC clustering: {e}") + # create html report + try: + create_html_report_subj_results( + subj=subj, + SESSIONS=SESSIONS, + TASKS=TASKS, + RUNS=RUNS, + reports_root=reports_root, + ) + except Exception as e: + print(f"Error in creating html report for subject results: {e}") + + # plot group results + # find the common run number for all tasks for task presence features + common_run = None + for task in TASKS: + if common_run is None: + common_run = RUNS[task][0] + else: + if RUNS[task][0] != common_run: + common_run = None + # raise warning + print( + "Warning: Tasks have different run numbers for task presence features!" + ) + break + + for session in SESSIONS: + try: + plot_task_presence_features( + ML_root=ML_root, + output_root=reports_root, + session=session, + run=common_run, + ) + except Exception as e: + print(f"Error in plotting task presence features: {e}") + + try: + plot_visual_clstr_centroids( + ML_root=ML_root, + output_root=reports_root, + session=session, + ) + except Exception as e: + print(f"Error in plotting visual clustering centroids: {e}") + + for task in TASKS: + for run in RUNS[task]: + for embedding in ["PCA", "LE"]: + try: + plot_ML_results( + ML_root=ML_root, + output_root=reports_root, + task=task, + run=run, + session=session, + ML_algorithms=["SVM", "Logistic regression"], + embedding=embedding, + ) + except Exception as e: + print(f"Error in plotting ML results for {embedding}: {e}") + + # create html report + try: + create_html_report_group_results( + SESSIONS=SESSIONS, + TASKS=TASKS, + RUNS=RUNS, + reports_root=reports_root, + ) + except Exception as e: + print(f"Error in creating html report for group results: {e}") + + print("Report generated successfully!") + +####################################################################################### diff --git a/task_dFC/multi_dataset_analysis/cohensd.py b/task_dFC/multi_dataset_analysis/cohensd.py new file mode 100644 index 0000000..3b892d1 --- /dev/null +++ b/task_dFC/multi_dataset_analysis/cohensd.py @@ -0,0 +1,485 @@ +import argparse +import json +import os +import sys + +import matplotlib.pyplot as plt +import nibabel as nib +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib.colors import to_rgba +from nilearn import datasets, plotting + +from pydfc import data_loader +from pydfc.ml_utils import find_available_subjects, load_task_data +from pydfc.task_utils import cohen_d_bold, extract_task_presence + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from helper_functions import ( # pyright: ignore[reportMissingImports] + build_experiment_display_info, +) + +####################################################################################### + + +_BOX_COLOR = "#4472C4" +_POINT_COLOR = "#C0392B" +_POINT_EDGE_COLOR = "#7B241C" +_BOX_OFFSET = -0.17 # box center relative to x-tick +_STRIP_OFFSET = 0.17 # point cloud center relative to x-tick +_BOX_WIDTH = 0.28 +_STRIP_JITTER = 0.09 + + +def plot_cohensd_per_experiment( + df, + experiment_order, + save_path, + y_col="abs_d", + y_label="|Cohen's d|", +): + """ + Boxplot (left of tick) + individual points (right of tick) per experiment. + + Boxes and points are spatially separated so neither buries the other. + Simulated data uses symlog y-scale to handle extreme outliers. + """ + fig_width = max(10, 0.7 * len(experiment_order)) + fig, ax = plt.subplots(figsize=(fig_width, 7)) + + n = len(experiment_order) + positions = np.arange(n) + exp_to_idx = {exp: i for i, exp in enumerate(experiment_order)} + + # --- Boxplot left of center --- + box_data = [ + df[df["experiment"] == exp][y_col].dropna().values for exp in experiment_order + ] + bp = ax.boxplot( + box_data, + positions=positions + _BOX_OFFSET, + widths=_BOX_WIDTH, + showfliers=False, + patch_artist=True, + medianprops=dict(color="#1A1A1A", linewidth=2.5), + boxprops=dict(linewidth=1.8), + whiskerprops=dict(linewidth=1.6), + capprops=dict(linewidth=1.6), + ) + for patch in bp["boxes"]: + patch.set_facecolor(to_rgba(_BOX_COLOR, 0.5)) + patch.set_edgecolor(_BOX_COLOR) + for line in bp["whiskers"] + bp["caps"]: + line.set_color(_BOX_COLOR) + + # --- Strip points right of center --- + rng = np.random.default_rng(42) + for exp in experiment_order: + vals = df[df["experiment"] == exp][y_col].dropna().values + if len(vals) == 0: + continue + x_jit = (exp_to_idx[exp] + _STRIP_OFFSET) + rng.uniform( + -_STRIP_JITTER, _STRIP_JITTER, len(vals) + ) + ax.scatter( + x_jit, + vals, + color=_POINT_COLOR, + alpha=0.55, + s=30, + linewidths=0.5, + edgecolors=_POINT_EDGE_COLOR, + zorder=3, + ) + + ax.set_xticks(positions) + ax.set_xticklabels(experiment_order) + ax.set_xlim(-0.6, n - 0.4) + + ax.set_ylim(bottom=0) + + ax.set_xlabel("Experiment", fontsize=13, fontweight="bold") + ax.set_ylabel(y_label, fontsize=13, fontweight="bold") + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", fontsize=11) + plt.setp(ax.get_yticklabels(), fontsize=11) + sns.despine(ax=ax) + plt.tight_layout() + + os.makedirs(os.path.dirname(os.path.abspath(save_path)), exist_ok=True) + plt.savefig(save_path, dpi=150, bbox_inches="tight", pad_inches=0.2, format="png") + plt.close() + + +####################################################################################### + +if __name__ == "__main__": + # argparse + HELPTEXT = """ + Script to compute and visualize Cohen's d effect sizes for task vs. rest BOLD signals across multiple datasets. + """ + + parser = argparse.ArgumentParser(description=HELPTEXT) + + parser.add_argument( + "--multi_dataset_info", type=str, help="path to multi-dataset info file" + ) + parser.add_argument( + "--simul_or_real", type=str, help="Specify 'simulated' or 'real' data" + ) + + args = parser.parse_args() + + multi_dataset_info = args.multi_dataset_info + simul_or_real = args.simul_or_real + + # Read dataset info + with open(multi_dataset_info, "r") as f: + multi_dataset_info = json.load(f) + + if simul_or_real == "real": + main_root = multi_dataset_info["real_data"]["main_root"] + DATASETS = multi_dataset_info["real_data"]["DATASETS"] + TASKS_to_include = multi_dataset_info["real_data"]["TASKS_to_include"] + elif simul_or_real == "simulated": + main_root = multi_dataset_info["simulated_data"]["main_root"] + DATASETS = multi_dataset_info["simulated_data"]["DATASETS"] + TASKS_to_include = multi_dataset_info["simulated_data"]["TASKS_to_include"] + output_root = f"{multi_dataset_info['output_root']}/CohensD/{simul_or_real}" + + if not os.path.exists(output_root): + os.makedirs(output_root) + + # the dictionary to build the dataframe for visualization of Cohen's d across tasks + CohensD_across_task = { + "task": [], + "d_values": [], + "dataset": [], + "ROI": [], + } + # the dictionary to be used for the correlation with ML performance + CohensD_ML = { + "task": [], + "run": [], + "dataset": [], + "CohensD_max": [], + "CohensD_mean": [], + } + for dataset in DATASETS: + print(f"Processing dataset: {dataset}") + dataset_info_file = f"{main_root}/{dataset}/codes/dataset_info.json" + roi_root = f"{main_root}/{dataset}/derivatives/ROI_timeseries" + dFC_root = f"{main_root}/{dataset}/derivatives/dFC_assessed" + + # Read dataset info + with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + + if "SESSIONS" in dataset_info: + SESSIONS = dataset_info["SESSIONS"] + else: + SESSIONS = None + if SESSIONS is None: + SESSIONS = [None] + + TASKS = dataset_info["TASKS"] + + if "RUNS" in dataset_info: + RUNS = dataset_info["RUNS"] + else: + RUNS = None + if RUNS is None: + RUNS = {task: [None] for task in TASKS} + + for task in TASKS: + if task not in TASKS_to_include: + print(f"Skipping task {task} as it's not in the inclusion list.") + continue + d_values_all = [] + session = SESSIONS[ + 0 + ] # for now, only use the first session if multiple are present + print(f"Processing task: {task}") + SUBJECTS = find_available_subjects( + dFC_root=dFC_root, + task=task, + dFC_id=None, + session=session, + ) + excluded_subjects = [] + for run in RUNS[task]: + d_values_run = [] + for subj in SUBJECTS: + try: + task_data = load_task_data( + roi_root=roi_root, + subj=subj, + task=task, + run=run, + session=session, + ) + except: + excluded_subjects.append(subj) + continue + + if run is None: + if session is None: + BOLD_file_name = "{subj_id}_{task}_time-series.npy" + else: + BOLD_file_name = "{subj_id}_{session}_{task}_time-series.npy" + else: + if session is None: + BOLD_file_name = "{subj_id}_{task}_{run}_time-series.npy" + else: + BOLD_file_name = ( + "{subj_id}_{session}_{task}_{run}_time-series.npy" + ) + try: + BOLD = data_loader.load_TS( + data_root=roi_root, + file_name=BOLD_file_name, + subj_id2load=subj, + task=task, + session=session, + run=run, + ) + except Exception as e: + print(f"Error loading BOLD data: {e}") + excluded_subjects.append(subj) + continue + BOLD_data = BOLD.data # np.ndarray (n_ROIs, n_TRs) + + Fs_task = task_data["Fs_task"] + TR_task = 1 / Fs_task + + TR_array = np.arange(0, BOLD_data.shape[1]) + task_presence, indices = extract_task_presence( + event_labels=task_data["event_labels"], + TR_task=TR_task, + TR_mri=task_data["TR_mri"], + binary=True, + binarizing_method="GMM", + no_hrf=False, + TR_array=TR_array, + ) + + # if n_TRs do not match, align them + if BOLD_data.shape[1] != task_presence.shape[0]: + print( + f"Before alignment, shape of task_presence: {task_presence.shape}, shape of BOLD_data: {BOLD_data.shape}" + ) + min_TRs = min(BOLD_data.shape[1], task_presence.shape[0]) + task_presence = task_presence[:min_TRs] + BOLD_data = BOLD_data[:, :min_TRs] + print( + f"After alignment, shape of task_presence: {task_presence.shape}, shape of BOLD_data: {BOLD_data.shape}" + ) + # also adjust indices + indices = [i for i in indices if i < min_TRs] + task_presence = task_presence[indices] # (n_TRs,) + BOLD_data = BOLD_data[:, indices] # (n_ROIs, n_TRs) + + assert BOLD_data.shape[1] == task_presence.shape[0] + + cohen_d = cohen_d_bold(X=BOLD_data.T, y=task_presence) # (n_ROIs,) + d_values_run.append(cohen_d) + d_values_all.append(cohen_d) + + d_values_run = np.array(d_values_run) # (n_subjects, n_ROIs) + assert ( + d_values_run.shape[1] == BOLD_data.shape[0] + ), f"Expected number of ROIs in d_values_run ({d_values_run.shape[1]}) to match BOLD_data ({BOLD_data.shape[0]})" + assert d_values_run.shape[0] == len(SUBJECTS) - len( + set(excluded_subjects) + ), f"Expected number of subjects in d_values_run ({d_values_run.shape[0]}) to match n_subjects ({len(SUBJECTS) - len(set(excluded_subjects))})" + + CohensD_ML["task"].append(task) + CohensD_ML["run"].append(run) + CohensD_ML["dataset"].append(dataset) + # MAX |d| across ROIs for this run after averaging across subjects + CohensD_ML["CohensD_max"].append( + np.nanmax(np.abs(np.nanmean(d_values_run, axis=0))) + ) + # MEAN |d| across ROIs for this run after averaging across subjects + CohensD_ML["CohensD_mean"].append( + np.nanmean(np.abs(np.nanmean(d_values_run, axis=0))) + ) + + if len(d_values_all) == 0: + print(f"No data found for task {task} in dataset {dataset}. Skipping.") + continue + d_values_all = np.array(d_values_all) # (runs x n_subjects, n_ROIs) + + avg_d_values = np.nanmean(d_values_all, axis=0) # (n_ROIs,) + CohensD_across_task["d_values"].extend(avg_d_values) + CohensD_across_task["task"].extend([task] * len(avg_d_values)) + CohensD_across_task["dataset"].extend([dataset] * len(avg_d_values)) + CohensD_across_task["ROI"].extend(BOLD.node_labels) + + # plot d values on a glass brain + if simul_or_real == "real": + coords = BOLD.locs + + template_img = datasets.load_mni152_template() + data = np.zeros(template_img.shape) + affine = template_img.affine + + # Create a small sphere for each coordinate + radius = 5 # in voxels + for c, d in zip(coords, avg_d_values): + ijk = np.round( + nib.affines.apply_affine(np.linalg.inv(affine), c) + ).astype(int) + x, y, z = ijk + for i in range(-radius, radius + 1): + for j in range(-radius, radius + 1): + for k in range(-radius, radius + 1): + if i**2 + j**2 + k**2 <= radius**2: + xi, yj, zk = x + i, y + j, z + k + if ( + (0 <= xi < data.shape[0]) + and (0 <= yj < data.shape[1]) + and (0 <= zk < data.shape[2]) + ): + data[xi, yj, zk] = d + + d_img = nib.Nifti1Image(data, affine) + + plotting.plot_glass_brain( + d_img, + display_mode="ortho", + colorbar=True, + plot_abs=False, + cmap="coolwarm", + vmax=np.max(avg_d_values), + ) + + plt.savefig( + f"{output_root}/cohensd_region_{task}.png", + dpi=120, + bbox_inches="tight", + pad_inches=0.1, + format="png", + ) + + plt.close() + + # Load Schaefer atlas (100 parcels) + schaefer = datasets.fetch_atlas_schaefer_2018(n_rois=100) + + # atlas_img is the path to the NIfTI file; load it + atlas_img = nib.load(schaefer["maps"]) + labels = schaefer["labels"] # list of labels + labels = [label.decode() for label in labels] + # check that the labels match BOLD.node_labels + assert all( + i == j for i, j in zip(labels, BOLD.node_labels) + ), "Labels do not match!" + + atlas_data = atlas_img.get_fdata() + cohen_img_data = np.zeros(atlas_data.shape) + + for i, d in enumerate(avg_d_values): + cohen_img_data[atlas_data == (i + 1)] = d # labels start from 1 + + cohen_img = nib.Nifti1Image(cohen_img_data, affine=atlas_img.affine) + + plotting.plot_glass_brain( + cohen_img, + display_mode="ortho", + colorbar=True, + cmap="coolwarm", + plot_abs=False, + vmax=np.max(avg_d_values), + ) + + plt.savefig( + f"{output_root}/cohensd_voxel_{task}.png", + dpi=120, + bbox_inches="tight", + pad_inches=0.1, + format="png", + ) + + plt.close() + + # Save the Cohen's d values for comparison with ML performance + np.save(f"{output_root}/CohensD_ML_{simul_or_real}.npy", CohensD_ML) + + # --- Across-task visualizations (ABSOLUTE Cohen's d) --- + sns.set_context("paper", font_scale=1.0, rc={"lines.linewidth": 1.2}) + sns.set_style("darkgrid") + + # Build dataframe if not already done + DF = pd.DataFrame.from_dict(CohensD_across_task) + + task_order_reference, task_to_experiment, _, _ = build_experiment_display_info( + tasks_iterable=DF["task"].unique().tolist(), + task_reference_order=TASKS_to_include, + simul_or_real=simul_or_real, + ) + DF["experiment"] = DF["task"].map(task_to_experiment) + + # Use absolute Cohen's d + DF["abs_d"] = DF["d_values"].abs() + + # Choose an order (sort tasks by their MAX |d| to align with Fig. 2) + max_abs_per_task = ( + DF.groupby("task")["abs_d"] + .max() + .sort_values(ascending=False) + .reset_index(name="abs_max") + ) + task_order = max_abs_per_task["task"].tolist() + experiment_order = [task_to_experiment[task] for task in task_order] + max_abs_per_task["experiment"] = max_abs_per_task["task"].map(task_to_experiment) + + # Dynamic width so labels don't collide (0.6 inch per task, min 14 inches) + fig_width = max(14, 0.6 * len(task_order)) + + # -------- Figure 1: Boxplot of |Cohen's d| per task with individual samples -------- + plot_cohensd_per_experiment( + df=DF, + experiment_order=experiment_order, + save_path=f"{output_root}/CohensD_abs_boxplot_with_samples_per_task.png", + y_col="abs_d", + y_label="|Cohen's d|", + ) + + # -------- Figure 2: Max |Cohen's d| across ROIs per task -------- + plt.figure(figsize=(fig_width, 6)) + + ax = sns.barplot( + data=max_abs_per_task, + x="experiment", + y="abs_max", + order=experiment_order, + ) + + # Optional: annotate bars with values (trim to 2 decimals) + for p in ax.patches: + height = p.get_height() + ax.annotate( + f"{height:.2f}", + (p.get_x() + p.get_width() / 2.0, height), + ha="center", + va="bottom", + xytext=(0, 2), + textcoords="offset points", + fontsize=8, + ) + + ax.set_xlabel("Experiment") + ax.set_ylabel("Max |Cohen's d|") + ax.set_ylim(bottom=0) + ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right") + plt.tight_layout() + + plt.savefig( + f"{output_root}/CohensD_abs_max_per_task.png", + dpi=150, + bbox_inches="tight", + pad_inches=0.2, + format="png", + ) + plt.close() diff --git a/task_dFC/multi_dataset_analysis/dfc_visualization.py b/task_dFC/multi_dataset_analysis/dfc_visualization.py new file mode 100644 index 0000000..c028069 --- /dev/null +++ b/task_dFC/multi_dataset_analysis/dfc_visualization.py @@ -0,0 +1,173 @@ +import argparse +import json +import os +import re +import sys + +from pydfc.dfc_utils import TR_intersection, rank_norm +from pydfc.ml_utils import find_available_subjects, load_dFC + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from helper_functions import ( # pyright: ignore[reportMissingImports] + build_experiment_display_info, + figure_dfc_matrices_window_png, +) + +normalize_dFC = True + + +def discover_available_dfc_ids(dfc_root): + """Return the sorted dFC IDs found anywhere under ``dfc_root``.""" + dfc_ids = set() + for root, _, files in os.walk(dfc_root): + for file_name in files: + if not file_name.endswith(".npy"): + continue + match = re.search(r"_(\d+)\.npy$", file_name) + if match: + dfc_ids.add(int(match.group(1))) + return sorted(dfc_ids) + + +####################################################################################### + +if __name__ == "__main__": + # argparse + HELPTEXT = """ + Script to make figures/tables from multi-dataset ML results. + """ + + parser = argparse.ArgumentParser(description=HELPTEXT) + + parser.add_argument( + "--multi_dataset_info", type=str, help="path to multi-dataset info file" + ) + parser.add_argument( + "--simul_or_real", type=str, help="Specify 'simulated' or 'real' data" + ) + + args = parser.parse_args() + + multi_dataset_info = args.multi_dataset_info + simul_or_real = args.simul_or_real + + # Read dataset info + with open(multi_dataset_info, "r") as f: + multi_dataset_info = json.load(f) + + print("Multi-Dataset Analysis started ...") + + if simul_or_real == "real": + main_root = multi_dataset_info["real_data"]["main_root"] + DATASETS = multi_dataset_info["real_data"]["DATASETS"] + TASKS_to_include = multi_dataset_info["real_data"]["TASKS_to_include"] + elif simul_or_real == "simulated": + main_root = multi_dataset_info["simulated_data"]["main_root"] + DATASETS = multi_dataset_info["simulated_data"]["DATASETS"] + TASKS_to_include = multi_dataset_info["simulated_data"]["TASKS_to_include"] + output_root = f"{multi_dataset_info['output_root']}/dFC/{simul_or_real}" + + if not os.path.exists(output_root): + os.makedirs(output_root) + + _, task_to_experiment, _, _ = build_experiment_display_info( + tasks_iterable=TASKS_to_include, + task_reference_order=TASKS_to_include, + simul_or_real=simul_or_real, + ) + + for dataset in DATASETS: + dataset_info_file = f"{main_root}/{dataset}/codes/dataset_info.json" + roi_root = f"{main_root}/{dataset}/derivatives/ROI_timeseries" + dFC_root = f"{main_root}/{dataset}/derivatives/dFC_assessed" + + # Read dataset info + with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + + if "SESSIONS" in dataset_info: + SESSIONS = dataset_info["SESSIONS"] + else: + SESSIONS = None + if SESSIONS is None: + SESSIONS = [None] + + TASKS = dataset_info["TASKS"] + + if "RUNS" in dataset_info: + RUNS = dataset_info["RUNS"] + else: + RUNS = None + if RUNS is None: + RUNS = {task: [None] for task in TASKS} + + DATA = {} + dFC_ids = discover_available_dfc_ids(dFC_root) + if len(dFC_ids) == 0: + print(f"No dFC files found under {dFC_root}; skipping dataset {dataset}.") + continue + + for dFC_id in dFC_ids: + for session in SESSIONS[:1]: # Only process the first session + for task_id, task in enumerate(TASKS): + for run in RUNS[task][:1]: # Only process the first run + print( + f"Processing dataset: {dataset}, task: {task}, run: {run}, session: {session}, dFC_id: {dFC_id}" + ) + + SUBJECTS = find_available_subjects( + dFC_root=dFC_root, + task=task, + dFC_id=dFC_id, + session=session, + run=run, + ) + if len(SUBJECTS) == 0: + print( + f"No subjects found for dataset: {dataset}, task: {task}, run: {run}, session: {session}, dFC_id: {dFC_id}" + ) + continue + + subj = SUBJECTS[0] # Only process the first subject + + dFC = load_dFC( + dFC_root=dFC_root, + subj=subj, + task=task, + dFC_id=dFC_id, + run=run, + session=session, + ) + + if not task in DATA: + DATA[task] = {} + DATA[task][dFC.measure.measure_name] = dFC + + # visualize the dFC matrices for each task + for task in DATA.keys(): + # first find common TRs across measures + common_TRs = TR_intersection( + [DATA[task][measure_name] for measure_name in DATA[task]] + ) + + dFC_mat_dict = {} + for measure_name in DATA[task]: + dFC = DATA[task][measure_name] + dFC_mat = dFC.get_dFC_mat(TRs=common_TRs) + if normalize_dFC: + dFC_mat = rank_norm(dFC_mat) + dFC_mat_dict[measure_name] = dFC_mat + figure_dfc_matrices_window_png( + dFC_mat_dict, + common_TRs, + window_len=10, + cmap="plasma", + outfile=( + f"{output_root}/dFC_{dataset}_" + f"{task_to_experiment.get(task, task).replace(' ', '_').replace('/', '-')}_" + f"{task}_mid_10.png" + ), + dpi=600, + ) + + print(f"Saved data for dataset {dataset}") diff --git a/task_dFC/multi_dataset_analysis/embedding_visualization.py b/task_dFC/multi_dataset_analysis/embedding_visualization.py new file mode 100644 index 0000000..e3975d6 --- /dev/null +++ b/task_dFC/multi_dataset_analysis/embedding_visualization.py @@ -0,0 +1,231 @@ +import argparse +import json +import os + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +from sklearn.decomposition import PCA +from sklearn.metrics import silhouette_score + +from pydfc.ml_utils import ( + LE_transform, + PLSEmbedder, + dFC_feature_extraction, + find_available_subjects, + process_SB_features, +) + +fig_dpi = 120 +fig_bbox_inches = "tight" +fig_pad = 0.1 +show_title = True +save_fig_format = "png" # pdf, png, + +normalize_dFC = False + +####################################################################################### + +if __name__ == "__main__": + # argparse + HELPTEXT = """ + Script to analyze and visualize LE-transformed features across multiple datasets. + """ + + parser = argparse.ArgumentParser(description=HELPTEXT) + + parser.add_argument( + "--multi_dataset_info", type=str, help="path to multi-dataset info file" + ) + parser.add_argument( + "--simul_or_real", type=str, help="Specify 'simulated' or 'real' data" + ) + + args = parser.parse_args() + + multi_dataset_info = args.multi_dataset_info + simul_or_real = args.simul_or_real + + # Read dataset info + with open(multi_dataset_info, "r") as f: + multi_dataset_info = json.load(f) + + if simul_or_real == "real": + main_root = multi_dataset_info["real_data"]["main_root"] + DATASETS = multi_dataset_info["real_data"]["DATASETS"] + TASKS_to_include = multi_dataset_info["real_data"]["TASKS_to_include"] + elif simul_or_real == "simulated": + main_root = multi_dataset_info["simulated_data"]["main_root"] + DATASETS = multi_dataset_info["simulated_data"]["DATASETS"] + TASKS_to_include = multi_dataset_info["simulated_data"]["TASKS_to_include"] + + output_root = f"{multi_dataset_info['output_root']}/LE_embed/{simul_or_real}" + + if not os.path.exists(output_root): + os.makedirs(output_root) + + for dataset in DATASETS: + dataset_info_file = f"{main_root}/{dataset}/codes/dataset_info.json" + roi_root = f"{main_root}/{dataset}/derivatives/ROI_timeseries" + dFC_root = f"{main_root}/{dataset}/derivatives/dFC_assessed" + + # Read dataset info + with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + + if "SESSIONS" in dataset_info: + SESSIONS = dataset_info["SESSIONS"] + else: + SESSIONS = None + if SESSIONS is None: + SESSIONS = [None] + + TASKS = dataset_info["TASKS"] + + if "RUNS" in dataset_info: + RUNS = dataset_info["RUNS"] + else: + RUNS = None + if RUNS is None: + RUNS = {task: [None] for task in TASKS} + + for session in SESSIONS: + for task_id, task in enumerate(TASKS): + for run in RUNS[task][:1]: + for dFC_id in range(7): + try: + SUBJECTS = find_available_subjects( + dFC_root=dFC_root, + task=task, + dFC_id=dFC_id, + session=session, + run=run, + ) + if len(SUBJECTS) == 0: + print( + f"No subjects found for task {task}, dFC_id {dFC_id}, session {session}, run {run}." + ) + continue + SUBJECTS = SUBJECTS[0:1] + print(f"Number of subjects: {len(SUBJECTS)}") + + ( + X, + _, + y, + _, + subj_label, + _, + measure_name, + measure_is_state_based, + ) = dFC_feature_extraction( + task=task, + train_subjects=SUBJECTS, + test_subjects=[], + dFC_id=dFC_id, + roi_root=roi_root, + dFC_root=dFC_root, + run=run, + session=session, + dynamic_pred="no", + normalize_dFC=normalize_dFC, + FCS_proba_for_SB=True, + ) + + assert ( + X.shape[0] == y.shape[0] + ), "Number of samples do not match." + assert ( + X.shape[0] == subj_label.shape[0] + ), "Number of samples do not match." + + if measure_is_state_based: + X = process_SB_features(X=X, measure_name=measure_name) + + print(f"Task: {task}") + print(measure_name) + print(X.shape, y.shape) + print(silhouette_score(X, y)) + + # embed the features + # n_components = "auto" + n_components = 3 + for embedding_method in ["PCA", "PLS", "LE"]: + if embedding_method == "PCA": + X_embedded = PCA( + n_components=n_components, + whiten=False, + svd_solver="full", + random_state=0, + ).fit_transform(X) + elif embedding_method == "PLS": + X_embedded = ( + PLSEmbedder( + n_components=n_components, scale=False + ) + .fit(X, y) + .transform(X) + ) + elif embedding_method == "LE": + X_embedded = LE_transform( + X, + n_components=n_components, + n_neighbors=125, + distance_metric="correlation", + ) + + # X_embedded = TSNE(n_components=n_components, learning_rate='auto', init='random', perplexity=125, metric="correlation").fit_transform(X) + print(silhouette_score(X_embedded, y)) + print(X_embedded.shape) + + # plot + # ---- publication style (light touch) ---- + mpl.rcParams.update( + { + "legend.fontsize": 10, + "axes.linewidth": 0.9, + "pdf.fonttype": 42, + "ps.fonttype": 42, # keep text as text in PDF/SVG + "savefig.bbox": "tight", + "savefig.dpi": 300, + "figure.dpi": 150, + } + ) + fig = plt.figure(figsize=(7, 7)) + ax = fig.add_subplot(111, projection="3d") + + colors = ("#B1B1B1", "#2F5BD3") + + for label in np.unique(y): + ax.scatter( + X_embedded[y == label, 0], + X_embedded[y == label, 1], + X_embedded[y == label, 2], + label=["rest", "task"][label], + s=50, + c=[colors[label]], + edgecolors="#202020", + linewidths=0.25, + depthshade=False, + ) + plt.legend() + + # remove tick labels + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_zticklabels([]) + + plt.savefig( + f"{output_root}/{embedding_method}_embed_{task}_{measure_name}.png", + dpi=fig_dpi, + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + format=save_fig_format, + ) + + plt.close() + except Exception as e: + print( + f"Error processing task {task}, dFC_id {dFC_id}, session {session}, run {run}: {e}" + ) + continue diff --git a/task_dFC/multi_dataset_analysis/helper_functions.py b/task_dFC/multi_dataset_analysis/helper_functions.py new file mode 100644 index 0000000..8f06217 --- /dev/null +++ b/task_dFC/multi_dataset_analysis/helper_functions.py @@ -0,0 +1,1060 @@ +import re +from pathlib import Path + +import matplotlib as mpl +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib.colors import ListedColormap +from scipy.cluster.hierarchy import leaves_list, linkage +from scipy.stats import ttest_ind +from sklearn.neighbors import NearestNeighbors + +# Curated palette of maximally distinct, publication-quality colors. +# Used for coloring high-performing experiments so each gets a clearly +# different hue even when only a few experiments are highlighted. +_VIBRANT_DISTINCT_COLORS = [ + "#E6194B", # vivid red + "#3CB44B", # vivid green + "#4363D8", # vivid blue + "#F58231", # vivid orange + "#911EB4", # vivid purple + "#42D4F4", # cyan + "#F032E6", # magenta + "#008080", # teal + "#9A6324", # brown + "#000075", # navy + "#808000", # olive + "#DC143C", # crimson +] + +###################### Publication style ###################### + + +def setup_pub_style(): + sns.set_theme(context="paper", style="whitegrid") + mpl.rcParams.update( + { + # Fonts & text + "font.size": 18, # base + "axes.titlesize": 12, + "axes.labelsize": 11, + "xtick.labelsize": 18, + "ytick.labelsize": 18, + "legend.fontsize": 9, + "figure.titlesize": 13, + "axes.titlepad": 8, + "axes.labelpad": 6, + # Lines/markers + "lines.linewidth": 1.5, + "lines.markersize": 5, + "axes.linewidth": 0.8, + "grid.linewidth": 0.6, + # Figure/layout + "figure.dpi": 150, # on-screen + "savefig.dpi": 1000, # export + "savefig.bbox": "tight", + "savefig.pad_inches": 0.04, + # Vector export: keep text as text in PDF/SVG + "pdf.fonttype": 42, + "ps.fonttype": 42, + } + ) + + +def savefig_pub(path_png_or_pdf: str): + Path(Path(path_png_or_pdf).parent).mkdir(parents=True, exist_ok=True) + plt.savefig(path_png_or_pdf) + # # Also export vector PDF alongside PNG unless you passed a .pdf + # p = Path(path_png_or_pdf) + # if p.suffix.lower() != ".pdf": + # plt.savefig(p.with_suffix(".pdf")) + + +###################### RDoC ###################### + +RDoC_MAP = { + "real": { + # --- Cognitive-Atlas–aligned domains (order on paper) --- + "DOMAIN_ORDER": [ + "Arousal & Regulatory Systems", + "Cognitive Systems", + "Negative Valence System", + "Positive Valence System", + "Sensorimotor Systems", + ], + # --- Map canonical task codes -> domain --- + "TASK2DOMAIN": { + # Language & Regulatory Systems + "emotionregulation": "Arousal & Regulatory Systems", + # Cognitive Systems + "audsem": "Cognitive Systems", + "visrhyme": "Cognitive Systems", + "vissem": "Cognitive Systems", + "visspell": "Cognitive Systems", + "arithmetic": "Cognitive Systems", + "stroop": "Cognitive Systems", + "cuedts": "Cognitive Systems", + "axcpt": "Cognitive Systems", + "matching": "Cognitive Systems", + "stern": "Cognitive Systems", + "st": "Cognitive Systems", + "vswm": "Cognitive Systems", + "expo": "Cognitive Systems", + "recall": "Cognitive Systems", + "feedback": "Cognitive Systems", + "ppalocalizer": "Cognitive Systems", + "localiser": "Cognitive Systems", + "localizer": "Cognitive Systems", + # Positive Valence System + "fribbids": "Positive Valence System", + "risk": "Positive Valence System", + "itc": "Positive Valence System", + # Negative Valence System + "fearlearning": "Negative Valence System", + "paingen": "Negative Valence System", + # Sensorimotor + "motor": "Sensorimotor Systems", + "execution": "Sensorimotor Systems", + "imagery": "Sensorimotor Systems", + "ihg": "Sensorimotor Systems", + }, + }, + "simulated": { + # --- Categories of simulated task paradigms --- + "DOMAIN_ORDER": [ + "Simulated Periodic", + "Strong Performance on Real Data", + "Weak Performance on Real Data", + ], + # --- Map task codes -> category --- + "TASK2DOMAIN": { + # Simulated Periodic + "lowfreqlongrest": "Simulated Periodic", + "lowfreqshortrest": "Simulated Periodic", + "lowfreqshorttask": "Simulated Periodic", + # Optimal Paradigm Design, Strong Performance on Real Data + "axcpt": "Strong Performance on Real Data", + "stern": "Strong Performance on Real Data", + "cuedts": "Strong Performance on Real Data", + "stroop": "Strong Performance on Real Data", + # Optimal Paradigm Design, Weak Performance on Real Data + "execution": "Weak Performance on Real Data", + "imagery": "Weak Performance on Real Data", + "localizer": "Weak Performance on Real Data", + "ppalocalizer": "Weak Performance on Real Data", + # Sub-Optimal Paradigm Design, Weak Performance on Real Data + "itc": "Weak Performance on Real Data", + "risk": "Weak Performance on Real Data", + }, + }, +} + +###################### ml_results ###################### + + +DEFAULT_EXPERIMENT_NAME_MAP = { + "real": { + "emotionregulation": "EXP.17", + "audsem": "EXP.3", + "visrhyme": "EXP.4", + "vissem": "EXP.5", + "visspell": "EXP.6", + "arithmetic": "EXP.23", + "stroop": "EXP.14", + "cuedts": "EXP.12", + "axcpt": "EXP.11", + "matching": "EXP.24", + "stern": "EXP.13", + "st": "EXP.28", + "vswm": "EXP.25", + "expo": "EXP.19", + "recall": "EXP.20", + "feedback": "EXP.21", + "ppalocalizer": "EXP.2", + "localiser": "EXP.26", + "localizer": "EXP.27", + "fribbids": "EXP.10", + "risk": "EXP.9", + "itc": "EXP.8", + "fearlearning": "EXP.1", + "paingen": "EXP.22", + "motor": "EXP.18", + "execution": "EXP.15", + "imagery": "EXP.16", + "ihg": "EXP.7", + }, + "simulated": { + "lowfreqlongrest": "EXP.S.29", + "lowfreqshortrest": "EXP.S.30", + "lowfreqshorttask": "EXP.S.31", + "axcpt": "EXP.S.11", + "cuedts": "EXP.S.12", + "stern": "EXP.S.13", + "stroop": "EXP.S.14", + "execution": "EXP.S.15", + "imagery": "EXP.S.16", + "localizer": "EXP.S.27", + "ppalocalizer": "EXP.S.2", + "itc": "EXP.S.8", + "risk": "EXP.S.9", + }, +} + + +def canon_task(task_str: str) -> str: + """strip 'task-' and non-letters, lowercase → canonical key""" + s = task_str.replace("task-", "") + s = re.sub(r"[^a-zA-Z]", "", s) + return s.lower() + + +def get_default_experiment_name_map(simul_or_real: str): + if simul_or_real not in DEFAULT_EXPERIMENT_NAME_MAP: + raise ValueError(f"Invalid simul_or_real: {simul_or_real}") + return DEFAULT_EXPERIMENT_NAME_MAP[simul_or_real].copy() + + +def get_present_task_order(tasks_iterable, task_reference_order): + present_tasks = list(dict.fromkeys(tasks_iterable)) + present_set = set(present_tasks) + ordered_tasks = [task for task in task_reference_order if task in present_set] + remaining_tasks = sorted( + [task for task in present_tasks if task not in ordered_tasks], + key=lambda task: task.lower(), + ) + return ordered_tasks + remaining_tasks + + +def _next_available_experiment_label(used_labels_lower): + index = 1 + while f"exp{index}" in used_labels_lower: + index += 1 + return f"exp{index}" + + +def build_experiment_display_info(tasks_iterable, task_reference_order, simul_or_real): + """ + Resolve task order, experiment labels, and a stable palette for ML result plots. + + Edit ``DEFAULT_EXPERIMENT_NAME_MAP`` above to change experiment labels. + Any task not listed there is auto-assigned the next available ``expN`` label. + """ + task_order = get_present_task_order(tasks_iterable, task_reference_order) + configured_map = get_default_experiment_name_map(simul_or_real) + + task_to_experiment = {} + used_labels = {} + used_labels_lower = set() + + for task in task_order: + experiment_label = configured_map.get(canon_task(task)) + if experiment_label is None: + experiment_label = _next_available_experiment_label(used_labels_lower) + + experiment_label_key = experiment_label.lower() + if experiment_label_key in used_labels: + raise ValueError( + "Experiment labels must be unique for the plotted tasks. " + f"Both '{used_labels[experiment_label_key]}' and '{task}' map to " + f"'{experiment_label}'." + ) + + task_to_experiment[task] = experiment_label + used_labels[experiment_label_key] = task + used_labels_lower.add(experiment_label_key) + + n = max(1, len(task_order)) + colors = [ + _VIBRANT_DISTINCT_COLORS[i % len(_VIBRANT_DISTINCT_COLORS)] for i in range(n) + ] + experiment_order = [task_to_experiment[task] for task in task_order] + experiment_palette = dict(zip(experiment_order, colors)) + + return task_order, task_to_experiment, experiment_order, experiment_palette + + +def relabel_heatmap_rows(matrix_df, annot_df, task_reference_order, task_to_experiment): + def _experiment_sort_key(exp_label): + match = re.match(r"(?i)^\s*exp\s*[._-]?\s*(\d+)\s*$", str(exp_label)) + if match: + return (0, int(match.group(1)), str(exp_label).lower()) + return (1, float("inf"), str(exp_label).lower()) + + row_order = get_present_task_order(matrix_df.index.tolist(), task_reference_order) + experiment_labels = [task_to_experiment[task] for task in row_order] + + relabeled_matrix = matrix_df.loc[row_order].copy() + relabeled_matrix.index = experiment_labels + + relabeled_annot = None + if annot_df is not None: + relabeled_annot = annot_df.loc[row_order].copy() + relabeled_annot.index = experiment_labels + + sorted_labels = sorted(relabeled_matrix.index.tolist(), key=_experiment_sort_key) + relabeled_matrix = relabeled_matrix.loc[sorted_labels] + if relabeled_annot is not None: + relabeled_annot = relabeled_annot.loc[sorted_labels] + + return relabeled_matrix, relabeled_annot, row_order + + +def boldify_axes(ax, xlabel=None, ylabel=None, rotate_xticks=35): + if xlabel is not None: + ax.set_xlabel(xlabel, fontweight="bold") + if ylabel is not None: + ax.set_ylabel(ylabel, fontweight="bold") + # dFC method names on x-axis + if rotate_xticks is not None: + plt.setp( + ax.get_xticklabels(), fontweight="bold", rotation=rotate_xticks, ha="right" + ) + else: + plt.setp(ax.get_xticklabels(), fontweight="bold") + + +def mean_ci_boot(y, n_boot=3000, ci=95, rng=None): + y = np.asarray(y, float) + y = y[~np.isnan(y)] + if y.size == 0: + return np.nan, np.nan, np.nan + m = float(np.mean(y)) + if y.size == 1: + return m, m, m + if rng is None: + rng = np.random.default_rng() # fresh entropy + idx = rng.integers(0, y.size, size=(n_boot, y.size)) + boots = np.mean(y[idx], axis=1) + lo = float(np.percentile(boots, (100 - ci) / 2)) + hi = float(np.percentile(boots, 100 - (100 - ci) / 2)) + return m, lo, hi + + +def summarize_methods_across_tasks( + df_plot, ycol, method_col="dFC method", ci_func=mean_ci_boot +): + """ + Return a DataFrame with columns: [method_col, 'mean','lo','hi']. + Robust to Pandas quirks; no MultiIndex/unnamed columns. + Assumes df_plot has one row per (task, method) already (your BEST table). + """ + rows = [] + for meth, s in df_plot.groupby(method_col, observed=True)[ycol]: + m, lo, hi = ci_func(s.values) + rows.append({method_col: meth, "mean": m, "lo": lo, "hi": hi}) + return pd.DataFrame(rows) + + +def overlay_method_mean_ci( + ax, + df_plot, + ycol, + method_col="dFC method", + line_halfwidth=0.30, + cap_halfwidth=0.12, + color="#222", + lower=None, + upper=None, + rng=None, +): + # map x positions from current ticks (call after you set/rotate xticklabels) + xticks = ax.get_xticks() + xlabs = [t.get_text() for t in ax.get_xticklabels()] + xpos = {lab: xticks[i] for i, lab in enumerate(xlabs)} + + # summarize robustly + summ = summarize_methods_across_tasks( + df_plot, ycol, method_col, ci_func=lambda y: mean_ci_boot(y, rng=rng) + ) + + # clip to metric bounds if provided + def clip(v): + if lower is not None: + v = max(lower, v) + if upper is not None: + v = min(upper, v) + return v + + for _, r in summ.iterrows(): + meth = r[method_col] + if meth not in xpos or np.isnan(r["mean"]): + continue + x = xpos[meth] + m = clip(r["mean"]) + lo = clip(r["lo"]) if not np.isnan(r["lo"]) else m + hi = clip(r["hi"]) if not np.isnan(r["hi"]) else m + + # mean line (thick) + CI whisker & caps (thin) + ax.hlines( + m, x - line_halfwidth, x + line_halfwidth, colors=color, lw=2.6, zorder=6 + ) + ax.vlines(x, lo, hi, colors=color, lw=1.2, alpha=0.9, zorder=5) + ax.hlines( + [lo, hi], + x - cap_halfwidth, + x + cap_halfwidth, + colors=color, + lw=1.2, + alpha=0.9, + zorder=5, + ) + + +###################### task_timing_stats ###################### + + +def as_long_df(d, value_col, task_col="task"): + rows = [] + for t, vals in d.items(): + for v in vals: + rows.append({task_col: t, value_col: v}) + return pd.DataFrame(rows) + + +# --- median labels with matching hue colors (log-safe) --- +def annotate_medians_by_geometry( + ax, + df_long, + x_col, + hue_col, + y_col, + x_order, + hue_order, + fmt="{:.0f}", # ints; change to "{:.2g}" if you prefer + y_nudge_factor=1.08, + bin_halfwidth=0.6, + bbox_alpha=0.9, +): + def _luminance(r, g, b): + # simple relative luminance for contrast + return 0.299 * r + 0.587 * g + 0.114 * b + + # collect box patches and centers + patches = [ + p for p in getattr(ax, "artists", []) if isinstance(p, mpl.patches.PathPatch) + ] + if not patches: + patches = [p for p in ax.patches if isinstance(p, mpl.patches.PathPatch)] + + boxes = [] + for p in patches: + verts = p.get_path().vertices + xs = verts[:, 0] + x_center = 0.5 * (xs.min() + xs.max()) + boxes.append((x_center, p)) + + if not boxes: + return + + # bin by x tick index (0..len(x_order)-1) + boxes_by_tick = {i: [] for i in range(len(x_order))} + for x_center, p in boxes: + idx = int(round(x_center)) + if idx in boxes_by_tick and abs(x_center - idx) <= bin_halfwidth: + boxes_by_tick[idx].append((x_center, p)) + + # medians from data + med_dict = df_long.groupby([x_col, hue_col])[y_col].median().to_dict() + + for i, task in enumerate(x_order): + group = boxes_by_tick.get(i, []) + if not group: + continue + # left->right inside this task bin + group.sort(key=lambda t: t[0]) + + for j, hue in enumerate(hue_order): + if j >= len(group): + break + x_center, patch = group[j] + med = med_dict.get((task, hue), np.nan) + if not (np.isfinite(med) and med > 0): + continue + + # extract the exact facecolor of this box (matches legend/palette) + fc = patch.get_facecolor() # RGBA + if fc is None or len(fc) < 3: + # fallback (rare): use current color cycle + fc = ax._get_lines.get_next_color() + # normalize to RGBA + fc = mpl.colors.to_rgba(fc) + + r, g, b, a = fc + # adjust alpha for the textbox so it’s legible + fc_box = (r, g, b, bbox_alpha) + + # choose black/white text for contrast + txt_color = "black" if _luminance(r, g, b) > 0.6 else "white" + + ax.text( + x_center, + med * y_nudge_factor, + fmt.format(med), + ha="center", + va="center", + fontsize=9, + fontweight="bold", + color=txt_color, + bbox=dict(boxstyle="round,pad=0.2", fc=fc_box, ec="none"), + zorder=100, + clip_on=False, + ) + + +# ---------- helpers: median ordering + median labeler (single-category boxplot) ---------- +def order_by_median_dict(d, reverse=True): + """Return (ordered_task_names, stats_dict) where stats_dict[task]=(median, std).""" + stats = {t: (np.median(vals), np.std(vals)) for t, vals in d.items() if len(vals) > 0} + ordered = sorted(stats.keys(), key=lambda t: stats[t][0], reverse=reverse) + return ordered, stats + + +def annotate_medians_single_boxplot( + ax, df_long, x_col, y_col, order, fmt="{:.2f}", box_alpha=0.90 +): + """ + Annotate the median for each category on a seaborn.boxplot *without hue*. + Places the number at the geometric center of each box, using the box facecolor for the label bg. + Call this AFTER setting any y-limits (so the nudge uses final limits). + """ + + # compute medians in plotting order + med = df_long.groupby(x_col)[y_col].median().reindex(order) + + # collect PathPatches for boxes (artists in most seaborn versions; fallback to patches) + patches = [ + p for p in getattr(ax, "artists", []) if isinstance(p, mpl.patches.PathPatch) + ] + if not patches: + patches = [p for p in ax.patches if isinstance(p, mpl.patches.PathPatch)] + + n = min(len(patches), len(order)) + ymin, ymax = ax.get_ylim() + dy = 0.02 * (ymax - ymin) # small additive nudge in data units + + for k in range(n): + patch = patches[k] + verts = patch.get_path().vertices + xs, _ = verts[:, 0], verts[:, 1] + x_center = 0.5 * (xs.min() + xs.max()) + + m = med.iloc[k] + if not np.isfinite(m): + continue + + # label background color = box facecolor (match legend/palette) + fc = patch.get_facecolor() + if fc is None or len(fc) < 3: + fc = mpl.colors.to_rgba("white", box_alpha) + else: + fc = (fc[0], fc[1], fc[2], box_alpha) + + # text color for contrast (simple luminance check) + lum = 0.299 * fc[0] + 0.587 * fc[1] + 0.114 * fc[2] + txt_color = "black" if lum > 0.6 else "white" + + # keep label inside the axis (avoid hitting the top bound) + y_text = min(m + dy, ymax - 0.01 * (ymax - ymin)) + + ax.text( + x_center, + y_text, + fmt.format(m), + ha="center", + va="center", + fontsize=9, + fontweight="bold", + color=txt_color, + bbox=dict(boxstyle="round,pad=0.2", fc=fc, ec="none"), + zorder=100, + clip_on=False, + ) + + +###################### task_presence_binarization ###################### + +###################### dfc_visualization ###################### + + +def _window_indices( + trs, window_len=8, center="middle", center_time=None, center_index=None, interval=None +): + T = len(trs) + trs = np.asarray(trs) + if interval is not None: + t0, t1 = interval + idxs = np.where((trs >= t0) & (trs <= t1))[0] + if len(idxs) == 0: + raise ValueError("interval produced no indices; check units.") + return idxs + if center_index is not None: + c = int(np.clip(center_index, 0, T - 1)) + elif center_time is not None: + c = int(np.argmin(np.abs(trs - center_time))) + else: + c = (T - 1) // 2 + half = window_len // 2 + start = max(0, c - half) + end = min(T, start + window_len) + start = max(0, end - window_len) + return np.arange(start, end, dtype=int) + + +def _common_limits(dfc_dict, robust_percentile=(2, 98), symmetric=True): + vals = [] + for A in dfc_dict.values(): + R = A.shape[1] + iu = np.triu_indices(R, 1) + vals.append(A[:, iu[0], iu[1]].ravel()) + lo, hi = np.percentile(np.concatenate(vals), robust_percentile) + if symmetric: + m = max(abs(lo), abs(hi)) + return -m, m + return lo, hi + + +def figure_dfc_matrices_window_png( + dfc_dict, + trs, + window_len=8, + center="middle", + center_time=None, + center_index=None, + interval=None, + cmap="coolwarm", + outfile="fig_dfc_window.png", + show_region_ticks=False, + region_labels=None, + draw_network_bounds=None, + dpi=600, + transparent=False, + # style knobs + method_label_size=11, + tr_label_size=10, + cbar_label_size=11, + rotate_method_labels=90, + method_label_pad=18, # << controls distance between method names and images + wspace=None, # << override column spacing if needed (None = auto) +): + import matplotlib as mpl + import matplotlib.pyplot as plt + import numpy as np + from matplotlib import gridspec + + mpl.rcParams.update( + { + "figure.dpi": dpi, + "savefig.dpi": dpi, + "pdf.fonttype": 42, + "ps.fonttype": 42, + "font.size": 8, + "axes.titlesize": tr_label_size, + "axes.labelsize": method_label_size, + } + ) + + methods = list(dfc_dict.keys()) + R = next(iter(dfc_dict.values())).shape[1] + + idxs = _window_indices( + trs, + window_len=window_len, + center=center, + center_time=center_time, + center_index=center_index, + interval=interval, + ) + + vmin, vmax = _common_limits(dfc_dict, robust_percentile=(2, 98), symmetric=True) + vmin = 0 + + # figure sizing + col_width = 1.6 + row_height = 1.5 + nrows, ncols = len(methods), len(idxs) + + fig = plt.figure(figsize=((ncols + 0.5) * col_width, nrows * row_height)) + + # spacing + auto_wspace = min(0.35, 0.12 + 0.01 * ncols) + wspace = auto_wspace if wspace is None else wspace + hspace = 0.25 + + # add a dedicated colorbar column on the far right + gs = gridspec.GridSpec( + nrows, + ncols + 1, + width_ratios=[1] * ncols + [0.06], # last slot = colorbar + hspace=hspace, + wspace=wspace, + ) + + last_im = None + for r, m in enumerate(methods): + A = dfc_dict[m] + for c, t_idx in enumerate(idxs): + ax = fig.add_subplot(gs[r, c]) + M = A[t_idx].copy() + np.fill_diagonal(M, np.nan) + im = ax.imshow(M, vmin=vmin, vmax=vmax, cmap=cmap, interpolation="none") + last_im = im + + if draw_network_bounds: + for b in draw_network_bounds: + ax.axhline(b - 0.5, linewidth=0.4, color="k") + ax.axvline(b - 0.5, linewidth=0.4, color="k") + + if show_region_ticks and region_labels is not None: + step = max(1, R // 16) + ticks = np.arange(0, R, step) + ax.set_xticks(ticks) + ax.set_yticks(ticks) + ax.set_xticklabels( + [region_labels[i] for i in ticks], rotation=90, fontsize=6 + ) + ax.set_yticklabels([region_labels[i] for i in ticks], fontsize=6) + else: + ax.set_xticks([]) + ax.set_yticks([]) + for s in ax.spines.values(): + s.set_visible(False) + + if r == 0: + label = ( + f"TR{trs[t_idx]}" + if np.issubdtype(np.asarray(trs).dtype, np.number) + else str(trs[t_idx]) + ) + ax.set_title(label, pad=6, fontsize=tr_label_size, fontweight="bold") + + if c == 0: + ax.set_ylabel( + m, + rotation=rotate_method_labels, + labelpad=method_label_pad, # << tighten/loosen here + va="center", + ha="center", + fontsize=method_label_size, + fontweight="bold", + ) + ax.yaxis.set_label_position("left") + + # colorbar in its own axis (no overlap) + cax = fig.add_subplot(gs[:, -1]) + cbar = fig.colorbar(last_im, cax=cax) + cbar.set_label("Connectivity", fontsize=cbar_label_size, fontweight="bold") + cbar.ax.tick_params(labelsize=max(8, cbar_label_size - 1)) + + fig.subplots_adjust(left=0.10, right=0.98, top=0.95, bottom=0.05) + fig.savefig( + outfile, + bbox_inches="tight", + pad_inches=0.02, + transparent=transparent, + facecolor="white", + ) + plt.close(fig) + print( + f"Saved {outfile} | TR columns: {len(idxs)} | vmin={vmin:.3f}, vmax={vmax:.3f} | dpi={dpi}" + ) + + +###################### sample_matrix plots ###################### + + +def nice_step(n, max_ticks=10): + """Return a 'nice' step (1-2-5x10^k) to keep ≤ max_ticks across [1..n].""" + if n <= 1: + return 1 + raw = max(1.0, n / max(2, (max_ticks - 1))) + exp = np.floor(np.log10(raw)) + frac = raw / (10**exp) + base = 1 if frac <= 1 else 2 if frac <= 2 else 5 if frac <= 5 else 10 + return int(base * (10**exp)) + + +def plot_samples_features( + X, + y, + *, + sample_order="original", # "original" | "label" | "label+cluster" | "cluster" + feature_order="original", # "original" | "tstat" + col_order_from_train=None, # optional np.ndarray (feature indices) to reuse on test + ZSCORE=True, + V_RANGE=None, + cmap="coolwarm", + title=None, + save_path=None, + show=True, +): + """ + X: (n_samples, n_features) matrix (features in columns) + y: (n_samples,) binary (0=rest, 1=task) + + Samples are shown along the horizontal axis (time-like), features along the vertical axis. + If feature_order == "tstat", a slim vertical t-stat bar is shown on the LEFT, + aligned 1:1 with feature rows (no top t-bar). + """ + # ---------- prep ---------- + X = np.asarray(X, float) + y = np.asarray(y) + n_samples, n_features = X.shape + + # z-score per feature + Xz = X.copy() + if ZSCORE: + mu = Xz.mean(axis=0, keepdims=True) + sd = Xz.std(axis=0, keepdims=True) + 1e-8 + Xz = (Xz - mu) / sd + + # ---------- feature order ---------- + if feature_order == "tstat": + if col_order_from_train is not None: + col_order = np.asarray(col_order_from_train, int) + t, _ = ttest_ind(Xz[y == 1], Xz[y == 0], axis=0, equal_var=False) + t_ord = t[col_order] + else: + t, _ = ttest_ind(Xz[y == 1], Xz[y == 0], axis=0, equal_var=False) + col_order = np.argsort(-np.abs(t)) # strongest contrast first + t_ord = t[col_order] + else: + col_order = np.arange(n_features) + t_ord = None # no t-stat bar + + # ---------- sample order ---------- + if sample_order == "original": + row_order = np.arange(n_samples) + split = np.sum(y == 0) + draw_separator = False + elif sample_order == "label": + rest_idx = np.where(y == 0)[0] + task_idx = np.where(y == 1)[0] + row_order = np.r_[rest_idx, task_idx] + split = len(rest_idx) + draw_separator = True + elif sample_order == "label+cluster": + + def order_rows(A): + if len(A) <= 2: + return np.arange(len(A)) + return leaves_list(linkage(A, method="average", metric="cosine")) + + rest_idx = np.where(y == 0)[0] + task_idx = np.where(y == 1)[0] + rest_order = rest_idx[order_rows(Xz[rest_idx])] if len(rest_idx) else rest_idx + task_order = task_idx[order_rows(Xz[task_idx])] if len(task_idx) else task_idx + row_order = np.r_[rest_order, task_order] + split = len(rest_order) + draw_separator = True + elif sample_order == "cluster": + + def order_rows(A): + if len(A) <= 2: + return np.arange(len(A)) + return leaves_list(linkage(A, method="average", metric="cosine")) + + all_idx = np.arange(n_samples) + # rest_order = rest_idx[order_rows(Xz[rest_idx])] if len(rest_idx) else rest_idx + # task_order = task_idx[order_rows(Xz[task_idx])] if len(task_idx) else task_idx + + row_order = all_idx[order_rows(Xz[all_idx])] + split = np.sum(y == 0) + draw_separator = False + else: + raise ValueError( + "sample_order must be one of {'original','label','label+cluster'}" + ) + + # ---------- figure & layout (no top t-bar) ---------- + # W = max(10, min(24, n_samples / 30)) + w_min = 12 + w_max = 24 + width_per_100 = 0.5 # additional width per 100 samples + W = float(np.clip(w_min + (n_samples / 100.0) * width_per_100, w_min, w_max)) + H = max(6, min(16, n_features / 30)) + fig = plt.figure(figsize=(W, H)) + + gs = fig.add_gridspec( + nrows=2, + ncols=1, + height_ratios=[1.0, 0.06], # main heatmap + class strip + hspace=0.08, + ) + ax_main = fig.add_subplot(gs[0, 0]) + ax_lab = fig.add_subplot(gs[1, 0]) + + # --- VRANGE --- + if V_RANGE is None: + Xflat = np.asarray(Xz, float).ravel() + lo, hi = np.nanpercentile(Xflat, [5, 95]) # robust to outliers; tweak if needed + V_RANGE = max(abs(lo), abs(hi)) # symmetric around 0 (for diverging cmap) + + # ---------- main heatmap ---------- + img = Xz[row_order, :][:, col_order].T # (features, samples) + ax_main.imshow( + img, aspect="auto", origin="lower", cmap=cmap, vmin=-V_RANGE, vmax=V_RANGE + ) + n_features = img.shape[0] + last_idx = n_features - 1 + + if n_features < 10: + # every feature: labels 1..n, positions 0..n-1 + labels_1based = np.arange(1, n_features + 1, dtype=int) + else: + step = nice_step(n_features, max_ticks=10) + # use round multiples of the step + labels_1based = list(np.arange(step, n_features + 1, step, dtype=int)) + # de-dup & sort (in case step == 1) + labels_1based = np.unique(labels_1based) + + # convert 1-based labels to 0-based tick positions + ticks_pos = labels_1based - 1 + + # lock y-limits so the last tick isn't clipped + ax_main.set_ylim(-0.5, last_idx + 0.5) + + # set ticks & labels + ax_main.set_yticks(ticks_pos) + ax_main.set_yticklabels([f"{v:d}" for v in labels_1based]) + ax_main.set_ylabel("feature", fontsize=18, fontweight="bold") + # ax_main.set_xlabel("sample", fontsize=18, fontweight="bold") + # ax_main.set_xticks([]) + ax_main.tick_params(axis="y", labelsize=18) + ax_main.tick_params(axis="x", labelsize=18) + + if draw_separator and 0 < split < n_samples: + ax_main.axvline(split - 0.5, color="k", lw=2) + + # ---------- bottom class strip ---------- + y_reordered = y[row_order] + cmap_lbl = ListedColormap( + [[0.85, 0.85, 0.85], [0.25, 0.5, 0.9]] + ) # rest=gray, task=blue + ax_lab.imshow( + y_reordered[None, :], aspect="auto", origin="lower", cmap=cmap_lbl, vmin=0, vmax=1 + ) + ax_lab.set_yticks([]) + ax_lab.set_xticks([]) + # ax_lab.set_title("class", fontsize=11, pad=2) + + # show class labels only when there is label grouping + if draw_separator: + n0 = (y_reordered == 0).sum() + n1 = (y_reordered == 1).sum() + if n0 > 0: + x0 = (n0 - 1) / 2.0 + ax_lab.annotate( + "rest (0)", + xy=(x0, -0.35), + xycoords=("data", "axes fraction"), + ha="center", + va="top", + fontsize=18, + fontweight="bold", + ) + if n1 > 0: + x1 = n0 + (n1 - 1) / 2.0 + ax_lab.annotate( + "task (1)", + xy=(x1, -0.35), + xycoords=("data", "axes fraction"), + ha="center", + va="top", + fontsize=18, + fontweight="bold", + ) + + # --- move the class bar (ax_lab) down a bit --- + fig.canvas.draw() # ensure positions are current + lab_box = ax_lab.get_position() # [x0, y0, width, height] in figure coords + down = 0.070 # how much to move down (figure fraction) + new_y0 = max(0.01, lab_box.y0 - down) # keep it inside the figure + ax_lab.set_position([lab_box.x0, new_y0, lab_box.width, lab_box.height]) + + # after you position ax_lab (i.e., after ax_lab.set_position([...])) + ax_lab.xaxis.set_label_position("top") + ax_lab.set_xlabel("sample", labelpad=4, fontweight="bold", fontsize=18) + # keep the strip clean + ax_lab.tick_params( + axis="x", which="both", length=0, labelbottom=False, labeltop=False + ) + + # (re)grab the updated box for the colorbar placement that comes next + lab_box = ax_lab.get_position() + + # ---------- LEFT vertical t-stat bar (only if feature_order=="tstat") ---------- + if t_ord is not None: + fig.canvas.draw() + main_box = ax_main.get_position() # figure coords + + tbar_left_width = 0.010 # ~2% fig width + tbar_left_pad = 0.035 / W * 24 # gap from main heatmap, proportional to fig width + + x0 = max(0.01, main_box.x0 - tbar_left_pad - tbar_left_width) + y0 = main_box.y0 + w = tbar_left_width + h = main_box.height + + ax_tleft = fig.add_axes([x0, y0, w, h]) + m = np.nanmax(np.abs(t_ord)) if np.isfinite(t_ord).any() else 1.0 + ax_tleft.imshow( + t_ord[:, None], origin="lower", aspect="auto", cmap=cmap, vmin=-m, vmax=m + ) + ax_tleft.set_xticks([]) + ax_tleft.set_yticks([]) + ax_tleft.set_title("t-stat", fontsize=11, pad=2, fontweight="bold") + + if title: + fig.suptitle(title, y=0.995, fontsize=12, fontweight="bold") + + if save_path: + fig.savefig(save_path, dpi=300, bbox_inches="tight", pad_inches=0.15) + if show: + plt.show() + else: + plt.close(fig) + + return dict(row_order=row_order, col_order=col_order) + + +def save_scalar_colorbar( + cmap="coolwarm", + vmin=-2.0, + vmax=2.0, # use the same V_RANGE you use in plots + label="z-scored feature value", + filename="zscore_colorbar.png", + orientation="horizontal", + figsize=(6, 0.4), # width, height in inches + dpi=300, + ticks=None, +): + """ + Saves a standalone scalar colorbar image you can reuse in the paper. + """ + # Make a dummy mappable with the correct colormap and limits + from matplotlib.cm import ScalarMappable + from matplotlib.colors import Normalize + + fig = plt.figure(figsize=figsize, dpi=dpi) + ax = fig.add_axes( + [0.05, 0.35, 0.90, 0.30] + if orientation == "horizontal" + else [0.35, 0.05, 0.30, 0.90] + ) + + sm = ScalarMappable(norm=Normalize(vmin=vmin, vmax=vmax), cmap=cmap) + sm.set_array([]) + + cb = plt.colorbar(sm, cax=ax, orientation=orientation) + cb.set_label(label, fontsize=18, fontweight="bold") + + if ticks is not None: + cb.set_ticks(ticks) + cb.set_ticklabels([str(t) for t in ticks]) + cb.ax.tick_params(labelsize=18) + + fig.savefig(filename, bbox_inches="tight", pad_inches=0.02) + plt.close(fig) diff --git a/task_dFC/multi_dataset_analysis/ml_results.py b/task_dFC/multi_dataset_analysis/ml_results.py new file mode 100644 index 0000000..f522886 --- /dev/null +++ b/task_dFC/multi_dataset_analysis/ml_results.py @@ -0,0 +1,811 @@ +import argparse +import json +import os +import sys + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib.colors import to_rgba +from matplotlib.ticker import PercentFormatter + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from helper_functions import ( # pyright: ignore[reportMissingImports] + boldify_axes, + build_experiment_display_info, + relabel_heatmap_rows, + savefig_pub, + setup_pub_style, +) + +LEVEL = "group_lvl" +KEYS_NOT_TO_INCLUDE = [ + "Logistic regression permutation p_value", + "Logistic regression permutation score mean", + "Logistic regression permutation score std", + "SVM permutation p_value", + "SVM permutation score mean", + "SVM permutation score std", +] +GROUP = "test" +TARGETS = [ + ("PCA", "Logistic regression balanced accuracy"), + ("PLS", "Logistic regression balanced accuracy"), + ("PCA", "SVM balanced accuracy"), + ("PLS", "SVM balanced accuracy"), + ("PCA", "SI"), + ("PLS", "SI"), +] +TOP_EXPERIMENT_SHAPES = 3 +TOP_EXPERIMENT_MARKERS = ["*"] # star for all top experiments +COLOR_THRESHOLD = 60.0 +PER_METHOD_LABEL_SCORE_THRESHOLD = 55.0 +SIMULATED_METHOD_MEDIAN_ANNOTATION_THRESHOLD = 80.0 +NEUTRAL_COLOR = "#D49B9B" + + +def parse_args(): + helptext = """ + Script to make figures/tables from multi-dataset ML results. + """ + parser = argparse.ArgumentParser(description=helptext) + parser.add_argument( + "--multi_dataset_info", type=str, help="path to multi-dataset info file" + ) + parser.add_argument( + "--simul_or_real", type=str, help="Specify 'simulated' or 'real' data" + ) + return parser.parse_args() + + +def read_json(json_file): + with open(json_file, "r") as f: + return json.load(f) + + +def get_analysis_config(multi_dataset_info, simul_or_real): + if simul_or_real == "real": + return multi_dataset_info["real_data"] + if simul_or_real == "simulated": + return multi_dataset_info["simulated_data"] + raise ValueError(f"Invalid simul_or_real: {simul_or_real}") + + +def get_classification_input_dir(ml_root, dataset_info): + sessions = dataset_info.get("SESSIONS") or [None] + session = sessions[0] + if session is None: + return f"{ml_root}/classification" + return f"{ml_root}/classification/{session}" + + +def filter_ml_scores(ml_scores_new, tasks_to_include): + filtered_scores = { + key: [] for key in ml_scores_new[LEVEL].keys() if key not in KEYS_NOT_TO_INCLUDE + } + + for index, task in enumerate(ml_scores_new[LEVEL]["task"]): + if task not in tasks_to_include: + continue + for key in filtered_scores: + filtered_scores[key].append(ml_scores_new[LEVEL][key][index]) + + return filtered_scores + + +def merge_ml_scores(all_ml_scores, ml_scores_new_updated): + if all_ml_scores is None: + return ml_scores_new_updated + + for key, values in ml_scores_new_updated.items(): + if key in all_ml_scores: + all_ml_scores[key].extend(values) + return all_ml_scores + + +def collect_all_ml_scores(main_root, datasets, tasks_to_include): + all_ml_scores = None + + for dataset in datasets: + print(f"Processing dataset: {dataset}") + dataset_info_file = f"{main_root}/{dataset}/codes/dataset_info.json" + ml_root = f"{main_root}/{dataset}/derivatives/ML" + dataset_info = read_json(dataset_info_file) + input_dir = get_classification_input_dir(ml_root, dataset_info) + + if not os.path.exists(input_dir): + print( + f"Input directory {input_dir} does not exist. Skipping dataset {dataset}." + ) + continue + + all_ml_scores_files = [ + filename + for filename in os.listdir(input_dir) + if "ML_scores_classify_" in filename + ] + + for filename in all_ml_scores_files: + try: + ml_scores_new = np.load( + f"{input_dir}/{filename}", allow_pickle=True + ).item() + filtered_scores = filter_ml_scores(ml_scores_new, tasks_to_include) + all_ml_scores = merge_ml_scores(all_ml_scores, filtered_scores) + except Exception as error: + print(f"Error loading {filename}: {error}") + continue + + return all_ml_scores + + +def validate_score_lengths(all_ml_scores): + if all_ml_scores is None: + return + + lengths = [len(values) for values in all_ml_scores.values()] + if len(set(lengths)) != 1: + print( + "Warning: Not all keys have the same length in ALL_ML_SCORES. " + f"key and length pairs: {dict(zip(all_ml_scores.keys(), lengths))}" + ) + + +def save_all_ml_scores(all_ml_scores, output_root, simul_or_real): + os.makedirs(output_root, exist_ok=True) + np.save(f"{output_root}/ALL_ML_SCORES_{simul_or_real}.npy", all_ml_scores) + + +def prepare_metric_dataframe( + all_ml_scores, tasks_to_include, embedding, metric, simul_or_real +): + df = pd.DataFrame.from_dict(all_ml_scores) + df = df[df["task"].isin(tasks_to_include)] + df = df[(df["embedding"] == embedding) & (df["group"] == GROUP)].copy() + + method_order = sorted(df["dFC method"].unique(), key=lambda method: method.lower()) + df["dFC method"] = pd.Categorical( + df["dFC method"], categories=method_order, ordered=True + ) + + task_order, task_to_experiment, experiment_order, experiment_palette = ( + build_experiment_display_info( + df["task"].unique(), + task_reference_order=tasks_to_include, + simul_or_real=simul_or_real, + ) + ) + df["experiment"] = df["task"].map(task_to_experiment) + + return ( + df, + method_order, + task_order, + task_to_experiment, + experiment_order, + experiment_palette, + ) + + +def build_best_and_multi_tables(df, metric): + counts_task = df.groupby("task")["run"].nunique() + multi_tasks = counts_task[counts_task > 1].index + df_multi = df[df["task"].isin(multi_tasks)].copy() + + df_best = ( + df.sort_values(["task", "dFC method", metric], ascending=[True, True, False]) + .drop_duplicates(subset=["task", "dFC method"], keep="first") + .rename(columns={metric: "score"}) + ) + + return df_best, df_multi + + +def get_pointplot_limits(metric): + if metric == "SI": + return -1.0, 1.0 + return 0.5, 1.0 + + +def convert_threshold_to_score_scale(threshold, metric): + if metric != "SI" and threshold > 1.0: + return threshold / 100.0 + return threshold + + +def get_heatmap_limits(metric): + if metric == "SI": + return None, 1.0, 0.0 + return 0.5 - 1e-6, 1.0, 0.5 + + +def style_boxplot(ax, box_edge): + for artist in ax.artists: + artist.set_edgecolor(box_edge) + facecolor = artist.get_facecolor() + artist.set_facecolor((facecolor[0], facecolor[1], facecolor[2], 0.12)) + for line in ax.lines: + line.set_color(box_edge) + line.set_alpha(0.5) + line.set_zorder(1) + + +def overlay_method_means(ax, df_best, lower, upper): + means = df_best.groupby("dFC method", observed=True)["score"].mean() + xticks = ax.get_xticks() + xticklabels = [tick.get_text() for tick in ax.get_xticklabels()] + x_positions = {label: xticks[index] for index, label in enumerate(xticklabels)} + + halfwidth = 0.1 + for method, mean_score in means.items(): + if method not in x_positions or pd.isna(mean_score): + continue + mean_score = min(upper, max(lower, mean_score)) + x_position = x_positions[method] + ax.hlines( + mean_score, + x_position - halfwidth, + x_position + halfwidth, + colors="#050505", + lw=2.4, + zorder=3, + ) + + +def finalize_marker_edges(ax): + for line in ax.lines: + try: + line.set_markeredgecolor("#222222") + line.set_markeredgewidth(0.8) + except Exception: + pass + + +def get_colored_experiment_mask(df_best, color_threshold=COLOR_THRESHOLD): + """Return set of experiments with max score >= color_threshold across all methods.""" + max_scores = df_best.groupby("experiment", observed=True)["score"].max() + return set(max_scores[max_scores >= color_threshold].index) + + +def get_top_experiments_by_mean(df_best, top_experiment_shapes=TOP_EXPERIMENT_SHAPES): + if top_experiment_shapes <= 0: + return [] + return ( + df_best.groupby("experiment", observed=True)["score"] + .mean() + .sort_values(ascending=False) + .head(top_experiment_shapes) + .index.tolist() + ) + + +def create_neutral_palette(experiment_order, colored_experiments, vibrant_palette): + """Create palette: neutral for non-colored experiments, vibrant for colored ones.""" + palette = {} + for exp in experiment_order: + if exp in colored_experiments: + palette[exp] = vibrant_palette[exp] + else: + palette[exp] = NEUTRAL_COLOR + return palette + + +def extract_pointplot_coordinates(ax, method_order, experiment_order, experiment_palette): + candidate_lines = [] + for line in ax.lines: + x_data = np.asarray(line.get_xdata(), dtype=float) + y_data = np.asarray(line.get_ydata(), dtype=float) + marker = line.get_marker() + if marker in {None, "", "None", " "}: + continue + if x_data.size != len(method_order) or y_data.size != len(method_order): + continue + candidate_lines.append(line) + + assigned_lines = { + experiment: candidate_lines[idx] + for idx, experiment in enumerate(experiment_order) + if idx < len(candidate_lines) + } + + coordinates = {} + for experiment, line in assigned_lines.items(): + x_data = np.asarray(line.get_xdata(), dtype=float) + y_data = np.asarray(line.get_ydata(), dtype=float) + coordinates[experiment] = {} + for method_index, method in enumerate(method_order): + y_value = y_data[method_index] + if np.isnan(y_value): + continue + coordinates[experiment][method] = (x_data[method_index], y_value) + return coordinates + + +def resize_colored_markers( + ax, + experiment_order, + colored_experiments, + method_order, + base_size=5, + colored_size=8, +): + """Make circles for colored (high-performing) experiments slightly bigger.""" + candidate_lines = [] + for line in ax.lines: + x_data = np.asarray(line.get_xdata(), dtype=float) + y_data = np.asarray(line.get_ydata(), dtype=float) + marker = line.get_marker() + if marker in {None, "", "None", " "}: + continue + if x_data.size != len(method_order) or y_data.size != len(method_order): + continue + candidate_lines.append(line) + + for idx, experiment in enumerate(experiment_order): + if idx >= len(candidate_lines): + break + size = colored_size if experiment in colored_experiments else base_size + candidate_lines[idx].set_markersize(size) + + +def overlay_top_experiment_shapes( + ax, + df_best, + point_coordinates, + shape_palette, + top_experiment_shapes, +): + if top_experiment_shapes <= 0: + return + + top_experiments = get_top_experiments_by_mean(df_best, top_experiment_shapes) + + for rank, experiment in enumerate(top_experiments): + if experiment not in point_coordinates: + continue + marker = TOP_EXPERIMENT_MARKERS[rank % len(TOP_EXPERIMENT_MARKERS)] + points = list(point_coordinates[experiment].values()) + if not points: + continue + x_vals = [pt[0] for pt in points] + y_vals = [pt[1] for pt in points] + ax.scatter( + x_vals, + y_vals, + marker=marker, + s=250, + c=shape_palette[experiment], + edgecolors="#111111", + linewidths=1.0, + zorder=8, + ) + + +def annotate_per_method_quartile( + ax, + df_best, + point_coordinates, + method_order, + colored_experiments, + metric, + simul_or_real, + score_threshold=PER_METHOD_LABEL_SCORE_THRESHOLD, +): + """ + Default behavior: + - annotate colored experiments in top quartile and above score_threshold. + + Simulated + non-SI override: + - if a method median is above SIMULATED_METHOD_MEDIAN_ANNOTATION_THRESHOLD, + annotate all experiments for that method. + """ + simulated_non_si = simul_or_real == "simulated" and metric != "SI" + if not colored_experiments and not simulated_non_si: + return + + simulated_median_threshold = convert_threshold_to_score_scale( + SIMULATED_METHOD_MEDIAN_ANNOTATION_THRESHOLD, metric + ) + + xticks = ax.get_xticks() + xticklabels = [t.get_text() for t in ax.get_xticklabels()] + method_positions = {lab: xticks[i] for i, lab in enumerate(xticklabels)} + + for method in method_order: + method_df = df_best[df_best["dFC method"] == method] + if method_df.empty: + continue + + scores = method_df["score"].values + quartile_threshold = np.percentile(scores, 75) + + if simulated_non_si and np.nanmedian(scores) > simulated_median_threshold: + qualify_rows = method_df + else: + qualify_rows = method_df[ + method_df["experiment"].isin(colored_experiments) + & (method_df["score"] > score_threshold) + & (method_df["score"] >= quartile_threshold) + ] + + method_center = method_positions[method] + + for _, row in qualify_rows.iterrows(): + experiment = row["experiment"] + if experiment not in point_coordinates: + continue + if method not in point_coordinates[experiment]: + continue + + x_value, y_value = point_coordinates[experiment][method] + + # Position text left or right based on point position + if x_value < method_center: + ha_align = "right" + x_offset = -10 + else: + ha_align = "left" + x_offset = 10 + + ax.annotate( + experiment, + xy=(x_value, y_value), + xytext=(x_offset, 0), + textcoords="offset points", + ha=ha_align, + va="center", + fontsize=7, + fontweight="bold", + color="#1A1A1A", + bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.75), + zorder=9, + ) + + +def plot_best_pointplot( + df_best, + method_order, + experiment_order, + experiment_palette, + output_root, + embedding, + metric, + simul_or_real, +): + # Keep the original width scaling so method spacing is unchanged; + # reduce only the height to improve aspect ratio. + plot_width = max(11, 0.6 * len(method_order)) + plot_height = 5.6 + figure, ax = plt.subplots(figsize=(plot_width, plot_height)) + + color_threshold = convert_threshold_to_score_scale(COLOR_THRESHOLD, metric) + label_threshold = convert_threshold_to_score_scale( + PER_METHOD_LABEL_SCORE_THRESHOLD, metric + ) + + top_experiments = get_top_experiments_by_mean(df_best, TOP_EXPERIMENT_SHAPES) + + # SI policy: color/annotate only star experiments. + if metric == "SI": + colored_experiments = set(top_experiments) + label_threshold = -np.inf + else: + # Identify experiments with high performance (>= COLOR_THRESHOLD) + colored_experiments = get_colored_experiment_mask(df_best, color_threshold) + + # Create neutral palette: vibrant for high performers, neutral for others + neutral_palette = create_neutral_palette( + experiment_order, colored_experiments, experiment_palette + ) + + box_face = to_rgba("#DE9995", 0.18) + box_edge = "#730800" + + sns.boxplot( + data=df_best, + x="dFC method", + y="score", + order=method_order, + whis=(5, 95), + fliersize=0, + linewidth=1.0, + width=0.2, + color=box_face, + ax=ax, + zorder=1, + ) + style_boxplot(ax, box_edge) + + lower, upper = get_pointplot_limits(metric) + overlay_method_means(ax, df_best, lower, upper) + + # Draw pointplot with neutral palette + sns.pointplot( + data=df_best, + x="dFC method", + y="score", + hue="experiment", + order=method_order, + hue_order=experiment_order, + dodge=0.4, + errorbar=None, + linestyles="", + markers="o", + palette=neutral_palette, + ax=ax, + zorder=6, + ) + finalize_marker_edges(ax) + resize_colored_markers(ax, experiment_order, colored_experiments, method_order) + + # Extract point coordinates from the pointplot + point_coordinates = extract_pointplot_coordinates( + ax, + method_order, + experiment_order, + neutral_palette, + ) + + # Overlay shapes for top 3 experiments using vibrant palette + overlay_top_experiment_shapes( + ax, + df_best, + point_coordinates, + neutral_palette, + top_experiment_shapes=TOP_EXPERIMENT_SHAPES, + ) + + # Annotate per-method quartile points + annotate_per_method_quartile( + ax, + df_best, + point_coordinates, + method_order, + colored_experiments=colored_experiments, + metric=metric, + simul_or_real=simul_or_real, + score_threshold=label_threshold, + ) + + ax.set_xlabel("dFC method") + ax.set_ylabel(metric) + if metric == "SI": + ax.set_ylim(top=1.02) + else: + ax.set_ylim(0.48, 1.02) + ax.yaxis.set_major_formatter(PercentFormatter(xmax=1.0, decimals=0)) + ax.grid(True, axis="y", color="#FFFFFF", alpha=0.85, linewidth=1.1) + sns.despine(ax=ax, top=True, right=True) + plt.setp(ax.get_xticklabels(), rotation=35, ha="right") + + if ax.legend_: + ax.legend_.remove() + + boldify_axes(ax, xlabel="dFC method", ylabel=metric) + figure.tight_layout() + + savefig_pub( + f"{output_root}/ML_scores_{embedding}_{metric}_{LEVEL}_{simul_or_real}_best.png" + ) + plt.close(figure) + + +def plot_best_heatmap( + df_best, + method_order, + task_order, + task_to_experiment, + output_root, + embedding, + metric, + simul_or_real, +): + matrix_best = df_best.pivot(index="task", columns="dFC method", values="score") + annot_best = df_best.assign( + label=lambda df_plot: df_plot["score"].map(lambda value: f"{value:.2f}") + ).pivot(index="task", columns="dFC method", values="label") + + matrix_best, annot_best, _ = relabel_heatmap_rows( + matrix_best, + annot_best, + task_reference_order=task_order, + task_to_experiment=task_to_experiment, + ) + col_order = [method for method in method_order if method in matrix_best.columns] + + if simul_or_real == "real": + width = max(10, 0.65 * len(col_order)) + height = max(6.0, 0.30 * len(matrix_best.index)) + else: + width = max(11, 11 / 7 * len(col_order)) + height = max(7.0, 0.35 * len(matrix_best.index)) + + figure, ax = plt.subplots(figsize=(width, height)) + vmin, vmax, center = get_heatmap_limits(metric) + heatmap = sns.heatmap( + matrix_best.loc[:, col_order], + vmin=vmin, + vmax=vmax, + center=center, + cmap="coolwarm", + annot=annot_best.loc[:, col_order], + fmt="", + annot_kws={"fontsize": 9, "fontweight": "bold", "linespacing": 1.15}, + cbar_kws={"shrink": 0.7, "pad": 0.02}, + ax=ax, + ) + colorbar = heatmap.collections[0].colorbar + colorbar.set_label(metric, fontsize=10, fontweight="bold") + colorbar.ax.tick_params(labelsize=9) + + boldify_axes(ax, xlabel="dFC method", ylabel="Experiment", rotate_xticks=35) + ax.set_xlabel("dFC method") + ax.set_ylabel("Experiment") + plt.setp(ax.get_xticklabels(), fontweight="bold", rotation=35, ha="right") + plt.setp(ax.get_yticklabels(), fontweight="bold") + sns.despine(ax=ax, top=True, right=True) + plt.tight_layout() + savefig_pub( + f"{output_root}/ML_scores_heatmap_{embedding}_{metric}_{LEVEL}_{simul_or_real}_best.png" + ) + plt.close(figure) + + +def build_across_heatmap_data(df_multi, metric, task_order, task_to_experiment): + summary = ( + df_multi.groupby(["task", "dFC method"], observed=True)[metric] + .agg(n="count", med="median", vmin="min", vmax="max") + .reset_index() + ) + + matrix_across = summary.pivot(index="task", columns="dFC method", values="med") + annot_across = summary.assign( + label=lambda df_plot: df_plot["vmin"].map(lambda value: f"{value:.2f}") + + "\u2013" + + df_plot["vmax"].map(lambda value: f"{value:.2f}") + + "\n" + + df_plot["n"].map(lambda value: f"n={value}") + ).pivot(index="task", columns="dFC method", values="label") + + return relabel_heatmap_rows( + matrix_across, + annot_across, + task_reference_order=task_order, + task_to_experiment=task_to_experiment, + ) + + +def plot_across_heatmap( + df_multi, + method_order, + task_order, + task_to_experiment, + output_root, + embedding, + metric, + simul_or_real, +): + if df_multi.empty: + print( + f"[ACROSS-RUN] No tasks with ≥2 runs for {embedding} / {metric} — skipping across-run figures." + ) + return + + matrix_across, annot_across, _ = build_across_heatmap_data( + df_multi, + metric, + task_order, + task_to_experiment, + ) + col_order = [method for method in method_order if method in matrix_across.columns] + width = max(9.0, 11 / 7 * len(col_order)) + height = max(7.0, 7 / 20 * len(matrix_across.index)) + + figure, ax = plt.subplots(figsize=(width, height)) + vmin, vmax, center = get_heatmap_limits(metric) + heatmap = sns.heatmap( + matrix_across.loc[:, col_order], + vmin=vmin, + vmax=vmax, + center=center, + cmap="coolwarm", + annot=annot_across.loc[:, col_order], + fmt="", + annot_kws={"fontsize": 9, "fontweight": "bold", "linespacing": 1.15}, + cbar_kws={"shrink": 0.7, "pad": 0.02}, + ax=ax, + ) + + colorbar = heatmap.collections[0].colorbar + colorbar.set_label(metric, fontsize=10, fontweight="bold") + boldify_axes(ax, xlabel="dFC method", ylabel="Experiment", rotate_xticks=35) + ax.set_xlabel("dFC method") + ax.set_ylabel("Experiment") + plt.setp(ax.get_xticklabels(), fontweight="bold", rotation=35, ha="right") + plt.setp(ax.get_yticklabels(), fontweight="bold") + sns.despine(ax=ax, top=True, right=True) + plt.tight_layout() + savefig_pub( + f"{output_root}/ML_scores_heatmap_{embedding}_{metric}_{LEVEL}_{simul_or_real}_across.png" + ) + plt.close(figure) + + +def generate_all_plots(all_ml_scores, tasks_to_include, output_root, simul_or_real): + sns.set_context("paper", font_scale=1.0, rc={"lines.linewidth": 1.2}) + sns.set_style("darkgrid") + + for embedding, metric in TARGETS: + ( + df, + method_order, + task_order, + task_to_experiment, + experiment_order, + experiment_palette, + ) = prepare_metric_dataframe( + all_ml_scores, + tasks_to_include, + embedding, + metric, + simul_or_real, + ) + df_best, df_multi = build_best_and_multi_tables(df, metric) + + plot_best_pointplot( + df_best, + method_order, + experiment_order, + experiment_palette, + output_root, + embedding, + metric, + simul_or_real, + ) + plot_best_heatmap( + df_best, + method_order, + task_order, + task_to_experiment, + output_root, + embedding, + metric, + simul_or_real, + ) + plot_across_heatmap( + df_multi, + method_order, + task_order, + task_to_experiment, + output_root, + embedding, + metric, + simul_or_real, + ) + + +def main(): + args = parse_args() + setup_pub_style() + + multi_dataset_info = read_json(args.multi_dataset_info) + analysis_config = get_analysis_config(multi_dataset_info, args.simul_or_real) + main_root = analysis_config["main_root"] + datasets = analysis_config["DATASETS"] + tasks_to_include = analysis_config["TASKS_to_include"] + output_root = f"{multi_dataset_info['output_root']}/ML_results" + + all_ml_scores = collect_all_ml_scores(main_root, datasets, tasks_to_include) + validate_score_lengths(all_ml_scores) + save_all_ml_scores(all_ml_scores, output_root, args.simul_or_real) + generate_all_plots( + all_ml_scores, + tasks_to_include, + output_root, + args.simul_or_real, + ) + + +if __name__ == "__main__": + main() diff --git a/task_dFC/multi_dataset_analysis/performance_factor.py b/task_dFC/multi_dataset_analysis/performance_factor.py new file mode 100644 index 0000000..6d66d19 --- /dev/null +++ b/task_dFC/multi_dataset_analysis/performance_factor.py @@ -0,0 +1,1153 @@ +import argparse +import json +import os +import sys + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib.ticker import MultipleLocator + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from helper_functions import ( # pyright: ignore[reportMissingImports] + RDoC_MAP, + canon_task, + savefig_pub, + setup_pub_style, +) + +LEVEL = "group_lvl" +GROUP = "test" + +CLASSIFIER_METRIC_MAP = { + "Logistic regression": "Logistic regression balanced accuracy", + "SVM": "SVM balanced accuracy", +} + +TIMING_FEATURES = [ + "task_ratio_avg", + "transition_freq_avg", + "OI_median", + "rest_durations_median", + "task_durations_median", + "rest_durations_iqr", + "task_durations_iqr", +] + +COHEN_FEATURES = [ + "CohensD_max", + "CohensD_mean", +] + +TSNR_FEATURES = [ + "median_tsnr_avg_over_subjects", +] + +CORR_EXCLUDE_COLUMNS = { + "RDoC", + "task", + "run", + "dFC assessment method", + "classifier model", + "embedding", + "classification_balanced_accuracy", +} + +TOP_BOTTOM_QUANTILE = 0.2 +PERFORMANCE_GROUP_LABELS = ["Low", "Medium", "High"] + +DEFAULT_FACTOR_LABEL_MAP = { + "task_ratio_avg": "average task ratio", + "task_durations_iqr": "task duration IQR", + "task_durations_median": "task duration median", + "rest_durations_iqr": "rest duration IQR", + "rest_durations_median": "rest duration median", + "OI_median": "median OI", + "CohensD_mean": "mean Cohen's d", + "CohensD_max": "max Cohen's d", + "transition_freq_avg": "average transition frequency", + "median_tsnr_avg_over_subjects": "median tSNR averaged over subjects", +} + + +def get_domain_axis_label(simul_or_real): + return "RDoC domain" if simul_or_real == "real" else "Simulation design category" + + +def parse_args(): + helptext = """ + Build a unified run-level dataframe linking ML performance to task factors. + """ + parser = argparse.ArgumentParser(description=helptext) + parser.add_argument( + "--multi_dataset_info", + type=str, + required=True, + help="path to multi-dataset info file", + ) + parser.add_argument( + "--simul_or_real", + type=str, + required=True, + choices=["simulated", "real"], + help="Specify 'simulated' or 'real' data", + ) + return parser.parse_args() + + +def read_json(json_file): + with open(json_file, "r") as file_obj: + return json.load(file_obj) + + +def load_npy_dict(path, label): + assert os.path.exists(path), f"{label} file does not exist: {path}" + loaded = np.load(path, allow_pickle=True) + if isinstance(loaded, np.ndarray): + loaded = loaded.item() + assert isinstance(loaded, dict), f"{label} must be a dictionary. Got {type(loaded)}" + return loaded + + +def assert_required_keys(data_dict, required_keys, label): + missing = [key for key in required_keys if key not in data_dict] + assert not missing, f"Missing required keys in {label}: {missing}" + + +def dict_to_df(data_dict, label): + lengths = {key: len(value) for key, value in data_dict.items()} + unique_lengths = set(lengths.values()) + assert len(unique_lengths) == 1, ( + f"Inconsistent column lengths in {label}: {lengths}. " + "All arrays/lists must have equal length." + ) + return pd.DataFrame.from_dict(data_dict) + + +def normalize_run(value): + if value is None: + return "none" + if isinstance(value, float) and np.isnan(value): + return "none" + # TSV empty cells are read by pandas as NaN (float) handled above, + # but guard against empty strings too (e.g. after manual editing). + if str(value).strip() == "": + return "none" + return str(value).strip().lower() + + +def add_join_keys(df): + assert "task" in df.columns, "Expected column 'task'" + assert "run" in df.columns, "Expected column 'run'" + df = df.copy() + df["task_key"] = df["task"].astype(str).map(canon_task) + df["run_key"] = df["run"].map(normalize_run) + assert (df["task_key"].str.len() > 0).all(), "Found empty normalized task key" + return df + + +def get_paths(multi_dataset_info, simul_or_real): + output_root = multi_dataset_info["output_root"] + return { + "ml": f"{output_root}/ML_results/ALL_ML_SCORES_{simul_or_real}.npy", + "timing": f"{output_root}/task_timing_stats/{simul_or_real}/task_timing_stats_{simul_or_real}.npy", + "cohensd": f"{output_root}/CohensD/{simul_or_real}/CohensD_ML_{simul_or_real}.npy", + "tsnr": f"{output_root}/t-SNR/tsnr_summary_grouped.tsv", + "out_dir": f"{output_root}/performance_factor/{simul_or_real}", + } + + +def prepare_ml_df(ml_scores_all): + ml_scores = ml_scores_all + + required_keys = [ + "task", + "run", + "embedding", + "dFC method", + "group", + *CLASSIFIER_METRIC_MAP.values(), + ] + assert_required_keys(ml_scores, required_keys, "ALL_ML_SCORES") + + df_ml_wide = dict_to_df(ml_scores, "ALL_ML_SCORES") + df_ml_wide = df_ml_wide[df_ml_wide["group"] == GROUP].copy() + assert not df_ml_wide.empty, f"No ML rows found for group='{GROUP}'" + + if "dataset" in df_ml_wide.columns: + id_cols = ["dataset", "task", "run", "embedding", "dFC method", "group"] + else: + id_cols = ["task", "run", "embedding", "dFC method", "group"] + + classifier_frames = [] + for classifier, metric_key in CLASSIFIER_METRIC_MAP.items(): + frame = df_ml_wide[id_cols + [metric_key]].copy() + frame["classifier model"] = classifier + frame = frame.rename(columns={metric_key: "classification_balanced_accuracy"}) + classifier_frames.append(frame) + + df_ml = pd.concat(classifier_frames, ignore_index=True) + df_ml = df_ml.rename(columns={"dFC method": "dFC assessment method"}) + + score = df_ml["classification_balanced_accuracy"].astype(float) + assert np.isfinite(score).all(), "ML performance contains NaN/Inf values" + assert ((score >= 0.0) & (score <= 1.0)).all(), ( + "Expected balanced accuracy in [0, 1]. " + f"Observed min={score.min()}, max={score.max()}" + ) + + return add_join_keys(df_ml) + + +def prepare_timing_df(timing_dict): + required_keys = ["task", "run", *TIMING_FEATURES] + assert_required_keys(timing_dict, required_keys, "task_timing_stats") + df_timing = dict_to_df(timing_dict, "task_timing_stats") + + keep_cols = ["task", "run", *TIMING_FEATURES] + if "dataset" in df_timing.columns: + keep_cols = ["dataset", *keep_cols] + df_timing = df_timing[keep_cols].copy() + + for col in TIMING_FEATURES: + values = df_timing[col].astype(float) + assert np.isfinite(values).all(), f"Timing feature '{col}' contains NaN/Inf" + + return add_join_keys(df_timing) + + +def prepare_cohensd_df(cohensd_dict): + required_keys = ["task", "run", *COHEN_FEATURES] + assert_required_keys(cohensd_dict, required_keys, "CohensD_ML") + df_cohensd = dict_to_df(cohensd_dict, "CohensD_ML") + + keep_cols = ["task", "run", *COHEN_FEATURES] + if "dataset" in df_cohensd.columns: + keep_cols = ["dataset", *keep_cols] + df_cohensd = df_cohensd[keep_cols].copy() + + for col in COHEN_FEATURES: + values = df_cohensd[col].astype(float) + assert np.isfinite(values).all(), f"Cohen's D feature '{col}' contains NaN/Inf" + + return add_join_keys(df_cohensd) + + +def prepare_tsnr_df(tsnr_path): + assert os.path.exists(tsnr_path), f"tSNR file does not exist: {tsnr_path}" + df_tsnr = pd.read_csv(tsnr_path, sep="\t") + + required_cols = ["dataset", "task", "run", *TSNR_FEATURES] + missing = [col for col in required_cols if col not in df_tsnr.columns] + assert not missing, f"Missing required columns in tsnr_summary_grouped.tsv: {missing}" + + df_tsnr = df_tsnr[["dataset", "task", "run", *TSNR_FEATURES]].copy() + + # Validate tSNR values (allow NaN — runs with no data are left empty as specified) + tsnr_vals = df_tsnr["median_tsnr_avg_over_subjects"].astype(float) + assert ( + tsnr_vals.dropna() > 0 + ).all(), ( + "median_tsnr_avg_over_subjects contains non-positive values where data is present" + ) + + return add_join_keys(df_tsnr) + + +def choose_join_keys(df_ml, df_timing, df_cohensd, df_tsnr=None): + sources = [df_ml, df_timing, df_cohensd] + if df_tsnr is not None: + sources.append(df_tsnr) + + has_dataset_everywhere = all("dataset" in df.columns for df in sources) + + base_keys = ["task_key", "run_key"] + dataset_keys = ["dataset", *base_keys] + + timing_dupes_base = df_timing.duplicated(subset=base_keys).sum() + cohensd_dupes_base = df_cohensd.duplicated(subset=base_keys).sum() + tsnr_dupes_base = ( + df_tsnr.duplicated(subset=base_keys).sum() if df_tsnr is not None else 0 + ) + + if timing_dupes_base == 0 and cohensd_dupes_base == 0 and tsnr_dupes_base == 0: + return base_keys + + if has_dataset_everywhere: + timing_dupes_dataset = df_timing.duplicated(subset=dataset_keys).sum() + cohensd_dupes_dataset = df_cohensd.duplicated(subset=dataset_keys).sum() + tsnr_dupes_dataset = ( + df_tsnr.duplicated(subset=dataset_keys).sum() if df_tsnr is not None else 0 + ) + assert timing_dupes_dataset == 0, ( + "task_timing_stats still has duplicate rows per dataset/task/run after " + f"normalization. duplicate_count={timing_dupes_dataset}" + ) + assert cohensd_dupes_dataset == 0, ( + "CohensD_ML still has duplicate rows per dataset/task/run after " + f"normalization. duplicate_count={cohensd_dupes_dataset}" + ) + if df_tsnr is not None: + assert tsnr_dupes_dataset == 0, ( + "tsnr_summary_grouped still has duplicate rows per dataset/task/run after " + f"normalization. duplicate_count={tsnr_dupes_dataset}" + ) + return dataset_keys + + raise AssertionError( + "Ambiguous join on task/run (duplicates found), and dataset is not available " + "in all sources to disambiguate." + ) + + +def merge_with_checks(df_ml, df_timing, df_cohensd, join_keys, df_tsnr=None): + timing_cols = join_keys + TIMING_FEATURES + cohensd_cols = join_keys + COHEN_FEATURES + + df_merged = df_ml.merge( + df_timing[timing_cols], + on=join_keys, + how="left", + validate="many_to_one", + indicator="timing_merge", + ) + timing_unmatched = (df_merged["timing_merge"] != "both").sum() + assert ( + timing_unmatched == 0 + ), f"Could not match timing stats for {timing_unmatched} ML rows using keys {join_keys}" + df_merged = df_merged.drop(columns=["timing_merge"]) + + df_merged = df_merged.merge( + df_cohensd[cohensd_cols], + on=join_keys, + how="left", + validate="many_to_one", + indicator="cohensd_merge", + ) + cohensd_unmatched = (df_merged["cohensd_merge"] != "both").sum() + assert ( + cohensd_unmatched == 0 + ), f"Could not match Cohen's D stats for {cohensd_unmatched} ML rows using keys {join_keys}" + df_merged = df_merged.drop(columns=["cohensd_merge"]) + + if df_tsnr is not None: + tsnr_cols = join_keys + TSNR_FEATURES + # tSNR: left join — rows with no tSNR data (e.g. None-run datasets not in file) + # will have NaN, which is acceptable as specified. + df_merged = df_merged.merge( + df_tsnr[tsnr_cols], + on=join_keys, + how="left", + validate="many_to_one", + ) + + return df_merged + + +def add_rdoc(df, simul_or_real): + task_to_domain = RDoC_MAP[simul_or_real]["TASK2DOMAIN"] + df = df.copy() + df["RDoC"] = df["task_key"].map(task_to_domain) + + missing_mask = df["RDoC"].isna() + if missing_mask.any(): + missing_tasks = sorted(df.loc[missing_mask, "task"].astype(str).unique()) + raise AssertionError( + "Missing RDoC mapping for tasks (after canonicalization): " + f"{missing_tasks}. Update helper_functions.RDoC_MAP if needed." + ) + + return df + + +def finalize_columns(df): + cols = [] + if "dataset" in df.columns: + cols.append("dataset") + cols += [ + "task", + "run", + "RDoC", + *TIMING_FEATURES, + *COHEN_FEATURES, + "dFC assessment method", + "classifier model", + "embedding", + "classification_balanced_accuracy", + ] + + if all(col in df.columns for col in TSNR_FEATURES): + insert_at = cols.index("dFC assessment method") + cols[insert_at:insert_at] = TSNR_FEATURES + + missing = [col for col in cols if col not in df.columns] + assert not missing, f"Missing expected final columns: {missing}" + + out = df[cols].copy() + sort_cols = ["task", "run", "dFC assessment method", "classifier model", "embedding"] + if "dataset" in out.columns: + sort_cols = ["dataset", *sort_cols] + out = out.sort_values(sort_cols).reset_index(drop=True) + return out + + +def save_outputs(df, out_dir, simul_or_real): + os.makedirs(out_dir, exist_ok=True) + csv_path = f"{out_dir}/performance_factor_{simul_or_real}.csv" + pkl_path = f"{out_dir}/performance_factor_{simul_or_real}.pkl" + df.to_csv(csv_path, index=False) + df.to_pickle(pkl_path) + return csv_path, pkl_path + + +def build_correlation_table(df): + numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() + factor_cols = [ + col + for col in numeric_cols + if col not in CORR_EXCLUDE_COLUMNS and col != "classification_balanced_accuracy" + ] + assert factor_cols, "No numeric factor columns available for correlation analysis" + + rows = [] + for method, group_df in df.groupby("dFC assessment method", observed=True): + for factor in factor_cols: + pair_df = group_df[[factor, "classification_balanced_accuracy"]].dropna() + n_samples = len(pair_df) + + if ( + n_samples < 3 + or pair_df[factor].nunique(dropna=True) < 2 + or pair_df["classification_balanced_accuracy"].nunique(dropna=True) < 2 + ): + corr = np.nan + else: + corr = pair_df[factor].corr( + pair_df["classification_balanced_accuracy"], method="pearson" + ) + + rows.append( + { + "factor": factor, + "dFC assessment method": method, + "correlation": corr, + "n_samples": n_samples, + } + ) + + corr_df = pd.DataFrame(rows) + corr_df["factor"] = pd.Categorical( + corr_df["factor"], categories=factor_cols, ordered=True + ) + corr_df = corr_df.sort_values(["factor", "dFC assessment method"]).reset_index( + drop=True + ) + return corr_df + + +def get_numeric_factor_columns(df): + numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() + factor_cols = [ + col + for col in numeric_cols + if col not in CORR_EXCLUDE_COLUMNS and col != "classification_balanced_accuracy" + ] + assert factor_cols, "No numeric factor columns available for analysis" + return factor_cols + + +def plot_factor_correlation_pointplot(corr_df, out_dir, simul_or_real): + valid_df = corr_df.dropna(subset=["correlation"]).copy() + assert ( + not valid_df.empty + ), "All factor correlations are NaN; cannot generate correlation pointplot" + + n_factors = valid_df["factor"].nunique() + width = max(10, 0.75 * n_factors) + height = 7.0 + + figure, ax = plt.subplots(figsize=(width, height)) + sns.pointplot( + data=valid_df, + x="factor", + y="correlation", + hue="dFC assessment method", + dodge=0.4, + errorbar=None, + markers="o", + linestyles="", + ax=ax, + ) + + ax.axhline(0.0, color="#333333", linestyle="--", linewidth=1.0) + ax.set_ylim(-1.05, 1.05) + ax.set_xlabel("Factor") + ax.set_ylabel("Corr. with balanced accuracy") + plt.setp(ax.get_xticklabels(), rotation=35, ha="right") + ax.tick_params(axis="x", labelsize=11) + ax.tick_params(axis="y", labelsize=11) + + ax.yaxis.set_major_locator(MultipleLocator(0.25)) + ax.yaxis.set_minor_locator(MultipleLocator(0.125)) + ax.grid(True, axis="y", which="major", linestyle="-", alpha=0.4) + ax.grid(True, axis="y", which="minor", linestyle="--", alpha=0.22) + + ax.legend(title="dFC assessment method", frameon=True) + sns.despine(ax=ax, top=True, right=True) + figure.tight_layout() + + fig_path = f"{out_dir}/performance_factor_correlation_pointplot_{simul_or_real}.png" + savefig_pub(fig_path) + plt.close(figure) + return fig_path + + +def build_top_bottom_profile_table(df, quantile=TOP_BOTTOM_QUANTILE): + assert 0 < quantile < 0.5, "quantile must be in (0, 0.5)" + + factor_cols = get_numeric_factor_columns(df) + + rows = [] + for method, method_df in df.groupby("dFC assessment method", observed=True): + score = method_df["classification_balanced_accuracy"].astype(float) + low_thr = score.quantile(quantile) + high_thr = score.quantile(1 - quantile) + + bottom_df = method_df[score <= low_thr].copy() + top_df = method_df[score >= high_thr].copy() + + if len(top_df) < 3 or len(bottom_df) < 3: + print( + f"[TopBottom] Skipping method '{method}' due to too few samples " + f"(top={len(top_df)}, bottom={len(bottom_df)})." + ) + continue + + for factor in factor_cols: + top_vals = top_df[factor].astype(float).dropna() + bottom_vals = bottom_df[factor].astype(float).dropna() + n_top = len(top_vals) + n_bottom = len(bottom_vals) + + mean_top = np.nan if n_top == 0 else float(top_vals.mean()) + mean_bottom = np.nan if n_bottom == 0 else float(bottom_vals.mean()) + mean_diff = mean_top - mean_bottom if (n_top > 0 and n_bottom > 0) else np.nan + + cohens_d = np.nan + if n_top >= 2 and n_bottom >= 2: + var_top = float(np.var(top_vals, ddof=1)) + var_bottom = float(np.var(bottom_vals, ddof=1)) + pooled_num = ((n_top - 1) * var_top) + ((n_bottom - 1) * var_bottom) + pooled_den = n_top + n_bottom - 2 + if pooled_den > 0: + pooled_std = np.sqrt(pooled_num / pooled_den) + if np.isfinite(pooled_std) and pooled_std > 0: + cohens_d = mean_diff / pooled_std + + rows.append( + { + "factor": factor, + "dFC assessment method": method, + "mean_top": mean_top, + "mean_bottom": mean_bottom, + "mean_diff": mean_diff, + "cohens_d": cohens_d, + "n_top": n_top, + "n_bottom": n_bottom, + "low_threshold": float(low_thr), + "high_threshold": float(high_thr), + "n_method_total": int(len(method_df)), + } + ) + + assert rows, "No method had enough samples for top-vs-bottom profile analysis" + + profile_df = pd.DataFrame(rows) + profile_df["abs_cohens_d"] = profile_df["cohens_d"].abs() + profile_df = profile_df.sort_values( + ["abs_cohens_d", "factor", "dFC assessment method"], + ascending=[False, True, True], + ).reset_index(drop=True) + return profile_df + + +def plot_top_bottom_profile(profile_df, out_dir, simul_or_real, factor_label_map=None): + valid_df = profile_df.dropna(subset=["cohens_d"]).copy() + assert ( + not valid_df.empty + ), "No valid Cohen's d values available for top-vs-bottom profile plot" + + factor_label_map = factor_label_map or DEFAULT_FACTOR_LABEL_MAP + valid_df["factor_display"] = ( + valid_df["factor"].map(factor_label_map).fillna(valid_df["factor"].astype(str)) + ) + + factor_order = ( + valid_df.groupby("factor", observed=True)["abs_cohens_d"] + .max() + .sort_values(ascending=True) + .index.tolist() + ) + valid_df = valid_df.sort_values(["factor", "dFC assessment method"]) + factor_display_order = [ + factor_label_map.get(factor, factor) for factor in factor_order + ] + valid_df["factor_display"] = pd.Categorical( + valid_df["factor_display"], categories=factor_display_order, ordered=True + ) + + method_order = sorted(valid_df["dFC assessment method"].astype(str).unique()) + vivid_palette = sns.color_palette("tab10", n_colors=len(method_order)) + method_palette = {m: vivid_palette[i] for i, m in enumerate(method_order)} + + # More generous height keeps rows legible when many factors are shown. + height = max(7.0, 0.72 * len(factor_order)) + width = 16.5 + figure, ax = plt.subplots(figsize=(width, height)) + + # Alternating row bands make it easier to track each factor across methods. + for idx, factor in enumerate(factor_order): + if idx % 2 == 0: + ax.axhspan(idx - 0.5, idx + 0.5, color="#F4F6FA", alpha=0.75, zorder=0) + + sns.stripplot( + data=valid_df, + x="cohens_d", + y="factor_display", + hue="dFC assessment method", + order=factor_display_order, + hue_order=method_order, + palette=method_palette, + dodge=True, + jitter=0.08, + size=8.2, + linewidth=0.85, + edgecolor="white", + alpha=0.98, + ax=ax, + ) + + max_abs = float(np.nanmax(np.abs(valid_df["cohens_d"].values))) + x_pad = max(0.15, 0.12 * max_abs) + x_lim = max_abs + x_pad + + ax.axvline(0.0, color="#1F1F1F", linestyle="--", linewidth=1.5, zorder=3) + ax.set_xlim(-x_lim, x_lim) + ax.set_xlabel("Effect size (Cohen's d): Top 20% vs Bottom 20% within method") + ax.set_ylabel("Factor") + + ax.tick_params(axis="x", labelsize=12) + ax.tick_params(axis="y", labelsize=12) + + # Keep tick labels readable but not too sparse using a "nice number" step. + span = 2.0 * x_lim + target_ticks = 11 # aim for ~11 major ticks across full span + raw_major_step = span / (target_ticks - 1) + if raw_major_step <= 0: + major_step = 0.5 + else: + exponent = np.floor(np.log10(raw_major_step)) + base = 10.0**exponent + fraction = raw_major_step / base + if fraction <= 1.0: + nice_fraction = 1.0 + elif fraction <= 2.0: + nice_fraction = 2.0 + elif fraction <= 2.5: + nice_fraction = 2.5 + elif fraction <= 5.0: + nice_fraction = 5.0 + else: + nice_fraction = 10.0 + major_step = nice_fraction * base + + minor_step = major_step / 2.0 + ax.xaxis.set_major_locator(MultipleLocator(major_step)) + ax.xaxis.set_minor_locator(MultipleLocator(minor_step)) + ax.grid(True, axis="x", which="major", linestyle="-", linewidth=1.0, alpha=0.35) + ax.grid(True, axis="x", which="minor", linestyle="--", linewidth=0.8, alpha=0.2) + + legend = ax.legend( + title="dFC assessment method", + frameon=True, + loc="upper left", + bbox_to_anchor=(1.01, 1.0), + borderaxespad=0, + ) + if legend is not None: + legend.get_title().set_fontsize(12) + for txt in legend.get_texts(): + txt.set_fontsize(11) + + sns.despine(ax=ax, top=True, right=True) + figure.tight_layout(rect=[0.18, 0, 0.83, 1]) + + fig_path = f"{out_dir}/performance_top_bottom_profile_{simul_or_real}.png" + savefig_pub(fig_path) + plt.close(figure) + return fig_path + + +def _get_present_rdoc_order(df, simul_or_real): + domain_order = RDoC_MAP[simul_or_real]["DOMAIN_ORDER"] + present = set(df["RDoC"].dropna().astype(str).unique()) + ordered = [domain for domain in domain_order if domain in present] + remaining = sorted([domain for domain in present if domain not in ordered]) + return ordered + remaining + + +def add_performance_group(df): + df = df.copy() + score = df["classification_balanced_accuracy"].astype(float) + low_thr = score.quantile(0.25) + high_thr = score.quantile(0.75) + assert ( + low_thr < high_thr + ), "Performance-group thresholds collapsed; cannot form 25/50/25 groups" + + df["performance_group"] = pd.cut( + score, + bins=[-np.inf, low_thr, high_thr, np.inf], + labels=PERFORMANCE_GROUP_LABELS, + include_lowest=True, + ) + assert df["performance_group"].notna().all(), "Failed to assign performance groups" + return df, float(low_thr), float(high_thr) + + +def build_rdoc_performance_group_table(df, simul_or_real): + df_grouped, low_thr, high_thr = add_performance_group(df) + rdoc_order = _get_present_rdoc_order(df_grouped, simul_or_real) + assert rdoc_order, "No RDoC values found for RDoC-performance grouping" + + count_table = ( + df_grouped.groupby(["RDoC", "performance_group"], observed=True) + .size() + .unstack(fill_value=0) + .reindex(index=rdoc_order, columns=PERFORMANCE_GROUP_LABELS, fill_value=0) + ) + proportion_table = count_table.div(count_table.sum(axis=1), axis=0) + assert np.isclose( + proportion_table.sum(axis=1), 1.0 + ).all(), "RDoC performance-group proportions do not sum to 1" + + summary_long = ( + count_table.stack() + .rename("count") + .reset_index() + .rename(columns={"level_1": "performance_group"}) + ) + summary_long["proportion"] = [ + proportion_table.loc[row.RDoC, row.performance_group] + for row in summary_long.itertuples(index=False) + ] + summary_long["low_threshold"] = low_thr + summary_long["high_threshold"] = high_thr + return summary_long, count_table, proportion_table + + +def plot_rdoc_performance_group_stacked_bar( + proportion_table, + out_dir, + simul_or_real, + x_label="RDoC domain", + count_table=None, +): + width = max(10.0, 1.6 * len(proportion_table.index)) + figure, ax = plt.subplots(figsize=(width, 7.2)) + + palette = { + "Low": "#D1495B", + "Medium": "#F4D35E", + "High": "#2A9D8F", + } + proportion_pct = proportion_table.mul(100.0) + bottom = np.zeros(len(proportion_pct.index)) + + for label in PERFORMANCE_GROUP_LABELS: + values = proportion_pct[label].to_numpy() + counts = None + if count_table is not None and label in count_table.columns: + counts = count_table[label].to_numpy() + ax.bar( + proportion_pct.index, + values, + bottom=bottom, + label=label, + color=palette[label], + edgecolor="white", + linewidth=1.0, + ) + + # Annotate each stacked segment with sample count. + if counts is not None: + for i, (val, cnt) in enumerate(zip(values, counts)): + if cnt <= 0 or val <= 0: + continue + y = bottom[i] + 0.5 * val + # Skip tiny slivers to avoid clutter. + if val < 5.0: + continue + ax.text( + i, + y, + f"n={int(cnt)}", + ha="center", + va="center", + fontsize=9, + fontweight="bold", + color="#1F1F1F", + ) + + bottom += values + + for label in ax.get_xticklabels(): + label.set_rotation(25) + label.set_horizontalalignment("right") + label.set_fontsize(12) + label.set_fontweight("bold") + + ax.set_xlabel(x_label, fontweight="bold") + ax.set_ylabel("Samples (%)", fontweight="bold") + ax.set_ylim(0, 100) + ax.yaxis.set_major_locator(MultipleLocator(20)) + ax.yaxis.set_minor_locator(MultipleLocator(10)) + ax.grid(True, axis="y", which="major", linestyle="-", alpha=0.34) + ax.grid(True, axis="y", which="minor", linestyle="--", alpha=0.16) + ax.tick_params(axis="y", labelsize=12) + for label in ax.get_yticklabels(): + label.set_fontweight("bold") + legend = ax.legend( + title="Performance group", frameon=True, fontsize=11, title_fontsize=12 + ) + if legend is not None: + legend.get_title().set_fontweight("bold") + for txt in legend.get_texts(): + txt.set_fontweight("bold") + sns.despine(ax=ax, top=True, right=True) + figure.tight_layout() + + fig_path = f"{out_dir}/performance_group_by_rdoc_stacked_{simul_or_real}.png" + savefig_pub(fig_path) + plt.close(figure) + return fig_path + + +def plot_rdoc_performance_group_heatmap( + proportion_table, + out_dir, + simul_or_real, + x_label="Performance group", + count_table=None, +): + if count_table is not None: + count_view = count_table.loc[ + proportion_table.index, PERFORMANCE_GROUP_LABELS + ].astype(int) + annot_table = proportion_table.loc[:, PERFORMANCE_GROUP_LABELS].mul(100.0) + annot_table = annot_table.apply( + lambda col: [ + f"{pct:.1f}%\n(n={cnt})" + for pct, cnt in zip(col.values, count_view[col.name].values) + ] + ) + else: + annot_table = proportion_table.mul(100.0).applymap(lambda value: f"{value:.1f}%") + + figure, ax = plt.subplots(figsize=(8.6, max(5.2, 0.82 * len(proportion_table.index)))) + heatmap = sns.heatmap( + proportion_table.loc[:, PERFORMANCE_GROUP_LABELS], + cmap="crest", + vmin=0.0, + vmax=1.0, + annot=annot_table.loc[:, PERFORMANCE_GROUP_LABELS], + fmt="", + linewidths=0.9, + linecolor="white", + cbar_kws={"shrink": 0.84, "pad": 0.03}, + ax=ax, + ) + colorbar = heatmap.collections[0].colorbar + colorbar.set_label("Proportion", fontweight="bold", fontsize=12) + + ax.set_xlabel(x_label, fontweight="bold") + ax.set_ylabel("RDoC domain", fontweight="bold") + plt.setp(ax.get_xticklabels(), rotation=0, fontsize=12, fontweight="bold") + plt.setp(ax.get_yticklabels(), rotation=0, fontsize=12, fontweight="bold") + ax.set_title("RDoC composition by performance group", pad=12, fontweight="bold") + figure.tight_layout() + + fig_path = f"{out_dir}/performance_group_by_rdoc_heatmap_{simul_or_real}.png" + savefig_pub(fig_path) + plt.close(figure) + return fig_path + + +def plot_rdoc_overall_distribution(df, out_dir, simul_or_real, x_label="RDoC domain"): + rdoc_order = _get_present_rdoc_order(df, simul_or_real) + assert rdoc_order, "No RDoC values found for plotting" + + width = max(12.0, 1.55 * len(rdoc_order)) + height = 7.0 + figure, ax = plt.subplots(figsize=(width, height)) + + palette = sns.color_palette("Spectral", n_colors=len(rdoc_order)) + palette_map = {rdoc: palette[i] for i, rdoc in enumerate(rdoc_order)} + + sns.boxplot( + data=df, + x="RDoC", + y="classification_balanced_accuracy", + order=rdoc_order, + showfliers=False, + width=0.58, + palette=palette_map, + linewidth=1.2, + ax=ax, + ) + sns.stripplot( + data=df, + x="RDoC", + y="classification_balanced_accuracy", + order=rdoc_order, + palette=palette_map, + alpha=0.45, + size=3.1, + jitter=0.2, + ax=ax, + ) + + ax.set_xlabel(x_label, fontweight="bold") + ax.set_ylabel("Balanced accuracy", fontweight="bold") + ax.set_ylim(0.45, 1.02) + plt.setp( + ax.get_xticklabels(), rotation=25, ha="right", fontsize=12, fontweight="bold" + ) + ax.yaxis.set_major_locator(MultipleLocator(0.05)) + ax.yaxis.set_minor_locator(MultipleLocator(0.025)) + ax.tick_params(axis="y", labelsize=12) + for label in ax.get_yticklabels(): + label.set_fontweight("bold") + ax.grid(True, axis="y", which="major", linestyle="-", alpha=0.34) + ax.grid(True, axis="y", which="minor", linestyle="--", alpha=0.18) + sns.despine(ax=ax, top=True, right=True) + figure.tight_layout() + + fig_path = f"{out_dir}/performance_by_rdoc_overall_{simul_or_real}.png" + savefig_pub(fig_path) + plt.close(figure) + return fig_path + + +def plot_rdoc_faceted_distribution(df, out_dir, simul_or_real, x_label="RDoC domain"): + rdoc_order = _get_present_rdoc_order(df, simul_or_real) + assert rdoc_order, "No RDoC values found for plotting" + + combo_df = ( + df[["classifier model", "embedding"]] + .drop_duplicates() + .sort_values(["classifier model", "embedding"]) + ) + assert not combo_df.empty, "No classifier/embedding combinations found for plotting" + + n_methods = df["dFC assessment method"].nunique() + # Generous per-domain width so boxes never feel cramped + n_domains = len(rdoc_order) + # Each domain gets ~2.8 in; keep figure compact for manuscript layouts. + axes_width = max(17.0, 2.8 * n_domains) + # Small right margin only; legend now sits at the top-right of the full figure. + legend_width = 1.6 + total_width = axes_width + legend_width + # Height: keep panels open and readable + height = max(8.5, 0.42 * n_methods + 6.8) + + fig_paths = [] + for _, combo in combo_df.iterrows(): + classifier = combo["classifier model"] + embedding = combo["embedding"] + + sub_df = df[ + (df["classifier model"] == classifier) & (df["embedding"] == embedding) + ].copy() + if sub_df.empty: + continue + + figure, ax = plt.subplots(figsize=(total_width, height)) + + palette = sns.color_palette("tab10", n_colors=n_methods) + hue_order = sorted(sub_df["dFC assessment method"].dropna().astype(str).unique()) + method_palette = {method: palette[i] for i, method in enumerate(hue_order)} + + sns.boxplot( + data=sub_df, + x="RDoC", + y="classification_balanced_accuracy", + hue="dFC assessment method", + order=rdoc_order, + showfliers=False, + width=0.72, + linewidth=1.35, + palette=method_palette, + hue_order=hue_order, + ax=ax, + ) + + ax.set_ylim(0.45, 1.02) + ax.set_xlabel(x_label, labelpad=12, fontsize=14, fontweight="bold") + ax.set_ylabel("Balanced accuracy", labelpad=12, fontsize=14, fontweight="bold") + ax.set_title( + f"{classifier} | {embedding}", + fontweight="bold", + pad=14, + fontsize=15, + ) + ax.tick_params(axis="both", labelsize=12) + ax.yaxis.set_major_locator(MultipleLocator(0.05)) + ax.yaxis.set_minor_locator(MultipleLocator(0.025)) + ax.grid(True, axis="y", which="major", linestyle="-", alpha=0.36) + ax.grid(True, axis="y", which="minor", linestyle="--", alpha=0.20) + for label in ax.get_xticklabels(): + label.set_rotation(30) + label.set_horizontalalignment("right") + label.set_fontsize(13) + label.set_fontweight("bold") + for label in ax.get_yticklabels(): + label.set_fontweight("bold") + + handles, labels = ax.get_legend_handles_labels() + if handles: + ax.get_legend().remove() + figure.legend( + handles, + labels, + title="dFC assessment method", + title_fontsize=12, + fontsize=11, + frameon=True, + loc="upper right", + bbox_to_anchor=(0.995, 0.995), + ) + if figure.legends: + for legend in figure.legends: + legend.get_title().set_fontweight("bold") + for txt in legend.get_texts(): + txt.set_fontweight("bold") + + sns.despine(ax=ax, top=True, right=True) + # Leave a slim top/right margin for the figure-level legend. + figure.tight_layout(rect=[0, 0, 0.94, 0.96]) + + classifier_key = str(classifier).replace(" ", "_").replace("/", "-") + embedding_key = str(embedding).replace(" ", "_").replace("/", "-") + fig_path = ( + f"{out_dir}/performance_by_rdoc_{classifier_key}" + f"_{embedding_key}_{simul_or_real}.png" + ) + plt.savefig(fig_path, bbox_inches="tight", dpi=150) + plt.close(figure) + fig_paths.append(fig_path) + + assert fig_paths, "No RDoC per-combination figures were generated" + return fig_paths + + +def main(): + args = parse_args() + setup_pub_style() + + multi_dataset_info = read_json(args.multi_dataset_info) + paths = get_paths(multi_dataset_info, args.simul_or_real) + + ml_scores_all = load_npy_dict(paths["ml"], "ALL_ML_SCORES") + timing_dict = load_npy_dict(paths["timing"], "task_timing_stats") + cohensd_dict = load_npy_dict(paths["cohensd"], "CohensD_ML") + + df_ml = prepare_ml_df(ml_scores_all) + df_timing = prepare_timing_df(timing_dict) + df_cohensd = prepare_cohensd_df(cohensd_dict) + df_tsnr = None + if args.simul_or_real == "real": + df_tsnr = prepare_tsnr_df(paths["tsnr"]) + + join_keys = choose_join_keys(df_ml, df_timing, df_cohensd, df_tsnr) + print(f"Using join keys: {join_keys}") + + df = merge_with_checks(df_ml, df_timing, df_cohensd, join_keys, df_tsnr) + df = add_rdoc(df, args.simul_or_real) + df = finalize_columns(df) + + csv_path, pkl_path = save_outputs(df, paths["out_dir"], args.simul_or_real) + + corr_df = build_correlation_table(df) + corr_csv_path = ( + f"{paths['out_dir']}/performance_factor_correlations_{args.simul_or_real}.csv" + ) + corr_df.to_csv(corr_csv_path, index=False) + corr_fig_path = plot_factor_correlation_pointplot( + corr_df, paths["out_dir"], args.simul_or_real + ) + + profile_df = build_top_bottom_profile_table(df, quantile=TOP_BOTTOM_QUANTILE) + profile_csv_path = ( + f"{paths['out_dir']}/performance_top_bottom_profile_{args.simul_or_real}.csv" + ) + profile_df.to_csv(profile_csv_path, index=False) + profile_fig_path = plot_top_bottom_profile( + profile_df, + paths["out_dir"], + args.simul_or_real, + factor_label_map=DEFAULT_FACTOR_LABEL_MAP, + ) + + domain_x_label = get_domain_axis_label(args.simul_or_real) + + rdoc_overall_path = plot_rdoc_overall_distribution( + df, + paths["out_dir"], + args.simul_or_real, + x_label=domain_x_label, + ) + rdoc_faceted_paths = plot_rdoc_faceted_distribution( + df, + paths["out_dir"], + args.simul_or_real, + x_label=domain_x_label, + ) + rdoc_group_long_df, rdoc_group_count_table, rdoc_group_prop_table = ( + build_rdoc_performance_group_table(df, args.simul_or_real) + ) + rdoc_group_csv_path = ( + f"{paths['out_dir']}/performance_group_by_rdoc_{args.simul_or_real}.csv" + ) + rdoc_group_long_df.to_csv(rdoc_group_csv_path, index=False) + rdoc_group_bar_path = plot_rdoc_performance_group_stacked_bar( + rdoc_group_prop_table, + paths["out_dir"], + args.simul_or_real, + x_label=domain_x_label, + count_table=rdoc_group_count_table, + ) + rdoc_group_heatmap_path = plot_rdoc_performance_group_heatmap( + rdoc_group_prop_table, + paths["out_dir"], + args.simul_or_real, + count_table=rdoc_group_count_table, + ) + + print(f"Saved dataframe with shape: {df.shape}") + print(f"CSV: {csv_path}") + print(f"PKL: {pkl_path}") + print(f"Correlation CSV: {corr_csv_path}") + print(f"Correlation figure: {corr_fig_path}") + print(f"Top-bottom profile CSV: {profile_csv_path}") + print(f"Top-bottom profile figure: {profile_fig_path}") + print(f"RDoC overall figure: {rdoc_overall_path}") + print(f"RDoC per-combination figures: {len(rdoc_faceted_paths)} files") + print(f"RDoC performance-group CSV: {rdoc_group_csv_path}") + print(f"RDoC performance-group stacked bar: {rdoc_group_bar_path}") + print(f"RDoC performance-group heatmap: {rdoc_group_heatmap_path}") + + +if __name__ == "__main__": + main() diff --git a/task_dFC/multi_dataset_analysis/performance_predict.py b/task_dFC/multi_dataset_analysis/performance_predict.py new file mode 100644 index 0000000..36ea679 --- /dev/null +++ b/task_dFC/multi_dataset_analysis/performance_predict.py @@ -0,0 +1,527 @@ +import argparse +import json +import os +import sys + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns + +from pydfc import data_loader +from pydfc.data_loader import find_subj_list +from pydfc.ml_utils import find_available_subjects, load_task_data +from pydfc.task_utils import ( + calc_relative_task_on, + calc_rest_duration, + calc_task_duration, + calc_transition_freq, + cohen_d_bold, + compute_optimality_index, + compute_periodicity_index, + extract_task_presence, + periodicity_autocorr, +) + +fig_bbox_inches = "tight" +fig_pad = 0.1 +show_title = False +save_fig_format = "png" # pdf, png, + +level = "group_lvl" +keys_not_to_include = [ + "Logistic regression permutation p_value", + "Logistic regression permutation score mean", + "Logistic regression permutation score std", + "SVM permutation p_value", + "SVM permutation score mean", + "SVM permutation score std", +] + +####################################################################################### + +if __name__ == "__main__": + # argparse + HELPTEXT = """ + Script to predict performance based on task design features and BOLD signals across multiple datasets. + """ + + parser = argparse.ArgumentParser(description=HELPTEXT) + + parser.add_argument( + "--multi_dataset_info", type=str, help="path to multi-dataset info file" + ) + parser.add_argument( + "--simul_or_real", type=str, help="Specify 'simulated' or 'real' data" + ) + + args = parser.parse_args() + + multi_dataset_info = args.multi_dataset_info + simul_or_real = args.simul_or_real + + # Read dataset info + with open(multi_dataset_info, "r") as f: + multi_dataset_info = json.load(f) + + if simul_or_real == "real": + main_root = multi_dataset_info["real_data"]["main_root"] + DATASETS = multi_dataset_info["real_data"]["DATASETS"] + TASKS_to_include = multi_dataset_info["real_data"]["TASKS_to_include"] + elif simul_or_real == "simulated": + main_root = multi_dataset_info["simulated_data"]["main_root"] + DATASETS = multi_dataset_info["simulated_data"]["DATASETS"] + TASKS_to_include = multi_dataset_info["simulated_data"]["TASKS_to_include"] + output_root = ( + f"{multi_dataset_info['output_root']}/performance_predictor/{simul_or_real}" + ) + + if not os.path.exists(output_root): + os.makedirs(output_root) + + task_ratio_all = {} + transition_freq_all = {} + rest_durations_original_all = {} + task_durations_original_all = {} + rest_durations_all = {} + task_durations_all = {} + PI_all = {} + OI_all = {} + PAC_all = {} + for dataset in DATASETS: + + print(f"Processing dataset: {dataset}") + dataset_info_file = f"{main_root}/{dataset}/codes/dataset_info.json" + roi_root = f"{main_root}/{dataset}/derivatives/ROI_timeseries" + dFC_root = f"{main_root}/{dataset}/derivatives/dFC_assessed" + + # Read dataset info + with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + + if "SESSIONS" in dataset_info: + SESSIONS = dataset_info["SESSIONS"] + else: + SESSIONS = None + if SESSIONS is None: + SESSIONS = [None] + + TASKS = dataset_info["TASKS"] + + if "RUNS" in dataset_info: + RUNS = dataset_info["RUNS"] + else: + RUNS = None + if RUNS is None: + RUNS = {task: [None] for task in TASKS} + + for session in SESSIONS: + for task_id, task in enumerate(TASKS): + if not task in TASKS_to_include: + continue + for run in RUNS[task]: + SUBJECTS = find_subj_list(roi_root) + # print(f"Number of subjects: {len(SUBJECTS)}") + + for subj in SUBJECTS: + + try: + task_data = load_task_data( + roi_root=roi_root, + subj=subj, + task=task, + run=run, + session=session, + ) + except FileNotFoundError: + continue + + task_presence, indices = extract_task_presence( + event_labels=task_data["event_labels"], + TR_task=1 / task_data["Fs_task"], + TR_mri=task_data["TR_mri"], + binary=True, + binarizing_method="GMM", + no_hrf=False, + ) + + relative_task_on = calc_relative_task_on(task_presence[indices]) + num_of_transitions, relative_transition_freq = ( + calc_transition_freq(task_presence[indices]) + ) + # calculate rest and task durations based on original event labels + event_labels = np.multiply(task_data["event_labels"] != 0, 1) + rest_durations_original = calc_rest_duration( + event_labels, TR_mri=1 / task_data["Fs_task"] + ) + task_durations_original = calc_task_duration( + event_labels, TR_mri=1 / task_data["Fs_task"] + ) + # calculate rest and task durations based on binary task presence + rest_durations = calc_rest_duration( + task_presence[indices], TR_mri=task_data["TR_mri"] + ) + task_durations = calc_task_duration( + task_presence[indices], TR_mri=task_data["TR_mri"] + ) + # Periodicity Index (low entropy => high periodicity) + out = compute_periodicity_index( + event_labels=event_labels, + TR_task=1 / task_data["Fs_task"], + no_hrf=False, + ) + PI = out["periodicity_index"] + + # Optimality Index (how close the task design is to the optimal design) + out = compute_optimality_index( + event_labels=event_labels, + TR_task=1 / task_data["Fs_task"], + TR_mri=task_data["TR_mri"], + ) + OI = out["OI_norm"] + + # Periodicity via autocorrelation + out = periodicity_autocorr( + event_labels=event_labels, + TR_task=1 / task_data["Fs_task"], + ) + PAC = out["periodicity"] + + if not task in task_ratio_all: + task_ratio_all[task] = [] + if not task in transition_freq_all: + transition_freq_all[task] = [] + if not task in rest_durations_original_all: + rest_durations_original_all[task] = [] + if not task in task_durations_original_all: + task_durations_original_all[task] = [] + if not task in rest_durations_all: + rest_durations_all[task] = [] + if not task in task_durations_all: + task_durations_all[task] = [] + if not task in PI_all: + PI_all[task] = [] + if not task in OI_all: + OI_all[task] = [] + if not task in PAC_all: + PAC_all[task] = [] + task_ratio_all[task].append(relative_task_on) + transition_freq_all[task].append(relative_transition_freq) + # rest_durations and task_durations are lists + rest_durations_original_all[task].extend(rest_durations_original) + task_durations_original_all[task].extend(task_durations_original) + rest_durations_all[task].extend(rest_durations) + task_durations_all[task].extend(task_durations) + PI_all[task].append(PI) + OI_all[task].append(OI) + PAC_all[task].append(PAC) + + task_design_features = { + "task_ratio_all": task_ratio_all, + "transition_freq_all": transition_freq_all, + "rest_durations_original_all": rest_durations_original_all, + "task_durations_original_all": task_durations_original_all, + "rest_durations_all": rest_durations_all, + "task_durations_all": task_durations_all, + "PI_all": PI_all, + "OI_all": OI_all, + "PAC_all": PAC_all, + } + + CohensD_across_task = {} + for dataset in DATASETS: + print(f"Processing dataset: {dataset}") + dataset_info_file = f"{main_root}/{dataset}/codes/dataset_info.json" + roi_root = f"{main_root}/{dataset}/derivatives/ROI_timeseries" + dFC_root = f"{main_root}/{dataset}/derivatives/dFC_assessed" + + # Read dataset info + with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + + if "SESSIONS" in dataset_info: + SESSIONS = dataset_info["SESSIONS"] + else: + SESSIONS = None + if SESSIONS is None: + SESSIONS = [None] + + TASKS = dataset_info["TASKS"] + + if "RUNS" in dataset_info: + RUNS = dataset_info["RUNS"] + else: + RUNS = None + if RUNS is None: + RUNS = {task: [None] for task in TASKS} + + for task in TASKS: + if task not in TASKS_to_include: + print(f"Skipping task {task} as it's not in the inclusion list.") + continue + d_values_all = [] + for session in SESSIONS: + print(f"Processing task: {task}") + SUBJECTS = find_available_subjects( + dFC_root=dFC_root, + task=task, + dFC_id=None, + session=session, + ) + for subj in SUBJECTS: + for run in RUNS[task]: + try: + task_data = load_task_data( + roi_root=roi_root, + subj=subj, + task=task, + run=run, + session=session, + ) + except: + continue + + if run is None: + if session is None: + BOLD_file_name = "{subj_id}_{task}_time-series.npy" + else: + BOLD_file_name = ( + "{subj_id}_{session}_{task}_time-series.npy" + ) + else: + if session is None: + BOLD_file_name = "{subj_id}_{task}_{run}_time-series.npy" + else: + BOLD_file_name = ( + "{subj_id}_{session}_{task}_{run}_time-series.npy" + ) + try: + BOLD = data_loader.load_TS( + data_root=roi_root, + file_name=BOLD_file_name, + subj_id2load=subj, + task=task, + session=session, + run=run, + ) + except Exception as e: + print(f"Error loading BOLD data: {e}") + continue + BOLD_data = BOLD.data # np.ndarray (n_ROIs, n_TRs) + + Fs_task = task_data["Fs_task"] + TR_task = 1 / Fs_task + + TR_array = np.arange(0, BOLD_data.shape[1]) + task_presence, indices = extract_task_presence( + event_labels=task_data["event_labels"], + TR_task=TR_task, + TR_mri=task_data["TR_mri"], + binary=True, + binarizing_method="GMM", + no_hrf=False, + TR_array=TR_array, + ) + + # if n_TRs do not match, align them + if BOLD_data.shape[1] != task_presence.shape[0]: + print( + f"Before alignment, shape of task_presence: {task_presence.shape}, shape of BOLD_data: {BOLD_data.shape}" + ) + min_TRs = min(BOLD_data.shape[1], task_presence.shape[0]) + task_presence = task_presence[:min_TRs] + BOLD_data = BOLD_data[:, :min_TRs] + print( + f"After alignment, shape of task_presence: {task_presence.shape}, shape of BOLD_data: {BOLD_data.shape}" + ) + # also adjust indices + indices = [i for i in indices if i < min_TRs] + task_presence = task_presence[indices] # (n_TRs,) + BOLD_data = BOLD_data[:, indices] # (n_ROIs, n_TRs) + + assert BOLD_data.shape[1] == task_presence.shape[0] + + cohen_d = cohen_d_bold(X=BOLD_data.T, y=task_presence) + d_values_all.append(cohen_d) + + if len(d_values_all) == 0: + print(f"No data found for task {task} in dataset {dataset}. Skipping.") + continue + d_values_all = np.array(d_values_all) # (n_subjectsxrunsxsessions, n_ROIs) + avg_d_values = np.nanmean(d_values_all, axis=0) # (n_ROIs,) + if not task in CohensD_across_task: + CohensD_across_task[task] = [] + CohensD_across_task[task].extend(avg_d_values) + + ALL_ML_SCORES = None + for dataset in DATASETS: + print(f"Processing dataset: {dataset}") + dataset_info_file = f"{main_root}/{dataset}/codes/dataset_info.json" + ML_root = f"{main_root}/{dataset}/derivatives/ML" + + # Read dataset info + with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + + if "SESSIONS" in dataset_info: + SESSIONS = dataset_info["SESSIONS"] + else: + SESSIONS = None + if SESSIONS is None: + SESSIONS = [None] + + TASKS = dataset_info["TASKS"] + + if "RUNS" in dataset_info: + RUNS = dataset_info["RUNS"] + else: + RUNS = None + if RUNS is None: + RUNS = {task: [None] for task in TASKS} + + # find all ML_scores_classify_dFC-id.npy in the ML_root/classfication/ folder + # for now we will only use the first session + session = SESSIONS[0] + if session is None: + input_dir = f"{ML_root}/classification" + else: + input_dir = f"{ML_root}/classification/{session}" + if not os.path.exists(input_dir): + print( + f"Input directory {input_dir} does not exist. Skipping dataset {dataset}." + ) + continue + ALL_ML_SCORES_FILES = os.listdir(input_dir) + ALL_ML_SCORES_FILES = [ + f for f in ALL_ML_SCORES_FILES if "ML_scores_classify_" in f + ] + for f in ALL_ML_SCORES_FILES: + try: + ML_scores_new = np.load(f"{input_dir}/{f}", allow_pickle=True).item() + # ML_scores_new_updated is a new dictionary with same keys as ML_scores_new but empty lists + ML_scores_new_updated = { + key: [] + for key in ML_scores_new[level].keys() + if key not in keys_not_to_include + } + for i in range(len(ML_scores_new[level]["task"])): + if task not in TASKS_to_include: + continue + + for key in ML_scores_new_updated.keys(): + ML_scores_new_updated[key].append(ML_scores_new[level][key][i]) + + if ALL_ML_SCORES is None: + ALL_ML_SCORES = ML_scores_new_updated + else: + for key in ML_scores_new_updated.keys(): + if key in ALL_ML_SCORES: + ALL_ML_SCORES[key].extend(ML_scores_new_updated[key]) + except Exception as e: + print(f"Error loading {f}: {e}") + continue + + # check that the lists in all keys have the same length + if ALL_ML_SCORES is not None: + lengths = [len(v) for v in ALL_ML_SCORES.values()] + if len(set(lengths)) != 1: + print( + f"Warning: Not all keys have the same length in ALL_ML_SCORES. key and length pairs: {dict(zip(ALL_ML_SCORES.keys(), lengths))}" + ) + + embedding = "LE" + metric = "SVM balanced accuracy" + GROUP = "test" + + METHODS = set(ALL_ML_SCORES["dFC method"]) + all_scores = {method: {} for method in METHODS} + for method in METHODS: + for i in range(len(ALL_ML_SCORES["task"])): + if ( + ALL_ML_SCORES["embedding"][i] == embedding + and ALL_ML_SCORES["group"][i] == GROUP + and ALL_ML_SCORES["dFC method"][i] == method + ): + if ALL_ML_SCORES["task"][i] not in all_scores[method]: + all_scores[method][ALL_ML_SCORES["task"][i]] = [] + all_scores[method][ALL_ML_SCORES["task"][i]].append( + ALL_ML_SCORES[metric][i] + ) + + # all_scores[][] is a list of scores across runs + for method in all_scores: + all_scores[method] = {k: np.array(v) for k, v in all_scores[method].items()} + + # we have task design features in task_design_features[task_ratio_all][task], task_design_features[transition_freq_all][task], task_design_features[rest_durations_all][task], task_design_features[task_durations_all][task] + # we have CohensD in CohensD_across_task[task] + # we have ML scores in all_scores[task] + + DATA = { + "task": [], + "task_ratio": [], + "transition_freq": [], + "rest_durations_original_mean": [], + "task_durations_original_mean": [], + "rest_durations_original_std": [], + "task_durations_original_std": [], + "rest_durations_mean": [], + "task_durations_mean": [], + "rest_durations_std": [], + "task_durations_std": [], + "PI_mean": [], + "OI_mean": [], + "PAC_mean": [], + "cohen_d_max": [], + } + for task in TASKS_to_include: + task_ratio = np.mean(task_design_features["task_ratio_all"][task]) + transition_freq = np.mean(task_design_features["transition_freq_all"][task]) + rest_durations_original_mean = np.mean( + task_design_features["rest_durations_original_all"][task] + ) + task_durations_original_mean = np.mean( + task_design_features["task_durations_original_all"][task] + ) + rest_durations_original_std = np.std( + task_design_features["rest_durations_original_all"][task] + ) + task_durations_original_std = np.std( + task_design_features["task_durations_original_all"][task] + ) + rest_durations_mean = np.mean(task_design_features["rest_durations_all"][task]) + task_durations_mean = np.mean(task_design_features["task_durations_all"][task]) + rest_durations_std = np.std(task_design_features["rest_durations_all"][task]) + task_durations_std = np.std(task_design_features["task_durations_all"][task]) + PI_mean = np.mean(PI_all[task]) + OI_mean = np.mean(OI_all[task]) + PAC_mean = np.mean(PAC_all[task]) + cohen_d_max = np.max(np.abs(CohensD_across_task[task])) + + DATA["task"].append(task) + DATA["task_ratio"].append(task_ratio) + DATA["transition_freq"].append(transition_freq) + DATA["rest_durations_original_mean"].append(rest_durations_original_mean) + DATA["task_durations_original_mean"].append(task_durations_original_mean) + DATA["rest_durations_original_std"].append(rest_durations_original_std) + DATA["task_durations_original_std"].append(task_durations_original_std) + DATA["rest_durations_mean"].append(rest_durations_mean) + DATA["task_durations_mean"].append(task_durations_mean) + DATA["rest_durations_std"].append(rest_durations_std) + DATA["task_durations_std"].append(task_durations_std) + DATA["PI_mean"].append(PI_mean) + DATA["OI_mean"].append(OI_mean) + DATA["PAC_mean"].append(PAC_mean) + DATA["cohen_d_max"].append(cohen_d_max) + + # Also add ML scores + for method in all_scores: + if f"classfication_score_{method}" not in DATA: + DATA[f"classfication_score_{method}"] = [] + if task in all_scores[method]: + score_mean = np.mean(all_scores[method][task]) + else: + score_mean = np.nan + DATA[f"classfication_score_{method}"].append(score_mean) + + # save DATA + np.save(f"{output_root}/performance_predictor_data.npy", DATA) diff --git a/task_dFC/multi_dataset_analysis/sample_matrix_visualization.py b/task_dFC/multi_dataset_analysis/sample_matrix_visualization.py new file mode 100644 index 0000000..7f3bfcd --- /dev/null +++ b/task_dFC/multi_dataset_analysis/sample_matrix_visualization.py @@ -0,0 +1,321 @@ +import argparse +import json +import os +import sys + +import numpy as np + +from pydfc.ml_utils import ( + PLSEmbedder, + dFC_feature_extraction, + find_available_subjects, + process_SB_features, + select_num_components_binary_groupcv, + subject_center, +) + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from helper_functions import ( # pyright: ignore[reportMissingImports] + plot_samples_features, + save_scalar_colorbar, +) + +use_raw_features = False # if True, use raw dFC features instead of embedded features +normalize_dFC = False +FCS_proba_for_SB = True +train_test_ratio = 0.8 + +if use_raw_features: + raw_or_embedded = "_raw" +else: + raw_or_embedded = "" + +####################################################################################### + +if __name__ == "__main__": + # argparse + HELPTEXT = """ + Script to visualize the feature-sample matrix for each dataset, task, and dFC measure. + """ + + parser = argparse.ArgumentParser(description=HELPTEXT) + + parser.add_argument( + "--multi_dataset_info", type=str, help="path to multi-dataset info file" + ) + parser.add_argument( + "--simul_or_real", type=str, help="Specify 'simulated' or 'real' data" + ) + + args = parser.parse_args() + + multi_dataset_info = args.multi_dataset_info + simul_or_real = args.simul_or_real + + # Read dataset info + with open(multi_dataset_info, "r") as f: + multi_dataset_info = json.load(f) + + if simul_or_real == "real": + main_root = multi_dataset_info["real_data"]["main_root"] + DATASETS = multi_dataset_info["real_data"]["DATASETS"] + # TASKS_to_include = multi_dataset_info["real_data"]["TASKS_to_include"] # temporary !!!! + TASKS_to_include = [ + "task-Axcpt", + "task-CIC", + "task-Cuedts", + "task-feedback", + "task-IHG", + "task-matching", + "task-motor", + "task-Stern", + "task-Stroop", + ] + elif simul_or_real == "simulated": + main_root = multi_dataset_info["simulated_data"]["main_root"] + DATASETS = multi_dataset_info["simulated_data"]["DATASETS"] + TASKS_to_include = multi_dataset_info["simulated_data"]["TASKS_to_include"] + + output_root = f"{multi_dataset_info['output_root']}/feature-sample/{simul_or_real}" + if not os.path.exists(output_root): + os.makedirs(output_root) + + for dataset in DATASETS: + dataset_info_file = f"{main_root}/{dataset}/codes/dataset_info.json" + roi_root = f"{main_root}/{dataset}/derivatives/ROI_timeseries" + dFC_root = f"{main_root}/{dataset}/derivatives/dFC_assessed" + + # Read dataset info + with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + + if "SESSIONS" in dataset_info: + SESSIONS = dataset_info["SESSIONS"] + else: + SESSIONS = None + if SESSIONS is None: + SESSIONS = [None] + + TASKS = dataset_info["TASKS"] + + if "RUNS" in dataset_info: + RUNS = dataset_info["RUNS"] + else: + RUNS = None + if RUNS is None: + RUNS = {task: [None] for task in TASKS} + + for dFC_id in range(0, 7): + DATA = {} + for session in SESSIONS[:1]: # Only process the first session + for task_id, task in enumerate(TASKS): + if task not in TASKS_to_include: + print(f"Skipping task: {task} as it is not in TASKS_to_include.") + continue + for run in RUNS[task][:1]: # Only process the first run + print( + f"Processing dataset: {dataset}, task: {task}, run: {run}, session: {session}, dFC_id: {dFC_id}" + ) + + SUBJECTS = find_available_subjects( + dFC_root=dFC_root, + task=task, + run=run, + session=session, + dFC_id=dFC_id, + ) + + if task == "task-paingen": + # due to computational load, only use 100 subjects for this task + SUBJECTS = SUBJECTS[:100] + + if ( + task == "task-lowFreqLongRest" + or task == "task-lowFreqShortRest" + or task == "task-lowFreqShortTask" + ): + # due to computational load, only use 100 subjects for this task + SUBJECTS = SUBJECTS[:100] + + # randomly select train_test_ratio of the subjects for training + # and rest for testing using numpy.random.choice + train_subjects = np.random.choice( + SUBJECTS, int(train_test_ratio * len(SUBJECTS)), replace=False + ) + test_subjects = np.setdiff1d(SUBJECTS, train_subjects) + print( + f"Number of train subjects: {len(train_subjects)} and test subjects: {len(test_subjects)}" + ) + + ( + X_train, + X_test, + y_train, + y_test, + subj_label_train, + subj_label_test, + measure_name, + measure_is_state_based, + ) = dFC_feature_extraction( + task=task, + train_subjects=train_subjects, + test_subjects=test_subjects, + dFC_id=dFC_id, + roi_root=roi_root, + dFC_root=dFC_root, + run=run, + session=session, + dynamic_pred="no", + normalize_dFC=normalize_dFC, + FCS_proba_for_SB=FCS_proba_for_SB, # for state-based dFC features, we use FCS_proba + ) + + if measure_name is None: + print( + f"Skipping dataset: {dataset}, task: {task}, run: {run}, session: {session}, dFC_id: {dFC_id} due to no measure_name." + ) + continue + + if measure_is_state_based: + X_train = process_SB_features( + X=X_train, measure_name=measure_name + ) + X_test = process_SB_features( + X=X_test, measure_name=measure_name + ) + # center the data by subject before embedding to remove subject effects + # separately for train and test sets to avoid data leakage + # for both state-based and state-free methods + X_train_centered = subject_center( + X_train, subj_label_train, mode="demean" + ) + X_test_centered = subject_center( + X_test, subj_label_test, mode="demean" + ) + if not measure_is_state_based: + # embed dFC features using PLS regression, which is a supervised embedding method that finds the components that best explain the variance in the labels + best_n, _ = select_num_components_binary_groupcv( + X=X_train_centered, + y=y_train, + groups=subj_label_train, + embedding_method="PLS", + n_list=[ + 2, + 3, + 4, + 5, + 10, + 15, + 20, + 25, + 30, + 40, + 50, + ], # you can adjust this range based on your data + cv=5, # more stable + ) + pls = PLSEmbedder(n_components=best_n, scale=True) + # fit on train set + X_train_embedded = pls.fit_transform( + X_train_centered, y_train + ) + assert ( + X_train_embedded.shape[0] == y_train.shape[0] + ), "Number of samples do not match." + # only transform test set + if X_test is not None: + X_test_embedded = pls.transform(X_test_centered) + assert ( + X_test_embedded.shape[0] == y_test.shape[0] + ), "Number of samples do not match." + else: + X_test_embedded = None + else: + # for state-based measures, we skip the embedding step and just use the original features + X_train_embedded = X_train + X_test_embedded = X_test + + assert ( + task not in DATA + ), f"Task {task} already exists in DATA. Overwriting." + DATA[task] = { + "X_train": X_train, + "X_test": X_test, + "X_train_embedded": X_train_embedded, + "X_test_embedded": X_test_embedded, + "y_train": y_train, + "y_test": y_test, + "subj_label_train": subj_label_train, + "subj_label_test": subj_label_test, + "measure_name": measure_name, + } + # save the data + # save each task in a separate file and name the file as the task name, measure name, and dataset name + for task in DATA.keys(): + if use_raw_features: + X_train = DATA[task]["X_train"] + X_test = DATA[task]["X_test"] + else: + X_train = DATA[task]["X_train_embedded"] + X_test = DATA[task]["X_test_embedded"] + y_train = DATA[task]["y_train"] + y_test = DATA[task]["y_test"] + subj_label_train = DATA[task]["subj_label_train"] + subj_label_test = DATA[task]["subj_label_test"] + measure_name = DATA[task]["measure_name"] + + if X_train is None or X_test is None: + print(f"Skipping task {task} due to embedding error.") + continue + + if not os.path.exists(f"{output_root}/processed_data"): + os.makedirs(f"{output_root}/processed_data") + np.save( + f"{output_root}/processed_data/{dataset}_{task}_{measure_name}.npy", + DATA[task], + ) + + for group, X, y in zip( + ["train", "test"], [X_train, X_test], [y_train, y_test] + ): + # if the folder does not exist, create it + if not os.path.exists(f"{output_root}/{measure_name}"): + os.makedirs(f"{output_root}/{measure_name}") + + # A) Unsorted (your first vis, but rotated so time is horizontal) + plot_samples_features( + X, + y, + sample_order="original", + feature_order="original", + save_path=f"{output_root}/{measure_name}/feature-sample_{simul_or_real}_unsorted_{task}_{group}{raw_or_embedded}.png", + show=False, + ) + + # B) Label-sorted (your third vis) + plot_samples_features( + X, + y, + sample_order="label", + feature_order="original", + save_path=f"{output_root}/{measure_name}/feature-sample_{simul_or_real}_sorted-label_{task}_{group}{raw_or_embedded}.png", + show=False, + ) + + # C) clustering + plot_samples_features( + X, + y, + sample_order="cluster", + feature_order="original", + save_path=f"{output_root}/{measure_name}/feature-sample_{simul_or_real}_clustered-samples_{task}_{group}{raw_or_embedded}.png", + show=False, + ) + + save_scalar_colorbar( + cmap="coolwarm", + vmin=-1.6, + vmax=1.6, # use the same V_RANGE you use in plots + label="z-scored feature value", + filename=f"{output_root}/zscore_colorbar.png", + ) diff --git a/task_dFC/multi_dataset_analysis/task_presence_binarization.py b/task_dFC/multi_dataset_analysis/task_presence_binarization.py new file mode 100644 index 0000000..dac9547 --- /dev/null +++ b/task_dFC/multi_dataset_analysis/task_presence_binarization.py @@ -0,0 +1,234 @@ +import argparse +import json +import os +import sys + +import matplotlib.pyplot as plt +import numpy as np + +from pydfc.ml_utils import find_available_subjects, load_task_data +from pydfc.task_utils import extract_task_presence + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from helper_functions import ( # pyright: ignore[reportMissingImports] + build_experiment_display_info, +) + +####################################################################################### + +if __name__ == "__main__": + # argparse + HELPTEXT = """ + Script to visualize task timing and binarization results for multiple datasets. + """ + + parser = argparse.ArgumentParser(description=HELPTEXT) + + parser.add_argument( + "--multi_dataset_info", type=str, help="path to multi-dataset info file" + ) + parser.add_argument( + "--simul_or_real", type=str, help="Specify 'simulated' or 'real' data" + ) + + args = parser.parse_args() + + multi_dataset_info = args.multi_dataset_info + simul_or_real = args.simul_or_real + + # Read dataset info + with open(multi_dataset_info, "r") as f: + multi_dataset_info = json.load(f) + + print("Multi-Dataset Analysis started ...") + + if simul_or_real == "real": + main_root = multi_dataset_info["real_data"]["main_root"] + DATASETS = multi_dataset_info["real_data"]["DATASETS"] + TASKS_to_include = multi_dataset_info["real_data"]["TASKS_to_include"] + elif simul_or_real == "simulated": + main_root = multi_dataset_info["simulated_data"]["main_root"] + DATASETS = multi_dataset_info["simulated_data"]["DATASETS"] + TASKS_to_include = multi_dataset_info["simulated_data"]["TASKS_to_include"] + output_root = f"{multi_dataset_info['output_root']}/task_timing/{simul_or_real}" + + if not os.path.exists(output_root): + os.makedirs(output_root) + + _, task_to_experiment, _, _ = build_experiment_display_info( + tasks_iterable=TASKS_to_include, + task_reference_order=TASKS_to_include, + simul_or_real=simul_or_real, + ) + + for dataset in DATASETS: + print(f"Processing dataset: {dataset}") + dataset_info_file = f"{main_root}/{dataset}/codes/dataset_info.json" + roi_root = f"{main_root}/{dataset}/derivatives/ROI_timeseries" + dFC_root = f"{main_root}/{dataset}/derivatives/dFC_assessed" + + # Read dataset info + with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + + if "SESSIONS" in dataset_info: + SESSIONS = dataset_info["SESSIONS"] + else: + SESSIONS = None + if SESSIONS is None: + SESSIONS = [None] + + TASKS = dataset_info["TASKS"] + + if "RUNS" in dataset_info: + RUNS = dataset_info["RUNS"] + else: + RUNS = None + if RUNS is None: + RUNS = {task: [None] for task in TASKS} + + for task in TASKS: + features_all = None + for session in SESSIONS: + print(f"Processing task: {task}") + SUBJECTS = find_available_subjects( + dFC_root=dFC_root, + task=task, + dFC_id=None, + session=session, + ) + subj = SUBJECTS[0] + + run = RUNS[task][0] + try: + task_data = load_task_data( + roi_root=roi_root, subj=subj, task=task, run=run, session=session + ) + except: + continue + + stimulus_timing = np.multiply(task_data["event_labels"] != 0, 1) + + event_labels_all_task_hrf, _ = extract_task_presence( + event_labels=task_data["event_labels"], + TR_task=1 / task_data["Fs_task"], + TR_mri=1 / task_data["Fs_task"], + TR_array=None, + binary=False, + binarizing_method="GMM", + no_hrf=False, + ) + + task_presence, indices = extract_task_presence( + event_labels=task_data["event_labels"], + TR_task=1 / task_data["Fs_task"], + TR_mri=1 / task_data["Fs_task"], + TR_array=None, + binary=True, + binarizing_method="GMM", + no_hrf=False, + ) + # plot event_labels_all_task_hrf + plt.figure(figsize=(250, 10)) + print( + f"Fs_task: {task_data['Fs_task']}, TR_mri: {task_data['TR_mri']}, length of event_labels_all_task_hrf: {len(event_labels_all_task_hrf)}" + ) + plt.plot( + stimulus_timing, + label="Stimulus Timing", + color="#B8AD6F", + linewidth=15, + ) + # plt.plot(task_presence, label="Task Presence", color="blue", linewidth=3) + plt.plot( + event_labels_all_task_hrf, + label="HRF Convolved", + color="#010101", + linewidth=8, + ) + # plot a vertical dashed line at every TR_mri + for i in range( + 0, + len(event_labels_all_task_hrf), + int(task_data["TR_mri"] * task_data["Fs_task"]), + ): + plt.axvline(x=i, color="#c20707", linestyle="--", linewidth=5.0) + + # on_indices are index in indices where task_presence=1 + on_indices = indices[task_presence[indices] == 1] + # off_indices are index in indices where task_presence=0 + off_indices = indices[task_presence[indices] == 0] + plt.scatter( + on_indices, + event_labels_all_task_hrf[on_indices], + color="#7ab3dc", + label="on_indices", + s=300, + zorder=10, + ) + plt.scatter( + off_indices, + event_labels_all_task_hrf[off_indices], + color="#A8ACAD", + label="off_indices", + s=300, + zorder=10, + ) + + # remove all axis and spines, show only x axis + plt.gca().spines["top"].set_visible(False) + plt.gca().spines["right"].set_visible(False) + plt.gca().spines["left"].set_visible(False) + plt.gca().spines["bottom"].set_visible(True) + # increase bottom spine width + plt.gca().spines["bottom"].set_linewidth(5) + plt.gca().yaxis.set_visible(False) + plt.gca().xaxis.set_visible(True) + + # # set background color to lite pink + # plt.gca().set_facecolor("#F7EFEF") + + # set x ticks to be every TR_mri + step_factor = 1 + # if the length of event_labels_all_task_hrf > 6500, set step_factor to 5 + # to make the plot less crowded + if len(event_labels_all_task_hrf) > 6500: + step_factor = ( + np.ceil(len(event_labels_all_task_hrf) / 6500).astype(int) + 1 + ) + step = int( + round(task_data["TR_mri"] * task_data["Fs_task"] * step_factor) + ) + step = max(step, 1) # avoid step=0 + + ticks = np.arange(0, len(event_labels_all_task_hrf), step) + plt.gca().set_xticks(ticks) + + TR_labels = np.arange( + len(ticks) * step_factor, step=step_factor + ) # same length as ticks + # label each tick as time in seconds, TR_labels*TR_mri + time_labels = np.round(TR_labels * task_data["TR_mri"]).astype(int) + plt.gca().set_xticklabels(time_labels, fontsize=50) + plt.xlabel("Time (sec)", fontsize=60) + + experiment_label = task_to_experiment.get(task, task) + experiment_key = str(experiment_label).replace(" ", "_").replace("/", "-") + + plt.savefig( + f"{output_root}/task_timing_{experiment_key}_{task}.png", + dpi=120, + bbox_inches="tight", + pad_inches=0.1, + format="png", + ) + if task == "task-Localizer": + plt.savefig( + f"{output_root}/task_timing_{experiment_key}_{task}.svg", + dpi=120, + bbox_inches="tight", + pad_inches=0.1, + format="svg", + ) + + plt.close() diff --git a/task_dFC/multi_dataset_analysis/task_timing_stats.py b/task_dFC/multi_dataset_analysis/task_timing_stats.py new file mode 100644 index 0000000..b485b73 --- /dev/null +++ b/task_dFC/multi_dataset_analysis/task_timing_stats.py @@ -0,0 +1,521 @@ +import argparse +import json +import os +import sys + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns + +from pydfc.data_loader import find_subj_list +from pydfc.ml_utils import load_task_data +from pydfc.task_utils import ( + calc_relative_task_on, + calc_rest_duration, + calc_task_duration, + calc_transition_freq, + compute_optimality_index, + extract_task_presence, +) + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from helper_functions import ( # pyright: ignore[reportMissingImports] + annotate_medians_by_geometry, + annotate_medians_single_boxplot, + as_long_df, + build_experiment_display_info, + order_by_median_dict, + setup_pub_style, +) + +fig_bbox_inches = "tight" +fig_pad = 0.1 +show_title = False +save_fig_format = "png" # pdf, png, + +####################################################################################### + +if __name__ == "__main__": + # argparse + HELPTEXT = """ + Script to analyze and visualize task timing statistics across multiple datasets. + """ + + setup_pub_style() + parser = argparse.ArgumentParser(description=HELPTEXT) + + parser.add_argument( + "--multi_dataset_info", type=str, help="path to multi-dataset info file" + ) + parser.add_argument( + "--simul_or_real", type=str, help="Specify 'simulated' or 'real' data" + ) + + args = parser.parse_args() + + multi_dataset_info = args.multi_dataset_info + simul_or_real = args.simul_or_real + + # Read dataset info + with open(multi_dataset_info, "r") as f: + multi_dataset_info = json.load(f) + + if simul_or_real == "real": + main_root = multi_dataset_info["real_data"]["main_root"] + DATASETS = multi_dataset_info["real_data"]["DATASETS"] + TASKS_to_include = multi_dataset_info["real_data"]["TASKS_to_include"] + elif simul_or_real == "simulated": + main_root = multi_dataset_info["simulated_data"]["main_root"] + DATASETS = multi_dataset_info["simulated_data"]["DATASETS"] + TASKS_to_include = multi_dataset_info["simulated_data"]["TASKS_to_include"] + output_root = f"{multi_dataset_info['output_root']}/task_timing_stats/{simul_or_real}" + + if not os.path.exists(output_root): + os.makedirs(output_root) + + task_ratio_all = {} + transition_freq_all = {} + rest_durations_all = {} + task_durations_all = {} + OI_all = {} + DATA = { + "task": [], + "run": [], + "dataset": [], + "task_ratio_avg": [], + "transition_freq_avg": [], + "rest_durations_median": [], + "task_durations_median": [], + "rest_durations_iqr": [], + "task_durations_iqr": [], + "OI_median": [], + } + for dataset in DATASETS: + + print(f"Processing dataset: {dataset}") + dataset_info_file = f"{main_root}/{dataset}/codes/dataset_info.json" + roi_root = f"{main_root}/{dataset}/derivatives/ROI_timeseries" + dFC_root = f"{main_root}/{dataset}/derivatives/dFC_assessed" + + # Read dataset info + with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + + if "SESSIONS" in dataset_info: + SESSIONS = dataset_info["SESSIONS"] + else: + SESSIONS = None + if SESSIONS is None: + SESSIONS = [None] + + TASKS = dataset_info["TASKS"] + + if "RUNS" in dataset_info: + RUNS = dataset_info["RUNS"] + else: + RUNS = None + if RUNS is None: + RUNS = {task: [None] for task in TASKS} + + for session in SESSIONS[:1]: # process only the first session if multiple exist + for task_id, task in enumerate(TASKS): + if not task in TASKS_to_include: + continue + for run in RUNS[task]: + + task_ratio_run = [] + transition_freq_run = [] + rest_durations_run = [] + task_durations_run = [] + OI_run = [] + + SUBJECTS = find_subj_list(roi_root) + # print(f"Number of subjects: {len(SUBJECTS)}") + + for subj in SUBJECTS: + + try: + task_data = load_task_data( + roi_root=roi_root, + subj=subj, + task=task, + run=run, + session=session, + ) + except FileNotFoundError: + continue + + task_presence, indices = extract_task_presence( + event_labels=task_data["event_labels"], + TR_task=1 / task_data["Fs_task"], + TR_mri=task_data["TR_mri"], + binary=True, + binarizing_method="GMM", + no_hrf=False, + ) + + relative_task_on = calc_relative_task_on(task_presence[indices]) + num_of_transitions, relative_transition_freq = ( + calc_transition_freq(task_presence[indices]) + ) + # calculate rest and task durations based original event labels + event_labels = np.multiply(task_data["event_labels"] != 0, 1) + rest_durations = calc_rest_duration( + event_labels, TR_mri=1 / task_data["Fs_task"] + ) + task_durations = calc_task_duration( + event_labels, TR_mri=1 / task_data["Fs_task"] + ) + # calculate Optimality Index + out = compute_optimality_index( + event_labels=event_labels, + TR_task=1 / task_data["Fs_task"], + TR_mri=task_data["TR_mri"], + ) + OI = out["OI_norm"] + + task_ratio_run.append(relative_task_on) + transition_freq_run.append(relative_transition_freq) + rest_durations_run.extend(rest_durations) + task_durations_run.extend(task_durations) + OI_run.append(OI) + + # Aggregate stats across runs for this task and store in the all-run dictionaries for later plotting + if not task in task_ratio_all: + task_ratio_all[task] = [] + if not task in transition_freq_all: + transition_freq_all[task] = [] + if not task in rest_durations_all: + rest_durations_all[task] = [] + if not task in task_durations_all: + task_durations_all[task] = [] + if not task in OI_all: + OI_all[task] = [] + task_ratio_all[task].extend(task_ratio_run) + transition_freq_all[task].extend(transition_freq_run) + rest_durations_all[task].extend(rest_durations_run) + task_durations_all[task].extend(task_durations_run) + OI_all[task].extend(OI_run) + + # Aggregate run-level stats for this task and store in DATA for potential further analysis + DATA["task"].append(task) + DATA["run"].append(run) + DATA["dataset"].append(dataset) + DATA["task_ratio_avg"].append(np.nanmean(task_ratio_run)) + DATA["transition_freq_avg"].append(np.nanmean(transition_freq_run)) + DATA["rest_durations_median"].append(np.nanmedian(rest_durations_run)) + DATA["task_durations_median"].append(np.nanmedian(task_durations_run)) + q75_rest, q25_rest = np.percentile(rest_durations_run, [75, 25]) + iqr_rest = q75_rest - q25_rest + q75_task, q25_task = np.percentile(task_durations_run, [75, 25]) + iqr_task = q75_task - q25_task + DATA["rest_durations_iqr"].append(iqr_rest) + DATA["task_durations_iqr"].append(iqr_task) + DATA["OI_median"].append(np.nanmedian(OI_run)) + + np.save(f"{output_root}/task_timing_stats_{simul_or_real}.npy", DATA) + + all_tasks_present = sorted( + set(task_ratio_all) + | set(transition_freq_all) + | set(rest_durations_all) + | set(task_durations_all) + | set(OI_all) + ) + _, task_to_experiment, _, _ = build_experiment_display_info( + tasks_iterable=all_tasks_present, + task_reference_order=TASKS_to_include, + simul_or_real=simul_or_real, + ) + + # ========================= + # Paper-quality seaborn plots (patched) + # ========================= + + sns.set_theme(context="paper", style="darkgrid") + plt.rcParams.update( + { + "figure.dpi": 300, + "savefig.dpi": 500, + "pdf.fonttype": 42, + "ps.fonttype": 42, + "axes.labelweight": "bold", + "axes.titleweight": "bold", + "axes.labelsize": 14, + "axes.titlesize": 16, + "xtick.labelsize": 11, + "ytick.labelsize": 11, + "legend.fontsize": 12, + } + ) + + # ============================== + # 1) Task ratio (sorted by median) — BOX PLOT + median labels + # ============================== + order_ratio, stats_ratio = order_by_median_dict(task_ratio_all, reverse=True) + df_ratio = as_long_df(task_ratio_all, "task_ratio") + df_ratio = df_ratio[df_ratio["task"].isin(order_ratio)] + order_ratio_exp = [task_to_experiment[task] for task in order_ratio] + df_ratio["experiment"] = df_ratio["task"].map(task_to_experiment) + df_ratio["experiment"] = pd.Categorical( + df_ratio["experiment"], categories=order_ratio_exp, ordered=True + ) + + fig_w = max(15, 15 / 30 * len(order_ratio)) + plt.figure(figsize=(fig_w, 6)) + + ax = sns.boxplot( + data=df_ratio, + x="experiment", + y="task_ratio", + order=order_ratio_exp, + width=0.6, + linewidth=1, + showfliers=False, + ) + + ax.set_xlabel("Experiment") + ax.set_ylabel("Task ratio") + ax.set_ylim(0, 1) # keep ratios bounded + + # annotate medians (use integers if you prefer: fmt="{:.0f}") + annotate_medians_single_boxplot( + ax, + df_ratio, + x_col="experiment", + y_col="task_ratio", + order=order_ratio_exp, + fmt="{:.2f}", + box_alpha=0.6, + ) + + for label in ax.get_xticklabels(): + label.set_rotation(65) + label.set_horizontalalignment("right") + label.set_fontweight("bold") + if show_title: + ax.set_title("Task ratio per task (box + samples, ordered by median)", pad=12) + + plt.tight_layout() + plt.savefig( + f"{output_root}/task_ratio_{simul_or_real}.{save_fig_format}", + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + ) + plt.close() + + # ====================================== + # 2) Transition frequency (sorted by median) — BOX PLOT + median labels + # ====================================== + order_tf, stats_tf = order_by_median_dict(transition_freq_all, reverse=True) + df_tf = as_long_df(transition_freq_all, "transition_freq") + df_tf = df_tf[df_tf["task"].isin(order_tf)] + order_tf_exp = [task_to_experiment[task] for task in order_tf] + df_tf["experiment"] = df_tf["task"].map(task_to_experiment) + df_tf["experiment"] = pd.Categorical( + df_tf["experiment"], categories=order_tf_exp, ordered=True + ) + + fig_w = max(15, 15 / 30 * len(order_tf)) + plt.figure(figsize=(fig_w, 6)) + + ax = sns.boxplot( + data=df_tf, + x="experiment", + y="transition_freq", + order=order_tf_exp, + width=0.6, + linewidth=1, + showfliers=False, + ) + + ax.set_xlabel("Experiment") + ax.set_ylabel("Relative transition frequency") + + # annotate medians + annotate_medians_single_boxplot( + ax, + df_tf, + x_col="experiment", + y_col="transition_freq", + order=order_tf_exp, + fmt="{:.2f}", + box_alpha=0.6, + ) + + for label in ax.get_xticklabels(): + label.set_rotation(65) + label.set_horizontalalignment("right") + label.set_fontweight("bold") + if show_title: + ax.set_title( + "Transition frequency per task (box + samples, ordered by median)", pad=12 + ) + + plt.tight_layout() + plt.savefig( + f"{output_root}/transition_freq_{simul_or_real}.{save_fig_format}", + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + ) + plt.close() + + # ========================================================= + # 3) Rest vs Task durations: side-by-side per task paradigm (LOG SCALE) + # ========================================================= + df_rest = as_long_df(rest_durations_all, "duration") + df_rest["state"] = "Rest" + df_task = as_long_df(task_durations_all, "duration") + df_task["state"] = "Task" + df_dur = pd.concat([df_rest, df_task], ignore_index=True) + + # Order tasks by mean Task duration (change to Rest if you prefer) + order_dur, _ = order_by_median_dict(task_durations_all, reverse=True) + df_dur = df_dur[df_dur["task"].isin(order_dur)] + order_dur_exp = [task_to_experiment[task] for task in order_dur] + df_dur["experiment"] = df_dur["task"].map(task_to_experiment) + df_dur["experiment"] = pd.Categorical( + df_dur["experiment"], categories=order_dur_exp, ordered=True + ) + + # ---- LOG display handling (avoid -inf for zeros) ---- + # pick an adaptive epsilon based on the smallest positive value + pos = df_dur.loc[df_dur["duration"] > 0, "duration"] + if len(pos) == 0: + EPS = 1e-3 + else: + EPS = max(min(pos) / 10.0, 1e-3) # small but data-driven + df_dur["duration_plot"] = df_dur["duration"].clip(lower=EPS) + + fig_w = max(17, 17 / 30 * len(order_dur)) + plt.figure(figsize=(fig_w, 7)) + + # Boxplot on log scale (no fliers; jitters will show samples, incl. singletons) + ax = sns.boxplot( + data=df_dur, + x="experiment", + y="duration_plot", + hue="state", + order=order_dur_exp, + hue_order=["Rest", "Task"], + linewidth=1, + dodge=True, + showfliers=False, + width=0.6, + ) + + # Put y-axis on log scale (preserves wide dynamic range) + ax.set_yscale("log") + + # annotate medians on the median line (log-scale safe) + annotate_medians_by_geometry( + ax=ax, + df_long=df_dur, # the DF you plotted + x_col="experiment", + hue_col="state", + y_col="duration_plot", # the epsilon-clipped column you used for plotting + x_order=order_dur_exp, + hue_order=["Rest", "Task"], + fmt="{:.0f}", + y_nudge_factor=1.08, # bump if labels sit on the line in log-space + bin_halfwidth=0.6, # widen if categories are very tightly packed + bbox_alpha=0.6, # make label bg more opaque for legibility + ) + + # Clean up duplicated legends (boxplot + stripplot both add entries) + handles, labels = ax.get_legend_handles_labels() + # the first two unique handles correspond to Rest/Task once; keep those + unique = [] + seen = set() + for h, l in zip(handles, labels): + if l not in seen: + unique.append((h, l)) + seen.add(l) + # Keep only Rest/Task (first two) + handles_clean, labels_clean = ( + zip(*unique[:2]) if len(unique) >= 2 else (handles[:2], labels[:2]) + ) + ax.legend(handles_clean, labels_clean, title="", frameon=True, loc="upper right") + + ax.set_xlabel("Experiment") + ax.set_ylabel("Duration (sec, log scale)") + + for label in ax.get_xticklabels(): + label.set_rotation(65) + label.set_horizontalalignment("right") + label.set_fontweight("bold") + + if show_title: + ax.set_title("Rest vs Task durations per task (log scale; box + points)", pad=12) + + plt.tight_layout() + plt.savefig( + f"{output_root}/durations_rest_vs_task_{simul_or_real}.{save_fig_format}", + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + ) + plt.close() + + # ====================================== + # 4) Optimality Index (sorted by median) — BOX PLOT + median labels + # ====================================== + order_oi, stats_oi = order_by_median_dict(OI_all, reverse=True) + df_oi = as_long_df(OI_all, "OI_avg") + df_oi = df_oi[df_oi["task"].isin(order_oi)] + order_oi_exp = [task_to_experiment[task] for task in order_oi] + df_oi["experiment"] = df_oi["task"].map(task_to_experiment) + df_oi["experiment"] = pd.Categorical( + df_oi["experiment"], categories=order_oi_exp, ordered=True + ) + + fig_w = max(15, 15 / 30 * len(order_oi)) + plt.figure(figsize=(fig_w, 6)) + + ax = sns.boxplot( + data=df_oi, + x="experiment", + y="OI_avg", + order=order_oi_exp, + width=0.6, + linewidth=1, + showfliers=False, + ) + + ax.set_xlabel("Experiment") + ax.set_ylabel("Optimality Index") + oi_max = float(np.nanmax(df_oi["OI_avg"])) + oi_min = float(np.nanmin(df_oi["OI_avg"])) + if np.isfinite(oi_max) and np.isfinite(oi_min): + y_pad = max(0.03 * (oi_max - oi_min), 0.02) + ax.set_ylim(oi_min - y_pad, oi_max + y_pad) + + # annotate medians + annotate_medians_single_boxplot( + ax, + df_oi, + x_col="experiment", + y_col="OI_avg", + order=order_oi_exp, + fmt="{:.2f}", + box_alpha=0.6, + ) + + for label in ax.get_xticklabels(): + label.set_rotation(65) + label.set_horizontalalignment("right") + label.set_fontweight("bold") + if show_title: + ax.set_title( + "Optimality Index per task (box + samples, ordered by median)", pad=12 + ) + + plt.tight_layout() + plt.savefig( + f"{output_root}/optimality_index_{simul_or_real}.{save_fig_format}", + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + ) + plt.close() + + # ========================================================= diff --git a/task_dFC/multi_dataset_analysis/train_ts2vec_dfc_embeddings.py b/task_dFC/multi_dataset_analysis/train_ts2vec_dfc_embeddings.py new file mode 100644 index 0000000..471d303 --- /dev/null +++ b/task_dFC/multi_dataset_analysis/train_ts2vec_dfc_embeddings.py @@ -0,0 +1,787 @@ +import argparse +import inspect +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import numpy as np +import pandas as pd + +from pydfc.ml_utils import ( + dFC_feature_extraction_subj_lvl, + find_available_subjects, + load_dFC, + load_task_data, +) + + +def str2bool(v: Any) -> bool: + if isinstance(v, bool): + return v + s = str(v).strip().lower() + if s in {"1", "true", "t", "yes", "y"}: + return True + if s in {"0", "false", "f", "no", "n"}: + return False + raise argparse.ArgumentTypeError(f"Invalid boolean value: {v}") + + +def parse_json_arg(value: Optional[str]) -> Dict[str, Any]: + if value is None or value.strip() == "": + return {} + parsed = json.loads(value) + if not isinstance(parsed, dict): + raise ValueError("JSON argument must be an object/dict.") + return parsed + + +def normalize_optional_token(value: Optional[str]) -> Optional[str]: + if value is None: + return None + if value in {"None", "none", "null"}: + return None + return value + + +def choose_subjects( + subjects: Sequence[str], + max_subjects_per_scan: Optional[int], + rng: np.random.Generator, +) -> List[str]: + subjects = sorted(list(subjects)) + if max_subjects_per_scan is None or len(subjects) <= max_subjects_per_scan: + return subjects + idx = rng.choice(len(subjects), size=max_subjects_per_scan, replace=False) + idx = np.sort(idx) + return [subjects[i] for i in idx] + + +def load_multi_dataset_spec( + multi_dataset_info_path: str, simul_or_real: str +) -> Tuple[Dict[str, Any], str, List[str], List[str]]: + with open(multi_dataset_info_path, "r") as f: + multi_dataset_info = json.load(f) + + if simul_or_real == "real": + spec = multi_dataset_info["real_data"] + elif simul_or_real == "simulated": + spec = multi_dataset_info["simulated_data"] + else: + raise ValueError("--simul_or_real must be 'real' or 'simulated'") + + main_root = spec["main_root"] + datasets = list(spec["DATASETS"]) + tasks_to_include = list(spec["TASKS_to_include"]) + return multi_dataset_info, main_root, datasets, tasks_to_include + + +def load_dataset_info( + dataset_info_file: str, +) -> Tuple[List[Optional[str]], List[str], Dict[str, List[Optional[str]]]]: + with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + + sessions = dataset_info.get("SESSIONS", None) + if sessions is None: + sessions = [None] + + tasks = dataset_info["TASKS"] + + runs = dataset_info.get("RUNS", None) + if runs is None: + runs = {task: [None] for task in tasks} + else: + runs = { + task: (runs[task] if runs[task] is not None else [None]) for task in tasks + } + + return sessions, tasks, runs + + +def prepare_ts2vec_input( + sequences: Sequence[np.ndarray], + seq_len_mode: str, + pad_value: float, + target_seq_len: Optional[int] = None, +) -> Tuple[np.ndarray, np.ndarray]: + if len(sequences) == 0: + raise ValueError("No sequences provided.") + + lengths = np.array([seq.shape[0] for seq in sequences], dtype=np.int32) + feature_dims = {seq.shape[1] for seq in sequences} + if len(feature_dims) != 1: + raise ValueError(f"Inconsistent feature dimensions found: {sorted(feature_dims)}") + feat_dim = next(iter(feature_dims)) + + if target_seq_len is None: + if seq_len_mode == "truncate_min": + target_seq_len = int(lengths.min()) + elif seq_len_mode == "pad_max": + target_seq_len = int(lengths.max()) + else: + raise ValueError(f"Unknown seq_len_mode: {seq_len_mode}") + target_seq_len = int(target_seq_len) + if target_seq_len <= 0: + raise ValueError("target_seq_len must be positive.") + + X = np.full((len(sequences), target_seq_len, feat_dim), pad_value, dtype=np.float32) + for i, seq in enumerate(sequences): + seq = seq.astype(np.float32, copy=False) + if seq.shape[0] >= target_seq_len: + X[i] = seq[:target_seq_len, :] + else: + X[i, : seq.shape[0], :] = seq + + return X, lengths + + +def standardize_ts2vec_input( + X: np.ndarray, eps: float = 1e-6 +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + mean = X.mean(axis=(0, 1), keepdims=True) + std = X.std(axis=(0, 1), keepdims=True) + std = np.where(std < eps, 1.0, std) + Xz = (X - mean) / std + return Xz.astype(np.float32, copy=False), mean.squeeze(), std.squeeze() + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Load dFC feature sequences across multiple datasets and train a TS2Vec " + "model to learn embeddings." + ) + ) + parser.add_argument( + "--multi_dataset_info", + type=str, + required=True, + help="Path to task_dFC/run_scripts_slurm/multi_dataset_info.json", + ) + parser.add_argument( + "--simul_or_real", + type=str, + required=True, + choices=["real", "simulated"], + help="Which section of the multi-dataset config to use.", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Output directory. Defaults to /TS2Vec/.", + ) + parser.add_argument( + "--dFC_ids", + type=int, + nargs="+", + required=True, + help="One or more dFC method IDs to process. A separate TS2Vec model is trained per compatible group.", + ) + parser.add_argument( + "--datasets", + type=str, + nargs="*", + default=None, + help="Optional subset of dataset IDs to include.", + ) + parser.add_argument( + "--tasks", + type=str, + nargs="*", + default=None, + help="Optional subset of task labels to include (e.g., task-Axcpt).", + ) + parser.add_argument( + "--sessions", + type=str, + nargs="*", + default=None, + help="Optional subset of session labels to include.", + ) + parser.add_argument( + "--runs", + type=str, + nargs="*", + default=None, + help="Optional subset of run labels to include.", + ) + parser.add_argument( + "--dynamic_pred", + type=str, + default="no", + choices=["no", "past", "past_and_future"], + help="Feature stacking mode reused from pydfc.ml_utils.dFC_feature_extraction_subj_lvl.", + ) + parser.add_argument( + "--normalize_dFC", + type=str2bool, + default=True, + help="Apply rank normalization to state-free dFC matrices before vectorization.", + ) + parser.add_argument( + "--FCS_proba_for_SB", + type=str2bool, + default=True, + help="For state-based dFC, use FCS probabilities instead of vectorized dFC matrices.", + ) + parser.add_argument( + "--min_seq_len", + type=int, + default=10, + help="Minimum sequence length (TRs) after feature extraction.", + ) + parser.add_argument( + "--max_subjects_per_scan", + type=int, + default=None, + help="Randomly subsample subjects per (dataset, session, task, run, dFC_id).", + ) + parser.add_argument( + "--max_total_sequences", + type=int, + default=None, + help="Optional global cap on number of sequences per TS2Vec training group.", + ) + parser.add_argument( + "--seq_len_mode", + type=str, + default="truncate_min", + choices=["truncate_min", "pad_max"], + help="How to make variable-length sequences compatible for TS2Vec input.", + ) + parser.add_argument( + "--target_seq_len", + type=int, + default=None, + help="Override sequence length used for training input (truncate/pad to this length).", + ) + parser.add_argument( + "--pad_value", + type=float, + default=0.0, + help="Padding value when --seq_len_mode=pad_max or --target_seq_len exceeds sequence length.", + ) + parser.add_argument( + "--standardize_features", + type=str2bool, + default=False, + help="Z-score features globally across sequences and timepoints before TS2Vec training.", + ) + parser.add_argument( + "--seed", type=int, default=0, help="Random seed for subsampling." + ) + + # TS2Vec common args (kept optional and overridable via JSON) + parser.add_argument( + "--device", type=str, default=None, help="TS2Vec device (e.g., cpu, cuda)." + ) + parser.add_argument( + "--output_dims", type=int, default=320, help="TS2Vec output embedding dimension." + ) + parser.add_argument( + "--hidden_dims", type=int, default=64, help="TS2Vec hidden dimension." + ) + parser.add_argument("--depth", type=int, default=10, help="TS2Vec encoder depth.") + parser.add_argument( + "--batch_size", type=int, default=8, help="TS2Vec fit batch size." + ) + parser.add_argument( + "--lr", type=float, default=1e-3, help="TS2Vec fit learning rate (if supported)." + ) + parser.add_argument( + "--max_train_length", + type=int, + default=None, + help="TS2Vec max_train_length (if supported).", + ) + parser.add_argument( + "--temporal_unit", + type=int, + default=0, + help="TS2Vec temporal_unit (if supported).", + ) + parser.add_argument( + "--n_epochs", type=int, default=50, help="Number of TS2Vec training epochs." + ) + + parser.add_argument( + "--ts2vec_init_json", + type=str, + default=None, + help="Extra JSON object of kwargs for TS2Vec(...) init. Overrides common args on key conflict.", + ) + parser.add_argument( + "--ts2vec_fit_json", + type=str, + default=None, + help="Extra JSON object of kwargs for model.fit(...). Overrides common args on key conflict.", + ) + parser.add_argument( + "--ts2vec_encode_json", + type=str, + default=None, + help="Extra JSON object of kwargs for model.encode(...).", + ) + + parser.add_argument( + "--encoding_window", + type=str, + default="full_series", + help="TS2Vec encode encoding_window. Use integer string for numeric window or full_series.", + ) + parser.add_argument( + "--save_timestep_embeddings", + type=str2bool, + default=False, + help="Also save per-timestep embeddings (can be large).", + ) + parser.add_argument( + "--save_model", + type=str2bool, + default=True, + help="Try to save the TS2Vec model if the package exposes model.save(...).", + ) + + return parser + + +def instantiate_ts2vec( + TS2Vec: Any, init_kwargs: Dict[str, Any] +) -> Tuple[Any, Dict[str, Any]]: + """ + Try a few progressively smaller init signatures to tolerate TS2Vec package variants. + Returns (model, effective_init_kwargs). + """ + candidate_kwargs = [dict(init_kwargs)] + optional_drop_order = ["temporal_unit", "max_train_length", "device"] + current = dict(init_kwargs) + for key in optional_drop_order: + if key in current: + current = dict(current) + current.pop(key, None) + candidate_kwargs.append(current) + + last_error = None + for kwargs in candidate_kwargs: + try: + return TS2Vec(**kwargs), kwargs + except TypeError as e: + last_error = e + continue + + raise TypeError( + f"Could not instantiate TS2Vec with tested kwargs variants: {last_error}" + ) + + +def fit_ts2vec_adaptive( + model: Any, X_ts2vec: np.ndarray, fit_kwargs: Dict[str, Any] +) -> Dict[str, Any]: + """ + Call TS2Vec.fit using only kwargs supported by the installed implementation. + Returns the effective kwargs that were actually used. + """ + try: + sig = inspect.signature(model.fit) + params = sig.parameters + accepts_var_kw = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values() + ) + if accepts_var_kw: + _ = model.fit(X_ts2vec, **fit_kwargs) + return dict(fit_kwargs) + + allowed = { + name + for name, p in params.items() + if name != "self" + and p.kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + } + effective_fit_kwargs = {k: v for k, v in fit_kwargs.items() if k in allowed} + _ = model.fit(X_ts2vec, **effective_fit_kwargs) + return effective_fit_kwargs + except (TypeError, ValueError): + # Some wrapped methods don't expose a reliable signature. + last_error = None + candidate_kwarg_sets = [ + dict(fit_kwargs), + {k: v for k, v in fit_kwargs.items() if k != "batch_size"}, + {k: v for k, v in fit_kwargs.items() if k not in {"batch_size", "lr"}}, + {k: v for k, v in fit_kwargs.items() if k == "n_epochs"}, + {}, + ] + + seen = set() + for kw in candidate_kwarg_sets: + key = tuple(sorted((str(k), str(v)) for k, v in kw.items())) + if key in seen: + continue + seen.add(key) + try: + _ = model.fit(X_ts2vec, **kw) + return kw + except TypeError as e: + last_error = e + continue + + raise TypeError( + f"Could not call TS2Vec.fit with tested kwargs variants: {last_error}" + ) + + +def main() -> None: + parser = build_parser() + args = parser.parse_args() + + rng = np.random.default_rng(args.seed) + multi_dataset_info, main_root, datasets, tasks_to_include = load_multi_dataset_spec( + args.multi_dataset_info, args.simul_or_real + ) + + if args.datasets: + datasets = [d for d in datasets if d in set(args.datasets)] + task_filter = set(args.tasks) if args.tasks else set(tasks_to_include) + session_filter = ( + {normalize_optional_token(x) for x in args.sessions} if args.sessions else None + ) + run_filter = {normalize_optional_token(x) for x in args.runs} if args.runs else None + + if args.output_dir is None: + output_root = f"{multi_dataset_info['output_root']}/TS2Vec/{args.simul_or_real}" + else: + output_root = args.output_dir + Path(output_root).mkdir(parents=True, exist_ok=True) + + # group key -> payload + grouped_sequences: Dict[Tuple[int, str, int], List[np.ndarray]] = {} + grouped_targets: Dict[Tuple[int, str, int], List[np.ndarray]] = {} + grouped_meta: Dict[Tuple[int, str, int], List[Dict[str, Any]]] = {} + skipped_records: List[Dict[str, Any]] = [] + + total_loaded = 0 + print(f"Datasets to process: {datasets}") + for dataset in datasets: + dataset_info_file = f"{main_root}/{dataset}/codes/dataset_info.json" + roi_root = f"{main_root}/{dataset}/derivatives/ROI_timeseries" + dFC_root = f"{main_root}/{dataset}/derivatives/dFC_assessed" + + if not os.path.exists(dataset_info_file): + print( + f"Skipping dataset {dataset}: dataset_info.json not found at {dataset_info_file}" + ) + continue + + sessions, tasks, runs_map = load_dataset_info(dataset_info_file) + if session_filter is not None: + sessions = [s for s in sessions if s in session_filter] + + for session in sessions: + for task in tasks: + if task not in task_filter: + continue + runs = runs_map.get(task, [None]) + if run_filter is not None: + runs = [r for r in runs if r in run_filter] + + for run in runs: + for dFC_id in args.dFC_ids: + try: + subjects = find_available_subjects( + dFC_root=dFC_root, + task=task, + run=run, + session=session, + dFC_id=dFC_id, + ) + except FileNotFoundError: + print(f"Skipping missing dFC directory: {dFC_root}") + continue + + if len(subjects) == 0: + continue + subjects = choose_subjects( + subjects=subjects, + max_subjects_per_scan=args.max_subjects_per_scan, + rng=rng, + ) + + print( + "Loading " + f"dataset={dataset} session={session} task={task} run={run} " + f"dFC_id={dFC_id} n_subjects={len(subjects)}" + ) + for subj in subjects: + try: + dFC = load_dFC( + dFC_root=dFC_root, + subj=subj, + task=task, + dFC_id=dFC_id, + run=run, + session=session, + ) + task_data = load_task_data( + roi_root=roi_root, + subj=subj, + task=task, + run=run, + session=session, + ) + X_subj, y_subj = dFC_feature_extraction_subj_lvl( + dFC=dFC, + task_data=task_data, + dynamic_pred=args.dynamic_pred, + normalize_dFC=args.normalize_dFC, + FCS_proba_for_SB=args.FCS_proba_for_SB, + ) + except Exception as e: + skipped_records.append( + { + "dataset": dataset, + "session": session, + "task": task, + "run": run, + "dFC_id": dFC_id, + "subject": subj, + "reason": f"{type(e).__name__}: {e}", + } + ) + continue + + if X_subj.shape[0] < args.min_seq_len: + skipped_records.append( + { + "dataset": dataset, + "session": session, + "task": task, + "run": run, + "dFC_id": dFC_id, + "subject": subj, + "reason": f"seq_too_short({X_subj.shape[0]}<{args.min_seq_len})", + } + ) + continue + + measure_name = dFC.measure.measure_name + group_key = (dFC_id, measure_name, int(X_subj.shape[1])) + grouped_sequences.setdefault(group_key, []).append(X_subj) + grouped_targets.setdefault(group_key, []).append(y_subj) + grouped_meta.setdefault(group_key, []).append( + { + "dataset": dataset, + "session": session, + "task": task, + "run": run, + "dFC_id": dFC_id, + "subject": subj, + "measure_name": measure_name, + "seq_len_raw": int(X_subj.shape[0]), + "feature_dim": int(X_subj.shape[1]), + "task_presence_mean": float(np.mean(y_subj)), + } + ) + total_loaded += 1 + + print(f"Loaded sequences: {total_loaded}") + if total_loaded == 0: + raise RuntimeError("No sequences were loaded. Check filters/paths/dFC_ids.") + + # Lazy import to avoid making this script unusable when ts2vec is not installed. + try: + from ts2vec import TS2Vec # type: ignore + except ImportError as e: + raise ImportError( + "TS2Vec package is not installed. Install a compatible implementation " + "(commonly `pip install ts2vec`) and rerun." + ) from e + + ts2vec_init_extra = parse_json_arg(args.ts2vec_init_json) + ts2vec_fit_extra = parse_json_arg(args.ts2vec_fit_json) + ts2vec_encode_extra = parse_json_arg(args.ts2vec_encode_json) + + encoding_window: Any = args.encoding_window + if isinstance(encoding_window, str) and encoding_window.isdigit(): + encoding_window = int(encoding_window) + + run_summaries: List[Dict[str, Any]] = [] + + for group_key in sorted(grouped_sequences.keys(), key=lambda x: (x[0], x[1], x[2])): + dFC_id, measure_name, feature_dim = group_key + sequences = grouped_sequences[group_key] + targets = grouped_targets[group_key] + meta_rows = grouped_meta[group_key] + + if ( + args.max_total_sequences is not None + and len(sequences) > args.max_total_sequences + ): + idx = rng.choice(len(sequences), size=args.max_total_sequences, replace=False) + idx = np.sort(idx) + sequences = [sequences[i] for i in idx] + targets = [targets[i] for i in idx] + meta_rows = [meta_rows[i] for i in idx] + + if len(sequences) < 2: + print( + f"Skipping group dFC_id={dFC_id}, measure={measure_name}, feat_dim={feature_dim}: " + "need at least 2 sequences for training." + ) + continue + + X_ts2vec, raw_lengths = prepare_ts2vec_input( + sequences=sequences, + seq_len_mode=args.seq_len_mode, + pad_value=args.pad_value, + target_seq_len=args.target_seq_len, + ) + + feature_mean = None + feature_std = None + if args.standardize_features: + X_ts2vec, feature_mean, feature_std = standardize_ts2vec_input(X_ts2vec) + + print( + f"Training TS2Vec on group dFC_id={dFC_id}, measure={measure_name}, " + f"X.shape={X_ts2vec.shape}" + ) + + init_kwargs: Dict[str, Any] = { + "input_dims": int(feature_dim), + "output_dims": int(args.output_dims), + "hidden_dims": int(args.hidden_dims), + "depth": int(args.depth), + } + if args.device is not None: + init_kwargs["device"] = args.device + if args.max_train_length is not None: + init_kwargs["max_train_length"] = int(args.max_train_length) + if args.temporal_unit is not None: + init_kwargs["temporal_unit"] = int(args.temporal_unit) + init_kwargs.update(ts2vec_init_extra) + + fit_kwargs: Dict[str, Any] = { + "n_epochs": int(args.n_epochs), + "batch_size": int(args.batch_size), + "lr": float(args.lr), + } + fit_kwargs.update(ts2vec_fit_extra) + + model, effective_init_kwargs = instantiate_ts2vec(TS2Vec, init_kwargs) + fit_kwargs = fit_ts2vec_adaptive(model, X_ts2vec, fit_kwargs) + + encode_kwargs: Dict[str, Any] = {"encoding_window": encoding_window} + encode_kwargs.update(ts2vec_encode_extra) + full_series_embeddings = model.encode(X_ts2vec, **encode_kwargs) + + timestep_embeddings = None + if args.save_timestep_embeddings: + timestep_embeddings = model.encode(X_ts2vec) + + safe_measure = str(measure_name).replace(" ", "_") + group_dir = Path(output_root) / f"dFC_{dFC_id}_{safe_measure}_feat{feature_dim}" + group_dir.mkdir(parents=True, exist_ok=True) + + meta_df = pd.DataFrame(meta_rows).copy() + meta_df["seq_len_used"] = int(X_ts2vec.shape[1]) + meta_df["seq_len_raw"] = raw_lengths + meta_df.to_csv(group_dir / "sequence_metadata.csv", index=False) + + np.save( + group_dir / "full_series_embeddings.npy", np.asarray(full_series_embeddings) + ) + np.save(group_dir / "train_sequences_input.npy", X_ts2vec) + np.save( + group_dir / "task_presence_labels.npy", + np.array(targets, dtype=object), + allow_pickle=True, + ) + if timestep_embeddings is not None: + np.save( + group_dir / "timestep_embeddings.npy", np.asarray(timestep_embeddings) + ) + if feature_mean is not None and feature_std is not None: + np.save(group_dir / "feature_mean.npy", np.asarray(feature_mean)) + np.save(group_dir / "feature_std.npy", np.asarray(feature_std)) + + model_saved = False + if args.save_model and hasattr(model, "save"): + try: + model.save(str(group_dir / "ts2vec_model")) + model_saved = True + except Exception as e: + print(f"Could not save TS2Vec model for group {group_key}: {e}") + + config_to_save = { + "group_key": { + "dFC_id": int(dFC_id), + "measure_name": str(measure_name), + "feature_dim": int(feature_dim), + }, + "data": { + "n_sequences": int(len(sequences)), + "seq_len_mode": args.seq_len_mode, + "target_seq_len": int(X_ts2vec.shape[1]), + "raw_seq_len_min": int(raw_lengths.min()), + "raw_seq_len_max": int(raw_lengths.max()), + "standardize_features": bool(args.standardize_features), + }, + "loader_params": { + "simul_or_real": args.simul_or_real, + "datasets": datasets, + "task_filter": sorted(list(task_filter)), + "session_filter": ( + None if session_filter is None else sorted(list(session_filter)) + ), + "run_filter": None if run_filter is None else sorted(list(run_filter)), + "dFC_ids": [int(x) for x in args.dFC_ids], + "dynamic_pred": args.dynamic_pred, + "normalize_dFC": bool(args.normalize_dFC), + "FCS_proba_for_SB": bool(args.FCS_proba_for_SB), + "min_seq_len": int(args.min_seq_len), + "max_subjects_per_scan": args.max_subjects_per_scan, + "max_total_sequences": args.max_total_sequences, + "seed": int(args.seed), + }, + "ts2vec": { + "init_kwargs": effective_init_kwargs, + "fit_kwargs": fit_kwargs, + "encode_kwargs": encode_kwargs, + "model_saved": model_saved, + }, + } + with open(group_dir / "run_config.json", "w") as f: + json.dump(config_to_save, f, indent=2) + + run_summaries.append( + { + "dFC_id": int(dFC_id), + "measure_name": str(measure_name), + "feature_dim": int(feature_dim), + "n_sequences": int(len(sequences)), + "seq_len_used": int(X_ts2vec.shape[1]), + "embedding_shape": list(np.asarray(full_series_embeddings).shape), + "output_dir": str(group_dir), + } + ) + + # Avoid holding multiple large arrays/models longer than needed. + del model + + if skipped_records: + pd.DataFrame(skipped_records).to_csv( + Path(output_root) / "skipped_records.csv", index=False + ) + with open(Path(output_root) / "run_summary.json", "w") as f: + json.dump(run_summaries, f, indent=2) + + print(f"Finished. Outputs written to: {output_root}") + + +if __name__ == "__main__": + main() diff --git a/task_dFC/multi_dataset_analysis/tsnr.py b/task_dFC/multi_dataset_analysis/tsnr.py new file mode 100644 index 0000000..d61de5a --- /dev/null +++ b/task_dFC/multi_dataset_analysis/tsnr.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from helper_functions import ( + annotate_medians_single_boxplot, + build_experiment_display_info, + canon_task, + get_default_experiment_name_map, + order_by_median_dict, + setup_pub_style, +) + + +def _load_and_filter_tsnr_df(tsv_path: str) -> pd.DataFrame: + df = pd.read_csv(tsv_path, sep="\t", dtype=str) + + # Make sure expected columns exist + required_cols = {"dataset", "sub", "ses", "task", "run", "tsnr_median", "error"} + missing = required_cols - set(df.columns) + if missing: + raise ValueError(f"Missing required columns in TSV: {sorted(missing)}") + + # Normalize missing values + df = df.fillna("") + + # Keep only rows without errors + df = df[df["error"].astype(str).str.strip() == ""].copy() + + # Keep only the desired session for multi-session datasets + # ds005038 -> keep ses == "pre" + # ds003823 -> keep ses == "post" + mask_ds005038 = df["dataset"] == "ds005038" + mask_ds003823 = df["dataset"] == "ds003823" + + df = df[ + (~mask_ds005038 | (df["ses"] == "pre")) & (~mask_ds003823 | (df["ses"] == "post")) + ].copy() + + # Convert tSNR median to numeric + df["tsnr_median"] = pd.to_numeric(df["tsnr_median"], errors="coerce") + + # Drop rows where tsnr_median could not be parsed + df = df[df["tsnr_median"].notna()].copy() + + return df + + +def build_grouped_tsnr_summary(tsv_path: str) -> Path: + tsv_path = Path(tsv_path).resolve() + out_path = tsv_path.parent / "tsnr_summary_grouped.tsv" + + df = _load_and_filter_tsnr_df(str(tsv_path)) + + # Average over subjects for each dataset/task/run + out_df = ( + df.groupby(["dataset", "task", "run"], as_index=False)["tsnr_median"] + .mean() + .rename(columns={"tsnr_median": "median_tsnr_avg_over_subjects"}) + ) + + # Append prefixes + out_df["task"] = "task-" + out_df["task"].astype(str) + + def format_run(x): + if pd.isna(x) or str(x).strip() == "": + return None + return f"run-{x}" + + out_df["run"] = out_df["run"].apply(format_run) + + # Reorder columns exactly as requested + out_df = out_df[["dataset", "run", "task", "median_tsnr_avg_over_subjects"]] + + # Optional: round nicely + out_df["median_tsnr_avg_over_subjects"] = out_df[ + "median_tsnr_avg_over_subjects" + ].round(2) + + # Save in same directory + out_df.to_csv(out_path, sep="\t", index=False) + + return out_path + + +def build_tsnr_distribution_figure(tsv_path: str) -> Path: + tsv_path = Path(tsv_path).resolve() + fig_path = tsv_path.parent / "tsnr_median_distribution_by_exp.png" + + df = _load_and_filter_tsnr_df(str(tsv_path)) + if df.empty: + raise ValueError("No valid tSNR rows available to plot after filtering.") + + task_to_values = df.groupby("task")["tsnr_median"].apply(list).to_dict() + if not task_to_values: + raise ValueError("No task-wise tSNR values found for plotting.") + + tasks_present = sorted(task_to_values.keys()) + known_tasks = set(get_default_experiment_name_map("real").keys()) + unknown_tasks = sorted( + [task for task in tasks_present if canon_task(task) not in known_tasks] + ) + if unknown_tasks: + unknown_str = ", ".join(unknown_tasks) + raise ValueError( + "Found task(s) not mapped to EXP labels in real-data mapping: " + f"{unknown_str}. Remove these tasks from input TSV or add them to " + "DEFAULT_EXPERIMENT_NAME_MAP['real'] in helper_functions.py." + ) + + _, task_to_experiment, _, _ = build_experiment_display_info( + tasks_iterable=tasks_present, + task_reference_order=tasks_present, + simul_or_real="real", + ) + + order_task, _ = order_by_median_dict(task_to_values, reverse=True) + order_exp = [task_to_experiment[t] for t in order_task] + + df_plot = df.copy() + df_plot["experiment"] = df_plot["task"].map(task_to_experiment) + df_plot = df_plot[df_plot["task"].isin(order_task)].copy() + df_plot["experiment"] = pd.Categorical( + df_plot["experiment"], categories=order_exp, ordered=True + ) + + setup_pub_style() + sns.set_theme(context="paper", style="darkgrid") + + fig_w = max(14, 14 / 30 * len(order_exp)) + plt.figure(figsize=(fig_w, 6)) + ax = sns.boxplot( + data=df_plot, + x="experiment", + y="tsnr_median", + order=order_exp, + width=0.6, + linewidth=1, + showfliers=False, + ) + + annotate_medians_single_boxplot( + ax, + df_plot, + x_col="experiment", + y_col="tsnr_median", + order=order_exp, + fmt="{:.1f}", + box_alpha=0.6, + ) + + ax.set_xlabel("Experiment") + ax.set_ylabel("tSNR median") + for label in ax.get_xticklabels(): + label.set_rotation(65) + label.set_horizontalalignment("right") + label.set_fontweight("bold") + + plt.tight_layout() + plt.savefig(fig_path, bbox_inches="tight", pad_inches=0.1, dpi=500) + plt.close() + + return fig_path + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Build a grouped TSV from tsnr_summary.tsv and create a figure showing " + "tSNR median distributions per experiment (EXP)." + ) + ) + parser.add_argument( + "tsnr_summary_tsv", + help="Path to tsnr_summary.tsv", + ) + args = parser.parse_args() + + out_path = build_grouped_tsnr_summary(args.tsnr_summary_tsv) + fig_path = build_tsnr_distribution_figure(args.tsnr_summary_tsv) + print(f"[DONE] Wrote grouped TSV to: {out_path}") + print(f"[DONE] Wrote figure to: {fig_path}") + + +if __name__ == "__main__": + main() diff --git a/task_dFC/nifti_to_roi_signal.py b/task_dFC/nifti_to_roi_signal.py index 1e52cb8..b9dd9a7 100644 --- a/task_dFC/nifti_to_roi_signal.py +++ b/task_dFC/nifti_to_roi_signal.py @@ -1,3 +1,4 @@ +import argparse import json import os import warnings @@ -6,130 +7,339 @@ from pydfc import data_loader, task_utils -warnings.simplefilter("ignore") - -################################# Parameters ################################# -# data paths -# main_root = '../../DATA/ds002785' # for local -main_root = "../../../DATA/task-based/openneuro/ds002785" # for server -fmriprep_root = f"{main_root}/derivatives/fmriprep" -output_root = f"{main_root}/derivatives/ROI_timeseries" - -bold_suffix = "_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz" - -# for consistency we use 0 for resting state -TASKS = [ - "task-restingstate", - "task-anticipation", - "task-emomatching", - "task-faces", - "task-gstroop", - "task-workingmemory", -] - -# find all subjects -ALL_SUBJs = os.listdir(fmriprep_root) -ALL_SUBJs = [i for i in ALL_SUBJs if ("sub-" in i) and (not ".html" in i)] -ALL_SUBJs.sort() - -# pick the subject -job_id = int(os.getenv("SGE_TASK_ID")) -subj = ALL_SUBJs[job_id - 1] # SGE_TASK_ID starts from 1 not 0 - -print( - f"subject-level ROI signal extraction CODE started running ... for subject: {subj} ..." -) -################################# FIND THE FUNC FILE ################################# -for task in TASKS: +# warnings.simplefilter("ignore") + + +################################# FUNCTIONS ################################# +def run_roi_signal_extraction( + subj, + task, + bids_root, + fmriprep_root, + bold_suffix, + output_root, + session=None, + RUNS=[None], + trial_type_label="trial_type", + rest_labels=[], + denoising_strategy="simple", +): + """ + Extract ROI signals and task labels for a given subject and task + and optionally session. + """ + if session is None: + session_str = "" + else: + session_str = session # find the func file for this subject and task - ALL_TASK_FILES = os.listdir(f"{fmriprep_root}/{subj}/func/") + try: + if session is None: + ALL_TASK_FILES = os.listdir(f"{fmriprep_root}/{subj}/func/") + else: + ALL_TASK_FILES = os.listdir(f"{fmriprep_root}/{subj}/{session}/func/") + except FileNotFoundError: + warnings.warn(f"Subject {subj} {session_str} not found in {fmriprep_root}") + return + ALL_TASK_FILES = [ - i for i in ALL_TASK_FILES if (bold_suffix in i) and (task in i) + file_i + for file_i in ALL_TASK_FILES + if (bold_suffix in file_i) and (f"_{task}_" in file_i) ] # only keep the denoised files? or use the original files? - # print(ALL_TASK_FILES) - if not len(ALL_TASK_FILES) == 1: + + if not len(ALL_TASK_FILES) >= 1: # if the func file is not found, exclude the subject - print("Func file not found for " + subj + " " + task) - continue - fmriprep_file = f"{fmriprep_root}/{subj}/func/{ALL_TASK_FILES[0]}" - info_file = ( - f"{main_root}/{subj}/func/{ALL_TASK_FILES[0].replace(bold_suffix, '_bold.json')}" - ) + warnings.warn(f"Func file not found for {subj} {session_str} {task}") + return + + for run in RUNS: + if run is None: + task_file = ALL_TASK_FILES[0] + else: + task_file = [file_i for file_i in ALL_TASK_FILES if f"_{run}_" in file_i][0] + if session is None: + nifti_file = f"{fmriprep_root}/{subj}/func/{task_file}" + task_events_root = f"{bids_root}/{subj}/func" + else: + nifti_file = f"{fmriprep_root}/{subj}/{session}/func/{task_file}" + task_events_root = f"{bids_root}/{subj}/{session}/func" + # we need the info file to get the TR + # we can find the acquisition data in either the fmriprep folder + # or in the bids folder + # BUT for multi-echo data, we must use the fmriprep folder + # because the bids folder contains multiple files for each echo + # so first we check if the file exists in the fmriprep folder + # and if not, we check the bids folder + # the info file is the same as the nifti file but with a .json extension + info_file = nifti_file.replace(".nii.gz", ".json") + if not os.path.exists(info_file): + info_file = ( + f"{task_events_root}/{task_file.replace(bold_suffix, '_bold.json')}" + ) + + if os.path.exists(info_file): + f = open(info_file) + acquisition_data = json.load(f) + f.close() + else: + acquisition_data = None + + # in some cases the info file is common for all subjects and can be found in f"{bids_root}" + ALL_COMMON_FILES = os.listdir(f"{bids_root}/") + ALL_COMMON_FILES = [ + file_i + for file_i in ALL_COMMON_FILES + if (f"{task}_" in file_i) and ("_bold.json" in file_i) + ] + if len(ALL_COMMON_FILES) == 1: + global_info_file = f"{bids_root}/{ALL_COMMON_FILES[0]}" + f = open(global_info_file) + global_acquisition_data = json.load(f) + f.close() + else: + global_acquisition_data = None + + if global_acquisition_data is None and acquisition_data is None: + # if the acquisition_data is not found, exclude the subject + if run is None: + warnings.warn( + f"bold.json info file not found for {subj} {session_str} {task}" + ) + else: + warnings.warn( + f"bold.json info file not found for {subj} {session_str} {task} {run}" + ) + return + ################################# GET REPETITION TIME ######################### + TR_mri = None + # first check the acquisition_data + if acquisition_data is not None: + if "RepetitionTime" in acquisition_data: + TR_mri = acquisition_data["RepetitionTime"] + # if not found, check the global_acquisition_data + if TR_mri is None and global_acquisition_data is not None: + if "RepetitionTime" in global_acquisition_data: + TR_mri = global_acquisition_data["RepetitionTime"] + # if not found, print a warning and skip the subject + if TR_mri is None: + if run is None: + warnings.warn( + f"Repetition time not found for {subj} {session_str} {task}" + ) + else: + warnings.warn( + f"Repetition time not found for {subj} {session_str} {task} {run}" + ) + return + ################################# EXTRACT TIME SERIES ######################### + # extract ROI signals and convert to TIME_SERIES object + time_series = data_loader.nifti2timeseries( + nifti_file=nifti_file, + n_rois=100, + Fs=1 / TR_mri, + subj_id=subj, + confound_strategy=denoising_strategy, + standardize="zscore", + TS_name="BOLD", + session=task, + ) + num_time_mri = time_series.n_time + ################################# EXTRACT TASK LABELS ######################### + oversampling = 50 # more samples per TR than the func data to have a better event_labels time resolution - ################################# LOAD JSON INFO ######################### - # Opening JSON file as a dictionary - f = open(info_file) - acquisition_data = json.load(f) - f.close() - TR_mri = acquisition_data["RepetitionTime"] - ################################# EXTRACT TIME SERIES ######################### - # extract ROI signals and convert to TIME_SERIES object - time_series = data_loader.nifti2timeseries( - nifti_file=fmriprep_file, - n_rois=100, - Fs=1 / TR_mri, - subj_id=subj, - confound_strategy="no_motion", - standardize="zscore", - TS_name="BOLD", - session=task, - ) - num_time_mri = time_series.n_time - ################################# EXTRACT TASK LABELS ######################### - oversampling = 50 # more samples per TR than the func data to have a better event_labels time resolution - if task == "task-restingstate": - events = [] - event_types = ["rest"] - event_labels = np.zeros((int(num_time_mri * oversampling), 1)) - task_labels = np.zeros((int(num_time_mri * oversampling), 1)) - Fs_task = float(1 / TR_mri) * oversampling - else: - task_events_root = f"{main_root}/{subj}/func/" ALL_EVENTS_FILES = os.listdir(task_events_root) ALL_EVENTS_FILES = [ - i - for i in ALL_EVENTS_FILES - if (subj in i) and (task in i) and ("events.tsv" in i) + file_i + for file_i in ALL_EVENTS_FILES + if (f"{subj}_" in file_i) + and (f"_{task}_" in file_i) + and ("events.tsv" in file_i) ] + if not run is None: + ALL_EVENTS_FILES = [ + file_i for file_i in ALL_EVENTS_FILES if f"_{run}_" in file_i + ] + if not session is None: + ALL_EVENTS_FILES = [ + file_i for file_i in ALL_EVENTS_FILES if f"_{session}_" in file_i + ] + + events_file_exists = True if not len(ALL_EVENTS_FILES) == 1: - # if the events file is not found, exclude the subject - print("Events file not found for " + subj + " " + task) - continue - # load the tsv events file - events_file = task_events_root + ALL_EVENTS_FILES[0] - events = np.genfromtxt(events_file, delimiter="\t", dtype=str) - # get the task labels - event_types = ["rest"] + list(np.unique(events[1:, 2])) - event_labels, Fs_task = task_utils.events_time_to_labels( - events=events, - TR_mri=TR_mri, - num_time_mri=num_time_mri, - event_types=event_types, - oversampling=oversampling, - return_0_1=False, + # in some cases the event file is common for all subjects and can be found in f"{bids_root}" + ALL_EVENTS_FILES_COMMON = os.listdir(f"{bids_root}/") + ALL_EVENTS_FILES_COMMON = [ + file_i + for file_i in ALL_EVENTS_FILES_COMMON + if (f"{task}_" in file_i) and ("events.tsv" in file_i) + ] + if len(ALL_EVENTS_FILES_COMMON) == 1: + events_file = f"{bids_root}/{ALL_EVENTS_FILES_COMMON[0]}" + else: + # if the events file is not found, do not exclude the subject, only save time-series data + # this will allow including resting state files + if run is None: + warnings.warn( + f"Events file not found for {subj} {session_str} {task}" + ) + else: + warnings.warn( + f"Events file not found for {subj} {session_str} {task} {run}" + ) + events_file_exists = False + else: + events_file = f"{task_events_root}/{ALL_EVENTS_FILES[0]}" + + if events_file_exists: + # load the tsv events file + events = np.genfromtxt(events_file, delimiter="\t", dtype=str) + # get the event labels + event_labels, Fs_task, event_types = task_utils.events_time_to_labels( + events=events, + TR_mri=TR_mri, + num_time_mri=num_time_mri, + event_types=None, + oversampling=oversampling, + trial_type_label=trial_type_label, + rest_labels=rest_labels, + return_0_1=False, + ) + # fill task labels with task's index + task_labels = np.ones((int(num_time_mri * oversampling), 1)) * TASKS.index( + task + ) + ################################# SAVE ################################# + # save the ROI time series and task data + if events_file_exists: + task_data = { + "task": task, + "task_labels": task_labels, + "task_types": TASKS, + "event_labels": event_labels, + "event_types": event_types, + "events": events, + "Fs_task": Fs_task, + "TR_mri": TR_mri, + "num_time_mri": num_time_mri, + } + + if session is None: + subj_session_prefix = f"{subj}" + output_dir = f"{output_root}/{subj}" + else: + subj_session_prefix = f"{subj}_{session}" + output_dir = f"{output_root}/{subj}/{session}" + + if run is None: + output_file_prefix = f"{subj_session_prefix}_{task}" + else: + output_file_prefix = f"{subj_session_prefix}_{task}_{run}" + + if not os.path.exists(f"{output_dir}/"): + os.makedirs(f"{output_dir}/") + np.save(f"{output_dir}/{output_file_prefix}_time-series.npy", time_series) + if events_file_exists: + np.save(f"{output_dir}/{output_file_prefix}_task-data.npy", task_data) + + +######################################################################################## + +if __name__ == "__main__": + # argparse + HELPTEXT = """ + Script to convert nifti files to ROI signals for a given participant. + """ + + parser = argparse.ArgumentParser(description=HELPTEXT) + + parser.add_argument("--dataset_info", type=str, help="path to dataset info file") + parser.add_argument("--participant_id", type=str, help="participant id") + parser.add_argument( + "--denoising_strategy", type=str, default="simple", help="denoising strategy" + ) + + args = parser.parse_args() + + dataset_info_file = args.dataset_info + participant_id = args.participant_id + denoising_strategy = args.denoising_strategy + + # Read dataset info + with open(dataset_info_file, "r") as f: + dataset_info = json.load(f) + + print( + f"subject-level ROI signal extraction CODE started running ... for subject: {participant_id} ..." + ) + + TASKS = dataset_info["TASKS"] + + if "SESSIONS" in dataset_info: + SESSIONS = dataset_info["SESSIONS"] + else: + SESSIONS = None + if SESSIONS is None: + SESSIONS = [None] + + if "RUNS" in dataset_info: + RUNS = dataset_info["RUNS"] + else: + RUNS = None + if RUNS is None: + RUNS = {task: [None] for task in TASKS} + + if "{dataset}" in dataset_info["main_root"]: + main_root = dataset_info["main_root"].replace( + "{dataset}", dataset_info["dataset"] ) - # fill task labels with 0 (rest) and k (task's index) - task_labels = np.multiply(event_labels != 0, TASKS.index(task)) - ################################# SAVE ################################# - # save the ROI time series and task data - task_data = { - "task": task, - "task_labels": task_labels, - "task_types": TASKS, - "event_labels": event_labels, - "event_types": event_types, - "events": events, - "Fs_task": Fs_task, - "TR_mri": TR_mri, - "num_time_mri": num_time_mri, - } - subj_folder = f"{subj}_{task}" - if not os.path.exists(f"{output_root}/{subj_folder}/"): - os.makedirs(f"{output_root}/{subj_folder}/") - np.save(f"{output_root}/{subj_folder}/time_series.npy", time_series) - np.save(f"{output_root}/{subj_folder}/task_data.npy", task_data) - -print( - f"subject-level ROI signal extraction CODE finished running ... for subject: {subj} ..." -) + else: + main_root = dataset_info["main_root"] + + if "{main_root}" in dataset_info["bids_root"]: + bids_root = dataset_info["bids_root"].replace("{main_root}", main_root) + elif "{dataset}" in dataset_info["bids_root"]: + bids_root = dataset_info["bids_root"].replace( + "{dataset}", dataset_info["dataset"] + ) + else: + bids_root = dataset_info["bids_root"] + + if "{main_root}" in dataset_info["fmriprep_root"]: + fmriprep_root = dataset_info["fmriprep_root"].replace("{main_root}", main_root) + elif "{dataset}" in dataset_info["fmriprep_root"]: + fmriprep_root = dataset_info["fmriprep_root"].replace( + "{dataset}", dataset_info["dataset"] + ) + else: + fmriprep_root = dataset_info["fmriprep_root"] + + if "{main_root}" in dataset_info["roi_root"]: + output_root = dataset_info["roi_root"].replace("{main_root}", main_root) + else: + output_root = dataset_info["roi_root"] + + trial_type_label = dataset_info["trial_type_label"] + rest_labels = dataset_info["rest_labels"] + + for session in SESSIONS: + for task in TASKS: + run_roi_signal_extraction( + subj=participant_id, + task=task, + bids_root=bids_root, + fmriprep_root=fmriprep_root, + bold_suffix=dataset_info["bold_suffix"], + output_root=output_root, + session=session, + RUNS=RUNS[task], + trial_type_label=trial_type_label[task], + rest_labels=rest_labels[task], + denoising_strategy=denoising_strategy, + ) + + print( + f"subject-level ROI signal extraction CODE finished running ... for subject: {participant_id} ..." + ) + #################################################################### diff --git a/task_dFC/run_scripts_sge/dataset_info.json b/task_dFC/run_scripts_sge/dataset_info.json new file mode 100644 index 0000000..b01dbda --- /dev/null +++ b/task_dFC/run_scripts_sge/dataset_info.json @@ -0,0 +1,27 @@ +{ + "dataset" : "", + "main_root" : "/path/to/your/data/{dataset}", + "bids_root" : "/path/to/your/data/{dataset}/bids", + "fmriprep_root" : "/path/to/your/data/{dataset}/derivatives/fmriprep/23.1.3/output", + "roi_root" : "{main_root}/derivatives/ROI_timeseries", + "fitted_measures_root" : "{main_root}/derivatives/fitted_MEASURES", + "dFC_root" : "{main_root}/derivatives/dFC_assessed", + "ML_root" : "{main_root}/derivatives/ML", + "reports_root" : "{main_root}/derivatives/reports", + "bold_suffix" : "_space-MNI152NLin2009cAsym_res-2_desc-preproc_bold.nii.gz", + "SESSIONS" : [ + "ses-1" + ], + "TASKS" : [ + "task-A" + ], + "RUNS" : { + "task-A": ["run-01", "run-02", "run-03", "run-04", "run-05", "run-06"] + }, + "trial_type_label" : { + "task-A": "trial_type" + }, + "rest_labels" : { + "task-A": ["rest", "Rest"] + } +} diff --git a/task_dFC/run_scripts_sge/global_configs.json b/task_dFC/run_scripts_sge/global_configs.json new file mode 100644 index 0000000..04d15c4 --- /dev/null +++ b/task_dFC/run_scripts_sge/global_configs.json @@ -0,0 +1,170 @@ +{ + "DATASET_NAME": "", + "VISIT_IDS": [ + "", + "" + ], + "SESSION_IDS": [ + "", + "" + ], + "SUBSTITUTIONS": { + "[[NIPOPPY_DPATH_CONTAINERS]]": "/path/to/your/container_store/nipoppy", + "[[HEUDICONV_HEURISTIC_FILE]]": "", + "[[DCM2BIDS_CONFIG_FILE]]": "", + "[[FREESURFER_LICENSE_FILE]]": "/path/to/your/freesurfer/license.txt", + "[[TEMPLATEFLOW_HOME]]": "/path/to/your/templateflow" + }, + "DICOM_DIR_PARTICIPANT_FIRST": true, + "CONTAINER_CONFIG": { + "COMMAND": "apptainer", + "ARGS": [ + "--cleanenv" + ] + }, + "BIDS_PIPELINES": [ + { + "NAME": "heudiconv", + "VERSION": "0.12.2", + "CONTAINER_INFO": { + "FILE": "[[NIPOPPY_DPATH_CONTAINERS]]/[[PIPELINE_NAME]]_[[PIPELINE_VERSION]].sif", + "URI": "docker://nipy/[[PIPELINE_NAME]]:[[PIPELINE_VERSION]]" + }, + "STEPS": [ + { + "NAME": "prepare", + "INVOCATION_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/invocation-[[STEP_NAME]].json", + "DESCRIPTOR_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/descriptor.json" + }, + { + "NAME": "convert", + "INVOCATION_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/invocation-[[STEP_NAME]].json", + "DESCRIPTOR_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/descriptor.json", + "CONTAINER_CONFIG": { + "ARGS": [ + "--bind", + "[[HEUDICONV_HEURISTIC_FILE]]" + ] + }, + "UPDATE_DOUGHNUT": true + } + ] + }, + { + "NAME": "dcm2bids", + "VERSION": "3.1.0", + "CONTAINER_INFO": { + "FILE": "[[NIPOPPY_DPATH_CONTAINERS]]/[[PIPELINE_NAME]]_[[PIPELINE_VERSION]].sif", + "URI": "docker://unfmontreal/[[PIPELINE_NAME]]:[[PIPELINE_VERSION]]" + }, + "STEPS": [ + { + "NAME": "prepare", + "INVOCATION_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/invocation-[[STEP_NAME]].json", + "DESCRIPTOR_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/descriptor-dcm2bids_helper.json" + }, + { + "NAME": "convert", + "INVOCATION_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/invocation-[[STEP_NAME]].json", + "DESCRIPTOR_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/descriptor-dcm2bids.json", + "CONTAINER_CONFIG": { + "ARGS": [ + "--bind", + "[[DCM2BIDS_CONFIG_FILE]]" + ] + }, + "UPDATE_DOUGHNUT": true + } + ] + }, + { + "NAME": "bidscoin", + "VERSION": "4.3.2", + "STEPS": [ + { + "NAME": "prepare", + "INVOCATION_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/invocation-[[STEP_NAME]].json", + "DESCRIPTOR_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/descriptor-bidsmapper.json", + "ANALYSIS_LEVEL": "group" + }, + { + "NAME": "edit", + "INVOCATION_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/invocation-[[STEP_NAME]].json", + "DESCRIPTOR_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/descriptor-bidseditor.json", + "ANALYSIS_LEVEL": "group" + }, + { + "NAME": "convert", + "INVOCATION_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/invocation-[[STEP_NAME]].json", + "DESCRIPTOR_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/descriptor-bidscoiner.json", + "ANALYSIS_LEVEL": "participant", + "UPDATE_DOUGHNUT": true + } + ] + } + ], + "PROC_PIPELINES": [ + { + "NAME": "fmriprep", + "VERSION": "23.1.3", + "CONTAINER_INFO": { + "FILE": "[[NIPOPPY_DPATH_CONTAINERS]]/[[PIPELINE_NAME]]_[[PIPELINE_VERSION]].sif", + "URI": "docker://nipreps/[[PIPELINE_NAME]]:[[PIPELINE_VERSION]]" + }, + "CONTAINER_CONFIG": { + "ENV_VARS": { + "TEMPLATEFLOW_HOME": "[[TEMPLATEFLOW_HOME]]" + }, + "ARGS": [ + "--bind", + "[[FREESURFER_LICENSE_FILE]]", + "--bind", + "[[TEMPLATEFLOW_HOME]]" + ] + }, + "STEPS": [ + { + "INVOCATION_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/invocation.json", + "GENERATE_PYBIDS_DATABASE": false, + "DESCRIPTOR_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/descriptor.json", + "TRACKER_CONFIG_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/tracker_config.json" + } + ] + }, + { + "NAME": "freesurfer", + "VERSION": "7.3.2", + "DESCRIPTION": "Freesurfer version associated with fMRIPrep 23.1.3", + "STEPS": [ + { + "TRACKER_CONFIG_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/tracker_config.json" + } + ] + }, + { + "NAME": "mriqc", + "VERSION": "23.1.0", + "CONTAINER_INFO": { + "FILE": "[[NIPOPPY_DPATH_CONTAINERS]]/[[PIPELINE_NAME]]_[[PIPELINE_VERSION]].sif", + "URI": "docker://nipreps/[[PIPELINE_NAME]]:[[PIPELINE_VERSION]]" + }, + "CONTAINER_CONFIG": { + "ENV_VARS": { + "TEMPLATEFLOW_HOME": "[[TEMPLATEFLOW_HOME]]" + }, + "ARGS": [ + "--bind", + "[[TEMPLATEFLOW_HOME]]" + ] + }, + "STEPS": [ + { + "INVOCATION_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/invocation.json", + "DESCRIPTOR_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/descriptor.json", + "TRACKER_CONFIG_FILE": "[[NIPOPPY_DPATH_PIPELINES]]/[[PIPELINE_NAME]]-[[PIPELINE_VERSION]]/tracker_config.json" + } + ] + } + ], + "CUSTOM": {} +} diff --git a/task_dFC/run_scripts_sge/methods_config.json b/task_dFC/run_scripts_sge/methods_config.json new file mode 100644 index 0000000..722b4ff --- /dev/null +++ b/task_dFC/run_scripts_sge/methods_config.json @@ -0,0 +1,40 @@ +{ + "params_methods" : { + "W": 44, + "n_overlap": 1.0, + "sw_method": "pear_corr", + "tapered_window": true, + "TF_method": "WTC", + "clstr_base_measure": "SlidingWindow", + "clstr_distance": "manhattan", + "hmm_iter": 20, + "dhmm_obs_state_ratio": 0.666, + "n_states": 5, + "n_subj_clstrs": 10, + "verbose": 0, + "n_jobs_sw": 8, + "backend_sw": "threading", + "n_jobs_tf": 2, + "backend_tf": "loky", + "n_jobs_swc": null, + "backend_swc": null, + "normalization": true, + "num_subj": null, + "num_time_point": null + }, + "MEASURES_name_lst" : [ + "SlidingWindow", + "Time-Freq", + "CAP", + "ContinuousHMM", + "Windowless", + "Clustering", + "DiscreteHMM" + ], + "alter_hparams" : [], + "params_multi_analysis" : { + "n_jobs": 8, + "verbose": 0, + "backend": "loky" + } +} diff --git a/task_dFC/run_scripts_sge/multi_dataset_info.json b/task_dFC/run_scripts_sge/multi_dataset_info.json new file mode 100644 index 0000000..de0cf2b --- /dev/null +++ b/task_dFC/run_scripts_sge/multi_dataset_info.json @@ -0,0 +1,38 @@ +{ + "output_root": "/path/to/your/data/multi_dataset_analysis/results", + "real_data": { + "main_root": "/path/to/your/data/openneuro", + "DATASETS": [ + "ds001242", "ds002236", "ds002647", + "ds002843", "ds002994", + "ds003465", "ds003612", "ds003823", + "ds004044", "ds004349", "ds004359", + "ds004556", "ds004746", "ds004791", + "ds004848", "ds005038" + ], + "TASKS_to_include": [ + "task-arithmetic", "task-AudSem", "task-Axcpt", + "task-Cuedts", "task-emotionRegulation", "task-execution","task-expo", + "task-fearlearning", "task-feedback", "task-fribBids", "task-IHG", + "task-imagery", "task-itc", "task-localiser", "task-Localizer", + "task-matching", "task-motor", "task-paingen", "task-ppalocalizer", + "task-recall", "task-risk", "task-ST", "task-Stern", + "task-Stroop", "task-VisRhyme", "task-VisSem", "task-VisSpell", + "task-vswm" + ] + }, + "simulated_data": { + "main_root": "/path/to/your/data/simulated", + "DATASETS": [ + "ds000001", "ds000002", "ds000003", "ds000004", "ds000005", "ds000006" + ], + "TASKS_to_include": [ + "task-Axcpt", "task-Cuedts", "task-Stern", "task-Stroop", + "task-lowFreqLongRest", "task-lowFreqShortRest", "task-lowFreqShortTask", + "task-imagery", "task-execution", + "task-itc", "task-risk", + "task-Localizer", + "task-ppalocalizer" + ] + } +} diff --git a/task_dFC/run_scripts_sge/run_FCS.sh b/task_dFC/run_scripts_sge/run_FCS.sh new file mode 100644 index 0000000..9601a5c --- /dev/null +++ b/task_dFC/run_scripts_sge/run_FCS.sh @@ -0,0 +1,35 @@ +#!/bin/sh +# +#$ -N fit_fcs_job +#$ -o logs/fcs_out.txt +#$ -e logs/fcs_err.txt +#$ -l h_rt=168:00:00 +#$ -pe smp 8 +#$ -l h_vmem=8g +#$ -q YOUR_QUEUE + +# ---- Cluster configuration (set these for your system) ---- +VENV_PATH="/path/to/your/venv/bin/activate" +PYDFC_CODE_DIR="/path/to/pydfc" +# For conda environments, replace the two lines above with: +# CONDA_SH="/path/to/conda/etc/profile.d/conda.sh" +# CONDA_ENV="pydfc" +# ----------------------------------------------------------- + +DATASET_INFO="./dataset_info.json" +METHODS_CONFIG="./methods_config.json" + +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export OPENBLAS_NUM_THREADS=1 +export NUMEXPR_NUM_THREADS=1 + +# Activate virtual environment +source "$VENV_PATH" +# For conda: source "$CONDA_SH" && conda activate "$CONDA_ENV" + +python "$PYDFC_CODE_DIR/task_dFC/FCS_estimate.py" \ +--dataset_info $DATASET_INFO \ +--methods_config $METHODS_CONFIG + +deactivate diff --git a/task_dFC/run_scripts_sge/run_ML.sh b/task_dFC/run_scripts_sge/run_ML.sh new file mode 100644 index 0000000..2cc1339 --- /dev/null +++ b/task_dFC/run_scripts_sge/run_ML.sh @@ -0,0 +1,23 @@ +#!/bin/sh +# +#$ -N ml_job +#$ -o logs/ML_out.txt +#$ -e logs/ML_err.txt +#$ -pe smp 8 +#$ -l h_vmem=16g +#$ -q YOUR_QUEUE + +# ---- Cluster configuration (set these for your system) ---- +VENV_PATH="/path/to/your/venv/bin/activate" +PYDFC_CODE_DIR="/path/to/pydfc" +# ----------------------------------------------------------- + +DATASET_INFO="./dataset_info.json" + +# Activate virtual environment +source "$VENV_PATH" + +python "$PYDFC_CODE_DIR/task_dFC/ML.py" \ +--dataset_info $DATASET_INFO + +deactivate diff --git a/task_dFC/run_scripts_sge/run_across_dataset_analysis.sh b/task_dFC/run_scripts_sge/run_across_dataset_analysis.sh new file mode 100644 index 0000000..34a0fc8 --- /dev/null +++ b/task_dFC/run_scripts_sge/run_across_dataset_analysis.sh @@ -0,0 +1,47 @@ +#!/bin/sh +# +#$ -N across_dataset_analysis +#$ -o logs/across_dataset_analysis_out.txt +#$ -e logs/across_dataset_analysis_err.txt +#$ -l h_rt=05:00:00 +#$ -l h_vmem=32g +#$ -q YOUR_QUEUE + +# ---- Cluster configuration (set these for your system) ---- +VENV_PATH="/path/to/your/venv/bin/activate" +PYDFC_CODE_DIR="/path/to/pydfc" +# ----------------------------------------------------------- + +set -euo pipefail + +mkdir -p logs +source "$VENV_PATH" + +MULTI_DATASET_INFO="$PYDFC_CODE_DIR/task_dFC/run_scripts_sge/multi_dataset_info.json" + +SCRIPT_NAME=${1:-} +SIMUL_OR_REAL=${2:-real} +SCRIPT_DIR="$PYDFC_CODE_DIR/task_dFC/multi_dataset_analysis" +SCRIPT_PATH="$SCRIPT_DIR/$SCRIPT_NAME" + +if [ -z "$SCRIPT_NAME" ]; then + echo "Usage: qsub run_across_dataset_analysis.sh [real|simulated]" + exit 1 +fi + +if [ ! -f "$SCRIPT_PATH" ]; then + echo "Error: Script '$SCRIPT_PATH' not found." + exit 1 +fi + +case "$SCRIPT_NAME" in + performance_predict.py | performance_factor.py | ml_results.py | dfc_visualization.py | embedding_visualization.py | sample_matrix_visualization.py | task_presence_binarization.py | task_timing_stats.py | cohensd.py) + python "$SCRIPT_PATH" --multi_dataset_info "$MULTI_DATASET_INFO" --simul_or_real "$SIMUL_OR_REAL" + ;; + *) + echo "Unknown script: $SCRIPT_NAME" + exit 1 + ;; +esac + +deactivate diff --git a/task_dFC/run_scripts_sge/run_dFC.sh b/task_dFC/run_scripts_sge/run_dFC.sh new file mode 100644 index 0000000..463fbcf --- /dev/null +++ b/task_dFC/run_scripts_sge/run_dFC.sh @@ -0,0 +1,33 @@ +#!/bin/sh +# +#$ -N assess_dfc_job +#$ -o logs/dfc_out.txt +#$ -e logs/dfc_err.txt +#$ -l h_rt=24:00:00 +#$ -l h_vmem=32g +#$ -t 1-NSUBJECTS +#$ -q YOUR_QUEUE + +# ---- Cluster configuration (set these for your system) ---- +VENV_PATH="/path/to/your/venv/bin/activate" +PYDFC_CODE_DIR="/path/to/pydfc" +# ----------------------------------------------------------- + +SUBJECT_LIST="./subj_list.txt" +DATASET_INFO="./dataset_info.json" +METHODS_CONFIG="./methods_config.json" + +echo "Number subjects found: `cat $SUBJECT_LIST | wc -l`" + +SUBJECT_ID=`sed -n "${SGE_TASK_ID}p" $SUBJECT_LIST` +echo "Subject ID: $SUBJECT_ID" + +# Activate virtual environment +source "$VENV_PATH" + +python "$PYDFC_CODE_DIR/task_dFC/dFC_assessment.py" \ +--dataset_info $DATASET_INFO \ +--methods_config $METHODS_CONFIG \ +--participant_id $SUBJECT_ID + +deactivate diff --git a/task_dFC/run_scripts_sge/run_fmriprep.sh b/task_dFC/run_scripts_sge/run_fmriprep.sh new file mode 100644 index 0000000..075b581 --- /dev/null +++ b/task_dFC/run_scripts_sge/run_fmriprep.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# +#$ -N fmriprep_job +#$ -o logs/fmriprep_out.log +#$ -e logs/fmriprep_err.log +#$ -l h_vmem=16g +#$ -pe smp 8 +#$ -t 1-NSUBJECTS +#$ -q YOUR_QUEUE + +# ---- Cluster configuration (set these for your system) ---- +NIPOPPY_VENV_PATH="/path/to/your/nipoppy_venv/bin/activate" +# ----------------------------------------------------------- + +module load apptainer + +source "$NIPOPPY_VENV_PATH" + +SUBJECT_LIST="./subj_list.txt" + +echo "Number subjects found: $(wc -l < $SUBJECT_LIST)" + +SUBJECT_ID=$(sed -n "${SGE_TASK_ID}p" $SUBJECT_LIST) +echo "Subject ID: $SUBJECT_ID" + +nipoppy run \ +"$(dirname "$(pwd)")" \ +--pipeline fmriprep \ +--participant-id $SUBJECT_ID + +deactivate diff --git a/task_dFC/run_scripts_sge/run_nifti_to_roi.sh b/task_dFC/run_scripts_sge/run_nifti_to_roi.sh new file mode 100644 index 0000000..95c4b73 --- /dev/null +++ b/task_dFC/run_scripts_sge/run_nifti_to_roi.sh @@ -0,0 +1,39 @@ +#!/bin/sh +# +#$ -N extract_roi_job +#$ -o logs/roi_out.txt +#$ -e logs/roi_err.txt +#$ -l h_rt=24:00:00 +#$ -l h_vmem=64g +#$ -t 1-NSUBJECTS +#$ -q YOUR_QUEUE + +# ---- Cluster configuration (set these for your system) ---- +VENV_PATH="/path/to/your/venv/bin/activate" +PYDFC_CODE_DIR="/path/to/pydfc" +# ----------------------------------------------------------- + +# ----------------------------- +# Inputs +# ----------------------------- +SUBJECT_LIST="./subj_list.txt" +DATASET_INFO="./dataset_info.json" +DENOISING_STRATEGY=${1:-simple} + +echo "Denoising strategy: $DENOISING_STRATEGY" +echo "Number of subjects: $(wc -l < "$SUBJECT_LIST")" + +SUBJECT_ID=$(sed -n "${SGE_TASK_ID}p" "$SUBJECT_LIST") +echo "Subject ID: $SUBJECT_ID" + +# ----------------------------- +# Environment +# ----------------------------- +source "$VENV_PATH" + +python "$PYDFC_CODE_DIR/task_dFC/nifti_to_roi_signal.py" \ + --dataset_info $DATASET_INFO \ + --participant_id $SUBJECT_ID \ + --denoising_strategy $DENOISING_STRATEGY + +deactivate diff --git a/task_dFC/run_scripts_sge/run_report.sh b/task_dFC/run_scripts_sge/run_report.sh new file mode 100644 index 0000000..a1c13b4 --- /dev/null +++ b/task_dFC/run_scripts_sge/run_report.sh @@ -0,0 +1,25 @@ +#!/bin/sh +# +#$ -N report_job +#$ -o logs/report_out.txt +#$ -e logs/report_err.txt +#$ -l h_rt=24:00:00 +#$ -l h_vmem=64g +#$ -q YOUR_QUEUE + +# ---- Cluster configuration (set these for your system) ---- +VENV_PATH="/path/to/your/venv/bin/activate" +PYDFC_CODE_DIR="/path/to/pydfc" +# ----------------------------------------------------------- + +DATASET_INFO="./dataset_info.json" +SUBJ_LIST="./subj_list.txt" + +# Activate virtual environment +source "$VENV_PATH" + +python "$PYDFC_CODE_DIR/task_dFC/generate_report.py" \ +--dataset_info $DATASET_INFO \ +--subj_list $SUBJ_LIST + +deactivate diff --git a/task_dFC/run_scripts_slurm/dataset_info.json b/task_dFC/run_scripts_slurm/dataset_info.json new file mode 100644 index 0000000..b01dbda --- /dev/null +++ b/task_dFC/run_scripts_slurm/dataset_info.json @@ -0,0 +1,27 @@ +{ + "dataset" : "", + "main_root" : "/path/to/your/data/{dataset}", + "bids_root" : "/path/to/your/data/{dataset}/bids", + "fmriprep_root" : "/path/to/your/data/{dataset}/derivatives/fmriprep/23.1.3/output", + "roi_root" : "{main_root}/derivatives/ROI_timeseries", + "fitted_measures_root" : "{main_root}/derivatives/fitted_MEASURES", + "dFC_root" : "{main_root}/derivatives/dFC_assessed", + "ML_root" : "{main_root}/derivatives/ML", + "reports_root" : "{main_root}/derivatives/reports", + "bold_suffix" : "_space-MNI152NLin2009cAsym_res-2_desc-preproc_bold.nii.gz", + "SESSIONS" : [ + "ses-1" + ], + "TASKS" : [ + "task-A" + ], + "RUNS" : { + "task-A": ["run-01", "run-02", "run-03", "run-04", "run-05", "run-06"] + }, + "trial_type_label" : { + "task-A": "trial_type" + }, + "rest_labels" : { + "task-A": ["rest", "Rest"] + } +} diff --git a/task_dFC/run_scripts_slurm/descriptor.json b/task_dFC/run_scripts_slurm/descriptor.json new file mode 100644 index 0000000..f039583 --- /dev/null +++ b/task_dFC/run_scripts_slurm/descriptor.json @@ -0,0 +1,668 @@ +{ + "name": "fmriprep", + "description": "fmriprep", + "tool-version": "23.1.3", + "schema-version": "0.5", + "command-line": "[[NIPOPPY_CONTAINER_COMMAND]] --bind $SLURM_TMPDIR:/work [[NIPOPPY_FPATH_CONTAINER]] [BIDS_DIR] [OUTPUT_DIR] [ANALYSIS_LEVEL] [SKIP_BIDS_VALIDATION] [PARTICIPANT_LABEL] [TASK_ID] [ECHO_IDX] [BIDS_FILTERS] [ANAT_DERIVATIVES] [BIDS_DATABASE_DIR] [NPROCS] [OMP_NTHREADS] [MEMORY_GB] [LOW_MEM] [USE_PLUGIN] [SLOPPY] [ANAT_ONLY] [BOILERPLATE_ONLY] [REPORTS_ONLY] [IGNORE] [OUTPUT_SPACES] [LONGITUDINAL] [BOLD2T1W_INIT] [BOLD2T1W_DOF] [USE_BBR] [SLICE_TIME_REF] [DUMMY_SCANS] [_RANDOM_SEED] [ME_T2S_FIT_METHOD] [OUTPUT_LAYOUT] [ME_OUTPUT_ECHOS] [MEDIAL_SURFACE_NAN] [PROJECT_GOODVOXELS] [MD_ONLY_BOILERPLATE] [CIFTI_OUTPUT] [USE_AROMA] [AROMA_MELODIC_DIM] [AROMA_ERR_ON_WARN] [REGRESSORS_ALL_COMPS] [REGRESSORS_FD_TH] [REGRESSORS_DVARS_TH] [SKULL_STRIP_TEMPLATE] [SKULL_STRIP_FIXED_SEED] [SKULL_STRIP_T1W] [FMAP_BSPLINE] [FMAP_NO_DEMEAN] [USE_SYN_SDC] [FORCE_SYN] [FS_LICENSE_FILE] [FS_SUBJECTS_DIR] [HIRES] [SKIP_RECONALL] [TRACK_CARBON] [COUNTRY_CODE] [VERSION] [VERBOSE_COUNT] [WORK_DIR] [CLEAN_WORKDIR] [RESOURCE_MONITOR] [CONFIG_FILE] [WRITE_GRAPH] [STOP_ON_FIRST_CRASH] [NOTRACK] [DEBUG]", + "inputs": [ + { + "id": "bids_dir", + "name": "bids_dir", + "description": "The root folder of a BIDS valid dataset (sub-XXXXX folders should be found at the top level in this folder).", + "optional": false, + "type": "String", + "value-key": "[BIDS_DIR]" + }, + { + "id": "output_dir", + "name": "output_dir", + "description": "The output path for the outcomes of preprocessing and visual reports", + "optional": false, + "type": "String", + "value-key": "[OUTPUT_DIR]" + }, + { + "id": "analysis_level", + "name": "analysis_level", + "description": "Processing stage to be run, only \"participant\" in the case of fMRIPrep (see BIDS-Apps specification).", + "optional": false, + "type": "String", + "value-key": "[ANALYSIS_LEVEL]", + "value-choices": [ + "participant" + ] + }, + { + "id": "skip_bids_validation", + "name": "skip_bids_validation", + "description": "Assume the input dataset is BIDS compliant and skip the validation", + "optional": true, + "type": "Flag", + "value-key": "[SKIP_BIDS_VALIDATION]", + "command-line-flag": "--skip_bids_validation" + }, + { + "id": "participant_label", + "name": "participant_label", + "description": "A space delimited list of participant identifiers or a single identifier (the sub- prefix can be removed)", + "optional": true, + "type": "String", + "value-key": "[PARTICIPANT_LABEL]", + "list": true, + "command-line-flag": "--participant-label" + }, + { + "id": "task_id", + "name": "task_id", + "description": "Select a specific task to be processed", + "optional": true, + "type": "String", + "value-key": "[TASK_ID]", + "command-line-flag": "-t" + }, + { + "id": "echo_idx", + "name": "echo_idx", + "description": "Select a specific echo to be processed in a multiecho series", + "optional": true, + "type": "Number", + "value-key": "[ECHO_IDX]", + "command-line-flag": "--echo-idx" + }, + { + "id": "bids_filters", + "name": "bids_filters", + "description": "A JSON file describing custom BIDS input filters using PyBIDS. For further details, please check out https://fmriprep.readthedocs.io/en/0/faq.html#how-do-I-select-only-certain-files-to-be-input-to-fMRIPrep", + "optional": true, + "type": "String", + "value-key": "[BIDS_FILTERS]", + "command-line-flag": "--bids-filter-file" + }, + { + "id": "anat_derivatives", + "name": "anat_derivatives", + "description": "Reuse the anatomical derivatives from another fMRIPrep run or calculated with an alternative processing tool (NOT RECOMMENDED).", + "optional": true, + "type": "String", + "value-key": "[ANAT_DERIVATIVES]", + "command-line-flag": "--anat-derivatives" + }, + { + "id": "bids_database_dir", + "name": "bids_database_dir", + "description": "Path to a PyBIDS database folder, for faster indexing (especially useful for large datasets). Will be created if not present.", + "optional": true, + "type": "String", + "value-key": "[BIDS_DATABASE_DIR]", + "command-line-flag": "--bids-database-dir" + }, + { + "id": "nprocs", + "name": "nprocs", + "description": "Maximum number of threads across all processes", + "optional": true, + "type": "String", + "value-key": "[NPROCS]", + "command-line-flag": "--nprocs" + }, + { + "id": "omp_nthreads", + "name": "omp_nthreads", + "description": "Maximum number of threads per-process", + "optional": true, + "type": "String", + "value-key": "[OMP_NTHREADS]", + "command-line-flag": "--omp-nthreads" + }, + { + "id": "memory_gb", + "name": "memory_gb", + "description": "Upper bound memory limit for fMRIPrep processes", + "optional": true, + "type": "String", + "value-key": "[MEMORY_GB]", + "command-line-flag": "--mem" + }, + { + "id": "low_mem", + "name": "low_mem", + "description": "Attempt to reduce memory usage (will increase disk usage in working directory)", + "optional": true, + "type": "Flag", + "value-key": "[LOW_MEM]", + "command-line-flag": "--low-mem" + }, + { + "id": "use_plugin", + "name": "use_plugin", + "description": "Nipype plugin configuration file", + "optional": true, + "type": "String", + "value-key": "[USE_PLUGIN]", + "command-line-flag": "--use-plugin" + }, + { + "id": "sloppy", + "name": "sloppy", + "description": "Use low-quality tools for speed - TESTING ONLY", + "optional": true, + "type": "Flag", + "value-key": "[SLOPPY]", + "command-line-flag": "--sloppy" + }, + { + "id": "anat_only", + "name": "anat_only", + "description": "Run anatomical workflows only", + "optional": true, + "type": "Flag", + "value-key": "[ANAT_ONLY]", + "command-line-flag": "--anat-only" + }, + { + "id": "boilerplate_only", + "name": "boilerplate_only", + "description": "Generate boilerplate only", + "optional": true, + "type": "Flag", + "value-key": "[BOILERPLATE_ONLY]", + "command-line-flag": "--boilerplate-only" + }, + { + "id": "reports_only", + "name": "reports_only", + "description": "Only generate reports, don't run workflows. This will only rerun report aggregation, not reportlet generation for specific nodes.", + "optional": true, + "type": "Flag", + "value-key": "[REPORTS_ONLY]", + "command-line-flag": "--reports-only" + }, + { + "id": "ignore", + "name": "ignore", + "description": "Ignore selected aspects of the input dataset to disable corresponding parts of the workflow (a space delimited list)", + "optional": true, + "type": "String", + "value-key": "[IGNORE]", + "list": true, + "value-choices": [ + "fieldmaps", + "slicetiming", + "sbref", + "t2w", + "flair" + ], + "command-line-flag": "--ignore" + }, + { + "id": "output_spaces", + "name": "output_spaces", + "description": "Standard and non-standard spaces to resample anatomical and functional images to. Standard spaces may be specified by the form ``[:cohort-