Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions dimod/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
import collections.abc as abc
import inspect
import itertools
import warnings

from functools import wraps
from numbers import Integral

from dimod.exceptions import BinaryQuadraticModelStructureError, WriteableError
from dimod.exceptions import BinaryQuadraticModelStructureError
from dimod.utilities import new_variable_label
from dimod.vartypes import as_vartype

Expand Down Expand Up @@ -230,7 +228,8 @@ def graph_argument(*arg_names, **options):

The wrapped function accepts either an integer n, interpreted as a
complete graph of size n, a nodes/edges pair, a sequence of edges, or a
NetworkX graph. The argument is converted into a nodes/edges 2-tuple.
NetworkX graph. The argument is converted into a nodes/edges 2-tuple, or
a NetworkX graph if ``as_networkx`` option is set.

Args:
*arg_names (optional, default='G'):
Expand All @@ -239,21 +238,33 @@ def graph_argument(*arg_names, **options):
allow_None (bool, optional, default=False):
If True, None can be passed through as an input graph.

as_networkx (book, optional, default=False):
If True, return a NetworkX graph.

"""

# by default, constrain only one argument, the 'G`
if not arg_names:
arg_names = ['G']

# we only allow one option allow_None
# we only allow two options
allow_None = options.pop("allow_None", False)
as_networkx = options.pop("as_networkx", False)
if options:
# to keep it consistent with python3
# behaviour like graph_argument(*arg_names, allow_None=False)
key, _ = options.popitem()
msg = "graph_argument() for an unexpected keyword argument '{}'".format(key)
raise TypeError(msg)

# if user asks for a nx graph, we require nx
if as_networkx:
try:
import networkx as nx
except ImportError:
raise RuntimeError("graph_argument() with 'as_networkx=True' "
"requires NetworkX installed")

def _graph_arg(f):
argspec = inspect.getfullargspec(f)

Expand All @@ -265,6 +276,10 @@ def _enforce_single_arg(name, args, kwargs):

if hasattr(G, 'edges') and hasattr(G, 'nodes'):
# networkx or perhaps a named tuple
if as_networkx and isinstance(G, nx.Graph):
# short-circuit the conversion to nx graph
return

kwargs[name] = (list(G.nodes), list(G.edges))

elif _is_integer(G):
Expand Down Expand Up @@ -304,6 +319,13 @@ def _enforce_single_arg(name, args, kwargs):
else:
raise ValueError('Unexpected graph input form')

if as_networkx:
nodes, edges = kwargs[name]
G = nx.Graph()
G.add_nodes_from(nodes)
G.add_edges_from(edges)
kwargs[name] = G

return

@wraps(f)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
Add support for NetworkX graph output during function argument coercion via
``@graph_argument(..., as_networkx=True)``.
40 changes: 39 additions & 1 deletion tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ def f(G=None):
f()

def test_allow_None_True(self):

@graph_argument('G', allow_None=True)
def f(G=None):
return G
Expand All @@ -312,6 +311,45 @@ def test_other_kwarg(self):
def f(G):
pass

@unittest.skipUnless(_networkx, "no networkx installed")
def test_as_networkx_graph(self):
from networkx.utils import graphs_equal

@graph_argument('g', as_networkx=True)
def f(g):
return g

with self.subTest('nx.Graph pass-through'):
g = nx.complete_graph(3)
G = f(g)

self.assertIsInstance(G, nx.Graph)
self.assertIs(G, g)

with self.subTest('int'):
G = f(3)

self.assertIsInstance(G, nx.Graph)
self.assertTrue(graphs_equal(G, nx.complete_graph(3)))

with self.subTest('nodes, edges'):
nodes = [1, 2, 3]
edges = [(1, 2), (2, 3)]
g = nx.Graph()
g.add_nodes_from(nodes)
g.add_edges_from(edges)
G = f((nodes, edges))

self.assertIsInstance(G, nx.Graph)
self.assertTrue(graphs_equal(G, g))

with self.subTest('edges'):
edges = [(0, 1), (2, 3), (0, 2)]
G = f(edges)

self.assertIsInstance(G, nx.Graph)
self.assertTrue(graphs_equal(G, nx.Graph(edges)))


class TestForwardingMethod(unittest.TestCase):
def setUp(self):
Expand Down