{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c9784ee967924331",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# 🦺 随机网格搜索\n",
    "\n",
    "sklearn中网格搜索优化方法包括两类：\n",
    "1. 调整搜索空间。\n",
    "2. 调整每次训练的数据\n",
    "\n",
    "调整搜索空间：挑选出部分参数组合，构造参数子空间，并只在子空间中搜索。如`n_estimators`5个参数与`max_depth`6个参数构成的5×6=30的空间中随机选择参数组合子空间，并只在这些上面进行搜索。相同全域空间下，可以更快；相同训练次数下，可以覆盖更大空间；得到的最小损失与网格搜索的最小损失很接近\n",
    "\n",
    "`sklearn.model_selection.RandomizedSearchCV()`\n",
    "\n",
    "| Name                  | Description                                                  |\n",
    "| --------------------- | ------------------------------------------------------------ |\n",
    "| `estimator`           | 评估器、调参对象                                             |\n",
    "| `param_distributions` | 全域参数空间，`dict`、`list` of `dict`                       |\n",
    "| `n_iter`              | 迭代次数                                                     |\n",
    "| `scoring`             | 评价指标，支持多个输出                                       |\n",
    "| `n_jobs`              | 线程数                                                       |\n",
    "| `refit`               | 挑选评估指标和最佳参数，在完整数据集上训练                   |\n",
    "| `cv`                  | 交叉验证折数                                                 |\n",
    "| `verbose`             | 输出工作日志                                                 |\n",
    "| `pre_dispatch`        | 多任务并行时任务划分数量                                     |\n",
    "| `random_state`        | 随机数种子                                                   |\n",
    "| `error_score`         | 网格搜索报错时返回结果，选择`raise`时直接报错并中断训练过程，其他情况会显示警告信息后继续完成训练 |\n",
    "| `return_train_score`  | 交叉验证是否显示训练集中参数得分                             |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3b1c5c22d7cebcdb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-22T14:14:31.615850847Z",
     "start_time": "2023-11-22T14:14:27.879758143Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['data', 'target', 'frame', 'target_names', 'feature_names', 'DESCR'])\n"
     ]
    }
   ],
   "source": [
    "# 导入加利福尼亚房价数据\n",
    "from sklearn.datasets import fetch_california_housing\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "data = fetch_california_housing()\n",
    "print(data.keys())\n",
    "\n",
    "# 划分数据集\n",
    "x = data['data']\n",
    "y = data['target']\n",
    "train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.2, random_state=22)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b42625c25e8d8664",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-22T14:14:37.129280456Z",
     "start_time": "2023-11-22T14:14:37.094823197Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from sklearn.model_selection import KFold, RandomizedSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a3d743c21dd3e53c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-22T14:14:39.414282388Z",
     "start_time": "2023-11-22T14:14:39.404915041Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def count_space(param_ranges):\n",
    "    \"\"\"计算全域参数空间大小\"\"\"\n",
    "    space_size = 1\n",
    "    for param_range in param_ranges.values():\n",
    "        space_size *= len(param_range)\n",
    "    return space_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4474833bedc2e3a5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-22T14:14:41.606581828Z",
     "start_time": "2023-11-22T14:14:41.586322207Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# 创造参数空间\n",
    "param_grid_sample = {'criterion': ['squared_error', 'poisson'],\n",
    "                     'n_estimators': [*range(20, 100, 5)],\n",
    "                     'max_depth': [*range(10, 25, 2)],\n",
    "                     'max_features': ['log2', 'sqrt', 16, 32, 64, 'auto'],\n",
    "                     'min_impurity_decrease': [*np.arange(0, 5, 10)]}\n",
    "\n",
    "# 建立回归器、交叉验证\n",
    "reg = RandomForestRegressor(random_state=22, verbose=True, n_jobs=-1)\n",
    "cv = KFold(n_splits=5, shuffle=True, random_state=22)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6c3103c0bde07ee",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-22T14:14:44.716265792Z",
     "start_time": "2023-11-22T14:14:44.698340152Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1536"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算全域参数空间大小，是能够抽样的最大值\n",
    "count_space(param_grid_sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e3c148aa1ea40cd",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-22T14:14:47.107560566Z",
     "start_time": "2023-11-22T14:14:47.097743759Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# 创建随机搜索评估器\n",
    "search = RandomizedSearchCV(estimator=reg,\n",
    "                            param_distributions=param_grid_sample,\n",
    "                            n_iter=200,  # 子空间大小设置为全域一半左右\n",
    "                            scoring='neg_mean_squared_error',\n",
    "                            verbose=True,\n",
    "                            cv=cv,\n",
    "                            random_state=22,\n",
    "                            n_jobs=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "165c76b4460808",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-22T14:18:09.742142984Z",
     "start_time": "2023-11-22T14:14:55.726068739Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# 训练随机搜索评估器\n",
    "start = time.time()\n",
    "search.fit(train_x, train_y)\n",
    "print(time.time() - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "22faac0139e8318",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-22T14:18:16.990582955Z",
     "start_time": "2023-11-22T14:18:16.987082793Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "最优参数: {'n_estimators': 90, 'min_impurity_decrease': 0, 'max_features': 'log2', 'max_depth': 18, 'criterion': 'squared_error'}\n",
      "最优算法结果(RMSE): 0.4991504254535447\n"
     ]
    }
   ],
   "source": [
    "# 查看最优的参数\n",
    "print(f'最优参数: {search.best_params_}\\n'\n",
    "      f'最优算法结果(RMSE): {abs(search.best_score_) ** 0.5}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e087eebe89ac8916",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-22T14:18:22.883403189Z",
     "start_time": "2023-11-22T14:18:19.915946361Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8223297122582989"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 根据最优参数重建模型\n",
    "best_reg = RandomForestRegressor(n_estimators=85, max_depth=24, max_features='log2', min_impurity_decrease=0,\n",
    "                                 criterion='squared_error')\n",
    "best_reg.fit(train_x, train_y)\n",
    "best_reg.score(test_x, test_y)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
