Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
elbeejay committed Jun 14, 2024
1 parent 0ebbde7 commit 4e4fb5b
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 44 deletions.
12 changes: 5 additions & 7 deletions src/GOSTnets/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,16 +695,14 @@ def graph_nodes_intersecting_polygon(G, polygons, crs=None):
a list of the nodes intersecting the polygons
"""
if type(G) == nx.classes.multidigraph.MultiDiGraph:
if isinstance(G, nx.Graph):
graph_gdf = node_gdf_from_graph(G)

elif type(G) == gpd.geodataframe.GeoDataFrame:
elif isinstance(G, gpd.GeoDataFrame):
graph_gdf = G

else:
raise ValueError("Expecting a graph or node geodataframe for G!")

if type(polygons) != gpd.geodataframe.GeoDataFrame:
if type(polygons) != gpd.GeoDataFrame:
raise ValueError("Expecting a geodataframe for polygon(s)!")

if (crs is not None) and (graph_gdf.crs != crs):
Expand Down Expand Up @@ -751,13 +749,13 @@ def graph_edges_intersecting_polygon(G, polygons, mode, crs=None, fast=True):
a list of the edges intersecting the polygons
"""
if type(G) == nx.classes.multidigraph.MultiDiGraph:
if isinstance(G, nx.Graph):
node_graph_gdf = node_gdf_from_graph(G)
edge_graph_gdf = edge_gdf_from_graph(G)
else:
raise ValueError("Expecting a graph for G!")

if type(polygons) != gpd.geodataframe.GeoDataFrame:
if type(polygons) != gpd.GeoDataFrame:
raise ValueError("Expecting a geodataframe for polygon(s)!")

if (crs is not None) and (node_graph_gdf.crs != crs):
Expand Down
108 changes: 71 additions & 37 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,30 @@ def mocked_convert(x, y, z, f, g):
return ["one", "two", "three"]


@mock.patch("GOSTnets.core.convert", mocked_convert)
def test_combo_csv_to_graph():
"""Test the combo_csv_to_graph function."""
# create the test csv object
f_df = pd.DataFrame(
data={"u_col": ["a", "b", "c"], "v_col": ["x", "y", "z"], "geo_col": [1, 2, 3]}
)
# write to buffer
s_buf = io.StringIO()
f_df.to_csv(s_buf)
s_buf.seek(0)
# call function
G = core.combo_csv_to_graph(
s_buf, u_tag="u_col", v_tag="v_col", geometry_tag="geo_col"
)
# assertions
assert isinstance(G, nx.MultiDiGraph)
assert "a" in G.nodes
assert "x" in G.nodes
class TestGraphCreationFunctions:
@mock.patch("GOSTnets.core.convert", mocked_convert)
def test_combo_csv_to_graph(self):
"""Test the combo_csv_to_graph function."""
# create the test csv object
f_df = pd.DataFrame(
data={
"u_col": ["a", "b", "c"],
"v_col": ["x", "y", "z"],
"geo_col": [1, 2, 3],
}
)
# write to buffer
s_buf = io.StringIO()
f_df.to_csv(s_buf)
s_buf.seek(0)
# call function
G = core.combo_csv_to_graph(
s_buf, u_tag="u_col", v_tag="v_col", geometry_tag="geo_col"
)
# assertions
assert isinstance(G, nx.MultiDiGraph)
assert "a" in G.nodes
assert "x" in G.nodes


def test_edges_and_nodes_gdf_to_graph():
Expand All @@ -118,24 +123,53 @@ def test_edges_and_nodes_csv_to_graph():
pass


def test_node_gdf_from_graph():
"""Test the node_gdf_from_graph function."""
pass


def test_edge_gdf_from_graph():
"""Test the edge_gdf_from_graph function."""
pass


def test_graph_nodes_intersecting_polgyon():
"""Test the graph_nodes_intersecting_polgyon function."""
pass


def test_graph_edges_intersecting_polgyon():
"""Test the graph_edges_intersecting_polgyon function."""
pass
class TestGDFfromGraph:
# create graph to use for tests
G = nx.Graph()
# add some nodes w/ (x, y) attributes
G.add_node(1, x=0, y=0)
G.add_node(2, x=-1, y=0.3)
G.add_node(3, x=2, y=0.17)
G.add_node(4, x=4, y=0.255)
G.add_node(5, x=5, y=0.03)
# create some edges
G.add_edge(1, 2)
G.add_edge(4, 5)
# define polygon
poly = Polygon([(-1, -1), (0.0, 1.0), (1.0, 0.0)])
poly_gdf = gpd.GeoDataFrame({"x": 1, "geometry": poly}, index=[1])

def test_node_gdf_from_graph(self):
"""Test the node_gdf_from_graph function."""
node_gdf = core.node_gdf_from_graph(self.G)
assert isinstance(node_gdf, gpd.GeoDataFrame)
assert "x" in node_gdf.columns
assert "y" in node_gdf.columns
assert "geometry" in node_gdf.columns
assert node_gdf.shape[0] == 5

def test_edge_gdf_from_graph(self):
"""Test the edge_gdf_from_graph function."""
edge_gdf = core.edge_gdf_from_graph(self.G)
assert edge_gdf.shape[0] == 2
assert isinstance(edge_gdf, gpd.GeoDataFrame)
assert "stnode" in edge_gdf.columns
assert "endnode" in edge_gdf.columns
assert "geometry" in edge_gdf.columns

# def test_graph_nodes_intersecting_polgyon(self):
# """Test the graph_nodes_intersecting_polygon function."""
# # call function
# int_list = core.graph_nodes_intersecting_polygon(self.G, self.poly_gdf)
# import pdb

# pdb.set_trace()

# def test_graph_edges_intersecting_polgyon(self):
# """Test the graph_edges_intersecting_polygon function."""
# int_list = core.graph_edges_intersecting_polygon(
# self.G, self.poly_gdf, mode="contains"
# )


def test_sample_raster():
Expand Down

0 comments on commit 4e4fb5b

Please sign in to comment.