Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Point access speedup #10

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions examples/plot_point_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import cartopy.crs as ccrs
import cartopy.io.img_tiles as cimgt
from math import log10, floor, pow
import sys

def plot_point_dataset(netcdf_point_utils,
variable_to_map,
Expand Down Expand Up @@ -69,7 +70,6 @@ def rescale_array(input_np_array, new_range_min=0, new_range_max=1):
utm_zone = abs(utm_zone)
projection = ccrs.UTM(zone=utm_zone, southern_hemisphere=southern_hemisphere)
print('utm_zone = {}'.format(utm_zone))
#print(utm_coords)

variable = netcdf_point_utils.netcdf_dataset.variables[variable_to_map]

Expand All @@ -91,8 +91,7 @@ def rescale_array(input_np_array, new_range_min=0, new_range_max=1):
utm_coords = utm_coords[spatial_mask]

print('{} points in UTM bounding box: {}'.format(np.count_nonzero(spatial_mask), utm_bbox))
#print(utm_coords)


colour_array = rescale_array(variable[spatial_mask], 0, 1)

fig = plt.figure(figsize=(30,30))
Expand All @@ -103,9 +102,9 @@ def rescale_array(input_np_array, new_range_min=0, new_range_max=1):

#map_image = cimgt.OSM() # https://www.openstreetmap.org/about
#map_image = cimgt.StamenTerrain() # http://maps.stamen.com/
map_image = cimgt.QuadtreeTiles()
#print(map_image.__dict__)
ax.add_image(map_image, 10)
#map_image = cimgt.GoogleTiles(style='satellite')

#ax.add_image(map_image, 10)

# Compute and set regular tick spacing
range_x = utm_bbox[2] - utm_bbox[0]
Expand Down Expand Up @@ -139,31 +138,32 @@ def rescale_array(input_np_array, new_range_min=0, new_range_max=1):
try: # not all variables have units. These will fail on the try and produce the map without tick labels.
cb = plt.colorbar(sc, ticks=[0, 1])
cb.ax.set_yticklabels([str(np.min(variable[spatial_mask])), str(np.max(variable[spatial_mask]))]) # vertically oriented colorbar

cb.set_label("{} {}".format(variable.long_name, variable.units))
except:
pass

print("show")
plt.show()


def main():
'''
main function for quick and dirty testing
'''
# Create NetCDFPointUtils object for specified netCDF dataset
netcdf_path = 'http://dapds00.nci.org.au/thredds/dodsC/uc0/rr2_dev/axi547/ground_gravity/point_datasets/201780.nc'
#netcdf_path = 'E:\\Temp\\gravity_point_test\\201780.nc'

netcdf_dataset = netCDF4.Dataset(netcdf_path)
nc_path = sys.argv[1]
variable_to_plot = sys.argv[2]

netcdf_dataset = netCDF4.Dataset(nc_path)
npu = NetCDFPointUtils(netcdf_dataset)

print('1D Point variables:\n\t{}'.format('\n\t'.join([key for key, value in netcdf_dataset.variables.items()
if value.dimensions == ('point',)])))
# Plot spatial subset
plot_point_dataset(npu,
'Bouguer',
utm_bbox=[630000,7980000,680000,8030000],
variable_to_plot,
#utm_bbox=[660000,7080000,680000,7330000],
colour_scheme='gist_heat',
point_size=50
point_size=30,
point_step=100
)

if __name__ == '__main__':
Expand Down
106 changes: 63 additions & 43 deletions geophys_utils/_netcdf_grid_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,8 @@ def get_value_at_coords(self, coordinates, wkt=None,
@parameter max_bytes: Maximum number of bytes to read in a single query. Defaults to NetCDFGridUtils.DEFAULT_MAX_BYTES
@parameter variable_name: NetCDF variable_name if not default data variable
'''
# Use arbitrary maximum request size of NetCDFGridUtils.DEFAULT_MAX_BYTES
# (500,000,000 bytes => 11180 points per query)
#TODO: Find a better way of overcoming the netCDF problem where whole rows & columns are retrieved
max_bytes = max_bytes or 100 # NetCDFGridUtils.DEFAULT_MAX_BYTES
# Use small max request size by default
max_bytes = max_bytes or 100

if variable_name:
data_variable = self.netcdf_dataset.variables[variable_name]
Expand All @@ -256,47 +254,39 @@ def get_value_at_coords(self, coordinates, wkt=None,
no_data_value = data_variable._FillValue

indices = np.array(self.get_indices_from_coords(coordinates, wkt))

# return data_variable[indices[:,0], indices[:,1]].diagonal() # This could get too big

# Allow for the fact that the NetCDF advanced indexing will pull back
# n^2 cells rather than n
max_points = max(
int(math.sqrt(max_bytes / data_variable.dtype.itemsize)), 1)
try:
# Make this a vectorised operation for speed (one query for as many
# points as possible)
# Array of valid index pairs only
index_array = np.array(
[index_pair for index_pair in indices if index_pair is not None])
assert len(index_array.shape) == 2 and index_array.shape[
1] == 2, 'Not an iterable containing index pairs'
# Boolean mask indicating which index pairs are valid
mask_array = np.array([(index_pair is not None)
for index_pair in indices])
# Array of values read from variable
value_array = np.ones(shape=(len(index_array)),
dtype=data_variable.dtype) * no_data_value
# Final result array including no-data for invalid index pairs
result_array = np.ones(
shape=(len(mask_array)), dtype=data_variable.dtype) * no_data_value
start_index = 0
end_index = min(max_points, len(index_array))
while True:
# N.B: ".diagonal()" is required because NetCDF doesn't do advanced indexing exactly like numpy
# Hack is required to take values from leading diagonal. Requires n^2 elements retrieved instead of n. Not good, but better than whole array
# TODO: Think of a better way of doing this
value_array[start_index:end_index] = data_variable[
(index_array[start_index:end_index, 0], index_array[start_index:end_index, 1])].diagonal()
if end_index == len(index_array): # Finished
break
start_index = end_index
end_index = min(start_index + max_points, len(index_array))

result_array[mask_array] = value_array
return list(result_array)
except:
if indices.ndim == 1: #single coordinate pair
return data_variable[indices[0], indices[1]]

# Make this a vectorised operation for speed (one query for as many
# points as possible)
# Array of valid index pairs only
index_array = np.array(
[index_pair for index_pair in indices if index_pair is not None])
assert len(index_array.shape) == 2 and index_array.shape[
1] == 2, 'Not an iterable containing index pairs'
# Boolean mask indicating which index pairs are valid
mask_array = np.array([(index_pair is not None)
for index_pair in indices])
# Array of values read from variable
value_array = np.ones(shape=(len(index_array)),
dtype=data_variable.dtype) * no_data_value
# Final result array including no-data for invalid index pairs
result_array = np.ones(
shape=(len(mask_array)), dtype=data_variable.dtype) * no_data_value
start_index = 0
while start_index < len(index_array):
#read up to max_bytes of data_variable containing as many of the next points in
#index_array as possible
end_index, bbox = _get_query_params(index_array, start_index, data_variable, max_bytes)
query_result = data_variable[bbox[0]:bbox[1]+1, bbox[2]:bbox[3]+1]
residx0 = index_array[start_index:end_index+1,0] - bbox[0]
residx1 = index_array[start_index:end_index+1,1] - bbox[2]
value_array[start_index:end_index+1] = query_result[residx0, residx1]
start_index = end_index + 1

result_array[mask_array] = value_array
return list(result_array)

def get_interpolated_value_at_coords(
self, coordinates, wkt=None, max_bytes=None, variable_name=None):
Expand Down Expand Up @@ -903,6 +893,36 @@ def iterate_through_data_chunks_and_find_mins_and_maxs(self, variable, num_of_ch
logger.debug("current_min: {}".format(current_min))
return current_min, current_max

def _get_query_params(index_array, start_index, data_variable, max_bytes):
#find the maximum ending index to use for a request under max_bytes
# and the associated bounding box
if start_index >= len(index_array):
return start_index
end_index = start_index
dbytes = data_variable.dtype.itemsize
nbytes = dbytes
bbox = [index_array[start_index,0], index_array[start_index,0],
index_array[start_index,1], index_array[start_index,1]]
while end_index < len(index_array) and nbytes <= max_bytes:
end_index += 1
if end_index >= len(index_array):
break
x = index_array[end_index,0]
y = index_array[end_index,1]
new_bbox = bbox.copy()
if x < bbox[0]:
new_bbox[0] = x
elif x > bbox[1]:
new_bbox[1] = x
if y < bbox[2]:
new_bbox[2] = y
elif y > bbox[3]:
new_bbox[3] = y
nbytes = dbytes*(new_bbox[1]-new_bbox[0]+1)*(new_bbox[3]-new_bbox[2]+1)
if nbytes > max_bytes:
break
bbox=new_bbox
return max(start_index, end_index - 1), bbox

def main():
'''
Expand Down
18 changes: 17 additions & 1 deletion geophys_utils/test/test_netcdf_grid_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,19 @@
netcdf_grid_utils = None

NC_PATH = 'test_grid.nc'
MAX_BYTES = 1600
MAX_BYTES_SM = 100
MAX_BYTES_LG = 5000
MAX_ERROR = 0.000001
TEST_COORDS = (148.213, -36.015)
TEST_MULTI_COORDS = np.array([[148.213, -36.015], [148.516, -35.316]])
TEST_MANY_COORDS = np.array([[148.484, -35.352],
[148.328, -35.428], [148.436, -35.744], [148.300, -35.436]])
TEST_INDICES = [1, 1]
TEST_MULTI_INDICES = [[1, 1], [176, 77]]
TEST_FRACTIONAL_INDICES = [1.25, 1.25]
TEST_VALUE = -99999.
TEST_MULTI_VALUES = [-99999.0, -134.711334229]
TEST_MANY_VALS = [-136.13321, -31.7626, -58.755764, -90.484276]
TEST_INTERPOLATED_VALUE = -99997.6171875

class TestNetCDFGridUtilsConstructor(unittest.TestCase):
Expand Down Expand Up @@ -82,6 +86,18 @@ def test_get_value_at_coords(self):
multi_values = netcdf_grid_utils.get_value_at_coords(TEST_MULTI_COORDS)
assert (np.abs(np.array(multi_values) - np.array(TEST_MULTI_VALUES)) < MAX_ERROR).all(), 'Incorrect retrieved value: {} instead of {}'.format(multi_values, TEST_MULTI_VALUES)

print('Testing get_value_at_coords with long coordinate list {}'.format(TEST_MANY_COORDS))
many_values = netcdf_grid_utils.get_value_at_coords(TEST_MANY_COORDS)
assert (np.abs(np.array(many_values) - np.array(TEST_MANY_VALS)) < MAX_ERROR).all(), 'Incorrect retrieved value: {} instead of {}'.format(many_values, TEST_MANY_VALS)

print('Testing get_value_at_coords with long coordinate list {} and request size {} bytes'.format(TEST_MANY_COORDS, MAX_BYTES_SM))
many_values = netcdf_grid_utils.get_value_at_coords(TEST_MANY_COORDS, max_bytes=MAX_BYTES_SM)
assert (np.abs(np.array(many_values) - np.array(TEST_MANY_VALS)) < MAX_ERROR).all(), 'Incorrect retrieved value: {} instead of {}'.format(many_values, TEST_MANY_VALS)

print('Testing get_value_at_coords with long coordinate list {} and request size {} bytes'.format(TEST_MANY_COORDS, MAX_BYTES_LG))
many_values = netcdf_grid_utils.get_value_at_coords(TEST_MANY_COORDS, max_bytes=MAX_BYTES_LG)
assert (np.abs(np.array(many_values) - np.array(TEST_MANY_VALS)) < MAX_ERROR).all(), 'Incorrect retrieved value: {} instead of {}'.format(many_values, TEST_MANY_VALS)

def test_get_interpolated_value_at_coords(self):
print('Testing get_interpolated_value_at_coords function')
interpolated_value = netcdf_grid_utils.get_interpolated_value_at_coords(TEST_COORDS)
Expand Down