Source code for mascdb.api

# -----------------------------------------------------------------------------.
# Copyright (c) 2021-2025 MASCDB developers
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the MIT License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
#
# You should have received a copy of the MIT License
# along with this program.  If not, see <https://opensource.org/license/mit/>.
# -----------------------------------------------------------------------------.
"""MASCDB API."""
import copy
import os
import shutil

import dask
import numpy as np
import pandas as pd
import xarray as xr

from mascdb.utils_aux import (
    get_melting_class_name_dict,
    get_precip_class_name_dict,
    get_riming_class_name_dict,
    get_snowflake_class_name_dict,
    get_vars_class,
    var_explanations,
    var_units,
)
from mascdb.utils_event import _define_event_id, _get_timesteps_duration
from mascdb.utils_img import (
    _compute_2Dimage_descriptors,
    xri_contrast_stretching,
    xri_hist_equalization,
    xri_local_hist_equalization,
    xri_zoom,
)


####--------------------------------------------------------------------------.
###############
#### Checks ###
###############
def _check_cam_id(cam_id):
    """Return cam_id integer."""
    if not isinstance(cam_id, (int, np.int64, list)):
        raise TypeError("'cam_id', if specified, must be an integer or list (of length 1).")
    if isinstance(cam_id, list):
        if len(cam_id) != 1:
            raise ValueError("Expecting a single value for 'cam_id'.")
        cam_id = int(cam_id[0])
    # Check value validity
    if cam_id not in [0, 1, 2]:
        raise ValueError("Valid values of 'cam_id' are [0,1,2].")
    # Return integer cam_id
    return cam_id


def _check_index(index, vmax):
    """Return index integer."""
    if not isinstance(index, (int, np.int64, list)):
        raise TypeError("'index', if specified, must be an integer or list (of length 1).")
    if isinstance(index, list):
        if len(index) != 1:
            raise ValueError("Expecting a single value for 'index'.")
        index = int(index[0])
    # Check value validity
    if index < 0:
        raise ValueError("'index' must be a positive integer.")
    if index > vmax:
        raise ValueError(f"The largest 'index' can be {vmax}")
    # Return integer index
    return index


def _check_indices(indices, vmax):
    """Return index integer."""
    if not isinstance(indices, (int, np.int64, list)):
        raise TypeError("'indices', if specified, must be an integer or list (of length 1).")
    if isinstance(indices, (int, np.int64)):
        indices = [indices]
    # Check value validity
    if any(not isinstance(index, (int, np.int64)) for index in indices):
        raise ValueError("All 'indices' values must be integers.")
    if any(index < 0 for index in indices):
        raise ValueError("All 'indices' values must be positive integers.")
    if any(index > vmax for index in indices):
        raise ValueError(f"The largest 'indices' value can be {vmax}")
    # Return list of indices
    return indices


def _check_n_triplets(n_triplets, vmax):
    if not isinstance(n_triplets, int):
        raise TypeError("'n_triplets' must be an integer.")
    if n_triplets < 1:
        raise ValueError("'n_triplets' must be at least 1.")
    if n_triplets > vmax:
        raise ValueError(f"'n_triplets' must be maximum {vmax}.")
    if n_triplets > 10:
        raise ValueError("It's not recommended to plot more than 10 triplets of images.")


def _check_n_images(n_images, vmax):
    if not isinstance(n_images, int):
        raise TypeError("'n_images' must be an integer.")
    if n_images < 1:
        raise ValueError("'n_images' must be at least 1.")
    if n_images > vmax:
        raise ValueError(f"'n_images' must be maximum {vmax}.")
    if n_images > 25:
        raise ValueError("It's not recommended to plot more than 25 images.")


def _check_random(random):
    if not isinstance(random, bool):
        raise TypeError("'random' must be either True or False.")


def _check_zoom(zoom):
    if not isinstance(zoom, bool):
        raise TypeError("'zoom' must be either True or False.")


def _check_enhancement(enhancement):
    if not isinstance(enhancement, (type(None), str)):
        raise TypeError("'enhancement' must be a string (or None).")
    if isinstance(enhancement, str):
        valid_enhancements = ["histogram_equalization", "contrast_stretching", "local_equalization"]
        if enhancement not in valid_enhancements:
            raise ValueError(f"{enhancement!r} is not a valid enhancement. " f"Specify one of {valid_enhancements}")


def _check_isel_idx(idx, vmax):
    # Return a numpy array of positional idx
    if not isinstance(idx, (int, list, slice, pd.Series, np.ndarray)):
        raise ValueError("isel expect a slice object, an integer or a list/pd.Series/np.array of int or boolean.")
    # Reformat all types to unique format (numpy array of integers)
    if isinstance(idx, int):
        idx = np.array([idx])
    if isinstance(idx, list):
        idx = np.array(idx)
        if idx.dtype.name not in ["bool", "int64"]:
            raise ValueError("Expecting values in the idx list to be of 'bool' or 'int' type.")
    if isinstance(idx, slice):
        idx = np.arange(idx.start, idx.stop, idx.step)

    if isinstance(idx, pd.Series):
        idx = np.where(idx.values)[0] if idx.dtype.name in ["bool", "boolean"] else np.array(idx.values)

    if isinstance(idx, np.ndarray):
        if idx.dtype.name in ["bool", "boolean"]:
            idx = np.where(idx)[0]
        if idx.dtype.name != "int64":
            raise ValueError("Expecting idx np.array to be of 'bool' or 'int64' type.")

    # --------------------------------------------------------------------.
    # Check idx validity
    if np.any(idx > vmax):
        raise ValueError(f"The maximum positional idx is {vmax}")
    if np.any(idx < 0):
        raise ValueError("The positional idx must be positive integers.")

    # --------------------------------------------------------------------.
    # Return idx
    return idx


def _check_sel_ids(ids, valid_ids):
    # Return a numpy array of str
    if not isinstance(ids, (str, list, pd.Series, np.ndarray)):
        raise ValueError("sel expect a string or a list/pd.Series/np.array of str")
    if valid_ids.dtype.name == "object":
        valid_ids = valid_ids.astype(str)
    # Reformat all types to unique format (numpy array of strings)
    if isinstance(ids, str):
        ids = np.array([ids])
    if isinstance(ids, list):
        ids = np.array(ids)
        if ids.dtype.name == "object":
            ids = ids.astype(str)
        if not ids.dtype.name.startswith("str"):
            raise ValueError("Expecting values in the list to be strings.")
    if isinstance(ids, pd.Series):
        ids = np.array(ids.values)
    if isinstance(ids, np.ndarray):
        if ids.dtype.name == "object":
            ids = ids.astype(str)
        if not ids.dtype.name.startswith("str"):
            raise ValueError("Expecting values in the np.array to be strings.")
    # --------------------------------------------------------------------.
    # Check idx validity
    invalid_ids = ids[np.isin(ids, valid_ids, invert=True)]
    if len(invalid_ids) > 0:
        raise ValueError(f"The following ids are not valid: {invalid_ids.tolist()}")
    if len(ids) == 0:
        raise ValueError("THIS SHOULD NOT OCCUR ...")
    # --------------------------------------------------------------------.
    return ids


def _check_timedelta(timedelta):
    if not isinstance(timedelta, (pd.Timedelta, np.timedelta64)):
        raise TypeError("'timedelta' must be a pd.Timedelta or np.timedelta64 instance.")
    return timedelta


def _check_df(df, name=None):
    if not isinstance(df, (pd.DataFrame, pd.Series)):
        if name is not None:
            raise TypeError(f"{name} is not a pd.DataFrame or pd.Series")
        raise TypeError("Expecting a pd.DataFrame or pd.Series")
    if isinstance(df, pd.Series):
        df = df.to_frame()
    return df


def _check_columns(columns):
    if not isinstance(columns, (str, list, np.ndarray)):
        raise TypeError("'columns' must be a string or list/np.array of strings.")
    if isinstance(columns, str):
        columns = [columns]
    if isinstance(columns, list):
        columns = np.array(columns)
    if isinstance(columns, np.ndarray):
        if columns.dtype.name == "object":
            columns = columns.astype(str)
        if not columns.dtype.name.startswith("str"):
            raise ValueError("Expecting columns in the np.array to be strings.")
    return columns


def _check_df_source(df_source):
    if not isinstance(df_source, str):
        raise TypeError("'df_source' must be a string. Either 'triplet', 'cam0','cam1', 'cam2'.")
    valid_source = ["triplet", "cam0", "cam1", "cam2"]
    if df_source not in valid_source:
        raise ValueError(f"Valid 'df_source' values are {valid_source}")


def _get_df_values(self, column, df_source="triplet"):
    if df_source == "triplet":
        return self._triplet[column].to_numpy()
    if df_source == "cam0":
        return self._cam0[column].to_numpy()
    if df_source == "cam1":
        return self._cam1[column].to_numpy()
    if df_source == "cam2":
        return self._cam2[column].to_numpy()
    raise ValueError("Invalid 'source'.")


def _count_occurrence(x):
    return [dict(zip(list(x.value_counts().keys()), list(x.value_counts())))]


def _convert_object_to_string(df):
    idx_object = df.dtypes.to_numpy() == "object"
    columns = df.columns[np.where(idx_object)]
    for column in columns:
        df[column] = df[column].astype("string")
    return df


def _read_parquet(fpath):
    df = pd.read_parquet(fpath).set_index("flake_id", drop=False)
    # - Ensure categorical/object columns are encoded as string
    df = _convert_object_to_string(df)
    return df


####-----------------------------------------------------------------------------.
[docs] class MASC_DB: """ Read MASCDB database from a specific directory. Parameters ---------- dir_path : str Filepath to a directory storing a MASCDB. 5 files are expected in the directory: - MASCdb_cam<0/1/2>.parquet - MASCdb_triplet.parquet - MASCdb.zarr Returns ------- MASCDB MASCDB class instance. """ ##################### #### Read MASCDB ### ##################### def __init__(self, dir_path): """Initialize MASC_DB object. It reads 4 parquet databases as well as the zarr database of MASC greyscale images. Returns ------- MASCDB MASCDB class instance. """ # DEBUG # return zarr_store_fpath = os.path.join(dir_path, "MASCdb.zarr") cam0_fpath = os.path.join(dir_path, "MASCdb_cam0.parquet") cam1_fpath = os.path.join(dir_path, "MASCdb_cam1.parquet") cam2_fpath = os.path.join(dir_path, "MASCdb_cam2.parquet") triplet_fpath = os.path.join(dir_path, "MASCdb_triplet.parquet") # - Check if the Zarr DirectoryStore has not been unzipped if not os.path.exists(zarr_store_fpath): zarr_zipstore_fpath = zarr_store_fpath + ".zip" if os.path.exists(zarr_zipstore_fpath): raise ValueError(f"You need to unzip {zarr_zipstore_fpath}") # - Read image dataset da = xr.open_zarr(zarr_store_fpath)["data"] da["flake_id"] = da["flake_id"].astype(str) da.name = "MASC Images" self._da = da # - Read cam dataframes self._cam0 = _read_parquet(cam0_fpath) self._cam1 = _read_parquet(cam1_fpath) self._cam2 = _read_parquet(cam2_fpath) # - Read triplet triplet = _read_parquet(triplet_fpath) # - Ensure event duration in ns if "event_duration" in list(triplet.columns): triplet["event_duration"] = triplet["event_duration"].astype( "timedelta64[ns]", ) # astype("m8[ns]") self._triplet = triplet # ------------ # - Define number of triplets self._n_triplets = len(self._triplet) # - Save source MASCDB directory self._dir_path = dir_path # - Add default events self._define_events(max_interval_without_images=np.timedelta64(4, "h"), unit="ns") ####----------------------------------------------------------------------. ######################### #### Builtins methods ### ######################### def __len__(self): """Return number of triplets in the MASCDB.""" return self._n_triplets def __str__(self): """Return string representation of MASCDB object.""" print("MASCDB data structure:") print("-------------------------------------------------------------------------") print("- mascdb.da:") print(self._da) print("-------------------------------------------------------------------------") print("- mascdb.cam0, mascdb.cam1, mascdb.cam2:") print(self._cam0) print("-------------------------------------------------------------------------") print("- mascdb.triplet") print(self._triplet) print("-------------------------------------------------------------------------") print("- mascdb.env") print(self.env) print("-------------------------------------------------------------------------") print("- mascdb.bs") print(self.bs) print("-------------------------------------------------------------------------") print("- mascdb.gan3d") print(self.gan3d) print("-------------------------------------------------------------------------") return "" def __repr__(self): """Return string representation of MASCDB object.""" return self.__str__() ####----------------------------------------------------------------------. ##################### #### Write MASCDB ### #####################
[docs] def save(self, dir_path, force=False): """ Save MASCDB object to disk into 4 parquet files and one Zarr store. Parameters ---------- dir_path : str Directory path where to save the current MASCDB database. force : bool Default is False If dir_path is the same as the source path of MASCDB object, force=True should allows to overwrite the original source database. """ # TODO: if overwriting, put all DataArray in memory first... otherwise deleting on disk remove lazy loaded data # - Check there are data to save if self._n_triplets == 0: raise ValueError("Nothing to save. No data left in the MASCDB.") # - Check dir_path if dir_path == self._dir_path: if force: print(f"- Overwriting existing 'source' MASCDB at {dir_path}") shutil.rmtree(dir_path) else: raise ValueError( f"If you want to overwrite the existing MASCDB at {dir_path}," "please specify force=True", ) if os.path.exists(dir_path): if force: print(f"- Replacing content of directory {dir_path}.") shutil.rmtree(dir_path) else: raise ValueError( f"A directory already exists at {dir_path}." "Please specify force=True if you want to overwrite it.", ) # ---------------------------------------------------------------------. # - Create directory os.makedirs(dir_path) # - Define fpath of databases zarr_store_fpath = os.path.join(dir_path, "MASCdb.zarr") cam0_fpath = os.path.join(dir_path, "MASCdb_cam0.parquet") cam1_fpath = os.path.join(dir_path, "MASCdb_cam1.parquet") cam2_fpath = os.path.join(dir_path, "MASCdb_cam2.parquet") triplet_fpath = os.path.join(dir_path, "MASCdb_triplet.parquet") # - Ensure "correct" chunks of DataArray da = self.da new_chunks = [max(chunk) for chunk in da.chunks] da = da.chunk(new_chunks) # - Write databases # ------------ # Temporary solution because timedelta cannot be saved # currently to parquet: https://issues.apache.org/jira/browse/ARROW-6780 # - event_duration timedelta is converted to int # - It assume no other timedelta columns are present in dataframes triplet = self.triplet triplet["event_duration"] = triplet["event_duration"].astype("timedelta64[ns]").view(int) # ------------ ds = da.to_dataset(name="data") ds.to_zarr(zarr_store_fpath) self._cam0.to_parquet(cam0_fpath, engine="auto") self._cam1.to_parquet(cam1_fpath, engine="auto") self._cam2.to_parquet(cam2_fpath, engine="auto") triplet.to_parquet(triplet_fpath, engine="auto")
# ---------------------------------------------------------------------. ####----------------------------------------------------------------------. ################### #### Subsetting ### ###################
[docs] def isel(self, idx): """Positional-index subsetting of MASCDB DataArray and MASCDB DataFrames. Parameters ---------- idx : (numpy.ndarray, list, int) List or np.ndarray of integer/boolean values used as positional indices for subsetting. Returns ------- MASCDB MASCDB class instance subsetted class instance subsetted (or index-based reordered). """ # ---------------------------------------------------------------------. # Check valid (integer) idx idx = _check_isel_idx(idx, vmax=self._n_triplets - 1) # ---------------------------------------------------------------------. # Copy new instance db = copy.deepcopy(self) # ---------------------------------------------------------------------. ### Subset all datasets # - DataArray with dask.config.set(**{"array.slicing.split_large_chunks": False}): db._da = db._da.isel(flake_id=idx) # - Dataframes # if isinstance(idx[0], bool): # db._cam0 = db._cam0[idx] # self._cam1 = self._cam1[idx] # self._cam2 = self._cam2[idx] # self._triplet = self._triplet[idx] # else: db._cam0 = db._cam0.iloc[idx] db._cam1 = db._cam1.iloc[idx] db._cam2 = db._cam2.iloc[idx] db._triplet = db._triplet.iloc[idx] ##--------------------------------------------------------------------. # Update number of triplets db._n_triplets = len(db._triplet) ##--------------------------------------------------------------------. return db
[docs] def sel(self, flake_ids): """Subset MASCDB based on specified flake_ids. Parameters ---------- flake_ids : numpy.ndarray, list, str List or np.ndarray of string specifying flake_id values to subset. Returns ------- MASCDB MASCDB class instance subsetted. """ # ---------------------------------------------------------------------. # Check valid flake_ids valid_flake_ids = self._da["flake_id"].to_numpy() flake_ids = _check_sel_ids(flake_ids, valid_ids=valid_flake_ids) # ---------------------------------------------------------------------. # Copy new instance db = copy.deepcopy(self) # ---------------------------------------------------------------------. ### Subset all datasets # - DataArray with dask.config.set(**{"array.slicing.split_large_chunks": False}): db._da = db._da.sel(flake_id=flake_ids) # - Dataframes db._cam0 = db._cam0.loc[flake_ids] db._cam1 = db._cam1.loc[flake_ids] db._cam2 = db._cam2.loc[flake_ids] db._triplet = db._triplet.loc[flake_ids] ##--------------------------------------------------------------------. # Update number of triplets db._n_triplets = len(db._triplet) ##--------------------------------------------------------------------. return db
[docs] def sample_n(self, n=10): """Sample randomly 'n' flakes in the current MASCDB object. Parameters ---------- n : int, float, optional Number of samples to extract The default is 10. Returns ------- MASCDB MASCDB object with n sampled flakes. """ if n > len(self): raise ValueError(f"The MASCDB instance has currently only {len(self)} triplets.") idx = np.random.choice(self._n_triplets, n) return self.isel(idx)
[docs] def first(self, n=1): """Extract first 'n' flakes in the database. Parameters ---------- n : int,float, optional Number of samples to extract The default is 1. Returns ------- MASCDB MASCDB object containing only the n first flakes of the current database """ if n > len(self): raise ValueError(f"The MASCDB instance has currently only {len(self)} triplets.") idx = np.arange(n) return self.isel(idx)
[docs] def last(self, n=1): """Extract last 'n' flakes in the database. Parameters ---------- n : int,float, optional Number of samples to extract The default is 1 Returns ------- MASCDB MASCDB object containing only the n last flakes of the current database. """ if n > len(self): raise ValueError(f"The MASCDB instance has currently only {len(self)} triplets.") idx = np.arange(self._n_triplets - 1, self._n_triplets - n - 1, step=-1) return self.isel(idx)
[docs] def head(self, n=10): """Extract first 'n' flakes in the database. Parameters ---------- n : int,float, optional Number of samples to extract The default is 10. Returns ------- MASCDB MASCDB object containing only the n first flakes of the current database. """ n = min(self._n_triplets, n) idx = np.arange(n) return self.isel(idx)
[docs] def tail(self, n=10): """Extract last 'n' flakes in the database. Parameters ---------- n : int,float, optional Number of samples to extract The default is 10. Returns ------- MASCDB MASCDB object containing only the n last flakes of the current database. """ n = min(self._n_triplets, n) idx = np.arange(self._n_triplets - 1, self._n_triplets - n - 1, step=-1) return self.isel(idx)
####----------------------------------------------------------------------. ################ #### Sorting ### ################
[docs] def arrange(self, expression, decreasing=True): """Reorder the MASCDB based on the DataFrame column values specified with expression. Parameters ---------- expression : str Expression specifying the DataFrame and column used to sort the MASCDB. The expression must have the following pattern '<df_name>.<column_name>' . Valid df_names are : ['cam0', 'cam1','cam2','triplet','bs','env','gan3d','flake','labels'] . decreasing : bool, optional Whether to sort MASCDB by increasing or decreasing values of the DataFrame column. The default is True. Returns ------- MASCDB MASCDB object sorted. """ # Check expression type if not isinstance(expression, str): raise TypeError("'expression' must be a string.") # ------------------------------. # Retrieve db name and column split_expression = expression.split(".") db_name = split_expression[0] db_column = split_expression[1] # Check valid format if len(split_expression) != 2: raise ValueError( "An invalid 'expression' has been specified.\n" "The expected format is <cam*/triplet/env/bs>.<column_name>.", ) # Check valid db valid_db = ["cam0", "cam1", "cam2", "triplet", "bs", "env", "gan3d", "flake", "labels"] if db_name not in valid_db: raise ValueError(f"The first component must be one of {valid_db}") # ------------------------------. # Get db db = getattr(self, db_name) # Check valid column valid_columns = list(db.columns) if db_column not in valid_columns: raise ValueError(f"{db_column!r} is not a column of {db_name!r}. Valid columns are {valid_columns}") # ------------------------------. # Retrieve sorting idx idx = db[db_column].to_numpy().argsort() if decreasing: idx = idx[::-1] # ------------------------------. # Return sorted object return self.isel(idx)
[docs] def select_max(self, expression, n=10): """Select 'n' triplets with maximum values of a given DataFrame column.""" return self.arrange(expression, decreasing=True).isel(np.arange(min(n, self._n_triplets)))
[docs] def select_min(self, expression, n=10): """Select 'n' triplets with minimum values of a given DataFrame column.""" return self.arrange(expression, decreasing=False).isel(np.arange(min(n, self._n_triplets)))
####----------------------------------------------------------------------. ########################## #### Data explanation #### ##########################
[docs] def get_var_units(self, varname): """ Get units of a given variable. Parameters ---------- varname : str String specifying a single column of the cam or triplet dataframe. Returns ------- str Abbreviated units of the variable. """ if not isinstance(varname, str): raise TypeError("'varname' must be a string") units = var_units() if varname in units: return units[varname] raise ValueError( f"{varname} units are not currently available. " f"Units are available for {list(units.keys())}", )
[docs] def get_var_explanation(self, varname): """ Get verbose explanation of a given variable. It includes DOI of reference paper whenever relevant. Parameters ---------- varname : str String specifying a single column of the cam or triplet dataframe. Returns ------- str Verbose explanation of the variable. """ if not isinstance(varname, str): raise TypeError("'varname' must be a string") explanations = var_explanations() if varname in explanations: return explanations[varname] raise ValueError( f"{varname} verbose explanation is not currently available. " f"Explanations are available for {list(explanations.keys())}", )
####----------------------------------------------------------------------. ################# #### Filters #### #################
[docs] def select_campaign(self, campaign): """ Select MASCDB data of specific campaigns. Parameters ---------- campaign : (str, list) String or list of string specifying MASCDB campaigns to select. Returns ------- MASCDB MASCDB class instance with data of specific campaigns. """ if not isinstance(campaign, (list, str)): raise TypeError("'campaign' must be a string or a list of strings.") if isinstance(campaign, str): campaign = [campaign] # Convert to numpy array with str type (not object...) campaign = np.array(campaign).astype(str) campaigns_arr = self._triplet["campaign"].to_numpy().astype(str) valid_campaigns = np.unique(campaigns_arr) invalid_campaigns_arg = campaign[np.isin(campaign, valid_campaigns, invert=True)] if len(invalid_campaigns_arg) > 0: raise ValueError( f"{invalid_campaigns_arg.tolist()} is not a campaign of the current mascdb. " f"Valid campaign names are {valid_campaigns.tolist()}", ) idx = np.isin(campaigns_arr, campaign) return self.isel(idx)
[docs] def discard_campaign(self, campaign): """ Discard MASCDB data from specific campaigns. Parameters ---------- campaign : (str, list) String or list of string specifying MASCDB campaigns to discard. Returns ------- MASCDB MASCDB class instance with data of specific campaigns. """ if not isinstance(campaign, (list, str)): raise TypeError("'campaign' must be a string or a list of strings.") if isinstance(campaign, str): campaign = [campaign] # Convert to numpy array with str type (not object...) campaign = np.array(campaign).astype(str) campaigns_arr = self._triplet["campaign"].to_numpy().astype(str) valid_campaigns = np.unique(campaigns_arr) invalid_campaigns_arg = campaign[np.isin(campaign, valid_campaigns, invert=True)] if len(invalid_campaigns_arg) > 0: raise ValueError( f"{invalid_campaigns_arg.tolist()} is already not a campaign of the current mascdb. " f"Current mascdb has campaign names {valid_campaigns.tolist()}", ) idx = np.isin(campaigns_arr, campaign, invert=True) return self.isel(idx)
[docs] def select_snowflake_class(self, values, method="Praz2017", invert=False, df_source="triplet"): """ Select MASCDB data with specific snowflake classes. Parameters ---------- values : (str, int, list) Values specifying the snowflake classes to select. If integers, it assumes snowflake_class_id. If strings, it assumes snowflake_class_name. Valid values can be retrieved by calling 'mascdb.utils_aux.get_snowflake_class_name_dict(method)'. method : str, optional Method used to determine snowflake_class. The default is 'Praz2017'. invert : bool, optional If True, instead of selecting it discard the specified snowflake_class. The default is False. df_source: str, optional The dataframe from which retrieve the class. Either 'cam0', 'cam1', 'cam2' or 'triplet'. The default is 'triplet'. Returns ------- MASCDB MASCDB class instance with specific snowflake classes. """ # ---------------------------------------------------------------------. ## Check default args if not isinstance(invert, bool): raise TypeError("'invert' must be either True or False'.") _check_df_source(df_source) # ---------------------------------------------------------------------. ## Check values if not isinstance(values, (int, str, list, np.ndarray)): raise TypeError("'values' must be either (list of) integers (for class ids) or str (for class names).") # Convert to numpy array object values = np.array([values]) if isinstance(values, (int, str)) else np.array(values) # If values are integers --> Assume it provide the class id if isinstance(values[0].item(), int): valid_names = list(get_snowflake_class_name_dict(method=method).values()) # id column = "snowflake_class_id" # If values are str --> Assume it provide the class name elif isinstance(values[0].item(), str): valid_names = list(get_snowflake_class_name_dict(method=method).keys()) # name column = "snowflake_class_name" else: raise TypeError("'values' must be either integers (for class ids) or str (for class names).") # ---------------------------------------------------------------------. # Retrieve column values (by default from triplet df) arr = _get_df_values(self, df_source=df_source, column=column) # Check values are valid invalid_values = values[np.isin(values, valid_names, invert=True)] if len(invalid_values) > 0: raise ValueError( f"{invalid_values.tolist()} is not a {column} of the current mascdb. " f"Current mascdb has {column} values {valid_names}", ) # ---------------------------------------------------------------------. # Subset the mascdb idx = np.isin(arr, values, invert=invert) return self.isel(idx)
[docs] def select_riming_class(self, values, method="Praz2017", invert=False, df_source="triplet"): """ Select MASCDB data with specific riming classes. Parameters ---------- values : (str, int, list) Values specifying the riming classes to select. If integers, it assumes riming_class_id. If strings, it assumes riming_class_name. Valid values can be retrieved by calling 'mascdb.utils_aux.get_riming_class_name_dict(method)'. method : str, optional Method used to determine riming_class. The default is 'Praz2017'. invert : bool, optional If True, instead of selecting it discard the specified riming_class. The default is False. df_source: str, optional The dataframe from which retrieve the class. Either 'cam0', 'cam1', 'cam2' or 'triplet'. The default is 'triplet'. Returns ------- MASCDB MASCDB class instance with specific riming classes. """ # ---------------------------------------------------------------------. ## Check default args if not isinstance(invert, bool): raise TypeError("'invert' must be either True or False'.") _check_df_source(df_source) # ---------------------------------------------------------------------. ## Check values if not isinstance(values, (int, str, list, np.ndarray)): raise TypeError("'values' must be either (list of) integers (for class ids) or str (for class names).") # Convert to numpy array object values = np.array([values]) if isinstance(values, (int, str)) else np.array(values) # If values are integers --> Assume it provide the class id if isinstance(values[0].item(), int): valid_names = list(get_riming_class_name_dict(method=method).values()) # id column = "riming_class_id" # If values are str --> Assume it provide the class name elif isinstance(values[0].item(), str): valid_names = list(get_riming_class_name_dict(method=method).keys()) # name column = "riming_class_name" else: raise TypeError("'values' must be either integers (for class ids) or str (for class names).") # ---------------------------------------------------------------------. # Retrieve column values (by default from triplet df) arr = _get_df_values(self, df_source=df_source, column=column) # Check values are valid invalid_values = values[np.isin(values, valid_names, invert=True)] if len(invalid_values) > 0: raise ValueError( f"{invalid_values.tolist()} is not a {column} of the current mascdb. " f"Current mascdb has {column} values {valid_names}", ) # ---------------------------------------------------------------------. # Subset the mascdb idx = np.isin(arr, values, invert=invert) return self.isel(idx)
[docs] def select_melting_class(self, values, method="Praz2017", invert=False, df_source="triplet"): """ Select MASCDB data with specific melting classes. Parameters ---------- values : (str, int, list) Values specifying the melting classes to select. If integers, it assumes melting_class_id. If strings, it assumes melting_class_name. Valid values can be retrieved by calling 'mascdb.utils_aux.get_melting_class_name_dict(method)'. method : str, optional Method used to determine melting_class. The default is 'Praz2017'. invert : bool, optional If True, instead of selecting it discard the specified melting_class_id. The default is False. df_source: str, optional The dataframe from which retrieve the class. Either 'cam0', 'cam1', 'cam2' or 'triplet'. The default is 'triplet'. Returns ------- MASCDB MASCDB class instance with specific melting classes. """ # ---------------------------------------------------------------------. ## Check default args if not isinstance(invert, bool): raise TypeError("'invert' must be either True or False'.") _check_df_source(df_source) # ---------------------------------------------------------------------. ## Check values if not isinstance(values, (int, str, list, np.ndarray)): raise TypeError("'values' must be either (list of) integers (for class ids) or str (for class names).") # Convert to numpy array object values = np.array([values]) if isinstance(values, (int, str)) else np.array(values) # If values are integers --> Assume it provide the class id if isinstance(values[0].item(), int): valid_names = list(get_melting_class_name_dict(method=method).values()) # id column = "melting_class_id" # If values are str --> Assume it provide the class name elif isinstance(values[0].item(), str): valid_names = list(get_melting_class_name_dict(method=method).keys()) # name column = "melting_class_name" else: raise TypeError("'values' must be either integers (for class ids) or str (for class names).") # ---------------------------------------------------------------------. # Retrieve column values (by default from triplet df) arr = _get_df_values(self, df_source=df_source, column=column) # Check values are valid invalid_values = values[np.isin(values, valid_names, invert=True)] if len(invalid_values) > 0: raise ValueError( f"{invalid_values.tolist()} is not a {column} of the current mascdb. " f"Current mascdb has {column} values {valid_names}", ) # ---------------------------------------------------------------------. # Subset the mascdb idx = np.isin(arr, values, invert=invert) return self.isel(idx)
[docs] def select_precip_class(self, values, method="Schaer2020", invert=False): """ Select MASCDB data with specific precipitation types. Parameters ---------- values : (str, int, list) Values specifying the precipitation classes to select. If integers, it assumes bs_precip_class_id. If strings, it assumes bs_precip_class_name. Valid values can be retrieved by calling 'mascdb.utils_aux.get_precip_class_name_dict(method)'. method : str, optional Method used to determine bs_precip_class. The default is 'Schaer2020'. invert : bool, optional If True, instead of selecting it discard the specified bs_precip_class. The default is False. Returns ------- MASCDB MASCDB class instance with specific precipitation classes. """ if not isinstance(values, (int, str, list, np.ndarray)): raise TypeError("'values' must be either (list of) integers (for class ids) or str (for class names).") # Convert to numpy array object values = np.array([values]) if isinstance(values, (int, str)) else np.array(values) # If values are integers --> Assume it provide the class id if isinstance(values[0].item(), int): valid_names = list(get_precip_class_name_dict(method=method).values()) # id column = "bs_precip_class_id" # If values are str --> Assume it provide the class name elif isinstance(values[0].item(), str): valid_names = list(get_precip_class_name_dict(method=method).keys()) # name column = "bs_precip_class_name" else: raise TypeError("'values' must be either integers (for class ids) or str (for class names).") # ---------------------------------------------------------------------. # Retrieve triplet column values arr = self._triplet[column].to_numpy() # Check values are valid invalid_values = values[np.isin(values, valid_names, invert=True)] if len(invalid_values) > 0: raise ValueError( f"{invalid_values.tolist()} is not a {column} of the current mascdb. " f"Current mascdb has {column} values {valid_names}", ) # ---------------------------------------------------------------------. # Subset the mascdb idx = np.isin(arr, values, invert=invert) return self.isel(idx)
[docs] def discard_snowflake_class(self, values, method="Praz2017", df_source="triplet"): """ Discard MASCDB data with specific snowflake classes. Parameters ---------- values : (str, int, list) Values specifying the snowflake classes to discard. If integers, it assumes snowflake_class_id. If strings, it assumes snowflake_class_name. Valid values can be retrieved by calling 'mascdb.utils_aux.get_snowflake_class_name_dict(method)'. method : str, optional Method used to determine snowflake_class. The default is 'Praz2017'. df_source: str, optional The dataframe from which retrieve the class. Either 'cam0', 'cam1', 'cam2' or 'triplet'. The default is 'triplet'. Returns ------- MASCDB MASCDB class instance with specific snowflake classes. """ return self.select_snowflake_class(values=values, method=method, invert=True, df_source=df_source)
[docs] def discard_melting_class(self, values, method="Praz2017", df_source="triplet"): """ Discard MASCDB data with specific melting classes. Parameters ---------- values : (str, int, list) Values specifying the melting classes to discard. If integers, it assumes melting_class_id. If strings, it assumes melting_class_name. Valid values can be retrieved by calling 'mascdb.utils_aux.get_melting_class_name_dict(method)'. method : str, optional Method used to determine melting_class. The default is 'Praz2017'. df_source: str, optional The dataframe from which retrieve the class. Either 'cam0', 'cam1', 'cam2' or 'triplet'. The default is 'triplet'. Returns ------- MASCDB MASCDB class instance with specific melting classes. """ return self.select_melting_class(values=values, method=method, invert=True, df_source=df_source)
[docs] def discard_riming_class(self, values, method="Praz2017", df_source="triplet"): """ Discard MASCDB data with specific riming classes. Parameters ---------- values : (str, int, list) Values specifying the riming classes to discard. If integers, it assumes riming_class_id. If strings, it assumes riming_class_name. Valid values can be retrieved by calling 'mascdb.utils_aux.get_riming_class_name_dict(method)'. method : str, optional Method used to determine riming_class. The default is 'Praz2017'. df_source: str, optional The dataframe from which retrieve the class. Either 'cam0', 'cam1', 'cam2' or 'triplet'. The default is 'triplet'. Returns ------- MASCDB MASCDB class instance with specific riming classes. """ return self.select_riming_class(values=values, method=method, invert=True, df_source=df_source)
[docs] def discard_precip_class(self, values, method="Schaer2020"): """Discard MASCDB data with specific precipitation types. Parameters ---------- values : (str, int, list) Values specifying the precipitation classes to discard. If integers, it assumes bs_precip_class_id. If strings, it assumes bs_precip_class_name. Valid values can be retrieved by calling 'mascdb.utils_aux.get_precip_class_name_dict(method)'. method : str, optional Method used to determine bs_precip_class. The default is 'Schaer2020'. Returns ------- MASCDB MASCDB class instance with specific precipitation classes. """ return self.select_precip_class(values=values, method=method, invert=True)
####----------------------------------------------------------------------. ################# #### Getters #### ################# # The following properties are used to avoid accidental modification in place by the user @property def da(self): """DataArray of MASC images and attributes.""" return self._da.copy() @property def cam0(self): """Dataframe of snowflake attributes for CAM0 view.""" return self._cam0.copy() @property def cam1(self): """Dataframe of snowflake attributes for CAM1 view.""" return self._cam1.copy() @property def cam2(self): """Dataframe of snowflake attributes for CAM2 view.""" return self._cam2.copy() @property def triplet(self): """Dataframe of snowflake attributes valid for the triplet of images.""" return self._triplet.copy() # The following properties are just utils @property def env(self): """Dataframe of environmental (env_*) attributes.""" columns = list(self._triplet.columns) env_variables = [column for column in columns if column.startswith("env_")] env_db = self._triplet[[*env_variables]].copy() env_db.columns = [column.strip("env_") for column in env_variables] return env_db @property def bs(self): """Dataframe of blowing-snow estimation (bs_*) attributes (Schaer et al 2020).""" columns = list(self._triplet.columns) bs_variables = [column for column in columns if column.startswith("bs_")] bs_db = self._triplet[[*bs_variables]].copy() bs_db.columns = [column.strip("bs_") for column in bs_variables] return bs_db @property def gan3d(self): """Dataframe of 3d reconstruction (gan3d_*) attributes (Leinonen et al, 2021).""" columns = list(self._triplet.columns) gan3d_variables = [column for column in columns if column.startswith("gan3d_")] gan3d_db = self._triplet[[*gan3d_variables]].copy() gan3d_db.columns = [column.replace("gan3d_", "") for column in gan3d_variables] return gan3d_db @property def flake(self): """Dataframe of flake/triplet (flake_*) attributes.""" columns = list(self._triplet.columns) flake_variables = [column for column in columns if column.startswith("flake_")] flake_db = self._triplet[[*flake_variables]].copy() flake_db.columns = [column.strip("flake_") for column in flake_variables] return flake_db @property def labels(self): """Dataframe of hydrometeor classification, riming and melting attributes (Praz et al 2017).""" labels_variables = get_vars_class() labels_db = self._triplet[[*labels_variables]].copy() return labels_db @property def event(self): """Dataframe summarizing separate (according to setting) precip. events.""" # columns = list(self._triplet.columns) # event_columns = [column for column in columns if column.startswith("event_")] event_columns = [ "event_id", "event_duration", "event_n_triplets", "campaign", "datetime", "latitude", "longitude", "altitude", ] event_db = self._triplet[event_columns].groupby("event_id").first().reset_index() # Compute month and year event_db["month"] = event_db["datetime"].dt.month event_db["year"] = event_db["datetime"].dt.year # Compute start_time and end_time start_time = self._triplet[["event_id", "datetime"]].groupby("event_id").min() start_time.columns = ["start_time"] end_time = self._triplet[["event_id", "datetime"]].groupby("event_id").max() end_time.columns = ["end_time"] _ = event_db.drop(columns="datetime", inplace=True) event_db = event_db.merge(start_time, left_on="event_id", right_index=True) event_db = event_db.merge(end_time, left_on="event_id", right_index=True) return event_db @property def campaign(self): """Dataframe summarizing information of different field campaigns.""" # ----------------------------------------------. # Retrieve data df_event = self.event df_triplet = self._triplet c_id = "campaign" # ----------------------------------------------. # Compute location info columns = ["latitude", "longitude", "altitude", "campaign"] info_location = df_event[columns].groupby("campaign").first() # ----------------------------------------------. # Compute number of triplets # - n_triplet (sum) n_triplets = df_triplet.groupby(c_id)["flake_id"].count() n_triplets.name = "n_triplets" # ----------------------------------------------. ## Compute event summary n_events = df_event[["event_id", c_id]].groupby(c_id).count() n_events.columns = ["n_events"] start_time = df_event[["start_time", c_id]].groupby(c_id).min() end_time = df_event[["end_time", c_id]].groupby(c_id).max() event_duration_stats = df_event.groupby(c_id).agg({"event_duration": ["min", "mean", "max", "sum"]}) event_duration_stats.columns = [ "event_duration_min", "event_duration_mean", "event_duration_max", "total_event_duration", ] # ----------------------------------------------. ### Compute class occurrence snowflake_class_counts = ( df_triplet[["snowflake_class_name", c_id]] .groupby(c_id)["snowflake_class_name"] .apply(_count_occurrence) .apply(lambda x: x[0]) ) riming_class_counts = ( df_triplet[["riming_class_name", c_id]] .groupby(c_id)["riming_class_name"] .apply(_count_occurrence) .apply(lambda x: x[0]) ) melting_class_counts = ( df_triplet[["melting_class_name", c_id]] .groupby(c_id)["melting_class_name"] .apply(_count_occurrence) .apply(lambda x: x[0]) ) precipitation_class_counts = ( df_triplet[["bs_precip_class_name", c_id]] .groupby(c_id)["bs_precip_class_name"] .apply(_count_occurrence) .apply(lambda x: x[0]) ) snowflake_class_counts.name = "snowflake_class" riming_class_counts.name = "riming_class" melting_class_counts.name = "melting_class" precipitation_class_counts.name = "precipitation_class" # ----------------------------------------------. ## Compute other time infos # months_list = df_triplet.groupby(c_id)['datetime'].apply(lambda x: list(np.unique(x.dt.month_name()))) # months_list.name = "months" # years_list = df_triplet.groupby(c_id)['datetime'].apply(lambda x: list(np.unique(x.dt.year))) # years_list.name = "years" # years_months_list = df_triplet.groupby(c_id)['datetime'].apply( # lambda x: list(np.unique(x.dt.strftime('%Y-%m')))) # years_months_list.name = "years_months" # ----------------------------------------------. # Define summary dataframe summary = pd.merge(start_time, end_time, right_index=True, left_index=True) summary = summary.join(info_location) summary = summary.join(n_triplets) summary = summary.join(n_events) summary = summary.join(event_duration_stats) summary = summary.join(snowflake_class_counts) summary = summary.join(riming_class_counts) summary = summary.join(melting_class_counts) summary = summary.join(precipitation_class_counts) # ----------------------------------------------. return summary @property def full_db(self): """Dataframe including cam0, cam1, cam2 and triplet stacked.""" # TODO: check same order as ds_images ... maybe add cam_id and campaign args # Add cam_id to each cam db l_cams = [self.cam0, self.cam1, self.cam2] for i, cam in enumerate(l_cams): cam["cam_id"] = i # Merge cam(s) db into full_db = pd.concat(l_cams) # Add triplet variables to fulldb labels_vars = get_vars_class() vars_not_add = ["flake_quality_xhi", "flake_n_roi", "flake_Dmax", "flake_id", *labels_vars] triplet = self._triplet.drop(columns=vars_not_add) full_db = full_db.merge(triplet, how="left") return full_db
[docs] def ds_images(self, cam_id=None, campaign=None, img_id="img_id"): """Return xarray DataArray of images.""" # ----------------------------------------------------------------------. # Subset by campaign if campaign is not None: if isinstance(campaign, str): campaign = [campaign] campaign = np.array(campaign).astype(str) db_campaigns = self._triplet["campaign"].to_numpy().astype(str) valid_campaigns = np.unique(db_campaigns) invalid_campaigns_arg = campaign[np.isin(campaign, valid_campaigns, invert=True)] if len(invalid_campaigns_arg) > 0: raise ValueError( f"{invalid_campaigns_arg.tolist()} is not a campaign of the current mascdb. " f"Valid campaign names are {valid_campaigns.tolist()}", ) idx = np.isin(db_campaigns, campaign) da = self.isel(idx).da else: da = self.da # ----------------------------------------------------------------------. # Subset cam images if cam_id is not None: da = da.isel(cam_id=cam_id) # ----------------------------------------------------------------------. ### Retrieve dimensions to eventually stack along a new third dimension dims = list(da.dims) unstacked_dims = list(set(dims).difference(["x", "y"])) # - If only x and y, add third dimension img_id if len(unstacked_dims) == 0: stack_dict = {} da_stacked = da.expand_dims(img_id, axis=-1) # - If there is already a third dimension, transpose to the last elif len(unstacked_dims) == 1: da = da.rename({unstacked_dims[0]: img_id}) da_stacked = da.transpose(..., img_id) # - If there is more than 3 dimensions, stack it all into a new third dimension elif len(unstacked_dims) > 1: stack_dict = {img_id: unstacked_dims} # Stack all additional dimensions into a 3D array with all img_id in the last dimension da_stacked = da.stack(stack_dict).transpose(..., img_id) else: raise NotImplementedError return da_stacked
####----------------------------------------------------------------------. ###################### #### Event utils ##### ###################### def _add_event_n_triplets(self): if "event_id" not in list(self._triplet.columns): raise ValueError("First define 'event_id' using mascdb.define_event_id().") event_triplets = self._triplet.groupby("event_id").size() event_triplets.name = "event_n_triplets" # - Remove existing column to remerge if "event_n_triplets" in list(self._triplet.columns): _ = self._triplet.drop(columns="event_n_triplets", inplace=True) # - Add column self._triplet = self._triplet.merge(event_triplets, on="event_id").set_index("flake_id", drop=False) def _add_event_duration(self, unit="ns"): if "event_id" not in list(self._triplet.columns): raise ValueError("First define 'event_id' using mascdb.define_event_id().") event_durations = self._triplet.groupby("event_id").apply( lambda x: _get_timesteps_duration(x.datetime, unit=unit), include_groups=False, ) event_durations.name = "event_duration" # - Remove existing column to remerge if "event_duration" in list(self._triplet.columns): _ = self._triplet.drop(columns="event_duration", inplace=True) # - Add column self._triplet = self._triplet.merge(event_durations, on="event_id").set_index("flake_id", drop=False) def _define_events(self, max_interval_without_images=None, unit="ns"): max_interval_without_images = ( np.timedelta64(4, "h") if max_interval_without_images is None else max_interval_without_images ) # This function modify in place # ----------------------------------------------------------. # - Extract relevant columns from triplet db db = self._triplet[["campaign", "datetime"]].copy() # - Retrieve campaign_ids campaign_ids = np.unique(db["campaign"]) # ----------------------------------------------------------. # - Retrieve event_id column db["event_id"] = -1 max_event_id = 0 for campaign_id in campaign_ids: # - Retrieve row index of specific campaign idx_campaign = db["campaign"] == campaign_id # - Define event_ids for the campaign campaign_event_ids = _define_event_id( timesteps=db.loc[idx_campaign, "datetime"], maximum_interval_without_timesteps=max_interval_without_images, ) # - Add offset to ensure having an unique event_id across all campaigns campaign_event_ids = campaign_event_ids + max_event_id # - Add event_id to the campaign subset of the database db.loc[idx_campaign, "event_id"] = campaign_event_ids # - Update the current maximum event_id max_event_id = max(campaign_event_ids) + 1 # ----------------------------------------------------------. # - Add event_id column to all databases self._cam0["event_id"] = db["event_id"] self._cam1["event_id"] = db["event_id"] self._cam2["event_id"] = db["event_id"] self._triplet["event_id"] = db["event_id"] # ----------------------------------------------------------. # - Add duration and n_triplets for each event to triplet db self._add_event_duration(unit=unit) self._add_event_n_triplets() ##-------------------------------------------------- ## Filtering events utils
[docs] def select_events_with_n_triplets(self, min=0, max=np.inf): """ Select events with number of triplets between min and max. Parameters ---------- min : int, optional Minimum number of triplets. The default is 0. max : int, optional Maximum number of triplets. The default is np.inf. Returns ------- MASCDB MASCDB class instance """ ## Check min and max values validity if not isinstance(min, int): raise TypeError("'min' must be an integer.") if not isinstance(max, (int, float)): raise TypeError("'max' must be an integer (or np.inf).") if min < 0: raise ValueError("'min' must be an integer larger or equal to 0.") if max < 1: raise ValueError("'max' must be an integer larger or equal to 1.") if isinstance(max, float) and max != np.inf: raise ValueError("'max' must be an integer (or np.inf).") if min > max: raise ValueError("'min' must be smaller than 'max'.") if max < min: raise ValueError("'max' must be larger than 'min'.") # ---------------------------------------------------------------------. # Retrieve subset index df_event = self.event idx_event_ids = (df_event["event_n_triplets"] >= min) & (df_event["event_n_triplets"] <= max) event_ids_subset = df_event.loc[idx_event_ids, "event_id"].to_numpy() idx_bool_subset = np.isin(self._triplet["event_id"].to_numpy(), event_ids_subset) # ---------------------------------------------------------------------. # Subset the data and return return self.isel(idx_bool_subset)
## ------------------------------------------------------------------------.
[docs] def select_events_with_duration(self, min=None, max=None): """ Select events with duration between min and max. Parameters ---------- min : (numpy.timedelta64, pandas.Timedelta), optional Minimum duration. The default is 0 ns. max : (numpy.timedelta64, pandas.Timedelta), optional Maximum duration. The default is 1 year. Returns ------- MASCDB MASCDB class instance """ min = np.timedelta64(0, "ns") if min is None else min max = np.timedelta64(365, "D") if max is None else max # Check min and max values validity if not isinstance(min, (np.timedelta64, pd.Timedelta)): raise TypeError("'min' must be a np.timedelta64 or pd.Timedelta object.") if not isinstance(max, (np.timedelta64, pd.Timedelta)): raise TypeError("'max' must be a np.timedelta64 or pd.Timedelta object.") if isinstance(min, pd.Timedelta): min = min.to_numpy() if isinstance(max, pd.Timedelta): max = max.to_numpy() if min < np.timedelta64(0, "ns"): raise ValueError("'min' must be a positive timedelta object") if max < np.timedelta64(0, "ns"): raise ValueError("'max' must be a positive timedelta object (larger than 0).") if min > max: raise ValueError("'min' must be smaller than 'max'.") if max < min: raise ValueError("'max' must be larger than 'min'.") # ---------------------------------------------------------------------. # Retrieve subset index df_event = self.event idx_event_ids = (df_event["event_duration"] >= min) & (df_event["event_duration"] <= max) subset_event_ids = df_event.loc[idx_event_ids, "event_id"].to_numpy() idx_subset = np.isin(self._triplet["event_id"].to_numpy(), subset_event_ids) # Subset the data and return return self.isel(idx_subset)
[docs] def select_events_longest(self, n=1): """ Select MASCDB data corresponding to the 'n' events with longest duration. Parameters ---------- n : int, optional The number of events to retrieve. The default is 1. Returns ------- MASCDB MASCDB class instance """ longest_event_ids = self.arrange("triplet.event_duration", decreasing=True)._triplet["event_id"].iloc[0:n] idx_longest_events = np.isin(self._triplet["event_id"].to_numpy(), longest_event_ids) return self.isel(idx_longest_events)
[docs] def select_events_shortest(self, n=1): """ Select MASCDB data corresponding to the 'n' events with shortest duration. Parameters ---------- n : int, optional The number of events to retrieve. The default is 1. Returns ------- MASCDB MASCDB class instance """ shortest_event_ids = self.arrange("triplet.event_duration", decreasing=False)._triplet["event_id"].iloc[0:n] idx_shortest_events = np.isin(self._triplet["event_id"].to_numpy(), shortest_event_ids) return self.isel(idx_shortest_events)
##-------------------------------------------------- ## Redefine events utils
[docs] def redefine_events( self, max_interval_without_images=None, min_duration=None, max_duration=None, min_n_triplets=None, max_n_triplets=None, unit="ns", ): """ Enable selection and custom definition of an 'event'. If <min/max>_<duration/n_triplets> are specified, the MASCDB will likely be subsetted. Parameters ---------- max_interval_without_images : (numpy.timedelta64, pandas.Timedelta), optional Maximum interval of time without images to consider consecutive images to belong the same event. The default is np.timedelta64(4,'h'). min_duration : (numpy.timedelta64, pandas.Timedelta), optional Minimum duration of an event to retained. The default is numpy.timedelta64(0,'ns'). max_duration : (numpy.timedelta64, pandas.Timedelta), optional Maximum duration of an event to retained. The default is numpy.timedelta64(365,'D'). min_n_triplets : int, optional Minimum number of triplets within an event to retain the event. The default is 0. max_n_triplets : int, optional Maximum number of triplets within an event to retain the event.. The default is Inf. unit : str, optional Unit of timedelta to consider for events definition. The default is "ns". Returns ------- MASCDB MASCDB class instance with the custom event definition. """ max_interval_without_images = ( np.timedelta64(4, "h") if max_interval_without_images is None else max_interval_without_images ) # Copy new instance db = copy.deepcopy(self) # Define event_id db._define_events(max_interval_without_images=max_interval_without_images, unit=unit) # ----------------------------------------------------------. EVENT_FILTERING = False # Select only events with specific min/max n_triplets and duration if (min_n_triplets is not None) or (max_n_triplets is not None): EVENT_FILTERING = True if min_n_triplets is None: min_n_triplets = 0 if max_n_triplets is None: max_n_triplets = np.inf db = db.select_events_with_n_triplets(min=min_n_triplets, max=max_n_triplets) if (min_duration is not None) or (max_duration is not None): EVENT_FILTERING = True if min_duration is None: min_duration = np.timedelta64(0, "ns") if max_duration is None: max_duration = np.timedelta64(365, "D") db = db.select_events_with_duration(min=min_duration, max=max_duration) # ----------------------------------------------------------. # Ensure event_id incremental order (0,1,..,n_events) if filtering out events if EVENT_FILTERING: event_ids = db._triplet["event_id"].to_numpy() event_ids_new = np.unique(event_ids, return_inverse=True)[1] db._cam0["event_id"] = event_ids_new db._cam1["event_id"] = event_ids_new db._cam2["event_id"] = event_ids_new db._triplet["event_id"] = event_ids_new # Return the object return db
####------------------------------------------------------------------------. ################################# #### Image plotting routines #### #################################
[docs] def plot_triplets( self, indices=None, random=False, n_triplets=1, enhancement="histogram_equalization", zoom=True, squared=True, wspace=0.01, hspace=0.01, **kwargs, ): """ Plotting routine to display specific triplets of MASC snowflake images. By default: - images are enhanced with histogram_equalization and zoomed. - 'n_triplets' and 'random' are effective only if 'indices' are not specified. - If indices are unspecified, the chosen triplets correspond to the first 'n_triplets' of MASCDB. Parameters ---------- indices : (int, list), optional Integer list of rows to display. The default is None. random : bool, optional Specify if the displayed MASCDB triplets must be chosen randomly. It's effective only if 'indices' are not specified. The default is False. n_triplets : int, optional Specify the number of MASCDB triplets to be displayed. It's effective only if 'indices' are not specified. The default is 1. enhancement : str, optional Type of enhancement to use to improve the image quality. Valid enhancements are : [None, "histogram_equalization", "contrast_stretching", "local_equalization"] The default is "histogram_equalization". zoom : bool, optional Specify if zooming close to the snowflake bounding box. The image shape is defined by selecting the smallest possible shapes across all the snowflakes to be plotted The default is True. squared : bool, optional Specify if the zoomed images must have equal height,width. The default is True. hspace : float Define the space across images in the vertical dimension. The default is 0.01. wspace : float Define the space across images in the horizontal dimension. The default is 0.01. **kwargs : dict Optional arguments to be passed to DataArray.plot. Returns ------- xarray.plot.facetgrid.FacetGrid FacetGrid object for additional customization """ # --------------------------------------------------. # Retrieve number of valid index n_idxs = len(self._triplet.index) if n_idxs == 0: raise ValueError("No data to plot.") # --------------------------------------------------. # Check args _check_random(random) _check_zoom(zoom) _check_enhancement(enhancement) _check_n_triplets(n_triplets, vmax=n_idxs) # --------------------------------------------------. # Define index if is not provided if indices is None: indices = list(np.random.choice(n_idxs, n_triplets)) if random else list(np.arange(0, n_triplets)) # --------------------------------------------------. # Check validity of indices indices = _check_indices(indices, vmax=n_idxs - 1) # --------------------------------------------------. # Subset triplet(s) images da_subset = self._da.isel(flake_id=indices).transpose(..., "cam_id", "flake_id") # --------------------------------------------------. # Apply enhancements if enhancement is not None: if enhancement == "histogram_equalization": da_subset = xri_hist_equalization(da_subset, adaptive=True) elif enhancement == "contrast_stretching": da_subset = xri_contrast_stretching(da_subset, pmin=2, pmax=98) elif enhancement == "local_equalization": da_subset = xri_local_hist_equalization(da_subset) # --------------------------------------------------. # Zoom all images to same extent if zoom: da_subset = xri_zoom(da_subset, squared=squared) # --------------------------------------------------. # Plot triplet(s) row = "flake_id" if len(indices) > 1 else None p = da_subset.plot( x="x", y="y", col="cam_id", row=row, aspect=1, yincrease=False, cmap="gray", add_colorbar=False, vmin=0, vmax=255, **kwargs, ) # Nice layout for _i, ax in enumerate(p.axes.flat): ax.set_xlabel("") ax.set_ylabel("") ax.set_axis_off() p.fig.subplots_adjust(wspace=wspace, hspace=hspace) # --------------------------------------------------. return p
[docs] def plot_flake( self, cam_id=None, index=None, random=False, enhancement="histogram_equalization", zoom=True, squared=True, ax=None, **kwargs, ): """ Plotting routine to display a specific MASC snowflake image. By default: - The image is enhanced with histogram_equalization and zoomed. - 'random' is effective only if 'index' is not specified. - If index is unspecified, it plot an image of the first MASCDB triplet. Parameters ---------- cam_id : int, optional The camera from which display the snowflake image. If not specified, the camera is randomly chosen. Valid cam_id values are 0, 1 and 2. The default is None. index : int, optional Row index of the MASCDB triplet image to display. The default is None. random : bool, optional Specify if the displayed MASCDB image must be chosen randomly. It's effective only if 'index' is not specified. The default is False. enhancement : str, optional Type of enhancement to use to improve the image quality. Valid enhancements are : [None, "histogram_equalization", "contrast_stretching", "local_equalization"] The default is "histogram_equalization". zoom : bool, optional Specify if zooming close to the snowflake bounding box. The image shape is defined by selecting the smallest possible shape to include the entire snowflake. The default is True. squared : bool, optional Specify if the zoomed images must have equal height,width. The default is True. ax: matplotlib.axes.Axes, optional Optional matplotlib axis on which to plot the image. The default is None. **kwargs : dict Optional arguments to be passed to DataArray.plot. """ # Check args _check_random(random) _check_zoom(zoom) _check_enhancement(enhancement) # --------------------------------------------------. # Retrieve number of valid index n_idxs = len(self._triplet.index) if n_idxs == 0: raise ValueError("No data to plot.") # --------------------------------------------------. # Define index if is not provided if index is None: index = next(iter(np.random.choice(n_idxs, 1))) if random else 0 if cam_id is None: cam_id = list(np.random.choice([0, 1, 2], 1)) if random else 1 # --------------------------------------------------. # Check validty of cam_id and index cam_id = _check_cam_id(cam_id) index = _check_index(index, vmax=n_idxs - 1) # --------------------------------------------------. # Subset triplet(s) images # - If cam_id is an integer (instead of list of length 1), then the cam_id dimension is dropped) da_img = self._da.isel(flake_id=index, cam_id=cam_id) # --------------------------------------------------. # Apply enhancements if enhancement is not None: if enhancement == "histogram_equalization": da_img = xri_hist_equalization(da_img, adaptive=True) elif enhancement == "contrast_stretching": da_img = xri_contrast_stretching(da_img, pmin=2, pmax=98) elif enhancement == "local_equalization": da_img = xri_local_hist_equalization(da_img) # --------------------------------------------------. # Zoom all images to same extent if zoom: da_img = xri_zoom(da_img, squared=squared) # --------------------------------------------------. # Plot single image # - TODO: 'aspect' cannot be specified without 'size' p = da_img.plot( x="x", y="y", ax=ax, yincrease=False, cmap="gray", add_colorbar=False, vmin=0, vmax=255, **kwargs, ) # --------------------------------------------------. return p
[docs] def plot_flakes( self, cam_id=None, indices=None, random=False, n_images=9, col_wrap=3, enhancement="histogram_equalization", zoom=True, squared=True, hspace=0.1, wspace=0.1, **kwargs, ): """ Plotting routine to display MASC snowflake images. By default: - images are enhanced with histogram_equalization and zoomed. - 'n_images' and 'random' are effective only if 'indices' are not specified. - If indices are unspecified: * If cam_id is unspecified: it displays the first 'n_images' from a randomly selected camera of MASCDB. * If cam_id specify 1 camera: it displays the first 'n_images' of the specified camera of MASCDB. * If cam_id specifies more than 1 camera: it displays the first 'n_images' of each of the specified camera of MASCDB. Parameters ---------- cam_id : (int, list), optional The camera(s) from which display the snowflake images. If not specified, a single camera is randomly chosen. If specified, it can be any subset of the 3 camera. Valid cam_id values are 0, 1 and 2. The default is None. indices : (int, list), optional Integer list of rows to display. The default is None. random : bool, optional Specify if the displayed MASCDB images must be chosen randomly. It's effective only if 'indices' are not specified. The default is False. n_images : int, optional Specify the number of MASCDB images to be displayed for each camera. It's effective only if 'indices' are not specified. The default is 1. enhancement : str, optional Type of enhancement to use to improve the image quality. Valid enhancements are : [None, "histogram_equalization", "contrast_stretching", "local_equalization"] The default is "histogram_equalization". zoom : bool, optional Specify if zooming close to the snowflake bounding box. The image shape is defined by selecting the smallest possible shapes across all the snowflakes to be plotted. The default is True. squared : bool, optional Specify if the zoomed images must have equal height,width. The default is True. hspace : float Define the space across images in the vertical dimension. The default is 0.1. wspace : float Define the space across images in the horizontal dimension. The default is 0.1. **kwargs : dict Optional arguments to be passed to DataArray.plot. Returns ------- xarray.plot.facetgrid.FacetGrid FacetGrid object for additional customization """ # -------------------------------------------------- # Check args n_idxs = len(self) _check_random(random) _check_zoom(zoom) _check_enhancement(enhancement) # -------------------------------------------------- # Check cam_id if cam_id is None: cam_id = np.random.choice([0, 1, 2], 1).tolist() if isinstance(cam_id, int): cam_id = [cam_id] # -------------------------------------------------- # Define indices if is not provided _check_n_images(n_images, vmax=n_idxs) if indices is None: indices = list(np.random.choice(n_idxs, n_images)) if random else list(np.arange(0, n_images)) # -------------------------------------------------- # Check indices and recompute n_images indices = _check_indices(indices, vmax=n_idxs - 1) if isinstance(indices, int): indices = [indices] n_images = len(indices) * len(cam_id) # -------------------------------------------------- # If a single flake is specified, plot it with plot_flake if len(indices) == 1 and len(cam_id) == 1: print("It's recommended to use 'plot_flake()' to plot a single image.") return self.plot_flake( index=indices[0], cam_id=cam_id, random=random, enhancement=enhancement, zoom=zoom, **kwargs, ) # --------------------------------------------------. # Retrieve DataArray and subset cam_id da = self.da da = da.isel(cam_id=cam_id, flake_id=indices) # --------------------------------------------------. # If more than 1 between cam_id and flake_id dimensions are present --> Stack dims = list(da.dims) unstacked_dims = list(set(dims).difference(["x", "y"])) # - If only x and y, do nothing if len(unstacked_dims) == 0: stack_dict = {} da_stacked = da n_idxs = 1 # - If there is already a third dimension, transpose to the last elif len(unstacked_dims) == 1: img_id = unstacked_dims[0] da_stacked = da n_idxs = len(da_stacked[img_id]) # stack_dict = {} # da_stacked = da.stack(stack_dict).transpose(..., img_id) # n_idxs = len(da_stacked[img_id]) # - If there is more than 3 dimensions, stack it all into a new third dimension elif len(unstacked_dims) > 1: img_id = "img_id" stack_dict = {img_id: ("cam_id", "flake_id")} # stack_dict = {img_id: unstacked_dims} # Stack all additional dimensions into a 3D array with all img_id in the last dimension da_stacked = da.stack(stack_dict).transpose(..., img_id) n_idxs = len(da_stacked[img_id]) else: raise NotImplementedError # --------------------------------------------------. da_subset = da_stacked # --------------------------------------------------. # Apply enhancements if enhancement is not None: if enhancement == "histogram_equalization": da_subset = xri_hist_equalization(da_subset, adaptive=True) elif enhancement == "contrast_stretching": da_subset = xri_contrast_stretching(da_subset, pmin=2, pmax=98) elif enhancement == "local_equalization": da_subset = xri_local_hist_equalization(da_subset) # --------------------------------------------------. # Zoom all images to same extent if zoom: da_subset = xri_zoom(da_subset, squared=squared) # --------------------------------------------------. # Retrieve title from stacked dimension xr_indexes = da_subset.img_id.xindexes[img_id] FLAG_MULTI_INDEX = False if isinstance(xr_indexes, xr.core.indexes.PandasMultiIndex): FLAG_MULTI_INDEX = True pd_indexes = xr_indexes.to_pandas_index() names = list(pd_indexes.names) titles = [] for i in range(n_images): tmp_str_list = ", ".join([str(pd_indexes.get_level_values(name)[i]) for name in names]) # tmp_str_list = ", ".join([name + ": " + str(pd_indexes.get_level_values(name)[i]) for name in names]) titles.append(tmp_str_list) # --------------------------------------------------. # Plot flakes(s) row = img_id # if len(indices) > 1 else None p = da_subset.plot( x="x", y="y", row=row, col_wrap=col_wrap, aspect=1, yincrease=False, cmap="gray", add_colorbar=False, vmin=0, vmax=255, **kwargs, ) # Nice layout for i, ax in enumerate(p.axes.flat): ax.set_xlabel("") ax.set_ylabel("") ax.set_axis_off() if FLAG_MULTI_INDEX and i < len(titles): ax.set_title(titles[i]) p.fig.subplots_adjust(wspace=wspace, hspace=hspace) # --------------------------------------------------. return p
####-----------------------------------------------------------------------. ####################### #### MASCDB Updates ### #######################
[docs] def compute_2Dimage_descriptors(self, fun, labels, fun_kwargs=None, force=False, dask="parallelized"): """Compute and add user-specific image descriptors to the CAM dataframes. It requires the specification of a function ('fun') expecting the image 2D array and returning the descriptor(s) value(s). It also require the specification of the expected descriptors names ('labels'). Parameters ---------- fun : callable A function computing the descriptor(s) of a 2D image. The function must expects a grayscale 2D array and return the descriptor(s) value(s). labels : (str, list) String or list of string specifying the descriptor names computed by 'fun'. These labels will become the columns added to cam dataframe. fun_kwargs : dict, optional Optional arguments to be passed to 'fun'. The default is None. force : bool, optional force=True enable to overwrite existing descriptors present in the cam dataframes. The default is False. dask : str, optional Option to be passed to xr.apply_u_func. The default is "parallelized". Returns ------- MASCDB MASCDB class instance with new descriptors in cam dataframes. """ # ---------------------------------------------------------------------. # Check if specified labels are already columns of mascdb.cam* existing_cam_columns = list(self._cam0.columns) overwritten_columns = list(np.array(existing_cam_columns)[np.isin(existing_cam_columns, labels)]) if not force and len(overwritten_columns) > 0: raise ValueError( f"Columns {overwritten_columns} would be overwritten. " "Specify force=True if you want to overwrite existing columns.", ) # ---------------------------------------------------------------------. # Compute descriptors da_descriptors = _compute_2Dimage_descriptors( da=self.da, fun=fun, labels=labels, fun_kwargs=fun_kwargs, dask=dask, ) # ---------------------------------------------------------------------. # Retrieve cam dataframes cam0 = da_descriptors.isel(cam_id=0).to_dataset("descriptor").to_pandas().drop(columns="cam_id") cam1 = da_descriptors.isel(cam_id=1).to_dataset("descriptor").to_pandas().drop(columns="cam_id") cam2 = da_descriptors.isel(cam_id=2).to_dataset("descriptor").to_pandas().drop(columns="cam_id") # ---------------------------------------------------------------------. # Attach to new mascdb instance new_mascdb = self.add_cam_columns(cam0=cam0, cam1=cam1, cam2=cam2, force=force, complete=True) # Return new mascdb instance return new_mascdb
[docs] def add_cam_columns(self, cam0, cam1, cam2, force=False, complete=True): """ Method allowing to safely add columns to cam dataframes of MASCDB. Parameters ---------- cam0 : pandas.DataFrame pd.DataFrame with index 'flake_id' . cam1 : pandas.DataFrame pd.DataFrame with index 'flake_id' . cam2 : pandas.DataFrame pd.DataFrame with index 'flake_id' . force : bool, optional Whether to overwrite existing column of mascdb. The default is False. complete : bool, optional Whether to merge only when the cam dataframes have same 'flake_id' of the current mascdb. The default is True. Returns ------- MASCDB MASCDB class instance """ # ---------------------------------------------------------------------. # Copy new instance db = copy.deepcopy(self) # ---------------------------------------------------------------------. # Check all cam* are pd.DataFrame cam0 = _check_df(cam0, name="cam0") cam1 = _check_df(cam1, name="cam1") cam2 = _check_df(cam2, name="cam2") # Check length is the same across all dataframes n_cam0 = len(cam0) n_cam1 = len(cam1) n_cam2 = len(cam2) if not n_cam0 == n_cam1: raise ValueError(f"cam0 has {n_cam0} rows, while cam1 has {n_cam1} rows.") if not n_cam0 == n_cam2: raise ValueError(f"cam0 has {n_cam0} rows, while cam2 has {n_cam2} rows.") # Check columns are the same across all dataframes cam0_columns = np.sort(list(cam0.columns)) cam1_columns = np.sort(list(cam1.columns)) cam2_columns = np.sort(list(cam2.columns)) if not np.array_equal(cam0_columns, cam1_columns): raise ValueError("cam0 and cam1 does not have the same column names.") if not np.array_equal(cam0_columns, cam2_columns): raise ValueError("cam0 and cam2 does not have the same column names.") # ---------------------------------------------------------------------. ### - Check cam flake_id match each others cam0_flake_ids = np.sort(cam0.index.to_numpy().astype(str)) cam1_flake_ids = np.sort(cam0.index.to_numpy().astype(str)) cam2_flake_ids = np.sort(cam0.index.to_numpy().astype(str)) if not np.array_equal(cam0_flake_ids, cam1_flake_ids): raise ValueError("cam0 and cam1 does not have the same 'flake_id' index.") if not np.array_equal(cam0_flake_ids, cam2_flake_ids): raise ValueError("cam0 and cam2 does not have the same 'flake_id' index.") # ---------------------------------------------------------------------. # Check if column names already exist in mascdb.cam* existing_cam_columns = list(db._cam0.columns) overwritten_columns = list(np.array(existing_cam_columns)[np.isin(existing_cam_columns, cam0_columns)]) if not force and len(overwritten_columns) > 0: raise ValueError( f"Columns {overwritten_columns} would be overwritten. " "Specify force=True if you want to overwrite existing columns.", ) # ---------------------------------------------------------------------. # Drop columns that must be overwritten if len(overwritten_columns) > 0: _ = db._cam0.drop(columns=overwritten_columns, inplace=True) _ = db._cam1.drop(columns=overwritten_columns, inplace=True) _ = db._cam2.drop(columns=overwritten_columns, inplace=True) # ---------------------------------------------------------------------. # Ensure columns order is the same across all dataframes cam0 = cam0[cam0_columns] cam1 = cam1[cam1_columns] cam1 = cam2[cam2_columns] # ---------------------------------------------------------------------. ### - Check flake_id match between mascdb and provided cam dataframes new_flake_ids = cam0_flake_ids existing_flake_ids = db._cam0.index.to_numpy().astype(str) missing_flake_ids = existing_flake_ids[np.isin(existing_flake_ids, new_flake_ids, invert=True)] matching_flake_ids = new_flake_ids[np.isin(new_flake_ids, existing_flake_ids)] non_matching_flake_ids = new_flake_ids[np.isin(new_flake_ids, existing_flake_ids, invert=True)] # ----------------------------------------------------------------------. # Check at least 1 flake_id match if len(matching_flake_ids) == 0: raise ValueError("No matching 'flake_id' between current mascdb and provided cam dataframes.") # ---------------------------------------------------------------------. # Check flake_id index and number of rows of new cam correspond to existing one if complete: # Check that there are the same flake_id if len(missing_flake_ids) > 0: msg = ( f"There are {len(missing_flake_ids)} flake_id missing in the provided cam dataframes. \n " "If you want to still merge the new columns, specify complete=False. \n " "New columns with non-matching rows will be filled by NaN." ) raise ValueError(msg) # Check number of rows if db._n_triplets != n_cam0: msg = ( f"The provided cam dataframes have {n_cam0} rows, while " f"the current mascdb has {db._n_triplets} rows. \n " "If you want to still merge the new columns, specify complete=False. \n " "New columns with non-matching rows will be filled by NaN." ) raise ValueError(msg) # ---------------------------------------------------------------------. # Print a message if some flake_id does not have a match if len(non_matching_flake_ids) > 0: msg = ( f"There are {len(non_matching_flake_ids)} flake_id in the provided cam dataframes which " "will not be merged to the mascdb because of non-matching flake_id." ) print(msg) # ---------------------------------------------------------------------. # Join data db._cam0 = db._cam0.merge(cam0, left_index=True, right_index=True, how="left") db._cam1 = db._cam1.merge(cam1, left_index=True, right_index=True, how="left") db._cam2 = db._cam2.merge(cam2, left_index=True, right_index=True, how="left") # ---------------------------------------------------------------------. # Return the new mascdb return db
[docs] def add_triplet_columns(self, df, force=False, complete=True): """ Method allowing to safely add columns to cam dataframes of MASCDB. Parameters ---------- df : pandas.DataFrame pd.DataFrame with index 'flake_id' . force : bool, optional Whether to overwrite existing column of mascdb. The default is False. complete : bool, optional Whether to merge only when the provided dataframe has the same 'flake_id' of the current mascdb triplet dataframe. The default is True. Returns ------- MASCDB MASCDB class instance """ # ---------------------------------------------------------------------. # Copy new instance db = copy.deepcopy(self) # ---------------------------------------------------------------------. # Check df is pd.DataFrame df = _check_df(df, name="df") # Check length is the same across all dataframes n_df = len(df) # Check columns are the same across all dataframes df_columns = list(df.columns) # ---------------------------------------------------------------------. ### - Check df flake_id match each others df_flake_ids = np.sort(df.index.to_numpy().astype(str)) # ---------------------------------------------------------------------. # Check if column names already exist in mascdb.triplet existing_triplet_columns = list(db._triplet.columns) overwritten_columns = list(np.array(existing_triplet_columns)[np.isin(existing_triplet_columns, df_columns)]) if not force and len(overwritten_columns) > 0: raise ValueError( f"Columns {overwritten_columns} would be overwritten. " "Specify force=True if you want to overwrite existing columns.", ) # ---------------------------------------------------------------------. # Drop columns that must be overwritten if len(overwritten_columns) > 0: _ = db._triplet.drop(columns=overwritten_columns, inplace=True) # ---------------------------------------------------------------------. ### - Check flake_id match between mascdb and provided dataframes new_flake_ids = df_flake_ids existing_flake_ids = db._triplet.index.to_numpy().astype(str) missing_flake_ids = existing_flake_ids[np.isin(existing_flake_ids, new_flake_ids, invert=True)] matching_flake_ids = new_flake_ids[np.isin(new_flake_ids, existing_flake_ids)] non_matching_flake_ids = new_flake_ids[np.isin(new_flake_ids, existing_flake_ids, invert=True)] # ---------------------------------------------------------------------. # Check at least 1 flake_id match if len(matching_flake_ids) == 0: raise ValueError("No matching 'flake_id' between current mascdb and provided dataframe.") # ---------------------------------------------------------------------. # Check flake_id index and number of rows of new cam correspond to existing one if complete: # Check that there are the same flake_id if len(missing_flake_ids) > 0: msg = ( f"There are {len(missing_flake_ids)} flake_id missing in the provided dataframe. \n " "If you want to still merge the new columns, specify complete=False. \n " "New columns with non-matching rows will be filled by NaN." ) raise ValueError(msg) # Check number of rows if db._n_triplets != n_df: msg = ( f"The provided dataframe have {n_df} rows, while " f"the current mascdb has {db._n_triplets} rows. \n " "If you want to still merge the new columns, specify complete=False. \n " "New columns with non-matching rows will be filled by NaN." ) raise ValueError(msg) # ---------------------------------------------------------------------. # Print a message if some flake_id does not have a match if len(non_matching_flake_ids) > 0: msg = ( f"There are {len(non_matching_flake_ids)} flake_id in the provided df dataframe which " "will not be merged to the mascdb because of non-matching flake_id." ) print(msg) # ---------------------------------------------------------------------. # Join data db._triplet = db._triplet.merge(df, left_index=True, right_index=True, how="left") # ---------------------------------------------------------------------. # Return the new mascdb return db
[docs] def drop_cam_columns(self, columns): """ Method allowing to safely remove columns from all cam dataframes of MASCDB. Parameters ---------- columns : list List with column names of MASCDB cam dataframes to be removed Returns ------- MASCDB MASCDB class instance """ # ---------------------------------------------------------------------. # Copy new instance db = copy.deepcopy(self) # ---------------------------------------------------------------------. # Check columns columns = _check_columns(columns) # Check columns are valid columns current_columns = np.array(list(db._cam0.columns)) invalid_columns = np.array(columns)[np.isin(columns, current_columns, invert=True)] if len(invalid_columns) > 0: raise ValueError(f"{invalid_columns.tolist()} are not columns of cam dataframes.") # ---------------------------------------------------------------------. # - Remove columns columns = columns.tolist() _ = db._cam0.drop(columns=columns, inplace=True) _ = db._cam1.drop(columns=columns, inplace=True) _ = db._cam2.drop(columns=columns, inplace=True) # ---------------------------------------------------------------------. return db
[docs] def drop_triplet_columns(self, columns): """ Method allowing to safely remove columns from the MASCDB triplet dataframe. Parameters ---------- columns : list List with column names of cam dataframes of MASCDB to be removed Returns ------- MASCDB MASCDBclass instance """ # ---------------------------------------------------------------------. # Copy new instance db = copy.deepcopy(self) # ---------------------------------------------------------------------. # Check columns columns = _check_columns(columns) # Check columns are valid columns current_columns = np.array(list(db._triplet.columns)) invalid_columns = np.array(columns)[np.isin(columns, current_columns, invert=True)] if len(invalid_columns) > 0: raise ValueError(f"{invalid_columns.tolist()} are not columns of cam dataframes.") # ---------------------------------------------------------------------. # - Remove columns columns = columns.tolist() _ = db._triplet.drop(columns=columns, inplace=True) # ---------------------------------------------------------------------. return db