From 02a22950680d48fc52757efcc66b6e51ffcb9958 Mon Sep 17 00:00:00 2001 From: Radomir Stevanovic Date: Fri, 8 May 2026 21:42:49 +0200 Subject: [PATCH] Add networkx graph output support to `@graph_argument` --- dimod/decorators.py | 32 ++++++++++++--- ...h_argument-decorator-283515d9b29828a1.yaml | 5 +++ tests/test_decorators.py | 40 ++++++++++++++++++- 3 files changed, 71 insertions(+), 6 deletions(-) create mode 100644 releasenotes/notes/add-nx-graph-output-to-graph_argument-decorator-283515d9b29828a1.yaml diff --git a/dimod/decorators.py b/dimod/decorators.py index 336dcecf0..774483c91 100644 --- a/dimod/decorators.py +++ b/dimod/decorators.py @@ -16,12 +16,10 @@ import collections.abc as abc import inspect import itertools -import warnings from functools import wraps -from numbers import Integral -from dimod.exceptions import BinaryQuadraticModelStructureError, WriteableError +from dimod.exceptions import BinaryQuadraticModelStructureError from dimod.utilities import new_variable_label from dimod.vartypes import as_vartype @@ -230,7 +228,8 @@ def graph_argument(*arg_names, **options): The wrapped function accepts either an integer n, interpreted as a complete graph of size n, a nodes/edges pair, a sequence of edges, or a - NetworkX graph. The argument is converted into a nodes/edges 2-tuple. + NetworkX graph. The argument is converted into a nodes/edges 2-tuple, or + a NetworkX graph if ``as_networkx`` option is set. Args: *arg_names (optional, default='G'): @@ -239,14 +238,18 @@ def graph_argument(*arg_names, **options): allow_None (bool, optional, default=False): If True, None can be passed through as an input graph. + as_networkx (book, optional, default=False): + If True, return a NetworkX graph. + """ # by default, constrain only one argument, the 'G` if not arg_names: arg_names = ['G'] - # we only allow one option allow_None + # we only allow two options allow_None = options.pop("allow_None", False) + as_networkx = options.pop("as_networkx", False) if options: # to keep it consistent with python3 # behaviour like graph_argument(*arg_names, allow_None=False) @@ -254,6 +257,14 @@ def graph_argument(*arg_names, **options): msg = "graph_argument() for an unexpected keyword argument '{}'".format(key) raise TypeError(msg) + # if user asks for a nx graph, we require nx + if as_networkx: + try: + import networkx as nx + except ImportError: + raise RuntimeError("graph_argument() with 'as_networkx=True' " + "requires NetworkX installed") + def _graph_arg(f): argspec = inspect.getfullargspec(f) @@ -265,6 +276,10 @@ def _enforce_single_arg(name, args, kwargs): if hasattr(G, 'edges') and hasattr(G, 'nodes'): # networkx or perhaps a named tuple + if as_networkx and isinstance(G, nx.Graph): + # short-circuit the conversion to nx graph + return + kwargs[name] = (list(G.nodes), list(G.edges)) elif _is_integer(G): @@ -304,6 +319,13 @@ def _enforce_single_arg(name, args, kwargs): else: raise ValueError('Unexpected graph input form') + if as_networkx: + nodes, edges = kwargs[name] + G = nx.Graph() + G.add_nodes_from(nodes) + G.add_edges_from(edges) + kwargs[name] = G + return @wraps(f) diff --git a/releasenotes/notes/add-nx-graph-output-to-graph_argument-decorator-283515d9b29828a1.yaml b/releasenotes/notes/add-nx-graph-output-to-graph_argument-decorator-283515d9b29828a1.yaml new file mode 100644 index 000000000..f2e3ca09e --- /dev/null +++ b/releasenotes/notes/add-nx-graph-output-to-graph_argument-decorator-283515d9b29828a1.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Add support for NetworkX graph output during function argument coercion via + ``@graph_argument(..., as_networkx=True)``. diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 71cda2812..9d649e5f1 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -299,7 +299,6 @@ def f(G=None): f() def test_allow_None_True(self): - @graph_argument('G', allow_None=True) def f(G=None): return G @@ -312,6 +311,45 @@ def test_other_kwarg(self): def f(G): pass + @unittest.skipUnless(_networkx, "no networkx installed") + def test_as_networkx_graph(self): + from networkx.utils import graphs_equal + + @graph_argument('g', as_networkx=True) + def f(g): + return g + + with self.subTest('nx.Graph pass-through'): + g = nx.complete_graph(3) + G = f(g) + + self.assertIsInstance(G, nx.Graph) + self.assertIs(G, g) + + with self.subTest('int'): + G = f(3) + + self.assertIsInstance(G, nx.Graph) + self.assertTrue(graphs_equal(G, nx.complete_graph(3))) + + with self.subTest('nodes, edges'): + nodes = [1, 2, 3] + edges = [(1, 2), (2, 3)] + g = nx.Graph() + g.add_nodes_from(nodes) + g.add_edges_from(edges) + G = f((nodes, edges)) + + self.assertIsInstance(G, nx.Graph) + self.assertTrue(graphs_equal(G, g)) + + with self.subTest('edges'): + edges = [(0, 1), (2, 3), (0, 2)] + G = f(edges) + + self.assertIsInstance(G, nx.Graph) + self.assertTrue(graphs_equal(G, nx.Graph(edges))) + class TestForwardingMethod(unittest.TestCase): def setUp(self):