Module ablation.perturb

Expand source code
import numpy as np

from . import distributions
from .distributions import (
    constant,
    constant_mean,
    constant_median,
    gaussian,
    gaussian_blur,
    marginal,
    max_distance,
)

PERTURBATIONS = [
    "gaussian_blur",
    "constant",
    "constant_mean",
    "constant_median",
    "gaussian",
    "marginal",
    "max_distance",
]


def generate_perturbation_distribution(
    method: str, X: np.ndarray, X_obs: np.ndarray, random_state=42, **kwargs
) -> np.ndarray:
    """Generate perturbation distribution to be used for ablation

    Args:
        method (str): Perturbation method
        X (np.ndarray): Data for source distribution
        X_obs (np.ndarray): Data observations to perturb
        random_state (Optional[int]): random seed


    Returns:
        np.ndarray: Perturbed dataset
    """
    np.random.seed(random_state)

    if method == "gaussian_blur":
        perturbation = gaussian_blur(X=X_obs, **kwargs)
    elif method == "constant":
        perturbation = constant(X=X_obs, **kwargs)
    elif method == "constant_mean":
        perturbation = constant_mean(X=X, **kwargs)
    elif method == "constant_median":
        perturbation = constant_median(X=X, **kwargs)
    elif method == "max_distance":
        perturbation = max_distance(X=X, X_obs=X_obs, **kwargs)
    elif method == "gaussian":
        perturbation = gaussian(X=X_obs, **kwargs)
    elif method == "marginal":
        perturbation = marginal(X=X, X_obs=X_obs, **kwargs)
    else:
        raise ValueError(f"Perturbation method '{method}' does not exist!")

    if method in distributions.CONSTANT:
        return np.tile(perturbation, (len(X_obs), 1))

    return perturbation

Functions

def generate_perturbation_distribution(method: str, X: numpy.ndarray, X_obs: numpy.ndarray, random_state=42, **kwargs) ‑> numpy.ndarray

Generate perturbation distribution to be used for ablation

Args

method : str
Perturbation method
X : np.ndarray
Data for source distribution
X_obs : np.ndarray
Data observations to perturb
random_state : Optional[int]
random seed

Returns

np.ndarray
Perturbed dataset
Expand source code
def generate_perturbation_distribution(
    method: str, X: np.ndarray, X_obs: np.ndarray, random_state=42, **kwargs
) -> np.ndarray:
    """Generate perturbation distribution to be used for ablation

    Args:
        method (str): Perturbation method
        X (np.ndarray): Data for source distribution
        X_obs (np.ndarray): Data observations to perturb
        random_state (Optional[int]): random seed


    Returns:
        np.ndarray: Perturbed dataset
    """
    np.random.seed(random_state)

    if method == "gaussian_blur":
        perturbation = gaussian_blur(X=X_obs, **kwargs)
    elif method == "constant":
        perturbation = constant(X=X_obs, **kwargs)
    elif method == "constant_mean":
        perturbation = constant_mean(X=X, **kwargs)
    elif method == "constant_median":
        perturbation = constant_median(X=X, **kwargs)
    elif method == "max_distance":
        perturbation = max_distance(X=X, X_obs=X_obs, **kwargs)
    elif method == "gaussian":
        perturbation = gaussian(X=X_obs, **kwargs)
    elif method == "marginal":
        perturbation = marginal(X=X, X_obs=X_obs, **kwargs)
    else:
        raise ValueError(f"Perturbation method '{method}' does not exist!")

    if method in distributions.CONSTANT:
        return np.tile(perturbation, (len(X_obs), 1))

    return perturbation