{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3e2a6292-a3e4-4ef0-8208-ba93d0139bb5",
   "metadata": {},
   "source": [
    "# 🙆🏻 模型微调"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9e884eda-281c-45f6-a576-c6282529eccd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_breast_cancer\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "X, y = load_breast_cancer(return_X_y=True, as_frame=True)\n",
    "y[0:15] = 2  # 手动将2分类问题修改成3分类问题对应下面的`Dist=k_categorical(3)`\n",
    "X_cls_train, X_cls_test, Y_cls_train, Y_cls_test = train_test_split(X, y, test_size=0.2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3386ca3-60f9-4971-ac3b-21eeb4e15f48",
   "metadata": {},
   "source": [
    "## 壹丨分阶段预测\n",
    "\n",
    "NGBoost 的 Staged Prediction 是一种用于获取模型在不同训练阶段的预测结果的功能。这种方法允许用户在训练过程中查看模型在每个阶段的性能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "aa3eaba1-c929-42e9-9768-938a37545767",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ngboost import NGBClassifier\n",
    "from ngboost.distns import k_categorical, Bernoulli\n",
    "from ngboost.scores import LogScore\n",
    "\n",
    "ngb_cls = NGBClassifier(Dist=k_categorical(3), Score=LogScore, n_estimators=500, verbose=False).fit(X_cls_train, Y_cls_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6c01041-267b-4cad-9aeb-9f67a3947ad6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "ngb_cls.staged_predict(X_cls_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbbfded6-667e-4cb2-8498-1749cd8a6e03",
   "metadata": {},
   "source": [
    "## 贰丨提前结束\n",
    "\n",
    "将一个整数`early_stopping_rounds`和一个验证集(`X_val`, `Y_val`)传递给`fit()`，则会在验证损失+`early_stopping_rounds`后停止算法\n",
    "\n",
    "验证集数据权重可以通过`val_sample_weight`参数传给`fit()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2ae23fce-555f-4767-8ab0-9893165ff360",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[iter 0] loss=1.5641 val_loss=1.5444 scale=1.0000 norm=1.1066\n",
      "[iter 100] loss=1.1366 val_loss=1.1271 scale=2.0000 norm=1.5541\n",
      "[iter 200] loss=0.9167 val_loss=0.9139 scale=1.0000 norm=0.7047\n",
      "[iter 300] loss=0.7755 val_loss=0.7840 scale=1.0000 norm=0.6812\n",
      "[iter 400] loss=0.7027 val_loss=0.7235 scale=1.0000 norm=0.6786\n"
     ]
    }
   ],
   "source": [
    "from ngboost import NGBRegressor\n",
    "\n",
    "from sklearn.datasets import fetch_california_housing\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "X, Y = fetch_california_housing(return_X_y=True, as_frame=True)\n",
    "X_reg_train, X_reg_test, Y_reg_train, Y_reg_test = train_test_split(X, Y, test_size=0.2)\n",
    "\n",
    "_ = NGBRegressor().fit(X_reg_train, Y_reg_train, X_val=X_reg_test, Y_val=Y_reg_test, early_stopping_rounds=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86dab6af-deb5-45a2-a645-8d873431c5f5",
   "metadata": {},
   "source": [
    "## 叁丨使用sklearn模型选择"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3fae1f9b-963f-4841-88d6-6929bfe39f5d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'Base': DecisionTreeRegressor(criterion='friedman_mse', max_depth=2), 'minibatch_frac': 1.0}\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.tree import DecisionTreeRegressor\n",
    "from ngboost.distns import Exponential,Normal\n",
    "\n",
    "b1 = DecisionTreeRegressor(criterion='friedman_mse', max_depth=2)\n",
    "b2 = DecisionTreeRegressor(criterion='friedman_mse', max_depth=4)\n",
    "\n",
    "param_grid = {\n",
    "    'minibatch_frac': [1.0, 0.5],\n",
    "    'Base': [b1, b2]\n",
    "}\n",
    "\n",
    "ngb = NGBRegressor(Dist=Normal, verbose=False)\n",
    "\n",
    "grid_search = GridSearchCV(ngb, param_grid=param_grid, cv=5)\n",
    "grid_search.fit(X_reg_train, Y_reg_train)\n",
    "print(grid_search.best_params_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb4a476f-48b8-4dfc-85d1-fa6f0787bd5b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa45b788-e6cf-4672-93fa-93d1975ea394",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d308e2c3-8a35-47b3-8348-a09ba0739b55",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
