Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions HT_LLM/similarity/AS_FS_scatterplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

algorithm_pairs = pd.read_csv(
"HT_LLM/similarity/algorithm_similarity_results/AS_BOO_weighted_jaccard_pairs.csv"
)

feature_pairs = pd.read_csv(
"HT_LLM/similarity/feature_similarity_results/FS_pairs.csv"
)

scatter_df = feature_pairs.merge(
algorithm_pairs[["pair_key", "algorithm_similarity"]],
on="pair_key",
how="inner",
)

plt.figure(figsize=(7, 6))

ax = sns.scatterplot(
data=scatter_df,
x="algorithm_similarity",
y="feature_similarity",
alpha=0.7,

)


# Label certain points by their method pair name for identification
counter = 0
for idx, row in scatter_df.iterrows():
# Only label points satisfying this condition
if row['algorithm_similarity'] > 0.5:
ax.text(
row['algorithm_similarity'] + 0.01, # Add a slight x-offset manually
row['feature_similarity'] + 0.01, # Add a slight y-offset manually
row['pair_key'], # The labelled text is the method pair's name
color='red', # Highlight color
weight='bold'
)
counter += 1

# Sanity checks for correct number of method pairs (not missing any)
print("Feature pairs:", len(feature_pairs))
print("Algorithm pairs:", len(algorithm_pairs))
print("Merged pairs:", len(scatter_df))
print("Plotted points:", counter)

plt.xlim(0, 1)
plt.ylim(-0.2, 1)
plt.axvline(x=0.5)
plt.axhline(y=0.5)

plt.xlabel("Algorithm similarity")
plt.ylabel("dFC feature similarity")
plt.title("Algorithm similarity vs. dFC feature similarity for method pairs")

plt.tight_layout()
plt.savefig("HT_LLM/similarity/algorithm_vs_feature_similarity_scatter.png", dpi=600)
plt.savefig("HT_LLM/similarity/algorithm_vs_feature_similarity_scatter.pdf")
plt.show()
14 changes: 11 additions & 3 deletions HT_LLM/similarity/algorithm_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,12 @@ def weighted_jaccard_similarity(counts_a, counts_b):

overlap = sum(min(counts_a[op], counts_b[op]) for op in operations)
union = sum(max(counts_a[op], counts_b[op]) for op in operations)
similarity = overlap / union

if not operations: # neither script captured any operations from tracked libraries
similarity = 0.0

else:
similarity = overlap / union

return overlap, union, similarity

Expand Down Expand Up @@ -208,6 +210,10 @@ def _make_unique_labels(filepaths):
labels.append(base if counts[base] == 1 else f"{base}_{counts[base]}")
return labels

def make_pair_key(method_a, method_b):
"""Stable key for joining method-pair outputs across scripts for AS vs FS scatterplot."""
return "+".join(sorted([method_a, method_b]))


def _hierarchical_cluster_order(matrix, cluster_method="average"):
"""Return indices that order similar methods next to each other."""
Expand Down Expand Up @@ -300,11 +306,12 @@ def save_similarity_outputs(output_dir, labels, matrix, table):
output_dir / f"AS_{METRIC_NAME}_pairs.csv", "w", newline="", encoding="utf-8"
) as f:
fieldnames = [
"pair_key",
"method_a",
"method_b",
"source_a",
"source_b",
"similarity",
"algorithm_similarity",
"weighted_overlap",
"weighted_union",
"n_shared_distinct",
Expand Down Expand Up @@ -393,11 +400,12 @@ def main(filepaths):
}
pairwise_rows.append(
{
"pair_key": make_pair_key(method_a, method_b),
"method_a": method_a,
"method_b": method_b,
"source_a": source_paths[i],
"source_b": source_paths[j],
"similarity": similarity,
"algorithm_similarity": similarity,
"weighted_overlap": weighted_overlap,
"weighted_union": weighted_union,
"n_shared_distinct": len(shared),
Expand Down
104 changes: 98 additions & 6 deletions HT_LLM/similarity/heatmaps_feature_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@
import os
import pickle

import csv
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy.cluster.hierarchy import leaves_list, linkage
from scipy.spatial.distance import squareform

DEFAULT_OUTPUT_DIR = "HT_LLM/similarity/feature_similarity_results"
os.makedirs(f"{DEFAULT_OUTPUT_DIR}/pdf", exist_ok=True)
os.makedirs(f"{DEFAULT_OUTPUT_DIR}/png", exist_ok=True)
DEFAULT_OUTPUT_DIR = Path("HT_LLM/similarity/feature_similarity_results")

os.makedirs(DEFAULT_OUTPUT_DIR / "pdf", exist_ok=True)
os.makedirs(DEFAULT_OUTPUT_DIR / "png", exist_ok=True)


NON_AIGM_SET = {
Expand Down Expand Up @@ -63,6 +67,10 @@
######### Helper functions to collect and aggregate similarity matrices based on filters
# for various levels (dataset, subject, session, run, task) #########

def make_pair_key(method_a, method_b):
"""Stable key for joining method-pair outputs across scripts for AS vs FS scatterplot."""
return "+".join(sorted([method_a, method_b]))


def collect_similarity_matrices(
similarity: dict,
Expand Down Expand Up @@ -150,23 +158,96 @@ def aggregate_similarity_matrices(matrices, aggregation="mean"):
return aggregated, len(matrices)


def save_feature_similarity_outputs(
matrix,
method_names,
aggregation_size,
output_name,
aggregation="mean",
similarity_key="all",
metric="spearman",
):
"""
Save feature similarity outputs in both matrix and tidy pairwise formats.

The pairwise CSV is designed to be merged with algorithm similarity outputs
using pair_key in the AS vs FS scatterplot.
"""

matrix = np.squeeze(matrix)
method_names = list(method_names)

# Note: Saved matrix is in the original ethods order, not the reordered version for plotting
np.save(DEFAULT_OUTPUT_DIR / f"{output_name}_matrix.npy", matrix)
np.save(DEFAULT_OUTPUT_DIR / f"{output_name}_method_names.npy", np.array(method_names, dtype=object))
pairwise_path = DEFAULT_OUTPUT_DIR / f"{output_name}_pairs.csv"

rows = []

for i in range(len(method_names)):
for j in range(i + 1, len(method_names)):
method_a = method_names[i]
method_b = method_names[j]

rows.append(
{
"pair_key": make_pair_key(method_a, method_b),
"method_a": method_a,
"method_b": method_b,
"feature_similarity": matrix[i, j],
"aggregation": aggregation,
"aggregation_size": aggregation_size,
"similarity_key": similarity_key,
"metric": metric,
"output_name": output_name,
}
)

with open(pairwise_path, "w", newline="", encoding="utf-8") as f:
fieldnames = [
"pair_key",
"method_a",
"method_b",
"feature_similarity",
"aggregation",
"aggregation_size",
"similarity_key",
"metric",
"output_name",
]
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)


def plot_similarity_heatmap(
matrix,
aggregation_size=None,
method_names=methods,
ordered_method_names=None,
title="Similarity Heatmap",
annot=False,
figsize=(10, 8),
cluster=False,
cluster_method="average",
):

method_names = list(method_names)
matrix = np.squeeze(matrix)

# Use given ordering of methods, i.e., do not do hierarchical clustering
if ordered_method_names is not None:
if cluster:
raise ValueError("When using a given methods order, it doesn't make sense to also do clustering.")
ordered_method_names = list(ordered_method_names)
order = [method_names.index(name) for name in ordered_method_names]

# Reorder matrix and labels
matrix = matrix[np.ix_(order, order)]
method_names = [method_names[i] for i in order]

# Highlight non-AIGM names in a different color
highlight_color = NON_AIGM_COLOR
highlight_method_names = NON_AIGM_SET
matrix = np.squeeze(matrix)

# Optional hierarchical clustering to reorder methods based on similarity to each other
if cluster:
Expand Down Expand Up @@ -300,6 +381,16 @@ def plot_similarity_heatmap(
cluster=True,
)

save_feature_similarity_outputs(
matrix=aggregated,
method_names=methods, # not reordered since aggregated did not go through hierarchical clustering in plotting function
aggregation_size=aggregation_size,
output_name="FS",
aggregation="mean",
similarity_key="all",
metric="spearman",
)


# %%
### Standard deviation over EVERYTHING ###
Expand All @@ -314,7 +405,8 @@ def plot_similarity_heatmap(
aggregated,
aggregation_size,
title="Standard deviation of dFC feature similarity between methods",
method_names=methods_order,
method_names=methods,
ordered_method_names=methods_order
)


Expand Down
Loading