diff --git a/graphconstructor/operators/__init__.py b/graphconstructor/operators/__init__.py index 9e3c09e..5427bb3 100644 --- a/graphconstructor/operators/__init__.py +++ b/graphconstructor/operators/__init__.py @@ -4,6 +4,7 @@ from .knn_selector import KNNSelector from .locally_adaptive_sparsification import LocallyAdaptiveSparsification from .marginal_likelihood import MarginalLikelihoodFilter +from .metric_distance import MetricDistanceFilter from .minimum_spanning_tree import MinimumSpanningTree from .noise_corrected import NoiseCorrected from .weight_threshold import WeightThreshold @@ -16,6 +17,7 @@ "KNNSelector", "LocallyAdaptiveSparsification", "MarginalLikelihoodFilter", + "MetricDistanceFilter", "MinimumSpanningTree", "NoiseCorrected", "WeightThreshold", diff --git a/graphconstructor/operators/metric_distance.py b/graphconstructor/operators/metric_distance.py new file mode 100644 index 0000000..dfff9b8 --- /dev/null +++ b/graphconstructor/operators/metric_distance.py @@ -0,0 +1,91 @@ +from dataclasses import dataclass +from typing import Literal +import networkx as nx +from distanceclosure.dijkstra import single_source_dijkstra_path_length +from networkx.algorithms.shortest_paths.weighted import _weight_function +from ..graph import Graph +from .base import GraphOperator + + +Mode = Literal["distance", "similarity"] + + +@dataclass(slots=True) +class MetricDistanceFilter(GraphOperator): + """ + Metric Distance Backbone Filter for similarity graphs. + Code: https://github.com/CASCI-lab/distanceclosure/blob/master/distanceclosure/backbone.py + + Parameters + ---------- + weight : str, optional + Edge property containing distance values, by default 'weight' + distortion : bool, optional + Whether to compute and return distortion values, by default False + verbose : bool, optional + Prints statements as it computes, by default False + """ + + weight: str = "weight" + distortion: bool = False + verbose: bool = False + mode: Mode = "distance" + supported_modes = ["distance", "similarity"] + + @staticmethod + def _compute_distortions(D: GraphOperator, B, weight="weight", disjunction=sum): + G = D.copy() + + G.remove_edges_from(B.edges()) + weight_function = _weight_function(B, weight) + + svals = dict() + for u in G.nodes(): + metric_dist = single_source_dijkstra_path_length( + B, source=u, weight_function=weight_function, disjunction=disjunction + ) + for v in G.neighbors(u): + svals[(u, v)] = G[u][v][weight] / metric_dist[v] + + return svals + + def _directed_filter(self, G: Graph) -> Graph: + raise NotImplementedError("MetricDistanceFilter is defined only for undirected graphs.") + + def _undirected_filter(self, D): + disjunction = sum + + D = D.to_networkx() + G = D.copy() + weight_function = _weight_function(G, self.weight) + + if self.verbose: + total = G.number_of_nodes() + i = 0 + + for u, _ in sorted(G.degree(weight=self.weight), key=lambda x: x[1]): + if self.verbose: + i += 1 + per = i / total + print("Backbone: Dijkstra: {i:d} of {total:d} ({per:.2%})".format(i=i, total=total, per=per)) + + metric_dist = single_source_dijkstra_path_length( + G, source=u, weight_function=weight_function, disjunction=disjunction + ) + for v in list(G.neighbors(u)): + if metric_dist[v] < G[u][v][self.weight]: + G.remove_edge(u, v) + + sparse_adj = nx.to_scipy_sparse_array(G) + if self.distortion: + svals = self._compute_distortions(D, G, weight=self.weight, disjunction=disjunction) + return Graph(sparse_adj, False, True, self.mode), svals + else: + return Graph(sparse_adj, False, True, self.mode) + + def apply(self, G: Graph) -> Graph: + self._check_mode_supported(G) + if G.directed: + return self._directed_filter(G) + else: + return self._undirected_filter(G) diff --git a/pyproject.toml b/pyproject.toml index 51059b1..8f86069 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ networkx = ">3.4.0" pandas = ">=2.1.1" scipy= ">=1.14.0" matplotlib= ">=3.8.0" +distanceclosure = ">=0.5.0" [tool.poetry.group.dev.dependencies] decorator = "^5.1.1" diff --git a/tests/test_metric_distance.py b/tests/test_metric_distance.py new file mode 100644 index 0000000..8de5865 --- /dev/null +++ b/tests/test_metric_distance.py @@ -0,0 +1,110 @@ +import networkx as nx +import numpy as np +import pytest +import scipy.sparse as sp +from graphconstructor import Graph +from graphconstructor.operators import MetricDistanceFilter + + +def _csr(data, rows, cols, n): + return sp.csr_matrix( + (np.asarray(data, float), (np.asarray(rows, int), np.asarray(cols, int))), + shape=(n, n), + ) + + +def simple_undirected_graph(): + A = _csr( + data=[0.5, 0.5, 0.3, 0.3, 0.8, 0.8], + rows=[0, 1, 0, 2, 1, 2], + cols=[1, 0, 2, 0, 2, 1], + n=3, + ) + + return Graph.from_csr(A, directed=False, weighted=True, mode="similarity") + + +def simple_directed_graph(): + A = _csr( + data=[0.5, 0.5, 0.3], + rows=[0, 0, 1], + cols=[1, 2, 2], + n=3, + ) + + return Graph.from_csr(A, directed=True, weighted=True, mode="similarity") + + +def test_basic_undirected_filtering(): + G0 = simple_undirected_graph() + + out = MetricDistanceFilter(distortion=False, verbose=False).apply(G0) + + assert isinstance(out, Graph) + assert not out.directed + assert out.weighted + + original_edges = G0.to_networkx().number_of_edges() + result_edges = out.to_networkx().number_of_edges() + assert result_edges <= original_edges + + +def test_undirected_filtering_distortion(): + G0 = simple_undirected_graph() + + out = MetricDistanceFilter(distortion=True, verbose=False).apply(G0) + + assert isinstance(out, tuple) + assert len(out) == 2 + + filtered_graph, svals = out + assert isinstance(filtered_graph, Graph) + assert isinstance(svals, dict) + + if svals: + key = next(iter(svals.keys())) + assert isinstance(key, tuple) + assert len(key) == 2 + + +def test_directed_graph_not_implemented(): + G0 = simple_directed_graph() + with pytest.raises(NotImplementedError): + MetricDistanceFilter().apply(G0) + + +def test_edge_removal_logic(): + G0 = simple_undirected_graph() + out = MetricDistanceFilter().apply(G0) + + original_nx = G0.to_networkx() + out_nx = out.to_networkx() + + assert out_nx.number_of_edges() <= original_nx.number_of_edges() + + if nx.is_connected(original_nx): + assert nx.is_connected(out_nx) + + +def test_isolated_nodes(): + A = _csr( + data=[0.5, 0.5], + rows=[0, 1], + cols=[1, 0], + n=3, + ) + G0 = Graph.from_csr(A, directed=False, weighted=True, mode="distance") + out = MetricDistanceFilter().apply(G0) + + assert out.to_networkx().number_of_nodes() == 3 + assert 2 in out.to_networkx().nodes() + + +def test_empty_graph(): + A = _csr(data=[], rows=[], cols=[], n=3) + G0 = Graph.from_csr(A, directed=False, weighted=True, mode="distance") + + out = MetricDistanceFilter().apply(G0) + + assert out.to_networkx().number_of_edges() == 0 + assert out.to_networkx().number_of_nodes() == 3