Source code for gplately.grids

#
#    Copyright (C) 2024-2025 The University of Sydney, Australia
#
#    This program is free software; you can redistribute it and/or modify it under
#    the terms of the GNU General Public License, version 2, as published by
#    the Free Software Foundation.
#
#    This program is distributed in the hope that it will be useful, but WITHOUT
#    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
#    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
#    for more details.
#
#    You should have received a copy of the GNU General Public License along
#    with this program; if not, write to Free Software Foundation, Inc.,
#    51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#

"""
This sub-module contains tools for working with MaskedArray, ndarray and netCDF4 rasters, as well as
gridded-data.

Some methods available in `grids`:

* Point data can be interpolated onto a raster or grid with Scipy using linear or
nearest-neighbour interpolation.
* Rasters can be resampled with a set of X and Y-direction spacings, and can be resized
using given X and Y resolutions.
* Grids with invalid (NaN-type) data cells can have their NaN entries replaced
with the values of their nearest valid neighbours.

Classes
-------
* RegularGridInterpolator
* Raster
"""

import copy
import logging
import math
import warnings
from multiprocessing import cpu_count
from typing import Tuple, Union

import matplotlib.colors
import matplotlib.pyplot as plt
import netCDF4
import numpy as np
import pygplates
from cartopy.crs import PlateCarree as _PlateCarree
from cartopy.mpl.geoaxes import GeoAxes as _GeoAxes
from rasterio.enums import MergeAlg
from rasterio.features import rasterize as _rasterize
from rasterio.transform import from_bounds as _from_bounds
from scipy.interpolate import RegularGridInterpolator as _RGI
from scipy.interpolate import griddata
from scipy.ndimage import distance_transform_edt, map_coordinates
from scipy.spatial import cKDTree as _cKDTree  # type: ignore
from scipy.spatial.transform import Rotation as _Rotation

from .geometry import pygplates_to_shapely
from .reconstruction import PlateReconstruction as _PlateReconstruction
from .tools import _deg2pixels, griddata_sphere

logger = logging.getLogger("gplately")

__all__ = [
    "fill_raster",
    "read_netcdf_grid",
    "write_netcdf_grid",
    "RegularGridInterpolator",
    "sample_grid",
    "reconstruct_grid",
    "rasterise",
    "rasterize",
    "Raster",
    # "TimeRaster",
]


def fill_raster(data, invalid=None):
    """Search a grid of ``data`` for invalid cells (i.e NaN-type entries) and fill each
    invalid cell with the value of its nearest valid neighbour.

    .. note::

        Uses scipy's ``distance_transform_edt`` function to perform an Exact Euclidean
        Distance Transform (EEDT). This locates the nearest valid neighbours of an invalid
        ``data`` cell.

        An optional parameter, ``invalid``, is a binary ndarray with the same dimensions
        as ``data`` and the following entries:

        * 1 if its corresponding entry in ``data`` is of NaN-type;
        * 0 if not NaN-type

        This will be used to locate nearest neighbour fill values during the Exact Euclidian
        Distance Transform. If ``invalid`` is not passed to ``fill_raster``, it will be created
        for the user.

    Parameters
    ----------
    data : MaskedArray
        A MaskedArray of data that may have invalid cells (i.e. entries of type NaN).

    invalid : ndarray, optional, default=None
        An ndarray with the same shape as ``data`` whose elements are 1 if its corresponding
        elements in ``data`` are of type ``NaN``, and 0 if its corresponding entries in ``data``
        are valid. An optional parameter - this will be created for the user if it isn't
        provided.

    Returns
    -------
    data : ndarray
        An updated ``data`` array where each invalid cell has been replaced with the value
        of its nearest valid neighbour.
    """
    masked_array = hasattr(data, "fill_value")
    mask_fill_value = None
    if masked_array:
        mask_fill_value = data.data == data.fill_value
        data = data.data.copy()
        data[mask_fill_value] = np.nan
    else:
        data = data.copy()

    if invalid is None:
        invalid = np.isnan(data)
        if masked_array:
            invalid += mask_fill_value
    ind = distance_transform_edt(invalid, return_distances=False, return_indices=True)
    return data[tuple(ind)]


def _realign_grid(array, lons, lats):
    """realigns grid to -180/180 and flips the array if the latitudinal coordinates are decreasing."""
    mask_lons = lons > 180

    # realign to -180/180
    if mask_lons.any():
        dlon = np.diff(lons).mean()
        array = np.hstack([array[:, mask_lons], array[:, ~mask_lons]])
        lons = np.hstack([lons[mask_lons] - 360 - dlon, lons[~mask_lons]])

    if lats[0] > lats[-1]:
        array = np.flipud(array)
        lats = lats[::-1]

    return array, lons, lats


def _guess_data_variable_name(cdf: netCDF4.Dataset, x_name: str, y_name: str) -> Union[str, None]:  # type: ignore
    """best effort to find out the data variable name"""
    vars = cdf.variables.keys()
    for var in vars:
        dimensions = cdf.variables[var].dimensions
        if len(dimensions) != 2:  # only consider two-dimensional data
            continue
        else:
            if dimensions[0] == y_name and dimensions[1] == x_name:
                return var
    return None


[docs] def read_netcdf_grid( filename, return_grids: bool = False, realign: bool = False, resample=None, resize=None, x_dimension_name: str = "", y_dimension_name: str = "", data_variable_name: str = "", ) -> Union[ Tuple[np.ma.MaskedArray, np.ma.MaskedArray, np.ma.MaskedArray], np.ma.MaskedArray ]: """Read grid data from a NetCDF (.nc) file. Parameters ---------- filename : str Full path to the ``netCDF`` raster file. return_grids : bool, optional, default=False If set to ``True``, returns lon, lat arrays associated with the grid data. realign : bool, optional, default=False if set to ``True``, realigns grid to -180/180 and flips the array if the latitudinal coordinates are decreasing. resample : tuple, optional, default=None If provided as ``resample = (spacingX, spacingY)``, the grid data will be resampled with these x and y resolutions. resize : tuple, optional, default=None If provided as ``resample = (resX, resY)``, the grid data will be resized to the number of columns (resX) and rows (resY). x_dimension_name : str, optional, default="" If the grid file uses the comman names, such as ``x``, ``lon``, ``lons`` or ``longitude``, you need not to provide this parameter. Otherwise, you need to tell us what the x dimension name is. y_dimension_name : str, optional, default="" If the grid file uses the comman names, such as ``y``, ``lat``, ``lats`` or ``latitude``, you need not to provide this parameter. Otherwise, you need to tell us what the y dimension name is. data_variable_name : str, optional, default="" GPlately will try its best to guess the data variable name. However, it would be much better if you tell us what the data variable name is. Otherwise, GPlately's guess may/may not be correct. Returns ------- grid_z : `MaskedArray`_ A `MaskedArray`_ object containing the grid data. The longitudes are re-aligned between -180 and 180 degrees. lon, lat : `MaskedArray`_ When ``return_grids`` is ``True``, return two additional `MaskedArray`_ objects containing the longitudes and latitudes of the grid data. .. _MaskedArray: https://numpy.org/doc/stable/reference/maskedarray.generic.html """ def find_label(keys, labels): for label in labels: if label in keys: return label return None # possible permutations of lon/lat/z label_lon = ["lon", "lons", "longitude", "x", "east", "easting", "eastings"] label_lat = ["lat", "lats", "latitude", "y", "north", "northing", "northings"] label_z = ["z", "data", "values", "Band1", "__xarray_dataarray_variable__"] # add capitalise and upper case permutations label_lon = ( label_lon + [label.capitalize() for label in label_lon] + [label.upper() for label in label_lon] ) label_lat = ( label_lat + [label.capitalize() for label in label_lat] + [label.upper() for label in label_lat] ) label_z = ( label_z + [label.capitalize() for label in label_z] + [label.upper() for label in label_z] ) # open netCDF file and re-align from -180, 180 degrees with netCDF4.Dataset(filename, "r") as cdf: keys = cdf.variables.keys() # find the names of variables if data_variable_name: key_z = data_variable_name else: key_z = find_label(keys, label_z) if x_dimension_name: key_lon = x_dimension_name else: key_lon = find_label(keys, label_lon) if y_dimension_name: key_lat = y_dimension_name else: key_lat = find_label(keys, label_lat) if key_lon is None or key_lat is None: raise ValueError( f"Cannot find x,y or lon/lat coordinates in netcdf. The dimensions in the file are {cdf.dimensions.keys()}" ) if key_z is None: key_z = _guess_data_variable_name(cdf, key_lon, key_lat) if key_z is None: raise ValueError( f"Cannot find z data in netcdf. The variables in the file are {cdf.variables.keys()}" ) # extract data from cdf variables # TODO: the dimensions of data may not be (lat, lon). It is possible(but unlikely?) that the dimensions are(lon, lat). # just note you may need numpy.swapaxes() here. if len(cdf[key_z].dimensions) != 2: raise Exception( f"The data in the netcdf file is not two-dimensional. This function can only handle two-dimensional data." + f"The dimensions in the file are {cdf[key_z].dimensions.keys()}" ) cdf_grid = cdf[key_z][:] cdf_lon = cdf[key_lon][:] cdf_lat = cdf[key_lat][:] # fill missing values if hasattr(cdf[key_z], "missing_value") and np.issubdtype( cdf_grid.dtype, np.floating ): fill_value = cdf[key_z].missing_value cdf_grid[np.isclose(cdf_grid, fill_value, rtol=0.1)] = np.nan # convert to boolean array if np.issubdtype(cdf_grid.dtype, np.integer): unique_grid = np.unique(cdf_grid) if len(unique_grid) == 2: if (unique_grid == [0, 1]).all(): cdf_grid = cdf_grid.astype(bool) if realign: # realign longitudes to -180/180 dateline cdf_grid_z, cdf_lon, cdf_lat = _realign_grid(cdf_grid, cdf_lon, cdf_lat) else: cdf_grid_z = cdf_grid # resample if resample is not None: spacingX, spacingY = resample # don't resample if already the same resolution dX = np.diff(cdf_lon).mean() dY = np.diff(cdf_lat).mean() if spacingX != dX or spacingY != dY: lon_grid = np.arange(cdf_lon.min(), cdf_lon.max() + spacingX, spacingX) lat_grid = np.arange(cdf_lat.min(), cdf_lat.max() + spacingY, spacingY) lonq, latq = np.meshgrid(lon_grid, lat_grid) original_extent = ( cdf_lon[0], cdf_lon[-1], cdf_lat[0], cdf_lat[-1], ) cdf_grid_z = sample_grid( lonq, latq, cdf_grid_z, method="nearest", extent=original_extent, return_indices=False, ) cdf_lon = lon_grid cdf_lat = lat_grid # resize if resize is not None: resX, resY = resize # don't resize if already the same shape if resX != cdf_grid_z.shape[1] or resY != cdf_grid_z.shape[0]: original_extent = ( cdf_lon[0], cdf_lon[-1], cdf_lat[0], cdf_lat[-1], ) lon_grid = np.linspace(original_extent[0], original_extent[1], resX) lat_grid = np.linspace(original_extent[2], original_extent[3], resY) lonq, latq = np.meshgrid(lon_grid, lat_grid) cdf_grid_z = sample_grid( lonq, latq, cdf_grid_z, method="nearest", extent=original_extent, return_indices=False, ) cdf_lon = lon_grid cdf_lat = lat_grid # Fix grids with 9e36 as the fill value for nan. # cdf_grid_z.fill_value = float('nan') # cdf_grid_z.data[cdf_grid_z.data > 1e36] = cdf_grid_z.fill_value if return_grids: return cdf_grid_z, cdf_lon, cdf_lat else: return cdf_grid_z
def write_netcdf_grid( filename, grid, extent: Union[tuple, str] = "global", significant_digits=None, fill_value: Union[float, None] = np.nan, ): """Write geological data contained in a `grid` to a netCDF4 grid with a specified `filename`. Notes ----- The written netCDF4 grid has the same latitudinal and longitudinal (row and column) dimensions as `grid`. It has three variables: * Latitudes of `grid` data * Longitudes of `grid` data * The data stored in `grid` However, the latitudes and longitudes of the grid returned to the user are constrained to those specified in `extent`. By default, `extent` assumes a global latitudinal and longitudinal span: `extent=[-180,180,-90,90]`. Parameters ---------- filename : str The full path (including a filename and the ".nc" extension) to save the created netCDF4 `grid` to. grid : array-like An ndarray grid containing data to be written into a `netCDF` (.nc) file. Note: Rows correspond to the data's latitudes, while the columns correspond to the data's longitudes. extent : list, default=[-180,180,-90,90] Four elements that specify the [min lon, max lon, min lat, max lat] to constrain the lat and lon variables of the netCDF grid to. If no extents are supplied, full global extent `[-180, 180, -90, 90]` is assumed. significant_digits : int Applies lossy data compression up to a specified number of significant digits. This significantly reduces file size, but make sure the required precision is preserved in the saved netcdf file. fill_value : scalar, NoneType, default: np.nan Value used to fill in missing data. By default this is np.nan. Returns ------- A netCDF grid will be saved to the path specified in `filename`. """ import netCDF4 from gplately import __version__ as _version if extent == "global": extent = (-180, 180, -90, 90) else: extent = tuple(extent) assert len(extent) == 4, "specify the [min lon, max lon, min lat, max lat]" nrows, ncols = np.shape(grid) assert isinstance(extent, tuple) lon_grid = np.linspace(extent[0], extent[1], ncols) lat_grid = np.linspace(extent[2], extent[3], nrows) data_kwds = {"compression": "zlib", "complevel": 6} with netCDF4.Dataset(filename, "w", driver=None) as cdf: cdf.title = "Grid produced by gplately " + str(_version) cdf.createDimension("lon", lon_grid.size) cdf.createDimension("lat", lat_grid.size) cdf_lon = cdf.createVariable("lon", lon_grid.dtype, ("lon",), **data_kwds) cdf_lat = cdf.createVariable("lat", lat_grid.dtype, ("lat",), **data_kwds) cdf_lon[:] = lon_grid cdf_lat[:] = lat_grid # Units for Geographic Grid type cdf_lon.units = "degrees_east" cdf_lon.standard_name = "lon" cdf_lon.actual_range = [lon_grid[0], lon_grid[-1]] cdf_lat.units = "degrees_north" cdf_lat.standard_name = "lat" cdf_lat.actual_range = [lat_grid[0], lat_grid[-1]] # create container variable for CRS: lon/lat WGS84 datum crso = cdf.createVariable("crs", "i4") crso.long_name = "Lon/Lat Coords in WGS84" crso.grid_mapping_name = "latitude_longitude" crso.longitude_of_prime_meridian = 0.0 crso.semi_major_axis = 6378137.0 crso.inverse_flattening = 298.257223563 crso.spatial_ref = """GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.01745329251994328,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]]""" # add more keyword arguments for quantizing data if significant_digits: # significant_digits needs to be >= 2 so that NaNs are preserved data_kwds["significant_digits"] = max(2, int(significant_digits)) data_kwds["quantize_mode"] = "GranularBitRound" # boolean arrays need to be converted to integers # no such thing as a mask on a boolean array if grid.dtype is np.dtype(bool): grid = grid.astype("i1") fill_value = None cdf_data = cdf.createVariable("z", grid.dtype, ("lat", "lon"), **data_kwds) # netCDF4 uses the missing_value attribute as the default _FillValue # without this, _FillValue defaults to 9.969209968386869e+36 if fill_value is not None: cdf_data.missing_value = fill_value grid_mask = grid != fill_value cdf_data.actual_range = [ np.nanmin(grid[grid_mask]), np.nanmax(grid[grid_mask]), ] else: # ensure min and max z values are properly registered cdf_data.actual_range = [np.nanmin(grid), np.nanmax(grid)] cdf_data.standard_name = "z" # cdf_data.add_offset = 0.0 cdf_data.grid_mapping = "crs" # cdf_data.set_auto_maskandscale(False) # write data cdf_data[:, :] = grid class RegularGridInterpolator(_RGI): """A class to sample gridded data at a set of point coordinates using either linear or nearest-neighbour interpolation methods. It is a child class of `scipy 1.10`'s [`RegularGridInterpolator`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.RegularGridInterpolator.html) class. This will only work for scipy version 1.10 onwards. Attributes ---------- points : tuple of ndarrays of float with shapes (m1, ), …, (mn, ) Each array contains point coordinates that define the regular grid in n dimensions. values : ndarray The data on a regular grid. Note: the number of rows corresponds to the number of point latitudes, while the number of columns corresponds to the number of point longitudes. method : str, default=’linear’ The method of interpolation to perform. Supported are "linear" and "nearest". Assumes “linear” by default. bounds_error : bool, default=false Choose whether to return a ValueError and terminate the interpolation if any provided sample points are out of grid bounds. By default, it is set to `False`. In this case, all out-of-bound point values are replaced with the `fill_value` (defined below) if supplied. fill_value : float, default=np.nan Used to replace point values that are out of grid bounds, provided that ‘bounds_error’ is false. """ def __init__( self, points, values, method="linear", bounds_error=False, fill_value=np.nan ): super(RegularGridInterpolator, self).__init__( points, values, method, bounds_error, fill_value ) def __call__(self, xi, method=None, return_indices=False, return_distances=False): """Samples gridded data at a set of point coordinates. Uses either a linear or nearest-neighbour interpolation `method`. Uses the gridded data specified in the sample_grid method parameter. Note: if any provided sample points are out of grid bounds and a corresponding error message was suppressed (by specifying bounds_error=False), all out-of-bound point values are replaced with the self.fill_value attribute ascribed to the RegularGridInterpolator object (if it exists). Terminates otherwise. This is identical to scipy 1.10's RGI object. Parameters ---------- xi : ndarray of shape (..., ndim) The coordinates of points to sample the gridded data at. method : str, default=None The method of interpolation to perform. Supported are "linear" and "Nearest". Assumes “linear” interpolation if None provided. return_indices : bool, default=False Choose whether to return indices of neighbouring sampling points. return_distances : bool, default=False Choose whether to return normal distances between interpolated points and neighbouring sampling points. Returns ------- output_tuple : tuple of ndarrays The first ndarray in the output tuple holds the interpolated grid data. If sample point distances and indices are required, these are returned as subsequent tuple elements. Raises ------ ValueError * Raised if the string method supplied is not “linear” or “nearest”. * Raised if the provided sample points for interpolation (xi) do not have the same dimensions as the supplied grid. * Raised if the provided sample points for interpolation include any point out of grid bounds. Alerts user which dimension (index) the point is located. Only raised if the RegularGridInterpolator attribute bounds_error is set to True. If suppressed, out-of-bound points are replaced with a set fill_value. """ method = self.method if method is None else method if method not in ["linear", "nearest"]: raise ValueError("Method '%s' is not defined" % method) xi, xi_shape, ndim, nans, out_of_bounds = self._prepare_xi(xi) indices, norm_distances = self._find_indices(xi.T) if method == "linear": result = self._evaluate_linear(indices, norm_distances) elif method == "nearest": result = self._evaluate_nearest(indices, norm_distances) if not self.bounds_error and self.fill_value is not None: result[out_of_bounds] = self.fill_value interp_output = result.reshape(xi_shape[:-1] + self.values.shape[ndim:]) output_tuple = [interp_output] if return_indices: output_tuple.append(indices) if return_distances: output_tuple.append(norm_distances) if return_distances or return_indices: return tuple(output_tuple) else: return output_tuple[0] def _prepare_xi(self, xi): try: from scipy.interpolate.interpnd import _ndim_coords_from_arrays except ImportError: # SciPy 1.15 renamed interpnd to _interpnd (see https://github.com/scipy/scipy/pull/21754). from scipy.interpolate._interpnd import _ndim_coords_from_arrays ndim = len(self.grid) xi = _ndim_coords_from_arrays(xi, ndim=ndim) if xi.shape[-1] != len(self.grid): raise ValueError( "The requested sample points xi have dimension " f"{xi.shape[-1]} but this " f"RegularGridInterpolator has dimension {ndim}" ) xi_shape = xi.shape xi = xi.reshape(-1, xi_shape[-1]) # find nans in input nans = np.any(np.isnan(xi), axis=-1) if self.bounds_error: for i, p in enumerate(xi.T): if not np.logical_and( np.all(self.grid[i][0] <= p), np.all(p <= self.grid[i][-1]) ): raise ValueError( "One of the requested xi is out of bounds " "in dimension %d" % i ) out_of_bounds = None else: out_of_bounds = self._find_out_of_bounds(xi.T) return xi, xi_shape, ndim, nans, out_of_bounds def _find_out_of_bounds(self, xi): # check for out of bounds xi out_of_bounds = np.zeros((xi.shape[1]), dtype=bool) # iterate through dimensions for x, grid in zip(xi, self.grid): out_of_bounds += x < grid[0] out_of_bounds += x > grid[-1] return out_of_bounds def _find_indices(self, xi): """Index identifier outsourced from scipy 1.9's RegularGridInterpolator to ensure stable operations with all versions of scipy >1.0. """ # find relevant edges between which xi are situated indices = [] # compute distance to lower edge in unity units norm_distances = [] # iterate through dimensions for x, grid in zip(xi, self.grid): i = np.searchsorted(grid, x) - 1 i[i < 0] = 0 i[i > grid.size - 2] = grid.size - 2 indices.append(i) # compute norm_distances, incl length-1 grids, # where `grid[i+1] == grid[i]` denom = grid[i + 1] - grid[i] with np.errstate(divide="ignore", invalid="ignore"): norm_dist = np.where(denom != 0, (x - grid[i]) / denom, 0) norm_distances.append(norm_dist) return indices, norm_distances def _evaluate_linear(self, indices, norm_distances): """Linear interpolator outsourced from scipy 1.9's RegularGridInterpolator to ensure stable operations with all versions of scipy >1.0. """ import itertools # slice for broadcasting over trailing dimensions in self.values vslice = (slice(None),) + (None,) * (self.values.ndim - len(indices)) # Compute shifting up front before zipping everything together shift_norm_distances = [1 - yi for yi in norm_distances] shift_indices = [i + 1 for i in indices] # The formula for linear interpolation in 2d takes the form: # values = self.values[(i0, i1)] * (1 - y0) * (1 - y1) + \ # self.values[(i0, i1 + 1)] * (1 - y0) * y1 + \ # self.values[(i0 + 1, i1)] * y0 * (1 - y1) + \ # self.values[(i0 + 1, i1 + 1)] * y0 * y1 # We pair i with 1 - yi (zipped1) and i + 1 with yi (zipped2) zipped1 = zip(indices, shift_norm_distances) zipped2 = zip(shift_indices, norm_distances) # Take all products of zipped1 and zipped2 and iterate over them # to get the terms in the above formula. This corresponds to iterating # over the vertices of a hypercube. hypercube = itertools.product(*zip(zipped1, zipped2)) values = 0.0 for h in hypercube: edge_indices, weights = zip(*h) weight = 1.0 for w in weights: weight *= w values += np.asarray(self.values[edge_indices]) * weight[vslice] return values def _evaluate_nearest(self, indices, norm_distances): """Nearest neighbour interpolator outsourced from scipy 1.9's RegularGridInterpolator to ensure stable operations with all versions of scipy >1.0. """ idx_res = [ np.where(yi <= 0.5, i, i + 1) for i, yi in zip(indices, norm_distances) ] return self.values[tuple(idx_res)] def sample_grid( lon, lat, grid, method="linear", extent: Union[tuple, str] = "global", origin=None, return_indices=False, ): """Sample point data with given `lon` and `lat` coordinates onto a `grid` using spline interpolation. Parameters ---------- lon, lat : array_like The longitudes and latitudes of the points to interpolate onto the gridded data. Must be broadcastable to a common shape. grid : Raster or array_like An array whose elements define a grid. The number of rows corresponds to the number of point latitudes, while the number of columns corresponds to the number of point longitudes. method : str or int; default: 'linear' The order of spline interpolation. Must be an integer in the range 0-5. 'nearest', 'linear', and 'cubic' are aliases for 0, 1, and 3, respectively. extent : str or 4-tuple, default: 'global' 4-tuple to specify (min_lon, max_lon, min_lat, max_lat) extents of the raster. If no extents are supplied, full global extent [-180,180,-90,90] is assumed (equivalent to `extent='global'`). For array data with an upper-left origin, make sure `min_lat` is greater than `max_lat`, or specify `origin` parameter. origin : {'lower', 'upper'}, optional When `data` is an array, use this parameter to specify the origin (upper left or lower left) of the data (overriding `extent`). return_indices : bool, default=False Whether to return the row and column indices of the nearest grid points. Returns ------- numpy.ndarray The values interpolated at the input points. indices : 2-tuple of numpy.ndarray The i- and j-indices of the nearest grid points to the input points, only present if `return_indices=True`. Raises ------ ValueError If an invalid `method` is provided. RuntimeWarning If `lat` contains any invalid values outside of the interval [-90, 90]. Invalid values will be clipped to this interval. Notes ----- If `return_indices` is set to `True`, the nearest array indices are returned as a tuple of arrays, in (i, j) or (lat, lon) format. An example output: # The first array holds the rows of the raster where point data spatially falls near. # The second array holds the columns of the raster where point data spatially falls near. sampled_indices = (array([1019, 1019, 1019, ..., 1086, 1086, 1087]), array([2237, 2237, 2237, ..., 983, 983, 983])) """ order = { "nearest": 0, "linear": 1, "cubic": 3, }.get(method, method) if order not in {0, 1, 2, 3, 4, 5}: raise ValueError("Invalid `method` parameter: {}".format(method)) if isinstance(grid, Raster): extent = grid.extent grid = np.array(grid.data) else: extent = _parse_extent_origin(extent, origin) grid = _check_grid(grid) # Do not wrap from North to South Pole (or vice versa) if np.any(np.abs(lat) > 90.0): warnings.warn( "Invalid values encountered in lat; clipping to [-90, 90]", RuntimeWarning, ) lat = np.clip(lat, -90.0, 90.0) dx = (extent[1] - extent[0]) / (np.shape(grid)[1] - 1) dy = (extent[3] - extent[2]) / (np.shape(grid)[0] - 1) point_i = (lat - extent[2]) / dy point_j = (lon - extent[0]) / dx point_coords = np.row_stack( ( np.ravel(point_i), np.ravel(point_j), ) ) if np.ndim(grid) == 2: interpolated = map_coordinates( np.array(grid, dtype="float"), point_coords, order=order, mode="grid-wrap", prefilter=order > 1, ) interpolated = np.reshape(interpolated, np.shape(lon)) else: # ndim(grid) == 3 depth = np.shape(grid)[2] interpolated = [] interpolated_k = np.array([]) for k in range(depth): interpolated_k = map_coordinates( grid[..., k], point_coords, order=order, mode="grid-wrap", prefilter=order > 1, ) interpolated_k = np.reshape( interpolated_k, np.shape(lon), ) interpolated.append(interpolated_k) del interpolated_k interpolated = np.stack(interpolated, axis=-1) interpolated = interpolated.astype(grid.dtype) if return_indices: indices = ( np.rint(np.ravel(point_i)).astype(np.int_), np.rint(np.ravel(point_j)).astype(np.int_), ) return interpolated, indices return interpolated
[docs] def reconstruct_grid( grid, partitioning_features, rotation_model, to_time, from_time=0.0, extent: Union[tuple, str] = "global", origin=None, fill_value=None, threads=1, anchor_plate_id=None, x_dimension_name: str = "", y_dimension_name: str = "", data_variable_name: str = "", ): """Reconstruct a gridded dataset to a given reconstruction time. .. note:: Use :meth:`Raster.reconstruct` whenever is possible. This :func:`reconstruct_grid` is better to be private. Parameters ---------- grid : array_like, or str The grid to be reconstructed. If ``grid`` is a filename, it will be loaded using :meth:`read_netcdf_grid`. partitioning_features : valid argument to pygplates.FeaturesFunctionArgument Features used to partition the ``grid`` by plate ID, usually a static polygons file. The ``partitioning_features`` may be a single ``pygplates.Feature`` object, a ``pygplates.FeatureCollection``, a filename (:class:`str`), or a (potentially nested) sequence of any combination of the above types. rotation_model : valid argument to pygplates.RotationModel The rotation model used to reconstruct the ``grid``. The ``rotation_model`` may be a ``pygplates.RotationModel`` object, a rotation ``pygplates.FeatureCollection``, a rotation filename (:class:`str`), a rotation ``pygplates.Feature``, a sequence of rotation features, or a (potentially nested) sequence of any combination of the above types. to_time : float Time to which ``grid`` will be reconstructed. from_time : float, default=0.0 Time from which to reconstruct the ``grid``. extent : tuple or str, default="global" Extent of the ``grid``. Valid arguments are a tuple of the form (xmin, xmax, ymin, ymax), or the string "global", equivalent to (-180.0, 180.0, -90.0, 90.0). origin : {"upper", "lower"}, optional Origin of the ``grid`` - either lower-left or upper-left. By default, determined from `extent`. fill_value : float, int, or tuple, optional, default=None The value to be used for regions outside of ``partitioning_features`` at ``to_time``. If not provided, this value will be determined based on the input. threads : int, default=1 Number of threads to use for certain computationally heavy routines. anchor_plate_id : int, optional, default=None ID of the anchored plate. By default, use the default anchor plate ID of ``rotation_model`` if it's a ``pygplates.RotationModel`` (otherwise zero). x_dimension_name : str, optional, default="" If the grid file uses comman names, such as "x", "lon", "lons" or "longitude", you need not set this parameter. Otherwise, you need to tell us what the x dimension name is. y_dimension_name : str, optional, default="" If the grid file uses comman names, such as "y", "lat", "lats" or "latitude", you need not set this parameter. Otherwise, you need to tell us what the y dimension name is. data_variable_name : str, optional, default="" The program will try its best to determine the data variable name. However, it would be better if you could tell us what the data variable name is. Otherwise, the program will guess. The result may/may not be correct. Returns ------- numpy.ndarray The reconstructed grid. Areas for which no plate ID could be determined from ``partitioning_features`` will be filled with ``fill_value``. .. note:: For two-dimensional grids, ``fill_value`` should be a single number. The default value will be ``np.nan`` for float or complex types, the minimum value for integer types, and the maximum value for unsigned types. For RGB image grids, ``fill_value`` should be a 3-tuple RGB colour code or a matplotlib colour name. The default value will be black (0.0, 0.0, 0.0). For RGBA image grids, ``fill_value`` should be a 4-tuple RGBA colour code or a matplotlib colour name. The default fill value will be transparent black (0.0, 0.0, 0.0, 0.0). """ try: grid = np.array( read_netcdf_grid( grid, x_dimension_name=x_dimension_name, y_dimension_name=y_dimension_name, data_variable_name=data_variable_name, ) ) # load grid data from file except Exception: grid = np.array(grid) # copy grid data to array if to_time == from_time: return grid elif rotation_model is None: raise TypeError("`rotation_model` must be provided if `to_time` != `from_time`") extent = _parse_extent_origin(extent, origin) dtype = grid.dtype if isinstance(threads, str): if threads.lower() in {"all", "max"}: threads = cpu_count() else: raise ValueError("Invalid `threads` value: {}".format(threads)) threads = min([int(threads), cpu_count()]) threads = max([threads, 1]) grid = grid.squeeze() grid = _check_grid(grid) # Determine fill_value if fill_value is None: if grid.ndim == 2: if dtype.kind == "i": fill_value = np.iinfo(dtype).min elif dtype.kind == "u": fill_value = np.iinfo(dtype).max else: # dtype.kind in ("f", "c") fill_value = np.nan else: # grid.ndim == 3 if dtype.kind in ("i", "u"): fill_value = tuple([0] * grid.shape[2]) else: # dtype.kind == "f" fill_value = tuple([0.0] * grid.shape[2]) if isinstance(fill_value, str): if grid.ndim == 2: raise TypeError("Invalid fill_value for 2D grid: {}".format(fill_value)) fill_value = np.array(matplotlib.colors.to_rgba(fill_value)) if dtype.kind == "u": fill_value = (fill_value * 255.0).astype("u1") fill_value = np.clip(fill_value, 0, 255) fill_value = tuple(fill_value)[: grid.shape[2]] if ( grid.ndim == 3 and grid.shape[2] == 4 and hasattr(fill_value, "__len__") and len(fill_value) == 3 ): # give fill colour maximum alpha value if not specified fill_alpha = 255 if dtype.kind in ("i", "u") else 1.0 fill_value = (*fill_value, fill_alpha) if np.size(fill_value) != np.atleast_3d(grid).shape[-1]: raise ValueError( "Shape mismatch: " + "fill_value size: {}".format(np.size(fill_value)) + ", grid shape: {}".format(np.shape(grid)) ) xmin, xmax, ymin, ymax = extent ny, nx = grid.shape[:2] if isinstance(partitioning_features, pygplates.FeaturesFunctionArgument): partitioning_features = pygplates.FeatureCollection( partitioning_features.get_features() ) elif not isinstance(partitioning_features, pygplates.FeatureCollection): partitioning_features = pygplates.FeatureCollection( pygplates.FeaturesFunctionArgument(partitioning_features).get_features() ) if not isinstance(rotation_model, pygplates.RotationModel): rotation_model = pygplates.RotationModel(rotation_model) lons = np.linspace(xmin, xmax, nx) lats = np.linspace(ymin, ymax, ny) m_lons, m_lats = np.meshgrid(lons, lats) valid_partitioning_features = [ i for i in partitioning_features if i.is_valid_at_time(from_time) and i.is_valid_at_time(to_time) ] plate_ids = rasterise( features=valid_partitioning_features, rotation_model=rotation_model, key="plate_id", time=from_time, extent=extent, shape=grid.shape[:2], origin=origin, anchor_plate_id=anchor_plate_id, ) valid_output_mask = ( rasterise( features=valid_partitioning_features, rotation_model=rotation_model, key="plate_id", time=to_time, extent=extent, shape=grid.shape[:2], origin=origin, anchor_plate_id=anchor_plate_id, ) != -1 ) valid_mask = plate_ids != -1 valid_m_lons = m_lons[valid_mask] valid_m_lats = m_lats[valid_mask] valid_plate_ids = plate_ids[valid_mask] if grid.ndim == 2: valid_data = grid[valid_mask] else: valid_data = np.empty( (grid.shape[2], np.sum(valid_mask)), dtype=dtype, ) for k in range(grid.shape[2]): valid_data[k, :] = grid[..., k][valid_mask] if grid.ndim == 2: output_grid = np.full(grid.shape, fill_value) else: output_grid = np.empty(grid.shape, dtype=dtype) for k in range(grid.shape[2]): output_grid[..., k] = fill_value[k] output_lons = m_lons[valid_output_mask] output_lats = m_lats[valid_output_mask] unique_plate_ids, inv = np.unique(valid_plate_ids, return_inverse=True) rotations_dict = {} for plate in unique_plate_ids: rot = rotation_model.get_rotation( to_time=float(to_time), from_time=float(from_time), moving_plate_id=int(plate), anchor_plate_id=anchor_plate_id, # if None then uses default anchor plate of 'rotation_model' ) if not isinstance(rot, pygplates.FiniteRotation): raise ValueError("No rotation found for plate ID: {}".format(plate)) lat, lon, angle = rot.get_lat_lon_euler_pole_and_angle_degrees() angle = np.deg2rad(angle) vec = _lat_lon_to_vector(lat, lon, degrees=True) rotations_dict[plate] = vec * angle rotations_array = np.array([rotations_dict[x] for x in unique_plate_ids])[inv] combined_rotations = _Rotation.from_rotvec(rotations_array) point_vecs = _lat_lon_to_vector( np.ravel(valid_m_lats), np.ravel(valid_m_lons), degrees=True, ) rotated_vecs = combined_rotations.apply(point_vecs) tree = _cKDTree(rotated_vecs) output_vecs = _lat_lon_to_vector( output_lats, output_lons, degrees=True, ) # Compatibility with older versions of SciPy: # 'n_jobs' argument was replaced with 'workers' try: _, indices = tree.query( output_vecs, k=1, workers=threads, ) except TypeError as err: if "Unexpected keyword argument" in err.args[0] and "workers" in err.args[0]: _, indices = tree.query( output_vecs, k=1, n_jobs=threads, ) else: raise err if grid.ndim == 2: output_data = valid_data[indices] output_grid[valid_output_mask] = output_data else: for k in range(grid.shape[2]): output_data = valid_data[k, indices] output_grid[..., k][valid_output_mask] = output_data return output_grid
def rasterise( features, rotation_model=None, key="plate_id", time=None, resx=1.0, resy=1.0, shape=None, extent: Union[tuple, str] = "global", origin=None, tessellate_degrees=0.1, anchor_plate_id=None, ): """Rasterise geometries or GPlates features at a given reconstruction time. This function is particularly useful for rasterising static polygons to extract a grid of plate IDs. Parameters ---------- features : geometries or features `features` may be a single `pygplates.Feature`, a `pygplates.FeatureCollection`, a `str` filename, or a (potentially nested) sequence of any combination of the above types. Alternatively, `features` may also be a sequence of geometry types (`pygplates.GeometryOnSphere` or `pygplates.ReconstructionGeometry`). In this case, `rotation_model` and `time` will be ignored, and `key` must be an array_like of the same length as `features`. rotation_model : valid argument for pygplates.RotationModel, optional `rotation_model` may be a `pygplates.RotationModel`, a rotation feature collection (pygplates.FeatureCollection), a rotation filename (`str`), a rotation feature (`pygplates.Feature`), a sequence of rotation features, or a (potentially nested) sequence of any combination of the above types. Alternatively, if time not given, a rotation model is not usually required. key : str or array_like, default "plate_id" The value used to create the rasterised grid. May be any of the following values: - "plate_id" - "conjugate_plate_id" - "from_age" - "to_age" - "left_plate" - "right_plate" Alternatively, `key` may be a sequence of the same length as `features`. time : float, optional Reconstruction time at which to perform rasterisation. If given, `rotation_model` must also be specified. resx, resy : float, default 1.0 Resolution (in degrees) of the rasterised grid. shape : tuple, optional If given, the output grid will have the specified shape, overriding `resx` and `resy`. extent : tuple or "global", default "global" Extent of the rasterised grid. Valid arguments are a tuple of the form (xmin, xmax, ymin, ymax), or the string "global", equivalent to (-180.0, 180.0, -90.0, 90.0). origin : {"upper", "lower"}, optional Origin (upper-left or lower-left) of the output array. By default, determined from `extent`. tessellate_degrees : float, default 0.1 Densify pyGPlates geometries to this resolution before conversion. Can be disabled by specifying `tessellate_degrees=None`, but this may provide inaccurate results for low-resolution input geometries. Returns ------- grid : numpy.ndarray The output array will have the shape specified in `shape`, if given. The origin of the array will be in the lower-left corner of the area specified in `extent`, unless `resx` or `resy` is negative. Raises ------ ValueError If an invalid `key` value is passed. TypeError If `rotation_model` is not supplied and `time` is not `None`. Notes ----- This function is used by gplately.grids.reconstruct_grids to rasterise static polygons in order to extract their plate IDs. """ valid_keys = { "plate_id", "conjugate_plate_id", "from_age", "to_age", "left_plate", "right_plate", } if isinstance(key, str): key = key.lower() if key not in valid_keys: raise ValueError( "Invalid key: {}".format(key) + "\nkey must be one of {}".format(valid_keys) ) extent = _parse_extent_origin(extent, origin) minx, maxx, miny, maxy = extent if minx > maxx: resx = -1.0 * np.abs(resx) if miny > maxy: resy = -1.0 * np.abs(resy) if shape is not None: lons = np.linspace(minx, maxx, shape[1], endpoint=True) lats = np.linspace(miny, maxy, shape[0], endpoint=True) else: lons = np.arange(minx, maxx + resx, resx) lats = np.arange(miny, maxy + resy, resy) nx = lons.size ny = lats.size try: features = pygplates.FeaturesFunctionArgument(features).get_features() geometries = None except Exception as err: if not str(err).startswith("Python argument types in"): # Not a Boost.Python.ArgumentError raise err geometries = pygplates_to_shapely( features, tessellate_degrees=tessellate_degrees, ) reconstructed = [] if geometries is None: if rotation_model is None: if time is not None: raise TypeError( "Rotation model must be provided if `time` is not `None`" ) rotation_model = pygplates.RotationModel(pygplates.Feature()) time = 0.0 features = pygplates.FeaturesFunctionArgument(features).get_features() if time is None: time = 0.0 time = float(time) pygplates.reconstruct( # type: ignore features, rotation_model, reconstructed, time, anchor_plate_id=anchor_plate_id, ) geometries = pygplates_to_shapely( reconstructed, tessellate_degrees=tessellate_degrees, ) if not isinstance(geometries, list): geometries = [geometries] if isinstance(key, str): values, fill_value, dtype = _get_rasterise_values(key, reconstructed) else: if not hasattr(key, "__len__"): key = [key] * len(geometries) if len(key) != len(geometries): raise ValueError( "Shape mismatch: len(key) = {}, ".format(len(key)) + "len(geometries) = {}".format(len(geometries)) ) values = np.array(key) dtype = values.dtype if dtype.kind == "u": fill_value = np.iinfo(dtype).max elif dtype.kind == "i": fill_value = -1 elif dtype.kind == "f": fill_value = np.nan else: raise TypeError("Unrecognised dtype for `key`: {}".format(dtype)) return _rasterise_geometries( geometries=geometries, values=values, out_shape=(ny, nx), fill_value=fill_value, dtype=dtype, merge_alg=MergeAlg.replace, transform=_from_bounds(minx, miny, maxx, maxy, nx, ny), ) def _get_rasterise_values( key, reconstructed, ): valid_keys = { "plate_id", "conjugate_plate_id", "from_age", "to_age", "left_plate", "right_plate", } if key == "plate_id": values = [i.get_feature().get_reconstruction_plate_id() for i in reconstructed] fill_value = -1 dtype = np.int32 elif key == "conjugate_plate_id": values = [i.get_feature().get_conjugate_plate_id() for i in reconstructed] fill_value = -1 dtype = np.int32 elif key == "from_age": values = [i.get_feature().get_valid_time()[0] for i in reconstructed] fill_value = np.nan dtype = np.float32 elif key == "to_age": values = [i.get_feature().get_valid_time()[1] for i in reconstructed] fill_value = np.nan dtype = np.float32 elif key == "left_plate": values = [i.get_feature().get_left_plate() for i in reconstructed] fill_value = -1 dtype = np.int32 elif key == "right_plate": values = [i.get_feature().get_right_plate() for i in reconstructed] fill_value = -1 dtype = np.int32 else: raise ValueError( "Invalid key: {}".format(key) + "\nkey must be one of {}".format(valid_keys) ) return values, fill_value, dtype def _rasterise_geometries( geometries, values, out_shape, fill_value, dtype, transform, merge_alg=MergeAlg.replace, ): shapes = zip(geometries, values) out = _rasterize( shapes=shapes, out_shape=out_shape, fill=fill_value, dtype=dtype, merge_alg=merge_alg, transform=transform, ) return np.flipud(out) rasterize = rasterise def _lat_lon_to_vector(lat, lon, degrees=False): """Convert (lat, lon) coordinates (degrees or radians) to vectors on the unit sphere. Returns a vector of shape (3,) if `lat` and `lon` are single values, else an array of shape (N, 3) containing N (x, y, z) row vectors, where N is the size of `lat` and `lon`. """ lon = np.atleast_1d(lon).flatten() lat = np.atleast_1d(lat).flatten() if degrees: lat = np.deg2rad(lat) lon = np.deg2rad(lon) x = np.cos(lat) * np.cos(lon) y = np.cos(lat) * np.sin(lon) z = np.sin(lat) size = x.size if size == 1: x = np.atleast_1d(np.squeeze(x))[0] y = np.atleast_1d(np.squeeze(y))[0] z = np.atleast_1d(np.squeeze(z))[0] return np.array((x, y, z)) x = x.reshape((-1, 1)) y = y.reshape((-1, 1)) z = z.reshape((-1, 1)) return np.hstack((x, y, z)) def _vector_to_lat_lon( x, y, z, degrees=False, return_array=False, ): """Convert one or more (x, y, z) vectors (on the unit sphere) to (lat, lon) coordinate pairs, in degrees or radians. """ x = np.atleast_1d(x).flatten() y = np.atleast_1d(y).flatten() z = np.atleast_1d(z).flatten() with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) lat = np.arcsin(z) lon = np.arctan2(y, x) if degrees: lat = np.rad2deg(lat) lon = np.rad2deg(lon) if lat.size == 1 and not return_array: lat = np.atleast_1d(np.squeeze(lat))[0] lon = np.atleast_1d(np.squeeze(lon))[0] return (lat, lon) lat = lat.reshape((-1, 1)) lon = lon.reshape((-1, 1)) return lat, lon def _check_grid_shape(data): """Check data is a 2D grid or a 3D RGB(A) image.""" ndim = np.ndim(data) shape = np.shape(data) valid = True if ndim not in (2, 3): # ndim == 2: greyscale image/grid # ndim == 3: colour RGB(A) image valid = False if ndim == 3 and shape[2] not in (3, 4): # shape[2] == 3: colour image (RGB) # shape[2] == 4: colour image w/ transparency (RGBA) valid = False if not valid: raise ValueError("Invalid grid shape: {}".format(shape)) def _check_image_values(data): """Check values are within correct range for an RGB(A) image.""" dtype = data.dtype if dtype.kind == "i": data = data.astype("u1") dtype = data.dtype min_value = np.nanmin(data) max_value = np.nanmax(data) if min_value < 0: raise ValueError("Invalid value for RGB(A) image: {}".format(min_value)) if (dtype.kind == "f" and max_value > 1.0) or ( dtype.kind == "u" and max_value > 255 ): raise ValueError("Invalid value for RGB(A) image: {}".format(max_value)) return data def _check_grid(data): """Check grid shape and values make sense.""" if not isinstance(data, np.ndarray): data = np.array(data) ndim = data.ndim dtype = data.dtype _check_grid_shape(data) if ndim == 3: # data is an RGB(A) image data = _check_image_values(data) return data def _parse_extent_origin(extent, origin): """Default values: extent='global', origin=None""" if hasattr(extent, "lower"): # i.e. a string extent = extent.lower() if extent is None or extent == "global": extent = (-180.0, 180.0, -90.0, 90.0) elif len(extent) != 4: raise TypeError("`extent` must be a four-element tuple, 'global', or None") extent = tuple(float(i) for i in extent) if origin is not None: origin = str(origin).lower() if origin == "lower" and extent[2] > extent[3]: extent = ( extent[0], extent[1], extent[3], extent[2], ) if origin == "upper" and extent[2] < extent[3]: extent = ( extent[0], extent[1], extent[3], extent[2], ) return extent
[docs] class Raster(object): """The functionalities include sampling data at points using spline interpolation, resampling rasters with new X and Y-direction spacings and resizing rasters using new X and Y grid pixel resolutions. NaN-type data in rasters can be replaced with the values of their nearest valid neighbours. """
[docs] def __init__( self, data=None, plate_reconstruction=None, extent: Union[str, tuple] = "global", realign=False, resample=None, resize=None, time=0.0, origin=None, x_dimension_name: str = "", y_dimension_name: str = "", data_variable_name: str = "", **kwargs, ): """Constructor. Create a :class:`Raster` object. Parameters ---------- data : str or array-like The raster data, either as a file path (:class:`str`) or array data. plate_reconstruction : PlateReconstruction A :class:`PlateReconstruction` object to provide the following essential components for reconstructing points. * :py:attr:`PlateReconstruction.rotation_model` * :py:attr:`PlateReconstruction.topology_featues` * :py:attr:`PlateReconstruction.static_polygons` extent : str or 4-tuple, default: 'global' 4-tuple to specify (min_lon, max_lon, min_lat, max_lat) extents of the raster. If no extents are supplied, full global extent (-180, 180, -90, 90) is assumed (equivalent to ``extent='global'``). For array data with an upper-left origin, make sure ``min_lat`` is greater than ``max_lat``, or specify ``origin`` parameter. resample : 2-tuple, optional Optionally resample grid, pass spacing in X and Y direction as a 2-tuple e.g. resample=(spacingX, spacingY). resize : 2-tuple, optional Optionally resample grid to X-columns, Y-rows as a 2-tuple e.g. resample=(resX, resY). time : float, default: 0.0 The geological time the time-dependant raster data. origin : {'lower', 'upper'}, optional When ``data`` is an array, use this parameter to specify the origin (upper left or lower left) of the data (overriding ``extent``). x_dimension_name : str, optional, default="" If the grid file uses the comman names, such as ``x``, ``lon``, ``lons`` or ``longitude``, you need not to provide this parameter. Otherwise, you need to tell us what the x dimension name is. y_dimension_name : str, optional, default="" If the grid file uses the comman names, such as ``y``, ``lat``, ``lats`` or ``latitude``, you need not to provide this parameter. Otherwise, you need to tell us what the y dimension name is. data_variable_name : str, optional, default="" GPlately will try its best to guess the data variable name. However, it would be much better if you tell us what the data variable name is. Otherwise, GPlately's guess may/may not be correct. **kwargs Handle deprecated arguments such as ``PlateReconstruction_object``, ``filename``, and ``array``. """ if isinstance(data, self.__class__): self._data = data._data.copy() self.plate_reconstruction = data.plate_reconstruction self._lons = data._lons self._lats = data._lats self._time = data._time return if "PlateReconstruction_object" in kwargs.keys(): warnings.warn( "`PlateReconstruction_object` keyword argument has been " + "deprecated, use `plate_reconstruction` instead", DeprecationWarning, ) if plate_reconstruction is None: plate_reconstruction = kwargs.pop("PlateReconstruction_object") if "filename" in kwargs.keys() and "array" in kwargs.keys(): raise TypeError( "Both `filename` and `array` were provided; use " + "one or the other, or use the `data` argument" ) if "filename" in kwargs.keys(): warnings.warn( "`filename` keyword argument has been deprecated, " + "use `data` instead", DeprecationWarning, ) if data is None: data = kwargs.pop("filename") if "array" in kwargs.keys(): warnings.warn( "`array` keyword argument has been deprecated, " + "use `data` instead", DeprecationWarning, ) if data is None: data = kwargs.pop("array") for key in kwargs.keys(): raise TypeError( "Raster.__init__() got an unexpected keyword argument " + "'{}'".format(key) ) self.plate_reconstruction = plate_reconstruction if time < 0.0: raise ValueError("Invalid time: {}".format(time)) time = float(time) self._time = time if data is None: raise TypeError("`data` argument (or `filename` or `array`) is required") if isinstance(data, str): # Filename self._filename = data self._data, lons, lats = read_netcdf_grid( data, return_grids=True, realign=realign, resample=resample, resize=resize, x_dimension_name=x_dimension_name, y_dimension_name=y_dimension_name, data_variable_name=data_variable_name, ) self._lons = lons self._lats = lats else: # numpy array self._filename = None extent = _parse_extent_origin(extent, origin) data = _check_grid(data) self._data = np.array(data) self._lons = np.linspace(extent[0], extent[1], self.data.shape[1]) self._lats = np.linspace(extent[2], extent[3], self.data.shape[0]) if realign: # realign to -180,180 and flip grid self._data, self._lons, self._lats = _realign_grid( self._data, self._lons, self._lats ) if (not isinstance(data, str)) and (resample is not None): self.resample(*resample, inplace=True) if (not isinstance(data, str)) and (resize is not None): self.resize(*resize, inplace=True)
@property def time(self): """The geological time of the time-dependant raster data. :type: float """ return self._time @time.setter def time(self, new_time: float): """Set a new reconstruction time.""" try: new_time_f = float(new_time) except ValueError: raise ValueError(f"Invalid new reconstruction time: {new_time}") if new_time_f < 0.0: raise ValueError( f"The reconstruction time ({new_time_f}) must be greater than 0." ) if not math.isclose(self._time, new_time_f): self._time = new_time_f self.reconstruct(new_time_f, inplace=True) @property def data(self): """Array containing the raster data. This attribute can be modified after creating the :class:`Raster` object. :type: ndarray, shape (ny, nx) """ return self._data @data.setter def data(self, z): z = np.array(z) if z.shape != np.shape(self.data): raise ValueError( "Shape mismatch: old dimensions are {}, new are {}".format( np.shape(self.data), z.shape, ) ) self._data = z @property def lons(self): """The x-coordinates of the raster data. This attribute can be modified after creating the :class:`Raster` object. :type: ndarray, shape (nx,) """ return self._lons @lons.setter def lons(self, x): x = np.array(x).ravel() if x.size != np.shape(self.data)[1]: raise ValueError( "Shape mismatch: data x-dimension is {}, new value is {}".format( np.shape(self.data)[1], x.size, ) ) self._lons = x @property def lats(self): """The y-coordinates of the raster data. This attribute can be modified after creating the :class:`Raster` object. :type: ndarray, shape (ny,) """ return self._lats @lats.setter def lats(self, y): y = np.array(y).ravel() if y.size != np.shape(self.data)[0]: raise ValueError( "Shape mismatch: data y-dimension is {}, new value is {}".format( np.shape(self.data)[0], y.size, ) ) self._lats = y @property def extent(self): """The spatial extent ``(x0, x1, y0, y1)`` of the data. If not supplied, global extent ``(-180, 180, -90, 90)`` is assumed. If y0 < y1, the origin is the lower-left corner; else the upper-left. :type: tuple of 4 floats """ return ( float(self.lons[0]), float(self.lons[-1]), float(self.lats[0]), float(self.lats[-1]), ) @property def origin(self): """The origin (``lower`` or ``upper``) of the data array. :type: str """ if self.lats[0] < self.lats[-1]: return "lower" else: return "upper" @property def shape(self): """The shape of the data array.""" return np.shape(self.data) @property def size(self): """The size of the data array.""" return np.size(self.data) @property def dtype(self): """The data type of the array.""" return self.data.dtype @property def ndim(self): """The number of dimensions in the array.""" return np.ndim(self.data) @property def filename(self): """The filename used to create the :class:`Raster` object. If the object was created directly from an array, this attribute is ``None``. :type: str or None """ return self._filename @property def plate_reconstruction(self): """A :class:`PlateReconstruction` object to provide the following essential components for reconstructing points. * :py:attr:`PlateReconstruction.rotation_model` * :py:attr:`PlateReconstruction.topology_featues` * :py:attr:`PlateReconstruction.static_polygons` :type: PlateReconstruction """ return self._plate_reconstruction @plate_reconstruction.setter def plate_reconstruction(self, reconstruction): if reconstruction is None: # Remove `plate_reconstruction` attribute pass elif not isinstance(reconstruction, _PlateReconstruction): # Convert to a `PlateReconstruction` if possible try: reconstruction = _PlateReconstruction(*reconstruction) except Exception: reconstruction = _PlateReconstruction(reconstruction) self._plate_reconstruction = reconstruction
[docs] def copy(self): """Return a copy of the :class:`Raster` object. Returns ------- Raster A copy of the current :class:`Raster` object. """ return Raster( self.data.copy(), self.plate_reconstruction, self.extent, time=self.time )
[docs] def interpolate( self, lons, lats, method="linear", return_indices=False, ): """Sample grid data at a set of points using spline interpolation. Parameters ---------- lons, lats : array_like The longitudes and latitudes of the points to interpolate onto the gridded data. Must be broadcastable to a common shape. method : str or int; default: 'linear' The order of spline interpolation. Must be an integer in the range 0-5. ``nearest``, ``linear``, and ``cubic`` are aliases for 0, 1, and 3, respectively. return_indices : bool, default=False Whether to return the row and column indices of the nearest grid points. Returns ------- numpy.ndarray The values interpolated at the input points. indices : 2-tuple of numpy.ndarray The i- and j-indices of the nearest grid points to the input points, only present if ``return_indices=True``. Raises ------ ValueError If an invalid ``method`` is provided. RuntimeWarning If ``lats`` contains any invalid values outside of the interval [-90, 90]. Invalid values will be clipped to this interval. .. note:: If ``return_indices`` is set to ``True``, the nearest array indices are returned as a tuple of arrays, in ``(i, j)`` or ``(lat, lon)`` format. An example output: .. code:: console # The first array holds the rows of the raster where point data spatially falls near. # The second array holds the columns of the raster where point data spatially falls near. sampled_indices = (array([1019, 1019, 1019, ..., 1086, 1086, 1087]), array([2237, 2237, 2237, ..., 983, 983, 983])) """ return sample_grid( lon=lons, lat=lats, grid=self, method=method, return_indices=return_indices, )
[docs] def resample(self, spacingX, spacingY, method="linear", inplace=False): """Resamples the grid with a new ``spacingX`` and ``spacingY``, meshed with linear interpolation. .. note:: Ultimately, the :meth:`resample` changes the lat-lon resolution of the gridded data. The larger the x and y spacings given are, the larger the pixellation of raster data. The :meth:`resample` creates new latitude and longitude arrays with specified spacings in the X and Y directions (``spacingX`` and ``spacingY``). These arrays are linearly interpolated into a new raster. If ``inplace`` is set to ``True``, the respaced latitude array, longitude array and raster will inplace the ones currently attributed to the :class:`Raster` object. Parameters ---------- spacingX, spacingY : ndarray Specify the spacing in the X and Y directions with which to resample. The larger ``spacingX`` and ``spacingY`` are, the larger the raster pixels become (less resolved). Note: to keep the size of the raster consistent, set ``spacingX = spacingY``; otherwise, if for example ``spacingX > spacingY``, the raster will appear stretched longitudinally. method : str or int; default: 'linear' The order of spline interpolation. Must be an integer in the range 0-5. 'nearest', 'linear', and 'cubic' are aliases for 0, 1, and 3, respectively. inplace : bool, default=False Choose to overwrite the data (the ``self.data`` attribute), latitude array (``self.lats``) and longitude array (``self.lons``) currently attributed to the :class:`Raster` object. Returns ------- Raster The resampled grid. If ``inplace`` is set to ``True``, this raster overwrites the one attributed to ``data``. """ spacingX = np.abs(spacingX) spacingY = np.abs(spacingY) if self.origin == "upper": spacingY *= -1.0 lons = np.arange(self.extent[0], self.extent[1] + spacingX, spacingX) lats = np.arange(self.extent[2], self.extent[3] + spacingY, spacingY) lonq, latq = np.meshgrid(lons, lats) data = self.interpolate(lonq, latq, method=method) if inplace: self._data = data self._lons = lons self._lats = lats else: return Raster(data, self.plate_reconstruction, self.extent, self.time)
[docs] def resize(self, resX, resY, inplace=False, method="linear", return_array=False): """Resize the grid with a new resolution (``resX`` and ``resY``) using linear interpolation. .. note:: Ultimately, The :meth:`resize` "stretches" a raster in the x and y directions. The larger the resolutions in x and y, the more stretched the raster appears in x and y. It creates new latitude and longitude arrays with specific resolutions in the X and Y directions (``resX`` and ``resY``). These arrays are linearly interpolated into a new raster. If ``inplace`` is set to ``True``, the resized latitude, longitude arrays and raster will inplace the ones currently attributed to the :class:`Raster` object. Parameters ---------- resX, resY : ndarray Specify the resolutions with which to resize the raster. The larger ``resX`` is, the more longitudinally-stretched the raster becomes. The larger ``resY`` is, the more latitudinally-stretched the raster becomes. method : str or int; default: 'linear' The order of spline interpolation. Must be an integer in the range 0-5. 'nearest', 'linear', and 'cubic' are aliases for 0, 1, and 3, respectively. inplace : bool, default=False Choose to overwrite the data (the ``self.data`` attribute), latitude array (``self.lats``) and longitude array (``self.lons``) currently attributed to the :class:`Raster` object. return_array : bool, default False Return a ``numpy.ndarray``, rather than a :class:`Raster` object. Returns ------- Raster The resized grid. If ``inplace`` is set to ``True``, the data in :attr:`Raster.data` will be overwritten. """ # construct grid lons = np.linspace(self.extent[0], self.extent[1], resX) lats = np.linspace(self.extent[2], self.extent[3], resY) lonq, latq = np.meshgrid(lons, lats) data = self.interpolate(lonq, latq, method=method) if inplace: self._data = data self._lons = lons self._lats = lats if return_array: return data else: return Raster(data, self.plate_reconstruction, self.extent, time=self.time)
[docs] def fill_NaNs(self, inplace=False, return_array=False): """Search for the invalid ``data`` cells containing NaN-type entries and replaces NaNs with the value of the nearest valid data cell. Parameters --------- inplace : bool, default=False Choose whether to overwrite the grid currently held in the ``data`` attribute with the filled grid. return_array : bool, default False Return a ``numpy.ndarray``, rather than a :class:`Raster`. Returns -------- Raster The resized grid. If ``inplace`` is set to ``True``, the data in :attr:`Raster.data` will be overwritten. """ data = fill_raster(self.data) if inplace: self._data = data if return_array: return data else: return Raster(data, self.plate_reconstruction, self.extent, time=self.time)
[docs] def save_to_netcdf4(self, filename, significant_digits=None, fill_value=np.nan): """Saves the grid attributed to the :class:`Raster` object to the given ``filename`` (including the ".nc" extension) in netCDF4 format.""" write_netcdf_grid( str(filename), self.data, self.extent, significant_digits, fill_value )
[docs] def reconstruct( self, time, fill_value=None, partitioning_features=None, threads=1, anchor_plate_id=None, inplace=False, return_array=False, ): """Reconstruct the raster from its initial time (``self.time``) to a new time. Parameters ---------- time : float Time to which the data will be reconstructed. fill_value : float, int, str, or tuple, optional The value to be used for regions outside of the static polygons at ``time``. By default (``fill_value=None``), this value will be determined based on the input. partitioning_features : sequence of Feature or str, optional The features used to partition the raster grid and assign plate IDs. By default, ``self.plate_reconstruction.static_polygons`` will be used, but alternatively any valid argument to ``pygplates.FeaturesFunctionArgument`` can be specified here. threads : int, default 1 Number of threads to use for certain computationally heavy routines. anchor_plate_id : int, optional ID of the anchored plate. By default, reconstructions are made with respect to the anchor plate ID specified in the :class:`PlateReconstruction` object. inplace : bool, default False Perform the reconstruction in-place (replace the raster's data with the reconstructed data). return_array : bool, default False Return a ``numpy.ndarray``, rather than a :class:`Raster`. Returns ------- Raster or np.ndarray The reconstructed grid. Areas for which no plate ID could be determined will be filled with ``fill_value``. .. note:: For two-dimensional grids, ``fill_value`` should be a single number. The default value will be ``np.nan`` for float or complex types, the minimum value for integer types, and the maximum value for unsigned types. For RGB image grids, ``fill_value`` should be a 3-tuple RGB colour code or a matplotlib colour string. The default value will be black (0.0, 0.0, 0.0) or (0, 0, 0). For RGBA image grids, ``fill_value`` should be a 4-tuple RGBA colour code or a matplotlib colour string. The default fill value will be transparent black (0.0, 0.0, 0.0, 0.0) or (0, 0, 0, 0). """ try: to_time_f = float(time) except ValueError: raise ValueError(f"Invalid reconstruction time: {time}") if to_time_f < 0.0: raise ValueError( f"The reconstruction time ({to_time_f}) must be greater than 0." ) # A valid PlateReconstruction object is required! assert self.plate_reconstruction is not None if partitioning_features is None: partitioning_features = self.plate_reconstruction.static_polygons result = reconstruct_grid( grid=self.data, partitioning_features=partitioning_features, rotation_model=self.plate_reconstruction.rotation_model, from_time=self.time, to_time=to_time_f, extent=self.extent, origin=self.origin, fill_value=fill_value, threads=threads, anchor_plate_id=anchor_plate_id, ) raster_rotation_model = self.plate_reconstruction.rotation_model # use the new reconstructed raster data to replace the current Raster obj # TODO: maybe need to put anchor_plate_id into rotation_model if it is not None if inplace: self.data = result self._time = to_time_f if ( anchor_plate_id is not None and raster_rotation_model and raster_rotation_model.get_default_anchor_plate_id() != anchor_plate_id ): self.plate_reconstruction.rotation_model = pygplates.RotationModel( raster_rotation_model, default_anchor_plate_id=anchor_plate_id ) if return_array: return result return self # create a new Raster obj to return if not return_array: result = Raster( data=result, plate_reconstruction=copy.deepcopy(self.plate_reconstruction), extent=self.extent, time=to_time_f, origin=self.origin, ) if ( anchor_plate_id is not None and raster_rotation_model and raster_rotation_model.get_default_anchor_plate_id() != anchor_plate_id ): result.plate_reconstruction.rotation_model = pygplates.RotationModel( raster_rotation_model, default_anchor_plate_id=anchor_plate_id ) return result
[docs] def imshow(self, ax=None, projection=None, **kwargs): """Display raster data. A pre-existing matplotlib ``Axes`` instance is used if available, else a new one is created. The ``origin`` and ``extent`` of the image are determined automatically and should not be specified. Parameters ---------- ax : matplotlib.axes.Axes, optional If specified, the image will be drawn within these axes. projection : cartopy.crs.Projection, optional The map projection to be used. If both ``ax`` and ``projection`` are specified, this will be checked against the ``projection`` attribute of ``ax``, if it exists. **kwargs : dict, optional Any further keyword arguments are passed to ``matplotlib.pyplot.imshow`` or ``matplotlib.axes.Axes.imshow``, where appropriate. Returns ------- matplotlib.image.AxesImage Raises ------ ValueError If ``ax`` and ``projection`` are both specified, but do not match (i.e. ``ax.projection != projection``). """ for kw in ("origin", "extent"): if kw in kwargs.keys(): raise TypeError( "imshow got an unexpected keyword argument: {}".format(kw) ) if ax is None: existing_figure = len(plt.get_fignums()) > 0 current_axes = plt.gca() if projection is None: ax = current_axes elif ( isinstance(current_axes, _GeoAxes) and current_axes.projection == projection ): ax = current_axes else: if not existing_figure: current_axes.remove() ax = plt.axes(projection=projection) elif projection is not None: # projection and ax both specified if isinstance(ax, _GeoAxes) and ax.projection == projection: pass # projections match else: raise ValueError( "Both `projection` and `ax` were specified, but" + " `projection` does not match `ax.projection`" ) if isinstance(ax, _GeoAxes) and "transform" not in kwargs.keys(): kwargs["transform"] = _PlateCarree() extent = self.extent if self.origin == "upper": extent = ( extent[0], extent[1], extent[3], extent[2], ) im = ax.imshow(self.data, origin=self.origin, extent=extent, **kwargs) return im
plot = imshow
[docs] def rotate_reference_frames( self, grid_spacing_degrees, reconstruction_time, from_rotation_features_or_model=None, # filename(s), or pyGPlates feature(s)/collection(s) or a RotationModel to_rotation_features_or_model=None, # filename(s), or pyGPlates feature(s)/collection(s) or a RotationModel from_rotation_reference_plate=0, to_rotation_reference_plate=0, non_reference_plate=701, output_name=None, ): """Rotate a grid defined in one plate model reference frame within a :class:`Raster` object to another plate reconstruction model reference frame. Parameters ---------- grid_spacing_degrees : float The spacing (in degrees) for the output rotated grid. reconstruction_time : float The time at which to rotate the input grid. from_rotation_features_or_model : str, list of str, or instance of pygplates.RotationModel A filename, or a list of filenames, or a pyGPlates RotationModel object that defines the rotation model that the input grid is currently associated with. to_rotation_features_or_model : str, list of str, or instance of pygplates.RotationModel A filename, or a list of filenames, or a pyGPlates RotationModel object that defines the rotation model that the input grid shall be rotated with. from_rotation_reference_plate : int, default = 0 The current reference plate for the plate model the grid is defined in. Defaults to the anchor plate 0. to_rotation_reference_plate : int, default = 0 The desired reference plate for the plate model the grid is being rotated to. Defaults to the anchor plate 0. non_reference_plate : int, default = 701 An arbitrary placeholder reference frame with which to define the "from" and "to" reference frames. output_name : str, default None If passed, the rotated grid is saved as a netCDF grid to this filename. Returns ------- Raster An instance of the :class:`Raster` object containing the rotated grid. """ if from_rotation_features_or_model is None: if self.plate_reconstruction is None: raise ValueError("Set a plate reconstruction model") from_rotation_features_or_model = self.plate_reconstruction.rotation_model if to_rotation_features_or_model is None: if self.plate_reconstruction is None: raise ValueError("Set a plate reconstruction model") to_rotation_features_or_model = self.plate_reconstruction.rotation_model # Create the pygplates.FiniteRotation that rotates # between the two reference frames. from_rotation_model = pygplates.RotationModel(from_rotation_features_or_model) to_rotation_model = pygplates.RotationModel(to_rotation_features_or_model) from_rotation = from_rotation_model.get_rotation( reconstruction_time, non_reference_plate, anchor_plate_id=from_rotation_reference_plate, ) to_rotation = to_rotation_model.get_rotation( reconstruction_time, non_reference_plate, anchor_plate_id=to_rotation_reference_plate, ) reference_frame_conversion_rotation = to_rotation * from_rotation.get_inverse() # Resize the input grid to the specified output resolution before rotating resX = _deg2pixels(grid_spacing_degrees, self.extent[0], self.extent[1]) resY = _deg2pixels(grid_spacing_degrees, self.extent[2], self.extent[3]) resized_input_grid = self.resize(resX, resY, inplace=False) # Get the flattened lons, lats llons, llats = np.meshgrid(resized_input_grid.lons, resized_input_grid.lats) llons = llons.ravel() llats = llats.ravel() # Convert lon-lat points of Raster grid to pyGPlates points input_points = pygplates.MultiPointOnSphere( (lat, lon) for lon, lat in zip(llons, llats) ) # Get grid values of the resized Raster object values = np.array(resized_input_grid.data).ravel() # Rotate grid nodes to the other reference frame output_points = reference_frame_conversion_rotation * input_points # Assemble rotated points with grid values. out_lon = np.empty_like(llons) out_lat = np.empty_like(llats) zdata = np.empty_like(values) for i, point in enumerate(output_points): out_lat[i], out_lon[i] = point.to_lat_lon() zdata[i] = values[i] # Create a regular grid on which to interpolate lats, lons and zdata # Use the extent of the original Raster object extent_globe = self.extent resX = ( int(np.floor((extent_globe[1] - extent_globe[0]) / grid_spacing_degrees)) + 1 ) resY = ( int(np.floor((extent_globe[3] - extent_globe[2]) / grid_spacing_degrees)) + 1 ) grid_lon = np.linspace(extent_globe[0], extent_globe[1], resX) grid_lat = np.linspace(extent_globe[2], extent_globe[3], resY) X, Y = np.meshgrid(grid_lon, grid_lat) # Interpolate lons, lats and zvals over a regular grid using nearest # neighbour interpolation Z = griddata_sphere((out_lon, out_lat), zdata, (X, Y), method="nearest") # Write output grid to netCDF if requested. if output_name: write_netcdf_grid(output_name, Z, extent=extent_globe) return Raster(data=Z)
[docs] def query(self, lons, lats, region_of_interest=None): """Given a set of location coordinates, return the grid values at these locations. Parameters ---------- lons: list a list of longitudes of the location coordinates lats: list a list of latitude of the location coordinates region_of_interest: float the radius of the region of interest in km this is the arch length. we need to calculate the straight distance between the two points in 3D space from this arch length. Returns ------- list a list of grid values for the given locations. """ if not hasattr(self, "spatial_cKDTree"): # build the spatial tree if the tree has not been built yet x0 = self.extent[0] x1 = self.extent[1] y0 = self.extent[2] y1 = self.extent[3] yn = self.data.shape[0] xn = self.data.shape[1] # we assume the grid is Grid-line Registration, not Pixel Registration # http://www.soest.hawaii.edu/pwessel/courses/gg710-01/GMT_grid.pdf # TODO: support both Grid-line and Pixel Registration grid_x, grid_y = np.meshgrid( np.linspace(x0, x1, xn), np.linspace(y0, y1, yn) ) # in degrees self.grid_cell_radius = ( math.sqrt(math.pow(((y0 - y1) / yn), 2) + math.pow(((x0 - x1) / xn), 2)) / 2 ) self.data_mask = ~np.isnan(self.data) grid_points = [ pygplates.PointOnSphere((float(p[1]), float(p[0]))).to_xyz() for p in np.dstack((grid_x, grid_y))[self.data_mask] ] logger.debug("building the spatial tree...") self.spatial_cKDTree = _cKDTree(grid_points) query_points = [ pygplates.PointOnSphere((float(p[1]), float(p[0]))).to_xyz() for p in zip(lons, lats) ] if region_of_interest is None: # convert the arch length(in degrees) to direct length in 3D space roi = 2 * math.sin(math.radians(self.grid_cell_radius / 2.0)) else: roi = 2 * math.sin( region_of_interest / pygplates.Earth.mean_radius_in_kms / 2.0 ) dists, indices = self.spatial_cKDTree.query( query_points, k=1, distance_upper_bound=roi ) # print(dists, indices) return np.concatenate((self.data[self.data_mask], [math.nan]))[indices]
[docs] def clip_by_extent(self, extent): """Clip the raster according to a given extent ``(x_min, x_max, y_min, y_max)``. The extent of the returned raster may be slightly bigger than the given extent. This happens when the border of the given extent fall between two gird lines. Parameters ---------- extent: tuple A tuple of 4 (min_lon, max_lon, min_lat, max_lat) extent. Returns -------- Raster The clipped grid. """ if ( extent[0] >= extent[1] or extent[2] >= extent[3] or extent[0] < -180 or extent[1] > 180 or extent[2] < -90 or extent[3] > 90 ): raise Exception(f"Invalid extent: {extent}") if ( extent[0] < self.extent[0] or extent[1] > self.extent[1] or extent[2] < self.extent[2] or extent[3] > self.extent[3] ): raise Exception( f"The given extent is out of scope. {extent} -- {self.extent}" ) y_len, x_len = self.data.shape logger.debug(f"the shape of raster data x:{x_len} y:{y_len}") x0 = math.floor( (extent[0] - self.extent[0]) / (self.extent[1] - self.extent[0]) * (x_len - 1) ) x1 = math.ceil( (extent[1] - self.extent[0]) / (self.extent[1] - self.extent[0]) * (x_len - 1) ) # print(x0, x1) y0 = math.floor( (extent[2] - self.extent[2]) / (self.extent[3] - self.extent[2]) * (y_len - 1) ) y1 = math.ceil( (extent[3] - self.extent[2]) / (self.extent[3] - self.extent[2]) * (y_len - 1) ) # print(y0, y1) new_extent = ( x0 / (x_len - 1) * (self.extent[1] - self.extent[0]) - 180, x1 / (x_len - 1) * (self.extent[1] - self.extent[0]) - 180, y0 / (y_len - 1) * (self.extent[3] - self.extent[2]) - 90, y1 / (y_len - 1) * (self.extent[3] - self.extent[2]) - 90, ) # print(new_extent) # print(self.data[y0 : y1 + 1, x0 : x1 + 1].shape) return Raster( data=self.data[y0 : y1 + 1, x0 : x1 + 1], extent=new_extent, )
def _clip_by_polygon(self, polygon): """TODO:""" pass def __array__(self): return np.array(self.data) def __add__(self, other): if isinstance(other, Raster): # Return array, since we don't know which Raster # to take properties from return self.data + other.data # Return Raster with new data new_raster = self.copy() new_data = self.data + other new_raster.data = new_data return new_raster def __radd__(self, other): return self + other def __sub__(self, other): if isinstance(other, Raster): # Return array, since we don't know which Raster # to take properties from return self.data - other.data # Return Raster with new data new_raster = self.copy() new_data = self.data - other new_raster.data = new_data return new_raster def __rsub__(self, other): if isinstance(other, Raster): # Return array, since we don't know which Raster # to take properties from return other.data - self.data # Return Raster with new data new_raster = self.copy() new_data = other - self.data new_raster.data = new_data return new_raster def __mul__(self, other): if isinstance(other, Raster): # Return array, since we don't know which Raster # to take properties from return self.data * other.data # Return Raster with new data new_raster = self.copy() new_data = self.data * other new_raster.data = new_data return new_raster def __rmul__(self, other): return self * other def __truediv__(self, other): if isinstance(other, Raster): # Return array, since we don't know which Raster # to take properties from return self.data / other.data # Return Raster with new data new_raster = self.copy() new_data = self.data / other new_raster.data = new_data return new_raster def __rtruediv__(self, other): if isinstance(other, Raster): # Return array, since we don't know which Raster # to take properties from return other.data / self.data # Return Raster with new data new_raster = self.copy() new_data = other / self.data new_raster.data = new_data return new_raster def __floordiv__(self, other): if isinstance(other, Raster): # Return array, since we don't know which Raster # to take properties from return self.data // other.data # Return Raster with new data new_raster = self.copy() new_data = self.data // other new_raster.data = new_data return new_raster def __rfloordiv__(self, other): if isinstance(other, Raster): # Return array, since we don't know which Raster # to take properties from return other.data // self.data # Return Raster with new data new_raster = self.copy() new_data = other // self.data new_raster.data = new_data return new_raster def __mod__(self, other): if isinstance(other, Raster): # Return array, since we don't know which Raster # to take properties from return self.data % other.data # Return Raster with new data new_raster = self.copy() new_data = self.data % other new_raster.data = new_data return new_raster def __rmod__(self, other): if isinstance(other, Raster): # Return array, since we don't know which Raster # to take properties from return other.data % self.data # Return Raster with new data new_raster = self.copy() new_data = other % self.data new_raster.data = new_data return new_raster def __pow__(self, other): if isinstance(other, Raster): # Return array, since we don't know which Raster # to take properties from return self.data**other.data # Return Raster with new data new_raster = self.copy() new_data = self.data**other new_raster.data = new_data return new_raster def __rpow__(self, other): if isinstance(other, Raster): # Return array, since we don't know which Raster # to take properties from return other.data**self.data # Return Raster with new data new_raster = self.copy() new_data = other**self.data new_raster.data = new_data return new_raster
# class TimeRaster(Raster): # """A class for the temporal manipulation of raster data. To be added soon!""" # def __init__(self, PlateReconstruction_object=None, filename=None, array=None, extent=None, resample=None): # raise NotImplementedError( # "This class has not been implemented; use `Raster` instead" # ) # super(TimeRaster, self).__init__(PlateReconstruction_object)