{ "cells": [ { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Set a schema on a project\n", "\n", "\"Log a ``rubicon_ml`` experiment with a ``rubicon_schema``\" showed how ``rubicon_schema`` can\n", "infer schema from the object to log - sometimes, this may not be possible and a schema may need to be set manually\n", "\n", "## Select a schema\n", "\n", "View all available schema" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "['sklearn__RandomForestClassifier',\n", " 'xgboost__XGBClassifier',\n", " 'xgboost__DaskXGBClassifier']" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from rubicon_ml.schema import registry\n", "\n", "available_schema = registry.available_schema()\n", "available_schema" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load a schema" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'artifacts': ['self'],\n", " 'compatibility': {'scikit-learn': {'max_version': None,\n", " 'min_version': '1.0.2'}},\n", " 'docs_url': 'https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html',\n", " 'features': [{'importances_attr': 'feature_importances_',\n", " 'names_attr': 'feature_names_in_',\n", " 'optional': True}],\n", " 'metrics': [{'name': 'classes', 'value_attr': 'classes_'},\n", " {'name': 'n_classes', 'value_attr': 'n_classes_'},\n", " {'name': 'n_features_in', 'value_attr': 'n_features_in_'},\n", " {'name': 'n_outputs', 'value_attr': 'n_outputs_'},\n", " {'name': 'oob_decision_function',\n", " 'optional': True,\n", " 'value_attr': 'oob_decision_function_'},\n", " {'name': 'oob_score',\n", " 'optional': True,\n", " 'value_attr': 'oob_score_'}],\n", " 'name': 'sklearn__RandomForestClassifier',\n", " 'parameters': [{'name': 'bootstrap', 'value_attr': 'bootstrap'},\n", " {'name': 'ccp_alpha', 'value_attr': 'ccp_alpha'},\n", " {'name': 'class_weight', 'value_attr': 'class_weight'},\n", " {'name': 'criterion', 'value_attr': 'criterion'},\n", " {'name': 'max_depth', 'value_attr': 'max_depth'},\n", " {'name': 'max_features', 'value_attr': 'max_features'},\n", " {'name': 'min_impurity_decrease',\n", " 'value_attr': 'min_impurity_decrease'},\n", " {'name': 'max_leaf_nodes', 'value_attr': 'max_leaf_nodes'},\n", " {'name': 'max_samples', 'value_attr': 'max_samples'},\n", " {'name': 'min_samples_split',\n", " 'value_attr': 'min_samples_split'},\n", " {'name': 'min_samples_leaf', 'value_attr': 'min_samples_leaf'},\n", " {'name': 'min_weight_fraction_leaf',\n", " 'value_attr': 'min_weight_fraction_leaf'},\n", " {'name': 'n_estimators', 'value_attr': 'n_estimators'},\n", " {'name': 'oob_score', 'value_attr': 'oob_score'},\n", " {'name': 'random_state', 'value_attr': 'random_state'}],\n", " 'verison': '1.0.0'}\n" ] } ], "source": [ "import pprint\n", "\n", "rfc_schema = registry.get_schema(\"sklearn__RandomForestClassifier\")\n", "pprint.pprint(rfc_schema)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Apply the schema to a project\n", "\n", "Create a ``rubicon_ml`` project" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from rubicon_ml import Rubicon\n", "\n", "rubicon = Rubicon(persistence=\"memory\")\n", "project = rubicon.create_project(name=\"apply schema\")\n", "project" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set the schema on the project" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [] }, "outputs": [], "source": [ "project.set_schema(rfc_schema)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, ``log_with_schema`` will leverage the schema ``rfc_schema`` instead of trying to infer one" ] } ]