diff --git a/HT_LLM/similarity/AS_FS_scatterplot.py b/HT_LLM/similarity/AS_FS_scatterplot.py new file mode 100644 index 0000000..218430c --- /dev/null +++ b/HT_LLM/similarity/AS_FS_scatterplot.py @@ -0,0 +1,59 @@ +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns + +algorithm_pairs = pd.read_csv( + "HT_LLM/similarity/algorithm_similarity_results/AS_BOO_weighted_jaccard_pairs.csv" +) + +feature_pairs = pd.read_csv("HT_LLM/similarity/feature_similarity_results/FS_pairs.csv") + +scatter_df = feature_pairs.merge( + algorithm_pairs[["pair_key", "algorithm_similarity"]], + on="pair_key", + how="inner", +) + +plt.figure(figsize=(7, 6)) + +ax = sns.scatterplot( + data=scatter_df, + x="algorithm_similarity", + y="feature_similarity", + alpha=0.7, +) + + +# Label certain points by their method pair name for identification +counter = 0 +for idx, row in scatter_df.iterrows(): + # Only label points satisfying this condition + if row["algorithm_similarity"] > 0.5: + ax.text( + row["algorithm_similarity"] + 0.01, # Add a slight x-offset manually + row["feature_similarity"] + 0.01, # Add a slight y-offset manually + row["pair_key"], # The labelled text is the method pair's name + color="red", # Highlight color + weight="bold", + ) + counter += 1 + +# Sanity checks for correct number of method pairs (not missing any) +print("Feature pairs:", len(feature_pairs)) +print("Algorithm pairs:", len(algorithm_pairs)) +print("Merged pairs:", len(scatter_df)) +print("Plotted points:", counter) + +plt.xlim(0, 1) +plt.ylim(-0.2, 1) +plt.axvline(x=0.5) +plt.axhline(y=0.5) + +plt.xlabel("Algorithm similarity") +plt.ylabel("dFC feature similarity") +plt.title("Algorithm similarity vs. dFC feature similarity for method pairs") + +plt.tight_layout() +plt.savefig("HT_LLM/similarity/algorithm_vs_feature_similarity_scatter.png", dpi=600) +plt.savefig("HT_LLM/similarity/algorithm_vs_feature_similarity_scatter.pdf") +plt.show() diff --git a/HT_LLM/similarity/algorithm_similarity.py b/HT_LLM/similarity/algorithm_similarity.py index 5bf2751..3daf4e4 100644 --- a/HT_LLM/similarity/algorithm_similarity.py +++ b/HT_LLM/similarity/algorithm_similarity.py @@ -156,11 +156,13 @@ def weighted_jaccard_similarity(counts_a, counts_b): overlap = sum(min(counts_a[op], counts_b[op]) for op in operations) union = sum(max(counts_a[op], counts_b[op]) for op in operations) - similarity = overlap / union if not operations: # neither script captured any operations from tracked libraries similarity = 0.0 + else: + similarity = overlap / union + return overlap, union, similarity @@ -209,6 +211,11 @@ def _make_unique_labels(filepaths): return labels +def make_pair_key(method_a, method_b): + """Stable key for joining method-pair outputs across scripts for AS vs FS scatterplot.""" + return "+".join(sorted([method_a, method_b])) + + def _hierarchical_cluster_order(matrix, cluster_method="average"): """Return indices that order similar methods next to each other.""" @@ -300,11 +307,12 @@ def save_similarity_outputs(output_dir, labels, matrix, table): output_dir / f"AS_{METRIC_NAME}_pairs.csv", "w", newline="", encoding="utf-8" ) as f: fieldnames = [ + "pair_key", "method_a", "method_b", "source_a", "source_b", - "similarity", + "algorithm_similarity", "weighted_overlap", "weighted_union", "n_shared_distinct", @@ -393,11 +401,12 @@ def main(filepaths): } pairwise_rows.append( { + "pair_key": make_pair_key(method_a, method_b), "method_a": method_a, "method_b": method_b, "source_a": source_paths[i], "source_b": source_paths[j], - "similarity": similarity, + "algorithm_similarity": similarity, "weighted_overlap": weighted_overlap, "weighted_union": weighted_union, "n_shared_distinct": len(shared), diff --git a/HT_LLM/similarity/heatmaps_feature_similarity.py b/HT_LLM/similarity/heatmaps_feature_similarity.py index 2fb45b7..3bd3df7 100644 --- a/HT_LLM/similarity/heatmaps_feature_similarity.py +++ b/HT_LLM/similarity/heatmaps_feature_similarity.py @@ -3,9 +3,12 @@ # $ salloc --account=def- --mem=128G --cpus-per-task=8 --time=4:00:00 # or submit a batch job +import csv + # %% import os import pickle +from pathlib import Path import matplotlib.pyplot as plt import numpy as np @@ -13,9 +16,10 @@ from scipy.cluster.hierarchy import leaves_list, linkage from scipy.spatial.distance import squareform -DEFAULT_OUTPUT_DIR = "HT_LLM/similarity/feature_similarity_results" -os.makedirs(f"{DEFAULT_OUTPUT_DIR}/pdf", exist_ok=True) -os.makedirs(f"{DEFAULT_OUTPUT_DIR}/png", exist_ok=True) +DEFAULT_OUTPUT_DIR = Path("HT_LLM/similarity/feature_similarity_results") + +os.makedirs(DEFAULT_OUTPUT_DIR / "pdf", exist_ok=True) +os.makedirs(DEFAULT_OUTPUT_DIR / "png", exist_ok=True) NON_AIGM_SET = { @@ -64,6 +68,11 @@ # for various levels (dataset, subject, session, run, task) ######### +def make_pair_key(method_a, method_b): + """Stable key for joining method-pair outputs across scripts for AS vs FS scatterplot.""" + return "+".join(sorted([method_a, method_b])) + + def collect_similarity_matrices( similarity: dict, dataset_id=None, @@ -150,23 +159,101 @@ def aggregate_similarity_matrices(matrices, aggregation="mean"): return aggregated, len(matrices) +def save_feature_similarity_outputs( + matrix, + method_names, + aggregation_size, + output_name, + aggregation="mean", + similarity_key="all", + metric="spearman", +): + """ + Save feature similarity outputs in both matrix and tidy pairwise formats. + + The pairwise CSV is designed to be merged with algorithm similarity outputs + using pair_key in the AS vs FS scatterplot. + """ + + matrix = np.squeeze(matrix) + method_names = list(method_names) + + # Note: Saved matrix is in the original ethods order, not the reordered version for plotting + np.save(DEFAULT_OUTPUT_DIR / f"{output_name}_matrix.npy", matrix) + np.save( + DEFAULT_OUTPUT_DIR / f"{output_name}_method_names.npy", + np.array(method_names, dtype=object), + ) + pairwise_path = DEFAULT_OUTPUT_DIR / f"{output_name}_pairs.csv" + + rows = [] + + for i in range(len(method_names)): + for j in range(i + 1, len(method_names)): + method_a = method_names[i] + method_b = method_names[j] + + rows.append( + { + "pair_key": make_pair_key(method_a, method_b), + "method_a": method_a, + "method_b": method_b, + "feature_similarity": matrix[i, j], + "aggregation": aggregation, + "aggregation_size": aggregation_size, + "similarity_key": similarity_key, + "metric": metric, + "output_name": output_name, + } + ) + + with open(pairwise_path, "w", newline="", encoding="utf-8") as f: + fieldnames = [ + "pair_key", + "method_a", + "method_b", + "feature_similarity", + "aggregation", + "aggregation_size", + "similarity_key", + "metric", + "output_name", + ] + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + def plot_similarity_heatmap( matrix, aggregation_size=None, method_names=methods, + ordered_method_names=None, title="Similarity Heatmap", annot=False, figsize=(10, 8), cluster=False, cluster_method="average", ): - method_names = list(method_names) + matrix = np.squeeze(matrix) + + # Use given ordering of methods, i.e., do not do hierarchical clustering + if ordered_method_names is not None: + if cluster: + raise ValueError( + "When using a given methods order, it doesn't make sense to also do clustering." + ) + ordered_method_names = list(ordered_method_names) + order = [method_names.index(name) for name in ordered_method_names] + + # Reorder matrix and labels + matrix = matrix[np.ix_(order, order)] + method_names = [method_names[i] for i in order] # Highlight non-AIGM names in a different color highlight_color = NON_AIGM_COLOR highlight_method_names = NON_AIGM_SET - matrix = np.squeeze(matrix) # Optional hierarchical clustering to reorder methods based on similarity to each other if cluster: @@ -300,6 +387,16 @@ def plot_similarity_heatmap( cluster=True, ) +save_feature_similarity_outputs( + matrix=aggregated, + method_names=methods, # not reordered since aggregated did not go through hierarchical clustering in plotting function + aggregation_size=aggregation_size, + output_name="FS", + aggregation="mean", + similarity_key="all", + metric="spearman", +) + # %% ### Standard deviation over EVERYTHING ### @@ -314,7 +411,8 @@ def plot_similarity_heatmap( aggregated, aggregation_size, title="Standard deviation of dFC feature similarity between methods", - method_names=methods_order, + method_names=methods, + ordered_method_names=methods_order, )