Module ablation.utils.model

Expand source code
from functools import partial
from typing import Callable, Union

import numpy as np
import torch
from sklearn.base import ClassifierMixin


def _torch_long(x: np.array, device=None) -> torch.tensor:
    return torch.tensor(x, device=device).long()


def _torch_float(x: np.array, device=None) -> torch.tensor:
    return torch.tensor(x, device=device).float()


def _as_numpy(x: torch.tensor) -> np.ndarray:
    return x.cpu().detach().numpy().astype(float)


def _predict_proba_model_type(
    X: np.array, model: Union[torch.nn.Module, ClassifierMixin, Callable],
) -> np.ndarray:
    """Checks type of model and predicts probability accordingly

    Args:
        X (np.array): numpy array
        model (Union[torch.nn.Module, ClassifierMixin, Callable]): model

    Raises:
        ValueError: if model type is unsupported

    Returns:
        np.ndarray: prediction
    """
    if isinstance(model, ClassifierMixin):
        pred = model.predict_proba(X)
    elif isinstance(model, torch.nn.Module):
        pred = _as_numpy(model(_torch_float(X)))
    elif isinstance(model, Callable):
        pred = model(X)
    else:
        raise ValueError("Model type not supported")

    output_shape = pred.shape

    if len(output_shape) == 1:
        return pred
    elif output_shape[1] == 1:
        return pred.flatten()
    elif output_shape[1] == 2:
        return pred[:, 1]

    return pred


def _predict_model_type(
    X: np.array, model: Union[torch.nn.Module, ClassifierMixin, Callable],
) -> np.ndarray:
    """Checks type of model and predicts accordingly

    Args:
        X (np.array): numpy array
        model (Union[torch.nn.Module, ClassifierMixin, Callable]): model

    Returns:
        np.ndarray: prediction
    """

    pred = _predict_proba_model_type(X, model)
    output_shape = pred.shape

    if len(output_shape) == 1:
        return np.round(pred)

    return np.argmax(pred, -1)


def _predict_proba_fn(
    model: Union[torch.nn.Module, ClassifierMixin, Callable]
) -> Callable:
    """Generalize all model predict probability functions

    Args:
        model (Union[torch.nn.Module, ClassifierMixin, Callable]): model

    Returns:
        Callable: generalized predict proba function
    """
    return partial(_predict_proba_model_type, model=model)


def _predict_fn(
    model: Union[torch.nn.Module, ClassifierMixin, Callable]
) -> Callable:
    """Generalize all model predict functions

    Args:
        model (Union[torch.nn.Module, ClassifierMixin, Callable]): model

    Returns:
        Callable: generalized predict proba function
    """
    return partial(_predict_model_type, model=model)