Module ablation.pytorch_explanations

Expand source code
from typing import Optional

import numpy as np
import torch
from captum._utils.models.linear_model import SkLearnLinearRegression
from captum.attr import DeepLift, DeepLiftShap, IntegratedGradients, KernelShap, Lime
from captum.attr._core.lime import get_exp_kernel_similarity_function
from pytorch_lightning import seed_everything

from .baseline import (
    ConstantBaseline,
    ManyToOneBaseline,
    OneToOneBaseline,
    SampleBaseline,
)
from .utils.logging import timing
from .utils.model import _as_numpy, _predict_proba_fn, _torch_float, _torch_long

CAPTUM_METHODS = {
    "integrated_gradients": IntegratedGradients,
    "deep_lift": DeepLift,
    "deep_shap": DeepLiftShap,
    "kernel_shap": KernelShap,
    "lime": Lime,
}
SAMPLE_BASED_METHODS = ["deep_shap"]


@timing
def captum_explanation(
    method: str,
    model: torch.nn.Module,
    X: np.ndarray,
    baseline: np.ndarray,
    target: Optional[np.ndarray] = None,
    random_state: Optional[int] = None,
    **kwargs,
):
    """Collection of captum based explanation methods following the
        On-Baselines paper.

        Baselines either are constant (one sample), have a one-to-one
        relationship with the observations (max_distance), or are a sample
        from a larger distribution that needs to be averaged over.

    Args:
        method (str): name of explanation method
        model (torch.nn.Module): torch model
        X (np.ndarray): Observations to derive explainations from
        baseline (np.ndarray): Baseline array.
        target (np.ndarray, optional): If multiclass, should be array of integers.
                If it is a binary problem it should be None. Defaults to None.
    """
    seed_everything(random_state)

    if method == "random":
        return np.random.rand(*X.shape)

    X = _torch_float(X)
    baseline_type = type(baseline)
    baseline = _torch_float(baseline)
    y_pred = model(X)

    if target is None and y_pred.size(1) > 1:
        # Use predicted class for multiclass
        target = y_pred.argmax(-1)
    elif target is not None:
        # Use target if provided
        target = _torch_long(target)

    if method not in CAPTUM_METHODS:
        raise ValueError(
            f"Method {method} not in available explanation methods. "
            f"""Choose one of "{'", "'.join(CAPTUM_METHODS.keys())}" """
        )

    explainer_args = {}

    if method == "lime":
        explainer_args.update(
            {
                "interpretable_model": SkLearnLinearRegression(),
                "similarity_func": get_exp_kernel_similarity_function(
                    distance_mode="euclidean",
                    kernel_width=np.sqrt(X.size(1)) * 0.75,
                ),
            }
        )
    if method == "deep_shap" and baseline_type in [
        ConstantBaseline,
        OneToOneBaseline,
        ManyToOneBaseline,
    ]:
        # Deep shap cannot be used for constant, one to one baselines, or many to one baselines
        method = "deep_lift"

    exp = CAPTUM_METHODS[method](model, **explainer_args)

    if baseline_type == ManyToOneBaseline:
        # Many to one baseline
        # baseline shape: (# observations in X, # baseline samples per observation, # features)
        attr = torch.zeros_like(X, device=X.device)
        for b in baseline.permute(1, 0, 2):
            attr += exp.attribute(inputs=X, baselines=b, target=target, **kwargs)
        return _as_numpy(attr / baseline.size(1))

    elif baseline_type == SampleBaseline and method not in SAMPLE_BASED_METHODS:
        # Non-sample based method and sample baseline
        # baseline shape: (# baseline samples, # features)
        attr = torch.zeros_like(X, device=X.device)
        for b in baseline:
            attr += exp.attribute(
                inputs=X, baselines=b.unsqueeze(0), target=target, **kwargs
            )
        return _as_numpy(attr / baseline.size(0))

    return _as_numpy(
        exp.attribute(inputs=X, baselines=baseline, target=target, **kwargs)
    )