diff --git a/README.rst b/README.rst index e774b97..f9ba123 100644 --- a/README.rst +++ b/README.rst @@ -165,3 +165,21 @@ If you are new to **pydfc**, we recommend starting with: This optional AI-assisted workflow is designed to complement — not replace — the documentation and example scripts. + +Generating New dFC Methods with AI +----------------------------------- + +You can ask an AI coding assistant (Claude, Copilot, Codex, etc.) to implement +brand-new dFC methods and add them directly to ``pydfc``. Just describe what +you want at whatever level of specificity feels right: + +- *"Generate 5 new creative dFC methods."* +- *"Generate 3 new state-based methods."* +- *"Implement a dFC method based on Granger causality."* +- *"Add a method that uses Riemannian geometry on covariance matrices."* +- *"Here is a paper — implement the method it describes."* (paste the PDF or text) + +The AI will read the existing codebase, follow the conventions in +``docs/ADDING_DFC_METHODS.md``, write the new method file, and register it in +``pydfc/dfc_methods/__init__.py`` so it works immediately alongside all other +methods. diff --git a/algorithm_similarity.py b/algorithm_similarity.py index bf3ae38..f8d526d 100644 --- a/algorithm_similarity.py +++ b/algorithm_similarity.py @@ -26,6 +26,7 @@ import re import sys from pathlib import Path + import numpy as np # Only calls whose resolved root module starts with one of these are kept as @@ -129,7 +130,7 @@ def _make_unique_labels(filepaths): def _hierarchical_cluster_order(matrix, cluster_method="average"): """Return indices that order similar methods next to each other.""" - from scipy.cluster.hierarchy import linkage, leaves_list + from scipy.cluster.hierarchy import leaves_list, linkage from scipy.spatial.distance import squareform if matrix.shape[0] < 2: @@ -159,7 +160,9 @@ def plot_similarity_heatmap( matrix = matrix[np.ix_(order, order)] labels = [labels[i] for i in order] - fig, ax = plt.subplots(figsize=(max(8, 0.45 * len(labels)), max(6, 0.45 * len(labels)))) + fig, ax = plt.subplots( + figsize=(max(8, 0.45 * len(labels)), max(6, 0.45 * len(labels))) + ) image = ax.imshow(matrix, vmin=0.0, vmax=1.0, cmap="viridis", aspect="equal") fig.colorbar(image, ax=ax, label="AS") @@ -187,7 +190,9 @@ def save_similarity_outputs(output_dir, labels, source_paths, matrix, table): with open(output_dir / "AS_jaccard_source_paths.json", "w", encoding="utf-8") as f: json.dump(source_paths, f, indent=2) - with open(output_dir / "AS_jaccard_pairs.csv", "w", newline="", encoding="utf-8") as f: + with open( + output_dir / "AS_jaccard_pairs.csv", "w", newline="", encoding="utf-8" + ) as f: fieldnames = [ "method_a", "method_b", @@ -206,9 +211,13 @@ def save_similarity_outputs(output_dir, labels, source_paths, matrix, table): try: fig, _ = plot_similarity_heatmap(matrix, labels, cluster=True) except ImportError: - print("Skipping heatmap export because matplotlib is not available in this environment.") + print( + "Skipping heatmap export because matplotlib is not available in this environment." + ) else: - fig.savefig(str(output_dir / "AS_jaccard_heatmap.png"), dpi=200, bbox_inches="tight") + fig.savefig( + str(output_dir / "AS_jaccard_heatmap.png"), dpi=200, bbox_inches="tight" + ) import matplotlib.pyplot as plt plt.close(fig) @@ -221,7 +230,9 @@ def load_similarity_outputs(output_dir): labels = np.load(output_dir / "AS_jaccard_names.npy", allow_pickle=True).tolist() with open(output_dir / "AS_jaccard_source_paths.json", "r", encoding="utf-8") as f: source_paths = json.load(f) - with open(output_dir / "AS_jaccard_pairs.csv", "r", newline="", encoding="utf-8") as f: + with open( + output_dir / "AS_jaccard_pairs.csv", "r", newline="", encoding="utf-8" + ) as f: table = list(csv.DictReader(f)) return labels, source_paths, matrix, table @@ -264,7 +275,9 @@ def main(filepaths): print(f"{method_a:35s} vs {method_b:35s} AS = {sim:.3f} shared = {shared}") - save_similarity_outputs("algorithm_similarity_results", names, source_paths, alg_sim, pairwise_rows) + save_similarity_outputs( + "algorithm_similarity_results", names, source_paths, alg_sim, pairwise_rows + ) print("Saved outputs to algorithm_similarity_results/") diff --git a/compute_feature_similarity.py b/compute_feature_similarity.py index 8b64bd0..ebed76b 100644 --- a/compute_feature_similarity.py +++ b/compute_feature_similarity.py @@ -1,11 +1,11 @@ +import pickle import sys -import numpy as np -from pathlib import Path from collections import defaultdict +from pathlib import Path -from pydfc.comparison import SimilarityAssessment # pip install pydfc +import numpy as np -import pickle +from pydfc.comparison import SimilarityAssessment # pip install pydfc # FULL PATH usually looks like: # "{path_to_datasets}/{dataset_id}/derivatives/dFC_assessed/{subject_id}/{session_id}/*.npy" @@ -17,7 +17,7 @@ print("Missing a path to the datasets directory") print("Usage: sbatch run_dfc.sh ") sys.exit(1) - + path_to_datasets = sys.argv[1] root = Path(path_to_datasets) @@ -28,25 +28,21 @@ # where matrix.shape = (1, num_methods, num_methods) and contains the similarity values between methods similarity = defaultdict( - lambda: defaultdict( - lambda: defaultdict( - lambda: defaultdict(dict) - ) - ) + lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) ) for dataset_dir in root.iterdir(): if not dataset_dir.is_dir(): continue - + if not dataset_dir.name.startswith("ds"): continue dataset_id = dataset_dir.name dfc_dir = dataset_dir / "derivatives" / "dFC_assessed" - + if not dfc_dir.is_dir(): print(f"Skipping {dataset_id} since /derivatives/dFC_assessed not found") continue @@ -61,104 +57,108 @@ # If no session folders, treat the subject directory as the session directory # to avoid file path issues. If this case, session_id will be set to None later. session_dirs = [ - p for p in subject_dir.iterdir() - if p.is_dir() and p.name.startswith("ses-") + p for p in subject_dir.iterdir() if p.is_dir() and p.name.startswith("ses-") ] if not session_dirs: session_dirs = [subject_dir] - for session_dir in session_dirs: - + # Group files by identifier files_by_identifier = defaultdict(list) for npy_file in session_dir.glob("dFC_*.npy"): - filename = npy_file.stem # removed .npy - - _, rest = filename.split("_", 1) # e.g., "dFC", "ses-wave1bas_task-Stroop_run-2_24" - identifier, method_number = rest.rsplit("_", 1) # e.g., "ses-wave1bas_task-Stroop_run-2", "24" + filename = npy_file.stem # removed .npy - files_by_identifier[identifier].append( - (int(method_number), npy_file) - ) + _, rest = filename.split( + "_", 1 + ) # e.g., "dFC", "ses-wave1bas_task-Stroop_run-2_24" + identifier, method_number = rest.rsplit( + "_", 1 + ) # e.g., "ses-wave1bas_task-Stroop_run-2", "24" + files_by_identifier[identifier].append((int(method_number), npy_file)) # Process one identifier at a time (similarity across methods) for identifier, file_info in files_by_identifier.items(): - + # Initialize session_id and run_id as None in case they don't exist session_id = None run_id = None task_id = None # must exist, see check later to catch error. - + # Get session, task, and run from identifier (if they exist) for part in identifier.split("_"): - if part.startswith("ses-"): # e.g., "ses-wave1bas" + if part.startswith("ses-"): # e.g., "ses-wave1bas" session_id = part - elif part.startswith("run-"): # e.g., "run-2" + elif part.startswith("run-"): # e.g., "run-2" run_id = part elif part.startswith("task-"): # e.g., "task-Stroop" task_id = part - + else: - print(f"Warning: Unrecognized part '{part}' in identifier '{identifier}' \ - of subject '{subject_id}' in dataset '{dataset_id}'. Ignoring this part.") - + print( + f"Warning: Unrecognized part '{part}' in identifier '{identifier}' \ + of subject '{subject_id}' in dataset '{dataset_id}'. Ignoring this part." + ) + if task_id is None: - print(f"Error: task_id not found in identifier '{identifier}' of subject '{subject_id}' \ - in dataset '{dataset_id}'. Skipping this file.") + print( + f"Error: task_id not found in identifier '{identifier}' of subject '{subject_id}' \ + in dataset '{dataset_id}'. Skipping this file." + ) continue # Sort methods numerically file_info.sort(key=lambda x: x[0]) method_numbers = [] - - # This is a list of the dFC objects from various methods - # that share the same identifier i.e., they came from the same + + # This is a list of the dFC objects from various methods + # that share the same identifier i.e., they came from the same # BOLD time series, but they were computed using different methods # Each dFC in the list is recognized as a dFC object by pydfc dFC_lst = [] for method_num, path in file_info: method_numbers.append(method_num) - dFC_lst.append( - np.load(path, allow_pickle=True).item() - ) - - # Note: type(output) = dict with + dFC_lst.append(np.load(path, allow_pickle=True).item()) + + # Note: type(output) = dict with # dict_keys(['measure_lst', 'TS_info_lst', 'common_TRs', 'time_record_dict', 'all']) similarity_assessment = SimilarityAssessment(dFC_lst=dFC_lst) output = similarity_assessment.assess_similarity_fast(dFC_lst=dFC_lst) - - + similarity[dataset_id][subject_id][session_id][run_id][task_id] = { "matrix": output, "methods": method_numbers, } - + print(f"Finished processing subject {subject_id} in dataset {dataset_id}") - -output_dir = Path("/home/kinichen/scratch/data/pydfc_validator/similarity_assessments_complete") + +output_dir = Path( + "/home/kinichen/scratch/data/pydfc_validator/similarity_assessments_complete" +) output_dir.mkdir(parents=True, exist_ok=True) output_file = output_dir / "similarity.pkl" + # Convert to normal dict for pickling. Need to do recursively because of the nested defaultdicts. def to_dict(d): if isinstance(d, defaultdict): return {k: to_dict(v) for k, v in d.items()} return d + similarity = to_dict(similarity) with open(output_file, "wb") as f: pickle.dump(similarity, f) - - -print(f"Saved results to: {output_file}") \ No newline at end of file + + +print(f"Saved results to: {output_file}") diff --git a/docs/ADDING_DFC_METHODS.md b/docs/ADDING_DFC_METHODS.md index 9d5844f..c139aeb 100644 --- a/docs/ADDING_DFC_METHODS.md +++ b/docs/ADDING_DFC_METHODS.md @@ -259,6 +259,42 @@ time_series = self.manipulate_time_series4FCS(time_series) and include any FCS-only parameters such as `num_subj` if the method uses them. +## ML Pipeline Registration (State-Based Methods Only) + +If the new method is state-based (`is_state_based = True`), you **must** also +register it in `pydfc/ml_utils.py` inside `process_SB_features`. This function +applies the correct feature transformation before classification. Omitting this +step causes the function to return `None`, which crashes the ML pipeline with a +`TypeError` at `subject_center`. + +Determine which branch your method belongs to: + +- **Softmax → ILR** (`if` branch, methods like `CAP`, `Clustering`): use this + when `FCS_proba` stores raw distances or dissimilarity scores that must first + be converted to a probability simplex via softmax. +- **ILR only** (`elif` branch, methods like `GaussianMixtureStates`, + `ContinuousHMM`, `NMFStates`): use this when `FCS_proba` already contains + proper probabilities (non-negative, rows summing to 1). + +Add the method name to the correct branch: + +```python +# pydfc/ml_utils.py — process_SB_features +elif measure_name in [ + "ContinuousHMM", + ... + "NMFStates", # ← add your method here if FCS_proba rows sum to 1 + ... +]: + X_transformed = ilr_transform(X) +``` + +A quick check: inspect `estimate_dFC` in the method file and look at how +`FCS_proba` is set. If it is produced by a row-wise normalization +(`/ row_sums`) or a soft-assignment model (GMM, HMM posterior), it belongs in +the ILR-only branch. If it stores distances or un-normalized scores, it belongs +in the softmax + ILR branch. + ## Package Export After adding a method file, update: @@ -376,3 +412,7 @@ against established methods rather than interpreted in isolation. subclass. - Importing optional dependencies at package level in a way that breaks unrelated methods. +- For state-based methods: forgetting to add the method name to `process_SB_features` + in `pydfc/ml_utils.py`. The function silently returns `None` if the method is + missing from both branches, crashing the ML pipeline. See the + "ML Pipeline Registration" section above. diff --git a/feature_similarity_heatmaps.ipynb b/feature_similarity_heatmaps.ipynb index d93d0e4..470bd15 100644 --- a/feature_similarity_heatmaps.ipynb +++ b/feature_similarity_heatmaps.ipynb @@ -50,8 +50,10 @@ "\n", "with open(f\"{root}/similarity.pkl\", \"rb\") as f:\n", " similarity = pickle.load(f)\n", - " \n", - "print(similarity.keys()) # layer 1 of hierarchy is datasets, then subjects, sessions, runs, tasks, etc. (pydFC objects)" + "\n", + "print(\n", + " similarity.keys()\n", + ") # layer 1 of hierarchy is datasets, then subjects, sessions, runs, tasks, etc. (pydFC objects)" ] }, { @@ -129,7 +131,9 @@ "# print(methods_num_ex) # list of method numbers (indices) used in the similarity assessment\n", "\n", "measures = sim_ex[\"matrix\"][\"measure_lst\"]\n", - "methods = [method.MEASURE_NAME for method in measures] # extract method names from the pydfc dfc_methods objects\n", + "methods = [\n", + " method.MEASURE_NAME for method in measures\n", + "] # extract method names from the pydfc dfc_methods objects\n", "print(methods[:5])\n", "print(len(methods))" ] @@ -141,9 +145,10 @@ "metadata": {}, "outputs": [], "source": [ - "######### Helper functions to collect and aggregate similarity matrices based on filters \n", + "######### Helper functions to collect and aggregate similarity matrices based on filters\n", "# for various levels (dataset, subject, session, run, task) #########\n", "\n", + "\n", "def collect_similarity_matrices(\n", " similarity: dict,\n", " dataset_id=None,\n", @@ -155,7 +160,7 @@ " metric=\"spearman\",\n", "):\n", " \"\"\"\n", - " Collect all similarity matrices matching the specified filters. If a filter is None, \n", + " Collect all similarity matrices matching the specified filters. If a filter is None,\n", " it matches all values for that level and aggregates over/across it.\n", "\n", " Returns:\n", @@ -189,19 +194,12 @@ " if task_id is not None and task != task_id:\n", " continue\n", "\n", - " matrices.append(\n", - " task_data[\"matrix\"][similarity_key][metric]\n", - " )\n", + " matrices.append(task_data[\"matrix\"][similarity_key][metric])\n", "\n", " return matrices\n", "\n", "\n", - "\n", - "\n", - "def aggregate_similarity_matrices(\n", - " matrices,\n", - " aggregation=\"mean\"\n", - "):\n", + "def aggregate_similarity_matrices(matrices, aggregation=\"mean\"):\n", " \"\"\"\n", " Parameters\n", " ----------\n", @@ -232,14 +230,11 @@ " aggregated = np.std(arr, axis=0)\n", "\n", " else:\n", - " raise ValueError(\n", - " f\"Unknown aggregation: {aggregation}\"\n", - " )\n", + " raise ValueError(f\"Unknown aggregation: {aggregation}\")\n", "\n", " return aggregated, len(matrices)\n", "\n", "\n", - "\n", "def plot_similarity_heatmap(\n", " matrix,\n", " aggregation_size=None,\n", @@ -249,11 +244,11 @@ " figsize=(10, 8),\n", " cmap=\"viridis\",\n", " cluster=True,\n", - " cluster_method=\"average\"\n", + " cluster_method=\"average\",\n", "):\n", - " \n", + "\n", " matrix = np.squeeze(matrix)\n", - " \n", + "\n", " # Optional hierarchical clustering to reorder methods based on similarity to each other\n", " if cluster:\n", "\n", @@ -277,12 +272,8 @@ " matrix = matrix[np.ix_(order, order)]\n", "\n", " # Reorder labels\n", - " method_names = [\n", - " method_names[i]\n", - " for i in order\n", - " ]\n", - " \n", - " \n", + " method_names = [method_names[i] for i in order]\n", + "\n", " plt.figure(figsize=figsize)\n", "\n", " sns.heatmap(\n", @@ -292,7 +283,7 @@ " yticklabels=method_names,\n", " cmap=cmap,\n", " )\n", - " \n", + "\n", " if aggregation_size is not None:\n", " title += f\" (n={aggregation_size})\"\n", "\n", @@ -335,18 +326,15 @@ " subject_id=subject_id,\n", " session_id=session_id,\n", " run_id=run_id,\n", - " task_id=task_id\n", + " task_id=task_id,\n", ")\n", "\n", - "aggregated, aggregation_size = aggregate_similarity_matrices(\n", - " matrices,\n", - " aggregation=\"mean\"\n", - ")\n", + "aggregated, aggregation_size = aggregate_similarity_matrices(matrices, aggregation=\"mean\")\n", "\n", "plot_similarity_heatmap(\n", " aggregated,\n", " aggregation_size,\n", - " title=f\"Similarity for {dataset_id} {subject_id} {session_id} {run_id} {task_id}\"\n", + " title=f\"Similarity for {dataset_id} {subject_id} {session_id} {run_id} {task_id}\",\n", ")" ] }, @@ -372,21 +360,12 @@ "\n", "task_id = \"task-Axcpt\"\n", "\n", - "matrices = collect_similarity_matrices(\n", - " similarity,\n", - " dataset_id=dataset_id,\n", - " task_id=task_id\n", - ")\n", + "matrices = collect_similarity_matrices(similarity, dataset_id=dataset_id, task_id=task_id)\n", "\n", - "aggregated, aggregation_size = aggregate_similarity_matrices(\n", - " matrices,\n", - " aggregation=\"mean\"\n", - ")\n", + "aggregated, aggregation_size = aggregate_similarity_matrices(matrices, aggregation=\"mean\")\n", "\n", "plot_similarity_heatmap(\n", - " aggregated,\n", - " aggregation_size,\n", - " title=f\"Mean similarity for {dataset_id} {task_id}\"\n", + " aggregated, aggregation_size, title=f\"Mean similarity for {dataset_id} {task_id}\"\n", ")" ] }, @@ -412,21 +391,12 @@ "\n", "task_id = \"task-Cuedts\"\n", "\n", - "matrices = collect_similarity_matrices(\n", - " similarity,\n", - " dataset_id=dataset_id,\n", - " task_id=task_id\n", - ")\n", + "matrices = collect_similarity_matrices(similarity, dataset_id=dataset_id, task_id=task_id)\n", "\n", - "aggregated, aggregation_size = aggregate_similarity_matrices(\n", - " matrices,\n", - " aggregation=\"mean\"\n", - ")\n", + "aggregated, aggregation_size = aggregate_similarity_matrices(matrices, aggregation=\"mean\")\n", "\n", "plot_similarity_heatmap(\n", - " aggregated,\n", - " aggregation_size,\n", - " title=f\"Mean similarity for {dataset_id} {task_id}\"\n", + " aggregated, aggregation_size, title=f\"Mean similarity for {dataset_id} {task_id}\"\n", ")" ] }, @@ -452,22 +422,15 @@ "\n", "task_id = \"task-Stroop\"\n", "\n", - "matrices = collect_similarity_matrices(\n", - " similarity,\n", - " dataset_id=dataset_id,\n", - " task_id=task_id\n", - ")\n", + "matrices = collect_similarity_matrices(similarity, dataset_id=dataset_id, task_id=task_id)\n", "\n", - "aggregated, aggregation_size = aggregate_similarity_matrices(\n", - " matrices,\n", - " aggregation=\"mean\"\n", - ")\n", + "aggregated, aggregation_size = aggregate_similarity_matrices(matrices, aggregation=\"mean\")\n", "\n", "plot_similarity_heatmap(\n", " aggregated,\n", " aggregation_size,\n", " title=f\"Mean similarity for {dataset_id} {task_id}\",\n", - " cluster=False\n", + " cluster=False,\n", ")" ] }, @@ -491,20 +454,12 @@ "source": [ "### 3.a) Average over everything for a specific dataset ###\n", "\n", - "matrices = collect_similarity_matrices(\n", - " similarity,\n", - " dataset_id=dataset_id\n", - ")\n", + "matrices = collect_similarity_matrices(similarity, dataset_id=dataset_id)\n", "\n", - "aggregated, aggregation_size = aggregate_similarity_matrices(\n", - " matrices,\n", - " aggregation=\"mean\"\n", - ")\n", + "aggregated, aggregation_size = aggregate_similarity_matrices(matrices, aggregation=\"mean\")\n", "\n", "plot_similarity_heatmap(\n", - " aggregated,\n", - " aggregation_size,\n", - " title=f\"Mean similarity for {dataset_id}\"\n", + " aggregated, aggregation_size, title=f\"Mean similarity for {dataset_id}\"\n", ")" ] }, @@ -529,20 +484,14 @@ "### 3.b) Standard deviation over everything for a specific dataset ###\n", "# Measures which method pairs are more stable vs. more variable across filters\n", "\n", - "matrices = collect_similarity_matrices(\n", - " similarity,\n", - " dataset_id=dataset_id\n", - ")\n", + "matrices = collect_similarity_matrices(similarity, dataset_id=dataset_id)\n", "\n", - "aggregated, aggregation_size = aggregate_similarity_matrices(\n", - " matrices,\n", - " aggregation=\"std\"\n", - ")\n", + "aggregated, aggregation_size = aggregate_similarity_matrices(matrices, aggregation=\"std\")\n", "\n", "plot_similarity_heatmap(\n", " aggregated,\n", " aggregation_size,\n", - " title=f\"Standard deviation of similarity for {dataset_id}\"\n", + " title=f\"Standard deviation of similarity for {dataset_id}\",\n", ")" ] } diff --git a/feature_similarity_heatmaps_run.py b/feature_similarity_heatmaps_run.py index d5bf010..29ab3c3 100644 --- a/feature_similarity_heatmaps_run.py +++ b/feature_similarity_heatmaps_run.py @@ -1,14 +1,14 @@ # %% +# %% +import os import pickle -import numpy as np + import matplotlib.pyplot as plt +import numpy as np import seaborn as sns -from scipy.cluster.hierarchy import linkage, leaves_list +from scipy.cluster.hierarchy import leaves_list, linkage from scipy.spatial.distance import squareform - -# %% -import os os.makedirs("feature_similarity_results", exist_ok=True) os.makedirs("feature_similarity_results/pdf", exist_ok=True) os.makedirs("feature_similarity_results/jpg", exist_ok=True) @@ -19,8 +19,10 @@ with open(f"{root}/similarity.pkl", "rb") as f: similarity = pickle.load(f) - -print(similarity.keys()) # layer 1 of hierarchy is datasets, then subjects, sessions, runs, tasks, etc. (pydFC objects) + +print( + similarity.keys() +) # layer 1 of hierarchy is datasets, then subjects, sessions, runs, tasks, etc. (pydFC objects) # %% @@ -39,9 +41,10 @@ # %% -######### Helper functions to collect and aggregate similarity matrices based on filters +######### Helper functions to collect and aggregate similarity matrices based on filters # for various levels (dataset, subject, session, run, task) ######### + def collect_similarity_matrices( similarity: dict, dataset_id=None, @@ -53,7 +56,7 @@ def collect_similarity_matrices( metric="spearman", ): """ - Collect all similarity matrices matching the specified filters. If a filter is None, + Collect all similarity matrices matching the specified filters. If a filter is None, it matches all values for that level and aggregates over/across it. Returns: @@ -87,19 +90,12 @@ def collect_similarity_matrices( if task_id is not None and task != task_id: continue - matrices.append( - task_data["matrix"][similarity_key][metric] - ) + matrices.append(task_data["matrix"][similarity_key][metric]) return matrices - - -def aggregate_similarity_matrices( - matrices, - aggregation="mean" -): +def aggregate_similarity_matrices(matrices, aggregation="mean"): """ Parameters ---------- @@ -130,14 +126,11 @@ def aggregate_similarity_matrices( aggregated = np.std(arr, axis=0) else: - raise ValueError( - f"Unknown aggregation: {aggregation}" - ) + raise ValueError(f"Unknown aggregation: {aggregation}") return aggregated, len(matrices) - def plot_similarity_heatmap( matrix, aggregation_size=None, @@ -147,11 +140,11 @@ def plot_similarity_heatmap( figsize=(10, 8), cmap="viridis", cluster=True, - cluster_method="average" + cluster_method="average", ): - + matrix = np.squeeze(matrix) - + # Optional hierarchical clustering to reorder methods based on similarity to each other if cluster: @@ -175,12 +168,8 @@ def plot_similarity_heatmap( matrix = matrix[np.ix_(order, order)] # Reorder labels - method_names = [ - method_names[i] - for i in order - ] - - + method_names = [method_names[i] for i in order] + plt.figure(figsize=figsize) sns.heatmap( @@ -190,7 +179,7 @@ def plot_similarity_heatmap( yticklabels=method_names, cmap=cmap, ) - + if aggregation_size is not None: title += f" (n={aggregation_size})" @@ -200,15 +189,13 @@ def plot_similarity_heatmap( plt.yticks(rotation=0, fontsize=6) plt.tight_layout() - + # For running .py, save fig plt.savefig(f"feature_similarity/pdf/{title}.pdf", bbox_inches="tight") plt.savefig(f"feature_similarity/jpg/{title}.jpg", bbox_inches="tight") plt.close() - - # %% ### Average over everything (all subjects, sessions, runs, and datasets) for a specific TASK ### @@ -224,23 +211,14 @@ def plot_similarity_heatmap( ) for task_id in task_ids: - - matrices = collect_similarity_matrices( - similarity, - task_id=task_id - ) - aggregated, aggregation_size = aggregate_similarity_matrices( - matrices, - aggregation="mean" - ) + matrices = collect_similarity_matrices(similarity, task_id=task_id) - plot_similarity_heatmap( - aggregated, - aggregation_size, - title=f"{task_id}" + aggregated, aggregation_size = aggregate_similarity_matrices( + matrices, aggregation="mean" ) + plot_similarity_heatmap(aggregated, aggregation_size, title=f"{task_id}") # %% @@ -250,66 +228,41 @@ def plot_similarity_heatmap( for dataset_id in dataset_ids: - matrices = collect_similarity_matrices( - similarity, - dataset_id=dataset_id - ) + matrices = collect_similarity_matrices(similarity, dataset_id=dataset_id) aggregated, aggregation_size = aggregate_similarity_matrices( - matrices, - aggregation="mean" + matrices, aggregation="mean" ) - plot_similarity_heatmap( - aggregated, - aggregation_size, - title=f"{dataset_id}" - ) - - + plot_similarity_heatmap(aggregated, aggregation_size, title=f"{dataset_id}") # %% ### Average over EVERYTHING ### -matrices = collect_similarity_matrices( - similarity -) +matrices = collect_similarity_matrices(similarity) -aggregated, aggregation_size = aggregate_similarity_matrices( - matrices, - aggregation="mean" -) +aggregated, aggregation_size = aggregate_similarity_matrices(matrices, aggregation="mean") plot_similarity_heatmap( - aggregated, - aggregation_size, - title=f"Mean dFC feature similarity between methods" + aggregated, aggregation_size, title="Mean dFC feature similarity between methods" ) - # %% ### Standard deviation over EVERYTHING ### # Measures which method pairs are more stable vs. more variable across filters -matrices = collect_similarity_matrices( - similarity -) +matrices = collect_similarity_matrices(similarity) -aggregated, aggregation_size = aggregate_similarity_matrices( - matrices, - aggregation="std" -) +aggregated, aggregation_size = aggregate_similarity_matrices(matrices, aggregation="std") plot_similarity_heatmap( aggregated, aggregation_size, - title=f"Standard deviation of dFC feature similarity between methods" + title="Standard deviation of dFC feature similarity between methods", ) - - -print("Complete! Figures saved to feature_similarity_results/") \ No newline at end of file +print("Complete! Figures saved to feature_similarity_results/") diff --git a/pydfc/ml_utils.py b/pydfc/ml_utils.py index 1d30d1c..38bc429 100644 --- a/pydfc/ml_utils.py +++ b/pydfc/ml_utils.py @@ -1922,6 +1922,7 @@ def process_SB_features(X, measure_name): "MiniBatchKMeansStates", "GaussianMixtureStates", "BayesianGaussianMixtureStates", + "NMFStates", "SpectralStates", "BirchStates", "AgglomerativeStates", diff --git a/similarity_compute.py b/similarity_compute.py index 8e20049..17c6ca5 100644 --- a/similarity_compute.py +++ b/similarity_compute.py @@ -1,11 +1,11 @@ +import pickle import sys -import numpy as np -from pathlib import Path from collections import defaultdict +from pathlib import Path -from pydfc.comparison import SimilarityAssessment # pip install pydfc +import numpy as np -import pickle +from pydfc.comparison import SimilarityAssessment # pip install pydfc # FULL PATH usually looks like: # "{path_to_datasets}/{dataset_id}/derivatives/dFC_assessed/{subject_id}/{session_id}/*.npy" @@ -17,7 +17,7 @@ print("Missing a path to the datasets directory") print("Usage: sbatch run_dfc.sh ") sys.exit(1) - + path_to_datasets = sys.argv[1] root = Path(path_to_datasets) @@ -28,11 +28,7 @@ # where matrix.shape = (1, num_methods, num_methods) and contains the similarity values between methods similarity = defaultdict( - lambda: defaultdict( - lambda: defaultdict( - lambda: defaultdict(dict) - ) - ) + lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) ) for dataset_dir in root.iterdir(): @@ -43,7 +39,7 @@ dataset_id = dataset_dir.name dfc_dir = dataset_dir / "derivatives" / "dFC_assessed" - + if not dfc_dir.is_dir(): print(f"Skipping {dataset_id} since /derivatives/dFC_assessed not found") continue @@ -58,102 +54,104 @@ # If no session folders, treat the subject directory as the session directory # to avoid file path issues. If this case, session_id will be set to None later. session_dirs = [ - p for p in subject_dir.iterdir() - if p.is_dir() and p.name.startswith("ses-") + p for p in subject_dir.iterdir() if p.is_dir() and p.name.startswith("ses-") ] if not session_dirs: session_dirs = [subject_dir] - for session_dir in session_dirs: - + # Group files by identifier files_by_identifier = defaultdict(list) for npy_file in session_dir.glob("dFC_*.npy"): - filename = npy_file.stem # removed .npy - - _, rest = filename.split("_", 1) # e.g., "dFC", "ses-wave1bas_task-Stroop_run-2_24" - identifier, method_number = rest.rsplit("_", 1) # e.g., "ses-wave1bas_task-Stroop_run-2", "24" + filename = npy_file.stem # removed .npy - files_by_identifier[identifier].append( - (int(method_number), npy_file) - ) + _, rest = filename.split( + "_", 1 + ) # e.g., "dFC", "ses-wave1bas_task-Stroop_run-2_24" + identifier, method_number = rest.rsplit( + "_", 1 + ) # e.g., "ses-wave1bas_task-Stroop_run-2", "24" + files_by_identifier[identifier].append((int(method_number), npy_file)) # Process one identifier at a time (similarity across methods) for identifier, file_info in files_by_identifier.items(): - + # Initialize session_id and run_id as None in case they don't exist session_id = None run_id = None task_id = None # must exist, see check later to catch error. - + # Get session, task, and run from identifier (if they exist) for part in identifier.split("_"): - if part.startswith("ses-"): # e.g., "ses-wave1bas" + if part.startswith("ses-"): # e.g., "ses-wave1bas" session_id = part - elif part.startswith("run-"): # e.g., "run-2" + elif part.startswith("run-"): # e.g., "run-2" run_id = part elif part.startswith("task-"): # e.g., "task-Stroop" task_id = part - + else: - print(f"Warning: Unrecognized part '{part}' in identifier '{identifier}' \ - of subject '{subject_id}' in dataset '{dataset_id}'. Ignoring this part.") - + print( + f"Warning: Unrecognized part '{part}' in identifier '{identifier}' \ + of subject '{subject_id}' in dataset '{dataset_id}'. Ignoring this part." + ) + if task_id is None: - print(f"Error: task_id not found in identifier '{identifier}' of subject '{subject_id}' \ - in dataset '{dataset_id}'. Skipping this file.") + print( + f"Error: task_id not found in identifier '{identifier}' of subject '{subject_id}' \ + in dataset '{dataset_id}'. Skipping this file." + ) continue # Sort methods numerically file_info.sort(key=lambda x: x[0]) method_numbers = [] - - # This is a list of the dFC objects from various methods - # that share the same identifier i.e., they came from the same + + # This is a list of the dFC objects from various methods + # that share the same identifier i.e., they came from the same # BOLD time series, but they were computed using different methods # Each dFC in the list is recognized as a dFC object by pydfc dFC_lst = [] for method_num, path in file_info: method_numbers.append(method_num) - dFC_lst.append( - np.load(path, allow_pickle=True).item() - ) - + dFC_lst.append(np.load(path, allow_pickle=True).item()) + similarity_assessment = SimilarityAssessment(dFC_lst=dFC_lst) output = similarity_assessment.assess_similarity_fast(dFC_lst=dFC_lst) - - + similarity[dataset_id][subject_id][session_id][run_id][task_id] = { "matrix": output, "methods": method_numbers, } - + print(f"Finished processing subject {subject_id} in dataset {dataset_id}") - + output_dir = root / "similarity_assessments" output_dir.mkdir(parents=True, exist_ok=True) output_file = output_dir / "similarity.pkl" + # Convert to normal dict for pickling. Need to do recursively because of the nested defaultdicts. def to_dict(d): if isinstance(d, defaultdict): return {k: to_dict(v) for k, v in d.items()} return d + similarity = to_dict(similarity) with open(output_file, "wb") as f: pickle.dump(similarity, f) - - -print(f"Saved results to: {output_file}") \ No newline at end of file + + +print(f"Saved results to: {output_file}") diff --git a/task_dFC/multi_dataset_analysis/ml_results.py b/task_dFC/multi_dataset_analysis/ml_results.py index 0f4a5ef..61e0aca 100644 --- a/task_dFC/multi_dataset_analysis/ml_results.py +++ b/task_dFC/multi_dataset_analysis/ml_results.py @@ -252,13 +252,12 @@ def style_boxplot(ax, box_edge): line.set_zorder(1) -def overlay_method_means(ax, df_best, lower, upper): +def overlay_method_means(ax, df_best, lower, upper, lw=2.4, halfwidth=0.25): means = df_best.groupby("dFC method", observed=True)["score"].mean() yticks = ax.get_yticks() yticklabels = [tick.get_text() for tick in ax.get_yticklabels()] y_positions = {label: yticks[index] for index, label in enumerate(yticklabels)} - halfwidth = 0.25 for method, mean_score in means.items(): if method not in y_positions or pd.isna(mean_score): continue @@ -269,7 +268,7 @@ def overlay_method_means(ax, df_best, lower, upper): y_pos - halfwidth, y_pos + halfwidth, colors="#050505", - lw=2.4, + lw=lw, zorder=3, ) @@ -666,9 +665,9 @@ def plot_lollipop_pointplot( metric, simul_or_real, ): - method_medians = df_best.groupby("dFC method", observed=True)["score"].median() + method_means = df_best.groupby("dFC method", observed=True)["score"].mean() method_order_sorted = ( - method_medians.reindex(method_order).sort_values(ascending=True).index.tolist() + method_means.reindex(method_order).sort_values(ascending=True).index.tolist() ) plot_width = 10 @@ -683,10 +682,6 @@ def plot_lollipop_pointplot( else: colored_experiments = get_colored_experiment_mask(df_best, color_threshold) - neutral_palette = create_neutral_palette( - experiment_order, colored_experiments, experiment_palette - ) - box_edge = "#730800" # Lollipop: 5th–95th percentile range line + median dot per method @@ -702,7 +697,6 @@ def plot_lollipop_pointplot( ax.scatter(med, i, color=box_edge, s=28, zorder=2, linewidths=0) lower, upper = get_pointplot_limits(metric) - overlay_method_means(ax, df_best, lower, upper) sns.pointplot( data=df_best, @@ -715,21 +709,27 @@ def plot_lollipop_pointplot( errorbar=None, linestyles="", markers="o", - palette=neutral_palette, + palette=experiment_palette, ax=ax, zorder=6, ) finalize_marker_edges(ax) - resize_colored_markers(ax, experiment_order, colored_experiments, method_order_sorted) + # Only starred (top) experiments get large markers; all others stay small + resize_colored_markers( + ax, experiment_order, set(top_experiments), method_order_sorted + ) + + # Called after pointplot so yticks are populated + overlay_method_means(ax, df_best, lower, upper, lw=2.8, halfwidth=0.30) point_coordinates = extract_pointplot_coordinates( - ax, method_order_sorted, experiment_order, neutral_palette + ax, method_order_sorted, experiment_order, experiment_palette ) overlay_top_experiment_shapes( ax, df_best, point_coordinates, - neutral_palette, + experiment_palette, top_experiment_shapes=TOP_EXPERIMENT_SHAPES, ) @@ -747,7 +747,7 @@ def plot_lollipop_pointplot( plt.setp(ax.get_xticklabels(), fontsize=12) _highlight_nonaigm_labels(ax) _build_experiment_legend( - ax, experiment_order, neutral_palette, colored_experiments, top_experiments + ax, experiment_order, experiment_palette, colored_experiments, top_experiments ) figure.tight_layout()