Module ablation.experiment

Expand source code
import logging
import os
import pickle
import shutil
import time
from copy import copy
from typing import Any, Dict, List, Union
from warnings import warn

import numpy as np
import pandas as pd
import yaml

from .ablation import Ablation
from .baseline import generate_baseline_distribution
from .dataset import load_data
from .explanations import Explanations
from .perturb import generate_perturbation_distribution
from .pytorch_explanations import captum_explanation
from .pytorch_model import load_model, train
from .utils.evaluate import eval_model_performance
from .utils.logging import logger
from .utils.model import _predict_fn
from .utils.transform import le_to_ohe, ohe_to_le


class Config:
    def __init__(
        self,
        dataset_name: str,
        model_type: str,
        perturbation_config: Dict[str, Dict[str, int]],
        baseline_config: Dict[str, Dict[str, int]],
        explanation_methods: List[str],
        ablation_args: Dict[str, Any],
        dataset_sample_perc: float = 1.0,
        dataset_n_random_features: int = 0,
        n_trials: int = 1,
        path: Union[str, os.PathLike] = "tmp",
        load=False,
        rerun_ablation=False,
    ):
        """Experiment config

        Args:
            dataset_name (str): name of dataset
            model_type (str): model type ('nn' or 'linear')
            perturbation_config (Dict[str, Dict[str, int]]): dict of perturbation names associated with dicts of extra arguments
            baseline_config (Dict[str, Dict[str, int]]): dict of baseline names associated with dicts of extra arguments
            explanation_methods (List[str]): list of explanation methods
            ablation_args (Dict[str, Any]): dict of ablation arguments
            dataset_sample_perc (float): percent of dataset to use for experiment
            dataset_n_random_features (int): number of random features to add to dataset for sanity check
            n_trials (int): Number of seeds to run experiment
            path (str, optional): path to save results, model, and intermediary calculations. Defaults to "tmp".
            load (bool, optional): If true, will load from path. Defaults to False.
            rerun_ablation (bool, optional): If true, will rerun ablation on load
        """
        self.dataset_name = dataset_name
        self.model_type = model_type
        self.dataset_sample_perc = dataset_sample_perc
        self.dataset_n_random_features = dataset_n_random_features
        self.perturbation_config = perturbation_config
        self.baseline_config = baseline_config
        self.explanation_methods = explanation_methods
        self.ablation_args = ablation_args
        self.n_trials = n_trials
        self.path = path
        self.load = load
        self.rerun_ablation = rerun_ablation

        self._check()

    def _check(self):

        assert (
            isinstance(self.n_trials, int) and self.n_trials >= 1
        ), "n_trials must be an integer >=1"

        assert (
            len(self.perturbation_config) > 0
        ), "Must specify at least one perturbation"

        assert (
            len(self.explanation_methods) > 0
        ), "Must specify at least one explanation method"

        assert len(self.baseline_config) > 0, "Must specify at least one baseline"

        assert self.model_type in [
            "nn",
            "linear",
        ], "model_type must be type 'nn' or 'linear'"

    @classmethod
    def from_yaml_file(cls, path: Union[str, os.PathLike]):
        config = yaml.safe_load(open(path, "r"))
        return cls(**config)

    @classmethod
    def from_dict(cls, dict_):
        return cls(**dict_)

    def to_dict(self):
        return copy(self.__dict__)

    def save(self):
        return yaml.dump(
            self.__dict__,
            open(os.path.join(self.path, "config.yml"), "w"),
            allow_unicode=True,
        )

    @classmethod
    def load(cls, path: Union[str, os.PathLike]):
        return cls.from_yaml_file(os.path.join(path, "config.yml"))

    @property
    def result_name(self):
        """Name of results file based on ablation args"""
        explanation = "local" if self.ablation_args["local"] else "global"
        return f"results-{explanation}"

    def diff(self, other: object) -> bool:
        experiment_args = [
            "dataset_name",
            "perturbation_config",
            "baseline_config",
            "explanation_methods",
            "ablation_args",
            "dataset_sample_perc",
            "dataset_n_random_features",
            "n_trials",
            "model_type",
        ]
        _self = self.__dict__
        _other = other.__dict__
        diff = [arg for arg in experiment_args if _self[arg] != _other[arg]]

        return diff

    def __eq__(self, other: object) -> bool:
        return len(self.diff(other)) == 0


class Experiment:
    def __init__(self, config: Config):
        """Experiment runner

        Args:
            config (Config): configuration
        """
        self.config = config

        self._clean_dir()
        self._config_check()
        self.config.save()
        self._set_logging()

        self.dataset = load_data(
            self.config.dataset_name,
            self.config.dataset_sample_perc,
            self.config.dataset_n_random_features,
        )

        if self.config.load:
            self.model = load_model(self.config.path)
            self.label_shuffed_model = load_model(
                self.config.path, prefix="label_shuffed_model"
            )
        else:
            self.model = train(
                self.dataset,
                max_epochs=100,
                path=self.config.path,
                model_type=self.config.model_type,
            )
            self.label_shuffed_model = train(
                self.dataset,
                max_epochs=100,
                path=self.config.path,
                model_type=self.config.model_type,
                prefix="label_shuffed_model",
                shuffle_labels=True,
            )

    def _set_logging(self):
        fh = logging.FileHandler(os.path.join(self.config.path, "experiment.log"))
        formatter = logging.Formatter(
            "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
        )
        fh.setLevel(logging.INFO)
        fh.setFormatter(formatter)
        logger.addHandler(fh)

    def _config_check(self):
        self.rerun_ablation = self.config.rerun_ablation

        if self.config.load:
            if not os.path.exists(self.config.path):
                raise IOError(f"Path does not exist: {self.config.path}")

            old_config = Config.load(self.config.path)
            diff = self.config.diff(old_config)

            if diff == ["ablation_args"]:
                warn(
                    "Recomputing ablations from previous experiment with settings: "
                    f"{', '.join([f'{k}: {v}' for k,v in self.config.ablation_args.items()])}."
                )
                self.rerun_ablation = True

            elif len(diff) > 0:
                # TODO: Add more functionality to rerun components of the experiment
                raise ValueError(
                    f"Configuration file doesn't match. The following arguments have changed: {', '.join(diff)}."
                    # f"Using config.yml from {self.config.path}."
                    # "Set load = False to run new config."
                )

    def _clean_dir(self):
        """Clean experiment directory"""
        if os.path.exists(self.config.path) and not self.config.load:
            shutil.rmtree(self.config.path, ignore_errors=True)

        if not os.path.exists(self.config.path):
            os.makedirs(self.config.path)

    def _load_if_exists(self, name):
        """Load file if exists"""
        file = os.path.join(self.config.path, f"{name}.pkl")
        if os.path.exists(file):
            return pickle.load(open(file, "rb"))
        return

    def _save(self, obj, name):
        """Save object as pkl file"""
        file = os.path.join(self.config.path, f"{name}.pkl")
        pickle.dump(obj, open(file, "wb"))

    def _compute_model_sanity_checks(self):
        self.model_sanity_checks = self._load_if_exists("model_sanity_checks")

        if self.model_sanity_checks is None:
            self.model_sanity_checks = {}

            self.model_sanity_checks["label_shuffled"] = {
                scoring_method: eval_model_performance(
                    self.label_shuffed_model,
                    self.dataset.X_test,
                    self.dataset.y_test,
                    scoring_method=scoring_method,
                )
                for scoring_method in self.config.ablation_args["scoring_methods"]
            }

            self._save(self.model_sanity_checks, "model_sanity_checks")

        # TODO: Do we want also the default probabilities?
        # _, counts = np.unique(self.dataset.y_train, return_counts=True)
        # naive_pred = lambda x: np.tile(counts / sum(counts), (len(x), 1))
        # self.naive_performance = eval_model_performance(
        #     naive_pred,
        #     self.dataset.X_test,
        #     self.dataset.y_test,
        #     scoring_method=config.ablation_args["scoring_method"],
        # )

    def _compute_perturbations(
        self, perturbation_config: Dict[str, Dict[str, int]]
    ) -> None:
        """Compute perturbations

        Args:
            perturbation_config (Dict[str, Dict[str, int]]): dict of perturbation names associated with dicts of extra arguments
        """
        self.perturbations = self._load_if_exists("perturbations")
        if self.perturbations is None:
            self.perturbations = {
                trial: {
                    name: generate_perturbation_distribution(
                        name,
                        self.dataset.X_train,
                        self.dataset.X_test,
                        random_state=trial,
                        agg_map=self.dataset.agg_map,
                        **kwargs,
                    )
                    for name, kwargs in perturbation_config.items()
                }
                for trial in range(self.config.n_trials)
            }
            self._save(self.perturbations, "perturbations")

    def _compute_baselines(self, baseline_config: Dict[str, Dict[str, int]]) -> None:
        """Compute baselines

        Args:
            baseline_config (Dict[str, Dict[str, int]]): dict of baseline names associated with dicts of extra arguments
        """

        self.baselines = self._load_if_exists("baselines")

        if self.baselines is None:
            numpy_model = _predict_fn(self.model)
            y = numpy_model(self.dataset.X_train)
            y_obs = numpy_model(self.dataset.X_test)

            self.baselines = {
                trial: {
                    name: generate_baseline_distribution(
                        name,
                        self.dataset.X_train,
                        self.dataset.X_test,
                        y,
                        y_obs,
                        random_state=trial,
                        agg_map=self.dataset.agg_map,
                        **kwargs,
                    )
                    for name, kwargs in baseline_config.items()
                }
                for trial in range(self.config.n_trials)
            }
            self._save(self.baselines, "baselines")

    def _compute_explanations(
        self,
        explanation_methods: List[str],
        computed_baselines: Dict[str, np.ndarray],
    ) -> None:
        """Compute explanations

        Args:
            explanation_methods (List[str]): list of explanations methods
            computed_baselines (Dict[str, np.ndarray]): dict of baseline name with associated baselines
        """

        # TODO, we want to support raw explanations in addition to Explanations object
        self.explanations = self._load_if_exists("explanations")

        if self.explanations is None:
            self.explanations = {}
            non_random_exp_methods = [m for m in explanation_methods if m != "random"]
            for trial, baselines in computed_baselines.items():
                self.explanations[trial] = {}
                for method in non_random_exp_methods:
                    self.explanations[trial][method] = {}
                    for bname, baseline in baselines.items():
                        logger.info(
                            f"Running explanation: {method} | baseline: {bname}"
                        )

                        # Calculate the explanation values for set of observations
                        explanations = Explanations(
                            explanation_values=captum_explanation(
                                method,
                                self.model,
                                self.dataset.X_test,
                                baseline,
                                random_state=trial,
                            ),
                            agg_map=self.dataset.agg_map,
                        )

                        self.explanations[trial][method][bname] = explanations

                    if "random" in explanation_methods:
                        logger.info("Running random explanation")
                        explanations = Explanations(
                            explanation_values=captum_explanation(
                                "random",
                                self.model,
                                self.dataset.X_test,
                                baseline,
                                random_state=trial,
                            ),
                            agg_map=self.dataset.agg_map,
                        )
                        self.explanations[trial][method][
                            "random explanation"
                        ] = explanations

            self._save(self.explanations, "explanations")

    def run_exp(self):
        """Run experiment"""

        self._compute_model_sanity_checks()
        self._compute_perturbations(self.config.perturbation_config)
        self._compute_baselines(self.config.baseline_config)
        self._compute_explanations(self.config.explanation_methods, self.baselines)

        self.results = self._load_if_exists(self.config.result_name)

        if self.results is None or self.rerun_ablation:

            trials = []
            for trial in range(self.config.n_trials):
                logger.info(f"Running ablation trial {trial}")
                np.random.seed(trial)
                comb = []
                for p_name, perturb in self.perturbations[trial].items():
                    for exp_name, exp_dict in self.explanations[trial].items():
                        for b_name, exp in exp_dict.items():

                            # TODO: debug kernelshap and lime for overflow
                            # Below is currently a quick fix to exclude these samples
                            overflow_idx = np.unique(
                                np.where(exp.data("sparse") > 10)[0]
                            )
                            if len(overflow_idx) > 0:
                                logger.warn(
                                    f"Overflow warning (perturb: {p_name}, exp: {exp_name}, baseline: {b_name}):\n"
                                    f"{len(overflow_idx)} samples removed. "
                                    f"Indices: {','.join(overflow_idx.astype(str))}"
                                )

                            abl = Ablation(
                                perturbation_distribution=np.delete(
                                    perturb, overflow_idx, 0
                                ),
                                model=self.model,
                                dataset=self.dataset,
                                X=np.delete(self.dataset.X_test, overflow_idx, 0),
                                y=np.delete(self.dataset.y_test, overflow_idx, 0),
                                explanation_values=np.delete(
                                    exp.data("sparse"), overflow_idx, 0
                                ),
                                explanation_values_dense=np.delete(
                                    exp.data("dense"), overflow_idx, 0
                                ),
                                random_feat_idx=self.dataset.dense_random_feat_idx,
                                **self.config.ablation_args,
                            )
                            # abl = Ablation(
                            #     perturbation_distribution=perturb,
                            #     model=self.model,
                            #     X=self.dataset.X_test,
                            #     y=self.dataset.y_test,
                            #     explanation_values=exp,
                            #     random_feat_idx=self.dataset.random_feat_idx,
                            #     **self.config.ablation_args,
                            # )

                            result = abl.ablate_features()

                            n_obs = len(result["scores"])
                            result["explanation_method"] = [exp_name] * n_obs
                            result["baseline"] = [b_name] * n_obs
                            result["perturbation"] = [p_name] * n_obs
                            (
                                result["random_sanity_check_idx"],
                                result["random_sanity_check_perc"],
                            ) = abl.random_sanity_check_idx()
                            result[
                                "random_sanity_check_value"
                            ] = abl.random_sanity_check_value()
                            comb.append(result)

                trial_df = pd.concat(comb)
                trial_df["trial"] = trial
                trials.append(trial_df)

            self.results = pd.concat([t for t in trials]).reset_index(drop=True)
            self._save(self.results, self.config.result_name)

        return self.results

Classes

class Config (dataset_name: str, model_type: str, perturbation_config: Dict[str, Dict[str, int]], baseline_config: Dict[str, Dict[str, int]], explanation_methods: List[str], ablation_args: Dict[str, Any], dataset_sample_perc: float = 1.0, dataset_n_random_features: int = 0, n_trials: int = 1, path: Union[str, os.PathLike] = 'tmp', load=False, rerun_ablation=False)

Experiment config

Args

dataset_name : str
name of dataset
model_type : str
model type ('nn' or 'linear')
perturbation_config : Dict[str, Dict[str, int]]
dict of perturbation names associated with dicts of extra arguments
baseline_config : Dict[str, Dict[str, int]]
dict of baseline names associated with dicts of extra arguments
explanation_methods : List[str]
list of explanation methods
ablation_args : Dict[str, Any]
dict of ablation arguments
dataset_sample_perc : float
percent of dataset to use for experiment
dataset_n_random_features : int
number of random features to add to dataset for sanity check
n_trials : int
Number of seeds to run experiment
path : str, optional
path to save results, model, and intermediary calculations. Defaults to "tmp".
load : bool, optional
If true, will load from path. Defaults to False.
rerun_ablation : bool, optional
If true, will rerun ablation on load
Expand source code
class Config:
    def __init__(
        self,
        dataset_name: str,
        model_type: str,
        perturbation_config: Dict[str, Dict[str, int]],
        baseline_config: Dict[str, Dict[str, int]],
        explanation_methods: List[str],
        ablation_args: Dict[str, Any],
        dataset_sample_perc: float = 1.0,
        dataset_n_random_features: int = 0,
        n_trials: int = 1,
        path: Union[str, os.PathLike] = "tmp",
        load=False,
        rerun_ablation=False,
    ):
        """Experiment config

        Args:
            dataset_name (str): name of dataset
            model_type (str): model type ('nn' or 'linear')
            perturbation_config (Dict[str, Dict[str, int]]): dict of perturbation names associated with dicts of extra arguments
            baseline_config (Dict[str, Dict[str, int]]): dict of baseline names associated with dicts of extra arguments
            explanation_methods (List[str]): list of explanation methods
            ablation_args (Dict[str, Any]): dict of ablation arguments
            dataset_sample_perc (float): percent of dataset to use for experiment
            dataset_n_random_features (int): number of random features to add to dataset for sanity check
            n_trials (int): Number of seeds to run experiment
            path (str, optional): path to save results, model, and intermediary calculations. Defaults to "tmp".
            load (bool, optional): If true, will load from path. Defaults to False.
            rerun_ablation (bool, optional): If true, will rerun ablation on load
        """
        self.dataset_name = dataset_name
        self.model_type = model_type
        self.dataset_sample_perc = dataset_sample_perc
        self.dataset_n_random_features = dataset_n_random_features
        self.perturbation_config = perturbation_config
        self.baseline_config = baseline_config
        self.explanation_methods = explanation_methods
        self.ablation_args = ablation_args
        self.n_trials = n_trials
        self.path = path
        self.load = load
        self.rerun_ablation = rerun_ablation

        self._check()

    def _check(self):

        assert (
            isinstance(self.n_trials, int) and self.n_trials >= 1
        ), "n_trials must be an integer >=1"

        assert (
            len(self.perturbation_config) > 0
        ), "Must specify at least one perturbation"

        assert (
            len(self.explanation_methods) > 0
        ), "Must specify at least one explanation method"

        assert len(self.baseline_config) > 0, "Must specify at least one baseline"

        assert self.model_type in [
            "nn",
            "linear",
        ], "model_type must be type 'nn' or 'linear'"

    @classmethod
    def from_yaml_file(cls, path: Union[str, os.PathLike]):
        config = yaml.safe_load(open(path, "r"))
        return cls(**config)

    @classmethod
    def from_dict(cls, dict_):
        return cls(**dict_)

    def to_dict(self):
        return copy(self.__dict__)

    def save(self):
        return yaml.dump(
            self.__dict__,
            open(os.path.join(self.path, "config.yml"), "w"),
            allow_unicode=True,
        )

    @classmethod
    def load(cls, path: Union[str, os.PathLike]):
        return cls.from_yaml_file(os.path.join(path, "config.yml"))

    @property
    def result_name(self):
        """Name of results file based on ablation args"""
        explanation = "local" if self.ablation_args["local"] else "global"
        return f"results-{explanation}"

    def diff(self, other: object) -> bool:
        experiment_args = [
            "dataset_name",
            "perturbation_config",
            "baseline_config",
            "explanation_methods",
            "ablation_args",
            "dataset_sample_perc",
            "dataset_n_random_features",
            "n_trials",
            "model_type",
        ]
        _self = self.__dict__
        _other = other.__dict__
        diff = [arg for arg in experiment_args if _self[arg] != _other[arg]]

        return diff

    def __eq__(self, other: object) -> bool:
        return len(self.diff(other)) == 0

Static methods

def from_dict(dict_)
Expand source code
@classmethod
def from_dict(cls, dict_):
    return cls(**dict_)
def from_yaml_file(path: Union[str, os.PathLike])
Expand source code
@classmethod
def from_yaml_file(cls, path: Union[str, os.PathLike]):
    config = yaml.safe_load(open(path, "r"))
    return cls(**config)
def load(path: Union[str, os.PathLike])
Expand source code
@classmethod
def load(cls, path: Union[str, os.PathLike]):
    return cls.from_yaml_file(os.path.join(path, "config.yml"))

Instance variables

var result_name

Name of results file based on ablation args

Expand source code
@property
def result_name(self):
    """Name of results file based on ablation args"""
    explanation = "local" if self.ablation_args["local"] else "global"
    return f"results-{explanation}"

Methods

def diff(self, other: object) ‑> bool
Expand source code
def diff(self, other: object) -> bool:
    experiment_args = [
        "dataset_name",
        "perturbation_config",
        "baseline_config",
        "explanation_methods",
        "ablation_args",
        "dataset_sample_perc",
        "dataset_n_random_features",
        "n_trials",
        "model_type",
    ]
    _self = self.__dict__
    _other = other.__dict__
    diff = [arg for arg in experiment_args if _self[arg] != _other[arg]]

    return diff
def save(self)
Expand source code
def save(self):
    return yaml.dump(
        self.__dict__,
        open(os.path.join(self.path, "config.yml"), "w"),
        allow_unicode=True,
    )
def to_dict(self)
Expand source code
def to_dict(self):
    return copy(self.__dict__)
class Experiment (config: Config)

Experiment runner

Args

config : Config
configuration
Expand source code
class Experiment:
    def __init__(self, config: Config):
        """Experiment runner

        Args:
            config (Config): configuration
        """
        self.config = config

        self._clean_dir()
        self._config_check()
        self.config.save()
        self._set_logging()

        self.dataset = load_data(
            self.config.dataset_name,
            self.config.dataset_sample_perc,
            self.config.dataset_n_random_features,
        )

        if self.config.load:
            self.model = load_model(self.config.path)
            self.label_shuffed_model = load_model(
                self.config.path, prefix="label_shuffed_model"
            )
        else:
            self.model = train(
                self.dataset,
                max_epochs=100,
                path=self.config.path,
                model_type=self.config.model_type,
            )
            self.label_shuffed_model = train(
                self.dataset,
                max_epochs=100,
                path=self.config.path,
                model_type=self.config.model_type,
                prefix="label_shuffed_model",
                shuffle_labels=True,
            )

    def _set_logging(self):
        fh = logging.FileHandler(os.path.join(self.config.path, "experiment.log"))
        formatter = logging.Formatter(
            "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
        )
        fh.setLevel(logging.INFO)
        fh.setFormatter(formatter)
        logger.addHandler(fh)

    def _config_check(self):
        self.rerun_ablation = self.config.rerun_ablation

        if self.config.load:
            if not os.path.exists(self.config.path):
                raise IOError(f"Path does not exist: {self.config.path}")

            old_config = Config.load(self.config.path)
            diff = self.config.diff(old_config)

            if diff == ["ablation_args"]:
                warn(
                    "Recomputing ablations from previous experiment with settings: "
                    f"{', '.join([f'{k}: {v}' for k,v in self.config.ablation_args.items()])}."
                )
                self.rerun_ablation = True

            elif len(diff) > 0:
                # TODO: Add more functionality to rerun components of the experiment
                raise ValueError(
                    f"Configuration file doesn't match. The following arguments have changed: {', '.join(diff)}."
                    # f"Using config.yml from {self.config.path}."
                    # "Set load = False to run new config."
                )

    def _clean_dir(self):
        """Clean experiment directory"""
        if os.path.exists(self.config.path) and not self.config.load:
            shutil.rmtree(self.config.path, ignore_errors=True)

        if not os.path.exists(self.config.path):
            os.makedirs(self.config.path)

    def _load_if_exists(self, name):
        """Load file if exists"""
        file = os.path.join(self.config.path, f"{name}.pkl")
        if os.path.exists(file):
            return pickle.load(open(file, "rb"))
        return

    def _save(self, obj, name):
        """Save object as pkl file"""
        file = os.path.join(self.config.path, f"{name}.pkl")
        pickle.dump(obj, open(file, "wb"))

    def _compute_model_sanity_checks(self):
        self.model_sanity_checks = self._load_if_exists("model_sanity_checks")

        if self.model_sanity_checks is None:
            self.model_sanity_checks = {}

            self.model_sanity_checks["label_shuffled"] = {
                scoring_method: eval_model_performance(
                    self.label_shuffed_model,
                    self.dataset.X_test,
                    self.dataset.y_test,
                    scoring_method=scoring_method,
                )
                for scoring_method in self.config.ablation_args["scoring_methods"]
            }

            self._save(self.model_sanity_checks, "model_sanity_checks")

        # TODO: Do we want also the default probabilities?
        # _, counts = np.unique(self.dataset.y_train, return_counts=True)
        # naive_pred = lambda x: np.tile(counts / sum(counts), (len(x), 1))
        # self.naive_performance = eval_model_performance(
        #     naive_pred,
        #     self.dataset.X_test,
        #     self.dataset.y_test,
        #     scoring_method=config.ablation_args["scoring_method"],
        # )

    def _compute_perturbations(
        self, perturbation_config: Dict[str, Dict[str, int]]
    ) -> None:
        """Compute perturbations

        Args:
            perturbation_config (Dict[str, Dict[str, int]]): dict of perturbation names associated with dicts of extra arguments
        """
        self.perturbations = self._load_if_exists("perturbations")
        if self.perturbations is None:
            self.perturbations = {
                trial: {
                    name: generate_perturbation_distribution(
                        name,
                        self.dataset.X_train,
                        self.dataset.X_test,
                        random_state=trial,
                        agg_map=self.dataset.agg_map,
                        **kwargs,
                    )
                    for name, kwargs in perturbation_config.items()
                }
                for trial in range(self.config.n_trials)
            }
            self._save(self.perturbations, "perturbations")

    def _compute_baselines(self, baseline_config: Dict[str, Dict[str, int]]) -> None:
        """Compute baselines

        Args:
            baseline_config (Dict[str, Dict[str, int]]): dict of baseline names associated with dicts of extra arguments
        """

        self.baselines = self._load_if_exists("baselines")

        if self.baselines is None:
            numpy_model = _predict_fn(self.model)
            y = numpy_model(self.dataset.X_train)
            y_obs = numpy_model(self.dataset.X_test)

            self.baselines = {
                trial: {
                    name: generate_baseline_distribution(
                        name,
                        self.dataset.X_train,
                        self.dataset.X_test,
                        y,
                        y_obs,
                        random_state=trial,
                        agg_map=self.dataset.agg_map,
                        **kwargs,
                    )
                    for name, kwargs in baseline_config.items()
                }
                for trial in range(self.config.n_trials)
            }
            self._save(self.baselines, "baselines")

    def _compute_explanations(
        self,
        explanation_methods: List[str],
        computed_baselines: Dict[str, np.ndarray],
    ) -> None:
        """Compute explanations

        Args:
            explanation_methods (List[str]): list of explanations methods
            computed_baselines (Dict[str, np.ndarray]): dict of baseline name with associated baselines
        """

        # TODO, we want to support raw explanations in addition to Explanations object
        self.explanations = self._load_if_exists("explanations")

        if self.explanations is None:
            self.explanations = {}
            non_random_exp_methods = [m for m in explanation_methods if m != "random"]
            for trial, baselines in computed_baselines.items():
                self.explanations[trial] = {}
                for method in non_random_exp_methods:
                    self.explanations[trial][method] = {}
                    for bname, baseline in baselines.items():
                        logger.info(
                            f"Running explanation: {method} | baseline: {bname}"
                        )

                        # Calculate the explanation values for set of observations
                        explanations = Explanations(
                            explanation_values=captum_explanation(
                                method,
                                self.model,
                                self.dataset.X_test,
                                baseline,
                                random_state=trial,
                            ),
                            agg_map=self.dataset.agg_map,
                        )

                        self.explanations[trial][method][bname] = explanations

                    if "random" in explanation_methods:
                        logger.info("Running random explanation")
                        explanations = Explanations(
                            explanation_values=captum_explanation(
                                "random",
                                self.model,
                                self.dataset.X_test,
                                baseline,
                                random_state=trial,
                            ),
                            agg_map=self.dataset.agg_map,
                        )
                        self.explanations[trial][method][
                            "random explanation"
                        ] = explanations

            self._save(self.explanations, "explanations")

    def run_exp(self):
        """Run experiment"""

        self._compute_model_sanity_checks()
        self._compute_perturbations(self.config.perturbation_config)
        self._compute_baselines(self.config.baseline_config)
        self._compute_explanations(self.config.explanation_methods, self.baselines)

        self.results = self._load_if_exists(self.config.result_name)

        if self.results is None or self.rerun_ablation:

            trials = []
            for trial in range(self.config.n_trials):
                logger.info(f"Running ablation trial {trial}")
                np.random.seed(trial)
                comb = []
                for p_name, perturb in self.perturbations[trial].items():
                    for exp_name, exp_dict in self.explanations[trial].items():
                        for b_name, exp in exp_dict.items():

                            # TODO: debug kernelshap and lime for overflow
                            # Below is currently a quick fix to exclude these samples
                            overflow_idx = np.unique(
                                np.where(exp.data("sparse") > 10)[0]
                            )
                            if len(overflow_idx) > 0:
                                logger.warn(
                                    f"Overflow warning (perturb: {p_name}, exp: {exp_name}, baseline: {b_name}):\n"
                                    f"{len(overflow_idx)} samples removed. "
                                    f"Indices: {','.join(overflow_idx.astype(str))}"
                                )

                            abl = Ablation(
                                perturbation_distribution=np.delete(
                                    perturb, overflow_idx, 0
                                ),
                                model=self.model,
                                dataset=self.dataset,
                                X=np.delete(self.dataset.X_test, overflow_idx, 0),
                                y=np.delete(self.dataset.y_test, overflow_idx, 0),
                                explanation_values=np.delete(
                                    exp.data("sparse"), overflow_idx, 0
                                ),
                                explanation_values_dense=np.delete(
                                    exp.data("dense"), overflow_idx, 0
                                ),
                                random_feat_idx=self.dataset.dense_random_feat_idx,
                                **self.config.ablation_args,
                            )
                            # abl = Ablation(
                            #     perturbation_distribution=perturb,
                            #     model=self.model,
                            #     X=self.dataset.X_test,
                            #     y=self.dataset.y_test,
                            #     explanation_values=exp,
                            #     random_feat_idx=self.dataset.random_feat_idx,
                            #     **self.config.ablation_args,
                            # )

                            result = abl.ablate_features()

                            n_obs = len(result["scores"])
                            result["explanation_method"] = [exp_name] * n_obs
                            result["baseline"] = [b_name] * n_obs
                            result["perturbation"] = [p_name] * n_obs
                            (
                                result["random_sanity_check_idx"],
                                result["random_sanity_check_perc"],
                            ) = abl.random_sanity_check_idx()
                            result[
                                "random_sanity_check_value"
                            ] = abl.random_sanity_check_value()
                            comb.append(result)

                trial_df = pd.concat(comb)
                trial_df["trial"] = trial
                trials.append(trial_df)

            self.results = pd.concat([t for t in trials]).reset_index(drop=True)
            self._save(self.results, self.config.result_name)

        return self.results

Methods

def run_exp(self)

Run experiment

Expand source code
def run_exp(self):
    """Run experiment"""

    self._compute_model_sanity_checks()
    self._compute_perturbations(self.config.perturbation_config)
    self._compute_baselines(self.config.baseline_config)
    self._compute_explanations(self.config.explanation_methods, self.baselines)

    self.results = self._load_if_exists(self.config.result_name)

    if self.results is None or self.rerun_ablation:

        trials = []
        for trial in range(self.config.n_trials):
            logger.info(f"Running ablation trial {trial}")
            np.random.seed(trial)
            comb = []
            for p_name, perturb in self.perturbations[trial].items():
                for exp_name, exp_dict in self.explanations[trial].items():
                    for b_name, exp in exp_dict.items():

                        # TODO: debug kernelshap and lime for overflow
                        # Below is currently a quick fix to exclude these samples
                        overflow_idx = np.unique(
                            np.where(exp.data("sparse") > 10)[0]
                        )
                        if len(overflow_idx) > 0:
                            logger.warn(
                                f"Overflow warning (perturb: {p_name}, exp: {exp_name}, baseline: {b_name}):\n"
                                f"{len(overflow_idx)} samples removed. "
                                f"Indices: {','.join(overflow_idx.astype(str))}"
                            )

                        abl = Ablation(
                            perturbation_distribution=np.delete(
                                perturb, overflow_idx, 0
                            ),
                            model=self.model,
                            dataset=self.dataset,
                            X=np.delete(self.dataset.X_test, overflow_idx, 0),
                            y=np.delete(self.dataset.y_test, overflow_idx, 0),
                            explanation_values=np.delete(
                                exp.data("sparse"), overflow_idx, 0
                            ),
                            explanation_values_dense=np.delete(
                                exp.data("dense"), overflow_idx, 0
                            ),
                            random_feat_idx=self.dataset.dense_random_feat_idx,
                            **self.config.ablation_args,
                        )
                        # abl = Ablation(
                        #     perturbation_distribution=perturb,
                        #     model=self.model,
                        #     X=self.dataset.X_test,
                        #     y=self.dataset.y_test,
                        #     explanation_values=exp,
                        #     random_feat_idx=self.dataset.random_feat_idx,
                        #     **self.config.ablation_args,
                        # )

                        result = abl.ablate_features()

                        n_obs = len(result["scores"])
                        result["explanation_method"] = [exp_name] * n_obs
                        result["baseline"] = [b_name] * n_obs
                        result["perturbation"] = [p_name] * n_obs
                        (
                            result["random_sanity_check_idx"],
                            result["random_sanity_check_perc"],
                        ) = abl.random_sanity_check_idx()
                        result[
                            "random_sanity_check_value"
                        ] = abl.random_sanity_check_value()
                        comb.append(result)

            trial_df = pd.concat(comb)
            trial_df["trial"] = trial
            trials.append(trial_df)

        self.results = pd.concat([t for t in trials]).reset_index(drop=True)
        self._save(self.results, self.config.result_name)

    return self.results