diff --git a/src/GOSTnets/core.py b/src/GOSTnets/core.py index 38646b5..57f633b 100644 --- a/src/GOSTnets/core.py +++ b/src/GOSTnets/core.py @@ -695,16 +695,14 @@ def graph_nodes_intersecting_polygon(G, polygons, crs=None): a list of the nodes intersecting the polygons """ - if type(G) == nx.classes.multidigraph.MultiDiGraph: + if isinstance(G, nx.Graph): graph_gdf = node_gdf_from_graph(G) - - elif type(G) == gpd.geodataframe.GeoDataFrame: + elif isinstance(G, gpd.GeoDataFrame): graph_gdf = G - else: raise ValueError("Expecting a graph or node geodataframe for G!") - if type(polygons) != gpd.geodataframe.GeoDataFrame: + if type(polygons) != gpd.GeoDataFrame: raise ValueError("Expecting a geodataframe for polygon(s)!") if (crs is not None) and (graph_gdf.crs != crs): @@ -751,13 +749,13 @@ def graph_edges_intersecting_polygon(G, polygons, mode, crs=None, fast=True): a list of the edges intersecting the polygons """ - if type(G) == nx.classes.multidigraph.MultiDiGraph: + if isinstance(G, nx.Graph): node_graph_gdf = node_gdf_from_graph(G) edge_graph_gdf = edge_gdf_from_graph(G) else: raise ValueError("Expecting a graph for G!") - if type(polygons) != gpd.geodataframe.GeoDataFrame: + if type(polygons) != gpd.GeoDataFrame: raise ValueError("Expecting a geodataframe for polygon(s)!") if (crs is not None) and (node_graph_gdf.crs != crs): diff --git a/tests/test_core.py b/tests/test_core.py index c2ed75b..a36601e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -87,25 +87,30 @@ def mocked_convert(x, y, z, f, g): return ["one", "two", "three"] -@mock.patch("GOSTnets.core.convert", mocked_convert) -def test_combo_csv_to_graph(): - """Test the combo_csv_to_graph function.""" - # create the test csv object - f_df = pd.DataFrame( - data={"u_col": ["a", "b", "c"], "v_col": ["x", "y", "z"], "geo_col": [1, 2, 3]} - ) - # write to buffer - s_buf = io.StringIO() - f_df.to_csv(s_buf) - s_buf.seek(0) - # call function - G = core.combo_csv_to_graph( - s_buf, u_tag="u_col", v_tag="v_col", geometry_tag="geo_col" - ) - # assertions - assert isinstance(G, nx.MultiDiGraph) - assert "a" in G.nodes - assert "x" in G.nodes +class TestGraphCreationFunctions: + @mock.patch("GOSTnets.core.convert", mocked_convert) + def test_combo_csv_to_graph(self): + """Test the combo_csv_to_graph function.""" + # create the test csv object + f_df = pd.DataFrame( + data={ + "u_col": ["a", "b", "c"], + "v_col": ["x", "y", "z"], + "geo_col": [1, 2, 3], + } + ) + # write to buffer + s_buf = io.StringIO() + f_df.to_csv(s_buf) + s_buf.seek(0) + # call function + G = core.combo_csv_to_graph( + s_buf, u_tag="u_col", v_tag="v_col", geometry_tag="geo_col" + ) + # assertions + assert isinstance(G, nx.MultiDiGraph) + assert "a" in G.nodes + assert "x" in G.nodes def test_edges_and_nodes_gdf_to_graph(): @@ -118,24 +123,53 @@ def test_edges_and_nodes_csv_to_graph(): pass -def test_node_gdf_from_graph(): - """Test the node_gdf_from_graph function.""" - pass - - -def test_edge_gdf_from_graph(): - """Test the edge_gdf_from_graph function.""" - pass - - -def test_graph_nodes_intersecting_polgyon(): - """Test the graph_nodes_intersecting_polgyon function.""" - pass - - -def test_graph_edges_intersecting_polgyon(): - """Test the graph_edges_intersecting_polgyon function.""" - pass +class TestGDFfromGraph: + # create graph to use for tests + G = nx.Graph() + # add some nodes w/ (x, y) attributes + G.add_node(1, x=0, y=0) + G.add_node(2, x=-1, y=0.3) + G.add_node(3, x=2, y=0.17) + G.add_node(4, x=4, y=0.255) + G.add_node(5, x=5, y=0.03) + # create some edges + G.add_edge(1, 2) + G.add_edge(4, 5) + # define polygon + poly = Polygon([(-1, -1), (0.0, 1.0), (1.0, 0.0)]) + poly_gdf = gpd.GeoDataFrame({"x": 1, "geometry": poly}, index=[1]) + + def test_node_gdf_from_graph(self): + """Test the node_gdf_from_graph function.""" + node_gdf = core.node_gdf_from_graph(self.G) + assert isinstance(node_gdf, gpd.GeoDataFrame) + assert "x" in node_gdf.columns + assert "y" in node_gdf.columns + assert "geometry" in node_gdf.columns + assert node_gdf.shape[0] == 5 + + def test_edge_gdf_from_graph(self): + """Test the edge_gdf_from_graph function.""" + edge_gdf = core.edge_gdf_from_graph(self.G) + assert edge_gdf.shape[0] == 2 + assert isinstance(edge_gdf, gpd.GeoDataFrame) + assert "stnode" in edge_gdf.columns + assert "endnode" in edge_gdf.columns + assert "geometry" in edge_gdf.columns + + # def test_graph_nodes_intersecting_polgyon(self): + # """Test the graph_nodes_intersecting_polygon function.""" + # # call function + # int_list = core.graph_nodes_intersecting_polygon(self.G, self.poly_gdf) + # import pdb + + # pdb.set_trace() + + # def test_graph_edges_intersecting_polgyon(self): + # """Test the graph_edges_intersecting_polygon function.""" + # int_list = core.graph_edges_intersecting_polygon( + # self.G, self.poly_gdf, mode="contains" + # ) def test_sample_raster():