Source code for mascdb.utils_img

# -----------------------------------------------------------------------------.
# Copyright (c) 2021-2025 MASCDB developers
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the MIT License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
#
# You should have received a copy of the MIT License
# along with this program.  If not, see <https://opensource.org/license/mit/>.
# -----------------------------------------------------------------------------.
"""Utilities for image processing."""
import numpy as np
import xarray as xr
from dask.diagnostics import ProgressBar
from skimage import exposure
from skimage.filters import rank
from skimage.morphology import rectangle

# Performs Gamma Correction on the input image.
# skimage.exposure.adjust_gamma(image[, ...])

# Performs Logarithmic correction on the input image.
# skimage.exposure.adjust_log(image[, gain, inv])

# Performs Sigmoid Correction on the input image
# skimage.exposure.adjust_sigmoid(image[, ...])

# http://www.janeriksolem.net/histogram-equalization-with-python-and.html


#################################################
### Workhorse to compute 2D image descriptors ###
#################################################
def _compute_2Dimage_descriptors(da, fun, labels, x="x", y="y", fun_kwargs=None, dask="parallelized"):
    # Checks arguments
    if fun_kwargs is None:
        fun_kwargs = {}
    if not isinstance(da, xr.DataArray):
        raise TypeError("Expecting a xr.DataArray.")
    if not isinstance(x, str):
        raise TypeError("'x' must be a string indicating the width dimension name of the DataArray.")
    if not isinstance(y, str):
        raise TypeError("'y' must be a string indicating the height dimension name of the DataArray.")
    if not isinstance(labels, (str, list)):
        raise TypeError("Descriptor 'labels' must be a str or list of strings.")
    if isinstance(labels, str):
        labels = [labels]
    labels = np.array(labels)
    if not isinstance(labels[0].item(), str):
        raise ValueError("Descriptor 'labels' must be a list of strings.")
    # -----------------------------------------------------------------------.
    # Retrieve DataArray dimension original order
    dims = da.dims
    # -----------------------------------------------------------------------.
    # Check x and y are dimension of the DataArray
    if x not in dims:
        raise ValueError(f"x={x!r} is not a dimension of the DataArray")
    if y not in dims:
        raise ValueError(f"y={y!r} is not a dimension of the DataArray")
    # -----------------------------------------------------------------------.
    ### Retrieve dimensions to eventually stack along a new third dimension
    unstacked_dims = list(set(dims).difference([x, y]))
    # - If only x and y, do nothing
    if len(unstacked_dims) == 0:
        # raise ValueError("Expecting a DataArray with a third dimension in "
        #                  " addition to {!r} and  {!r}".format(x,y))
        stack_dict = {}
        da_stacked = da
    # - If there is already a third dimension, transpose to the last
    elif len(unstacked_dims) == 1:
        img_id = unstacked_dims[0]
        stack_dict = {}
        da_stacked = da.stack(stack_dict).transpose(..., img_id)
    # - If there is more than 3 dimensions, stack it all into a new third dimension
    elif len(unstacked_dims) > 1:
        img_id = "img_id"
        stack_dict = {img_id: unstacked_dims}
        # Stack all additional dimensions into a 3D array with all img_id in the last dimension
        da_stacked = da.stack(stack_dict).transpose(..., img_id)
    else:
        raise NotImplementedError
    # -----------------------------------------------------------------------.
    ### Check the function
    # TODO: checks that len(labels) = len(arr)

    # -----------------------------------------------------------------------.
    ### Compute descriptors for each 2D image
    vectorize = True  # because the function work only on 2D image
    da_stacked = xr.apply_ufunc(
        fun,
        da_stacked,
        input_core_dims=[[x, y]],
        output_core_dims=[["descriptor"]],  # returned data has one dimension
        kwargs=fun_kwargs,
        dask=dask,
        vectorize=vectorize,
        dask_gufunc_kwargs={"output_sizes": {"descriptor": len(labels)}},
        output_dtypes=["float64"],
    )  # TODO: automate

    # Compute the descriptors
    with ProgressBar():
        da_stacked = da_stacked.compute()

    # Add descriptor coordinates
    da_stacked = da_stacked.assign_coords({"descriptor": labels})

    # -----------------------------------------------------------------------.
    # Retrieve dataarray of descriptors
    da_descriptors = da_stacked.unstack(stack_dict)

    # -----------------------------------------------------------------------.
    return da_descriptors


##----------------------------------------------------------------------------.
#####################################
### Workhorse to modify 2D images ###
#####################################
[docs] def apply_2Dimage_fun(da, fun, x="x", y="y", fun_kwargs=None): """Apply a function to each 2D image in a DataArray. This function applies a user-defined function to 2D images stored in a xarray DataArray. It handles DataArrays with multiple dimensions by stacking and unstacking as needed, ensuring the function is applied to each 2D image independently. Parameters ---------- da : xarray.DataArray Input DataArray containing 2D images. fun : callable Function to apply to each 2D image. Should accept a 2D numpy array and return a 2D numpy array. x : str, optional Name of the width dimension. Default is "x". y : str, optional Name of the height dimension. Default is "y". fun_kwargs : dict, optional Additional keyword arguments to pass to the function. Default is None. Returns ------- xarray.DataArray DataArray with the function applied to each 2D image, maintaining original dimensions. Raises ------ TypeError If da is not a xarray.DataArray or if x/y are not strings. ValueError If x or y are not dimensions of the DataArray. """ # Checks arguments if fun_kwargs is None: fun_kwargs = {} if not isinstance(da, xr.DataArray): raise TypeError("Expecting a xr.DataArray.") if not isinstance(x, str): raise TypeError("'x' must be a string indicating the width dimension name of the DataArray.") if not isinstance(y, str): raise TypeError("'y' must be a string indicating the height dimension name of the DataArray.") # -----------------------------------------------------------------------. # Retrieve DataArray dimension original order dims = da.dims # -----------------------------------------------------------------------. # Check x and y are dimension of the DataArray if x not in dims: raise ValueError(f"x={x!r} is not a dimension of the DataArray") if y not in dims: raise ValueError(f"y={y!r} is not a dimension of the DataArray") # -----------------------------------------------------------------------. ### Retrieve dimensions to eventually stack along a new third dimension unstacked_dims = list(set(dims).difference([x, y])) # - If only x and y, do nothing if len(unstacked_dims) == 0: # raise ValueError("Expecting a DataArray with a third dimension in " # " addition to {!r} and {!r}".format(x,y)) stack_dict = {} da_stacked = da # - If there is already a third dimension, transpose to the last elif len(unstacked_dims) == 1: img_id = unstacked_dims[0] stack_dict = {} da_stacked = da.stack(stack_dict).transpose(..., img_id) # - If there is more than 3 dimensions, stack it all into a new third dimension elif len(unstacked_dims) > 1: img_id = "img_id" stack_dict = {img_id: unstacked_dims} # Stack all additional dimensions into a 3D array with all img_id in the last dimension da_stacked = da.stack(stack_dict).transpose(..., img_id) else: raise NotImplementedError # -----------------------------------------------------------------------. ### Apply the function to each 2D image dask = "parallelized" # 'allowed' vectorize = True # because the function work only on 2D image da_stacked = xr.apply_ufunc( fun, da_stacked, input_core_dims=[[x, y]], output_core_dims=[[x, y]], kwargs=fun_kwargs, dask=dask, vectorize=vectorize, output_dtypes=da_stacked.values.dtype, ) # -----------------------------------------------------------------------. # Unstack back to original dimensions da = da_stacked.unstack(stack_dict).transpose(*dims) # -----------------------------------------------------------------------. return da
##----------------------------------------------------------------------------. ################## ### Zoom utils ### ################## def _internal_bbox(img): rows = np.any(img, axis=1) cols = np.any(img, axis=0) rmin, rmax = np.where(rows)[0][[0, -1]] cmin, cmax = np.where(cols)[0][[0, -1]] return rmin, rmax, cmin, cmax def _get_zoomed_image(img): rmin, rmax, cmin, cmax = _internal_bbox(img) zoom_img = img[rmin : rmax + 1, cmin : cmax + 1] return zoom_img def _center_image(img, nrow, ncol): r, c = img.shape col_incr = int((ncol - c) / 2) row_incr = int((nrow - r) / 2) arr = np.zeros((nrow, ncol)) arr[slice(row_incr, row_incr + r), slice(col_incr, col_incr + c)] = img return arr
[docs] def xri_zoom(da, x="x", y="y", squared=False): """Zoom into 2D images by cropping to non-zero regions and centering. This function removes zero-valued borders from images, crops to the smallest bounding box containing all non-zero pixels, and centers the result. Optionally creates square images. Parameters ---------- da : xarray.DataArray Input DataArray containing 2D images. x : str, optional Name of the width dimension. Default is "x". y : str, optional Name of the height dimension. Default is "y". squared : bool, optional If True, output images will be square (same height and width). If False, output images maintain their aspect ratio. Default is False. Returns ------- xarray.DataArray DataArray with zoomed and centered images. Raises ------ TypeError If da is not a xarray.DataArray or if x/y are not strings. ValueError If x or y are not dimensions of the DataArray. """ # Checks arguments if not isinstance(da, xr.DataArray): raise TypeError("Expecting a xr.DataArray.") if not isinstance(x, str): raise TypeError("'x' must be a string indicating the width dimension name of the DataArray.") if not isinstance(y, str): raise TypeError("'y' must be a string indicating the height dimension name of the DataArray.") # -----------------------------------------------------------------------. # Retrieve DataArray dimension original order dims = da.dims # -----------------------------------------------------------------------. # Check x and y are dimension of the DataArray if x not in dims: raise ValueError(f"x={x!r} is not a dimension of the DataArray") if y not in dims: raise ValueError(f"y={y!r} is not a dimension of the DataArray") # -----------------------------------------------------------------------. ### Retrieve dimensions to eventually stack along a new third dimension unstacked_dims = list(set(dims).difference([x, y])) ## Enforce (..., y, x) dimension order da = da.transpose(..., y, x) # - If only x and y, do nothing if len(unstacked_dims) == 0: # raise ValueError("Expecting a DataArray with a third dimension in " # " addition to {!r} and {!r}".format(x,y)) stack_dict = {} da_stacked = da n_imgs = 0 # - If there is already a third dimension, transpose to the last elif len(unstacked_dims) == 1: img_id = unstacked_dims[0] stack_dict = {} da_stacked = da.stack(stack_dict).transpose(..., img_id) # Retrieve number of images n_imgs = da_stacked.shape[2] # - If there is more than 3 dimensions, stack it all into a new third dimension elif len(unstacked_dims) > 1: img_id = "img_id" stack_dict = {img_id: unstacked_dims} # Stack all additional dimensions into a 3D array with all img_id in the last dimension da_stacked = da.stack(stack_dict).transpose(..., img_id) # Retrieve number of images n_imgs = da_stacked.shape[2] else: raise NotImplementedError # -----------------------------------------------------------------------. # Extract the list of images l_imgs = [da_stacked.values] if n_imgs == 0 else [da_stacked.isel({img_id: i}).values for i in range(n_imgs)] # -----------------------------------------------------------------------. # Zoom a list of image l_zoomed = [_get_zoomed_image(img) for img in l_imgs] # Retrieve shape of all zoomed images l_shapes = [img.shape for img in l_zoomed] # Get number of row an columns of the largest zoomed image r_max, c_max = (max(n) for n in zip(*l_shapes)) # Define size of the zoomed image if squared: r_max = max([r_max, c_max]) c_max = max([r_max, c_max]) # Center all images l_zoomed = [_center_image(img, nrow=r_max, ncol=c_max) for img in l_zoomed] # Assign it to the DataArray (with new x and y dimensions) da_stacked = da_stacked.isel(y=slice(0, r_max), x=slice(0, c_max)) if n_imgs == 0: da_stacked.values = l_zoomed[0] else: da_stacked.values = np.stack(l_zoomed, axis=-1) # -----------------------------------------------------------------------. # Unstack back to original dimensions da = da_stacked.unstack(stack_dict).transpose(*dims) # -----------------------------------------------------------------------. return da
##----------------------------------------------------------------------------. ############################################ ### Functions to enhance image contrast #### ############################################ def _contrast_stretching(img, pmin=1, pmax=99): # Get source dtype src_dtype = img.dtype # Retrieve img mask img_mask = img == 0 # Compute percentiles p2, p98 = np.percentile(img, (pmin, pmax)) # Perform contrast stretching using percentile values as the intensity range img = exposure.rescale_intensity(img, in_range=(p2, p98)) # Perform contrast stretching using min/max of the dtype as the intensity range # img = exposure.rescale_intensity(img, dtype=src_dtype) # Change dtype img = img.astype(src_dtype) # Mask regions that were 0 before stretching img[img_mask] = 0 return img def _hist_equalization(img, adaptive=False, nbins=256, kernel_size=None, clip_limit=0.03): """Apply global or adaptive histogram equalization to an image. Parameters ---------- img : arr Image array adaptive : bool, optional If False, it employs the classical histogram equalization. If True, it employs Contrast Limited Adaptive Histogram Equalization (CLAHE). CLAHE uses histograms computed over different tile regions of the image. Local details can therefore be enhanced even in regions that are darker or lighter than most of the image. The default is False. nbins: int, optional Number of bins for image histogram. Note: this argument is ignored for integer images, for which each integer is its own bin. kernel_size: int or array-like, optional Argument used by CLAHE. Defines the shape of contextual regions used in the algorithm. By default, kernel_size is 1/8 of image height by 1/8 of its width. clip_limit: float, optional Argument used by CLAHE. Clipping limit, normalized between 0 and 1 (higher values give more contrast). By default clip_limit=0.01. Returns ------- img : arr Image array after histogram equalization. """ # Get source dtype src_dtype = img.dtype # Retrieve img mask img_mask = img == 0 # Retrieve mask for hist equalization hist_mask = img > 0 # Perform histogram equalization # - The output values are between 0 and 1 !!! if not adaptive: img = exposure.equalize_hist(img, nbins=nbins, mask=hist_mask) else: img = exposure.equalize_adapthist(img, clip_limit=clip_limit, kernel_size=kernel_size, nbins=nbins) # Rescale to 0-255 img = img * 255 # Change dtype img = img.astype(src_dtype) # Mask regions that were 0 before stretching img[img_mask] = 0 return img def _local_hist_equalization(img, footprint=None): """Equalize an image using local histograms. Parameters ---------- img : arr (uint8, uint16) Image array footprint: array The neighborhood expressed as an ndarray of 1 and 0. By default it uses a rectangle of size 1/8 of image height and width Custom footprints can be easily generated using skimage.morphology functions such as <rectangle, disk, square,star, diamond, octagon,...> Returns ------- img : arr Image array after equalization. """ # Get source dtype src_dtype = img.dtype # Define footprint is None if footprint is None: nrows, ncols = img.shape footprint = rectangle(int(nrows / 8), int(ncols / 8)) # Retrieve img mask img_mask = img == 0 # Retrieve mask for hist equalization hist_mask = img > 0 # Perform local equalization img = rank.equalize(np.array(img), footprint, mask=hist_mask) # Change dtype img = img.astype(src_dtype) # Mask regions that were 0 before stretching img[img_mask] = 0 return img ### Wrappers
[docs] def xri_contrast_stretching(da, x="x", y="y", pmin=2, pmax=98): """Apply contrast stretching to 2D images using percentile-based intensity rescaling. Contrast stretching improves image contrast by remapping pixel intensities based on specified percentiles, expanding the dynamic range of the image. Parameters ---------- da : xarray.DataArray Input DataArray containing 2D images. x : str, optional Name of the width dimension. Default is "x". y : str, optional Name of the height dimension. Default is "y". pmin : float, optional Lower percentile for intensity remapping. Default is 2. pmax : float, optional Upper percentile for intensity remapping. Default is 98. Returns ------- xarray.DataArray DataArray with contrast-stretched images. Notes ----- Zero-valued pixels are preserved and not affected by the stretching operation. """ fun_kwargs = {"pmin": pmin, "pmax": pmax} da = apply_2Dimage_fun(da=da, fun=_contrast_stretching, x=x, y=y, fun_kwargs=fun_kwargs) return da
[docs] def xri_hist_equalization(da, x="x", y="y", nbins=256, adaptive=False, kernel_size=None, clip_limit=0.01): """Apply global or adaptive histogram equalization to 2D images. Histogram equalization enhances image contrast by redistributing pixel intensities to approximate a uniform distribution. Adaptive equalization (CLAHE) computes histograms over local tile regions for better enhancement of local details. Parameters ---------- da : xarray.DataArray Input DataArray containing 2D images. x : str, optional Name of the width dimension. Default is "x". y : str, optional Name of the height dimension. Default is "y". nbins : int, optional Number of bins for image histogram. Ignored for integer images where each integer is its own bin. Default is 256. adaptive : bool, optional If False, uses classical histogram equalization. If True, uses Contrast Limited Adaptive Histogram Equalization (CLAHE). Default is False. kernel_size : int or array-like, optional Shape of contextual regions used in CLAHE algorithm. By default, uses 1/8 of image height by 1/8 of image width. clip_limit : float, optional Clipping limit for CLAHE, normalized between 0 and 1. Higher values give more contrast. Default is 0.01. Returns ------- xarray.DataArray DataArray with histogram-equalized images. Notes ----- Zero-valued pixels are preserved and not affected by the equalization. """ fun_kwargs = { "nbins": nbins, "adaptive": adaptive, "kernel_size": kernel_size, "clip_limit": clip_limit, } da = apply_2Dimage_fun(da=da, fun=_hist_equalization, x=x, y=y, fun_kwargs=fun_kwargs) return da
[docs] def xri_local_hist_equalization(da, x="x", y="y", footprint=None): """Equalize images using local histograms with a specified neighborhood footprint. This function performs histogram equalization using local histograms computed over a neighborhood defined by the footprint parameter. This allows for better enhancement of local details compared to global histogram equalization. Parameters ---------- da : xarray.DataArray Input DataArray containing 2D images with dtype uint8 or uint16. x : str, optional Name of the width dimension. Default is "x". y : str, optional Name of the height dimension. Default is "y". footprint : numpy.ndarray, optional The neighborhood expressed as an ndarray of 1's and 0's. By default, uses a rectangle of size 1/8 of image height and width. Custom footprints can be generated using skimage.morphology functions (e.g., rectangle, disk, square, star, diamond, octagon). Returns ------- xarray.DataArray DataArray with locally equalized images. Notes ----- Zero-valued pixels are preserved and not affected by the equalization. The input images should have dtype uint8 or uint16. """ fun_kwargs = {"footprint": footprint} da = apply_2Dimage_fun(da=da, fun=_local_hist_equalization, x=x, y=y, fun_kwargs=fun_kwargs) return da