View this notebook on GitHub or run it yourself on Binder!
Link Plots to Features¶
Rubicon_ml makes it easy to log plots with features and artifacts. In this example we’ll walk through creating a feature dependency plot using the shap package and saving it to an artifact.
Before getting started, we’ll have to install some dependencies for this example.
[1]:
! pip install matplotlib kaleido Pillow shap "numba>=0.56.2"
Requirement already satisfied: matplotlib in /Users/nvd215/mambaforge/envs/rubicon-ml-dev/lib/python3.10/site-packages (3.6.1)
Requirement already satisfied: kaleido in /Users/nvd215/mambaforge/envs/rubicon-ml-dev/lib/python3.10/site-packages (0.2.1)
Requirement already satisfied: Pillow in /Users/nvd215/mambaforge/envs/rubicon-ml-dev/lib/python3.10/site-packages (9.2.0)
Requirement already satisfied: shap in /Users/nvd215/mambaforge/envs/rubicon-ml-dev/lib/python3.10/site-packages (0.41.0)
Requirement already satisfied: numba>=0.56.2 in /Users/nvd215/mambaforge/envs/rubicon-ml-dev/lib/python3.10/site-packages (0.56.2)
Set up¶
First lets create a Rubicon project and create a pipeline with rubicon_ml.sklearn.pipeline.
[2]:
import shap
import sklearn
from sklearn.datasets import load_wine
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.preprocessing import StandardScaler
from rubicon_ml import Rubicon
from rubicon_ml.sklearn import make_pipeline
rubicon = Rubicon(persistence="memory")
project = rubicon.get_or_create_project("Logging Feature Plots")
X, y = load_wine(return_X_y=True)
reg = GradientBoostingRegressor(random_state=1)
pipeline = make_pipeline(project, reg)
pipeline.fit(X, y)
[2]:
RubiconPipeline(project=<rubicon_ml.client.project.Project object at 0x167a37c40>,
steps=[('gradientboostingregressor',
GradientBoostingRegressor(random_state=1))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RubiconPipeline(project=<rubicon_ml.client.project.Project object at 0x167a37c40>,
steps=[('gradientboostingregressor',
GradientBoostingRegressor(random_state=1))])GradientBoostingRegressor(random_state=1)
Generating Data¶
After fitting the pipeline, using shap.Explainer we can generate shap values to later plot. For more information on generating shap values with shap.explainer, check the documentation here.
[3]:
explainer = shap.Explainer(pipeline[0])
shap_values = explainer.shap_values(X)
Plotting¶
The generated shap_values from the above cell can be passed to a shap.depence_plot to generate a dependence plot. pl.gcf() allows the plot generated by shap to be saved to a variable. Using the matplotlib and io libraries, shap plots can be saved to a byte representation. Here, a feature and its plot are both logged withrubicon_ml.Features and rubuicon_ml.Artifact respectively. These features and artifacts are logged to the same
rubicon_ml.Experiment that was created by calling pipeline.fit.
[4]:
import io
import matplotlib.pyplot as pl
experiment = pipeline.experiment
for i in range(reg.n_features_in_):
feature_name = f"feature {i}"
experiment.log_feature(name=feature_name, tags=[feature_name])
shap.dependence_plot(i, shap_values, X, interaction_index=None, show=False)
fig = pl.gcf()
buf = io.BytesIO()
fig.savefig(buf, format="png")
buf.seek(0)
experiment.log_artifact(
data_bytes=buf.read(), name=feature_name, tags=[feature_name],
)
buf.close()
Retrieving your logged plot and features programmatically¶
Finally, we can retrieve a feature and its associated artifact plot using the name argument. We’ll retrieve each artifact based on the names of the features logged. With IO and PIL, after retrieving the PNG byte representation of the plot, the plot can be rendered as a PNG image.
[5]:
import io
from PIL import Image
experiment = pipeline.experiment
for feature in experiment.features():
artifact = experiment.artifact(name=feature.name)
buf = io.BytesIO(artifact.data)
scatter_plot_image = Image.open(buf)
print(feature.name)
display(scatter_plot_image)
feature 0
feature 1
feature 2
feature 3
feature 4
feature 5
feature 6
feature 7
feature 8
feature 9
feature 10
feature 11
feature 12