Skip to content

Commit

Permalink
Merge pull request #97 from boothmanrylan:transform-coordinates
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 592884814
  • Loading branch information
Xee authors committed Dec 21, 2023
2 parents 9b8c50a + 24ff582 commit 7bb1407
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 7 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ tests = [
"absl-py",
"pytest",
"pyink",
"rasterio",
"rioxarray",
]
examples = [
"apache_beam[gcp]",
Expand Down
12 changes: 5 additions & 7 deletions xee/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,11 @@ def _get_tile_from_ee(
tile_index, band_id = tile_index
bbox = self.project(
(tile_index[0], 0, tile_index[1], 1)
if band_id == 'longitude'
if band_id == 'x'
else (0, tile_index[0], 1, tile_index[1])
)
tile_idx = slice(tile_index[0], tile_index[1])
target_image = ee.Image.pixelLonLat()
target_image = ee.Image.pixelCoordinates(ee.Projection(self.crs_arg))
return tile_idx, self.image_to_array(
target_image, grid=bbox, dtype=np.float32, bandIds=[band_id]
)
Expand All @@ -574,9 +574,7 @@ def _process_coordinate_data(
self._get_tile_from_ee,
list(zip(data, itertools.cycle([coordinate_type]))),
):
tiles[i] = (
arr.tolist() if coordinate_type == 'longitude' else arr.tolist()[0]
)
tiles[i] = arr.tolist() if coordinate_type == 'x' else arr.tolist()[0]
return np.concatenate(tiles)

def get_variables(self) -> utils.Frozen[str, xarray.Variable]:
Expand Down Expand Up @@ -605,11 +603,11 @@ def get_variables(self) -> utils.Frozen[str, xarray.Variable]:

lon_total_tile = math.ceil(v0.shape[1] / width_chunk)
lon = self._process_coordinate_data(
lon_total_tile, width_chunk, v0.shape[1], 'longitude'
lon_total_tile, width_chunk, v0.shape[1], 'x'
)
lat_total_tile = math.ceil(v0.shape[2] / height_chunk)
lat = self._process_coordinate_data(
lat_total_tile, height_chunk, v0.shape[2], 'latitude'
lat_total_tile, height_chunk, v0.shape[2], 'y'
)

width_coord = np.squeeze(lon)
Expand Down
81 changes: 81 additions & 0 deletions xee/ext_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import os
import pathlib
import tempfile

from absl.testing import absltest
from google.auth import identity_pool
Expand All @@ -26,6 +27,13 @@

import ee

_SKIP_RASTERIO_TESTS = False
try:
import rasterio # pylint: disable=g-import-not-at-top
import rioxarray # pylint: disable=g-import-not-at-top,unused-import
except ImportError:
_SKIP_RASTERIO_TESTS = True

_CREDENTIALS_PATH_KEY = 'GOOGLE_APPLICATION_CREDENTIALS'
_SCOPES = [
'https://www.googleapis.com/auth/cloud-platform',
Expand Down Expand Up @@ -397,6 +405,79 @@ def test_validate_band_attrs(self):
for _, value in variable.attrs.items():
self.assertIsInstance(value, valid_types)

@absltest.skipIf(_SKIP_RASTERIO_TESTS, 'rioxarray module not loaded')
def test_write_projected_dataset_to_raster(self):
# ensure that a projected dataset written to a raster intersects with the
# point used to create the initial image collection
with tempfile.TemporaryDirectory() as temp_dir:
temp_file = os.path.join(temp_dir, 'test.tif')

crs = 'epsg:32610'
proj = ee.Projection(crs)
point = ee.Geometry.Point([-122.44, 37.78])
geom = point.buffer(1024).bounds()

col = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
col = col.filterBounds(point)
col = col.filter(ee.Filter.lte('CLOUDY_PIXEL_PERCENTAGE', 5))
col = col.limit(10)

ds = xr.open_dataset(
col,
engine=xee.EarthEngineBackendEntrypoint,
scale=10,
crs=crs,
geometry=geom,
)

ds = ds.isel(time=0).transpose('Y', 'X')
ds.rio.set_spatial_dims(x_dim='X', y_dim='Y', inplace=True)
ds.rio.write_crs(crs, inplace=True)
ds.rio.reproject(crs, inplace=True)
ds.rio.to_raster(temp_file)

with rasterio.open(temp_file) as raster:
# see https://gis.stackexchange.com/a/407755 for evenOdd explanation
bbox = ee.Geometry.Rectangle(raster.bounds, proj=proj, evenOdd=False)
intersects = bbox.intersects(point, 1, proj=proj)
self.assertTrue(intersects.getInfo())

@absltest.skipIf(_SKIP_RASTERIO_TESTS, 'rioxarray module not loaded')
def test_write_dataset_to_raster(self):
# ensure that a dataset written to a raster intersects with the point used
# to create the initial image collection
with tempfile.TemporaryDirectory() as temp_dir:
temp_file = os.path.join(temp_dir, 'test.tif')

crs = 'EPSG:4326'
proj = ee.Projection(crs)
point = ee.Geometry.Point([-122.44, 37.78])
geom = point.buffer(1024).bounds()

col = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
col = col.filterBounds(point)
col = col.filter(ee.Filter.lte('CLOUDY_PIXEL_PERCENTAGE', 5))
col = col.limit(10)

ds = xr.open_dataset(
col,
engine=xee.EarthEngineBackendEntrypoint,
scale=0.0025,
geometry=geom,
)

ds = ds.isel(time=0).transpose('lat', 'lon')
ds.rio.set_spatial_dims(x_dim='lon', y_dim='lat', inplace=True)
ds.rio.write_crs(crs, inplace=True)
ds.rio.reproject(crs, inplace=True)
ds.rio.to_raster(temp_file)

with rasterio.open(temp_file) as raster:
# see https://gis.stackexchange.com/a/407755 for evenOdd explanation
bbox = ee.Geometry.Rectangle(raster.bounds, proj=proj, evenOdd=False)
intersects = bbox.intersects(point, 1, proj=proj)
self.assertTrue(intersects.getInfo())


if __name__ == '__main__':
absltest.main()

0 comments on commit 7bb1407

Please sign in to comment.