Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Graph difference #571

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ the functions from the explicitly typed based on the data type.
rustworkx.digraph_core_number
rustworkx.digraph_complement
rustworkx.digraph_union
rustworkx.digraph_difference
rustworkx.digraph_tensor_product
rustworkx.digraph_cartesian_product
rustworkx.digraph_random_layout
Expand Down Expand Up @@ -369,6 +370,7 @@ typed API based on the data type.
rustworkx.graph_core_number
rustworkx.graph_complement
rustworkx.graph_union
rustworkx.graph_difference
rustworkx.graph_tensor_product
rustworkx.graph_token_swapper
rustworkx.graph_cartesian_product
Expand Down
31 changes: 31 additions & 0 deletions releasenotes/notes/add-graph-difference-9916bf3d612f0b1a.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
---
features:
- |
Add two new functions which calculates the difference of two graphs :func:`~rustworkx.graph_difference`
for undirected graphs and :func:`~rustworkx.digraph_difference` for directed graphs. For example:

.. jupyter-execute::

import rustworkx
from rustworkx.visualization import mpl_draw

graph_1 = rustworkx.PyGraph()
graph_1.add_nodes_from(["a_1", "a_2", "a_3", "a_4"])
graph_1.extend_from_weighted_edge_list([(0, 1, "e_1"),
(1, 2, "e_2"),
(2, 3, "e_3"),
(3, 0, "e_4"),
(0, 2, "e_5"),
(1, 3, "e_6"),
])
graph_2 = rustworkx.PyGraph()
graph_2.add_nodes_from(["a_1", "a_2", "a_3", "a_4"])
graph_2.extend_from_weighted_edge_list([(0, 1, "e_1"),
(1, 2, "e_2"),
(2, 3, "e_3"),
(3, 0, "e_4"),
])

graph_difference = rustworkx.graph_difference(graph_1, graph_2)

mpl_draw(graph_difference)
32 changes: 32 additions & 0 deletions rustworkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,38 @@ def _graph_all_pairs_bellman_ford_shortest_path(graph, edge_cost_fn):
return graph_all_pairs_bellman_ford_shortest_paths(graph, edge_cost_fn)


@functools.singledispatch
def difference(
first,
second,
):
"""Return a new PyGraph that is the difference from two input
graph objects
:param first: The first graph object
:param second: The second graph object
:returns: A new graph object that is the difference of ``second`` and
``first``. It's worth noting the weight/data payload objects are
passed by reference from ``first`` to this new object.
:rtype: :class:`~rustworkx.PyGraph` or :class:`~rustworkx.PyDiGraph`
"""
raise TypeError("Invalid Input Type %s for graph" % type(first))


@difference.register(PyDiGraph)
def _digraph_difference(
first,
second,
):
return digraph_difference(first, second)


@difference.register(PyGraph)
def _graph_difference(
first,
second,
):
return graph_difference(first, second)

@functools.singledispatch
def node_link_json(graph, path=None, graph_attrs=None, node_attrs=None, edge_attrs=None):
"""Generate a JSON object representing a graph in a node-link format
Expand Down
116 changes: 116 additions & 0 deletions src/difference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use crate::{digraph, graph, StablePyGraph};

use hashbrown::HashSet;

use petgraph::visit::{EdgeRef, IntoEdgeReferences};
use petgraph::{algo, EdgeType};

use pyo3::exceptions::PyIndexError;
use pyo3::prelude::*;
use pyo3::Python;

fn difference<Ty: EdgeType>(
py: Python,
first: &StablePyGraph<Ty>,
second: &StablePyGraph<Ty>,
) -> PyResult<StablePyGraph<Ty>> {
let indexes_first = first.node_indices().collect::<HashSet<_>>();
let indexes_second = second.node_indices().collect::<HashSet<_>>();

if indexes_first != indexes_second {
return Err(PyIndexError::new_err(
"Node sets of the graphs should be equal",
));
}

let mut final_graph = StablePyGraph::<Ty>::with_capacity(
first.node_count(),
first.edge_count() - second.edge_count(),
);

for node in first.node_indices() {
let weight = &first[node];
final_graph.add_node(weight.clone_ref(py));
}
Comment on lines +43 to +46
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the node weights are different between first and second? Do we care at all that if first[0] != second[0] we only preserve the payload for first[0]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should have a callback to handle weights from both graphs?


for e in first.edge_references() {
let has_edge = second.find_edge(e.source(), e.target());

match has_edge {
Some(_x) => continue,
None => final_graph.add_edge(e.source(), e.target(), e.weight().clone_ref(py)),
};
}

Ok(final_graph)
}

/// Return a new PyGraph that is the difference from two input
/// PyGraph objects
///
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the big thing missing here is the constraints on the input types (they have to have identical node indices). This probably should mention that and also maybe have an example and/or an explanation of how the difference is computed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I'll add examples and explain how we compute the difference in the docs.

/// :param PyGraph first: The first undirected graph object
/// :param PyGraph second: The second undirected graph object
///
/// :returns: A new PyGraph object that is the difference of ``first``
/// and ``second``. It's worth noting the weight/data payload objects are
/// passed by reference from ``first`` graph to this new object.
///
/// :rtype: :class:`~rustworkx.PyGraph`
#[pyfunction()]
#[pyo3(text_signature = "(first, second, /)")]
pub fn graph_difference(
py: Python,
first: &graph::PyGraph,
second: &graph::PyGraph,
) -> PyResult<graph::PyGraph> {
let out_graph = difference(py, &first.graph, &second.graph)?;

Ok(graph::PyGraph {
graph: out_graph,
multigraph: true,
node_removed: false,
attrs: py.None(),
})
}

/// Return a new PyDiGraph that is the difference from two input
/// PyGraph objects
///
/// :param PyGraph first: The first undirected graph object
/// :param PyGraph second: The second undirected graph object
///
/// :returns: A new PyDiGraph object that is the difference of ``first``
/// and ``second``. It's worth noting the weight/data payload objects are
/// passed by reference from ``first`` graph to this new object.
///
/// :rtype: :class:`~rustworkx.PyDiGraph`
#[pyfunction()]
#[pyo3(text_signature = "(first, second, /)")]
pub fn digraph_difference(
py: Python,
first: &digraph::PyDiGraph,
second: &digraph::PyDiGraph,
) -> PyResult<digraph::PyDiGraph> {
let out_graph = difference(py, &first.graph, &second.graph)?;

Ok(digraph::PyDiGraph {
graph: out_graph,
cycle_state: algo::DfsSpace::default(),
check_cycle: false,
node_removed: false,
multigraph: true,
attrs: py.None(),
})
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ mod centrality;
mod coloring;
mod connectivity;
mod dag_algo;
mod difference;
mod digraph;
mod dot_utils;
mod generators;
Expand Down Expand Up @@ -43,6 +44,7 @@ use centrality::*;
use coloring::*;
use connectivity::*;
use dag_algo::*;
use difference::*;
use graphml::*;
use isomorphism::*;
use json::*;
Expand Down Expand Up @@ -372,6 +374,8 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(graph_union))?;
m.add_wrapped(wrap_pyfunction!(digraph_cartesian_product))?;
m.add_wrapped(wrap_pyfunction!(graph_cartesian_product))?;
m.add_wrapped(wrap_pyfunction!(digraph_difference))?;
m.add_wrapped(wrap_pyfunction!(graph_difference))?;
m.add_wrapped(wrap_pyfunction!(topological_sort))?;
m.add_wrapped(wrap_pyfunction!(descendants))?;
m.add_wrapped(wrap_pyfunction!(ancestors))?;
Expand Down
63 changes: 63 additions & 0 deletions tests/rustworkx_tests/digraph/test_difference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest
import rustworkx


class TestDifference(unittest.TestCase):
def test_null_difference_null(self):
graph_1 = rustworkx.PyDiGraph()
graph_2 = rustworkx.PyDiGraph()

graph_difference = rustworkx.digraph_difference(graph_1, graph_2)

self.assertEqual(graph_difference.num_nodes(), 0)
self.assertEqual(graph_difference.num_edges(), 0)

def test_difference_non_matching(self):
graph_1 = rustworkx.generators.directed_path_graph(2)
graph_2 = rustworkx.generators.directed_path_graph(3)

with self.assertRaises(IndexError):
_ = rustworkx.digraph_difference(graph_1, graph_2)

def test_difference_weights_edges(self):
graph_1 = rustworkx.PyDiGraph()
graph_1.add_nodes_from(["a_1", "a_2", "a_3", "a_4"])
graph_1.extend_from_weighted_edge_list(
[
(0, 1, "e_1"),
(1, 2, "e_2"),
(2, 3, "e_3"),
(3, 0, "e_4"),
(0, 2, "e_5"),
(1, 3, "e_6"),
]
)
graph_2 = rustworkx.PyDiGraph()
graph_2.add_nodes_from(["a_1", "a_2", "a_3", "a_4"])
graph_2.extend_from_weighted_edge_list(
[
(0, 1, "e_1"),
(1, 2, "e_2"),
(2, 3, "e_3"),
(3, 0, "e_4"),
]
)

graph_difference = rustworkx.digraph_difference(graph_1, graph_2)

expected_edges = [(0, 2, "e_5"), (1, 3, "e_6")]
self.assertEqual(graph_difference.num_nodes(), 4)
self.assertEqual(graph_difference.num_edges(), 2)
self.assertEqual(graph_difference.weighted_edge_list(), expected_edges)
63 changes: 63 additions & 0 deletions tests/rustworkx_tests/graph/test_difference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest
import rustworkx


class TestDifference(unittest.TestCase):
def test_null_difference_null(self):
graph_1 = rustworkx.PyGraph()
graph_2 = rustworkx.PyGraph()

graph_difference = rustworkx.graph_difference(graph_1, graph_2)

self.assertEqual(graph_difference.num_nodes(), 0)
self.assertEqual(graph_difference.num_edges(), 0)

def test_difference_non_matching(self):
graph_1 = rustworkx.generators.path_graph(2)
graph_2 = rustworkx.generators.path_graph(3)

with self.assertRaises(IndexError):
_ = rustworkx.graph_difference(graph_1, graph_2)

def test_difference_weights(self):
graph_1 = rustworkx.PyGraph()
graph_1.add_nodes_from(["a_1", "a_2", "a_3", "a_4"])
graph_1.extend_from_weighted_edge_list(
[
(0, 1, "e_1"),
(1, 2, "e_2"),
(2, 3, "e_3"),
(3, 0, "e_4"),
(0, 2, "e_5"),
(1, 3, "e_6"),
]
)
graph_2 = rustworkx.PyGraph()
graph_2.add_nodes_from(["a_1", "a_2", "a_3", "a_4"])
graph_2.extend_from_weighted_edge_list(
[
(0, 1, "e_1"),
(1, 2, "e_2"),
(2, 3, "e_3"),
(3, 0, "e_4"),
]
)

graph_difference = rustworkx.graph_difference(graph_1, graph_2)

expected_edges = [(0, 2, "e_5"), (1, 3, "e_6")]
self.assertEqual(graph_difference.num_nodes(), 4)
self.assertEqual(graph_difference.num_edges(), 2)
self.assertEqual(graph_difference.weighted_edge_list(), expected_edges)
Loading