diff --git a/docs/source/api.rst b/docs/source/api.rst index 44a7dab7a..de23148e2 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -313,6 +313,7 @@ the functions from the explicitly typed based on the data type. rustworkx.digraph_core_number rustworkx.digraph_complement rustworkx.digraph_union + rustworkx.digraph_difference rustworkx.digraph_tensor_product rustworkx.digraph_cartesian_product rustworkx.digraph_random_layout @@ -369,6 +370,7 @@ typed API based on the data type. rustworkx.graph_core_number rustworkx.graph_complement rustworkx.graph_union + rustworkx.graph_difference rustworkx.graph_tensor_product rustworkx.graph_token_swapper rustworkx.graph_cartesian_product diff --git a/releasenotes/notes/add-graph-difference-9916bf3d612f0b1a.yaml b/releasenotes/notes/add-graph-difference-9916bf3d612f0b1a.yaml new file mode 100644 index 000000000..db931deb5 --- /dev/null +++ b/releasenotes/notes/add-graph-difference-9916bf3d612f0b1a.yaml @@ -0,0 +1,31 @@ +--- +features: + - | + Add two new functions which calculates the difference of two graphs :func:`~rustworkx.graph_difference` + for undirected graphs and :func:`~rustworkx.digraph_difference` for directed graphs. For example: + + .. jupyter-execute:: + + import rustworkx + from rustworkx.visualization import mpl_draw + + graph_1 = rustworkx.PyGraph() + graph_1.add_nodes_from(["a_1", "a_2", "a_3", "a_4"]) + graph_1.extend_from_weighted_edge_list([(0, 1, "e_1"), + (1, 2, "e_2"), + (2, 3, "e_3"), + (3, 0, "e_4"), + (0, 2, "e_5"), + (1, 3, "e_6"), + ]) + graph_2 = rustworkx.PyGraph() + graph_2.add_nodes_from(["a_1", "a_2", "a_3", "a_4"]) + graph_2.extend_from_weighted_edge_list([(0, 1, "e_1"), + (1, 2, "e_2"), + (2, 3, "e_3"), + (3, 0, "e_4"), + ]) + + graph_difference = rustworkx.graph_difference(graph_1, graph_2) + + mpl_draw(graph_difference) \ No newline at end of file diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index 40e3cd6c7..ab479f774 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -2462,6 +2462,38 @@ def _graph_all_pairs_bellman_ford_shortest_path(graph, edge_cost_fn): return graph_all_pairs_bellman_ford_shortest_paths(graph, edge_cost_fn) +@functools.singledispatch +def difference( + first, + second, +): + """Return a new PyGraph that is the difference from two input + graph objects + :param first: The first graph object + :param second: The second graph object + :returns: A new graph object that is the difference of ``second`` and + ``first``. It's worth noting the weight/data payload objects are + passed by reference from ``first`` to this new object. + :rtype: :class:`~rustworkx.PyGraph` or :class:`~rustworkx.PyDiGraph` + """ + raise TypeError("Invalid Input Type %s for graph" % type(first)) + + +@difference.register(PyDiGraph) +def _digraph_difference( + first, + second, +): + return digraph_difference(first, second) + + +@difference.register(PyGraph) +def _graph_difference( + first, + second, +): + return graph_difference(first, second) + @functools.singledispatch def node_link_json(graph, path=None, graph_attrs=None, node_attrs=None, edge_attrs=None): """Generate a JSON object representing a graph in a node-link format diff --git a/src/difference.rs b/src/difference.rs new file mode 100644 index 000000000..45f4a0aed --- /dev/null +++ b/src/difference.rs @@ -0,0 +1,116 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use crate::{digraph, graph, StablePyGraph}; + +use hashbrown::HashSet; + +use petgraph::visit::{EdgeRef, IntoEdgeReferences}; +use petgraph::{algo, EdgeType}; + +use pyo3::exceptions::PyIndexError; +use pyo3::prelude::*; +use pyo3::Python; + +fn difference( + py: Python, + first: &StablePyGraph, + second: &StablePyGraph, +) -> PyResult> { + let indexes_first = first.node_indices().collect::>(); + let indexes_second = second.node_indices().collect::>(); + + if indexes_first != indexes_second { + return Err(PyIndexError::new_err( + "Node sets of the graphs should be equal", + )); + } + + let mut final_graph = StablePyGraph::::with_capacity( + first.node_count(), + first.edge_count() - second.edge_count(), + ); + + for node in first.node_indices() { + let weight = &first[node]; + final_graph.add_node(weight.clone_ref(py)); + } + + for e in first.edge_references() { + let has_edge = second.find_edge(e.source(), e.target()); + + match has_edge { + Some(_x) => continue, + None => final_graph.add_edge(e.source(), e.target(), e.weight().clone_ref(py)), + }; + } + + Ok(final_graph) +} + +/// Return a new PyGraph that is the difference from two input +/// PyGraph objects +/// +/// :param PyGraph first: The first undirected graph object +/// :param PyGraph second: The second undirected graph object +/// +/// :returns: A new PyGraph object that is the difference of ``first`` +/// and ``second``. It's worth noting the weight/data payload objects are +/// passed by reference from ``first`` graph to this new object. +/// +/// :rtype: :class:`~rustworkx.PyGraph` +#[pyfunction()] +#[pyo3(text_signature = "(first, second, /)")] +pub fn graph_difference( + py: Python, + first: &graph::PyGraph, + second: &graph::PyGraph, +) -> PyResult { + let out_graph = difference(py, &first.graph, &second.graph)?; + + Ok(graph::PyGraph { + graph: out_graph, + multigraph: true, + node_removed: false, + attrs: py.None(), + }) +} + +/// Return a new PyDiGraph that is the difference from two input +/// PyGraph objects +/// +/// :param PyGraph first: The first undirected graph object +/// :param PyGraph second: The second undirected graph object +/// +/// :returns: A new PyDiGraph object that is the difference of ``first`` +/// and ``second``. It's worth noting the weight/data payload objects are +/// passed by reference from ``first`` graph to this new object. +/// +/// :rtype: :class:`~rustworkx.PyDiGraph` +#[pyfunction()] +#[pyo3(text_signature = "(first, second, /)")] +pub fn digraph_difference( + py: Python, + first: &digraph::PyDiGraph, + second: &digraph::PyDiGraph, +) -> PyResult { + let out_graph = difference(py, &first.graph, &second.graph)?; + + Ok(digraph::PyDiGraph { + graph: out_graph, + cycle_state: algo::DfsSpace::default(), + check_cycle: false, + node_removed: false, + multigraph: true, + attrs: py.None(), + }) +} diff --git a/src/lib.rs b/src/lib.rs index 7f941e2e6..ce197aa97 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,7 @@ mod centrality; mod coloring; mod connectivity; mod dag_algo; +mod difference; mod digraph; mod dot_utils; mod generators; @@ -43,6 +44,7 @@ use centrality::*; use coloring::*; use connectivity::*; use dag_algo::*; +use difference::*; use graphml::*; use isomorphism::*; use json::*; @@ -372,6 +374,8 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(graph_union))?; m.add_wrapped(wrap_pyfunction!(digraph_cartesian_product))?; m.add_wrapped(wrap_pyfunction!(graph_cartesian_product))?; + m.add_wrapped(wrap_pyfunction!(digraph_difference))?; + m.add_wrapped(wrap_pyfunction!(graph_difference))?; m.add_wrapped(wrap_pyfunction!(topological_sort))?; m.add_wrapped(wrap_pyfunction!(descendants))?; m.add_wrapped(wrap_pyfunction!(ancestors))?; diff --git a/tests/rustworkx_tests/digraph/test_difference.py b/tests/rustworkx_tests/digraph/test_difference.py new file mode 100644 index 000000000..d2cae9916 --- /dev/null +++ b/tests/rustworkx_tests/digraph/test_difference.py @@ -0,0 +1,63 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest +import rustworkx + + +class TestDifference(unittest.TestCase): + def test_null_difference_null(self): + graph_1 = rustworkx.PyDiGraph() + graph_2 = rustworkx.PyDiGraph() + + graph_difference = rustworkx.digraph_difference(graph_1, graph_2) + + self.assertEqual(graph_difference.num_nodes(), 0) + self.assertEqual(graph_difference.num_edges(), 0) + + def test_difference_non_matching(self): + graph_1 = rustworkx.generators.directed_path_graph(2) + graph_2 = rustworkx.generators.directed_path_graph(3) + + with self.assertRaises(IndexError): + _ = rustworkx.digraph_difference(graph_1, graph_2) + + def test_difference_weights_edges(self): + graph_1 = rustworkx.PyDiGraph() + graph_1.add_nodes_from(["a_1", "a_2", "a_3", "a_4"]) + graph_1.extend_from_weighted_edge_list( + [ + (0, 1, "e_1"), + (1, 2, "e_2"), + (2, 3, "e_3"), + (3, 0, "e_4"), + (0, 2, "e_5"), + (1, 3, "e_6"), + ] + ) + graph_2 = rustworkx.PyDiGraph() + graph_2.add_nodes_from(["a_1", "a_2", "a_3", "a_4"]) + graph_2.extend_from_weighted_edge_list( + [ + (0, 1, "e_1"), + (1, 2, "e_2"), + (2, 3, "e_3"), + (3, 0, "e_4"), + ] + ) + + graph_difference = rustworkx.digraph_difference(graph_1, graph_2) + + expected_edges = [(0, 2, "e_5"), (1, 3, "e_6")] + self.assertEqual(graph_difference.num_nodes(), 4) + self.assertEqual(graph_difference.num_edges(), 2) + self.assertEqual(graph_difference.weighted_edge_list(), expected_edges) diff --git a/tests/rustworkx_tests/graph/test_difference.py b/tests/rustworkx_tests/graph/test_difference.py new file mode 100644 index 000000000..96d8516a2 --- /dev/null +++ b/tests/rustworkx_tests/graph/test_difference.py @@ -0,0 +1,63 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest +import rustworkx + + +class TestDifference(unittest.TestCase): + def test_null_difference_null(self): + graph_1 = rustworkx.PyGraph() + graph_2 = rustworkx.PyGraph() + + graph_difference = rustworkx.graph_difference(graph_1, graph_2) + + self.assertEqual(graph_difference.num_nodes(), 0) + self.assertEqual(graph_difference.num_edges(), 0) + + def test_difference_non_matching(self): + graph_1 = rustworkx.generators.path_graph(2) + graph_2 = rustworkx.generators.path_graph(3) + + with self.assertRaises(IndexError): + _ = rustworkx.graph_difference(graph_1, graph_2) + + def test_difference_weights(self): + graph_1 = rustworkx.PyGraph() + graph_1.add_nodes_from(["a_1", "a_2", "a_3", "a_4"]) + graph_1.extend_from_weighted_edge_list( + [ + (0, 1, "e_1"), + (1, 2, "e_2"), + (2, 3, "e_3"), + (3, 0, "e_4"), + (0, 2, "e_5"), + (1, 3, "e_6"), + ] + ) + graph_2 = rustworkx.PyGraph() + graph_2.add_nodes_from(["a_1", "a_2", "a_3", "a_4"]) + graph_2.extend_from_weighted_edge_list( + [ + (0, 1, "e_1"), + (1, 2, "e_2"), + (2, 3, "e_3"), + (3, 0, "e_4"), + ] + ) + + graph_difference = rustworkx.graph_difference(graph_1, graph_2) + + expected_edges = [(0, 2, "e_5"), (1, 3, "e_6")] + self.assertEqual(graph_difference.num_nodes(), 4) + self.assertEqual(graph_difference.num_edges(), 2) + self.assertEqual(graph_difference.weighted_edge_list(), expected_edges)