{ "cells": [ { "cell_type": "markdown", "id": "6898ac1c-c6e3-40d8-a787-c70f2b4e0b03", "metadata": { "tags": [] }, "source": [ "# Register a custom schema\n", "\n", "``rubicon_schema`` can be constructed within a Python session in addition to being read from\n", "the registry's YAML files\n", "\n", "## Define additional metadata to log\n", "\n", "Add an additional variable to the environment to record with our ``rubicon_schema``" ] }, { "cell_type": "code", "execution_count": 1, "id": "75a9fb48-2c0f-4fdc-91df-9105fde2892f", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AWS\n" ] } ], "source": [ "import os\n", "\n", "os.environ[\"RUNTIME_ENV\"] = \"AWS\"\n", "\n", "! echo $RUNTIME_ENV" ] }, { "cell_type": "markdown", "id": "6ce08db5-c532-4f1a-9357-61a6b3d1eadd", "metadata": {}, "source": [ "## Construct a custom schema\n", "\n", "Create a dictionary representation of the new, custom schema. This new schema will extend\n", "the existing ``RandomForestClassifier`` schema with an additional parameter that logs the\n", "new environment variable\n", "\n", "**Note:** The ``extends`` key is not required - custom schema do not need to extend existing schema" ] }, { "cell_type": "code", "execution_count": 2, "id": "020156b3-d0b2-4b99-8c5a-bfb60e02612d", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'extends': 'sklearn__RandomForestClassifier',\n", " 'name': 'sklearn__RandomForestClassifier__ext',\n", " 'parameters': [{'name': 'runtime_environment', 'value_env': 'RUNTIME_ENV'}]}\n" ] } ], "source": [ "import pprint\n", "\n", "extended_schema = {\n", " \"name\": \"sklearn__RandomForestClassifier__ext\",\n", " \"extends\": \"sklearn__RandomForestClassifier\", \n", "\n", " \"parameters\": [\n", " {\"name\": \"runtime_environment\", \"value_env\": \"RUNTIME_ENV\"},\n", " ],\n", "}\n", "pprint.pprint(extended_schema)" ] }, { "cell_type": "markdown", "id": "ae6d5dfa-e511-4901-ae61-f6f2aea55cb0", "metadata": {}, "source": [ "## Apply a custom schema to a project\n", "\n", "Create a ``rubicon_ml`` project" ] }, { "cell_type": "code", "execution_count": 3, "id": "6fde24fc-1ab3-49fc-8ea4-fb3573e3bb29", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from rubicon_ml import Rubicon\n", "\n", "rubicon = Rubicon(persistence=\"memory\", auto_git_enabled=True)\n", "project = rubicon.create_project(name=\"apply schema\")\n", "project" ] }, { "cell_type": "markdown", "id": "b9c46f5b-27da-42bb-b2d2-4761e35cdd4a", "metadata": { "tags": [] }, "source": [ "Apply the custom schema to the project" ] }, { "cell_type": "code", "execution_count": 4, "id": "e91c3d60-806a-49d4-a4b8-ec13489e6a11", "metadata": { "tags": [] }, "outputs": [], "source": [ "project.set_schema(extended_schema)" ] }, { "cell_type": "markdown", "id": "885b3e61-7875-445e-993f-3359bd4bb7ad", "metadata": {}, "source": [ "## Log model metadata with a custom schema\n", "\n", "Load a training dataset" ] }, { "cell_type": "code", "execution_count": 5, "id": "f71158d7-208d-4094-92b9-94b49a45cb6b", "metadata": { "tags": [] }, "outputs": [], "source": [ "from sklearn.datasets import load_wine\n", "\n", "X, y = load_wine(return_X_y=True, as_frame=True)" ] }, { "cell_type": "markdown", "id": "7b779808-771f-4c40-8250-a347e3b67c19", "metadata": { "tags": [] }, "source": [ "Train an instance of the model the schema represents" ] }, { "cell_type": "code", "execution_count": 6, "id": "838d254b-de2a-4155-909b-707728f343d9", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RandomForestClassifier(ccp_alpha=0.005, criterion='log_loss',\n", " max_features='log2', n_estimators=24, oob_score=True,\n", " random_state=121)\n" ] } ], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "\n", "rfc = RandomForestClassifier(\n", " ccp_alpha=5e-3,\n", " criterion=\"log_loss\",\n", " max_features=\"log2\",\n", " n_estimators=24,\n", " oob_score=True,\n", " random_state=121,\n", ")\n", "rfc.fit(X, y)\n", "\n", "print(rfc)" ] }, { "cell_type": "markdown", "id": "60e4c75b-1b92-4b8e-b938-81603162f2f4", "metadata": {}, "source": [ "Log the model metadata defined in the base ``RandomForestClassifier`` plus the additional parameter\n", "from the environment to a new experiment in ``project`` with ``project.log_with_schema``" ] }, { "cell_type": "code", "execution_count": 7, "id": "34341f0d-32a8-4a39-aaf6-dad8ccc8bf1b", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "experiment = project.log_with_schema(\n", " rfc,\n", " experiment_kwargs={\n", " \"name\": \"log with extended schema\",\n", " \"model_name\": \"RandomForestClassifier\",\n", " \"description\": \"logged with an extended `rubicon_schema`\",\n", " },\n", ")\n", "experiment" ] }, { "cell_type": "markdown", "id": "793daa8e-693b-4d2e-8c31-c71cd236291e", "metadata": {}, "source": [ "## View the experiment's logged metadata\n", "\n", "Each experiment contains all the data represented in the base ``RandomForestClassifier`` schema plus the\n", "additional parameter from the environment" ] }, { "cell_type": "code", "execution_count": 8, "id": "c656f695-5a30-4333-9aa8-a206f52a6d31", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "bootstrap: True\n", "ccp_alpha: 0.005\n", "class_weight: None\n", "criterion: log_loss\n", "max_depth: None\n", "max_features: log2\n", "min_impurity_decrease: 0.0\n", "max_leaf_nodes: None\n", "max_samples: None\n", "min_samples_split: 2\n", "min_samples_leaf: 1\n", "min_weight_fraction_leaf: 0.0\n", "n_estimators: 24\n", "oob_score: True\n", "random_state: 121\n", "runtime_environment: AWS\n" ] } ], "source": [ "for parameter in experiment.parameters():\n", " print(f\"{parameter.name}: {parameter.value}\")" ] }, { "cell_type": "markdown", "id": "d33f1585-00b3-42e6-9741-ac849a6cc8a9", "metadata": {}, "source": [ "Don't forget to clean up" ] }, { "cell_type": "code", "execution_count": 9, "id": "ec8e4159-0f97-4c4d-923a-b8f283184b66", "metadata": { "tags": [] }, "outputs": [], "source": [ "del os.environ[\"RUNTIME_ENV\"]" ] }, { "cell_type": "markdown", "id": "d4757cb2-00e8-4ba1-aa64-04959dfea5d8", "metadata": {}, "source": [ "## Persisting and sharing a custom schema\n", "\n", "To share custom schema with all ``rubicon_schema`` users, check out the \"Contribute a ``rubicon_schema``\" section" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }